Module fast_transformers.local_product.local_product_cpu

Functions

def local_dot_backward(...)

local_dot_backward(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor, arg4: int) -> Tuple[at::Tensor, at::Tensor]

Compute the gradient of local_dot_product

def local_dot_product(...)

local_dot_product(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor, arg4: int) -> at::Tensor

Compute the product of Q and K for a small context around each Q

def local_weighted_average(...)

local_weighted_average(arg0: at::Tensor, arg1: at::Tensor) -> at::Tensor

Perform the weighted average of V for a small context around each Q

def local_weighted_average_backward(...)

local_weighted_average_backward(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor) -> Tuple[at::Tensor, at::Tensor]

Compute the gradient of the local weighted average