Overfitting and regularization (with gluon)

Now that we’ve built a regularized logistic regression model from scratch, let’s make this more efficient with gluon. We recommend that you read the latter for a description as to why regularization is a good idea. As always, we begin by loading libraries and some data.

[REFINED DRAFT - RELEASE STAGE: CATFOOD]

In [1]:
from __future__ import print_function
import mxnet as mx
from mxnet import autograd
from mxnet import gluon
import mxnet.ndarray as nd
import numpy as np
ctx = mx.cpu()

The MNIST Dataset

In [2]:
mnist = mx.test_utils.get_mnist()
num_examples = 1000
batch_size = 64
train_data = mx.gluon.data.DataLoader(
    mx.gluon.data.ArrayDataset(mnist["train_data"][:num_examples],
                               mnist["train_label"][:num_examples].astype(np.float32)),
                               batch_size, shuffle=True)
test_data = mx.gluon.data.DataLoader(
    mx.gluon.data.ArrayDataset(mnist["test_data"][:num_examples],
                               mnist["test_label"][:num_examples].astype(np.float32)),
                               batch_size, shuffle=False)

Multiclass Logistic Regression

In [3]:
net = gluon.nn.Sequential()
with net.name_scope():
    net.add(gluon.nn.Dense(10))

Parameter initialization

In [4]:
net.collect_params().initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)

Softmax Cross Entropy Loss

In [5]:
loss = gluon.loss.SoftmaxCrossEntropyLoss()

Optimizer

By default gluon tries to keep the coefficients from diverging by using a weight decay penalty. So, to get the real overfitting experience we need to switch it off. We do this by passing 'wd': 0.0' when we instantiate the trainer.

In [6]:
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.01, 'wd': 0.0})

Evaluation Metric

In [7]:
def evaluate_accuracy(data_iterator, net):
    numerator = 0.
    denominator = 0.
    for i, (data, label) in enumerate(data_iterator):
        data = data.as_in_context(ctx).reshape((-1,784))
        label = label.as_in_context(ctx)
        label_one_hot = nd.one_hot(label, 10)
        output = net(data)
        predictions = nd.argmax(output, axis=1)
        numerator += nd.sum(predictions == label)
        denominator += data.shape[0]
    return (numerator / denominator).asscalar()

Execute training loop

In [8]:
epochs = 700
moving_loss = 0.

for e in range(epochs):
    for i, (data, label) in enumerate(train_data):
        data = data.as_in_context(ctx).reshape((-1,784))
        label = label.as_in_context(ctx)
        with autograd.record():
            output = net(data)
            cross_entropy = loss(output, label)
        cross_entropy.backward()
        trainer.step(data.shape[0])

        ##########################
        #  Keep a moving average of the losses
        ##########################
        if i == 0:
            moving_loss = nd.mean(cross_entropy).asscalar()
        else:
            moving_loss = .99 * moving_loss + .01 * nd.mean(cross_entropy).asscalar()

    test_accuracy = evaluate_accuracy(test_data, net)
    train_accuracy = evaluate_accuracy(train_data, net)
    if e % 20 == 0:
        print("Completed epoch %s. Loss: %s, Train_acc %s, Test_acc %s" %
              (e, moving_loss, train_accuracy, test_accuracy))
Completed epoch 0. Loss: 2.41042828249, Train_acc 0.183, Test_acc 0.163
Completed epoch 20. Loss: 0.980734459237, Train_acc 0.823, Test_acc 0.734
Completed epoch 40. Loss: 0.656826479353, Train_acc 0.863, Test_acc 0.762
Completed epoch 60. Loss: 0.547160148674, Train_acc 0.879, Test_acc 0.776
Completed epoch 80. Loss: 0.410905732785, Train_acc 0.894, Test_acc 0.8
Completed epoch 100. Loss: 0.5042345608, Train_acc 0.902, Test_acc 0.813
Completed epoch 120. Loss: 0.393295650603, Train_acc 0.908, Test_acc 0.817
Completed epoch 140. Loss: 0.31935924206, Train_acc 0.915, Test_acc 0.819
Completed epoch 160. Loss: 0.353683424529, Train_acc 0.922, Test_acc 0.826
Completed epoch 180. Loss: 0.441564063282, Train_acc 0.925, Test_acc 0.826
Completed epoch 200. Loss: 0.439105920901, Train_acc 0.928, Test_acc 0.827
Completed epoch 220. Loss: 0.335929108585, Train_acc 0.93, Test_acc 0.831
Completed epoch 240. Loss: 0.341114005655, Train_acc 0.933, Test_acc 0.833
Completed epoch 260. Loss: 0.265167222471, Train_acc 0.935, Test_acc 0.835
Completed epoch 280. Loss: 0.225616094168, Train_acc 0.937, Test_acc 0.836
Completed epoch 300. Loss: 0.31287811555, Train_acc 0.939, Test_acc 0.836
Completed epoch 320. Loss: 0.201846450408, Train_acc 0.941, Test_acc 0.836
Completed epoch 340. Loss: 0.26857858658, Train_acc 0.946, Test_acc 0.838
Completed epoch 360. Loss: 0.250789181783, Train_acc 0.949, Test_acc 0.839
Completed epoch 380. Loss: 0.292163310882, Train_acc 0.949, Test_acc 0.839
Completed epoch 400. Loss: 0.364868967968, Train_acc 0.95, Test_acc 0.842
Completed epoch 420. Loss: 0.16974889173, Train_acc 0.952, Test_acc 0.843
Completed epoch 440. Loss: 0.20597298219, Train_acc 0.954, Test_acc 0.845
Completed epoch 460. Loss: 0.135431291838, Train_acc 0.955, Test_acc 0.848
Completed epoch 480. Loss: 0.192532801942, Train_acc 0.96, Test_acc 0.849
Completed epoch 500. Loss: 0.139670844177, Train_acc 0.961, Test_acc 0.848
Completed epoch 520. Loss: 0.188918106269, Train_acc 0.962, Test_acc 0.85
Completed epoch 540. Loss: 0.182555058185, Train_acc 0.964, Test_acc 0.848
Completed epoch 560. Loss: 0.149213267695, Train_acc 0.965, Test_acc 0.847
Completed epoch 580. Loss: 0.215244527124, Train_acc 0.965, Test_acc 0.844
Completed epoch 600. Loss: 0.220504981679, Train_acc 0.966, Test_acc 0.842
Completed epoch 620. Loss: 0.225874061765, Train_acc 0.966, Test_acc 0.841
Completed epoch 640. Loss: 0.105731161028, Train_acc 0.967, Test_acc 0.841
Completed epoch 660. Loss: 0.229186017501, Train_acc 0.968, Test_acc 0.841
Completed epoch 680. Loss: 0.195402703398, Train_acc 0.969, Test_acc 0.842

Regularization

Now let’s see what this mysterious weight decay is all about. We begin with a bit of math. When we add an L2 penalty to the weights we are effectively adding \(\frac{\lambda}{2} \|w\|^2\) to the loss. Hence, every time we compute the gradient it gets an additional \(\lambda w\) term that is added to \(g_t\), since this is the very derivative of the L2 penalty. As a result we end up taking a descent step not in the direction \(-\eta g_t\) but rather in the direction \(-\eta (g_t + \lambda w)\). This effectively shrinks \(w\) at each step by \(\eta \lambda w\), thus the name weight decay. To make this work in practice we just need to set the weight decay to something nonzero.

In [9]:
net.collect_params().initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx, force_reinit=True)
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.01, 'wd': 0.001})

moving_loss = 0.

for e in range(epochs):
    for i, (data, label) in enumerate(train_data):
        data = data.as_in_context(ctx).reshape((-1,784))
        label = label.as_in_context(ctx)
        with autograd.record():
            output = net(data)
            cross_entropy = loss(output, label)
        cross_entropy.backward()
        trainer.step(data.shape[0])

        ##########################
        #  Keep a moving average of the losses
        ##########################
        if i == 0:
            moving_loss = nd.mean(cross_entropy).asscalar()
        else:
            moving_loss = .99 * moving_loss + .01 * nd.mean(cross_entropy).asscalar()

    test_accuracy = evaluate_accuracy(test_data, net)
    train_accuracy = evaluate_accuracy(train_data, net)
    if e % 20 == 0:
        print("Completed epoch %s. Loss: %s, Train_acc %s, Test_acc %s" %
              (e, moving_loss, train_accuracy, test_accuracy))
Completed epoch 0. Loss: 2.30056658499, Train_acc 0.255, Test_acc 0.207
Completed epoch 20. Loss: 0.953274930743, Train_acc 0.832, Test_acc 0.733
Completed epoch 40. Loss: 0.773029556608, Train_acc 0.861, Test_acc 0.763
Completed epoch 60. Loss: 0.643435905558, Train_acc 0.886, Test_acc 0.782
Completed epoch 80. Loss: 0.573625952417, Train_acc 0.892, Test_acc 0.794
Completed epoch 100. Loss: 0.471484044873, Train_acc 0.9, Test_acc 0.804
Completed epoch 120. Loss: 0.396960894484, Train_acc 0.909, Test_acc 0.814
Completed epoch 140. Loss: 0.409119025299, Train_acc 0.916, Test_acc 0.821
Completed epoch 160. Loss: 0.34956342471, Train_acc 0.918, Test_acc 0.826
Completed epoch 180. Loss: 0.364122123705, Train_acc 0.923, Test_acc 0.83
Completed epoch 200. Loss: 0.23629064181, Train_acc 0.927, Test_acc 0.83
Completed epoch 220. Loss: 0.35734446598, Train_acc 0.929, Test_acc 0.832
Completed epoch 240. Loss: 0.272298600617, Train_acc 0.931, Test_acc 0.833
Completed epoch 260. Loss: 0.362590129895, Train_acc 0.936, Test_acc 0.832
Completed epoch 280. Loss: 0.27461734356, Train_acc 0.938, Test_acc 0.835
Completed epoch 300. Loss: 0.253035167053, Train_acc 0.941, Test_acc 0.835
Completed epoch 320. Loss: 0.257027310991, Train_acc 0.942, Test_acc 0.836
Completed epoch 340. Loss: 0.359865081517, Train_acc 0.944, Test_acc 0.836
Completed epoch 360. Loss: 0.275248116974, Train_acc 0.945, Test_acc 0.839
Completed epoch 380. Loss: 0.357199658944, Train_acc 0.953, Test_acc 0.841
Completed epoch 400. Loss: 0.175214320661, Train_acc 0.955, Test_acc 0.842
Completed epoch 420. Loss: 0.270390815553, Train_acc 0.956, Test_acc 0.841
Completed epoch 440. Loss: 0.199450716782, Train_acc 0.956, Test_acc 0.842
Completed epoch 460. Loss: 0.151572242711, Train_acc 0.957, Test_acc 0.842
Completed epoch 480. Loss: 0.225918149547, Train_acc 0.96, Test_acc 0.842
Completed epoch 500. Loss: 0.309496395041, Train_acc 0.96, Test_acc 0.841
Completed epoch 520. Loss: 0.194795705577, Train_acc 0.961, Test_acc 0.841
Completed epoch 540. Loss: 0.211884802227, Train_acc 0.962, Test_acc 0.844
Completed epoch 560. Loss: 0.163623161352, Train_acc 0.962, Test_acc 0.845
Completed epoch 580. Loss: 0.1975672145, Train_acc 0.964, Test_acc 0.845
Completed epoch 600. Loss: 0.136671575637, Train_acc 0.966, Test_acc 0.845
Completed epoch 620. Loss: 0.267496398274, Train_acc 0.966, Test_acc 0.845
Completed epoch 640. Loss: 0.185523537505, Train_acc 0.966, Test_acc 0.846
Completed epoch 660. Loss: 0.154396287643, Train_acc 0.966, Test_acc 0.848
Completed epoch 680. Loss: 0.183919263789, Train_acc 0.966, Test_acc 0.848

As we can see, the test accuracy improves a bit. Note that the amount by which it improves actually depends on the amount of weight decay. We recommend that you try and experiment with different extents of weight decay. For instance, a larger weight decay (e.g. \(0.01\)) will lead to inferior performance, one that’s larger still (\(0.1\)) will lead to terrible results. This is one of the reasons why tuning parameters is quite so important in getting good experimental results in practice.