主题
字号
CHAPTER 06 ≈ 45 MIN READ

注意力机制——突破瓶颈的那一刻

解码器不再盯着一个压缩过的摘要苦苦辨认,而是可以回头看整个源序列。从这以后,一切都是注意力加工程。

为什么要学这个

第 09 课以一个"有意义的失败"收尾。一个 GRU 编码器-解码器在 toy copy 任务上,长度 5 时有 89% 的准确率,长度 80 时接近瞎猜。原因是结构性的,不是训练 bug:编码器学到的所有信息必须塞进一个定长隐藏状态,而解码器永远只能看到这一个向量。

Bahdanau、Cho 和 Bengio 在 2014 年发了一个三行修复。与其只把编码器最后的状态给解码器,不如把每一步的编码器状态都留着。在每个解码器步骤中,计算编码器状态的加权平均——权重表示"解码器此刻需要多关注编码器位置 i?"这个加权平均就是上下文(context),而且它每步都在变。

整个想法就是这么回事。Transformer 扩展了它。自注意力(Self-attention)把它用在单个序列上。多头注意力(Multi-head attention)并行跑。但 2014 版已经打破了瓶颈,之后走向 Transformer 的过程是工程,不是概念突破。

核心概念

Bahdanau 注意力:解码器查询所有编码器状态

在每个解码器步骤 t

  1. 用前一步的解码器隐藏状态 s_{t-1} 作为查询(Query)
  2. 把它和每个编码器隐藏状态 h_1, ..., h_T 做打分。每个编码器位置一个标量。
  3. 对分数做 softmax,得到注意力权重(Attention Weights) α_{t,1}, ..., α_{t,T},加起来等于 1。
  4. 上下文向量 c_t = Σ α_{t,i} * h_i。编码器状态的加权平均。
  5. 解码器拿 c_t 加上前一个输出 token,生成下一个 token。

加权平均是关键。当解码器需要把"Je"翻译成"I"时,它把编码器中"Je"对应的状态权重拉高,其他拉低。需要翻译"not"时,把"pas"的权重拉高。上下文向量每步都在重塑。

形状(每个人第一次写注意力都在这翻车)

每个注意力实现第一次出 bug 都是因为形状对不上。慢慢看。

东西 形状 说明
编码器隐藏状态 H (T_enc, d_h) 如果是 BiLSTM,d_h = 2 * d_hidden
解码器隐藏状态 s_{t-1} (d_s,) 一个向量
注意力分数 e_{t,i} 标量 每个编码器位置一个
注意力权重 α_{t,i} 标量 对所有 i 做 softmax 之后
上下文向量 c_t (d_h,) 和一个编码器状态形状一样

Bahdanau(加性)打分。 e_{t,i} = v_α^T * tanh(W_a * s_{t-1} + U_a * h_i)

Luong(乘性)打分。 三种变体:

一个 Bahdanau / Luong 的坑值得点名。 Bahdanau 用的是 s_{t-1}(生成当前词之前的解码器状态)。Luong 用的是 s_t(生成当前词之后的状态)。搞混了会产生微妙的梯度错误,极难调试。选一篇论文的约定,然后坚持用它。

从零实现

第 1 步:加性(Bahdanau)注意力

import numpy as np


def additive_attention(decoder_state, encoder_states, W_a, U_a, v_a):
    projected_dec = W_a @ decoder_state
    projected_enc = encoder_states @ U_a.T
    combined = np.tanh(projected_enc + projected_dec)
    scores = combined @ v_a
    weights = softmax(scores)
    context = weights @ encoder_states
    return context, weights


def softmax(x):
    x = x - np.max(x)
    e = np.exp(x)
    return e / e.sum()

对着上面的表检查你的形状。encoder_states 形状 (T_enc, d_h)projected_enc 形状 (T_enc, d_attn)projected_dec 形状 (d_attn,),广播。combined 形状 (T_enc, d_attn)scores 形状 (T_enc,)weights 形状 (T_enc,)context 形状 (d_h,)。搞定。

第 2 步:Luong dot 和 general

def dot_attention(decoder_state, encoder_states):
    scores = encoder_states @ decoder_state
    weights = softmax(scores)
    return weights @ encoder_states, weights


def general_attention(decoder_state, encoder_states, W):
    projected = W.T @ decoder_state
    scores = encoder_states @ projected
    weights = softmax(scores)
    return weights @ encoder_states, weights

各三行。这就是 Luong 那篇论文影响力大的原因。大多数任务精度一样,代码少很多。

第 3 步:一个带数字的例子

给三个编码器状态(大致对应"cat"、"sat"、"mat")和一个跟第一个最对齐的解码器状态,注意力分布会集中在位置 0。如果解码器状态转向跟最后一个对齐,注意力就跑到位置 2。上下文向量跟着走。

H = np.array([
    [1.0, 0.0, 0.2],
    [0.5, 0.5, 0.1],
    [0.1, 0.9, 0.3],
])

s_close_to_cat = np.array([0.9, 0.1, 0.2])
ctx, w = dot_attention(s_close_to_cat, H)
print("weights:", w.round(3))
weights: [0.464 0.305 0.231]

第一行赢了。然后把解码器状态挪到接近第三个编码器状态,观察权重转移。就这样。注意力就是显式对齐。

第 4 步:为什么这是通往 Transformer 的桥梁

把上面的语言翻译成 Q/K/V:

在经典注意力里,键和值是同一个东西。自注意力(Self-attention)把它们分开:你可以让一个序列查询自身,K 和 V 用不同的学习投影。多头注意力(Multi-head attention)用不同的学习投影并行跑。Transformer 把整个阶段堆很多层,然后扔掉 RNN。

数学是一样的。形状是一样的。从 Bahdanau 注意力到缩放点积注意力(Scaled Dot-Product Attention)的教学跨度,主要是符号变了。

实战用法

PyTorch 和 TensorFlow 都直接提供了注意力模块。

import torch
import torch.nn as nn

mha = nn.MultiheadAttention(embed_dim=128, num_heads=8, batch_first=True)
query = torch.randn(2, 5, 128)
key = torch.randn(2, 10, 128)
value = torch.randn(2, 10, 128)

output, weights = mha(query, key, value)
print(output.shape, weights.shape)
torch.Size([2, 5, 128]) torch.Size([2, 5, 10])

这就是一个 Transformer 注意力层。Query 批次 5 个位置,Key/Value 批次 10 个位置,128 维,8 头。output 是经过上下文增强的新查询。weights 是 5×10 的对齐矩阵,可以可视化。

经典注意力仍然重要的场景

注意力权重当"解释"用的陷阱

注意力权重看起来很好解释。它们是跨位置求和为 1 的权重;你可以画出来;高意味着"看了这里"。审稿人很喜欢。

但它们没有看起来那么可解释。Jain 和 Wallace(2019)证明了在某些任务上,注意力分布可以被打乱、替换成任意分布,而模型预测不变。永远不要把注意力权重当作推理证据报告,除非你做了消融或反事实检验。

Ship It

保存为 outputs/prompt-attention-shapes.md

---
name: attention-shapes
description: Debug shape bugs in attention implementations.
phase: 5
lesson: 10
---

Given a broken attention implementation, you identify the shape mismatch. Output:

1. Which matrix has the wrong shape. Name the tensor.
2. What its shape should be, derived from (d_s, d_h, d_attn, T_enc, T_dec, batch_size).
3. One-line fix. Transpose, reshape, or project.
4. A test to catch regressions. Typically: assert `output.shape == (batch, T_dec, d_h)` and `weights.shape == (batch, T_dec, T_enc)` and `weights.sum(dim=-1) close to 1`.

Refuse to recommend fixes that silently broadcast. Broadcast-hiding bugs surface later as silent accuracy degradation, the worst kind of attention bug.

For Bahdanau confusion, insist the decoder input is `s_{t-1}` (pre-step state). For Luong, `s_t` (post-step state). For dot-product, flag dimension mismatch between query and key as the most common first-time error.

练习

  1. 简单。 实现 softmax 掩码(masking),让编码器中的 padding token 注意力权重为零。在一个变长序列的 batch 上测试。
  2. 中等。 给 Luong general 形式加上多头注意力(Multi-head Attention)。把 d_h 拆成 n_heads 组,每头单独跑注意力,拼接。验证单头情况和你之前的实现一致。
  3. 困难。 在第 09 课的 toy copy 任务上训练一个带 Bahdanau 注意力的 GRU 编码器-解码器。画出准确率 vs 序列长度的曲线。和无注意力基线对比。你应该看到随着长度增加差距越来越大,证明注意力确实打破了瓶颈。

术语表

术语 口语说法 实际意思
注意力(Attention) 看东西 值序列的加权平均,权重由查询-键相似度算出。
查询、键、值(Query, Key, Value) QKV 三个投影:Q 提问,K 用来匹配,V 是返回的内容。
加性注意力(Additive Attention) Bahdanau 前馈网络打分:v^T tanh(W q + U k)
乘性注意力(Multiplicative Attention) Luong dot / general 分数是 q^T kq^T W k。更便宜,大多数任务精度一样。
对齐矩阵(Alignment Matrix) 那张好看的图 注意力权重作为 (T_dec, T_enc) 网格。看它就能知道模型关注了什么。

延伸阅读


自测题

Q1Bahdanau 注意力解决了 seq2seq 中的什么问题?
Q2解码器步骤 t 时的注意力上下文向量是什么?
Q3在 Bahdanau(加性)注意力中,向量 v_a 起什么作用?
Q4Luong 的 'dot' 注意力变体有什么硬性约束?
Q5哪种 Q/K/V 映射描述的是经典(Bahdanau/Luong)注意力?
Q6为什么把原始注意力权重当"解释"报告被认为是脆弱的?
Q7哪一步把 Bahdanau 注意力桥接到了 Transformer 的自注意力?
Q8注意力中掩码(masking)的一个实际用途是什么?