pyspiral 0.6.8__cp312-abi3-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyspiral-0.6.8.dist-info/METADATA +51 -0
- pyspiral-0.6.8.dist-info/RECORD +102 -0
- pyspiral-0.6.8.dist-info/WHEEL +4 -0
- pyspiral-0.6.8.dist-info/entry_points.txt +2 -0
- spiral/__init__.py +35 -0
- spiral/_lib.abi3.so +0 -0
- spiral/adbc.py +411 -0
- spiral/api/__init__.py +78 -0
- spiral/api/admin.py +15 -0
- spiral/api/client.py +164 -0
- spiral/api/filesystems.py +134 -0
- spiral/api/key_space_indexes.py +23 -0
- spiral/api/organizations.py +77 -0
- spiral/api/projects.py +219 -0
- spiral/api/telemetry.py +19 -0
- spiral/api/text_indexes.py +56 -0
- spiral/api/types.py +22 -0
- spiral/api/workers.py +40 -0
- spiral/api/workloads.py +52 -0
- spiral/arrow_.py +216 -0
- spiral/cli/__init__.py +88 -0
- spiral/cli/__main__.py +4 -0
- spiral/cli/admin.py +14 -0
- spiral/cli/app.py +104 -0
- spiral/cli/console.py +95 -0
- spiral/cli/fs.py +76 -0
- spiral/cli/iceberg.py +97 -0
- spiral/cli/key_spaces.py +89 -0
- spiral/cli/login.py +24 -0
- spiral/cli/orgs.py +89 -0
- spiral/cli/printer.py +53 -0
- spiral/cli/projects.py +147 -0
- spiral/cli/state.py +5 -0
- spiral/cli/tables.py +174 -0
- spiral/cli/telemetry.py +17 -0
- spiral/cli/text.py +115 -0
- spiral/cli/types.py +50 -0
- spiral/cli/workloads.py +58 -0
- spiral/client.py +178 -0
- spiral/core/__init__.pyi +0 -0
- spiral/core/_tools/__init__.pyi +5 -0
- spiral/core/authn/__init__.pyi +27 -0
- spiral/core/client/__init__.pyi +237 -0
- spiral/core/table/__init__.pyi +101 -0
- spiral/core/table/manifests/__init__.pyi +35 -0
- spiral/core/table/metastore/__init__.pyi +58 -0
- spiral/core/table/spec/__init__.pyi +213 -0
- spiral/dataloader.py +285 -0
- spiral/dataset.py +255 -0
- spiral/datetime_.py +27 -0
- spiral/debug/__init__.py +0 -0
- spiral/debug/manifests.py +87 -0
- spiral/debug/metrics.py +56 -0
- spiral/debug/scan.py +266 -0
- spiral/expressions/__init__.py +276 -0
- spiral/expressions/base.py +157 -0
- spiral/expressions/http.py +86 -0
- spiral/expressions/io.py +100 -0
- spiral/expressions/list_.py +68 -0
- spiral/expressions/mp4.py +62 -0
- spiral/expressions/png.py +18 -0
- spiral/expressions/qoi.py +18 -0
- spiral/expressions/refs.py +58 -0
- spiral/expressions/str_.py +39 -0
- spiral/expressions/struct.py +59 -0
- spiral/expressions/text.py +62 -0
- spiral/expressions/tiff.py +223 -0
- spiral/expressions/udf.py +46 -0
- spiral/grpc_.py +32 -0
- spiral/iceberg.py +31 -0
- spiral/iterable_dataset.py +106 -0
- spiral/key_space_index.py +44 -0
- spiral/project.py +199 -0
- spiral/protogen/_/__init__.py +0 -0
- spiral/protogen/_/arrow/__init__.py +0 -0
- spiral/protogen/_/arrow/flight/__init__.py +0 -0
- spiral/protogen/_/arrow/flight/protocol/__init__.py +0 -0
- spiral/protogen/_/arrow/flight/protocol/sql/__init__.py +2548 -0
- spiral/protogen/_/google/__init__.py +0 -0
- spiral/protogen/_/google/protobuf/__init__.py +2310 -0
- spiral/protogen/_/message_pool.py +3 -0
- spiral/protogen/_/py.typed +0 -0
- spiral/protogen/_/scandal/__init__.py +190 -0
- spiral/protogen/_/spfs/__init__.py +72 -0
- spiral/protogen/_/spql/__init__.py +61 -0
- spiral/protogen/_/substrait/__init__.py +6196 -0
- spiral/protogen/_/substrait/extensions/__init__.py +169 -0
- spiral/protogen/__init__.py +0 -0
- spiral/protogen/util.py +41 -0
- spiral/py.typed +0 -0
- spiral/scan.py +285 -0
- spiral/server.py +17 -0
- spiral/settings.py +114 -0
- spiral/snapshot.py +56 -0
- spiral/streaming_/__init__.py +3 -0
- spiral/streaming_/reader.py +133 -0
- spiral/streaming_/stream.py +157 -0
- spiral/substrait_.py +274 -0
- spiral/table.py +293 -0
- spiral/text_index.py +17 -0
- spiral/transaction.py +58 -0
- spiral/types_.py +6 -0
@@ -0,0 +1,133 @@
|
|
1
|
+
import dataclasses
|
2
|
+
import functools
|
3
|
+
import os
|
4
|
+
from typing import Any
|
5
|
+
|
6
|
+
import vortex as vx
|
7
|
+
|
8
|
+
from spiral.core.client import Shard
|
9
|
+
|
10
|
+
|
11
|
+
# Fake streaming.base.format.base.reader.FileInfo
|
12
|
+
# Dataset manages decompression instead of the Stream in MDS.
|
13
|
+
# So we return our own fake FileInfo that has None for compressed file.
|
14
|
+
@dataclasses.dataclass
|
15
|
+
class FileInfo:
|
16
|
+
basename: str
|
17
|
+
hashes: dict[str, str] = dataclasses.field(default_factory=dict)
|
18
|
+
|
19
|
+
@property
|
20
|
+
def bytes(self):
|
21
|
+
raise NotImplementedError("FileInfo.bytes should NOT be called.")
|
22
|
+
|
23
|
+
|
24
|
+
class SpiralReader:
|
25
|
+
"""
|
26
|
+
An MDS (streaming) compatible Reader.
|
27
|
+
"""
|
28
|
+
|
29
|
+
def __init__(self, shard: Shard, basepath):
|
30
|
+
self._shard = shard
|
31
|
+
if shard.cardinality is None:
|
32
|
+
raise ValueError("Shard cardinality must be known for `streaming`.")
|
33
|
+
self._cardinality = shard.cardinality
|
34
|
+
self._basepath = basepath
|
35
|
+
self._scan: vx.RepeatedScan | None = None
|
36
|
+
|
37
|
+
@property
|
38
|
+
def shard(self) -> Shard:
|
39
|
+
return self._shard
|
40
|
+
|
41
|
+
@property
|
42
|
+
def size(self):
|
43
|
+
"""Get the number of samples in this shard.
|
44
|
+
|
45
|
+
Returns:
|
46
|
+
int: Sample count.
|
47
|
+
"""
|
48
|
+
return self._cardinality
|
49
|
+
|
50
|
+
@property
|
51
|
+
def samples(self):
|
52
|
+
"""Get the number of samples in this shard.
|
53
|
+
|
54
|
+
Returns:
|
55
|
+
int: Sample count.
|
56
|
+
"""
|
57
|
+
return self._cardinality
|
58
|
+
|
59
|
+
def __len__(self) -> int:
|
60
|
+
"""Get the number of samples in this shard.
|
61
|
+
|
62
|
+
Returns:
|
63
|
+
int: Sample count.
|
64
|
+
"""
|
65
|
+
return self._cardinality
|
66
|
+
|
67
|
+
@property
|
68
|
+
def file_pairs(self) -> list[tuple[FileInfo, FileInfo | None]]:
|
69
|
+
"""Get the infos from raw and compressed file.
|
70
|
+
|
71
|
+
MDS uses this because dataset manages decompression of the shards, not stream...
|
72
|
+
"""
|
73
|
+
return [(FileInfo(basename=self.filename), None)]
|
74
|
+
|
75
|
+
def get_max_size(self) -> int:
|
76
|
+
"""Get the full size of this shard.
|
77
|
+
|
78
|
+
"Max" in this case means both the raw (decompressed) and zip (compressed) versions are
|
79
|
+
resident (assuming it has a zip form). This is the maximum disk usage the shard can reach.
|
80
|
+
When compressed was used, even if keep_zip is ``False``, the zip form must still be
|
81
|
+
resident at the same time as the raw form during shard decompression.
|
82
|
+
|
83
|
+
Returns:
|
84
|
+
int: Size in bytes.
|
85
|
+
"""
|
86
|
+
# TODO(marko): This is used to check cache limit is possible...
|
87
|
+
return 0
|
88
|
+
|
89
|
+
@functools.cached_property
|
90
|
+
def filename(self) -> str:
|
91
|
+
"""Used by SpiralStream to identify shard's file-on-disk, if it exists."""
|
92
|
+
# TODO(marko): This might be too long...
|
93
|
+
return (
|
94
|
+
bytes(self._shard.key_range.begin).hex()
|
95
|
+
+ "_"
|
96
|
+
+ bytes(self._shard.key_range.end).hex()
|
97
|
+
+ "_"
|
98
|
+
+ str(self._shard.cardinality)
|
99
|
+
+ ".vortex"
|
100
|
+
)
|
101
|
+
|
102
|
+
@functools.cached_property
|
103
|
+
def filepath(self) -> str:
|
104
|
+
"""Full path to the shard's file-on-disk, if it exists."""
|
105
|
+
return os.path.join(self._basepath, self.filename)
|
106
|
+
|
107
|
+
def evict(self) -> int:
|
108
|
+
"""Remove all files belonging to this shard."""
|
109
|
+
|
110
|
+
# Clean up the scan handle first. This will make sure memory is freed.
|
111
|
+
self._scan = None
|
112
|
+
|
113
|
+
# Try to evict file.
|
114
|
+
try:
|
115
|
+
stat = os.stat(self.filepath)
|
116
|
+
os.remove(self.filepath)
|
117
|
+
return stat.st_size
|
118
|
+
except FileNotFoundError:
|
119
|
+
# Nothing to evict.
|
120
|
+
return 0
|
121
|
+
|
122
|
+
def __getitem__(self, item):
|
123
|
+
return self.get_item(item)
|
124
|
+
|
125
|
+
def get_item(self, idx: int) -> dict[str, Any]:
|
126
|
+
if self._scan is None:
|
127
|
+
# TODO(marko): vx.open should throw FileNotFoundError instead of
|
128
|
+
# ValueError: No such file or directory (os error 2)
|
129
|
+
# Check if shard is ready on disk. This must throw FileNotFoundError.
|
130
|
+
if not os.path.exists(self.filepath):
|
131
|
+
raise FileNotFoundError(f"Shard not found: {self.filepath}")
|
132
|
+
self._scan = vx.open(self.filepath, without_segment_cache=True).to_repeated_scan()
|
133
|
+
return self._scan.scalar_at(idx).as_py()
|
@@ -0,0 +1,157 @@
|
|
1
|
+
import os
|
2
|
+
import tempfile
|
3
|
+
from typing import TYPE_CHECKING
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
from spiral import Scan, Spiral
|
8
|
+
from spiral.core.client import Shard
|
9
|
+
from spiral.streaming_.reader import SpiralReader
|
10
|
+
|
11
|
+
if TYPE_CHECKING:
|
12
|
+
from streaming.base.array import NDArray
|
13
|
+
from streaming.base.format import Reader
|
14
|
+
from streaming.base.world import World
|
15
|
+
|
16
|
+
|
17
|
+
class SpiralStream:
|
18
|
+
"""
|
19
|
+
An MDS (streaming) compatible Stream.
|
20
|
+
|
21
|
+
The stream does not extend the default Stream class, but it is compactible with its API.
|
22
|
+
|
23
|
+
The stream is not registered with MDS, as the only way to construct the stream is through Spiral client.
|
24
|
+
Stream can be passed to MDS's StreamingDataset in `streams` argument.
|
25
|
+
"""
|
26
|
+
|
27
|
+
def __init__(
|
28
|
+
self,
|
29
|
+
sp: Spiral,
|
30
|
+
scan: Scan,
|
31
|
+
shards: list[Shard],
|
32
|
+
cache_dir: str | None = None,
|
33
|
+
shard_row_block_size: int | None = None,
|
34
|
+
):
|
35
|
+
self._sp = sp
|
36
|
+
self._scan = scan
|
37
|
+
# TODO(marko): Read shards only on world.is_local_leader in `get_shards` and materialize on disk.
|
38
|
+
self._shards = shards
|
39
|
+
self._shard_row_block_size = shard_row_block_size or 8192
|
40
|
+
|
41
|
+
if cache_dir is not None:
|
42
|
+
if not os.path.exists(cache_dir):
|
43
|
+
os.makedirs(cache_dir, exist_ok=True)
|
44
|
+
if not os.path.isdir(cache_dir):
|
45
|
+
raise ValueError(f"Cache dir {cache_dir} is not a directory.")
|
46
|
+
else:
|
47
|
+
cache_dir = os.path.join(tempfile.gettempdir(), "spiral-streaming")
|
48
|
+
self._cache_dir = cache_dir
|
49
|
+
|
50
|
+
# Enure split directory exists.
|
51
|
+
os.makedirs(os.path.join(self._cache_dir, self.split), exist_ok=True)
|
52
|
+
|
53
|
+
@property
|
54
|
+
def local(self) -> str:
|
55
|
+
# Dataset: Register/lookup our shared memory prefix and filelock root directory.
|
56
|
+
return self._cache_dir
|
57
|
+
|
58
|
+
@property
|
59
|
+
def remote(self) -> str | None:
|
60
|
+
# Dataset: Register/lookup our shared memory prefix and filelock root directory.
|
61
|
+
return None
|
62
|
+
|
63
|
+
@property
|
64
|
+
def split(self) -> str:
|
65
|
+
# Dataset: Register/lookup our shared memory prefix and filelock root directory.
|
66
|
+
return "default"
|
67
|
+
|
68
|
+
@classmethod
|
69
|
+
def validate_weights(cls, streams) -> tuple[bool, bool]:
|
70
|
+
from streaming.base.stream import Stream
|
71
|
+
|
72
|
+
return Stream.validate_weights(streams)
|
73
|
+
|
74
|
+
@classmethod
|
75
|
+
def apply_weights(cls, streams, samples_per_stream, choose_per_epoch, seed) -> int:
|
76
|
+
from streaming.base.stream import Stream
|
77
|
+
|
78
|
+
return Stream.apply_weights(streams, samples_per_stream, choose_per_epoch, seed)
|
79
|
+
|
80
|
+
def apply_default(self, default: dict):
|
81
|
+
# Applies defaults from the StreamingDataset.
|
82
|
+
# 'remote', 'local', 'split', 'download_retry', 'download_timeout', 'validate_hash', 'keep_zip'
|
83
|
+
if default["split"] is not None:
|
84
|
+
raise ValueError("SpiralStream does not support split, as the split is defined in the Scan.")
|
85
|
+
|
86
|
+
def prepare_shard(self, shard: "Reader") -> int:
|
87
|
+
"""Ensure (download, validate, extract, etc.) that we have the given shard.
|
88
|
+
|
89
|
+
Args:
|
90
|
+
shard (Reader): Which shard.
|
91
|
+
|
92
|
+
Returns:
|
93
|
+
int: Change in cache usage.
|
94
|
+
"""
|
95
|
+
if not isinstance(shard, SpiralReader):
|
96
|
+
raise ValueError("Only SpiralReader is supported in SpiralStream")
|
97
|
+
|
98
|
+
shard_path = os.path.join(self._cache_dir, self.split, shard.filename)
|
99
|
+
if os.path.exists(shard_path):
|
100
|
+
# Already exists.
|
101
|
+
return 0
|
102
|
+
|
103
|
+
# Prepare the shard, writing it to disk.
|
104
|
+
self._sp._ops().prepare_shard(
|
105
|
+
shard_path, self._scan.core, shard.shard, row_block_size=self._shard_row_block_size
|
106
|
+
)
|
107
|
+
|
108
|
+
# Get the size of the file on disk.
|
109
|
+
stat = os.stat(shard_path)
|
110
|
+
return stat.st_size
|
111
|
+
|
112
|
+
def get_shards(self, world: "World", allow_unsafe_types: bool) -> list["Reader"]:
|
113
|
+
"""Load this Stream's index, retrieving its shard readers.
|
114
|
+
|
115
|
+
Args:
|
116
|
+
world (World): Distributed context.
|
117
|
+
allow_unsafe_types (bool): If a shard contains Pickle, which allows arbitrary code
|
118
|
+
execution during deserialization, whether to keep going if ``True`` or raise an error.
|
119
|
+
This argument is ignored as SpiralStream does not support Pickle.
|
120
|
+
|
121
|
+
Returns:
|
122
|
+
`List[Reader]: Shard readers.
|
123
|
+
"""
|
124
|
+
basepath = os.path.join(self._cache_dir, self.split)
|
125
|
+
return [SpiralReader(shard, basepath) for shard in self._shards] # type: ignore[return-value]
|
126
|
+
|
127
|
+
def set_up_local(self, shards: list["Reader"], cache_usage_per_shard: "NDArray[np.int64]") -> None:
|
128
|
+
"""Bring a local directory into a consistent state, getting which shards are present.
|
129
|
+
|
130
|
+
Args:
|
131
|
+
shards (List[Reader]): List of this stream's shards.
|
132
|
+
cache_usage_per_shard (NDArray[np.int64]): Cache usage per shard of this stream.
|
133
|
+
"""
|
134
|
+
listing = set()
|
135
|
+
for file in os.listdir(os.path.join(self._cache_dir, self.split)):
|
136
|
+
if os.path.isfile(os.path.join(self._cache_dir, self.split, file)) and file.endswith(".vortex"):
|
137
|
+
listing.add(file)
|
138
|
+
|
139
|
+
# Determine which shards are present, making local dir consistent.
|
140
|
+
for i, shard in enumerate(shards):
|
141
|
+
if not isinstance(shard, SpiralReader):
|
142
|
+
raise ValueError("Only SpiralReader is supported in SpiralStream")
|
143
|
+
if shard.filename in listing:
|
144
|
+
# Get the size of the file on disk.
|
145
|
+
stat = os.stat(os.path.join(self._cache_dir, self.split, shard.filename))
|
146
|
+
cache_usage_per_shard[i] = stat.st_size
|
147
|
+
else:
|
148
|
+
cache_usage_per_shard[i] = 0
|
149
|
+
|
150
|
+
def get_index_size(self) -> int:
|
151
|
+
"""Get the size of the index file in bytes.
|
152
|
+
|
153
|
+
Returns:
|
154
|
+
int: Size in bytes.
|
155
|
+
"""
|
156
|
+
# There is no index file stored on disk.
|
157
|
+
return 0
|
spiral/substrait_.py
ADDED
@@ -0,0 +1,274 @@
|
|
1
|
+
import betterproto2
|
2
|
+
import pyarrow as pa
|
3
|
+
|
4
|
+
import spiral.expressions as se
|
5
|
+
from spiral.expressions.base import Expr
|
6
|
+
from spiral.protogen._.substrait import (
|
7
|
+
Expression,
|
8
|
+
ExpressionFieldReference,
|
9
|
+
ExpressionLiteral,
|
10
|
+
ExpressionLiteralList,
|
11
|
+
ExpressionLiteralStruct,
|
12
|
+
ExpressionLiteralUserDefined,
|
13
|
+
ExpressionMaskExpression,
|
14
|
+
ExpressionReferenceSegment,
|
15
|
+
ExpressionReferenceSegmentListElement,
|
16
|
+
ExpressionReferenceSegmentStructField,
|
17
|
+
ExpressionScalarFunction,
|
18
|
+
ExtendedExpression,
|
19
|
+
)
|
20
|
+
from spiral.protogen._.substrait.extensions import (
|
21
|
+
SimpleExtensionDeclaration,
|
22
|
+
SimpleExtensionDeclarationExtensionFunction,
|
23
|
+
SimpleExtensionDeclarationExtensionType,
|
24
|
+
SimpleExtensionDeclarationExtensionTypeVariation,
|
25
|
+
)
|
26
|
+
|
27
|
+
|
28
|
+
class SubstraitConverter:
|
29
|
+
def __init__(self, scope: Expr, schema: pa.Schema, key_schema: pa.Schema):
|
30
|
+
self.scope = scope
|
31
|
+
self.schema = schema
|
32
|
+
self.key_names = set(key_schema.names)
|
33
|
+
|
34
|
+
# Extension URIs, keyed by extension URI anchor
|
35
|
+
self.extension_uris = {}
|
36
|
+
|
37
|
+
# Functions, keyed by function_anchor
|
38
|
+
self.functions = {}
|
39
|
+
|
40
|
+
# Types, keyed by type anchor
|
41
|
+
self.type_factories = {}
|
42
|
+
|
43
|
+
def convert(self, buffer: pa.Buffer) -> Expr:
|
44
|
+
"""Convert a Substrait Extended Expression into a Spiral expression."""
|
45
|
+
|
46
|
+
expr: ExtendedExpression = ExtendedExpression().parse(buffer)
|
47
|
+
assert len(expr.referred_expr) == 1, "Only one expression is supported"
|
48
|
+
|
49
|
+
# Parse the extension URIs from the plan.
|
50
|
+
for ext_uri in expr.extension_uris:
|
51
|
+
self.extension_uris[ext_uri.extension_uri_anchor] = ext_uri.uri
|
52
|
+
|
53
|
+
# Parse the extensions from the plan.
|
54
|
+
for ext in expr.extensions:
|
55
|
+
self._extension_declaration(ext)
|
56
|
+
|
57
|
+
# Convert the expression
|
58
|
+
return self._expr(expr.referred_expr[0].expression)
|
59
|
+
|
60
|
+
def _extension_declaration(self, ext: SimpleExtensionDeclaration):
|
61
|
+
match betterproto2.which_one_of(ext, "mapping_type"):
|
62
|
+
case "extension_function", ext_func:
|
63
|
+
self._extension_function(ext_func)
|
64
|
+
case "extension_type", ext_type:
|
65
|
+
self._extension_type(ext_type)
|
66
|
+
case "extension_type_variation", ext_type_variation:
|
67
|
+
self._extension_type_variation(ext_type_variation)
|
68
|
+
case _:
|
69
|
+
raise AssertionError("Invalid substrait plan")
|
70
|
+
|
71
|
+
def _extension_function(self, ext: SimpleExtensionDeclarationExtensionFunction):
|
72
|
+
ext_uri: str = self.extension_uris[ext.extension_uri_reference]
|
73
|
+
match ext_uri:
|
74
|
+
case "https://github.com/substrait-io/substrait/blob/main/extensions/functions_boolean.yaml":
|
75
|
+
match ext.name:
|
76
|
+
case "or":
|
77
|
+
self.functions[ext.function_anchor] = se.or_
|
78
|
+
case "and":
|
79
|
+
self.functions[ext.function_anchor] = se.and_
|
80
|
+
case "xor":
|
81
|
+
self.functions[ext.function_anchor] = se.xor
|
82
|
+
case "not":
|
83
|
+
self.functions[ext.function_anchor] = se.not_
|
84
|
+
case _:
|
85
|
+
raise NotImplementedError(f"Function name {ext.name} not supported")
|
86
|
+
case "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml":
|
87
|
+
match ext.name:
|
88
|
+
case "equal":
|
89
|
+
self.functions[ext.function_anchor] = se.eq
|
90
|
+
case "not_equal":
|
91
|
+
self.functions[ext.function_anchor] = se.neq
|
92
|
+
case "lt":
|
93
|
+
self.functions[ext.function_anchor] = se.lt
|
94
|
+
case "lte":
|
95
|
+
self.functions[ext.function_anchor] = se.lte
|
96
|
+
case "gt":
|
97
|
+
self.functions[ext.function_anchor] = se.gt
|
98
|
+
case "gte":
|
99
|
+
self.functions[ext.function_anchor] = se.gte
|
100
|
+
case "is_null":
|
101
|
+
self.functions[ext.function_anchor] = se.is_null
|
102
|
+
case "is_not_null":
|
103
|
+
self.functions[ext.function_anchor] = se.is_not_null
|
104
|
+
case _:
|
105
|
+
raise NotImplementedError(f"Function name {ext.name} not supported")
|
106
|
+
case uri:
|
107
|
+
raise NotImplementedError(f"Function extension URI {uri} not supported")
|
108
|
+
|
109
|
+
def _extension_type(self, ext: SimpleExtensionDeclarationExtensionType):
|
110
|
+
ext_uri: str = self.extension_uris[ext.extension_uri_reference]
|
111
|
+
match ext_uri:
|
112
|
+
case "https://github.com/apache/arrow/blob/main/format/substrait/extension_types.yaml":
|
113
|
+
match ext.name:
|
114
|
+
case "null":
|
115
|
+
self.type_factories[ext.type_anchor] = pa.null
|
116
|
+
case "interval_month_day_nano":
|
117
|
+
self.type_factories[ext.type_anchor] = pa.month_day_nano_interval
|
118
|
+
case "u8":
|
119
|
+
self.type_factories[ext.type_anchor] = pa.uint8
|
120
|
+
case "u16":
|
121
|
+
self.type_factories[ext.type_anchor] = pa.uint16
|
122
|
+
case "u32":
|
123
|
+
self.type_factories[ext.type_anchor] = pa.uint32
|
124
|
+
case "u64":
|
125
|
+
self.type_factories[ext.type_anchor] = pa.uint64
|
126
|
+
case "fp16":
|
127
|
+
self.type_factories[ext.type_anchor] = pa.float16
|
128
|
+
case "date_millis":
|
129
|
+
self.type_factories[ext.type_anchor] = pa.date64
|
130
|
+
case "time_seconds":
|
131
|
+
self.type_factories[ext.type_anchor] = lambda: pa.time32("s")
|
132
|
+
case "time_millis":
|
133
|
+
self.type_factories[ext.type_anchor] = lambda: pa.time32("ms")
|
134
|
+
case "time_nanos":
|
135
|
+
self.type_factories[ext.type_anchor] = lambda: pa.time64("ns")
|
136
|
+
case "large_string":
|
137
|
+
self.type_factories[ext.type_anchor] = pa.large_string
|
138
|
+
case "large_binary":
|
139
|
+
self.type_factories[ext.type_anchor] = pa.large_binary
|
140
|
+
case "decimal256":
|
141
|
+
self.type_factories[ext.type_anchor] = pa.decimal256
|
142
|
+
case "large_list":
|
143
|
+
self.type_factories[ext.type_anchor] = pa.large_list
|
144
|
+
case "fixed_size_list":
|
145
|
+
self.type_factories[ext.type_anchor] = pa.list_
|
146
|
+
case "duration":
|
147
|
+
self.type_factories[ext.type_anchor] = pa.duration
|
148
|
+
case uri:
|
149
|
+
raise NotImplementedError(f"Type extension URI {uri} not support")
|
150
|
+
|
151
|
+
def _extension_type_variation(self, ext: SimpleExtensionDeclarationExtensionTypeVariation):
|
152
|
+
raise NotImplementedError()
|
153
|
+
|
154
|
+
def _expr(self, expr: Expression) -> Expr:
|
155
|
+
match betterproto2.which_one_of(expr, "rex_type"):
|
156
|
+
case "literal", e:
|
157
|
+
return self._expr_literal(e)
|
158
|
+
case "selection", e:
|
159
|
+
return self._expr_selection(e)
|
160
|
+
case "scalar_function", e:
|
161
|
+
return self._expr_scalar_function(e)
|
162
|
+
case "window_function", _:
|
163
|
+
raise ValueError("Window functions are not supported in Spiral push-down")
|
164
|
+
case "if_then", e:
|
165
|
+
return self._expr_if_then(e)
|
166
|
+
case "switch", e:
|
167
|
+
return self._expr_switch(e)
|
168
|
+
case "singular_or_list", _:
|
169
|
+
raise ValueError("singular_or_list is not supported in Spiral push-down")
|
170
|
+
case "multi_or_list", _:
|
171
|
+
raise ValueError("multi_or_list is not supported in Spiral push-down")
|
172
|
+
case "cast", e:
|
173
|
+
return self._expr_cast(e)
|
174
|
+
case "subquery", _:
|
175
|
+
raise ValueError("Subqueries are not supported in Spiral push-down")
|
176
|
+
case "nested", e:
|
177
|
+
return self._expr_nested(e)
|
178
|
+
case _:
|
179
|
+
raise NotImplementedError(f"Expression type {expr.rex_type} not implemented")
|
180
|
+
|
181
|
+
def _expr_literal(self, expr: ExpressionLiteral):
|
182
|
+
# TODO(ngates): the Spiral literal expression is quite weakly typed...
|
183
|
+
# Maybe we can switch to Vortex?
|
184
|
+
simple = {
|
185
|
+
"boolean",
|
186
|
+
"i8",
|
187
|
+
"i16",
|
188
|
+
"i32",
|
189
|
+
"i64",
|
190
|
+
"fp32",
|
191
|
+
"fp64",
|
192
|
+
"string",
|
193
|
+
"binary",
|
194
|
+
"fixed_char",
|
195
|
+
"var_char",
|
196
|
+
"fixed_binary",
|
197
|
+
}
|
198
|
+
|
199
|
+
match betterproto2.which_one_of(expr, "literal_type"):
|
200
|
+
case type_, v if type_ in simple:
|
201
|
+
return se.scalar(pa.scalar(v))
|
202
|
+
case "timestamp", v:
|
203
|
+
return se.scalar(pa.scalar(v, type=pa.timestamp("us")))
|
204
|
+
case "date", v:
|
205
|
+
return se.scalar(pa.scalar(v, type=pa.date32()))
|
206
|
+
case "time", v:
|
207
|
+
# Substrait time is us since midnight. PyArrow only supports ms.
|
208
|
+
v: int
|
209
|
+
v = int(v / 1000)
|
210
|
+
return se.scalar(pa.scalar(v, type=pa.time32("ms")))
|
211
|
+
case "null", _null_type:
|
212
|
+
# We need a typed null value
|
213
|
+
raise NotImplementedError()
|
214
|
+
case "struct", v:
|
215
|
+
v: ExpressionLiteralStruct
|
216
|
+
# Hmm, v has fields, but no field names. I guess we return a list and the type is applied later?
|
217
|
+
raise NotImplementedError()
|
218
|
+
case "list", v:
|
219
|
+
v: ExpressionLiteralList
|
220
|
+
return pa.scalar([self._expr_literal(e) for e in v.values])
|
221
|
+
case "user_defined", v:
|
222
|
+
v: ExpressionLiteralUserDefined
|
223
|
+
raise NotImplementedError()
|
224
|
+
case literal_type, _:
|
225
|
+
raise NotImplementedError(f"Literal type not supported: {literal_type}")
|
226
|
+
|
227
|
+
def _expr_selection(self, expr: ExpressionFieldReference):
|
228
|
+
match betterproto2.which_one_of(expr, "root_type"):
|
229
|
+
case "root_reference", _:
|
230
|
+
# The reference is relative to the root
|
231
|
+
base_expr = self.scope
|
232
|
+
base_type = pa.struct(self.schema)
|
233
|
+
case _:
|
234
|
+
raise NotImplementedError("Only root_reference expressions are supported")
|
235
|
+
|
236
|
+
match betterproto2.which_one_of(expr, "reference_type"):
|
237
|
+
case "direct_reference", direct_ref:
|
238
|
+
return self._expr_direct_reference(base_expr, base_type, direct_ref)
|
239
|
+
case "masked_reference", masked_ref:
|
240
|
+
return self._expr_masked_reference(base_expr, base_type, masked_ref)
|
241
|
+
case _:
|
242
|
+
raise NotImplementedError()
|
243
|
+
|
244
|
+
def _expr_direct_reference(self, scope: Expr, scope_type: pa.StructType, expr: ExpressionReferenceSegment):
|
245
|
+
match betterproto2.which_one_of(expr, "reference_type"):
|
246
|
+
case "map_key", ref:
|
247
|
+
raise NotImplementedError("Map types not yet supported in Spiral")
|
248
|
+
case "struct_field", ref:
|
249
|
+
ref: ExpressionReferenceSegmentStructField
|
250
|
+
field_name = scope_type.field(ref.field).name
|
251
|
+
scope = se.getitem(scope, field_name)
|
252
|
+
scope_type = scope_type.field(ref.field).type
|
253
|
+
if ref.is_set("child"):
|
254
|
+
return self._expr_direct_reference(scope, scope_type, ref.child)
|
255
|
+
return scope
|
256
|
+
case "list_element", ref:
|
257
|
+
ref: ExpressionReferenceSegmentListElement
|
258
|
+
scope = se.getitem(scope, ref.offset)
|
259
|
+
scope_type = scope_type.field(ref.field).type
|
260
|
+
if ref.is_set("child"):
|
261
|
+
return self._expr_direct_reference(scope, scope_type, ref.child)
|
262
|
+
return scope
|
263
|
+
case "", ref:
|
264
|
+
# Because Proto... we hit this case when we recurse into a child node and it's actually "None".
|
265
|
+
return scope
|
266
|
+
case _:
|
267
|
+
raise NotImplementedError()
|
268
|
+
|
269
|
+
def _expr_masked_reference(self, scope: Expr, scope_type: pa.StructType, expr: ExpressionMaskExpression):
|
270
|
+
raise NotImplementedError("Masked references are not yet supported in Spiral push-down")
|
271
|
+
|
272
|
+
def _expr_scalar_function(self, expr: ExpressionScalarFunction):
|
273
|
+
args = [self._expr(arg.value) for arg in expr.arguments]
|
274
|
+
return self.functions[expr.function_reference](*args)
|