Module fast_transformers.attention_registry.spec
Spec instances allow to describe and check the type and value of parameters.
Expand source code
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>
#
"""Spec instances allow to describe and check the type and value of
parameters."""
from ..events import EventDispatcher
class Spec(object):
"""Describe and validate a parameter type.
Arguments
---------
predicate: A callable that checks if the value is acceptable and
returns its canonical value or raises ValueError.
name: A name to create a human readable description of the Spec
"""
def __init__(self, predicate, name="CustomSpec"):
self._predicate = predicate
self._name = name
def __repr__(self):
return self._name
def check(self, x):
try:
self._predicate(x)
return True
except ValueError:
return False
def get(self, x):
return self._predicate(x)
def __eq__(self, y):
return self is y
class Choice(Spec):
"""A parameter type for a set of options.
Arguments
---------
choices: A set or list of possible values for this parameter
"""
def __init__(self, choices):
self._choices = choices
def get(self, x):
if x in self._choices:
return x
raise ValueError("{!r} is not in {!r}".format(x, self._choices))
def __repr__(self):
return "Choice({!r})".format(self._choices)
def __eq__(self, x):
if isinstance(x, Choice):
return self._choices == x._choices
return False
class _Callable(Spec):
def __init__(self):
super(_Callable, self).__init__(None, "Callable")
def get(self, x):
if callable(x):
return x
raise ValueError("{!r} is not a callable".format(x))
class _EventDispatcherInstance(Spec):
def __init__(self):
super(_EventDispatcherInstance, self).__init__(
_EventDispatcherInstance._get_event_dispatcher,
"EventDispatcherInstance"
)
@staticmethod
def _get_event_dispatcher(x):
if isinstance(x, str):
return x
if isinstance(x, EventDispatcher):
return x
raise ValueError("{!r} is not an event dispatcher".format(x))
class Optional(Spec):
"""Represent an optional parameter that can either have a value or it can
be None.
Arguments
---------
spec: The spec for the value if it is not None
default: The returned value in case it is None
"""
def __init__(self, spec, default=None):
self._other_spec = spec
self._default = default
def __repr__(self):
return "Optional[{!r}, {!r}]".format(self._other_spec, self._default)
def get(self, x):
if x is None:
return self._default
return self._other_spec.get(x)
def __eq__(self, x):
if isinstance(x, Optional):
return (
self._other_spec == x._other_spec and
self._default == x._default
)
return False
Int = Spec(int, "Int")
Float = Spec(float, "Float")
Bool = Spec(bool, "Bool")
Callable = _Callable()
EventDispatcherInstance = _EventDispatcherInstance()
Classes
class Choice (choices)
-
A parameter type for a set of options.
Arguments
choices: A set or list of possible values for this parameter
Expand source code
class Choice(Spec): """A parameter type for a set of options. Arguments --------- choices: A set or list of possible values for this parameter """ def __init__(self, choices): self._choices = choices def get(self, x): if x in self._choices: return x raise ValueError("{!r} is not in {!r}".format(x, self._choices)) def __repr__(self): return "Choice({!r})".format(self._choices) def __eq__(self, x): if isinstance(x, Choice): return self._choices == x._choices return False
Ancestors
Methods
def get(self, x)
-
Expand source code
def get(self, x): if x in self._choices: return x raise ValueError("{!r} is not in {!r}".format(x, self._choices))
class Optional (spec, default=None)
-
Represent an optional parameter that can either have a value or it can be None.
Arguments
spec: The spec for the value if it is not None default: The returned value in case it is None
Expand source code
class Optional(Spec): """Represent an optional parameter that can either have a value or it can be None. Arguments --------- spec: The spec for the value if it is not None default: The returned value in case it is None """ def __init__(self, spec, default=None): self._other_spec = spec self._default = default def __repr__(self): return "Optional[{!r}, {!r}]".format(self._other_spec, self._default) def get(self, x): if x is None: return self._default return self._other_spec.get(x) def __eq__(self, x): if isinstance(x, Optional): return ( self._other_spec == x._other_spec and self._default == x._default ) return False
Ancestors
Methods
def get(self, x)
-
Expand source code
def get(self, x): if x is None: return self._default return self._other_spec.get(x)
class Spec (predicate, name='CustomSpec')
-
Describe and validate a parameter type.
Arguments
predicate: A callable that checks if the value is acceptable and returns its canonical value or raises ValueError. name: A name to create a human readable description of the Spec
Expand source code
class Spec(object): """Describe and validate a parameter type. Arguments --------- predicate: A callable that checks if the value is acceptable and returns its canonical value or raises ValueError. name: A name to create a human readable description of the Spec """ def __init__(self, predicate, name="CustomSpec"): self._predicate = predicate self._name = name def __repr__(self): return self._name def check(self, x): try: self._predicate(x) return True except ValueError: return False def get(self, x): return self._predicate(x) def __eq__(self, y): return self is y
Subclasses
- Choice
- Optional
- fast_transformers.attention_registry.spec._Callable
- fast_transformers.attention_registry.spec._EventDispatcherInstance
Methods
def check(self, x)
-
Expand source code
def check(self, x): try: self._predicate(x) return True except ValueError: return False
def get(self, x)
-
Expand source code
def get(self, x): return self._predicate(x)