Multi GPU 학습 모델 불러오기
·
언어 | Framework/Pytorch
학습 시 Multi GPU 사용을 위해 nn.DataParallel을 사용.# ...model = nn.DataParallel(model)# ...  2개의 GPU에서 학습한 모델을 불러올때 아래와 같은 에러가 발생.pytorch RuntimeError: Error(s) in loading state_dict for XXX  nn.DataParallel로 병렬화 하면서 state_dict 키값에 'module.'이 붙으면서 맞지 않아서 발생하는 오류라고 한다. 학습 완료 후, 모델 추론 시 아래와 같은 방법으로 해결.from collections import OrderedDictmodel = MyModel().to(device)state_dict = torch.load('XXX.pth')new_state..
[Pytorch] Multi GPU
·
언어 | Framework/Pytorch
pytorch에서 여러개의 GPU 사용하기. import torch model = MyModel() # CNN이든 뭐든 사용할 모델 device = 'cuda' if torch.cuda.is_available() else 'cpu' if (device == 'cuda') and (torch.cuda.device_count() > 1): model = nn.DataParallel(model) model.to(device) 간단하다.
woongs_93
'GPU' 태그의 글 목록