orbit-torch 0.0.4a1__py3-none-any.whl → 0.1.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. orbit/__init__.py +3 -1
  2. orbit/callback.py +4 -3
  3. orbit/dataset/__init__.py +1 -0
  4. orbit/dataset/cogn.py +138 -0
  5. orbit/dataset/data/cogn_en.jsonl +45 -0
  6. orbit/dataset/data/cogn_zh.jsonl +113 -0
  7. orbit/engine.py +210 -146
  8. orbit/kit/__init__.py +2 -0
  9. orbit/kit/interface.py +154 -0
  10. orbit/kit/wrapper.py +157 -0
  11. orbit/model/__init__.py +5 -0
  12. orbit/model/base.py +125 -0
  13. orbit/model/block/__init__.py +34 -0
  14. orbit/model/block/attention.py +265 -0
  15. orbit/model/block/bio.py +537 -0
  16. orbit/model/block/codebook.py +122 -0
  17. orbit/model/block/conv.py +505 -0
  18. orbit/model/block/embedding.py +252 -0
  19. orbit/model/block/film.py +176 -0
  20. orbit/model/block/fusion.py +335 -0
  21. orbit/model/block/gate.py +334 -0
  22. orbit/model/block/lora.py +776 -0
  23. orbit/model/block/mlp.py +68 -0
  24. orbit/model/block/moe.py +94 -0
  25. orbit/model/block/tcn.py +99 -0
  26. orbit/model/config.py +62 -0
  27. orbit/model/kit/__init__.py +6 -0
  28. orbit/model/kit/discriminator.py +46 -0
  29. orbit/model/kit/losses.py +193 -0
  30. orbit/model/motif/__init__.py +0 -0
  31. orbit/model/motif/vision/__init__.py +0 -0
  32. orbit/model/motif/vision/v1.py +645 -0
  33. orbit/model/registry.py +53 -0
  34. orbit/optim/__init__.py +2 -2
  35. orbit/optim/sam.py +10 -3
  36. orbit/plugin/__init__.py +12 -8
  37. orbit/plugin/board.py +1 -2
  38. orbit/plugin/checkpoint.py +137 -62
  39. orbit/plugin/classification.py +2 -2
  40. orbit/plugin/display_model.py +1 -2
  41. orbit/plugin/early_stopping.py +1 -2
  42. orbit/plugin/ema.py +1 -2
  43. orbit/plugin/gradient_accumulation.py +1 -2
  44. orbit/plugin/lora.py +346 -0
  45. orbit/plugin/memory_estimator.py +1 -2
  46. orbit/plugin/warmup.py +1 -2
  47. orbit/utils/__init__.py +24 -1
  48. orbit/utils/cuda.py +10 -0
  49. orbit/utils/freeze.py +61 -17
  50. orbit/utils/image.py +164 -0
  51. orbit/utils/initialization.py +184 -94
  52. orbit/utils/layer_io.py +66 -7
  53. orbit/utils/lora.py +480 -0
  54. orbit/utils/moe.py +55 -0
  55. orbit/utils/seed.py +3 -19
  56. orbit/utils/sft.py +93 -0
  57. orbit_torch-0.1.0b1.dist-info/METADATA +208 -0
  58. orbit_torch-0.1.0b1.dist-info/RECORD +65 -0
  59. orbit_torch-0.0.4a1.dist-info/METADATA +0 -25
  60. orbit_torch-0.0.4a1.dist-info/RECORD +0 -29
  61. {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/WHEEL +0 -0
  62. {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/top_level.txt +0 -0
orbit/optim/sam.py CHANGED
@@ -34,7 +34,7 @@ class SAM(torch.optim.Optimizer):
34
34
 
35
35
  @torch.no_grad()
36
36
  def first_step(self, zero_grad=False):
37
- '''计算并应用参数扰动 epsilon。
37
+ r'''计算并应用参数扰动 epsilon。
38
38
 
39
39
  该步骤根据当前梯度将模型参数 $w$ 更新为 $w + \epsilon$。
40
40
 
@@ -55,7 +55,7 @@ class SAM(torch.optim.Optimizer):
55
55
 
56
56
  @torch.no_grad()
57
57
  def second_step(self, zero_grad=False):
58
- '''恢复参数并执行真正的梯度更新。
58
+ r'''恢复参数并执行真正的梯度更新。
59
59
 
60
60
  该步骤将参数从 $w + \epsilon$ 恢复回 $w$,并利用扰动位置的梯度更新 $w$。
61
61
 
@@ -89,4 +89,11 @@ class SAM(torch.optim.Optimizer):
89
89
  return norm
90
90
 
91
91
  def step(self, closure=None):
92
- raise NotImplementedError('SAM requires steps to be run manually: first_step and second_step')
92
+ raise NotImplementedError('SAM requires steps to be run manually: first_step and second_step')
93
+
94
+ def load_state_dict(self, state_dict):
95
+ self.base_optimizer.load_state_dict(state_dict)
96
+ self.param_groups = self.base_optimizer.param_groups
97
+
98
+ def state_dict(self):
99
+ return self.base_optimizer.state_dict()
orbit/plugin/__init__.py CHANGED
@@ -1,10 +1,14 @@
1
- try: from orbit.plugin.classification import ClassificationReport
1
+ try: from .classification import ClassificationReport
2
2
  except: pass
3
3
 
4
- from orbit.plugin.warmup import Warmup
5
- from orbit.plugin.early_stopping import EarlyStopping
6
- from orbit.plugin.gradient_accumulation import GradientAccumulation
7
- from orbit.plugin.mentor import Mentor
8
- from orbit.plugin.ema import EMA # Not tested
9
- from orbit.plugin.memory_estimator import MemoryEstimator
10
- from orbit.plugin.overfit import Overfit
4
+ from .checkpoint import Checkpoint
5
+ from .board import Board
6
+ from .display_model import ModelSummary
7
+ from .warmup import Warmup
8
+ from .early_stopping import EarlyStopping
9
+ from .gradient_accumulation import GradientAccumulation
10
+ from .mentor import Mentor
11
+ from .ema import EMA # Not tested
12
+ from .memory_estimator import MemoryEstimator
13
+ from .overfit import Overfit
14
+ from .lora import LoRA
orbit/plugin/board.py CHANGED
@@ -2,8 +2,7 @@ from torch.utils.tensorboard import SummaryWriter
2
2
  from typing import Optional, TYPE_CHECKING
3
3
  from orbit.callback import Callback, Event
4
4
 
5
- if TYPE_CHECKING:
6
- from ..engine import Engine
5
+ if TYPE_CHECKING: from orbit.engine import Engine
7
6
 
8
7
  class Board(Callback):
9
8
  def __init__(self, name: str, log_dir: str):
@@ -1,10 +1,14 @@
1
1
  import os
2
2
  import torch
3
+ import json
3
4
  from typing import TYPE_CHECKING, List, Tuple
4
5
  from orbit.callback import Callback, Event
5
6
 
6
- if TYPE_CHECKING:
7
- from orbit.engine import Engine
7
+ from safetensors.torch import save_file as safe_save_file
8
+ from safetensors.torch import load_file as safe_load_file
9
+ from safetensors.torch import safe_open
10
+
11
+ if TYPE_CHECKING: from orbit.engine import Engine
8
12
 
9
13
  class Checkpoint(Callback):
10
14
  def __init__(
@@ -17,6 +21,7 @@ class Checkpoint(Callback):
17
21
  save_top_k: int = 1,
18
22
  save_last: bool = True,
19
23
  every_n_train_steps: int = None,
24
+ use_safetensors: bool = False,
20
25
  verbose: bool = True
21
26
  ):
22
27
  """
@@ -29,6 +34,7 @@ class Checkpoint(Callback):
29
34
  save_top_k (int): 保存最好的 K 个模型。设为 0 则禁用 Top-K 保存。
30
35
  save_last (bool): 是否总是保存 '{name}_last.pt'。
31
36
  every_n_train_steps (int): 每隔多少个训练步保存一次。
37
+ use_safetensors (bool): 是否使用 safetensors 格式保存。注意:safetensors 不支持保存优化器状态。
32
38
  verbose (bool): 是否打印保存信息。
33
39
  """
34
40
  super().__init__()
@@ -40,8 +46,12 @@ class Checkpoint(Callback):
40
46
  self.save_top_k = save_top_k
41
47
  self.save_last = save_last
42
48
  self.every_n_train_steps = every_n_train_steps
49
+ self.use_safetensors = use_safetensors
43
50
  self.verbose = verbose
44
51
 
52
+ if self.use_safetensors and not self.save_weights_only:
53
+ print("[yellow]Warning: safetensors does not support saving optimizer state. Setting save_weights_only=True implicitly.[/]")
54
+
45
55
  # 维护 Top-K 模型列表: [(score, filename), ...]
46
56
  self.best_k_models: List[Tuple[float, str]] = []
47
57
 
@@ -65,12 +75,20 @@ class Checkpoint(Callback):
65
75
  if self._meta_key in engine.meta:
66
76
  self.best_k_models = engine.meta[self._meta_key].get('best_k_models', [])
67
77
 
68
- load_path = os.path.join(self.path, self.name + "_last.pt").replace("\\", "/")
78
+ ext = ".safetensors" if self.use_safetensors else ".pt"
79
+ load_path = os.path.join(self.path, self.name + "_last" + ext).replace("\\", "/")
69
80
 
70
81
  if os.path.exists(load_path):
71
82
  self._load(engine, load_path)
72
83
  else:
73
- engine.print(f"[yellow]Warning: Resume checkpoint '{load_path}' not found. Starting from scratch.[/]", plugin='Checkpointing')
84
+ # 尝试查找另一种格式
85
+ alt_ext = ".pt" if self.use_safetensors else ".safetensors"
86
+ alt_path = os.path.join(self.path, self.name + "_last" + alt_ext).replace("\\", "/")
87
+ if os.path.exists(alt_path):
88
+ engine.print(f"[yellow]Found checkpoint with alternative extension: {alt_path}[/]", plugin='Checkpointing')
89
+ self._load(engine, alt_path)
90
+ else:
91
+ engine.print(f"[yellow]Warning: Resume checkpoint '{load_path}' not found. Starting from scratch.[/]", plugin='Checkpointing')
74
92
 
75
93
  def on_batch_end(self, event: Event):
76
94
  """
@@ -81,7 +99,8 @@ class Checkpoint(Callback):
81
99
  step = event.engine.global_step
82
100
  if step > 0 and step % self.every_n_train_steps == 0:
83
101
  # 保存 step checkpoint
84
- filename = f"{self.name}_step_{step}.pt"
102
+ ext = ".safetensors" if self.use_safetensors else ".pt"
103
+ filename = f"{self.name}_step_{step}{ext}"
85
104
 
86
105
  # 传递 is_step=True
87
106
  self._save(event.engine, filename, verbose=self.verbose, is_step=True)
@@ -93,7 +112,7 @@ class Checkpoint(Callback):
93
112
 
94
113
  # 同时更新 last checkpoint
95
114
  if self.save_last:
96
- self._save(event.engine, f"{self.name}_last.pt", verbose=False, is_step=True)
115
+ self._save(event.engine, f"{self.name}_last{ext}", verbose=False, is_step=True)
97
116
 
98
117
  def on_epoch_end(self, event: Event):
99
118
  """
@@ -102,9 +121,11 @@ class Checkpoint(Callback):
102
121
  2. 如果设置了 monitor,保存 top_k
103
122
  """
104
123
  engine = event.engine
124
+ ext = ".safetensors" if self.use_safetensors else ".pt"
125
+
105
126
  # 1. Save Last
106
127
  if self.save_last:
107
- self._save(engine, f"{self.name}_last.pt", verbose=False) # last 不需要每次都啰嗦
128
+ self._save(engine, f"{self.name}_last{ext}", verbose=False) # last 不需要每次都啰嗦
108
129
 
109
130
  # 2. Save Top K
110
131
  if self.monitor and self.save_top_k > 0:
@@ -119,7 +140,8 @@ class Checkpoint(Callback):
119
140
 
120
141
  def _check_and_save_top_k(self, engine: 'Engine', current_score: float):
121
142
  """检查并保存 Top-K 模型"""
122
- filename = f"{self.name}_ep{engine.epoch+1}_{self.monitor}_{current_score:.4f}.pt"
143
+ ext = ".safetensors" if self.use_safetensors else ".pt"
144
+ filename = f"{self.name}_ep{engine.epoch+1}_{self.monitor}_{current_score:.4f}{ext}"
123
145
 
124
146
  # 逻辑简化:总是先加入,然后排序,如果超过 K 个,删除最差的
125
147
  self.best_k_models.append((current_score, filename))
@@ -154,24 +176,44 @@ class Checkpoint(Callback):
154
176
 
155
177
  # 获取原始模型 (去除 DataParallel 包装) 以保证 Checkpoint 通用性
156
178
  raw_model = engine.unwrap_model()
157
-
158
- state = {
159
- 'epoch': engine.epoch,
160
- 'global_step': engine.global_step,
161
- 'batch_idx': engine.batch_idx,
162
- 'is_step': is_step,
163
- 'model_state_dict': raw_model.state_dict(),
164
- 'optimizer_state_dict': engine.optimizer.state_dict() if engine.optimizer else None,
165
- 'scheduler_state_dict': engine.scheduler.state_dict() if engine.scheduler else None,
166
- 'scaler_state_dict': engine.scaler.state_dict() if engine.scaler else None,
167
- 'meta': engine.meta,
168
- }
169
- if self.save_weights_only:
170
- state = raw_model.state_dict()
171
-
172
179
  file_path = os.path.join(self.path, filename)
180
+
173
181
  try:
174
- torch.save(state, file_path)
182
+ if self.use_safetensors and filename.endswith('.safetensors'):
183
+ # Safetensors 模式:仅保存权重和元数据
184
+ metadata = {
185
+ 'epoch': str(engine.epoch),
186
+ 'global_step': str(engine.global_step),
187
+ 'batch_idx': str(engine.batch_idx),
188
+ 'is_step': str(is_step),
189
+ # 注意:safetensors metadata 只能是字符串
190
+ }
191
+ # 尝试序列化 meta
192
+ try:
193
+ metadata['meta'] = json.dumps(engine.meta)
194
+ except:
195
+ pass
196
+
197
+ safe_save_file(raw_model.state_dict(), file_path, metadata=metadata)
198
+
199
+ else:
200
+ # Torch 模式:保存完整状态
201
+ state = {
202
+ 'epoch': engine.epoch,
203
+ 'global_step': engine.global_step,
204
+ 'batch_idx': engine.batch_idx,
205
+ 'is_step': is_step,
206
+ 'model_state_dict': raw_model.state_dict(),
207
+ 'optimizer_state_dict': engine.optimizer.state_dict() if engine.optimizer else None,
208
+ 'scheduler_state_dict': engine.scheduler.state_dict() if engine.scheduler else None,
209
+ 'scaler_state_dict': engine.scaler.state_dict() if engine.scaler else None,
210
+ 'meta': engine.meta,
211
+ }
212
+ if self.save_weights_only:
213
+ state = raw_model.state_dict()
214
+
215
+ torch.save(state, file_path)
216
+
175
217
  if verbose:
176
218
  # 使用相对路径显示,更简洁
177
219
  rel_path = os.path.relpath(file_path)
@@ -194,52 +236,85 @@ class Checkpoint(Callback):
194
236
  """加载 Checkpoint 的核心逻辑"""
195
237
  engine.print(f"[cyan]Loading checkpoint from: {file_path}[/]", plugin='Checkpointing')
196
238
  try:
197
- # 加载到设备
198
- checkpoint = torch.load(file_path, map_location=engine.device)
199
-
200
239
  # 获取原始模型以进行加载
201
240
  raw_model = engine.unwrap_model()
202
-
203
- # 1. 加载模型权重
204
- if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
205
- raw_model.load_state_dict(checkpoint['model_state_dict'])
206
- else:
207
- raw_model.load_state_dict(checkpoint)
208
- engine.print("[yellow]Loaded model weights only (legacy format).[/]", plugin='Checkpointing')
209
- return
210
241
 
211
- # 2. 恢复训练状态
212
- if not self.save_weights_only:
213
- if engine.optimizer and 'optimizer_state_dict' in checkpoint:
214
- engine.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
215
-
216
- if engine.scheduler and 'scheduler_state_dict' in checkpoint:
217
- engine.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
242
+ if file_path.endswith('.safetensors'):
243
+ # 加载权重
244
+ state_dict = safe_load_file(file_path, device=str(engine.device))
245
+ raw_model.load_state_dict(state_dict)
218
246
 
219
- if engine.scaler and 'scaler_state_dict' in checkpoint:
220
- engine.scaler.load_state_dict(checkpoint['scaler_state_dict'])
247
+ # 尝试恢复元数据
248
+ with safe_open(file_path, framework="pt", device=str(engine.device)) as f:
249
+ metadata = f.metadata()
250
+ if metadata:
251
+ loaded_epoch = int(metadata.get('epoch', 0))
252
+ loaded_batch_idx = int(metadata.get('batch_idx', -1))
253
+ is_step = metadata.get('is_step', 'False') == 'True'
254
+ engine.global_step = int(metadata.get('global_step', 0))
255
+
256
+ if 'meta' in metadata:
257
+ try:
258
+ engine.meta.update(json.loads(metadata['meta']))
259
+ except: pass
260
+
261
+ if is_step:
262
+ engine.start_epoch = loaded_epoch
263
+ engine.start_batch_idx = loaded_batch_idx
264
+ msg = f"Epoch {engine.start_epoch}, Batch {engine.start_batch_idx + 1}"
265
+ else:
266
+ engine.start_epoch = loaded_epoch + 1
267
+ engine.start_batch_idx = -1
268
+ msg = f"Epoch {engine.start_epoch}"
269
+
270
+ engine.print(f"[green]Resumed weights from {msg}. Note: Optimizer state not restored (safetensors).[/]", plugin='Checkpointing')
271
+ else:
272
+ engine.print("[yellow]Loaded weights only (no metadata in safetensors).[/]", plugin='Checkpointing')
221
273
 
222
- if 'meta' in checkpoint:
223
- engine.meta.update(checkpoint['meta'])
274
+ else:
275
+ # Torch 加载
276
+ checkpoint = torch.load(file_path, map_location=engine.device)
224
277
 
225
- loaded_epoch = checkpoint.get('epoch', 0)
226
- loaded_batch_idx = checkpoint.get('batch_idx', -1)
227
- is_step = checkpoint.get('is_step', False)
228
-
229
- if is_step:
230
- # 如果是 Step Checkpoint,从当前 Epoch 的下一个 Batch 继续
231
- engine.start_epoch = loaded_epoch
232
- engine.start_batch_idx = loaded_batch_idx
233
- msg = f"Epoch {engine.start_epoch}, Batch {engine.start_batch_idx + 1}"
278
+ # 1. 加载模型权重
279
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
280
+ raw_model.load_state_dict(checkpoint['model_state_dict'])
234
281
  else:
235
- # 如果是 Epoch Checkpoint,从下一个 Epoch 开始
236
- engine.start_epoch = loaded_epoch + 1
237
- engine.start_batch_idx = -1
238
- msg = f"Epoch {engine.start_epoch}"
239
-
240
- engine.global_step = checkpoint.get('global_step', 0)
282
+ raw_model.load_state_dict(checkpoint)
283
+ engine.print("[yellow]Loaded model weights only (legacy format).[/]", plugin='Checkpointing')
284
+ return
241
285
 
242
- engine.print(f"[green]Successfully resumed training from {msg}, Global Step {engine.global_step}[/]", plugin='Checkpointing')
286
+ # 2. 恢复训练状态
287
+ if not self.save_weights_only:
288
+ if engine.optimizer and 'optimizer_state_dict' in checkpoint:
289
+ engine.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
290
+
291
+ if engine.scheduler and 'scheduler_state_dict' in checkpoint:
292
+ engine.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
293
+
294
+ if engine.scaler and 'scaler_state_dict' in checkpoint:
295
+ engine.scaler.load_state_dict(checkpoint['scaler_state_dict'])
296
+
297
+ if 'meta' in checkpoint:
298
+ engine.meta.update(checkpoint['meta'])
299
+
300
+ loaded_epoch = checkpoint.get('epoch', 0)
301
+ loaded_batch_idx = checkpoint.get('batch_idx', -1)
302
+ is_step = checkpoint.get('is_step', False)
303
+
304
+ if is_step:
305
+ # 如果是 Step Checkpoint,从当前 Epoch 的下一个 Batch 继续
306
+ engine.start_epoch = loaded_epoch
307
+ engine.start_batch_idx = loaded_batch_idx
308
+ msg = f"Epoch {engine.start_epoch}, Batch {engine.start_batch_idx + 1}"
309
+ else:
310
+ # 如果是 Epoch Checkpoint,从下一个 Epoch 开始
311
+ engine.start_epoch = loaded_epoch + 1
312
+ engine.start_batch_idx = -1
313
+ msg = f"Epoch {engine.start_epoch}"
314
+
315
+ engine.global_step = checkpoint.get('global_step', 0)
316
+
317
+ engine.print(f"[green]Successfully resumed training from {msg}, Global Step {engine.global_step}[/]", plugin='Checkpointing')
243
318
 
244
319
  except Exception as e:
245
320
  engine.print(f"[red]Failed to load checkpoint: {e}[/]", plugin='Checkpointing')
@@ -7,8 +7,8 @@ from typing import List, Optional, TYPE_CHECKING
7
7
  import rich.box as box
8
8
 
9
9
  from orbit.callback import Callback, Event
10
- if TYPE_CHECKING:
11
- from ..engine import Engine
10
+
11
+ if TYPE_CHECKING: from orbit.engine import Engine
12
12
 
13
13
  class ClassificationReport(Callback):
14
14
  def __init__(
@@ -6,8 +6,7 @@ import rich.box as box
6
6
 
7
7
  from orbit.callback import Callback, Event
8
8
 
9
- if TYPE_CHECKING:
10
- from ..engine import Engine
9
+ if TYPE_CHECKING: from orbit.engine import Engine
11
10
 
12
11
  class ModelSummary(Callback):
13
12
  def __init__(self, max_depth: int = 3):
@@ -2,8 +2,7 @@ import numpy as np
2
2
  from typing import TYPE_CHECKING
3
3
  from orbit.callback import Callback, Event
4
4
 
5
- if TYPE_CHECKING:
6
- from orbit.engine import Engine
5
+ if TYPE_CHECKING: from orbit.engine import Engine
7
6
 
8
7
  class EarlyStopping(Callback):
9
8
  """
orbit/plugin/ema.py CHANGED
@@ -3,8 +3,7 @@ import torch
3
3
  from orbit.callback import Callback, Event
4
4
  from typing import TYPE_CHECKING, Dict
5
5
 
6
- if TYPE_CHECKING:
7
- from orbit.engine import Engine
6
+ if TYPE_CHECKING: from orbit.engine import Engine
8
7
 
9
8
  class EMA(Callback):
10
9
  """
@@ -1,8 +1,7 @@
1
1
  from typing import TYPE_CHECKING
2
2
  from orbit.callback import Callback, Event
3
3
 
4
- if TYPE_CHECKING:
5
- from orbit.engine import Engine
4
+ if TYPE_CHECKING: from orbit.engine import Engine
6
5
 
7
6
  class GradientAccumulation(Callback):
8
7
  """