Distilling the Knowledge in a Neural Network
arxiv.org
2015/3
今回は初のモデル圧縮論文読みということで、詳しめに書きます。
一言で言うと
Knowledge Distillationという蒸留を利用したモデル圧縮
この論文のすごいところ
アンサンブルモデルを(一つの)蒸留モデルに転移させたら、同等のネットワーク構造のもとに比べて、高い精度を達成することができる。
感想
やばい、全然わからんかった(汗)
そもそも蒸留のことそんなに知らなかったので、"3"のデータなくてもソフトターゲットを使うと学習されてく的なのは面白いと思いました。
また、スペシャリストモデルなるものを作って、混同しやすいクラスを見分けるモデルでアンサンブルするのも面白いと思いました。確かに、並列学習させることができる場合、そっちのほうが精度良く、効率よくアンサンブルモデルの素を作ることができると思いました。
混同しやすいクラスはクラスタリングで求めてくんですか、参考になります。
ざっく理論
転移セットとソフトターゲットのlogit(Softmax前の値)でクロスエントロピー $$\frac { \partial C } { \partial z _ { i } } = \frac { 1 } { T } \left( q _ { i } - p _ { i } \right) = \frac { 1 } { T } \left( \frac { e ^ { z _ { i } / T } } { \sum _ { j } e ^ { z _ { j } / T } } - \frac { e ^ { v _ { i } / T } } { \sum _ { j } e ^ { v _ { j } / T } } \right)$$
Tはsoftmaxの温度(強度?)で、この値がlogitよりも大きかったら
$$\frac { \partial C } { \partial z _ { i } } \approx \frac { 1 } { T } \left( \frac { 1 + z _ { i } / T } { N + \sum _ { j } z _ { j } / T } - \frac { 1 + v _ { i } / T } { N + \sum _ { j } v _ { j } / T } \right)$$
が成り立つ。
それぞれのlogitがゼロ平均化(平均が0になるようにベクトルを変換)されると
と が0になるので
$$\frac { \partial C } { \partial z _ { i } } \approx \frac { 1 } { N T ^ { 2 } } \left( z _ { i } - v _ { i } \right)$$
とできる。
つまりは、
・高い温度Tで設定しておくと
蒸留は個々の転移サンプルがゼロ平均化されてるとき、$$1 / 2 \left( z _ { i } - v _ { i } \right) ^ { 2 }$$を最小化することを目指すことと等価になる。
・低い温度のときは、
平均よりもかなり低いマッチングロジットに対しては注意を払わなくなる。
これらのロジットはcumbersomeモデル(もともとの大きめのモデル)の学習に使われたコスト関数に全く制約を受けないので、とてもノイジー?(いろいろな情報が入ってるってこと?)になりうるとのこと。
どうやら温度Tを高温にすると、出力確率を緩やかにすることができるらしい。
確かに、指数関数は根本付近に近づくと、緩やかになることが想像できる。
一方、負のロジットはcumbersomeモデルから取得された知識について有益な情報を伝える可能性がある
結果
- MNISTでは教えてもない"3"を他のデータからのsoft targetラベルで学習させると、"3"をまあまあ認識できるようになった。"3"のクラスのバイアスが小さかった。
- 音響認識では、蒸留モデルが同等のサイズの学習モデルよりも性能が良いことを目指している。
- それぞれの観測値の状態グランドトゥルースシーケンスを強制的にアライメントすることで得られるネットとラベルで作られる推論値とのクロスエントロピーを局所的に最小化することで、フレームごとの分類ができる。 $$\theta = \arg \max _ { \boldsymbol { \theta } ^ { \prime } } P \left( h _ { t } | \mathbf { s } _ { t } ; \boldsymbol { \theta } ^ { \prime } \right)$$ =時刻tの観測値をにマッピングする音響モデルPのパラメータ。が正解のHMMの状態。
- 10のモデルからアンサンブルしたモデルを作る。アンサンブル方法はそれぞれの推論結果の平均。それぞれのモデルは別々の学習セットで学習させて、多様性をもたせた。アンサンブルモデルは個々のモデルよりめっちゃ性能よくなる。[1, 2, 5, 10]の温度で試し、ハードターゲットの相対重みは0.5とした。
- 大きなデータセットでスペシャリストモデルをアンサンブルする
- アンサンブルモデルは学習時間が過剰にかかる
- ソフトターゲットでスペシャリストモデルの過学習を抑える
- 混同しやすいクラス("3"と"8"とか)で作ったサブセットで作ったスペシャリストモデルと全てのデータで学習したジェネラリストモデルでアンサンブル
- スペシャリストモデルのソフトマックスは、考慮しないすべてのクラスをゴミ箱クラスとしてまとめることで非常に小さく抑えることができる
- 過学習を抑えることと低レベルな特徴検出能力を共有するためにそれぞれのスペシャリストモデルのパラはジェネラリストモデルで初期化
- 混同しやすいクラスをどうやって決めるのか?-->クラスタリングアルゴリズム。 それをジェネラリストモデルの予測の共分散行列に適用するので、一緒に出力されるクラス[tex:Sm]のセットはスペシャリストモデルmのためのターゲットとして使われる。
- 2ステップでトップ1のクラス分類を行う。
- 正則化としてのソフトターゲット。 スペシャリストモデルの過学習抑制にも使える。
覚えた英語
cumbersome・・・厄介な、めんどうな、扱いにくい
advantageous・・・有利
discrete・・・離散的
take A into account・・・Aを考慮して
alignment・・・位置合わせ
outperform・・・めちゃ性能いい
vary・・・多様にする
indeed・・・たしかに、実際に
asynchronous・・・非同期の
replica・・・レプリカ、複製品
shade・・・次第に変化する
dustbin・・・ゴミ箱
undersampling・・・少数派のデータ件数に合うように多数派データからランダムに抽出する方法
oversampling・・・少数派のデータをもとに不足分のデータを補完すること
derive・・・得る、導出する
derive from・・・由来する