基于知识蒸馏学习的轻量化高光谱图像分类模型代码 Pytorch制作 教师模型采用Resnet18,学生模型是对教师模型进行改进的轻量化模型,外加最新的注意力机制模块。 在一定基础上,可以超过教师模型。 全套项目,包含网络模型,训练代码,预测代码,直接下载数据集就能跑,拿上就能用,简单又省事儿 需要讲解另算) 内附indian pines数据集,采用30%的数据作为训练集,并附上迭代10次的模型结果,准确率90以上。
项目简介
本项目是一个基于知识蒸馏(Knowledge Distillation)技术的高光谱图像分类系统,采用教师-学生网络架构,实现高效的图像分类任务。系统包含完整的数据预处理、模型训练、知识蒸馏和预测推理流程。
核心文件结构
1. utils.py - 数据预处理工具集
主要功能:
- PCA降维:
applyPCA()函数对高光谱数据进行主成分分析,将光谱维度从上百个波段压缩到指定数量(如3个) - 边缘填充:
padWithZeros()函数对图像边缘进行零填充,便于后续提取图像块 - 图像块提取:
createImageCubes()函数以每个像素为中心提取指定大小的局部图像块(如25×25) - 数据划分:
splitTrainTestSet()函数按比例划分训练集和测试集,支持分层抽样 - 数据加载:提供两种数据加载方式,支持ENVI格式和MATLAB格式的高光谱数据
关键参数:
windowSize=25:图像块大小pca_components=3:PCA降维后的主成分数量testRatio=0.7:测试集比例
2. teacher.py - 教师网络
网络架构:基于ResNet-18的标准卷积神经网络
- 基础模块:
basicblockteacher使用标准3×3卷积 - 网络结构:
- 输入层:7×7卷积,步长2,最大池化
- 4个残差层:通道数[32, 64, 128, 256],每层2个残差块
- 分类层:全局平均池化 + 全连接层
- 特点:参数量较大,精度高,作为知识蒸馏的教师模型
3. student.py - 学生网络
轻量化改进:
- 异构卷积:
Conv2db类结合分组卷积和点卷积,减少计算量 - 坐标注意力机制:
CoordAtt模块增强特征表达能力 - 网络结构:在ResNet基础上替换部分标准卷积为轻量化模块
核心组件:
hsigmoid和hswish:激活函数CoordAtt:坐标注意力机制,通过空间注意力提升特征质量
4. train.py - 训练框架
三种训练模式:
基于知识蒸馏学习的轻量化高光谱图像分类模型代码 Pytorch制作 教师模型采用Resnet18,学生模型是对教师模型进行改进的轻量化模型,外加最新的注意力机制模块。 在一定基础上,可以超过教师模型。 全套项目,包含网络模型,训练代码,预测代码,直接下载数据集就能跑,拿上就能用,简单又省事儿 需要讲解另算) 内附indian pines数据集,采用30%的数据作为训练集,并附上迭代10次的模型结果,准确率90以上。
1. 单独训练教师网络
trainer.train_teacher()- 使用交叉熵损失训练
- 保存为
teachernetalonekl
2. 单独训练学生网络(注释状态)
# trainer.train_student()- 基准对比实验
3. 知识蒸馏训练
trainer.train_teacher_student()- 核心蒸馏损失函数:
def distillation(self, y, labels, teacher_scores, temp, alpha): return self.KLDivLoss(F.log_softmax(y/temp, dim=1), F.softmax(teacher_scores/temp, dim=1)) * (temp*temp*2.0*alpha) + \ F.cross_entropy(y, labels) * (1. - alpha)- 结合KL散度损失和交叉熵损失
- 温度参数
temp=10,平衡参数alpha=0.9
训练配置:
- epoch:10
- 学习率:0.001
- 批大小:128
- 设备:自动检测CUDA
5. pre.py - 预测推理
功能流程:
- 加载训练好的学生网络模型(
teacherstudentnet.pkl) - 对输入高光谱数据进行PCA处理和边缘填充
- 逐像素提取图像块并进行分类预测
- 可视化显示分类结果
关键特性:
- 跳过标签为0的背景像素
- 实时显示处理进度
- 使用spectral库进行结果可视化
数据处理流程
原始数据
- 输入:Indianpines.mat(高光谱图像)和Indianpines_gt.mat(地面真实标签)
- 维度:145×145像素,224个光谱波段
处理步骤
- PCA降维:224波段 → 3个主成分
- 边缘填充:处理边界像素问题
- 图像块提取:生成25×25×3的图像块
- 数据转置:适应PyTorch的(N, C, H, W)格式
- 数据划分:70%测试集,分层抽样
系统特点
1. 知识蒸馏优势
- 模型压缩:将教师网络知识迁移到更小的学生网络
- 性能保持:学生网络接近教师网络的准确率
- 推理加速:轻量化网络更适合部署
2. 技术创新
- 异构卷积:平衡计算效率和特征提取能力
- 坐标注意力:提升空间特征感知能力
- 多损失函数:结合蒸馏损失和分类损失
3. 工程实践
- 完整流程:从数据预处理到模型部署的全流程实现
- 模块化设计:各功能模块独立,便于维护和扩展
- 可视化支持:训练过程监控和结果可视化
应用场景
本系统适用于各种高光谱图像分类任务,如:
- 农业遥感监测
- 地质矿物识别
- 环境变化检测
- 军事目标识别
使用说明
- 准备高光谱数据文件
- 运行train.py进行模型训练
- 使用pre.py进行预测推理
- 查看生成的分类结果图
该系统通过知识蒸馏技术实现了高精度与高效率的平衡,为高光谱图像分类提供了一套完整的解决方案。