ConvNeXt 的 torchvision 版本 推理实现
flyfish
importtorchimportosfromPILimportImagefromtorchvisionimporttransforms# 引入你自定义的ConvNeXt模型文件(确保和本文件同目录)fromconvnext_tinyimportconvnext_tiny# ===================== 配置 =====================# 本地模型权重路径(.pth / .pt / .pth.tar 格式)MODEL_WEIGHT_PATH="convnext_tiny-983f1562.pth"# 测试图像路径TEST_IMAGE_PATH="test.jpg"# 分类类别数(必须和训练模型时的num_classes一致)NUM_CLASSES=1000# 输出Top-K预测结果TOPK=5# =====================================================================# 设备配置:自动使用GPU,无GPU则用CPUDEVICE=torch.device("cuda"iftorch.cuda.is_available()else"cpu")defget_convnext_preprocess():""" ConvNeXt 官方标准图像预处理(必须严格匹配训练流程) 步骤:Resize → CenterCrop → ToTensor → Normalize """preprocess=transforms.Compose([# 缩放到256像素transforms.Resize(256),# 中心裁剪为224x224(ConvNeXt-Tiny标准输入尺寸)transforms.CenterCrop(224),# 转为Tensor,数值归一化到 [0, 1]transforms.ToTensor(),# ImageNet 标准归一化(ConvNeXt训练使用的均值/方差)transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])])returnpreprocessdefload_local_model(model_path:str,num_classes:int):""" 初始化自定义ConvNeXt模型 + 加载本地权重 :param model_path: 本地权重文件路径 :param num_classes: 分类类别数 :return: 评估模式的模型 """# 1. 初始化模型结构model=convnext_tiny(num_classes=num_classes)# 2. 检查权重文件是否存在ifnotos.path.exists(model_path):raiseFileNotFoundError(f"模型权重不存在!路径:{model_path}")# 3. 加载本地权重(map_location适配CPU/GPU)checkpoint=torch.load(model_path,map_location=DEVICE)# strict=False:适配修改了分类头(num_classes)的模型,避免权重不匹配报错model.load_state_dict(checkpoint,strict=False)# 4. 模型迁移到设备 + 开启评估模式(关闭dropout/bn训练特性)model=model.to(DEVICE)model.eval()returnmodeldefimage_preprocess(image_path:str,transform):""" 图像完整预处理:加载 → 转RGB → 预处理 → 增加Batch维度 → 设备迁移 """# 检查图像是否存在ifnotos.path.exists(image_path):raiseFileNotFoundError(f"测试图像不存在!路径:{image_path}")# 1. 加载图像 + 强制转为RGB(解决灰度图/透明通道报错问题)image=Image.open(image_path).convert("RGB")# 2. 应用预处理 [H, W, C] → [C, H, W]tensor_image=transform(image)# 3. 增加Batch维度:[C, H, W] → [1, C, H, W](模型要求批量输入)tensor_image=tensor_image.unsqueeze(0)# 4. 迁移到设备tensor_image=tensor_image.to(DEVICE)returnimage,tensor_imagedefmodel_infer(model,input_tensor):""" 模型推理 + 后处理 1. 无梯度推理(节省显存/加速) 2. Softmax将Logits转为概率 3. 提取Top-K置信度和类别索引 """withtorch.no_grad():# 禁用梯度计算,推理必备# 前向传播,输出原始预测值 (logits)outputs=model(input_tensor)# 后处理1:softmax归一化,得到0~1的概率值probabilities=torch.softmax(outputs,dim=1)# 后处理2:获取Top-K 类别索引 + 置信度topk_probs,topk_indices=torch.topk(probabilities,k=TOPK)# 转为numpy格式,方便打印topk_probs=topk_probs.cpu().numpy()[0]topk_indices=topk_indices.cpu().numpy()[0]returntopk_probs,topk_indicesdefprint_result(topk_probs,topk_indices,class_names=None):"""打印最终预测结果"""print("\n"+"="*50)print(f"图像分类预测结果(Top-{TOPK})")print("="*50)ifclass_namesisNone:class_names=[f"类别_{i}"foriinrange(NUM_CLASSES)]fori,(idx,prob)inenumerate(zip(topk_indices,topk_probs)):print(f"Top{i+1}:{class_names[idx]:<10}| 置信度:{prob:.4f}({prob*100:.2f}%)")print("="*50)if__name__=='__main__':# ========== 1. 初始化预处理工具 ==========transform=get_convnext_preprocess()print("图像预处理配置完成")# ========== 2. 加载本地模型 ==========print(f"正在加载本地模型:{MODEL_WEIGHT_PATH}")model=load_local_model(MODEL_WEIGHT_PATH,NUM_CLASSES)print("模型加载完成,已切换到评估模式")# ========== 3. 图像预处理 ==========print(f"正在处理图像:{TEST_IMAGE_PATH}")raw_image,input_tensor=image_preprocess(TEST_IMAGE_PATH,transform)print("图像预处理完成")# ========== 4. 模型推理 ==========print("开始模型推理...")topk_probs,topk_indices=model_infer(model,input_tensor)print("推理完成")# ========== 5. 打印结果 ==========# 替换为真实类别名称# 示例:custom_classes = ["cat", "dog", "car", "tree"]custom_classes=Noneprint_result(topk_probs,topk_indices,custom_classes)输出
图像预处理配置完成 正在加载本地模型:convnext_tiny-983f1562.pth 模型加载完成,已切换到评估模式 正在处理图像:test.jpg 图像预处理完成 开始模型推理... 推理完成==================================================图像分类预测结果(Top-5)==================================================Top1:类别_209|置信度:0.3588(35.88%)Top2:类别_227|置信度:0.2372(23.72%)Top3:类别_168|置信度:0.0283(2.83%)Top4:类别_208|置信度:0.0222(2.22%)Top5:类别_166|置信度:0.0149(1.49%)==================================================初始化准备 → 模型加载 → 图像预处理 → 前向推理 → 后处理输出
整体推理总流程
从__main__入口开始,代码会按顺序执行 5 个步骤:
- 初始化图像预处理规则
- 加载本地权重,构建并初始化 ConvNeXt 模型
- 读取测试图像,完成标准化预处理
- 输入模型执行前向推理,得到原始预测值
- 对预测结果做后处理,格式化打印 Top-K 分类结果
分阶段说明
阶段1:全局配置与预处理初始化
对应代码:顶部配置项 +get_convnext_preprocess()
这是推理的前置准备,目标是保证推理时的图像处理逻辑和模型训练时完全一致,否则会出现精度大幅下降。
固定全局超参数
指定权重路径、测试图路径、分类类别数、输出 Top-K 数量;
自动选择运行设备:有 GPU 则用 CUDA 加速,没有则自动回退到 CPU,保证代码可跨环境运行。定义标准预处理流水线
按顺序执行 4 步操作,和 ImageNet 训练流程严格对齐:Resize(256):把图像短边缩放到 256 像素,保持宽高比不变;CenterCrop(224):从缩放后的图像中心裁剪出 224×224 的正方形,匹配 ConvNeXt-Tiny 的标准输入尺寸;ToTensor():把 PIL 图像(0~255 整数,HWC 格式)转为 PyTorch 张量(0~1 浮点数,CHW 格式);Normalize:用 ImageNet 数据集的均值和标准差做通道归一化,把数据分布对齐到训练时的分布,是保证精度的关键一步。
阶段2:模型加载与初始化
对应函数:load_local_model()
目标是构建和权重匹配的网络结构,载入预训练参数,切换到推理专用状态。
构建模型骨架
调用convnext_tiny(num_classes=1000),按照我们之前分析的 CNBlock + 四阶段结构,实例化一个随机初始化的 ConvNeXt-Tiny 网络。载入本地权重文件
用torch.load读取.pth权重文件,map_location=DEVICE保证权重自动加载到对应设备,避免 GPU 权重在 CPU 上报错;model.load_state_dict(checkpoint, strict=False)把权重参数填入模型骨架。strict=False是容错设计:如果修改了分类头类别数、或有少量层不匹配,不会直接报错,只加载匹配的层,适配微调后的模型推理。
切换到推理模式
model.to(DEVICE):把模型整体迁移到 GPU/CPU;model.eval():开启评估模式,这是推理的必要操作。
它会关闭所有训练专用的随机组件:比如StochasticDepth(随机深度)、Dropout 都会停止工作,保证每次推理结果稳定、可复现;同时 BatchNorm/LayerNorm 会用训练好的全局统计量,而非 batch 统计量。
阶段3:图像预处理与格式对齐
对应函数:image_preprocess()
目标是把一张普通图片,转换成模型能接收的标准输入张量。
图像读取与格式统一
Image.open(...).convert("RGB"):强制转为 3 通道 RGB 格式,兼容灰度图、带透明通道的 RGBA 图,避免通道数不匹配报错。应用预处理流水线
调用阶段1定义的 transform,完成缩放、裁剪、转张量、归一化,输出形状为[3, 224, 224]的张量(通道 × 高 × 宽)。补充 Batch 维度
unsqueeze(0)把形状从[C, H, W]扩展为[1, C, H, W]。
PyTorch 模型默认接受批量输入,第一维是 batch 大小;单张图推理也要补一个 batch=1 的维度,否则维度不匹配会报错。数据迁移到对应设备
把图像张量移到和模型相同的设备(GPU/CPU),保证运算时设备一致。
阶段4:模型前向推理
对应函数:model_infer()
这是计算环节,执行前向传播得到预测结果,同时做推理优化。
禁用梯度计算
with torch.no_grad():是推理的标准写法:
推理不需要反向传播,关闭梯度计算可以大幅节省显存、提升推理速度;
避免推理过程中意外修改模型参数。前向传播,得到原始输出
outputs = model(input_tensor)调用模型的forward方法,数据依次经过 Stem → 4个Stage → 全局平均池化 → 分类头,最终输出形状为[1, 1000]的原始预测值(logits),数值没有归一化,不代表概率。后处理:转概率 + 取Top-K
torch.softmax(outputs, dim=1):对 1000 个类别的原始输出做归一化,得到 0~1 之间的概率值,所有类别概率和为 1;torch.topk(probabilities, k=5):从 1000 个类别中选出置信度最高的 5 个,返回对应的概率值和类别索引。结果转numpy格式
把张量从 GPU 移回 CPU,转为 numpy 数组,方便后续打印输出。
阶段5:结果格式化输出
对应函数:print_result()
把 Top-K 的索引和置信度转换成可读的结果打印出来:
如果传入了自定义类别名称列表,就显示具体类别名;
默认显示类别_索引的占位符,同时打印置信度的小数和百分比形式,直观展示预测可信度。