Module fast_transformers.aggregate

Expand source code
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
# Apoorv Vyas <avyas@idiap.ch>
#


import torch

from .aggregate_cpu import aggregate as aggregate_cpu, \
    broadcast as broadcast_cpu
try:
    from .aggregate_cuda import aggregate as aggregate_gpu, \
        broadcast as broadcast_gpu
    from .clustered_broadcast_cuda import \
        clustered_broadcast as clustered_broadcast_gpu

except ImportError:
    pass


def aggregate(X, G, F, Y=None):
    device = X.device
    if Y is None:
        Y = torch.zeros(
            F.shape + (X.shape[-1],),
            device=device,
            dtype=X.dtype
        )
    else:
        Y.zero_()

    if device.type == "cpu":
        aggregate_cpu(X, G, F, Y)
    else:
        aggregate_gpu(X, G, F, Y)

    return Y


def broadcast(Y, G, F, X=None):
    device = Y.device
    if X is None:
        X = torch.zeros(
            G.shape + (Y.shape[-1],),
            device=device,
            dtype=Y.dtype
        )

    if device.type == "cpu":
        broadcast_cpu(Y, G, F, X)
    else:
        broadcast_gpu(Y, G, F, X)

    return X
    

def clustered_broadcast(Y, groups, counts, lengths, X=None):
    device = Y.device
    if X is None:
        X = torch.zeros(
            groups.shape + (Y.shape[-1],),
            device=device,
            dtype=Y.dtype
        )

    if device.type == "cpu":
        raise NotImplementedError
    else:
        N, H, C, E = Y.shape
        _, _, L, E = X.shape
   
        queries_per_block = min(L, 1024) 
        threads = queries_per_block
        blocks = (L//threads) + C + 1
        query_map = torch.ones((N, H, blocks),
                               dtype=torch.int32,
                               device=Y.device) * L 
        blocks_map = torch.ones_like(query_map,
                                     dtype=torch.int32,
                                     device=Y.device) * -1 
        _, sorted_group_indices = torch.sort(groups, descending=True, dim=-1)
        factors = torch.ones_like(counts, dtype=Y.dtype)
        clustered_broadcast_gpu(
            Y,
            groups,
            factors,
            X,
            lengths,
            blocks_map,
            query_map,
            counts,
            sorted_group_indices,
        )

    return X

Sub-modules

fast_transformers.aggregate.aggregate_cpu
fast_transformers.aggregate.aggregate_cuda
fast_transformers.aggregate.clustered_broadcast_cuda

Functions

def aggregate(X, G, F, Y=None)
Expand source code
def aggregate(X, G, F, Y=None):
    device = X.device
    if Y is None:
        Y = torch.zeros(
            F.shape + (X.shape[-1],),
            device=device,
            dtype=X.dtype
        )
    else:
        Y.zero_()

    if device.type == "cpu":
        aggregate_cpu(X, G, F, Y)
    else:
        aggregate_gpu(X, G, F, Y)

    return Y
def broadcast(Y, G, F, X=None)
Expand source code
def broadcast(Y, G, F, X=None):
    device = Y.device
    if X is None:
        X = torch.zeros(
            G.shape + (Y.shape[-1],),
            device=device,
            dtype=Y.dtype
        )

    if device.type == "cpu":
        broadcast_cpu(Y, G, F, X)
    else:
        broadcast_gpu(Y, G, F, X)

    return X
def clustered_broadcast(Y, groups, counts, lengths, X=None)
Expand source code
def clustered_broadcast(Y, groups, counts, lengths, X=None):
    device = Y.device
    if X is None:
        X = torch.zeros(
            groups.shape + (Y.shape[-1],),
            device=device,
            dtype=Y.dtype
        )

    if device.type == "cpu":
        raise NotImplementedError
    else:
        N, H, C, E = Y.shape
        _, _, L, E = X.shape
   
        queries_per_block = min(L, 1024) 
        threads = queries_per_block
        blocks = (L//threads) + C + 1
        query_map = torch.ones((N, H, blocks),
                               dtype=torch.int32,
                               device=Y.device) * L 
        blocks_map = torch.ones_like(query_map,
                                     dtype=torch.int32,
                                     device=Y.device) * -1 
        _, sorted_group_indices = torch.sort(groups, descending=True, dim=-1)
        factors = torch.ones_like(counts, dtype=Y.dtype)
        clustered_broadcast_gpu(
            Y,
            groups,
            factors,
            X,
            lengths,
            blocks_map,
            query_map,
            counts,
            sorted_group_indices,
        )

    return X