vivid: Variable Importance and Variable Interaction Displays

Introduction

Variable importance (VImp), variable interaction measures (VInt) and partial dependence plots (PDPs) are important summaries in the interpretation of statistical and machine learning models. In this vignette we describe new visualization techniques for exploring these model summaries. We construct heatmap and graph-based displays showing variable importance and interaction jointly, which are carefully designed to highlight important aspects of the fit. We describe a new matrix-type layout showing all single and bivariate partial dependence plots, and an alternative layout based on graph Eulerians focusing on key subsets. Our new visualisations are model-agnostic and are applicable to regression and classification supervised learning settings. They enhance interpretation even in situations where the number of variables is large and the interaction structure complex. Our R package vivid (variable importance and variable interaction displays) provides an implementation. When referring to VImp and VInt together, we use the shorthand VIVI.

Install instructions

Some of the plots used by vivid are built upon the zenplots package which requires the graph package from BioConductor. To install the graph and zenplots packages use:

if (!requireNamespace("graph", quietly = TRUE)){ install.packages("BiocManager") BiocManager::install("graph")
} install.packages("zenplots")

Now we can install vivid by using:

install.packages("vivid")

Alternatively you can install the latest development version of the package in R with the commands:

if(!require(remotes)) install.packages('remotes') remotes::install_github('AlanInglis/vividPackage')

We then load the required packages. vivid to create the visualizations and some other packages to create various model fits.

library(vivid) # for visualisations 
library(randomForest) # for model fit
library(ranger)       # for model fit

Section 1: Data and model fits

Data used in this vignette:

The data used in the following examples is simulated from the Friedman benchmark problem1. This benchmark problem is commonly used for testing purposes. The output is created according to the equation:

\[y = 10 sin(π x_1 x_2) + 20 (x_3 - 0.5)^2 + 10 x_4 + 5 x_5 + e\]

For the following examples we set the number of features to equal 9 and the number of samples is set to 350 and fit a randomForest random forest model with \(y\) as the response. As the features \(x_1\) to \(x_5\) are the only variables in the model, therefore \(x_6\) to \(x_{9}\) are noise variables. As can be seen by the above equation, the only interaction is between \(x_1\) and \(x_2\)

Create the data:

set.seed(101)
genFriedman <- function(noFeatures = 10,
                        noSamples = 100,
                        sigma = 1
) {
  # Set Values
  n <- noSamples # no of rows
  p <- noFeatures # no of variables
  e <- rnorm(n, sd = sigma)
  
  
  # Create matrix of values
  xValues <- matrix(runif(n * p, 0, 1), nrow = n) # Create matrix
  colnames(xValues) <- paste0("x", 1:p) # Name columns
  df <- data.frame(xValues) # Create dataframe
  
  
  # Equation:
  # y = 10sin(πx1x2) + 20(x3−0.5)^2 + 10x4 + 5x5 + ε
  y <- (10 * sin(pi * df$x1 * df$x2) + 20 * (df$x3 - 0.5)^2 + 10 * df$x4 + 5 *    df$x5 + e)
  # Adding y to df
  df$y <- y
  df
}

myData <- genFriedman(noFeatures = 9, noSamples = 350, sigma = 1)

Model fit

Here we create two model fits. We create a random forest fit from the randomForest package.

set.seed(1701)
rf <- randomForest(y ~ ., data = myData, importance = TRUE)

Note that for a randomForest model, if importance = TRUE, then multiple importance metrics are returned. In order to choose a specific metric for use with vivid, it is necessary to specify one of the importance metrics returned by the random forest as the argument for the importanceType parameter in the vivi function (as shown below).

vivi function

To utilize vivid, the initial step involves computing variable importance and interactions for a fitted model. The vivi function performs this calculation, producing a square, symmetrical matrix that contains variable importance on the diagonal and variable interactions on the off-diagonal. To calculate the pair-wise interaction strength interactions Friedman’s model agnostic, unnormalized \(H\)-Statistic2 is used. The unnormalized version of the \(H\)-statistic was chosen to have a more direct comparison of interaction effects across pairs of variables and the results of \(H\) are on the scale of the response (for regression). For the importance, either a selected embedded importance measure can be used or an agnostic permutation method3 can be selected (examples of which are shown below).

The vivi function requires three inputs: a fitted machine learning model, a data frame used in the model’s training, and the name of the response variable for the fit. The resulting matrix will have importance and interaction values for all variables in the data frame, excluding the response variable.

Any variables that are not used by the supplied model will have their importance and interaction values set to NA. While the viviHeatmap and viviNetwork visualization functions (seen below) are tailored for displaying the results of vivi calculations, they can also work with any square matrix that has identical row and column names. (Note, the symmetry assumption is not required for viviHeatmap and viviNetwork uses interaction values from the lower-triangular part of the matrix only.)

This function works with multiple model fits and results in a matrix which can be supplied to the plotting functions. The predict function argument uses condvis2::CVpredict by default, which works for many fit classes. To see a description of all function arguments use: ?vivid::vivi()

set.seed(1701)
viviRf  <- vivi(fit = rf, 
                data = myData, 
                response = "y",
                gridSize = 50,
                importanceType = "%IncMSE",
                nmax = 500,
                reorder = TRUE,
                class = 1,
                predictFun = NULL,
                numPerm = 4)
#> %IncMSE importance selected.
#> Calculating interactions...

Section 2: Visualizing the results

Heatmap plot

The viviHeatmap function generates a heatmap that displays variable importance and interactions, with importance values on the diagonal and interaction values on the off-diagonal. The function only requires a vivid matrix as input, which does not need to be symmetrical. Additionally, color palettes can be specified for both importance and interactions via the impPal and intPal arguments. By default, we have opted for single-hue, color-blind friendly sequential color palettes developed by Zeileis et al4. These palettes represent low and high VIVI values with low and high luminance colors, respectively, which can aid in highlighting pertinent values.

The impLims and intLims arguments determine the range of importance and interaction values that will be assigned colors. If these arguments are not provided, the default values will be calculated based on the minimum and maximum VIVI values in the vivid matrix. If any importance or interaction values fall outside of the specified limits, they will be squished to the closest limit. For brevity, only the required vivid matrix input is shown in the following code. To see a description of all the function arguments, see ?vivid::viviheatmap()

viviHeatmap(mat = viviRf)
Figure 1: Heatmap of a random forest fit displaying 2-way interaction strength on the off diagonal and individual variable importance on the diagonal. \(x_1\) and \(x_2\) show a strong interaction with \(x_4\) being the most important for predicting \(y\).

Network plot

With viviNetwork, a network graph is produced to visualize both importance and interactions. Similar to viviHeatmap, this function only requires a vivid matrix as input and uses visual elements, such as size and color, to depict the magnitude of importance and interaction values. The graph displays each variable as a node, where its size and color reflect its importance (larger and darker nodes indicate higher importance). Pairwise interactions are displayed through connecting edges, where thicker and darker edges indicate higher interaction values.

To begin we show the network using default settings.

viviNetwork(mat = viviRf)
Figure 2: Network plot of a random forest fit displaying 2-way interaction strength and individual variable importance. \(x_1\) and \(x_2\) show a strong interaction with \(x_4\) being the most important for predicting \(y\).

We can also filter out any interactions below a set value using the intThreshold argument. This can be useful when the number of variables included in the model is large or just to highlight the strongest interactions. By default, unconnected nodes are displayed, however, they can be removed by setting the argument removeNode = T.

viviNetwork(mat = viviRf, intThreshold = 0.12, removeNode = FALSE)

viviNetwork(mat = viviRf, intThreshold = 0.12, removeNode = TRUE)
Figure 3: Network plot of a random forest fit displaying 2-way interaction strength and individual variable importance. In (a) a filtered network is shown displaying all interactions above 0.12, with all nodes shown. In (b) the unconnected nodes are removed.

The network plot offers multiple customization possibilities when it comes to displaying the network style plot through use of the layout argument. The default layout is a circle but the argument accepts any igraph layout function or a numeric matrix with two columns, one row per node.

viviNetwork(mat = viviRf, 
            layout = cbind(c(1,1,1,1,2,2,2,2,2), c(1,2,4,5,1,2,3,4,5)))
Figure 4: Network plot of a random forest fit using custom layout.

Finally, for the network plot to highlight any relationships in the model fit, we can cluster variables together using the cluster argument. This argument can either accept a vector of cluster memberships for nodes or an igraph package clustering function.

set.seed(1701)
viviNetwork(mat = viviRf, cluster = igraph::cluster_fast_greedy)
Figure 5: Clustered network plot of a random forest fit.

Univariate Partial Dependence Plot

The pdpVars function constructs a grid of univariate PDPs with ICE curves for selected variables. We use ICE curves to assist in the identification of linear or non-linear effects. The fit, data frame used to train the model, and the name of the response variable are required inputs.

In the example below, we select the first five variables from our created vivid matrix to display.

top5 <- colnames(viviRf)[1:5]
pdpVars(data = myData,
        fit = rf,
        response = 'y',
        vars = top5)
Figure 6: Partial dependence plots (black line) with individual conditional expectation curves (colored lines) of a random forest fit on the Friedman data. The changing partial dependence and ICE curves of \(x1\), \(x2\), and \(x4\) indicate that these variables have some impact on the response.

Generalized partial dependence pairs plot

By employing a matrix layout, the pdpPairs function generates a generalized pairs partial dependence plot (GPDP) that encompasses univariate partial dependence (with ICE curves) on the diagonal, bivariate partial dependence on the upper diagonal, and a scatterplot of raw variable values on the lower diagonal, where all colours are assigned to points and ICE curves by the predicted \(\hat{y}\) value. As with the univariate PDP, the fit, data frame used to train the model, and the name of the response variable are required inputs. For a full description of all the function arguments, see ?vivid::pdpPairs. In the following example, we select the first five variables to display and set the number of shown ICE curves to 100.

set.seed(1701)
pdpPairs(data = myData, 
         fit =  rf, 
         response = "y", 
         nmax = 500, 
         gridSize = 10,         
         vars = c("x1", "x2", "x3", "x4", "x5"),
         nIce = 100)
#> Generating ice/pdp fits... waiting...
#> Finished ice/pdp
Figure 7: Filtered generalized pairs partial dependence plot for a random forest fit on the Friedman data. From both the univariate and bivariate PDPs, we can see that \(x1\), \(x2\), and \(x4\) have an impact on the response.

Partial dependence ‘Zenplot’

The pdpZen function utilizes a space-saving technique based on graph Eulerians, introduced by Hierholzer and Wiener in 18735 to create partial dependence plots. We refer to these plots as zen-partial dependence plots (ZPDP). These plots are based on zigzag expanded navigation plots, also known as zenplots, which are available in the zenplots package6. Zenplots were designed to showcase paired graphs of high-dimensional data with a focus on the most significant 2D displays. In our version, we display bivariate PDPs that emphasize variables with the most significant interaction values in a compact zigzag layout. This format is useful when dealing with high-dimensional predictor space.

To begin, we show a ZPDP using all the variables in the model.

set.seed(1701)
pdpZen(data = myData, fit = rf, response = "y", nmax = 500, gridSize = 10)
#> Generating ice/pdp fits... waiting...
#> Finished ice/pdp
Figure 8: Zen partial dependence plot for the random forest fit on the Friedman data. Here we display all the variables used in the random forest model.

In Fig 8, we can see PDPs laid out in a zigzag structure, with the most influential variable pairs displayed at the top and generally decreasing as we move down. In Figure 9, below, we select a subset of variables to display. In this case we select the first five variables from the data. The argument zpath specifies the variables to be plotted, defaulting to all dataset variables aside from the response.

set.seed(1701)
pdpZen(data = myData, 
       fit = rf, 
       response = "y",
       nmax = 500, 
       gridSize = 10, 
       zpath = c("x1", "x2", "x3", "x4", "x5"))
#> Generating ice/pdp fits... waiting...
#> Finished ice/pdp
Figure 9: Zen partial dependence plot for the random forest fit on the Friedman data. Here we display only the first five variables.

We can also create a sequence or sequences of variable paths for use in pdpZen. via the zPath function. The zPath function takes four arguments. These are: viv - a matrix of interaction values, cutoff - exclude interaction values below this threshold, method - a string indicating which method to use to create the path, and connect - a logical value indicating if separate Eulerians should be connected.

You can choose between two methods when using the zPath function: "greedy.weighted" and "strictly.weighted". The first method utilizes a greedy Eulerian path algorithm for connected graphs. This method traverses each edge at least once, beginning at the highest-weighted edge, and moving on to the remaining edges while prioritizing the highest-weighted edge. If the graph has an odd number of nodes, some edges may be visited more than once, or additional edges may be visited. The second method, "strictly.weighted" visits edges in strictly decreasing order by weight (in this case, interaction values). If the connect argument is set to TRUE, the sequences generated by the strictly weighted method are combined to create a single path. In the code below, we provide an example of creating zen-paths using the "strictly.weighted" method, from the top 10% of interaction scores in viviRf (i.e., the created vivid matrix.)

# set zpaths with different parameters
intVals <- viviRf
diag(intVals) <- NA
intThresh <- quantile(intVals, .90, na.rm=TRUE)
zpSw <- zPath(viv = viviRf, cutoff = intThresh, connect = FALSE, method = 'strictly.weighted')



set.seed(1701)
pdpZen(data = myData, 
       fit = rf, 
       response = "y",
       nmax = 500, 
       gridSize = 10, 
       zpath = zpSw)
#> Generating ice/pdp fits... waiting...
#> Finished ice/pdp
Figure 10: ZPDP for a random forest fit on the Friedman data. The sorting method is defined by the ’strictly.weighted‘ method and is un-connected.

Section 3: Example using the predict function

We supply an internal custom predict function called CVpredictfun to both importance and interaction calculations. CVpredictfun is a wrapper around CVpredict from the condvis2 package7. CVpredict accepts a broad range of fit classes thus streamlining the process of calculating variable importance and interactions.

In situations where the fit class is not handled by CVpredict, supplying a custom predict function to the vivi function by way of the predictFun argument allows the agnostic VIVI values to be calculated. In the following, we provide a small example of using such a fit with vivid by using the xgboost package to create a gradient boosting machine (GBM). TO begin we build the model.

library("xgboost")
gbst <- xgboost(data = as.matrix(myData[,1:9]),
                label =  as.matrix(myData[,10]),
                nrounds = 100,
                verbose = 0)

We then build the vivid matrix for the GBM fit using a custom predict function, which must be of the form given in the code snippet.

# predict function for GBM
pFun <- function(fit, data, ...) predict(fit, as.matrix(data[,1:9]))

set.seed(1701)
viviGBst <- vivi(fit = gbst,
                 data = myData,
                 response = "y",
                 reorder = FALSE,
                 normalized = FALSE,
                 predictFun = pFun,
                 gridSize = 50,
                 nmax = 500)
#> Agnostic variable importance method used.
#> Calculating interactions...

From this we can now create our visualisations. For brevity, we only show the heatmap.

viviHeatmap(mat = viviGBst)
Figure 11: Heatmap for the GBM fit on the Friedman data.

Section 4: Classification example

In this section, we briefly describe how to apply the above visualisations to a classification example using the iris data set.

To begin we fit a ranger random forest model with “Species” as the response and create the vivi matrix setting the category for classification to be “setosa” using class.

set.seed(1701)
rfClassif <- ranger(Species~ ., data = iris, probability = T, 
                    importance = "impurity")

set.seed(101)
viviClassif  <- vivi(fit = rfClassif, 
                     data = iris, 
                     response = "Species",
                     gridSize = 10,
                     nmax = 50,
                     reorder = TRUE,
                     class = "setosa")
#> Agnostic variable importance method used.
#> Calculating interactions...

Next we plot the heatmap and network plot of the iris data.

viviHeatmap(mat = viviClassif)

viviNetwork(mat = viviClassif)
Figure 12: Heatmap in (a) and Network plot in (b) of random forest fit on the iris data

As mentioned above, as PDPs are evaluated on a grid and can extrapolate where there is no data. To solve this issue we calculate a convex hull around the data and remove any points that fall outside the convex hull, as shown below.

set.seed(1701)
pdpPairs(data = iris, 
         fit = rfClassif, 
         response = "Species",
         class = "setosa",  
         convexHull = T, 
         gridSize = 10, 
         nmax = 50) 
#> Generating ice/pdp fits... waiting...
#> Finished ice/pdp
Figure 12: GPDP of random forest fit on the iris data with extrapolated data removed.

  1. Friedman, Jerome H. (1991) Multivariate adaptive regression splines. The Annals of Statistics 19 (1), pages 1-67.↩︎

  2. Friedman, J. H. and Popescu, B. E. (2008). “Predictive learning via rule ensembles.” The Annals of Applied Statistics. JSTOR, 916–54.↩︎

  3. Fisher A., Rudin C., Dominici F. (2018). All Models are Wrong but many are Useful: Variable Importance for Black-Box, Proprietary, or Misspecified Prediction Models, using Model Class Reliance. Arxiv.↩︎

  4. Zeileis, Achim, Jason C. Fisher, Kurt Hornik, Ross Ihaka, Claire D. McWhite, Paul Murrell, Reto Stauffer, and Claus O. Wilke. 2020. “Colorspace: A Toolbox for Manipulating and Assessing Colors and Palettes.” Journal of Statistical Software, Articles 96 (1): 1–49↩︎

  5. Hierholzer, Carl, and Chr Wiener. 1873. “Über Die möglichkeit, Einen Linienzug Ohne Wiederholung Und Ohne Unterbrechung Zu Umfahren.” Mathematische Annalen 6 (1): 30–32.↩︎

  6. Hofert, Marius, and Wayne Oldford. 2020. “Zigzag Expanded Navigation Plots in R: The R Package zenplots.” Journal of Statistical Software 95 (4): 1–44.↩︎

  7. Hurley, Catherine, Mark OConnell, and Katarina Domijan. 2022. Condvis2: Interactive Conditional Visualization for Supervised and Unsupervised Models in Shiny.↩︎