일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | ||||||
2 | 3 | 4 | 5 | 6 | 7 | 8 |
9 | 10 | 11 | 12 | 13 | 14 | 15 |
16 | 17 | 18 | 19 | 20 | 21 | 22 |
23 | 24 | 25 | 26 | 27 | 28 | 29 |
30 | 31 |
- jax
- jax error
- Q factor
- FNO
- PIDoN
- HFSS
- Physics Informed Neural Operator
- PINN
- Pino
- Wireless Power Transform
- 순수 함수
- 품질 인자
- 일러스트레이터 수식
- Resonance frequency
- 맥스웰 방정식
- 연산자 학습
- Burgers Equation
- Anaconda
- DeepONet
- 딥오넷
- Maxwell Equation
- Operator Learning
- 일반기계기사
- Physics Informed Neural Network
- Neural Operator 설명
- Neural Operator
- WPT
- Wireless Power Transfer
- Governing Equation
- COLAB
- Today
- Total
잡다구리 너구리
[Jax] Jax란 무엇인가? Jax를 공부해야 되는 이유 본문
필자는 기계공학과의 한 연구실에서 학부 연구생으로 일하고 있다. 요즘 워낙 인공지능이 대세이다 보니 기계공학과에서도 Physics-Informed Neural Network(PINN)을 비롯해 여러 신경망들을 기계공학에 맞게 사용하는 연구들이 진행되고 있다. 물론 나도 그중 하나...
각설하고 최근 Physics-Informed DeepONet(PIDON)이라는 Neural Operator에 대해 공부하게 되었는데, 해당 코드가 Jax가 주가 되어 코드가 돌아가 Jax에 대해 약간이나마 조사를 해보았다. 본문은 어디까지나 아직 Jax를 본격적으로 공부하지 않았고, 필자가 본 PIDON 코드 기준으로 작성이 되어서 많이 주관적일 수 있다는 점 참고 바란다.
Jax는 최근 구글에서 만든 딥러닝을 위한 프레임워크로 Autograd와 XLA(Accelerated Linear Algebra: TensorFlw 모델을 가속화할 수 있는 선형 대수학용 도메인별 컴파일러)가 결합되어 만들어졌다. 해당 라이브러리는 Numpy 대신 활용이 가능하고, CPU, GPU, TPU에서 코드 변경이 없이 사용 가능하며 빠르다는 장점으로 최근 머신러닝에서 자주 활용이 되고 있다. 내가 공부하게 된 PIDON 또한 이러한 장점 때문에 Jax를 기반으로 코드를 짜지 않았을까 싶다. 간단하게 장점과 단점을 정리하자면 다음과 같다.
장점
1. Numpy와 문법이 유사하여 활용이 쉽다.
2. CPU, GPU, TPU의 코드가 동일하다.
3. 속도가 빠르다. JIT 컴파일 방식으로 더 빠른 속도로 실행 가능하다.
실제로 코드를 읽으면서 Jax가 작동하는 방식에 대해 Numpy와 유사한 문법으로 사용되기 때문에 Jax에 대해 부분 때문에 크게 어려웠던 점은 없다. 물론 처음 보는 vmap, Jit, ravel_pytree 등에 대한 공부는 따로 필요했지만 이런 부분들은 뭐 당연하니까... 또한 일반적인 Pytorch 기반 코드에는 cuda와 같이 gpu를 사용하기 위한 별도의 코드들이 많이 들어가 있지만, 딱히 이런 코드는 없고 그냥 cpu만 사용 가능하면 지가 알아서 cpu로 돌아가고 gpu가 사용 가능하면 자기가 알아서 gpu로 돌리더라. 다만 속도의 경우는 빠르다고는 하는데, 내가 직접 비교를 한 건 아니라 정확히 어느 정도의 속도 차이가 있는지는 모르겠다. 뭐 다들 빠르다고 하니까 빠르긴 하겠지 싶은 느낌
단점
1. functional programming을 따르기 때문에 Numpy와 다르게 유의점이 많다.
2. 최신 모델에는 활용이 많지만 전통적인 모델들에 대한 코드가 부족하다(ex) PINN 같은 모델?
3. 국내 자료가 많지 않아 진입장벽이 높은 편이다.
4. 최근에 만들어진 라이브러리다 보니 버전 업데이트가 자주 이루어져 버전 맞추기가 살짝 혼란스럽다.
1번의 경우 실제로 PIDON 코드를 공부하며 내가 직접 Jax를 통해 기존의 코드를 수정한다고 생각하였을 때 많이 어려울 것 같다는 생각이 들었다. Jax는 역전파를 통해 연산 처리를 하는 것이 아니라 순수 함수(pure function)을 통해 연산을 표현하는데, 이런 순수 함수와 jax에 의해 지원되는 데이터 타입을 사용함으로써 속도를 크게 향상시키는데 정확한 이해가 없으면 구현하기 어려울 것 같았다. 최신 라이브러리이다 보니까 2번 같은 문제는 어쩔 수 없는 부분이고, 3번과 4번이 너무 어려웠다. 환경 구축에만 30시간을 넘게 쏟았다... 확실히 국내 자료도 너무 없어서 내가 어려움을 겪은 부분들과 해결한 방법, 공부한 내용들을 공유하고자 블로그를 시작하게 된 계기가 되었다.(업데이트를 자주 할지는 모르겠지만)
Jax를 공부해야 하는 이유?
필자는 기계공학과 스마트카 융합전공 트랙을 밟고 있다 보니 자율주행에 대한 지식도 조금이나마 가지고 있다. 이러한 자율주행뿐만 아니라 점차적으로 나오는 모든 인공 지능 모델들이 기능이 심화될수록 부담하게 되는 Computation cost는 증가할 것이다. 때문에 속도가 빠르다는 장점 자체가 너무 넘사라고 생각하고 구글에서 밀고 있다는 점이 점차적으로 Jax 사용자가 많아질 것이라고 생각한다. 그렇기 때문에 어느 정도 기본 지식은 가지고 있어야 된다고 생각하고 아직 국내 사용자가 많이 없다는 점에서 추후 이걸 할 수 있을 때 어느 정도의 메리트가 있다고 생각한다. 일을 시키는 입장에서는 속도가 빠르다는 장점만 보이지 구현하기 어렵다는 장점은 안 보이니까...
다들 나처럼 시간을 박지 않았으면 좋겠는 마음에 본 블로그를 시작한다. 물론 아무것도 모르는 4학년 학부생(심지어 기계공학전공)이 쓰는 것 때문에 논리적인 설명이 부족할 수 있지만 내가 이해한 한에서 쉽게 적어볼 테니 다들 팩트체크는 크로스로 알아서 해주시길...
'인공지능 > jax 입문' 카테고리의 다른 글
[Jax] Just in Time(Jit) Compiler(컴파일러)란? (0) | 2023.08.09 |
---|---|
[Jax] Pure function(순수 함수)에 대해 (0) | 2023.08.08 |
[Jax] Jax 자주 발생하는 오류 및 해결 방법 (0) | 2023.08.07 |
[Jax] Window에서 Jax 가상 환경 설치하기(설치 및 GPU 사용) (0) | 2023.08.06 |