Module fast_transformers.sparse_product.clustered_sparse_product_cuda

Functions

def clustered_sparse_dot_backward(...)

clustered_sparse_dot_backward(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor, arg4: at::Tensor, arg5: at::Tensor, arg6: at::Tensor) -> None

Compute the gradients for the sparse dot product.

def clustered_sparse_dot_product(...)

clustered_sparse_dot_product(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor, arg4: at::Tensor, arg5: at::Tensor, arg6: at::Tensor, arg7: at::Tensor, arg8: at::Tensor) -> None

Compute the dot product only in the positions specified by topk.

def clustered_sparse_weighted_average(...)

clustered_sparse_weighted_average(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor, arg4: at::Tensor) -> None

Average the values using the sparse attention.

def clustered_sparse_weighted_average_backward(...)

clustered_sparse_weighted_average_backward(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor, arg4: at::Tensor, arg5: at::Tensor, arg6: at::Tensor) -> None

Compute the gradients for the sparse weighted average.