建立Linear对象
建立一个Linear对象,其输入特征维度为3,输出维度为4。数据随机生成1
2import torch.nn as nn
linear1 = nn.Linear(3, 4, bias=True)
通过weight和bias可以查看数据,weight的维度为4行3列,bias的维度是一维的,为4. 可以方便做广播操作1
2
3
4
5
6print(linear1.weight)
Parameter containing:
tensor([[ 0.0139, 0.2096, -0.5021],
[-0.4480, 0.5108, 0.0279],
[-0.3873, -0.5569, 0.3556],
[-0.4588, -0.0081, -0.3111]], requires_grad=True)
1 | print(linear1.bias) |
上述两个变量的类型均为torch.nn.parameter.Parameter,
可以通过print(linear1.weight.data)获得tensor数据
1 | linear1.weight.data |
linear 计算
1 | import torch |
等价计算1
2
3
4y2 = x @ (linear1.weight.data).t() + linear1.bias
y2.shape
# 输出为
torch.Size([2, 4])
详细计算
其中x.shape = (2,3), $A^T$.shape=(3,4), b.shape=(4)。
初始化
对Linear的数据初始化1
2
3
4
5def _init_weights(m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
_init_weights(linear1)
这里的两个操作,一个torch.nn.init.xavier_uniform_
,一个m.bias.data.fill_
都是inplace操作的 。
为何要手动初始化,可以参考https://zhuanlan.zhihu.com/p/25110150