Module fast_transformers.aggregate.aggregate_cpu

Functions

def aggregate(...)

aggregate(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor) -> None

Aggregate the vectors of X based on the indices in groups G multiplied by factors F.

def broadcast(...)

broadcast(arg0: at::Tensor, arg1: at::Tensor, arg2: at::Tensor, arg3: at::Tensor) -> None

Broadcast the aggregated vectors Y back to Xbased on the indices in groups G multiplied bythe factors F.