pyspiral 0.2.4__cp310-abi3-macosx_11_0_arm64.whl → 0.3.1__cp310-abi3-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,42 +1,44 @@
1
1
  import numpy as np
2
2
  import pyarrow as pa
3
3
 
4
- from spiral.expressions.base import ExprLike
4
+ from spiral.expressions.base import Expr, ExprLike
5
5
  from spiral.expressions.udf import RefUDF
6
6
 
7
+ _TIFF_RES_DTYPE: pa.DataType = pa.struct(
8
+ [
9
+ pa.field("pixels", pa.large_binary()),
10
+ pa.field("height", pa.uint32()),
11
+ pa.field("width", pa.uint32()),
12
+ pa.field("channels", pa.uint8()),
13
+ pa.field("channel_bit_depth", pa.uint8()),
14
+ ]
15
+ )
16
+
7
17
 
8
18
  def read(
9
19
  expr: ExprLike,
10
- indexes: ExprLike | int | list[int] | None = None,
20
+ indexes: ExprLike | int | None = None,
11
21
  window: ExprLike | tuple[tuple[int, int], tuple[int, int]] | None = None,
12
22
  boundless: ExprLike | bool | None = None,
13
- ):
23
+ ) -> Expr:
14
24
  """
15
25
  Read referenced cell in a `TIFF` format. Requires `rasterio` to be installed.
16
26
 
17
27
  Args:
18
28
  expr: The referenced `TIFF` bytes.
19
- indexes: The band indexes to read. Defaults to first band. The first dimension of the result's `shape` field
20
- is either 1 or the number of indexes.
29
+ indexes: The band indexes to read. Defaults to all.
21
30
  window: The window to read. In format (row_range_tuple, col_range_tuple). Defaults to full window.
22
31
  boundless: If `True`, windows that extend beyond the dataset's extent
23
32
  are permitted and partially or completely filled arrays will be returned as appropriate.
24
33
 
25
34
  Returns:
26
- An array where each element is a NumPy array represented as a struct with fields:
27
- bytes: Array bytes with type `pa.large_binary()`.
28
- shape: Array shape with type `pa.list_(pa.uint32(), 3)`.
29
- dtype: String representation of NumPy dtype with type `pa.string()`.
30
-
31
- Example:
32
- A way to get the i-th element in the result as NumPy array:
33
-
34
- ```
35
- array: np.ndarray = np.frombuffer(
36
- result["bytes"][i].as_py(),
37
- dtype=np.dtype(result["dtype"][i].as_py()),
38
- ).reshape(tuple(result["shape"][i].as_py()))
39
- ```
35
+ An array where each element is a decoded image with fields:
36
+ pixels: bytes of shape (channels, width, height).
37
+ width: Width of the image with type `pa.uint32()`.
38
+ height: Height of the image with type `pa.uint32()`.
39
+ channels: Number of channels of the image with type `pa.uint8()`.
40
+ If `indexes` is not None, this is the length of `indexes` or 1 if `indexes` is an int.
41
+ channel_bit_depth: Bit depth of the channel with type `pa.uint8()`.
40
42
  """
41
43
  try:
42
44
  import rasterio # noqa: F401
@@ -46,55 +48,42 @@ def read(
46
48
  return TiffReadUDF()(expr, indexes, window, boundless)
47
49
 
48
50
 
49
- def crop(
51
+ def select(
50
52
  expr: ExprLike,
51
- shape: ExprLike,
52
- ):
53
+ shape: ExprLike | dict,
54
+ indexes: ExprLike | int | None = None,
55
+ ) -> Expr:
53
56
  """
54
- Crop shapes out of the referenced cell in a `TIFF` format. Requires `rasterio` to be installed.
57
+ Select the shape out of the referenced cell in a `TIFF` format. Requires `rasterio` to be installed.
55
58
 
56
59
  Args:
57
60
  expr: The referenced `TIFF` bytes.
58
61
  shape: [GeoJSON-like](https://geojson.org/) shape.
62
+ indexes: The band indexes to read. Defaults to all.
59
63
 
60
64
  Returns:
61
- An array where each element is a NumPy array represented as a struct with fields:
62
- bytes: Array bytes with type `pa.large_binary()`.
63
- shape: Array shape with type `pa.list_(pa.uint32(), 3)`.
64
- dtype: String representation of NumPy dtype with type `pa.string()`.
65
-
66
- Example:
67
- A way to get the i-th element in the result as NumPy array:
68
-
69
- ```
70
- array: np.ndarray = np.frombuffer(
71
- result["bytes"][i].as_py(),
72
- dtype=np.dtype(result["dtype"][i].as_py()),
73
- ).reshape(tuple(result["shape"][i].as_py()))
74
- ```
65
+ An array where each element is a decoded image with fields:
66
+ pixels: bytes of shape (len(indexes) or 1, width, height).
67
+ width: Width of the image with type `pa.uint32()`.
68
+ height: Height of the image with type `pa.uint32()`.
69
+ channels: Number of channels of the image with type `pa.uint8()`.
70
+ If `indexes` is not None, this is the length of `indexes` or 1 if `indexes` is an int.
71
+ channel_bit_depth: Bit depth of the channel with type `pa.uint8()`.
75
72
  """
76
73
  try:
77
74
  import rasterio # noqa: F401
78
75
  except ImportError:
79
- raise ImportError("`rasterio` is required for tiff.crop")
76
+ raise ImportError("`rasterio` is required for tiff.select")
80
77
 
81
- return TiffCropUDF()(expr, shape)
78
+ return TiffSelectUDF()(expr, shape, indexes)
82
79
 
83
80
 
84
81
  class TiffReadUDF(RefUDF):
85
- RES_DTYPE: pa.DataType = pa.struct(
86
- [
87
- pa.field("bytes", pa.large_binary()),
88
- pa.field("shape", pa.list_(pa.uint32(), 3)),
89
- pa.field("dtype", pa.string()),
90
- ]
91
- )
92
-
93
82
  def __init__(self):
94
83
  super().__init__("tiff.read")
95
84
 
96
85
  def return_type(self, *input_types: pa.DataType) -> pa.DataType:
97
- return TiffReadUDF.RES_DTYPE
86
+ return _TIFF_RES_DTYPE
98
87
 
99
88
  def invoke(self, fp, *input_args: pa.Array) -> pa.Array:
100
89
  try:
@@ -130,65 +119,76 @@ class TiffReadUDF(RefUDF):
130
119
  # This matters more if we want to rewrite this function to work with multiple inputs at once, in which
131
120
  # case we should first consider using Rust GDAL bindings - I believe rasterio uses GDAL under the hood.
132
121
  result: np.ndarray = src.read(indexes=indexes, window=window)
133
- return pa.array(
134
- [
135
- {
136
- "bytes": result.tobytes(),
137
- "shape": list(result.shape),
138
- "dtype": str(result.dtype),
139
- }
140
- ],
141
- type=TiffReadUDF.RES_DTYPE,
142
- )
143
-
144
-
145
- class TiffCropUDF(RefUDF):
146
- RES_DTYPE: pa.DataType = pa.struct(
147
- [
148
- pa.field("bytes", pa.large_binary()),
149
- pa.field("shape", pa.list_(pa.uint32()), 3),
150
- pa.field("dtype", pa.string()),
151
- ]
152
- )
122
+ return _return_result(result, indexes)
123
+
153
124
 
125
+ class TiffSelectUDF(RefUDF):
154
126
  def __init__(self):
155
- super().__init__("tiff.crop")
127
+ super().__init__("tiff.select")
156
128
 
157
129
  def return_type(self, *input_types: pa.DataType) -> pa.DataType:
158
- return TiffCropUDF.RES_DTYPE
130
+ return _TIFF_RES_DTYPE
159
131
 
160
132
  def invoke(self, fp, *input_args: pa.Array) -> pa.Array:
161
133
  try:
162
134
  import rasterio
163
135
  except ImportError:
164
- raise ImportError("`rasterio` is required for tiff.crop")
136
+ raise ImportError("`rasterio` is required for tiff.select")
165
137
 
166
- from rasterio.mask import mask as rio_mask
138
+ from rasterio.mask import raster_geometry_mask
167
139
 
168
- if len(input_args) != 2:
169
- raise ValueError("tiff.crop expects exactly 2 arguments: expr, shape")
140
+ if len(input_args) != 3:
141
+ raise ValueError("tiff.select expects exactly 3 arguments: expr, shape, indexes")
170
142
 
171
- _, shape = input_args
143
+ _, shape, indexes = input_args
172
144
 
173
145
  shape = shape[0].as_py()
174
146
  if shape is None:
175
- raise ValueError("tiff.crop expects shape to be a GeoJSON-like shape")
147
+ raise ValueError("tiff.select expects shape to be a GeoJSON-like shape")
148
+
149
+ indexes = indexes[0].as_py()
150
+ if indexes is not None and not isinstance(indexes, int) and not isinstance(indexes, list):
151
+ raise ValueError(f"tiff.select expects indexes to be None or an int or a list, got {indexes}")
176
152
 
177
153
  opener = _VsiOpener(fp)
178
154
  with rasterio.open("ref", opener=opener) as src:
179
155
  src: rasterio.DatasetReader
180
- result, _ = rio_mask(src, shapes=[shape], crop=True)
181
- result: np.ndarray
182
- return pa.array(
183
- [
184
- {
185
- "bytes": result.tobytes(),
186
- "shape": list(result.shape),
187
- "dtype": str(result.dtype),
188
- }
189
- ],
190
- type=TiffCropUDF.RES_DTYPE,
191
- )
156
+
157
+ shape_mask, _, window = raster_geometry_mask(src, [shape], crop=True)
158
+ out_shape = (src.count,) + shape_mask.shape
159
+
160
+ result: np.ndarray = src.read(window=window, indexes=indexes, out_shape=out_shape, masked=True)
161
+ return _return_result(result, indexes)
162
+
163
+
164
+ def _return_result(result: np.ndarray, indexes) -> pa.Array:
165
+ channels = result.shape[0]
166
+ if indexes is None:
167
+ pass
168
+ elif isinstance(indexes, int):
169
+ assert channels == 1, f"Expected 1 channel, got {channels}"
170
+ else:
171
+ assert channels == len(indexes), f"Expected {len(indexes)} channels, got {channels}"
172
+
173
+ if result.dtype == np.uint8:
174
+ channel_bit_depth = 8
175
+ elif result.dtype == np.uint16:
176
+ channel_bit_depth = 16
177
+ else:
178
+ raise ValueError(f"Unsupported bit width: {result.dtype}")
179
+
180
+ return pa.array(
181
+ [
182
+ {
183
+ "pixels": result.tobytes(),
184
+ "height": result.shape[1],
185
+ "width": result.shape[2],
186
+ "channels": channels,
187
+ "channel_bit_depth": channel_bit_depth,
188
+ }
189
+ ],
190
+ type=_TIFF_RES_DTYPE,
191
+ )
192
192
 
193
193
 
194
194
  class _VsiOpener:
spiral/maintenance.py ADDED
@@ -0,0 +1,12 @@
1
+ from spiral.core.core import TableMaintenance
2
+
3
+
4
+ class Maintenance:
5
+ """Spiral table maintenance."""
6
+
7
+ def __init__(self, maintenance: TableMaintenance):
8
+ self._maintenance = maintenance
9
+
10
+ def flush_wal(self):
11
+ """Flush the write-ahead log."""
12
+ self._maintenance.flush_wal()
@@ -30,6 +30,11 @@ class Source(betterproto.Message):
30
30
  parquet: "MetadataParquet" = betterproto.message_field(10, group="metadata")
31
31
 
32
32
 
33
+ @dataclass(eq=False, repr=False)
34
+ class Sink(betterproto.Message):
35
+ url: str = betterproto.string_field(1)
36
+
37
+
33
38
  @dataclass(eq=False, repr=False)
34
39
  class Fetch(betterproto.Message):
35
40
  """Let's make "fetch" happen."""
@@ -39,15 +44,24 @@ class Fetch(betterproto.Message):
39
44
 
40
45
  @dataclass(eq=False, repr=False)
41
46
  class FetchRequest(betterproto.Message):
47
+ """TODO(ngates): include projection expression."""
48
+
42
49
  uri: str = betterproto.string_field(1)
43
50
  """
44
- A signed request to read an spfs://<fsid>/path?token=<jwt> URI.
45
- * Declares the MIME types the client can read directly.
46
- * Declares whether the client has connectivity to the FileSystem.
51
+ A signed request to read an
52
+ spfs://&lt;fsid&gt;/path?token=&lt;jwt&gt URI.
47
53
  """
48
54
 
49
- connectivity: "Connectivity" = betterproto.message_field(2)
50
- accepts: List[str] = betterproto.string_field(3)
55
+ headers: Dict[str, str] = betterproto.map_field(
56
+ 2, betterproto.TYPE_STRING, betterproto.TYPE_STRING
57
+ )
58
+ """Custom headers to sign into the request."""
59
+
60
+ connectivity: "Connectivity" = betterproto.message_field(3)
61
+ """Declares whether the client has connectivity to the FileSystem."""
62
+
63
+ accepts: List[str] = betterproto.string_field(4)
64
+ """Declares the MIME types the client can read directly."""
51
65
 
52
66
 
53
67
  @dataclass(eq=False, repr=False)
@@ -59,11 +73,6 @@ class FetchResponse(betterproto.Message):
59
73
  """
60
74
 
61
75
 
62
- @dataclass(eq=False, repr=False)
63
- class Sink(betterproto.Message):
64
- url: str = betterproto.string_field(1)
65
-
66
-
67
76
  @dataclass(eq=False, repr=False)
68
77
  class Put(betterproto.Message):
69
78
  pass
@@ -72,7 +81,10 @@ class Put(betterproto.Message):
72
81
  @dataclass(eq=False, repr=False)
73
82
  class PutRequest(betterproto.Message):
74
83
  uri: str = betterproto.string_field(1)
75
- connectivity: "Connectivity" = betterproto.message_field(2)
84
+ headers: Dict[str, str] = betterproto.map_field(
85
+ 2, betterproto.TYPE_STRING, betterproto.TYPE_STRING
86
+ )
87
+ connectivity: "Connectivity" = betterproto.message_field(3)
76
88
 
77
89
 
78
90
  @dataclass(eq=False, repr=False)
@@ -80,6 +92,25 @@ class PutResponse(betterproto.Message):
80
92
  sinks: List["Sink"] = betterproto.message_field(1)
81
93
 
82
94
 
95
+ @dataclass(eq=False, repr=False)
96
+ class Head(betterproto.Message):
97
+ pass
98
+
99
+
100
+ @dataclass(eq=False, repr=False)
101
+ class HeadRequest(betterproto.Message):
102
+ uri: str = betterproto.string_field(1)
103
+ headers: Dict[str, str] = betterproto.map_field(
104
+ 2, betterproto.TYPE_STRING, betterproto.TYPE_STRING
105
+ )
106
+
107
+
108
+ @dataclass(eq=False, repr=False)
109
+ class HeadResponse(betterproto.Message):
110
+ url: str = betterproto.string_field(1)
111
+ """Returns signed URL to head the resource."""
112
+
113
+
83
114
  @dataclass(eq=False, repr=False)
84
115
  class Delete(betterproto.Message):
85
116
  pass
@@ -88,6 +119,9 @@ class Delete(betterproto.Message):
88
119
  @dataclass(eq=False, repr=False)
89
120
  class DeleteRequest(betterproto.Message):
90
121
  uri: str = betterproto.string_field(1)
122
+ headers: Dict[str, str] = betterproto.map_field(
123
+ 2, betterproto.TYPE_STRING, betterproto.TYPE_STRING
124
+ )
91
125
 
92
126
 
93
127
  @dataclass(eq=False, repr=False)
@@ -151,6 +185,23 @@ class ScandalServiceStub(betterproto.ServiceStub):
151
185
  metadata=metadata,
152
186
  )
153
187
 
188
+ async def head(
189
+ self,
190
+ head_request: "HeadRequest",
191
+ *,
192
+ timeout: Optional[float] = None,
193
+ deadline: Optional["Deadline"] = None,
194
+ metadata: Optional["MetadataLike"] = None
195
+ ) -> "HeadResponse":
196
+ return await self._unary_unary(
197
+ "/scandal.ScandalService/Head",
198
+ head_request,
199
+ HeadResponse,
200
+ timeout=timeout,
201
+ deadline=deadline,
202
+ metadata=metadata,
203
+ )
204
+
154
205
  async def delete(
155
206
  self,
156
207
  delete_request: "DeleteRequest",
@@ -176,6 +227,9 @@ class ScandalServiceBase(ServiceBase):
176
227
  async def put(self, put_request: "PutRequest") -> "PutResponse":
177
228
  raise grpclib.GRPCError(grpclib.const.Status.UNIMPLEMENTED)
178
229
 
230
+ async def head(self, head_request: "HeadRequest") -> "HeadResponse":
231
+ raise grpclib.GRPCError(grpclib.const.Status.UNIMPLEMENTED)
232
+
179
233
  async def delete(self, delete_request: "DeleteRequest") -> "DeleteResponse":
180
234
  raise grpclib.GRPCError(grpclib.const.Status.UNIMPLEMENTED)
181
235
 
@@ -193,6 +247,13 @@ class ScandalServiceBase(ServiceBase):
193
247
  response = await self.put(request)
194
248
  await stream.send_message(response)
195
249
 
250
+ async def __rpc_head(
251
+ self, stream: "grpclib.server.Stream[HeadRequest, HeadResponse]"
252
+ ) -> None:
253
+ request = await stream.recv_message()
254
+ response = await self.head(request)
255
+ await stream.send_message(response)
256
+
196
257
  async def __rpc_delete(
197
258
  self, stream: "grpclib.server.Stream[DeleteRequest, DeleteResponse]"
198
259
  ) -> None:
@@ -214,6 +275,12 @@ class ScandalServiceBase(ServiceBase):
214
275
  PutRequest,
215
276
  PutResponse,
216
277
  ),
278
+ "/scandal.ScandalService/Head": grpclib.const.Handler(
279
+ self.__rpc_head,
280
+ grpclib.const.Cardinality.UNARY_UNARY,
281
+ HeadRequest,
282
+ HeadResponse,
283
+ ),
217
284
  "/scandal.ScandalService/Delete": grpclib.const.Handler(
218
285
  self.__rpc_delete,
219
286
  grpclib.const.Cardinality.UNARY_UNARY,
@@ -152,6 +152,12 @@ class FragmentSetWriteOp(betterproto.Message):
152
152
  key_span: "KeySpan" = betterproto.message_field(5)
153
153
  key_extent: "KeyExtent" = betterproto.message_field(6)
154
154
  column_ids: List[str] = betterproto.string_field(7)
155
+ data_ts: Optional[int] = betterproto.uint64_field(8, optional=True)
156
+ """
157
+ Timestamp of the data in the fragments.
158
+ Used as committed_ts for files in the manifest.
159
+ If not present, timestamp of the operation is used.
160
+ """
155
161
 
156
162
 
157
163
  @dataclass(eq=False, repr=False)
@@ -175,8 +181,53 @@ class SchemaBreakOp(betterproto.Message):
175
181
 
176
182
  @dataclass(eq=False, repr=False)
177
183
  class CompactKeySpaceOp(betterproto.Message):
178
- from_ks_ids: List[str] = betterproto.string_field(1)
179
- into_ks_ids: List[str] = betterproto.string_field(2)
184
+ results: List["CompactKeySpaceResult"] = betterproto.message_field(1)
185
+
186
+
187
+ @dataclass(eq=False, repr=False)
188
+ class CompactKeySpaceResult(betterproto.Message):
189
+ """
190
+ TODO(marko): Do we really need to know all of this? UpdateKeySpaceOp?
191
+ """
192
+
193
+ ks_id: str = betterproto.string_field(1)
194
+ compacted: "CompactKeySpaceResultCompacted" = betterproto.message_field(
195
+ 2, group="action"
196
+ )
197
+ """Key space has been compacted."""
198
+
199
+ created: "CompactKeySpaceResultCreated" = betterproto.message_field(
200
+ 3, group="action"
201
+ )
202
+ """New output key space has been created."""
203
+
204
+ moved: "CompactKeySpaceResultMoved" = betterproto.message_field(4, group="action")
205
+ """Key space has been promoted to L1."""
206
+
207
+ extended: "CompactKeySpaceResultExtended" = betterproto.message_field(
208
+ 5, group="action"
209
+ )
210
+ """Key space has been extended with new key files."""
211
+
212
+
213
+ @dataclass(eq=False, repr=False)
214
+ class CompactKeySpaceResultCompacted(betterproto.Message):
215
+ pass
216
+
217
+
218
+ @dataclass(eq=False, repr=False)
219
+ class CompactKeySpaceResultCreated(betterproto.Message):
220
+ pass
221
+
222
+
223
+ @dataclass(eq=False, repr=False)
224
+ class CompactKeySpaceResultMoved(betterproto.Message):
225
+ pass
226
+
227
+
228
+ @dataclass(eq=False, repr=False)
229
+ class CompactKeySpaceResultExtended(betterproto.Message):
230
+ pass
180
231
 
181
232
 
182
233
  @dataclass(eq=False, repr=False)
spiral/scan_.py CHANGED
@@ -13,6 +13,8 @@ if TYPE_CHECKING:
13
13
  import dask.dataframe as dd
14
14
  import pandas as pd
15
15
  import polars as pl
16
+ import pyarrow
17
+ import pyarrow.dataset
16
18
  from datasets import iterable_dataset
17
19
 
18
20
  tracer = trace.get_tracer("pyspiral.client.scan")
@@ -23,8 +25,6 @@ def scan(
23
25
  where: ExprLike | None = None,
24
26
  asof: datetime | int | str = None,
25
27
  exclude_keys: bool = False,
26
- # TODO(marko): Support config.
27
- # config: Config | None = None,
28
28
  ) -> "Scan":
29
29
  """Starts a read transaction on the spiral.
30
30
 
@@ -33,6 +33,7 @@ def scan(
33
33
  where: a query expression to apply to the data.
34
34
  asof: only data written before the given timestamp will be returned, caveats around compaction.
35
35
  exclude_keys: whether to exclude the key columns in the scan result, defaults to False.
36
+ Note that if a projection includes a key column, it will be included in the result.
36
37
  """
37
38
  from spiral import expressions as se
38
39
 
@@ -58,8 +59,6 @@ class Scan:
58
59
  def __init__(
59
60
  self,
60
61
  scan: TableScan,
61
- # TODO(marko): Support config.
62
- # config: Config | None = None,
63
62
  ):
64
63
  # NOTE(ngates): this API is a little weird. e.g. if the query doesn't define an asof, it is resolved
65
64
  # when we wrap it into a core.Scan. Should we expose a Query object in the Python API that's reusable
@@ -84,27 +83,57 @@ class Scan:
84
83
  """
85
84
  return self._scan.is_empty()
86
85
 
87
- def to_record_batches(self, key_table: pa.Table | pa.RecordBatchReader | None = None) -> pa.RecordBatchReader:
86
+ def to_dataset(
87
+ self,
88
+ key_table: pa.Table | pa.RecordBatchReader | None = None,
89
+ ) -> "pyarrow.dataset.Dataset":
90
+ """Returns a PyArrow Dataset representing the scan.
91
+
92
+ Args:
93
+ key_table: a table of keys to "take" (including aux columns for cell-push-down).
94
+ If None, the scan will be executed without a key table.
95
+ """
96
+ from .dataset import ScanDataset
97
+
98
+ return ScanDataset(self, key_table=key_table)
99
+
100
+ def to_record_batches(
101
+ self,
102
+ key_table: pa.Table | pa.RecordBatchReader | None = None,
103
+ batch_size: int | None = None,
104
+ batch_readahead: int | None = None,
105
+ ) -> pa.RecordBatchReader:
88
106
  """Read as a stream of RecordBatches.
89
107
 
90
108
  Args:
91
109
  key_table: a table of keys to "take" (including aux columns for cell-push-down).
110
+ If None, the scan will be executed without a key table.
111
+ batch_size: the maximum number of rows per returned batch.
112
+ IMPORTANT: This is currently only respected when the key_table is used. If key table is a
113
+ RecordBatchReader, the batch_size argument must be None, and the existing batching is respected.
114
+ batch_readahead: the number of batches to prefetch in the background.
92
115
  """
93
116
  if isinstance(key_table, pa.RecordBatchReader):
94
- raise NotImplementedError("RecordBatchReader is not supported as key_table")
117
+ if batch_size is not None:
118
+ raise ValueError(
119
+ "batch_size must be None when key_table is a RecordBatchReader, the existing batching is respected."
120
+ )
121
+ elif isinstance(key_table, pa.Table):
122
+ key_table = key_table.to_reader(max_chunksize=batch_size)
95
123
 
96
- # Prefix non-key columns in the key table with # (auxiliary) to avoid conflicts with the scan schema.
97
- if key_table is not None:
98
- key_columns = list(self._scan.key_schema().to_arrow().names)
99
- key_table = key_table.rename_columns(
100
- {name: f"#{name}" if name not in key_columns else name for name in key_table.schema.names}
101
- )
124
+ return self._scan.to_record_batches(key_table=key_table, batch_readahead=batch_readahead)
102
125
 
103
- return self._scan.to_record_batches(aux_table=key_table)
126
+ def to_table(
127
+ self,
128
+ key_table: pa.Table | pa.RecordBatchReader | None = None,
129
+ ) -> pa.Table:
130
+ """Read into a single PyArrow Table.
104
131
 
105
- def to_table(self) -> pa.Table:
106
- """Read into a single PyArrow Table."""
107
- return self.to_record_batches().read_all()
132
+ Args:
133
+ key_table: a table of keys to "take" (including aux columns for cell-push-down).
134
+ If None, the scan will be executed without a key table.
135
+ """
136
+ return self.to_record_batches(key_table=key_table).read_all()
108
137
 
109
138
  def to_dask(self) -> "dd.DataFrame":
110
139
  """Read into a Dask DataFrame.
@@ -121,32 +150,54 @@ class Scan:
121
150
  # Fetch a set of partition ranges
122
151
  return dd.from_map(_read_key_range, self.split())
123
152
 
124
- def to_pandas(self) -> "pd.DataFrame":
153
+ def to_pandas(
154
+ self,
155
+ key_table: pa.Table | pa.RecordBatchReader | None = None,
156
+ ) -> "pd.DataFrame":
125
157
  """Read into a Pandas DataFrame.
126
158
 
127
159
  Requires the `pandas` package to be installed.
160
+
161
+ Args:
162
+ key_table: a table of keys to "take" (including aux columns for cell-push-down).
163
+ If None, the scan will be executed without a key table.
128
164
  """
129
- return self.to_table().to_pandas()
165
+ return self.to_table(key_table=key_table).to_pandas()
130
166
 
131
- def to_polars(self) -> "pl.DataFrame":
132
- """Read into a Polars DataFrame.
167
+ def to_polars(self, key_table: pa.Table | pa.RecordBatchReader | None = None) -> "pl.LazyFrame":
168
+ """Read into a Polars LazyFrame.
133
169
 
134
170
  Requires the `polars` package to be installed.
171
+
172
+ Args:
173
+ key_table: a table of keys to "take" (including aux columns for cell-push-down).
174
+ If None, the scan will be executed without a key table.
135
175
  """
136
176
  import polars as pl
137
177
 
138
- # TODO(ngates): PR PyArrow to support lazy datasets
139
- return pl.from_arrow(self.to_record_batches())
178
+ return pl.scan_pyarrow_dataset(self.to_dataset(key_table=key_table))
140
179
 
141
- def to_pytorch(self) -> "iterable_dataset.IterableDataset":
180
+ def to_pytorch(
181
+ self,
182
+ key_table: pa.Table | pa.RecordBatchReader | None = None,
183
+ batch_readahead: int | None = None,
184
+ ) -> "iterable_dataset.IterableDataset":
142
185
  """Returns an iterable dataset that can be used to build a `pytorch.DataLoader`.
143
186
 
144
187
  Requires the `datasets` package to be installed.
188
+
189
+ Args:
190
+ key_table: a table of keys to "take" (including aux columns for cell-push-down).
191
+ If None, the scan will be executed without a key table.
192
+ batch_readahead: the number of batches to prefetch in the background.
145
193
  """
146
194
  from datasets.iterable_dataset import ArrowExamplesIterable, IterableDataset
147
195
 
148
196
  def _generate_tables(**kwargs) -> Iterator[tuple[int, pa.Table]]:
149
- stream = self.to_record_batches()
197
+ # Use batch size 1 when iterating samples, unless batch reader is already used.
198
+ stream = self.to_record_batches(
199
+ key_table, batch_size=1 if isinstance(key_table, pa.Table) else None, batch_readahead=batch_readahead
200
+ )
150
201
 
151
202
  # This key is unused when training with IterableDataset.
152
203
  # Default implementation returns shard id, e.g. parquet row group id.
spiral/settings.py CHANGED
@@ -29,6 +29,8 @@ from spiral.authn.github_ import GitHubActionsProvider
29
29
  from spiral.authn.modal_ import ModalProvider
30
30
 
31
31
  DEV = "PYTEST_VERSION" in os.environ or bool(os.environ.get("SPIRAL_DEV", None))
32
+ FILE_FORMAT = os.environ.get("SPIRAL_FILE_FORMAT", "parquet")
33
+
32
34
  APP_DIR = Path(typer.get_app_dir("pyspiral"))
33
35
  LOG_DIR = APP_DIR / "logs"
34
36
  CONFIG_FILE = APP_DIR / "config.toml"
@@ -67,6 +69,10 @@ class SpiralDBSettings(BaseSettings):
67
69
  # TODO(marko): Scandal will be a different service. For now, gRPC API is hosted on the SpiralDB service.
68
70
  return f"{'grpc+tls' if self.ssl else 'grpc'}://{self.host}:{self.port}"
69
71
 
72
+ @property
73
+ def uri_iceberg(self) -> str:
74
+ return self.uri + "/iceberg"
75
+
70
76
  def device_auth(self) -> DeviceAuth:
71
77
  auth_file = (
72
78
  APP_DIR / hashlib.md5(f"{self.auth.domain}/{self.auth.client_id}".encode()).hexdigest() / "auth.json"