Variational inference for Bayesian logistic regression using CAVI algorithm
Source:R/logistic.R
logit_CAVI.RdVariational inference for Bayesian logistic regression using CAVI algorithm
Usage
logit_CAVI(
X,
y,
prior,
delta = 1e-16,
maxiters = 10000,
verbose = FALSE,
progressbar = TRUE
)Arguments
- X
The input design matrix. Note the intercept column vector is assumed included.
- y
The binary response.
- prior
Prior for the logistic parameters.
- delta
The ELBO difference tolerance for conversion.
- maxiters
The maximum iterations to run if convergence is not achieved.
- verbose
A diagnostics flag (off by default).
- progressbar
A visual progress bar to indicate iterations (on by default).
Value
A list containing:
error - An error message if convergence failed or the number of iterations to achieve convergence.
mu - A vector of posterior means.
Sigma - A vector of posterior variances.
Convergence - A vector of the ELBO at each iteration.
LBDifference - A vector of ELBO differences between each iteration.
xi - A vector of log-odds per X row.
Examples
# \donttest{
# Use Old Faithful data to show the effect of VB GMM Priors,
# stopping on ELBO reverse parameter or delta threshold
# ------------------------------------------------------------------------------
gen_path <- tempdir()
data("faithful")
X <- faithful
P <- ncol(X)
# Run with 4 presumed components for demonstration purposes
k = 4
# ------------------------------------------------------------------------------
# Plotting
# ------------------------------------------------------------------------------
#' Plots the GMM components with centroids
#'
#' @param i List index to place the plot
#' @param gmm_result Results from the VB GMM run
#' @param var_name Variable to hold the GMM hyperparameter name
#' @param grid Grid element used in the plot file name
#' @param fig_path Path to the directory where the plots should be stored
#'
#' @returns
#' @importFrom ggplot2 ggplot
#' @importFrom ggplot2 aes
#' @importFrom ggplot2 geom_point
#' @importFrom ggplot2 scale_color_discrete
#' @importFrom ggplot2 stat_ellipse
#' @export
do_plots <- function(i, gmm_result, var_name, grid, fig_path) {
dd <- as.data.frame(cbind(X, cluster = gmm_result$z_post))
dd$cluster <- as.factor(dd$cluster)
# The group means
# ---------------------------------------------------------------------------
mu <- as.data.frame( t(gmm_result$q_post$m) )
# Plot the posterior mixture groups
# ---------------------------------------------------------------------------
cols <- c("#1170AA", "#55AD89", "#EF6F6A", "#D3A333", "#5FEFE8", "#11F444")
p <- ggplot2::ggplot() +
ggplot2::geom_point(dd, mapping=ggplot2::aes(x=eruptions, y=waiting, color=cluster)) +
ggplot2::scale_color_discrete(cols, guide = 'none') +
ggplot2::geom_point(mu, mapping=ggplot2::aes(x = eruptions, y = waiting), color="black",
pch=7, size=2) +
ggplot2::stat_ellipse(dd, geom="polygon",
mapping=ggplot2::aes(x=eruptions, y=waiting, fill=cluster),
alpha=0.25)
grids <- paste((grid[,i]), collapse = "_")
ggplot2::ggsave(filename=paste0(var_name,"_",grids,".png"), plot=p, path=fig_path,
width=12, height=12, units="cm", dpi=600, create.dir = TRUE)
}
# ------------------------------------------------------------------------------
# Dirichlet alpha
# ------------------------------------------------------------------------------
alpha_grid <- data.frame(c1=c(1,1,1,1),
c2=c(1,92,183,183),
c3=c(183,92,183,183))
init <- "kmeans"
for (i in 1:ncol(alpha_grid)) {
prior <- list(
alpha = as.integer(alpha_grid[,i]) # set most of the weight on one component
)
gmm_result <- vb_gmm_cavi(X=X, k=k, prior=prior, delta=1e-8, init=init,
verbose=FALSE, logDiagnostics=FALSE)
do_plots(i, gmm_result, "alpha", alpha_grid, gen_path)
}
#>
|
| | 0%
|
| | 1%
|
|= | 1%
|
|= | 2%
|
|== | 2%
|
|== | 3%
#>
|
| | 0%
|
| | 1%
|
|= | 1%
|
|= | 2%
#>
|
| | 0%
|
| | 1%
|
|= | 1%
|
|= | 2%
|
|== | 2%
# ------------------------------------------------------------------------------
# Normal-Wishart beta for precision proportionality
# ------------------------------------------------------------------------------
beta_grid <- data.frame(c1=c(0.1,0.1,0.1,0.1),
c2=c(0.9,0.9,0.9,0.9))
init <- "kmeans"
for (i in 1:ncol(beta_grid)) {
prior <- list(
beta = as.numeric(beta_grid[,i])
)
gmm_result <- vb_gmm_cavi(X=X, k=k, prior=prior, delta=1e-8, init=init,
verbose=FALSE, logDiagnostics=FALSE)
do_plots(i, gmm_result, "beta", beta_grid, gen_path)
}
#>
|
| | 0%
|
| | 1%
|
|= | 1%
|
|= | 2%
|
|== | 2%
|
|== | 3%
|
|== | 4%
|
|=== | 4%
|
|=== | 5%
|
|==== | 5%
|
|==== | 6%
|
|===== | 6%
|
|===== | 7%
|
|===== | 8%
|
|====== | 8%
|
|====== | 9%
|
|======= | 9%
|
|======= | 10%
|
|======= | 11%
|
|======== | 11%
|
|======== | 12%
|
|========= | 12%
|
|========= | 13%
|
|========= | 14%
|
|========== | 14%
|
|========== | 15%
#>
|
| | 0%
|
| | 1%
|
|= | 1%
|
|= | 2%
|
|== | 2%
|
|== | 3%
# ------------------------------------------------------------------------------
# Normal-Wishart W0 (assuming variance matrix) & logW
# ------------------------------------------------------------------------------
w_grid <- data.frame(w1=c(0.001,0.001),
w2=c(2.001,2.001))
init <- "kmeans"
for (i in 1:nrow(w_grid)) {
w0 = diag(w_grid[,i],P)
prior <- list(
W = w0,
logW = -2*sum(log(diag(chol(w0))))
)
gmm_result <- vb_gmm_cavi(X=X, k=k, prior=prior, delta=1e-8, init=init,
verbose=FALSE, logDiagnostics=FALSE)
do_plots(i, gmm_result, "w", w_grid, gen_path)
}
#>
|
| | 0%
|
| | 1%
|
|= | 1%
|
|= | 2%
|
|== | 2%
|
|== | 3%
|
|== | 4%
|
|=== | 4%
|
|=== | 5%
|
|==== | 5%
|
|==== | 6%
|
|===== | 6%
|
|===== | 7%
|
|===== | 8%
|
|====== | 8%
|
|====== | 9%
|
|======= | 9%
|
|======= | 10%
|
|======= | 11%
|
|======== | 11%
|
|======== | 12%
|
|========= | 12%
|
|========= | 13%
|
|========= | 14%
|
|========== | 14%
|
|========== | 15%
#>
|
| | 0%
|
| | 1%
|
|= | 1%
|
|= | 2%
# }