반응형

학습 시 Multi GPU 사용을 위해 nn.DataParallel을 사용.

# ...

model = nn.DataParallel(model)

# ...

 

 

2개의 GPU에서 학습한 모델을 불러올때 아래와 같은 에러가 발생.

pytorch RuntimeError: Error(s) in loading state_dict for XXX

 

 

nn.DataParallel로 병렬화 하면서 state_dict 키값에 'module.'이 붙으면서 맞지 않아서 발생하는 오류라고 한다.

 

학습 완료 후, 모델 추론 시 아래와 같은 방법으로 해결.

from collections import OrderedDict

model = MyModel().to(device)
state_dict = torch.load('XXX.pth')
new_state_dict = OrderedDict()

# key값에서 'module.'을 삭제
for key in state_dict:
	new_key = key.replace('module.', '')
	new_state_dict[new_key] = state_dict[key]
   
model.load_state_dict(new_state_dict)
model.eval()

 

 

반응형

'언어 | Framework > Pytorch' 카테고리의 다른 글

[Pytorch] tensor to PIL Image  (0) 2021.08.19
[Pytorch] model.eval()  (0) 2021.07.14
[Pytorch] numpy, tensor, list 변환  (0) 2021.03.24
[Pytorch] Multi GPU  (0) 2021.03.24
woongs_93