cortexnet 3.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cortexnet/config.py ADDED
@@ -0,0 +1,234 @@
1
+ """
2
+ CortexNet 配置模块 (Configuration Module)
3
+
4
+ 定义 CortexNet 架构和训练的所有超参数。
5
+ 使用 dataclass 实现类型安全的配置管理。
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import logging
11
+ from dataclasses import dataclass
12
+ from typing import Optional
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ @dataclass
18
+ class CortexNetConfig:
19
+ """CortexNet 架构配置。
20
+
21
+ 控制模型的规模、结构和行为。所有 CortexNet 模块共享此配置。
22
+
23
+ Attributes:
24
+ vocab_size: 词汇表大小
25
+ hidden_size: 隐藏层维度 (d_model)
26
+ num_layers: CortexBlock 层数
27
+ num_heads: 注意力头数
28
+ num_scales: SSM 多尺度数
29
+ ssm_state_size: SSM 状态维度
30
+ ssm_expand_factor: SSM 内部扩展倍数
31
+ top_k_ratio: 稀疏注意力 top-k 比例
32
+ attention_k_mode: top-k 计算模式 ("ratio", "sqrt", "log", "fixed")
33
+ max_seq_len: 最大序列长度
34
+ rope_theta: RoPE 频率基数
35
+ dropout: Dropout 比率
36
+ memory_dim: 记忆模块维度
37
+ memory_decay_init: 记忆衰减初始值
38
+ expert_ff_dim: 专家 FFN 中间维度
39
+ num_experts: 专家总数
40
+ num_active_experts: 每个 token 激活的专家数
41
+ moe_aux_loss_weight: MoE 辅助损失权重
42
+ moe_capacity_factor: MoE 容量因子
43
+ norm_eps: 归一化层 epsilon
44
+ """
45
+
46
+ # ═══ 基础架构参数 ═══
47
+ vocab_size: int = 32000
48
+ hidden_size: int = 512
49
+ num_layers: int = 4
50
+ num_heads: int = 8
51
+ max_seq_len: int = 8192
52
+ dropout: float = 0.0
53
+ norm_eps: float = 1e-6
54
+
55
+ # ═══ SSM 参数 ═══
56
+ num_scales: int = 4
57
+ ssm_state_size: int = 16
58
+ ssm_expand_factor: int = 2
59
+
60
+ # ═══ 注意力参数 ═══
61
+ top_k_ratio: float = 0.25
62
+ attention_k_mode: str = "ratio"
63
+ rope_theta: float = 10000.0
64
+ sliding_window_size: int = 0
65
+
66
+ # ═══ 记忆参数 ═══
67
+ memory_dim: int = 64
68
+ memory_decay_init: float = 0.95
69
+ episodic_slots: int = 32
70
+ semantic_slots: int = 64
71
+
72
+ # ═══ MoE 参数 ═══
73
+ expert_ff_dim: int = 1024
74
+ num_experts: int = 8
75
+ num_active_experts: int = 2
76
+ moe_aux_loss_weight: float = 0.02
77
+ moe_capacity_factor: float = 1.25 # MoE 路由容量因子
78
+
79
+ # ═══ V2/V3 扩展参数 ═══
80
+ graph_neighbors: int = 16
81
+ graph_iterations: int = 2
82
+ num_task_modes: int = 4
83
+ num_counterfactuals: int = 4
84
+ num_agents: int = 4
85
+ use_gradient_checkpointing: bool = False
86
+ use_mixture_of_depths: bool = False
87
+ mod_capacity: float = 0.5
88
+ causal_top_k_ratio: float = 0.25 # 因果推理干预注意力的 top-k 比例
89
+
90
+ # ═══ 适配器参数(新增:用于开源模型加载) ═══
91
+ model_type: str = "cortexnet" # 源模型类型标识
92
+ source_model_path: Optional[str] = None # HuggingFace 模型路径
93
+ auto_calibrate: bool = True # 是否在加载后自动校准
94
+ intermediate_size: int = 0 # 源模型中间层维度(如有)
95
+ num_kv_heads: int = 0 # GQA 中的 KV 头数(0 = 与 num_heads 相同)
96
+ use_qk_norm: bool = False # 注意力中是否启用 per-head Q/K RMSNorm(Qwen2/3)
97
+ rope_scaling: Optional[dict] = None # RoPE 缩放配置
98
+ tie_word_embeddings: bool = True # 是否绑定嵌入和输出权重
99
+ compatibility_mode: bool = False # 兼容大模型的轻量 V3 路径
100
+ lite: bool = True # Lite 模式:仅 SSM+Attention+Memory+FFN(参数减少 60-70%)
101
+ ssm_decode_after: int = 0 # SSM 纯解码阈值(>0 时,超过此 token 数后仅用 SSM)
102
+ expand_gqa_weights: bool = True # 是否将 KV 投影扩展到全头(旧结构兼容)
103
+ compat_ssm_rank: int = 256 # 兼容模式 SSM 低秩维度(控制增量参数量)
104
+ fusion_long_context_threshold: int = 2048 # 长序列时启用 SSM 最低占比的阈值
105
+ fusion_long_context_ssm_ratio: float = 0.35 # 长序列时 SSM 最低占比
106
+ mapped_cache_enabled: bool = False # 是否启用映射后权重缓存(可选,默认关闭)
107
+ mapped_cache_dir: Optional[str] = None # 映射缓存目录(None 使用默认目录)
108
+ mapped_cache_force_refresh: bool = False # 是否强制忽略缓存并重建
109
+ mapped_cache_auto_enable_with_lazy: bool = True # lazy_device_load 时自动启用映射缓存
110
+ mapped_cache_fast_init_on_hit: bool = True # 命中映射缓存时跳过额外自定义重初始化
111
+ lazy_device_load: bool = False # 是否启用惰性上设备(from_pretrained 快速返回)
112
+ lazy_cpu_fallback: bool = True # 惰性阶段是否允许 CPU 兜底推理
113
+ lazy_background_warmup: bool = True # 首次推理后是否后台预热到目标设备
114
+ lazy_start_warmup_on_load: bool = True # from_pretrained 返回前立即启动后台预热线程
115
+ lazy_disable_on_cache_hit: bool = True # 命中映射缓存时自动关闭 lazy,优先首 token 体验
116
+
117
+ def __post_init__(self):
118
+ # ═══ 默认值推导 ═══
119
+ if self.num_kv_heads == 0:
120
+ self.num_kv_heads = self.num_heads
121
+ if self.intermediate_size == 0:
122
+ self.intermediate_size = self.expert_ff_dim
123
+
124
+ # ═══ 参数验证 ═══
125
+ self._validate()
126
+
127
+ def _validate(self):
128
+ """全面的参数验证,确保配置合法且防止隐蔽错误。"""
129
+ # 基础维度检查
130
+ if self.hidden_size % self.num_heads != 0:
131
+ raise ValueError(
132
+ f"hidden_size ({self.hidden_size}) 必须能被 num_heads ({self.num_heads}) 整除"
133
+ )
134
+ if self.num_kv_heads > self.num_heads:
135
+ raise ValueError(
136
+ f"num_kv_heads ({self.num_kv_heads}) 不能超过 num_heads ({self.num_heads})"
137
+ )
138
+ if self.num_heads % self.num_kv_heads != 0:
139
+ raise ValueError(
140
+ f"num_heads ({self.num_heads}) 必须能被 num_kv_heads ({self.num_kv_heads}) 整除 (GQA 要求)"
141
+ )
142
+
143
+ # MoE 参数检查
144
+ if self.num_active_experts > self.num_experts:
145
+ raise ValueError(
146
+ f"num_active_experts ({self.num_active_experts}) 不能超过 num_experts ({self.num_experts})"
147
+ )
148
+ if self.moe_capacity_factor <= 0:
149
+ raise ValueError(
150
+ f"moe_capacity_factor ({self.moe_capacity_factor}) 必须为正数"
151
+ )
152
+
153
+ # 注意力参数检查
154
+ if not (0 < self.top_k_ratio <= 1.0):
155
+ raise ValueError(
156
+ f"top_k_ratio ({self.top_k_ratio}) 必须在 (0, 1.0] 范围内"
157
+ )
158
+ _valid_k_modes = {"ratio", "sqrt", "log", "fixed"}
159
+ if self.attention_k_mode not in _valid_k_modes:
160
+ raise ValueError(
161
+ f"attention_k_mode ('{self.attention_k_mode}') 必须是 {_valid_k_modes} 之一"
162
+ )
163
+
164
+ # 因果推理参数检查
165
+ if not (0 < self.causal_top_k_ratio <= 1.0):
166
+ raise ValueError(
167
+ f"causal_top_k_ratio ({self.causal_top_k_ratio}) 必须在 (0, 1.0] 范围内"
168
+ )
169
+
170
+ # 正数检查
171
+ if self.hidden_size <= 0:
172
+ raise ValueError(f"hidden_size ({self.hidden_size}) 必须为正整数")
173
+ if self.num_layers <= 0:
174
+ raise ValueError(f"num_layers ({self.num_layers}) 必须为正整数")
175
+ if self.max_seq_len <= 0:
176
+ raise ValueError(f"max_seq_len ({self.max_seq_len}) 必须为正整数")
177
+ if self.vocab_size <= 0:
178
+ raise ValueError(f"vocab_size ({self.vocab_size}) 必须为正整数")
179
+
180
+ # 范围检查
181
+ if not (0 <= self.dropout < 1.0):
182
+ raise ValueError(
183
+ f"dropout ({self.dropout}) 必须在 [0, 1.0) 范围内"
184
+ )
185
+
186
+ # 混合精度检查
187
+ if self.fusion_long_context_ssm_ratio < 0 or self.fusion_long_context_ssm_ratio > 1:
188
+ raise ValueError(
189
+ f"fusion_long_context_ssm_ratio ({self.fusion_long_context_ssm_ratio}) 必须在 [0, 1] 范围内"
190
+ )
191
+
192
+ # 软警告
193
+ if self.hidden_size < 64:
194
+ logger.warning(f"hidden_size={self.hidden_size} 过小,可能影响模型表达能力")
195
+ if self.num_experts > 1 and self.num_active_experts < 1:
196
+ logger.warning("num_active_experts < 1,MoE 路由可能无效")
197
+
198
+ @classmethod
199
+ def from_dict(cls, d: dict) -> "CortexNetConfig":
200
+ """从字典创建配置,忽略未知字段。"""
201
+ valid_fields = {f.name for f in cls.__dataclass_fields__.values()}
202
+ filtered = {k: v for k, v in d.items() if k in valid_fields}
203
+ return cls(**filtered)
204
+
205
+
206
+ @dataclass
207
+ class TrainingConfig:
208
+ """训练超参数配置。
209
+
210
+ Attributes:
211
+ learning_rate: 学习率
212
+ weight_decay: 权重衰减
213
+ num_epochs: 训练轮数
214
+ batch_size: 批大小
215
+ gradient_accumulation_steps: 梯度累积步数
216
+ max_grad_norm: 梯度裁剪最大范数
217
+ warmup_steps: 学习率预热步数
218
+ eval_interval: 评估间隔步数
219
+ save_interval: 保存间隔步数
220
+ mixed_precision: 混合精度类型 ("no", "fp16", "bf16")
221
+ """
222
+
223
+ learning_rate: float = 3e-4
224
+ weight_decay: float = 0.01
225
+ num_epochs: int = 3
226
+ batch_size: int = 8
227
+ gradient_accumulation_steps: int = 1
228
+ max_grad_norm: float = 1.0
229
+ warmup_steps: int = 100
230
+ eval_interval: int = 500
231
+ save_interval: int = 1000
232
+ mixed_precision: str = "no"
233
+ seed: int = 42
234
+ log_interval: int = 10
@@ -0,0 +1,256 @@
1
+ """
2
+ 连续学习模块 (Continual Learning Module)
3
+
4
+ 核心创新:
5
+ 解决神经网络在学习新任务时"灾难性遗忘"旧知识的问题。
6
+ 实现两种互补的反遗忘策略:
7
+
8
+ 1. 弹性权重巩固 (EWC, Elastic Weight Consolidation)
9
+ - 计算 Fisher 信息矩阵,识别对旧任务重要的参数
10
+ - 在训练新任务时,惩罚对这些重要参数的大幅修改
11
+ - 效果:旧知识被"保护",新知识在不重要的参数空间中学习
12
+
13
+ 2. 渐进式记忆回放 (Progressive Memory Replay)
14
+ - 维护一个经验缓冲区,存储旧任务的关键样本
15
+ - 在训练新任务时,混合旧样本进行回放
16
+ - 效果:周期性"复习"旧知识,防止遗忘
17
+
18
+ 类比人类学习:
19
+ - EWC ≈ 大脑对重要神经连接的保护机制
20
+ - 回放 ≈ 人类的睡眠巩固和复习
21
+ """
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ from typing import Dict, Optional
26
+
27
+
28
+ class ElasticWeightConsolidation:
29
+ """弹性权重巩固 (EWC)。
30
+
31
+ 通过 Fisher 信息矩阵估计每个参数对已学任务的重要性,
32
+ 在学习新任务时添加正则化项以保护重要参数。
33
+
34
+ 数学原理:
35
+ L_total = L_new_task + λ · Σ_i F_i · (θ_i - θ*_i)²
36
+
37
+ 其中 F_i 是第 i 个参数的 Fisher 信息(重要性),
38
+ θ*_i 是在旧任务上的最优参数值。
39
+
40
+ Args:
41
+ model: CortexNet 模型
42
+ lambda_ewc: EWC 正则化强度
43
+ """
44
+
45
+ def __init__(self, model: nn.Module, lambda_ewc: float = 1000.0):
46
+ self.model = model
47
+ self.lambda_ewc = lambda_ewc
48
+ self.fisher: Dict[str, torch.Tensor] = {}
49
+ self.optimal_params: Dict[str, torch.Tensor] = {}
50
+ self._consolidated = False
51
+ self._num_consolidations = 0
52
+
53
+ def consolidate(self, dataloader, num_samples: int = 200):
54
+ """计算 Fisher 信息并保存当前最优参数。
55
+
56
+ 在完成一个任务的训练后调用,将当前知识"巩固"。
57
+
58
+ Args:
59
+ dataloader: 当前任务的数据加载器
60
+ num_samples: 用于估计 Fisher 的样本数
61
+ """
62
+ self.model.train() # 需要梯度计算
63
+ fisher = {
64
+ n: torch.zeros_like(p)
65
+ for n, p in self.model.named_parameters()
66
+ if p.requires_grad
67
+ }
68
+
69
+ count = 0
70
+ for batch in dataloader:
71
+ if count >= num_samples:
72
+ break
73
+
74
+ input_ids, labels = batch
75
+ input_ids = input_ids.to(
76
+ next(self.model.parameters()).device
77
+ )
78
+ labels = labels.to(next(self.model.parameters()).device)
79
+
80
+ self.model.zero_grad()
81
+ output = self.model(input_ids, labels=labels)
82
+
83
+ # 使用对数似然的梯度平方作为 Fisher 近似
84
+ output["loss"].backward()
85
+
86
+ for n, p in self.model.named_parameters():
87
+ if p.requires_grad and p.grad is not None:
88
+ fisher[n] += p.grad.pow(2)
89
+
90
+ count += 1
91
+
92
+ # 平均 Fisher 信息
93
+ for n in fisher:
94
+ fisher[n] /= max(count, 1)
95
+
96
+ # 如果之前已巩固,使用在线更新(累积 Fisher)
97
+ if self._consolidated:
98
+ for n in fisher:
99
+ if n in self.fisher:
100
+ fisher[n] = (
101
+ self.fisher[n] * self._num_consolidations
102
+ + fisher[n]
103
+ ) / (self._num_consolidations + 1)
104
+
105
+ self.fisher = fisher
106
+ self.optimal_params = {
107
+ n: p.clone().detach()
108
+ for n, p in self.model.named_parameters()
109
+ if p.requires_grad
110
+ }
111
+ self._consolidated = True
112
+ self._num_consolidations += 1
113
+
114
+ def penalty(self) -> torch.Tensor:
115
+ """计算 EWC 正则化惩罚。
116
+
117
+ 在训练新任务时添加到损失函数中。
118
+
119
+ Returns:
120
+ penalty: 标量张量
121
+ """
122
+ if not self._consolidated:
123
+ device = next(self.model.parameters()).device
124
+ return torch.tensor(0.0, device=device)
125
+
126
+ loss = torch.tensor(
127
+ 0.0, device=next(self.model.parameters()).device
128
+ )
129
+ for n, p in self.model.named_parameters():
130
+ if n in self.fisher and p.requires_grad:
131
+ loss = loss + (
132
+ self.fisher[n] * (p - self.optimal_params[n]).pow(2)
133
+ ).sum()
134
+
135
+ return self.lambda_ewc * loss
136
+
137
+
138
+ class ProgressiveMemoryReplay:
139
+ """渐进式记忆回放。
140
+
141
+ 维护一个经验缓冲区,在训练新任务时混合旧样本。
142
+ 使用 reservoir sampling 确保缓冲区均匀覆盖所有旧任务。
143
+
144
+ Args:
145
+ buffer_size: 缓冲区最大样本数
146
+ replay_ratio: 每批中旧样本的比例
147
+ """
148
+
149
+ def __init__(self, buffer_size: int = 5000, replay_ratio: float = 0.3):
150
+ self.buffer_size = buffer_size
151
+ self.replay_ratio = replay_ratio
152
+ self.buffer_inputs = []
153
+ self.buffer_labels = []
154
+ self._count = 0
155
+
156
+ def add_samples(
157
+ self, input_ids: torch.Tensor, labels: torch.Tensor
158
+ ):
159
+ """向缓冲区添加样本(使用 reservoir sampling)。"""
160
+ batch_size = input_ids.shape[0]
161
+
162
+ for i in range(batch_size):
163
+ if len(self.buffer_inputs) < self.buffer_size:
164
+ self.buffer_inputs.append(input_ids[i].cpu())
165
+ self.buffer_labels.append(labels[i].cpu())
166
+ else:
167
+ # Reservoir sampling
168
+ idx = torch.randint(0, self._count + 1, (1,)).item()
169
+ if idx < self.buffer_size:
170
+ self.buffer_inputs[idx] = input_ids[i].cpu()
171
+ self.buffer_labels[idx] = labels[i].cpu()
172
+ self._count += 1
173
+
174
+ def get_replay_batch(
175
+ self, batch_size: int, device: torch.device
176
+ ) -> Optional[tuple]:
177
+ """获取回放批次。"""
178
+ if len(self.buffer_inputs) == 0:
179
+ return None
180
+
181
+ num_replay = max(1, int(batch_size * self.replay_ratio))
182
+ num_replay = min(num_replay, len(self.buffer_inputs))
183
+
184
+ indices = torch.randperm(len(self.buffer_inputs))[:num_replay]
185
+
186
+ replay_inputs = torch.stack(
187
+ [self.buffer_inputs[i] for i in indices]
188
+ ).to(device)
189
+ replay_labels = torch.stack(
190
+ [self.buffer_labels[i] for i in indices]
191
+ ).to(device)
192
+
193
+ return replay_inputs, replay_labels
194
+
195
+ @property
196
+ def size(self) -> int:
197
+ return len(self.buffer_inputs)
198
+
199
+
200
+ class ContinualLearningManager:
201
+ """连续学习管理器:整合 EWC + 记忆回放。
202
+
203
+ 提供统一接口管理连续学习的各个组件。
204
+
205
+ 使用方法:
206
+ manager = ContinualLearningManager(model)
207
+
208
+ # 任务 1 训练
209
+ for batch in task1_loader:
210
+ loss = model(batch)
211
+ loss = loss + manager.get_regularization_loss()
212
+ loss.backward()
213
+
214
+ # 巩固任务 1 的知识
215
+ manager.consolidate_task(task1_loader)
216
+
217
+ # 任务 2 训练(自动保护任务 1 的知识)
218
+ for batch in task2_loader:
219
+ loss = model(batch)
220
+ loss = loss + manager.get_regularization_loss()
221
+ replay = manager.get_replay_batch(batch_size)
222
+ if replay:
223
+ replay_loss = model(replay)
224
+ loss = loss + replay_loss * 0.5
225
+ loss.backward()
226
+ """
227
+
228
+ def __init__(
229
+ self,
230
+ model: nn.Module,
231
+ lambda_ewc: float = 1000.0,
232
+ buffer_size: int = 5000,
233
+ replay_ratio: float = 0.3,
234
+ ):
235
+ self.ewc = ElasticWeightConsolidation(model, lambda_ewc)
236
+ self.replay = ProgressiveMemoryReplay(buffer_size, replay_ratio)
237
+ self.task_count = 0
238
+
239
+ def consolidate_task(self, dataloader, num_samples: int = 200):
240
+ """巩固当前任务的知识。"""
241
+ self.ewc.consolidate(dataloader, num_samples)
242
+ self.task_count += 1
243
+
244
+ def get_regularization_loss(self) -> torch.Tensor:
245
+ """获取防遗忘正则化损失。"""
246
+ return self.ewc.penalty()
247
+
248
+ def add_experience(
249
+ self, input_ids: torch.Tensor, labels: torch.Tensor
250
+ ):
251
+ """记录训练样本到经验缓冲区。"""
252
+ self.replay.add_samples(input_ids, labels)
253
+
254
+ def get_replay_batch(self, batch_size: int, device: torch.device):
255
+ """获取回放样本。"""
256
+ return self.replay.get_replay_batch(batch_size, device)