網(wǎng)站首頁(yè) 編程語(yǔ)言 正文
本文以一段代碼為例,簡(jiǎn)單介紹一下tensorflow與pytorch的相互轉(zhuǎn)換(主要是tensorflow轉(zhuǎn)pytorch),可能介紹的沒(méi)有那么詳細(xì),僅供參考。
由于本人只熟悉pytorch,而對(duì)tensorflow一知半解,而代碼經(jīng)常遇到tensorflow,而我希望使用pytorch,因此簡(jiǎn)單介紹一下tensorflow轉(zhuǎn)pytorch,可能存在諸多錯(cuò)誤,希望輕噴~
1.變量預(yù)定義
在TensorFlow的世界里,變量的定義和初始化是分開(kāi)的。
tensorflow中一般都是在開(kāi)頭預(yù)定義變量,聲明其數(shù)據(jù)類(lèi)型、形狀等,在執(zhí)行的時(shí)候再賦具體的值,如下圖所示,而pytorch用到時(shí)才會(huì)定義,定義和變量初始化是合在一起的。
2.創(chuàng)建變量并初始化
tensorflow中利用tf.Variable創(chuàng)建變量并進(jìn)行初始化,而pytorch中使用torch.tensor創(chuàng)建變量并進(jìn)行初始化,如下圖所示。
3.語(yǔ)句執(zhí)行
在TensorFlow的世界里,變量的定義和初始化是分開(kāi)的,所有關(guān)于圖變量的賦值和計(jì)算都要通過(guò)tf.Session的run來(lái)進(jìn)行。
sess.run([G_solver, G_loss_temp, MSE_loss],
feed_dict = {X: X_mb, M: M_mb, H: H_mb})
而在pytorch中,并不需要通過(guò)run進(jìn)行,賦值完了直接計(jì)算即可。
4.tensor
pytorch運(yùn)算時(shí)要?jiǎng)?chuàng)建完的numpy數(shù)組轉(zhuǎn)為tensor,如下:
if use_gpu is True:
X_mb = torch.tensor(X_mb, device="cuda")
M_mb = torch.tensor(M_mb, device="cuda")
H_mb = torch.tensor(H_mb, device="cuda")
else:
X_mb = torch.tensor(X_mb)
M_mb = torch.tensor(M_mb)
H_mb = torch.tensor(H_mb)
最后運(yùn)行完還要將tensor數(shù)據(jù)類(lèi)型轉(zhuǎn)換回numpy數(shù)組:
if use_gpu is True:
imputed_data=imputed_data.cpu().detach().numpy()
else:
imputed_data=imputed_data.detach().numpy()
而tensorflow中不需要這種操作。
5.其他函數(shù)
在tensorflow中包含諸多函數(shù)是pytorch中沒(méi)有的,但是都可以在其他庫(kù)中找到類(lèi)似,具體如下表所示。
tensorflow中函數(shù) | pytorch中代替(所在庫(kù)) | 參數(shù)區(qū)別 |
---|---|---|
tf.sqrt | np.sqrt(numpy) | 完全相同 |
tf.random_normal | np.random.normal(numpy) | tf.random_normal(shape = size, stddev = xavier_stddev) np.random.normal(size = size, scale = xavier_stddev) |
tf.concat | torch.cat(torch) | inputs = tf.concat(values = [x, m], axis = 1) inputs = torch.cat(dim=1, tensors=[x, m]) |
tf.nn.relu | F.relu(torch.nn.functional) | 完全相同 |
tf.nn.sigmoid | torch.sigmoid(torch) | 完全相同 |
tf.matmul | torch.matmul(torch) | 完全相同 |
tf.reduce_mean | torch.mean(torch) | 完全相同 |
tf.log | torch.log(torch) | 完全相同 |
tf.zeros | np.zeros | 完全相同 |
tf.train.AdamOptimizer | torch.optim.Adam(torch) | optimizer_D = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D) optimizer_D = torch.optim.Adam(params=theta_D) |
原文鏈接:https://blog.csdn.net/didi_ya/article/details/125461794
相關(guān)推薦
- 2022-03-28 C語(yǔ)言怎么連接兩個(gè)數(shù)組的內(nèi)容你知道嗎_C 語(yǔ)言
- 2022-12-12 Android?SharedPreferences數(shù)據(jù)存儲(chǔ)詳解_Android
- 2022-09-08 關(guān)于keras中的Reshape用法_python
- 2023-05-11 Oracle怎么刪除數(shù)據(jù),Oracle數(shù)據(jù)刪除的三種方式_oracle
- 2023-02-05 Python?面向?qū)ο缶幊淘斀鈅python
- 2022-11-22 Python實(shí)例方法與類(lèi)方法和靜態(tài)方法介紹與區(qū)別分析_python
- 2023-09-12 git常用指令
- 2022-11-05 python中的bisect模塊與二分查找詳情_(kāi)python
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細(xì)win安裝深度學(xué)習(xí)環(huán)境2025年最新版(
- Linux 中運(yùn)行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲(chǔ)小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎(chǔ)操作-- 運(yùn)算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認(rèn)證信息的處理
- Spring Security之認(rèn)證過(guò)濾器
- Spring Security概述快速入門(mén)
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權(quán)
- redisson分布式鎖中waittime的設(shè)
- maven:解決release錯(cuò)誤:Artif
- restTemplate使用總結(jié)
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實(shí)現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務(wù)發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結(jié)構(gòu)-簡(jiǎn)單動(dòng)態(tài)字符串(SD
- arthas操作spring被代理目標(biāo)對(duì)象命令
- Spring中的單例模式應(yīng)用詳解
- 聊聊消息隊(duì)列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠(yuǎn)程分支