Bright Wire

More complicated sequences call for more complicated neural networks. This tutorial shows how to use a GRU recurrent neural network to learn the Embedded Reber Grammar.

Motivation

It's possible to get a lot done with simple recurrent neural networks (SRNN). But they do have a few limitations.

First, they're fiddly to train. On a good day they work well and train up quickly. On a bad day, they blow up or fail to learn at all.

Second, they're not really able to learn long term dependencies. They have a reasonable short term memory but no capacity to remember things that happened a while ago that might be relevant to what's happening right now.

Reber Sequences

An example of a problem that is beyond SRNN is learning an Embedded Reber Grammar.

Extended Reber Grammar

The Reber Grammar (contained in each of the two boxes above) is a simple finite state grammar, that generates random strings like "BTSSXXTVVE" or "BPVPXVPXVPXVVE". The state transitions give the valid next characters (a T or P is equally likely after a B, but after a BT you can only get an X or any number of S, etc).

To make things harder, the Embedded Reber Grammar extends the Reber Grammar to also include a longer term state transition. So a sequence that begins with BT must end with TE (while including any number of Reber Grammar transitions in the meantime).

The job of our recurrent neural network is to be able to predict the set of possible next state transitions, after having observed any number of previous states.

It turns out that a SRNN is able to learn the Reber grammar state transitions fairly well. However their memory is so short term that they're not able to ensure that any sequence that starts with BT ends with a TE.

Gated Recurrent Units

Gated Recurrent Units (GRU) are a variation on Long Short Term Memory (LSTM) recurrent neural networks. Both LSTM and GRU networks have additional parameters that control when and how their memory is updated.

Both GRU and LSTM networks can capture both long and short term dependencies in sequences, but GRU networks involve less parameters and so are faster to train.

Conceptually, a GRU network has a reset and forget "gate" that helps ensure its memory doesn't get taken over by tracking short term dependencies. The network learns how to use its gates to protect its memory so that it's able to make longer term predictions.

LSTM Gating. Chung, Junyoung, et al. “Empirical evaluation of gated recurrent neural networks on sequence modeling.” (2014)

The GRU is implemented as:

GRU Implementation

Since Bright Wire is implemented in terms of a graph of nodes and "wires", let's see how those equations look in graph form:

GRU Graph based implementation

As with a SRNN, the memory buffer is updated with the layer's output at each step in the sequence, and then this saved output flows into the next item in the sequence.

Generating Sequences

Bright Wire includes a helper class to generate Reber Grammar and Embedded Reber Grammar sequences. For this example, we generate 500 extended (embedded) sequences of length 10 characters and split them into training and test sets.

The training data contains the set of possible following state transitions (one hot encoded) at each point in each sequence. So the first item in each sequence is "B", followed by "T" and "P" etc.

// generate 500 extended reber grammar training examples
var grammar = new ReberGrammar(stochastic: false);
var sequences = grammar.GetExtended(10).Take(500).ToList();

// split the data into training and test sets
var data = ReberGrammar.GetOneHot(sequences).Split(0);

Next, we connect a GRU and Feed Forward layer with Sigmoid activation and train the neural network with rms prop gradient descent, learning rate of 0.003 and batch size of 32 for 30 epochs. The GRU's memory is a buffer of 64 floating point numbers. The network reaches 100% accuracy after about 20 epochs.

(Note that the error metric BinaryClassification rounds each value in each output vector to be either 1 or 0, which is how we're able to get near 100% accuracy so easily).

using (var lap = BrightWireProvider.CreateLinearAlgebra(stochastic: false)) {
    var graph = new GraphFactory(lap);

    // binary classification rounds each output to either 0 or 1
    var errorMetric = graph.ErrorMetric.BinaryClassification;

    // configure the network properties
    graph.CurrentPropertySet
        .Use(graph.GradientDescent.RmsProp)
        .Use(graph.WeightInitialisation.Xavier)
    ;

    // create the engine
    var trainingData = graph.CreateDataSource(data.Training);
    var testData = trainingData.CloneWith(data.Test);
    var engine = graph.CreateTrainingEngine(trainingData, learningRate: 0.1f, batchSize: 32);

    // build the network
    const int HIDDEN_LAYER_SIZE = 64, TRAINING_ITERATIONS = 30;
    var memory = new float[HIDDEN_LAYER_SIZE];
    var network = graph.Connect(engine)
        .AddGru(memory)
        .AddFeedForward(engine.DataSource.OutputSize)
        .Add(graph.TanhActivation())
        .AddBackpropagationThroughTime(errorMetric)
    ;

    engine.Train(TRAINING_ITERATIONS, testData, errorMetric);
}

Output

The final accuracy reaches 99.79% after 30 epochs.

Reber output

Complete Source Code

View the complete source on GitHub.

Fork me on GitHub