{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Bayes by Backprop from scratch (NN, classification)\n", "\n", "We have already learned how to implement deep neural networks and how to use them for classification and regression tasks. In order to fight overfitting, we further introduced a concept called _dropout_, which randomly turns off a certain percentage of the weights during training.\n", "\n", "Recall the classic architecture of a MLP (shown below, without bias terms). So far, when training a neural network, our goal was to find an optimal point estimate for the weights.\n", "\n", "![](https://github.com/zackchase/mxnet-the-straight-dope/blob/master/img/bbb_nn_classic.png?raw=true)\n", "\n", "While networks trained using this approach usually perform well in regions with lots of data, they fail to express uncertainity in regions with little or no data, leading to overconfident decisions. This drawback motivates the application of Bayesian learning to neural networks, introducing probability distributions over the weights. These distributions can be of various nature in theory. To make our lifes easier and to have an intuitive understanding of the distribution at each weight, we will use a Gaussian distribution.\n", "\n", "![](https://github.com/zackchase/mxnet-the-straight-dope/blob/master/img/bbb_nn_bayes.png?raw=true)\n", "\n", "Unfortunately though, exact Bayesian inference on the parameters of a neural network is intractable. One promising way of addressing this problem is presented by the \"Bayes by Backprop\" algorithm (introduced by the \"[Weight Uncertainity in Neural Networks](https://arxiv.org/abs/1505.05424)\" paper) which derives a variational approximation to the true posterior. This algorithm does not only make networks more \"honest\" with respect to their overall uncertainity, but also automatically leads to regularization, thereby eliminating the need of using dropout in this model.\n", "\n", "While we will try to explain the most important concepts of this algorithm in this notebook, we also encourage the reader to consult the paper for deeper insights." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's start implementing this idea and evaluate its performance on the MNIST classification problem. We start off with the usual set of imports." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from __future__ import print_function\n", "import collections\n", "import mxnet as mx\n", "import numpy as np\n", "from mxnet import nd, autograd\n", "from matplotlib import pyplot as plt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For easy tuning and experimentation, we define a dictionary holding the hyper-parameters of our model." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "config = {\n", " \"num_hidden_layers\": 2,\n", " \"num_hidden_units\": 400, \n", " \"batch_size\": 128,\n", " \"epochs\": 10,\n", " \"learning_rate\": 0.001,\n", " \"num_samples\": 1,\n", " \"pi\": 0.25,\n", " \"sigma_p\": 1.0,\n", " \"sigma_p1\": 0.75,\n", " \"sigma_p2\": 0.1,\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Also, we specify the device context for MXNet." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true }, "outputs": [], "source": [ "ctx = mx.cpu()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load MNIST data\n", "\n", "We will again train and evaluate the algorithm on the MNIST data set and therefore load the data set as follows:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true, "scrolled": true }, "outputs": [], "source": [ "def transform(data, label):\n", " return data.astype(np.float32)/126.0, label.astype(np.float32)\n", "\n", "mnist = mx.test_utils.get_mnist()\n", "num_inputs = 784\n", "num_outputs = 10\n", "batch_size = config['batch_size']\n", "\n", "train_data = mx.gluon.data.DataLoader(mx.gluon.data.vision.MNIST(train=True, transform=transform),\n", " batch_size, shuffle=True)\n", "test_data = mx.gluon.data.DataLoader(mx.gluon.data.vision.MNIST(train=False, transform=transform),\n", " batch_size, shuffle=False)\n", "\n", "num_train = sum([batch_size for i in train_data])\n", "num_batches = num_train / batch_size" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In order to reproduce and compare the results from the paper, we preprocess the pixels by dividing by 126." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model definition\n", "\n", "### Activation function\n", "\n", "As with lots of past examples, we will again use the ReLU as our activation function for the hidden units of our neural network." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def relu(X):\n", " return nd.maximum(X, nd.zeros_like(X))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Neural net modeling\n", "\n", "As our model we are using a straightforward MLP and we are wiring up our network just as we are used to." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": true }, "outputs": [], "source": [ "num_layers = config['num_hidden_layers']\n", "\n", "# define function for evaluating MLP\n", "def net(X, layer_params):\n", " layer_input = X\n", " for i in range(len(layer_params) // 2 - 2):\n", " h_linear = nd.dot(layer_input, layer_params[2*i]) + layer_params[2*i + 1]\n", " layer_input = relu(h_linear)\n", " # last layer without ReLU\n", " output = nd.dot(layer_input, layer_params[-2]) + layer_params[-1]\n", " return output\n", "\n", "# define network weight shapes\n", "layer_param_shapes = []\n", "num_hidden = config['num_hidden_units']\n", "for i in range(num_layers + 1):\n", " if i == 0: # input layer\n", " W_shape = (num_inputs, num_hidden)\n", " b_shape = (num_hidden,) \n", " elif i == num_layers: # last layer\n", " W_shape = (num_hidden, num_outputs)\n", " b_shape = (num_outputs,)\n", " else: # hidden layers\n", " W_shape = (num_hidden, num_hidden)\n", " b_shape = (num_hidden,)\n", " layer_param_shapes.extend([W_shape, b_shape])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Build objective/loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As we briefly mentioned at the beginning of the notebook, we will use variational inference in order to make the prediction of the posterior tractable. While we can not model the posterior $P(\\mathbf{w}\\ |\\ \\mathcal{D})$ directly, we try to find the parameters $\\mathbf{\\theta}$ of a distribution on the weights $q(\\mathbf{w}\\ |\\ \\mathbf{\\theta})$ (commly referred to as the _variational posterior_) that minimizes the KL divergence with the true posterior. Formally this can be expressed as:\n", "\n", "\\begin{equation*}\n", "\\begin{split}\n", "\\theta^{*} & = \\arg\\min_{\\theta} \\text{KL}[q(\\mathbf{w}\\ |\\ \\mathbf{\\theta})\\ ||\\ P(\\mathbf{w}\\ |\\ \\mathcal{D}]\\\\\n", "& = \\arg\\min_{\\theta} \\int q(\\mathbf{w}\\ |\\ \\mathbf{\\theta}) \\log \\frac{q(\\mathbf{w}\\ |\\ \\mathbf{\\theta})}{P(\\mathbf{w}) P(\\mathcal{D}\\ |\\ \\mathbf{w})} d\\mathbf{w} \\\\\n", "& = \\arg\\min_{\\theta} \\text{KL}[q(\\mathbf{w}\\ |\\ \\mathbf{\\theta})\\ ||\\ P(\\mathbf{w})] - \\mathbb{E}_{q(\\mathbf{w}\\ |\\ \\mathbf{\\theta})}[\\log P(\\mathcal{D}\\ |\\ \\mathbf{w})]\n", "\\end{split}\n", "\\end{equation*}\n", "\n", "The resulting loss function, commonly referred to as either _variational free energy_ or _expected lower bound_ (_ELBO_), has to be minimized and is then given as follows:\n", "\n", "\\begin{equation*}\n", "\\mathcal{F}(\\mathcal{D}, \\mathbf{\\theta}) = \\text{KL}[q(\\mathbf{w}\\ |\\ \\mathbf{\\theta})\\ ||\\ P(\\mathbf{w})] - \\mathbb{E}_{q(\\mathbf{w}\\ |\\ \\mathbf{\\theta})}[\\log P(\\mathcal{D}\\ |\\ \\mathbf{w})]\n", "\\end{equation*}\n", "\n", "As one can easily see, the cost function tries to balance the complexity of the data $P(\\mathcal{D}\\ |\\ \\mathbf{w})$ and the simplicity of the prior $P(\\mathbf{w})$.\n", "\n", "We can approximate this exact cost through a Monte Carlo sampling procedure as follows\n", "\n", "\\begin{equation*}\n", "\\mathcal{F}(\\mathcal{D}, \\mathbf{\\theta}) \\approx \\sum_{i = 1}^{n} \\log q(\\mathbf{w}^{(i)}\\ |\\ \\mathbf{\\theta}) - \\log P(\\mathbf{w}^{(i)}) - \\log P(\\mathcal{D}\\ |\\ \\mathbf{w}^{(i)})\n", "\\end{equation*}\n", "\n", "where $\\mathbf{w}^{(i)}$ corresponds to the $i$-th Monte Carlo sample from the variational posterior. While writing this notebook, we noticed that even taking just one sample leads to good results and we will therefore stick to just sampling once throughout the notebook.\n", "\n", "Since we will be working with mini-batches, the exact loss form on mini-batch $i$ we will be using looks as follows:\n", "\n", "\\begin{equation*}\n", "\\begin{split}\n", "\\mathcal{F}(\\mathcal{D}_i, \\mathbf{\\theta}) & = \\frac{1}{M} \\text{KL}[\\log q(\\mathbf{w}\\ |\\ \\mathbf{\\theta})\\ ||\\ \\log P(\\mathbf{w})] - \\mathbb{E}_{q(\\mathbf{w}\\ |\\ \\mathbf{\\theta})}[\\log P(\\mathcal{D}_i\\ |\\ \\mathbf{w})]\\\\\n", "& \\approx \\frac{1}{M} (\\log q(\\mathbf{w}^{(1)}\\ |\\ \\mathbf{\\theta}) - \\log P(\\mathbf{w}^{(1)})) - \\log P(\\mathcal{D}_i\\ |\\ \\mathbf{w}^{(1)})\n", "\\end{split}\n", "\\end{equation*}\n", "\n", "where $M$ corresponds to the number of batches,\n", "and $\\mathcal{F}(\\mathcal{D}, \\mathbf{\\theta}) = \\sum_{i = 1}^{M} \\mathcal{F}(\\mathcal{D}_i, \\mathbf{\\theta})$\n", "\n", "Let's now look at each of these single terms individually." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Likelihood\n", "\n", "As with lots of past examples, we will again use the softmax to define our likelihood $P(\\mathcal{D}_i\\ |\\ \\mathbf{w})$. Revisit the [MLP from scratch notebook](https://github.com/zackchase/mxnet-the-straight-dope/blob/master/chapter03_deep-neural-networks/mlp-scratch.ipynb) for a detailed motivation of this function." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def log_softmax_likelihood(yhat_linear, y):\n", " return nd.nansum(y * nd.log_softmax(yhat_linear), axis=0, exclude=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prior\n", "\n", "Since we are introducing a Bayesian treatment for the network, we need to define a Prior over the weights.\n", "\n", "#### Gaussian prior\n", "\n", "A popular and simple prior is the Gaussian distribution. The prior over the entire weight vector simply corresponds to the prodcut of the individual Gaussians:\n", "\n", "\\begin{equation*}\n", "P(\\mathbf{w}) = \\prod_i \\mathcal{N}(\\mathbf{w}_i\\ |\\ 0,\\sigma_p^2)\n", "\\end{equation*}\n", "\n", "We can define the Gaussian distribution and our Gaussian prior as seen below. Note that we are ultimately intersted in the log-prior $\\log P(\\mathbf{w})$ and therefore compute the sum of the log-Gaussians.\n", "\n", "\\begin{equation*}\n", "\\log P(\\mathbf{w}) = \\sum_i \\log \\mathcal{N}(\\mathbf{w}_i\\ |\\ 0,\\sigma_p^2)\n", "\\end{equation*}" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": true }, "outputs": [], "source": [ "LOG2PI = np.log(2.0 * np.pi)\n", "\n", "def log_gaussian(x, mu, sigma):\n", " return -0.5 * LOG2PI - nd.log(sigma) - (x - mu) ** 2 / (2 * sigma ** 2)\n", "\n", "def gaussian_prior(x):\n", " sigma_p = nd.array([config['sigma_p']], ctx=ctx)\n", " \n", " return nd.sum(log_gaussian(x, 0., sigma_p))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Scale mixture prior\n", "\n", "Instead of a single Gaussian, the paper also suggests the usage of a scale mixture prior for $P(\\mathbf{w})$ as an alternative:\n", "\n", "\\begin{equation*}\n", "P(\\mathbf{w}) = \\prod_i \\bigg ( \\pi \\mathcal{N}(\\mathbf{w}_i\\ |\\ 0,\\sigma_1^2) + (1 - \\pi) \\mathcal{N}(\\mathbf{w}_i\\ |\\ 0,\\sigma_2^2)\\bigg )\n", "\\end{equation*}\n", "\n", "where $\\pi \\in [0, 1]$, $\\sigma_1 > \\sigma_2$ and $\\sigma_2 \\ll 1$. Again we are intersted in the log-prior $\\log P(\\mathbf{w})$, which can be expressed as follows:\n", "\n", "\\begin{equation*}\n", "\\log P(\\mathbf{w}) = \\sum_i \\log \\bigg ( \\pi \\mathcal{N}(\\mathbf{w}_i\\ |\\ 0,\\sigma_1^2) + (1 - \\pi) \\mathcal{N}(\\mathbf{w}_i\\ |\\ 0,\\sigma_2^2)\\bigg )\n", "\\end{equation*}" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def gaussian(x, mu, sigma):\n", " scaling = 1.0 / nd.sqrt(2.0 * np.pi * (sigma ** 2))\n", " bell = nd.exp(- (x - mu) ** 2 / (2.0 * sigma ** 2))\n", " \n", " return scaling * bell\n", "\n", "def scale_mixture_prior(x):\n", " sigma_p1 = nd.array([config['sigma_p1']], ctx=ctx)\n", " sigma_p2 = nd.array([config['sigma_p2']], ctx=ctx)\n", " pi = config['pi']\n", " \n", " first_gaussian = pi * gaussian(x, 0., sigma_p1)\n", " second_gaussian = (1 - pi) * gaussian(x, 0., sigma_p2)\n", " \n", " return nd.log(first_gaussian + second_gaussian)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Variational Posterior\n", "\n", "The last missing piece in the equation is the variational posterior. Again, we choose a Gaussian disribution for this purpose. The variational posterior on the weights is centered on the mean vector $\\mathbf{\\mu}$ and has variance $\\mathbf{\\sigma}^2$:\n", "\n", "\\begin{equation*}\n", "q(\\mathbf{w}\\ |\\ \\theta) = \\prod_i \\mathcal{N}(\\mathbf{w}_i\\ |\\ \\mathbf{\\mu},\\mathbf{\\sigma}^2)\n", "\\end{equation*}\n", "\n", "The log-posterior $\\log q(\\mathbf{w}\\ |\\ \\theta)$ is given by:\n", "\n", "\\begin{equation*}\n", "\\log q(\\mathbf{w}\\ |\\ \\theta) = \\sum_i \\log \\mathcal{N}(\\mathbf{w}_i\\ |\\ \\mathbf{\\mu},\\mathbf{\\sigma}^2)\n", "\\end{equation*}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Combined Loss\n", "\n", "After introducing the data likelihood, the prior, and the variational posterior, we are now able to build our combined loss function: $\\mathcal{F}(\\mathcal{D}_i, \\mathbf{\\theta}) = \\frac{1}{M} (\\log q(\\mathbf{w}\\ |\\ \\mathbf{\\theta}) - \\log P(\\mathbf{w})) - \\log P(\\mathcal{D}_i\\ |\\ \\mathbf{w})$" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def combined_loss(output, label_one_hot, params, mus, sigmas, log_prior, log_likelihood):\n", " \n", " # Calculate data likelihood\n", " log_likelihood_sum = nd.sum(log_likelihood(output, label_one_hot))\n", " \n", " # Calculate prior\n", " log_prior_sum = sum([nd.sum(log_prior(param)) for param in params])\n", "\n", " # Calculate variational posterior\n", " log_var_posterior_sum = sum([nd.sum(log_gaussian(params[i], mus[i], sigmas[i])) for i in range(len(params))])\n", " \n", " # Calculate total loss\n", " return 1.0 / num_batches * (log_var_posterior_sum - log_prior_sum) - log_likelihood_sum" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Optimizer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We use vanilla stochastic gradient descent to optimize the variational parameters. Note that this implements the updates described in the paper, as the gradient contribution due to the reparametrization trick is automatically included by taking the gradients of the combined loss function with respect to the variational parameters." ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def SGD(params, lr):\n", " for param in params:\n", " param[:] = param - lr * param.grad" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluation metric" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In order to being able to assess our model performance we define a helper function which evaluates our accuracy on an ongoing basis." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def evaluate_accuracy(data_iterator, net, layer_params):\n", " numerator = 0.\n", " denominator = 0.\n", " for i, (data, label) in enumerate(data_iterator):\n", " data = data.as_in_context(ctx).reshape((-1, 784))\n", " label = label.as_in_context(ctx)\n", " output = net(data, layer_params)\n", " predictions = nd.argmax(output, axis=1)\n", " numerator += nd.sum(predictions == label)\n", " denominator += data.shape[0]\n", " return (numerator / denominator).asscalar()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Parameter initialization\n", "\n", "We are using a Gaussian distribution for each individual weight as our variational posterior, which means that we need to store two parameters, mean and variance, for each weight. For the variance we need to ensure that it is non-negative, which we will do by using the softplus function to express $\\mathbf{\\sigma}$ in terms of an unconstrained parameter $\\mathbf{\\rho}$. While gradient descent will be performed on $\\theta = (\\mathbf{\\mu}, \\mathbf{\\rho})$, the distribution for each individual weight is represented as $w_i \\sim \\mathcal{N}(w_i\\ |\\ \\mu_i,\\sigma_i)$ with $\\sigma_i = \\text{softplus}(\\mathbf{\\rho}_i)$.\n", "\n", "We initialize $\\mathbf{\\mu}$ with a Gaussian around $0$ (just as we would initialize standard weights of a neural network). It is important to initialize $\\mathbf{\\rho}$ (and hence $\\sigma$) to a small value, otherwise learning might not work properly." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": true }, "outputs": [], "source": [ "weight_scale = .1\n", "rho_offset = -3\n", "\n", "# initialize variational parameters; mean and variance for each weight\n", "mus = []\n", "rhos = []\n", " \n", "for shape in layer_param_shapes:\n", " mu = nd.random_normal(shape=shape, ctx=ctx, scale=weight_scale)\n", " rho = rho_offset + nd.zeros(shape=shape, ctx=ctx)\n", " mus.append(mu)\n", " rhos.append(rho)\n", "\n", "variational_params = mus + rhos" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since these are the parameters we wish to do gradient descent on, we need to allocate space for storing the gradients." ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": true }, "outputs": [], "source": [ "for param in variational_params:\n", " param.attach_grad()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Main training loop\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The main training loop is pretty similar to the one we used in the MLP example. The only adaptation we need to make is to add the weight sampling which is performed during each optimization step. Generating a set of weights, which will subsequently be used in the neural network and the loss function, is a 3-step process:\n", "\n", "1) Sample $\\mathbf{\\epsilon} \\sim \\mathcal{N}(\\mathbf{0},\\mathbf{I}^{d})$" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def sample_epsilons(param_shapes):\n", " epsilons = [nd.random_normal(shape=shape, loc=0., scale=1.0, ctx=ctx) for shape in param_shapes]\n", " return epsilons" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "2) Transform $\\mathbf{\\rho}$ to a postive vector via the softplus function: $\\mathbf{\\sigma} = \\text{softplus}(\\mathbf{\\rho}) = \\log(1 + \\exp(\\mathbf{\\rho}))$" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def softplus(x):\n", " return nd.log(1. + nd.exp(x))\n", "\n", "def transform_rhos(rhos):\n", " return [softplus(rho) for rho in rhos]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "3) Compute $\\mathbf{w}$: $\\mathbf{w} = \\mathbf{\\mu} + \\mathbf{\\sigma} \\circ \\mathbf{\\epsilon}$, where the $\\circ$ operator represents the element-wise multiplication. This is the \"reparametrization trick\" for separating the randomness from the parameters of $q$." ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def transform_gaussian_samples(mus, sigmas, epsilons):\n", " samples = []\n", " for j in range(len(mus)):\n", " samples.append(mus[j] + sigmas[j] * epsilons[j])\n", " return samples" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Complete loop\n", "\n", "The complete training loop is given below." ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0. Loss: 2626.47417991, Train_acc 0.945617, Test_acc 0.9455\n", "Epoch 1. Loss: 2606.28165139, Train_acc 0.962783, Test_acc 0.9593\n", "Epoch 2. Loss: 2600.2452303, Train_acc 0.969783, Test_acc 0.9641\n", "Epoch 3. Loss: 2595.75639899, Train_acc 0.9753, Test_acc 0.9684\n", "Epoch 4. Loss: 2592.98582057, Train_acc 0.978633, Test_acc 0.9723\n", "Epoch 5. Loss: 2590.05895182, Train_acc 0.980483, Test_acc 0.9733\n", "Epoch 6. Loss: 2588.57918775, Train_acc 0.9823, Test_acc 0.9756\n", "Epoch 7. Loss: 2586.00932367, Train_acc 0.984, Test_acc 0.9749\n", "Epoch 8. Loss: 2585.4614887, Train_acc 0.985883, Test_acc 0.9765\n", "Epoch 9. Loss: 2582.92995846, Train_acc 0.9878, Test_acc 0.9775\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD9CAYAAABQvqc9AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xl8VOW9x/HPLwkJEJZs7GQjbEZlDTsKola9dddaRVF7\nXdpbra3ebrb32tYutrfW2lvtgkoLrtfaRVu1rigKYQmbspOEbOxJSAgJIcs8948zQEhRAkxyksz3\n/Xrxysw5ZzK/GZLvPHnOc57HnHOIiEh4iPC7ABERaTsKfRGRMKLQFxEJIwp9EZEwotAXEQkjCn0R\nkTDSotA3s4vNbLOZ5ZrZt4+zP9XM3jGzj8zsPTMb3GTfz8xsXfDf50NZvIiInJwThr6ZRQKPA5cA\nmcANZpbZ7LCHgQXOuVHAg8BDwcd+FhgHjAEmAV83s16hK19ERE5GS1r6E4Fc51y+c64OeAG4otkx\nmcC7wdsLm+zPBBY55xqcc9XAR8DFp1+2iIicipaE/iCguMn9kuC2ptYCVwdvXwX0NLPE4PaLzay7\nmSUB5wHJp1eyiIicqqgQfZ+vA4+Z2a3AImA70Oice9PMJgBLgL1ANtDY/MFmdidwJ0BsbOz4kSNH\nhqgsEZHwsHLlylLnXJ8THdeS0N/Osa3zwcFtRzjndhBs6ZtZD+Aa51xFcN+PgR8H9z0HbGn+BM65\nucBcgKysLJeTk9OCskRE5DAzK2zJcS3p3lkBDDOzdDOLBq4HXmn2ZElmdvh73Q/MC26PDHbzYGaj\ngFHAmy17CSIiEmonbOk75xrM7G7gDSASmOecW29mDwI5zrlXgJnAQ2bm8Lp37go+vAvwgZkB7Adu\ncs41hP5liIhIS1h7m1pZ3TsiIifPzFY657JOdJyuyBURCSMKfRGRMKLQFxEJIwp9EZEwEqqLs0RE\n5BQVl9ewOLeUgIPZk1Ja9bkU+iIibay8uo7svDI+zC1lcW4pReU1AIxNiVPoi4h0dAfrGllRUM7i\n3FI+zC1lw879OAc9Y6KYnJHIbdPTmTY0kYw+PVq9FoW+iEiINTQG+Hh7JYtzS1mcW8bKwn3UNQbo\nEmmMS4nnvguGM21YEqMG9SYqsm1PrSr0RUROk3OOvL3VLMkr5cOtpWTnl1FV600+kDmgF7dOS2Pa\n0CQmpMXTPdrf2FXoi4icgj37a1mcV8qHW8tYnFvKrv21ACQndOPSUQOYNjSJKUMSSewR43Olx1Lo\ni4i0QFVtPcvyy4+cfN265wAA8d27MHVoEtOHJjEtI4mUxO4+V/rpFPoiIsdR1xBgddE+r18+r4w1\nxRU0Bhxdu0QwMT2Ra8cPZtrQJDIH9CIiwvwut8UU+iIiQCDg2LSryuuXzy1lWX45B+sbiTAYnRzH\nf8zIYNrQJMalxhETFel3uadMoS8iYevwRVGL88pYkltKWXUdABl9Yrkuy2vJTxqSSO9uXXyuNHQU\n+iISNvZU1ZKdV8aS3DIW55VSsu8gAH17xjBjeB+mDU1i2tAk+vfu6nOlrUehLyKdVuXBepbll7Ek\nr4wleaVs2e2dfO3VNYopGYnccc6QIxdFBRd76vQU+iLSaRysaySnsNwL+dxSPt5eScBB1y4RTEhL\n4Opxg5mWkUTmwF5EdqCTr6Gk0BeRDqu+McBHJRUszvVa8qsKK6hrDBAVYYxNieMrs4YxNSORMSkd\n++RrKCn0RaTDCAQcG3ftJzvPuyBq+bZyqusaMTt65evUjEQmpCUQG6N4Ox69KyLSbjnnKCjzRtgs\nySslO6+MfTX1AAzpE8tV4wYxLSOJyUMSiY+N9rnajkGhLyLtyq7K2mDIe102Oyu96Q0G9O7KrJH9\nmDY0kSkZiQzo3c3nSjsmhb6I+KqixptbfnGeF/T5e6sBb3qDKRmJ3JXhDaNMS+weNiNsWpNCX0Ta\n1J6qWlYW7GN5QTnLt5UfmVs+NjqSiekJ3DAhhalDEzmjf8ea3qCjUOiLSKtxzpFfWk1OQTkrCvaR\nU1BOQZm3SlTXLhGMSY7j3guGM21oIqMGx9GljeeWD0cKfREJmfrGAOt37Ccn2IpfWbjvyNQGCbHR\nZKXGc+OkVLLS4jlzYG+ioxTybU2hLyKn7MChBlYV7jvSkl9dvI/a+gAAqYndmTmiLxPS4slKSyCj\nT6z65NsBhb6ItNie/bWsKNjHioJycgrL2bBjPwEHEQaZA3tx/YQUJqYnkJUaT99enXf+mo5MoS8i\nx3V4CcCcgnKWF5STU7CPonKvP75bl0jGpsRx96xhTEiLZ2xKPD10MVSHoP8lEQG8RUPW7ag85qTr\n4QuhEmOjyUqL5+YpqWSlJXDmwF466dpBKfRFwlRVbT2riiqCIV/OmuKKI/3xaYndOf+MfkxMSyAr\nLZ70JPXHdxYKfZEwUVlTz7JtZSzNL2dpfhmbdnn98ZERxpkDezF7YioT0uIZnxZP357qj++sFPoi\nnVRlTT3LC7yAX5pfduQiqJioCMalxPOVWcOYkJbA2JQ4TU4WRvQ/LdJJVB6sZ8W2YMhvK2P9Di/k\no6MiGJ8Sz9fOH86UjERGJ/fWNMNhTKEv0kHtr20S8vnlrN/hLRgSHRXBuJQ4vnr+MKYMSWR0chxd\nuyjkxaPQF+kg9tfWk1NQfqRPfl1wVajoyIgjC4ZMyUhkjEJePoVCX6SdqqqtJ6dg35E++Y+bhPyY\n4Bj5yUMSGJcSr5CXFlPoi7QTBw41sKLgaHfNuu2VNAYcXSKNscnx3H3eUCYPSWRcqkJeTp1CX8Qn\nBw41HNNd83GTkB+THMeXZ2Z4IZ8ST7dohbyERotC38wuBn4FRAJPOud+2mx/KjAP6AOUAzc550qC\n+/4H+CwQAbwFfNU550L2CkQ6iJq6BlY06a75qMQL+agIL+T/Y4YX8uNTFfLSek4Y+mYWCTwOXAiU\nACvM7BXn3IYmhz0MLHDOzTezWcBDwBwzmwpMA0YFj/sQmAG8F7qXINI+1dY3sqpoH9l5ZWTnlbGm\nuIKGYMiPTo7jSzOGHAn57tH6o1vaRkt+0iYCuc65fAAzewG4Amga+pnAfcHbC4G/BW87oCsQDRjQ\nBdh9+mWLtD91DQE+KqlgSTDkVxbto64hQITB2YPjuP2cIUzJSCQrNV4XQ4lvWvKTNwgobnK/BJjU\n7Ji1wNV4XUBXAT3NLNE5l21mC4GdeKH/mHNu4+mXLeK/huCCIUvyysjOL2PFtnIO1jdiBmf078Wc\nyalMzUhkQnoCvbp28btcESB0J3K/DjxmZrcCi4DtQKOZDQXOAAYHj3vLzM5xzn3Q9MFmdidwJ0BK\nSkqIShIJrUDAsXHX/iPdNcu3lVN1qAGAYX17cF3WYKZkJDIpPZH42GifqxU5vpaE/nYgucn9wcFt\nRzjnduC19DGzHsA1zrkKM7sDWOqcOxDc9zowBfig2ePnAnMBsrKydJJX2gXnHLl7DpCdX8aS3DKW\nbSs7MtVwelIsl44eyJSMRCYPSdAEZdJhtCT0VwDDzCwdL+yvB2Y3PcDMkoBy51wAuB9vJA9AEXCH\nmT2E170zA3g0RLWLhJRzjsKyGi/kg6350gOHABgU143zz+jH1IxEpmQkMqB3N5+rFTk1Jwx951yD\nmd0NvIE3ZHOec269mT0I5DjnXgFmAg+ZmcPr3rkr+PCXgFnAx3gndf/pnPt76F+GyKnZXnGQJbml\nZOeXsTSvjB2VtQD07RnD9KFewE8ZkkRyQjfNJy+dgrW3IfNZWVkuJyfH7zKkk9qzv5bsfK8VvySv\n7Mjyfwmx0UwZksjkjESmZiQyRIuGSAdjZiudc1knOk7jxqTT21l5kD8uKeDtDbvJ21sNQK+uUUwa\nksgXpqUxJSOR4X17EhGhkJfOT6EvndbmXVXMXZTPy2u244DpQ5P4/IRkpgxJInNgLyIV8hKGFPrS\nqTjnWL6tnN8vyufdTXvo1iWSmyanctv0dJITuvtdnojvFPrSKTQGHG9t2MXv3s9nTXEFCbHR3Hfh\ncOZMTtWYeZEmFPrSodXWN/KXVdt54oN8tpVWk5rYnR9eeRafGz9Y0w+LHIdCXzqkipo6nllayB+X\nFFB6oI5Rg3vzmxvHcdGZ/dVXLx1LXQ3sK4B928AiYMQlrfp0Cn3pULZXHOSpD7bxwooiauoamTmi\nD188N4PJQxI0xFLar4P7oHwblOd74V4e/LdvG1TtPHpc/1EKfRGAjTv3M3dRPq+s3YEBl48eyB3n\nDuGMAb38Lk0EnIMDu71QPxzmh2+X50NtxbHH9+gPCUMgYxYkpEN8+tGvrUyhL+2Wc47s/DJ+/34+\n72/ZS/foSG6dmsa/T09nUJymQZA21tgAlcVNWur5XrfM4ZCvrzl6rEVCXLIX4mdd0yTYh0B8GkT7\nN5JMoS/tTmPA8c91u/j9ojw+KqkkqUc037hoBDdNSqV3d01RLK2ovvZo/3rzVntFEQQajh4b1dUL\n8Ph0GDLTC/bD4R6XApHt82dVoS/txsG6Rl5aWcwTH2yjqLyG9KRYfnLV2Vw9bpBG4khoOee12ktW\nQPEK2L3OC/b9O/CmCQuK6Q0JaTBgNGReGQz2IV6w9xwAERF+vYJTptAX3+2rrmNBdiHzswsor65j\nTHIc3/m3kVyYqZE4EiL1B2HHGihZDsXLoSQHDuzy9kV1g/5nQ/q5R7tgDod7t3joZAMEFPrim+Ly\nGp76cBv/t6KYg/WNnD+yL1+ckcGEtHiNxJFT55zXRVOS44V8yQrY9fHRrpn4dBgyAwZP8P71O7Pd\ndsW0BoW+tLl12yuZuyifVz/eSYTBFWMGcee5Qxjer6ffpUlHVFcNO1YfbcGXLIfqvd6+LrEwaBxM\nvQeSJ8KgLOjRx996fabQlzbhnGNxbhm/X5THB1tL6RETxW3T0/nCtDQtSNKWnIOda2HbIujSDXr2\n94YP9uwHPfpBVIzfFX4657y+95IVwf745bB7PbhGb3/iUBh6wdFWfN9MiFTMNaV3Q1pVQ2OA19bt\n4vfv57F+x3769IzhWxePZPakFHp3C58/qX0VaISipbDpH7DxH1BZ9MnHdo0LfhD0C37tG/xQCG7r\n0c/7gIjp1TZ93YcOwPaVR0O+ZAXUlHn7ont6rfhz7jsa8t0TWr+mDk6hL62mZF8Ndz27irUllQzp\nE8vPrjmbK8cOIiZKI3FaXcMhrzW/8RXY9BrUlEJktHcx0IxvwvCLwAWgahcc2OOd1Kza7X09sNu7\nXZTtfW089K/fP6pb8K+D4AfDMR8U/Y7e7p7U8hEuzkFZ7tEWfMkK2LPBqxMgaQQMvwSSgwHfZyRE\n6GfpZCn0pVW8u2k39/7fWgIBx6+uH8NlowZqkZLWdugA5L4FG/8OW96EuiqvNTz8MzDyUhh2IcQ0\nO2/Ss/+nf0/nvKtJD+wJfkDsPvr18O29myD/fThU+a+Pt0iI7XP0A6L5B0VU16P98dtzvOkKwBsq\nOXi8V/fgCd7tbvGheZ/CnEJfQqqhMcAv3trCb9/L48yBvfjNjeNITYz1u6zOq7oMtrzuBX3eQq9V\n3j0JzroKRl7mjVI5nX56My9su8VDnxGffmz9waN/JRzY1eSDIvhXRNVO2LnGO8l6uPXuPYnXaj/j\nsmDAT4Sk4R1yDHxHoNCXkNmzv5avPL+aZdvKuWFiCt+7LFMXVbWGyhLY9KoX9IWLvQDtnQxZ/+4F\nZ8pkf7o9unQLXqGa9unHNTZ43U0Hdnsjb/qdCV17t0WFgkJfQmRJXin3PL+G6kMN/PLzo7lq7GC/\nS+pc9m6BTX/3TsTuWOVt6zMSpt/nBf2A0R3nIqLIKK9r50RdS9IqFPpyWgIBx2/ey+WRt7aQnhTL\nc3dM0nj7UHDO6wrZGAz60s3e9kHj4fzveUGfNMzfGqVDUujLKdtXXce9L67hvc17uXz0QB66+mxi\nY/QjdcoCjd6ImY1/97pvKou9E6Fp02DC7TDys9B7kN9VSgen31A5JauL9nH3c6vZW3WIH155FjdN\nStHUCaeivha2ve8Nrdz8ujcGPTIGhp4PM+/3FtTQ2HMJIYW+nBTnHH9cUsBPXttIv15deek/pjBq\ncJzfZXUsh6pg65tet83WN6HugHex0/CLvCGKQy+AmB5+VymdlEJfWqyqtp5v//ljXv14Jxec0Y9f\nfG605rf/NPUHjy6ycXiZvNKtXhdOY503fv2sa+CMyyH9nPY/BYJ0Cgp9aZGNO/fz5WdXUVRew/2X\njOTOc4eoOwfgYEWzpfG2HV1ZqWrHscd27e3N8DjhDu9EbPJEXVEqbU6hLyf04opi/vvldfTu1oXn\n75jMxPQw6mN2zrvI6MiC1s1WUzp8BelhPfp587APmXl0XvbD65+qb17aAYW+fKKDdY088PI6/rSy\nhKkZifzq+rH06dkJuyAaG2B/ybHdMEda7QVQX330WIvwLoRKSA+upNRkwY34NIjW1cfSvin05bjy\n9x7gy8+uYvPuKu6ZNZSvXjC8Y69i1VgfbKU3a6mXbwuufVp/9NjIGC/AE4KLbRxuqScM8QI/Ktq3\nlyFyuhT68i9e/Wgn3/rzR3SJNP5w6wRmjujrd0mnrrEeVj8N7/3Uu+z/sJheXrD3PxsyLz92mbye\nAzXvi3RaCn05oq4hwE9e28gflxQwLiWOx2aPY2BcB13gxDlv7Ps7D3rT9SZPhgsfhISMYP96YseZ\ntkAkhBT6AgTnvn9uNWuLK7htejrfungk0VEdtLVb8CG89YC3+EafkXD9895FTgp5EYW+wMJNe7j3\nxTU0Njp+e+M4Ljl7gN8lnZpd6+CdH3gXPPUcCJc/BqNv0HJ5Ik3otyGMNTQG+OXbW3h8YR5nDOjF\nb28cR1pSBxx9UlEEC38Ca1+Arr28bpyJd3pT/YrIMRT6YWpPVS33PL+apfnlXD8hme9ffmbHm/u+\nphw++AUsnwsYTLsHpt+rFZZEPoVCPwwtzS/jK8+vpqq2noc/N5prx3ewue/ramDZb+HDR715a8bM\n9iYn693BXoeIDxT6YSQQcPz2/Tx+8eZm0pJiefq2iYzs38vvslqusQHWPAMLH/KW4Bvxb3D+A9D3\nDL8rE+kwWhT6ZnYx8CsgEnjSOffTZvtTgXlAH6AcuMk5V2Jm5wG/bHLoSOB659zfQlG8tFxFTR33\nvbiWdzft4dJRA/jpNaPo0VHmvncONv0D3v4BlG2F5EnwuT9C6hS/KxPpcE74W29mkcDjwIVACbDC\nzF5xzm1octjDwALn3HwzmwU8BMxxzi0ExgS/TwKQC7wZ4tcgJ7CmuIK7nl3FnqpaHrziTOZMTu04\nk6UVLIa3vwclKyBpBFz/nNfC7yj1i7QzLWnqTQRynXP5AGb2AnAF0DT0M4H7grcXAsdryV8LvO6c\nqzn1cuVkOOdYkF3Ij17dQN+eXfnTl6YyJrmDzH2/e4M3/HLLP4PDL38No2dr+KXIaWrJb9AgoLjJ\n/RJgUrNj1gJX43UBXQX0NLNE51xZk2OuBx453hOY2Z3AnQApKSktq1w+VfWhBr7154/4x0c7mTWy\nL49cN5q47h1gzpiK4uDwy+e9qRIu+D5M/CJEd/e7MpFOIVTNpq8Dj5nZrcAiYDvQeHinmQ0Azgbe\nON6DnXNzgbkAWVlZLkQ1ha3Kg/XcMm85H5VU8M2LR/ClczOIaO+TpR0ZfvmEd3/q3TD9Pk1HLBJi\nLQn97UByk/uDg9uOcM7twGvpY2Y9gGuccxVNDrkO+Ktzrh5pVeXVdcx5ahlbdlfx25vGc9GZ/f0u\n6dPV1cCy33nDLw/tPzr8Mi75xI8VkZPWktBfAQwzs3S8sL8emN30ADNLAsqdcwHgfryRPE3dENwu\nrWhv1SFuenIZBWXVzL05i/Pa8+yYjQ2w5ll47yGo2gnDL/GGX/bL9LsykU7thKHvnGsws7vxumYi\ngXnOufVm9iCQ45x7BZgJPGRmDq97567DjzezNLy/FN4PefVyxK7KWmY/uZSdFbXMu3UC04Ym+V3S\n8TkHm171TtKWboHBE+HaeZA61e/KRMKCOde+utCzsrJcTk6O32V0KCX7apj9xDLKq+uYd+uE9ruc\nYWG2N/tlyXJIGg7nfw9GflbDL0VCwMxWOueyTnScxr91cIVl1cx+YhlVtfU8fdtExqa0w3ln9mz0\nLqza8jr0HACX/S+MuVHDL0V8oN+6Dix3zwFufHIpdQ0BnrtjMmcN6u13SceqLvO6cVY/DdE9vZb9\npC9p+KWIjxT6HdSmXfu56cllALxw5xRG9O/pc0VNBBohZx68+yNvQrRJX4Jzv6HhlyLtgEK/A1q3\nvZKbnlpGTFQEz94+maF9e/hd0lFFS+G1r8OujyH9XLjk59B3pN9ViUiQQr+DWV20j5vnLadX1y48\nd8ckUhPbyaInVbu9OXLWPg+9BsPn5kPmFTpJK9LOKPQ7kOXbyvnCH5aT1DOGZ2+fxOD4dtA33ljv\nLWKy8CFoPATn/Kf3L7qdfBiJyDEU+h3E4txSbp+fw8C4rjx7+2T69+7qd0mQ/z68/k3YuwmGXgiX\n/AwSM/yuSkQ+hUK/A1i4eQ9ffHol6YmxPHP7JPr0jPG3oMoSePO/YP1fIS4VbngBhl+srhyRDkCh\n3869sX4Xdz+3iuH9evL0bZNIiPVxpsyGQ5D9GCx6GFwAzvsuTP2KFiAX6UAU+u3Y39fu4Gv/t4az\nB/Vm/r9PpHe3Lv4Vs/UtryunPB9GXgoX/QTiU/2rR0ROiUK/nfrzyhK+8dJaslITmPeFCf4tbVi+\nDd74Dmx+DRKHwk1/hqEX+FOLiJw2hX479PzyIr7z14+ZmpHIEzdn0T3ah/+muhpY/Kg35XFEFFzw\nA5j8ZYjqAAuxiMgnUui3M/OXFPC9V9Zz3og+/Pam8XTtEtm2BRyeBfOf90NlEZx1LXzmh9BrYNvW\nISKtQqHfjvz+/Tween0Tn8nsx69njyUmqo0Dv3QrvP4tyHsH+mbCra9C2vS2rUFEWpVCv53433e2\n8shbW7h01AB++fkxdImMaLsnP3QAFv0csh/3RuJc/FOYcDtE+njiWERahULfZ845Hn5zM48vzOPq\ncYP4+bWjiWyr9Wydg3V/hjf/G6p2eNMdX/B96NGOV9wSkdOi0PeRc44fvbqRpz7cxg0Tk/nxlWe3\n3QLmuzfAa9+Awg9hwGi4bj4kT2yb5xYR3yj0fRIIOB54ZR3PLC3i1qlpfO+yTKwtrmg9WAHv/dSb\nL6drL7j0lzDuFoho4/MHIuILhb4PGgOO+//yES/mlPDFGUP49sUjWz/wAwFvBsy3vwfVpTD+Vm8h\ncs1xLxJWFPptrKExwH/+aS0vr9nBPecP494LhrV+4O9Y43XllCyHwRPgxj/BwLGt+5wi0i4p9NtQ\nXUOAr76wmtfX7eIbF43grvOGtu4T1pTDuz+EnD9AbBJc8RsYfQNEtOHIIBFpVxT6baS2vpG7n1vF\n2xv38N+XZnLb9PTWezLnYM1z8OZ3oXa/t1zhzG9Dt7jWe04R6RAU+m3gYF0jdz6dwwdbS/nhlWcx\nZ3IrTlRWuR3+/lXIfQuSJ8Olj0C/M1vv+USkQ1Hot7LqQw3cNn8Fy7aV8z/XjOK6Ccmt80TOweqn\n4Y3vQqABLv4ZTLxTXTkicgyFfivaX1vPF/6wgjXFFTz6+TFcMWZQ6zxRRTH8/R7IexdSp8MVv4aE\nIa3zXCLSoSn0W0lFTR03z1vOhh37eeyGsVxy9oDQP4lzsGo+vPFf3qIm//YwZN2m1r2IfCKFfiso\nO3CIm55aTt6eA/zupvFckNkv9E9SUQSvfAXy34O0c+DyX0NCK54cFpFOQaEfYpU19Vw/dylF5TU8\neUsW5w7vE9onCARg5R/grQe8+599BMZ/Qa17EWkRhX6IPbOskK17DvDs7ZOYNjQptN98X4HXut+2\nCNJneK17LVkoIidBoR9CDY0BnllayDnDkkIb+IEA5DwFb30PLAIufdSbRqEt5uoRkU5FoR9Cb2/c\nzc7KWn5weQjHxZfnw8tf8WbDzJgFl/0vxLXSsE8R6fQU+iG0ILuQQXHdOP+MEJy4DQRgxRPw9ve9\nNWov/zWMnaPWvYicFoV+iGzdXcWSvDK+dfHI018EpSwPXr4bipbA0Avgsl9B78GhKVREwppCP0QW\nZBcSHRXB50/nittAIyz7PbzzIERGexOkjZmt1r2IhIxCPwT219bz51UlXD56IAmx0af2TUpz4eW7\noHgpDLsILnsUeg0MbaEiEvYU+iHwl5Ul1NQ1csuUtJN/cKARlv4G3v0RRMXAlb+D0derdS8irUKh\nf5oCAceC7ELGpsRx9uDeJ/fgvVvg5S9DyQoYfom3dGGvVpiuQUQkqEWXcZrZxWa22cxyzezbx9mf\nambvmNlHZvaemQ1usi/FzN40s41mtsHM0kJXvv8W55WSX1p9cq38QCMs/hX8bjqUboWrn4Abnlfg\ni0irO2FL38wigceBC4ESYIWZveKc29DksIeBBc65+WY2C3gImBPctwD4sXPuLTPrAQRC+gp8Nn9J\nIUk9ornk7P4te8DezfC3L8P2HBh5qTeNQs9WmJtHROQ4WtLSnwjkOufynXN1wAvAFc2OyQTeDd5e\neHi/mWUCUc65twCccwecczUhqbwdKC6v4Z1Nu7lhYgoxUZGffnBjA3z4S/jdOd4FV9c8BZ9/RoEv\nIm2qJaE/CChucr8kuK2ptcDVwdtXAT3NLBEYDlSY2V/MbLWZ/Tz4l0On8MyyQiLMmD0p5dMP3LMR\nnrrQu9Bq+GfgrmVw9rU6WSsibS5UUzN+HZhhZquBGcB2oBGv++ic4P4JwBDg1uYPNrM7zSzHzHL2\n7t0bopJaV219I/+3opiLzuzHgN7djn9QYwMsehh+fy5UFMK1f4DrnoYefdu2WBGRoJaM3tkONL3i\naHBw2xHOuR0EW/rBfvtrnHMVZlYCrHHO5Qf3/Q2YDDzV7PFzgbkAWVlZ7tReStt6Ze0OKmrqmTM5\n7fgH7F7v9d3vXANnXuUtcBIb4lk3RUROUktCfwUwzMzS8cL+emB20wPMLAkod84FgPuBeU0eG2dm\nfZxze4FZQE6oiveLc44F2QUM79eDyUMSjt0ZCMAHv4D3fwZde8Pn5sOZV/pSp4hIcyfs3nHONQB3\nA28AG4FdB6NNAAAMPElEQVQXnXPrzexBM7s8eNhMYLOZbQH6AT8OPrYRr2vnHTP7GDDgiZC/ija2\nuriCddv3c/OUNKx5v/yqP8LCH8EZl8FdyxX4ItKutOjiLOfca8BrzbY90OT2S8BLn/DYt4BRp1Fj\nu7NgSQE9Y6K4amyz89n1B+H9/4HkSXDtPJ2oFZF2R2vsnaS9VYd49eOdXJs1mNiYZp+Zy5+Aqp1w\n/gMKfBFplxT6J+mF5UXUNzrmTG62TGHtfvjwEW+hk7Tp/hQnInICCv2TUN8Y4NllRZw7vA9D+vQ4\ndmf243Bwn9fKFxFppxT6J+GtDbvZtb+WW6Y0a+VXl0H2Y3DG5TBwrD/FiYi0gEL/JMxfUkByQjdm\njmh2cdWHj0B9DZz3XX8KExFpIYV+C23atZ9l28qZMzn12OUQK7d7J3BHXQ99R/pXoIhICyj0W2hB\ndiExURFcl9VsOcRF/wMuADP/ZcZpEZF2R6HfApUH6/nrqu1cOWYQcd2bLIdYlgernoasL0B86id/\nAxGRdkKh3wIvrSzhYH0jc5qfwF34E28B83O+7k9hIiInSaF/AoGA4+nsAsanxnPWoCbLIe76GNa9\nBJO/pDnxRaTDUOifwAe5pRSU1XBz81b+uz+GmN4w7av+FCYicgoU+iewYEkBST1iuOSsJuvXFi+H\nLa/DtHugW7x/xYmInCSF/qcoKqvh3c17mD0pheio4FvlHLzzIMT2gUlf8rdAEZGTpND/FM8sKyTS\njBubLoeYvxAKPvBO3sb0+OQHi4i0Qwr9T3CwLrgc4ln96derq7fxcCu/d7I3TFNEpINR6H+CV9Zu\np/JgPbdMSTu6cdM/YMdq70KsqBjfahMROVUK/eNwzjF/SSEj+/dkQlrwRG2gEd79ESQO86ZcEBHp\ngBT6x7GycB8bdu7nlqlNlkP86EXYuwlmfRciW7TgmIhIu6PQP4752YX06hrFFWMGehsa6uC9n8CA\n0XDGFf4WJyJyGhT6zezZX8vrH+/kuqxkukcHW/Sr5kNFEcx6ACL0lolIx6UEa+a55UU0OsdNh5dD\nrKuGRT+HlKkw9Hx/ixMROU3qnG6irsFbDnHG8D6kJcV6G5fPhQO74XPztdi5iHR4auk38cb6Xeyt\nOnR0mObBCvjwURj2GUid4mttIiKhoNBv4unsQlISujNjeB9vQ/ZjUFsBs/7L38JEREJEoR+0Ycd+\nlheUc/OUVCIiDA7shezfwJlXeaN2REQ6AYV+0NNLC+jaJYLPjQ8uh/jBL6ChVoudi0inotAHKmvq\n+evq7Vw1dhC9u3eBimLIeQrGzIakYX6XJyISMgp94E8ri6mtDzBncpq34f2feV9nfMu3mkREWkPY\nh34g4FiQXcjEtAQyB/aC0q2w5jnIug3ikv0uT0QkpMI+9N/fspei8hpunhq8GGvhjyGqK5zzn/4W\nJiLSCsI+9OdnF9C3ZwwXndkfdq6F9X+FKV+GHn38Lk1EJOTCOvQLSqt5b/NebpyUSpfICG/q5K5x\nMOVuv0sTEWkVYR36Ty8tpEukccOkZCjMhq1vwvSvQbc4v0sTEWkVYRv6NXUNvJhTzMVnDaBvjxhv\nGcQe/WDiF/0uTUSk1YRt6P9t9Q6qahu4ZUoq5L4DRUvg3G9AdHe/SxMRaTVhGfrOORZkF5A5oBfj\nU3rDOz+AuBQYd4vfpYmItKqwDP0VBfvYtKuKW6amYhv/Drs+gpnfgahov0sTEWlVYRn687ML6N2t\nC5ef3c8bsdNnJIy6zu+yRERaXYtC38wuNrPNZpZrZt8+zv5UM3vHzD4ys/fMbHCTfY1mtib475VQ\nFn8qdlXW8sa6XXx+QjLdNv4JyrZ6UydHRPpdmohIqzth6JtZJPA4cAmQCdxgZpnNDnsYWOCcGwU8\nCDzUZN9B59yY4L/LQ1T3KTuyHGJWf3jvpzBwLIy81O+yRETaREta+hOBXOdcvnOuDngBuKLZMZnA\nu8HbC4+zv12oawjw3LIiZo3oS8q2F6GyGM5/QMsgikjYaEnoDwKKm9wvCW5rai1wdfD2VUBPM0sM\n3u9qZjlmttTMrjytak/T6+t2UnrgELdO6OMtdp52Dgw5z8+SRETaVKhO5H4dmGFmq4EZwHagMbgv\n1TmXBcwGHjWzjOYPNrM7gx8MOXv37g1RSf9qQXYh6UmxTCt9Car3qpUvImGnJaG/HWg6x/Dg4LYj\nnHM7nHNXO+fGAt8NbqsIft0e/JoPvAeMbf4Ezrm5zrks51xWnz6tM9HZuu2VrCzcx23j44lY8r8w\n/BJIntgqzyUi0l61JPRXAMPMLN3MooHrgWNG4ZhZkpkd/l73A/OC2+PNLObwMcA0YEOoij8ZC7IL\n6B4dybWH/gKHKrXYuYiEpROGvnOuAbgbeAPYCLzonFtvZg+a2eHRODOBzWa2BegH/Di4/Qwgx8zW\n4p3g/alzrs1Df191HS+v2cGcs7rSdeVcOOta6H9WW5chIuK7qJYc5Jx7DXit2bYHmtx+CXjpOI9b\nApx9mjWethdzijnUEOBLEX+FhkNw3nf8LklExBed/orcxoDj6aWFXJpcR/yGZ2HcHEj8l3PJIiJh\noUUt/Y5s4aY9lOw7yAt9X4byCDj3m36XJCLim07f0l+wtJDJPUsZVPQyTLwDeje/xEBEJHx06pZ+\n/t4DLNqyl7cH/Q2r6g7T7/W7JBERX3Xqlv7TSwsZG5XP0LJ3vXVvY5P8LklExFedtqVffaiBl3JK\neLH3yxCIhyl3+V2SiIjvOm1L/6+rt5NZ9zFnVK+A6fdB115+lyQi4rtO2dJ3zrFgyTYejX0J120A\nNvEOv0sSEWkXOmXoL80vZ1Dph2RGb4QZv4Qu3fwuSUSkXeiU3TtPL8nn29EvEohLg7Fz/C5HRKTd\n6HQt/R0VB4na/Aojogph1hMQ2cXvkkRE2o1O19J/ITufeyNepC5xJJx1jd/liIi0K52qpX+ooZHq\n5QtIj9gFF/5Si52LiDTTqVr6/1xTwG2BF9mfOBpGXOJ3OSIi7U6naumXLvwdA62cwL/9Qcsgiogc\nR6dp6Rft3M2VB55nR8IkIjJm+l2OiEi71Gla+ik9HDVDz6Hb9K/5XYqISLvVaUKfnv3pPud5v6sQ\nEWnXOk33joiInJhCX0QkjCj0RUTCiEJfRCSMKPRFRMKIQl9EJIwo9EVEwohCX0QkjJhzzu8ajmFm\ne4HC0/gWSUBpiMrp6PReHEvvx7H0fhzVGd6LVOdcnxMd1O5C/3SZWY5zLsvvOtoDvRfH0vtxLL0f\nR4XTe6HuHRGRMKLQFxEJI50x9Of6XUA7ovfiWHo/jqX346iweS86XZ++iIh8ss7Y0hcRkU/QaULf\nzC42s81mlmtm3/a7Hj+ZWbKZLTSzDWa23sy+6ndNfjOzSDNbbWb/8LsWv5lZnJm9ZGabzGyjmU3x\nuyY/mdm9wd+TdWb2vJl19bum1tQpQt/MIoHHgUuATOAGM8v0typfNQD/6ZzLBCYDd4X5+wHwVWCj\n30W0E78C/umcGwmMJozfFzMbBNwDZDnnzgIigev9rap1dYrQByYCuc65fOdcHfACcIXPNfnGObfT\nObcqeLsK75d6kL9V+cfMBgOfBZ70uxa/mVlv4FzgKQDnXJ1zrsLfqnwXBXQzsyigO7DD53paVWcJ\n/UFAcZP7JYRxyDVlZmnAWGCZv5X46lHgm0DA70LagXRgL/CHYHfXk2YW63dRfnHObQceBoqAnUCl\nc+5Nf6tqXZ0l9OU4zKwH8Gfga865/X7X4wczuxTY45xb6Xct7UQUMA74rXNuLFANhO05MDOLx+sV\nSAcGArFmdpO/VbWuzhL624HkJvcHB7eFLTPrghf4zzrn/uJ3PT6aBlxuZgV43X6zzOwZf0vyVQlQ\n4pw7/JffS3gfAuHqAmCbc26vc64e+Asw1eeaWlVnCf0VwDAzSzezaLwTMa/4XJNvzMzw+mw3Ouce\n8bsePznn7nfODXbOpeH9XLzrnOvULblP45zbBRSb2YjgpvOBDT6W5LciYLKZdQ/+3pxPJz+xHeV3\nAaHgnGsws7uBN/DOvs9zzq33uSw/TQPmAB+b2Zrgtu84517zsSZpP74CPBtsIOUDX/C5Ht8455aZ\n2UvAKrxRb6vp5Ffn6opcEZEw0lm6d0REpAUU+iIiYUShLyISRhT6IiJhRKEvIhJGFPoiImFEoS8i\nEkYU+iIiYeT/AUSFzzyvHHEbAAAAAElFTkSuQmCC\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "epochs = config['epochs']\n", "learning_rate = config['learning_rate']\n", "smoothing_constant = .01\n", "train_acc = []\n", "test_acc = []\n", "\n", "for e in range(epochs):\n", " for i, (data, label) in enumerate(train_data):\n", " data = data.as_in_context(ctx).reshape((-1, 784))\n", " label = label.as_in_context(ctx)\n", " label_one_hot = nd.one_hot(label, 10)\n", " \n", " with autograd.record():\n", " # sample epsilons from standard normal\n", " epsilons = sample_epsilons(layer_param_shapes)\n", " \n", " # compute softplus for variance\n", " sigmas = transform_rhos(rhos)\n", "\n", " # obtain a sample from q(w|theta) by transforming the epsilons\n", " layer_params = transform_gaussian_samples(mus, sigmas, epsilons)\n", " \n", " # forward-propagate the batch\n", " output = net(data, layer_params)\n", " \n", " # calculate the loss\n", " loss = combined_loss(output, label_one_hot, layer_params, mus, sigmas, gaussian_prior, log_softmax_likelihood)\n", " \n", " # backpropagate for gradient calculation\n", " loss.backward()\n", " \n", " # apply stochastic gradient descent to variational parameters\n", " SGD(variational_params, learning_rate)\n", " \n", " # calculate moving loss for monitoring convergence\n", " curr_loss = nd.mean(loss).asscalar()\n", " moving_loss = (curr_loss if ((i == 0) and (e == 0)) \n", " else (1 - smoothing_constant) * moving_loss + (smoothing_constant) * curr_loss)\n", "\n", " \n", " test_accuracy = evaluate_accuracy(test_data, net, mus)\n", " train_accuracy = evaluate_accuracy(train_data, net, mus)\n", " train_acc.append(np.asscalar(train_accuracy))\n", " test_acc.append(np.asscalar(test_accuracy))\n", " print(\"Epoch %s. Loss: %s, Train_acc %s, Test_acc %s\" %\n", " (e, moving_loss, train_accuracy, test_accuracy))\n", " \n", "plt.plot(train_acc)\n", "plt.plot(test_acc)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For demonstration purposes, we can now take a look at one particular weight by plotting its distribution." ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "scrolled": false }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAW8AAAD8CAYAAAC4uSVNAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xl8U9ed9/HPT5J3Y+N9A2NWA2bHIWQnJGEN2dpOQhLa\naZqhnaaT6Tad9uksT/t0pp1Mp03b6UaTTtM2zU46gbAkJCSELIDZbBYDZjPejcEbXiWd5w+L1HEA\ny7bkq+X3fr38QpaupC9X0s9H555zjxhjUEopFVxsVgdQSik1cFq8lVIqCGnxVkqpIKTFWymlgpAW\nb6WUCkJavJVSKghp8VZKqSCkxVsppYKQFm+llApCDn89cGpqqsnLy/PXwyulVEjavXv3WWNMWn/b\n+a145+XlUVRU5K+HV0qpkCQip73ZTrtNlFIqCGnxVkqpIKTFWymlgpAWb6WUCkJavJVSKghp8VZK\nqSCkxVsppYKQ38Z5KxVIGlo7+d99VTS2dX143ewxSdw0MQ2bTSxMptTgeF28ReQrwMOAAUqAzxpj\nOvwVTClfKK5o5HfvnmJ9cTVdLjfiqdMXl24dkxLLg1eP4d55o0mIjrAuqFID5FXxFpEc4FFgqjGm\nXUSeB+4DfufHbEoNmjGGX7x1nP/cfIT4KAcr541m1TVjmJA+AoAup5vNB2v4/fun+LcNh/nDB6d5\n4jOFTMoYYW1wpbw0kG4TBxAjIt1ALFDln0hKDU1Ht4tvvFjMK/uruHNWNt+7axoj+rSqIx02VszM\nZsXMbHadOscXn97DPb94j5+unMXCyRkWJVfKe14dsDTGVAI/BMqBaqDJGPNa3+1EZLWIFIlIUX19\nvW+TKuWFpvZu7l3zAa/sr+IfFufz+L2zPla4+7oqL5lXvnQdeamxfO6pIv7w/qlhyarUUHhVvEUk\nCbgTGAtkA3Ei8mDf7Ywxa4wxhcaYwrS0fk+KpZRPudyGR5/Zy8HKJn69ai6P3DwBEe8ORmYlxvD8\n56/hlsnp/MsrB9laWufntEoNjbdDBW8FThpj6o0x3cBa4Fr/xVJq4B7bVMrbR+v5zp0FLC7IHPD9\nYyMd/HTlbKZkJvDoM3spq2v1Q0qlfMPb4l0OzBeRWOlpytwCHPZfLKUG5uW9Ffx62wlWzR/DA1eP\nGfTjxEY6+M1nCol02Fj9+yKa2rp9mFIp3/G2z3sH8CKwh55hgjZgjR9zKeW10ppm/vGlEuaPS+Zf\nVkwd8uPljIzhV6vmcuZ8G199fh/m4rhCpQKI1zMsjTH/aoyZbIyZZoxZZYzp9Gcwpbzhdhu+tbaE\n+CgHP79/DhF230waviovmX9cMpk3SutYX1ztk8dUypd0erwKak/vLGdveSP/tHwKKfFRPn3sz143\nluk5iXxn3SGa2rX7RAUWLd4qaNU1d/DYxlKum5DC3bNzfP74dpvw/Xumc+5CJ49tKvX54ys1FFq8\nVdD6zrpDdLrcfO+u6V4PCRyoaTmJfPa6sTy9o5zdp8/55TmUGgwt3ioovXOsnldLqvm7mycwNjXO\nr8/11dsmkZ0YzbdfPoDLrQcvVWDQ4q2CjjGG/9x8hJyRMay+aZzfny8uysE3l02htKaF9cV6VggV\nGLR4q6Dz+qFaiiua+PtbJhLlsA/Lc94+PYvJmSN4fMsxnC73sDynUleixVsFFbfb8KPXjzI2NY57\n5vj+IOXl2GzCV2+bxMmzF1i7t3LYnlepy9HirYLK+pJqSmta+PKtE3H4aEy3t26bmsGMUYn8ZMsx\nOp2uYX1upfrS4q2ChtPl5vHXj5KfMYIVM7KH/flFhK8tyqeysZ3nd50Z9udXqjct3ipovLK/ihNn\nL/DVRZMsW7rsxompzMtL5r+3lmnrW1lKi7cKCsYYfvPOSSamx7NoqnWLJYgIjyycQG1zJ+v367R5\nZR0t3ioovH+8gcPVzTx8w1i/Tcjx1o0TU5mUEc8T20/qSauUZbR4q6Dwm3dOkBofyZ2zhm+EyeWI\nCA9fP47D1c28d7zB6jgqTGnxVgGvrK6FrUfqWTU/j+iI4RnX3Z87ZmWTGh/JE++csDqKClNavFXA\ne3L7KSIdNh6cn2t1lA9FR9hZNT+PrUfqKatrsTqOCkPermGZLyL7ev00i8iX/R1OqYbWTtbuqeAT\nc3J8fsrXoXpwfi5RDhtPbj9ldRQVhrxdSeeIMWaWMWYWMBdoA172azKlgGd2ltPpdPO568daHeVj\nUuKjuGfOKNbuqaCxrcvqOCrMDKbb5BbguDHmtK/DKNWb2214ZucZrh2fwoT0EVbHuaRV88fQ6XSz\ndo9OmVfDazDF+z7gGV8HUaqvbcfqqWxsZ+W8wOnr7mtqdgIzR4/kTzvLddigGlYDKt4iEgncAbxw\nmdtXi0iRiBTV19f7Ip8KY8/sLCclLpLFBZlWR7miB+blUlbXStHp81ZHUWFkoC3vpcAeY0ztpW40\nxqwxxhQaYwrT0tKGnk6FrbrmDrYcruOTc0cR6QjsQVG3z8wiPsrBMzvKrY6iwshAPxUr0S4TNQxe\n2F2By22496rRVkfpV2ykg7tmZ7O+pFoPXKph43XxFpE44DZgrf/iKHXxQGU514xLYVxavNVxvLJy\nXi5deuBSDSOvi7cx5oIxJsUY0+TPQEq9U3aWivPt3H914B6o7KsgO5GZo0fyjB64VMMksDsTVVh6\nvugMSbERLCqw7uyBg7HyqtEcq2tl75lGq6OoMKDFWwWUpvZuXj9Uyx0zs4dtfUpfWTYjiyiHjZe1\n60QNAy3eKqBsKKmmy+nmnjmjrI4yYAnRESwqyGRdcRVdTl2kWPmXFm8VUF7eU8n4tDhmjEq0Osqg\n3DMnh8a2brYeqbM6igpxWrxVwChvaGPnqXPcM2eU5QsuDNYNE1JJjY9i7Z4Kq6OoEKfFWwWMl/dW\nIgJ3zbZ+wYXBctht3DUrmzdL6zh/Qcd8K//R4q0CgjGGtXsrmD82hZyRMVbHGZJ75oyi22VYX1xl\ndRQVwrR4q4Cwp/w8pxvauGdO8La6L5qancDkzBG8pKNOlB9p8VYB4eW9lURH2Fg6PcvqKD7xiTmj\n2HemkRP1rVZHUSFKi7eyXLfLzYaSGm6bmkl8lMPqOD6xYmY2IrC+uNrqKCpEafFWlnvveAPnLnSx\nYkZotLoBMhOjmZeXzCv7q3S6vPILLd7Kcuv2VzEi2sFN+aF1GuEVM7Mpq2vlSK0uUKx8T4u3slRH\nt4vNB2pYXJAZdNPh+7N0WiZ2m/DKPh11onxPi7ey1NtH62npdHLHzGyro/hcSnwU109IZV2xdp0o\n39PirSy1bn8VyXGRXDs+xeoofrFiZjZnzrWzv0LPpKx8ayCLMYwUkRdFpFREDovINf4MpkJfW5eT\nNw7XsWx6Jg57aLYjFhVkEGm3sW6/dp0o3xrIJ+YnwCZjzGRgJnDYP5FUuNhyuI72bhcrZoRel8lF\nCdERLMhPY31xFW63dp0o3/GqeItIInAj8CSAMabLGKNnnFdDsm5/FRkJUVyVl2x1FL9aMTOb2uZO\ndp06Z3UUFUK8bXmPBeqB/xGRvSLyhGdNS6UGpaWjm7eP1rNsehY2W3CeQdBbCyenEx1hY0OJTthR\nvuNt8XYAc4BfGmNmAxeAb/bdSERWi0iRiBTV19f7MKYKNW+W1tHldLM8RKbDX0lclIMFk9LZeKBG\nu06Uz3hbvCuACmPMDs/vL9JTzD/CGLPGGFNojClMSwutCRfKtzaUVJOREMWc3CSrowyLZTOyqGvp\nZHf5eaujqBDhVfE2xtQAZ0Qk33PVLcAhv6VSIe1Cp5O3jtSzdFrod5lctHByOpEO7TpRvjOQ0SZ/\nBzwtIsXALODf/RNJhbo3S+vodLpZOi3T6ijDJj7KwYJJaWws0a4T5RteF29jzD5Pl8gMY8xdxhj9\n/qcGZeOBatJGRFEY4qNM+lo2PYua5g72ntGBWmroQnNmhApYbV1O3iytY0lBz3k/wsnCKelE2rXr\nRPmGFm81rN46Uk9Ht5tlYTDKpK+E6AhunJTKxpJqPdeJGjIt3mpYvVpSTWp8JPPGhleXyUVLp2VR\n1dTBPu06UUOkxVsNm45uF1tL67htavh1mVx069QMHDZh08Eaq6OoIKfFWw2bd46dpa3LFVajTPpK\njIng2gmpbDpQo10naki0eKths+lADQnRDuaPC83Tv3prSUEmpxvaKK3RFXbU4GnxVsOi2+Vmy+Fa\nbp2SQaQjvN92iwoyEOn5Y6bUYIX3p0gNmx0nztHU3s2SMO4yuSg1vudMipu131sNgRZvNSw2Hawm\nJsLOjZP0nDfQ03VSWtPCybMXrI6igpQWb+V3brdh88Fabp6cRnREaC0yPFgXv4Fo14kaLC3eyu/2\nlJ+nvqWTxQXaZXJR9sgYZo5K1CGDatC0eCu/23Sghki7jYWT062OElAWT8tk/5lGqhrbrY6igpAW\nb+VXxhg2HazhugkpjIiOsDpOQFni+SaiBy7VYGjxVn51qLqZivPtOsrkEsalxTMpI16LtxoULd7K\nrzYfqMEmcOuUDKujBKTFBZnsPHmOcxe6rI6igowWb+VXmw/WUpiXTEp8lNVRAtLigkzcBrYcrrU6\nigoyXhdvETklIiUisk9EivwZSoWGk2cvcKS2RUeZXEFBdgI5I2PYrEMG1QANtOV9szFmljGm0C9p\nVEi52Je7uEC7TC5HRFhckMk7ZWdp7XRaHUcFEe02UX6z+WAN03ISGJUUa3WUgLa4IIMup5u3j9Rb\nHUUFkYEUbwO8JiK7RWT1pTYQkdUiUiQiRfX1+kYMZ7XNHewtb2TxVO0y6U9hXjIpcZE6YUcNyECK\n9/XGmDnAUuAREbmx7wbGmDWeRYoL09L0HBbh7LVDPQfgFusQwX7ZbcKtUzLYWlpHp9NldRwVJAay\nenyl59864GVgnr9CqeC3+UAN41LjmJgeb3WUoLBkWiatnU7eK2uwOooKEl4VbxGJE5ERFy8Di4AD\n/gymgldTWzcfnGhgUUEmIuG53NlAXTshhfgoh07YUV7ztuWdAWwXkf3ATuBVY8wm/8VSweyN0lqc\nbqOzKgcgymFnQX4arx+qxeXW5dFU/xzebGSMOQHM9HMWFSI2HaghMyGaGTmJVkcJKkumZbK+uJqi\nU+e4OsyXilP906GCyqfau1xsO1bP4oIMbGG6QvxgLchPJ9JhY/NBnW2p+qfFW/nU20fr6eh266zK\nQYiPcnDDhFQ2H9SV5VX/tHgrn9p8sIaRsRHMG5tsdZSgtLggk8rGdg5WNVsdRQU4Ld7KZ7pdbt7w\nrBDvsOtbazBumZKOTfQc36p/+glTPvPBiQaaO5zaZTIEKfFRzBurK8ur/mnxVj6z6UANsZF2bpiY\nanWUoLa4IJOjta2cqG+1OooKYFq8lU+43YbXD9WyIF9XiB+qRZ5vLnquE3UlWryVT+wpP0+drhDv\nEzkjY5gxKlHP8a2uSIu38gldId63lkzLZH9FE5W6sry6DC3easiMMWw8UMP1E1N1hXgfubiy/Gva\ndaIuQ4u3GrKDVc1UNuoK8b40Li2e/IwRbNSuE3UZWrzVkG08UI3dJtymK8T71OJpmew6dY76lk6r\no6gApMVbDdmmAzXMH5dMUlyk1VFCytJpmRgDrx/Sc52oj9PirYbkWG0Lx+svfNhHq3xncuYIxqTE\n6pBBdUlavNWQbDxQgwg6RNAPRIQl0zJ5r+wsTW3dVsdRAWZAxVtE7CKyV0TW+yuQCi6bDtQwJzeJ\n9IRoq6OEpCUFmTjdhjdKtetEfdRAW95/Dxz2RxAVfE43XOBQdTNLdZSJ38wcNZKsxGg2lGjXifoo\nr4u3iIwClgNP+C+OCiYXC4oOEfQfm62n62TbsXpaOrTrRP3FQFrejwPfANx+yqKCzIaSamaOHsmo\npFiro4S05dOz6HK6ebO0zuooKoB4u3r87UCdMWZ3P9utFpEiESmqr6/3SUAVmMob2iipbGL5dG11\n+9uc3CQyEqJ4tbja6igqgHjb8r4OuENETgHPAgtF5I99NzLGrDHGFBpjCtPS0nwYUwWajQd6CsnS\naVkWJwl9NpuwdFoWbx2tp7XTaXUcFSC8Kt7GmG8ZY0YZY/KA+4A3jTEP+jWZCmgbSqqZMSqR0cna\nZTIclmnXiepDx3mrAas438b+iiaWTddW93ApHJNE+ogoNmjXifIYcPE2xrxljLndH2FUcNjoGWWy\nTLtMhk1P10kmW4/UcUG7ThTa8laD8GpJNdNyEshN0S6T4bR0ehad2nWiPLR4qwGpbGxn35lG7TKx\nwFV5yaTGR7GhRLtOlBZvNUCvFlcBcPv0bIuThB+7TVg+PZM3S+t01InS4q0GZt3+nok52mVijRUz\ns+l0utmip4kNe1q8lddOnr1ASWUTK2Zol4lV5uQmkZ0Yzbr9VVZHURbT4q28tt5TMJZr8baMzSbc\nPjObbcfq9TSxYU6Lt/LauuIq5uUlk5UYY3WUsHb7jCy6XYbNukhDWNPirbxypKaFo7WtrJiprW6r\nTc9JZExKLOuKtesknGnxVl5Zt78Km8ASnZhjORFhxYxs3i07y9lWXZw4XGnxVv0yxrCuuIprx6eS\nNiLK6jiKnlEnbgMbdcx32NLirfpVXNHE6YY2btcDlQEjP3MEkzLieUVHnYQtLd6qX3/eV0mk3aan\nfw0wd87KYdep85w512Z1FGUBLd7qipwuN+v2V3HLlHQSYyOsjqN6uWNmzyxXbX2HJy3e6oq2l53l\nbGsXd83OsTqK6mN0cizz8pJZu6cCY4zVcdQw0+KtrujPeytJjIlgQb6ujBSI7pqdw/H6CxysarY6\nihpm3q5hGS0iO0Vkv4gcFJHv+DuYst6FTiebD9ayfEYWUQ671XHUJSyfnkWk3cbaPZVWR1HDzNuW\ndyew0BgzE5gFLBGR+f6LpQLB5oM1tHe7uFu7TAJWYmwECyen88r+Kpwut9Vx1DDydg1LY4xp9fwa\n4fnRTrYQ9/LeSkYlxTA3N8nqKOoK7pqdw9nWTt493mB1FDWMvO7zFhG7iOwD6oDXjTE7/BdLWa2u\nuYN3y85y16wcbDaxOo66gpsnp5EQ7eDlPRVWR1HDyOvibYxxGWNmAaOAeSIyre82IrJaRIpEpKi+\nvt6XOdUw+/O+StwGHWUSBKIcdpbPyGbzwVpaOvRMg+FiMAsQNwJbgSWXuG2NMabQGFOYlqajE4KV\nMYbniyqYkzuSCenxVsdRXvhU4Sjau128qqvLhw1vR5ukichIz+UY4Dag1J/BlHX2nWmkrK6Vvyoc\nbXUU5aXZo3v+0L6wW7tOwoW3Le8sYKuIFAO76OnzXu+/WMpKL+yuIDrCposuBBER4VNzR7H79HnK\n6lr7v4MKet6ONik2xsw2xswwxkwzxnzX38GUNdq7XKzbV8Wy6VmMiNbp8MHk7jk52G3Ci9r6Dgs6\nw1J9xOaDNbR0OvnUXO0yCTbpI6K5OT+Nl/ZU6JjvMKDFW33EC7vPMDo5hqvHJlsdRQ3CJ+eOpr6l\nk23HdLRXqNPirT505lwb75Y18Km5o3Vsd5BaODmdlLhInt+lXSehTou3+tDzRWcQgU/MHWV1FDVI\nkQ4bd8/OYcvhWupaOqyOo/xIi7cCoNvl5tldZ1iYn07OSF0dPpitvDoXp9vwQpG2vkOZFm8FwJZD\ntdS3dPLA/Fyro6ghGp8Wz7XjU/jTjnJcbj0FUajS4q0A+OOO0+SMjOGmSelWR1E+8MDVY6hsbGfb\nUT1wGaq0eCtO1LfyblkDK+eNxq4HKkPCbVMzSI2P4ukdp62OovxEi7fimZ3lOGyi0+FDSKTDxr1X\njeLN0joqG9utjqP8QIt3mOvodvHC7goWFWSQnhBtdRzlQ/ddlYsBnt1ZbnUU5QdavMPcq8XVNLZ1\n8+DVY6yOonxsdHIsN+en8+yuM3Q5dcZlqNHiHcaMMfz23ZOMT4vjmvEpVsdRfrDqmjHUt3SyoURP\nFRtqtHiHsV2nznOwqpmHrh+LiB6oDEU3TUxjXFocT24/iTE6bDCUaPEOY09uP8HI2Ajuma0zKkOV\nzSY8dN1YSiqbKDp93uo4yoe0eIep8oY2XjtUy/3zcomJtFsdR/nRPXNySIyJ4Ml3TlodRfmQFu8w\n9dT7p7CL8Olr8qyOovwsNtLB/Vfn8tqhGs6ca7M6jvIRb5dBGy0iW0XkkIgcFJG/93cw5T8tHd08\nt+sMy2dkkZmowwPDwaevGYOI8Lv3TlkdRfmIty1vJ/A1Y8xUYD7wiIhM9V8s5U/PF1XQ2unkoevG\nWh1FDZOsxBiWTc/iuV1naNYV5kOCt8ugVRtj9ngutwCHgRx/BlP+0el08ZttJ5g3NpmZo0daHUcN\no9U3jKO108kf3tcp86FgwH3eIpIHzAZ2XOK21SJSJCJF9fV6QpxA9PKeSmqaO/jSzROsjqKG2fRR\nidw4KY3fbj9Je5fL6jhqiAZUvEUkHngJ+LIxprnv7caYNcaYQmNMYVpamq8yKh9xutz88u3jTM9J\n5IaJqVbHURZ4ZMF4Gi508dwunTIf7Lwu3iISQU/hftoYs9Z/kZS/vFpSzemGNh65eYJOyglTV49L\n4aq8JH697YROmQ9y3o42EeBJ4LAx5kf+jaT8we02/GLrcSamx7NoaobVcZSFvnjzBKqbOvjz3kqr\no6gh8LblfR2wClgoIvs8P8v8mEv52BuldRypbeGLN4/XxYXD3IJJaRRkJ/DLt4/rSjtBzNvRJtuN\nMWKMmWGMmeX52eDvcMo33G7D41uOkpscy4oZ2VbHURYTEb508wROnr2gre8gpjMsw8CGA9UcrGrm\nK7dNxGHXl1zB4oJMpuUk8OMtR7XvO0jpJznEOV1ufvTaUSZlxHPHTB2ar3rYbMLXF+VTcb6dZ3Xk\nSVDS4h3iXtpTwYmzF/j6onxdn1J9xE2T0pg3NpmfvlFGW5fT6jhqgLR4h7CObhc/2XKMWaNHcpuO\nMFF9iAjfWJzP2dZOnnpPZ10GGy3eIezpHeVUNXXwjcX5Oq5bXVJhXjILJ6fzq7eP09Sm5zwJJlq8\nQ9S5C138ZMtRbpiYyrUTdDaluryvL8qnpaObx984anUUNQBavEPUj14/woUuF/98u578UV3Z1OwE\n7puXy+/fP82x2har4ygvafEOQYeqmvnTjnJWzR/DpIwRVsdRQeBrt00iLtLOd9cf0rUug4QW7xBj\njOE76w6SGBPBV26dZHUcFSRS4qP48q2TeOfYWV4/VGt1HOUFLd4hZkNJDTtOnuNri/JJjI2wOo4K\nIquuGcPE9Hi+9+phOrr1lLGBTot3CGnu6Ob/rT/ElKwEVs7LtTqOCjIRdhv/uqKA8nNt/GJrmdVx\nVD+0eIeQ728opa6lgx/cM10n5KhBuX5iKnfPzuEXbx3ncPXHTtmvAogW7xDx3vGzPLOznIdvGKfL\nm6kh+Zfbp5IYE8E/vlSM06XnPQlUWrxDQHuXi2++VEJeSqwepFRDlhQXyXfuLKC4oonfvnvS6jjq\nMrR4h4AfvnaE8nNt/OATM4iJtFsdR4WA5dOzuG1qBv/12lFO1LdaHUddwkCWQfutiNSJyAF/BlID\ns+1oPU9uP8mq+WOYPy7F6jgqRIgI37trGjGRdh59di+dTh19EmgG0vL+HbDETznUINS1dPDV5/eR\nnzGCby+fYnUcFWIyEqJ57BMzOFDZzGObjlgdR/XhdfE2xmwDzvkxixoAt9vwtef309rp5Gf3zyY6\nQrtLlO8tKsjkM9eM4cntJ3mzVCfvBBLt8w5Sa945wTvHzvKvKwp0Crzyq28tm8KUrAS+/kIxNU0d\nVsdRHj4t3iKyWkSKRKSovr7elw+tetl2tJ7HNpWyfEYW91012uo4KsRFR9j57/tn09Ht4vN/KNLZ\nlwHCp8XbGLPGGFNojClMS0vz5UMrj7K6Vh750x7yMxN47BMz9DzdaliMT4vnx/fOYn9FE994sVhP\nXhUAtNskiDS2dfHwU7uIctj4zafnEhflsDqSCiOLCzL5h8X5vLK/ip/r9HnLDWSo4DPA+0C+iFSI\nyOf8F0v11el08bd/3ENVYwe/XjWXUUmxVkdSYeiLC8Zz9+wcfvjaUdYXV1kdJ6x53XQzxqz0ZxB1\ned0uN1/6017eP9HAj++dydwxyVZHUmFKRPj+PdOpON/GV57bR1ykg5snp1sdKyxpt0mAc3mGBL5+\nqJbv3lnA3bNHWR1JhbnoCDtP/vVV5GeO4At/3M17x89aHSksafEOYG634f+sLeGV/VV8c+lkPn1N\nntWRlAIgITqC3z90NbnJsTz8VBG7T+sUkOGmxTtAdTpdPPrsXp4rOsOjCyfwhZvGWx1JqY9Ijovk\n6YevJn1EFA8+sZOtR+qsjhRWtHgHoNZOJ5/7XRHri6v51tLJfHVRvtWRlLqk9IRoXvjCtYxNjeNv\nniriz3srrY4UNrR4B5jqpnZWrvmA90808MNPzeTz2uJWAS5tRBTPfn4+hXlJfPm5ffzq7eM6DnwY\naPEOIO8fb2DFz7Zzor6VNavm8sm5enBSBYeE6Ah+99l53D4jix9sLOWRP+2htdNpdayQprM8AoDb\nbXhi+wn+Y9MR8lJieXb1NUxIj7c6llIDEh1h52crZzNjVCI/2FjKkZoWfr1qLhPS9dw7/qAtb4ud\nOdfGA0/s4N83lLJoagb/+6XrtXCroCUirL5xPH98+Goa27pZ/tPt/GbbCVxu7UbxNS3eFnG7DX/4\n4DSLH99GSWUT379nOr94YA7xOuVdhYBrx6ey8cs3cOOkNP5tw2H+6tfvc1xX5PEp8deBhcLCQlNU\nVOSXxw52O0408L1XD1NS2cT1E1L5j0/OIGdkjNWxlPI5Ywx/3lfJ/33lEG1dTj59TR6PLpxIYmyE\n1dEClojsNsYU9redNvOG0dHaFn702lE2HawhKzGaH987k7tm5eiZAVXIEhHunj2K6yek8V+vHeG3\n755k7Z4K/m7hRFbOy9U1V4dAW97DYP+ZRn6+tYzXDtUSG2nnCzeN529uGKdvXBV2DlU1871XD/He\n8QZS4iJ56PqxrLpmDAnR2hK/yNuWtxZvP2nrcrK+uJpndpazt7yRxJgI/vraPP762jyS4iKtjqeU\npXaePMeBHCOWAAAJlUlEQVTPt5bx9tF64iLt3DErm5Xzcpmekxj230S1eFug0+ni3bKzbCipYfOB\nGlo6nYxPi+P+q8dw71Wj9WCkUn0cqGziqfdOsa64io5uN1OyErh9RhZLp2UyLi08R11p8R4mZ861\nsb3sLNuPnWXbsXpaOpyMiHawaGom980bTeGYpLBvSSjVn+aObv53XxVr91Swt7wRgPyMEdw4KZXr\nJ6YxLy85bLoZfV68RWQJ8BPADjxhjPnBlbYPxeLd1N7NkZoWSiqb2Ft+nr3ljVQ2tgOQmRDNjZNS\nWTo9i+vGpxLp0FGYSg1GdVM7mw7U8NrBWnafPk+Xy02k3cbU7ARm545kdm4SU7MSyEuJxWEPvc+Z\nT4u3iNiBo8BtQAWwC1hpjDl0ufsEY/E2xtDc4aSqsZ3qpnZON7Rx6uwFTjW0UVbX+mGhBsgZGcOs\n3JEUjknihompjE+L1xa2Uj7W3uVix8kG3j/RwN7TjRRXNtLR7QYg0mFjYno849LiGZsSy5iUOEYl\nxZA9MoaMhOigbUD5eqjgPKDMGHPC8+DPAncCly3ew8kYQ5fLTZfTTbfL0NHtotPppqPbRVuXi/Yu\nF21dTlo7e35aOpw0tXdz/kIXje3dNLR2Ut/aydmWLtr7rIwdF2knLzWOuWOSeGB+LlMyE5ianUBG\nQrRF/1ulwkdMpJ0F+eksyO9Zrafb5eZobQul1S2U1jRTWtPCvjPnebW4it6TOEUgOTaS1PgoUkdE\nkhTb8zMyNoKE6Ajiox3ER/X8xETaiY20ExNhJzrCTpTDRpTDToRDiLTbsNskIBtm3hbvHOBMr98r\ngKt9Hwf+7dVDvFlahzHgNgaXMbjdnstug9sYnG6D02Xodrlxus2gpt5GR9gYGdPzYqbERzInN4nU\n+CgyE6LJHhlD1shoRifFkhofGZAvnFLhKMJuoyA7kYLsxI9c3+l0UXG+vedbc2MHlY3tngZZJ2db\nO6lqbKaxrYum9m4GWi5EIMLWU8QddsFhE+w2Gw6bYBOw2QSbiKfIg02En62czZSsBB/+zz/Op8Mf\nRGQ1sBogNzd3UI+RlRjD5KwEbOLZMSIfXrbbenbQxZ0XYe/ZmXabjSiHjUh7z3XRvf6C9vxVdRAb\naSc+ysGIaAdxUQ6iI8Lj4IdS4SDKYWd8Wjzj+xmh4nYb2rpdtHR009Lh5EKn0/PN3EWH00VHt/vD\nb+7dH36b7/lG73Jf/LenUel0uXGbnsd0GfNhg9MYiBmG+uJt8a4ERvf6fZTnuo8wxqwB1kBPn/dg\nAj10/VgeYuxg7qqUUldks8mH3SVZif1vH8i87dHfBUwUkbEiEgncB7ziv1hKKaWuxKuWtzHGKSJf\nAjbTM1Twt8aYg35NppRS6rK87vM2xmwANvgxi1JKKS8F50BIpZQKc1q8lVIqCGnxVkqpIKTFWyml\ngpAWb6WUCkJ+OyWsiLQAR/zy4EOXCpy1OsRlaLbBCdRsgZoLNNtg+TvbGGNMWn8b+XN1gCPenBnL\nCiJSpNkGTrMNXKDmAs02WIGSTbtNlFIqCGnxVkqpIOTP4r3Gj489VJptcDTbwAVqLtBsgxUQ2fx2\nwFIppZT/aLeJUkoFoSEVbxFJFpHXReSY59+ky2y3SUQaRWR9n+vHisgOESkTkec8p5v1iQFk+4xn\nm2Mi8ple178lIkdEZJ/nJ90HmZZ4HrNMRL55idujPPuhzLNf8nrd9i3P9UdEZPFQs/gil4jkiUh7\nr330K1/m8jLbjSKyR0ScIvLJPrdd8rUNkGyuXvvN56dX9iLbV0XkkIgUi8gbIjKm121+229DzGX1\nPvuCiJR4nn+7iEztdZvfPp+XZYwZ9A/wGPBNz+VvAv9xme1uAVYA6/tc/zxwn+fyr4C/HUqegWYD\nkoETnn+TPJeTPLe9BRT6MI8dOA6MAyKB/cDUPtt8EfiV5/J9wHOey1M920cBYz2PYw+AXHnAAV/t\no0FmywNmAL8HPunNa2t1Ns9trRbvt5uBWM/lv+31mvptvw0lV4Dss4Rel+8ANnku++3zeaWfoXab\n3Ak85bn8FHDXpTYyxrwBtPS+TkQEWAi82N/9/ZhtMfC6MeacMeY88DqwxIcZevtwEWdjTBdwcRHn\ny2V+EbjFs5/uBJ41xnQaY04CZZ7HszqXv/WbzRhzyhhTDLj73Nffr+1QsvmbN9m2GmPaPL9+QM/q\nWODf/TaUXP7mTbbmXr/GARcPGPrz83lZQy3eGcaYas/lGiBjAPdNARqNMU7P7xX0LHTsK95ku9TC\nyr0z/I/nK9I/+6BY9fdcH9nGs1+a6NlP3tzXilwAY0Vkr4i8LSI3+CjTQLL5477D8fjRIlIkIh+I\niC8bLTDwbJ8DNg7yvsOVCwJgn4nIIyJynJ5v9o8O5L6+1u8MSxHZAmRe4qZv9/7FGGNEZFiHrvg5\n2wPGmEoRGQG8BKyi5+uv+otqINcY0yAic4E/i0hBnxaKurQxnvfXOOBNESkxxhwf7hAi8iBQCNw0\n3M99JZfJZfk+M8b8HPi5iNwP/BPg82Mp3uq3eBtjbr3cbSJSKyJZxphqEckC6gbw3A3ASBFxeFpz\nl1zU2M/ZKoEFvX4fRU9fN8aYSs+/LSLyJ3q+Bg2leHuziPPFbSpExAEk0rOfvFoAerhzmZ4Ov04A\nY8xuT4tkElA0jNmudN8Ffe77lk9S/eXxB/2a9Hp/nRCRt4DZ9PSVDls2EbmVnobOTcaYzl73XdDn\nvm8FQK6A2Ge9PAv8cpD39Y0hdvL/Jx89KPjYFbZdwMcPWL7ARw9YftFXnfneZKPnoMxJeg7MJHku\nJ9PzRy3Vs00EPf28XxhiHgc9B3/G8pcDIgV9tnmEjx4YfN5zuYCPHhA5ge8OWA4lV9rFHPQc6KkE\nkn34Gvabrde2v+PjByw/9toGSLYkIMpzORU4Rp+DY8Pwml4sfBO9+UwEQK5A2GcTe11eARR5Lvvt\n83nFzEP8D6cAb3h25JaLLzI9X3ee6LXdO0A90E5Pf9Biz/XjgJ30dPC/cPHF8dGL4W22hzzPXwZ8\n1nNdHLAbKAYOAj/xxYsBLAOOet6c3/Zc913gDs/laM9+KPPsl3G97vttz/2OAEt9+iYYZC7gE579\nsw/YA6zw+Ru0/2xXed5TF+j5lnLwSq9tIGQDrgVKPB/4EuBzFmTbAtR6Xrt9wCvDsd8GmytA9tlP\ner3ft9KruPvz83m5H51hqZRSQUhnWCqlVBDS4q2UUkFIi7dSSgUhLd5KKRWEtHgrpVQQ0uKtlFJB\nSIu3UkoFIS3eSikVhP4/6wFSm8ylAfcAAAAASUVORK5CYII=\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "def show_weight_dist(mean, variance):\n", " sigma = nd.sqrt(variance)\n", " x = np.linspace(mean.asscalar() - 4*sigma.asscalar(), mean.asscalar() + 4*sigma.asscalar(), 100)\n", " plt.plot(x, gaussian(nd.array(x, ctx=ctx), mean, sigma).asnumpy())\n", " plt.show()\n", " \n", "mu = mus[0][0][0]\n", "var = softplus(rhos[0][0][0]) ** 2\n", "\n", "show_weight_dist(mu, var)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Great! We have obtained a fully functional Bayesian neural network. However, the number of weights now is twice as high as for traditional neural networks. As we will see in the final section of this notebook, we are able to drastically reduce the number of weights our network uses for prediction with _weight pruning_." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Weight pruning\n", "\n", "To measure the degree of redundancy present in the trained network and to reduce the model's parameter count, we now want to examine the effect of setting some of the weights to $0$ and evaluate the test accuracy afterwards. We can achieve this by ordering the weights according to their signal-to-noise-ratio, $\\frac{|\\mu_i|}{\\sigma_i}$, and setting a certain percentage of the weights with the lowest ratios to $0$." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can calculate the signal-to-noise-ratio as follows:" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def signal_to_noise_ratio(mus, sigmas):\n", " sign_to_noise = []\n", " for j in range(len(mus)):\n", " sign_to_noise.extend([nd.abs(mus[j]) / sigmas[j]])\n", " return sign_to_noise" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We further introduce a few helper methods which turn our list of weights into a single vector containing all weights. This will make our subsequent actions easier." ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def vectorize_matrices_in_vector(vec):\n", " for i in range(0, (num_layers + 1) * 2, 2):\n", " if i == 0:\n", " vec[i] = nd.reshape(vec[i], num_inputs * num_hidden)\n", " elif i == num_layers * 2:\n", " vec[i] = nd.reshape(vec[i], num_hidden * num_outputs)\n", " else:\n", " vec[i] = nd.reshape(vec[i], num_hidden * num_hidden)\n", " \n", " return vec\n", "\n", "def concact_vectors_in_vector(vec):\n", " concat_vec = vec[0]\n", " for i in range(1, len(vec)):\n", " concat_vec = nd.concat(concat_vec, vec[i], dim=0)\n", " \n", " return concat_vec\n", "\n", "def transform_vector_structure(vec):\n", " vec = vectorize_matrices_in_vector(vec)\n", " vec = concact_vectors_in_vector(vec)\n", " \n", " return vec" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In addition, we also have a helper method which transforms the pruned weight vector back to the original layered structure." ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from functools import reduce\n", "import operator\n", "\n", "def prod(iterable):\n", " return reduce(operator.mul, iterable, 1)\n", "\n", "def restore_weight_structure(vec):\n", " pruned_weights = []\n", " \n", " index = 0\n", " \n", " for shape in layer_param_shapes:\n", " incr = prod(shape)\n", " pruned_weights.extend([nd.reshape(vec[index : index + incr], shape)])\n", " index += incr\n", " \n", " return pruned_weights" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The actual pruning of the vector happens in the following function. Note that this function accepts an ordered list of percentages to evaluate the performance at different pruning rates. In this setting, pruning at each iteration means extracting the index of the lowest signal-to-noise-ratio weight and setting the weight at this index to $0$." ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "collapsed": true, "scrolled": false }, "outputs": [], "source": [ "def prune_weights(sign_to_noise_vec, prediction_vector, percentages):\n", " pruning_indices = nd.argsort(sign_to_noise_vec, axis=0)\n", " \n", " for percentage in percentages:\n", " prediction_vector = mus_copy_vec.copy()\n", " pruning_indices_percent = pruning_indices[0:int(len(pruning_indices)*percentage)]\n", " for pr_ind in pruning_indices_percent:\n", " prediction_vector[int(pr_ind.asscalar())] = 0\n", " pruned_weights = restore_weight_structure(prediction_vector)\n", " test_accuracy = evaluate_accuracy(test_data, net, pruned_weights)\n", " print(\"%s --> %s\" % (percentage, test_accuracy))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Putting the above functions together:" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.1 --> 0.9777\n", "0.25 --> 0.9779\n", "0.5 --> 0.9756\n", "0.75 --> 0.9602\n", "0.95 --> 0.7259\n", "0.99 --> 0.3753\n", "1.0 --> 0.098\n" ] } ], "source": [ "sign_to_noise = signal_to_noise_ratio(mus, sigmas)\n", "sign_to_noise_vec = transform_vector_structure(sign_to_noise)\n", "\n", "mus_copy = mus.copy()\n", "mus_copy_vec = transform_vector_structure(mus_copy)\n", "\n", "prune_weights(sign_to_noise_vec, mus_copy_vec, [0.1, 0.25, 0.5, 0.75, 0.95, 0.99, 1.0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Depending on the number of units used in the original network and the number of training epochs, the highest achievable pruning percentages (without significantly reducing the predictive performance) can vary. The paper, for example, reports almost no change in the test accuracy when pruning 95% of the weights in a 2x1200 unit Bayesian neural network, which creates a significantly sparser network, leading to faster predictions and reduced memory requirements." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Conclusion\n", "\n", "We have taken a look at an efficient Bayesian treatment for neural networks using variational inference via the \"Bayes by Backprop\" algorithm (introduced by the \"[Weight Uncertainity in Neural Networks](https://arxiv.org/abs/1505.05424)\" paper). We have implemented a stochastic version of the variational lower bound and optimized it in order to find an approximation to the posterior distribution over the weights of a MLP network on the MNIST data set. As a result, we achieve regularization on the network's parameters and can quantify our uncertainty about the weights accurately. Finally, we saw that it is possible to significantly reduce the number of weights in the neural network after training while still keeping a high accuracy on the test set.\n", "\n", "We also note that, given this model implementation, we were able to reproduce the paper's results on the MNIST data set, achieving a comparable test accuracy for all documented instances of the MNIST classification problem." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For whinges or inquiries, [open an issue on GitHub.](https://github.com/zackchase/mxnet-the-straight-dope)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.2" } }, "nbformat": 4, "nbformat_minor": 2 }