網站首頁 編程語言 正文
返回最大值的index
import torch
a=torch.tensor([[.1,.2,.3],
? ? ? ? ? ? ? ? [1.1,1.2,1.3],
? ? ? ? ? ? ? ? [2.1,2.2,2.3],
? ? ? ? ? ? ? ? [3.1,3.2,3.3]])
print(a.argmax(dim=1))
print(a.argmax())
輸出:
tensor([ 2, ?2, ?2, ?2])
tensor(11)
pytorch 找最大值
題意:使用神經網絡實現,從數組中找出最大值。
提供數據:兩個 csv 文件,一個存訓練集:n 個 m 維特征自然數數據,另一個存每條數據對應的 label ,就是每條數據中的最大值。
這里將隨機構建訓練集:
#%%
import numpy as np
import pandas as pd
import torch
import random
import torch.utils.data as Data
import torch.nn as nn
import torch.optim as optim
def GetData(m, n):
dataset = []
for j in range(m):
max_v = random.randint(0, 9)
data = [random.randint(0, 9) for i in range(n)]
dataset.append(data)
label = [max(dataset[i]) for i in range(len(dataset))]
data_list = np.column_stack((dataset, label))
data_list = data_list.astype(np.float32)
return data_list
#%%
# 數據集封裝 重載函數len, getitem
class GetMaxEle(Data.Dataset):
def __init__(self, trainset):
self.data = trainset
def __getitem__(self, index):
item = self.data[index]
x = item[:-1]
y = item[-1]
return x, y
def __len__(self):
return len(self.data)
# %% 定義網絡模型
class SingleNN(nn.Module):
def __init__(self, n_feature, n_hidden, n_output):
super(SingleNN, self).__init__()
self.hidden = nn.Linear(n_feature, n_hidden)
self.relu = nn.ReLU()
self.predict = nn.Linear(n_hidden, n_output)
def forward(self, x):
x = self.hidden(x)
x = self.relu(x)
x = self.predict(x)
return x
def train(m, n, batch_size, PATH):
# 隨機生成 m 個 n 個維度的訓練樣本
data_list =GetData(m, n)
dataset = GetMaxEle(data_list)
trainset = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
shuffle=True)
net = SingleNN(n_feature=10, n_hidden=100,
n_output=10)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
#
total_epoch = 100
for epoch in range(total_epoch):
for index, data in enumerate(trainset):
input_x, labels = data
labels = labels.long()
optimizer.zero_grad()
output = net(input_x)
# print(output)
# print(labels)
loss = criterion(output, labels)
loss.backward()
optimizer.step()
# scheduled_optimizer.step()
print(f"Epoch {epoch}, loss:{loss.item()}")
# %% 保存參數
torch.save(net.state_dict(), PATH)
#測試
def test(m, n, batch_size, PATH):
data_list = GetData(m, n)
dataset = GetMaxEle(data_list)
testloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
dataiter = iter(testloader)
input_x, labels = dataiter.next()
net = SingleNN(n_feature=10, n_hidden=100,
n_output=10)
net.load_state_dict(torch.load(PATH))
outputs = net(input_x)
_, predicted = torch.max(outputs, 1)
print("Ground_truth:",labels.numpy())
print("predicted:",predicted.numpy())
if __name__ == "__main__":
m = 1000
n = 10
batch_size = 64
PATH = './max_list.pth'
train(m, n, batch_size, PATH)
test(m, n, batch_size, PATH)
初始的想法是使用全連接網絡+分類來實現, 但是結果不盡人意,主要原因:不同類別之間的樣本量差太大,幾乎90%都是最大值。
比如代碼中隨機構建 10 個 0~9 的數字構成一個樣本[2, 3, 5, 8, 9, 5, 3, 9, 3, 6], 該樣本標簽是9。
原文鏈接:https://blog.csdn.net/lrt366/article/details/94408090
相關推薦
- 2023-10-11 小程序|頁面傳參的三種方式
- 2022-09-25 nginx平滑升級、nginx支持的kill信號
- 2022-09-03 Python?Pandas中DataFrame.drop_duplicates()刪除重復值詳解_p
- 2022-07-01 Oracle數據庫用戶密碼過期的解決方法_oracle
- 2022-11-09 PostgreSQL?HOT與PHOT有哪些區別_PostgreSQL
- 2023-05-31 Pandas中map(),applymap(),apply()函數的使用方法_python
- 2021-12-15 git_stats?web代碼圖形統計工具詳解_其它綜合
- 2022-12-24 提升Go語言開發效率的小技巧實例(GO語言語法糖)匯總_Golang
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支