Module fast_transformers.attention.local_attention

Implement local context attention.

Expand source code
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>
#

"""Implement local context attention."""

from math import sqrt

import torch
from torch.nn import Module, Dropout
from torch.nn import functional as F

from ..attention_registry import AttentionRegistry, Optional, Int, Float, \
    EventDispatcherInstance
from ..events import EventDispatcher
from ..local_product import local_dot_product, local_weighted_average


class LocalAttention(Module):
    """Implement fast local attention where a query can only attend to
    neighboring keys.

    In this attention module the query Q_i can only attend to a key K_j if
    |i-j| < local_context/2.

    Arguments
    ---------
        local_context: The neighborhood to consider for local attention.
        softmax_temp: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.1)
        event_dispatcher: str or EventDispatcher instance to be used by this
                          module for dispatching events (default: the default
                          global dispatcher)
    """
    def __init__(self, local_context, softmax_temp=None, attention_dropout=0.1,
                 event_dispatcher=""):
        super(LocalAttention, self).__init__()
        self.local_context = local_context
        self.softmax_temp = softmax_temp
        self.dropout = Dropout(attention_dropout)
        self.event_dispatcher = EventDispatcher.get(event_dispatcher)

    def forward(self, queries, keys, values, attn_mask, query_lengths,
                key_lengths):
        """Implements the local attention.

        The attn_mask can be anything but the only values that will be
        considered will be the ones in the neighborhood of each query.

        Arguments
        ---------
            queries: (N, L, H, E) The tensor containing the queries
            keys: (N, S, H, E) The tensor containing the keys
            values: (N, S, H, D) The tensor containing the values
            attn_mask: An implementation of BaseMask that encodes where each
                       query can attend to
            query_lengths: An implementation of  BaseMask that encodes how
                           many queries each sequence in the batch consists of
            key_lengths: An implementation of BaseMask that encodes how
                         many queries each sequence in the batch consists of
        """
        # Extract some shapes and compute the temperature
        N, L, H, E = queries.shape
        _, S, _, D = values.shape
        context = self.local_context
        softmax_temp = self.softmax_temp or 1./sqrt(E)

        # Permute the dimensions to NHLE instead of NLHE
        queries = queries.permute(0, 2, 1, 3).contiguous()
        keys = keys.permute(0, 2, 1, 3).contiguous()
        values = values.permute(0, 2, 1, 3).contiguous()

        QK = local_dot_product(
            queries,
            keys,
            attn_mask.additive_matrix_finite,
            key_lengths.lengths,
            self.local_context
        )
        A = self.dropout(torch.softmax(softmax_temp * QK, dim=-1))

        V_new = local_weighted_average(A, values)

        return V_new.permute(0, 2, 1, 3).contiguous()


# Register the attention implementation so that it becomes available in our
# builders
AttentionRegistry.register(
    "local", LocalAttention,
    [
        ("local_context", Int),
        ("softmax_temp", Optional(Float)),
        ("attention_dropout", Optional(Float, 0.1)),
        ("event_dispatcher", Optional(EventDispatcherInstance, ""))
    ]
)

Functions

def local_dot_product(...)
def local_weighted_average(...)

Classes

class LocalAttention (local_context, softmax_temp=None, attention_dropout=0.1, event_dispatcher='')

Implement fast local attention where a query can only attend to neighboring keys.

In this attention module the query Q_i can only attend to a key K_j if |i-j| < local_context/2.

Arguments

local_context: The neighborhood to consider for local attention.
softmax_temp: The temperature to use for the softmax attention.
              (default: 1/sqrt(d_keys) where d_keys is computed at
              runtime)
attention_dropout: The dropout rate to apply to the attention
                   (default: 0.1)
event_dispatcher: str or EventDispatcher instance to be used by this
                  module for dispatching events (default: the default
                  global dispatcher)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class LocalAttention(Module):
    """Implement fast local attention where a query can only attend to
    neighboring keys.

    In this attention module the query Q_i can only attend to a key K_j if
    |i-j| < local_context/2.

    Arguments
    ---------
        local_context: The neighborhood to consider for local attention.
        softmax_temp: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.1)
        event_dispatcher: str or EventDispatcher instance to be used by this
                          module for dispatching events (default: the default
                          global dispatcher)
    """
    def __init__(self, local_context, softmax_temp=None, attention_dropout=0.1,
                 event_dispatcher=""):
        super(LocalAttention, self).__init__()
        self.local_context = local_context
        self.softmax_temp = softmax_temp
        self.dropout = Dropout(attention_dropout)
        self.event_dispatcher = EventDispatcher.get(event_dispatcher)

    def forward(self, queries, keys, values, attn_mask, query_lengths,
                key_lengths):
        """Implements the local attention.

        The attn_mask can be anything but the only values that will be
        considered will be the ones in the neighborhood of each query.

        Arguments
        ---------
            queries: (N, L, H, E) The tensor containing the queries
            keys: (N, S, H, E) The tensor containing the keys
            values: (N, S, H, D) The tensor containing the values
            attn_mask: An implementation of BaseMask that encodes where each
                       query can attend to
            query_lengths: An implementation of  BaseMask that encodes how
                           many queries each sequence in the batch consists of
            key_lengths: An implementation of BaseMask that encodes how
                         many queries each sequence in the batch consists of
        """
        # Extract some shapes and compute the temperature
        N, L, H, E = queries.shape
        _, S, _, D = values.shape
        context = self.local_context
        softmax_temp = self.softmax_temp or 1./sqrt(E)

        # Permute the dimensions to NHLE instead of NLHE
        queries = queries.permute(0, 2, 1, 3).contiguous()
        keys = keys.permute(0, 2, 1, 3).contiguous()
        values = values.permute(0, 2, 1, 3).contiguous()

        QK = local_dot_product(
            queries,
            keys,
            attn_mask.additive_matrix_finite,
            key_lengths.lengths,
            self.local_context
        )
        A = self.dropout(torch.softmax(softmax_temp * QK, dim=-1))

        V_new = local_weighted_average(A, values)

        return V_new.permute(0, 2, 1, 3).contiguous()

Ancestors

  • torch.nn.modules.module.Module

Methods

def forward(self, queries, keys, values, attn_mask, query_lengths, key_lengths)

Implements the local attention.

The attn_mask can be anything but the only values that will be considered will be the ones in the neighborhood of each query.

Arguments

queries: (N, L, H, E) The tensor containing the queries
keys: (N, S, H, E) The tensor containing the keys
values: (N, S, H, D) The tensor containing the values
attn_mask: An implementation of BaseMask that encodes where each
           query can attend to
query_lengths: An implementation of  BaseMask that encodes how
               many queries each sequence in the batch consists of
key_lengths: An implementation of BaseMask that encodes how
             many queries each sequence in the batch consists of
Expand source code
def forward(self, queries, keys, values, attn_mask, query_lengths,
            key_lengths):
    """Implements the local attention.

    The attn_mask can be anything but the only values that will be
    considered will be the ones in the neighborhood of each query.

    Arguments
    ---------
        queries: (N, L, H, E) The tensor containing the queries
        keys: (N, S, H, E) The tensor containing the keys
        values: (N, S, H, D) The tensor containing the values
        attn_mask: An implementation of BaseMask that encodes where each
                   query can attend to
        query_lengths: An implementation of  BaseMask that encodes how
                       many queries each sequence in the batch consists of
        key_lengths: An implementation of BaseMask that encodes how
                     many queries each sequence in the batch consists of
    """
    # Extract some shapes and compute the temperature
    N, L, H, E = queries.shape
    _, S, _, D = values.shape
    context = self.local_context
    softmax_temp = self.softmax_temp or 1./sqrt(E)

    # Permute the dimensions to NHLE instead of NLHE
    queries = queries.permute(0, 2, 1, 3).contiguous()
    keys = keys.permute(0, 2, 1, 3).contiguous()
    values = values.permute(0, 2, 1, 3).contiguous()

    QK = local_dot_product(
        queries,
        keys,
        attn_mask.additive_matrix_finite,
        key_lengths.lengths,
        self.local_context
    )
    A = self.dropout(torch.softmax(softmax_temp * QK, dim=-1))

    V_new = local_weighted_average(A, values)

    return V_new.permute(0, 2, 1, 3).contiguous()