Module fast_transformers.aggregate.clustered_broadcast_cuda
Functions
def clustered_broadcast(...)
-
clustered_broadcast(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor, arg4: at::Tensor, arg5: at::Tensor, arg6: at::Tensor, arg7: at::Tensor, arg8: at::Tensor) -> None
Broadcast the vectors of Y based on the indices in G multiplied by F back to X.