
library("SASxport")
library("matrixcalc")


## The cellFlagger method
#########################


cellFlagger <- function(X, mu, Sigma, quant = 0.99, 
                        rule = "residualsFromG") {
  # This is the cellFlagger algorithm starting from a given mu/Sigma. 
  # It returns a binary matrix W where 1 indicates a flagged & imputed cell.
  # Ximp is the imputed matrix, where the imputed values were RECALCULATED
  # using EM after W was constructed.
  # 
  # Arguments:
  #   X: data matrix
  #   mu: center
  #   Sigma: covariance matrix
  #   quant: quantile used for flagging cells
  #   rule: "allSignificantDelta" or "residualsFromG"
  # Returns: 
  #   Ximp = imputed data
  #   W = matrix indicating which cells were flagged/imputed (1 = flagged & imputed)
  
  if (!rule %in% c("allSignificantDelta", "residualsFromG")) {
    stop("invalid rule argument")
  }
  n <- dim(X)[1]
  d <- dim(X)[2]
  inv.out    <- mpinv(Sigma)
  Sigmai     <- inv.out$Inv
  Sigmaisqrt <- inv.out$InvSqrt
  scales     <- sqrt(diag(Sigma))
  predictors <- Sigmaisqrt
  Ximp <- X
  W <- matrix(0, n, d)
  Zres <- matrix(0, n, d)
  cellPaths <- cellOrder <- Zres_num <- Zres_denom <- matrix(0, n, d)
  #
  for (i in 1:n) {
    x        <- X[i, ] 
    response <- Sigmaisqrt %*% (x - mu)
    
    weights  <- huberweights(x = (x - mu) / scales, b = 1.5)
    
    larOut   <- findCellPath(predictors = predictors, response = response,
                             weights = weights,
                             Sigmai = Sigmai)
    cellPaths[i, ] <- larOut$ordering
    cellOrder[i, larOut$ordering] <- 1:d
    diffRSS  <- abs(diff(larOut$RSS))
    pvals    <- 1 - pchisq(diffRSS, 1)
    badCells <- which(pvals < 1 - quant)
    #
    if (length(badCells) > 0) {
      if (rule == "allSignificantDelta") {
        # allSignificantDelta Rule
        badinds <- larOut$ordering[badCells]
      } else if (rule == "residualsFromG") {
        badinds  <- larOut$ordering[1:max(badCells)] # maxDelta rule
        # now calculate residuals:
        if (length(badinds) == d) {
          stdresid <- (x - mu) / sqrt(diag(Sigma))
        } else {
          stdresid <- rep(0, d)
          stdresid[badinds] <- abs(larOut$beta[length(badinds) + 1, badinds ]) / 
            sqrt(diag(Sigma[badinds, badinds] - Sigma[badinds, -badinds] %*%
                        solve(Sigma[-badinds, -badinds]) %*% Sigma[-badinds, badinds]))
        }
        badinds <- which(abs(stdresid) > sqrt(qchisq(quant, 1)))
      }
      W[i, badinds] <- 1
      
      if (length(badinds) == d) {
        Ximp[i, ] <- mu
        Zres_num[i, ] <- (X[i, ] - mu)
        Zres_denom[i, ] <- sqrt(diag(Sigma))
        Zres[i, ] <-  Zres_num[i, ] / Zres_denom[i, ]
      } else {
        replacement <- X[i, ]
        replacement[badinds] <- mu[badinds] +  Sigma[badinds, -badinds] %*%
          solve(Sigma[-badinds, -badinds]) %*% (replacement[-badinds] - mu[-badinds])
        Ximp[i, ] <- replacement
        residual  <- X[i, ] - replacement
        Zres_num[i, badinds]   <- residual[badinds]
        Zres_denom[i, badinds] <- sqrt(diag(Sigma[badinds, badinds] - Sigma[badinds, -badinds] %*%
                                              solve(Sigma[-badinds, -badinds]) %*% Sigma[-badinds, badinds]))
        Zres[i, badinds] <- Zres_num[i, badinds] / Zres_denom[i, badinds] 
      }
    }
  }
  return(list(Ximp = Ximp, W = W, Zres = Zres,
              cellPaths = cellPaths, cellOrder = cellOrder,
              Zres_num = Zres_num, Zres_denom = Zres_denom))
}



findCellPath <- function(predictors, response, weights, Sigmai){
  # response should be  Sigmaisqrt %*% (z - mu)
  # predictors should be Sigmaisqrt
  #
  # empty containers
  d    <- dim(predictors)[2]
  beta <- matrix(0, nrow = d + 1, ncol = d) # matrix of coefficients
  bmat <- array(0, c(d + 1, d, d))
  RSS  <- sum(response^2)
  #
  # first we scale the predictors
  x       <- scale(predictors, FALSE, weights)
  y       <- response
  gramMat <- scale(t(scale(t(Sigmai), FALSE, weights)),
                   FALSE, weights)
  #
  # perform lar regression
  lar.out <- lar_ols(x = x, y = y, type = "lar", normalize = FALSE,
                     intercept = FALSE, positive = FALSE,
                     Gram = gramMat)
  path    <- unlist(lar.out$actions) # save the path
  if (dim(lar.out$beta)[1] == 1) {path <- c()}
  if (length(lar.out$ignores) > 0) { # this happens in some collinear situations
    path <- path[-which(path < 0)]
    missingVars <- which(!(1:d) %in% path)
    path <- c(path, missingVars)
    lar.out$betaOLS[(d + 2 - length(missingVars)):(d + 1),] <- 
      matrix(mpinv(gramMat)$Inv %*% t(x) %*% y,
             nrow = length(missingVars),
             ncol = d, byrow = TRUE)
  }
  missingVars <- which(!(1:d) %in% path)
  if (length(missingVars) > 0) { 
    # happens when original x has components (almost) same as mu
    path <- c(path, missingVars)
    lar.out$betaOLS[(d + 2 - length(missingVars)):(d + 1), ] <- matrix(lar.out$betaOLS[dim(lar.out$betaOLS)[1] - length(missingVars), ],
                                                                         nrow = length(missingVars),
                                                                         ncol = d, byrow = TRUE)
    lar.out$RSS_sw <- c(lar.out$RSS_sw, rep(lar.out$RSS_sw[length(lar.out$RSS_sw)],
                                            length(missingVars)))
    lar.out$biasMat[(d + 2 - length(missingVars)):(d + 1), , ] <- 
      array(data = lar.out$biasMat[dim(lar.out$betaOLS)[1] - length(missingVars), , ],
            dim = c(length(missingVars), d, d))
  }
  beta <- scale(lar.out$betaOLS, FALSE, weights)
  RSS  <- lar.out$RSS_sw
  bmat <- t(apply(lar.out$biasMat, c(1), function(y) t(y/(weights))/(weights)))
  dim(bmat) <- dim(lar.out$biasMat)
  return(list(beta = beta,
              RSS = RSS,
              biasmat = bmat,
              ordering = path))
}


mpinv <- function(X, tol = sqrt(.Machine$double.eps)) {
  ## Moore-Penrose generalized inverse of a matrix and its square root.
  ## This is the function ginv from MASS.
  dnx = dimnames(X)
  if (is.null(dnx)) dnx <- vector("list", 2)
  s = svd(X)
  nz = s$d > tol * s$d[1]
  if (any(nz)) {
    outInv <- s$v[, nz] %*% (t(s$u[, nz])/s$d[nz])
    outInvsqrt <- s$v[, nz] %*% (t(s$u[, nz])/sqrt(s$d[nz]))
  } else { 
    outInv <- outInvsqrt <- X
  }
  return(list(Inv = outInv, InvSqrt = outInvsqrt))
}


huberweights <- function(x, b) {
  highind <- which(abs(x) > b)
  result  <- rep(1, length(x))
  result[highind] <- abs(b / x[highind])
  return(result)
}


KLdiv <- function(T1,T2,d) {
  result <- matrixcalc::matrix.trace(T1 %*% mpinv(T2,
                                                  tol = sqrt(.Machine$double.eps))$Inv) - 
    log(det(T1 %*% mpinv(T2, tol = sqrt(.Machine$double.eps))$Inv)) - d
  return(result)
}


lar_ols <- function(x, y, type = c("lasso", "lar", "forward.stagewise", 
                                   "stepwise"), trace = FALSE, normalize = FALSE, intercept = TRUE, 
                    Gram, eps = sqrt(.Machine$double.eps), max.steps, use.Gram = TRUE,
                    positive=FALSE){
  call <- match.call()
  type <- match.arg(type)
  TYPE <- switch(type, lasso = "LASSO", lar = "LAR", forward.stagewise = "Forward Stagewise", 
                 stepwise = "Forward Stepwise")
  if (trace) 
    cat(paste(TYPE, "sequence\n"))
  nm <- dim(x)
  n <- nm[1]
  m <- nm[2]
  im <- inactive <- seq(m)
  one <- rep(1, n)
  vn <- dimnames(x)[[2]]
  if (intercept) {
    meanx <- drop(one %*% x)/n
    x <- scale(x, meanx, FALSE)
    mu <- mean(y)
    y <- drop(y - mu)
  }
  else {
    meanx <- rep(0, m)
    mu <- 0
    y <- drop(y)
  }
  if (normalize) {
    normx <- sqrt(drop(one %*% (x^2)))
    nosignal <- normx/sqrt(n) < eps
    if (any(nosignal)) {
      ignores <- im[nosignal]
      inactive <- im[-ignores]
      normx[nosignal] <- eps * sqrt(n)
      if (trace) 
        cat("LARS Step 0 :\t", sum(nosignal), "Variables with Variance < eps; dropped for good\n")
    }
    else ignores <- NULL
    names(normx) <- NULL
    x <- scale(x, FALSE, normx)
  }
  else {
    normx <- rep(1, m)
    ignores <- NULL
  }
  if (use.Gram & missing(Gram)) {
    if (m > 500 && n < m) 
      cat("There are more than 500 variables and n<m;\nYou may wish to restart and set use.Gram=FALSE\n")
    if (trace) 
      cat("Computing X'X .....\n")
    Gram <- t(x) %*% x
  }
  Cvec <- drop(t(y) %*% x)
  ssy <- sum(y^2)
  residuals <- y
  if (missing(max.steps)) 
    max.steps <- 8 * min(m, n - intercept)
  beta <- matrix(0, max.steps + 1, m)
  lambda = double(max.steps)
  Gamrat <- NULL
  arc.length <- NULL
  R2 <- 1
  RSS <- ssy
  first.in <- integer(m)
  active <- NULL
  actions <- as.list(seq(max.steps))
  drops <- FALSE
  Sign <- NULL
  R <- NULL
  k <- 0
  beta_OLS    <- matrix(0, nrow = max.steps + 1, ncol = dim(x)[1]) # matrix of coefficients
  RSS_sw      <- sum(residuals ^ 2) # initial MD^2
  biasMat     <- array(0, c(max.steps + 1, dim(x)[2], dim(x)[2]))
  
  while ((k < max.steps) & (length(active) < min(m - length(ignores), 
                                                 n - intercept))) {
    action <- NULL
    C <- Cvec[inactive]
    
    if (positive) {
      Cmax <- max(C)
    } else {
      Cmax <- max(abs(C))
    }
    
    if (Cmax < eps^2 * 100) {
      if (trace) 
        cat("Max |corr| = 0; exiting...\n")
      break
    }
    k <- k + 1
    lambda[k] = Cmax
    if (!any(drops)) {
      if (positive) {
        new <- C >= Cmax - eps
      } else {
        new <- abs(C) >= Cmax - eps
      }
      
      C <- C[!new]
      new <- inactive[new]
      for (inew in new) {
        if (use.Gram) {
          R <- lars::updateR(Gram[inew, inew], R, drop(Gram[inew, 
                                                            active]), Gram = TRUE, eps = eps)
        } else {
          R <- lars::updateR(x[, inew], R, x[, active], Gram = FALSE, 
                             eps = eps)
        }
        if (attr(R, "rank") == length(active)) {
          nR <- seq(length(active))
          R <- R[nR, nR, drop = FALSE]
          attr(R, "rank") <- length(active)
          ignores <- c(ignores, inew)
          action <- c(action, -inew)
          if (trace) 
            cat("LARS Step", k, ":\t Variable", inew, 
                "\tcollinear; dropped for good\n")
        } else {
          if (first.in[inew] == 0) {
            first.in[inew] <- k
          }
          active <- c(active, inew)
          
          if (positive) {
            Sign <- c(Sign, 1)
          } else {
            Sign <- c(Sign, sign(Cvec[inew]))
          }
          
          action <- c(action, inew)
          if (trace){
            cat("LARS Step", k, ":\t Variable", inew, 
                "\tadded\n")
          }
        }
      }
    }
    else {
      action <- -dropid
    }
    Gi1 <- backsolve(R, lars::backsolvet(R, Sign))
    
    newModelSize <- length(active)
    oldModelSize <- length(active) - length(action)
    if (length(action) > 1) {
      oldcoefs <- rep(0, m)
      oldcoefs[active]  <- beta_OLS[oldModelSize + 1, active]
      newcoefs <- rep(0, m)
      newcoefs[active] <- backsolve(R, backsolve(R, t(x[, active]),transpose = TRUE)) %*% y
      
      
      beta_OLS[(oldModelSize + 2):(newModelSize),] <-
        matrix(oldcoefs, length(action) - 1, m, byrow = TRUE)
      beta_OLS[newModelSize + 1, ] <-  newcoefs
      residuals <- as.vector(y -  beta_OLS[newModelSize + 1, active] %*% t(x[, active]))
      RSS_sw <- c(RSS_sw, rep(RSS_sw[length(RSS_sw)],
                              length(action) - 1), sum(residuals ^ 2))
      rho_sw <- drop(t(residuals) %*% x)
    } else {
      newcoefs <- rep(0, m)
      newcoefs[active] <- backsolve(R, backsolve(R, t(x[, active]),transpose = TRUE)) %*% y
      
      beta_OLS[newModelSize + 1, ] <- newcoefs
      residuals <- as.vector(y -  beta_OLS[newModelSize + 1, active] %*% t(x[, active]))
      RSS_sw <- c(RSS_sw, sum(residuals ^ 2))
      rho_sw <- drop(t(residuals) %*% x)
    }
    
    dropouts <- NULL
    if (type == "forward.stagewise") {
      directions <- Gi1 * Sign
      if (!all(directions > 0)) {
        if (use.Gram) {
          nnls.object <- lars::nnls.lars(active, Sign, R, directions, 
                                         Gram[active, active], trace = trace, use.Gram = TRUE, 
                                         eps = eps)
        }  else {
          nnls.object <- lars::nnls.lars(active, Sign, R, directions, 
                                         x[, active], trace = trace, use.Gram = FALSE, 
                                         eps = eps)
        }
        positive <- nnls.object$positive
        dropouts <- active[-positive]
        action <- c(action, -dropouts)
        active <- nnls.object$active
        Sign <- Sign[positive]
        Gi1 <- nnls.object$beta[positive] * Sign
        R <- nnls.object$R
        C <- Cvec[-c(active, ignores)]
      }
    }
    A <- 1/sqrt(sum(Gi1 * Sign))
    w <- A * Gi1
    if (!use.Gram) 
      u <- drop(x[, active, drop = FALSE] %*% w)
    if ((length(active) >= min(n - intercept, m - length(ignores))) | 
        type == "stepwise") {
      gamhat <- Cmax/A
    }  else {
      if (use.Gram) {
        a <- drop(w %*% Gram[active, -c(active, ignores), 
                             drop = FALSE])
      }  else {
        a <- drop(u %*% x[, -c(active, ignores), drop = FALSE])
      }
      
      if (positive) {
        gam <- c((Cmax - C)/(A - a))
      } else {
        gam <- c((Cmax - C)/(A - a), (Cmax + C)/(A + a))
      }
      
      gamhat <- min(gam[gam > eps], Cmax/A)
    }
    if (type == "lasso") {
      dropid <- NULL
      b1 <- beta[k, active]
      z1 <- -b1/w
      zmin <- min(z1[z1 > eps], gamhat)
      if (zmin < gamhat) {
        gamhat <- zmin
        drops <- z1 == zmin
      } else {
        drops <- FALSE
      }
    }
    beta[k + 1, ] <- beta[k, ]
    beta[k + 1, active] <- beta[k + 1, active] + gamhat * 
      w
    if (use.Gram) {
      Cvec <- Cvec - gamhat * Gram[, active, drop = FALSE] %*% 
        w
    }  else {
      residuals <- residuals - gamhat * u
      Cvec <- drop(t(residuals) %*% x)
    }
    Gamrat <- c(Gamrat, gamhat/(Cmax/A))
    arc.length <- c(arc.length, gamhat)
    if (type == "lasso" && any(drops)) {
      dropid <- seq(drops)[drops]
      for (id in rev(dropid)) {
        if (trace) 
          cat("Lasso Step", k + 1, ":\t Variable", active[id], 
              "\tdropped\n")
        R <- lars::downdateR(R, id)
      }
      dropid <- active[drops]
      beta[k + 1, dropid] <- 0
      active <- active[!drops]
      Sign <- Sign[!drops]
    }
    if (!is.null(vn)) 
      names(action) <- vn[abs(action)]
    actions[[k]] <- action
    inactive <- im[-c(active, ignores)]
    if (type == "stepwise") 
      Sign = Sign * 0
    biasMat[k + length(action), active, active] <-
      chol2inv(R) # equivalent to previous 2 lines
    
  }
  beta <- beta[seq(k + 1), , drop = FALSE]
  beta_OLS <- beta_OLS[seq(m + 1), , drop = FALSE]
  biasMat <- biasMat[seq(m + 1), , , drop = FALSE]
  
  lambda = lambda[seq(k)]
  dimnames(beta) <- list(paste(0:k), vn)
  if (trace) 
    cat("Computing residuals, RSS etc .....\n")
  residuals <- y - x %*% t(beta)
  beta <- scale(beta, FALSE, normx)
  RSS <- apply(residuals^2, 2, sum)
  R2 <- 1 - RSS/RSS[1]
  actions = actions[seq(k)]
  netdf = sapply(actions, function(x) sum(sign(x)))
  df = cumsum(netdf)
  if (intercept) 
    df = c(Intercept = 1, df + 1)
  else df = c(Null = 0, df)
  rss.big = rev(RSS)[1]
  df.big = n - rev(df)[1]
  if (rss.big < eps | df.big < eps) 
    sigma2 = NaN
  else sigma2 = rss.big/df.big
  Cp <- RSS/sigma2 - n + 2 * df
  attr(Cp, "sigma2") = sigma2
  attr(Cp, "n") = n
  object <- list(call = call, type = TYPE, df = df, lambda = lambda, 
                 R2 = R2, RSS = RSS, Cp = Cp, actions = actions[seq(k)], 
                 entry = first.in, Gamrat = Gamrat, arc.length = arc.length, 
                 Gram = if (use.Gram) Gram else NULL, beta = beta, mu = mu, 
                 normx = normx, meanx = meanx, RSS_sw = RSS_sw, betaOLS = beta_OLS,
                 biasMat = biasMat, ignores = ignores)
  class(object) <- "lars"
  object
}


## Initial estimators:

DDCW <- function(X, maxCol = 0.25) {
  # Functions only used here:
  
  DDC_controlled <- function(X, tolProbCell, maxCol = 0.25) {
    # Executes DDC with a given tolProbCell, which ensures that
    # no more than maxCol*n cells are flagged in any variable
    n <- dim(X)[1]
    d <- dim(X)[2]
    DDCout <- cellWise::DDC(X, list(fastDDC = FALSE, silent = TRUE,
                                    tolProbCell = tolProbCell, standType = "wrap"))
    Wna <- matrix(0, n, d); Wna[DDCout$indcells] <- 1
    overflag <- which(colSums(Wna) > maxCol * n)
    if (length(overflag) > 0) {
      for (i in 1:length(overflag)) {
        ind <- overflag[i]
        replacement <- rep(0, n)
        replacement[order(abs(DDCout$stdResid[, ind]),
                          decreasing = TRUE)[1:(floor(maxCol*n))]] <- 1
        Wna[, ind] <- replacement
      }
      DDCout$indcells <- which(Wna == 1)
      DDCout$Ximp <- X; DDCout$Ximp[DDCout$indcells] <- DDCout$Xest[DDCout$indcells]
    }
    return(DDCout)
  }
  
  iDDC9I.O.Wrap <- function(X, maxCol = 0.25) {
    n <- dim(X)[1]
    d <- dim(X)[2]
    DDCout   <- DDC_controlled(X, tolProbCell = 0.9, maxCol = maxCol)
    locScale <- list(loc = DDCout$locX, scale = DDCout$scaleX)
    Z        <- scale(X, locScale$loc, locScale$scale)
    Zimp     <- scale(DDCout$Ximp, locScale$loc, locScale$scale)
    Zorig    <- Z
    Zimporig <- Zimp
    Znaorig  <- Z; Znaorig[DDCout$indcells] <- NA
    
    if (length(DDCout$indrows) > 0) {
      Z                   <- Z[-DDCout$indrows, ]
      Zimp                <- Zimp[-DDCout$indrows, ]
    }
    
    # Calculate first eigenvector estimate on imputed data, project and estimate scale
    eigenvectors  <- eigen(cov(Zimp), symmetric = TRUE)$vectors
    Zimp_proj     <- Zimp %*% eigenvectors
    locscale_proj <- cellWise::estLocScale(Zimp_proj)
    
    # Calculate final estimate
    Zimp_proj_w <- cellWise::wrap(Zimp_proj, locscale_proj$loc, locscale_proj$scale)$Xw
    cov         <- eigenvectors %*% cov(Zimp_proj_w) %*%  t(eigenvectors)
    
    return(list(locScale = locScale,
                mu = rep(0, dim(X)[2]),
                cov = cov,
                Z = Z,
                Zorig = Zorig,
                Zimporig = Zimporig,
                Znaorig = Znaorig,
                indrows = DDCout$indrows))
  }
  
  RR <- function(Z, cov, b = 2, quant = 0.99) {
    d <- dim(Z)[2]
    MDs <- mahalanobis(pmin(pmax(Z, -b), b), rep(0, d), cov)
    rowinds <- which(MDs / median(MDs) * qchisq(0.5, d) > qchisq(quant, d))
    return(rowinds)
  }
  
  iDDC9I.O.Wrap.RR <- function(X, maxCol = 0.25) {
    
    d <- dim(X)[2]
    result   <- iDDC9I.O.Wrap(X, maxCol = maxCol)
    cov      <- result$cov
    locScale <- result$locScale
    Z        <- result$Zorig
    Zimp     <- result$Zimporig
    Zna      <- result$Znaorig
    
    rowinds <- RR(Z, cov, 2, 0.99)
    if (length(rowinds) > 0) {
      Z <- Z[-rowinds, ]
      Zimp <- Zimp[-rowinds,]
    }
    
    return(list(locScale = locScale,
                mu = rep(0, dim(X)[2]),
                cov = cov,
                Z = Z,
                Zimp = Zimp,
                Zna = Zna,
                indrows = rowinds))
  }
  
  d <- dim(X)[2]
  result   <- iDDC9I.O.Wrap.RR(X, maxCol = maxCol)
  cov      <- result$cov
  locScale <- result$locScale
  Z        <- result$Z
  Zimp     <- result$Zimp
  indrows  <- result$indrows
  #
  # orthogonalize
  eigenvectors  <- eigen(cov(Zimp), symmetric = TRUE)$vectors
  Zimp_proj     <- Zimp %*% eigenvectors
  locscale_proj <- cellWise::estLocScale(Zimp_proj)
  #
  # Wrap and Calculate final estimate
  Zimp_proj_w <- cellWise::wrap(Zimp_proj, locscale_proj$loc, locscale_proj$scale)$Xw
  cov <- eigenvectors %*% cov(Zimp_proj_w) %*%  t(eigenvectors)
  return(list(locScale = locScale,
              mu = rep(0, dim(X)[2]),
              cov = cov,
              Z = Z, 
              indrows = indrows))
}


TwoSGS <- function(X) {
  tsgs.out <- GSE::TSGS(X)
  locScale <- list(loc = tsgs.out@mu, scale = sqrt(diag(tsgs.out@S)))
  cov      <- cov2cor(tsgs.out@S)
  Z        <- scale(X, locScale$loc, locScale$scale)
  return(list(locScale = locScale,
              mu = rep(0, dim(X)[2]),
              cov = cov,
              Z = Z))
}


## The Detection Imputation (DI) method:


DI = function(X,
              initEst = DDCW,
              crit = 0.01,
              maxits = 50,
              quant = 0.99,
              implosionguard = "none",
              maxCol = 0.25,
              deltaCrit = "maxDelta"){
  # Computes a covariance matrix on data with possibly
  # both cellwise and casewise outliers.
  # 
  # Arguments:
  #   X: data matrix
  #   initEst: the initial estimator used
  #   maxits: maximum number of iterations
  #   crit: criterion for convergence
  #   quant: quantile used for flagging cells
  #   implosionguard: either "cor", "trace" or "none"
  #   deltaCrit: "maxDelta" or "minDelta"
  # Returns: 
  #
  # center =  final estimate of the center
  # cov = final estimate of the covariance matrix
  # center_init = initial estimate of the center
  # cov_init = initial estimate of the covariance matrix
  # allSigmas = 3-dim array with the covariance estimates of every iteration
  # Ximp = imputed data
  # nbimps = final number of imputes
  # W = matrix indicating which cells were imputed
  #
  # DI uses its own version of CellFlagger since it needs to take 
  # into account maximp, biascorrections, etc.
  
  mS_cov <- function(X, distances, nbimps) {
    return(list(mu = colMeans(X), cov = cov(X)))
  }
  
  if (!deltaCrit %in% c("maxDelta", "minDelta")) {
    stop("invalid deltaCrit argument")
  }
  
  # Step 1a: initial estimate & standardization
  if (is.list(initEst)) { # initEst is list with mu and Sigma
    locScale_init <- list(loc = initEst$mu, scale = sqrt(diag(initEst$Sigma)))
    mu_init  <- rep(0, dim(X)[2])
    cov_init <- cov2cor(initEst$Sigma)
    Z <- scale(X, initEst$mu, sqrt(diag(initEst$Sigma)))
    out_init <- list()
  } else {
    out_init      <- initEst(X)
    locScale_init <- out_init$locScale # locScale of X
    mu_init       <- out_init$mu # initial location of Z (should be zeroes)
    cov_init      <- out_init$cov # initial cov of Z
    Z             <- out_init$Z
  }
  
  # Step 1b: initialization
  nbits    <- 0
  convcrit <- 1
  n        <- dim(Z)[1]
  d        <- dim(Z)[2]
  M        <- floor(maxCol * n) # max number of imputed cells per variable
  mu       <- mu_init
  Sigma    <- cov_init
  invOut   <- mpinv(cov_init)
  Sigmai   <- invOut$Inv
  Sigmaisqrt <- invOut$InvSqrt
  Zimp <- Z
  
  # Step 1c: containers for simulation
  Sigmas        <- array(0, dim = c(maxits + 1, d, d) )
  Sigmas[1, , ] <- cov_init
  
  # Step 2: iteration step
  while ((nbits < maxits) && (convcrit > crit)) {
    # Step 2a: flag 
    
    # 2a Stage 1: Univariate regressions
    predictors <- Sigmaisqrt
    betamat    <- array(0, c(n, d + 1, d))
    Bmat       <- array(0, c(n, d + 1, d, d)) # matrix containing bias terms
    orderings  <- matrix(0, n, d)
    distances  <- matrix(0, n, d + 1)
    pvals      <- matrix(1, n, d)
    for (i in 1:n) {
      z        <- Z[i, ] 
      response <- Sigmaisqrt %*% (z - mu)
      
      weights  <- huberweights(x = z - mu, b = 1.5)
      
      larOut   <- findCellPath(predictors = predictors,
                            response = response,
                            weights = weights,
                            Sigmai = Sigmai)
      diffRSS  <- abs(diff(larOut$RSS))
      
      
      pvals[i, larOut$ordering] <- 1 - pchisq(diffRSS, 1)
      if (deltaCrit == "maxDelta") {# maxdelta rule
        pvals[i, larOut$ordering] <- rev(cummin(rev(pvals[i,larOut$ordering])))
      } else {# minDelta rule
        pvals[i, larOut$ordering] <-  cummax(pvals[i,larOut$ordering])
      } 
      
      betamat[i, , ] <- larOut$beta
      Bmat[i, , , ]  <- larOut$biasmat
      distances[i, ] <- larOut$RSS
      orderings[i, ] <- larOut$ordering
    }
    
    
    # 2a stage 2: sort and iterate through p-values
    # to determine the actual flagged cells
    # Goal of this stage is to take maxCol into account
    
    tiebraker <- t(apply(orderings, 1, function(y) order(y))) + 1:n * d
    pvals_order <- order(pvals, tiebraker, # order of pvalues in increasing order
                         decreasing = FALSE) # second argument of order() solves ties
    
    NBimps_col <- rep(0, d)
    cutpoints  <- rep(1, n) # where to stop the paths (1 = no imputes)
    droppedPaths <- rep(0, n) # which paths are dropped (="locked")
    W <- matrix(0, n , d)
    for (i in 1:length(pvals_order)) {
      idx  <- pvals_order[i]
      pval <- pvals[idx]
      rownb <- (idx - 1) %% n + 1
      if (pval < 1 - quant) {
        if (!droppedPaths[rownb]) {# check if case still has an open path
          colnb <- ((idx - 1) %/% n) + 1
          if (NBimps_col[colnb] < M) {
            cutpoints[rownb]  <- cutpoints[rownb] + 1
            NBimps_col[colnb] <- NBimps_col[colnb] + 1
            W[rownb, colnb] <- 1
          } else {
            droppedPaths[rownb] <- 1
            # print(paste0("Iteration ",nbits ,"| careful: maxCol in variable ",colnb ))
          }
        }
      } else {
        droppedPaths[rownb] <- 1
      }
    }
    
    #  Step 2 B: impute cells
    finalBetas <- matrix(0, n, d)
    finalBias  <- matrix(0, d, d)
    finalDistances <- rep(0, n)
    finalNbimps <- cutpoints - 1
    finalW     <- matrix(0, n, d)
    for (i in 1:n) {
      finalBetas[i, ]   <- betamat[i, cutpoints[i], ]
      finalBias         <- finalBias + Bmat[i, cutpoints[i], , ]
      finalDistances[i] <- distances[i, cutpoints[i]]
      finalW[i, ] <- (abs(betamat[i, cutpoints[i], ]) > 1e-10) + 0
    }
    
    Zimp <- Z - finalBetas
    
    
    # Step 2C: re-estimate the covariance matrix
    muSigmaNew <- mS_cov(Zimp, finalDistances, finalNbimps) 
    mu         <- muSigmaNew$mu
    Sigma      <- muSigmaNew$cov + finalBias / n # add bias matrix
    
    
    if (implosionguard == "cor") {
      Sigma    <- cov2cor(Sigma)
    } else if (implosionguard == "trace") {
      Sigma    <- Sigma / sum(diag(Sigma)) * d
    }
    
    # Step 2D: bookkeeping and setting up for next iteration
    
    Sigmas[nbits + 2, , ] <- Sigma
    invOut     <- mpinv(Sigma)
    Sigmai     <- invOut$Inv
    Sigmaisqrt <- invOut$InvSqrt
    
    convcrit <- KLdiv(Sigmas[nbits + 2, , ],
                      Sigmas[nbits + 1, , ], d)
    nbits <- nbits + 1
  } # end of Step2: iteration
  
  # unstandardize and clean:
  Sigmas <- Sigmas[1:(nbits + 1), , ]
  
  Ximp   <- scale(Zimp, FALSE, 1 / locScale_init$scale)
  Ximp   <- scale(Ximp, -locScale_init$loc, FALSE)
  
  center_out  <- locScale_init$loc + mu * locScale_init$scale 
  cov_out     <- t(t(Sigma) * locScale_init$scale) * locScale_init$scale
  icenter_out <- locScale_init$loc
  icov_out    <- t(t(cov_init) * locScale_init$scale) * locScale_init$scale
  
  Sigmas      <- apply(Sigmas, 1,
                       function(y) t(t(y) * locScale_init$scale) * locScale_init$scale)
  dim(Sigmas) <- c(d, d, nbits + 1)
  Sigmas      <- aperm(Sigmas, perm = c(3, 1, 2))
  
  
  if (length(out_init$indrows) > 0) {
    Ximp2 <- X
    Ximp2[-out_init$indrows, ] <- Ximp
    Ximp2[out_init$indrows, ] <- matrix(center_out, length(out_init$indrows), d, byrow = TRUE)
    Ximp <- Ximp2
    finalW2 <- matrix(1, dim(X)[1], dim(X)[2])
    finalW2[-out_init$indrows,] <- finalW
    finalW <- finalW2
  }
  
  return(list(center = center_out,
              cov = cov_out,
              center_init = icenter_out,
              cov_init = icov_out,
              allSigmas = Sigmas,
              Ximp = Ximp,
              nbimps = finalNbimps,
              W = finalW))
}


## Modified version of cellMap without circles to the right:
############################################################

library(cellWise)
library(ggplot2)
library(reshape2)
library(gridExtra)
library(scales)

# cellWise::cellMap
cellMap2 = function (D, R, indcells = NULL, indrows = NULL, standOD = NULL, 
          showVals = NULL, rowlabels = "", columnlabels = "", mTitle = "", 
          rowtitle = "", columntitle = "", showrows = NULL, showcolumns = NULL, 
          nrowsinblock = 1, ncolumnsinblock = 1, autolabel = TRUE, 
          columnangle = 90, sizetitles = 1.1, adjustrowlabels = 1, 
          adjustcolumnlabels = 1, colContrast = 1, outlyingGrad = TRUE, 
          darkestColor = sqrt(qchisq(0.999, 1)), drawCircles = TRUE) 
{
  funcSqueeze = function(Xin, n, d, ncolumnsinblock, nrowsinblock, 
                         colContrast) {
    Xblock = matrix(0, nrow = n, ncol = d)
    Xblockgrad = matrix(0, nrow = n, ncol = d)
    for (i in 1:n) {
      for (j in 1:d) {
        Xsel = Xin[(1 + ((i - 1) * nrowsinblock)):(i * 
                                                     nrowsinblock), (1 + ((j - 1) * ncolumnsinblock)):(j * 
                                                                                                         ncolumnsinblock)]
        seltable = tabulate(Xsel, nbins = 4)
        if (sum(seltable) > 0) {
          indmax = which(seltable == max(seltable))[1]
          cntmax = seltable[indmax]
          gradmax = (cntmax/(ncolumnsinblock * nrowsinblock))^(1/colContrast)
        }
        else {
          indmax = 0
          gradmax = 1
        }
        Xblock[i, j] = indmax
        Xblockgrad[i, j] = gradmax
      }
    }
    return(list(X = Xblock, Xgrad = Xblockgrad))
  }
  variable <- rownr <- rescaleoffset <- x <- y <- NULL
  type = "cell"
  n = nrow(R)
  d = ncol(R)
  blockMap = FALSE
  if (ncolumnsinblock > 1 | nrowsinblock > 1) {
    blockMap = TRUE
    if (ncolumnsinblock > d) 
      stop("Input argument ncolumnsinblock cannot be larger than d")
    if (nrowsinblock > n) 
      stop("Input argument nrowsinblock cannot be larger than n")
    if (!is.null(showVals)) 
      warning("The option showVals=D or showVals=R cannot\n                                    be combined with ncolumnsinblock or nrowsinblock greater than 1,\n                                    so showVals is set to NULL here.")
    showVals = NULL
  }
  if (!blockMap) {
    if (!all(dim(R) == dim(D))) 
      stop("Dimensions of D and R must match")
  }
  if (!(blockMap & autolabel == FALSE)) {
    if (length(columnlabels) > 0 & length(columnlabels) != 
        d) {
      stop(paste("Number of columnlabels does not match d = ", 
                 d, sep = ""))
    }
    if (length(rowlabels) > 0 & length(rowlabels) != n) {
      stop(paste("Number of rowlabels does not match n = ", 
                 n, sep = ""))
    }
  }
  if (!is.null(showVals)) {
    if (!showVals %in% c("D", "R")) {
      stop(paste("Invalid \"showVals\" argument. Should be one of: NULL, \"D\", \"R\""))
    }
  }
  if (is.null(indcells)) 
    indcells = which(abs(R) > sqrt(qchisq(0.99, 1)))
  if (!(is.null(showcolumns) & is.null(showrows))) {
    if (is.null(showcolumns)) {
      showcolumns = 1:d
    }
    else {
      if (!(all(showcolumns %in% 1:d))) 
        stop(" showcolumns goes out of bounds")
    }
    if (is.null(showrows)) {
      showrows = 1:n
    }
    else {
      if (!(all(showrows %in% 1:n))) 
        stop(" showrows goes out of bounds")
    }
    tempMat = matrix(0, n, d)
    tempMat[indcells] = 1
    tempMat = tempMat[showrows, showcolumns]
    indcells = which(tempMat == 1)
    tempVec = rep(0, n)
    tempVec[indrows] = 1
    tempVec = tempVec[showrows]
    indrows = which(tempVec == 1)
    rm(tempMat, tempVec)
    if (!blockMap) 
      D = D[showrows, showcolumns]
    R = R[showrows, showcolumns]
    if (!(blockMap & autolabel == FALSE)) 
      columnlabels = columnlabels[showcolumns]
    if (!(blockMap & autolabel == FALSE)) 
      rowlabels = rowlabels[showrows]
    n = nrow(R)
    d = ncol(R)
    if (!is.null(standOD)) 
      standOD = standOD[showrows]
  }
  if (type == "residual") 
    outlyingGrad = 1
  X = matrix(0, n, d)
  Xrow = matrix(0, n, 1)
  Xrow[indrows, 1] = 3
  if (type == "cell" | blockMap) {
    pcells = indcells[indcells %in% which(R >= 0)]
    ncells = indcells[indcells %in% which(R < 0)]
  }
  else {
    pcells = which(R >= 0)
    ncells = which(R < 0)
  }
  X[ncells] = 1
  X[pcells] = 2
  X[is.na(R)] = 4
  if (blockMap) {
    n = floor(n/nrowsinblock)
    d = floor(d/ncolumnsinblock)
    result = funcSqueeze(X, n, d, ncolumnsinblock, nrowsinblock, 
                         colContrast)
    X = result$X
    Xgrad = result$Xgrad
    result = funcSqueeze(Xrow, n, 1, 1, nrowsinblock, colContrast)
    Xrowgrad = result$Xgrad
    Xrowgrad[result$X == 0] = 0
    if (autolabel == TRUE) {
      if (ncolumnsinblock > 1 & length(columnlabels) > 
          0) {
        labx = columnlabels
        columnlabels = rep(0, d)
        for (ind in 1:d) {
          columnlabels[ind] = paste(labx[(1 + ((ind - 
                                                  1) * ncolumnsinblock))], "-", labx[(ind * 
                                                                                        ncolumnsinblock)], sep = "")
        }
      }
      if (nrowsinblock > 1 & length(rowlabels) > 0) {
        laby = rowlabels
        rowlabels = rep(0, n)
        for (ind in 1:n) {
          rowlabels[ind] = paste(laby[(1 + ((ind - 1) * 
                                              nrowsinblock))], "-", laby[(ind * nrowsinblock)])
        }
      }
    }
    else {
      if (length(columnlabels) > 0 & length(columnlabels) != 
          d) {
        stop(paste(" autolabel=FALSE and number of columnlabels is ", 
                   length(columnlabels), " but should be ", d, 
                   sep = ""))
      }
      if (length(rowlabels) > 0 & length(rowlabels) != 
          n) {
        stop(paste(" autolabel=FALSE and number of rowlabels is ", 
                   length(rowlabels), " but should be ", n, sep = ""))
      }
    }
    Xdf = data.frame(cbind(seq(1, n, 1), X))
    colnames(Xdf) = c("rownr", seq(1, d, 1))
    rownames(Xdf) = NULL
    Xdf$rownr = with(Xdf, reorder(rownr, seq(n, 1, -1)))
    mX = melt(Xdf, id.var = "rownr", value.name = "CatNr")
    Xgraddf = data.frame(cbind(seq(1, n, 1), Xgrad))
    colnames(Xgraddf) = c("rownr", seq(1, d, 1))
    rownames(Xgraddf) = NULL
    Xgraddf$rownr = with(Xgraddf, reorder(rownr, seq(n, 1, 
                                                     -1)))
    mXgrad = melt(Xgraddf, id.var = "rownr", value.name = "grad")
    mX$grad = mXgrad$grad
    mX$rescaleoffset = mXgrad$grad + 10 * mX$CatNr
    mXrow = data.frame(rownr = 1:n, rescaleoffset = Xrowgrad + 
                         10 * 3)
    scalerange = c(0, 1)
    gradientends = scalerange + rep(c(0, 10, 20, 30, 40), 
                                    each = 2)
    if (type == "cell") 
      colorends = c("yellow", "yellow", "yellow", "blue", 
                    "yellow", "red", "white", "black", "yellow", 
                    "white")
    if (type == "residual") 
      colorends = c("white", "white", "white", "blue", 
                    "white", "red", "white", "black", "white", "white")
  }
  else {
    Ddf = data.frame(cbind(seq(1, n, 1), D))
    colnames(Ddf) = c("rownr", seq(1, d, 1))
    rownames(Ddf) = NULL
    Ddf$rownr = with(Ddf, reorder(rownr, seq(n, 1, -1)))
    mD = melt(Ddf, id.var = "rownr")
    Rdf = data.frame(cbind(seq(1, n, 1), R))
    colnames(Rdf) = c("rownr", seq(1, d, 1))
    rownames(Rdf) = NULL
    Rdf$rownr = with(Rdf, reorder(rownr, seq(n, 1, -1)))
    mR = melt(Rdf, id.var = "rownr")
    Xdf = data.frame(cbind(seq(1, n, 1), X))
    colnames(Xdf) = c("rownr", seq(1, d, 1))
    rownames(Xdf) = NULL
    Xdf$rownr = with(Xdf, reorder(rownr, seq(n, 1, -1)))
    mX = melt(Xdf, id.var = "rownr", value.name = "CatNr")
    if (!is.null(showVals)) {
      if (showVals == "D") 
        mX$data = mD$value
      if (showVals == "R") 
        mX$data = mR$value
    }
    if (!outlyingGrad) {
      mX$rescaleoffset = 10 * mX$CatNr
      scalerange = c(0, 1)
      gradientends = scalerange + rep(c(0, 10, 20, 30, 
                                        40), each = 2)
      gradientends
      colorends = c("yellow", "yellow", "blue", "blue", 
                    "red", "red", "white", "black", "white", "white")
    }
    else {
      Xgrad = matrix(NA, n, d)
      if (type == "cell") {
        Xgrad[indcells] = abs(R[indcells])
        limL = sqrt(qchisq(0.9, 1))
      }
      else {
        Xgrad = abs(R)
        limL = 0
      }
      limH = darkestColor
      Xgrad[Xgrad > limH] = limH
      Xgrad = ((Xgrad - limL)/(limH - limL))^colContrast
      Xgrad[is.na(Xgrad)] = 0
      Xgraddf = data.frame(cbind(seq(1, n, 1), Xgrad))
      colnames(Xgraddf) = c("rownr", seq(1, d, 1))
      rownames(Xgraddf) = NULL
      Xgraddf$rownr = with(Xgraddf, reorder(rownr, seq(n, 
                                                       1, -1)))
      mXgrad = melt(Xgraddf, id.var = "rownr", value.name = "grad")
      mX$grad = mXgrad$grad
      mX$rescaleoffset = mXgrad$grad + 10 * mX$CatNr
      scalerange = c(0, 1)
      gradientends = scalerange + rep(c(0, 10, 20, 30, 
                                        40), each = 2)
      if (type == "cell") 
        colorends = c("yellow", "yellow", "yellow", "blue", 
                      "yellow", "red", "white", "black", "white", 
                      "white")
      if (type == "residual") 
        colorends = c("white", "white", "white", "blue", 
                      "white", "red", "white", "black", "white", 
                      "white")
    }
    tempVec = rep(0, n)
    tempVec[indrows] = 1
    mXrow = data.frame(rownr = 1:n, rescaleoffset = 40 - 
                         (10 * tempVec))
    rm(tempVec)
    if (is.null(standOD)) {
      mXrow$rescaleoffset[indrows] = mXrow$rescaleoffset[indrows] + 
        1
    }
    else {
      limL = 1
      limH = 3
      standOD[standOD > limH] = limH
      standOD = ((standOD - limL)/(limH - limL))^colContrast
      mXrow$rescaleoffset[indrows] = mXrow$rescaleoffset[indrows] + 
        standOD[indrows]
    }
  }
  rowlabels = rev(rowlabels)
  base_size = 10
  columnlabels = c(columnlabels, "", "")
  circleFun <- function(centerx, centery, r, npoints) {
    tt <- seq(0, 2 * pi, length.out = npoints)
    xx <- centerx + r * cos(tt)
    yy <- centery + r * sin(tt)
    return(c(xx, yy))
  }
  if (drawCircles){
    centerx = d + 1
    centery = n:1
    radius = 0.4
    npoints = 100
    circlePoints = mapply(circleFun, centerx, centery, radius,
                          npoints)
    positions <- data.frame(rownr = rep(1:n, each = npoints),
                          x = c(circlePoints[1:npoints, ]), y = c(circlePoints[(npoints +
                                                                                  1):(2 * npoints), ]))
    datapoly <- merge(mXrow, positions, by = c("rownr"))
  }
  # drawCircles = F
  ggp = ggplot(data = mX, aes(variable, rownr)) + {
    if (blockMap) 
      geom_tile(aes(fill = rescale(rescaleoffset, from = range(gradientends))), 
                color = "white")
  } + {
    if (!blockMap & outlyingGrad) 
      geom_tile(aes(fill = rescale(rescaleoffset, from = range(gradientends))), 
                color = "white")
  } + {
    if (!blockMap & !outlyingGrad) 
      geom_tile(aes(fill = rescale(rescaleoffset, from = range(gradientends))), 
                colour = "white")
  } + {
    if (drawCircles)
     geom_polygon(data = datapoly, aes(x = x, y = y, fill = rescale(rescaleoffset, 
                                                                      from = range(gradientends)), group = rownr), colour = "black") } + 
    scale_fill_gradientn(colours = colorends, values = rescale(gradientends), 
                         rescaler = function(x, ...) x, oob = scales::squish) + 
    ggtitle(mTitle) + coord_fixed() + theme_classic(base_size = base_size * 
                                                      1) + labs(x = columntitle, y = rowtitle) + scale_x_discrete(expand = c(0, 
                                                                                                                             0), limits = seq(1, d + 2, 1), labels = columnlabels) + 
    scale_y_discrete(expand = c(0, 0), labels = rowlabels) + 
    theme(legend.position = "none", axis.ticks = element_blank(), 
          plot.title = element_text(size = base_size * 2, hjust = 0.5, 
                                    vjust = 1, face = "bold"), axis.text.x = element_text(size = base_size * 
                                                                                            1.8, angle = columnangle, hjust = adjustcolumnlabels, 
                                                                                          vjust = 0.5, colour = "black"), axis.text.y = element_text(size = base_size * 
                                                                                                                                                       1.8, angle = 0, hjust = adjustrowlabels, colour = "black"), 
          axis.title.x = element_text(colour = "black", size = base_size * 
                                        sizetitles, vjust = 1), axis.title.y = element_text(colour = "black", 
                                                                                            size = base_size * sizetitles, vjust = 0), axis.line.x = element_blank(), 
          panel.border = element_blank()) + annotate(geom = "segment", 
                                                     x = 0.5, xend = d + 0.5, y = 0.5, yend = 0.5) + annotate(geom = "segment", 
                                                                                                              x = 0.5, xend = d + 0.5, y = n + 0.5, yend = n + 0.5) + 
    annotate(geom = "segment", x = d + 0.5, xend = d + 0.5, 
             y = 0.5, yend = n + 0.5)
  if (!is.null(showVals)) {
    txtcol = mX$CatNr
    txtcol[txtcol == 0] = "black"
    txtcol[txtcol == 1] = "white"
    txtcol[txtcol == 2] = "white"
    if (type == "residual") {
      txtcol[] = "black"
      txtcol[mXgrad$grad > 0.5] = "white"
    }
    txtcol[txtcol == 4] = "black"
    ggp = ggp + geom_text(aes(label = ifelse(is.na(data), 
                                             sprintf("%1.0f", data), round(data, 1))), size = base_size * 
                            0.5, colour = txtcol, na.rm = TRUE)
  }
  return(ggp)
}



cutCellmap <- function(rowsToShow, D, R, W, drawCircles=F) {
  D = D[rowsToShow, ]
  R = R[rowsToShow,]
  W = W[rowsToShow, ]
  indcells = which(W == 1)
  cellMap2(D, R, indcells = indcells, 
           columnlabels = colnames(D), 
           rowlabels = 1:nrow(D),
           drawCircles = drawCircles)
}



offdiag = function(x){
  # t(x) is to read x by row
  y = matrix(1,ncol(x),nrow(x))
  diag(y) = 0
  t(x)[y==1]
}

