orbit-torch 0.0.4a1__py3-none-any.whl → 0.1.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orbit/__init__.py +3 -1
- orbit/callback.py +4 -3
- orbit/dataset/__init__.py +1 -0
- orbit/dataset/cogn.py +138 -0
- orbit/dataset/data/cogn_en.jsonl +45 -0
- orbit/dataset/data/cogn_zh.jsonl +113 -0
- orbit/engine.py +210 -146
- orbit/kit/__init__.py +2 -0
- orbit/kit/interface.py +154 -0
- orbit/kit/wrapper.py +157 -0
- orbit/model/__init__.py +5 -0
- orbit/model/base.py +125 -0
- orbit/model/block/__init__.py +34 -0
- orbit/model/block/attention.py +265 -0
- orbit/model/block/bio.py +537 -0
- orbit/model/block/codebook.py +122 -0
- orbit/model/block/conv.py +505 -0
- orbit/model/block/embedding.py +252 -0
- orbit/model/block/film.py +176 -0
- orbit/model/block/fusion.py +335 -0
- orbit/model/block/gate.py +334 -0
- orbit/model/block/lora.py +776 -0
- orbit/model/block/mlp.py +68 -0
- orbit/model/block/moe.py +94 -0
- orbit/model/block/tcn.py +99 -0
- orbit/model/config.py +62 -0
- orbit/model/kit/__init__.py +6 -0
- orbit/model/kit/discriminator.py +46 -0
- orbit/model/kit/losses.py +193 -0
- orbit/model/motif/__init__.py +0 -0
- orbit/model/motif/vision/__init__.py +0 -0
- orbit/model/motif/vision/v1.py +645 -0
- orbit/model/registry.py +53 -0
- orbit/optim/__init__.py +2 -2
- orbit/optim/sam.py +10 -3
- orbit/plugin/__init__.py +12 -8
- orbit/plugin/board.py +1 -2
- orbit/plugin/checkpoint.py +137 -62
- orbit/plugin/classification.py +2 -2
- orbit/plugin/display_model.py +1 -2
- orbit/plugin/early_stopping.py +1 -2
- orbit/plugin/ema.py +1 -2
- orbit/plugin/gradient_accumulation.py +1 -2
- orbit/plugin/lora.py +346 -0
- orbit/plugin/memory_estimator.py +1 -2
- orbit/plugin/warmup.py +1 -2
- orbit/utils/__init__.py +24 -1
- orbit/utils/cuda.py +10 -0
- orbit/utils/freeze.py +61 -17
- orbit/utils/image.py +164 -0
- orbit/utils/initialization.py +184 -94
- orbit/utils/layer_io.py +66 -7
- orbit/utils/lora.py +480 -0
- orbit/utils/moe.py +55 -0
- orbit/utils/seed.py +3 -19
- orbit/utils/sft.py +93 -0
- orbit_torch-0.1.0b1.dist-info/METADATA +208 -0
- orbit_torch-0.1.0b1.dist-info/RECORD +65 -0
- orbit_torch-0.0.4a1.dist-info/METADATA +0 -25
- orbit_torch-0.0.4a1.dist-info/RECORD +0 -29
- {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/WHEEL +0 -0
- {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/top_level.txt +0 -0
orbit/optim/sam.py
CHANGED
|
@@ -34,7 +34,7 @@ class SAM(torch.optim.Optimizer):
|
|
|
34
34
|
|
|
35
35
|
@torch.no_grad()
|
|
36
36
|
def first_step(self, zero_grad=False):
|
|
37
|
-
'''计算并应用参数扰动 epsilon。
|
|
37
|
+
r'''计算并应用参数扰动 epsilon。
|
|
38
38
|
|
|
39
39
|
该步骤根据当前梯度将模型参数 $w$ 更新为 $w + \epsilon$。
|
|
40
40
|
|
|
@@ -55,7 +55,7 @@ class SAM(torch.optim.Optimizer):
|
|
|
55
55
|
|
|
56
56
|
@torch.no_grad()
|
|
57
57
|
def second_step(self, zero_grad=False):
|
|
58
|
-
'''恢复参数并执行真正的梯度更新。
|
|
58
|
+
r'''恢复参数并执行真正的梯度更新。
|
|
59
59
|
|
|
60
60
|
该步骤将参数从 $w + \epsilon$ 恢复回 $w$,并利用扰动位置的梯度更新 $w$。
|
|
61
61
|
|
|
@@ -89,4 +89,11 @@ class SAM(torch.optim.Optimizer):
|
|
|
89
89
|
return norm
|
|
90
90
|
|
|
91
91
|
def step(self, closure=None):
|
|
92
|
-
raise NotImplementedError('SAM requires steps to be run manually: first_step and second_step')
|
|
92
|
+
raise NotImplementedError('SAM requires steps to be run manually: first_step and second_step')
|
|
93
|
+
|
|
94
|
+
def load_state_dict(self, state_dict):
|
|
95
|
+
self.base_optimizer.load_state_dict(state_dict)
|
|
96
|
+
self.param_groups = self.base_optimizer.param_groups
|
|
97
|
+
|
|
98
|
+
def state_dict(self):
|
|
99
|
+
return self.base_optimizer.state_dict()
|
orbit/plugin/__init__.py
CHANGED
|
@@ -1,10 +1,14 @@
|
|
|
1
|
-
try: from
|
|
1
|
+
try: from .classification import ClassificationReport
|
|
2
2
|
except: pass
|
|
3
3
|
|
|
4
|
-
from
|
|
5
|
-
from
|
|
6
|
-
from
|
|
7
|
-
from
|
|
8
|
-
from
|
|
9
|
-
from
|
|
10
|
-
from
|
|
4
|
+
from .checkpoint import Checkpoint
|
|
5
|
+
from .board import Board
|
|
6
|
+
from .display_model import ModelSummary
|
|
7
|
+
from .warmup import Warmup
|
|
8
|
+
from .early_stopping import EarlyStopping
|
|
9
|
+
from .gradient_accumulation import GradientAccumulation
|
|
10
|
+
from .mentor import Mentor
|
|
11
|
+
from .ema import EMA # Not tested
|
|
12
|
+
from .memory_estimator import MemoryEstimator
|
|
13
|
+
from .overfit import Overfit
|
|
14
|
+
from .lora import LoRA
|
orbit/plugin/board.py
CHANGED
|
@@ -2,8 +2,7 @@ from torch.utils.tensorboard import SummaryWriter
|
|
|
2
2
|
from typing import Optional, TYPE_CHECKING
|
|
3
3
|
from orbit.callback import Callback, Event
|
|
4
4
|
|
|
5
|
-
if TYPE_CHECKING:
|
|
6
|
-
from ..engine import Engine
|
|
5
|
+
if TYPE_CHECKING: from orbit.engine import Engine
|
|
7
6
|
|
|
8
7
|
class Board(Callback):
|
|
9
8
|
def __init__(self, name: str, log_dir: str):
|
orbit/plugin/checkpoint.py
CHANGED
|
@@ -1,10 +1,14 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import torch
|
|
3
|
+
import json
|
|
3
4
|
from typing import TYPE_CHECKING, List, Tuple
|
|
4
5
|
from orbit.callback import Callback, Event
|
|
5
6
|
|
|
6
|
-
|
|
7
|
-
|
|
7
|
+
from safetensors.torch import save_file as safe_save_file
|
|
8
|
+
from safetensors.torch import load_file as safe_load_file
|
|
9
|
+
from safetensors.torch import safe_open
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING: from orbit.engine import Engine
|
|
8
12
|
|
|
9
13
|
class Checkpoint(Callback):
|
|
10
14
|
def __init__(
|
|
@@ -17,6 +21,7 @@ class Checkpoint(Callback):
|
|
|
17
21
|
save_top_k: int = 1,
|
|
18
22
|
save_last: bool = True,
|
|
19
23
|
every_n_train_steps: int = None,
|
|
24
|
+
use_safetensors: bool = False,
|
|
20
25
|
verbose: bool = True
|
|
21
26
|
):
|
|
22
27
|
"""
|
|
@@ -29,6 +34,7 @@ class Checkpoint(Callback):
|
|
|
29
34
|
save_top_k (int): 保存最好的 K 个模型。设为 0 则禁用 Top-K 保存。
|
|
30
35
|
save_last (bool): 是否总是保存 '{name}_last.pt'。
|
|
31
36
|
every_n_train_steps (int): 每隔多少个训练步保存一次。
|
|
37
|
+
use_safetensors (bool): 是否使用 safetensors 格式保存。注意:safetensors 不支持保存优化器状态。
|
|
32
38
|
verbose (bool): 是否打印保存信息。
|
|
33
39
|
"""
|
|
34
40
|
super().__init__()
|
|
@@ -40,8 +46,12 @@ class Checkpoint(Callback):
|
|
|
40
46
|
self.save_top_k = save_top_k
|
|
41
47
|
self.save_last = save_last
|
|
42
48
|
self.every_n_train_steps = every_n_train_steps
|
|
49
|
+
self.use_safetensors = use_safetensors
|
|
43
50
|
self.verbose = verbose
|
|
44
51
|
|
|
52
|
+
if self.use_safetensors and not self.save_weights_only:
|
|
53
|
+
print("[yellow]Warning: safetensors does not support saving optimizer state. Setting save_weights_only=True implicitly.[/]")
|
|
54
|
+
|
|
45
55
|
# 维护 Top-K 模型列表: [(score, filename), ...]
|
|
46
56
|
self.best_k_models: List[Tuple[float, str]] = []
|
|
47
57
|
|
|
@@ -65,12 +75,20 @@ class Checkpoint(Callback):
|
|
|
65
75
|
if self._meta_key in engine.meta:
|
|
66
76
|
self.best_k_models = engine.meta[self._meta_key].get('best_k_models', [])
|
|
67
77
|
|
|
68
|
-
|
|
78
|
+
ext = ".safetensors" if self.use_safetensors else ".pt"
|
|
79
|
+
load_path = os.path.join(self.path, self.name + "_last" + ext).replace("\\", "/")
|
|
69
80
|
|
|
70
81
|
if os.path.exists(load_path):
|
|
71
82
|
self._load(engine, load_path)
|
|
72
83
|
else:
|
|
73
|
-
|
|
84
|
+
# 尝试查找另一种格式
|
|
85
|
+
alt_ext = ".pt" if self.use_safetensors else ".safetensors"
|
|
86
|
+
alt_path = os.path.join(self.path, self.name + "_last" + alt_ext).replace("\\", "/")
|
|
87
|
+
if os.path.exists(alt_path):
|
|
88
|
+
engine.print(f"[yellow]Found checkpoint with alternative extension: {alt_path}[/]", plugin='Checkpointing')
|
|
89
|
+
self._load(engine, alt_path)
|
|
90
|
+
else:
|
|
91
|
+
engine.print(f"[yellow]Warning: Resume checkpoint '{load_path}' not found. Starting from scratch.[/]", plugin='Checkpointing')
|
|
74
92
|
|
|
75
93
|
def on_batch_end(self, event: Event):
|
|
76
94
|
"""
|
|
@@ -81,7 +99,8 @@ class Checkpoint(Callback):
|
|
|
81
99
|
step = event.engine.global_step
|
|
82
100
|
if step > 0 and step % self.every_n_train_steps == 0:
|
|
83
101
|
# 保存 step checkpoint
|
|
84
|
-
|
|
102
|
+
ext = ".safetensors" if self.use_safetensors else ".pt"
|
|
103
|
+
filename = f"{self.name}_step_{step}{ext}"
|
|
85
104
|
|
|
86
105
|
# 传递 is_step=True
|
|
87
106
|
self._save(event.engine, filename, verbose=self.verbose, is_step=True)
|
|
@@ -93,7 +112,7 @@ class Checkpoint(Callback):
|
|
|
93
112
|
|
|
94
113
|
# 同时更新 last checkpoint
|
|
95
114
|
if self.save_last:
|
|
96
|
-
self._save(event.engine, f"{self.name}_last
|
|
115
|
+
self._save(event.engine, f"{self.name}_last{ext}", verbose=False, is_step=True)
|
|
97
116
|
|
|
98
117
|
def on_epoch_end(self, event: Event):
|
|
99
118
|
"""
|
|
@@ -102,9 +121,11 @@ class Checkpoint(Callback):
|
|
|
102
121
|
2. 如果设置了 monitor,保存 top_k
|
|
103
122
|
"""
|
|
104
123
|
engine = event.engine
|
|
124
|
+
ext = ".safetensors" if self.use_safetensors else ".pt"
|
|
125
|
+
|
|
105
126
|
# 1. Save Last
|
|
106
127
|
if self.save_last:
|
|
107
|
-
self._save(engine, f"{self.name}_last
|
|
128
|
+
self._save(engine, f"{self.name}_last{ext}", verbose=False) # last 不需要每次都啰嗦
|
|
108
129
|
|
|
109
130
|
# 2. Save Top K
|
|
110
131
|
if self.monitor and self.save_top_k > 0:
|
|
@@ -119,7 +140,8 @@ class Checkpoint(Callback):
|
|
|
119
140
|
|
|
120
141
|
def _check_and_save_top_k(self, engine: 'Engine', current_score: float):
|
|
121
142
|
"""检查并保存 Top-K 模型"""
|
|
122
|
-
|
|
143
|
+
ext = ".safetensors" if self.use_safetensors else ".pt"
|
|
144
|
+
filename = f"{self.name}_ep{engine.epoch+1}_{self.monitor}_{current_score:.4f}{ext}"
|
|
123
145
|
|
|
124
146
|
# 逻辑简化:总是先加入,然后排序,如果超过 K 个,删除最差的
|
|
125
147
|
self.best_k_models.append((current_score, filename))
|
|
@@ -154,24 +176,44 @@ class Checkpoint(Callback):
|
|
|
154
176
|
|
|
155
177
|
# 获取原始模型 (去除 DataParallel 包装) 以保证 Checkpoint 通用性
|
|
156
178
|
raw_model = engine.unwrap_model()
|
|
157
|
-
|
|
158
|
-
state = {
|
|
159
|
-
'epoch': engine.epoch,
|
|
160
|
-
'global_step': engine.global_step,
|
|
161
|
-
'batch_idx': engine.batch_idx,
|
|
162
|
-
'is_step': is_step,
|
|
163
|
-
'model_state_dict': raw_model.state_dict(),
|
|
164
|
-
'optimizer_state_dict': engine.optimizer.state_dict() if engine.optimizer else None,
|
|
165
|
-
'scheduler_state_dict': engine.scheduler.state_dict() if engine.scheduler else None,
|
|
166
|
-
'scaler_state_dict': engine.scaler.state_dict() if engine.scaler else None,
|
|
167
|
-
'meta': engine.meta,
|
|
168
|
-
}
|
|
169
|
-
if self.save_weights_only:
|
|
170
|
-
state = raw_model.state_dict()
|
|
171
|
-
|
|
172
179
|
file_path = os.path.join(self.path, filename)
|
|
180
|
+
|
|
173
181
|
try:
|
|
174
|
-
|
|
182
|
+
if self.use_safetensors and filename.endswith('.safetensors'):
|
|
183
|
+
# Safetensors 模式:仅保存权重和元数据
|
|
184
|
+
metadata = {
|
|
185
|
+
'epoch': str(engine.epoch),
|
|
186
|
+
'global_step': str(engine.global_step),
|
|
187
|
+
'batch_idx': str(engine.batch_idx),
|
|
188
|
+
'is_step': str(is_step),
|
|
189
|
+
# 注意:safetensors metadata 只能是字符串
|
|
190
|
+
}
|
|
191
|
+
# 尝试序列化 meta
|
|
192
|
+
try:
|
|
193
|
+
metadata['meta'] = json.dumps(engine.meta)
|
|
194
|
+
except:
|
|
195
|
+
pass
|
|
196
|
+
|
|
197
|
+
safe_save_file(raw_model.state_dict(), file_path, metadata=metadata)
|
|
198
|
+
|
|
199
|
+
else:
|
|
200
|
+
# Torch 模式:保存完整状态
|
|
201
|
+
state = {
|
|
202
|
+
'epoch': engine.epoch,
|
|
203
|
+
'global_step': engine.global_step,
|
|
204
|
+
'batch_idx': engine.batch_idx,
|
|
205
|
+
'is_step': is_step,
|
|
206
|
+
'model_state_dict': raw_model.state_dict(),
|
|
207
|
+
'optimizer_state_dict': engine.optimizer.state_dict() if engine.optimizer else None,
|
|
208
|
+
'scheduler_state_dict': engine.scheduler.state_dict() if engine.scheduler else None,
|
|
209
|
+
'scaler_state_dict': engine.scaler.state_dict() if engine.scaler else None,
|
|
210
|
+
'meta': engine.meta,
|
|
211
|
+
}
|
|
212
|
+
if self.save_weights_only:
|
|
213
|
+
state = raw_model.state_dict()
|
|
214
|
+
|
|
215
|
+
torch.save(state, file_path)
|
|
216
|
+
|
|
175
217
|
if verbose:
|
|
176
218
|
# 使用相对路径显示,更简洁
|
|
177
219
|
rel_path = os.path.relpath(file_path)
|
|
@@ -194,52 +236,85 @@ class Checkpoint(Callback):
|
|
|
194
236
|
"""加载 Checkpoint 的核心逻辑"""
|
|
195
237
|
engine.print(f"[cyan]Loading checkpoint from: {file_path}[/]", plugin='Checkpointing')
|
|
196
238
|
try:
|
|
197
|
-
# 加载到设备
|
|
198
|
-
checkpoint = torch.load(file_path, map_location=engine.device)
|
|
199
|
-
|
|
200
239
|
# 获取原始模型以进行加载
|
|
201
240
|
raw_model = engine.unwrap_model()
|
|
202
|
-
|
|
203
|
-
# 1. 加载模型权重
|
|
204
|
-
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
|
205
|
-
raw_model.load_state_dict(checkpoint['model_state_dict'])
|
|
206
|
-
else:
|
|
207
|
-
raw_model.load_state_dict(checkpoint)
|
|
208
|
-
engine.print("[yellow]Loaded model weights only (legacy format).[/]", plugin='Checkpointing')
|
|
209
|
-
return
|
|
210
241
|
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
if engine.scheduler and 'scheduler_state_dict' in checkpoint:
|
|
217
|
-
engine.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
|
242
|
+
if file_path.endswith('.safetensors'):
|
|
243
|
+
# 加载权重
|
|
244
|
+
state_dict = safe_load_file(file_path, device=str(engine.device))
|
|
245
|
+
raw_model.load_state_dict(state_dict)
|
|
218
246
|
|
|
219
|
-
|
|
220
|
-
|
|
247
|
+
# 尝试恢复元数据
|
|
248
|
+
with safe_open(file_path, framework="pt", device=str(engine.device)) as f:
|
|
249
|
+
metadata = f.metadata()
|
|
250
|
+
if metadata:
|
|
251
|
+
loaded_epoch = int(metadata.get('epoch', 0))
|
|
252
|
+
loaded_batch_idx = int(metadata.get('batch_idx', -1))
|
|
253
|
+
is_step = metadata.get('is_step', 'False') == 'True'
|
|
254
|
+
engine.global_step = int(metadata.get('global_step', 0))
|
|
255
|
+
|
|
256
|
+
if 'meta' in metadata:
|
|
257
|
+
try:
|
|
258
|
+
engine.meta.update(json.loads(metadata['meta']))
|
|
259
|
+
except: pass
|
|
260
|
+
|
|
261
|
+
if is_step:
|
|
262
|
+
engine.start_epoch = loaded_epoch
|
|
263
|
+
engine.start_batch_idx = loaded_batch_idx
|
|
264
|
+
msg = f"Epoch {engine.start_epoch}, Batch {engine.start_batch_idx + 1}"
|
|
265
|
+
else:
|
|
266
|
+
engine.start_epoch = loaded_epoch + 1
|
|
267
|
+
engine.start_batch_idx = -1
|
|
268
|
+
msg = f"Epoch {engine.start_epoch}"
|
|
269
|
+
|
|
270
|
+
engine.print(f"[green]Resumed weights from {msg}. Note: Optimizer state not restored (safetensors).[/]", plugin='Checkpointing')
|
|
271
|
+
else:
|
|
272
|
+
engine.print("[yellow]Loaded weights only (no metadata in safetensors).[/]", plugin='Checkpointing')
|
|
221
273
|
|
|
222
|
-
|
|
223
|
-
|
|
274
|
+
else:
|
|
275
|
+
# Torch 加载
|
|
276
|
+
checkpoint = torch.load(file_path, map_location=engine.device)
|
|
224
277
|
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
if is_step:
|
|
230
|
-
# 如果是 Step Checkpoint,从当前 Epoch 的下一个 Batch 继续
|
|
231
|
-
engine.start_epoch = loaded_epoch
|
|
232
|
-
engine.start_batch_idx = loaded_batch_idx
|
|
233
|
-
msg = f"Epoch {engine.start_epoch}, Batch {engine.start_batch_idx + 1}"
|
|
278
|
+
# 1. 加载模型权重
|
|
279
|
+
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
|
280
|
+
raw_model.load_state_dict(checkpoint['model_state_dict'])
|
|
234
281
|
else:
|
|
235
|
-
|
|
236
|
-
engine.
|
|
237
|
-
|
|
238
|
-
msg = f"Epoch {engine.start_epoch}"
|
|
239
|
-
|
|
240
|
-
engine.global_step = checkpoint.get('global_step', 0)
|
|
282
|
+
raw_model.load_state_dict(checkpoint)
|
|
283
|
+
engine.print("[yellow]Loaded model weights only (legacy format).[/]", plugin='Checkpointing')
|
|
284
|
+
return
|
|
241
285
|
|
|
242
|
-
|
|
286
|
+
# 2. 恢复训练状态
|
|
287
|
+
if not self.save_weights_only:
|
|
288
|
+
if engine.optimizer and 'optimizer_state_dict' in checkpoint:
|
|
289
|
+
engine.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
290
|
+
|
|
291
|
+
if engine.scheduler and 'scheduler_state_dict' in checkpoint:
|
|
292
|
+
engine.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
|
293
|
+
|
|
294
|
+
if engine.scaler and 'scaler_state_dict' in checkpoint:
|
|
295
|
+
engine.scaler.load_state_dict(checkpoint['scaler_state_dict'])
|
|
296
|
+
|
|
297
|
+
if 'meta' in checkpoint:
|
|
298
|
+
engine.meta.update(checkpoint['meta'])
|
|
299
|
+
|
|
300
|
+
loaded_epoch = checkpoint.get('epoch', 0)
|
|
301
|
+
loaded_batch_idx = checkpoint.get('batch_idx', -1)
|
|
302
|
+
is_step = checkpoint.get('is_step', False)
|
|
303
|
+
|
|
304
|
+
if is_step:
|
|
305
|
+
# 如果是 Step Checkpoint,从当前 Epoch 的下一个 Batch 继续
|
|
306
|
+
engine.start_epoch = loaded_epoch
|
|
307
|
+
engine.start_batch_idx = loaded_batch_idx
|
|
308
|
+
msg = f"Epoch {engine.start_epoch}, Batch {engine.start_batch_idx + 1}"
|
|
309
|
+
else:
|
|
310
|
+
# 如果是 Epoch Checkpoint,从下一个 Epoch 开始
|
|
311
|
+
engine.start_epoch = loaded_epoch + 1
|
|
312
|
+
engine.start_batch_idx = -1
|
|
313
|
+
msg = f"Epoch {engine.start_epoch}"
|
|
314
|
+
|
|
315
|
+
engine.global_step = checkpoint.get('global_step', 0)
|
|
316
|
+
|
|
317
|
+
engine.print(f"[green]Successfully resumed training from {msg}, Global Step {engine.global_step}[/]", plugin='Checkpointing')
|
|
243
318
|
|
|
244
319
|
except Exception as e:
|
|
245
320
|
engine.print(f"[red]Failed to load checkpoint: {e}[/]", plugin='Checkpointing')
|
orbit/plugin/classification.py
CHANGED
|
@@ -7,8 +7,8 @@ from typing import List, Optional, TYPE_CHECKING
|
|
|
7
7
|
import rich.box as box
|
|
8
8
|
|
|
9
9
|
from orbit.callback import Callback, Event
|
|
10
|
-
|
|
11
|
-
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING: from orbit.engine import Engine
|
|
12
12
|
|
|
13
13
|
class ClassificationReport(Callback):
|
|
14
14
|
def __init__(
|
orbit/plugin/display_model.py
CHANGED
orbit/plugin/early_stopping.py
CHANGED
orbit/plugin/ema.py
CHANGED