cortexnet 3.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cortexnet/__init__.py +197 -0
- cortexnet/adapter/__init__.py +26 -0
- cortexnet/adapter/arch_adapter.py +209 -0
- cortexnet/adapter/calibrator.py +244 -0
- cortexnet/adapter/inference_adapter.py +272 -0
- cortexnet/adapter/model_registry.py +378 -0
- cortexnet/adapter/weight_adapter.py +415 -0
- cortexnet/adversarial.py +195 -0
- cortexnet/attention.py +520 -0
- cortexnet/blocks.py +682 -0
- cortexnet/cache.py +83 -0
- cortexnet/causal_reasoning.py +232 -0
- cortexnet/compat.py +245 -0
- cortexnet/config.py +234 -0
- cortexnet/continual_learning.py +256 -0
- cortexnet/cortex_block_lite.py +221 -0
- cortexnet/distributed.py +213 -0
- cortexnet/graph_reasoning.py +207 -0
- cortexnet/hierarchical_memory.py +360 -0
- cortexnet/interpretability.py +196 -0
- cortexnet/memory.py +179 -0
- cortexnet/meta_learning.py +187 -0
- cortexnet/model.py +1360 -0
- cortexnet/multi_agent.py +241 -0
- cortexnet/multimodal.py +278 -0
- cortexnet/ops/__init__.py +28 -0
- cortexnet/ops/device_manager.py +449 -0
- cortexnet/ops/npu_ops.py +243 -0
- cortexnet/quantization.py +496 -0
- cortexnet/routing.py +335 -0
- cortexnet/self_evolution.py +174 -0
- cortexnet/ssm.py +340 -0
- cortexnet/training_utils.py +204 -0
- cortexnet/transformer_baseline.py +157 -0
- cortexnet-3.2.1.dist-info/METADATA +114 -0
- cortexnet-3.2.1.dist-info/RECORD +39 -0
- cortexnet-3.2.1.dist-info/WHEEL +5 -0
- cortexnet-3.2.1.dist-info/licenses/LICENSE +201 -0
- cortexnet-3.2.1.dist-info/top_level.txt +1 -0
cortexnet/model.py
ADDED
|
@@ -0,0 +1,1360 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CortexNet: 超越 Transformer 的新一代神经网络架构
|
|
3
|
+
|
|
4
|
+
CortexNet 融合了五项核心创新:
|
|
5
|
+
1. 多尺度状态空间模块 (MSSM) — O(n) 高效序列建模
|
|
6
|
+
2. 选择性稀疏注意力 — 聚焦重要 token,大幅降低计算量
|
|
7
|
+
3. 突触可塑性记忆 — 受生物启发的快速上下文适应
|
|
8
|
+
4. 混合专家路由 (MoE) — 高效参数扩展
|
|
9
|
+
5. 自适应融合门控 — 动态平衡各处理路径
|
|
10
|
+
|
|
11
|
+
完整架构流程:
|
|
12
|
+
Token IDs → Embedding → [CortexBlock × N] → RMSNorm → LM Head → Logits
|
|
13
|
+
|
|
14
|
+
其中每个 CortexBlock:
|
|
15
|
+
Input ─┬─► SSM ────────┐
|
|
16
|
+
├─► Attention ──┤─► Fusion ─► + ─► MoE FFN ─► + ─► Output
|
|
17
|
+
└─► Memory ─────┘ ↑ ↑
|
|
18
|
+
Residual Residual
|
|
19
|
+
|
|
20
|
+
复杂度对比:
|
|
21
|
+
┌─────────────┬──────────────┬──────────────┐
|
|
22
|
+
│ │ Transformer │ CortexNet │
|
|
23
|
+
├─────────────┼──────────────┼──────────────┤
|
|
24
|
+
│ 序列建模 │ O(n²·d) │ O(n·d) │
|
|
25
|
+
│ 注意力 │ O(n²·d) │ O(n·k·d) │
|
|
26
|
+
│ 上下文学习 │ 隐式 │ 显式记忆 │
|
|
27
|
+
│ 参数利用 │ 全部激活 │ 稀疏激活(MoE) │
|
|
28
|
+
│ 位置编码 │ 绝对/相对 │ RoPE │
|
|
29
|
+
└─────────────┴──────────────┴──────────────┘
|
|
30
|
+
其中 k << n,可选 k = √n 实现亚二次复杂度
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
import os
|
|
34
|
+
import json
|
|
35
|
+
import hashlib
|
|
36
|
+
import threading
|
|
37
|
+
import torch
|
|
38
|
+
import torch.nn as nn
|
|
39
|
+
import torch.nn.functional as F
|
|
40
|
+
from typing import Optional, Dict, List, Tuple, Any
|
|
41
|
+
|
|
42
|
+
try:
|
|
43
|
+
from .config import CortexNetConfig
|
|
44
|
+
from .blocks import CortexBlock, CortexBlockV2, CortexBlockV3, RMSNorm
|
|
45
|
+
from .cortex_block_lite import CortexBlockLite
|
|
46
|
+
from .compat import (
|
|
47
|
+
_NoOpEvolutionEngine, _CompatCortexBlockV3,
|
|
48
|
+
)
|
|
49
|
+
except ImportError:
|
|
50
|
+
# 兼容脚本式导入: `from model import CortexNetV3`
|
|
51
|
+
from cortexnet.config import CortexNetConfig
|
|
52
|
+
from cortexnet.blocks import CortexBlock, CortexBlockV2, CortexBlockV3, RMSNorm
|
|
53
|
+
from cortexnet.cortex_block_lite import CortexBlockLite
|
|
54
|
+
from cortexnet.compat import (
|
|
55
|
+
_NoOpEvolutionEngine, _CompatCortexBlockV3,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class CortexNetBase(nn.Module):
|
|
61
|
+
"""CortexNet 语言模型。
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
config: CortexNetConfig 配置对象
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
def __init__(self, config: CortexNetConfig):
|
|
68
|
+
super().__init__()
|
|
69
|
+
self.config = config
|
|
70
|
+
|
|
71
|
+
# Token 嵌入
|
|
72
|
+
self.embed = nn.Embedding(config.vocab_size, config.hidden_size)
|
|
73
|
+
self.embed_dropout = nn.Dropout(config.dropout)
|
|
74
|
+
|
|
75
|
+
# CortexNet 块堆叠
|
|
76
|
+
self.blocks = nn.ModuleList(
|
|
77
|
+
[
|
|
78
|
+
CortexBlock(config, layer_idx=i)
|
|
79
|
+
for i in range(config.num_layers)
|
|
80
|
+
]
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# 最终归一化
|
|
84
|
+
self.final_norm = RMSNorm(config.hidden_size, config.norm_eps)
|
|
85
|
+
|
|
86
|
+
# 语言模型输出头(与嵌入层共享权重)
|
|
87
|
+
self.lm_head = nn.Linear(
|
|
88
|
+
config.hidden_size, config.vocab_size, bias=False
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
# 可选权重绑定(部分模型如 Qwen3 默认不绑定)
|
|
92
|
+
if getattr(config, "tie_word_embeddings", True):
|
|
93
|
+
self.lm_head.weight = self.embed.weight
|
|
94
|
+
|
|
95
|
+
# 初始化权重
|
|
96
|
+
self._maybe_init_weights()
|
|
97
|
+
|
|
98
|
+
def _silent_calibrate_compat(self):
|
|
99
|
+
"""兼容模式下的无感校准:保持无损迁移基线并为后续学习预留入口。"""
|
|
100
|
+
if not getattr(self, "compatibility_mode", False):
|
|
101
|
+
return
|
|
102
|
+
with torch.no_grad():
|
|
103
|
+
for block in self.blocks:
|
|
104
|
+
if hasattr(block, "ssm") and hasattr(block.ssm, "alpha"):
|
|
105
|
+
block.ssm.alpha.zero_()
|
|
106
|
+
if hasattr(block.ssm, "fast_skip"):
|
|
107
|
+
block.ssm.fast_skip = True
|
|
108
|
+
if hasattr(block, "fusion") and hasattr(block.fusion, "attn_bias"):
|
|
109
|
+
block.fusion.attn_bias.fill_(10.0)
|
|
110
|
+
|
|
111
|
+
def _init_weights(self, module: nn.Module):
|
|
112
|
+
"""使用缩放正态分布初始化权重。"""
|
|
113
|
+
if isinstance(module, nn.Linear):
|
|
114
|
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
|
115
|
+
if module.bias is not None:
|
|
116
|
+
torch.nn.init.zeros_(module.bias)
|
|
117
|
+
elif isinstance(module, nn.Embedding):
|
|
118
|
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
|
119
|
+
|
|
120
|
+
def _maybe_init_weights(self) -> None:
|
|
121
|
+
"""按配置决定是否执行自定义二次初始化。"""
|
|
122
|
+
if bool(getattr(self.config, "skip_weight_init", False)):
|
|
123
|
+
return
|
|
124
|
+
self.apply(self._init_weights)
|
|
125
|
+
|
|
126
|
+
def forward(
|
|
127
|
+
self,
|
|
128
|
+
input_ids: torch.Tensor,
|
|
129
|
+
labels: Optional[torch.Tensor] = None,
|
|
130
|
+
) -> Dict[str, torch.Tensor]:
|
|
131
|
+
"""
|
|
132
|
+
Args:
|
|
133
|
+
input_ids: (batch, seq_len) token 索引
|
|
134
|
+
labels: (batch, seq_len) 目标 token 索引
|
|
135
|
+
Returns:
|
|
136
|
+
字典包含 'logits',训练时还包含 'loss' 和 'aux_loss'
|
|
137
|
+
"""
|
|
138
|
+
B, L = input_ids.shape
|
|
139
|
+
|
|
140
|
+
# 嵌入
|
|
141
|
+
x = self.embed(input_ids) # (B, L, D)
|
|
142
|
+
x = self.embed_dropout(x)
|
|
143
|
+
|
|
144
|
+
# 通过 CortexNet 块
|
|
145
|
+
aux_loss = torch.tensor(0.0, device=x.device, dtype=x.dtype)
|
|
146
|
+
for block in self.blocks:
|
|
147
|
+
x = block(x)
|
|
148
|
+
if self.training:
|
|
149
|
+
aux_loss = aux_loss + block.get_aux_loss()
|
|
150
|
+
|
|
151
|
+
# 最终归一化 + 输出
|
|
152
|
+
x = self.final_norm(x)
|
|
153
|
+
logits = self.lm_head(x) # (B, L, vocab_size)
|
|
154
|
+
|
|
155
|
+
result = {"logits": logits}
|
|
156
|
+
|
|
157
|
+
# 计算损失
|
|
158
|
+
if labels is not None:
|
|
159
|
+
# 移位以进行下一个 token 预测
|
|
160
|
+
shift_logits = logits[:, :-1, :].contiguous()
|
|
161
|
+
shift_labels = labels[:, 1:].contiguous()
|
|
162
|
+
|
|
163
|
+
loss = F.cross_entropy(
|
|
164
|
+
shift_logits.view(-1, self.config.vocab_size),
|
|
165
|
+
shift_labels.view(-1),
|
|
166
|
+
ignore_index=-100,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
# 加入 MoE 负载均衡辅助损失
|
|
170
|
+
if self.training:
|
|
171
|
+
loss = loss + aux_loss
|
|
172
|
+
|
|
173
|
+
result["loss"] = loss
|
|
174
|
+
result["aux_loss"] = aux_loss
|
|
175
|
+
|
|
176
|
+
return result
|
|
177
|
+
|
|
178
|
+
@torch.no_grad()
|
|
179
|
+
def generate(
|
|
180
|
+
self,
|
|
181
|
+
input_ids: torch.Tensor,
|
|
182
|
+
max_new_tokens: int = 100,
|
|
183
|
+
temperature: float = 1.0,
|
|
184
|
+
top_k: int = 50,
|
|
185
|
+
top_p: float = 0.9,
|
|
186
|
+
) -> torch.Tensor:
|
|
187
|
+
"""自回归文本生成。
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
input_ids: (batch, seq_len) 初始 token 索引
|
|
191
|
+
max_new_tokens: 最大生成 token 数
|
|
192
|
+
temperature: 采样温度(越高越随机)
|
|
193
|
+
top_k: Top-k 过滤
|
|
194
|
+
top_p: 核采样阈值
|
|
195
|
+
Returns:
|
|
196
|
+
generated: (batch, seq_len + max_new_tokens) token 索引
|
|
197
|
+
"""
|
|
198
|
+
self.eval()
|
|
199
|
+
|
|
200
|
+
for _ in range(max_new_tokens):
|
|
201
|
+
# 截断到最大序列长度
|
|
202
|
+
idx_cond = (
|
|
203
|
+
input_ids
|
|
204
|
+
if input_ids.shape[1] <= self.config.max_seq_len
|
|
205
|
+
else input_ids[:, -self.config.max_seq_len :]
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
# 前向传播
|
|
209
|
+
output = self.forward(idx_cond)
|
|
210
|
+
logits = output["logits"][:, -1, :] / max(
|
|
211
|
+
temperature, 1e-8
|
|
212
|
+
) # (B, vocab)
|
|
213
|
+
|
|
214
|
+
# Top-k 过滤
|
|
215
|
+
if top_k > 0:
|
|
216
|
+
top_k_val = min(top_k, logits.size(-1))
|
|
217
|
+
indices_to_remove = logits < torch.topk(logits, top_k_val)[0][
|
|
218
|
+
..., -1, None
|
|
219
|
+
]
|
|
220
|
+
logits[indices_to_remove] = float("-inf")
|
|
221
|
+
|
|
222
|
+
# Top-p (nucleus) 过滤
|
|
223
|
+
if top_p < 1.0:
|
|
224
|
+
sorted_logits, sorted_indices = torch.sort(
|
|
225
|
+
logits, descending=True
|
|
226
|
+
)
|
|
227
|
+
cumulative_probs = torch.cumsum(
|
|
228
|
+
F.softmax(sorted_logits, dim=-1), dim=-1
|
|
229
|
+
)
|
|
230
|
+
sorted_indices_to_remove = cumulative_probs > top_p
|
|
231
|
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
|
|
232
|
+
..., :-1
|
|
233
|
+
].clone()
|
|
234
|
+
sorted_indices_to_remove[..., 0] = 0
|
|
235
|
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
|
236
|
+
1, sorted_indices, sorted_indices_to_remove
|
|
237
|
+
)
|
|
238
|
+
logits[indices_to_remove] = float("-inf")
|
|
239
|
+
|
|
240
|
+
# 采样(保护 NaN/Inf)
|
|
241
|
+
logits = logits.clamp(-100, 100)
|
|
242
|
+
logits = torch.where(torch.isnan(logits), torch.zeros_like(logits), logits)
|
|
243
|
+
probs = F.softmax(logits, dim=-1)
|
|
244
|
+
probs = probs.clamp(min=1e-8) # 防止零概率
|
|
245
|
+
next_token = torch.multinomial(probs, num_samples=1) # (B, 1)
|
|
246
|
+
input_ids = torch.cat([input_ids, next_token], dim=1)
|
|
247
|
+
|
|
248
|
+
return input_ids
|
|
249
|
+
|
|
250
|
+
def count_parameters(self) -> Dict[str, int]:
|
|
251
|
+
"""按组件统计模型参数量。"""
|
|
252
|
+
counts = {
|
|
253
|
+
"embedding": sum(p.numel() for p in self.embed.parameters()),
|
|
254
|
+
"ssm": 0,
|
|
255
|
+
"attention": 0,
|
|
256
|
+
"memory": 0,
|
|
257
|
+
"fusion": 0,
|
|
258
|
+
"moe": 0,
|
|
259
|
+
"norm": 0,
|
|
260
|
+
}
|
|
261
|
+
|
|
262
|
+
for block in self.blocks:
|
|
263
|
+
if hasattr(block, "ssm"):
|
|
264
|
+
counts["ssm"] += sum(p.numel() for p in block.ssm.parameters())
|
|
265
|
+
if hasattr(block, "attention"):
|
|
266
|
+
counts["attention"] += sum(
|
|
267
|
+
p.numel() for p in block.attention.parameters()
|
|
268
|
+
)
|
|
269
|
+
if hasattr(block, "memory"):
|
|
270
|
+
counts["memory"] += sum(
|
|
271
|
+
p.numel() for p in block.memory.parameters()
|
|
272
|
+
)
|
|
273
|
+
if hasattr(block, "fusion"):
|
|
274
|
+
counts["fusion"] += sum(
|
|
275
|
+
p.numel() for p in block.fusion.parameters()
|
|
276
|
+
)
|
|
277
|
+
if hasattr(block, "moe"):
|
|
278
|
+
counts["moe"] += sum(p.numel() for p in block.moe.parameters())
|
|
279
|
+
if hasattr(block, "norm1") and hasattr(block, "norm2"):
|
|
280
|
+
counts["norm"] += sum(
|
|
281
|
+
p.numel() for p in block.norm1.parameters()
|
|
282
|
+
) + sum(p.numel() for p in block.norm2.parameters())
|
|
283
|
+
|
|
284
|
+
counts["final_norm"] = sum(
|
|
285
|
+
p.numel() for p in self.final_norm.parameters()
|
|
286
|
+
)
|
|
287
|
+
counts["total"] = sum(p.numel() for p in self.parameters())
|
|
288
|
+
|
|
289
|
+
# 每个 token 的活跃参数(考虑 MoE 稀疏性)
|
|
290
|
+
if len(self.blocks) > 0:
|
|
291
|
+
if hasattr(self.blocks[0], "moe") and hasattr(self.blocks[0].moe, "experts"):
|
|
292
|
+
expert_params_per_block = sum(
|
|
293
|
+
p.numel() for p in self.blocks[0].moe.experts[0].parameters()
|
|
294
|
+
)
|
|
295
|
+
inactive_experts = (
|
|
296
|
+
self.config.num_experts - self.config.num_active_experts
|
|
297
|
+
)
|
|
298
|
+
inactive_params = (
|
|
299
|
+
expert_params_per_block
|
|
300
|
+
* inactive_experts
|
|
301
|
+
* self.config.num_layers
|
|
302
|
+
)
|
|
303
|
+
counts["active_per_token"] = counts["total"] - inactive_params
|
|
304
|
+
else:
|
|
305
|
+
counts["active_per_token"] = counts["total"]
|
|
306
|
+
else:
|
|
307
|
+
counts["active_per_token"] = counts["total"]
|
|
308
|
+
|
|
309
|
+
return counts
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
# ═══════════════════════════════════════════════════════════════
|
|
313
|
+
# CortexNet V2 — 全面进化版
|
|
314
|
+
# ═══════════════════════════════════════════════════════════════
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
class CortexNetV2(CortexNetBase):
|
|
318
|
+
"""CortexNet V2:全面进化的神经网络架构。
|
|
319
|
+
|
|
320
|
+
在 V1 基础上集成 9 大升级:
|
|
321
|
+
|
|
322
|
+
┌─────────────────────────────────────────────────────────────┐
|
|
323
|
+
│ CortexNet V2 进化清单 │
|
|
324
|
+
├─────────────────────────────────────────────────────────────┤
|
|
325
|
+
│ │
|
|
326
|
+
│ 1. ✦ 分层记忆系统 — 工作/情景/语义三层记忆 │
|
|
327
|
+
│ 2. ✦ 图推理模块 — 多步消息传递关系推理 │
|
|
328
|
+
│ 3. ✦ 元学习适配 — FiLM 快速任务适应 │
|
|
329
|
+
│ 4. ✦ 任务自适应控制 — 隐式多任务切换 │
|
|
330
|
+
│ 5. ✦ 协作式 MoE — 专家间知识共享 │
|
|
331
|
+
│ 6. ✦ 分块并行扫描 SSM — 长序列加速 │
|
|
332
|
+
│ 7. ✦ 连续学习支持 — EWC + 记忆回放 │
|
|
333
|
+
│ 8. ✦ 多模态编码 — 文本/图像/音频统一处理 │
|
|
334
|
+
│ 9. ✦ 可解释性系统 — 思维流实时监控 │
|
|
335
|
+
│ │
|
|
336
|
+
│ 架构: │
|
|
337
|
+
│ [MultiModal Encoder] → [CortexBlockV2 × N] → Output │
|
|
338
|
+
│ │
|
|
339
|
+
│ CortexBlockV2: │
|
|
340
|
+
│ SSM + Attention + HierarchicalMemory + GraphReasoning │
|
|
341
|
+
│ → AdaptiveFusion(4路) → MetaAdapter → CollabMoE │
|
|
342
|
+
│ → TaskController │
|
|
343
|
+
└─────────────────────────────────────────────────────────────┘
|
|
344
|
+
|
|
345
|
+
Args:
|
|
346
|
+
config: CortexNetConfig 配置
|
|
347
|
+
"""
|
|
348
|
+
|
|
349
|
+
def __init__(self, config: CortexNetConfig):
|
|
350
|
+
# 不调用 CortexNet.__init__,而是直接初始化
|
|
351
|
+
nn.Module.__init__(self)
|
|
352
|
+
self.config = config
|
|
353
|
+
|
|
354
|
+
# Token 嵌入
|
|
355
|
+
self.embed = nn.Embedding(config.vocab_size, config.hidden_size)
|
|
356
|
+
self.embed_dropout = nn.Dropout(config.dropout)
|
|
357
|
+
|
|
358
|
+
# V2 进化版块
|
|
359
|
+
self.blocks = nn.ModuleList(
|
|
360
|
+
[
|
|
361
|
+
CortexBlockV2(config, layer_idx=i)
|
|
362
|
+
for i in range(config.num_layers)
|
|
363
|
+
]
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
# 最终归一化
|
|
367
|
+
self.final_norm = RMSNorm(config.hidden_size, config.norm_eps)
|
|
368
|
+
|
|
369
|
+
# 语言模型输出头
|
|
370
|
+
self.lm_head = nn.Linear(
|
|
371
|
+
config.hidden_size, config.vocab_size, bias=False
|
|
372
|
+
)
|
|
373
|
+
if getattr(config, "tie_word_embeddings", True):
|
|
374
|
+
self.lm_head.weight = self.embed.weight
|
|
375
|
+
|
|
376
|
+
# 初始化
|
|
377
|
+
self._maybe_init_weights()
|
|
378
|
+
|
|
379
|
+
def get_evolution_info(self) -> Dict[str, any]:
|
|
380
|
+
"""获取 V2 进化信息。"""
|
|
381
|
+
info = self.count_parameters()
|
|
382
|
+
info["version"] = "V2"
|
|
383
|
+
info["num_paths"] = 4 # SSM + Attention + Memory + Graph
|
|
384
|
+
info["memory_tiers"] = 3 # Working + Episodic + Semantic
|
|
385
|
+
info["meta_learning"] = True
|
|
386
|
+
info["graph_reasoning"] = True
|
|
387
|
+
info["collaborative_moe"] = True
|
|
388
|
+
info["continual_learning_ready"] = True
|
|
389
|
+
info["multimodal_ready"] = True
|
|
390
|
+
return info
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
# ═══════════════════════════════════════════════════════════════
|
|
394
|
+
# CortexNet V3 — 终极进化: 世界顶级架构
|
|
395
|
+
# ═══════════════════════════════════════════════════════════════
|
|
396
|
+
|
|
397
|
+
try:
|
|
398
|
+
from .self_evolution import SelfEvolutionEngine
|
|
399
|
+
except ImportError:
|
|
400
|
+
from cortexnet.self_evolution import SelfEvolutionEngine
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
class CortexNetV3(CortexNetBase):
|
|
404
|
+
"""CortexNet V3:终极进化的世界顶级神经网络架构。
|
|
405
|
+
|
|
406
|
+
在 V2 基础上集成 4 项突破性能力:
|
|
407
|
+
|
|
408
|
+
┌─────────────────────────────────────────────────────────────┐
|
|
409
|
+
│ CortexNet V3 终极进化 │
|
|
410
|
+
├─────────────────────────────────────────────────────────────┤
|
|
411
|
+
│ V1 基础 (5 项): SSM + Attn + Memory + MoE + Fusion │
|
|
412
|
+
│ V2 进化 (9 项): +3层记忆 +图推理 +元学习 +协作MoE ... │
|
|
413
|
+
│ V3 突破 (4 项): │
|
|
414
|
+
│ 10. ✦ 因果推理 — 理解因果,支持反事实思考 │
|
|
415
|
+
│ 11. ✦ 自我进化 — 动态架构,输入驱动的 NAS │
|
|
416
|
+
│ 12. ✦ 多智能体 — 专家团队协作决策 │
|
|
417
|
+
│ 13. ✦ 对抗防御 — 三层防护,安全可靠 │
|
|
418
|
+
│ │
|
|
419
|
+
│ 处理路径: 5 条 (SSM + Attn + Memory + Graph + Causal) │
|
|
420
|
+
│ 架构: 动态自适应 (路径可按需激活/禁用) │
|
|
421
|
+
└─────────────────────────────────────────────────────────────┘
|
|
422
|
+
"""
|
|
423
|
+
|
|
424
|
+
def __init__(self, config: CortexNetConfig):
|
|
425
|
+
nn.Module.__init__(self)
|
|
426
|
+
self.config = config
|
|
427
|
+
|
|
428
|
+
# 嵌入
|
|
429
|
+
self.embed = nn.Embedding(config.vocab_size, config.hidden_size)
|
|
430
|
+
self.embed_dropout = nn.Dropout(config.dropout)
|
|
431
|
+
|
|
432
|
+
self.compatibility_mode = bool(getattr(config, "compatibility_mode", False))
|
|
433
|
+
|
|
434
|
+
self.lite_mode = bool(getattr(config, 'lite', False))
|
|
435
|
+
|
|
436
|
+
if self.lite_mode:
|
|
437
|
+
# Lite 模式:精简块,参数与 Transformer 相当
|
|
438
|
+
self.blocks = nn.ModuleList(
|
|
439
|
+
[CortexBlockLite(config, layer_idx=i) for i in range(config.num_layers)]
|
|
440
|
+
)
|
|
441
|
+
self.evolution_engine = _NoOpEvolutionEngine()
|
|
442
|
+
elif self.compatibility_mode:
|
|
443
|
+
# 兼容模式:使用轻量块,参数规模与开源源模型同量级
|
|
444
|
+
self.blocks = nn.ModuleList(
|
|
445
|
+
[_CompatCortexBlockV3(config) for _ in range(config.num_layers)]
|
|
446
|
+
)
|
|
447
|
+
self.evolution_engine = _NoOpEvolutionEngine()
|
|
448
|
+
else:
|
|
449
|
+
# 原始 V3 全功能路径
|
|
450
|
+
self.blocks = nn.ModuleList(
|
|
451
|
+
[CortexBlockV3(config, layer_idx=i) for i in range(config.num_layers)]
|
|
452
|
+
)
|
|
453
|
+
self.evolution_engine = SelfEvolutionEngine(
|
|
454
|
+
config.hidden_size,
|
|
455
|
+
num_paths=5,
|
|
456
|
+
num_blocks=config.num_layers,
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
# 最终归一化 + 输出
|
|
460
|
+
self.final_norm = RMSNorm(config.hidden_size, config.norm_eps)
|
|
461
|
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
462
|
+
if getattr(config, "tie_word_embeddings", True):
|
|
463
|
+
self.lm_head.weight = self.embed.weight
|
|
464
|
+
|
|
465
|
+
self._maybe_init_weights()
|
|
466
|
+
|
|
467
|
+
# 惰性上设备运行时状态
|
|
468
|
+
self._lazy_enabled = False
|
|
469
|
+
self._lazy_ready = True
|
|
470
|
+
self._lazy_target_device: Optional[str] = None
|
|
471
|
+
self._lazy_target_dtype: Optional[torch.dtype] = None
|
|
472
|
+
self._lazy_cpu_fallback = True
|
|
473
|
+
self._lazy_background_warmup = True
|
|
474
|
+
self._lazy_warmup_started = False
|
|
475
|
+
self._lazy_warmup_error: Optional[str] = None
|
|
476
|
+
self._lazy_lock = threading.RLock()
|
|
477
|
+
self._lazy_thread: Optional[threading.Thread] = None
|
|
478
|
+
|
|
479
|
+
def _setup_lazy_runtime(
|
|
480
|
+
self,
|
|
481
|
+
*,
|
|
482
|
+
target_device: str,
|
|
483
|
+
target_dtype: Optional[torch.dtype],
|
|
484
|
+
cpu_fallback: bool = True,
|
|
485
|
+
background_warmup: bool = True,
|
|
486
|
+
) -> None:
|
|
487
|
+
"""启用惰性上设备。
|
|
488
|
+
|
|
489
|
+
行为:
|
|
490
|
+
1. from_pretrained 先返回 CPU 可用模型;
|
|
491
|
+
2. 首次推理可走 CPU 兜底;
|
|
492
|
+
3. 推理结束后后台线程搬运到目标设备(如 mps/cuda/npu)。
|
|
493
|
+
"""
|
|
494
|
+
self._lazy_enabled = True
|
|
495
|
+
self._lazy_ready = False
|
|
496
|
+
self._lazy_target_device = str(target_device)
|
|
497
|
+
self._lazy_target_dtype = target_dtype
|
|
498
|
+
self._lazy_cpu_fallback = bool(cpu_fallback)
|
|
499
|
+
self._lazy_background_warmup = bool(background_warmup)
|
|
500
|
+
self._lazy_warmup_started = False
|
|
501
|
+
self._lazy_warmup_error = None
|
|
502
|
+
self._lazy_thread = None
|
|
503
|
+
|
|
504
|
+
def start_background_warmup(self) -> bool:
|
|
505
|
+
"""手动触发后台预热(可重复调用,只有首次生效)。"""
|
|
506
|
+
if not self._lazy_enabled or self._lazy_ready or self._lazy_warmup_started:
|
|
507
|
+
return False
|
|
508
|
+
|
|
509
|
+
def _worker():
|
|
510
|
+
try:
|
|
511
|
+
with self._lazy_lock:
|
|
512
|
+
if self._lazy_target_device is None:
|
|
513
|
+
self._lazy_warmup_error = "lazy target device is not set"
|
|
514
|
+
return
|
|
515
|
+
self.to(self._lazy_target_device)
|
|
516
|
+
if self._lazy_target_dtype is not None:
|
|
517
|
+
self.to(dtype=self._lazy_target_dtype)
|
|
518
|
+
self._lazy_ready = True
|
|
519
|
+
except Exception as exc:
|
|
520
|
+
self._lazy_warmup_error = str(exc)
|
|
521
|
+
|
|
522
|
+
self._lazy_warmup_started = True
|
|
523
|
+
self._lazy_thread = threading.Thread(target=_worker, daemon=True)
|
|
524
|
+
self._lazy_thread.start()
|
|
525
|
+
return True
|
|
526
|
+
|
|
527
|
+
def wait_until_ready(self, timeout: Optional[float] = None) -> bool:
|
|
528
|
+
"""等待后台预热完成。"""
|
|
529
|
+
thread = self._lazy_thread
|
|
530
|
+
if thread is not None:
|
|
531
|
+
thread.join(timeout=timeout)
|
|
532
|
+
return bool(self._lazy_ready)
|
|
533
|
+
|
|
534
|
+
def get_lazy_status(self) -> Dict[str, Any]:
|
|
535
|
+
return {
|
|
536
|
+
"enabled": self._lazy_enabled,
|
|
537
|
+
"ready": self._lazy_ready,
|
|
538
|
+
"target_device": self._lazy_target_device,
|
|
539
|
+
"target_dtype": (
|
|
540
|
+
str(self._lazy_target_dtype).replace("torch.", "")
|
|
541
|
+
if self._lazy_target_dtype is not None
|
|
542
|
+
else None
|
|
543
|
+
),
|
|
544
|
+
"background_warmup": self._lazy_background_warmup,
|
|
545
|
+
"warmup_started": self._lazy_warmup_started,
|
|
546
|
+
"warmup_error": self._lazy_warmup_error,
|
|
547
|
+
}
|
|
548
|
+
|
|
549
|
+
def _prepare_input_for_lazy(
|
|
550
|
+
self,
|
|
551
|
+
input_ids: torch.Tensor,
|
|
552
|
+
*,
|
|
553
|
+
start_warmup_after_infer: bool,
|
|
554
|
+
) -> Tuple[torch.Tensor, bool]:
|
|
555
|
+
"""根据惰性状态准备输入设备,并返回是否应在本次推理后触发预热。"""
|
|
556
|
+
if not self._lazy_enabled:
|
|
557
|
+
return input_ids, False
|
|
558
|
+
|
|
559
|
+
if self._lazy_ready and self._lazy_target_device is not None:
|
|
560
|
+
if str(input_ids.device) != self._lazy_target_device:
|
|
561
|
+
input_ids = input_ids.to(self._lazy_target_device)
|
|
562
|
+
return input_ids, False
|
|
563
|
+
|
|
564
|
+
# 尚未预热完成:走 CPU 兜底
|
|
565
|
+
if str(input_ids.device) != "cpu":
|
|
566
|
+
input_ids = input_ids.to("cpu")
|
|
567
|
+
|
|
568
|
+
should_start = (
|
|
569
|
+
bool(start_warmup_after_infer)
|
|
570
|
+
and self._lazy_background_warmup
|
|
571
|
+
and not self._lazy_warmup_started
|
|
572
|
+
)
|
|
573
|
+
return input_ids, should_start
|
|
574
|
+
|
|
575
|
+
def _move_cache_to_device(
|
|
576
|
+
self,
|
|
577
|
+
cache: Optional[List[Tuple[Any, Any, Any]]],
|
|
578
|
+
device: torch.device,
|
|
579
|
+
) -> Optional[List[Tuple[Any, Any, Any]]]:
|
|
580
|
+
"""将增量缓存迁移到目标设备,避免 CPU/MPS 混用报错。"""
|
|
581
|
+
if cache is None:
|
|
582
|
+
return None
|
|
583
|
+
|
|
584
|
+
def _move(x: Any) -> Any:
|
|
585
|
+
if torch.is_tensor(x):
|
|
586
|
+
return x.to(device)
|
|
587
|
+
if isinstance(x, tuple):
|
|
588
|
+
return tuple(_move(v) for v in x)
|
|
589
|
+
if isinstance(x, list):
|
|
590
|
+
return [_move(v) for v in x]
|
|
591
|
+
return x
|
|
592
|
+
|
|
593
|
+
return _move(cache)
|
|
594
|
+
|
|
595
|
+
def forward_from_embeddings(
|
|
596
|
+
self,
|
|
597
|
+
x_emb: torch.Tensor,
|
|
598
|
+
labels: Optional[torch.Tensor] = None,
|
|
599
|
+
past_cache: Optional[List[Tuple[Any, Any, Any]]] = None,
|
|
600
|
+
use_cache: bool = False,
|
|
601
|
+
) -> Dict[str, torch.Tensor]:
|
|
602
|
+
"""从嵌入张量前向(跳过 embedding 层),用于对抗训练等场景。"""
|
|
603
|
+
return self._forward_impl(x_emb, labels, past_cache, use_cache)
|
|
604
|
+
|
|
605
|
+
def forward(
|
|
606
|
+
self,
|
|
607
|
+
input_ids: torch.Tensor,
|
|
608
|
+
labels: Optional[torch.Tensor] = None,
|
|
609
|
+
past_cache: Optional[List[Tuple[Any, Any, Any]]] = None,
|
|
610
|
+
use_cache: bool = False,
|
|
611
|
+
*,
|
|
612
|
+
start_warmup_after_infer: bool = True,
|
|
613
|
+
ssm_only: bool = False,
|
|
614
|
+
) -> Dict[str, torch.Tensor]:
|
|
615
|
+
"""Forward with optional layer caches for incremental decoding.
|
|
616
|
+
|
|
617
|
+
Args:
|
|
618
|
+
ssm_only: Lite 模式下启用纯 SSM 解码。
|
|
619
|
+
"""
|
|
620
|
+
with self._lazy_lock:
|
|
621
|
+
input_ids, should_start_warmup = self._prepare_input_for_lazy(
|
|
622
|
+
input_ids,
|
|
623
|
+
start_warmup_after_infer=start_warmup_after_infer,
|
|
624
|
+
)
|
|
625
|
+
past_cache = self._move_cache_to_device(past_cache, input_ids.device)
|
|
626
|
+
x = self.embed_dropout(self.embed(input_ids))
|
|
627
|
+
output = self._forward_impl(x, labels, past_cache, use_cache, ssm_only=ssm_only)
|
|
628
|
+
|
|
629
|
+
if should_start_warmup:
|
|
630
|
+
self.start_background_warmup()
|
|
631
|
+
return output
|
|
632
|
+
|
|
633
|
+
def _forward_impl(
|
|
634
|
+
self,
|
|
635
|
+
x: torch.Tensor,
|
|
636
|
+
labels: Optional[torch.Tensor] = None,
|
|
637
|
+
past_cache: Optional[List[Tuple[Any, Any, Any]]] = None,
|
|
638
|
+
use_cache: bool = False,
|
|
639
|
+
ssm_only: bool = False,
|
|
640
|
+
) -> Dict[str, torch.Tensor]:
|
|
641
|
+
"""内部前向实现(接受已嵌入的张量)。
|
|
642
|
+
|
|
643
|
+
Args:
|
|
644
|
+
ssm_only: Lite 模式下启用纯 SSM 解码(跳过 Attention,O(1) 每 token)。
|
|
645
|
+
"""
|
|
646
|
+
|
|
647
|
+
compute_budget = self.evolution_engine.get_compute_budget(x)
|
|
648
|
+
|
|
649
|
+
aux_loss = torch.tensor(0.0, device=x.device, dtype=x.dtype)
|
|
650
|
+
new_cache: List[Tuple[Any, Any, Any]] = [] if use_cache else []
|
|
651
|
+
|
|
652
|
+
for i, block in enumerate(self.blocks):
|
|
653
|
+
past_layer = past_cache[i] if past_cache and i < len(past_cache) else None
|
|
654
|
+
# Lite 模式的 CortexBlockLite 支持 ssm_only 参数
|
|
655
|
+
block_kwargs = {}
|
|
656
|
+
if self.lite_mode and ssm_only:
|
|
657
|
+
block_kwargs['ssm_only'] = True
|
|
658
|
+
if use_cache:
|
|
659
|
+
x, layer_cache = block(x, past_cache=past_layer, use_cache=True, **block_kwargs)
|
|
660
|
+
new_cache.append(layer_cache)
|
|
661
|
+
else:
|
|
662
|
+
x = block(x, past_cache=past_layer, use_cache=False, **block_kwargs)
|
|
663
|
+
if self.training:
|
|
664
|
+
aux_loss = aux_loss + block.get_aux_loss()
|
|
665
|
+
|
|
666
|
+
if self.training:
|
|
667
|
+
aux_loss = aux_loss + self.evolution_engine.get_efficiency_loss()
|
|
668
|
+
|
|
669
|
+
x = self.final_norm(x)
|
|
670
|
+
logits = self.lm_head(x)
|
|
671
|
+
|
|
672
|
+
result = {"logits": logits, "compute_budget": compute_budget}
|
|
673
|
+
if use_cache and new_cache:
|
|
674
|
+
result["past_cache"] = new_cache
|
|
675
|
+
|
|
676
|
+
if labels is not None:
|
|
677
|
+
shift_logits = logits[:, :-1, :].contiguous()
|
|
678
|
+
shift_labels = labels[:, 1:].contiguous()
|
|
679
|
+
loss = F.cross_entropy(
|
|
680
|
+
shift_logits.view(-1, self.config.vocab_size),
|
|
681
|
+
shift_labels.view(-1), ignore_index=-100,
|
|
682
|
+
)
|
|
683
|
+
if self.training:
|
|
684
|
+
loss = loss + aux_loss
|
|
685
|
+
result["loss"] = loss
|
|
686
|
+
result["aux_loss"] = aux_loss
|
|
687
|
+
|
|
688
|
+
return result
|
|
689
|
+
|
|
690
|
+
@torch.no_grad()
|
|
691
|
+
def generate(
|
|
692
|
+
self,
|
|
693
|
+
input_ids: torch.Tensor,
|
|
694
|
+
max_new_tokens: int = 100,
|
|
695
|
+
temperature: float = 1.0,
|
|
696
|
+
top_k: int = 50,
|
|
697
|
+
top_p: float = 0.9,
|
|
698
|
+
repetition_penalty: float = 1.0,
|
|
699
|
+
) -> torch.Tensor:
|
|
700
|
+
"""自回归生成,使用 KV/状态缓存加速;支持 repetition_penalty 减轻重复。"""
|
|
701
|
+
self.eval()
|
|
702
|
+
input_ids, _ = self._prepare_input_for_lazy(
|
|
703
|
+
input_ids,
|
|
704
|
+
start_warmup_after_infer=False,
|
|
705
|
+
)
|
|
706
|
+
should_start_warmup_after_generate = (
|
|
707
|
+
self._lazy_enabled
|
|
708
|
+
and not self._lazy_ready
|
|
709
|
+
and self._lazy_background_warmup
|
|
710
|
+
and not self._lazy_warmup_started
|
|
711
|
+
)
|
|
712
|
+
past_cache: Optional[List[Tuple[Any, Any, Any]]] = None
|
|
713
|
+
generated = input_ids
|
|
714
|
+
use_repetition_penalty = repetition_penalty != 1.0 and repetition_penalty > 0
|
|
715
|
+
seen_mask: Optional[torch.Tensor] = None
|
|
716
|
+
|
|
717
|
+
for step_i in range(max_new_tokens):
|
|
718
|
+
if past_cache is None:
|
|
719
|
+
idx_cond = (
|
|
720
|
+
generated
|
|
721
|
+
if generated.shape[1] <= self.config.max_seq_len
|
|
722
|
+
else generated[:, -self.config.max_seq_len :]
|
|
723
|
+
)
|
|
724
|
+
else:
|
|
725
|
+
idx_cond = generated[:, -1:] # 增量:只处理新 token
|
|
726
|
+
|
|
727
|
+
# Lite SSM 优先解码:prefill 后切换纯 SSM 模式
|
|
728
|
+
use_ssm_only = (
|
|
729
|
+
self.lite_mode
|
|
730
|
+
and past_cache is not None
|
|
731
|
+
and getattr(self.config, 'ssm_decode_after', 0) > 0
|
|
732
|
+
and step_i >= self.config.ssm_decode_after
|
|
733
|
+
)
|
|
734
|
+
|
|
735
|
+
output = self.forward(
|
|
736
|
+
idx_cond,
|
|
737
|
+
past_cache=past_cache,
|
|
738
|
+
use_cache=True,
|
|
739
|
+
start_warmup_after_infer=False,
|
|
740
|
+
ssm_only=use_ssm_only,
|
|
741
|
+
)
|
|
742
|
+
logits = output["logits"][:, -1, :] / max(temperature, 1e-8)
|
|
743
|
+
past_cache = output.get("past_cache")
|
|
744
|
+
|
|
745
|
+
# 向量化重复惩罚:避免 Python token 循环(提升解码吞吐)
|
|
746
|
+
if use_repetition_penalty:
|
|
747
|
+
if seen_mask is None or seen_mask.shape[1] != logits.shape[1]:
|
|
748
|
+
seen_mask = torch.zeros(
|
|
749
|
+
logits.shape[0],
|
|
750
|
+
logits.shape[1],
|
|
751
|
+
device=logits.device,
|
|
752
|
+
dtype=torch.bool,
|
|
753
|
+
)
|
|
754
|
+
seen_ids = generated.clamp_max(logits.shape[1] - 1)
|
|
755
|
+
seen_mask.scatter_(1, seen_ids, True)
|
|
756
|
+
positive = logits > 0
|
|
757
|
+
logits = torch.where(seen_mask & positive, logits / repetition_penalty, logits)
|
|
758
|
+
logits = torch.where(seen_mask & ~positive, logits * repetition_penalty, logits)
|
|
759
|
+
|
|
760
|
+
if top_k > 0:
|
|
761
|
+
top_k_val = min(top_k, logits.size(-1))
|
|
762
|
+
_, top_indices = torch.topk(logits, top_k_val)
|
|
763
|
+
mask = torch.ones_like(logits, dtype=torch.bool)
|
|
764
|
+
mask.scatter_(1, top_indices, False)
|
|
765
|
+
logits.masked_fill_(mask, float("-inf"))
|
|
766
|
+
|
|
767
|
+
if top_p < 1.0:
|
|
768
|
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
|
769
|
+
cumulative_probs = torch.cumsum(
|
|
770
|
+
F.softmax(sorted_logits, dim=-1), dim=-1
|
|
771
|
+
)
|
|
772
|
+
sorted_indices_to_remove = cumulative_probs > top_p
|
|
773
|
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
|
|
774
|
+
..., :-1
|
|
775
|
+
].clone()
|
|
776
|
+
sorted_indices_to_remove[..., 0] = 0
|
|
777
|
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
|
778
|
+
1, sorted_indices, sorted_indices_to_remove
|
|
779
|
+
)
|
|
780
|
+
logits[indices_to_remove] = float("-inf")
|
|
781
|
+
|
|
782
|
+
logits = logits.clamp(-100, 100)
|
|
783
|
+
logits = torch.where(torch.isnan(logits), torch.zeros_like(logits), logits)
|
|
784
|
+
probs = F.softmax(logits, dim=-1)
|
|
785
|
+
probs = probs.clamp(min=1e-8)
|
|
786
|
+
next_token = torch.multinomial(probs, num_samples=1)
|
|
787
|
+
generated = torch.cat([generated, next_token], dim=1)
|
|
788
|
+
if seen_mask is not None:
|
|
789
|
+
seen_mask.scatter_(1, next_token, True)
|
|
790
|
+
|
|
791
|
+
if should_start_warmup_after_generate:
|
|
792
|
+
self.start_background_warmup()
|
|
793
|
+
return generated
|
|
794
|
+
|
|
795
|
+
def compile_model(self, **kwargs) -> "CortexNet":
|
|
796
|
+
"""使用 torch.compile 编译关键子模块,获得自动 kernel fusion 加速。
|
|
797
|
+
|
|
798
|
+
建议在训练/推理前调用一次。首次前向会有编译开销,后续大幅加速。
|
|
799
|
+
"""
|
|
800
|
+
try:
|
|
801
|
+
for block in self.blocks:
|
|
802
|
+
if hasattr(block, "ssm"):
|
|
803
|
+
block.ssm = torch.compile(block.ssm, **kwargs)
|
|
804
|
+
if hasattr(block, "attention"):
|
|
805
|
+
block.attention = torch.compile(block.attention, **kwargs)
|
|
806
|
+
if hasattr(block, "moe"):
|
|
807
|
+
block.moe = torch.compile(block.moe, **kwargs)
|
|
808
|
+
if hasattr(block, "graph_reasoning"):
|
|
809
|
+
block.graph_reasoning = torch.compile(block.graph_reasoning, **kwargs)
|
|
810
|
+
if hasattr(block, "causal_reasoning"):
|
|
811
|
+
block.causal_reasoning = torch.compile(block.causal_reasoning, **kwargs)
|
|
812
|
+
except Exception:
|
|
813
|
+
pass # torch.compile 不可用时静默跳过
|
|
814
|
+
return self
|
|
815
|
+
|
|
816
|
+
@torch.no_grad()
|
|
817
|
+
def speculative_generate(
|
|
818
|
+
self,
|
|
819
|
+
input_ids: torch.Tensor,
|
|
820
|
+
max_new_tokens: int = 100,
|
|
821
|
+
draft_steps: int = 4,
|
|
822
|
+
temperature: float = 1.0,
|
|
823
|
+
top_k: int = 50,
|
|
824
|
+
) -> torch.Tensor:
|
|
825
|
+
"""推测式解码:用浅层草稿批量猜测,全模型一次验证。
|
|
826
|
+
|
|
827
|
+
比标准自回归快 draft_steps 倍(理论上限),实际 2-3x 加速。
|
|
828
|
+
|
|
829
|
+
原理:
|
|
830
|
+
1. 用前 1 层作为轻量「草稿模型」,快速生成 draft_steps 个候选 token
|
|
831
|
+
2. 用完整模型对整个候选序列做一次前向,得到每步的 logits
|
|
832
|
+
3. 从左到右验证:若草稿与完整模型一致则接受,否则从分歧处截断
|
|
833
|
+
4. 接受 n 个 token + 采样 1 个新 token = 每轮最多前进 draft_steps+1 步
|
|
834
|
+
"""
|
|
835
|
+
self.eval()
|
|
836
|
+
input_ids, _ = self._prepare_input_for_lazy(
|
|
837
|
+
input_ids,
|
|
838
|
+
start_warmup_after_infer=False,
|
|
839
|
+
)
|
|
840
|
+
should_start_warmup_after_generate = (
|
|
841
|
+
self._lazy_enabled
|
|
842
|
+
and not self._lazy_ready
|
|
843
|
+
and self._lazy_background_warmup
|
|
844
|
+
and not self._lazy_warmup_started
|
|
845
|
+
)
|
|
846
|
+
generated = input_ids.clone()
|
|
847
|
+
|
|
848
|
+
while generated.shape[1] - input_ids.shape[1] < max_new_tokens:
|
|
849
|
+
remaining = max_new_tokens - (generated.shape[1] - input_ids.shape[1])
|
|
850
|
+
n_draft = min(draft_steps, remaining)
|
|
851
|
+
|
|
852
|
+
# === 阶段 1:草稿生成(只用嵌入 + 1 层 block + lm_head)===
|
|
853
|
+
draft_tokens = []
|
|
854
|
+
draft_input = generated[:, -1:]
|
|
855
|
+
for _ in range(n_draft):
|
|
856
|
+
x_d = self.embed(draft_input)
|
|
857
|
+
x_d = self.blocks[0](x_d) # 只用第 0 层
|
|
858
|
+
x_d = self.final_norm(x_d)
|
|
859
|
+
d_logits = self.lm_head(x_d)[:, -1, :] / max(temperature, 1e-8)
|
|
860
|
+
if top_k > 0:
|
|
861
|
+
v, _ = torch.topk(d_logits, min(top_k, d_logits.size(-1)))
|
|
862
|
+
d_logits[d_logits < v[:, -1:]] = float("-inf")
|
|
863
|
+
d_probs = F.softmax(d_logits, dim=-1)
|
|
864
|
+
d_tok = torch.multinomial(d_probs, 1)
|
|
865
|
+
draft_tokens.append(d_tok)
|
|
866
|
+
draft_input = d_tok
|
|
867
|
+
|
|
868
|
+
if not draft_tokens:
|
|
869
|
+
break
|
|
870
|
+
draft_seq = torch.cat(draft_tokens, dim=1) # (B, n_draft)
|
|
871
|
+
candidate = torch.cat([generated[:, -1:], draft_seq], dim=1) # (B, 1+n_draft)
|
|
872
|
+
|
|
873
|
+
# === 阶段 2:完整模型验证 ===
|
|
874
|
+
full_out = self.forward(candidate, start_warmup_after_infer=False)
|
|
875
|
+
full_logits = full_out["logits"][:, :-1, :] / max(temperature, 1e-8) # (B, n_draft, V)
|
|
876
|
+
|
|
877
|
+
# === 阶段 3:逐步验证 ===
|
|
878
|
+
n_accepted = 0
|
|
879
|
+
for i in range(n_draft):
|
|
880
|
+
if top_k > 0:
|
|
881
|
+
fl = full_logits[:, i, :]
|
|
882
|
+
v, _ = torch.topk(fl, min(top_k, fl.size(-1)))
|
|
883
|
+
fl[fl < v[:, -1:]] = float("-inf")
|
|
884
|
+
full_probs = F.softmax(fl, dim=-1)
|
|
885
|
+
else:
|
|
886
|
+
full_probs = F.softmax(full_logits[:, i, :], dim=-1)
|
|
887
|
+
full_tok = torch.multinomial(full_probs, 1)
|
|
888
|
+
if (full_tok == draft_seq[:, i:i+1]).all():
|
|
889
|
+
n_accepted += 1
|
|
890
|
+
else:
|
|
891
|
+
generated = torch.cat([generated, draft_seq[:, :i], full_tok], dim=1)
|
|
892
|
+
break
|
|
893
|
+
else:
|
|
894
|
+
# 全部接受 + 采样下一个
|
|
895
|
+
last_logits = full_out["logits"][:, -1, :] / max(temperature, 1e-8)
|
|
896
|
+
if top_k > 0:
|
|
897
|
+
v, _ = torch.topk(last_logits, min(top_k, last_logits.size(-1)))
|
|
898
|
+
last_logits[last_logits < v[:, -1:]] = float("-inf")
|
|
899
|
+
bonus = torch.multinomial(F.softmax(last_logits, dim=-1), 1)
|
|
900
|
+
generated = torch.cat([generated, draft_seq, bonus], dim=1)
|
|
901
|
+
|
|
902
|
+
if should_start_warmup_after_generate:
|
|
903
|
+
self.start_background_warmup()
|
|
904
|
+
return generated[:, : input_ids.shape[1] + max_new_tokens]
|
|
905
|
+
|
|
906
|
+
def get_evolution_info(self) -> Dict[str, any]:
|
|
907
|
+
info = self.count_parameters()
|
|
908
|
+
info["version"] = "CortexNet"
|
|
909
|
+
info["num_paths"] = 5
|
|
910
|
+
info["memory_tiers"] = 3
|
|
911
|
+
info["causal_reasoning"] = True
|
|
912
|
+
info["self_evolution"] = True
|
|
913
|
+
info["multi_agent"] = True
|
|
914
|
+
info["adversarial_defense"] = True
|
|
915
|
+
info["meta_learning"] = True
|
|
916
|
+
info["graph_reasoning"] = True
|
|
917
|
+
info["collaborative_moe"] = True
|
|
918
|
+
info["continual_learning_ready"] = True
|
|
919
|
+
info["multimodal_ready"] = True
|
|
920
|
+
return info
|
|
921
|
+
|
|
922
|
+
@classmethod
|
|
923
|
+
def from_pretrained(
|
|
924
|
+
cls,
|
|
925
|
+
model_path: str,
|
|
926
|
+
*,
|
|
927
|
+
model_type: Optional[str] = None,
|
|
928
|
+
device: Optional[str] = None,
|
|
929
|
+
dtype: Optional[torch.dtype] = None,
|
|
930
|
+
auto_calibrate: Optional[bool] = None,
|
|
931
|
+
calibration_data: Optional[List] = None,
|
|
932
|
+
load_weights: bool = True,
|
|
933
|
+
max_weight_tensors: Optional[int] = None,
|
|
934
|
+
**kwargs,
|
|
935
|
+
) -> "CortexNet":
|
|
936
|
+
"""从 HuggingFace 预训练模型一键加载 CortexNet。
|
|
937
|
+
|
|
938
|
+
完整加载流程:
|
|
939
|
+
1. 自动识别模型类型(LLaMA/Qwen/Mistral/...)
|
|
940
|
+
2. 转换配置为 CortexNetConfig
|
|
941
|
+
3. 初始化 CortexNet 模型
|
|
942
|
+
4. 映射权重到 CortexNet 模块
|
|
943
|
+
5. 执行架构适配(注意力/SSM/门控参数调整)
|
|
944
|
+
6. 可选:轻量校准
|
|
945
|
+
|
|
946
|
+
Args:
|
|
947
|
+
model_path: HuggingFace 模型目录路径
|
|
948
|
+
model_type: 模型类型(可选,自动检测)
|
|
949
|
+
device: 目标设备 ("cuda", "cpu", "npu" 等)
|
|
950
|
+
dtype: 目标数据类型 (torch.float16, torch.bfloat16 等)
|
|
951
|
+
auto_calibrate: 是否自动执行校准(None 时使用配置默认值)
|
|
952
|
+
calibration_data: 校准数据(可选)
|
|
953
|
+
load_weights: 是否加载权重(False 则仅创建架构)
|
|
954
|
+
max_weight_tensors: 最多加载多少个源权重 tensor(用于兼容性快速验证)
|
|
955
|
+
**kwargs: 覆盖 CortexNetConfig 的额外参数
|
|
956
|
+
|
|
957
|
+
Returns:
|
|
958
|
+
适配后的 CortexNet 模型实例
|
|
959
|
+
|
|
960
|
+
Example:
|
|
961
|
+
>>> model = CortexNet.from_pretrained("/path/to/llama3-7b")
|
|
962
|
+
>>> output = model.generate(input_ids, max_new_tokens=100)
|
|
963
|
+
"""
|
|
964
|
+
import logging
|
|
965
|
+
logger = logging.getLogger(__name__)
|
|
966
|
+
|
|
967
|
+
# ═══ Step 1: 识别模型类型 ═══
|
|
968
|
+
try:
|
|
969
|
+
from .adapter.model_registry import detect_model_type, get_cortexnet_config
|
|
970
|
+
except ImportError:
|
|
971
|
+
from cortexnet.adapter.model_registry import detect_model_type, get_cortexnet_config
|
|
972
|
+
|
|
973
|
+
if model_type is None:
|
|
974
|
+
model_type = detect_model_type(model_path)
|
|
975
|
+
logger.info(f"Loading model from {model_path} (type: {model_type})")
|
|
976
|
+
|
|
977
|
+
# ═══ Step 2: 生成 CortexNetConfig ═══
|
|
978
|
+
config = get_cortexnet_config(model_path, model_type)
|
|
979
|
+
|
|
980
|
+
# 应用用户覆盖参数
|
|
981
|
+
for key, value in kwargs.items():
|
|
982
|
+
if hasattr(config, key):
|
|
983
|
+
setattr(config, key, value)
|
|
984
|
+
|
|
985
|
+
# compatibility_mode 与 lite 二选一。
|
|
986
|
+
# 若用户显式开启 compatibility_mode 且未显式传 lite,
|
|
987
|
+
# 默认关闭 lite,保证走兼容路径而非 Lite 路径。
|
|
988
|
+
if (
|
|
989
|
+
bool(getattr(config, "compatibility_mode", False))
|
|
990
|
+
and "lite" not in kwargs
|
|
991
|
+
):
|
|
992
|
+
config.lite = False
|
|
993
|
+
|
|
994
|
+
if auto_calibrate is None:
|
|
995
|
+
# Lite 模式默认不校准(大模型 CPU 校准极慢)
|
|
996
|
+
if getattr(config, 'lite', False):
|
|
997
|
+
auto_calibrate = False
|
|
998
|
+
else:
|
|
999
|
+
auto_calibrate = bool(getattr(config, "auto_calibrate", True))
|
|
1000
|
+
|
|
1001
|
+
import glob
|
|
1002
|
+
|
|
1003
|
+
def _discover_weight_files(path: str) -> List[str]:
|
|
1004
|
+
files = sorted(glob.glob(os.path.join(path, "*.safetensors")))
|
|
1005
|
+
if not files:
|
|
1006
|
+
files = sorted(glob.glob(os.path.join(path, "pytorch_model*.bin")))
|
|
1007
|
+
if not files:
|
|
1008
|
+
files = sorted(glob.glob(os.path.join(path, "model*.safetensors")))
|
|
1009
|
+
return files
|
|
1010
|
+
|
|
1011
|
+
def _build_mapped_cache_file(
|
|
1012
|
+
*,
|
|
1013
|
+
source_path: str,
|
|
1014
|
+
source_model_type: str,
|
|
1015
|
+
files: List[str],
|
|
1016
|
+
cache_dir: str,
|
|
1017
|
+
) -> Optional[str]:
|
|
1018
|
+
if not files:
|
|
1019
|
+
return None
|
|
1020
|
+
weight_fingerprints = []
|
|
1021
|
+
for wf in files:
|
|
1022
|
+
stat = os.stat(wf)
|
|
1023
|
+
weight_fingerprints.append(
|
|
1024
|
+
{
|
|
1025
|
+
"name": os.path.basename(wf),
|
|
1026
|
+
"size": int(stat.st_size),
|
|
1027
|
+
"mtime_ns": int(stat.st_mtime_ns),
|
|
1028
|
+
}
|
|
1029
|
+
)
|
|
1030
|
+
cache_payload = {
|
|
1031
|
+
"model_path": os.path.abspath(source_path),
|
|
1032
|
+
"model_type": source_model_type,
|
|
1033
|
+
"compatibility_mode": bool(getattr(config, "compatibility_mode", False)),
|
|
1034
|
+
"expand_gqa_weights": bool(getattr(config, "expand_gqa_weights", True)),
|
|
1035
|
+
"compat_ssm_rank": int(getattr(config, "compat_ssm_rank", 256)),
|
|
1036
|
+
"hidden_size": int(getattr(config, "hidden_size", 0)),
|
|
1037
|
+
"num_layers": int(getattr(config, "num_layers", 0)),
|
|
1038
|
+
"num_heads": int(getattr(config, "num_heads", 0)),
|
|
1039
|
+
"num_kv_heads": int(getattr(config, "num_kv_heads", 0)),
|
|
1040
|
+
"intermediate_size": int(getattr(config, "intermediate_size", 0)),
|
|
1041
|
+
"tie_word_embeddings": bool(getattr(config, "tie_word_embeddings", True)),
|
|
1042
|
+
"weights": weight_fingerprints,
|
|
1043
|
+
}
|
|
1044
|
+
cache_key = hashlib.sha256(
|
|
1045
|
+
json.dumps(cache_payload, ensure_ascii=False, sort_keys=True).encode("utf-8")
|
|
1046
|
+
).hexdigest()[:24]
|
|
1047
|
+
return os.path.join(cache_dir, f"{source_model_type}_{cache_key}.safetensors")
|
|
1048
|
+
|
|
1049
|
+
mapped_cache_enabled = bool(getattr(config, "mapped_cache_enabled", False))
|
|
1050
|
+
if (
|
|
1051
|
+
load_weights
|
|
1052
|
+
and bool(getattr(config, "lazy_device_load", False))
|
|
1053
|
+
and bool(getattr(config, "mapped_cache_auto_enable_with_lazy", True))
|
|
1054
|
+
and (not mapped_cache_enabled)
|
|
1055
|
+
):
|
|
1056
|
+
mapped_cache_enabled = True
|
|
1057
|
+
config.mapped_cache_enabled = True
|
|
1058
|
+
logger.info("Auto-enable mapped cache because lazy_device_load=True")
|
|
1059
|
+
|
|
1060
|
+
mapped_cache_force_refresh = bool(getattr(config, "mapped_cache_force_refresh", False))
|
|
1061
|
+
mapped_cache_dir = getattr(config, "mapped_cache_dir", None) or os.path.join(
|
|
1062
|
+
os.path.expanduser("~"), ".cache", "cortexnet", "mapped_weights"
|
|
1063
|
+
)
|
|
1064
|
+
weight_files: List[str] = _discover_weight_files(model_path) if load_weights else []
|
|
1065
|
+
cache_file: Optional[str] = None
|
|
1066
|
+
cache_hit = False
|
|
1067
|
+
if load_weights and mapped_cache_enabled and weight_files:
|
|
1068
|
+
os.makedirs(mapped_cache_dir, exist_ok=True)
|
|
1069
|
+
cache_file = _build_mapped_cache_file(
|
|
1070
|
+
source_path=model_path,
|
|
1071
|
+
source_model_type=model_type,
|
|
1072
|
+
files=weight_files,
|
|
1073
|
+
cache_dir=mapped_cache_dir,
|
|
1074
|
+
)
|
|
1075
|
+
cache_hit = bool(
|
|
1076
|
+
cache_file
|
|
1077
|
+
and os.path.exists(cache_file)
|
|
1078
|
+
and (not mapped_cache_force_refresh)
|
|
1079
|
+
and max_weight_tensors is None
|
|
1080
|
+
)
|
|
1081
|
+
if cache_hit and bool(getattr(config, "mapped_cache_fast_init_on_hit", True)):
|
|
1082
|
+
# 命中缓存时跳过额外的自定义二次初始化,加快冷启动。
|
|
1083
|
+
setattr(config, "skip_weight_init", True)
|
|
1084
|
+
if (
|
|
1085
|
+
cache_hit
|
|
1086
|
+
and bool(getattr(config, "lazy_device_load", False))
|
|
1087
|
+
and bool(getattr(config, "lazy_disable_on_cache_hit", True))
|
|
1088
|
+
):
|
|
1089
|
+
# cache 命中时,直接走目标设备加载通常首 token 更快,避免 lazy CPU 首轮抖动。
|
|
1090
|
+
config.lazy_device_load = False
|
|
1091
|
+
logger.info("Mapped cache hit: disable lazy_device_load for better first-token latency")
|
|
1092
|
+
|
|
1093
|
+
# ═══ Step 3: 初始化模型 ═══
|
|
1094
|
+
model = cls(config)
|
|
1095
|
+
# 若用户显式指定 dtype,先转换
|
|
1096
|
+
if dtype is not None:
|
|
1097
|
+
model = model.to(dtype=dtype)
|
|
1098
|
+
|
|
1099
|
+
# ═══ Step 4: 映射并加载权重 ═══
|
|
1100
|
+
if load_weights:
|
|
1101
|
+
try:
|
|
1102
|
+
from .adapter.weight_adapter import WeightAdapter
|
|
1103
|
+
except ImportError:
|
|
1104
|
+
from cortexnet.adapter.weight_adapter import WeightAdapter
|
|
1105
|
+
import gc
|
|
1106
|
+
|
|
1107
|
+
if weight_files:
|
|
1108
|
+
loaded_from_cache = False
|
|
1109
|
+
if cache_hit and cache_file:
|
|
1110
|
+
try:
|
|
1111
|
+
from safetensors.torch import load_model # type: ignore
|
|
1112
|
+
|
|
1113
|
+
missing, unexpected = load_model(model, cache_file, strict=False, device="cpu")
|
|
1114
|
+
if missing or unexpected:
|
|
1115
|
+
logger.warning(
|
|
1116
|
+
"Mapped cache partially matched (missing=%s, unexpected=%s), "
|
|
1117
|
+
"fallback to raw mapping.",
|
|
1118
|
+
len(missing),
|
|
1119
|
+
len(unexpected),
|
|
1120
|
+
)
|
|
1121
|
+
else:
|
|
1122
|
+
loaded_from_cache = True
|
|
1123
|
+
logger.info(
|
|
1124
|
+
f"Loaded mapped weights cache: {cache_file} "
|
|
1125
|
+
f"(missing={len(missing)}, unexpected={len(unexpected)})"
|
|
1126
|
+
)
|
|
1127
|
+
except Exception as exc:
|
|
1128
|
+
logger.warning(f"Failed to load mapped cache, fallback to raw mapping: {exc}")
|
|
1129
|
+
|
|
1130
|
+
if not loaded_from_cache:
|
|
1131
|
+
target_tensors: Dict[str, torch.Tensor] = {}
|
|
1132
|
+
for name, param in model.named_parameters():
|
|
1133
|
+
target_tensors[name] = param
|
|
1134
|
+
for name, buf in model.named_buffers():
|
|
1135
|
+
if name not in target_tensors:
|
|
1136
|
+
target_tensors[name] = buf
|
|
1137
|
+
|
|
1138
|
+
adapter = WeightAdapter(model_type, config)
|
|
1139
|
+
total_source = 0
|
|
1140
|
+
total_mapped = 0
|
|
1141
|
+
total_shape_mismatch = 0
|
|
1142
|
+
total_unexpected = 0
|
|
1143
|
+
total_unmapped_source = 0
|
|
1144
|
+
mismatch_examples = 0
|
|
1145
|
+
stop_loading = False
|
|
1146
|
+
|
|
1147
|
+
def _iter_weight_tensors(weight_file: str):
|
|
1148
|
+
"""按 tensor 流式迭代,降低大模型加载峰值内存。"""
|
|
1149
|
+
if weight_file.endswith(".safetensors"):
|
|
1150
|
+
try:
|
|
1151
|
+
from safetensors.torch import safe_open # type: ignore
|
|
1152
|
+
except ImportError:
|
|
1153
|
+
logger.warning(
|
|
1154
|
+
"safetensors not installed. Install with: pip install safetensors"
|
|
1155
|
+
)
|
|
1156
|
+
return
|
|
1157
|
+
|
|
1158
|
+
with safe_open(weight_file, framework="pt", device="cpu") as f:
|
|
1159
|
+
for key in f.keys():
|
|
1160
|
+
yield {key: f.get_tensor(key)}
|
|
1161
|
+
return
|
|
1162
|
+
|
|
1163
|
+
if weight_file.endswith(".bin"):
|
|
1164
|
+
state = torch.load(weight_file, map_location="cpu", weights_only=True)
|
|
1165
|
+
if isinstance(state, dict):
|
|
1166
|
+
for key, value in state.items():
|
|
1167
|
+
if torch.is_tensor(value):
|
|
1168
|
+
yield {key: value}
|
|
1169
|
+
|
|
1170
|
+
for wf in weight_files:
|
|
1171
|
+
logger.info(f"Loading weight shard: {os.path.basename(wf)}")
|
|
1172
|
+
for raw_chunk in _iter_weight_tensors(wf):
|
|
1173
|
+
if not raw_chunk:
|
|
1174
|
+
continue
|
|
1175
|
+
if max_weight_tensors is not None and total_source >= max_weight_tensors:
|
|
1176
|
+
stop_loading = True
|
|
1177
|
+
break
|
|
1178
|
+
total_source += len(raw_chunk)
|
|
1179
|
+
|
|
1180
|
+
mapped_chunk = adapter.map_weights(raw_chunk)
|
|
1181
|
+
total_unmapped_source += len(adapter.get_unmapped_weights())
|
|
1182
|
+
|
|
1183
|
+
for name, param in mapped_chunk.items():
|
|
1184
|
+
target = target_tensors.get(name)
|
|
1185
|
+
if target is not None:
|
|
1186
|
+
if target.shape == param.shape:
|
|
1187
|
+
with torch.no_grad():
|
|
1188
|
+
target.copy_(param.to(dtype=target.dtype))
|
|
1189
|
+
total_mapped += 1
|
|
1190
|
+
else:
|
|
1191
|
+
total_shape_mismatch += 1
|
|
1192
|
+
if mismatch_examples < 5:
|
|
1193
|
+
logger.warning(
|
|
1194
|
+
f"Shape mismatch for {name}: "
|
|
1195
|
+
f"expected {target.shape}, "
|
|
1196
|
+
f"got {param.shape}"
|
|
1197
|
+
)
|
|
1198
|
+
mismatch_examples += 1
|
|
1199
|
+
else:
|
|
1200
|
+
total_unexpected += 1
|
|
1201
|
+
|
|
1202
|
+
# 主动触发回收,减少大模型加载时峰值内存
|
|
1203
|
+
if total_source % 1024 == 0:
|
|
1204
|
+
gc.collect()
|
|
1205
|
+
if stop_loading:
|
|
1206
|
+
logger.info(
|
|
1207
|
+
f"Reached max_weight_tensors={max_weight_tensors}, "
|
|
1208
|
+
"stopping early for compatibility check."
|
|
1209
|
+
)
|
|
1210
|
+
break
|
|
1211
|
+
logger.info(
|
|
1212
|
+
f"Weight loading: {total_mapped} mapped, "
|
|
1213
|
+
f"{total_shape_mismatch} shape mismatches, "
|
|
1214
|
+
f"{total_unexpected} unexpected, "
|
|
1215
|
+
f"{total_unmapped_source} source weights unmapped "
|
|
1216
|
+
f"out of {total_source} tensors"
|
|
1217
|
+
)
|
|
1218
|
+
|
|
1219
|
+
# 首次完成映射后缓存结果,二次加载可直接复用
|
|
1220
|
+
if (
|
|
1221
|
+
mapped_cache_enabled
|
|
1222
|
+
and cache_file
|
|
1223
|
+
and max_weight_tensors is None
|
|
1224
|
+
):
|
|
1225
|
+
try:
|
|
1226
|
+
from safetensors.torch import save_model # type: ignore
|
|
1227
|
+
|
|
1228
|
+
metadata = {
|
|
1229
|
+
"model_type": model_type,
|
|
1230
|
+
"compatibility_mode": str(
|
|
1231
|
+
bool(getattr(config, "compatibility_mode", False))
|
|
1232
|
+
),
|
|
1233
|
+
}
|
|
1234
|
+
save_model(model, cache_file, metadata=metadata)
|
|
1235
|
+
logger.info(f"Saved mapped weights cache: {cache_file}")
|
|
1236
|
+
except Exception as exc:
|
|
1237
|
+
logger.warning(f"Failed to save mapped cache: {exc}")
|
|
1238
|
+
else:
|
|
1239
|
+
logger.warning(
|
|
1240
|
+
f"No weight files found in {model_path}. "
|
|
1241
|
+
f"Model initialized with random weights."
|
|
1242
|
+
)
|
|
1243
|
+
|
|
1244
|
+
# ═══ Step 5: 架构适配 ═══
|
|
1245
|
+
if not getattr(config, "compatibility_mode", False):
|
|
1246
|
+
try:
|
|
1247
|
+
from .adapter.arch_adapter import ArchitectureAdapter
|
|
1248
|
+
except ImportError:
|
|
1249
|
+
from cortexnet.adapter.arch_adapter import ArchitectureAdapter
|
|
1250
|
+
arch_adapter = ArchitectureAdapter(model_type, config)
|
|
1251
|
+
arch_adapter.adapt(model)
|
|
1252
|
+
|
|
1253
|
+
# ═══ Step 6: 可选校准 ═══
|
|
1254
|
+
if auto_calibrate:
|
|
1255
|
+
if getattr(config, "compatibility_mode", False):
|
|
1256
|
+
model._silent_calibrate_compat()
|
|
1257
|
+
logger.info("Compatibility calibration applied (silent mode).")
|
|
1258
|
+
else:
|
|
1259
|
+
try:
|
|
1260
|
+
from .adapter.calibrator import LightweightCalibrator
|
|
1261
|
+
except ImportError:
|
|
1262
|
+
from cortexnet.adapter.calibrator import LightweightCalibrator
|
|
1263
|
+
calibrator = LightweightCalibrator(model, model_type)
|
|
1264
|
+
model = calibrator.calibrate(calibration_data=calibration_data)
|
|
1265
|
+
|
|
1266
|
+
# ═══ Step 7: 设备和精度 ═══
|
|
1267
|
+
try:
|
|
1268
|
+
from .ops.device_manager import (
|
|
1269
|
+
get_best_device_info,
|
|
1270
|
+
resolve_device_string,
|
|
1271
|
+
resolve_dtype_for_device,
|
|
1272
|
+
)
|
|
1273
|
+
except ImportError:
|
|
1274
|
+
from cortexnet.ops.device_manager import (
|
|
1275
|
+
get_best_device_info,
|
|
1276
|
+
resolve_device_string,
|
|
1277
|
+
resolve_dtype_for_device,
|
|
1278
|
+
)
|
|
1279
|
+
|
|
1280
|
+
if device is None:
|
|
1281
|
+
device_info = get_best_device_info()
|
|
1282
|
+
device = str(device_info.torch_device)
|
|
1283
|
+
if dtype is None:
|
|
1284
|
+
dtype = device_info.optimal_dtype
|
|
1285
|
+
else:
|
|
1286
|
+
device = resolve_device_string(
|
|
1287
|
+
str(device),
|
|
1288
|
+
auto_priority=("npu", "cuda", "mlu", "mps", "cpu"),
|
|
1289
|
+
allow_fallback=True,
|
|
1290
|
+
)
|
|
1291
|
+
|
|
1292
|
+
if isinstance(dtype, str):
|
|
1293
|
+
dtype = resolve_dtype_for_device(dtype, str(device))
|
|
1294
|
+
|
|
1295
|
+
lazy_device_load = bool(getattr(config, "lazy_device_load", False))
|
|
1296
|
+
lazy_cpu_fallback = bool(getattr(config, "lazy_cpu_fallback", True))
|
|
1297
|
+
lazy_background_warmup = bool(getattr(config, "lazy_background_warmup", True))
|
|
1298
|
+
target_device_type = str(device).split(":", 1)[0]
|
|
1299
|
+
use_lazy_runtime = lazy_device_load and target_device_type != "cpu"
|
|
1300
|
+
|
|
1301
|
+
if use_lazy_runtime:
|
|
1302
|
+
# 惰性模式:先返回 CPU 模型,首次推理后后台预热到目标设备。
|
|
1303
|
+
model = model.to("cpu")
|
|
1304
|
+
if dtype is not None:
|
|
1305
|
+
model = model.to(dtype)
|
|
1306
|
+
model._setup_lazy_runtime(
|
|
1307
|
+
target_device=str(device),
|
|
1308
|
+
target_dtype=dtype,
|
|
1309
|
+
cpu_fallback=lazy_cpu_fallback,
|
|
1310
|
+
background_warmup=lazy_background_warmup,
|
|
1311
|
+
)
|
|
1312
|
+
if (
|
|
1313
|
+
bool(getattr(config, "lazy_start_warmup_on_load", True))
|
|
1314
|
+
and lazy_background_warmup
|
|
1315
|
+
):
|
|
1316
|
+
model.start_background_warmup()
|
|
1317
|
+
else:
|
|
1318
|
+
model = model.to(device)
|
|
1319
|
+
if dtype is not None:
|
|
1320
|
+
model = model.to(dtype)
|
|
1321
|
+
|
|
1322
|
+
# 存储推理适配器
|
|
1323
|
+
try:
|
|
1324
|
+
from .adapter.inference_adapter import InferenceAdapter
|
|
1325
|
+
except ImportError:
|
|
1326
|
+
from cortexnet.adapter.inference_adapter import InferenceAdapter
|
|
1327
|
+
model._inference_adapter = InferenceAdapter(model, model_type)
|
|
1328
|
+
|
|
1329
|
+
model.eval()
|
|
1330
|
+
if getattr(model, "compatibility_mode", False):
|
|
1331
|
+
logger.info("Core Mode: SSM + Sparse Attention + Lite Fusion")
|
|
1332
|
+
else:
|
|
1333
|
+
logger.info("Core Mode: Full CortexNet")
|
|
1334
|
+
if use_lazy_runtime:
|
|
1335
|
+
logger.info(
|
|
1336
|
+
f"Model loaded in lazy mode (cpu fallback), target={device}, dtype={dtype}"
|
|
1337
|
+
)
|
|
1338
|
+
else:
|
|
1339
|
+
logger.info(f"Model loaded successfully on {device} with dtype {dtype}")
|
|
1340
|
+
return model
|
|
1341
|
+
|
|
1342
|
+
def smart_generate(self, input_ids: torch.Tensor, **kwargs):
|
|
1343
|
+
"""使用推理适配器的智能生成(自动匹配源模型参数)。
|
|
1344
|
+
|
|
1345
|
+
如果模型是通过 from_pretrained() 加载的,会使用源模型的默认参数。
|
|
1346
|
+
否则回退到标准 generate()。
|
|
1347
|
+
"""
|
|
1348
|
+
if hasattr(self, '_inference_adapter'):
|
|
1349
|
+
return self._inference_adapter.generate(input_ids, **kwargs)
|
|
1350
|
+
return self.generate(input_ids, **kwargs)
|
|
1351
|
+
|
|
1352
|
+
|
|
1353
|
+
class CortexNet(CortexNetV3):
|
|
1354
|
+
"""Unified CortexNet model (canonical API).
|
|
1355
|
+
|
|
1356
|
+
This is the primary public model class. Legacy names like ``CortexNetV3``
|
|
1357
|
+
remain available for compatibility, but new code should use ``CortexNet``.
|
|
1358
|
+
"""
|
|
1359
|
+
|
|
1360
|
+
pass
|