Recurrent Transformers
The transformer layers implemented in the fast_transformers.transformers module are processing the entire sequence simultaneously. On the other hand, this module implements transfomers as recurrent networks. Namely as networks that process the sequence one element at a time while updating some state.
The TransformerEncoder and TransformerEncoderLayer give way to RecurrentTransformerEncoder and RecurrentTransformerEncoderLayer and for the decoders RecurrentTransformerDecoder and RecurrentTransformerDecoderLayer respectively.
Forward method
RecurrentTransformerEncoder or RecurrentTransformerEncoderLayer
forward(x, state=None)
Arguments
- x: The input features of shape (N, E) where N is the batch size and E is
d_model
passed in the constructor. Note that x corresponds to a specific element in the sequence and not the entire sequence. - state: The state is a python object that varies depending on the attention implementation
RecurrentTransformerDecoder or RecurrentTransformerDecoderLayer
forward(x, memory, memory_length_mask=None, state=None)
- x: The input features of shape (N, E) where N is the batch size and E is
d_model
passed in the constructor. Note that x corresponds to a specific element in the sequence and not the entire sequence. - memory: A sequence of features (N, S, E) that the input will attend to. S is the sequence length and E is the same as for x.
- memory_length_mask: An implementation of a BaseMask that encodes how many elements each memory sequence in the batch consists of.
- state: The state is a python object that varies depending on the attention implementation
Note
The masks are different in the recurrent implementations than in their batch counterparts. Namely, recurrent encoders and decoders enforce a triangular causal mask on self attention. In addition, recurrent decoders enforce a full mask on cross attention.
Available Attentions
Not all attention formulations can be written in an autoregressive fashion as a recurrent model. In particular, since the sequence is passed to the transformer element by element we have the same result as passing a causal mask to normal transformers. The current list for recurrent attention implementations is:
Example
The following example builds a random recurrent transformer encoder and applies its output as input 100 times.
# for simplicity ignore all the classification
# layers and the embedding layers
from fast_transformers.builders import RecurrentEncoderBuilder
model = RecurrentEncoderBuilder.from_kwargs(
attention_type="linear",
n_layers=8,
n_heads=12,
feed_forward_dimensions=1536,
query_dimensions=32,
value_dimensions=32
).get()
x0 = torch.rand(
10, # batch size
12*32 # feature size
)
state = None
x = x0
for i in range(100):
x, state = model(x, state=state)