언어 | Framework/Pytorch
Multi GPU 학습 모델 불러오기
woongs_93
2024. 7. 9. 14:14
반응형
학습 시 Multi GPU 사용을 위해 nn.DataParallel을 사용.
# ...
model = nn.DataParallel(model)
# ...
2개의 GPU에서 학습한 모델을 불러올때 아래와 같은 에러가 발생.
pytorch RuntimeError: Error(s) in loading state_dict for XXX
nn.DataParallel로 병렬화 하면서 state_dict 키값에 'module.'이 붙으면서 맞지 않아서 발생하는 오류라고 한다.
학습 완료 후, 모델 추론 시 아래와 같은 방법으로 해결.
from collections import OrderedDict
model = MyModel().to(device)
state_dict = torch.load('XXX.pth')
new_state_dict = OrderedDict()
# key값에서 'module.'을 삭제
for key in state_dict:
new_key = key.replace('module.', '')
new_state_dict[new_key] = state_dict[key]
model.load_state_dict(new_state_dict)
model.eval()
반응형