pyspiral 0.6.8__cp312-abi3-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyspiral-0.6.8.dist-info/METADATA +51 -0
- pyspiral-0.6.8.dist-info/RECORD +102 -0
- pyspiral-0.6.8.dist-info/WHEEL +4 -0
- pyspiral-0.6.8.dist-info/entry_points.txt +2 -0
- spiral/__init__.py +35 -0
- spiral/_lib.abi3.so +0 -0
- spiral/adbc.py +411 -0
- spiral/api/__init__.py +78 -0
- spiral/api/admin.py +15 -0
- spiral/api/client.py +164 -0
- spiral/api/filesystems.py +134 -0
- spiral/api/key_space_indexes.py +23 -0
- spiral/api/organizations.py +77 -0
- spiral/api/projects.py +219 -0
- spiral/api/telemetry.py +19 -0
- spiral/api/text_indexes.py +56 -0
- spiral/api/types.py +22 -0
- spiral/api/workers.py +40 -0
- spiral/api/workloads.py +52 -0
- spiral/arrow_.py +216 -0
- spiral/cli/__init__.py +88 -0
- spiral/cli/__main__.py +4 -0
- spiral/cli/admin.py +14 -0
- spiral/cli/app.py +104 -0
- spiral/cli/console.py +95 -0
- spiral/cli/fs.py +76 -0
- spiral/cli/iceberg.py +97 -0
- spiral/cli/key_spaces.py +89 -0
- spiral/cli/login.py +24 -0
- spiral/cli/orgs.py +89 -0
- spiral/cli/printer.py +53 -0
- spiral/cli/projects.py +147 -0
- spiral/cli/state.py +5 -0
- spiral/cli/tables.py +174 -0
- spiral/cli/telemetry.py +17 -0
- spiral/cli/text.py +115 -0
- spiral/cli/types.py +50 -0
- spiral/cli/workloads.py +58 -0
- spiral/client.py +178 -0
- spiral/core/__init__.pyi +0 -0
- spiral/core/_tools/__init__.pyi +5 -0
- spiral/core/authn/__init__.pyi +27 -0
- spiral/core/client/__init__.pyi +237 -0
- spiral/core/table/__init__.pyi +101 -0
- spiral/core/table/manifests/__init__.pyi +35 -0
- spiral/core/table/metastore/__init__.pyi +58 -0
- spiral/core/table/spec/__init__.pyi +213 -0
- spiral/dataloader.py +285 -0
- spiral/dataset.py +255 -0
- spiral/datetime_.py +27 -0
- spiral/debug/__init__.py +0 -0
- spiral/debug/manifests.py +87 -0
- spiral/debug/metrics.py +56 -0
- spiral/debug/scan.py +266 -0
- spiral/expressions/__init__.py +276 -0
- spiral/expressions/base.py +157 -0
- spiral/expressions/http.py +86 -0
- spiral/expressions/io.py +100 -0
- spiral/expressions/list_.py +68 -0
- spiral/expressions/mp4.py +62 -0
- spiral/expressions/png.py +18 -0
- spiral/expressions/qoi.py +18 -0
- spiral/expressions/refs.py +58 -0
- spiral/expressions/str_.py +39 -0
- spiral/expressions/struct.py +59 -0
- spiral/expressions/text.py +62 -0
- spiral/expressions/tiff.py +223 -0
- spiral/expressions/udf.py +46 -0
- spiral/grpc_.py +32 -0
- spiral/iceberg.py +31 -0
- spiral/iterable_dataset.py +106 -0
- spiral/key_space_index.py +44 -0
- spiral/project.py +199 -0
- spiral/protogen/_/__init__.py +0 -0
- spiral/protogen/_/arrow/__init__.py +0 -0
- spiral/protogen/_/arrow/flight/__init__.py +0 -0
- spiral/protogen/_/arrow/flight/protocol/__init__.py +0 -0
- spiral/protogen/_/arrow/flight/protocol/sql/__init__.py +2548 -0
- spiral/protogen/_/google/__init__.py +0 -0
- spiral/protogen/_/google/protobuf/__init__.py +2310 -0
- spiral/protogen/_/message_pool.py +3 -0
- spiral/protogen/_/py.typed +0 -0
- spiral/protogen/_/scandal/__init__.py +190 -0
- spiral/protogen/_/spfs/__init__.py +72 -0
- spiral/protogen/_/spql/__init__.py +61 -0
- spiral/protogen/_/substrait/__init__.py +6196 -0
- spiral/protogen/_/substrait/extensions/__init__.py +169 -0
- spiral/protogen/__init__.py +0 -0
- spiral/protogen/util.py +41 -0
- spiral/py.typed +0 -0
- spiral/scan.py +285 -0
- spiral/server.py +17 -0
- spiral/settings.py +114 -0
- spiral/snapshot.py +56 -0
- spiral/streaming_/__init__.py +3 -0
- spiral/streaming_/reader.py +133 -0
- spiral/streaming_/stream.py +157 -0
- spiral/substrait_.py +274 -0
- spiral/table.py +293 -0
- spiral/text_index.py +17 -0
- spiral/transaction.py +58 -0
- spiral/types_.py +6 -0
@@ -0,0 +1,169 @@
|
|
1
|
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
2
|
+
# sources: substrait/extensions/extensions.proto
|
3
|
+
# plugin: python-betterproto2
|
4
|
+
# This file has been @generated
|
5
|
+
|
6
|
+
__all__ = (
|
7
|
+
"AdvancedExtension",
|
8
|
+
"SimpleExtensionDeclaration",
|
9
|
+
"SimpleExtensionDeclarationExtensionFunction",
|
10
|
+
"SimpleExtensionDeclarationExtensionType",
|
11
|
+
"SimpleExtensionDeclarationExtensionTypeVariation",
|
12
|
+
"SimpleExtensionUri",
|
13
|
+
)
|
14
|
+
|
15
|
+
from dataclasses import dataclass
|
16
|
+
|
17
|
+
import betterproto2
|
18
|
+
|
19
|
+
from ...message_pool import default_message_pool
|
20
|
+
|
21
|
+
_COMPILER_VERSION = "0.9.0"
|
22
|
+
betterproto2.check_compiler_version(_COMPILER_VERSION)
|
23
|
+
|
24
|
+
|
25
|
+
@dataclass(eq=False, repr=False)
|
26
|
+
class AdvancedExtension(betterproto2.Message):
|
27
|
+
"""
|
28
|
+
A generic object that can be used to embed additional extension information
|
29
|
+
into the serialized substrait plan.
|
30
|
+
"""
|
31
|
+
|
32
|
+
optimization: "list[__google__protobuf__.Any]" = betterproto2.field(1, betterproto2.TYPE_MESSAGE, repeated=True)
|
33
|
+
"""
|
34
|
+
An optimization is helpful information that don't influence semantics. May
|
35
|
+
be ignored by a consumer.
|
36
|
+
"""
|
37
|
+
|
38
|
+
enhancement: "__google__protobuf__.Any | None" = betterproto2.field(2, betterproto2.TYPE_MESSAGE, optional=True)
|
39
|
+
"""
|
40
|
+
An enhancement alter semantics. Cannot be ignored by a consumer.
|
41
|
+
"""
|
42
|
+
|
43
|
+
|
44
|
+
default_message_pool.register_message("substrait.extensions", "AdvancedExtension", AdvancedExtension)
|
45
|
+
|
46
|
+
|
47
|
+
@dataclass(eq=False, repr=False)
|
48
|
+
class SimpleExtensionDeclaration(betterproto2.Message):
|
49
|
+
"""
|
50
|
+
Describes a mapping between a specific extension entity and the uri where
|
51
|
+
that extension can be found.
|
52
|
+
|
53
|
+
Oneofs:
|
54
|
+
- mapping_type:
|
55
|
+
"""
|
56
|
+
|
57
|
+
extension_type: "SimpleExtensionDeclarationExtensionType | None" = betterproto2.field(
|
58
|
+
1, betterproto2.TYPE_MESSAGE, optional=True, group="mapping_type"
|
59
|
+
)
|
60
|
+
|
61
|
+
extension_type_variation: "SimpleExtensionDeclarationExtensionTypeVariation | None" = betterproto2.field(
|
62
|
+
2, betterproto2.TYPE_MESSAGE, optional=True, group="mapping_type"
|
63
|
+
)
|
64
|
+
|
65
|
+
extension_function: "SimpleExtensionDeclarationExtensionFunction | None" = betterproto2.field(
|
66
|
+
3, betterproto2.TYPE_MESSAGE, optional=True, group="mapping_type"
|
67
|
+
)
|
68
|
+
|
69
|
+
|
70
|
+
default_message_pool.register_message("substrait.extensions", "SimpleExtensionDeclaration", SimpleExtensionDeclaration)
|
71
|
+
|
72
|
+
|
73
|
+
@dataclass(eq=False, repr=False)
|
74
|
+
class SimpleExtensionDeclarationExtensionFunction(betterproto2.Message):
|
75
|
+
extension_uri_reference: "int" = betterproto2.field(1, betterproto2.TYPE_UINT32)
|
76
|
+
"""
|
77
|
+
references the extension_uri_anchor defined for a specific extension URI.
|
78
|
+
"""
|
79
|
+
|
80
|
+
function_anchor: "int" = betterproto2.field(2, betterproto2.TYPE_UINT32)
|
81
|
+
"""
|
82
|
+
A surrogate key used in the context of a single plan to reference a
|
83
|
+
specific function
|
84
|
+
"""
|
85
|
+
|
86
|
+
name: "str" = betterproto2.field(3, betterproto2.TYPE_STRING)
|
87
|
+
"""
|
88
|
+
A function signature compound name
|
89
|
+
"""
|
90
|
+
|
91
|
+
|
92
|
+
default_message_pool.register_message(
|
93
|
+
"substrait.extensions", "SimpleExtensionDeclaration.ExtensionFunction", SimpleExtensionDeclarationExtensionFunction
|
94
|
+
)
|
95
|
+
|
96
|
+
|
97
|
+
@dataclass(eq=False, repr=False)
|
98
|
+
class SimpleExtensionDeclarationExtensionType(betterproto2.Message):
|
99
|
+
"""
|
100
|
+
Describes a Type
|
101
|
+
"""
|
102
|
+
|
103
|
+
extension_uri_reference: "int" = betterproto2.field(1, betterproto2.TYPE_UINT32)
|
104
|
+
"""
|
105
|
+
references the extension_uri_anchor defined for a specific extension URI.
|
106
|
+
"""
|
107
|
+
|
108
|
+
type_anchor: "int" = betterproto2.field(2, betterproto2.TYPE_UINT32)
|
109
|
+
"""
|
110
|
+
A surrogate key used in the context of a single plan to reference a
|
111
|
+
specific extension type
|
112
|
+
"""
|
113
|
+
|
114
|
+
name: "str" = betterproto2.field(3, betterproto2.TYPE_STRING)
|
115
|
+
"""
|
116
|
+
the name of the type in the defined extension YAML.
|
117
|
+
"""
|
118
|
+
|
119
|
+
|
120
|
+
default_message_pool.register_message(
|
121
|
+
"substrait.extensions", "SimpleExtensionDeclaration.ExtensionType", SimpleExtensionDeclarationExtensionType
|
122
|
+
)
|
123
|
+
|
124
|
+
|
125
|
+
@dataclass(eq=False, repr=False)
|
126
|
+
class SimpleExtensionDeclarationExtensionTypeVariation(betterproto2.Message):
|
127
|
+
extension_uri_reference: "int" = betterproto2.field(1, betterproto2.TYPE_UINT32)
|
128
|
+
"""
|
129
|
+
references the extension_uri_anchor defined for a specific extension URI.
|
130
|
+
"""
|
131
|
+
|
132
|
+
type_variation_anchor: "int" = betterproto2.field(2, betterproto2.TYPE_UINT32)
|
133
|
+
"""
|
134
|
+
A surrogate key used in the context of a single plan to reference a
|
135
|
+
specific type variation
|
136
|
+
"""
|
137
|
+
|
138
|
+
name: "str" = betterproto2.field(3, betterproto2.TYPE_STRING)
|
139
|
+
"""
|
140
|
+
the name of the type in the defined extension YAML.
|
141
|
+
"""
|
142
|
+
|
143
|
+
|
144
|
+
default_message_pool.register_message(
|
145
|
+
"substrait.extensions",
|
146
|
+
"SimpleExtensionDeclaration.ExtensionTypeVariation",
|
147
|
+
SimpleExtensionDeclarationExtensionTypeVariation,
|
148
|
+
)
|
149
|
+
|
150
|
+
|
151
|
+
@dataclass(eq=False, repr=False)
|
152
|
+
class SimpleExtensionUri(betterproto2.Message):
|
153
|
+
extension_uri_anchor: "int" = betterproto2.field(1, betterproto2.TYPE_UINT32)
|
154
|
+
"""
|
155
|
+
A surrogate key used in the context of a single plan used to reference the
|
156
|
+
URI associated with an extension.
|
157
|
+
"""
|
158
|
+
|
159
|
+
uri: "str" = betterproto2.field(2, betterproto2.TYPE_STRING)
|
160
|
+
"""
|
161
|
+
The URI where this extension YAML can be retrieved. This is the "namespace"
|
162
|
+
of this extension.
|
163
|
+
"""
|
164
|
+
|
165
|
+
|
166
|
+
default_message_pool.register_message("substrait.extensions", "SimpleExtensionURI", SimpleExtensionUri)
|
167
|
+
|
168
|
+
|
169
|
+
from ...google import protobuf as __google__protobuf__
|
File without changes
|
spiral/protogen/util.py
ADDED
@@ -0,0 +1,41 @@
|
|
1
|
+
import betterproto
|
2
|
+
from betterproto.grpc.grpclib_server import ServiceBase
|
3
|
+
|
4
|
+
|
5
|
+
def patch_protos(proto_module, our_module_globals):
|
6
|
+
"""Calculate __all__ to re-export protos from a module."""
|
7
|
+
|
8
|
+
betterproto_types = (betterproto.Message, betterproto.Enum, betterproto.ServiceStub, ServiceBase)
|
9
|
+
|
10
|
+
proto_overrides = {}
|
11
|
+
missing = set()
|
12
|
+
for ident in dir(proto_module):
|
13
|
+
var = getattr(proto_module, ident)
|
14
|
+
if isinstance(var, type) and issubclass(var, betterproto_types):
|
15
|
+
if ident in our_module_globals:
|
16
|
+
override = id(our_module_globals.get(ident)) != id(var)
|
17
|
+
else:
|
18
|
+
override = False
|
19
|
+
missing.add(ident)
|
20
|
+
proto_overrides[ident] = override
|
21
|
+
|
22
|
+
if missing:
|
23
|
+
print(f"from {proto_module.__name__} import (")
|
24
|
+
for ident, override in proto_overrides.items():
|
25
|
+
if override:
|
26
|
+
print(f" {ident} as {ident}_,")
|
27
|
+
else:
|
28
|
+
print(f" {ident},")
|
29
|
+
print(")")
|
30
|
+
print("\n")
|
31
|
+
print("__all__ = [")
|
32
|
+
for ident in proto_overrides:
|
33
|
+
print(f' "{ident}",')
|
34
|
+
print("]")
|
35
|
+
|
36
|
+
raise ValueError(f"Missing types that need to be re-exported: {missing}")
|
37
|
+
|
38
|
+
# Patch any local subclasses back into the original module so the gRPC client will use them
|
39
|
+
for ident, override in proto_overrides.items():
|
40
|
+
if override:
|
41
|
+
setattr(proto_module, ident, our_module_globals[ident])
|
spiral/py.typed
ADDED
File without changes
|
spiral/scan.py
ADDED
@@ -0,0 +1,285 @@
|
|
1
|
+
from typing import TYPE_CHECKING, Any, Optional
|
2
|
+
|
3
|
+
import pyarrow as pa
|
4
|
+
|
5
|
+
from spiral.core.client import Shard, ShuffleConfig
|
6
|
+
from spiral.core.table import Scan as CoreScan
|
7
|
+
from spiral.core.table.spec import Schema
|
8
|
+
from spiral.settings import CI, DEV
|
9
|
+
|
10
|
+
if TYPE_CHECKING:
|
11
|
+
import dask.dataframe as dd
|
12
|
+
import datasets.iterable_dataset as hf # noqa
|
13
|
+
import pandas as pd
|
14
|
+
import polars as pl
|
15
|
+
import streaming # noqa
|
16
|
+
import torch.utils.data as torchdata # noqa
|
17
|
+
|
18
|
+
from spiral.dataloader import SpiralDataLoader, World # noqa
|
19
|
+
|
20
|
+
|
21
|
+
class Scan:
|
22
|
+
"""Scan object."""
|
23
|
+
|
24
|
+
def __init__(self, core: CoreScan):
|
25
|
+
self.core = core
|
26
|
+
|
27
|
+
@property
|
28
|
+
def metrics(self) -> dict[str, Any]:
|
29
|
+
"""Returns metrics about the scan."""
|
30
|
+
return self.core.metrics()
|
31
|
+
|
32
|
+
@property
|
33
|
+
def schema(self) -> Schema:
|
34
|
+
"""Returns the schema of the scan."""
|
35
|
+
return self.core.schema()
|
36
|
+
|
37
|
+
def is_empty(self) -> bool:
|
38
|
+
"""Check if the Spiral is empty for the given key range.
|
39
|
+
|
40
|
+
**IMPORTANT**: False negatives are possible, but false positives are not,
|
41
|
+
i.e. is_empty can return False and scan can return zero rows.
|
42
|
+
"""
|
43
|
+
return self.core.is_empty()
|
44
|
+
|
45
|
+
def to_record_batches(
|
46
|
+
self,
|
47
|
+
key_table: pa.Table | pa.RecordBatchReader | None = None,
|
48
|
+
batch_size: int | None = None,
|
49
|
+
batch_readahead: int | None = None,
|
50
|
+
) -> pa.RecordBatchReader:
|
51
|
+
"""Read as a stream of RecordBatches.
|
52
|
+
|
53
|
+
Args:
|
54
|
+
key_table: a table of keys to "take" (including aux columns for cell-push-down).
|
55
|
+
If None, the scan will be executed without a key table.
|
56
|
+
batch_size: the maximum number of rows per returned batch.
|
57
|
+
IMPORTANT: This is currently only respected when the key_table is used. If key table is a
|
58
|
+
RecordBatchReader, the batch_size argument must be None, and the existing batching is respected.
|
59
|
+
batch_readahead: the number of batches to prefetch in the background.
|
60
|
+
"""
|
61
|
+
if isinstance(key_table, pa.RecordBatchReader):
|
62
|
+
if batch_size is not None:
|
63
|
+
raise ValueError(
|
64
|
+
"batch_size must be None when key_table is a RecordBatchReader, the existing batching is respected."
|
65
|
+
)
|
66
|
+
elif isinstance(key_table, pa.Table):
|
67
|
+
key_table = key_table.to_reader(max_chunksize=batch_size)
|
68
|
+
|
69
|
+
return self.core.to_record_batches(key_table=key_table, batch_readahead=batch_readahead)
|
70
|
+
|
71
|
+
def to_table(
|
72
|
+
self,
|
73
|
+
key_table: pa.Table | pa.RecordBatchReader | None = None,
|
74
|
+
) -> pa.Table:
|
75
|
+
"""Read into a single PyArrow Table.
|
76
|
+
|
77
|
+
Args:
|
78
|
+
key_table: a table of keys to "take" (including aux columns for cell-push-down).
|
79
|
+
If None, the scan will be executed without a key table.
|
80
|
+
"""
|
81
|
+
# NOTE: Evaluates fully on Rust side which improved debuggability.
|
82
|
+
if DEV and not CI and key_table is None:
|
83
|
+
rb = self.core.to_record_batch()
|
84
|
+
return pa.Table.from_batches([rb])
|
85
|
+
|
86
|
+
return self.to_record_batches(key_table=key_table).read_all()
|
87
|
+
|
88
|
+
def to_dask(self) -> "dd.DataFrame":
|
89
|
+
"""Read into a Dask DataFrame.
|
90
|
+
|
91
|
+
Requires the `dask` package to be installed.
|
92
|
+
"""
|
93
|
+
import dask.dataframe as dd
|
94
|
+
import pandas as pd
|
95
|
+
|
96
|
+
def _read_shard(shard: Shard) -> pd.DataFrame:
|
97
|
+
# TODO(ngates): we need a way to preserve the existing asofs?
|
98
|
+
raise NotImplementedError()
|
99
|
+
|
100
|
+
# Fetch a set of partition ranges
|
101
|
+
return dd.from_map(_read_shard, self.shards())
|
102
|
+
|
103
|
+
def to_pandas(self) -> "pd.DataFrame":
|
104
|
+
"""Read into a Pandas DataFrame.
|
105
|
+
|
106
|
+
Requires the `pandas` package to be installed.
|
107
|
+
"""
|
108
|
+
return self.to_table().to_pandas()
|
109
|
+
|
110
|
+
def to_polars(self) -> "pl.DataFrame":
|
111
|
+
"""Read into a Polars DataFrame.
|
112
|
+
|
113
|
+
Requires the `polars` package to be installed.
|
114
|
+
"""
|
115
|
+
import polars as pl
|
116
|
+
|
117
|
+
return pl.from_arrow(self.to_record_batches())
|
118
|
+
|
119
|
+
def to_data_loader(
|
120
|
+
self, seed: int = 42, shuffle_buffer_size: int = 8192, batch_size: int = 32, **kwargs
|
121
|
+
) -> "SpiralDataLoader":
|
122
|
+
"""Read into a Torch-compatible DataLoader for single-node training.
|
123
|
+
|
124
|
+
Args:
|
125
|
+
seed: Random seed for reproducibility.
|
126
|
+
shuffle_buffer_size: Size of shuffle buffer.
|
127
|
+
batch_size: Batch size.
|
128
|
+
**kwargs: Additional arguments passed to SpiralDataLoader constructor.
|
129
|
+
|
130
|
+
Returns:
|
131
|
+
SpiralDataLoader with shuffled shards.
|
132
|
+
"""
|
133
|
+
from spiral.dataloader import SpiralDataLoader
|
134
|
+
|
135
|
+
return SpiralDataLoader(
|
136
|
+
self, seed=seed, shuffle_buffer_size=shuffle_buffer_size, batch_size=batch_size, **kwargs
|
137
|
+
)
|
138
|
+
|
139
|
+
def to_distributed_data_loader(
|
140
|
+
self,
|
141
|
+
world: Optional["World"] = None,
|
142
|
+
shards: list[Shard] | None = None,
|
143
|
+
seed: int = 42,
|
144
|
+
shuffle_buffer_size: int = 8192,
|
145
|
+
batch_size: int = 32,
|
146
|
+
**kwargs,
|
147
|
+
) -> "SpiralDataLoader":
|
148
|
+
"""Read into a Torch-compatible DataLoader for distributed training.
|
149
|
+
|
150
|
+
Args:
|
151
|
+
world: World configuration with rank and world_size.
|
152
|
+
If None, auto-detects from torch.distributed.
|
153
|
+
shards: Optional sharding. Sharding is global, i.e. the world will be used to select
|
154
|
+
the shards for this rank. If None, uses scan's natural sharding.
|
155
|
+
seed: Random seed for reproducibility.
|
156
|
+
shuffle_buffer_size: Size of shuffle buffer.
|
157
|
+
Use zero to skip shuffling with shuffle buffer.
|
158
|
+
batch_size: Batch size.
|
159
|
+
**kwargs: Additional arguments passed to SpiralDataLoader constructor.
|
160
|
+
|
161
|
+
Returns:
|
162
|
+
SpiralDataLoader with shards partitioned for this rank.
|
163
|
+
"""
|
164
|
+
# Example usage:
|
165
|
+
#
|
166
|
+
# Auto-detect from PyTorch distributed:
|
167
|
+
# loader: SpiralDataLoader = scan.to_distributed_data_loader(batch_size=32)
|
168
|
+
#
|
169
|
+
# Explicit world configuration:
|
170
|
+
# world = World(rank=0, world_size=4)
|
171
|
+
# loader: SpiralDataLoader = scan.to_distributed_data_loader(world=world, batch_size=32)
|
172
|
+
|
173
|
+
from spiral.dataloader import SpiralDataLoader, World
|
174
|
+
|
175
|
+
if world is None:
|
176
|
+
world = World.from_torch()
|
177
|
+
|
178
|
+
shards = shards or self.shards()
|
179
|
+
# Apply world partitioning to shards.
|
180
|
+
shards = world.shards(shards, seed)
|
181
|
+
|
182
|
+
return SpiralDataLoader(
|
183
|
+
self,
|
184
|
+
shards=shards,
|
185
|
+
shuffle_shards=False, # Shards are shuffled before selected for the world.
|
186
|
+
seed=seed,
|
187
|
+
shuffle_buffer_size=shuffle_buffer_size,
|
188
|
+
batch_size=batch_size,
|
189
|
+
**kwargs,
|
190
|
+
)
|
191
|
+
|
192
|
+
def resume_data_loader(self, state: dict[str, Any], **kwargs) -> "SpiralDataLoader":
|
193
|
+
"""Create a DataLoader from checkpoint state, resuming from where it left off.
|
194
|
+
|
195
|
+
This is the recommended way to resume training from a checkpoint. It extracts
|
196
|
+
the seed, samples_yielded, and shards from the state dict and creates a new
|
197
|
+
DataLoader that will skip the already-processed samples.
|
198
|
+
|
199
|
+
Args:
|
200
|
+
state: Checkpoint state from state_dict().
|
201
|
+
**kwargs: Additional arguments to pass to SpiralDataLoader constructor.
|
202
|
+
These will override values in the state dict where applicable.
|
203
|
+
|
204
|
+
Returns:
|
205
|
+
New SpiralDataLoader instance configured to resume from the checkpoint.
|
206
|
+
"""
|
207
|
+
# Example usage:
|
208
|
+
#
|
209
|
+
# Save checkpoint during training:
|
210
|
+
# loader = scan.to_distributed_data_loader(batch_size=32, seed=42)
|
211
|
+
# checkpoint = loader.state_dict()
|
212
|
+
#
|
213
|
+
# Resume later - uses same shards from checkpoint:
|
214
|
+
# resumed_loader = scan.resume_data_loader(
|
215
|
+
# checkpoint,
|
216
|
+
# batch_size=32,
|
217
|
+
# transform_fn=my_transform,
|
218
|
+
# )
|
219
|
+
from spiral.dataloader import SpiralDataLoader
|
220
|
+
|
221
|
+
return SpiralDataLoader.from_state_dict(self, state, **kwargs)
|
222
|
+
|
223
|
+
def to_iterable_dataset(
|
224
|
+
self,
|
225
|
+
shards: list[Shard] | None = None,
|
226
|
+
shuffle: ShuffleConfig | None = None,
|
227
|
+
batch_readahead: int | None = None,
|
228
|
+
infinite: bool = False,
|
229
|
+
) -> "hf.IterableDataset":
|
230
|
+
"""Returns a Huggingface's IterableDataset.
|
231
|
+
|
232
|
+
Requires `datasets` package to be installed.
|
233
|
+
|
234
|
+
Note: For new code, consider using SpiralDataLoader instead.
|
235
|
+
|
236
|
+
Args:
|
237
|
+
shards: Optional list of shards to read. If None, uses scan's natural sharding.
|
238
|
+
shuffle: Optional ShuffleConfig for configuring within-shard sample shuffling.
|
239
|
+
If None, no shuffling is performed.
|
240
|
+
batch_readahead: Controls how many batches to read ahead concurrently.
|
241
|
+
If pipeline includes work after reading (e.g. decoding, transforming, ...) this can be set higher.
|
242
|
+
Otherwise, it should be kept low to reduce next batch latency. Defaults to 2.
|
243
|
+
infinite: If True, the returned IterableDataset will loop infinitely over the data,
|
244
|
+
re-shuffling ranges after exhausting all data.
|
245
|
+
"""
|
246
|
+
stream = self.core.to_shuffled_record_batches(
|
247
|
+
shards=shards,
|
248
|
+
shuffle=shuffle,
|
249
|
+
batch_readahead=batch_readahead,
|
250
|
+
infinite=infinite,
|
251
|
+
)
|
252
|
+
|
253
|
+
from spiral.iterable_dataset import to_iterable_dataset
|
254
|
+
|
255
|
+
return to_iterable_dataset(stream)
|
256
|
+
|
257
|
+
def shards(self) -> list[Shard]:
|
258
|
+
"""Get list of shards for this scan.
|
259
|
+
|
260
|
+
The shards are based on the scan's physical data layout (file fragments).
|
261
|
+
Each shard contains a key range and cardinality (set to None when unknown).
|
262
|
+
|
263
|
+
Returns:
|
264
|
+
List of Shard objects with key range and cardinality (if known).
|
265
|
+
|
266
|
+
"""
|
267
|
+
return self.core.shards()
|
268
|
+
|
269
|
+
def _debug(self):
|
270
|
+
# Visualizes the scan, mainly for debugging purposes.
|
271
|
+
from spiral.debug.scan import show_scan
|
272
|
+
|
273
|
+
show_scan(self.core)
|
274
|
+
|
275
|
+
def _dump_manifests(self):
|
276
|
+
# Print manifests in a human-readable format.
|
277
|
+
from spiral.debug.manifests import display_scan_manifests
|
278
|
+
|
279
|
+
display_scan_manifests(self.core)
|
280
|
+
|
281
|
+
def _dump_metrics(self):
|
282
|
+
# Print metrics in a human-readable format.
|
283
|
+
from spiral.debug.metrics import display_metrics
|
284
|
+
|
285
|
+
display_metrics(self.metrics)
|
spiral/server.py
ADDED
@@ -0,0 +1,17 @@
|
|
1
|
+
import socket
|
2
|
+
import time
|
3
|
+
|
4
|
+
|
5
|
+
def wait_for_port(port: int, host: str = "localhost", timeout: float = 5.0):
|
6
|
+
"""Wait until a port starts accepting TCP connections."""
|
7
|
+
start_time = time.time()
|
8
|
+
while True:
|
9
|
+
try:
|
10
|
+
with socket.create_connection((host, port), timeout=timeout):
|
11
|
+
break
|
12
|
+
except OSError as ex:
|
13
|
+
time.sleep(0.01)
|
14
|
+
if time.time() - start_time >= timeout:
|
15
|
+
raise TimeoutError(
|
16
|
+
f"Waited too long for the port {port} on host {host} to start accepting connections."
|
17
|
+
) from ex
|
spiral/settings.py
ADDED
@@ -0,0 +1,114 @@
|
|
1
|
+
import functools
|
2
|
+
import os
|
3
|
+
from pathlib import Path
|
4
|
+
from typing import TYPE_CHECKING, Annotated
|
5
|
+
|
6
|
+
import typer
|
7
|
+
from pydantic import Field, ValidatorFunctionWrapHandler, WrapValidator
|
8
|
+
from pydantic_settings import (
|
9
|
+
BaseSettings,
|
10
|
+
InitSettingsSource,
|
11
|
+
PydanticBaseSettingsSource,
|
12
|
+
SettingsConfigDict,
|
13
|
+
)
|
14
|
+
|
15
|
+
from spiral.core.authn import Authn, DeviceCodeAuth, Token
|
16
|
+
from spiral.core.client import Spiral
|
17
|
+
|
18
|
+
if TYPE_CHECKING:
|
19
|
+
from spiral.api import SpiralAPI
|
20
|
+
|
21
|
+
DEV = "PYTEST_VERSION" in os.environ or bool(os.environ.get("SPIRAL_DEV", None))
|
22
|
+
CI = "GITHUB_ACTIONS" in os.environ
|
23
|
+
|
24
|
+
APP_DIR = Path(typer.get_app_dir("pyspiral"))
|
25
|
+
LOG_DIR = APP_DIR / "logs"
|
26
|
+
|
27
|
+
PACKAGE_NAME = "pyspiral"
|
28
|
+
|
29
|
+
|
30
|
+
def validate_token(v, handler: ValidatorFunctionWrapHandler):
|
31
|
+
if isinstance(v, str):
|
32
|
+
return Token(v)
|
33
|
+
else:
|
34
|
+
raise ValueError("Token value must be a string")
|
35
|
+
|
36
|
+
|
37
|
+
TokenType = Annotated[Token, WrapValidator(validate_token)]
|
38
|
+
|
39
|
+
|
40
|
+
class SpiralDBSettings(BaseSettings):
|
41
|
+
model_config = SettingsConfigDict(frozen=True)
|
42
|
+
|
43
|
+
host: str = "localhost" if DEV else "api.spiraldb.com"
|
44
|
+
port: int = 4279 if DEV else 443
|
45
|
+
ssl: bool = not DEV
|
46
|
+
token: TokenType | None = None
|
47
|
+
|
48
|
+
@property
|
49
|
+
def uri(self) -> str:
|
50
|
+
return f"{'https' if self.ssl else 'http'}://{self.host}:{self.port}"
|
51
|
+
|
52
|
+
|
53
|
+
class SpfsSettings(BaseSettings):
|
54
|
+
model_config = SettingsConfigDict(frozen=True)
|
55
|
+
|
56
|
+
host: str = "localhost" if DEV else "spfs.spiraldb.dev"
|
57
|
+
port: int = 4295 if DEV else 443
|
58
|
+
ssl: bool = not DEV
|
59
|
+
|
60
|
+
@property
|
61
|
+
def uri(self) -> str:
|
62
|
+
return f"{'https' if self.ssl else 'http'}://{self.host}:{self.port}"
|
63
|
+
|
64
|
+
|
65
|
+
class Settings(BaseSettings):
|
66
|
+
model_config = SettingsConfigDict(
|
67
|
+
env_nested_delimiter="__",
|
68
|
+
env_prefix="SPIRAL__",
|
69
|
+
frozen=True,
|
70
|
+
)
|
71
|
+
|
72
|
+
spiraldb: SpiralDBSettings = Field(default_factory=SpiralDBSettings)
|
73
|
+
spfs: SpfsSettings = Field(default_factory=SpfsSettings)
|
74
|
+
file_format: str = Field(default="vortex")
|
75
|
+
|
76
|
+
@functools.cached_property
|
77
|
+
def api(self) -> "SpiralAPI":
|
78
|
+
from spiral.api import SpiralAPI
|
79
|
+
|
80
|
+
return SpiralAPI(self.authn, base_url=self.spiraldb.uri)
|
81
|
+
|
82
|
+
@functools.cached_property
|
83
|
+
def core(self) -> Spiral:
|
84
|
+
return Spiral(
|
85
|
+
api_url=self.spiraldb.uri,
|
86
|
+
spfs_url=self.spfs.uri,
|
87
|
+
authn=self.authn,
|
88
|
+
)
|
89
|
+
|
90
|
+
@functools.cached_property
|
91
|
+
def authn(self):
|
92
|
+
if self.spiraldb.token:
|
93
|
+
return Authn.from_token(self.spiraldb.token)
|
94
|
+
return Authn.from_fallback(self.spiraldb.uri)
|
95
|
+
|
96
|
+
@functools.cached_property
|
97
|
+
def device_code_auth(self) -> DeviceCodeAuth:
|
98
|
+
return DeviceCodeAuth.default()
|
99
|
+
|
100
|
+
@classmethod
|
101
|
+
def settings_customise_sources(
|
102
|
+
cls,
|
103
|
+
settings_cls: type[BaseSettings],
|
104
|
+
env_settings: PydanticBaseSettingsSource,
|
105
|
+
dotenv_settings: PydanticBaseSettingsSource,
|
106
|
+
init_settings: InitSettingsSource,
|
107
|
+
**kwargs,
|
108
|
+
) -> tuple[PydanticBaseSettingsSource, ...]:
|
109
|
+
return env_settings, dotenv_settings, init_settings
|
110
|
+
|
111
|
+
|
112
|
+
@functools.cache
|
113
|
+
def settings() -> Settings:
|
114
|
+
return Settings()
|
spiral/snapshot.py
ADDED
@@ -0,0 +1,56 @@
|
|
1
|
+
from typing import TYPE_CHECKING
|
2
|
+
|
3
|
+
from spiral.core.table import Snapshot as CoreSnapshot
|
4
|
+
from spiral.core.table.spec import Schema
|
5
|
+
from spiral.types_ import Timestamp
|
6
|
+
|
7
|
+
if TYPE_CHECKING:
|
8
|
+
import duckdb
|
9
|
+
import polars as pl
|
10
|
+
import pyarrow.dataset as ds
|
11
|
+
import torch.utils.data as torchdata # noqa
|
12
|
+
|
13
|
+
from spiral.table import Table
|
14
|
+
|
15
|
+
|
16
|
+
class Snapshot:
|
17
|
+
"""Spiral table snapshot.
|
18
|
+
|
19
|
+
A snapshot represents a point-in-time view of a table.
|
20
|
+
"""
|
21
|
+
|
22
|
+
def __init__(self, table: "Table", core: CoreSnapshot):
|
23
|
+
self.core = core
|
24
|
+
self._table = table
|
25
|
+
|
26
|
+
@property
|
27
|
+
def asof(self) -> Timestamp:
|
28
|
+
"""Returns the asof timestamp of the snapshot."""
|
29
|
+
return self.core.asof
|
30
|
+
|
31
|
+
def schema(self) -> Schema:
|
32
|
+
"""Returns the schema of the snapshot."""
|
33
|
+
return self.core.table.get_schema(asof=self.asof)
|
34
|
+
|
35
|
+
@property
|
36
|
+
def table(self) -> "Table":
|
37
|
+
"""Returns the table associated with the snapshot."""
|
38
|
+
return self._table
|
39
|
+
|
40
|
+
def to_dataset(self) -> "ds.Dataset":
|
41
|
+
"""Returns a PyArrow Dataset representing the table."""
|
42
|
+
from spiral.dataset import Dataset
|
43
|
+
|
44
|
+
return Dataset(self)
|
45
|
+
|
46
|
+
def to_polars(self) -> "pl.LazyFrame":
|
47
|
+
"""Returns a Polars LazyFrame for the Spiral table."""
|
48
|
+
import polars as pl
|
49
|
+
|
50
|
+
return pl.scan_pyarrow_dataset(self.to_dataset())
|
51
|
+
|
52
|
+
def to_duckdb(self) -> "duckdb.DuckDBPyRelation":
|
53
|
+
"""Returns a DuckDB relation for the Spiral table."""
|
54
|
+
import duckdb
|
55
|
+
|
56
|
+
return duckdb.from_arrow(self.to_dataset())
|