pyspiral 0.6.6__cp312-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyspiral-0.6.6.dist-info/METADATA +51 -0
- pyspiral-0.6.6.dist-info/RECORD +102 -0
- pyspiral-0.6.6.dist-info/WHEEL +4 -0
- pyspiral-0.6.6.dist-info/entry_points.txt +2 -0
- spiral/__init__.py +35 -0
- spiral/_lib.abi3.so +0 -0
- spiral/adbc.py +411 -0
- spiral/api/__init__.py +78 -0
- spiral/api/admin.py +15 -0
- spiral/api/client.py +164 -0
- spiral/api/filesystems.py +134 -0
- spiral/api/key_space_indexes.py +23 -0
- spiral/api/organizations.py +77 -0
- spiral/api/projects.py +219 -0
- spiral/api/telemetry.py +19 -0
- spiral/api/text_indexes.py +56 -0
- spiral/api/types.py +22 -0
- spiral/api/workers.py +40 -0
- spiral/api/workloads.py +52 -0
- spiral/arrow_.py +216 -0
- spiral/cli/__init__.py +88 -0
- spiral/cli/__main__.py +4 -0
- spiral/cli/admin.py +14 -0
- spiral/cli/app.py +104 -0
- spiral/cli/console.py +95 -0
- spiral/cli/fs.py +76 -0
- spiral/cli/iceberg.py +97 -0
- spiral/cli/key_spaces.py +89 -0
- spiral/cli/login.py +24 -0
- spiral/cli/orgs.py +89 -0
- spiral/cli/printer.py +53 -0
- spiral/cli/projects.py +147 -0
- spiral/cli/state.py +5 -0
- spiral/cli/tables.py +174 -0
- spiral/cli/telemetry.py +17 -0
- spiral/cli/text.py +115 -0
- spiral/cli/types.py +50 -0
- spiral/cli/workloads.py +58 -0
- spiral/client.py +178 -0
- spiral/core/__init__.pyi +0 -0
- spiral/core/_tools/__init__.pyi +5 -0
- spiral/core/authn/__init__.pyi +27 -0
- spiral/core/client/__init__.pyi +237 -0
- spiral/core/table/__init__.pyi +101 -0
- spiral/core/table/manifests/__init__.pyi +35 -0
- spiral/core/table/metastore/__init__.pyi +58 -0
- spiral/core/table/spec/__init__.pyi +213 -0
- spiral/dataloader.py +285 -0
- spiral/dataset.py +255 -0
- spiral/datetime_.py +27 -0
- spiral/debug/__init__.py +0 -0
- spiral/debug/manifests.py +87 -0
- spiral/debug/metrics.py +56 -0
- spiral/debug/scan.py +266 -0
- spiral/expressions/__init__.py +276 -0
- spiral/expressions/base.py +157 -0
- spiral/expressions/http.py +86 -0
- spiral/expressions/io.py +100 -0
- spiral/expressions/list_.py +68 -0
- spiral/expressions/mp4.py +62 -0
- spiral/expressions/png.py +18 -0
- spiral/expressions/qoi.py +18 -0
- spiral/expressions/refs.py +58 -0
- spiral/expressions/str_.py +39 -0
- spiral/expressions/struct.py +59 -0
- spiral/expressions/text.py +62 -0
- spiral/expressions/tiff.py +223 -0
- spiral/expressions/udf.py +46 -0
- spiral/grpc_.py +32 -0
- spiral/iceberg.py +31 -0
- spiral/iterable_dataset.py +106 -0
- spiral/key_space_index.py +44 -0
- spiral/project.py +199 -0
- spiral/protogen/_/__init__.py +0 -0
- spiral/protogen/_/arrow/__init__.py +0 -0
- spiral/protogen/_/arrow/flight/__init__.py +0 -0
- spiral/protogen/_/arrow/flight/protocol/__init__.py +0 -0
- spiral/protogen/_/arrow/flight/protocol/sql/__init__.py +2548 -0
- spiral/protogen/_/google/__init__.py +0 -0
- spiral/protogen/_/google/protobuf/__init__.py +2310 -0
- spiral/protogen/_/message_pool.py +3 -0
- spiral/protogen/_/py.typed +0 -0
- spiral/protogen/_/scandal/__init__.py +190 -0
- spiral/protogen/_/spfs/__init__.py +72 -0
- spiral/protogen/_/spql/__init__.py +61 -0
- spiral/protogen/_/substrait/__init__.py +6196 -0
- spiral/protogen/_/substrait/extensions/__init__.py +169 -0
- spiral/protogen/__init__.py +0 -0
- spiral/protogen/util.py +41 -0
- spiral/py.typed +0 -0
- spiral/scan.py +285 -0
- spiral/server.py +17 -0
- spiral/settings.py +114 -0
- spiral/snapshot.py +56 -0
- spiral/streaming_/__init__.py +3 -0
- spiral/streaming_/reader.py +133 -0
- spiral/streaming_/stream.py +157 -0
- spiral/substrait_.py +274 -0
- spiral/table.py +293 -0
- spiral/text_index.py +17 -0
- spiral/transaction.py +58 -0
- spiral/types_.py +6 -0
@@ -0,0 +1,213 @@
|
|
1
|
+
"""Type definitions for the spiral.core.spec module shipped as part of the native library."""
|
2
|
+
|
3
|
+
import pyarrow as pa
|
4
|
+
|
5
|
+
class ColumnGroup:
|
6
|
+
def __init__(self, path: list[str]): ...
|
7
|
+
@property
|
8
|
+
def table_id(self) -> str: ...
|
9
|
+
@property
|
10
|
+
def path(self) -> list[str]: ...
|
11
|
+
def identifier(self, salt: int) -> str:
|
12
|
+
"""Return the column group identifier based on the given salt."""
|
13
|
+
|
14
|
+
@staticmethod
|
15
|
+
def from_str(path: str) -> ColumnGroup: ...
|
16
|
+
|
17
|
+
class KeySpaceMetadata:
|
18
|
+
def __init__(
|
19
|
+
self,
|
20
|
+
*,
|
21
|
+
manifest_handle: ManifestHandle | None,
|
22
|
+
last_modified_at: int,
|
23
|
+
): ...
|
24
|
+
|
25
|
+
manifest_handle: ManifestHandle | None
|
26
|
+
last_modified_at: int
|
27
|
+
|
28
|
+
def asof(self, asof: int) -> KeySpaceMetadata:
|
29
|
+
"""Returns the metadata as of a given timestamp. Currently just filtering versioned schemas."""
|
30
|
+
...
|
31
|
+
|
32
|
+
def apply_wal(self, wal: WriteAheadLog) -> KeySpaceMetadata:
|
33
|
+
"""Applies the given WAL to the metadata."""
|
34
|
+
|
35
|
+
class ColumnGroupMetadata:
|
36
|
+
def __init__(
|
37
|
+
self,
|
38
|
+
*,
|
39
|
+
column_group: ColumnGroup,
|
40
|
+
manifest_handle: ManifestHandle | None,
|
41
|
+
last_modified_at: int,
|
42
|
+
schema_versions: list[VersionedSchema] | None,
|
43
|
+
immutable_schema: bool,
|
44
|
+
schema_salt: int,
|
45
|
+
): ...
|
46
|
+
|
47
|
+
column_group: ColumnGroup
|
48
|
+
manifest_handle: ManifestHandle | None
|
49
|
+
last_modified_at: int
|
50
|
+
schema_versions: list[VersionedSchema]
|
51
|
+
immutable_schema: bool
|
52
|
+
schema_salt: int
|
53
|
+
|
54
|
+
def latest_schema(self) -> VersionedSchema:
|
55
|
+
"""Returns the latest schema of the column group."""
|
56
|
+
...
|
57
|
+
|
58
|
+
def asof(self, asof: int) -> ColumnGroupMetadata:
|
59
|
+
"""Returns the metadata as of a given timestamp. Currently just filtering versioned schemas."""
|
60
|
+
...
|
61
|
+
|
62
|
+
def apply_wal(self, wal: WriteAheadLog) -> ColumnGroupMetadata:
|
63
|
+
"""Applies the given WAL to the metadata."""
|
64
|
+
|
65
|
+
class LogEntry:
|
66
|
+
ts: int
|
67
|
+
operation: (
|
68
|
+
KeySpaceWriteOp
|
69
|
+
| ColumnGroupWriteOp
|
70
|
+
| SchemaEvolutionOp
|
71
|
+
| SchemaBreakOp
|
72
|
+
| KeySpaceCompactOp
|
73
|
+
| ColumnGroupCompactOp
|
74
|
+
)
|
75
|
+
|
76
|
+
def column_group(self) -> ColumnGroup | None:
|
77
|
+
"""Returns the column group of the entry if it is associated with one."""
|
78
|
+
|
79
|
+
class FileFormat:
|
80
|
+
def __init__(self, value: int): ...
|
81
|
+
|
82
|
+
Parquet: FileFormat
|
83
|
+
Protobuf: FileFormat
|
84
|
+
BinaryArray: FileFormat
|
85
|
+
Vortex: FileFormat
|
86
|
+
|
87
|
+
def __int__(self) -> int:
|
88
|
+
"""Returns the protobuf enum int value."""
|
89
|
+
...
|
90
|
+
|
91
|
+
def __str__(self) -> str:
|
92
|
+
"""Returns the string representation of the file format."""
|
93
|
+
...
|
94
|
+
|
95
|
+
class FragmentLevel:
|
96
|
+
L0: FragmentLevel
|
97
|
+
L1: FragmentLevel
|
98
|
+
|
99
|
+
def __int__(self) -> int:
|
100
|
+
"""Returns the protobuf enum int value."""
|
101
|
+
...
|
102
|
+
|
103
|
+
class Key:
|
104
|
+
def __init__(self, key: bytes): ...
|
105
|
+
def __bytes__(self): ...
|
106
|
+
def step(self) -> Key:
|
107
|
+
"""Returns the next key in the key space."""
|
108
|
+
|
109
|
+
@staticmethod
|
110
|
+
def min() -> Key: ...
|
111
|
+
@staticmethod
|
112
|
+
def max() -> Key: ...
|
113
|
+
def __reduce__(self) -> tuple[type[Key], tuple[bytes]]: ...
|
114
|
+
|
115
|
+
class KeyExtent:
|
116
|
+
"""An inclusive range of keys."""
|
117
|
+
|
118
|
+
def __init__(self, *, min: Key, max: Key): ...
|
119
|
+
|
120
|
+
min: Key
|
121
|
+
max: Key
|
122
|
+
|
123
|
+
def union(self, key_extent: KeyExtent) -> KeyExtent: ...
|
124
|
+
def __or__(self, other: KeyExtent) -> KeyExtent: ...
|
125
|
+
def intersection(self, key_extent: KeyExtent) -> KeyExtent | None: ...
|
126
|
+
def __and__(self, other: KeyExtent) -> KeyExtent | None: ...
|
127
|
+
def contains(self, item: Key) -> bool: ...
|
128
|
+
def __contains__(self, item: Key) -> bool: ...
|
129
|
+
|
130
|
+
class KeySpan:
|
131
|
+
"""An exclusive range of keys as indexed by their position in a key space."""
|
132
|
+
|
133
|
+
def __init__(self, *, begin: int, end: int): ...
|
134
|
+
|
135
|
+
begin: int
|
136
|
+
end: int
|
137
|
+
|
138
|
+
def __len__(self) -> int: ...
|
139
|
+
def shift(self, offset: int) -> KeySpan: ...
|
140
|
+
def union(self, other: KeySpan) -> KeySpan: ...
|
141
|
+
def __or__(self, other: KeySpan) -> KeySpan: ...
|
142
|
+
|
143
|
+
class ManifestHandle:
|
144
|
+
id: str
|
145
|
+
format: FileFormat
|
146
|
+
file_size: int
|
147
|
+
|
148
|
+
class Schema:
|
149
|
+
def to_arrow(self) -> pa.Schema:
|
150
|
+
"""Returns the Arrow schema."""
|
151
|
+
...
|
152
|
+
@staticmethod
|
153
|
+
def from_arrow(arrow: pa.Schema) -> Schema:
|
154
|
+
"""Creates a Schema from an Arrow schema."""
|
155
|
+
...
|
156
|
+
def __len__(self):
|
157
|
+
"""Returns the number of columns in the schema."""
|
158
|
+
...
|
159
|
+
@property
|
160
|
+
def names(self) -> list[str]:
|
161
|
+
"""Returns the names of the columns in the schema."""
|
162
|
+
...
|
163
|
+
|
164
|
+
class VersionedSchema:
|
165
|
+
ts: int
|
166
|
+
schema: Schema
|
167
|
+
column_ids: list[str]
|
168
|
+
|
169
|
+
class KeySpaceWriteOp:
|
170
|
+
ks_id: str
|
171
|
+
manifest_handle: ManifestHandle
|
172
|
+
|
173
|
+
class ColumnGroupWriteOp:
|
174
|
+
column_group: ColumnGroup
|
175
|
+
level: FragmentLevel
|
176
|
+
manifest_handle: ManifestHandle
|
177
|
+
key_span: KeySpan
|
178
|
+
key_extent: KeyExtent
|
179
|
+
column_ids: list[str]
|
180
|
+
|
181
|
+
class SchemaEvolutionOp:
|
182
|
+
column_group: ColumnGroup
|
183
|
+
|
184
|
+
class SchemaBreakOp:
|
185
|
+
column_group: ColumnGroup
|
186
|
+
|
187
|
+
class KeySpaceCompactOp:
|
188
|
+
ks_ids: list[str]
|
189
|
+
moved_ks_ids: list[str]
|
190
|
+
|
191
|
+
class ColumnGroupCompactOp:
|
192
|
+
column_group: ColumnGroup
|
193
|
+
fragment_ids: list[int]
|
194
|
+
|
195
|
+
class WriteAheadLog:
|
196
|
+
def __init__(
|
197
|
+
self,
|
198
|
+
*,
|
199
|
+
entries: list[LogEntry] | None = None,
|
200
|
+
truncated_up_to: int = 0,
|
201
|
+
): ...
|
202
|
+
|
203
|
+
entries: list[LogEntry]
|
204
|
+
truncated_up_to: int
|
205
|
+
|
206
|
+
@property
|
207
|
+
def last_modified_at(self) -> int:
|
208
|
+
"""Returns the timestamp of the last modification of the log."""
|
209
|
+
|
210
|
+
def filter(
|
211
|
+
self, asof: int | None = None, since: int | None = None, column_group: ColumnGroup | None = None
|
212
|
+
) -> WriteAheadLog:
|
213
|
+
"""Filters the WAL to entries by the given parameters."""
|
spiral/dataloader.py
ADDED
@@ -0,0 +1,285 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import os
|
4
|
+
import random
|
5
|
+
from collections.abc import Callable, Iterator
|
6
|
+
from dataclasses import dataclass
|
7
|
+
from functools import partial
|
8
|
+
from multiprocessing import Pool
|
9
|
+
from typing import Any
|
10
|
+
|
11
|
+
import pyarrow as pa
|
12
|
+
|
13
|
+
from spiral.core.client import Shard
|
14
|
+
from spiral.scan import Scan
|
15
|
+
|
16
|
+
|
17
|
+
@dataclass(frozen=True)
|
18
|
+
class World:
|
19
|
+
"""Distributed training configuration.
|
20
|
+
Attributes:
|
21
|
+
rank: Process rank (0 to world_size-1).
|
22
|
+
world_size: Total number of processes.
|
23
|
+
"""
|
24
|
+
|
25
|
+
rank: int
|
26
|
+
world_size: int
|
27
|
+
|
28
|
+
def shards(
|
29
|
+
self,
|
30
|
+
shards: list[Shard],
|
31
|
+
shuffle_seed: int | None = None,
|
32
|
+
) -> list[Shard]:
|
33
|
+
"""Partition shards for distributed training.
|
34
|
+
|
35
|
+
Args:
|
36
|
+
shards: List of Shard objects to partition.
|
37
|
+
shuffle_seed: Optional seed to shuffle before partitioning.
|
38
|
+
|
39
|
+
Returns:
|
40
|
+
Subset of shards for this rank (round-robin partitioning).
|
41
|
+
"""
|
42
|
+
if shuffle_seed is not None:
|
43
|
+
shards = World._shuffle(shards, shuffle_seed)
|
44
|
+
|
45
|
+
return shards[self.rank :: self.world_size]
|
46
|
+
|
47
|
+
@classmethod
|
48
|
+
def from_torch(cls) -> World:
|
49
|
+
"""Auto-detect world configuration from PyTorch distributed."""
|
50
|
+
try:
|
51
|
+
import torch.distributed as dist
|
52
|
+
|
53
|
+
if dist.is_available() and dist.is_initialized():
|
54
|
+
return cls(
|
55
|
+
rank=dist.get_rank(),
|
56
|
+
world_size=dist.get_world_size(),
|
57
|
+
)
|
58
|
+
except ImportError:
|
59
|
+
pass
|
60
|
+
|
61
|
+
return cls(
|
62
|
+
rank=int(os.environ.get("RANK", 0)),
|
63
|
+
world_size=int(os.environ.get("WORLD_SIZE", 1)),
|
64
|
+
)
|
65
|
+
|
66
|
+
@classmethod
|
67
|
+
def _shuffle(cls, shards: list[Shard], seed: int) -> list[Shard]:
|
68
|
+
"""Shuffle shards deterministically with given seed."""
|
69
|
+
shuffled = list(shards)
|
70
|
+
random.Random(seed).shuffle(shuffled)
|
71
|
+
return shuffled
|
72
|
+
|
73
|
+
|
74
|
+
# Top level so we can pickle this function
|
75
|
+
def _len_and_transform(batch: pa.RecordBatch, transform_fn: Callable) -> tuple[int, Any]:
|
76
|
+
return (len(batch), transform_fn(batch))
|
77
|
+
|
78
|
+
|
79
|
+
class SpiralDataLoader:
|
80
|
+
"""DataLoader optimized for Spiral's multi-threaded streaming architecture.
|
81
|
+
|
82
|
+
Unlike PyTorch's DataLoader which uses multiprocessing for I/O (num_workers),
|
83
|
+
SpiralDataLoader leverages Spiral's efficient Rust-based streaming and only
|
84
|
+
uses multiprocessing for CPU-bound post-processing transforms.
|
85
|
+
|
86
|
+
Key differences from PyTorch DataLoader:
|
87
|
+
- No num_workers for I/O (Spiral's Rust layer is already multi-threaded)
|
88
|
+
- map_workers for parallel post-processing (tokenization, decoding, etc.)
|
89
|
+
- Built-in checkpoint support via skip_samples
|
90
|
+
- Explicit shard-based architecture for distributed training
|
91
|
+
"""
|
92
|
+
|
93
|
+
# Example usage:
|
94
|
+
#
|
95
|
+
# Simple usage:
|
96
|
+
# loader = SpiralDataLoader(scan, batch_size=32)
|
97
|
+
# for batch in loader:
|
98
|
+
# train_step(batch)
|
99
|
+
#
|
100
|
+
# With parallel transforms:
|
101
|
+
# loader = SpiralDataLoader(
|
102
|
+
# scan,
|
103
|
+
# batch_size=32,
|
104
|
+
# transform_fn=tokenize_batch,
|
105
|
+
# map_workers=4,
|
106
|
+
# )
|
107
|
+
|
108
|
+
def __init__(
|
109
|
+
self,
|
110
|
+
scan: Scan,
|
111
|
+
*,
|
112
|
+
shards: list[Shard] | None = None,
|
113
|
+
shuffle_shards: bool = True,
|
114
|
+
seed: int = 42,
|
115
|
+
skip_samples: int = 0,
|
116
|
+
shuffle_buffer_size: int = 0,
|
117
|
+
batch_size: int = 32,
|
118
|
+
batch_readahead: int | None = None,
|
119
|
+
# TODO(os): accept vortex arrays here instead of Arrow
|
120
|
+
transform_fn: Callable[[pa.RecordBatch], Any] | None = None,
|
121
|
+
map_workers: int = 0,
|
122
|
+
):
|
123
|
+
"""Initialize SpiralDataLoader.
|
124
|
+
|
125
|
+
Args:
|
126
|
+
scan: Spiral scan to load data from.
|
127
|
+
shards: Optional list of Shard objects to read. If None, uses
|
128
|
+
scan's natural sharding based on physical layout.
|
129
|
+
shuffle_shards: Whether to shuffle the list of shards.
|
130
|
+
Uses the provided seed.
|
131
|
+
seed: Base random seed for deterministic shuffling and checkpointing.
|
132
|
+
skip_samples: Number of samples to skip at the beginning (for resuming
|
133
|
+
from checkpoint).
|
134
|
+
shuffle_buffer_size: Size of shuffle buffer for within-shard shuffling.
|
135
|
+
0 means no shuffling.
|
136
|
+
batch_size: Number of rows per batch.
|
137
|
+
batch_readahead: Number of batches to prefetch in background. If None,
|
138
|
+
uses a sensible default based on whether transforms are applied.
|
139
|
+
transform_fn: Optional function to transform each batch. Takes a PyArrow
|
140
|
+
RecordBatch and returns any type. Users can call batch.to_pydict()
|
141
|
+
inside the function if they need a dict. If map_workers > 0, this
|
142
|
+
function must be picklable.
|
143
|
+
map_workers: Number of worker processes for parallel transform_fn
|
144
|
+
application. 0 means single-process (no parallelism). Use this for
|
145
|
+
CPU-bound transforms like tokenization or audio decoding.
|
146
|
+
"""
|
147
|
+
self.scan = scan
|
148
|
+
self.shards = shards if shards is not None else scan.shards()
|
149
|
+
if shuffle_shards:
|
150
|
+
self.shards = World._shuffle(self.shards, seed)
|
151
|
+
self.seed = seed
|
152
|
+
self.skip_samples = skip_samples
|
153
|
+
self.shuffle_buffer_size = shuffle_buffer_size
|
154
|
+
self.batch_size = batch_size
|
155
|
+
self.batch_readahead = batch_readahead
|
156
|
+
self.transform_fn = transform_fn
|
157
|
+
self.map_workers = map_workers
|
158
|
+
|
159
|
+
self._samples_yielded = 0
|
160
|
+
|
161
|
+
def __iter__(self) -> Iterator[Any]:
|
162
|
+
"""Iterate over batches."""
|
163
|
+
from spiral.core.client import ShuffleConfig
|
164
|
+
|
165
|
+
shuffle = None
|
166
|
+
if self.shuffle_buffer_size > 0:
|
167
|
+
shuffle = ShuffleConfig(
|
168
|
+
buffer_size=self.shuffle_buffer_size,
|
169
|
+
seed=self.seed,
|
170
|
+
)
|
171
|
+
|
172
|
+
stream = self.scan.core.to_shuffled_record_batches(
|
173
|
+
shards=self.shards,
|
174
|
+
shuffle=shuffle,
|
175
|
+
max_batch_size=self.batch_size,
|
176
|
+
batch_readahead=self.batch_readahead,
|
177
|
+
infinite=False,
|
178
|
+
)
|
179
|
+
|
180
|
+
if self.skip_samples > 0:
|
181
|
+
|
182
|
+
def skip(s: Iterator[pa.RecordBatch], skip_count: int) -> Iterator[pa.RecordBatch]:
|
183
|
+
"""Skip samples from stream, yielding remaining batches."""
|
184
|
+
skipped = 0
|
185
|
+
for batch in s:
|
186
|
+
batch_size = len(batch)
|
187
|
+
if skipped + batch_size <= skip_count:
|
188
|
+
# Skip entire batch
|
189
|
+
skipped += batch_size
|
190
|
+
continue
|
191
|
+
elif skipped < skip_count:
|
192
|
+
# Partial skip - discard first N samples, yield remainder
|
193
|
+
skip_in_batch = skip_count - skipped
|
194
|
+
skipped = skip_count
|
195
|
+
yield batch[skip_in_batch:]
|
196
|
+
else:
|
197
|
+
# take the entire batch
|
198
|
+
yield batch
|
199
|
+
|
200
|
+
stream = skip(stream, self.skip_samples)
|
201
|
+
|
202
|
+
if self.transform_fn is None:
|
203
|
+
for batch in stream:
|
204
|
+
self._samples_yielded += len(batch)
|
205
|
+
yield batch
|
206
|
+
elif self.map_workers == 0:
|
207
|
+
# Single-process transform
|
208
|
+
for batch in stream:
|
209
|
+
result = self.transform_fn(batch)
|
210
|
+
self._samples_yielded += len(batch)
|
211
|
+
yield result
|
212
|
+
else:
|
213
|
+
with Pool(self.map_workers) as pool:
|
214
|
+
for batch_len, result in pool.imap(partial(_len_and_transform, transform_fn=self.transform_fn), stream):
|
215
|
+
self._samples_yielded += batch_len
|
216
|
+
yield result
|
217
|
+
|
218
|
+
def state_dict(self) -> dict[str, Any]:
|
219
|
+
"""Get checkpoint state for resuming.
|
220
|
+
|
221
|
+
Returns:
|
222
|
+
Dictionary containing samples_yielded, seed, and shards.
|
223
|
+
"""
|
224
|
+
# Example usage:
|
225
|
+
# loader = SpiralDataLoader(scan, batch_size=32, seed=42)
|
226
|
+
# for i, batch in enumerate(loader):
|
227
|
+
# if i == 10:
|
228
|
+
# checkpoint = loader.state_dict()
|
229
|
+
# break
|
230
|
+
#
|
231
|
+
# # Resume later with exact same shards
|
232
|
+
# loader = SpiralDataLoader.from_state_dict(scan, checkpoint, batch_size=32)
|
233
|
+
return {
|
234
|
+
"samples_yielded": self._samples_yielded,
|
235
|
+
"seed": self.seed,
|
236
|
+
"shards": self.shards, # Will be pickled automatically
|
237
|
+
}
|
238
|
+
|
239
|
+
@classmethod
|
240
|
+
def from_state_dict(
|
241
|
+
cls,
|
242
|
+
scan: Scan,
|
243
|
+
state: dict[str, Any],
|
244
|
+
**kwargs,
|
245
|
+
) -> SpiralDataLoader:
|
246
|
+
"""Create a DataLoader from checkpoint state, resuming from where it left off.
|
247
|
+
|
248
|
+
This is the recommended way to resume training from a checkpoint. It extracts
|
249
|
+
the seed, samples_yielded, and shards from the state dict and creates a new
|
250
|
+
DataLoader that will skip the already-processed samples.
|
251
|
+
|
252
|
+
Args:
|
253
|
+
scan: Spiral scan to load data from.
|
254
|
+
state: Checkpoint state from state_dict().
|
255
|
+
**kwargs: Additional arguments to pass to SpiralDataLoader constructor.
|
256
|
+
These will override values in the state dict where applicable.
|
257
|
+
|
258
|
+
Returns:
|
259
|
+
New SpiralDataLoader instance configured to resume from the checkpoint.
|
260
|
+
"""
|
261
|
+
# Example usage:
|
262
|
+
#
|
263
|
+
# Save checkpoint during training:
|
264
|
+
# loader = scan.to_distributed_data_loader(scan, batch_size=32, seed=42)
|
265
|
+
# checkpoint = loader.state_dict()
|
266
|
+
#
|
267
|
+
# Resume later using the same shards from checkpoint:
|
268
|
+
# resumed_loader = SpiralDataLoader.from_state_dict(
|
269
|
+
# scan,
|
270
|
+
# checkpoint,
|
271
|
+
# batch_size=32,
|
272
|
+
# transform_fn=my_transform,
|
273
|
+
# )
|
274
|
+
|
275
|
+
# Extract resume parameters from state
|
276
|
+
seed = state.get("seed", 42)
|
277
|
+
skip_samples = state.get("samples_yielded", 0)
|
278
|
+
shards = state.get("shards")
|
279
|
+
|
280
|
+
# Allow kwargs to override state dict values
|
281
|
+
seed = kwargs.pop("seed", seed)
|
282
|
+
skip_samples = kwargs.pop("skip_samples", skip_samples)
|
283
|
+
shards = kwargs.pop("shards", shards)
|
284
|
+
|
285
|
+
return cls(scan, seed=seed, skip_samples=skip_samples, shards=shards, **kwargs)
|