Introduction to bkmr and bkmrhat

bkmr is a package to implement Bayesian kernel machine regression (BKMR) using Markov chain Monte Carlo (MCMC). Notably, bkmr is missing some key features in Bayesian inference and MCMC diagnostics: 1) no facility for running multiple chains in parallel 2) no inference across multiple chains 3) limited posterior summary of parameters 4) limited diagnostics. The bkmrhat package is a lightweight set of function that fills in each of those gaps by enabling post-processing of bkmr output in other packages and building a small framework for parallel processing.

How to use the bkmrhat package

  1. Fit a BKMR model for a single chain using the kmbaryes function from bkmr, or use multiple parallel chains kmbayes_parallel from bkmrhat
  2. Perform diagnostics on single or multiple chains using the kmbayes_diag function (uses functions from the rstan package) OR convert the BKMR fit(s) to mcmc (one chain) or mcmc.list (multiple chains) objects from the coda package using as.mcmc or as.mcmc.list from the bkmrhat package. The coda package has a whole host of inference and diagnostic procedures (but may lag behind some of the diagnostics functions from rstan).
  3. Perform posterior summaries using coda functions or combine chains from a kmbayes_parallel fit using comb_bkmrfits. Final posterior inferences can be made on the combined object, which enables use of bkmr package functions for visual summaries of independent and joint effects of exposures in the bkmr model.

First, simulate some data from the bkmr function

library("bkmr")
library("bkmrhat")
library("coda")
Sys.setenv(R_FUTURE_SUPPORTSMULTICORE_UNSTABLE="quiet") # for future package

set.seed(111)
dat <- bkmr::SimData(n = 50, M = 5, ind=1:3, Zgen="realistic")
y <- dat$y
Z <- dat$Z
X <- cbind(dat$X, rnorm(50))
head(cbind(y,Z,X))
##               y          z1          z2          z3          z4         z5
## [1,]  4.1379128 -0.06359282 -0.02996246 -0.14190647 -0.44089352 -0.1878732
## [2,] 12.0843607 -0.07308834  0.32021690  1.08838691  0.29448354 -1.4609837
## [3,]  7.8859254  0.59604857  0.20602329  0.46218114 -0.03387906 -0.7615902
## [4,]  1.1609768  1.46504863  2.48389356  1.39869461  1.49678590  0.2837234
## [5,]  0.4989372 -0.37549639  0.01159884  1.17891641 -0.05286516 -0.1680664
## [6,]  5.0731242 -0.36904566 -0.49744932 -0.03330522  0.30843805  0.6814844
##                           
## [1,]  1.0569172 -1.0503824
## [2,]  4.8158570  0.3251424
## [3,]  2.6683461 -2.1048716
## [4,] -0.7492096 -0.9551027
## [5,] -0.5428339 -0.5306399
## [6,]  1.6493251  0.8274405

Example 1: single vs multi-chains

There is some overhead in parallel processing when using the future package, so the payoff when using parallel processing may vary by the problem. Here it is about a 2-4x speedup, but you can see more benefit at higher iterations. Note that this may not yield as many usable iterations as a single large chain if a substantial burnin period is needed, but it will enable useful convergence diagnostics. Note that the future package can implement sequential processing, which effectively turns the kmbayes_parallel into a loop, but still has all other advantages of multiple chains.

# enable parallel processing
future::plan(strategy = future::multiprocess)

# single run of 4000 observations from bkmr package
set.seed(111)
system.time(kmfit <- suppressMessages(kmbayes(y = y, Z = Z, X = X, iter = 4000, verbose = FALSE, varsel = FALSE)))
##    user  system elapsed 
##  13.973   0.258  14.502
# 4 runs of 1000 observations from bkmrhat package
set.seed(111)
system.time(kmfit5 <- suppressMessages(kmbayes_parallel(nchains=4, y = y, Z = Z, X = X, iter = 1000, verbose = FALSE, varsel = FALSE)))
## Chain 1 
## Chain 2 
## Chain 3 
## Chain 4
##    user  system elapsed 
##   0.087   0.005   6.487

Example 2: Diagnostics

The diagnostics from the rstan package come from the monitor function (see the help files for that function in the rstan pacakge)

# Using rstan functions (set burnin/warmup to zero for comparability with coda numbers given later
#  posterior summaries should be performed after excluding warmup/burnin)
singlediag = kmbayes_diag(kmfit, warmup=0, digits_summary=2)
## Single chain
## Inference for the input samples (1 chains: each with iter = 4000; warmup = 0):
## 
##            Q5  Q50  Q95 Mean  SD  Rhat Bulk_ESS Tail_ESS
## beta1     1.9  2.0  2.1  2.0 0.0  1.00     2820     3194
## beta2     0.0  0.1  0.3  0.1 0.1  1.00     3739     3535
## lambda    3.9 10.0 22.3 11.2 5.9  1.00      346      222
## r1        0.0  0.0  0.1  0.0 0.1  1.01      129      173
## r2        0.0  0.0  0.1  0.0 0.1  1.00      182      181
## r3        0.0  0.0  0.0  0.0 0.0  1.01      158      112
## r4        0.0  0.0  0.1  0.0 0.1  1.03      176      135
## r5        0.0  0.0  0.0  0.0 0.1  1.00      107      114
## sigsq.eps 0.2  0.3  0.5  0.4 0.1  1.00     1262     1563
## 
## For each parameter, Bulk_ESS and Tail_ESS are crude measures of 
## effective sample size for bulk and tail quantities respectively (an ESS > 100 
## per chain is considered good), and Rhat is the potential scale reduction 
## factor on rank normalized split chains (at convergence, Rhat <= 1.05).
# Using rstan functions (multiple chains enable R-hat)
multidiag = kmbayes_diag(kmfit5, warmup=0, digits_summary=2)
## Parallel chains
## Inference for the input samples (4 chains: each with iter = 1000; warmup = 0):
## 
##            Q5  Q50  Q95 Mean  SD  Rhat Bulk_ESS Tail_ESS
## beta1     1.9  2.0  2.1  2.0 0.0  1.00     1847     1287
## beta2     0.0  0.1  0.3  0.1 0.1  1.00     2647     2995
## lambda    4.0 10.6 23.7 11.9 6.6  1.01      325      281
## r1        0.0  0.0  0.1  0.0 0.1  1.02      137       94
## r2        0.0  0.0  0.1  0.1 0.1  1.03      155      109
## r3        0.0  0.0  0.2  0.1 0.2  1.03       98       41
## r4        0.0  0.0  0.2  0.1 0.2  1.03      113       78
## r5        0.0  0.0  0.1  0.0 0.1  1.05       81       80
## sigsq.eps 0.2  0.3  0.5  0.3 0.1  1.00      710      442
## 
## For each parameter, Bulk_ESS and Tail_ESS are crude measures of 
## effective sample size for bulk and tail quantities respectively (an ESS > 100 
## per chain is considered good), and Rhat is the potential scale reduction 
## factor on rank normalized split chains (at convergence, Rhat <= 1.05).
# using coda functions, not using any burnin (for demonstration only)
kmfitcoda = as.mcmc(kmfit, iterstart = 1)
kmfit5coda = as.mcmc.list(kmfit5, iterstart = 1)

# single chain trace plot
traceplot(kmfitcoda)

plot of chunk diagnostics 1plot of chunk diagnostics 1plot of chunk diagnostics 1plot of chunk diagnostics 1plot of chunk diagnostics 1plot of chunk diagnostics 1plot of chunk diagnostics 1plot of chunk diagnostics 1plot of chunk diagnostics 1 The trace plots look typical, and fine, but trace plots don't give a full picture of convergence. Note that there is apparent quick convergence for a couple of parameters demonstrated by movement away from the starting value and concentration of the rest of the samples within a narrow band.

Seeing visual evidence that different chains are sampling from the same marginal distributions is reassuring about the stability of the results.

# multiple chain trace plot
traceplot(kmfit5coda)

plot of chunk diagnostics 2plot of chunk diagnostics 2plot of chunk diagnostics 2plot of chunk diagnostics 2plot of chunk diagnostics 2plot of chunk diagnostics 2plot of chunk diagnostics 2plot of chunk diagnostics 2plot of chunk diagnostics 2

Trace plots can be limited, and are sometimes difficult to use effectively with scale parameters (of which bkmr has many). Thus, more formal diagnostics are helpful.

Gelman's r-hat diagnostic gives an interpretable diagnostic: the expected reduction in the standard error of the posterior means if you could run the chains to an infinite size. These give some idea about when is a fine idea to stop sampling. There are rules of thumb about using r-hat to stop sampling that are available from several authors (for example you can consult the help files for rstan and coda).

Effective sample size is also useful - it estimates the amount of information in your chain, expressed in terms of the number of independent posterior samples it would take to match that information (e.g. if we could just sample from the posterior directly).

# Gelman's r-hat using coda estimator (will differ from rstan implementation)
gelman.diag(kmfit5coda)
## Potential scale reduction factors:
## 
##           Point est. Upper C.I.
## beta1           1.00       1.01
## beta2           1.00       1.00
## lambda          1.01       1.02
## r1              1.03       1.07
## r2              1.07       1.18
## r3              1.04       1.06
## r4              1.04       1.05
## r5              1.21       1.40
## sigsq.eps       1.00       1.00
## 
## Multivariate psrf
## 
## 1.06
# effective sample size
effectiveSize(kmfitcoda)
##      beta1      beta2     lambda         r1         r2         r3         r4 
## 2411.61878 2865.78299  431.49158   87.11091  260.29419  328.45388  181.61903 
##         r5  sigsq.eps 
##  123.06100 1719.70679
effectiveSize(kmfit5coda)
##     beta1     beta2    lambda        r1        r2        r3        r4        r5 
## 1580.8891 3069.8553  407.7873  178.5527  139.9949  131.5215  122.8238  117.5702 
## sigsq.eps 
## 1062.4836

Example 3: Posterior summaries

Posterior kernel marginal densities, 1 chain

# posterior kernel marginal densities using `mcmc` and `mcmc` objects
densplot(kmfitcoda)

plot of chunk post summaries 1plot of chunk post summaries 1plot of chunk post summaries 1plot of chunk post summaries 1plot of chunk post summaries 1plot of chunk post summaries 1plot of chunk post summaries 1plot of chunk post summaries 1plot of chunk post summaries 1

Posterior kernel marginal densities, multiple chains combined. Look for multiple modes that may indicate non-convergence of some chains

# posterior kernel marginal densities using `mcmc` and `mcmc` objects
densplot(kmfit5coda)

plot of chunk post summaries 2plot of chunk post summaries 2plot of chunk post summaries 2plot of chunk post summaries 2plot of chunk post summaries 2plot of chunk post summaries 2plot of chunk post summaries 2plot of chunk post summaries 2plot of chunk post summaries 2

Other diagnostics from the coda package are available here.

Finally, the chains from the original kmbayes_parallel fit can be combined into a single chain (see the help files for how to deal with burn-in, the default in bkmr is to use the first half of the chain, which is respected here). The comb_bkmrfits function smartly first combines the burn-in iterations and then combines the iterations after burnin, such that the burn-in rules of subsequent functions within the bkmr package are respected. Note that unlike the as.mcmc.list function, this function combines all iterations into a single chain, so trace plots will not be good diagnotistics in this combined object, and it should be used once one is assured that all chains have converged and the burn-in is acceptable.

With this combined set of samples, you can follow any of the post-processing functions from the bkmr functions, which are described here: https://jenfb.github.io/bkmr/overview.html. For example, see below the estimation of the posterior mean difference along a series of quantiles of all exposures in Z.

# posterior summaries using `mcmc` and `mcmc` objects
summary(kmfitcoda)
## 
## Iterations = 1:4000
## Thinning interval = 1 
## Number of chains = 1 
## Sample size per chain = 4000 
## 
## 1. Empirical mean and standard deviation for each variable,
##    plus standard error of the mean:
## 
##               Mean      SD  Naive SE Time-series SE
## beta1      1.98349 0.04392 0.0006945      0.0008944
## beta2      0.11999 0.08532 0.0013490      0.0015937
## lambda    11.18052 5.90540 0.0933725      0.2842909
## r1         0.03010 0.06385 0.0010095      0.0068408
## r2         0.03656 0.05690 0.0008997      0.0035270
## r3         0.02131 0.04065 0.0006427      0.0022429
## r4         0.02855 0.06641 0.0010500      0.0049278
## r5         0.02904 0.09543 0.0015089      0.0086023
## sigsq.eps  0.35283 0.08228 0.0013009      0.0019840
## 
## 2. Quantiles for each variable:
## 
##               2.5%     25%      50%      75%    97.5%
## beta1      1.90086 1.95474  1.98287  2.01177  2.07218
## beta2     -0.04513 0.06141  0.12001  0.17716  0.28713
## lambda     3.23941 7.01814 10.00979 14.12075 25.93332
## r1         0.01022 0.01232  0.01807  0.02767  0.09818
## r2         0.01018 0.01433  0.02172  0.04049  0.12353
## r3         0.01015 0.01180  0.01488  0.02198  0.05655
## r4         0.01040 0.01299  0.01670  0.02582  0.08533
## r5         0.01025 0.01219  0.01532  0.01951  0.07021
## sigsq.eps  0.22855 0.29302  0.34057  0.39833  0.54838
summary(kmfit5coda)
## 
## Iterations = 1:1000
## Thinning interval = 1 
## Number of chains = 4 
## Sample size per chain = 1000 
## 
## 1. Empirical mean and standard deviation for each variable,
##    plus standard error of the mean:
## 
##               Mean      SD  Naive SE Time-series SE
## beta1      1.98504 0.04677 0.0007395       0.001239
## beta2      0.11575 0.09000 0.0014230       0.001707
## lambda    11.92057 6.56256 0.1037632       0.348427
## r1         0.04414 0.10754 0.0017004       0.010208
## r2         0.05647 0.13704 0.0021668       0.016506
## r3         0.05954 0.18967 0.0029990       0.028864
## r4         0.05837 0.18214 0.0028798       0.016738
## r5         0.04633 0.14204 0.0022458       0.014624
## sigsq.eps  0.34664 0.08395 0.0013273       0.003062
## 
## 2. Quantiles for each variable:
## 
##               2.5%     25%      50%      75%   97.5%
## beta1      1.89885 1.95420  1.98425  2.01385  2.0776
## beta2     -0.05946 0.05566  0.11620  0.17516  0.2922
## lambda     3.58754 7.15603 10.55348 15.18081 27.7287
## r1         0.01030 0.01382  0.01910  0.03149  0.3753
## r2         0.01024 0.01354  0.02344  0.04398  0.4713
## r3         0.01026 0.01217  0.01674  0.02392  0.8794
## r4         0.01015 0.01225  0.01611  0.02697  0.7717
## r5         0.01018 0.01172  0.01490  0.02179  0.6456
## sigsq.eps  0.21068 0.28928  0.33687  0.39156  0.5442
# highest posterior density intervals using `mcmc` and `mcmc` objects
HPDinterval(kmfitcoda)
##                 lower       upper
## beta1      1.90086937  2.07231225
## beta2     -0.04051413  0.29012992
## lambda     2.68866745 22.62238656
## r1         0.01002094  0.06754473
## r2         0.01002597  0.09683241
## r3         0.01003504  0.04392585
## r4         0.01000511  0.06510500
## r5         0.01005733  0.04755863
## sigsq.eps  0.21411781  0.52229778
## attr(,"Probability")
## [1] 0.95
HPDinterval(kmfit5coda)
## [[1]]
##                 lower       upper
## beta1      1.90095064  2.07791956
## beta2     -0.05862274  0.28118915
## lambda     2.91386257 25.86857310
## r1         0.01013501  0.25079609
## r2         0.01000428  0.51753332
## r3         0.01014805  0.05769139
## r4         0.01015497  0.15973022
## r5         0.01009127  0.16953688
## sigsq.eps  0.21170393  0.57226636
## attr(,"Probability")
## [1] 0.95
## 
## [[2]]
##                 lower       upper
## beta1      1.90226494  2.07365948
## beta2     -0.05062520  0.29951571
## lambda     3.12950912 22.91155191
## r1         0.01032702  0.10587560
## r2         0.01004702  0.09449772
## r3         0.01000304  0.04779422
## r4         0.01005963  0.07642952
## r5         0.01017053  0.46628446
## sigsq.eps  0.16957238  0.49457057
## attr(,"Probability")
## [1] 0.95
## 
## [[3]]
##                 lower       upper
## beta1      1.89034527  2.06965257
## beta2     -0.05466157  0.30755041
## lambda     3.58754416 24.92685271
## r1         0.01079946  0.07476612
## r2         0.01005008  0.11222992
## r3         0.01026009  0.99871122
## r4         0.01007529  0.61692323
## r5         0.01021641  0.06880410
## sigsq.eps  0.21487500  0.52810961
## attr(,"Probability")
## [1] 0.95
## 
## [[4]]
##                 lower       upper
## beta1      1.90946455  2.07985932
## beta2     -0.05277738  0.29022701
## lambda     3.17405148 23.18505389
## r1         0.01018462  0.11921053
## r2         0.01005280  0.14463180
## r3         0.01018742  0.22776253
## r4         0.01009111  0.15711542
## r5         0.01018262  0.04220932
## sigsq.eps  0.21722643  0.50853234
## attr(,"Probability")
## [1] 0.95
# combine multiple chains into a single chain
fitkmccomb = comb_bkmrfits(kmfit5)


# For example:
summary(fitkmccomb)
## Fitted object of class 'bkmrfit'
## Iterations: 4000 
## Outcome family: gaussian  
## Model fit on: 2020-09-08 10:58:39 
## Running time:  4.16858 secs 
## 
## Acceptance rates for Metropolis-Hastings algorithm:
##    param      rate
## 1 lambda 0.4341085
## 2     r1 0.2038010
## 3     r2 0.2780695
## 4     r3 0.1690423
## 5     r4 0.1815454
## 6     r5 0.1315329
## 
## Parameter estimates (based on iterations 2001-4000):
##       param     mean      sd    q_2.5   q_97.5
## 1     beta1  1.98265 0.04236  1.90102  2.06799
## 2     beta2  0.11380 0.08595 -0.05622  0.28000
## 3 sigsq.eps  0.35111 0.08172  0.22696  0.55196
## 4        r1  0.02634 0.02277  0.01110  0.09901
## 5        r2  0.03426 0.02907  0.01025  0.11333
## 6        r3  0.01846 0.01029  0.01023  0.04512
## 7        r4  0.02105 0.01871  0.01023  0.06984
## 8        r5  0.01810 0.01714  0.01078  0.04059
## 9    lambda 11.82551 6.55547  3.33260 28.09170
mean.difference <- suppressWarnings(OverallRiskSummaries(fit = fitkmccomb, y = y, Z = Z, X = X, 
                                      qs = seq(0.25, 0.75, by = 0.05), 
                                      q.fixed = 0.5, method = "exact"))
mean.difference
##    quantile        est         sd
## 1      0.25 -0.7219873 0.12001580
## 2      0.30 -0.5823791 0.09730798
## 3      0.35 -0.3910603 0.08027052
## 4      0.40 -0.2740483 0.04692598
## 5      0.45 -0.1530667 0.02687483
## 6      0.50  0.0000000 0.00000000
## 7      0.55  0.2173967 0.04225033
## 8      0.60  0.3347273 0.05184714
## 9      0.65  0.5160224 0.08424381
## 10     0.70  0.8844215 0.14840458
## 11     0.75  0.9761952 0.15743255
with(mean.difference, {
  plot(quantile, est, pch=19, ylim=c(min(est - 1.96*sd), max(est + 1.96*sd)), 
       axes=FALSE, ylab= "Mean difference", xlab = "Joint quantile")
  segments(x0=quantile, x1=quantile, y0 = est - 1.96*sd, y1 = est + 1.96*sd)
  abline(h=0)
  axis(1)
  axis(2)
  box(bty='l')
})

plot of chunk post summaries 3

Example 4: diagnostics and inference when variable selection is used (Bayesian model averaging over the scale parameters of the kernel function)

These results parallel previous session and are given here without comment, other than to note that no fixed effects (X variables) are included, and that it is useful to check the posterior inclusion probabilities to ensure they are similar across chains.

set.seed(111)
system.time(kmfitbma.list <- suppressWarnings(kmbayes_parallel(nchains=4, y = y, Z = Z, X = X, iter = 1000, verbose = FALSE, varsel = TRUE)))
## Chain 1
## Iteration: 100 (10% completed; 0.16832 secs elapsed)
## Iteration: 200 (20% completed; 0.33918 secs elapsed)
## Iteration: 300 (30% completed; 0.52428 secs elapsed)
## Iteration: 400 (40% completed; 0.67745 secs elapsed)
## Iteration: 500 (50% completed; 0.85374 secs elapsed)
## Iteration: 600 (60% completed; 1.01676 secs elapsed)
## Iteration: 700 (70% completed; 1.19156 secs elapsed)
## Iteration: 800 (80% completed; 1.37779 secs elapsed)
## Iteration: 900 (90% completed; 1.57305 secs elapsed)
## Iteration: 1000 (100% completed; 1.75086 secs elapsed)
## Chain 2
## Iteration: 100 (10% completed; 0.16259 secs elapsed)
## Iteration: 200 (20% completed; 0.33395 secs elapsed)
## Iteration: 300 (30% completed; 0.49694 secs elapsed)
## Iteration: 400 (40% completed; 0.65646 secs elapsed)
## Iteration: 500 (50% completed; 0.82854 secs elapsed)
## Iteration: 600 (60% completed; 0.99429 secs elapsed)
## Iteration: 700 (70% completed; 1.16818 secs elapsed)
## Iteration: 800 (80% completed; 1.3513 secs elapsed)
## Iteration: 900 (90% completed; 1.5465 secs elapsed)
## Iteration: 1000 (100% completed; 1.71671 secs elapsed)
## Chain 3
## Iteration: 100 (10% completed; 0.16648 secs elapsed)
## Iteration: 200 (20% completed; 0.34762 secs elapsed)
## Iteration: 300 (30% completed; 0.51031 secs elapsed)
## Iteration: 400 (40% completed; 0.67377 secs elapsed)
## Iteration: 500 (50% completed; 0.84787 secs elapsed)
## Iteration: 600 (60% completed; 1.01307 secs elapsed)
## Iteration: 700 (70% completed; 1.18325 secs elapsed)
## Iteration: 800 (80% completed; 1.37914 secs elapsed)
## Iteration: 900 (90% completed; 1.57388 secs elapsed)
## Iteration: 1000 (100% completed; 1.74193 secs elapsed)
## Chain 4
## Iteration: 100 (10% completed; 0.1611 secs elapsed)
## Iteration: 200 (20% completed; 0.33424 secs elapsed)
## Iteration: 300 (30% completed; 0.49566 secs elapsed)
## Iteration: 400 (40% completed; 0.6476 secs elapsed)
## Iteration: 500 (50% completed; 0.8273 secs elapsed)
## Iteration: 600 (60% completed; 0.99273 secs elapsed)
## Iteration: 700 (70% completed; 1.15981 secs elapsed)
## Iteration: 800 (80% completed; 1.35441 secs elapsed)
## Iteration: 900 (90% completed; 1.54774 secs elapsed)
## Iteration: 1000 (100% completed; 1.71861 secs elapsed)
##    user  system elapsed 
##   0.059   0.004   1.889
bmadiag = kmbayes_diag(kmfitbma.list)
## Parallel chains
## Inference for the input samples (4 chains: each with iter = 1000; warmup = 500):
## 
##            Q5  Q50  Q95 Mean  SD  Rhat Bulk_ESS Tail_ESS
## beta1     1.9  2.0  2.1  2.0 0.0  1.00      944     1687
## beta2     0.0  0.1  0.3  0.1 0.1  1.00     1262     1748
## lambda    4.6 11.0 25.7 12.7 6.9  1.02      125      133
## r1        0.0  0.0  0.1  0.0 0.0  1.22       19       51
## r2        0.0  0.0  0.1  0.0 0.0  1.11       24       68
## r3        0.0  0.0  0.0  0.0 0.0  1.07       71      103
## r4        0.0  0.0  0.0  0.0 0.0  1.04       80      141
## r5        0.0  0.0  0.0  0.0 0.0  1.02       95       66
## sigsq.eps 0.3  0.4  0.5  0.4 0.1  1.00      553     1377
## 
## For each parameter, Bulk_ESS and Tail_ESS are crude measures of 
## effective sample size for bulk and tail quantities respectively (an ESS > 100 
## per chain is considered good), and Rhat is the potential scale reduction 
## factor on rank normalized split chains (at convergence, Rhat <= 1.05).
# posterior exclusion probability of each chain
lapply(kmfitbma.list, function(x) t(ExtractPIPs(x)))
## [[1]]
##          [,1]    [,2]    [,3]    [,4]    [,5]   
## variable "z1"    "z2"    "z3"    "z4"    "z5"   
## PIP      "0.866" "0.808" "0.234" "0.792" "0.468"
## 
## [[2]]
##          [,1]    [,2]    [,3]    [,4]    [,5]   
## variable "z1"    "z2"    "z3"    "z4"    "z5"   
## PIP      "0.940" "0.638" "0.470" "0.774" "0.344"
## 
## [[3]]
##          [,1]    [,2]    [,3]    [,4]    [,5]   
## variable "z1"    "z2"    "z3"    "z4"    "z5"   
## PIP      "0.964" "0.664" "0.338" "0.604" "0.470"
## 
## [[4]]
##          [,1]    [,2]    [,3]    [,4]    [,5]   
## variable "z1"    "z2"    "z3"    "z4"    "z5"   
## PIP      "0.934" "0.774" "0.570" "0.776" "0.442"
kmfitbma.comb = comb_bkmrfits(kmfitbma.list)
summary(kmfitbma.comb)
## Fitted object of class 'bkmrfit'
## Iterations: 4000 
## Outcome family: gaussian  
## Model fit on: 2020-09-08 10:58:45 
## Running time:  1.75132 secs 
## 
## Acceptance rates for Metropolis-Hastings algorithm:
##               param      rate
## 1            lambda 0.4701175
## 2 r/delta (overall) 0.3478370
## 3 r/delta  (move 1) 0.4701754
## 4 r/delta  (move 2) 0.2263868
## 
## Parameter estimates (based on iterations 2001-4000):
##       param     mean      sd    q_2.5   q_97.5
## 1     beta1  1.97732 0.04520  1.88702  2.06416
## 2     beta2  0.11712 0.08902 -0.05683  0.29331
## 3 sigsq.eps  0.37596 0.08888  0.23714  0.58000
## 4        r1  0.02817 0.02885  0.00000  0.08851
## 5        r2  0.03060 0.03590  0.00000  0.12385
## 6        r3  0.00778 0.01207  0.00000  0.03544
## 7        r4  0.01574 0.02210  0.00000  0.05836
## 8        r5  0.00779 0.01116  0.00000  0.03828
## 9    lambda 12.68386 6.88441  3.97781 30.02420
## 
## Posterior inclusion probabilities:
##   variable    PIP
## 1       z1 0.9260
## 2       z2 0.7210
## 3       z3 0.4030
## 4       z4 0.7365
## 5       z5 0.4310
ExtractPIPs(kmfitbma.comb) # posterior inclusion probabilities
##   variable    PIP
## 1       z1 0.9260
## 2       z2 0.7210
## 3       z3 0.4030
## 4       z4 0.7365
## 5       z5 0.4310
mean.difference2 <- suppressWarnings(OverallRiskSummaries(fit = kmfitbma.comb, y = y, Z = Z, X = X,                                       qs = seq(0.25, 0.75, by = 0.05), 
                                      q.fixed = 0.5, method = "exact"))
mean.difference2
##    quantile        est         sd
## 1      0.25 -0.6260508 0.12849958
## 2      0.30 -0.5019690 0.10378009
## 3      0.35 -0.3208190 0.09577422
## 4      0.40 -0.2399183 0.05374843
## 5      0.45 -0.1453802 0.02760506
## 6      0.50  0.0000000 0.00000000
## 7      0.55  0.2020261 0.04706467
## 8      0.60  0.3039065 0.05854309
## 9      0.65  0.4764684 0.09470680
## 10     0.70  0.8316744 0.16714990
## 11     0.75  0.9045896 0.16891285
with(mean.difference2, {
  plot(quantile, est, pch=19, ylim=c(min(est - 1.96*sd), max(est + 1.96*sd)), 
       axes=FALSE, ylab= "Mean difference", xlab = "Joint quantile")
  segments(x0=quantile, x1=quantile, y0 = est - 1.96*sd, y1 = est + 1.96*sd)
  abline(h=0)
  axis(1)
  axis(2)
  box(bty='l')
})

plot of chunk varsel

Example 5: Parallel posterior summaries as diagnostics

bkmrhat also has ported versions of the native posterior summarization functions to compare how these summaries vary across parallel chains. Note that these should serve as diagnostics, and final posterior inference should be done on the combined chain. The easiest of these functions to demonstrate is the OverallRiskSummaries_parallel function, which simply runs OverallRiskSummaries (from the bkmr package) on each chain and combines the results. Notably, this function fixes the y-axis at zero for the median, so it under-represents overall predictive variation across chains, but captures variation in effect estimates across the chains. Ideally, that variation is negligible - e.g. if you see differences between chains that would result in different interpretations, you should re-fit the model with more iterations. In this example, the results are reasonably consistent across chains, but one might want to run more iterations if, say, the differences seen across the upper error bounds are of such a magnitude as to be practically meaningful.

set.seed(111)
system.time(kmfitbma.list <- suppressWarnings(kmbayes_parallel(nchains=4, y = y, Z = Z, X = X, iter = 1000, verbose = FALSE, varsel = TRUE)))
## Chain 1
## Iteration: 100 (10% completed; 0.17407 secs elapsed)
## Iteration: 200 (20% completed; 0.34649 secs elapsed)
## Iteration: 300 (30% completed; 0.52999 secs elapsed)
## Iteration: 400 (40% completed; 0.70884 secs elapsed)
## Iteration: 500 (50% completed; 0.88093 secs elapsed)
## Iteration: 600 (60% completed; 1.05573 secs elapsed)
## Iteration: 700 (70% completed; 1.23029 secs elapsed)
## Iteration: 800 (80% completed; 1.4171 secs elapsed)
## Iteration: 900 (90% completed; 1.60029 secs elapsed)
## Iteration: 1000 (100% completed; 1.78528 secs elapsed)
## Chain 2
## Iteration: 100 (10% completed; 0.17944 secs elapsed)
## Iteration: 200 (20% completed; 0.35169 secs elapsed)
## Iteration: 300 (30% completed; 0.52609 secs elapsed)
## Iteration: 400 (40% completed; 0.70891 secs elapsed)
## Iteration: 500 (50% completed; 0.8761 secs elapsed)
## Iteration: 600 (60% completed; 1.04731 secs elapsed)
## Iteration: 700 (70% completed; 1.22273 secs elapsed)
## Iteration: 800 (80% completed; 1.40487 secs elapsed)
## Iteration: 900 (90% completed; 1.59858 secs elapsed)
## Iteration: 1000 (100% completed; 1.78209 secs elapsed)
## Chain 3
## Iteration: 100 (10% completed; 0.17872 secs elapsed)
## Iteration: 200 (20% completed; 0.3526 secs elapsed)
## Iteration: 300 (30% completed; 0.53121 secs elapsed)
## Iteration: 400 (40% completed; 0.71663 secs elapsed)
## Iteration: 500 (50% completed; 0.8825 secs elapsed)
## Iteration: 600 (60% completed; 1.0566 secs elapsed)
## Iteration: 700 (70% completed; 1.22646 secs elapsed)
## Iteration: 800 (80% completed; 1.40671 secs elapsed)
## Iteration: 900 (90% completed; 1.59582 secs elapsed)
## Iteration: 1000 (100% completed; 1.78005 secs elapsed)
## Chain 4
## Iteration: 100 (10% completed; 0.18301 secs elapsed)
## Iteration: 200 (20% completed; 0.35704 secs elapsed)
## Iteration: 300 (30% completed; 0.53534 secs elapsed)
## Iteration: 400 (40% completed; 0.71203 secs elapsed)
## Iteration: 500 (50% completed; 0.87501 secs elapsed)
## Iteration: 600 (60% completed; 1.03775 secs elapsed)
## Iteration: 700 (70% completed; 1.22166 secs elapsed)
## Iteration: 800 (80% completed; 1.40107 secs elapsed)
## Iteration: 900 (90% completed; 1.73358 secs elapsed)
## Iteration: 1000 (100% completed; 1.89462 secs elapsed)
##    user  system elapsed 
##   0.063   0.004   1.948
meandifference_par = OverallRiskSummaries_parallel(kmfitbma.list, y = y, Z = Z, X = X ,qs = seq(0.25, 0.75, by = 0.05), q.fixed = 0.5, method = "exact")
## Chain 1 
## Chain 2 
## Chain 3 
## Chain 4
head(meandifference_par)
##   quantile        est         sd chain
## 1     0.25 -0.5993038 0.11647532     1
## 2     0.30 -0.4796568 0.09186142     1
## 3     0.35 -0.3035896 0.09261565     1
## 4     0.40 -0.2250368 0.04800464     1
## 5     0.45 -0.1343526 0.02434849     1
## 6     0.50  0.0000000 0.00000000     1
nchains = length(unique(meandifference_par$chain))

with(meandifference_par, {
  plot.new()
  plot.window(ylim=c(min(est - 1.96*sd), max(est + 1.96*sd)), 
              xlim=c(min(quantile), max(quantile)),
       ylab= "Mean difference", xlab = "Joint quantile")
  for(cch in seq_len(nchains)){
    width = diff(quantile)[1]
    jit = runif(1, -width/5, width/5)
   points(jit+quantile[chain==cch], est[chain==cch], pch=19, col=cch) 
   segments(x0=jit+quantile[chain==cch], x1=jit+quantile[chain==cch], y0 = est[chain==cch] - 1.96*sd[chain==cch], y1 = est[chain==cch] + 1.96*sd[chain==cch], col=cch)
  }
  abline(h=0)
  axis(1)
  axis(2)
  box(bty='l')
  legend("bottom", col=1:nchains, pch=19, lty=1, legend=paste("chain", 1:nchains), bty="n")
})

regfuns_par = PredictorResponseUnivar_parallel(kmfitbma.list, y = y, Z = Z, X = X ,qs = seq(0.25, 0.75, by = 0.05), q.fixed = 0.5, method = "exact")
## Chain 1 
## Chain 2 
## Chain 3 
## Chain 4
head(regfuns_par)
##   variable         z        est        se chain
## 1       z1 -2.186199 -1.2002239 0.7935128     1
## 2       z1 -2.082261 -1.1505627 0.7619830     1
## 3       z1 -1.978323 -1.1003333 0.7304029     1
## 4       z1 -1.874385 -1.0495803 0.6987936     1
## 5       z1 -1.770446 -0.9983499 0.6671784     1
## 6       z1 -1.666508 -0.9466894 0.6355833     1
nchains = length(unique(meandifference_par$chain))

# single variable
with(regfuns_par[regfuns_par$variable=="z1",], {
  plot.new()
  plot.window(ylim=c(min(est - 1.96*se), max(est + 1.96*se)), 
              xlim=c(min(z), max(z)),
       ylab= "Predicted Y", xlab = "Z")
  pc = c("#000000", "#E69F00", "#56B4E9", "#009E73", "#F0E442", "#0072B2", "#D55E00", "#CC79A7", "#999999")
  pc2 = c("#0000001A", "#E69F001A", "#56B4E91A", "#009E731A", "#F0E4421A", "#0072B21A", "#D55E001A", "#CC79A71A", "#9999991A")
  for(cch in seq_len(nchains)){
   ribbonX = c(z[chain==cch], rev(z[chain==cch]))
   ribbonY = c(est[chain==cch] + 1.96*se[chain==cch], rev(est[chain==cch] - 1.96*se[chain==cch]))
   polygon(x=ribbonX, y = ribbonY, col=pc2[cch], border=NA)
   lines(z[chain==cch], est[chain==cch], pch=19, col=pc[cch]) 
  }
  axis(1)
  axis(2)
  box(bty='l')
  legend("bottom", col=1:nchains, pch=19, lty=1, legend=paste("chain", 1:nchains), bty="n")
})

plot of chunk post diagnosticsplot of chunk post diagnostics

Acknowledgments

Thanks to Haotian “Howie” Wu for invaluable feedback on early versions of the package.