Module fast_transformers.attention.linear_attention
Implement unmasked linear attention.
Expand source code
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
# Apoorv Vyas <avyas@idiap.ch>
#
"""Implement unmasked linear attention."""
import torch
from torch.nn import Module
from ..attention_registry import AttentionRegistry, Optional, Callable, Int, \
EventDispatcherInstance
from ..events import EventDispatcher
from ..feature_maps import elu_feature_map
class LinearAttention(Module):
"""Implement unmasked attention using dot product of feature maps in
O(N D^2) complexity.
Given the queries, keys and values as Q, K, V instead of computing
V' = softmax(Q.mm(K.t()), dim=-1).mm(V),
we make use of a feature map function Φ(.) and perform the following
computation
V' = normalize(Φ(Q).mm(Φ(K).t())).mm(V).
The above can be computed in O(N D^2) complexity where D is the
dimensionality of Q, K and V and N is the sequence length. Depending on the
feature map, however, the complexity of the attention might be limited.
Arguments
---------
feature_map: callable, a callable that applies the feature map to the
last dimension of a tensor (default: elu(x)+1)
eps: float, a small number to ensure the numerical stability of the
denominator (default: 1e-6)
event_dispatcher: str or EventDispatcher instance to be used by this
module for dispatching events (default: the default
global dispatcher)
"""
def __init__(self, query_dimensions, feature_map=None, eps=1e-6,
event_dispatcher=""):
super(LinearAttention, self).__init__()
self.feature_map = (
feature_map(query_dimensions) if feature_map else
elu_feature_map(query_dimensions)
)
self.eps = eps
self.event_dispatcher = EventDispatcher.get(event_dispatcher)
def forward(self, queries, keys, values, attn_mask, query_lengths,
key_lengths):
# Apply the feature map to the queries and keys
self.feature_map.new_feature_map()
Q = self.feature_map.forward_queries(queries)
K = self.feature_map.forward_keys(keys)
# Apply the key padding mask and make sure that the attn_mask is
# all_ones
if not attn_mask.all_ones:
raise RuntimeError(("LinearAttention does not support arbitrary "
"attention masks"))
K = K * key_lengths.float_matrix[:, :, None, None]
# Compute the KV matrix, namely the dot product of keys and values so
# that we never explicitly compute the attention matrix and thus
# decrease the complexity
KV = torch.einsum("nshd,nshm->nhmd", K, values)
# Compute the normalizer
Z = 1/(torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1))+self.eps)
# Finally compute and return the new values
V = torch.einsum("nlhd,nhmd,nlh->nlhm", Q, KV, Z)
return V.contiguous()
# Register the attention implementation so that it becomes available in our
# builders
AttentionRegistry.register(
"linear", LinearAttention,
[
("query_dimensions", Int),
("feature_map", Optional(Callable)),
("event_dispatcher", Optional(EventDispatcherInstance, ""))
]
)
Classes
class LinearAttention (query_dimensions, feature_map=None, eps=1e-06, event_dispatcher='')
-
Implement unmasked attention using dot product of feature maps in O(N D^2) complexity.
Given the queries, keys and values as Q, K, V instead of computing
V' = softmax(Q.mm(K.t()), dim=-1).mm(V),
we make use of a feature map function Φ(.) and perform the following computation
V' = normalize(Φ(Q).mm(Φ(K).t())).mm(V).
The above can be computed in O(N D^2) complexity where D is the dimensionality of Q, K and V and N is the sequence length. Depending on the feature map, however, the complexity of the attention might be limited.
Arguments
feature_map: callable, a callable that applies the feature map to the last dimension of a tensor (default: elu(x)+1) eps: float, a small number to ensure the numerical stability of the denominator (default: 1e-6) event_dispatcher: str or EventDispatcher instance to be used by this module for dispatching events (default: the default global dispatcher)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
class LinearAttention(Module): """Implement unmasked attention using dot product of feature maps in O(N D^2) complexity. Given the queries, keys and values as Q, K, V instead of computing V' = softmax(Q.mm(K.t()), dim=-1).mm(V), we make use of a feature map function Φ(.) and perform the following computation V' = normalize(Φ(Q).mm(Φ(K).t())).mm(V). The above can be computed in O(N D^2) complexity where D is the dimensionality of Q, K and V and N is the sequence length. Depending on the feature map, however, the complexity of the attention might be limited. Arguments --------- feature_map: callable, a callable that applies the feature map to the last dimension of a tensor (default: elu(x)+1) eps: float, a small number to ensure the numerical stability of the denominator (default: 1e-6) event_dispatcher: str or EventDispatcher instance to be used by this module for dispatching events (default: the default global dispatcher) """ def __init__(self, query_dimensions, feature_map=None, eps=1e-6, event_dispatcher=""): super(LinearAttention, self).__init__() self.feature_map = ( feature_map(query_dimensions) if feature_map else elu_feature_map(query_dimensions) ) self.eps = eps self.event_dispatcher = EventDispatcher.get(event_dispatcher) def forward(self, queries, keys, values, attn_mask, query_lengths, key_lengths): # Apply the feature map to the queries and keys self.feature_map.new_feature_map() Q = self.feature_map.forward_queries(queries) K = self.feature_map.forward_keys(keys) # Apply the key padding mask and make sure that the attn_mask is # all_ones if not attn_mask.all_ones: raise RuntimeError(("LinearAttention does not support arbitrary " "attention masks")) K = K * key_lengths.float_matrix[:, :, None, None] # Compute the KV matrix, namely the dot product of keys and values so # that we never explicitly compute the attention matrix and thus # decrease the complexity KV = torch.einsum("nshd,nshm->nhmd", K, values) # Compute the normalizer Z = 1/(torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1))+self.eps) # Finally compute and return the new values V = torch.einsum("nlhd,nhmd,nlh->nlhm", Q, KV, Z) return V.contiguous()
Ancestors
- torch.nn.modules.module.Module
Methods
def forward(self, queries, keys, values, attn_mask, query_lengths, key_lengths)
-
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the :class:
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.Expand source code
def forward(self, queries, keys, values, attn_mask, query_lengths, key_lengths): # Apply the feature map to the queries and keys self.feature_map.new_feature_map() Q = self.feature_map.forward_queries(queries) K = self.feature_map.forward_keys(keys) # Apply the key padding mask and make sure that the attn_mask is # all_ones if not attn_mask.all_ones: raise RuntimeError(("LinearAttention does not support arbitrary " "attention masks")) K = K * key_lengths.float_matrix[:, :, None, None] # Compute the KV matrix, namely the dot product of keys and values so # that we never explicitly compute the attention matrix and thus # decrease the complexity KV = torch.einsum("nshd,nshm->nhmd", K, values) # Compute the normalizer Z = 1/(torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1))+self.eps) # Finally compute and return the new values V = torch.einsum("nlhd,nhmd,nlh->nlhm", Q, KV, Z) return V.contiguous()