Study: Artificial Intelligence(AI)/AI: Light Weight(LW)

[LW] ๋„คํŠธ์›Œํฌ ๊ฒฝ๋Ÿ‰ํ™”: Knowledge Distillation ๋‹ค์–‘ํ•œ ๋ฐฉ๋ฒ•๋“ค(feat. RepDistiller)

DrawingProcess 2024. 4. 26. 16:32
๋ฐ˜์‘ํ˜•
๐Ÿ’ก ๋ณธ ๋ฌธ์„œ๋Š” '[DL] Knowledge Distillation ๋‹ค์–‘ํ•œ ๋ฐฉ๋ฒ•๋“ค(feat. RepDistiller)'์— ๋Œ€ํ•ด ์ •๋ฆฌํ•ด๋†“์€ ๊ธ€์ž…๋‹ˆ๋‹ค.
Knowledge Distillation์€ 2014๋…„์— ์ œ์‹œ๋œ ๋ฐฉ๋ฒ•๋ก ์œผ๋กœ ๋”ฅ๋Ÿฌ๋‹ ๋ถ„์•ผ์น˜๊ณ ๋Š” ๊ฝค๋‚˜ ์˜ค๋ž˜๋œ ๋ฐฉ๋ฒ•๋ก ์ด์ง€๋งŒ, ํ˜„์žฌ๊นŒ์ง€๋„ ๋ชจ๋ธ์„ ๊ฒฝ๋Ÿ‰ํ™”ํ•˜๋Š”๋ฐ ๋งŽ์ด ์‚ฌ์šฉ๋˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ์ด Knowledge Distillation์—์„œ ํŒŒ์ƒ๋œ ๋ฐฉ๋ฒ•๋ก ๋“ค์— ๋Œ€ํ•ด ์ฝ”๋“œ๋ ˆ๋ฒจ๋กœ ์ •๋ฆฌํ•˜์˜€์œผ๋‹ˆ ์ฐธ๊ณ ํ•˜์‹œ๊ธฐ ๋ฐ”๋ž๋‹ˆ๋‹ค.

1. Knowledge Distillation?

Knowledge Distillation ์ด๋ž€?

๋”ฅ๋Ÿฌ๋‹์—์„œ Knowledge Distillation์€ ํฐ ๋ชจ๋ธ(Teacher Network)๋กœ๋ถ€ํ„ฐ ์ฆ๋ฅ˜ํ•œ ์ง€์‹์„ ์ž‘์€ ๋ชจ๋ธ(Student Network)๋กœ transferํ•˜๋Š” ์ผ๋ จ์˜ ๊ณผ์ •์ด๋ผ๊ณ  ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๊ทธ๋ ‡๋‹ค๋ฉด ์™œ Knowledge Distillation?

์ง€์‹ ์ฆ๋ฅ˜๋ฅผ ์ฒ˜์Œ์œผ๋กœ ์†Œ๊ฐœํ•œ ๋…ผ๋ฌธ์ธ "Distilling the Knowledge in a Neural Network(Hinton)"์€ ๋ชจ๋ธ ๋ฐฐํฌ(model deployment) ์ธก๋ฉด์—์„œ ์ง€์‹ ์ฆ๋ฅ˜์˜ ํ•„์š”์„ฑ์„ ์ฐพ๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ๊ทธ๋ ‡๋‹ค๋ฉด ๋‹ค์Œ์˜ ๋‘ ๋ชจ๋ธ์ด ์žˆ๋‹ค๋ฉด ์–ด๋–ค ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜๋Š” ๊ฒŒ ์ ํ•ฉํ• ๊นŒ์š”?

  • ๋ณต์žกํ•œ ๋ชจ๋ธ T : ์˜ˆ์ธก ์ •ํ™•๋„ 99% + ์˜ˆ์ธก ์†Œ์š” ์‹œ๊ฐ„ 3์‹œ๊ฐ„
  • ๋‹จ์ˆœํ•œ ๋ชจ๋ธ S : ์˜ˆ์ธก ์ •ํ™•๋„ 90% + ์˜ˆ์ธก ์†Œ์š” ์‹œ๊ฐ„ 3๋ถ„

์–ด๋–ค ์„œ๋น„์Šค๋ƒ์— ๋”ฐ๋ผ ๋‹ค๋ฅผ ์ˆ˜ ์žˆ๊ฒ ์ง€๋งŒ, ๋ฐฐํฌ ๊ด€์ ์—์„œ๋Š” ๋‹จ์ˆœํ•œ ๋ชจ๋ธ S๊ฐ€ ์กฐ๊ธˆ ๋” ์ ํ•ฉํ•œ ๊ฒƒ์œผ๋กœ ๋ณด์ž…๋‹ˆ๋‹ค. ๊ทธ๋ ‡๋‹ค๋ฉด, ๋ณต์žกํ•œ ๋ชจ๋ธ T์™€ ๋‹จ์ˆœํ•œ ๋ชจ๋ธ S๋ฅผ ์ž˜ ํ™œ์šฉํ•˜๋Š” ๋ฐฉ๋ฒ•๋„ ์žˆ์ง€ ์•Š์„๊นŒ์š”? ๋ฐ”๋กœ ์—ฌ๊ธฐ์„œ ํƒ„์ƒํ•œ ๊ฐœ๋…์ด ์ง€์‹ ์ฆ๋ฅ˜(Knowledge Distillation)์ž…๋‹ˆ๋‹ค. ํŠนํžˆ, ๋ณต์žกํ•œ ๋ชจ๋ธ์ด ํ•™์Šตํ•œ generalization ๋Šฅ๋ ฅ์„ ๋‹จ์ˆœํ•œ ๋ชจ๋ธ S์— ์ „๋‹ฌ(transfer)ํ•ด์ฃผ๋Š” ๊ฒƒ์„ ๋งํ•ฉ๋‹ˆ๋‹ค.

Knowledge Distillation ๋Œ€๋žต์  ์ดํ•ด

Teacher Network(T): cumbersome model

  • (pros) excellent performance
  • (cons) computationally espansive
  • can not be deployed when limited environments

Student Network(S): Small Model

  • (pros) fast inference
  • (cons) lower performance than T
  • suitable for deployment

Knowledge Distillation ๊ตฌ์ฒด์ ์ธ ๋ฐฉ๋ฒ•๋ก  (Sotf Label, KD Loss)

๊ทธ๋ ‡๋‹ค๋ฉด ์–ด๋–ป๊ฒŒ ํฐ ๋ชจ๋ธ๋กœ๋ถ€ํ„ฐ ์ž‘์€ ๋ชจ๋ธ๋กœ ์ง€์‹์„ ์ „๋‹ฌํ•  ์ˆ˜ ์žˆ๋Š” ๊ฑธ๊นŒ์š”? ์ด๋Š” ์‹ ๊ฒฝ๋ง๊ณผ ์†์‹คํ•จ์ˆ˜๋ฅผ ์‚ดํŽด๋ณด๋ฉด ์‰ฝ๊ฒŒ ์ดํ•ดํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์•ž์„œ ์–ธ๊ธ‰ํ•œ Knowledge Distillation์„ ์ฒ˜์Œ ์ œ์‹œํ•œ Hinton์˜ ๋…ผ๋ฌธ์—์„œ ๋‚˜์˜จ Hinton's KD๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ์‚ดํŽด๋ณด๋„๋ก ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.

1) Sotf Label

์ผ๋ฐ˜์ ์œผ๋กœ, ์ด๋ฏธ์ง€ ํด๋ž˜์Šค ๋ถ„๋ฅ˜์™€ ๊ฐ™์€ task๋Š” ์‹ ๊ฒฝ๋ง์˜ ๋งˆ์ง€๋ง‰ softmax ๋ ˆ์ด์–ด๋ฅผ ํ†ตํ•ด ๊ฐ ํด๋ž˜์Šค์˜ ํ™•๋ฅ ๊ฐ’์„ ๋ฑ‰์–ด๋‚ด๊ฒŒ ๋ฉ๋‹ˆ๋‹ค. softmax ์ˆ˜์‹์„ ํ†ตํ•ด ๋‹ค์Œ๊ณผ ๊ฐ™์ด ๋ฒˆ์งธ ํด๋ž˜์Šค์— ๋Œ€ํ•œ ํ™•๋ฅ ๊ฐ’()๋ฅผ ๋งŒ๋“ค์–ด๋‚ด๋Š” ๋ฐฉ์‹์ž…๋‹ˆ๋‹ค.

์ด๋•Œ, Hinton์€ ์˜ˆ์ธกํ•œ ํด๋ž˜์Šค ์ด์™ธ์˜ ๊ฐ’์„ ์ฃผ์˜ ๊นŠ๊ฒŒ ๋ณด์•˜์Šต๋‹ˆ๋‹ค. ๊ฐœ๋ฅผ ์ œ์™ธํ•œ ๊ณ ์–‘์ด๋‚˜ ์ž๋™์ฐจ ๊ทธ๋ฆฌ๊ณ  ์ –์†Œ์˜ ํ™•๋ฅ ์„ ๋ณด์•˜์œผ๋ฉฐ, ์ด ์ถœ๋ ฅ๊ฐ’๋“ค์ด ๋ชจ๋ธ์˜ ์ง€์‹์ด ๋  ์ˆ˜ ์žˆ๋‹ค๊ณ  ๋งํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ํ•˜์ง€๋งŒ, ์ด๋Ÿฌํ•œ ๊ฐ’๋“ค์€ softmax์— ์˜ํ•ด ๋„ˆ๋ฌด ์ž‘์•„ ๋ชจ๋ธ์— ๋ฐ˜์˜ํ•˜๊ธฐ ์‰ฝ์ง€ ์•Š์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๋”ฐ๋ผ์„œ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์ถœ๋ ฅ๊ฐ’์˜ ๋ถ„ํฌ๋ฅผ ์ข€ ๋” softํ•˜๊ฒŒ ๋งŒ๋“ค๋ฉด, ์ด ๊ฐ’๋“ค์ด ๋ชจ๋ธ์ด ๊ฐ€์ง„ ์ง€์‹์ด๋ผ๊ณ ๋„ ๋ณผ ์ˆ˜ ์žˆ์„ ๋“ฏ ํ•ฉ๋‹ˆ๋‹ค. ์ด๊ฒƒ์ด ๋ฐ”๋กœ Knowledge Distillation์˜ ์‹œ์ดˆ(Hinton’s KD)์ž…๋‹ˆ๋‹ค. ๊ทธ๋ฆฌ๊ณ  ํ•ด๋‹น ๋…ผ๋ฌธ์—์„œ๋Š” ์ด๋Ÿฌํ•œ soft output์„ dark knowledge๋ผ๊ณ  ํ‘œํ˜„ํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.

์ด๋ ‡๊ฒŒ softํ•˜๊ฒŒ ๋งŒ๋“ค์–ด์ฃผ๋Š” ๊ณผ์ •์„ ์ˆ˜์‹์œผ๋กœ ํ‘œํ˜„ํ•˜๋ฉด, ์•„๋ž˜์™€ ๊ฐ™์Šต๋‹ˆ๋‹ค. ๊ธฐ์กด softmax ํ•จ์ˆ˜์— T(temperature) ๊ฐ’์„ ๋ถ„๋ชจ๋กœ ๋„ฃ์–ด์ฃผ์–ด ๋ถ„ํฌ๋ฅผ softํ•˜๊ฒŒ ๋งŒ๋“ค์–ด์ฃผ๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.

$$ qi = \frac{exp(z_i)}{\sum_j exp(z_j)} \to qi = \frac{exp(z_i / T)}{\sum_j exp(z_j/T)} $$

2) distillation loss

์œ„์—์„œ ์ •์˜ํ•œ Hinton์˜ soft target์€ ๊ฒฐ๊ตญ ํฐ ๋ชจ๋ธ(T)์˜ ์ง€์‹์„ ์˜๋ฏธํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋ ‡๋‹ค๋ฉด ์ด ์ง€์‹์„ ์–ด๋–ป๊ฒŒ ์ž‘์€ ๋ชจ๋ธ(S)์—๊ฒŒ ๋„˜๊ธธ ์ˆ˜ ์žˆ์„๊นŒ์š”? ๋จผ์ €, ํฐ ๋ชจ๋ธ(T)์„ ํ•™์Šต์„ ์‹œํ‚จ ํ›„ ์ž‘์€ ๋ชจ๋ธ(S)์„ ๋‹ค์Œ๊ณผ ๊ฐ™์€ ์†์‹คํ•จ์ˆ˜๋ฅผ ํ†ตํ•ด ํ•™์Šต์‹œํ‚ต๋‹ˆ๋‹ค.

$$ L = \sum_{(x, y)\in D} L_{KD}(S(x, \Theta_S, \tau ), T(x, \Theta_T, \tau )) + \lambda L_{CE}(\hat{y}_S, y) $$

  • Student model, Teacher model
  • ํ•˜๋‚˜์˜ ์ด๋ฏธ์ง€์™€ ๊ทธ ๋ ˆ์ด๋ธ”
  • ๋ชจ๋ธ์˜ ํ•™์Šต ํŒŒ๋ผ๋ฏธํ„ฐ
  • temperature

๋Š” ์ž˜ ํ•™์Šต๋œ Teacher model์˜ soft labels์™€ Student model์˜ soft predictions๋ฅผ ๋น„๊ตํ•˜์—ฌ ์†์‹คํ•จ์ˆ˜๋ฅผ ๊ตฌ์„ฑํ•ฉ๋‹ˆ๋‹ค. ์ด๋•Œ, ์˜จ๋„()๋Š” ๋™์ผํ•˜๊ฒŒ ์„ค์ •ํ•˜๊ณ  Cross Entropy Loss๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

2. Knowledge Distillation Schema

Offline Distillation

์ดˆ๊ธฐ knowledge distillation์€ pre-trained teacher model์˜ teacher knowledge๋ฅผ student model๋กœ transfferedํ•จ.

๋”ฐ๋ผ์„œ 2๊ฐ€์ง€ stage์˜ training ํ”„๋กœ์„ธ์Šค๊ฐ€ ์žˆ์Œ.

  • 1) distillation ์ „์— large teacher model์„ ๋จผ์ € ํ›ˆ๋ จํ•จ
  • 2) ์•ž์„œ ์–ธ๊ธ‰ํ–ˆ๋“ฏ teacher model์˜ logits์ด๋‚˜ intermediate features๋ฅผ knowledge๋กœ ์ถ”์ถœํ•˜์—ฌ student model์˜ distillation ํ›ˆ๋ จ์‹œ ๊ฐ€์ด๋“œ๋กœ ์‚ฌ์šฉํ•จ

Online Distillation

In online distillation, both the teacher model and the student model are updated simultaneously, and the whole knowledge distillation framework is end-to-end trainable.

Self Distaillation

In self-distillation, the same networks are used for the teacher and the student models.

(self-distillation means student learn knowledge by oneself)

3. Knowledge Distillation ๋‹ค์–‘ํ•œ ๋ฐฉ๋ฒ•๋ก  (Algorithms)

 

 

์ฐธ๊ณ 

๋ฐ˜์‘ํ˜•