Module fast_transformers.local_product
Expand source code
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>
#
import torch
from .local_product_cpu import local_dot_product as local_dot_product_cpu, \
local_dot_backward as local_dot_backward_cpu, \
local_weighted_average as local_weighted_average_cpu, \
local_weighted_average_backward as local_weighted_average_backward_cpu
try:
from .local_product_cuda import \
local_dot_product as local_dot_product_cuda, \
local_dot_backward as local_dot_backward_cuda, \
local_weighted_average as local_weighted_average_cuda, \
local_weighted_average_backward as local_weighted_average_backward_cuda
except ImportError:
local_dot_product_cuda = None
local_dot_backward_cuda = None
local_weighted_average_cuda = None
local_weighted_average_backward_cuda = None
class LocalDotProduct(torch.autograd.Function):
"""Compute the dot product of the queries and keys but only consider a
local neighborhood of each query."""
dot = {
"cpu": local_dot_product_cpu,
"cuda": local_dot_product_cuda
}
dot_backward = {
"cpu": local_dot_backward_cpu,
"cuda": local_dot_backward_cuda
}
@staticmethod
def forward(ctx, queries, keys, attn_mask, key_lengths, local_context):
# Save the inputs for the gradient computation
ctx.save_for_backward(queries, keys, key_lengths)
ctx.local_context = local_context
return LocalDotProduct.dot[queries.device.type](
queries,
keys,
attn_mask,
key_lengths,
local_context
)
@staticmethod
def backward(ctx, grad_input):
queries, keys, key_lengths = ctx.saved_tensors
local_context = ctx.local_context
grads = LocalDotProduct.dot_backward[queries.device.type](
queries,
keys,
key_lengths,
grad_input,
local_context
)
# plus 3 None for masks and local_context
return grads + (None, None, None)
class LocalWeightedAverage(torch.autograd.Function):
"""Compute the weighted average of the values with the local attention."""
avg = {
"cpu": local_weighted_average_cpu,
"cuda": local_weighted_average_cuda
}
avg_backward = {
"cpu": local_weighted_average_backward_cpu,
"cuda": local_weighted_average_backward_cuda
}
@staticmethod
def forward(ctx, A, V):
# Save the inputs for the gradient computation
ctx.save_for_backward(A, V)
return LocalWeightedAverage.avg[A.device.type](A, V)
@staticmethod
def backward(ctx, grad_input):
A, V = ctx.saved_tensors
return LocalWeightedAverage.avg_backward[A.device.type](
A, V, grad_input
)
# Alias the autograd functions to python style snake case naming
local_dot_product = LocalDotProduct.apply
local_weighted_average = LocalWeightedAverage.apply
Sub-modules
fast_transformers.local_product.local_product_cpu
fast_transformers.local_product.local_product_cuda
Functions
def local_dot_product(...)
def local_weighted_average(...)
Classes
class LocalDotProduct (...)
-
Compute the dot product of the queries and keys but only consider a local neighborhood of each query.
Expand source code
class LocalDotProduct(torch.autograd.Function): """Compute the dot product of the queries and keys but only consider a local neighborhood of each query.""" dot = { "cpu": local_dot_product_cpu, "cuda": local_dot_product_cuda } dot_backward = { "cpu": local_dot_backward_cpu, "cuda": local_dot_backward_cuda } @staticmethod def forward(ctx, queries, keys, attn_mask, key_lengths, local_context): # Save the inputs for the gradient computation ctx.save_for_backward(queries, keys, key_lengths) ctx.local_context = local_context return LocalDotProduct.dot[queries.device.type]( queries, keys, attn_mask, key_lengths, local_context ) @staticmethod def backward(ctx, grad_input): queries, keys, key_lengths = ctx.saved_tensors local_context = ctx.local_context grads = LocalDotProduct.dot_backward[queries.device.type]( queries, keys, key_lengths, grad_input, local_context ) # plus 3 None for masks and local_context return grads + (None, None, None)
Ancestors
- torch.autograd.function.Function
- torch._C._FunctionBase
- torch.autograd.function._ContextMethodMixin
- torch.autograd.function._HookMixin
Class variables
var dot
var dot_backward
Static methods
def backward(ctx, grad_input)
-
Defines a formula for differentiating the operation.
This function is to be overridden by all subclasses.
It must accept a context :attr:
ctx
as the first argument, followed by as many outputs did :func:forward
return, and it should return as many tensors, as there were inputs to :func:forward
. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute :attr:
ctx.needs_input_grad
as a tuple of booleans representing whether each input needs gradient. E.g., :func:backward
will havectx.needs_input_grad[0] = True
if the first input to :func:forward
needs gradient computated w.r.t. the output.Expand source code
@staticmethod def backward(ctx, grad_input): queries, keys, key_lengths = ctx.saved_tensors local_context = ctx.local_context grads = LocalDotProduct.dot_backward[queries.device.type]( queries, keys, key_lengths, grad_input, local_context ) # plus 3 None for masks and local_context return grads + (None, None, None)
def forward(ctx, queries, keys, attn_mask, key_lengths, local_context)
-
Performs the operation.
This function is to be overridden by all subclasses.
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
The context can be used to store tensors that can be then retrieved during the backward pass.
Expand source code
@staticmethod def forward(ctx, queries, keys, attn_mask, key_lengths, local_context): # Save the inputs for the gradient computation ctx.save_for_backward(queries, keys, key_lengths) ctx.local_context = local_context return LocalDotProduct.dot[queries.device.type]( queries, keys, attn_mask, key_lengths, local_context )
class LocalWeightedAverage (...)
-
Compute the weighted average of the values with the local attention.
Expand source code
class LocalWeightedAverage(torch.autograd.Function): """Compute the weighted average of the values with the local attention.""" avg = { "cpu": local_weighted_average_cpu, "cuda": local_weighted_average_cuda } avg_backward = { "cpu": local_weighted_average_backward_cpu, "cuda": local_weighted_average_backward_cuda } @staticmethod def forward(ctx, A, V): # Save the inputs for the gradient computation ctx.save_for_backward(A, V) return LocalWeightedAverage.avg[A.device.type](A, V) @staticmethod def backward(ctx, grad_input): A, V = ctx.saved_tensors return LocalWeightedAverage.avg_backward[A.device.type]( A, V, grad_input )
Ancestors
- torch.autograd.function.Function
- torch._C._FunctionBase
- torch.autograd.function._ContextMethodMixin
- torch.autograd.function._HookMixin
Class variables
var avg
var avg_backward
Static methods
def backward(ctx, grad_input)
-
Defines a formula for differentiating the operation.
This function is to be overridden by all subclasses.
It must accept a context :attr:
ctx
as the first argument, followed by as many outputs did :func:forward
return, and it should return as many tensors, as there were inputs to :func:forward
. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute :attr:
ctx.needs_input_grad
as a tuple of booleans representing whether each input needs gradient. E.g., :func:backward
will havectx.needs_input_grad[0] = True
if the first input to :func:forward
needs gradient computated w.r.t. the output.Expand source code
@staticmethod def backward(ctx, grad_input): A, V = ctx.saved_tensors return LocalWeightedAverage.avg_backward[A.device.type]( A, V, grad_input )
def forward(ctx, A, V)
-
Performs the operation.
This function is to be overridden by all subclasses.
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
The context can be used to store tensors that can be then retrieved during the backward pass.
Expand source code
@staticmethod def forward(ctx, A, V): # Save the inputs for the gradient computation ctx.save_for_backward(A, V) return LocalWeightedAverage.avg[A.device.type](A, V)