# Creating a custom attention layer

In this page, we will go through the process of creating a custom attention module and integrating it with the library. We will implement a quadratic kernel attention instead of softmax attention.

## New Attention

Our attention layer will follow closely the implementation of FullAttention. Let's start with the skeleton of our module.

class QuadraticAttention(Module):
self.eps = eps

def forward(self, queries, keys, values, attn_mask, query_lengths,
key_lengths):
# implement the logic of the layer here


The queries, keys and values are already projected and split into multiple heads by the AttentionLayer. This means that we need only implement the attention part.

class QuadraticAttention(Module):
self.eps = eps

def forward(self, queries, keys, values, attn_mask, query_lengths,
key_lengths):
# compute the unnormalized attention
QK = torch.einsum("nlhe,nshe->nhls", queries, keys) # compute the dot products
QK = torch.square(self.quadratic_temp * QK) # implement our custom attention twist
QK = QK * key_lengths.float_matrix[:, None, None] # also a multiplicative mask

# normalize and compute the average
A = QK / (QK.sum(dim=-1, keepdim=True) + self.eps)
V = torch.einsum("nhls,nshd->nlhd", A, values)

return V.contiguous()


## Integrate with the Builder

To add it as an option to the TransformerEncoderBuilder or the TransformerDecoderBuilder we have to register our new attention in the appropriate attention registry. The available registries are

• AttentionRegistry
• RecurrentAttentionRegistry
• RecurrentCrossAttentionRegistry

Similar to FullAttention we will use AttentionRegistry because our implementation is not recurrent. The following snippet integrates our quadratic attention with the builders.

from fast_transformers.attention_registry import AttentionRegistry, \
Optional, Float  # we also need these to add our new

AttentionRegistry.register(
"square", QuadraticAttention,  # attention_type, class pair
[
("quadratic_temp", Optional(Float, 1.0))  # an optional parameter named
# float and with default
# value 1.0
]
)


Afterwards we can use the builder to create transformers with our new attention layer.

quadratic_bert = TransformerEncoderBuilder.from_kwargs(
attention_type="square", # here we select our custom attention layer
n_layers=12,