잡다구리 너구리

[PIDoN] Burger's Equation 예제 코드 분석 [2/2] 본문

인공지능/Deep Learning

[PIDoN] Burger's Equation 예제 코드 분석 [2/2]

너굴뽀이 2025. 4. 30. 20:22

 이전 글에 이은 Physics Informed DeepONet의 Burger's Equation 예제 코드 분석, 이번 글에는 PI-DeepONet에 정의된 손실 함수를 집중적으로 다뤄보고자 한다. 이전 글과 PIDoN의 개념적인 내용은 아래 글을 통해 확인할 수 있다. 시작하기 앞서, PIDoN은 PINN과 동일하게 물리적 정보를 포함한 지배 방정식을 이용한 손실함수와 초기 조건과 경계 조건을 반영한 손실 함수를 사용한다는 점을 인지한다면 이해하기는 어렵지 않다.

 

2024.02.11 - [인공지능/Deep Learning] - [PIDoN] Physics-Informed DeepONet에 대하여

 

[PIDoN] Physics-Informed DeepONet에 대하여

최근 PINN과 같이 Single Instance 특성을 가진, 즉 같은 문제 상황이더라도 Input 데이터 등의 조건이 바뀌면 재훈련해야 된다는 문제로 인해 Operator Learning과 같은 Multi Instance의 특성을 가진 알고리즘

sarasara.tistory.com

2025.03.22 - [인공지능/Deep Learning] - [PIDoN] Burger's Equation 예제 코드 분석 [1/2]

 

Physics Loss

 

 PI-DeepONet의 PI를 담당하는 Physics Loss를 먼저 설정해 준다. 예제 코드에서는 Loss_res라는 이름으로 사용되었는데, 아마 물리적 손실 함수의 개념 자체가 지배 방정식이 있을 때 좌항과 우항의 잔차를 이용해서 손실 함수를 계산하기 때문에 residual의 뜻의 Loss_res가 아닌가 싶다. 다른 물리적 기반 인공 신경망과 마찬가지로 해당 손실 함수는 기본 PDE의 제약 조건을 적용한다. PI-DeepONet의 솔루션 연산자가 $G_{\theta}$일 때 입력 함수 $u^{(i)}$에 대한 PDE의 잔차는 아래의 식과 같이 결정된다.

 

$$R_{\theta}^{(i)}(x, t) = \frac {dG_{\theta}(u^{(i)})(x, t)}{dt}+G_{\theta}(u^{(i)})(x, t)\frac {dG_{\theta}(u^{(i)})(x, t)}{dx}-ν\frac {d^2G_{\theta}(u^{(i)})(x, t)}{dt^2}$$

 

 얼핏 봤을 때는 어려워 보이지만, Burger's Eqation의 형태에 솔루션 연산자를 거친 값을 넣어줬다고 생각해 보면 이해가 어렵지 않을 것이다. 이때의 $u(i)$는 [0,1]의 공간에서 균일한 간격으로 배치된 고정 센서 $\lbrace x_{i} \rbrace^m_{i=1}$를 평가한 입력 함수이다. 위의 수식과 같은 과정이 residual_net 부분에 정의가 되어 있다.

 

 손실 함수는 아래의 식과 같이 사용된다. $N$은 training data의 sample의 수이며, $Q$는 PDE 잔차를 평가할 Location의 수이다. 쉽게 생각하면 한 필드 내의 유체의 움직임을 PIDoN을 통해 예측하려고 할 때, $Q$는 필드 내의 센서이고, $N$은 들어갈 데이터의 개수가 된다.

 

$$\mathcal {L}_{physics}({\theta})=\frac {1}{NQ}\sum_{i=1}^{N}\sum_{j=1}^{Q}\left|R_{\theta}^{(i)}(x_{r, j}^{(i)}, t_{r, j}^{(i)})\right|^2.$$

 

 한 가지 눈 여겨볼 점은, 예측값을 계산하는 과정 속에서 vmap 함수를 사용한다는 점이다. vmap은 jax에서 제공하는 함수 중 하나로 for-loop를 벡터화해서 자동으로 병렬 처리를 해주는 함수이다. 이는 jax가 타 프레임워크에 비해 더 빠른 속도를 제공하는 이유 중 하나이다. residual_net(params, $u_i$, $t_i$, $x_i$)를 여러 데이터에 반복 적용하기 위해 사용되며, vmap($f$, input)를 예로 들면, $f$는 벡터화하려는 함수를 뜻하고, input은 각 인풋에 대한 벡터화 여부를 지정한다. 즉, 아래 코드에서는 self.residual_net를 병렬 적용하기 위해 벡터화를 하며, (None, 0, 0, 0)에서는 params는 공통적으로 사용하기 위해 None으로 지정 후, $u_i, t_i, x_i$는 샘플마다 각각 다르기 때문에 여러 값을 처리하기 위해 벡터화를 하는 것이다. 

 

    # Define PDE residual
    def residual_net(self, params, u, t, x):
        s = self.operator_net(params, u, t, x)
        s_t = grad(self.operator_net, argnums=2)(params, u, t, x)
        s_x = grad(self.operator_net, argnums=3)(params, u, t, x)
        s_xx= grad(grad(self.operator_net, argnums=3), argnums=3)(params, u, t, x)

        res = s_t + s * s_x - 0.01 * s_xx
        return res

    # Define residual loss
    def loss_res(self, params, batch):
        # Fetch data
        inputs, outputs = batch
        u, y = inputs
        # Compute forward pass
        pred = vmap(self.residual_net, (None, 0, 0, 0))(params, u, y[:,0], y[:,1])

        # Compute loss
        loss = np.mean((outputs.flatten() - pred)**2)
        return loss

 

Initial Loss

 

 다음은 초기 조건을 통해 계산되는 손실 함수를 정의한다. Loss_ics로 지정이 되어 있으며, 이는 아래와 같다.

 

$$\mathcal {L}_{IC}({\theta})=\frac {1}{NP}\sum_{i=1}^{N}\sum_{j=1}^{P}\left|G_{\theta}(u^{(i)})(x_{ic, j}^{(i)},0)-u^{(i)}(x_{ic, j}^{(i)})\right|^2.$$

 

 $N$은 physics loss와 동일하게 training data의 sample 수이며, $P$는 초기 조건을 평가할 Location의 수이다. $Q$와 같은 역할인데, 해당 코드에서는 한 데이터셋에 대해 physics loss, initial loss, boundary loss에 대한 데이터를 각각 나눈다.

 

    # Define initial loss
    def loss_ics(self, params, batch):
        # Fetch data
        inputs, outputs = batch
        u, y = inputs

        # Compute forward pass
        s_pred = vmap(self.operator_net, (None, 0, 0, 0))(params, u, y[:,0], y[:,1])

        # Compute loss
        loss = np.mean((outputs.flatten() - s_pred)**2)
        return loss

 

Boundary Loss

 다음은 경계 조건으로부터 발생하는 손실 함수이며, 식은 아래와 같다.

 

$$\mathcal {L}_{BC}({\theta})=\frac {1}{NP}\sum_{i=1}^{N}\sum_{j=1}^{P}\left|G_{\theta}(u^{(i)})(0, t_{bc, j}^{(i)})-G_{\theta}(u^{(i)})(1, t_{bc, j}^{(i)})\right|^2\\+\frac {1}{NP}\sum_{i=1}^{N}\sum_{j=1}^{P}\left|\frac {dG_{\theta}(u^{(i)})(x, t)}{dx}\vert_{(0, t_{bc, j}^{(i)})}-\frac {dG_{\theta}(u^{(i)})(x, t)}{dx}\vert_{(1, t_{bc, j}^{(i)})}\right|^2$$

 

    # Define ds/dx
    def s_x_net(self, params, u, t, x):
         s_x = grad(self.operator_net, argnums=3)(params, u, t, x)
         return s_x

    # Define boundary loss
    def loss_bcs(self, params, batch):
        # Fetch data
        inputs, outputs = batch
        u, y = inputs

        # Compute forward pass
        s_bc1_pred = vmap(self.operator_net, (None, 0, 0, 0))(params, u, y[:,0], y[:,1])
        s_bc2_pred = vmap(self.operator_net, (None, 0, 0, 0))(params, u, y[:,2], y[:,3])

        s_x_bc1_pred = vmap(self.s_x_net, (None, 0, 0, 0))(params, u, y[:,0], y[:,1])
        s_x_bc2_pred = vmap(self.s_x_net, (None, 0, 0, 0))(params, u, y[:,2], y[:,3])

        # Compute loss
        loss_s_bc = np.mean((s_bc1_pred - s_bc2_pred)**2)
        loss_s_x_bc = np.mean((s_x_bc1_pred - s_x_bc2_pred)**2)

        return loss_s_bc + loss_s_x_bc

 

 

Total Loss

 

 마지막으로 위의 세 손실 함수를 바탕으로 최종 손실 함수를 정의한다. 최종 손실 함수의 식은 다음과 같다.

 

$$\mathcal {L}({\theta})=\mathcal {L}_{IC}({\theta})+\mathcal {L}_{BC}({\theta})+\mathcal {L}_{physics}({\theta})$$

 

 위의 식은 기본적인 식의 구성이며, 논문에서는 손실 함수의 상호 균형을 맞추기 위한 hyper-perametre $\lambda$를 사용한다. 아래 코드에서는 20의 수치를 사용하였고, 해당 논문에서는 [1, 5, 10, 20, 50, 100]의 hyper-parameter 세팅을 통해 실험해 본 결과 20의 수치가 가장 성능이 우수하였다는 결과를 제시한다. hyper-parameter를 고려한 손실 함수의 식은 아래와 같다.

 

$$\mathcal {L}({\theta})=\lambda\mathcal {L}_{IC}({\theta})+\mathcal {L}_{BC}({\theta})+\mathcal {L}_{physics}({\theta})$$

 

 사담을 하자면, 최근 연구까지는 모르기에 내가 잘 모르는 거일 수도 있지만, 아직 왜 이런 hyper-parameter가 효과가 좋은지는 밝혀지지도 않았고, 개인적으로는 조금 억지라는 느낌도 든다. 너무 성능이 좋아졌으니 장땡인 느낌.. 나조차도 전에 matlab을 통해 wave equation을 physics informed neural network로 예측하는 코드를 짤 때, 코드 전체적인 부분에서 아무 이상도 없었는데, 예측이 안되다가 혹시나 하니까 loss 부분에 hyper-parameter를 적용했더니 잘 되었던 기억이 있다. 기억이 가물가물하긴 하지만, 2~3년 전에 갔던 학회에서도 SKY에서 한 신경망에 대해 hyper-parameter를 각각 다르게 적용했더니, 이게 잘 되더라 라는 알맹이 없는 발표를 들었던 기억도 있다. 여하튼 해당 코드는 아래와 같다.

 

    # Define total loss
    def loss(self, params, ics_batch, bcs_batch, res_batch):
        loss_ics = self.loss_ics(params, ics_batch)
        loss_bcs = self.loss_bcs(params, bcs_batch)
        loss_res = self.loss_res(params, res_batch)
        loss =  20 * loss_ics + loss_bcs +  loss_res
        return loss

 

Update Step

 

 나머지 부분은 일반적인 신경망과 비슷하여, 크게 어려운 부분은 없지만, jax의 특징을 엿볼 수 있는 부분을 살펴보자. 일단 코드를 먼저 보면 다음과 같다.

 

    # Define a compiled update step
    @partial(jit, static_argnums=(0,))
    def step(self, i, opt_state, ics_batch, bcs_batch, res_batch):
        params = self.get_params(opt_state)
        g = grad(self.loss)(params, ics_batch, bcs_batch, res_batch)
        return self.opt_update(i, g, opt_state)

 

 @partial이라는 특이한 코드를 사용하는데, 이는 마찬가지로 jax에만 있는 jit decorator라고 한다. 이를 이해하기 위해서는 먼저, 코드 실행 방식을 알아야 한다. 일반적으로 코드 실행은 인터프리터 방식과 컴파일 방식으로 나뉜다. 인터프리터 방식은 코드를 한 줄씩 번역해서 실행하는 방식이고, 컴파일 방식은 코드를 한 번에 번역하여 저장하고 실행하는 방식이다. 후자가 코드 실행 속도가 더 빠른데, 왜냐면 이미 번역된 코드를 다시 읽지 않고, 저장된 컴파일된 버전을 사용하기 때문이다.

 

 Phthon 같은 언어는 기본적으로 인터프리터 방식으로 동작하여, 코드를 실행할 때 한 줄씩 번역해서 처리한다. 그렇기 때문에 VS code에서 코드를 실행시키고 아래쪽을 보면 line by line으로 실행되는 것을 볼 수 있는 것이다. 반면, jax에서는 jit decorator를 사용해 부분적인 컴파일을 수행할 수 있다. jit은 특정 함수에 대해 한 번 컴파일된 결과를 저장해 두고, 이후 실행 시 저장된 컴파일된 코드를 재사용하기 때문에, 신경망 훈련과 같은 반본적인 코드 실행에서 성능이 크게 향상된다. 이렇게 jit decorator가 적용된 함수는 첫 실행 시에 컴파일되고, 다음 epoch 때는 다시 컴파일할 필요 없이 저장된 컴파일 결과를 불러와 실행 속도를 개선시킬 수 있다. pure function만을 사용하다는 제약이 있긴 하지만...

 

글을 마치며

 

 모든 코드를 다 적기에는 무리도 있고, 딱히 어려운 부분이 없어, 할 말이 없는 파트들도 있어서 여기서 코드 정리를 마치려고 한다. 해당 코드를 공부할 때는 jax에 대한 상당한 의구심을 가지고 있었는데, 내가 jax 기반 환경을 구축하는 데에 상당한 시간을 썼기 때문이다. 물론, 리눅스에서 설치해야 되는 걸 무식하게 Window에 설치하려고 그런 것도 있지만... pure function이란 개념 자체가 처음 들었을 때, 복합적인 코드에서 모든 함수에 pure function 조건을 고려하며 짤 수 있을까...라는 생각이 들었다. 회사에서 일하고 있는 지금도 해당 생각은 유효한데, 물론 내가 소프트웨어 쪽에서 일하는 게 아니라 그런 걸 수 있지만, 가장 중요한 것은 프로그램을 만들어도 유지 보수성이 용이한가라고 생각한다. 본인만 아는 코드는 의미가 없기 때문에... 그런 점에서 수많은 라이브러리가 있는데, 대부분의 개발자들이 jax를 선택할 이유가 있을까 싶은 생각이다.