Module fast_transformers.causal_product.causal_product_cpu

Functions

def causal_dot_backward(...)

causal_dot_backward(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor, arg4: at::Tensor, arg5: at::Tensor, arg6: at::Tensor) -> None

Compute the gradient of queries, keys and values given the gradient of causal_dot_product.

def causal_dot_product(...)

causal_dot_product(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor) -> None

Compute the weighted sum of values but attending only to previous values.