cortexnet 3.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cortexnet/__init__.py +197 -0
- cortexnet/adapter/__init__.py +26 -0
- cortexnet/adapter/arch_adapter.py +209 -0
- cortexnet/adapter/calibrator.py +244 -0
- cortexnet/adapter/inference_adapter.py +272 -0
- cortexnet/adapter/model_registry.py +378 -0
- cortexnet/adapter/weight_adapter.py +415 -0
- cortexnet/adversarial.py +195 -0
- cortexnet/attention.py +520 -0
- cortexnet/blocks.py +682 -0
- cortexnet/cache.py +83 -0
- cortexnet/causal_reasoning.py +232 -0
- cortexnet/compat.py +245 -0
- cortexnet/config.py +234 -0
- cortexnet/continual_learning.py +256 -0
- cortexnet/cortex_block_lite.py +221 -0
- cortexnet/distributed.py +213 -0
- cortexnet/graph_reasoning.py +207 -0
- cortexnet/hierarchical_memory.py +360 -0
- cortexnet/interpretability.py +196 -0
- cortexnet/memory.py +179 -0
- cortexnet/meta_learning.py +187 -0
- cortexnet/model.py +1360 -0
- cortexnet/multi_agent.py +241 -0
- cortexnet/multimodal.py +278 -0
- cortexnet/ops/__init__.py +28 -0
- cortexnet/ops/device_manager.py +449 -0
- cortexnet/ops/npu_ops.py +243 -0
- cortexnet/quantization.py +496 -0
- cortexnet/routing.py +335 -0
- cortexnet/self_evolution.py +174 -0
- cortexnet/ssm.py +340 -0
- cortexnet/training_utils.py +204 -0
- cortexnet/transformer_baseline.py +157 -0
- cortexnet-3.2.1.dist-info/METADATA +114 -0
- cortexnet-3.2.1.dist-info/RECORD +39 -0
- cortexnet-3.2.1.dist-info/WHEEL +5 -0
- cortexnet-3.2.1.dist-info/licenses/LICENSE +201 -0
- cortexnet-3.2.1.dist-info/top_level.txt +1 -0
cortexnet/attention.py
ADDED
|
@@ -0,0 +1,520 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""
|
|
4
|
+
选择性稀疏注意力 (Selective Sparse Attention) + 旋转位置编码 (RoPE)
|
|
5
|
+
|
|
6
|
+
核心创新:
|
|
7
|
+
传统 Transformer 的注意力机制对所有 token 对计算注意力权重,
|
|
8
|
+
复杂度为 O(n²)。CortexNet 通过学习一个重要性评分函数,
|
|
9
|
+
只选择最关键的 token 作为 Key/Value,所有 token 的 Query
|
|
10
|
+
只与这些选中的 token 交互。
|
|
11
|
+
|
|
12
|
+
复杂度分析:
|
|
13
|
+
- ratio 模式:O(n·k),k = ratio·n,例如 ratio=0.25 → 4x 加速
|
|
14
|
+
- sqrt 模式:O(n·√n),真正的亚二次复杂度
|
|
15
|
+
- log 模式:O(n·log(n)),近似线性
|
|
16
|
+
|
|
17
|
+
重要性评分器通过软门控接收梯度,实现端到端学习。
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
import math
|
|
21
|
+
import torch
|
|
22
|
+
import torch.nn as nn
|
|
23
|
+
import torch.nn.functional as F
|
|
24
|
+
from typing import Optional, Tuple
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _mps_safe() -> bool:
|
|
28
|
+
"""检测是否在 MPS 设备上(MPS 对复数运算支持有限)。"""
|
|
29
|
+
return hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class HeadRMSNorm(nn.Module):
|
|
33
|
+
"""用于注意力 Q/K 的按头 RMSNorm(最后一维归一化)。"""
|
|
34
|
+
|
|
35
|
+
def __init__(self, dim: int, eps: float = 1e-6):
|
|
36
|
+
super().__init__()
|
|
37
|
+
self.eps = eps
|
|
38
|
+
self.weight = nn.Parameter(torch.ones(dim))
|
|
39
|
+
|
|
40
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
41
|
+
if hasattr(F, "rms_norm"):
|
|
42
|
+
return F.rms_norm(x, (x.shape[-1],), self.weight, self.eps)
|
|
43
|
+
rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt()
|
|
44
|
+
return x * rms * self.weight
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def precompute_rope_freqs(
|
|
48
|
+
dim: int, max_seq_len: int, theta: float = 10000.0
|
|
49
|
+
) -> torch.Tensor:
|
|
50
|
+
"""预计算旋转位置编码 (RoPE) 的频率。
|
|
51
|
+
|
|
52
|
+
返回 cos/sin 张量而非复数,确保 CUDA/MPS/CPU 全兼容。
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
dim: 头维度(必须为偶数)
|
|
56
|
+
max_seq_len: 最大序列长度
|
|
57
|
+
theta: 频率计算的基数
|
|
58
|
+
Returns:
|
|
59
|
+
(max_seq_len, dim//2, 2) — [..., 0] = cos, [..., 1] = sin
|
|
60
|
+
"""
|
|
61
|
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
|
|
62
|
+
t = torch.arange(max_seq_len, dtype=torch.float32)
|
|
63
|
+
angles = torch.outer(t, freqs) # (max_seq_len, dim//2)
|
|
64
|
+
return torch.stack([angles.cos(), angles.sin()], dim=-1) # (L, dim//2, 2)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _apply_rope_real(x: torch.Tensor, cos_sin: torch.Tensor) -> torch.Tensor:
|
|
68
|
+
"""用实数 cos/sin 旋转(全设备兼容),替代复数乘法。
|
|
69
|
+
|
|
70
|
+
x: (B, H, L, D) where D is head_dim (even)
|
|
71
|
+
cos_sin: (L, D//2, 2)
|
|
72
|
+
"""
|
|
73
|
+
D = x.shape[-1]
|
|
74
|
+
x1, x2 = x[..., : D // 2], x[..., D // 2 :] # 各 (B, H, L, D//2)
|
|
75
|
+
cos = cos_sin[..., 0] # (L, D//2)
|
|
76
|
+
sin = cos_sin[..., 1] # (L, D//2)
|
|
77
|
+
# 广播到 (1, 1, L, D//2)
|
|
78
|
+
cos = cos.unsqueeze(0).unsqueeze(0)
|
|
79
|
+
sin = sin.unsqueeze(0).unsqueeze(0)
|
|
80
|
+
out1 = x1 * cos - x2 * sin
|
|
81
|
+
out2 = x1 * sin + x2 * cos
|
|
82
|
+
return torch.cat([out1, out2], dim=-1).type_as(x)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def apply_rope(x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
|
86
|
+
"""对输入张量应用旋转位置编码(全设备兼容)。
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
x: (batch, num_heads, seq_len, head_dim)
|
|
90
|
+
freqs: (max_seq_len, head_dim//2, 2) — cos/sin 频率
|
|
91
|
+
Returns:
|
|
92
|
+
旋转后的张量,形状与 x 相同
|
|
93
|
+
"""
|
|
94
|
+
L = x.shape[2]
|
|
95
|
+
return _apply_rope_real(x, freqs[:L])
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def apply_rope_with_positions(
|
|
99
|
+
x: torch.Tensor, freqs: torch.Tensor, positions: torch.Tensor
|
|
100
|
+
) -> torch.Tensor:
|
|
101
|
+
"""对指定位置的 token 应用 RoPE(全设备兼容)。
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
x: (batch, num_heads, k, head_dim)
|
|
105
|
+
freqs: (max_seq_len, head_dim//2, 2)
|
|
106
|
+
positions: (batch, k) - 每个 token 的原始位置
|
|
107
|
+
Returns:
|
|
108
|
+
旋转后的张量
|
|
109
|
+
"""
|
|
110
|
+
B, H, K, D = x.shape
|
|
111
|
+
pos_flat = positions.reshape(-1) # (B*K,)
|
|
112
|
+
pos_freqs = freqs[pos_flat].reshape(B, K, freqs.shape[1], 2) # (B, K, D//2, 2)
|
|
113
|
+
# 分离 cos, sin → (B, 1, K, D//2) 广播到所有 head
|
|
114
|
+
cos = pos_freqs[..., 0].unsqueeze(1)
|
|
115
|
+
sin = pos_freqs[..., 1].unsqueeze(1)
|
|
116
|
+
x1, x2 = x[..., : D // 2], x[..., D // 2 :]
|
|
117
|
+
out1 = x1 * cos - x2 * sin
|
|
118
|
+
out2 = x1 * sin + x2 * cos
|
|
119
|
+
return torch.cat([out1, out2], dim=-1).type_as(x)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class SelectiveSparseAttention(nn.Module):
|
|
123
|
+
"""选择性稀疏注意力机制。
|
|
124
|
+
|
|
125
|
+
架构流程:
|
|
126
|
+
1. 对所有 token 计算重要性分数(学习得到)
|
|
127
|
+
2. 选择 top-k 个最重要的 token 作为 Key/Value
|
|
128
|
+
3. 所有 token 的 Query 与选中的 Key/Value 交互
|
|
129
|
+
4. 应用 RoPE 保持位置信息
|
|
130
|
+
5. 通过重要性软门控传递梯度
|
|
131
|
+
|
|
132
|
+
相比全注意力的优势:
|
|
133
|
+
- 计算量:O(n·k) vs O(n²)
|
|
134
|
+
- 显存:O(n·k) vs O(n²)
|
|
135
|
+
- 自动学习哪些 token 最值得关注
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
d_model: 模型维度
|
|
139
|
+
num_heads: 注意力头数
|
|
140
|
+
top_k_ratio: 选择比例(ratio 模式下使用)
|
|
141
|
+
k_mode: 选择模式 - 'ratio', 'sqrt', 'log'
|
|
142
|
+
max_seq_len: 最大序列长度
|
|
143
|
+
rope_theta: RoPE 基数
|
|
144
|
+
dropout: 注意力 dropout 率
|
|
145
|
+
"""
|
|
146
|
+
|
|
147
|
+
def __init__(
|
|
148
|
+
self,
|
|
149
|
+
d_model: int,
|
|
150
|
+
num_heads: int,
|
|
151
|
+
top_k_ratio: float = 0.25,
|
|
152
|
+
k_mode: str = "ratio",
|
|
153
|
+
max_seq_len: int = 8192,
|
|
154
|
+
rope_theta: float = 10000.0,
|
|
155
|
+
dropout: float = 0.0,
|
|
156
|
+
sliding_window_size: int = 0,
|
|
157
|
+
num_kv_heads: int = 0,
|
|
158
|
+
use_qk_norm: bool = False,
|
|
159
|
+
norm_eps: float = 1e-6,
|
|
160
|
+
):
|
|
161
|
+
super().__init__()
|
|
162
|
+
self.d_model = d_model
|
|
163
|
+
self.num_heads = num_heads
|
|
164
|
+
self.num_kv_heads = num_kv_heads if num_kv_heads > 0 else num_heads
|
|
165
|
+
self.head_dim = d_model // num_heads
|
|
166
|
+
self.kv_dim = self.num_kv_heads * self.head_dim
|
|
167
|
+
self.top_k_ratio = top_k_ratio
|
|
168
|
+
self.k_mode = k_mode
|
|
169
|
+
self.scale = self.head_dim ** -0.5
|
|
170
|
+
self.sliding_window_size = sliding_window_size
|
|
171
|
+
|
|
172
|
+
assert d_model % num_heads == 0, (
|
|
173
|
+
f"d_model ({d_model}) 必须能被 num_heads ({num_heads}) 整除"
|
|
174
|
+
)
|
|
175
|
+
assert self.head_dim % 2 == 0, (
|
|
176
|
+
f"head_dim ({self.head_dim}) 必须为偶数以支持 RoPE"
|
|
177
|
+
)
|
|
178
|
+
assert self.num_heads % self.num_kv_heads == 0, (
|
|
179
|
+
f"num_heads ({self.num_heads}) 必须能被 num_kv_heads ({self.num_kv_heads}) 整除"
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
# Q/KV 投影(GQA: KV 使用更少的头)
|
|
183
|
+
self.q_proj = nn.Linear(d_model, d_model, bias=False)
|
|
184
|
+
self.kv_proj = nn.Linear(d_model, self.kv_dim * 2, bias=False)
|
|
185
|
+
self.o_proj = nn.Linear(d_model, d_model, bias=False)
|
|
186
|
+
|
|
187
|
+
# 学习的重要性评分器
|
|
188
|
+
self.importance_scorer = nn.Sequential(
|
|
189
|
+
nn.Linear(d_model, d_model // 4),
|
|
190
|
+
nn.GELU(),
|
|
191
|
+
nn.Linear(d_model // 4, 1),
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
self.attn_dropout = nn.Dropout(dropout)
|
|
195
|
+
|
|
196
|
+
# 可选的 Q/K 归一化(Qwen3 等模型使用 per-head RMSNorm)
|
|
197
|
+
if use_qk_norm:
|
|
198
|
+
self.q_norm = HeadRMSNorm(self.head_dim, eps=norm_eps)
|
|
199
|
+
self.k_norm = HeadRMSNorm(self.head_dim, eps=norm_eps)
|
|
200
|
+
else:
|
|
201
|
+
self.q_norm = None
|
|
202
|
+
self.k_norm = None
|
|
203
|
+
|
|
204
|
+
# 滑动窗口门控(与稀疏全局注意力并行融合)
|
|
205
|
+
if sliding_window_size > 0:
|
|
206
|
+
self.window_gate = nn.Linear(d_model, 1)
|
|
207
|
+
else:
|
|
208
|
+
self.window_gate = None
|
|
209
|
+
|
|
210
|
+
# 预计算 RoPE 频率
|
|
211
|
+
self.register_buffer(
|
|
212
|
+
"rope_freqs",
|
|
213
|
+
precompute_rope_freqs(self.head_dim, max_seq_len, rope_theta),
|
|
214
|
+
persistent=False,
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
def _project_kv(self, x: torch.Tensor):
|
|
218
|
+
"""融合 KV 投影(GQA: 输出 kv_dim 而非 d_model)。"""
|
|
219
|
+
kv = self.kv_proj(x)
|
|
220
|
+
return kv.chunk(2, dim=-1) # each (B, L, kv_dim)
|
|
221
|
+
|
|
222
|
+
def _compute_k_count(self, seq_len: int) -> int:
|
|
223
|
+
"""根据模式计算选择的 token 数量。"""
|
|
224
|
+
if self.k_mode == "sqrt":
|
|
225
|
+
return max(1, int(math.sqrt(seq_len)))
|
|
226
|
+
elif self.k_mode == "log":
|
|
227
|
+
return max(1, int(math.log2(max(seq_len, 2)) * self.num_heads))
|
|
228
|
+
else: # ratio
|
|
229
|
+
return max(1, int(seq_len * self.top_k_ratio))
|
|
230
|
+
|
|
231
|
+
def forward(
|
|
232
|
+
self,
|
|
233
|
+
x: torch.Tensor,
|
|
234
|
+
causal_mask: bool = True,
|
|
235
|
+
past_key_value: Optional[Tuple[torch.Tensor, ...]] = None,
|
|
236
|
+
use_cache: bool = False,
|
|
237
|
+
) -> torch.Tensor | Tuple[torch.Tensor, Optional[Tuple]]:
|
|
238
|
+
"""
|
|
239
|
+
Args:
|
|
240
|
+
x: (batch, seq_len, d_model)
|
|
241
|
+
causal_mask: 是否应用因果掩码
|
|
242
|
+
past_key_value:
|
|
243
|
+
新版: (past_K, past_V, past_top_k_indices, cache_len, total_seen_len)
|
|
244
|
+
旧版兼容: (past_K, past_V, past_top_k_indices, cache_len)
|
|
245
|
+
用于增量解码
|
|
246
|
+
use_cache: 若 True 返回 (output, new_cache)
|
|
247
|
+
Returns:
|
|
248
|
+
output: (batch, seq_len, d_model)
|
|
249
|
+
new_cache (可选): (K, V, top_k_indices) 当 use_cache=True
|
|
250
|
+
"""
|
|
251
|
+
B, L, D = x.shape
|
|
252
|
+
rope_freqs = self.rope_freqs.to(x.device)
|
|
253
|
+
|
|
254
|
+
cache_len = 0
|
|
255
|
+
total_seen_len = 0
|
|
256
|
+
if past_key_value is not None:
|
|
257
|
+
if len(past_key_value) == 5:
|
|
258
|
+
past_K, past_V, past_top_k, cache_len, total_seen_len = past_key_value
|
|
259
|
+
elif len(past_key_value) == 4:
|
|
260
|
+
past_K, past_V, past_top_k, cache_len = past_key_value
|
|
261
|
+
inferred_seen = int(past_top_k.max().item()) + 1 if past_top_k.numel() > 0 else 0
|
|
262
|
+
total_seen_len = max(int(cache_len), inferred_seen)
|
|
263
|
+
elif len(past_key_value) == 3:
|
|
264
|
+
past_K, past_V, past_top_k = past_key_value
|
|
265
|
+
cache_len = int(past_K.shape[2])
|
|
266
|
+
inferred_seen = int(past_top_k.max().item()) + 1 if past_top_k.numel() > 0 else 0
|
|
267
|
+
total_seen_len = max(cache_len, inferred_seen)
|
|
268
|
+
else:
|
|
269
|
+
raise ValueError(
|
|
270
|
+
"past_key_value must be a tuple of length 3, 4, or 5"
|
|
271
|
+
)
|
|
272
|
+
position_offset = int(total_seen_len)
|
|
273
|
+
else:
|
|
274
|
+
position_offset = 0
|
|
275
|
+
|
|
276
|
+
importance = self.importance_scorer(x).squeeze(-1) # (B, L)
|
|
277
|
+
|
|
278
|
+
if past_key_value is not None:
|
|
279
|
+
# 增量模式:静态缓存插入,避免 torch.cat 复制开销
|
|
280
|
+
new_positions = torch.arange(
|
|
281
|
+
position_offset, position_offset + L, device=x.device
|
|
282
|
+
).unsqueeze(0).expand(B, -1)
|
|
283
|
+
k_raw, v_raw = self._project_kv(x)
|
|
284
|
+
k_new = k_raw.view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
|
285
|
+
v_new = v_raw.view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
|
286
|
+
if self.k_norm is not None:
|
|
287
|
+
k_new = self.k_norm(k_new)
|
|
288
|
+
k_new = apply_rope_with_positions(k_new, rope_freqs, new_positions)
|
|
289
|
+
|
|
290
|
+
# 静态缓存:如果空间不够则扩展
|
|
291
|
+
new_len = cache_len + L
|
|
292
|
+
max_cache = past_K.shape[2]
|
|
293
|
+
if new_len > max_cache:
|
|
294
|
+
# 按 2x 扩展,分摊分配成本
|
|
295
|
+
new_max = max(new_len, max_cache * 2)
|
|
296
|
+
pad_k = torch.zeros(B, self.num_kv_heads, new_max - max_cache, self.head_dim,
|
|
297
|
+
device=past_K.device, dtype=past_K.dtype)
|
|
298
|
+
pad_v = torch.zeros_like(pad_k)
|
|
299
|
+
past_K = torch.cat([past_K, pad_k], dim=2)
|
|
300
|
+
past_V = torch.cat([past_V, pad_v], dim=2)
|
|
301
|
+
|
|
302
|
+
# 原地插入新 KV(零拷贝)
|
|
303
|
+
past_K[:, :, cache_len:new_len] = k_new
|
|
304
|
+
past_V[:, :, cache_len:new_len] = v_new
|
|
305
|
+
|
|
306
|
+
k = past_K[:, :, :new_len]
|
|
307
|
+
v = past_V[:, :, :new_len]
|
|
308
|
+
top_k_indices_full = torch.cat(
|
|
309
|
+
[past_top_k[:, :cache_len], new_positions], dim=1
|
|
310
|
+
)
|
|
311
|
+
else:
|
|
312
|
+
k_count = min(self._compute_k_count(L), L)
|
|
313
|
+
_, top_k_indices = importance.topk(k_count, dim=-1)
|
|
314
|
+
top_k_indices_sorted, _ = top_k_indices.sort(dim=-1)
|
|
315
|
+
idx_expanded = top_k_indices_sorted.unsqueeze(-1).expand(-1, -1, D)
|
|
316
|
+
x_selected = x.gather(1, idx_expanded)
|
|
317
|
+
k, v = self._project_kv(x_selected)
|
|
318
|
+
k = k.view(B, k_count, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
|
319
|
+
v = v.view(B, k_count, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
|
320
|
+
if self.k_norm is not None:
|
|
321
|
+
k = self.k_norm(k)
|
|
322
|
+
k = apply_rope_with_positions(k, rope_freqs, top_k_indices_sorted)
|
|
323
|
+
top_k_indices_full = top_k_indices_sorted
|
|
324
|
+
|
|
325
|
+
q = self.q_proj(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
|
|
326
|
+
# 可选 Q/K 归一化(Qwen3 per-head RMSNorm)
|
|
327
|
+
if self.q_norm is not None:
|
|
328
|
+
q = self.q_norm(q)
|
|
329
|
+
if position_offset > 0:
|
|
330
|
+
q_positions = torch.arange(
|
|
331
|
+
position_offset, position_offset + L, device=x.device
|
|
332
|
+
).unsqueeze(0).expand(B, -1)
|
|
333
|
+
q = apply_rope_with_positions(q, rope_freqs, q_positions)
|
|
334
|
+
else:
|
|
335
|
+
q = apply_rope(q, rope_freqs)
|
|
336
|
+
|
|
337
|
+
# 构建因果掩码 (bool)
|
|
338
|
+
attn_mask = None
|
|
339
|
+
if causal_mask and (L > 1 or position_offset > 0):
|
|
340
|
+
q_positions = torch.arange(
|
|
341
|
+
position_offset, position_offset + L, device=x.device
|
|
342
|
+
).view(1, 1, L, 1)
|
|
343
|
+
k_positions = top_k_indices_full.unsqueeze(1).unsqueeze(1)
|
|
344
|
+
attn_mask = q_positions >= k_positions # (B, 1, L, k)
|
|
345
|
+
|
|
346
|
+
# GQA 扩展:KV 头数 < Q 头数时,expand 不复制内存(比 repeat_interleave 快 2-3x)
|
|
347
|
+
if self.num_kv_heads != self.num_heads:
|
|
348
|
+
kv_repeat = self.num_heads // self.num_kv_heads
|
|
349
|
+
# expand 创建视图而非拷贝
|
|
350
|
+
k_attn = k.unsqueeze(2).expand(-1, -1, kv_repeat, -1, -1).reshape(B, self.num_heads, -1, self.head_dim)
|
|
351
|
+
v_attn = v.unsqueeze(2).expand(-1, -1, kv_repeat, -1, -1).reshape(B, self.num_heads, -1, self.head_dim)
|
|
352
|
+
else:
|
|
353
|
+
k_attn = k
|
|
354
|
+
v_attn = v
|
|
355
|
+
|
|
356
|
+
# 尝试使用 PyTorch SDPA(自动选择 Flash/Memory-efficient/Math 后端)
|
|
357
|
+
out = self._sdpa_attention(q, k_attn, v_attn, attn_mask)
|
|
358
|
+
out = out.transpose(1, 2).contiguous().view(B, L, D)
|
|
359
|
+
|
|
360
|
+
# 滑动窗口注意力(与稀疏全局并行,门控融合)
|
|
361
|
+
if self.window_gate is not None and past_key_value is None and L > 1:
|
|
362
|
+
window_out = self._sliding_window_attention(x, rope_freqs, L, position_offset)
|
|
363
|
+
gate_w = torch.sigmoid(self.window_gate(x)) # (B, L, 1)
|
|
364
|
+
out = out * (1 - gate_w) + window_out * gate_w
|
|
365
|
+
|
|
366
|
+
out = self.o_proj(out)
|
|
367
|
+
|
|
368
|
+
if use_cache:
|
|
369
|
+
# 返回静态缓存 + 实际长度
|
|
370
|
+
if past_key_value is not None:
|
|
371
|
+
new_cache_len = int(cache_len) + L
|
|
372
|
+
new_total_seen = int(total_seen_len) + L
|
|
373
|
+
cache_k = past_K
|
|
374
|
+
cache_v = past_V
|
|
375
|
+
else:
|
|
376
|
+
new_cache_len = k.shape[2]
|
|
377
|
+
new_total_seen = L
|
|
378
|
+
cache_k = k
|
|
379
|
+
cache_v = v
|
|
380
|
+
return out, (
|
|
381
|
+
cache_k,
|
|
382
|
+
cache_v,
|
|
383
|
+
top_k_indices_full,
|
|
384
|
+
new_cache_len,
|
|
385
|
+
new_total_seen,
|
|
386
|
+
)
|
|
387
|
+
return out
|
|
388
|
+
|
|
389
|
+
def _sdpa_attention(
|
|
390
|
+
self,
|
|
391
|
+
q: torch.Tensor,
|
|
392
|
+
k: torch.Tensor,
|
|
393
|
+
v: torch.Tensor,
|
|
394
|
+
attn_mask: Optional[torch.Tensor] = None,
|
|
395
|
+
) -> torch.Tensor:
|
|
396
|
+
"""使用 PyTorch SDPA 计算注意力(自动选择最优后端)。
|
|
397
|
+
|
|
398
|
+
优先: FlashAttention > Memory-Efficient > Math fallback
|
|
399
|
+
"""
|
|
400
|
+
try:
|
|
401
|
+
# SDPA 的 bool mask 语义:True=允许关注,False=屏蔽。
|
|
402
|
+
if attn_mask is not None:
|
|
403
|
+
out = F.scaled_dot_product_attention(
|
|
404
|
+
q, k, v,
|
|
405
|
+
attn_mask=attn_mask.to(dtype=torch.bool),
|
|
406
|
+
dropout_p=self.attn_dropout.p if self.training else 0.0,
|
|
407
|
+
scale=self.scale,
|
|
408
|
+
)
|
|
409
|
+
else:
|
|
410
|
+
out = F.scaled_dot_product_attention(
|
|
411
|
+
q, k, v,
|
|
412
|
+
dropout_p=self.attn_dropout.p if self.training else 0.0,
|
|
413
|
+
scale=self.scale,
|
|
414
|
+
)
|
|
415
|
+
except RuntimeError:
|
|
416
|
+
# 回退到手动实现(兼容旧 PyTorch 或不支持的形状)
|
|
417
|
+
out = self._manual_attention(q, k, v, attn_mask)
|
|
418
|
+
return out
|
|
419
|
+
|
|
420
|
+
def _sliding_window_attention(
|
|
421
|
+
self,
|
|
422
|
+
x: torch.Tensor,
|
|
423
|
+
rope_freqs: torch.Tensor,
|
|
424
|
+
L: int,
|
|
425
|
+
position_offset: int,
|
|
426
|
+
) -> torch.Tensor:
|
|
427
|
+
"""分块滑动窗口注意力,避免构建 O(L²) 掩码。
|
|
428
|
+
|
|
429
|
+
将序列按 window_size 分块,每块只与自身 + 上一块的后半交互,
|
|
430
|
+
总内存 O(L × W) 而非 O(L²)。
|
|
431
|
+
"""
|
|
432
|
+
B, _, D = x.shape
|
|
433
|
+
W = self.sliding_window_size
|
|
434
|
+
|
|
435
|
+
# 全量 Q/K/V 投影 + RoPE
|
|
436
|
+
q_w = self.q_proj(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
|
|
437
|
+
k_raw, v_raw = self._project_kv(x)
|
|
438
|
+
k_w = k_raw.view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
|
439
|
+
v_w = v_raw.view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
|
440
|
+
q_w = apply_rope(q_w, rope_freqs)
|
|
441
|
+
k_w = apply_rope(k_w, rope_freqs)
|
|
442
|
+
|
|
443
|
+
# GQA 扩展
|
|
444
|
+
if self.num_kv_heads != self.num_heads:
|
|
445
|
+
kv_repeat = self.num_heads // self.num_kv_heads
|
|
446
|
+
k_w = k_w.repeat_interleave(kv_repeat, dim=1)
|
|
447
|
+
v_w = v_w.repeat_interleave(kv_repeat, dim=1)
|
|
448
|
+
|
|
449
|
+
# 如果序列短于 2 倍窗口,直接用全量注意力
|
|
450
|
+
if L <= W * 2:
|
|
451
|
+
positions = torch.arange(L, device=x.device)
|
|
452
|
+
diff = positions.unsqueeze(0) - positions.unsqueeze(1)
|
|
453
|
+
window_mask = (diff >= 0) & (diff < W)
|
|
454
|
+
window_mask = window_mask.unsqueeze(0).unsqueeze(0)
|
|
455
|
+
try:
|
|
456
|
+
out_w = F.scaled_dot_product_attention(
|
|
457
|
+
q_w, k_w, v_w,
|
|
458
|
+
attn_mask=window_mask,
|
|
459
|
+
scale=self.scale,
|
|
460
|
+
)
|
|
461
|
+
except RuntimeError:
|
|
462
|
+
attn = (q_w @ k_w.transpose(-2, -1)) * self.scale
|
|
463
|
+
attn = attn.masked_fill(~window_mask, float("-inf"))
|
|
464
|
+
attn = F.softmax(attn, dim=-1)
|
|
465
|
+
out_w = attn @ v_w
|
|
466
|
+
return out_w.transpose(1, 2).contiguous().view(B, L, D)
|
|
467
|
+
|
|
468
|
+
# 分块处理: 每块处理 W 个 query,KV 范围为 [start-W+1, start+W]
|
|
469
|
+
outputs = []
|
|
470
|
+
for start in range(0, L, W):
|
|
471
|
+
end = min(start + W, L)
|
|
472
|
+
q_chunk = q_w[:, :, start:end, :] # (B, H, chunk, D)
|
|
473
|
+
|
|
474
|
+
# KV 范围: [max(0, start-W+1), end]
|
|
475
|
+
kv_start = max(0, start - W + 1)
|
|
476
|
+
k_chunk = k_w[:, :, kv_start:end, :]
|
|
477
|
+
v_chunk = v_w[:, :, kv_start:end, :]
|
|
478
|
+
|
|
479
|
+
# 构建局部因果 + 窗口掩码
|
|
480
|
+
q_pos = torch.arange(start, end, device=x.device).view(1, 1, -1, 1)
|
|
481
|
+
k_pos = torch.arange(kv_start, end, device=x.device).view(1, 1, 1, -1)
|
|
482
|
+
local_mask = (q_pos >= k_pos) & (q_pos - k_pos < W)
|
|
483
|
+
|
|
484
|
+
try:
|
|
485
|
+
out_chunk = F.scaled_dot_product_attention(
|
|
486
|
+
q_chunk, k_chunk, v_chunk,
|
|
487
|
+
attn_mask=local_mask,
|
|
488
|
+
scale=self.scale,
|
|
489
|
+
)
|
|
490
|
+
except RuntimeError:
|
|
491
|
+
attn = (q_chunk @ k_chunk.transpose(-2, -1)) * self.scale
|
|
492
|
+
attn = attn.masked_fill(~local_mask, float("-inf"))
|
|
493
|
+
attn = F.softmax(attn, dim=-1)
|
|
494
|
+
out_chunk = attn @ v_chunk
|
|
495
|
+
outputs.append(out_chunk)
|
|
496
|
+
|
|
497
|
+
out_w = torch.cat(outputs, dim=2)
|
|
498
|
+
return out_w.transpose(1, 2).contiguous().view(B, L, D)
|
|
499
|
+
|
|
500
|
+
def _manual_attention(
|
|
501
|
+
self,
|
|
502
|
+
q: torch.Tensor,
|
|
503
|
+
k: torch.Tensor,
|
|
504
|
+
v: torch.Tensor,
|
|
505
|
+
attn_mask: Optional[torch.Tensor] = None,
|
|
506
|
+
) -> torch.Tensor:
|
|
507
|
+
"""手动注意力计算(SDPA 不可用时的回退)。"""
|
|
508
|
+
attn_weights = (q @ k.transpose(-2, -1)) * self.scale
|
|
509
|
+
if attn_mask is not None:
|
|
510
|
+
attn_weights = attn_weights.masked_fill(~attn_mask, float("-inf"))
|
|
511
|
+
attn_max = attn_weights.max(dim=-1, keepdim=True).values
|
|
512
|
+
attn_max = torch.where(
|
|
513
|
+
attn_max == float("-inf"), torch.zeros_like(attn_max), attn_max
|
|
514
|
+
)
|
|
515
|
+
attn_weights = attn_weights - attn_max
|
|
516
|
+
all_masked = (attn_weights == float("-inf")).all(dim=-1, keepdim=True)
|
|
517
|
+
attn_weights = F.softmax(attn_weights, dim=-1)
|
|
518
|
+
attn_weights = attn_weights.masked_fill(all_masked, 0.0)
|
|
519
|
+
attn_weights = self.attn_dropout(attn_weights)
|
|
520
|
+
return attn_weights @ v
|