Skip to content

Overview

MTRNN is a type of RNN consisting of a hierarchical group of neurons with different firing rates1. It consists of three layers: an input-output (IO) layer and context layers (Cf and Cs layers) with different firing rates (time constants), each with recursive inputs. The time constants increase from the Cf layer to the Cs layer, resulting in slower response speeds to the input. The input information is then passed through the Cf and Cs layers to the output layer. There is no direct connection between the IO and Cs layers, and their interaction occurs through the Cf layer. The MTRNN allows the robot to learn behaviors, where the Cf layer represents behavioral primitives and the Cs layer represents learning the combination of these primitives. Compared to LSTM, MTRNN is more interpretable and is widely used in our lab.

MTRNN

MTRNN.MTRNNCell

Bases: nn.Module

Multiple Timescale RNN.

Implements a form of Recurrent Neural Network (RNN) that operates with multiple timescales. This is based on the idea of hierarchical organization in human cognitive functions.

Parameters:

Name Type Description Default
input_dim int

Number of input features.

required
fast_dim int

Number of fast context neurons.

required
slow_dim int

Number of slow context neurons.

required
fast_tau float

Time constant value of fast context.

required
slow_tau float

Time constant value of slow context.

required
activation string

If you set None, no activation is applied (ie. "linear" activation: a(x) = x).

'tanh'
use_bias Boolean

whether the layer uses a bias vector. The default is False.

False
use_pb Boolean

whether the recurrent uses a pb vector. The default is False.

False

Yuichi Yamashita, Jun Tani, "Emergence of Functional Hierarchy in a Multiple Timescale Neural Network Model: A Humanoid Robot Experiment.", NeurIPS 2018. https://arxiv.org/abs/1807.03247v2

Source code in en/docs/zoo/src/MTRNN.py
class MTRNNCell(nn.Module):
    #:: MTRNNCell
    """Multiple Timescale RNN.

    Implements a form of Recurrent Neural Network (RNN) that operates with multiple timescales.
    This is based on the idea of hierarchical organization in human cognitive functions.

    Arguments:
        input_dim (int): Number of input features.
        fast_dim (int): Number of fast context neurons.
        slow_dim (int): Number of slow context neurons.
        fast_tau (float): Time constant value of fast context.
        slow_tau (float): Time constant value of slow context.
        activation (string, optional): If you set `None`, no activation is applied (ie. "linear" activation: `a(x) = x`).
        use_bias (Boolean, optional): whether the layer uses a bias vector. The default is False.
        use_pb (Boolean, optional): whether the recurrent uses a pb vector. The default is False.

    Yuichi Yamashita, Jun Tani,
    "Emergence of Functional Hierarchy in a Multiple Timescale Neural Network Model: A Humanoid Robot Experiment.", NeurIPS 2018.
    https://arxiv.org/abs/1807.03247v2
    """

    def __init__(
        self,
        input_dim,
        fast_dim,
        slow_dim,
        fast_tau,
        slow_tau,
        activation="tanh",
        use_bias=False,
        use_pb=False,
    ):
        super(MTRNNCell, self).__init__()

        self.input_dim = input_dim
        self.fast_dim = fast_dim
        self.slow_dim = slow_dim
        self.fast_tau = fast_tau
        self.slow_tau = slow_tau
        self.use_bias = use_bias
        self.use_pb = use_pb

        # Legacy string support for activation function.
        if isinstance(activation, str):
            self.activation = get_activation_fn(activation)
        else:
            self.activation = activation

        # Input Layers
        self.i2f = nn.Linear(input_dim, fast_dim, bias=use_bias)

        # Fast context layer
        self.f2f = nn.Linear(fast_dim, fast_dim, bias=False)
        self.f2s = nn.Linear(fast_dim, slow_dim, bias=use_bias)

        # Slow context layer
        self.s2s = nn.Linear(slow_dim, slow_dim, bias=False)
        self.s2f = nn.Linear(slow_dim, fast_dim, bias=use_bias)

    def forward(self, x, state=None, pb=None):
        """Forward propagation of the MTRNN.

        Arguments:
            x (torch.Tensor): Input tensor of shape (batch_size, input_dim).
            state (list): Previous states (h_fast, h_slow, u_fast, u_slow), each of shape (batch_size, context_dim).
                   If None, initialize states to zeros.
            pb (bool): pb vector. Used if self.use_pb is set to True.

        Returns:
            new_h_fast (torch.Tensor): Updated fast context state.
            new_h_slow (torch.Tensor): Updated slow context state.
            new_u_fast (torch.Tensor): Updated fast internal state.
            new_u_slow (torch.Tensor): Updated slow internal state.
        """
        batch_size = x.shape[0]
        if state is not None:
            prev_h_fast, prev_h_slow, prev_u_fast, prev_u_slow = state
        else:
            device = x.device
            prev_h_fast = torch.zeros(batch_size, self.fast_dim).to(device)
            prev_h_slow = torch.zeros(batch_size, self.slow_dim).to(device)
            prev_u_fast = torch.zeros(batch_size, self.fast_dim).to(device)
            prev_u_slow = torch.zeros(batch_size, self.slow_dim).to(device)

        # Update of fast internal state
        new_u_fast = (1.0 - 1.0 / self.fast_tau) * prev_u_fast + 1.0 / self.fast_tau * (
            self.i2f(x) + self.f2f(prev_h_fast) + self.s2f(prev_h_slow)
        )

        # Update of slow internal state
        _input_slow = self.f2s(prev_h_fast) + self.s2s(prev_h_slow)
        if pb is not None:
            _input_slow += pb

        new_u_slow = (1.0 - 1.0 / self.slow_tau) * prev_u_slow + 1.0 / self.slow_tau * _input_slow

        # Compute the activation for both fast and slow context states
        new_h_fast = self.activation(new_u_fast)
        new_h_slow = self.activation(new_u_slow)

        return new_h_fast, new_h_slow, new_u_fast, new_u_slow

forward(x, state=None, pb=None)

Forward propagation of the MTRNN.

Parameters:

Name Type Description Default
x torch.Tensor

Input tensor of shape (batch_size, input_dim).

required
state list

Previous states (h_fast, h_slow, u_fast, u_slow), each of shape (batch_size, context_dim). If None, initialize states to zeros.

None
pb bool

pb vector. Used if self.use_pb is set to True.

None

Returns:

Name Type Description
new_h_fast torch.Tensor

Updated fast context state.

new_h_slow torch.Tensor

Updated slow context state.

new_u_fast torch.Tensor

Updated fast internal state.

new_u_slow torch.Tensor

Updated slow internal state.

Source code in en/docs/zoo/src/MTRNN.py
def forward(self, x, state=None, pb=None):
    """Forward propagation of the MTRNN.

    Arguments:
        x (torch.Tensor): Input tensor of shape (batch_size, input_dim).
        state (list): Previous states (h_fast, h_slow, u_fast, u_slow), each of shape (batch_size, context_dim).
               If None, initialize states to zeros.
        pb (bool): pb vector. Used if self.use_pb is set to True.

    Returns:
        new_h_fast (torch.Tensor): Updated fast context state.
        new_h_slow (torch.Tensor): Updated slow context state.
        new_u_fast (torch.Tensor): Updated fast internal state.
        new_u_slow (torch.Tensor): Updated slow internal state.
    """
    batch_size = x.shape[0]
    if state is not None:
        prev_h_fast, prev_h_slow, prev_u_fast, prev_u_slow = state
    else:
        device = x.device
        prev_h_fast = torch.zeros(batch_size, self.fast_dim).to(device)
        prev_h_slow = torch.zeros(batch_size, self.slow_dim).to(device)
        prev_u_fast = torch.zeros(batch_size, self.fast_dim).to(device)
        prev_u_slow = torch.zeros(batch_size, self.slow_dim).to(device)

    # Update of fast internal state
    new_u_fast = (1.0 - 1.0 / self.fast_tau) * prev_u_fast + 1.0 / self.fast_tau * (
        self.i2f(x) + self.f2f(prev_h_fast) + self.s2f(prev_h_slow)
    )

    # Update of slow internal state
    _input_slow = self.f2s(prev_h_fast) + self.s2s(prev_h_slow)
    if pb is not None:
        _input_slow += pb

    new_u_slow = (1.0 - 1.0 / self.slow_tau) * prev_u_slow + 1.0 / self.slow_tau * _input_slow

    # Compute the activation for both fast and slow context states
    new_h_fast = self.activation(new_u_fast)
    new_h_slow = self.activation(new_u_slow)

    return new_h_fast, new_h_slow, new_u_fast, new_u_slow

MTRNN.BasicMTRNN

Bases: nn.Module

MTRNN Wrapper Module.

This module encapsulates the MTRNNCell, adding an output layer to it.

Parameters:

Name Type Description Default
in_dim int

Number of input features.

required
fast_dim int

Number of fast context neurons.

required
slow_dim int

Number of slow context neurons.

required
fast_tau float

Time constant value of fast context.

required
slow_tau float

Time constant value of slow context.

required
out_dim int

Number of output features. If None, set equal to in_dim.

None
activation string

If you set None, no activation is applied (ie. "linear" activation: a(x) = x).

'tanh'
Source code in en/docs/zoo/src/MTRNN.py
class BasicMTRNN(nn.Module):
    #:: BasicMTRNN
    """MTRNN Wrapper Module.

    This module encapsulates the MTRNNCell, adding an output layer to it.

    Arguments:
        in_dim (int):  Number of input features.
        fast_dim (int): Number of fast context neurons.
        slow_dim (int): Number of slow context neurons.
        fast_tau (float): Time constant value of fast context.
        slow_tau (float): Time constant value of slow context.
        out_dim (int, optional): Number of output features. If None, set equal to in_dim.
        activation (string, optional): If you set `None`, no activation is applied (ie. "linear" activation: `a(x) = x`).
    """

    def __init__(
        self, in_dim, fast_dim, slow_dim, fast_tau, slow_tau, out_dim=None, activation="tanh"
    ):
        super(BasicMTRNN, self).__init__()

        if out_dim is None:
            out_dim = in_dim

        # Legacy string support for activation function.
        if isinstance(activation, str):
            activation = get_activation_fn(activation)

        self.mtrnn = MTRNNCell(
            in_dim, fast_dim, slow_dim, fast_tau, slow_tau, activation=activation
        )
        # Output of RNN
        self.rnn_out = nn.Sequential(nn.Linear(fast_dim, out_dim), activation)

    def forward(self, x, state=None):
        """Forward propagation of the BasicMTRNN.

        Arguments:
            x (torch.Tensor): Input tensor of shape (batch_size, input_dim).
            state (list): Previous states (h_fast, h_slow, u_fast, u_slow), each of shape (batch_size, context_dim).
                   If None, initialize states to zeros.

        Returns:
            y_hat (torch.Tensor): Output tensor of shape (batch_size, out_dim).
            rnn_hid (list): Updated states (h_fast, h_slow, u_fast, u_slow).
        """
        rnn_hid = self.mtrnn(x, state)
        y_hat = self.rnn_out(rnn_hid[0])

        return y_hat, rnn_hid

forward(x, state=None)

Forward propagation of the BasicMTRNN.

Parameters:

Name Type Description Default
x torch.Tensor

Input tensor of shape (batch_size, input_dim).

required
state list

Previous states (h_fast, h_slow, u_fast, u_slow), each of shape (batch_size, context_dim). If None, initialize states to zeros.

None

Returns:

Name Type Description
y_hat torch.Tensor

Output tensor of shape (batch_size, out_dim).

rnn_hid list

Updated states (h_fast, h_slow, u_fast, u_slow).

Source code in en/docs/zoo/src/MTRNN.py
def forward(self, x, state=None):
    """Forward propagation of the BasicMTRNN.

    Arguments:
        x (torch.Tensor): Input tensor of shape (batch_size, input_dim).
        state (list): Previous states (h_fast, h_slow, u_fast, u_slow), each of shape (batch_size, context_dim).
               If None, initialize states to zeros.

    Returns:
        y_hat (torch.Tensor): Output tensor of shape (batch_size, out_dim).
        rnn_hid (list): Updated states (h_fast, h_slow, u_fast, u_slow).
    """
    rnn_hid = self.mtrnn(x, state)
    y_hat = self.rnn_out(rnn_hid[0])

    return y_hat, rnn_hid

  1. Yuichi Yamashita and Jun Tani. Emergence of functional hierarchy in a multiple timescale neural network model: a humanoid robot experiment. PLoS computational biology, 4(11):e1000220, 2008.