pyspiral 0.7.18__cp312-abi3-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyspiral-0.7.18.dist-info/METADATA +52 -0
- pyspiral-0.7.18.dist-info/RECORD +110 -0
- pyspiral-0.7.18.dist-info/WHEEL +4 -0
- pyspiral-0.7.18.dist-info/entry_points.txt +3 -0
- spiral/__init__.py +55 -0
- spiral/_lib.abi3.so +0 -0
- spiral/adbc.py +411 -0
- spiral/api/__init__.py +78 -0
- spiral/api/admin.py +15 -0
- spiral/api/client.py +164 -0
- spiral/api/filesystems.py +134 -0
- spiral/api/key_space_indexes.py +23 -0
- spiral/api/organizations.py +77 -0
- spiral/api/projects.py +219 -0
- spiral/api/telemetry.py +19 -0
- spiral/api/text_indexes.py +56 -0
- spiral/api/types.py +23 -0
- spiral/api/workers.py +40 -0
- spiral/api/workloads.py +52 -0
- spiral/arrow_.py +216 -0
- spiral/cli/__init__.py +88 -0
- spiral/cli/__main__.py +4 -0
- spiral/cli/admin.py +14 -0
- spiral/cli/app.py +108 -0
- spiral/cli/console.py +95 -0
- spiral/cli/fs.py +76 -0
- spiral/cli/iceberg.py +97 -0
- spiral/cli/key_spaces.py +103 -0
- spiral/cli/login.py +25 -0
- spiral/cli/orgs.py +90 -0
- spiral/cli/printer.py +53 -0
- spiral/cli/projects.py +147 -0
- spiral/cli/state.py +7 -0
- spiral/cli/tables.py +197 -0
- spiral/cli/telemetry.py +17 -0
- spiral/cli/text.py +115 -0
- spiral/cli/types.py +50 -0
- spiral/cli/workloads.py +58 -0
- spiral/client.py +256 -0
- spiral/core/__init__.pyi +0 -0
- spiral/core/_tools/__init__.pyi +5 -0
- spiral/core/authn/__init__.pyi +21 -0
- spiral/core/client/__init__.pyi +285 -0
- spiral/core/config/__init__.pyi +35 -0
- spiral/core/expr/__init__.pyi +15 -0
- spiral/core/expr/images/__init__.pyi +3 -0
- spiral/core/expr/list_/__init__.pyi +4 -0
- spiral/core/expr/refs/__init__.pyi +4 -0
- spiral/core/expr/str_/__init__.pyi +3 -0
- spiral/core/expr/struct_/__init__.pyi +6 -0
- spiral/core/expr/text/__init__.pyi +5 -0
- spiral/core/expr/udf/__init__.pyi +14 -0
- spiral/core/expr/video/__init__.pyi +3 -0
- spiral/core/table/__init__.pyi +141 -0
- spiral/core/table/manifests/__init__.pyi +35 -0
- spiral/core/table/metastore/__init__.pyi +58 -0
- spiral/core/table/spec/__init__.pyi +215 -0
- spiral/dataloader.py +299 -0
- spiral/dataset.py +264 -0
- spiral/datetime_.py +27 -0
- spiral/debug/__init__.py +0 -0
- spiral/debug/manifests.py +87 -0
- spiral/debug/metrics.py +56 -0
- spiral/debug/scan.py +266 -0
- spiral/enrichment.py +306 -0
- spiral/expressions/__init__.py +274 -0
- spiral/expressions/base.py +167 -0
- spiral/expressions/file.py +17 -0
- spiral/expressions/http.py +17 -0
- spiral/expressions/list_.py +68 -0
- spiral/expressions/s3.py +16 -0
- spiral/expressions/str_.py +39 -0
- spiral/expressions/struct.py +59 -0
- spiral/expressions/text.py +62 -0
- spiral/expressions/tiff.py +222 -0
- spiral/expressions/udf.py +60 -0
- spiral/grpc_.py +32 -0
- spiral/iceberg.py +31 -0
- spiral/iterable_dataset.py +106 -0
- spiral/key_space_index.py +44 -0
- spiral/project.py +227 -0
- spiral/protogen/_/__init__.py +0 -0
- spiral/protogen/_/arrow/__init__.py +0 -0
- spiral/protogen/_/arrow/flight/__init__.py +0 -0
- spiral/protogen/_/arrow/flight/protocol/__init__.py +0 -0
- spiral/protogen/_/arrow/flight/protocol/sql/__init__.py +2548 -0
- spiral/protogen/_/google/__init__.py +0 -0
- spiral/protogen/_/google/protobuf/__init__.py +2310 -0
- spiral/protogen/_/message_pool.py +3 -0
- spiral/protogen/_/py.typed +0 -0
- spiral/protogen/_/scandal/__init__.py +190 -0
- spiral/protogen/_/spfs/__init__.py +72 -0
- spiral/protogen/_/spql/__init__.py +61 -0
- spiral/protogen/_/substrait/__init__.py +6196 -0
- spiral/protogen/_/substrait/extensions/__init__.py +169 -0
- spiral/protogen/__init__.py +0 -0
- spiral/protogen/util.py +41 -0
- spiral/py.typed +0 -0
- spiral/scan.py +363 -0
- spiral/server.py +17 -0
- spiral/settings.py +36 -0
- spiral/snapshot.py +56 -0
- spiral/streaming_/__init__.py +3 -0
- spiral/streaming_/reader.py +133 -0
- spiral/streaming_/stream.py +157 -0
- spiral/substrait_.py +274 -0
- spiral/table.py +224 -0
- spiral/text_index.py +17 -0
- spiral/transaction.py +155 -0
- spiral/types_.py +6 -0
spiral/dataloader.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import random
|
|
5
|
+
from collections.abc import Callable, Iterator
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from functools import partial
|
|
8
|
+
from multiprocessing import Pool
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
import pyarrow as pa
|
|
12
|
+
|
|
13
|
+
from spiral.core.client import Shard
|
|
14
|
+
from spiral.scan import Scan
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass(frozen=True)
|
|
18
|
+
class World:
|
|
19
|
+
"""Distributed training configuration.
|
|
20
|
+
Attributes:
|
|
21
|
+
rank: Process rank (0 to world_size-1).
|
|
22
|
+
world_size: Total number of processes.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
rank: int
|
|
26
|
+
world_size: int
|
|
27
|
+
|
|
28
|
+
def shards(
|
|
29
|
+
self,
|
|
30
|
+
shards: list[Shard],
|
|
31
|
+
shuffle_seed: int | None = None,
|
|
32
|
+
) -> list[Shard]:
|
|
33
|
+
"""Partition shards for distributed training.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
shards: List of Shard objects to partition.
|
|
37
|
+
shuffle_seed: Optional seed to shuffle before partitioning.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
Subset of shards for this rank (round-robin partitioning).
|
|
41
|
+
"""
|
|
42
|
+
if shuffle_seed is not None:
|
|
43
|
+
shards = World._shuffle(shards, shuffle_seed)
|
|
44
|
+
|
|
45
|
+
return shards[self.rank :: self.world_size]
|
|
46
|
+
|
|
47
|
+
@classmethod
|
|
48
|
+
def from_torch(cls) -> World:
|
|
49
|
+
"""Auto-detect world configuration from PyTorch distributed."""
|
|
50
|
+
try:
|
|
51
|
+
import torch.distributed as dist
|
|
52
|
+
|
|
53
|
+
if dist.is_available() and dist.is_initialized():
|
|
54
|
+
return cls(
|
|
55
|
+
rank=dist.get_rank(),
|
|
56
|
+
world_size=dist.get_world_size(),
|
|
57
|
+
)
|
|
58
|
+
except ImportError:
|
|
59
|
+
pass
|
|
60
|
+
|
|
61
|
+
return cls(
|
|
62
|
+
rank=int(os.environ.get("RANK", 0)),
|
|
63
|
+
world_size=int(os.environ.get("WORLD_SIZE", 1)),
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
@classmethod
|
|
67
|
+
def _shuffle(cls, shards: list[Shard], seed: int) -> list[Shard]:
|
|
68
|
+
"""Shuffle shards deterministically with given seed."""
|
|
69
|
+
shuffled = list(shards)
|
|
70
|
+
random.Random(seed).shuffle(shuffled)
|
|
71
|
+
return shuffled
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
# Top level so we can pickle this function
|
|
75
|
+
def _len_and_transform(batch: pa.RecordBatch, transform_fn: Callable) -> tuple[int, Any]:
|
|
76
|
+
return (len(batch), transform_fn(batch))
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class SpiralDataLoader:
|
|
80
|
+
"""DataLoader optimized for Spiral's multi-threaded streaming architecture.
|
|
81
|
+
|
|
82
|
+
Unlike PyTorch's DataLoader which uses multiprocessing for I/O (num_workers),
|
|
83
|
+
SpiralDataLoader leverages Spiral's efficient Rust-based streaming and only
|
|
84
|
+
uses multiprocessing for CPU-bound post-processing transforms.
|
|
85
|
+
|
|
86
|
+
Key differences from PyTorch DataLoader:
|
|
87
|
+
- No num_workers for I/O (Spiral's Rust layer is already multi-threaded)
|
|
88
|
+
- map_workers for parallel post-processing (tokenization, decoding, etc.)
|
|
89
|
+
- Built-in checkpoint support via skip_samples
|
|
90
|
+
- Explicit shard-based architecture for distributed training
|
|
91
|
+
|
|
92
|
+
Simple usage:
|
|
93
|
+
```python
|
|
94
|
+
loader = SpiralDataLoader(scan, batch_size=32)
|
|
95
|
+
for batch in loader:
|
|
96
|
+
train_step(batch)
|
|
97
|
+
```
|
|
98
|
+
|
|
99
|
+
With parallel transforms:
|
|
100
|
+
```python
|
|
101
|
+
loader = SpiralDataLoader(
|
|
102
|
+
scan,
|
|
103
|
+
batch_size=32,
|
|
104
|
+
transform_fn=tokenize_batch,
|
|
105
|
+
map_workers=4,
|
|
106
|
+
)
|
|
107
|
+
```
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
def __init__(
|
|
111
|
+
self,
|
|
112
|
+
scan: Scan,
|
|
113
|
+
*,
|
|
114
|
+
shards: list[Shard] | None = None,
|
|
115
|
+
shuffle_shards: bool = True,
|
|
116
|
+
seed: int = 42,
|
|
117
|
+
skip_samples: int = 0,
|
|
118
|
+
shuffle_buffer_size: int = 0,
|
|
119
|
+
batch_size: int = 32,
|
|
120
|
+
batch_readahead: int | None = None,
|
|
121
|
+
# TODO(os): accept vortex arrays here instead of Arrow
|
|
122
|
+
transform_fn: Callable[[pa.RecordBatch], Any] | None = None,
|
|
123
|
+
map_workers: int = 0,
|
|
124
|
+
infinite: bool = False,
|
|
125
|
+
):
|
|
126
|
+
"""Initialize SpiralDataLoader.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
scan: Spiral scan to load data from.
|
|
130
|
+
shards: Optional list of Shard objects to read. If None, uses
|
|
131
|
+
scan's natural sharding based on physical layout.
|
|
132
|
+
shuffle_shards: Whether to shuffle the list of shards.
|
|
133
|
+
Uses the provided seed.
|
|
134
|
+
seed: Base random seed for deterministic shuffling and checkpointing.
|
|
135
|
+
skip_samples: Number of samples to skip at the beginning (for resuming
|
|
136
|
+
from checkpoint).
|
|
137
|
+
shuffle_buffer_size: Size of shuffle buffer for within-shard shuffling.
|
|
138
|
+
0 means no shuffling.
|
|
139
|
+
batch_size: Number of rows per batch.
|
|
140
|
+
batch_readahead: Number of batches to prefetch in background. If None,
|
|
141
|
+
uses a sensible default based on whether transforms are applied.
|
|
142
|
+
transform_fn: Optional function to transform each batch. Takes a PyArrow
|
|
143
|
+
RecordBatch and returns any type. Users can call batch.to_pydict()
|
|
144
|
+
inside the function if they need a dict. If map_workers > 0, this
|
|
145
|
+
function must be picklable.
|
|
146
|
+
map_workers: Number of worker processes for parallel transform_fn
|
|
147
|
+
application. 0 means single-process (no parallelism). Use this for
|
|
148
|
+
CPU-bound transforms like tokenization or audio decoding.
|
|
149
|
+
infinite: Whether to cycle through the dataset infinitely. If True,
|
|
150
|
+
the dataloader will repeat the dataset indefinitely. If False,
|
|
151
|
+
the dataloader will stop after going through the dataset once.
|
|
152
|
+
"""
|
|
153
|
+
self.scan = scan
|
|
154
|
+
self.shards = shards if shards is not None else scan.shards()
|
|
155
|
+
if shuffle_shards:
|
|
156
|
+
self.shards = World._shuffle(self.shards, seed)
|
|
157
|
+
self.seed = seed
|
|
158
|
+
self.skip_samples = skip_samples
|
|
159
|
+
self.shuffle_buffer_size = shuffle_buffer_size
|
|
160
|
+
self.batch_size = batch_size
|
|
161
|
+
self.batch_readahead = batch_readahead
|
|
162
|
+
self.transform_fn = transform_fn
|
|
163
|
+
self.map_workers = map_workers
|
|
164
|
+
self.infinite = infinite
|
|
165
|
+
|
|
166
|
+
self._samples_yielded = 0
|
|
167
|
+
|
|
168
|
+
def __iter__(self) -> Iterator[Any]:
|
|
169
|
+
"""Iterate over batches."""
|
|
170
|
+
from spiral.core.client import ShuffleConfig
|
|
171
|
+
|
|
172
|
+
shuffle = None
|
|
173
|
+
if self.shuffle_buffer_size > 0:
|
|
174
|
+
shuffle = ShuffleConfig(
|
|
175
|
+
buffer_size=self.shuffle_buffer_size,
|
|
176
|
+
seed=self.seed,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
stream = self.scan.core.to_shuffled_record_batches(
|
|
180
|
+
shards=self.shards,
|
|
181
|
+
shuffle=shuffle,
|
|
182
|
+
max_batch_size=self.batch_size,
|
|
183
|
+
batch_readahead=self.batch_readahead,
|
|
184
|
+
infinite=self.infinite,
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
if self.skip_samples > 0:
|
|
188
|
+
|
|
189
|
+
def skip(s: Iterator[pa.RecordBatch], skip_count: int) -> Iterator[pa.RecordBatch]:
|
|
190
|
+
"""Skip samples from stream, yielding remaining batches."""
|
|
191
|
+
skipped = 0
|
|
192
|
+
for batch in s:
|
|
193
|
+
batch_size = len(batch)
|
|
194
|
+
if skipped + batch_size <= skip_count:
|
|
195
|
+
# Skip entire batch
|
|
196
|
+
skipped += batch_size
|
|
197
|
+
continue
|
|
198
|
+
elif skipped < skip_count:
|
|
199
|
+
# Partial skip - discard first N samples, yield remainder
|
|
200
|
+
skip_in_batch = skip_count - skipped
|
|
201
|
+
skipped = skip_count
|
|
202
|
+
yield batch[skip_in_batch:]
|
|
203
|
+
else:
|
|
204
|
+
# take the entire batch
|
|
205
|
+
yield batch
|
|
206
|
+
|
|
207
|
+
stream = skip(stream, self.skip_samples)
|
|
208
|
+
|
|
209
|
+
if self.transform_fn is None:
|
|
210
|
+
for batch in stream:
|
|
211
|
+
self._samples_yielded += len(batch)
|
|
212
|
+
yield batch
|
|
213
|
+
elif self.map_workers == 0:
|
|
214
|
+
# Single-process transform
|
|
215
|
+
for batch in stream:
|
|
216
|
+
result = self.transform_fn(batch)
|
|
217
|
+
self._samples_yielded += len(batch)
|
|
218
|
+
yield result
|
|
219
|
+
else:
|
|
220
|
+
with Pool(self.map_workers) as pool:
|
|
221
|
+
for batch_len, result in pool.imap(partial(_len_and_transform, transform_fn=self.transform_fn), stream):
|
|
222
|
+
self._samples_yielded += batch_len
|
|
223
|
+
yield result
|
|
224
|
+
|
|
225
|
+
def state_dict(self) -> dict[str, Any]:
|
|
226
|
+
"""Get checkpoint state for resuming.
|
|
227
|
+
|
|
228
|
+
Returns:
|
|
229
|
+
Dictionary containing samples_yielded, seed, and shards.
|
|
230
|
+
|
|
231
|
+
Example checkpoint:
|
|
232
|
+
```python
|
|
233
|
+
loader = SpiralDataLoader(scan, batch_size=32, seed=42)
|
|
234
|
+
for i, batch in enumerate(loader):
|
|
235
|
+
if i == 10:
|
|
236
|
+
checkpoint = loader.state_dict()
|
|
237
|
+
break
|
|
238
|
+
```
|
|
239
|
+
|
|
240
|
+
Example resume:
|
|
241
|
+
```python
|
|
242
|
+
loader = SpiralDataLoader.from_state_dict(scan, checkpoint, batch_size=32)
|
|
243
|
+
```
|
|
244
|
+
"""
|
|
245
|
+
return {
|
|
246
|
+
"samples_yielded": self._samples_yielded,
|
|
247
|
+
"seed": self.seed,
|
|
248
|
+
"shards": self.shards, # Will be pickled automatically
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
@classmethod
|
|
252
|
+
def from_state_dict(
|
|
253
|
+
cls,
|
|
254
|
+
scan: Scan,
|
|
255
|
+
state: dict[str, Any],
|
|
256
|
+
**kwargs,
|
|
257
|
+
) -> SpiralDataLoader:
|
|
258
|
+
"""Create a DataLoader from checkpoint state, resuming from where it left off.
|
|
259
|
+
|
|
260
|
+
This is the recommended way to resume training from a checkpoint. It extracts
|
|
261
|
+
the seed, samples_yielded, and shards from the state dict and creates a new
|
|
262
|
+
DataLoader that will skip the already-processed samples.
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
scan: Spiral scan to load data from.
|
|
266
|
+
state: Checkpoint state from state_dict().
|
|
267
|
+
**kwargs: Additional arguments to pass to SpiralDataLoader constructor.
|
|
268
|
+
These will override values in the state dict where applicable.
|
|
269
|
+
|
|
270
|
+
Returns:
|
|
271
|
+
New SpiralDataLoader instance configured to resume from the checkpoint.
|
|
272
|
+
|
|
273
|
+
Save checkpoint during training:
|
|
274
|
+
```python
|
|
275
|
+
loader = scan.to_distributed_data_loader(scan, batch_size=32, seed=42)
|
|
276
|
+
checkpoint = loader.state_dict()
|
|
277
|
+
```
|
|
278
|
+
|
|
279
|
+
Resume later using the same shards from checkpoint:
|
|
280
|
+
```python
|
|
281
|
+
resumed_loader = SpiralDataLoader.from_state_dict(
|
|
282
|
+
scan,
|
|
283
|
+
checkpoint,
|
|
284
|
+
batch_size=32,
|
|
285
|
+
transform_fn=my_transform,
|
|
286
|
+
)
|
|
287
|
+
"""
|
|
288
|
+
|
|
289
|
+
# Extract resume parameters from state
|
|
290
|
+
seed = state.get("seed", 42)
|
|
291
|
+
skip_samples = state.get("samples_yielded", 0)
|
|
292
|
+
shards = state.get("shards")
|
|
293
|
+
|
|
294
|
+
# Allow kwargs to override state dict values
|
|
295
|
+
seed = kwargs.pop("seed", seed)
|
|
296
|
+
skip_samples = kwargs.pop("skip_samples", skip_samples)
|
|
297
|
+
shards = kwargs.pop("shards", shards)
|
|
298
|
+
|
|
299
|
+
return cls(scan, seed=seed, skip_samples=skip_samples, shards=shards, **kwargs)
|
spiral/dataset.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
import pyarrow as pa
|
|
4
|
+
import pyarrow.compute as pc
|
|
5
|
+
import pyarrow.dataset as ds
|
|
6
|
+
|
|
7
|
+
from spiral.scan import Scan
|
|
8
|
+
from spiral.snapshot import Snapshot
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Dataset(ds.Dataset):
|
|
12
|
+
def __init__(self, snapshot: Snapshot):
|
|
13
|
+
self._snapshot = snapshot
|
|
14
|
+
self._table = snapshot.table
|
|
15
|
+
self._schema: pa.Schema = self._snapshot.schema().to_arrow()
|
|
16
|
+
|
|
17
|
+
# We don't actually initialize a Dataset, we just implement enough of the API
|
|
18
|
+
# to fool both DuckDB and Polars.
|
|
19
|
+
# super().__init__()
|
|
20
|
+
self._last_scan = None
|
|
21
|
+
|
|
22
|
+
@property
|
|
23
|
+
def schema(self) -> pa.Schema:
|
|
24
|
+
return self._schema
|
|
25
|
+
|
|
26
|
+
def count_rows(
|
|
27
|
+
self,
|
|
28
|
+
filter: pc.Expression | None = None,
|
|
29
|
+
batch_size: int | None = None,
|
|
30
|
+
batch_readahead: int | None = None,
|
|
31
|
+
fragment_readahead: int | None = None,
|
|
32
|
+
fragment_scan_options: ds.FragmentScanOptions | None = None,
|
|
33
|
+
use_threads: bool = True,
|
|
34
|
+
memory_pool: pa.MemoryPool = None,
|
|
35
|
+
):
|
|
36
|
+
return self.scanner(
|
|
37
|
+
None,
|
|
38
|
+
filter,
|
|
39
|
+
batch_size,
|
|
40
|
+
batch_readahead,
|
|
41
|
+
fragment_readahead,
|
|
42
|
+
fragment_scan_options,
|
|
43
|
+
use_threads,
|
|
44
|
+
memory_pool,
|
|
45
|
+
).count_rows()
|
|
46
|
+
|
|
47
|
+
def filter(self, expression: pc.Expression) -> "Dataset":
|
|
48
|
+
raise NotImplementedError("filter not implemented")
|
|
49
|
+
|
|
50
|
+
def get_fragments(self, filter: pc.Expression | None = None):
|
|
51
|
+
"""TODO(ngates): perhaps we should return ranges as per our split API?"""
|
|
52
|
+
raise NotImplementedError("get_fragments not implemented")
|
|
53
|
+
|
|
54
|
+
def head(
|
|
55
|
+
self,
|
|
56
|
+
num_rows: int,
|
|
57
|
+
columns: list[str] | None = None,
|
|
58
|
+
filter: pc.Expression | None = None,
|
|
59
|
+
batch_size: int | None = None,
|
|
60
|
+
batch_readahead: int | None = None,
|
|
61
|
+
fragment_readahead: int | None = None,
|
|
62
|
+
fragment_scan_options: ds.FragmentScanOptions | None = None,
|
|
63
|
+
use_threads: bool = True,
|
|
64
|
+
memory_pool: pa.MemoryPool = None,
|
|
65
|
+
):
|
|
66
|
+
return self.scanner(
|
|
67
|
+
columns,
|
|
68
|
+
filter,
|
|
69
|
+
batch_size,
|
|
70
|
+
batch_readahead,
|
|
71
|
+
fragment_readahead,
|
|
72
|
+
fragment_scan_options,
|
|
73
|
+
use_threads,
|
|
74
|
+
memory_pool,
|
|
75
|
+
).head(num_rows)
|
|
76
|
+
|
|
77
|
+
def join(
|
|
78
|
+
self,
|
|
79
|
+
right_dataset,
|
|
80
|
+
keys,
|
|
81
|
+
right_keys=None,
|
|
82
|
+
join_type=None,
|
|
83
|
+
left_suffix=None,
|
|
84
|
+
right_suffix=None,
|
|
85
|
+
coalesce_keys=True,
|
|
86
|
+
use_threads=True,
|
|
87
|
+
):
|
|
88
|
+
raise NotImplementedError("join not implemented")
|
|
89
|
+
|
|
90
|
+
def join_asof(self, right_dataset, on, by, tolerance, right_on=None, right_by=None):
|
|
91
|
+
raise NotImplementedError("join_asof not implemented")
|
|
92
|
+
|
|
93
|
+
def replace_schema(self, schema: pa.Schema) -> "Dataset":
|
|
94
|
+
raise NotImplementedError("replace_schema not implemented")
|
|
95
|
+
|
|
96
|
+
def scanner(
|
|
97
|
+
self,
|
|
98
|
+
columns: list[str] | None = None,
|
|
99
|
+
filter: pc.Expression | None = None,
|
|
100
|
+
batch_size: int | None = None,
|
|
101
|
+
batch_readahead: int | None = None,
|
|
102
|
+
fragment_readahead: int | None = None,
|
|
103
|
+
fragment_scan_options: ds.FragmentScanOptions | None = None,
|
|
104
|
+
use_threads: bool = True,
|
|
105
|
+
memory_pool: pa.MemoryPool = None,
|
|
106
|
+
) -> "TableScanner":
|
|
107
|
+
from spiral.substrait_ import SubstraitConverter
|
|
108
|
+
|
|
109
|
+
# Extract the substrait expression so we can convert it to a Spiral expression
|
|
110
|
+
if filter is not None:
|
|
111
|
+
filter = SubstraitConverter(self._table, self._schema, self._table.key_schema.to_arrow()).convert(
|
|
112
|
+
filter.to_substrait(self._schema, allow_arrow_extensions=True),
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
scan = (
|
|
116
|
+
self._table.spiral.scan(
|
|
117
|
+
{c: self._table[c] for c in columns},
|
|
118
|
+
where=filter,
|
|
119
|
+
asof=self._snapshot.asof,
|
|
120
|
+
)
|
|
121
|
+
if columns
|
|
122
|
+
else self._table.spiral.scan(
|
|
123
|
+
self._table,
|
|
124
|
+
where=filter,
|
|
125
|
+
asof=self._snapshot.asof,
|
|
126
|
+
)
|
|
127
|
+
)
|
|
128
|
+
self._last_scan = scan
|
|
129
|
+
|
|
130
|
+
return TableScanner(scan)
|
|
131
|
+
|
|
132
|
+
def sort_by(self, sorting, **kwargs):
|
|
133
|
+
raise NotImplementedError("sort_by not implemented")
|
|
134
|
+
|
|
135
|
+
def take(
|
|
136
|
+
self,
|
|
137
|
+
indices: pa.Array | Any,
|
|
138
|
+
columns: list[str] | None = None,
|
|
139
|
+
filter: pc.Expression | None = None,
|
|
140
|
+
batch_size: int | None = None,
|
|
141
|
+
batch_readahead: int | None = None,
|
|
142
|
+
fragment_readahead: int | None = None,
|
|
143
|
+
fragment_scan_options: ds.FragmentScanOptions | None = None,
|
|
144
|
+
use_threads: bool = True,
|
|
145
|
+
memory_pool: pa.MemoryPool = None,
|
|
146
|
+
):
|
|
147
|
+
return self.scanner(
|
|
148
|
+
columns,
|
|
149
|
+
filter,
|
|
150
|
+
batch_size,
|
|
151
|
+
batch_readahead,
|
|
152
|
+
fragment_readahead,
|
|
153
|
+
fragment_scan_options,
|
|
154
|
+
use_threads,
|
|
155
|
+
memory_pool,
|
|
156
|
+
).take(indices)
|
|
157
|
+
|
|
158
|
+
def to_batches(
|
|
159
|
+
self,
|
|
160
|
+
columns: list[str] | None = None,
|
|
161
|
+
filter: pc.Expression | None = None,
|
|
162
|
+
batch_size: int | None = None,
|
|
163
|
+
batch_readahead: int | None = None,
|
|
164
|
+
fragment_readahead: int | None = None,
|
|
165
|
+
fragment_scan_options: ds.FragmentScanOptions | None = None,
|
|
166
|
+
use_threads: bool = True,
|
|
167
|
+
memory_pool: pa.MemoryPool = None,
|
|
168
|
+
):
|
|
169
|
+
return self.scanner(
|
|
170
|
+
columns,
|
|
171
|
+
filter,
|
|
172
|
+
batch_size,
|
|
173
|
+
batch_readahead,
|
|
174
|
+
fragment_readahead,
|
|
175
|
+
fragment_scan_options,
|
|
176
|
+
use_threads,
|
|
177
|
+
memory_pool,
|
|
178
|
+
).to_batches()
|
|
179
|
+
|
|
180
|
+
def to_table(
|
|
181
|
+
self,
|
|
182
|
+
columns=None,
|
|
183
|
+
filter: pc.Expression | None = None,
|
|
184
|
+
batch_size: int | None = None,
|
|
185
|
+
batch_readahead: int | None = None,
|
|
186
|
+
fragment_readahead: int | None = None,
|
|
187
|
+
fragment_scan_options: ds.FragmentScanOptions | None = None,
|
|
188
|
+
use_threads: bool = True,
|
|
189
|
+
memory_pool: pa.MemoryPool = None,
|
|
190
|
+
):
|
|
191
|
+
return self.scanner(
|
|
192
|
+
columns,
|
|
193
|
+
filter,
|
|
194
|
+
batch_size,
|
|
195
|
+
batch_readahead,
|
|
196
|
+
fragment_readahead,
|
|
197
|
+
fragment_scan_options,
|
|
198
|
+
use_threads,
|
|
199
|
+
memory_pool,
|
|
200
|
+
).to_table()
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
class TableScanner(ds.Scanner):
|
|
204
|
+
"""A PyArrow Dataset Scanner that reads from a Spiral Table."""
|
|
205
|
+
|
|
206
|
+
def __init__(
|
|
207
|
+
self,
|
|
208
|
+
scan: Scan,
|
|
209
|
+
key_table: pa.Table | pa.RecordBatchReader | None = None,
|
|
210
|
+
):
|
|
211
|
+
self._scan = scan
|
|
212
|
+
self._schema = scan.schema
|
|
213
|
+
self.key_table = key_table
|
|
214
|
+
|
|
215
|
+
# We don't actually initialize a Dataset, we just implement enough of the API
|
|
216
|
+
# to fool both DuckDB and Polars.
|
|
217
|
+
# super().__init__()
|
|
218
|
+
|
|
219
|
+
@property
|
|
220
|
+
def schema(self):
|
|
221
|
+
return self._schema
|
|
222
|
+
|
|
223
|
+
def count_rows(self):
|
|
224
|
+
# TODO(ngates): is there a faster way to count rows?
|
|
225
|
+
return sum(len(batch) for batch in self.to_reader())
|
|
226
|
+
|
|
227
|
+
def head(self, num_rows: int):
|
|
228
|
+
"""Return the first `num_rows` rows of the dataset."""
|
|
229
|
+
|
|
230
|
+
kwargs = {}
|
|
231
|
+
if num_rows <= 10_000:
|
|
232
|
+
# We are unlikely to need more than a couple batches
|
|
233
|
+
kwargs["batch_readahead"] = 1
|
|
234
|
+
# The progress bar length is the total number of splits in this dataset. We will likely
|
|
235
|
+
# stop streaming early. As a result, the progress bar is misleading.
|
|
236
|
+
kwargs["hide_progress_bar"] = True
|
|
237
|
+
|
|
238
|
+
reader = self._scan.to_unordered_record_batches(key_table=self.key_table, **kwargs)
|
|
239
|
+
batches = []
|
|
240
|
+
row_count = 0
|
|
241
|
+
for batch in reader:
|
|
242
|
+
if row_count + len(batch) > num_rows:
|
|
243
|
+
batches.append(batch.slice(0, num_rows - row_count))
|
|
244
|
+
break
|
|
245
|
+
row_count += len(batch)
|
|
246
|
+
batches.append(batch)
|
|
247
|
+
return pa.Table.from_batches(batches, schema=reader.schema)
|
|
248
|
+
|
|
249
|
+
def scan_batches(self):
|
|
250
|
+
raise NotImplementedError("scan_batches not implemented")
|
|
251
|
+
|
|
252
|
+
def take(self, indices):
|
|
253
|
+
# TODO(ngates): can we defer take until after we've constructed the scan?
|
|
254
|
+
# Or should this we delay constructing the Spiral Table.scan?
|
|
255
|
+
raise NotImplementedError("take not implemented")
|
|
256
|
+
|
|
257
|
+
def to_batches(self):
|
|
258
|
+
return self.to_reader()
|
|
259
|
+
|
|
260
|
+
def to_reader(self):
|
|
261
|
+
return self._scan.to_unordered_record_batches(key_table=self.key_table)
|
|
262
|
+
|
|
263
|
+
def to_table(self):
|
|
264
|
+
return self.to_reader().read_all()
|
spiral/datetime_.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from datetime import UTC, datetime, timedelta, tzinfo
|
|
3
|
+
|
|
4
|
+
_THE_EPOCH = datetime.fromtimestamp(0, tz=UTC)
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def local_tz() -> tzinfo:
|
|
8
|
+
"""Determine this machine's local timezone."""
|
|
9
|
+
tz = datetime.now().astimezone().tzinfo
|
|
10
|
+
if tz is None:
|
|
11
|
+
raise ValueError("Could not determine this machine's local timezone.")
|
|
12
|
+
return tz
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def timestamp_micros(instant: datetime) -> int:
|
|
16
|
+
"""The number of microseconds between the epoch and the given instant."""
|
|
17
|
+
if instant.tzinfo is None:
|
|
18
|
+
warnings.warn("assuming timezone-naive datetime is local time", stacklevel=2)
|
|
19
|
+
instant = instant.replace(tzinfo=local_tz())
|
|
20
|
+
return (instant - _THE_EPOCH) // timedelta(microseconds=1)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def from_timestamp_micros(ts: int) -> datetime:
|
|
24
|
+
"""Convert a timestamp in microseconds to a datetime."""
|
|
25
|
+
if ts < 0:
|
|
26
|
+
raise ValueError("Timestamp must be non-negative")
|
|
27
|
+
return _THE_EPOCH + timedelta(microseconds=ts)
|
spiral/debug/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
from rich.console import Console
|
|
2
|
+
from rich.table import Table
|
|
3
|
+
|
|
4
|
+
from spiral import datetime_
|
|
5
|
+
from spiral.core.table import Scan
|
|
6
|
+
from spiral.core.table.manifests import FragmentManifest
|
|
7
|
+
from spiral.core.table.spec import ColumnGroup
|
|
8
|
+
from spiral.debug.metrics import _format_bytes
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def display_scan_manifests(scan: Scan):
|
|
12
|
+
"""Display all manifests in a scan."""
|
|
13
|
+
if len(scan.table_ids()) != 1:
|
|
14
|
+
raise NotImplementedError("Multiple table scans are not supported.")
|
|
15
|
+
table_id = scan.table_ids()[0]
|
|
16
|
+
key_space_manifest = scan.key_space_state(table_id).manifest
|
|
17
|
+
column_group_manifests = [
|
|
18
|
+
(column_group, scan.column_group_state(column_group).manifest) for column_group in scan.column_groups()
|
|
19
|
+
]
|
|
20
|
+
|
|
21
|
+
display_manifests(key_space_manifest, column_group_manifests)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def display_manifests(
|
|
25
|
+
key_space_manifest: FragmentManifest, column_group_manifests: list[tuple[ColumnGroup, FragmentManifest]]
|
|
26
|
+
):
|
|
27
|
+
_table_of_fragments(
|
|
28
|
+
key_space_manifest,
|
|
29
|
+
title="Key Space manifest",
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
for column_group, column_group_manifest in column_group_manifests:
|
|
33
|
+
_table_of_fragments(
|
|
34
|
+
column_group_manifest,
|
|
35
|
+
title=f"Column Group manifest for {str(column_group)}",
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _table_of_fragments(manifest: FragmentManifest, title: str):
|
|
40
|
+
"""Display fragments in a formatted table."""
|
|
41
|
+
# Calculate summary statistics
|
|
42
|
+
total_size = sum(fragment.size_bytes for fragment in manifest)
|
|
43
|
+
total_metadata_size = sum(len(fragment.format_metadata or b"") for fragment in manifest)
|
|
44
|
+
fragment_count = len(manifest)
|
|
45
|
+
avg_size = total_size / fragment_count if fragment_count > 0 else 0
|
|
46
|
+
|
|
47
|
+
# Print title and summary
|
|
48
|
+
console = Console()
|
|
49
|
+
console.print(f"\n\n{title}")
|
|
50
|
+
console.print(
|
|
51
|
+
f"{fragment_count} fragments, "
|
|
52
|
+
f"total: {_format_bytes(total_size)}, "
|
|
53
|
+
f"avg: {_format_bytes(int(avg_size))}, "
|
|
54
|
+
f"metadata: {_format_bytes(total_metadata_size)}"
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
# Create rich table
|
|
58
|
+
table = Table(title=None, show_header=True, header_style="bold")
|
|
59
|
+
table.add_column("ID", style="cyan", no_wrap=True)
|
|
60
|
+
table.add_column("Size (Metadata)", justify="right")
|
|
61
|
+
table.add_column("Format", justify="center")
|
|
62
|
+
table.add_column("Key Span", justify="center")
|
|
63
|
+
table.add_column("Level", justify="center")
|
|
64
|
+
table.add_column("Committed At", justify="center")
|
|
65
|
+
table.add_column("Compacted At", justify="center")
|
|
66
|
+
|
|
67
|
+
# Add each fragment as a row
|
|
68
|
+
for fragment in manifest:
|
|
69
|
+
committed_str = str(datetime_.from_timestamp_micros(fragment.committed_at)) if fragment.committed_at else "N/A"
|
|
70
|
+
compacted_str = str(datetime_.from_timestamp_micros(fragment.compacted_at)) if fragment.compacted_at else "N/A"
|
|
71
|
+
|
|
72
|
+
size_with_metadata = (
|
|
73
|
+
f"{_format_bytes(fragment.size_bytes)} ({_format_bytes(len(fragment.format_metadata or b''))})"
|
|
74
|
+
)
|
|
75
|
+
key_span = f"{fragment.key_span.begin}..{fragment.key_span.end}"
|
|
76
|
+
|
|
77
|
+
table.add_row(
|
|
78
|
+
fragment.id,
|
|
79
|
+
size_with_metadata,
|
|
80
|
+
str(fragment.format),
|
|
81
|
+
key_span,
|
|
82
|
+
str(fragment.level),
|
|
83
|
+
committed_str,
|
|
84
|
+
compacted_str,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
console.print(table)
|