Gated DeltaNet && Linear Attention
Gated DeltaNet
Gated DeltaNet论文:https://arxiv.org/pdf/2412.06464
linear attention的一种变体
linear attention
(以下内容部分摘自 苏剑林的博客)
标准的self-attention可以写成:
\[Attn(Q, K, V) = softmax(\frac{QK^\top}{\sqrt{d}})V\]Q的shape是(L, d),K的shape是(d, L),V的shape是(L, d),其中L是序列长度,d是hidden size。
其中$QK^\top$的计算复杂度是$O(L^2d)$,softmax得到的矩阵P(L, L)与V做矩阵乘法的计算复杂度是$O(L^2d)$,总的计算复杂度是$O(L^2d + L^2d) = O(L^2d)$。
如果没有softmax,根据矩阵乘法的结合律,其实可以先算后面的$K^\top V$,计算复杂度是$O(Ld^2)$,得到的矩阵(d, L)再与Q做矩阵乘法,计算复杂度是$O(Ld^2)$,总的计算复杂度是$O(Ld^2 + Ld^2) = O(Ld^2)$。这样复杂度就随序列长度L线性增长了。
linear attention通过把$softmax(QK^T)$拆成某种线性形式,使得后面的KV乘法可以先计算:
标准的self-attention可以写成:
\[Attn(Q, K, V)_i = \frac{\sum^n_{j=1} e^{\frac{q_i^\top k_j}{\sqrt{d}}} v_j }{\sum^n_{j=1}e^{\frac{q_i^\top k_j}{\sqrt{d}}}}\]第i个token的attention其实可以视为对v的加权平均,权重是$e^{\frac{q_i^\top k_j}{\sqrt{d}}}$。
$exp(q^\top k)$是一个常见的核函数(指数点积核)(见附录A)。
(核函数:如果存在某个(可能是高维的甚至无限维的)特征映射 $\phi(\cdot)$,使得$K(x,y) = \phi(x)^\top \phi(y)$,则称$K(x,y)$是一个核函数。)
根据核函数的定义,$\exp(q^\top k)$ 可以写成$ \phi(q)^\top \phi(k)$。
那么$Attn(Q, K, V)_i$的分子就可以写成:
\[\sum_{j=1}^n \exp\!\left(\tfrac{q_i^\top k_j}{\sqrt{d}}\right) v_j = \sum_{j=1}^n \phi(q_i)^\top \phi(k_j) v_j = \phi(q_i)^\top \Big(\sum_{j=1}^n \phi(k_j) v_j^\top \Big)\]定义状态矩阵(State Space Model, SSM,下文也称记忆矩阵)$S = \sum_{j=1}^n \phi(k_j)v_j^\top \in \mathbb{R}^{d’ \times d_v}$
分母可以写成:
\[\sum_{j=1}^n \exp\!\left(\tfrac{q_i^\top k_j}{\sqrt{d}}\right) = \sum_{j=1}^n \phi(q_i)^\top \phi(k_j) = \phi(q_i)^\top \Big(\sum_{j=1}^n \phi(k_j)\Big)\]定义归一化项$z = \sum_{j=1}^n \phi(k_j) \in \mathbb{R}^{d’}$
那么$Attn(Q, K, V)_i$可以写成:
\[Attn(Q, K, V)_i = \frac{\phi(q_i)^\top S}{\phi(q_i)^\top z}\](PS:对于第n+1个token(decode场景),$S_{n+1} = S_n + \phi(k_{n+1})v_{n+1}^\top$,$z_{n+1} = z_n + \phi(k_{n+1})$,一次递推的时间复杂度为$O(d^2)$。)
推广成矩阵的形式:
把所有 key 堆成矩阵 $\phi(K) \in \mathbb{R}^{n \times d’}$
把所有 value 堆成矩阵 $V \in \mathbb{R}^{n \times d_v}$
把所有 query 堆成矩阵 $\phi(Q) \in \mathbb{R}^{n \times d’}$
那么:$S = \phi(K)^\top V \quad \in \mathbb{R}^{d’ \times d_v}$,$z = \phi(K)^\top \mathbf{1} \quad \in \mathbb{R}^{d’}$
得到:
\[Attn(Q, K, V) = \frac{\phi(Q)(\phi(K)^\top V)}{\phi(Q)\phi(K)^\top\mathbf{1}}\]其中$\phi(·)$是一个无限维的核函数,实际应用中有其他近似的替代,比如:
-
$\phi(x) = \text{elu}(x)+1$ ,见论文Transformers are RNNs:Fast Autoregressive Transformers with Linear Attention
-
$\phi(Q) = \text{softmax}_2(Q)$ ,$\phi(K)^\top = \text{softmax}_1(K)^\top$ ,Q在d的维度做softmax,K在L的维度做softmax,见论文Efficient Attention: Attention with Linear Complexities
小结: 线性注意力的基本思想是用一个 状态矩阵(SSM) $S = \sum_{j=1}^n \phi(k_j)v_j^\top$ 来累积过去的所有token的信息,然后用 $\phi(q_i)$ 去查询它,由于状态矩阵每次递推更新的复杂度是$O(d^2)$,每次decode的时间复杂度为$O(d^2)$, prefill的时间复杂度为线性的$O(Ld^2)$。
Gated DeltaNet
上述朴素的linear attention存在几个问题:
-
遗忘机制不足:旧的信息一直堆在 SSM 里,很难被“忘掉”。 → 容易被噪声污染。
-
更新不够精细:SSM 每次更新是简单加法(outer product 累积),不能对已有记忆做校正。
对于第一个问题,Mamba2提出用一个 门控/衰减因子 $\alpha_t$,对旧记忆做指数衰减:$S_t = \alpha_t S_{t-1} + \text{新信息}$
对于第二个问题,DeltaNet提出每次更新不是简单地加 $\phi(k_j)v_j^\top$,而是根据 误差 (delta) $ \Delta_t$ 来调整记忆矩阵。
Gated DeltaNet可以理解为带有门控机制的DeltaNet:
\[S_t = \alpha_t S_{t-1} + \Delta_t\]门控因子$\alpha_t$
门控因子$\alpha_t$依赖输入动态地计算,可以理解为一个小的FFN:
\[\alpha_t = \sigma(W_\alpha x_t + b_\alpha)\]其中:
-
$x_t$:输入 token 的embedding
-
$W_\alpha, b_\alpha$:可学习参数
-
$\sigma(\cdot)$:Sigmoid 或 Softplus → 保证 $\alpha_t \in (0,1)$
$\alpha_t$的值越小,旧记忆的权重越小,遗忘得越快。
误差$\Delta_t$
在 DeltaNet 中,SSM的更新公式是:
\[S_t = S_{t-1} + \eta \cdot \phi(k_t)\,\delta_t^\top\]其中:
-
$\eta$:学习率/缩放因子(可学习或固定)
-
$\phi(k_t)$:当前的 key 映射
-
$\delta_t = v_t - \hat{v}_t$:预测误差 (delta)
-
$\hat{v}_t = \frac{\phi(k_t)^\top S_{t-1}}{\phi(k_t)^\top z_{t-1}}$:用旧状态预测当前 value
即模型用旧记忆 $S_{t-1}$ 预测出 $\hat{v}_t$,然后用预测和真实 $v_t$ 的差 $\delta_t = v_t - \hat{v}_t$ 来调整记忆矩阵。
附录
A. 证明$exp(q^\top k)$ 是一个核函数
$exp(q^\top k)$ 可以展开成泰勒级数:
\[\exp(q^\top k) = \sum_{n=0}^\infty \frac{(q^\top k)^n}{n!}\]注意到 $(q^\top k)^n$ 其实是 $\langle q^{\otimes n}, k^{\otimes n}\rangle$
(这里 $q^{\otimes n}$ 表示 $q$ 的 $n$ 阶张量积)。
所以我们可以定义一个无限维的特征映射:
\[\phi(q) = \Big[1, \; q, \; \frac{q^{\otimes 2}}{\sqrt{2!}}, \; \frac{q^{\otimes 3}}{\sqrt{3!}}, \;\dots \Big]\]这样我们就有:
\[\exp(q^\top k) = \phi(q)^\top \phi(k)\]参考链接
-
线性注意力的基本推导:https://spaces.ac.cn/archives/7546
-
各版本线性注意力串讲: https://zhuanlan.zhihu.com/p/718156896
-
Gated DeltaNet解读: https://zhuanlan.zhihu.com/p/672824235