Module fast_transformers.local_product

Expand source code
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>
#

import torch

from .local_product_cpu import local_dot_product as local_dot_product_cpu, \
    local_dot_backward as local_dot_backward_cpu, \
    local_weighted_average as local_weighted_average_cpu, \
    local_weighted_average_backward as local_weighted_average_backward_cpu

try:
    from .local_product_cuda import \
        local_dot_product as local_dot_product_cuda, \
        local_dot_backward as local_dot_backward_cuda, \
        local_weighted_average as local_weighted_average_cuda, \
        local_weighted_average_backward as local_weighted_average_backward_cuda
except ImportError:
    local_dot_product_cuda = None
    local_dot_backward_cuda = None
    local_weighted_average_cuda = None
    local_weighted_average_backward_cuda = None


class LocalDotProduct(torch.autograd.Function):
    """Compute the dot product of the queries and keys but only consider a
    local neighborhood of each query."""
    dot = {
        "cpu": local_dot_product_cpu,
        "cuda": local_dot_product_cuda
    }
    dot_backward = {
        "cpu": local_dot_backward_cpu,
        "cuda": local_dot_backward_cuda
    }

    @staticmethod
    def forward(ctx, queries, keys, attn_mask, key_lengths, local_context):
        # Save the inputs for the gradient computation
        ctx.save_for_backward(queries, keys, key_lengths)
        ctx.local_context = local_context

        return LocalDotProduct.dot[queries.device.type](
            queries,
            keys,
            attn_mask,
            key_lengths,
            local_context
        )

    @staticmethod
    def backward(ctx, grad_input):
        queries, keys, key_lengths = ctx.saved_tensors
        local_context = ctx.local_context

        grads = LocalDotProduct.dot_backward[queries.device.type](
            queries,
            keys,
            key_lengths,
            grad_input,
            local_context
        )

        # plus 3 None for masks and local_context
        return grads + (None, None, None)


class LocalWeightedAverage(torch.autograd.Function):
    """Compute the weighted average of the values with the local attention."""
    avg = {
        "cpu": local_weighted_average_cpu,
        "cuda": local_weighted_average_cuda
    }
    avg_backward = {
        "cpu": local_weighted_average_backward_cpu,
        "cuda": local_weighted_average_backward_cuda
    }

    @staticmethod
    def forward(ctx, A, V):
        # Save the inputs for the gradient computation
        ctx.save_for_backward(A, V)

        return LocalWeightedAverage.avg[A.device.type](A, V)

    @staticmethod
    def backward(ctx, grad_input):
        A, V = ctx.saved_tensors
        return LocalWeightedAverage.avg_backward[A.device.type](
            A, V, grad_input
        )


# Alias the autograd functions to python style snake case naming
local_dot_product = LocalDotProduct.apply
local_weighted_average = LocalWeightedAverage.apply

Sub-modules

fast_transformers.local_product.local_product_cpu
fast_transformers.local_product.local_product_cuda

Functions

def local_dot_product(...)
def local_weighted_average(...)

Classes

class LocalDotProduct (...)

Compute the dot product of the queries and keys but only consider a local neighborhood of each query.

Expand source code
class LocalDotProduct(torch.autograd.Function):
    """Compute the dot product of the queries and keys but only consider a
    local neighborhood of each query."""
    dot = {
        "cpu": local_dot_product_cpu,
        "cuda": local_dot_product_cuda
    }
    dot_backward = {
        "cpu": local_dot_backward_cpu,
        "cuda": local_dot_backward_cuda
    }

    @staticmethod
    def forward(ctx, queries, keys, attn_mask, key_lengths, local_context):
        # Save the inputs for the gradient computation
        ctx.save_for_backward(queries, keys, key_lengths)
        ctx.local_context = local_context

        return LocalDotProduct.dot[queries.device.type](
            queries,
            keys,
            attn_mask,
            key_lengths,
            local_context
        )

    @staticmethod
    def backward(ctx, grad_input):
        queries, keys, key_lengths = ctx.saved_tensors
        local_context = ctx.local_context

        grads = LocalDotProduct.dot_backward[queries.device.type](
            queries,
            keys,
            key_lengths,
            grad_input,
            local_context
        )

        # plus 3 None for masks and local_context
        return grads + (None, None, None)

Ancestors

  • torch.autograd.function.Function
  • torch._C._FunctionBase
  • torch.autograd.function._ContextMethodMixin
  • torch.autograd.function._HookMixin

Class variables

var dot
var dot_backward

Static methods

def backward(ctx, grad_input)

Defines a formula for differentiating the operation.

This function is to be overridden by all subclasses.

It must accept a context :attr:ctx as the first argument, followed by as many outputs did :func:forward return, and it should return as many tensors, as there were inputs to :func:forward. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute :attr:ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., :func:backward will have ctx.needs_input_grad[0] = True if the first input to :func:forward needs gradient computated w.r.t. the output.

Expand source code
@staticmethod
def backward(ctx, grad_input):
    queries, keys, key_lengths = ctx.saved_tensors
    local_context = ctx.local_context

    grads = LocalDotProduct.dot_backward[queries.device.type](
        queries,
        keys,
        key_lengths,
        grad_input,
        local_context
    )

    # plus 3 None for masks and local_context
    return grads + (None, None, None)
def forward(ctx, queries, keys, attn_mask, key_lengths, local_context)

Performs the operation.

This function is to be overridden by all subclasses.

It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

The context can be used to store tensors that can be then retrieved during the backward pass.

Expand source code
@staticmethod
def forward(ctx, queries, keys, attn_mask, key_lengths, local_context):
    # Save the inputs for the gradient computation
    ctx.save_for_backward(queries, keys, key_lengths)
    ctx.local_context = local_context

    return LocalDotProduct.dot[queries.device.type](
        queries,
        keys,
        attn_mask,
        key_lengths,
        local_context
    )
class LocalWeightedAverage (...)

Compute the weighted average of the values with the local attention.

Expand source code
class LocalWeightedAverage(torch.autograd.Function):
    """Compute the weighted average of the values with the local attention."""
    avg = {
        "cpu": local_weighted_average_cpu,
        "cuda": local_weighted_average_cuda
    }
    avg_backward = {
        "cpu": local_weighted_average_backward_cpu,
        "cuda": local_weighted_average_backward_cuda
    }

    @staticmethod
    def forward(ctx, A, V):
        # Save the inputs for the gradient computation
        ctx.save_for_backward(A, V)

        return LocalWeightedAverage.avg[A.device.type](A, V)

    @staticmethod
    def backward(ctx, grad_input):
        A, V = ctx.saved_tensors
        return LocalWeightedAverage.avg_backward[A.device.type](
            A, V, grad_input
        )

Ancestors

  • torch.autograd.function.Function
  • torch._C._FunctionBase
  • torch.autograd.function._ContextMethodMixin
  • torch.autograd.function._HookMixin

Class variables

var avg
var avg_backward

Static methods

def backward(ctx, grad_input)

Defines a formula for differentiating the operation.

This function is to be overridden by all subclasses.

It must accept a context :attr:ctx as the first argument, followed by as many outputs did :func:forward return, and it should return as many tensors, as there were inputs to :func:forward. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute :attr:ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., :func:backward will have ctx.needs_input_grad[0] = True if the first input to :func:forward needs gradient computated w.r.t. the output.

Expand source code
@staticmethod
def backward(ctx, grad_input):
    A, V = ctx.saved_tensors
    return LocalWeightedAverage.avg_backward[A.device.type](
        A, V, grad_input
    )
def forward(ctx, A, V)

Performs the operation.

This function is to be overridden by all subclasses.

It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

The context can be used to store tensors that can be then retrieved during the backward pass.

Expand source code
@staticmethod
def forward(ctx, A, V):
    # Save the inputs for the gradient computation
    ctx.save_for_backward(A, V)

    return LocalWeightedAverage.avg[A.device.type](A, V)