網(wǎng)站首頁 編程語言 正文
一、實(shí)現(xiàn)步驟
1、準(zhǔn)備數(shù)據(jù)
x_data = torch.tensor([[1.0],[2.0],[3.0]]) y_data = torch.tensor([[2.0],[4.0],[6.0]])
2、設(shè)計(jì)模型
class LinearModel(torch.nn.Module): ? ? def __init__(self): ? ? ? ? super(LinearModel,self).__init__() ? ? ? ? self.linear = torch.nn.Linear(1,1) ? ? ? ?? ? ? def forward(self, x): ? ? ? ? y_pred = self.linear(x) ? ? ? ? return y_pred ? ? ? ?? model = LinearModel() ?
3、構(gòu)造損失函數(shù)和優(yōu)化器
criterion = torch.nn.MSELoss(reduction='sum') optimizer = torch.optim.SGD(model.parameters(),lr=0.01)
4、訓(xùn)練過程
epoch_list = [] loss_list = [] w_list = [] b_list = [] for epoch in range(1000): ? ? y_pred = model(x_data)?? ??? ??? ??? ??? ? ?# 計(jì)算預(yù)測(cè)值 ? ? loss = criterion(y_pred, y_data)?? ?# 計(jì)算損失 ? ? print(epoch,loss) ? ?? ? ? epoch_list.append(epoch) ? ? loss_list.append(loss.data.item()) ? ? w_list.append(model.linear.weight.item()) ? ? b_list.append(model.linear.bias.item()) ? ?? ? ? optimizer.zero_grad() ? # 梯度歸零 ? ? loss.backward() ? ? ? ? # 反向傳播 ? ? optimizer.step() ? ? ? ?# 更新
5、結(jié)果展示
展示最終的權(quán)重和偏置:
# 輸出權(quán)重和偏置 print('w = ',model.linear.weight.item()) print('b = ',model.linear.bias.item())
結(jié)果為:
w = ?1.9998501539230347
b = ?0.0003405189490877092
模型測(cè)試:
# 測(cè)試模型 x_test = torch.tensor([[4.0]]) y_test = model(x_test) print('y_pred = ',y_test.data) y_pred = ?tensor([[7.9997]])
分別繪制損失值隨迭代次數(shù)變化的二維曲線圖和其隨權(quán)重與偏置變化的三維散點(diǎn)圖:
# 二維曲線圖 plt.plot(epoch_list,loss_list,'b') plt.xlabel('epoch') plt.ylabel('loss') plt.show() # 三維散點(diǎn)圖 fig = plt.figure() ax = fig.add_subplot(111, projection='3d') ax.scatter(w_list,b_list,loss_list,c='r') #設(shè)置坐標(biāo)軸 ax.set_xlabel('weight') ax.set_ylabel('bias') ax.set_zlabel('loss') plt.show()
結(jié)果如下圖所示:
?到此這篇關(guān)于PyTorch實(shí)現(xiàn)線性回歸詳細(xì)過程的文章就介紹到這了,更多相關(guān)PyTorch線性回歸內(nèi)容請(qǐng)搜索AB教程網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持AB教程網(wǎng)!
二、參考文獻(xiàn)
- [1] https://www.bilibili.com/video/BV1Y7411d7Ys?p=5
原文鏈接:https://blog.csdn.net/weixin_43821559/article/details/123298468
相關(guān)推薦
- 2022-11-16 Django?報(bào)錯(cuò):Broken?pipe?from?('127.0.0.1',?58924)的解決
- 2023-07-13 用webpack做一些前端打包時(shí)的性能優(yōu)化
- 2022-12-24 Android開發(fā)中Signal背后的bug與解決_Android
- 2022-01-08 vscode推送代碼失敗,報(bào)錯(cuò)You are not allowed to push code to
- 2024-01-14 在springboot中給mybatis加攔截器
- 2022-07-30 os.path模塊下的顯示路徑方法
- 2022-08-20 docker鏡像alpine中安裝oracle客戶端_docker
- 2022-06-19 mybatis-plus的sql語句打印問題小結(jié)_MsSql
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細(xì)win安裝深度學(xué)習(xí)環(huán)境2025年最新版(
- Linux 中運(yùn)行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲(chǔ)小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎(chǔ)操作-- 運(yùn)算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認(rèn)證信息的處理
- Spring Security之認(rèn)證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權(quán)
- redisson分布式鎖中waittime的設(shè)
- maven:解決release錯(cuò)誤:Artif
- restTemplate使用總結(jié)
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實(shí)現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務(wù)發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結(jié)構(gòu)-簡(jiǎn)單動(dòng)態(tài)字符串(SD
- arthas操作spring被代理目標(biāo)對(duì)象命令
- Spring中的單例模式應(yīng)用詳解
- 聊聊消息隊(duì)列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠(yuǎn)程分支