網站首頁 編程語言 正文
一、實現(xiàn)步驟
1、準備數(shù)據(jù)
x_data = torch.tensor([[1.0],[2.0],[3.0]]) y_data = torch.tensor([[2.0],[4.0],[6.0]])
2、設計模型
class LinearModel(torch.nn.Module): ? ? def __init__(self): ? ? ? ? super(LinearModel,self).__init__() ? ? ? ? self.linear = torch.nn.Linear(1,1) ? ? ? ?? ? ? def forward(self, x): ? ? ? ? y_pred = self.linear(x) ? ? ? ? return y_pred ? ? ? ?? model = LinearModel() ?
3、構造損失函數(shù)和優(yōu)化器
criterion = torch.nn.MSELoss(reduction='sum') optimizer = torch.optim.SGD(model.parameters(),lr=0.01)
4、訓練過程
epoch_list = [] loss_list = [] w_list = [] b_list = [] for epoch in range(1000): ? ? y_pred = model(x_data)?? ??? ??? ??? ??? ? ?# 計算預測值 ? ? loss = criterion(y_pred, y_data)?? ?# 計算損失 ? ? print(epoch,loss) ? ?? ? ? epoch_list.append(epoch) ? ? loss_list.append(loss.data.item()) ? ? w_list.append(model.linear.weight.item()) ? ? b_list.append(model.linear.bias.item()) ? ?? ? ? optimizer.zero_grad() ? # 梯度歸零 ? ? loss.backward() ? ? ? ? # 反向傳播 ? ? optimizer.step() ? ? ? ?# 更新
5、結果展示
展示最終的權重和偏置:
# 輸出權重和偏置 print('w = ',model.linear.weight.item()) print('b = ',model.linear.bias.item())
結果為:
w = ?1.9998501539230347
b = ?0.0003405189490877092
模型測試:
# 測試模型 x_test = torch.tensor([[4.0]]) y_test = model(x_test) print('y_pred = ',y_test.data) y_pred = ?tensor([[7.9997]])
分別繪制損失值隨迭代次數(shù)變化的二維曲線圖和其隨權重與偏置變化的三維散點圖:
# 二維曲線圖 plt.plot(epoch_list,loss_list,'b') plt.xlabel('epoch') plt.ylabel('loss') plt.show() # 三維散點圖 fig = plt.figure() ax = fig.add_subplot(111, projection='3d') ax.scatter(w_list,b_list,loss_list,c='r') #設置坐標軸 ax.set_xlabel('weight') ax.set_ylabel('bias') ax.set_zlabel('loss') plt.show()
結果如下圖所示:
?到此這篇關于PyTorch實現(xiàn)線性回歸詳細過程的文章就介紹到這了,更多相關PyTorch線性回歸內容請搜索AB教程網以前的文章或繼續(xù)瀏覽下面的相關文章希望大家以后多多支持AB教程網!
二、參考文獻
- [1] https://www.bilibili.com/video/BV1Y7411d7Ys?p=5
原文鏈接:https://blog.csdn.net/weixin_43821559/article/details/123298468
相關推薦
- 2023-03-02 Python實現(xiàn)設置顯示屏分辨率_python
- 2022-08-20 windows系統(tǒng)安裝配置nginx環(huán)境_nginx
- 2022-10-29 【npm 報錯 gyp info it worked if it ends with ok 大概率是
- 2022-08-07 QT利用QProcess獲取計算機硬件信息_C 語言
- 2022-06-10 C語言?模擬實現(xiàn)strlen函數(shù)詳解_C 語言
- 2022-09-15 C/C++?左移<<,?右移>>的作用及說明_C 語言
- 2022-07-27 python中format的用法實例詳解_python
- 2022-06-15 C語言詳解實現(xiàn)字符菱形的方法_C 語言
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細win安裝深度學習環(huán)境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結構-簡單動態(tài)字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支