cortexnet 3.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cortexnet/cache.py ADDED
@@ -0,0 +1,83 @@
1
+ """
2
+ CortexNet 推理缓存 (CortexNet Inference Cache)
3
+
4
+ 用于自回归生成时的增量解码,避免每步重算完整序列。
5
+
6
+ 缓存层级:
7
+ CortexNetCache
8
+ └── LayerCache (per block)
9
+ ├── ssm_state: Tensor — SSM hidden state
10
+ ├── memory_state: (Tensor, Tensor) — WorkingMemory (memory, z)
11
+ └── kv_cache: (Tensor, Tensor, Tensor) — Attention (K, V, indices)
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ from dataclasses import dataclass, field
17
+ from typing import List, Optional, Tuple
18
+
19
+ import torch
20
+
21
+
22
+ @dataclass
23
+ class LayerCache:
24
+ """单层 CortexBlock 的缓存。
25
+
26
+ Attributes:
27
+ ssm_state: SSM 隐状态 (B, d_inner, N),可为 None(首次前向)
28
+ memory_state: WorkingMemory 的 (memory, z),可为 None
29
+ kv_cache: Attention 的 (K, V, top_k_indices),可为 None
30
+ """
31
+
32
+ ssm_state: Optional[torch.Tensor] = None
33
+ memory_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
34
+ kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None
35
+
36
+ def as_tuple(self) -> Tuple:
37
+ """转换为 (ssm_state, memory_state, kv_cache) 元组,兼容旧 API。"""
38
+ return (self.ssm_state, self.memory_state, self.kv_cache)
39
+
40
+ @staticmethod
41
+ def from_tuple(t: Tuple) -> LayerCache:
42
+ """从旧式元组创建。"""
43
+ if t is None:
44
+ return LayerCache()
45
+ return LayerCache(
46
+ ssm_state=t[0] if len(t) > 0 else None,
47
+ memory_state=t[1] if len(t) > 1 else None,
48
+ kv_cache=t[2] if len(t) > 2 else None,
49
+ )
50
+
51
+
52
+ @dataclass
53
+ class CortexNetCache:
54
+ """CortexNet 增量解码缓存(全模型)。
55
+
56
+ Attributes:
57
+ layers: 每层的 LayerCache
58
+ position_offset: 已处理的 token 数量
59
+ """
60
+
61
+ layers: List[LayerCache] = field(default_factory=list)
62
+ position_offset: int = 0
63
+
64
+ @staticmethod
65
+ def init_empty(num_layers: int) -> CortexNetCache:
66
+ """创建空缓存。"""
67
+ return CortexNetCache(
68
+ layers=[LayerCache() for _ in range(num_layers)],
69
+ position_offset=0,
70
+ )
71
+
72
+ def as_list(self) -> List[Tuple]:
73
+ """转换为旧式 List[Tuple] 格式,兼容旧 API。"""
74
+ return [lc.as_tuple() for lc in self.layers]
75
+
76
+ @staticmethod
77
+ def from_list(lst: List[Tuple]) -> CortexNetCache:
78
+ """从旧式 List[Tuple] 创建。"""
79
+ if lst is None:
80
+ return CortexNetCache()
81
+ return CortexNetCache(
82
+ layers=[LayerCache.from_tuple(t) for t in lst],
83
+ )
@@ -0,0 +1,232 @@
1
+ """
2
+ 因果推理模块 (Causal Reasoning Module)
3
+
4
+ 核心创新:
5
+ 传统注意力只学习相关性 P(Y|X),而因果推理学习因果关系 P(Y|do(X))。
6
+ 这使模型不仅知道"什么和什么一起出现",还知道"什么导致了什么"。
7
+
8
+ ┌─────────────────────────────────────────────────────────────┐
9
+ │ 因果推理 vs 传统注意力 │
10
+ ├─────────────────────────────────────────────────────────────┤
11
+ │ │
12
+ │ 传统注意力: P(Y|X) — "X 出现时 Y 也出现" │
13
+ │ → 只捕获相关性,无法区分因果与混淆 │
14
+ │ │
15
+ │ 因果推理: P(Y|do(X)) — "如果我们主动设置 X,Y 会如何" │
16
+ │ → 捕获真正的因果关系,支持反事实推理 │
17
+ │ │
18
+ │ 三大组件: │
19
+ │ 1. 因果强度估计器 — 评估每个 token 的因果重要性 │
20
+ │ 2. 干预注意力 — 基于因果方向选择高因果 token 计算注意力 │
21
+ │ 3. 反事实分支 — "如果这个原因不同,结果会如何?" │
22
+ └─────────────────────────────────────────────────────────────┘
23
+
24
+ 灵感来源: Pearl's do-calculus, 结构因果模型 (SCM)
25
+
26
+ 优化 (v3.2):
27
+ - InterventionalAttention 使用选择性稀疏注意力替代 O(n²) 全量注意力,
28
+ 仅对因果强度 Top-K 的 token 计算注意力,复杂度降至 O(n·k)
29
+ - CounterfactualBranch 使用合并的批量线性变换替代 for 循环,
30
+ 单次 matmul 完成所有反事实分支计算
31
+ """
32
+
33
+ import logging
34
+ import torch
35
+ import torch.nn as nn
36
+ import torch.nn.functional as F
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+
41
+ class CausalStrengthEstimator(nn.Module):
42
+ """因果强度估计器:评估每个 token 对后续 token 的因果影响力。"""
43
+
44
+ def __init__(self, d_model: int):
45
+ super().__init__()
46
+ self.scorer = nn.Sequential(
47
+ nn.Linear(d_model, d_model // 4),
48
+ nn.GELU(),
49
+ nn.Linear(d_model // 4, 1),
50
+ )
51
+
52
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
53
+ return torch.sigmoid(self.scorer(x)) # (B, L, 1)
54
+
55
+
56
+ class InterventionalAttention(nn.Module):
57
+ """干预注意力:基于因果方向的选择性稀疏注意力。
58
+
59
+ 优化: 仅对因果强度 Top-K 的 token 计算注意力,
60
+ 复杂度从 O(n²) 降至 O(n·k),与系统整体 O(n) 定位一致。
61
+
62
+ 与标准注意力的区别:
63
+ 标准: attn(i,j) = softmax(Q_i · K_j / √d) — 全量 O(n²)
64
+ 干预: attn(i,j) = softmax((Q_i · K_top_j + bias_j) / √d) — 稀疏 O(n·k)
65
+
66
+ 因果偏置让 "因果上更重要" 的 token 获得更多关注。
67
+ """
68
+
69
+ def __init__(self, d_model: int, num_heads: int = 4, top_k_ratio: float = 0.25,
70
+ dropout: float = 0.0):
71
+ super().__init__()
72
+ self.num_heads = num_heads
73
+ self.head_dim = d_model // num_heads
74
+ self.scale = self.head_dim ** -0.5
75
+ self.top_k_ratio = top_k_ratio
76
+
77
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
78
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
79
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
80
+ self.o_proj = nn.Linear(d_model, d_model, bias=False)
81
+ self.attn_dropout = nn.Dropout(dropout)
82
+
83
+ def _compute_k_count(self, seq_len: int) -> int:
84
+ """根据序列长度计算选择的 token 数量。"""
85
+ k = max(1, int(seq_len * self.top_k_ratio))
86
+ return min(k, seq_len)
87
+
88
+ def forward(
89
+ self, x: torch.Tensor, causal_strength: torch.Tensor
90
+ ) -> torch.Tensor:
91
+ B, L, D = x.shape
92
+ H, hd = self.num_heads, self.head_dim
93
+
94
+ q = self.q_proj(x).view(B, L, H, hd).transpose(1, 2) # (B, H, L, hd)
95
+ k = self.k_proj(x).view(B, L, H, hd).transpose(1, 2)
96
+ v = self.v_proj(x).view(B, L, H, hd).transpose(1, 2)
97
+
98
+ # 确定 top-k 数量
99
+ k_count = self._compute_k_count(L)
100
+
101
+ if k_count >= L:
102
+ # 短序列退化为完整注意力(但加上因果偏置)
103
+ return self._full_attention(q, k, v, causal_strength, B, L, D, H, hd)
104
+
105
+ # ═══ 选择性稀疏: 按因果强度选 top-k token 作为 KV ═══
106
+ # causal_strength: (B, L, 1)
107
+ scores = causal_strength.squeeze(-1) # (B, L)
108
+ _, top_indices = scores.topk(k_count, dim=-1, sorted=False) # (B, k)
109
+
110
+ # 收集 top-k 的 K, V
111
+ top_idx_kv = top_indices.unsqueeze(1).unsqueeze(-1).expand(-1, H, -1, hd)
112
+ k_sel = k.gather(2, top_idx_kv) # (B, H, k, hd)
113
+ v_sel = v.gather(2, top_idx_kv) # (B, H, k, hd)
114
+
115
+ # 计算稀疏注意力分数
116
+ attn = (q @ k_sel.transpose(-2, -1)) * self.scale # (B, H, L, k)
117
+
118
+ # 因果偏置: 对选中的 token 施加因果强度调制
119
+ causal_bias_sel = scores.gather(1, top_indices) # (B, k)
120
+ causal_bias = torch.log(causal_bias_sel + 1e-6).clamp(min=-10)
121
+ attn = attn + causal_bias.unsqueeze(1).unsqueeze(2) # (B, 1, 1, k) broadcast
122
+
123
+ # 因果掩码: 只允许 attend 到位置 ≤ 当前 token 的 key
124
+ positions = torch.arange(L, device=x.device).unsqueeze(1) # (L, 1)
125
+ key_positions = top_indices.unsqueeze(1) # (B, 1, k)
126
+ causal_mask = positions >= key_positions # (B, L, k)
127
+ attn = attn.masked_fill(~causal_mask.unsqueeze(1), float("-inf"))
128
+
129
+ attn = F.softmax(attn, dim=-1)
130
+ attn = self.attn_dropout(attn)
131
+ out = (attn @ v_sel).transpose(1, 2).contiguous().view(B, L, D)
132
+ return self.o_proj(out)
133
+
134
+ def _full_attention(self, q, k, v, causal_strength, B, L, D, H, hd):
135
+ """短序列退化为完整注意力(带因果偏置)。"""
136
+ attn = (q @ k.transpose(-2, -1)) * self.scale
137
+
138
+ causal_bias = causal_strength.squeeze(-1).unsqueeze(1).unsqueeze(2)
139
+ attn = attn + torch.log(causal_bias + 1e-6).clamp(min=-10)
140
+
141
+ causal_mask = torch.tril(torch.ones(L, L, device=q.device))
142
+ attn = attn.masked_fill(causal_mask.unsqueeze(0).unsqueeze(0) == 0, float("-inf"))
143
+
144
+ attn = F.softmax(attn, dim=-1)
145
+ attn = self.attn_dropout(attn)
146
+ out = (attn @ v).transpose(1, 2).contiguous().view(B, L, D)
147
+ return self.o_proj(out)
148
+
149
+
150
+ class CounterfactualBranch(nn.Module):
151
+ """反事实推理分支:探索 "如果原因不同,结果会如何"。
152
+
153
+ 优化: 使用合并的权重矩阵 (d_model, num_cf * d_model) 替代
154
+ K 个独立 nn.Linear 的 for 循环,单次 matmul 完成所有反事实变换。
155
+
156
+ 维护 K 个反事实变换,每个代表一种 "假设情景"。
157
+ 通过门控机制软选择最相关的反事实分支。
158
+ """
159
+
160
+ def __init__(self, d_model: int, num_counterfactuals: int = 4, dropout: float = 0.0):
161
+ super().__init__()
162
+ self.num_cf = num_counterfactuals
163
+ self.d_model = d_model
164
+
165
+ # 合并所有反事实变换为单个矩阵: (d_model, num_cf * d_model)
166
+ self.merged_transform = nn.Linear(d_model, num_counterfactuals * d_model, bias=False)
167
+ # 初始化为接近零(初始不改变输入)
168
+ nn.init.zeros_(self.merged_transform.weight)
169
+
170
+ self.gate = nn.Sequential(
171
+ nn.Linear(d_model, d_model // 4),
172
+ nn.GELU(),
173
+ nn.Linear(d_model // 4, num_counterfactuals),
174
+ )
175
+ # 可学习的反事实融合强度
176
+ self.cf_scale = nn.Parameter(torch.tensor(0.1))
177
+ self.dropout = nn.Dropout(dropout)
178
+
179
+ def forward(self, x: torch.Tensor, causal_strength: torch.Tensor) -> torch.Tensor:
180
+ B, L, D = x.shape
181
+ K = self.num_cf
182
+
183
+ gate_weights = F.softmax(self.gate(x), dim=-1) # (B, L, K)
184
+
185
+ # 单次 matmul 完成所有 K 个反事实变换
186
+ all_cf = self.merged_transform(x) # (B, L, K*D)
187
+ all_cf = all_cf.view(B, L, K, D) # (B, L, K, D)
188
+
189
+ # 门控加权融合
190
+ counterfactual = (all_cf * gate_weights.unsqueeze(-1)).sum(dim=2) # (B, L, D)
191
+ counterfactual = self.dropout(counterfactual)
192
+
193
+ # 因果强度高的 token 贡献更多反事实信息
194
+ return counterfactual * causal_strength * self.cf_scale
195
+
196
+
197
+ class CausalReasoningModule(nn.Module):
198
+ """因果推理模块:完整的因果推理流水线。
199
+
200
+ 架构:
201
+ 1. 估计每个 token 的因果强度
202
+ 2. 干预注意力(因果方向感知,稀疏 O(n·k))
203
+ 3. 反事实推理(探索替代情景,批量单次 matmul)
204
+ 4. 融合观察结果和反事实结果
205
+
206
+ Args:
207
+ d_model: 模型维度
208
+ num_heads: 干预注意力头数
209
+ num_counterfactuals: 反事实分支数
210
+ top_k_ratio: 干预注意力的 top-k 比例
211
+ dropout: Dropout 比率
212
+ """
213
+
214
+ def __init__(self, d_model: int, num_heads: int = 4, num_counterfactuals: int = 4,
215
+ top_k_ratio: float = 0.25, dropout: float = 0.0):
216
+ super().__init__()
217
+ self.causal_estimator = CausalStrengthEstimator(d_model)
218
+ self.interventional_attn = InterventionalAttention(
219
+ d_model, num_heads, top_k_ratio=top_k_ratio, dropout=dropout
220
+ )
221
+ self.counterfactual = CounterfactualBranch(
222
+ d_model, num_counterfactuals, dropout=dropout
223
+ )
224
+ self.out_proj = nn.Linear(d_model, d_model, bias=False)
225
+ self.norm = nn.LayerNorm(d_model)
226
+
227
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
228
+ causal_strength = self.causal_estimator(x) # (B, L, 1)
229
+ observational = self.interventional_attn(x, causal_strength)
230
+ counterfactual = self.counterfactual(x, causal_strength)
231
+ combined = observational + counterfactual
232
+ return self.out_proj(self.norm(combined))
cortexnet/compat.py ADDED
@@ -0,0 +1,245 @@
1
+ """
2
+ CortexNet 兼容模式组件 (Compatibility Mode Components)
3
+
4
+ 为开源 LLM (LLaMA/Qwen/Mistral...) 的权重无损迁移提供轻量兼容组件:
5
+ - _CompatAttention: GQA + KV cache 注意力
6
+ - _CompatLiteSSM: 低秩 SSM 旁路(默认零影响)
7
+ - _CompatFusionGate: 轻量两路融合(强偏向 Attention)
8
+ - _CompatExpert/_CompatMoE: 单专家 FFN 兼容壳
9
+ - _CompatCortexBlockV3: 完整的 V3 兼容块
10
+ - _NoOpEvolutionEngine: 占位进化引擎
11
+
12
+ 这些组件在 compatibility_mode=True 时替代完整 V3 模块使用。
13
+ """
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from typing import Optional, Tuple
19
+
20
+ try:
21
+ from .config import CortexNetConfig
22
+ from .blocks import RMSNorm
23
+ from .attention import precompute_rope_freqs, apply_rope, apply_rope_with_positions
24
+ except ImportError:
25
+ from cortexnet.config import CortexNetConfig
26
+ from cortexnet.blocks import RMSNorm
27
+ from cortexnet.attention import precompute_rope_freqs, apply_rope, apply_rope_with_positions
28
+
29
+
30
+ class _NoOpEvolutionEngine(nn.Module):
31
+ """兼容模式下的轻量占位引擎。"""
32
+
33
+ def __init__(self):
34
+ super().__init__()
35
+
36
+ def get_compute_budget(self, x: torch.Tensor) -> torch.Tensor:
37
+ return torch.ones(x.size(0), device=x.device, dtype=x.dtype)
38
+
39
+ def get_efficiency_loss(self) -> float:
40
+ return 0.0
41
+
42
+
43
+ class _CompatAttention(nn.Module):
44
+ """与主流 HF 架构对齐的轻量注意力(支持 GQA + KV cache)。"""
45
+
46
+ def __init__(self, config: CortexNetConfig):
47
+ super().__init__()
48
+ self.d_model = config.hidden_size
49
+ self.num_heads = config.num_heads
50
+ self.num_kv_heads = config.num_kv_heads
51
+ self.head_dim = self.d_model // self.num_heads
52
+
53
+ kv_dim = self.num_kv_heads * self.head_dim
54
+ self.q_proj = nn.Linear(self.d_model, self.d_model, bias=False)
55
+ self.k_proj = nn.Linear(self.d_model, kv_dim, bias=False)
56
+ self.v_proj = nn.Linear(self.d_model, kv_dim, bias=False)
57
+ self.o_proj = nn.Linear(self.d_model, self.d_model, bias=False)
58
+
59
+ # Qwen2/3 常见 q_norm/k_norm
60
+ self.q_norm = RMSNorm(self.head_dim, config.norm_eps)
61
+ self.k_norm = RMSNorm(self.head_dim, config.norm_eps)
62
+
63
+ self.register_buffer(
64
+ "rope_freqs",
65
+ precompute_rope_freqs(self.head_dim, config.max_seq_len, config.rope_theta),
66
+ persistent=False,
67
+ )
68
+
69
+ def forward(
70
+ self,
71
+ x: torch.Tensor,
72
+ past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
73
+ use_cache: bool = False,
74
+ ) -> torch.Tensor | Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
75
+ B, L, D = x.shape
76
+ q = self.q_proj(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
77
+ k = self.k_proj(x).view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2)
78
+ v = self.v_proj(x).view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2)
79
+
80
+ q = self.q_norm(q)
81
+ k = self.k_norm(k)
82
+
83
+ rope = self.rope_freqs.to(x.device)
84
+ if past_key_value is None:
85
+ q = apply_rope(q, rope)
86
+ k = apply_rope(k, rope)
87
+ else:
88
+ past_k, past_v = past_key_value
89
+ pos_offset = past_k.shape[2]
90
+ pos = torch.arange(
91
+ pos_offset, pos_offset + L, device=x.device
92
+ ).unsqueeze(0).expand(B, -1)
93
+ q = apply_rope_with_positions(q, rope, pos)
94
+ k = apply_rope_with_positions(k, rope, pos)
95
+ k = torch.cat([past_k, k], dim=2)
96
+ v = torch.cat([past_v, v], dim=2)
97
+
98
+ # GQA: 将 KV 头扩展到 Query 头
99
+ if self.num_kv_heads != self.num_heads:
100
+ repeat = self.num_heads // self.num_kv_heads
101
+ k_attn = k.repeat_interleave(repeat, dim=1)
102
+ v_attn = v.repeat_interleave(repeat, dim=1)
103
+ else:
104
+ k_attn = k
105
+ v_attn = v
106
+
107
+ attn_out = F.scaled_dot_product_attention(
108
+ q, k_attn, v_attn,
109
+ attn_mask=None, dropout_p=0.0,
110
+ is_causal=(past_key_value is None),
111
+ )
112
+ out = self.o_proj(attn_out.transpose(1, 2).contiguous().view(B, L, D))
113
+
114
+ if use_cache:
115
+ return out, (k, v)
116
+ return out
117
+
118
+
119
+ class _CompatLiteSSM(nn.Module):
120
+ """轻量 SSM 路径(参数开销小,默认零影响,支持后续无感校准启用)。"""
121
+
122
+ def __init__(self, d_model: int, rank: int = 256, skip_threshold: float = 1e-6):
123
+ super().__init__()
124
+ self.rank = max(16, min(rank, d_model))
125
+ self.skip_threshold = float(skip_threshold)
126
+ self.fast_skip = True
127
+ self.in_proj = nn.Linear(d_model, self.rank, bias=False)
128
+ self.out_proj = nn.Linear(self.rank, d_model, bias=False)
129
+ self.alpha = nn.Parameter(torch.tensor(0.0))
130
+
131
+ def is_effectively_disabled(self) -> bool:
132
+ return (not self.training) and bool(self.fast_skip)
133
+
134
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
135
+ if self.is_effectively_disabled():
136
+ return torch.zeros_like(x)
137
+ return self.alpha * self.out_proj(torch.tanh(self.in_proj(x)))
138
+
139
+
140
+ class _CompatFusionGate(nn.Module):
141
+ """轻量融合门:仅融合 SSM/Attention 两路,初始化强偏向 Attention。"""
142
+
143
+ def __init__(
144
+ self, d_model: int,
145
+ long_context_threshold: int = 2048,
146
+ long_context_ssm_ratio: float = 0.35,
147
+ ):
148
+ super().__init__()
149
+ bottleneck = max(d_model // 32, 32)
150
+ self.long_context_threshold = int(long_context_threshold)
151
+ self.long_context_ssm_ratio = float(min(max(long_context_ssm_ratio, 0.0), 0.95))
152
+ self.context_proj = nn.Sequential(
153
+ nn.Linear(d_model, bottleneck),
154
+ nn.SiLU(),
155
+ nn.Linear(bottleneck, 1),
156
+ )
157
+ self.attn_bias = nn.Parameter(torch.tensor(10.0))
158
+ with torch.no_grad():
159
+ for module in self.context_proj:
160
+ if isinstance(module, nn.Linear):
161
+ module.weight.zero_()
162
+ if module.bias is not None:
163
+ module.bias.zero_()
164
+
165
+ def forward(self, x_norm, ssm_out, attn_out, *, ssm_enabled=True):
166
+ if not ssm_enabled:
167
+ return attn_out
168
+ context = x_norm.mean(dim=1)
169
+ gate = torch.sigmoid(self.attn_bias + self.context_proj(context)).unsqueeze(1)
170
+ if self.long_context_threshold > 0 and x_norm.size(1) >= self.long_context_threshold:
171
+ max_gate = 1.0 - self.long_context_ssm_ratio
172
+ gate = torch.clamp(gate, min=0.0, max=max_gate)
173
+ return gate * attn_out + (1.0 - gate) * ssm_out
174
+
175
+
176
+ class _CompatExpert(nn.Module):
177
+ """与源模型 FFN 权重命名对齐的单专家。"""
178
+
179
+ def __init__(self, d_model: int, d_ff: int):
180
+ super().__init__()
181
+ self.gate_proj = nn.Linear(d_model, d_ff, bias=False)
182
+ self.up_proj = nn.Linear(d_model, d_ff, bias=False)
183
+ self.down_proj = nn.Linear(d_ff, d_model, bias=False)
184
+
185
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
186
+ return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
187
+
188
+
189
+ class _CompatMoE(nn.Module):
190
+ """轻量 MoE 兼容壳:保留 experts.0 参数路径,前向等价单 FFN。"""
191
+
192
+ def __init__(self, d_model: int, d_ff: int):
193
+ super().__init__()
194
+ self.experts = nn.ModuleList([_CompatExpert(d_model, d_ff)])
195
+ self.router = nn.Linear(d_model, 1, bias=False)
196
+ self.aux_loss = torch.tensor(0.0)
197
+
198
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
199
+ self.aux_loss = x.new_zeros(())
200
+ return self.experts[0](x)
201
+
202
+
203
+ class _CompatCortexBlockV3(nn.Module):
204
+ """V3 兼容块:保留 SSM + 稀疏注意力 + 轻量融合门,去除非核心大模块。"""
205
+
206
+ def __init__(self, config: CortexNetConfig):
207
+ super().__init__()
208
+ self.norm1 = RMSNorm(config.hidden_size, config.norm_eps)
209
+ self.norm2 = RMSNorm(config.hidden_size, config.norm_eps)
210
+ self.ssm = _CompatLiteSSM(
211
+ config.hidden_size,
212
+ rank=getattr(config, "compat_ssm_rank", 256),
213
+ )
214
+ self.attention = _CompatAttention(config)
215
+ self.fusion = _CompatFusionGate(
216
+ config.hidden_size,
217
+ long_context_threshold=getattr(config, "fusion_long_context_threshold", 2048),
218
+ long_context_ssm_ratio=getattr(config, "fusion_long_context_ssm_ratio", 0.35),
219
+ )
220
+ self.moe = _CompatMoE(config.hidden_size, config.intermediate_size)
221
+ self.dropout = nn.Dropout(config.dropout)
222
+
223
+ def forward(self, x, past_cache=None, use_cache=False):
224
+ residual = x
225
+ x_norm = self.norm1(x)
226
+ ssm_enabled = not self.ssm.is_effectively_disabled()
227
+ ssm_out = self.ssm(x_norm) if ssm_enabled else None
228
+
229
+ if use_cache:
230
+ attn_out, new_cache = self.attention(x_norm, past_key_value=past_cache, use_cache=True)
231
+ else:
232
+ attn_out = self.attention(x_norm, past_key_value=past_cache, use_cache=False)
233
+ new_cache = None
234
+
235
+ fused = self.fusion(x_norm, ssm_out, attn_out, ssm_enabled=(ssm_enabled and ssm_out is not None)) if ssm_enabled and ssm_out is not None else attn_out
236
+ x = residual + self.dropout(fused)
237
+ residual = x
238
+ x = residual + self.dropout(self.moe(self.norm2(x)))
239
+
240
+ if use_cache:
241
+ return x, new_cache
242
+ return x
243
+
244
+ def get_aux_loss(self):
245
+ return self.moe.aux_loss