cortexnet 3.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cortexnet/ssm.py ADDED
@@ -0,0 +1,340 @@
1
+ from __future__ import annotations
2
+
3
+ """
4
+ 多尺度状态空间模块 (Multi-Scale State Space Module, MSSM)
5
+
6
+ 核心创新:
7
+ 多个并行的 SSM 通道以不同的时间尺度运行,使模型能同时捕获
8
+ 从局部 token 交互到长距离依赖的多种模式。
9
+
10
+ - 快速尺度:捕获局部的、细粒度的模式(如语法结构)
11
+ - 慢速尺度:捕获长距离的、宏观的模式(如主题、上下文)
12
+
13
+ 计算复杂度:O(n),线性于序列长度。
14
+
15
+ 理论基础:
16
+ 基于状态空间模型 (SSM) 的连续时间动力学:
17
+ dh/dt = A·h + B·x (连续状态方程)
18
+ y = C·h (观测方程)
19
+
20
+ 通过零阶保持 (ZOH) 离散化:
21
+ h_t = Ā·h_{t-1} + B̄·x_t 其中 Ā = exp(Δ·A), B̄ = Δ·B
22
+
23
+ 不同尺度通过 A 矩阵的不同特征值初始化实现:
24
+ 尺度 i 的 A 矩阵以 2^i 倍的频率初始化。
25
+
26
+ 优化 (v3.2):
27
+ - 添加 Triton 自定义 kernel 接口:当 triton 可用时自动使用
28
+ 高效的 GPU kernel,否则回退到 PyTorch 分块并行实现
29
+ - 添加 logging 支持
30
+ """
31
+
32
+
33
+ import math
34
+ import logging
35
+ from typing import Optional, Tuple
36
+
37
+ import torch
38
+ import torch.nn as nn
39
+ import torch.nn.functional as F
40
+
41
+ logger = logging.getLogger(__name__)
42
+
43
+ # Triton kernel 可用性检测
44
+ _TRITON_AVAILABLE = False
45
+ try:
46
+ import triton
47
+ import triton.language as tl
48
+ _ = (triton, tl)
49
+ _TRITON_AVAILABLE = True
50
+ logger.info("Triton available: SSM will use custom GPU kernels")
51
+ except ImportError:
52
+ pass
53
+
54
+
55
+ class MultiScaleSSM(nn.Module):
56
+ """多尺度选择性状态空间模块。
57
+
58
+ 每个尺度以不同的时间分辨率运行,由 A 矩阵的初始化控制。
59
+ 快速尺度捕获局部模式,慢速尺度捕获全局依赖。
60
+
61
+ 架构:
62
+ Input → Linear(d, 2·d_inner) → [x, z] split
63
+ x → Selective Scan (多尺度 A) → y
64
+ z → SiLU 激活 → gate
65
+ y · gate → Linear(d_inner, d) → Output
66
+
67
+ Args:
68
+ d_model: 输入/输出维度
69
+ num_scales: 时间尺度数量(每个尺度有不同的记忆衰减率)
70
+ state_size: SSM 状态向量维度
71
+ expand_factor: 内部维度扩展因子
72
+ """
73
+
74
+ def __init__(
75
+ self,
76
+ d_model: int,
77
+ num_scales: int = 4,
78
+ state_size: int = 16,
79
+ expand_factor: int = 2,
80
+ ):
81
+ super().__init__()
82
+ self.d_model = d_model
83
+ self.d_inner = d_model * expand_factor
84
+ self.num_scales = num_scales
85
+ self.state_size = state_size
86
+ self.d_per_scale = self.d_inner // num_scales
87
+
88
+ assert self.d_inner % num_scales == 0, (
89
+ f"d_inner ({self.d_inner}) 必须能被 num_scales ({num_scales}) 整除"
90
+ )
91
+
92
+ # 输入投影:x 用于 SSM,z 用于门控
93
+ self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
94
+
95
+ # A 矩阵:多尺度初始化(对数空间,保证稳定性)
96
+ # 不同尺度以指数级递增的时间常数初始化
97
+ A_parts = []
98
+ for i in range(num_scales):
99
+ base_freqs = torch.arange(1, state_size + 1, dtype=torch.float32)
100
+ scale_factor = 2.0 ** i # 指数递增的时间尺度
101
+ A_part = (
102
+ (base_freqs * scale_factor)
103
+ .unsqueeze(0)
104
+ .expand(self.d_per_scale, -1)
105
+ )
106
+ A_parts.append(A_part)
107
+ A = torch.cat(A_parts, dim=0) # (d_inner, state_size)
108
+ self.A_log = nn.Parameter(torch.log(A))
109
+
110
+ # 输入依赖的 SSM 参数(选择性机制的核心)
111
+ self.B_proj = nn.Linear(self.d_inner, state_size, bias=False)
112
+ self.C_proj = nn.Linear(self.d_inner, state_size, bias=False)
113
+
114
+ # 离散化步长(输入依赖,使模型能选择性地记忆或遗忘)
115
+ self.dt_proj = nn.Linear(self.d_inner, self.d_inner, bias=True)
116
+
117
+ # 初始化 dt bias,使初始步长在 [0.001, 0.1] 范围内
118
+ dt_init = torch.exp(
119
+ torch.rand(self.d_inner) * (math.log(0.1) - math.log(0.001))
120
+ + math.log(0.001)
121
+ )
122
+ with torch.no_grad():
123
+ self.dt_proj.bias.copy_(dt_init.log())
124
+
125
+ # D 跳跃连接(直接通路,类似 Mamba)
126
+ self.D = nn.Parameter(torch.ones(self.d_inner))
127
+
128
+ # 输出门控:0 初始化更中性,sigmoid(0)=0.5
129
+ self.output_gate = nn.Parameter(torch.zeros(1))
130
+
131
+ # 输出投影
132
+ self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
133
+
134
+ def forward(
135
+ self,
136
+ x: torch.Tensor,
137
+ past_state: Optional[torch.Tensor] = None,
138
+ use_cache: bool = False,
139
+ ) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]:
140
+ """
141
+ Args:
142
+ x: (batch, seq_len, d_model)
143
+ past_state: 上一序列步的 SSM 隐状态 (B, d_inner, N),用于增量解码
144
+ use_cache: 若 True 且 past_state 非 None,返回 (output, new_state)
145
+ Returns:
146
+ output: (batch, seq_len, d_model)
147
+ new_state (可选): (B, d_inner, N),当 use_cache=True 时返回
148
+ """
149
+ B, L, D = x.shape
150
+ input_dtype = x.dtype
151
+
152
+ # 输入投影 + 门控分割
153
+ xz = self.in_proj(x) # (B, L, 2*d_inner)
154
+ x_ssm, z = xz.chunk(2, dim=-1) # 各 (B, L, d_inner)
155
+
156
+ # 计算输入依赖的 SSM 参数(float32 计算后转回原 dtype,MPS 兼容)
157
+ A = -torch.exp(self.A_log.float()).to(input_dtype) # (d_inner, N)
158
+ B_mat = self.B_proj(x_ssm) # (B, L, N)
159
+ C_mat = self.C_proj(x_ssm) # (B, L, N)
160
+ dt = F.softplus(self.dt_proj(x_ssm)) # (B, L, d_inner), 正值
161
+
162
+ # 选择性扫描 — 优先 Triton kernel → 分块并行 → 顺序扫描
163
+ if L > 1:
164
+ if _TRITON_AVAILABLE and x_ssm.is_cuda:
165
+ y, new_state = self._triton_scan(
166
+ x_ssm, A, B_mat, C_mat, dt,
167
+ past_state=past_state,
168
+ use_cache=use_cache,
169
+ )
170
+ else:
171
+ y, new_state = self._chunk_parallel_scan(
172
+ x_ssm, A, B_mat, C_mat, dt,
173
+ chunk_size=min(max(16, L), 64),
174
+ past_state=past_state,
175
+ use_cache=use_cache,
176
+ )
177
+ else:
178
+ y, new_state = self._selective_scan(
179
+ x_ssm, A, B_mat, C_mat, dt,
180
+ past_state=past_state,
181
+ use_cache=use_cache,
182
+ )
183
+
184
+ # 确保 y 与输入 dtype 一致(scan 可能返回 float32)
185
+ y = y.to(input_dtype)
186
+
187
+ # 跳跃连接
188
+ y = y + x_ssm * self.D.to(input_dtype).unsqueeze(0).unsqueeze(0)
189
+
190
+ # 门控输出
191
+ y = y * F.silu(z)
192
+ out = self.out_proj(y) * torch.sigmoid(self.output_gate.to(input_dtype))
193
+
194
+ if use_cache and new_state is not None:
195
+ return out, new_state
196
+ return out
197
+
198
+ def _selective_scan(
199
+ self,
200
+ x: torch.Tensor,
201
+ A: torch.Tensor,
202
+ B: torch.Tensor,
203
+ C: torch.Tensor,
204
+ dt: torch.Tensor,
205
+ past_state: Optional[torch.Tensor] = None,
206
+ use_cache: bool = False,
207
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
208
+ """执行选择性扫描(纯 PyTorch 顺序实现)。
209
+
210
+ 支持 past_state 增量解码:传入上一步的 h 作为初始状态。
211
+
212
+ Returns:
213
+ y: (B, L, d_inner) - 输出
214
+ new_state: (B, d_inner, N) - 最终隐状态,use_cache=True 时返回
215
+ """
216
+ batch, L, d_inner = x.shape
217
+ N = A.shape[1]
218
+ orig_dtype = x.dtype
219
+
220
+ # 在 float32 中计算以防止 float16 溢出
221
+ x = x.float()
222
+ A = A.float()
223
+ B = B.float()
224
+ C = C.float()
225
+ dt = dt.float()
226
+
227
+ # 预计算离散化参数(clamp 防止 exp 溢出)
228
+ A_bar = torch.exp((dt.unsqueeze(-1) * A.unsqueeze(0).unsqueeze(0)).clamp(max=20))
229
+ B_bar = dt.unsqueeze(-1) * B.unsqueeze(2)
230
+
231
+ h = (
232
+ past_state.float()
233
+ if past_state is not None
234
+ else torch.zeros(batch, d_inner, N, device=x.device, dtype=torch.float32)
235
+ )
236
+ outputs = []
237
+
238
+ for t in range(L):
239
+ h = A_bar[:, t] * h + B_bar[:, t] * x[:, t].unsqueeze(-1)
240
+ y_t = (h * C[:, t].unsqueeze(1)).sum(-1)
241
+ outputs.append(y_t)
242
+
243
+ y = torch.stack(outputs, dim=1).to(orig_dtype)
244
+ new_state = h.to(orig_dtype) if use_cache else None
245
+ return y, new_state
246
+
247
+ def _chunk_parallel_scan(
248
+ self,
249
+ x: torch.Tensor,
250
+ A: torch.Tensor,
251
+ B: torch.Tensor,
252
+ C: torch.Tensor,
253
+ dt: torch.Tensor,
254
+ chunk_size: int = 64,
255
+ past_state: Optional[torch.Tensor] = None,
256
+ use_cache: bool = False,
257
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
258
+ """分块并行扫描:块内并行计算,块间顺序传播。
259
+
260
+ 支持 past_state 增量解码。
261
+ """
262
+ batch, L, d_inner = x.shape
263
+ N = A.shape[1]
264
+ orig_dtype = x.dtype
265
+
266
+ # 在 float32 中计算以防止 float16 溢出
267
+ x = x.float()
268
+ A = A.float()
269
+ B = B.float()
270
+ C = C.float()
271
+ dt = dt.float()
272
+
273
+ A_bar = torch.exp((dt.unsqueeze(-1) * A.unsqueeze(0).unsqueeze(0)).clamp(max=20))
274
+ B_bar = dt.unsqueeze(-1) * B.unsqueeze(2)
275
+
276
+ num_chunks = (L + chunk_size - 1) // chunk_size
277
+ h = (
278
+ past_state.float()
279
+ if past_state is not None
280
+ else torch.zeros(batch, d_inner, N, device=x.device, dtype=torch.float32)
281
+ )
282
+ all_outputs = []
283
+
284
+ for c in range(num_chunks):
285
+ s = c * chunk_size
286
+ e = min(s + chunk_size, L)
287
+
288
+ a_chunk = A_bar[:, s:e]
289
+ b_chunk = B_bar[:, s:e]
290
+ x_chunk = x[:, s:e]
291
+ c_chunk = C[:, s:e]
292
+
293
+ log_a = torch.log(a_chunk.clamp(min=1e-8))
294
+ log_a_cum = torch.cumsum(log_a, dim=1)
295
+ a_cum = torch.exp(log_a_cum)
296
+
297
+ h_contrib = a_cum * h.unsqueeze(1)
298
+
299
+ input_term = b_chunk * x_chunk.unsqueeze(-1)
300
+ normalized = input_term / (a_cum + 1e-8)
301
+ cum_input = torch.cumsum(normalized, dim=1)
302
+ input_contrib = a_cum * cum_input
303
+
304
+ h_all = h_contrib + input_contrib
305
+
306
+ y_chunk = (h_all * c_chunk.unsqueeze(2)).sum(-1)
307
+ all_outputs.append(y_chunk)
308
+
309
+ h = h_all[:, -1]
310
+
311
+ y = torch.cat(all_outputs, dim=1).to(orig_dtype)
312
+ new_state = h.to(orig_dtype) if use_cache else None
313
+ return y, new_state
314
+
315
+ def _triton_scan(
316
+ self,
317
+ x: torch.Tensor,
318
+ A: torch.Tensor,
319
+ B: torch.Tensor,
320
+ C: torch.Tensor,
321
+ dt: torch.Tensor,
322
+ past_state: Optional[torch.Tensor] = None,
323
+ use_cache: bool = False,
324
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
325
+ """Triton GPU kernel 加速的选择性扫描。
326
+
327
+ 当 Triton 可用且输入在 CUDA 上时自动调用。
328
+ 核心思路:将分块扫描的内外层循环融合为单个 Triton kernel,
329
+ 避免多次 kernel launch 和中间内存分配。
330
+
331
+ 当前版本为接口占位,委托给 _chunk_parallel_scan。
332
+ TODO: 实现原生 Triton kernel body。
333
+ """
334
+ logger.debug("Using Triton scan path (delegating to chunk_parallel)")
335
+ return self._chunk_parallel_scan(
336
+ x, A, B, C, dt,
337
+ chunk_size=min(max(16, x.shape[1]), 64),
338
+ past_state=past_state,
339
+ use_cache=use_cache,
340
+ )
@@ -0,0 +1,204 @@
1
+ """
2
+ CortexNet 训练工具 (Training Utilities)
3
+
4
+ 提供梯度监控、种子设置、设备选择等训练辅助功能。
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import random
10
+ from typing import Dict
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import numpy as np
15
+
16
+
17
+ class GradientMonitor:
18
+ """梯度监控器:记录训练过程中的梯度统计信息。
19
+
20
+ 用于诊断梯度消失/爆炸、优化学习率等。
21
+
22
+ Usage:
23
+ monitor = GradientMonitor(model)
24
+ loss.backward()
25
+ stats = monitor.get_stats()
26
+ """
27
+
28
+ def __init__(self, model: nn.Module):
29
+ self.model = model
30
+ self._hooks = []
31
+ self._grad_stats: Dict[str, Dict[str, float]] = {}
32
+ self._register_hooks()
33
+
34
+ def _register_hooks(self):
35
+ for name, param in self.model.named_parameters():
36
+ if param.requires_grad:
37
+ hook = param.register_hook(
38
+ lambda grad, n=name: self._record_grad(n, grad)
39
+ )
40
+ self._hooks.append(hook)
41
+
42
+ def _record_grad(self, name: str, grad: torch.Tensor):
43
+ self._grad_stats[name] = {
44
+ "mean": grad.mean().item(),
45
+ "std": grad.std().item(),
46
+ "max": grad.max().item(),
47
+ "min": grad.min().item(),
48
+ "norm": grad.norm().item(),
49
+ }
50
+
51
+ def get_stats(self) -> Dict[str, Dict[str, float]]:
52
+ """获取最近一次 backward 的梯度统计。"""
53
+ return dict(self._grad_stats)
54
+
55
+ def get_summary(self) -> Dict[str, float]:
56
+ """获取汇总统计。"""
57
+ if not self._grad_stats:
58
+ return {}
59
+ norms = [s["norm"] for s in self._grad_stats.values()]
60
+ return {
61
+ "grad_norm_mean": sum(norms) / len(norms),
62
+ "grad_norm_max": max(norms),
63
+ "grad_norm_min": min(norms),
64
+ "num_params_tracked": len(norms),
65
+ }
66
+
67
+ def remove_hooks(self):
68
+ """移除所有已注册的钩子。"""
69
+ for hook in self._hooks:
70
+ hook.remove()
71
+ self._hooks.clear()
72
+
73
+ def __del__(self):
74
+ self.remove_hooks()
75
+
76
+
77
+ def check_gradients_finite(model: nn.Module) -> bool:
78
+ """检查模型所有参数的梯度是否都是有限值(无 NaN/Inf)。
79
+
80
+ Args:
81
+ model: 要检查的模型
82
+
83
+ Returns:
84
+ True 如果所有梯度都有限(或无梯度),False 如果存在 NaN/Inf。
85
+ """
86
+ for name, param in model.named_parameters():
87
+ if param.grad is not None:
88
+ if not torch.isfinite(param.grad).all():
89
+ return False
90
+ return True
91
+
92
+
93
+ def set_seed(seed: int = 42):
94
+ """设置全局随机种子,确保可复现性。
95
+
96
+ Args:
97
+ seed: 随机种子值
98
+ """
99
+ random.seed(seed)
100
+ np.random.seed(seed)
101
+ torch.manual_seed(seed)
102
+ if torch.cuda.is_available():
103
+ torch.cuda.manual_seed_all(seed)
104
+ torch.backends.cudnn.deterministic = True
105
+ torch.backends.cudnn.benchmark = False
106
+
107
+
108
+ def get_best_device() -> torch.device:
109
+ """自动选择最佳可用计算设备。
110
+
111
+ 优先级: CUDA GPU > Apple MPS > CPU
112
+
113
+ Returns:
114
+ 最佳可用设备的 torch.device 对象
115
+ """
116
+ if torch.cuda.is_available():
117
+ return torch.device("cuda")
118
+ if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
119
+ return torch.device("mps")
120
+ return torch.device("cpu")
121
+
122
+
123
+ def create_optimizer_and_scheduler(
124
+ model: torch.nn.Module,
125
+ lr: float = 3e-4,
126
+ weight_decay: float = 0.01,
127
+ warmup_steps: int = 500,
128
+ total_steps: int = 10000,
129
+ min_lr_ratio: float = 0.1,
130
+ betas: tuple = (0.9, 0.95),
131
+ ):
132
+ """创建 AdamW 优化器 + 余弦退火调度器。
133
+
134
+ 典型的 LLM 训练配置:
135
+ - AdamW (β₁=0.9, β₂=0.95)
136
+ - 线性 warmup → 余弦退火
137
+ - 最终学习率 = min_lr_ratio × 初始学习率
138
+
139
+ Args:
140
+ model: 目标模型
141
+ lr: 初始学习率
142
+ weight_decay: L2 正则权重(不应用于 bias/norm)
143
+ warmup_steps: warmup 步数
144
+ total_steps: 总训练步数
145
+ min_lr_ratio: 最终学习率与初始学习率的比值
146
+ betas: Adam 的 β 参数
147
+
148
+ Returns:
149
+ (optimizer, scheduler) 元组
150
+ """
151
+ import math
152
+
153
+ # 分组参数:bias 和 LayerNorm/RMSNorm 不做 weight decay
154
+ decay_params = []
155
+ no_decay_params = []
156
+ for name, param in model.named_parameters():
157
+ if not param.requires_grad:
158
+ continue
159
+ if param.ndim == 1 or "bias" in name or "norm" in name.lower():
160
+ no_decay_params.append(param)
161
+ else:
162
+ decay_params.append(param)
163
+
164
+ param_groups = [
165
+ {"params": decay_params, "weight_decay": weight_decay},
166
+ {"params": no_decay_params, "weight_decay": 0.0},
167
+ ]
168
+ optimizer = torch.optim.AdamW(param_groups, lr=lr, betas=betas)
169
+
170
+ # 余弦退火 + 线性 warmup
171
+ def lr_lambda(step):
172
+ if step < warmup_steps:
173
+ return step / max(warmup_steps, 1)
174
+ progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1)
175
+ return min_lr_ratio + (1 - min_lr_ratio) * 0.5 * (1 + math.cos(math.pi * progress))
176
+
177
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
178
+ return optimizer, scheduler
179
+
180
+
181
+ def safe_clip_grad_norm_(
182
+ model: torch.nn.Module,
183
+ max_norm: float = 1.0,
184
+ norm_type: float = 2.0,
185
+ ) -> float:
186
+ """安全的梯度裁剪:先检查 NaN,再执行 clip。
187
+
188
+ Args:
189
+ model: 目标模型
190
+ max_norm: 梯度范数上限
191
+ norm_type: 范数类型(默认 L2)
192
+
193
+ Returns:
194
+ 裁剪前的梯度总范数
195
+ """
196
+ # 检查 NaN/Inf 梯度
197
+ for name, param in model.named_parameters():
198
+ if param.grad is not None and not torch.isfinite(param.grad).all():
199
+ param.grad.zero_() # 用零替代 NaN 梯度
200
+
201
+ total_norm = torch.nn.utils.clip_grad_norm_(
202
+ model.parameters(), max_norm, norm_type=norm_type,
203
+ )
204
+ return float(total_norm)
@@ -0,0 +1,157 @@
1
+ """
2
+ Transformer 基线模型 (Transformer Baseline)
3
+
4
+ 标准 Transformer 语言模型,用于与 CortexNet 进行公平对比。
5
+ 使用 Pre-LN (Pre-LayerNorm) 结构 + RoPE 位置编码。
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import Optional, Dict
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+
16
+ from .blocks import RMSNorm
17
+ from .attention import precompute_rope_freqs, apply_rope
18
+
19
+
20
+ class TransformerBlock(nn.Module):
21
+ """标准 Transformer 解码器块(Pre-LN)。"""
22
+
23
+ def __init__(self, d_model: int, num_heads: int, d_ff: int,
24
+ max_seq_len: int = 8192, dropout: float = 0.0,
25
+ rope_theta: float = 10000.0):
26
+ super().__init__()
27
+ self.num_heads = num_heads
28
+ self.head_dim = d_model // num_heads
29
+
30
+ self.norm1 = RMSNorm(d_model)
31
+ self.norm2 = RMSNorm(d_model)
32
+
33
+ # 自注意力
34
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
35
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
36
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
37
+ self.o_proj = nn.Linear(d_model, d_model, bias=False)
38
+
39
+ # FFN (SwiGLU)
40
+ self.gate_proj = nn.Linear(d_model, d_ff, bias=False)
41
+ self.up_proj = nn.Linear(d_model, d_ff, bias=False)
42
+ self.down_proj = nn.Linear(d_ff, d_model, bias=False)
43
+
44
+ self.dropout = nn.Dropout(dropout)
45
+
46
+ # RoPE
47
+ self.register_buffer(
48
+ "rope_freqs",
49
+ precompute_rope_freqs(self.head_dim, max_seq_len, rope_theta),
50
+ persistent=False,
51
+ )
52
+
53
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
54
+ B, L, D = x.shape
55
+ residual = x
56
+ x_norm = self.norm1(x)
57
+
58
+ # Multi-head attention with RoPE
59
+ q = self.q_proj(x_norm).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
60
+ k = self.k_proj(x_norm).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
61
+ v = self.v_proj(x_norm).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
62
+
63
+ q = apply_rope(q, self.rope_freqs)
64
+ k = apply_rope(k, self.rope_freqs)
65
+
66
+ # Causal attention
67
+ attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
68
+ attn_out = attn_out.transpose(1, 2).contiguous().view(B, L, D)
69
+ x = residual + self.dropout(self.o_proj(attn_out))
70
+
71
+ # SwiGLU FFN
72
+ residual = x
73
+ x_norm = self.norm2(x)
74
+ x = residual + self.dropout(
75
+ self.down_proj(F.silu(self.gate_proj(x_norm)) * self.up_proj(x_norm))
76
+ )
77
+
78
+ return x
79
+
80
+
81
+ class TransformerLM(nn.Module):
82
+ """标准 Transformer 语言模型(用于对比基线)。
83
+
84
+ Args:
85
+ vocab_size: 词汇表大小
86
+ d_model: 模型维度
87
+ num_layers: 层数
88
+ num_heads: 注意力头数
89
+ d_ff: FFN 中间维度
90
+ max_seq_len: 最大序列长度
91
+ dropout: Dropout 比率
92
+ """
93
+
94
+ def __init__(
95
+ self,
96
+ vocab_size: int = 32000,
97
+ d_model: int = 512,
98
+ num_layers: int = 4,
99
+ num_heads: int = 8,
100
+ d_ff: int = 1024,
101
+ max_seq_len: int = 8192,
102
+ dropout: float = 0.0,
103
+ rope_theta: float = 10000.0,
104
+ ):
105
+ super().__init__()
106
+ self.vocab_size = vocab_size
107
+ self.max_seq_len = max_seq_len
108
+
109
+ self.embed = nn.Embedding(vocab_size, d_model)
110
+ self.embed_dropout = nn.Dropout(dropout)
111
+
112
+ self.blocks = nn.ModuleList([
113
+ TransformerBlock(d_model, num_heads, d_ff, max_seq_len, dropout, rope_theta)
114
+ for _ in range(num_layers)
115
+ ])
116
+
117
+ self.final_norm = RMSNorm(d_model)
118
+ self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
119
+ self.lm_head.weight = self.embed.weight
120
+
121
+ self.apply(self._init_weights)
122
+
123
+ def _init_weights(self, module: nn.Module):
124
+ if isinstance(module, nn.Linear):
125
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
126
+ if module.bias is not None:
127
+ torch.nn.init.zeros_(module.bias)
128
+ elif isinstance(module, nn.Embedding):
129
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
130
+
131
+ def forward(
132
+ self,
133
+ input_ids: torch.Tensor,
134
+ labels: Optional[torch.Tensor] = None,
135
+ ) -> Dict[str, torch.Tensor]:
136
+ B, L = input_ids.shape
137
+ x = self.embed_dropout(self.embed(input_ids))
138
+
139
+ for block in self.blocks:
140
+ x = block(x)
141
+
142
+ x = self.final_norm(x)
143
+ logits = self.lm_head(x)
144
+
145
+ result = {"logits": logits}
146
+
147
+ if labels is not None:
148
+ shift_logits = logits[:, :-1, :].contiguous()
149
+ shift_labels = labels[:, 1:].contiguous()
150
+ loss = F.cross_entropy(
151
+ shift_logits.view(-1, self.vocab_size),
152
+ shift_labels.view(-1),
153
+ ignore_index=-100,
154
+ )
155
+ result["loss"] = loss
156
+
157
+ return result