Module fast_transformers.attention_registry.registry
Expand source code
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>
#
class Registry(object):
"""Hold the available attention implementations and their required
parameters."""
def __init__(self):
self._classes = {}
self._class_params = {}
self._parameters = {}
def register(self, key, class_object, parameter_tuples):
# register the class if the key is new
if key in self._classes:
raise ValueError("{} is already registered".format(key))
self._classes[key] = class_object
# register the parameters
for parameter, spec in parameter_tuples:
if (
parameter in self._parameters and
self._parameters[parameter] != spec
):
raise ValueError(("{} is already registered with "
"spec {!r} instead of {!r}").format(
parameter,
self._parameters[parameter],
spec
))
self._parameters[parameter] = spec
# note which parameters are needed by this class
self._class_params[key] = [p for p, s in parameter_tuples]
def __contains__(self, key):
return key in self._classes
def __getitem__(self, key):
return self._classes[key], self._class_params[key]
@property
def keys(self):
return list(self._classes.keys())
def contains_parameter(self, key):
return key in self._parameters
def validate_parameter(self, key, value):
try:
return self._parameters[key].get(value)
except Exception as e:
raise ValueError(("Invalid value {!r} for "
"parameter {!r}").format(value, key)) from e
AttentionRegistry = Registry()
RecurrentAttentionRegistry = Registry()
RecurrentCrossAttentionRegistry = Registry()
Classes
class Registry
-
Hold the available attention implementations and their required parameters.
Expand source code
class Registry(object): """Hold the available attention implementations and their required parameters.""" def __init__(self): self._classes = {} self._class_params = {} self._parameters = {} def register(self, key, class_object, parameter_tuples): # register the class if the key is new if key in self._classes: raise ValueError("{} is already registered".format(key)) self._classes[key] = class_object # register the parameters for parameter, spec in parameter_tuples: if ( parameter in self._parameters and self._parameters[parameter] != spec ): raise ValueError(("{} is already registered with " "spec {!r} instead of {!r}").format( parameter, self._parameters[parameter], spec )) self._parameters[parameter] = spec # note which parameters are needed by this class self._class_params[key] = [p for p, s in parameter_tuples] def __contains__(self, key): return key in self._classes def __getitem__(self, key): return self._classes[key], self._class_params[key] @property def keys(self): return list(self._classes.keys()) def contains_parameter(self, key): return key in self._parameters def validate_parameter(self, key, value): try: return self._parameters[key].get(value) except Exception as e: raise ValueError(("Invalid value {!r} for " "parameter {!r}").format(value, key)) from e
Instance variables
var keys
-
Expand source code
@property def keys(self): return list(self._classes.keys())
Methods
def contains_parameter(self, key)
-
Expand source code
def contains_parameter(self, key): return key in self._parameters
def register(self, key, class_object, parameter_tuples)
-
Expand source code
def register(self, key, class_object, parameter_tuples): # register the class if the key is new if key in self._classes: raise ValueError("{} is already registered".format(key)) self._classes[key] = class_object # register the parameters for parameter, spec in parameter_tuples: if ( parameter in self._parameters and self._parameters[parameter] != spec ): raise ValueError(("{} is already registered with " "spec {!r} instead of {!r}").format( parameter, self._parameters[parameter], spec )) self._parameters[parameter] = spec # note which parameters are needed by this class self._class_params[key] = [p for p, s in parameter_tuples]
def validate_parameter(self, key, value)
-
Expand source code
def validate_parameter(self, key, value): try: return self._parameters[key].get(value) except Exception as e: raise ValueError(("Invalid value {!r} for " "parameter {!r}").format(value, key)) from e