orbit-torch 0.0.4a1__py3-none-any.whl → 0.1.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. orbit/__init__.py +3 -1
  2. orbit/callback.py +4 -3
  3. orbit/dataset/__init__.py +1 -0
  4. orbit/dataset/cogn.py +138 -0
  5. orbit/dataset/data/cogn_en.jsonl +45 -0
  6. orbit/dataset/data/cogn_zh.jsonl +113 -0
  7. orbit/engine.py +210 -146
  8. orbit/kit/__init__.py +2 -0
  9. orbit/kit/interface.py +154 -0
  10. orbit/kit/wrapper.py +157 -0
  11. orbit/model/__init__.py +5 -0
  12. orbit/model/base.py +125 -0
  13. orbit/model/block/__init__.py +34 -0
  14. orbit/model/block/attention.py +265 -0
  15. orbit/model/block/bio.py +537 -0
  16. orbit/model/block/codebook.py +122 -0
  17. orbit/model/block/conv.py +505 -0
  18. orbit/model/block/embedding.py +252 -0
  19. orbit/model/block/film.py +176 -0
  20. orbit/model/block/fusion.py +335 -0
  21. orbit/model/block/gate.py +334 -0
  22. orbit/model/block/lora.py +776 -0
  23. orbit/model/block/mlp.py +68 -0
  24. orbit/model/block/moe.py +94 -0
  25. orbit/model/block/tcn.py +99 -0
  26. orbit/model/config.py +62 -0
  27. orbit/model/kit/__init__.py +6 -0
  28. orbit/model/kit/discriminator.py +46 -0
  29. orbit/model/kit/losses.py +193 -0
  30. orbit/model/motif/__init__.py +0 -0
  31. orbit/model/motif/vision/__init__.py +0 -0
  32. orbit/model/motif/vision/v1.py +645 -0
  33. orbit/model/registry.py +53 -0
  34. orbit/optim/__init__.py +2 -2
  35. orbit/optim/sam.py +10 -3
  36. orbit/plugin/__init__.py +12 -8
  37. orbit/plugin/board.py +1 -2
  38. orbit/plugin/checkpoint.py +137 -62
  39. orbit/plugin/classification.py +2 -2
  40. orbit/plugin/display_model.py +1 -2
  41. orbit/plugin/early_stopping.py +1 -2
  42. orbit/plugin/ema.py +1 -2
  43. orbit/plugin/gradient_accumulation.py +1 -2
  44. orbit/plugin/lora.py +346 -0
  45. orbit/plugin/memory_estimator.py +1 -2
  46. orbit/plugin/warmup.py +1 -2
  47. orbit/utils/__init__.py +24 -1
  48. orbit/utils/cuda.py +10 -0
  49. orbit/utils/freeze.py +61 -17
  50. orbit/utils/image.py +164 -0
  51. orbit/utils/initialization.py +184 -94
  52. orbit/utils/layer_io.py +66 -7
  53. orbit/utils/lora.py +480 -0
  54. orbit/utils/moe.py +55 -0
  55. orbit/utils/seed.py +3 -19
  56. orbit/utils/sft.py +93 -0
  57. orbit_torch-0.1.0b1.dist-info/METADATA +208 -0
  58. orbit_torch-0.1.0b1.dist-info/RECORD +65 -0
  59. orbit_torch-0.0.4a1.dist-info/METADATA +0 -25
  60. orbit_torch-0.0.4a1.dist-info/RECORD +0 -29
  61. {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/WHEEL +0 -0
  62. {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/top_level.txt +0 -0
orbit/utils/lora.py ADDED
@@ -0,0 +1,480 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import json
4
+
5
+ from safetensors.torch import save_file as safe_save_file
6
+ from safetensors.torch import load_file as safe_load_file
7
+
8
+ from orbit.model.block import LinearLoRA, Conv2dLoRA, Conv1dLoRA, EmbeddingLoRA
9
+
10
+ lora_models = [LinearLoRA, Conv2dLoRA, Conv1dLoRA, EmbeddingLoRA]
11
+
12
+ def freeze_backbone_only(
13
+ model: nn.Module,
14
+ unlock_head_keywords: list = None,
15
+ verbose: bool = True
16
+ ):
17
+ '''冻结骨干网络,仅保留 LoRA 层和指定的头部层可训练。
18
+
19
+ 该函数首先冻结所有参数,然后解冻 LoRA 模块(LinearLoRA, Conv2dLoRA)的参数,
20
+ 最后解冻名称中包含 unlock_head_keywords 中任意关键字的参数。
21
+
22
+ Args:
23
+ model (nn.Module): 目标模型。
24
+ unlock_head_keywords (list, optional): 需要保持解冻状态的头部层关键字列表。
25
+ 例如 ['head', 'fc', 'classifier']。默认为 None。
26
+ verbose (bool): 是否打印冻结状态统计信息。默认为 True。
27
+ '''
28
+ for param in model.parameters():
29
+ param.requires_grad = False
30
+
31
+ if not unlock_head_keywords: unlock_head_keywords = []
32
+
33
+ lora_types = tuple(lora_models)
34
+
35
+ lora_counter = 0
36
+ for name, module in model.named_modules():
37
+ if isinstance(module, lora_types):
38
+ lora_counter += 1
39
+
40
+ if hasattr(module, 'lora_a') and module.lora_a is not None:
41
+ if isinstance(module.lora_a, nn.Module):
42
+ for p in module.lora_a.parameters(): p.requires_grad = True
43
+ elif isinstance(module.lora_a, nn.Parameter):
44
+ module.lora_a.requires_grad = True
45
+
46
+ if hasattr(module, 'lora_b') and module.lora_b is not None:
47
+ if isinstance(module.lora_b, nn.Module):
48
+ for p in module.lora_b.parameters(): p.requires_grad = True
49
+ elif isinstance(module.lora_b, nn.Parameter):
50
+ module.lora_b.requires_grad = True
51
+
52
+ if hasattr(module, 'dora_m') and module.dora_m is not None:
53
+ module.dora_m.requires_grad = True
54
+
55
+ if hasattr(module, 'lora_gate') and module.lora_gate is not None:
56
+ module.lora_gate.requires_grad = True
57
+
58
+ for name, param in model.named_parameters():
59
+ if any(k in name for k in unlock_head_keywords):
60
+ param.requires_grad = True
61
+
62
+ if verbose:
63
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
64
+ total = sum(p.numel() for p in model.parameters())
65
+ print(f"Backbone frozen.")
66
+ print(f"- Active LoRA blocks: {lora_counter}")
67
+ print(f"- Trainable params: {trainable:,} / {total:,} ({trainable/total:.2%})")
68
+ if unlock_head_keywords:
69
+ print(f"- Extra unlocked layers: {unlock_head_keywords}")
70
+
71
+
72
+ def merge_lora(model: nn.Module, verbose: bool = False):
73
+ '''将模型中所有 LoRA 层的权重合并到原始层中。
74
+
75
+ 用于推理加速。遍历模型中的所有模块,找到 LoRA 包装层并调用其 merge 方法。
76
+
77
+ Args:
78
+ model (nn.Module): 包含 LoRA 层的模型。
79
+ verbose (bool): 是否打印合并统计信息。默认为 False。
80
+
81
+ Returns:
82
+ nn.Module: 合并权重后的模型。
83
+ '''
84
+ lora_types = tuple(lora_models)
85
+ count = 0
86
+
87
+ for module in model.modules():
88
+ if isinstance(module, lora_types):
89
+ # 仅当 r > 0 且尚未合并时才执行合并
90
+ if hasattr(module, 'r') and module.r > 0 and hasattr(module, 'merged') and not module.merged:
91
+ module.merge()
92
+ count += 1
93
+
94
+ if verbose:
95
+ print(f"Merged LoRA weights in {count} modules.")
96
+
97
+ return model
98
+
99
+ def unmerge_lora(model: nn.Module, verbose: bool = False):
100
+ '''撤销模型中所有 LoRA 层的权重合并。
101
+
102
+ 用于在推理后恢复训练状态。注意:DoRA 模式下无法精确恢复原始权重。
103
+
104
+ Args:
105
+ model (nn.Module): 包含 LoRA 层的模型。
106
+ verbose (bool): 是否打印操作信息。默认为 False。
107
+
108
+ Returns:
109
+ nn.Module: 撤销合并后的模型。
110
+ '''
111
+ lora_types = tuple(lora_models)
112
+ count = 0
113
+
114
+ for module in model.modules():
115
+ if isinstance(module, lora_types):
116
+ if hasattr(module, 'merged') and module.merged:
117
+ module.unmerge()
118
+ count += 1
119
+
120
+ if verbose:
121
+ print(f"Unmerged LoRA weights in {count} modules.")
122
+
123
+ return model
124
+
125
+
126
+ def unload_lora(model: nn.Module, merge: bool = False, verbose: bool = False):
127
+ '''卸载模型中的 LoRA 层,恢复原始层。
128
+
129
+ Args:
130
+ model (nn.Module): 包含 LoRA 层的模型。
131
+ merge (bool): 是否在卸载前合并权重。默认为 False。
132
+ verbose (bool): 是否打印操作信息。默认为 False。
133
+
134
+ Returns:
135
+ nn.Module: 卸载 LoRA 后的模型。
136
+ '''
137
+ lora_types = tuple(lora_models)
138
+ count = 0
139
+
140
+ def _unload(parent):
141
+ nonlocal count
142
+ for name, child in parent.named_children():
143
+ if isinstance(child, lora_types):
144
+ if merge:
145
+ child.merge()
146
+
147
+ if hasattr(child, 'original_layer'):
148
+ setattr(parent, name, child.original_layer)
149
+ count += 1
150
+ else:
151
+ _unload(child)
152
+
153
+ _unload(model)
154
+
155
+ if verbose:
156
+ action = "Merged and unloaded" if merge else "Unloaded"
157
+ print(f"{action} LoRA layers in {count} modules.")
158
+
159
+ return model
160
+
161
+
162
+ def inject_lora(
163
+ model: nn.Module,
164
+ r: int = 8,
165
+ lora_alpha: int = 16,
166
+ lora_dropout: float = 0.05,
167
+ gate: bool = False,
168
+ dora: bool = False,
169
+ target_names: list = None,
170
+ exclude_names: list = None,
171
+ verbose: bool = False,
172
+ prefix: str = ""
173
+ ):
174
+ '''向模型中注入 LoRA 层。
175
+
176
+ 递归遍历模型,将 Linear, Conv2d, Conv1d, Embedding 层替换为对应的 LoRA 包装层。
177
+ 支持普通 LoRA、Gated LoRA 和 DoRA。
178
+
179
+ Args:
180
+ model (nn.Module): 目标模型。
181
+ r (int): LoRA 的秩。默认为 8。
182
+ lora_alpha (int): LoRA 的缩放系数。默认为 16。
183
+ lora_dropout (float): Dropout 概率。默认为 0.05。
184
+ gate (bool): 是否启用 Gated LoRA (添加可学习的门控参数)。默认为 False。
185
+ dora (bool): 是否启用 DoRA (Weight-Decomposed Low-Rank Adaptation)。默认为 False。
186
+ target_names (list, optional): 仅注入名称包含这些关键字的层。支持字符串(子串匹配)或正则表达式对象。默认为 None (注入所有支持的层)。
187
+ exclude_names (list, optional): 排除名称包含这些关键字的层。支持字符串(子串匹配)或正则表达式对象。默认为 None。
188
+ verbose (bool): 是否打印注入的层信息。默认为 False。
189
+ prefix (str): 内部递归使用的路径前缀。
190
+
191
+ Returns:
192
+ nn.Module: 注入 LoRA 后的模型。
193
+ '''
194
+ import re
195
+
196
+ def check_match(name, patterns):
197
+ if patterns is None: return False
198
+ for p in patterns:
199
+ if isinstance(p, str):
200
+ if p in name: return True
201
+ elif hasattr(p, 'search'):
202
+ if p.search(name): return True
203
+ return False
204
+
205
+ for name, child in model.named_children():
206
+ full_name = f"{prefix}.{name}" if prefix else name
207
+
208
+ should_inject = target_names is None or check_match(full_name, target_names)
209
+ is_excluded = check_match(full_name, exclude_names)
210
+
211
+ if not should_inject or is_excluded:
212
+ inject_lora(child, r, lora_alpha, lora_dropout, gate, dora, target_names, exclude_names, verbose, full_name)
213
+ continue
214
+
215
+ new_layer = None
216
+ if isinstance(child, nn.Linear):
217
+ new_layer = LinearLoRA(
218
+ child, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
219
+ gate=gate, dora=dora
220
+ )
221
+
222
+ elif isinstance(child, nn.Conv2d):
223
+ if child.kernel_size == (1, 1) or child.kernel_size == 1:
224
+ pass
225
+ else:
226
+ new_layer = Conv2dLoRA(
227
+ child, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
228
+ gate=gate, dora=dora
229
+ )
230
+
231
+ elif isinstance(child, nn.Conv1d):
232
+ new_layer = Conv1dLoRA(
233
+ child, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
234
+ gate=gate, dora=dora
235
+ )
236
+
237
+ elif isinstance(child, nn.Embedding):
238
+ new_layer = EmbeddingLoRA(
239
+ child, r=r, lora_alpha=lora_alpha,
240
+ gate=gate, dora=dora
241
+ )
242
+
243
+ if new_layer is not None:
244
+ setattr(model, name, new_layer)
245
+ if verbose:
246
+ print(f"LoRA injected: {full_name} ({child.__class__.__name__})")
247
+ else:
248
+ # 如果当前层匹配但不是支持的叶子层(或者是容器),继续递归
249
+ inject_lora(child, r, lora_alpha, lora_dropout, gate, dora, target_names, exclude_names, verbose, full_name)
250
+
251
+ return model
252
+
253
+ def inject_lora_file(
254
+ model: nn.Module,
255
+ path: str,
256
+ merge_and_unload: bool = False,
257
+ verbose: bool = False
258
+ ):
259
+ '''从文件自动注入 LoRA 并加载权重。
260
+
261
+ 该函数会分析权重文件,自动推断 LoRA 参数(如 r, gate, dora),
262
+ 向模型注入相应的 LoRA 层,并加载权重。
263
+
264
+ Args:
265
+ model (nn.Module): 目标模型。
266
+ path (str): 权重文件路径。
267
+ merge_and_unload (bool): 是否在加载后合并权重并卸载 LoRA 层。默认为 False。
268
+ verbose (bool): 是否打印详细信息。默认为 False。
269
+
270
+ Returns:
271
+ nn.Module: 注入 LoRA 并加载权重后的模型。
272
+ '''
273
+ import os
274
+ if not os.path.exists(path):
275
+ raise FileNotFoundError(f"File not found: {path}")
276
+
277
+ device = next(model.parameters()).device
278
+
279
+ config = {}
280
+ state_dict = {}
281
+
282
+ if path.endswith('.safetensors'):
283
+ # 加载 safetensors
284
+ state_dict = safe_load_file(path, device=str(device))
285
+
286
+ # 尝试读取 metadata
287
+ from safetensors.torch import safe_open
288
+ with safe_open(path, framework="pt", device=str(device)) as f:
289
+ metadata = f.metadata()
290
+ if metadata and 'orbit_lora_config' in metadata:
291
+ try:
292
+ config = json.loads(metadata['orbit_lora_config'])
293
+ except:
294
+ pass
295
+ else:
296
+ checkpoint = torch.load(path, map_location=device)
297
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
298
+ state_dict = checkpoint['model_state_dict']
299
+ config = checkpoint.get('orbit_lora_config', {})
300
+ else:
301
+ state_dict = checkpoint
302
+ config = {}
303
+
304
+ r = config.get('r', 8)
305
+ alpha = config.get('alpha', 16)
306
+ target_names = config.get('target_names', None)
307
+
308
+ # 自动检测 gate 和 dora
309
+ has_gate = any('lora_gate' in k for k in state_dict.keys())
310
+ has_dora = any('dora_m' in k for k in state_dict.keys())
311
+
312
+ if verbose:
313
+ print(f"Injecting LoRA from {path}...")
314
+ print(f"Config: r={r}, alpha={alpha}, gate={has_gate}, dora={has_dora}, targets={target_names}")
315
+
316
+ inject_lora(
317
+ model,
318
+ r=r,
319
+ lora_alpha=alpha,
320
+ gate=has_gate,
321
+ dora=has_dora,
322
+ target_names=target_names,
323
+ verbose=verbose
324
+ )
325
+
326
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
327
+
328
+ if verbose:
329
+ lora_missing = [k for k in missing if 'lora_' in k or 'dora_' in k]
330
+ if lora_missing:
331
+ print(f"Warning: Missing LoRA keys: {lora_missing}")
332
+ else:
333
+ print("LoRA weights loaded successfully.")
334
+
335
+ if merge_and_unload:
336
+ unload_lora(model, merge=True, verbose=verbose)
337
+
338
+ return model
339
+
340
+ def save_lora(model: nn.Module, path: str):
341
+ '''仅保存模型的 LoRA 权重。
342
+
343
+ 遍历模型的 state_dict,提取所有键名中包含 'lora_' 的参数并保存。
344
+
345
+ Args:
346
+ model (nn.Module): 包含 LoRA 层的模型。
347
+ path (str): 保存路径。
348
+ '''
349
+ lora_state_dict = {}
350
+ full_state_dict = model.state_dict()
351
+
352
+ for key, value in full_state_dict.items():
353
+ if 'lora_' in key:
354
+ lora_state_dict[key] = value
355
+
356
+ if path.endswith('.safetensors'):
357
+ safe_save_file(lora_state_dict, path)
358
+ else:
359
+ torch.save(lora_state_dict, path)
360
+ print(f"LoRA weights saved to {path}. Size: {len(lora_state_dict)} keys.")
361
+
362
+ def load_lora(model: nn.Module, path: str):
363
+ '''加载 LoRA 权重到模型中。
364
+
365
+ 使用 strict=False 加载权重,并打印缺失或意外的键的警告信息。
366
+ 支持加载纯权重文件或 Checkpoint 插件保存的包含 'model_state_dict' 的字典。
367
+
368
+ Args:
369
+ model (nn.Module): 目标模型。
370
+ path (str): 权重文件路径。
371
+ '''
372
+ device = next(model.parameters()).device
373
+
374
+ if path.endswith('.safetensors'):
375
+ lora_state_dict = safe_load_file(path, device=str(device))
376
+ else:
377
+ checkpoint = torch.load(path, map_location=device)
378
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
379
+ lora_state_dict = checkpoint['model_state_dict']
380
+ else:
381
+ lora_state_dict = checkpoint
382
+
383
+ missing, unexpected = model.load_state_dict(lora_state_dict, strict=False)
384
+
385
+ if unexpected:
386
+ print(f"Warning: Unexpected keys found: {unexpected}")
387
+
388
+ lora_missing = [k for k in missing if 'lora_' in k]
389
+ if lora_missing:
390
+ print(f"Warning: Missing LoRA keys: {lora_missing}")
391
+ else:
392
+ print("LoRA weights loaded successfully.")
393
+
394
+ class LoRADiagnoser:
395
+ @staticmethod
396
+ def get_status(model: nn.Module, verbose: bool = False) -> dict:
397
+ """
398
+ 在 train loop 中调用此函数,返回当前 LoRA 层的健康状态。
399
+ """
400
+ stats = {
401
+ "total_lora_modules": 0,
402
+ "active_grads": 0,
403
+ "avg_update_magnitude": 0.0,
404
+ "max_grad_norm": 0.0,
405
+ "dead_neurons": 0
406
+ }
407
+
408
+ update_ratios = []
409
+
410
+ for name, module in model.named_modules():
411
+ if hasattr(module, 'lora_a') and hasattr(module, 'lora_b') and module.r > 0:
412
+ stats["total_lora_modules"] += 1
413
+
414
+ wa = module.lora_a.weight if isinstance(module.lora_a, nn.Module) else module.lora_a
415
+ wb = module.lora_b.weight if isinstance(module.lora_b, nn.Module) else module.lora_b
416
+
417
+ if wa.grad is not None and wb.grad is not None:
418
+ stats["active_grads"] += 1
419
+ g_norm = wa.grad.norm().item() + wb.grad.norm().item()
420
+ stats["max_grad_norm"] = max(stats["max_grad_norm"], g_norm)
421
+
422
+ s = module.scaling
423
+
424
+ norm_a = wa.data.norm().item()
425
+ norm_b = wb.data.norm().item()
426
+ norm_delta = norm_a * norm_b * s
427
+
428
+ if hasattr(module, 'original_layer'):
429
+ # Conv LoRA
430
+ norm_w = module.original_layer.weight.data.norm().item()
431
+ elif hasattr(module, 'weight'):
432
+ norm_w = module.weight.data.norm().item()
433
+ else:
434
+ norm_w = 1.0 # Fallback
435
+
436
+ ratio = norm_delta / (norm_w + 1e-6)
437
+ update_ratios.append(ratio)
438
+
439
+ if norm_b < 1e-9:
440
+ stats["dead_neurons"] += 1
441
+
442
+ if update_ratios:
443
+ stats["avg_update_magnitude"] = sum(update_ratios) / len(update_ratios)
444
+ stats["min_ratio"] = min(update_ratios)
445
+ stats["max_ratio"] = max(update_ratios)
446
+
447
+ if verbose:
448
+ print(f"--- LoRA Diagnosis ---")
449
+ print(f"Modules: {stats['total_lora_modules']} | Active Grads: {stats['active_grads']}")
450
+ print(f"Update Ratio (Perturbation): {stats['avg_update_magnitude']:.6f} (Target ~0.001 - 0.01)")
451
+ print(f"Max Gradient Norm: {stats['max_grad_norm']:.6f}")
452
+ if stats['dead_neurons'] > 0:
453
+ print(f"Warning: {stats['dead_neurons']} modules have near-zero output (Initialization issue?)")
454
+ print(f"----------------------")
455
+
456
+ return stats
457
+
458
+ @staticmethod
459
+ def check_collapse(model: nn.Module, threshold: float = 1e-4):
460
+ """
461
+ 检查 LoRA 矩阵是否存在严重的秩塌缩 (Rank Collapse)。
462
+ """
463
+ print("Running SVD analysis on LoRA layers...")
464
+ for name, module in model.named_modules():
465
+ if hasattr(module, 'lora_a') and hasattr(module, 'lora_b'):
466
+ wa = module.lora_a.weight if isinstance(module.lora_a, nn.Module) else module.lora_a
467
+
468
+ if wa.dim() > 2:
469
+ wa_flat = wa.view(wa.shape[0], -1) # (r, in*k*k)
470
+ else:
471
+ wa_flat = wa
472
+
473
+ try:
474
+ _, S, _ = torch.svd(wa_flat)
475
+ # 归一化奇异值
476
+ S = S / S[0]
477
+ effective_rank = (S > threshold).sum().item()
478
+ print(f"[{name}] Rank: {module.r} | Effective Rank: {effective_rank} | Top/Bottom Ratio: {S[0]/S[-1]:.1f}")
479
+ except:
480
+ pass
orbit/utils/moe.py ADDED
@@ -0,0 +1,55 @@
1
+ import torch.nn as nn
2
+ from typing import Literal
3
+ from orbit.utils.freeze import set_trainable
4
+
5
+ def set_moe_training_mode(
6
+ model: nn.Module,
7
+ mode: Literal['all', 'router_only', 'experts_only'] = 'all',
8
+ verbose: bool = True
9
+ ) -> None:
10
+ '''设置 MoE (Mixture of Experts) 模型的训练模式。
11
+
12
+ 通过识别模型中的 MoE 模块实例,精确控制 Router (门控网络) 和 Experts (专家网络) 的冻结/解冻状态。
13
+
14
+ Args:
15
+ model (nn.Module): 包含 MoE 结构的模型。
16
+ mode (str): 训练模式。
17
+ - 'all': 联合训练 Router 和 Experts (默认)。
18
+ - 'router_only': 仅训练 Router,冻结 Experts。
19
+ - 'experts_only': 仅训练 Experts,冻结 Router。
20
+ verbose (bool): 是否打印状态变更信息。
21
+ '''
22
+ from orbit.model.block.moe import MoE
23
+
24
+ moe_modules = [m for m in model.modules() if isinstance(m, MoE)]
25
+
26
+ if not moe_modules:
27
+ if verbose:
28
+ print("Warning: No MoE modules found in the provided model.")
29
+ return
30
+
31
+ count = len(moe_modules)
32
+
33
+ for moe in moe_modules:
34
+ if mode == 'all':
35
+ set_trainable(moe.router, trainable=True)
36
+ set_trainable(moe.experts, trainable=True)
37
+
38
+ elif mode == 'router_only':
39
+ set_trainable(moe.experts, trainable=False)
40
+ set_trainable(moe.router, trainable=True)
41
+
42
+ elif mode == 'experts_only':
43
+ set_trainable(moe.router, trainable=False)
44
+ set_trainable(moe.experts, trainable=True)
45
+
46
+ else:
47
+ raise ValueError(f"Unknown mode: {mode}. Supported modes: 'all', 'router_only', 'experts_only'")
48
+
49
+ if verbose:
50
+ mode_desc = {
51
+ 'all': "ALL (Router: Unfrozen, Experts: Unfrozen)",
52
+ 'router_only': "ROUTER_ONLY (Router: Unfrozen, Experts: Frozen)",
53
+ 'experts_only': "EXPERTS_ONLY (Router: Frozen, Experts: Unfrozen)"
54
+ }
55
+ print(f"MoE Training Mode set to: {mode_desc[mode]} for {count} MoE module(s).")
orbit/utils/seed.py CHANGED
@@ -4,7 +4,7 @@ import numpy as np
4
4
  import random
5
5
  import os
6
6
 
7
- def seed_everything(seed=42, strict=False):
7
+ def seed_everything(seed=42, strict=False, warn_only=True):
8
8
  """
9
9
  设置所有随机种子以确保 PyTorch 实验的可复现性。
10
10
 
@@ -16,38 +16,22 @@ def seed_everything(seed=42, strict=False):
16
16
  """
17
17
  import orbit
18
18
  orbit.seed_info = seed
19
- # 1. 设置 Python 原生 random
20
19
  random.seed(seed)
20
+ np.random.seed(seed)
21
21
 
22
- # 2. 设置 Python 哈希种子 (影响字典/集合迭代顺序)
23
22
  os.environ['PYTHONHASHSEED'] = str(seed)
24
23
 
25
- # 3. 设置 Numpy
26
- np.random.seed(seed)
27
-
28
- # 4. 设置 PyTorch CPU/GPU
29
24
  torch.manual_seed(seed)
30
25
  if torch.cuda.is_available():
31
26
  torch.cuda.manual_seed_all(seed)
32
27
 
33
- # 5. 设置 CuDNN 后端 (常规复现性设置)
34
28
  if torch.cuda.is_available():
35
- # 禁止寻找最优算法 (因为最优算法可能因硬件状态而变)
36
29
  torch.backends.cudnn.benchmark = False
37
- # 强制使用确定性算法
38
30
  torch.backends.cudnn.deterministic = True
39
- # 6. 严格模式 (Strict Mode)
40
31
  if strict:
41
32
  try:
42
- # 启用严格确定性算法
43
- # 注意:某些操作如果 PyTorch 没有对应的确定性实现,会直接通过 RuntimeError 报错
44
- torch.use_deterministic_algorithms(True)
45
-
46
- # 为了让 use_deterministic_algorithms 在 CUDA 上正常工作,
47
- # 必须设置 CUBLAS_WORKSPACE_CONFIG,否则会报 CuBLAS 错误。
48
- # :4096:8 是官方推荐的设置,虽然会增加少许显存开销
33
+ torch.use_deterministic_algorithms(True, warn_only=warn_only)
49
34
  os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
50
-
51
35
  print(f"[Info] Strict deterministic mode enabled. (seed={seed})")
52
36
  except AttributeError:
53
37
  print("[Warning] torch.use_deterministic_algorithms is not available in your PyTorch version.")
orbit/utils/sft.py ADDED
@@ -0,0 +1,93 @@
1
+ from typing import Dict, Any, TYPE_CHECKING
2
+
3
+ if TYPE_CHECKING: from orbit.engine import Engine
4
+
5
+ def build_sft(
6
+ user_content: str,
7
+ model_content: str,
8
+ tokenizer: Any,
9
+ max_length: int = 2048,
10
+ model_role: str = 'model',
11
+ padding: bool = True,
12
+ ignore_index: int = -100
13
+ ) -> Dict[str, Any]:
14
+ '''构建 SFT (Supervised Fine-Tuning) 数据集样本。
15
+
16
+ 将用户输入和模型输出组合,应用对话模板,并进行分词、截断和 padding。
17
+ 同时生成 labels,其中用户输入部分和 padding 部分被 mask (设置为 ignore_index)。
18
+
19
+ Args:
20
+ user_content (str): 用户的输入内容。
21
+ model_content (str): 模型的期望回复内容。
22
+ tokenizer (Any): 分词器实例,需要支持 apply_chat_template 和 __call__。
23
+ max_length (int, optional): 序列最大长度。默认为 2048。
24
+ model_role (str, optional): 模型角色名称,用于构建对话消息。默认为 'model'。
25
+ padding (bool, optional): 是否进行 padding 到 max_length。默认为 True。
26
+ ignore_index (int, optional): 用于 mask labels 的索引值,与 PyTorch 损失函数保持一致。默认为 -100。
27
+
28
+ Returns:
29
+ Dict[str, Any]: 包含处理后的数据字典:
30
+ - 'input_ids': 输入 token ID 张量。
31
+ - 'attention_mask': 注意力掩码张量。
32
+ - 'labels': 用于计算损失的标签张量,用户部分已 mask。
33
+ '''
34
+ messages = [
35
+ {'role': 'user', 'content': user_content},
36
+ {'role': model_role, 'content': model_content}
37
+ ]
38
+
39
+ text = tokenizer.apply_chat_template(
40
+ messages,
41
+ tokenize=False,
42
+ add_generation_prompt=False
43
+ )
44
+
45
+ user_messages = [{'role': 'user', 'content': user_content}]
46
+ prompt_text = tokenizer.apply_chat_template(
47
+ user_messages,
48
+ tokenize=False,
49
+ add_generation_prompt=True
50
+ )
51
+
52
+ encodings = tokenizer(
53
+ text,
54
+ max_length=max_length,
55
+ truncation=True,
56
+ padding='max_length' if padding else False,
57
+ return_tensors='pt'
58
+ )
59
+
60
+ input_ids = encodings['input_ids'][0]
61
+ attention_mask = encodings['attention_mask'][0]
62
+ labels = input_ids.clone()
63
+
64
+ prompt_encodings = tokenizer(
65
+ prompt_text,
66
+ truncation=True,
67
+ max_length=max_length,
68
+ add_special_tokens=True
69
+ )
70
+ prompt_len = len(prompt_encodings['input_ids'])
71
+
72
+ if prompt_len > len(input_ids):
73
+ prompt_len = len(input_ids)
74
+
75
+ labels[:prompt_len] = ignore_index
76
+
77
+ if tokenizer.pad_token_id is not None:
78
+ labels[input_ids == tokenizer.pad_token_id] = ignore_index
79
+
80
+ return {
81
+ 'input_ids': input_ids,
82
+ 'attention_mask': attention_mask,
83
+ 'labels': labels
84
+ }
85
+
86
+ def train_sft(engine: 'Engine'):
87
+ output = engine.unwrap_model()(
88
+ input_ids=engine.data['input_ids'],
89
+ attention_mask=engine.data['attention_mask'],
90
+ labels=engine.data['labels']
91
+ )
92
+
93
+ engine.update(output.loss)