๐Ÿ“ƒ On Layer Normalization in the Transformer Architecture ๋ฆฌ๋ทฐ

ICML 2020์— ๋‚˜์˜จ ๋…ผ๋ฌธ์ด๊ณ , ZeRO ํ•™์Šต ํŠœํ† ๋ฆฌ์–ผ์— ์“ฐ์ธ PreLN ๊ตฌ์กฐ๋ฅผ ์†Œ๊ฐœํ•œ ๋…ผ๋ฌธ์ด๋‹ค. Microsoft Research์—์„œ ๋‚˜์˜จ ๋…ผ๋ฌธ์ธ๋“ฏโ€ฆ? ๊ทธ๋ž˜์„œ ZeRO ํŠœํ† ๋ฆฌ์–ผ์—์„œ ์“ด ๊ฒƒ ๊ฐ™๋‹ค. BERT ํ•™์Šต ๋„์ค‘ ์ฝ์–ด๋ณธ ๋…ผ๋ฌธ์ด๋ผ ๊ฐ„๋‹จํ•˜๊ฒŒ๋งŒ ์ •๋ฆฌ

Abstract

  • learning rate warm-up์ด ํ•„์š”ํ•œ ์ด์œ ๋ฅผ ์‚ดํŽด๋ณด๊ณ  LN(LayerNormalization) ์ปดํฌ๋„ŒํŠธ์˜ ์œ„์น˜๊ฐ€ ์™œ ์ค‘์š”ํ•œ์ง€๋„ ์‚ดํŽด๋ด„
  • PostLN(๊ธฐ์กด Transformer Block) ๊ตฌ์กฐ๋Š” ์ดˆ๊ธฐ์— output layer ๊ทผ์ฒ˜์˜ gradient์˜ ๊ธฐ๋Œ“๊ฐ’์ด ๋งค์šฐ ํฌ๋‹ค.
    • ๊ทธ๋ž˜์„œ warmup์ด ํ•„์ˆ˜์ ์ด๊ณ , ์—†์„๋•Œ๋Š” ํ•™์Šต์ด ์—„์ฒญ unstableํ•˜๋‹ค.
  • ๊ทผ๋ฐ ์—ฌ๊ธฐ์„œ ์ œ์‹œํ•œ PreLN ๊ตฌ์กฐ๋Š” ์ดˆ๊ธฐ์—๋„ ๊ดœ์ฐฎ๋‹ค.

Introduction

  • Transformer Block์˜ ํ‘œํ˜„๋ ฅ์ด ๋งค์šฐ ํ’๋ถ€ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ํ•™์Šต๋งŒ ์ž˜ ๋˜๋ฉด ์ข‹์€ ์„ฑ๋Šฅ์ด ๋‚˜์˜ค์ง€๋งŒ, ์‹ค์ œ๋กœ ํ•™์Šต์„ ์‹œ์ผœ๋ณด๋ฉด LR์— ๋งค์šฐ ๋ฏผ๊ฐํ•˜๋‹ค๋Š” ๊ฒƒ์„ ์•Œ ์ˆ˜ ์žˆ๋‹ค.
    • ์ด๋Š” ์ด ๋…ผ๋ฌธ์—์„œ ๋งํ•˜๊ธธ ์ดˆ๊ธฐ Gradient๊ฐ€ ๋งค์šฐ ๋ถˆ์•ˆ์ •ํ•˜๊ธฐ ๋•Œ๋ฌธ์ด๋ผ๊ณ  ํ•œ๋‹ค.
    • ๊ทธ๋ž˜์„œ CNN์ด๋‚˜ ๋‹ค๋ฅธ Seq2Seq ๋ชจ๋ธ๋“ค๋ณด๋‹ค optimizeํ•˜๊ธฐ ๋„ˆ๋ฌด ์–ด๋ ต๋‹ค.
  • ๊ทธ๋ž˜์„œ ๋‚˜์˜จ ๊ฒƒ์ด warmup stage์ธ๋ฐ, ์ด๋ ‡๊ฒŒ ํ•™์Šต์„ ํ•˜๋ฉด ํ•™์Šต๋„ ๋Šฆ์–ด์ง€๊ณ  ์ถ”๊ฐ€์ ์ธ ํ•˜์ดํผ ํŒŒ๋ผ๋ฏธํ„ฐ ํŠœ๋‹๋„ ๋งŽ์ด ํ•„์š”ํ•˜๋‹ค.
  • ์ด ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด mean field theory๋ฅผ ์ด์šฉํ•จ
  • LN ์œ„์น˜๋ฅผ ๋ฐ”๊พธ์–ด๋ณด์•˜๊ณ 
    • gradient๊ฐ€ ํ›จ์”ฌ ์ž˜ ๋ถ„๋ฐฐ๋˜๋ฉด์„œ
    • LN์ด gradient control์— ๋งŽ์€ ์—ญํ• ์„ ํ•˜๋Š” ๊ฒƒ์„ ๋ณด์—ฌ์ฃผ์—ˆ๋‹ค.
  • ๊ฒฐ๋ก ์ ์œผ๋กœ PreLN ๋ชจ๋ธ์ด ํ•™์Šต๋„ ๋น ๋ฅด๋ฉด์„œ ์ ์€ ์Šคํ…๋งŒ์œผ๋กœ ๋น„์Šทํ•œ ์„ฑ๋Šฅ์„ ๋‚ผ ์ˆ˜ ์žˆ๋‹ค.

ํŒจ์Šค

Optimization for the Transformer

Transformer with Post Layer Normalization

  • ์ด๊ฑด ์ถ”ํ›„์— ๋‚˜์˜ค๋Š” ํ‘œ๊ธฐ ๋•Œ๋ฌธ์—..
  • \(W^1\), \(b^1\) -> intermediate dense layer. base ๊ธฐ์ค€์œผ๋กœ 768 -> 3072 ๊ฐ€๋Š” ๊ทธ ๋ ˆ์ด์–ด
  • \(W^2\), \(b^2\) -> intermediate output layer. 3072 -> 768 ๊ฐ€๋Š” ๊ทธ ๋ ˆ์ด์–ด

The learning rate warm-up stage

  • warmup๋•Œ๋Š” LR max๊นŒ์ง€ linearํ•˜๊ฒŒ ์ฆ๊ฐ€
  • PostLN ๊ตฌ์กฐ์— ์–ผ๋งˆ๋‚˜ ์น˜๋ช…์ ์ธ์ง€ ๋ณด๊ธฐ ์œ„ํ•ด์„œ (Popel & Bojar, 2018) IWSLT14 German-to-English ๋ฒˆ์—ญ ํƒœ์Šคํฌ๋ฅผ ํ’€์–ด๋ด„
  • ๊ฒฐ๊ณผ
    • ๊ฒฐ๊ตญ Adam์„ ์“ฐ๋‚˜ SGD๋ฅผ ์“ฐ๋‚˜ WarmUp์ด ์ค‘์š”ํ•œ ๊ฒƒ์€ ๋งˆ์ฐฌ๊ฐ€์ง€๋กœ ๋ณด์ธ๋‹ค.
    • ๋˜ํ•œ warmup step ์ˆ˜์— ๊ต‰์žฅํžˆ ๋ฏผ๊ฐํ•œ ๊ฒƒ์œผ๋กœ ๋ณด์ธ๋‹ค.
      • ์ด์™€ ๊ด€๋ จํ•ด์„œ ๋…ผ๋ฌธ์— ๋ ˆํผ๋Ÿฐ์Šค๋กœ ๋‹ฌ๋ ค์žˆ์ง„ ์•Š์ง€๋งŒ, LAMB Optimizer ๋…ผ๋ฌธ๋„ ๋งŽ์€ ๋„์›€์ด ๋˜์—ˆ๋‹ค.

Understanding the Transformer at initialization

Notation, Initialization๊ณผ ๊ด€๋ จํ•ด์„œ๋Š” ์ž์„ธํžˆ ์„ค๋ช…์ด ๋˜์–ด์žˆ์ง€๋งŒ ํŒจ์Šค

Post-LN Transformer vs Pre-LN Transformer

  • ์–ด์ฐŒ๋˜์—ˆ๋“  \(Z\)๊ฐ€ epsilon delta bounded๋ฉด ๋†’์€ ํ™•๋ฅ ๋กœ ๊ธฐ๋Œ“๊ฐ’์—์„œ ๋จผ ์œ„์น˜์— ์žˆ์ง€ ์•Š๋‹ค๋Š” ๊ฒƒ์„ ์•Œ ์ˆ˜ ์žˆ๋‹ค.
  • ๊ทธ๋ž˜์„œ ์œ„๋ฅผ ์ƒ๊ฐํ•˜๊ณ  ์•„๋ž˜ Theorem์„ ๋ณด์ž
  • ๋ด์•ผํ•˜๋Š” ๊ฑด sqrt์•ˆ์— L๋กœ ๋‚˜๋ˆ„๋Š” ํ•ญ์ด ์žˆ๋Š” ๊ฒƒ.
  • Gradient์˜ Frobenius Norm์ด ํ›จ์”ฌ ์ž‘๋‹ค.
  • PostLN์€ Layer ๊ฐœ์ˆ˜์— ์ƒ๊ด€์—†์ด \(O(d\sqrt {\ln d})\)์ธ๋ฐ PreLN์€ ๋ ˆ์ด์–ด๊ฐ€ ๋งŽ์•„์งˆ ์ˆ˜๋ก ๋งˆ์ง€๋ง‰ ๋ ˆ์ด์–ด์— ๊ฑธ๋ฆฌ๋Š” Gradient๊ฐ€ ์ค„์–ด๋“ ๋‹ค.
  • ์œ„ theorem์„ ์ดํ•ดํ•˜๊ธฐ ์œ„ํ•ด์„œ lemma 3๊ฐœ๊ฐ€ ๋‚˜์˜ค๋Š”๋ฐ ์ด๋ฅผ ์ •๋ฆฌํ•ด๋ณด๋ฉด
    1. d-dimensional Gaussian vector๋ฅผ ReLU์— ๋„ฃ์—ˆ์„ ๋•Œ ๊ทธ ๊ฒฐ๊ณผ์˜ l2 norm์˜ ๊ธฐ๋Œ“๊ฐ’ -> ๊ทธ๋ƒฅ ๊ธฐ๋ณธ์ ์œผ๋กœ ๋’ค ์ฆ๋ช…์—์„œ ํ™œ์šฉํ•˜๊ธฐ ์œ„ํ•จ
    2. Pre/Post LN Layer์˜ ์ค‘๊ฐ„ ๊ฒฐ๊ณผ/์ตœ์ข… ๊ฒฐ๊ณผ.
      • PostLN์€ ๋ ˆ์ด์–ด๊ฐ€ ๋‹ฌ๋ผ์ ธ๋„ Scale์ด ํฌ๊ฒŒ ๋‹ฌ๋ผ์ง€์ง€ ์•Š๋Š”๋ฐ
      • PreLN์€ ๋ ˆ์ด์–ด๊ฐ€ ์ฆ๊ฐ€ํ•จ์— ๋”ฐ๋ผ Scale๋„ ๊ฐ™์ด ์ปค์ง„๋‹ค.
    3. LayerNormlization์— ์‹ค๋ฆฌ๋Š” Gradient๊ฐ€ input vector์˜ norm์— ๋ฐ˜๋น„๋ก€ํ•œ๋‹ค.
  • ๋”ฐ๋ผ์„œ ๋ฉ”์ธ ์•„์ด๋””์–ด๋Š” โ€œLayerNorm์ด Gradient๋ฅผ Normalizeํ•œ๋‹ค.โ€œ๊ฐ€ ๋œ๋‹ค.
  • PostLN ๊ตฌ์กฐ๋Š” ์Šค์ผ€์ผ์ด ์ผ์ •ํ•˜๊ธฐ ๋•Œ๋ฌธ์— Last Layer์˜ Gradient๋„ ์ผ์ •ํ•œ๋ฐ ๋น„ํ•ด PreLN์€ ์Šค์ผ€์ผ์ด ๋ ˆ์ด์–ด ์ˆ˜์— ๋น„๋ก€ํ•˜๊ธฐ ๋–„๋ฌธ์— Gradient๋Š” \(\sqrt L\)๋กœ Normalize๋œ๋‹ค.

Empirical verification of the theory and discussion

๊ฐ ๋ ˆ์ด์–ด Gradient์˜ Frobenius Norm์„ ๊ธฐ๋กํ•œ ๊ฒฐ๊ณผ

์—ฌ๊ธฐ์„œ PreLN์ด ๋” ํ•™์Šต์ด ์ž˜ ๋  ์ˆ˜ ์žˆ์„ ๊ฒƒ์„ ์•Œ ์ˆ˜ ์žˆ๋Š”๋ฐ, PostLN์˜ ๊ฒฝ์šฐ์—๋Š” Gradient๊ฐ€ ๋„ˆ๋ฌด ํฌ๊ฒŒ ์‹ค๋ ค์„œ, LR์„ ์ž‘๊ฒŒ ์ฃผ๊ฑฐ๋‚˜ Warmup Stage๋ฅผ ์ฃผ์–ด์•ผ ํ•œ๋‹ค. ํ•˜์ง€๋งŒ, PreLN ๊ตฌ์กฐ๋Š” LR ์กฐ์ •๋งŒ์œผ๋กœ ์ž˜ ํ•™์Šต์ด ๊ฐ€๋Šฅํ•  ๊ฒƒ์ด๋‹ค.

Experiment

NMT๋ž‘ BERTํ–ˆ์ง€๋งŒ, ๊ด€์‹ฌ์žˆ๋Š” ๊ฒƒ์€ ์—ญ์‹œ BERT.

BERT์™€ ๋น„์Šทํ•œ ์‚ฌ์ด์ฆˆ์˜ Corpus๋ฅผ ๊ฐ€์ ธ๊ฐ”์„ ๋•Œ ์›”๋“ฑํ•œ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์˜€๋‹ค.

validation loss๊ฐ€ 1.7 ๊ทผ์ฒ˜์— ๋„๋‹ฌํ•˜๋Š”๋ฐ Warmup 10k๋ฅผ ์ค€ PostLN ๋ชจ๋ธ์€ 700k Step์„ ๋ฐŸ์•˜๊ณ , PreLN๋ชจ๋ธ์€ 500k๋งŒ ํ•™์Šต์„ ์ง„ํ–‰ํ–ˆ๋‹ค๊ณ  ํ•œ๋‹ค.

Conclusion

๊ฒฐ๊ตญ ๋ญ PreLN์ด ์ž˜๋œ๋‹ค.. Warmup Stage๋Š” ์•ˆ์ „ํ•˜๊ฒŒ ์—†์•จ ์ˆ˜ ์žˆ๋‹ค.. ์ •๋„๋กœ ์š”์•ฝ์ด ๊ฐ€๋Šฅํ•œ๋ฐ, ์—ฌ๊ธฐ์„œ ๋งํ•˜๋Š” ๊ฒƒ์œผ๋กœ ๋ณด์•„ ์ด์ „์— ๋ณธ BERT ๋ถ„์„ ๋…ผ๋ฌธ๋“ค์˜ Task Specific weight๋“ค์ด ์กฐ๊ธˆ ์ž˜๋ชป ๋ถ„์„๋œ ๋…ผ๋ฌธ๋“ค์ธ๊ฐ€?? ์‹ถ๋‹ค.

์ž์„ธํ•œ ์ฆ๋ช…์€ ๋…ผ๋ฌธ Appendix์— ์กด์žฌํ•œ๋‹ค.

July 16, 2020
Tags: paper