일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | ||||||
2 | 3 | 4 | 5 | 6 | 7 | 8 |
9 | 10 | 11 | 12 | 13 | 14 | 15 |
16 | 17 | 18 | 19 | 20 | 21 | 22 |
23 | 24 | 25 | 26 | 27 | 28 | 29 |
30 | 31 |
- PIDoN
- 순수 함수
- Pino
- DeepONet
- 맥스웰 방정식
- jax error
- Q factor
- jax
- Burgers Equation
- COLAB
- Resonance frequency
- FNO
- 일반기계기사
- Physics Informed Neural Operator
- Neural Operator 설명
- 딥오넷
- Maxwell Equation
- 품질 인자
- PINN
- WPT
- Operator Learning
- 일러스트레이터 수식
- 연산자 학습
- HFSS
- Wireless Power Transfer
- Physics Informed Neural Network
- Governing Equation
- Anaconda
- Wireless Power Transform
- Neural Operator
- Today
- Total
잡다구리 너구리
[Jax] Pure function(순수 함수)에 대해 본문
Jax는 Pure function(순수 함수)을 통해 연산을 표현하고, 사용되는 데이터들이 불변적이어야 되는 특징을 가져야 되기 때문에 Pure function이 매우 중요한 개념이라고 할 수 있다. 하지만 필자는 Pure function에 대해 처음 들어보았고, 개념이 애매하게 잡혀있는 것 같아 이를 정리하며 좀 더 확고하게 다지고자 한다.
여러 글과 유튜브, jax 공식 document를 통해 내 나름대로 이해하며 정의를 내려보았다. 먼저 Jax가 Pure function만을 사용해야 하는 이유는 jax가 functional programming을 따르기 때문에 데이터들이 불변적이라는 속성을 띄어야 하기 때문이다. Pure function을 만족하기 위해서는 두 가지가 지켜져야 한다. 먼저, 동일한 입력값을 주고 함수가 실행되었을 때, 항상 동일한 출력을 가져야 한다. 즉 함수 자체도 불변적이어야 한다는 것이다. 또한 함수가 Side effect(부수 효과)를 가지면 안 된다. 즉 함수 외부에 영향을 끼쳐서는 안 되는 것이다. 본문의 예제 코드는 Jax의 예제 코드를 바탕으로 하였다.
https://jax.readthedocs.io/en/latest/notebooks/quickstart.html
JAX Quickstart — JAX documentation
JAX Quickstart JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research. With its updated version of Autograd, JAX can automatically differentiate native Python and NumPy code. It can differ
jax.readthedocs.io
1. 동일한 입력값과 출력값을 유지하는 함수
g = 0.
def impure_uses_globals(x):
return x + g
# JAX captures the value of the global during the first run
print ("First call: ", jit(impure_uses_globals)(4.))
g = 10. # Update the global
# Subsequent runs may silently use the cached value of the globals
print ("Second call: ", jit(impure_uses_globals)(5.))
# JAX re-runs the Python function when the type or shape of the argument changes
# This will end up reading the latest value of the global
print ("Third call, different type: ", jit(impure_uses_globals)(jnp.array([4.])))
해당 함수의 경우 함수 외부의 g로 인해 함수가 실행되어 x 값과 g 값을 더한 값이 반환이 된다. 그렇다면 코드를 실행시켰을 때의 결과는 4, 15(중간에 g = 10이 있기 때문에), [14. ] 가 나와야 될 것이다. 하지만 코드의 실행 결과는 아래와 같다.
First call: 4.0
Second call: 5.0
Third call, different type: [14.]
두 번째 print 구문이 15가 아닌 5가 출력된 모습을 확인할 수 있다. 이러한 원인이 바로 함수가 동일한 입력값과 출력값을 유지하지 못하기 때문에, 불변성을 잃어 jax가 구동이 잘못된 것이다. 위의 함수는 함수 외부의 g 값에 따라 출력값이 달라지기 때문에 input으로 동일하게 4를 넣더라도 g가 0,3,5 등 숫자가 바뀌게 되면 다른 출력값을 낸다. 즉, 동일한 입력값에 대해 동일한 출력값을 유지하지 못하는 것이다.
2. Side effect를 가지면 안 된다.
함수가 side effect를 가지는 데에 대한 여부는 함수 외부에 영향을 끼치는 요소가 있는지이다. 함수 외부에 영향을 끼치는 요소가 있을 경우 해당 함수가 side effect를 가지고 있다고 본다.
def impure_print_side_effect(x):
print("Executing function") # This is a side-effect
return x
# The side-effects appear during the first run
print ("First call: ", jit(impure_print_side_effect)(4.))
# Subsequent runs with parameters of same type and shape may not show the side-effect
# This is because JAX now invokes a cached compilation of the function
print ("Second call: ", jit(impure_print_side_effect)(5.))
# JAX re-runs the Python function when the type or shape of the argument changes
print ("Third call, different type: ", jit(impure_print_side_effect)(jnp.array([5.])))
위의 코드는 함수를 실행시킬 경우 "Executing function"을 실행시킨다. 즉, 외부에 영향을 끼치는 것이다. 이를 실행하면 결과가 아래와 같이 나온다.
Executing function
First call: 4.0
Second call: 5.0
Executing function
Third call, different type: [5.]
함수가 side effect를 가지고 있기 때문에 두 번째 print 구문이 제대로 실행되지 않은 것이다. 이와 같이 Pure function을 위해서는 함수가 외부에 영향을 끼쳐서는 안 된다. 위의 코드와 같이 무언가를 출력하는 것이 아닌, 함수 외부를 변경하는 것 또한 side effect이다. 아래 코드와 같은 함수도 Pure function이 아닌 것이다. 함수가 실행되게 되면 함수 외부의 변수인 a의 값이 변경되기 때문에 Pure function이라고 할 수 없다.
a = 0
def impure_print_side_effect(x):
a = x # This is a side-effect
return x
# The side-effects appear during the first run
print ("First call: ", jit(impure_print_side_effect)(4.))
# Subsequent runs with parameters of same type and shape may not show the side-effect
# This is because JAX now invokes a cached compilation of the function
print ("Second call: ", jit(impure_print_side_effect)(5.))
# JAX re-runs the Python function when the type or shape of the argument changes
print ("Third call, different type: ", jit(impure_print_side_effect)(jnp.array([5.])))
글을 마치며
Jax를 처음 접할 때는 이것을 안 쓸 이유가 없어 보였지만, 막상 공부를 해보니 생각보다 사용 조건이 까다롭다는 것을 느꼈다. 함수 자체가 Pure function을 유지하는 것을 계속 생각하며 코딩하기란 하나의 제약 조건이 생기는 것이기에 쉽지 않을 것 같다. 같은 원리로 Iterator의 경우 functional programming을 위배할 여지가 크기 때문에 Jax에서 Iterator를 사용하는 것을 추천하지 않는다고 한다. 하지만 한 편으로는 활용의 여지가 있는 분야가 있다고 생각한다. 기계공학과의 경우는 물리적 정보를 사용하거나, 수치해석적 방법을 통해 Partial differental equations을 푸는 연구가 많이 진행되고 있는데, 오히려 이런 요소에서는 functional programming을 사용하기가 용이하지 않을까? 아마 조금 더 공부를 해봐야 알 것 같다. 혹시 내가 이해한 것이 잘못되었다면, 이에 대한 지적은 언제든지 환영입니다.
'인공지능 > jax 입문' 카테고리의 다른 글
[Jax] Just in Time(Jit) Compiler(컴파일러)란? (0) | 2023.08.09 |
---|---|
[Jax] Jax 자주 발생하는 오류 및 해결 방법 (0) | 2023.08.07 |
[Jax] Window에서 Jax 가상 환경 설치하기(설치 및 GPU 사용) (0) | 2023.08.06 |
[Jax] Jax란 무엇인가? Jax를 공부해야 되는 이유 (1) | 2023.08.05 |