Module fast_transformers.attention.conditional_full_attention

Implement a self attention that delegates to full attention or another attention depending on the input sequence length.

Expand source code
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
# Apoorv Vyas <avyas@idiap.ch>
#

"""Implement a self attention that delegates to full attention or another
attention depending on the input sequence length."""

import torch
from torch.nn import Module

from ..attention_registry import AttentionRegistry, Optional, Int, Float, \
    EventDispatcherInstance
from ..events import EventDispatcher
from .full_attention import FullAttention


class ConditionalFullAttention(Module):
    """"Delegate to full attention if the input sequence is short.

    Arguments
    ---------
        other_attention: Use the passed attention module if the sequence is
                         longer than 'length_limit'.
        length_limit: An integer denoting the maximum sequence length to
                      consider.
        softmax_temp: See fast_transformers.attention.full_attention.
        attention_dropout: See fast_transformers.attention.full_attention.
        event_dispatcher: str or EventDispatcher instance to be used by this
                          module for dispatching events (default: the default
                          global dispatcher)
    """
    def __init__(self, other_attention, length_limit=512, softmax_temp=None,
                 attention_dropout=0.1, event_dispatcher=""):
        super(ConditionalFullAttention, self).__init__()
        self.full_attention = FullAttention(softmax_temp, attention_dropout)
        self.other_attention = other_attention
        self.length_limit = length_limit
        self.event_dispatcher = EventDispatcher.get(event_dispatcher)

    def forward(self, queries, keys, values, attn_mask, query_lengths,
                key_lengths):
        # Extract some shapes to compare with the length limit
        L = queries.shape[1]
        S = values.shape[1]

        if L > self.length_limit or S > self.length_limit:
            return self.other_attention(queries, keys, values, attn_mask,
                                        query_lengths, key_lengths)
        else:
            return self.full_attention(queries, keys, values, attn_mask,
                                       query_lengths, key_lengths)


# Register the attention implementation so that it becomes available in our
# builders
AttentionRegistry.register(
    "conditional-full", ConditionalFullAttention,
    [
        ("length_limit", Optional(Int, 512)),
        ("softmax_temp", Optional(Float)),
        ("attention_dropout", Optional(Float, 0.1)),
        ("event_dispatcher", Optional(EventDispatcherInstance, ""))
    ]
)

Classes

class ConditionalFullAttention (other_attention, length_limit=512, softmax_temp=None, attention_dropout=0.1, event_dispatcher='')

"Delegate to full attention if the input sequence is short.

Arguments

other_attention: Use the passed attention module if the sequence is
                 longer than 'length_limit'.
length_limit: An integer denoting the maximum sequence length to
              consider.
softmax_temp: See fast_transformers.attention.full_attention.
attention_dropout: See fast_transformers.attention.full_attention.
event_dispatcher: str or EventDispatcher instance to be used by this
                  module for dispatching events (default: the default
                  global dispatcher)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class ConditionalFullAttention(Module):
    """"Delegate to full attention if the input sequence is short.

    Arguments
    ---------
        other_attention: Use the passed attention module if the sequence is
                         longer than 'length_limit'.
        length_limit: An integer denoting the maximum sequence length to
                      consider.
        softmax_temp: See fast_transformers.attention.full_attention.
        attention_dropout: See fast_transformers.attention.full_attention.
        event_dispatcher: str or EventDispatcher instance to be used by this
                          module for dispatching events (default: the default
                          global dispatcher)
    """
    def __init__(self, other_attention, length_limit=512, softmax_temp=None,
                 attention_dropout=0.1, event_dispatcher=""):
        super(ConditionalFullAttention, self).__init__()
        self.full_attention = FullAttention(softmax_temp, attention_dropout)
        self.other_attention = other_attention
        self.length_limit = length_limit
        self.event_dispatcher = EventDispatcher.get(event_dispatcher)

    def forward(self, queries, keys, values, attn_mask, query_lengths,
                key_lengths):
        # Extract some shapes to compare with the length limit
        L = queries.shape[1]
        S = values.shape[1]

        if L > self.length_limit or S > self.length_limit:
            return self.other_attention(queries, keys, values, attn_mask,
                                        query_lengths, key_lengths)
        else:
            return self.full_attention(queries, keys, values, attn_mask,
                                       query_lengths, key_lengths)

Ancestors

  • torch.nn.modules.module.Module

Methods

def forward(self, queries, keys, values, attn_mask, query_lengths, key_lengths)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Expand source code
def forward(self, queries, keys, values, attn_mask, query_lengths,
            key_lengths):
    # Extract some shapes to compare with the length limit
    L = queries.shape[1]
    S = values.shape[1]

    if L > self.length_limit or S > self.length_limit:
        return self.other_attention(queries, keys, values, attn_mask,
                                    query_lengths, key_lengths)
    else:
        return self.full_attention(queries, keys, values, attn_mask,
                                   query_lengths, key_lengths)