コンテンツにスキップ

概要

MTRNNは、応答速度の異なる階層的なニューロン群から構成されるRNNの一種である1。 IO層と異なる発火速度(時定数)を持つコンテキスト層(Cf層とCs層)の3層からなり、それぞれ再帰的な入力を持つ。 時定数は、Cf層からCs層の順で値が大きくなり、入力に対する応答速度が遅くなる。 入力された情報は、Cf層とCs層を介してOutput層で出力される。 IO層とCs層の間に直接の結合は存在せず、Cf層を介して相互作用する。 MTRNNを用いることで、ロボットの動作学習が可能となり、Cf層では動作プリミティブ、Cs層ではそれらの組み合わせが表現(学習)される。 LSTMと比較してMTRNNは解釈性が高いため、尾形研究室でよく用いている。

MTRNN

MTRNN.MTRNNCell

Bases: nn.Module

Multiple Timescale RNN.

Implements a form of Recurrent Neural Network (RNN) that operates with multiple timescales. This is based on the idea of hierarchical organization in human cognitive functions.

Parameters:

Name Type Description Default
input_dim int

Number of input features.

required
fast_dim int

Number of fast context neurons.

required
slow_dim int

Number of slow context neurons.

required
fast_tau float

Time constant value of fast context.

required
slow_tau float

Time constant value of slow context.

required
activation string

If you set None, no activation is applied (ie. "linear" activation: a(x) = x).

'tanh'
use_bias Boolean

whether the layer uses a bias vector. The default is False.

False
use_pb Boolean

whether the recurrent uses a pb vector. The default is False.

False

Yuichi Yamashita, Jun Tani, "Emergence of Functional Hierarchy in a Multiple Timescale Neural Network Model: A Humanoid Robot Experiment.", NeurIPS 2018. https://arxiv.org/abs/1807.03247v2

Source code in ja/docs/zoo/src/MTRNN.py
class MTRNNCell(nn.Module):
    #:: MTRNNCell
    """Multiple Timescale RNN.

    Implements a form of Recurrent Neural Network (RNN) that operates with multiple timescales.
    This is based on the idea of hierarchical organization in human cognitive functions.

    Arguments:
        input_dim (int): Number of input features.
        fast_dim (int): Number of fast context neurons.
        slow_dim (int): Number of slow context neurons.
        fast_tau (float): Time constant value of fast context.
        slow_tau (float): Time constant value of slow context.
        activation (string, optional): If you set `None`, no activation is applied (ie. "linear" activation: `a(x) = x`).
        use_bias (Boolean, optional): whether the layer uses a bias vector. The default is False.
        use_pb (Boolean, optional): whether the recurrent uses a pb vector. The default is False.

    Yuichi Yamashita, Jun Tani,
    "Emergence of Functional Hierarchy in a Multiple Timescale Neural Network Model: A Humanoid Robot Experiment.", NeurIPS 2018.
    https://arxiv.org/abs/1807.03247v2
    """

    def __init__(
        self,
        input_dim,
        fast_dim,
        slow_dim,
        fast_tau,
        slow_tau,
        activation="tanh",
        use_bias=False,
        use_pb=False,
    ):
        super(MTRNNCell, self).__init__()

        self.input_dim = input_dim
        self.fast_dim = fast_dim
        self.slow_dim = slow_dim
        self.fast_tau = fast_tau
        self.slow_tau = slow_tau
        self.use_bias = use_bias
        self.use_pb = use_pb

        # Legacy string support for activation function.
        if isinstance(activation, str):
            self.activation = get_activation_fn(activation)
        else:
            self.activation = activation

        # Input Layers
        self.i2f = nn.Linear(input_dim, fast_dim, bias=use_bias)

        # Fast context layer
        self.f2f = nn.Linear(fast_dim, fast_dim, bias=False)
        self.f2s = nn.Linear(fast_dim, slow_dim, bias=use_bias)

        # Slow context layer
        self.s2s = nn.Linear(slow_dim, slow_dim, bias=False)
        self.s2f = nn.Linear(slow_dim, fast_dim, bias=use_bias)

    def forward(self, x, state=None, pb=None):
        """Forward propagation of the MTRNN.

        Arguments:
            x (torch.Tensor): Input tensor of shape (batch_size, input_dim).
            state (list): Previous states (h_fast, h_slow, u_fast, u_slow), each of shape (batch_size, context_dim).
                   If None, initialize states to zeros.
            pb (bool): pb vector. Used if self.use_pb is set to True.

        Returns:
            new_h_fast (torch.Tensor): Updated fast context state.
            new_h_slow (torch.Tensor): Updated slow context state.
            new_u_fast (torch.Tensor): Updated fast internal state.
            new_u_slow (torch.Tensor): Updated slow internal state.
        """
        batch_size = x.shape[0]
        if state is not None:
            prev_h_fast, prev_h_slow, prev_u_fast, prev_u_slow = state
        else:
            device = x.device
            prev_h_fast = torch.zeros(batch_size, self.fast_dim).to(device)
            prev_h_slow = torch.zeros(batch_size, self.slow_dim).to(device)
            prev_u_fast = torch.zeros(batch_size, self.fast_dim).to(device)
            prev_u_slow = torch.zeros(batch_size, self.slow_dim).to(device)

        # Update of fast internal state
        new_u_fast = (1.0 - 1.0 / self.fast_tau) * prev_u_fast + 1.0 / self.fast_tau * (
            self.i2f(x) + self.f2f(prev_h_fast) + self.s2f(prev_h_slow)
        )

        # Update of slow internal state
        _input_slow = self.f2s(prev_h_fast) + self.s2s(prev_h_slow)
        if pb is not None:
            _input_slow += pb

        new_u_slow = (1.0 - 1.0 / self.slow_tau) * prev_u_slow + 1.0 / self.slow_tau * _input_slow

        # Compute the activation for both fast and slow context states
        new_h_fast = self.activation(new_u_fast)
        new_h_slow = self.activation(new_u_slow)

        return new_h_fast, new_h_slow, new_u_fast, new_u_slow

forward(x, state=None, pb=None)

Forward propagation of the MTRNN.

Parameters:

Name Type Description Default
x torch.Tensor

Input tensor of shape (batch_size, input_dim).

required
state list

Previous states (h_fast, h_slow, u_fast, u_slow), each of shape (batch_size, context_dim). If None, initialize states to zeros.

None
pb bool

pb vector. Used if self.use_pb is set to True.

None

Returns:

Name Type Description
new_h_fast torch.Tensor

Updated fast context state.

new_h_slow torch.Tensor

Updated slow context state.

new_u_fast torch.Tensor

Updated fast internal state.

new_u_slow torch.Tensor

Updated slow internal state.

Source code in ja/docs/zoo/src/MTRNN.py
def forward(self, x, state=None, pb=None):
    """Forward propagation of the MTRNN.

    Arguments:
        x (torch.Tensor): Input tensor of shape (batch_size, input_dim).
        state (list): Previous states (h_fast, h_slow, u_fast, u_slow), each of shape (batch_size, context_dim).
               If None, initialize states to zeros.
        pb (bool): pb vector. Used if self.use_pb is set to True.

    Returns:
        new_h_fast (torch.Tensor): Updated fast context state.
        new_h_slow (torch.Tensor): Updated slow context state.
        new_u_fast (torch.Tensor): Updated fast internal state.
        new_u_slow (torch.Tensor): Updated slow internal state.
    """
    batch_size = x.shape[0]
    if state is not None:
        prev_h_fast, prev_h_slow, prev_u_fast, prev_u_slow = state
    else:
        device = x.device
        prev_h_fast = torch.zeros(batch_size, self.fast_dim).to(device)
        prev_h_slow = torch.zeros(batch_size, self.slow_dim).to(device)
        prev_u_fast = torch.zeros(batch_size, self.fast_dim).to(device)
        prev_u_slow = torch.zeros(batch_size, self.slow_dim).to(device)

    # Update of fast internal state
    new_u_fast = (1.0 - 1.0 / self.fast_tau) * prev_u_fast + 1.0 / self.fast_tau * (
        self.i2f(x) + self.f2f(prev_h_fast) + self.s2f(prev_h_slow)
    )

    # Update of slow internal state
    _input_slow = self.f2s(prev_h_fast) + self.s2s(prev_h_slow)
    if pb is not None:
        _input_slow += pb

    new_u_slow = (1.0 - 1.0 / self.slow_tau) * prev_u_slow + 1.0 / self.slow_tau * _input_slow

    # Compute the activation for both fast and slow context states
    new_h_fast = self.activation(new_u_fast)
    new_h_slow = self.activation(new_u_slow)

    return new_h_fast, new_h_slow, new_u_fast, new_u_slow

MTRNN.BasicMTRNN

Bases: nn.Module

MTRNN Wrapper Module.

This module encapsulates the MTRNNCell, adding an output layer to it.

Parameters:

Name Type Description Default
in_dim int

Number of input features.

required
fast_dim int

Number of fast context neurons.

required
slow_dim int

Number of slow context neurons.

required
fast_tau float

Time constant value of fast context.

required
slow_tau float

Time constant value of slow context.

required
out_dim int

Number of output features. If None, set equal to in_dim.

None
activation string

If you set None, no activation is applied (ie. "linear" activation: a(x) = x).

'tanh'
Source code in ja/docs/zoo/src/MTRNN.py
class BasicMTRNN(nn.Module):
    #:: BasicMTRNN
    """MTRNN Wrapper Module.

    This module encapsulates the MTRNNCell, adding an output layer to it.

    Arguments:
        in_dim (int):  Number of input features.
        fast_dim (int): Number of fast context neurons.
        slow_dim (int): Number of slow context neurons.
        fast_tau (float): Time constant value of fast context.
        slow_tau (float): Time constant value of slow context.
        out_dim (int, optional): Number of output features. If None, set equal to in_dim.
        activation (string, optional): If you set `None`, no activation is applied (ie. "linear" activation: `a(x) = x`).
    """

    def __init__(
        self, in_dim, fast_dim, slow_dim, fast_tau, slow_tau, out_dim=None, activation="tanh"
    ):
        super(BasicMTRNN, self).__init__()

        if out_dim is None:
            out_dim = in_dim

        # Legacy string support for activation function.
        if isinstance(activation, str):
            activation = get_activation_fn(activation)

        self.mtrnn = MTRNNCell(
            in_dim, fast_dim, slow_dim, fast_tau, slow_tau, activation=activation
        )
        # Output of RNN
        self.rnn_out = nn.Sequential(nn.Linear(fast_dim, out_dim), activation)

    def forward(self, x, state=None):
        """Forward propagation of the BasicMTRNN.

        Arguments:
            x (torch.Tensor): Input tensor of shape (batch_size, input_dim).
            state (list): Previous states (h_fast, h_slow, u_fast, u_slow), each of shape (batch_size, context_dim).
                   If None, initialize states to zeros.

        Returns:
            y_hat (torch.Tensor): Output tensor of shape (batch_size, out_dim).
            rnn_hid (list): Updated states (h_fast, h_slow, u_fast, u_slow).
        """
        rnn_hid = self.mtrnn(x, state)
        y_hat = self.rnn_out(rnn_hid[0])

        return y_hat, rnn_hid

forward(x, state=None)

Forward propagation of the BasicMTRNN.

Parameters:

Name Type Description Default
x torch.Tensor

Input tensor of shape (batch_size, input_dim).

required
state list

Previous states (h_fast, h_slow, u_fast, u_slow), each of shape (batch_size, context_dim). If None, initialize states to zeros.

None

Returns:

Name Type Description
y_hat torch.Tensor

Output tensor of shape (batch_size, out_dim).

rnn_hid list

Updated states (h_fast, h_slow, u_fast, u_slow).

Source code in ja/docs/zoo/src/MTRNN.py
def forward(self, x, state=None):
    """Forward propagation of the BasicMTRNN.

    Arguments:
        x (torch.Tensor): Input tensor of shape (batch_size, input_dim).
        state (list): Previous states (h_fast, h_slow, u_fast, u_slow), each of shape (batch_size, context_dim).
               If None, initialize states to zeros.

    Returns:
        y_hat (torch.Tensor): Output tensor of shape (batch_size, out_dim).
        rnn_hid (list): Updated states (h_fast, h_slow, u_fast, u_slow).
    """
    rnn_hid = self.mtrnn(x, state)
    y_hat = self.rnn_out(rnn_hid[0])

    return y_hat, rnn_hid

  1. Yuichi Yamashita and Jun Tani. Emergence of functional hierarchy in a multiple timescale neural network model: a humanoid robot experiment. PLoS computational biology, 4(11):e1000220, 2008.