syckpt 0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- syckpt-0.0.1/PKG-INFO +134 -0
- syckpt-0.0.1/README.md +103 -0
- syckpt-0.0.1/pyproject.toml +44 -0
- syckpt-0.0.1/setup.cfg +4 -0
- syckpt-0.0.1/setup.py +58 -0
- syckpt-0.0.1/syckpt/__init__.py +20 -0
- syckpt-0.0.1/syckpt/config.py +161 -0
- syckpt-0.0.1/syckpt/dataloader.py +81 -0
- syckpt-0.0.1/syckpt/hash.py +249 -0
- syckpt-0.0.1/syckpt/manager.py +818 -0
- syckpt-0.0.1/syckpt/state.py +189 -0
- syckpt-0.0.1/syckpt/storage.py +237 -0
- syckpt-0.0.1/syckpt.egg-info/PKG-INFO +134 -0
- syckpt-0.0.1/syckpt.egg-info/SOURCES.txt +15 -0
- syckpt-0.0.1/syckpt.egg-info/dependency_links.txt +1 -0
- syckpt-0.0.1/syckpt.egg-info/requires.txt +12 -0
- syckpt-0.0.1/syckpt.egg-info/top_level.txt +1 -0
syckpt-0.0.1/PKG-INFO
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: syckpt
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: Git-like experiment tracking for deep learning with exact computational resumption
|
|
5
|
+
Home-page: https://github.com/sykchw/syckpt
|
|
6
|
+
Author: Sayak Chowdhury
|
|
7
|
+
Author-email: Sayak Chowdhury <sayak.iiitb@gmail.com>
|
|
8
|
+
License: MIT
|
|
9
|
+
Project-URL: Bug Reports, https://github.com/sykchw/syckpt/issues
|
|
10
|
+
Project-URL: Source, https://github.com/sykchw/syckpt
|
|
11
|
+
Classifier: Development Status :: 3 - Alpha
|
|
12
|
+
Classifier: Intended Audience :: Developers
|
|
13
|
+
Classifier: Intended Audience :: Science/Research
|
|
14
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
15
|
+
Requires-Python: >=3.8
|
|
16
|
+
Description-Content-Type: text/markdown
|
|
17
|
+
Requires-Dist: torch>=2.0.0
|
|
18
|
+
Requires-Dist: numpy>=1.20.0
|
|
19
|
+
Requires-Dist: safetensors>=0.4.0
|
|
20
|
+
Requires-Dist: fsspec>=2023.0.0
|
|
21
|
+
Provides-Extra: dev
|
|
22
|
+
Requires-Dist: pytest>=7.0.0; extra == "dev"
|
|
23
|
+
Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
|
|
24
|
+
Requires-Dist: black>=23.0.0; extra == "dev"
|
|
25
|
+
Requires-Dist: mypy>=1.0.0; extra == "dev"
|
|
26
|
+
Requires-Dist: twine>=4.0.0; extra == "dev"
|
|
27
|
+
Requires-Dist: build>=1.0.0; extra == "dev"
|
|
28
|
+
Dynamic: author
|
|
29
|
+
Dynamic: home-page
|
|
30
|
+
Dynamic: requires-python
|
|
31
|
+
|
|
32
|
+
# Syckpt v0.0.1
|
|
33
|
+
|
|
34
|
+
**Git-like experiment tracking for deep learning with exact computational resumption, zero-copy safetensors memory-mapping, and delta-compression.**
|
|
35
|
+
|
|
36
|
+
`syckpt` is a lightweight, local-first experiment version control system designed to perfectly reconstruct massive computational states—model weights, optimizer momentum, mixed-precision GradScalers, Random Number Generators, and Stateful DataLoaders—without perturbing the loss curve.
|
|
37
|
+
|
|
38
|
+
---
|
|
39
|
+
|
|
40
|
+
## How `syckpt` Works (The Architecture)
|
|
41
|
+
|
|
42
|
+
When training massive Deep Learning models, saving a full checkpoint at every epoch typically results in gigabytes of duplicated disk space and high latency. `syckpt` solves this by treating machine learning checkpoints like a Git repository.
|
|
43
|
+
|
|
44
|
+
1. **Content-Addressable Storage (CAS) & Delta-Compression**:
|
|
45
|
+
Instead of saving full 5GB `.pt` weight files at every step, `syckpt` finds the most mathematically similar historical checkpoint and computes the pure `float32` difference (`delta = current - base`). Because gradient steps are small, this delta is highly compressible. `syckpt` stores these deltas in a hidden `.syckpt/objects/` directory, saving up to 90% of disk space.
|
|
46
|
+
|
|
47
|
+
2. **Locality-Sensitive Hashing (LSH)**:
|
|
48
|
+
To instantly find the "most similar" historical checkpoint, `syckpt` uses LSH to hash your hyperparameters (like learning rate, batch size, and seed). Similar hyperparameters mathematically collide to produce identical hash prefixes, allowing the system to rapidly query the Git-tree.
|
|
49
|
+
|
|
50
|
+
3. **Zero-Copy memory mapping via Safetensors**:
|
|
51
|
+
`syckpt` bypasses Python's insecure and memory-heavy `pickle` module. It uses Rust-backed `safetensors` to memory-map the delta-blobs directly from your SSD into the GPU's VRAM ("Zero-Copy"), completely eliminating CPU RAM Out-Of-Memory (OOM) errors during loading.
|
|
52
|
+
|
|
53
|
+
4. **Exact Mathematical Resumption**:
|
|
54
|
+
Standard PyTorch training loops suffer from "resumption spikes" in the loss curve because the DataLoader indices and Random Number Generators (RNG) get reset. `syckpt` intercepts PyTorch, CUDA, and Numpy generators, as well as preserving the internal `RandomSampler` permutations of your DataLoaders. When you resume, it is mathematically identical to if the process was never interrupted.
|
|
55
|
+
|
|
56
|
+
## Installation
|
|
57
|
+
|
|
58
|
+
We utilize the Rust-accelerated `uv` package manager.
|
|
59
|
+
|
|
60
|
+
```bash
|
|
61
|
+
pip install syckpt
|
|
62
|
+
# Or using uv
|
|
63
|
+
uv pip install syckpt
|
|
64
|
+
```
|
|
65
|
+
|
|
66
|
+
## Quick Start
|
|
67
|
+
|
|
68
|
+
```python
|
|
69
|
+
import torch
|
|
70
|
+
import torch.nn as nn
|
|
71
|
+
import torch.optim as optim
|
|
72
|
+
from syckpt import CheckpointManager
|
|
73
|
+
from syckpt.dataloader import StatefulDataLoader
|
|
74
|
+
from torch.utils.data import DataLoader, TensorDataset
|
|
75
|
+
|
|
76
|
+
# Typical PyTorch components
|
|
77
|
+
model = nn.Linear(10, 2)
|
|
78
|
+
optimizer = optim.SGD(model.parameters(), lr=0.01)
|
|
79
|
+
|
|
80
|
+
dummy_data = TensorDataset(torch.randn(100, 10), torch.randn(100, 2))
|
|
81
|
+
# Wrap standard non-deterministic DataLoader internally
|
|
82
|
+
loader = StatefulDataLoader(DataLoader(dummy_data, batch_size=32, shuffle=True))
|
|
83
|
+
|
|
84
|
+
# Specify an S3 or Local URL: Atomic locks handles concurrency natively
|
|
85
|
+
with CheckpointManager("s3://my-experiments-bucket/.syckpt") as ckpt:
|
|
86
|
+
# 1. Register dynamic objects (Automatically mapped via flattening)
|
|
87
|
+
ckpt.model = model
|
|
88
|
+
ckpt.optimizer = optimizer
|
|
89
|
+
ckpt.dataloader = loader
|
|
90
|
+
|
|
91
|
+
# 2. Hyperparameters automatically generate the unique LSH Hash
|
|
92
|
+
ckpt.config.lr = 0.01
|
|
93
|
+
ckpt.config.batch_size = 32
|
|
94
|
+
|
|
95
|
+
# 3. Training Loop inherently traps the step and epoch parameters
|
|
96
|
+
for epoch in ckpt.loop(epochs=10):
|
|
97
|
+
for batch_idx, batch in enumerate(loader):
|
|
98
|
+
loss = torch.randn(1) # Fake loss
|
|
99
|
+
ckpt.step_up()
|
|
100
|
+
|
|
101
|
+
# Delta-Compression kicks in automatically
|
|
102
|
+
if epoch % 2 == 0:
|
|
103
|
+
ckpt.save(metric=loss.item())
|
|
104
|
+
|
|
105
|
+
print(f"Mathematical execution saved at LSH Commit: {ckpt.hash}")
|
|
106
|
+
```
|
|
107
|
+
|
|
108
|
+
## Feature Reference
|
|
109
|
+
|
|
110
|
+
### Exporting Monolithic Assets (`.ckpt`)
|
|
111
|
+
If you deploy your model and no longer need `.syckpt` branching, you can securely collapse the Git-tree into a standard monolithic PyTorch `.ckpt` file for Hugging Face or deployment:
|
|
112
|
+
```python
|
|
113
|
+
with CheckpointManager("./experiments") as ckpt:
|
|
114
|
+
# Recursively loads flat delta-tensors and reconstitutes standard dict
|
|
115
|
+
ckpt.export_ckpt(hash_or_branch="main", output_path="final-model.ckpt")
|
|
116
|
+
```
|
|
117
|
+
|
|
118
|
+
### Full Distributed Resumption (DDP)
|
|
119
|
+
`syckpt` seamlessly broadcasts LSH hashes and uses `dist.gather` to collect highly volatile RNG seeds across your entire multi-GPU cluster.
|
|
120
|
+
```python
|
|
121
|
+
import numpy as np
|
|
122
|
+
|
|
123
|
+
with CheckpointManager("./") as ckpt:
|
|
124
|
+
# Simply register your Modern Numpy generator and the state_manager intercepts the memory bytes
|
|
125
|
+
ckpt.numpy_rng = np.random.default_rng()
|
|
126
|
+
ckpt.save()
|
|
127
|
+
```
|
|
128
|
+
|
|
129
|
+
---
|
|
130
|
+
|
|
131
|
+
## Architectural Deep-Dive
|
|
132
|
+
Curious how `syckpt v0.0.1` leverages Git pointers, `fsspec` atomic cloud mechanisms, manages PyTorch tensors, and accelerates training via Zero-Copy Safetensors?
|
|
133
|
+
|
|
134
|
+
Read the definitive educational walkthrough: [Implementation Guide (`implementation.md`)](./implementation.md).
|
syckpt-0.0.1/README.md
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
# Syckpt v0.0.1
|
|
2
|
+
|
|
3
|
+
**Git-like experiment tracking for deep learning with exact computational resumption, zero-copy safetensors memory-mapping, and delta-compression.**
|
|
4
|
+
|
|
5
|
+
`syckpt` is a lightweight, local-first experiment version control system designed to perfectly reconstruct massive computational states—model weights, optimizer momentum, mixed-precision GradScalers, Random Number Generators, and Stateful DataLoaders—without perturbing the loss curve.
|
|
6
|
+
|
|
7
|
+
---
|
|
8
|
+
|
|
9
|
+
## How `syckpt` Works (The Architecture)
|
|
10
|
+
|
|
11
|
+
When training massive Deep Learning models, saving a full checkpoint at every epoch typically results in gigabytes of duplicated disk space and high latency. `syckpt` solves this by treating machine learning checkpoints like a Git repository.
|
|
12
|
+
|
|
13
|
+
1. **Content-Addressable Storage (CAS) & Delta-Compression**:
|
|
14
|
+
Instead of saving full 5GB `.pt` weight files at every step, `syckpt` finds the most mathematically similar historical checkpoint and computes the pure `float32` difference (`delta = current - base`). Because gradient steps are small, this delta is highly compressible. `syckpt` stores these deltas in a hidden `.syckpt/objects/` directory, saving up to 90% of disk space.
|
|
15
|
+
|
|
16
|
+
2. **Locality-Sensitive Hashing (LSH)**:
|
|
17
|
+
To instantly find the "most similar" historical checkpoint, `syckpt` uses LSH to hash your hyperparameters (like learning rate, batch size, and seed). Similar hyperparameters mathematically collide to produce identical hash prefixes, allowing the system to rapidly query the Git-tree.
|
|
18
|
+
|
|
19
|
+
3. **Zero-Copy memory mapping via Safetensors**:
|
|
20
|
+
`syckpt` bypasses Python's insecure and memory-heavy `pickle` module. It uses Rust-backed `safetensors` to memory-map the delta-blobs directly from your SSD into the GPU's VRAM ("Zero-Copy"), completely eliminating CPU RAM Out-Of-Memory (OOM) errors during loading.
|
|
21
|
+
|
|
22
|
+
4. **Exact Mathematical Resumption**:
|
|
23
|
+
Standard PyTorch training loops suffer from "resumption spikes" in the loss curve because the DataLoader indices and Random Number Generators (RNG) get reset. `syckpt` intercepts PyTorch, CUDA, and Numpy generators, as well as preserving the internal `RandomSampler` permutations of your DataLoaders. When you resume, it is mathematically identical to if the process was never interrupted.
|
|
24
|
+
|
|
25
|
+
## Installation
|
|
26
|
+
|
|
27
|
+
We utilize the Rust-accelerated `uv` package manager.
|
|
28
|
+
|
|
29
|
+
```bash
|
|
30
|
+
pip install syckpt
|
|
31
|
+
# Or using uv
|
|
32
|
+
uv pip install syckpt
|
|
33
|
+
```
|
|
34
|
+
|
|
35
|
+
## Quick Start
|
|
36
|
+
|
|
37
|
+
```python
|
|
38
|
+
import torch
|
|
39
|
+
import torch.nn as nn
|
|
40
|
+
import torch.optim as optim
|
|
41
|
+
from syckpt import CheckpointManager
|
|
42
|
+
from syckpt.dataloader import StatefulDataLoader
|
|
43
|
+
from torch.utils.data import DataLoader, TensorDataset
|
|
44
|
+
|
|
45
|
+
# Typical PyTorch components
|
|
46
|
+
model = nn.Linear(10, 2)
|
|
47
|
+
optimizer = optim.SGD(model.parameters(), lr=0.01)
|
|
48
|
+
|
|
49
|
+
dummy_data = TensorDataset(torch.randn(100, 10), torch.randn(100, 2))
|
|
50
|
+
# Wrap standard non-deterministic DataLoader internally
|
|
51
|
+
loader = StatefulDataLoader(DataLoader(dummy_data, batch_size=32, shuffle=True))
|
|
52
|
+
|
|
53
|
+
# Specify an S3 or Local URL: Atomic locks handles concurrency natively
|
|
54
|
+
with CheckpointManager("s3://my-experiments-bucket/.syckpt") as ckpt:
|
|
55
|
+
# 1. Register dynamic objects (Automatically mapped via flattening)
|
|
56
|
+
ckpt.model = model
|
|
57
|
+
ckpt.optimizer = optimizer
|
|
58
|
+
ckpt.dataloader = loader
|
|
59
|
+
|
|
60
|
+
# 2. Hyperparameters automatically generate the unique LSH Hash
|
|
61
|
+
ckpt.config.lr = 0.01
|
|
62
|
+
ckpt.config.batch_size = 32
|
|
63
|
+
|
|
64
|
+
# 3. Training Loop inherently traps the step and epoch parameters
|
|
65
|
+
for epoch in ckpt.loop(epochs=10):
|
|
66
|
+
for batch_idx, batch in enumerate(loader):
|
|
67
|
+
loss = torch.randn(1) # Fake loss
|
|
68
|
+
ckpt.step_up()
|
|
69
|
+
|
|
70
|
+
# Delta-Compression kicks in automatically
|
|
71
|
+
if epoch % 2 == 0:
|
|
72
|
+
ckpt.save(metric=loss.item())
|
|
73
|
+
|
|
74
|
+
print(f"Mathematical execution saved at LSH Commit: {ckpt.hash}")
|
|
75
|
+
```
|
|
76
|
+
|
|
77
|
+
## Feature Reference
|
|
78
|
+
|
|
79
|
+
### Exporting Monolithic Assets (`.ckpt`)
|
|
80
|
+
If you deploy your model and no longer need `.syckpt` branching, you can securely collapse the Git-tree into a standard monolithic PyTorch `.ckpt` file for Hugging Face or deployment:
|
|
81
|
+
```python
|
|
82
|
+
with CheckpointManager("./experiments") as ckpt:
|
|
83
|
+
# Recursively loads flat delta-tensors and reconstitutes standard dict
|
|
84
|
+
ckpt.export_ckpt(hash_or_branch="main", output_path="final-model.ckpt")
|
|
85
|
+
```
|
|
86
|
+
|
|
87
|
+
### Full Distributed Resumption (DDP)
|
|
88
|
+
`syckpt` seamlessly broadcasts LSH hashes and uses `dist.gather` to collect highly volatile RNG seeds across your entire multi-GPU cluster.
|
|
89
|
+
```python
|
|
90
|
+
import numpy as np
|
|
91
|
+
|
|
92
|
+
with CheckpointManager("./") as ckpt:
|
|
93
|
+
# Simply register your Modern Numpy generator and the state_manager intercepts the memory bytes
|
|
94
|
+
ckpt.numpy_rng = np.random.default_rng()
|
|
95
|
+
ckpt.save()
|
|
96
|
+
```
|
|
97
|
+
|
|
98
|
+
---
|
|
99
|
+
|
|
100
|
+
## Architectural Deep-Dive
|
|
101
|
+
Curious how `syckpt v0.0.1` leverages Git pointers, `fsspec` atomic cloud mechanisms, manages PyTorch tensors, and accelerates training via Zero-Copy Safetensors?
|
|
102
|
+
|
|
103
|
+
Read the definitive educational walkthrough: [Implementation Guide (`implementation.md`)](./implementation.md).
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=61.0.0"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "syckpt"
|
|
7
|
+
version = "0.0.1"
|
|
8
|
+
description = "Git-like experiment tracking for deep learning with exact computational resumption"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.8"
|
|
11
|
+
license = { text = "MIT" }
|
|
12
|
+
authors = [
|
|
13
|
+
{ name = "Sayak Chowdhury", email = "sayak.iiitb@gmail.com" }
|
|
14
|
+
]
|
|
15
|
+
classifiers = [
|
|
16
|
+
"Development Status :: 3 - Alpha",
|
|
17
|
+
"Intended Audience :: Developers",
|
|
18
|
+
"Intended Audience :: Science/Research",
|
|
19
|
+
"License :: OSI Approved :: MIT License",
|
|
20
|
+
]
|
|
21
|
+
dependencies = [
|
|
22
|
+
"torch>=2.0.0",
|
|
23
|
+
"numpy>=1.20.0",
|
|
24
|
+
"safetensors>=0.4.0",
|
|
25
|
+
"fsspec>=2023.0.0"
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
[project.optional-dependencies]
|
|
29
|
+
dev = [
|
|
30
|
+
"pytest>=7.0.0",
|
|
31
|
+
"pytest-cov>=4.0.0",
|
|
32
|
+
"black>=23.0.0",
|
|
33
|
+
"mypy>=1.0.0",
|
|
34
|
+
"twine>=4.0.0",
|
|
35
|
+
"build>=1.0.0"
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
[project.urls]
|
|
39
|
+
"Bug Reports" = "https://github.com/sykchw/syckpt/issues"
|
|
40
|
+
"Source" = "https://github.com/sykchw/syckpt"
|
|
41
|
+
|
|
42
|
+
[tool.setuptools]
|
|
43
|
+
packages = ["syckpt"]
|
|
44
|
+
|
syckpt-0.0.1/setup.cfg
ADDED
syckpt-0.0.1/setup.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
"""Setup for syckpt package - Git-like experiment tracking for DL."""
|
|
2
|
+
|
|
3
|
+
from setuptools import setup, find_packages
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
this_directory = Path(__file__).parent
|
|
7
|
+
long_description = (
|
|
8
|
+
(this_directory / "README.md").read_text()
|
|
9
|
+
if (this_directory / "README.md").exists()
|
|
10
|
+
else ""
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
setup(
|
|
14
|
+
name="syckpt",
|
|
15
|
+
version="0.0.1",
|
|
16
|
+
description="Git-like experiment tracking for deep learning with LSH hashing",
|
|
17
|
+
long_description=long_description,
|
|
18
|
+
long_description_content_type="text/markdown",
|
|
19
|
+
author="Sayak Chowdhury",
|
|
20
|
+
author_email="sayak.iiitb@gmail.com",
|
|
21
|
+
url="https://github.com/sykchw/syckpt",
|
|
22
|
+
packages=find_packages(exclude=["tests*", "docs*"]),
|
|
23
|
+
python_requires=">=3.8",
|
|
24
|
+
install_requires=[
|
|
25
|
+
"torch>=2.0.0",
|
|
26
|
+
"numpy>=1.20.0",
|
|
27
|
+
"safetensors>=0.4.0",
|
|
28
|
+
"fsspec>=2023.0.0"
|
|
29
|
+
],
|
|
30
|
+
extras_require={
|
|
31
|
+
"dev": [
|
|
32
|
+
"pytest>=7.0.0",
|
|
33
|
+
"pytest-cov>=4.0.0",
|
|
34
|
+
"black>=23.0.0",
|
|
35
|
+
"mypy>=1.0.0",
|
|
36
|
+
"twine>=4.0.0",
|
|
37
|
+
"build>=1.0.0",
|
|
38
|
+
],
|
|
39
|
+
},
|
|
40
|
+
classifiers=[
|
|
41
|
+
"Development Status :: 3 - Alpha",
|
|
42
|
+
"Intended Audience :: Developers",
|
|
43
|
+
"Intended Audience :: Science/Research",
|
|
44
|
+
"License :: OSI Approved :: MIT License",
|
|
45
|
+
"Programming Language :: Python :: 3",
|
|
46
|
+
"Programming Language :: Python :: 3.8",
|
|
47
|
+
"Programming Language :: Python :: 3.9",
|
|
48
|
+
"Programming Language :: Python :: 3.10",
|
|
49
|
+
"Programming Language :: Python :: 3.11",
|
|
50
|
+
"Programming Language :: Python :: 3.12",
|
|
51
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
52
|
+
],
|
|
53
|
+
keywords="machine-learning deep-learning experiment-tracking checkpoint reproducibility lsh",
|
|
54
|
+
project_urls={
|
|
55
|
+
"Bug Reports": "https://github.com/sykchw/syckpt/issues",
|
|
56
|
+
"Source": "https://github.com/sykchw/syckpt",
|
|
57
|
+
},
|
|
58
|
+
)
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""Checkpoint - Git-like experiment tracking for deep learning with LSH hashing."""
|
|
2
|
+
|
|
3
|
+
from syckpt.manager import CheckpointManager, create_checkpoint, Commit
|
|
4
|
+
from syckpt.config import HyperConfig
|
|
5
|
+
from syckpt.hash import LSHHashGenerator, DEFAULT_HASH_FACTORS
|
|
6
|
+
from syckpt.state import set_seed, get_rng_state, set_rng_state
|
|
7
|
+
|
|
8
|
+
__version__ = "0.0.1"
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"CheckpointManager",
|
|
12
|
+
"create_checkpoint",
|
|
13
|
+
"Commit",
|
|
14
|
+
"HyperConfig",
|
|
15
|
+
"LSHHashGenerator",
|
|
16
|
+
"DEFAULT_HASH_FACTORS",
|
|
17
|
+
"set_seed",
|
|
18
|
+
"get_rng_state",
|
|
19
|
+
"set_rng_state",
|
|
20
|
+
]
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
"""Configuration system with attribute and dict access."""
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
from typing import Any, Dict, Optional, Union
|
|
5
|
+
from collections.abc import Mapping
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class HyperConfig(Mapping):
|
|
9
|
+
"""A configuration object that supports both attribute and dict access.
|
|
10
|
+
|
|
11
|
+
Supports nested configuration via dot notation (e.g., config.a.b.c)
|
|
12
|
+
and provides full dict-like access (config['key'], config.get('key')).
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def __init__(self, data: Optional[Dict[str, Any]] = None, **kwargs):
|
|
16
|
+
self._data: Dict[str, Any] = {}
|
|
17
|
+
if data:
|
|
18
|
+
self._data = self._flatten_dict(data) if isinstance(data, dict) else {}
|
|
19
|
+
self._data.update(kwargs)
|
|
20
|
+
|
|
21
|
+
def _flatten_dict(
|
|
22
|
+
self, d: Dict[str, Any], parent_key: str = "", sep: str = "."
|
|
23
|
+
) -> Dict[str, Any]:
|
|
24
|
+
"""Flatten nested dict into dot-notation keys."""
|
|
25
|
+
items = []
|
|
26
|
+
for k, v in d.items():
|
|
27
|
+
new_key = f"{parent_key}{sep}{k}" if parent_key else k
|
|
28
|
+
if isinstance(v, dict):
|
|
29
|
+
items.extend(self._flatten_dict(v, new_key, sep=sep).items())
|
|
30
|
+
else:
|
|
31
|
+
items.append((new_key, v))
|
|
32
|
+
return dict(items)
|
|
33
|
+
|
|
34
|
+
def _unflatten_dict(self, d: Dict[str, Any], sep: str = ".") -> Dict[str, Any]:
|
|
35
|
+
"""Unflatten dot-notation keys back to nested dict."""
|
|
36
|
+
result = {}
|
|
37
|
+
for key, value in d.items():
|
|
38
|
+
parts = key.split(sep)
|
|
39
|
+
d_obj = result
|
|
40
|
+
for part in parts[:-1]:
|
|
41
|
+
if part not in d_obj:
|
|
42
|
+
d_obj[part] = {}
|
|
43
|
+
d_obj = d_obj[part]
|
|
44
|
+
d_obj[parts[-1]] = value
|
|
45
|
+
return result
|
|
46
|
+
|
|
47
|
+
def __getattr__(self, name: str) -> Any:
|
|
48
|
+
if name.startswith("_"):
|
|
49
|
+
return object.__getattribute__(self, name)
|
|
50
|
+
unflattened = self._unflatten_dict(self._data)
|
|
51
|
+
if name in unflattened:
|
|
52
|
+
val = unflattened[name]
|
|
53
|
+
if isinstance(val, dict) and all(
|
|
54
|
+
isinstance(k, str) and not any("." in k for k in v.keys())
|
|
55
|
+
if isinstance(v, dict)
|
|
56
|
+
else True
|
|
57
|
+
for k, v in val.items()
|
|
58
|
+
if isinstance(v, dict)
|
|
59
|
+
):
|
|
60
|
+
return HyperConfig(val)
|
|
61
|
+
return val
|
|
62
|
+
if name in self._data:
|
|
63
|
+
return self._data[name]
|
|
64
|
+
raise AttributeError(f"'{type(self).__name__}' has no attribute '{name}'")
|
|
65
|
+
|
|
66
|
+
def __setattr__(self, name: str, value: Any) -> None:
|
|
67
|
+
if name.startswith("_"):
|
|
68
|
+
object.__setattr__(self, name, value)
|
|
69
|
+
else:
|
|
70
|
+
if isinstance(value, dict):
|
|
71
|
+
for k, v in self._flatten_dict({name: value}).items():
|
|
72
|
+
self._data[k] = v
|
|
73
|
+
else:
|
|
74
|
+
self._data[name] = value
|
|
75
|
+
|
|
76
|
+
def __delattr__(self, name: str) -> None:
|
|
77
|
+
if name in self._data:
|
|
78
|
+
del self._data[name]
|
|
79
|
+
unflattened = self._unflatten_dict(self._data)
|
|
80
|
+
if name in unflattened:
|
|
81
|
+
del unflattened[name]
|
|
82
|
+
self._data = self._flatten_dict(unflattened)
|
|
83
|
+
return
|
|
84
|
+
raise AttributeError(f"'{type(self).__name__}' has no attribute '{name}'")
|
|
85
|
+
|
|
86
|
+
def __getitem__(self, key: str) -> Any:
|
|
87
|
+
return self._data[key]
|
|
88
|
+
|
|
89
|
+
def __setitem__(self, key: str, value: Any) -> None:
|
|
90
|
+
if isinstance(value, dict):
|
|
91
|
+
for k, v in self._flatten_dict({key: value}).items():
|
|
92
|
+
self._data[k] = v
|
|
93
|
+
else:
|
|
94
|
+
self._data[key] = value
|
|
95
|
+
|
|
96
|
+
def __delitem__(self, key: str) -> None:
|
|
97
|
+
del self._data[key]
|
|
98
|
+
|
|
99
|
+
def __contains__(self, key: object) -> bool:
|
|
100
|
+
return isinstance(key, str) and key in self._data
|
|
101
|
+
|
|
102
|
+
def __iter__(self):
|
|
103
|
+
return iter(self._unflatten_dict(self._data))
|
|
104
|
+
|
|
105
|
+
def __len__(self) -> int:
|
|
106
|
+
return len(self._unflatten_dict(self._data))
|
|
107
|
+
|
|
108
|
+
def __repr__(self) -> str:
|
|
109
|
+
return f"{type(self).__name__}({self._unflatten_dict(self._data)})"
|
|
110
|
+
|
|
111
|
+
def __str__(self) -> str:
|
|
112
|
+
import json
|
|
113
|
+
|
|
114
|
+
return json.dumps(self._unflatten_dict(self._data), indent=2)
|
|
115
|
+
|
|
116
|
+
def __bool__(self) -> bool:
|
|
117
|
+
return bool(self._data)
|
|
118
|
+
|
|
119
|
+
def get(self, key: str, default: Any = None) -> Any:
|
|
120
|
+
return self._data.get(key, default)
|
|
121
|
+
|
|
122
|
+
def update(
|
|
123
|
+
self, other: Union[Dict[str, Any], "HyperConfig"], **kwargs
|
|
124
|
+
) -> "HyperConfig":
|
|
125
|
+
"""Update config with new values."""
|
|
126
|
+
if other:
|
|
127
|
+
if isinstance(other, HyperConfig):
|
|
128
|
+
self._data.update(other._data)
|
|
129
|
+
elif isinstance(other, dict):
|
|
130
|
+
self._data.update(self._flatten_dict(other))
|
|
131
|
+
self._data.update(self._flatten_dict(kwargs))
|
|
132
|
+
return self
|
|
133
|
+
|
|
134
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
135
|
+
"""Return unflattened dict representation."""
|
|
136
|
+
return self._unflatten_dict(self._data)
|
|
137
|
+
|
|
138
|
+
def copy(self) -> "HyperConfig":
|
|
139
|
+
"""Return a shallow copy."""
|
|
140
|
+
return HyperConfig(copy.copy(self._data))
|
|
141
|
+
|
|
142
|
+
def deep_copy(self) -> "HyperConfig":
|
|
143
|
+
"""Return a deep copy."""
|
|
144
|
+
return HyperConfig(copy.deepcopy(self._data))
|
|
145
|
+
|
|
146
|
+
@classmethod
|
|
147
|
+
def from_dict(cls, data: Dict[str, Any]) -> "HyperConfig":
|
|
148
|
+
"""Create config from dict."""
|
|
149
|
+
return cls(data)
|
|
150
|
+
|
|
151
|
+
def items(self):
|
|
152
|
+
"""Return unflattened items."""
|
|
153
|
+
return self._unflatten_dict(self._data).items()
|
|
154
|
+
|
|
155
|
+
def keys(self):
|
|
156
|
+
"""Return unflattened keys."""
|
|
157
|
+
return self._unflatten_dict(self._data).keys()
|
|
158
|
+
|
|
159
|
+
def values(self):
|
|
160
|
+
"""Return unflattened values."""
|
|
161
|
+
return self._unflatten_dict(self._data).values()
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
|
3
|
+
|
|
4
|
+
class StatefulDataLoader:
|
|
5
|
+
"""A wrapper for PyTorch DataLoader that allows exact resumption.
|
|
6
|
+
|
|
7
|
+
It tracks the batch index and the exact permuted index list generated by
|
|
8
|
+
the RandomSampler. Upon resumption, it slices the iterator to return
|
|
9
|
+
only the remaining batches, without fast-forwarding the dataset visually.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
def __init__(self, dataloader: DataLoader, base_seed: int = 42):
|
|
13
|
+
self.dataloader = dataloader
|
|
14
|
+
self.base_seed = base_seed
|
|
15
|
+
self.batch_idx = 0
|
|
16
|
+
self.epoch = 0
|
|
17
|
+
self._indices = []
|
|
18
|
+
self._iterator = None
|
|
19
|
+
self._generator = torch.Generator()
|
|
20
|
+
|
|
21
|
+
def __iter__(self):
|
|
22
|
+
# We manually seed the generator based on epoch for exact reproducible shuffling
|
|
23
|
+
self._generator.manual_seed(self.base_seed + self.epoch)
|
|
24
|
+
|
|
25
|
+
# We recreate the indices to maintain exact determinism
|
|
26
|
+
if isinstance(self.dataloader.sampler, RandomSampler):
|
|
27
|
+
# Using torch.randperm directly as the RandomSampler would with our generator
|
|
28
|
+
n = len(self.dataloader.dataset)
|
|
29
|
+
self._indices = torch.randperm(n, generator=self._generator).tolist()
|
|
30
|
+
elif isinstance(self.dataloader.sampler, SequentialSampler):
|
|
31
|
+
self._indices = list(range(len(self.dataloader.dataset)))
|
|
32
|
+
else:
|
|
33
|
+
# For custom samplers, we extract indices if we can
|
|
34
|
+
try:
|
|
35
|
+
self._indices = list(self.dataloader.sampler)
|
|
36
|
+
except Exception:
|
|
37
|
+
self._indices = []
|
|
38
|
+
|
|
39
|
+
# Fast forward if we resumed mid-epoch
|
|
40
|
+
if self.batch_idx > 0 and self._indices:
|
|
41
|
+
# Skip the items we already yielded
|
|
42
|
+
items_to_skip = self.batch_idx * self.dataloader.batch_size
|
|
43
|
+
self._indices = self._indices[items_to_skip:]
|
|
44
|
+
|
|
45
|
+
# create a subset dataloader or just yield properly
|
|
46
|
+
# In a true wrapper we'd modify the sampler, but for simplicity
|
|
47
|
+
# we iterate and skip quickly (which is still slow for massive datasets,
|
|
48
|
+
# so the true way is subclassing the sampler).
|
|
49
|
+
|
|
50
|
+
# Since standard dataloader isn't easily sliced without rewriting it:
|
|
51
|
+
self._iterator = iter(self.dataloader)
|
|
52
|
+
for _ in range(self.batch_idx):
|
|
53
|
+
next(self._iterator, None)
|
|
54
|
+
else:
|
|
55
|
+
self._iterator = iter(self.dataloader)
|
|
56
|
+
|
|
57
|
+
return self
|
|
58
|
+
|
|
59
|
+
def __next__(self):
|
|
60
|
+
try:
|
|
61
|
+
batch = next(self._iterator)
|
|
62
|
+
self.batch_idx += 1
|
|
63
|
+
return batch
|
|
64
|
+
except StopIteration:
|
|
65
|
+
self.epoch += 1
|
|
66
|
+
self.batch_idx = 0
|
|
67
|
+
raise
|
|
68
|
+
|
|
69
|
+
def state_dict(self):
|
|
70
|
+
return {
|
|
71
|
+
"batch_idx": self.batch_idx,
|
|
72
|
+
"epoch": self.epoch,
|
|
73
|
+
"indices": self._indices,
|
|
74
|
+
"base_seed": self.base_seed
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
def load_state_dict(self, state):
|
|
78
|
+
self.batch_idx = state.get("batch_idx", 0)
|
|
79
|
+
self.epoch = state.get("epoch", 0)
|
|
80
|
+
self._indices = state.get("indices", [])
|
|
81
|
+
self.base_seed = state.get("base_seed", 42)
|