Bright Wire

Training a vanilla feed forward Neural Network on images of handwritten digits.

Motivation

The MNIST data set is a classic handwritten digit recognition data set. This tutorial shows how you can use Bright Wire to train a vanilla feed forward neural network to get a classification accuracy of 98%.

Getting Started

You'll need download the four data files and unzip them to a directory on your computer.

Create a new .NET 4.6 console application and include Bright Wire.

If you have an NVIDIA GPU then you can also install the CUDA version to speed up training.

Loading the Data

Bright Wire includes a helper class for loading the MNIST data, which consists of two sets of labelled image data (one for test and the other for training).

Each image is stored as 784 pixels with an intensity ranging from 0 to 255 for each value. The Mnist helper class creates vectors of length 784 and scales each value between 0 and 1.

using (var lap = BrightWireGpuProvider.CreateLinearAlgebra()) {
    var graph = new GraphFactory(lap);

    Console.Write("Loading training data...");
    var trainingData = _BuildVectors(null, graph, Mnist.Load(dataFilesPath + "train-labels.idx1-ubyte", dataFilesPath + "train-images.idx3-ubyte"));
    var testData = _BuildVectors(trainingData, graph, Mnist.Load(dataFilesPath + "t10k-labels.idx1-ubyte", dataFilesPath + "t10k-images.idx3-ubyte"));
    Console.WriteLine($"done - {trainingData.RowCount} training images and {testData.RowCount} test images loaded");

Network Design

Since our input layer is a feature vector of length 784 (one feature for each pixel) and the output is a one hot encoded vector of length 10 (if the training sample is an image of a three then the output vector will be 0010000000) we just need to choose the network dimensions in the middle.

A single hidden layer of size 1024 produces reasonable results, so the network ends up as 784x1024x10 with a dropout layer in between the first and second layers. The training rate is reduced after 15 epochs.

// one hot encoding uses the index of the output vector's maximum value as the classification label
var errorMetric = graph.ErrorMetric.OneHotEncoding;

// configure the network properties
graph.CurrentPropertySet
    .Use(graph.GradientDescent.RmsProp)
    .Use(graph.WeightInitialisation.Xavier)
;

// create the training engine and schedule a training rate change
const float TRAINING_RATE = 0.1f;
var engine = graph.CreateTrainingEngine(trainingData, TRAINING_RATE, 128);
engine.LearningContext.ScheduleLearningRate(15, TRAINING_RATE / 3);

// create the network
graph.Connect(engine)
    .AddFeedForward(outputSize: 1024)
    .Add(graph.LeakyReluActivation())
    .AddDropOut(dropOutPercentage: 0.5f)
    .AddFeedForward(outputSize: trainingData.OutputSize)
    .Add(graph.SigmoidActivation())
    .AddBackpropagation(errorMetric)
;

// train the network for twenty iterations, saving the model on each improvement
Models.ExecutionGraph bestGraph = null;
engine.Train(20, testData, errorMetric, model => bestGraph = model.Graph);

// export the final model and execute it on the training set
var executionEngine = graph.CreateEngine(bestGraph ?? engine.Graph);
var output = executionEngine.Execute(testData);
Console.WriteLine($"Final accuracy: {output.Average(o => o.CalculateError(errorMetric)):P2}");

Results

The classifier reaches a final accuracy of 98.4%.