pyspiral 0.1.0__cp310-abi3-macosx_11_0_arm64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (81) hide show
  1. pyspiral-0.1.0.dist-info/METADATA +48 -0
  2. pyspiral-0.1.0.dist-info/RECORD +81 -0
  3. pyspiral-0.1.0.dist-info/WHEEL +4 -0
  4. pyspiral-0.1.0.dist-info/entry_points.txt +2 -0
  5. spiral/__init__.py +11 -0
  6. spiral/_lib.abi3.so +0 -0
  7. spiral/adbc.py +386 -0
  8. spiral/api/__init__.py +221 -0
  9. spiral/api/admin.py +29 -0
  10. spiral/api/filesystems.py +125 -0
  11. spiral/api/organizations.py +90 -0
  12. spiral/api/projects.py +160 -0
  13. spiral/api/tables.py +94 -0
  14. spiral/api/tokens.py +56 -0
  15. spiral/api/workloads.py +45 -0
  16. spiral/arrow.py +209 -0
  17. spiral/authn/__init__.py +0 -0
  18. spiral/authn/authn.py +89 -0
  19. spiral/authn/device.py +206 -0
  20. spiral/authn/github_.py +33 -0
  21. spiral/authn/modal_.py +18 -0
  22. spiral/catalog.py +78 -0
  23. spiral/cli/__init__.py +82 -0
  24. spiral/cli/__main__.py +4 -0
  25. spiral/cli/admin.py +21 -0
  26. spiral/cli/app.py +48 -0
  27. spiral/cli/console.py +95 -0
  28. spiral/cli/fs.py +47 -0
  29. spiral/cli/login.py +13 -0
  30. spiral/cli/org.py +90 -0
  31. spiral/cli/printer.py +45 -0
  32. spiral/cli/project.py +107 -0
  33. spiral/cli/state.py +3 -0
  34. spiral/cli/table.py +20 -0
  35. spiral/cli/token.py +27 -0
  36. spiral/cli/types.py +53 -0
  37. spiral/cli/workload.py +59 -0
  38. spiral/config.py +26 -0
  39. spiral/core/__init__.py +0 -0
  40. spiral/core/core/__init__.pyi +53 -0
  41. spiral/core/manifests/__init__.pyi +53 -0
  42. spiral/core/metastore/__init__.pyi +91 -0
  43. spiral/core/spec/__init__.pyi +257 -0
  44. spiral/dataset.py +239 -0
  45. spiral/debug.py +251 -0
  46. spiral/expressions/__init__.py +222 -0
  47. spiral/expressions/base.py +149 -0
  48. spiral/expressions/http.py +86 -0
  49. spiral/expressions/io.py +100 -0
  50. spiral/expressions/list_.py +68 -0
  51. spiral/expressions/refs.py +44 -0
  52. spiral/expressions/str_.py +39 -0
  53. spiral/expressions/struct.py +57 -0
  54. spiral/expressions/tiff.py +223 -0
  55. spiral/expressions/udf.py +46 -0
  56. spiral/grpc_.py +32 -0
  57. spiral/project.py +137 -0
  58. spiral/proto/_/__init__.py +0 -0
  59. spiral/proto/_/arrow/__init__.py +0 -0
  60. spiral/proto/_/arrow/flight/__init__.py +0 -0
  61. spiral/proto/_/arrow/flight/protocol/__init__.py +0 -0
  62. spiral/proto/_/arrow/flight/protocol/sql/__init__.py +1990 -0
  63. spiral/proto/_/scandal/__init__.py +223 -0
  64. spiral/proto/_/spfs/__init__.py +36 -0
  65. spiral/proto/_/spiral/__init__.py +0 -0
  66. spiral/proto/_/spiral/table/__init__.py +225 -0
  67. spiral/proto/_/spiraldb/__init__.py +0 -0
  68. spiral/proto/_/spiraldb/metastore/__init__.py +499 -0
  69. spiral/proto/__init__.py +0 -0
  70. spiral/proto/scandal/__init__.py +45 -0
  71. spiral/proto/spiral/__init__.py +0 -0
  72. spiral/proto/spiral/table/__init__.py +96 -0
  73. spiral/proto/substrait/__init__.py +3399 -0
  74. spiral/proto/substrait/extensions/__init__.py +115 -0
  75. spiral/proto/util.py +41 -0
  76. spiral/py.typed +0 -0
  77. spiral/scan_.py +168 -0
  78. spiral/settings.py +157 -0
  79. spiral/substrait_.py +275 -0
  80. spiral/table.py +157 -0
  81. spiral/types_.py +6 -0
spiral/dataset.py ADDED
@@ -0,0 +1,239 @@
1
+ from typing import TYPE_CHECKING, Any
2
+
3
+ import pyarrow as pa
4
+ import pyarrow.compute as pc
5
+
6
+ if TYPE_CHECKING:
7
+ import pyarrow.dataset
8
+
9
+ from spiral import Scan, Table
10
+
11
+
12
+ class TableDataset(pa.dataset.Dataset):
13
+ def __init__(self, table: Table):
14
+ self._table = table
15
+ self._schema: pa.Schema = table.scan().schema.to_arrow()
16
+
17
+ # We don't actually initialize a Dataset, we just implement enough of the API
18
+ # to fool both DuckDB and Polars.
19
+ # super().__init__()
20
+
21
+ @property
22
+ def schema(self) -> pa.Schema:
23
+ return self._schema
24
+
25
+ def count_rows(
26
+ self,
27
+ filter: pc.Expression | None = None,
28
+ batch_size: int | None = None,
29
+ batch_readahead: int | None = None,
30
+ fragment_readahead: int | None = None,
31
+ fragment_scan_options: pa.dataset.FragmentScanOptions | None = None,
32
+ use_threads: bool = True,
33
+ memory_pool: pa.MemoryPool = None,
34
+ ):
35
+ return self.scanner(
36
+ None,
37
+ filter,
38
+ batch_size,
39
+ batch_readahead,
40
+ fragment_readahead,
41
+ fragment_scan_options,
42
+ use_threads,
43
+ memory_pool,
44
+ ).count_rows()
45
+
46
+ def filter(self, expression: pc.Expression) -> "TableDataset":
47
+ raise NotImplementedError("filter not implemented")
48
+
49
+ def get_fragments(self, filter: pc.Expression | None = None):
50
+ """TODO(ngates): perhaps we should return ranges as per our split API?"""
51
+ raise NotImplementedError("get_fragments not implemented")
52
+
53
+ def head(
54
+ self,
55
+ num_rows: int,
56
+ columns: list[str] | None = None,
57
+ filter: pc.Expression | None = None,
58
+ batch_size: int | None = None,
59
+ batch_readahead: int | None = None,
60
+ fragment_readahead: int | None = None,
61
+ fragment_scan_options: pa.dataset.FragmentScanOptions | None = None,
62
+ use_threads: bool = True,
63
+ memory_pool: pa.MemoryPool = None,
64
+ ):
65
+ self.scanner(
66
+ columns,
67
+ filter,
68
+ batch_size,
69
+ batch_readahead,
70
+ fragment_readahead,
71
+ fragment_scan_options,
72
+ use_threads,
73
+ memory_pool,
74
+ ).head(num_rows)
75
+
76
+ def join(
77
+ self,
78
+ right_dataset,
79
+ keys,
80
+ right_keys=None,
81
+ join_type=None,
82
+ left_suffix=None,
83
+ right_suffix=None,
84
+ coalesce_keys=True,
85
+ use_threads=True,
86
+ ):
87
+ raise NotImplementedError("join not implemented")
88
+
89
+ def join_asof(self, right_dataset, on, by, tolerance, right_on=None, right_by=None):
90
+ raise NotImplementedError("join_asof not implemented")
91
+
92
+ def replace_schema(self, schema: pa.Schema) -> "TableDataset":
93
+ raise NotImplementedError("replace_schema not implemented")
94
+
95
+ def scanner(
96
+ self,
97
+ columns: list[str] | None = None,
98
+ filter: pc.Expression | None = None,
99
+ batch_size: int | None = None,
100
+ batch_readahead: int | None = None,
101
+ fragment_readahead: int | None = None,
102
+ fragment_scan_options: pa.dataset.FragmentScanOptions | None = None,
103
+ use_threads: bool = True,
104
+ memory_pool: pa.MemoryPool = None,
105
+ ) -> "TableScanner":
106
+ from .substrait_ import SubstraitConverter
107
+
108
+ # Extract the substrait expression so we can convert it to a Spiral expression
109
+ if filter is not None:
110
+ filter = SubstraitConverter(self._table, self._schema, self._table.key_schema).convert(
111
+ filter.to_substrait(self._schema, allow_arrow_extensions=True),
112
+ )
113
+
114
+ scan = self._table.scan(
115
+ {c: self._table[c] for c in columns} if columns else self._table,
116
+ where=filter,
117
+ exclude_keys=True,
118
+ )
119
+ return TableScanner(scan)
120
+
121
+ def sort_by(self, sorting, **kwargs):
122
+ raise NotImplementedError("sort_by not implemented")
123
+
124
+ def take(
125
+ self,
126
+ indices: pa.Array | Any,
127
+ columns: list[str] | None = None,
128
+ filter: pc.Expression | None = None,
129
+ batch_size: int | None = None,
130
+ batch_readahead: int | None = None,
131
+ fragment_readahead: int | None = None,
132
+ fragment_scan_options: pa.dataset.FragmentScanOptions | None = None,
133
+ use_threads: bool = True,
134
+ memory_pool: pa.MemoryPool = None,
135
+ ):
136
+ return self.scanner(
137
+ columns,
138
+ filter,
139
+ batch_size,
140
+ batch_readahead,
141
+ fragment_readahead,
142
+ fragment_scan_options,
143
+ use_threads,
144
+ memory_pool,
145
+ ).take(indices)
146
+
147
+ def to_batches(
148
+ self,
149
+ columns: list[str] | None = None,
150
+ filter: pc.Expression | None = None,
151
+ batch_size: int | None = None,
152
+ batch_readahead: int | None = None,
153
+ fragment_readahead: int | None = None,
154
+ fragment_scan_options: pa.dataset.FragmentScanOptions | None = None,
155
+ use_threads: bool = True,
156
+ memory_pool: pa.MemoryPool = None,
157
+ ):
158
+ return self.scanner(
159
+ columns,
160
+ filter,
161
+ batch_size,
162
+ batch_readahead,
163
+ fragment_readahead,
164
+ fragment_scan_options,
165
+ use_threads,
166
+ memory_pool,
167
+ ).to_batches()
168
+
169
+ def to_table(
170
+ self,
171
+ columns=None,
172
+ filter: pc.Expression | None = None,
173
+ batch_size: int | None = None,
174
+ batch_readahead: int | None = None,
175
+ fragment_readahead: int | None = None,
176
+ fragment_scan_options: pa.dataset.FragmentScanOptions | None = None,
177
+ use_threads: bool = True,
178
+ memory_pool: pa.MemoryPool = None,
179
+ ):
180
+ return self.scanner(
181
+ columns,
182
+ filter,
183
+ batch_size,
184
+ batch_readahead,
185
+ fragment_readahead,
186
+ fragment_scan_options,
187
+ use_threads,
188
+ memory_pool,
189
+ ).to_table()
190
+
191
+
192
+ class TableScanner(pa.dataset.Scanner):
193
+ """A PyArrow Dataset Scanner that reads from a Spiral Table."""
194
+
195
+ def __init__(self, scan: Scan):
196
+ self._scan = scan
197
+ self._schema = scan.schema
198
+
199
+ # We don't actually initialize a Dataset, we just implement enough of the API
200
+ # to fool both DuckDB and Polars.
201
+ # super().__init__()
202
+
203
+ @property
204
+ def schema(self):
205
+ return self._schema
206
+
207
+ def count_rows(self):
208
+ # TODO(ngates): is there a faster way to count rows?
209
+ return sum(len(batch) for batch in self.to_reader())
210
+
211
+ def head(self, num_rows: int):
212
+ """Return the first `num_rows` rows of the dataset."""
213
+ reader = self.to_reader()
214
+ batches = []
215
+ row_count = 0
216
+ for batch in reader:
217
+ if row_count + len(batch) > num_rows:
218
+ batches.append(batch.slice(0, num_rows - row_count))
219
+ break
220
+ row_count += len(batch)
221
+ batches.append(batch)
222
+ return pa.Table.from_batches(batches, schema=reader.schema)
223
+
224
+ def scan_batches(self):
225
+ raise NotImplementedError("scan_batches not implemented")
226
+
227
+ def take(self, indices):
228
+ # TODO(ngates): can we defer take until after we've constructed the scan?
229
+ # Or should this we delay constructing the Spiral Table.scan?
230
+ raise NotImplementedError("take not implemented")
231
+
232
+ def to_batches(self):
233
+ return self.to_reader()
234
+
235
+ def to_reader(self):
236
+ return self._scan.to_record_batches()
237
+
238
+ def to_table(self):
239
+ return self.to_reader().read_all()
spiral/debug.py ADDED
@@ -0,0 +1,251 @@
1
+ from datetime import datetime
2
+
3
+ from spiral.core.core import TableScan
4
+ from spiral.core.manifests import FragmentFile, FragmentManifest
5
+ from spiral.core.spec import Key, KeyRange
6
+ from spiral.types_ import Timestamp
7
+
8
+
9
+ def show_scan(scan: TableScan):
10
+ """Displays a scan in a way that is useful for debugging."""
11
+ table_ids = scan.table_ids()
12
+ if len(table_ids) > 1:
13
+ raise NotImplementedError("Multiple table scan is not supported.")
14
+ table_id = table_ids[0]
15
+ column_groups = scan.column_groups()
16
+
17
+ splits = scan.split()
18
+ key_space_scan = scan.key_space_scan(table_id)
19
+
20
+ # Collect all key bounds from all manifests. This makes sure all visualizations are aligned.
21
+ key_points = set()
22
+ key_space_manifest = key_space_scan.manifest
23
+ for i in range(len(key_space_manifest)):
24
+ fragment_file = key_space_manifest[i]
25
+ key_points.add(fragment_file.key_extent.min)
26
+ key_points.add(fragment_file.key_extent.max)
27
+ for cg in column_groups:
28
+ cg_scan = scan.column_group_scan(cg)
29
+ cg_manifest = cg_scan.manifest
30
+ for i in range(len(cg_manifest)):
31
+ fragment_file = cg_manifest[i]
32
+ key_points.add(fragment_file.key_extent.min)
33
+ key_points.add(fragment_file.key_extent.max)
34
+
35
+ # Make sure split points exist in all key points.
36
+ for s in splits[:-1]: # Don't take the last end.
37
+ key_points.add(s.end)
38
+ key_points = list(sorted(key_points))
39
+
40
+ show_manifest(key_space_manifest, scope="Key space", key_points=key_points, splits=splits)
41
+ for cg in scan.column_groups():
42
+ cg_scan = scan.column_group_scan(cg)
43
+ # Skip table id from the start of the column group.
44
+ show_manifest(cg_scan.manifest, scope=".".join(cg.path[1:]), key_points=key_points, splits=splits)
45
+
46
+
47
+ def show_manifest(
48
+ manifest: FragmentManifest, scope: str = None, key_points: list[Key] = None, splits: list[KeyRange] = None
49
+ ):
50
+ try:
51
+ import matplotlib.patches as patches
52
+ import matplotlib.pyplot as plt
53
+ except ImportError:
54
+ raise ImportError("matplotlib is required for debug")
55
+
56
+ total_fragments = len(manifest)
57
+
58
+ size_points = set()
59
+ for i in range(total_fragments):
60
+ manifest_file: FragmentFile = manifest[i]
61
+ size_points.add(manifest_file.size_bytes)
62
+ size_points = list(sorted(size_points))
63
+
64
+ if key_points is None:
65
+ key_points = set()
66
+
67
+ for i in range(total_fragments):
68
+ manifest_file: FragmentFile = manifest[i]
69
+
70
+ key_points.add(manifest_file.key_extent.min)
71
+ key_points.add(manifest_file.key_extent.max)
72
+
73
+ if splits is not None:
74
+ for split in splits[:-1]:
75
+ key_points.add(split.end)
76
+
77
+ key_points = list(sorted(key_points))
78
+
79
+ # Create figure and axis with specified size
80
+ fig, ax = plt.subplots(figsize=(12, 8))
81
+
82
+ # Plot each rectangle
83
+ for i in range(total_fragments):
84
+ manifest_file: FragmentFile = manifest[i]
85
+
86
+ left = key_points.index(manifest_file.key_extent.min)
87
+ right = key_points.index(manifest_file.key_extent.max)
88
+ height = size_points.index(manifest_file.size_bytes) + 1
89
+
90
+ color = _get_fragment_color(manifest_file, i, total_fragments)
91
+
92
+ # Create rectangle patch
93
+ rect = patches.Rectangle(
94
+ (left, 0), # (x, y)
95
+ right - left, # width
96
+ height, # height
97
+ facecolor=color, # fill color
98
+ edgecolor="black", # border color
99
+ alpha=0.5, # transparency
100
+ linewidth=1, # border width
101
+ label=manifest_file.id, # label for legend
102
+ )
103
+
104
+ ax.add_patch(rect)
105
+
106
+ # Set axis limits with some padding
107
+ ax.set_xlim(-0.5, len(key_points) - 1 + 0.5)
108
+ ax.set_ylim(-0.5, len(size_points) + 0.5)
109
+
110
+ # Create split markers on x-axis
111
+ if splits is not None:
112
+ split_positions = [key_points.index(split.end) for split in splits[:-1]]
113
+
114
+ # Add split markers at the bottom
115
+ for pos in split_positions:
116
+ ax.annotate("▲", xy=(pos, 0), ha="center", va="top", color="red", annotation_clip=False)
117
+
118
+ # Add grid
119
+ ax.grid(True, linestyle="--", alpha=0.7, zorder=0)
120
+
121
+ # Add labels and title
122
+ ax.set_title("Fragment Distribution" if scope is None else f"{scope} Fragment Distribution")
123
+ ax.set_xlabel("Key Index")
124
+ ax.set_ylabel("Size Index")
125
+
126
+ # Add legend
127
+ ax.legend(bbox_to_anchor=(1, 1), loc="upper left", fontsize="small")
128
+
129
+ # Adjust layout to prevent label cutoff
130
+ plt.tight_layout()
131
+
132
+ plot = FragmentManifestPlot(fig, ax, manifest)
133
+ fig.canvas.mpl_connect("motion_notify_event", plot.hover)
134
+
135
+ plt.show()
136
+
137
+
138
+ def _get_fragment_color(manifest_file: FragmentFile, color_index, total_colors):
139
+ import matplotlib.cm as cm
140
+
141
+ if manifest_file.compacted_at is not None:
142
+ # Use a shade of gray for compacted fragments
143
+ # Vary the shade based on the index to distinguish different compacted fragments
144
+ gray_value = 0.3 + (0.5 * (color_index / total_colors))
145
+ return (gray_value, gray_value, gray_value)
146
+ else:
147
+ # Use viridis colormap for non-compacted fragments
148
+ return cm.viridis(color_index / total_colors)
149
+
150
+
151
+ def _get_fragment_legend(manifest_file: FragmentFile):
152
+ return "\n".join(
153
+ [
154
+ f"id: {manifest_file.id}",
155
+ f"size: {manifest_file.size_bytes:,} bytes",
156
+ f"key_span: {manifest_file.key_span}",
157
+ f"key_min: {manifest_file.key_extent.min}",
158
+ f"key_max: {manifest_file.key_extent.max}",
159
+ f"format: {manifest_file.format}",
160
+ f"level: {manifest_file.fs_level}",
161
+ f"committed_at: {_format_timestamp(manifest_file.committed_at)}",
162
+ f"compacted_at: {_format_timestamp(manifest_file.compacted_at)}",
163
+ f"fs_id: {manifest_file.fs_id}",
164
+ f"ks_id: {manifest_file.ks_id}",
165
+ ]
166
+ )
167
+
168
+
169
+ def _format_timestamp(ts: Timestamp | None) -> str:
170
+ # Format timestamp or show None
171
+ if ts is None:
172
+ return "None"
173
+ try:
174
+ return datetime.fromtimestamp(ts / 1e6).strftime("%Y-%m-%d %H:%M:%S")
175
+ except ValueError:
176
+ return str(ts)
177
+
178
+
179
+ class FragmentManifestPlot:
180
+ def __init__(self, fig, ax, manifest: FragmentManifest):
181
+ self.fig = fig
182
+ self.ax = ax
183
+ self.manifest = manifest
184
+
185
+ # Position the annotation in the bottom right corner
186
+ self.annotation = ax.annotate(
187
+ "",
188
+ xy=(0.98, 0.02), # Position in axes coordinates
189
+ xycoords="axes fraction",
190
+ bbox=dict(boxstyle="round,pad=0.5", fc="white", ec="gray", alpha=0.8),
191
+ ha="right", # Right-align text
192
+ va="bottom", # Bottom-align text
193
+ visible=False,
194
+ )
195
+ self.highlighted_rect = None
196
+ self.highlighted_legend = None
197
+
198
+ def hover(self, event):
199
+ if event.inaxes != self.ax:
200
+ # Check if we're hovering over the legend
201
+ legend = self.ax.get_legend()
202
+ if legend and legend.contains(event)[0]:
203
+ # Find which legend item we're hovering over
204
+ for i, legend_text in enumerate(legend.get_texts()):
205
+ if legend_text.contains(event)[0]:
206
+ manifest_file = self.manifest[i]
207
+ self._show_legend(manifest_file, i, legend_text)
208
+ return
209
+ self._hide_legend()
210
+ return
211
+
212
+ # Check rectangles in the main plot
213
+ for i, rect in enumerate(self.ax.patches):
214
+ if rect.contains(event)[0]:
215
+ manifest_file = self.manifest[i]
216
+ self._show_legend(manifest_file, i, rect)
217
+ return
218
+
219
+ self._hide_legend()
220
+
221
+ def _show_legend(self, manifest_file, index, highlight_obj):
222
+ import matplotlib.patches as patches
223
+
224
+ # Update tooltip text
225
+ self.annotation.set_text(_get_fragment_legend(manifest_file))
226
+ self.annotation.set_visible(True)
227
+
228
+ # Handle highlighting
229
+ if isinstance(highlight_obj, patches.Rectangle):
230
+ # Highlighting rectangle in main plot
231
+ if self.highlighted_rect and self.highlighted_rect != highlight_obj:
232
+ self.highlighted_rect.set_alpha(0.5)
233
+ highlight_obj.set_alpha(0.8)
234
+ self.highlighted_rect = highlight_obj
235
+ else:
236
+ # Highlighting legend text
237
+ if self.highlighted_rect:
238
+ self.highlighted_rect.set_alpha(0.5)
239
+ # Find and highlight corresponding rectangle
240
+ rect = self.ax.patches[index]
241
+ rect.set_alpha(0.8)
242
+ self.highlighted_rect = rect
243
+
244
+ self.fig.canvas.draw_idle()
245
+
246
+ def _hide_legend(self):
247
+ if self.annotation.get_visible():
248
+ self.annotation.set_visible(False)
249
+ if self.highlighted_rect:
250
+ self.highlighted_rect.set_alpha(0.5)
251
+ self.fig.canvas.draw_idle()
@@ -0,0 +1,222 @@
1
+ import builtins
2
+ import functools
3
+ import operator
4
+ from typing import Any
5
+
6
+ import pyarrow as pa
7
+
8
+ from spiral import _lib, arrow
9
+
10
+ from . import http as http
11
+ from . import io as io
12
+ from . import list_ as list
13
+ from . import refs as refs
14
+ from . import str_ as str
15
+ from . import struct as struct
16
+ from . import tiff as tiff
17
+ from .base import Expr, ExprLike
18
+
19
+ __all__ = [
20
+ "Expr",
21
+ "add",
22
+ "and_",
23
+ "deref",
24
+ "divide",
25
+ "eq",
26
+ "getitem",
27
+ "gt",
28
+ "gte",
29
+ "http",
30
+ "io",
31
+ "is_not_null",
32
+ "is_null",
33
+ "lift",
34
+ "list",
35
+ "lt",
36
+ "lte",
37
+ "merge",
38
+ "modulo",
39
+ "multiply",
40
+ "negate",
41
+ "neq",
42
+ "not_",
43
+ "or_",
44
+ "pack",
45
+ "ref",
46
+ "refs",
47
+ "scalar",
48
+ "select",
49
+ "str",
50
+ "struct",
51
+ "subtract",
52
+ "tiff",
53
+ "var",
54
+ "xor",
55
+ ]
56
+
57
+ # Inline some of the struct expressions since they're so common
58
+ getitem = struct.getitem
59
+ merge = struct.merge
60
+ pack = struct.pack
61
+ select = struct.select
62
+ ref = refs.ref
63
+ deref = refs.deref
64
+
65
+
66
+ def lift(expr: ExprLike) -> Expr:
67
+ # Convert an ExprLike into an Expr.
68
+ if isinstance(expr, Expr):
69
+ return expr
70
+
71
+ if isinstance(expr, dict):
72
+ # NOTE: we assume this is a struct expression. We could be smarter and be context aware to determine if
73
+ # this is in fact a struct scalar, but the user can always create one of those manually.
74
+
75
+ # First we un-nest any dot-separated field names
76
+ expr: dict = arrow.nest_structs(expr)
77
+
78
+ return pack({k: lift(v) for k, v in expr.items()})
79
+
80
+ if isinstance(expr, builtins.list):
81
+ return lift(pa.array(expr))
82
+
83
+ # Unpack tables and chunked arrays
84
+ if isinstance(expr, pa.Table):
85
+ expr = expr.to_struct_array()
86
+ if isinstance(expr, pa.ChunkedArray):
87
+ expr = expr.combine_chunks()
88
+
89
+ # If the value is struct-like, we un-nest any dot-separated field names
90
+ if isinstance(expr, pa.StructArray | pa.StructScalar):
91
+ return lift(arrow.nest_structs(expr))
92
+
93
+ if isinstance(expr, pa.Array):
94
+ return Expr(_lib.spql.expr.array_lit(expr))
95
+
96
+ # Otherwise, assume it's a scalar.
97
+ return scalar(expr)
98
+
99
+
100
+ def var(name: builtins.str) -> Expr:
101
+ """Create a variable expression."""
102
+ return Expr(_lib.spql.expr.var(name))
103
+
104
+
105
+ def keyed(name: builtins.str, dtype: pa.DataType) -> Expr:
106
+ """Create a variable expression referencing a column in the key table.
107
+
108
+ Key table is optionally given to `Scan#to_record_batches` function when reading only specific keys
109
+ or doing cell pushdown.
110
+
111
+ Args:
112
+ name: variable name
113
+ dtype: must match dtype of the column in the key table.
114
+ """
115
+ return Expr(_lib.spql.expr.keyed(f"#{name}", dtype))
116
+
117
+
118
+ def scalar(value: Any) -> Expr:
119
+ """Create a scalar expression."""
120
+ if not isinstance(value, pa.Scalar):
121
+ value = pa.scalar(value)
122
+ return Expr(_lib.spql.expr.scalar(value))
123
+
124
+
125
+ def cast(expr: ExprLike, dtype: pa.DataType) -> Expr:
126
+ """Cast an expression into another PyArrow DataType."""
127
+ expr = lift(expr)
128
+ return Expr(_lib.spql.expr.cast(expr.__expr__, dtype))
129
+
130
+
131
+ def and_(expr: ExprLike, *exprs: ExprLike) -> Expr:
132
+ """Create a conjunction of one or more expressions."""
133
+
134
+ return functools.reduce(operator.and_, [lift(e) for e in exprs], lift(expr))
135
+
136
+
137
+ def or_(expr: ExprLike, *exprs: ExprLike) -> Expr:
138
+ """Create a disjunction of one or more expressions."""
139
+ return functools.reduce(operator.or_, [lift(e) for e in exprs], lift(expr))
140
+
141
+
142
+ def eq(lhs: ExprLike, rhs: ExprLike) -> Expr:
143
+ """Create an equality comparison."""
144
+ return operator.eq(lift(lhs), rhs)
145
+
146
+
147
+ def neq(lhs: ExprLike, rhs: ExprLike) -> Expr:
148
+ """Create a not-equal comparison."""
149
+ return operator.ne(lift(lhs), rhs)
150
+
151
+
152
+ def xor(lhs: ExprLike, rhs: ExprLike) -> Expr:
153
+ """Create a XOR comparison."""
154
+ return operator.xor(lift(lhs), rhs)
155
+
156
+
157
+ def lt(lhs: ExprLike, rhs: ExprLike) -> Expr:
158
+ """Create a less-than comparison."""
159
+ return operator.lt(lift(lhs), rhs)
160
+
161
+
162
+ def lte(lhs: ExprLike, rhs: ExprLike) -> Expr:
163
+ """Create a less-than-or-equal comparison."""
164
+ return operator.le(lift(lhs), rhs)
165
+
166
+
167
+ def gt(lhs: ExprLike, rhs: ExprLike) -> Expr:
168
+ """Create a greater-than comparison."""
169
+ return operator.gt(lift(lhs), rhs)
170
+
171
+
172
+ def gte(lhs: ExprLike, rhs: ExprLike) -> Expr:
173
+ """Create a greater-than-or-equal comparison."""
174
+ return operator.ge(lift(lhs), rhs)
175
+
176
+
177
+ def negate(expr: ExprLike) -> Expr:
178
+ """Negate the given expression."""
179
+ return operator.neg(lift(expr))
180
+
181
+
182
+ def not_(expr: ExprLike) -> Expr:
183
+ """Negate the given expression."""
184
+ expr = lift(expr)
185
+ return Expr(_lib.spql.expr.unary("not", expr.__expr__))
186
+
187
+
188
+ def is_null(expr: ExprLike) -> Expr:
189
+ """Check if the given expression is null."""
190
+ expr = lift(expr)
191
+ return Expr(_lib.spql.expr.unary("is_null", expr.__expr__))
192
+
193
+
194
+ def is_not_null(expr: ExprLike) -> Expr:
195
+ """Check if the given expression is null."""
196
+ expr = lift(expr)
197
+ return Expr(_lib.spql.expr.unary("is_not_null", expr.__expr__))
198
+
199
+
200
+ def add(lhs: ExprLike, rhs: ExprLike) -> Expr:
201
+ """Add two expressions."""
202
+ return operator.add(lift(lhs), rhs)
203
+
204
+
205
+ def subtract(lhs: ExprLike, rhs: ExprLike) -> Expr:
206
+ """Subtract two expressions."""
207
+ return operator.sub(lift(lhs), rhs)
208
+
209
+
210
+ def multiply(lhs: ExprLike, rhs: ExprLike) -> Expr:
211
+ """Multiply two expressions."""
212
+ return operator.mul(lift(lhs), rhs)
213
+
214
+
215
+ def divide(lhs: ExprLike, rhs: ExprLike) -> Expr:
216
+ """Divide two expressions."""
217
+ return operator.truediv(lift(lhs), rhs)
218
+
219
+
220
+ def modulo(lhs: ExprLike, rhs: ExprLike) -> Expr:
221
+ """Modulo two expressions."""
222
+ return operator.mod(lift(lhs), rhs)