cortexnet 3.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cortexnet/__init__.py +197 -0
- cortexnet/adapter/__init__.py +26 -0
- cortexnet/adapter/arch_adapter.py +209 -0
- cortexnet/adapter/calibrator.py +244 -0
- cortexnet/adapter/inference_adapter.py +272 -0
- cortexnet/adapter/model_registry.py +378 -0
- cortexnet/adapter/weight_adapter.py +415 -0
- cortexnet/adversarial.py +195 -0
- cortexnet/attention.py +520 -0
- cortexnet/blocks.py +682 -0
- cortexnet/cache.py +83 -0
- cortexnet/causal_reasoning.py +232 -0
- cortexnet/compat.py +245 -0
- cortexnet/config.py +234 -0
- cortexnet/continual_learning.py +256 -0
- cortexnet/cortex_block_lite.py +221 -0
- cortexnet/distributed.py +213 -0
- cortexnet/graph_reasoning.py +207 -0
- cortexnet/hierarchical_memory.py +360 -0
- cortexnet/interpretability.py +196 -0
- cortexnet/memory.py +179 -0
- cortexnet/meta_learning.py +187 -0
- cortexnet/model.py +1360 -0
- cortexnet/multi_agent.py +241 -0
- cortexnet/multimodal.py +278 -0
- cortexnet/ops/__init__.py +28 -0
- cortexnet/ops/device_manager.py +449 -0
- cortexnet/ops/npu_ops.py +243 -0
- cortexnet/quantization.py +496 -0
- cortexnet/routing.py +335 -0
- cortexnet/self_evolution.py +174 -0
- cortexnet/ssm.py +340 -0
- cortexnet/training_utils.py +204 -0
- cortexnet/transformer_baseline.py +157 -0
- cortexnet-3.2.1.dist-info/METADATA +114 -0
- cortexnet-3.2.1.dist-info/RECORD +39 -0
- cortexnet-3.2.1.dist-info/WHEEL +5 -0
- cortexnet-3.2.1.dist-info/licenses/LICENSE +201 -0
- cortexnet-3.2.1.dist-info/top_level.txt +1 -0
cortexnet/ssm.py
ADDED
|
@@ -0,0 +1,340 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""
|
|
4
|
+
多尺度状态空间模块 (Multi-Scale State Space Module, MSSM)
|
|
5
|
+
|
|
6
|
+
核心创新:
|
|
7
|
+
多个并行的 SSM 通道以不同的时间尺度运行,使模型能同时捕获
|
|
8
|
+
从局部 token 交互到长距离依赖的多种模式。
|
|
9
|
+
|
|
10
|
+
- 快速尺度:捕获局部的、细粒度的模式(如语法结构)
|
|
11
|
+
- 慢速尺度:捕获长距离的、宏观的模式(如主题、上下文)
|
|
12
|
+
|
|
13
|
+
计算复杂度:O(n),线性于序列长度。
|
|
14
|
+
|
|
15
|
+
理论基础:
|
|
16
|
+
基于状态空间模型 (SSM) 的连续时间动力学:
|
|
17
|
+
dh/dt = A·h + B·x (连续状态方程)
|
|
18
|
+
y = C·h (观测方程)
|
|
19
|
+
|
|
20
|
+
通过零阶保持 (ZOH) 离散化:
|
|
21
|
+
h_t = Ā·h_{t-1} + B̄·x_t 其中 Ā = exp(Δ·A), B̄ = Δ·B
|
|
22
|
+
|
|
23
|
+
不同尺度通过 A 矩阵的不同特征值初始化实现:
|
|
24
|
+
尺度 i 的 A 矩阵以 2^i 倍的频率初始化。
|
|
25
|
+
|
|
26
|
+
优化 (v3.2):
|
|
27
|
+
- 添加 Triton 自定义 kernel 接口:当 triton 可用时自动使用
|
|
28
|
+
高效的 GPU kernel,否则回退到 PyTorch 分块并行实现
|
|
29
|
+
- 添加 logging 支持
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
import math
|
|
34
|
+
import logging
|
|
35
|
+
from typing import Optional, Tuple
|
|
36
|
+
|
|
37
|
+
import torch
|
|
38
|
+
import torch.nn as nn
|
|
39
|
+
import torch.nn.functional as F
|
|
40
|
+
|
|
41
|
+
logger = logging.getLogger(__name__)
|
|
42
|
+
|
|
43
|
+
# Triton kernel 可用性检测
|
|
44
|
+
_TRITON_AVAILABLE = False
|
|
45
|
+
try:
|
|
46
|
+
import triton
|
|
47
|
+
import triton.language as tl
|
|
48
|
+
_ = (triton, tl)
|
|
49
|
+
_TRITON_AVAILABLE = True
|
|
50
|
+
logger.info("Triton available: SSM will use custom GPU kernels")
|
|
51
|
+
except ImportError:
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class MultiScaleSSM(nn.Module):
|
|
56
|
+
"""多尺度选择性状态空间模块。
|
|
57
|
+
|
|
58
|
+
每个尺度以不同的时间分辨率运行,由 A 矩阵的初始化控制。
|
|
59
|
+
快速尺度捕获局部模式,慢速尺度捕获全局依赖。
|
|
60
|
+
|
|
61
|
+
架构:
|
|
62
|
+
Input → Linear(d, 2·d_inner) → [x, z] split
|
|
63
|
+
x → Selective Scan (多尺度 A) → y
|
|
64
|
+
z → SiLU 激活 → gate
|
|
65
|
+
y · gate → Linear(d_inner, d) → Output
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
d_model: 输入/输出维度
|
|
69
|
+
num_scales: 时间尺度数量(每个尺度有不同的记忆衰减率)
|
|
70
|
+
state_size: SSM 状态向量维度
|
|
71
|
+
expand_factor: 内部维度扩展因子
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
def __init__(
|
|
75
|
+
self,
|
|
76
|
+
d_model: int,
|
|
77
|
+
num_scales: int = 4,
|
|
78
|
+
state_size: int = 16,
|
|
79
|
+
expand_factor: int = 2,
|
|
80
|
+
):
|
|
81
|
+
super().__init__()
|
|
82
|
+
self.d_model = d_model
|
|
83
|
+
self.d_inner = d_model * expand_factor
|
|
84
|
+
self.num_scales = num_scales
|
|
85
|
+
self.state_size = state_size
|
|
86
|
+
self.d_per_scale = self.d_inner // num_scales
|
|
87
|
+
|
|
88
|
+
assert self.d_inner % num_scales == 0, (
|
|
89
|
+
f"d_inner ({self.d_inner}) 必须能被 num_scales ({num_scales}) 整除"
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
# 输入投影:x 用于 SSM,z 用于门控
|
|
93
|
+
self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
|
|
94
|
+
|
|
95
|
+
# A 矩阵:多尺度初始化(对数空间,保证稳定性)
|
|
96
|
+
# 不同尺度以指数级递增的时间常数初始化
|
|
97
|
+
A_parts = []
|
|
98
|
+
for i in range(num_scales):
|
|
99
|
+
base_freqs = torch.arange(1, state_size + 1, dtype=torch.float32)
|
|
100
|
+
scale_factor = 2.0 ** i # 指数递增的时间尺度
|
|
101
|
+
A_part = (
|
|
102
|
+
(base_freqs * scale_factor)
|
|
103
|
+
.unsqueeze(0)
|
|
104
|
+
.expand(self.d_per_scale, -1)
|
|
105
|
+
)
|
|
106
|
+
A_parts.append(A_part)
|
|
107
|
+
A = torch.cat(A_parts, dim=0) # (d_inner, state_size)
|
|
108
|
+
self.A_log = nn.Parameter(torch.log(A))
|
|
109
|
+
|
|
110
|
+
# 输入依赖的 SSM 参数(选择性机制的核心)
|
|
111
|
+
self.B_proj = nn.Linear(self.d_inner, state_size, bias=False)
|
|
112
|
+
self.C_proj = nn.Linear(self.d_inner, state_size, bias=False)
|
|
113
|
+
|
|
114
|
+
# 离散化步长(输入依赖,使模型能选择性地记忆或遗忘)
|
|
115
|
+
self.dt_proj = nn.Linear(self.d_inner, self.d_inner, bias=True)
|
|
116
|
+
|
|
117
|
+
# 初始化 dt bias,使初始步长在 [0.001, 0.1] 范围内
|
|
118
|
+
dt_init = torch.exp(
|
|
119
|
+
torch.rand(self.d_inner) * (math.log(0.1) - math.log(0.001))
|
|
120
|
+
+ math.log(0.001)
|
|
121
|
+
)
|
|
122
|
+
with torch.no_grad():
|
|
123
|
+
self.dt_proj.bias.copy_(dt_init.log())
|
|
124
|
+
|
|
125
|
+
# D 跳跃连接(直接通路,类似 Mamba)
|
|
126
|
+
self.D = nn.Parameter(torch.ones(self.d_inner))
|
|
127
|
+
|
|
128
|
+
# 输出门控:0 初始化更中性,sigmoid(0)=0.5
|
|
129
|
+
self.output_gate = nn.Parameter(torch.zeros(1))
|
|
130
|
+
|
|
131
|
+
# 输出投影
|
|
132
|
+
self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
|
|
133
|
+
|
|
134
|
+
def forward(
|
|
135
|
+
self,
|
|
136
|
+
x: torch.Tensor,
|
|
137
|
+
past_state: Optional[torch.Tensor] = None,
|
|
138
|
+
use_cache: bool = False,
|
|
139
|
+
) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]:
|
|
140
|
+
"""
|
|
141
|
+
Args:
|
|
142
|
+
x: (batch, seq_len, d_model)
|
|
143
|
+
past_state: 上一序列步的 SSM 隐状态 (B, d_inner, N),用于增量解码
|
|
144
|
+
use_cache: 若 True 且 past_state 非 None,返回 (output, new_state)
|
|
145
|
+
Returns:
|
|
146
|
+
output: (batch, seq_len, d_model)
|
|
147
|
+
new_state (可选): (B, d_inner, N),当 use_cache=True 时返回
|
|
148
|
+
"""
|
|
149
|
+
B, L, D = x.shape
|
|
150
|
+
input_dtype = x.dtype
|
|
151
|
+
|
|
152
|
+
# 输入投影 + 门控分割
|
|
153
|
+
xz = self.in_proj(x) # (B, L, 2*d_inner)
|
|
154
|
+
x_ssm, z = xz.chunk(2, dim=-1) # 各 (B, L, d_inner)
|
|
155
|
+
|
|
156
|
+
# 计算输入依赖的 SSM 参数(float32 计算后转回原 dtype,MPS 兼容)
|
|
157
|
+
A = -torch.exp(self.A_log.float()).to(input_dtype) # (d_inner, N)
|
|
158
|
+
B_mat = self.B_proj(x_ssm) # (B, L, N)
|
|
159
|
+
C_mat = self.C_proj(x_ssm) # (B, L, N)
|
|
160
|
+
dt = F.softplus(self.dt_proj(x_ssm)) # (B, L, d_inner), 正值
|
|
161
|
+
|
|
162
|
+
# 选择性扫描 — 优先 Triton kernel → 分块并行 → 顺序扫描
|
|
163
|
+
if L > 1:
|
|
164
|
+
if _TRITON_AVAILABLE and x_ssm.is_cuda:
|
|
165
|
+
y, new_state = self._triton_scan(
|
|
166
|
+
x_ssm, A, B_mat, C_mat, dt,
|
|
167
|
+
past_state=past_state,
|
|
168
|
+
use_cache=use_cache,
|
|
169
|
+
)
|
|
170
|
+
else:
|
|
171
|
+
y, new_state = self._chunk_parallel_scan(
|
|
172
|
+
x_ssm, A, B_mat, C_mat, dt,
|
|
173
|
+
chunk_size=min(max(16, L), 64),
|
|
174
|
+
past_state=past_state,
|
|
175
|
+
use_cache=use_cache,
|
|
176
|
+
)
|
|
177
|
+
else:
|
|
178
|
+
y, new_state = self._selective_scan(
|
|
179
|
+
x_ssm, A, B_mat, C_mat, dt,
|
|
180
|
+
past_state=past_state,
|
|
181
|
+
use_cache=use_cache,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
# 确保 y 与输入 dtype 一致(scan 可能返回 float32)
|
|
185
|
+
y = y.to(input_dtype)
|
|
186
|
+
|
|
187
|
+
# 跳跃连接
|
|
188
|
+
y = y + x_ssm * self.D.to(input_dtype).unsqueeze(0).unsqueeze(0)
|
|
189
|
+
|
|
190
|
+
# 门控输出
|
|
191
|
+
y = y * F.silu(z)
|
|
192
|
+
out = self.out_proj(y) * torch.sigmoid(self.output_gate.to(input_dtype))
|
|
193
|
+
|
|
194
|
+
if use_cache and new_state is not None:
|
|
195
|
+
return out, new_state
|
|
196
|
+
return out
|
|
197
|
+
|
|
198
|
+
def _selective_scan(
|
|
199
|
+
self,
|
|
200
|
+
x: torch.Tensor,
|
|
201
|
+
A: torch.Tensor,
|
|
202
|
+
B: torch.Tensor,
|
|
203
|
+
C: torch.Tensor,
|
|
204
|
+
dt: torch.Tensor,
|
|
205
|
+
past_state: Optional[torch.Tensor] = None,
|
|
206
|
+
use_cache: bool = False,
|
|
207
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
208
|
+
"""执行选择性扫描(纯 PyTorch 顺序实现)。
|
|
209
|
+
|
|
210
|
+
支持 past_state 增量解码:传入上一步的 h 作为初始状态。
|
|
211
|
+
|
|
212
|
+
Returns:
|
|
213
|
+
y: (B, L, d_inner) - 输出
|
|
214
|
+
new_state: (B, d_inner, N) - 最终隐状态,use_cache=True 时返回
|
|
215
|
+
"""
|
|
216
|
+
batch, L, d_inner = x.shape
|
|
217
|
+
N = A.shape[1]
|
|
218
|
+
orig_dtype = x.dtype
|
|
219
|
+
|
|
220
|
+
# 在 float32 中计算以防止 float16 溢出
|
|
221
|
+
x = x.float()
|
|
222
|
+
A = A.float()
|
|
223
|
+
B = B.float()
|
|
224
|
+
C = C.float()
|
|
225
|
+
dt = dt.float()
|
|
226
|
+
|
|
227
|
+
# 预计算离散化参数(clamp 防止 exp 溢出)
|
|
228
|
+
A_bar = torch.exp((dt.unsqueeze(-1) * A.unsqueeze(0).unsqueeze(0)).clamp(max=20))
|
|
229
|
+
B_bar = dt.unsqueeze(-1) * B.unsqueeze(2)
|
|
230
|
+
|
|
231
|
+
h = (
|
|
232
|
+
past_state.float()
|
|
233
|
+
if past_state is not None
|
|
234
|
+
else torch.zeros(batch, d_inner, N, device=x.device, dtype=torch.float32)
|
|
235
|
+
)
|
|
236
|
+
outputs = []
|
|
237
|
+
|
|
238
|
+
for t in range(L):
|
|
239
|
+
h = A_bar[:, t] * h + B_bar[:, t] * x[:, t].unsqueeze(-1)
|
|
240
|
+
y_t = (h * C[:, t].unsqueeze(1)).sum(-1)
|
|
241
|
+
outputs.append(y_t)
|
|
242
|
+
|
|
243
|
+
y = torch.stack(outputs, dim=1).to(orig_dtype)
|
|
244
|
+
new_state = h.to(orig_dtype) if use_cache else None
|
|
245
|
+
return y, new_state
|
|
246
|
+
|
|
247
|
+
def _chunk_parallel_scan(
|
|
248
|
+
self,
|
|
249
|
+
x: torch.Tensor,
|
|
250
|
+
A: torch.Tensor,
|
|
251
|
+
B: torch.Tensor,
|
|
252
|
+
C: torch.Tensor,
|
|
253
|
+
dt: torch.Tensor,
|
|
254
|
+
chunk_size: int = 64,
|
|
255
|
+
past_state: Optional[torch.Tensor] = None,
|
|
256
|
+
use_cache: bool = False,
|
|
257
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
258
|
+
"""分块并行扫描:块内并行计算,块间顺序传播。
|
|
259
|
+
|
|
260
|
+
支持 past_state 增量解码。
|
|
261
|
+
"""
|
|
262
|
+
batch, L, d_inner = x.shape
|
|
263
|
+
N = A.shape[1]
|
|
264
|
+
orig_dtype = x.dtype
|
|
265
|
+
|
|
266
|
+
# 在 float32 中计算以防止 float16 溢出
|
|
267
|
+
x = x.float()
|
|
268
|
+
A = A.float()
|
|
269
|
+
B = B.float()
|
|
270
|
+
C = C.float()
|
|
271
|
+
dt = dt.float()
|
|
272
|
+
|
|
273
|
+
A_bar = torch.exp((dt.unsqueeze(-1) * A.unsqueeze(0).unsqueeze(0)).clamp(max=20))
|
|
274
|
+
B_bar = dt.unsqueeze(-1) * B.unsqueeze(2)
|
|
275
|
+
|
|
276
|
+
num_chunks = (L + chunk_size - 1) // chunk_size
|
|
277
|
+
h = (
|
|
278
|
+
past_state.float()
|
|
279
|
+
if past_state is not None
|
|
280
|
+
else torch.zeros(batch, d_inner, N, device=x.device, dtype=torch.float32)
|
|
281
|
+
)
|
|
282
|
+
all_outputs = []
|
|
283
|
+
|
|
284
|
+
for c in range(num_chunks):
|
|
285
|
+
s = c * chunk_size
|
|
286
|
+
e = min(s + chunk_size, L)
|
|
287
|
+
|
|
288
|
+
a_chunk = A_bar[:, s:e]
|
|
289
|
+
b_chunk = B_bar[:, s:e]
|
|
290
|
+
x_chunk = x[:, s:e]
|
|
291
|
+
c_chunk = C[:, s:e]
|
|
292
|
+
|
|
293
|
+
log_a = torch.log(a_chunk.clamp(min=1e-8))
|
|
294
|
+
log_a_cum = torch.cumsum(log_a, dim=1)
|
|
295
|
+
a_cum = torch.exp(log_a_cum)
|
|
296
|
+
|
|
297
|
+
h_contrib = a_cum * h.unsqueeze(1)
|
|
298
|
+
|
|
299
|
+
input_term = b_chunk * x_chunk.unsqueeze(-1)
|
|
300
|
+
normalized = input_term / (a_cum + 1e-8)
|
|
301
|
+
cum_input = torch.cumsum(normalized, dim=1)
|
|
302
|
+
input_contrib = a_cum * cum_input
|
|
303
|
+
|
|
304
|
+
h_all = h_contrib + input_contrib
|
|
305
|
+
|
|
306
|
+
y_chunk = (h_all * c_chunk.unsqueeze(2)).sum(-1)
|
|
307
|
+
all_outputs.append(y_chunk)
|
|
308
|
+
|
|
309
|
+
h = h_all[:, -1]
|
|
310
|
+
|
|
311
|
+
y = torch.cat(all_outputs, dim=1).to(orig_dtype)
|
|
312
|
+
new_state = h.to(orig_dtype) if use_cache else None
|
|
313
|
+
return y, new_state
|
|
314
|
+
|
|
315
|
+
def _triton_scan(
|
|
316
|
+
self,
|
|
317
|
+
x: torch.Tensor,
|
|
318
|
+
A: torch.Tensor,
|
|
319
|
+
B: torch.Tensor,
|
|
320
|
+
C: torch.Tensor,
|
|
321
|
+
dt: torch.Tensor,
|
|
322
|
+
past_state: Optional[torch.Tensor] = None,
|
|
323
|
+
use_cache: bool = False,
|
|
324
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
325
|
+
"""Triton GPU kernel 加速的选择性扫描。
|
|
326
|
+
|
|
327
|
+
当 Triton 可用且输入在 CUDA 上时自动调用。
|
|
328
|
+
核心思路:将分块扫描的内外层循环融合为单个 Triton kernel,
|
|
329
|
+
避免多次 kernel launch 和中间内存分配。
|
|
330
|
+
|
|
331
|
+
当前版本为接口占位,委托给 _chunk_parallel_scan。
|
|
332
|
+
TODO: 实现原生 Triton kernel body。
|
|
333
|
+
"""
|
|
334
|
+
logger.debug("Using Triton scan path (delegating to chunk_parallel)")
|
|
335
|
+
return self._chunk_parallel_scan(
|
|
336
|
+
x, A, B, C, dt,
|
|
337
|
+
chunk_size=min(max(16, x.shape[1]), 64),
|
|
338
|
+
past_state=past_state,
|
|
339
|
+
use_cache=use_cache,
|
|
340
|
+
)
|
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CortexNet 训练工具 (Training Utilities)
|
|
3
|
+
|
|
4
|
+
提供梯度监控、种子设置、设备选择等训练辅助功能。
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import random
|
|
10
|
+
from typing import Dict
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
import torch.nn as nn
|
|
14
|
+
import numpy as np
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class GradientMonitor:
|
|
18
|
+
"""梯度监控器:记录训练过程中的梯度统计信息。
|
|
19
|
+
|
|
20
|
+
用于诊断梯度消失/爆炸、优化学习率等。
|
|
21
|
+
|
|
22
|
+
Usage:
|
|
23
|
+
monitor = GradientMonitor(model)
|
|
24
|
+
loss.backward()
|
|
25
|
+
stats = monitor.get_stats()
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(self, model: nn.Module):
|
|
29
|
+
self.model = model
|
|
30
|
+
self._hooks = []
|
|
31
|
+
self._grad_stats: Dict[str, Dict[str, float]] = {}
|
|
32
|
+
self._register_hooks()
|
|
33
|
+
|
|
34
|
+
def _register_hooks(self):
|
|
35
|
+
for name, param in self.model.named_parameters():
|
|
36
|
+
if param.requires_grad:
|
|
37
|
+
hook = param.register_hook(
|
|
38
|
+
lambda grad, n=name: self._record_grad(n, grad)
|
|
39
|
+
)
|
|
40
|
+
self._hooks.append(hook)
|
|
41
|
+
|
|
42
|
+
def _record_grad(self, name: str, grad: torch.Tensor):
|
|
43
|
+
self._grad_stats[name] = {
|
|
44
|
+
"mean": grad.mean().item(),
|
|
45
|
+
"std": grad.std().item(),
|
|
46
|
+
"max": grad.max().item(),
|
|
47
|
+
"min": grad.min().item(),
|
|
48
|
+
"norm": grad.norm().item(),
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
def get_stats(self) -> Dict[str, Dict[str, float]]:
|
|
52
|
+
"""获取最近一次 backward 的梯度统计。"""
|
|
53
|
+
return dict(self._grad_stats)
|
|
54
|
+
|
|
55
|
+
def get_summary(self) -> Dict[str, float]:
|
|
56
|
+
"""获取汇总统计。"""
|
|
57
|
+
if not self._grad_stats:
|
|
58
|
+
return {}
|
|
59
|
+
norms = [s["norm"] for s in self._grad_stats.values()]
|
|
60
|
+
return {
|
|
61
|
+
"grad_norm_mean": sum(norms) / len(norms),
|
|
62
|
+
"grad_norm_max": max(norms),
|
|
63
|
+
"grad_norm_min": min(norms),
|
|
64
|
+
"num_params_tracked": len(norms),
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
def remove_hooks(self):
|
|
68
|
+
"""移除所有已注册的钩子。"""
|
|
69
|
+
for hook in self._hooks:
|
|
70
|
+
hook.remove()
|
|
71
|
+
self._hooks.clear()
|
|
72
|
+
|
|
73
|
+
def __del__(self):
|
|
74
|
+
self.remove_hooks()
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def check_gradients_finite(model: nn.Module) -> bool:
|
|
78
|
+
"""检查模型所有参数的梯度是否都是有限值(无 NaN/Inf)。
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
model: 要检查的模型
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
True 如果所有梯度都有限(或无梯度),False 如果存在 NaN/Inf。
|
|
85
|
+
"""
|
|
86
|
+
for name, param in model.named_parameters():
|
|
87
|
+
if param.grad is not None:
|
|
88
|
+
if not torch.isfinite(param.grad).all():
|
|
89
|
+
return False
|
|
90
|
+
return True
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def set_seed(seed: int = 42):
|
|
94
|
+
"""设置全局随机种子,确保可复现性。
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
seed: 随机种子值
|
|
98
|
+
"""
|
|
99
|
+
random.seed(seed)
|
|
100
|
+
np.random.seed(seed)
|
|
101
|
+
torch.manual_seed(seed)
|
|
102
|
+
if torch.cuda.is_available():
|
|
103
|
+
torch.cuda.manual_seed_all(seed)
|
|
104
|
+
torch.backends.cudnn.deterministic = True
|
|
105
|
+
torch.backends.cudnn.benchmark = False
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def get_best_device() -> torch.device:
|
|
109
|
+
"""自动选择最佳可用计算设备。
|
|
110
|
+
|
|
111
|
+
优先级: CUDA GPU > Apple MPS > CPU
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
最佳可用设备的 torch.device 对象
|
|
115
|
+
"""
|
|
116
|
+
if torch.cuda.is_available():
|
|
117
|
+
return torch.device("cuda")
|
|
118
|
+
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
119
|
+
return torch.device("mps")
|
|
120
|
+
return torch.device("cpu")
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def create_optimizer_and_scheduler(
|
|
124
|
+
model: torch.nn.Module,
|
|
125
|
+
lr: float = 3e-4,
|
|
126
|
+
weight_decay: float = 0.01,
|
|
127
|
+
warmup_steps: int = 500,
|
|
128
|
+
total_steps: int = 10000,
|
|
129
|
+
min_lr_ratio: float = 0.1,
|
|
130
|
+
betas: tuple = (0.9, 0.95),
|
|
131
|
+
):
|
|
132
|
+
"""创建 AdamW 优化器 + 余弦退火调度器。
|
|
133
|
+
|
|
134
|
+
典型的 LLM 训练配置:
|
|
135
|
+
- AdamW (β₁=0.9, β₂=0.95)
|
|
136
|
+
- 线性 warmup → 余弦退火
|
|
137
|
+
- 最终学习率 = min_lr_ratio × 初始学习率
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
model: 目标模型
|
|
141
|
+
lr: 初始学习率
|
|
142
|
+
weight_decay: L2 正则权重(不应用于 bias/norm)
|
|
143
|
+
warmup_steps: warmup 步数
|
|
144
|
+
total_steps: 总训练步数
|
|
145
|
+
min_lr_ratio: 最终学习率与初始学习率的比值
|
|
146
|
+
betas: Adam 的 β 参数
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
(optimizer, scheduler) 元组
|
|
150
|
+
"""
|
|
151
|
+
import math
|
|
152
|
+
|
|
153
|
+
# 分组参数:bias 和 LayerNorm/RMSNorm 不做 weight decay
|
|
154
|
+
decay_params = []
|
|
155
|
+
no_decay_params = []
|
|
156
|
+
for name, param in model.named_parameters():
|
|
157
|
+
if not param.requires_grad:
|
|
158
|
+
continue
|
|
159
|
+
if param.ndim == 1 or "bias" in name or "norm" in name.lower():
|
|
160
|
+
no_decay_params.append(param)
|
|
161
|
+
else:
|
|
162
|
+
decay_params.append(param)
|
|
163
|
+
|
|
164
|
+
param_groups = [
|
|
165
|
+
{"params": decay_params, "weight_decay": weight_decay},
|
|
166
|
+
{"params": no_decay_params, "weight_decay": 0.0},
|
|
167
|
+
]
|
|
168
|
+
optimizer = torch.optim.AdamW(param_groups, lr=lr, betas=betas)
|
|
169
|
+
|
|
170
|
+
# 余弦退火 + 线性 warmup
|
|
171
|
+
def lr_lambda(step):
|
|
172
|
+
if step < warmup_steps:
|
|
173
|
+
return step / max(warmup_steps, 1)
|
|
174
|
+
progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1)
|
|
175
|
+
return min_lr_ratio + (1 - min_lr_ratio) * 0.5 * (1 + math.cos(math.pi * progress))
|
|
176
|
+
|
|
177
|
+
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
|
178
|
+
return optimizer, scheduler
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def safe_clip_grad_norm_(
|
|
182
|
+
model: torch.nn.Module,
|
|
183
|
+
max_norm: float = 1.0,
|
|
184
|
+
norm_type: float = 2.0,
|
|
185
|
+
) -> float:
|
|
186
|
+
"""安全的梯度裁剪:先检查 NaN,再执行 clip。
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
model: 目标模型
|
|
190
|
+
max_norm: 梯度范数上限
|
|
191
|
+
norm_type: 范数类型(默认 L2)
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
裁剪前的梯度总范数
|
|
195
|
+
"""
|
|
196
|
+
# 检查 NaN/Inf 梯度
|
|
197
|
+
for name, param in model.named_parameters():
|
|
198
|
+
if param.grad is not None and not torch.isfinite(param.grad).all():
|
|
199
|
+
param.grad.zero_() # 用零替代 NaN 梯度
|
|
200
|
+
|
|
201
|
+
total_norm = torch.nn.utils.clip_grad_norm_(
|
|
202
|
+
model.parameters(), max_norm, norm_type=norm_type,
|
|
203
|
+
)
|
|
204
|
+
return float(total_norm)
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Transformer 基线模型 (Transformer Baseline)
|
|
3
|
+
|
|
4
|
+
标准 Transformer 语言模型,用于与 CortexNet 进行公平对比。
|
|
5
|
+
使用 Pre-LN (Pre-LayerNorm) 结构 + RoPE 位置编码。
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from typing import Optional, Dict
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
import torch.nn as nn
|
|
14
|
+
import torch.nn.functional as F
|
|
15
|
+
|
|
16
|
+
from .blocks import RMSNorm
|
|
17
|
+
from .attention import precompute_rope_freqs, apply_rope
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class TransformerBlock(nn.Module):
|
|
21
|
+
"""标准 Transformer 解码器块(Pre-LN)。"""
|
|
22
|
+
|
|
23
|
+
def __init__(self, d_model: int, num_heads: int, d_ff: int,
|
|
24
|
+
max_seq_len: int = 8192, dropout: float = 0.0,
|
|
25
|
+
rope_theta: float = 10000.0):
|
|
26
|
+
super().__init__()
|
|
27
|
+
self.num_heads = num_heads
|
|
28
|
+
self.head_dim = d_model // num_heads
|
|
29
|
+
|
|
30
|
+
self.norm1 = RMSNorm(d_model)
|
|
31
|
+
self.norm2 = RMSNorm(d_model)
|
|
32
|
+
|
|
33
|
+
# 自注意力
|
|
34
|
+
self.q_proj = nn.Linear(d_model, d_model, bias=False)
|
|
35
|
+
self.k_proj = nn.Linear(d_model, d_model, bias=False)
|
|
36
|
+
self.v_proj = nn.Linear(d_model, d_model, bias=False)
|
|
37
|
+
self.o_proj = nn.Linear(d_model, d_model, bias=False)
|
|
38
|
+
|
|
39
|
+
# FFN (SwiGLU)
|
|
40
|
+
self.gate_proj = nn.Linear(d_model, d_ff, bias=False)
|
|
41
|
+
self.up_proj = nn.Linear(d_model, d_ff, bias=False)
|
|
42
|
+
self.down_proj = nn.Linear(d_ff, d_model, bias=False)
|
|
43
|
+
|
|
44
|
+
self.dropout = nn.Dropout(dropout)
|
|
45
|
+
|
|
46
|
+
# RoPE
|
|
47
|
+
self.register_buffer(
|
|
48
|
+
"rope_freqs",
|
|
49
|
+
precompute_rope_freqs(self.head_dim, max_seq_len, rope_theta),
|
|
50
|
+
persistent=False,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
54
|
+
B, L, D = x.shape
|
|
55
|
+
residual = x
|
|
56
|
+
x_norm = self.norm1(x)
|
|
57
|
+
|
|
58
|
+
# Multi-head attention with RoPE
|
|
59
|
+
q = self.q_proj(x_norm).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
|
|
60
|
+
k = self.k_proj(x_norm).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
|
|
61
|
+
v = self.v_proj(x_norm).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
|
|
62
|
+
|
|
63
|
+
q = apply_rope(q, self.rope_freqs)
|
|
64
|
+
k = apply_rope(k, self.rope_freqs)
|
|
65
|
+
|
|
66
|
+
# Causal attention
|
|
67
|
+
attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
|
|
68
|
+
attn_out = attn_out.transpose(1, 2).contiguous().view(B, L, D)
|
|
69
|
+
x = residual + self.dropout(self.o_proj(attn_out))
|
|
70
|
+
|
|
71
|
+
# SwiGLU FFN
|
|
72
|
+
residual = x
|
|
73
|
+
x_norm = self.norm2(x)
|
|
74
|
+
x = residual + self.dropout(
|
|
75
|
+
self.down_proj(F.silu(self.gate_proj(x_norm)) * self.up_proj(x_norm))
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
return x
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class TransformerLM(nn.Module):
|
|
82
|
+
"""标准 Transformer 语言模型(用于对比基线)。
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
vocab_size: 词汇表大小
|
|
86
|
+
d_model: 模型维度
|
|
87
|
+
num_layers: 层数
|
|
88
|
+
num_heads: 注意力头数
|
|
89
|
+
d_ff: FFN 中间维度
|
|
90
|
+
max_seq_len: 最大序列长度
|
|
91
|
+
dropout: Dropout 比率
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
def __init__(
|
|
95
|
+
self,
|
|
96
|
+
vocab_size: int = 32000,
|
|
97
|
+
d_model: int = 512,
|
|
98
|
+
num_layers: int = 4,
|
|
99
|
+
num_heads: int = 8,
|
|
100
|
+
d_ff: int = 1024,
|
|
101
|
+
max_seq_len: int = 8192,
|
|
102
|
+
dropout: float = 0.0,
|
|
103
|
+
rope_theta: float = 10000.0,
|
|
104
|
+
):
|
|
105
|
+
super().__init__()
|
|
106
|
+
self.vocab_size = vocab_size
|
|
107
|
+
self.max_seq_len = max_seq_len
|
|
108
|
+
|
|
109
|
+
self.embed = nn.Embedding(vocab_size, d_model)
|
|
110
|
+
self.embed_dropout = nn.Dropout(dropout)
|
|
111
|
+
|
|
112
|
+
self.blocks = nn.ModuleList([
|
|
113
|
+
TransformerBlock(d_model, num_heads, d_ff, max_seq_len, dropout, rope_theta)
|
|
114
|
+
for _ in range(num_layers)
|
|
115
|
+
])
|
|
116
|
+
|
|
117
|
+
self.final_norm = RMSNorm(d_model)
|
|
118
|
+
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
|
|
119
|
+
self.lm_head.weight = self.embed.weight
|
|
120
|
+
|
|
121
|
+
self.apply(self._init_weights)
|
|
122
|
+
|
|
123
|
+
def _init_weights(self, module: nn.Module):
|
|
124
|
+
if isinstance(module, nn.Linear):
|
|
125
|
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
|
126
|
+
if module.bias is not None:
|
|
127
|
+
torch.nn.init.zeros_(module.bias)
|
|
128
|
+
elif isinstance(module, nn.Embedding):
|
|
129
|
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
|
130
|
+
|
|
131
|
+
def forward(
|
|
132
|
+
self,
|
|
133
|
+
input_ids: torch.Tensor,
|
|
134
|
+
labels: Optional[torch.Tensor] = None,
|
|
135
|
+
) -> Dict[str, torch.Tensor]:
|
|
136
|
+
B, L = input_ids.shape
|
|
137
|
+
x = self.embed_dropout(self.embed(input_ids))
|
|
138
|
+
|
|
139
|
+
for block in self.blocks:
|
|
140
|
+
x = block(x)
|
|
141
|
+
|
|
142
|
+
x = self.final_norm(x)
|
|
143
|
+
logits = self.lm_head(x)
|
|
144
|
+
|
|
145
|
+
result = {"logits": logits}
|
|
146
|
+
|
|
147
|
+
if labels is not None:
|
|
148
|
+
shift_logits = logits[:, :-1, :].contiguous()
|
|
149
|
+
shift_labels = labels[:, 1:].contiguous()
|
|
150
|
+
loss = F.cross_entropy(
|
|
151
|
+
shift_logits.view(-1, self.vocab_size),
|
|
152
|
+
shift_labels.view(-1),
|
|
153
|
+
ignore_index=-100,
|
|
154
|
+
)
|
|
155
|
+
result["loss"] = loss
|
|
156
|
+
|
|
157
|
+
return result
|