cortexnet 3.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,272 @@
1
+ """
2
+ 推理适配层 (Inference Adapter)
3
+
4
+ 核心功能:
5
+ 提供统一的生成接口,自动适配不同模型的默认参数、
6
+ 生成策略和加速配置。
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import logging
12
+ from typing import Any, Dict, Iterator, Optional
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ # 各模型族的默认生成参数
21
+ DEFAULT_GENERATION_PARAMS: Dict[str, Dict[str, Any]] = {
22
+ "llama": {
23
+ "temperature": 0.6,
24
+ "top_p": 0.9,
25
+ "top_k": 50,
26
+ "repetition_penalty": 1.1,
27
+ "max_new_tokens": 512,
28
+ },
29
+ "qwen2": {
30
+ "temperature": 0.7,
31
+ "top_p": 0.8,
32
+ "top_k": 20,
33
+ "repetition_penalty": 1.05,
34
+ "max_new_tokens": 512,
35
+ },
36
+ "qwen3": {
37
+ "temperature": 0.7,
38
+ "top_p": 0.8,
39
+ "top_k": 20,
40
+ "repetition_penalty": 1.05,
41
+ "max_new_tokens": 512,
42
+ },
43
+ "mistral": {
44
+ "temperature": 0.7,
45
+ "top_p": 0.9,
46
+ "top_k": 50,
47
+ "repetition_penalty": 1.0,
48
+ "max_new_tokens": 512,
49
+ },
50
+ "chatglm": {
51
+ "temperature": 0.8,
52
+ "top_p": 0.8,
53
+ "top_k": 50,
54
+ "repetition_penalty": 1.02,
55
+ "max_new_tokens": 512,
56
+ },
57
+ "baichuan": {
58
+ "temperature": 0.3,
59
+ "top_p": 0.85,
60
+ "top_k": 5,
61
+ "repetition_penalty": 1.05,
62
+ "max_new_tokens": 512,
63
+ },
64
+ "phi": {
65
+ "temperature": 0.7,
66
+ "top_p": 0.9,
67
+ "top_k": 50,
68
+ "repetition_penalty": 1.0,
69
+ "max_new_tokens": 256,
70
+ },
71
+ "default": {
72
+ "temperature": 0.7,
73
+ "top_p": 0.9,
74
+ "top_k": 50,
75
+ "repetition_penalty": 1.0,
76
+ "max_new_tokens": 512,
77
+ },
78
+ }
79
+
80
+
81
+ class InferenceAdapter:
82
+ """推理适配器:统一生成接口。
83
+
84
+ 自动匹配源模型的默认生成参数,并启用 CortexNet 加速策略。
85
+
86
+ Args:
87
+ cortex_model: CortexNet 模型实例
88
+ model_type: 源模型类型
89
+ """
90
+
91
+ def __init__(self, cortex_model: Any, model_type: str = "default"):
92
+ self.model = cortex_model
93
+ self.model_type = model_type
94
+ self.default_params = self._get_default_params()
95
+
96
+ def _get_default_params(self) -> Dict[str, Any]:
97
+ """获取模型族的默认生成参数。"""
98
+ params = DEFAULT_GENERATION_PARAMS.get(
99
+ self.model_type,
100
+ DEFAULT_GENERATION_PARAMS["default"],
101
+ )
102
+ return dict(params)
103
+
104
+ def generate(
105
+ self,
106
+ input_ids: torch.Tensor,
107
+ *,
108
+ max_new_tokens: Optional[int] = None,
109
+ temperature: Optional[float] = None,
110
+ top_k: Optional[int] = None,
111
+ top_p: Optional[float] = None,
112
+ repetition_penalty: Optional[float] = None,
113
+ stream: bool = False,
114
+ use_speculative: bool = False,
115
+ **kwargs,
116
+ ):
117
+ """统一生成接口。
118
+
119
+ Args:
120
+ input_ids: (batch, seq_len) 输入 token 索引
121
+ max_new_tokens: 最大生成 token 数
122
+ temperature: 采样温度
123
+ top_k: Top-k 采样
124
+ top_p: 核采样阈值
125
+ repetition_penalty: 重复惩罚
126
+ stream: 是否流式输出
127
+ use_speculative: 是否使用推测式解码
128
+ **kwargs: 其他生成参数
129
+
130
+ Returns:
131
+ stream=False: (batch, seq_len + max_new_tokens) token 索引
132
+ stream=True: token 迭代器
133
+ """
134
+ # 合并参数(用户参数优先 > 模型默认参数)
135
+ params = dict(self.default_params)
136
+ if max_new_tokens is not None:
137
+ params["max_new_tokens"] = max_new_tokens
138
+ if temperature is not None:
139
+ params["temperature"] = temperature
140
+ if top_k is not None:
141
+ params["top_k"] = top_k
142
+ if top_p is not None:
143
+ params["top_p"] = top_p
144
+ if repetition_penalty is not None:
145
+ params["repetition_penalty"] = repetition_penalty
146
+ params.update(kwargs)
147
+
148
+ if stream:
149
+ return self._stream_generate(input_ids, params)
150
+
151
+ # 推测式解码(如果模型支持且已开启)
152
+ if use_speculative and hasattr(self.model, 'speculative_generate'):
153
+ return self.model.speculative_generate(
154
+ input_ids,
155
+ max_new_tokens=params["max_new_tokens"],
156
+ temperature=params["temperature"],
157
+ top_k=params["top_k"],
158
+ )
159
+
160
+ # 标准生成
161
+ return self.model.generate(
162
+ input_ids,
163
+ max_new_tokens=params["max_new_tokens"],
164
+ temperature=params["temperature"],
165
+ top_k=params["top_k"],
166
+ top_p=params["top_p"],
167
+ repetition_penalty=params.get("repetition_penalty", 1.0),
168
+ )
169
+
170
+ def _stream_generate(
171
+ self,
172
+ input_ids: torch.Tensor,
173
+ params: Dict[str, Any],
174
+ ) -> Iterator[torch.Tensor]:
175
+ """流式生成:逐 token 产出。
176
+
177
+ Yields:
178
+ 每步新生成的 token (batch, 1)
179
+ """
180
+ self.model.eval()
181
+ generated = input_ids
182
+ max_new_tokens = params["max_new_tokens"]
183
+ temperature = params["temperature"]
184
+ top_k = params["top_k"]
185
+ top_p = params["top_p"]
186
+ repetition_penalty = params.get("repetition_penalty", 1.0)
187
+ defer_lazy_warmup = bool(
188
+ getattr(self.model, "_lazy_enabled", False)
189
+ and not getattr(self.model, "_lazy_ready", True)
190
+ )
191
+ supports_cache_forward = (
192
+ hasattr(self.model, "forward")
193
+ and hasattr(self.model.forward, "__code__")
194
+ and "past_cache" in self.model.forward.__code__.co_varnames
195
+ )
196
+ use_repetition_penalty = repetition_penalty != 1.0 and repetition_penalty > 0
197
+
198
+ past_cache = None
199
+ seen_mask: Optional[torch.Tensor] = None
200
+
201
+ try:
202
+ with torch.no_grad():
203
+ for _ in range(max_new_tokens):
204
+ if past_cache is None:
205
+ idx_cond = generated
206
+ else:
207
+ idx_cond = generated[:, -1:]
208
+
209
+ # 前向
210
+ if supports_cache_forward:
211
+ output = self.model.forward(
212
+ idx_cond,
213
+ past_cache=past_cache,
214
+ use_cache=True,
215
+ start_warmup_after_infer=not defer_lazy_warmup,
216
+ )
217
+ past_cache = output.get("past_cache")
218
+ else:
219
+ output = self.model.forward(idx_cond)
220
+
221
+ logits = output["logits"][:, -1, :] / max(temperature, 1e-8)
222
+
223
+ # 向量化重复惩罚:避免 Python token 循环
224
+ if use_repetition_penalty:
225
+ if seen_mask is None or seen_mask.shape[1] != logits.shape[1]:
226
+ seen_mask = torch.zeros(
227
+ logits.shape[0],
228
+ logits.shape[1],
229
+ device=logits.device,
230
+ dtype=torch.bool,
231
+ )
232
+ seen_ids = generated.clamp_max(logits.shape[1] - 1)
233
+ seen_mask.scatter_(1, seen_ids, True)
234
+ positive = logits > 0
235
+ logits = torch.where(
236
+ seen_mask & positive, logits / repetition_penalty, logits
237
+ )
238
+ logits = torch.where(
239
+ seen_mask & ~positive, logits * repetition_penalty, logits
240
+ )
241
+
242
+ # Top-k
243
+ if top_k > 0:
244
+ top_k_val = min(top_k, logits.size(-1))
245
+ _, top_indices = torch.topk(logits, top_k_val)
246
+ mask = torch.ones_like(logits, dtype=torch.bool)
247
+ mask.scatter_(1, top_indices, False)
248
+ logits.masked_fill_(mask, float("-inf"))
249
+
250
+ # Top-p
251
+ if top_p < 1.0:
252
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
253
+ cumprobs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
254
+ sorted_remove = cumprobs > top_p
255
+ sorted_remove[..., 1:] = sorted_remove[..., :-1].clone()
256
+ sorted_remove[..., 0] = 0
257
+ remove = sorted_remove.scatter(1, sorted_indices, sorted_remove)
258
+ logits[remove] = float("-inf")
259
+
260
+ # NaN/Inf 安全处理(float16 可能溢出)
261
+ logits = logits.clamp(-100, 100)
262
+ logits = torch.where(torch.isnan(logits), torch.zeros_like(logits), logits)
263
+ probs = F.softmax(logits, dim=-1)
264
+ probs = probs.clamp(min=1e-8)
265
+ next_token = torch.multinomial(probs, num_samples=1)
266
+ generated = torch.cat([generated, next_token], dim=1)
267
+ if seen_mask is not None:
268
+ seen_mask.scatter_(1, next_token, True)
269
+ yield next_token
270
+ finally:
271
+ if defer_lazy_warmup and hasattr(self.model, "start_background_warmup"):
272
+ self.model.start_background_warmup()
@@ -0,0 +1,378 @@
1
+ """
2
+ 模型注册表与自动识别 (Model Registry & Auto-Detection)
3
+
4
+ 核心功能:
5
+ 1. 从 HuggingFace config.json 自动识别模型类型
6
+ 2. 将 HuggingFace 配置转换为 CortexNetConfig
7
+ 3. 维护支持模型的注册表
8
+
9
+ 支持的模型族:
10
+ LLaMA 2/3, Qwen 1.5/2, Mistral, Baichuan 1/2, ChatGLM 3/4,
11
+ Phi 2/3, Yi, DeepSeek, InternLM, Gemma, CodeLlama, Falcon
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import json
17
+ import os
18
+ import logging
19
+ from dataclasses import dataclass
20
+ from typing import Dict, Any, Optional, List
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ @dataclass
26
+ class ModelInfo:
27
+ """模型族元信息。"""
28
+ family: str # 模型族标识 (如 "llama")
29
+ display_name: str # 显示名称
30
+ architectures: List[str] # HuggingFace architectures 字段匹配
31
+ model_type_patterns: List[str] # config.json 中 model_type 匹配模式
32
+ weight_format: str # 权重命名格式族 ("llama", "qwen", "glm" 等)
33
+ default_rope_theta: float = 10000.0 # 默认 RoPE theta
34
+ supports_gqa: bool = False # 是否使用 GQA
35
+ default_norm_type: str = "rmsnorm" # 默认归一化类型
36
+ position_encoding: str = "rope" # 位置编码类型
37
+
38
+
39
+ # ═══════════════════════════════════════════════════════════════
40
+ # 模型注册表
41
+ # ═══════════════════════════════════════════════════════════════
42
+
43
+ MODEL_REGISTRY: Dict[str, ModelInfo] = {
44
+ "llama": ModelInfo(
45
+ family="llama",
46
+ display_name="LLaMA / LLaMA 2 / LLaMA 3",
47
+ architectures=["LlamaForCausalLM"],
48
+ model_type_patterns=["llama"],
49
+ weight_format="llama",
50
+ default_rope_theta=10000.0,
51
+ supports_gqa=True,
52
+ default_norm_type="rmsnorm",
53
+ ),
54
+ "qwen2": ModelInfo(
55
+ family="qwen2",
56
+ display_name="Qwen 1.5 / Qwen 2",
57
+ architectures=["Qwen2ForCausalLM", "QWenLMHeadModel"],
58
+ model_type_patterns=["qwen2", "qwen"],
59
+ weight_format="qwen",
60
+ default_rope_theta=10000.0,
61
+ supports_gqa=True,
62
+ default_norm_type="rmsnorm",
63
+ ),
64
+ "qwen3": ModelInfo(
65
+ family="qwen3",
66
+ display_name="Qwen 3",
67
+ architectures=["Qwen3ForCausalLM"],
68
+ model_type_patterns=["qwen3", "qwen"],
69
+ weight_format="qwen",
70
+ default_rope_theta=1000000.0,
71
+ supports_gqa=True,
72
+ default_norm_type="rmsnorm",
73
+ ),
74
+ "mistral": ModelInfo(
75
+ family="mistral",
76
+ display_name="Mistral / Mixtral",
77
+ architectures=["MistralForCausalLM", "MixtralForCausalLM"],
78
+ model_type_patterns=["mistral", "mixtral"],
79
+ weight_format="llama",
80
+ default_rope_theta=10000.0,
81
+ supports_gqa=True,
82
+ default_norm_type="rmsnorm",
83
+ ),
84
+ "baichuan": ModelInfo(
85
+ family="baichuan",
86
+ display_name="Baichuan 1 / Baichuan 2",
87
+ architectures=["BaichuanForCausalLM", "BaiChuanForCausalLM"],
88
+ model_type_patterns=["baichuan"],
89
+ weight_format="baichuan",
90
+ default_rope_theta=10000.0,
91
+ supports_gqa=False,
92
+ default_norm_type="rmsnorm",
93
+ ),
94
+ "chatglm": ModelInfo(
95
+ family="chatglm",
96
+ display_name="ChatGLM 3 / GLM-4",
97
+ architectures=["ChatGLMModel", "ChatGLMForConditionalGeneration"],
98
+ model_type_patterns=["chatglm"],
99
+ weight_format="glm",
100
+ default_rope_theta=10000.0,
101
+ supports_gqa=True,
102
+ default_norm_type="rmsnorm",
103
+ position_encoding="rope",
104
+ ),
105
+ "phi": ModelInfo(
106
+ family="phi",
107
+ display_name="Phi 2 / Phi 3",
108
+ architectures=["PhiForCausalLM", "Phi3ForCausalLM"],
109
+ model_type_patterns=["phi", "phi3"],
110
+ weight_format="phi",
111
+ default_rope_theta=10000.0,
112
+ supports_gqa=True,
113
+ default_norm_type="layernorm",
114
+ ),
115
+ "yi": ModelInfo(
116
+ family="yi",
117
+ display_name="Yi",
118
+ architectures=["YiForCausalLM"],
119
+ model_type_patterns=["yi"],
120
+ weight_format="llama",
121
+ default_rope_theta=5000000.0,
122
+ supports_gqa=True,
123
+ default_norm_type="rmsnorm",
124
+ ),
125
+ "deepseek": ModelInfo(
126
+ family="deepseek",
127
+ display_name="DeepSeek / DeepSeek V2",
128
+ architectures=["DeepSeekForCausalLM", "DeepseekV2ForCausalLM"],
129
+ model_type_patterns=["deepseek"],
130
+ weight_format="llama",
131
+ default_rope_theta=10000.0,
132
+ supports_gqa=True,
133
+ default_norm_type="rmsnorm",
134
+ ),
135
+ "internlm": ModelInfo(
136
+ family="internlm",
137
+ display_name="InternLM / InternLM 2",
138
+ architectures=["InternLMForCausalLM", "InternLM2ForCausalLM"],
139
+ model_type_patterns=["internlm", "internlm2"],
140
+ weight_format="llama",
141
+ default_rope_theta=10000.0,
142
+ supports_gqa=True,
143
+ default_norm_type="rmsnorm",
144
+ ),
145
+ "gemma": ModelInfo(
146
+ family="gemma",
147
+ display_name="Gemma / Gemma 2",
148
+ architectures=["GemmaForCausalLM", "Gemma2ForCausalLM"],
149
+ model_type_patterns=["gemma", "gemma2"],
150
+ weight_format="gemma",
151
+ default_rope_theta=10000.0,
152
+ supports_gqa=True,
153
+ default_norm_type="rmsnorm",
154
+ ),
155
+ "falcon": ModelInfo(
156
+ family="falcon",
157
+ display_name="Falcon",
158
+ architectures=["FalconForCausalLM", "RWForCausalLM"],
159
+ model_type_patterns=["falcon"],
160
+ weight_format="falcon",
161
+ default_rope_theta=10000.0,
162
+ supports_gqa=True,
163
+ default_norm_type="layernorm",
164
+ ),
165
+ "codelama": ModelInfo(
166
+ family="codelama",
167
+ display_name="Code Llama",
168
+ architectures=["LlamaForCausalLM"],
169
+ model_type_patterns=["llama"],
170
+ weight_format="llama",
171
+ default_rope_theta=1000000.0,
172
+ supports_gqa=True,
173
+ default_norm_type="rmsnorm",
174
+ ),
175
+ }
176
+
177
+
178
+ class ModelRegistry:
179
+ """模型注册表管理器。"""
180
+
181
+ @staticmethod
182
+ def list_supported() -> List[str]:
183
+ """返回所有支持的模型族名称。"""
184
+ return list(MODEL_REGISTRY.keys())
185
+
186
+ @staticmethod
187
+ def get_info(family: str) -> Optional[ModelInfo]:
188
+ """获取指定模型族的信息。"""
189
+ return MODEL_REGISTRY.get(family)
190
+
191
+ @staticmethod
192
+ def register(family: str, info: ModelInfo):
193
+ """注册新的模型族。"""
194
+ MODEL_REGISTRY[family] = info
195
+ logger.info(f"Registered model family: {family} ({info.display_name})")
196
+
197
+
198
+ def _load_hf_config(model_path: str) -> Dict[str, Any]:
199
+ """加载 HuggingFace config.json。"""
200
+ config_path = os.path.join(model_path, "config.json")
201
+ if not os.path.exists(config_path):
202
+ raise FileNotFoundError(
203
+ f"No config.json found at {model_path}. "
204
+ "Please provide a valid HuggingFace model directory."
205
+ )
206
+ with open(config_path, "r", encoding="utf-8") as f:
207
+ return json.load(f)
208
+
209
+
210
+ def detect_model_type(model_path: str) -> str:
211
+ """从 HuggingFace 模型目录自动识别模型类型。
212
+
213
+ 识别逻辑(按优先级):
214
+ 1. config.json 中的 architectures 字段精确匹配
215
+ 2. config.json 中的 model_type 字段模式匹配
216
+ 3. 模型目录名称启发式匹配
217
+
218
+ Args:
219
+ model_path: 模型目录路径
220
+
221
+ Returns:
222
+ 模型族标识符 (如 "llama", "qwen2", "mistral")
223
+
224
+ Raises:
225
+ ValueError: 无法识别模型类型
226
+ """
227
+ hf_config = _load_hf_config(model_path)
228
+
229
+ # 策略1:architectures 字段精确匹配
230
+ architectures = hf_config.get("architectures", [])
231
+ for family, info in MODEL_REGISTRY.items():
232
+ for arch in architectures:
233
+ if arch in info.architectures:
234
+ logger.info(f"Detected model type '{family}' via architecture: {arch}")
235
+ return family
236
+
237
+ # 策略2:model_type 字段模式匹配
238
+ model_type = hf_config.get("model_type", "").lower()
239
+ for family, info in MODEL_REGISTRY.items():
240
+ for pattern in info.model_type_patterns:
241
+ if pattern in model_type:
242
+ logger.info(f"Detected model type '{family}' via model_type: {model_type}")
243
+ return family
244
+
245
+ # 策略3:目录名称启发式
246
+ dir_name = os.path.basename(model_path).lower()
247
+ for family in MODEL_REGISTRY:
248
+ if family in dir_name:
249
+ logger.info(f"Detected model type '{family}' via directory name: {dir_name}")
250
+ return family
251
+
252
+ raise ValueError(
253
+ f"Cannot detect model type from {model_path}. "
254
+ f"Supported model families: {list(MODEL_REGISTRY.keys())}. "
255
+ f"Found architectures={architectures}, model_type='{model_type}'."
256
+ )
257
+
258
+
259
+ def get_cortexnet_config(model_path: str, model_type: Optional[str] = None):
260
+ """将 HuggingFace 配置转换为 CortexNetConfig。
261
+
262
+ 自动从 HuggingFace config.json 提取所有必要参数,
263
+ 并映射到 CortexNet 的配置格式。
264
+
265
+ Args:
266
+ model_path: 模型目录路径
267
+ model_type: 模型类型(可选,自动检测)
268
+
269
+ Returns:
270
+ 生成的 CortexNetConfig
271
+ """
272
+ try:
273
+ from ..config import CortexNetConfig
274
+ except ImportError:
275
+ # 兼容脚本式导入: `from adapter.model_registry import ...`
276
+ from cortexnet.config import CortexNetConfig
277
+
278
+ hf_config = _load_hf_config(model_path)
279
+
280
+ if model_type is None:
281
+ model_type = detect_model_type(model_path)
282
+
283
+ model_info = MODEL_REGISTRY.get(model_type)
284
+
285
+ # 提取通用参数(不同模型的命名可能不同)
286
+ vocab_size = hf_config.get("vocab_size", 32000)
287
+ hidden_size = hf_config.get("hidden_size", 4096)
288
+ num_layers = hf_config.get(
289
+ "num_hidden_layers",
290
+ hf_config.get("num_layers", 32),
291
+ )
292
+ num_heads = hf_config.get(
293
+ "num_attention_heads",
294
+ hf_config.get("num_heads", 32),
295
+ )
296
+ num_kv_heads = hf_config.get(
297
+ "num_key_value_heads",
298
+ hf_config.get("multi_query_group_num", num_heads),
299
+ )
300
+ intermediate_size = hf_config.get(
301
+ "intermediate_size",
302
+ hf_config.get("ffn_hidden_size", hidden_size * 4),
303
+ )
304
+ max_seq_len = hf_config.get(
305
+ "max_position_embeddings",
306
+ hf_config.get("seq_length", 8192),
307
+ )
308
+ rope_theta = hf_config.get(
309
+ "rope_theta",
310
+ model_info.default_rope_theta if model_info else 10000.0,
311
+ )
312
+ rope_scaling = hf_config.get("rope_scaling")
313
+ norm_eps = hf_config.get(
314
+ "rms_norm_eps",
315
+ hf_config.get("layer_norm_epsilon", 1e-6),
316
+ )
317
+ tie_word_embeddings = hf_config.get("tie_word_embeddings", True)
318
+ use_qk_norm = bool(model_type in {"qwen2", "qwen3"})
319
+
320
+ # 滑动窗口(Mistral 等)
321
+ sliding_window = hf_config.get("sliding_window", 0) or 0
322
+
323
+ # 为兼容已有大模型推理,默认不放大 FFN 参数量:
324
+ # 将 MoE 退化为单专家,相当于普通前馈层,保证参数规模与源模型同量级。
325
+ num_experts = 1
326
+ num_active_experts = 1
327
+ expert_ff_dim = intermediate_size
328
+
329
+ # 构建 CortexNetConfig
330
+ config = CortexNetConfig(
331
+ vocab_size=vocab_size,
332
+ hidden_size=hidden_size,
333
+ num_layers=num_layers,
334
+ num_heads=num_heads,
335
+ num_kv_heads=num_kv_heads,
336
+ use_qk_norm=use_qk_norm,
337
+ max_seq_len=max_seq_len,
338
+ rope_theta=rope_theta,
339
+ rope_scaling=rope_scaling,
340
+ norm_eps=norm_eps,
341
+ dropout=0.0, # 推理时关闭 dropout
342
+ # SSM 参数根据模型规模自适应
343
+ num_scales=min(4, max(2, hidden_size // 1024)),
344
+ ssm_state_size=16,
345
+ ssm_expand_factor=2,
346
+ # 注意力参数
347
+ top_k_ratio=0.25,
348
+ attention_k_mode="ratio",
349
+ sliding_window_size=sliding_window,
350
+ # 记忆参数
351
+ memory_dim=min(128, hidden_size // 4),
352
+ memory_decay_init=0.95,
353
+ # MoE 参数(兼容模式下为单专家)
354
+ expert_ff_dim=expert_ff_dim,
355
+ num_experts=num_experts,
356
+ num_active_experts=num_active_experts,
357
+ moe_aux_loss_weight=0.02,
358
+ # 适配器元数据
359
+ model_type=model_type,
360
+ source_model_path=model_path,
361
+ intermediate_size=intermediate_size,
362
+ tie_word_embeddings=tie_word_embeddings,
363
+ # Lite 模式(默认):SSM + Attention + Memory + FFN,参数高效
364
+ compatibility_mode=False,
365
+ lite=True,
366
+ # Qwen/LLaMA 等 GQA 模型保持原生 KV 头维度,不做全头扩展
367
+ expand_gqa_weights=False,
368
+ # SSM 低秩参数化,降低额外参数开销
369
+ compat_ssm_rank=min(256, max(64, hidden_size // 16)),
370
+ )
371
+
372
+ logger.info(
373
+ f"Created CortexNetConfig for {model_type}: "
374
+ f"hidden={hidden_size}, layers={num_layers}, heads={num_heads}, "
375
+ f"kv_heads={num_kv_heads}, vocab={vocab_size}"
376
+ )
377
+
378
+ return config