Module fast_transformers.aggregate.aggregate_cpu
Functions
def aggregate(...)
-
aggregate(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor) -> None
Aggregate the vectors of X based on the indices in groups G multiplied by factors F.
def broadcast(...)
-
broadcast(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor) -> None
Broadcast the aggregated vectors Y back to Xbased on the indices in groups G multiplied bythe factors F.