Monte-Carlo Methods for Prediction Functions

Zachary M. Jones

2017-03-05

This packages allows you to to marginalize arbitrary prediction functions using Monte-Carlo integration. Since many prediction functions cannot be easily decomposed into a sum of low dimensional components marginalization can be helpful in making these functions interpretable.

marginalPrediction does this computation and then evaluates the marginalized function at a set grid points, which can be uniformly created, subsampled from the training data, or explicitly specified via the points argument.

The create of a uniform grid is handled by the uniformGrid method. If uniform = FALSE and the points argument isn’t used to specify what points to evaluate, a sample of size n[1] is taken from the data without replacement.

The function is integrated against a sample of size n[2] taken without replacement from the data. The argument int.points can be used to override this (in which case you can specify n[2] = NA). int.points is a vector of integerish indices which specify rows of the data to use instead.

library(mmpf)
library(randomForest)
## randomForest 4.6-12
## Type rfNews() to see new features/changes/bug fixes.
library(ggplot2)
## 
## Attaching package: 'ggplot2'
## The following object is masked from 'package:randomForest':
## 
##     margin
library(reshape2)

data(swiss)

fit = randomForest(Fertility ~ ., swiss)
mp = marginalPrediction(swiss[, -1], "Education", c(10, nrow(swiss)), fit)
mp
##     Education    preds
##  1:         1 71.78593
##  2:         7 71.81086
##  3:        13 71.77107
##  4:        18 66.38694
##  5:        24 63.79314
##  6:        30 63.10018
##  7:        36 63.41916
##  8:        41 63.30996
##  9:        47 62.56040
## 10:        53 62.56040
ggplot(data.frame(mp), aes(Education, preds)) + geom_point() + geom_line()

The output of marginalPrediction is a data.table which contains the marginalized predictions and the grid points of the vars.

By default the Monte-Carlo expectation is computed, which is set by the aggregate.fun argument’s default value, the mean function. Substituting, say, the median, would give a different output.

By passing the identity function to aggregate.fun, which simply returns its input exactly, the integration points are returned directly so that the prediction element of the return is a matrix of dimension n. n, although it is an argument, can be larger or smaller depending on the interaction between the input arguments n and data. For example if a uniform grid of size 10 is requested (via n[1]) from a factor with only 5 levels, a uniform grid of size 5 is created. If vars is a vector of length greater than 1, then n[1] becomes the size of the Cartesian product of the grids created for each element of vars, which can be at most n[1]^length(vars).

mp = marginalPrediction(swiss[, -1], "Education", c(10, 5), fit, aggregate.fun = identity)
mp
##     Education   preds1   preds2   preds3   preds4   preds5
##  1:         1 67.26028 72.22698 77.07416 79.41154 69.77705
##  2:         7 66.99037 72.18876 76.33932 80.12122 70.63135
##  3:        13 67.69827 72.57807 76.02844 79.19610 71.35259
##  4:        18 61.36805 66.71255 71.07892 74.04416 66.62648
##  5:        24 59.28644 63.40695 68.32098 71.36805 64.16810
##  6:        30 57.98800 62.52623 67.91764 71.04811 63.75927
##  7:        36 58.46004 62.88612 68.08237 71.24201 64.09344
##  8:        41 58.30849 62.84218 68.02237 71.18201 64.00034
##  9:        47 57.28113 62.06052 67.43486 70.66994 63.33073
## 10:        53 57.28113 62.06052 67.43486 70.66994 63.33073
ggplot(melt(data.frame(mp), id.vars = "Education"), aes(Education, value, group = variable)) + geom_point() + geom_line()

predict.fun specifies a prediction function to apply to the model argument. This function must take two arguments, object (where model is inserted) and newdata, which is a data.frame to compute predictions on, which is generated internally and is controlled by the other arguments. This allows marginalPrediction to handle cases in which predictions for a single data point are vector-valued. That is, classification tasks where probabilities are output, and multivariate regression and/or classification. In these cases aggregate.fun is applied separately to each column of the prediction matrix. aggregate.fun must take one argument x, a vector output from predict.fun and return a vector of no greater dimension than that of x.

data(iris)

fit = randomForest(Species ~ ., iris)
mp = marginalPrediction(iris[, -ncol(iris)], "Petal.Width", c(10, 25), fit,
  predict.fun = function(object, newdata) predict(object, newdata = newdata, type = "prob"))
mp
##     Petal.Width  setosa versicolor virginica
##  1:   0.1000000 0.60136    0.24352   0.15512
##  2:   0.3666667 0.60136    0.24352   0.15512
##  3:   0.6333333 0.59384    0.24920   0.15696
##  4:   0.9000000 0.17560    0.56464   0.25976
##  5:   1.1666667 0.17448    0.56576   0.25976
##  6:   1.4333333 0.17448    0.55168   0.27384
##  7:   1.7000000 0.16968    0.42560   0.40472
##  8:   1.9666667 0.16208    0.22016   0.61776
##  9:   2.2333333 0.16208    0.21984   0.61808
## 10:   2.5000000 0.16208    0.21984   0.61808
plt = melt(data.frame(mp), id.vars = "Petal.Width", variable.name = "class",
  value.name = "prob")

ggplot(plt, aes(Petal.Width, prob, color = class)) + geom_line() + geom_point()

As mentioned before, vars can include multiple variables.

mp = marginalPrediction(iris[, -ncol(iris)], c("Petal.Width", "Petal.Length"), c(10, 25), fit,
  predict.fun = function(object, newdata) predict(object, newdata = newdata, type = "prob"))
mp
##      Petal.Width Petal.Length  setosa versicolor virginica
##   1:   0.1000000     1.000000 0.94472    0.05448   0.00080
##   2:   0.1000000     1.655556 0.94472    0.05448   0.00080
##   3:   0.1000000     2.311111 0.94304    0.05616   0.00080
##   4:   0.1000000     2.966667 0.44664    0.53944   0.01392
##   5:   0.1000000     3.622222 0.44456    0.54152   0.01392
##   6:   0.1000000     4.277778 0.44456    0.53464   0.02080
##   7:   0.1000000     4.933333 0.41664    0.45712   0.12624
##   8:   0.1000000     5.588889 0.41032    0.18552   0.40416
##   9:   0.1000000     6.244444 0.41032    0.18552   0.40416
##  10:   0.1000000     6.900000 0.41032    0.18552   0.40416
##  11:   0.3666667     1.000000 0.94472    0.05448   0.00080
##  12:   0.3666667     1.655556 0.94472    0.05448   0.00080
##  13:   0.3666667     2.311111 0.94304    0.05616   0.00080
##  14:   0.3666667     2.966667 0.44664    0.53944   0.01392
##  15:   0.3666667     3.622222 0.44456    0.54152   0.01392
##  16:   0.3666667     4.277778 0.44456    0.53464   0.02080
##  17:   0.3666667     4.933333 0.41664    0.45712   0.12624
##  18:   0.3666667     5.588889 0.41032    0.18552   0.40416
##  19:   0.3666667     6.244444 0.41032    0.18552   0.40416
##  20:   0.3666667     6.900000 0.41032    0.18552   0.40416
##  21:   0.6333333     1.000000 0.93440    0.06472   0.00088
##  22:   0.6333333     1.655556 0.93440    0.06472   0.00088
##  23:   0.6333333     2.311111 0.93272    0.06640   0.00088
##  24:   0.6333333     2.966667 0.43632    0.54968   0.01400
##  25:   0.6333333     3.622222 0.43424    0.55176   0.01400
##  26:   0.6333333     4.277778 0.43424    0.54488   0.02088
##  27:   0.6333333     4.933333 0.40736    0.46632   0.12632
##  28:   0.6333333     5.588889 0.40248    0.18816   0.40936
##  29:   0.6333333     6.244444 0.40248    0.18816   0.40936
##  30:   0.6333333     6.900000 0.40248    0.18816   0.40936
##  31:   0.9000000     1.000000 0.50688    0.48144   0.01168
##  32:   0.9000000     1.655556 0.50688    0.48144   0.01168
##  33:   0.9000000     2.311111 0.50520    0.48312   0.01168
##  34:   0.9000000     2.966667 0.00880    0.96640   0.02480
##  35:   0.9000000     3.622222 0.00672    0.96848   0.02480
##  36:   0.9000000     4.277778 0.00672    0.95480   0.03848
##  37:   0.9000000     4.933333 0.00576    0.76840   0.22584
##  38:   0.9000000     5.588889 0.00552    0.31528   0.67920
##  39:   0.9000000     6.244444 0.00552    0.31528   0.67920
##  40:   0.9000000     6.900000 0.00552    0.31528   0.67920
##  41:   1.1666667     1.000000 0.50592    0.48240   0.01168
##  42:   1.1666667     1.655556 0.50592    0.48240   0.01168
##  43:   1.1666667     2.311111 0.50424    0.48408   0.01168
##  44:   1.1666667     2.966667 0.00784    0.96736   0.02480
##  45:   1.1666667     3.622222 0.00576    0.96944   0.02480
##  46:   1.1666667     4.277778 0.00576    0.95576   0.03848
##  47:   1.1666667     4.933333 0.00480    0.76936   0.22584
##  48:   1.1666667     5.588889 0.00456    0.31624   0.67920
##  49:   1.1666667     6.244444 0.00456    0.31624   0.67920
##  50:   1.1666667     6.900000 0.00456    0.31624   0.67920
##  51:   1.4333333     1.000000 0.50592    0.47656   0.01752
##  52:   1.4333333     1.655556 0.50592    0.47656   0.01752
##  53:   1.4333333     2.311111 0.50424    0.47824   0.01752
##  54:   1.4333333     2.966667 0.00784    0.95656   0.03560
##  55:   1.4333333     3.622222 0.00576    0.95864   0.03560
##  56:   1.4333333     4.277778 0.00576    0.94496   0.04928
##  57:   1.4333333     4.933333 0.00480    0.75736   0.23784
##  58:   1.4333333     5.588889 0.00456    0.28616   0.70928
##  59:   1.4333333     6.244444 0.00456    0.28616   0.70928
##  60:   1.4333333     6.900000 0.00456    0.28616   0.70928
##  61:   1.7000000     1.000000 0.49312    0.36120   0.14568
##  62:   1.7000000     1.655556 0.49312    0.36120   0.14568
##  63:   1.7000000     2.311111 0.49144    0.36152   0.14704
##  64:   1.7000000     2.966667 0.00648    0.69136   0.30216
##  65:   1.7000000     3.622222 0.00440    0.69344   0.30216
##  66:   1.7000000     4.277778 0.00440    0.67816   0.31744
##  67:   1.7000000     4.933333 0.00344    0.58848   0.40808
##  68:   1.7000000     5.588889 0.00320    0.32000   0.67680
##  69:   1.7000000     6.244444 0.00320    0.32000   0.67680
##  70:   1.7000000     6.900000 0.00320    0.32000   0.67680
##  71:   1.9666667     1.000000 0.46136    0.22592   0.31272
##  72:   1.9666667     1.655556 0.46136    0.22592   0.31272
##  73:   1.9666667     2.311111 0.45968    0.22624   0.31408
##  74:   1.9666667     2.966667 0.00504    0.42424   0.57072
##  75:   1.9666667     3.622222 0.00296    0.42632   0.57072
##  76:   1.9666667     4.277778 0.00296    0.41336   0.58368
##  77:   1.9666667     4.933333 0.00216    0.11336   0.88448
##  78:   1.9666667     5.588889 0.00192    0.05088   0.94720
##  79:   1.9666667     6.244444 0.00192    0.05088   0.94720
##  80:   1.9666667     6.900000 0.00192    0.05088   0.94720
##  81:   2.2333333     1.000000 0.46136    0.22592   0.31272
##  82:   2.2333333     1.655556 0.46136    0.22592   0.31272
##  83:   2.2333333     2.311111 0.45968    0.22624   0.31408
##  84:   2.2333333     2.966667 0.00504    0.42376   0.57120
##  85:   2.2333333     3.622222 0.00296    0.42584   0.57120
##  86:   2.2333333     4.277778 0.00296    0.41288   0.58416
##  87:   2.2333333     4.933333 0.00216    0.11112   0.88672
##  88:   2.2333333     5.588889 0.00192    0.04864   0.94944
##  89:   2.2333333     6.244444 0.00192    0.04864   0.94944
##  90:   2.2333333     6.900000 0.00192    0.04864   0.94944
##  91:   2.5000000     1.000000 0.46136    0.22592   0.31272
##  92:   2.5000000     1.655556 0.46136    0.22592   0.31272
##  93:   2.5000000     2.311111 0.45968    0.22624   0.31408
##  94:   2.5000000     2.966667 0.00504    0.42376   0.57120
##  95:   2.5000000     3.622222 0.00296    0.42584   0.57120
##  96:   2.5000000     4.277778 0.00296    0.41288   0.58416
##  97:   2.5000000     4.933333 0.00216    0.11112   0.88672
##  98:   2.5000000     5.588889 0.00192    0.04864   0.94944
##  99:   2.5000000     6.244444 0.00192    0.04864   0.94944
## 100:   2.5000000     6.900000 0.00192    0.04864   0.94944
##      Petal.Width Petal.Length  setosa versicolor virginica
plt = melt(data.frame(mp), id.vars = c("Petal.Width", "Petal.Length"),
  variable.name = "class", value.name = "prob")

ggplot(plt, aes(Petal.Width, Petal.Length, fill = prob)) + geom_raster() + facet_wrap(~ class)

Permutation importance is a Monte-Carlo method which estimates the importance of variables in determining predictions by computing the change in prediction error from repeatedly permuting the values of those variables.

permutationImportance can compute this type of importance under arbitrary loss functions and contrast (between the loss with the unpermuted and permuted data).

permutationImportance(iris, "Sepal.Width", "Species", fit)
## [1] 0.01366667

For methods which generate predictions which are characters or unordered factors, the default loss function is the mean misclassification error. For all other types of predictions mean squared error is used.

It is, for example, possible to compute the expected change in the mean misclassification rate by class. The two arguments to loss.fun are the permuted predictions and the target variable. In this case they are both vectors of factors.

contrast.fun takes the output of loss.fun on both the permuted and unpermuted predictions (x corresponds to the permuted predictions and y the unpermuted predictions).

This can, for example, be used to compute the mean misclassification error change on a per-class basis.

permutationImportance(iris, "Sepal.Width", "Species", fit,
  loss.fun = function(x, y) {
    mat = table(x, y)
    n = colSums(mat)
    diag(mat) = 0
    rowSums(mat) / n
  },
  contrast.fun = function(x, y) x - y)
##     setosa versicolor  virginica 
##     0.0000     0.0282     0.0116