๐Ÿ“ƒ Efficient 8-Bit Quantization of Transformer Neural Machine Language Translation Model ๋ฆฌ๋ทฐ

TensorFlow ์ƒ์—์„œ FP32๋ฅผ INT8๋กœ quantization์„ ํ•ด๋ณด๋Š” ๋…ผ๋ฌธ์ด๋‹ค. 1.5๋ฐฐ์˜ ์„ฑ๋Šฅ ํ–ฅ์ƒ์„ ์–ป์œผ๋ฉด์„œ 0.5 BLEU score accuracy๋งŒ ๋–จ์–ด์กŒ๋‹ค๊ณ  ํ•œ๋‹ค. ๋˜ํ•œ intel cpu์— ์ตœ์ ํ™”๋ฅผ ์ง„ํ–‰ํ–ˆ๋‹ค. arxiv ๋งํฌ๋Š” https://arxiv.org/abs/1906.00532์ด๊ณ , intel์—์„œ ๋‚˜์˜จ ๋…ผ๋ฌธ์ด๋‹ค.

1. Introduction

  • Contributions
    • Quantized a trained FP32 Transformer model to INT8 to achieve < 0.5 drop in state-of-the-art (SOTA) BLEU score.
    • Improve inference performance by:
      1. Optimizing quantized MatMuls for tensor shapes and sizes in the Transformer model
      2. Reducingoverheadduetoquantizationoperations in the Transformer model compute graph
      3. Optimizing input pipeline by ordering sentences by token length
      4. Implementing parallel execution of batches with increased inference throughput

ํŒจ์Šค

3. Model Description

  • Transformer๋Š” scaled dot product attention ์‚ฌ์šฉ
  • ์—ฌ๊ธฐ์„œ softmax ์—ฐ์‚ฐ์ด ๋ผ์—ฌ์žˆ๋Š”๋ฐ ํ•ด๋‹น ์—ฐ์‚ฐ์„ quantizationํ•˜๋ฉด acc loss๊ฐ€ ๋†’์„ ๊ฒƒ์ด ๋ช…ํ™•.
  • layer norm๋„ ์žˆ๋Š”๋ฐ ์ด ์—ฐ์‚ฐ์ด mean, variance๋ฅผ ์—ฐ์‚ฐํ•˜๊ธฐ ๋•Œ๋ฌธ์— ์ด๊ฒƒ๋„ ํž˜๋“ค์ง€ ์•Š์„๊นŒ?

4. Quantization with accuracy

  • ,
  • ์œ„์˜ ์‹์„ ๋”ฐ๋ผ quantization์„ ์ง„ํ–‰ํ•˜๋Š”๋ฐ 8bit๋ผ์„œ min ~ max๋Š” ๋‹น์—ฐํžˆ 256์˜ scale์„ ๊ฐ€์ง€๊ฒŒ ๋œ๋‹ค.

4.1. Na ฬˆฤฑve Quantization

  • ์œ„์˜ ๊ทธ๋ฆผ๊ณผ ๊ฐ™์ด ์ง„ํ–‰ํ•  ๋•Œ dequantizationํ•˜๋Š” ๋ฐฉ๋ฒ•:
  • NMT ํƒœ์Šคํฌ์˜€๋Š”๋ฐ, Stop token ๋‚ด๋ฑ‰๋Š”๋ฐ ์‹คํŒจํ•ด์„œ acc๊ฐ€ ๋งŽ์ด ๋–จ์–ด์ ธ๋ฒ„๋ฆผ

4.2. KL-Divergence for optimal saturation thresholds

  • ์ด๊ฒŒ quantization์ด ์–ด์ฐŒ๋˜์—ˆ๋“  ์ž˜ ๋งคํ•‘ํ•˜๋Š” ๊ฒƒ์ด ๋ฌธ์ œ์ด๋‹ค๋ณด๋‹ˆ๊นŒ representation์˜ ๋ฒ”์œ„๋ฅผ ์ ๋‹นํžˆ ์ž˜ ์ค„์ด๊ณ  ๋Š˜๋ฆฌ๋Š” ๊ฒƒ์ด ์ค‘์š”ํ•จ
  • This relies on the assumption that maintaining small differences between tensor values that are close together is more important than representing the absolute extreme values or the outliers. Ideally, the numerical distribution of values in the mapped INT8 tensor representations should be as close as possible to the distribution of values for FP32 tensors.

  • ๊ทธ๋ž˜์„œ KL Divergence ์‚ฌ์šฉํ•จ
  • ์•„์ด๋””์–ด๋Š” ์—ฌ๊ธฐ์„œ ์ฐธ๊ณ ํ–ˆ๋‹ค๊ณ  ํ•จ 8-bit Inference with TensorRT
  • calibration data๋กœ 600 ๋žœ๋ค ์ƒ˜ํ”Œ๋งํ•จ
  • min, max threshold๋ฅผ ์ •ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์„ธ๊ฐ€์ง€ ํ…Œ์ŠคํŠธํ•จ
    1. symmetricํ•˜๊ฒŒ. โ€œthreshold_min = - threshold_maxโ€
    2. ๋…๋ฆฝ์ ์œผ๋กœ ๊ฐ๊ฐ ๊ณ„์‚ฐํ•จ
    3. conjugate๋กœ ๊ณ„์‚ฐํ•จ () ๊ทธ๋ฆฌ๊ณ  symmetricํ•˜๊ฒŒ
  • ๊ทผ๋ฐ ๋…๋ฆฝ์ ์œผ๋กœ ๊ณ„์‚ฐํ•˜๋Š” ๊ฒƒ์ด ์ œ์ผ ์ข‹์Œ
  • ๊ฒฐ๊ตญ ์•„๋ž˜์ฒ˜๋Ÿผ quantization ์ง„ํ–‰ํ•จ

5. Improving Performance

์—ฌ๊ธฐ์„œ๋ถ€ํ„ฐ๊ฐ€ ์ด ๋…ผ๋ฌธ์—์„œ ์ œ์ผ ์žฌ๋ฐŒ๋‹ค๊ณ  ์ƒ๊ฐํ•œ ๋ถ€๋ถ„์ธ๋ฐ, โ€œ์‚ฌ์‹ค์ƒ ์„ฑ๋Šฅ์„ ์ด๊ฑธ๋กœ ์˜ฌ๋ฆฐ๊ฑฐ ์•„๋ƒ??โ€๋ผ๊ณ  ์ƒ๊ฐ๋“ค ์ •๋„์ด๋‹ค.

  • INT8๋กœ ๋ณ€ํ™˜ํ•˜๋ ค๋Š” ์ด์œ :
  • INT8 MatMuls using VNNI provides a speed-up of 3.7X over FP32 MatMuls using AVX512.

  • MKL๋กœ TensorFlow Operation์ง์ ‘ ์ž‘์„ฑํ•จ. (์•„๋งˆ Custom Ops์ธ๋“ฏ?)
  • TensorFlow 1.12๋Š” GEMMLOWP๋ผ๋Š” ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์‚ฌ์šฉํ•˜๊ธฐ ๋•Œ๋ฌธ์— INT8/VNNI๋ฅผ ์ง€์›ํ•˜์ง€ ์•Š๋Š”๋‹ค.
  • ๊ทธ๋ฆฌ๊ณ  ๋ฐ์ดํ„ฐ ๋ณ€ํ™˜ ๊ณผ์ •๋„ ํ•„์š”ํ•ด์„œ ํšจ์œจ์ ์ด์ง€ ์•Š๋‹ค.
  • ๊ทธ๋ž˜์„œ ์ง์ ‘ ์ž‘์„ฑํ•ด๋„ ์•ˆ๋นจ๋ผ์„œ ํ™•์ธํ•ด๋ณด๋‹ˆ๊นŒ MatMul์—์„œ ์ตœ์ ํ™” ์•ˆ๋œ ๋ถ€๋ถ„์ด ์žˆ์—ˆ๊ณ  ๊ทธ ๋ถ€๋ถ„์„ ์ตœ์ ํ™”ํ•จ

  • ๊ทธ ์™ธ์—๋„ ์•„๋ž˜์ฒ˜๋Ÿผ ์ตœ์ ํ™”ํ•จ
    • GatherND๋ฅผ ์ตœ์ ํ™”ํ–ˆ๋Š”๋ฐ ๊ทธ ์ด์œ ๋Š” ์„ฑ๋Šฅ ํ–ฅ์ƒ์ด ์•„๋‹ˆ๋ผ ๋ฐ์ดํ„ฐ ํ†ต์‹ ์„ ๋น ๋ฅด๊ฒŒํ•˜๊ธฐ ์œ„ํ•ด์„œ ์ง„ํ–‰ํ•จ. 32bit๋ณด๋‹ค 8bit๋‚˜๋ฅด๋Š”๊ฒŒ ์•ฝ 3.8x๋ฐฐ๊ฐ€ ๋นจ๋ž๊ธฐ ๋•Œ๋ฌธ
    • input sentence sortingํ•ด์„œ ์—ฐ์‚ฐ ์ง„ํ–‰ํ•จ
    • Quantization ์ค‘์—์„œ ๋ถˆํ•„์š”ํ•œ reshape ๋“ฑ์˜ ์—ฐ์‚ฐ์„ ์ œ๊ฑฐํ•จ
    • batching์„ parallel๋กœ ์ž‘์„ฑํ•จ

6. Throughput Performance Results

  • ํ™˜๊ฒฝ ์…‹์—…์€ ํŒจ์Šค
  • ์„ค์ •๋งŒ ์ž˜ํ•ด๋‘๋ฉด ๋ณ‘๋ ฌ์—ฐ์‚ฐ์ด ์ž˜ ๋˜์–ด์„œ 4.5x๊นŒ์ง€ throughput ํ–ฅ์ƒ๋จ
  • ๊ทผ๋ฐ input pipeline ์ตœ์ ํ™”ํ•œ๊ฒŒ fp32๋„ ์ตœ์ ํ™”ํ•ด๋ฒ„๋ ค์„œ ๊ฒฐ๊ตญ fp32๋ณด๋‹ค 1.5x์ •๋„ ๋น ๋ฅธ ์—ฐ์‚ฐ์ด ๋˜์—ˆ๋‹ค

7. Conclusion

We optimized the compute graph by reducing number of operations, improved kernels of key operations such as MatMuls and GatherNd, optimized order of sentences in the input pipeline and finally used parallel batching to achieve the highest throughput gains of 1.5X.


๊ทธ๋ƒฅ โ€œ8-bit๋กœ ์—ฐ์‚ฐํ•ด๋„ ์ž˜ ๋œ๋‹คโ€์™€ โ€œ๋” ์ตœ์ ํ™” ๊ฐ€๋Šฅํ•œ ๋ถ€๋ถ„์ด ๋งŽ๋‹คโ€ ์ •๋„์˜ ๋…ผ๋ฌธ์ธ ๊ฒƒ ๊ฐ™๋‹ค. MKL๋กœ ์ตœ์ ํ™”ํ•œ ๋ถ€๋ถ„์ด TF2์— ์ ์šฉ๊ฐ€๋Šฅํ•œ์ง€๋Š” ๋ชจ๋ฅด๊ฒ ๋‹ค.

April 27, 2020 ์— ์ž‘์„ฑ
Tags: paper