SARNN
SARNN "explicitly" extracts the spatial coordinates of critical positions in the task, such as target objects and arms, from images, and learns the coordinates along with the robot's joint angles of the robot1. This greatly improves robustness to changes in the object's position. The figure below illustrates the network structure of SARNN, which consists of an encoder responsible for extracting image features $f_t$ and object position coordinates $p_t$ from camera images $i_t$, a recurrent module that learns the temporal changes in the robot's joint angles and object position coordinates $p_t$, and a decoder that reconstructs images based on the image features $f_t$ and heat maps $\hat h_{t+1}$.
The upper part of the encoder and decoder consists of CNN layers, including Convolutional and Transposed Convolutional layers, which extract and reconstruct color and shape information of objects from image features. The lower part of the CNN uses the Spatial Softmax layer to extract 2D position information of objects. The Recurrent module only predicts the position information $p_{t+1}$ of the object, which alone is not sufficient to reconstruct the image using the decoder. Therefore, a heat map $\hat h_{t+1}$ centered on the predicted coordinate information $p_{t+1}$ is generated. By multiplying it with the image features extracted by the CNN in the upper part, a predicted image $\hat i_{t+1}$ is generated based on the information around the predicted attention point.
Here, we show the implementation method and model classes for the distinctive features of SARNN: Spatial Attention Mechanism, Heatmap Generator, Loss Scheduler, and Backpropagation Through Time.
Spatial Attention Mechanism
The spatial attention mechanism emphasizes important information (pixels with large values) by multiplying the feature map by softmax. It then extracts the position information of the highlighted pixels using position encoding. The figure below illustrates the results of the spatial attention mechanism, where important position information (represented by red dots) is extracted by multiplying a "pseudo" feature map generated by two randomly generated Gaussian distributions with Softmax. Since CNN feature maps contain diverse information, they are not effectively enhanced by a simple softmax multiplication. To further enhance the features, it is critical to use Softmax with temperature. The effect of Softmax with temperature can be observed by adjusting the temperature
parameter in the provided example program. The red dots in the figure indicate the positions extracted by spatial Softmax, and since they are generated at the center of one of the Gaussian distributions, the position information can be extracted accurately.
[SOURCE] SpatialSoftmax.py | |
---|---|
|
Heatmap Generator
The Heatmap Generator generates a heatmap centered on specific pixel coordinates that represent the position information. The figure below illustrates a heatmap generated by the heatmap generator, centered on the position extracted by the spatial attention mechanism (indicated by the red dot in the figure). The size of the heatmap can be adjusted using the heatmap_size
parameter. A smaller heatmap size considers only the information near the attention point, while a larger size includes some surrounding information in the generated image. It is important to note that if the heatmap is too small, the corresponding predictive image $\hat i_{t+1}$ may not be generated, while if it is too large, adjustments to the sensitivity parameter may be required to account for changes in the environment, such as background and obstacles.
[SOURCE] InverseSpatialSoftmax.py | |
---|---|
|
Loss Scheduler
The loss scheduler is a callback
function that gradually assigns weights to the prediction error of the attention point based on the number of epochs. It is an important feature for SARNN training. The figure below shows the weighting curve for each curve_name
argument, where the horizontal axis represents the number of epochs and the vertical axis represents the weighting value. The decay weighting starts at 0 and gradually reaches the maximum weighting value (e.g. 0.1) at the epoch specified by decay_end
(e.g. 100). It is important to note that the maximum weighting value is determined by the __call__
method. This class supports five types of curves, as shown in the figure: linear, S-curve, inverse S-curve, decay, and acceleration interpolation.
The reason for using the error scheduler in SARNN training is to allow the CNN filters to be trained more freely in the early stages. Since the encoder and decoder weights of SARNN are randomly initialized, visual features may not be correctly extracted or learned during the initial learning phase.
When the prediction error of attention obtained in such a situation is backpropagated, the attention point may not be correctly directed to the work object. Instead, the attention point that minimizes the "prediction image error" is learned. Therefore, by ignoring the prediction error of the attention point at the initial stage of learning, it is possible to obtain an attention point that focuses only on the work object. The attention point prediction error is then learned when the CNN filters have finished learning features. The decay_end
parameter sets the learning time of the CNN, which is typically set to about 1000 epochs, but may need to be adjusted depending on the task.
[SOURCE] callback.py | |
---|---|
|
Backpropagation Through Time
We use Backpropagation Through Time (BPTT) to learn the time series of the model2. In an RNN, the internal state $h_{t}$ at each time step depends on the internal state $h_{t-1}$ at the previous time step $t-1$. In BPTT, the parameters are updated at each time step by calculating the loss at each time step and then calculating the gradients backwards. Specifically, the model takes input images $i_t$ and joint angles $a_{t}$ and outputs the next state ($\hat i_{t+1}$, $\hat a_{t+1}$). The mean squared error (MSE) loss between the predictions and the true values ($f_{t+1}$, $a_{t+1}$) for all sequences is computed using nn.MSELoss
, and error propagation is performed based on the loss
value. Since the parameters at each time step are used for all subsequent time steps, backpropagation is performed with temporal expansion.
Lines 47-54 show that SARNN calculates not only the image loss and the joint angle loss, but also the prediction loss of the attention point. Since the true value of the attention point is not available, the bidirectional loss 3 is used to learn the attention point. Specifically, the model updates the weights to minimize the loss between the attention point $\hat p_{t+1}$ predicted by the RNN at each time step and the attention point $p_{t+1}$ extracted by the CNN at the same time step $t+1$. Based on this bidirectional loss, the LSTM learns the time-series relationship between attention points and joint angles. This approach not only eliminates redundant image predictions, but also encourages the CNN to predict attention points that are critical for motion prediction.
In addition, loss_weights
assign weights to each modality loss, thus determining the focus of learning for each modality. In deep predictive learning, joint angles are learned intensively because they directly affect the robot's motion commands. However, if the image information is not adequately learned, the integration of image and joint angle learning may not occur properly, making joint angle prediction corresponding to the image information difficult. Therefore, the weighting coefficients need to be adjusted based on the model and the task. In our experience, the weighting factor is often set to 1.0 for all modalities or 0.1 for images only.
[SOURCE] fullBPTT.py | |
---|---|
|
model.SARNN
Bases: nn.Module
SARNN: Spatial Attention with Recurrent Neural Network. This model "explicitly" extracts positions from the image that are important to the task, such as the target object or arm position, and learns the time-series relationship between these positions and the robot's joint angles. The robot is able to generate robust motions in response to changes in object position and lighting.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
rec_dim |
int
|
The dimension of the recurrent state in the LSTM cell. |
required |
k_dim |
int
|
The dimension of the attention points. |
5
|
joint_dim |
int
|
The dimension of the joint angles. |
14
|
temperature |
float
|
The temperature parameter for the softmax function. |
0.0001
|
heatmap_size |
float
|
The size of the heatmap in the InverseSpatialSoftmax layer. |
0.1
|
kernel_size |
int
|
The size of the convolutional kernel. |
3
|
activation |
str
|
The name of activation function. |
'lrelu'
|
im_size |
list
|
The size of the input image [height, width]. |
[128, 128]
|
Source code in en/docs/model/src/model.py
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
|
forward(xi, xv, state=None)
Forward pass of the SARNN module. Predicts the image, joint angle, and attention at the next time based on the image and joint angle at time t. Predict the image, joint angles, and attention points for the next state (t+1) based on the image and joint angles of the current state (t). By inputting the predicted joint angles as control commands for the robot, it is possible to generate sequential motion based on sensor information.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
xi |
torch.Tensor
|
Input image tensor of shape (batch_size, channels, height, width). |
required |
xv |
torch.Tensor
|
Input vector tensor of shape (batch_size, input_dim). |
required |
state |
tuple
|
Initial hidden state and cell state of the LSTM cell. |
None
|
Returns:
Name | Type | Description |
---|---|---|
y_image |
torch.Tensor
|
Decoded image tensor of shape (batch_size, channels, height, width). |
y_joint |
torch.Tensor
|
Decoded joint prediction tensor of shape (batch_size, joint_dim). |
enc_pts |
torch.Tensor
|
Encoded points tensor of shape (batch_size, k_dim * 2). |
dec_pts |
torch.Tensor
|
Decoded points tensor of shape (batch_size, k_dim * 2). |
rnn_hid |
tuple
|
Tuple containing the hidden state and cell state of the LSTM cell. |
Source code in en/docs/model/src/model.py
-
Hideyuki Ichiwara, Hiroshi Ito, Kenjiro Yamamoto, Hiroki Mori, and Tetsuya Ogata. Contact-rich manipulation of a flexible object based on deep predictive learning using vision and tactility. In 2022 International Conference on Robotics and Automation (ICRA), 5375–5381. IEEE, 2022. ↩
-
David E Rumelhart, Geoffrey E Hinton, and Ronald J Williams. Learning representations by back-propagating errors. nature, 323(6088):533–536, 1986. ↩
-
Hyogo Hiruma, Hiroshi Ito, Hiroki Mori, and Tetsuya Ogata. Deep active visual attention for real-time robot motion generation: emergence of tool-body assimilation and adaptive tool-use. IEEE Robotics and Automation Letters, 7(3):8550–8557, 2022. ↩