It's possible to get a lot done with simple recurrent neural networks (SRNN). But they do have a few limitations.
First, they're fiddly to train. On a good day they work well and train up quickly. On a bad day, they blow up or fail to learn at all.
Second, they're not really able to learn long term dependencies. They have a reasonable short term memory but no capacity to remember things that happened a while ago that might be relevant to what's happening right now.
An example of a problem that is beyond SRNN is learning an Embedded Reber Grammar.
The Reber Grammar (contained in each of the two boxes above) is a simple finite state grammar, that generates random strings like "BTSSXXTVVE" or "BPVPXVPXVPXVVE". The state transitions give the valid next characters (a T or P is equally likely after a B, but after a BT you can only get an X or any number of S, etc).
To make things harder, the Embedded Reber Grammar extends the Reber Grammar to also include a longer term state transition. So a sequence that begins with BT must end with TE (while including any number of Reber Grammar transitions in the meantime).
The job of our recurrent neural network is to be able to predict the set of possible next state transitions, after having observed any number of previous states.
It turns out that a SRNN is able to learn the Reber grammar state transitions fairly well. However their memory is so short term that they're not able to ensure that any sequence that starts with BT ends with a TE.
Gated Recurrent Units
Gated Recurrent Units (GRU) are a variation on Long Short Term Memory (LSTM) recurrent neural networks. Both LSTM and GRU networks have additional parameters that control when and how their memory is updated.
Both GRU and LSTM networks can capture both long and short term dependencies in sequences, but GRU networks involve less parameters and so are faster to train.
Conceptually, a GRU network has a reset and forget "gate" that helps ensure its memory doesn't get taken over by tracking short term dependencies. The network learns how to use its gates to protect its memory so that it's able to make longer term predictions.
The GRU is implemented as:
Since Bright Wire is implemented in terms of a graph of nodes and "wires", let's see how those equations look in graph form:
As with a SRNN, the memory buffer is updated with the layer's output at each step in the sequence, and then this saved output flows into the next item in the sequence.
Bright Wire includes a helper class to generate Reber Grammar and Embedded Reber Grammar sequences. For this example, we generate 500 extended (embedded) sequences of min length 6, max length 10 characters and split them into training and test sets.
The training data contains the set of possible following state transitions (one hot encoded) at each point in each sequence. So the first item in each sequence is "B", followed by "T" and "P" etc.
var grammar = new ReberGrammar(context.Random); var sequences = extended ? grammar.GetExtended(minLength, maxLength) : grammar.Get(minLength, maxLength); return new ReberSequenceTrainer(context, ReberGrammar.GetOneHot(context, sequences.Take(500)));
Next, we connect two GRUs and a Feed Forward layer with Soft Max activation and train the neural network with rms prop gradient descent and batch size of 32 for 50 epochs. The GRU's memory is a buffer of 50 floating point numbers. The network reaches around 95% accuracy.
var graph = _context.CreateGraphFactory(); // binary classification rounds each output to either 0 or 1 var errorMetric = graph.ErrorMetric.BinaryClassification; // configure the network properties graph.CurrentPropertySet .Use(graph.GradientDescent.RmsProp) .Use(graph.WeightInitialisation.Xavier) ; // create the engine var trainingData = graph.CreateDataSource(Training); var testData = trainingData.CloneWith(Test); var engine = graph.CreateTrainingEngine(trainingData, learningRate: 0.1f, batchSize: 32); // build the network const int HIDDEN_LAYER_SIZE = 50, TRAINING_ITERATIONS = 50; graph.Connect(engine) .AddGru(HIDDEN_LAYER_SIZE) .AddGru(HIDDEN_LAYER_SIZE) .AddFeedForward(engine.DataSource.GetOutputSizeOrThrow()) .Add(graph.SoftMaxActivation()) .AddBackpropagationThroughTime(errorMetric) ; engine.Train(TRAINING_ITERATIONS, testData, errorMetric); return graph.CreateEngine(engine.Graph);
Complete Source Code
View the complete source on GitHub.