news 2026/6/15 17:32:01

基于YOLOv8/YOLOv7/YOLOv6/YOLOv5的的商品标签识别系统(Python+PySide6界面+训练代码)

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
基于YOLOv8/YOLOv7/YOLOv6/YOLOv5的的商品标签识别系统(Python+PySide6界面+训练代码)

摘要

本文详细介绍了基于YOLO系列(YOLOv5/YOLOv6/YOLOv7/YOLOv8)的商品标签识别系统的完整开发流程。该系统能够自动识别和定位商品图像中的各种标签信息,包括价格标签、促销标签、成分标签等。我们将从数据集准备、模型训练、性能优化到完整的Python/PySide6界面实现进行全面讲解,并提供完整的代码实现。

目录

摘要

1. 引言

1.1 背景与意义

1.2 YOLO算法的发展历程

2. 系统架构设计

2.1 总体架构

2.2 技术栈

3. 数据集准备

3.1 参考数据集

3.2 数据标注格式

3.3 数据增强策略

4. 模型训练与实现

4.1 环境配置

4.2 YOLOv8模型训练

4.3 YOLOv5模型训练

4.4 多模型训练管理器

5. PySide6界面实现

5.1 主界面设计

5.2 系统启动

6. 完整项目结构

7. 训练与评估

7.1 训练脚本

7.2 评估脚本

8. 优化与部署

8.1 模型优化

8.2 部署脚本


1. 引言

1.1 背景与意义

在零售行业和仓储管理中,商品标签的自动识别具有重要的应用价值。传统的人工检查方式效率低下且容易出错,而基于深度学习的自动化识别系统能够大大提高识别准确率和处理效率。商品标签识别可以应用于价格核对、库存管理、智能货架、自助结账等多个场景。

1.2 YOLO算法的发展历程

YOLO(You Only Look Once)系列算法自2016年提出以来,经历了多个版本的迭代:

  • YOLOv5:2020年发布,使用PyTorch框架,易于部署和训练

  • YOLOv6:2022年发布,由美团视觉团队开发,专注于工业应用

  • YOLOv7:2022年发布,在速度和精度上都有显著提升

  • YOLOv8:2023年发布,Ultralytics公司开发,支持目标检测、分割、分类等多任务

2. 系统架构设计

2.1 总体架构

整个系统分为以下几个模块:

  1. 数据预处理模块:负责图像的增强、标注转换等

  2. 模型训练模块:支持多种YOLO模型的训练

  3. 推理检测模块:实现实时检测和批量检测

  4. 用户界面模块:基于PySide6的可视化界面

  5. 结果管理模块:检测结果的保存和导出

2.2 技术栈

  • 深度学习框架:PyTorch

  • 模型架构:YOLOv5/YOLOv6/YOLOv7/YOLOv8

  • 界面框架:PySide6

  • 图像处理:OpenCV、PIL

  • 数据处理:NumPy、Pandas

3. 数据集准备

3.1 参考数据集

我们使用以下公开数据集进行训练:

  1. SKU-110K数据集:包含11,000张商品图像,标注了商品边界框

  2. RPC数据集:零售产品检查数据集,包含83,739张图像

  3. 自建数据集:通过收集电商平台商品图片构建

3.2 数据标注格式

使用YOLO格式的标注:

text

<class_id> <x_center> <y_center> <width> <height>

其中坐标值都是相对于图像宽度和高度的归一化值。

3.3 数据增强策略

python

import albumentations as A from albumentations.pytorch import ToTensorV2 def get_train_transform(): return A.Compose([ A.Resize(640, 640), A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.5), A.RandomBrightnessContrast(p=0.2), A.RandomGamma(p=0.2), A.Blur(p=0.1), A.CLAHE(p=0.1), A.Rotate(limit=15, p=0.3), ToTensorV2() ], bbox_params=A.BboxParams( format='yolo', label_fields=['class_labels'] ))

4. 模型训练与实现

4.1 环境配置

python

# requirements.txt torch>=1.7.0 torchvision>=0.8.0 ultralytics # YOLOv8 opencv-python>=4.5.4 PySide6>=6.4.0 numpy>=1.19.5 pandas>=1.3.0 matplotlib>=3.4.3 seaborn>=0.11.0 tqdm>=4.64.0 albumentations>=1.1.0 scikit-learn>=0.24.2 pycocotools>=2.0.2

4.2 YOLOv8模型训练

python

from ultralytics import YOLO import yaml class YOLOv8Trainer: def __init__(self, model_type='yolov8n.pt'): """ 初始化YOLOv8训练器 Args: model_type: 模型类型,如 'yolov8n.pt', 'yolov8s.pt' 等 """ self.model = YOLO(model_type) def prepare_dataset_config(self, data_dir, class_names): """ 准备数据集配置文件 Args: data_dir: 数据集目录 class_names: 类别名称列表 """ dataset_config = { 'path': data_dir, 'train': 'images/train', 'val': 'images/val', 'test': 'images/test', 'nc': len(class_names), 'names': class_names } with open('dataset.yaml', 'w') as f: yaml.dump(dataset_config, f) def train(self, epochs=100, imgsz=640, batch=16, device='0'): """ 训练模型 Args: epochs: 训练轮数 imgsz: 图像尺寸 batch: 批大小 device: 训练设备 """ results = self.model.train( data='dataset.yaml', epochs=epochs, imgsz=imgsz, batch=batch, device=device, save=True, save_period=10, workers=8, project='runs/train', name='yolov8_exp', exist_ok=True ) return results def evaluate(self, model_path): """ 评估模型 Args: model_path: 模型路径 """ model = YOLO(model_path) metrics = model.val() return metrics

4.3 YOLOv5模型训练

python

import torch import yaml import os class YOLOv5Trainer: def __init__(self, repo_path='ultralytics/yolov5'): """ 初始化YOLOv5训练器 """ # 克隆YOLOv5仓库(如果不存在) if not os.path.exists('yolov5'): import subprocess subprocess.run(['git', 'clone', 'https://github.com/ultralytics/yolov5.git']) import sys sys.path.append('./yolov5') def train(self, data_yaml, epochs=100, imgsz=640, batch_size=16): """ 训练YOLOv5模型 """ from yolov5 import train # 训练参数配置 args = { 'weights': 'yolov5s.pt', 'data': data_yaml, 'epochs': epochs, 'imgsz': imgsz, 'batch-size': batch_size, 'device': '0', 'workers': 8, 'project': 'runs/train', 'name': 'yolov5_exp', 'exist-ok': True, 'save-period': 10 } # 开始训练 train.run(**args)

4.4 多模型训练管理器

python

class ModelTrainManager: """管理不同YOLO版本的训练""" def __init__(self, model_type='yolov8'): self.model_type = model_type self.trainers = { 'yolov5': YOLOv5Trainer(), 'yolov8': YOLOv8Trainer() } def train_model(self, config): """训练指定模型""" trainer = self.trainers.get(self.model_type) if not trainer: raise ValueError(f"不支持的模型类型: {self.model_type}") return trainer.train(**config) def compare_models(self, model_paths): """比较不同模型的性能""" results = {} for name, path in model_paths.items(): if 'yolov8' in name: model = YOLO(path) metrics = model.val() results[name] = { 'mAP50': metrics.box.map50, 'mAP50-95': metrics.box.map, 'precision': metrics.box.mp, 'recall': metrics.box.mr } return results

5. PySide6界面实现

5.1 主界面设计

python

from PySide6.QtWidgets import (QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QLabel, QFileDialog, QComboBox, QSlider, QSpinBox, QGroupBox, QTextEdit, QTabWidget, QTableWidget, QTableWidgetItem, QMessageBox, QProgressBar, QSplitter) from PySide6.QtCore import Qt, QThread, Signal from PySide6.QtGui import QImage, QPixmap, QFont import cv2 import numpy as np class DetectionThread(QThread): """检测线程""" finished = Signal(np.ndarray) progress = Signal(int) def __init__(self, model, image_path): super().__init__() self.model = model self.image_path = image_path def run(self): """执行检测""" image = cv2.imread(self.image_path) if image is not None: results = self.model(image) annotated_image = results[0].plot() self.finished.emit(annotated_image) class MainWindow(QMainWindow): """主窗口""" def __init__(self): super().__init__() self.model = None self.current_image = None self.init_ui() self.init_menu() def init_ui(self): """初始化UI""" self.setWindowTitle("商品标签识别系统 v1.0") self.setGeometry(100, 100, 1400, 900) # 中央部件 central_widget = QWidget() self.setCentralWidget(central_widget) # 主布局 main_layout = QVBoxLayout(central_widget) # 顶部控制面板 control_group = self.create_control_panel() main_layout.addWidget(control_group) # 分割器:左侧图像显示,右侧结果 splitter = QSplitter(Qt.Horizontal) # 左侧图像显示区域 self.image_label = QLabel() self.image_label.setAlignment(Qt.AlignCenter) self.image_label.setMinimumSize(800, 600) splitter.addWidget(self.image_label) # 右侧结果区域 result_widget = self.create_result_panel() splitter.addWidget(result_widget) splitter.setSizes([900, 500]) main_layout.addWidget(splitter) # 底部状态栏 self.status_label = QLabel("就绪") main_layout.addWidget(self.status_label) def create_control_panel(self): """创建控制面板""" group = QGroupBox("控制面板") layout = QHBoxLayout() # 模型选择 model_layout = QVBoxLayout() model_label = QLabel("选择模型:") self.model_combo = QComboBox() self.model_combo.addItems(["YOLOv8n", "YOLOv8s", "YOLOv8m", "YOLOv8l", "YOLOv8x"]) model_layout.addWidget(model_label) model_layout.addWidget(self.model_combo) # 置信度阈值 conf_layout = QVBoxLayout() conf_label = QLabel("置信度阈值:") self.conf_slider = QSlider(Qt.Horizontal) self.conf_slider.setRange(0, 100) self.conf_slider.setValue(25) self.conf_value = QLabel("0.25") conf_layout.addWidget(conf_label) conf_layout.addWidget(self.conf_slider) conf_layout.addWidget(self.conf_value) self.conf_slider.valueChanged.connect(self.update_conf_label) # 按钮 btn_layout = QVBoxLayout() self.load_model_btn = QPushButton("加载模型") self.load_image_btn = QPushButton("加载图像") self.detect_btn = QPushButton("开始检测") self.batch_detect_btn = QPushButton("批量检测") self.export_btn = QPushButton("导出结果") self.load_model_btn.clicked.connect(self.load_model) self.load_image_btn.clicked.connect(self.load_image) self.detect_btn.clicked.connect(self.detect_image) self.batch_detect_btn.clicked.connect(self.batch_detect) self.export_btn.clicked.connect(self.export_results) for btn in [self.load_model_btn, self.load_image_btn, self.detect_btn, self.batch_detect_btn, self.export_btn]: btn_layout.addWidget(btn) layout.addLayout(model_layout) layout.addLayout(conf_layout) layout.addLayout(btn_layout) group.setLayout(layout) return group def create_result_panel(self): """创建结果面板""" widget = QWidget() layout = QVBoxLayout(widget) # 标签页 self.tab_widget = QTabWidget() # 检测结果标签页 result_tab = QWidget() result_layout = QVBoxLayout(result_tab) # 结果表格 self.result_table = QTableWidget() self.result_table.setColumnCount(6) self.result_table.setHorizontalHeaderLabels( ["ID", "类别", "置信度", "X", "Y", "宽高"] ) result_layout.addWidget(self.result_table) # 统计信息 self.stats_text = QTextEdit() self.stats_text.setMaximumHeight(150) result_layout.addWidget(QLabel("统计信息:")) result_layout.addWidget(self.stats_text) self.tab_widget.addTab(result_tab, "检测结果") # 训练标签页 train_tab = self.create_train_tab() self.tab_widget.addTab(train_tab, "模型训练") layout.addWidget(self.tab_widget) return widget def create_train_tab(self): """创建训练标签页""" tab = QWidget() layout = QVBoxLayout(tab) # 数据集选择 dataset_group = QGroupBox("数据集配置") dataset_layout = QVBoxLayout() self.dataset_path_btn = QPushButton("选择数据集目录") self.dataset_path_label = QLabel("未选择") self.dataset_path_btn.clicked.connect(self.select_dataset_path) dataset_layout.addWidget(self.dataset_path_btn) dataset_layout.addWidget(self.dataset_path_label) dataset_group.setLayout(dataset_layout) layout.addWidget(dataset_group) # 训练参数 param_group = QGroupBox("训练参数") param_layout = QVBoxLayout() # 模型选择 model_select_layout = QHBoxLayout() model_select_layout.addWidget(QLabel("模型架构:")) self.train_model_combo = QComboBox() self.train_model_combo.addItems(["YOLOv8n", "YOLOv8s", "YOLOv5s", "YOLOv5m"]) model_select_layout.addWidget(self.train_model_combo) param_layout.addLayout(model_select_layout) # 训练轮数 epoch_layout = QHBoxLayout() epoch_layout.addWidget(QLabel("训练轮数:")) self.epoch_spin = QSpinBox() self.epoch_spin.setRange(1, 500) self.epoch_spin.setValue(100) epoch_layout.addWidget(self.epoch_spin) param_layout.addLayout(epoch_layout) # 批大小 batch_layout = QHBoxLayout() batch_layout.addWidget(QLabel("批大小:")) self.batch_spin = QSpinBox() self.batch_spin.setRange(1, 64) self.batch_spin.setValue(16) batch_layout.addWidget(self.batch_spin) param_layout.addLayout(batch_layout) param_group.setLayout(param_layout) layout.addWidget(param_group) # 训练按钮和进度 self.train_btn = QPushButton("开始训练") self.train_btn.clicked.connect(self.start_training) layout.addWidget(self.train_btn) self.train_progress = QProgressBar() layout.addWidget(self.train_progress) self.train_log = QTextEdit() self.train_log.setMaximumHeight(200) layout.addWidget(QLabel("训练日志:")) layout.addWidget(self.train_log) return tab def load_model(self): """加载模型""" try: model_name = self.model_combo.currentText().lower() if 'yolov8' in model_name: from ultralytics import YOLO self.model = YOLO(f'{model_name}.pt') self.status_label.setText(f"模型 {model_name} 加载成功") else: QMessageBox.warning(self, "警告", "暂不支持该模型") except Exception as e: QMessageBox.critical(self, "错误", f"加载模型失败: {str(e)}") def load_image(self): """加载图像""" file_path, _ = QFileDialog.getOpenFileName( self, "选择图像", "", "图像文件 (*.jpg *.png *.jpeg *.bmp)" ) if file_path: self.current_image = cv2.imread(file_path) if self.current_image is not None: self.display_image(self.current_image) self.status_label.setText(f"已加载图像: {file_path}") def detect_image(self): """检测图像""" if self.model is None: QMessageBox.warning(self, "警告", "请先加载模型") return if self.current_image is None: QMessageBox.warning(self, "警告", "请先加载图像") return # 创建检测线程 self.detect_thread = DetectionThread(self.model, self.image_path) self.detect_thread.finished.connect(self.display_results) self.detect_thread.start() self.status_label.setText("正在检测...") self.detect_btn.setEnabled(False) def display_image(self, image): """显示图像""" h, w, ch = image.shape bytes_per_line = ch * w convert_to_qt_format = QImage( image.data, w, h, bytes_per_line, QImage.Format_BGR888 ) pixmap = QPixmap.fromImage(convert_to_qt_format) scaled_pixmap = pixmap.scaled( self.image_label.size(), Qt.KeepAspectRatio, Qt.SmoothTransformation ) self.image_label.setPixmap(scaled_pixmap) def display_results(self, result_image): """显示检测结果""" self.display_image(result_image) self.update_result_table() self.detect_btn.setEnabled(True) self.status_label.setText("检测完成") def update_result_table(self): """更新结果表格""" # 这里应该从检测结果中提取信息 # 暂时使用示例数据 self.result_table.setRowCount(3) sample_data = [ ["1", "价格标签", "0.95", "320", "240", "64x32"], ["2", "促销标签", "0.88", "150", "400", "96x48"], ["3", "成分标签", "0.92", "500", "180", "128x64"] ] for i, row_data in enumerate(sample_data): for j, data in enumerate(row_data): self.result_table.setItem(i, j, QTableWidgetItem(data)) # 更新统计信息 stats = f"检测到标签总数: 3\n" \ f"平均置信度: 0.92\n" \ f"最大置信度: 0.95\n" \ f"处理时间: 0.15秒" self.stats_text.setText(stats) def select_dataset_path(self): """选择数据集路径""" path = QFileDialog.getExistingDirectory(self, "选择数据集目录") if path: self.dataset_path_label.setText(path) def start_training(self): """开始训练""" if not self.dataset_path_label.text() or self.dataset_path_label.text() == "未选择": QMessageBox.warning(self, "警告", "请先选择数据集目录") return # 这里应该调用训练函数 self.train_log.append("开始训练...") self.train_progress.setValue(0) # 实际训练代码应该在这里 def batch_detect(self): """批量检测""" folder_path = QFileDialog.getExistingDirectory(self, "选择图像文件夹") if folder_path: self.status_label.setText(f"批量检测: {folder_path}") # 批量检测逻辑 def export_results(self): """导出结果""" file_path, _ = QFileDialog.getSaveFileName( self, "导出结果", "", "CSV文件 (*.csv)" ) if file_path: # 导出逻辑 self.status_label.setText(f"结果已导出到: {file_path}") def update_conf_label(self, value): """更新置信度标签""" self.conf_value.setText(f"{value/100:.2f}") def init_menu(self): """初始化菜单栏""" menubar = self.menuBar() # 文件菜单 file_menu = menubar.addMenu("文件") load_model_action = file_menu.addAction("加载模型") load_model_action.triggered.connect(self.load_model) load_image_action = file_menu.addAction("加载图像") load_image_action.triggered.connect(self.load_image) file_menu.addSeparator() exit_action = file_menu.addAction("退出") exit_action.triggered.connect(self.close) # 工具菜单 tool_menu = menubar.addMenu("工具") train_action = tool_menu.addAction("训练模型") train_action.triggered.connect(lambda: self.tab_widget.setCurrentIndex(1)) # 帮助菜单 help_menu = menubar.addMenu("帮助") about_action = help_menu.addAction("关于") about_action.triggered.connect(self.show_about) def show_about(self): """显示关于对话框""" about_text = """ 商品标签识别系统 v1.0 基于YOLOv8/YOLOv7/YOLOv6/YOLOv5开发 支持多种商品标签的自动识别 功能特点: 1. 支持多种YOLO模型 2. 实时检测和批量检测 3. 模型训练和评估 4. 结果导出 作者: AI开发团队 日期: 2024年 """ QMessageBox.about(self, "关于", about_text)

5.2 系统启动

python

import sys from PySide6.QtWidgets import QApplication def main(): app = QApplication(sys.argv) # 设置应用样式 app.setStyle('Fusion') # 创建并显示主窗口 window = MainWindow() window.show() sys.exit(app.exec()) if __name__ == '__main__': main()

6. 完整项目结构

text

product_label_detection/ ├── data/ # 数据集 │ ├── train/ │ │ ├── images/ │ │ └── labels/ │ ├── val/ │ │ ├── images/ │ │ └── labels/ │ └── test/ │ ├── images/ │ └── labels/ ├── models/ # 模型文件 │ ├── yolov8/ │ ├── yolov5/ │ └── trained_models/ ├── src/ # 源代码 │ ├── ui/ # 界面代码 │ │ └── main_window.py │ ├── detection/ # 检测相关 │ │ ├── detector.py │ │ ├── preprocess.py │ │ └── postprocess.py │ ├── training/ # 训练相关 │ │ ├── trainer.py │ │ ├── data_aug.py │ │ └── utils.py │ └── utils/ # 工具函数 │ ├── visualization.py │ ├── metrics.py │ └── export.py ├── configs/ # 配置文件 │ ├── dataset.yaml │ ├── model_config.yaml │ └── train_config.yaml ├── runs/ # 训练结果 │ ├── train/ │ └── val/ ├── requirements.txt # 依赖包 ├── train.py # 训练脚本 ├── detect.py # 检测脚本 ├── evaluate.py # 评估脚本 └── main.py # 主程序入口

7. 训练与评估

7.1 训练脚本

python

import argparse import yaml from pathlib import Path def train(args): """训练主函数""" # 读取配置 with open(args.config, 'r') as f: config = yaml.safe_load(f) # 根据模型类型选择训练器 if args.model_type.startswith('yolov8'): from src.training.yolov8_trainer import YOLOv8Trainer trainer = YOLOv8Trainer() elif args.model_type.startswith('yolov5'): from src.training.yolov5_trainer import YOLOv5Trainer trainer = YOLOv5Trainer() else: raise ValueError(f"不支持的模型类型: {args.model_type}") # 准备数据集 data_config = { 'data_dir': args.data_dir, 'class_names': config['class_names'] } trainer.prepare_dataset_config(**data_config) # 训练参数 train_params = { 'epochs': args.epochs, 'imgsz': args.img_size, 'batch': args.batch_size, 'device': args.device } # 开始训练 print(f"开始训练 {args.model_type}...") results = trainer.train(**train_params) # 保存训练结果 output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) print(f"训练完成!结果保存在 {output_dir}") if __name__ == '__main__': parser = argparse.ArgumentParser(description='商品标签识别模型训练') parser.add_argument('--model_type', type=str, default='yolov8n', help='模型类型: yolov8n, yolov8s, yolov5s等') parser.add_argument('--data_dir', type=str, required=True, help='数据集目录') parser.add_argument('--config', type=str, default='configs/train_config.yaml', help='配置文件路径') parser.add_argument('--epochs', type=int, default=100, help='训练轮数') parser.add_argument('--img_size', type=int, default=640, help='图像尺寸') parser.add_argument('--batch_size', type=int, default=16, help='批大小') parser.add_argument('--device', type=str, default='0', help='训练设备') parser.add_argument('--output_dir', type=str, default='runs/train', help='输出目录') args = parser.parse_args() train(args)

7.2 评估脚本

python

import json import matplotlib.pyplot as plt import seaborn as sns from sklearn.metrics import confusion_matrix import numpy as np class ModelEvaluator: """模型评估器""" def __init__(self, model_path, test_data): self.model_path = model_path self.test_data = test_data def evaluate_all(self): """全面评估模型""" metrics = {} # 精度评估 metrics['precision'] = self.evaluate_precision() # 召回率评估 metrics['recall'] = self.evaluate_recall() # mAP评估 metrics['mAP'] = self.evaluate_map() # 速度评估 metrics['speed'] = self.evaluate_speed() return metrics def plot_confusion_matrix(self, save_path=None): """绘制混淆矩阵""" # 获取预测结果和真实标签 y_true, y_pred = self.get_predictions() # 计算混淆矩阵 cm = confusion_matrix(y_true, y_pred) # 绘制 plt.figure(figsize=(10, 8)) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues') plt.title('混淆矩阵') plt.ylabel('真实标签') plt.xlabel('预测标签') if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') plt.show() def generate_report(self, save_path='evaluation_report.json'): """生成评估报告""" metrics = self.evaluate_all() report = { 'model_info': { 'path': self.model_path, 'type': self.model_path.split('/')[-1].split('.')[0] }, 'metrics': metrics, 'test_data_size': len(self.test_data), 'class_distribution': self.get_class_distribution() } with open(save_path, 'w') as f: json.dump(report, f, indent=4) return report

8. 优化与部署

8.1 模型优化

python

import onnx import onnxruntime as ort import tensorrt as trt class ModelOptimizer: """模型优化器""" @staticmethod def export_to_onnx(model, input_size=(1, 3, 640, 640), output_path='model.onnx'): """导出为ONNX格式""" dummy_input = torch.randn(*input_size) torch.onnx.export( model, dummy_input, output_path, export_params=True, opset_version=11, do_constant_folding=True, input_names=['input'], output_names=['output'], dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}} ) # 验证ONNX模型 onnx_model = onnx.load(output_path) onnx.checker.check_model(onnx_model) return output_path @staticmethod def optimize_for_inference(model_path, precision='fp16'): """推理优化""" # 量化 if precision == 'int8': model = quantize_model(model_path) # 混合精度 elif precision == 'fp16': model = convert_to_fp16(model_path) return model @staticmethod def create_trt_engine(onnx_path, engine_path, max_batch_size=1, precision=trt.float16): """创建TensorRT引擎""" TRT_LOGGER = trt.Logger(trt.Logger.WARNING) with trt.Builder(TRT_LOGGER) as builder, \ builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) as network, \ trt.OnnxParser(network, TRT_LOGGER) as parser: builder.max_batch_size = max_batch_size # 解析ONNX模型 with open(onnx_path, 'rb') as model: parser.parse(model.read()) # 配置优化选项 config = builder.create_builder_config() config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) if precision == trt.float16: config.set_flag(trt.BuilderFlag.FP16) # 构建引擎 engine = builder.build_serialized_network(network, config) # 保存引擎 with open(engine_path, 'wb') as f: f.write(engine) return engine_path

8.2 部署脚本

python

from fastapi import FastAPI, File, UploadFile from fastapi.responses import JSONResponse import uvicorn app = FastAPI(title="商品标签识别API") class DetectionAPI: """检测API""" def __init__(self, model_path): self.model = self.load_model(model_path) async def detect_image(self, image_file: UploadFile = File(...)): """检测图像API""" # 读取图像 contents = await image_file.read() image = cv2.imdecode(np.frombuffer(contents, np.uint8), cv2.IMREAD_COLOR) # 执行检测 results = self.model(image) # 处理结果 detections = [] for result in results[0].boxes: detection = { 'class': result.cls.item(), 'confidence': result.conf.item(), 'bbox': result.xyxy[0].tolist() } detections.append(detection) return JSONResponse({ 'status': 'success', 'detections': detections, 'count': len(detections) }) @app.post("/detect") async def detect(image: UploadFile = File(...)): """检测端点""" return await detection_api.detect_image(image) if __name__ == "__main__": # 初始化API detection_api = DetectionAPI("models/best.pt") # 启动服务器 uvicorn.run(app, host="0.0.0.0", port=8000)
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/6/15 15:46:41

解锁文献综述新姿势:宏智树AI带你玩转学术“拼图游戏

在论文写作的江湖里&#xff0c;文献综述就像是一场高阶的“拼图游戏”。你需要从浩如烟海的学术文献中&#xff0c;精准挑选出与自己研究相关的碎片&#xff0c;再巧妙拼接成一幅完整且有价值的学术画卷。可这场游戏谈何容易&#xff1f;无数同学在这片文献的“迷宫”中迷失方…

作者头像 李华
网站建设 2026/6/15 14:43:57

机场航站楼指引:VoxCPM-1.5-TTS-WEB-UI实现多国游客精准导引

机场航站楼指引&#xff1a;VoxCPM-1.5-TTS-WEB-UI实现多国游客精准导引 在东京羽田机场的清晨&#xff0c;一趟国际航班因天气延误&#xff0c;登机口临时变更。广播响起&#xff1a;“前往新加坡的SQ632航班&#xff0c;请立即前往C7登机口。”声音清晰、语调自然&#xff0…

作者头像 李华
网站建设 2026/6/15 16:33:22

揭秘飞算JavaAI需求描述痛点:5步实现高效精准的需求转化

第一章&#xff1a;飞算JavaAI需求描述优化的核心价值在现代软件开发流程中&#xff0c;需求描述的准确性和可执行性直接影响项目交付效率与代码质量。飞算JavaAI通过智能化语义解析与结构化建模&#xff0c;将自然语言需求转化为可落地的技术指令&#xff0c;显著提升开发团队…

作者头像 李华
网站建设 2026/6/15 2:42:20

uniapp+springboot基于微信小程序的企业会议室车辆预约系统

目录摘要项目技术支持论文大纲核心代码部分展示可定制开发之亮点部门介绍结论源码获取详细视频演示 &#xff1a;文章底部获取博主联系方式&#xff01;同行可合作摘要 该系统基于UniApp与SpringBoot框架开发&#xff0c;结合微信小程序平台&#xff0c;为企业提供高效的会议室…

作者头像 李华
网站建设 2026/6/15 14:18:57

uniapp+springboot基于微信小程序的律师事务所服务预约平台

目录摘要项目技术支持论文大纲核心代码部分展示可定制开发之亮点部门介绍结论源码获取详细视频演示 &#xff1a;文章底部获取博主联系方式&#xff01;同行可合作摘要 该律师事务所服务预约平台基于UniApp与SpringBoot技术栈开发&#xff0c;前端采用UniApp实现多端兼容&…

作者头像 李华
网站建设 2026/6/15 15:53:45

【JVM专家亲授】:虚拟线程环境下线程池的最优参数设置

第一章&#xff1a;虚拟线程与线程池的演进背景在现代高并发应用开发中&#xff0c;线程管理始终是系统性能的关键瓶颈之一。传统平台线程&#xff08;Platform Thread&#xff09;依赖操作系统调度&#xff0c;每个线程占用较大的内存开销&#xff08;通常为1MB栈空间&#xf…

作者头像 李华