網站首頁 編程語言 正文
由于Pytorch不像TensorFlow有谷歌巨頭做維護,很多功能并沒有很高級的封裝,比如說沒有tf.one_hot函數。
本篇介紹將一個mini batch的label向量變成形狀為[batch size, class numbers]的one hot編碼的兩種方法,涉及到
tensor.scatter_
tensor.index_select
前言
本文將針對全連接網絡和全卷積網絡輸出的形式不同,將one hot編碼分兩種情況。
- 第一種針對網絡輸出是二維,即全連接層的輸出形式, [Batchsize, Num_class]
- 第二種針對輸出是四維特征圖,即分割網絡的輸出形式,[Batchsize, Num_class, H,W]
先將第一種情況
使用scatter_獲得one hot 編碼
我相信在CSDN上找這個函數用法的人都是看不懂官方介紹的,所以我不會像其他地方那樣,搬官方教程,我也是琢磨了很久才看懂這個函數,但函數聲明還是要看看的。
tensor.scatter_(dim, index, src)?
-
dim
: 指定了覆蓋數據是從哪個軸作為依據。后面再詳細解釋。值的范圍是從0到 sum(tensor.shape)-1 -
index
: 告訴函數要將src中對應的值放到tensor的哪個位置。index的shape要和src一致,或者src可以通過廣播機制實現shape一致。 -
src
: 保存了想用來覆蓋tensor的值
我們先看一個例子,例子從別的博客copy過來,但我會做更加詳細的介紹。覺得講得好請留言作為鼓勵。
>>> x = torch.rand(2, 5) >>> x ?0.4319 ?0.6500 ?0.4080 ?0.8760 ?0.2355 ?0.2609 ?0.4711 ?0.8486 ?0.8573 ?0.1029 [torch.FloatTensor of size 2x5] >>> torch.zeros(3, 5).scatter_(0, torch.LongTensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x) ?0.4319 ?0.4711 ?0.8486 ?0.8760 ?0.2355 ?0.0000 ?0.6500 ?0.0000 ?0.8573 ?0.0000 ?0.2609 ?0.0000 ?0.4080 ?0.0000 ?0.1029 [torch.FloatTensor of size 3x5]
注意到dim為0,代表以第一個維度作為依托。index是一個二維數組。
[0,1,2,0,0]
[2,0,0,1,2]
那么我們要覆蓋tensor的位置有10個,分別為
[0,0];[1,1];[2,2];[0,3];[0,4]
[2,0];[0,1];[0,2];[1,3];[2,4]
dim指定了index我們要將index的值作為哪一個軸的值。其他軸就是按照0到max shape -1變化罷了。比如說dim為0,那么index的值都作為坐標的第一個位置的值,另一個位置從0到4變換。
你們可以驗證下,是不是這10個位置被覆蓋了。10個位置的第一個軸是index的數字,第二個數字是index中的列數,從0到4。
要覆蓋的位置有了,那么用什么值覆蓋呢?別忘了我們的index的維度和src是一樣的。index中選擇什么位置的坐標,就對應用src對應的位置的值代替。
比如說要代替tensor中[0,0]的值,index中[0,0]就是第0行第0列對應的位置,那我們用src第0行第0列的值代替tensor的值。大家可以去驗證一下。
我們看看下面的的情況,如果dim為1呢。
>>> z = torch.zeros(2, 4).scatter_(1, torch.LongTensor([[2], [3]]), 1.23) >>> z
先分析一下
dim為1,那么index的值都作為坐標的第2個位置的值,第一個位置的值應該從0到1變化。
所以要被代替的位置有
[0,2];[1,3]
而[0,2]的位置要填入的值為1.23,[1,3]要填入的值為1.23。(廣播機制將1.23這個標量擴展到了shape為(2,1))
好的,函數用法知道了。我們現在看看如何用該函數將label編碼為one hot編碼。
首先設想一個batch size為8的label。有10類,所以label中的數字應該是從0到9的。
import torch as t import numpy as np batch_size = 8 class_num = 10 label = np.random.randint(0,class_num,size=(batch_size,1)) label = t.LongTensor(label)
我們就獲得了一個label,shape是(8,1),必須是2維。如果是(8,)下面的內容會報錯的。
y_one_hot = t.zeros(batch_size,class_num).scatter_(1,label,1) print(y_one_hot) ''' tensor([[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], ? ? ? ? [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.], ? ? ? ? [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.], ? ? ? ? [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.], ? ? ? ? [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.], ? ? ? ? [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], ? ? ? ? [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.], ? ? ? ? [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]]) '''
搞定。下面我們看下面一種方法。
使用tensor.index_select獲得one hot編碼
還是先看下index_select的用法。
tensor.index_select( dim, index, out=None)
-
dim
: 指定按什么維度取tensor中的向量 -
index
: 是一個一維的張量。描述了按照dim維度取出tensor對應的index值的向量。
我們不看例子了,直接看方法,以此為例。
ones = torch.sparse.torch.eye(class_num) return ones.index_select(0,label)
這里的label是一維的向量,不是二維的。因為index制定了必須是一維的
先生成一個單位矩陣,尺寸是[class_num, class_num]。
dim為0,以為這按照行來取tensor的向量。具體取哪一行呢,就是label中的值了。
這時我們應該也明白為啥這兩行代碼能實現one hot編碼了吧。
如果label是[ 1,3,0],有四類。那我們得到就是
[0,1,0,0]
[0,0,0,1]
[1,0,0,0]
第二種針對分割網絡的one_hot編碼
對于分割類任務,網絡的GT肯定是二維數組,而不是像分類任務那樣的一維數組了。而對于分割任務,我們將其視作很多個像素值的分類任務,將ground truth 直接 reshape為向量形式,然后用上面的方法轉為one hot編碼,然后再reshape回來。核心是不變的。
下面舉個例子。
import torch import numpy as np gt = np.random.randint(0,5, size=[15,15]) ?#先生成一個15*15的label,值在5以內,意思是5類分割任務 gt = torch.LongTensor(gt) def get_one_hot(label, N): ? ? size = list(label.size()) ? ? label = label.view(-1) ? # reshape 為向量 ? ? ones = torch.sparse.torch.eye(N) ? ? ones = ones.index_select(0, label) ? # 用上面的辦法轉為換one hot ? ? size.append(N) ?# 把類別輸目添到size的尾后,準備reshape回原來的尺寸 ? ? return ones.view(*size) gt_one_hot = get_one_hot(gt, 5) print(gt_one_hot) print(gt_one_hot.shape) print(gt_one_hot.argmax(-1) == gt) ?# 判斷one hot 轉換方式是否正確,全是1就是正確的
另外注意,在Pytorch中,如果要和網絡輸出的特征圖一起計算loss,還要把上面輸出的one hot編碼的最后一個維度使用permute轉到通道維度上。
總結
原文鏈接:https://blog.csdn.net/qq_34914551/article/details/88700334
相關推薦
- 2022-04-23 C語言復雜鏈表的復制實例詳解_C 語言
- 2022-11-15 python內置模塊OS?實現SHELL端文件處理器_python
- 2023-03-18 C++虛函數和多態超詳細分析_C 語言
- 2023-01-08 C#?Winform文本面板帶滾動條的實現過程_C#教程
- 2022-10-12 排查服務器異常流量教程詳解_nginx
- 2021-12-15 C#?多線程學習之基礎入門_C#教程
- 2022-10-07 react性能優化useMemo與useCallback使用對比詳解_React
- 2023-10-16 input框錄入身份證自動填寫性別年齡
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支