Module fast_transformers.recurrent.attention.self_attention.linear_attention

Implement the causally masked linear attention as a recurrent model.

Expand source code
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
# Apoorv Vyas <avyas@idiap.ch>
#

"""Implement the causally masked linear attention as a recurrent model."""

import torch
from torch.nn import Module

from ....attention_registry import RecurrentAttentionRegistry, Optional, Int, \
    Callable, EventDispatcherInstance
from ....events import EventDispatcher
from ....feature_maps import elu_feature_map
from ..._utils import check_state


class RecurrentLinearAttention(Module):
    """Implement fast_transformers.attention.causal_linear_attention as a
    fixed-dimensional state recurrent model.

    See fast_transformers.attention.linear_attention and
    fast_transformers.attention.causal_linear_attention for the general concept
    of replacing the softmax with feature maps.

    Arguments
    ---------
        feature_map: callable, a callable that applies the feature map to the
                     last dimension of a tensor (default: elu(x)+1)
        eps: float, a small number to ensure the numerical stability of the
             denominator (default: 1e-6)
        event_dispatcher: str or EventDispatcher instance to be used by this
                          module for dispatching events (default: the default
                          global dispatcher)
    """
    def __init__(self, query_dimensions, feature_map=None, eps=1e-6,
                 event_dispatcher=""):
        super(RecurrentLinearAttention, self).__init__()
        self.feature_map = (
            feature_map(query_dimensions) if feature_map else
            elu_feature_map(query_dimensions)
        )
        self.eps = eps
        self.event_dispatcher = EventDispatcher.get(event_dispatcher)

    def forward(self, query, key, value, state=None, memory=None):
        # Normalize state/memory
        state = check_state(state, memory)

        # If this is a new sequence reinitialize the feature map
        if state is None:
            self.feature_map.new_feature_map()

        # Apply the feature map to the query and key
        Q = self.feature_map.forward_queries(query)
        K = self.feature_map.forward_keys(key)

        # Extract some shapes
        N, H, D = Q.shape
        _, _, M = value.shape

        # Extract the memory or initialize it
        if state is None:
            Si = query.new_zeros((N, H, D, M))
            Zi = query.new_zeros((N, H, D))
        else:
            Si, Zi = state

        # Ensure the batch size did not change
        if len(Si) != N:
            raise ValueError("The batch size changed during iteration")

        # Update the internal state
        #
        # NOTE: The if clause is added due to GitHub PR #10. Simply using the
        # following two lines does not perform the operation in place which
        # means it is slower for inference.
        if K.grad_fn is not None or value.grad_fn is not None:
            Zi = Zi + K
            Si = Si + torch.einsum("nhd,nhm->nhdm", K, value)
        else:
            Zi += K
            Si += torch.einsum("nhd,nhm->nhdm", K, value)

        # Compute the output
        Z = 1. / (torch.einsum("nhd,nhd->nh", Q, Zi) + self.eps)
        V = torch.einsum("nhd,nhdm,nh->nhm", Q, Si, Z)

        return V, [Si, Zi]


# Register the attention implementation so that it becomes available in our
# builders
RecurrentAttentionRegistry.register(
    "linear", RecurrentLinearAttention,
    [
        ("query_dimensions", Int),
        ("feature_map", Optional(Callable)),
        ("event_dispatcher", Optional(EventDispatcherInstance, ""))
    ]
)
RecurrentAttentionRegistry.register(
    "causal-linear", RecurrentLinearAttention,
    [
        ("query_dimensions", Int),
        ("feature_map", Optional(Callable)),
        ("event_dispatcher", Optional(EventDispatcherInstance, ""))
    ]
)

Classes

class RecurrentLinearAttention (query_dimensions, feature_map=None, eps=1e-06, event_dispatcher='')

Implement fast_transformers.attention.causal_linear_attention as a fixed-dimensional state recurrent model.

See fast_transformers.attention.linear_attention and fast_transformers.attention.causal_linear_attention for the general concept of replacing the softmax with feature maps.

Arguments

feature_map: callable, a callable that applies the feature map to the
             last dimension of a tensor (default: elu(x)+1)
eps: float, a small number to ensure the numerical stability of the
     denominator (default: 1e-6)
event_dispatcher: str or EventDispatcher instance to be used by this
                  module for dispatching events (default: the default
                  global dispatcher)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class RecurrentLinearAttention(Module):
    """Implement fast_transformers.attention.causal_linear_attention as a
    fixed-dimensional state recurrent model.

    See fast_transformers.attention.linear_attention and
    fast_transformers.attention.causal_linear_attention for the general concept
    of replacing the softmax with feature maps.

    Arguments
    ---------
        feature_map: callable, a callable that applies the feature map to the
                     last dimension of a tensor (default: elu(x)+1)
        eps: float, a small number to ensure the numerical stability of the
             denominator (default: 1e-6)
        event_dispatcher: str or EventDispatcher instance to be used by this
                          module for dispatching events (default: the default
                          global dispatcher)
    """
    def __init__(self, query_dimensions, feature_map=None, eps=1e-6,
                 event_dispatcher=""):
        super(RecurrentLinearAttention, self).__init__()
        self.feature_map = (
            feature_map(query_dimensions) if feature_map else
            elu_feature_map(query_dimensions)
        )
        self.eps = eps
        self.event_dispatcher = EventDispatcher.get(event_dispatcher)

    def forward(self, query, key, value, state=None, memory=None):
        # Normalize state/memory
        state = check_state(state, memory)

        # If this is a new sequence reinitialize the feature map
        if state is None:
            self.feature_map.new_feature_map()

        # Apply the feature map to the query and key
        Q = self.feature_map.forward_queries(query)
        K = self.feature_map.forward_keys(key)

        # Extract some shapes
        N, H, D = Q.shape
        _, _, M = value.shape

        # Extract the memory or initialize it
        if state is None:
            Si = query.new_zeros((N, H, D, M))
            Zi = query.new_zeros((N, H, D))
        else:
            Si, Zi = state

        # Ensure the batch size did not change
        if len(Si) != N:
            raise ValueError("The batch size changed during iteration")

        # Update the internal state
        #
        # NOTE: The if clause is added due to GitHub PR #10. Simply using the
        # following two lines does not perform the operation in place which
        # means it is slower for inference.
        if K.grad_fn is not None or value.grad_fn is not None:
            Zi = Zi + K
            Si = Si + torch.einsum("nhd,nhm->nhdm", K, value)
        else:
            Zi += K
            Si += torch.einsum("nhd,nhm->nhdm", K, value)

        # Compute the output
        Z = 1. / (torch.einsum("nhd,nhd->nh", Q, Zi) + self.eps)
        V = torch.einsum("nhd,nhdm,nh->nhm", Q, Si, Z)

        return V, [Si, Zi]

Ancestors

  • torch.nn.modules.module.Module

Methods

def forward(self, query, key, value, state=None, memory=None)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Expand source code
def forward(self, query, key, value, state=None, memory=None):
    # Normalize state/memory
    state = check_state(state, memory)

    # If this is a new sequence reinitialize the feature map
    if state is None:
        self.feature_map.new_feature_map()

    # Apply the feature map to the query and key
    Q = self.feature_map.forward_queries(query)
    K = self.feature_map.forward_keys(key)

    # Extract some shapes
    N, H, D = Q.shape
    _, _, M = value.shape

    # Extract the memory or initialize it
    if state is None:
        Si = query.new_zeros((N, H, D, M))
        Zi = query.new_zeros((N, H, D))
    else:
        Si, Zi = state

    # Ensure the batch size did not change
    if len(Si) != N:
        raise ValueError("The batch size changed during iteration")

    # Update the internal state
    #
    # NOTE: The if clause is added due to GitHub PR #10. Simply using the
    # following two lines does not perform the operation in place which
    # means it is slower for inference.
    if K.grad_fn is not None or value.grad_fn is not None:
        Zi = Zi + K
        Si = Si + torch.einsum("nhd,nhm->nhdm", K, value)
    else:
        Zi += K
        Si += torch.einsum("nhd,nhm->nhdm", K, value)

    # Compute the output
    Z = 1. / (torch.einsum("nhd,nhd->nh", Q, Zi) + self.eps)
    V = torch.einsum("nhd,nhdm,nh->nhm", Q, Si, Z)

    return V, [Si, Zi]