orbit-torch 0.0.4a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orbit/__init__.py +3 -0
- orbit/callback.py +54 -0
- orbit/engine.py +802 -0
- orbit/optim/__init__.py +2 -0
- orbit/optim/muon.py +193 -0
- orbit/optim/sam.py +92 -0
- orbit/plugin/__init__.py +10 -0
- orbit/plugin/board.py +61 -0
- orbit/plugin/checkpoint.py +245 -0
- orbit/plugin/classification.py +190 -0
- orbit/plugin/data/mentor_i18n.json +102 -0
- orbit/plugin/display_model.py +75 -0
- orbit/plugin/early_stopping.py +101 -0
- orbit/plugin/ema.py +97 -0
- orbit/plugin/gradient_accumulation.py +32 -0
- orbit/plugin/memory_estimator.py +234 -0
- orbit/plugin/mentor.py +313 -0
- orbit/plugin/overfit.py +30 -0
- orbit/plugin/warmup.py +119 -0
- orbit/utils/__init__.py +29 -0
- orbit/utils/freeze.py +59 -0
- orbit/utils/initialization.py +501 -0
- orbit/utils/layer_io.py +55 -0
- orbit/utils/mask.py +92 -0
- orbit/utils/seed.py +66 -0
- orbit_torch-0.0.4a1.dist-info/METADATA +25 -0
- orbit_torch-0.0.4a1.dist-info/RECORD +29 -0
- orbit_torch-0.0.4a1.dist-info/WHEEL +5 -0
- orbit_torch-0.0.4a1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import matplotlib.pyplot as plt
|
|
3
|
+
import seaborn as sns
|
|
4
|
+
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
|
|
5
|
+
from rich.table import Table
|
|
6
|
+
from typing import List, Optional, TYPE_CHECKING
|
|
7
|
+
import rich.box as box
|
|
8
|
+
|
|
9
|
+
from orbit.callback import Callback, Event
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from ..engine import Engine
|
|
12
|
+
|
|
13
|
+
class ClassificationReport(Callback):
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
num_classes: int,
|
|
17
|
+
class_names: Optional[List[str]] = None,
|
|
18
|
+
top_k: int = 1,
|
|
19
|
+
cm_cmap: str = 'Blues'
|
|
20
|
+
):
|
|
21
|
+
"""
|
|
22
|
+
专用于分类任务的评估与可视化回调。
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
num_classes (int): 类别总数。
|
|
26
|
+
class_names (List[str]): 类别名称列表 ["Cat", "Dog", ...]。可选。
|
|
27
|
+
top_k (int): 另外计算 Top-K 准确率。
|
|
28
|
+
cm_cmap (str): 混淆矩阵热图的颜色风格。
|
|
29
|
+
"""
|
|
30
|
+
super().__init__()
|
|
31
|
+
self.num_classes = num_classes
|
|
32
|
+
self.class_names = class_names if class_names else [str(i) for i in range(num_classes)]
|
|
33
|
+
self.top_k = top_k
|
|
34
|
+
self.cm_cmap = cm_cmap
|
|
35
|
+
|
|
36
|
+
# 缓存预测结果
|
|
37
|
+
self.preds = []
|
|
38
|
+
self.targets = []
|
|
39
|
+
|
|
40
|
+
def on_eval_start(self, event: Event):
|
|
41
|
+
"""每轮验证开始前清空缓存"""
|
|
42
|
+
self.preds = []
|
|
43
|
+
self.targets = []
|
|
44
|
+
|
|
45
|
+
def on_batch_end(self, event: Event):
|
|
46
|
+
"""收集验证阶段的预测结果"""
|
|
47
|
+
engine = event.engine
|
|
48
|
+
if engine.state == "EVAL":
|
|
49
|
+
# 假设 engine.output 是 logits [Batch, NumClasses]
|
|
50
|
+
# 假设 engine.target 是 labels [Batch]
|
|
51
|
+
|
|
52
|
+
# 收集 Raw Output (用于 Top-K) 或 Argmax (用于混淆矩阵)
|
|
53
|
+
# 为了节省内存,我们这里尽量存 CPU Tensor
|
|
54
|
+
self.preds.append(engine.output.detach().cpu())
|
|
55
|
+
self.targets.append(engine.target.detach().cpu())
|
|
56
|
+
|
|
57
|
+
def on_eval_end(self, event: Event):
|
|
58
|
+
"""验证结束后计算指标并绘图"""
|
|
59
|
+
if not self.preds: return
|
|
60
|
+
engine = event.engine
|
|
61
|
+
|
|
62
|
+
# 1. 拼接所有 Batch
|
|
63
|
+
all_logits = torch.cat(self.preds) # [N, C]
|
|
64
|
+
all_targets = torch.cat(self.targets) # [N]
|
|
65
|
+
|
|
66
|
+
# 转为预测类别索引 [N]
|
|
67
|
+
all_preds_idx = all_logits.argmax(dim=1)
|
|
68
|
+
|
|
69
|
+
# 转换 numpy 用于 sklearn
|
|
70
|
+
y_true = all_targets.numpy()
|
|
71
|
+
y_pred = all_preds_idx.numpy()
|
|
72
|
+
|
|
73
|
+
# --- A. 计算基础 Acc 并存入 metrics ---
|
|
74
|
+
acc = accuracy_score(y_true, y_pred)
|
|
75
|
+
engine.metrics['val_acc'] = acc
|
|
76
|
+
|
|
77
|
+
# --- 计算 Top-K Acc ---
|
|
78
|
+
topk_acc = None
|
|
79
|
+
if self.top_k > 1:
|
|
80
|
+
_, indices = all_logits.topk(self.top_k, dim=1)
|
|
81
|
+
correct = indices.eq(all_targets.view(-1, 1).expand_as(indices))
|
|
82
|
+
topk_acc = correct.sum().item() / len(all_targets)
|
|
83
|
+
engine.metrics[f'val_acc_top{self.top_k}'] = topk_acc
|
|
84
|
+
|
|
85
|
+
# --- B. 控制台打印 Classification Report ---
|
|
86
|
+
report = classification_report(
|
|
87
|
+
y_true, y_pred,
|
|
88
|
+
target_names=self.class_names,
|
|
89
|
+
output_dict=True,
|
|
90
|
+
zero_division=0
|
|
91
|
+
)
|
|
92
|
+
self._print_rich_table(engine, report, acc, topk_acc)
|
|
93
|
+
|
|
94
|
+
# --- C. 绘制 Confusion Matrix ---
|
|
95
|
+
# 只有挂载了 TensorBoard Writer 才画图
|
|
96
|
+
if hasattr(engine, 'writer') and engine.writer is not None:
|
|
97
|
+
fig = self._plot_confusion_matrix(y_true, y_pred)
|
|
98
|
+
engine.writer.add_figure("Eval/Confusion_Matrix", fig, global_step=engine.epoch)
|
|
99
|
+
plt.close(fig) # 关闭 release 内存
|
|
100
|
+
|
|
101
|
+
def _print_rich_table(self, engine, report: dict, acc: float, topk_acc: Optional[float] = None):
|
|
102
|
+
"""用 Rich 打印漂亮的分类报告表格"""
|
|
103
|
+
table = Table(title=f"[bold]Evaluation Report (Ep {engine.epoch+1})[/]", box=box.HORIZONTALS)
|
|
104
|
+
table.add_column("Class", style="cyan")
|
|
105
|
+
table.add_column("Precision", justify="right")
|
|
106
|
+
table.add_column("Recall", justify="right")
|
|
107
|
+
table.add_column("F1-Score", justify="right")
|
|
108
|
+
|
|
109
|
+
# 限制显示数量,防止刷屏
|
|
110
|
+
max_display = 20
|
|
111
|
+
items_to_show = []
|
|
112
|
+
|
|
113
|
+
if len(self.class_names) > max_display:
|
|
114
|
+
items_to_show.extend(self.class_names[:10])
|
|
115
|
+
items_to_show.append(None) # None 表示省略号
|
|
116
|
+
items_to_show.extend(self.class_names[-10:])
|
|
117
|
+
else:
|
|
118
|
+
items_to_show = self.class_names
|
|
119
|
+
|
|
120
|
+
for class_name in items_to_show:
|
|
121
|
+
if class_name is None:
|
|
122
|
+
table.add_row("...", "...", "...", "...")
|
|
123
|
+
continue
|
|
124
|
+
|
|
125
|
+
if class_name in report:
|
|
126
|
+
row = report[class_name]
|
|
127
|
+
table.add_row(
|
|
128
|
+
class_name,
|
|
129
|
+
f"{row['precision']:.3f}",
|
|
130
|
+
f"{row['recall']:.3f}",
|
|
131
|
+
f"{row['f1-score']:.3f}",
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
avg = report['weighted avg']
|
|
135
|
+
table.add_row(
|
|
136
|
+
"[bold]Weighted Avg[/]",
|
|
137
|
+
f"[bold]{avg['precision']:.3f}[/]",
|
|
138
|
+
f"[bold]{avg['recall']:.3f}[/]",
|
|
139
|
+
f"[bold]{avg['f1-score']:.3f}[/]",
|
|
140
|
+
end_section=True
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
with engine.out_logs:
|
|
144
|
+
engine.print(table)
|
|
145
|
+
|
|
146
|
+
engine.print(f"Accuracy: [green]{acc*100:.2f}%[/]", plugin='ClassReport')
|
|
147
|
+
if topk_acc is not None:
|
|
148
|
+
engine.print(f"Top-{self.top_k} Accuracy: [green]{topk_acc*100:.2f}%[/]", plugin='ClassReport')
|
|
149
|
+
|
|
150
|
+
def _plot_confusion_matrix(self, y_true, y_pred):
|
|
151
|
+
"""使用 Seaborn 绘制混淆矩阵"""
|
|
152
|
+
cm = confusion_matrix(y_true, y_pred)
|
|
153
|
+
|
|
154
|
+
num_classes = len(self.class_names)
|
|
155
|
+
|
|
156
|
+
# 动态调整 Figure 大小
|
|
157
|
+
# 基础大小 10,每多一个类别增加一点尺寸,上限设个限制防止过大
|
|
158
|
+
fig_base = 10
|
|
159
|
+
fig_scale = 0.3
|
|
160
|
+
figsize_dim = max(fig_base, min(50, num_classes * fig_scale))
|
|
161
|
+
|
|
162
|
+
# 创建 Figure
|
|
163
|
+
fig, ax = plt.subplots(figsize=(figsize_dim, figsize_dim))
|
|
164
|
+
|
|
165
|
+
# 智能决定是否显示数值 annot
|
|
166
|
+
do_annot = True
|
|
167
|
+
if num_classes > 20:
|
|
168
|
+
do_annot = False
|
|
169
|
+
|
|
170
|
+
sns.heatmap(
|
|
171
|
+
cm,
|
|
172
|
+
annot=do_annot,
|
|
173
|
+
fmt='d',
|
|
174
|
+
cmap=self.cm_cmap,
|
|
175
|
+
xticklabels=self.class_names,
|
|
176
|
+
yticklabels=self.class_names,
|
|
177
|
+
ax=ax,
|
|
178
|
+
square=True
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
# 调整标签样式
|
|
182
|
+
if num_classes > 20:
|
|
183
|
+
plt.xticks(rotation=90, fontsize=8)
|
|
184
|
+
plt.yticks(rotation=0, fontsize=8)
|
|
185
|
+
|
|
186
|
+
ax.set_xlabel('Predicted')
|
|
187
|
+
ax.set_ylabel('True')
|
|
188
|
+
ax.set_title(f'Confusion Matrix ({num_classes} classes)')
|
|
189
|
+
plt.tight_layout()
|
|
190
|
+
return fig
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
{
|
|
2
|
+
"en": {
|
|
3
|
+
"nan_loss_title": "Loss is NaN or Infinity!",
|
|
4
|
+
"nan_loss_msg": "Gradient explosion detected. Try:\n1. Lower learning rate significantly.\n2. Enable gradient clipping (engine.grad_clip_norm).\n3. Check your data for anomalies.",
|
|
5
|
+
|
|
6
|
+
"divergence_title": "⚠️ Loss Divergence/Instability Detected",
|
|
7
|
+
"divergence_msg_unstable": "Train loss has increased for {count} consecutive epochs ({loss:.4f}).",
|
|
8
|
+
"divergence_msg_spike": "Train loss has spiked significantly ({loss:.4f} vs min {min_loss:.4f}).",
|
|
9
|
+
|
|
10
|
+
"stagnation_title": "🛑 Loss Stagnation Detected",
|
|
11
|
+
"stagnation_msg": "Train loss hasn't improved for {patience} epochs.",
|
|
12
|
+
|
|
13
|
+
"advice_add_warmup": "- Add Warmup: You are not using Warmup. Unstable gradients at start can cause divergence.",
|
|
14
|
+
"advice_lower_lr": "- Lower LR: Try reducing your initial learning rate.",
|
|
15
|
+
"advice_warmup_start_lr": "- Warmup Start LR: Your warmup starting learning rate might be too high.",
|
|
16
|
+
"advice_post_warmup_lr": "- Post-Warmup LR: The learning rate after warmup might be too high.",
|
|
17
|
+
|
|
18
|
+
"advice_small_bs": "- Small Batch Size: Effective Batch Size is {eff_bs}. Small batches have noisy gradients.",
|
|
19
|
+
"advice_increase_accum": " -> Try increasing `accumulation_steps` (currently {accum_steps}) or batch size.",
|
|
20
|
+
|
|
21
|
+
"advice_grad_clip": "- Gradient Clipping: Enable `grad_clip_norm` in Engine if not already used.",
|
|
22
|
+
|
|
23
|
+
"advice_warmup_duration": "- Warmup Duration: You are in warmup (Epoch {epoch}).",
|
|
24
|
+
"advice_warmup_too_long": " -> Your warmup is {warmup_epochs} epochs (>20% of total). Consider shortening it.",
|
|
25
|
+
"advice_check_start_lr": "- Start LR: Check if your warmup start LR is too small.",
|
|
26
|
+
|
|
27
|
+
"advice_lr_general": "- Learning Rate: LR might be too small (slow convergence) or too large (bouncing).",
|
|
28
|
+
"advice_add_scheduler": "- Add Scheduler: You are not using an LR Scheduler. Dynamic LR reduction often breaks stagnation.",
|
|
29
|
+
"advice_check_scheduler": "- Scheduler: Check if your scheduler decayed LR too early or too aggressively.",
|
|
30
|
+
|
|
31
|
+
"advice_large_bs": "- Large Batch Size: Effective BS is {eff_bs}. Large batches generalize worse.",
|
|
32
|
+
"advice_reduce_bs": " -> Try increasing LR (Linear Scaling Rule) or reducing batch size.",
|
|
33
|
+
|
|
34
|
+
"advice_check_init": "- Bad Initialization: Weights might be initialized poorly. Try Xavier or Kaiming initialization.",
|
|
35
|
+
"advice_overfit_single_batch": "- Debugging: Try overfitting a single batch. If the model can't learn even one batch, there's a bug in the model or data pipeline.",
|
|
36
|
+
"advice_data_hard": "- Data Complexity: The data might be too hard or noisy. Check labels and input features.",
|
|
37
|
+
|
|
38
|
+
"overfitting_title": "📉 Overfitting Detected",
|
|
39
|
+
"overfitting_msg": "Validation loss is rising while training loss decreases.",
|
|
40
|
+
"advice_regularization": "- Regularization: Add Dropout or Weight Decay (L2 regularization).",
|
|
41
|
+
"advice_data_aug": "- Data Augmentation: Increase data augmentation to improve generalization.",
|
|
42
|
+
"advice_early_stopping": "- Early Stopping: Consider stopping training now to prevent further degradation.",
|
|
43
|
+
|
|
44
|
+
"oscillation_title": "〰️ Loss Oscillation Detected",
|
|
45
|
+
"oscillation_msg": "Training loss is oscillating significantly (Std Dev: {std:.4f}).",
|
|
46
|
+
"advice_lower_lr_oscillation": "- Lower LR: The learning rate is likely too high, causing the model to overshoot minima.",
|
|
47
|
+
"advice_oscillation_scheduler": "- Add Scheduler: Dynamic LR reduction helps stabilize training when loss oscillates.",
|
|
48
|
+
"advice_oscillation_grad_clip": "- Gradient Clipping: Enable `grad_clip_norm` to prevent large gradient updates from destabilizing the model.",
|
|
49
|
+
|
|
50
|
+
"mentor_watching": "[dim]Mentor watching: Eff. Batch Size={eff_bs} (BS={bs} * Accum={accum})[/]"
|
|
51
|
+
},
|
|
52
|
+
"zh": {
|
|
53
|
+
"nan_loss_title": "Loss 变为 NaN 或无穷大!",
|
|
54
|
+
"nan_loss_msg": "检测到梯度爆炸。尝试:\n1. 显著降低学习率。\n2. 启用梯度裁剪 (engine.grad_clip_norm)。\n3. 检查数据是否存在异常。",
|
|
55
|
+
|
|
56
|
+
"divergence_title": "⚠️ 检测到 Loss 发散/不稳定",
|
|
57
|
+
"divergence_msg_unstable": "训练 Loss 已连续 {count} 个 Epoch 上升 ({loss:.4f})。",
|
|
58
|
+
"divergence_msg_spike": "训练 Loss 显著飙升 ({loss:.4f} vs 最小值 {min_loss:.4f})。",
|
|
59
|
+
|
|
60
|
+
"stagnation_title": "🛑 检测到 Loss 停滞",
|
|
61
|
+
"stagnation_msg": "训练 Loss 已连续 {patience} 个 Epoch 未改善。",
|
|
62
|
+
|
|
63
|
+
"advice_add_warmup": "- 添加预热 (Warmup): 您未使用 Warmup。初始阶段的不稳定梯度可能导致发散。",
|
|
64
|
+
"advice_lower_lr": "- 降低 LR: 尝试降低初始学习率。",
|
|
65
|
+
"advice_warmup_start_lr": "- Warmup 初始 LR: 您的 Warmup 起始学习率可能过高。",
|
|
66
|
+
"advice_post_warmup_lr": "- Warmup 后 LR: Warmup 结束后的学习率可能过高。",
|
|
67
|
+
|
|
68
|
+
"advice_small_bs": "- Batch Size 过小: 有效 Batch Size 为 {eff_bs}。小批量会导致梯度噪声大。",
|
|
69
|
+
"advice_increase_accum": " -> 尝试增加 `accumulation_steps` (当前为 {accum_steps}) 或 Batch Size。",
|
|
70
|
+
|
|
71
|
+
"advice_grad_clip": "- 梯度裁剪: 如果尚未启用,请在 Engine 中启用 `grad_clip_norm`。",
|
|
72
|
+
|
|
73
|
+
"advice_warmup_duration": "- Warmup 持续时间: 当前处于 Warmup 阶段 (Epoch {epoch})。",
|
|
74
|
+
"advice_warmup_too_long": " -> 您的 Warmup 长达 {warmup_epochs} 个 Epoch (>总数的 20%)。考虑缩短它。",
|
|
75
|
+
"advice_check_start_lr": "- 初始 LR: 检查 Warmup 起始 LR 是否过小。",
|
|
76
|
+
|
|
77
|
+
"advice_lr_general": "- 学习率: LR 可能太小 (收敛慢) 或太大 (在最小值附近震荡)。",
|
|
78
|
+
"advice_add_scheduler": "- 添加调度器: 您未使用 LR Scheduler。动态降低 LR 通常能打破停滞。",
|
|
79
|
+
"advice_check_scheduler": "- 检查调度器: 检查 Scheduler 是否过早或过激进地降低了 LR。",
|
|
80
|
+
|
|
81
|
+
"advice_large_bs": "- Batch Size 过大: 有效 BS 为 {eff_bs}。大批量通常泛化能力较差。",
|
|
82
|
+
"advice_reduce_bs": " -> 尝试增加 LR (线性缩放规则) 或减小 Batch Size。",
|
|
83
|
+
|
|
84
|
+
"advice_check_init": "- 初始化糟糕: 权重初始化可能不当。尝试使用 Xavier 或 Kaiming 初始化。",
|
|
85
|
+
"advice_overfit_single_batch": "- 调试建议: 尝试使用单 Batch 过拟合。如果模型连一个 Batch 都学不会,说明模型或数据管道有 Bug。",
|
|
86
|
+
"advice_data_hard": "- 数据难度: 数据可能太难或噪声太大。检查标签和输入特征。",
|
|
87
|
+
|
|
88
|
+
"overfitting_title": "📉 检测到过拟合",
|
|
89
|
+
"overfitting_msg": "验证集 Loss 正在上升,而训练集 Loss 仍在下降。",
|
|
90
|
+
"advice_regularization": "- 正则化: 添加 Dropout 或权重衰减 (Weight Decay/L2 正则)。",
|
|
91
|
+
"advice_data_aug": "- 数据增强: 增加数据增强强度以提高泛化能力。",
|
|
92
|
+
"advice_early_stopping": "- 早停 (Early Stopping): 考虑立即停止训练以防止性能进一步恶化。",
|
|
93
|
+
|
|
94
|
+
"oscillation_title": "〰️ 检测到 Loss 震荡",
|
|
95
|
+
"oscillation_msg": "训练 Loss 波动剧烈 (标准差: {std:.4f})。",
|
|
96
|
+
"advice_lower_lr_oscillation": "- 降低 LR: 学习率可能过高,导致模型在极小值附近跳跃。",
|
|
97
|
+
"advice_oscillation_scheduler": "- 添加调度器: 动态降低 LR 有助于在 Loss 震荡时稳定训练。",
|
|
98
|
+
"advice_oscillation_grad_clip": "- 梯度裁剪: 启用 `grad_clip_norm` 以防止大梯度更新导致模型不稳定。",
|
|
99
|
+
|
|
100
|
+
"mentor_watching": "[dim]Mentor watching: Eff. Batch Size={eff_bs} (BS={bs} * Accum={accum})[/]"
|
|
101
|
+
}
|
|
102
|
+
}
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
import torch.nn as nn
|
|
2
|
+
from rich.table import Table
|
|
3
|
+
from rich.console import Console
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
import rich.box as box
|
|
6
|
+
|
|
7
|
+
from orbit.callback import Callback, Event
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from ..engine import Engine
|
|
11
|
+
|
|
12
|
+
class ModelSummary(Callback):
|
|
13
|
+
def __init__(self, max_depth: int = 3):
|
|
14
|
+
super().__init__()
|
|
15
|
+
self.max_depth = max_depth
|
|
16
|
+
|
|
17
|
+
def on_init(self, event: Event):
|
|
18
|
+
"""
|
|
19
|
+
Engine 初始化时,自动打印模型结构
|
|
20
|
+
"""
|
|
21
|
+
engine = event.engine
|
|
22
|
+
self.display(engine.model, engine.console)
|
|
23
|
+
|
|
24
|
+
def display(self, model: nn.Module, console: Console):
|
|
25
|
+
"""核心打印逻辑"""
|
|
26
|
+
table = Table(title=f"[bold]Model Summary: {model.__class__.__name__}[/]", box=box.HORIZONTALS)
|
|
27
|
+
|
|
28
|
+
table.add_column("Layer (Type)", style="cyan", no_wrap=True)
|
|
29
|
+
table.add_column("Output Shape", style="magenta")
|
|
30
|
+
table.add_column("Param #", justify="right", style="green")
|
|
31
|
+
table.add_column("Trainable", justify="right", style="yellow")
|
|
32
|
+
|
|
33
|
+
total_params = 0
|
|
34
|
+
trainable_params = 0
|
|
35
|
+
|
|
36
|
+
# 遍历顶层模块 (简单版遍历,深度遍历比较复杂,为了美观这里展示第一级子模块)
|
|
37
|
+
for name, module in model.named_children():
|
|
38
|
+
# 计算该模块的总参数
|
|
39
|
+
num_params = sum(p.numel() for p in module.parameters())
|
|
40
|
+
num_trainable = sum(p.numel() for p in module.parameters() if p.requires_grad)
|
|
41
|
+
|
|
42
|
+
total_params += num_params
|
|
43
|
+
trainable_params += num_trainable
|
|
44
|
+
|
|
45
|
+
is_trainable = "[bold green]Yes[/]" if num_trainable > 0 else "[dim]No[/]"
|
|
46
|
+
|
|
47
|
+
layer_name = f"{name} ({module.__class__.__name__})"
|
|
48
|
+
|
|
49
|
+
table.add_row(
|
|
50
|
+
layer_name,
|
|
51
|
+
"-",
|
|
52
|
+
f"{num_params:,}",
|
|
53
|
+
is_trainable
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
# 计算模型总大小 (MB) - Float32 = 4 bytes
|
|
57
|
+
total_size_mb = total_params * 4 / (1024 ** 2)
|
|
58
|
+
|
|
59
|
+
console.print(table)
|
|
60
|
+
|
|
61
|
+
if total_params > 0:
|
|
62
|
+
trainable_params = trainable_params/total_params
|
|
63
|
+
else:
|
|
64
|
+
trainable_params = 0
|
|
65
|
+
|
|
66
|
+
# 打印汇总信息
|
|
67
|
+
summary_table = Table(show_header=False, box=None)
|
|
68
|
+
summary_table.add_row("Total Params:", f"[bold cyan]{total_params:,}[/]")
|
|
69
|
+
summary_table.add_row("Trainable Params:", f"[bold green]{trainable_params:,}[/] ({trainable_params:.1%})")
|
|
70
|
+
summary_table.add_row("Non-trainable Params:", f"[dim]{total_params - trainable_params:,}[/]")
|
|
71
|
+
summary_table.add_row("Est. Params Size (MB):", f"[bold blue]{total_size_mb:.2f} MB[/]")
|
|
72
|
+
|
|
73
|
+
console.print(summary_table)
|
|
74
|
+
#console.print(' ' + '─' * 15 + '─' + '─' * 35)
|
|
75
|
+
console.print()
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from typing import TYPE_CHECKING
|
|
3
|
+
from orbit.callback import Callback, Event
|
|
4
|
+
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
from orbit.engine import Engine
|
|
7
|
+
|
|
8
|
+
class EarlyStopping(Callback):
|
|
9
|
+
"""
|
|
10
|
+
Early Stopping 插件。
|
|
11
|
+
如果监控的指标在 'patience' 个 Epoch 内没有改善,则停止训练。
|
|
12
|
+
支持断点续训状态保存。
|
|
13
|
+
"""
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
monitor: str = 'val_loss',
|
|
17
|
+
mode: str = 'min',
|
|
18
|
+
patience: int = 5,
|
|
19
|
+
min_delta: float = 0.0,
|
|
20
|
+
verbose: bool = True
|
|
21
|
+
):
|
|
22
|
+
"""
|
|
23
|
+
Args:
|
|
24
|
+
monitor (str): 监控的指标名称 (e.g., 'val_loss', 'val_acc')。
|
|
25
|
+
mode (str): 'min' (越小越好) 或 'max' (越大越好)。
|
|
26
|
+
patience (int): 容忍多少个 Epoch 不提升。
|
|
27
|
+
min_delta (float): 视为提升的最小变化量。
|
|
28
|
+
verbose (bool): 是否打印信息。
|
|
29
|
+
"""
|
|
30
|
+
super().__init__()
|
|
31
|
+
self.monitor = monitor
|
|
32
|
+
self.mode = mode
|
|
33
|
+
self.patience = patience
|
|
34
|
+
self.min_delta = min_delta
|
|
35
|
+
self.verbose = verbose
|
|
36
|
+
|
|
37
|
+
self.wait_count = 0
|
|
38
|
+
self.best_score = np.inf if mode == 'min' else -np.inf
|
|
39
|
+
|
|
40
|
+
# 内部状态 Key
|
|
41
|
+
self._meta_key = 'early_stopping'
|
|
42
|
+
|
|
43
|
+
def on_train_start(self, event: Event):
|
|
44
|
+
"""尝试从 engine.meta 恢复状态"""
|
|
45
|
+
engine = event.engine
|
|
46
|
+
if self._meta_key in engine.meta:
|
|
47
|
+
state = engine.meta[self._meta_key]
|
|
48
|
+
self.best_score = state.get('best_score', self.best_score)
|
|
49
|
+
self.wait_count = state.get('wait_count', 0)
|
|
50
|
+
if self.verbose:
|
|
51
|
+
engine.print(f"[cyan]Resumed: Best Score={self.best_score:.4f}, Wait={self.wait_count}/{self.patience}[/]", plugin='EarlyStopping')
|
|
52
|
+
|
|
53
|
+
def on_epoch_end(self, event: Event):
|
|
54
|
+
"""每 Epoch 检查指标"""
|
|
55
|
+
engine = event.engine
|
|
56
|
+
# 0. 如果处于 Warmup 阶段,跳过 Early Stopping
|
|
57
|
+
if engine.is_in_warmup():
|
|
58
|
+
if self.verbose:
|
|
59
|
+
engine.print(f"[dim]Skipping EarlyStopping during warmup.[/]", plugin='EarlyStopping')
|
|
60
|
+
return
|
|
61
|
+
|
|
62
|
+
# 1. 获取当前指标
|
|
63
|
+
current_score = engine.metrics.get(self.monitor)
|
|
64
|
+
|
|
65
|
+
if current_score is None:
|
|
66
|
+
# 如果指标不存在 (例如只跑了 Train 没跑 Eval),跳过检查
|
|
67
|
+
return
|
|
68
|
+
|
|
69
|
+
# 2. 判断是否提升
|
|
70
|
+
improved = False
|
|
71
|
+
if self.mode == 'min':
|
|
72
|
+
if current_score < self.best_score - self.min_delta:
|
|
73
|
+
improved = True
|
|
74
|
+
else:
|
|
75
|
+
if current_score > self.best_score + self.min_delta:
|
|
76
|
+
improved = True
|
|
77
|
+
|
|
78
|
+
# 3. 更新状态
|
|
79
|
+
if improved:
|
|
80
|
+
old_best = self.best_score
|
|
81
|
+
self.best_score = current_score
|
|
82
|
+
self.wait_count = 0
|
|
83
|
+
if self.verbose:
|
|
84
|
+
if old_best == np.inf or old_best == -np.inf:
|
|
85
|
+
engine.print(f"{self.monitor} improved to [green]{current_score:.4f}[/]", plugin='EarlyStopping')
|
|
86
|
+
else:
|
|
87
|
+
engine.print(f"{self.monitor} improved [green]{old_best:.4f} -> {current_score:.4f}[/]", plugin='EarlyStopping')
|
|
88
|
+
else:
|
|
89
|
+
self.wait_count += 1
|
|
90
|
+
if self.verbose:
|
|
91
|
+
engine.print(f"[yellow]{self.monitor} did not improve ({self.wait_count}/{self.patience}). Best: {self.best_score:.4f}[/]", plugin='EarlyStopping')
|
|
92
|
+
|
|
93
|
+
if self.wait_count >= self.patience:
|
|
94
|
+
engine.stop(source="EarlyStopping", reason=f"No improvement in {self.patience} epochs")
|
|
95
|
+
engine.print(f"[red][bold]Stopping training (no improvement in {self.patience} epochs).[/]", plugin='EarlyStopping')
|
|
96
|
+
|
|
97
|
+
# 4. 保存状态到 meta,以便 Checkpoint 持久化
|
|
98
|
+
engine.meta[self._meta_key] = {
|
|
99
|
+
'best_score': self.best_score,
|
|
100
|
+
'wait_count': self.wait_count
|
|
101
|
+
}
|
orbit/plugin/ema.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
from copy import deepcopy
|
|
2
|
+
import torch
|
|
3
|
+
from orbit.callback import Callback, Event
|
|
4
|
+
from typing import TYPE_CHECKING, Dict
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
from orbit.engine import Engine
|
|
8
|
+
|
|
9
|
+
class EMA(Callback):
|
|
10
|
+
"""
|
|
11
|
+
指数移动平均 (Exponential Moving Average) 插件。
|
|
12
|
+
在训练过程中维护模型参数的滑动平均版本,并在评估/预测时使用它。
|
|
13
|
+
通常能提升模型的泛化能力和鲁棒性。
|
|
14
|
+
"""
|
|
15
|
+
def __init__(self, decay: float = 0.999, start_step: int = 0):
|
|
16
|
+
"""
|
|
17
|
+
Args:
|
|
18
|
+
decay (float): 衰减率,通常接近 1 (如 0.999, 0.9999)。
|
|
19
|
+
start_step (int): 从第几个 Global Step 开始启用 EMA。
|
|
20
|
+
"""
|
|
21
|
+
super().__init__()
|
|
22
|
+
self.decay = decay
|
|
23
|
+
self.start_step = start_step
|
|
24
|
+
self.shadow: Dict[str, torch.Tensor] = {}
|
|
25
|
+
self.backup: Dict[str, torch.Tensor] = {}
|
|
26
|
+
|
|
27
|
+
# 内部状态 Key,用于 Checkpoint 保存/恢复
|
|
28
|
+
self._meta_key = 'ema_state'
|
|
29
|
+
|
|
30
|
+
def on_init(self, event: Event):
|
|
31
|
+
# 初始化影子权重 (Shadow Weights)
|
|
32
|
+
# 注意:此时模型应该已经加载到了正确的 Device 上
|
|
33
|
+
engine = event.engine
|
|
34
|
+
for name, param in engine.model.named_parameters():
|
|
35
|
+
if param.requires_grad:
|
|
36
|
+
self.shadow[name] = param.data.clone()
|
|
37
|
+
|
|
38
|
+
engine.print(f"[magenta]Enabled (decay={self.decay})[/]", plugin='EMA')
|
|
39
|
+
|
|
40
|
+
def on_train_start(self, event: Event):
|
|
41
|
+
"""尝试从 Checkpoint 恢复 EMA 状态"""
|
|
42
|
+
engine = event.engine
|
|
43
|
+
if self._meta_key in engine.meta:
|
|
44
|
+
saved_shadow = engine.meta[self._meta_key]
|
|
45
|
+
# 确保加载的权重在正确的设备上
|
|
46
|
+
for k, v in saved_shadow.items():
|
|
47
|
+
if k in self.shadow:
|
|
48
|
+
self.shadow[k] = v.to(engine.device)
|
|
49
|
+
engine.print(f"[green]Resumed EMA state from checkpoint[/]", plugin='EMA')
|
|
50
|
+
|
|
51
|
+
def on_batch_end(self, event: Event):
|
|
52
|
+
"""每个 Batch 结束后更新 EMA 权重"""
|
|
53
|
+
engine = event.engine
|
|
54
|
+
if engine.state == 'TRAIN' and engine.global_step >= self.start_step:
|
|
55
|
+
for name, param in engine.model.named_parameters():
|
|
56
|
+
if param.requires_grad:
|
|
57
|
+
# shadow = decay * shadow + (1 - decay) * param
|
|
58
|
+
self.shadow[name].data.mul_(self.decay).add_(param.data, alpha=1.0 - self.decay)
|
|
59
|
+
|
|
60
|
+
def on_eval_start(self, event: Event):
|
|
61
|
+
"""评估开始前:备份当前权重,应用 EMA 权重"""
|
|
62
|
+
engine = event.engine
|
|
63
|
+
if engine.global_step < self.start_step:
|
|
64
|
+
return
|
|
65
|
+
|
|
66
|
+
self.backup = {
|
|
67
|
+
name: p.data.clone()
|
|
68
|
+
for name, p in engine.model.named_parameters()
|
|
69
|
+
if p.requires_grad
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
for name, param in engine.model.named_parameters():
|
|
73
|
+
if param.requires_grad:
|
|
74
|
+
param.data.copy_(self.shadow[name])
|
|
75
|
+
|
|
76
|
+
engine.print("[dim]Switched to EMA weights for evaluation[/]", plugin='EMA')
|
|
77
|
+
|
|
78
|
+
def on_eval_end(self, event: Event):
|
|
79
|
+
"""评估结束后:恢复原始训练权重"""
|
|
80
|
+
engine = event.engine
|
|
81
|
+
if not self.backup:
|
|
82
|
+
return
|
|
83
|
+
|
|
84
|
+
for name, param in engine.model.named_parameters():
|
|
85
|
+
if param.requires_grad:
|
|
86
|
+
param.data.copy_(self.backup[name])
|
|
87
|
+
|
|
88
|
+
self.backup = {} # 清空备份
|
|
89
|
+
engine.print("[dim]Restored training weights[/]", plugin='EMA')
|
|
90
|
+
|
|
91
|
+
def on_epoch_end(self, event: Event):
|
|
92
|
+
"""
|
|
93
|
+
Epoch 结束时:将 EMA 状态存入 meta,以便 Checkpoint 插件保存。
|
|
94
|
+
注意:这会增加 Checkpoint 文件的大小 (约 2 倍模型大小)。
|
|
95
|
+
"""
|
|
96
|
+
engine = event.engine
|
|
97
|
+
engine.meta[self._meta_key] = self.shadow
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING
|
|
2
|
+
from orbit.callback import Callback, Event
|
|
3
|
+
|
|
4
|
+
if TYPE_CHECKING:
|
|
5
|
+
from orbit.engine import Engine
|
|
6
|
+
|
|
7
|
+
class GradientAccumulation(Callback):
|
|
8
|
+
"""
|
|
9
|
+
梯度累积插件。
|
|
10
|
+
通过配置 Engine 的 accumulation_steps 属性来实现。
|
|
11
|
+
"""
|
|
12
|
+
def __init__(self, steps: int = 1):
|
|
13
|
+
"""
|
|
14
|
+
Args:
|
|
15
|
+
steps (int): 累积步数。默认为 1 (不累积)。
|
|
16
|
+
例如 steps=4,则每 4 个 Batch 更新一次参数,
|
|
17
|
+
等效 Batch Size = 原始 Batch Size * 4。
|
|
18
|
+
"""
|
|
19
|
+
super().__init__()
|
|
20
|
+
self.steps = steps
|
|
21
|
+
|
|
22
|
+
if self.steps < 1:
|
|
23
|
+
raise ValueError("Gradient accumulation steps must be >= 1")
|
|
24
|
+
|
|
25
|
+
def on_init(self, event: Event):
|
|
26
|
+
"""
|
|
27
|
+
在初始化阶段配置 Engine
|
|
28
|
+
"""
|
|
29
|
+
engine = event.engine
|
|
30
|
+
engine.accumulation_steps = self.steps
|
|
31
|
+
if self.steps > 1:
|
|
32
|
+
engine.print(f"[magenta]Enabled: steps={self.steps}[/]", plugin='GradAccum')
|