日本免费高清视频-国产福利视频导航-黄色在线播放国产-天天操天天操天天操天天操|www.shdianci.com

學(xué)無先后,達者為師

網(wǎng)站首頁 編程語言 正文

Python中torch.load()加載模型以及其map_location參數(shù)詳解_python

作者:eecspan ? 更新時間: 2022-11-13 編程語言

參考

TORCH.LOAD

torch.load()

函數(shù)格式為:torch.load(f, map_location=None, pickle_module=pickle, **pickle_load_args),一般我們使用的時候,基本只使用前兩個參數(shù)。

模型的保存

模型保存有兩種形式,一種是保存模型的state_dict(),只是保存模型的參數(shù)。那么加載時需要先創(chuàng)建一個模型的實例model,之后通過torch.load()將保存的模型參數(shù)加載進來,得到dict,再通過model.load_state_dict(dict)將模型的參數(shù)更新。

另一種是將整個模型保存下來,之后加載的時候只需要通過torch.load()將模型加載,即可返回一個加載好的模型。

具體可參考:PyTorch模型的保存與加載。

模型加載中的map_location參數(shù)

具體來說,map_location參數(shù)是用于重定向,比如此前模型的參數(shù)是在cpu中的,我們希望將其加載到cuda:0中?;蛘呶覀冇卸鄰埧ǎ敲次覀兙涂梢詫⒖?中訓(xùn)練好的模型加載到卡2中,這在數(shù)據(jù)并行的分布式深度學(xué)習(xí)中可能會用到。

首先定義一個AlexNet,并使用cuda:0將其訓(xùn)練了一個貓狗分類,之后把模型存儲起來。

map_location=None

我們先把state_dict加載進來。

model_path = "./cuda_model.pth"
model = torch.load(model_path)
print(next(model.parameters()).device)

結(jié)果為:

cuda:0

因為保存的時候就是模型就是cuda:0的,所以加載進來也是。

map_location=torch.device()

model_path = "./cuda_model.pth"
model = torch.load(model_path, map_location=torch.device('cpu'))
print(next(model.parameters()).device)

結(jié)果為:

cpu

模型從cuda:0變成了cpu。

map_location={xx:xx}

model_path = "./cuda_model.pth"
model = torch.load(model_path, map_location={'cuda:0':'cuda:1'})
print(next(model.parameters()).device)

結(jié)果為:

cuda:1

模型從cuda:0變成了cuda:1。

model_path = "./cuda_model.pth"
model = torch.load(model_path, map_location={'cuda:2':'cpu'})
print(next(model.parameters()).device)

結(jié)果為:

cuda:0

模型還是cuda:0,并沒有變成cpu。因為這個map_location的映射是不對的,原始的模型就是cuda:0,而映射是cuda:2cpu,是不對的。這種情況下,map_location返回None,也就是和不加map_location相同。

總結(jié)

原文鏈接:https://blog.csdn.net/qq_43219379/article/details/123675375

欄目分類
最近更新