orbit-torch 0.0.4a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orbit/__init__.py +3 -0
- orbit/callback.py +54 -0
- orbit/engine.py +802 -0
- orbit/optim/__init__.py +2 -0
- orbit/optim/muon.py +193 -0
- orbit/optim/sam.py +92 -0
- orbit/plugin/__init__.py +10 -0
- orbit/plugin/board.py +61 -0
- orbit/plugin/checkpoint.py +245 -0
- orbit/plugin/classification.py +190 -0
- orbit/plugin/data/mentor_i18n.json +102 -0
- orbit/plugin/display_model.py +75 -0
- orbit/plugin/early_stopping.py +101 -0
- orbit/plugin/ema.py +97 -0
- orbit/plugin/gradient_accumulation.py +32 -0
- orbit/plugin/memory_estimator.py +234 -0
- orbit/plugin/mentor.py +313 -0
- orbit/plugin/overfit.py +30 -0
- orbit/plugin/warmup.py +119 -0
- orbit/utils/__init__.py +29 -0
- orbit/utils/freeze.py +59 -0
- orbit/utils/initialization.py +501 -0
- orbit/utils/layer_io.py +55 -0
- orbit/utils/mask.py +92 -0
- orbit/utils/seed.py +66 -0
- orbit_torch-0.0.4a1.dist-info/METADATA +25 -0
- orbit_torch-0.0.4a1.dist-info/RECORD +29 -0
- orbit_torch-0.0.4a1.dist-info/WHEEL +5 -0
- orbit_torch-0.0.4a1.dist-info/top_level.txt +1 -0
orbit/__init__.py
ADDED
orbit/callback.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import TYPE_CHECKING, Any, Optional
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING: from .engine import Engine
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class Event:
|
|
10
|
+
engine: Engine
|
|
11
|
+
name: str
|
|
12
|
+
source: Optional[str] = None
|
|
13
|
+
reason: Optional[str] = None
|
|
14
|
+
|
|
15
|
+
class Callback:
|
|
16
|
+
"""
|
|
17
|
+
回调基类。
|
|
18
|
+
所有方法都接收 event 实例,允许修改 engine 状态或读取数据。
|
|
19
|
+
"""
|
|
20
|
+
def on_init(self, event: Event): ...
|
|
21
|
+
|
|
22
|
+
def on_train_start(self, event: Event): ...
|
|
23
|
+
def on_train_end(self, event: Event): ...
|
|
24
|
+
|
|
25
|
+
def on_epoch_start(self, event: Event): ...
|
|
26
|
+
def on_epoch_end(self, event: Event): ...
|
|
27
|
+
|
|
28
|
+
def on_batch_start(self, event: Event): ...
|
|
29
|
+
def on_batch_end(self, event: Event): ...
|
|
30
|
+
|
|
31
|
+
def on_eval_start(self, event: Event): ...
|
|
32
|
+
def on_eval_end(self, event: Event): ...
|
|
33
|
+
|
|
34
|
+
def on_requested_stop(self, event: Event): ...
|
|
35
|
+
def on_exception(self, event: Event): ...
|
|
36
|
+
|
|
37
|
+
class Forward:
|
|
38
|
+
'''自定义前向传播和 Loss 计算接口。
|
|
39
|
+
|
|
40
|
+
实现此接口以接管 Engine 的默认前向传播逻辑。
|
|
41
|
+
'''
|
|
42
|
+
|
|
43
|
+
def forward(self, engine: Engine, data: Any, target: Any) -> torch.Tensor:
|
|
44
|
+
'''执行前向传播并返回 Loss。
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
engine (Engine): 当前 Engine 实例。可以通过 engine.model 访问模型,通过 engine.criterion 访问损失函数。
|
|
48
|
+
data (Any): 当前 Batch 的输入数据。
|
|
49
|
+
target (Any): 当前 Batch 的目标数据(标签)。
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
torch.Tensor: 计算得到的 Loss 标量。Engine 将使用此 Loss 进行反向传播。
|
|
53
|
+
'''
|
|
54
|
+
... # Returns loss
|