pyspiral 0.6.9__cp312-abi3-macosx_11_0_arm64.whl → 0.7.12__cp312-abi3-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyspiral-0.6.9.dist-info → pyspiral-0.7.12.dist-info}/METADATA +9 -8
- {pyspiral-0.6.9.dist-info → pyspiral-0.7.12.dist-info}/RECORD +53 -45
- {pyspiral-0.6.9.dist-info → pyspiral-0.7.12.dist-info}/entry_points.txt +1 -0
- spiral/__init__.py +20 -0
- spiral/_lib.abi3.so +0 -0
- spiral/api/__init__.py +1 -1
- spiral/api/client.py +1 -1
- spiral/api/types.py +1 -0
- spiral/cli/admin.py +2 -2
- spiral/cli/app.py +8 -4
- spiral/cli/fs.py +4 -4
- spiral/cli/iceberg.py +1 -1
- spiral/cli/key_spaces.py +15 -1
- spiral/cli/login.py +4 -3
- spiral/cli/orgs.py +8 -7
- spiral/cli/projects.py +4 -4
- spiral/cli/state.py +5 -3
- spiral/cli/tables.py +59 -36
- spiral/cli/telemetry.py +1 -1
- spiral/cli/types.py +2 -2
- spiral/cli/workloads.py +3 -3
- spiral/client.py +69 -22
- spiral/core/client/__init__.pyi +48 -13
- spiral/core/config/__init__.pyi +47 -0
- spiral/core/expr/__init__.pyi +15 -0
- spiral/core/expr/images/__init__.pyi +3 -0
- spiral/core/expr/list_/__init__.pyi +4 -0
- spiral/core/expr/refs/__init__.pyi +4 -0
- spiral/core/expr/str_/__init__.pyi +3 -0
- spiral/core/expr/struct_/__init__.pyi +6 -0
- spiral/core/expr/text/__init__.pyi +5 -0
- spiral/core/expr/udf/__init__.pyi +14 -0
- spiral/core/expr/video/__init__.pyi +3 -0
- spiral/core/table/__init__.pyi +37 -2
- spiral/core/table/spec/__init__.pyi +6 -4
- spiral/dataloader.py +52 -38
- spiral/dataset.py +10 -1
- spiral/enrichment.py +304 -0
- spiral/expressions/__init__.py +21 -23
- spiral/expressions/base.py +9 -4
- spiral/expressions/file.py +17 -0
- spiral/expressions/http.py +11 -80
- spiral/expressions/s3.py +16 -0
- spiral/expressions/tiff.py +2 -3
- spiral/expressions/udf.py +38 -24
- spiral/iceberg.py +3 -3
- spiral/project.py +34 -6
- spiral/scan.py +80 -33
- spiral/settings.py +19 -97
- spiral/streaming_/stream.py +1 -1
- spiral/table.py +40 -10
- spiral/transaction.py +99 -2
- spiral/expressions/io.py +0 -100
- spiral/expressions/mp4.py +0 -62
- spiral/expressions/png.py +0 -18
- spiral/expressions/qoi.py +0 -18
- spiral/expressions/refs.py +0 -58
- {pyspiral-0.6.9.dist-info → pyspiral-0.7.12.dist-info}/WHEEL +0 -0
spiral/dataloader.py
CHANGED
|
@@ -88,22 +88,24 @@ class SpiralDataLoader:
|
|
|
88
88
|
- map_workers for parallel post-processing (tokenization, decoding, etc.)
|
|
89
89
|
- Built-in checkpoint support via skip_samples
|
|
90
90
|
- Explicit shard-based architecture for distributed training
|
|
91
|
-
"""
|
|
92
91
|
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
92
|
+
Simple usage:
|
|
93
|
+
```python
|
|
94
|
+
loader = SpiralDataLoader(scan, batch_size=32)
|
|
95
|
+
for batch in loader:
|
|
96
|
+
train_step(batch)
|
|
97
|
+
```
|
|
98
|
+
|
|
99
|
+
With parallel transforms:
|
|
100
|
+
```python
|
|
101
|
+
loader = SpiralDataLoader(
|
|
102
|
+
scan,
|
|
103
|
+
batch_size=32,
|
|
104
|
+
transform_fn=tokenize_batch,
|
|
105
|
+
map_workers=4,
|
|
106
|
+
)
|
|
107
|
+
```
|
|
108
|
+
"""
|
|
107
109
|
|
|
108
110
|
def __init__(
|
|
109
111
|
self,
|
|
@@ -119,6 +121,7 @@ class SpiralDataLoader:
|
|
|
119
121
|
# TODO(os): accept vortex arrays here instead of Arrow
|
|
120
122
|
transform_fn: Callable[[pa.RecordBatch], Any] | None = None,
|
|
121
123
|
map_workers: int = 0,
|
|
124
|
+
infinite: bool = False,
|
|
122
125
|
):
|
|
123
126
|
"""Initialize SpiralDataLoader.
|
|
124
127
|
|
|
@@ -143,6 +146,9 @@ class SpiralDataLoader:
|
|
|
143
146
|
map_workers: Number of worker processes for parallel transform_fn
|
|
144
147
|
application. 0 means single-process (no parallelism). Use this for
|
|
145
148
|
CPU-bound transforms like tokenization or audio decoding.
|
|
149
|
+
infinite: Whether to cycle through the dataset infinitely. If True,
|
|
150
|
+
the dataloader will repeat the dataset indefinitely. If False,
|
|
151
|
+
the dataloader will stop after going through the dataset once.
|
|
146
152
|
"""
|
|
147
153
|
self.scan = scan
|
|
148
154
|
self.shards = shards if shards is not None else scan.shards()
|
|
@@ -155,6 +161,7 @@ class SpiralDataLoader:
|
|
|
155
161
|
self.batch_readahead = batch_readahead
|
|
156
162
|
self.transform_fn = transform_fn
|
|
157
163
|
self.map_workers = map_workers
|
|
164
|
+
self.infinite = infinite
|
|
158
165
|
|
|
159
166
|
self._samples_yielded = 0
|
|
160
167
|
|
|
@@ -174,7 +181,7 @@ class SpiralDataLoader:
|
|
|
174
181
|
shuffle=shuffle,
|
|
175
182
|
max_batch_size=self.batch_size,
|
|
176
183
|
batch_readahead=self.batch_readahead,
|
|
177
|
-
infinite=
|
|
184
|
+
infinite=self.infinite,
|
|
178
185
|
)
|
|
179
186
|
|
|
180
187
|
if self.skip_samples > 0:
|
|
@@ -220,16 +227,21 @@ class SpiralDataLoader:
|
|
|
220
227
|
|
|
221
228
|
Returns:
|
|
222
229
|
Dictionary containing samples_yielded, seed, and shards.
|
|
230
|
+
|
|
231
|
+
Example checkpoint:
|
|
232
|
+
```python
|
|
233
|
+
loader = SpiralDataLoader(scan, batch_size=32, seed=42)
|
|
234
|
+
for i, batch in enumerate(loader):
|
|
235
|
+
if i == 10:
|
|
236
|
+
checkpoint = loader.state_dict()
|
|
237
|
+
break
|
|
238
|
+
```
|
|
239
|
+
|
|
240
|
+
Example resume:
|
|
241
|
+
```python
|
|
242
|
+
loader = SpiralDataLoader.from_state_dict(scan, checkpoint, batch_size=32)
|
|
243
|
+
```
|
|
223
244
|
"""
|
|
224
|
-
# Example usage:
|
|
225
|
-
# loader = SpiralDataLoader(scan, batch_size=32, seed=42)
|
|
226
|
-
# for i, batch in enumerate(loader):
|
|
227
|
-
# if i == 10:
|
|
228
|
-
# checkpoint = loader.state_dict()
|
|
229
|
-
# break
|
|
230
|
-
#
|
|
231
|
-
# # Resume later with exact same shards
|
|
232
|
-
# loader = SpiralDataLoader.from_state_dict(scan, checkpoint, batch_size=32)
|
|
233
245
|
return {
|
|
234
246
|
"samples_yielded": self._samples_yielded,
|
|
235
247
|
"seed": self.seed,
|
|
@@ -257,20 +269,22 @@ class SpiralDataLoader:
|
|
|
257
269
|
|
|
258
270
|
Returns:
|
|
259
271
|
New SpiralDataLoader instance configured to resume from the checkpoint.
|
|
272
|
+
|
|
273
|
+
Save checkpoint during training:
|
|
274
|
+
```python
|
|
275
|
+
loader = scan.to_distributed_data_loader(scan, batch_size=32, seed=42)
|
|
276
|
+
checkpoint = loader.state_dict()
|
|
277
|
+
```
|
|
278
|
+
|
|
279
|
+
Resume later using the same shards from checkpoint:
|
|
280
|
+
```python
|
|
281
|
+
resumed_loader = SpiralDataLoader.from_state_dict(
|
|
282
|
+
scan,
|
|
283
|
+
checkpoint,
|
|
284
|
+
batch_size=32,
|
|
285
|
+
transform_fn=my_transform,
|
|
286
|
+
)
|
|
260
287
|
"""
|
|
261
|
-
# Example usage:
|
|
262
|
-
#
|
|
263
|
-
# Save checkpoint during training:
|
|
264
|
-
# loader = scan.to_distributed_data_loader(scan, batch_size=32, seed=42)
|
|
265
|
-
# checkpoint = loader.state_dict()
|
|
266
|
-
#
|
|
267
|
-
# Resume later using the same shards from checkpoint:
|
|
268
|
-
# resumed_loader = SpiralDataLoader.from_state_dict(
|
|
269
|
-
# scan,
|
|
270
|
-
# checkpoint,
|
|
271
|
-
# batch_size=32,
|
|
272
|
-
# transform_fn=my_transform,
|
|
273
|
-
# )
|
|
274
288
|
|
|
275
289
|
# Extract resume parameters from state
|
|
276
290
|
seed = state.get("seed", 42)
|
spiral/dataset.py
CHANGED
|
@@ -226,7 +226,16 @@ class TableScanner(ds.Scanner):
|
|
|
226
226
|
|
|
227
227
|
def head(self, num_rows: int):
|
|
228
228
|
"""Return the first `num_rows` rows of the dataset."""
|
|
229
|
-
|
|
229
|
+
|
|
230
|
+
kwargs = {}
|
|
231
|
+
if num_rows <= 10_000:
|
|
232
|
+
# We are unlikely to need more than a couple batches
|
|
233
|
+
kwargs["batch_readahead"] = 1
|
|
234
|
+
# The progress bar length is the total number of splits in this dataset. We will likely
|
|
235
|
+
# stop streaming early. As a result, the progress bar is misleading.
|
|
236
|
+
kwargs["hide_progress_bar"] = True
|
|
237
|
+
|
|
238
|
+
reader = self._scan.to_record_batches(key_table=self.key_table, **kwargs)
|
|
230
239
|
batches = []
|
|
231
240
|
row_count = 0
|
|
232
241
|
for batch in reader:
|
spiral/enrichment.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import dataclasses
|
|
4
|
+
import logging
|
|
5
|
+
from functools import partial
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
from spiral.core.client import KeyColumns, Shard
|
|
9
|
+
from spiral.core.table import KeyRange
|
|
10
|
+
from spiral.core.table.spec import Key, Operation
|
|
11
|
+
from spiral.expressions import Expr
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
import dask.distributed
|
|
15
|
+
|
|
16
|
+
from spiral import KeySpaceIndex, Scan, Table
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Enrichment:
|
|
22
|
+
"""
|
|
23
|
+
An enrichment is used to derive new columns from the existing once, such as fetching data from object storage
|
|
24
|
+
with `se.s3.get` or compute embeddings. With column groups design supporting 100s of thousands of columns,
|
|
25
|
+
horizontally expanding tables are a powerful primitive.
|
|
26
|
+
|
|
27
|
+
NOTE: Spiral aims to optimize enrichments where source and destination table are the same.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
table: Table,
|
|
33
|
+
projection: Expr,
|
|
34
|
+
where: Expr | None,
|
|
35
|
+
):
|
|
36
|
+
self._table = table
|
|
37
|
+
self._projection = projection
|
|
38
|
+
self._where = where
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def table(self) -> Table:
|
|
42
|
+
"""The table to write back into."""
|
|
43
|
+
return self._table
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def projection(self) -> Expr:
|
|
47
|
+
"""The projection expression."""
|
|
48
|
+
return self._projection
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def where(self) -> Expr | None:
|
|
52
|
+
"""The filter expression."""
|
|
53
|
+
return self._where
|
|
54
|
+
|
|
55
|
+
def _scan(self) -> Scan:
|
|
56
|
+
return self._table.spiral.scan(self._projection, where=self._where, _key_columns=KeyColumns.Included)
|
|
57
|
+
|
|
58
|
+
def apply(
|
|
59
|
+
self, *, batch_readahead: int | None = None, partition_size_bytes: int | None = None, tx_dump: str | None = None
|
|
60
|
+
) -> None:
|
|
61
|
+
"""Apply the enrichment onto the table in a streaming fashion.
|
|
62
|
+
|
|
63
|
+
For large tables, consider using `apply_dask` for distributed execution.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
index: Optional key space index to use for sharding the enrichment.
|
|
67
|
+
If not provided, the table's default sharding will be used.
|
|
68
|
+
partition_size_bytes: The maximum partition size in bytes.
|
|
69
|
+
If not provided, the default partition size is used.
|
|
70
|
+
tx_dump: Optional path to dump the transaction JSON for debugging.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
txn = self._table.txn()
|
|
74
|
+
|
|
75
|
+
txn.writeback(
|
|
76
|
+
self._scan(),
|
|
77
|
+
partition_size_bytes=partition_size_bytes,
|
|
78
|
+
batch_readahead=batch_readahead,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
if txn.is_empty():
|
|
82
|
+
logger.warning("Transaction not committed. No rows were read for enrichment.")
|
|
83
|
+
return
|
|
84
|
+
|
|
85
|
+
txn.commit(tx_dump=tx_dump)
|
|
86
|
+
|
|
87
|
+
# TODO(marko): Need to figure out this sharding with key space index in places.
|
|
88
|
+
# We could compute on-demand instead of requiring a resource.
|
|
89
|
+
def apply_dask(
|
|
90
|
+
self,
|
|
91
|
+
*,
|
|
92
|
+
index: KeySpaceIndex | None = None,
|
|
93
|
+
partition_size_bytes: int | None = None,
|
|
94
|
+
tx_dump: str | None = None,
|
|
95
|
+
checkpoint_dump: str | None = None,
|
|
96
|
+
client: dask.distributed.Client | None = None,
|
|
97
|
+
**kwargs,
|
|
98
|
+
) -> None:
|
|
99
|
+
"""Use distributed Dask to apply the enrichment. Requires `dask[distributed]` to be installed.
|
|
100
|
+
|
|
101
|
+
If "address" of an existing Dask cluster is not provided in `kwargs`, a local cluster will be created.
|
|
102
|
+
|
|
103
|
+
IMPORTANT: Dask execution has some limitations, e.g. UDFs are not currently supported. These limitations
|
|
104
|
+
usually manifest as serialization errors when Dask workers attempt to serialize the state. If you are
|
|
105
|
+
encountering such issues, consider splitting the enrichment into UDF-only derivation that will be
|
|
106
|
+
executed in a streaming fashion, followed by a Dask enrichment for the rest of the computation.
|
|
107
|
+
If that is not possible, please reach out to the support for assistance.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
index: Optional key space index to use for sharding the enrichment.
|
|
111
|
+
If not provided, the table's default sharding will be used.
|
|
112
|
+
partition_size_bytes: The maximum partition size in bytes.
|
|
113
|
+
If not provided, the default partition size is used.
|
|
114
|
+
tx_dump: Optional path to dump the transaction JSON for debugging.
|
|
115
|
+
checkpoint_dump: Optional path to dump intermediate checkpoints for incremental progress.
|
|
116
|
+
client: Optional Dask distributed client. If not provided, a new client will be created
|
|
117
|
+
**kwargs: Additional keyword arguments to pass to `dask.distributed.Client`
|
|
118
|
+
such as `address` to connect to an existing cluster.
|
|
119
|
+
"""
|
|
120
|
+
if client is None:
|
|
121
|
+
try:
|
|
122
|
+
from dask.distributed import Client
|
|
123
|
+
except ImportError:
|
|
124
|
+
raise ImportError("dask is not installed, please install dask[distributed] to use this feature.")
|
|
125
|
+
|
|
126
|
+
# Connect before doing any work.
|
|
127
|
+
client = Client(**kwargs)
|
|
128
|
+
|
|
129
|
+
# Start a transaction BEFORE the planning scan.
|
|
130
|
+
tx = self._table.txn()
|
|
131
|
+
plan_scan = self._scan()
|
|
132
|
+
|
|
133
|
+
# Determine the "tasks".
|
|
134
|
+
shards = None
|
|
135
|
+
# Use checkpoint, if provided.
|
|
136
|
+
if checkpoint_dump is not None:
|
|
137
|
+
checkpoint: list[KeyRange] | None = _checkpoint_load_key_ranges(checkpoint_dump)
|
|
138
|
+
if checkpoint is None:
|
|
139
|
+
logger.info(f"No existing checkpoint found at {checkpoint_dump}. Starting from scratch.")
|
|
140
|
+
else:
|
|
141
|
+
logger.info(f"Resuming enrichment from checkpoint at {checkpoint_dump} with {len(checkpoint)} ranges.")
|
|
142
|
+
shards = [Shard(kr, None) for kr in checkpoint]
|
|
143
|
+
# Fallback to index-based sharding.
|
|
144
|
+
if shards is None and index is not None:
|
|
145
|
+
# TODO(marko): This will use index's asof automatically.
|
|
146
|
+
shards = self._table.spiral.internal.compute_shards(index.core)
|
|
147
|
+
# Fallback to default sharding.
|
|
148
|
+
if shards is None:
|
|
149
|
+
shards = plan_scan.shards()
|
|
150
|
+
|
|
151
|
+
# TODO(marko): This is temporary workaround. Passing token is a bad idea.
|
|
152
|
+
# Token can expire during long-running enrichments.
|
|
153
|
+
# Maybe if device code is used, we can pass something.
|
|
154
|
+
token = self._table.spiral.authn.token()
|
|
155
|
+
if token is None:
|
|
156
|
+
raise ValueError("Spiral client is not authenticated.")
|
|
157
|
+
config = self._table.spiral.config
|
|
158
|
+
config.token = token
|
|
159
|
+
|
|
160
|
+
# Partially bind the enrichment function.
|
|
161
|
+
_compute = partial(
|
|
162
|
+
_enrichment_task,
|
|
163
|
+
settings_json=config.to_json(),
|
|
164
|
+
state_json=plan_scan.core.plan_state().to_json(),
|
|
165
|
+
output_table_id=self._table.table_id,
|
|
166
|
+
partition_size_bytes=partition_size_bytes,
|
|
167
|
+
incremental=checkpoint_dump is not None,
|
|
168
|
+
)
|
|
169
|
+
enrichments = client.map(_compute, shards)
|
|
170
|
+
|
|
171
|
+
logger.info(f"Applying enrichment with {len(shards)} shards. Follow progress at {client.dashboard_link}")
|
|
172
|
+
|
|
173
|
+
failed_ranges = []
|
|
174
|
+
try:
|
|
175
|
+
for result, shard in zip(client.gather(enrichments), shards):
|
|
176
|
+
result: EnrichmentTaskResult
|
|
177
|
+
|
|
178
|
+
if result.error is not None:
|
|
179
|
+
logger.error(f"Enrichment task failed for range {shard.key_range}: {result.error}")
|
|
180
|
+
failed_ranges.append(shard.key_range)
|
|
181
|
+
continue
|
|
182
|
+
|
|
183
|
+
tx.include(result.ops)
|
|
184
|
+
except Exception as e:
|
|
185
|
+
# If not incremental, re-raise the exception.
|
|
186
|
+
if checkpoint_dump is None:
|
|
187
|
+
raise e
|
|
188
|
+
|
|
189
|
+
# Handle worker failures (e.g., KilledWorker from Dask)
|
|
190
|
+
from dask.distributed import KilledWorker
|
|
191
|
+
|
|
192
|
+
if isinstance(e, KilledWorker):
|
|
193
|
+
logger.error(f"Dask worker was killed during enrichment: {e}")
|
|
194
|
+
|
|
195
|
+
# Try to gather partial results and mark remaining tasks as failed
|
|
196
|
+
for future, shard in zip(enrichments, shards):
|
|
197
|
+
if future.done() and not future.exception():
|
|
198
|
+
try:
|
|
199
|
+
result = future.result()
|
|
200
|
+
|
|
201
|
+
if result.error is not None:
|
|
202
|
+
logger.error(f"Enrichment task failed for range {shard.key_range}: {result.error}")
|
|
203
|
+
failed_ranges.append(shard.key_range)
|
|
204
|
+
continue
|
|
205
|
+
|
|
206
|
+
tx.include(result.ops)
|
|
207
|
+
except Exception:
|
|
208
|
+
# Task failed or incomplete, add to failed ranges
|
|
209
|
+
failed_ranges.append(shard.key_range)
|
|
210
|
+
else:
|
|
211
|
+
# Task didn't complete, add to failed ranges
|
|
212
|
+
failed_ranges.append(shard.key_range)
|
|
213
|
+
|
|
214
|
+
# Dump checkpoint of failed ranges, if any.
|
|
215
|
+
if checkpoint_dump is not None:
|
|
216
|
+
logger.info(
|
|
217
|
+
f"Dumping checkpoint with failed {len(failed_ranges)}/{len(shards)} ranges to {checkpoint_dump}."
|
|
218
|
+
)
|
|
219
|
+
_checkpoint_dump_key_ranges(checkpoint_dump, failed_ranges)
|
|
220
|
+
|
|
221
|
+
if tx.is_empty():
|
|
222
|
+
logger.warning("Transaction not committed. No rows were read for enrichment.")
|
|
223
|
+
return
|
|
224
|
+
|
|
225
|
+
# Always compact in distributed enrichment.
|
|
226
|
+
tx.commit(compact=True, tx_dump=tx_dump)
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def _checkpoint_load_key_ranges(checkpoint_dump: str) -> list[KeyRange] | None:
|
|
230
|
+
import json
|
|
231
|
+
import os
|
|
232
|
+
|
|
233
|
+
if not os.path.exists(checkpoint_dump):
|
|
234
|
+
return None
|
|
235
|
+
|
|
236
|
+
with open(checkpoint_dump) as f:
|
|
237
|
+
data = json.load(f)
|
|
238
|
+
return [
|
|
239
|
+
KeyRange(begin=Key(bytes.fromhex(r["begin"])), end=Key(bytes.fromhex(r["end"])))
|
|
240
|
+
for r in data.get("key_ranges", [])
|
|
241
|
+
]
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def _checkpoint_dump_key_ranges(checkpoint_dump: str, ranges: list[KeyRange]):
|
|
245
|
+
import json
|
|
246
|
+
import os
|
|
247
|
+
|
|
248
|
+
os.makedirs(os.path.dirname(checkpoint_dump), exist_ok=True)
|
|
249
|
+
with open(checkpoint_dump, "w") as f:
|
|
250
|
+
json.dump(
|
|
251
|
+
{"key_ranges": [{"begin": bytes(r.begin).hex(), "end": bytes(r.end).hex()} for r in ranges]},
|
|
252
|
+
f,
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
@dataclasses.dataclass
|
|
257
|
+
class EnrichmentTaskResult:
|
|
258
|
+
ops: list[Operation]
|
|
259
|
+
error: str | None = None
|
|
260
|
+
|
|
261
|
+
def __getstate__(self):
|
|
262
|
+
return {
|
|
263
|
+
"ops": [op.to_json() for op in self.ops],
|
|
264
|
+
"error": self.error,
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
def __setstate__(self, state):
|
|
268
|
+
self.ops = [Operation.from_json(op_json) for op_json in state["ops"]]
|
|
269
|
+
self.error = state["error"]
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
# NOTE(marko): This function must be picklable!
|
|
273
|
+
def _enrichment_task(
|
|
274
|
+
shard: Shard,
|
|
275
|
+
*,
|
|
276
|
+
settings_json: str,
|
|
277
|
+
state_json: str,
|
|
278
|
+
output_table_id,
|
|
279
|
+
partition_size_bytes: int | None,
|
|
280
|
+
incremental: bool,
|
|
281
|
+
) -> EnrichmentTaskResult:
|
|
282
|
+
# Returns operations that can be included in a transaction.
|
|
283
|
+
from spiral import Scan, Spiral
|
|
284
|
+
from spiral.core.table import ScanState
|
|
285
|
+
from spiral.settings import ClientSettings
|
|
286
|
+
|
|
287
|
+
settings = ClientSettings.from_json(settings_json)
|
|
288
|
+
sp = Spiral(config=settings)
|
|
289
|
+
state = ScanState.from_json(state_json)
|
|
290
|
+
task_scan = Scan(sp, sp.core.load_scan(state))
|
|
291
|
+
table = sp.table(output_table_id)
|
|
292
|
+
task_tx = table.txn()
|
|
293
|
+
|
|
294
|
+
try:
|
|
295
|
+
task_tx.writeback(task_scan, key_range=shard.key_range, partition_size_bytes=partition_size_bytes)
|
|
296
|
+
return EnrichmentTaskResult(ops=task_tx.take())
|
|
297
|
+
except Exception as e:
|
|
298
|
+
task_tx.abort()
|
|
299
|
+
|
|
300
|
+
if incremental:
|
|
301
|
+
return EnrichmentTaskResult(ops=[], error=str(e))
|
|
302
|
+
|
|
303
|
+
logger.error(f"Enrichment task failed for shard {shard}: {e}")
|
|
304
|
+
raise e
|
spiral/expressions/__init__.py
CHANGED
|
@@ -8,31 +8,25 @@ import pyarrow as pa
|
|
|
8
8
|
|
|
9
9
|
from spiral import _lib, arrow_
|
|
10
10
|
|
|
11
|
+
from . import file as file
|
|
11
12
|
from . import http as http
|
|
12
|
-
from . import io as io
|
|
13
13
|
from . import list_ as list
|
|
14
|
-
from . import
|
|
15
|
-
from . import png as png
|
|
16
|
-
from . import qoi as qoi
|
|
17
|
-
from . import refs as refs
|
|
14
|
+
from . import s3 as s3
|
|
18
15
|
from . import str_ as str
|
|
19
16
|
from . import struct as struct
|
|
20
17
|
from . import text as text
|
|
21
|
-
from . import tiff as tiff
|
|
22
18
|
from .base import Expr, ExprLike, NativeExpr
|
|
19
|
+
from .udf import UDF
|
|
23
20
|
|
|
24
21
|
__all__ = [
|
|
25
22
|
"Expr",
|
|
26
23
|
"add",
|
|
27
24
|
"and_",
|
|
28
|
-
"deref",
|
|
29
25
|
"divide",
|
|
30
26
|
"eq",
|
|
31
27
|
"getitem",
|
|
32
28
|
"gt",
|
|
33
29
|
"gte",
|
|
34
|
-
"http",
|
|
35
|
-
"io",
|
|
36
30
|
"is_not_null",
|
|
37
31
|
"is_null",
|
|
38
32
|
"lift",
|
|
@@ -48,19 +42,17 @@ __all__ = [
|
|
|
48
42
|
"or_",
|
|
49
43
|
"pack",
|
|
50
44
|
"aux",
|
|
51
|
-
"ref",
|
|
52
|
-
"refs",
|
|
53
45
|
"scalar",
|
|
54
46
|
"select",
|
|
55
47
|
"str",
|
|
56
48
|
"struct",
|
|
57
49
|
"subtract",
|
|
58
|
-
"tiff",
|
|
59
50
|
"xor",
|
|
60
|
-
"png",
|
|
61
|
-
"qoi",
|
|
62
|
-
"mp4",
|
|
63
51
|
"text",
|
|
52
|
+
"s3",
|
|
53
|
+
"http",
|
|
54
|
+
"file",
|
|
55
|
+
"UDF",
|
|
64
56
|
]
|
|
65
57
|
|
|
66
58
|
# Inline some of the struct expressions since they're so common
|
|
@@ -68,8 +60,6 @@ getitem = struct.getitem
|
|
|
68
60
|
merge = struct.merge
|
|
69
61
|
pack = struct.pack
|
|
70
62
|
select = struct.select
|
|
71
|
-
ref = refs.ref
|
|
72
|
-
deref = refs.deref
|
|
73
63
|
|
|
74
64
|
|
|
75
65
|
def lift(expr: ExprLike) -> Expr:
|
|
@@ -127,16 +117,24 @@ def evaluate(expr: ExprLike) -> pa.RecordBatchReader:
|
|
|
127
117
|
return pa.RecordBatchReader.from_batches(expr.schema, [expr])
|
|
128
118
|
if isinstance(expr, pa.StructArray):
|
|
129
119
|
return pa.Table.from_struct_array(expr).to_reader()
|
|
120
|
+
|
|
130
121
|
if isinstance(expr, pa.ChunkedArray):
|
|
131
|
-
|
|
132
|
-
|
|
122
|
+
if not pa.types.is_struct(expr.type):
|
|
123
|
+
raise ValueError("Arrow chunked array must be a struct type.")
|
|
124
|
+
|
|
125
|
+
def _iter_batches():
|
|
126
|
+
for chunk in expr.chunks:
|
|
127
|
+
yield pa.RecordBatch.from_struct_array(chunk)
|
|
128
|
+
|
|
129
|
+
return pa.RecordBatchReader.from_batches(pa.schema(expr.type.fields), _iter_batches())
|
|
130
|
+
|
|
133
131
|
if isinstance(expr, pa.Array):
|
|
134
132
|
raise ValueError("Arrow array must be a struct array.")
|
|
135
133
|
|
|
136
|
-
if isinstance(expr, Expr):
|
|
137
|
-
raise NotImplementedError(
|
|
138
|
-
|
|
139
|
-
|
|
134
|
+
if isinstance(expr, Expr) or isinstance(expr, NativeExpr):
|
|
135
|
+
raise NotImplementedError(
|
|
136
|
+
"Expr evaluation not supported yet. Use Arrow to write instead. Reach out if you require this feature."
|
|
137
|
+
)
|
|
140
138
|
|
|
141
139
|
if isinstance(expr, dict):
|
|
142
140
|
# NOTE: we assume this is a struct expression. We could be smarter and be context aware to determine if
|
spiral/expressions/base.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
|
-
import builtins
|
|
2
1
|
import datetime
|
|
3
|
-
from typing import TypeAlias
|
|
2
|
+
from typing import TypeAlias, Union
|
|
4
3
|
|
|
5
4
|
import pyarrow as pa
|
|
6
5
|
|
|
@@ -153,5 +152,11 @@ class Expr:
|
|
|
153
152
|
|
|
154
153
|
|
|
155
154
|
ScalarLike: TypeAlias = bool | int | float | str | list["ScalarLike"] | datetime.datetime | None
|
|
156
|
-
ArrowLike: TypeAlias =
|
|
157
|
-
|
|
155
|
+
ArrowLike: TypeAlias = Union[
|
|
156
|
+
pa.RecordBatch,
|
|
157
|
+
"pa.Array[pa.Scalar[pa.DataType]]",
|
|
158
|
+
"pa.ChunkedArray[pa.Scalar[pa.DataType]]",
|
|
159
|
+
"pa.Scalar[pa.DataType]",
|
|
160
|
+
pa.Table,
|
|
161
|
+
]
|
|
162
|
+
ExprLike: TypeAlias = Expr | dict[str, "ExprLike"] | list["ExprLike"] | ArrowLike | ScalarLike
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from spiral import _lib
|
|
2
|
+
from spiral.expressions.base import Expr, ExprLike
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def get(expr: ExprLike, abort_on_error: bool = False) -> Expr:
|
|
6
|
+
"""Read data from the local filesystem by the file:// URL.
|
|
7
|
+
|
|
8
|
+
Args:
|
|
9
|
+
expr: URLs of the data that needs to be read.
|
|
10
|
+
abort_on_error: Should the expression abort on errors or just collect them.
|
|
11
|
+
"""
|
|
12
|
+
from spiral import expressions as se
|
|
13
|
+
|
|
14
|
+
expr = se.lift(expr)
|
|
15
|
+
|
|
16
|
+
# This just works :)
|
|
17
|
+
return Expr(_lib.expr.s3.get(expr.__expr__, abort_on_error))
|