pyspiral 0.7.18__cp312-abi3-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. pyspiral-0.7.18.dist-info/METADATA +52 -0
  2. pyspiral-0.7.18.dist-info/RECORD +110 -0
  3. pyspiral-0.7.18.dist-info/WHEEL +4 -0
  4. pyspiral-0.7.18.dist-info/entry_points.txt +3 -0
  5. spiral/__init__.py +55 -0
  6. spiral/_lib.abi3.so +0 -0
  7. spiral/adbc.py +411 -0
  8. spiral/api/__init__.py +78 -0
  9. spiral/api/admin.py +15 -0
  10. spiral/api/client.py +164 -0
  11. spiral/api/filesystems.py +134 -0
  12. spiral/api/key_space_indexes.py +23 -0
  13. spiral/api/organizations.py +77 -0
  14. spiral/api/projects.py +219 -0
  15. spiral/api/telemetry.py +19 -0
  16. spiral/api/text_indexes.py +56 -0
  17. spiral/api/types.py +23 -0
  18. spiral/api/workers.py +40 -0
  19. spiral/api/workloads.py +52 -0
  20. spiral/arrow_.py +216 -0
  21. spiral/cli/__init__.py +88 -0
  22. spiral/cli/__main__.py +4 -0
  23. spiral/cli/admin.py +14 -0
  24. spiral/cli/app.py +108 -0
  25. spiral/cli/console.py +95 -0
  26. spiral/cli/fs.py +76 -0
  27. spiral/cli/iceberg.py +97 -0
  28. spiral/cli/key_spaces.py +103 -0
  29. spiral/cli/login.py +25 -0
  30. spiral/cli/orgs.py +90 -0
  31. spiral/cli/printer.py +53 -0
  32. spiral/cli/projects.py +147 -0
  33. spiral/cli/state.py +7 -0
  34. spiral/cli/tables.py +197 -0
  35. spiral/cli/telemetry.py +17 -0
  36. spiral/cli/text.py +115 -0
  37. spiral/cli/types.py +50 -0
  38. spiral/cli/workloads.py +58 -0
  39. spiral/client.py +256 -0
  40. spiral/core/__init__.pyi +0 -0
  41. spiral/core/_tools/__init__.pyi +5 -0
  42. spiral/core/authn/__init__.pyi +21 -0
  43. spiral/core/client/__init__.pyi +285 -0
  44. spiral/core/config/__init__.pyi +35 -0
  45. spiral/core/expr/__init__.pyi +15 -0
  46. spiral/core/expr/images/__init__.pyi +3 -0
  47. spiral/core/expr/list_/__init__.pyi +4 -0
  48. spiral/core/expr/refs/__init__.pyi +4 -0
  49. spiral/core/expr/str_/__init__.pyi +3 -0
  50. spiral/core/expr/struct_/__init__.pyi +6 -0
  51. spiral/core/expr/text/__init__.pyi +5 -0
  52. spiral/core/expr/udf/__init__.pyi +14 -0
  53. spiral/core/expr/video/__init__.pyi +3 -0
  54. spiral/core/table/__init__.pyi +141 -0
  55. spiral/core/table/manifests/__init__.pyi +35 -0
  56. spiral/core/table/metastore/__init__.pyi +58 -0
  57. spiral/core/table/spec/__init__.pyi +215 -0
  58. spiral/dataloader.py +299 -0
  59. spiral/dataset.py +264 -0
  60. spiral/datetime_.py +27 -0
  61. spiral/debug/__init__.py +0 -0
  62. spiral/debug/manifests.py +87 -0
  63. spiral/debug/metrics.py +56 -0
  64. spiral/debug/scan.py +266 -0
  65. spiral/enrichment.py +306 -0
  66. spiral/expressions/__init__.py +274 -0
  67. spiral/expressions/base.py +167 -0
  68. spiral/expressions/file.py +17 -0
  69. spiral/expressions/http.py +17 -0
  70. spiral/expressions/list_.py +68 -0
  71. spiral/expressions/s3.py +16 -0
  72. spiral/expressions/str_.py +39 -0
  73. spiral/expressions/struct.py +59 -0
  74. spiral/expressions/text.py +62 -0
  75. spiral/expressions/tiff.py +222 -0
  76. spiral/expressions/udf.py +60 -0
  77. spiral/grpc_.py +32 -0
  78. spiral/iceberg.py +31 -0
  79. spiral/iterable_dataset.py +106 -0
  80. spiral/key_space_index.py +44 -0
  81. spiral/project.py +227 -0
  82. spiral/protogen/_/__init__.py +0 -0
  83. spiral/protogen/_/arrow/__init__.py +0 -0
  84. spiral/protogen/_/arrow/flight/__init__.py +0 -0
  85. spiral/protogen/_/arrow/flight/protocol/__init__.py +0 -0
  86. spiral/protogen/_/arrow/flight/protocol/sql/__init__.py +2548 -0
  87. spiral/protogen/_/google/__init__.py +0 -0
  88. spiral/protogen/_/google/protobuf/__init__.py +2310 -0
  89. spiral/protogen/_/message_pool.py +3 -0
  90. spiral/protogen/_/py.typed +0 -0
  91. spiral/protogen/_/scandal/__init__.py +190 -0
  92. spiral/protogen/_/spfs/__init__.py +72 -0
  93. spiral/protogen/_/spql/__init__.py +61 -0
  94. spiral/protogen/_/substrait/__init__.py +6196 -0
  95. spiral/protogen/_/substrait/extensions/__init__.py +169 -0
  96. spiral/protogen/__init__.py +0 -0
  97. spiral/protogen/util.py +41 -0
  98. spiral/py.typed +0 -0
  99. spiral/scan.py +363 -0
  100. spiral/server.py +17 -0
  101. spiral/settings.py +36 -0
  102. spiral/snapshot.py +56 -0
  103. spiral/streaming_/__init__.py +3 -0
  104. spiral/streaming_/reader.py +133 -0
  105. spiral/streaming_/stream.py +157 -0
  106. spiral/substrait_.py +274 -0
  107. spiral/table.py +224 -0
  108. spiral/text_index.py +17 -0
  109. spiral/transaction.py +155 -0
  110. spiral/types_.py +6 -0
@@ -0,0 +1,62 @@
1
+ from spiral.expressions.base import Expr, ExprLike
2
+
3
+
4
+ def field(expr: ExprLike, field_name: str | None = None, tokenizer: str | None = None) -> Expr:
5
+ """Configure a column for text indexing.
6
+
7
+ Args:
8
+ expr: An input column. The expression must either evaluate to a UTF-8,
9
+ or, if a `field_name` is provided, to a struct with a field of that name.
10
+ field_name: If provided, the expression must evaluate to a struct with a field of that name.
11
+ The given field will be indexed.
12
+ tokenizer: If provided, the text will be tokenized using the given tokenizer.
13
+
14
+ Returns:
15
+ An expression that can be used to construct a text index.
16
+ """
17
+ from spiral import _lib
18
+ from spiral.expressions import getitem, lift, merge, pack
19
+
20
+ expr = lift(expr)
21
+ if field_name is None:
22
+ return Expr(_lib.expr.text.field(expr.__expr__, tokenizer))
23
+
24
+ child = _lib.expr.text.field(getitem(expr, field_name).__expr__)
25
+ return merge(
26
+ expr,
27
+ pack({field_name: child}),
28
+ )
29
+
30
+
31
+ def find(expr: ExprLike, term: str) -> Expr:
32
+ """Search for a term in the text.
33
+
34
+ Args:
35
+ expr: An index field.
36
+ term: The term to search for.
37
+
38
+ Returns:
39
+ An expression that can be used in ranking for text search.
40
+ """
41
+ from spiral import _lib
42
+ from spiral.expressions import lift
43
+
44
+ expr = lift(expr)
45
+ return Expr(_lib.expr.text.find(expr.__expr__, term))
46
+
47
+
48
+ def boost(expr: ExprLike, factor: float) -> Expr:
49
+ """Boost the relevance of a ranking expression.
50
+
51
+ Args:
52
+ expr: Rank by expression.
53
+ factor: The factor by which to boost the relevance.
54
+
55
+ Returns:
56
+ An expression that can be used in ranking for text search.
57
+ """
58
+ from spiral import _lib
59
+ from spiral.expressions import lift
60
+
61
+ expr = lift(expr)
62
+ return Expr(_lib.expr.text.boost(expr.__expr__, factor))
@@ -0,0 +1,222 @@
1
+ import numpy as np
2
+ import pyarrow as pa
3
+
4
+ from spiral.expressions.base import Expr, ExprLike
5
+
6
+ _TIFF_RES_DTYPE: pa.DataType = pa.struct(
7
+ [
8
+ pa.field("pixels", pa.large_binary()),
9
+ pa.field("height", pa.uint32()),
10
+ pa.field("width", pa.uint32()),
11
+ pa.field("channels", pa.uint8()),
12
+ pa.field("channel_bit_depth", pa.uint8()),
13
+ ]
14
+ )
15
+
16
+
17
+ def read(
18
+ expr: ExprLike,
19
+ indexes: ExprLike | int | None = None,
20
+ window: ExprLike | tuple[tuple[int, int], tuple[int, int]] | None = None,
21
+ boundless: ExprLike | bool | None = None,
22
+ ) -> Expr:
23
+ """
24
+ Read referenced cell in a `TIFF` format. Requires `rasterio` to be installed.
25
+
26
+ Args:
27
+ expr: The referenced `TIFF` bytes.
28
+ indexes: The band indexes to read. Defaults to all.
29
+ window: The window to read. In format (row_range_tuple, col_range_tuple). Defaults to full window.
30
+ boundless: If `True`, windows that extend beyond the dataset's extent
31
+ are permitted and partially or completely filled arrays will be returned as appropriate.
32
+
33
+ Returns:
34
+ An array where each element is a decoded image with fields:
35
+ pixels: bytes of shape (channels, width, height).
36
+ width: Width of the image with type `pa.uint32()`.
37
+ height: Height of the image with type `pa.uint32()`.
38
+ channels: Number of channels of the image with type `pa.uint8()`.
39
+ If `indexes` is not None, this is the length of `indexes` or 1 if `indexes` is an int.
40
+ channel_bit_depth: Bit depth of the channel with type `pa.uint8()`.
41
+ """
42
+ try:
43
+ import rasterio # noqa: F401
44
+ except ImportError:
45
+ raise ImportError("`rasterio` is required for tiff.read")
46
+
47
+ return TiffReadUDF()(expr, indexes, window, boundless)
48
+
49
+
50
+ def select(
51
+ expr: ExprLike,
52
+ shape: ExprLike | dict,
53
+ indexes: ExprLike | int | None = None,
54
+ ) -> Expr:
55
+ """
56
+ Select the shape out of the referenced cell in a `TIFF` format. Requires `rasterio` to be installed.
57
+
58
+ Args:
59
+ expr: The referenced `TIFF` bytes.
60
+ shape: [GeoJSON-like](https://geojson.org/) shape.
61
+ indexes: The band indexes to read. Defaults to all.
62
+
63
+ Returns:
64
+ An array where each element is a decoded image with fields:
65
+ pixels: bytes of shape (len(indexes) or 1, width, height).
66
+ width: Width of the image with type `pa.uint32()`.
67
+ height: Height of the image with type `pa.uint32()`.
68
+ channels: Number of channels of the image with type `pa.uint8()`.
69
+ If `indexes` is not None, this is the length of `indexes` or 1 if `indexes` is an int.
70
+ channel_bit_depth: Bit depth of the channel with type `pa.uint8()`.
71
+ """
72
+ try:
73
+ import rasterio # noqa: F401
74
+ except ImportError:
75
+ raise ImportError("`rasterio` is required for tiff.select")
76
+
77
+ return TiffSelectUDF()(expr, shape, indexes)
78
+
79
+
80
+ class TiffReadUDF:
81
+ def __init__(self):
82
+ super().__init__("tiff.read")
83
+
84
+ def return_type(self, *input_types: pa.DataType) -> pa.DataType:
85
+ return _TIFF_RES_DTYPE
86
+
87
+ def invoke(self, fp, *input_args: pa.Array) -> pa.Array:
88
+ try:
89
+ import rasterio
90
+ except ImportError:
91
+ raise ImportError("`rasterio` is required for tiff.read")
92
+
93
+ from rasterio.windows import Window
94
+
95
+ if len(input_args) != 4:
96
+ raise ValueError("tiff.read expects exactly 4 arguments: expr, indexes, window, boundless")
97
+
98
+ _, indexes, window, boundless = input_args
99
+
100
+ indexes = indexes[0].as_py()
101
+ if indexes is not None and not isinstance(indexes, int) and not isinstance(indexes, list):
102
+ raise ValueError(f"tiff.read expects indexes to be None or an int or a list, got {indexes}")
103
+
104
+ boundless = boundless[0].as_py()
105
+ if boundless is not None and not isinstance(boundless, bool):
106
+ raise ValueError(f"tiff.read expects boundless to be None or a bool, got {boundless}")
107
+
108
+ window = window[0].as_py()
109
+ if window is not None:
110
+ if len(window) != 2:
111
+ raise ValueError(f"tiff.read window invalid, got {window}")
112
+ window = Window.from_slices(slice(*window[0]), slice(*window[1]), boundless=boundless or False)
113
+
114
+ opener = _VsiOpener(fp)
115
+ with rasterio.open("ref", opener=opener) as src:
116
+ src: rasterio.DatasetReader
117
+ # TODO(marko): We know the size and dtype so we should be able to preallocate the result and read into it.
118
+ # This matters more if we want to rewrite this function to work with multiple inputs at once, in which
119
+ # case we should first consider using Rust GDAL bindings - I believe rasterio uses GDAL under the hood.
120
+ result: np.ndarray = src.read(indexes=indexes, window=window)
121
+ return _return_result(result, indexes)
122
+
123
+
124
+ class TiffSelectUDF:
125
+ def __init__(self):
126
+ super().__init__("tiff.select")
127
+
128
+ def return_type(self, *input_types: pa.DataType) -> pa.DataType:
129
+ return _TIFF_RES_DTYPE
130
+
131
+ def invoke(self, fp, *input_args: pa.Array) -> pa.Array:
132
+ try:
133
+ import rasterio
134
+ except ImportError:
135
+ raise ImportError("`rasterio` is required for tiff.select")
136
+
137
+ from rasterio.mask import raster_geometry_mask
138
+
139
+ if len(input_args) != 3:
140
+ raise ValueError("tiff.select expects exactly 3 arguments: expr, shape, indexes")
141
+
142
+ _, shape, indexes = input_args
143
+
144
+ shape = shape[0].as_py()
145
+ if shape is None:
146
+ raise ValueError("tiff.select expects shape to be a GeoJSON-like shape")
147
+
148
+ indexes = indexes[0].as_py()
149
+ if indexes is not None and not isinstance(indexes, int) and not isinstance(indexes, list):
150
+ raise ValueError(f"tiff.select expects indexes to be None or an int or a list, got {indexes}")
151
+
152
+ opener = _VsiOpener(fp)
153
+ with rasterio.open("ref", opener=opener) as src:
154
+ src: rasterio.DatasetReader
155
+
156
+ shape_mask, _, window = raster_geometry_mask(src, [shape], crop=True)
157
+ out_shape = (src.count,) + shape_mask.shape
158
+
159
+ result: np.ndarray = src.read(window=window, indexes=indexes, out_shape=out_shape, masked=True)
160
+ return _return_result(result, indexes)
161
+
162
+
163
+ def _return_result(result: np.ndarray, indexes) -> pa.Array:
164
+ channels = result.shape[0]
165
+ if indexes is None:
166
+ pass
167
+ elif isinstance(indexes, int):
168
+ assert channels == 1, f"Expected 1 channel, got {channels}"
169
+ else:
170
+ assert channels == len(indexes), f"Expected {len(indexes)} channels, got {channels}"
171
+
172
+ if result.dtype == np.uint8:
173
+ channel_bit_depth = 8
174
+ elif result.dtype == np.uint16:
175
+ channel_bit_depth = 16
176
+ else:
177
+ raise ValueError(f"Unsupported bit width: {result.dtype}")
178
+
179
+ return pa.array(
180
+ [
181
+ {
182
+ "pixels": result.tobytes(),
183
+ "height": result.shape[1],
184
+ "width": result.shape[2],
185
+ "channels": channels,
186
+ "channel_bit_depth": channel_bit_depth,
187
+ }
188
+ ],
189
+ type=_TIFF_RES_DTYPE,
190
+ )
191
+
192
+
193
+ class _VsiOpener:
194
+ """
195
+ VSI file opener which returns a constant file-like on open.
196
+
197
+ Must match https://rasterio.readthedocs.io/en/stable/topics/vsi.html#python-file-and-filesystem-openers spec but
198
+ only `open` is needed when going through rasterio.
199
+ """
200
+
201
+ def __init__(self, file_like):
202
+ self._file_like = file_like
203
+
204
+ def open(self, _path, mode):
205
+ if mode not in {"r", "rb"}:
206
+ raise ValueError(f"Unsupported mode: {mode}")
207
+ return self._file_like
208
+
209
+ def isdir(self, _):
210
+ return False
211
+
212
+ def isfile(self, _):
213
+ return False
214
+
215
+ def mtime(self, _):
216
+ return 0
217
+
218
+ def size(self, _):
219
+ return self._file_like.size()
220
+
221
+ def modified(self, _):
222
+ raise NotImplementedError
@@ -0,0 +1,60 @@
1
+ import abc
2
+
3
+ import pyarrow as pa
4
+
5
+ from spiral import _lib
6
+ from spiral.expressions.base import Expr, ExprLike
7
+
8
+
9
+ class UDF(abc.ABC):
10
+ """A User-Defined Function (UDF). This class should be subclassed to define custom UDFs.
11
+
12
+ Example:
13
+
14
+ ```python
15
+ from spiral import expressions as se
16
+ import pyarrow as pa
17
+
18
+ class MyAdd(se.UDF):
19
+ def __init__(self):
20
+ super().__init__("my_add")
21
+
22
+ def return_type(self, scope: pa.DataType):
23
+ if not isinstance(scope, pa.StructType):
24
+ raise ValueError("Expected struct type as input")
25
+ return scope.field(0).type
26
+
27
+ def invoke(self, scope: pa.Array):
28
+ if not isinstance(scope, pa.StructArray):
29
+ raise ValueError("Expected struct array as input")
30
+ return pa.compute.add(scope.field(0), scope.field(1))
31
+
32
+ my_add = MyAdd()
33
+
34
+ expr = my_add(table.select("first_arg", "second_arg"))
35
+ ```
36
+ """
37
+
38
+ def __init__(self, name: str):
39
+ self._udf = _lib.expr.udf.create(name, return_type=self.return_type, invoke=self.invoke)
40
+
41
+ def __call__(self, scope: ExprLike) -> Expr:
42
+ """Create an expression that calls this UDF with the given arguments."""
43
+ from spiral import expressions as se
44
+
45
+ return Expr(self._udf(se.lift(scope).__expr__))
46
+
47
+ @abc.abstractmethod
48
+ def return_type(self, scope: pa.DataType) -> pa.DataType:
49
+ """Must return the return type of the UDF given the input scope type.
50
+
51
+ All expressions in Spiral must return nullable (Arrow default) types,
52
+ including nested structs, meaning that all fields in structs must also be nullable,
53
+ and if those fields are structs, their fields must also be nullable, and so on.
54
+ """
55
+ ...
56
+
57
+ @abc.abstractmethod
58
+ def invoke(self, scope: pa.Array) -> pa.Array:
59
+ """Must implement the UDF logic given the input scope array."""
60
+ ...
spiral/grpc_.py ADDED
@@ -0,0 +1,32 @@
1
+ from collections.abc import AsyncIterator, Awaitable, Callable
2
+ from typing import TypeVar
3
+
4
+ R = TypeVar("R")
5
+ T = TypeVar("T")
6
+
7
+
8
+ async def paged(stub_fn: Callable[[R], Awaitable[T]], request: R, page_size: int = None) -> AsyncIterator[T]:
9
+ """Page through a gRPC paged API.
10
+
11
+ Assumes fields exist as per https://cloud.google.com/apis/design/design_patterns#list_pagination
12
+ """
13
+ next_page_token: str | None = None
14
+ while True:
15
+ request.page_size = page_size
16
+ request.page_token = next_page_token
17
+ res = await stub_fn(request)
18
+ if not res.next_page_token:
19
+ # No more items
20
+ yield res
21
+ break
22
+
23
+ next_page_token = res.next_page_token
24
+ yield res
25
+
26
+
27
+ async def paged_items(
28
+ stub_fn: Callable[[R], Awaitable[T]], request: R, collection_name: str, page_size: int = None
29
+ ) -> AsyncIterator:
30
+ async for page in paged(stub_fn, request, page_size=page_size):
31
+ for item in getattr(page, collection_name):
32
+ yield item
spiral/iceberg.py ADDED
@@ -0,0 +1,31 @@
1
+ from typing import TYPE_CHECKING
2
+
3
+ if TYPE_CHECKING:
4
+ from pyiceberg.catalog import Catalog
5
+
6
+ from spiral.client import Spiral
7
+
8
+
9
+ class Iceberg:
10
+ """
11
+ Apache Iceberg is a powerful open-source table format designed for high-performance data lakes.
12
+ Iceberg brings reliability, scalability, and advanced features like time travel, schema evolution,
13
+ and ACID transactions to your warehouse.
14
+ """
15
+
16
+ def __init__(self, spiral: "Spiral"):
17
+ self._spiral = spiral
18
+ self._api = self._spiral.api
19
+
20
+ def catalog(self) -> "Catalog":
21
+ """Open the Iceberg catalog."""
22
+ from pyiceberg.catalog import load_catalog
23
+
24
+ return load_catalog(
25
+ "default",
26
+ **{
27
+ "type": "rest",
28
+ "uri": self._spiral.config.server_url + "/iceberg",
29
+ "token": self._spiral.authn.token().expose_secret(),
30
+ },
31
+ )
@@ -0,0 +1,106 @@
1
+ from collections.abc import Callable, Iterator
2
+ from typing import TYPE_CHECKING
3
+
4
+ import pyarrow as pa
5
+
6
+ if TYPE_CHECKING:
7
+ import datasets.iterable_dataset as hf # noqa
8
+ import streaming # noqa
9
+ import torch.utils.data as torchdata # noqa
10
+
11
+
12
+ def _hf_compatible_schema(schema: pa.Schema) -> pa.Schema:
13
+ """
14
+ Replace string-view and binary-view columns in the schema with strings/binary.
15
+ Recursively handles nested types (struct, list, etc).
16
+ We use this converted schema as Features in the returned Dataset.
17
+ Remove this method once we have https://github.com/huggingface/datasets/pull/7718
18
+ """
19
+
20
+ def _convert_type(dtype: pa.DataType) -> pa.DataType:
21
+ if dtype == pa.string_view():
22
+ return pa.string()
23
+ elif dtype == pa.binary_view():
24
+ return pa.binary()
25
+ elif pa.types.is_struct(dtype):
26
+ new_fields = [
27
+ pa.field(field.name, _convert_type(field.type), nullable=field.nullable, metadata=field.metadata)
28
+ for field in dtype
29
+ ]
30
+ return pa.struct(new_fields)
31
+ elif pa.types.is_list(dtype):
32
+ return pa.list_(_convert_type(dtype.value_type))
33
+ elif pa.types.is_large_list(dtype):
34
+ return pa.large_list(_convert_type(dtype.value_type))
35
+ elif pa.types.is_fixed_size_list(dtype):
36
+ return pa.list_(_convert_type(dtype.value_type), dtype.list_size)
37
+ elif pa.types.is_map(dtype):
38
+ return pa.map_(_convert_type(dtype.key_type), _convert_type(dtype.item_type))
39
+ else:
40
+ return dtype
41
+
42
+ new_fields = []
43
+ for field in schema:
44
+ new_type = _convert_type(field.type)
45
+ new_fields.append(pa.field(field.name, new_type, nullable=field.nullable, metadata=field.metadata))
46
+
47
+ return pa.schema(new_fields)
48
+
49
+
50
+ def to_iterable_dataset(stream: pa.RecordBatchReader) -> "hf.IterableDataset":
51
+ from datasets import DatasetInfo, Features
52
+ from datasets.builder import ArrowExamplesIterable
53
+ from datasets.iterable_dataset import IterableDataset
54
+
55
+ def _generate_tables(**kwargs) -> Iterator[tuple[int, pa.Table]]:
56
+ # This key is unused when training with IterableDataset.
57
+ # Default implementation returns shard id, e.g. parquet row group id.
58
+ for i, rb in enumerate(stream):
59
+ yield i, pa.Table.from_batches([rb], stream.schema)
60
+
61
+ # TODO(marko): This is temporary until we stop returning IterableDataset from this function.
62
+ class _IterableDataset(IterableDataset):
63
+ # Diff with datasets.iterable_dataset.IterableDataset:
64
+ # - Removes torch handling which attempts to handle worker processes.
65
+ # - Assumes arrow iterator.
66
+ def __iter__(self):
67
+ from datasets.formatting import get_formatter
68
+
69
+ prepared_ex_iterable = self._prepare_ex_iterable_for_iteration()
70
+ if self._formatting and (prepared_ex_iterable.iter_arrow or self._formatting.is_table):
71
+ formatter = get_formatter(self._formatting.format_type, features=self.features)
72
+ iterator = prepared_ex_iterable.iter_arrow()
73
+ for key, pa_table in iterator:
74
+ yield formatter.format_row(pa_table)
75
+ return
76
+
77
+ for key, example in prepared_ex_iterable:
78
+ # no need to format thanks to FormattedExamplesIterable
79
+ yield example
80
+
81
+ def map(self, *args, **kwargs):
82
+ # Map constructs a new IterableDataset, so we need to "patch" it
83
+ base = super().map(*args, **kwargs)
84
+ if isinstance(base, IterableDataset):
85
+ # Patch __iter__ to avoid torch handling
86
+ base.__class__ = _IterableDataset # type: ignore
87
+ return base
88
+
89
+ class _ArrowExamplesIterable(ArrowExamplesIterable):
90
+ def __init__(self, generate_tables_fn: Callable[..., Iterator[tuple[int, pa.Table]]], features: Features):
91
+ # NOTE: generate_tables_fn type annotations are wrong, return type must be an iterable of tuples.
92
+ super().__init__(generate_tables_fn, kwargs={}) # type: ignore
93
+ self._features = features
94
+
95
+ @property
96
+ def is_typed(self) -> bool:
97
+ return True
98
+
99
+ @property
100
+ def features(self) -> Features:
101
+ return self._features
102
+
103
+ target_features = Features.from_arrow_schema(_hf_compatible_schema(stream.schema))
104
+ ex_iterable = _ArrowExamplesIterable(_generate_tables, target_features)
105
+ info = DatasetInfo(features=target_features)
106
+ return _IterableDataset(ex_iterable=ex_iterable, info=info)
@@ -0,0 +1,44 @@
1
+ from spiral.core.client import KeySpaceIndex as CoreKeySpaceIndex
2
+ from spiral.expressions import Expr
3
+ from spiral.types_ import Timestamp
4
+
5
+
6
+ class KeySpaceIndex:
7
+ """
8
+ KeysIndex represents an optionally materialized key space, defined by a projection and a filter over a table.
9
+ It can be used to efficiently and precisely shard the table for parallel processing or distributed training.
10
+
11
+ An index is defined by:
12
+ - A granularity that defines the target size of key ranges in the index.
13
+ IMPORTANT: Actual key ranges may be smaller, but will not exceed twice the granularity.
14
+ - A projection expression that defines which columns are included in the resulting key space.
15
+ - An optional filter expression that defines which rows are included in the index.
16
+ """
17
+
18
+ def __init__(self, core: CoreKeySpaceIndex, *, name: str | None = None):
19
+ self.core = core
20
+ self._name = name
21
+
22
+ @property
23
+ def index_id(self) -> str:
24
+ return self.core.id
25
+
26
+ @property
27
+ def table_id(self) -> str:
28
+ return self.core.table_id
29
+
30
+ @property
31
+ def name(self) -> str:
32
+ return self._name or self.index_id
33
+
34
+ @property
35
+ def asof(self) -> Timestamp:
36
+ return self.core.asof
37
+
38
+ @property
39
+ def projection(self) -> Expr:
40
+ return Expr(self.core.projection)
41
+
42
+ @property
43
+ def filter(self) -> Expr | None:
44
+ return Expr(self.core.filter) if self.core.filter is not None else None