pyspiral 0.6.6__cp312-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. pyspiral-0.6.6.dist-info/METADATA +51 -0
  2. pyspiral-0.6.6.dist-info/RECORD +102 -0
  3. pyspiral-0.6.6.dist-info/WHEEL +4 -0
  4. pyspiral-0.6.6.dist-info/entry_points.txt +2 -0
  5. spiral/__init__.py +35 -0
  6. spiral/_lib.abi3.so +0 -0
  7. spiral/adbc.py +411 -0
  8. spiral/api/__init__.py +78 -0
  9. spiral/api/admin.py +15 -0
  10. spiral/api/client.py +164 -0
  11. spiral/api/filesystems.py +134 -0
  12. spiral/api/key_space_indexes.py +23 -0
  13. spiral/api/organizations.py +77 -0
  14. spiral/api/projects.py +219 -0
  15. spiral/api/telemetry.py +19 -0
  16. spiral/api/text_indexes.py +56 -0
  17. spiral/api/types.py +22 -0
  18. spiral/api/workers.py +40 -0
  19. spiral/api/workloads.py +52 -0
  20. spiral/arrow_.py +216 -0
  21. spiral/cli/__init__.py +88 -0
  22. spiral/cli/__main__.py +4 -0
  23. spiral/cli/admin.py +14 -0
  24. spiral/cli/app.py +104 -0
  25. spiral/cli/console.py +95 -0
  26. spiral/cli/fs.py +76 -0
  27. spiral/cli/iceberg.py +97 -0
  28. spiral/cli/key_spaces.py +89 -0
  29. spiral/cli/login.py +24 -0
  30. spiral/cli/orgs.py +89 -0
  31. spiral/cli/printer.py +53 -0
  32. spiral/cli/projects.py +147 -0
  33. spiral/cli/state.py +5 -0
  34. spiral/cli/tables.py +174 -0
  35. spiral/cli/telemetry.py +17 -0
  36. spiral/cli/text.py +115 -0
  37. spiral/cli/types.py +50 -0
  38. spiral/cli/workloads.py +58 -0
  39. spiral/client.py +178 -0
  40. spiral/core/__init__.pyi +0 -0
  41. spiral/core/_tools/__init__.pyi +5 -0
  42. spiral/core/authn/__init__.pyi +27 -0
  43. spiral/core/client/__init__.pyi +237 -0
  44. spiral/core/table/__init__.pyi +101 -0
  45. spiral/core/table/manifests/__init__.pyi +35 -0
  46. spiral/core/table/metastore/__init__.pyi +58 -0
  47. spiral/core/table/spec/__init__.pyi +213 -0
  48. spiral/dataloader.py +285 -0
  49. spiral/dataset.py +255 -0
  50. spiral/datetime_.py +27 -0
  51. spiral/debug/__init__.py +0 -0
  52. spiral/debug/manifests.py +87 -0
  53. spiral/debug/metrics.py +56 -0
  54. spiral/debug/scan.py +266 -0
  55. spiral/expressions/__init__.py +276 -0
  56. spiral/expressions/base.py +157 -0
  57. spiral/expressions/http.py +86 -0
  58. spiral/expressions/io.py +100 -0
  59. spiral/expressions/list_.py +68 -0
  60. spiral/expressions/mp4.py +62 -0
  61. spiral/expressions/png.py +18 -0
  62. spiral/expressions/qoi.py +18 -0
  63. spiral/expressions/refs.py +58 -0
  64. spiral/expressions/str_.py +39 -0
  65. spiral/expressions/struct.py +59 -0
  66. spiral/expressions/text.py +62 -0
  67. spiral/expressions/tiff.py +223 -0
  68. spiral/expressions/udf.py +46 -0
  69. spiral/grpc_.py +32 -0
  70. spiral/iceberg.py +31 -0
  71. spiral/iterable_dataset.py +106 -0
  72. spiral/key_space_index.py +44 -0
  73. spiral/project.py +199 -0
  74. spiral/protogen/_/__init__.py +0 -0
  75. spiral/protogen/_/arrow/__init__.py +0 -0
  76. spiral/protogen/_/arrow/flight/__init__.py +0 -0
  77. spiral/protogen/_/arrow/flight/protocol/__init__.py +0 -0
  78. spiral/protogen/_/arrow/flight/protocol/sql/__init__.py +2548 -0
  79. spiral/protogen/_/google/__init__.py +0 -0
  80. spiral/protogen/_/google/protobuf/__init__.py +2310 -0
  81. spiral/protogen/_/message_pool.py +3 -0
  82. spiral/protogen/_/py.typed +0 -0
  83. spiral/protogen/_/scandal/__init__.py +190 -0
  84. spiral/protogen/_/spfs/__init__.py +72 -0
  85. spiral/protogen/_/spql/__init__.py +61 -0
  86. spiral/protogen/_/substrait/__init__.py +6196 -0
  87. spiral/protogen/_/substrait/extensions/__init__.py +169 -0
  88. spiral/protogen/__init__.py +0 -0
  89. spiral/protogen/util.py +41 -0
  90. spiral/py.typed +0 -0
  91. spiral/scan.py +285 -0
  92. spiral/server.py +17 -0
  93. spiral/settings.py +114 -0
  94. spiral/snapshot.py +56 -0
  95. spiral/streaming_/__init__.py +3 -0
  96. spiral/streaming_/reader.py +133 -0
  97. spiral/streaming_/stream.py +157 -0
  98. spiral/substrait_.py +274 -0
  99. spiral/table.py +293 -0
  100. spiral/text_index.py +17 -0
  101. spiral/transaction.py +58 -0
  102. spiral/types_.py +6 -0
@@ -0,0 +1,133 @@
1
+ import dataclasses
2
+ import functools
3
+ import os
4
+ from typing import Any
5
+
6
+ import vortex as vx
7
+
8
+ from spiral.core.client import Shard
9
+
10
+
11
+ # Fake streaming.base.format.base.reader.FileInfo
12
+ # Dataset manages decompression instead of the Stream in MDS.
13
+ # So we return our own fake FileInfo that has None for compressed file.
14
+ @dataclasses.dataclass
15
+ class FileInfo:
16
+ basename: str
17
+ hashes: dict[str, str] = dataclasses.field(default_factory=dict)
18
+
19
+ @property
20
+ def bytes(self):
21
+ raise NotImplementedError("FileInfo.bytes should NOT be called.")
22
+
23
+
24
+ class SpiralReader:
25
+ """
26
+ An MDS (streaming) compatible Reader.
27
+ """
28
+
29
+ def __init__(self, shard: Shard, basepath):
30
+ self._shard = shard
31
+ if shard.cardinality is None:
32
+ raise ValueError("Shard cardinality must be known for `streaming`.")
33
+ self._cardinality = shard.cardinality
34
+ self._basepath = basepath
35
+ self._scan: vx.RepeatedScan | None = None
36
+
37
+ @property
38
+ def shard(self) -> Shard:
39
+ return self._shard
40
+
41
+ @property
42
+ def size(self):
43
+ """Get the number of samples in this shard.
44
+
45
+ Returns:
46
+ int: Sample count.
47
+ """
48
+ return self._cardinality
49
+
50
+ @property
51
+ def samples(self):
52
+ """Get the number of samples in this shard.
53
+
54
+ Returns:
55
+ int: Sample count.
56
+ """
57
+ return self._cardinality
58
+
59
+ def __len__(self) -> int:
60
+ """Get the number of samples in this shard.
61
+
62
+ Returns:
63
+ int: Sample count.
64
+ """
65
+ return self._cardinality
66
+
67
+ @property
68
+ def file_pairs(self) -> list[tuple[FileInfo, FileInfo | None]]:
69
+ """Get the infos from raw and compressed file.
70
+
71
+ MDS uses this because dataset manages decompression of the shards, not stream...
72
+ """
73
+ return [(FileInfo(basename=self.filename), None)]
74
+
75
+ def get_max_size(self) -> int:
76
+ """Get the full size of this shard.
77
+
78
+ "Max" in this case means both the raw (decompressed) and zip (compressed) versions are
79
+ resident (assuming it has a zip form). This is the maximum disk usage the shard can reach.
80
+ When compressed was used, even if keep_zip is ``False``, the zip form must still be
81
+ resident at the same time as the raw form during shard decompression.
82
+
83
+ Returns:
84
+ int: Size in bytes.
85
+ """
86
+ # TODO(marko): This is used to check cache limit is possible...
87
+ return 0
88
+
89
+ @functools.cached_property
90
+ def filename(self) -> str:
91
+ """Used by SpiralStream to identify shard's file-on-disk, if it exists."""
92
+ # TODO(marko): This might be too long...
93
+ return (
94
+ bytes(self._shard.key_range.begin).hex()
95
+ + "_"
96
+ + bytes(self._shard.key_range.end).hex()
97
+ + "_"
98
+ + str(self._shard.cardinality)
99
+ + ".vortex"
100
+ )
101
+
102
+ @functools.cached_property
103
+ def filepath(self) -> str:
104
+ """Full path to the shard's file-on-disk, if it exists."""
105
+ return os.path.join(self._basepath, self.filename)
106
+
107
+ def evict(self) -> int:
108
+ """Remove all files belonging to this shard."""
109
+
110
+ # Clean up the scan handle first. This will make sure memory is freed.
111
+ self._scan = None
112
+
113
+ # Try to evict file.
114
+ try:
115
+ stat = os.stat(self.filepath)
116
+ os.remove(self.filepath)
117
+ return stat.st_size
118
+ except FileNotFoundError:
119
+ # Nothing to evict.
120
+ return 0
121
+
122
+ def __getitem__(self, item):
123
+ return self.get_item(item)
124
+
125
+ def get_item(self, idx: int) -> dict[str, Any]:
126
+ if self._scan is None:
127
+ # TODO(marko): vx.open should throw FileNotFoundError instead of
128
+ # ValueError: No such file or directory (os error 2)
129
+ # Check if shard is ready on disk. This must throw FileNotFoundError.
130
+ if not os.path.exists(self.filepath):
131
+ raise FileNotFoundError(f"Shard not found: {self.filepath}")
132
+ self._scan = vx.open(self.filepath, without_segment_cache=True).to_repeated_scan()
133
+ return self._scan.scalar_at(idx).as_py()
@@ -0,0 +1,157 @@
1
+ import os
2
+ import tempfile
3
+ from typing import TYPE_CHECKING
4
+
5
+ import numpy as np
6
+
7
+ from spiral import Scan, Spiral
8
+ from spiral.core.client import Shard
9
+ from spiral.streaming_.reader import SpiralReader
10
+
11
+ if TYPE_CHECKING:
12
+ from streaming.base.array import NDArray
13
+ from streaming.base.format import Reader
14
+ from streaming.base.world import World
15
+
16
+
17
+ class SpiralStream:
18
+ """
19
+ An MDS (streaming) compatible Stream.
20
+
21
+ The stream does not extend the default Stream class, but it is compactible with its API.
22
+
23
+ The stream is not registered with MDS, as the only way to construct the stream is through Spiral client.
24
+ Stream can be passed to MDS's StreamingDataset in `streams` argument.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ sp: Spiral,
30
+ scan: Scan,
31
+ shards: list[Shard],
32
+ cache_dir: str | None = None,
33
+ shard_row_block_size: int | None = None,
34
+ ):
35
+ self._sp = sp
36
+ self._scan = scan
37
+ # TODO(marko): Read shards only on world.is_local_leader in `get_shards` and materialize on disk.
38
+ self._shards = shards
39
+ self._shard_row_block_size = shard_row_block_size or 8192
40
+
41
+ if cache_dir is not None:
42
+ if not os.path.exists(cache_dir):
43
+ os.makedirs(cache_dir, exist_ok=True)
44
+ if not os.path.isdir(cache_dir):
45
+ raise ValueError(f"Cache dir {cache_dir} is not a directory.")
46
+ else:
47
+ cache_dir = os.path.join(tempfile.gettempdir(), "spiral-streaming")
48
+ self._cache_dir = cache_dir
49
+
50
+ # Enure split directory exists.
51
+ os.makedirs(os.path.join(self._cache_dir, self.split), exist_ok=True)
52
+
53
+ @property
54
+ def local(self) -> str:
55
+ # Dataset: Register/lookup our shared memory prefix and filelock root directory.
56
+ return self._cache_dir
57
+
58
+ @property
59
+ def remote(self) -> str | None:
60
+ # Dataset: Register/lookup our shared memory prefix and filelock root directory.
61
+ return None
62
+
63
+ @property
64
+ def split(self) -> str:
65
+ # Dataset: Register/lookup our shared memory prefix and filelock root directory.
66
+ return "default"
67
+
68
+ @classmethod
69
+ def validate_weights(cls, streams) -> tuple[bool, bool]:
70
+ from streaming.base.stream import Stream
71
+
72
+ return Stream.validate_weights(streams)
73
+
74
+ @classmethod
75
+ def apply_weights(cls, streams, samples_per_stream, choose_per_epoch, seed) -> int:
76
+ from streaming.base.stream import Stream
77
+
78
+ return Stream.apply_weights(streams, samples_per_stream, choose_per_epoch, seed)
79
+
80
+ def apply_default(self, default: dict):
81
+ # Applies defaults from the StreamingDataset.
82
+ # 'remote', 'local', 'split', 'download_retry', 'download_timeout', 'validate_hash', 'keep_zip'
83
+ if default["split"] is not None:
84
+ raise ValueError("SpiralStream does not support split, as the split is defined in the Scan.")
85
+
86
+ def prepare_shard(self, shard: "Reader") -> int:
87
+ """Ensure (download, validate, extract, etc.) that we have the given shard.
88
+
89
+ Args:
90
+ shard (Reader): Which shard.
91
+
92
+ Returns:
93
+ int: Change in cache usage.
94
+ """
95
+ if not isinstance(shard, SpiralReader):
96
+ raise ValueError("Only SpiralReader is supported in SpiralStream")
97
+
98
+ shard_path = os.path.join(self._cache_dir, self.split, shard.filename)
99
+ if os.path.exists(shard_path):
100
+ # Already exists.
101
+ return 0
102
+
103
+ # Prepare the shard, writing it to disk.
104
+ self._sp._ops().prepare_shard(
105
+ shard_path, self._scan.core, shard.shard, row_block_size=self._shard_row_block_size
106
+ )
107
+
108
+ # Get the size of the file on disk.
109
+ stat = os.stat(shard_path)
110
+ return stat.st_size
111
+
112
+ def get_shards(self, world: "World", allow_unsafe_types: bool) -> list["Reader"]:
113
+ """Load this Stream's index, retrieving its shard readers.
114
+
115
+ Args:
116
+ world (World): Distributed context.
117
+ allow_unsafe_types (bool): If a shard contains Pickle, which allows arbitrary code
118
+ execution during deserialization, whether to keep going if ``True`` or raise an error.
119
+ This argument is ignored as SpiralStream does not support Pickle.
120
+
121
+ Returns:
122
+ `List[Reader]: Shard readers.
123
+ """
124
+ basepath = os.path.join(self._cache_dir, self.split)
125
+ return [SpiralReader(shard, basepath) for shard in self._shards] # type: ignore[return-value]
126
+
127
+ def set_up_local(self, shards: list["Reader"], cache_usage_per_shard: "NDArray[np.int64]") -> None:
128
+ """Bring a local directory into a consistent state, getting which shards are present.
129
+
130
+ Args:
131
+ shards (List[Reader]): List of this stream's shards.
132
+ cache_usage_per_shard (NDArray[np.int64]): Cache usage per shard of this stream.
133
+ """
134
+ listing = set()
135
+ for file in os.listdir(os.path.join(self._cache_dir, self.split)):
136
+ if os.path.isfile(os.path.join(self._cache_dir, self.split, file)) and file.endswith(".vortex"):
137
+ listing.add(file)
138
+
139
+ # Determine which shards are present, making local dir consistent.
140
+ for i, shard in enumerate(shards):
141
+ if not isinstance(shard, SpiralReader):
142
+ raise ValueError("Only SpiralReader is supported in SpiralStream")
143
+ if shard.filename in listing:
144
+ # Get the size of the file on disk.
145
+ stat = os.stat(os.path.join(self._cache_dir, self.split, shard.filename))
146
+ cache_usage_per_shard[i] = stat.st_size
147
+ else:
148
+ cache_usage_per_shard[i] = 0
149
+
150
+ def get_index_size(self) -> int:
151
+ """Get the size of the index file in bytes.
152
+
153
+ Returns:
154
+ int: Size in bytes.
155
+ """
156
+ # There is no index file stored on disk.
157
+ return 0
spiral/substrait_.py ADDED
@@ -0,0 +1,274 @@
1
+ import betterproto2
2
+ import pyarrow as pa
3
+
4
+ import spiral.expressions as se
5
+ from spiral.expressions.base import Expr
6
+ from spiral.protogen._.substrait import (
7
+ Expression,
8
+ ExpressionFieldReference,
9
+ ExpressionLiteral,
10
+ ExpressionLiteralList,
11
+ ExpressionLiteralStruct,
12
+ ExpressionLiteralUserDefined,
13
+ ExpressionMaskExpression,
14
+ ExpressionReferenceSegment,
15
+ ExpressionReferenceSegmentListElement,
16
+ ExpressionReferenceSegmentStructField,
17
+ ExpressionScalarFunction,
18
+ ExtendedExpression,
19
+ )
20
+ from spiral.protogen._.substrait.extensions import (
21
+ SimpleExtensionDeclaration,
22
+ SimpleExtensionDeclarationExtensionFunction,
23
+ SimpleExtensionDeclarationExtensionType,
24
+ SimpleExtensionDeclarationExtensionTypeVariation,
25
+ )
26
+
27
+
28
+ class SubstraitConverter:
29
+ def __init__(self, scope: Expr, schema: pa.Schema, key_schema: pa.Schema):
30
+ self.scope = scope
31
+ self.schema = schema
32
+ self.key_names = set(key_schema.names)
33
+
34
+ # Extension URIs, keyed by extension URI anchor
35
+ self.extension_uris = {}
36
+
37
+ # Functions, keyed by function_anchor
38
+ self.functions = {}
39
+
40
+ # Types, keyed by type anchor
41
+ self.type_factories = {}
42
+
43
+ def convert(self, buffer: pa.Buffer) -> Expr:
44
+ """Convert a Substrait Extended Expression into a Spiral expression."""
45
+
46
+ expr: ExtendedExpression = ExtendedExpression().parse(buffer)
47
+ assert len(expr.referred_expr) == 1, "Only one expression is supported"
48
+
49
+ # Parse the extension URIs from the plan.
50
+ for ext_uri in expr.extension_uris:
51
+ self.extension_uris[ext_uri.extension_uri_anchor] = ext_uri.uri
52
+
53
+ # Parse the extensions from the plan.
54
+ for ext in expr.extensions:
55
+ self._extension_declaration(ext)
56
+
57
+ # Convert the expression
58
+ return self._expr(expr.referred_expr[0].expression)
59
+
60
+ def _extension_declaration(self, ext: SimpleExtensionDeclaration):
61
+ match betterproto2.which_one_of(ext, "mapping_type"):
62
+ case "extension_function", ext_func:
63
+ self._extension_function(ext_func)
64
+ case "extension_type", ext_type:
65
+ self._extension_type(ext_type)
66
+ case "extension_type_variation", ext_type_variation:
67
+ self._extension_type_variation(ext_type_variation)
68
+ case _:
69
+ raise AssertionError("Invalid substrait plan")
70
+
71
+ def _extension_function(self, ext: SimpleExtensionDeclarationExtensionFunction):
72
+ ext_uri: str = self.extension_uris[ext.extension_uri_reference]
73
+ match ext_uri:
74
+ case "https://github.com/substrait-io/substrait/blob/main/extensions/functions_boolean.yaml":
75
+ match ext.name:
76
+ case "or":
77
+ self.functions[ext.function_anchor] = se.or_
78
+ case "and":
79
+ self.functions[ext.function_anchor] = se.and_
80
+ case "xor":
81
+ self.functions[ext.function_anchor] = se.xor
82
+ case "not":
83
+ self.functions[ext.function_anchor] = se.not_
84
+ case _:
85
+ raise NotImplementedError(f"Function name {ext.name} not supported")
86
+ case "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml":
87
+ match ext.name:
88
+ case "equal":
89
+ self.functions[ext.function_anchor] = se.eq
90
+ case "not_equal":
91
+ self.functions[ext.function_anchor] = se.neq
92
+ case "lt":
93
+ self.functions[ext.function_anchor] = se.lt
94
+ case "lte":
95
+ self.functions[ext.function_anchor] = se.lte
96
+ case "gt":
97
+ self.functions[ext.function_anchor] = se.gt
98
+ case "gte":
99
+ self.functions[ext.function_anchor] = se.gte
100
+ case "is_null":
101
+ self.functions[ext.function_anchor] = se.is_null
102
+ case "is_not_null":
103
+ self.functions[ext.function_anchor] = se.is_not_null
104
+ case _:
105
+ raise NotImplementedError(f"Function name {ext.name} not supported")
106
+ case uri:
107
+ raise NotImplementedError(f"Function extension URI {uri} not supported")
108
+
109
+ def _extension_type(self, ext: SimpleExtensionDeclarationExtensionType):
110
+ ext_uri: str = self.extension_uris[ext.extension_uri_reference]
111
+ match ext_uri:
112
+ case "https://github.com/apache/arrow/blob/main/format/substrait/extension_types.yaml":
113
+ match ext.name:
114
+ case "null":
115
+ self.type_factories[ext.type_anchor] = pa.null
116
+ case "interval_month_day_nano":
117
+ self.type_factories[ext.type_anchor] = pa.month_day_nano_interval
118
+ case "u8":
119
+ self.type_factories[ext.type_anchor] = pa.uint8
120
+ case "u16":
121
+ self.type_factories[ext.type_anchor] = pa.uint16
122
+ case "u32":
123
+ self.type_factories[ext.type_anchor] = pa.uint32
124
+ case "u64":
125
+ self.type_factories[ext.type_anchor] = pa.uint64
126
+ case "fp16":
127
+ self.type_factories[ext.type_anchor] = pa.float16
128
+ case "date_millis":
129
+ self.type_factories[ext.type_anchor] = pa.date64
130
+ case "time_seconds":
131
+ self.type_factories[ext.type_anchor] = lambda: pa.time32("s")
132
+ case "time_millis":
133
+ self.type_factories[ext.type_anchor] = lambda: pa.time32("ms")
134
+ case "time_nanos":
135
+ self.type_factories[ext.type_anchor] = lambda: pa.time64("ns")
136
+ case "large_string":
137
+ self.type_factories[ext.type_anchor] = pa.large_string
138
+ case "large_binary":
139
+ self.type_factories[ext.type_anchor] = pa.large_binary
140
+ case "decimal256":
141
+ self.type_factories[ext.type_anchor] = pa.decimal256
142
+ case "large_list":
143
+ self.type_factories[ext.type_anchor] = pa.large_list
144
+ case "fixed_size_list":
145
+ self.type_factories[ext.type_anchor] = pa.list_
146
+ case "duration":
147
+ self.type_factories[ext.type_anchor] = pa.duration
148
+ case uri:
149
+ raise NotImplementedError(f"Type extension URI {uri} not support")
150
+
151
+ def _extension_type_variation(self, ext: SimpleExtensionDeclarationExtensionTypeVariation):
152
+ raise NotImplementedError()
153
+
154
+ def _expr(self, expr: Expression) -> Expr:
155
+ match betterproto2.which_one_of(expr, "rex_type"):
156
+ case "literal", e:
157
+ return self._expr_literal(e)
158
+ case "selection", e:
159
+ return self._expr_selection(e)
160
+ case "scalar_function", e:
161
+ return self._expr_scalar_function(e)
162
+ case "window_function", _:
163
+ raise ValueError("Window functions are not supported in Spiral push-down")
164
+ case "if_then", e:
165
+ return self._expr_if_then(e)
166
+ case "switch", e:
167
+ return self._expr_switch(e)
168
+ case "singular_or_list", _:
169
+ raise ValueError("singular_or_list is not supported in Spiral push-down")
170
+ case "multi_or_list", _:
171
+ raise ValueError("multi_or_list is not supported in Spiral push-down")
172
+ case "cast", e:
173
+ return self._expr_cast(e)
174
+ case "subquery", _:
175
+ raise ValueError("Subqueries are not supported in Spiral push-down")
176
+ case "nested", e:
177
+ return self._expr_nested(e)
178
+ case _:
179
+ raise NotImplementedError(f"Expression type {expr.rex_type} not implemented")
180
+
181
+ def _expr_literal(self, expr: ExpressionLiteral):
182
+ # TODO(ngates): the Spiral literal expression is quite weakly typed...
183
+ # Maybe we can switch to Vortex?
184
+ simple = {
185
+ "boolean",
186
+ "i8",
187
+ "i16",
188
+ "i32",
189
+ "i64",
190
+ "fp32",
191
+ "fp64",
192
+ "string",
193
+ "binary",
194
+ "fixed_char",
195
+ "var_char",
196
+ "fixed_binary",
197
+ }
198
+
199
+ match betterproto2.which_one_of(expr, "literal_type"):
200
+ case type_, v if type_ in simple:
201
+ return se.scalar(pa.scalar(v))
202
+ case "timestamp", v:
203
+ return se.scalar(pa.scalar(v, type=pa.timestamp("us")))
204
+ case "date", v:
205
+ return se.scalar(pa.scalar(v, type=pa.date32()))
206
+ case "time", v:
207
+ # Substrait time is us since midnight. PyArrow only supports ms.
208
+ v: int
209
+ v = int(v / 1000)
210
+ return se.scalar(pa.scalar(v, type=pa.time32("ms")))
211
+ case "null", _null_type:
212
+ # We need a typed null value
213
+ raise NotImplementedError()
214
+ case "struct", v:
215
+ v: ExpressionLiteralStruct
216
+ # Hmm, v has fields, but no field names. I guess we return a list and the type is applied later?
217
+ raise NotImplementedError()
218
+ case "list", v:
219
+ v: ExpressionLiteralList
220
+ return pa.scalar([self._expr_literal(e) for e in v.values])
221
+ case "user_defined", v:
222
+ v: ExpressionLiteralUserDefined
223
+ raise NotImplementedError()
224
+ case literal_type, _:
225
+ raise NotImplementedError(f"Literal type not supported: {literal_type}")
226
+
227
+ def _expr_selection(self, expr: ExpressionFieldReference):
228
+ match betterproto2.which_one_of(expr, "root_type"):
229
+ case "root_reference", _:
230
+ # The reference is relative to the root
231
+ base_expr = self.scope
232
+ base_type = pa.struct(self.schema)
233
+ case _:
234
+ raise NotImplementedError("Only root_reference expressions are supported")
235
+
236
+ match betterproto2.which_one_of(expr, "reference_type"):
237
+ case "direct_reference", direct_ref:
238
+ return self._expr_direct_reference(base_expr, base_type, direct_ref)
239
+ case "masked_reference", masked_ref:
240
+ return self._expr_masked_reference(base_expr, base_type, masked_ref)
241
+ case _:
242
+ raise NotImplementedError()
243
+
244
+ def _expr_direct_reference(self, scope: Expr, scope_type: pa.StructType, expr: ExpressionReferenceSegment):
245
+ match betterproto2.which_one_of(expr, "reference_type"):
246
+ case "map_key", ref:
247
+ raise NotImplementedError("Map types not yet supported in Spiral")
248
+ case "struct_field", ref:
249
+ ref: ExpressionReferenceSegmentStructField
250
+ field_name = scope_type.field(ref.field).name
251
+ scope = se.getitem(scope, field_name)
252
+ scope_type = scope_type.field(ref.field).type
253
+ if ref.is_set("child"):
254
+ return self._expr_direct_reference(scope, scope_type, ref.child)
255
+ return scope
256
+ case "list_element", ref:
257
+ ref: ExpressionReferenceSegmentListElement
258
+ scope = se.getitem(scope, ref.offset)
259
+ scope_type = scope_type.field(ref.field).type
260
+ if ref.is_set("child"):
261
+ return self._expr_direct_reference(scope, scope_type, ref.child)
262
+ return scope
263
+ case "", ref:
264
+ # Because Proto... we hit this case when we recurse into a child node and it's actually "None".
265
+ return scope
266
+ case _:
267
+ raise NotImplementedError()
268
+
269
+ def _expr_masked_reference(self, scope: Expr, scope_type: pa.StructType, expr: ExpressionMaskExpression):
270
+ raise NotImplementedError("Masked references are not yet supported in Spiral push-down")
271
+
272
+ def _expr_scalar_function(self, expr: ExpressionScalarFunction):
273
+ args = [self._expr(arg.value) for arg in expr.arguments]
274
+ return self.functions[expr.function_reference](*args)