Creating a custom attention layer
In this page, we will go through the process of creating a custom attention module and integrating it with the library. We will implement a quadratic kernel attention instead of softmax attention.
New Attention
Our attention layer will follow closely the implementation of FullAttention. Let's start with the skeleton of our module.
class QuadraticAttention(Module):
def __init__(self, quadratic_temp=1.0, eps=1e-6):
super(QuadraticAttention, self).__init__()
self.eps = eps
self.quadratic_temp = quadratic_temp
def forward(self, queries, keys, values, attn_mask, query_lengths,
key_lengths):
# implement the logic of the layer here
The queries, keys and values are already projected and split into multiple heads by the AttentionLayer. This means that we need only implement the attention part.
class QuadraticAttention(Module):
def __init__(self, quadratic_temp=1.0, eps=1e-6):
super(QuadraticAttention, self).__init__()
self.eps = eps
self.quadratic_temp = quadratic_temp
def forward(self, queries, keys, values, attn_mask, query_lengths,
key_lengths):
# compute the unnormalized attention
QK = torch.einsum("nlhe,nshe->nhls", queries, keys) # compute the dot products
QK = torch.square(self.quadratic_temp * QK) # implement our custom attention twist
QK = QK * attn_mask.float_matrix # use the attention mask as a multiplicative mask
QK = QK * key_lengths.float_matrix[:, None, None] # also a multiplicative mask
# normalize and compute the average
A = QK / (QK.sum(dim=-1, keepdim=True) + self.eps)
V = torch.einsum("nhls,nshd->nlhd", A, values)
return V.contiguous()
Integrate with the Builder
To add it as an option to the TransformerEncoderBuilder
or the
TransformerDecoderBuilder
we have to register our new attention in the
appropriate attention registry. The available
registries are
- AttentionRegistry
- RecurrentAttentionRegistry
- RecurrentCrossAttentionRegistry
Similar to FullAttention we will use AttentionRegistry
because our
implementation is not recurrent. The following snippet integrates our quadratic
attention with the builders.
from fast_transformers.attention_registry import AttentionRegistry, \
Optional, Float # we also need these to add our new
# parameter 'quadratic_temp'
AttentionRegistry.register(
"square", QuadraticAttention, # attention_type, class pair
[
("quadratic_temp", Optional(Float, 1.0)) # an optional parameter named
# 'quadratic_temp' of type
# float and with default
# value 1.0
]
)
Afterwards we can use the builder to create transformers with our new attention layer.
quadratic_bert = TransformerEncoderBuilder.from_kwargs(
attention_type="square", # here we select our custom attention layer
n_layers=12,
n_heads=12,
query_dimensions=64,
value_dimensions=64,
feed_forward_dimensions=3072,
activation="gelu",
quadratic_temp=5.0 # set the temperature for our quadratic layer
)