[Keras] Returning the hidden state in keras RNNs with return_state

There is a lot of confusion about return_state in Keras. What does ist actually return and how can we use it for stacking RNNs or encoder/decoder models. Hopefully this post makes it a bit clearer.

Cell state vs Hidden state

The difference between cell state (return_states) and hidden state for each timestep (return_sequences) is not very intuitive as both sates represent previous data in the sequence. Furthermore keras returns a confusing numer of tensors. But when remembering that in RNNs neurons are replaced by memory cells with several gates, it gets a bit clearer. Here we can see an LSTM cell:

LSTM cell
A peephole LSTM unit with input, output, and forget gates. Each of these gates can be thought as a “standard” neuron in a feed-forward (or multi-layer) neural network (wikipedia).

In LSTMs return_sequences returns the states of the neurons at each timestep, return_states returns the internal cell states, which are responsible for memory control. While the LSTMs have three gates, GRUs have only two gates. For both architectures one of the gates is the output gate. But why the output of the LSTM returns three tensors and GRUs return two tensors?

Looking at the LSTM outputs

inputs1 = Input(shape=(5, 1))
lstm1, state_h, state_c = LSTM(1, return_state=True, return_sequences=True)(inputs1)
model = Model(inputs=inputs1, outputs=[lstm1, state_h, state_c])
data = array([0.1, 0.2, 0.3, 0.4, 0.5]).reshape((1,5,1))
print(model.predict(data))
[array([[[0.00734747],
        [0.02000349],
        [0.03651035],
        [0.05576567],
        [0.07689518]]], dtype=float32), array([[0.07689518]], dtype=float32), array([[0.15496857]], dtype=float32)]

Our sequence has 5 values and the first returned array is the output for return_sequences. For each timestep there is one value. It has nothing to do with return_state itself but helps to make clear it can be a single value or a sequence, depending on the parameter.

The two other values are only returned if return_state is set. The first of these values is the output of the memory state (state_h), which is actually the last value from the sequence prediction seen before. The last state (state_c) is the carry state (cell state).

Take a look at the LSTM source

With GRUs

inputs1 = Input(shape=(5, 1))
lstm1, state_h= GRU(1, return_state=True, return_sequences=True)(inputs1)
model = Model(inputs=inputs1, outputs=[lstm1, state_h])
# define input data
data = array([0.1, 0.2, 0.3, 0.4, 0.5]).reshape((1,5,1))
# make and show prediction
print(model.predict(data))
[array([[[0.02443524],
        [0.06332285],
        [0.1107745 ],
        [0.16312647],
        [0.21796873]]], dtype=float32), array([[0.21796873]], dtype=float32)]

Now we have only one state tensor (state_h), which is returning the state of the memory neuron. Remember the GRU cell is simpler and has no additional neuron.

In conclusion we can think about the memory cell state as a summary about the sequence so far and the carry state is responsible for the forget gate control. We can see that in GRUs the state equals the output of the neuron, which is exactly what we would expect.

GRU source

When to use return_state?

In Encoder decoder architectures return_state is used a lot. It preserves the memory of the encoder and makes it available in the decoder. Apart of encoder/decoder architectures return_state can be useful in event processing, where we stack LSTMs on top of each other. Unlike other “parameters”, it is necessary to pass the initial_state for following layers with inside of the call method.

Here is an example with Bidirectional LSTMs. Note that we have two LSTMs in bidirectional “mode”, one for normal sequence and one for the reverse sequence, therefore we have 5 returned tensors.

rnn_seq, h1, c1, h2, c2 = Bidirectional(LSTM(64, return_sequences=True, return_state=True))(mask)
rnn = Bidirectional(LSTM(64, return_sequences=True, return_state=False))(rnn1, initial_state=[h1, c1, h2, c2])
att = AttentionWithContext()(rnn)

Related Posts

Leave a reply