基于 ResNet-18 的 CIFAR-10 物体分类
基于 ResNet-18 的 CIFAR-10 物体分类
小嗷犬介绍
环境准备
使用到的库:
- Pytorch
- matplotlib
- d2l
d2l 为斯坦福大学李沐教授打包的一个库,其中包含一些深度学习中常用的函数方法。
安装:
1 |
|
Pytorch 环境请自行配置。
数据集介绍
CIFAR-10 是一个更接近普适物体的彩色图像数据集。CIFAR-10 是由 Hinton 的学生 Alex Krizhevsky 和 Ilya Sutskever 整理的一个用于识别普适物体的小型数据集。一共包含 10 个类别的 RGB 彩色图片:飞机( airplane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。每个图片的尺寸为32 × 32
,每个类别有 6000 个图像,数据集中一共有 50000 张训练图片和 10000 张测试图片。
下载地址:
官网(较慢):http://www.cs.toronto.edu/~kriz/cifar.html
CSDN:https://download.csdn.net/download/qq_63585949/86928673
也可以使用 Pytorch 自动下载,速度基本等于官网速度。
网络模型介绍
残差神经网络(ResNet) 是由微软研究院的 何恺明、张祥雨、任少卿、孙剑 等人提出的。ResNet 在 2015 年的 ILSVRC(ImageNet Large Scale Visual Recognition Challenge)中取得了冠军。
残差神经网络 的主要贡献是发现了“退化现象(Degradation)”,并针对退化现象发明了 “快捷连接(Shortcut connection)”,极大的消除了深度过大的神经网络训练困难问题。神经网络的“深度”首次突破了 100 层、最大的神经网络甚至超过了 1000 层。
正常块(左)与残差块(右):
两种具体结构(包含以及不包含 1*1 卷积层的残差块):
ResNet-18 网络结构:
导入相关库
1 |
|
定义 ResNet-18 网络结构
1 |
|
下载并配置数据集和加载器
1 |
|
定义训练函数
训练完成后会保存模型,可以修改模型的保存路径。
1 |
|
训练模型(或加载模型)
如果环境正确配置了 CUDA,则会由 GPU 进行训练。
加载模型需要根据自身情况修改路径。
1 |
|
可视化展示
1 |
|
预测图
结果来自训练轮数epochs=20
,准确率Accuracy=80.46%
的 ResNet-18 模型: