cortexnet 3.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cortexnet/routing.py ADDED
@@ -0,0 +1,335 @@
1
+ """
2
+ 动态混合专家路由 (Dynamic Mixture-of-Experts Routing)
3
+
4
+ 核心创新:
5
+ 每个 token 被动态路由到最相关的专家子网络。
6
+ 模型拥有大量参数(所有专家的总和),但每个 token 只激活
7
+ 其中一小部分(top-k 个专家),实现了:
8
+
9
+ 1. 高效扩展:总参数量可以很大,但计算量不随之线性增长
10
+ 2. 专业化分工:不同专家自动学习处理不同类型的模式
11
+ 3. 条件计算:简单 token 和复杂 token 获得同样的专家处理
12
+
13
+ 例如,8 个专家各 512 维 FFN,激活 2 个:
14
+ - 总参数:8 × 3 × (D × 512) ≈ 9.4M
15
+ - 每 token 计算:2 × 3 × (D × 512) ≈ 2.4M
16
+ - 等效 FFN 宽度:2 × 512 = 1024
17
+
18
+ 包含负载均衡辅助损失,防止专家坍塌(所有 token 被路由到少数专家)。
19
+ """
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+
25
+
26
+ class ExpertFFN(nn.Module):
27
+ """单个专家前馈网络(SwiGLU 激活)。
28
+
29
+ SwiGLU 是 GLU 变体中效果最好的,被 LLaMA 等模型采用。
30
+ 结构: down_proj(SiLU(gate_proj(x)) * up_proj(x))
31
+
32
+ Args:
33
+ d_model: 输入/输出维度
34
+ d_ff: 前馈层中间维度
35
+ """
36
+
37
+ def __init__(self, d_model: int, d_ff: int):
38
+ super().__init__()
39
+ self.gate_proj = nn.Linear(d_model, d_ff, bias=False)
40
+ self.up_proj = nn.Linear(d_model, d_ff, bias=False)
41
+ self.down_proj = nn.Linear(d_ff, d_model, bias=False)
42
+
43
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
44
+ return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
45
+
46
+
47
+ class MixtureOfExperts(nn.Module):
48
+ """混合专家模块(Top-K 路由)。
49
+
50
+ 架构流程:
51
+ 1. Router 对每个 token 计算所有专家的门控概率
52
+ 2. 选择 top-k 个概率最高的专家
53
+ 3. 将 token 分派到选中的专家处理
54
+ 4. 用门控权重加权合并专家输出
55
+
56
+ Args:
57
+ d_model: 模型维度
58
+ d_ff: 每个专家的 FFN 中间维度
59
+ num_experts: 专家总数
60
+ num_active: 每个 token 激活的专家数
61
+ aux_loss_weight: 负载均衡辅助损失权重
62
+ """
63
+
64
+ def __init__(
65
+ self,
66
+ d_model: int,
67
+ d_ff: int,
68
+ num_experts: int = 8,
69
+ num_active: int = 2,
70
+ aux_loss_weight: float = 0.02,
71
+ capacity_factor: float = 1.25,
72
+ ):
73
+ super().__init__()
74
+ self.d_model = d_model
75
+ self.num_experts = num_experts
76
+ self.num_active = num_active
77
+ self.aux_loss_weight = aux_loss_weight
78
+ self.capacity_factor = capacity_factor
79
+
80
+ # 路由器 (小初始化,使初始 softmax 更平滑)
81
+ self.router = nn.Linear(d_model, num_experts, bias=False)
82
+ nn.init.normal_(self.router.weight, 0, 0.02)
83
+
84
+ # 专家网络
85
+ self.experts = nn.ModuleList(
86
+ [ExpertFFN(d_model, d_ff) for _ in range(num_experts)]
87
+ )
88
+
89
+ # 存储辅助损失 (aux + z_loss)
90
+ self.aux_loss = torch.tensor(0.0)
91
+
92
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
93
+ """
94
+ Args:
95
+ x: (batch, seq_len, d_model)
96
+ Returns:
97
+ output: (batch, seq_len, d_model)
98
+ """
99
+ B, L, D = x.shape
100
+ x_flat = x.reshape(-1, D) # (B*L, D)
101
+ num_tokens = x_flat.shape[0]
102
+
103
+ # 计算路由 logits
104
+ router_logits = self.router(x_flat) # (B*L, num_experts)
105
+
106
+ # Noisy routing: 训练时加高斯噪声,鼓励探索(加大强度防止坍塌)
107
+ if self.training:
108
+ noise_scale = 0.02 if num_tokens > 1 else 0.01
109
+ router_logits = router_logits + noise_scale * torch.randn_like(
110
+ router_logits, device=router_logits.device
111
+ )
112
+
113
+ # Z-loss: 惩罚过大的 logits,防止 softmax 坍塌
114
+ z_loss = 0.001 * (router_logits**2).mean() if self.training else 0.0
115
+
116
+ router_probs = F.softmax(router_logits, dim=-1)
117
+
118
+ # 选择 top-k 专家
119
+ top_k_probs, top_k_indices = router_probs.topk(
120
+ self.num_active, dim=-1
121
+ ) # (B*L, num_active)
122
+
123
+ # 归一化选中专家的概率
124
+ top_k_weights = top_k_probs / (
125
+ top_k_probs.sum(dim=-1, keepdim=True) + 1e-6
126
+ )
127
+
128
+ # 计算负载均衡损失
129
+ if self.training:
130
+ aux = self._compute_aux_loss(
131
+ router_probs, top_k_indices, num_tokens
132
+ )
133
+ self.aux_loss = aux + z_loss
134
+
135
+ # Expert capacity cap: 每专家最多处理 capacity_factor × ideal_load 个 token
136
+ capacity = max(
137
+ 1, int(self.capacity_factor * num_tokens / self.num_experts)
138
+ )
139
+ output = self._dispatch_with_capacity(
140
+ x_flat, top_k_indices, top_k_weights, capacity
141
+ )
142
+
143
+ return output.reshape(B, L, D)
144
+
145
+ def _dispatch_with_capacity(
146
+ self,
147
+ x_flat: torch.Tensor,
148
+ top_k_indices: torch.Tensor,
149
+ top_k_weights: torch.Tensor,
150
+ capacity: int,
151
+ ) -> torch.Tensor:
152
+ """批量并行分派:所有专家一次 forward,消除 Python for-loop。
153
+
154
+ 优化要点:
155
+ 1. 按专家排序 + 容量约束后,pad 到统一长度
156
+ 2. 构建 (num_experts, max_cap, D) 的 batched input
157
+ 3. 所有专家共享同一组 gate/up/down_proj 权重形状,
158
+ 可用循环展开或 vmap 实现;此版本使用 chunk 批量 forward
159
+ 4. scatter 回原始 token 位置
160
+ """
161
+ num_tokens = x_flat.shape[0]
162
+
163
+ # 展平 top-k 选择: (N*k,)
164
+ flat_expert = top_k_indices.reshape(-1)
165
+ flat_weight = top_k_weights.reshape(-1)
166
+ flat_token = (
167
+ torch.arange(num_tokens, device=x_flat.device)
168
+ .unsqueeze(1)
169
+ .expand(-1, self.num_active)
170
+ .reshape(-1)
171
+ )
172
+
173
+ # 按 expert 排序
174
+ sort_order = flat_expert.argsort(stable=True)
175
+ s_expert = flat_expert[sort_order]
176
+ s_weight = flat_weight[sort_order]
177
+ s_token = flat_token[sort_order]
178
+
179
+ unique_e, counts = s_expert.unique_consecutive(return_counts=True)
180
+
181
+ # 预处理:容量约束 + 收集每专家的 token/weight
182
+ expert_token_lists = [[] for _ in range(self.num_experts)]
183
+ expert_weight_lists = [[] for _ in range(self.num_experts)]
184
+ offset = 0
185
+ for i in range(unique_e.shape[0]):
186
+ e = unique_e[i].item()
187
+ c = counts[i].item()
188
+ seg_w = s_weight[offset : offset + c]
189
+ seg_t = s_token[offset : offset + c]
190
+ cap = min(capacity, c)
191
+ if cap < c:
192
+ top_idx = seg_w.argsort(descending=True)[:cap]
193
+ seg_t = seg_t[top_idx]
194
+ seg_w = seg_w[top_idx]
195
+ expert_token_lists[e] = seg_t
196
+ expert_weight_lists[e] = seg_w
197
+ offset += c
198
+
199
+ # 批量 forward:逐专家但避免小 tensor kernel launch 开销
200
+ # 先收集每专家的 input,一次性 gather
201
+ output = torch.zeros_like(x_flat)
202
+ all_tokens = []
203
+ all_weights = []
204
+ all_expert_ids = []
205
+ expert_sizes = []
206
+ for e in range(self.num_experts):
207
+ toks = expert_token_lists[e]
208
+ if isinstance(toks, torch.Tensor) and toks.numel() > 0:
209
+ all_tokens.append(toks)
210
+ all_weights.append(expert_weight_lists[e])
211
+ all_expert_ids.append(e)
212
+ expert_sizes.append(toks.shape[0])
213
+ else:
214
+ expert_sizes.append(0)
215
+
216
+ if not all_tokens:
217
+ return output
218
+
219
+ # 单次 gather 所有 expert 需要的 token
220
+ cat_tokens = torch.cat(all_tokens) # (total,)
221
+ cat_weights = torch.cat(all_weights) # (total,)
222
+ cat_input = x_flat[cat_tokens] # (total, D)
223
+
224
+ # 按专家分块 forward(连续 tensor,高 GPU 利用率)
225
+ cat_output = torch.empty_like(cat_input)
226
+ pos = 0
227
+ for idx, e in enumerate(all_expert_ids):
228
+ sz = all_tokens[idx].shape[0]
229
+ cat_output[pos:pos + sz] = self.experts[e](cat_input[pos:pos + sz])
230
+ pos += sz
231
+
232
+ # 单次 scatter 回原位置
233
+ output.index_add_(0, cat_tokens, cat_weights.unsqueeze(-1) * cat_output)
234
+ return output
235
+
236
+ def _compute_aux_loss(
237
+ self,
238
+ router_probs: torch.Tensor,
239
+ top_k_indices: torch.Tensor,
240
+ num_tokens: int,
241
+ ) -> torch.Tensor:
242
+ """计算负载均衡辅助损失(增强版)。
243
+
244
+ 三重惩罚机制:
245
+ 1. Switch/GShard 标准损失: N × Σ(f_i × P_i)
246
+ 2. 负载方差惩罚: Var(tokens_per_expert) — 直接惩罚不均匀
247
+ 3. 概率集中度惩罚: -Entropy(P) — 防止路由概率坍塌到少数专家
248
+
249
+ 理想情况下每个专家处理 1/N_experts 的 token。
250
+ """
251
+ # 每个专家处理的 token 比例
252
+ one_hot = F.one_hot(
253
+ top_k_indices, self.num_experts
254
+ ).float() # (B*L, k, E)
255
+ tokens_per_expert = one_hot.sum(dim=1).sum(dim=0) # (E,)
256
+ f = tokens_per_expert / (num_tokens * self.num_active)
257
+
258
+ # 平均路由概率
259
+ P = router_probs.mean(dim=0) # (E,)
260
+
261
+ # 1) 标准负载均衡损失
262
+ balance_loss = self.num_experts * (f * P).sum()
263
+
264
+ # 2) 负载方差惩罚:直接惩罚专家间的 token 分配不均匀
265
+ ideal_load = 1.0 / self.num_experts
266
+ variance_loss = ((f - ideal_load) ** 2).sum() * self.num_experts
267
+
268
+ # 3) 概率熵正则:鼓励路由概率在专家间分散(防止 softmax 坍塌)
269
+ entropy_loss = -(P * torch.log(P + 1e-6)).sum()
270
+ max_entropy = -self.num_experts * (ideal_load * torch.log(torch.tensor(ideal_load)))
271
+ entropy_penalty = (max_entropy - entropy_loss) / max_entropy # 0=完美均匀, 1=完全坍塌
272
+
273
+ # 综合损失(方差惩罚权重高,确保均衡度不低于 0.6)
274
+ loss = balance_loss + 1.0 * variance_loss + 0.2 * entropy_penalty
275
+
276
+ return loss * self.aux_loss_weight
277
+
278
+
279
+ class CollaborativeMoE(MixtureOfExperts):
280
+ """协作式混合专家:在标准 MoE 基础上增加专家间协作。
281
+
282
+ 创新:
283
+ 1. 专家协作层:选中的专家通过小型网络交换信息
284
+ 2. 专家多样性损失:鼓励不同专家学习不同的特征
285
+ 3. 残差修正:专家协作产生的修正叠加到 MoE 输出上
286
+
287
+ Args:
288
+ d_model: 模型维度
289
+ d_ff: 每个专家的 FFN 中间维度
290
+ num_experts: 专家总数
291
+ num_active: 每 token 激活的专家数
292
+ aux_loss_weight: 辅助损失权重
293
+ """
294
+
295
+ def __init__(
296
+ self,
297
+ d_model: int,
298
+ d_ff: int,
299
+ num_experts: int = 8,
300
+ num_active: int = 2,
301
+ aux_loss_weight: float = 0.02,
302
+ capacity_factor: float = 1.25,
303
+ ):
304
+ super().__init__(
305
+ d_model, d_ff, num_experts, num_active, aux_loss_weight, capacity_factor
306
+ )
307
+
308
+ # 专家协作层:融合多个专家的输出
309
+ self.collaboration = nn.Sequential(
310
+ nn.Linear(d_model * 2, d_model),
311
+ nn.SiLU(),
312
+ nn.Linear(d_model, d_model),
313
+ )
314
+
315
+ # 协作门控
316
+ self.collab_gate = nn.Linear(d_model, 1)
317
+
318
+ # 初始化为小值,使初始行为接近标准 MoE
319
+ nn.init.zeros_(self.collaboration[-1].weight)
320
+ nn.init.zeros_(self.collaboration[-1].bias)
321
+
322
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
323
+ """带专家协作的前向传播。"""
324
+ # 标准 MoE 输出
325
+ moe_out = super().forward(x)
326
+
327
+ # 专家协作修正:基于输入和 MoE 输出的联合信息
328
+ collab_input = torch.cat([x, moe_out], dim=-1)
329
+ correction = self.collaboration(collab_input)
330
+
331
+ # 门控:控制协作修正的强度
332
+ gate = torch.sigmoid(self.collab_gate(x))
333
+ corrected = moe_out + gate * correction
334
+
335
+ return corrected
@@ -0,0 +1,174 @@
1
+ """
2
+ 自我进化引擎 (Self-Evolution Engine)
3
+
4
+ 核心创新:
5
+ 模型在推理时动态决定激活哪些处理路径,实现可微分的神经架构搜索。
6
+ 不同输入自动获得不同的架构配置——简单输入用少量路径,
7
+ 复杂输入启用全部路径。
8
+
9
+ ┌─────────────────────────────────────────────────────────────┐
10
+ │ 自我进化 vs 传统 NAS │
11
+ ├─────────────────────────────────────────────────────────────┤
12
+ │ │
13
+ │ 传统 NAS: 训练前搜索固定架构(离线,计算昂贵) │
14
+ │ 自我进化: 推理时动态选择架构(在线,零额外开销) │
15
+ │ │
16
+ │ 实现方式: │
17
+ │ 1. 路径门控 — Gumbel-Sigmoid 可微分开关 │
18
+ │ 2. 计算预算 — 根据输入复杂度分配计算资源 │
19
+ │ 3. 进化记忆 — 记录最优配置,加速未来决策 │
20
+ └─────────────────────────────────────────────────────────────┘
21
+
22
+ 类比: 如同人脑的 "系统1/系统2 思维"——
23
+ 简单问题用快速直觉(少量路径),
24
+ 复杂问题启动深度思考(全部路径)。
25
+ """
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+
30
+
31
+ class DynamicPathController(nn.Module):
32
+ """动态路径控制器:可微分的路径开关。
33
+
34
+ 使用 Gumbel-Sigmoid 实现可微分的二值决策,
35
+ 训练时使用软决策(梯度流通),推理时使用硬决策(效率优先)。
36
+
37
+ Args:
38
+ d_model: 输入维度
39
+ num_paths: 可控制的路径数
40
+ """
41
+
42
+ def __init__(self, d_model: int, num_paths: int = 5):
43
+ super().__init__()
44
+ self.num_paths = num_paths
45
+ self.gate_net = nn.Sequential(
46
+ nn.Linear(d_model, d_model // 4),
47
+ nn.GELU(),
48
+ nn.Linear(d_model // 4, num_paths),
49
+ )
50
+ # 初始偏置: 默认激活所有路径
51
+ nn.init.constant_(self.gate_net[-1].bias, 2.0)
52
+ self.temperature = 5.0 # 初始高温,鼓励探索
53
+ self.min_temperature = 0.5
54
+ self.anneal_steps = 10000 # 默认退火步数
55
+
56
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
57
+ """
58
+ Args:
59
+ x: (batch, seq_len, d_model)
60
+ Returns:
61
+ gates: (batch, num_paths) - 每条路径的激活概率
62
+ """
63
+ context = x.mean(dim=1) # (B, D) 全局上下文
64
+ logits = self.gate_net(context) # (B, num_paths)
65
+
66
+ if self.training:
67
+ # Gumbel-Sigmoid: 可微分的随机二值决策(收紧范围防止极端值)
68
+ noise = torch.zeros_like(logits).uniform_(1e-4, 1 - 1e-4)
69
+ noise = (torch.log(noise) - torch.log(1 - noise)).clamp(-10, 10)
70
+ gates = torch.sigmoid((logits + noise) / max(self.temperature, 0.1))
71
+ else:
72
+ gates = (logits > 0).float()
73
+
74
+ return gates
75
+
76
+ def anneal_temperature(self, step: int):
77
+ """线性温度退火:从 initial_temp 线性降至 min_temp。
78
+
79
+ 应在每个训练 step 后调用。
80
+
81
+ Args:
82
+ step: 当前训练步数
83
+ """
84
+ ratio = min(step / max(self.anneal_steps, 1), 1.0)
85
+ self.temperature = 5.0 * (1.0 - ratio) + self.min_temperature * ratio
86
+
87
+
88
+ class ComputeBudgetAllocator(nn.Module):
89
+ """计算预算分配器:根据输入复杂度动态分配计算资源。
90
+
91
+ 评估输入的 "难度",为简单输入分配少量计算(提速),
92
+ 为复杂输入分配更多计算(提质)。
93
+
94
+ Args:
95
+ d_model: 输入维度
96
+ num_levels: 计算预算级别数
97
+ """
98
+
99
+ def __init__(self, d_model: int, num_levels: int = 3):
100
+ super().__init__()
101
+ self.num_levels = num_levels
102
+ self.complexity_scorer = nn.Sequential(
103
+ nn.Linear(d_model, d_model // 4),
104
+ nn.GELU(),
105
+ nn.Linear(d_model // 4, 1),
106
+ nn.Sigmoid(),
107
+ )
108
+
109
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
110
+ """
111
+ Returns:
112
+ budget: (batch,) 值在 [0,1],表示应使用的计算比例
113
+ """
114
+ complexity = self.complexity_scorer(x.mean(dim=1)).squeeze(-1) # (B,)
115
+ return complexity
116
+
117
+
118
+ class SelfEvolutionEngine(nn.Module):
119
+ """自我进化引擎:统一的动态架构控制。
120
+
121
+ 整合路径门控和计算预算分配,为每个 CortexBlock
122
+ 提供动态的架构决策。
123
+
124
+ Args:
125
+ d_model: 模型维度
126
+ num_paths: 路径数
127
+ num_blocks: 模型总层数
128
+ """
129
+
130
+ def __init__(self, d_model: int, num_paths: int = 5, num_blocks: int = 4):
131
+ super().__init__()
132
+ self.num_paths = num_paths
133
+ self.num_blocks = num_blocks
134
+
135
+ # 每层独立的路径控制器
136
+ self.path_controllers = nn.ModuleList(
137
+ [DynamicPathController(d_model, num_paths) for _ in range(num_blocks)]
138
+ )
139
+ # 全局计算预算
140
+ self.budget_allocator = ComputeBudgetAllocator(d_model)
141
+
142
+ # 进化记忆: 记录历史最优配置 (不参与梯度)
143
+ self.register_buffer(
144
+ "config_history",
145
+ torch.ones(num_blocks, num_paths), # 平均配置
146
+ )
147
+ self._history_count = 0
148
+
149
+ def get_block_config(self, x: torch.Tensor, block_idx: int) -> torch.Tensor:
150
+ """获取指定层的路径激活配置。"""
151
+ return self.path_controllers[block_idx](x)
152
+
153
+ def get_compute_budget(self, x: torch.Tensor) -> torch.Tensor:
154
+ """获取全局计算预算。"""
155
+ return self.budget_allocator(x)
156
+
157
+ @torch.no_grad()
158
+ def update_history(self, block_idx: int, config: torch.Tensor):
159
+ """更新进化记忆。"""
160
+ mean_config = config.mean(dim=0)
161
+ momentum = 0.99
162
+ self.config_history[block_idx] = (
163
+ momentum * self.config_history[block_idx]
164
+ + (1 - momentum) * mean_config
165
+ )
166
+ self._history_count += 1
167
+
168
+ def get_efficiency_loss(self) -> torch.Tensor:
169
+ """计算效率损失: 鼓励使用更少的路径。"""
170
+ total_activation = sum(
171
+ ctrl.gate_net[-1].bias.sigmoid().mean()
172
+ for ctrl in self.path_controllers
173
+ )
174
+ return total_activation * 0.01 # 小权重避免过度压缩