Module fast_transformers.local_product.local_product_cuda
Functions
def local_dot_backward(...)
-
local_dot_backward(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor, arg4: int) -> Tuple[at::Tensor, at::Tensor]
Compute the gradient of local_dot_product
def local_dot_product(...)
-
local_dot_product(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor, arg4: int) -> at::Tensor
Compute the product of Q and K for a small context around each Q
def local_weighted_average(...)
-
local_weighted_average(arg0: at::Tensor, arg1: at::Tensor) -> at::Tensor
Perform the weighted average of V for a small context around each Q
def local_weighted_average_backward(...)
-
local_weighted_average_backward(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor) -> Tuple[at::Tensor, at::Tensor]
Compute the gradient of the local weighted average