from d2l import torch as d2l
import torch
from torch import nn10.2 Gated Recurrent Units (GRU)
As RNNs and particularly the LSTM architecture (Section 10.1) rapidly gained popularity during the 2010s, a number of researchers began to experiment with simplified architectures in hopes of retaining the key idea of incorporating an internal state and multiplicative gating mechanisms but with the aim of speeding up computation. The gated recurrent unit (GRU) (Cho et al. 2014) offered a streamlined version of the LSTM memory cell that often achieves comparable performance but with the advantage of being faster to compute (Chung et al. 2014).
from d2l import tensorflow as d2l
import tensorflow as tffrom d2l import jax as d2l
from flax import linen as nn
import jax
from jax import numpy as jnpfrom d2l import mxnet as d2l
from mxnet import np, npx
from mxnet.gluon import rnn
npx.set_np()10.2.1 Reset Gate and Update Gate
Here, the LSTM’s three gates are replaced by two: the reset gate and the update gate. As with LSTMs, these gates are given sigmoid activations, forcing their values to lie in the interval \((0, 1)\). Intuitively, the reset gate controls how much of the previous state we might still want to remember. Likewise, an update gate would allow us to control how much of the new state is just a copy of the old one. Figure 10.2.1 illustrates the inputs for both the reset and update gates in a GRU, given the input of the current time step and the hidden state of the previous time step. The outputs of the gates are given by two fully connected layers with a sigmoid activation function.
Mathematically, for a given time step \(t\), suppose that the input is a minibatch \(\mathbf{X}_t \in \mathbb{R}^{n \times d}\) (number of examples \(=n\); number of inputs \(=d\)) and the hidden state of the previous time step is \(\mathbf{H}_{t-1} \in \mathbb{R}^{n \times h}\) (number of hidden units \(=h\)). Then the reset gate \(\mathbf{R}_t \in \mathbb{R}^{n \times h}\) and update gate \(\mathbf{Z}_t \in \mathbb{R}^{n \times h}\) are computed as follows:
\[ \begin{aligned} \mathbf{R}_t = \sigma(\mathbf{X}_t \mathbf{W}_{\textrm{xr}} + \mathbf{H}_{t-1} \mathbf{W}_{\textrm{hr}} + \mathbf{b}_\textrm{r}),\\ \mathbf{Z}_t = \sigma(\mathbf{X}_t \mathbf{W}_{\textrm{xz}} + \mathbf{H}_{t-1} \mathbf{W}_{\textrm{hz}} + \mathbf{b}_\textrm{z}), \end{aligned} \]
where \(\mathbf{W}_{\textrm{xr}}, \mathbf{W}_{\textrm{xz}} \in \mathbb{R}^{d \times h}\) and \(\mathbf{W}_{\textrm{hr}}, \mathbf{W}_{\textrm{hz}} \in \mathbb{R}^{h \times h}\) are weight parameters and \(\mathbf{b}_\textrm{r}, \mathbf{b}_\textrm{z} \in \mathbb{R}^{1 \times h}\) are bias parameters.
10.2.4 Implementation from Scratch
To gain a better understanding of the GRU model, let’s implement it from scratch.
10.2.4.1 Initializing Model Parameters
The first step is to initialize the model parameters. We draw the weights from a Gaussian distribution with standard deviation to be sigma and set the bias to 0. The hyperparameter num_hiddens defines the number of hidden units. We instantiate all weights and biases relating to the update gate, the reset gate, and the candidate hidden state.
class GRUScratch(d2l.Module):
def __init__(self, num_inputs, num_hiddens, sigma=0.01):
super().__init__()
self.save_hyperparameters()
init_weight = lambda *shape: nn.Parameter(d2l.randn(*shape) * sigma)
triple = lambda: (init_weight(num_inputs, num_hiddens),
init_weight(num_hiddens, num_hiddens),
nn.Parameter(d2l.zeros(num_hiddens)))
self.W_xz, self.W_hz, self.b_z = triple() # Update gate
self.W_xr, self.W_hr, self.b_r = triple() # Reset gate
self.W_xh, self.W_hh, self.b_h = triple() # Candidate hidden state class GRUScratch(d2l.Module):
def __init__(self, num_inputs, num_hiddens, sigma=0.01):
super().__init__()
self.save_hyperparameters()
init_weight = lambda *shape: tf.Variable(d2l.normal(shape) * sigma)
triple = lambda: (init_weight(num_inputs, num_hiddens),
init_weight(num_hiddens, num_hiddens),
tf.Variable(d2l.zeros(num_hiddens)))
self.W_xz, self.W_hz, self.b_z = triple() # Update gate
self.W_xr, self.W_hr, self.b_r = triple() # Reset gate
self.W_xh, self.W_hh, self.b_h = triple() # Candidate hidden state class GRUScratch(d2l.Module):
num_inputs: int
num_hiddens: int
sigma: float = 0.01
def setup(self):
init_weight = lambda name, shape: self.param(name,
nn.initializers.normal(self.sigma),
shape)
triple = lambda name : (
init_weight(f'W_x{name}', (self.num_inputs, self.num_hiddens)),
init_weight(f'W_h{name}', (self.num_hiddens, self.num_hiddens)),
self.param(f'b_{name}', nn.initializers.zeros, (self.num_hiddens)))
self.W_xz, self.W_hz, self.b_z = triple('z') # Update gate
self.W_xr, self.W_hr, self.b_r = triple('r') # Reset gate
self.W_xh, self.W_hh, self.b_h = triple('h') # Candidate hidden stateclass GRUScratch(d2l.Module):
def __init__(self, num_inputs, num_hiddens, sigma=0.01):
super().__init__()
self.save_hyperparameters()
init_weight = lambda *shape: d2l.randn(*shape) * sigma
triple = lambda: (init_weight(num_inputs, num_hiddens),
init_weight(num_hiddens, num_hiddens),
d2l.zeros(num_hiddens))
self.W_xz, self.W_hz, self.b_z = triple() # Update gate
self.W_xr, self.W_hr, self.b_r = triple() # Reset gate
self.W_xh, self.W_hh, self.b_h = triple() # Candidate hidden state 10.2.4.2 Defining the Model
Now we are ready to define the GRU forward computation. Its structure is the same as that of the basic RNN cell, except that the update equations are more complex.
@d2l.add_to_class(GRUScratch)
def forward(self, inputs, H=None):
if H is None:
# Initial state with shape: (batch_size, num_hiddens)
H = d2l.zeros((inputs.shape[1], self.num_hiddens),
device=inputs.device)
outputs = []
for X in inputs:
Z = d2l.sigmoid(d2l.matmul(X, self.W_xz) +
d2l.matmul(H, self.W_hz) + self.b_z)
R = d2l.sigmoid(d2l.matmul(X, self.W_xr) +
d2l.matmul(H, self.W_hr) + self.b_r)
H_tilde = d2l.tanh(d2l.matmul(X, self.W_xh) +
d2l.matmul(R * H, self.W_hh) + self.b_h)
H = Z * H + (1 - Z) * H_tilde
outputs.append(H)
return outputs, H@d2l.add_to_class(GRUScratch)
def forward(self, inputs, H=None):
if H is None:
# Initial state with shape: (batch_size, num_hiddens)
H = tf.zeros((tf.shape(inputs)[1], self.num_hiddens))
outputs = []
for X in tf.unstack(inputs):
Z = d2l.sigmoid(d2l.matmul(X, self.W_xz) +
d2l.matmul(H, self.W_hz) + self.b_z)
R = d2l.sigmoid(d2l.matmul(X, self.W_xr) +
d2l.matmul(H, self.W_hr) + self.b_r)
H_tilde = d2l.tanh(d2l.matmul(X, self.W_xh) +
d2l.matmul(R * H, self.W_hh) + self.b_h)
H = Z * H + (1 - Z) * H_tilde
outputs.append(H)
return outputs, H@d2l.add_to_class(GRUScratch)
def forward(self, inputs, H=None):
# Use lax.scan primitive instead of looping over the
# inputs, since scan saves time in jit compilation
def scan_fn(H, X):
Z = d2l.sigmoid(d2l.matmul(X, self.W_xz) + d2l.matmul(H, self.W_hz) +
self.b_z)
R = d2l.sigmoid(d2l.matmul(X, self.W_xr) +
d2l.matmul(H, self.W_hr) + self.b_r)
H_tilde = d2l.tanh(d2l.matmul(X, self.W_xh) +
d2l.matmul(R * H, self.W_hh) + self.b_h)
H = Z * H + (1 - Z) * H_tilde
return H, H # return carry, y
if H is None:
batch_size = inputs.shape[1]
carry = jnp.zeros((batch_size, self.num_hiddens))
else:
carry = H
# scan takes the scan_fn, initial carry state, xs with leading axis to be scanned
carry, outputs = jax.lax.scan(scan_fn, carry, inputs)
return outputs, carry@d2l.add_to_class(GRUScratch)
def forward(self, inputs, H=None):
if H is None:
# Initial state with shape: (batch_size, num_hiddens)
H = d2l.zeros((inputs.shape[1], self.num_hiddens),
ctx=inputs.ctx)
outputs = []
for X in inputs:
Z = d2l.sigmoid(d2l.matmul(X, self.W_xz) +
d2l.matmul(H, self.W_hz) + self.b_z)
R = d2l.sigmoid(d2l.matmul(X, self.W_xr) +
d2l.matmul(H, self.W_hr) + self.b_r)
H_tilde = d2l.tanh(d2l.matmul(X, self.W_xh) +
d2l.matmul(R * H, self.W_hh) + self.b_h)
H = Z * H + (1 - Z) * H_tilde
outputs.append(H)
return outputs, H10.2.4.3 Training
Training a language model on The Time Machine dataset works in exactly the same manner as in Section 9.5.
data = d2l.TimeMachine(batch_size=1024, num_steps=32)
gru = GRUScratch(num_inputs=len(data.vocab), num_hiddens=32)
model = d2l.RNNLMScratch(gru, vocab_size=len(data.vocab), lr=4)
trainer = d2l.Trainer(max_epochs=50, gradient_clip_val=1, num_gpus=1)
trainer.fit(model, data)data = d2l.TimeMachine(batch_size=1024, num_steps=32)
with d2l.try_gpu():
gru = GRUScratch(num_inputs=len(data.vocab), num_hiddens=32)
model = d2l.RNNLMScratch(gru, vocab_size=len(data.vocab), lr=4)
trainer = d2l.Trainer(max_epochs=50, gradient_clip_val=1)
trainer.fit(model, data)data = d2l.TimeMachine(batch_size=1024, num_steps=32)
gru = GRUScratch(num_inputs=len(data.vocab), num_hiddens=32)
model = d2l.RNNLMScratch(gru, vocab_size=len(data.vocab), lr=4)
trainer = d2l.Trainer(max_epochs=50, gradient_clip_val=1, num_gpus=1)
trainer.fit(model, data)data = d2l.TimeMachine(batch_size=1024, num_steps=32)
gru = GRUScratch(num_inputs=len(data.vocab), num_hiddens=32)
model = d2l.RNNLMScratch(gru, vocab_size=len(data.vocab), lr=4)
trainer = d2l.Trainer(max_epochs=50, gradient_clip_val=1, num_gpus=1)
trainer.fit(model, data)10.2.5 Concise Implementation
In high-level APIs, we can directly instantiate a GRU model. This encapsulates all the configuration detail that we made explicit above.
class GRU(d2l.RNN):
def __init__(self, num_inputs, num_hiddens):
d2l.Module.__init__(self)
self.save_hyperparameters()
self.rnn = nn.GRU(num_inputs, num_hiddens)class GRU(d2l.RNN):
def __init__(self, num_inputs, num_hiddens):
d2l.Module.__init__(self)
self.save_hyperparameters()
self.rnn = tf.keras.layers.GRU(num_hiddens, return_sequences=True,
return_state=True)class GRU(d2l.RNN):
num_hiddens: int
@nn.compact
def __call__(self, inputs, H=None, training=False):
if H is None:
batch_size = inputs.shape[1]
H = nn.GRUCell(features=self.num_hiddens).initialize_carry(
jax.random.PRNGKey(0), (batch_size, self.num_hiddens))
GRU = nn.scan(nn.GRUCell, variable_broadcast="params",
in_axes=0, out_axes=0, split_rngs={"params": False})
H, outputs = GRU(features=self.num_hiddens)(H, inputs)
return outputs, Hclass GRU(d2l.RNN):
def __init__(self, num_inputs, num_hiddens):
d2l.Module.__init__(self)
self.save_hyperparameters()
self.rnn = rnn.GRU(num_hiddens)The code is significantly faster in training as it uses compiled operators rather than Python.
gru = GRU(num_inputs=len(data.vocab), num_hiddens=32)
model = d2l.RNNLM(gru, vocab_size=len(data.vocab), lr=4)
trainer.fit(model, data)gru = GRU(num_inputs=len(data.vocab), num_hiddens=32)
with d2l.try_gpu():
model = d2l.RNNLM(gru, vocab_size=len(data.vocab), lr=4)
trainer.fit(model, data)gru = GRU(num_hiddens=32)
model = d2l.RNNLM(gru, vocab_size=len(data.vocab), lr=4)
trainer.fit(model, data)gru = GRU(num_inputs=len(data.vocab), num_hiddens=32)
model = d2l.RNNLM(gru, vocab_size=len(data.vocab), lr=4)
trainer.fit(model, data)After training, we print out the perplexity on the training set and the predicted sequence following the provided prefix.
model.predict('it has', 20, data.vocab, d2l.try_gpu())'it has a cond the the pron'
model.predict('it has', 20, data.vocab)'it has in the time travell'
model.predict('it has', 20, data.vocab, trainer.state.params)'it has in and and and and '
model.predict('it has', 20, data.vocab, d2l.try_gpu())'it has surd the time trave'
10.2.6 Summary
Compared with LSTMs, GRUs achieve similar performance but tend to be lighter computationally. Generally, compared with simple RNNs, gated RNNS, just like LSTMs and GRUs, can better capture dependencies for sequences with large time step distances. GRUs contain basic RNNs as their extreme case whenever the reset gate is switched on. They can also skip subsequences by turning on the update gate.
10.2.7 Exercises
- Assume that we only want to use the input at time step \(t'\) to predict the output at time step \(t > t'\). What are the best values for the reset and update gates for each time step?
- Adjust the hyperparameters and analyze their influence on running time, perplexity, and the output sequence.
- Compare runtime, perplexity, and the output strings for
rnn.RNNandrnn.GRUimplementations with each other. - What happens if you implement only parts of a GRU, e.g., with only a reset gate or only an update gate?