这几天有空余时间,把以前看到一半看不下去的深度学习书又拿出来重新翻看,遇到几个问题请教一下 v 友们
一个问题是,安装 torch 的过程中自带了一个预训练模型库 torchvision ,可以直接打印模型结构,非常简明清晰,一眼就能看明白,在学习过程中给了我很大帮助,比如引入 ResNet18 的话直接打印可以看到如下结果ResNet( (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) (layer1): Sequential( (0): BasicBlock( (conv1): Conv2d(64, 64...) ......
但是查了一下文档,torchvision 里面的模型基本都是 cv 方向用的,在看书学习 cnn 的时候还好说,但是比如想要学习 rnn 相关的时候就没有模型能参考了.
尝试自己定义 class ,使用官网范例里的 nn.LSTM()建立模型,但是结构打印就只能变成类似下面这样
for name, layer in nn.LSTM(64).named_parameters(recurse=True): print(name, layer.shape, sep=" ")''' weight_ih_l0 torch.Size([256, 32]) weight_hh_l0 torch.Size([256, 64]) bias_ih_l0 torch.Size([256]) bias_hh_l0 torch.Size([256])'''
搞得实在是一头雾水,搞不清楚 rnn 这东西到底是怎么搭起来的,请问 nn.module 有像 torchvision 里一样的可以查看结构的打印方式吗?
2.印象里以前学习的时候看人说过除了 torchvision 还有一些第三方的预训练模型库,效果也都不错,但是这次搜了一些关键字都没有搜到,有 v 友能推荐几个吗,以及其中有没有带有 RNN 、transformer 相关模型的
最新回复 (10)
- 你第一个问题可能没定义清楚。PyTorch 里所有网络都是 nn.Module 的子类,torchvision 里面的网络也是。我看你意思应该是想问如何查看官方自己实现的基础模块,比如 nn.LSTM 的结构吧。如 nn.Conv2d ,nn.LSTM 这些类是官方实现的一些基础的 nn.Module 模块,确实没法直接 print 看到细节。想了解这种官方模块的具体结构,要不就算看 torch 的源码(对于初学者来说过于复杂了),要不就是直接看论文或者相关博客。因为这些基础模块不多,同时也都非常重要,我个人还是推荐先看论文或者解读的博客,理解原理就好。初学阶段还是把它当成最底层的积木,只学习如何使用它。具体如何实现比较麻烦,比如 nn.Conv2d 可能背后被优化得一般人都认不出来。
另外,看模型如何搭起来的还可以尝试一些第三方的把模型可视化出来的库,比如 torchinfo 。甚至自己写也不是很难。我的博客里就有一篇如何利用 torch.fx 实现模型可视化的文章,一共就百来行代码。https://wrong.wang/blog/20220520-%E5%88%A9%E7%94%A8torch.fx%E6%8F%90%E5%8F%96pytorch%E7%BD%91%E7%BB%9C%E7%BB%93%E6%9E%84%E4%BF%A1%E6%81%AF%E7%BB%98%E5%88%B6%E7%BD%91%E7%BB%9C%E7%BB%93%E6%9E%84%E5%9B%BE/
这个就是生成出来的 ResNet18 的可视化结果:
https://wrong.wang/blog/20220520-%E5%88%A9%E7%94%A8torch.fx%E6%8F%90%E5%8F%96pytorch%E7%BD%91%E7%BB%9C%E7%BB%93%E6%9E%84%E4%BF%A1%E6%81%AF%E7%BB%98%E5%88%B6%E7%BD%91%E7%BB%9C%E7%BB%93%E6%9E%84%E5%9B%BE/ResNet18.gv.svg
至于第二个问题,最火的 transformer 实现之一可能就是 Hugging Face 搞的吧: https://huggingface.co/docs/transformers/index 。CV 领域比较火的还有 https://github.com/rwightman/pytorch-image-models ,也实现了超多模型。https://pytorch.org/hub/ PyTorch 官方也搞了一些非常基础的模型实现,可以参考。 - @Richard14 RNN ,CNN 等基础模型 torch 都使用 C++在内部实现了,不过你也可以使用 pytorch 自己按照公式实现,实现方式类似于你自己定义一个 nn.Module 。实际生产中要看具体任务是什么样的,根据资源跟延时的限制决定一层还是多层。Transformer 出来之后基本上 RNN 很少单独使用了,有一些做序列标注的会在 BERT 上面加一层 biRNN ,用来提取文本的前后依赖信息。目前一般来说,RNN 在 Transformer 上的使用都是来补充位置向量太弱的问题的。不过,在小样本上,RNN 的效果会好于 Transformer ,所以具体用什么还是要看使用场景。
- 游客12楼
发新帖
主题数 460490 | 帖子数 7176701 | 注册排名 2 |