一起深度学习24/04/30——ResNet

ResNet神经网络

  • 定义ResNet Block
  • 定义ResNet18
  • 加载数据集并训练、测试

定义ResNet Block

ResNet Block 的作用:
是一个残差块,用于构建ResNet
主要是为了解决神经网络中的梯度爆炸和梯度消失问题,以及缓解训练过程中的退化问题。
在传统的神经网络中,每层的输出会直接作为下一层的输入,可能会导致梯度在反向传播过程中逐渐减小,当层数比较深时,就可能导致梯度消失。故引入了跳跃连接,将每一层的输出与最初的x进行相加,当你对其进行求导,能发现比传统的多了一项对x的求导,也就是因为该项,避免了梯度消失的问题。

class ResBlk(nn.Module):
    """
    resnet Block
    """
    def __init__(self,ch_in,ch_out,stride):
        super(ResBlk,self).__init__()
        self.conv1 = nn.Conv2d(in_channels=ch_in,out_channels=ch_out,kernel_size=3,stride=stride,padding=1)
        print(self.conv1)
        self.bn1 = nn.BatchNorm2d(ch_out)
        self.conv2 = nn.Conv2d(in_channels=ch_out, out_channels=ch_out, kernel_size=3, stride=1, padding=1)
        print(self.conv2)
        self.bn2 = nn.BatchNorm2d(ch_out)
	
        self.extra =nn.Sequential()#当输入通道数并不等于输出通道数的时候,进行转换。
        if ch_out != ch_in:
            self.extra = nn.Sequential(
                # [b,ch_in,h,w] =>[b,ch_out,h,w]
                nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=stride),
                nn.BatchNorm2d(ch_out)
            )

    def forward(self,x):
        """
        :param x: [b,ch,h,w]
        :return:
        """
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        #shor cut
        # x :[b,ch_in,h,w]  而out [b,ch_out,h,w]
        out = self.extra(x) +out #resNet的精髓所在,能够避免过拟合,梯度爆炸,梯度消失,
        return out

运行测试一下:

def main():
    blk = ResBlk(64,128,stride=4)
    tmp = torch.randn(2,64,32,32)
    out = blk(tmp)
    print(out.shape)
if __name__ == '__main__':
    main()

在这里说明一下其中的疑惑,在做该模块的时候
blk = ResBlk(64,128,stride=4) #64是输入通道数,128表示输出通道数。
tmp = torch.randn(2,64,32,32) # 2是样本数量,64是输入通道数,32是形状。
out = blk(tmp) #将其传入到ResBlok中,进行运算。
输出为torch.Size([2, 128, 8, 8])。

定义ResNet18

class ResNet18(nn.Module):
    def __init__(self):
        super(ResNet18,self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(3,64,kernel_size=3,stride=3,padding=0),
            nn.BatchNorm2d(64)
        )
        # followed 4 blocks
        # [b,64,h,w] => [b,128,h,w]
        self.blk1 =  ResBlk(64,128,stride=2)
        # [b,128,h,w] => [b,256,h,w]
        self.blk2 = ResBlk(128,256,stride=2)
        # [b,256,h,w] => [b,512,h,w]
        self.blk3 = ResBlk(256, 512,stride=2)
        # [b,512,h,w] => [b,1024,h,w]
        self.blk4 = ResBlk(512, 512,stride=2)

        self.outlayer = nn.Linear(512,10)
    def forward(self,x):
        x = F.relu(self.conv1(x))

        x = self.blk1(x)
        x = self.blk2(x)
        x = self.blk3(x)
        x = self.blk4(x)

        x = F.adaptive_avg_pool2d(x,[1,1])
        x = x.view(x.size(0), -1)
        x = self.outlayer(x)

        return x

加载数据集并训练、测试

import torch
import torchvision.transforms
from torch import nn, optim
from torchvision import datasets
from torch.utils.data import DataLoader
# from lenet5 import Lenet5
from learing_resnet import ResNet18
def main():
    batchsz = 32
    cifar_train= datasets.CIFAR10('data',train=True,transform=torchvision.transforms.Compose([
        torchvision.transforms.Resize((32,32)),
        torchvision.transforms.ToTensor()
    ]),download=True)
    cifar_train = DataLoader(cifar_train,batch_size=batchsz,shuffle=True)

    cifar_test= datasets.CIFAR10('data',train=False,transform=torchvision.transforms.Compose([
        torchvision.transforms.Resize((32,32)),
        torchvision.transforms.ToTensor()
    ]),download=True)
    cifar_test = DataLoader(cifar_test,batch_size=batchsz,shuffle=True)

    # x, label = iter(cifar_train)
    # print("x:",x.shape,"label:",label.shape)
    device  = torch.device('cuda')
    # model = Lenet5().to(device)
    model = ResNet18().to(device)
    criten = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(),lr=1e-3)
    for epoch in range(1000):
        for batchidx,(x,lable) in enumerate(cifar_train):
            x,lable = x.to(device),lable.to(device)
            logits = model(x)
            loss = criten(logits,lable)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(epoch,loss.item())
        total_correct = 0
        total_num = 0
        model.eval()
        with torch.no_grad():
            for x,label in cifar_test:
                x,label = x.to(device),label.to(device)
                logits = model(x)
                pred = logits.argmax(dim=1)
                total_correct += torch.eq(pred,label).float().sum().item()
                total_num += x.size(0)
            acc = total_correct /total_num
            print(epoch,acc)

if __name__ == '__main__':
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/599388.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

编程算法赛

1偶数累加 2、统计字符的数量 3、计算表达式的值 4、哥德巴赫猜想 5、进制的转换

英语学习笔记5——Nice to meet you.

Nice to meet you. 很高兴见到你。 词汇 Vocabulary Mr. 先生 用法:自己全名 / 姓 例如:Mr. Zhang Mingdong 或 Mr. Zhang,绝对不能是 Mr. Mingdong! Miss 女士,小姐 未婚 用法:自己全名 / 姓 例如&#…

【论文阅读】Fuzz4All: Universal Fuzzing with Large Language Models

文章目录 摘要一、介绍二、Fuzz4All的方法2.1、自动提示2.1.1、自动提示算法2.1.2、自动提示的例子2.1.3、与现有自动提示技术的比较 2.2、fuzzing循环2.2.1、模糊循环算法2.2.2、Oracle 三、实验设计3.1、实现3.2、被测系统和baseline3.3、实验设置以及评估指标 四、结果分析4…

P8801 [蓝桥杯 2022 国 B] 最大数字

P8801 [蓝桥杯 2022 国 B] 最大数字 分析 dfs 思路:题目的意思,要让一个数最大,用贪心去考虑,从高位开始,对其进行a / b操作,使其变为9,可让该数最大 1.a 操作:1;b 操…

嵌入式学习<1>:建立工程、GPIO

嵌入式学习_part1 本部分笔记用于学习记录,笔记源头 >>b站江科大_STM32入门教程 建立工程、GPIO 开发环境:keil MDK、STM32F103C8T6 1 )建立工程 (1)基于寄存器开发、基于标准库 或者 基于HAL库开发; &…

【代码随想录——哈希表】

1.哈希表理论基础 首先什么是 哈希表,哈希表(英文名字为Hash table,国内也有一些算法书籍翻译为散列表,大家看到这两个名称知道都是指hash table就可以了)。 那么哈希表能解决什么问题呢,一般哈希表都是用…

高素质高学历婚恋相亲交友平台有哪些?分享我的网上找对象成功脱单经历!

尽管觉得在社交软件上找到真爱的可能性很小,但我却时常看到别人成功的案例,这也让我跃跃欲试了。没想到,我真的成功了!以下是我亲身使用过的一些方法,在此与大家分享,仅供参考哦! 👉…

c++ cpp 在类中执行线程 进行恒定计算

在编程中,顺序执行是常见的模式,但是对cpu的利用率不是很高,采用线程池,又太麻烦了,原因是还得不断地把任务拆分,扫描返回值。 如果 初始化n个类的时候,传递数据自身即可异步计算,那…

六、文件查找

一、文件查找 1.查找文件内容 ​ 命令:grep keywords /dir_path/filename 2.查找系统命令 ​ 命令:which command 3.查找命令及配置文件位置 ​ 命令:whereis command 4.find查找 ​ find $find_path -name|-type|-perm|-size|-atime…

【前端】HTML基础(3)

文章目录 前言一、HTML基础1、表格标签1.1 基本使用1.2 合并单元格 2、列表标签2.1 无序列表2.2 有序列表2.3 自定义列表 3、 表单标签2.1 form标签2.2 input标签2.3 label标签2.4 select标签2.5 textarea标签 4、无语义标签5、HTML特殊字符 前言 这篇博客仅仅是对HTML的基本结…

RVM(相关向量机)、CNN_RVM(卷积神经网络结合相关向量机)、RVM-Adaboost(相关向量机结合Adaboost)

当我们谈到RVM(Relevance Vector Machine,相关向量机)、CNN_RVM(卷积神经网络结合相关向量机)以及RVM-Adaboost(相关向量机结合AdaBoost算法)时,每种模型都有其独特的原理和结构。以…

streamlit通过子目录访问

运行命令: streamlit hello 系统默认使用8501端口启动服务: 如果想通过子目录访问服务,可以这么启动服务 streamlit hello --server.baseUrlPath "app" 也可以通过以下命令换端口 streamlit hello --server.port 9999 参考&…

2024最新CTF入门的正确路线

目录 前言 一、什么是CTF比赛? 二、CTF比赛的流程 三、需要具备的知识 四、总结 前言 随着网络安全意识的增强,越来越多的人开始涉足网络安全领域,其中CTF比赛成为了重要的学习和竞赛平台。本人从事网络安全工作多年,也参加过…

甲小姐对话柳钢:CEO对股东最大的责任,是对成功的概率负责|甲子光年

只有看见最微小的事物,才能洞悉伟大的定律。 来源|甲子光年 作者|甲小姐 刘杨楠 编辑|栗子 商业史上,职业经理人成为“空降CEO”的故事往往胜少败多。 “究其原因有三条——容易自嗨、喊口号;不顾公司历…

笔试强训Day19 数学知识 动态规划 模拟

[编程题]小易的升级之路 题目链接&#xff1a;小易的升级之路__牛客网 思路&#xff1a; 按题目写即可 注意辗转相除法。 AC code&#xff1a; #include<iostream> using namespace std; int gcd(int a, int b) {return b ? gcd(b, a % b) : a; } int main() {int n…

三步学会苹果手机怎么关震动的方法!

苹果手机的震动功能在某些情况下可能会被认为是不必要的&#xff0c;比如在会议、课堂或者晚间睡眠时。因此&#xff0c;学会如何关闭苹果手机的震动功能是非常实用的。苹果手机怎么关震动&#xff1f;在本文中&#xff0c;我们将介绍三个步骤&#xff0c;帮助你关闭苹果手机的…

openEuler 22.03 GPT分区表模式下磁盘分区管理

目录 GPT分区表模式下磁盘分区管理parted交互式创建分区步骤 1 执行如下步骤对/dev/sdc磁盘分区 非交互式创建分区步骤 1 输入如下命令直接创建分区。 删除分区步骤 1 执行如下命令删除/dev/sdc1分区。 GPT分区表模式下磁盘分区管理 parted交互式创建分区 步骤 1 执行如下步骤…

ThingsBoard版本控制配合Gitee实现版本控制

1、概述 2、架构 3、导出设置 4、仓库 5、同步策略 6、扩展 7、案例 7.1、首先需要在Giitee上创建对应同步到仓库地址 ​7.2、giit仓库只能在租户层面进行配置 7.3、 配置完成后&#xff1a;检查访问权限。显示已成功验证仓库访问&#xff01;表示配置成功 7.4、添加设…

喜报 | 擎创科技荣获NIISA联盟2023年度创新技术特等奖!

为深入实施创新驱动发展战略&#xff0c;紧紧把握全球科技革命和产业变革方向&#xff0c;密切跟踪前沿科技新趋势&#xff0c;经科技部中国民营促进会业务主管部门批准以及国家互联网数据中心产业技术创新战略联盟&#xff08;以下简称联盟&#xff09;总体工作安排&#xff0…
最新文章