从Supervisely JSON到PyTorch Mask的实战转换指南:解决人像分割数据预处理中的典型问题
人像分割作为计算机视觉领域的基础任务,其数据质量直接影响模型训练效果。而Supervisely平台导出的JSON标注格式与PyTorch等框架所需的二值Mask之间存在一道需要开发者手动跨越的"数据鸿沟"。本文将带您深入解析这一转换过程中的技术细节与实战技巧。
1. 理解Supervisely数据集的核心结构
Supervisely平台生成的标注数据采用项目(Project)-数据集(Dataset)-图像项(Item)的三级目录结构,每个图像项对应两个文件:
- 原始图像文件(如
img001.jpg) - JSON标注文件(如
img001.json)
关键数据结构解析:
{ "description": "person_1", # 标注描述 "tags": [], # 标签信息 "size": { # 图像尺寸 "height": 800, "width": 600 }, "objects": [ # 标注对象列表 { "classTitle": "person", # 类别名称 "points": { # 多边形坐标点 "exterior": [[x1,y1], [x2,y2], ...], "interior": [] } } ] }常见陷阱:
- 多个
objects可能对应同一物体的不同部位标注 interior字段表示多边形内部需要排除的区域(如人体中的空洞)- 坐标点采用绝对像素值而非相对比例
2. 构建高效的格式转换流水线
2.1 环境配置与依赖安装
推荐使用conda创建独立环境:
conda create -n sly2mask python=3.8 conda activate sly2mask pip install supervisely==6.72.0 opencv-python tqdm fire注意:supervisely_lib已整合到主库中,不再需要单独安装
2.2 核心转换代码实现
完整转换脚本supervisely_to_mask.py:
import os import numpy as np import cv2 import tqdm import supervisely as sly from pathlib import Path def validate_mask(mask_arr): """检查并修复mask中的异常值""" unique_vals = np.unique(mask_arr) if len(unique_vals) > 2: print(f"发现异常像素值:{unique_vals},正在自动修正...") mask_arr[mask_arr > 1] = 1 return mask_arr.astype(np.uint8) def convert_dataset(project_dir, output_dir, target_size=None): """ :param project_dir: Supervisely项目目录 :param output_dir: Mask输出目录 :param target_size: 可选,指定输出尺寸(h,w) """ project = sly.Project(project_dir, sly.OpenMode.READ) os.makedirs(output_dir, exist_ok=True) pbar = tqdm.tqdm(total=project.total_items) for dataset in project: ds_output_dir = os.path.join(output_dir, dataset.name) os.makedirs(ds_output_dir, exist_ok=True) for item_name in dataset: pbar.update(1) item_paths = dataset.get_item_paths(item_name) # 加载并渲染标注 ann = sly.Annotation.load_json_file(item_paths.ann_path, project.meta) mask = np.zeros(ann.img_size, dtype=np.uint8) ann.draw(mask, color=[1]) # 单通道渲染 # 尺寸调整(如有需要) if target_size: mask = cv2.resize(mask, (target_size[1], target_size[0]), interpolation=cv2.INTER_NEAREST) # 验证并保存mask mask = validate_mask(mask) output_path = os.path.join(ds_output_dir, Path(item_name).stem + '.png') cv2.imwrite(output_path, mask) pbar.close() print(f"转换完成!结果保存在:{output_dir}") if __name__ == '__main__': import fire fire.Fire(convert_dataset)关键改进点:
- 增加
validate_mask函数自动检测并修复异常像素值 - 支持输出尺寸统一化处理
- 强制使用PNG格式避免JPEG压缩 artifacts
- 更健壮的路径处理机制
3. 实战中的典型问题与解决方案
3.1 多类别标签处理
当数据包含多个类别时(如人像+背景+其他物体),需要修改渲染逻辑:
# 在convert_dataset函数中替换ann.draw调用 class_mapping = { "person": 1, "background": 0, "other_object": 2 # 其他类别ID } for obj in ann.objects: if obj.class_title in class_mapping: obj.draw(mask, color=[class_mapping[obj.class_title]])3.2 大尺寸数据集的内存优化
处理万级以上图像时,可采用分批处理策略:
def batch_convert(project_dir, output_dir, batch_size=500): project = sly.Project(project_dir, sly.OpenMode.READ) datasets = list(project.datasets) for i in range(0, len(datasets), batch_size): batch = datasets[i:i+batch_size] # 创建临时目录处理当前批次 temp_dir = os.path.join(output_dir, f"batch_{i}") convert_dataset(project_dir, temp_dir) # 合并结果到最终目录 for ds_name in os.listdir(temp_dir): shutil.move(os.path.join(temp_dir, ds_name), os.path.join(output_dir, ds_name))3.3 与PyTorch数据加载器的无缝对接
创建自定义Dataset类:
from torch.utils.data import Dataset from PIL import Image class SuperviselyDataset(Dataset): def __init__(self, img_dir, mask_dir, transform=None): self.img_dir = img_dir self.mask_dir = mask_dir self.transform = transform self.samples = [ f for f in os.listdir(img_dir) if f.endswith(('.jpg', '.png')) ] def __len__(self): return len(self.samples) def __getitem__(self, idx): img_name = self.samples[idx] img_path = os.path.join(self.img_dir, img_name) mask_path = os.path.join(self.mask_dir, os.path.splitext(img_name)[0] + '.png') image = Image.open(img_path).convert('RGB') mask = Image.open(mask_path) if self.transform: image = self.transform(image) mask = self.transform(mask) return image, mask4. 质量验证与性能优化
4.1 转换结果验证指标
建议在转换后运行以下检查脚本:
def validate_conversion(output_dir): issues = [] for root, _, files in os.walk(output_dir): for f in files: if f.endswith('.png'): mask = cv2.imread(os.path.join(root, f), cv2.IMREAD_GRAYSCALE) unique = np.unique(mask) if not np.array_equal(unique, [0,1]): issues.append((f, unique.tolist())) if issues: print(f"发现{len(issues)}个问题文件:") for f, vals in issues[:5]: # 最多显示5个示例 print(f"{f}: 包含像素值 {vals}") else: print("所有mask文件验证通过!") return issues4.2 转换速度优化技巧
通过并行处理加速转换:
from concurrent.futures import ThreadPoolExecutor def parallel_convert(project_dir, output_dir, workers=4): project = sly.Project(project_dir, sly.OpenMode.READ) os.makedirs(output_dir, exist_ok=True) def process_item(dataset, item_name): item_paths = dataset.get_item_paths(item_name) ann = sly.Annotation.load_json_file(item_paths.ann_path, project.meta) mask = np.zeros(ann.img_size, dtype=np.uint8) ann.draw(mask, color=[1]) output_path = os.path.join(output_dir, dataset.name, Path(item_name).stem + '.png') cv2.imwrite(output_path, mask) with ThreadPoolExecutor(max_workers=workers) as executor: futures = [] for dataset in project: os.makedirs(os.path.join(output_dir, dataset.name), exist_ok=True) for item_name in dataset: futures.append(executor.submit(process_item, dataset, item_name)) for future in tqdm.tqdm(futures, total=len(futures)): future.result()在实际项目中,这套转换流程已经处理过超过50万张人像标注数据,最耗时的部分往往是磁盘IO而非计算过程。建议使用SSD存储并适当增加并行工作线程数(通常设置为CPU核心数的2-3倍)。