日本免费高清视频-国产福利视频导航-黄色在线播放国产-天天操天天操天天操天天操|www.shdianci.com

學(xué)無先后,達(dá)者為師

網(wǎng)站首頁 編程語言 正文

Pytorch?linear?多維輸入的參數(shù)問題_python

作者:又是花落時(shí) ? 更新時(shí)間: 2022-10-16 編程語言

問題: 由于 在輸入lstm 層 每個(gè)batch 做了根據(jù)輸入序列最大長度做了padding,導(dǎo)致每個(gè) batch 的 length 不同。 導(dǎo)致輸出 長度不同 。如:(batch, length, output_dim): (12,128,10),(12,111,10). 但是輸入 linear 層的時(shí)候沒有出現(xiàn)問題。

網(wǎng)站解釋:

官網(wǎng) pytorch linear:

  • Input:(*, H_{in})(?,Hin?)where*?means any number of dimensions including none andH_{in} = \text{in\_features}Hin?=in_features. 任意維度 number 理解有歧義 (a)number. k可以理解三維,四維。。。 (b) 可以理解 為某一維度的數(shù) 。
  • Output:(*, H_{out})(?,Hout?)where all but the last dimension are the same shape as the input andH_{out} = \text{out\_features}Hout?=out_features.

代碼解釋:

分別 用三維 和二維輸入數(shù)組,查看他們參數(shù)數(shù)目是否一樣。

import torch
 
x = torch.randn(128, 20)  # 輸入的維度是(128,20)
m = torch.nn.Linear(20, 30)  # 20,30是指維度
output = m(x)
print('m.weight.shape:\n ', m.weight.shape)
print('m.bias.shape:\n', m.bias.shape)
print('output.shape:\n', output.shape)
 
# ans = torch.mm(input,torch.t(m.weight))+m.bias 等價(jià)于下面的
ans = torch.mm(x, m.weight.t()) + m.bias   
print('ans.shape:\n', ans.shape)
 
print(torch.equal(ans, output))

output:

m.weight.shape:
  torch.Size([30, 20])
m.bias.shape:
 torch.Size([30])
output.shape:
 torch.Size([128, 30])
ans.shape:
 torch.Size([128, 30])
True
x = torch.randn(128, 30,20)  # 輸入的維度是(128,30,20)
m = torch.nn.Linear(20, 30)  # 20,30是指維度
output = m(x)
print('m.weight.shape:\n ', m.weight.shape)
print('m.bias.shape:\n', m.bias.shape)
print('output.shape:\n', output.shape)
ouput:
m.weight.shape:
  torch.Size([30, 20])
m.bias.shape:
 torch.Size([30])
output.shape:
 torch.Size([128, 30, 30])

結(jié)果:

(128,30,20),和 (128,20) 分別是如 nn.linear(30,20) 層。

weight.shape 均為: (30,20)

linear() 參數(shù)數(shù)目只和 input_dim ,output_dim 有關(guān)。

weight 在源碼的定義, 沒找到如何計(jì)算多維input的代碼。

原文鏈接:https://blog.csdn.net/u013996948/article/details/126406694

欄目分類
最近更新