Module fast_transformers.bucket_product.bucket_product_cpu

Functions

def bucket_qksum(...)

bucket_qksum(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor, arg4: int, arg5: float) -> Tuple[at::Tensor, at::Tensor]

Compute Q K.sum(-2) where Q and K are bucket sparse matrices.

def bucket_qksum_grad(...)

bucket_qksum_grad(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor, arg4: at::Tensor, arg5: at::Tensor, arg6: int, arg7: float) -> Tuple[at::Tensor, at::Tensor]

Compute the gradient wrt Q and K for the bucket_qsum operation.

def bucket_qkv(...)

bucket_qkv(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor, arg4: at::Tensor, arg5: int, arg6: float) -> Tuple[at::Tensor, at::Tensor]

Compute Q K^T V where Q and K are bucket sparse matrices.

def bucket_qkv_grad(...)

bucket_qkv_grad(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor, arg4: at::Tensor, arg5: at::Tensor, arg6: at::Tensor, arg7: int, arg8: float) -> Tuple[at::Tensor, at::Tensor, at::Tensor]

Compute the gradient wrt Q, K and V for the bucket_qkv operation.