主题
字号
CHAPTER 12 ≈ 90 MIN READ

JAX 入门

PyTorch 原地修改张量,TensorFlow 构建计算图,JAX 编译纯函数。最后这个会改变你对深度学习的思考方式。

学习目标

为什么要学这个

你知道怎么用 PyTorch 搭神经网络。定义 nn.Module,调 .backward(),step 优化器。它能跑。数百万人在用。

但 PyTorch 有一个刻在 DNA 里的约束:它逐个操作、在 Python 中即时追踪。每个 tensor + tensor 都是一次单独的内核调用。每个训练步骤都重新解释同一段 Python 代码。这在小规模没问题,但当你要在 2048 块 TPU 上训练 5400 亿参数的模型时,这些开销就要了命。

Google DeepMind 用 JAX 训练 Gemini。Anthropic 用 JAX 训练 Claude。这不是小打小闹——它们是地球上最大规模的神经网络训练。它们选 JAX 是因为 JAX 把你的训练循环当成一个可编译的程序,而不是一连串 Python 调用。

JAX 本质上就是带三种超能力的 NumPy:自动微分、JIT 编译到 XLA、自动向量化。你写一个处理单个样本的函数,JAX 给你一个能处理批量、计算梯度、编译成机器码、在多设备上运行的函数。而且原始函数一个字不用改。

核心概念

JAX 的哲学

JAX 是函数式框架。没有类,没有可变状态,没有 .backward() 方法。取而代之:

PyTorch JAX
nn.Module 类带状态 纯函数:f(params, x) -> y
loss.backward() jax.grad(loss_fn)(params, x, y)
即时执行 通过 XLA 做 JIT 编译
for x in batch: 手动循环 jax.vmap(f) 自动向量化
DataParallel / FSDP jax.pmap(f) 自动并行
可变的 model.parameters() 不可变的 pytree 数组

这不是代码风格偏好,而是编译器约束。JIT 编译要求纯函数——同样的输入永远产生同样的输出,没有副作用。这个限制恰恰是 100 倍加速的前提。

jax.numpy:熟悉的表面

JAX 在加速器上重新实现了 NumPy API:

import jax.numpy as jnp

a = jnp.array([1.0, 2.0, 3.0])
b = jnp.array([4.0, 5.0, 6.0])
c = jnp.dot(a, b)

函数名一样,广播规则一样,切片语义一样。但数组跑在 GPU/TPU 上,而且每一步操作都可以被编译器追踪。

一个关键区别:JAX 数组是不可变的。不能写 a[0] = 5,而要写 a = a.at[0].set(5)。刚开始会觉得别扭,但一周后就想通了——不可变性正是 gradjitvmap 能互相组合的原因。

jax.grad:函数式自动微分

PyTorch 把梯度附着在张量上(.grad)。JAX 把梯度附着在函数上。

import jax

def f(x):
    return x ** 2

df = jax.grad(f)
df(3.0)  # 6.0

jax.grad 接收一个函数,返回一个计算其梯度的新函数。没有 .backward() 调用,张量上也不存计算图。梯度就是另一个可以调用、组合、JIT 编译的函数。

任意组合:

d2f = jax.grad(jax.grad(f))  # 二阶导数
d2f(3.0)  # 2.0

二阶导、三阶导、Jacobian、Hessian——都通过组合 grad 实现。PyTorch 也能做(torch.autograd.functional.hessian),但那是后来加上去的。在 JAX,这是根基。

约束:grad 只对纯函数有效。内部不能 print(会在 tracing 时执行而非运行时),不能修改外部状态,不能不带 key 就生成随机数。

jit:编译到 XLA

@jax.jit
def train_step(params, x, y):
    loss = loss_fn(params, x, y)
    return loss

fast_step = jax.jit(train_step)

第一次调用时,JAX 追踪(trace)函数——记录发生了哪些操作,但不真正执行。然后把这个 trace 交给 XLA(Accelerated Linear Algebra),Google 为 TPU 和 GPU 打造的编译器。XLA 融合操作、消除冗余内存拷贝、生成优化的机器码。

后续调用完全跳过 Python。编译后的代码以 C++ 速度跑在加速器上。

JIT 有用的情况:

JIT 有害的情况:

控制流限制是真实的。jax.lax.cond 替代 if/elsejax.lax.scan 替代 for 循环。这不是可选的——这是编译的代价。

vmap:自动向量化

你写一个处理单样本的函数:

def predict(params, x):
    return jnp.dot(params['w'], x) + params['b']

vmap 把它提升为处理批量:

batch_predict = jax.vmap(predict, in_axes=(None, 0))

in_axes=(None, 0) 意思是:不要在 params 上批处理(共享的),在 x 的第 0 维上批处理。没有手写 for 循环,没有 reshape,没有手动穿引 batch 维度。JAX 自己搞定 batch 维并向量化整个计算。

这不是语法糖。vmap 生成融合的向量化代码,比 Python 循环快 10–100 倍。而且它跟 jitgrad 可以组合:

per_example_grads = jax.vmap(jax.grad(loss_fn), in_axes=(None, 0, 0))

逐样本梯度,一行代码。这在 PyTorch 中几乎不可能不用 hack 实现。

pmap:跨设备数据并行

parallel_step = jax.pmap(train_step, axis_name='devices')

pmap 把函数复制到所有可用设备(GPU/TPU)上并切分 batch。函数内部用 jax.lax.pmeanjax.lax.psum 同步梯度。

Google 用 pmap(以及后续的 shard_map)在数千块 TPU v5e 上训练 Gemini。编程模型:写单设备版本,包一层 pmap,搞定。

Pytree:通用数据结构

JAX 操作"pytree"——列表、元组、字典和数组的嵌套组合。你的模型参数就是一个 pytree:

params = {
    'layer1': {'w': jnp.zeros((784, 256)), 'b': jnp.zeros(256)},
    'layer2': {'w': jnp.zeros((256, 128)), 'b': jnp.zeros(128)},
    'layer3': {'w': jnp.zeros((128, 10)),  'b': jnp.zeros(10)},
}

每个 JAX 变换——gradjitvmap——都知道怎么遍历 pytree。jax.tree.map(f, tree)f 应用到每个叶子。优化器就这样一次性更新所有参数:

params = jax.tree.map(lambda p, g: p - lr * g, params, grads)

不需要 .parameters() 方法,不需要参数注册。树结构本身就是模型。

函数式 vs 面向对象

PyTorch 把状态存在对象里:

class Model(nn.Module):
    def __init__(self):
        self.linear = nn.Linear(784, 10)

    def forward(self, x):
        return self.linear(x)

JAX 用纯函数加显式状态:

def predict(params, x):
    return jnp.dot(x, params['w']) + params['b']

参数是传进来的。没有东西被存储,没有东西被修改。这让每个函数都可测试、可组合、可编译。代价是你要自己管理参数——或者用 Flax、Equinox 这类库。

JAX 生态系统

JAX 给你基础原语,库给你工程体验:

角色 风格
Flax(Google) 神经网络层 nn.Module + 显式状态
Equinox(Patrick Kidger) 神经网络层 基于 pytree,更 Pythonic
Optax(DeepMind) 优化器 + 学习率调度 可组合的梯度变换
Orbax(Google) Checkpointing 保存/恢复 pytree
CLU(Google) 指标 + 日志 训练循环工具

Optax 是标准优化器库。它把梯度变换(Adam、SGD、梯度裁剪)和参数更新分开,组合起来极其方便:

optimizer = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.adam(learning_rate=1e-3),
)

什么时候该用 JAX vs PyTorch

因素 JAX PyTorch
TPU 支持 一等公民(Google 两个都造) 社区维护(torch_xla)
GPU 支持 不错(通过 XLA 支持 CUDA) 业界最佳(原生 CUDA)
调试 难(tracing + 编译) 容易(eager,逐行执行)
生态 研究向(Flax、Equinox) 巨大(HuggingFace、torchvision 等)
招聘 小众(Google/DeepMind/Anthropic) 主流(到处都是)
大规模训练 更强(XLA、pmap、mesh) 不错(FSDP、DeepSpeed)
原型速度 更慢(函数式开销) 更快(改了就跑)
生产推理 TF Serving, Vertex AI TorchServe, Triton, ONNX
谁在用 DeepMind(Gemini)、Anthropic(Claude) Meta(Llama)、OpenAI(GPT)、Stability AI

坦诚答案:除非有明确理由,否则用 PyTorch。那些理由是——有 TPU、需要逐样本梯度、超大规模多设备训练,或者你在 Google/DeepMind/Anthropic 工作。

JAX 中的随机数

JAX 没有全局随机状态。每个随机操作都需要显式的 PRNG key:

key = jax.random.PRNGKey(42)
key1, key2 = jax.random.split(key)
w = jax.random.normal(key1, shape=(784, 256))

一开始觉得烦。但它保证了跨设备和跨编译的可复现性——这是 PyTorch 的 torch.manual_seed 在多 GPU 环境下做不到的。

从零实现

第一步:准备数据

我们用 JAX 和 Optax 在 MNIST 上训练一个 3 层 MLP。784 输入,两个隐藏层分别 256 和 128 神经元,10 个输出类别。

import jax
import jax.numpy as jnp
from jax import random
import optax

def get_mnist_data():
    from sklearn.datasets import fetch_openml
    mnist = fetch_openml('mnist_784', version=1, as_frame=False, parser='auto')
    X = mnist.data.astype('float32') / 255.0
    y = mnist.target.astype('int')
    X_train, X_test = X[:60000], X[60000:]
    y_train, y_test = y[:60000], y[60000:]
    return X_train, y_train, X_test, y_test

第二步:初始化参数

没有类。只是一个返回 pytree 的函数:

def init_params(key):
    k1, k2, k3 = random.split(key, 3)
    scale1 = jnp.sqrt(2.0 / 784)
    scale2 = jnp.sqrt(2.0 / 256)
    scale3 = jnp.sqrt(2.0 / 128)
    params = {
        'layer1': {
            'w': scale1 * random.normal(k1, (784, 256)),
            'b': jnp.zeros(256),
        },
        'layer2': {
            'w': scale2 * random.normal(k2, (256, 128)),
            'b': jnp.zeros(128),
        },
        'layer3': {
            'w': scale3 * random.normal(k3, (128, 10)),
            'b': jnp.zeros(10),
        },
    }
    return params

He 初始化,手动完成。三个 PRNG key 从一个种子 split 出来。每个权重都是嵌套字典里的不可变数组。

第三步:前向传播

def forward(params, x):
    x = jnp.dot(x, params['layer1']['w']) + params['layer1']['b']
    x = jax.nn.relu(x)
    x = jnp.dot(x, params['layer2']['w']) + params['layer2']['b']
    x = jax.nn.relu(x)
    x = jnp.dot(x, params['layer3']['w']) + params['layer3']['b']
    return x

def loss_fn(params, x, y):
    logits = forward(params, x)
    one_hot = jax.nn.one_hot(y, 10)
    return -jnp.mean(jnp.sum(jax.nn.log_softmax(logits) * one_hot, axis=-1))

纯函数。参数传入,预测传出。没有 self,没有存储状态。loss_fn 从头计算交叉熵——softmax、log、取负均值。

第四步:JIT 编译的训练步骤

@jax.jit
def train_step(params, opt_state, x, y):
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

@jax.jit
def accuracy(params, x, y):
    logits = forward(params, x)
    preds = jnp.argmax(logits, axis=-1)
    return jnp.mean(preds == y)

jax.value_and_grad 一次调用同时返回 loss 值和梯度。@jax.jit 装饰器把两个函数都编译到 XLA。首次调用后,每一步训练都不再碰 Python。

第五步:训练循环

optimizer = optax.adam(learning_rate=1e-3)

X_train, y_train, X_test, y_test = get_mnist_data()
X_train, X_test = jnp.array(X_train), jnp.array(X_test)
y_train, y_test = jnp.array(y_train), jnp.array(y_test)

key = random.PRNGKey(0)
params = init_params(key)
opt_state = optimizer.init(params)

batch_size = 128
n_epochs = 10

for epoch in range(n_epochs):
    key, subkey = random.split(key)
    perm = random.permutation(subkey, len(X_train))
    X_shuffled = X_train[perm]
    y_shuffled = y_train[perm]

    epoch_loss = 0.0
    n_batches = len(X_train) // batch_size
    for i in range(n_batches):
        start = i * batch_size
        xb = X_shuffled[start:start + batch_size]
        yb = y_shuffled[start:start + batch_size]
        params, opt_state, loss = train_step(params, opt_state, xb, yb)
        epoch_loss += loss

    train_acc = accuracy(params, X_train[:5000], y_train[:5000])
    test_acc = accuracy(params, X_test, y_test)
    print(f"Epoch {epoch + 1:2d} | Loss: {epoch_loss / n_batches:.4f} | "
          f"Train Acc: {train_acc:.4f} | Test Acc: {test_acc:.4f}")

10 个 epoch,测试准确率约 97%。第一个 epoch 慢(JIT 编译),第 2–10 个 epoch 很快。

注意这里缺了什么:没有 .zero_grad(),没有 .backward(),没有 .step()。整个更新是一次组合好的函数调用。梯度计算、Adam 变换、应用到参数——全在 train_step 里完成。

实战用法

Flax:Google 标准库

Flax 是最常用的 JAX 神经网络库。它加回了 nn.Module,但保持显式状态管理:

import flax.linen as nn

class MLP(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(256)(x)
        x = nn.relu(x)
        x = nn.Dense(128)(x)
        x = nn.relu(x)
        x = nn.Dense(10)(x)
        return x

model = MLP()
params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 784)))
logits = model.apply(params, x_batch)

结构跟 PyTorch 类似,但 params 跟模型是分开的。model.init() 创建参数,model.apply(params, x) 跑前向传播。模型对象本身无状态。

Equinox:更 Pythonic 的替代品

Equinox(Patrick Kidger 开发)把模型表示为 pytree:

import equinox as eqx

model = eqx.nn.MLP(
    in_size=784, out_size=10, width_size=256, depth=2,
    activation=jax.nn.relu, key=jax.random.PRNGKey(0)
)
logits = model(x)

模型本身就是 pytree。不需要 .apply()。参数就是模型的叶节点。这更接近 JAX 的思维方式。

Optax:可组合的优化器

Optax 把梯度变换和参数更新解耦:

schedule = optax.warmup_cosine_decay_schedule(
    init_value=0.0, peak_value=1e-3,
    warmup_steps=1000, decay_steps=50000
)

optimizer = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.adamw(learning_rate=schedule, weight_decay=0.01),
)

梯度裁剪、学习率 warmup、weight decay——全部作为变换链组合起来。每个变换看到梯度,修改它,传给下一个。没有巨大的单体优化器类。

练习

  1. 给 MLP 加 Dropout。 在 JAX 中,Dropout 需要 PRNG key——把 key 穿过前向传播,每层 Dropout 都 split 一次。对比有无 Dropout 的测试准确率。

  2. jax.vmap 计算逐样本梯度。 对一批 32 张 MNIST 图片计算每个样本的梯度,算出每个样本的梯度范数。哪些样本梯度最大?为什么?

  3. 写一个通用的 MLP forward 函数。jax.tree.leaves 自动判断层数,让 mlp_forward(params, x) 能处理任意深度的网络。

  4. Benchmark JIT 加速。 有无 @jax.jit 分别跑 100 步训练,计时对比。你的硬件上加速多少倍?第一次调用的编译开销是多少?

  5. 实现梯度裁剪。 组合 optax.chain(optax.clip_by_global_norm(1.0), optax.adam(1e-3))。有无裁剪分别训练,画出训练过程中梯度范数的变化曲线。

术语表

术语 通俗说法 真正含义
XLA "让 JAX 快的东西" Accelerated Linear Algebra——从计算图生成优化 GPU/TPU 内核的编译器
JIT "即时编译" JAX 在首次调用时追踪函数,编译到 XLA,后续调用直接执行编译版本
Pure function(纯函数) "没有副作用" 输出只取决于输入——无全局状态、无修改、无隐式随机
vmap "自动批处理" 把处理单个样本的函数变成处理整批的函数,无需重写
pmap "自动并行" 把函数复制到多设备并切分输入 batch
Pytree "嵌套字典里装数组" 列表、元组、字典、数组的嵌套结构,JAX 能遍历和变换它们
Tracing(追踪) "记录计算过程" JAX 用抽象值执行函数来构建计算图,不真正计算
函数式自动微分 "函数的梯度" 通过变换函数来计算导数,而非在张量上附着梯度存储
Optax "JAX 的优化器库" 可组合的梯度变换库——Adam、SGD、裁剪、调度——可以链式组合
Flax "JAX 的 nn.Module" Google 的 JAX 神经网络库,添加层抽象同时保持状态显式

自测题

Q1PyTorch 和 JAX 之间的根本设计差异是什么?
Q2jax.jit 做了什么?
Q3jax.vmap 做了什么?
Q4JAX 处理模型状态(权重)和 PyTorch 有什么不同?
Q5什么时候应该选 JAX 而不是 PyTorch?