orbit-torch 0.0.4a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orbit/__init__.py +3 -0
- orbit/callback.py +54 -0
- orbit/engine.py +802 -0
- orbit/optim/__init__.py +2 -0
- orbit/optim/muon.py +193 -0
- orbit/optim/sam.py +92 -0
- orbit/plugin/__init__.py +10 -0
- orbit/plugin/board.py +61 -0
- orbit/plugin/checkpoint.py +245 -0
- orbit/plugin/classification.py +190 -0
- orbit/plugin/data/mentor_i18n.json +102 -0
- orbit/plugin/display_model.py +75 -0
- orbit/plugin/early_stopping.py +101 -0
- orbit/plugin/ema.py +97 -0
- orbit/plugin/gradient_accumulation.py +32 -0
- orbit/plugin/memory_estimator.py +234 -0
- orbit/plugin/mentor.py +313 -0
- orbit/plugin/overfit.py +30 -0
- orbit/plugin/warmup.py +119 -0
- orbit/utils/__init__.py +29 -0
- orbit/utils/freeze.py +59 -0
- orbit/utils/initialization.py +501 -0
- orbit/utils/layer_io.py +55 -0
- orbit/utils/mask.py +92 -0
- orbit/utils/seed.py +66 -0
- orbit_torch-0.0.4a1.dist-info/METADATA +25 -0
- orbit_torch-0.0.4a1.dist-info/RECORD +29 -0
- orbit_torch-0.0.4a1.dist-info/WHEEL +5 -0
- orbit_torch-0.0.4a1.dist-info/top_level.txt +1 -0
orbit/utils/mask.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def make_padding_mask(src: torch.Tensor, pad_idx: int = 0) -> torch.Tensor:
|
|
5
|
+
'''
|
|
6
|
+
创建填充掩码。
|
|
7
|
+
|
|
8
|
+
Args:
|
|
9
|
+
src (torch.Tensor): 源序列张量。形状为 [B, L_src]
|
|
10
|
+
pad_idx (int, optional): 填充符号的索引。默认为 0。
|
|
11
|
+
|
|
12
|
+
Returns:
|
|
13
|
+
torch.Tensor: 填充掩码。形状为 [B, 1, 1, L_src]
|
|
14
|
+
True 表示该位置不是填充,应该被关注。
|
|
15
|
+
'''
|
|
16
|
+
mask = (src != pad_idx).unsqueeze(1).unsqueeze(2)
|
|
17
|
+
return mask
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def make_lookahead_mask(size: int, device: torch.device = torch.device('cpu')) -> torch.Tensor:
|
|
21
|
+
'''
|
|
22
|
+
创建前瞻掩码(下三角矩阵)。
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
size (int): 序列长度。
|
|
26
|
+
device (torch.device, optional): 设备。默认为 cpu。
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
torch.Tensor: 前瞻掩码。形状为 [size, size]
|
|
30
|
+
True 表示允许关注的位置(下三角部分)。
|
|
31
|
+
'''
|
|
32
|
+
mask = torch.tril(torch.ones((size, size), device=device)).bool()
|
|
33
|
+
return mask
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def make_causal_mask(tgt: torch.Tensor, pad_idx: int = 0) -> torch.Tensor:
|
|
37
|
+
'''
|
|
38
|
+
创建因果掩码(结合了填充掩码和前瞻掩码)。
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
tgt (torch.Tensor): 目标序列张量。形状为 [B, L_tgt]
|
|
42
|
+
pad_idx (int, optional): 填充符号的索引。默认为 0。
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
torch.Tensor: 因果掩码。形状为 [B, 1, L_tgt, L_tgt]
|
|
46
|
+
'''
|
|
47
|
+
pad_mask = make_padding_mask(tgt, pad_idx)
|
|
48
|
+
seq_len = tgt.size(1)
|
|
49
|
+
lookahead_mask = make_lookahead_mask(seq_len, device=tgt.device)
|
|
50
|
+
|
|
51
|
+
# pad_mask: [B, 1, 1, L]
|
|
52
|
+
# lookahead_mask: [L, L]
|
|
53
|
+
# 广播后: [B, 1, L, L]
|
|
54
|
+
mask = pad_mask & lookahead_mask
|
|
55
|
+
return mask
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def make_sliding_window_mask(
|
|
59
|
+
tensor: torch.Tensor, window_size: int, pad_idx: int = 0, causal: bool = True
|
|
60
|
+
) -> torch.Tensor:
|
|
61
|
+
'''
|
|
62
|
+
创建滑动窗口掩码。
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
tensor (torch.Tensor): 输入序列张量。形状为 [B, L]
|
|
66
|
+
window_size (int): 窗口大小(单侧)。
|
|
67
|
+
pad_idx (int, optional): 填充符号的索引。默认为 0。
|
|
68
|
+
causal (bool, optional): 是否为因果(单向)。默认为 True。
|
|
69
|
+
如果为 True,位置 i 只能关注 [i - window_size, i]。
|
|
70
|
+
如果为 False,位置 i 可以关注 [i - window_size, i + window_size]。
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
torch.Tensor: 滑动窗口掩码。形状为 [B, 1, L, L]
|
|
74
|
+
'''
|
|
75
|
+
pad_mask = make_padding_mask(tensor, pad_idx) # [B, 1, 1, L]
|
|
76
|
+
seq_len = tensor.size(1)
|
|
77
|
+
|
|
78
|
+
ones = torch.ones((seq_len, seq_len), device=tensor.device, dtype=torch.bool)
|
|
79
|
+
|
|
80
|
+
if causal:
|
|
81
|
+
# j <= i AND j >= i - window_size
|
|
82
|
+
window_mask = torch.tril(ones, diagonal=0) & torch.triu(
|
|
83
|
+
ones, diagonal=-window_size
|
|
84
|
+
)
|
|
85
|
+
else:
|
|
86
|
+
# j <= i + window_size AND j >= i - window_size
|
|
87
|
+
window_mask = torch.tril(ones, diagonal=window_size) & torch.triu(
|
|
88
|
+
ones, diagonal=-window_size
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
mask = pad_mask & window_mask
|
|
92
|
+
return mask
|
orbit/utils/seed.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch
|
|
3
|
+
import numpy as np
|
|
4
|
+
import random
|
|
5
|
+
import os
|
|
6
|
+
|
|
7
|
+
def seed_everything(seed=42, strict=False):
|
|
8
|
+
"""
|
|
9
|
+
设置所有随机种子以确保 PyTorch 实验的可复现性。
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
seed (int): 随机种子数值,默认为 42。
|
|
13
|
+
strict (bool): 是否启用严格确定性模式。
|
|
14
|
+
如果为 True,将调用 torch.use_deterministic_algorithms(True),
|
|
15
|
+
这可能会导致某些不支持确定性算法的操作报错,并且会降低训练速度。
|
|
16
|
+
"""
|
|
17
|
+
import orbit
|
|
18
|
+
orbit.seed_info = seed
|
|
19
|
+
# 1. 设置 Python 原生 random
|
|
20
|
+
random.seed(seed)
|
|
21
|
+
|
|
22
|
+
# 2. 设置 Python 哈希种子 (影响字典/集合迭代顺序)
|
|
23
|
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
|
24
|
+
|
|
25
|
+
# 3. 设置 Numpy
|
|
26
|
+
np.random.seed(seed)
|
|
27
|
+
|
|
28
|
+
# 4. 设置 PyTorch CPU/GPU
|
|
29
|
+
torch.manual_seed(seed)
|
|
30
|
+
if torch.cuda.is_available():
|
|
31
|
+
torch.cuda.manual_seed_all(seed)
|
|
32
|
+
|
|
33
|
+
# 5. 设置 CuDNN 后端 (常规复现性设置)
|
|
34
|
+
if torch.cuda.is_available():
|
|
35
|
+
# 禁止寻找最优算法 (因为最优算法可能因硬件状态而变)
|
|
36
|
+
torch.backends.cudnn.benchmark = False
|
|
37
|
+
# 强制使用确定性算法
|
|
38
|
+
torch.backends.cudnn.deterministic = True
|
|
39
|
+
# 6. 严格模式 (Strict Mode)
|
|
40
|
+
if strict:
|
|
41
|
+
try:
|
|
42
|
+
# 启用严格确定性算法
|
|
43
|
+
# 注意:某些操作如果 PyTorch 没有对应的确定性实现,会直接通过 RuntimeError 报错
|
|
44
|
+
torch.use_deterministic_algorithms(True)
|
|
45
|
+
|
|
46
|
+
# 为了让 use_deterministic_algorithms 在 CUDA 上正常工作,
|
|
47
|
+
# 必须设置 CUBLAS_WORKSPACE_CONFIG,否则会报 CuBLAS 错误。
|
|
48
|
+
# :4096:8 是官方推荐的设置,虽然会增加少许显存开销
|
|
49
|
+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
|
50
|
+
|
|
51
|
+
print(f"[Info] Strict deterministic mode enabled. (seed={seed})")
|
|
52
|
+
except AttributeError:
|
|
53
|
+
print("[Warning] torch.use_deterministic_algorithms is not available in your PyTorch version.")
|
|
54
|
+
else:
|
|
55
|
+
print(f"[Info] Random seed set as {seed}")
|
|
56
|
+
|
|
57
|
+
def worker_init_fn(worker_id):
|
|
58
|
+
worker_seed = torch.initial_seed() % 2**32
|
|
59
|
+
np.random.seed(worker_seed)
|
|
60
|
+
random.seed(worker_seed)
|
|
61
|
+
|
|
62
|
+
def create_generator() -> torch.Generator:
|
|
63
|
+
"""创建随机数生成器"""
|
|
64
|
+
import orbit
|
|
65
|
+
seed = orbit.seed_info if hasattr(orbit, 'seed_info') else 42
|
|
66
|
+
return torch.Generator().manual_seed(seed)
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: orbit-torch
|
|
3
|
+
Version: 0.0.4a1
|
|
4
|
+
Summary: A PyTorch training engine with plugin system
|
|
5
|
+
Author: Aiden Hopkins
|
|
6
|
+
Author-email: acdphc@qq.com
|
|
7
|
+
Classifier: Programming Language :: Python :: 3
|
|
8
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
9
|
+
Classifier: Operating System :: OS Independent
|
|
10
|
+
Requires-Python: >=3.8
|
|
11
|
+
Requires-Dist: torch>=1.10.0
|
|
12
|
+
Requires-Dist: rich
|
|
13
|
+
Requires-Dist: tensorboard
|
|
14
|
+
Requires-Dist: matplotlib
|
|
15
|
+
Requires-Dist: seaborn
|
|
16
|
+
Requires-Dist: numpy
|
|
17
|
+
Requires-Dist: scikit-learn
|
|
18
|
+
Requires-Dist: einops
|
|
19
|
+
Requires-Dist: tokenizers
|
|
20
|
+
Dynamic: author
|
|
21
|
+
Dynamic: author-email
|
|
22
|
+
Dynamic: classifier
|
|
23
|
+
Dynamic: requires-dist
|
|
24
|
+
Dynamic: requires-python
|
|
25
|
+
Dynamic: summary
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
orbit/__init__.py,sha256=xikJCuBzK0VSwJ8gvPt0cZijw5TzVI2t3a62_aK-uEo,51
|
|
2
|
+
orbit/callback.py,sha256=FXq-bOVfoYW0S5S95ry55yCn1QgYkecpHonO35BNKwE,1738
|
|
3
|
+
orbit/engine.py,sha256=4Vskd39cfVa7jUWucikva8oDc68om4DtGH6OZLUDhKE,34230
|
|
4
|
+
orbit/optim/__init__.py,sha256=jp3TZFLM5LRcHeeYjR6qNwzwo1dfeR65A6DSGCEPnFg,67
|
|
5
|
+
orbit/optim/muon.py,sha256=IT0l1b0mAtLpB-SUuPyo9uaYHO2x_wkqg1TLcDJ6Bvk,7342
|
|
6
|
+
orbit/optim/sam.py,sha256=lDfou1jQYM_qtlWgYzfUKp2Zdi-uBSKZS7iYxB-yGkU,3702
|
|
7
|
+
orbit/plugin/__init__.py,sha256=9EcMoS__kzldM4MBNUwv5xUs21rliqrBALVmzIk_ADw,435
|
|
8
|
+
orbit/plugin/board.py,sha256=dK3aaTmznq9vSYvpSPgfE9kt30tieR93oK0oDaC8yJw,2267
|
|
9
|
+
orbit/plugin/checkpoint.py,sha256=w9R5O2-jzbRPSGPiuGuUycq21EebAdqigEFK0T0waOE,10898
|
|
10
|
+
orbit/plugin/classification.py,sha256=iVOFNUaL4SUMNEtnmbnioprb2nH1sOV4IsGH7UPhU8U,7019
|
|
11
|
+
orbit/plugin/display_model.py,sha256=Uch563Jq0R78lGHYiDPotXB6dmWfqdClSOiuOIKw7i8,2907
|
|
12
|
+
orbit/plugin/early_stopping.py,sha256=eDTjIzxDZvnCXANR2RXn8vrMh5wX8F2rQjl0BNp-6dU,4065
|
|
13
|
+
orbit/plugin/ema.py,sha256=NtQVB3Rz5YreQwWicXCNy_CfD6d7pu0crGGrz6-IJI8,3892
|
|
14
|
+
orbit/plugin/gradient_accumulation.py,sha256=JaSl-U9Zo_jeEaJE-KAXEPhQdjFA_He729VaCw2W8n4,1066
|
|
15
|
+
orbit/plugin/memory_estimator.py,sha256=YxSXbjAwl6jJh6yFEod5Tu_pgZh1Qadb0JsgmyfF7Mc,9743
|
|
16
|
+
orbit/plugin/mentor.py,sha256=sx3gLHnxYJX1tJhqnqHPTW2KuTeNM7h0PqOIquxMaxQ,12474
|
|
17
|
+
orbit/plugin/overfit.py,sha256=lpfxX4igi4uNdY4OR2_fpayw52iEfkCQnQ0HgzodZd4,1194
|
|
18
|
+
orbit/plugin/warmup.py,sha256=jS3Q52zuhTdqN9Bw85zcoH1YXZoqyRvnD7JHSSLRUwQ,4996
|
|
19
|
+
orbit/plugin/data/mentor_i18n.json,sha256=KHnEKRKHeLga_nqEUZVgiTh4jb4j2f_yD0MJT-7sdis,8005
|
|
20
|
+
orbit/utils/__init__.py,sha256=gXNfTz1a_kL5UZXT2o5BhoItc4Dgrk8w4K2hsfkbiYk,589
|
|
21
|
+
orbit/utils/freeze.py,sha256=ujR5wl2GHuqIQi0rm2i_yz6W8dj6X4YqgyZ3tBbv4yU,2405
|
|
22
|
+
orbit/utils/initialization.py,sha256=X0-OP0FPWX60_SRuB-x_Od3KdDh4EAD9UbOCZFj6NtI,20560
|
|
23
|
+
orbit/utils/layer_io.py,sha256=E-YnYw7av2ZdczDzhs08hUQtbZrcHPvqDdzEvWgs7Gs,1872
|
|
24
|
+
orbit/utils/mask.py,sha256=DayMxmKlWMYMYPceYhYlf9lIx4FD9s4Tp1WvzfVpsu0,3072
|
|
25
|
+
orbit/utils/seed.py,sha256=tbjF2jIPQfb0M_-wXOMVH3_XzWVsDkuhzD5lctGqalA,2578
|
|
26
|
+
orbit_torch-0.0.4a1.dist-info/METADATA,sha256=g2NSH_PVm9afSyVAfvct5xvk2hFColq85tKdbGQfNFc,700
|
|
27
|
+
orbit_torch-0.0.4a1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
28
|
+
orbit_torch-0.0.4a1.dist-info/top_level.txt,sha256=emrF0of931NzTSL4R5yBKpGoewFCB-cAwYNcUF5cqBs,6
|
|
29
|
+
orbit_torch-0.0.4a1.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
orbit
|