網站首頁 編程語言 正文
引言
solver算是caffe的核心的核心,它協調著整個模型的運作。caffe程序運行必帶的一個參數就是solver配置文件。運行代碼一般為
# caffe train --solver=*_slover.prototxt
在Deep Learning中,往往loss function是非凸的,沒有解析解,我們需要通過優化方法來求解。solver的主要作用就是交替調用前向(forward)算法和后向(backward)算法來更新參數,從而最小化loss,實際上就是一種迭代的優化算法。
到目前的版本,caffe提供了六種優化算法來求解最優參數,在solver配置文件中,通過設置type類型來選擇。
- Stochastic Gradient Descent (
type: "SGD"
), - AdaDelta (
type: "AdaDelta"
), - Adaptive Gradient (
type: "AdaGrad"
), - Adam (
type: "Adam"
), - Nesterov’s Accelerated Gradient (
type: "Nesterov"
) and - RMSprop (
type: "RMSProp"
)
具體的每種方法的介紹,請看本系列的下一篇文章, 本文著重介紹solver配置文件的編寫。
Solver的流程:
- 1.設計好需要優化的對象,以及用于學習的訓練網絡和用于評估的測試網絡。(通過調用另外一個配置文件prototxt來進行)
- 2.通過forward和backward迭代的進行優化來跟新參數。
- 3.定期的評價測試網絡。 (可設定多少次訓練后,進行一次測試)
- 4.在優化過程中顯示模型和solver的狀態
在每一次的迭代過程中,solver做了這幾步工作:
- 1、調用forward算法來計算最終的輸出值,以及對應的loss
- 2、調用backward算法來計算每層的梯度
- 3、根據選用的slover方法,利用梯度進行參數更新
- 4、記錄并保存每次迭代的學習率、快照,以及對應的狀態。
接下來,我們先來看一個實例:
net: "examples/mnist/lenet_train_test.prototxt" test_iter: 100 test_interval: 500 base_lr: 0.01 momentum: 0.9 type: SGD weight_decay: 0.0005 lr_policy: "inv" gamma: 0.0001 power: 0.75 display: 100 max_iter: 20000 snapshot: 5000 snapshot_prefix: "examples/mnist/lenet" solver_mode: CPU
接下來,我們對每一行進行詳細解譯:
net: "examples/mnist/lenet_train_test.prototxt"
設置深度網絡模型。每一個模型就是一個net,需要在一個專門的配置文件中對net進行配置,每個net由許多的layer所組成。每一個layer的具體配置方式可參考本系列文文章中的(2)-(5)。注意的是:文件的路徑要從caffe的根目錄開始,其它的所有配置都是這樣。
訓練測試模型
也可用train_net和test_net來對訓練模型和測試模型分別設定。例如:
train_net: "examples/hdf5_classification/logreg_auto_train.prototxt" test_net: "examples/hdf5_classification/logreg_auto_test.prototxt"
接下來第二行:
test_iter: 100
這個要與test layer中的batch_size結合起來理解。mnist數據中測試樣本總數為10000,一次性執行全部數據效率很低,因此我們將測試數據分成幾個批次來執行,每個批次的數量就是batch_size。假設我們設置batch_size為100,則需要迭代100次才能將10000個數據全部執行完。因此test_iter設置為100。執行完一次全部數據,稱之為一個epoch
test_interval: 500
測試間隔。也就是每訓練500次,才進行一次測試。
base_lr: 0.01 lr_policy: "inv" gamma: 0.0001 power: 0.75
這四行可以放在一起理解,用于學習率的設置。只要是梯度下降法來求解優化,都會有一個學習率,也叫步長。base_lr用于設置基礎學習率,在迭代的過程中,可以對基礎學習率進行調整。怎么樣進行調整,就是調整的策略,由lr_policy來設置。
lr_policy可以設置為下面這些值,相應的學習率的計算為:
- - fixed:保持base_lr不變.
- - step: 如果設置為step,則還需要設置一個stepsize, 返回 base_lr * gamma ^ (floor(iter / stepsize)),其中iter表示當前的迭代次數
- - exp: 返回base_lr * gamma ^ iter, iter為當前迭代次數
- - inv: 如果設置為inv,還需要設置一個power, 返回base_lr * (1 + gamma * iter) ^ (- power)
- - multistep:如果設置為multistep,則還需要設置一個stepvalue。這個參數和step很相似,step是均勻等間隔變化,而multistep則是根據 stepvalue值變化
- - poly:學習率進行多項式誤差, 返回 base_lr (1 - iter/max_iter) ^ (power)
- - sigmoid:學習率進行sigmod衰減,返回 base_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))
multistep示例:
base_lr: 0.01 momentum: 0.9 weight_decay: 0.0005 # The learning rate policy lr_policy: "multistep" gamma: 0.9 stepvalue: 5000 stepvalue: 7000 stepvalue: 8000 stepvalue: 9000 stepvalue: 9500
參數
接下來的參數:
momentum :0.9
上一次梯度更新的權重,具體可參看下一篇文章。
type: SGD
優化算法選擇。這一行可以省掉,因為默認值就是SGD??偣灿辛N方法可選擇,在本文的開頭已介紹。
weight_decay: 0.0005
權重衰減項,防止過擬合的一個參數。
display: 100
每訓練100次,在屏幕上顯示一次。如果設置為0,則不顯示。
max_iter: 20000
最大迭代次數。這個數設置太小,會導致沒有收斂,精確度很低。設置太大,會導致震蕩,浪費時間。
snapshot: 5000 snapshot_prefix: "examples/mnist/lenet"
快照。將訓練出來的model和solver狀態進行保存,snapshot用于設置訓練多少次后進行保存,默認為0,不保存。snapshot_prefix設置保存路徑。
還可以設置snapshot_diff,是否保存梯度值,默認為false,不保存。
也可以設置snapshot_format,保存的類型。有兩種選擇:HDF5 和BINARYPROTO ,默認為BINARYPROTO
solver_mode: CPU
設置運行模式。默認為GPU,如果你沒有GPU,則需要改成CPU,否則會出錯。
注意:以上的所有參數都是可選參數,都有默認值。根據solver方法(type)的不同,還有一些其它的參數,在此不一一列舉。
原文鏈接:https://www.cnblogs.com/denny402/p/5074049.html
相關推薦
- 2022-05-06 python使用xlrd模塊讀取excel的方法實例_python
- 2022-10-10 React?組件的常用生命周期函數匯總_Redis
- 2022-03-16 Android線程池源碼閱讀記錄介紹_Android
- 2022-10-18 go日志庫中的logrus_Golang
- 2022-04-06 Python中shutil模塊的使用詳解_python
- 2022-12-04 go高并發時append方法偶現錯誤解決分析_Golang
- 2022-10-09 React高階組件的使用淺析_React
- 2022-09-29 數據設計之權限的實現_數據庫其它
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支