網站首頁 編程語言 正文
本文以一段代碼為例,簡單介紹一下tensorflow與pytorch的相互轉換(主要是tensorflow轉pytorch),可能介紹的沒有那么詳細,僅供參考。
由于本人只熟悉pytorch,而對tensorflow一知半解,而代碼經常遇到tensorflow,而我希望使用pytorch,因此簡單介紹一下tensorflow轉pytorch,可能存在諸多錯誤,希望輕噴~
1.變量預定義
在TensorFlow的世界里,變量的定義和初始化是分開的。
tensorflow中一般都是在開頭預定義變量,聲明其數據類型、形狀等,在執行的時候再賦具體的值,如下圖所示,而pytorch用到時才會定義,定義和變量初始化是合在一起的。
2.創建變量并初始化
tensorflow中利用tf.Variable創建變量并進行初始化,而pytorch中使用torch.tensor創建變量并進行初始化,如下圖所示。
3.語句執行
在TensorFlow的世界里,變量的定義和初始化是分開的,所有關于圖變量的賦值和計算都要通過tf.Session的run來進行。
sess.run([G_solver, G_loss_temp, MSE_loss],
feed_dict = {X: X_mb, M: M_mb, H: H_mb})
而在pytorch中,并不需要通過run進行,賦值完了直接計算即可。
4.tensor
pytorch運算時要創建完的numpy數組轉為tensor,如下:
if use_gpu is True:
X_mb = torch.tensor(X_mb, device="cuda")
M_mb = torch.tensor(M_mb, device="cuda")
H_mb = torch.tensor(H_mb, device="cuda")
else:
X_mb = torch.tensor(X_mb)
M_mb = torch.tensor(M_mb)
H_mb = torch.tensor(H_mb)
最后運行完還要將tensor數據類型轉換回numpy數組:
if use_gpu is True:
imputed_data=imputed_data.cpu().detach().numpy()
else:
imputed_data=imputed_data.detach().numpy()
而tensorflow中不需要這種操作。
5.其他函數
在tensorflow中包含諸多函數是pytorch中沒有的,但是都可以在其他庫中找到類似,具體如下表所示。
tensorflow中函數 | pytorch中代替(所在庫) | 參數區別 |
---|---|---|
tf.sqrt | np.sqrt(numpy) | 完全相同 |
tf.random_normal | np.random.normal(numpy) | tf.random_normal(shape = size, stddev = xavier_stddev) np.random.normal(size = size, scale = xavier_stddev) |
tf.concat | torch.cat(torch) | inputs = tf.concat(values = [x, m], axis = 1) inputs = torch.cat(dim=1, tensors=[x, m]) |
tf.nn.relu | F.relu(torch.nn.functional) | 完全相同 |
tf.nn.sigmoid | torch.sigmoid(torch) | 完全相同 |
tf.matmul | torch.matmul(torch) | 完全相同 |
tf.reduce_mean | torch.mean(torch) | 完全相同 |
tf.log | torch.log(torch) | 完全相同 |
tf.zeros | np.zeros | 完全相同 |
tf.train.AdamOptimizer | torch.optim.Adam(torch) | optimizer_D = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D) optimizer_D = torch.optim.Adam(params=theta_D) |
原文鏈接:https://blog.csdn.net/didi_ya/article/details/125461794
相關推薦
- 2022-10-05 Ubuntu?Server?20.04?LTS?環境下搭建vim?編輯器Python?IDE的詳細步
- 2021-12-07 關于postman上傳文件執行成功而使用collection?runner執行失敗的問題_相關技巧
- 2024-03-23 spring boot 使用AOP實現是否已登錄檢測
- 2022-06-12 ASP.NET?Core?WebApi返回結果統一包裝實踐記錄_實用技巧
- 2022-12-15 詳解Golang如何比較兩個slice是否相等_Golang
- 2021-12-09 Android音頻開發之錄制音頻(WAV及MP3格式)_Android
- 2022-09-15 windows中cmd下添加、刪除和修改靜態路由實現_DOS/BAT
- 2022-04-28 一篇文章帶你了解C++特殊類的設計_C 語言
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支