網(wǎng)站首頁 編程語言 正文
導(dǎo)讀
只需要添加幾行代碼,就可以得到更快速,更省顯存的PyTorch模型。
你知道嗎,在1986年Geoffrey Hinton就在Nature論文中給出了反向傳播算法?
此外,卷積網(wǎng)絡(luò)最早是由Yann le cun在1998年提出的,用于數(shù)字分類,他使用了一個卷積層。但是直到2012年晚些時(shí)候,Alexnet才通過使用多個卷積層來實(shí)現(xiàn)最先進(jìn)的imagenet。
那么,是什么讓他們現(xiàn)在如此出名,而不是之前呢?
只有在我們擁有大量計(jì)算資源的情況下,我們才能夠在最近的過去試驗(yàn)和充分利用深度學(xué)習(xí)的潛力。
但是,我們是否已經(jīng)足夠好地使用了我們的計(jì)算資源呢?我們能做得更好嗎?
這篇文章的主要內(nèi)容是關(guān)于如何利用Tensor Cores和自動混合精度更快地訓(xùn)練深度學(xué)習(xí)網(wǎng)絡(luò)。
什么是Tensor Cores?
根據(jù)NVIDIA的網(wǎng)站:
NVIDIA Turing和Volta GPUs都是由Tensor Cores驅(qū)動的,這是一項(xiàng)突破性的技術(shù),提供了突破性的AI性能。Tensor Cores可以加速AI核心的大矩陣運(yùn)算,在一次運(yùn)算中就可以完成混合精度的矩陣乘法和累加運(yùn)算。在一個NVIDIA GPU上有數(shù)百個Tensor Cores并行運(yùn)行,這大大提高了吞吐量和效率。
簡單地說,它們是專門的cores,非常適合特定類型的矩陣操作。
我們可以將兩個FP16矩陣相乘,并將其添加到一個FP16/FP32矩陣中,從而得到一個FP16/FP32矩陣。Tensor cores支持混合精度數(shù)學(xué),即以半精度(FP16)進(jìn)行輸入,以全精度(FP32)進(jìn)行輸出。上述類型的操作對許多深度學(xué)習(xí)任務(wù)具有內(nèi)在價(jià)值,而Tensor cores為這種操作提供了專門的硬件。
現(xiàn)在,使用FP16和FP32主要有兩個好處。
- FP16需要更少的內(nèi)存,因此更容易訓(xùn)練和部署大型神經(jīng)網(wǎng)絡(luò)。它還只需要較少的數(shù)據(jù)移動。
- 數(shù)學(xué)運(yùn)算在降低精度的Tensor cores運(yùn)行得更快。NVIDIA給出的Volta GPU的確切數(shù)字是:FP16的125 TFlops vs FP32的15.7 TFlops(8倍加速)。
但也有缺點(diǎn)。當(dāng)我們從FP32轉(zhuǎn)到FP16時(shí),我們需要降低精度。
FP32 vs FP16: FP32 有8個指數(shù)位和23個分?jǐn)?shù)位,而FP16有5個指數(shù)位和10個分?jǐn)?shù)位。
但是FP32真的有必要嗎?
實(shí)際上,F(xiàn)P16可以很好地表示大多數(shù)權(quán)重和梯度。所以存儲和使用FP32是很浪費(fèi)的。
那么,我們?nèi)绾问褂肨ensor Cores?
我檢查了一下我的Titan RTX GPU有576個tensor cores和4608個NVIDIA CUDA核心。但是我如何使用這些tensor cores呢?
坦白地說,NVIDIA用幾行代碼就能提供自動混合精度,因此使用tensor cores很簡單。我們需要在代碼中做兩件事:
- 需要用到FP32的運(yùn)算比如Softmax之類的就分配用FP32,而Conv之類的操作可以用FP16的則被自動分配用FP16。
- 使用損失縮放?為了保留小的梯度值。梯度值可能落在FP16的范圍之外。在這種情況下,梯度值被縮放,使它們落在FP16范圍內(nèi)。
如果你還不了解背景細(xì)節(jié)也沒關(guān)系,代碼實(shí)現(xiàn)相對簡單。
使用PyTorch進(jìn)行混合精度訓(xùn)練:
讓我們從PyTorch中的一個基本網(wǎng)絡(luò)開始。
N, D_in, D_out = 64, 1024, 512 x = torch.randn(N, D_in, device="cuda") y = torch.randn(N, D_out, device="cuda") model = torch.nn.Linear(D_in, D_out).cuda() optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) for to in range(500): y_pred = model(x) loss = torch.nn.functional.mse_loss(y_pred, y) optimizer.zero_grad() loss.backward() optimizer.step()
為了充分利用自動混合精度訓(xùn)練的優(yōu)勢,我們首先需要安裝apex庫。只需在終端中運(yùn)行以下命令。
$ git clone https://github.com/NVIDIA/apex $ cd apex $ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
然后,我們只需向神經(jīng)網(wǎng)絡(luò)代碼中添加幾行代碼,就可以利用自動混合精度(AMP)。
from apex import amp N, D_in, D_out = 64, 1024, 512 x = torch.randn(N, D_in, device="cuda") y = torch.randn(N, D_out, device="cuda") model = torch.nn.Linear(D_in, D_out).cuda() optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) model, optimizer = amp.initialize(model, optimizer, opt_level="O1") for to in range(500): y_pred = model(x) loss = torch.nn.functional.mse_loss(y_pred, y) optimizer.zero_grad() with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() optimizer.step()
在這里你可以看到我們用amp.initialize
初始化了我們的模型。我們還使用amp.scale_loss
來指定損失縮放。
基準(zhǔn)測試
git clone https://github.com/MLWhiz/data_science_blogs cd data_science_blogs/amp/pytorch-apex-experiment/ python run_benchmark.py python make_plot.py --GPU 'RTX' --method 'FP32' 'FP16' 'amp' --batch 128 256 512 1024 2048
這會在home目錄中生成下面的圖:
在這里,我使用不同的精度和批大小設(shè)置訓(xùn)練了同一個模型的多個實(shí)例。我們可以看到,從FP32到amp,內(nèi)存需求減少,而精度保持大致相同。時(shí)間也會減少,但不會減少那么多。這可能是由于數(shù)據(jù)集或模型太簡單。
根據(jù)NVIDIA給出的基準(zhǔn)測試,AMP比標(biāo)準(zhǔn)的FP32快3倍左右,如下圖所示。
在單精度和自動混合精度兩種精度下,加速比為固定周期訓(xùn)練的時(shí)間比。
原文鏈接:https://juejin.cn/post/7157663977437134884
相關(guān)推薦
- 2022-11-27 Rust指南之生命周期機(jī)制詳解_Rust語言
- 2022-05-31 利用Python進(jìn)行數(shù)據(jù)清洗的操作指南_python
- 2022-07-03 C語言數(shù)組全面詳細(xì)講解_C 語言
- 2022-10-25 IDEA 安裝tomcat10創(chuàng)建servlet報(bào)404錯誤
- 2022-05-10 二叉樹的遞歸和非遞歸遍歷
- 2022-03-27 C#?Razor語法規(guī)則_C#教程
- 2022-05-22 Nginx的基本概念和原理_nginx
- 2022-04-02 Docker鏡像發(fā)布到Docker?Hub的實(shí)現(xiàn)方法_docker
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細(xì)win安裝深度學(xué)習(xí)環(huán)境2025年最新版(
- Linux 中運(yùn)行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎(chǔ)操作-- 運(yùn)算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認(rèn)證信息的處理
- Spring Security之認(rèn)證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權(quán)
- redisson分布式鎖中waittime的設(shè)
- maven:解決release錯誤:Artif
- restTemplate使用總結(jié)
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實(shí)現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務(wù)發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結(jié)構(gòu)-簡單動態(tài)字符串(SD
- arthas操作spring被代理目標(biāo)對象命令
- Spring中的單例模式應(yīng)用詳解
- 聊聊消息隊(duì)列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠(yuǎn)程分支