Bayes by Backprop with gluon (NN, classification)

After discussing Bayes by Backprop from scratch in a previous notebook, we can now look at the corresponding implementation as gluon components.

We start off with the usual set of imports.

In [1]:
from __future__ import print_function
import collections
import mxnet as mx
import numpy as np
from mxnet import nd, autograd
from matplotlib import pyplot as plt
from mxnet import gluon

For easy tuning and experimentation, we define a dictionary holding the hyper-parameters of our model.

In [2]:
config = {
    "num_hidden_layers": 2,
    "num_hidden_units": 400,
    "batch_size": 128,
    "epochs": 10,
    "learning_rate": 0.001,
    "num_samples": 1,
    "pi": 0.25,
    "sigma_p": 1.0,
    "sigma_p1": 0.75,
    "sigma_p2": 0.01,
}

Also, we specify the device context for MXNet.

In [3]:
ctx = mx.cpu()

Load MNIST data

We will again train and evaluate the algorithm on the MNIST data set and therefore load the data set as follows:

In [4]:
def transform(data, label):
    return data.astype(np.float32)/126.0, label.astype(np.float32)

mnist = mx.test_utils.get_mnist()
num_inputs = 784
num_outputs = 10
batch_size = config['batch_size']

train_data = mx.gluon.data.DataLoader(mx.gluon.data.vision.MNIST(train=True, transform=transform),
                                      batch_size, shuffle=True)
test_data = mx.gluon.data.DataLoader(mx.gluon.data.vision.MNIST(train=False, transform=transform),
                                     batch_size, shuffle=False)

num_train = sum([batch_size for i in train_data])
num_batches = num_train / batch_size

In order to reproduce and compare the results from the paper, we preprocess the pixels by dividing by 126.

Model definition

Neural net modeling

As our model we are using a straightforward MLP and we are wiring up our network just as we are used to in gluon. Note that we are not using any special layers during the definition of our network, as we believe that Bayes by Backprop should be thought of as a training method, rather than a speical architecture.

In [5]:
num_layers = config['num_hidden_layers']
num_hidden = config['num_hidden_units']

net = gluon.nn.Sequential()
with net.name_scope():
    for i in range(num_layers):
        net.add(gluon.nn.Dense(num_hidden, activation="relu"))
    net.add(gluon.nn.Dense(num_outputs))

Build objective/loss

Again, we define our loss function as described in Bayes by Backprop from scratch. Note that we are bundling all of this functionality as part of a gluon.loss.Loss subclass, where the loss computation is performed in the hybrid_forward function.

In [6]:
class BBBLoss(gluon.loss.Loss):
    def __init__(self, log_prior="gaussian", log_likelihood="softmax_cross_entropy",
                 sigma_p1=1.0, sigma_p2=0.1, pi=0.5, weight=None, batch_axis=0, **kwargs):
        super(BBBLoss, self).__init__(weight, batch_axis, **kwargs)
        self.log_prior = log_prior
        self.log_likelihood = log_likelihood
        self.sigma_p1 = sigma_p1
        self.sigma_p2 = sigma_p2
        self.pi = pi

    def log_softmax_likelihood(self, yhat_linear, y):
        return nd.nansum(y * nd.log_softmax(yhat_linear), axis=0, exclude=True)

    def log_gaussian(self, x, mu, sigma):
        return -0.5 * np.log(2.0 * np.pi) - nd.log(sigma) - (x - mu) ** 2 / (2 * sigma ** 2)

    def gaussian_prior(self, x):
        sigma_p = nd.array([self.sigma_p1], ctx=ctx)
        return nd.sum(self.log_gaussian(x, 0., sigma_p))

    def gaussian(self, x, mu, sigma):
        scaling = 1.0 / nd.sqrt(2.0 * np.pi * (sigma ** 2))
        bell = nd.exp(- (x - mu) ** 2 / (2.0 * sigma ** 2))

        return scaling * bell

    def scale_mixture_prior(self, x):
        sigma_p1 = nd.array([self.sigma_p1], ctx=ctx)
        sigma_p2 = nd.array([self.sigma_p2], ctx=ctx)
        pi = self.pi

        first_gaussian = pi * self.gaussian(x, 0., sigma_p1)
        second_gaussian = (1 - pi) * self.gaussian(x, 0., sigma_p2)

        return nd.log(first_gaussian + second_gaussian)

    def hybrid_forward(self, F, output, label, params, mus, sigmas, sample_weight=None):
        log_likelihood_sum = nd.sum(self.log_softmax_likelihood(output, label))
        prior = None
        if self.log_prior == "gaussian":
            prior = self.gaussian_prior
        elif self.log_prior == "scale_mixture":
            prior = self.scale_mixture_prior
        log_prior_sum = sum([nd.sum(prior(param)) for param in params])
        log_var_posterior_sum = sum([nd.sum(self.log_gaussian(params[i], mus[i], sigmas[i])) for i in range(len(params))])
        return 1.0 / num_batches * (log_var_posterior_sum - log_prior_sum) - log_likelihood_sum

bbb_loss = BBBLoss(log_prior="scale_mixture", sigma_p1=config['sigma_p1'], sigma_p2=config['sigma_p2'])

Parameter initialization

First, we need to initialize all the network’s parameters, which are only point estimates of the weights at this point. We will soon see, how we can still train the netork in a Bayesian fashion, without interfering with the netowk’s architecture.

In [7]:
net.collect_params().initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)

Then we have to forward-propagate a single data set entry once to set up all network parameters (weights and biases) with the desired initliaizer specified above.

In [8]:
for i, (data, label) in enumerate(train_data):
    data = data.as_in_context(ctx).reshape((-1, 784))
    net(data)
    break
In [9]:
weight_scale = .1
rho_offset = -3

# initialize variational parameters; mean and variance for each weight
mus = []
rhos = []

shapes = list(map(lambda x: x.shape, net.collect_params().values()))

for shape in shapes:
    mu = gluon.Parameter('mu', shape=shape, init=mx.init.Normal(weight_scale))
    rho = gluon.Parameter('rho',shape=shape, init=mx.init.Constant(rho_offset))
    mu.initialize(ctx=ctx)
    rho.initialize(ctx=ctx)
    mus.append(mu)
    rhos.append(rho)

variational_params = mus + rhos

raw_mus = list(map(lambda x: x.data(ctx), mus))
raw_rhos = list(map(lambda x: x.data(ctx), rhos))

Optimizer

Now, we still have to choose the optimizer we wish to use for training. This time, we are using the adam optimizer.

In [10]:
trainer = gluon.Trainer(variational_params, 'adam', {'learning_rate': config['learning_rate']})

Main training loop

Sampling

Recall the 3-step process for the variational parameters:

  1. Sample \(\mathbf{\epsilon} \sim \mathcal{N}(\mathbf{0},\mathbf{I}^{d})\)
In [11]:
def sample_epsilons(param_shapes):
    epsilons = [nd.random_normal(shape=shape, loc=0., scale=1.0, ctx=ctx) for shape in param_shapes]
    return epsilons
  1. Transform \(\mathbf{\rho}\) to a positive vector via the softplus function: \(\mathbf{\sigma} = \text{softplus}(\mathbf{\rho}) = \log(1 + \exp(\mathbf{\rho}))\)
In [12]:
def softplus(x):
    return nd.log(1. + nd.exp(x))

def transform_rhos(rhos):
    return [softplus(rho) for rho in rhos]
  1. Compute \(\mathbf{w}\): \(\mathbf{w} = \mathbf{\mu} + \mathbf{\sigma} \circ \mathbf{\epsilon}\), where the \(\circ\) operator represents the element-wise multiplication. This is the “reparametrization trick” for separating the randomness from the parameters of \(q\).
In [13]:
def transform_gaussian_samples(mus, sigmas, epsilons):
    samples = []
    for j in range(len(mus)):
        samples.append(mus[j] + sigmas[j] * epsilons[j])
    return samples

Putting these three steps together we get:

In [14]:
def generate_weight_sample(layer_param_shapes, mus, rhos):
    # sample epsilons from standard normal
    epsilons = sample_epsilons(layer_param_shapes)

    # compute softplus for variance
    sigmas = transform_rhos(rhos)

    # obtain a sample from q(w|theta) by transforming the epsilons
    layer_params = transform_gaussian_samples(mus, sigmas, epsilons)

    return layer_params, sigmas

Evaluation metric

In order to being able to assess our model performance we define a helper function which evaluates our accuracy on an ongoing basis.

In [15]:
def evaluate_accuracy(data_iterator, net, layer_params):
    numerator = 0.
    denominator = 0.
    for i, (data, label) in enumerate(data_iterator):
        data = data.as_in_context(ctx).reshape((-1, 784))
        label = label.as_in_context(ctx)

        for l_param, param in zip(layer_params, net.collect_params().values()):
            param._data[list(param._data.keys())[0]] = l_param

        output = net(data)
        predictions = nd.argmax(output, axis=1)
        numerator += nd.sum(predictions == label)
        denominator += data.shape[0]
    return (numerator / denominator).asscalar()

Complete loop

The complete training loop is given below.

In [16]:
epochs = config['epochs']
learning_rate = config['learning_rate']
smoothing_constant = .01
train_acc = []
test_acc = []

for e in range(epochs):
    for i, (data, label) in enumerate(train_data):
        data = data.as_in_context(ctx).reshape((-1, 784))
        label = label.as_in_context(ctx)
        label_one_hot = nd.one_hot(label, 10)

        with autograd.record():
            # generate sample
            layer_params, sigmas = generate_weight_sample(shapes, raw_mus, raw_rhos)

            # overwrite network parameters with sampled parameters
            for sample, param in zip(layer_params, net.collect_params().values()):
                param._data[list(param._data.keys())[0]] = sample

            # forward-propagate the batch
            output = net(data)

            # calculate the loss
            loss = bbb_loss(output, label_one_hot, layer_params, raw_mus, sigmas)

            # backpropagate for gradient calculation
            loss.backward()

        trainer.step(data.shape[0])

        # calculate moving loss for monitoring convergence
        curr_loss = nd.mean(loss).asscalar()
        moving_loss = (curr_loss if ((i == 0) and (e == 0))
                       else (1 - smoothing_constant) * moving_loss + (smoothing_constant) * curr_loss)

    test_accuracy = evaluate_accuracy(test_data, net, raw_mus)
    train_accuracy = evaluate_accuracy(train_data, net, raw_mus)
    train_acc.append(np.asscalar(train_accuracy))
    test_acc.append(np.asscalar(test_accuracy))
    print("Epoch %s. Loss: %s, Train_acc %s, Test_acc %s" %
          (e, moving_loss, train_accuracy, test_accuracy))

plt.plot(train_acc)
plt.plot(test_acc)
plt.show()
Epoch 0. Loss: 2121.26601683, Train_acc 0.950017, Test_acc 0.9503
Epoch 1. Loss: 1918.8522369, Train_acc 0.963667, Test_acc 0.961
Epoch 2. Loss: 1813.43826684, Train_acc 0.970483, Test_acc 0.966
Epoch 3. Loss: 1740.46931458, Train_acc 0.969417, Test_acc 0.9658
Epoch 4. Loss: 1681.04620544, Train_acc 0.973767, Test_acc 0.9694
Epoch 5. Loss: 1625.9179831, Train_acc 0.975267, Test_acc 0.9709
Epoch 6. Loss: 1568.97912286, Train_acc 0.975317, Test_acc 0.9709
Epoch 7. Loss: 1509.50606071, Train_acc 0.977733, Test_acc 0.973
Epoch 8. Loss: 1449.39600539, Train_acc 0.978467, Test_acc 0.9721
Epoch 9. Loss: 1390.66561781, Train_acc 0.97885, Test_acc 0.9736
../_images/chapter18_variational-methods-and-uncertainty_bayes-by-backprop-gluon_35_1.png

For demonstration purposes, we can now take a look at one particular weight by plotting its distribution.

In [17]:
def gaussian(x, mu, sigma):
    scaling = 1.0 / nd.sqrt(2.0 * np.pi * (sigma ** 2))
    bell = nd.exp(- (x - mu) ** 2 / (2.0 * sigma ** 2))

    return scaling * bell

def show_weight_dist(mean, variance):
    sigma = nd.sqrt(variance)
    x = np.linspace(mean.asscalar() - 4*sigma.asscalar(), mean.asscalar() + 4*sigma.asscalar(), 100)
    plt.plot(x, gaussian(nd.array(x, ctx=ctx), mean, sigma).asnumpy())
    plt.show()

mu = raw_mus[0][0][0]
var = softplus(raw_rhos[0][0][0]) ** 2

show_weight_dist(mu, var)
../_images/chapter18_variational-methods-and-uncertainty_bayes-by-backprop-gluon_37_0.png

Weight pruning

To measure the degree of redundancy present in the trained network and to reduce the model’s parameter count, we now want to examine the effect of setting some of the weights to \(0\) and evaluate the test accuracy afterwards. We can achieve this by ordering the weights according to their signal-to-noise-ratio, \(\frac{|\mu_i|}{\sigma_i}\), and setting a certain percentage of the weights with the lowest ratios to \(0\).

We can calculate the signal-to-noise-ratio as follows:

In [18]:
def signal_to_noise_ratio(mus, sigmas):
    sign_to_noise = []
    for j in range(len(mus)):
        sign_to_noise.extend([nd.abs(mus[j]) / sigmas[j]])
    return sign_to_noise

We further introduce a few helper methods which turn our list of weights into a single vector containing all weights. This will make our subsequent actions easier.

In [19]:
def vectorize_matrices_in_vector(vec):
    for i in range(0, (num_layers + 1) * 2, 2):
        if i == 0:
            vec[i] = nd.reshape(vec[i], num_inputs * num_hidden)
        elif i == num_layers * 2:
            vec[i] = nd.reshape(vec[i], num_hidden * num_outputs)
        else:
            vec[i] = nd.reshape(vec[i], num_hidden * num_hidden)

    return vec

def concact_vectors_in_vector(vec):
    concat_vec = vec[0]
    for i in range(1, len(vec)):
        concat_vec = nd.concat(concat_vec, vec[i], dim=0)

    return concat_vec

def transform_vector_structure(vec):
    vec = vectorize_matrices_in_vector(vec)
    vec = concact_vectors_in_vector(vec)

    return vec

In addition, we also have a helper method which transforms the pruned weight vector back to the original layered structure.

In [20]:
from functools import reduce
import operator

def prod(iterable):
    return reduce(operator.mul, iterable, 1)

def restore_weight_structure(vec):
    pruned_weights = []

    index = 0

    for shape in shapes:
        incr = prod(shape)
        pruned_weights.extend([nd.reshape(vec[index : index + incr], shape)])
        index += incr

    return pruned_weights

The actual pruning of the vector happens in the following function. Note that this function accepts an ordered list of percentages to evaluate the performance at different pruning rates. In this setting, pruning at each iteration means extracting the index of the lowest signal-to-noise-ratio weight and setting the weight at this index to \(0\).

In [21]:
def prune_weights(sign_to_noise_vec, prediction_vector, percentages):
    pruning_indices = nd.argsort(sign_to_noise_vec, axis=0)

    for percentage in percentages:
        prediction_vector = mus_copy_vec.copy()
        pruning_indices_percent = pruning_indices[0:int(len(pruning_indices)*percentage)]
        for pr_ind in pruning_indices_percent:
            prediction_vector[int(pr_ind.asscalar())] = 0
        pruned_weights = restore_weight_structure(prediction_vector)
        test_accuracy = evaluate_accuracy(test_data, net, pruned_weights)
        print("%s --> %s" % (percentage, test_accuracy))

Putting the above function together:

In [22]:
sign_to_noise = signal_to_noise_ratio(raw_mus, sigmas)
sign_to_noise_vec = transform_vector_structure(sign_to_noise)

mus_copy = raw_mus.copy()
mus_copy_vec = transform_vector_structure(mus_copy)

prune_weights(sign_to_noise_vec, mus_copy_vec, [0.1, 0.25, 0.5, 0.75, 0.95, 0.98, 1.0])
0.1 --> 0.9737
0.25 --> 0.9737
0.5 --> 0.9748
0.75 --> 0.9754
0.95 --> 0.9697
0.98 --> 0.9549
1.0 --> 0.098

Depending on the number of units used in the original network, the highest achievable pruning percentages (without significantly reducing the predictive performance) can vary. The paper, for example, reports almost no change in the test accuracy when pruning 95% of the weights in a 1200 unit Bayesian neural network, which creates a significantly sparser network, leading to faster predictions and reduced memory requirements.

Conclusion

We have taken a look at an efficient Bayesian treatment for neural networks using variational inference via the “Bayes by Backprop” algorithm (introduced by the “Weight Uncertainity in Neural Networks” paper). We have implemented a stochastic version of the variational lower bound and optimized it in order to find an approximation to the posterior distribution over the weights of a MLP network on the MNIST data set. As a result, we achieve regularization on the network’s parameters and can quantify our uncertainty about the weights accurately. Finally, we saw that it is possible to significantly reduce the number of weights in the neural network after training while still keeping a high accuracy on the test set.

We also note that, given this model implementation, we were able to reproduce the paper’s results on the MNIST data set, achieving a comparable test accuracy for all documented instances of the MNIST classification problem.

For whinges or inquiries, open an issue on GitHub.