網(wǎng)站首頁 編程語言 正文
簡述
為了方便理解卷積神經(jīng)網(wǎng)絡的運行過程,需要對卷積神經(jīng)網(wǎng)絡的運行結(jié)果進行可視化的展示。
大致可分為如下步驟:
- 單個圖片的提取
- 神經(jīng)網(wǎng)絡的構(gòu)建
- 特征圖的提取
- 可視化展示
單個圖片的提取
根據(jù)目標要求,需要對單個圖片進行卷積運算,但是Pytorch中讀取數(shù)據(jù)主要用到torch.utils.data.DataLoader類,因此我們需要編寫單個圖片的讀取程序
def get_picture(picture_dir, transform): ''' 該算法實現(xiàn)了讀取圖片,并將其類型轉(zhuǎn)化為Tensor ''' tmp = [] img = skimage.io.imread(picture_dir) tmp.append(img) img = skimage.io.imread('./picture/4.jpg') tmp.append(img) img256 = [skimage.transform.resize(img, (256, 256)) for img in tmp] img256 = np.asarray(img256) img256 = img256.astype(np.float32) return transform(img256[0])
注意: 神經(jīng)網(wǎng)絡的輸入是四維形式,我們返回的圖片是三維形式,需要使用unsqueeze()插入一個維度
神經(jīng)網(wǎng)絡的構(gòu)建
網(wǎng)絡的基于LeNet構(gòu)建,不過為了方便展示,將其中的參數(shù)按照2562563進行的參數(shù)的修正
網(wǎng)絡構(gòu)建如下:
class LeNet(nn.Module): ''' 該類繼承了torch.nn.Modul類 構(gòu)建LeNet神經(jīng)網(wǎng)絡模型 ''' def __init__(self): super(LeNet, self).__init__() # 第一層神經(jīng)網(wǎng)絡,包括卷積層、線性激活函數(shù)、池化層 self.conv1 = nn.Sequential( nn.Conv2d(3, 32, 5, 1, 2), # input_size=(3*256*256),padding=2 nn.ReLU(), # input_size=(32*256*256) nn.MaxPool2d(kernel_size=2, stride=2), # output_size=(32*128*128) ) # 第二層神經(jīng)網(wǎng)絡,包括卷積層、線性激活函數(shù)、池化層 self.conv2 = nn.Sequential( nn.Conv2d(32, 64, 5, 1, 2), # input_size=(32*128*128) nn.ReLU(), # input_size=(64*128*128) nn.MaxPool2d(2, 2) # output_size=(64*64*64) ) # 全連接層(將神經(jīng)網(wǎng)絡的神經(jīng)元的多維輸出轉(zhuǎn)化為一維) self.fc1 = nn.Sequential( nn.Linear(64 * 64 * 64, 128), # 進行線性變換 nn.ReLU() # 進行ReLu激活 ) # 輸出層(將全連接層的一維輸出進行處理) self.fc2 = nn.Sequential( nn.Linear(128, 84), nn.ReLU() ) # 將輸出層的數(shù)據(jù)進行分類(輸出預測值) self.fc3 = nn.Linear(84, 62) # 定義前向傳播過程,輸入為x def forward(self, x): x = self.conv1(x) x = self.conv2(x) # nn.Linear()的輸入輸出都是維度為一的值,所以要把多維度的tensor展平成一維 x = x.view(x.size()[0], -1) x = self.fc1(x) x = self.fc2(x) x = self.fc3(x) return x
特征圖的提取
直接上代碼:
class FeatureExtractor(nn.Module): def __init__(self, submodule, extracted_layers): super(FeatureExtractor, self).__init__() self.submodule = submodule self.extracted_layers = extracted_layers def forward(self, x): outputs = [] for name, module in self.submodule._modules.items(): # 目前不展示全連接層 if "fc" in name: x = x.view(x.size(0), -1) print(module) x = module(x) print(name) if name in self.extracted_layers: outputs.append(x) return outputs
可視化展示
可視化展示使用matplotlib
代碼如下:
# 特征輸出可視化 for i in range(32): ax = plt.subplot(6, 6, i + 1) ax.set_title('Feature {}'.format(i)) ax.axis('off') plt.imshow(x[0].data.numpy()[0,i,:,:],cmap='jet') plt.plot()
完整代碼
在此貼上完整代碼
import os import torch import torchvision as tv import torchvision.transforms as transforms import torch.nn as nn import torch.optim as optim import argparse import skimage.data import skimage.io import skimage.transform import numpy as np import matplotlib.pyplot as plt # 定義是否使用GPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load training and testing datasets. pic_dir = './picture/3.jpg' # 定義數(shù)據(jù)預處理方式(將輸入的類似numpy中arrary形式的數(shù)據(jù)轉(zhuǎn)化為pytorch中的張量(tensor)) transform = transforms.ToTensor() def get_picture(picture_dir, transform): ''' 該算法實現(xiàn)了讀取圖片,并將其類型轉(zhuǎn)化為Tensor ''' img = skimage.io.imread(picture_dir) img256 = skimage.transform.resize(img, (256, 256)) img256 = np.asarray(img256) img256 = img256.astype(np.float32) return transform(img256) def get_picture_rgb(picture_dir): ''' 該函數(shù)實現(xiàn)了顯示圖片的RGB三通道顏色 ''' img = skimage.io.imread(picture_dir) img256 = skimage.transform.resize(img, (256, 256)) skimage.io.imsave('./picture/4.jpg',img256) # 取單一通道值顯示 # for i in range(3): # img = img256[:,:,i] # ax = plt.subplot(1, 3, i + 1) # ax.set_title('Feature {}'.format(i)) # ax.axis('off') # plt.imshow(img) # r = img256.copy() # r[:,:,0:2]=0 # ax = plt.subplot(1, 4, 1) # ax.set_title('B Channel') # # ax.axis('off') # plt.imshow(r) # g = img256.copy() # g[:,:,0]=0 # g[:,:,2]=0 # ax = plt.subplot(1, 4, 2) # ax.set_title('G Channel') # # ax.axis('off') # plt.imshow(g) # b = img256.copy() # b[:,:,1:3]=0 # ax = plt.subplot(1, 4, 3) # ax.set_title('R Channel') # # ax.axis('off') # plt.imshow(b) # img = img256.copy() # ax = plt.subplot(1, 4, 4) # ax.set_title('image') # # ax.axis('off') # plt.imshow(img) img = img256.copy() ax = plt.subplot() ax.set_title('image') # ax.axis('off') plt.imshow(img) plt.show() class LeNet(nn.Module): ''' 該類繼承了torch.nn.Modul類 構(gòu)建LeNet神經(jīng)網(wǎng)絡模型 ''' def __init__(self): super(LeNet, self).__init__() # 第一層神經(jīng)網(wǎng)絡,包括卷積層、線性激活函數(shù)、池化層 self.conv1 = nn.Sequential( nn.Conv2d(3, 32, 5, 1, 2), # input_size=(3*256*256),padding=2 nn.ReLU(), # input_size=(32*256*256) nn.MaxPool2d(kernel_size=2, stride=2), # output_size=(32*128*128) ) # 第二層神經(jīng)網(wǎng)絡,包括卷積層、線性激活函數(shù)、池化層 self.conv2 = nn.Sequential( nn.Conv2d(32, 64, 5, 1, 2), # input_size=(32*128*128) nn.ReLU(), # input_size=(64*128*128) nn.MaxPool2d(2, 2) # output_size=(64*64*64) ) # 全連接層(將神經(jīng)網(wǎng)絡的神經(jīng)元的多維輸出轉(zhuǎn)化為一維) self.fc1 = nn.Sequential( nn.Linear(64 * 64 * 64, 128), # 進行線性變換 nn.ReLU() # 進行ReLu激活 ) # 輸出層(將全連接層的一維輸出進行處理) self.fc2 = nn.Sequential( nn.Linear(128, 84), nn.ReLU() ) # 將輸出層的數(shù)據(jù)進行分類(輸出預測值) self.fc3 = nn.Linear(84, 62) # 定義前向傳播過程,輸入為x def forward(self, x): x = self.conv1(x) x = self.conv2(x) # nn.Linear()的輸入輸出都是維度為一的值,所以要把多維度的tensor展平成一維 x = x.view(x.size()[0], -1) x = self.fc1(x) x = self.fc2(x) x = self.fc3(x) return x # 中間特征提取 class FeatureExtractor(nn.Module): def __init__(self, submodule, extracted_layers): super(FeatureExtractor, self).__init__() self.submodule = submodule self.extracted_layers = extracted_layers def forward(self, x): outputs = [] print(self.submodule._modules.items()) for name, module in self.submodule._modules.items(): if "fc" in name: print(name) x = x.view(x.size(0), -1) print(module) x = module(x) print(name) if name in self.extracted_layers: outputs.append(x) return outputs def get_feature(): # 輸入數(shù)據(jù) img = get_picture(pic_dir, transform) # 插入維度 img = img.unsqueeze(0) img = img.to(device) # 特征輸出 net = LeNet().to(device) # net.load_state_dict(torch.load('./model/net_050.pth')) exact_list = ["conv1","conv2"] myexactor = FeatureExtractor(net, exact_list) x = myexactor(img) # 特征輸出可視化 for i in range(32): ax = plt.subplot(6, 6, i + 1) ax.set_title('Feature {}'.format(i)) ax.axis('off') plt.imshow(x[0].data.numpy()[0,i,:,:],cmap='jet') plt.show() # 訓練 if __name__ == "__main__": get_picture_rgb(pic_dir) # get_feature()
總結(jié)
原文鏈接:https://blog.csdn.net/ZOUZHEN_ID/article/details/84025943
相關(guān)推薦
- 2022-01-29 git 本地,遠程做了不同的修改,同步方法
- 2023-01-18 淺析Python的對象拷貝和內(nèi)存布局_python
- 2022-07-01 Keras實現(xiàn)Vision?Transformer?VIT模型示例詳解_python
- 2022-12-09 Python命名空間與作用域深入全面詳解_python
- 2022-05-20 python?使用turtle實現(xiàn)實時鐘表并生成exe_python
- 2022-10-04 Android系統(tǒng)優(yōu)化Ninja加快編譯_Android
- 2022-11-26 React?數(shù)據(jù)獲取與性能優(yōu)化詳解_React
- 2022-04-25 在?Python?中進行?One-Hot?編碼_python
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細win安裝深度學習環(huán)境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎(chǔ)操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權(quán)
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結(jié)
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結(jié)構(gòu)-簡單動態(tài)字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支