Module fast_transformers.recurrent.attention.cross_attention.attention_layer
Similar to the corresponding module in fast_transformers.attention, this module performs all the query, key, value projections and output projections leaving the implementation of the attention to the inner attention module.
The crucial difference with respect to the self attention recurrent module (fast_transformers.recurrent.attention.RecurrentAttentionLayer) is that it doesn't recompute the projections for the keys and values if the state is not None.
Expand source code
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>
#
"""Similar to the corresponding module in fast_transformers.attention, this
module performs all the query, key, value projections and output projections
leaving the implementation of the attention to the inner attention module.
The crucial difference with respect to the self attention recurrent module
(fast_transformers.recurrent.attention.RecurrentAttentionLayer) is that it
doesn't recompute the projections for the keys and values if the state is not
None.
"""
from torch.nn import Linear, Module
from ....events import EventDispatcher
class RecurrentCrossAttentionLayer(Module):
"""See fast_transformers.attention.attention_layer.AttentionLayer .
The differences with the aforementioned module as well as the
RecurrentAttentionLayer are that this module projects the query every time
and the keys and values only the first time they are provided.
Arguments
---------
attention: Specific inner attention implementation that just computes a
weighted average of values given a similarity of queries and
keys.
d_model: The input feature dimensionality
n_heads: The number of heads for the multi head attention
d_keys: The dimensionality of the keys/queries
(default: d_model/n_heads)
d_values: The dimensionality of the values (default: d_model/n_heads)
event_dispatcher: str or EventDispatcher instance to be used by this
module for dispatching events (default: the default
global dispatcher)
"""
def __init__(self, attention, d_model, n_heads, d_keys=None,
d_values=None, event_dispatcher=""):
super(RecurrentCrossAttentionLayer, self).__init__()
# Fill d_keys and d_values
d_keys = d_keys or (d_model//n_heads)
d_values = d_values or (d_model//n_heads)
self.inner_attention = attention
self.query_projection = Linear(d_model, d_keys * n_heads)
self.key_projection = Linear(d_model, d_keys * n_heads)
self.value_projection = Linear(d_model, d_values * n_heads)
self.out_projection = Linear(d_values * n_heads, d_model)
self.n_heads = n_heads
self.event_dispatcher = EventDispatcher.get(event_dispatcher)
def forward(self, query, keys, values, key_lengths, state=None):
"""Attend to the keys and values based on the passed in query.
In the argument description we make use of the following sizes
- N: the batch size
- S: the sequence length of the keys and values
- D: The input feature dimensionality passed in the constructor as
'd_model'
Argument
--------
query: (N, D) The tensor containing the queries
keys: (N, S, D) The tensor containing the keys
values: (N, S, D) The tensor containing the values
key_lengths: A fast_transformers.masking.BaseMask implementation
that defines the length of each key/value sequence
state: The state varies depending on the inner attention
implementation, but if it is not None then the keys and
values are ignored
"""
#Extract some shapes
N, _ = query.shape
H = self.n_heads
# Project the query
query = self.query_projection(query).view(N, H, -1)
# Project the keys and values if there is no state
if state is None:
_, S, _ = keys.shape
keys = self.key_projection(keys).view(N, S, H, -1)
values = self.value_projection(values).view(N, S, H, -1)
else:
keys = None
values = None
new_value, state = self.inner_attention(
query,
keys,
values,
key_lengths,
state=state
)
new_value = new_value.view(N, -1)
# Project the output and return
return self.out_projection(new_value), state
Classes
class RecurrentCrossAttentionLayer (attention, d_model, n_heads, d_keys=None, d_values=None, event_dispatcher='')
-
See fast_transformers.attention.attention_layer.AttentionLayer .
The differences with the aforementioned module as well as the RecurrentAttentionLayer are that this module projects the query every time and the keys and values only the first time they are provided.
Arguments
attention: Specific inner attention implementation that just computes a weighted average of values given a similarity of queries and keys. d_model: The input feature dimensionality n_heads: The number of heads for the multi head attention d_keys: The dimensionality of the keys/queries (default: d_model/n_heads) d_values: The dimensionality of the values (default: d_model/n_heads) event_dispatcher: str or EventDispatcher instance to be used by this module for dispatching events (default: the default global dispatcher)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
class RecurrentCrossAttentionLayer(Module): """See fast_transformers.attention.attention_layer.AttentionLayer . The differences with the aforementioned module as well as the RecurrentAttentionLayer are that this module projects the query every time and the keys and values only the first time they are provided. Arguments --------- attention: Specific inner attention implementation that just computes a weighted average of values given a similarity of queries and keys. d_model: The input feature dimensionality n_heads: The number of heads for the multi head attention d_keys: The dimensionality of the keys/queries (default: d_model/n_heads) d_values: The dimensionality of the values (default: d_model/n_heads) event_dispatcher: str or EventDispatcher instance to be used by this module for dispatching events (default: the default global dispatcher) """ def __init__(self, attention, d_model, n_heads, d_keys=None, d_values=None, event_dispatcher=""): super(RecurrentCrossAttentionLayer, self).__init__() # Fill d_keys and d_values d_keys = d_keys or (d_model//n_heads) d_values = d_values or (d_model//n_heads) self.inner_attention = attention self.query_projection = Linear(d_model, d_keys * n_heads) self.key_projection = Linear(d_model, d_keys * n_heads) self.value_projection = Linear(d_model, d_values * n_heads) self.out_projection = Linear(d_values * n_heads, d_model) self.n_heads = n_heads self.event_dispatcher = EventDispatcher.get(event_dispatcher) def forward(self, query, keys, values, key_lengths, state=None): """Attend to the keys and values based on the passed in query. In the argument description we make use of the following sizes - N: the batch size - S: the sequence length of the keys and values - D: The input feature dimensionality passed in the constructor as 'd_model' Argument -------- query: (N, D) The tensor containing the queries keys: (N, S, D) The tensor containing the keys values: (N, S, D) The tensor containing the values key_lengths: A fast_transformers.masking.BaseMask implementation that defines the length of each key/value sequence state: The state varies depending on the inner attention implementation, but if it is not None then the keys and values are ignored """ #Extract some shapes N, _ = query.shape H = self.n_heads # Project the query query = self.query_projection(query).view(N, H, -1) # Project the keys and values if there is no state if state is None: _, S, _ = keys.shape keys = self.key_projection(keys).view(N, S, H, -1) values = self.value_projection(values).view(N, S, H, -1) else: keys = None values = None new_value, state = self.inner_attention( query, keys, values, key_lengths, state=state ) new_value = new_value.view(N, -1) # Project the output and return return self.out_projection(new_value), state
Ancestors
- torch.nn.modules.module.Module
Methods
def forward(self, query, keys, values, key_lengths, state=None)
-
Attend to the keys and values based on the passed in query.
In the argument description we make use of the following sizes
- N: the batch size - S: the sequence length of the keys and values - D: The input feature dimensionality passed in the constructor as 'd_model'
Argument
query: (N, D) The tensor containing the queries keys: (N, S, D) The tensor containing the keys values: (N, S, D) The tensor containing the values key_lengths: A fast_transformers.masking.BaseMask implementation that defines the length of each key/value sequence state: The state varies depending on the inner attention implementation, but if it is not None then the keys and values are ignored
Expand source code
def forward(self, query, keys, values, key_lengths, state=None): """Attend to the keys and values based on the passed in query. In the argument description we make use of the following sizes - N: the batch size - S: the sequence length of the keys and values - D: The input feature dimensionality passed in the constructor as 'd_model' Argument -------- query: (N, D) The tensor containing the queries keys: (N, S, D) The tensor containing the keys values: (N, S, D) The tensor containing the values key_lengths: A fast_transformers.masking.BaseMask implementation that defines the length of each key/value sequence state: The state varies depending on the inner attention implementation, but if it is not None then the keys and values are ignored """ #Extract some shapes N, _ = query.shape H = self.n_heads # Project the query query = self.query_projection(query).view(N, H, -1) # Project the keys and values if there is no state if state is None: _, S, _ = keys.shape keys = self.key_projection(keys).view(N, S, H, -1) values = self.value_projection(values).view(N, S, H, -1) else: keys = None values = None new_value, state = self.inner_attention( query, keys, values, key_lengths, state=state ) new_value = new_value.view(N, -1) # Project the output and return return self.out_projection(new_value), state