Module fast_transformers.causal_product.causal_product_cpu
Functions
def causal_dot_backward(...)
-
causal_dot_backward(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor, arg4: at::Tensor, arg5: at::Tensor, arg6: at::Tensor) -> None
Compute the gradient of queries, keys and values given the gradient of causal_dot_product.
def causal_dot_product(...)
-
causal_dot_product(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor) -> None
Compute the weighted sum of values but attending only to previous values.