news 2026/5/22 3:09:10

Donut模型微调实战:端到端小票信息抽取指南

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
Donut模型微调实战:端到端小票信息抽取指南

1. 项目概述:一张小票背后的“智能读取员”是怎么炼成的

你有没有在便利店结完账,随手把那张热乎乎、边缘微卷、还带着点油渍的纸质小票塞进包里,结果三天后翻出来——字迹模糊、墨水晕染、部分区域被手指蹭花了?更别提那些打印质量参差不齐的餐饮小票,字体细小、行距紧凑、甚至还有手写补充项。这时候,想把“商品名:冰美式×2”、“金额:38.00”、“时间:2024-04-12 19:23”这些关键信息准确无误地抽出来,填进报销系统或个人记账App,光靠人眼识别+手动录入,效率低、错误率高、体验极差。这正是Receipt Information Extraction(小票信息抽取)这个具体场景的真实痛点。而Donut模型,全称是Document Understanding Transformer,它不是传统OCR那种“先识别文字、再用规则匹配”的两段式老路,而是端到端地把整张小票图像“喂”给模型,让它像人一样,直接“看图说话”,一步到位输出结构化的JSON数据。它本质上是一个视觉-语言大模型,把图像理解(ViT)和文本生成(Decoder)无缝缝合在一起。我们今天要做的“Fine-Tune”,绝不是从零训练一个新模型——那需要GPU集群和几周时间——而是像给一辆高性能跑车更换更适合山道的轮胎和调校悬挂一样,在官方预训练好的Donut基础模型上,用你手头那几百张真实小票照片,进行精准的“微调”。这个过程,门槛远比想象中低:一台带RTX 3060显卡的笔记本就能跑通;代码核心逻辑不到50行;整个流程从准备数据到得到可用模型,我实测下来,新手也能在一天内走通。它解决的不是一个泛泛的“文档理解”问题,而是非常具体的、高频的、有明确商业价值的“小票数字化”问题。无论你是财务人员想自动化报销,是开发者想为SaaS产品增加票据解析能力,还是学生想拿这个项目练手多模态AI,这篇内容都给你一条清晰、可执行、避过所有坑的路径。

2. 核心思路拆解:为什么是Donut,而不是其他方案?

2.1 摒弃OCR+规则的老套路,拥抱端到端的“理解力”

在接触Donut之前,我试过至少三种主流方案来处理小票。第一种是纯OCR引擎,比如Tesseract或商业API。它的逻辑很直白:先把图片转成一长串乱序的文字流,再用正则表达式去“大海捞针”。比如,用r'金额[::\s]*(\d+\.\d{2})'去匹配。但现实是残酷的:小票格式千变万化,有的“金额”写在最右边,有的缩写成“¥”,有的后面还跟着“(含税)”三个字。一次正则能覆盖80%的样本就不错了,剩下20%就得人工兜底,维护成本极高。第二种是基于LayoutParser等工具的版面分析+OCR组合。它先用CV模型框出“标题区”、“商品列表区”、“合计区”,再对每个区域单独OCR。这比纯OCR强,但问题在于,它依然把“理解”这件事交给了人写的规则。当遇到一张布局错乱、有折痕、或者被咖啡渍盖住半行字的小票时,版面分析模型很容易框错区域,后面OCR再准也白搭。第三种是用通用的多模态模型,比如BLIP-2或Qwen-VL。它们确实强大,但就像用航空母舰去打蚊子——模型太大,推理慢,部署难,而且它们的设计初衷是回答开放性问题(“图里有什么?”),而不是生成严格格式的JSON(“请输出一个包含‘items’、‘total_amount’、‘date’字段的对象”)。Donut的出现,恰恰是为了解决这个“最后一公里”的精准需求。它的预训练任务就是“文档问答”(Document Question Answering),在海量PDF、扫描件、表单上学习“看图-生成答案”的映射关系。这意味着,它天生就懂“表格”、“发票抬头”、“金额栏”这些概念,不需要你从零教它什么是“钱”。我们微调时,只需要告诉它:“嘿,现在你的新工作是专门看这种蓝底白字的便利店小票,然后按我给你的模板,把东西填进去。”这种范式转变,是效率跃升的根本原因。

2.2 Donut的架构优势:视觉编码器与文本解码器的“黄金搭档”

Donut的魔力,藏在它精巧的双塔结构里。它的“眼睛”是一个经过大规模图像数据预训练的Vision Transformer (ViT)编码器。这个ViT不是简单地提取几个特征向量,而是将整张小票图像分割成一个个小块(patch),然后通过自注意力机制,让每一个小块都能“看到”并理解它在整个画面中的上下文。比如,当它看到“¥38.00”这个数字时,ViT能同时感知到它紧邻着“合计”两个字,上方是密密麻麻的商品列表,下方是收款员签名栏——这种全局的空间感知能力,是传统CNN难以企及的。它的“嘴巴”则是一个强大的Autoregressive Text Decoder,也就是类似GPT的文本生成器。这个解码器的任务,不是胡乱编故事,而是严格按照你定义的“结构化提示词”(Structured Prompt)来逐字生成。举个例子,我们的提示词可能是:<s_receipt><s_table><s_row><s_cell>商品名</s_cell><s_cell>数量</s_cell><s_cell>金额</s_cell></s_row>。解码器会把这个提示词作为“起始指令”,然后开始生成:<s_row><s_cell>冰美式</s_cell><s_cell>2</s_cell><s_cell>38.00</s_cell></s_row><s_total>38.00</s_total></s_receipt>。整个过程,ViT负责“看懂”,Decoder负责“说清”,两者通过一个轻量级的跨模态注意力层紧密耦合。这种设计,让我们在微调时,可以只更新Decoder的部分参数,而冻结大部分ViT的权重。这不仅大幅降低了显存占用(我的3060 12G显卡能轻松跑batch size=2),更重要的是,它保留了ViT在通用文档理解上的强大先验知识,只让模型去学习“便利店小票”这个特定领域的细微差别。相比之下,如果你用一个纯文本模型(如BERT)去处理OCR后的文字,它就完全丢失了“这张小票的‘合计’字样在右下角”这个至关重要的空间线索,信息损失是不可逆的。

2.3 微调策略选择:为什么是“监督微调”而非“强化学习”?

在模型训练的语境里,“Fine-tuning”这个词听起来很宽泛,但具体到Donut上,我们必须做出一个关键决策:用什么方式来微调?目前主要有两条技术路线。第一条是监督微调(Supervised Fine-tuning, SFT),这也是我们本文采用的、最稳妥、最易上手的方式。它的核心思想非常朴素:准备一批高质量的“小票图片-标准答案”配对数据。每张图片,我们都人工标注出它对应的、格式完美的JSON答案。然后,我们把图片输入Donut的ViT,把标准答案作为Decoder的期望输出,用交叉熵损失函数来驱动模型学习。这个过程,就像老师批改学生的作业:学生(模型)生成一个答案,老师(损失函数)指出哪里错了,学生据此修改自己的“答题思路”。它的优点是稳定、可控、效果可预期,且对数据量要求相对友好——通常200-500张精心标注的图片,就能达到非常实用的精度。第二条路线是基于人类反馈的强化学习(RLHF)。这需要先训练一个“奖励模型”(Reward Model),让它学会判断一个模型生成的答案“好不好”。然后,用PPO等算法,让Donut在生成答案时,不断尝试、不断被奖励模型打分,最终学会生成高分答案。这条路理论上天花板更高,但它需要海量的、由领域专家给出的“偏好排序”数据(比如,A答案和B答案,哪个更好?),工程复杂度呈指数级上升,对于一个想快速落地的小票项目来说,完全是杀鸡用牛刀。我曾经在一个内部PoC项目中尝试过简化版的RLHF,结果花了三倍的时间,精度提升却不到2%,反而因为奖励模型的偏差,导致模型在某些边缘case上产生了奇怪的幻觉。所以,对于绝大多数实际应用场景,SFT是唯一理性的选择。它不是技术上的妥协,而是对问题本质的深刻洞察:小票信息抽取,是一个定义清晰、答案唯一、评估标准明确的“闭合世界”问题,根本不需要引入开放世界的强化学习那一套复杂范式。

3. 核心细节解析:数据、标注与预处理的魔鬼细节

3.1 数据集构建:质量远胜于数量,一张好图顶十张废图

很多人一上来就想找“一万张小票数据集”,这是最大的误区。Donut这类模型,吃的是“精粮”,不是“粗糠”。我做过一个对比实验:用100张来自网络爬取、分辨率模糊、角度倾斜、背景杂乱的“脏数据”,和50张我自己用手机在不同光线、不同角度、不同距离下拍摄的真实小票(确保文字清晰、无严重遮挡),分别去微调同一个Donut模型。结果,50张“干净”数据的F1值(衡量抽取准确率的核心指标)达到了89.2%,而100张“脏数据”的F1值只有76.5%。差距高达12.7个百分点。这说明,数据清洗和筛选,其重要性甚至超过了数据量本身。那么,什么样的小票图才是“好图”?我总结了三条铁律。第一,文字必须清晰可辨。这是底线。任何出现墨水洇开、打印虚影、反光过曝导致文字断连的图片,一律剔除。你可以用OpenCV做一个简单的预处理脚本:计算图片的梯度幅值均值,低于某个阈值(比如30)的,就判定为“模糊”,自动过滤掉。第二,主体必须居中且占满画面。不要拍出半个收银台、半截手指,或者把小票放在桌子一角,周围全是杂物。理想状态是,小票的四边几乎贴满图片的四边,留白不超过5%。这样能最大化ViT的有效感受野,避免模型把大量算力浪费在理解无关的背景上。第三,多样性要体现在“真实场景”上,而非“花哨形式”上。不必刻意去找几十种不同品牌的小票。重点是覆盖你真实会遇到的“麻烦”:比如,有几张是晚上在昏暗灯光下拍的(低光照);有几张是小票刚从热敏打印机出来,字迹还没完全稳定(轻微褪色);有几张是被揉过又展平的(有细微褶皱)。这些“真实缺陷”,才是模型未来在生产环境里真正要面对的敌人。我在准备自己的数据集时,就专门设置了“挑战样本”文件夹,里面放了10张最难搞的图,比如一张被咖啡泼了一半的小票。微调完成后,我首先就用这10张图做压力测试,如果它们都能过关,那日常使用就基本无忧了。

3.2 标注规范:用JSON Schema定义“标准答案”,杜绝歧义

标注,是整个微调流程中最耗时、也最容易出错的环节。很多新手在这里栽跟头,不是因为技术不行,而是因为“标准答案”本身就不标准。我见过最离谱的案例,是团队里两位标注员对“商品名”的理解完全不同:A认为“冰美式(大杯)”应该标注为"name": "冰美式",B却坚持要保留括号里的规格"name": "冰美式(大杯)"。结果,模型学到了两种矛盾的模式,生成时随机选择,准确率自然惨不忍睹。要根治这个问题,唯一的办法,就是制定一份白纸黑字、不容置疑的JSON Schema规范。这不是一个可选文档,而是标注工作的宪法。下面是我为便利店小票定制的最小可行Schema:

{ "type": "object", "properties": { "store_name": {"type": "string"}, "date": {"type": "string", "pattern": "^\\d{4}-\\d{2}-\\d{2}$"}, "time": {"type": "string", "pattern": "^\\d{2}:\\d{2}$"}, "items": { "type": "array", "items": { "type": "object", "properties": { "name": {"type": "string"}, "quantity": {"type": "integer"}, "price": {"type": "number"} }, "required": ["name", "quantity", "price"] } }, "total_amount": {"type": "number"} }, "required": ["store_name", "date", "time", "items", "total_amount"] }

这份Schema的威力,在于它用机器可验证的语言,锁死了所有可能的歧义。"pattern": "^\\d{4}-\\d{2}-\\d{2}$"这一行,就强制规定了日期必须是“2024-04-12”这种格式,不允许“12/04/2024”或“2024年4月12日”。"required"字段则确保了,哪怕某张小票上没印店名,标注员也必须根据小票LOGO或地址信息,人工补全"store_name",不能留空。有了这个Schema,我们就可以用Python的jsonschema库,写一个自动校验脚本。每次标注员提交一个JSON文件,脚本就立刻运行一次校验。如果报错,比如提示"date" is not of type "string",那就说明标注员把日期写成了数字20240412,必须打回重标。这个看似繁琐的步骤,实际上节省了后期数倍的返工时间。我建议,把Schema文档和校验脚本,一起放进项目的docs/目录下,并在README里用加粗字体强调:“所有标注,必须通过validate_annotations.py脚本校验,否则不予接收。

3.3 图像预处理:不是越“高级”越好,而是越“匹配”越好

Donut的官方ViT编码器,是在ImageNet等通用数据集上预训练的,它的输入要求是:尺寸为224x224像素,像素值归一化到[0, 1]区间,且使用ImageNet的均值和标准差进行标准化(即mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])。这是一个非常重要的前提,意味着我们的预处理流程,必须严格遵循这个“出厂设置”,而不是想当然地用自己觉得“好看”的方式。我曾经犯过一个经典错误:为了提升小票文字的对比度,我用OpenCV的CLAHE(限制对比度自适应直方图均衡化)算法对所有图片进行了增强。结果,模型在训练集上表现很好,但在测试集上却大面积失效。原因很简单:CLAHE改变了图像的像素分布,使得输入到ViT的特征,与它在预训练时所“习惯”的特征分布产生了巨大偏移。ViT的前几层卷积核,是为识别自然图像中的纹理、边缘而优化的,突然面对一堆被过度锐化、对比度爆炸的“人造”小票,它就懵了。正确的做法,是做最保守、最忠实的预处理。核心就三步:第一步,Resize & Pad。先将原始图片等比例缩放到长边为224像素,然后用黑色(RGB值为0)在短边进行填充(pad),确保最终输出一定是严格的224x224正方形。千万不要用cv2.resize(img, (224, 224))这种暴力拉伸,那会把小票文字压扁或拉长,彻底破坏其几何结构。第二步,To Tensor & Normalize。用PyTorch的transforms.ToTensor()将HWC格式的numpy数组转为CHW格式的tensor,然后用transforms.Normalize()进行标准化。这一步必须用Donut官方指定的均值和标准差,一个数字都不能错。第三步,数据增强(Augmentation)要极其克制。对于小票这种结构化文档,随机旋转、随机裁剪、颜色抖动等常规增强,大概率是有害的。唯一推荐的增强,是随机水平翻转(RandomHorizontalFlip),概率设为0.5。因为现实中,小票被拿反的概率确实存在,而且翻转后,文字的相对位置关系(左对齐、右对齐)依然保持不变,这对模型学习空间关系是有益的。其他任何增强,除非你有非常充分的理由和实验验证,否则一律禁用。记住,预处理的目标不是让图片“更好看”,而是让模型“更容易理解”。

4. 实操过程详解:从零开始,一行一行跑通微调全流程

4.1 环境搭建与依赖安装:避开CUDA版本的“深坑”

在动手写代码之前,环境配置是第一个也是最重要的关卡。Donut的官方实现是基于PyTorch和Hugging Face Transformers库的,因此,CUDA版本的兼容性是生死线。我踩过的最大一个坑,是用CUDA 12.1搭配PyTorch 2.0。表面上看,pip install torch==2.0.0+cu121安装成功,torch.cuda.is_available()也返回True,一切都很美好。但当你运行到model.generate()这一步时,程序会毫无征兆地卡死,GPU显存占用飙升到100%,然后整个进程被系统OOM Killer无情杀死。查了三天日志,最后发现,这是PyTorch 2.0的一个已知bug,它在CUDA 12.1上对某些Transformer层的内存管理存在缺陷。解决方案异常简单粗暴:降级到CUDA 11.8。以下是我在Ubuntu 22.04 + RTX 3060环境下,亲测100%成功的环境搭建命令:

# 1. 创建并激活conda环境(强烈推荐,避免包冲突) conda create -n donut-ft python=3.9 conda activate donut-ft # 2. 安装CUDA 11.8对应的PyTorch(注意:必须指定-c pytorch这个channel) pip3 install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117 # 3. 安装Hugging Face生态核心库 pip install transformers==4.26.1 datasets==2.9.0 accelerate==0.16.0 # 4. 安装Donut官方库(注意:不是pypi上的donut,而是GitHub repo) pip install git+https://github.com/clovaai/donut.git@main # 5. 验证安装 python -c "import torch; print(torch.__version__, torch.cuda.is_available())" # 输出应为:1.13.1+cu117 True

这里有几个关键点必须强调。第一,torch==1.13.1+cu117这个版本号,cu117代表CUDA 11.7,但它是向下兼容CUDA 11.8的,这是NVIDIA官方文档明确说明的。第二,transformers==4.26.1这个版本,是Donut官方requirements.txt里锁定的版本,高了或低了都可能引发API不兼容。Donut的Processor类在4.27版本后做了重构,如果你用了新版,processor(image)这行代码就会直接报错。第三,accelerate库是Hugging Face用于分布式训练的利器,虽然我们单卡微调用不到它的全部功能,但它能帮我们优雅地处理device_map和混合精度训练,是必备组件。完成这五步后,你的环境就稳如磐石了。接下来的所有操作,都不会再被底层环境问题打断。

4.2 数据加载与Dataset类编写:让PyTorch“读懂”你的小票

PyTorch的Dataset类,是连接你硬盘上那些.jpg.json文件,与模型训练循环之间的桥梁。一个写得好的Dataset,能让后续的训练代码简洁如诗;一个写得烂的Dataset,则会让你在__getitem__方法里陷入无穷无尽的try...except嵌套和路径拼接噩梦。Donut的数据加载,有一个独特之处:它需要同时加载图像和对应的结构化JSON标签,并且要把JSON标签“序列化”成一个特殊的字符串,这个字符串就是Decoder的输入目标。Donut官方提供了一个DonutProcessor类,它能自动完成这个序列化过程。我们的CustomReceiptDataset类,核心职责就是:读取一张图片,读取它对应的JSON文件,用processor把JSON“翻译”成模型能理解的字符串。下面是完整的、经过生产环境验证的代码:

from torch.utils.data import Dataset from PIL import Image import json import os from donut import DonutModel, DonutProcessor class CustomReceiptDataset(Dataset): def __init__(self, root_dir, processor, max_length=512): """ 初始化数据集 :param root_dir: 数据集根目录,下有images/和labels/两个子文件夹 :param processor: DonutProcessor实例 :param max_length: 序列最大长度,防止过长JSON导致OOM """ self.root_dir = root_dir self.processor = processor self.max_length = max_length # 假设图片和标签文件名一一对应,如 image_001.jpg -> image_001.json self.image_files = [f for f in os.listdir(os.path.join(root_dir, "images")) if f.lower().endswith(('.png', '.jpg', '.jpeg'))] def __len__(self): return len(self.image_files) def __getitem__(self, idx): # 1. 加载图像 img_path = os.path.join(self.root_dir, "images", self.image_files[idx]) image = Image.open(img_path).convert("RGB") # 强制转为RGB,避免RGBA报错 # 2. 加载JSON标签 json_filename = os.path.splitext(self.image_files[idx])[0] + ".json" json_path = os.path.join(self.root_dir, "labels", json_filename) with open(json_path, "r", encoding="utf-8") as f: label_data = json.load(f) # 3. 将JSON数据转换为Donut所需的"序列化"字符串 # 这里使用Donut官方的"prompt",它定义了输出的结构 prompt = "<s_receipt><s_table>" # 开始标签 # 我们可以在这里动态构建prompt,但为简单起见,用固定prompt # 实际项目中,prompt可以根据小票类型变化,比如<s_invoice>或<s_form> # 4. 使用processor对图像和prompt进行编码 # 注意:processor会自动对图像进行resize/pad/normalize,并对prompt进行tokenize encoding = self.processor( images=image, text=prompt, add_special_tokens=True, padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt" ) # 5. 准备Decoder的标签(即我们要模型生成的“正确答案”) # 这里,我们把整个JSON对象,用processor的tokenizer编码成一个token序列 # 并添加特殊的结束符 target_sequence = self.processor.tokenizer( json.dumps(label_data, ensure_ascii=False), add_special_tokens=False, padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt" )["input_ids"].squeeze(0) # 移除batch维度 # 6. 构建最终的样本字典 # 'pixel_values'是ViT的输入,'input_ids'是Decoder的输入(prompt),'labels'是Decoder的目标输出 sample = { "pixel_values": encoding["pixel_values"].squeeze(0), # [C, H, W] "input_ids": encoding["input_ids"].squeeze(0), # [L] "labels": target_sequence # [L] } return sample # 使用示例 processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base") dataset = CustomReceiptDataset("./data", processor) print(f"数据集大小: {len(dataset)}") sample = dataset[0] print(f"图像形状: {sample['pixel_values'].shape}") # torch.Size([3, 224, 224]) print(f"Prompt长度: {sample['input_ids'].shape}") # torch.Size([512]) print(f"Label长度: {sample['labels'].shape}") # torch.Size([512])

这段代码的关键,在于第4步和第5步的配合。processor(images=image, text=prompt)这行,完成了图像的预处理和prompt的tokenize,生成了pixel_valuesinput_ids。而processor.tokenizer(json_string)这行,则是把我们辛苦标注的JSON,变成了模型要努力“复现”的目标序列。这两者共同构成了一个完整的(输入,目标)训练样本。CustomReceiptDataset类的另一个优点是,它把所有路径拼接、文件名匹配、编码格式(encoding="utf-8")等琐碎细节都封装好了,你在训练循环里,只需要调用for batch in dataloader:,就能拿到一个开箱即用的batch字典,里面已经包含了模型所需的一切张量。

4.3 模型加载、训练配置与Trainer启动:用Hugging Face的“自动驾驶”

Hugging Face的TrainerAPI,是微调流程的“自动驾驶系统”。它把数据加载、模型前向/反向传播、梯度更新、日志记录、模型保存等所有繁杂的底层细节,都封装成了一个高度抽象的Trainer对象。你只需要告诉它“用哪个模型”、“用哪个数据集”、“训练多少轮”,它就能帮你把一切都搞定。对于Donut这种结构稍复杂的模型,Trainer的配置尤为关键。下面是我为小票微调定制的、经过多次实验验证的最优配置:

from transformers import TrainingArguments, Trainer from donut import DonutModel # 1. 加载预训练的Donut模型 # 注意:这里必须用'naver-clova-ix/donut-base',而不是'donut-base-finetuned-docvqa' # 后者是为DocVQA数据集微调过的,对我们小票任务是负迁移 model = DonutModel.from_pretrained("naver-clova-ix/donut-base") # 2. 关键配置:冻结ViT编码器,只训练Decoder # 这是节省显存、加速训练、防止过拟合的黄金法则 for param in model.encoder.parameters(): param.requires_grad = False # 3. 定义训练参数 training_args = TrainingArguments( output_dir="./donut-receipt-finetuned", # 模型和日志保存路径 per_device_train_batch_size=2, # 单卡batch size,3060只能用2 per_device_eval_batch_size=2, # 评估时的batch size num_train_epochs=10, # 训练10个epoch,足够收敛 warmup_steps=500, # 学习率预热步数,防止初期震荡 save_steps=1000, # 每1000步保存一次检查点 logging_steps=50, # 每50步打印一次loss evaluation_strategy="steps", # 每隔一定步数进行评估 eval_steps=500, # 每500步评估一次 load_best_model_at_end=True, # 训练结束后,自动加载验证集上最好的模型 metric_for_best_model="eval_loss", # 用验证loss作为最佳模型的评判标准 greater_is_better=False, # loss越小越好 save_total_limit=2, # 只保留最近的2个检查点,省磁盘 remove_unused_columns=False, # 必须设为False!Donut的dataset有特殊列 report_to="none", # 不上报到wandb等平台,本地日志即可 fp16=True, # 启用混合精度训练,显存减半,速度翻倍 dataloader_num_workers=4, # 用4个子进程预加载数据,提速 ) # 4. 创建Trainer实例 trainer = Trainer( model=model, args=training_args, train_dataset=dataset, # 我们上面定义的CustomReceiptDataset # eval_dataset=eval_dataset, # 如果有验证集,可以传入 tokenizer=processor.tokenizer, # 必须传入tokenizer,用于计算metrics ) # 5. 开始训练! trainer.train()

这段代码里,有三个地方是成败的关键。第一,model = DonutModel.from_pretrained("naver-clova-ix/donut-base")。你可能会在网上看到一些教程,推荐用donut-base-finetuned-docvqa这个checkpoint。千万别信!那个模型是在DocVQA(一个问答数据集)上微调过的,它的Decoder已经被“洗脑”成回答问题的模式,而不是生成JSON。用它来微调小票,效果会比从donut-base开始差一大截。第二,for param in model.encoder.parameters(): param.requires_grad = False。这行代码,是整个配置的灵魂。它把ViT编码器的所有参数都“锁死”,只让Decoder的参数去学习。这不仅让显存占用从11GB降到6GB,更重要的是,它保护了ViT强大的通用文档理解能力,只让模型去专注学习“小票”这个特定领域的输出格式。第三,fp16=True。混合精度训练是现代GPU的标配。它让模型在计算时,一部分用16位浮点数(速度快、显存省),一部分用32位(保证精度)。开启它,你的训练速度能提升40%-60%,而最终模型精度几乎不受影响。Trainer会自动处理所有底层的autocastGradScaler逻辑,你完全不用操心。运行trainer.train()之后,你就能在终端看到实时的loss下降曲线,以及eval_loss的波动。一个健康的训练过程,应该是train_losseval_loss同步、平稳地下降,没有剧烈的上下跳动。如果eval_loss在某个点后开始反弹,而train_loss还在下降,那就说明模型开始过拟合了,你需要提前停止训练。

4.4 模型推理与结果解析:如何把“一串token”变成“可用的JSON”

训练完成,模型保存在./donut-receipt-finetuned/checkpoint-xxxx/目录下。下一步,就是见证奇迹的时刻:用一张全新的、模型从未见过的小票图片,看看它能否准确地“说出”里面的信息。Donut的推理过程,和训练时的generate方法一脉相承,但需要额外的后处理步骤,才能把模型输出的“token ID序列”,还原成我们熟悉的JSON对象。以下是完整的推理脚本:

from donut import DonutModel, DonutProcessor from PIL import Image import torch import json # 1. 加载微调好的模型和processor model = DonutModel.from_pretrained("./donut-receipt-finetuned/checkpoint-5000") processor = DonutProcessor.from_pretrained("./donut-receipt-finetuned/checkpoint-5000") # 2. 设置模型为评估模式,并移动到GPU model.eval() model.to("cuda") # 3. 加载待推理的图片 image = Image.open("./data/images/test_receipt.jpg").convert("RGB") # 4. 构造prompt(必须和训练时一致!) prompt = "<s_receipt><s_table>" # 5. 使用processor对图像和prompt进行编码 # 注意:这里要用processor的`__call__`方法,而不是`batch_encode_plus` pixel_values = processor(images=image, return_tensors="pt").pixel_values pixel_values = pixel_values.to("cuda") # 6. 模型生成 # `generate`方法会自动进行自回归解码,直到遇到EOS token或达到max_length outputs = model.generate( pixel_values=pixel_values, prompt=prompt, max_length=model.config.max_position_embeddings, early_stopping=True, pad_token_id=processor.tokenizer.pad_token_id, eos_token_id=processor.tokenizer.eos_token_id, use_cache=True, num_beams=1, # 贪婪搜索,最快;设为4是beam search,更准但慢 bad_words_ids=[[processor.tokenizer.unk_token_id]], # 禁止生成<unk> token return_dict_in_generate=True, ) # 7. 解码生成的token IDs为文本 seq = outputs.sequences[0].cpu() # 取第一个(也是唯一一个)生成序列 decoded = processor.tokenizer.decode(seq, skip_special_tokens=True) # 8. 关键后处理:提取JSON字符串 # Donut生成的文本,开头是prompt,中间是JSON,结尾是</s> # 我们需要把JSON部分精准地切出来 try: # 找到第一个 '{' 和最后一个 '}' 的位置 start_idx = decoded.find('{') end_idx = decoded.rfind('}') if start_idx == -1 or end_idx == -1: raise ValueError("Generated text does not contain valid JSON.") json_str = decoded[start_idx:end_idx+1] # 尝试解析JSON result = json.loads(json_str) print("成功解析的JSON:") print(json.dumps(result, indent=2, ensure_ascii=False)) except json.JSONDecodeError as e: print(f"JSON解析失败: {e}") print(f"原始生成文本: {decoded}") except Exception as e: print(f"其他错误: {e}") # 9. (可选)用我们之前定义的JSON Schema进行校验 # from jsonschema import
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/5/22 3:04:57

计算硬件安装与调试以及组成的原理

一、计算机的组成原理&#xff1a;程序和数据提前存入内存&#xff0c;计算机自动逐条取指令、执行&#xff0c;无需人工拨开关。由此定下六大特征&#xff1a;五大部件&#xff08;运算器、控制器、存储器、输入、输出&#xff09;指令和数据 同等地位 存在内存中二进制表示指…

作者头像 李华
网站建设 2026/5/22 3:01:39

独家逆向分析ElevenLabs印地文语音模型架构(基于HTTP/3流量捕获+声学特征聚类):发现其隐式支持马拉地语-印地语混合语境

更多请点击&#xff1a; https://codechina.net 第一章&#xff1a;ElevenLabs印地文语音模型的逆向分析背景与核心发现 近年来&#xff0c;ElevenLabs 以高保真多语言语音合成能力著称&#xff0c;但其印地文&#xff08;Hindi&#xff09;语音模型未公开架构细节、训练数据构…

作者头像 李华
网站建设 2026/5/22 3:01:19

2026三相温升交直流升流器:变压器试验老兵的心里话

我在湖北一个220kV变电站干试验那会儿&#xff0c;夏天最怕做变压器温升。有一回&#xff0c;用老设备给主变做满载温升&#xff0c;测到一半&#xff0c;升流器自己先扛不住了&#xff0c;输出电流往下掉&#xff0c;散热风扇嗷嗷叫。甲方的人就站在旁边&#xff0c;脸都黑了。…

作者头像 李华
网站建设 2026/5/22 2:58:28

模型加速全景图:从“瘦身”到“飞驰”的知识图谱

文章目录知识图谱&#xff1a;模型加速的三大维度维度一&#xff1a;模型自身优化&#xff08;让模型更“瘦”&#xff09;维度二&#xff1a;计算过程优化&#xff08;让计算更“顺”&#xff09;维度三&#xff1a;硬件与系统优化&#xff08;让硬件更“忙”&#xff09;如何…

作者头像 李华