Module fast_transformers.clustering.hamming
Expand source code
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
# Apoorv Vyas <avyas@idiap.ch>
#
import torch
from .cluster_cpu import cluster as cluster_cpu
try:
from .cluster_cuda import cluster as cluster_gpu
except ImportError:
pass
def cluster(
hashes,
lengths,
groups=None,
counts=None,
centroids=None,
distances=None,
bitcounts=None,
clusters=30,
iterations=10,
bits=32
):
"""Cluster hashes using a few iterations of K-Means with hamming distance.
All the tensors default initialized to None are optional buffers to avoid
memory allocations. distances and bitcounts are only used by the CUDA
version of this call. clusters will be ignored if centroids is provided.
Arguments
---------
hashes: A long tensor of shape (N, H, L) containing a hashcode for each
query.
lengths: An int tensor of shape (N,) containing the sequence length for
each sequence in hashes.
groups: An int tensor buffer of shape (N, H, L) contaning the cluster
in which the corresponding hash belongs to.
counts: An int tensor buffer of shape (N, H, K) containing the number
of elements in each cluster.
centroids: A long tensor buffer of shape (N, H, K) containing the
centroid for each cluster.
distances: An int tensor of shape (N, H, L) containing the distance to
the closest centroid for each hash.
bitcounts: An int tensor of shape (N, H, K, bits) containing the number
of elements that have 1 for a given bit.
clusters: The number of clusters to use for each sequence. It is
ignored if centroids is not None.
iterations: How many k-means iterations to perform.
bits: How many of the least-significant bits in hashes to consider.
Returns
-------
groups and counts as defined above.
"""
device = hashes.device
N, H, L = hashes.shape
# Unfortunately cpu and gpu have different APIs so the entire call must be
# surrounded by an if-then-else
if device.type == "cpu":
if groups is None:
groups = torch.empty((N, H, L), dtype=torch.int32)
if centroids is None:
centroids = torch.empty((N, H, clusters), dtype=torch.int64)
K = centroids.shape[2]
if counts is None:
counts = torch.empty((N, H, K), dtype=torch.int32)
cluster_cpu(
hashes, lengths,
centroids, groups, counts,
iterations, bits
)
return groups, counts
else:
if groups is None:
groups = torch.empty((N, H, L), dtype=torch.int32, device=device)
if centroids is None:
centroids = torch.empty((N, H, clusters), dtype=torch.int64,
device=device)
K = centroids.numel() // N // H
#K = clusters
if counts is None:
counts = torch.empty((N, H, K), dtype=torch.int32, device=device)
if distances is None:
distances = torch.empty((N, H, L), dtype=torch.int32,
device=device)
if bitcounts is None:
bitcounts = torch.empty((N, H, K, bits), dtype=torch.int32,
device=device)
groups = groups.view(N, H, L)
counts = counts.view(N, H, K)
centroids = centroids.view(N, H, K)
distances = distances.view(N, H, L)
bitcounts = bitcounts.view(N, H, K, -1)
cluster_gpu(
hashes, lengths,
centroids, distances, bitcounts, groups, counts,
iterations, bits
)
return groups, counts
Sub-modules
fast_transformers.clustering.hamming.cluster_cpu
fast_transformers.clustering.hamming.cluster_cuda
Functions
def cluster(hashes, lengths, groups=None, counts=None, centroids=None, distances=None, bitcounts=None, clusters=30, iterations=10, bits=32)
-
Cluster hashes using a few iterations of K-Means with hamming distance.
All the tensors default initialized to None are optional buffers to avoid memory allocations. distances and bitcounts are only used by the CUDA version of this call. clusters will be ignored if centroids is provided.
Arguments
hashes: A long tensor of shape (N, H, L) containing a hashcode for each query. lengths: An int tensor of shape (N,) containing the sequence length for each sequence in hashes. groups: An int tensor buffer of shape (N, H, L) contaning the cluster in which the corresponding hash belongs to. counts: An int tensor buffer of shape (N, H, K) containing the number of elements in each cluster. centroids: A long tensor buffer of shape (N, H, K) containing the centroid for each cluster. distances: An int tensor of shape (N, H, L) containing the distance to the closest centroid for each hash. bitcounts: An int tensor of shape (N, H, K, bits) containing the number of elements that have 1 for a given bit. clusters: The number of clusters to use for each sequence. It is ignored if centroids is not None. iterations: How many k-means iterations to perform. bits: How many of the least-significant bits in hashes to consider.
Returns
groups and counts as defined above.
Expand source code
def cluster( hashes, lengths, groups=None, counts=None, centroids=None, distances=None, bitcounts=None, clusters=30, iterations=10, bits=32 ): """Cluster hashes using a few iterations of K-Means with hamming distance. All the tensors default initialized to None are optional buffers to avoid memory allocations. distances and bitcounts are only used by the CUDA version of this call. clusters will be ignored if centroids is provided. Arguments --------- hashes: A long tensor of shape (N, H, L) containing a hashcode for each query. lengths: An int tensor of shape (N,) containing the sequence length for each sequence in hashes. groups: An int tensor buffer of shape (N, H, L) contaning the cluster in which the corresponding hash belongs to. counts: An int tensor buffer of shape (N, H, K) containing the number of elements in each cluster. centroids: A long tensor buffer of shape (N, H, K) containing the centroid for each cluster. distances: An int tensor of shape (N, H, L) containing the distance to the closest centroid for each hash. bitcounts: An int tensor of shape (N, H, K, bits) containing the number of elements that have 1 for a given bit. clusters: The number of clusters to use for each sequence. It is ignored if centroids is not None. iterations: How many k-means iterations to perform. bits: How many of the least-significant bits in hashes to consider. Returns ------- groups and counts as defined above. """ device = hashes.device N, H, L = hashes.shape # Unfortunately cpu and gpu have different APIs so the entire call must be # surrounded by an if-then-else if device.type == "cpu": if groups is None: groups = torch.empty((N, H, L), dtype=torch.int32) if centroids is None: centroids = torch.empty((N, H, clusters), dtype=torch.int64) K = centroids.shape[2] if counts is None: counts = torch.empty((N, H, K), dtype=torch.int32) cluster_cpu( hashes, lengths, centroids, groups, counts, iterations, bits ) return groups, counts else: if groups is None: groups = torch.empty((N, H, L), dtype=torch.int32, device=device) if centroids is None: centroids = torch.empty((N, H, clusters), dtype=torch.int64, device=device) K = centroids.numel() // N // H #K = clusters if counts is None: counts = torch.empty((N, H, K), dtype=torch.int32, device=device) if distances is None: distances = torch.empty((N, H, L), dtype=torch.int32, device=device) if bitcounts is None: bitcounts = torch.empty((N, H, K, bits), dtype=torch.int32, device=device) groups = groups.view(N, H, L) counts = counts.view(N, H, K) centroids = centroids.view(N, H, K) distances = distances.view(N, H, L) bitcounts = bitcounts.view(N, H, K, -1) cluster_gpu( hashes, lengths, centroids, distances, bitcounts, groups, counts, iterations, bits ) return groups, counts