news 2026/5/1 7:16:51

Mosaic:面向超长序列的多GPU注意力分片方案

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
Mosaic:面向超长序列的多GPU注意力分片方案

Transformer的"二次方注意力瓶颈"的问题是老生常谈了。这个瓶颈到底卡在哪实际工程里怎么绕过去?本文从一个具体问题出发,介绍Mosaic这套多轴注意力分片方案的设计思路。

注意力的内存困境

注意力机制的计算公式:

Attention(Q, K, V) = softmax(QKᵀ / √d) × V

问题出在QKᵀ这个矩阵上,它的形状是

(序列长度 × 序列长度)

拿150,000个token的序列算一下:

Memory = 150,000² × 4 bytes = 90 billion bytes ≈ 84 GB

这只是注意力权重本身的开销,而且还是单层、单头。A100的显存上限是80GB,放不下就是放不下。

现有方案的局限

FlashAttention它通过分块计算,不需要把完整的注意力矩阵实例化出来,内存复杂度从O(n²)降到O(n)。单卡场景下效果很好,但问题是整个序列还是得塞进同一张GPU。

Ring Attention换了个思路:把序列切片分到多张GPU上,每张卡持有一部分Q,K和V在GPU之间像传令牌一样轮转,一维序列处理起来是很不错的。

但是多维怎么办?

比如处理表格数据的Transformer,输入张量形状是

(batch, rows, features, embed)

。模型需要在不同维度上做注意力:features维度只有5个token,rows维度却有150,000个。前者单卡轻松搞定,后者则必须分片。

现有的库都没法干净地处理这种多轴场景。手写的话,每个轴要单独写分片逻辑,进程组管理、张量reshape全得自己来。代码会变得很脏。

Mosaic的设计

Mosaic本质上是个协调层,负责把不同的注意力轴路由到合适的计算后端:

import mosaic # Small axis: run locally feature_attn = mosaic.MultiAxisAttention( embed_dim=96, num_heads=4, attention_axis=2, # features dimension backend="local" # no communication needed ) # Large axis: shard across GPUs row_attn = mosaic.MultiAxisAttention( embed_dim=96, num_heads=4, attention_axis=1, # rows dimension backend="ring" # ring attention across GPUs )

底层Mosaic会自动处理轴的置换、QKV投影前的reshape、后端分发、以及计算完成后张量形状的还原。模型代码保持清晰,分布式的复杂性被封装掉了。

Ring Attention的工作机制

核心思想其实很直接:不需要同时持有全部的K和V。可以分批计算注意力分数,逐步累积,最后再做归一化。

比如说4张GPU的情况下流程是这样的:

Initial state: GPU 0: Q₀, K₀, V₀ GPU 1: Q₁, K₁, V₁ GPU 2: Q₂, K₂, V₂ GPU 3: Q₃, K₃, V₃ Step 1: Each GPU computes attention with its local K, V GPU 0: score₀₀ = Q₀ @ K₀ᵀ ... Step 2: Pass K, V to the next GPU in the ring GPU 0 receives K₃, V₃ from GPU 3 GPU 0 sends K₀, V₀ to GPU 1 Step 3: Compute attention with received K, V GPU 0: score₀₃ = Q₀ @ K₃ᵀ Accumulate with score₀₀ Repeat for all chunks... Final: Each GPU has complete attention output for its Q chunk

单卡内存占用变成O(n²/p),p是GPU数量。8张卡的话内存需求直接砍到1/8。150k序列从84GB降到约10GB每卡。

Mesh2D:更激进的分片

序列特别长的时候Ring Attention的线性分片可能还不够,这时候可以用Mesh2D把Q和K都切分了:

4 GPUs arranged in 2×2 mesh: K₀ K₁ ┌──────┬──────┐ Q₀ │GPU 0 │GPU 1 │ ├──────┼──────┤ Q₁ │GPU 2 │GPU 3 │ └──────┴──────┘ Each GPU computes one tile of QKᵀ

内存复杂度降到O(n²/p²)。64张卡组成8×8网格时,每卡内存需求下降64倍。

attn=mosaic.MultiAxisAttention( embed_dim=128, num_heads=8, attention_axis=1, backend="mesh2d", mesh_shape=(8, 8) )

感知集群拓扑的组合策略

在实际部署环境里,不同GPU之间的通信带宽差异很大。节点内GPU走NVLink能到900 GB/s,跨节点通过InfiniBand通常只有200 GB/s左右。

ComposedAttention

就是针对这种拓扑特征设计的:

# 4 nodes × 8 GPUs = 32 total composed = mosaic.ComposedAttention( mesh_shape=(4, 8), # (nodes, gpus_per_node) head_parallel=True, # Split heads across nodes (slow link) seq_parallel="ring" # Ring within nodes (fast link) )

需要更精细控制的话,可以用

HierarchicalAttention

hier = mosaic.HierarchicalAttention( intra_node_size=8, intra_node_strategy="local", # Compute locally within node inter_node_strategy="ring" # Ring between node leaders )

重通信走快链路轻通信才跨节点。

实现细节

整个库大约800行Python,核心代码如下:

class MultiAxisAttention(nn.Module): def forward(self, x): # 1. Move attention axis to seq position x, inv_perm = self._permute_to_seq(x) # 2. Flatten batch dims, project QKV x = x.view(-1, seq_len, embed_dim) qkv = self.qkv_proj(x).view(batch, seq, 3, heads, head_dim) q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0) # 3. Dispatch to backend out = self._attn_fn(q, k, v) # local, ring, or mesh2d # 4. Project output, restore shape out = self.out_proj(out.transpose(1, 2).reshape(...)) return out.permute(inv_perm)

后端封装了现有的成熟实现:

local

后端调用

F.scaled_dot_product_attention

(也就是FlashAttention),

ring

后端用ring-flash-attn库的

ring_flash_attn_func

mesh2d

是自定义的all-gather加SDPA,所有的底层都跑的是FlashAttention内核。

所有后端统一用FlashAttention的融合GEMM+softmax实现。后端函数在初始化时就绑定好,前向传播不做分支判断。张量操作尽量用

x.view()

而不是

x.reshape()

,保持内存连续性。集合通信的目标张量预分配好,避免

torch.cat

的开销。模块级别做导入不在每次前向传播时产生import开销。

快速上手

安装:

pip install git+https://github.com/stprnvsh/mosaic.git # With ring attention support pip install flash-attn ring-flash-attn

单节点启动:

torchrun --nproc_per_node=4 train.py

多节点的话:

# Node 0 torchrun --nnodes=2 --nproc_per_node=8 --node_rank=0 \ --master_addr=192.168.1.100 --master_port=29500 train.py # Node 1 torchrun --nnodes=2 --nproc_per_node=8 --node_rank=1 \ --master_addr=192.168.1.100 --master_port=29500 train.py

训练脚本示例:

import mosaic import torch.distributed as dist dist.init_process_group("nccl") ctx = mosaic.init(sp_size=dist.get_world_size()) model = MyModel().to(ctx.device) # Data is pre-sharded: each GPU has seq_total / world_size tokens x_local = load_my_shard() out = model(x_local) # Communication handled by Mosaic

总结

最后,Mosaic不会自动并行化模型(这个用nnScaler),不管数据并行(PyTorch DDP/FSDP的事),也不处理模型分片(交给FSDP或Megatron)。

Mosaic专注于一件事:多轴注意力的分片路由,这套方案最初是给nanoTabPFN做的,一个表格数据Transformer。

这个模型要同时在rows(150k个)和features(5个)两个维度做注意力。标准Ring Attention对维度语义没有感知,它只认序列这个概念,分不清rows和features的区别。

所以Mosaic需求很明确:小轴本地算,大轴分布式算,轴的路由逻辑不能侵入模型代码,有兴趣的可以试试。

https://avoid.overfit.cn/post/791e0f30540e4d289a43d01d383e8ab2

作者:Pranav Sateesh

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/5/1 6:08:40

微信小程序 PHP_uniapp的农产品质量追溯系统_gkm0juhi

微信小程序 PHPUniapp 农产品质量追溯系统摘要 该系统基于微信小程序开发,采用PHP后端与Uniapp前端框架结合,实现农产品从生产到销售的全流程质量追溯。系统通过区块链技术确保数据不可篡改,提升消费者对农产品安全的信任度。 核心功能模块 生…

作者头像 李华
网站建设 2026/5/1 5:00:21

微信小程序 PHP_uniapp的汽车销售库存管理系统785h00gj

系统概述 微信小程序结合PHP和UniApp开发的汽车销售库存管理系统(编号785h00gj)是一款针对汽车经销商设计的数字化管理工具。该系统通过移动端与后台协同,实现车辆信息管理、库存监控、销售跟踪、客户管理等功能,提升汽车销售流程…

作者头像 李华
网站建设 2026/4/15 3:42:24

PyQt5(十一)如何打包成exe

1、在pycharm的终端输入pip install pyinstaller -i https://pypi.tuna.tsinghua.edu.cn/simple然后输入pyinstaller -F -w main.py在文件夹的dist中就会出现main.exe,打开速度看电脑配置。

作者头像 李华
网站建设 2026/5/1 4:59:59

基于大数据爬虫+Python+机器学习的电商农产品销售预测系统设计与实现(精品源码+论文+答辩PPT)

博主介绍:CSDN毕设辅导第一人、靠谱第一人、全网粉丝50W,csdn特邀作者、博客专家、腾讯云社区合作讲师、CSDN新星计划导师、Java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和学生毕业项目实战,高校老师/讲师/同行前辈交…

作者头像 李华
网站建设 2026/4/24 4:48:25

全新UI简易漂流瓶系统源码

源码介绍:全新UI简易漂流瓶系统源码 更新日志: ✔ 功能优化:新增「自己不能捡到自己瓶子」的防误触机制 ✔ 交互升级:优化页面细节,操作体验更流畅下载地址(无套路,无须解压密码)htt…

作者头像 李华
网站建设 2026/5/1 5:02:13

格雷厄姆的工作资本策略:关注流动性

格雷厄姆的工作资本策略:关注流动性关键词:格雷厄姆、工作资本策略、流动性、价值投资、财务分析、股票估值、投资决策摘要:本文深入探讨了格雷厄姆的工作资本策略,重点关注其对流动性的考量。首先介绍了该策略提出的背景和目的&a…

作者头像 李华