pyspiral 0.1.0__cp310-abi3-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. pyspiral-0.1.0.dist-info/METADATA +48 -0
  2. pyspiral-0.1.0.dist-info/RECORD +81 -0
  3. pyspiral-0.1.0.dist-info/WHEEL +4 -0
  4. pyspiral-0.1.0.dist-info/entry_points.txt +2 -0
  5. spiral/__init__.py +11 -0
  6. spiral/_lib.abi3.so +0 -0
  7. spiral/adbc.py +386 -0
  8. spiral/api/__init__.py +221 -0
  9. spiral/api/admin.py +29 -0
  10. spiral/api/filesystems.py +125 -0
  11. spiral/api/organizations.py +90 -0
  12. spiral/api/projects.py +160 -0
  13. spiral/api/tables.py +94 -0
  14. spiral/api/tokens.py +56 -0
  15. spiral/api/workloads.py +45 -0
  16. spiral/arrow.py +209 -0
  17. spiral/authn/__init__.py +0 -0
  18. spiral/authn/authn.py +89 -0
  19. spiral/authn/device.py +206 -0
  20. spiral/authn/github_.py +33 -0
  21. spiral/authn/modal_.py +18 -0
  22. spiral/catalog.py +78 -0
  23. spiral/cli/__init__.py +82 -0
  24. spiral/cli/__main__.py +4 -0
  25. spiral/cli/admin.py +21 -0
  26. spiral/cli/app.py +48 -0
  27. spiral/cli/console.py +95 -0
  28. spiral/cli/fs.py +47 -0
  29. spiral/cli/login.py +13 -0
  30. spiral/cli/org.py +90 -0
  31. spiral/cli/printer.py +45 -0
  32. spiral/cli/project.py +107 -0
  33. spiral/cli/state.py +3 -0
  34. spiral/cli/table.py +20 -0
  35. spiral/cli/token.py +27 -0
  36. spiral/cli/types.py +53 -0
  37. spiral/cli/workload.py +59 -0
  38. spiral/config.py +26 -0
  39. spiral/core/__init__.py +0 -0
  40. spiral/core/core/__init__.pyi +53 -0
  41. spiral/core/manifests/__init__.pyi +53 -0
  42. spiral/core/metastore/__init__.pyi +91 -0
  43. spiral/core/spec/__init__.pyi +257 -0
  44. spiral/dataset.py +239 -0
  45. spiral/debug.py +251 -0
  46. spiral/expressions/__init__.py +222 -0
  47. spiral/expressions/base.py +149 -0
  48. spiral/expressions/http.py +86 -0
  49. spiral/expressions/io.py +100 -0
  50. spiral/expressions/list_.py +68 -0
  51. spiral/expressions/refs.py +44 -0
  52. spiral/expressions/str_.py +39 -0
  53. spiral/expressions/struct.py +57 -0
  54. spiral/expressions/tiff.py +223 -0
  55. spiral/expressions/udf.py +46 -0
  56. spiral/grpc_.py +32 -0
  57. spiral/project.py +137 -0
  58. spiral/proto/_/__init__.py +0 -0
  59. spiral/proto/_/arrow/__init__.py +0 -0
  60. spiral/proto/_/arrow/flight/__init__.py +0 -0
  61. spiral/proto/_/arrow/flight/protocol/__init__.py +0 -0
  62. spiral/proto/_/arrow/flight/protocol/sql/__init__.py +1990 -0
  63. spiral/proto/_/scandal/__init__.py +223 -0
  64. spiral/proto/_/spfs/__init__.py +36 -0
  65. spiral/proto/_/spiral/__init__.py +0 -0
  66. spiral/proto/_/spiral/table/__init__.py +225 -0
  67. spiral/proto/_/spiraldb/__init__.py +0 -0
  68. spiral/proto/_/spiraldb/metastore/__init__.py +499 -0
  69. spiral/proto/__init__.py +0 -0
  70. spiral/proto/scandal/__init__.py +45 -0
  71. spiral/proto/spiral/__init__.py +0 -0
  72. spiral/proto/spiral/table/__init__.py +96 -0
  73. spiral/proto/substrait/__init__.py +3399 -0
  74. spiral/proto/substrait/extensions/__init__.py +115 -0
  75. spiral/proto/util.py +41 -0
  76. spiral/py.typed +0 -0
  77. spiral/scan_.py +168 -0
  78. spiral/settings.py +157 -0
  79. spiral/substrait_.py +275 -0
  80. spiral/table.py +157 -0
  81. spiral/types_.py +6 -0
spiral/dataset.py ADDED
@@ -0,0 +1,239 @@
1
+ from typing import TYPE_CHECKING, Any
2
+
3
+ import pyarrow as pa
4
+ import pyarrow.compute as pc
5
+
6
+ if TYPE_CHECKING:
7
+ import pyarrow.dataset
8
+
9
+ from spiral import Scan, Table
10
+
11
+
12
+ class TableDataset(pa.dataset.Dataset):
13
+ def __init__(self, table: Table):
14
+ self._table = table
15
+ self._schema: pa.Schema = table.scan().schema.to_arrow()
16
+
17
+ # We don't actually initialize a Dataset, we just implement enough of the API
18
+ # to fool both DuckDB and Polars.
19
+ # super().__init__()
20
+
21
+ @property
22
+ def schema(self) -> pa.Schema:
23
+ return self._schema
24
+
25
+ def count_rows(
26
+ self,
27
+ filter: pc.Expression | None = None,
28
+ batch_size: int | None = None,
29
+ batch_readahead: int | None = None,
30
+ fragment_readahead: int | None = None,
31
+ fragment_scan_options: pa.dataset.FragmentScanOptions | None = None,
32
+ use_threads: bool = True,
33
+ memory_pool: pa.MemoryPool = None,
34
+ ):
35
+ return self.scanner(
36
+ None,
37
+ filter,
38
+ batch_size,
39
+ batch_readahead,
40
+ fragment_readahead,
41
+ fragment_scan_options,
42
+ use_threads,
43
+ memory_pool,
44
+ ).count_rows()
45
+
46
+ def filter(self, expression: pc.Expression) -> "TableDataset":
47
+ raise NotImplementedError("filter not implemented")
48
+
49
+ def get_fragments(self, filter: pc.Expression | None = None):
50
+ """TODO(ngates): perhaps we should return ranges as per our split API?"""
51
+ raise NotImplementedError("get_fragments not implemented")
52
+
53
+ def head(
54
+ self,
55
+ num_rows: int,
56
+ columns: list[str] | None = None,
57
+ filter: pc.Expression | None = None,
58
+ batch_size: int | None = None,
59
+ batch_readahead: int | None = None,
60
+ fragment_readahead: int | None = None,
61
+ fragment_scan_options: pa.dataset.FragmentScanOptions | None = None,
62
+ use_threads: bool = True,
63
+ memory_pool: pa.MemoryPool = None,
64
+ ):
65
+ self.scanner(
66
+ columns,
67
+ filter,
68
+ batch_size,
69
+ batch_readahead,
70
+ fragment_readahead,
71
+ fragment_scan_options,
72
+ use_threads,
73
+ memory_pool,
74
+ ).head(num_rows)
75
+
76
+ def join(
77
+ self,
78
+ right_dataset,
79
+ keys,
80
+ right_keys=None,
81
+ join_type=None,
82
+ left_suffix=None,
83
+ right_suffix=None,
84
+ coalesce_keys=True,
85
+ use_threads=True,
86
+ ):
87
+ raise NotImplementedError("join not implemented")
88
+
89
+ def join_asof(self, right_dataset, on, by, tolerance, right_on=None, right_by=None):
90
+ raise NotImplementedError("join_asof not implemented")
91
+
92
+ def replace_schema(self, schema: pa.Schema) -> "TableDataset":
93
+ raise NotImplementedError("replace_schema not implemented")
94
+
95
+ def scanner(
96
+ self,
97
+ columns: list[str] | None = None,
98
+ filter: pc.Expression | None = None,
99
+ batch_size: int | None = None,
100
+ batch_readahead: int | None = None,
101
+ fragment_readahead: int | None = None,
102
+ fragment_scan_options: pa.dataset.FragmentScanOptions | None = None,
103
+ use_threads: bool = True,
104
+ memory_pool: pa.MemoryPool = None,
105
+ ) -> "TableScanner":
106
+ from .substrait_ import SubstraitConverter
107
+
108
+ # Extract the substrait expression so we can convert it to a Spiral expression
109
+ if filter is not None:
110
+ filter = SubstraitConverter(self._table, self._schema, self._table.key_schema).convert(
111
+ filter.to_substrait(self._schema, allow_arrow_extensions=True),
112
+ )
113
+
114
+ scan = self._table.scan(
115
+ {c: self._table[c] for c in columns} if columns else self._table,
116
+ where=filter,
117
+ exclude_keys=True,
118
+ )
119
+ return TableScanner(scan)
120
+
121
+ def sort_by(self, sorting, **kwargs):
122
+ raise NotImplementedError("sort_by not implemented")
123
+
124
+ def take(
125
+ self,
126
+ indices: pa.Array | Any,
127
+ columns: list[str] | None = None,
128
+ filter: pc.Expression | None = None,
129
+ batch_size: int | None = None,
130
+ batch_readahead: int | None = None,
131
+ fragment_readahead: int | None = None,
132
+ fragment_scan_options: pa.dataset.FragmentScanOptions | None = None,
133
+ use_threads: bool = True,
134
+ memory_pool: pa.MemoryPool = None,
135
+ ):
136
+ return self.scanner(
137
+ columns,
138
+ filter,
139
+ batch_size,
140
+ batch_readahead,
141
+ fragment_readahead,
142
+ fragment_scan_options,
143
+ use_threads,
144
+ memory_pool,
145
+ ).take(indices)
146
+
147
+ def to_batches(
148
+ self,
149
+ columns: list[str] | None = None,
150
+ filter: pc.Expression | None = None,
151
+ batch_size: int | None = None,
152
+ batch_readahead: int | None = None,
153
+ fragment_readahead: int | None = None,
154
+ fragment_scan_options: pa.dataset.FragmentScanOptions | None = None,
155
+ use_threads: bool = True,
156
+ memory_pool: pa.MemoryPool = None,
157
+ ):
158
+ return self.scanner(
159
+ columns,
160
+ filter,
161
+ batch_size,
162
+ batch_readahead,
163
+ fragment_readahead,
164
+ fragment_scan_options,
165
+ use_threads,
166
+ memory_pool,
167
+ ).to_batches()
168
+
169
+ def to_table(
170
+ self,
171
+ columns=None,
172
+ filter: pc.Expression | None = None,
173
+ batch_size: int | None = None,
174
+ batch_readahead: int | None = None,
175
+ fragment_readahead: int | None = None,
176
+ fragment_scan_options: pa.dataset.FragmentScanOptions | None = None,
177
+ use_threads: bool = True,
178
+ memory_pool: pa.MemoryPool = None,
179
+ ):
180
+ return self.scanner(
181
+ columns,
182
+ filter,
183
+ batch_size,
184
+ batch_readahead,
185
+ fragment_readahead,
186
+ fragment_scan_options,
187
+ use_threads,
188
+ memory_pool,
189
+ ).to_table()
190
+
191
+
192
+ class TableScanner(pa.dataset.Scanner):
193
+ """A PyArrow Dataset Scanner that reads from a Spiral Table."""
194
+
195
+ def __init__(self, scan: Scan):
196
+ self._scan = scan
197
+ self._schema = scan.schema
198
+
199
+ # We don't actually initialize a Dataset, we just implement enough of the API
200
+ # to fool both DuckDB and Polars.
201
+ # super().__init__()
202
+
203
+ @property
204
+ def schema(self):
205
+ return self._schema
206
+
207
+ def count_rows(self):
208
+ # TODO(ngates): is there a faster way to count rows?
209
+ return sum(len(batch) for batch in self.to_reader())
210
+
211
+ def head(self, num_rows: int):
212
+ """Return the first `num_rows` rows of the dataset."""
213
+ reader = self.to_reader()
214
+ batches = []
215
+ row_count = 0
216
+ for batch in reader:
217
+ if row_count + len(batch) > num_rows:
218
+ batches.append(batch.slice(0, num_rows - row_count))
219
+ break
220
+ row_count += len(batch)
221
+ batches.append(batch)
222
+ return pa.Table.from_batches(batches, schema=reader.schema)
223
+
224
+ def scan_batches(self):
225
+ raise NotImplementedError("scan_batches not implemented")
226
+
227
+ def take(self, indices):
228
+ # TODO(ngates): can we defer take until after we've constructed the scan?
229
+ # Or should this we delay constructing the Spiral Table.scan?
230
+ raise NotImplementedError("take not implemented")
231
+
232
+ def to_batches(self):
233
+ return self.to_reader()
234
+
235
+ def to_reader(self):
236
+ return self._scan.to_record_batches()
237
+
238
+ def to_table(self):
239
+ return self.to_reader().read_all()
spiral/debug.py ADDED
@@ -0,0 +1,251 @@
1
+ from datetime import datetime
2
+
3
+ from spiral.core.core import TableScan
4
+ from spiral.core.manifests import FragmentFile, FragmentManifest
5
+ from spiral.core.spec import Key, KeyRange
6
+ from spiral.types_ import Timestamp
7
+
8
+
9
+ def show_scan(scan: TableScan):
10
+ """Displays a scan in a way that is useful for debugging."""
11
+ table_ids = scan.table_ids()
12
+ if len(table_ids) > 1:
13
+ raise NotImplementedError("Multiple table scan is not supported.")
14
+ table_id = table_ids[0]
15
+ column_groups = scan.column_groups()
16
+
17
+ splits = scan.split()
18
+ key_space_scan = scan.key_space_scan(table_id)
19
+
20
+ # Collect all key bounds from all manifests. This makes sure all visualizations are aligned.
21
+ key_points = set()
22
+ key_space_manifest = key_space_scan.manifest
23
+ for i in range(len(key_space_manifest)):
24
+ fragment_file = key_space_manifest[i]
25
+ key_points.add(fragment_file.key_extent.min)
26
+ key_points.add(fragment_file.key_extent.max)
27
+ for cg in column_groups:
28
+ cg_scan = scan.column_group_scan(cg)
29
+ cg_manifest = cg_scan.manifest
30
+ for i in range(len(cg_manifest)):
31
+ fragment_file = cg_manifest[i]
32
+ key_points.add(fragment_file.key_extent.min)
33
+ key_points.add(fragment_file.key_extent.max)
34
+
35
+ # Make sure split points exist in all key points.
36
+ for s in splits[:-1]: # Don't take the last end.
37
+ key_points.add(s.end)
38
+ key_points = list(sorted(key_points))
39
+
40
+ show_manifest(key_space_manifest, scope="Key space", key_points=key_points, splits=splits)
41
+ for cg in scan.column_groups():
42
+ cg_scan = scan.column_group_scan(cg)
43
+ # Skip table id from the start of the column group.
44
+ show_manifest(cg_scan.manifest, scope=".".join(cg.path[1:]), key_points=key_points, splits=splits)
45
+
46
+
47
+ def show_manifest(
48
+ manifest: FragmentManifest, scope: str = None, key_points: list[Key] = None, splits: list[KeyRange] = None
49
+ ):
50
+ try:
51
+ import matplotlib.patches as patches
52
+ import matplotlib.pyplot as plt
53
+ except ImportError:
54
+ raise ImportError("matplotlib is required for debug")
55
+
56
+ total_fragments = len(manifest)
57
+
58
+ size_points = set()
59
+ for i in range(total_fragments):
60
+ manifest_file: FragmentFile = manifest[i]
61
+ size_points.add(manifest_file.size_bytes)
62
+ size_points = list(sorted(size_points))
63
+
64
+ if key_points is None:
65
+ key_points = set()
66
+
67
+ for i in range(total_fragments):
68
+ manifest_file: FragmentFile = manifest[i]
69
+
70
+ key_points.add(manifest_file.key_extent.min)
71
+ key_points.add(manifest_file.key_extent.max)
72
+
73
+ if splits is not None:
74
+ for split in splits[:-1]:
75
+ key_points.add(split.end)
76
+
77
+ key_points = list(sorted(key_points))
78
+
79
+ # Create figure and axis with specified size
80
+ fig, ax = plt.subplots(figsize=(12, 8))
81
+
82
+ # Plot each rectangle
83
+ for i in range(total_fragments):
84
+ manifest_file: FragmentFile = manifest[i]
85
+
86
+ left = key_points.index(manifest_file.key_extent.min)
87
+ right = key_points.index(manifest_file.key_extent.max)
88
+ height = size_points.index(manifest_file.size_bytes) + 1
89
+
90
+ color = _get_fragment_color(manifest_file, i, total_fragments)
91
+
92
+ # Create rectangle patch
93
+ rect = patches.Rectangle(
94
+ (left, 0), # (x, y)
95
+ right - left, # width
96
+ height, # height
97
+ facecolor=color, # fill color
98
+ edgecolor="black", # border color
99
+ alpha=0.5, # transparency
100
+ linewidth=1, # border width
101
+ label=manifest_file.id, # label for legend
102
+ )
103
+
104
+ ax.add_patch(rect)
105
+
106
+ # Set axis limits with some padding
107
+ ax.set_xlim(-0.5, len(key_points) - 1 + 0.5)
108
+ ax.set_ylim(-0.5, len(size_points) + 0.5)
109
+
110
+ # Create split markers on x-axis
111
+ if splits is not None:
112
+ split_positions = [key_points.index(split.end) for split in splits[:-1]]
113
+
114
+ # Add split markers at the bottom
115
+ for pos in split_positions:
116
+ ax.annotate("▲", xy=(pos, 0), ha="center", va="top", color="red", annotation_clip=False)
117
+
118
+ # Add grid
119
+ ax.grid(True, linestyle="--", alpha=0.7, zorder=0)
120
+
121
+ # Add labels and title
122
+ ax.set_title("Fragment Distribution" if scope is None else f"{scope} Fragment Distribution")
123
+ ax.set_xlabel("Key Index")
124
+ ax.set_ylabel("Size Index")
125
+
126
+ # Add legend
127
+ ax.legend(bbox_to_anchor=(1, 1), loc="upper left", fontsize="small")
128
+
129
+ # Adjust layout to prevent label cutoff
130
+ plt.tight_layout()
131
+
132
+ plot = FragmentManifestPlot(fig, ax, manifest)
133
+ fig.canvas.mpl_connect("motion_notify_event", plot.hover)
134
+
135
+ plt.show()
136
+
137
+
138
+ def _get_fragment_color(manifest_file: FragmentFile, color_index, total_colors):
139
+ import matplotlib.cm as cm
140
+
141
+ if manifest_file.compacted_at is not None:
142
+ # Use a shade of gray for compacted fragments
143
+ # Vary the shade based on the index to distinguish different compacted fragments
144
+ gray_value = 0.3 + (0.5 * (color_index / total_colors))
145
+ return (gray_value, gray_value, gray_value)
146
+ else:
147
+ # Use viridis colormap for non-compacted fragments
148
+ return cm.viridis(color_index / total_colors)
149
+
150
+
151
+ def _get_fragment_legend(manifest_file: FragmentFile):
152
+ return "\n".join(
153
+ [
154
+ f"id: {manifest_file.id}",
155
+ f"size: {manifest_file.size_bytes:,} bytes",
156
+ f"key_span: {manifest_file.key_span}",
157
+ f"key_min: {manifest_file.key_extent.min}",
158
+ f"key_max: {manifest_file.key_extent.max}",
159
+ f"format: {manifest_file.format}",
160
+ f"level: {manifest_file.fs_level}",
161
+ f"committed_at: {_format_timestamp(manifest_file.committed_at)}",
162
+ f"compacted_at: {_format_timestamp(manifest_file.compacted_at)}",
163
+ f"fs_id: {manifest_file.fs_id}",
164
+ f"ks_id: {manifest_file.ks_id}",
165
+ ]
166
+ )
167
+
168
+
169
+ def _format_timestamp(ts: Timestamp | None) -> str:
170
+ # Format timestamp or show None
171
+ if ts is None:
172
+ return "None"
173
+ try:
174
+ return datetime.fromtimestamp(ts / 1e6).strftime("%Y-%m-%d %H:%M:%S")
175
+ except ValueError:
176
+ return str(ts)
177
+
178
+
179
+ class FragmentManifestPlot:
180
+ def __init__(self, fig, ax, manifest: FragmentManifest):
181
+ self.fig = fig
182
+ self.ax = ax
183
+ self.manifest = manifest
184
+
185
+ # Position the annotation in the bottom right corner
186
+ self.annotation = ax.annotate(
187
+ "",
188
+ xy=(0.98, 0.02), # Position in axes coordinates
189
+ xycoords="axes fraction",
190
+ bbox=dict(boxstyle="round,pad=0.5", fc="white", ec="gray", alpha=0.8),
191
+ ha="right", # Right-align text
192
+ va="bottom", # Bottom-align text
193
+ visible=False,
194
+ )
195
+ self.highlighted_rect = None
196
+ self.highlighted_legend = None
197
+
198
+ def hover(self, event):
199
+ if event.inaxes != self.ax:
200
+ # Check if we're hovering over the legend
201
+ legend = self.ax.get_legend()
202
+ if legend and legend.contains(event)[0]:
203
+ # Find which legend item we're hovering over
204
+ for i, legend_text in enumerate(legend.get_texts()):
205
+ if legend_text.contains(event)[0]:
206
+ manifest_file = self.manifest[i]
207
+ self._show_legend(manifest_file, i, legend_text)
208
+ return
209
+ self._hide_legend()
210
+ return
211
+
212
+ # Check rectangles in the main plot
213
+ for i, rect in enumerate(self.ax.patches):
214
+ if rect.contains(event)[0]:
215
+ manifest_file = self.manifest[i]
216
+ self._show_legend(manifest_file, i, rect)
217
+ return
218
+
219
+ self._hide_legend()
220
+
221
+ def _show_legend(self, manifest_file, index, highlight_obj):
222
+ import matplotlib.patches as patches
223
+
224
+ # Update tooltip text
225
+ self.annotation.set_text(_get_fragment_legend(manifest_file))
226
+ self.annotation.set_visible(True)
227
+
228
+ # Handle highlighting
229
+ if isinstance(highlight_obj, patches.Rectangle):
230
+ # Highlighting rectangle in main plot
231
+ if self.highlighted_rect and self.highlighted_rect != highlight_obj:
232
+ self.highlighted_rect.set_alpha(0.5)
233
+ highlight_obj.set_alpha(0.8)
234
+ self.highlighted_rect = highlight_obj
235
+ else:
236
+ # Highlighting legend text
237
+ if self.highlighted_rect:
238
+ self.highlighted_rect.set_alpha(0.5)
239
+ # Find and highlight corresponding rectangle
240
+ rect = self.ax.patches[index]
241
+ rect.set_alpha(0.8)
242
+ self.highlighted_rect = rect
243
+
244
+ self.fig.canvas.draw_idle()
245
+
246
+ def _hide_legend(self):
247
+ if self.annotation.get_visible():
248
+ self.annotation.set_visible(False)
249
+ if self.highlighted_rect:
250
+ self.highlighted_rect.set_alpha(0.5)
251
+ self.fig.canvas.draw_idle()
@@ -0,0 +1,222 @@
1
+ import builtins
2
+ import functools
3
+ import operator
4
+ from typing import Any
5
+
6
+ import pyarrow as pa
7
+
8
+ from spiral import _lib, arrow
9
+
10
+ from . import http as http
11
+ from . import io as io
12
+ from . import list_ as list
13
+ from . import refs as refs
14
+ from . import str_ as str
15
+ from . import struct as struct
16
+ from . import tiff as tiff
17
+ from .base import Expr, ExprLike
18
+
19
+ __all__ = [
20
+ "Expr",
21
+ "add",
22
+ "and_",
23
+ "deref",
24
+ "divide",
25
+ "eq",
26
+ "getitem",
27
+ "gt",
28
+ "gte",
29
+ "http",
30
+ "io",
31
+ "is_not_null",
32
+ "is_null",
33
+ "lift",
34
+ "list",
35
+ "lt",
36
+ "lte",
37
+ "merge",
38
+ "modulo",
39
+ "multiply",
40
+ "negate",
41
+ "neq",
42
+ "not_",
43
+ "or_",
44
+ "pack",
45
+ "ref",
46
+ "refs",
47
+ "scalar",
48
+ "select",
49
+ "str",
50
+ "struct",
51
+ "subtract",
52
+ "tiff",
53
+ "var",
54
+ "xor",
55
+ ]
56
+
57
+ # Inline some of the struct expressions since they're so common
58
+ getitem = struct.getitem
59
+ merge = struct.merge
60
+ pack = struct.pack
61
+ select = struct.select
62
+ ref = refs.ref
63
+ deref = refs.deref
64
+
65
+
66
+ def lift(expr: ExprLike) -> Expr:
67
+ # Convert an ExprLike into an Expr.
68
+ if isinstance(expr, Expr):
69
+ return expr
70
+
71
+ if isinstance(expr, dict):
72
+ # NOTE: we assume this is a struct expression. We could be smarter and be context aware to determine if
73
+ # this is in fact a struct scalar, but the user can always create one of those manually.
74
+
75
+ # First we un-nest any dot-separated field names
76
+ expr: dict = arrow.nest_structs(expr)
77
+
78
+ return pack({k: lift(v) for k, v in expr.items()})
79
+
80
+ if isinstance(expr, builtins.list):
81
+ return lift(pa.array(expr))
82
+
83
+ # Unpack tables and chunked arrays
84
+ if isinstance(expr, pa.Table):
85
+ expr = expr.to_struct_array()
86
+ if isinstance(expr, pa.ChunkedArray):
87
+ expr = expr.combine_chunks()
88
+
89
+ # If the value is struct-like, we un-nest any dot-separated field names
90
+ if isinstance(expr, pa.StructArray | pa.StructScalar):
91
+ return lift(arrow.nest_structs(expr))
92
+
93
+ if isinstance(expr, pa.Array):
94
+ return Expr(_lib.spql.expr.array_lit(expr))
95
+
96
+ # Otherwise, assume it's a scalar.
97
+ return scalar(expr)
98
+
99
+
100
+ def var(name: builtins.str) -> Expr:
101
+ """Create a variable expression."""
102
+ return Expr(_lib.spql.expr.var(name))
103
+
104
+
105
+ def keyed(name: builtins.str, dtype: pa.DataType) -> Expr:
106
+ """Create a variable expression referencing a column in the key table.
107
+
108
+ Key table is optionally given to `Scan#to_record_batches` function when reading only specific keys
109
+ or doing cell pushdown.
110
+
111
+ Args:
112
+ name: variable name
113
+ dtype: must match dtype of the column in the key table.
114
+ """
115
+ return Expr(_lib.spql.expr.keyed(f"#{name}", dtype))
116
+
117
+
118
+ def scalar(value: Any) -> Expr:
119
+ """Create a scalar expression."""
120
+ if not isinstance(value, pa.Scalar):
121
+ value = pa.scalar(value)
122
+ return Expr(_lib.spql.expr.scalar(value))
123
+
124
+
125
+ def cast(expr: ExprLike, dtype: pa.DataType) -> Expr:
126
+ """Cast an expression into another PyArrow DataType."""
127
+ expr = lift(expr)
128
+ return Expr(_lib.spql.expr.cast(expr.__expr__, dtype))
129
+
130
+
131
+ def and_(expr: ExprLike, *exprs: ExprLike) -> Expr:
132
+ """Create a conjunction of one or more expressions."""
133
+
134
+ return functools.reduce(operator.and_, [lift(e) for e in exprs], lift(expr))
135
+
136
+
137
+ def or_(expr: ExprLike, *exprs: ExprLike) -> Expr:
138
+ """Create a disjunction of one or more expressions."""
139
+ return functools.reduce(operator.or_, [lift(e) for e in exprs], lift(expr))
140
+
141
+
142
+ def eq(lhs: ExprLike, rhs: ExprLike) -> Expr:
143
+ """Create an equality comparison."""
144
+ return operator.eq(lift(lhs), rhs)
145
+
146
+
147
+ def neq(lhs: ExprLike, rhs: ExprLike) -> Expr:
148
+ """Create a not-equal comparison."""
149
+ return operator.ne(lift(lhs), rhs)
150
+
151
+
152
+ def xor(lhs: ExprLike, rhs: ExprLike) -> Expr:
153
+ """Create a XOR comparison."""
154
+ return operator.xor(lift(lhs), rhs)
155
+
156
+
157
+ def lt(lhs: ExprLike, rhs: ExprLike) -> Expr:
158
+ """Create a less-than comparison."""
159
+ return operator.lt(lift(lhs), rhs)
160
+
161
+
162
+ def lte(lhs: ExprLike, rhs: ExprLike) -> Expr:
163
+ """Create a less-than-or-equal comparison."""
164
+ return operator.le(lift(lhs), rhs)
165
+
166
+
167
+ def gt(lhs: ExprLike, rhs: ExprLike) -> Expr:
168
+ """Create a greater-than comparison."""
169
+ return operator.gt(lift(lhs), rhs)
170
+
171
+
172
+ def gte(lhs: ExprLike, rhs: ExprLike) -> Expr:
173
+ """Create a greater-than-or-equal comparison."""
174
+ return operator.ge(lift(lhs), rhs)
175
+
176
+
177
+ def negate(expr: ExprLike) -> Expr:
178
+ """Negate the given expression."""
179
+ return operator.neg(lift(expr))
180
+
181
+
182
+ def not_(expr: ExprLike) -> Expr:
183
+ """Negate the given expression."""
184
+ expr = lift(expr)
185
+ return Expr(_lib.spql.expr.unary("not", expr.__expr__))
186
+
187
+
188
+ def is_null(expr: ExprLike) -> Expr:
189
+ """Check if the given expression is null."""
190
+ expr = lift(expr)
191
+ return Expr(_lib.spql.expr.unary("is_null", expr.__expr__))
192
+
193
+
194
+ def is_not_null(expr: ExprLike) -> Expr:
195
+ """Check if the given expression is null."""
196
+ expr = lift(expr)
197
+ return Expr(_lib.spql.expr.unary("is_not_null", expr.__expr__))
198
+
199
+
200
+ def add(lhs: ExprLike, rhs: ExprLike) -> Expr:
201
+ """Add two expressions."""
202
+ return operator.add(lift(lhs), rhs)
203
+
204
+
205
+ def subtract(lhs: ExprLike, rhs: ExprLike) -> Expr:
206
+ """Subtract two expressions."""
207
+ return operator.sub(lift(lhs), rhs)
208
+
209
+
210
+ def multiply(lhs: ExprLike, rhs: ExprLike) -> Expr:
211
+ """Multiply two expressions."""
212
+ return operator.mul(lift(lhs), rhs)
213
+
214
+
215
+ def divide(lhs: ExprLike, rhs: ExprLike) -> Expr:
216
+ """Divide two expressions."""
217
+ return operator.truediv(lift(lhs), rhs)
218
+
219
+
220
+ def modulo(lhs: ExprLike, rhs: ExprLike) -> Expr:
221
+ """Modulo two expressions."""
222
+ return operator.mod(lift(lhs), rhs)