orbit-torch 0.0.4a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orbit/__init__.py +3 -0
- orbit/callback.py +54 -0
- orbit/engine.py +802 -0
- orbit/optim/__init__.py +2 -0
- orbit/optim/muon.py +193 -0
- orbit/optim/sam.py +92 -0
- orbit/plugin/__init__.py +10 -0
- orbit/plugin/board.py +61 -0
- orbit/plugin/checkpoint.py +245 -0
- orbit/plugin/classification.py +190 -0
- orbit/plugin/data/mentor_i18n.json +102 -0
- orbit/plugin/display_model.py +75 -0
- orbit/plugin/early_stopping.py +101 -0
- orbit/plugin/ema.py +97 -0
- orbit/plugin/gradient_accumulation.py +32 -0
- orbit/plugin/memory_estimator.py +234 -0
- orbit/plugin/mentor.py +313 -0
- orbit/plugin/overfit.py +30 -0
- orbit/plugin/warmup.py +119 -0
- orbit/utils/__init__.py +29 -0
- orbit/utils/freeze.py +59 -0
- orbit/utils/initialization.py +501 -0
- orbit/utils/layer_io.py +55 -0
- orbit/utils/mask.py +92 -0
- orbit/utils/seed.py +66 -0
- orbit_torch-0.0.4a1.dist-info/METADATA +25 -0
- orbit_torch-0.0.4a1.dist-info/RECORD +29 -0
- orbit_torch-0.0.4a1.dist-info/WHEEL +5 -0
- orbit_torch-0.0.4a1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
import gc
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
from rich.panel import Panel
|
|
5
|
+
from rich.table import Table
|
|
6
|
+
from rich import box
|
|
7
|
+
from typing import TYPE_CHECKING, Optional, Union
|
|
8
|
+
|
|
9
|
+
from orbit.callback import Callback, Event
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from orbit.engine import Engine
|
|
13
|
+
|
|
14
|
+
class MemoryEstimator(Callback):
|
|
15
|
+
"""
|
|
16
|
+
显存预估插件。
|
|
17
|
+
在训练开始前,通过运行一个虚拟 Batch 来预估显存使用峰值。
|
|
18
|
+
同时支持在训练过程中监控显存使用情况。
|
|
19
|
+
"""
|
|
20
|
+
def __init__(self, verbose: bool = True, alert_threshold: Union[float, str, int] = 0.8, stop_threshold: Union[float, str, int] = 0.95, clean_interval: Optional[int] = None):
|
|
21
|
+
'''
|
|
22
|
+
Args:
|
|
23
|
+
verbose (bool): 是否打印预估报告。
|
|
24
|
+
alert_threshold (Union[float, str, int]): 警告阈值。
|
|
25
|
+
如果是 float <= 1.0,视为总显存的百分比 (例如 0.8 = 80%)。
|
|
26
|
+
如果是 str (例如 "4GB", "500MB") 或 int (字节数),视为绝对值。
|
|
27
|
+
stop_threshold (Union[float, str, int]): 停止阈值。类型同上。
|
|
28
|
+
clean_interval (Optional[int]): 如果提供,则每隔多少个 Batch 执行一次显存清理 (gc.collect + empty_cache)。
|
|
29
|
+
这有助于解决显存缓慢泄漏的问题,但可能会轻微影响训练速度。
|
|
30
|
+
'''
|
|
31
|
+
super().__init__()
|
|
32
|
+
self.verbose = verbose
|
|
33
|
+
self.alert_threshold_arg = alert_threshold
|
|
34
|
+
self.stop_threshold_arg = stop_threshold
|
|
35
|
+
self.clean_interval = clean_interval
|
|
36
|
+
|
|
37
|
+
self.alert_bytes = None
|
|
38
|
+
self.stop_bytes = None
|
|
39
|
+
|
|
40
|
+
self.has_run = False
|
|
41
|
+
self.has_alerted = False
|
|
42
|
+
|
|
43
|
+
def _parse_threshold(self, threshold: Union[float, str, int], total_capacity: int) -> int:
|
|
44
|
+
if isinstance(threshold, float):
|
|
45
|
+
return int(threshold * total_capacity)
|
|
46
|
+
if isinstance(threshold, int):
|
|
47
|
+
return threshold
|
|
48
|
+
if isinstance(threshold, str):
|
|
49
|
+
s = threshold.upper().strip()
|
|
50
|
+
if s.endswith('GB'):
|
|
51
|
+
return int(float(s[:-2]) * (1024**3))
|
|
52
|
+
elif s.endswith('MB'):
|
|
53
|
+
return int(float(s[:-2]) * (1024**2))
|
|
54
|
+
elif s.endswith('KB'):
|
|
55
|
+
return int(float(s[:-2]) * 1024)
|
|
56
|
+
elif s.endswith('B'):
|
|
57
|
+
return int(float(s[:-1]))
|
|
58
|
+
else:
|
|
59
|
+
try:
|
|
60
|
+
return int(float(s))
|
|
61
|
+
except ValueError:
|
|
62
|
+
pass
|
|
63
|
+
raise ValueError(f"Invalid memory threshold format: {threshold}")
|
|
64
|
+
|
|
65
|
+
def on_batch_start(self, event: Event):
|
|
66
|
+
if torch.cuda.is_available():
|
|
67
|
+
torch.cuda.reset_peak_memory_stats()
|
|
68
|
+
|
|
69
|
+
def on_batch_end(self, event: Event):
|
|
70
|
+
if not torch.cuda.is_available():
|
|
71
|
+
return
|
|
72
|
+
|
|
73
|
+
engine = event.engine
|
|
74
|
+
peak_memory = torch.cuda.max_memory_allocated()
|
|
75
|
+
total_capacity = torch.cuda.get_device_properties(engine.device).total_memory
|
|
76
|
+
|
|
77
|
+
# 初始化阈值字节数
|
|
78
|
+
if self.stop_bytes is None:
|
|
79
|
+
self.stop_bytes = self._parse_threshold(self.stop_threshold_arg, total_capacity)
|
|
80
|
+
if self.alert_bytes is None:
|
|
81
|
+
self.alert_bytes = self._parse_threshold(self.alert_threshold_arg, total_capacity)
|
|
82
|
+
|
|
83
|
+
# 格式化辅助函数
|
|
84
|
+
to_mb = lambda x: x / (1024 ** 2)
|
|
85
|
+
|
|
86
|
+
if peak_memory > self.stop_bytes:
|
|
87
|
+
engine.print(f"[bold red]Memory usage ({to_mb(peak_memory):.2f} MB) exceeded critical threshold ({to_mb(self.stop_bytes):.2f} MB)! Stopping training.[/]", plugin='MemEst')
|
|
88
|
+
engine.stop(source="MemoryEstimator", reason=f"Memory usage exceeded critical threshold ({to_mb(self.stop_bytes):.2f} MB)")
|
|
89
|
+
elif peak_memory > self.alert_bytes and not self.has_alerted:
|
|
90
|
+
engine.print(f"[yellow]Memory usage ({to_mb(peak_memory):.2f} MB) exceeded warning threshold ({to_mb(self.alert_bytes):.2f} MB).[/]", plugin='MemEst')
|
|
91
|
+
self.has_alerted = True
|
|
92
|
+
|
|
93
|
+
# 定期清理显存
|
|
94
|
+
if self.clean_interval and (engine.batch_idx + 1) % self.clean_interval == 0:
|
|
95
|
+
gc.collect()
|
|
96
|
+
torch.cuda.empty_cache()
|
|
97
|
+
|
|
98
|
+
def on_train_start(self, event: Event):
|
|
99
|
+
if self.has_run:
|
|
100
|
+
return
|
|
101
|
+
|
|
102
|
+
engine = event.engine
|
|
103
|
+
if not torch.cuda.is_available():
|
|
104
|
+
if self.verbose:
|
|
105
|
+
engine.print("[yellow]CUDA not available. Skipping memory estimation.[/]", plugin='MemEst')
|
|
106
|
+
return
|
|
107
|
+
|
|
108
|
+
# 确保模型在正确的设备上
|
|
109
|
+
device = engine.device
|
|
110
|
+
if device.type != 'cuda':
|
|
111
|
+
return
|
|
112
|
+
|
|
113
|
+
try:
|
|
114
|
+
self._estimate(engine)
|
|
115
|
+
except Exception as e:
|
|
116
|
+
engine.print(f"[red]Error during memory estimation: {e}[/]", plugin='MemEst')
|
|
117
|
+
finally:
|
|
118
|
+
# 清理
|
|
119
|
+
if engine.optimizer:
|
|
120
|
+
engine.optimizer.zero_grad()
|
|
121
|
+
torch.cuda.empty_cache()
|
|
122
|
+
torch.cuda.reset_peak_memory_stats()
|
|
123
|
+
self.has_run = True
|
|
124
|
+
|
|
125
|
+
def _estimate(self, engine: 'Engine'):
|
|
126
|
+
if self.verbose:
|
|
127
|
+
engine.print("Running dry run for memory estimation...[/]", plugin='MemEst')
|
|
128
|
+
|
|
129
|
+
# 1. 获取一个 Batch 的数据
|
|
130
|
+
try:
|
|
131
|
+
batch_data = next(iter(engine.train_loader))
|
|
132
|
+
except StopIteration:
|
|
133
|
+
engine.print("[yellow]Train loader is empty. Skipping.[/]", plugin='MemEst')
|
|
134
|
+
return
|
|
135
|
+
|
|
136
|
+
# 2. 准备环境
|
|
137
|
+
torch.cuda.empty_cache()
|
|
138
|
+
torch.cuda.reset_peak_memory_stats()
|
|
139
|
+
initial_memory = torch.cuda.memory_allocated()
|
|
140
|
+
|
|
141
|
+
# 计算模型静态大小 (Weights + Buffers)
|
|
142
|
+
model_stats = self._get_model_size(engine.model)
|
|
143
|
+
|
|
144
|
+
# 3. 模拟 Forward & Backward
|
|
145
|
+
try:
|
|
146
|
+
# 移动数据
|
|
147
|
+
engine._process_batch_data(batch_data)
|
|
148
|
+
|
|
149
|
+
# Forward
|
|
150
|
+
with torch.amp.autocast(device_type=engine.device.type, enabled=engine.use_amp):
|
|
151
|
+
if isinstance(engine.data, (list, tuple)):
|
|
152
|
+
output = engine.model(*engine.data)
|
|
153
|
+
else:
|
|
154
|
+
output = engine.model(engine.data)
|
|
155
|
+
|
|
156
|
+
# 构造虚拟 Loss
|
|
157
|
+
if engine.criterion and engine.target is not None:
|
|
158
|
+
loss = engine.criterion(output, engine.target)
|
|
159
|
+
else:
|
|
160
|
+
# 如果没有 target 或 criterion,构造一个标量 loss 用于 backward
|
|
161
|
+
if isinstance(output, torch.Tensor):
|
|
162
|
+
loss = output.mean()
|
|
163
|
+
elif isinstance(output, (list, tuple)) and isinstance(output[0], torch.Tensor):
|
|
164
|
+
loss = output[0].mean()
|
|
165
|
+
elif isinstance(output, dict):
|
|
166
|
+
loss = list(output.values())[0].mean()
|
|
167
|
+
else:
|
|
168
|
+
loss = torch.tensor(0.0, device=engine.device, requires_grad=True)
|
|
169
|
+
|
|
170
|
+
# Backward
|
|
171
|
+
if engine.use_amp and engine.scaler:
|
|
172
|
+
engine.scaler.scale(loss).backward()
|
|
173
|
+
else:
|
|
174
|
+
loss.backward()
|
|
175
|
+
|
|
176
|
+
# 获取峰值显存
|
|
177
|
+
peak_memory = torch.cuda.max_memory_allocated()
|
|
178
|
+
total_capacity = torch.cuda.get_device_properties(engine.device).total_memory
|
|
179
|
+
|
|
180
|
+
self._print_report(engine, model_stats, initial_memory, peak_memory, total_capacity)
|
|
181
|
+
|
|
182
|
+
except RuntimeError as e:
|
|
183
|
+
if "out of memory" in str(e):
|
|
184
|
+
engine.print("[bold red]OOM detected during memory estimation![/]", plugin='MemEst')
|
|
185
|
+
engine.print(f"[red]Your batch size is likely too large for this device.[/]", plugin='MemEst')
|
|
186
|
+
else:
|
|
187
|
+
raise e
|
|
188
|
+
|
|
189
|
+
def _get_model_size(self, model: nn.Module) -> float:
|
|
190
|
+
"""计算模型参数和缓冲区的总字节数"""
|
|
191
|
+
mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()])
|
|
192
|
+
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()])
|
|
193
|
+
return mem_params + mem_bufs
|
|
194
|
+
|
|
195
|
+
def _print_report(self, engine: 'Engine', model_size: int, initial: int, peak: int, capacity: int):
|
|
196
|
+
if not self.verbose: return
|
|
197
|
+
|
|
198
|
+
# 转换单位为 MB
|
|
199
|
+
to_mb = lambda x: x / (1024 ** 2)
|
|
200
|
+
|
|
201
|
+
model_mb = to_mb(model_size)
|
|
202
|
+
peak_mb = to_mb(peak)
|
|
203
|
+
capacity_mb = to_mb(capacity)
|
|
204
|
+
usage_percent = (peak / capacity) * 100
|
|
205
|
+
|
|
206
|
+
# 颜色编码
|
|
207
|
+
if usage_percent < 70:
|
|
208
|
+
color = "green"
|
|
209
|
+
status = "Safe"
|
|
210
|
+
elif usage_percent < 90:
|
|
211
|
+
color = "yellow"
|
|
212
|
+
status = "Warning"
|
|
213
|
+
else:
|
|
214
|
+
color = "red"
|
|
215
|
+
status = "Critical"
|
|
216
|
+
|
|
217
|
+
table = Table(box=box.SIMPLE, show_header=False)
|
|
218
|
+
table.add_column("Item", style="cyan")
|
|
219
|
+
table.add_column("Value", justify="right")
|
|
220
|
+
|
|
221
|
+
table.add_row("Model Weights", f"{model_mb:.2f} MB")
|
|
222
|
+
table.add_row("Est. Peak Memory", f"[{color}]{peak_mb:.2f} MB[/]")
|
|
223
|
+
table.add_row("Device Capacity", f"{capacity_mb:.2f} MB")
|
|
224
|
+
table.add_row("Usage", f"[{color}]{usage_percent:.1f}% ({status})[/]")
|
|
225
|
+
|
|
226
|
+
panel = Panel(
|
|
227
|
+
table,
|
|
228
|
+
title="[bold]Memory Estimation Report[/]",
|
|
229
|
+
border_style="blue",
|
|
230
|
+
expand=False
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
with engine.out_logs:
|
|
234
|
+
engine.console.print(panel)
|
orbit/plugin/mentor.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
import locale
|
|
5
|
+
from typing import Optional, List, TYPE_CHECKING, Tuple, Dict
|
|
6
|
+
from rich.panel import Panel
|
|
7
|
+
from orbit.callback import Callback, Event
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from orbit.engine import Engine
|
|
11
|
+
from orbit.plugin.warmup import Warmup
|
|
12
|
+
|
|
13
|
+
class Mentor(Callback):
|
|
14
|
+
"""
|
|
15
|
+
Mentor 插件:监控训练过程,提供改进建议。
|
|
16
|
+
主要关注 Loss 的异常行为(停滞、发散、过拟合、震荡)并结合状态给出建议。
|
|
17
|
+
"""
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
patience: int = 3,
|
|
21
|
+
threshold: float = 1e-4,
|
|
22
|
+
divergence_threshold: float = 2.0,
|
|
23
|
+
language: Optional[str] = None,
|
|
24
|
+
verbose: bool = True
|
|
25
|
+
):
|
|
26
|
+
"""
|
|
27
|
+
Args:
|
|
28
|
+
patience (int): 连续多少个 Epoch Loss 没有显著改善视为停滞。
|
|
29
|
+
threshold (float): 视为改善的最小 Loss 变化量。
|
|
30
|
+
divergence_threshold (float): 当前 Loss 是最小 Loss 的多少倍视为发散。
|
|
31
|
+
language (str): 语言代码 ('en' 或 'zh')。如果为 None,则自动检测系统语言。
|
|
32
|
+
verbose (bool): 是否打印建议。
|
|
33
|
+
"""
|
|
34
|
+
super().__init__()
|
|
35
|
+
self.patience = patience
|
|
36
|
+
self.threshold = threshold
|
|
37
|
+
self.divergence_threshold = divergence_threshold
|
|
38
|
+
self.language = language if language else self._detect_language()
|
|
39
|
+
self.verbose = verbose
|
|
40
|
+
|
|
41
|
+
self.loss_history: List[float] = []
|
|
42
|
+
self.val_loss_history: List[float] = []
|
|
43
|
+
self.min_loss = np.inf
|
|
44
|
+
self.stagnation_counter = 0
|
|
45
|
+
self.increase_counter = 0
|
|
46
|
+
|
|
47
|
+
# 记录 Warmup 相关信息
|
|
48
|
+
self.has_warmup = False
|
|
49
|
+
self.warmup_plugin: Optional['Warmup'] = None
|
|
50
|
+
|
|
51
|
+
# 记录其他配置
|
|
52
|
+
self.has_scheduler = False
|
|
53
|
+
self.grad_clip_norm = None
|
|
54
|
+
self.batch_size = 0
|
|
55
|
+
self.accum_steps = 1
|
|
56
|
+
|
|
57
|
+
# 加载翻译
|
|
58
|
+
self.i18n = self._load_i18n()
|
|
59
|
+
|
|
60
|
+
def _detect_language(self) -> str:
|
|
61
|
+
try:
|
|
62
|
+
lang_code, _ = locale.getdefaultlocale()
|
|
63
|
+
if lang_code and lang_code.startswith('zh'):
|
|
64
|
+
return 'zh'
|
|
65
|
+
except:
|
|
66
|
+
pass
|
|
67
|
+
return 'en'
|
|
68
|
+
|
|
69
|
+
def _load_i18n(self) -> Dict[str, str]:
|
|
70
|
+
try:
|
|
71
|
+
# 获取当前文件所在目录
|
|
72
|
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
73
|
+
json_path = os.path.join(current_dir, 'data', 'mentor_i18n.json')
|
|
74
|
+
|
|
75
|
+
with open(json_path, 'r', encoding='utf-8') as f:
|
|
76
|
+
data = json.load(f)
|
|
77
|
+
|
|
78
|
+
return data.get(self.language, data.get('en', {}))
|
|
79
|
+
except Exception as e:
|
|
80
|
+
print(f"[Mentor] Warning: Failed to load i18n file: {e}. Fallback to keys.")
|
|
81
|
+
return {}
|
|
82
|
+
|
|
83
|
+
def _t(self, key: str, **kwargs) -> str:
|
|
84
|
+
"""获取翻译并格式化"""
|
|
85
|
+
template = self.i18n.get(key, key)
|
|
86
|
+
try:
|
|
87
|
+
return template.format(**kwargs)
|
|
88
|
+
except KeyError:
|
|
89
|
+
return template
|
|
90
|
+
|
|
91
|
+
def on_train_start(self, event: Event):
|
|
92
|
+
engine = event.engine
|
|
93
|
+
# 1. 检查 Warmup
|
|
94
|
+
for p in engine.plugins:
|
|
95
|
+
if p.__class__.__name__ == 'Warmup':
|
|
96
|
+
self.has_warmup = True
|
|
97
|
+
self.warmup_plugin = p
|
|
98
|
+
break
|
|
99
|
+
|
|
100
|
+
# 2. 检查 Scheduler & Gradient Clipping
|
|
101
|
+
self.has_scheduler = engine.scheduler is not None
|
|
102
|
+
self.grad_clip_norm = getattr(engine, 'grad_clip_norm', None)
|
|
103
|
+
|
|
104
|
+
# 3. 检查 Batch Size & Accumulation
|
|
105
|
+
self.accum_steps = engine.accumulation_steps
|
|
106
|
+
if hasattr(engine, 'train_loader') and hasattr(engine.train_loader, 'batch_size'):
|
|
107
|
+
self.batch_size = engine.train_loader.batch_size or 1 # Handle None
|
|
108
|
+
|
|
109
|
+
if self.verbose:
|
|
110
|
+
eff_bs = self.batch_size * self.accum_steps
|
|
111
|
+
msg = self._t("mentor_watching", eff_bs=eff_bs, bs=self.batch_size, accum=self.accum_steps)
|
|
112
|
+
engine.print(msg, plugin='Mentor')
|
|
113
|
+
|
|
114
|
+
def on_epoch_end(self, event: Event):
|
|
115
|
+
engine = event.engine
|
|
116
|
+
current_loss = engine.metrics.get('train_loss')
|
|
117
|
+
val_loss = engine.metrics.get('val_loss')
|
|
118
|
+
|
|
119
|
+
if current_loss is None:
|
|
120
|
+
return
|
|
121
|
+
|
|
122
|
+
self.loss_history.append(current_loss)
|
|
123
|
+
if val_loss is not None:
|
|
124
|
+
self.val_loss_history.append(val_loss)
|
|
125
|
+
|
|
126
|
+
# 收集本 Epoch 的所有建议 (Title, Content, Style)
|
|
127
|
+
advice_list: List[Tuple[str, str, str]] = []
|
|
128
|
+
|
|
129
|
+
# 1. 检查 NaN / Inf
|
|
130
|
+
if not np.isfinite(current_loss):
|
|
131
|
+
advice_list.append((
|
|
132
|
+
self._t("nan_loss_title"),
|
|
133
|
+
self._t("nan_loss_msg"),
|
|
134
|
+
"bold red"
|
|
135
|
+
))
|
|
136
|
+
else:
|
|
137
|
+
# 更新最小 Loss 和 连续上升计数
|
|
138
|
+
if current_loss < self.min_loss:
|
|
139
|
+
self.min_loss = current_loss
|
|
140
|
+
self.stagnation_counter = 0
|
|
141
|
+
self.increase_counter = 0
|
|
142
|
+
else:
|
|
143
|
+
self.stagnation_counter += 1
|
|
144
|
+
# 检查是否比上一个 epoch 增加
|
|
145
|
+
if len(self.loss_history) > 1 and current_loss > self.loss_history[-2]:
|
|
146
|
+
self.increase_counter += 1
|
|
147
|
+
else:
|
|
148
|
+
self.increase_counter = 0
|
|
149
|
+
|
|
150
|
+
# 2. 检查发散 (Divergence)
|
|
151
|
+
is_diverging = (current_loss > self.min_loss * self.divergence_threshold)
|
|
152
|
+
is_unstable = (self.increase_counter >= 2)
|
|
153
|
+
|
|
154
|
+
if (is_diverging or is_unstable) and len(self.loss_history) > 1:
|
|
155
|
+
res = self._analyze_divergence(engine, current_loss, is_unstable)
|
|
156
|
+
if res: advice_list.append(res)
|
|
157
|
+
|
|
158
|
+
# 重置以避免每个 epoch 都刷屏
|
|
159
|
+
if is_diverging:
|
|
160
|
+
self.min_loss = current_loss
|
|
161
|
+
if is_unstable:
|
|
162
|
+
self.increase_counter = 0
|
|
163
|
+
|
|
164
|
+
# 3. 检查停滞 (Stagnation)
|
|
165
|
+
if self.stagnation_counter >= self.patience:
|
|
166
|
+
res = self._analyze_stagnation(engine, current_loss)
|
|
167
|
+
if res: advice_list.append(res)
|
|
168
|
+
self.stagnation_counter = 0 # 重置计数器
|
|
169
|
+
|
|
170
|
+
# 4. 检查过拟合 (Overfitting)
|
|
171
|
+
if len(self.val_loss_history) >= 3:
|
|
172
|
+
res = self._analyze_overfitting(engine)
|
|
173
|
+
if res: advice_list.append(res)
|
|
174
|
+
|
|
175
|
+
# 5. 检查震荡 (Oscillation)
|
|
176
|
+
if len(self.loss_history) >= 5:
|
|
177
|
+
res = self._analyze_oscillation(engine)
|
|
178
|
+
if res: advice_list.append(res)
|
|
179
|
+
|
|
180
|
+
# 统一打印
|
|
181
|
+
if advice_list and self.verbose:
|
|
182
|
+
with engine.out_logs:
|
|
183
|
+
for title, content, style in advice_list:
|
|
184
|
+
engine.console.print(Panel(content, title=title, border_style=style, expand=False))
|
|
185
|
+
|
|
186
|
+
def _analyze_divergence(self, engine: 'Engine', current_loss: float, is_unstable: bool = False) -> Tuple[str, str, str]:
|
|
187
|
+
in_warmup = engine.is_in_warmup()
|
|
188
|
+
eff_bs = self.batch_size * self.accum_steps
|
|
189
|
+
|
|
190
|
+
title = self._t("divergence_title")
|
|
191
|
+
if is_unstable:
|
|
192
|
+
msg = self._t("divergence_msg_unstable", count=self.increase_counter, loss=current_loss)
|
|
193
|
+
else:
|
|
194
|
+
msg = self._t("divergence_msg_spike", loss=current_loss, min_loss=self.min_loss)
|
|
195
|
+
|
|
196
|
+
advice = []
|
|
197
|
+
|
|
198
|
+
# 1. Warmup 建议
|
|
199
|
+
if not self.has_warmup:
|
|
200
|
+
advice.append(self._t("advice_add_warmup"))
|
|
201
|
+
advice.append(self._t("advice_lower_lr"))
|
|
202
|
+
elif in_warmup:
|
|
203
|
+
advice.append(self._t("advice_warmup_start_lr"))
|
|
204
|
+
else:
|
|
205
|
+
advice.append(self._t("advice_post_warmup_lr"))
|
|
206
|
+
|
|
207
|
+
# 2. Batch Size & Accumulation 建议
|
|
208
|
+
if eff_bs < 32:
|
|
209
|
+
advice.append(self._t("advice_small_bs", eff_bs=eff_bs))
|
|
210
|
+
advice.append(self._t("advice_increase_accum", accum_steps=self.accum_steps))
|
|
211
|
+
|
|
212
|
+
# 3. 通用建议
|
|
213
|
+
advice.append(self._t("advice_grad_clip"))
|
|
214
|
+
|
|
215
|
+
return title, msg + "\n\n" + "\n".join(advice), "red"
|
|
216
|
+
|
|
217
|
+
def _analyze_stagnation(self, engine: 'Engine', current_loss: float) -> Tuple[str, str, str]:
|
|
218
|
+
in_warmup = engine.is_in_warmup()
|
|
219
|
+
eff_bs = self.batch_size * self.accum_steps
|
|
220
|
+
|
|
221
|
+
title = self._t("stagnation_title")
|
|
222
|
+
msg = self._t("stagnation_msg", patience=self.patience)
|
|
223
|
+
|
|
224
|
+
advice = []
|
|
225
|
+
|
|
226
|
+
if in_warmup:
|
|
227
|
+
# 在 Warmup 期间停滞
|
|
228
|
+
warmup_epochs = getattr(self.warmup_plugin, 'warmup_epochs', 0)
|
|
229
|
+
total_epochs = engine.num_epochs
|
|
230
|
+
|
|
231
|
+
advice.append(self._t("advice_warmup_duration", epoch=engine.epoch+1))
|
|
232
|
+
if warmup_epochs > total_epochs * 0.2:
|
|
233
|
+
advice.append(self._t("advice_warmup_too_long", warmup_epochs=warmup_epochs))
|
|
234
|
+
advice.append(self._t("advice_check_start_lr"))
|
|
235
|
+
|
|
236
|
+
else:
|
|
237
|
+
# 非 Warmup 期间停滞
|
|
238
|
+
advice.append(self._t("advice_lr_general"))
|
|
239
|
+
|
|
240
|
+
# Scheduler 建议
|
|
241
|
+
if not self.has_scheduler:
|
|
242
|
+
advice.append(self._t("advice_add_scheduler"))
|
|
243
|
+
else:
|
|
244
|
+
advice.append(self._t("advice_check_scheduler"))
|
|
245
|
+
|
|
246
|
+
# Batch Size 建议
|
|
247
|
+
if eff_bs > 4096:
|
|
248
|
+
advice.append(self._t("advice_large_bs", eff_bs=eff_bs))
|
|
249
|
+
advice.append(self._t("advice_reduce_bs"))
|
|
250
|
+
|
|
251
|
+
# 4. 初始化建议 (仅在早期)
|
|
252
|
+
if engine.epoch < 10:
|
|
253
|
+
advice.append(self._t("advice_check_init"))
|
|
254
|
+
|
|
255
|
+
# 5. 调试与数据建议
|
|
256
|
+
advice.append(self._t("advice_overfit_single_batch"))
|
|
257
|
+
advice.append(self._t("advice_data_hard"))
|
|
258
|
+
|
|
259
|
+
return title, msg + "\n\n" + "\n".join(advice), "yellow"
|
|
260
|
+
|
|
261
|
+
def _analyze_overfitting(self, engine: 'Engine') -> Optional[Tuple[str, str, str]]:
|
|
262
|
+
"""检测过拟合:Train Loss 下降,Val Loss 上升"""
|
|
263
|
+
if len(self.val_loss_history) < 3 or len(self.loss_history) < 3:
|
|
264
|
+
return None
|
|
265
|
+
|
|
266
|
+
# 检查最近 3 个 epoch
|
|
267
|
+
recent_val = self.val_loss_history[-3:]
|
|
268
|
+
recent_train = self.loss_history[-3:]
|
|
269
|
+
|
|
270
|
+
# Val Loss 持续上升
|
|
271
|
+
val_rising = (recent_val[-1] > recent_val[-2] > recent_val[-3])
|
|
272
|
+
# Train Loss 持续下降 (或保持低位)
|
|
273
|
+
train_dropping = (recent_train[-1] <= recent_train[-2])
|
|
274
|
+
|
|
275
|
+
if val_rising and train_dropping:
|
|
276
|
+
title = self._t("overfitting_title")
|
|
277
|
+
msg = self._t("overfitting_msg")
|
|
278
|
+
advice = [
|
|
279
|
+
self._t("advice_regularization"),
|
|
280
|
+
self._t("advice_data_aug"),
|
|
281
|
+
self._t("advice_early_stopping")
|
|
282
|
+
]
|
|
283
|
+
return title, msg + "\n\n" + "\n".join(advice), "magenta"
|
|
284
|
+
return None
|
|
285
|
+
|
|
286
|
+
def _analyze_oscillation(self, engine: 'Engine') -> Optional[Tuple[str, str, str]]:
|
|
287
|
+
"""检测震荡:Loss 标准差过大"""
|
|
288
|
+
if len(self.loss_history) < 5:
|
|
289
|
+
return None
|
|
290
|
+
|
|
291
|
+
recent_loss = self.loss_history[-5:]
|
|
292
|
+
std_dev = np.std(recent_loss)
|
|
293
|
+
mean_loss = np.mean(recent_loss)
|
|
294
|
+
|
|
295
|
+
# 如果标准差超过均值的 10% (经验值),且没有持续下降趋势
|
|
296
|
+
# 简单的趋势检查:首尾差异不大
|
|
297
|
+
is_flat_trend = abs(recent_loss[-1] - recent_loss[0]) < std_dev
|
|
298
|
+
|
|
299
|
+
if std_dev > 0.1 * mean_loss and is_flat_trend:
|
|
300
|
+
title = self._t("oscillation_title")
|
|
301
|
+
msg = self._t("oscillation_msg", std=std_dev)
|
|
302
|
+
advice = [
|
|
303
|
+
self._t("advice_lower_lr_oscillation")
|
|
304
|
+
]
|
|
305
|
+
|
|
306
|
+
if not self.has_scheduler:
|
|
307
|
+
advice.append(self._t("advice_oscillation_scheduler"))
|
|
308
|
+
|
|
309
|
+
if not self.grad_clip_norm:
|
|
310
|
+
advice.append(self._t("advice_oscillation_grad_clip"))
|
|
311
|
+
|
|
312
|
+
return title, msg + "\n\n" + "\n".join(advice), "cyan"
|
|
313
|
+
return None
|
orbit/plugin/overfit.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from orbit.callback import Callback, Event
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
class Overfit(Callback):
|
|
5
|
+
'''故意重复第一个 Batch 的数据以进行过拟合测试的插件。
|
|
6
|
+
|
|
7
|
+
这对于验证模型架构是否有能力拟合数据非常有用(Sanity Check)。
|
|
8
|
+
如果模型无法在单个 Batch 上过拟合,则可能存在代码错误或架构问题。
|
|
9
|
+
'''
|
|
10
|
+
|
|
11
|
+
def __init__(self):
|
|
12
|
+
self.fixed_data: Any = None
|
|
13
|
+
self.fixed_target: Any = None
|
|
14
|
+
self.has_captured = False
|
|
15
|
+
|
|
16
|
+
def on_batch_start(self, event: Event):
|
|
17
|
+
# 仅在训练阶段生效
|
|
18
|
+
if event.engine.state != "TRAIN":
|
|
19
|
+
return
|
|
20
|
+
|
|
21
|
+
if not self.has_captured:
|
|
22
|
+
# 捕获第一个 Batch
|
|
23
|
+
self.fixed_data = event.engine.data
|
|
24
|
+
self.fixed_target = event.engine.target
|
|
25
|
+
self.has_captured = True
|
|
26
|
+
event.engine.print("[yellow]Overfit Plugin: Captured first batch. All subsequent batches will be replaced by this one.[/]", plugin='Overfit')
|
|
27
|
+
else:
|
|
28
|
+
# 替换后续 Batch
|
|
29
|
+
event.engine.data = self.fixed_data
|
|
30
|
+
event.engine.target = self.fixed_target
|
orbit/plugin/warmup.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
from typing import Optional, List, TYPE_CHECKING
|
|
2
|
+
from orbit.callback import Callback, Event
|
|
3
|
+
|
|
4
|
+
if TYPE_CHECKING:
|
|
5
|
+
from orbit.engine import Engine
|
|
6
|
+
|
|
7
|
+
class Warmup(Callback):
|
|
8
|
+
"""
|
|
9
|
+
学习率预热 (Warmup) 插件。
|
|
10
|
+
支持 Linear, Constant, Noam (Transformer) 三种模式。
|
|
11
|
+
可以在 Batch 粒度上动态调整学习率。
|
|
12
|
+
"""
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
warmup_steps: int = 0,
|
|
16
|
+
warmup_epochs: int = 0,
|
|
17
|
+
mode: str = 'linear',
|
|
18
|
+
min_lr: float = 0.0,
|
|
19
|
+
model_dim: Optional[int] = None,
|
|
20
|
+
scale: float = 1.0,
|
|
21
|
+
):
|
|
22
|
+
"""
|
|
23
|
+
Args:
|
|
24
|
+
warmup_steps (int): 预热的总步数 (Batch 数)。
|
|
25
|
+
warmup_epochs (int): 预热的总 Epoch 数。如果设置,优先级高于 warmup_steps。
|
|
26
|
+
mode (str): 'linear' | 'constant' | 'noam'。
|
|
27
|
+
min_lr (float): 起始学习率 (linear/constant 模式)。
|
|
28
|
+
model_dim (int): 模型维度 (仅 noam 模式需要)。
|
|
29
|
+
scale (float): 缩放因子 (仅 noam 模式需要)。
|
|
30
|
+
"""
|
|
31
|
+
super().__init__()
|
|
32
|
+
self.warmup_steps = warmup_steps
|
|
33
|
+
self.warmup_epochs = warmup_epochs
|
|
34
|
+
self.mode = mode.lower()
|
|
35
|
+
self.min_lr = min_lr
|
|
36
|
+
self.model_dim = model_dim
|
|
37
|
+
self.scale = scale
|
|
38
|
+
|
|
39
|
+
self.total_warmup_steps = 0
|
|
40
|
+
self.base_lrs: List[float] = []
|
|
41
|
+
|
|
42
|
+
# 验证参数
|
|
43
|
+
if self.mode == 'noam' and self.model_dim is None:
|
|
44
|
+
raise ValueError("Noam mode requires 'model_dim' to be specified.")
|
|
45
|
+
|
|
46
|
+
def on_train_start(self, event: Event):
|
|
47
|
+
"""
|
|
48
|
+
训练开始时计算总预热步数并记录初始学习率
|
|
49
|
+
"""
|
|
50
|
+
engine = event.engine
|
|
51
|
+
if not engine.optimizer:
|
|
52
|
+
raise ValueError("Warmup plugin requires an optimizer in the Engine.")
|
|
53
|
+
|
|
54
|
+
# 1. 记录优化器的初始学习率 (Base LR)
|
|
55
|
+
# 如果是从 Checkpoint 恢复,base_lrs 可能会变,但在 Warmup 逻辑里我们通常认为
|
|
56
|
+
# 预热的目标就是 param_group['initial_lr'] (如果存在) 或者当前的 ['lr']
|
|
57
|
+
self.base_lrs = []
|
|
58
|
+
for group in engine.optimizer.param_groups:
|
|
59
|
+
# 优先使用 initial_lr (由某些 Scheduler 设置),否则使用当前 lr
|
|
60
|
+
self.base_lrs.append(group.get('initial_lr', group['lr']))
|
|
61
|
+
|
|
62
|
+
# 2. 计算 total_warmup_steps
|
|
63
|
+
if self.warmup_epochs > 0:
|
|
64
|
+
if not engine.train_loader:
|
|
65
|
+
raise ValueError("warmup_epochs requires train_loader to be available.")
|
|
66
|
+
try:
|
|
67
|
+
steps_per_epoch = len(engine.train_loader)
|
|
68
|
+
self.total_warmup_steps = self.warmup_epochs * steps_per_epoch
|
|
69
|
+
except TypeError:
|
|
70
|
+
# 如果 train_loader 无法求 len (例如 iterable dataset),则必须提供 steps
|
|
71
|
+
raise ValueError("Could not determine length of train_loader. Please use 'warmup_steps' instead.")
|
|
72
|
+
else:
|
|
73
|
+
self.total_warmup_steps = self.warmup_steps
|
|
74
|
+
|
|
75
|
+
# 打印信息
|
|
76
|
+
if self.total_warmup_steps > 0 or self.mode == 'noam':
|
|
77
|
+
engine.print(f"[magenta]Strategy activated: {self.mode}[/]", plugin='Warmup')
|
|
78
|
+
if self.mode != 'noam':
|
|
79
|
+
engine.print(f"[magenta]Steps: {self.total_warmup_steps} (Epochs: {self.warmup_epochs})[/]", plugin='Warmup')
|
|
80
|
+
|
|
81
|
+
def on_batch_start(self, event: Event):
|
|
82
|
+
"""
|
|
83
|
+
每个 Batch 开始前调整学习率
|
|
84
|
+
"""
|
|
85
|
+
engine = event.engine
|
|
86
|
+
# 当前步数 (从 1 开始计算,方便公式)
|
|
87
|
+
current_step = engine.global_step + 1
|
|
88
|
+
|
|
89
|
+
# 如果超出预热范围且不是 Noam 模式,则不进行干预
|
|
90
|
+
# Noam 模式是一种全程调度策略,所以它会一直运行
|
|
91
|
+
if self.mode != 'noam' and current_step > self.total_warmup_steps:
|
|
92
|
+
return
|
|
93
|
+
|
|
94
|
+
for i, group in enumerate(engine.optimizer.param_groups):
|
|
95
|
+
base_lr = self.base_lrs[i]
|
|
96
|
+
new_lr = base_lr
|
|
97
|
+
|
|
98
|
+
if self.mode == 'noam':
|
|
99
|
+
# Noam Scheduler: scale * d_model^-0.5 * min(step^-0.5, step * warmup^-1.5)
|
|
100
|
+
warmup = self.total_warmup_steps
|
|
101
|
+
# 避免 step=0
|
|
102
|
+
step = max(1, current_step)
|
|
103
|
+
|
|
104
|
+
term1 = step ** -0.5
|
|
105
|
+
term2 = step * (warmup ** -1.5)
|
|
106
|
+
|
|
107
|
+
new_lr = self.scale * (self.model_dim ** -0.5) * min(term1, term2)
|
|
108
|
+
|
|
109
|
+
elif self.mode == 'linear':
|
|
110
|
+
# Linear: min_lr -> base_lr
|
|
111
|
+
alpha = current_step / self.total_warmup_steps
|
|
112
|
+
new_lr = self.min_lr + (base_lr - self.min_lr) * alpha
|
|
113
|
+
|
|
114
|
+
elif self.mode == 'constant':
|
|
115
|
+
# Constant: min_lr
|
|
116
|
+
new_lr = self.min_lr
|
|
117
|
+
|
|
118
|
+
# 更新学习率
|
|
119
|
+
group['lr'] = new_lr
|