Module fast_transformers.sparse_product.clustered_sparse_product_cuda
Functions
def clustered_sparse_dot_backward(...)
-
clustered_sparse_dot_backward(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor, arg4: at::Tensor, arg5: at::Tensor, arg6: at::Tensor) -> None
Compute the gradients for the sparse dot product.
def clustered_sparse_dot_product(...)
-
clustered_sparse_dot_product(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor, arg4: at::Tensor, arg5: at::Tensor, arg6: at::Tensor, arg7: at::Tensor, arg8: at::Tensor) -> None
Compute the dot product only in the positions specified by topk.
def clustered_sparse_weighted_average(...)
-
clustered_sparse_weighted_average(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor, arg4: at::Tensor) -> None
Average the values using the sparse attention.
def clustered_sparse_weighted_average_backward(...)
-
clustered_sparse_weighted_average_backward(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor, arg4: at::Tensor, arg5: at::Tensor, arg6: at::Tensor) -> None
Compute the gradients for the sparse weighted average.