Overview
MTRNN is a type of RNN consisting of a hierarchical group of neurons with different firing rates1. It consists of three layers: an input-output (IO) layer and context layers (Cf and Cs layers) with different firing rates (time constants), each with recursive inputs. The time constants increase from the Cf layer to the Cs layer, resulting in slower response speeds to the input. The input information is then passed through the Cf and Cs layers to the output layer. There is no direct connection between the IO and Cs layers, and their interaction occurs through the Cf layer. The MTRNN allows the robot to learn behaviors, where the Cf layer represents behavioral primitives and the Cs layer represents learning the combination of these primitives. Compared to LSTM, MTRNN is more interpretable and is widely used in our lab.
MTRNN.MTRNNCell
Bases: nn.Module
Multiple Timescale RNN.
Implements a form of Recurrent Neural Network (RNN) that operates with multiple timescales. This is based on the idea of hierarchical organization in human cognitive functions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input_dim |
int
|
Number of input features. |
required |
fast_dim |
int
|
Number of fast context neurons. |
required |
slow_dim |
int
|
Number of slow context neurons. |
required |
fast_tau |
float
|
Time constant value of fast context. |
required |
slow_tau |
float
|
Time constant value of slow context. |
required |
activation |
string
|
If you set |
'tanh'
|
use_bias |
Boolean
|
whether the layer uses a bias vector. The default is False. |
False
|
use_pb |
Boolean
|
whether the recurrent uses a pb vector. The default is False. |
False
|
Yuichi Yamashita, Jun Tani, "Emergence of Functional Hierarchy in a Multiple Timescale Neural Network Model: A Humanoid Robot Experiment.", NeurIPS 2018. https://arxiv.org/abs/1807.03247v2
Source code in en/docs/zoo/src/MTRNN.py
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
|
forward(x, state=None, pb=None)
Forward propagation of the MTRNN.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
torch.Tensor
|
Input tensor of shape (batch_size, input_dim). |
required |
state |
list
|
Previous states (h_fast, h_slow, u_fast, u_slow), each of shape (batch_size, context_dim). If None, initialize states to zeros. |
None
|
pb |
bool
|
pb vector. Used if self.use_pb is set to True. |
None
|
Returns:
Name | Type | Description |
---|---|---|
new_h_fast |
torch.Tensor
|
Updated fast context state. |
new_h_slow |
torch.Tensor
|
Updated slow context state. |
new_u_fast |
torch.Tensor
|
Updated fast internal state. |
new_u_slow |
torch.Tensor
|
Updated slow internal state. |
Source code in en/docs/zoo/src/MTRNN.py
MTRNN.BasicMTRNN
Bases: nn.Module
MTRNN Wrapper Module.
This module encapsulates the MTRNNCell, adding an output layer to it.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
in_dim |
int
|
Number of input features. |
required |
fast_dim |
int
|
Number of fast context neurons. |
required |
slow_dim |
int
|
Number of slow context neurons. |
required |
fast_tau |
float
|
Time constant value of fast context. |
required |
slow_tau |
float
|
Time constant value of slow context. |
required |
out_dim |
int
|
Number of output features. If None, set equal to in_dim. |
None
|
activation |
string
|
If you set |
'tanh'
|
Source code in en/docs/zoo/src/MTRNN.py
forward(x, state=None)
Forward propagation of the BasicMTRNN.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
torch.Tensor
|
Input tensor of shape (batch_size, input_dim). |
required |
state |
list
|
Previous states (h_fast, h_slow, u_fast, u_slow), each of shape (batch_size, context_dim). If None, initialize states to zeros. |
None
|
Returns:
Name | Type | Description |
---|---|---|
y_hat |
torch.Tensor
|
Output tensor of shape (batch_size, out_dim). |
rnn_hid |
list
|
Updated states (h_fast, h_slow, u_fast, u_slow). |
Source code in en/docs/zoo/src/MTRNN.py
-
Yuichi Yamashita and Jun Tani. Emergence of functional hierarchy in a multiple timescale neural network model: a humanoid robot experiment. PLoS computational biology, 4(11):e1000220, 2008. ↩