Module fast_transformers.causal_product.causal_product_cuda
Functions
def causal_dot_backward(...)
-
causal_dot_backward(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor, arg4: at::Tensor, arg5: at::Tensor, arg6: at::Tensor) -> None
Compute the gradients for the causal dot product.
def causal_dot_product(...)
-
causal_dot_product(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor) -> None
Compute the weighted sum of values but attending only to previous values.