orbit-torch 0.0.4a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,234 @@
1
+ import gc
2
+ import torch
3
+ import torch.nn as nn
4
+ from rich.panel import Panel
5
+ from rich.table import Table
6
+ from rich import box
7
+ from typing import TYPE_CHECKING, Optional, Union
8
+
9
+ from orbit.callback import Callback, Event
10
+
11
+ if TYPE_CHECKING:
12
+ from orbit.engine import Engine
13
+
14
+ class MemoryEstimator(Callback):
15
+ """
16
+ 显存预估插件。
17
+ 在训练开始前,通过运行一个虚拟 Batch 来预估显存使用峰值。
18
+ 同时支持在训练过程中监控显存使用情况。
19
+ """
20
+ def __init__(self, verbose: bool = True, alert_threshold: Union[float, str, int] = 0.8, stop_threshold: Union[float, str, int] = 0.95, clean_interval: Optional[int] = None):
21
+ '''
22
+ Args:
23
+ verbose (bool): 是否打印预估报告。
24
+ alert_threshold (Union[float, str, int]): 警告阈值。
25
+ 如果是 float <= 1.0,视为总显存的百分比 (例如 0.8 = 80%)。
26
+ 如果是 str (例如 "4GB", "500MB") 或 int (字节数),视为绝对值。
27
+ stop_threshold (Union[float, str, int]): 停止阈值。类型同上。
28
+ clean_interval (Optional[int]): 如果提供,则每隔多少个 Batch 执行一次显存清理 (gc.collect + empty_cache)。
29
+ 这有助于解决显存缓慢泄漏的问题,但可能会轻微影响训练速度。
30
+ '''
31
+ super().__init__()
32
+ self.verbose = verbose
33
+ self.alert_threshold_arg = alert_threshold
34
+ self.stop_threshold_arg = stop_threshold
35
+ self.clean_interval = clean_interval
36
+
37
+ self.alert_bytes = None
38
+ self.stop_bytes = None
39
+
40
+ self.has_run = False
41
+ self.has_alerted = False
42
+
43
+ def _parse_threshold(self, threshold: Union[float, str, int], total_capacity: int) -> int:
44
+ if isinstance(threshold, float):
45
+ return int(threshold * total_capacity)
46
+ if isinstance(threshold, int):
47
+ return threshold
48
+ if isinstance(threshold, str):
49
+ s = threshold.upper().strip()
50
+ if s.endswith('GB'):
51
+ return int(float(s[:-2]) * (1024**3))
52
+ elif s.endswith('MB'):
53
+ return int(float(s[:-2]) * (1024**2))
54
+ elif s.endswith('KB'):
55
+ return int(float(s[:-2]) * 1024)
56
+ elif s.endswith('B'):
57
+ return int(float(s[:-1]))
58
+ else:
59
+ try:
60
+ return int(float(s))
61
+ except ValueError:
62
+ pass
63
+ raise ValueError(f"Invalid memory threshold format: {threshold}")
64
+
65
+ def on_batch_start(self, event: Event):
66
+ if torch.cuda.is_available():
67
+ torch.cuda.reset_peak_memory_stats()
68
+
69
+ def on_batch_end(self, event: Event):
70
+ if not torch.cuda.is_available():
71
+ return
72
+
73
+ engine = event.engine
74
+ peak_memory = torch.cuda.max_memory_allocated()
75
+ total_capacity = torch.cuda.get_device_properties(engine.device).total_memory
76
+
77
+ # 初始化阈值字节数
78
+ if self.stop_bytes is None:
79
+ self.stop_bytes = self._parse_threshold(self.stop_threshold_arg, total_capacity)
80
+ if self.alert_bytes is None:
81
+ self.alert_bytes = self._parse_threshold(self.alert_threshold_arg, total_capacity)
82
+
83
+ # 格式化辅助函数
84
+ to_mb = lambda x: x / (1024 ** 2)
85
+
86
+ if peak_memory > self.stop_bytes:
87
+ engine.print(f"[bold red]Memory usage ({to_mb(peak_memory):.2f} MB) exceeded critical threshold ({to_mb(self.stop_bytes):.2f} MB)! Stopping training.[/]", plugin='MemEst')
88
+ engine.stop(source="MemoryEstimator", reason=f"Memory usage exceeded critical threshold ({to_mb(self.stop_bytes):.2f} MB)")
89
+ elif peak_memory > self.alert_bytes and not self.has_alerted:
90
+ engine.print(f"[yellow]Memory usage ({to_mb(peak_memory):.2f} MB) exceeded warning threshold ({to_mb(self.alert_bytes):.2f} MB).[/]", plugin='MemEst')
91
+ self.has_alerted = True
92
+
93
+ # 定期清理显存
94
+ if self.clean_interval and (engine.batch_idx + 1) % self.clean_interval == 0:
95
+ gc.collect()
96
+ torch.cuda.empty_cache()
97
+
98
+ def on_train_start(self, event: Event):
99
+ if self.has_run:
100
+ return
101
+
102
+ engine = event.engine
103
+ if not torch.cuda.is_available():
104
+ if self.verbose:
105
+ engine.print("[yellow]CUDA not available. Skipping memory estimation.[/]", plugin='MemEst')
106
+ return
107
+
108
+ # 确保模型在正确的设备上
109
+ device = engine.device
110
+ if device.type != 'cuda':
111
+ return
112
+
113
+ try:
114
+ self._estimate(engine)
115
+ except Exception as e:
116
+ engine.print(f"[red]Error during memory estimation: {e}[/]", plugin='MemEst')
117
+ finally:
118
+ # 清理
119
+ if engine.optimizer:
120
+ engine.optimizer.zero_grad()
121
+ torch.cuda.empty_cache()
122
+ torch.cuda.reset_peak_memory_stats()
123
+ self.has_run = True
124
+
125
+ def _estimate(self, engine: 'Engine'):
126
+ if self.verbose:
127
+ engine.print("Running dry run for memory estimation...[/]", plugin='MemEst')
128
+
129
+ # 1. 获取一个 Batch 的数据
130
+ try:
131
+ batch_data = next(iter(engine.train_loader))
132
+ except StopIteration:
133
+ engine.print("[yellow]Train loader is empty. Skipping.[/]", plugin='MemEst')
134
+ return
135
+
136
+ # 2. 准备环境
137
+ torch.cuda.empty_cache()
138
+ torch.cuda.reset_peak_memory_stats()
139
+ initial_memory = torch.cuda.memory_allocated()
140
+
141
+ # 计算模型静态大小 (Weights + Buffers)
142
+ model_stats = self._get_model_size(engine.model)
143
+
144
+ # 3. 模拟 Forward & Backward
145
+ try:
146
+ # 移动数据
147
+ engine._process_batch_data(batch_data)
148
+
149
+ # Forward
150
+ with torch.amp.autocast(device_type=engine.device.type, enabled=engine.use_amp):
151
+ if isinstance(engine.data, (list, tuple)):
152
+ output = engine.model(*engine.data)
153
+ else:
154
+ output = engine.model(engine.data)
155
+
156
+ # 构造虚拟 Loss
157
+ if engine.criterion and engine.target is not None:
158
+ loss = engine.criterion(output, engine.target)
159
+ else:
160
+ # 如果没有 target 或 criterion,构造一个标量 loss 用于 backward
161
+ if isinstance(output, torch.Tensor):
162
+ loss = output.mean()
163
+ elif isinstance(output, (list, tuple)) and isinstance(output[0], torch.Tensor):
164
+ loss = output[0].mean()
165
+ elif isinstance(output, dict):
166
+ loss = list(output.values())[0].mean()
167
+ else:
168
+ loss = torch.tensor(0.0, device=engine.device, requires_grad=True)
169
+
170
+ # Backward
171
+ if engine.use_amp and engine.scaler:
172
+ engine.scaler.scale(loss).backward()
173
+ else:
174
+ loss.backward()
175
+
176
+ # 获取峰值显存
177
+ peak_memory = torch.cuda.max_memory_allocated()
178
+ total_capacity = torch.cuda.get_device_properties(engine.device).total_memory
179
+
180
+ self._print_report(engine, model_stats, initial_memory, peak_memory, total_capacity)
181
+
182
+ except RuntimeError as e:
183
+ if "out of memory" in str(e):
184
+ engine.print("[bold red]OOM detected during memory estimation![/]", plugin='MemEst')
185
+ engine.print(f"[red]Your batch size is likely too large for this device.[/]", plugin='MemEst')
186
+ else:
187
+ raise e
188
+
189
+ def _get_model_size(self, model: nn.Module) -> float:
190
+ """计算模型参数和缓冲区的总字节数"""
191
+ mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()])
192
+ mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()])
193
+ return mem_params + mem_bufs
194
+
195
+ def _print_report(self, engine: 'Engine', model_size: int, initial: int, peak: int, capacity: int):
196
+ if not self.verbose: return
197
+
198
+ # 转换单位为 MB
199
+ to_mb = lambda x: x / (1024 ** 2)
200
+
201
+ model_mb = to_mb(model_size)
202
+ peak_mb = to_mb(peak)
203
+ capacity_mb = to_mb(capacity)
204
+ usage_percent = (peak / capacity) * 100
205
+
206
+ # 颜色编码
207
+ if usage_percent < 70:
208
+ color = "green"
209
+ status = "Safe"
210
+ elif usage_percent < 90:
211
+ color = "yellow"
212
+ status = "Warning"
213
+ else:
214
+ color = "red"
215
+ status = "Critical"
216
+
217
+ table = Table(box=box.SIMPLE, show_header=False)
218
+ table.add_column("Item", style="cyan")
219
+ table.add_column("Value", justify="right")
220
+
221
+ table.add_row("Model Weights", f"{model_mb:.2f} MB")
222
+ table.add_row("Est. Peak Memory", f"[{color}]{peak_mb:.2f} MB[/]")
223
+ table.add_row("Device Capacity", f"{capacity_mb:.2f} MB")
224
+ table.add_row("Usage", f"[{color}]{usage_percent:.1f}% ({status})[/]")
225
+
226
+ panel = Panel(
227
+ table,
228
+ title="[bold]Memory Estimation Report[/]",
229
+ border_style="blue",
230
+ expand=False
231
+ )
232
+
233
+ with engine.out_logs:
234
+ engine.console.print(panel)
orbit/plugin/mentor.py ADDED
@@ -0,0 +1,313 @@
1
+ import numpy as np
2
+ import json
3
+ import os
4
+ import locale
5
+ from typing import Optional, List, TYPE_CHECKING, Tuple, Dict
6
+ from rich.panel import Panel
7
+ from orbit.callback import Callback, Event
8
+
9
+ if TYPE_CHECKING:
10
+ from orbit.engine import Engine
11
+ from orbit.plugin.warmup import Warmup
12
+
13
+ class Mentor(Callback):
14
+ """
15
+ Mentor 插件:监控训练过程,提供改进建议。
16
+ 主要关注 Loss 的异常行为(停滞、发散、过拟合、震荡)并结合状态给出建议。
17
+ """
18
+ def __init__(
19
+ self,
20
+ patience: int = 3,
21
+ threshold: float = 1e-4,
22
+ divergence_threshold: float = 2.0,
23
+ language: Optional[str] = None,
24
+ verbose: bool = True
25
+ ):
26
+ """
27
+ Args:
28
+ patience (int): 连续多少个 Epoch Loss 没有显著改善视为停滞。
29
+ threshold (float): 视为改善的最小 Loss 变化量。
30
+ divergence_threshold (float): 当前 Loss 是最小 Loss 的多少倍视为发散。
31
+ language (str): 语言代码 ('en' 或 'zh')。如果为 None,则自动检测系统语言。
32
+ verbose (bool): 是否打印建议。
33
+ """
34
+ super().__init__()
35
+ self.patience = patience
36
+ self.threshold = threshold
37
+ self.divergence_threshold = divergence_threshold
38
+ self.language = language if language else self._detect_language()
39
+ self.verbose = verbose
40
+
41
+ self.loss_history: List[float] = []
42
+ self.val_loss_history: List[float] = []
43
+ self.min_loss = np.inf
44
+ self.stagnation_counter = 0
45
+ self.increase_counter = 0
46
+
47
+ # 记录 Warmup 相关信息
48
+ self.has_warmup = False
49
+ self.warmup_plugin: Optional['Warmup'] = None
50
+
51
+ # 记录其他配置
52
+ self.has_scheduler = False
53
+ self.grad_clip_norm = None
54
+ self.batch_size = 0
55
+ self.accum_steps = 1
56
+
57
+ # 加载翻译
58
+ self.i18n = self._load_i18n()
59
+
60
+ def _detect_language(self) -> str:
61
+ try:
62
+ lang_code, _ = locale.getdefaultlocale()
63
+ if lang_code and lang_code.startswith('zh'):
64
+ return 'zh'
65
+ except:
66
+ pass
67
+ return 'en'
68
+
69
+ def _load_i18n(self) -> Dict[str, str]:
70
+ try:
71
+ # 获取当前文件所在目录
72
+ current_dir = os.path.dirname(os.path.abspath(__file__))
73
+ json_path = os.path.join(current_dir, 'data', 'mentor_i18n.json')
74
+
75
+ with open(json_path, 'r', encoding='utf-8') as f:
76
+ data = json.load(f)
77
+
78
+ return data.get(self.language, data.get('en', {}))
79
+ except Exception as e:
80
+ print(f"[Mentor] Warning: Failed to load i18n file: {e}. Fallback to keys.")
81
+ return {}
82
+
83
+ def _t(self, key: str, **kwargs) -> str:
84
+ """获取翻译并格式化"""
85
+ template = self.i18n.get(key, key)
86
+ try:
87
+ return template.format(**kwargs)
88
+ except KeyError:
89
+ return template
90
+
91
+ def on_train_start(self, event: Event):
92
+ engine = event.engine
93
+ # 1. 检查 Warmup
94
+ for p in engine.plugins:
95
+ if p.__class__.__name__ == 'Warmup':
96
+ self.has_warmup = True
97
+ self.warmup_plugin = p
98
+ break
99
+
100
+ # 2. 检查 Scheduler & Gradient Clipping
101
+ self.has_scheduler = engine.scheduler is not None
102
+ self.grad_clip_norm = getattr(engine, 'grad_clip_norm', None)
103
+
104
+ # 3. 检查 Batch Size & Accumulation
105
+ self.accum_steps = engine.accumulation_steps
106
+ if hasattr(engine, 'train_loader') and hasattr(engine.train_loader, 'batch_size'):
107
+ self.batch_size = engine.train_loader.batch_size or 1 # Handle None
108
+
109
+ if self.verbose:
110
+ eff_bs = self.batch_size * self.accum_steps
111
+ msg = self._t("mentor_watching", eff_bs=eff_bs, bs=self.batch_size, accum=self.accum_steps)
112
+ engine.print(msg, plugin='Mentor')
113
+
114
+ def on_epoch_end(self, event: Event):
115
+ engine = event.engine
116
+ current_loss = engine.metrics.get('train_loss')
117
+ val_loss = engine.metrics.get('val_loss')
118
+
119
+ if current_loss is None:
120
+ return
121
+
122
+ self.loss_history.append(current_loss)
123
+ if val_loss is not None:
124
+ self.val_loss_history.append(val_loss)
125
+
126
+ # 收集本 Epoch 的所有建议 (Title, Content, Style)
127
+ advice_list: List[Tuple[str, str, str]] = []
128
+
129
+ # 1. 检查 NaN / Inf
130
+ if not np.isfinite(current_loss):
131
+ advice_list.append((
132
+ self._t("nan_loss_title"),
133
+ self._t("nan_loss_msg"),
134
+ "bold red"
135
+ ))
136
+ else:
137
+ # 更新最小 Loss 和 连续上升计数
138
+ if current_loss < self.min_loss:
139
+ self.min_loss = current_loss
140
+ self.stagnation_counter = 0
141
+ self.increase_counter = 0
142
+ else:
143
+ self.stagnation_counter += 1
144
+ # 检查是否比上一个 epoch 增加
145
+ if len(self.loss_history) > 1 and current_loss > self.loss_history[-2]:
146
+ self.increase_counter += 1
147
+ else:
148
+ self.increase_counter = 0
149
+
150
+ # 2. 检查发散 (Divergence)
151
+ is_diverging = (current_loss > self.min_loss * self.divergence_threshold)
152
+ is_unstable = (self.increase_counter >= 2)
153
+
154
+ if (is_diverging or is_unstable) and len(self.loss_history) > 1:
155
+ res = self._analyze_divergence(engine, current_loss, is_unstable)
156
+ if res: advice_list.append(res)
157
+
158
+ # 重置以避免每个 epoch 都刷屏
159
+ if is_diverging:
160
+ self.min_loss = current_loss
161
+ if is_unstable:
162
+ self.increase_counter = 0
163
+
164
+ # 3. 检查停滞 (Stagnation)
165
+ if self.stagnation_counter >= self.patience:
166
+ res = self._analyze_stagnation(engine, current_loss)
167
+ if res: advice_list.append(res)
168
+ self.stagnation_counter = 0 # 重置计数器
169
+
170
+ # 4. 检查过拟合 (Overfitting)
171
+ if len(self.val_loss_history) >= 3:
172
+ res = self._analyze_overfitting(engine)
173
+ if res: advice_list.append(res)
174
+
175
+ # 5. 检查震荡 (Oscillation)
176
+ if len(self.loss_history) >= 5:
177
+ res = self._analyze_oscillation(engine)
178
+ if res: advice_list.append(res)
179
+
180
+ # 统一打印
181
+ if advice_list and self.verbose:
182
+ with engine.out_logs:
183
+ for title, content, style in advice_list:
184
+ engine.console.print(Panel(content, title=title, border_style=style, expand=False))
185
+
186
+ def _analyze_divergence(self, engine: 'Engine', current_loss: float, is_unstable: bool = False) -> Tuple[str, str, str]:
187
+ in_warmup = engine.is_in_warmup()
188
+ eff_bs = self.batch_size * self.accum_steps
189
+
190
+ title = self._t("divergence_title")
191
+ if is_unstable:
192
+ msg = self._t("divergence_msg_unstable", count=self.increase_counter, loss=current_loss)
193
+ else:
194
+ msg = self._t("divergence_msg_spike", loss=current_loss, min_loss=self.min_loss)
195
+
196
+ advice = []
197
+
198
+ # 1. Warmup 建议
199
+ if not self.has_warmup:
200
+ advice.append(self._t("advice_add_warmup"))
201
+ advice.append(self._t("advice_lower_lr"))
202
+ elif in_warmup:
203
+ advice.append(self._t("advice_warmup_start_lr"))
204
+ else:
205
+ advice.append(self._t("advice_post_warmup_lr"))
206
+
207
+ # 2. Batch Size & Accumulation 建议
208
+ if eff_bs < 32:
209
+ advice.append(self._t("advice_small_bs", eff_bs=eff_bs))
210
+ advice.append(self._t("advice_increase_accum", accum_steps=self.accum_steps))
211
+
212
+ # 3. 通用建议
213
+ advice.append(self._t("advice_grad_clip"))
214
+
215
+ return title, msg + "\n\n" + "\n".join(advice), "red"
216
+
217
+ def _analyze_stagnation(self, engine: 'Engine', current_loss: float) -> Tuple[str, str, str]:
218
+ in_warmup = engine.is_in_warmup()
219
+ eff_bs = self.batch_size * self.accum_steps
220
+
221
+ title = self._t("stagnation_title")
222
+ msg = self._t("stagnation_msg", patience=self.patience)
223
+
224
+ advice = []
225
+
226
+ if in_warmup:
227
+ # 在 Warmup 期间停滞
228
+ warmup_epochs = getattr(self.warmup_plugin, 'warmup_epochs', 0)
229
+ total_epochs = engine.num_epochs
230
+
231
+ advice.append(self._t("advice_warmup_duration", epoch=engine.epoch+1))
232
+ if warmup_epochs > total_epochs * 0.2:
233
+ advice.append(self._t("advice_warmup_too_long", warmup_epochs=warmup_epochs))
234
+ advice.append(self._t("advice_check_start_lr"))
235
+
236
+ else:
237
+ # 非 Warmup 期间停滞
238
+ advice.append(self._t("advice_lr_general"))
239
+
240
+ # Scheduler 建议
241
+ if not self.has_scheduler:
242
+ advice.append(self._t("advice_add_scheduler"))
243
+ else:
244
+ advice.append(self._t("advice_check_scheduler"))
245
+
246
+ # Batch Size 建议
247
+ if eff_bs > 4096:
248
+ advice.append(self._t("advice_large_bs", eff_bs=eff_bs))
249
+ advice.append(self._t("advice_reduce_bs"))
250
+
251
+ # 4. 初始化建议 (仅在早期)
252
+ if engine.epoch < 10:
253
+ advice.append(self._t("advice_check_init"))
254
+
255
+ # 5. 调试与数据建议
256
+ advice.append(self._t("advice_overfit_single_batch"))
257
+ advice.append(self._t("advice_data_hard"))
258
+
259
+ return title, msg + "\n\n" + "\n".join(advice), "yellow"
260
+
261
+ def _analyze_overfitting(self, engine: 'Engine') -> Optional[Tuple[str, str, str]]:
262
+ """检测过拟合:Train Loss 下降,Val Loss 上升"""
263
+ if len(self.val_loss_history) < 3 or len(self.loss_history) < 3:
264
+ return None
265
+
266
+ # 检查最近 3 个 epoch
267
+ recent_val = self.val_loss_history[-3:]
268
+ recent_train = self.loss_history[-3:]
269
+
270
+ # Val Loss 持续上升
271
+ val_rising = (recent_val[-1] > recent_val[-2] > recent_val[-3])
272
+ # Train Loss 持续下降 (或保持低位)
273
+ train_dropping = (recent_train[-1] <= recent_train[-2])
274
+
275
+ if val_rising and train_dropping:
276
+ title = self._t("overfitting_title")
277
+ msg = self._t("overfitting_msg")
278
+ advice = [
279
+ self._t("advice_regularization"),
280
+ self._t("advice_data_aug"),
281
+ self._t("advice_early_stopping")
282
+ ]
283
+ return title, msg + "\n\n" + "\n".join(advice), "magenta"
284
+ return None
285
+
286
+ def _analyze_oscillation(self, engine: 'Engine') -> Optional[Tuple[str, str, str]]:
287
+ """检测震荡:Loss 标准差过大"""
288
+ if len(self.loss_history) < 5:
289
+ return None
290
+
291
+ recent_loss = self.loss_history[-5:]
292
+ std_dev = np.std(recent_loss)
293
+ mean_loss = np.mean(recent_loss)
294
+
295
+ # 如果标准差超过均值的 10% (经验值),且没有持续下降趋势
296
+ # 简单的趋势检查:首尾差异不大
297
+ is_flat_trend = abs(recent_loss[-1] - recent_loss[0]) < std_dev
298
+
299
+ if std_dev > 0.1 * mean_loss and is_flat_trend:
300
+ title = self._t("oscillation_title")
301
+ msg = self._t("oscillation_msg", std=std_dev)
302
+ advice = [
303
+ self._t("advice_lower_lr_oscillation")
304
+ ]
305
+
306
+ if not self.has_scheduler:
307
+ advice.append(self._t("advice_oscillation_scheduler"))
308
+
309
+ if not self.grad_clip_norm:
310
+ advice.append(self._t("advice_oscillation_grad_clip"))
311
+
312
+ return title, msg + "\n\n" + "\n".join(advice), "cyan"
313
+ return None
@@ -0,0 +1,30 @@
1
+ from orbit.callback import Callback, Event
2
+ from typing import Any
3
+
4
+ class Overfit(Callback):
5
+ '''故意重复第一个 Batch 的数据以进行过拟合测试的插件。
6
+
7
+ 这对于验证模型架构是否有能力拟合数据非常有用(Sanity Check)。
8
+ 如果模型无法在单个 Batch 上过拟合,则可能存在代码错误或架构问题。
9
+ '''
10
+
11
+ def __init__(self):
12
+ self.fixed_data: Any = None
13
+ self.fixed_target: Any = None
14
+ self.has_captured = False
15
+
16
+ def on_batch_start(self, event: Event):
17
+ # 仅在训练阶段生效
18
+ if event.engine.state != "TRAIN":
19
+ return
20
+
21
+ if not self.has_captured:
22
+ # 捕获第一个 Batch
23
+ self.fixed_data = event.engine.data
24
+ self.fixed_target = event.engine.target
25
+ self.has_captured = True
26
+ event.engine.print("[yellow]Overfit Plugin: Captured first batch. All subsequent batches will be replaced by this one.[/]", plugin='Overfit')
27
+ else:
28
+ # 替换后续 Batch
29
+ event.engine.data = self.fixed_data
30
+ event.engine.target = self.fixed_target
orbit/plugin/warmup.py ADDED
@@ -0,0 +1,119 @@
1
+ from typing import Optional, List, TYPE_CHECKING
2
+ from orbit.callback import Callback, Event
3
+
4
+ if TYPE_CHECKING:
5
+ from orbit.engine import Engine
6
+
7
+ class Warmup(Callback):
8
+ """
9
+ 学习率预热 (Warmup) 插件。
10
+ 支持 Linear, Constant, Noam (Transformer) 三种模式。
11
+ 可以在 Batch 粒度上动态调整学习率。
12
+ """
13
+ def __init__(
14
+ self,
15
+ warmup_steps: int = 0,
16
+ warmup_epochs: int = 0,
17
+ mode: str = 'linear',
18
+ min_lr: float = 0.0,
19
+ model_dim: Optional[int] = None,
20
+ scale: float = 1.0,
21
+ ):
22
+ """
23
+ Args:
24
+ warmup_steps (int): 预热的总步数 (Batch 数)。
25
+ warmup_epochs (int): 预热的总 Epoch 数。如果设置,优先级高于 warmup_steps。
26
+ mode (str): 'linear' | 'constant' | 'noam'。
27
+ min_lr (float): 起始学习率 (linear/constant 模式)。
28
+ model_dim (int): 模型维度 (仅 noam 模式需要)。
29
+ scale (float): 缩放因子 (仅 noam 模式需要)。
30
+ """
31
+ super().__init__()
32
+ self.warmup_steps = warmup_steps
33
+ self.warmup_epochs = warmup_epochs
34
+ self.mode = mode.lower()
35
+ self.min_lr = min_lr
36
+ self.model_dim = model_dim
37
+ self.scale = scale
38
+
39
+ self.total_warmup_steps = 0
40
+ self.base_lrs: List[float] = []
41
+
42
+ # 验证参数
43
+ if self.mode == 'noam' and self.model_dim is None:
44
+ raise ValueError("Noam mode requires 'model_dim' to be specified.")
45
+
46
+ def on_train_start(self, event: Event):
47
+ """
48
+ 训练开始时计算总预热步数并记录初始学习率
49
+ """
50
+ engine = event.engine
51
+ if not engine.optimizer:
52
+ raise ValueError("Warmup plugin requires an optimizer in the Engine.")
53
+
54
+ # 1. 记录优化器的初始学习率 (Base LR)
55
+ # 如果是从 Checkpoint 恢复,base_lrs 可能会变,但在 Warmup 逻辑里我们通常认为
56
+ # 预热的目标就是 param_group['initial_lr'] (如果存在) 或者当前的 ['lr']
57
+ self.base_lrs = []
58
+ for group in engine.optimizer.param_groups:
59
+ # 优先使用 initial_lr (由某些 Scheduler 设置),否则使用当前 lr
60
+ self.base_lrs.append(group.get('initial_lr', group['lr']))
61
+
62
+ # 2. 计算 total_warmup_steps
63
+ if self.warmup_epochs > 0:
64
+ if not engine.train_loader:
65
+ raise ValueError("warmup_epochs requires train_loader to be available.")
66
+ try:
67
+ steps_per_epoch = len(engine.train_loader)
68
+ self.total_warmup_steps = self.warmup_epochs * steps_per_epoch
69
+ except TypeError:
70
+ # 如果 train_loader 无法求 len (例如 iterable dataset),则必须提供 steps
71
+ raise ValueError("Could not determine length of train_loader. Please use 'warmup_steps' instead.")
72
+ else:
73
+ self.total_warmup_steps = self.warmup_steps
74
+
75
+ # 打印信息
76
+ if self.total_warmup_steps > 0 or self.mode == 'noam':
77
+ engine.print(f"[magenta]Strategy activated: {self.mode}[/]", plugin='Warmup')
78
+ if self.mode != 'noam':
79
+ engine.print(f"[magenta]Steps: {self.total_warmup_steps} (Epochs: {self.warmup_epochs})[/]", plugin='Warmup')
80
+
81
+ def on_batch_start(self, event: Event):
82
+ """
83
+ 每个 Batch 开始前调整学习率
84
+ """
85
+ engine = event.engine
86
+ # 当前步数 (从 1 开始计算,方便公式)
87
+ current_step = engine.global_step + 1
88
+
89
+ # 如果超出预热范围且不是 Noam 模式,则不进行干预
90
+ # Noam 模式是一种全程调度策略,所以它会一直运行
91
+ if self.mode != 'noam' and current_step > self.total_warmup_steps:
92
+ return
93
+
94
+ for i, group in enumerate(engine.optimizer.param_groups):
95
+ base_lr = self.base_lrs[i]
96
+ new_lr = base_lr
97
+
98
+ if self.mode == 'noam':
99
+ # Noam Scheduler: scale * d_model^-0.5 * min(step^-0.5, step * warmup^-1.5)
100
+ warmup = self.total_warmup_steps
101
+ # 避免 step=0
102
+ step = max(1, current_step)
103
+
104
+ term1 = step ** -0.5
105
+ term2 = step * (warmup ** -1.5)
106
+
107
+ new_lr = self.scale * (self.model_dim ** -0.5) * min(term1, term2)
108
+
109
+ elif self.mode == 'linear':
110
+ # Linear: min_lr -> base_lr
111
+ alpha = current_step / self.total_warmup_steps
112
+ new_lr = self.min_lr + (base_lr - self.min_lr) * alpha
113
+
114
+ elif self.mode == 'constant':
115
+ # Constant: min_lr
116
+ new_lr = self.min_lr
117
+
118
+ # 更新学习率
119
+ group['lr'] = new_lr