news 2026/3/6 8:01:38

基于PyTorch的深度学习基础课程之九:分类模型评价指标(1|3)

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
基于PyTorch的深度学习基础课程之九:分类模型评价指标(1|3)

本文详细讨论了分类模型的常用评价指标,包括准确率、平均准确率、混淆矩阵、精确率、召回率、F1值和AUC等。对这些指标含义的理解和运用,尤其是在不平衡样本数据集上的应用,是设计恰当模型和指导AI大模型调整模型需要掌握的知识。对这些指标的讨论采用了示例入手、逐步推进的方式,便于读者理解。

在本专栏的前述文章里,对分类模型的评价采用了最简单的准确率。本文详细讨论分类模型的常用评价指标。无论是自己设计模型,还是指导AI大模型去调整模型,评价指标显然是必须理解的内容。

本文仍然采用示例入手的分析方法,便于读者理解。读者也可暂时跳过公式推导部分,先掌握应用方法。

1 准确率(Accuracy)

准确率是指在分类中,用模型对测试集进行分类,分类正确的样本数占总数的比例:
accuracy=ncorrectntotal(式9-1) \text{accuracy} = \frac{n_{\text{correct}}}{n_{\text{total}}}\tag{式9-1}accuracy=ntotalncorrect(9-1)
sklearn库中提供了一个专门对模型进行评估的包metrics,该包可以满足一般的模型评估需求。其中提供了准确率计算函数,函数原型为:sklearn.metrics.accuracy_score(y_true,y_pred,normalize=True,sample_weight=None)。其中,normalize默认值为True,返回正确分类的比例,如果设为False,则返回正确分类的样本数。

准确率的计算比较简单,在前文的示例中是直接用Python代码来实现,如先统计预测成功的样本数的代码:correct += (predicted == labels).sum().item(),最后将它除以样本总数即可。

2 平均准确率(Average Per-class Accuracy)

准确率指标虽然简单、易懂,但它没有对不同类别进行区分。不同类别下分类错误的代价可能不同,例如在重大病患诊断中,漏诊可能要比误诊给治疗带来更为严重的后果,此时准确率就不足以反映预测的效果。

如果样本类别分布不平衡,即有的类别下的样本过多,有的类别下的样本个数过少,准确率也难以反映真实预测效果。例如,某病的发病率为十万分之一,如果简单地将所有样本都判断为负,就可以取得99.999%的准确率,但是实际上,该模型并没有任何作用。

平均准确率的全称为:按类平均准确率,即计算每个类别的准确率,然后再计算它们的平均值。在上面的例子中,模型对健康人的预测成功率为100%,而对病患的预测成功率为0,因此该模型的平均准确率为50%。

实际上,类别不平衡问题是深度学习中一个重要而且经常面对的问题,人们从评价指标、损失函数等多方面对该问题进行了研究。

下面还是用MNIST数据集的示例来说明平均准确率等。在该示例中,所有图片分为两类:数字0的图片和非数字0的图片,显然这样划分后的样本集是不平衡的,方便用于示例各类指标的评价作用。划分新的样本集,用来训练一个二分类的模型,并用准确率和平均准确率来对之进行评价的代码见代码9-1.1。

在数据预处理部分,重设了标本的标签,将所有非0数字图片的标签都设为1。在每轮训练时,都计算样本集上的准确率和平均准确率,可见从第1轮训练到第10轮训练,模型对小类别样本的识别率是逐渐升高的。

准确率和平均准确率的计算并不难实现,建议在理解原理的基础上直接采用sklearn中的函数来实现。sklearn中的平均准确率函数为sklearn.metrics.balanced_accuracy_score()。

代码9-1.1 准确率和平均准确率示例

### 1.导入和设置环境importtorchimporttorch.nnasnnimporttorch.optimasoptimfromtorch.utils.dataimportDataLoader,TensorDatasetimportdatetimefromtorchvisionimportdatasets,transformsimportnumpyasnpimportsklearn.metricsasmetricsimportmatplotlib.pyplotaspltimportseabornassns# 设置中文字体plt.rcParams['font.sans-serif']=['SimHei','Microsoft YaHei','DejaVu Sans Fallback']# 用来正常显示中文标签plt.rcParams['axes.unicode_minus']=False# 用来正常显示负号# 设置随机种子torch.manual_seed(0)### 2.训练样本和验证样本数据预处理# 数据预处理方式transform=transforms.Compose([transforms.ToTensor(),# 转换为 torch.Tensor])# 加载MNIST数据集train_dataset=datasets.MNIST('./data',train=True,download=True,transform=transform)val_dataset=datasets.MNIST('./data',train=False,transform=transform)# 样本拉平、归一化X_train=train_dataset.data.float().view(-1,784)/255.0y_train=train_dataset.targets X_val=val_dataset.data.float().view(-1,784)/255.0y_val=val_dataset.targets# 将标签重新映射为0和非0两类:0 -> 类别0,非0 -> 类别1y_train_binary=(y_train!=0).long()# 0保持为0,非0变为1y_val_binary=(y_val!=0).long()# 转换为独热编码(现在只有2个类别)y_train=torch.nn.functional.one_hot(y_train_binary,num_classes=2).float()y_val=torch.nn.functional.one_hot(y_val_binary,num_classes=2).float()# 创建数据加载器batch_size=200train_loader=DataLoader(TensorDataset(X_train,y_train),batch_size=batch_size,shuffle=True)val_loader=DataLoader(TensorDataset(X_val,y_val),batch_size=batch_size)### 3.定义神经网络模型classMNISTModel(nn.Module):def__init__(self):super(MNISTModel,self).__init__()self.fc1=nn.Linear(784,784)self.fc2=nn.Linear(784,784)self.fc3=nn.Linear(784,2)# 修改:输出层改为2个神经元,对应0和非0两类self.sigmoid=nn.Sigmoid()defforward(self,x):x=self.sigmoid(self.fc1(x))x=self.sigmoid(self.fc2(x))x=self.fc3(x)returnx### 4.创建模型并用训练样本对它进行训练model=MNISTModel()# 实例化模型类得到模型对象criterion=nn.CrossEntropyLoss()# 定义损失函数optimizer=optim.SGD(model.parameters(),lr=0.15)# 定义优化器# 训练模型,开始计时start_time=datetime.datetime.now()epochs=10forepochinrange(epochs):# 每轮中的训练model.train()train_loss=0.0# 用于计算训练集上的指标train_preds=[]train_labels=[]forbatch_X,batch_yintrain_loader:optimizer.zero_grad()outputs=model(batch_X)# 获取原始二元标签,相当于从独热编码标签变回原始标签_,batch_labels=torch.max(batch_y.data,1)loss=criterion(outputs,batch_y)loss.backward()optimizer.step()train_loss+=loss.item()# 收集训练集的预测和标签withtorch.no_grad():_,predicted=torch.max(outputs.data,1)train_preds.append(predicted.cpu())train_labels.append(batch_labels.cpu())# 计算训练集指标train_preds_all=torch.cat(train_preds).numpy()train_labels_all=torch.cat(train_labels).numpy()# 使用sklearn计算指标train_acc=metrics.accuracy_score(train_labels_all,train_preds_all)*100train_acc_0=metrics.accuracy_score(train_labels_all[train_labels_all==0],train_preds_all[train_labels_all==0])*100ifnp.sum(train_labels_all==0)>0else0train_acc_1=metrics.accuracy_score(train_labels_all[train_labels_all==1],train_preds_all[train_labels_all==1])*100ifnp.sum(train_labels_all==1)>0else0train_avg_acc=metrics.balanced_accuracy_score(train_labels_all,train_preds_all)*100print(f'Epoch{epoch+1}/{epochs}:')print(f' 训练集准确率:{train_acc:.2f}%')print(f' 类别0(数字0)准确率:{train_acc_0:.2f}%')print(f' 类别1(非0)准确率:{train_acc_1:.2f}%')print(f' 训练集平均准确率:{train_avg_acc:.2f}%')# 训练结束,终止计时end_time=datetime.datetime.now()print(f"\n训练用时:{end_time-start_time}")

输出:

Epoch 1/10:
训练集准确率: 89.86%
类别0(数字0)准确率: 0.47%
类别1(非0)准确率: 99.65%
训练集平均准确率: 50.06%
Epoch 2/10:
训练集准确率: 95.36%
类别0(数字0)准确率: 56.07%
类别1(非0)准确率: 99.67%
训练集平均准确率: 77.87%
Epoch 3/10:
训练集准确率: 98.41%
类别0(数字0)准确率: 88.23%
类别1(非0)准确率: 99.52%
训练集平均准确率: 93.88%
Epoch 4/10:
训练集准确率: 98.77%
类别0(数字0)准确率: 91.44%
类别1(非0)准确率: 99.57%
训练集平均准确率: 95.50%
Epoch 5/10:
训练集准确率: 98.90%
类别0(数字0)准确率: 93.01%
类别1(非0)准确率: 99.54%
训练集平均准确率: 96.28%
Epoch 6/10:
训练集准确率: 99.00%
类别0(数字0)准确率: 93.47%
类别1(非0)准确率: 99.60%
训练集平均准确率: 96.53%
Epoch 7/10:
训练集准确率: 99.01%
类别0(数字0)准确率: 93.99%
类别1(非0)准确率: 99.56%
训练集平均准确率: 96.77%
Epoch 8/10:
训练集准确率: 99.03%
类别0(数字0)准确率: 94.14%
类别1(非0)准确率: 99.57%
训练集平均准确率: 96.85%
Epoch 9/10:
训练集准确率: 99.08%
类别0(数字0)准确率: 94.55%
类别1(非0)准确率: 99.58%
训练集平均准确率: 97.06%
Epoch 10/10:
训练集准确率: 99.09%
类别0(数字0)准确率: 94.55%
类别1(非0)准确率: 99.59%
训练集平均准确率: 97.07%

训练用时: 0:00:52.823934

### 5.训练好的模型对验证集的预测数据model.eval()all_predictions=[]all_labels=[]all_probabilities=[]# 用于计算AUC的概率withtorch.no_grad():forbatch_X,batch_yinval_loader:outputs=model(batch_X)probabilities=torch.softmax(outputs,dim=1)# 获取概率_,predicted=torch.max(outputs.data,1)_,labels=torch.max(batch_y.data,1)all_predictions.append(predicted.cpu())all_labels.append(labels.cpu())all_probabilities.append(probabilities[:,1].cpu())# 转换为numpy数组all_predictions=torch.cat(all_predictions).numpy()all_labels=torch.cat(all_labels).numpy()all_probabilities=torch.cat(all_probabilities).numpy()### 6.计算验证集上的准确率与平均准确率overall_accuracy=metrics.accuracy_score(all_labels,all_predictions)*100balanced_acc=metrics.balanced_accuracy_score(all_labels,all_predictions)*100print(f'验证集准确率:{overall_accuracy:.2f}%')print(f'验证集平均准确率:{balanced_acc:.2f}%')

输出:

验证集准确率: 99.07%
验证集平均准确率: 97.94%

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!