pytorch初学笔记(十四):损失函数
创始人
2024-03-05 02:24:03
0

 目录

一、损失函数 

1.1 L1损失函数

1.1.1 简介

1.1.2 参数设定

1.1.3 代码实现

1.2 MSE损失函数(平方和)

1.2.1 简介

1.2.2 参数介绍

1.2.3 代码实现

1.3 损失函数的作用

二、在神经网络中使用loss function

2.1 使用交叉熵损失函数 

2.2 反向传播


一、损失函数 

torch.nn — PyTorch 1.13 documentation

每一个样本经过模型后会得到一个预测值,然后得到的预测值和真实值的差值就成为损失(当然损失值越小证明模型越是成功),我们知道有许多不同种类的损失函数,这些函数本质上就是计算预测值和真实值的差距的一类型函数,然后经过库(如pytorch,tensorflow等)的封装形成了有具体名字的函数。

1.1 L1损失函数

1.1.1 简介

L1损失函数: 基于逐像素比较差异,然后取绝对值。

在这里插入图片描述 

1.1.2 参数设定 

L1Loss — PyTorch 1.13 documentation

 CLASS torch.nn.L1Loss(size_average=Nonereduce=Nonereduction='mean')

 

我们一般设定reduction的值来显示平均值或者和。

 参数设定:

reduction可取的值: 'none' | 'mean' | 'sum'

  • 'none': no reduction will be applied
  •  'mean': the sum of the output will be divided by the number of elements in the output,求的是平均值,即各个差求和之后除以总数。

  •  'sum': the output will be summed. Note: size_average and reduce are in the process of being deprecated, and in the meantime, specifying either of those two args will override reduction.只求和,不除总数。

  • Default: 'mean'

1.1.3 代码实现

 设置reduction的值为默认值和sum,观察区别。

import torch
from torch.nn import L1Lossinputs = torch.tensor([1,2,3],dtype=torch.float32)
targets = torch.tensor([1,2,5],dtype=torch.float32)inputs = torch.reshape(inputs,(1,1,1,3))
targets = torch.reshape(targets,(1,1,1,3))loss1 = L1Loss()
result1 = loss1(inputs,targets)
print(result1)loss2 = L1Loss(reduction="sum")
result2 = loss2(inputs,targets)
print(result2)
  • 当取值为默认值mean时,求的是平均值,sum=(1-1+2-2+5-3)=2, n=3, result = sum/n=0.6667 
  • 当取值为sum时,求的是和,即result=2

 

1.2 MSE损失函数(平方和)

1.2.1 简介

均方误差(Mean Square Error,MSE)是回归损失函数中最常用的误差,它是预测值f(x)与目标值y之间差值平方和的均值,其公式如下所示:
在这里插入图片描述 

 

1.2.2 参数介绍

MSELoss — PyTorch 1.13 documentation

 

与上面的L1损失函数一样,我们可以改变reduction的值来进行对应数值的输出。

 

1.2.3 代码实现

import torch
from torch.nn import L1Loss, MSELossinputs = torch.tensor([1,2,3],dtype=torch.float32)
targets = torch.tensor([1,2,5],dtype=torch.float32)inputs = torch.reshape(inputs,(1,1,1,3))
targets = torch.reshape(targets,(1,1,1,3))loss_mse1 = MSELoss()
result1 = loss_mse1(inputs,targets)
print(result1)loss_mse2 = MSELoss(reduction="sum")
result2 = loss_mse2(inputs,targets)
print(result2)

 可以看到reduction设置不同的值对应的输出也不同。

 

1.3 损失函数的作用

  1. 计算实际输出和目标之间的差距
  2. 为更新输出(反向传播)提供一定的依据

二、在神经网络中使用loss function

2.1 使用交叉熵损失函数 

使用上次定义的神经网络和CIFAR10数据集进行图像分类,分类问题使用交叉熵损失函数。

import torch.nn
from torch import nn
import torchvision.datasets
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset = torchvision.datasets.CIFAR10(root="./CIFAR10",train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader = DataLoader(dataset,batch_size=1)class Maweiyi(torch.nn.Module):def __init__(self):super(Maweiyi, self).__init__()self.model1 = Sequential(Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2),MaxPool2d(kernel_size=2),Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2),MaxPool2d(kernel_size=2),Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2),MaxPool2d(kernel_size=2),Flatten(),Linear(in_features=1024, out_features=64),Linear(in_features=64, out_features=10))def forward(self, x):x = self.model1(x)return xmaweiyi = Maweiyi()
# 使用交叉熵损失函数
loss_cross = nn.CrossEntropyLoss()for data in dataloader:imgs,labels = dataoutputs = maweiyi(imgs)results = loss_cross(outputs,labels)print(results)

可以看到使用loss function计算出了在神经网路中预测的output和真实值labels之间的差距大小。 

 

2.2 反向传播

results_loss.backward() 

  

相关内容

热门资讯

保存时出现了1个错误,导致这篇... 当保存文章时出现错误时,可以通过以下步骤解决问题:查看错误信息:查看错误提示信息可以帮助我们了解具体...
汇川伺服电机位置控制模式参数配... 1. 基本控制参数设置 1)设置位置控制模式   2)绝对值位置线性模...
不能访问光猫的的管理页面 光猫是现代家庭宽带网络的重要组成部分,它可以提供高速稳定的网络连接。但是,有时候我们会遇到不能访问光...
表格中数据未显示 当表格中的数据未显示时,可能是由于以下几个原因导致的:HTML代码问题:检查表格的HTML代码是否正...
本地主机上的图像未显示 问题描述:在本地主机上显示图像时,图像未能正常显示。解决方法:以下是一些可能的解决方法,具体取决于问...
表格列调整大小出现问题 问题描述:表格列调整大小出现问题,无法正常调整列宽。解决方法:检查表格的布局方式是否正确。确保表格使...
不一致的条件格式 要解决不一致的条件格式问题,可以按照以下步骤进行:确定条件格式的规则:首先,需要明确条件格式的规则是...
Android|无法访问或保存... 这个问题可能是由于权限设置不正确导致的。您需要在应用程序清单文件中添加以下代码来请求适当的权限:此外...
【NI Multisim 14...   目录 序言 一、工具栏 🍊1.“标准”工具栏 🍊 2.视图工具...
银河麒麟V10SP1高级服务器... 银河麒麟高级服务器操作系统简介: 银河麒麟高级服务器操作系统V10是针对企业级关键业务...