Module fast_transformers.aggregate.aggregate_cuda
Functions
def aggregate(...)
-
aggregate(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor) -> None
Aggregate the vectors of X based on the indices in G multiplied by F.
def broadcast(...)
-
broadcast(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor) -> None
Broadcast the vectors of Y based on the indices in G multiplied by F back to X.