Module fast_transformers.feature_maps.base

Create the feature map interface and some commonly used feature maps.

All attention implementations that expect a feature map shall receive a factory function that returns a feature map instance when called with the query dimensions.

Expand source code
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>
#

"""Create the feature map interface and some commonly used feature maps.

All attention implementations that expect a feature map shall receive a factory
function that returns a feature map instance when called with the query
dimensions.
"""

from functools import partial

import torch
from torch.nn import Module


class FeatureMap(Module):
    """Define the FeatureMap interface."""
    def __init__(self, query_dims):
        super().__init__()
        self.query_dims = query_dims

    def new_feature_map(self):
        """Create a new instance of this feature map. In particular, if it is a
        random feature map sample new parameters."""
        raise NotImplementedError()

    def forward_queries(self, x):
        """Encode the queries `x` using this feature map."""
        return self(x)

    def forward_keys(self, x):
        """Encode the keys `x` using this feature map."""
        return self(x)

    def forward(self, x):
        """Encode x using this feature map. For symmetric feature maps it
        suffices to define this function, but for asymmetric feature maps one
        needs to define the `forward_queries` and `forward_keys` functions."""
        raise NotImplementedError()

    @classmethod
    def factory(cls, *args, **kwargs):
        """Return a function that when called with the query dimensions returns
        an instance of this feature map.

        It is inherited by the subclasses so it is available in all feature
        maps.
        """
        def inner(query_dims):
            return cls(query_dims, *args, **kwargs)
        return inner


class ActivationFunctionFeatureMap(FeatureMap):
    """Define a feature map that is simply an element-wise activation
    function."""
    def __init__(self, query_dims, activation_function):
        super().__init__(query_dims)
        self.activation_function = activation_function

    def new_feature_map(self):
        return

    def forward(self, x):
        return self.activation_function(x)


elu_feature_map = ActivationFunctionFeatureMap.factory(
    lambda x: torch.nn.functional.elu(x) + 1
)

Functions

def elu_feature_map(query_dims)
Expand source code
def inner(query_dims):
    return cls(query_dims, *args, **kwargs)

Classes

class ActivationFunctionFeatureMap (query_dims, activation_function)

Define a feature map that is simply an element-wise activation function.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class ActivationFunctionFeatureMap(FeatureMap):
    """Define a feature map that is simply an element-wise activation
    function."""
    def __init__(self, query_dims, activation_function):
        super().__init__(query_dims)
        self.activation_function = activation_function

    def new_feature_map(self):
        return

    def forward(self, x):
        return self.activation_function(x)

Ancestors

Inherited members

class FeatureMap (query_dims)

Define the FeatureMap interface.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class FeatureMap(Module):
    """Define the FeatureMap interface."""
    def __init__(self, query_dims):
        super().__init__()
        self.query_dims = query_dims

    def new_feature_map(self):
        """Create a new instance of this feature map. In particular, if it is a
        random feature map sample new parameters."""
        raise NotImplementedError()

    def forward_queries(self, x):
        """Encode the queries `x` using this feature map."""
        return self(x)

    def forward_keys(self, x):
        """Encode the keys `x` using this feature map."""
        return self(x)

    def forward(self, x):
        """Encode x using this feature map. For symmetric feature maps it
        suffices to define this function, but for asymmetric feature maps one
        needs to define the `forward_queries` and `forward_keys` functions."""
        raise NotImplementedError()

    @classmethod
    def factory(cls, *args, **kwargs):
        """Return a function that when called with the query dimensions returns
        an instance of this feature map.

        It is inherited by the subclasses so it is available in all feature
        maps.
        """
        def inner(query_dims):
            return cls(query_dims, *args, **kwargs)
        return inner

Ancestors

  • torch.nn.modules.module.Module

Subclasses

Static methods

def factory(*args, **kwargs)

Return a function that when called with the query dimensions returns an instance of this feature map.

It is inherited by the subclasses so it is available in all feature maps.

Expand source code
@classmethod
def factory(cls, *args, **kwargs):
    """Return a function that when called with the query dimensions returns
    an instance of this feature map.

    It is inherited by the subclasses so it is available in all feature
    maps.
    """
    def inner(query_dims):
        return cls(query_dims, *args, **kwargs)
    return inner

Methods

def forward(self, x)

Encode x using this feature map. For symmetric feature maps it suffices to define this function, but for asymmetric feature maps one needs to define the forward_queries and forward_keys functions.

Expand source code
def forward(self, x):
    """Encode x using this feature map. For symmetric feature maps it
    suffices to define this function, but for asymmetric feature maps one
    needs to define the `forward_queries` and `forward_keys` functions."""
    raise NotImplementedError()
def forward_keys(self, x)

Encode the keys x using this feature map.

Expand source code
def forward_keys(self, x):
    """Encode the keys `x` using this feature map."""
    return self(x)
def forward_queries(self, x)

Encode the queries x using this feature map.

Expand source code
def forward_queries(self, x):
    """Encode the queries `x` using this feature map."""
    return self(x)
def new_feature_map(self)

Create a new instance of this feature map. In particular, if it is a random feature map sample new parameters.

Expand source code
def new_feature_map(self):
    """Create a new instance of this feature map. In particular, if it is a
    random feature map sample new parameters."""
    raise NotImplementedError()