반응형
학습 시 Multi GPU 사용을 위해 nn.DataParallel을 사용.
# ...
model = nn.DataParallel(model)
# ...
2개의 GPU에서 학습한 모델을 불러올때 아래와 같은 에러가 발생.
pytorch RuntimeError: Error(s) in loading state_dict for XXX
nn.DataParallel로 병렬화 하면서 state_dict 키값에 'module.'이 붙으면서 맞지 않아서 발생하는 오류라고 한다.
학습 완료 후, 모델 추론 시 아래와 같은 방법으로 해결.
from collections import OrderedDict
model = MyModel().to(device)
state_dict = torch.load('XXX.pth')
new_state_dict = OrderedDict()
# key값에서 'module.'을 삭제
for key in state_dict:
new_key = key.replace('module.', '')
new_state_dict[new_key] = state_dict[key]
model.load_state_dict(new_state_dict)
model.eval()
반응형
'언어 | Framework > Pytorch' 카테고리의 다른 글
[Pytorch] tensor to PIL Image (0) | 2021.08.19 |
---|---|
[Pytorch] model.eval() (0) | 2021.07.14 |
[Pytorch] numpy, tensor, list 변환 (0) | 2021.03.24 |
[Pytorch] Multi GPU (0) | 2021.03.24 |