syckpt 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
syckpt-0.0.1/PKG-INFO ADDED
@@ -0,0 +1,134 @@
1
+ Metadata-Version: 2.4
2
+ Name: syckpt
3
+ Version: 0.0.1
4
+ Summary: Git-like experiment tracking for deep learning with exact computational resumption
5
+ Home-page: https://github.com/sykchw/syckpt
6
+ Author: Sayak Chowdhury
7
+ Author-email: Sayak Chowdhury <sayak.iiitb@gmail.com>
8
+ License: MIT
9
+ Project-URL: Bug Reports, https://github.com/sykchw/syckpt/issues
10
+ Project-URL: Source, https://github.com/sykchw/syckpt
11
+ Classifier: Development Status :: 3 - Alpha
12
+ Classifier: Intended Audience :: Developers
13
+ Classifier: Intended Audience :: Science/Research
14
+ Classifier: License :: OSI Approved :: MIT License
15
+ Requires-Python: >=3.8
16
+ Description-Content-Type: text/markdown
17
+ Requires-Dist: torch>=2.0.0
18
+ Requires-Dist: numpy>=1.20.0
19
+ Requires-Dist: safetensors>=0.4.0
20
+ Requires-Dist: fsspec>=2023.0.0
21
+ Provides-Extra: dev
22
+ Requires-Dist: pytest>=7.0.0; extra == "dev"
23
+ Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
24
+ Requires-Dist: black>=23.0.0; extra == "dev"
25
+ Requires-Dist: mypy>=1.0.0; extra == "dev"
26
+ Requires-Dist: twine>=4.0.0; extra == "dev"
27
+ Requires-Dist: build>=1.0.0; extra == "dev"
28
+ Dynamic: author
29
+ Dynamic: home-page
30
+ Dynamic: requires-python
31
+
32
+ # Syckpt v0.0.1
33
+
34
+ **Git-like experiment tracking for deep learning with exact computational resumption, zero-copy safetensors memory-mapping, and delta-compression.**
35
+
36
+ `syckpt` is a lightweight, local-first experiment version control system designed to perfectly reconstruct massive computational states—model weights, optimizer momentum, mixed-precision GradScalers, Random Number Generators, and Stateful DataLoaders—without perturbing the loss curve.
37
+
38
+ ---
39
+
40
+ ## How `syckpt` Works (The Architecture)
41
+
42
+ When training massive Deep Learning models, saving a full checkpoint at every epoch typically results in gigabytes of duplicated disk space and high latency. `syckpt` solves this by treating machine learning checkpoints like a Git repository.
43
+
44
+ 1. **Content-Addressable Storage (CAS) & Delta-Compression**:
45
+ Instead of saving full 5GB `.pt` weight files at every step, `syckpt` finds the most mathematically similar historical checkpoint and computes the pure `float32` difference (`delta = current - base`). Because gradient steps are small, this delta is highly compressible. `syckpt` stores these deltas in a hidden `.syckpt/objects/` directory, saving up to 90% of disk space.
46
+
47
+ 2. **Locality-Sensitive Hashing (LSH)**:
48
+ To instantly find the "most similar" historical checkpoint, `syckpt` uses LSH to hash your hyperparameters (like learning rate, batch size, and seed). Similar hyperparameters mathematically collide to produce identical hash prefixes, allowing the system to rapidly query the Git-tree.
49
+
50
+ 3. **Zero-Copy memory mapping via Safetensors**:
51
+ `syckpt` bypasses Python's insecure and memory-heavy `pickle` module. It uses Rust-backed `safetensors` to memory-map the delta-blobs directly from your SSD into the GPU's VRAM ("Zero-Copy"), completely eliminating CPU RAM Out-Of-Memory (OOM) errors during loading.
52
+
53
+ 4. **Exact Mathematical Resumption**:
54
+ Standard PyTorch training loops suffer from "resumption spikes" in the loss curve because the DataLoader indices and Random Number Generators (RNG) get reset. `syckpt` intercepts PyTorch, CUDA, and Numpy generators, as well as preserving the internal `RandomSampler` permutations of your DataLoaders. When you resume, it is mathematically identical to if the process was never interrupted.
55
+
56
+ ## Installation
57
+
58
+ We utilize the Rust-accelerated `uv` package manager.
59
+
60
+ ```bash
61
+ pip install syckpt
62
+ # Or using uv
63
+ uv pip install syckpt
64
+ ```
65
+
66
+ ## Quick Start
67
+
68
+ ```python
69
+ import torch
70
+ import torch.nn as nn
71
+ import torch.optim as optim
72
+ from syckpt import CheckpointManager
73
+ from syckpt.dataloader import StatefulDataLoader
74
+ from torch.utils.data import DataLoader, TensorDataset
75
+
76
+ # Typical PyTorch components
77
+ model = nn.Linear(10, 2)
78
+ optimizer = optim.SGD(model.parameters(), lr=0.01)
79
+
80
+ dummy_data = TensorDataset(torch.randn(100, 10), torch.randn(100, 2))
81
+ # Wrap standard non-deterministic DataLoader internally
82
+ loader = StatefulDataLoader(DataLoader(dummy_data, batch_size=32, shuffle=True))
83
+
84
+ # Specify an S3 or Local URL: Atomic locks handles concurrency natively
85
+ with CheckpointManager("s3://my-experiments-bucket/.syckpt") as ckpt:
86
+ # 1. Register dynamic objects (Automatically mapped via flattening)
87
+ ckpt.model = model
88
+ ckpt.optimizer = optimizer
89
+ ckpt.dataloader = loader
90
+
91
+ # 2. Hyperparameters automatically generate the unique LSH Hash
92
+ ckpt.config.lr = 0.01
93
+ ckpt.config.batch_size = 32
94
+
95
+ # 3. Training Loop inherently traps the step and epoch parameters
96
+ for epoch in ckpt.loop(epochs=10):
97
+ for batch_idx, batch in enumerate(loader):
98
+ loss = torch.randn(1) # Fake loss
99
+ ckpt.step_up()
100
+
101
+ # Delta-Compression kicks in automatically
102
+ if epoch % 2 == 0:
103
+ ckpt.save(metric=loss.item())
104
+
105
+ print(f"Mathematical execution saved at LSH Commit: {ckpt.hash}")
106
+ ```
107
+
108
+ ## Feature Reference
109
+
110
+ ### Exporting Monolithic Assets (`.ckpt`)
111
+ If you deploy your model and no longer need `.syckpt` branching, you can securely collapse the Git-tree into a standard monolithic PyTorch `.ckpt` file for Hugging Face or deployment:
112
+ ```python
113
+ with CheckpointManager("./experiments") as ckpt:
114
+ # Recursively loads flat delta-tensors and reconstitutes standard dict
115
+ ckpt.export_ckpt(hash_or_branch="main", output_path="final-model.ckpt")
116
+ ```
117
+
118
+ ### Full Distributed Resumption (DDP)
119
+ `syckpt` seamlessly broadcasts LSH hashes and uses `dist.gather` to collect highly volatile RNG seeds across your entire multi-GPU cluster.
120
+ ```python
121
+ import numpy as np
122
+
123
+ with CheckpointManager("./") as ckpt:
124
+ # Simply register your Modern Numpy generator and the state_manager intercepts the memory bytes
125
+ ckpt.numpy_rng = np.random.default_rng()
126
+ ckpt.save()
127
+ ```
128
+
129
+ ---
130
+
131
+ ## Architectural Deep-Dive
132
+ Curious how `syckpt v0.0.1` leverages Git pointers, `fsspec` atomic cloud mechanisms, manages PyTorch tensors, and accelerates training via Zero-Copy Safetensors?
133
+
134
+ Read the definitive educational walkthrough: [Implementation Guide (`implementation.md`)](./implementation.md).
syckpt-0.0.1/README.md ADDED
@@ -0,0 +1,103 @@
1
+ # Syckpt v0.0.1
2
+
3
+ **Git-like experiment tracking for deep learning with exact computational resumption, zero-copy safetensors memory-mapping, and delta-compression.**
4
+
5
+ `syckpt` is a lightweight, local-first experiment version control system designed to perfectly reconstruct massive computational states—model weights, optimizer momentum, mixed-precision GradScalers, Random Number Generators, and Stateful DataLoaders—without perturbing the loss curve.
6
+
7
+ ---
8
+
9
+ ## How `syckpt` Works (The Architecture)
10
+
11
+ When training massive Deep Learning models, saving a full checkpoint at every epoch typically results in gigabytes of duplicated disk space and high latency. `syckpt` solves this by treating machine learning checkpoints like a Git repository.
12
+
13
+ 1. **Content-Addressable Storage (CAS) & Delta-Compression**:
14
+ Instead of saving full 5GB `.pt` weight files at every step, `syckpt` finds the most mathematically similar historical checkpoint and computes the pure `float32` difference (`delta = current - base`). Because gradient steps are small, this delta is highly compressible. `syckpt` stores these deltas in a hidden `.syckpt/objects/` directory, saving up to 90% of disk space.
15
+
16
+ 2. **Locality-Sensitive Hashing (LSH)**:
17
+ To instantly find the "most similar" historical checkpoint, `syckpt` uses LSH to hash your hyperparameters (like learning rate, batch size, and seed). Similar hyperparameters mathematically collide to produce identical hash prefixes, allowing the system to rapidly query the Git-tree.
18
+
19
+ 3. **Zero-Copy memory mapping via Safetensors**:
20
+ `syckpt` bypasses Python's insecure and memory-heavy `pickle` module. It uses Rust-backed `safetensors` to memory-map the delta-blobs directly from your SSD into the GPU's VRAM ("Zero-Copy"), completely eliminating CPU RAM Out-Of-Memory (OOM) errors during loading.
21
+
22
+ 4. **Exact Mathematical Resumption**:
23
+ Standard PyTorch training loops suffer from "resumption spikes" in the loss curve because the DataLoader indices and Random Number Generators (RNG) get reset. `syckpt` intercepts PyTorch, CUDA, and Numpy generators, as well as preserving the internal `RandomSampler` permutations of your DataLoaders. When you resume, it is mathematically identical to if the process was never interrupted.
24
+
25
+ ## Installation
26
+
27
+ We utilize the Rust-accelerated `uv` package manager.
28
+
29
+ ```bash
30
+ pip install syckpt
31
+ # Or using uv
32
+ uv pip install syckpt
33
+ ```
34
+
35
+ ## Quick Start
36
+
37
+ ```python
38
+ import torch
39
+ import torch.nn as nn
40
+ import torch.optim as optim
41
+ from syckpt import CheckpointManager
42
+ from syckpt.dataloader import StatefulDataLoader
43
+ from torch.utils.data import DataLoader, TensorDataset
44
+
45
+ # Typical PyTorch components
46
+ model = nn.Linear(10, 2)
47
+ optimizer = optim.SGD(model.parameters(), lr=0.01)
48
+
49
+ dummy_data = TensorDataset(torch.randn(100, 10), torch.randn(100, 2))
50
+ # Wrap standard non-deterministic DataLoader internally
51
+ loader = StatefulDataLoader(DataLoader(dummy_data, batch_size=32, shuffle=True))
52
+
53
+ # Specify an S3 or Local URL: Atomic locks handles concurrency natively
54
+ with CheckpointManager("s3://my-experiments-bucket/.syckpt") as ckpt:
55
+ # 1. Register dynamic objects (Automatically mapped via flattening)
56
+ ckpt.model = model
57
+ ckpt.optimizer = optimizer
58
+ ckpt.dataloader = loader
59
+
60
+ # 2. Hyperparameters automatically generate the unique LSH Hash
61
+ ckpt.config.lr = 0.01
62
+ ckpt.config.batch_size = 32
63
+
64
+ # 3. Training Loop inherently traps the step and epoch parameters
65
+ for epoch in ckpt.loop(epochs=10):
66
+ for batch_idx, batch in enumerate(loader):
67
+ loss = torch.randn(1) # Fake loss
68
+ ckpt.step_up()
69
+
70
+ # Delta-Compression kicks in automatically
71
+ if epoch % 2 == 0:
72
+ ckpt.save(metric=loss.item())
73
+
74
+ print(f"Mathematical execution saved at LSH Commit: {ckpt.hash}")
75
+ ```
76
+
77
+ ## Feature Reference
78
+
79
+ ### Exporting Monolithic Assets (`.ckpt`)
80
+ If you deploy your model and no longer need `.syckpt` branching, you can securely collapse the Git-tree into a standard monolithic PyTorch `.ckpt` file for Hugging Face or deployment:
81
+ ```python
82
+ with CheckpointManager("./experiments") as ckpt:
83
+ # Recursively loads flat delta-tensors and reconstitutes standard dict
84
+ ckpt.export_ckpt(hash_or_branch="main", output_path="final-model.ckpt")
85
+ ```
86
+
87
+ ### Full Distributed Resumption (DDP)
88
+ `syckpt` seamlessly broadcasts LSH hashes and uses `dist.gather` to collect highly volatile RNG seeds across your entire multi-GPU cluster.
89
+ ```python
90
+ import numpy as np
91
+
92
+ with CheckpointManager("./") as ckpt:
93
+ # Simply register your Modern Numpy generator and the state_manager intercepts the memory bytes
94
+ ckpt.numpy_rng = np.random.default_rng()
95
+ ckpt.save()
96
+ ```
97
+
98
+ ---
99
+
100
+ ## Architectural Deep-Dive
101
+ Curious how `syckpt v0.0.1` leverages Git pointers, `fsspec` atomic cloud mechanisms, manages PyTorch tensors, and accelerates training via Zero-Copy Safetensors?
102
+
103
+ Read the definitive educational walkthrough: [Implementation Guide (`implementation.md`)](./implementation.md).
@@ -0,0 +1,44 @@
1
+ [build-system]
2
+ requires = ["setuptools>=61.0.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "syckpt"
7
+ version = "0.0.1"
8
+ description = "Git-like experiment tracking for deep learning with exact computational resumption"
9
+ readme = "README.md"
10
+ requires-python = ">=3.8"
11
+ license = { text = "MIT" }
12
+ authors = [
13
+ { name = "Sayak Chowdhury", email = "sayak.iiitb@gmail.com" }
14
+ ]
15
+ classifiers = [
16
+ "Development Status :: 3 - Alpha",
17
+ "Intended Audience :: Developers",
18
+ "Intended Audience :: Science/Research",
19
+ "License :: OSI Approved :: MIT License",
20
+ ]
21
+ dependencies = [
22
+ "torch>=2.0.0",
23
+ "numpy>=1.20.0",
24
+ "safetensors>=0.4.0",
25
+ "fsspec>=2023.0.0"
26
+ ]
27
+
28
+ [project.optional-dependencies]
29
+ dev = [
30
+ "pytest>=7.0.0",
31
+ "pytest-cov>=4.0.0",
32
+ "black>=23.0.0",
33
+ "mypy>=1.0.0",
34
+ "twine>=4.0.0",
35
+ "build>=1.0.0"
36
+ ]
37
+
38
+ [project.urls]
39
+ "Bug Reports" = "https://github.com/sykchw/syckpt/issues"
40
+ "Source" = "https://github.com/sykchw/syckpt"
41
+
42
+ [tool.setuptools]
43
+ packages = ["syckpt"]
44
+
syckpt-0.0.1/setup.cfg ADDED
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
syckpt-0.0.1/setup.py ADDED
@@ -0,0 +1,58 @@
1
+ """Setup for syckpt package - Git-like experiment tracking for DL."""
2
+
3
+ from setuptools import setup, find_packages
4
+ from pathlib import Path
5
+
6
+ this_directory = Path(__file__).parent
7
+ long_description = (
8
+ (this_directory / "README.md").read_text()
9
+ if (this_directory / "README.md").exists()
10
+ else ""
11
+ )
12
+
13
+ setup(
14
+ name="syckpt",
15
+ version="0.0.1",
16
+ description="Git-like experiment tracking for deep learning with LSH hashing",
17
+ long_description=long_description,
18
+ long_description_content_type="text/markdown",
19
+ author="Sayak Chowdhury",
20
+ author_email="sayak.iiitb@gmail.com",
21
+ url="https://github.com/sykchw/syckpt",
22
+ packages=find_packages(exclude=["tests*", "docs*"]),
23
+ python_requires=">=3.8",
24
+ install_requires=[
25
+ "torch>=2.0.0",
26
+ "numpy>=1.20.0",
27
+ "safetensors>=0.4.0",
28
+ "fsspec>=2023.0.0"
29
+ ],
30
+ extras_require={
31
+ "dev": [
32
+ "pytest>=7.0.0",
33
+ "pytest-cov>=4.0.0",
34
+ "black>=23.0.0",
35
+ "mypy>=1.0.0",
36
+ "twine>=4.0.0",
37
+ "build>=1.0.0",
38
+ ],
39
+ },
40
+ classifiers=[
41
+ "Development Status :: 3 - Alpha",
42
+ "Intended Audience :: Developers",
43
+ "Intended Audience :: Science/Research",
44
+ "License :: OSI Approved :: MIT License",
45
+ "Programming Language :: Python :: 3",
46
+ "Programming Language :: Python :: 3.8",
47
+ "Programming Language :: Python :: 3.9",
48
+ "Programming Language :: Python :: 3.10",
49
+ "Programming Language :: Python :: 3.11",
50
+ "Programming Language :: Python :: 3.12",
51
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
52
+ ],
53
+ keywords="machine-learning deep-learning experiment-tracking checkpoint reproducibility lsh",
54
+ project_urls={
55
+ "Bug Reports": "https://github.com/sykchw/syckpt/issues",
56
+ "Source": "https://github.com/sykchw/syckpt",
57
+ },
58
+ )
@@ -0,0 +1,20 @@
1
+ """Checkpoint - Git-like experiment tracking for deep learning with LSH hashing."""
2
+
3
+ from syckpt.manager import CheckpointManager, create_checkpoint, Commit
4
+ from syckpt.config import HyperConfig
5
+ from syckpt.hash import LSHHashGenerator, DEFAULT_HASH_FACTORS
6
+ from syckpt.state import set_seed, get_rng_state, set_rng_state
7
+
8
+ __version__ = "0.0.1"
9
+
10
+ __all__ = [
11
+ "CheckpointManager",
12
+ "create_checkpoint",
13
+ "Commit",
14
+ "HyperConfig",
15
+ "LSHHashGenerator",
16
+ "DEFAULT_HASH_FACTORS",
17
+ "set_seed",
18
+ "get_rng_state",
19
+ "set_rng_state",
20
+ ]
@@ -0,0 +1,161 @@
1
+ """Configuration system with attribute and dict access."""
2
+
3
+ import copy
4
+ from typing import Any, Dict, Optional, Union
5
+ from collections.abc import Mapping
6
+
7
+
8
+ class HyperConfig(Mapping):
9
+ """A configuration object that supports both attribute and dict access.
10
+
11
+ Supports nested configuration via dot notation (e.g., config.a.b.c)
12
+ and provides full dict-like access (config['key'], config.get('key')).
13
+ """
14
+
15
+ def __init__(self, data: Optional[Dict[str, Any]] = None, **kwargs):
16
+ self._data: Dict[str, Any] = {}
17
+ if data:
18
+ self._data = self._flatten_dict(data) if isinstance(data, dict) else {}
19
+ self._data.update(kwargs)
20
+
21
+ def _flatten_dict(
22
+ self, d: Dict[str, Any], parent_key: str = "", sep: str = "."
23
+ ) -> Dict[str, Any]:
24
+ """Flatten nested dict into dot-notation keys."""
25
+ items = []
26
+ for k, v in d.items():
27
+ new_key = f"{parent_key}{sep}{k}" if parent_key else k
28
+ if isinstance(v, dict):
29
+ items.extend(self._flatten_dict(v, new_key, sep=sep).items())
30
+ else:
31
+ items.append((new_key, v))
32
+ return dict(items)
33
+
34
+ def _unflatten_dict(self, d: Dict[str, Any], sep: str = ".") -> Dict[str, Any]:
35
+ """Unflatten dot-notation keys back to nested dict."""
36
+ result = {}
37
+ for key, value in d.items():
38
+ parts = key.split(sep)
39
+ d_obj = result
40
+ for part in parts[:-1]:
41
+ if part not in d_obj:
42
+ d_obj[part] = {}
43
+ d_obj = d_obj[part]
44
+ d_obj[parts[-1]] = value
45
+ return result
46
+
47
+ def __getattr__(self, name: str) -> Any:
48
+ if name.startswith("_"):
49
+ return object.__getattribute__(self, name)
50
+ unflattened = self._unflatten_dict(self._data)
51
+ if name in unflattened:
52
+ val = unflattened[name]
53
+ if isinstance(val, dict) and all(
54
+ isinstance(k, str) and not any("." in k for k in v.keys())
55
+ if isinstance(v, dict)
56
+ else True
57
+ for k, v in val.items()
58
+ if isinstance(v, dict)
59
+ ):
60
+ return HyperConfig(val)
61
+ return val
62
+ if name in self._data:
63
+ return self._data[name]
64
+ raise AttributeError(f"'{type(self).__name__}' has no attribute '{name}'")
65
+
66
+ def __setattr__(self, name: str, value: Any) -> None:
67
+ if name.startswith("_"):
68
+ object.__setattr__(self, name, value)
69
+ else:
70
+ if isinstance(value, dict):
71
+ for k, v in self._flatten_dict({name: value}).items():
72
+ self._data[k] = v
73
+ else:
74
+ self._data[name] = value
75
+
76
+ def __delattr__(self, name: str) -> None:
77
+ if name in self._data:
78
+ del self._data[name]
79
+ unflattened = self._unflatten_dict(self._data)
80
+ if name in unflattened:
81
+ del unflattened[name]
82
+ self._data = self._flatten_dict(unflattened)
83
+ return
84
+ raise AttributeError(f"'{type(self).__name__}' has no attribute '{name}'")
85
+
86
+ def __getitem__(self, key: str) -> Any:
87
+ return self._data[key]
88
+
89
+ def __setitem__(self, key: str, value: Any) -> None:
90
+ if isinstance(value, dict):
91
+ for k, v in self._flatten_dict({key: value}).items():
92
+ self._data[k] = v
93
+ else:
94
+ self._data[key] = value
95
+
96
+ def __delitem__(self, key: str) -> None:
97
+ del self._data[key]
98
+
99
+ def __contains__(self, key: object) -> bool:
100
+ return isinstance(key, str) and key in self._data
101
+
102
+ def __iter__(self):
103
+ return iter(self._unflatten_dict(self._data))
104
+
105
+ def __len__(self) -> int:
106
+ return len(self._unflatten_dict(self._data))
107
+
108
+ def __repr__(self) -> str:
109
+ return f"{type(self).__name__}({self._unflatten_dict(self._data)})"
110
+
111
+ def __str__(self) -> str:
112
+ import json
113
+
114
+ return json.dumps(self._unflatten_dict(self._data), indent=2)
115
+
116
+ def __bool__(self) -> bool:
117
+ return bool(self._data)
118
+
119
+ def get(self, key: str, default: Any = None) -> Any:
120
+ return self._data.get(key, default)
121
+
122
+ def update(
123
+ self, other: Union[Dict[str, Any], "HyperConfig"], **kwargs
124
+ ) -> "HyperConfig":
125
+ """Update config with new values."""
126
+ if other:
127
+ if isinstance(other, HyperConfig):
128
+ self._data.update(other._data)
129
+ elif isinstance(other, dict):
130
+ self._data.update(self._flatten_dict(other))
131
+ self._data.update(self._flatten_dict(kwargs))
132
+ return self
133
+
134
+ def to_dict(self) -> Dict[str, Any]:
135
+ """Return unflattened dict representation."""
136
+ return self._unflatten_dict(self._data)
137
+
138
+ def copy(self) -> "HyperConfig":
139
+ """Return a shallow copy."""
140
+ return HyperConfig(copy.copy(self._data))
141
+
142
+ def deep_copy(self) -> "HyperConfig":
143
+ """Return a deep copy."""
144
+ return HyperConfig(copy.deepcopy(self._data))
145
+
146
+ @classmethod
147
+ def from_dict(cls, data: Dict[str, Any]) -> "HyperConfig":
148
+ """Create config from dict."""
149
+ return cls(data)
150
+
151
+ def items(self):
152
+ """Return unflattened items."""
153
+ return self._unflatten_dict(self._data).items()
154
+
155
+ def keys(self):
156
+ """Return unflattened keys."""
157
+ return self._unflatten_dict(self._data).keys()
158
+
159
+ def values(self):
160
+ """Return unflattened values."""
161
+ return self._unflatten_dict(self._data).values()
@@ -0,0 +1,81 @@
1
+ import torch
2
+ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
3
+
4
+ class StatefulDataLoader:
5
+ """A wrapper for PyTorch DataLoader that allows exact resumption.
6
+
7
+ It tracks the batch index and the exact permuted index list generated by
8
+ the RandomSampler. Upon resumption, it slices the iterator to return
9
+ only the remaining batches, without fast-forwarding the dataset visually.
10
+ """
11
+
12
+ def __init__(self, dataloader: DataLoader, base_seed: int = 42):
13
+ self.dataloader = dataloader
14
+ self.base_seed = base_seed
15
+ self.batch_idx = 0
16
+ self.epoch = 0
17
+ self._indices = []
18
+ self._iterator = None
19
+ self._generator = torch.Generator()
20
+
21
+ def __iter__(self):
22
+ # We manually seed the generator based on epoch for exact reproducible shuffling
23
+ self._generator.manual_seed(self.base_seed + self.epoch)
24
+
25
+ # We recreate the indices to maintain exact determinism
26
+ if isinstance(self.dataloader.sampler, RandomSampler):
27
+ # Using torch.randperm directly as the RandomSampler would with our generator
28
+ n = len(self.dataloader.dataset)
29
+ self._indices = torch.randperm(n, generator=self._generator).tolist()
30
+ elif isinstance(self.dataloader.sampler, SequentialSampler):
31
+ self._indices = list(range(len(self.dataloader.dataset)))
32
+ else:
33
+ # For custom samplers, we extract indices if we can
34
+ try:
35
+ self._indices = list(self.dataloader.sampler)
36
+ except Exception:
37
+ self._indices = []
38
+
39
+ # Fast forward if we resumed mid-epoch
40
+ if self.batch_idx > 0 and self._indices:
41
+ # Skip the items we already yielded
42
+ items_to_skip = self.batch_idx * self.dataloader.batch_size
43
+ self._indices = self._indices[items_to_skip:]
44
+
45
+ # create a subset dataloader or just yield properly
46
+ # In a true wrapper we'd modify the sampler, but for simplicity
47
+ # we iterate and skip quickly (which is still slow for massive datasets,
48
+ # so the true way is subclassing the sampler).
49
+
50
+ # Since standard dataloader isn't easily sliced without rewriting it:
51
+ self._iterator = iter(self.dataloader)
52
+ for _ in range(self.batch_idx):
53
+ next(self._iterator, None)
54
+ else:
55
+ self._iterator = iter(self.dataloader)
56
+
57
+ return self
58
+
59
+ def __next__(self):
60
+ try:
61
+ batch = next(self._iterator)
62
+ self.batch_idx += 1
63
+ return batch
64
+ except StopIteration:
65
+ self.epoch += 1
66
+ self.batch_idx = 0
67
+ raise
68
+
69
+ def state_dict(self):
70
+ return {
71
+ "batch_idx": self.batch_idx,
72
+ "epoch": self.epoch,
73
+ "indices": self._indices,
74
+ "base_seed": self.base_seed
75
+ }
76
+
77
+ def load_state_dict(self, state):
78
+ self.batch_idx = state.get("batch_idx", 0)
79
+ self.epoch = state.get("epoch", 0)
80
+ self._indices = state.get("indices", [])
81
+ self.base_seed = state.get("base_seed", 42)