Module fast_transformers.utils

Boilerplate code for dealing with fast_transformers modules.

Expand source code
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>
#

"""Boilerplate code for dealing with fast_transformers modules."""


def make_mirror(src_module, dst_module):
    """Sets the parameters of src_module to dst_module so that they share the
    same parameters.

    Most noteable usecase is to make a recurrent transformer mirror of a batch
    transformer for fast inference.

    Arguments
    ---------
        src_module: Module to take the parameters from
        dst_module: Module to set the parameters to

    Returns
    -------
        None, it changes dst_module in place
    """
    def setattr_recursive(mod, key, value):
        key, *next_key = key.split(".", maxsplit=1)
        if not next_key:
            setattr(mod, key, value)
        else:
            setattr_recursive(getattr(mod, key), next_key[0], value)

    for name, param in src_module.named_parameters():
        setattr_recursive(dst_module, name, param)

Functions

def make_mirror(src_module, dst_module)

Sets the parameters of src_module to dst_module so that they share the same parameters.

Most noteable usecase is to make a recurrent transformer mirror of a batch transformer for fast inference.

Arguments

src_module: Module to take the parameters from
dst_module: Module to set the parameters to

Returns

None, it changes dst_module in place
Expand source code
def make_mirror(src_module, dst_module):
    """Sets the parameters of src_module to dst_module so that they share the
    same parameters.

    Most noteable usecase is to make a recurrent transformer mirror of a batch
    transformer for fast inference.

    Arguments
    ---------
        src_module: Module to take the parameters from
        dst_module: Module to set the parameters to

    Returns
    -------
        None, it changes dst_module in place
    """
    def setattr_recursive(mod, key, value):
        key, *next_key = key.split(".", maxsplit=1)
        if not next_key:
            setattr(mod, key, value)
        else:
            setattr_recursive(getattr(mod, key), next_key[0], value)

    for name, param in src_module.named_parameters():
        setattr_recursive(dst_module, name, param)