pyspiral 0.7.18__cp312-abi3-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. pyspiral-0.7.18.dist-info/METADATA +52 -0
  2. pyspiral-0.7.18.dist-info/RECORD +110 -0
  3. pyspiral-0.7.18.dist-info/WHEEL +4 -0
  4. pyspiral-0.7.18.dist-info/entry_points.txt +3 -0
  5. spiral/__init__.py +55 -0
  6. spiral/_lib.abi3.so +0 -0
  7. spiral/adbc.py +411 -0
  8. spiral/api/__init__.py +78 -0
  9. spiral/api/admin.py +15 -0
  10. spiral/api/client.py +164 -0
  11. spiral/api/filesystems.py +134 -0
  12. spiral/api/key_space_indexes.py +23 -0
  13. spiral/api/organizations.py +77 -0
  14. spiral/api/projects.py +219 -0
  15. spiral/api/telemetry.py +19 -0
  16. spiral/api/text_indexes.py +56 -0
  17. spiral/api/types.py +23 -0
  18. spiral/api/workers.py +40 -0
  19. spiral/api/workloads.py +52 -0
  20. spiral/arrow_.py +216 -0
  21. spiral/cli/__init__.py +88 -0
  22. spiral/cli/__main__.py +4 -0
  23. spiral/cli/admin.py +14 -0
  24. spiral/cli/app.py +108 -0
  25. spiral/cli/console.py +95 -0
  26. spiral/cli/fs.py +76 -0
  27. spiral/cli/iceberg.py +97 -0
  28. spiral/cli/key_spaces.py +103 -0
  29. spiral/cli/login.py +25 -0
  30. spiral/cli/orgs.py +90 -0
  31. spiral/cli/printer.py +53 -0
  32. spiral/cli/projects.py +147 -0
  33. spiral/cli/state.py +7 -0
  34. spiral/cli/tables.py +197 -0
  35. spiral/cli/telemetry.py +17 -0
  36. spiral/cli/text.py +115 -0
  37. spiral/cli/types.py +50 -0
  38. spiral/cli/workloads.py +58 -0
  39. spiral/client.py +256 -0
  40. spiral/core/__init__.pyi +0 -0
  41. spiral/core/_tools/__init__.pyi +5 -0
  42. spiral/core/authn/__init__.pyi +21 -0
  43. spiral/core/client/__init__.pyi +285 -0
  44. spiral/core/config/__init__.pyi +35 -0
  45. spiral/core/expr/__init__.pyi +15 -0
  46. spiral/core/expr/images/__init__.pyi +3 -0
  47. spiral/core/expr/list_/__init__.pyi +4 -0
  48. spiral/core/expr/refs/__init__.pyi +4 -0
  49. spiral/core/expr/str_/__init__.pyi +3 -0
  50. spiral/core/expr/struct_/__init__.pyi +6 -0
  51. spiral/core/expr/text/__init__.pyi +5 -0
  52. spiral/core/expr/udf/__init__.pyi +14 -0
  53. spiral/core/expr/video/__init__.pyi +3 -0
  54. spiral/core/table/__init__.pyi +141 -0
  55. spiral/core/table/manifests/__init__.pyi +35 -0
  56. spiral/core/table/metastore/__init__.pyi +58 -0
  57. spiral/core/table/spec/__init__.pyi +215 -0
  58. spiral/dataloader.py +299 -0
  59. spiral/dataset.py +264 -0
  60. spiral/datetime_.py +27 -0
  61. spiral/debug/__init__.py +0 -0
  62. spiral/debug/manifests.py +87 -0
  63. spiral/debug/metrics.py +56 -0
  64. spiral/debug/scan.py +266 -0
  65. spiral/enrichment.py +306 -0
  66. spiral/expressions/__init__.py +274 -0
  67. spiral/expressions/base.py +167 -0
  68. spiral/expressions/file.py +17 -0
  69. spiral/expressions/http.py +17 -0
  70. spiral/expressions/list_.py +68 -0
  71. spiral/expressions/s3.py +16 -0
  72. spiral/expressions/str_.py +39 -0
  73. spiral/expressions/struct.py +59 -0
  74. spiral/expressions/text.py +62 -0
  75. spiral/expressions/tiff.py +222 -0
  76. spiral/expressions/udf.py +60 -0
  77. spiral/grpc_.py +32 -0
  78. spiral/iceberg.py +31 -0
  79. spiral/iterable_dataset.py +106 -0
  80. spiral/key_space_index.py +44 -0
  81. spiral/project.py +227 -0
  82. spiral/protogen/_/__init__.py +0 -0
  83. spiral/protogen/_/arrow/__init__.py +0 -0
  84. spiral/protogen/_/arrow/flight/__init__.py +0 -0
  85. spiral/protogen/_/arrow/flight/protocol/__init__.py +0 -0
  86. spiral/protogen/_/arrow/flight/protocol/sql/__init__.py +2548 -0
  87. spiral/protogen/_/google/__init__.py +0 -0
  88. spiral/protogen/_/google/protobuf/__init__.py +2310 -0
  89. spiral/protogen/_/message_pool.py +3 -0
  90. spiral/protogen/_/py.typed +0 -0
  91. spiral/protogen/_/scandal/__init__.py +190 -0
  92. spiral/protogen/_/spfs/__init__.py +72 -0
  93. spiral/protogen/_/spql/__init__.py +61 -0
  94. spiral/protogen/_/substrait/__init__.py +6196 -0
  95. spiral/protogen/_/substrait/extensions/__init__.py +169 -0
  96. spiral/protogen/__init__.py +0 -0
  97. spiral/protogen/util.py +41 -0
  98. spiral/py.typed +0 -0
  99. spiral/scan.py +363 -0
  100. spiral/server.py +17 -0
  101. spiral/settings.py +36 -0
  102. spiral/snapshot.py +56 -0
  103. spiral/streaming_/__init__.py +3 -0
  104. spiral/streaming_/reader.py +133 -0
  105. spiral/streaming_/stream.py +157 -0
  106. spiral/substrait_.py +274 -0
  107. spiral/table.py +224 -0
  108. spiral/text_index.py +17 -0
  109. spiral/transaction.py +155 -0
  110. spiral/types_.py +6 -0
@@ -0,0 +1,169 @@
1
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
2
+ # sources: substrait/extensions/extensions.proto
3
+ # plugin: python-betterproto2
4
+ # This file has been @generated
5
+
6
+ __all__ = (
7
+ "AdvancedExtension",
8
+ "SimpleExtensionDeclaration",
9
+ "SimpleExtensionDeclarationExtensionFunction",
10
+ "SimpleExtensionDeclarationExtensionType",
11
+ "SimpleExtensionDeclarationExtensionTypeVariation",
12
+ "SimpleExtensionUri",
13
+ )
14
+
15
+ from dataclasses import dataclass
16
+
17
+ import betterproto2
18
+
19
+ from ...message_pool import default_message_pool
20
+
21
+ _COMPILER_VERSION = "0.9.0"
22
+ betterproto2.check_compiler_version(_COMPILER_VERSION)
23
+
24
+
25
+ @dataclass(eq=False, repr=False)
26
+ class AdvancedExtension(betterproto2.Message):
27
+ """
28
+ A generic object that can be used to embed additional extension information
29
+ into the serialized substrait plan.
30
+ """
31
+
32
+ optimization: "list[__google__protobuf__.Any]" = betterproto2.field(1, betterproto2.TYPE_MESSAGE, repeated=True)
33
+ """
34
+ An optimization is helpful information that don't influence semantics. May
35
+ be ignored by a consumer.
36
+ """
37
+
38
+ enhancement: "__google__protobuf__.Any | None" = betterproto2.field(2, betterproto2.TYPE_MESSAGE, optional=True)
39
+ """
40
+ An enhancement alter semantics. Cannot be ignored by a consumer.
41
+ """
42
+
43
+
44
+ default_message_pool.register_message("substrait.extensions", "AdvancedExtension", AdvancedExtension)
45
+
46
+
47
+ @dataclass(eq=False, repr=False)
48
+ class SimpleExtensionDeclaration(betterproto2.Message):
49
+ """
50
+ Describes a mapping between a specific extension entity and the uri where
51
+ that extension can be found.
52
+
53
+ Oneofs:
54
+ - mapping_type:
55
+ """
56
+
57
+ extension_type: "SimpleExtensionDeclarationExtensionType | None" = betterproto2.field(
58
+ 1, betterproto2.TYPE_MESSAGE, optional=True, group="mapping_type"
59
+ )
60
+
61
+ extension_type_variation: "SimpleExtensionDeclarationExtensionTypeVariation | None" = betterproto2.field(
62
+ 2, betterproto2.TYPE_MESSAGE, optional=True, group="mapping_type"
63
+ )
64
+
65
+ extension_function: "SimpleExtensionDeclarationExtensionFunction | None" = betterproto2.field(
66
+ 3, betterproto2.TYPE_MESSAGE, optional=True, group="mapping_type"
67
+ )
68
+
69
+
70
+ default_message_pool.register_message("substrait.extensions", "SimpleExtensionDeclaration", SimpleExtensionDeclaration)
71
+
72
+
73
+ @dataclass(eq=False, repr=False)
74
+ class SimpleExtensionDeclarationExtensionFunction(betterproto2.Message):
75
+ extension_uri_reference: "int" = betterproto2.field(1, betterproto2.TYPE_UINT32)
76
+ """
77
+ references the extension_uri_anchor defined for a specific extension URI.
78
+ """
79
+
80
+ function_anchor: "int" = betterproto2.field(2, betterproto2.TYPE_UINT32)
81
+ """
82
+ A surrogate key used in the context of a single plan to reference a
83
+ specific function
84
+ """
85
+
86
+ name: "str" = betterproto2.field(3, betterproto2.TYPE_STRING)
87
+ """
88
+ A function signature compound name
89
+ """
90
+
91
+
92
+ default_message_pool.register_message(
93
+ "substrait.extensions", "SimpleExtensionDeclaration.ExtensionFunction", SimpleExtensionDeclarationExtensionFunction
94
+ )
95
+
96
+
97
+ @dataclass(eq=False, repr=False)
98
+ class SimpleExtensionDeclarationExtensionType(betterproto2.Message):
99
+ """
100
+ Describes a Type
101
+ """
102
+
103
+ extension_uri_reference: "int" = betterproto2.field(1, betterproto2.TYPE_UINT32)
104
+ """
105
+ references the extension_uri_anchor defined for a specific extension URI.
106
+ """
107
+
108
+ type_anchor: "int" = betterproto2.field(2, betterproto2.TYPE_UINT32)
109
+ """
110
+ A surrogate key used in the context of a single plan to reference a
111
+ specific extension type
112
+ """
113
+
114
+ name: "str" = betterproto2.field(3, betterproto2.TYPE_STRING)
115
+ """
116
+ the name of the type in the defined extension YAML.
117
+ """
118
+
119
+
120
+ default_message_pool.register_message(
121
+ "substrait.extensions", "SimpleExtensionDeclaration.ExtensionType", SimpleExtensionDeclarationExtensionType
122
+ )
123
+
124
+
125
+ @dataclass(eq=False, repr=False)
126
+ class SimpleExtensionDeclarationExtensionTypeVariation(betterproto2.Message):
127
+ extension_uri_reference: "int" = betterproto2.field(1, betterproto2.TYPE_UINT32)
128
+ """
129
+ references the extension_uri_anchor defined for a specific extension URI.
130
+ """
131
+
132
+ type_variation_anchor: "int" = betterproto2.field(2, betterproto2.TYPE_UINT32)
133
+ """
134
+ A surrogate key used in the context of a single plan to reference a
135
+ specific type variation
136
+ """
137
+
138
+ name: "str" = betterproto2.field(3, betterproto2.TYPE_STRING)
139
+ """
140
+ the name of the type in the defined extension YAML.
141
+ """
142
+
143
+
144
+ default_message_pool.register_message(
145
+ "substrait.extensions",
146
+ "SimpleExtensionDeclaration.ExtensionTypeVariation",
147
+ SimpleExtensionDeclarationExtensionTypeVariation,
148
+ )
149
+
150
+
151
+ @dataclass(eq=False, repr=False)
152
+ class SimpleExtensionUri(betterproto2.Message):
153
+ extension_uri_anchor: "int" = betterproto2.field(1, betterproto2.TYPE_UINT32)
154
+ """
155
+ A surrogate key used in the context of a single plan used to reference the
156
+ URI associated with an extension.
157
+ """
158
+
159
+ uri: "str" = betterproto2.field(2, betterproto2.TYPE_STRING)
160
+ """
161
+ The URI where this extension YAML can be retrieved. This is the "namespace"
162
+ of this extension.
163
+ """
164
+
165
+
166
+ default_message_pool.register_message("substrait.extensions", "SimpleExtensionURI", SimpleExtensionUri)
167
+
168
+
169
+ from ...google import protobuf as __google__protobuf__
File without changes
@@ -0,0 +1,41 @@
1
+ import betterproto
2
+ from betterproto.grpc.grpclib_server import ServiceBase
3
+
4
+
5
+ def patch_protos(proto_module, our_module_globals):
6
+ """Calculate __all__ to re-export protos from a module."""
7
+
8
+ betterproto_types = (betterproto.Message, betterproto.Enum, betterproto.ServiceStub, ServiceBase)
9
+
10
+ proto_overrides = {}
11
+ missing = set()
12
+ for ident in dir(proto_module):
13
+ var = getattr(proto_module, ident)
14
+ if isinstance(var, type) and issubclass(var, betterproto_types):
15
+ if ident in our_module_globals:
16
+ override = id(our_module_globals.get(ident)) != id(var)
17
+ else:
18
+ override = False
19
+ missing.add(ident)
20
+ proto_overrides[ident] = override
21
+
22
+ if missing:
23
+ print(f"from {proto_module.__name__} import (")
24
+ for ident, override in proto_overrides.items():
25
+ if override:
26
+ print(f" {ident} as {ident}_,")
27
+ else:
28
+ print(f" {ident},")
29
+ print(")")
30
+ print("\n")
31
+ print("__all__ = [")
32
+ for ident in proto_overrides:
33
+ print(f' "{ident}",')
34
+ print("]")
35
+
36
+ raise ValueError(f"Missing types that need to be re-exported: {missing}")
37
+
38
+ # Patch any local subclasses back into the original module so the gRPC client will use them
39
+ for ident, override in proto_overrides.items():
40
+ if override:
41
+ setattr(proto_module, ident, our_module_globals[ident])
spiral/py.typed ADDED
File without changes
spiral/scan.py ADDED
@@ -0,0 +1,363 @@
1
+ from functools import partial
2
+ from typing import TYPE_CHECKING, Any, Optional
3
+
4
+ import pyarrow as pa
5
+
6
+ from spiral.core.client import Shard, ShuffleConfig
7
+ from spiral.core.table import KeyRange
8
+ from spiral.core.table import Scan as CoreScan
9
+ from spiral.core.table.spec import Schema
10
+ from spiral.settings import CI, DEV
11
+
12
+ if TYPE_CHECKING:
13
+ import dask.dataframe as dd
14
+ import datasets.iterable_dataset as hf # noqa
15
+ import pandas as pd
16
+ import polars as pl
17
+ import streaming # noqa
18
+ import torch.utils.data as torchdata # noqa
19
+
20
+ from spiral.client import Spiral
21
+ from spiral.dataloader import SpiralDataLoader, World # noqa
22
+
23
+
24
+ class Scan:
25
+ """Scan object."""
26
+
27
+ def __init__(self, spiral: "Spiral", core: CoreScan):
28
+ self.spiral = spiral
29
+ self.core = core
30
+
31
+ @property
32
+ def metrics(self) -> dict[str, Any]:
33
+ """Returns metrics about the scan."""
34
+ return self.core.metrics()
35
+
36
+ @property
37
+ def schema(self) -> Schema:
38
+ """Returns the schema of the scan."""
39
+ return self.core.schema()
40
+
41
+ @property
42
+ def key_schema(self) -> Schema:
43
+ """Returns the key schema of the scan."""
44
+ return self.core.key_schema()
45
+
46
+ def is_empty(self) -> bool:
47
+ """Check if the Spiral is empty for the given key range.
48
+
49
+ False negatives are possible, but false positives are not,
50
+ i.e. is_empty can return False and scan can return zero rows.
51
+ """
52
+ return self.core.is_empty()
53
+
54
+ def to_record_batches(
55
+ self,
56
+ *,
57
+ key_range: KeyRange | None = None,
58
+ key_table: pa.Table | pa.RecordBatchReader | None = None,
59
+ batch_size: int | None = None,
60
+ batch_readahead: int | None = None,
61
+ hide_progress_bar: bool = False,
62
+ ) -> pa.RecordBatchReader:
63
+ """Read as a stream of RecordBatches.
64
+
65
+ Args:
66
+ key_range: Optional key range to filter the scan.
67
+ If provided, the scan will only return rows within the key range.
68
+ Only one of key_range or key_table can be provided.
69
+ key_table: a table of keys to "take" (including aux columns for cell-push-down).
70
+ If None, the scan will be executed without a key table.
71
+ batch_size: the maximum number of rows per returned batch.
72
+ This is currently only respected when the key_table is used. If key table is a
73
+ RecordBatchReader, the batch_size argument must be None, and the existing batching is respected.
74
+ batch_readahead: the number of batches to prefetch in the background.
75
+ hide_progress_bar: If True, disables the progress bar during reading.
76
+ """
77
+ if key_range is not None and key_table is not None:
78
+ raise ValueError("Only one of key_range or key_table can be provided.")
79
+
80
+ if isinstance(key_table, pa.RecordBatchReader):
81
+ if batch_size is not None:
82
+ raise ValueError(
83
+ "batch_size must be None when key_table is a RecordBatchReader, the existing batching is respected."
84
+ )
85
+ elif isinstance(key_table, pa.Table):
86
+ key_table = key_table.to_reader(max_chunksize=batch_size)
87
+
88
+ return self.core.to_record_batches(
89
+ key_range=key_range, key_table=key_table, batch_readahead=batch_readahead, progress=(not hide_progress_bar)
90
+ )
91
+
92
+ def to_unordered_record_batches(
93
+ self,
94
+ *,
95
+ key_table: pa.Table | pa.RecordBatchReader | None = None,
96
+ batch_size: int | None = None,
97
+ batch_readahead: int | None = None,
98
+ hide_progress_bar: bool = False,
99
+ ) -> pa.RecordBatchReader:
100
+ """Read as a stream of RecordBatches, NOT ordered by key.
101
+
102
+ Args:
103
+ key_table: a table of keys to "take" (including aux columns for cell-push-down).
104
+ If None, the scan will be executed without a key table.
105
+ batch_size: the maximum number of rows per returned batch.
106
+ This is currently only respected when the key_table is used. If key table is a
107
+ RecordBatchReader, the batch_size argument must be None, and the existing batching is respected.
108
+ batch_readahead: the number of batches to prefetch in the background.
109
+ hide_progress_bar: If True, disables the progress bar during reading.
110
+ """
111
+ if isinstance(key_table, pa.RecordBatchReader):
112
+ if batch_size is not None:
113
+ raise ValueError(
114
+ "batch_size must be None when key_table is a RecordBatchReader, the existing batching is respected."
115
+ )
116
+ elif isinstance(key_table, pa.Table):
117
+ key_table = key_table.to_reader(max_chunksize=batch_size)
118
+
119
+ return self.core.to_unordered_record_batches(
120
+ key_table=key_table, batch_readahead=batch_readahead, progress=(not hide_progress_bar)
121
+ )
122
+
123
+ def to_table(
124
+ self,
125
+ *,
126
+ key_range: KeyRange | None = None,
127
+ key_table: pa.Table | pa.RecordBatchReader | None = None,
128
+ ) -> pa.Table:
129
+ """Read into a single PyArrow Table.
130
+
131
+ Args:
132
+ key_range: Optional key range to filter the scan.
133
+ If provided, the scan will only return rows within the key range.
134
+ Only one of key_range or key_table can be provided.
135
+ key_table: a table of keys to "take" (including aux columns for cell-push-down).
136
+ If None, the scan will be executed without a key table.
137
+ """
138
+ # NOTE: Evaluates fully on Rust side which improved debuggability.
139
+ if DEV and not CI and key_table is None and key_range is None:
140
+ rb = self.core.to_record_batch()
141
+ return pa.Table.from_batches([rb])
142
+
143
+ return self.to_record_batches(key_range=key_range, key_table=key_table).read_all()
144
+
145
+ def to_dask(self) -> "dd.DataFrame":
146
+ """Read into a Dask DataFrame.
147
+
148
+ Requires the `dask` package to be installed.
149
+
150
+ Dask execution has some limitations, e.g. UDFs are not currently supported. These limitations
151
+ usually manifest as serialization errors when Dask workers attempt to serialize the state. If you are
152
+ encountering such issues, please reach out to the support for assistance.
153
+ """
154
+ import dask.dataframe as dd
155
+
156
+ _read_shard = partial(
157
+ _read_shard_task,
158
+ settings_json=self.spiral.config.to_json(),
159
+ state_json=self.core.plan_state().to_json(),
160
+ )
161
+ return dd.from_map(_read_shard, self.shards())
162
+
163
+ def to_pandas(self, *, key_range: KeyRange | None = None) -> "pd.DataFrame":
164
+ """Read into a Pandas DataFrame.
165
+
166
+ Requires the `pandas` package to be installed.
167
+ """
168
+ return self.to_table(key_range=key_range).to_pandas()
169
+
170
+ def to_polars(self) -> "pl.DataFrame":
171
+ """Read into a Polars DataFrame.
172
+
173
+ Requires the `polars` package to be installed.
174
+ """
175
+ import polars as pl
176
+
177
+ return pl.from_arrow(self.to_record_batches())
178
+
179
+ def to_data_loader(
180
+ self, seed: int = 42, shuffle_buffer_size: int = 8192, batch_size: int = 32, **kwargs
181
+ ) -> "SpiralDataLoader":
182
+ """Read into a Torch-compatible DataLoader for single-node training.
183
+
184
+ Args:
185
+ seed: Random seed for reproducibility.
186
+ shuffle_buffer_size: Size of shuffle buffer.
187
+ batch_size: Batch size.
188
+ **kwargs: Additional arguments passed to SpiralDataLoader constructor.
189
+
190
+ Returns:
191
+ SpiralDataLoader with shuffled shards.
192
+ """
193
+ from spiral.dataloader import SpiralDataLoader
194
+
195
+ return SpiralDataLoader(
196
+ self, seed=seed, shuffle_buffer_size=shuffle_buffer_size, batch_size=batch_size, **kwargs
197
+ )
198
+
199
+ def to_distributed_data_loader(
200
+ self,
201
+ world: Optional["World"] = None,
202
+ shards: list[Shard] | None = None,
203
+ seed: int = 42,
204
+ shuffle_buffer_size: int = 8192,
205
+ batch_size: int = 32,
206
+ **kwargs,
207
+ ) -> "SpiralDataLoader":
208
+ """Read into a Torch-compatible DataLoader for distributed training.
209
+
210
+ Args:
211
+ world: World configuration with rank and world_size.
212
+ If None, auto-detects from torch.distributed.
213
+ shards: Optional sharding. Sharding is global, i.e. the world will be used to select
214
+ the shards for this rank. If None, uses scan's natural sharding.
215
+ seed: Random seed for reproducibility.
216
+ shuffle_buffer_size: Size of shuffle buffer.
217
+ Use zero to skip shuffling with shuffle buffer.
218
+ batch_size: Batch size.
219
+ **kwargs: Additional arguments passed to SpiralDataLoader constructor.
220
+
221
+ Returns:
222
+ SpiralDataLoader with shards partitioned for this rank.
223
+
224
+ Auto-detect from PyTorch distributed:
225
+ ```python
226
+ loader: SpiralDataLoader = scan.to_distributed_data_loader(batch_size=32)
227
+ ```
228
+
229
+ Explicit world configuration:
230
+ ```python
231
+ world = World(rank=0, world_size=4)
232
+ loader: SpiralDataLoader = scan.to_distributed_data_loader(world=world, batch_size=32)
233
+ ```
234
+ """
235
+ from spiral.dataloader import SpiralDataLoader, World
236
+
237
+ if world is None:
238
+ world = World.from_torch()
239
+
240
+ shards = shards or self.shards()
241
+ # Apply world partitioning to shards.
242
+ shards = world.shards(shards, seed)
243
+
244
+ return SpiralDataLoader(
245
+ self,
246
+ shards=shards,
247
+ shuffle_shards=False, # Shards are shuffled before selected for the world.
248
+ seed=seed,
249
+ shuffle_buffer_size=shuffle_buffer_size,
250
+ batch_size=batch_size,
251
+ **kwargs,
252
+ )
253
+
254
+ def resume_data_loader(self, state: dict[str, Any], **kwargs) -> "SpiralDataLoader":
255
+ """Create a DataLoader from checkpoint state, resuming from where it left off.
256
+
257
+ This is the recommended way to resume training from a checkpoint. It extracts
258
+ the seed, samples_yielded, and shards from the state dict and creates a new
259
+ DataLoader that will skip the already-processed samples.
260
+
261
+ Args:
262
+ state: Checkpoint state from state_dict().
263
+ **kwargs: Additional arguments to pass to SpiralDataLoader constructor.
264
+ These will override values in the state dict where applicable.
265
+
266
+ Returns:
267
+ New SpiralDataLoader instance configured to resume from the checkpoint.
268
+
269
+ Save checkpoint during training:
270
+ ```python
271
+ loader = scan.to_distributed_data_loader(batch_size=32, seed=42)
272
+ checkpoint = loader.state_dict()
273
+ ```
274
+
275
+ Resume later - uses same shards from checkpoint:
276
+ ```python
277
+ resumed_loader = scan.resume_data_loader(
278
+ checkpoint,
279
+ batch_size=32,
280
+ transform_fn=my_transform,
281
+ )
282
+ """
283
+ from spiral.dataloader import SpiralDataLoader
284
+
285
+ return SpiralDataLoader.from_state_dict(self, state, **kwargs)
286
+
287
+ def to_iterable_dataset(
288
+ self,
289
+ shards: list[Shard] | None = None,
290
+ shuffle: ShuffleConfig | None = None,
291
+ batch_readahead: int | None = None,
292
+ infinite: bool = False,
293
+ ) -> "hf.IterableDataset":
294
+ """Returns a Huggingface's IterableDataset.
295
+
296
+ Requires `datasets` package to be installed.
297
+
298
+ Note: For new code, consider using SpiralDataLoader instead.
299
+
300
+ Args:
301
+ shards: Optional list of shards to read. If None, uses scan's natural sharding.
302
+ shuffle: Optional ShuffleConfig for configuring within-shard sample shuffling.
303
+ If None, no shuffling is performed.
304
+ batch_readahead: Controls how many batches to read ahead concurrently.
305
+ If pipeline includes work after reading (e.g. decoding, transforming, ...) this can be set higher.
306
+ Otherwise, it should be kept low to reduce next batch latency. Defaults to 2.
307
+ infinite: If True, the returned IterableDataset will loop infinitely over the data,
308
+ re-shuffling ranges after exhausting all data.
309
+ """
310
+ stream = self.core.to_shuffled_record_batches(
311
+ shards=shards,
312
+ shuffle=shuffle,
313
+ batch_readahead=batch_readahead,
314
+ infinite=infinite,
315
+ )
316
+
317
+ from spiral.iterable_dataset import to_iterable_dataset
318
+
319
+ return to_iterable_dataset(stream)
320
+
321
+ def shards(self) -> list[Shard]:
322
+ """Get list of shards for this scan.
323
+
324
+ The shards are based on the scan's physical data layout (file fragments).
325
+ Each shard contains a key range and cardinality (set to None when unknown).
326
+
327
+ Returns:
328
+ List of Shard objects with key range and cardinality (if known).
329
+
330
+ """
331
+ return self.core.shards()
332
+
333
+ def _debug(self):
334
+ # Visualizes the scan, mainly for debugging purposes.
335
+ from spiral.debug.scan import show_scan
336
+
337
+ show_scan(self.core)
338
+
339
+ def _dump_manifests(self):
340
+ # Print manifests in a human-readable format.
341
+ from spiral.debug.manifests import display_scan_manifests
342
+
343
+ display_scan_manifests(self.core)
344
+
345
+ def _dump_metrics(self):
346
+ # Print metrics in a human-readable format.
347
+ from spiral.debug.metrics import display_metrics
348
+
349
+ display_metrics(self.metrics)
350
+
351
+
352
+ # NOTE(marko): This function must be picklable!
353
+ def _read_shard_task(shard: Shard, *, settings_json: str, state_json: str) -> "pd.DataFrame":
354
+ from spiral import Spiral
355
+ from spiral.core.table import ScanState
356
+ from spiral.settings import ClientSettings
357
+
358
+ settings = ClientSettings.from_json(settings_json)
359
+ sp = Spiral(config=settings)
360
+ state = ScanState.from_json(state_json)
361
+ task_scan = Scan(sp, sp.core.load_scan(state))
362
+
363
+ return task_scan.to_record_batches(key_range=shard.key_range, hide_progress_bar=True).read_all().to_pandas()
spiral/server.py ADDED
@@ -0,0 +1,17 @@
1
+ import socket
2
+ import time
3
+
4
+
5
+ def wait_for_port(port: int, host: str = "localhost", timeout: float = 5.0):
6
+ """Wait until a port starts accepting TCP connections."""
7
+ start_time = time.time()
8
+ while True:
9
+ try:
10
+ with socket.create_connection((host, port), timeout=timeout):
11
+ break
12
+ except OSError as ex:
13
+ time.sleep(0.01)
14
+ if time.time() - start_time >= timeout:
15
+ raise TimeoutError(
16
+ f"Waited too long for the port {port} on host {host} to start accepting connections."
17
+ ) from ex
spiral/settings.py ADDED
@@ -0,0 +1,36 @@
1
+ """Configuration module using Rust ClientSettings via PyO3.
2
+
3
+ This module provides a simple settings() function that returns a cached
4
+ ClientSettings instance loaded from ~/.spiral.toml and environment variables.
5
+ """
6
+
7
+ import functools
8
+ import os
9
+ from pathlib import Path
10
+
11
+ import typer
12
+
13
+ from spiral.core.config import ClientSettings
14
+
15
+ DEV = "PYTEST_VERSION" in os.environ or bool(os.environ.get("SPIRAL_DEV", None))
16
+ CI = "GITHUB_ACTIONS" in os.environ
17
+
18
+ APP_DIR = Path(typer.get_app_dir("pyspiral"))
19
+ LOG_DIR = APP_DIR / "logs"
20
+
21
+ PACKAGE_NAME = "pyspiral"
22
+
23
+
24
+ @functools.cache
25
+ def settings() -> ClientSettings:
26
+ """Get the global ClientSettings instance.
27
+
28
+ Configuration is loaded with the following priority (highest to lowest):
29
+ 1. Environment variables (SPIRAL__*)
30
+ 2. Config file (~/.spiral.toml)
31
+ 3. Default values
32
+
33
+ Returns:
34
+ ClientSettings: The global configuration instance
35
+ """
36
+ return ClientSettings.load()
spiral/snapshot.py ADDED
@@ -0,0 +1,56 @@
1
+ from typing import TYPE_CHECKING
2
+
3
+ from spiral.core.table import Snapshot as CoreSnapshot
4
+ from spiral.core.table.spec import Schema
5
+ from spiral.types_ import Timestamp
6
+
7
+ if TYPE_CHECKING:
8
+ import duckdb
9
+ import polars as pl
10
+ import pyarrow.dataset as ds
11
+ import torch.utils.data as torchdata # noqa
12
+
13
+ from spiral.table import Table
14
+
15
+
16
+ class Snapshot:
17
+ """Spiral table snapshot.
18
+
19
+ A snapshot represents a point-in-time view of a table.
20
+ """
21
+
22
+ def __init__(self, table: "Table", core: CoreSnapshot):
23
+ self.core = core
24
+ self._table = table
25
+
26
+ @property
27
+ def asof(self) -> Timestamp:
28
+ """Returns the asof timestamp of the snapshot."""
29
+ return self.core.asof
30
+
31
+ def schema(self) -> Schema:
32
+ """Returns the schema of the snapshot."""
33
+ return self.core.table.get_schema(asof=self.asof)
34
+
35
+ @property
36
+ def table(self) -> "Table":
37
+ """Returns the table associated with the snapshot."""
38
+ return self._table
39
+
40
+ def to_dataset(self) -> "ds.Dataset":
41
+ """Returns a PyArrow Dataset representing the table."""
42
+ from spiral.dataset import Dataset
43
+
44
+ return Dataset(self)
45
+
46
+ def to_polars(self) -> "pl.LazyFrame":
47
+ """Returns a Polars LazyFrame for the Spiral table."""
48
+ import polars as pl
49
+
50
+ return pl.scan_pyarrow_dataset(self.to_dataset())
51
+
52
+ def to_duckdb(self) -> "duckdb.DuckDBPyRelation":
53
+ """Returns a DuckDB relation for the Spiral table."""
54
+ import duckdb
55
+
56
+ return duckdb.from_arrow(self.to_dataset())
@@ -0,0 +1,3 @@
1
+ from .stream import SpiralStream
2
+
3
+ __all__ = ["SpiralStream"]