Module fast_transformers.sparse_product.sparse_product_cuda
Functions
def sparse_dot_backward(...)
-
sparse_dot_backward(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor, arg4: at::Tensor, arg5: at::Tensor) -> None
Compute the gradients for the sparse dot product.
def sparse_dot_product(...)
-
sparse_dot_product(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor) -> None
Compute the dot product only in the positions specified by topk.
def sparse_weighted_average(...)
-
sparse_weighted_average(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor) -> None
Average the values using the sparse attention.
def sparse_weighted_average_backward(...)
-
sparse_weighted_average_backward(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor, arg4: at::Tensor, arg5: at::Tensor) -> None
Compute the gradients for the sparse weighted average.