Module fast_transformers.sparse_product
Expand source code
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
# Apoorv Vyas <avyas@idiap.ch>
#
import torch
from .sparse_product_cpu import \
sparse_dot_product as sparse_dot_product_cpu, \
sparse_dot_backward as sparse_dot_backward_cpu, \
sparse_weighted_average as sparse_weighted_average_cpu, \
sparse_weighted_average_backward as sparse_weighted_average_backward_cpu
try:
from .sparse_product_cuda import \
sparse_dot_product as sparse_dot_product_cuda, \
sparse_dot_backward as sparse_dot_backward_cuda, \
sparse_weighted_average as sparse_weighted_average_cuda, \
sparse_weighted_average_backward as \
sparse_weighted_average_backward_cuda
except ImportError:
sparse_dot_product_cuda = None
sparse_dot_backward_cuda = None
sparse_weighted_average_cuda = None
sparse_weighted_average_backward_cuda = None
from .clustered_sparse_product_cpu import \
clustered_sparse_dot_product as clustered_sparse_dot_product_cpu, \
clustered_sparse_dot_backward as clustered_sparse_dot_backward_cpu, \
clustered_sparse_weighted_average as \
clustered_sparse_weighted_average_cpu, \
clustered_sparse_weighted_average_backward as \
clustered_sparse_weighted_average_backward_cpu
try:
from .clustered_sparse_product_cuda import \
clustered_sparse_dot_product as clustered_sparse_dot_product_cuda, \
clustered_sparse_dot_backward as clustered_sparse_dot_backward_cuda, \
clustered_sparse_weighted_average as \
clustered_sparse_weighted_average_cuda, \
clustered_sparse_weighted_average_backward as \
clustered_sparse_weighted_average_backward_cuda
except ImportError:
clustered_sparse_dot_product_cuda = None
clustered_sparse_dot_backward_cuda = None
clustered_sparse_weighted_average_cuda = None
clustered_sparse_weighted_average_backward_cuda = None
class SparseDotProduct(torch.autograd.Function):
"""Compute the dot products only at the positions specified by topk."""
dot = {
"cpu": sparse_dot_product_cpu,
"cuda": sparse_dot_product_cuda
}
dot_backward = {
"cpu": sparse_dot_backward_cpu,
"cuda": sparse_dot_backward_cuda
}
@staticmethod
def forward(ctx, Q, K, topk):
# Save the inputs to compute the gradient
ctx.save_for_backward(Q, K, topk)
# Create the output tensor
device = Q.device
N, H, L, E = Q.shape
_, _, _, k = topk.shape
product = torch.empty((N, H, L, k), device=device)
# Actually perform the dot product
SparseDotProduct.dot[device.type](Q, K, topk, product)
return product
@staticmethod
def backward(ctx, grad_output):
# Extract the saved tensors and allocate memory for the gradients
Q, K, topk = ctx.saved_tensors
grad_Q = torch.zeros_like(Q)
grad_K = torch.zeros_like(K)
SparseDotProduct.dot_backward[Q.device.type](
Q,
K,
topk,
grad_output,
grad_Q,
grad_K
)
return grad_Q, grad_K, None
class SparseWeightedAverage(torch.autograd.Function):
"""Compute the weighted average only for the topk values."""
avg = {
"cpu": sparse_weighted_average_cpu,
"cuda": sparse_weighted_average_cuda
}
avg_backward = {
"cpu": sparse_weighted_average_backward_cpu,
"cuda": sparse_weighted_average_backward_cuda
}
@staticmethod
def forward(ctx, weights, values, topk):
# Save the tensors to compute the gradient
ctx.save_for_backward(weights, values, topk)
# Allocate the output tensor
N, H, L, _ = weights.shape
_, _, _, E = values.shape
output = values.new_zeros(N, H, L, E)
# Compute the average
SparseWeightedAverage.avg[weights.device.type](
weights,
values,
topk,
output
)
return output
@staticmethod
def backward(ctx, grad_output):
# Extract the saved tensors and allocate memory for the gradients
weights, values, topk = ctx.saved_tensors
grad_weights = torch.zeros_like(weights)
grad_values = torch.zeros_like(values)
if grad_output.stride()[-1] != 1:
grad_output = grad_output.contiguous()
SparseWeightedAverage.avg_backward[weights.device.type](
weights,
values,
topk,
grad_output,
grad_weights,
grad_values
)
return grad_weights, grad_values, None
class ClusteredSparseDotProduct(torch.autograd.Function):
"""Compute the dot products only at the positions specified by topk."""
dot = {
"cpu": clustered_sparse_dot_product_cpu,
"cuda": clustered_sparse_dot_product_cuda
}
dot_backward = {
"cpu": clustered_sparse_dot_backward_cpu,
"cuda": clustered_sparse_dot_backward_cuda
}
@staticmethod
def forward(ctx, Q, K, topk, groups, counts, lengths):
# Save the inputs to compute the gradient
ctx.save_for_backward(Q, K, topk, groups)
device = Q.device
N, H, L, E = Q.shape
_, _, C, k = topk.shape
# Create the output tensor
product = torch.empty((N, H, L, k), device=device)
# Unfortunately the cpu and gpu interfaces are different so
# the entire call is surrounded by if-else block
if device.type == "cpu":
ClusteredSparseDotProduct.dot[device.type](
Q,
K,
groups,
topk,
product
)
else:
queries_per_block = min(L, 1024//k)
threads = k * queries_per_block
blocks = ((L*k)//threads) + C + 1
query_map = torch.ones((N, H, blocks), dtype=torch.int32).cuda() * L
blocks_map = torch.ones((N, H, blocks), dtype=torch.int32).cuda() * -1
_, sorted_group_indices = torch.sort(groups, descending=True, dim=-1)
# Actually perform the dot product
ClusteredSparseDotProduct.dot[device.type](
Q,
K,
topk,
lengths,
blocks_map,
query_map,
counts,
sorted_group_indices,
product
)
return product
@staticmethod
def backward(ctx, grad_output):
# Extract the saved tensors and allocate memory for the gradients
Q, K, topk, groups = ctx.saved_tensors
grad_Q = torch.zeros_like(Q)
grad_K = torch.zeros_like(K)
ClusteredSparseDotProduct.dot_backward[Q.device.type](
Q,
K,
groups,
topk,
grad_output,
grad_Q,
grad_K
)
return grad_Q, grad_K, None, None, None, None
class ClusteredSparseWeightedAverage(torch.autograd.Function):
"""Compute the weighted average only for the topk values."""
avg = {
"cpu": clustered_sparse_weighted_average_cpu,
"cuda": clustered_sparse_weighted_average_cuda
}
avg_backward = {
"cpu": clustered_sparse_weighted_average_backward_cpu,
"cuda": clustered_sparse_weighted_average_backward_cuda
}
@staticmethod
def forward(ctx, weights, values, topk, groups):
# Save the tensors to compute the gradient
ctx.save_for_backward(weights, values, topk, groups)
# Allocate the output tensor
N, H, L, _ = weights.shape
_, _, _, E = values.shape
output = values.new_zeros(N, H, L, E)
# Compute the average
ClusteredSparseWeightedAverage.avg[weights.device.type](
weights,
values,
groups,
topk,
output
)
return output
@staticmethod
def backward(ctx, grad_output):
# Extract the saved tensors and allocate memory for the gradients
weights, values, topk, groups = ctx.saved_tensors
grad_weights = torch.zeros_like(weights)
grad_values = torch.zeros_like(values)
if grad_output.stride()[-1] != 1:
grad_output = grad_output.contiguous()
ClusteredSparseWeightedAverage.avg_backward[weights.device.type](
weights,
values,
groups,
topk,
grad_output,
grad_weights,
grad_values
)
return grad_weights, grad_values, None, None
# Alias the autograd functions to python style snake case naming
clustered_sparse_dot_product = ClusteredSparseDotProduct.apply
clustered_sparse_weighted_average = ClusteredSparseWeightedAverage.apply
# Alias the autograd functions to python style snake case naming
sparse_dot_product = SparseDotProduct.apply
sparse_weighted_average = SparseWeightedAverage.apply
Sub-modules
fast_transformers.sparse_product.clustered_sparse_product_cpu
fast_transformers.sparse_product.clustered_sparse_product_cuda
fast_transformers.sparse_product.sparse_product_cpu
fast_transformers.sparse_product.sparse_product_cuda
Functions
def clustered_sparse_dot_product(...)
def clustered_sparse_weighted_average(...)
def sparse_dot_product(...)
def sparse_weighted_average(...)
Classes
class ClusteredSparseDotProduct (...)
-
Compute the dot products only at the positions specified by topk.
Expand source code
class ClusteredSparseDotProduct(torch.autograd.Function): """Compute the dot products only at the positions specified by topk.""" dot = { "cpu": clustered_sparse_dot_product_cpu, "cuda": clustered_sparse_dot_product_cuda } dot_backward = { "cpu": clustered_sparse_dot_backward_cpu, "cuda": clustered_sparse_dot_backward_cuda } @staticmethod def forward(ctx, Q, K, topk, groups, counts, lengths): # Save the inputs to compute the gradient ctx.save_for_backward(Q, K, topk, groups) device = Q.device N, H, L, E = Q.shape _, _, C, k = topk.shape # Create the output tensor product = torch.empty((N, H, L, k), device=device) # Unfortunately the cpu and gpu interfaces are different so # the entire call is surrounded by if-else block if device.type == "cpu": ClusteredSparseDotProduct.dot[device.type]( Q, K, groups, topk, product ) else: queries_per_block = min(L, 1024//k) threads = k * queries_per_block blocks = ((L*k)//threads) + C + 1 query_map = torch.ones((N, H, blocks), dtype=torch.int32).cuda() * L blocks_map = torch.ones((N, H, blocks), dtype=torch.int32).cuda() * -1 _, sorted_group_indices = torch.sort(groups, descending=True, dim=-1) # Actually perform the dot product ClusteredSparseDotProduct.dot[device.type]( Q, K, topk, lengths, blocks_map, query_map, counts, sorted_group_indices, product ) return product @staticmethod def backward(ctx, grad_output): # Extract the saved tensors and allocate memory for the gradients Q, K, topk, groups = ctx.saved_tensors grad_Q = torch.zeros_like(Q) grad_K = torch.zeros_like(K) ClusteredSparseDotProduct.dot_backward[Q.device.type]( Q, K, groups, topk, grad_output, grad_Q, grad_K ) return grad_Q, grad_K, None, None, None, None
Ancestors
- torch.autograd.function.Function
- torch._C._FunctionBase
- torch.autograd.function._ContextMethodMixin
- torch.autograd.function._HookMixin
Class variables
var dot
var dot_backward
Static methods
def backward(ctx, grad_output)
-
Defines a formula for differentiating the operation.
This function is to be overridden by all subclasses.
It must accept a context :attr:
ctx
as the first argument, followed by as many outputs did :func:forward
return, and it should return as many tensors, as there were inputs to :func:forward
. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute :attr:
ctx.needs_input_grad
as a tuple of booleans representing whether each input needs gradient. E.g., :func:backward
will havectx.needs_input_grad[0] = True
if the first input to :func:forward
needs gradient computated w.r.t. the output.Expand source code
@staticmethod def backward(ctx, grad_output): # Extract the saved tensors and allocate memory for the gradients Q, K, topk, groups = ctx.saved_tensors grad_Q = torch.zeros_like(Q) grad_K = torch.zeros_like(K) ClusteredSparseDotProduct.dot_backward[Q.device.type]( Q, K, groups, topk, grad_output, grad_Q, grad_K ) return grad_Q, grad_K, None, None, None, None
def forward(ctx, Q, K, topk, groups, counts, lengths)
-
Performs the operation.
This function is to be overridden by all subclasses.
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
The context can be used to store tensors that can be then retrieved during the backward pass.
Expand source code
@staticmethod def forward(ctx, Q, K, topk, groups, counts, lengths): # Save the inputs to compute the gradient ctx.save_for_backward(Q, K, topk, groups) device = Q.device N, H, L, E = Q.shape _, _, C, k = topk.shape # Create the output tensor product = torch.empty((N, H, L, k), device=device) # Unfortunately the cpu and gpu interfaces are different so # the entire call is surrounded by if-else block if device.type == "cpu": ClusteredSparseDotProduct.dot[device.type]( Q, K, groups, topk, product ) else: queries_per_block = min(L, 1024//k) threads = k * queries_per_block blocks = ((L*k)//threads) + C + 1 query_map = torch.ones((N, H, blocks), dtype=torch.int32).cuda() * L blocks_map = torch.ones((N, H, blocks), dtype=torch.int32).cuda() * -1 _, sorted_group_indices = torch.sort(groups, descending=True, dim=-1) # Actually perform the dot product ClusteredSparseDotProduct.dot[device.type]( Q, K, topk, lengths, blocks_map, query_map, counts, sorted_group_indices, product ) return product
class ClusteredSparseWeightedAverage (...)
-
Compute the weighted average only for the topk values.
Expand source code
class ClusteredSparseWeightedAverage(torch.autograd.Function): """Compute the weighted average only for the topk values.""" avg = { "cpu": clustered_sparse_weighted_average_cpu, "cuda": clustered_sparse_weighted_average_cuda } avg_backward = { "cpu": clustered_sparse_weighted_average_backward_cpu, "cuda": clustered_sparse_weighted_average_backward_cuda } @staticmethod def forward(ctx, weights, values, topk, groups): # Save the tensors to compute the gradient ctx.save_for_backward(weights, values, topk, groups) # Allocate the output tensor N, H, L, _ = weights.shape _, _, _, E = values.shape output = values.new_zeros(N, H, L, E) # Compute the average ClusteredSparseWeightedAverage.avg[weights.device.type]( weights, values, groups, topk, output ) return output @staticmethod def backward(ctx, grad_output): # Extract the saved tensors and allocate memory for the gradients weights, values, topk, groups = ctx.saved_tensors grad_weights = torch.zeros_like(weights) grad_values = torch.zeros_like(values) if grad_output.stride()[-1] != 1: grad_output = grad_output.contiguous() ClusteredSparseWeightedAverage.avg_backward[weights.device.type]( weights, values, groups, topk, grad_output, grad_weights, grad_values ) return grad_weights, grad_values, None, None
Ancestors
- torch.autograd.function.Function
- torch._C._FunctionBase
- torch.autograd.function._ContextMethodMixin
- torch.autograd.function._HookMixin
Class variables
var avg
var avg_backward
Static methods
def backward(ctx, grad_output)
-
Defines a formula for differentiating the operation.
This function is to be overridden by all subclasses.
It must accept a context :attr:
ctx
as the first argument, followed by as many outputs did :func:forward
return, and it should return as many tensors, as there were inputs to :func:forward
. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute :attr:
ctx.needs_input_grad
as a tuple of booleans representing whether each input needs gradient. E.g., :func:backward
will havectx.needs_input_grad[0] = True
if the first input to :func:forward
needs gradient computated w.r.t. the output.Expand source code
@staticmethod def backward(ctx, grad_output): # Extract the saved tensors and allocate memory for the gradients weights, values, topk, groups = ctx.saved_tensors grad_weights = torch.zeros_like(weights) grad_values = torch.zeros_like(values) if grad_output.stride()[-1] != 1: grad_output = grad_output.contiguous() ClusteredSparseWeightedAverage.avg_backward[weights.device.type]( weights, values, groups, topk, grad_output, grad_weights, grad_values ) return grad_weights, grad_values, None, None
def forward(ctx, weights, values, topk, groups)
-
Performs the operation.
This function is to be overridden by all subclasses.
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
The context can be used to store tensors that can be then retrieved during the backward pass.
Expand source code
@staticmethod def forward(ctx, weights, values, topk, groups): # Save the tensors to compute the gradient ctx.save_for_backward(weights, values, topk, groups) # Allocate the output tensor N, H, L, _ = weights.shape _, _, _, E = values.shape output = values.new_zeros(N, H, L, E) # Compute the average ClusteredSparseWeightedAverage.avg[weights.device.type]( weights, values, groups, topk, output ) return output
class SparseDotProduct (...)
-
Compute the dot products only at the positions specified by topk.
Expand source code
class SparseDotProduct(torch.autograd.Function): """Compute the dot products only at the positions specified by topk.""" dot = { "cpu": sparse_dot_product_cpu, "cuda": sparse_dot_product_cuda } dot_backward = { "cpu": sparse_dot_backward_cpu, "cuda": sparse_dot_backward_cuda } @staticmethod def forward(ctx, Q, K, topk): # Save the inputs to compute the gradient ctx.save_for_backward(Q, K, topk) # Create the output tensor device = Q.device N, H, L, E = Q.shape _, _, _, k = topk.shape product = torch.empty((N, H, L, k), device=device) # Actually perform the dot product SparseDotProduct.dot[device.type](Q, K, topk, product) return product @staticmethod def backward(ctx, grad_output): # Extract the saved tensors and allocate memory for the gradients Q, K, topk = ctx.saved_tensors grad_Q = torch.zeros_like(Q) grad_K = torch.zeros_like(K) SparseDotProduct.dot_backward[Q.device.type]( Q, K, topk, grad_output, grad_Q, grad_K ) return grad_Q, grad_K, None
Ancestors
- torch.autograd.function.Function
- torch._C._FunctionBase
- torch.autograd.function._ContextMethodMixin
- torch.autograd.function._HookMixin
Class variables
var dot
var dot_backward
Static methods
def backward(ctx, grad_output)
-
Defines a formula for differentiating the operation.
This function is to be overridden by all subclasses.
It must accept a context :attr:
ctx
as the first argument, followed by as many outputs did :func:forward
return, and it should return as many tensors, as there were inputs to :func:forward
. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute :attr:
ctx.needs_input_grad
as a tuple of booleans representing whether each input needs gradient. E.g., :func:backward
will havectx.needs_input_grad[0] = True
if the first input to :func:forward
needs gradient computated w.r.t. the output.Expand source code
@staticmethod def backward(ctx, grad_output): # Extract the saved tensors and allocate memory for the gradients Q, K, topk = ctx.saved_tensors grad_Q = torch.zeros_like(Q) grad_K = torch.zeros_like(K) SparseDotProduct.dot_backward[Q.device.type]( Q, K, topk, grad_output, grad_Q, grad_K ) return grad_Q, grad_K, None
def forward(ctx, Q, K, topk)
-
Performs the operation.
This function is to be overridden by all subclasses.
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
The context can be used to store tensors that can be then retrieved during the backward pass.
Expand source code
@staticmethod def forward(ctx, Q, K, topk): # Save the inputs to compute the gradient ctx.save_for_backward(Q, K, topk) # Create the output tensor device = Q.device N, H, L, E = Q.shape _, _, _, k = topk.shape product = torch.empty((N, H, L, k), device=device) # Actually perform the dot product SparseDotProduct.dot[device.type](Q, K, topk, product) return product
class SparseWeightedAverage (...)
-
Compute the weighted average only for the topk values.
Expand source code
class SparseWeightedAverage(torch.autograd.Function): """Compute the weighted average only for the topk values.""" avg = { "cpu": sparse_weighted_average_cpu, "cuda": sparse_weighted_average_cuda } avg_backward = { "cpu": sparse_weighted_average_backward_cpu, "cuda": sparse_weighted_average_backward_cuda } @staticmethod def forward(ctx, weights, values, topk): # Save the tensors to compute the gradient ctx.save_for_backward(weights, values, topk) # Allocate the output tensor N, H, L, _ = weights.shape _, _, _, E = values.shape output = values.new_zeros(N, H, L, E) # Compute the average SparseWeightedAverage.avg[weights.device.type]( weights, values, topk, output ) return output @staticmethod def backward(ctx, grad_output): # Extract the saved tensors and allocate memory for the gradients weights, values, topk = ctx.saved_tensors grad_weights = torch.zeros_like(weights) grad_values = torch.zeros_like(values) if grad_output.stride()[-1] != 1: grad_output = grad_output.contiguous() SparseWeightedAverage.avg_backward[weights.device.type]( weights, values, topk, grad_output, grad_weights, grad_values ) return grad_weights, grad_values, None
Ancestors
- torch.autograd.function.Function
- torch._C._FunctionBase
- torch.autograd.function._ContextMethodMixin
- torch.autograd.function._HookMixin
Class variables
var avg
var avg_backward
Static methods
def backward(ctx, grad_output)
-
Defines a formula for differentiating the operation.
This function is to be overridden by all subclasses.
It must accept a context :attr:
ctx
as the first argument, followed by as many outputs did :func:forward
return, and it should return as many tensors, as there were inputs to :func:forward
. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute :attr:
ctx.needs_input_grad
as a tuple of booleans representing whether each input needs gradient. E.g., :func:backward
will havectx.needs_input_grad[0] = True
if the first input to :func:forward
needs gradient computated w.r.t. the output.Expand source code
@staticmethod def backward(ctx, grad_output): # Extract the saved tensors and allocate memory for the gradients weights, values, topk = ctx.saved_tensors grad_weights = torch.zeros_like(weights) grad_values = torch.zeros_like(values) if grad_output.stride()[-1] != 1: grad_output = grad_output.contiguous() SparseWeightedAverage.avg_backward[weights.device.type]( weights, values, topk, grad_output, grad_weights, grad_values ) return grad_weights, grad_values, None
def forward(ctx, weights, values, topk)
-
Performs the operation.
This function is to be overridden by all subclasses.
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
The context can be used to store tensors that can be then retrieved during the backward pass.
Expand source code
@staticmethod def forward(ctx, weights, values, topk): # Save the tensors to compute the gradient ctx.save_for_backward(weights, values, topk) # Allocate the output tensor N, H, L, _ = weights.shape _, _, _, E = values.shape output = values.new_zeros(N, H, L, E) # Compute the average SparseWeightedAverage.avg[weights.device.type]( weights, values, topk, output ) return output