摘要:稻田虫害是威胁全球粮食安全的关键因素之一。传统的虫害检测方法依赖于农技人员目视检查,效率低下且易出错。本文详细介绍一种基于深度学习目标检测模型YOLO系列(包括YOLOv5、YOLOv8及最新的YOLOv10)的智能稻田虫害检测系统。系统以Python为核心,实现了从数据准备、模型训练与比较、性能评估到最终可交互的Web界面(基于Gradio)部署的全流程。文章不仅提供了完整的代码实现,还探讨了各版本YOLO模型在此特定农业场景下的表现差异,旨在为农业AI应用的研究与开发人员提供一份详尽的实践指南。
1. 引言
粮食安全是人类社会发展的基石。水稻作为全球半数以上人口的主食,其稳定生产至关重要。然而,稻飞虱、二化螟、稻纵卷叶螟等害虫每年导致巨大的产量损失。及时、准确地识别害虫是进行精准防控的前提。
近年来,以卷积神经网络(CNN)为代表的深度学习技术在计算机视觉领域取得了革命性进展。YOLO(You Only Look Once)系列模型因其在精度与速度间的卓越平衡,成为实时目标检测的事实标准。本系统旨在利用YOLOv5(成熟稳定)、YOLOv8(SOTA功能丰富)和YOLOv10(最新无NMS设计)构建一个端到端的虫害检测解决方案,并通过友好的Web界面降低使用门槛,赋能一线农技工作者。
2. 系统架构与技术栈
整个系统可分为四大核心模块:
数据预处理模块:处理原始图像,进行标注格式转换、数据集划分与增强。
模型训练与验证模块:支持YOLOv5/v8/v10模型的训练、验证与性能评估。
模型推理模块:加载训练好的最佳权重,对单张图片、批量图片或视频进行害虫检测。
Web交互界面模块:基于Gradio构建,允许用户上传图片并可视化检测结果。
技术栈:
深度学习框架:PyTorch
目标检测模型:YOLOv5, YOLOv8 (ultralytics), YOLOv10 (近期发布)
编程语言:Python 3.8+
Web框架:Gradio (快速构建ML演示界面)
数据处理:OpenCV, Pillow, Pandas
可视化:Matplotlib, Seaborn
3. 数据集准备与增强
3.1 数据收集与标注
我们使用一个自构建的稻田害虫图像数据集,包含Rice_Leafhopper(稻叶蝉)、Rice_Stem_Borer(二化螟)、Rice_Plant_Hopper(稻飞虱)等类别。图像通过高清摄像头在田间和实验室环境下采集。
使用标注工具LabelImg或Roboflow进行边界框标注,生成PASCAL VOC格式(XML)或YOLO格式(每张图片一个.txt文件,内容为class_id x_center y_center width_height,坐标已归一化)。
3.2 数据集结构
项目采用标准YOLO格式目录结构:
text
datasets/ └── rice_pests/ ├── train/ │ ├── images/ # 存放训练图片 │ └── labels/ # 存放对应的YOLO格式标签文件 ├── val/ │ ├── images/ │ └── labels/ └── test/ ├── images/ └── labels/
3.3 数据增强策略
为提高模型鲁棒性,采用在线数据增强(在训练时由数据加载器实时完成)。常用增强包括:随机翻转(水平、垂直)、随机旋转、亮度/对比度调整、马赛克增强(Mosaic)和混合增强(MixUp)。这些增强在YOLO的训练配置中直接启用。
关键代码:创建数据集配置文件data/rice_pests.yaml
yaml
# YOLO数据集配置文件 path: ../datasets/rice_pests # 数据集根目录 train: train/images # 训练集路径(相对path) val: val/images # 验证集路径 test: test/images # 测试集路径(可选) # 类别数量与名称 nc: 3 # number of classes names: ['Rice_Leafhopper', 'Rice_Stem_Borer', 'Rice_Plant_Hopper'] # 下载地址/说明(可选) # download: https://your-dataset-url.com
4. YOLO模型训练与比较
4.1 环境配置
创建虚拟环境并安装依赖。
bash
# 创建并激活环境(以conda为例) conda create -n rice_pest_detection python=3.8 conda activate rice_pest_detection # 安装PyTorch (请根据CUDA版本选择) pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 # 安装YOLOv5, YOLOv8, YOLOv10及其他依赖 pip install ultralytics # 包含YOLOv8,也支持YOLOv5/v10的部分接口 # YOLOv5 需要单独克隆其官方仓库 git clone https://github.com/ultralytics/yolov5.git cd yolov5 pip install -r requirements.txt cd .. # 安装其他库 pip install opencv-python pillow pandas matplotlib seabron scikit-learn gradio
4.2 YOLOv5训练
YOLOv5提供了清晰的命令行接口和预训练模型。
python
# train_yolov5.py import os import subprocess def train_yolov5(): """ 使用YOLOv5官方脚本进行训练 """ # 切换到yolov5目录 yolov5_dir = './yolov5' os.chdir(yolov5_dir) # 训练命令 # 使用预训练模型yolov5s.pt, 输入图像大小为640, batch size为16, 训练100个epochs cmd = [ 'python', 'train.py', '--data', '../data/rice_pests.yaml', # 数据集配置路径 '--weights', 'yolov5s.pt', # 预训练权重 '--img', '640', # 训练图像大小 '--batch', '16', '--epochs', '100', '--project', '../runs/train', # 输出目录 '--name', 'yolov5s_rice_pests', '--exist-ok', # 允许覆盖已存在的结果 '--device', '0', # GPU ID, 'cpu' 或 '0,1,2,3' ] subprocess.run(cmd) # 切换回主目录 os.chdir('..') if __name__ == '__main__': train_yolov5()4.3 YOLOv8训练
YOLOv8的API更为统一和简洁。
python
# train_yolov8.py from ultralytics import YOLO def train_yolov8(): # 加载一个预训练模型 (推荐使用最新版本) model = YOLO('yolov8s.pt') # 可以使用 yolov8n.pt, yolov8m.pt, yolov8l.pt, yolov8x.pt # 训练模型 results = model.train( data='data/rice_pests.yaml', # 数据集配置文件路径 epochs=100, imgsz=640, batch=16, device='0', # 或 'cpu' 或 [0, 1] project='runs/train', name='yolov8s_rice_pests', exist_ok=True, pretrained=True, optimizer='AdamW', # 可选优化器 lr0=0.01, # 初始学习率 # 更多超参数... ) # 验证模型在验证集上的性能 metrics = model.val() print(metrics.box.map) # 打印mAP50-95 return model if __name__ == '__main__': model = train_yolov8()4.4 YOLOv10训练
YOLOv10是2024年发布的最新版本,主要特点是无需NMS的后处理,推理速度更快。其训练方式与YOLOv8非常相似。
python
# train_yolov10.py from ultralytics import YOLO def train_yolov10(): # 注意:需要确保ultralytics版本支持YOLOv10,或从源码安装 # pip install git+https://github.com/THU-MIG/yolov10.git model = YOLO('yolov10s.pt') # 或 'yolov10n.pt', 'yolov10m.pt'等 results = model.train( data='data/rice_pests.yaml', epochs=100, imgsz=640, batch=16, device='0', project='runs/train', name='yolov10s_rice_pests', exist_ok=True, ) # 验证 metrics = model.val() return model if __name__ == '__main__': train_yolov10()4.5 训练结果分析与模型比较
训练完成后,在runs/train/目录下会生成包含所有训练日志、权重和可视化结果的文件夹。关键文件包括:
weights/best.pt: 验证集上表现最佳的权重。results.png: 训练损失、精度等指标曲线图。confusion_matrix.png: 混淆矩阵。val_batchX_labels.jpg: 验证集预测示例。
我们可以编写一个评估脚本来比较三个模型的性能:
python
# evaluate_models.py import matplotlib.pyplot as plt import pandas as pd import yaml from pathlib import Path def plot_training_results(exp_paths, model_names): """ 绘制多个模型的训练结果对比图 """ fig, axes = plt.subplots(2, 3, figsize=(15, 10)) axes = axes.flatten() metrics = ['train/box_loss', 'train/cls_loss', 'train/dfl_loss', 'val/box_loss', 'val/cls_loss', 'metrics/mAP50-95(B)'] for idx, metric in enumerate(metrics): ax = axes[idx] ax.set_title(metric) ax.set_xlabel('Epoch') ax.grid(True) for exp_path, name in zip(exp_paths, model_names): results_csv = Path(exp_path) / 'results.csv' if results_csv.exists(): df = pd.read_csv(results_csv) # 注意:列名可能包含空格,需要处理 col_name = metric if metric in df.columns: ax.plot(df[metric].values, label=name) else: # 尝试查找相似的列名 for col in df.columns: if metric.replace('/', ' ').strip() in col: ax.plot(df[col].values, label=name) break if idx == 0: ax.legend() plt.tight_layout() plt.savefig('model_comparison.png', dpi=300) plt.show() if __name__ == '__main__': # 假设训练结果保存在以下路径 exp_paths = [ 'runs/train/yolov5s_rice_pests', 'runs/train/yolov8s_rice_pests', 'runs/train/yolov10s_rice_pests' ] model_names = ['YOLOv5s', 'YOLOv8s', 'YOLOv10s'] plot_training_results(exp_paths, model_names)5. 完整的推理与Web界面代码
我们将创建一个主程序,集成模型加载、推理和Gradio界面。
python
# main_app.py import gradio as gr import cv2 import torch import numpy as np from pathlib import Path import matplotlib.pyplot as plt from PIL import Image, ImageDraw, ImageFont import tempfile import warnings warnings.filterwarnings('ignore') # 尝试导入不同版本的YOLO try: from ultralytics import YOLO as YOLOv8_10 ULTRA_AVAILABLE = True except ImportError: ULTRA_AVAILABLE = False # 假设YOLOv5的推理代码需要自定义(或使用其仓库中的detect.py) # 这里我们简化,使用一个统一的接口,实际中你可能需要适配 class PestDetector: """ 统一的害虫检测器,支持加载不同版本的YOLO模型 """ def __init__(self, model_path, model_type='yolov8'): """ 初始化检测器 Args: model_path: 模型权重文件路径 (.pt) model_type: 模型类型, 'yolov5', 'yolov8', 或 'yolov10' """ self.model_type = model_type self.model_path = model_path self.device = 'cuda' if torch.cuda.is_available() else 'cpu' if model_type in ['yolov8', 'yolov10'] and ULTRA_AVAILABLE: self.model = YOLOv8_10(model_path) self.use_ultralytics = True elif model_type == 'yolov5': # 对于YOLOv5,我们可以使用其自带的加载方式,这里简化处理 # 实际应用中,你可能需要导入yolov5的模型定义 try: import sys sys.path.append('./yolov5') # 添加yolov5路径 from models.experimental import attempt_load from utils.general import non_max_suppression from utils.torch_utils import select_device self.yolov5_model = attempt_load(model_path, device=select_device(self.device)) self.yolov5_model.eval() self.use_ultralytics = False except ImportError: raise ImportError("请确保YOLOv5仓库已克隆并放置在正确位置。") else: raise ValueError(f"不支持的模型类型: {model_type}") # 类别颜色和名称 (应与训练时一致) self.class_names = ['Rice_Leafhopper', 'Rice_Stem_Borer', 'Rice_Plant_Hopper'] self.colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255)] # 红,绿,蓝 def detect(self, image, conf_threshold=0.3, iou_threshold=0.5): """ 执行目标检测 Args: image: PIL Image 或 numpy array conf_threshold: 置信度阈值 iou_threshold: IOU阈值(用于NMS,YOLOv10可能不需要) Returns: annotated_image: 绘制了边界框的PIL Image detections: 检测结果列表,每个元素为 [class_name, confidence, [x1, y1, x2, y2]] """ # 转换输入为RGB numpy数组 if isinstance(image, Image.Image): img_np = np.array(image) img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) else: img_np = image.copy() original_h, original_w = img_np.shape[:2] if self.use_ultralytics: # 使用Ultralytics (YOLOv8/v10) 接口 results = self.model(img_np, conf=conf_threshold, iou=iou_threshold, verbose=False) # 解析结果 detections = [] for r in results: boxes = r.boxes for box in boxes: # 获取坐标、置信度、类别ID x1, y1, x2, y2 = box.xyxy[0].cpu().numpy() conf = box.conf[0].cpu().numpy() cls_id = int(box.cls[0].cpu().numpy()) # 映射回原始图像尺寸 (Ultralytics可能已经做了resize) # 注意:results可能已经包含了原始尺寸的坐标,这里简化处理 detections.append([ self.class_names[cls_id], float(conf), [float(x1), float(y1), float(x2), float(y2)] ]) # 绘制结果 (使用Ultralytics自带的绘图功能) annotated_img_np = results[0].plot() # 返回BGR图像 annotated_img = Image.fromarray(cv2.cvtColor(annotated_img_np, cv2.COLOR_BGR2RGB)) else: # YOLOv5 推理流程 (简化版,实际应用请参考yolov5/detect.py) from utils.general import non_max_suppression, scale_boxes from utils.augmentations import letterbox # 预处理 img_size = 640 stride = 32 img = letterbox(img_np, img_size, stride=stride, auto=True)[0] img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB img = np.ascontiguousarray(img) img = torch.from_numpy(img).to(self.device) img = img.float() / 255.0 # 0 - 255 to 0.0 - 1.0 if img.ndimension() == 3: img = img.unsqueeze(0) # 推理 with torch.no_grad(): pred = self.yolov5_model(img, augment=False)[0] # NMS pred = non_max_suppression(pred, conf_threshold, iou_threshold, classes=None, agnostic=False) # 处理检测结果 detections = [] img_pil = Image.fromarray(cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)) draw = ImageDraw.Draw(img_pil) # 尝试加载字体 try: font = ImageFont.truetype("arial.ttf", 20) except IOError: font = ImageFont.load_default() for i, det in enumerate(pred): if len(det): # 将坐标从预处理后的尺寸缩放回原始尺寸 det[:, :4] = scale_boxes(img.shape[2:], det[:, :4], img_np.shape).round() for *xyxy, conf, cls in det: x1, y1, x2, y2 = map(int, xyxy) cls_id = int(cls) label = f"{self.class_names[cls_id]} {conf:.2f}" # 记录检测结果 detections.append([ self.class_names[cls_id], float(conf), [float(x1), float(y1), float(x2), float(y2)] ]) # 绘制边界框和标签 color = self.colors[cls_id % len(self.colors)] draw.rectangle([x1, y1, x2, y2], outline=color, width=3) # 绘制文本背景 text_bbox = draw.textbbox((x1, y1), label, font=font) draw.rectangle(text_bbox, fill=color) draw.text((x1, y1), label, fill=(255, 255, 255), font=font) annotated_img = img_pil return annotated_img, detections # 初始化模型 (假设我们选择YOLOv8作为默认模型) DEFAULT_MODEL_PATH = "runs/train/yolov8s_rice_pests/weights/best.pt" detector = PestDetector(DEFAULT_MODEL_PATH, model_type='yolov8') def predict_image(image, conf_threshold, model_choice): """ Gradio预测函数,处理图像输入 """ # 根据下拉菜单选择加载不同的模型 (为了演示,这里每次切换都重新加载,实际可以缓存) global detector model_map = { "YOLOv5s": ("runs/train/yolov5s_rice_pests/weights/best.pt", "yolov5"), "YOLOv8s": ("runs/train/yolov8s_rice_pests/weights/best.pt", "yolov8"), "YOLOv10s": ("runs/train/yolov10s_rice_pests/weights/best.pt", "yolov10") } if model_choice in model_map: model_path, model_type = model_map[model_choice] # 只有在模型切换时才重新加载 if not hasattr(predict_image, 'current_model') or predict_image.current_model != model_choice: try: detector = PestDetector(model_path, model_type=model_type) predict_image.current_model = model_choice except Exception as e: return f"加载模型失败: {e}", None, None # 执行检测 annotated_img, detections = detector.detect(image, conf_threshold=conf_threshold) # 统计结果 stats = {name: 0 for name in detector.class_names} for det in detections: class_name = det[0] if class_name in stats: stats[class_name] += 1 # 生成统计文本 stats_text = "检测统计:\n" for name, count in stats.items(): stats_text += f"{name}: {count} 只\n" stats_text += f"总计: {len(detections)} 只" # 生成检测结果表格数据 table_data = [["类别", "置信度", "边界框"]] for det in detections[:10]: # 只显示前10个检测结果 class_name, conf, bbox = det table_data.append([class_name, f"{conf:.3f}", f"[{bbox[0]:.1f}, {bbox[1]:.1f}, {bbox[2]:.1f}, {bbox[3]:.1f}]"]) if len(detections) > 10: table_data.append(["...", "...", "..."]) return annotated_img, stats_text, table_data def predict_video(video_file, conf_threshold, model_choice): """ 处理视频输入,输出检测后的视频 """ # 类似图像处理,但这里简化,只处理视频第一帧作为示例 # 实际应用中,你需要逐帧处理视频 cap = cv2.VideoCapture(video_file) ret, frame = cap.read() cap.release() if ret: frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) image = Image.fromarray(frame_rgb) annotated_img, _, _ = predict_image(image, conf_threshold, model_choice) return annotated_img, "视频检测完成 (这里仅展示了第一帧)" else: return None, "无法读取视频文件" # 创建Gradio界面 with gr.Blocks(title="稻田虫害智能检测系统", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🐛 稻田虫害智能检测系统") gr.Markdown("使用YOLO深度学习模型自动检测稻田中的常见害虫。支持YOLOv5, YOLOv8, YOLOv10。") with gr.Row(): with gr.Column(scale=1): model_choice = gr.Dropdown( choices=["YOLOv5s", "YOLOv8s", "YOLOv10s"], value="YOLOv8s", label="选择模型" ) conf_threshold = gr.Slider( minimum=0.1, maximum=1.0, value=0.4, step=0.05, label="置信度阈值" ) with gr.Tab("图像上传"): image_input = gr.Image(type="pil", label="上传稻田图片") image_button = gr.Button("检测害虫", variant="primary") with gr.Tab("视频上传"): video_input = gr.Video(label="上传田间视频") video_button = gr.Button("检测视频", variant="primary") with gr.Tab("摄像头"): camera_input = gr.Image(source="webcam", streaming=True, label="实时摄像头") camera_button = gr.Button("开始实时检测", variant="primary") with gr.Column(scale=2): image_output = gr.Image(type="pil", label="检测结果") stats_output = gr.Textbox(label="检测统计", lines=5) table_output = gr.Dataframe( headers=["类别", "置信度", "边界框"], label="检测详情 (前10个)", row_count=10 ) # 绑定事件 image_button.click( fn=predict_image, inputs=[image_input, conf_threshold, model_choice], outputs=[image_output, stats_output, table_output] ) video_button.click( fn=predict_video, inputs=[video_input, conf_threshold, model_choice], outputs=[image_output, stats_output] ) # 摄像头实时检测(简化版,实际需要更复杂的流处理) camera_button.click( fn=lambda img, conf, model: predict_image(img, conf, model)[0], inputs=[camera_input, conf_threshold, model_choice], outputs=[image_output] ) # 示例图片 gr.Markdown("### 示例图片") gr.Examples( examples=[ ["examples/test1.jpg"], ["examples/test2.jpg"], ["examples/test3.jpg"] ], inputs=image_input, outputs=[image_output, stats_output, table_output], fn=lambda img: predict_image(img, 0.4, "YOLOv8s"), cache_examples=False ) if __name__ == "__main__": # 创建示例目录(如果不存在) Path("examples").mkdir(exist_ok=True) # 启动Gradio应用 demo.launch(server_name="0.0.0.0", server_port=7860, share=False)6. 系统部署与优化建议
6.1 部署方式
本地部署:适合农技站、农业实验室。运行上述Python脚本即可启动本地Web服务器。
服务器部署:将代码部署到云服务器(如AWS, Azure, 阿里云),通过Nginx反向代理提供公网访问。
边缘设备部署:使用TensorRT, ONNX或OpenVINO将PyTorch模型转换为优化格式,部署到Jetson Nano, Raspberry Pi等边缘设备,实现田间实时检测。
6.2 性能优化
模型量化:使用PyTorch的量化工具将FP32模型转换为INT8,减少模型大小和推理时间。
模型剪枝:移除网络中不重要的权重,进一步压缩模型。
TensorRT加速:对于NVIDIA GPU,使用TensorRT可显著提升推理速度。
6.3 未来改进方向
多模态融合:结合红外、多光谱图像或环境传感器数据。
时序分析:利用视频序列信息,跟踪害虫活动轨迹,预测爆发趋势。
移动端适配:开发轻量级模型(如YOLOv5n/v8n)的Android/iOS应用。
集成GIS系统:将检测结果与地理信息系统结合,生成虫害分布热力图。