잡다구리 너구리

[Jax] Window에서 Jax 가상 환경 설치하기(설치 및 GPU 사용) 본문

인공지능/jax 입문

[Jax] Window에서 Jax 가상 환경 설치하기(설치 및 GPU 사용)

너굴뽀이 2023. 8. 6. 12:35

 앞서 말한 것처럼 필자는 Jax 환경 구축하는 데만 30시간 이상을 갈아 넣었다. 이런 부분이 익숙하신 컴공과나 소웨과 분들이 보시면 비웃으실 수 있겠지만, 나와 같은 컴맹 분들을 위해 쉽고 상세하게 정리하고자 한다. 먼저 가장 중요한 점, 나는 미처 간과했지만 Jax는 최근에 구글에서 만들어진 라이브러리다 보니 리눅스에서 설치해야 된다. 하지만 나는 20시간 정도 박고 나서야 이 사실을 알게 되었다. 하지만 갈아 넣은 시간이 아까워서 계속 Window에서 Anaconda 가상 환경에서의 환경 구축을 시도하였고, 해당 환경 구축을 위해 수십 개의 가상환경을 수정하고 뒤엎었고, 결국은 성공했다. 때문에 나 같은 리눅스에 무지한 사람들을 위해 이에 대한 방법을 공유하고자 한다.(사실 Vmware로도 시도해 봤지만 Vmware는 애초에 Gpu를 못 쓰더라)

 

 하지만 리눅스로 하면 금방 될 듯하다. 본인이 리눅스 환경이 이미 있다 하는 사람들은 뒤로 가기를 누르고 다른 분들 글을 참조하시는 걸 추천드린다. 해당 유튜브가 자주 발생하는 오류(나도 겪은)을 극복하는 것 또한 다루니 참조하시길. 필자도 다음 글에서 해당 내용을 다룰 예정이다.

 

https://www.youtube.com/watch?v=auksaSl8jlM 

 

 앞서 말한 것처럼 본 글은 Window 바탕 Anaconda에서의 가상 환경을 통해 Jax를 설치하는 법을 다룬 글이다. Jax를 설치할 때 자주 발생하는 오류를 극복하는 내용은 추후 다른 글에서 다루고 본 문은 설치하는 법 위주만을 쓸 예정이다.

 

사용 버전

Python == 3.9.0

Tensorflow-gpu == 2.5.0

jax == 0.3.7

jaxlib == 0.3.7

 

1. Anaconda 가상환경 만들기

conda create -n 가상환경 이름 python=3.9.0

 다음과 같이 사용할 가상환경을 생성해 준다. 3.9.0 버전을 사용해 주는 이유는 왠진 모르겠지만 대부분의 사람들이 해당 버전을 통해 Jax 환경 구축을 하길래 나 또한 그렇게 했다. 다음에 가상환경을 활성화한 후 필요 라이브러리 설치 및 아래의 절차를 밟으면 된다.

 

2. Tensorflow-gpu 설치

pip install tensorflow-gpu==2.5.0

 분명 작년에 cuda 사용 환경 구축할 때는 환경에서 따로 cuda와 cudnn을 까는 명령어를 생성해서 만들었던 거 같은데 1년 사이에 세상이 바뀐 건지 tensorflow-gpu만 깔아도 자기들이 알아서 cuda와 cudnn도 깔더라. 사실 cuda와 cudnn을 따로 깔고도 시도해 봤는데 실패했다. 그냥 tensorflow-gpu만 깔아도 되는 것 같다. 어떤 글에는 Jax 설치를 위해서는 cuda와 cudnn 버전까지도 맞춰줘야 한다고 하던데 나도 이거 때문에 계속 cuda와 cudnn 버전도 신경 쓰면서 했는데 위에 올려둔 유튜브에서 cuda를 블로그 내용과 다른 버전을 쓴 걸 보고 신경 안 쓰고 진행하였다. 

 

3. Jax 설치

pip install jax==<버전>
ex) pip install jax==0.3.7

 일단 jax의 경우 jaxlib과 버전 호환이 꼭 돼야 된다. 그렇지 않으면 오류가 생기는데 필자가 환경 구축을 하며 겪은 주요한 오류들은 다른 포스트에서 다루도록 하겠다. 내가 돌리는 코드 경우 0.3.25를 쓰지만 일부 jax 버전이 UnicodeDecodeError가 뜨더라 해당 오류의 경우 python 파일에서 enoding="utf-8"을 추가해 주면 해결하는 듯 보이지만 여러 글을 참고해도 해당 오류가 해결되지 않아 그냥 설치되는 버전으로 깔았다. 그리고 jax와 jaxlib의 경우 꼭 pip으로 깔아주도록 하자. 처음에는 생각 없이 conda와 pip 중에서 그냥 다른 글들에서 보이는 대로 깔았는데, 이 경우 아래와 같이 gpu를 못 잡는 오류가 발생한다.

WARNING - No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

이거 때문에 gpu 잡게 해주려고 개고생을 했던...

 

4. Jaxlib 설치 

 기본적으로 Jax는 리눅스에서 해야 되기 때문에 jaxlib을 jax 설치하는 github에서 까는 것이 아니라 별도로 whl 파일을 다운로드해서 사용해 주어야 한다. 아래의 사이트에서 본인이 다운로드한 jaxlib 버전과 cuda, cudnn 버전에 맞게 설치를 해주어야 한다. 이런 식으로 쭉 목록이 있는데 꼭 gpu 사용을 위해서는 앞에 cuda111이 붙은 것을 다운로드해야 한다.(사실 cpu 붙은 건 안 깔아봤지만 굳이 똥인지 된장인지 찍어 먹어봐야 아는 건 아니니까) cp~~ 는 파이썬 버전을 뜻하는데 나 같은 경우는 cuda111/jaxlib-0.3.7+cuda11.cudnn82-cp39-none-win_amd64.whl을 다운로드했다. 다만 상대적으로 최신 버전이 없는데(현재 0.4.14까지 나옴) 아마 Window 지원이 끊긴 건지 최신 버전 구동이 필요한 경우는 그냥 리눅스로 설치하도록 하자

 

jaxlib

https://whls.blob.core.windows.net/unstable/index.html

 

https://whls.blob.core.windows.net/unstable/index.html

 

whls.blob.core.windows.net

 자신의 버전에 맞게 다운을 받으면 자신의 콘다 환경에서 해당 whl 파일을 깔아주면 되는데 이때 꼭 파일이 있는 경로에서 설치를 하여야 한다.

cd <경로>
pip install <파일 명.whl>

 예시를 보면 아래 사진과 같다. 나 같은 경우는 그냥 바탕화면으로 다운로드한 파일을 옮겨서 Desktop 경로로 이동해 준 뒤 거기서 pip install jaxlib-0.3.7+cuda11.cudnn82-cp39-none-win_amd64.whl을 통해 설치해 주었다. 아래 사진 같은 경우는 기존에 설치되어 있는 jaxlib 버전에 똑같은 버전 설치 명령어를 넣었지만 새로 구축하시는 여러분이라면 잘 설치가 될 것이다.

 

jaxlib-install

 

 설치가 끝났으면 아래 코드를 통해 본인이 원하는 버전으로 잘 설치되었는지 확인해보자

conda list

 여기까지 잘 따라오신 여러분들은 윈도우에서 jax를 사용할 준비가 끝난 것이다. 떨리는 마음으로 gpu 사용이 가능한지 확인해 보면 된다. 

python #python을 anaconda 프롬프트 창에서 사용할 수 있게 하는것
>> import jax
>> jax.devices()

 위의 명령어를 입력하면 된다. 아래 사진과 같이 GpuDevice가 뜨면 성공한 것이다. 만약 Cpu로 뜬다면 욕 한 마디 하시고 뭐가 문제인지 고민하며 처음부터 천천히 확인해 보거나 다른 글들을 참조하면 된다.

 

jax-check

 

 나는 엄청 끙끙거리면서 힘들게 했는데 이렇게 글로 써놓으니까 이렇게 간단한 걸 왜 그렇게 오래 걸렸지 싶긴 하다. 교수님이 보시면 여기에 30시간 박은 거에 놀랄듯... 교수님 저 힘들었어요... 국내 자료도 얼마 없고 자주 생기는 오류들을 해결하는 과정 때문에 그런 거라고 자기 위로를 한다. 혹시 하다가 오류가 생기면 다른 포스트에서 자신의 오류와 동일한지도 확인해 보는 것이 좋을 것 같다. 혹시나 정 안되면 댓글로 남기면 기억하는 내에서 답변을 해주겠다.(글을 쓰면서도 가물가물해서 얼마나 도움이 될진 모르겠지만) 

 

 마지막으로 내가 도움 되었던 사이트를 올리며 글을 마친다. 해당 블로그가 설명이 잘돼있더라

https://normal-engineer.tistory.com/313