Module fast_transformers.events.filters
Define composable functions to filter events.
Expand source code
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>
#
"""Define composable functions to filter events."""
import weakref
from .event import Event
class EventFilter(object):
"""EventFilter instances are predicates (ie functions that return True or
False) to be used with an event dispatcher for filtering event
instances.
The main benefit from using raw functions is that an EventFilter composes
very easily using operators such as &, |, ~.
Example
--------
event_filter = AttentionEvent | layer_name_contains("layers.1")
event_filter = from_layer(transformer.layers[2].attention)
event_filter = (
AttentionEvent &
lambda ev: torch.isnan(ev.attention_matrix).any()
)
"""
def __call__(self, event):
raise NotImplementedError()
def _to_event_filter(self, other):
if isinstance(other, EventFilter):
return other
if isinstance(other, type) and issubclass(other, Event):
return event_class(other)
if callable(other):
return CallableEventFilter(other)
return NotImplemented
def __and__(self, other):
other = self._to_event_filter(other)
if other is NotImplemented:
return other
return CallableEventFilter(lambda ev: self(ev) and other(ev))
def __rand__(self, other):
other = self._to_event_filter(other)
if other is NotImplemented:
return other
return CallableEventFilter(lambda ev: other(ev) and self(ev))
def __or__(self, other):
other = self._to_event_filter(other)
if other is NotImplemented:
return other
return CallableEventFilter(lambda ev: self(ev) or other(ev))
def __ror__(self, other):
other = self._to_event_filter(other)
if other is NotImplemented:
return other
return CallableEventFilter(lambda ev: other(ev) or self(ev))
def __invert__(self):
return CallableEventFilter(lambda ev: not self(ev))
class CallableEventFilter(EventFilter):
"""Wrap a function with an EventFilter object."""
def __init__(self, event_filter):
self._event_filter = event_filter
def __call__(self, event):
return self._event_filter(event)
class LayerNameEventFilter(EventFilter):
"""A LayerNameEventFilter allows to filter events based on a human readable
name of the layer that emitted them.
Note that LayerNameEventFilter keeps a weak reference to all modules which
means that it cannot be used to prevent modules from being garbage
collected.
Arguments
---------
root: torch.nn.Module instance that represents the root container
name_filter: callable, that returns true if the name
"""
def __init__(self, root, name_filter):
self._names = {
weakref.ref(m): n
for n, m in root.named_modules()
}
self._name_filter = name_filter
def __call__(self, event):
name = self._names.get(weakref.ref(event.source), None)
if name is None:
return False
return self._name_filter(name)
def event_class(klass):
"""Select events that are instances of `klass`.
Arguments
---------
klass: A class to check the event instance against
Returns
-------
An instance of EventFilter
"""
return CallableEventFilter(lambda ev: isinstance(ev, klass))
def from_layer(layer):
"""Select events that are dispatched from the `layer`.
Arguments
---------
layer: An instance of torch.nn.Module to check against the event source
Returns
-------
An instance of EventFilter
"""
return CallableEventFilter(lambda ev: ev.source is layer)
def layer_name_contains(root, name):
"""Select events that contain `name` in their human readable name.
We use root.named_modules() to get human readable names for the layers.
"""
return LayerNameEventFilter(root, lambda n: name in n)
Functions
def event_class(klass)
-
Select events that are instances of
klass
.Arguments
klass: A class to check the event instance against
Returns
An instance of EventFilter
Expand source code
def event_class(klass): """Select events that are instances of `klass`. Arguments --------- klass: A class to check the event instance against Returns ------- An instance of EventFilter """ return CallableEventFilter(lambda ev: isinstance(ev, klass))
def from_layer(layer)
-
Select events that are dispatched from the
layer
.Arguments
layer: An instance of torch.nn.Module to check against the event source
Returns
An instance of EventFilter
Expand source code
def from_layer(layer): """Select events that are dispatched from the `layer`. Arguments --------- layer: An instance of torch.nn.Module to check against the event source Returns ------- An instance of EventFilter """ return CallableEventFilter(lambda ev: ev.source is layer)
def layer_name_contains(root, name)
-
Select events that contain
name
in their human readable name.We use root.named_modules() to get human readable names for the layers.
Expand source code
def layer_name_contains(root, name): """Select events that contain `name` in their human readable name. We use root.named_modules() to get human readable names for the layers. """ return LayerNameEventFilter(root, lambda n: name in n)
Classes
class CallableEventFilter (event_filter)
-
Wrap a function with an EventFilter object.
Expand source code
class CallableEventFilter(EventFilter): """Wrap a function with an EventFilter object.""" def __init__(self, event_filter): self._event_filter = event_filter def __call__(self, event): return self._event_filter(event)
Ancestors
class EventFilter
-
EventFilter instances are predicates (ie functions that return True or False) to be used with an event dispatcher for filtering event instances.
The main benefit from using raw functions is that an EventFilter composes very easily using operators such as &, |, ~.
Example
event_filter = AttentionEvent | layer_name_contains("layers.1") event_filter = from_layer(transformer.layers[2].attention) event_filter = ( AttentionEvent & lambda ev: torch.isnan(ev.attention_matrix).any() )
Expand source code
class EventFilter(object): """EventFilter instances are predicates (ie functions that return True or False) to be used with an event dispatcher for filtering event instances. The main benefit from using raw functions is that an EventFilter composes very easily using operators such as &, |, ~. Example -------- event_filter = AttentionEvent | layer_name_contains("layers.1") event_filter = from_layer(transformer.layers[2].attention) event_filter = ( AttentionEvent & lambda ev: torch.isnan(ev.attention_matrix).any() ) """ def __call__(self, event): raise NotImplementedError() def _to_event_filter(self, other): if isinstance(other, EventFilter): return other if isinstance(other, type) and issubclass(other, Event): return event_class(other) if callable(other): return CallableEventFilter(other) return NotImplemented def __and__(self, other): other = self._to_event_filter(other) if other is NotImplemented: return other return CallableEventFilter(lambda ev: self(ev) and other(ev)) def __rand__(self, other): other = self._to_event_filter(other) if other is NotImplemented: return other return CallableEventFilter(lambda ev: other(ev) and self(ev)) def __or__(self, other): other = self._to_event_filter(other) if other is NotImplemented: return other return CallableEventFilter(lambda ev: self(ev) or other(ev)) def __ror__(self, other): other = self._to_event_filter(other) if other is NotImplemented: return other return CallableEventFilter(lambda ev: other(ev) or self(ev)) def __invert__(self): return CallableEventFilter(lambda ev: not self(ev))
Subclasses
class LayerNameEventFilter (root, name_filter)
-
A LayerNameEventFilter allows to filter events based on a human readable name of the layer that emitted them.
Note that LayerNameEventFilter keeps a weak reference to all modules which means that it cannot be used to prevent modules from being garbage collected.
Arguments
root: torch.nn.Module instance that represents the root container name_filter: callable, that returns true if the name
Expand source code
class LayerNameEventFilter(EventFilter): """A LayerNameEventFilter allows to filter events based on a human readable name of the layer that emitted them. Note that LayerNameEventFilter keeps a weak reference to all modules which means that it cannot be used to prevent modules from being garbage collected. Arguments --------- root: torch.nn.Module instance that represents the root container name_filter: callable, that returns true if the name """ def __init__(self, root, name_filter): self._names = { weakref.ref(m): n for n, m in root.named_modules() } self._name_filter = name_filter def __call__(self, event): name = self._names.get(weakref.ref(event.source), None) if name is None: return False return self._name_filter(name)
Ancestors