Module fast_transformers.builders.attention_builders

Expand source code
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>
#

from collections import defaultdict

from .base import BaseBuilder
from ..attention_registry import \
    AttentionRegistry, \
    RecurrentAttentionRegistry, \
    RecurrentCrossAttentionRegistry


class BaseAttentionBuilder(BaseBuilder):
    def __init__(self, registry):
        self._registry = registry
        self._parameters = defaultdict(lambda: None)

    @property
    def available_attentions(self):
        """Return a list with the available attention implementations."""
        return self._registry.keys

    def validate_attention_type(self, attention_type):
        """Parse the attention type according to the rules used by `get()` and
        check if the requested attention is constructible."""
        return all(
            all(t in self._registry for t in a.split(","))
            for a in attention_type.split(":")
        )

    def __setattr__(self, key, value):
        # Make sure we have normal behaviour for the class members _registry
        # and _parameters
        if key in ["_registry", "_parameters"]:
            return object.__setattr__(self, key, value)

        # Assign everything else in the parameters dictionary
        if not self._registry.contains_parameter(key):
            raise AttributeError(("{!r} is not a valid attention "
                                  "parameter name").format(key))
        self._parameters[key] = self._registry.validate_parameter(key, value)

    def __getattr__(self, key):
        if key in self._parameters:
            return self._parameters[key]
        else:
            raise AttributeError()

    def __repr__(self):
        return (
            "{}.from_kwargs(\n".format(self.__class__.__name__) + 
            "\n".join(["    {}={!r},".format(k, v)
                       for k, v in self._parameters.items()])[:-1] +
            "\n)"
        )

    def get(self, attention_type):
        """Construct the attention implementation object and return it.

        The passed in attention_type argument defines the attention to be
        created. It should be a string and in its simplest form it should
        be one of the available choices from `available_attentions`.

        However, to enable attention decoration, namely an attention
        implementation augmenting the functionality of another implementation,
        the attention type can be a colon separated list of compositions like
        the following examples:

            - 'att1' means instantiate att1
            - 'att2:att1' means instantiate att1 and decorate it with att2
            - 'att3:att1,att4' means instantiate att1 and att4 and decorate
              them with att3

        Arguments
        ---------
            attention_type: A string that contains one or more keys from
                            `available_attentions` separated with a colon to
                            denote the decoration pattern.
        """
        compositions = reversed(attention_type.split(":"))
        attentions = []
        for c in compositions:
            attentions = [
                self._construct_attention(t, attentions)
                for t in c.split(",")
            ]
        if len(attentions) > 1:
            raise ValueError(("Invalid attention_type argument "
                              "{!r}").format(attention_type))
        return attentions[0]
        
    def _construct_attention(self, attention_type, decorated=[]):
        """Construct an attention implementation object.

        Arguments
        ---------
            attention_type: A string that contains a single key from the
                            `available_attentions`
            decorated: A list of attention implementations to pass as arguments
                       to be decorated
        """
        if attention_type not in self._registry:
            raise ValueError(("Unknown attention type "
                              "{!r}").format(attention_type))

        attention, parameters = self._registry[attention_type]
        parameter_dictionary = {
            p: self._registry.validate_parameter(p, self._parameters[p])
            for p in parameters
        }

        return attention(*decorated, **parameter_dictionary)


class AttentionBuilder(BaseAttentionBuilder):
    """Build attention implementations for batch sequence processing or
    training."""
    def __init__(self):
        super(AttentionBuilder, self).__init__(AttentionRegistry)


class RecurrentAttentionBuilder(BaseAttentionBuilder):
    """Build attention implementations for autoregressive sequence
    processing."""
    def __init__(self):
        super(RecurrentAttentionBuilder, self).__init__(
            RecurrentAttentionRegistry
        )


class RecurrentCrossAttentionBuilder(BaseAttentionBuilder):
    """Build attention implementations for autoregressive cross attention
    computation."""
    def __init__(self):
        super(RecurrentCrossAttentionBuilder, self).__init__(
            RecurrentCrossAttentionRegistry
        )

Classes

class AttentionBuilder

Build attention implementations for batch sequence processing or training.

Expand source code
class AttentionBuilder(BaseAttentionBuilder):
    """Build attention implementations for batch sequence processing or
    training."""
    def __init__(self):
        super(AttentionBuilder, self).__init__(AttentionRegistry)

Ancestors

Inherited members

class BaseAttentionBuilder (registry)
Expand source code
class BaseAttentionBuilder(BaseBuilder):
    def __init__(self, registry):
        self._registry = registry
        self._parameters = defaultdict(lambda: None)

    @property
    def available_attentions(self):
        """Return a list with the available attention implementations."""
        return self._registry.keys

    def validate_attention_type(self, attention_type):
        """Parse the attention type according to the rules used by `get()` and
        check if the requested attention is constructible."""
        return all(
            all(t in self._registry for t in a.split(","))
            for a in attention_type.split(":")
        )

    def __setattr__(self, key, value):
        # Make sure we have normal behaviour for the class members _registry
        # and _parameters
        if key in ["_registry", "_parameters"]:
            return object.__setattr__(self, key, value)

        # Assign everything else in the parameters dictionary
        if not self._registry.contains_parameter(key):
            raise AttributeError(("{!r} is not a valid attention "
                                  "parameter name").format(key))
        self._parameters[key] = self._registry.validate_parameter(key, value)

    def __getattr__(self, key):
        if key in self._parameters:
            return self._parameters[key]
        else:
            raise AttributeError()

    def __repr__(self):
        return (
            "{}.from_kwargs(\n".format(self.__class__.__name__) + 
            "\n".join(["    {}={!r},".format(k, v)
                       for k, v in self._parameters.items()])[:-1] +
            "\n)"
        )

    def get(self, attention_type):
        """Construct the attention implementation object and return it.

        The passed in attention_type argument defines the attention to be
        created. It should be a string and in its simplest form it should
        be one of the available choices from `available_attentions`.

        However, to enable attention decoration, namely an attention
        implementation augmenting the functionality of another implementation,
        the attention type can be a colon separated list of compositions like
        the following examples:

            - 'att1' means instantiate att1
            - 'att2:att1' means instantiate att1 and decorate it with att2
            - 'att3:att1,att4' means instantiate att1 and att4 and decorate
              them with att3

        Arguments
        ---------
            attention_type: A string that contains one or more keys from
                            `available_attentions` separated with a colon to
                            denote the decoration pattern.
        """
        compositions = reversed(attention_type.split(":"))
        attentions = []
        for c in compositions:
            attentions = [
                self._construct_attention(t, attentions)
                for t in c.split(",")
            ]
        if len(attentions) > 1:
            raise ValueError(("Invalid attention_type argument "
                              "{!r}").format(attention_type))
        return attentions[0]
        
    def _construct_attention(self, attention_type, decorated=[]):
        """Construct an attention implementation object.

        Arguments
        ---------
            attention_type: A string that contains a single key from the
                            `available_attentions`
            decorated: A list of attention implementations to pass as arguments
                       to be decorated
        """
        if attention_type not in self._registry:
            raise ValueError(("Unknown attention type "
                              "{!r}").format(attention_type))

        attention, parameters = self._registry[attention_type]
        parameter_dictionary = {
            p: self._registry.validate_parameter(p, self._parameters[p])
            for p in parameters
        }

        return attention(*decorated, **parameter_dictionary)

Ancestors

Subclasses

Instance variables

var available_attentions

Return a list with the available attention implementations.

Expand source code
@property
def available_attentions(self):
    """Return a list with the available attention implementations."""
    return self._registry.keys

Methods

def get(self, attention_type)

Construct the attention implementation object and return it.

The passed in attention_type argument defines the attention to be created. It should be a string and in its simplest form it should be one of the available choices from available_attentions.

However, to enable attention decoration, namely an attention implementation augmenting the functionality of another implementation, the attention type can be a colon separated list of compositions like the following examples:

- 'att1' means instantiate att1
- 'att2:att1' means instantiate att1 and decorate it with att2
- 'att3:att1,att4' means instantiate att1 and att4 and decorate
  them with att3

Arguments

attention_type: A string that contains one or more keys from
                `available_attentions` separated with a colon to
                denote the decoration pattern.
Expand source code
def get(self, attention_type):
    """Construct the attention implementation object and return it.

    The passed in attention_type argument defines the attention to be
    created. It should be a string and in its simplest form it should
    be one of the available choices from `available_attentions`.

    However, to enable attention decoration, namely an attention
    implementation augmenting the functionality of another implementation,
    the attention type can be a colon separated list of compositions like
    the following examples:

        - 'att1' means instantiate att1
        - 'att2:att1' means instantiate att1 and decorate it with att2
        - 'att3:att1,att4' means instantiate att1 and att4 and decorate
          them with att3

    Arguments
    ---------
        attention_type: A string that contains one or more keys from
                        `available_attentions` separated with a colon to
                        denote the decoration pattern.
    """
    compositions = reversed(attention_type.split(":"))
    attentions = []
    for c in compositions:
        attentions = [
            self._construct_attention(t, attentions)
            for t in c.split(",")
        ]
    if len(attentions) > 1:
        raise ValueError(("Invalid attention_type argument "
                          "{!r}").format(attention_type))
    return attentions[0]
def validate_attention_type(self, attention_type)

Parse the attention type according to the rules used by get() and check if the requested attention is constructible.

Expand source code
def validate_attention_type(self, attention_type):
    """Parse the attention type according to the rules used by `get()` and
    check if the requested attention is constructible."""
    return all(
        all(t in self._registry for t in a.split(","))
        for a in attention_type.split(":")
    )

Inherited members

class RecurrentAttentionBuilder

Build attention implementations for autoregressive sequence processing.

Expand source code
class RecurrentAttentionBuilder(BaseAttentionBuilder):
    """Build attention implementations for autoregressive sequence
    processing."""
    def __init__(self):
        super(RecurrentAttentionBuilder, self).__init__(
            RecurrentAttentionRegistry
        )

Ancestors

Inherited members

class RecurrentCrossAttentionBuilder

Build attention implementations for autoregressive cross attention computation.

Expand source code
class RecurrentCrossAttentionBuilder(BaseAttentionBuilder):
    """Build attention implementations for autoregressive cross attention
    computation."""
    def __init__(self):
        super(RecurrentCrossAttentionBuilder, self).__init__(
            RecurrentCrossAttentionRegistry
        )

Ancestors

Inherited members