Module fast_transformers.sparse_product.sparse_product_cuda

Functions

def sparse_dot_backward(...)

sparse_dot_backward(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor, arg4: at::Tensor, arg5: at::Tensor) -> None

Compute the gradients for the sparse dot product.

def sparse_dot_product(...)

sparse_dot_product(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor) -> None

Compute the dot product only in the positions specified by topk.

def sparse_weighted_average(...)

sparse_weighted_average(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor) -> None

Average the values using the sparse attention.

def sparse_weighted_average_backward(...)

sparse_weighted_average_backward(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor, arg4: at::Tensor, arg5: at::Tensor) -> None

Compute the gradients for the sparse weighted average.