Module fast_transformers.aggregate.clustered_broadcast_cuda

Functions

def clustered_broadcast(...)

clustered_broadcast(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor, arg4: at::Tensor, arg5: at::Tensor, arg6: at::Tensor, arg7: at::Tensor, arg8: at::Tensor) -> None

Broadcast the vectors of Y based on the indices in G multiplied by F back to X.