YOLO26 自定义损失函数 分类任务自定义损失的接口约定
flyfish
这个约定是 分类训练循环中调用损失函数的固定调用契约,自定义损失类必须完全符合这个契约,才能被框架正常识别、调用,不会出现参数不匹配、返回值解包失败等报错。
分别约束了「调用形式、入参格式、返回值格式、挂载位置」:
1. 调用形式约定:必须实现__call__方法,实例可直接调用
框架在训练迭代中,会以函数调用的方式直接使用损失函数实例,框架内部的调用逻辑(简化版)为:
# 训练循环中框架的固定调用写法loss,loss_items=self.model.criterion(preds,batch)因此自定义损失类必须实现__call__方法,让类的实例可以像函数一样被直接调用。
普通 PyTorch 损失函数继承nn.Module,通过forward方法实现计算,本质也是依赖nn.Module自带的__call__机制。YOLO 分类的损失不强制要求继承nn.Module,只要类实现了__call__即可正常工作。
2. 入参格式约定:固定接收 2 个参数,顺序不可修改
损失函数的__call__方法必须固定接收两个入参,顺序为(模型预测结果, 批次数据字典),不可调换、不可增减参数。
第一个参数:preds(模型前向输出)
分类模型前向传播的输出结果
格式兼容:不同版本 Ultralytics 中,分类模型的输出有两种形式:
- 直接返回形状为
[batch_size, 类别数]的分类 logits 张量 - 返回元组/列表,通常结构为
[中间特征图, 最终分类logits],有效预测值在第二个位置
对应代码中的兼容处理:
# 兼容两种输出格式,提取最终的分类预测张量preds=preds[1]ifisinstance(preds,(list,tuple))elsepreds第二个参数:batch(批次数据字典)
数据加载器(DataLoader)返回的单批次数据,固定为字典格式
分类任务固定包含键:"cls",对应形状为[batch_size]的类别索引标签(不是 one-hot 编码)
损失计算时,必须通过batch["cls"]取出真实标签
示例:batch=8 的二分类任务中,batch["cls"]是形如tensor([0, 1, 0, 1, 1, 0, 0, 1])的一维张量
3. 返回值约定:必须返回二元组
必须返回 2 个值,框架会自动解包,少返回/多返回都会直接触发报错。
| 返回值顺序 | 作用 | 格式要求 |
|---|---|---|
| 第一个值 | 用于反向传播,更新模型权重 | 必须是带计算图的标量张量(可求导),通常是批次内所有样本损失求均值后的结果 |
| 第二个值 | 用于训练日志统计、进度条打印、指标文件记录 | 必须是分离梯度后的损失值(.detach()),不参与计算图,避免显存泄漏 |
对应代码中的标准实现:
loss=focal_loss.mean()# 第一个值:带梯度的标量损失,用于反向传播更新参数returnloss,loss.detach()# 第二个值:脱梯度的损失值,仅用于日志打印和统计目标检测任务中会返回多个损失项的字典,但分类任务只有单损失,直接返回脱梯度的标量即可。
4. 挂载位置约定:必须挂载为模型的criterion属性
框架是通过self.model.criterion来定位损失函数的,因此无论你用哪种注入方式,最终都要把自定义损失的实例,赋值给模型实例的.criterion属性。
子类化标准法:模型初始化时自动调用init_criterion()生成实例并赋值给self.criterion,属于框架原生的标准流程
示例
下面是一个最小化的、完全符合接口约定的自定义损失(包装原生交叉熵),可以直接接入YOLO分类训练:
fromtorch.nnimportCrossEntropyLossclassSimpleCustomLoss:def__init__(self,label_smoothing=0.1):self.ce=CrossEntropyLoss(label_smoothing=label_smoothing)# 约定1:实现 __call__ 方法# 约定2:固定入参顺序 preds, batchdef__call__(self,preds,batch):# 兼容模型输出格式preds=preds[1]ifisinstance(preds,(list,tuple))elsepreds# 从 batch 字典中取出分类标签loss=self.ce(preds,batch["cls"])# 约定3:返回 (带梯度损失, 脱梯度损失) 二元组returnloss,loss.detach()