Code Companion¶
To accompany Probabilistic Machine Learning: New Frontiers for Modeling Consumers and their Choices
Introduction¶
In this notebook, we will demonstrate how to use two probabilistic programming languages -- Stan
and Pyro
-- to estimate probabilistic machine learning models.
Prerequisites:
Stan
will be called usingcmdstanpy
, which requires a workingCmdStan
installation. This can be installed easily by callingcmdstanpy.install_cmdstan()
in Python after installing thecmdstanpy
package.Pyro
requirestorch
as a dependency
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns
np.random.seed(0)
DGP: Mixed Logit¶
To illustrate the key ideas, we will use the mixed logit model, as introduced in Section 1 of the paper. That model generates choices as: \begin{gather*} P(y_{it} = j) = \frac{\theta_{i}'x_{ijt}}{\sum_{k} \theta_{i}'x_{ikt}} \\ \theta_i \sim \mathcal{N}(\mu, \Sigma). \end{gather*} Here, we will assume that $\Sigma = \mathrm{diag}(\sigma^2)$, where $\sigma$ is a vector.
We will start by generating some data from this model. Then we'll specify a model, using this likelihood plus some appropriate priors, using both Stan and Pyro.
n_cust = 200 # The number of customers we are simulating
n_obs = 10 # The number of observations per customer (i.e., choices, or purchases within the category)
n_alts = 3 # The number of choice alternatives
n_feats = 2 # The number of features of the alternatives (e.g., price, promotion)
n_pars = n_alts + n_feats - 1
coef_idx = [j for j in range(n_alts-1, n_pars)]
X = np.random.normal(size=(n_cust, n_obs, n_alts, n_feats))
def softmax(x):
return np.exp(x)/sum(np.exp(x))
# DGP
mu = np.random.normal(loc=0, scale=1, size=(n_alts+n_feats-1))
sigma = np.random.uniform(low=0.5, high=1.5, size=n_pars)
theta = np.random.multivariate_normal(mean=mu, cov=np.diag(sigma), size=n_cust)
y = np.zeros((n_cust,n_obs), dtype=int)
for i in range(n_cust):
for t in range(n_obs):
v = np.zeros(n_alts)
v[0] = theta[i,coef_idx] @ X[i,t,0] # Normalize first brand's intercept
for j in range(1,n_alts):
v[j] = theta[i,j-1] + theta[i,coef_idx] @ X[i,t,j]
probs = softmax(v)
y[i,t] = np.random.choice(n_alts, size=1, p=probs)[0]+1
Inference with Stan¶
Stan is a PPL that is callable from many different programming languages. Stan has its own syntax for defining models, and can be viewed as as standalone programming language, for the purpose of building probabilistic models. Here, we will build a Stan model, and call it using cmdstanpy
, which is a Python interface for accessing CmdStan
, or the command line version of Stan. Stan is fairly beginner-friendly, but does require learning its syntax (which is similar to other high-level programming languages like R).
import cmdstanpy
In Stan code, models are specified in blocks:
- In the data block, all of the data is named, with dimensions specified
- In the transformed data block, any useful transformations or new variables can be defined
- In the parameters block, any latent variables are defined, with their dimensions specified
- In transformed parameters, any useful transformations of the parameters are computed. Note that transformations can be defined here, or in the next block (model). By defining them here, Stan will save samples of them.
- Finally, in the model block, the actual DGP is defined, using the notation "~" to capture distributed, and using various distribution functions to define the model's priors and likelihood.
We copy the Stan code below for reference, though the code we actually call is saved separately as mixed_logit.stan
.
stan_mixed_logit_example_code = """
data {
int<lower=1> n_cust;
int<lower=1> n_obs; // assumes every customer as the same number of obs
int<lower=1> n_alts;
int<lower=1> n_feats;
array[n_cust, n_obs, n_alts, n_feats] real X;
array[n_cust, n_obs] int<lower=1,upper=n_alts> y;
}
transformed data {
// create some convenient indexes
int<lower=1> n_pars = n_alts+n_feats-1;
}
parameters {
matrix[n_pars, n_cust] z_theta;
vector[n_pars] mu;
cholesky_factor_corr[n_pars] L_Omega;
vector<lower=0>[n_pars] sigma;
}
transformed parameters {
matrix[n_pars, n_cust] theta;
theta = rep_matrix(mu, n_cust) + diag_matrix(sqrt(sigma)) * z_theta;
}
model {
to_vector(z_theta) ~ normal(0,1);
sigma ~ cauchy(0,2.5);
mu ~ normal(0,5);
for (i in 1:n_cust) {
for (t in 1:n_obs) {
vector[n_alts] utils; // temp variable
utils[1] = dot_product(theta[n_alts:n_pars,i], to_vector(X[i,t,1]));
for (j in 2:n_alts) {
utils[j] = theta[j-1,i] + dot_product(theta[n_alts:n_pars,i], to_vector(X[i,t,j]));
}
y[i,t] ~ categorical_logit(utils);
}
}
}
"""
Notice how easy the model is to modify: we are able to use whatever priors we like for any of the variables, including sophisticated distribution assumptions like a Cauchy, which is a less informative prior due to its fat tails.
model = cmdstanpy.CmdStanModel(stan_file="mixed_logit.stan");
To use Stan, we need to format our data as a dictionary, where each entry in the dictionary has a matching variable name in the data block of the Stan program:
stan_data = {
"n_cust": n_cust,
"n_obs": n_obs,
"n_alts": n_alts,
"n_feats": n_feats,
"X": X,
"y": y
}
Inference: HMC/NUTS¶
With the model defined, we need to do inference, to either approximate or draw samples from the posterior. We'll start with the latter, which is very straightforward via HMC/NUTS, via the .sample
method on the model. This automatically uses NUTS, with 2000 iterations.
(Additional documentation can be found here: https://mc-stan.org/cmdstanpy/users-guide/examples/MCMC%20Sampling.html)
fit = model.sample(stan_data, chains=1, seed=0)
09:35:13 - cmdstanpy - INFO - CmdStan start processing
chain 1 | | 00:00 Status
09:35:57 - cmdstanpy - INFO - CmdStan done processing.
From this, we can reason about the posterior. For instance, we can look at the posterior distribution of $\mu$:
# Access the posterior draws of mu:
mu_samples_stan_nuts = fit.stan_variable('mu')
fig, axs = plt.subplots(2, 2, figsize=(10, 8))
for i, mu_val in enumerate(mu):
# Get the corresponding mu1_trace
ax = axs[i // 2, i % 2]
sns.histplot(mu_samples_stan_nuts[:,i], bins=30, kde=True, ax=ax, label="NUTS", stat='density')
ax.axvline(x=mu_val, color='red', linestyle='--', linewidth=2)
ax.set_xlabel('Samples')
ax.set_ylabel('Density')
ax.set_title(f'Posterior of $\\mu_{{{i}}}$')
# Adjust the layout
plt.tight_layout()
plt.legend()
# Show the plots
plt.show()
Inference: ADVI¶
Stan also supports two types of variational inference, the first being Automatic Differentiation Variational Inference (ADVI). In ADVI, all variables are first converted to unconstrained spaces by reversible transformations. Then mean field normal variational families are assumed in that unconstrained space, and optimized. This makes ADVI very automatic, but also somewhat limited, and (in the authors' experience) prone to yielding poor approximations to the posterior.
(Additional documentation can be found here: https://mc-stan.org/cmdstanpy/users-guide/examples/Variational%20Inference.html)
vi = model.variational(data=stan_data, seed=0)
09:35:57 - cmdstanpy - INFO - Chain [1] start processing 09:36:00 - cmdstanpy - INFO - Chain [1] done processing
First, notice, in terms of computation, this was much faster: just 3 seconds to get an approximate posterior, as opposed to almost 1 minute for NUTS. With the speed-up in computation comes a cost, which is part of the "iron simplex" described in the paper: ADVI is faster than HMC, but the approximate posterior is not particularly accurate in terms of uncertainty quantification. We demonstrate that below, looking at the four approximate posterior distributions of mu, comparing them to the posterior estimated by NUTS. We see that, for some parameters, it is reasonably accurate, while for others, it is biased. In all cases, ADVI underestimates the posterior variance.
fig, axs = plt.subplots(2, 2, figsize=(10, 8))
for i, mu_val in enumerate(mu):
# Get the corresponding mu1_trace
ax = axs[i // 2, i % 2]
sns.histplot(mu_samples_stan_nuts[:,i], bins=30, kde=True, ax=ax, label="NUTS", stat='density')
sns.histplot(vi.variational_sample[:, vi.column_names.index(f'mu[{i+1}]')], bins=30, kde=True, ax=ax, label='ADVI', stat='density')
ax.axvline(x=mu_val, color='red', linestyle='--', linewidth=2)
ax.set_xlabel('Samples')
ax.set_ylabel('Density')
ax.set_title(f'Posterior of $\\mu_{{{i}}}$')
# Adjust the layout
plt.tight_layout()
plt.legend()
# Show the plots
plt.show()
In some cases, like we will see later for black box SVI in Pyro, variational inference does a poor job at recovering global, population-level parameters like $\mu$, but a better job at recovering local, individual-level parameters like $\theta_{i}$. In ADVI, at least in this example, this is not the case:
fig, axs = plt.subplots(2, 2, figsize=(10, 8))
for i, mu_val in enumerate(mu):
# Get the corresponding mu1_trace
ax = axs[i // 2, i % 2]
sns.histplot(fit.stan_variable('theta')[:,0,i], bins=30, kde=True, ax=ax, label="NUTS", stat='density')
sns.histplot(vi.variational_sample[:, vi.column_names.index(f'theta[{i+1},1]')], bins=30, kde=True, ax=ax, label='ADVI', stat='density')
ax.set_xlabel('Samples')
ax.set_ylabel('Density')
ax.set_title(f'Posterior of $\\theta_{{1,{i}}}$')
# Adjust the layout
plt.tight_layout()
plt.legend()
# Show the plots
plt.show()
We see the local variables are similarly biased as $\mu$. This is likely a systematic bias within ADVI: in the authors' experience, ADVI rarely results in a good approximation, except in very simple cases.
Inference: Pathfinder VI¶
Stan also implements another automatic variational inference method called Pathfinder (https://jmlr.org/papers/v23/21-0889.html). This algorithm is supposedly faster than ADVI, and supposedly yields better approximations to the posterior. We will explore its performance below.
Note: again, the power of PML and PPLs is that we can very easily substitute these things. Moreover, we can adjust the settings of each of these algorithms to yield more or less accurate results. Below, we use pathfinder with settings recommended by one of the method's developers (see: https://users.aalto.fi/~ave/casestudies/Birthdays/birthdays.html), and with psis_resample=False
, which was required to get the algorithm to return meaningful posterior draws.
pf = model.pathfinder(
data=stan_data,
psis_resample=False,
seed=0,
inits=0.1,
num_paths=10,
num_single_draws=40,
draws=1000,
history_size=100,
max_lbfgs_iters=100,
)
09:36:01 - cmdstanpy - INFO - Chain [1] start processing 09:36:26 - cmdstanpy - INFO - Chain [1] done processing
fig, axs = plt.subplots(2, 2, figsize=(10, 8))
for i, mu_val in enumerate(mu):
# Get the corresponding mu1_trace
ax = axs[i // 2, i % 2]
sns.histplot(mu_samples_stan_nuts[:,i], bins=30, kde=True, ax=ax, label="NUTS", stat='density')
sns.histplot(pf.draws()[:, pf.column_names.index(f'mu[{i+1}]')], bins=30, kde=True, ax=ax, label='Pathfinder', stat='density')
ax.axvline(x=mu_val, color='red', linestyle='--', linewidth=2)
ax.set_xlabel('Samples')
ax.set_ylabel('Density')
ax.set_title(f'Posterior of $\\mu_{{{i}}}$')
# Adjust the layout
plt.tight_layout()
plt.legend()
# Show the plots
plt.show()
Here, we see that Pathfinder does a better job than ADVI at approximating the posterior for some of the variables. It also doesn't exhibit the same problematic underestimation of the posterior variance -- in fact, at least for this example, it seems to overestimate uncertainty.
(One caveat is that the Pathfinder algorithm settings were set manually, while ADVI's were left at the default. It's possible that by running ADVI for longer, or enabling other features (like a full-rank covariance) that ADVI's approximation would improve.)
Inference with Pyro¶
Pyro is a PPL that works differently than Stan: rather than introducing a standalone syntax, Pyro infers probability models from Python code. It relies on Pytorch, one of the most popular deep learning libraries, for fast computation and automatic differentiation. Thus, all latent variables in the Python code defining a model must be called as Pytorch or Pyro objects.
Pyro is more flexible than Stan, in the following ways:
- As in Stan, it has a robust implementation of NUTS; however, it also allows for computation using GPUs, which can greatly accelerate computation times
- For VI, Pyro has a sophisticated implementation of BBVI, which can also use stochastic gradients (i.e., SVI). This form of VI is more flexible than ADVI/Pathfinder, as it allows the user to create a custom variational distribution (more here: https://pyro.ai/examples/svi_part_i.html)
This additional flexibility comes with a cost: in our experience, Pyro can be a bit more challenging to use, especially if one wants to take advantage of its more advanced features. The investment of time in learning it, however, can be worth it, especially for implementation of large models on large datasets where scalability is a central concern.
Let's start by importing all the required libraries:
import torch
import pyro
import pyro.distributions as dist
import torch.nn.functional as F
torch.set_default_dtype(torch.float64)
# Setting the random seed for reproducibility
pyro.set_rng_seed(0)
Next, we will do some slight reshuffling of the data. Our Pyro code is slightly more general than our Stan code, allowing for different numbers of observations per person, which we will accommodate via an ID vector. Since Pyro operates on Pytorch, all data must be in the form of torch tensors.
pyro_data = {}
pyro_data['X'] = torch.tensor(X.reshape((n_cust*n_obs, n_alts, n_feats)))
pyro_data['y'] = torch.tensor(y.reshape(-1) - 1)
pyro_data['id'] = torch.tensor(torch.repeat_interleave(torch.arange(n_cust), n_obs))
n_obs_total, n_brands, n_feats = pyro_data['X'].shape
n_nonzero_icepts = n_brands - 1
n_pars = n_nonzero_icepts + n_feats
/var/folders/4r/l9kj4wmj7rx3hjwq2vqrshg00000gq/T/ipykernel_10348/784393635.py:4: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). pyro_data['id'] = torch.tensor(torch.repeat_interleave(torch.arange(n_cust), n_obs))
As described above, in Pyro, the model is written as a standard Python function, that codes the DGP using Pyro sampling statements, and torch-based objects:
# Define model
def pyro_model(X, y, id):
mu = pyro.sample("mu", dist.Normal(0, 5).expand([n_pars]).to_event(1))
sigma = pyro.sample("sigma", dist.HalfNormal(5).expand([n_pars]).to_event(1))
with pyro.plate("coefs", n_cust):
theta = pyro.sample("theta", dist.Normal(mu, sigma).to_event(1))
theta_icepts = theta[:, :n_nonzero_icepts]
theta_feats = theta[:, n_nonzero_icepts:]
zeros_for_first_brand = torch.zeros((n_obs_total, 1))
theta_icepts_with_zero = torch.cat([zeros_for_first_brand, theta_icepts[id]], dim=1)
theta_feats_expanded = theta_feats[id].unsqueeze(1)
X_weighted = X * theta_feats_expanded
utils = theta_icepts_with_zero + torch.sum(X_weighted, dim=-1)
p = F.softmax(utils, dim=-1)
with pyro.plate("data", len(y)):
pyro.sample("choices", dist.Categorical(probs=p), obs=y)
One interesting thing to note about the Pyro model is the pyro.plate
specification: this declares the sampling statements that follow as conditionally independent, and can be used to accelerate inference, as we will see later.
Inference: NUTS¶
Now, we define the inference algorithm. There are a two steps:
- First, we need to define our MCMC "kernel" -- all this means is, how are the iterations of the MCMC algorithm defined. For us, we will use NUTS, but Pyro also allows for sampling with RWMH and vanilla HMC. We use 1000 warmup samples and 1000 draws from the posterior, though in many cases, a smaller number will suffice, as NUTS is very efficient.
- Second, we use the kernel to define an MCMC object, which we then use to run MCMC.
from pyro.infer import MCMC, NUTS
kernel = NUTS(pyro_model)
mcmc = MCMC(kernel, num_samples=1000, warmup_steps=1000)
mcmc.run(pyro_data["X"], pyro_data["y"], pyro_data["id"])
Sample: 100%|██████████| 2000/2000 [00:52, 38.43it/s, step size=2.25e-01, acc. prob=0.841]
Now, we will collect the samples, and compare them to Stan's NUTS outputs:
pyro_nuts_samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}
fig, axs = plt.subplots(2, 2, figsize=(10, 8))
for i, mu_val in enumerate(mu):
ax = axs[i // 2, i % 2]
sns.histplot(mu_samples_stan_nuts[:,i], bins=30, kde=True, ax=ax, label="Stan", stat='density')
sns.histplot(pyro_nuts_samples['mu'][:,i], bins=30, kde=True, ax=ax, label='Pyro', stat='density')
ax.axvline(x=mu_val, color='red', linestyle='--', linewidth=2)
ax.set_xlabel('Samples')
ax.set_ylabel('Density')
ax.set_title(f'Posterior of $\\mu_{{{i}}}$')
plt.tight_layout()
plt.legend()
plt.show()
We see that the two almost perfectly coincide. This is as expected: NUTS is a fairly robust algorithm, that samples from the true posterior.
Inference: SVI (BBVI)¶
Now, let's explore how to do this same thing, but using stochastic variational inference. We will demonstrate three approaches:
- In the first approach, we will use a mean field normal variational family, which assumes that every latent variable's posterior can be approximated by a normal distribution. This is very easy to do in Pyro, as we will demonstrate below.
- In the second approach, we will use a normal distribution with a full covariance matrix as the variational family. This allows for also approximating posterior covariances.
- Finally, in a third approach, we will define a custom variational family, showcasing the flexibility of Pyro.
Approach 1: Mean Field Normal Variational Family
Another nice benefit of using PPLs is that they are modular. Similar to Stan, to use VI to estimate our model, we don't have to change our model. Instead, we just need to modify the inference procedure.
In Pyro, the variational approximation is called the guide
. Thus, setting the variational family is equivalent to defining the guide
. To use a mean field normal variational family, Pyro has a very useful utility function called AutoNormal
. We demonstrate below:
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.autoguide import AutoNormal
from pyro.optim import Adam
from tqdm import trange
guide = AutoNormal(pyro_model)
optimizer = Adam({"lr": 0.01})
svi = SVI(pyro_model, guide, optimizer, loss=Trace_ELBO())
num_steps = 2000
losses = []
for step in trange(num_steps):
loss = svi.step(pyro_data['X'], pyro_data['y'], pyro_data['id'])
losses.append(loss)
100%|██████████| 2000/2000 [00:04<00:00, 436.66it/s]
Notice how fast this is: mere seconds to compute an approximate posterior.
To assess convergence of the algorithm, we can plot the loss over iterations:
plt.plot(losses)
[<matplotlib.lines.Line2D at 0x31c2ee920>]
We see the algorithm converged very quickly. We could have ran it even for just 1000 iterations, and gotten convergence.
Note that convergence of the algorithm does not mean that it "converged to the truth" -- in VI, the posterior is always being approximated. To assess how good that approximation is, we can again draw some samples from the approximation, and compare them to NUTS. Pyro provides a utility class called Predictive
for this and for computing other types of posterior predictive quantities.
from pyro.infer import Predictive
predictive = Predictive(pyro_model, guide=guide, num_samples=1000)
svi_samples = predictive(pyro_data['X'], pyro_data['y'], pyro_data['id'])
fig, axs = plt.subplots(2, 2, figsize=(10, 8))
for i, mu_val in enumerate(mu):
ax = axs[i // 2, i % 2]
sns.histplot(pyro_nuts_samples['mu'][:,i], bins=30, kde=True, ax=ax, label="NUTS", stat='density')
sns.histplot(svi_samples['mu'].detach().numpy()[:,0,i], bins=30, kde=True, ax=ax, label='SVI', stat='density')
ax.axvline(x=mu_val, color='red', linestyle='--', linewidth=2)
ax.set_xlabel('Samples')
ax.set_ylabel('Density')
ax.set_title(f'Posterior of $\\mu_{{{i}}}$')
plt.tight_layout()
plt.legend()
plt.show()
We see that, similar to ADVI, the approximation for $\mu$ is pretty bad: posterior covariances are underestimated, and in some cases, the distribution does not even capture the true value of $\mu$.
Such behavior is typical for variational inference: in our experience, especially in mean field VI, uncertainty in population parameters like $\mu$ is often underestimated, while inference for individual-level parameters like $\theta_i$ is often better. Indeed, we see that below, where we compare the four components of $\theta_i$ for the first simulated person (indexed 0):
fig, axs = plt.subplots(2, 2, figsize=(10, 8))
for i, mu_val in enumerate(mu):
ax = axs[i // 2, i % 2]
sns.histplot(pyro_nuts_samples['theta'][:,0,i], bins=30, kde=True, ax=ax, label="NUTS", stat='density')
sns.histplot(svi_samples['theta'].detach().numpy()[:,0,i], bins=30, kde=True, ax=ax, label='SVI', stat='density')
ax.set_xlabel('Samples')
ax.set_ylabel('Density')
ax.set_title(f'Posterior of $\\theta_{{1,{i}}}$')
plt.tight_layout()
plt.legend()
plt.show()
This is an important distinction: often for decision-making, it is really the local, individual-level variables like $\theta_i$ that truly matter.
Still, we may want to see if we can improve inference for all parameters, including global variables like $\mu$. To do so, we can experiment with different variational families.
from pyro.infer.autoguide import AutoMultivariateNormal
guide_mvn = AutoMultivariateNormal(pyro_model)
optimizer = Adam({"lr": 0.01})
svi_mvn = SVI(pyro_model, guide_mvn, optimizer, loss=Trace_ELBO())
num_steps = 2000
losses = []
for step in trange(num_steps):
loss = svi_mvn.step(pyro_data['X'], pyro_data['y'], pyro_data['id'])
losses.append(loss)
100%|██████████| 2000/2000 [00:21<00:00, 93.52it/s]
Note that this algorithm took longer to run: the approximation is more complex, and thus, more computationally expensive. Again, this emphasizes the "iron triangle" -- to get accuracy, we often must sacrifice speed.
As before, let's check convergence:
plt.plot(losses)
[<matplotlib.lines.Line2D at 0x31da013f0>]
In this case, the oscillation of the loss is more notable than before. An oscillating loss is expected: this is stochastic variational inference, so the gradient is always being sampled.
Let's compare this posterior approximation to the truth:
predictive = Predictive(pyro_model, guide=guide_mvn, num_samples=1000)
svi_samples = predictive(pyro_data['X'], pyro_data['y'], pyro_data['id'])
fig, axs = plt.subplots(2, 2, figsize=(10, 8))
for i, mu_val in enumerate(mu):
ax = axs[i // 2, i % 2]
sns.histplot(pyro_nuts_samples['mu'][:,i], bins=30, kde=True, ax=ax, label="NUTS", stat='density')
sns.histplot(svi_samples['mu'].detach().numpy()[:,0,i], bins=30, kde=True, ax=ax, label='SVI', stat='density')
ax.axvline(x=mu_val, color='red', linestyle='--', linewidth=2)
ax.set_xlabel('Samples')
ax.set_ylabel('Density')
ax.set_title(f'Posterior of $\\mu_{{{i}}}$')
plt.tight_layout()
plt.legend()
plt.show()
In this case, we can see that the approximation has improved, but does not completely align with NUTS.
To note, and in keeping with the idea of the iron simplex, the multivariate normal variational family will scale very poorly as the number of parameters grows, and it becomes difficult to model them all simultaneously with a full-rank covariance matrix. To handle this, Pyro also has a AutoLowRankMultivariateNormal
guide, but that also introduces some error. There's always a trade-off!
The Effect of num_particles
Besides changing the variational approximation, another way we can potentially improve our inference procedure is by using more samples in the stochastic gradient estimator, which should allow the model to converge better. While there's no evidence that this particular model is having trouble converging, in complex models, this is sometimes necessary. This will again come at the cost of speed. We demonstrate first with the original mean field normal SVI:
pyro.clear_param_store()
guide = AutoNormal(pyro_model)
optimizer = Adam({"lr": 0.01})
svi = SVI(pyro_model, guide, optimizer, loss=Trace_ELBO(num_particles=10))
num_steps = 2000
losses = []
for step in trange(num_steps):
loss = svi.step(pyro_data['X'], pyro_data['y'], pyro_data['id'])
losses.append(loss)
100%|██████████| 2000/2000 [00:42<00:00, 47.53it/s]
predictive = Predictive(pyro_model, guide=guide, num_samples=1000)
svi_samples = predictive(pyro_data['X'], pyro_data['y'], pyro_data['id'])
fig, axs = plt.subplots(2, 2, figsize=(10, 8))
for i, mu_val in enumerate(mu):
ax = axs[i // 2, i % 2]
sns.histplot(pyro_nuts_samples['mu'][:,i], bins=30, kde=True, ax=ax, label="NUTS", stat='density')
sns.histplot(svi_samples['mu'].detach().numpy()[:,0,i], bins=30, kde=True, ax=ax, label='SVI', stat='density')
ax.axvline(x=mu_val, color='red', linestyle='--', linewidth=2)
ax.set_xlabel('Samples')
ax.set_ylabel('Density')
ax.set_title(f'Posterior of $\\mu_{{{i}}}$')
plt.tight_layout()
plt.legend()
plt.show()
And again with the multivariate normal:
pyro.clear_param_store()
guide_mvn = AutoMultivariateNormal(pyro_model)
optimizer = Adam({"lr": 0.01})
svi_mvn = SVI(pyro_model, guide_mvn, optimizer, loss=Trace_ELBO(num_particles=10))
num_steps = 2000
losses = []
for step in trange(num_steps):
loss = svi_mvn.step(pyro_data['X'], pyro_data['y'], pyro_data['id'])
losses.append(loss)
100%|██████████| 2000/2000 [02:55<00:00, 11.41it/s]
predictive = Predictive(pyro_model, guide=guide_mvn, num_samples=1000)
svi_samples = predictive(pyro_data['X'], pyro_data['y'], pyro_data['id'])
fig, axs = plt.subplots(2, 2, figsize=(10, 8))
for i, mu_val in enumerate(mu):
ax = axs[i // 2, i % 2]
sns.histplot(pyro_nuts_samples['mu'][:,i], bins=30, kde=True, ax=ax, label="NUTS", stat='density')
sns.histplot(svi_samples['mu'].detach().numpy()[:,0,i], bins=30, kde=True, ax=ax, label='SVI', stat='density')
ax.axvline(x=mu_val, color='red', linestyle='--', linewidth=2)
ax.set_xlabel('Samples')
ax.set_ylabel('Density')
ax.set_title(f'Posterior of $\\mu_{{{i}}}$')
plt.tight_layout()
plt.legend()
plt.show()
We can see this not only improved the stability of SVI (i.e., fewer oscillations in the ELBO loss), but also led to better posterior approximations. The cost was longer computation time: in the multivariate normal case, even longer than NUTS. That being said, the promise of these methods is in massive data, where the stochastic gradients and ability to mini-batch really matter. Hence, when we scale the problem up, we should see that this approach scales better. We will explore that later.
Before we demonstrate the scalability of this approach, we explore one final way of specifying the variational family: a custom guide.
Approach 3: Custom Variational Family (+ Variational EM)
Pyro also allows for customizing the variational family, by writing custom guide functions. While custom guides can be very elaborate, an easy way of doing this is by combining custom guides with automatic guides, via the simple to use AutoGuideList
. We show both approaches below.
First, we use a fully custom guide with three parts:
- We retain the mean field normal assumption for $\mu$
- Rather than estimating a posterior for $\sigma$, we treat it as a hyperparameter and optimize it. This can be accomplished by using Pyro's
dist.Delta
, and makes our inference equivalent to variational EM (i.e., variational approximations for some parameters, then conditional on those, EM for the others). - We allow for additional uncertainty in $\theta$ through a t-distribution
def custom_guide(X, y, id):
mu_loc = pyro.param("mu_loc", torch.zeros(n_pars))
mu_scale = pyro.param("mu_scale", 0.1*torch.ones(n_pars), constraint=dist.constraints.positive)
mu = pyro.sample("mu", dist.Normal(mu_loc, mu_scale).to_event(1))
sigma_delta = pyro.param("sigma_delta", torch.ones(n_pars), constraint=dist.constraints.positive)
sigma = pyro.sample("sigma", dist.Delta(sigma_delta).to_event(1))
theta_loc = pyro.param(f"theta_loc", torch.zeros((n_cust,n_pars)))
theta_scale = pyro.param(f"theta_scale", 0.1*torch.ones((n_cust,n_pars)), constraint=dist.constraints.positive)
with pyro.plate("coefs", n_cust):
theta = pyro.sample(f"theta", dist.StudentT(torch.tensor(2.0), theta_loc, theta_scale).to_event(1))
pyro.clear_param_store()
optimizer = Adam({"lr": 1e-2})
svi_custom = SVI(pyro_model, custom_guide, optimizer, loss=Trace_ELBO(num_particles=10))
num_steps = 2000
losses = []
for step in trange(num_steps):
loss = svi_custom.step(pyro_data['X'], pyro_data['y'], pyro_data['id'])
losses.append(loss)
100%|██████████| 2000/2000 [00:37<00:00, 53.14it/s]
predictive = Predictive(pyro_model, guide=custom_guide, num_samples=1000)
svi_samples = predictive(pyro_data['X'], pyro_data['y'], pyro_data['id'])
fig, axs = plt.subplots(2, 2, figsize=(10, 8))
for i, mu_val in enumerate(mu):
ax = axs[i // 2, i % 2]
sns.histplot(pyro_nuts_samples['mu'][:,i], bins=30, kde=True, ax=ax, label="NUTS", stat='density')
sns.histplot(svi_samples['mu'].detach().numpy()[:,0,i], bins=30, kde=True, ax=ax, label='SVI', stat='density')
ax.axvline(x=mu_val, color='red', linestyle='--', linewidth=2)
ax.set_xlabel('Samples')
ax.set_ylabel('Density')
ax.set_title(f'Posterior of $\\mu_{{{i}}}$')
plt.tight_layout()
plt.legend()
plt.show()
Another useful tool is Pyro's AutoGuideList
, which allows for easily combining automatic guides. Here, we will put a mean field normal guide for $\mu$, a delta guide for $\sigma$, and a low-rank approximation to the multivariate normal guide for $\theta$:
import pyro.poutine as poutine
from pyro.infer.autoguide import AutoGuideList, AutoDelta, AutoLowRankMultivariateNormal
list_guide = AutoGuideList(pyro_model)
list_guide.append(AutoNormal(poutine.block(pyro_model, expose=["mu"])))
list_guide.append(AutoDelta(poutine.block(pyro_model, expose=["sigma"])))
list_guide.append(AutoLowRankMultivariateNormal(poutine.block(pyro_model, hide=["mu","sigma"])))
In this, the poutine
functionality tells Pyro to expose or not expose (hide) the AutoGuide
function to different variables. The custom_theta_guide
provides a customized guide for $\theta$. Together, this gives the same guide we manually specified above.
pyro.clear_param_store()
optimizer = Adam({"lr": 1e-2})
svi_custom = SVI(pyro_model, list_guide, optimizer, loss=Trace_ELBO(num_particles=10))
num_steps = 2000
losses = []
for step in trange(num_steps):
loss = svi_custom.step(pyro_data['X'], pyro_data['y'], pyro_data['id'])
losses.append(loss)
100%|██████████| 2000/2000 [01:06<00:00, 30.17it/s]
predictive = Predictive(pyro_model, guide=list_guide, num_samples=1000)
svi_samples = predictive(pyro_data['X'], pyro_data['y'], pyro_data['id'])
fig, axs = plt.subplots(2, 2, figsize=(10, 8))
for i, mu_val in enumerate(mu):
ax = axs[i // 2, i % 2]
sns.histplot(pyro_nuts_samples['mu'][:,i], bins=30, kde=True, ax=ax, label="NUTS", stat='density')
sns.histplot(svi_samples['mu'].detach().numpy()[:,0,i], bins=30, kde=True, ax=ax, label='SVI', stat='density')
ax.axvline(x=mu_val, color='red', linestyle='--', linewidth=2)
ax.set_xlabel('Samples')
ax.set_ylabel('Density')
ax.set_title(f'Posterior of $\\mu_{{{i}}}$')
plt.tight_layout()
plt.legend()
plt.show()
While we include this guide mostly to demonstrate the flexibility of Pyro and its AutoGuideList
functionality, we also note that this variational family seems to work well, except for one parameter.
These examples also highlight the importance of choosing a good variational family, and one of the key challenges of variational inference: while in this simple case, it is possible to compare to the "true" NUTS-based posterior, in practice, it is impossible to know a priori how good an approximation is.
Illustration: Usefulness of VI¶
Until now, we have only explored examples with a small data set. In this case, even NUTS runs reasonably quickly, and hence there's no real reason (besides curiosity) for exploring different approaches to SVI. However, when the data become large, SVI can make inference feasible, at the expense of accuracy. To demonstrate, let's scale up our logit example:
n_cust = 40000 # A new number of customers; we leave the remaining settings as before (10 choices per customer, 3 alternatives, 2 features)
theta = np.random.multivariate_normal(mean=mu, cov=np.diag(sigma), size=n_cust)
X = np.random.normal(size=(n_cust, n_obs, n_alts, n_feats))
y = np.zeros((n_cust,n_obs), dtype=int)
for i in range(n_cust):
for t in range(n_obs):
v = np.zeros(n_alts)
v[0] = theta[i,coef_idx] @ X[i,t,0]
for j in range(1,n_alts):
v[j] = theta[i,j-1] + theta[i,coef_idx] @ X[i,t,j]
probs = softmax(v)
y[i,t] = np.random.choice(n_alts, size=1, p=probs)[0]+1
pyro_bigger_data = {}
pyro_bigger_data['X'] = torch.tensor(X.reshape((n_cust*n_obs, n_alts, n_feats)))
pyro_bigger_data['y'] = torch.tensor(y.reshape(-1) - 1)
pyro_bigger_data['id'] = torch.tensor(torch.repeat_interleave(torch.arange(n_cust), n_obs))
n_obs_total, n_brands, n_feats = pyro_bigger_data['X'].shape
n_nonzero_icepts = n_brands - 1
n_pars = n_nonzero_icepts + n_feats
/var/folders/4r/l9kj4wmj7rx3hjwq2vqrshg00000gq/T/ipykernel_10348/239420210.py:4: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). pyro_bigger_data['id'] = torch.tensor(torch.repeat_interleave(torch.arange(n_cust), n_obs))
Here, we scaled up the example only modestly: to 40,000 customers from 200. For a sense of scale, 40,000 customers is roughly equivalent to the number of households in Nielsen's panel data. Let's re-explore some of our inference algorithms.
To begin, let's run NUTS:
pyro.clear_param_store()
kernel = NUTS(pyro_model)
mcmc = MCMC(kernel, num_samples=1000, warmup_steps=1000)
mcmc.run(pyro_bigger_data["X"], pyro_bigger_data["y"], pyro_bigger_data["id"])
Sample: 100%|██████████| 2000/2000 [1:51:20, 3.34s/it, step size=7.23e-02, acc. prob=0.792]
As we can see, even for this modest size data, NUTS is quite slow. Imagine, then, how long the model would take to estimate for 400,000, or even 40M customers!
We will save the results for later use:
pyro_bigger_nuts_samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}
In contrast, SVI can run quite quickly. We start with the simple mean field normal guide:
pyro.clear_param_store()
guide = AutoNormal(pyro_model)
optimizer = Adam({"lr": 0.01})
svi_bigger = SVI(pyro_model, guide, optimizer, loss=Trace_ELBO(num_particles=10))
num_steps = 2000
losses = []
for step in trange(num_steps):
loss = svi_bigger.step(pyro_bigger_data['X'], pyro_bigger_data['y'], pyro_bigger_data['id'])
losses.append(loss)
100%|██████████| 2000/2000 [14:49<00:00, 2.25it/s]
plt.plot(losses[1000:])
[<matplotlib.lines.Line2D at 0x31f675660>]
predictive = Predictive(pyro_model, guide=guide, num_samples=1000)
svi_samples_mfn = predictive(pyro_bigger_data['X'], pyro_bigger_data['y'], pyro_bigger_data['id'])
And now try our custom, variational EM guide:
pyro.clear_param_store()
optimizer = Adam({"lr": 0.01})
svi_bigger = SVI(pyro_model, custom_guide, optimizer, loss=Trace_ELBO(num_particles=10))
num_steps = 2000
losses = []
for step in trange(num_steps):
loss = svi_bigger.step(pyro_bigger_data['X'], pyro_bigger_data['y'], pyro_bigger_data['id'])
losses.append(loss)
100%|██████████| 2000/2000 [16:38<00:00, 2.00it/s]
As we can see, SVI runs in a quarter of the time as NUTS, even using our custom guide.
Per the iron simplex, this gain in computation comes at the cost of accuracy:
predictive = Predictive(pyro_model, guide=custom_guide, num_samples=1000)
svi_samples_vem = predictive(pyro_bigger_data['X'], pyro_bigger_data['y'], pyro_bigger_data['id'])
fig, axs = plt.subplots(2, 2, figsize=(10, 8))
for i, mu_val in enumerate(mu):
ax = axs[i // 2, i % 2]
sns.histplot(pyro_bigger_nuts_samples['mu'][:,i], bins=30, kde=True, ax=ax, label="NUTS", stat='density')
sns.histplot(svi_samples_mfn['mu'].detach().numpy()[:,0,i], bins=30, kde=True, ax=ax, label='SVI-MF', stat='density')
sns.histplot(svi_samples_vem['mu'].detach().numpy()[:,0,i], bins=30, kde=True, ax=ax, label='SVI-VEM', stat='density')
ax.axvline(x=mu_val, color='red', linestyle='--', linewidth=2)
ax.set_xlabel('Samples')
ax.set_ylabel('Density')
ax.set_title(f'Posterior of $\\mu_{{{i}}}$')
# Adjust the layout
plt.tight_layout()
plt.legend()
# Show the plots
plt.show()
One thing to note is that these plots make the bias seem quite severe, but these parameters are also estimated quite precisely: for each parameter, the posterior medians are fairly close.
As we saw before, the accuracy of SVI is often lower for population parameters like $\mu$, and higher for individual-level parameters like $\theta_i$, though in this case, the difference is somewhat less dramatic:
fig, axs = plt.subplots(2, 2, figsize=(10, 8))
for i, mu_val in enumerate(mu):
ax = axs[i // 2, i % 2]
sns.histplot(pyro_bigger_nuts_samples['theta'][:,0,i], bins=30, kde=True, ax=ax, label="NUTS", stat='density')
sns.histplot(svi_samples_mfn['theta'].detach().numpy()[:,0,i], bins=30, kde=True, ax=ax, label='SVI-MF', stat='density')
sns.histplot(svi_samples_vem['theta'].detach().numpy()[:,0,i], bins=30, kde=True, ax=ax, label='SVI-VEM', stat='density')
ax.set_xlabel('Samples')
ax.set_ylabel('Density')
ax.set_title(f'Posterior of $\\theta_{{1,{i}}}$')
plt.tight_layout()
plt.legend()
plt.show()
Note the fat tails are a function of the assumed Student-t variational family on $\theta$.
Final Thoughts¶
In this code companion, we have demonstrated both the usefulness of PPLs, and the concept of the Iron Simplex, through demonstrations of HMC/NUTS, and different forms of variational inference. The key takeaway is that Bayesian inference can be fast and reasonably scalable, but the further one pushes in that direction, the more approximations must be made. Another takeaway is how easy it is to manipulate PPLs: with minimal code changes, we can change both the model, and how inference works, employing fairly sophisticated inference algorithms.
One thing to note is that an aspect of scalability that is not covered by the Iron Simplex is pure gains in computation speed from hardware and basic algorithmic efficiency improvements. In all of the previous examples, computation times were based on an M2 MacBook Pro laptop running the models on its CPU. Using faster CPUs and, most importantly, GPUs, can accelerate model training dramatically. Some libraries, including Pyro and its cousin, NumPyro, can use GPUs to accelerate all types of Bayesian inference. NumPyro in particular deserves special note: as of this writing, NumPyro has, in our experience, the fastest implementation of NUTS, which even when deployed on a basic CPU, can yield dramatic improvements in computation time, due to its reliance on JIT compilation and the JAX library. Deploying these algorithms on GPUs would even more dramatically improve computation time, all without needing to sacrifice accuracy.
Additional Models and Resources¶
For readers interested in exploring PPLs beyond the examples here, and to see how they can be used to estimate all of the models referenced in the paper, there are a number of resources available. Some helpful links include:
- General Help
- Stan reference manual: https://mc-stan.org/docs/stan-users-guide/
- Pyro tutorials: http://pyro.ai/examples/intro_long.html
- Flexible Models of Distributions
- Finite mixture models in Stan: https://mc-stan.org/docs/stan-users-guide/finite-mixtures.html
- Dirichlet process mixtures in Pyro: https://pyro.ai/examples/dirichlet_process_mixture.html
- Mixed Membership Models
- Clustering-type models in Stan: https://mc-stan.org/docs/stan-users-guide/clustering.html
- Topic modeling in Pyro: https://pyro.ai/examples/prodlda.html
- Nonlinearities and Dynamics
- Gaussian processes in Stan: https://mc-stan.org/docs/stan-users-guide/gaussian-processes.html
- Gaussian processes in Pyro: https://pyro.ai/examples/gp.html
- While there is no specific documentation about Bayesian neural networks in Pyro, Pyro is capable of handling them. See examples like:
- Variational Autoencoders: https://pyro.ai/examples/vae.html
- Deep Markov Model: https://pyro.ai/examples/dmm.html?highlight=neural
- Attend Infer Repeat: https://pyro.ai/examples/air.html?highlight=neural
- Deep Generative Models
- VAE in Pyro: http://pyro.ai/examples/vae.html
- Sparse Gamma Deep Exponential Family in Pyro: http://pyro.ai/examples/sparse_gamma.html
- Data Fusion
- Missingness in Stan: https://mc-stan.org/docs/stan-users-guide/missing-data.html
- Advanced Topics in Inference
- Amortized inference in Pyro:
- In the context of VAEs: http://pyro.ai/examples/vae.html
- In the context of LDA: http://pyro.ai/examples/lda.html
- Normalizing flows as variational families in Pyro:
- Introduction: http://pyro.ai/examples/normalizing_flows_intro.html
- General tutorial: http://pyro.ai/examples/svi_flow_guide.html
- Example in VAEs: http://pyro.ai/examples/vae_flow_prior.html
- Amortized inference in Pyro:
(Note: none of these are maintained by the authors of this paper.)