概要
SARNNは、画像からタスクに重要な位置(作業対象物やアーム)の空間座標を「明示的」に抽出し、その座標とロボットの関節角度を学習することで、対象物の位置変化に対するロバスト性を大幅に向上させた1。 下図はSARNNのネットワーク構造を示しており、カメラ画像 $i_t$ から画像特徴量 $f_t$ と物体の位置座標 $p_t$ を抽出するEncoder部、ロボットの関節角度 $a_t$ と物体の位置情報の時系列変化を学習するRecurrent部、そして画像特徴量 $f_t$ とヒートマップ $\hat h_{t+1}$ に基づいて画像を再構成するDecoder部から構成される。 Encoder上段とDecoderのCNN 層(Convolution layer and Transposed convolutional layer)は、画像の特徴量抽出と再構成をすることで物体の色や形に関する情報を抽出する。 下段のCNNはSpatial Softmax層を用いることで物体の2D位置情報を抽出する。 Recurrent部は、物体の位置情報 $p_{t+1}$ のみを予測しているため、Decoderで画像を再構成するだけの十分な情報は含まれていない。 そこで予測位置情報 $p_{t+1}$ を中心としたヒートマップ $\hat h_{t+1}$ を生成し、上段のCNNで抽出した画像特徴量 $f_t$ と掛け合わせることで、予測注意点の周辺情報に基づいて予測画像 $\hat i_{t+1}$ を生成する。
ここでは、SARNNの特徴的な機能である空間的注意機構 、ヒートマップ生成機構、誤差スケジューラ、誤差逆伝播法の実装方法とモデルクラスを示す。
空間的注意機構
空間的注意機構は、特長マップにSoftmaxをかけ合わせることで重要な情報(ピクセル値が大きい)を強調したのちに、Position-Encodingを用いて強調されたピクセルの位置情報を抽出する。
下図は、空間的注意機構の処理結果を示しており、ランダムに生成した2つのガウス分布を用いて生成した「疑似」特徴マップ対し、Softmaxをかけ合わせることで重要な位置情報(赤点)を抽出する。
このとき、CNNの特徴マップには多様な情報が含まれているため、単純にSoftmaxをかけ合わせただけでは強調されないため、温度付きSoftmax を用いることが重要である。
温度付きSoftmaxの効果については、以下サンプルプログラムのパラメータ temperature
を調整して確認するとよい。
また図中の赤点はSpatialSoftmaxで抽出した位置を表しており、一方のガウス分布の中心に生成されていることから適切に位置情報が抽出できている。
[SOURCE] SpatialSoftmax.py | |
---|---|
|
ヒートマップ生成機構
ヒートマップ生成機構は、位置情報(特定のピクセル座標)を中心としたヒートマップを生成する。
下図は、空間的注意機構で抽出した位置(図中赤点)を中心にヒートマップ生成機構が生成したヒートマップを表している。
ヒートマップの大きさは、パラメータ heatmap_size
で設定することが可能であり、ヒートマップサイズが小さいと注意点近傍の情報のみ、大きいと周辺の情報も一部加味して画像生成を行う。
なお、ヒートマップがあまりにも小さいと適切な予測画像 $\hat i_{t+1}$ の生成ができない、また大きすぎると周囲の環境変化(背景や障害物)に敏感になるためパラメータの調整が必要である。
[SOURCE] InverseSpatialSoftmax.py | |
---|---|
|
誤差スケジューラ
誤差スケジューラとは、注意点の予測誤差をエポックに応じて徐々に重み付けする callback
であり、SARNNを学習させる上で重要な機能である。
下図は引数 curve_name
ごとの重み付け曲線を示しており、横軸はエポック数、縦軸は重み付けの値である。
誤差の重み付けは0から始まり、 decay_end
(例:100)で設定したエポックで重み付けの最大値(例:0.1)を返す。
なお、重み付け最大値は __call__
メソッドで指定する。
本クラスでは、図中に示す5種類の曲線(線形補間、S字補完、逆S字補完、減速補完、加速補完)に対応している。
SARNNの学習に誤差スケジューラを用いる理由として、学習初期段階ではCNNのフィルタを「自由に」学習させることが目的である。
SARNNのEncoderとDecoderはランダムに初期化されているため、学習初期段階では視覚画像中の特徴量を適切に抽出/学習できていない。
そのような状況で得られた注意点予測誤差を逆伝搬すると作業対象物に適切に注意点が向かず、「予測画像誤差」を最小にするような注意点が学習されてしまう。
そのため、学習初期段階では注意点の予測誤差は無視し、CNNのフィルタが特徴量を学習し終えた頃に注意点予測の誤差を学習させることで、作業に重要な対象物にのみ着目した注意点を獲得することが可能である。
decay_end
がCNNの学習タイミングを調整しており、通常1000エポック程度を設定しているが、タスクによっては調整が必要である。
[SOURCE] callback.py | |
---|---|
|
誤差逆伝播法
モデルの時系列学習を行うための誤差逆伝播アルゴリズムとしてBackpropagation Through Time(BPTT)を用いる2。
RNNでは、各時刻での内部状態 $h_{t}$ は、前時刻の時刻 $t-1$ の内部状態 $h_{t-1}$ に依存する。
BPTTでは、各時刻での誤差を計算し、それを遡って勾配を計算することで、各時刻でのパラメータの更新を行う。
具体的には、画像$i_t$と関節角度 $a_{t}$ をモデルに入力し、次状態($\hat i_{t+1}$, $ \hat a_{t+1}$)を出力(予測)する。
全シーケンスの予測値と真値($f_{t+1}$, $a_{t+1}$)の平均二乗誤差 nn.MSELoss
を計算し、誤差値loss
に基づいて誤差伝番を行う。
このとき、各時刻のパラメータが、その時刻より後のすべての時刻で使用されるため、時間的な展開を行いながら逆伝播を行う。
47-54行目に示すように、SARNNは画像誤差と関節角度誤差に加え、注意点の予測誤差も計算する。 注意点の真値は存在しないため、双方向誤差3 を用いて注意点の学習を行う。 具体的には時刻 $t$ でRNNが予測した注意点 $ \hat p_{t+1}$ と時刻 $t+1$ でCNNが抽出した注意点 $p_{t+1}$ が一致するように誤差を計算する。 この双方向誤差に基づいて、LSTMで注意点と関節角度の時系列関係を学習することで、冗長な画像予測を排除するだけでなく、動作予測に重要な注意点を予測するように誘導する。
またloss_weights
は、各モダリティ誤差の重みづけを行っており、どのモダリティを重点的に学習するかを決定する。
深層予測学習では、予測された関節角度がロボットの動作指令に直結するため関節角度を重点的に学習させる。
しかし逆に画像情報の学習が不十分な場合、画像と関節角度の統合学習が適切に行えない(画像情報に対応した関節角度予測が困難になる)ため、
重み付け係数はモデルやタスクに応じて調整することが求められる。
これまでの経験上、重み付け係数は全て1.0、もしくは画像のみ0.1にすることが多い。
[SOURCE] fullBPTT.py | |
---|---|
|
model.SARNN
Bases: nn.Module
SARNN: Spatial Attention with Recurrent Neural Network.
joint_dim
を設定することで、関節自由度が異なるロボットにも対応可能である。
一方でロボットの視覚画像 im_size
は128x128ピクセルのカラー画像に対応している。
カメラ画像のピクセルサイズを変更する場合、データによってはEncoderやDecoderのCNN層の数を調整する必要がある。
k_dim
は注意点の数を表しており、任意の数を設定することが可能である。活性化関数には LeakyReLU
を用いた。
Parameters:
Name | Type | Description | Default |
---|---|---|---|
rec_dim |
int
|
RNNの隠れ層のサイズ |
required |
k_dim |
int
|
注意点の数 |
5
|
joint_dim |
int
|
ロボット関節角度の次元数 |
14
|
temperature |
float
|
温度付きSoftmaxのハパラメータ |
0.0001
|
heatmap_size |
float
|
ヒートマップのサイズ |
0.1
|
kernel_size |
int
|
CNNのカーネルサイズ |
3
|
activation |
str
|
活性化関数 |
'lrelu'
|
im_size |
list
|
入力画像のサイズ [縦、横]. |
[128, 128]
|
Source code in ja/docs/model/src/model.py
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
|
forward(xi, xv, state=None)
時刻(t)の画像と関節角度から、次時刻(t+1)の画像、関節角度、注意点を予測する。 予測した関節角度をロボットの制御コマンドとして入力することで、 センサ情報に基づいた逐次的な動作生成が可能である。
Parameters:
Name | Type | Description | Default |
---|---|---|---|
xi |
torch.Tensor
|
時刻tの画像 [batch_size, channels, height, width] |
required |
xv |
torch.Tensor
|
時刻tの関節角度 [batch_size, input_dim] |
required |
state |
tuple
|
LSTMのセル状態と隠れ状態 [ [batch_size, rec_dim], [batch_size, rec_dim] ] |
None
|
Returns:
Name | Type | Description |
---|---|---|
y_image |
torch.Tensor
|
予測画像 [batch_size, channels, height, width] |
y_joint |
torch.Tensor
|
予測関節角度 [batch_size, joint_dim] |
enc_pts |
torch.Tensor
|
Spatial softmaxで抽出した注意点 [batch_size, k_dim * 2] |
dec_pts |
torch.Tensor
|
RNNが予測した注意点 [batch_size, k_dim * 2] |
rnn_hid |
tuple
|
LSTMのセル状態と隠れ状態 [ [batch_size, rec_dim], [batch_size, rec_dim] ] |
Source code in ja/docs/model/src/model.py
-
Hideyuki Ichiwara, Hiroshi Ito, Kenjiro Yamamoto, Hiroki Mori, and Tetsuya Ogata. Contact-rich manipulation of a flexible object based on deep predictive learning using vision and tactility. In 2022 International Conference on Robotics and Automation (ICRA), 5375–5381. IEEE, 2022. ↩
-
David E Rumelhart, Geoffrey E Hinton, and Ronald J Williams. Learning representations by back-propagating errors. nature, 323(6088):533–536, 1986. ↩
-
Hyogo Hiruma, Hiroshi Ito, Hiroki Mori, and Tetsuya Ogata. Deep active visual attention for real-time robot motion generation: emergence of tool-body assimilation and adaptive tool-use. IEEE Robotics and Automation Letters, 7(3):8550–8557, 2022. ↩