網站首頁 編程語言 正文
pytorch定義新的自動求導函數
在pytorch中想自定義求導函數,通過實現torch.autograd.Function并重寫forward和backward函數,來定義自己的自動求導運算。參考官網上的demo:傳送門
直接上代碼,定義一個ReLu來實現自動求導
import torch
class MyRelu(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
# 我們使用ctx上下文對象來緩存,以便在反向傳播中使用,ctx存儲時候只能存tensor
# 在正向傳播中,我們接收一個上下文對象ctx和一個包含輸入的張量input;
# 我們必須返回一個包含輸出的張量,
# input.clamp(min = 0)表示講輸入中所有值范圍規定到0到正無窮,如input=[-1,-2,3]則被轉換成input=[0,0,3]
ctx.save_for_backward(input)
# 返回幾個值,backward接受參數則包含ctx和這幾個值
return input.clamp(min = 0)
@staticmethod
def backward(ctx, grad_output):
# 把ctx中存儲的input張量讀取出來
input, = ctx.saved_tensors
# grad_output存放反向傳播過程中的梯度
grad_input = grad_output.clone()
# 這兒就是ReLu的規則,表示原始數據小于0,則relu為0,因此對應索引的梯度都置為0
grad_input[input < 0] = 0
return grad_input
進行輸入數據并測試
dtype = torch.float
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 使用torch的generator定義隨機數,注意產生的是cpu隨機數還是gpu隨機數
generator=torch.Generator(device).manual_seed(42)
# N是Batch, H is hidden dimension,
# D_in is input dimension;D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10
x = torch.randn(N, D_in, device=device, dtype=dtype,generator=generator)
y = torch.randn(N, D_out, device=device, dtype=dtype, generator=generator)
w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True, generator=generator)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True, generator=generator)
learning_rate = 1e-6
for t in range(500):
relu = MyRelu.apply
# 使用函數傳入參數運算
y_pred = relu(x.mm(w1)).mm(w2)
# 計算損失
loss = (y_pred - y).pow(2).sum()
if t % 100 == 99:
print(t, loss.item())
# 傳播
loss.backward()
with torch.no_grad():
w1 -= learning_rate * w1.grad
w2 -= learning_rate * w2.grad
w1.grad.zero_()
w2.grad.zero_()
pytorch自動求導與邏輯回歸
自動求導
retain_graph設為True,可以進行兩次反向傳播
邏輯回歸
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
torch.manual_seed(10)
#========生成數據=============
sample_nums = 100
mean_value = 1.7
bias = 1
n_data = torch.ones(sample_nums,2)
x0 = torch.normal(mean_value*n_data,1)+bias#類別0數據
y0 = torch.zeros(sample_nums)#類別0標簽
x1 = torch.normal(-mean_value*n_data,1)+bias#類別1數據
y1 = torch.ones(sample_nums)#類別1標簽
train_x = torch.cat((x0,x1),0)
train_y = torch.cat((y0,y1),0)
#==========選擇模型===========
class LR(nn.Module):
def __init__(self):
super(LR,self).__init__()
self.features = nn.Linear(2,1)
self.sigmoid = nn.Sigmoid()
def forward(self,x):
x = self.features(x)
x = self.sigmoid(x)
return x
lr_net = LR()#實例化邏輯回歸模型
#==============選擇損失函數===============
loss_fn = nn.BCELoss()
#==============選擇優化器=================
lr = 0.01
optimizer = torch.optim.SGD(lr_net.parameters(),lr = lr,momentum=0.9)
#===============模型訓練==================
for iteration in range(1000):
#前向傳播
y_pred = lr_net(train_x)#模型的輸出
#計算loss
loss = loss_fn(y_pred.squeeze(),train_y)
#反向傳播
loss.backward()
#更新參數
optimizer.step()
#繪圖
if iteration % 20 == 0:
mask = y_pred.ge(0.5).float().squeeze() #以0.5分類
correct = (mask==train_y).sum()#正確預測樣本數
acc = correct.item()/train_y.size(0)#分類準確率
plt.scatter(x0.data.numpy()[:,0],x0.data.numpy()[:,1],c='r',label='class0')
plt.scatter(x1.data.numpy()[:,0],x1.data.numpy()[:,1],c='b',label='class1')
w0,w1 = lr_net.features.weight[0]
w0,w1 = float(w0.item()),float(w1.item())
plot_b = float(lr_net.features.bias[0].item())
plot_x = np.arange(-6,6,0.1)
plot_y = (-w0*plot_x-plot_b)/w1
plt.xlim(-5,7)
plt.ylim(-7,7)
plt.plot(plot_x,plot_y)
plt.text(-5,5,'Loss=%.4f'%loss.data.numpy(),fontdict={'size':20,'color':'red'})
plt.title('Iteration:{}\nw0:{:.2f} w1:{:.2f} b{:.2f} accuracy:{:2%}'.format(iteration,w0,w1,plot_b,acc))
plt.legend()
plt.show()
plt.pause(0.5)
if acc > 0.99:
break
總結
原文鏈接:https://blog.csdn.net/l8947943/article/details/105633826
相關推薦
- 2022-06-24 C#利用itext實現PDF頁面處理與切分_C#教程
- 2022-02-28 npm install安裝報錯 gyp info it worked if it ends with
- 2022-07-09 kernel利用pt?regs劫持seq?operations的遷移過程詳解_C 語言
- 2022-07-19 安卓TextView的lineHeight*lineCount!=height問題,解決不支持滾動的
- 2022-05-04 R語言數據類型與相應運算的實現_R語言
- 2022-03-14 npm 依賴下載報錯 Hostname/IP does not match certificate‘
- 2022-04-19 npm install運行原理分析
- 2022-03-19 c語言執行Hello?World背后經歷的步驟_C 語言
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支