【初级】用nn.Module, nn.Sequential构造深度学习网络
创始人
2024-05-19 18:10:55
0

在Pytorch的网络构造代码中,都会看到继承nn.Module类的子类
然后重新编写子类的构造函数__init__()和forward函数 。
这里编写也有一些经验之谈。
1:
一般把conv,dense等函数放到init函数里面,而nn.Functional一般用来连接一些不需要训练参数的层比如relu,bn等等
2:
forward方法是必须要重写的,它是实现模型的功能,实现各个层之间的连接关系的核心。
下面先看一个简单的例子。

本文,先介绍nn.Sequential来构建基本网络。

Abcnet(nn.Module):def __init__(self, in_channels, out_channels, xxxx xxxx):super().__init__() def forward(sefl, xxxxx):

关于super函数的解读,可以参考我的另一篇文章。

nn.Sequential是nn.Module的一个子类,它作为一个有序的容器,网络模块将按照在传入构造器的顺序依次被添加到计算图中执行。

我们直接看看nn.Sequential的源码

   def __init__(self, *args):super(Sequential, self).__init__()if len(args) == 1 and isinstance(args[0], OrderedDict):for key, module in args[0].items():self.add_module(key, module)else:for idx, module in enumerate(args):self.add_module(str(idx), module)def forward(self, input):for module in self:input = module(input)return input

if len(args) == 1 and isinstance(args[0], OrderedDict) 其判断是否使用OrderedDict.
OrderDict的构造方法是

collections.OrderDict([('a', nn.conv), ('b', nn.pool), ('c', nn.dense)])

如果自己定义了名称的话使用自定义的名称,否则将使用idx自动定义。
再看看nn.Sequential的三种实现

第一种实现

import torch.nn as nn
model = nn.Sequential(nn.Conv2d(1,20,5), # in_channels=1, out_channels=20, kernel_size=5nn.ReLU(),nn.Conv2d(20,64,5),nn.ReLU())print(model)
print(model[2]) # 通过索引获取第几个层
'''运行结果为:
Sequential((0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))(1): ReLU()(2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))(3): ReLU()
)
Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
'''

这里nn.Sequential里面不是OrderDict
所以运行的是nn.Sequential源码里面的
for idx, module in enumerate(args):
self.add_module(str(idx), module)
注意这个add_module是nn.Module的方法之一,def add_module(self, name, module):
这样的网络构造有一个问题
每一个层是没有名称,默认的是以0、1、2、3来命名,从上面的运行结果也可以看出。

第二种实现

import torch.nn as nn
from collections import OrderedDict
model = nn.Sequential(OrderedDict([('conv1', nn.Conv2d(1,20,5)),('relu1', nn.ReLU()),('conv2', nn.Conv2d(20,64,5)),('relu2', nn.ReLU())]))print(model)
print(model[2]) # 通过索引获取第几个层 
#注意model["conv2"] 是错误的
#这其实是由它的定义实现的,看上面的Sequenrial定义可知,只支持index访问。
'''
运行结果为:
Sequential((conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))(relu1): ReLU()(conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))(relu2): ReLU()
)
Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
'''

第三种实现

import torch.nn as nn
from collections import OrderedDict
model = nn.Sequential()
model.add_module("conv1",nn.Conv2d(1,20,5))
model.add_module('relu1', nn.ReLU())
model.add_module('conv2', nn.Conv2d(20,64,5))
model.add_module('relu2', nn.ReLU())print(model)
print(model[2]) # 通过索引获取第几个层
这个就是直接用父类nn.Module的add_module(self, name, module)方法,语法上类似keras

reference

https://blog.csdn.net/qq_27825451/article/details/90550890
https://blog.csdn.net/qq_27825451/article/details/90551513

相关内容

热门资讯

【NI Multisim 14...   目录 序言 一、工具栏 🍊1.“标准”工具栏 🍊 2.视图工具...
银河麒麟V10SP1高级服务器... 银河麒麟高级服务器操作系统简介: 银河麒麟高级服务器操作系统V10是针对企业级关键业务...
不能访问光猫的的管理页面 光猫是现代家庭宽带网络的重要组成部分,它可以提供高速稳定的网络连接。但是,有时候我们会遇到不能访问光...
AWSECS:访问外部网络时出... 如果您在AWS ECS中部署了应用程序,并且该应用程序需要访问外部网络,但是无法正常访问,可能是因为...
Android|无法访问或保存... 这个问题可能是由于权限设置不正确导致的。您需要在应用程序清单文件中添加以下代码来请求适当的权限:此外...
北信源内网安全管理卸载 北信源内网安全管理是一款网络安全管理软件,主要用于保护内网安全。在日常使用过程中,卸载该软件是一种常...
AWSElasticBeans... 在Dockerfile中手动配置nginx反向代理。例如,在Dockerfile中添加以下代码:FR...
AsusVivobook无法开... 首先,我们可以尝试重置BIOS(Basic Input/Output System)来解决这个问题。...
ASM贪吃蛇游戏-解决错误的问... 要解决ASM贪吃蛇游戏中的错误问题,你可以按照以下步骤进行:首先,确定错误的具体表现和问题所在。在贪...
月入8000+的steam搬砖... 大家好,我是阿阳 今天要给大家介绍的是 steam 游戏搬砖项目,目前...