基于深度学习的MNIST手写数字数据集识别(准确率99%,附代码)_mnist数据集99正确率-程序员宅基地

技术标签: 算法  python  机器学习  深度学习  人工智能  

1.Mnist数据集介绍

1.1 基本介绍

Mnist数据集可以算是学习深度学习最常用到的了。这个数据集包含70000张手写数字图片,分别是60000张训练图片和10000张测试图片,训练集由来自250个不同人手写的数字构成,一般来自高中生,一半来自工作人员,测试集(test set)也是同样比例的手写数字数据,并且保证了测试集和训练集的作者不同。每个图片都是2828个像素点,数据集会把一张图片的数据转成一个2828=784的一维向量存储起来。
里面的图片数据如下所示,每张图是0-9的手写数字黑底白字的图片,存储时,黑色用0表示,白色用0-1的浮点数表示。
在这里插入图片描述

1.2 数据集下载

1)官网下载

Mnist数据集的下载地址如下:http://yann.lecun.com/exdb/mnist/
打开后会有四个文件:
在这里插入图片描述

  • 训练数据集:train-images-idx3-ubyte.gz
  • 训练数据集标签:train-labels-idx1-ubyte.gz
  • 测试数据集:t10k-images-idx3-ubyte.gz
  • 测试数据集标签:t10k-labels-idx1-ubyte.gz

将这四个文件下载后放置到需要用的文件夹下即可不要解压!下载后是什么就怎么放!

2)代码导入

文件夹下运行下面的代码,即可自动检测数据集是否存在,若没有会自动进行下载,下载后在这一路径:
在这里插入图片描述

# 下载数据集
from torchvision import datasets, transforms

train_set = datasets.MNIST("data",train=True,download=True, transform=transforms.ToTensor(),)
test_set = datasets.MNIST("data",train=False,download=True, transform=transforms.ToTensor(),)

参数解释:

  • datasets.MNIST:是Pytorch的内置函数torchvision.datasets.MNIST,可以导入数据集
  • train=True :读入的数据作为训练集
  • transform:读入我们自己定义的数据预处理操作
  • download=True:当我们的根目录(root)下没有数据集时,便自动下载

如果这时候我们通过联网自动下载方式download我们的数据后,它的文件路径是以下形式:原文件夹/data/MNIST/raw

2.代码部分

2.1文件夹目录:

在这里插入图片描述

  • test:自己写的测试图片
  • main:主函数
  • model:训练的模型参数,会自动生成
  • data:数据集文件夹

2.2 运行结果

14轮左右,模型识别准确率达到99%以上
在这里插入图片描述

2.3代码

1) 导入必要的包及预处理,本人学习时做了较多注释,且用的是下载好的文件,如果是自己的请更改对应的文件目录哦。

import os
import matplotlib.pyplot as plt
import torch
from PIL import Image
from torch import nn
from torch.nn import Conv2d, Linear, ReLU
from torch.nn import MaxPool2d
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader


# Dataset:创建数据集的函数;__init__:初始化数据内容和标签
# __geyitem:获取数据内容和标签
# __len__:获取数据集大小
# daataloader:数据加载类,接受来自dataset已经加载好的数据集
# torchbision:图形库,包含预训练模型,加载数据的函数、图片变换,裁剪、旋转等
# torchtext:处理文本的工具包,将不同类型的额文件转换为datasets

# 预处理:将两个步骤整合在一起
transform = transforms.Compose({
    
    transforms.ToTensor(),  # 将灰度图片像素值(0~255)转为Tensor(0~1),方便后续处理
    # transforms.Normalize((0.1307,),(0.3081)),    # 归一化,均值0,方差1;mean:各通道的均值std:各通道的标准差inplace:是否原地操作
})

2)加载数据集

# 加载数据集
# 训练数据集
train_data = MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True)
# transform:指示加载的数据集应用的数据预处理的规则,shuffle:洗牌,是否打乱输入数据顺序
# 测试数据集
test_data = MNIST(root="./data", train=False, transform=transform, download=True)
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True)

train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度:{}".format(train_data_size))
print("测试数据集的长度:{}".format(test_data_size))

3)构建模型

成功运行的话请给个免费的赞吧!(调试不易)

模型主要由两个卷积层,两个池化层,以及三个全连接层构成,激活函数使用relu.

class MnistModel(nn.Module):
    def __init__(self):
        super(MnistModel, self).__init__()
        self.conv1 = Conv2d(in_channels=1, out_channels=10, kernel_size=5, stride=1, padding=0)
        self.maxpool1 = MaxPool2d(2)
        self.conv2 = Conv2d(in_channels=10, out_channels=20, kernel_size=5, stride=1, padding=0)
        self.maxpool2 = MaxPool2d(2)
        self.linear1 = Linear(320, 128)
        self.linear2 = Linear(128, 64)
        self.linear3 = Linear(64, 10)
        self.relu = ReLU()

    def forward(self, x):
        x = self.relu(self.maxpool1(self.conv1(x)))
        x = self.relu(self.maxpool2(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = self.linear1(x)
        x = self.linear2(x)
        x = self.linear3(x)

        return x

# 损失函数CrossentropyLoss
model = MnistModel()#实例化
criterion = nn.CrossEntropyLoss()   # 交叉熵损失,相当于Softmax+Log+NllLoss
# 线性多分类模型Softmax,给出最终预测值对于10个类别出现的概率,Log:将乘法转换为加法,减少计算量,保证函数的单调性
# NLLLoss:计算损失,此过程不需要手动one-hot编码,NLLLoss会自动完成

# SGD,优化器,梯度下降算法e
optimizer = torch.optim.SGD(model.parameters(), lr=0.14)#lr:学习率

4)模型训练
每次训练完成后会自动保存参数到pkl模型中,如果路径中有Pkl文件,下次运行会自动加载上一次的模型参数,在这个基础上继续训练,第一次运行时没有模型参数,结束后会自动生成。

# 模型训练
def train():
    # index = 0
    for index, data in enumerate(train_loader):#获取训练数据以及对应标签
        # for data in train_loader:
       input, target = data   # input为输入数据,target为标签
       y_predict = model(input) #模型预测
       loss = criterion(y_predict, target)
       optimizer.zero_grad() #梯度清零
       loss.backward()#loss值反向传播
       optimizer.step()#更新参数
       # index += 1
       if index % 100 == 0: # 每一百次保存一次模型,打印损失
           torch.save(model.state_dict(), "./model/model.pkl")   # 保存模型
           torch.save(optimizer.state_dict(), "./model/optimizer.pkl")
           print("训练次数为:{},损失值为:{}".format(index, loss.item() ))


5)加载模型
第一次运行这里需要一个空的model文件夹

# 加载模型
if os.path.exists('./model/model.pkl'):
   model.load_state_dict(torch.load("./model/model.pkl"))#加载保存模型的参数

6)模型测试

# 模型测试
def test():
    correct = 0     # 正确预测的个数
    total = 0   # 总数
    with torch.no_grad():   # 测试不用计算梯度
        for data in test_loader:
            input, target = data
            output = model(input)   # output输出10个预测取值,概率最大的为预测数
            probability, predict = torch.max(input=output.data, dim=1)    # 返回一个元祖,第一个为最大概率值,第二个为最大概率值的下标
            # loss = criterion(output, target)
            total += target.size(0)  # target是形状为(batch_size,1)的矩阵,使用size(0)取出该批的大小
            correct += (predict == target).sum().item()  # predict 和target均为(batch_size,1)的矩阵,sum求出相等的个数
        print("测试准确率为:%.6f" %(correct / total))

6)自己手写数字图片识别函数(可选用)
这部分主要是加载训练好的pkl模型测试自己的数据,因此在进行自己手写图的测试时,需要有训练好的pkl文件,并且就不要调用train()函数和test()函数啦注意:这个图片像素也要说黑底白字,28*28像素,否则无法识别

def test_mydata():
    image = Image.open('./test/test_two.png')   #读取自定义手写图片
    image = image.resize((28, 28))   # 裁剪尺寸为28*28
    image = image.convert('L')  # 转换为灰度图像
    transform = transforms.ToTensor()
    image = transform(image)
    image = image.resize(1, 1, 28, 28)
    output = model(image)
    probability, predict = torch.max(output.data, dim=1)
    print("此手写图片值为:%d,其最大概率为:%.2f " % (predict[0], probability))
    plt.title("此手写图片值为:{}".format((int(predict))), fontname='SimHei')
    plt.imshow(image.squeeze())
    plt.show()

7)MNIST中的数据识别测试数据
训练过程中的打印信息我进行了修改,这里设置的训练轮数是15轮,每次训练生成的pkl模型参数也是会更新的,想要更多训练信息可以查看对应的教程哦~

#测试识别函数
if __name__ == '__main__':
    #训练与测试
    for i in range(15):#训练和测试进行15轮
        print({
    "————————第{}轮测试开始——————".format (i + 1)})
        train()
        test(

8)测试自己的手写数字图片(可选)
这部分主要是与tset_mydata()函数结合,加载训练好的pkl模型测试自己的数据,因此在进行自己手写图的测试时,需要有训练好的pkl文件,并且就不要调用train()函数和test()函数啦。注意:这个图片像素也要说黑底白字,28*28像素,否则无法识别

# 测试主函数
if __name__ == '__main__':
    test_mydata()

将所有代码按顺序放到编辑器中,安装好对应的包,就可以顺利运行啦。

  • 如果成功运行了请给我一个免费的赞吧,这对我真的很重要!!!
  • 如果不成功请留下你的问题,我会尽量帮忙的!

完整代码放下面:

import os
import matplotlib.pyplot as plt
import torch
from PIL import Image
from torch import nn
from torch.nn import Conv2d, Linear, ReLU
from torch.nn import MaxPool2d
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader


# Dataset:创建数据集的函数;__init__:初始化数据内容和标签
# __geyitem:获取数据内容和标签
# __len__:获取数据集大小
# daataloader:数据加载类,接受来自dataset已经加载好的数据集
# torchbision:图形库,包含预训练模型,加载数据的函数、图片变换,裁剪、旋转等
# torchtext:处理文本的工具包,将不同类型的额文件转换为datasets

# 预处理:将两个步骤整合在一起
transform = transforms.Compose({
    
    transforms.ToTensor(),  # 将灰度图片像素值(0~255)转为Tensor(0~1),方便后续处理
    # transforms.Normalize((0.1307,),(0.3081)),    # 归一化,均值0,方差1;mean:各通道的均值std:各通道的标准差inplace:是否原地操作
})

# normalize执行以下操作:image=(image-mean)/std?????
# input[channel] = (input[channel] - mean[channel]) / std[channel]

# 加载数据集
# 训练数据集
train_data = MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True)
# transform:指示加载的数据集应用的数据预处理的规则,shuffle:洗牌,是否打乱输入数据顺序
# 测试数据集
test_data = MNIST(root="./data", train=False, transform=transform, download=True)
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True)

train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度:{}".format(train_data_size))
print("测试数据集的长度:{}".format(test_data_size))
# print(test_data)
# print(train_data)


class MnistModel(nn.Module):
    def __init__(self):
        super(MnistModel, self).__init__()
        self.conv1 = Conv2d(in_channels=1, out_channels=10, kernel_size=5, stride=1, padding=0)
        self.maxpool1 = MaxPool2d(2)
        self.conv2 = Conv2d(in_channels=10, out_channels=20, kernel_size=5, stride=1, padding=0)
        self.maxpool2 = MaxPool2d(2)
        self.linear1 = Linear(320, 128)
        self.linear2 = Linear(128, 64)
        self.linear3 = Linear(64, 10)
        self.relu = ReLU()

    def forward(self, x):
        x = self.relu(self.maxpool1(self.conv1(x)))
        x = self.relu(self.maxpool2(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = self.linear1(x)
        x = self.linear2(x)
        x = self.linear3(x)

        return x


# 损失函数CrossentropyLoss
model = MnistModel()#实例化
criterion = nn.CrossEntropyLoss()   # 交叉熵损失,相当于Softmax+Log+NllLoss
# 线性多分类模型Softmax,给出最终预测值对于10个类别出现的概率,Log:将乘法转换为加法,减少计算量,保证函数的单调性
# NLLLoss:计算损失,此过程不需要手动one-hot编码,NLLLoss会自动完成

# SGD,优化器,梯度下降算法e
optimizer = torch.optim.SGD(model.parameters(), lr=0.14)#lr:学习率


# 模型训练
def train():
    # index = 0
    for index, data in enumerate(train_loader):#获取训练数据以及对应标签
        # for data in train_loader:
       input, target = data   # input为输入数据,target为标签
       y_predict = model(input) #模型预测
       loss = criterion(y_predict, target)
       optimizer.zero_grad() #梯度清零
       loss.backward()#loss值反向传播
       optimizer.step()#更新参数
       # index += 1
       if index % 100 == 0: # 每一百次保存一次模型,打印损失
           torch.save(model.state_dict(), "./model/model.pkl")   # 保存模型
           torch.save(optimizer.state_dict(), "./model/optimizer.pkl")
           print("训练次数为:{},损失值为:{}".format(index, loss.item() ))

# 加载模型
if os.path.exists('./model/model.pkl'):
   model.load_state_dict(torch.load("./model/model.pkl"))#加载保存模型的参数


# 模型测试
def test():
    correct = 0     # 正确预测的个数
    total = 0   # 总数
    with torch.no_grad():   # 测试不用计算梯度
        for data in test_loader:
            input, target = data
            output = model(input)   # output输出10个预测取值,概率最大的为预测数
            probability, predict = torch.max(input=output.data, dim=1)    # 返回一个元祖,第一个为最大概率值,第二个为最大概率值的下标
            # loss = criterion(output, target)
            total += target.size(0)  # target是形状为(batch_size,1)的矩阵,使用size(0)取出该批的大小
            correct += (predict == target).sum().item()  # predict 和target均为(batch_size,1)的矩阵,sum求出相等的个数
        print("测试准确率为:%.6f" %(correct / total))


#测试识别函数
if __name__ == '__main__':
    #训练与测试
    for i in range(15):#训练和测试进行5轮
        print({
    "————————第{}轮测试开始——————".format (i + 1)})
        train()
        test()


def test_mydata():
    image = Image.open('./test/test_two.png')   #读取自定义手写图片
    image = image.resize((28, 28))   # 裁剪尺寸为28*28
    image = image.convert('L')  # 转换为灰度图像
    transform = transforms.ToTensor()
    image = transform(image)
    image = image.resize(1, 1, 28, 28)
    output = model(image)
    probability, predict = torch.max(output.data, dim=1)
    print("此手写图片值为:%d,其最大概率为:%.2f " % (predict[0], probability))
    plt.title("此手写图片值为:{}".format((int(predict))), fontname='SimHei')
    plt.imshow(image.squeeze())
    plt.show()

# 测试主函数
# if __name__ == '__main__':
#     test_mydata()
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/Ps_hello/article/details/134169021

智能推荐

51单片机的中断系统_51单片机中断篇-程序员宅基地

文章浏览阅读3.3k次,点赞7次,收藏39次。CPU 执行现行程序的过程中,出现某些急需处理的异常情况或特殊请求,CPU暂时中止现行程序,而转去对异常情况或特殊请求进行处理,处理完毕后再返回现行程序断点处,继续执行原程序。void 函数名(void) interrupt n using m {中断函数内容 //尽量精简 }编译器会把该函数转化为中断函数,表示中断源编号为n,中断源对应一个中断入口地址,而中断入口地址的内容为跳转指令,转入本函数。using m用于指定本函数内部使用的工作寄存器组,m取值为0~3。该修饰符可省略,由编译器自动分配。_51单片机中断篇

oracle项目经验求职,网络工程师简历中的项目经验怎么写-程序员宅基地

文章浏览阅读396次。项目经验(案例一)项目时间:2009-10 - 2009-12项目名称:中驰别克信息化管理整改完善项目描述:项目介绍一,建立中驰别克硬件档案(PC,服务器,网络设备,办公设备等)二,建立中驰别克软件档案(每台PC安装的软件,财务,HR,OA,专用系统等)三,能过建立的档案对中驰别克信息化办公环境优化(合理使用ADSL宽带资源,对域进行调整,对文件服务器进行优化,对共享打印机进行调整)四,优化完成后..._网络工程师项目经历

LVS四层负载均衡集群-程序员宅基地

文章浏览阅读1k次,点赞31次,收藏30次。LVS:Linux Virtual Server,负载调度器,内核集成, 阿里的四层SLB(Server Load Balance)是基于LVS+keepalived实现。NATTUNDR优点端口转换WAN性能最好缺点性能瓶颈服务器支持隧道模式不支持跨网段真实服务器要求anyTunneling支持网络private(私网)LAN/WAN(私网/公网)LAN(私网)真实服务器数量High (100)High (100)真实服务器网关lvs内网地址。

「技术综述」一文道尽传统图像降噪方法_噪声很大的图片可以降噪吗-程序员宅基地

文章浏览阅读899次。https://www.toutiao.com/a6713171323893318151/作者 | 黄小邪/言有三编辑 | 黄小邪/言有三图像预处理算法的好坏直接关系到后续图像处理的效果,如图像分割、目标识别、边缘提取等,为了获取高质量的数字图像,很多时候都需要对图像进行降噪处理,尽可能的保持原始信息完整性(即主要特征)的同时,又能够去除信号中无用的信息。并且,降噪还引出了一..._噪声很大的图片可以降噪吗

Effective Java 【对于所有对象都通用的方法】第13条 谨慎地覆盖clone_为继承设计类有两种选择,但无论选择其中的-程序员宅基地

文章浏览阅读152次。目录谨慎地覆盖cloneCloneable接口并没有包含任何方法,那么它到底有什么作用呢?Object类中的clone()方法如何重写好一个clone()方法1.对于数组类型我可以采用clone()方法的递归2.如果对象是非数组,建议提供拷贝构造器(copy constructor)或者拷贝工厂(copy factory)3.如果为线程安全的类重写clone()方法4.如果为需要被继承的类重写clone()方法总结谨慎地覆盖cloneCloneable接口地目的是作为对象的一个mixin接口(详见第20_为继承设计类有两种选择,但无论选择其中的

毕业设计 基于协同过滤的电影推荐系统-程序员宅基地

文章浏览阅读958次,点赞21次,收藏24次。今天学长向大家分享一个毕业设计项目基于协同过滤的电影推荐系统项目运行效果:项目获取:https://gitee.com/assistant-a/project-sharing21世纪是信息化时代,随着信息技术和网络技术的发展,信息化已经渗透到人们日常生活的各个方面,人们可以随时随地浏览到海量信息,但是这些大量信息千差万别,需要费事费力的筛选、甄别自己喜欢或者感兴趣的数据。对网络电影服务来说,需要用到优秀的协同过滤推荐功能去辅助整个系统。系统基于Python技术,使用UML建模,采用Django框架组合进行设

随便推点

你想要的10G SFP+光模块大全都在这里-程序员宅基地

文章浏览阅读614次。10G SFP+光模块被广泛应用于10G以太网中,在下一代移动网络、固定接入网、城域网、以及数据中心等领域非常常见。下面易天光通信(ETU-LINK)就为大家一一盘点下10G SFP+光模块都有哪些吧。一、10G SFP+双纤光模块10G SFP+双纤光模块是一种常规的光模块,有两个LC光纤接口,传输距离最远可达100公里,常用的10G SFP+双纤光模块有10G SFP+ SR、10G SFP+ LR,其中10G SFP+ SR的传输距离为300米,10G SFP+ LR的传输距离为10公里。_10g sfp+

计算机毕业设计Node.js+Vue基于Web美食网站设计(程序+源码+LW+部署)_基于vue美食网站源码-程序员宅基地

文章浏览阅读239次。该项目含有源码、文档、程序、数据库、配套开发软件、软件安装教程。欢迎交流项目运行环境配置:项目技术:Express框架 + Node.js+ Vue 等等组成,B/S模式 +Vscode管理+前后端分离等等。环境需要1.运行环境:最好是Nodejs最新版,我们在这个版本上开发的。其他版本理论上也可以。2.开发环境:Vscode或HbuilderX都可以。推荐HbuilderX;3.mysql环境:建议是用5.7版本均可4.硬件环境:windows 7/8/10 1G内存以上;_基于vue美食网站源码

oldwain随便写@hexun-程序员宅基地

文章浏览阅读62次。oldwain随便写@hexun链接:http://oldwain.blog.hexun.com/ ...

渗透测试-SQL注入-SQLMap工具_sqlmap拖库-程序员宅基地

文章浏览阅读843次,点赞16次,收藏22次。用这个工具扫描其它网站时,要注意法律问题,同时也比较慢,所以我们以之前写的登录页面为例子扫描。_sqlmap拖库

origin三图合一_神教程:Origin也能玩转图片拼接组合排版-程序员宅基地

文章浏览阅读1.5w次,点赞5次,收藏38次。Origin也能玩转图片的拼接组合排版谭编(华南师范大学学报编辑部,广州 510631)通常,我们利用Origin软件能非常快捷地绘制出一张单独的绘图。但是,我们在论文的撰写过程中,经常需要将多种科学实验图片(电镜图、示意图、曲线图等)组合在一张图片中。大多数人都是采用PPT、Adobe Illustrator、CorelDraw等软件对多种不同类型的图进行拼接的。那么,利用Origin软件能否实..._origin怎么把三个图做到一张图上

51单片机智能电风扇控制系统proteus仿真设计( 仿真+程序+原理图+报告+讲解视频)_电风扇模拟控制系统设计-程序员宅基地

文章浏览阅读4.2k次,点赞4次,收藏51次。51单片机智能电风扇控制系统仿真设计( proteus仿真+程序+原理图+报告+讲解视频)仿真图proteus7.8及以上 程序编译器:keil 4/keil 5 编程语言:C语言 设计编号:S0042。_电风扇模拟控制系统设计