cortexnet 3.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cortexnet/blocks.py ADDED
@@ -0,0 +1,682 @@
1
+ from __future__ import annotations
2
+
3
+ """
4
+ CortexNet 核心构建块 (Core Building Blocks)
5
+
6
+ 将所有组件组合成 CortexBlock——CortexNet 架构的基本重复单元。
7
+
8
+ 每个 CortexBlock 包含三条并行处理路径和一个 MoE 前馈层:
9
+ ┌──────────────── CortexBlock ────────────────┐
10
+ │ │
11
+ │ Input │
12
+ │ │ │
13
+ │ ├──► RMSNorm ──┬──► Multi-Scale SSM ──┐ │
14
+ │ │ │ │ │
15
+ │ │ ├──► Sparse Attention ──┤ │
16
+ │ │ │ │ │
17
+ │ │ └──► Synaptic Memory ───┤ │
18
+ │ │ │ │
19
+ │ │ ┌── Adaptive Fusion ◄───┘ │
20
+ │ │ │ │
21
+ │ ├──── + ◄──────┘ (残差连接) │
22
+ │ │ │
23
+ │ ├──► RMSNorm ──► MoE FFN ──┐ │
24
+ │ │ │ │
25
+ │ └──── + ◄──────────────────┘ │
26
+ │ │ (残差连接) │
27
+ │ Output │
28
+ └───────────────────────────────────────────────┘
29
+ """
30
+
31
+ import torch
32
+ import torch.nn as nn
33
+ import torch.nn.functional as F
34
+
35
+ try:
36
+ from .config import CortexNetConfig
37
+ from .ssm import MultiScaleSSM
38
+ from .attention import SelectiveSparseAttention
39
+ from .memory import SynapticMemory
40
+ from .routing import MixtureOfExperts
41
+ except ImportError:
42
+ # 兼容脚本式导入: `from blocks import CortexBlockV3`
43
+ from cortexnet.config import CortexNetConfig
44
+ from cortexnet.ssm import MultiScaleSSM
45
+ from cortexnet.attention import SelectiveSparseAttention
46
+ from cortexnet.memory import SynapticMemory
47
+ from cortexnet.routing import MixtureOfExperts
48
+
49
+
50
+ class RMSNorm(nn.Module):
51
+ """均方根层归一化 (Root Mean Square Layer Normalization)。
52
+
53
+ 相比标准 LayerNorm,RMSNorm 不需要计算均值和偏移,
54
+ 只进行缩放归一化,计算更高效,效果相当。
55
+
56
+ 数学公式:
57
+ RMSNorm(x) = x / √(mean(x²) + ε) × γ
58
+
59
+ Args:
60
+ dim: 归一化的维度
61
+ eps: 防止除零的小常数
62
+ """
63
+
64
+ def __init__(self, dim: int, eps: float = 1e-6):
65
+ super().__init__()
66
+ self.eps = eps
67
+ self.weight = nn.Parameter(torch.ones(dim))
68
+
69
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
70
+ # 采用 FP32 归一化提高半精度推理稳定性(与主流 LLM 实现一致)
71
+ input_dtype = x.dtype
72
+ x_fp32 = x.to(torch.float32)
73
+ rms = torch.rsqrt(x_fp32.pow(2).mean(-1, keepdim=True) + self.eps)
74
+ out = x_fp32 * rms
75
+ return (out.to(input_dtype) * self.weight.to(input_dtype))
76
+
77
+
78
+ class AdaptiveFusionGate(nn.Module):
79
+ """自适应融合门控 (Adaptive Fusion Gate)。
80
+
81
+ 动态学习如何平衡 SSM、注意力和记忆三条路径的贡献。
82
+ 不同 token、不同上下文可能需要不同的路径组合:
83
+ - 局部模式:SSM 权重更高
84
+ - 需要远程参考:注意力权重更高
85
+ - 需要上下文适应:记忆权重更高
86
+
87
+ 使用瓶颈结构减少参数量:
88
+ Concat(3D) → Linear(3D, D//2) → SiLU → Linear(D//2, 3) → Softmax
89
+
90
+ Args:
91
+ d_model: 模型维度
92
+ num_paths: 路径数量(默认 3)
93
+ """
94
+
95
+ def __init__(self, d_model: int, num_paths: int = 3):
96
+ super().__init__()
97
+ bottleneck = max(d_model // 4, 32)
98
+ self.gate = nn.Sequential(
99
+ nn.Linear(d_model * num_paths, bottleneck),
100
+ nn.SiLU(),
101
+ nn.Linear(bottleneck, num_paths),
102
+ )
103
+
104
+ def forward(self, *path_outputs: torch.Tensor) -> torch.Tensor:
105
+ """
106
+ Args:
107
+ *path_outputs: 多个路径输出,各为 (batch, seq_len, d_model)
108
+ Returns:
109
+ fused: (batch, seq_len, d_model)
110
+ """
111
+ # 拼接所有路径的输出用于门控计算
112
+ concat = torch.cat(path_outputs, dim=-1) # (B, L, num_paths * D)
113
+ gate_weights = F.softmax(
114
+ self.gate(concat), dim=-1
115
+ ) # (B, L, num_paths)
116
+
117
+ # 堆叠 + 加权求和
118
+ stacked = torch.stack(
119
+ path_outputs, dim=-1
120
+ ) # (B, L, D, num_paths)
121
+ fused = (stacked * gate_weights.unsqueeze(-2)).sum(
122
+ dim=-1
123
+ ) # (B, L, D)
124
+
125
+ return fused
126
+
127
+
128
+ class CortexBlock(nn.Module):
129
+ """CortexNet 核心构建块。
130
+
131
+ 每个块包含三条并行处理路径:
132
+ 1. 多尺度 SSM:在多个时间尺度上捕获序列模式 — O(n)
133
+ 2. 选择性稀疏注意力:聚焦最重要的 token — O(n·k)
134
+ 3. 突触记忆:快速上下文适应 — O(n)
135
+
136
+ 路径输出通过自适应融合门控动态合并,
137
+ 然后经过混合专家前馈网络进行高效的非线性变换。
138
+
139
+ Args:
140
+ config: CortexNet 配置
141
+ layer_idx: 当前层索引(用于调试)
142
+ """
143
+
144
+ def __init__(self, config: CortexNetConfig, layer_idx: int = 0):
145
+ super().__init__()
146
+ self.layer_idx = layer_idx
147
+
148
+ # 预归一化
149
+ self.norm1 = RMSNorm(config.hidden_size, config.norm_eps)
150
+ self.norm2 = RMSNorm(config.hidden_size, config.norm_eps)
151
+
152
+ # 三条并行处理路径
153
+ self.ssm = MultiScaleSSM(
154
+ d_model=config.hidden_size,
155
+ num_scales=config.num_scales,
156
+ state_size=config.ssm_state_size,
157
+ expand_factor=config.ssm_expand_factor,
158
+ )
159
+
160
+ self.attention = SelectiveSparseAttention(
161
+ d_model=config.hidden_size,
162
+ num_heads=config.num_heads,
163
+ top_k_ratio=config.top_k_ratio,
164
+ k_mode=config.attention_k_mode,
165
+ max_seq_len=config.max_seq_len,
166
+ rope_theta=config.rope_theta,
167
+ dropout=config.dropout,
168
+ )
169
+
170
+ self.memory = SynapticMemory(
171
+ d_model=config.hidden_size,
172
+ memory_dim=config.memory_dim,
173
+ decay_init=config.memory_decay_init,
174
+ )
175
+
176
+ # 自适应融合
177
+ self.fusion = AdaptiveFusionGate(config.hidden_size, num_paths=3)
178
+
179
+ # 混合专家前馈网络
180
+ self.moe = MixtureOfExperts(
181
+ d_model=config.hidden_size,
182
+ d_ff=config.expert_ff_dim,
183
+ num_experts=config.num_experts,
184
+ num_active=config.num_active_experts,
185
+ aux_loss_weight=config.moe_aux_loss_weight,
186
+ )
187
+
188
+ self.dropout = nn.Dropout(config.dropout)
189
+
190
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
191
+ """
192
+ Args:
193
+ x: (batch, seq_len, hidden_size)
194
+ Returns:
195
+ output: (batch, seq_len, hidden_size)
196
+ """
197
+ # ═══ 多路径处理 + 残差 ═══
198
+ residual = x
199
+ x_norm = self.norm1(x)
200
+
201
+ # 三条路径并行处理
202
+ ssm_out = self.ssm(x_norm) # O(n) 复杂度
203
+ attn_out = self.attention(x_norm) # O(n·k) 复杂度
204
+ mem_out = self.memory(x_norm) # O(n) 复杂度
205
+
206
+ # 自适应融合
207
+ fused = self.fusion(ssm_out, attn_out, mem_out)
208
+ x = residual + self.dropout(fused)
209
+
210
+ # ═══ MoE FFN + 残差 ═══
211
+ residual = x
212
+ x = residual + self.dropout(self.moe(self.norm2(x)))
213
+
214
+ return x
215
+
216
+ def get_aux_loss(self) -> torch.Tensor:
217
+ """获取 MoE 路由的辅助损失。"""
218
+ return self.moe.aux_loss
219
+
220
+
221
+ # ═══════════════════════════════════════════════════════════════
222
+ # CortexBlock V2 — 进化版
223
+ # ═══════════════════════════════════════════════════════════════
224
+
225
+ try:
226
+ from .hierarchical_memory import HierarchicalMemorySystem
227
+ from .graph_reasoning import GraphReasoningModule
228
+ from .meta_learning import MetaLearningAdapter, TaskAdaptiveController
229
+ from .routing import CollaborativeMoE
230
+ except ImportError:
231
+ from cortexnet.hierarchical_memory import HierarchicalMemorySystem
232
+ from cortexnet.graph_reasoning import GraphReasoningModule
233
+ from cortexnet.meta_learning import MetaLearningAdapter, TaskAdaptiveController
234
+ from cortexnet.routing import CollaborativeMoE
235
+
236
+
237
+ class CortexBlockV2(nn.Module):
238
+ """CortexNet V2 进化版构建块。
239
+
240
+ 在 V1 的基础上增加了 4 项重大升级:
241
+
242
+ ┌──────────────── CortexBlock V2 ────────────────────────────┐
243
+ │ │
244
+ │ Input │
245
+ │ │ │
246
+ │ ├──► RMSNorm ──┬──► Multi-Scale SSM (并行扫描) ──┐ │
247
+ │ │ ├──► Selective Sparse Attention ────┤ │
248
+ │ │ ├──► Hierarchical Memory (3层) ─────┤ │
249
+ │ │ └──► Graph Reasoning (多步推理) ────┤ │
250
+ │ │ │ │
251
+ │ │ ┌── Adaptive Fusion (4路) ◄─────────┘ │
252
+ │ │ (residual) │ │
253
+ │ ├──── + ◄──────┤ │
254
+ │ │ │ │
255
+ │ │ └── Meta-Learning Adapter (FiLM) │
256
+ │ │ │
257
+ │ ├──► RMSNorm ──► Collaborative MoE FFN ──┐ │
258
+ │ │ (residual) │ │
259
+ │ └──── + ◄────────────────────────────────┘ │
260
+ │ │ │
261
+ │ └──► Task Adaptive Controller │
262
+ │ │ │
263
+ │ Output │
264
+ └─────────────────────────────────────────────────────────────┘
265
+
266
+ V2 新增:
267
+ 1. 分层记忆系统替代单一突触记忆
268
+ 2. 图推理模块实现多步关系推理
269
+ 3. 元学习适配器实现快速任务适应
270
+ 4. 协作式 MoE 实现专家间知识共享
271
+ 5. 任务自适应控制器实现隐式任务切换
272
+
273
+ Args:
274
+ config: CortexNet 配置
275
+ layer_idx: 层索引
276
+ """
277
+
278
+ def __init__(self, config, layer_idx: int = 0):
279
+ super().__init__()
280
+ self.layer_idx = layer_idx
281
+
282
+ # 预归一化
283
+ self.norm1 = RMSNorm(config.hidden_size, config.norm_eps)
284
+ self.norm2 = RMSNorm(config.hidden_size, config.norm_eps)
285
+
286
+ # ═══ 四条并行处理路径 ═══
287
+
288
+ # 路径1: 多尺度 SSM(含分块并行扫描)
289
+ self.ssm = MultiScaleSSM(
290
+ d_model=config.hidden_size,
291
+ num_scales=config.num_scales,
292
+ state_size=config.ssm_state_size,
293
+ expand_factor=config.ssm_expand_factor,
294
+ )
295
+
296
+ # 路径2: 选择性稀疏注意力
297
+ self.attention = SelectiveSparseAttention(
298
+ d_model=config.hidden_size,
299
+ num_heads=config.num_heads,
300
+ top_k_ratio=config.top_k_ratio,
301
+ k_mode=config.attention_k_mode,
302
+ max_seq_len=config.max_seq_len,
303
+ rope_theta=config.rope_theta,
304
+ dropout=config.dropout,
305
+ )
306
+
307
+ # 路径3: 分层记忆系统(V2 新增)
308
+ self.memory = HierarchicalMemorySystem(
309
+ d_model=config.hidden_size,
310
+ working_dim=getattr(config, "memory_dim", 64),
311
+ episodic_slots=getattr(config, "episodic_slots", 32),
312
+ semantic_slots=getattr(config, "semantic_slots", 64),
313
+ num_heads=min(4, config.num_heads),
314
+ )
315
+
316
+ # 路径4: 图推理模块(V2 新增)
317
+ self.graph_reasoning = GraphReasoningModule(
318
+ d_model=config.hidden_size,
319
+ num_neighbors=getattr(config, "graph_neighbors", 16),
320
+ num_iterations=getattr(config, "graph_iterations", 2),
321
+ )
322
+
323
+ # ═══ 自适应融合(4 路径) ═══
324
+ self.fusion = AdaptiveFusionGate(config.hidden_size, num_paths=4)
325
+
326
+ # ═══ 元学习适配器(V2 新增) ═══
327
+ self.meta_adapter = MetaLearningAdapter(config.hidden_size)
328
+
329
+ # ═══ 协作式 MoE(V2 升级) ═══
330
+ self.moe = CollaborativeMoE(
331
+ d_model=config.hidden_size,
332
+ d_ff=config.expert_ff_dim,
333
+ num_experts=config.num_experts,
334
+ num_active=config.num_active_experts,
335
+ aux_loss_weight=config.moe_aux_loss_weight,
336
+ )
337
+
338
+ # ═══ 任务自适应控制器(V2 新增) ═══
339
+ self.task_controller = TaskAdaptiveController(
340
+ config.hidden_size,
341
+ num_modes=getattr(config, "num_task_modes", 4),
342
+ )
343
+
344
+ self.dropout = nn.Dropout(config.dropout)
345
+
346
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
347
+ """
348
+ Args:
349
+ x: (batch, seq_len, hidden_size)
350
+ Returns:
351
+ output: (batch, seq_len, hidden_size)
352
+ """
353
+ # ═══ 四路并行处理 + 残差 ═══
354
+ residual = x
355
+ x_norm = self.norm1(x)
356
+
357
+ ssm_out = self.ssm(x_norm)
358
+ attn_out = self.attention(x_norm)
359
+ mem_out = self.memory(x_norm)
360
+ graph_out = self.graph_reasoning(x_norm)
361
+
362
+ # 自适应融合(4 路径)
363
+ fused = self.fusion(ssm_out, attn_out, mem_out, graph_out)
364
+
365
+ # 元学习适配(FiLM 调制)
366
+ fused = self.meta_adapter(fused)
367
+
368
+ x = residual + self.dropout(fused)
369
+
370
+ # ═══ 协作式 MoE FFN + 残差 ═══
371
+ residual = x
372
+ x = residual + self.dropout(self.moe(self.norm2(x)))
373
+
374
+ # ═══ 任务自适应控制 ═══
375
+ x = self.task_controller(x)
376
+
377
+ return x
378
+
379
+ def get_aux_loss(self) -> torch.Tensor:
380
+ """获取辅助损失。"""
381
+ return self.moe.aux_loss
382
+
383
+
384
+ # ═══════════════════════════════════════════════════════════════
385
+ # CortexBlock V3 — 终极进化
386
+ # ═══════════════════════════════════════════════════════════════
387
+
388
+ from typing import Optional, Tuple, Any
389
+
390
+ import torch.utils.checkpoint as ckpt
391
+
392
+ try:
393
+ from .causal_reasoning import CausalReasoningModule
394
+ from .self_evolution import DynamicPathController
395
+ from .multi_agent import MultiAgentSystem
396
+ from .adversarial import AdversarialShield
397
+ except ImportError:
398
+ from cortexnet.causal_reasoning import CausalReasoningModule
399
+ from cortexnet.self_evolution import DynamicPathController
400
+ from cortexnet.multi_agent import MultiAgentSystem
401
+ from cortexnet.adversarial import AdversarialShield
402
+
403
+
404
+ class MixtureOfDepths(nn.Module):
405
+ """Mixture of Depths: token 级自适应层跳过。
406
+
407
+ 学习一个路由器,决定每个 token 是否需要完整的 block 处理。
408
+ 「简单」token 跳过完整计算,只做轻量残差;
409
+ 「困难」token 接受全部 5 路径处理。
410
+ 整体减少平均计算量而不损失建模能力。
411
+
412
+ 参考: Raposo et al., "Mixture-of-Depths" (2024)
413
+
414
+ Args:
415
+ d_model: 模型维度
416
+ capacity: 每层处理的 token 比例 (0~1)
417
+ """
418
+
419
+ def __init__(self, d_model: int, capacity: float = 0.5):
420
+ super().__init__()
421
+ self.capacity = capacity
422
+ self.router = nn.Linear(d_model, 1, bias=False)
423
+ # 轻量旁路(跳过 token 不是纯恒等,而是经过简单变换保持表达力)
424
+ self.bypass = nn.Sequential(
425
+ nn.Linear(d_model, d_model, bias=False),
426
+ )
427
+ nn.init.eye_(self.bypass[0].weight) # 初始化为恒等,渐进式学习
428
+
429
+ def forward(
430
+ self,
431
+ x: torch.Tensor,
432
+ block_fn,
433
+ ) -> torch.Tensor:
434
+ """
435
+ Args:
436
+ x: (B, L, D)
437
+ block_fn: callable,接受 (B, L, D) 返回 (B, L, D) 的完整 block 函数
438
+ Returns:
439
+ output: (B, L, D)
440
+ """
441
+ B, L, D = x.shape
442
+ k = max(1, int(L * self.capacity))
443
+
444
+ if k >= L:
445
+ return block_fn(x)
446
+
447
+ scores = self.router(x).squeeze(-1) # (B, L)
448
+ _, top_idx = scores.topk(k, dim=-1, sorted=True) # (B, k)
449
+ top_idx_sorted, _ = top_idx.sort(dim=-1)
450
+
451
+ # Gather 选中 + 未选中 token(连续内存)
452
+ idx_exp = top_idx_sorted.unsqueeze(-1).expand(-1, -1, D)
453
+ x_selected = x.gather(1, idx_exp) # (B, k, D)
454
+
455
+ # 完整处理选中 token
456
+ out_selected = block_fn(x_selected)
457
+
458
+ # 构建未选中 token 索引(连续 gather,避免 bool mask)
459
+ selected_mask = torch.zeros(B, L, dtype=torch.bool, device=x.device)
460
+ selected_mask.scatter_(1, top_idx_sorted, True)
461
+
462
+ # 全量 bypass(先初始化为原始 x,再覆盖选中位置)
463
+ output = self.bypass(x) # (B, L, D) — 所有 token 经过轻量 bypass
464
+ # 选中 token 覆盖 bypass 结果
465
+ output.scatter_(1, idx_exp, out_selected)
466
+
467
+ return output
468
+
469
+
470
+ class CortexBlockV3(nn.Module):
471
+ """CortexNet V3 终极进化构建块。
472
+
473
+ 在 V2 基础上新增 4 项突破性能力:
474
+
475
+ ┌────────────────── CortexBlock V3 ──────────────────────────┐
476
+ │ │
477
+ │ Input ──► Adversarial Shield (对抗防御) ──┐ │
478
+ │ │ │
479
+ │ ├──► RMSNorm ──┬──► SSM ──────────────────┐│ │
480
+ │ │ ├──► Sparse Attention ──────┤│ │
481
+ │ │ ├──► Hierarchical Memory ───┤│ │
482
+ │ │ ├──► Graph Reasoning ────────┤ │
483
+ │ │ └──► Causal Reasoning ──────┤│ (5条路径) │
484
+ │ │ ││ │
485
+ │ │ Dynamic Path Controller ──► 路径开关 ───┤│ │
486
+ │ │ (自我进化: 动态激活/禁用路径) ││ │
487
+ │ │ ││ │
488
+ │ │ ┌── Adaptive Fusion (5路) ◄─┘│ │
489
+ │ │ (residual) │ │ │
490
+ │ ├──── + ◄──────┤ Meta Adapter │ │
491
+ │ │ │ │
492
+ │ ├──► RMSNorm ──► Collaborative MoE ──┐ │ │
493
+ │ │ (residual) │ │ │
494
+ │ └──── + ◄────────────────────────────┘ │ │
495
+ │ │ │ │
496
+ │ └──► Multi-Agent Coordinator ────────────┘ │
497
+ │ (多智能体协作决策) │
498
+ │ │ │
499
+ │ └──► Task Controller ──► Output │
500
+ └──────────────────────────────────────────────────────────────┘
501
+
502
+ V3 新增:
503
+ 1. 因果推理 — 第5条处理路径,理解因果而非仅相关性
504
+ 2. 自我进化 — 动态路径开关,输入驱动的架构适应
505
+ 3. 多智能体 — 多专家协作决策
506
+ 4. 对抗防御 — 三层防护保障安全性
507
+ """
508
+
509
+ def __init__(self, config, layer_idx: int = 0):
510
+ super().__init__()
511
+ self.layer_idx = layer_idx
512
+ self.use_gradient_checkpointing = getattr(config, "use_gradient_checkpointing", False)
513
+ D = config.hidden_size
514
+
515
+ # 预归一化
516
+ self.norm1 = RMSNorm(D, config.norm_eps)
517
+ self.norm2 = RMSNorm(D, config.norm_eps)
518
+
519
+ # ═══ 对抗防御 (V3) ═══
520
+ self.adversarial_shield = AdversarialShield(D)
521
+
522
+ # ═══ Mixture of Depths (可选) ═══
523
+ use_mod = getattr(config, "use_mixture_of_depths", False)
524
+ mod_cap = getattr(config, "mod_capacity", 0.5)
525
+ self.mixture_of_depths = MixtureOfDepths(D, mod_cap) if use_mod else None
526
+
527
+ # ═══ 五条并行处理路径 ═══
528
+ self.ssm = MultiScaleSSM(
529
+ d_model=D, num_scales=config.num_scales,
530
+ state_size=config.ssm_state_size, expand_factor=config.ssm_expand_factor,
531
+ )
532
+ self.attention = SelectiveSparseAttention(
533
+ d_model=D, num_heads=config.num_heads,
534
+ top_k_ratio=config.top_k_ratio, k_mode=config.attention_k_mode,
535
+ max_seq_len=config.max_seq_len, rope_theta=config.rope_theta,
536
+ dropout=config.dropout,
537
+ sliding_window_size=getattr(config, "sliding_window_size", 0),
538
+ )
539
+ self.memory = HierarchicalMemorySystem(
540
+ d_model=D, working_dim=getattr(config, "memory_dim", 64),
541
+ episodic_slots=getattr(config, "episodic_slots", 32),
542
+ semantic_slots=getattr(config, "semantic_slots", 64),
543
+ num_heads=min(4, config.num_heads),
544
+ )
545
+ self.graph_reasoning = GraphReasoningModule(
546
+ d_model=D,
547
+ num_neighbors=getattr(config, "graph_neighbors", 16),
548
+ num_iterations=getattr(config, "graph_iterations", 2),
549
+ )
550
+ # 第5条路径: 因果推理 (V3 新增)
551
+ self.causal_reasoning = CausalReasoningModule(
552
+ d_model=D,
553
+ num_heads=min(4, config.num_heads),
554
+ num_counterfactuals=getattr(config, "num_counterfactuals", 4),
555
+ )
556
+
557
+ # ═══ 自我进化: 动态路径控制 (V3) ═══
558
+ self.path_controller = DynamicPathController(D, num_paths=5)
559
+
560
+ # ═══ 5路自适应融合 ═══
561
+ self.fusion = AdaptiveFusionGate(D, num_paths=5)
562
+
563
+ # ═══ 元学习适配器 ═══
564
+ self.meta_adapter = MetaLearningAdapter(D)
565
+
566
+ # ═══ 协作式 MoE ═══
567
+ self.moe = CollaborativeMoE(
568
+ d_model=D, d_ff=config.expert_ff_dim,
569
+ num_experts=config.num_experts, num_active=config.num_active_experts,
570
+ aux_loss_weight=config.moe_aux_loss_weight,
571
+ )
572
+
573
+ # ═══ 多智能体协作 (V3) ═══
574
+ self.multi_agent = MultiAgentSystem(
575
+ d_model=D,
576
+ num_agents=getattr(config, "num_agents", 4),
577
+ )
578
+
579
+ # ═══ 任务自适应控制 ═══
580
+ self.task_controller = TaskAdaptiveController(
581
+ D, num_modes=getattr(config, "num_task_modes", 4),
582
+ )
583
+
584
+ self.dropout = nn.Dropout(config.dropout)
585
+
586
+ def forward(
587
+ self,
588
+ x: torch.Tensor,
589
+ past_cache: Optional[Tuple[Any, Any, Any]] = None,
590
+ use_cache: bool = False,
591
+ ) -> torch.Tensor | Tuple[torch.Tensor, Optional[Tuple[Any, Any, Any]]]:
592
+ """Forward with optional cache for incremental decoding.
593
+
594
+ Args:
595
+ x: (batch, seq_len, hidden_size)
596
+ past_cache: (ssm_state, (mem, z), (K, V, top_k)) 或 None
597
+ use_cache: 若 True 返回 (output, new_cache)
598
+ """
599
+ x = self.adversarial_shield.defend_input(x)
600
+
601
+ # Mixture of Depths:选择哪些 token 需要完整处理
602
+ if self.mixture_of_depths is not None and not use_cache:
603
+ return self.mixture_of_depths(x, lambda z: self._full_forward(z, None, False))
604
+
605
+ return self._full_forward(x, past_cache, use_cache)
606
+
607
+ def _full_forward(
608
+ self,
609
+ x: torch.Tensor,
610
+ past_cache: Optional[Tuple[Any, Any, Any]] = None,
611
+ use_cache: bool = False,
612
+ ) -> torch.Tensor | Tuple[torch.Tensor, Optional[Tuple[Any, Any, Any]]]:
613
+ """完整的 block 前向,支持梯度检查点。"""
614
+ residual = x
615
+ x_norm = self.norm1(x)
616
+ path_gates = self.path_controller(x_norm)
617
+
618
+ past_ssm, past_mem, past_kv = (None, None, None)
619
+ if past_cache is not None:
620
+ past_ssm, past_mem, past_kv = past_cache
621
+
622
+ # ═══ 五路并行处理(路径感知跳过 + 梯度检查点)═══
623
+ _SKIP_THRESH = 0.0 if self.training else 0.01 # eval 时跳过近零路径
624
+
625
+ def _compute_paths(xn, pg, p_ssm, p_mem, p_kv, uc):
626
+ zero = torch.zeros_like(xn)
627
+ n_ssm = n_kv = n_mem = None
628
+
629
+ # 路径感知跳过:门控 ≈ 0 的路径直接输出零,不做计算
630
+ if pg[:, 0].max() > _SKIP_THRESH:
631
+ s_res = self.ssm(xn, past_state=p_ssm, use_cache=uc)
632
+ s_out = (s_res[0] if uc and isinstance(s_res, tuple) else s_res) * pg[:, 0:1].unsqueeze(1)
633
+ n_ssm = s_res[1] if uc and isinstance(s_res, tuple) else None
634
+ else:
635
+ s_out = zero
636
+
637
+ if pg[:, 1].max() > _SKIP_THRESH:
638
+ a_res = self.attention(xn, past_key_value=p_kv, use_cache=uc)
639
+ a_out = (a_res[0] if uc and isinstance(a_res, tuple) else a_res) * pg[:, 1:2].unsqueeze(1)
640
+ n_kv = a_res[1] if uc and isinstance(a_res, tuple) else None
641
+ else:
642
+ a_out = zero
643
+
644
+ if pg[:, 2].max() > _SKIP_THRESH:
645
+ m_res = self.memory(xn, past_working_memory=p_mem, use_cache=uc)
646
+ m_out = (m_res[0] if uc and isinstance(m_res, tuple) else m_res) * pg[:, 2:3].unsqueeze(1)
647
+ n_mem = m_res[1] if uc and isinstance(m_res, tuple) else None
648
+ else:
649
+ m_out = zero
650
+
651
+ g_out = self.graph_reasoning(xn) * pg[:, 3:4].unsqueeze(1) if pg[:, 3].max() > _SKIP_THRESH else zero
652
+ c_out = self.causal_reasoning(xn) * pg[:, 4:5].unsqueeze(1) if pg[:, 4].max() > _SKIP_THRESH else zero
653
+
654
+ fused = self.fusion(s_out, a_out, m_out, g_out, c_out)
655
+ fused = self.meta_adapter(fused)
656
+ return fused, n_ssm, n_mem, n_kv
657
+
658
+ if self.use_gradient_checkpointing and self.training and not use_cache:
659
+ fused, new_ssm, new_mem, new_kv = ckpt.checkpoint(
660
+ _compute_paths, x_norm, path_gates,
661
+ past_ssm, past_mem, past_kv, use_cache,
662
+ use_reentrant=False,
663
+ )
664
+ else:
665
+ fused, new_ssm, new_mem, new_kv = _compute_paths(
666
+ x_norm, path_gates, past_ssm, past_mem, past_kv, use_cache
667
+ )
668
+
669
+ fused = self.adversarial_shield.defend_features(fused)
670
+ x = residual + self.dropout(fused)
671
+
672
+ residual = x
673
+ x = residual + self.dropout(self.moe(self.norm2(x)))
674
+ x = self.multi_agent(x)
675
+ x = self.task_controller(x)
676
+
677
+ if use_cache:
678
+ return x, (new_ssm, new_mem, new_kv)
679
+ return x
680
+
681
+ def get_aux_loss(self) -> torch.Tensor:
682
+ return self.moe.aux_loss