Main algorithm function for the VB CAVI GMM

vb_gmm_cavi(
  X,
  k,
  prior = NULL,
  delta = 1e-06,
  maxiters = 5000,
  init = "kmeans",
  initParams = NULL,
  stopIfELBOReverse = FALSE,
  verbose = FALSE,
  logDiagnostics = FALSE,
  logFilename = "vb_gmm_log.txt"
)

Arguments

X

n x p data matrix (or data frame that will be converted to a matrix).

k

guess for the number of mixture components.

prior

Prior for the GMM parameters.

delta

change in ELBO that triggers algorithm stopping.

maxiters

maximum iterations to run if delta does not stop the algorithm already.

init

initialize the clusters c("random", "kmeans", "dbscan").

initParams

initialization parameters for dbscan. NULL if dbscan not selected for init.

stopIfELBOReverse

stop the run if the ELBO at iteration t is detected to have reversed from iteration t-1.

verbose

print out information per iteration to track progress in case of long-running experiments.

logDiagnostics

log detailed diagnostics. If TRUE, a diagnostics RDS file will be created using the path specified in logFilename.

logFilename

the filename of the diagnostics log.

Value

A list containing:

  • error - An error message if convergence failed or the number of iterations to achieve convergence.

  • z_post - A vector of posterior cluster mappings.

  • q_post - A list of the fitted posterior Q family. q_post includes:

    • alpha - The Dirichlet prior for the mixing coefficient

    • beta - The proportionality of the precision for the Normal-Wishart prior

    • m - The mean vector for the Normal part of the Normal-Wishart prior

    • v - The degrees of freedom of the Wishart gamma ensuring it is well defined

    • U - The W hyperparameter of the Wishart conjugate prior for the precision of the mixtures. k sets of D x D symmetric, positive definite matrices.

    • logW - The logW term used to calculate the expectation of the mixture component precision. A vector of k.

    • R - The responsibilities. An n x k matrix.

    • logR - The log of responsibilities. An n x k matrix.

  • elbo - A vector of the ELBO at each iteration.

  • elbo_d - A vector of ELBO differences between each iteration.

Examples

if (FALSE) {
  # Use Old Faithful data to show the effect of VB GMM Priors along with choice
  # of initialiser (kmeans vs dbscan), stopping on ELBO reverse parameter and delta
  # threshold
  # ------------------------------------------------------------------------------

  require(ggplot2)

  # gen_path <- tempdir()
  gen_path <- "c:/Users/buckl/Documents/vbPhenoR/man/examples/plots"
  data("faithful")
  X <- faithful
  P <- ncol(X)

  # Run with 4 presumed components for demonstration purposes
  k = 4

  # ------------------------------------------------------------------------------
  # Plotting
  # ------------------------------------------------------------------------------

  #' Plots the GMM components with centroids
  #'
  #' @param i List index to place the plot
  #' @param gmm_result Results from the VB GMM run
  #' @param var_name Variable to hold the GMM hyperparameter name
  #' @param grid Grid element used in the plot file name
  #' @param fig_path Path to the directory where the plots should be stored
  #'
  #' @returns
  #' @importFrom ggplot2 ggplot
  #' @export

  do_plots <- function(i, gmm_result, var_name, grid, fig_path) {
    dd <- as.data.frame(cbind(X, cluster = gmm_result$z_post))
    dd$cluster <- as.factor(dd$cluster)

    # The group means
    # ---------------------------------------------------------------------------
    mu <- as.data.frame( t(gmm_result$q_post$m) )

    # Plot the posterior mixture groups
    # ---------------------------------------------------------------------------
    cols <- c("#1170AA", "#55AD89", "#EF6F6A", "#D3A333", "#5FEFE8", "#11F444")
    p <- ggplot() +
      geom_point(dd, mapping=aes(x=eruptions, y=waiting, color=cluster)) +
      scale_color_discrete(cols, guide = 'none') +
      geom_point(mu, mapping=aes(x = eruptions, y = waiting), color="black",
                 pch=7, size=2) +
      stat_ellipse(dd, geom="polygon",
                   mapping=aes(x=eruptions, y=waiting, fill=cluster),
                   alpha=0.25)

    plots[[i]] <- p
    rm(gmm_result, dd, mu)
    gc()

    grids <- paste((grid[i,]), collapse = "_")
    ggsave(filename=paste0(var_name,"_",grids,".png"), plot=p, path=fig_path,
           width=12, height=12, units="cm", dpi=600, create.dir = TRUE)
  }

  # ------------------------------------------------------------------------------
  # Dirichlet alpha
  # ------------------------------------------------------------------------------
  alpha_grid <- data.frame(c1=c(1,1,183),
                           c2=c(1,92,92),
                           c3=c(1,183,198),
                           c4=c(1,198,50))
  init <- "kmeans"

  z <- vector(mode="list", length=nrow(alpha_grid))
  plots <- vector(mode="list", length=nrow(alpha_grid))

  # Just plot the first 5 grids to save time
  for (i in 1:nrow(alpha_grid)) {
    prior <- list(
      alpha = as.integer(alpha_grid[i,]) # set most of the weight on one component
    )

    gmm_result <- vb_gmm_cavi(X=X, k=k, prior=prior, delta=1e-9, init=init,
                              verbose=FALSE, logDiagnostics=FALSE)
    z[[i]] <- table(gmm_result$z_post)
    do_plots(i, gmm_result, "alpha", alpha_grid, gen_path)
  }

  # ------------------------------------------------------------------------------
  # Normal-Wishart beta for precision proportionality
  # ------------------------------------------------------------------------------
  beta_grid <- data.frame(c1=c(0.1,0.9),
                          c2=c(0.1,0.9),
                          c3=c(0.1,0.9),
                          c4=c(0.1,0.9))
  init <- "kmeans"

  z <- vector(mode="list", length=nrow(beta_grid))
  plots <- vector(mode="list", length=nrow(beta_grid))

  # Just plot the first 5 grids to save time
  for (i in 1:nrow(beta_grid)) {
    prior <- list(
      beta = as.numeric(beta_grid[i,])
    )

    gmm_result <- vb_gmm_cavi(X=X, k=k, prior=prior, delta=1e-9, init=init,
                              verbose=FALSE, logDiagnostics=FALSE)
    z[[i]] <- table(gmm_result$z_post)
    do_plots(i, gmm_result, "beta", beta_grid, gen_path)
  }

  # ------------------------------------------------------------------------------
  # Normal-Wishart W0 (assuming variance matrix) & logW
  # ------------------------------------------------------------------------------

  w_grid <- data.frame(w1=c(0.001,2.001),
                       w2=c(0.001,2.001))
  init <- "kmeans"

  z <- vector(mode="list", length=nrow(w_grid))
  plots <- vector(mode="list", length=nrow(w_grid))

  # Just plot the first 5 grids to save time
  for (i in 1:nrow(w_grid)) {
    w0 = diag(w_grid[i,],P)
    prior <- list(
      W = w0,
      logW = -2*sum(log(diag(chol(w0))))
    )

    gmm_result <- vb_gmm_cavi(X=X, k=k, prior=prior, delta=1e-9, init=init,
                              verbose=FALSE, logDiagnostics=FALSE)
    z[[i]] <- table(gmm_result$z_post)
    do_plots(i, gmm_result, "w", w_grid, gen_path)
  }
}