cortexnet 3.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cortexnet/__init__.py +197 -0
- cortexnet/adapter/__init__.py +26 -0
- cortexnet/adapter/arch_adapter.py +209 -0
- cortexnet/adapter/calibrator.py +244 -0
- cortexnet/adapter/inference_adapter.py +272 -0
- cortexnet/adapter/model_registry.py +378 -0
- cortexnet/adapter/weight_adapter.py +415 -0
- cortexnet/adversarial.py +195 -0
- cortexnet/attention.py +520 -0
- cortexnet/blocks.py +682 -0
- cortexnet/cache.py +83 -0
- cortexnet/causal_reasoning.py +232 -0
- cortexnet/compat.py +245 -0
- cortexnet/config.py +234 -0
- cortexnet/continual_learning.py +256 -0
- cortexnet/cortex_block_lite.py +221 -0
- cortexnet/distributed.py +213 -0
- cortexnet/graph_reasoning.py +207 -0
- cortexnet/hierarchical_memory.py +360 -0
- cortexnet/interpretability.py +196 -0
- cortexnet/memory.py +179 -0
- cortexnet/meta_learning.py +187 -0
- cortexnet/model.py +1360 -0
- cortexnet/multi_agent.py +241 -0
- cortexnet/multimodal.py +278 -0
- cortexnet/ops/__init__.py +28 -0
- cortexnet/ops/device_manager.py +449 -0
- cortexnet/ops/npu_ops.py +243 -0
- cortexnet/quantization.py +496 -0
- cortexnet/routing.py +335 -0
- cortexnet/self_evolution.py +174 -0
- cortexnet/ssm.py +340 -0
- cortexnet/training_utils.py +204 -0
- cortexnet/transformer_baseline.py +157 -0
- cortexnet-3.2.1.dist-info/METADATA +114 -0
- cortexnet-3.2.1.dist-info/RECORD +39 -0
- cortexnet-3.2.1.dist-info/WHEEL +5 -0
- cortexnet-3.2.1.dist-info/licenses/LICENSE +201 -0
- cortexnet-3.2.1.dist-info/top_level.txt +1 -0
cortexnet/cache.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CortexNet 推理缓存 (CortexNet Inference Cache)
|
|
3
|
+
|
|
4
|
+
用于自回归生成时的增量解码,避免每步重算完整序列。
|
|
5
|
+
|
|
6
|
+
缓存层级:
|
|
7
|
+
CortexNetCache
|
|
8
|
+
└── LayerCache (per block)
|
|
9
|
+
├── ssm_state: Tensor — SSM hidden state
|
|
10
|
+
├── memory_state: (Tensor, Tensor) — WorkingMemory (memory, z)
|
|
11
|
+
└── kv_cache: (Tensor, Tensor, Tensor) — Attention (K, V, indices)
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
from dataclasses import dataclass, field
|
|
17
|
+
from typing import List, Optional, Tuple
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class LayerCache:
|
|
24
|
+
"""单层 CortexBlock 的缓存。
|
|
25
|
+
|
|
26
|
+
Attributes:
|
|
27
|
+
ssm_state: SSM 隐状态 (B, d_inner, N),可为 None(首次前向)
|
|
28
|
+
memory_state: WorkingMemory 的 (memory, z),可为 None
|
|
29
|
+
kv_cache: Attention 的 (K, V, top_k_indices),可为 None
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
ssm_state: Optional[torch.Tensor] = None
|
|
33
|
+
memory_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
|
|
34
|
+
kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None
|
|
35
|
+
|
|
36
|
+
def as_tuple(self) -> Tuple:
|
|
37
|
+
"""转换为 (ssm_state, memory_state, kv_cache) 元组,兼容旧 API。"""
|
|
38
|
+
return (self.ssm_state, self.memory_state, self.kv_cache)
|
|
39
|
+
|
|
40
|
+
@staticmethod
|
|
41
|
+
def from_tuple(t: Tuple) -> LayerCache:
|
|
42
|
+
"""从旧式元组创建。"""
|
|
43
|
+
if t is None:
|
|
44
|
+
return LayerCache()
|
|
45
|
+
return LayerCache(
|
|
46
|
+
ssm_state=t[0] if len(t) > 0 else None,
|
|
47
|
+
memory_state=t[1] if len(t) > 1 else None,
|
|
48
|
+
kv_cache=t[2] if len(t) > 2 else None,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dataclass
|
|
53
|
+
class CortexNetCache:
|
|
54
|
+
"""CortexNet 增量解码缓存(全模型)。
|
|
55
|
+
|
|
56
|
+
Attributes:
|
|
57
|
+
layers: 每层的 LayerCache
|
|
58
|
+
position_offset: 已处理的 token 数量
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
layers: List[LayerCache] = field(default_factory=list)
|
|
62
|
+
position_offset: int = 0
|
|
63
|
+
|
|
64
|
+
@staticmethod
|
|
65
|
+
def init_empty(num_layers: int) -> CortexNetCache:
|
|
66
|
+
"""创建空缓存。"""
|
|
67
|
+
return CortexNetCache(
|
|
68
|
+
layers=[LayerCache() for _ in range(num_layers)],
|
|
69
|
+
position_offset=0,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
def as_list(self) -> List[Tuple]:
|
|
73
|
+
"""转换为旧式 List[Tuple] 格式,兼容旧 API。"""
|
|
74
|
+
return [lc.as_tuple() for lc in self.layers]
|
|
75
|
+
|
|
76
|
+
@staticmethod
|
|
77
|
+
def from_list(lst: List[Tuple]) -> CortexNetCache:
|
|
78
|
+
"""从旧式 List[Tuple] 创建。"""
|
|
79
|
+
if lst is None:
|
|
80
|
+
return CortexNetCache()
|
|
81
|
+
return CortexNetCache(
|
|
82
|
+
layers=[LayerCache.from_tuple(t) for t in lst],
|
|
83
|
+
)
|
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
"""
|
|
2
|
+
因果推理模块 (Causal Reasoning Module)
|
|
3
|
+
|
|
4
|
+
核心创新:
|
|
5
|
+
传统注意力只学习相关性 P(Y|X),而因果推理学习因果关系 P(Y|do(X))。
|
|
6
|
+
这使模型不仅知道"什么和什么一起出现",还知道"什么导致了什么"。
|
|
7
|
+
|
|
8
|
+
┌─────────────────────────────────────────────────────────────┐
|
|
9
|
+
│ 因果推理 vs 传统注意力 │
|
|
10
|
+
├─────────────────────────────────────────────────────────────┤
|
|
11
|
+
│ │
|
|
12
|
+
│ 传统注意力: P(Y|X) — "X 出现时 Y 也出现" │
|
|
13
|
+
│ → 只捕获相关性,无法区分因果与混淆 │
|
|
14
|
+
│ │
|
|
15
|
+
│ 因果推理: P(Y|do(X)) — "如果我们主动设置 X,Y 会如何" │
|
|
16
|
+
│ → 捕获真正的因果关系,支持反事实推理 │
|
|
17
|
+
│ │
|
|
18
|
+
│ 三大组件: │
|
|
19
|
+
│ 1. 因果强度估计器 — 评估每个 token 的因果重要性 │
|
|
20
|
+
│ 2. 干预注意力 — 基于因果方向选择高因果 token 计算注意力 │
|
|
21
|
+
│ 3. 反事实分支 — "如果这个原因不同,结果会如何?" │
|
|
22
|
+
└─────────────────────────────────────────────────────────────┘
|
|
23
|
+
|
|
24
|
+
灵感来源: Pearl's do-calculus, 结构因果模型 (SCM)
|
|
25
|
+
|
|
26
|
+
优化 (v3.2):
|
|
27
|
+
- InterventionalAttention 使用选择性稀疏注意力替代 O(n²) 全量注意力,
|
|
28
|
+
仅对因果强度 Top-K 的 token 计算注意力,复杂度降至 O(n·k)
|
|
29
|
+
- CounterfactualBranch 使用合并的批量线性变换替代 for 循环,
|
|
30
|
+
单次 matmul 完成所有反事实分支计算
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
import logging
|
|
34
|
+
import torch
|
|
35
|
+
import torch.nn as nn
|
|
36
|
+
import torch.nn.functional as F
|
|
37
|
+
|
|
38
|
+
logger = logging.getLogger(__name__)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class CausalStrengthEstimator(nn.Module):
|
|
42
|
+
"""因果强度估计器:评估每个 token 对后续 token 的因果影响力。"""
|
|
43
|
+
|
|
44
|
+
def __init__(self, d_model: int):
|
|
45
|
+
super().__init__()
|
|
46
|
+
self.scorer = nn.Sequential(
|
|
47
|
+
nn.Linear(d_model, d_model // 4),
|
|
48
|
+
nn.GELU(),
|
|
49
|
+
nn.Linear(d_model // 4, 1),
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
53
|
+
return torch.sigmoid(self.scorer(x)) # (B, L, 1)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class InterventionalAttention(nn.Module):
|
|
57
|
+
"""干预注意力:基于因果方向的选择性稀疏注意力。
|
|
58
|
+
|
|
59
|
+
优化: 仅对因果强度 Top-K 的 token 计算注意力,
|
|
60
|
+
复杂度从 O(n²) 降至 O(n·k),与系统整体 O(n) 定位一致。
|
|
61
|
+
|
|
62
|
+
与标准注意力的区别:
|
|
63
|
+
标准: attn(i,j) = softmax(Q_i · K_j / √d) — 全量 O(n²)
|
|
64
|
+
干预: attn(i,j) = softmax((Q_i · K_top_j + bias_j) / √d) — 稀疏 O(n·k)
|
|
65
|
+
|
|
66
|
+
因果偏置让 "因果上更重要" 的 token 获得更多关注。
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
def __init__(self, d_model: int, num_heads: int = 4, top_k_ratio: float = 0.25,
|
|
70
|
+
dropout: float = 0.0):
|
|
71
|
+
super().__init__()
|
|
72
|
+
self.num_heads = num_heads
|
|
73
|
+
self.head_dim = d_model // num_heads
|
|
74
|
+
self.scale = self.head_dim ** -0.5
|
|
75
|
+
self.top_k_ratio = top_k_ratio
|
|
76
|
+
|
|
77
|
+
self.q_proj = nn.Linear(d_model, d_model, bias=False)
|
|
78
|
+
self.k_proj = nn.Linear(d_model, d_model, bias=False)
|
|
79
|
+
self.v_proj = nn.Linear(d_model, d_model, bias=False)
|
|
80
|
+
self.o_proj = nn.Linear(d_model, d_model, bias=False)
|
|
81
|
+
self.attn_dropout = nn.Dropout(dropout)
|
|
82
|
+
|
|
83
|
+
def _compute_k_count(self, seq_len: int) -> int:
|
|
84
|
+
"""根据序列长度计算选择的 token 数量。"""
|
|
85
|
+
k = max(1, int(seq_len * self.top_k_ratio))
|
|
86
|
+
return min(k, seq_len)
|
|
87
|
+
|
|
88
|
+
def forward(
|
|
89
|
+
self, x: torch.Tensor, causal_strength: torch.Tensor
|
|
90
|
+
) -> torch.Tensor:
|
|
91
|
+
B, L, D = x.shape
|
|
92
|
+
H, hd = self.num_heads, self.head_dim
|
|
93
|
+
|
|
94
|
+
q = self.q_proj(x).view(B, L, H, hd).transpose(1, 2) # (B, H, L, hd)
|
|
95
|
+
k = self.k_proj(x).view(B, L, H, hd).transpose(1, 2)
|
|
96
|
+
v = self.v_proj(x).view(B, L, H, hd).transpose(1, 2)
|
|
97
|
+
|
|
98
|
+
# 确定 top-k 数量
|
|
99
|
+
k_count = self._compute_k_count(L)
|
|
100
|
+
|
|
101
|
+
if k_count >= L:
|
|
102
|
+
# 短序列退化为完整注意力(但加上因果偏置)
|
|
103
|
+
return self._full_attention(q, k, v, causal_strength, B, L, D, H, hd)
|
|
104
|
+
|
|
105
|
+
# ═══ 选择性稀疏: 按因果强度选 top-k token 作为 KV ═══
|
|
106
|
+
# causal_strength: (B, L, 1)
|
|
107
|
+
scores = causal_strength.squeeze(-1) # (B, L)
|
|
108
|
+
_, top_indices = scores.topk(k_count, dim=-1, sorted=False) # (B, k)
|
|
109
|
+
|
|
110
|
+
# 收集 top-k 的 K, V
|
|
111
|
+
top_idx_kv = top_indices.unsqueeze(1).unsqueeze(-1).expand(-1, H, -1, hd)
|
|
112
|
+
k_sel = k.gather(2, top_idx_kv) # (B, H, k, hd)
|
|
113
|
+
v_sel = v.gather(2, top_idx_kv) # (B, H, k, hd)
|
|
114
|
+
|
|
115
|
+
# 计算稀疏注意力分数
|
|
116
|
+
attn = (q @ k_sel.transpose(-2, -1)) * self.scale # (B, H, L, k)
|
|
117
|
+
|
|
118
|
+
# 因果偏置: 对选中的 token 施加因果强度调制
|
|
119
|
+
causal_bias_sel = scores.gather(1, top_indices) # (B, k)
|
|
120
|
+
causal_bias = torch.log(causal_bias_sel + 1e-6).clamp(min=-10)
|
|
121
|
+
attn = attn + causal_bias.unsqueeze(1).unsqueeze(2) # (B, 1, 1, k) broadcast
|
|
122
|
+
|
|
123
|
+
# 因果掩码: 只允许 attend 到位置 ≤ 当前 token 的 key
|
|
124
|
+
positions = torch.arange(L, device=x.device).unsqueeze(1) # (L, 1)
|
|
125
|
+
key_positions = top_indices.unsqueeze(1) # (B, 1, k)
|
|
126
|
+
causal_mask = positions >= key_positions # (B, L, k)
|
|
127
|
+
attn = attn.masked_fill(~causal_mask.unsqueeze(1), float("-inf"))
|
|
128
|
+
|
|
129
|
+
attn = F.softmax(attn, dim=-1)
|
|
130
|
+
attn = self.attn_dropout(attn)
|
|
131
|
+
out = (attn @ v_sel).transpose(1, 2).contiguous().view(B, L, D)
|
|
132
|
+
return self.o_proj(out)
|
|
133
|
+
|
|
134
|
+
def _full_attention(self, q, k, v, causal_strength, B, L, D, H, hd):
|
|
135
|
+
"""短序列退化为完整注意力(带因果偏置)。"""
|
|
136
|
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
|
137
|
+
|
|
138
|
+
causal_bias = causal_strength.squeeze(-1).unsqueeze(1).unsqueeze(2)
|
|
139
|
+
attn = attn + torch.log(causal_bias + 1e-6).clamp(min=-10)
|
|
140
|
+
|
|
141
|
+
causal_mask = torch.tril(torch.ones(L, L, device=q.device))
|
|
142
|
+
attn = attn.masked_fill(causal_mask.unsqueeze(0).unsqueeze(0) == 0, float("-inf"))
|
|
143
|
+
|
|
144
|
+
attn = F.softmax(attn, dim=-1)
|
|
145
|
+
attn = self.attn_dropout(attn)
|
|
146
|
+
out = (attn @ v).transpose(1, 2).contiguous().view(B, L, D)
|
|
147
|
+
return self.o_proj(out)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
class CounterfactualBranch(nn.Module):
|
|
151
|
+
"""反事实推理分支:探索 "如果原因不同,结果会如何"。
|
|
152
|
+
|
|
153
|
+
优化: 使用合并的权重矩阵 (d_model, num_cf * d_model) 替代
|
|
154
|
+
K 个独立 nn.Linear 的 for 循环,单次 matmul 完成所有反事实变换。
|
|
155
|
+
|
|
156
|
+
维护 K 个反事实变换,每个代表一种 "假设情景"。
|
|
157
|
+
通过门控机制软选择最相关的反事实分支。
|
|
158
|
+
"""
|
|
159
|
+
|
|
160
|
+
def __init__(self, d_model: int, num_counterfactuals: int = 4, dropout: float = 0.0):
|
|
161
|
+
super().__init__()
|
|
162
|
+
self.num_cf = num_counterfactuals
|
|
163
|
+
self.d_model = d_model
|
|
164
|
+
|
|
165
|
+
# 合并所有反事实变换为单个矩阵: (d_model, num_cf * d_model)
|
|
166
|
+
self.merged_transform = nn.Linear(d_model, num_counterfactuals * d_model, bias=False)
|
|
167
|
+
# 初始化为接近零(初始不改变输入)
|
|
168
|
+
nn.init.zeros_(self.merged_transform.weight)
|
|
169
|
+
|
|
170
|
+
self.gate = nn.Sequential(
|
|
171
|
+
nn.Linear(d_model, d_model // 4),
|
|
172
|
+
nn.GELU(),
|
|
173
|
+
nn.Linear(d_model // 4, num_counterfactuals),
|
|
174
|
+
)
|
|
175
|
+
# 可学习的反事实融合强度
|
|
176
|
+
self.cf_scale = nn.Parameter(torch.tensor(0.1))
|
|
177
|
+
self.dropout = nn.Dropout(dropout)
|
|
178
|
+
|
|
179
|
+
def forward(self, x: torch.Tensor, causal_strength: torch.Tensor) -> torch.Tensor:
|
|
180
|
+
B, L, D = x.shape
|
|
181
|
+
K = self.num_cf
|
|
182
|
+
|
|
183
|
+
gate_weights = F.softmax(self.gate(x), dim=-1) # (B, L, K)
|
|
184
|
+
|
|
185
|
+
# 单次 matmul 完成所有 K 个反事实变换
|
|
186
|
+
all_cf = self.merged_transform(x) # (B, L, K*D)
|
|
187
|
+
all_cf = all_cf.view(B, L, K, D) # (B, L, K, D)
|
|
188
|
+
|
|
189
|
+
# 门控加权融合
|
|
190
|
+
counterfactual = (all_cf * gate_weights.unsqueeze(-1)).sum(dim=2) # (B, L, D)
|
|
191
|
+
counterfactual = self.dropout(counterfactual)
|
|
192
|
+
|
|
193
|
+
# 因果强度高的 token 贡献更多反事实信息
|
|
194
|
+
return counterfactual * causal_strength * self.cf_scale
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
class CausalReasoningModule(nn.Module):
|
|
198
|
+
"""因果推理模块:完整的因果推理流水线。
|
|
199
|
+
|
|
200
|
+
架构:
|
|
201
|
+
1. 估计每个 token 的因果强度
|
|
202
|
+
2. 干预注意力(因果方向感知,稀疏 O(n·k))
|
|
203
|
+
3. 反事实推理(探索替代情景,批量单次 matmul)
|
|
204
|
+
4. 融合观察结果和反事实结果
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
d_model: 模型维度
|
|
208
|
+
num_heads: 干预注意力头数
|
|
209
|
+
num_counterfactuals: 反事实分支数
|
|
210
|
+
top_k_ratio: 干预注意力的 top-k 比例
|
|
211
|
+
dropout: Dropout 比率
|
|
212
|
+
"""
|
|
213
|
+
|
|
214
|
+
def __init__(self, d_model: int, num_heads: int = 4, num_counterfactuals: int = 4,
|
|
215
|
+
top_k_ratio: float = 0.25, dropout: float = 0.0):
|
|
216
|
+
super().__init__()
|
|
217
|
+
self.causal_estimator = CausalStrengthEstimator(d_model)
|
|
218
|
+
self.interventional_attn = InterventionalAttention(
|
|
219
|
+
d_model, num_heads, top_k_ratio=top_k_ratio, dropout=dropout
|
|
220
|
+
)
|
|
221
|
+
self.counterfactual = CounterfactualBranch(
|
|
222
|
+
d_model, num_counterfactuals, dropout=dropout
|
|
223
|
+
)
|
|
224
|
+
self.out_proj = nn.Linear(d_model, d_model, bias=False)
|
|
225
|
+
self.norm = nn.LayerNorm(d_model)
|
|
226
|
+
|
|
227
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
228
|
+
causal_strength = self.causal_estimator(x) # (B, L, 1)
|
|
229
|
+
observational = self.interventional_attn(x, causal_strength)
|
|
230
|
+
counterfactual = self.counterfactual(x, causal_strength)
|
|
231
|
+
combined = observational + counterfactual
|
|
232
|
+
return self.out_proj(self.norm(combined))
|
cortexnet/compat.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CortexNet 兼容模式组件 (Compatibility Mode Components)
|
|
3
|
+
|
|
4
|
+
为开源 LLM (LLaMA/Qwen/Mistral...) 的权重无损迁移提供轻量兼容组件:
|
|
5
|
+
- _CompatAttention: GQA + KV cache 注意力
|
|
6
|
+
- _CompatLiteSSM: 低秩 SSM 旁路(默认零影响)
|
|
7
|
+
- _CompatFusionGate: 轻量两路融合(强偏向 Attention)
|
|
8
|
+
- _CompatExpert/_CompatMoE: 单专家 FFN 兼容壳
|
|
9
|
+
- _CompatCortexBlockV3: 完整的 V3 兼容块
|
|
10
|
+
- _NoOpEvolutionEngine: 占位进化引擎
|
|
11
|
+
|
|
12
|
+
这些组件在 compatibility_mode=True 时替代完整 V3 模块使用。
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
import torch.nn as nn
|
|
17
|
+
import torch.nn.functional as F
|
|
18
|
+
from typing import Optional, Tuple
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
from .config import CortexNetConfig
|
|
22
|
+
from .blocks import RMSNorm
|
|
23
|
+
from .attention import precompute_rope_freqs, apply_rope, apply_rope_with_positions
|
|
24
|
+
except ImportError:
|
|
25
|
+
from cortexnet.config import CortexNetConfig
|
|
26
|
+
from cortexnet.blocks import RMSNorm
|
|
27
|
+
from cortexnet.attention import precompute_rope_freqs, apply_rope, apply_rope_with_positions
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class _NoOpEvolutionEngine(nn.Module):
|
|
31
|
+
"""兼容模式下的轻量占位引擎。"""
|
|
32
|
+
|
|
33
|
+
def __init__(self):
|
|
34
|
+
super().__init__()
|
|
35
|
+
|
|
36
|
+
def get_compute_budget(self, x: torch.Tensor) -> torch.Tensor:
|
|
37
|
+
return torch.ones(x.size(0), device=x.device, dtype=x.dtype)
|
|
38
|
+
|
|
39
|
+
def get_efficiency_loss(self) -> float:
|
|
40
|
+
return 0.0
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class _CompatAttention(nn.Module):
|
|
44
|
+
"""与主流 HF 架构对齐的轻量注意力(支持 GQA + KV cache)。"""
|
|
45
|
+
|
|
46
|
+
def __init__(self, config: CortexNetConfig):
|
|
47
|
+
super().__init__()
|
|
48
|
+
self.d_model = config.hidden_size
|
|
49
|
+
self.num_heads = config.num_heads
|
|
50
|
+
self.num_kv_heads = config.num_kv_heads
|
|
51
|
+
self.head_dim = self.d_model // self.num_heads
|
|
52
|
+
|
|
53
|
+
kv_dim = self.num_kv_heads * self.head_dim
|
|
54
|
+
self.q_proj = nn.Linear(self.d_model, self.d_model, bias=False)
|
|
55
|
+
self.k_proj = nn.Linear(self.d_model, kv_dim, bias=False)
|
|
56
|
+
self.v_proj = nn.Linear(self.d_model, kv_dim, bias=False)
|
|
57
|
+
self.o_proj = nn.Linear(self.d_model, self.d_model, bias=False)
|
|
58
|
+
|
|
59
|
+
# Qwen2/3 常见 q_norm/k_norm
|
|
60
|
+
self.q_norm = RMSNorm(self.head_dim, config.norm_eps)
|
|
61
|
+
self.k_norm = RMSNorm(self.head_dim, config.norm_eps)
|
|
62
|
+
|
|
63
|
+
self.register_buffer(
|
|
64
|
+
"rope_freqs",
|
|
65
|
+
precompute_rope_freqs(self.head_dim, config.max_seq_len, config.rope_theta),
|
|
66
|
+
persistent=False,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
def forward(
|
|
70
|
+
self,
|
|
71
|
+
x: torch.Tensor,
|
|
72
|
+
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
73
|
+
use_cache: bool = False,
|
|
74
|
+
) -> torch.Tensor | Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
75
|
+
B, L, D = x.shape
|
|
76
|
+
q = self.q_proj(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
|
|
77
|
+
k = self.k_proj(x).view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
|
78
|
+
v = self.v_proj(x).view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
|
79
|
+
|
|
80
|
+
q = self.q_norm(q)
|
|
81
|
+
k = self.k_norm(k)
|
|
82
|
+
|
|
83
|
+
rope = self.rope_freqs.to(x.device)
|
|
84
|
+
if past_key_value is None:
|
|
85
|
+
q = apply_rope(q, rope)
|
|
86
|
+
k = apply_rope(k, rope)
|
|
87
|
+
else:
|
|
88
|
+
past_k, past_v = past_key_value
|
|
89
|
+
pos_offset = past_k.shape[2]
|
|
90
|
+
pos = torch.arange(
|
|
91
|
+
pos_offset, pos_offset + L, device=x.device
|
|
92
|
+
).unsqueeze(0).expand(B, -1)
|
|
93
|
+
q = apply_rope_with_positions(q, rope, pos)
|
|
94
|
+
k = apply_rope_with_positions(k, rope, pos)
|
|
95
|
+
k = torch.cat([past_k, k], dim=2)
|
|
96
|
+
v = torch.cat([past_v, v], dim=2)
|
|
97
|
+
|
|
98
|
+
# GQA: 将 KV 头扩展到 Query 头
|
|
99
|
+
if self.num_kv_heads != self.num_heads:
|
|
100
|
+
repeat = self.num_heads // self.num_kv_heads
|
|
101
|
+
k_attn = k.repeat_interleave(repeat, dim=1)
|
|
102
|
+
v_attn = v.repeat_interleave(repeat, dim=1)
|
|
103
|
+
else:
|
|
104
|
+
k_attn = k
|
|
105
|
+
v_attn = v
|
|
106
|
+
|
|
107
|
+
attn_out = F.scaled_dot_product_attention(
|
|
108
|
+
q, k_attn, v_attn,
|
|
109
|
+
attn_mask=None, dropout_p=0.0,
|
|
110
|
+
is_causal=(past_key_value is None),
|
|
111
|
+
)
|
|
112
|
+
out = self.o_proj(attn_out.transpose(1, 2).contiguous().view(B, L, D))
|
|
113
|
+
|
|
114
|
+
if use_cache:
|
|
115
|
+
return out, (k, v)
|
|
116
|
+
return out
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class _CompatLiteSSM(nn.Module):
|
|
120
|
+
"""轻量 SSM 路径(参数开销小,默认零影响,支持后续无感校准启用)。"""
|
|
121
|
+
|
|
122
|
+
def __init__(self, d_model: int, rank: int = 256, skip_threshold: float = 1e-6):
|
|
123
|
+
super().__init__()
|
|
124
|
+
self.rank = max(16, min(rank, d_model))
|
|
125
|
+
self.skip_threshold = float(skip_threshold)
|
|
126
|
+
self.fast_skip = True
|
|
127
|
+
self.in_proj = nn.Linear(d_model, self.rank, bias=False)
|
|
128
|
+
self.out_proj = nn.Linear(self.rank, d_model, bias=False)
|
|
129
|
+
self.alpha = nn.Parameter(torch.tensor(0.0))
|
|
130
|
+
|
|
131
|
+
def is_effectively_disabled(self) -> bool:
|
|
132
|
+
return (not self.training) and bool(self.fast_skip)
|
|
133
|
+
|
|
134
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
135
|
+
if self.is_effectively_disabled():
|
|
136
|
+
return torch.zeros_like(x)
|
|
137
|
+
return self.alpha * self.out_proj(torch.tanh(self.in_proj(x)))
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class _CompatFusionGate(nn.Module):
|
|
141
|
+
"""轻量融合门:仅融合 SSM/Attention 两路,初始化强偏向 Attention。"""
|
|
142
|
+
|
|
143
|
+
def __init__(
|
|
144
|
+
self, d_model: int,
|
|
145
|
+
long_context_threshold: int = 2048,
|
|
146
|
+
long_context_ssm_ratio: float = 0.35,
|
|
147
|
+
):
|
|
148
|
+
super().__init__()
|
|
149
|
+
bottleneck = max(d_model // 32, 32)
|
|
150
|
+
self.long_context_threshold = int(long_context_threshold)
|
|
151
|
+
self.long_context_ssm_ratio = float(min(max(long_context_ssm_ratio, 0.0), 0.95))
|
|
152
|
+
self.context_proj = nn.Sequential(
|
|
153
|
+
nn.Linear(d_model, bottleneck),
|
|
154
|
+
nn.SiLU(),
|
|
155
|
+
nn.Linear(bottleneck, 1),
|
|
156
|
+
)
|
|
157
|
+
self.attn_bias = nn.Parameter(torch.tensor(10.0))
|
|
158
|
+
with torch.no_grad():
|
|
159
|
+
for module in self.context_proj:
|
|
160
|
+
if isinstance(module, nn.Linear):
|
|
161
|
+
module.weight.zero_()
|
|
162
|
+
if module.bias is not None:
|
|
163
|
+
module.bias.zero_()
|
|
164
|
+
|
|
165
|
+
def forward(self, x_norm, ssm_out, attn_out, *, ssm_enabled=True):
|
|
166
|
+
if not ssm_enabled:
|
|
167
|
+
return attn_out
|
|
168
|
+
context = x_norm.mean(dim=1)
|
|
169
|
+
gate = torch.sigmoid(self.attn_bias + self.context_proj(context)).unsqueeze(1)
|
|
170
|
+
if self.long_context_threshold > 0 and x_norm.size(1) >= self.long_context_threshold:
|
|
171
|
+
max_gate = 1.0 - self.long_context_ssm_ratio
|
|
172
|
+
gate = torch.clamp(gate, min=0.0, max=max_gate)
|
|
173
|
+
return gate * attn_out + (1.0 - gate) * ssm_out
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
class _CompatExpert(nn.Module):
|
|
177
|
+
"""与源模型 FFN 权重命名对齐的单专家。"""
|
|
178
|
+
|
|
179
|
+
def __init__(self, d_model: int, d_ff: int):
|
|
180
|
+
super().__init__()
|
|
181
|
+
self.gate_proj = nn.Linear(d_model, d_ff, bias=False)
|
|
182
|
+
self.up_proj = nn.Linear(d_model, d_ff, bias=False)
|
|
183
|
+
self.down_proj = nn.Linear(d_ff, d_model, bias=False)
|
|
184
|
+
|
|
185
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
186
|
+
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
class _CompatMoE(nn.Module):
|
|
190
|
+
"""轻量 MoE 兼容壳:保留 experts.0 参数路径,前向等价单 FFN。"""
|
|
191
|
+
|
|
192
|
+
def __init__(self, d_model: int, d_ff: int):
|
|
193
|
+
super().__init__()
|
|
194
|
+
self.experts = nn.ModuleList([_CompatExpert(d_model, d_ff)])
|
|
195
|
+
self.router = nn.Linear(d_model, 1, bias=False)
|
|
196
|
+
self.aux_loss = torch.tensor(0.0)
|
|
197
|
+
|
|
198
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
199
|
+
self.aux_loss = x.new_zeros(())
|
|
200
|
+
return self.experts[0](x)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
class _CompatCortexBlockV3(nn.Module):
|
|
204
|
+
"""V3 兼容块:保留 SSM + 稀疏注意力 + 轻量融合门,去除非核心大模块。"""
|
|
205
|
+
|
|
206
|
+
def __init__(self, config: CortexNetConfig):
|
|
207
|
+
super().__init__()
|
|
208
|
+
self.norm1 = RMSNorm(config.hidden_size, config.norm_eps)
|
|
209
|
+
self.norm2 = RMSNorm(config.hidden_size, config.norm_eps)
|
|
210
|
+
self.ssm = _CompatLiteSSM(
|
|
211
|
+
config.hidden_size,
|
|
212
|
+
rank=getattr(config, "compat_ssm_rank", 256),
|
|
213
|
+
)
|
|
214
|
+
self.attention = _CompatAttention(config)
|
|
215
|
+
self.fusion = _CompatFusionGate(
|
|
216
|
+
config.hidden_size,
|
|
217
|
+
long_context_threshold=getattr(config, "fusion_long_context_threshold", 2048),
|
|
218
|
+
long_context_ssm_ratio=getattr(config, "fusion_long_context_ssm_ratio", 0.35),
|
|
219
|
+
)
|
|
220
|
+
self.moe = _CompatMoE(config.hidden_size, config.intermediate_size)
|
|
221
|
+
self.dropout = nn.Dropout(config.dropout)
|
|
222
|
+
|
|
223
|
+
def forward(self, x, past_cache=None, use_cache=False):
|
|
224
|
+
residual = x
|
|
225
|
+
x_norm = self.norm1(x)
|
|
226
|
+
ssm_enabled = not self.ssm.is_effectively_disabled()
|
|
227
|
+
ssm_out = self.ssm(x_norm) if ssm_enabled else None
|
|
228
|
+
|
|
229
|
+
if use_cache:
|
|
230
|
+
attn_out, new_cache = self.attention(x_norm, past_key_value=past_cache, use_cache=True)
|
|
231
|
+
else:
|
|
232
|
+
attn_out = self.attention(x_norm, past_key_value=past_cache, use_cache=False)
|
|
233
|
+
new_cache = None
|
|
234
|
+
|
|
235
|
+
fused = self.fusion(x_norm, ssm_out, attn_out, ssm_enabled=(ssm_enabled and ssm_out is not None)) if ssm_enabled and ssm_out is not None else attn_out
|
|
236
|
+
x = residual + self.dropout(fused)
|
|
237
|
+
residual = x
|
|
238
|
+
x = residual + self.dropout(self.moe(self.norm2(x)))
|
|
239
|
+
|
|
240
|
+
if use_cache:
|
|
241
|
+
return x, new_cache
|
|
242
|
+
return x
|
|
243
|
+
|
|
244
|
+
def get_aux_loss(self):
|
|
245
|
+
return self.moe.aux_loss
|