Module fast_transformers.attention.full_attention

Implement the full attention similar to the one implemented by PyTorch's MultiHeadAttention module. Note that this module is to be used in conjuction with the AttentionLayer in order to work.

Expand source code
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
# Apoorv Vyas <avyas@idiap.ch>
#

"""Implement the full attention similar to the one implemented by PyTorch's
MultiHeadAttention module. Note that this module is to be used in conjuction
with the `fast_transformers.attention.attention_layer.AttentionLayer` in order
to work."""

from math import sqrt

import torch
from torch.nn import Dropout, Module

from ..attention_registry import AttentionRegistry, Optional, Float, \
    EventDispatcherInstance
from ..events import EventDispatcher, AttentionEvent


class FullAttention(Module):
    """Implement the scaled dot product attention with softmax.

    Arguments
    ---------
        softmax_temp: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.1)
        event_dispatcher: str or EventDispatcher instance to be used by this
                          module for dispatching events (default: the default
                          global dispatcher)
    """
    def __init__(self, softmax_temp=None, attention_dropout=0.1,
                 event_dispatcher=""):
        super(FullAttention, self).__init__()
        self.softmax_temp = softmax_temp
        self.dropout = Dropout(attention_dropout)
        self.event_dispatcher = EventDispatcher.get(event_dispatcher)

    def forward(self, queries, keys, values, attn_mask, query_lengths,
                key_lengths):
        """Implements the multihead softmax attention.

        Arguments
        ---------
            queries: (N, L, H, E) The tensor containing the queries
            keys: (N, S, H, E) The tensor containing the keys
            values: (N, S, H, D) The tensor containing the values
            attn_mask: An implementation of BaseMask that encodes where each
                       query can attend to
            query_lengths: An implementation of  BaseMask that encodes how
                           many queries each sequence in the batch consists of
            key_lengths: An implementation of BaseMask that encodes how
                         many queries each sequence in the batch consists of
        """
        # Extract some shapes and compute the temperature
        N, L, H, E = queries.shape
        _, S, _, D = values.shape
        softmax_temp = self.softmax_temp or 1./sqrt(E)

        # Compute the unnormalized attention and apply the masks
        QK = torch.einsum("nlhe,nshe->nhls", queries, keys)
        if not attn_mask.all_ones:
            QK = QK + attn_mask.additive_matrix
        QK = QK + key_lengths.additive_matrix[:, None, None]

        # Compute the attention and the weighted average
        A = self.dropout(torch.softmax(softmax_temp * QK, dim=-1))
        V = torch.einsum("nhls,nshd->nlhd", A, values)

        # Let the world know of the attention matrix
        self.event_dispatcher.dispatch(AttentionEvent(self, A))

        # Make sure that what we return is contiguous
        return V.contiguous()


# Register the attention implementation so that it becomes available in our
# builders
AttentionRegistry.register(
    "full", FullAttention,
    [
        ("softmax_temp", Optional(Float)),
        ("attention_dropout", Optional(Float, 0.1)),
        ("event_dispatcher", Optional(EventDispatcherInstance, ""))
    ]
)

Classes

class FullAttention (softmax_temp=None, attention_dropout=0.1, event_dispatcher='')

Implement the scaled dot product attention with softmax.

Arguments

softmax_temp: The temperature to use for the softmax attention.
              (default: 1/sqrt(d_keys) where d_keys is computed at
              runtime)
attention_dropout: The dropout rate to apply to the attention
                   (default: 0.1)
event_dispatcher: str or EventDispatcher instance to be used by this
                  module for dispatching events (default: the default
                  global dispatcher)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class FullAttention(Module):
    """Implement the scaled dot product attention with softmax.

    Arguments
    ---------
        softmax_temp: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.1)
        event_dispatcher: str or EventDispatcher instance to be used by this
                          module for dispatching events (default: the default
                          global dispatcher)
    """
    def __init__(self, softmax_temp=None, attention_dropout=0.1,
                 event_dispatcher=""):
        super(FullAttention, self).__init__()
        self.softmax_temp = softmax_temp
        self.dropout = Dropout(attention_dropout)
        self.event_dispatcher = EventDispatcher.get(event_dispatcher)

    def forward(self, queries, keys, values, attn_mask, query_lengths,
                key_lengths):
        """Implements the multihead softmax attention.

        Arguments
        ---------
            queries: (N, L, H, E) The tensor containing the queries
            keys: (N, S, H, E) The tensor containing the keys
            values: (N, S, H, D) The tensor containing the values
            attn_mask: An implementation of BaseMask that encodes where each
                       query can attend to
            query_lengths: An implementation of  BaseMask that encodes how
                           many queries each sequence in the batch consists of
            key_lengths: An implementation of BaseMask that encodes how
                         many queries each sequence in the batch consists of
        """
        # Extract some shapes and compute the temperature
        N, L, H, E = queries.shape
        _, S, _, D = values.shape
        softmax_temp = self.softmax_temp or 1./sqrt(E)

        # Compute the unnormalized attention and apply the masks
        QK = torch.einsum("nlhe,nshe->nhls", queries, keys)
        if not attn_mask.all_ones:
            QK = QK + attn_mask.additive_matrix
        QK = QK + key_lengths.additive_matrix[:, None, None]

        # Compute the attention and the weighted average
        A = self.dropout(torch.softmax(softmax_temp * QK, dim=-1))
        V = torch.einsum("nhls,nshd->nlhd", A, values)

        # Let the world know of the attention matrix
        self.event_dispatcher.dispatch(AttentionEvent(self, A))

        # Make sure that what we return is contiguous
        return V.contiguous()

Ancestors

  • torch.nn.modules.module.Module

Methods

def forward(self, queries, keys, values, attn_mask, query_lengths, key_lengths)

Implements the multihead softmax attention.

Arguments

queries: (N, L, H, E) The tensor containing the queries
keys: (N, S, H, E) The tensor containing the keys
values: (N, S, H, D) The tensor containing the values
attn_mask: An implementation of BaseMask that encodes where each
           query can attend to
query_lengths: An implementation of  BaseMask that encodes how
               many queries each sequence in the batch consists of
key_lengths: An implementation of BaseMask that encodes how
             many queries each sequence in the batch consists of
Expand source code
def forward(self, queries, keys, values, attn_mask, query_lengths,
            key_lengths):
    """Implements the multihead softmax attention.

    Arguments
    ---------
        queries: (N, L, H, E) The tensor containing the queries
        keys: (N, S, H, E) The tensor containing the keys
        values: (N, S, H, D) The tensor containing the values
        attn_mask: An implementation of BaseMask that encodes where each
                   query can attend to
        query_lengths: An implementation of  BaseMask that encodes how
                       many queries each sequence in the batch consists of
        key_lengths: An implementation of BaseMask that encodes how
                     many queries each sequence in the batch consists of
    """
    # Extract some shapes and compute the temperature
    N, L, H, E = queries.shape
    _, S, _, D = values.shape
    softmax_temp = self.softmax_temp or 1./sqrt(E)

    # Compute the unnormalized attention and apply the masks
    QK = torch.einsum("nlhe,nshe->nhls", queries, keys)
    if not attn_mask.all_ones:
        QK = QK + attn_mask.additive_matrix
    QK = QK + key_lengths.additive_matrix[:, None, None]

    # Compute the attention and the weighted average
    A = self.dropout(torch.softmax(softmax_temp * QK, dim=-1))
    V = torch.einsum("nhls,nshd->nlhd", A, values)

    # Let the world know of the attention matrix
    self.event_dispatcher.dispatch(AttentionEvent(self, A))

    # Make sure that what we return is contiguous
    return V.contiguous()