cortexnet 3.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cortexnet/__init__.py ADDED
@@ -0,0 +1,197 @@
1
+ """CortexNet public package API."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from .config import CortexNetConfig, TrainingConfig
6
+ from .model import CortexNet, CortexNetBase, CortexNetV2 as _CortexNetV2, CortexNetV3 as _CortexNetV3
7
+ from .blocks import CortexBlock, CortexBlockV2 as _CortexBlockV2, CortexBlockV3 as _CortexBlockV3, RMSNorm, AdaptiveFusionGate
8
+ from .ssm import MultiScaleSSM
9
+ from .attention import SelectiveSparseAttention
10
+ from .memory import SynapticMemory
11
+ from .routing import MixtureOfExperts, ExpertFFN, CollaborativeMoE
12
+ from .cache import CortexNetCache
13
+ from .transformer_baseline import TransformerLM
14
+ from .quantization import quantize_dynamic, QuantizationWrapper
15
+ from .hierarchical_memory import (
16
+ HierarchicalMemorySystem,
17
+ WorkingMemory,
18
+ EpisodicMemory,
19
+ SemanticMemory,
20
+ MemoryController,
21
+ )
22
+ from .graph_reasoning import GraphReasoningModule
23
+ from .meta_learning import MetaLearningAdapter, TaskAdaptiveController, ContextEncoder
24
+ from .multimodal import MultiModalEncoder, PatchEmbedding, AudioEmbedding, CrossModalFusion
25
+ from .continual_learning import (
26
+ ElasticWeightConsolidation,
27
+ ProgressiveMemoryReplay,
28
+ ContinualLearningManager,
29
+ )
30
+ from .interpretability import ThoughtFlowMonitor
31
+ from .causal_reasoning import CausalReasoningModule, InterventionalAttention, CounterfactualBranch
32
+ from .self_evolution import SelfEvolutionEngine, DynamicPathController, ComputeBudgetAllocator
33
+ from .multi_agent import MultiAgentSystem, SpecialistAgent, AgentCoordinator, SharedMessageBoard
34
+ from .adversarial import AdversarialShield, AdversarialTrainer, InputShield, FeatureShield
35
+ from .training_utils import GradientMonitor, check_gradients_finite, set_seed, get_best_device
36
+ from .adapter import (
37
+ WeightAdapter,
38
+ ArchitectureAdapter,
39
+ InferenceAdapter,
40
+ LightweightCalibrator,
41
+ ModelRegistry,
42
+ detect_model_type,
43
+ get_cortexnet_config,
44
+ )
45
+ from .ops import (
46
+ DeviceManager,
47
+ get_best_device_info,
48
+ is_npu_available,
49
+ is_mlu_available,
50
+ get_device_type,
51
+ resolve_device_string,
52
+ resolve_dtype_for_device,
53
+ NPUOperators,
54
+ get_operators,
55
+ )
56
+ from .distributed import (
57
+ setup_distributed,
58
+ cleanup_distributed,
59
+ wrap_fsdp,
60
+ wrap_ddp,
61
+ get_rank,
62
+ get_world_size,
63
+ is_main_process,
64
+ )
65
+
66
+ _HAS_DATA = False
67
+ try:
68
+ from .data import (
69
+ SimpleTokenizer,
70
+ MiniMindTokenizer,
71
+ TextCorpusDataset,
72
+ StreamingDataset,
73
+ ConversationDataset,
74
+ PretrainDataset,
75
+ CodeCompletionDataset,
76
+ CodeGenerationDataset,
77
+ download_wikitext2,
78
+ download_minimind_data,
79
+ )
80
+ _DATA_EXPORTS = (
81
+ SimpleTokenizer,
82
+ MiniMindTokenizer,
83
+ TextCorpusDataset,
84
+ StreamingDataset,
85
+ ConversationDataset,
86
+ PretrainDataset,
87
+ CodeCompletionDataset,
88
+ CodeGenerationDataset,
89
+ download_wikitext2,
90
+ download_minimind_data,
91
+ )
92
+ _HAS_DATA = True
93
+ except ImportError:
94
+ _DATA_EXPORTS = ()
95
+ pass
96
+
97
+ # Legacy compatibility aliases (not part of the default public API list).
98
+ CortexNetV2 = _CortexNetV2
99
+ CortexNetV3 = _CortexNetV3
100
+ CortexBlockV2 = _CortexBlockV2
101
+ CortexBlockV3 = _CortexBlockV3
102
+
103
+ __version__ = "3.2.1"
104
+
105
+ __all__ = [
106
+ "CortexNet",
107
+ "CortexNetBase",
108
+ "CortexNetConfig",
109
+ "TrainingConfig",
110
+ "CortexBlock",
111
+ "RMSNorm",
112
+ "AdaptiveFusionGate",
113
+ "MultiScaleSSM",
114
+ "SelectiveSparseAttention",
115
+ "SynapticMemory",
116
+ "MixtureOfExperts",
117
+ "ExpertFFN",
118
+ "CollaborativeMoE",
119
+ "CortexNetCache",
120
+ "TransformerLM",
121
+ "quantize_dynamic",
122
+ "QuantizationWrapper",
123
+ "HierarchicalMemorySystem",
124
+ "WorkingMemory",
125
+ "EpisodicMemory",
126
+ "SemanticMemory",
127
+ "MemoryController",
128
+ "GraphReasoningModule",
129
+ "MetaLearningAdapter",
130
+ "TaskAdaptiveController",
131
+ "ContextEncoder",
132
+ "MultiModalEncoder",
133
+ "PatchEmbedding",
134
+ "AudioEmbedding",
135
+ "CrossModalFusion",
136
+ "ElasticWeightConsolidation",
137
+ "ProgressiveMemoryReplay",
138
+ "ContinualLearningManager",
139
+ "ThoughtFlowMonitor",
140
+ "CausalReasoningModule",
141
+ "InterventionalAttention",
142
+ "CounterfactualBranch",
143
+ "SelfEvolutionEngine",
144
+ "DynamicPathController",
145
+ "ComputeBudgetAllocator",
146
+ "MultiAgentSystem",
147
+ "SpecialistAgent",
148
+ "AgentCoordinator",
149
+ "SharedMessageBoard",
150
+ "AdversarialShield",
151
+ "AdversarialTrainer",
152
+ "InputShield",
153
+ "FeatureShield",
154
+ "GradientMonitor",
155
+ "check_gradients_finite",
156
+ "set_seed",
157
+ "get_best_device",
158
+ "WeightAdapter",
159
+ "ArchitectureAdapter",
160
+ "InferenceAdapter",
161
+ "LightweightCalibrator",
162
+ "ModelRegistry",
163
+ "detect_model_type",
164
+ "get_cortexnet_config",
165
+ "DeviceManager",
166
+ "get_best_device_info",
167
+ "is_npu_available",
168
+ "is_mlu_available",
169
+ "get_device_type",
170
+ "resolve_device_string",
171
+ "resolve_dtype_for_device",
172
+ "NPUOperators",
173
+ "get_operators",
174
+ "setup_distributed",
175
+ "cleanup_distributed",
176
+ "wrap_fsdp",
177
+ "wrap_ddp",
178
+ "get_rank",
179
+ "get_world_size",
180
+ "is_main_process",
181
+ ]
182
+
183
+ if _HAS_DATA:
184
+ __all__.extend(
185
+ [
186
+ "SimpleTokenizer",
187
+ "MiniMindTokenizer",
188
+ "TextCorpusDataset",
189
+ "StreamingDataset",
190
+ "ConversationDataset",
191
+ "PretrainDataset",
192
+ "CodeCompletionDataset",
193
+ "CodeGenerationDataset",
194
+ "download_wikitext2",
195
+ "download_minimind_data",
196
+ ]
197
+ )
@@ -0,0 +1,26 @@
1
+ """
2
+ CortexNet 适配器子包 (Adapter Sub-package)
3
+
4
+ 提供开源大模型到 CortexNet 的自动适配能力:
5
+ - 模型识别与注册
6
+ - 权重映射
7
+ - 架构适配
8
+ - 推理接口统一
9
+ - 轻量校准
10
+ """
11
+
12
+ from .model_registry import ModelRegistry, detect_model_type, get_cortexnet_config
13
+ from .weight_adapter import WeightAdapter
14
+ from .arch_adapter import ArchitectureAdapter
15
+ from .inference_adapter import InferenceAdapter
16
+ from .calibrator import LightweightCalibrator
17
+
18
+ __all__ = [
19
+ "ModelRegistry",
20
+ "detect_model_type",
21
+ "get_cortexnet_config",
22
+ "WeightAdapter",
23
+ "ArchitectureAdapter",
24
+ "InferenceAdapter",
25
+ "LightweightCalibrator",
26
+ ]
@@ -0,0 +1,209 @@
1
+ """
2
+ 架构适配层 (Architecture Adapter)
3
+
4
+ 核心功能:
5
+ 根据源模型的架构特性,自动调整 CortexNet 内部模块的参数和行为,
6
+ 使其最大程度保留原模型的能力。
7
+
8
+ 适配项:
9
+ 1. 稀疏注意力 top-k / 滑动窗口
10
+ 2. SSM 多尺度参数缩放
11
+ 3. 动态门控阈值(训推一致性)
12
+ 4. 位置编码参数(RoPE scaling / ALiBi)
13
+ 5. MoE 路由初始化
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import logging
19
+ from typing import Any
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class ArchitectureAdapter:
28
+ """架构适配器。
29
+
30
+ 在模型初始化后,根据源模型类型自动调整 CortexNet 模块参数。
31
+ """
32
+
33
+ def __init__(self, model_type: str, config: Any):
34
+ self.model_type = model_type
35
+ self.config = config
36
+
37
+ def adapt(self, cortex_model: nn.Module) -> nn.Module:
38
+ """对 CortexNet 模型执行全面架构适配。
39
+
40
+ Args:
41
+ cortex_model: 已初始化的 CortexNet 模型
42
+
43
+ Returns:
44
+ 适配后的模型(原地修改)
45
+ """
46
+ logger.info(f"Starting architecture adaptation for model type: {self.model_type}")
47
+
48
+ self._adapt_attention(cortex_model)
49
+ self._adapt_ssm(cortex_model)
50
+ self._adapt_position_encoding(cortex_model)
51
+ self._adapt_dynamic_gating(cortex_model)
52
+ self._adapt_moe(cortex_model)
53
+ self._adapt_normalization(cortex_model)
54
+
55
+ logger.info("Architecture adaptation complete.")
56
+ return cortex_model
57
+
58
+ def _adapt_attention(self, model: nn.Module):
59
+ """适配注意力模块。"""
60
+ preserve_pretrained_behavior = bool(getattr(self.config, "source_model_path", None))
61
+
62
+ for block in model.blocks:
63
+ attn = block.attention
64
+
65
+ # 预训练权重迁移优先保真:Lite 稀疏注意力退化为全注意力,避免随机重要性打分破坏输出。
66
+ if preserve_pretrained_behavior and hasattr(attn, "top_k_ratio"):
67
+ attn.top_k_ratio = 1.0
68
+
69
+ # 滑动窗口(Mistral/Mistral-derived)
70
+ if self.config.sliding_window_size > 0:
71
+ if hasattr(attn, 'sliding_window_size'):
72
+ attn.sliding_window_size = self.config.sliding_window_size
73
+ logger.debug(
74
+ f"Set sliding window size to {self.config.sliding_window_size}"
75
+ )
76
+
77
+ # 根据模型上下文窗口长度调整 top-k
78
+ if (not preserve_pretrained_behavior) and self.config.max_seq_len > 8192:
79
+ # 长上下文模型:降低 top-k 比例以节省显存
80
+ new_ratio = max(0.1, 0.25 * (8192 / self.config.max_seq_len))
81
+ if hasattr(attn, 'top_k_ratio'):
82
+ attn.top_k_ratio = new_ratio
83
+ logger.debug(
84
+ f"Adjusted top_k_ratio to {new_ratio:.3f} "
85
+ f"for max_seq_len={self.config.max_seq_len}"
86
+ )
87
+
88
+ # Lite 三路径中 SSM/Memory 无预训练权重时,先强偏向 Attention 作为稳定基线。
89
+ if preserve_pretrained_behavior and hasattr(block, "fusion_weights"):
90
+ with torch.no_grad():
91
+ block.fusion_weights.copy_(
92
+ torch.tensor(
93
+ [-8.0, 8.0, -8.0],
94
+ device=block.fusion_weights.device,
95
+ dtype=block.fusion_weights.dtype,
96
+ )
97
+ )
98
+
99
+ def _adapt_ssm(self, model: nn.Module):
100
+ """适配 SSM 多尺度参数。"""
101
+ for block in model.blocks:
102
+ ssm = block.ssm
103
+
104
+ # 根据模型隐藏维度调整 SSM 内部扩展
105
+ if self.config.hidden_size >= 4096:
106
+ # 大模型:更多尺度捕获不同粒度的依赖
107
+ if hasattr(ssm, 'num_scales'):
108
+ target_scales = min(8, self.config.hidden_size // 512)
109
+ logger.debug(
110
+ f"SSM scales adjusted for large model: {target_scales}"
111
+ )
112
+
113
+ def _adapt_position_encoding(self, model: nn.Module):
114
+ """适配位置编码参数。"""
115
+ for block in model.blocks:
116
+ attn = block.attention
117
+
118
+ # RoPE scaling(动态 NTK、线性插值等)
119
+ if self.config.rope_scaling:
120
+ scaling_type = self.config.rope_scaling.get("type", "linear")
121
+ scaling_factor = self.config.rope_scaling.get("factor", 1.0)
122
+
123
+ if hasattr(attn, 'rope_scaling_factor'):
124
+ attn.rope_scaling_factor = scaling_factor
125
+ attn.rope_scaling_type = scaling_type
126
+
127
+ logger.debug(
128
+ f"Applied RoPE scaling: type={scaling_type}, factor={scaling_factor}"
129
+ )
130
+
131
+ # 更新 RoPE theta
132
+ if hasattr(attn, 'rope_theta'):
133
+ attn.rope_theta = self.config.rope_theta
134
+
135
+ def _adapt_dynamic_gating(self, model: nn.Module):
136
+ """优化动态门控:从硬二值化改为软裁剪,解决训推不一致。
137
+
138
+ CortexNet V3 的 DynamicPathController 默认使用硬裁剪(eval 时 >0 → 1.0)。
139
+ 这里改为软阈值(sigmoid 连续值),保证训练和推理行为一致。
140
+ """
141
+ _SOFT_THRESHOLD = 0.1 # 软裁剪阈值
142
+
143
+ for block in model.blocks:
144
+ if hasattr(block, 'path_controller'):
145
+ controller = block.path_controller
146
+
147
+ # Monkey-patch forward 方法,使推理时也使用 sigmoid(非硬裁剪)
148
+ original_forward = controller.forward
149
+
150
+ def _soft_forward(self_ctrl, x, _orig=original_forward, _thresh=_SOFT_THRESHOLD):
151
+ context = x.mean(dim=1)
152
+ logits = self_ctrl.gate_net(context)
153
+
154
+ if self_ctrl.training:
155
+ noise = torch.zeros_like(logits).uniform_(1e-4, 1 - 1e-4)
156
+ noise = (torch.log(noise) - torch.log(1 - noise)).clamp(-10, 10)
157
+ gates = torch.sigmoid(
158
+ (logits + noise) / max(self_ctrl.temperature, 0.1)
159
+ )
160
+ else:
161
+ # 软裁剪:sigmoid 输出,而非硬二值
162
+ gates = torch.sigmoid(logits)
163
+ # 低于阈值的路径置零(节省计算但保持连续性)
164
+ gates = gates * (gates > _thresh).float()
165
+
166
+ return gates
167
+
168
+ import types
169
+ controller.forward = types.MethodType(_soft_forward, controller)
170
+
171
+ logger.debug(f"Applied soft gating threshold: {_SOFT_THRESHOLD}")
172
+
173
+ def _adapt_moe(self, model: nn.Module):
174
+ """适配 MoE 路由。
175
+
176
+ 对于非 MoE 原模型(如 LLaMA),需要将原模型的 FFN 权重
177
+ 复制到 CortexNet MoE 的第一个专家,并初始化路由器偏向该专家。
178
+ Lite 模式下无 MoE,自动跳过。
179
+ """
180
+ for block in model.blocks:
181
+ if not hasattr(block, 'moe'):
182
+ continue # Lite 模式使用 FFN,无需适配
183
+ moe = block.moe
184
+
185
+ # 检查是否已经有映射过的权重(experts.0 有非零权重)
186
+ expert_0 = moe.experts[0] if hasattr(moe, 'experts') else None
187
+ if expert_0 is None:
188
+ continue
189
+
190
+ # 初始化路由器偏置,使其倾向于激活 expert 0
191
+ if hasattr(moe, 'router') and hasattr(moe.router, 'weight'):
192
+ router = moe.router
193
+ with torch.no_grad():
194
+ # 对 expert 0 的路由权重加大偏置
195
+ if router.weight.shape[0] > 0:
196
+ router.weight.data[0] += 0.5
197
+
198
+ logger.debug("MoE router initialized with bias toward adapted expert")
199
+
200
+ def _adapt_normalization(self, model: nn.Module):
201
+ """确保归一化层类型与 CortexNet 一致。"""
202
+ # CortexNet 使用 RMSNorm,如果源模型使用 LayerNorm 则参数已在 WeightAdapter 中处理
203
+ # 这里只做最终验证
204
+ for name, module in model.named_modules():
205
+ if isinstance(module, nn.LayerNorm):
206
+ logger.warning(
207
+ f"Found unexpected LayerNorm at {name}. "
208
+ f"CortexNet expects RMSNorm. This may cause subtle differences."
209
+ )
@@ -0,0 +1,244 @@
1
+ """
2
+ 轻量校准器 (Lightweight Calibrator)
3
+
4
+ 核心功能:
5
+ 使用极少量数据(~100 样本,1 epoch)微调 CortexNet 的适配层参数,
6
+ 使模型在适配后的行为与原生模型尽可能一致。
7
+
8
+ 关键设计:
9
+ 1. 仅优化 <1% 的参数(融合门控 + MoE 路由层)
10
+ 2. 冻结所有核心权重(Q/K/V/FFN 等)
11
+ 3. 仅使用自回归交叉熵损失(无辅助损失)
12
+ 4. 校准结果缓存复用
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import os
18
+ import json
19
+ import logging
20
+ import hashlib
21
+ from typing import Dict, List, Optional
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ class LightweightCalibrator:
30
+ """轻量校准器。
31
+
32
+ Args:
33
+ cortex_model: CortexNet 模型
34
+ model_type: 源模型类型
35
+ cache_dir: 校准参数缓存目录
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ cortex_model: nn.Module,
41
+ model_type: str = "default",
42
+ cache_dir: Optional[str] = None,
43
+ ):
44
+ self.model = cortex_model
45
+ self.model_type = model_type
46
+ self.cache_dir = cache_dir or os.path.join(
47
+ os.path.expanduser("~"), ".cache", "cortexnet", "calibration"
48
+ )
49
+ os.makedirs(self.cache_dir, exist_ok=True)
50
+
51
+ def calibrate(
52
+ self,
53
+ calibration_data: Optional[List[Dict[str, torch.Tensor]]] = None,
54
+ n_samples: int = 100,
55
+ lr: float = 1e-5,
56
+ use_cache: bool = True,
57
+ ) -> nn.Module:
58
+ """执行轻量校准。
59
+
60
+ Args:
61
+ calibration_data: 校准数据列表,每个元素包含 "input_ids"
62
+ n_samples: 校准样本数(如未提供数据则自动生成)
63
+ lr: 学习率
64
+ use_cache: 是否使用/保存缓存
65
+
66
+ Returns:
67
+ 校准后的模型
68
+ """
69
+ # 尝试加载缓存
70
+ cache_key = self._get_cache_key()
71
+ if use_cache and self._load_cached(cache_key):
72
+ logger.info("Loaded calibration parameters from cache.")
73
+ return self.model
74
+
75
+ # 准备校准数据
76
+ if calibration_data is None:
77
+ calibration_data = self._generate_calibration_data(n_samples)
78
+
79
+ if not calibration_data:
80
+ logger.warning("No calibration data available. Skipping calibration.")
81
+ return self.model
82
+
83
+ # 冻结核心权重,仅解冻适配层
84
+ trainable_params = self._freeze_core_weights()
85
+
86
+ if not trainable_params:
87
+ logger.info("No trainable adaptation parameters found. Skipping calibration.")
88
+ return self.model
89
+
90
+ trainable_count = sum(p.numel() for p in trainable_params)
91
+ total_count = sum(p.numel() for p in self.model.parameters())
92
+ logger.info(
93
+ f"Calibration: optimizing {trainable_count:,} / {total_count:,} params "
94
+ f"({100 * trainable_count / max(total_count, 1):.2f}%)"
95
+ )
96
+
97
+ # 优化器
98
+ optimizer = torch.optim.AdamW(trainable_params, lr=lr, weight_decay=0.01)
99
+ loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
100
+
101
+ # 1 epoch 快速校准
102
+ self.model.train()
103
+ total_loss = 0.0
104
+ n_steps = 0
105
+
106
+ for batch in calibration_data[:n_samples]:
107
+ input_ids = batch["input_ids"]
108
+ if input_ids.dim() == 1:
109
+ input_ids = input_ids.unsqueeze(0)
110
+
111
+ device = next(self.model.parameters()).device
112
+ input_ids = input_ids.to(device)
113
+ labels = input_ids.clone()
114
+
115
+ # 清零梯度(在 forward 前,确保无残留计算图引用)
116
+ optimizer.zero_grad(set_to_none=True)
117
+
118
+ try:
119
+ # 前向(eval 模式避免 aux_loss 图残留,手动开启梯度)
120
+ self.model.eval()
121
+ with torch.enable_grad():
122
+ output = self.model(input_ids)
123
+ logits = output["logits"]
124
+
125
+ # 移位交叉熵
126
+ shift_logits = logits[:, :-1, :].contiguous()
127
+ shift_labels = labels[:, 1:].contiguous()
128
+ loss = loss_fn(
129
+ shift_logits.view(-1, shift_logits.size(-1)),
130
+ shift_labels.view(-1),
131
+ )
132
+
133
+ # 跳过 NaN/Inf loss(保护训练稳定性)
134
+ if not loss.isfinite():
135
+ continue
136
+
137
+ loss.backward()
138
+
139
+ torch.nn.utils.clip_grad_norm_(trainable_params, max_norm=1.0)
140
+ optimizer.step()
141
+ total_loss += loss.detach().item()
142
+ n_steps += 1
143
+ except RuntimeError:
144
+ # 自动恢复:清除可能残留的梯度
145
+ optimizer.zero_grad(set_to_none=True)
146
+ continue
147
+
148
+ avg_loss = total_loss / max(n_steps, 1)
149
+ logger.info(f"Calibration complete: {n_steps} steps, avg loss={avg_loss:.4f}")
150
+
151
+ # 保存缓存
152
+ if use_cache:
153
+ self._save_cached(cache_key)
154
+
155
+ # 恢复参数梯度状态
156
+ for param in self.model.parameters():
157
+ param.requires_grad = True
158
+
159
+ self.model.eval()
160
+ return self.model
161
+
162
+ def _freeze_core_weights(self) -> List[nn.Parameter]:
163
+ """冻结核心权重,返回可训练的适配层参数。"""
164
+ # 需要优化的参数名模式(适配层)
165
+ trainable_patterns = [
166
+ "fusion", # 融合门控
167
+ "router", # MoE 路由器
168
+ "gate_net", # 动态路径控制器
169
+ "meta_adapter", # 元学习适配器
170
+ "task_controller", # 任务控制器
171
+ "path_controller", # 路径控制器
172
+ ]
173
+
174
+ trainable_params = []
175
+ for name, param in self.model.named_parameters():
176
+ is_trainable = any(pat in name for pat in trainable_patterns)
177
+ param.requires_grad = is_trainable
178
+ if is_trainable:
179
+ trainable_params.append(param)
180
+
181
+ return trainable_params
182
+
183
+ def _generate_calibration_data(
184
+ self,
185
+ n_samples: int,
186
+ ) -> List[Dict[str, torch.Tensor]]:
187
+ """生成合成校准数据(当无真实数据时使用)。
188
+
189
+ 使用随机 token 序列作为校准数据。
190
+ 虽然不如真实文本效果好,但足以校准门控/路由参数。
191
+ """
192
+ vocab_size = getattr(self.model.config, 'vocab_size', 32000)
193
+ max_len = min(getattr(self.model.config, 'max_seq_len', 512), 256)
194
+
195
+ data = []
196
+ for _ in range(n_samples):
197
+ ids = torch.randint(1, vocab_size, (max_len,), dtype=torch.long)
198
+ data.append({"input_ids": ids})
199
+
200
+ logger.info(f"Generated {n_samples} synthetic calibration samples")
201
+ return data
202
+
203
+ def _get_cache_key(self) -> str:
204
+ """生成校准缓存的唯一键。"""
205
+ config_str = json.dumps({
206
+ "model_type": self.model_type,
207
+ "hidden_size": getattr(self.model.config, 'hidden_size', 0),
208
+ "num_layers": getattr(self.model.config, 'num_layers', 0),
209
+ "vocab_size": getattr(self.model.config, 'vocab_size', 0),
210
+ }, sort_keys=True)
211
+ return hashlib.md5(config_str.encode()).hexdigest()[:12]
212
+
213
+ def _save_cached(self, cache_key: str):
214
+ """保存校准参数到缓存。"""
215
+ cache_path = os.path.join(self.cache_dir, f"calibration_{cache_key}.pt")
216
+ # 仅保存适配层参数
217
+ adapt_state = {
218
+ name: param.data.clone()
219
+ for name, param in self.model.named_parameters()
220
+ if any(pat in name for pat in [
221
+ "fusion", "router", "gate_net",
222
+ "meta_adapter", "task_controller", "path_controller",
223
+ ])
224
+ }
225
+ torch.save(adapt_state, cache_path)
226
+ logger.info(f"Saved calibration cache: {cache_path}")
227
+
228
+ def _load_cached(self, cache_key: str) -> bool:
229
+ """从缓存加载校准参数。"""
230
+ cache_path = os.path.join(self.cache_dir, f"calibration_{cache_key}.pt")
231
+ if not os.path.exists(cache_path):
232
+ return False
233
+
234
+ try:
235
+ adapt_state = torch.load(cache_path, map_location="cpu", weights_only=True)
236
+ model_state = self.model.state_dict()
237
+ for name, param in adapt_state.items():
238
+ if name in model_state and model_state[name].shape == param.shape:
239
+ model_state[name] = param
240
+ self.model.load_state_dict(model_state, strict=False)
241
+ return True
242
+ except Exception as e:
243
+ logger.warning(f"Failed to load calibration cache: {e}")
244
+ return False