Bright Wire

More complicated sequences call for more complicated neural networks. This tutorial shows how to use a GRU recurrent neural network to learn the Embedded Reber Grammar.

Motivation

It's possible to get a lot done with simple recurrent neural networks (SRNN). But they do have a few limitations.

First, they're fiddly to train. On a good day they work well and train up quickly. On a bad day, they blow up or fail to learn at all.

Second, they're not really able to learn long term dependencies. They have a reasonable short term memory but no capacity to remember things that happened a while ago that might be relevant to what's happening right now.

Reber Sequences

An example of a problem that is beyond SRNN is learning an Embedded Reber Grammar.

Extended Reber Grammar

The Reber Grammar (contained in each of the two boxes above) is a simple finite state grammar, that generates random strings like "BTSSXXTVVE" or "BPVPXVPXVPXVVE". The state transitions give the valid next characters (a T or P is equally likely after a B, but after a BT you can only get an X or any number of S, etc).

To make things harder, the Embedded Reber Grammar extends the Reber Grammar to also include a longer term state transition. So a sequence that begins with BT must end with TE (while including any number of Reber Grammar transitions in the meantime).

The job of our recurrent neural network is to be able to predict the set of possible next state transitions, after having observed any number of previous states.

It turns out that a SRNN is able to learn the Reber grammar state transitions fairly well. However their memory is so short term that they're not able to ensure that any sequence that starts with BT ends with a TE.

Gated Recurrent Units

Gated Recurrent Units (GRU) are a variation on Long Short Term Memory (LSTM) recurrent neural networks. Both LSTM and GRU networks have additional parameters that control when and how their memory is updated.

Both GRU and LSTM networks can capture both long and short term dependencies in sequences, but GRU networks involve less parameters and so are faster to train.

Conceptually, a GRU network has a reset and forget "gate" that helps ensure its memory doesn't get taken over by tracking short term dependencies. The network learns how to use its gates to protect its memory so that it's able to make longer term predictions.

LSTM Gating. Chung, Junyoung, et al. “Empirical evaluation of gated recurrent neural networks on sequence modeling.” (2014)

The GRU is implemented as:

GRU Implementation

Since Bright Wire is implemented in terms of a graph of nodes and "wires", let's see how those equations look in graph form:

GRU Graph based implementation

As with a SRNN, the memory buffer is updated with the layer's output at each step in the sequence, and then this saved output flows into the next item in the sequence.

Generating Sequences

Bright Wire includes a helper class to generate Reber Grammar and Embedded Reber Grammar sequences. For this example, we generate 500 extended (embedded) sequences of min length 6, max length 10 characters and split them into training and test sets.

The training data contains the set of possible following state transitions (one hot encoded) at each point in each sequence. So the first item in each sequence is "B", followed by "T" and "P" etc.

var grammar = new ReberGrammar(context.Random);
var sequences = extended
    ? grammar.GetExtended(minLength, maxLength)
    : grammar.Get(minLength, maxLength);

return new ReberSequenceTrainer(context, ReberGrammar.GetOneHot(context, sequences.Take(500)));

Next, we connect two GRUs and a Feed Forward layer with Soft Max activation and train the neural network with rms prop gradient descent and batch size of 32 for 50 epochs. The GRU's memory is a buffer of 50 floating point numbers. The network reaches around 95% accuracy.

var graph = _context.CreateGraphFactory();

// binary classification rounds each output to either 0 or 1
var errorMetric = graph.ErrorMetric.BinaryClassification;

// configure the network properties
graph.CurrentPropertySet
    .Use(graph.GradientDescent.RmsProp)
    .Use(graph.WeightInitialisation.Xavier)
;

// create the engine
var trainingData = graph.CreateDataSource(Training);
var testData = trainingData.CloneWith(Test);
var engine = graph.CreateTrainingEngine(trainingData, learningRate: 0.1f, batchSize: 32);

// build the network
const int HIDDEN_LAYER_SIZE = 50, TRAINING_ITERATIONS = 50;
graph.Connect(engine)
    .AddGru(HIDDEN_LAYER_SIZE)
    .AddGru(HIDDEN_LAYER_SIZE)
    .AddFeedForward(engine.DataSource.GetOutputSizeOrThrow())
    .Add(graph.SoftMaxActivation())
    .AddBackpropagationThroughTime(errorMetric)
;

engine.Train(TRAINING_ITERATIONS, testData, errorMetric);
return graph.CreateEngine(engine.Graph);

Output

Reber output

Complete Source Code

 View the complete source on GitHub

Fork me on GitHub