博客
关于我
net网络查看其参数state_dict,data,named_parameters
阅读量:789 次
发布时间:2023-02-15

本文共 1940 字,大约阅读时间需要 6 分钟。

PyTorch 中 nn.Linear 层的初始化与参数获取详解

在 PyTorch 中,nn.Linear 层是一种常用的全连接层,用于将输入数据通过线性变换映射到输出数据。以下将详细介绍 nn.Linear 层的初始化过程以及如何获取其参数。

代码示例

import torchnet = nn.Linear(2, 2)print(net.state_dict())

输出结果

OrderedDict([('weight', tensor([[ 0.6155, -0.4649],                               [-0.1848, -0.0663]])),            ('bias', tensor([-0.0265, -0.4134]))])

初始化过程

在使用 nn.Linear 时,PyTorch 会自动为权重 (weight) 和偏置 (bias) 初始化默认值。具体来说:

  • 权重 (weight) 的形状为 (out_features, in_features),初始化方法为:

    [\text{math: U}(-\sqrt{k}, \sqrt{k}), \text{其中 } k = \frac{1}{\text{in_features}}]

  • 偏置 (bias) 的形状为 (out_features), 初始化方法同上。

自定义初始化方法

如果需要自定义初始化,可以通过定义一个函数并应用到网络中:

net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))def init_weights(m):    if type(m) == nn.Linear:        nn.init.normal_(m.weight, std=0.01)net.apply(init_weights)

参数获取方法

获取网络参数的方法如下:

net = nn.Linear(2, 2)print('net.state_dict():', net.state_dict())print('net.bias:', net.bias)print('net.bias.data:', net.bias.data)print('net.bias.grad:', net.bias.grad)  # 由于没有 backward 且没有梯度计算,梯度值为空print('net.named_parameters():', net.named_parameters())print([(name, param) for name, param in net.named_parameters()])

输出结果

net.state_dict(): OrderedDict([('weight', tensor([[-0.6907, -0.4128],                                                  [-0.4212,  0.0620]])),                              ('bias', tensor([-0.2019, -0.3731]))])net.bias: Parameter containing: tensor([-0.2019, -0.3731], requires_grad=True)net.bias.data: tensor([-0.2019, -0.3731])net.bias.grad: Nonenet.named_parameters():(generator object module.named_parameters at 0x151b22ba0)[    ('weight', Parameter containing: tensor([[-0.6907, -0.4128],                                            [-0.4212,  0.0620]], requires_grad=True)),    ('bias', Parameter containing: tensor([-0.2019, -0.3731], requires_grad=True))]

李沐老师手动实现网络层的启示

从李沐老师的实现可以看出,权重和偏置在初始化时确实经过了随机初始化过程,符合 PyTorch 中 nn.Linear 的默认初始化方式。

这意味着在实际应用中,建议使用 nn.init 函数对权重和偏置进行适当的初始化,而不是手动赋值。

转载地址:http://etcfk.baihongyu.com/

你可能感兴趣的文章