Module fast_transformers.causal_product

Expand source code
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
# Apoorv Vyas <avyas@idiap.ch>
#

import torch

from .causal_product_cpu import causal_dot_product as causal_dot_product_cpu, \
    causal_dot_backward as causal_dot_backward_cpu

try:
    from .causal_product_cuda import \
        causal_dot_product as causal_dot_product_cuda, \
        causal_dot_backward as causal_dot_backward_cuda
except ImportError:
    causal_dot_product_cuda = causal_dot_backward_cuda = None


class CausalDotProduct(torch.autograd.Function):
    """Compute the weighted sum of values but attending only to previous
    values."""
    dot = {
        "cpu": causal_dot_product_cpu,
        "cuda": causal_dot_product_cuda
    }
    dot_backward = {
        "cpu": causal_dot_backward_cpu,
        "cuda": causal_dot_backward_cuda
    }

    @staticmethod
    def forward(ctx, Q, K, V):
        # Save the inputs for the gradient computation
        ctx.save_for_backward(Q, K, V)

        # Create the output tensor
        device = Q.device
        N, H, L, _ = Q.shape
        _, _, _, M = V.shape
        product = torch.zeros((N, H, L, M), device=device)

        # Actually perform the dot product
        CausalDotProduct.dot[device.type](
            Q.data,
            K.data,
            V.data,
            product
        )

        return product

    @staticmethod
    def backward(ctx, grad_out):
        # Extract the saved tensors
        Q, K, V = ctx.saved_tensors

        # Allocate memory for the gradients
        grad_Q = torch.zeros_like(Q)
        grad_K = torch.zeros_like(K)
        grad_V = torch.zeros_like(V)

        # Actually compute the gradients
        CausalDotProduct.dot_backward[Q.device.type](
            Q.data,
            K.data,
            V.data,
            grad_out,
            grad_Q,
            grad_K,
            grad_V
        )

        return grad_Q, grad_K, grad_V


# Alias the autograd functions to python style snake case naming
causal_dot_product = CausalDotProduct.apply

Sub-modules

fast_transformers.causal_product.causal_product_cpu
fast_transformers.causal_product.causal_product_cuda

Functions

def causal_dot_product(...)

Classes

class CausalDotProduct (...)

Compute the weighted sum of values but attending only to previous values.

Expand source code
class CausalDotProduct(torch.autograd.Function):
    """Compute the weighted sum of values but attending only to previous
    values."""
    dot = {
        "cpu": causal_dot_product_cpu,
        "cuda": causal_dot_product_cuda
    }
    dot_backward = {
        "cpu": causal_dot_backward_cpu,
        "cuda": causal_dot_backward_cuda
    }

    @staticmethod
    def forward(ctx, Q, K, V):
        # Save the inputs for the gradient computation
        ctx.save_for_backward(Q, K, V)

        # Create the output tensor
        device = Q.device
        N, H, L, _ = Q.shape
        _, _, _, M = V.shape
        product = torch.zeros((N, H, L, M), device=device)

        # Actually perform the dot product
        CausalDotProduct.dot[device.type](
            Q.data,
            K.data,
            V.data,
            product
        )

        return product

    @staticmethod
    def backward(ctx, grad_out):
        # Extract the saved tensors
        Q, K, V = ctx.saved_tensors

        # Allocate memory for the gradients
        grad_Q = torch.zeros_like(Q)
        grad_K = torch.zeros_like(K)
        grad_V = torch.zeros_like(V)

        # Actually compute the gradients
        CausalDotProduct.dot_backward[Q.device.type](
            Q.data,
            K.data,
            V.data,
            grad_out,
            grad_Q,
            grad_K,
            grad_V
        )

        return grad_Q, grad_K, grad_V

Ancestors

  • torch.autograd.function.Function
  • torch._C._FunctionBase
  • torch.autograd.function._ContextMethodMixin
  • torch.autograd.function._HookMixin

Class variables

var dot
var dot_backward

Static methods

def backward(ctx, grad_out)

Defines a formula for differentiating the operation.

This function is to be overridden by all subclasses.

It must accept a context :attr:ctx as the first argument, followed by as many outputs did :func:forward return, and it should return as many tensors, as there were inputs to :func:forward. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute :attr:ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., :func:backward will have ctx.needs_input_grad[0] = True if the first input to :func:forward needs gradient computated w.r.t. the output.

Expand source code
@staticmethod
def backward(ctx, grad_out):
    # Extract the saved tensors
    Q, K, V = ctx.saved_tensors

    # Allocate memory for the gradients
    grad_Q = torch.zeros_like(Q)
    grad_K = torch.zeros_like(K)
    grad_V = torch.zeros_like(V)

    # Actually compute the gradients
    CausalDotProduct.dot_backward[Q.device.type](
        Q.data,
        K.data,
        V.data,
        grad_out,
        grad_Q,
        grad_K,
        grad_V
    )

    return grad_Q, grad_K, grad_V
def forward(ctx, Q, K, V)

Performs the operation.

This function is to be overridden by all subclasses.

It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

The context can be used to store tensors that can be then retrieved during the backward pass.

Expand source code
@staticmethod
def forward(ctx, Q, K, V):
    # Save the inputs for the gradient computation
    ctx.save_for_backward(Q, K, V)

    # Create the output tensor
    device = Q.device
    N, H, L, _ = Q.shape
    _, _, _, M = V.shape
    product = torch.zeros((N, H, L, M), device=device)

    # Actually perform the dot product
    CausalDotProduct.dot[device.type](
        Q.data,
        K.data,
        V.data,
        product
    )

    return product