Module fast_transformers.recurrent.transformers

Implement transformer encoders and decoders as RNNs that will be used with different recurrent attention mechanisms.

In all cases there exists no sequence dimension and the shapes are batch x heads x dims.

This module's interface is designed with the linear attention in mind. The interface is subject to change given the implementation of other recurrent attentions.

Expand source code
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
# Apoorv Vyas <avyas@idiap.ch>
#

"""Implement transformer encoders and decoders as RNNs that will be used with
different recurrent attention mechanisms.

In all cases there exists no sequence dimension and the shapes are batch x
heads x dims.

This module's interface is designed with the linear attention in mind. The
interface is subject to change given the implementation of other recurrent
attentions.
"""

import warnings

import torch
from torch.nn import Dropout, LayerNorm, Linear, Module, ModuleList
import torch.nn.functional as F

from ..events import EventDispatcher
from ..masking import LengthMask
from ._utils import check_state


class RecurrentTransformerEncoderLayer(Module):
    """Attention to the previous inputs and feed forward with skip connections.

    This transformer encoder layer is the recurrent dual of
    fast_transformers.transformers.TransformerEncoderLayer . The results should
    be identical given the same inputs and a lower triangular mask.

    Arguments
    ---------
        attention: The attention implementation to use given as a nn.Module
        d_model: The input feature dimensionality
        d_ff: The dimensionality of the intermediate features after the
              attention (default: d_model*4)
        dropout: The dropout rate to apply to the intermediate features
                 (default: 0.1)
        activation: {'relu', 'gelu'} Which activation to use for the feed
                    forward part of the layer (default: relu)
        event_dispatcher: str or EventDispatcher instance to be used by this
                          module for dispatching events (default: the default
                          global dispatcher)
    """
    def __init__(self, attention, d_model, d_ff=None, dropout=0.1,
                 activation="relu", event_dispatcher=""):
        super(RecurrentTransformerEncoderLayer, self).__init__()
        d_ff = d_ff or 4*d_model
        self.attention = attention
        self.linear1 = Linear(d_model, d_ff)
        self.linear2 = Linear(d_ff, d_model)
        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.dropout = Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu
        self.event_dispatcher = EventDispatcher.get(event_dispatcher)

    def forward(self, x, state=None, memory=None):
        """Apply the transformer encoder to the input x using the provided
        memory.

        Arguments
        ---------
            x: The input features of shape (N, E) where N is the batch size and
               E is d_model passed in the constructor
            state: The state can vary depending on the attention implementation
            memory: **Deprecated** name for the state argument
        """
        # Normalize the state name
        state = check_state(state, memory)

        # Run the self attention and add it to the input
        x2, state = self.attention(x, x, x, state)
        x = x + self.dropout(x2)

        # Run the fully connected part of the layer
        y = x = self.norm1(x)
        y = self.dropout(self.activation(self.linear1(y)))
        y = self.dropout(self.linear2(y))

        return self.norm2(x+y), state


class RecurrentTransformerEncoder(Module):
    """RecurrentTransformerEncoder is a sequence of
    RecurrentTransformerEncoderLayer instances.

    RecurrentTransformerEncoder keeps a separate state per
    RecurrentTransformerEncoderLayer.

    Arguments
    ---------
        layers: list, RecurrentTransformerEncoderLayer instances or instances
                that implement the same interface
        norm_layer: A normalization layer to be applied to the final output
                    (default: None which means no normalization)
        event_dispatcher: str or EventDispatcher instance to be used by this
                          module for dispatching events (default: the default
                          global dispatcher)
    """
    def __init__(self, layers, norm_layer=None, event_dispatcher=""):
        super(RecurrentTransformerEncoder, self).__init__()
        self.layers = ModuleList(layers)
        self.norm = norm_layer
        self.event_dispatcher = EventDispatcher.get(event_dispatcher)

    def forward(self, x, state=None, memory=None):
        """Apply all recurrent transformer layers to the input x using the
        provided state.

        Arguments
        ---------
            x: The input features of shape (N, E) where N is the batch size and
               E is d_model passed in the constructor of each recurrent
               transformer encoder layer
            state: A list of objects to be passed to each recurrent
                   transformer encoder layer
            memory: **Deprecated** name for the state argument
        """
        # Initialize the memory to None if not given
        state = check_state(state, memory)
        if state is None:
            state = [None]*len(self.layers)

        # Apply all the transformers
        for i, layer in enumerate(self.layers):
            x, s = layer(x, state[i])
            state[i] = s

        # Apply the normalization if needed
        if self.norm is not None:
            x = self.norm(x)

        return x, state


class RecurrentTransformerDecoderLayer(Module):
    """Attention to the previous inputs and a preprocessed memory.

    This transformer decoder layer is the recurrent dual of
    fast_transformers.transformers.TransformerDecoderLayer . The results should
    be identical given the same inputs and a lower triangular mask for x_mask.

    Arguments
    ---------
        self_attention: The attention implementation to use for self attention
                        given as a nn.Module
        cross_attention: The attention implementation to use for cross
                         attention given as a nn.Module
        d_model: The input feature dimensionality
        d_ff: The dimensionality of the intermediate features after the
              attention (default: d_model*4)
        dropout: The dropout rate to apply to the intermediate features
                 (default: 0.1)
        activation: {'relu', 'gelu'} Which activation to use for the feed
                    forward part of the layer (default: relu)
        event_dispatcher: str or EventDispatcher instance to be used by this
                          module for dispatching events (default: the default
                          global dispatcher)
    """
    def __init__(self, self_attention, cross_attention, d_model, d_ff=None,
                 dropout=0.1, activation="relu", event_dispatcher=""):
        super(RecurrentTransformerDecoderLayer, self).__init__()
        d_ff = d_ff or 4*d_model
        self.self_attention = self_attention
        self.cross_attention = cross_attention
        self.linear1 = Linear(d_model, d_ff)
        self.linear2 = Linear(d_ff, d_model)
        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.norm3 = LayerNorm(d_model)
        self.dropout = Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu
        self.event_dispatcher = EventDispatcher.get(event_dispatcher)

    def forward(self, x, memory, memory_length_mask=None, state=None):
        """Apply the transformer decoder to the input x and also attend to
        memory.

        Note the memory mask is assumed to be a full mask.

        Arguments
        ---------
            x: The input features of shape (N, E) where N is the batch size and
               E is d_model passed in the constructor
            memory: A sequence of features (N, S, E) that the input will attend
                    to. S is the sequence length and E is the same as for x.
            memory_length_mask: An implementation of a BaseMask that encodes
                                how many elements each memory sequence in the
                                batch consists of.
            state: The state varies depending on the attention implementations
                   but it allows for recurrent implementation.
        """
        # Normalize the mask
        N = x.shape[0]
        L = memory.shape[1]
        memory_length_mask = memory_length_mask or \
            LengthMask(x.new_full((N,), L, dtype=torch.int64))

        # Extract the individual states for the self attention and the cross
        # attention
        self_state, cross_state = state or [None, None]

        # First apply the self attention and add it to the input
        x2, self_state = self.self_attention(x, x, x, state=self_state)
        x = self.norm1(x + self.dropout(x2))

        # Secondly apply the cross attention and add it to the previous output
        x2, cross_state = self.cross_attention(
            x, memory, memory, memory_length_mask, state=cross_state
        )
        x = self.norm2(x + self.dropout(x2))

        # Finally run the fully connected part of the layer
        y = x
        y = self.dropout(self.activation(self.linear1(y)))
        y = self.dropout(self.linear2(y))

        return self.norm3(x+y), [self_state, cross_state]


class RecurrentTransformerDecoder(Module):
    """RecurrentTransformerDecoder is little more than a sequence of
    RecurrentTransformerDecoderLayer instances.

    Simlar to the recurrent encoder a separate state is kept per decoder layer.

    Arguments
    ---------
        layers: list, RecurrentTransformerDecoderLayer instances or instances
                that implement the same interface
        norm_layer: A normalization layer to be applied to the final output
                    (default: None which means no normalization)
        event_dispatcher: str or EventDispatcher instance to be used by this
                          module for dispatching events (default: the default
                          global dispatcher)
    """
    def __init__(self, layers, norm_layer=None, event_dispatcher=""):
        super(RecurrentTransformerDecoder, self).__init__()
        self.layers = ModuleList(layers)
        self.norm = norm_layer
        self.event_dispatcher = EventDispatcher.get(event_dispatcher)

    def forward(self, x, memory, memory_length_mask=None, state=None):
        """Apply all recurrent transformer layers to the input x using the
        provided state.

        Arguments
        ---------
            x: The input features of shape (N, E) where N is the batch size and
               E is d_model passed in the constructor
            memory: A sequence of features (N, S, E) that the input will attend
                    to. S is the sequence length and E is the same as for x.
            memory_length_mask: An implementation of a BaseMask that encodes
                                how many elements each memory sequence in the
                                batch consists of
            state: A list of objects to be passed to each recurrent
                   transformer decoder layer
        """
        # Initialize the state to None if not given
        if state is None:
            state = [None]*len(self.layers)

        # Apply all the transformers
        for i, layer in enumerate(self.layers):
            x, s = layer(x, memory, memory_length_mask=memory_length_mask,
                         state=state[i])
            state[i] = s

        # Apply the normalization if needed
        if self.norm is not None:
            x = self.norm(x)

        return x, state

Classes

class RecurrentTransformerDecoder (layers, norm_layer=None, event_dispatcher='')

RecurrentTransformerDecoder is little more than a sequence of RecurrentTransformerDecoderLayer instances.

Simlar to the recurrent encoder a separate state is kept per decoder layer.

Arguments

layers: list, RecurrentTransformerDecoderLayer instances or instances
        that implement the same interface
norm_layer: A normalization layer to be applied to the final output
            (default: None which means no normalization)
event_dispatcher: str or EventDispatcher instance to be used by this
                  module for dispatching events (default: the default
                  global dispatcher)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class RecurrentTransformerDecoder(Module):
    """RecurrentTransformerDecoder is little more than a sequence of
    RecurrentTransformerDecoderLayer instances.

    Simlar to the recurrent encoder a separate state is kept per decoder layer.

    Arguments
    ---------
        layers: list, RecurrentTransformerDecoderLayer instances or instances
                that implement the same interface
        norm_layer: A normalization layer to be applied to the final output
                    (default: None which means no normalization)
        event_dispatcher: str or EventDispatcher instance to be used by this
                          module for dispatching events (default: the default
                          global dispatcher)
    """
    def __init__(self, layers, norm_layer=None, event_dispatcher=""):
        super(RecurrentTransformerDecoder, self).__init__()
        self.layers = ModuleList(layers)
        self.norm = norm_layer
        self.event_dispatcher = EventDispatcher.get(event_dispatcher)

    def forward(self, x, memory, memory_length_mask=None, state=None):
        """Apply all recurrent transformer layers to the input x using the
        provided state.

        Arguments
        ---------
            x: The input features of shape (N, E) where N is the batch size and
               E is d_model passed in the constructor
            memory: A sequence of features (N, S, E) that the input will attend
                    to. S is the sequence length and E is the same as for x.
            memory_length_mask: An implementation of a BaseMask that encodes
                                how many elements each memory sequence in the
                                batch consists of
            state: A list of objects to be passed to each recurrent
                   transformer decoder layer
        """
        # Initialize the state to None if not given
        if state is None:
            state = [None]*len(self.layers)

        # Apply all the transformers
        for i, layer in enumerate(self.layers):
            x, s = layer(x, memory, memory_length_mask=memory_length_mask,
                         state=state[i])
            state[i] = s

        # Apply the normalization if needed
        if self.norm is not None:
            x = self.norm(x)

        return x, state

Ancestors

  • torch.nn.modules.module.Module

Methods

def forward(self, x, memory, memory_length_mask=None, state=None)

Apply all recurrent transformer layers to the input x using the provided state.

Arguments

x: The input features of shape (N, E) where N is the batch size and
   E is d_model passed in the constructor
memory: A sequence of features (N, S, E) that the input will attend
        to. S is the sequence length and E is the same as for x.
memory_length_mask: An implementation of a BaseMask that encodes
                    how many elements each memory sequence in the
                    batch consists of
state: A list of objects to be passed to each recurrent
       transformer decoder layer
Expand source code
def forward(self, x, memory, memory_length_mask=None, state=None):
    """Apply all recurrent transformer layers to the input x using the
    provided state.

    Arguments
    ---------
        x: The input features of shape (N, E) where N is the batch size and
           E is d_model passed in the constructor
        memory: A sequence of features (N, S, E) that the input will attend
                to. S is the sequence length and E is the same as for x.
        memory_length_mask: An implementation of a BaseMask that encodes
                            how many elements each memory sequence in the
                            batch consists of
        state: A list of objects to be passed to each recurrent
               transformer decoder layer
    """
    # Initialize the state to None if not given
    if state is None:
        state = [None]*len(self.layers)

    # Apply all the transformers
    for i, layer in enumerate(self.layers):
        x, s = layer(x, memory, memory_length_mask=memory_length_mask,
                     state=state[i])
        state[i] = s

    # Apply the normalization if needed
    if self.norm is not None:
        x = self.norm(x)

    return x, state
class RecurrentTransformerDecoderLayer (self_attention, cross_attention, d_model, d_ff=None, dropout=0.1, activation='relu', event_dispatcher='')

Attention to the previous inputs and a preprocessed memory.

This transformer decoder layer is the recurrent dual of fast_transformers.transformers.TransformerDecoderLayer . The results should be identical given the same inputs and a lower triangular mask for x_mask.

Arguments

self_attention: The attention implementation to use for self attention
                given as a nn.Module
cross_attention: The attention implementation to use for cross
                 attention given as a nn.Module
d_model: The input feature dimensionality
d_ff: The dimensionality of the intermediate features after the
      attention (default: d_model*4)
dropout: The dropout rate to apply to the intermediate features
         (default: 0.1)
activation: {'relu', 'gelu'} Which activation to use for the feed
            forward part of the layer (default: relu)
event_dispatcher: str or EventDispatcher instance to be used by this
                  module for dispatching events (default: the default
                  global dispatcher)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class RecurrentTransformerDecoderLayer(Module):
    """Attention to the previous inputs and a preprocessed memory.

    This transformer decoder layer is the recurrent dual of
    fast_transformers.transformers.TransformerDecoderLayer . The results should
    be identical given the same inputs and a lower triangular mask for x_mask.

    Arguments
    ---------
        self_attention: The attention implementation to use for self attention
                        given as a nn.Module
        cross_attention: The attention implementation to use for cross
                         attention given as a nn.Module
        d_model: The input feature dimensionality
        d_ff: The dimensionality of the intermediate features after the
              attention (default: d_model*4)
        dropout: The dropout rate to apply to the intermediate features
                 (default: 0.1)
        activation: {'relu', 'gelu'} Which activation to use for the feed
                    forward part of the layer (default: relu)
        event_dispatcher: str or EventDispatcher instance to be used by this
                          module for dispatching events (default: the default
                          global dispatcher)
    """
    def __init__(self, self_attention, cross_attention, d_model, d_ff=None,
                 dropout=0.1, activation="relu", event_dispatcher=""):
        super(RecurrentTransformerDecoderLayer, self).__init__()
        d_ff = d_ff or 4*d_model
        self.self_attention = self_attention
        self.cross_attention = cross_attention
        self.linear1 = Linear(d_model, d_ff)
        self.linear2 = Linear(d_ff, d_model)
        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.norm3 = LayerNorm(d_model)
        self.dropout = Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu
        self.event_dispatcher = EventDispatcher.get(event_dispatcher)

    def forward(self, x, memory, memory_length_mask=None, state=None):
        """Apply the transformer decoder to the input x and also attend to
        memory.

        Note the memory mask is assumed to be a full mask.

        Arguments
        ---------
            x: The input features of shape (N, E) where N is the batch size and
               E is d_model passed in the constructor
            memory: A sequence of features (N, S, E) that the input will attend
                    to. S is the sequence length and E is the same as for x.
            memory_length_mask: An implementation of a BaseMask that encodes
                                how many elements each memory sequence in the
                                batch consists of.
            state: The state varies depending on the attention implementations
                   but it allows for recurrent implementation.
        """
        # Normalize the mask
        N = x.shape[0]
        L = memory.shape[1]
        memory_length_mask = memory_length_mask or \
            LengthMask(x.new_full((N,), L, dtype=torch.int64))

        # Extract the individual states for the self attention and the cross
        # attention
        self_state, cross_state = state or [None, None]

        # First apply the self attention and add it to the input
        x2, self_state = self.self_attention(x, x, x, state=self_state)
        x = self.norm1(x + self.dropout(x2))

        # Secondly apply the cross attention and add it to the previous output
        x2, cross_state = self.cross_attention(
            x, memory, memory, memory_length_mask, state=cross_state
        )
        x = self.norm2(x + self.dropout(x2))

        # Finally run the fully connected part of the layer
        y = x
        y = self.dropout(self.activation(self.linear1(y)))
        y = self.dropout(self.linear2(y))

        return self.norm3(x+y), [self_state, cross_state]

Ancestors

  • torch.nn.modules.module.Module

Methods

def forward(self, x, memory, memory_length_mask=None, state=None)

Apply the transformer decoder to the input x and also attend to memory.

Note the memory mask is assumed to be a full mask.

Arguments

x: The input features of shape (N, E) where N is the batch size and
   E is d_model passed in the constructor
memory: A sequence of features (N, S, E) that the input will attend
        to. S is the sequence length and E is the same as for x.
memory_length_mask: An implementation of a BaseMask that encodes
                    how many elements each memory sequence in the
                    batch consists of.
state: The state varies depending on the attention implementations
       but it allows for recurrent implementation.
Expand source code
def forward(self, x, memory, memory_length_mask=None, state=None):
    """Apply the transformer decoder to the input x and also attend to
    memory.

    Note the memory mask is assumed to be a full mask.

    Arguments
    ---------
        x: The input features of shape (N, E) where N is the batch size and
           E is d_model passed in the constructor
        memory: A sequence of features (N, S, E) that the input will attend
                to. S is the sequence length and E is the same as for x.
        memory_length_mask: An implementation of a BaseMask that encodes
                            how many elements each memory sequence in the
                            batch consists of.
        state: The state varies depending on the attention implementations
               but it allows for recurrent implementation.
    """
    # Normalize the mask
    N = x.shape[0]
    L = memory.shape[1]
    memory_length_mask = memory_length_mask or \
        LengthMask(x.new_full((N,), L, dtype=torch.int64))

    # Extract the individual states for the self attention and the cross
    # attention
    self_state, cross_state = state or [None, None]

    # First apply the self attention and add it to the input
    x2, self_state = self.self_attention(x, x, x, state=self_state)
    x = self.norm1(x + self.dropout(x2))

    # Secondly apply the cross attention and add it to the previous output
    x2, cross_state = self.cross_attention(
        x, memory, memory, memory_length_mask, state=cross_state
    )
    x = self.norm2(x + self.dropout(x2))

    # Finally run the fully connected part of the layer
    y = x
    y = self.dropout(self.activation(self.linear1(y)))
    y = self.dropout(self.linear2(y))

    return self.norm3(x+y), [self_state, cross_state]
class RecurrentTransformerEncoder (layers, norm_layer=None, event_dispatcher='')

RecurrentTransformerEncoder is a sequence of RecurrentTransformerEncoderLayer instances.

RecurrentTransformerEncoder keeps a separate state per RecurrentTransformerEncoderLayer.

Arguments

layers: list, RecurrentTransformerEncoderLayer instances or instances
        that implement the same interface
norm_layer: A normalization layer to be applied to the final output
            (default: None which means no normalization)
event_dispatcher: str or EventDispatcher instance to be used by this
                  module for dispatching events (default: the default
                  global dispatcher)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class RecurrentTransformerEncoder(Module):
    """RecurrentTransformerEncoder is a sequence of
    RecurrentTransformerEncoderLayer instances.

    RecurrentTransformerEncoder keeps a separate state per
    RecurrentTransformerEncoderLayer.

    Arguments
    ---------
        layers: list, RecurrentTransformerEncoderLayer instances or instances
                that implement the same interface
        norm_layer: A normalization layer to be applied to the final output
                    (default: None which means no normalization)
        event_dispatcher: str or EventDispatcher instance to be used by this
                          module for dispatching events (default: the default
                          global dispatcher)
    """
    def __init__(self, layers, norm_layer=None, event_dispatcher=""):
        super(RecurrentTransformerEncoder, self).__init__()
        self.layers = ModuleList(layers)
        self.norm = norm_layer
        self.event_dispatcher = EventDispatcher.get(event_dispatcher)

    def forward(self, x, state=None, memory=None):
        """Apply all recurrent transformer layers to the input x using the
        provided state.

        Arguments
        ---------
            x: The input features of shape (N, E) where N is the batch size and
               E is d_model passed in the constructor of each recurrent
               transformer encoder layer
            state: A list of objects to be passed to each recurrent
                   transformer encoder layer
            memory: **Deprecated** name for the state argument
        """
        # Initialize the memory to None if not given
        state = check_state(state, memory)
        if state is None:
            state = [None]*len(self.layers)

        # Apply all the transformers
        for i, layer in enumerate(self.layers):
            x, s = layer(x, state[i])
            state[i] = s

        # Apply the normalization if needed
        if self.norm is not None:
            x = self.norm(x)

        return x, state

Ancestors

  • torch.nn.modules.module.Module

Methods

def forward(self, x, state=None, memory=None)

Apply all recurrent transformer layers to the input x using the provided state.

Arguments

x: The input features of shape (N, E) where N is the batch size and
   E is d_model passed in the constructor of each recurrent
   transformer encoder layer
state: A list of objects to be passed to each recurrent
       transformer encoder layer
memory: **Deprecated** name for the state argument
Expand source code
def forward(self, x, state=None, memory=None):
    """Apply all recurrent transformer layers to the input x using the
    provided state.

    Arguments
    ---------
        x: The input features of shape (N, E) where N is the batch size and
           E is d_model passed in the constructor of each recurrent
           transformer encoder layer
        state: A list of objects to be passed to each recurrent
               transformer encoder layer
        memory: **Deprecated** name for the state argument
    """
    # Initialize the memory to None if not given
    state = check_state(state, memory)
    if state is None:
        state = [None]*len(self.layers)

    # Apply all the transformers
    for i, layer in enumerate(self.layers):
        x, s = layer(x, state[i])
        state[i] = s

    # Apply the normalization if needed
    if self.norm is not None:
        x = self.norm(x)

    return x, state
class RecurrentTransformerEncoderLayer (attention, d_model, d_ff=None, dropout=0.1, activation='relu', event_dispatcher='')

Attention to the previous inputs and feed forward with skip connections.

This transformer encoder layer is the recurrent dual of fast_transformers.transformers.TransformerEncoderLayer . The results should be identical given the same inputs and a lower triangular mask.

Arguments

attention: The attention implementation to use given as a nn.Module
d_model: The input feature dimensionality
d_ff: The dimensionality of the intermediate features after the
      attention (default: d_model*4)
dropout: The dropout rate to apply to the intermediate features
         (default: 0.1)
activation: {'relu', 'gelu'} Which activation to use for the feed
            forward part of the layer (default: relu)
event_dispatcher: str or EventDispatcher instance to be used by this
                  module for dispatching events (default: the default
                  global dispatcher)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class RecurrentTransformerEncoderLayer(Module):
    """Attention to the previous inputs and feed forward with skip connections.

    This transformer encoder layer is the recurrent dual of
    fast_transformers.transformers.TransformerEncoderLayer . The results should
    be identical given the same inputs and a lower triangular mask.

    Arguments
    ---------
        attention: The attention implementation to use given as a nn.Module
        d_model: The input feature dimensionality
        d_ff: The dimensionality of the intermediate features after the
              attention (default: d_model*4)
        dropout: The dropout rate to apply to the intermediate features
                 (default: 0.1)
        activation: {'relu', 'gelu'} Which activation to use for the feed
                    forward part of the layer (default: relu)
        event_dispatcher: str or EventDispatcher instance to be used by this
                          module for dispatching events (default: the default
                          global dispatcher)
    """
    def __init__(self, attention, d_model, d_ff=None, dropout=0.1,
                 activation="relu", event_dispatcher=""):
        super(RecurrentTransformerEncoderLayer, self).__init__()
        d_ff = d_ff or 4*d_model
        self.attention = attention
        self.linear1 = Linear(d_model, d_ff)
        self.linear2 = Linear(d_ff, d_model)
        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.dropout = Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu
        self.event_dispatcher = EventDispatcher.get(event_dispatcher)

    def forward(self, x, state=None, memory=None):
        """Apply the transformer encoder to the input x using the provided
        memory.

        Arguments
        ---------
            x: The input features of shape (N, E) where N is the batch size and
               E is d_model passed in the constructor
            state: The state can vary depending on the attention implementation
            memory: **Deprecated** name for the state argument
        """
        # Normalize the state name
        state = check_state(state, memory)

        # Run the self attention and add it to the input
        x2, state = self.attention(x, x, x, state)
        x = x + self.dropout(x2)

        # Run the fully connected part of the layer
        y = x = self.norm1(x)
        y = self.dropout(self.activation(self.linear1(y)))
        y = self.dropout(self.linear2(y))

        return self.norm2(x+y), state

Ancestors

  • torch.nn.modules.module.Module

Methods

def forward(self, x, state=None, memory=None)

Apply the transformer encoder to the input x using the provided memory.

Arguments

x: The input features of shape (N, E) where N is the batch size and
   E is d_model passed in the constructor
state: The state can vary depending on the attention implementation
memory: **Deprecated** name for the state argument
Expand source code
def forward(self, x, state=None, memory=None):
    """Apply the transformer encoder to the input x using the provided
    memory.

    Arguments
    ---------
        x: The input features of shape (N, E) where N is the batch size and
           E is d_model passed in the constructor
        state: The state can vary depending on the attention implementation
        memory: **Deprecated** name for the state argument
    """
    # Normalize the state name
    state = check_state(state, memory)

    # Run the self attention and add it to the input
    x2, state = self.attention(x, x, x, state)
    x = x + self.dropout(x2)

    # Run the fully connected part of the layer
    y = x = self.norm1(x)
    y = self.dropout(self.activation(self.linear1(y)))
    y = self.dropout(self.linear2(y))

    return self.norm2(x+y), state