Main algorithm function for the VB CAVI GMM
Usage
vb_gmm_cavi(
X,
k,
prior = NULL,
delta = 1e-06,
maxiters = 5000,
init = "kmeans",
initParams = NULL,
stopIfELBOReverse = FALSE,
verbose = FALSE,
logDiagnostics = FALSE,
logFilename = "vb_gmm_log.txt",
progressbar = TRUE
)Arguments
- X
n x p data matrix (or data frame that will be converted to a matrix).
- k
guess for the number of mixture components.
- prior
Prior for the GMM parameters.
- delta
change in ELBO that triggers algorithm stopping.
- maxiters
maximum iterations to run if delta does not stop the algorithm already.
- init
initialize the clusters c("random", "kmeans", "dbscan").
- initParams
initialization parameters for dbscan. NULL if dbscan not selected for init.
- stopIfELBOReverse
stop the run if the ELBO at iteration t is detected to have reversed from iteration t-1.
- verbose
print out information per iteration to track progress in case of long-running experiments.
- logDiagnostics
log detailed diagnostics. If TRUE, a diagnostics RDS file will be created using the path specified in logFilename.
- logFilename
the filename of the diagnostics log.
- progressbar
A visual progress bar to indicate iterations (on by default).
Value
A list containing:
error - An error message if convergence failed or the number of iterations to achieve convergence.
z_post - A vector of posterior cluster mappings.
q_post - A list of the fitted posterior Q family. q_post includes:
alpha - The Dirichlet prior for the mixing coefficient
beta - The proportionality of the precision for the Normal-Wishart prior
m - The mean vector for the Normal part of the Normal-Wishart prior
v - The degrees of freedom of the Wishart gamma ensuring it is well defined
U - The W hyperparameter of the Wishart conjugate prior for the precision of the mixtures. k sets of D x D symmetric, positive definite matrices.
logW - The logW term used to calculate the expectation of the mixture component precision. A vector of k.
R - The responsibilities. An n x k matrix.
logR - The log of responsibilities. An n x k matrix.
elbo - A vector of the ELBO at each iteration.
elbo_d - A vector of ELBO differences between each iteration.
Examples
# \donttest{
# Use Old Faithful data to show the effect of VB GMM Priors,
# stopping on ELBO reverse parameter or delta threshold
# ------------------------------------------------------------------------------
gen_path <- tempdir()
data("faithful")
X <- faithful
P <- ncol(X)
# Run with 4 presumed components for demonstration purposes
k = 4
# ------------------------------------------------------------------------------
# Plotting
# ------------------------------------------------------------------------------
#' Plots the GMM components with centroids
#'
#' @param i List index to place the plot
#' @param gmm_result Results from the VB GMM run
#' @param var_name Variable to hold the GMM hyperparameter name
#' @param grid Grid element used in the plot file name
#' @param fig_path Path to the directory where the plots should be stored
#'
#' @returns
#' @importFrom ggplot2 ggplot
#' @importFrom ggplot2 aes
#' @importFrom ggplot2 geom_point
#' @importFrom ggplot2 scale_color_discrete
#' @importFrom ggplot2 stat_ellipse
#' @export
do_plots <- function(i, gmm_result, var_name, grid, fig_path) {
dd <- as.data.frame(cbind(X, cluster = gmm_result$z_post))
dd$cluster <- as.factor(dd$cluster)
# The group means
# ---------------------------------------------------------------------------
mu <- as.data.frame( t(gmm_result$q_post$m) )
# Plot the posterior mixture groups
# ---------------------------------------------------------------------------
cols <- c("#1170AA", "#55AD89", "#EF6F6A", "#D3A333", "#5FEFE8", "#11F444")
p <- ggplot2::ggplot() +
ggplot2::geom_point(dd, mapping=ggplot2::aes(x=eruptions, y=waiting, color=cluster)) +
ggplot2::scale_color_discrete(cols, guide = 'none') +
ggplot2::geom_point(mu, mapping=ggplot2::aes(x = eruptions, y = waiting), color="black",
pch=7, size=2) +
ggplot2::stat_ellipse(dd, geom="polygon",
mapping=ggplot2::aes(x=eruptions, y=waiting, fill=cluster),
alpha=0.25)
grids <- paste((grid[,i]), collapse = "_")
ggplot2::ggsave(filename=paste0(var_name,"_",grids,".png"), plot=p, path=fig_path,
width=12, height=12, units="cm", dpi=600, create.dir = TRUE)
}
# ------------------------------------------------------------------------------
# Dirichlet alpha
# ------------------------------------------------------------------------------
alpha_grid <- data.frame(c1=c(1,1,1,1),
c2=c(1,92,183,183),
c3=c(183,92,183,183))
init <- "kmeans"
for (i in 1:ncol(alpha_grid)) {
prior <- list(
alpha = as.integer(alpha_grid[,i]) # set most of the weight on one component
)
gmm_result <- vb_gmm_cavi(X=X, k=k, prior=prior, delta=1e-8, init=init,
verbose=FALSE, logDiagnostics=FALSE)
do_plots(i, gmm_result, "alpha", alpha_grid, gen_path)
}
#>
|
| | 0%
|
| | 1%
|
|= | 1%
|
|= | 2%
|
|== | 2%
|
|== | 3%
#>
|
| | 0%
|
| | 1%
|
|= | 1%
|
|= | 2%
#>
|
| | 0%
|
| | 1%
|
|= | 1%
|
|= | 2%
# ------------------------------------------------------------------------------
# Normal-Wishart beta for precision proportionality
# ------------------------------------------------------------------------------
beta_grid <- data.frame(c1=c(0.1,0.1,0.1,0.1),
c2=c(0.9,0.9,0.9,0.9))
init <- "kmeans"
for (i in 1:ncol(beta_grid)) {
prior <- list(
beta = as.numeric(beta_grid[,i])
)
gmm_result <- vb_gmm_cavi(X=X, k=k, prior=prior, delta=1e-8, init=init,
verbose=FALSE, logDiagnostics=FALSE)
do_plots(i, gmm_result, "beta", beta_grid, gen_path)
}
#>
|
| | 0%
|
| | 1%
|
|= | 1%
|
|= | 2%
|
|== | 2%
|
|== | 3%
|
|== | 4%
|
|=== | 4%
|
|=== | 5%
|
|==== | 5%
|
|==== | 6%
|
|===== | 6%
|
|===== | 7%
|
|===== | 8%
|
|====== | 8%
|
|====== | 9%
|
|======= | 9%
|
|======= | 10%
|
|======= | 11%
|
|======== | 11%
|
|======== | 12%
|
|========= | 12%
|
|========= | 13%
|
|========= | 14%
|
|========== | 14%
|
|========== | 15%
#>
|
| | 0%
|
| | 1%
|
|= | 1%
|
|= | 2%
|
|== | 2%
|
|== | 3%
# ------------------------------------------------------------------------------
# Normal-Wishart W0 (assuming variance matrix) & logW
# ------------------------------------------------------------------------------
w_grid <- data.frame(w1=c(0.001,0.001),
w2=c(2.001,2.001))
init <- "kmeans"
for (i in 1:nrow(w_grid)) {
w0 = diag(w_grid[,i],P)
prior <- list(
W = w0,
logW = -2*sum(log(diag(chol(w0))))
)
gmm_result <- vb_gmm_cavi(X=X, k=k, prior=prior, delta=1e-8, init=init,
verbose=FALSE, logDiagnostics=FALSE)
do_plots(i, gmm_result, "w", w_grid, gen_path)
}
#>
|
| | 0%
|
| | 1%
|
|= | 1%
|
|= | 2%
|
|== | 2%
|
|== | 3%
|
|== | 4%
|
|=== | 4%
|
|=== | 5%
|
|==== | 5%
|
|==== | 6%
|
|===== | 6%
|
|===== | 7%
|
|===== | 8%
|
|====== | 8%
|
|====== | 9%
|
|======= | 9%
|
|======= | 10%
|
|======= | 11%
|
|======== | 11%
|
|======== | 12%
|
|========= | 12%
|
|========= | 13%
|
|========= | 14%
|
|========== | 14%
|
|========== | 15%
#>
|
| | 0%
|
| | 1%
|
|= | 1%
|
|= | 2%
# }