orbit-torch 0.0.4a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
orbit/utils/mask.py ADDED
@@ -0,0 +1,92 @@
1
+ import torch
2
+
3
+
4
+ def make_padding_mask(src: torch.Tensor, pad_idx: int = 0) -> torch.Tensor:
5
+ '''
6
+ 创建填充掩码。
7
+
8
+ Args:
9
+ src (torch.Tensor): 源序列张量。形状为 [B, L_src]
10
+ pad_idx (int, optional): 填充符号的索引。默认为 0。
11
+
12
+ Returns:
13
+ torch.Tensor: 填充掩码。形状为 [B, 1, 1, L_src]
14
+ True 表示该位置不是填充,应该被关注。
15
+ '''
16
+ mask = (src != pad_idx).unsqueeze(1).unsqueeze(2)
17
+ return mask
18
+
19
+
20
+ def make_lookahead_mask(size: int, device: torch.device = torch.device('cpu')) -> torch.Tensor:
21
+ '''
22
+ 创建前瞻掩码(下三角矩阵)。
23
+
24
+ Args:
25
+ size (int): 序列长度。
26
+ device (torch.device, optional): 设备。默认为 cpu。
27
+
28
+ Returns:
29
+ torch.Tensor: 前瞻掩码。形状为 [size, size]
30
+ True 表示允许关注的位置(下三角部分)。
31
+ '''
32
+ mask = torch.tril(torch.ones((size, size), device=device)).bool()
33
+ return mask
34
+
35
+
36
+ def make_causal_mask(tgt: torch.Tensor, pad_idx: int = 0) -> torch.Tensor:
37
+ '''
38
+ 创建因果掩码(结合了填充掩码和前瞻掩码)。
39
+
40
+ Args:
41
+ tgt (torch.Tensor): 目标序列张量。形状为 [B, L_tgt]
42
+ pad_idx (int, optional): 填充符号的索引。默认为 0。
43
+
44
+ Returns:
45
+ torch.Tensor: 因果掩码。形状为 [B, 1, L_tgt, L_tgt]
46
+ '''
47
+ pad_mask = make_padding_mask(tgt, pad_idx)
48
+ seq_len = tgt.size(1)
49
+ lookahead_mask = make_lookahead_mask(seq_len, device=tgt.device)
50
+
51
+ # pad_mask: [B, 1, 1, L]
52
+ # lookahead_mask: [L, L]
53
+ # 广播后: [B, 1, L, L]
54
+ mask = pad_mask & lookahead_mask
55
+ return mask
56
+
57
+
58
+ def make_sliding_window_mask(
59
+ tensor: torch.Tensor, window_size: int, pad_idx: int = 0, causal: bool = True
60
+ ) -> torch.Tensor:
61
+ '''
62
+ 创建滑动窗口掩码。
63
+
64
+ Args:
65
+ tensor (torch.Tensor): 输入序列张量。形状为 [B, L]
66
+ window_size (int): 窗口大小(单侧)。
67
+ pad_idx (int, optional): 填充符号的索引。默认为 0。
68
+ causal (bool, optional): 是否为因果(单向)。默认为 True。
69
+ 如果为 True,位置 i 只能关注 [i - window_size, i]。
70
+ 如果为 False,位置 i 可以关注 [i - window_size, i + window_size]。
71
+
72
+ Returns:
73
+ torch.Tensor: 滑动窗口掩码。形状为 [B, 1, L, L]
74
+ '''
75
+ pad_mask = make_padding_mask(tensor, pad_idx) # [B, 1, 1, L]
76
+ seq_len = tensor.size(1)
77
+
78
+ ones = torch.ones((seq_len, seq_len), device=tensor.device, dtype=torch.bool)
79
+
80
+ if causal:
81
+ # j <= i AND j >= i - window_size
82
+ window_mask = torch.tril(ones, diagonal=0) & torch.triu(
83
+ ones, diagonal=-window_size
84
+ )
85
+ else:
86
+ # j <= i + window_size AND j >= i - window_size
87
+ window_mask = torch.tril(ones, diagonal=window_size) & torch.triu(
88
+ ones, diagonal=-window_size
89
+ )
90
+
91
+ mask = pad_mask & window_mask
92
+ return mask
orbit/utils/seed.py ADDED
@@ -0,0 +1,66 @@
1
+ import torch
2
+ import torch
3
+ import numpy as np
4
+ import random
5
+ import os
6
+
7
+ def seed_everything(seed=42, strict=False):
8
+ """
9
+ 设置所有随机种子以确保 PyTorch 实验的可复现性。
10
+
11
+ Args:
12
+ seed (int): 随机种子数值,默认为 42。
13
+ strict (bool): 是否启用严格确定性模式。
14
+ 如果为 True,将调用 torch.use_deterministic_algorithms(True),
15
+ 这可能会导致某些不支持确定性算法的操作报错,并且会降低训练速度。
16
+ """
17
+ import orbit
18
+ orbit.seed_info = seed
19
+ # 1. 设置 Python 原生 random
20
+ random.seed(seed)
21
+
22
+ # 2. 设置 Python 哈希种子 (影响字典/集合迭代顺序)
23
+ os.environ['PYTHONHASHSEED'] = str(seed)
24
+
25
+ # 3. 设置 Numpy
26
+ np.random.seed(seed)
27
+
28
+ # 4. 设置 PyTorch CPU/GPU
29
+ torch.manual_seed(seed)
30
+ if torch.cuda.is_available():
31
+ torch.cuda.manual_seed_all(seed)
32
+
33
+ # 5. 设置 CuDNN 后端 (常规复现性设置)
34
+ if torch.cuda.is_available():
35
+ # 禁止寻找最优算法 (因为最优算法可能因硬件状态而变)
36
+ torch.backends.cudnn.benchmark = False
37
+ # 强制使用确定性算法
38
+ torch.backends.cudnn.deterministic = True
39
+ # 6. 严格模式 (Strict Mode)
40
+ if strict:
41
+ try:
42
+ # 启用严格确定性算法
43
+ # 注意:某些操作如果 PyTorch 没有对应的确定性实现,会直接通过 RuntimeError 报错
44
+ torch.use_deterministic_algorithms(True)
45
+
46
+ # 为了让 use_deterministic_algorithms 在 CUDA 上正常工作,
47
+ # 必须设置 CUBLAS_WORKSPACE_CONFIG,否则会报 CuBLAS 错误。
48
+ # :4096:8 是官方推荐的设置,虽然会增加少许显存开销
49
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
50
+
51
+ print(f"[Info] Strict deterministic mode enabled. (seed={seed})")
52
+ except AttributeError:
53
+ print("[Warning] torch.use_deterministic_algorithms is not available in your PyTorch version.")
54
+ else:
55
+ print(f"[Info] Random seed set as {seed}")
56
+
57
+ def worker_init_fn(worker_id):
58
+ worker_seed = torch.initial_seed() % 2**32
59
+ np.random.seed(worker_seed)
60
+ random.seed(worker_seed)
61
+
62
+ def create_generator() -> torch.Generator:
63
+ """创建随机数生成器"""
64
+ import orbit
65
+ seed = orbit.seed_info if hasattr(orbit, 'seed_info') else 42
66
+ return torch.Generator().manual_seed(seed)
@@ -0,0 +1,25 @@
1
+ Metadata-Version: 2.4
2
+ Name: orbit-torch
3
+ Version: 0.0.4a1
4
+ Summary: A PyTorch training engine with plugin system
5
+ Author: Aiden Hopkins
6
+ Author-email: acdphc@qq.com
7
+ Classifier: Programming Language :: Python :: 3
8
+ Classifier: License :: OSI Approved :: MIT License
9
+ Classifier: Operating System :: OS Independent
10
+ Requires-Python: >=3.8
11
+ Requires-Dist: torch>=1.10.0
12
+ Requires-Dist: rich
13
+ Requires-Dist: tensorboard
14
+ Requires-Dist: matplotlib
15
+ Requires-Dist: seaborn
16
+ Requires-Dist: numpy
17
+ Requires-Dist: scikit-learn
18
+ Requires-Dist: einops
19
+ Requires-Dist: tokenizers
20
+ Dynamic: author
21
+ Dynamic: author-email
22
+ Dynamic: classifier
23
+ Dynamic: requires-dist
24
+ Dynamic: requires-python
25
+ Dynamic: summary
@@ -0,0 +1,29 @@
1
+ orbit/__init__.py,sha256=xikJCuBzK0VSwJ8gvPt0cZijw5TzVI2t3a62_aK-uEo,51
2
+ orbit/callback.py,sha256=FXq-bOVfoYW0S5S95ry55yCn1QgYkecpHonO35BNKwE,1738
3
+ orbit/engine.py,sha256=4Vskd39cfVa7jUWucikva8oDc68om4DtGH6OZLUDhKE,34230
4
+ orbit/optim/__init__.py,sha256=jp3TZFLM5LRcHeeYjR6qNwzwo1dfeR65A6DSGCEPnFg,67
5
+ orbit/optim/muon.py,sha256=IT0l1b0mAtLpB-SUuPyo9uaYHO2x_wkqg1TLcDJ6Bvk,7342
6
+ orbit/optim/sam.py,sha256=lDfou1jQYM_qtlWgYzfUKp2Zdi-uBSKZS7iYxB-yGkU,3702
7
+ orbit/plugin/__init__.py,sha256=9EcMoS__kzldM4MBNUwv5xUs21rliqrBALVmzIk_ADw,435
8
+ orbit/plugin/board.py,sha256=dK3aaTmznq9vSYvpSPgfE9kt30tieR93oK0oDaC8yJw,2267
9
+ orbit/plugin/checkpoint.py,sha256=w9R5O2-jzbRPSGPiuGuUycq21EebAdqigEFK0T0waOE,10898
10
+ orbit/plugin/classification.py,sha256=iVOFNUaL4SUMNEtnmbnioprb2nH1sOV4IsGH7UPhU8U,7019
11
+ orbit/plugin/display_model.py,sha256=Uch563Jq0R78lGHYiDPotXB6dmWfqdClSOiuOIKw7i8,2907
12
+ orbit/plugin/early_stopping.py,sha256=eDTjIzxDZvnCXANR2RXn8vrMh5wX8F2rQjl0BNp-6dU,4065
13
+ orbit/plugin/ema.py,sha256=NtQVB3Rz5YreQwWicXCNy_CfD6d7pu0crGGrz6-IJI8,3892
14
+ orbit/plugin/gradient_accumulation.py,sha256=JaSl-U9Zo_jeEaJE-KAXEPhQdjFA_He729VaCw2W8n4,1066
15
+ orbit/plugin/memory_estimator.py,sha256=YxSXbjAwl6jJh6yFEod5Tu_pgZh1Qadb0JsgmyfF7Mc,9743
16
+ orbit/plugin/mentor.py,sha256=sx3gLHnxYJX1tJhqnqHPTW2KuTeNM7h0PqOIquxMaxQ,12474
17
+ orbit/plugin/overfit.py,sha256=lpfxX4igi4uNdY4OR2_fpayw52iEfkCQnQ0HgzodZd4,1194
18
+ orbit/plugin/warmup.py,sha256=jS3Q52zuhTdqN9Bw85zcoH1YXZoqyRvnD7JHSSLRUwQ,4996
19
+ orbit/plugin/data/mentor_i18n.json,sha256=KHnEKRKHeLga_nqEUZVgiTh4jb4j2f_yD0MJT-7sdis,8005
20
+ orbit/utils/__init__.py,sha256=gXNfTz1a_kL5UZXT2o5BhoItc4Dgrk8w4K2hsfkbiYk,589
21
+ orbit/utils/freeze.py,sha256=ujR5wl2GHuqIQi0rm2i_yz6W8dj6X4YqgyZ3tBbv4yU,2405
22
+ orbit/utils/initialization.py,sha256=X0-OP0FPWX60_SRuB-x_Od3KdDh4EAD9UbOCZFj6NtI,20560
23
+ orbit/utils/layer_io.py,sha256=E-YnYw7av2ZdczDzhs08hUQtbZrcHPvqDdzEvWgs7Gs,1872
24
+ orbit/utils/mask.py,sha256=DayMxmKlWMYMYPceYhYlf9lIx4FD9s4Tp1WvzfVpsu0,3072
25
+ orbit/utils/seed.py,sha256=tbjF2jIPQfb0M_-wXOMVH3_XzWVsDkuhzD5lctGqalA,2578
26
+ orbit_torch-0.0.4a1.dist-info/METADATA,sha256=g2NSH_PVm9afSyVAfvct5xvk2hFColq85tKdbGQfNFc,700
27
+ orbit_torch-0.0.4a1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
28
+ orbit_torch-0.0.4a1.dist-info/top_level.txt,sha256=emrF0of931NzTSL4R5yBKpGoewFCB-cAwYNcUF5cqBs,6
29
+ orbit_torch-0.0.4a1.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (80.9.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1 @@
1
+ orbit