Module fast_transformers.bucket_product.bucket_product_cpu
Functions
def bucket_qksum(...)
-
bucket_qksum(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor, arg4: int, arg5: float) -> Tuple[at::Tensor, at::Tensor]
Compute Q K.sum(-2) where Q and K are bucket sparse matrices.
def bucket_qksum_grad(...)
-
bucket_qksum_grad(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor, arg4: at::Tensor, arg5: at::Tensor, arg6: int, arg7: float) -> Tuple[at::Tensor, at::Tensor]
Compute the gradient wrt Q and K for the bucket_qsum operation.
def bucket_qkv(...)
-
bucket_qkv(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor, arg4: at::Tensor, arg5: int, arg6: float) -> Tuple[at::Tensor, at::Tensor]
Compute Q K^T V where Q and K are bucket sparse matrices.
def bucket_qkv_grad(...)
-
bucket_qkv_grad(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor, arg4: at::Tensor, arg5: at::Tensor, arg6: at::Tensor, arg7: int, arg8: float) -> Tuple[at::Tensor, at::Tensor, at::Tensor]
Compute the gradient wrt Q, K and V for the bucket_qkv operation.