orbit-torch 0.0.4a1__py3-none-any.whl → 0.1.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. orbit/__init__.py +3 -1
  2. orbit/callback.py +4 -3
  3. orbit/dataset/__init__.py +1 -0
  4. orbit/dataset/cogn.py +138 -0
  5. orbit/dataset/data/cogn_en.jsonl +45 -0
  6. orbit/dataset/data/cogn_zh.jsonl +113 -0
  7. orbit/engine.py +210 -146
  8. orbit/kit/__init__.py +2 -0
  9. orbit/kit/interface.py +154 -0
  10. orbit/kit/wrapper.py +157 -0
  11. orbit/model/__init__.py +5 -0
  12. orbit/model/base.py +125 -0
  13. orbit/model/block/__init__.py +34 -0
  14. orbit/model/block/attention.py +265 -0
  15. orbit/model/block/bio.py +537 -0
  16. orbit/model/block/codebook.py +122 -0
  17. orbit/model/block/conv.py +505 -0
  18. orbit/model/block/embedding.py +252 -0
  19. orbit/model/block/film.py +176 -0
  20. orbit/model/block/fusion.py +335 -0
  21. orbit/model/block/gate.py +334 -0
  22. orbit/model/block/lora.py +776 -0
  23. orbit/model/block/mlp.py +68 -0
  24. orbit/model/block/moe.py +94 -0
  25. orbit/model/block/tcn.py +99 -0
  26. orbit/model/config.py +62 -0
  27. orbit/model/kit/__init__.py +6 -0
  28. orbit/model/kit/discriminator.py +46 -0
  29. orbit/model/kit/losses.py +193 -0
  30. orbit/model/motif/__init__.py +0 -0
  31. orbit/model/motif/vision/__init__.py +0 -0
  32. orbit/model/motif/vision/v1.py +645 -0
  33. orbit/model/registry.py +53 -0
  34. orbit/optim/__init__.py +2 -2
  35. orbit/optim/sam.py +10 -3
  36. orbit/plugin/__init__.py +12 -8
  37. orbit/plugin/board.py +1 -2
  38. orbit/plugin/checkpoint.py +137 -62
  39. orbit/plugin/classification.py +2 -2
  40. orbit/plugin/display_model.py +1 -2
  41. orbit/plugin/early_stopping.py +1 -2
  42. orbit/plugin/ema.py +1 -2
  43. orbit/plugin/gradient_accumulation.py +1 -2
  44. orbit/plugin/lora.py +346 -0
  45. orbit/plugin/memory_estimator.py +1 -2
  46. orbit/plugin/warmup.py +1 -2
  47. orbit/utils/__init__.py +24 -1
  48. orbit/utils/cuda.py +10 -0
  49. orbit/utils/freeze.py +61 -17
  50. orbit/utils/image.py +164 -0
  51. orbit/utils/initialization.py +184 -94
  52. orbit/utils/layer_io.py +66 -7
  53. orbit/utils/lora.py +480 -0
  54. orbit/utils/moe.py +55 -0
  55. orbit/utils/seed.py +3 -19
  56. orbit/utils/sft.py +93 -0
  57. orbit_torch-0.1.0b1.dist-info/METADATA +208 -0
  58. orbit_torch-0.1.0b1.dist-info/RECORD +65 -0
  59. orbit_torch-0.0.4a1.dist-info/METADATA +0 -25
  60. orbit_torch-0.0.4a1.dist-info/RECORD +0 -29
  61. {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/WHEEL +0 -0
  62. {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/top_level.txt +0 -0
orbit/engine.py CHANGED
@@ -8,20 +8,21 @@ from typing import Any, List, Optional, Union, Dict, Tuple
8
8
  try: from torch.utils.tensorboard import SummaryWriter
9
9
  except: pass
10
10
 
11
- from rich.progress import Progress, TextColumn, BarColumn, TimeRemainingColumn, MofNCompleteColumn
12
- from rich.console import Console
11
+ from rich.progress import Progress, TextColumn, BarColumn, TimeRemainingColumn, MofNCompleteColumn
12
+ from rich.console import Console
13
+ from accelerate import Accelerator
13
14
 
14
15
  from orbit.callback import Callback, Forward, Event
15
- from orbit.plugin.checkpoint import Checkpoint
16
- from orbit.plugin.board import Board
17
- from orbit.plugin.display_model import ModelSummary
16
+ from orbit.plugin import Checkpoint, Board, ModelSummary
17
+ from orbit.utils import load_model
18
+
18
19
 
19
20
  class Engine:
20
21
  '''训练循环控制器,负责协调模型训练、验证及回调事件。
21
22
 
22
- Engine 封装了 PyTorch 的训练循环,提供了插件机制(Callback),
23
- 支持自动混合精度训练(AMP)、梯度裁剪、梯度累积、Checkpoint 保存、
24
- TensorBoard 可视化等功能。
23
+ Engine 封装了 PyTorch 的训练循环,并深度集成了 Accelerate 库。
24
+ 它提供了开箱即用的分布式训练支持(DDP, FSDP, DeepSpeed 等)、
25
+ 自动混合精度(AMP)、梯度裁剪、梯度累积、Checkpoint 管理以及 TensorBoard 可视化等功能。
25
26
  '''
26
27
 
27
28
  class _OutLogs:
@@ -39,63 +40,70 @@ class Engine:
39
40
  model: nn.Module,
40
41
  optimizer: torch.optim.Optimizer = None,
41
42
  criterion: nn.Module = None,
42
- device: Optional[str] = None,
43
- device_ids: Optional[List[int]] = None,
44
- use_amp: bool = False,
43
+ accelerator: Optional[Accelerator] = None,
44
+ mixed_precision: str = 'no', # 'no', 'fp16', 'bf16'
45
45
  grad_clip_norm: float = None,
46
46
  scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
47
47
  plugins: List[Callback] = None,
48
48
  forward_step: Optional[Forward] = None,
49
49
  checkpoint_dir: str = None,
50
50
  console: Console = None,
51
+ **accelerator_kwargs
51
52
  ):
52
53
  '''初始化 Engine 实例。
53
54
 
54
55
  Args:
55
56
  model (nn.Module): 要训练的 PyTorch 模型。
56
- optimizer (torch.optim.Optimizer, optional): 优化器。如果为 None,则需要在其他地方(如插件)手动处理或稍后赋值。
57
+ optimizer (torch.optim.Optimizer, optional): 优化器。如果为 None,则需要在其他地方手动处理。
57
58
  criterion (nn.Module, optional): 损失函数。如果为 None,则假设模型输出包含 loss 或自定义 loss 计算。
58
- device (Optional[str], optional): 运行设备 ('cpu', 'cuda', 'cuda:0' 等)。如果为 None,则自动检测。
59
- device_ids (Optional[List[int]], optional): GPU 设备 ID 列表。如果提供且长度 > 1,将启用 DataParallel。
60
- use_amp (bool, optional): 是否启用自动混合精度 (Automatic Mixed Precision) 训练。默认为 False
59
+ accelerator (Optional[Accelerator], optional): 预初始化的 Accelerator 实例。
60
+ 如果为 None,将使用 mixed_precision accelerator_kwargs 创建一个新的实例。
61
+ mixed_precision (str, optional): 混合精度模式。可选值为 'no', 'fp16', 'bf16'。默认为 'no'
61
62
  grad_clip_norm (float, optional): 梯度裁剪的范数阈值。如果为 None,则不进行梯度裁剪。
62
63
  scheduler (Optional[torch.optim.lr_scheduler._LRScheduler], optional): 学习率调度器。
63
64
  plugins (List[Callback], optional): 初始化时要挂载的回调插件列表。
64
65
  forward_step (Optional[Forward], optional): 自定义前向传播和 Loss 计算逻辑的实现。
65
66
  checkpoint_dir (str, optional): 快速设置 Checkpoint 保存目录的快捷参数。
66
67
  console (Console, optional): 用于输出日志的 Rich Console 实例。如果为 None,则创建一个新的。
68
+ **accelerator_kwargs: 传递给 Accelerator 构造函数的其他参数(如 cpu, split_batches 等)。
67
69
  '''
68
- # --- 基础组件 ---
69
- self.device_ids = device_ids
70
-
71
- if self.device_ids and len(self.device_ids) > 0:
72
- self.device = torch.device(f"cuda:{self.device_ids[0]}")
73
- elif device is None:
74
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
70
+ # --- Accelerator 初始化 ---
71
+ if accelerator:
72
+ self.accelerator = accelerator
75
73
  else:
76
- self.device = torch.device(device)
77
-
78
- # 移动模型到主设备
79
- model = model.to(self.device)
80
-
81
- # 多显卡处理 (DataParallel)
82
- if self.device_ids and len(self.device_ids) > 1:
83
- self.model = nn.DataParallel(model, device_ids=self.device_ids)
74
+ self.accelerator = Accelerator(mixed_precision=mixed_precision, **accelerator_kwargs)
75
+
76
+ # --- 基础组件 ---
77
+ self.model = model
78
+ self.optimizer = optimizer
79
+ self.criterion = criterion
80
+ self.scheduler = scheduler
81
+
82
+ # 准备模型、优化器和调度器
83
+ # 注意:DataLoader 将在 run/train/eval 中准备
84
+ if self.optimizer:
85
+ if self.scheduler:
86
+ self.model, self.optimizer, self.scheduler = self.accelerator.prepare(
87
+ self.model, self.optimizer, self.scheduler
88
+ )
89
+ else:
90
+ self.model, self.optimizer = self.accelerator.prepare(
91
+ self.model, self.optimizer
92
+ )
84
93
  else:
85
- self.model = model
94
+ self.model = self.accelerator.prepare(self.model)
86
95
 
87
96
  self.model_name = self.unwrap_model().__class__.__name__
88
- self.optimizer = optimizer
89
- self.criterion = criterion
90
97
 
91
98
  # --- 训练配置 ---
92
- self.use_amp = use_amp
93
99
  self.grad_clip_norm = grad_clip_norm
94
- self.scheduler = scheduler
95
100
  self.forward_step = forward_step
96
- self.scaler = torch.amp.GradScaler('cuda', enabled=use_amp)
97
101
 
98
102
  # --- 交互与回调 ---
103
+ self.prepare = self.accelerator.prepare
104
+ self.backward = self.accelerator.backward
105
+ self.autocast = self.accelerator.autocast
106
+
99
107
  self.console = console if console else Console()
100
108
  self.out_logs = self._OutLogs(self)
101
109
  self.writer: Optional[SummaryWriter] = None
@@ -145,6 +153,21 @@ class Engine:
145
153
  # 触发初始化回调
146
154
  self._fire_event("on_init")
147
155
 
156
+ @property
157
+ def device(self):
158
+ '''获取当前使用的设备。'''
159
+ return self.accelerator.device
160
+
161
+ @property
162
+ def use_amp(self) -> bool:
163
+ '''是否启用了混合精度训练。'''
164
+ return self.accelerator.mixed_precision != 'no'
165
+
166
+ @property
167
+ def scaler(self):
168
+ '''获取 GradScaler 对象(如果存在)。'''
169
+ return self.accelerator.scaler
170
+
148
171
  def stop(self, source: str = "User", reason: str = "Unknown"):
149
172
  '''请求停止训练。
150
173
 
@@ -157,10 +180,15 @@ class Engine:
157
180
  self.stop_reason = reason
158
181
 
159
182
  def unwrap_model(self) -> nn.Module:
160
- '''获取原始模型对象 (去除 DataParallel/DistributedDataParallel 包装)。'''
161
- if isinstance(self.model, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
162
- return self.model.module
163
- return self.model
183
+ '''获取原始模型对象。
184
+
185
+ 使用 accelerator.unwrap_model 去除 DDP/FSDP/DeepSpeed 等分布式包装,
186
+ 返回原始的 nn.Module 实例。
187
+
188
+ Returns:
189
+ nn.Module: 原始模型对象。
190
+ '''
191
+ return self.accelerator.unwrap_model(self.model)
164
192
 
165
193
  def is_in_warmup(self) -> bool:
166
194
  '''检查当前是否处于 Warmup 阶段。
@@ -179,20 +207,25 @@ class Engine:
179
207
  def init_board(self, log_dir: str = 'runs') -> 'Engine':
180
208
  '''初始化 TensorBoard 可视化插件 (Board)。
181
209
 
210
+ 注意:此操作仅在主进程 (Main Process) 中执行,以避免多进程写入冲突。
211
+
182
212
  Args:
183
213
  log_dir (str, optional): TensorBoard 日志保存目录。默认为 'runs'。
184
214
 
185
215
  Returns:
186
216
  Engine: 返回 Engine 实例自身以支持链式调用。
187
217
  '''
188
- board = Board(name=self.model_name, log_dir=log_dir)
189
- self.attach(board, init=True)
218
+ # 仅在主进程初始化 Board
219
+ if self.accelerator.is_main_process:
220
+ board = Board(name=self.model_name, log_dir=log_dir)
221
+ self.attach(board, init=True)
190
222
  return self
191
223
 
192
224
  def set_checkpoint(self, dir: str, name: Optional[str] = None, **kwargs) -> 'Engine':
193
225
  '''配置 Checkpoint 插件。
194
226
 
195
227
  如果已存在 Checkpoint 插件,将被新配置替换。
228
+ 注意:此操作仅在主进程 (Main Process) 中执行。
196
229
 
197
230
  Args:
198
231
  dir (str): 模型保存目录。
@@ -202,36 +235,38 @@ class Engine:
202
235
  Returns:
203
236
  Engine: 返回 Engine 实例自身以支持链式调用。
204
237
  '''
238
+ # 仅在主进程配置 Checkpoint
239
+ if not self.accelerator.is_main_process:
240
+ return self
241
+
205
242
  if name is None:
206
243
  name = self.model_name
207
244
 
208
- # 1. 移除旧的 Checkpoint 插件 (如果存在)
209
245
  self.plugins = [p for p in self.plugins if not isinstance(p, Checkpoint)]
210
246
 
211
- # 2. 创建新插件
212
247
  ckpt = Checkpoint(name=name, path=dir, **kwargs)
213
-
214
- # 3. 调用 ckpt.on_init(event)
215
- # 注意:这里我们手动构造 Event,因为此时可能不在 run 循环中
216
248
  ckpt.on_init(Event(engine=self, name="on_init"))
217
-
218
- # 4. 挂载
219
249
  self.attach(ckpt)
220
250
  return self
221
251
 
222
252
  def _print_edge(self, top=True):
253
+ if not self.accelerator.is_main_process: return
223
254
  char = '┬' if top else '┴'
224
255
  self.console.print(' ' + '─' * 15 + char + '─' * 35)
225
256
 
226
257
  def print(self, *args, plugin: Optional[str] = None, **kwargs):
227
258
  '''统一日志打印方法,支持插件前缀。
228
259
 
260
+ 注意:此方法仅在主进程 (Main Process) 中输出日志,其他进程的调用将被忽略。
261
+
229
262
  Args:
230
263
  *args: 要打印的内容。
231
264
  plugin (str, optional): 插件名称。如果提供,将以固定宽度和特定颜色打印前缀,
232
265
  用于区分不同来源的日志。
233
266
  **kwargs: 传递给 console.print 的其他参数。
234
267
  '''
268
+ if not self.accelerator.is_main_process: return
269
+
235
270
  if plugin:
236
271
  # 宽度 15, 右对齐, 青色加粗
237
272
  prefix = f"[[bold cyan]{plugin:>15}[/]] "
@@ -273,9 +308,6 @@ class Engine:
273
308
  for cb in self.plugins:
274
309
  method = getattr(cb, event_name, None)
275
310
  if method:
276
- # [修改] 移除 try-except pass。
277
- # 我们需要看到 Callback 里的错误,否则调试是地狱。
278
- # 如果一定要防御性编程,可以使用 console.print_exception()
279
311
  method(event)
280
312
 
281
313
  def _process_batch_data(self, batch_data: Any):
@@ -283,12 +315,30 @@ class Engine:
283
315
 
284
316
  支持 Tensor, List[Tensor], Dict[str, Tensor] 等常见格式。
285
317
  自动解析并设置 self.data 和 self.target。
318
+
319
+ 注意:如果 DataLoader 已经被 accelerator.prepare 处理,数据通常已经在正确设备上。
320
+ 此方法包含额外的检查以确保数据位于 self.device 上。
286
321
 
287
322
  Args:
288
323
  batch_data (Any): DataLoader 产生的一个 Batch 数据。
289
324
  '''
325
+ def to_device(data):
326
+ if isinstance(data, torch.Tensor):
327
+ return data.to(self.device)
328
+ elif isinstance(data, (list, tuple)):
329
+ return [to_device(x) for x in data]
330
+ elif isinstance(data, dict):
331
+ return {k: to_device(v) for k, v in data.items()}
332
+ return data
333
+
334
+ # 如果数据不在当前设备上,移动它
335
+ # 注意:accelerator.prepare 后的 DataLoader 通常已经处理了设备放置
336
+
290
337
  if isinstance(batch_data, (list, tuple)):
291
- batch_data = [x.to(self.device) if isinstance(x, torch.Tensor) else x for x in batch_data]
338
+ # 简单检查第一个元素
339
+ if len(batch_data) > 0 and isinstance(batch_data[0], torch.Tensor) and batch_data[0].device != self.device:
340
+ batch_data = to_device(batch_data)
341
+
292
342
  if len(batch_data) == 2:
293
343
  self.data, self.target = batch_data
294
344
  elif len(batch_data) == 1:
@@ -298,36 +348,44 @@ class Engine:
298
348
  self.data = batch_data[:-1]
299
349
  self.target = batch_data[-1]
300
350
  elif isinstance(batch_data, dict):
301
- self.data = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in batch_data.items()}
351
+ # 简单检查第一个值
352
+ first_val = next(iter(batch_data.values()), None)
353
+ if isinstance(first_val, torch.Tensor) and first_val.device != self.device:
354
+ batch_data = to_device(batch_data)
355
+
356
+ self.data = batch_data
302
357
  self.target = None
303
358
  else:
304
- self.data = batch_data.to(self.device)
359
+ if isinstance(batch_data, torch.Tensor) and batch_data.device != self.device:
360
+ batch_data = batch_data.to(self.device)
361
+ self.data = batch_data
305
362
  self.target = None
306
363
 
307
364
  def update(self, loss: torch.Tensor):
308
365
  '''执行反向传播及参数更新。
309
366
 
367
+ 使用 accelerator.backward 处理反向传播,支持自动混合精度和梯度累积。
368
+ 同时支持 SAM (Sharpness-Aware Minimization) 优化器的两步更新逻辑。
369
+
310
370
  Args:
311
371
  loss (torch.Tensor): 当前 Step 的 Loss。
312
372
  '''
313
373
  if self.is_epoch_end: return
314
374
 
315
- # 保存原始 Loss 用于日志 (因为 SAM 需要第二次 forward 会覆盖 self.loss)
316
375
  original_loss = loss
317
376
  self.loss = loss
318
377
 
319
- # 1. 梯度累积:Loss 缩放 (仅用于 Backward)
378
+ if self.optimizer is None:
379
+ self.global_step += 1
380
+ return
381
+
382
+ # 梯度累积处理
320
383
  backward_loss = loss
321
384
  if self.accumulation_steps > 1:
322
385
  backward_loss = loss / self.accumulation_steps
323
386
 
324
- # 2. Backward 1 (计算梯度)
325
- if self.use_amp and self.scaler:
326
- self.scaler.scale(backward_loss).backward()
327
- else:
328
- backward_loss.backward()
387
+ self.accelerator.backward(backward_loss)
329
388
 
330
- # 3. Optimizer Step (仅在累积步数到达或 Epoch 结束时执行)
331
389
  if (self.batch_idx + 1) % self.accumulation_steps == 0 or self.is_last_batch:
332
390
 
333
391
  # 检测是否为 SAM 优化器 (Duck Typing)
@@ -335,73 +393,43 @@ class Engine:
335
393
 
336
394
  if is_sam:
337
395
  # --- SAM Optimizer Logic ---
338
- if self.use_amp and self.scaler:
339
- # AMP 下的 SAM 处理
340
- # 3.1. Unscale 梯度以便 first_step 计算正确的 epsilon
341
- self.scaler.unscale_(self.optimizer)
342
-
343
- # 3.2. 梯度裁剪 (可选)
344
- if self.grad_clip_norm:
345
- torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_norm)
346
-
347
- # 3.3. SAM First Step: w -> w + e
348
- # 注意: 我们假设 unscale 后梯度有效。
349
- self.optimizer.first_step(zero_grad=True)
350
-
351
- # 3.4. Second Forward: 计算 w + e 处的 Loss
352
- # _forward_pass 会更新 self.loss, 所以我们最后需要恢复
353
- self._forward_pass()
354
-
355
- # 3.5. Second Backward: 计算 w + e 处的梯度
356
- self.scaler.scale(self.loss).backward()
357
-
358
- # 3.6. SAM Second Step: 恢复 w, 并更新 w
359
- # 需要再次 unscale 第二次计算的梯度
360
- self.scaler.unscale_(self.optimizer)
361
- if self.grad_clip_norm:
362
- torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_norm)
363
-
364
- self.optimizer.second_step(zero_grad=True)
365
- self.scaler.update()
366
-
367
- else:
368
- # 普通模式下的 SAM 处理
369
- if self.grad_clip_norm:
370
- torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_norm)
371
-
372
- self.optimizer.first_step(zero_grad=True)
373
-
374
- self._forward_pass()
375
- self.loss.backward()
376
-
377
- if self.grad_clip_norm:
378
- torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_norm)
379
-
380
- self.optimizer.second_step(zero_grad=True)
396
+ # SAM 需要两次 backward,accelerate 支持多次 backward
397
+
398
+ if self.grad_clip_norm:
399
+ self.accelerator.clip_grad_norm_(self.model.parameters(), self.grad_clip_norm)
400
+
401
+ # First Step: Compute e_w
402
+ self.optimizer.first_step(zero_grad=True)
403
+
404
+ # Second Forward-Backward
405
+ self._forward_pass()
406
+ self.accelerator.backward(self.loss)
407
+
408
+ if self.grad_clip_norm:
409
+ self.accelerator.clip_grad_norm_(self.model.parameters(), self.grad_clip_norm)
410
+
411
+ # Second Step: Update weights
412
+ self.optimizer.second_step(zero_grad=True)
381
413
 
382
- # 恢复原始 Loss 以保证日志记录的一致性
383
414
  self.loss = original_loss
384
415
 
385
416
  else:
386
417
  # --- Standard Optimizer Logic ---
387
- if self.use_amp and self.scaler:
388
- if self.grad_clip_norm:
389
- self.scaler.unscale_(self.optimizer)
390
- torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_norm)
391
- self.scaler.step(self.optimizer)
392
- self.scaler.update()
393
- else:
394
- if self.grad_clip_norm:
395
- torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_norm)
396
- self.optimizer.step()
418
+ if self.grad_clip_norm:
419
+ self.accelerator.clip_grad_norm_(self.model.parameters(), self.grad_clip_norm)
397
420
 
421
+ self.optimizer.step()
398
422
  self.optimizer.zero_grad()
399
423
 
400
424
  self.global_step += 1
401
425
 
402
426
  def _forward_pass(self) -> torch.Tensor:
403
- '''执行前向传播并计算 Loss (内部方法)。'''
404
- with torch.amp.autocast(device_type=self.device.type, enabled=self.use_amp):
427
+ '''执行前向传播并计算 Loss (内部方法)。
428
+
429
+ 使用 accelerator.autocast() 上下文管理器自动处理混合精度。
430
+ '''
431
+ # 使用 accelerator.autocast() 处理混合精度
432
+ with self.accelerator.autocast():
405
433
  if self.forward_step:
406
434
  self.loss = self.forward_step.forward(self, self.data, self.target)
407
435
  else:
@@ -446,6 +474,8 @@ class Engine:
446
474
  ):
447
475
  '''启动训练循环。
448
476
 
477
+ 此方法会自动使用 accelerator.prepare 包装数据加载器以支持分布式训练。
478
+
449
479
  Args:
450
480
  train_loader (Any): 训练数据加载器 (通常是 torch.utils.data.DataLoader)。
451
481
  val_loader (Optional[Any], optional): 验证数据加载器。
@@ -454,6 +484,12 @@ class Engine:
454
484
  如果为 None,则从 0 开始。用于断点续训。
455
485
  with_eval (bool, optional): 是否在每个 Epoch 结束后执行验证。默认为 True。
456
486
  '''
487
+ # 准备 DataLoaders
488
+ if val_loader:
489
+ train_loader, val_loader = self.accelerator.prepare(train_loader, val_loader)
490
+ else:
491
+ train_loader = self.accelerator.prepare(train_loader)
492
+
457
493
  self.train_loader = train_loader
458
494
  self.val_loader = val_loader
459
495
  self.num_epochs = num_epochs
@@ -525,19 +561,24 @@ class Engine:
525
561
  self._fire_event("on_train_end")
526
562
 
527
563
  def _train_epoch_iterator(self, loader: Any, total_steps: Optional[int] = None, prefix: str = "Train", color: str = "blue"):
528
- '''生成器:执行单个 Epoch 的训练循环。'''
564
+ '''生成器:执行单个 Epoch 的训练循环。
565
+
566
+ 注意:进度条仅在主进程中显示。
567
+ '''
529
568
  self.model.train()
530
569
  self.is_epoch_end = False
570
+ torch.cuda.empty_cache()
531
571
 
532
- # 尝试获取真实的 loader 长度
533
572
  try:
534
573
  real_len = len(loader)
535
574
  except:
536
575
  real_len = None
537
576
 
538
- # 确定进度条的总步数
539
577
  num_batches = total_steps if total_steps is not None else real_len
540
578
 
579
+ # 仅主进程显示进度条
580
+ disable_progress = not self.accelerator.is_main_process
581
+
541
582
  with Progress(
542
583
  TextColumn(f"[{color}]{prefix}"),
543
584
  TextColumn("[progress.description]{task.description}"),
@@ -545,21 +586,21 @@ class Engine:
545
586
  MofNCompleteColumn(),
546
587
  TimeRemainingColumn(),
547
588
  console=self.console,
548
- transient=True
589
+ transient=True,
590
+ disable=disable_progress
549
591
  ) as progress:
550
592
 
551
593
  task = progress.add_task(f"[Ep {self.epoch+1}/{self.num_epochs}]", total=num_batches)
552
594
 
553
595
  for batch_idx, batch_data in enumerate(loader):
554
- # 断点续训:跳过已训练的 Batch
555
596
  if self.epoch == self.start_epoch and batch_idx <= self.start_batch_idx:
556
- progress.update(task, advance=1, description=f"[dim]Skipping batch {batch_idx}...[/]")
597
+ if not disable_progress:
598
+ progress.update(task, advance=1, description=f"[dim]Skipping batch {batch_idx}...[/]")
557
599
  continue
558
600
 
559
601
  self.batch_idx = batch_idx
560
602
  self.is_first_batch = (batch_idx == 0)
561
603
 
562
- # 优先使用真实长度判断 is_last_batch
563
604
  if real_len is not None:
564
605
  self.is_last_batch = (batch_idx == real_len - 1)
565
606
  elif num_batches is not None:
@@ -570,25 +611,24 @@ class Engine:
570
611
  self._process_batch_data(batch_data)
571
612
  self._fire_event("on_batch_start")
572
613
 
573
- # Yield self to allow external control (e.g., engine.step())
574
614
  yield self
575
615
 
576
- # 更新进度条
577
616
  loss_val = self.loss.item() if self.loss is not None else 0.0
578
617
 
579
- lr_str = ""
580
- if self.optimizer:
581
- current_lr = self.optimizer.param_groups[0]['lr']
582
- if current_lr < 1e-6:
583
- lr_str = f" LR: {current_lr:.2e}"
584
- else:
585
- lr_str = f" LR: {current_lr:.6f}"
586
-
587
- if self.is_in_warmup():
588
- lr_str += " [Warmup]"
618
+ if not disable_progress:
619
+ lr_str = ""
620
+ if self.optimizer:
621
+ current_lr = self.optimizer.param_groups[0]['lr']
622
+ if current_lr < 1e-6:
623
+ lr_str = f" LR: {current_lr:.2e}"
624
+ else:
625
+ lr_str = f" LR: {current_lr:.6f}"
626
+
627
+ if self.is_in_warmup():
628
+ lr_str += " [Warmup]"
589
629
 
590
- logs = f"Loss: {loss_val:.4f}{lr_str} [Ep {self.epoch+1}/{self.num_epochs}]"
591
- progress.update(task, advance=1, description=logs)
630
+ logs = f"Loss: {loss_val:.4f}{lr_str} [Ep {self.epoch+1}/{self.num_epochs}]"
631
+ progress.update(task, advance=1, description=logs)
592
632
 
593
633
  self._fire_event("on_batch_end")
594
634
 
@@ -600,7 +640,10 @@ class Engine:
600
640
  self.is_epoch_end = False
601
641
 
602
642
  def _eval_epoch_iterator(self, loader: Any, total_steps: Optional[int] = None, prefix: str = "Eval ", color: str = "yellow"):
603
- '''生成器:执行单个 Epoch 的验证/测试循环。'''
643
+ '''生成器:执行单个 Epoch 的验证/测试循环。
644
+
645
+ 注意:进度条仅在主进程中显示。
646
+ '''
604
647
  self.model.eval()
605
648
 
606
649
  try:
@@ -609,6 +652,7 @@ class Engine:
609
652
  real_len = None
610
653
 
611
654
  num_batches = total_steps if total_steps is not None else real_len
655
+ disable_progress = not self.accelerator.is_main_process
612
656
 
613
657
  with Progress(
614
658
  TextColumn(f"[{color}]{prefix}"),
@@ -617,7 +661,8 @@ class Engine:
617
661
  MofNCompleteColumn(),
618
662
  TimeRemainingColumn(),
619
663
  console=self.console,
620
- transient=True
664
+ transient=True,
665
+ disable=disable_progress
621
666
  ) as progress:
622
667
  lr_str = ""
623
668
  if self.optimizer:
@@ -629,6 +674,7 @@ class Engine:
629
674
 
630
675
  if self.is_in_warmup():
631
676
  lr_str += " [Warmup]"
677
+
632
678
  task = progress.add_task(f"{lr_str} [Ep {self.epoch+1}/{self.num_epochs}]", total=num_batches)
633
679
 
634
680
  with torch.no_grad():
@@ -649,8 +695,9 @@ class Engine:
649
695
  yield self
650
696
 
651
697
  loss_val = self.loss.item() if self.loss is not None else 0.0
652
- logs = f"Loss: {loss_val:.4f}{lr_str} [Ep {self.epoch+1}/{self.num_epochs}]"
653
- progress.update(task, advance=1, description=logs)
698
+ if not disable_progress:
699
+ logs = f"Loss: {loss_val:.4f}{lr_str} [Ep {self.epoch+1}/{self.num_epochs}]"
700
+ progress.update(task, advance=1, description=logs)
654
701
  self._fire_event("on_batch_end")
655
702
 
656
703
  def train(
@@ -662,12 +709,15 @@ class Engine:
662
709
  ):
663
710
  '''生成器:启动训练循环,允许用户自定义 Step 逻辑。
664
711
 
712
+ 此方法会自动使用 accelerator.prepare 包装数据加载器。
713
+
665
714
  Args:
666
715
  train_loader (Any): 训练数据加载器。
667
716
  num_epochs (int, optional): 总训练轮数。
668
717
  start_epoch (int, optional): 起始 Epoch。
669
718
  total_steps (int, optional): 手动指定进度条的总步数 (用于特殊 Loader)。
670
719
  '''
720
+ train_loader = self.accelerator.prepare(train_loader)
671
721
  self.train_loader = train_loader
672
722
  self.num_epochs = num_epochs
673
723
  if start_epoch is not None:
@@ -682,7 +732,6 @@ class Engine:
682
732
 
683
733
  self._fire_event("on_epoch_start")
684
734
 
685
- # 使用生成器迭代
686
735
  epoch_loss_sum = 0.0
687
736
  count = 0
688
737
 
@@ -700,7 +749,6 @@ class Engine:
700
749
 
701
750
  self._fire_event("on_epoch_end")
702
751
 
703
- # 计算并打印 Epoch 总结
704
752
  avg_loss = epoch_loss_sum / count if count > 0 else 0.0
705
753
  self.metrics['train_loss'] = avg_loss
706
754
 
@@ -736,11 +784,14 @@ class Engine:
736
784
  ):
737
785
  '''生成器:启动评估循环。
738
786
 
787
+ 此方法会自动使用 accelerator.prepare 包装数据加载器。
788
+
739
789
  Args:
740
790
  val_loader (Any): 验证数据加载器。
741
791
  total_steps (int, optional): 手动指定进度条的总步数。
742
792
  description (str, optional): 进度条描述。
743
793
  '''
794
+ val_loader = self.accelerator.prepare(val_loader)
744
795
  self.val_loader = val_loader
745
796
  self.state = "EVAL"
746
797
  self._fire_event("on_eval_start")
@@ -800,3 +851,16 @@ class Engine:
800
851
  plugin: Checkpoint = [p for p in self.plugins if not isinstance(p, Checkpoint)][0]
801
852
  plugin._load(self, path)
802
853
  return self
854
+
855
+ def load(self, path: str, strict: bool = True) -> 'Engine':
856
+ '''从文件加载模型权重。
857
+
858
+ Args:
859
+ path (str): 权重文件路径。
860
+ strict (bool): 是否严格匹配键值。默认为 True。
861
+
862
+ Returns:
863
+ Engine: 返回 Engine 实例自身以支持链式调用。
864
+ '''
865
+ load_model(self.unwrap_model(), path, strict=strict, map_location=self.device)
866
+ return self
orbit/kit/__init__.py ADDED
@@ -0,0 +1,2 @@
1
+ from .interface import ChatInterface
2
+ from .wrapper import AutoRegressiveWrapper