Events

When training transformers, some internal representations, such as the attention matrices, are useful for identifying problems or understanding how the model works.

Instead of making these representations accessible by returning them as the output of the model, we provide them via an event system. This allows for greater flexibility by allowing different attention implementations to return different things without affecting the execution speed or the interfaces.

You can explore the interfaces of the event system in our API Docs.

Getting Started

Before delving deeper into the API of the event system, the following commented code snippet collects all the attention matrices from a forward pass of a transformer and plots the first head of the first sample using matplotlib.

import matplotlib.pyplot as plt
import torch

from fast_transformers.builders import TransformerEncoderBuilder
from fast_transformers.events import EventDispatcher, AttentionEvent

# Make a transformer as we always would
transformer = TransformerEncoderBuilder.from_kwargs(
    n_layers=4,
    n_heads=4,
    query_dimensions=64,
    value_dimensions=64
).get()

# Make an event handler that just appends to a list
attentions = []
def save_attention_matrix(event):
    attentions.append(event.attention_matrix.detach().cpu())

# Register said event handler for AttentionEvents
EventDispatcher.get().listen(AttentionEvent, save_attention_matrix)

# Do a forward pass like always
transformer(torch.rand(10, 100, 64*4))

# Now get and plot the attention matrices from the `attentions` list
fig, axes = plt.subplots(2, 2)
for i in range(2):
    for j in range(2):
        axes[i, j].imshow(attentions[i*2+j][0, 0])
        axes[i, j].set_title("Layer {} Head 0".format(i*2+j))
plt.tight_layout()
plt.show()

EventDispatcher

The event system is implemented by the EventDispatcher which is shared by the transformer and attention modules as well as the rest of the system. The event dispatcher instance is injected as an argument to all attentions and transformer modules, but for further ease there is a global dictionary of event dispatchers that is accessible through the get(key="") factory method as follows:

from fast_transformers.events import EventDispatcher

# The default dispatcher used by all modules unless passed as an argument
ed = EventDispatcher.get()

Unless an event dispatcher is provided via an argument, all modules simply use the default event dispatcher.

Methods

EventDispatcher.listen(event_filter, event_handler)

The method listen() simply adds an event handler to be called when an event is dispatched via this dispatcher. The event handler will only be called if the event filter callable returns true for an event. For, details on the possible values of event_filter see the event filter section.

The EventDispatcher automatically casts callables and Event subclasses to the corresponding filter instances.

EventDispatcher.dispatch(event)

Simply, call any event handler that is registered for this type of event.

EventDispatcher.remove(event_handler)
EventDispatcher.clear()

Remove (unregister) a specific event handler using remove() or simply unregister all of the event handlers using the clear() method of the event dispatcher.

Event Filters

The event filters are callables that accept a single argument, an instance of Event, and return whether to accept or dismiss this event. For ease of filter composition, we provide an EventFilter object that allows for boolean composition of filters using python operators, as follows:

from fast_transformers.events.filters import event_class, from_layer, \
    layer_name_contains

# Checking whether an event is from a specific class
filter1 = event_class(AttentionEvent)

# Checking whether an event comes from a specific layer
filter2 = from_layer(net.layers[10])

# Checking whether the human readable name of the module contains a string
filter3 = layer_name_contains(net, "layers.10")

# Check whether it comes from a specific layer *and* is an AttentionEvent
filter4 = from_layer(net.layers[10] & event_class(AttentionEvent)
# or equivalently
filter4 = from_layer(net.layers[10] & AttentionEvent

# Check whether the attention matrix has 4 heads
filter5 = (
    event_class(AttentionEvent) &  # unless we also use the event_class
                                   # filter the event might not have the
                                   # attention_matrix attribute
    (lambda ev: ev.attention_matrix.shape[2]==4)
)

See the event filters API docs for more information.

Events

The events are subclasses of Event that contain the source layer from which they were emitted and a payload that depends on the specific event that was emitted.

The following is a list of the currently implemented events with a high-level overview of their payload as well as the layers which emit them.

QKVEvent

The QKVEvent is emmited by the attention layer and it contains the queries, keys and values in the corresponding attributes.

AttentionEvent

The AttentionEvent is emitted by the full attention and it contains the softmax normalized attention matrix in the attribute attention_matrix.