びぼーろぐ

備忘録としての勉強のログです。淡々と学んだことをログって行くので、雑な記事が多いです。

FITNETS: HINTS FOR THIN DEEP NETS

arxiv.org
2014/12

一言でいうと

knowledge distillationの拡張。今回の生徒モデルは教師モデルよりも深くて細長いモデル。教師モデルの出力だけを使うのではなく、中間特徴も生徒モデルの最終結果(出力層の結果)を改善させるために、ヒントとして利用する。生徒モデルの中間層のパラメータ数は、教師モデルのものよりも少ないため、それらをマッチングさせるマップの役割を果たすパラメータが導入される。これにより、深い生徒モデルはおよそ10.4倍のパラメータを削減できた。 得られたdeep and thinな生徒モデルを(FitNet)と呼ぶ。

結論

教師の中間層から中間レベルのヒントを教えてやることで、とても深い細長い生徒モデルでは、パラメータを少なくし、かつ教師モデルよりも高速で汎化性能の高いモデルを作ることができた。
データセットベンチマークから、キャパシティの少ないDNNは10倍のパラメータを持つネットワークよりも比較的良い特徴表現を得ることがわかった。

手法

knowledge distillation

knowledge distillationの損失は以下の式で取る

 \mathcal { L } _ { K D } \left( \mathbf { W } _ { \mathbf { S } } \right) = \mathcal { H } \left( \mathbf { y } _ { \text { true } } , \mathbf { P } _ { \mathbf { S } } \right) + \lambda \mathcal { H } \left( \mathbf { P } _ { \mathrm { T } } ^ { \tau } , \mathbf { P } _ { \mathbf { S } } ^ { \tau } \right)

ここで、 \mathcal { H } はクロスエントロピー関数で、 \lambdaは生徒モデルがハードターゲットかソフトターゲットのどちらから多く学ぶかを決める比率パラメータ。 このままだと、同じ深さの生徒モデルしか教えることができない。

Hint training

教師モデルの中間層の特徴をhintとして、生徒モデルの学習プロセスを指導する。 guided layer(FitNet(生徒モデル)の中間層)は教師のhint layerから学習。 hint layerとguided layerのペアは過剰に正則化(regularized)されていない層から選ばないと行けない。 教師モデルは生徒モデルよりもwiderなので、特徴サイズを一致させるためにregressorをguided layerに挿入する必要がある。
hint learningの損失関数が以下

 \mathcal { L } _ { H T } \left( \mathbf { W } _ { \text { Guided } } , \mathbf { W } _ { \mathbf { r } } \right) = \frac { 1 } { 2 } \left| u _ { h } \left( \mathbf { x } ; \mathbf { W } _ { \text { Hint } } \right) - r \left( v _ { g } \left( \mathrm { x } ; \mathbf { W } _ { \text { Guided } } \right) ; \mathbf { W } _ { \mathrm { r } } \right) \right| ^ { 2 }
 u _ { h } v _ { g } W _ { Hint } W _ { Guided }を対応付けるネスト関数で、 rは次元を一致させるregressor。

FitNetのステージごとの訓練

f:id:taku-buntu:20190108165442p:plain
生徒モデルのヒントを使った訓練

(a):教師モデルを学習することと、生徒モデルをランダムに初期化する。
(b):FitNetのguided layerのパラメータ W _ { Guided }の上に W _ { r }をつけて、 \mathbf { W } _ { Guided }を上式の L _ { HT }が最小化するように学習。
(c): L _ { KD }が最小化するように全体のFitNetのパラメータ W _ { S }を訓練。

f:id:taku-buntu:20190108170551p:plain
FitNEtsのステージごとの訓練