Skip to contents

Variational inference for Bayesian logistic regression using CAVI algorithm

Usage

logit_CAVI(
  X,
  y,
  prior,
  delta = 1e-16,
  maxiters = 10000,
  verbose = FALSE,
  progressbar = TRUE
)

Arguments

X

The input design matrix. Note the intercept column vector is assumed included.

y

The binary response.

prior

Prior for the logistic parameters.

delta

The ELBO difference tolerance for conversion.

maxiters

The maximum iterations to run if convergence is not achieved.

verbose

A diagnostics flag (off by default).

progressbar

A visual progress bar to indicate iterations (on by default).

Value

A list containing:

  • error - An error message if convergence failed or the number of iterations to achieve convergence.

  • mu - A vector of posterior means.

  • Sigma - A vector of posterior variances.

  • Convergence - A vector of the ELBO at each iteration.

  • LBDifference - A vector of ELBO differences between each iteration.

  • xi - A vector of log-odds per X row.

Examples

# \donttest{

  # Use Old Faithful data to show the effect of VB GMM Priors,
  # stopping on ELBO reverse parameter or delta threshold
  # ------------------------------------------------------------------------------

  gen_path <- tempdir()
  data("faithful")
  X <- faithful
  P <- ncol(X)

  # Run with 4 presumed components for demonstration purposes
  k = 4

  # ------------------------------------------------------------------------------
  # Plotting
  # ------------------------------------------------------------------------------

  #' Plots the GMM components with centroids
  #'
  #' @param i List index to place the plot
  #' @param gmm_result Results from the VB GMM run
  #' @param var_name Variable to hold the GMM hyperparameter name
  #' @param grid Grid element used in the plot file name
  #' @param fig_path Path to the directory where the plots should be stored
  #'
  #' @returns
  #' @importFrom ggplot2 ggplot
  #' @importFrom ggplot2 aes
  #' @importFrom ggplot2 geom_point
  #' @importFrom ggplot2 scale_color_discrete
  #' @importFrom ggplot2 stat_ellipse
  #' @export

  do_plots <- function(i, gmm_result, var_name, grid, fig_path) {
    dd <- as.data.frame(cbind(X, cluster = gmm_result$z_post))
    dd$cluster <- as.factor(dd$cluster)

    # The group means
    # ---------------------------------------------------------------------------
    mu <- as.data.frame( t(gmm_result$q_post$m) )

    # Plot the posterior mixture groups
    # ---------------------------------------------------------------------------
    cols <- c("#1170AA", "#55AD89", "#EF6F6A", "#D3A333", "#5FEFE8", "#11F444")
    p <- ggplot2::ggplot() +
      ggplot2::geom_point(dd, mapping=ggplot2::aes(x=eruptions, y=waiting, color=cluster)) +
      ggplot2::scale_color_discrete(cols, guide = 'none') +
      ggplot2::geom_point(mu, mapping=ggplot2::aes(x = eruptions, y = waiting), color="black",
                 pch=7, size=2) +
      ggplot2::stat_ellipse(dd, geom="polygon",
                   mapping=ggplot2::aes(x=eruptions, y=waiting, fill=cluster),
                   alpha=0.25)

    grids <- paste((grid[,i]), collapse = "_")
    ggplot2::ggsave(filename=paste0(var_name,"_",grids,".png"), plot=p, path=fig_path,
           width=12, height=12, units="cm", dpi=600, create.dir = TRUE)
  }

  # ------------------------------------------------------------------------------
  # Dirichlet alpha
  # ------------------------------------------------------------------------------
  alpha_grid <- data.frame(c1=c(1,1,1,1),
                           c2=c(1,92,183,183),
                           c3=c(183,92,183,183))
  init <- "kmeans"

  for (i in 1:ncol(alpha_grid)) {
    prior <- list(
      alpha = as.integer(alpha_grid[,i]) # set most of the weight on one component
    )

    gmm_result <- vb_gmm_cavi(X=X, k=k, prior=prior, delta=1e-8, init=init,
                              verbose=FALSE, logDiagnostics=FALSE)
    do_plots(i, gmm_result, "alpha", alpha_grid, gen_path)
  }
#> 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |                                                                      |   1%
  |                                                                            
  |=                                                                     |   1%
  |                                                                            
  |=                                                                     |   2%
  |                                                                            
  |==                                                                    |   2%
  |                                                                            
  |==                                                                    |   3%
#> 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |                                                                      |   1%
  |                                                                            
  |=                                                                     |   1%
  |                                                                            
  |=                                                                     |   2%
#> 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |                                                                      |   1%
  |                                                                            
  |=                                                                     |   1%
  |                                                                            
  |=                                                                     |   2%
  |                                                                            
  |==                                                                    |   2%

  # ------------------------------------------------------------------------------
  # Normal-Wishart beta for precision proportionality
  # ------------------------------------------------------------------------------
  beta_grid <- data.frame(c1=c(0.1,0.1,0.1,0.1),
                          c2=c(0.9,0.9,0.9,0.9))
  init <- "kmeans"

  for (i in 1:ncol(beta_grid)) {
    prior <- list(
      beta = as.numeric(beta_grid[,i])
    )

    gmm_result <- vb_gmm_cavi(X=X, k=k, prior=prior, delta=1e-8, init=init,
                              verbose=FALSE, logDiagnostics=FALSE)
    do_plots(i, gmm_result, "beta", beta_grid, gen_path)
  }
#> 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |                                                                      |   1%
  |                                                                            
  |=                                                                     |   1%
  |                                                                            
  |=                                                                     |   2%
  |                                                                            
  |==                                                                    |   2%
  |                                                                            
  |==                                                                    |   3%
  |                                                                            
  |==                                                                    |   4%
  |                                                                            
  |===                                                                   |   4%
  |                                                                            
  |===                                                                   |   5%
  |                                                                            
  |====                                                                  |   5%
  |                                                                            
  |====                                                                  |   6%
  |                                                                            
  |=====                                                                 |   6%
  |                                                                            
  |=====                                                                 |   7%
  |                                                                            
  |=====                                                                 |   8%
  |                                                                            
  |======                                                                |   8%
  |                                                                            
  |======                                                                |   9%
  |                                                                            
  |=======                                                               |   9%
  |                                                                            
  |=======                                                               |  10%
  |                                                                            
  |=======                                                               |  11%
  |                                                                            
  |========                                                              |  11%
  |                                                                            
  |========                                                              |  12%
  |                                                                            
  |=========                                                             |  12%
  |                                                                            
  |=========                                                             |  13%
  |                                                                            
  |=========                                                             |  14%
  |                                                                            
  |==========                                                            |  14%
  |                                                                            
  |==========                                                            |  15%
#> 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |                                                                      |   1%
  |                                                                            
  |=                                                                     |   1%
  |                                                                            
  |=                                                                     |   2%
  |                                                                            
  |==                                                                    |   2%
  |                                                                            
  |==                                                                    |   3%

  # ------------------------------------------------------------------------------
  # Normal-Wishart W0 (assuming variance matrix) & logW
  # ------------------------------------------------------------------------------

  w_grid <- data.frame(w1=c(0.001,0.001),
                       w2=c(2.001,2.001))
  init <- "kmeans"

  for (i in 1:nrow(w_grid)) {
    w0 = diag(w_grid[,i],P)
    prior <- list(
      W = w0,
      logW = -2*sum(log(diag(chol(w0))))
    )

    gmm_result <- vb_gmm_cavi(X=X, k=k, prior=prior, delta=1e-8, init=init,
                              verbose=FALSE, logDiagnostics=FALSE)
    do_plots(i, gmm_result, "w", w_grid, gen_path)
  }
#> 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |                                                                      |   1%
  |                                                                            
  |=                                                                     |   1%
  |                                                                            
  |=                                                                     |   2%
  |                                                                            
  |==                                                                    |   2%
  |                                                                            
  |==                                                                    |   3%
  |                                                                            
  |==                                                                    |   4%
  |                                                                            
  |===                                                                   |   4%
  |                                                                            
  |===                                                                   |   5%
  |                                                                            
  |====                                                                  |   5%
  |                                                                            
  |====                                                                  |   6%
  |                                                                            
  |=====                                                                 |   6%
  |                                                                            
  |=====                                                                 |   7%
  |                                                                            
  |=====                                                                 |   8%
  |                                                                            
  |======                                                                |   8%
  |                                                                            
  |======                                                                |   9%
  |                                                                            
  |=======                                                               |   9%
  |                                                                            
  |=======                                                               |  10%
  |                                                                            
  |=======                                                               |  11%
  |                                                                            
  |========                                                              |  11%
  |                                                                            
  |========                                                              |  12%
  |                                                                            
  |=========                                                             |  12%
  |                                                                            
  |=========                                                             |  13%
  |                                                                            
  |=========                                                             |  14%
  |                                                                            
  |==========                                                            |  14%
  |                                                                            
  |==========                                                            |  15%
#> 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |                                                                      |   1%
  |                                                                            
  |=                                                                     |   1%
  |                                                                            
  |=                                                                     |   2%
# }