びぼーろぐ

備忘録としての勉強のログです。淡々と学んだことをログって行くので、雑な記事が多いです。

Do Deep Nets Really Need to be Deep?

arxiv.org
2014

一言でいうと

浅いネットで深いネットをmimic learning(ものまね学習)することで、モデル圧縮をする。

この論文のすごいところ

教師モデルを使って浅くても優秀な生徒モデルを作ることができることを示した

感想

優れた教師で学習させたら、生徒モデルも性能がよくなるっていうことが定量化されていたので良かった。この論文のウリは事前にHinton先生のKnowledge Distillationを読んでいたのであまり衝撃はなかった。

ざっく理論

ディープなネットでラベル付されたデータを使ってシャローネットを学習。
ディープネット=softmaxとクロスエントロピー関数で学習
シャローネット=183のp値(ディープネットのsoftmaxの出力)とのクロスエントロピーで学習

softmaxを通して確率化してしまうとlogit[10, 20, 30]も[-10, 0, 10]も同じ値になってしまう。これだと重要な情報がなくなってしまうため、logitをターゲットとして使うほうが良い。

回帰を真似る目的関数はこれ(SNN-MIMIC)

 \mathcal { L } ( W , \beta ) = \frac { 1 } { 2 T } \sum _ { t } \left| g \left( x ^ { ( t ) } ; W , \beta \right) - z ^ { ( t ) } \right| _ { 2 } ^ { 2 }
training data= {(x^{(1)}, z^{(1)}), ...}
W=入力特徴xと中間層間の重み行列
 \beta=中間層から出力層の重み
 g \left( x ^ { ( t ) } ; W , \beta \right) = \beta f \left( W x ^ { ( t ) } \right)= t^{th}番目の訓練データの出力

KLダイバージェンスをものまねモデルの損失関数に使うことも考えた。

 \mathrm { KL } \left( p _ { \text { teacher } } | p _ { \text { student } } \right) とL2ノルム損失を最小化する。

線形層を導入してmimic learningを高速化する

シャローネットはディープネットを真似しなきゃいけないから、一層あたりのパラメータ数が多く必要になる(そういうもんなのか...)。 つまりは、Wのサイズが大きくなるから学習に時間がかかる。
--> W \in \mathbb { R } ^ { H \times D }を低ランク近似行列U \in \mathbb { R } ^ { H \times k }V \in \mathbb { R } ^ { k \times D }にする。

 \mathcal { L } ( U , V , \beta ) = \frac { 1 } { 2 T } \sum _ { t } \left| \beta f \left( U V x ^ { ( t ) } \right) - z ^ { ( t ) } \right| _ { 2 } ^ { 2 }

結果

f:id:taku-buntu:20181229224240p:plain
omparison of shallow and deep models: phone error rate (PER) on TIMIT core test set

f:id:taku-buntu:20181229224329p:plain
ShallowMimicNetは100Mになると通常の(アンサンブルCNNと比べれば小さいがそれでも大きい)CNNと同等の結果を示すようになる

浅いものまねモデルは過学習しにくい、1層あたりのパラ数が増えれば精度も上がっていく。

  • なぜものまねモデルはオリジナルラベルで学習したときよりも精度が高いのか?
    • 一部の正解ラベルに間違いがあって、教師モデルはそれを排除している
    • 正解ラベルの状態では生徒のものまねモデルはその領域を表現できない。教師モデルはもっと柔らかい表現を生徒に教えることができる
    • 蒸留が有効である理由と同じ。教師モデルからの出力は0/1のラベルよりも多くの情報を含んでいる

f:id:taku-buntu:20181229231927p:plain
x軸: 教師モデルの精度、y軸: ものまねモデルの精度
優れた教師から学ぶと生徒モデルも良くなる

英単語

fidelity・・・忠実
logarithm・・・対数