Module fast_transformers.aggregate.aggregate_cuda

Functions

def aggregate(...)

aggregate(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor) -> None

Aggregate the vectors of X based on the indices in G multiplied by F.

def broadcast(...)

broadcast(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor) -> None

Broadcast the vectors of Y based on the indices in G multiplied by F back to X.