用于线性注意力的 Gated DeltaNet | Sebastian Raschka
作者:Sebastian Raschka | 日期:2025年12月17日
最近,Qwen3-Next 和 Kimi Linear 提出了混合式 Transformer 架构,它们实现了对注意力机制的替代方案,使其在上下文长度上的计算复杂度呈线性增长,而不是二次增长。
Qwen3-Next 和 Kimi Linear 都采用了 3:1 的比例,这意味着每使用三个采用线性 Gated DeltaNet 变体的 Transformer 块,就会插入一个使用完整注意力的块,如下图所示。
引言与概览
Gated DeltaNet 是一种线性注意力变体,其灵感来源于循环神经网络,并引入了来自《Gated Delta Networks: Improving Mamba2 with Delta Rule》论文中的门控机制。从某种意义上说,Gated DeltaNet 是带有 Mamba 风格门控的 DeltaNet,而 DeltaNet 本身是一种线性注意力机制。
Kimi Linear 通过 Kimi Delta Attention(KDA)机制对 Qwen3-Next 的线性注意力机制进行了修改,该机制本质上是对 Gated DeltaNet 的一种改进。Qwen3-Next 使用的是标量门控(每个注意力头一个值)来控制记忆衰减率,而 Kimi Linear 则将其替换为针对每个特征维度的通道级门控。根据作者的说法,这种方式能够对记忆进行更精细的控制,从而提升长上下文推理能力。
此外,在完整注意力层中,Kimi Linear 用多头潜在注意力(Multi-Head Latent Attention,MLA)替换了 Qwen3-Next 的门控注意力层(本质上是带输出门控的标准多头注意力层)。这与我们之前在 DeepSeek V3/R1 部分讨论的 MLA 机制相同,只是额外加入了一个门控。(回顾一下,MLA 通过压缩 key/value 空间来减少 KV cache 的大小。)
Kimi Linear 中的 MLA 并未使用该门控,这是作者有意为之,以便更直接地将该架构与标准 MLA 进行对比;不过他们表示,未来计划将其加入。
由于我们已经在 ../05_mla 中实现了 MLA,因此本补充内容将重点放在 Gated DeltaNet 这一部分。
门控注意力
在讨论 Gated DeltaNet 本身之前,我们先简要介绍一下门控。如前一张图中 Qwen3-Next 架构的上半部分所示,Qwen3-Next 使用了“门控注意力”。这本质上是在常规完整注意力的基础上增加了一个 sigmoid 门控。
这种门控是一种简单的修改,下面的示例代码是在第 3 章的 MultiHeadAttention 代码基础上添加的,用于说明其实现方式:
import torch
from torch import nn
class GatedMultiHeadAttention(nn.Module):
def __init__(
self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False
):
super().__init__()
assert d_out % num_heads == 0
self.d_out = d_out
self.num_heads = num_heads
self.head_dim = d_out // num_heads
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
####################################################
### NEW: Add gate
self.W_gate = nn.Linear(d_in, d_out, bias=qkv_bias)
####################################################
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out)
self.dropout = nn.Dropout(dropout)
self.register_buffer(
"mask",
torch.triu(torch.ones(context_length, context_length), diagonal=1),
persistent=False,
)
def forward(self, x):
b, num_tokens, _ = x.shape
queries = self.W_query(x)
####################################################
### NEW: Add gate
gate = self.W_gate(x)
####################################################
keys = self.W_key(x)
values = self.W_value(x)
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
values = values.view(b, num_tokens, self.num_heads, self.head_dim)
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
keys = keys.transpose(1, 2)
queries = queries.transpose(1, 2)
values = values.transpose(1, 2)
attn_scores = queries @ keys.transpose(2, 3)
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
attn_scores.masked_fill_(
mask_bool, torch.finfo(attn_scores.dtype).min
)
attn_weights = torch.softmax(
attn_scores / (self.head_dim ** 0.5), dim=-1
)
attn_weights = self.dropout(attn_weights)
context = (attn_weights @ values).transpose(1, 2)
context = context.reshape(b, num_tokens, self.d_out)
####################################################
### NEW: Add gate
context = context * torch.sigmoid(gate)
####################################################
out = self.out_proj(context)
return out
如上所示,在按常规方式计算注意力之后,模型会从同一输入中生成一个独立的门控信号,对其施加 sigmoid 激活函数以将其限制在 0 到 1 之间,然后将其与注意力输出相乘。这使得模型能够动态地放大或缩小某些特征。Qwen3-Next 的开发者指出,这种机制有助于提升训练稳定性:
[…] 注意力输出门控机制有助于消除注意力汇聚(Attention Sink)和大规模激活(Massive Activation)等问题,从而确保模型整体的数值稳定性。
Gated DeltaNet
那么,什么是 Gated DeltaNet?Gated DeltaNet(即 Gated Delta Network)是 Qwen3-Next 所采用的一种线性注意力层,旨在作为标准 softmax 注意力的替代方案。如前所述,它源自《Gated Delta Networks: Improving Mamba2 with Delta Rule》论文。
Gated DeltaNet 最初被提出作为 Mamba2 的一种改进版本,将 Mamba2 的门控衰减机制与 delta 规则结合在一起。
Mamba 是一种状态空间模型(作为 Transformer 的替代方案),这是一个值得在未来单独深入探讨的重要主题。
所谓 delta 规则,是指通过计算新值与预测值之间的差值(delta,Δ)来更新隐藏状态,该隐藏状态被用作记忆状态(稍后会详细说明)。
(旁注:熟悉经典机器学习文献的读者可以将其视为一种受生物学启发的 Hebbian 学习形式:“一起激活的神经元会连接在一起。”它基本上是感知机更新规则和基于梯度下降学习的前身,但不依赖监督信号。)
Gated DeltaNet 拥有一个与前文门控注意力中类似的门控结构,不过它使用的是 SiLU 激活函数,而不是逻辑 sigmoid,如下图所示。(选择 SiLU 可能有助于改善梯度流动性和数值稳定性。)
不过,如上图所示,Gated DeltaNet 中的“门控”还包括多个额外的门:
-
• α(衰减门)控制记忆随时间衰减或重置的速度, -
• β(更新门)控制新输入对状态更新的强度。
在代码层面,下方给出了一个与上图所示 Gated DeltaNet 对应的简化实现(省略了卷积混合部分),以提高代码可读性并聚焦于其循环特性。该实现受 Qwen3 官方实现启发:
import torch
from torch import nn
import torch.nn.functional as F
def l2norm(x, dim=-1, eps=1e-6):
return x * torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
class GatedDeltaNet(nn.Module):
def __init__(
self, d_in, d_out, dropout, num_heads, qkv_bias=False
):
super().__init__()
assert d_out % num_heads == 0
self.d_out = d_out
self.num_heads = num_heads
self.head_dim = d_out // num_heads
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
####################################################
### NEW: Gates for delta rule and output gating
self.W_gate = nn.Linear(d_in, d_out, bias=False)
self.W_beta = nn.Linear(d_in, d_out, bias=False)
# Note: The decay gate alpha corresponds to
# A_log + W_alpha(x) + dt_bias
self.W_alpha = nn.Linear(d_in, num_heads, bias=False)
self.dt_bias = nn.Parameter(torch.ones(num_heads))
A_init = torch.empty(num_heads).uniform_(0, 16)
self.A_log = nn.Parameter(torch.log(A_init))
# We could implement this as
# W_alpha = nn.Linear(d_in, num_heads, bias=True)
# but the bias is separate for interpretability and
# to mimic the official implementation
self.norm = nn.RMSNorm(self.head_dim, eps=1e-6)
####################################################
self.out_proj = nn.Linear(d_out, d_out)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
b, num_tokens, _ = x.shape
queries = self.W_query(x)
keys = self.W_key(x)
values = self.W_value(x)
####################################################
### NEW: Compute delta rule gates
beta = torch.sigmoid(self.W_beta(x))
alpha = -self.A_log.exp().view(1, 1, -1) * F.softplus(
self.W_alpha(x) + self.dt_bias
)
gate = self.W_gate(x)
####################################################
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
values = values.view(b, num_tokens, self.num_heads, self.head_dim)
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
beta = beta.view(b, num_tokens, self.num_heads, self.head_dim)
gate = gate.view(b, num_tokens, self.num_heads, self.head_dim) # NEW
keys = keys.transpose(1, 2)
queries = queries.transpose(1, 2)
values = values.transpose(1, 2)
beta = beta.transpose(1, 2)
gate = gate.transpose(1, 2) # NEW
####################################################
### NEW: QKNorm-like normalization for delta rule
queries = l2norm(queries, dim=-1) / (self.head_dim ** 0.5)
keys = l2norm(keys, dim=-1)
####################################################
S = x.new_zeros(b, self.num_heads, self.head_dim, self.head_dim)
outs = []
####################################################
### NEW: Gated delta rule update
for t in range(num_tokens):
k_t = keys[:, :, t]
q_t = queries[:, :, t]
v_t = values[:, :, t]
b_t = beta[:, :, t]
a_t = alpha[:, t].unsqueeze(-1).unsqueeze(-1)
S = S * a_t.exp()
kv_mem = (S * k_t.unsqueeze(-1)).sum(dim=-2)
delta = (v_t - kv_mem) * b_t
S = S + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
y_t = (S * q_t.unsqueeze(-1)).sum(dim=-2)
####################################################
outs.append(y_t)
context = torch.stack(outs, dim=2).transpose(1, 2).contiguous()
context = context.view(b, num_tokens, self.num_heads, self.head_dim)
####################################################
### NEW: Apply RMSNorm and SiLU gate
context = self.norm(context)
context = context * F.silu(gate)
####################################################
context = context.view(b, num_tokens, self.d_out)
context = self.dropout(context)
out = self.out_proj(context)
return out
(需要注意的是,为了简化说明,我省略了 Qwen3-Next 和 Kimi Linear 中使用的卷积混合部分。)
由此可见,该机制与标准(或门控)注意力存在诸多差异。
在门控注意力中,模型仍然对所有 token 之间执行标准注意力计算(每个 token 都会关注其他所有 token)。随后,在得到注意力输出之后,通过一个 sigmoid 门控来决定保留多少输出。关键点在于,这仍然是常规的缩放点积注意力,其计算复杂度随上下文长度呈二次增长。
回顾一下,缩放点积注意力的计算形式为 softmax(QKᵀ)V,其中 Q 和 K 是 n×d 的矩阵,n 为输入 token 的数量,d 为嵌入维度。因此,QKᵀ 会生成一个 n×n 的注意力矩阵,并与 n×d 的 value 矩阵 V 相乘:
attn_scores = queries @ keys.transpose(2, 3)
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
attn_scores.masked_fill_(
mask_bool, torch.finfo(attn_scores.dtype).min
)
attn_weights = torch.softmax(
attn_scores / (self.head_dim ** 0.5), dim=-1
)
context = (attn_weights @ values).transpose(1, 2)
context = context.reshape(b, num_tokens, self.d_out)
而在 Gated DeltaNet 中,并不存在 n×n 的注意力矩阵。相反,模型会逐 token 进行处理,维护一个随每个新 token 更新的运行时记忆状态。这正是下方代码中所实现的逻辑,其中 S 是在每个时间步 t 递归更新的状态:
S = x.new_zeros(b, self.num_heads, self.head_dim, self.head_dim)
outs = []
for t in range(num_tokens):
k_t = keys[:, :, t]
q_t = queries[:, :, t]
v_t = values[:, :, t]
b_t = beta[:, :, t]
a_t = alpha[:, t].unsqueeze(-1).unsqueeze(-1)
S = S * a_t.exp()
kv_mem = (S * k_t.unsqueeze(-1)).sum(dim=-2)
delta = (v_t - kv_mem) * b_t
S = S + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
y_t = (S * q_t.unsqueeze(-1)).sum(dim=-2)
这些门控参数控制了记忆的更新方式:
-
• α(alpha)调节旧记忆的遗忘程度(衰减), -
• β(beta)调节当前时间步 t 的 token 对记忆更新的强度。
(最终的输出门控未在上述代码片段中展示,其作用与门控注意力类似,用于控制输出的保留比例。)
因此,从某种意义上说,Gated DeltaNet 中的状态更新方式与循环神经网络(RNN)的工作机制相似。其优势在于计算复杂度随上下文长度呈线性增长(通过 for-loop 实现),而非二次增长。
然而,这种递归状态更新的代价在于,相较于常规(或门控)注意力,它牺牲了通过全局成对注意力所获得的全局上下文建模能力。
Gated DeltaNet 在一定程度上仍能捕捉上下文信息,但必须通过记忆状态 S 这一瓶颈来实现。该记忆是固定大小的,因此更高效,但会将过去的上下文压缩到单一隐藏状态中,这一点与 RNN 非常相似。
这正是 Qwen3-Next 和 Kimi Linear 架构没有完全用 DeltaNet 层替换所有注意力层,而是采用前文提到的 3:1 比例的原因。
DeltaNet 的内存节省
在上一节中,我们讨论了 DeltaNet 相较于完整注意力在计算复杂度方面的优势,即上下文长度上线性增长而非二次增长。
除了线性计算复杂度之外,DeltaNet 的另一个重要优势在于内存节省,因为 DeltaNet 模块不会随着上下文长度增长 KV cache。(有关 KV cache 的更多信息,请参见 ../03_kv-cache。)相反,它们维护的是一个固定大小的递归状态,因此内存使用量与上下文长度无关。
对于常规多头注意力(MHA)层,其 KV cache 的大小可按如下方式计算:
KV_cache_MHA ≈ batch_size × n_tokens × n_heads × d_head × 2 × bytes
(其中系数 2 是因为需要同时存储 key 和 value。)
对于上文所实现的简化版 DeltaNet,其内存占用为:
KV_cache_DeltaNet = batch_size × n_heads × d_head × d_head × bytes
需要注意的是,KV_cache_DeltaNet 的内存大小不依赖上下文长度(n_tokens)。同时,我们只存储单一的状态 S,而不是分别存储 key 和 value,因此不再需要 2 × bytes。不过,这里引入了 d_head × d_head 的二次项,这是由于状态 S 的定义方式:
S = x.new_zeros(b, self.num_heads, self.head_dim, self.head_dim)
通常这并不是问题,因为 head_dim 通常较小。例如,在 Qwen3-Next 中其值为 128。
包含卷积混合的完整版本会更加复杂,还涉及卷积核大小等因素,但上述公式已经能够很好地说明 Gated DeltaNet 的整体趋势与设计动机。
我们可以通过以下辅助脚本,对不同上下文长度下的内存估算与节省情况进行可视化:
uv run plot_memory_estimates_gated_deltanet.py \
--emb_dim 2048 \
--n_heads 16 \
--n_layers 48 \
--dtype "bf16"
需要注意的是,上述脚本中 head_dim 的计算方式为 emb_dim / n_heads,即 2048 / 16 = 128。
本杂志是一个个人热情驱动的项目,你的支持有助于它持续发展。
如果你愿意支持我的工作,可以考虑我的书 《Build a Large Language Model (From Scratch)》,或它的续作 《Build a Reasoning Model (From Scratch)》。(我相信你会从中收获颇多;它们以你在其他地方很难看到的深度,系统讲解了 LLM 的工作原理。)
感谢你的阅读,也感谢你对独立研究的支持。
如果你已经阅读了这本书,并且有几分钟时间,我将非常感激你能留下一个 简短的评价。这对我们作者来说帮助非常大。
https://www.amazon.com/Build-Large-Language-Model-Scratch/dp/1633437167
你的支持意义重大!谢谢你!
https://sebastianraschka.com/llms-from-scratch/ch04/08_deltanet/
如果觉得内容不错,欢迎你点一下「在看」,或是将文章分享给其他有需要的人^^
相关好文推荐:
DeepSeek的多头潜在注意力(MLA) | Sebastian Raschka
嵌入模型检索面临严重限制 | DeepLearning.AI
理解用于评估大语言模型(LLM)的四种主要方法 | Sebastian Raschka
从 DeepSeek V3 到 Mistral 3 Large:现代大语言模型(LLM)架构设计概览(三)| Sebastian Raschka
从 DeepSeek V3 到 Mistral 3 Large:现代大语言模型(LLM)架构设计概览(二)| Sebastian Raschka
从 DeepSeek V3 到 Mistral 3 Large:现代大语言模型(LLM)架构设计概览(一)| Sebastian Raschka
递归语言模型(Recursive Language Models) | Alex Zhang
重新构想 LLM 记忆:将上下文作为训练数据,使模型能够在测试时学习 | Nvidia
引入嵌套学习(Nested Learning):一种用于持续学习的全新机器学习范式

0条留言