びぼーろぐ

備忘録としての勉強のログです。淡々と学んだことをログって行くので、雑な記事が多いです。

Learning Efficient Object Detection Models with Knowledge Distillation

https://papers.nips.cc/paper/6676-learning-efficient-object-detection-models-with-knowledge-distillation.pdf
2017

一言でいうと

knowledge distillationとhint learningを使った物体検出モデルの圧縮

新規性・差分

knowledge distillationは単純な分類問題に対して素晴らしい改善が示されているが、 物体検出では、回帰やRegion proposal、膨大なラベルがない場合に新たなchallengeが出現する。
これらの問題に、

  • クラスの不均衡に対処するために重みのクロスエントロピー誤差
  • 回帰コンポーネントを扱うための損失が制限された教師
  • 中間の教師分布からよりよく学ぶための適応層

で対処する。 結果、一貫してaccuracy-sppedトレードオフが改善した。(多クラス物体検出)

この論文の貢献

この論文の貢献は4つ

  • knowledge distillationを用いてend-to-endで学習可能なマルチクラス物体検出のフレームワークの提案
  • 前述のchallenge用の新しい損失関数の提案。特に、クラスの不均衡(背景である場合が多い)に対処するための重みクロスエントロピーの提案、knowledge distillationのための回帰損失を制限された教師、生徒モデルが教師の中間層ニューロンの蒸留からよりよく学習できるようにする適応層(adaptive layer)を提案している。
  • 大きなデータセットでの包括的な検証
  • 汎化と過小適合に関連付けた提案フレームワークの振る舞いの考察

ざっく理論

Faster-RCNNは3つのコンポーネントから成り立つ 1) 特徴を検出するCNN 2) Region Propossal Network(RPN) 3) 分類があっているのかと、出てきたRegion proposalがちゃんと物体のあるところに被さっているかを測るRegression(RCN) 2)は1)の特徴を使い、3)は1)も2)も両方使う。

全体構成

  1. hintを適応して、生徒モデルの特徴を教師モデルの特徴に近づけていく。
  2. knowledge distillationフレームワークを使って、RPNとRCNを強化する。
  3. 生徒モデルと教師モデルの重み量の不均衡を解消するために重みクロスエントロピーロス(weighted cross entropy loss)で損失を取る。
  4. 上界として教師の回帰出力を転送する。生徒の回帰出力のほうが教師のよりもよかったら、追加の損失は加えない。

f:id:taku-buntu:20190108012233p:plain
全体構成図

それぞれの損失の求め方

RCN, RPN, 総合的な損失

L _ { R C N } = \frac { 1 } { N } \sum _ { i } L _ { c l s } ^ { R C N } + \lambda \frac { 1 } { N } \sum _ { j } L _ { r e g } ^ { R C N }

 L _ { R P N } = \frac { 1 } { M } \sum _ { i } L _ { c l s } ^ { R P N } + \lambda \frac { 1 } { M } \sum _ { j } L _ { r e g } ^ { R P N }

 L = L _ { R P N } + L _ { R C N } + \gamma L _ { H i n t }
NはRCN、MはRPN用のバッチサイズ。[tex: L{cls}]はgrand truthラベルを使ったhard softmax lossとsoft knowledge distillation lossをあわせた損失。[tex: L{reg}]はバウンディングボックスの回帰損失(L1損失と教師が制限したL2回帰損失(4)の組み合わせ)

クラス損失

 L _ { c l s } = \mu L _ { h a r d } \left( P _ { s } , y \right) + ( 1 - \mu ) L _ { s o f t } \left( P _ { s } , P _ { t } \right)

soft target損失

 L _ { s o f \iota } \left( P _ { s } , P _ { \iota } \right) = - \sum w _ { c } P _ { l } \log P _ { s }
背景クラスの重みは大きくする。ex.  w_0=1.5

バウンディングボックスの回帰のところの損失

f:id:taku-buntu:20190108030653p:plain  L _ { r e g } = L _ { s L 1 } \left( R _ { s } , y _ { r e g } \right) + \nu L _ { b } \left( R _ { s } , R _ { l } , y _ { r e g } \right)
教師モデルといえども結構間違えることあるから、この出力をそのままターゲットにするのはまずい...
そこで、教師モデルの結果を上限として利用する。 mはマージン。生徒モデルの損失がある程度の余裕mを持って超える(生徒の損失のほうが低い)とき、生徒の損失は0。

特徴適応を用いたHint Learningの損失

教師モデルの中間特徴を生徒の中間特徴に教える。
次の式のように、特徴ベクトルV, ZでL2距離を使う。(学習)

 L _ { \operatorname { Hin } \iota } ( V , Z ) = | V - Z | _ { 2 } ^ { 2 }

L1損失で評価する。

 L _ { H i n t } ( V , Z ) = | V - Z | _ { 1 }

hint learningは指導する層(guided layer)の(チャネル, 幅, 高さ)が一致一致していないと使うことができないので、それらを一致させる層として適応層(adaptation layer)を入れる。hint layerもguided layerもFCなら、adaptation layerにはFCが使われる。Conv層なら1x1のConvを使う。

結果

  • 複雑なCNNベースの物体検出器は効率的に生徒モデルを指南することできた。
  • knowledge distillationとhint learningを今回提案した損失関数と一緒に使うことで、様々な実験で一貫して改善がみられた。
  • 精度の妥協も少なかった。特にPASCALデータセットの場合は全く精度妥協はなかった。

英語

comprehensive・・・総合的な
adaptation・・・適応的な
consistent・・・一貫して
notably・・・特に
insight・・・洞察
ample・・・十分な