pyspiral 0.8.9__cp311-abi3-macosx_11_0_arm64.whl → 0.9.9__cp311-abi3-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spiral/input.py ADDED
@@ -0,0 +1,131 @@
1
+ import builtins
2
+ from typing import TYPE_CHECKING, TypeAlias, Union
3
+
4
+ import numpy as np
5
+ import pyarrow as pa
6
+
7
+ from spiral import arrow_
8
+
9
+ if TYPE_CHECKING:
10
+ import pandas as pd
11
+
12
+ ArrayLike: TypeAlias = Union[pa.Array, pa.ChunkedArray, builtins.list, np.ndarray, "pd.Series"]
13
+ TableLike: TypeAlias = Union[
14
+ pa.Table,
15
+ pa.RecordBatch,
16
+ pa.RecordBatchReader,
17
+ pa.StructArray,
18
+ pa.ChunkedArray, # must be of struct type
19
+ builtins.list[dict], # list of objects, each element is a row
20
+ dict[str, ArrayLike], # dot-separated field names are nested
21
+ "pd.DataFrame",
22
+ ]
23
+
24
+
25
+ def evaluate(table: TableLike) -> pa.RecordBatchReader:
26
+ if isinstance(table, pa.RecordBatchReader):
27
+ return table
28
+
29
+ if isinstance(table, pa.Table):
30
+ return table.to_reader()
31
+ if isinstance(table, pa.RecordBatch):
32
+ return pa.RecordBatchReader.from_batches(table.schema, [table])
33
+
34
+ if isinstance(table, pa.StructArray):
35
+ return pa.Table.from_struct_array(table).to_reader()
36
+ if isinstance(table, pa.ChunkedArray):
37
+ if not pa.types.is_struct(table.type):
38
+ raise ValueError(f"Arrow ChunkedArray must have a struct type, got {table.type}.")
39
+ struct_type: pa.StructType = table.type # type: ignore[assignment]
40
+
41
+ def _iter_batches():
42
+ for chunk in table.chunks:
43
+ chunk: pa.StructArray
44
+ yield pa.RecordBatch.from_struct_array(chunk)
45
+
46
+ return pa.RecordBatchReader.from_batches(pa.schema(struct_type.fields), _iter_batches())
47
+ if isinstance(table, pa.Array):
48
+ raise ValueError(f"Arrow Array must be a struct array, got {type(table)}.")
49
+
50
+ if isinstance(table, builtins.list):
51
+ # Handle empty array case
52
+ if len(table) == 0:
53
+ return pa.RecordBatchReader.from_batches(pa.schema([]), [])
54
+ return evaluate(pa.array(table))
55
+
56
+ if isinstance(table, dict):
57
+ table: dict = dot_separated_dict_to_nested(table)
58
+
59
+ return evaluate(_evaluate_dict(table))
60
+
61
+ try:
62
+ import pandas as pd
63
+
64
+ if isinstance(table, pd.DataFrame):
65
+ return evaluate(pa.Table.from_pandas(table))
66
+ except ImportError:
67
+ pass
68
+
69
+ raise TypeError(f"Unsupported table-like type: {type(table)}")
70
+
71
+
72
+ def _evaluate_dict(table: dict) -> pa.StructArray:
73
+ """Handle dot-separated field names as nested dictionaries."""
74
+ table = dot_separated_dict_to_nested(table)
75
+ return _dict_to_struct_array(table)
76
+
77
+
78
+ def _dict_to_struct_array(table) -> pa.StructArray:
79
+ data = {}
80
+ for key, value in table.items():
81
+ data[key] = _evaluate_array_like(value) if not isinstance(value, dict) else _dict_to_struct_array(value)
82
+ return arrow_.dict_to_struct_array(data)
83
+
84
+
85
+ def _evaluate_array_like(array: ArrayLike) -> pa.Array:
86
+ if isinstance(array, pa.Array):
87
+ return array
88
+ if isinstance(array, pa.ChunkedArray):
89
+ return array.combine_chunks()
90
+
91
+ if isinstance(array, np.ndarray):
92
+ return _evaluate_array_like(pa.array(array))
93
+ if isinstance(array, builtins.list):
94
+ return _evaluate_array_like(pa.array(array))
95
+
96
+ try:
97
+ import pandas as pd
98
+
99
+ if isinstance(array, pd.Series):
100
+ return _evaluate_array_like(pa.Array.from_pandas(array))
101
+ except ImportError:
102
+ pass
103
+
104
+ raise TypeError(f"Unsupported array-like type: {type(array)}")
105
+
106
+
107
+ def dot_separated_dict_to_nested(expr: dict) -> dict:
108
+ """Handle dot-separated field names as nested dictionaries."""
109
+ data = {}
110
+
111
+ for name in expr.keys():
112
+ if "." not in name:
113
+ if name in data:
114
+ raise KeyError(f"Conflicting field name: {name}")
115
+ data[name] = expr[name]
116
+ continue
117
+
118
+ parts = name.split(".")
119
+ child_data = data
120
+ for part in parts[:-1]:
121
+ if part not in child_data:
122
+ child_data[part] = {}
123
+ if not isinstance(child_data[part], dict):
124
+ raise KeyError(f"Conflicting field name: {name}")
125
+ child_data = child_data[part]
126
+
127
+ if parts[-1] in child_data:
128
+ raise KeyError(f"Conflicting field name: {name}")
129
+ child_data[parts[-1]] = expr[name]
130
+
131
+ return data
spiral/ray_.py ADDED
@@ -0,0 +1,75 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Iterable
4
+ from typing import TYPE_CHECKING
5
+
6
+ import pyarrow as pa
7
+ import ray
8
+ from ray.data.block import Block
9
+ from ray.data.datasource.datasink import WriteResult
10
+
11
+ from spiral import Spiral, Transaction
12
+ from spiral.core.config import ClientSettings
13
+ from spiral.transaction import TransactionOps
14
+ from spiral.types_ import Timestamp
15
+
16
+ if TYPE_CHECKING:
17
+ from ray.data._internal.execution.interfaces import TaskContext
18
+
19
+
20
+ # TODO(DK): we should just ship the serde bytes not JSON-serialized strings.
21
+ class Datasink(ray.data.Datasink[tuple[Timestamp, list[str]]]):
22
+ def __init__(self, txn: Transaction):
23
+ super().__init__()
24
+ self._table_id: str = txn.table.table_id
25
+ self._spiral_config_json = txn.table.spiral.config.to_json()
26
+ self._txn = txn
27
+
28
+ def __getstate__(self):
29
+ state = dict(self.__dict__)
30
+ state["_txn"] = None # do not serialize the transaction
31
+ return state
32
+
33
+ def __setstate__(self, state):
34
+ self.__dict__.update(state)
35
+
36
+ def on_write_complete(self, write_result: WriteResult[TransactionOps]):
37
+ assert self._txn is not None # on_write_complete happens on the driver
38
+
39
+ for tx_ops in write_result.write_returns:
40
+ self._txn.include(tx_ops)
41
+
42
+ def on_write_failed(self, error: Exception):
43
+ pass
44
+
45
+ def on_write_start(self, schema: pa.Schema | None = None):
46
+ pass
47
+
48
+ def write(
49
+ self,
50
+ blocks: Iterable[Block],
51
+ ctx: TaskContext,
52
+ ) -> TransactionOps:
53
+ assert self._txn is None # writes happen on workers
54
+
55
+ import pyarrow
56
+
57
+ sp = Spiral(config=ClientSettings.from_json(self._spiral_config_json))
58
+
59
+ # Do *not* use a context manager and do *not* call commit/abort.
60
+ # We instead `take` and send the operations to the driver node.
61
+ txn = sp.table(self._table_id).txn()
62
+
63
+ for block in blocks:
64
+ if not isinstance(block, pyarrow.Table):
65
+ try:
66
+ import pandas as pd
67
+
68
+ assert isinstance(block, pd.DataFrame)
69
+ block = pyarrow.Table.from_pandas(block)
70
+ except ImportError:
71
+ raise TypeError(f"Expected block to be a pyarrow.Table or pandas.DataFrame, got {type(block)}")
72
+
73
+ txn.write(block)
74
+
75
+ return txn.take()
spiral/scan.py CHANGED
@@ -1,18 +1,19 @@
1
- from functools import partial
2
- from typing import TYPE_CHECKING, Any, Optional
1
+ from typing import TYPE_CHECKING, Any, TypedDict, cast
3
2
 
4
3
  import pyarrow as pa
4
+ from typing_extensions import Unpack
5
5
 
6
6
  from spiral.core.client import Shard, ShuffleConfig
7
7
  from spiral.core.table import Scan as CoreScan
8
8
  from spiral.core.table.spec import Schema
9
- from spiral.settings import CI, TEST
9
+ from spiral.input import TableLike, evaluate
10
10
 
11
11
  if TYPE_CHECKING:
12
12
  import dask.dataframe as dd
13
13
  import datasets.iterable_dataset as hf # noqa
14
14
  import pandas as pd
15
15
  import polars as pl
16
+ import ray.data
16
17
  import streaming # noqa
17
18
  import torch.utils.data as torchdata # noqa
18
19
 
@@ -20,6 +21,25 @@ if TYPE_CHECKING:
20
21
  from spiral.dataloader import SpiralDataLoader, World # noqa
21
22
 
22
23
 
24
+ class ExecuteKwargs(TypedDict):
25
+ shards: list[Shard] | None
26
+ key_table: pa.Table | pa.RecordBatchReader | None
27
+ batch_readahead: int | None
28
+ batch_aligned: bool | None
29
+ hide_progress_bar: bool | None
30
+
31
+
32
+ class DistributedExecuteKwargs(TypedDict):
33
+ shards: list[Shard] | None
34
+ batch_readahead: int | None
35
+ hide_progress_bar: bool | None
36
+
37
+
38
+ class _PoppedDistributedExecuteKwargs(TypedDict):
39
+ batch_readahead: int | None
40
+ hide_progress_bar: bool | None
41
+
42
+
23
43
  class Scan:
24
44
  """Scan object."""
25
45
 
@@ -54,10 +74,10 @@ class Scan:
54
74
  self,
55
75
  *,
56
76
  shards: list[Shard] | None = None,
57
- key_table: pa.Table | pa.RecordBatchReader | None = None,
58
- batch_size: int | None = None,
77
+ key_table: TableLike | None = None,
59
78
  batch_readahead: int | None = None,
60
- hide_progress_bar: bool = False,
79
+ batch_aligned: bool | None = None,
80
+ hide_progress_bar: bool | None = None,
61
81
  ) -> pa.RecordBatchReader:
62
82
  """Read as a stream of RecordBatches.
63
83
 
@@ -67,32 +87,38 @@ class Scan:
67
87
  Must not be provided together with key_table.
68
88
  key_table: a table of keys to "take" (including aux columns for cell-push-down).
69
89
  If None, the scan will be executed without a key table.
70
- batch_size: the maximum number of rows per returned batch.
71
- This is currently only respected when the key_table is used. If key table is a
72
- RecordBatchReader, the batch_size argument must be None, and the existing batching is respected.
73
90
  batch_readahead: the number of batches to prefetch in the background.
91
+ batch_aligned: if True, ensures that batches are aligned with key_table batches.
92
+ The stream will yield batches that correspond exactly to the batches in key_table,
93
+ but may be less efficient and use more memory (aligning batches requires buffering and maybe a copy).
94
+ Must only be used when key_table is provided.
74
95
  hide_progress_bar: If True, disables the progress bar during reading.
75
96
  """
76
- if isinstance(key_table, pa.RecordBatchReader):
77
- if batch_size is not None:
78
- raise ValueError(
79
- "batch_size must be None when key_table is a RecordBatchReader, the existing batching is respected."
80
- )
81
- elif isinstance(key_table, pa.Table):
82
- key_table = key_table.to_reader(max_chunksize=batch_size)
97
+ batch_aligned = False if batch_aligned is None else batch_aligned
98
+ hide_progress_bar = False if hide_progress_bar is None else hide_progress_bar
99
+
100
+ if key_table is not None:
101
+ key_table = evaluate(key_table)
102
+
103
+ # NOTE(marko): Uncomment for better debuggability.
104
+ # rb: pa.RecordBatch = self.core.to_record_batch(shards=shards, key_table=key_table)
105
+ # return pa.RecordBatchReader.from_batches(rb.schema, [rb])
83
106
 
84
107
  return self.core.to_record_batches(
85
- shards=shards, key_table=key_table, batch_readahead=batch_readahead, progress=(not hide_progress_bar)
108
+ shards=shards,
109
+ key_table=key_table,
110
+ batch_readahead=batch_readahead,
111
+ batch_aligned=batch_aligned,
112
+ hide_progress_bar=hide_progress_bar,
86
113
  )
87
114
 
88
115
  def to_unordered_record_batches(
89
116
  self,
90
117
  *,
91
118
  shards: list[Shard] | None = None,
92
- key_table: pa.Table | pa.RecordBatchReader | None = None,
93
- batch_size: int | None = None,
119
+ key_table: TableLike | None = None,
94
120
  batch_readahead: int | None = None,
95
- hide_progress_bar: bool = False,
121
+ hide_progress_bar: bool | None = None,
96
122
  ) -> pa.RecordBatchReader:
97
123
  """Read as a stream of RecordBatches, NOT ordered by key.
98
124
 
@@ -102,34 +128,43 @@ class Scan:
102
128
  Must not be provided together with key_table.
103
129
  key_table: a table of keys to "take" (including aux columns for cell-push-down).
104
130
  If None, the scan will be executed without a key table.
105
- batch_size: the maximum number of rows per returned batch.
106
- This is currently only respected when the key_table is used. If key table is a
107
- RecordBatchReader, the batch_size argument must be None, and the existing batching is respected.
108
131
  batch_readahead: the number of batches to prefetch in the background.
109
132
  hide_progress_bar: If True, disables the progress bar during reading.
110
133
  """
111
- if isinstance(key_table, pa.RecordBatchReader):
112
- if batch_size is not None:
113
- raise ValueError(
114
- "batch_size must be None when key_table is a RecordBatchReader, the existing batching is respected."
115
- )
116
- elif isinstance(key_table, pa.Table):
117
- key_table = key_table.to_reader(max_chunksize=batch_size)
134
+ hide_progress_bar = False if hide_progress_bar is None else hide_progress_bar
135
+
136
+ if key_table is not None:
137
+ key_table = evaluate(key_table)
118
138
 
119
139
  return self.core.to_unordered_record_batches(
120
- shards=shards, key_table=key_table, batch_readahead=batch_readahead, progress=(not hide_progress_bar)
140
+ shards=shards,
141
+ key_table=key_table,
142
+ batch_readahead=batch_readahead,
143
+ hide_progress_bar=hide_progress_bar,
121
144
  )
122
145
 
123
- def to_table(self, **kwargs) -> pa.Table:
124
- """Read into a single PyArrow Table."""
125
- # NOTE: Evaluates fully on Rust side which improved debuggability.
126
- if TEST and not CI:
127
- rb = self.core.to_record_batch(**kwargs)
128
- return pa.Table.from_batches([rb])
146
+ def to_table(self, **kwargs: Unpack[ExecuteKwargs]) -> pa.Table:
147
+ """Read into a single PyArrow Table.
148
+
149
+ Warnings:
150
+ This downloads the entire Spiral Table into memory on this machine.
151
+
152
+ Args:
153
+ shards: Optional list of shards to evaluate.
154
+ If provided, only the specified shards will be read.
155
+ Must not be provided together with key_table.
156
+ key_table: a table of keys to "take" (including aux columns for cell-push-down).
157
+ If None, the scan will be executed without a key table.
158
+ batch_readahead: the number of batches to prefetch in the background.
159
+ hide_progress_bar: If True, disables the progress bar during reading.
160
+
161
+ Returns:
162
+ pyarrow.Table
129
163
 
164
+ """
130
165
  return self.to_record_batches(**kwargs).read_all()
131
166
 
132
- def to_dask(self) -> "dd.DataFrame":
167
+ def to_dask(self, **kwargs: Unpack[DistributedExecuteKwargs]) -> "dd.DataFrame":
133
168
  """Read into a Dask DataFrame.
134
169
 
135
170
  Requires the `dask` package to be installed.
@@ -137,31 +172,126 @@ class Scan:
137
172
  Dask execution has some limitations, e.g. UDFs are not currently supported. These limitations
138
173
  usually manifest as serialization errors when Dask workers attempt to serialize the state. If you are
139
174
  encountering such issues, please reach out to the support for assistance.
175
+
176
+ Args:
177
+ shards: Optional list of shards to evaluate.
178
+ If provided, only the specified shards will be read.
179
+ batch_readahead: the number of batches to prefetch in the background.
180
+ hide_progress_bar: If True, disables the progress bar during reading.
181
+
182
+ Returns:
183
+ dask.dataframe.DataFrame
184
+
140
185
  """
141
186
  import dask.dataframe as dd
142
187
 
143
- _read_shard = partial(
144
- _read_shard_task,
145
- config_json=self.spiral.config.to_json(),
146
- state_json=self.core.plan_state().to_json(),
147
- )
148
- return dd.from_map(_read_shard, self.shards())
188
+ config_json = self.spiral.config.to_json()
189
+ state_bytes = self.core.plan_context().to_bytes_compressed()
190
+
191
+ shards = kwargs.pop("shards", None) or self.shards()
192
+ task_kwargs = cast(_PoppedDistributedExecuteKwargs, kwargs)
193
+
194
+ def _read_shard(shard: Shard) -> "pd.DataFrame":
195
+ arrow_table = _read_shard_task(
196
+ shard,
197
+ config_json=config_json,
198
+ state_bytes=state_bytes,
199
+ **task_kwargs,
200
+ )
201
+ return arrow_table.to_pandas()
202
+
203
+ return dd.from_map(_read_shard, shards)
149
204
 
150
- def to_pandas(self, **kwargs) -> "pd.DataFrame":
205
+ def to_ray_dataset(self, **kwargs: Unpack[DistributedExecuteKwargs]) -> "ray.data.Dataset":
206
+ """Read into a Ray Dataset.
207
+
208
+ Requires the `ray` package to be installed.
209
+
210
+ Warnings:
211
+ If the Scan returns zero rows, the resulting Ray Dataset will have [an empty
212
+ schema](https://github.com/ray-project/ray/issues/59946).
213
+
214
+ Args:
215
+ shards: Optional list of shards to evaluate.
216
+ If provided, only the specified shards will be read.
217
+ batch_readahead: the number of batches to prefetch in the background.
218
+ hide_progress_bar: If True, disables the progress bar during reading.
219
+
220
+ Returns:
221
+ ray.data.Dataset: A Ray Dataset distributed across shards.
222
+
223
+ """
224
+ import ray
225
+
226
+ config_json = self.spiral.config.to_json()
227
+ state_bytes = self.core.plan_context().to_bytes_compressed()
228
+
229
+ shards = kwargs.pop("shards", None) or self.shards()
230
+ task_kwargs = cast(_PoppedDistributedExecuteKwargs, kwargs)
231
+
232
+ read_shard_remote = ray.remote(_read_shard_task)
233
+ refs = [
234
+ read_shard_remote.remote(
235
+ shard,
236
+ config_json=config_json,
237
+ state_bytes=state_bytes,
238
+ **task_kwargs,
239
+ )
240
+ for shard in shards
241
+ ]
242
+
243
+ return ray.data.from_arrow_refs(refs)
244
+
245
+ def to_pandas(self, **kwargs: Unpack[ExecuteKwargs]) -> "pd.DataFrame":
151
246
  """Read into a Pandas DataFrame.
152
247
 
153
248
  Requires the `pandas` package to be installed.
249
+
250
+ Warnings:
251
+ This downloads the entire Spiral Table into memory on this machine.
252
+
253
+ Args:
254
+ shards: Optional list of shards to evaluate.
255
+ If provided, only the specified shards will be read.
256
+ Must not be provided together with key_table.
257
+ key_table: a table of keys to "take" (including aux columns for cell-push-down).
258
+ If None, the scan will be executed without a key table.
259
+ batch_readahead: the number of batches to prefetch in the background.
260
+ hide_progress_bar: If True, disables the progress bar during reading.
261
+
262
+ Returns:
263
+ pandas.DataFrame
264
+
154
265
  """
155
266
  return self.to_record_batches(**kwargs).read_all().to_pandas()
156
267
 
157
- def to_polars(self, **kwargs) -> "pl.DataFrame":
268
+ def to_polars(self, **kwargs: Unpack[ExecuteKwargs]) -> "pl.DataFrame":
158
269
  """Read into a Polars DataFrame.
159
270
 
160
271
  Requires the `polars` package to be installed.
272
+
273
+ Warnings:
274
+ This downloads the entire Spiral Table into memory on this machine. To lazily interact
275
+ with a Spiral Table try Table.to_polars_lazy_frame.
276
+
277
+ Args:
278
+ shards: Optional list of shards to evaluate.
279
+ If provided, only the specified shards will be read.
280
+ Must not be provided together with key_table.
281
+ key_table: a table of keys to "take" (including aux columns for cell-push-down).
282
+ If None, the scan will be executed without a key table.
283
+ batch_readahead: the number of batches to prefetch in the background.
284
+ hide_progress_bar: If True, disables the progress bar during reading.
285
+
286
+ Returns:
287
+ polars.DataFrame
288
+
161
289
  """
162
290
  import polars as pl
163
291
 
164
- return pl.from_arrow(self.to_record_batches(**kwargs))
292
+ df = pl.from_arrow(self.to_record_batches(**kwargs))
293
+ assert isinstance(df, pl.DataFrame)
294
+ return df
165
295
 
166
296
  def to_data_loader(
167
297
  self, seed: int = 42, shuffle_buffer_size: int = 0, batch_size: int = 32, **kwargs
@@ -186,7 +316,7 @@ class Scan:
186
316
 
187
317
  def to_distributed_data_loader(
188
318
  self,
189
- world: Optional["World"] = None,
319
+ world: "World | None" = None,
190
320
  shards: list[Shard] | None = None,
191
321
  seed: int = 42,
192
322
  shuffle_buffer_size: int = 0,
@@ -315,7 +445,8 @@ class Scan:
315
445
  If None, no shuffling is performed.
316
446
  batch_readahead: Controls how many batches to read ahead concurrently.
317
447
  If pipeline includes work after reading (e.g. decoding, transforming, ...) this can be set higher.
318
- Otherwise, it should be kept low to reduce next batch latency. Defaults to 2.
448
+ Otherwise, it should be kept low to reduce next batch latency.
449
+ Defaults to min(number of CPU cores, 64) or to shuffle.buffer_size/16 if shuffle is not None.
319
450
  infinite: If True, the returned IterableDataset will loop infinitely over the data,
320
451
  re-shuffling ranges after exhausting all data.
321
452
  """
@@ -326,7 +457,7 @@ class Scan:
326
457
  infinite=infinite,
327
458
  )
328
459
 
329
- from spiral.iterable_dataset import to_iterable_dataset
460
+ from spiral.huggingface import to_iterable_dataset
330
461
 
331
462
  return to_iterable_dataset(stream)
332
463
 
@@ -342,15 +473,15 @@ class Scan:
342
473
  """
343
474
  return self.core.shards()
344
475
 
345
- def state_json(self) -> str:
346
- """Get the scan state as a JSON string.
476
+ def state_bytes(self) -> bytes:
477
+ """Get the scan state as bytes.
347
478
 
348
479
  This state can be used to resume the scan later using Spiral.resume_scan().
349
480
 
350
481
  Returns:
351
- JSON string representing the internal scan state.
482
+ Compressed bytes representing the internal scan state.
352
483
  """
353
- return self.core.plan_state().to_json()
484
+ return self.core.plan_context().to_bytes_compressed()
354
485
 
355
486
  def _debug(self):
356
487
  # Visualizes the scan, mainly for debugging purposes.
@@ -358,12 +489,6 @@ class Scan:
358
489
 
359
490
  show_scan(self.core)
360
491
 
361
- def _dump_manifests(self):
362
- # Print manifests in a human-readable format.
363
- from spiral.debug.manifests import display_scan_manifests
364
-
365
- display_scan_manifests(self.core)
366
-
367
492
  def _dump_metrics(self):
368
493
  # Print metrics in a human-readable format.
369
494
  from spiral.debug.metrics import display_metrics
@@ -372,12 +497,41 @@ class Scan:
372
497
 
373
498
 
374
499
  # NOTE(marko): This function must be picklable!
375
- def _read_shard_task(shard: Shard, *, config_json: str, state_json: str) -> "pd.DataFrame":
500
+
501
+
502
+ def _read_shard_task(
503
+ shard: Shard,
504
+ *,
505
+ config_json: str,
506
+ state_bytes: bytes,
507
+ key_table: pa.Table | pa.RecordBatchReader | None = None,
508
+ batch_readahead: int | None = None,
509
+ hide_progress_bar: bool | None = None,
510
+ ) -> pa.Table:
511
+ """Ray worker function to read a single shard as Arrow table.
512
+
513
+ Args:
514
+ shard: The shard to read
515
+ config_json: Serialized ClientSettings
516
+ state_bytes: Serialized scan state
517
+ key_table: a table of keys to "take" (including aux columns for cell-push-down).
518
+ If None, the scan will be executed without a key table.
519
+ batch_readahead: the number of batches to prefetch in the background.
520
+ hide_progress_bar: If True, disables the progress bar during reading.
521
+
522
+ Returns:
523
+ PyArrow Table containing the shard data
524
+ """
376
525
  from spiral import Spiral
377
526
  from spiral.settings import ClientSettings
378
527
 
379
528
  config = ClientSettings.from_json(config_json)
380
529
  sp = Spiral(config=config)
381
- task_scan = sp.resume_scan(state_json)
382
-
383
- return task_scan.to_pandas(shards=[shard], hide_progress_bar=True)
530
+ task_scan = sp.resume_scan(state_bytes)
531
+
532
+ return task_scan.to_table(
533
+ shards=[shard],
534
+ key_table=key_table,
535
+ batch_readahead=batch_readahead,
536
+ hide_progress_bar=hide_progress_bar,
537
+ )
spiral/table.py CHANGED
@@ -5,6 +5,7 @@ from spiral.core.table import Table as CoreTable
5
5
  from spiral.core.table.spec import Schema
6
6
  from spiral.enrichment import Enrichment
7
7
  from spiral.expressions.base import Expr, ExprLike
8
+ from spiral.input import TableLike
8
9
  from spiral.snapshot import Snapshot
9
10
  from spiral.transaction import Transaction
10
11
 
@@ -99,17 +100,17 @@ class Table(Expr):
99
100
  """
100
101
  return self.core.get_schema(asof=None)
101
102
 
102
- def write(self, expr: ExprLike, push_down_nulls: bool = False, **kwargs) -> None:
103
+ def write(self, table: TableLike, push_down_nulls: bool = False, **kwargs) -> None:
103
104
  """Write an item to the table inside a single transaction.
104
105
 
105
106
  :param push_down_nulls: Whether to push down nullable structs down its children. E.g. `[{"a": 1}, null]` would
106
107
  become `[{"a": 1}, {"a": null}]`. SpiralDB doesn't allow struct-level nullability, so use this option if your
107
108
  data contains nullable structs.
108
109
 
109
- :param expr: The expression to write. Must evaluate to a struct array.
110
+ :param table: The table to write.
110
111
  """
111
112
  with self.txn(**kwargs) as txn:
112
- txn.write(expr, push_down_nulls=push_down_nulls)
113
+ txn.write(table, push_down_nulls=push_down_nulls)
113
114
 
114
115
  def enrich(
115
116
  self,
@@ -157,7 +158,7 @@ class Table(Expr):
157
158
  it is important that the primary key columns are unique within the transaction.
158
159
  The behavior is undefined if this is not the case.
159
160
  """
160
- return Transaction(self.spiral.core.transaction(self.core, **kwargs))
161
+ return Transaction(self, self.spiral.core.transaction(self.core, **kwargs))
161
162
 
162
163
  def to_arrow_dataset(self) -> "ds.Dataset":
163
164
  """Returns a PyArrow Dataset representing the table."""