摘要
随着无人机技术和计算机视觉的快速发展,机场航拍图像中的小目标检测成为了航空安全、交通管理等领域的重要研究方向。本文提出了一种基于YOLOv8/YOLOv7/YOLOv6/YOLOv5的机场航拍小目标检测系统,实现了对机场场景中飞机、车辆、设备等小目标的高精度检测。系统采用Python编程语言开发,结合PySide6构建了直观友好的图形用户界面,并提供了完整的训练代码和预训练模型。实验结果表明,该系统在多个公开数据集上均取得了优异的检测性能,特别在小目标检测方面表现出色。
关键词:YOLO;目标检测;航拍图像;小目标;PySide6;深度学习
目录
摘要
1. 引言
1.1 研究背景
1.2 研究意义
1.3 研究现状
2. 相关技术介绍
2.1 YOLO系列算法发展
2.1.1 YOLOv5
2.1.2 YOLOv6
2.1.3 YOLOv7
2.1.4 YOLOv8
2.2 小目标检测技术
3. 系统设计与实现
3.1 系统架构
3.2 数据集准备
3.2.1 参考数据集
3.2.2 数据预处理
3.3 模型训练
3.3.1 训练配置
3.3.2 训练代码
3.4 图形用户界面设计
3.4.1 主界面设计
1. 引言
1.1 研究背景
机场作为现代交通的重要枢纽,其安全管理和运营效率直接关系到航空运输的稳定性和安全性。传统的机场监控主要依赖人工观察和固定摄像头,存在视野有限、效率低下等问题。随着无人机技术的发展,航拍图像为机场监控提供了全新的视角,能够覆盖更广阔的区域并获取高分辨率的图像数据。
然而,航拍图像中的目标检测面临诸多挑战:
目标尺度小:在高空拍摄的图片中,飞机、车辆等目标往往只占图像的极小部分
目标密集:机场区域目标分布密集,存在严重的遮挡问题
背景复杂:机场环境包含跑道、建筑、草地等多种背景元素
光照变化:不同时间、天气条件下的光照差异较大
1.2 研究意义
开发高效的机场航拍小目标检测系统具有重要的现实意义:
安全监控:实时检测异常情况,预防安全事故
交通管理:优化飞机和车辆的调度,提高运行效率
设备维护:监控机场设备状态,及时发现故障
数据分析:为机场规划和运营提供数据支持
1.3 研究现状
近年来,基于深度学习的目标检测算法取得了显著进展。YOLO(You Only Look Once)系列算法作为单阶段检测器的代表,以其速度快、精度高的特点受到广泛关注。从YOLOv1到最新的YOLOv8,每一代都在速度、精度和模型复杂度之间取得了更好的平衡。
2. 相关技术介绍
2.1 YOLO系列算法发展
2.1.1 YOLOv5
YOLOv5由Ultralytics公司开发,采用PyTorch框架实现,具有以下特点:
自适应锚框计算
多种模型尺度(s、m、l、x)
数据增强策略丰富
易于部署和训练
2.1.2 YOLOv6
YOLOv6由美团视觉智能部提出,主要改进包括:
高效的网络结构设计
自蒸馏训练策略
硬件友好的架构
2.1.3 YOLOv7
YOLOv7在精度和速度上实现了新的突破:
扩展的高效层聚合网络
复合缩放方法
训练策略优化
2.1.4 YOLOv8
YOLOv8是YOLO系列的最新版本,主要特性包括:
无锚框检测机制
新的骨干网络和特征融合策略
更灵活的任务支持(检测、分割、分类)
2.2 小目标检测技术
小目标检测是计算机视觉领域的难点问题,常用技术包括:
多尺度特征融合:融合不同层次的特征图
注意力机制:增强对小目标的关注
数据增强:提高模型对小目标的泛化能力
高分辨率输入:保持更多的细节信息
3. 系统设计与实现
3.1 系统架构
本系统采用模块化设计,主要包括以下模块:
text
机场航拍小目标检测系统 ├── 数据预处理模块 ├── 模型训练模块 ├── 模型推理模块 ├── 结果可视化模块 └── 用户界面模块
3.2 数据集准备
3.2.1 参考数据集
VisDrone数据集
包含大量无人机航拍图像
标注类别包括行人、车辆、自行车等
图像分辨率高,目标尺度变化大
DOTA数据集
专为航拍图像目标检测设计
包含15个类别,超过1800张图像
目标具有方向性标注
UCAS-AOD数据集
包含飞机和汽车两类目标
图像来源于Google Earth
适合机场场景检测
自定义机场数据集
可以根据实际需求构建包含以下类别的数据集:text
- airplane(飞机) - vehicle(车辆) - person(人员) - baggage_cart(行李车) - fuel_truck(加油车) - catering_truck(餐车) - airport_bus(机场巴士) - security_vehicle(安保车辆)
3.2.2 数据预处理
python
import cv2 import albumentations as A from albumentations.pytorch import ToTensorV2 def create_train_transforms(image_size=640): """创建训练数据增强管道""" return A.Compose([ A.RandomResizedCrop(height=image_size, width=image_size, scale=(0.5, 1.0)), A.HorizontalFlip(p=0.5), A.RandomBrightnessContrast(p=0.2), A.HueSaturationValue(p=0.2), A.OneOf([ A.MotionBlur(p=0.2), A.MedianBlur(blur_limit=3, p=0.1), A.Blur(blur_limit=3, p=0.1), ], p=0.2), A.CLAHE(clip_limit=2.0, tile_grid_size=(8, 8), p=0.2), A.RandomGamma(p=0.2), A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255.0), ToTensorV2(), ], bbox_params=A.BboxParams( format='yolo', label_fields=['class_labels'] )) def create_val_transforms(image_size=640): """创建验证数据转换管道""" return A.Compose([ A.Resize(height=image_size, width=image_size), A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255.0), ToTensorV2(), ], bbox_params=A.BboxParams( format='yolo', label_fields=['class_labels'] ))
3.3 模型训练
3.3.1 训练配置
python
# train_config.py class TrainConfig: def __init__(self): # 模型选择 self.model_type = 'yolov8' # 可选: yolov5, yolov6, yolov7, yolov8 self.model_size = 's' # 可选: n, s, m, l, x # 训练参数 self.epochs = 300 self.batch_size = 16 self.img_size = 640 self.workers = 8 self.device = 'cuda' # 或 'cpu' # 优化器 self.optimizer = 'SGD' self.lr0 = 0.01 # 初始学习率 self.lrf = 0.01 # 最终学习率 self.momentum = 0.937 self.weight_decay = 0.0005 self.warmup_epochs = 3.0 self.warmup_momentum = 0.8 self.warmup_bias_lr = 0.1 # 数据增强 self.hsv_h = 0.015 # 色调增强 self.hsv_s = 0.7 # 饱和度增强 self.hsv_v = 0.4 # 明度增强 self.degrees = 0.0 # 旋转角度 self.translate = 0.1 # 平移 self.scale = 0.5 # 缩放 self.shear = 0.0 # 剪切 self.perspective = 0.0 # 透视变换 self.flipud = 0.0 # 上下翻转 self.fliplr = 0.5 # 左右翻转 self.mosaic = 1.0 # Mosaic增强 self.mixup = 0.0 # Mixup增强 # 小目标检测增强 self.small_object_aug = True self.copy_paste_aug = 0.0 # 小目标复制粘贴增强 # 保存设置 self.save_dir = 'runs/train' self.save_period = 10 # 每多少轮保存一次 # 早停设置 self.patience = 100 self.verbose = True
3.3.2 训练代码
python
# train.py import os import torch import yaml from pathlib import Path from datetime import datetime import argparse from tqdm import tqdm import numpy as np class AirportObjectDetectorTrainer: def __init__(self, config): self.config = config self.device = torch.device(config.device if torch.cuda.is_available() else 'cpu') self.setup_model() self.setup_data() self.setup_optimizer() def setup_model(self): """根据配置设置模型""" if self.config.model_type == 'yolov5': from models.yolov5 import Model self.model = Model(f'weights/yolov5{self.config.model_size}.pt') elif self.config.model_type == 'yolov8': from ultralytics import YOLO self.model = YOLO(f'weights/yolov8{self.config.model_size}.pt') elif self.config.model_type == 'yolov7': from models.yolov7 import Model self.model = Model(f'weights/yolov7{self.config.model_size}.pt') elif self.config.model_type == 'yolov6': from yolov6.models.yolo import Model self.model = Model(f'weights/yolov6{self.config.model_size}.pt') self.model.to(self.device) def setup_data(self): """设置数据加载器""" from datasets import AirportDataset # 训练数据集 self.train_dataset = AirportDataset( data_dir='data/train', img_size=self.config.img_size, augment=True, small_object_aug=self.config.small_object_aug ) # 验证数据集 self.val_dataset = AirportDataset( data_dir='data/val', img_size=self.config.img_size, augment=False ) self.train_loader = torch.utils.data.DataLoader( self.train_dataset, batch_size=self.config.batch_size, shuffle=True, num_workers=self.config.workers, pin_memory=True, collate_fn=collate_fn ) self.val_loader = torch.utils.data.DataLoader( self.val_dataset, batch_size=self.config.batch_size, shuffle=False, num_workers=self.config.workers, pin_memory=True, collate_fn=collate_fn ) def setup_optimizer(self): """设置优化器""" if self.config.optimizer == 'SGD': self.optimizer = torch.optim.SGD( self.model.parameters(), lr=self.config.lr0, momentum=self.config.momentum, weight_decay=self.config.weight_decay ) elif self.config.optimizer == 'Adam': self.optimizer = torch.optim.Adam( self.model.parameters(), lr=self.config.lr0, weight_decay=self.config.weight_decay ) # 学习率调度器 self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( self.optimizer, T_0=10, T_mult=2, eta_min=self.config.lr0 * 0.01 ) def train_epoch(self, epoch): """训练一个epoch""" self.model.train() total_loss = 0 pbar = tqdm(self.train_loader, desc=f'Epoch {epoch}/{self.config.epochs}') for batch_idx, (images, targets) in enumerate(pbar): images = images.to(self.device) targets = [t.to(self.device) for t in targets] # 前向传播 loss_dict = self.model(images, targets) losses = sum(loss for loss in loss_dict.values()) # 反向传播 self.optimizer.zero_grad() losses.backward() self.optimizer.step() total_loss += losses.item() pbar.set_postfix({'loss': losses.item()}) return total_loss / len(self.train_loader) def validate(self): """验证模型""" self.model.eval() metrics = { 'precision': 0, 'recall': 0, 'map50': 0, 'map': 0 } with torch.no_grad(): for images, targets in tqdm(self.val_loader, desc='Validating'): images = images.to(self.device) outputs = self.model(images) # 计算评估指标 # ... 省略详细实现 return metrics def train(self): """主训练循环""" best_map = 0 patience_counter = 0 for epoch in range(self.config.epochs): # 训练一个epoch train_loss = self.train_epoch(epoch) # 验证 if epoch % 5 == 0: metrics = self.validate() print(f'Epoch {epoch}: Loss={train_loss:.4f}, ' f'Precision={metrics["precision"]:.4f}, ' f'Recall={metrics["recall"]:.4f}, ' f'mAP50={metrics["map50"]:.4f}, ' f'mAP={metrics["map"]:.4f}') # 保存最佳模型 if metrics['map'] > best_map: best_map = metrics['map'] torch.save({ 'epoch': epoch, 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'metrics': metrics }, f'{self.config.save_dir}/best.pt') patience_counter = 0 else: patience_counter += 1 # 早停检查 if patience_counter >= self.config.patience: print(f'Early stopping at epoch {epoch}') break # 更新学习率 self.scheduler.step() def collate_fn(batch): """自定义批次整理函数""" images, targets = zip(*batch) return torch.stack(images, 0), targets def main(): parser = argparse.ArgumentParser() parser.add_argument('--config', type=str, default='config/train_config.yaml', help='配置文件路径') args = parser.parse_args() # 加载配置 with open(args.config, 'r') as f: config_dict = yaml.safe_load(f) config = TrainConfig() for key, value in config_dict.items(): if hasattr(config, key): setattr(config, key, value) # 创建训练器并开始训练 trainer = AirportObjectDetectorTrainer(config) trainer.train() if __name__ == '__main__': main()3.4 图形用户界面设计
3.4.1 主界面设计
python
# ui/main_window.py import sys import os from pathlib import Path import cv2 import numpy as np from PySide6.QtWidgets import * from PySide6.QtCore import * from PySide6.QtGui import * class AirportDetectionUI(QMainWindow): def __init__(self): super().__init__() self.model = None self.current_image = None self.detection_results = [] self.init_ui() self.load_stylesheet() def init_ui(self): """初始化用户界面""" self.setWindowTitle("机场航拍小目标检测系统") self.setGeometry(100, 100, 1400, 900) # 创建中心部件 central_widget = QWidget() self.setCentralWidget(central_widget) # 主布局 main_layout = QHBoxLayout(central_widget) # 左侧控制面板 control_panel = QWidget() control_layout = QVBoxLayout(control_panel) control_layout.setSpacing(10) # 模型选择 model_group = QGroupBox("模型选择") model_layout = QVBoxLayout() self.model_combo = QComboBox() self.model_combo.addItems(["YOLOv5", "YOLOv6", "YOLOv7", "YOLOv8"]) self.model_combo.currentTextChanged.connect(self.on_model_changed) self.size_combo = QComboBox() self.size_combo.addItems(["nano", "small", "medium", "large", "xlarge"]) model_layout.addWidget(QLabel("模型类型:")) model_layout.addWidget(self.model_combo) model_layout.addWidget(QLabel("模型大小:")) model_layout.addWidget(self.size_combo) model_group.setLayout(model_layout) # 文件操作 file_group = QGroupBox("文件操作") file_layout = QVBoxLayout() self.load_image_btn = QPushButton("加载图片") self.load_image_btn.clicked.connect(self.load_image) self.load_video_btn = QPushButton("加载视频") self.load_video_btn.clicked.connect(self.load_video) self.load_folder_btn = QPushButton("加载文件夹") self.load_folder_btn.clicked.connect(self.load_folder) file_layout.addWidget(self.load_image_btn) file_layout.addWidget(self.load_video_btn) file_layout.addWidget(self.load_folder_btn) file_group.setLayout(file_layout) # 检测设置 detect_group = QGroupBox("检测设置") detect_layout = QVBoxLayout() self.conf_slider = QSlider(Qt.Horizontal) self.conf_slider.setRange(10, 100) self.conf_slider.setValue(50) self.conf_label = QLabel("置信度阈值: 0.5") self.iou_slider = QSlider(Qt.Horizontal) self.iou_slider.setRange(10, 90) self.iou_slider.setValue(45) self.iou_label = QLabel("IoU阈值: 0.45") self.conf_slider.valueChanged.connect(self.update_conf_label) self.iou_slider.valueChanged.connect(self.update_iou_label) detect_layout.addWidget(QLabel("置信度阈值:")) detect_layout.addWidget(self.conf_slider) detect_layout.addWidget(self.conf_label) detect_layout.addWidget(QLabel("IoU阈值:")) detect_layout.addWidget(self.iou_slider) detect_layout.addWidget(self.iou_label) detect_group.setLayout(detect_layout) # 功能按钮 self.detect_btn = QPushButton("开始检测") self.detect_btn.clicked.connect(self.detect_objects) self.detect_btn.setEnabled(False) self.export_btn = QPushButton("导出结果") self.export_btn.clicked.connect(self.export_results) self.export_btn.setEnabled(False) self.train_btn = QPushButton("训练模型") self.train_btn.clicked.connect(self.train_model) # 添加控制面板部件 control_layout.addWidget(model_group) control_layout.addWidget(file_group) control_layout.addWidget(detect_group) control_layout.addWidget(self.detect_btn) control_layout.addWidget(self.export_btn) control_layout.addWidget(self.train_btn) control_layout.addStretch() # 右侧显示区域 display_panel = QWidget() display_layout = QVBoxLayout(display_panel) # 图像显示标签 self.image_label = QLabel() self.image_label.setAlignment(Qt.AlignCenter) self.image_label.setMinimumSize(800, 600) self.image_label.setStyleSheet("border: 2px solid #cccccc; background-color: #f0f0f0;") # 状态栏 self.status_bar = QStatusBar() self.setStatusBar(self.status_bar) self.status_label = QLabel("就绪") self.status_bar.addWidget(self.status_label) # 结果表格 self.result_table = QTableWidget() self.result_table.setColumnCount(6) self.result_table.setHorizontalHeaderLabels(["ID", "类别", "置信度", "X", "Y", "宽高"]) self.result_table.setMaximumHeight(200) display_layout.addWidget(self.image_label) display_layout.addWidget(QLabel("检测结果:")) display_layout.addWidget(self.result_table) # 添加面板到主布局 main_layout.addWidget(control_panel, 1) main_layout.addWidget(display_panel, 3) def load_stylesheet(self): """加载样式表""" style = """ QMainWindow { background-color: #f5f5f5; } QGroupBox { font-weight: bold; border: 2px solid #cccccc; border-radius: 5px; margin-top: 10px; padding-top: 10px; } QGroupBox::title { subcontrol-origin: margin; left: 10px; padding: 0 5px 0 5px; } QPushButton { background-color: #4CAF50; border: none; color: white; padding: 10px; text-align: center; text-decoration: none; font-size: 14px; margin: 4px 2px; border-radius: 5px; } QPushButton:hover { background-color: #45a049; } QPushButton:disabled { background-color: #cccccc; } QComboBox, QSlider { margin: 5px; } QTableWidget { background-color: white; alternate-background-color: #f9f9f9; } """ self.setStyleSheet(style) def load_image(self): """加载图片""" file_name, _ = QFileDialog.getOpenFileName( self, "选择图片", "", "Image Files (*.png *.jpg *.jpeg *.bmp *.tif)" ) if file_name: self.current_image = cv2.imread(file_name) if self.current_image is not None: self.display_image(self.current_image) self.detect_btn.setEnabled(True) self.status_label.setText(f"已加载图片: {Path(file_name).name}") def display_image(self, image): """显示图片""" if image is not None: # 转换颜色空间 BGR -> RGB rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) height, width, channel = rgb_image.shape bytes_per_line = 3 * width qt_image = QImage(rgb_image.data, width, height, bytes_per_line, QImage.Format_RGB888) pixmap = QPixmap.fromImage(qt_image) # 缩放以适应标签 scaled_pixmap = pixmap.scaled( self.image_label.size(), Qt.KeepAspectRatio, Qt.SmoothTransformation ) self.image_label.setPixmap(scaled_pixmap) def on_model_changed(self, model_name): """模型类型改变事件""" self.status_label.setText(f"切换到模型: {model_name}") def update_conf_label(self, value): """更新置信度标签""" self.conf_label.setText(f"置信度阈值: {value/100:.2f}") def update_iou_label(self, value): """更新IoU标签""" self.iou_label.setText(f"IoU阈值: {value/100:.2f}") def detect_objects(self): """执行目标检测""" if self.current_image is not None: # 获取当前设置 conf_thres = self.conf_slider.value() / 100 iou_thres = self.iou_slider.value() / 100 # 执行检测 self.status_label.setText("检测中...") QApplication.processEvents() # 更新UI # 调用检测函数 results = self.perform_detection( self.current_image, conf_thres=conf_thres, iou_thres=iou_thres ) # 显示结果 self.display_results(results) self.status_label.setText(f"检测完成,发现 {len(results)} 个目标") def perform_detection(self, image, conf_thres=0.5, iou_thres=0.45): """执行检测的具体实现""" # 这里需要根据选择的模型调用相应的检测代码 # 为了简洁,这里使用伪代码 if self.model is None: # 加载模型 model_type = self.model_combo.currentText().lower() model_size = self.size_combo.currentText() self.model = self.load_model(model_type, model_size) # 预处理图像 processed_image = self.preprocess_image(image) # 推理 with torch.no_grad(): outputs = self.model(processed_image) # 后处理 results = self.postprocess(outputs, conf_thres, iou_thres) return results def display_results(self, results): """显示检测结果""" # 清空表格 self.result_table.setRowCount(0) # 绘制检测框 display_image = self.current_image.copy() for i, result in enumerate(results): class_name, confidence, bbox = result x1, y1, x2, y2 = map(int, bbox) # 绘制矩形框 color = self.get_color_for_class(class_name) cv2.rectangle(display_image, (x1, y1), (x2, y2), color, 2) # 绘制标签 label = f"{class_name}: {confidence:.2f}" cv2.putText(display_image, label, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) # 添加到表格 row_position = self.result_table.rowCount() self.result_table.insertRow(row_position) center_x = (x1 + x2) // 2 center_y = (y1 + y2) // 2 width = x2 - x1 height = y2 - y1 self.result_table.setItem(row_position, 0, QTableWidgetItem(str(i+1))) self.result_table.setItem(row_position, 1, QTableWidgetItem(class_name)) self.result_table.setItem(row_position, 2, QTableWidgetItem(f"{confidence:.4f}")) self.result_table.setItem(row_position, 3, QTableWidgetItem(str(center_x))) self.result_table.setItem(row_position, 4, QTableWidgetItem(str(center_y))) self.result_table.setItem(row_position, 5, QTableWidgetItem(f"{width}x{height}")) # 显示带检测结果的图像 self.display_image(display_image) self.export_btn.setEnabled(True) def get_color_for_class(self, class_name): """根据类别获取颜色""" colors = { 'airplane': (0, 0, 255), # 红色 'vehicle': (0, 255, 0), # 绿色 'person': (255, 0, 0), # 蓝色 'baggage_cart': (255, 255, 0), # 青色 'fuel_truck': (255, 0, 255), # 紫色 } return colors.get(class_name, (128, 128, 128)) def export_results(self): """导出检测结果""" file_name, _ = QFileDialog.getSaveFileName( self, "保存结果", "", "JSON Files (*.json);;Text Files (*.txt);;CSV Files (*.csv)" ) if file_name: # 保存结果到文件 with open(file_name, 'w') as f: # 导出逻辑 pass self.status_label.setText(f"结果已导出到: {file_name}") def train_model(self): """打开训练对话框""" dialog = TrainingDialog(self) dialog.exec() def load_video(self): """加载视频""" pass def load_folder(self): """加载文件夹""" pass class TrainingDialog(QDialog): """训练模型对话框""" def __init__(self, parent=None): super().__init__(parent) self.parent = parent self.init_ui() def init_ui(self): self.setWindowTitle("模型训练") self.setGeometry(200, 200, 600, 400) layout = QVBoxLayout() # 训练配置选项 config_group = QGroupBox("训练配置") config_layout = QFormLayout() self.epochs_spin = QSpinBox() self.epochs_spin.setRange(1, 1000) self.epochs_spin.setValue(100) self.batch_spin = QSpinBox() self.batch_spin.setRange(1, 64) self.batch_spin.setValue(16) self.lr_spin = QDoubleSpinBox() self.lr_spin.setRange(0.0001, 0.1) self.lr_spin.setValue(0.01) self.lr_spin.setSingleStep(0.001) config_layout.addRow("训练轮数:", self.epochs_spin) config_layout.addRow("批大小:", self.batch_spin) config_layout.addRow("学习率:", self.lr_spin) config_group.setLayout(config_layout) # 数据集选择 data_group = QGroupBox("数据集") data_layout = QVBoxLayout() self.data_path_edit = QLineEdit() self.data_browse_btn = QPushButton("浏览") self.data_browse_btn.clicked.connect(self.browse_dataset) data_layout.addWidget(QLabel("数据集路径:")) data_layout.addWidget(self.data_path_edit) data_layout.addWidget(self.data_browse_btn) data_group.setLayout(data_layout) # 进度显示 self.progress_bar = QProgressBar() self.log_text = QTextEdit() self.log_text.setReadOnly(True) # 按钮 button_layout = QHBoxLayout() self.start_btn = QPushButton("开始训练") self.start_btn.clicked.connect(self.start_training) self.cancel_btn = QPushButton("取消") self.cancel_btn.clicked.connect(self.reject) button_layout.addWidget(self.start_btn) button_layout.addWidget(self.cancel_btn) # 添加到主布局 layout.addWidget(config_group) layout.addWidget(data_group) layout.addWidget(QLabel("训练进度:")) layout.addWidget(self.progress_bar) layout.addWidget(QLabel("训练日志:")) layout.addWidget(self.log_text) layout.addLayout(button_layout) self.setLayout(layout) def browse_dataset(self): """浏览数据集""" folder = QFileDialog.getExistingDirectory(self, "选择数据集文件夹") if folder: self.data_path_edit.setText(folder) def start_training(self): """开始训练""" # 获取训练参数 epochs = self.epochs_spin.value() batch_size = self.batch_spin.value() learning_rate = self.lr_spin.value() data_path = self.data_path_edit.text() # 验证数据集路径 if not os.path.exists(data_path): QMessageBox.warning(self, "错误", "数据集路径不存在!") return # 开始训练线程 self.training_thread = TrainingThread( data_path=data_path, epochs=epochs, batch_size=batch_size, learning_rate=learning_rate ) self.training_thread.progress_updated.connect(self.update_progress) self.training_thread.log_updated.connect(self.update_log) self.training_thread.training_finished.connect(self.training_finished) self.start_btn.setEnabled(False) self.training_thread.start() def update_progress(self, value): """更新进度条""" self.progress_bar.setValue(value) def update_log(self, message): """更新日志""" self.log_text.append(message) def training_finished(self, success): """训练完成""" self.start_btn.setEnabled(True) if success: QMessageBox.information(self, "完成", "训练完成!") else: QMessageBox.warning(self, "错误", "训练失败!") class TrainingThread(QThread): """训练线程""" progress_updated = Signal(int) log_updated = Signal(str) training_finished = Signal(bool) def __init__(self, data_path, epochs, batch_size, learning_rate): super().__init__() self.data_path = data_path self.epochs = epochs self.batch_size = batch_size self.learning_rate = learning_rate def run(self): try: # 训练逻辑 for epoch in range(self.epochs): # 模拟训练 self.log_updated.emit(f"Epoch {epoch+1}/{self.epochs}") # 更新进度 progress = int((epoch + 1) / self.epochs * 100) self.progress_updated.emit(progress) # 模拟训练延迟 self.sleep(1) self.training_finished.emit(True) except Exception as e: self.log_updated.emit(f"训练错误: {str(e)}") self.training_finished.emit(False) def main(): app = QApplication(sys.argv) window = AirportDetectionUI() window.show() sys.exit(app.exec()) if __name__ == "__main__": main()