博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
用Pytorch训练MNIST分类模型
阅读量:4522 次
发布时间:2019-06-08

本文共 2930 字,大约阅读时间需要 9 分钟。

本次分类问题使用的数据集是MNIST,每个图像的大小为\(28*28\)

编写代码的步骤如下

  1. 载入数据集,分别为训练集和测试集
  2. 让数据集可以迭代
  3. 定义模型,定义损失函数,训练模型
代码
import torchimport torch.nn as nnimport torchvision.transforms as transformsimport torchvision.datasets as dsetsfrom torch.autograd import Variable'''下载训练集和测试集'''train_dataset = dsets.MNIST(root='./datasets',                            train=True,                             transform=transforms.ToTensor(),                            download=True)test_dataset = dsets.MNIST(root='./datasets',                           train=False,                            transform=transforms.ToTensor())'''让数据集可以迭代'''batch_size = 100n_iters = 3000num_epochs = n_iters / (len(train_dataset) / batch_size)num_epochs = int(num_epochs)train_loader = torch.utils.data.DataLoader(dataset=train_dataset,                                            batch_size=batch_size,                                            shuffle=True)test_loader = torch.utils.data.DataLoader(dataset=test_dataset,                                           batch_size=batch_size,                                           shuffle=False)'''定义模型'''class LogisticRegressionModel(nn.Module):    def __init__(self, input_dim, output_dim):        super(LogisticRegressionModel, self).__init__()        self.linear = nn.Linear(input_dim, output_dim)        def forward(self, x):        out = self.linear(x)        return out'''实例化模型'''input_dim = 28*28output_dim = 10model = LogisticRegressionModel(input_dim, output_dim)'''定义损失计算方式'''criterion = nn.CrossEntropyLoss()learning_rate = 0.001optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)'''训练次数'''iter = 0for epoch in range(num_epochs):    for i, (images, labels) in enumerate(train_loader):        images = Variable(images.view(-1, 28*28))        labels = Variable(labels)                #梯度置零        optimizer.zero_grad()                #计算输出        outputs = model(images)                #计算损失,内部会自动softmax然后进行Crossentropy        loss = criterion(outputs, labels)                #反向传播        loss.backward()                #更新参数        optimizer.step()                iter += 1                if iter % 500 == 0:            #计算准确度            correct = 0            total = 0            for images, labels in test_loader:                images = Variable(images.view(-1, 28*28))                                #获得输出,输出的大小为(batch_size,10)                outputs = model(images)                                #获得预测值,输出的大小为(batch_size,1)                _, predicted = torch.max(outputs.data, 1)                                #labels的size是(100,)                total += labels.size(0)                #返回的是预测值和标签值相等的个数                correct += (predicted == labels).sum()                        accuracy = 100 * correct / total                        # Print Loss            print('Iteration: {}. Loss: {}. Accuracy: {}'.format(iter, loss.data[0], accuracy))
输出如下

1414681-20190207155833073-448672162.jpg

转载于:https://www.cnblogs.com/MartinLwx/p/10354889.html

你可能感兴趣的文章
node学习之搭建服务器并加装静态资源
查看>>
android 按menu键解锁功能的开关
查看>>
wpf 自定义窗口,最大化时覆盖任务栏解决方案
查看>>
Linux 下的dd命令使用详解
查看>>
POJ-1273 Drainage Ditches 最大流Dinic
查看>>
ASP.NET学习记录点滴
查看>>
uva 12097(二分)
查看>>
[Noip2016] 愤怒的小鸟
查看>>
Linux系统基础管理
查看>>
JAVA wait()和notifyAll()实现线程间通讯
查看>>
python全栈脱产第11天------装饰器
查看>>
koa2 从入门到进阶之路 (一)
查看>>
Java / Android 基于Http的多线程下载的实现
查看>>
求职历程-----我的简历
查看>>
[总结]数据结构(板子)
查看>>
网页图片加载失败,用默认图片替换
查看>>
C# 笔记
查看>>
android 之输入法
查看>>
编译参数-ObjC的说明
查看>>
配置Synergy(Server : XP, client: Win7)
查看>>