Feature Maps
The LinearAttention and CausalLinearAttention modules, as well as
their corresponding recurrent modules, accept a feature_map
argument which is
the kernel feature map for each attention implementation. The default
feature_map
is a simple activation function as used in "Transformers are
RNNs: Fast Autoregressive Transformers with Linear Attention".
However, the API allows for signficantly more complicated feature maps, that contain trainable weights or are asymmetric.
FeatureMap API
All feature maps must implement the following interface.
class FeatureMap(Module):
def __init__(self, query_dimensions):
...
def new_feature_map(self):
"""Create a new instance of this feature map. In particular, if it is a
random feature map sample new parameters."""
...
def forward_queries(self, x):
"""Encode the queries `x` using this feature map."""
...
def forward_keys(self, x):
"""Encode the keys `x` using this feature map."""
...
def forward(self, x):
# For symmetric feature maps it suffices to define this function
...
In particular, all feature maps accept the query dimensions as the first constructor parameter. After calling new_feature_map() all calls to forward variants should be compatible with each other, namely all randomness should happen in the new_feature_map method. Symmetric feature maps should only implement forward.
Using feature maps
All modules that accept feature maps, expect a factory function. Namely, a function that when given the query dimensions returns a new feature map instance.
A simple way to achieve that is by using the partial()
method of the built-in
module functools
or the utility class method factory()
which is basically
the same.
from functools import partial
from fast_transformers.builders import TransformerEncoderBuilder
from fast_transformers.feature_maps import Favor
transformer = TransformerEncoderBuilder.from_kwargs(
attention_type="linear",
n_layers=4,
n_heads=8,
query_dimensions=32,
feature_map=Favor.factory(n_dims=256)
).get()
transformer = TransformerEncoderBuilder.from_kwargs(
attention_type="linear",
n_layers=4,
n_heads=8,
query_dimensions=32,
feature_map=partial(Favor, n_dims=256)
).get()
If you do not want to pass any parameters to the feature map, then it suffices to use the class object directly.
Available feature maps
- ActivationFunctionFeatureMap uses a simple elementwise activation function as a feature map.
- elu_feature_map is a specialization of the above where the activation
function is
elu(x)+1
. It is also the default feature map. - RandomFourierFeatures approximates the RBF kernel using random Fourier features with trigonometric functions.
- SmoothedRandomFourierFeatures approximates the RBF kernel plus a constant for numerical stability.
- Favor implements the positive random features designed specifically for transformers in the paper "Rethinking Attention with Performers". It should be preferred over the RandomFourierFeatures.
- GeneralizedRandomFeatures is a simplification of Favor which does not approximate softmax but it can increase the rank of the resulting attention matrix.