Module fast_transformers.causal_product.causal_product_cuda

Functions

def causal_dot_backward(...)

causal_dot_backward(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor, arg4: at::Tensor, arg5: at::Tensor, arg6: at::Tensor) -> None

Compute the gradients for the causal dot product.

def causal_dot_product(...)

causal_dot_product(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor) -> None

Compute the weighted sum of values but attending only to previous values.