Module fast_transformers.causal_product
Expand source code
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
# Apoorv Vyas <avyas@idiap.ch>
#
import torch
from .causal_product_cpu import causal_dot_product as causal_dot_product_cpu, \
causal_dot_backward as causal_dot_backward_cpu
try:
from .causal_product_cuda import \
causal_dot_product as causal_dot_product_cuda, \
causal_dot_backward as causal_dot_backward_cuda
except ImportError:
causal_dot_product_cuda = causal_dot_backward_cuda = None
class CausalDotProduct(torch.autograd.Function):
"""Compute the weighted sum of values but attending only to previous
values."""
dot = {
"cpu": causal_dot_product_cpu,
"cuda": causal_dot_product_cuda
}
dot_backward = {
"cpu": causal_dot_backward_cpu,
"cuda": causal_dot_backward_cuda
}
@staticmethod
def forward(ctx, Q, K, V):
# Save the inputs for the gradient computation
ctx.save_for_backward(Q, K, V)
# Create the output tensor
device = Q.device
N, H, L, _ = Q.shape
_, _, _, M = V.shape
product = torch.zeros((N, H, L, M), device=device)
# Actually perform the dot product
CausalDotProduct.dot[device.type](
Q.data,
K.data,
V.data,
product
)
return product
@staticmethod
def backward(ctx, grad_out):
# Extract the saved tensors
Q, K, V = ctx.saved_tensors
# Allocate memory for the gradients
grad_Q = torch.zeros_like(Q)
grad_K = torch.zeros_like(K)
grad_V = torch.zeros_like(V)
# Actually compute the gradients
CausalDotProduct.dot_backward[Q.device.type](
Q.data,
K.data,
V.data,
grad_out,
grad_Q,
grad_K,
grad_V
)
return grad_Q, grad_K, grad_V
# Alias the autograd functions to python style snake case naming
causal_dot_product = CausalDotProduct.apply
Sub-modules
fast_transformers.causal_product.causal_product_cpu
fast_transformers.causal_product.causal_product_cuda
Functions
def causal_dot_product(...)
Classes
class CausalDotProduct (...)
-
Compute the weighted sum of values but attending only to previous values.
Expand source code
class CausalDotProduct(torch.autograd.Function): """Compute the weighted sum of values but attending only to previous values.""" dot = { "cpu": causal_dot_product_cpu, "cuda": causal_dot_product_cuda } dot_backward = { "cpu": causal_dot_backward_cpu, "cuda": causal_dot_backward_cuda } @staticmethod def forward(ctx, Q, K, V): # Save the inputs for the gradient computation ctx.save_for_backward(Q, K, V) # Create the output tensor device = Q.device N, H, L, _ = Q.shape _, _, _, M = V.shape product = torch.zeros((N, H, L, M), device=device) # Actually perform the dot product CausalDotProduct.dot[device.type]( Q.data, K.data, V.data, product ) return product @staticmethod def backward(ctx, grad_out): # Extract the saved tensors Q, K, V = ctx.saved_tensors # Allocate memory for the gradients grad_Q = torch.zeros_like(Q) grad_K = torch.zeros_like(K) grad_V = torch.zeros_like(V) # Actually compute the gradients CausalDotProduct.dot_backward[Q.device.type]( Q.data, K.data, V.data, grad_out, grad_Q, grad_K, grad_V ) return grad_Q, grad_K, grad_V
Ancestors
- torch.autograd.function.Function
- torch._C._FunctionBase
- torch.autograd.function._ContextMethodMixin
- torch.autograd.function._HookMixin
Class variables
var dot
var dot_backward
Static methods
def backward(ctx, grad_out)
-
Defines a formula for differentiating the operation.
This function is to be overridden by all subclasses.
It must accept a context :attr:
ctx
as the first argument, followed by as many outputs did :func:forward
return, and it should return as many tensors, as there were inputs to :func:forward
. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute :attr:
ctx.needs_input_grad
as a tuple of booleans representing whether each input needs gradient. E.g., :func:backward
will havectx.needs_input_grad[0] = True
if the first input to :func:forward
needs gradient computated w.r.t. the output.Expand source code
@staticmethod def backward(ctx, grad_out): # Extract the saved tensors Q, K, V = ctx.saved_tensors # Allocate memory for the gradients grad_Q = torch.zeros_like(Q) grad_K = torch.zeros_like(K) grad_V = torch.zeros_like(V) # Actually compute the gradients CausalDotProduct.dot_backward[Q.device.type]( Q.data, K.data, V.data, grad_out, grad_Q, grad_K, grad_V ) return grad_Q, grad_K, grad_V
def forward(ctx, Q, K, V)
-
Performs the operation.
This function is to be overridden by all subclasses.
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
The context can be used to store tensors that can be then retrieved during the backward pass.
Expand source code
@staticmethod def forward(ctx, Q, K, V): # Save the inputs for the gradient computation ctx.save_for_backward(Q, K, V) # Create the output tensor device = Q.device N, H, L, _ = Q.shape _, _, _, M = V.shape product = torch.zeros((N, H, L, M), device=device) # Actually perform the dot product CausalDotProduct.dot[device.type]( Q.data, K.data, V.data, product ) return product