基于 NiN 的 Fashion-MNIST 服饰分类
基于 NiN 的 Fashion-MNIST 服饰分类
小嗷犬介绍
环境准备
使用到的库:
- Pytorch
- matplotlib
- d2l
d2l 为斯坦福大学李沐教授打包的一个库,其中包含一些深度学习中常用的函数方法。
安装:
1 |
|
Pytorch 环境请自行配置。
数据集介绍
Fashion-MNIST 是一个替代 MNIST 手写数字集的图像数据集。 它是由 Zalando(一家德国的时尚科技公司)旗下的研究部门提供。其涵盖了来自 10 种类别的共 7 万个不同商品的正面图片。
Fashion-MNIST 的大小、格式和训练集/测试集划分与原始的 MNIST 完全一致。60000/10000
的训练测试数据划分,28x28
的灰度图片。你可以直接用它来测试你的机器学习和深度学习算法性能,且不需要改动任何的代码。
下载地址:
本文使用 Pytorch 自动下载。
网络模型介绍
Network In Network (NIN) 是由 Min Lin 等人于 2014 年提出,在 CIFAR-10 和 CIFAR-100 分类任务中达到当时的最好水平,其网络结构是由三个多层感知机(NiN 块)堆叠而成。NiN 模型论文 《Network In Network》 发表于 ICLR-2014,NIN 以一种全新的角度审视了卷积神经网络中的卷积核设计,通过引入子网络结构代替纯卷积中的线性映射部分,这种形式的网络结构激发了更复杂的卷积神经网络的结构设计,GoogLeNet 的 Inception 结构就是来源于这个思想。结构图如下:
NiN 块:
导入相关库
1 |
|
定义 NiN 网络结构
1 |
|
下载并配置数据集和加载器
这里 NiN 输入图片尺寸应为 224*224
,我们将 28*28
的 Fashion-MNIST 图片拉大到 224*224
。
1 |
|
定义训练函数
训练完成后会保存模型,可以修改模型的保存路径。
1 |
|
训练模型(或加载模型)
如果环境正确配置了 CUDA,则会由 GPU 进行训练。
加载模型需要根据自身情况修改路径。
1 |
|
可视化展示
1 |
|
预测图
结果来自训练轮数epochs=20
,准确率Accuracy=89.41%
的 NiN 模型: