๐ก ๋ณธ ๋ฌธ์๋ '[Perception] ONNX: ๋ค๋ฅธ DNN ํ๋ ์์ํฌ ๊ฐ ๋ชจ๋ธ ํธํ ํฌ๋ฉง(pytorch, tensorflow, ...)'์ ๋ํด ์ ๋ฆฌํด๋์ ๊ธ์ ๋๋ค.
1. PyTorch, ONNX, TensorRT ๋น๊ต
1) PyTorch 2.0
์ธ์ ๊ฐ๋ถํฐ Tensorflow๋ฅผ ์์ง๋ฅด๊ณ ๊ฐ์ฅ ์ธ๊ธฐ์๋ Deep Learning Framework์ด ๋ PyTorch๊ฐ 2022๋ 12์์ ์๋ก์ด ๋ฒ์ ์ ๊ณต๊ฐํ์ต๋๋ค. ์ด ๊ธ์์ ์์์ผ ํ ์ค์ํ ๋ถ๋ถ์ PyTorch๋ ๋๋ฌด Pythonicํด์ ๋ฌธ์ ๋ผ๋ ๊ฒ์ ๋๋ค.
Python์ ํน์ง
๊ทธ๋ผ Pythonic ํ๋ค๋ ๊ฒ ๋ญ๋?
- ๊ฐ๊ฒฐ์ฑ / ๊ฐ๋
์ฑ: User friendlyํ๊ฒ ์์ฑํ๊ธฐ ๋๋ฌธ์ ์ฝ๋๊ฐ ์ง๊ด์ ์ด๊ณ over-engineering์ ๋ง์ ์ ์์ต๋๋ค.
- ๋ฐ๋๋ก ๋งํ๋ฉด ์ปดํจํฐ๊ฐ ์์๋ฃ๊ธฐ ์ด๋ ต๊ธฐ ๋๋ฌธ์ ์ปดํจํฐ๋ ์ด๊ฒ์ ๋ฒ์ญํ๋๋ผ ์๊ฐ์ด ์ค๋ ๊ฑธ๋ฆฝ๋๋ค.
- ์ธํฐํ๋ฆฌํฐ ์ธ์ด: ์ปดํ์ผ ์ธ์ด์๋ ๋ค๋ฅด๊ฒ ์ปดํ์ผ ๊ณผ์ ์์ด ๋งค๋ฒ ํ์คํ์ค ์คํํฉ๋๋ค. ๊ทธ๋์ ๋ํ์ ์ธ ์ปดํ์ผ ์ธ์ด์ธ C์์๋ ๋ญ๊ฐ ์๋ชป ์ฝ๋ฉ๋๋ฉด ์์ ์คํ์ ํ ์๊ฐ ์์ง๋ง Python์์๋ ์ฒ์์ ์ ์คํ๋๋ ๊ฒ ๊ฐ๋ค๊ฐ๋ ์ค๊ฐ์ ์๋ฌ๋ฅผ ๋ฐ์์ํต๋๋ค.
- ๋ฏธ๋ฆฌ ์ปดํ์ผ ํ์ง ์๊ธฐ ๋๋ฌธ์ ์ปดํจํฐ๋ ๋งค๋ฒ ์ฝ๋๋ฅผ ๋ฐ๋ผ ์๋ก์ด ๋ชจํ์ ํ๋ ๊ฒ๋๋ค.. ์คํ ์๊ฐ์ด ๋๋ฆฌ๊ฒ ์ฃ
์ด ๋ฐ์๋ Python์ ํน์ง์ ๋ง์ด ์์ง๋ง ์๋ฌดํผ ์ฌ์ฉํ๊ธฐ ์ฝ์ง๋ง ๊ทธ๋งํผ ๋๋ฆฌ๋ค๋ผ๋ ํน์ง๋ง ์์๋ฉด ๋ฉ๋๋ค.
Pytorch 2.0์ ํน์ง
๊ทธ๋์ PyTorch 2.0์์๋ torch.compile์ ํฌํจํ ๋ช๊ฐ์ง ๊ธฐ๋ฅ๋ค์ ์ถ๊ฐํ์ฌ ๋์ฑ ๋น ๋ฅด๊ฒ Inference ํ ์ ์๊ฒ ์ ๋ฐ์ดํธ๊ฐ ๋์์ต๋๋ค. ์ด๊ฑธ ์ดํดํ๋ ค๋ฉด JIT(์ ์๊ณ ์์ด์ผ ํ๋๋ฐ, ์์ฝํ์๋ฉด ๋ฏธ๋ฆฌ ์ปดํ์ผ ํ์ฌ ๊ทธ ๋ค๋ก๋ ๋น ๋ฅด๊ฒ ์ถ๋ก ํ ์ ์์์ด ํต์ฌ์ ๋๋ค. ์์ธํ ๋ด์ฉ์ torch.compile tutorial์ ์ฐธ๊ณ ํด์ฃผ์ธ์.
2) ONNX
ONNX๋ "Open Neural Network Exchange"์ ์ฝ์ด๋ก, ์คํ ์์ค ํ๋ก์ ํธ์ ๋๋ค. ONNX๋ ์ธ๊ณต์ง๋ฅ ๋ชจ๋ธ์ ํ์ค ํ์์ผ๋ก ํํํ๊ณ ์๋ก ๋ค๋ฅธ ๋ฅ๋ฌ๋ ํ๋ ์์ํฌ ๊ฐ์ ๋ชจ๋ธ์ ๋ณํํ๊ณ ๊ณต์ ํ ์ ์๊ฒ ํด์ค๋๋ค.
ONNX๋ ๋ฅ๋ฌ๋ ๋ชจ๋ธ์ ๊ตฌ์กฐ์ ๊ฐ์ค์น๋ฅผ ํํํ๋ ์ค๋ฆฝ์ ์ธ ํ์์ ๋๋ค. ์ด ํ์์ ๋ค์ํ ๋ฅ๋ฌ๋ ํ๋ ์์ํฌ(๋ก๋ถํฐ ๋ชจ๋ธ์ ๋ด๋ณด๋ด๊ณ , ๋ค๋ฅธ ํ๋ ์์ํฌ์์ ๊ฐ์ ธ์์ ์คํํ๊ฑฐ๋ ๋ณํํ ์ ์์ต๋๋ค. ONNX ํ์์ ์ฌ์ฉํ๋ฉด ๊ฐ๋ฐ์๋ค์ ๋ค์ํ ํ๋ ์์ํฌ๋ฅผ ์ ์ฐํ๊ฒ ์กฐํฉํ๊ณ ๋ชจ๋ธ์ ์ฌ์ฌ์ฉํ๊ณ ๋ค์ํ ์ธ์ด, ํ๋ซํผ ๋ฐ ๋๋ฐ์ด์ค์์ ์คํํ ์ ์์ต๋๋ค.
TensorRT
TensorRT๋ ONNX์ ๋ง์ฐฌ๊ฐ์ง์ ๋๋ค. ํ์ง๋ง ์ด๊ฑด NVIDIA์์ ๋ง๋ ํ๋ ์์ํฌ๋ก์จ, NVIDIA GPU์์ ์ต์ ํ ๋ ๊ธฐ์ ์ ๋๋ค. ๊ทธ๋ฐ๋ฐ ํ์ฌ๊น์ง๋ ๋๋ถ๋ถ์ ์ฐ๊ตฌ๊ฐ NVIDIA GPU๋ฅผ ์ฌ์ฉํ๊ธฐ ๋๋ฌธ์ ํจ๊ณผ์ ์ธ ๋ฐฉ๋ฒ์ด๋ผ๊ณ ํ ์ ์์ต๋๋ค.
4) PyTorch 2.0 / ONNX / TensorRT ๋น๊ต
๊ทธ๋ผ ์ด์ ์ธ๊ฐ์ง ๋ชจ๋ธ ํ์์ ๋น๊ตํ๊ฒ ์ต๋๋ค. ์ฐธ๊ณ ๋ก ์คํ GPU๋ NVIDIA RTX 3090Ti๋ก,
๋น๊ต๋ฅผ ํตํด ์ป์ ์ ์๋ ํต์ฌ 4๊ฐ์ง๋ ์๋์ ๊ฐ์ต๋๋ค.
- PyTorch 2.0์ batch size๊ฐ ์ปค์ง ์๋ก ์ ์ ๋๋น ํฐ ์ฑ๋ฅ ํฅ์์ ์ด๋ฃธ
- ๋ํ FP16์ Batch size๊ฐ ํด์๋ก ๋น์ ๋ฐํจ
- ์ด๋ PyTorch 2.0์ด ํ์ต ์ต์ ํ์ ์ด์ ์ ๋ง์ท๊ธฐ ๋๋ฌธ
- ONNX Runtime์ Batch size๊ฐ ์์ ๋ PyTorch 2.0๋ณด๋ค ํจ๊ณผ๊ฐ ์ข์
- ์ด๋ ONNX๊ฐ ์ถ๋ก ์ต์ ํ์ ์ด์ ์ ๋ง์ท๊ธฐ ๋๋ฌธ
- PyTorch Eager ๋ PyTorch 2.0์ด๋ Batch size๊ฐ ์์ผ๋ฉด ๋ณ ์ฐจ์ด๊ฐ ์์
- ์ด๊ฑด ์๋ง GPU๊ฐ ์ค๋ฒ์คํ์ด๋ผ ๋ณด์ ํ ์์์ ์ถฉ๋ถํ ํ์ฉํ์ง ๋ชปํ๊ธฐ ๋๋ฌธ
- NVIDIA GPU์์๋ TensorRT๊ฐ ์ต๊ณ
- ๋น์ฐํ ๊ฒ์ด์ง๋ง NVIDIA์์ ์ต์ ํ๋ฅผ ํ๊ธฐ ๋๋ฌธ์ Cache๋ฅผ ๊ฐ์ฅ ํจ์จ์ ์ผ๋ก ํ์ฉํ๋๋ก ์ค๊ณ๋์ผ๋ฆฌ๋ผ ์ถ์
์ค๋์ ์์ ๊ฐ์ด PyTorch 2.0, ONNX, TensorRT ๊ฐ์ ์ถ๋ก ์๋ ๋น๊ต๋ฅผ ํด๋ดค์ผ๋ฉฐ, ๊ฒฐ๋ก ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
- ํ์ตํ ๋๋ PyTorch 2.0
- ์ถ๋ก ํ ๋๋ TensorRT๊ฐ ์ข์ ๋ณด์ด๊ธด ํจ. ํ์ง๋ง ๋ฅ๋ฌ๋ ๋ชจ๋ธ์ ๋ฐฐํฌํ ๋๋ GPU๋ฅผ ์ฐ์ง ์์ ์๋, GPU๊ฐ NVIDIA ์ ํ์ด ์๋ ์๋ ์์ต๋๋ค. ์ ๊ณ ๋ คํ์ ์ผ ํฉ๋๋ค.
- ๋ค์ํ ํ๊ฒฝ์ด๋ ๋ค๋ฅธ ํ๋ ์์ํฌ ๊ฐ ํธํ์ฑ์ ์ํด์๋ ๋ชจ๋ธ์ ONNX ํฌ๋ฉง์ผ๋ก ๋ณํํด์ผ ํฉ๋๋ค. ์ฌ๊ธฐ์ ONNX ํฌ๋ฉง์ผ๋ก ๋ณํํ๋ฉด ์ถ๋ก ์ต์ ํ๋ ์งํํฉ๋๋ค.
2. ONNX
1) ONNX ๋?
ONNX(Open Neural Network Exchange)๋ Tensorflow, PyTorch ์ ๊ฐ์ ์๋ก ๋ค๋ฅธ DNN ํ๋ ์์ํฌ ํ๊ฒฝ์์ ๋ง๋ค์ด์ง ๋ชจ๋ธ๋ค์ ์๋ก ํธํํด์ ์ฌ์ฉํ ์ ์๋๋ก ๋์์ฃผ๋ ๊ณต์ ํ๋ซํผ์ ๋๋ค. ๊ฐ๋ตํ ๋งํด, ๋ค์ํ ํ๋ซํผ ํ๊ฒฝ(Java, JS, C, C#, C++)์์ ํ๊ฒฝ์ ์ ์ฝ ์์ด ๊ตฌํ๋ ‘ML ๋ชจ๋ธ’์ ํธ์ถํ๊ณ ์ํํ์ฌ ์ํ ๊ฒฐ๊ณผ๊ฐ์ ๋ฐํ๋ฐ๋ ๊ฒ์ ์๋ฏธํฉ๋๋ค.
1. Framework Interoperability
์์์ ์ธ๊ธํ๋ค์ํผ ํน์ ํ๊ฒฝ์์ ์์ฑ๋ ๋ชจ๋ธ์ ๋ค๋ฅธ ํ๊ฒฝ์ผ๋ก importํ์ฌ ์์ ๋กญ๊ฒ ์ฌ์ฉ์ ํ ์ ์๋ค๋ ๊ฒ์ ONNX์ ์ต๋ ๊ฐ์ ์ ๋๋ค. ์์ปจ๋, Tensorflow์์ ๋น ๋ฅด๊ฒ ๋ชจ๋ธ์ ํ์ต ์ํจ ๋ค์ ์ด๋ฅผ ๋ชจ๋ฐ์ผ๋ก ์ฎ๊ฒจ์ ์ฌ์ฉ์ ํ๋ ๋ฑ ์ฌ๋ฌ๊ฐ์ง ๋ฐฉ์์ผ๋ก ํ์ฉ ๊ฐ๋ฅํฉ๋๋ค.
2. Shared Optimization
HW vendor(๊ฐ์๊ธฐ์ ๊ฐ์ HW ์ ์กฐ์ ์ฒด)์ ์ ์ฅ์์ ONNX์ ๊ฐ์ ํ๋ ์์ํฌ ๊ฐ ๊ณต์ ๋๋ ํฌ๋งท์ด ์กด์ฌํ๋ฉด, ํ๋์จ์ด ์ค๊ณ์ ONNX representation์ ๊ธฐ์ค์ผ๋ก ์ต์ ํ๋ฅผ ํ๋ฉด ๋๊ธฐ ๋๋ฌธ์ ํจ์จ์ ์ ๋๋ค.
๋ง์น JSON ํฌ๋งท์ด ์ ๋ณด ํํ์ ์ํด์ ์ฌ๋ฌ ๊ฐ๋ฐ์๋ค ์ฌ์ด์์ ํฉ์๋์ด ์ฌ์ฉ๋๋ฏ, ONNX๋ผ๋ ํฉ์๋ DNN ๋ชจ๋ธ ํฌ๋งท์ด ์กด์ฌํ๋ค๊ณ ์๊ฐํ๋ฉด ๋ฉ๋๋ค. ONNX ์ฌ์ฉ๊ณผ ๊ด๋ จํ์ฌ ๋ณด๋ค ์์ธํ ํํ ๋ฆฌ์ผ์ด ํ์ํ๋ค๋ฉด ๋ค์ ํ์ด์ง๋ฅผ ์ฐธ๊ณ ํ๋ฉด ๋ฉ๋๋ค.
2) ONNX Runtime ๋?
ONNX ๋ชจ๋ธ์ ์คํํ๊ธฐ ์ํ ์์ง์ ๋๋ค. ONNX ๋ชจ๋ธ์ ์คํํ๊ธฐ ์ํด ONNX ๋ฐํ์์ ๋น ๋ฅธ ์ถ๋ก ์ ์ํ ์ต์ ํ๋ ์ปค๋์ ์ฌ์ฉํฉ๋๋ค. ๋ํ, ONNX Runtime์ CPU, GPU ๋ฐ ๋ฅ๋ฌ๋ ๊ฐ์๊ธฐ(DNNL, NNAPI, OpenVINO)๋ฅผ ์ง์ํฉ๋๋ค. ๋ฐ๋ผ์, ONNX Runtime์ ONNX ๋ชจ๋ธ์ ์คํํ๊ธฐ ์ํ ์ต์ ํ๋ ๋ฐํ์ ํ๊ฒฝ์ ์ ๊ณตํฉ๋๋ค.
+ ORT ํ์
- ‘์ถ์๋ ํฌ๊ธฐ’์ ONNX Runtime ๋น๋์์ ์ง์ํ๋ ํ์์ ๋๋ค.
- onnx ํํ์์ ort ํํ๋ก ๋ณ๊ฒฝ์ ํจ์ผ๋ก์จ ์ถ์๋ ํฌ๊ธฐ ๋น๋๋ ๋ชจ๋ฐ์ผ ๋ฐ ์น ์ ํ๋ฆฌ์ผ์ด์ ๊ณผ ๊ฐ์ด ํฌ๊ธฐ๊ฐ ์ ํ๋ ํ๊ฒฝ์์ ์ฌ์ฉํ๊ธฐ์ ๋ ์ ํฉํฉ๋๋ค.
3) ONNX ์ฌ์ฉ ์ฌ๋ก
ONNX๋ฅผ ์ด์ฉํ์ฌ์ ์๋์ ๊ฐ์ด ์ฌ์ฉํ ์ ์์ต๋๋ค.
- 1. ๋ค์ํ ML ๋ชจ๋ธ์ ๋ํ ์ถ๋ก ์ฑ๋ฅ ํฅ์์ ์ฌ์ฉ๋ฉ๋๋ค.
- 2. ๋ค๋ฅธ ํ๋์จ์ด ๋ฐ ์ด์ ์ฒด์ ์์ ์คํ์ ์ฌ์ฉ๋ฉ๋๋ค.
- 3. Python์ผ๋ก ํ๋ จํ๋ C#/C++/Java ์ฑ์ ๋ฐฐํฌ์ ์ฌ์ฉ๋ฉ๋๋ค.
- 4. ๋ค์ํ ํ๋ ์ ์ํฌ์์ ์์ฑ๋ ๋ชจ๋ธ๋ก ์ถ๋ก ํ๋ จ ๋ฐ ์ํ์ ์ฌ์ฉ๋ฉ๋๋ค.
3. ONNX ์ฌ์ฉ ๊ณผ์
์ ๊ทธ๋ฆผ์, PyTorch ๋ชจ๋ธ์ ONNX ๊ทธ๋ํ๋ก export ํ๋ ์ ์ฒด ๊ณผ์ ์ ๋์ํํ ๊ฒ์ ๋๋ค.
์งํ ๊ณผ์ ์ ์๋ ์์์ ๊ฐ์ต๋๋ค.
- ์ฒซ ๋ฒ์งธ
- PyTorch ๋ชจ๋ธ๊ณผ Sample input ์ ์ธ์๋ก ํ์ฌ, torch.onnx.export ํจ์๋ฅผ ํธ์ถํฉ๋๋ค.
- PyTorch ์ JIT ์ปดํ์ผ๋ฌ๋ฅผ ํตํด์, Trace ํน์ Script ๋ฅผ ์์ฑํฉ๋๋ค.
- Trace ์ Script ๋ ๊ทธ ์์ฑ ๋ฐฉ์๊ณผ representation ์ ์ฐจ์ด๊ฐ ์์ต๋๋ค. (์ถํ ํฌ์คํ )
- PyTorch ๋ชจ๋ธ์ forward propagation ์์ ํธ์ถ๋๋, ํจ์ ๋ฐ ์ฐ์ฐ๋ค์ ๋ํ ์ต์ ํ๋ ๊ทธ๋ํ์ธ Torch IR ์ ๋ง๋ญ๋๋ค.
- Trace ๋ Script ๋, PyTorch ์ nn.Module ์ ์์ํ๋ ๋ชจ๋ธ์,
- forward ํจ์์์ ์คํ๋๋ ์ฝ๋๋ค์ ๋ํ IR(Intermediate Representation)์ ๋ด๊ณ ์์ต๋๋ค.
- ๋ ๋ฒ์งธ
- ์์ฑ๋ trace / script (Torch IR)๋, ONNX Exporter ๋ฅผ ํตํด์ ONNX IR ๋ก ๋ณํ๋๊ณ ,
- ์ฌ๊ธฐ์์ ํ ๋ฒ ๋ Graph Optimization ์ด ์ด๋ฃจ์ด์ง๋๋ค.
- ์ธ ๋ฒ์งธ
- ์ต์ข ์ ์ผ๋ก ์์ฑ๋ ONNX ๊ทธ๋ํ๋ .onnx ํฌ๋งท์ผ๋ก ์ ์ฅ๋ฉ๋๋ค.
์ฐธ๊ณ
- [Blog] ONNX ๋?: https://wooono.tistory.com/415
- [Blog]PyTorch 2.0 vs ONNX vs TensorRT ๋น๊ต: https://thecho7.tistory.com/entry/PyTorch-20-vs-ONNX-vs-TensorRT-%EB%B9%84%EA%B5%90
- [Blog] ONNX(Open Neural Network Exchange) ์ดํดํ๊ธฐ -1: React Native ํ์ฉ: https://adjh54.tistory.com/203