首页 星云 工具 资源 星选 资讯 热门工具
:

PDF转图片 完全免费 小红书视频下载 无水印 抖音视频下载 无水印 数字星空

零基础学习人工智能—Python—Pytorch学习(九)

编程知识
2024年08月27日 08:03

前言

本文主要介绍卷积神经网络的使用的下半部分。
另外,上篇文章增加了一点代码注释,主要是解释(w-f+2p)/s+1这个公式的使用。
所以,要是这篇文章的代码看不太懂,可以翻一下上篇文章。

代码实现

之前,我们已经学习了概念,在结合我们以前学习的知识,我们可以直接阅读下面代码了。
代码里使用了,dataset.CIFAR10数据集。
CIFAR-10 数据集由 60000 张 32x32 彩色图像组成,共分为 10 个不同的类别,分别是飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船和卡车。
每个类别包含 6000 张图像,其中 50000 张用于训练,10000 张用于测试。

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F #nn不好使时,在这里找激活函数
# device config
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# hyper parameters
input_size = 784  # 28x28
hidden_size = 100
num_classes = 10
batch_size = 100
learning_rate = 0.001
num_epochs = 2

 
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

train_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform)

test_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform)
 
train_loader = torch.utils. data.DataLoader(
    dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    dataset=test_dataset, batch_size=batch_size, shuffle=False)
print('每份100个,被分成多少份:', len(test_loader))


classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')


class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet,self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16*5*5, 120) #这个在forward里解释
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x))) #这里x已经变成 torch.Size([4, 16, 5, 5])
        # print("两次卷积两次池化后的x.shape:",x.shape)
        x = x.view(-1,16*5*5)#这里的16*5*5就是x的后面3个维度相乘
        x = F.relu(self.fc1(x)) #fc1定义时,inputx已经是6*5*5了
        x = F.relu(self.fc2(x))
        x= self.fc3(x)
        return x


model = ConvNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)


n_total_steps = len(train_loader)

for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        # origin shape:[4,3,32,32]=4,3,1024
        # input layer: 3 input channels, 6 output channels, 5 kernel size
        images = images.to(device)
        labels = labels.to(device)
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if (i+1) % 2000 == 0:
            print(
                f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss: {loss.item():.4f}')
print('Finished Training')


# test
with torch.no_grad():
    n_correct = 0
    n_samples = 0
    n_class_correct = [0 for i in range(10)] #生成 10 个 0 的列表
    n_class_samples = [0 for i in range(10)]
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        print('test-images.shape:', images.shape)
        outputs = model(images)
        # max returns(value ,index)
        _, predicted = torch.max(outputs, 1)
        n_samples += labels.size(0)
        n_correct += (predicted == labels).sum().item()
        for i in range(batch_size):
            label = labels[i]
            # print("label:",label) #这里存的是 0~9的数字 输出就是这样的 label: tensor(2) predicted[i]也是这样的数
            pred = predicted[i]
            if (label == pred):
                n_class_correct[label] += 1
            n_class_samples[label] += 1
    acc = 100.0*n_correct/n_samples  # 计算正确率
    print(f'accuracy ={acc}')
    
    for i in range(10):
        acc = 100.0*n_class_correct[i]/n_class_samples[i]
        print(f'Accuracy of {classes[i]}: {acc} %')

运行结果如下:

accuracy =10.26
Accuracy of plane: 0.0 %
Accuracy of car: 0.0 %
Accuracy of bird: 0.0 %
Accuracy of cat: 0.0 %
Accuracy of deer: 0.0 %
Accuracy of dog: 0.0 %
Accuracy of frog: 0.0 %
Accuracy of horse: 0.0 %
Accuracy of ship: 89.6 %
Accuracy of truck: 13.0 %

这是因为我设置的num_epochs=2,也就是循环的次数太低,所以结果的精确度就很低。
我们只要增加epochs的值,就能提高精确度了。


传送门:
零基础学习人工智能—Python—Pytorch学习—全集

这样我们卷积神经网络就学完了。


注:此文章为原创,任何形式的转载都请联系作者获得授权并注明出处!



若您觉得这篇文章还不错,请点击下方的【推荐】,非常感谢!

https://www.cnblogs.com/kiba/p/18381036

From:https://www.cnblogs.com/kiba/p/18381036
本文地址: http://www.shuzixingkong.net/article/1472
0评论
提交 加载更多评论
其他文章 6.2K star!推荐一款开源混沌工程测试平台:Chaos Mesh
1、Chaos Mesh 介绍 Chaos Mesh是一个开源的混沌工程平台,旨在帮助用户在生产环境中测试、验证和优化其应用程序的可靠性和稳定性。通过引入故障注入和混沌工程原则,Chaos Mesh可以模拟各种故障场景,如网络延迟、节点故障、磁盘故障等,以帮助用户发现和解决系统中的潜在问题。 项目地
6.2K star!推荐一款开源混沌工程测试平台:Chaos Mesh 6.2K star!推荐一款开源混沌工程测试平台:Chaos Mesh
网卡-热点搜索不到或者无法连接问题
大屏设置网卡开启热点后,经常收到反馈,手机端无法搜索到大屏热点、或者手机连接大屏热点失败 这类问题一般有以下几类情况: 1. 物理网卡IP与热点网卡IP相同 2. 热点网卡IP,非正常热点IP(192.168.137.X) 热点IP我们一般定为192.168.137.X,192.168.137.X是
Tomcat的配置文件中有哪些关键的配置项,它们分别有什么作用?
Tomcat的配置文件主要包括server.xml和web.xml,它们位于Tomcat安装目录下的conf文件夹中。今天的内容重点介绍 server.xml 文件的配置,V 哥会结合一些业务场景来介绍,希望可以帮助到你,以下是一些关键的配置项及其作用: server.xml中的配置项: <S
详细分析平衡树-红黑树的平衡修正 图文详解(附代码) (万字长文)
目录红黑树简述性质/规则主要规则:推导性质:红黑树的基本实现struct RBTreeNodeclass RBTree红黑树的插入红黑树插入修正前言什么时候需要变色:变色的基础:为什么需要旋转与变色变色:旋转需要修正的所有情况先认识最简单的情况1. 叔叔是红色结点注意:2.没有叔叔结点3. 叔叔是黑
详细分析平衡树-红黑树的平衡修正 图文详解(附代码) (万字长文) 详细分析平衡树-红黑树的平衡修正 图文详解(附代码) (万字长文) 详细分析平衡树-红黑树的平衡修正 图文详解(附代码) (万字长文)
One-for-All:上交大提出视觉推理的符号化与逻辑推理分离的新范式 | ECCV 2024
通过对多样化基准的严格评估,论文展示了现有特定方法在实现跨领域推理以及其偏向于数据偏差拟合方面的缺陷。从两阶段的视角重新审视视觉推理:(1)符号化和(2)基于符号或其表示的逻辑推理,发现推理阶段比符号化更擅长泛化。因此,更高效的做法是通过为不同数据领域使用分离的编码器来实现符号化,同时使用共享的推理
One-for-All:上交大提出视觉推理的符号化与逻辑推理分离的新范式 | ECCV 2024 One-for-All:上交大提出视觉推理的符号化与逻辑推理分离的新范式 | ECCV 2024 One-for-All:上交大提出视觉推理的符号化与逻辑推理分离的新范式 | ECCV 2024
折腾 Quickwit,Rust 编写的分布式搜索引擎-官方配置详解
Node configuration(节点配置) 节点配置允许您为集群中的各个节点自定义和优化设置。它被分为几个部分: 常规配置设置:共享的顶级属性 Storage(存储)设置:在storage部分定义 https://quickwit.io/docs/configuration/node-conf
折腾 Quickwit,Rust 编写的分布式搜索引擎-官方配置详解
每天那么多工作,我为什么能做到 "不忘事" ?
我相信很多朋友都遇到过丢失工作、或者忘记事情的情况,尤其是事情一多,就更容易遗漏;而如果在工作中你漏掉了某项任务,需要上级或同事重复提醒你,是很影响别人对你的印象的。 那么如何解决这个问题呢?我有一些自己的经验。
每天那么多工作,我为什么能做到 "不忘事" ? 每天那么多工作,我为什么能做到 "不忘事" ? 每天那么多工作,我为什么能做到 "不忘事" ?
[kernel] 带着问题看源码 —— 脚本是如何被 execve 调用的
Linux 脚本文件 shebang (!#) 行最大为何只有 128 字节?为何最多只能指定一个参数?如何将这些参数排列在参数列表前面?本文通过阅读 Linux 内核源码,一一为你揭秘
[kernel] 带着问题看源码 —— 脚本是如何被 execve 调用的 [kernel] 带着问题看源码 —— 脚本是如何被 execve 调用的 [kernel] 带着问题看源码 —— 脚本是如何被 execve 调用的