일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | ||||||
2 | 3 | 4 | 5 | 6 | 7 | 8 |
9 | 10 | 11 | 12 | 13 | 14 | 15 |
16 | 17 | 18 | 19 | 20 | 21 | 22 |
23 | 24 | 25 | 26 | 27 | 28 | 29 |
30 | 31 |
- Physics Informed Neural Operator
- Neural Operator
- 품질 인자
- Anaconda
- FNO
- PIDoN
- DeepONet
- 일반기계기사
- PINN
- Operator Learning
- Resonance frequency
- Wireless Power Transform
- Governing Equation
- 순수 함수
- jax
- jax error
- Q factor
- COLAB
- Wireless Power Transfer
- 일러스트레이터 수식
- Neural Operator 설명
- Burgers Equation
- Physics Informed Neural Network
- HFSS
- Pino
- 맥스웰 방정식
- WPT
- Maxwell Equation
- 딥오넷
- 연산자 학습
- Today
- Total
잡다구리 너구리
[Jax] Just in Time(Jit) Compiler(컴파일러)란? 본문
Jax가 경쟁력 있는 이유는 기존 방식의 딥러닝 프레임워크보다 속도가 빠르다는 점이다. 이를 가능케 하는 것 중 가장 큰 요인은 Just in Time(Jit) 컴파일러 때문이다. 본문에서는 Jit 컴파일러가 기존 python, tensorflow와 어떠한 차이점 때문에 속도 향상이 가능한 것인지, 실제 코드에서는 어떻게 구현이 되는지 알아볼 예정이다. 해당 글 작성에 있어서 Jax 홈페이지에 나와있는 글을 참고하였으며, 본문에 올라와 있는 실제 코드는 Physics Informed DeepONet의 코드의 일부이다. 아래 링크가 이에 해당된다.
Jit Compiler
Jit 컴파일러가 어떤 원리로 코드 속도를 향상시키는가를 알기 위해서는 코드가 실행될 때 어떠한 방식들로 기계어로 번역되는지를 알아야 한다. 이런 방식은 크게 인터프리터(Interpreter) 방식과 Compiler(컴파일러) 방식으로 나뉜다. 먼저 인터프리터 방식은 코드 실행 시 한 문장에 한 번씩 번역한다. 즉 line by line으로 기계어로 번역을 한다. 컴파일러는 코드 전체를 스캔하여 기계어로 번역한다. 이렇기 때문에 인터프리터가 컴파일러에 비해 기계어 변환 자체는 시간이 빠른 편이다. 그러면 인터프리터 방식이 좋은 게 아니냐? 인터프리터 방식은 실행 파일을 생성하지 않기 때문에 매번 실행할 때마다 같은 코드여도 다시 번역을 해주어야 한다. 반면 컴파일러는 변환은 오래 걸리더라도 이에 대한 실행 파일이 생성되어 다음에 실행될 때 해당 실행 파일을 실행시켜주어 인터프리터 방식에 비해 시간이 빠르다. 이런 인터프리터 방식과 컴파일러 방식의 차이는 코드와 함께 다시 언급할 예정이다.
그렇다면 Jit은 어디에 속할까? Jit 컴파일러는 인터프리터 방식과 컴파일러 방식이 혼합된 방식이다. 런타임 실행 시에 코드의 일부분을 컴파일 하여 인터프리터의 속도가 향상된다. Python, Java 같은 툴들은 인터프리터 방식이기 때문에 이런 툴들에서도 컴파일을 할 수 있게 하여 속도를 향상시키는 것이 Jit 컴파일러다.
Jit example
Jax document에는 selu 함수를 사용할 때 Jit 컴파일러를 사용하면 얼마만큼의 속도 향상이 가능한지를 보여준다. Jit 컴파일러가 없을 때는 loop 당 2.5ms가 걸리지만 Jit 컴파일러를 적용할 경우 loop 당 0.15ms로 획기적인 속도 향상을 보여준다. 사실 개인적으로는 document에 나와 있는 예시로는 정확히 얼마큼이 좋은 건지 감이 잘 안 왔다. 때문에 Jit 컴파일러가 어떻게 사용되는지를 내가 공부하는 코드를 통해 확인해 보고자 한다.
# Define a compiled update step
@partial(jit, static_argnums=(0,))
def step(self, i, opt_state, ics_batch, bcs_batch, res_batch):
params = self.get_params(opt_state)
g = grad(self.loss)(params, ics_batch, bcs_batch, res_batch)
return self.opt_update(i, g, opt_state)
해당 코드는 모델 파라미터를 업데이트하는 코드이다. @partial(jit, static_argnums=(0, )) 이 바로 Jit Compiler를 사용하게 하는 Jit Decorator이다. @partial~~을 통해 그 아래의 step 함수를 컴파일할 수 있는 것이다. step 함수를 확인해 보면 side effect를 만들어내지 않고 동일한 입력값과 출력값을 유지하는 Pure function 임을 확인할 수 있다. 만약 해당 코드를 인터프리터 방식으로 실행시킨다면, Epoch가 반복될 때마다 line by line으로 그때그때 다시 읽는다. 하지만 Jit Decorator를 적용함으로써, 한 번 실행시킨 후 실행 파일을 저장해 바로바로 사용이 가능하다.(대신에 처음 읽을 때가 오래 걸린다.) 생성된 실행 파일을 무제한으로 쓸 수 있는 것이 아닌 일정 Epoch가 반복될 때마다 사라져 다시 읽는 모습이 훈련 중에 보이지만,(아마 캐시의 문제인듯하다.) 그럼에도 불구하고 빠르다.
글을 마치며
jax document에 Jit을 사용할 때 발생하는 오류들에 대해서도 나와있었지만, 해당 부분은 그냥 skip 하였다. 원래 오류는 직접 겪어보고 몸으로 부딪혀야 이해가 되며 기억에 남는 법이라고 생각하기 때문이다. Jit 컴파일러의 사용 방법에 대해 깊게 공부한 것은 아니지만, 실제로 코드를 짠다고 생각했을 때 Jax document의 방식보다 저렇게 Jit Decorator를 쓰는 것이 현실적일 것 같아 앞선 코드를 가져왔었다. 하지만 과연 Jit Decorator를 쓸 수 있도록 논리적으로 짤 수 있을까? 아직은 모르겠다.
'인공지능 > jax 입문' 카테고리의 다른 글
[Jax] Pure function(순수 함수)에 대해 (0) | 2023.08.08 |
---|---|
[Jax] Jax 자주 발생하는 오류 및 해결 방법 (0) | 2023.08.07 |
[Jax] Window에서 Jax 가상 환경 설치하기(설치 및 GPU 사용) (0) | 2023.08.06 |
[Jax] Jax란 무엇인가? Jax를 공부해야 되는 이유 (1) | 2023.08.05 |