Module fast_transformers.feature_maps.base
Create the feature map interface and some commonly used feature maps.
All attention implementations that expect a feature map shall receive a factory function that returns a feature map instance when called with the query dimensions.
Expand source code
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>
#
"""Create the feature map interface and some commonly used feature maps.
All attention implementations that expect a feature map shall receive a factory
function that returns a feature map instance when called with the query
dimensions.
"""
from functools import partial
import torch
from torch.nn import Module
class FeatureMap(Module):
"""Define the FeatureMap interface."""
def __init__(self, query_dims):
super().__init__()
self.query_dims = query_dims
def new_feature_map(self):
"""Create a new instance of this feature map. In particular, if it is a
random feature map sample new parameters."""
raise NotImplementedError()
def forward_queries(self, x):
"""Encode the queries `x` using this feature map."""
return self(x)
def forward_keys(self, x):
"""Encode the keys `x` using this feature map."""
return self(x)
def forward(self, x):
"""Encode x using this feature map. For symmetric feature maps it
suffices to define this function, but for asymmetric feature maps one
needs to define the `forward_queries` and `forward_keys` functions."""
raise NotImplementedError()
@classmethod
def factory(cls, *args, **kwargs):
"""Return a function that when called with the query dimensions returns
an instance of this feature map.
It is inherited by the subclasses so it is available in all feature
maps.
"""
def inner(query_dims):
return cls(query_dims, *args, **kwargs)
return inner
class ActivationFunctionFeatureMap(FeatureMap):
"""Define a feature map that is simply an element-wise activation
function."""
def __init__(self, query_dims, activation_function):
super().__init__(query_dims)
self.activation_function = activation_function
def new_feature_map(self):
return
def forward(self, x):
return self.activation_function(x)
elu_feature_map = ActivationFunctionFeatureMap.factory(
lambda x: torch.nn.functional.elu(x) + 1
)
Functions
def elu_feature_map(query_dims)
-
Expand source code
def inner(query_dims): return cls(query_dims, *args, **kwargs)
Classes
class ActivationFunctionFeatureMap (query_dims, activation_function)
-
Define a feature map that is simply an element-wise activation function.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
class ActivationFunctionFeatureMap(FeatureMap): """Define a feature map that is simply an element-wise activation function.""" def __init__(self, query_dims, activation_function): super().__init__(query_dims) self.activation_function = activation_function def new_feature_map(self): return def forward(self, x): return self.activation_function(x)
Ancestors
- FeatureMap
- torch.nn.modules.module.Module
Inherited members
class FeatureMap (query_dims)
-
Define the FeatureMap interface.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
class FeatureMap(Module): """Define the FeatureMap interface.""" def __init__(self, query_dims): super().__init__() self.query_dims = query_dims def new_feature_map(self): """Create a new instance of this feature map. In particular, if it is a random feature map sample new parameters.""" raise NotImplementedError() def forward_queries(self, x): """Encode the queries `x` using this feature map.""" return self(x) def forward_keys(self, x): """Encode the keys `x` using this feature map.""" return self(x) def forward(self, x): """Encode x using this feature map. For symmetric feature maps it suffices to define this function, but for asymmetric feature maps one needs to define the `forward_queries` and `forward_keys` functions.""" raise NotImplementedError() @classmethod def factory(cls, *args, **kwargs): """Return a function that when called with the query dimensions returns an instance of this feature map. It is inherited by the subclasses so it is available in all feature maps. """ def inner(query_dims): return cls(query_dims, *args, **kwargs) return inner
Ancestors
- torch.nn.modules.module.Module
Subclasses
Static methods
def factory(*args, **kwargs)
-
Return a function that when called with the query dimensions returns an instance of this feature map.
It is inherited by the subclasses so it is available in all feature maps.
Expand source code
@classmethod def factory(cls, *args, **kwargs): """Return a function that when called with the query dimensions returns an instance of this feature map. It is inherited by the subclasses so it is available in all feature maps. """ def inner(query_dims): return cls(query_dims, *args, **kwargs) return inner
Methods
def forward(self, x)
-
Encode x using this feature map. For symmetric feature maps it suffices to define this function, but for asymmetric feature maps one needs to define the
forward_queries
andforward_keys
functions.Expand source code
def forward(self, x): """Encode x using this feature map. For symmetric feature maps it suffices to define this function, but for asymmetric feature maps one needs to define the `forward_queries` and `forward_keys` functions.""" raise NotImplementedError()
def forward_keys(self, x)
-
Encode the keys
x
using this feature map.Expand source code
def forward_keys(self, x): """Encode the keys `x` using this feature map.""" return self(x)
def forward_queries(self, x)
-
Encode the queries
x
using this feature map.Expand source code
def forward_queries(self, x): """Encode the queries `x` using this feature map.""" return self(x)
def new_feature_map(self)
-
Create a new instance of this feature map. In particular, if it is a random feature map sample new parameters.
Expand source code
def new_feature_map(self): """Create a new instance of this feature map. In particular, if it is a random feature map sample new parameters.""" raise NotImplementedError()