#########################################
##                                     ##
##          SUPPORT FUNCTIONS          ##
##                                     ##
#########################################

EM <- function(R, m, K, 
               tol = 1e-10, EM.iter = 20, 
               max_iter = 1500) {
  
  library(dplyr)
  
  #------------------------------
  # Loglikelihood function
  #------------------------------
  LogLik.function <- function(omega.mat, pi.tens, t, R, m, xi.tens, 
                              K, J, n, uniform.j) {
    
    CUB.mix <- array(NA, c(K, J, n))
    
    for(i in 1:n) {
      for(k in 1:K) {
        for(j in 1:J) {
          CUB.mix[k, j, i] <- pi.tens[k, j, t] * dbinom(R[i, j] - 1, m[j] - 1, 1 - xi.tens[k, j, t]) +
            (1 - pi.tens[k, j, t]) * uniform.j[i, j]
        }
      }
    }
    
    # Product over j
    multiCUB.j <- apply(CUB.mix, c(3, 1), prod)
    
    # Product with omega_k
    multiCUB.jk <- matrix(nrow = n, ncol = K)
    for(k in 1:K) {
      multiCUB.jk[, k] <- omega.mat[t, k] * multiCUB.j[, k]
    }
    
    MLCCUBk <- apply(multiCUB.jk, 1, sum)
    MLCCUBi <- log(MLCCUBk)
    LL <- sum(MLCCUBi)
    
    return(list(
      multiCUB.w_k = multiCUB.jk,
      MLCCUBk = MLCCUBk,
      MLCCUBi = MLCCUBi,
      LL = LL
    ))
  }
  
  #------------------------------------------
  # Function for computing eta
  #------------------------------------------
  eta_ijk <- function(omega.mat, pi.tens, t, R, m, xi.tens,
                      K, J, n, uniform.j) {
    
    eta.ijk <- array(NA, c(K, J, n))
    
    for(i in 1:n) {
      for(k in 1:K) {
        for(j in 1:J) {
          eta.ijk[k, j, i] <- pi.tens[k, j, t] * dbinom(R[i, j] - 1, m[j] - 1, 1 - xi.tens[k, j, t]) /
            (pi.tens[k, j, t] * dbinom(R[i, j] - 1, m[j] - 1, 1 - xi.tens[k, j, t]) +
               (1 - pi.tens[k, j, t]) * uniform.j[i, j])
        }
      }
    }
    
    return(eta.ijk)
  }
  
  #------------------------------------------
  # EM Iterations
  #------------------------------------------
  results.list <- list()
  max.LL <- NA
  
  for(em in 1:EM.iter) {  
    Niter <- 0
    iter <- 0
    t <- 2
    n <- nrow(R)
    J <- ncol(R)
    
    # Retry logic in case of no convergence
    max_retries <- 10
    retry_count <- 0
    retry <- FALSE
    
    while(retry_count < max_retries) {
      tryCatch({
        
        #--------------------#
        #   INITIALIZATION   #
        #--------------------#
        pi.tens <- array(NA, dim = c(K, J, t))
        pi.tens[,,t] <- matrix(runif(K*J), nrow = K, ncol = J)
        
        pi.mat <- matrix(NA, nrow = t, ncol = K*J)
        pi.mat[t, ] <- as.vector(pi.tens[,,t])
        colnames(pi.mat) <- paste0("pi.", rep(1:J, each = K), rep(1:K, times = J))
        
        xi.tens <- array(NA, dim = c(K, J, t))
        xi.tens[,,t] <- matrix(runif(K*J), nrow = K, ncol = J)
        
        xi.mat <- matrix(NA, nrow = t, ncol = K*J)
        xi.mat[t, ] <- as.vector(xi.tens[,,t])
        colnames(xi.mat) <- paste0("xi.", rep(1:J, each = K), rep(1:K, times = J))
        
        omega.mat <- matrix(nrow = t, ncol = K)
        omega.mat[t, ] <- rep(1/K, K)
        
        # Uniform distribution
        uniform.j <- matrix(ncol = J, nrow = n)
        for(j in 1:J) uniform.j[, j] <- rep(1/m[j], n)
        
        #--------------------#
        # LogLikelihood
        #--------------------#
        LL <- -Inf
        LL.list <- LogLik.function(omega.mat, pi.tens, t, R, m, xi.tens, K, J, n, uniform.j)
        LL[t] <- unlist(LL.list$LL)
        
        #-------------------------#
        #      EM ALGORITHM       #
        #-------------------------#
        while((LL[t] - LL[t-1]) >= tol && iter < max_iter) {
          Niter <- t - 1
          iter <- iter + 1
          
          # EXPECTATION
          tau.mat <- matrix(NA, nrow = n, ncol = K)
          for(i in 1:n) {
            for(k in 1:K) {
              tau.mat[i, k] <- LL.list$multiCUB.w_k[i, k] / LL.list$MLCCUBk[i]
            }
          }
          
          eta.tens <- eta_ijk(omega.mat, pi.tens, t, R, m, xi.tens, K, J, n, uniform.j)
          t <- t + 1
          
          #-----------------------------------
          # MAXIMIZATION
          #-----------------------------------
          
          # PI update
          if(dim(pi.tens)[3] < t) {
            pi.tens <- abind::abind(pi.tens, array(NA, dim = c(K, J, t - dim(pi.tens)[3])), along = 3)
          }
          
          pi.tens.n <- array(NA, dim = c(K, J, n))
          for(i in 1:n) pi.tens.n[,,i] <- tau.mat[i, ] * eta.tens[,,i]
          
          pi.tens[,,t] <- apply(pi.tens.n, c(1, 2), sum) / colSums(tau.mat)
          
          if(t > nrow(pi.mat)) pi.mat <- rbind(pi.mat, matrix(NA, nrow = t - nrow(pi.mat), ncol = K*J))
          pi.mat[t, ] <- as.vector(pi.tens[,,t])
          
          # XI update
          if(dim(xi.tens)[3] < t) {
            xi.tens <- abind::abind(xi.tens, array(NA, dim = c(K, J, t - dim(xi.tens)[3])), along = 3)
          }
          
          for(j in 1:J) {
            for(k in 1:K) {
              xi.tens[k, j, t] <- sum(tau.mat[, k] * eta.tens[k, j, ] * (m[j] - R[, j])) /
                sum(tau.mat[, k] * eta.tens[k, j, ] * (m[j] - 1))
            }
          }
          
          if(t > nrow(xi.mat)) xi.mat <- rbind(xi.mat, matrix(NA, nrow = t - nrow(xi.mat), ncol = K*J))
          xi.mat[t, ] <- as.vector(xi.tens[,,t])
          
          # OMEGA update
          if(t > nrow(omega.mat)) omega.mat <- rbind(omega.mat, matrix(NA, nrow = t - nrow(omega.mat), ncol = K))
          omega.mat[t, ] <- apply(tau.mat, 2, sum) / n
          
          # LogLikelihood update
          LL.list <- LogLik.function(omega.mat, pi.tens, t, R, m, xi.tens, K, J, n, uniform.j)
          LL[t] <- unlist(LL.list$LL)
        }
        
      }, error = function(e) {
        retry_count <- retry_count + 1
        iter <- 0
        retry <- TRUE
        cat("No convergence - Retry - N. of tries: ", retry_count, "\n")
      })
      
      if(retry) {
        retry <- FALSE
        next
      }
      
      break
    }
    
    colnames(omega.mat) <- paste("w.", 1:K, sep = "")
    
    params_table <- cbind(pi.mat, xi.mat, omega.mat, LL)
    colnames(tau.mat) <- paste("tau.", 1:K, sep = "")
    class <- max.col(tau.mat)
    
    results.list[[em]] <- list(
      params_table = params_table,
      xi.conv = xi.mat,
      pi.conv = pi.mat,
      omega.conv = omega.mat,
      xi.est = as.matrix(xi.tens[,,dim(xi.tens)[3]]),
      pi.est = as.matrix(pi.tens[,,dim(pi.tens)[3]]),
      omega.est = as.matrix(tail(omega.mat, 1)),
      taus = tau.mat,
      class = class,
      LogLik_vec = LL,
      LogLik = tail(LL,1),
      Niter = iter,
      AIC = 2*(K*J*2+K) - 2*tail(LL,1),
      BIC = (K*J*2+K)*log(n) - 2*tail(LL,1)
    )
    
    max.LL[em] <- tail(LL, 1)
    cat("Number of clusters:", K, "- EM initialization number:", em, "- N. of iterations:", tail(Niter, 1), "\n")
  }
  
  best.result <- results.list[[which.max(max.LL)]]
  return(best.result)
}

#------------------------------------------
# Bootstrap iteration function
#------------------------------------------
bootstrap_iteration <- function(data, m, K, tol, EM.iter, max_iter) {
  library(combinat)
  library(dplyr)
  library(magrittr)
  library(mclust)
  
  n <- nrow(data)
  J <- ncol(data)
  
  boot.indices <- sample(1:n, replace = TRUE)
  boot.data <- data[boot.indices, ]
  
  model.boot <- EM(boot.data, m = m, K = K, tol = tol, EM.iter = EM.iter, max_iter = max_iter)
  
  res <- as.data.frame(cbind(index = boot.indices, class = model.boot$class))
  res.sort <- res[order(res[,1], decreasing = FALSE), ]
  
  return(list(model.boot = model.boot, clusters.sort = res.sort))
}

######################################
##                                  ##
##          MAIN FUNCTIONS          ##
##                                  ##
######################################

EM.MLCCUB <- function(R, m, K.vec, tol = 1e-5, EM.iter = 20, max_iter = 1500) {
  results <- list()
  i <- 1
  for(k in K.vec) {
    results[[i]] <- EM(R = R, m = m, K = k, tol = tol, EM.iter = EM.iter, max_iter = max_iter)
    i <- i + 1
  }
  return(results)
}

EM.MLCCUB_parallel <- function(R, m, K.vec, tol = 1e-5, EM.iter = 20, max_iter = 1500) {
  library(parallel)
  
  cl <- makeCluster(min(detectCores(), length(K.vec)))
  clusterExport(cl, c("EM", "R"))
  
  K_chunks <- split(K.vec, rep(seq_len(length(cl)), each = length(K.vec) / length(cl)))
  K_chunks <- K_chunks[lengths(K_chunks) > 0]
  
  parallel_function <- function(k_chunk) {
    lapply(k_chunk, function(k) EM(R = R, m = m, K = k, tol = tol, EM.iter = EM.iter, max_iter = max_iter))
  }
  
  models <- unlist(parLapply(cl, K_chunks, parallel_function), recursive = FALSE)
  stopCluster(cl)
  
  return(models)
}

bootstrap <- function(data, n.parallel, n.boot = 100, m, K, tol, EM.iter, max_iter, cluster_export_list = NULL) {
  library(parallel)
  
  cl <- makeCluster(n.parallel)
  export_list <- c("bootstrap_iteration", "EM", "EM.MLCCUB")
  if(!is.null(cluster_export_list)) export_list <- c(export_list, cluster_export_list)
  clusterExport(cl, export_list)
  
  results <- parLapply(cl, 1:n.boot, function(rep) {
    bootstrap_iteration(data = data, m = m, K = K, tol = tol, EM.iter = EM.iter, max_iter = max_iter)
  })
  
  stopCluster(cl)
  
  model.boot <- lapply(results, function(x) x$model.boot)
  clusters.sort <- lapply(results, function(x) x$clusters.sort)
  
  return(list(model.boot.list = model.boot, clusters.sort.list = clusters.sort))
}
