Feature Maps

The LinearAttention and CausalLinearAttention modules, as well as their corresponding recurrent modules, accept a feature_map argument which is the kernel feature map for each attention implementation. The default feature_map is a simple activation function as used in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention".

However, the API allows for signficantly more complicated feature maps, that contain trainable weights or are asymmetric.

FeatureMap API

All feature maps must implement the following interface.

class FeatureMap(Module):
    def __init__(self, query_dimensions):
        ...

    def new_feature_map(self):
        """Create a new instance of this feature map. In particular, if it is a
        random feature map sample new parameters."""
        ...

    def forward_queries(self, x):
        """Encode the queries `x` using this feature map."""
        ...

    def forward_keys(self, x):
        """Encode the keys `x` using this feature map."""
        ...

    def forward(self, x):
        # For symmetric feature maps it suffices to define this function
        ...

In particular, all feature maps accept the query dimensions as the first constructor parameter. After calling new_feature_map() all calls to forward variants should be compatible with each other, namely all randomness should happen in the new_feature_map method. Symmetric feature maps should only implement forward.

Using feature maps

All modules that accept feature maps, expect a factory function. Namely, a function that when given the query dimensions returns a new feature map instance.

A simple way to achieve that is by using the partial() method of the built-in module functools or the utility class method factory() which is basically the same.

from functools import partial

from fast_transformers.builders import TransformerEncoderBuilder
from fast_transformers.feature_maps import Favor

transformer = TransformerEncoderBuilder.from_kwargs(
    attention_type="linear",
    n_layers=4,
    n_heads=8,
    query_dimensions=32,
    feature_map=Favor.factory(n_dims=256)
).get()

transformer = TransformerEncoderBuilder.from_kwargs(
    attention_type="linear",
    n_layers=4,
    n_heads=8,
    query_dimensions=32,
    feature_map=partial(Favor, n_dims=256)
).get()

If you do not want to pass any parameters to the feature map, then it suffices to use the class object directly.

Available feature maps

  • ActivationFunctionFeatureMap uses a simple elementwise activation function as a feature map.
  • elu_feature_map is a specialization of the above where the activation function is elu(x)+1. It is also the default feature map.
  • RandomFourierFeatures approximates the RBF kernel using random Fourier features with trigonometric functions.
  • SmoothedRandomFourierFeatures approximates the RBF kernel plus a constant for numerical stability.
  • Favor implements the positive random features designed specifically for transformers in the paper "Rethinking Attention with Performers". It should be preferred over the RandomFourierFeatures.
  • GeneralizedRandomFeatures is a simplification of Favor which does not approximate softmax but it can increase the rank of the resulting attention matrix.