기계공학의 잡다구리 생활

[Jax] Jax 자주 발생하는 오류 및 해결 방법 본문

인공지능/jax 입문

[Jax] Jax 자주 발생하는 오류 및 해결 방법

타구스 2023. 8. 7. 14:41

  필자는 이전 글에 언급한 것처럼 Window에서 Anaconda 가상 환경에 jax 실행 환경을 구축했다. 본 문에서는 jax 가상 환경을 구축하며 나를 괴롭게 했던 오류들을 정리해 보고자 한다. 필자는 jax, jaxlib 0.3.7 버전에서 발생한 문제이나 가상 환경을 구축하면서 다른 버전들도 많이 설치해 봤는데 다들 동일한 오류가 발생한 걸로 보아 아마 다들 대중적으로 발생하는 오류지 않을까 생각한다. 그러니 본인 버전과 달라도 크게 개의치 않고 일단 해결 방법을 따라 해보면 될 듯하다. 필자를 애먹게 했던 순으로 오류를 정리해 볼까 한다.

 

1. WARNING - No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

 

 해당 오류를 처음 접했을 때는 멘붕이 왔다. 분명 cuda, cudnn을 전에 구축해둔 가상 환경을 복제해서 사용했기 때문에 GPU 사용이 안 될 리가 없고, 실제로 Cuda 실행 확인을 해봤을 때 실행이 되었었다. 아마 지금 되돌아볼 때는 여러 원인으로 추정이 되는데, 일단 라이브러리 버전 문제이다. Jax의 경우 일정 버전 이상에서는 특정 cuda 버전 이상을 필요로 한다. Jaxlib 설치를 꼭 앞에 cuda111이 붙은 파일로 설치를 해야 된다. 또한 pip으로 설치를 안 해서 채널이 섞인 문제가 있을 것이다. 사실 이전 글에 쓴 순서대로 잘 따라왔다면 해당 오류는 발생하지 않을 것이라 생각하지만, 이전 글대로 실행한 후 해결이 되지 않아 본 글로 넘어온 사람이라면 버전을 확인해 보면 될 것이다. 아래 링크를 통해 Jax의 버전 업데이트 과정을 확인할 수 있으며, 특정 cuda 버전에 대한 정보도 확인할 수 있다. 해당 오류 때문에 필자는 엄청난 고생을 했기 때문에 원래 상세하게 이에 대해 풀고 싶었지만, 막상 풀려고 하다 보니 이전 글의 설치 순서대로 따라오면 해당 오류가 발생하지 않을 거란 생각 때문에 딱히 할 말이 없다.

 

https://jax.readthedocs.io/en/latest/changelog.html

 

Change log — JAX documentation

Many functions and objects available in jax.test_util are now deprecated and will raise a warning on import. This includes cases_from_list, check_close, check_eq, device_under_test, format_shape_dtype_string, rand_uniform, skip_on_devices, with_config, xla

jax.readthedocs.io

 

2. ValueError: DenseElementsAttr colud not be constructed from the given buffer. This may mean that the Python buffer layout does not match that MLIR expected layout and is a bug

 앞서 GPU가 잡히지 않는 오류를 힘겹게 극복한 후(여기까지만 이미 시간이 갈린 상태) 이제 코드가 실행될 것이라는 생각에 코드를 구동시켜 보았다. 오류가 아주 길게 나오며 오류 중간에 ValueError : DenseElementsAttr~~이 뜨는 모습과 아래 사진과 같은 오류가 뜨는 것을 확인하였다. 

 

jax-error

 

 이게 코드 상의 오류인지, jax 버전에 따른 오류인지(코드 github에서는 jax 0.3.25버전으로 작업했다.), jax 자체의 오류인지 많이 헤맸다. 각설하고 결론만 말하면 정답은 jax 자체의 오류가 맞는 것 같다. 먼저 발등에 불 떨어진 여러분들을 위해 해결 방법부터 말씀드리겠다. 먼저 본인의 가상환경 폴더를 열어준다. 나 같은 경우는 해당 경로는 다음과 같다.

 

C:\Users\user\anaconda3\envs\PIDON

 

 PIDON이 가상환경 명이다. 그 후 폴더 검색에 cuda_prng.py를 쳐준다. 그런 다음 jaxlib 경로로 잡혀있는 파일을 열어주면 된다. 글로 보면 감이 안 잡힐 것을 위해 사진을 통해 보면 아래와 같이 접근하면 된다.

 

cuda-prng.py-접근법

 

 다들 잘 따라오셨을 것이라 믿고 해당 파일(첫 번째 꺼)를 열어준다. 아마 내 기억 상에는 하나만 있었던 거 같은데 이 글 쓴다고 jaxlib 버전을 몇 번 바꿔깔았더니 몇 개가 더 생겼나? 첫 번째 꺼 바꿨는데도 안 되면 그냥 다 바꿔주면 될 것 같다. 그러면 뭘 바꾸냐? 해당 파일을 보면 아래와 같은 코드가 있다. 나 같은 경우는 76 line에 있었다.

 

  layout = ir.DenseIntElementsAttr.get(np.arange(ndims - 1, -1, -1),
                                       type=ir.IndexType.get())

 

 이 코드를 아래와 같이 바꿔주면 된다. 조금 더 상세하게 설명하면 np.arrange 안에 dtype을 지정해주면 된다. 이럴 경우 위의 오류가 해결된다.

 

  layout = ir.DenseIntElementsAttr.get(np.arange(ndims - 1, -1, -1,dtype=np.int64),
                                       type=ir.IndexType.get())

 

 오류의 정확한 원인은 모르겠으나 아마 윈도우로 깜으로써 데이터 형식이 꼬이는 그런 원인이 아닌가 싶다. pip으로 특정 파일을 설치할 때 UnicodeDecodeError가 발생해서 파이썬 파일을 까서 안에 encoding="utf-8"을 입력해 주는 것처럼...

 

글을 마치며..

 

 이런 오류들 외에 되게 중간중간 자잘한 오류들이 많았다. jax, jaxlib의 버전이 맞지 않아 생기는 오류 등등. 이런 오류들이 많아 가상 환경 구축을 성공하고 나서 극복한 과정들을 공유해야겠다고 생각해서 블로그를 시작하게 되었다. 그런데 막상 글을 쓰다 보니 어떤 오류를 겪었는지 기억이 나지 않아(능지 이슈) 내가 가장 고생했던 기억나는 두 오류에 대해서만 말하는 거에 대해 심심 찮은 사과의 말씀을 전한다. 

 

 사실 해결하지 못한 오류도 있다. 앞서 지나치듯 언급한 pip 설치 오류, Window에서 까는 거다 보니까 특정 jax 버전을 깔면 UnicodeDecodeError가 발생해 "UnicodeDecodeError: "cp949" codec can't decode byte 0xe2 in position 1304"가 뜨며 설치가 안 되더라. 해당 오류 원인 자체는 알고 있어 구글링을 통해 이걸 해결하려고 여러 방법을 시도했는데 왜인진 모르겠으나 끝까지 해결을 못했다. 이것이 내 코드는 0.3.25를 필요로 하는데 0.3.7 환경을 구축한 이유이다. 이거 때문에 2번 오류가 버전 오류이지 않을까도 생각했었고, 여하튼 Window로 UnicodeDecodeError를 해결하신 분은 어떻게 해결하셨는지 댓글로 알려주시면 감사

 

 이제 Jax 공부는 시작해야 되는 거라 다음 포스팅은 Jax가 될지, 기존에 공부했던 내용이 될지는 모르겠지만 긴 글 읽어주셔서 감사를 전하며, 한 명에게라도 도움이 되었으면 싶다.