在 Pytorch 学习的过程中如何借助预训练模型帮助理解结构?

Richard143天前0

这几天有空余时间,把以前看到一半看不下去的深度学习书又拿出来重新翻看,遇到几个问题请教一下 v 友们

一个问题是,安装 torch 的过程中自带了一个预训练模型库 torchvision ,可以直接打印模型结构,非常简明清晰,一眼就能看明白,在学习过程中给了我很大帮助,比如引入 ResNet18 的话直接打印可以看到如下结果
ResNet(  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)  (relu): ReLU(inplace=True)  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)  (layer1): Sequential(    (0): BasicBlock(      (conv1): Conv2d(64, 64...)      ......

但是查了一下文档,torchvision 里面的模型基本都是 cv 方向用的,在看书学习 cnn 的时候还好说,但是比如想要学习 rnn 相关的时候就没有模型能参考了.

尝试自己定义 class ,使用官网范例里的 nn.LSTM()建立模型,但是结构打印就只能变成类似下面这样

for name, layer in nn.LSTM(64).named_parameters(recurse=True):    print(name, layer.shape, sep=" ")'''    weight_ih_l0 torch.Size([256, 32])    weight_hh_l0 torch.Size([256, 64])    bias_ih_l0 torch.Size([256])    bias_hh_l0 torch.Size([256])'''

搞得实在是一头雾水,搞不清楚 rnn 这东西到底是怎么搭起来的,请问 nn.module 有像 torchvision 里一样的可以查看结构的打印方式吗?

2.印象里以前学习的时候看人说过除了 torchvision 还有一些第三方的预训练模型库,效果也都不错,但是这次搜了一些关键字都没有搜到,有 v 友能推荐几个吗,以及其中有没有带有 RNN 、transformer 相关模型的

最新回复 (10)
  • conhost3天前
    引用2
    torch vision 里面是设置好的模型,打印的是模型的结构,并不是具体到某一层是怎么搭建的。而 LSTM 是一个层,作用类似于 ResNet 中的 Conv2d 。而 ResNet 是一个网络,由多个层构成。现在你打印的是 LSTM 里面的参数名和参数的维度,不是模型结构。如果你是做 NLP 的话,可以使用 hugging face 的 transformers 库,里面是近几年 NLP 上预训练模型。
  • noqwerty3天前
    引用3
    你提到的模型结构输出应该是 nn.Module 里面定义的:
    https://github.com/pytorch/pytorch/blob/0ee31ce8c8312af1a61c161d21efb21d900a0c13/torch/nn/modules/module.py#L1927-L1950
  • 楼主Richard143天前
    引用4
    @conhost 谢谢,项目地址慢慢学习,很有帮助。NLP 方面 BERT 出来以后 RNN 过气了,但是我看到一些资料说在很多场景中 RNN 还在用,尤其是涉及到计算资源限制方面的,就想都从头了解一下。
    所以这里的 LSTM 应该类比类似 CNN 这种结构,是由框架库提供的基础组件,一般不去设计其内部的具体实现方式?
    实际生产中使用模型是需要多个 RNN 层叠加实现效果吗,就像多个 CNN 叠加一样?类似方面有什么资料可以参考吗,具体结构,超参调整之类的,不像 CNN 上感觉有很多模型可以参考,RNN 完全一头雾水没有概念应该怎么做
  • rayhy3天前
    引用5
    你第一个问题可能没定义清楚。PyTorch 里所有网络都是 nn.Module 的子类,torchvision 里面的网络也是。我看你意思应该是想问如何查看官方自己实现的基础模块,比如 nn.LSTM 的结构吧。如 nn.Conv2d ,nn.LSTM 这些类是官方实现的一些基础的 nn.Module 模块,确实没法直接 print 看到细节。想了解这种官方模块的具体结构,要不就算看 torch 的源码(对于初学者来说过于复杂了),要不就是直接看论文或者相关博客。因为这些基础模块不多,同时也都非常重要,我个人还是推荐先看论文或者解读的博客,理解原理就好。初学阶段还是把它当成最底层的积木,只学习如何使用它。具体如何实现比较麻烦,比如 nn.Conv2d 可能背后被优化得一般人都认不出来。
    另外,看模型如何搭起来的还可以尝试一些第三方的把模型可视化出来的库,比如 torchinfo 。甚至自己写也不是很难。我的博客里就有一篇如何利用 torch.fx 实现模型可视化的文章,一共就百来行代码。https://wrong.wang/blog/20220520-%E5%88%A9%E7%94%A8torch.fx%E6%8F%90%E5%8F%96pytorch%E7%BD%91%E7%BB%9C%E7%BB%93%E6%9E%84%E4%BF%A1%E6%81%AF%E7%BB%98%E5%88%B6%E7%BD%91%E7%BB%9C%E7%BB%93%E6%9E%84%E5%9B%BE/
    这个就是生成出来的 ResNet18 的可视化结果:
    https://wrong.wang/blog/20220520-%E5%88%A9%E7%94%A8torch.fx%E6%8F%90%E5%8F%96pytorch%E7%BD%91%E7%BB%9C%E7%BB%93%E6%9E%84%E4%BF%A1%E6%81%AF%E7%BB%98%E5%88%B6%E7%BD%91%E7%BB%9C%E7%BB%93%E6%9E%84%E5%9B%BE/ResNet18.gv.svg
    至于第二个问题,最火的 transformer 实现之一可能就是 Hugging Face 搞的吧: https://huggingface.co/docs/transformers/index 。CV 领域比较火的还有 https://github.com/rwightman/pytorch-image-models ,也实现了超多模型。https://pytorch.org/hub/ PyTorch 官方也搞了一些非常基础的模型实现,可以参考。
  • conhost3天前
    引用6
    @Richard14 RNN ,CNN 等基础模型 torch 都使用 C++在内部实现了,不过你也可以使用 pytorch 自己按照公式实现,实现方式类似于你自己定义一个 nn.Module 。实际生产中要看具体任务是什么样的,根据资源跟延时的限制决定一层还是多层。Transformer 出来之后基本上 RNN 很少单独使用了,有一些做序列标注的会在 BERT 上面加一层 biRNN ,用来提取文本的前后依赖信息。目前一般来说,RNN 在 Transformer 上的使用都是来补充位置向量太弱的问题的。不过,在小样本上,RNN 的效果会好于 Transformer ,所以具体用什么还是要看使用场景。
  • 楼主Richard143天前
    引用7
    @conhost CV 有很多项目可以参考,网上教学一般也是从 CNN 开始教起,学完就知道大概设计一个 CNN 网络是个什么形状,叠几层 CNN ,池化,dropout ,全连接层之类的,RNN 有类似的项目可以参考吗?我现在一个人拿着一个 nn.LSTM 在风中凌乱,看网上的例子是一个 lstm 层后面接一层 dense 就输出了,我很质疑生产上大家是这么做的?感觉有点太简单了
  • conhost3天前
    引用8
    具体需要看你做什么了,lstm 后面加一个 dense 已经是一个完整的网络结构了。其递归的结构可以提取到全部的输入信息。在文本方面的话,cnn 确实需要堆叠多层,这是因为 cnn 是提取的局部信息,想要获取到全部信息,只能通过堆叠间接扩大卷积核的大小,从而能覆盖到全部输入。
  • conhost3天前
    引用9
    @Richard14 总体来说循环网络的结构确实比较抽象,在实际理解的时候内部递归结构需要展开来看,不能将其单单就理解为一个层。包括训练时候的梯度回传,也要按照 rnn 的时间步进行展开回传的。而 cnn 由于其参数共享的原因,各个窗口之间是完全并行的,因此你理解一个窗口的操作,就可以直接扩展到其他窗口。
  • flyaway2天前
    引用10
    Transformer 可以看 huggingface ,它有很多 tutorial ,适合入门。
  • ox1802天前
    引用11
    重写 lstm 的__repr__就行
  • 游客
    12
返回