Module fast_transformers.attention.linear_attention

Implement unmasked linear attention.

Expand source code
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
# Apoorv Vyas <avyas@idiap.ch>
#

"""Implement unmasked linear attention."""

import torch
from torch.nn import Module

from ..attention_registry import AttentionRegistry, Optional, Callable, Int, \
    EventDispatcherInstance
from ..events import EventDispatcher
from ..feature_maps import elu_feature_map


class LinearAttention(Module):
    """Implement unmasked attention using dot product of feature maps in
    O(N D^2) complexity.

    Given the queries, keys and values as Q, K, V instead of computing

        V' = softmax(Q.mm(K.t()), dim=-1).mm(V),

    we make use of a feature map function Φ(.) and perform the following
    computation

        V' = normalize(Φ(Q).mm(Φ(K).t())).mm(V).

    The above can be computed in O(N D^2) complexity where D is the
    dimensionality of Q, K and V and N is the sequence length. Depending on the
    feature map, however, the complexity of the attention might be limited.

    Arguments
    ---------
        feature_map: callable, a callable that applies the feature map to the
                     last dimension of a tensor (default: elu(x)+1)
        eps: float, a small number to ensure the numerical stability of the
             denominator (default: 1e-6)
        event_dispatcher: str or EventDispatcher instance to be used by this
                          module for dispatching events (default: the default
                          global dispatcher)
    """
    def __init__(self, query_dimensions, feature_map=None, eps=1e-6,
                 event_dispatcher=""):
        super(LinearAttention, self).__init__()
        self.feature_map = (
            feature_map(query_dimensions) if feature_map else
            elu_feature_map(query_dimensions)
        )
        self.eps = eps
        self.event_dispatcher = EventDispatcher.get(event_dispatcher)

    def forward(self, queries, keys, values, attn_mask, query_lengths,
                key_lengths):
        # Apply the feature map to the queries and keys
        self.feature_map.new_feature_map()
        Q = self.feature_map.forward_queries(queries)
        K = self.feature_map.forward_keys(keys)

        # Apply the key padding mask and make sure that the attn_mask is
        # all_ones
        if not attn_mask.all_ones:
            raise RuntimeError(("LinearAttention does not support arbitrary "
                                "attention masks"))
        K = K * key_lengths.float_matrix[:, :, None, None]

        # Compute the KV matrix, namely the dot product of keys and values so
        # that we never explicitly compute the attention matrix and thus
        # decrease the complexity
        KV = torch.einsum("nshd,nshm->nhmd", K, values)

        # Compute the normalizer
        Z = 1/(torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1))+self.eps)

        # Finally compute and return the new values
        V = torch.einsum("nlhd,nhmd,nlh->nlhm", Q, KV, Z)

        return V.contiguous()


# Register the attention implementation so that it becomes available in our
# builders
AttentionRegistry.register(
    "linear", LinearAttention,
    [
        ("query_dimensions", Int),
        ("feature_map", Optional(Callable)),
        ("event_dispatcher", Optional(EventDispatcherInstance, ""))
    ]
)

Classes

class LinearAttention (query_dimensions, feature_map=None, eps=1e-06, event_dispatcher='')

Implement unmasked attention using dot product of feature maps in O(N D^2) complexity.

Given the queries, keys and values as Q, K, V instead of computing

V' = softmax(Q.mm(K.t()), dim=-1).mm(V),

we make use of a feature map function Φ(.) and perform the following computation

V' = normalize(Φ(Q).mm(Φ(K).t())).mm(V).

The above can be computed in O(N D^2) complexity where D is the dimensionality of Q, K and V and N is the sequence length. Depending on the feature map, however, the complexity of the attention might be limited.

Arguments

feature_map: callable, a callable that applies the feature map to the
             last dimension of a tensor (default: elu(x)+1)
eps: float, a small number to ensure the numerical stability of the
     denominator (default: 1e-6)
event_dispatcher: str or EventDispatcher instance to be used by this
                  module for dispatching events (default: the default
                  global dispatcher)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class LinearAttention(Module):
    """Implement unmasked attention using dot product of feature maps in
    O(N D^2) complexity.

    Given the queries, keys and values as Q, K, V instead of computing

        V' = softmax(Q.mm(K.t()), dim=-1).mm(V),

    we make use of a feature map function Φ(.) and perform the following
    computation

        V' = normalize(Φ(Q).mm(Φ(K).t())).mm(V).

    The above can be computed in O(N D^2) complexity where D is the
    dimensionality of Q, K and V and N is the sequence length. Depending on the
    feature map, however, the complexity of the attention might be limited.

    Arguments
    ---------
        feature_map: callable, a callable that applies the feature map to the
                     last dimension of a tensor (default: elu(x)+1)
        eps: float, a small number to ensure the numerical stability of the
             denominator (default: 1e-6)
        event_dispatcher: str or EventDispatcher instance to be used by this
                          module for dispatching events (default: the default
                          global dispatcher)
    """
    def __init__(self, query_dimensions, feature_map=None, eps=1e-6,
                 event_dispatcher=""):
        super(LinearAttention, self).__init__()
        self.feature_map = (
            feature_map(query_dimensions) if feature_map else
            elu_feature_map(query_dimensions)
        )
        self.eps = eps
        self.event_dispatcher = EventDispatcher.get(event_dispatcher)

    def forward(self, queries, keys, values, attn_mask, query_lengths,
                key_lengths):
        # Apply the feature map to the queries and keys
        self.feature_map.new_feature_map()
        Q = self.feature_map.forward_queries(queries)
        K = self.feature_map.forward_keys(keys)

        # Apply the key padding mask and make sure that the attn_mask is
        # all_ones
        if not attn_mask.all_ones:
            raise RuntimeError(("LinearAttention does not support arbitrary "
                                "attention masks"))
        K = K * key_lengths.float_matrix[:, :, None, None]

        # Compute the KV matrix, namely the dot product of keys and values so
        # that we never explicitly compute the attention matrix and thus
        # decrease the complexity
        KV = torch.einsum("nshd,nshm->nhmd", K, values)

        # Compute the normalizer
        Z = 1/(torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1))+self.eps)

        # Finally compute and return the new values
        V = torch.einsum("nlhd,nhmd,nlh->nlhm", Q, KV, Z)

        return V.contiguous()

Ancestors

  • torch.nn.modules.module.Module

Methods

def forward(self, queries, keys, values, attn_mask, query_lengths, key_lengths)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Expand source code
def forward(self, queries, keys, values, attn_mask, query_lengths,
            key_lengths):
    # Apply the feature map to the queries and keys
    self.feature_map.new_feature_map()
    Q = self.feature_map.forward_queries(queries)
    K = self.feature_map.forward_keys(keys)

    # Apply the key padding mask and make sure that the attn_mask is
    # all_ones
    if not attn_mask.all_ones:
        raise RuntimeError(("LinearAttention does not support arbitrary "
                            "attention masks"))
    K = K * key_lengths.float_matrix[:, :, None, None]

    # Compute the KV matrix, namely the dot product of keys and values so
    # that we never explicitly compute the attention matrix and thus
    # decrease the complexity
    KV = torch.einsum("nshd,nshm->nhmd", K, values)

    # Compute the normalizer
    Z = 1/(torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1))+self.eps)

    # Finally compute and return the new values
    V = torch.einsum("nlhd,nhmd,nlh->nlhm", Q, KV, Z)

    return V.contiguous()