Module fast_transformers.recurrent.attention.cross_attention
Autoregressive implementations for cross attention as a recurrent module.
The attention implementations in this module expect one input for query and a sequence of inputs for keys and values. The sequence for the keys and values is fixed for all queries.
Example
import torch
from fast_transformers.recurrent.attention import RecurrentCrossAttentionLayer, RecurrentCrossFullAttention
att = RecurrentCrossAttentionLayer(RecurrentCrossFullAttention(), 16, 4)
state = None
x = torch.rand(8, 16)
memory = torch.rand(8, 64, 16)
for i in range(10):
x, state = att(x, memory, memory, state=state)
Expand source code
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>
#
"""Autoregressive implementations for cross attention as a recurrent module.
The attention implementations in this module expect one input for query and a
sequence of inputs for keys and values. The sequence for the keys and values is
fixed for all queries.
Example
--------
import torch
from fast_transformers.recurrent.attention import \
RecurrentCrossAttentionLayer, RecurrentCrossFullAttention
att = RecurrentCrossAttentionLayer(RecurrentCrossFullAttention(), 16, 4)
state = None
x = torch.rand(8, 16)
memory = torch.rand(8, 64, 16)
for i in range(10):
x, state = att(x, memory, memory, state=state)
"""
from .attention_layer import RecurrentCrossAttentionLayer
from .full_attention import RecurrentCrossFullAttention
from .linear_attention import RecurrentCrossLinearAttention
Sub-modules
fast_transformers.recurrent.attention.cross_attention.attention_layer
-
Similar to the corresponding module in fast_transformers.attention, this module performs all the query, key, value projections and output projections …
fast_transformers.recurrent.attention.cross_attention.full_attention
-
Implement the typical softmax attention as a recurrent cross attention module to speed up autoregressive decoding.
fast_transformers.recurrent.attention.cross_attention.linear_attention
-
Implement unmasked linear attention as a recurrent cross attention module to speed up autoregressive decoding.