news 2026/4/30 14:28:52

JAX性能优化实战:7个变换让TPU/GPU吃满算力

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
JAX性能优化实战:7个变换让TPU/GPU吃满算力

JAX跑得快的技巧其实很简单:通过组合变换让XLA能看到大块连续的计算,比如说批处理、融合、分片,让每一步在单设备或多设备同步时都像一个干净的kernel。

我们今天就来总结7个能够提高运行速度的JAX变换组合

1、 jit 优先,形状稳定

jit

对函数做一次追踪后XLA负责融合算子,形状稳定、无副作用时,Python处理的开销就被分摊掉,可以提高运行速度。

形状创建和静态参数要么挪到step外部,要么显式标记为static。

donate_argnums

能让JAX复用缓冲区,省掉不必要的内存拷贝。step之间保持dtype和shape一致,trace结果才能被缓存下来。

import jax, jax.numpy as jnp @jax.jit(donate_argnums=(0,)) def sgd_step(params, batch, lr): x, y = batch def loss_fn(p): preds = model_apply(p, x) # pure function return jnp.mean((preds - y) ** 2) grads = jax.grad(loss_fn)(params) return jax.tree_map(lambda p, g: p - lr * g, params, grads)

每个(shape, dtype, static-arg)组合只追踪一次。频繁retrace多半是输入shape在变,或者Python逻辑泄漏进了计算图。

2、vmap替换Python循环

vmap

在leading axis上做向量化,XLA直接把batch融进kernel。for循环没了设备launch就少了,内存访问也更连续。

# per-example loss def example_loss(params, x, y): pred = model_apply(params, x) return jnp.mean((pred - y) ** 2) # batch it without writing loops batched_loss = jax.vmap(example_loss, in_axes=(None, 0, 0)) # params broadcasted

嵌套

vmap

可以搞2D batch,比如time × batch,只要别超HBM容量。

vmap

适合做内层微批处理,比如ensemble或MC sampling这类场景,外层维度留给分片。

3、长循环的融合利器Scan

RNN、展开解码、迭代求解器,这些场景用

scan

比Python循环快。

scan

只编译一次循环体跑在XLA的while-loop里,Python开销基本为0,融合和内存复用也更激进。

from jax import lax def rnn_cell(carry, x): h = carry h = jnp.tanh(W_hh @ h + W_xh @ x + b) y = W_hy @ h return h, y # (carry, output) def rnn_forward(h0, xs): hT, ys = lax.scan(rnn_cell, h0, xs) # xs: [T, B, D] return hT, ys

循环状态用

carry

传递,body保持小而纯净,要注意保持形状不要变,比如:序列模型、diffusion step循环、定点迭代、beam解码(形状稳定时)都适用。

4、remat可以用计算换内存

批次大了TPU/GPU的FLOP利用率往往更高。

remat

(也叫checkpoint)会丢掉部分中间激活,反向时重算这样峰值显存下来batch就能开的更大。

from jax import remat def block(params, x): x = jax.nn.gelu(x @ params['w1']) x = x @ params['w2'] return x fast_block = remat(block) # checkpointed @jax.jit def forward(params, x): for _ in range(6): x = x + fast_block(params, x) return x

只包最重的子块就行,比如attention加MLP那几层。同时配合

vmap

或分片,全局batch能再往上拉。不过需要一些额外FLOPs,但如果换来1.3到2倍的batch increase,wall-clock往往更短。

5、pmap单机多卡数据并行

pmap

把函数复制到单主机的多个设备上(8卡工作站、单节点8核TPU),梯度可以自动all-reduce,并且每设备只编译一次。

from jax import pmap, lax @pmap(axis_name='d') def train_step(params, batch, lr): x, y = batch # each device sees [local_B, ...] def loss_fn(p): pred = model_apply(p, x) loss = jnp.mean((pred - y) ** 2) return loss loss, grads = jax.value_and_grad(loss_fn)(params) loss = lax.pmean(loss, axis_name='d') grads = lax.pmean(grads, axis_name='d') params = jax.tree_map(lambda p, g: p - lr * g, params, grads) return params, loss

batch在leading axis分片,

lax.pmean

聚合loss和grads。单机场景下

pmap

简单可靠。跨主机扩展或者想做张量级细粒度分片可以成换

pjit

6、pjit+ 命名分片:SPMD并行

pjit

编译出单一SPMD程序可以跨设备跨主机运行。用mesh和

PartitionSpec

描述数组怎么切,JAX处理collective通信,这样数据并行、张量并行、混合并行都能做。

import jax from jax.sharding import Mesh, PartitionSpec as P import numpy as np devices = np.array(jax.devices()).reshape(2, 4) # 2 × 4 mesh (dp × mp) mesh = Mesh(devices, ('dp', 'mp')) @jax.jit # jit is optional when using pjit; shown when composing def model_apply_sharded(params, x): return model_apply(params, x) from jax.experimental.pjit import pjit with mesh: in_shard = (P('mp',), P('dp',)) # example; tailor to your shapes out_shard = P('dp',) # e.g., shard batch across dp step = pjit(model_apply_sharded, in_shardings=(P('mp',), P('dp',)), out_shardings=out_shard) y = step(params_sharded, x_sharded)

一般都是batch轴走

dp

,大矩阵维度(hidden size、heads)走

mp

。分片数需要跟设备拓扑对齐,跨主机流量才少。

7、value_and_grad的正确堆叠方式

规范写法是

jit(value_and_grad(loss, has_aux=True))

,外面可以再套一层

pmap

pjit

。这样forward只跑一遍metrics留在aux里带出来。

def loss_with_aux(params, batch): x, y = batch pred = model_apply(params, x) loss = jnp.mean((pred - y) ** 2) aux = {'mse': loss, 'mean_pred': jnp.mean(pred)} return loss, aux @jax.jit def train_step(params, opt_state, batch, lr): (loss, aux), grads = jax.value_and_grad(loss_with_aux, has_aux=True)(params, batch) updates, opt_state = optimizer_update(grads, opt_state, params, lr) params = optax_apply(updates, params) return params, opt_state, loss, aux
value_and_grad

jit

里面,JAX会把forward和backward一起stage。返回

(loss, aux)

日志指标不用再跑一遍forward。

这套组合很灵活:

vmap

做微批次,

scan

跑时序循环,外面套

pmap

pjit

donate_argnums

标上buffer。

总结

变长序列pad加mask,shape稳定是前提条件。traced代码里不要添加Python随机性,比如PRNG key要在外面split好。矩阵乘用

bfloat16

,这样数值稳定性也够用,吞吐量在TPU/GPU上表现的也很好。性能profile要重点看warm-up之后的tokens/sec或samples/sec。日志只看标量aux metrics就行,每step把大数组传回host是性能杀手。

JAX的性能不是黑盒:

jit
  • shape可以稳定打底,
vmap

做batch,

scan

融合循环,

remat

回收显存,

pmap

pjit

做扩展,

value_and_grad(..., has_aux=True)

让每一步只跑一次forward一次backward。

https://avoid.overfit.cn/post/84e4e28e3ca8473488a0e9248d1ec51b

作者:Nexumo

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/23 3:17:41

收藏必读:大模型架构演进全解析——从GPT-4到智能体的三大技术支柱

文章分析了2023-2025年大模型技术的演进,从GPT-4的"唯参数规模论"到效率、推理和智能体三大新支柱的确立。技术演进包括MoE稀疏架构、线性注意力机制、推理时计算(Thinking)以及智能体工具使用。未来趋势指向具身智能、世界模型和后Transformer架构探索&a…

作者头像 李华
网站建设 2026/4/25 3:01:14

华为OD机考双机位C卷 - 完美走位 (Java Python JS C/C++ GO )

最新华为上机考试 真题目录:点击查看目录 华为OD面试真题精选:点击立即查看 2025华为od机试双机位C卷 题目描述 在第一人称射击游戏中,玩家通过键盘的A、S、D、W四个按键控制游戏人物分别向左、向后、向右、向前进行移动,从而…

作者头像 李华
网站建设 2026/4/28 4:04:16

LangGraph 是什么?一文秒懂且通俗易懂!

“「 嘿兄弟,我好想交女朋友但都交不到,怎么办? 」” 身为 AI 工程师,为了帮他,当然是画个流程图啊! 交女朋友要分步骤,每个步骤都有单一目的。 如果失败也没关系,流程上我们退回去反…

作者头像 李华
网站建设 2026/4/15 6:40:02

win11蓝屏dump日志无法定位到具体应用终极解决方案

❤️博客主页: iknow181 🔥系列专栏: 网络安全、 Python、JavaSE、JavaWeb、CCNP 🎉欢迎大家点赞👍收藏⭐评论✍ 一、出现问题 不知道啥时候开始的,电脑有时莫名蓝屏,好像是去年开始的&#xff…

作者头像 李华
网站建设 2026/4/29 22:41:53

亲测好用!8款一键生成论文工具测评:本科生毕业论文必备

亲测好用!8款一键生成论文工具测评:本科生毕业论文必备 2026年学术写作工具测评:为何值得一看? 随着AI技术的不断进步,越来越多的学术写作工具涌现出来,帮助学生和研究者提升效率、优化内容质量。然而&…

作者头像 李华
网站建设 2026/5/1 4:49:07

springboot+vue基于spring的药品销售商城进销存管理系统的设计与实现

目录摘要技术要点开发技术核心代码参考示例1.建立用户稀疏矩阵,用于用户相似度计算【相似度矩阵】2.计算目标用户与其他用户的相似度总结源码文档获取/同行可拿货,招校园代理 :文章底部获取博主联系方式!摘要 该系统基于SpringBoot和Vue.js技…

作者头像 李华