I had a partially wrong understanding of the ‘stateful‘ option in LSTM while learning Keras. To confirm my knowledge, I did some searching and experimentation to make things clear. I have summarized my findings below. The source code is accessible in this Bitbucket repository.
I would advise taking my conclusions with a grain of salt as I don’t have an extensive experience with Keras and could be mistaken. Also, this blog by Philippe Remy cleared up many of my questions and if you’re also striving to understand the ‘stateful’ option I highly recommend going through it.
For a purpose of this blog, the following terminology holds
State = Cell-State that gets passed to next time-step. Not hidden state
Sample = Sequence. If your problem is to predict a char given 5 characters preceding it, then sample size (aka sequence length) would be 5.
What we already know and agree upon:
- No information is exchanged between samples in a batch. For example, imagine every sample is a sentence with Label ‘positive emotion‘ and ‘negative emotion‘. As we can see, such sentences are independent and if two of the samples end up in the same batch then it was just a coincidence.
Also, as per my understanding, a wrong usage of LSTM would be to predict the labels where a sample consists of one sentence and in order to predict for current sentence, I heavily rely on the information in the previous sentence.
- Having a larger batch size helps avoid the zig-zag trajectory of weights over performance surface which in-turn reduces the number of gradient updates needed for convergence. If we ignore this particular effect of batch size, then
Batch size is chosen based on experimentation. Start training with a low batch size, say 4, and keep incrementing it with the subsequent training sessions. The time spent per epoch will decrease first (due to better parallelism) and then start increasing (hardware limitation). Choose the batch size with the least training-time-per-epoch and stick to it then onward.
- ‘stateful’ has nothing to do with ‘unroll‘ parameter in LSTM layer. This is obvious.
For all experimentation, this is how the simple network has been defined:
For a stateful LSTM, I just add the flag like this:
I randomly used
num_units = 3, batch_size = 4. Dataset determines the
time_stepsand is 5 for our case. There are 1024 samples in training set and 256 in test set. ‘MSE’ is used as loss function.
Here is a code snippet of training. Notice the
shuffle = Falseflag. It is necessary if using ‘stateful‘ option.
Taking cue from Philippe’s example, here is how the dataset looks like:
|Sample 0||[1, X, X, X, X]||1|
|Sample 1||[0, X, X, X, X]||0|
|Sample 2||[0, X, X, X, X]||0|
|Sample 3||[1, X, X, X, X]||1|
|… and so on||…||…|
Each sample is 5 time-steps long and the Label is always the same as the value of the sample at first time-step. The ‘X’ represent that the number could be either 0 or 1 and is set randomly while generating the training set.
For the sake of completion here is code to generate this dataset:
<p># Setting internal time-steps to random numbers<br />X_train[:, 1:] = np.random.randint(0, 2, (num_train_samples, time_steps-1, 1), dtype=int)<br />X_test[:, 1:] = np.random.randint(0, 2, (num_test_samples, time_steps-1, 1), dtype=int)<br /># Setting half of the first time-steps to 1<br />one_indexes = np.random.choice(a=num_train_samples, size=int(num_train_samples / 2), replace=False)<br />X_train[one_indexes, 0] = 1<br />one_indexes = np.random.choice(a=num_test_samples, size=int(num_test_samples / 2), replace=False)<br />X_test[one_indexes, 0] = 1</p>
<p># Creating labels<br />y_train = X_train[:, 0, 0]<br />y_test = X_test[:, 0, 0]<br />
Given the trivial nature of this problem, I started by training a ‘stateless’ LSTM first. The network learns the pattern in the first epoch itself and achieves 100% accuracy on test set:
This is expected as the problem is easy. Next, I use ‘stateful’ LSTM on the same setup. Here is the console log for training for 5 epochs:
Making the LSTM ‘stateful’ is hindering LSTM’s in learning for this specific problem. Although the result do seem to improve with successive epochs.
Keras documentation describes ‘stateful’ as “Boolean (default False). If True, the last state for each sample at index i in a batch will be used as initial state for the sample of index i in the following batch“. Philippe’s blog states, “If the model is stateless, the cell states are reset at each sequence. With the stateful model, all the states are propagated to the next batch. It means that the state of the sample located at index i, Xi will be used in the computation of the sample (Xi+bs) in the next batch, where bs is the batch size (no shuffling).“
Seems clear enough. the state from the current batch will be used for the collocated samples in next batch. Below is a pictorial representation of stateless LSTM. The state of last time step is discarded.
‘b‘ stands for batch size. I have shown two samples belonging to two successive batches. On top the sample index (in entire training set) is ‘k‘ and on bottom sample index is ‘k + b’. x0 to x4 represents the 5 time-steps present in our problem. ‘C‘ is cell-state.
In case of a stateful LSTM, the state of last time step is given to the collocated sample in next batch:
To confirm this, I created a new problem in which current label is determined by the collocated sample in the previous batch. Here is how the dataset looks like:
|Batch 0||Sample 0||[1, X, X, X, X]||1|
|Batch 0||Sample 1||[0, X, X, X, X]||1|
|Batch 0||Sample 2||[0, X, X, X, X]||0|
|Batch 0||Sample 3||[1, X, X, X, X]||0|
|Batch 1||Sample 4||[0, X, X, X, X]||1|
|Batch 1||Sample 5||[0, X, X, X, X]||0|
|Batch 1||Sample 6||[0, X, X, X, X]||0|
|Batch 1||Sample 7||[1, X, X, X, X]||1|
|Batch 2||Sample 8||[1, X, X, X, X]||0|
|Batch 2||Sample 9||[1, X, X, X, X]||0|
|Batch 2||Sample 10||[0, X, X, X, X]||0|
|Batch 2||Sample 11||[1, X, X, X, X]||1|
|.. and so on||…||…||…|
The first time-step of previous batch’s collocated sample determines current samples label. So in theory, if the state is able to propagate across batches, then stateless LSTM should fail to solve this problem while stateful should succeed.
Stateless LSTM network training console output
The model fails to learn anything and is probably predicting either all 0 or all 1 for entire training/test set. Surprisingly, stateful LSTM also fails on this problem:
It didn’t make any sense to me and I stayed stuck on this result for a while. I tried keeping the batch size as 1 while changing the data such that current sample’s labels is present in previous sample. But couldn’t get accurate results even then. Since stateful LSTMs can pass information across batches (only to collocated samples), it should have been able to perform well on this problem. And then, I read through the ‘Don’t use `stateful` LSTM unless you know what it does’ page on GitHub once again. And it made sense. Here are some keypoints:
“Cell t+1 will do its best to do something with state t, but state t will be random and untrained. It might learn something it might not.” and “The hidden states are randomly initialized and untrained. I don’t think most people understand that part and end up with some weird models”.
Ben points out that although making the LSTMs stateful does allow transfer of information across batches, the backpropagation cannot pass through batch boundary. And hence it would be unable to train the network to produce a useful information after the last time step. Resulting in potentially garbage state after last time step.
I don’t yet know when using ‘stateful’ option be helpful. But certainly, if I am just starting with LSTMs I may not need to use this ‘stateful’ LSTMs that much.
At this point, to make sure that we’re on the same page, I have given some example problems and an approach to solving them:
(One-step time-series prediction refers to the problem of predicting just the next value. It has nothing to do with ‘time-steps’ in LSTM. Once the prediction is made and evaluated we have the true value available. And then onward this true value will be used for making predictions. e.g. predicting global population for next year. On the other hand, in multi-step time-series problems, the machine predicts many future values without getting a true future value. e.g. predicting global population for each year in next century.)
- One-step time-series forecasting. Say, I want to predict the mean temperature tomorrow based on historical data (each sample is mean temperature of one day). Trying to solve this problem with a single-time-step LSTM model is plain wrong. Although one could argue that using ‘stateful’ LSTM entire historical information can pass via the state, as we saw above, the backprop cannot teach the network to put useful information in the last time-step state (here, we have just one time-step). And hence the network may not learn to predict the temperature.
To solve this problem we must use a multi-time-step LSTM network. However, the series must be detrended in the preprocessing stage (see next point).
- If there is a trend in the time-series (e.g. an ever-increasing pattern) then in order to predict a correct future value, location of a sample within the series may become crucial. However, during training the vanilla LSTM network, the information about the sample position within time-series is lost. This would affect the accuracy of predictions in unpreprocessed time-series. By detrending the series, the entire series looks identically distributed with respect to time.
- What if we know in advance if there are patterns in collocated samples across batches? Like in the example above where current sample’s output depends on the collocated sample in the previous batch. Would ‘stateful’ be useful there?
Well, I did try using stateful LSTM above. And making it ‘stateful’ didn’t show any observable accuracy improvement. Maybe I should have trained it for longer and with a random chance, the network would have eventually picked up the pattern.
At the end, I am open to suggestions, if you find something wrong with these conclusions, please do put your thoughts in comments.
Other Recommended Readings
- TridentNet Explained: Beginner friendly intro to handling multiple scales in object detection using Dilated Convolutions.
- Plotting Activation Functions & Gradients in Neural Networks