vgi-python 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vgi/__init__.py +152 -0
- vgi/_duckdb.py +62 -0
- vgi/_storage_profile.py +132 -0
- vgi/_test_fixtures/__init__.py +20 -0
- vgi/_test_fixtures/accumulate/__init__.py +19 -0
- vgi/_test_fixtures/accumulate/worker.py +762 -0
- vgi/_test_fixtures/aggregate/__init__.py +62 -0
- vgi/_test_fixtures/aggregate/_common.py +21 -0
- vgi/_test_fixtures/aggregate/basic.py +232 -0
- vgi/_test_fixtures/aggregate/dynamic.py +409 -0
- vgi/_test_fixtures/aggregate/generic.py +86 -0
- vgi/_test_fixtures/aggregate/listagg.py +71 -0
- vgi/_test_fixtures/aggregate/percentile.py +107 -0
- vgi/_test_fixtures/aggregate/streaming.py +192 -0
- vgi/_test_fixtures/aggregate/varargs.py +75 -0
- vgi/_test_fixtures/aggregate/window.py +380 -0
- vgi/_test_fixtures/attach_options.py +308 -0
- vgi/_test_fixtures/bad_protocol.py +62 -0
- vgi/_test_fixtures/cancellable.py +336 -0
- vgi/_test_fixtures/catalog.py +813 -0
- vgi/_test_fixtures/http_server.py +394 -0
- vgi/_test_fixtures/nest_tensor.py +614 -0
- vgi/_test_fixtures/orchard_catalog.py +47 -0
- vgi/_test_fixtures/projection_repro/__init__.py +6 -0
- vgi/_test_fixtures/projection_repro/worker.py +454 -0
- vgi/_test_fixtures/scalar/__init__.py +116 -0
- vgi/_test_fixtures/scalar/_common.py +69 -0
- vgi/_test_fixtures/scalar/arithmetic.py +321 -0
- vgi/_test_fixtures/scalar/binary.py +120 -0
- vgi/_test_fixtures/scalar/formatting.py +176 -0
- vgi/_test_fixtures/scalar/geo.py +300 -0
- vgi/_test_fixtures/scalar/null_handling.py +107 -0
- vgi/_test_fixtures/scalar/random_demo.py +171 -0
- vgi/_test_fixtures/scalar/settings_secrets.py +102 -0
- vgi/_test_fixtures/scalar/type_info.py +219 -0
- vgi/_test_fixtures/schema_reconcile/__init__.py +29 -0
- vgi/_test_fixtures/schema_reconcile/worker.py +653 -0
- vgi/_test_fixtures/simple_writable.py +793 -0
- vgi/_test_fixtures/table/__init__.py +221 -0
- vgi/_test_fixtures/table/_common.py +162 -0
- vgi/_test_fixtures/table/batch_index.py +283 -0
- vgi/_test_fixtures/table/batch_index_broken.py +200 -0
- vgi/_test_fixtures/table/catalog_scans.py +162 -0
- vgi/_test_fixtures/table/filters.py +1005 -0
- vgi/_test_fixtures/table/late_materialization.py +249 -0
- vgi/_test_fixtures/table/make_series.py +273 -0
- vgi/_test_fixtures/table/misc.py +499 -0
- vgi/_test_fixtures/table/order_modes.py +164 -0
- vgi/_test_fixtures/table/pairs.py +437 -0
- vgi/_test_fixtures/table/partition_columns.py +472 -0
- vgi/_test_fixtures/table/partition_columns_broken.py +304 -0
- vgi/_test_fixtures/table/profiling_example.py +195 -0
- vgi/_test_fixtures/table/required_filters.py +234 -0
- vgi/_test_fixtures/table/sequence.py +710 -0
- vgi/_test_fixtures/table/settings.py +426 -0
- vgi/_test_fixtures/table/transaction_storage.py +162 -0
- vgi/_test_fixtures/table/tt_pushdown.py +191 -0
- vgi/_test_fixtures/table/versioned.py +230 -0
- vgi/_test_fixtures/table_in_out.py +1392 -0
- vgi/_test_fixtures/versioned.py +155 -0
- vgi/_test_fixtures/versioned_tables.py +595 -0
- vgi/_test_fixtures/worker.py +1631 -0
- vgi/_test_fixtures/writable/__init__.py +8 -0
- vgi/_test_fixtures/writable/generic.py +236 -0
- vgi/_test_fixtures/writable/table.py +149 -0
- vgi/_test_fixtures/writable/worker.py +1148 -0
- vgi/aggregate_function.py +607 -0
- vgi/argument_spec.py +472 -0
- vgi/arguments.py +1747 -0
- vgi/auth.py +55 -0
- vgi/catalog/__init__.py +88 -0
- vgi/catalog/attach_option.py +206 -0
- vgi/catalog/catalog_interface.py +2767 -0
- vgi/catalog/descriptors.py +870 -0
- vgi/catalog/duckdb_statistics.py +377 -0
- vgi/catalog/secret_type.py +96 -0
- vgi/catalog/setting.py +253 -0
- vgi/catalog/storage.py +372 -0
- vgi/client/__init__.py +67 -0
- vgi/client/catalog_mixin.py +1251 -0
- vgi/client/cli.py +582 -0
- vgi/client/cli_catalog.py +182 -0
- vgi/client/cli_schema.py +270 -0
- vgi/client/cli_table.py +907 -0
- vgi/client/cli_transaction.py +97 -0
- vgi/client/cli_utils.py +441 -0
- vgi/client/cli_view.py +303 -0
- vgi/client/client.py +2183 -0
- vgi/exceptions.py +205 -0
- vgi/function.py +245 -0
- vgi/function_storage.py +1636 -0
- vgi/function_storage_azure_sql.py +922 -0
- vgi/function_storage_cf_do.py +740 -0
- vgi/http/__init__.py +25 -0
- vgi/http/demo_storage.py +212 -0
- vgi/http/worker_page.py +1252 -0
- vgi/invocation.py +154 -0
- vgi/logging_config.py +93 -0
- vgi/meta_worker.py +661 -0
- vgi/metadata.py +1403 -0
- vgi/otel.py +406 -0
- vgi/protocol.py +2418 -0
- vgi/protocol_version.txt +1 -0
- vgi/py.typed +0 -0
- vgi/scalar_function.py +1211 -0
- vgi/schema_utils.py +234 -0
- vgi/secret_protocol.py +124 -0
- vgi/secret_service.py +238 -0
- vgi/serve.py +769 -0
- vgi/table_buffering_function.py +443 -0
- vgi/table_filter_pushdown.py +1528 -0
- vgi/table_function.py +1130 -0
- vgi/table_in_out_function.py +383 -0
- vgi/transactor/__init__.py +24 -0
- vgi/transactor/_duckdb_compat.py +27 -0
- vgi/transactor/client.py +137 -0
- vgi/transactor/protocol.py +149 -0
- vgi/transactor/server.py +740 -0
- vgi/worker.py +4761 -0
- vgi_python-0.8.0.dist-info/METADATA +735 -0
- vgi_python-0.8.0.dist-info/RECORD +124 -0
- vgi_python-0.8.0.dist-info/WHEEL +4 -0
- vgi_python-0.8.0.dist-info/entry_points.txt +5 -0
- vgi_python-0.8.0.dist-info/licenses/LICENSE +134 -0
|
@@ -0,0 +1,1528 @@
|
|
|
1
|
+
# Copyright 2025, 2026 Query Farm LLC - https://query.farm
|
|
2
|
+
|
|
3
|
+
"""Filter pushdown AST classes for table functions.
|
|
4
|
+
|
|
5
|
+
This module provides:
|
|
6
|
+
- Filter AST classes for representing pushdown filter predicates
|
|
7
|
+
- ColumnBounds for extracting numeric bounds from filters
|
|
8
|
+
- PushdownFilters container with evaluation and helper methods
|
|
9
|
+
- Deserialization from Arrow IPC format
|
|
10
|
+
|
|
11
|
+
Filter Types:
|
|
12
|
+
ConstantFilter: Comparison with a constant value (=, !=, >, >=, <, <=)
|
|
13
|
+
IsNullFilter: IS NULL check
|
|
14
|
+
IsNotNullFilter: IS NOT NULL check
|
|
15
|
+
InFilter: Set membership (IN clause)
|
|
16
|
+
AndFilter: Conjunction of child filters
|
|
17
|
+
OrFilter: Disjunction of child filters
|
|
18
|
+
StructFilter: Nested struct field filter
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
from __future__ import annotations
|
|
22
|
+
|
|
23
|
+
import dataclasses
|
|
24
|
+
import json
|
|
25
|
+
import logging
|
|
26
|
+
import os
|
|
27
|
+
import threading
|
|
28
|
+
from dataclasses import dataclass
|
|
29
|
+
from enum import Enum
|
|
30
|
+
from typing import TYPE_CHECKING, Any, cast
|
|
31
|
+
|
|
32
|
+
import pyarrow as pa
|
|
33
|
+
import pyarrow.compute as pc
|
|
34
|
+
|
|
35
|
+
if TYPE_CHECKING:
|
|
36
|
+
from collections.abc import Callable, Iterator
|
|
37
|
+
|
|
38
|
+
# =============================================================================
|
|
39
|
+
# Debug Logging
|
|
40
|
+
# =============================================================================
|
|
41
|
+
|
|
42
|
+
# Enable with VGI_FILTER_DEBUG=1 for detailed filter pushdown diagnostics
|
|
43
|
+
_FILTER_DEBUG = os.environ.get("VGI_FILTER_DEBUG", "").lower() in ("1", "true", "yes")
|
|
44
|
+
_filter_logger = logging.getLogger("vgi.filter_pushdown")
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _log_debug(event: str, **kwargs: Any) -> None:
|
|
48
|
+
"""Log a debug message if VGI_FILTER_DEBUG is enabled."""
|
|
49
|
+
if _FILTER_DEBUG:
|
|
50
|
+
extra = " ".join(f"{k}={v}" for k, v in kwargs.items())
|
|
51
|
+
_filter_logger.debug("%s %s", event, extra) if extra else _filter_logger.debug("%s", event)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
# Supported filter protocol version
|
|
55
|
+
_SUPPORTED_VERSION = "1"
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _strip_extension(x: Any) -> Any:
|
|
59
|
+
"""Unwrap a canonical Arrow extension wrapper, returning the storage form.
|
|
60
|
+
|
|
61
|
+
Filter literals serialised by the DuckDB VGI extension carry their
|
|
62
|
+
canonical Arrow extension type (e.g. ``arrow.bool8`` for BOOLEAN, see
|
|
63
|
+
``vgi/duckdb/src/common/arrow/arrow_type_extension.cpp``). PyArrow's
|
|
64
|
+
binary compute kernels are type-pair-keyed and have no entry for
|
|
65
|
+
``equal(bool, extension<arrow.bool8>)`` — so we must strip the wrapper
|
|
66
|
+
before comparing. Pass the result of this helper to ``pc.equal`` and
|
|
67
|
+
friends.
|
|
68
|
+
|
|
69
|
+
For Arrays / ChunkedArrays, returns the ``.storage`` array. For
|
|
70
|
+
Scalars, returns the ``.value`` storage scalar. Plain (non-extension)
|
|
71
|
+
inputs are returned unchanged.
|
|
72
|
+
|
|
73
|
+
Note: the type-check uses ``pa.BaseExtensionType`` rather than
|
|
74
|
+
``pa.ExtensionType`` because canonical Arrow extension types like
|
|
75
|
+
``Bool8Type`` and ``UuidType`` inherit from ``BaseExtensionType``
|
|
76
|
+
directly without going through the user-extension class.
|
|
77
|
+
"""
|
|
78
|
+
if isinstance(x.type, pa.BaseExtensionType):
|
|
79
|
+
if isinstance(x, (pa.Array, pa.ChunkedArray)):
|
|
80
|
+
# Narrowed to Array/ChunkedArray, but an extension-typed one —
|
|
81
|
+
# ``.storage`` exists at runtime; re-widen so mypy allows it.
|
|
82
|
+
return cast(Any, x).storage
|
|
83
|
+
return x.value
|
|
84
|
+
return x
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _normalize_for_compare(col: Any, val: Any) -> tuple[Any, Any]:
|
|
88
|
+
"""Bring a (column, scalar) pair into a shape pyarrow.compute can compare.
|
|
89
|
+
|
|
90
|
+
Three transforms in order:
|
|
91
|
+
1. Strip canonical Arrow extension wrappers from both sides.
|
|
92
|
+
2. Decode a dictionary-encoded column to its value type. The
|
|
93
|
+
compute kernels (``is_in``, ``equal``, …) accept a dictionary
|
|
94
|
+
column paired with a plain literal, but throw if the literal is
|
|
95
|
+
*also* dictionary-encoded — so decoding the column (rather than
|
|
96
|
+
casting the literal up to a dictionary) is the path that
|
|
97
|
+
resolves. See ``ArrowTypeError: Array type doesn't match type
|
|
98
|
+
of values set``.
|
|
99
|
+
3. If the resulting types still differ (e.g. column is plain
|
|
100
|
+
``bool_`` while the literal came over the wire as
|
|
101
|
+
``arrow.bool8`` and stripped to ``int8``), cast the literal to
|
|
102
|
+
the column's type so the kernel resolves.
|
|
103
|
+
"""
|
|
104
|
+
col = _strip_extension(col)
|
|
105
|
+
val = _strip_extension(val)
|
|
106
|
+
if pa.types.is_dictionary(col.type):
|
|
107
|
+
col = col.cast(col.type.value_type)
|
|
108
|
+
if val.type != col.type:
|
|
109
|
+
val = val.cast(col.type)
|
|
110
|
+
return col, val
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def _make_bool_array(value: bool, length: int) -> pa.BooleanArray:
|
|
114
|
+
"""Create a boolean array of constant value.
|
|
115
|
+
|
|
116
|
+
Used for empty AND (all True) and empty OR (all False) filter results.
|
|
117
|
+
"""
|
|
118
|
+
return pa.repeat(pa.scalar(value), length)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
__all__ = [
|
|
122
|
+
# Exceptions
|
|
123
|
+
"FilterError",
|
|
124
|
+
"FilterDeserializationError",
|
|
125
|
+
"FilterVersionError",
|
|
126
|
+
# Enums
|
|
127
|
+
"FilterType",
|
|
128
|
+
"ComparisonOp",
|
|
129
|
+
"ExpressionNodeType",
|
|
130
|
+
# Filter classes
|
|
131
|
+
"Filter",
|
|
132
|
+
"ConstantFilter",
|
|
133
|
+
"IsNullFilter",
|
|
134
|
+
"IsNotNullFilter",
|
|
135
|
+
"InFilter",
|
|
136
|
+
"AndFilter",
|
|
137
|
+
"OrFilter",
|
|
138
|
+
"StructFilter",
|
|
139
|
+
"ExpressionFilter",
|
|
140
|
+
# Expression node classes
|
|
141
|
+
"ExpressionNode",
|
|
142
|
+
"ColumnRefNode",
|
|
143
|
+
"ConstantNode",
|
|
144
|
+
"FunctionNode",
|
|
145
|
+
"ComparisonNode",
|
|
146
|
+
"ConjunctionNode",
|
|
147
|
+
# Helpers
|
|
148
|
+
"ColumnBounds",
|
|
149
|
+
"PushdownFilters",
|
|
150
|
+
# Functions
|
|
151
|
+
"deserialize_filters",
|
|
152
|
+
]
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
# =============================================================================
|
|
156
|
+
# Exceptions
|
|
157
|
+
# =============================================================================
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
class FilterError(Exception):
|
|
161
|
+
"""Base exception for filter pushdown errors."""
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class FilterDeserializationError(FilterError):
|
|
165
|
+
"""Failed to parse filter IPC bytes."""
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
class FilterVersionError(FilterError):
|
|
169
|
+
"""Unsupported filter protocol version."""
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
# =============================================================================
|
|
173
|
+
# Enums
|
|
174
|
+
# =============================================================================
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
class FilterType(Enum):
|
|
178
|
+
"""Filter type identifiers matching the JSON protocol."""
|
|
179
|
+
|
|
180
|
+
CONSTANT = "constant"
|
|
181
|
+
IS_NULL = "is_null"
|
|
182
|
+
IS_NOT_NULL = "is_not_null"
|
|
183
|
+
IN = "in"
|
|
184
|
+
JOIN_KEYS = "join_keys"
|
|
185
|
+
AND = "and"
|
|
186
|
+
OR = "or"
|
|
187
|
+
STRUCT = "struct"
|
|
188
|
+
EXPRESSION = "expression"
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
class ExpressionNodeType(Enum):
|
|
192
|
+
"""Expression node type identifiers matching the JSON protocol."""
|
|
193
|
+
|
|
194
|
+
COLUMN_REF = "column_ref"
|
|
195
|
+
CONSTANT = "constant"
|
|
196
|
+
FUNCTION = "function"
|
|
197
|
+
COMPARISON = "comparison"
|
|
198
|
+
CONJUNCTION = "conjunction"
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
class ComparisonOp(Enum):
|
|
202
|
+
"""Comparison operators for constant filters."""
|
|
203
|
+
|
|
204
|
+
EQ = "eq" # =
|
|
205
|
+
NE = "ne" # !=
|
|
206
|
+
GT = "gt" # >
|
|
207
|
+
GE = "ge" # >=
|
|
208
|
+
LT = "lt" # <
|
|
209
|
+
LE = "le" # <=
|
|
210
|
+
|
|
211
|
+
@property
|
|
212
|
+
def symbol(self) -> str:
|
|
213
|
+
"""Return the SQL symbol for this operator."""
|
|
214
|
+
symbols = {"eq": "=", "ne": "!=", "gt": ">", "ge": ">=", "lt": "<", "le": "<="}
|
|
215
|
+
return symbols[self.value]
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
# =============================================================================
|
|
219
|
+
# Filter Base Class
|
|
220
|
+
# =============================================================================
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
@dataclass(frozen=True, slots=True)
|
|
224
|
+
class Filter:
|
|
225
|
+
"""Base class for all filter types.
|
|
226
|
+
|
|
227
|
+
Attributes:
|
|
228
|
+
column_name: Name of the column this filter applies to.
|
|
229
|
+
column_index: Index of the column in the output schema.
|
|
230
|
+
|
|
231
|
+
"""
|
|
232
|
+
|
|
233
|
+
column_name: str
|
|
234
|
+
column_index: int
|
|
235
|
+
|
|
236
|
+
def evaluate(self, batch: pa.RecordBatch) -> pa.BooleanArray:
|
|
237
|
+
"""Evaluate filter against batch using PyArrow compute.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
batch: RecordBatch to evaluate filter against.
|
|
241
|
+
|
|
242
|
+
Returns:
|
|
243
|
+
Boolean array with True for rows that pass the filter.
|
|
244
|
+
|
|
245
|
+
Raises:
|
|
246
|
+
NotImplementedError: Base class does not implement evaluation.
|
|
247
|
+
|
|
248
|
+
"""
|
|
249
|
+
raise NotImplementedError
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
# =============================================================================
|
|
253
|
+
# Filter Type Classes
|
|
254
|
+
# =============================================================================
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
@dataclass(frozen=True, slots=True)
|
|
258
|
+
class ConstantFilter(Filter):
|
|
259
|
+
"""Comparison filter: column <op> value.
|
|
260
|
+
|
|
261
|
+
Examples:
|
|
262
|
+
age >= 18
|
|
263
|
+
status = 'active'
|
|
264
|
+
price < 100.0
|
|
265
|
+
|
|
266
|
+
"""
|
|
267
|
+
|
|
268
|
+
op: ComparisonOp
|
|
269
|
+
value: pa.Scalar[Any]
|
|
270
|
+
|
|
271
|
+
def evaluate(self, batch: pa.RecordBatch) -> pa.BooleanArray:
|
|
272
|
+
"""Evaluate comparison against batch column."""
|
|
273
|
+
col, val = _normalize_for_compare(batch.column(self.column_index), self.value)
|
|
274
|
+
# _normalize_for_compare returns (Any, Any) — the compute kernels
|
|
275
|
+
# below resolve to BooleanArray at runtime; cast once on the way out.
|
|
276
|
+
result: Any
|
|
277
|
+
match self.op:
|
|
278
|
+
case ComparisonOp.EQ:
|
|
279
|
+
result = pc.equal(col, val)
|
|
280
|
+
case ComparisonOp.NE:
|
|
281
|
+
result = pc.not_equal(col, val)
|
|
282
|
+
case ComparisonOp.GT:
|
|
283
|
+
result = pc.greater(col, val)
|
|
284
|
+
case ComparisonOp.GE:
|
|
285
|
+
result = pc.greater_equal(col, val)
|
|
286
|
+
case ComparisonOp.LT:
|
|
287
|
+
result = pc.less(col, val)
|
|
288
|
+
case ComparisonOp.LE:
|
|
289
|
+
result = pc.less_equal(col, val)
|
|
290
|
+
case _:
|
|
291
|
+
raise ValueError(f"Unknown comparison operator: {self.op}")
|
|
292
|
+
return cast("pa.BooleanArray", result)
|
|
293
|
+
|
|
294
|
+
def __repr__(self) -> str:
|
|
295
|
+
"""Return string representation for debugging."""
|
|
296
|
+
return f"ConstantFilter({self.column_name} {self.op.symbol} {self.value})"
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
@dataclass(frozen=True, slots=True)
|
|
300
|
+
class IsNullFilter(Filter):
|
|
301
|
+
"""IS NULL check filter."""
|
|
302
|
+
|
|
303
|
+
def evaluate(self, batch: pa.RecordBatch) -> pa.BooleanArray:
|
|
304
|
+
"""Evaluate IS NULL check against batch column."""
|
|
305
|
+
return pc.is_null(batch.column(self.column_index))
|
|
306
|
+
|
|
307
|
+
def __repr__(self) -> str:
|
|
308
|
+
"""Return string representation for debugging."""
|
|
309
|
+
return f"IsNullFilter({self.column_name} IS NULL)"
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
@dataclass(frozen=True, slots=True)
|
|
313
|
+
class IsNotNullFilter(Filter):
|
|
314
|
+
"""IS NOT NULL check filter."""
|
|
315
|
+
|
|
316
|
+
def evaluate(self, batch: pa.RecordBatch) -> pa.BooleanArray:
|
|
317
|
+
"""Evaluate IS NOT NULL check against batch column."""
|
|
318
|
+
return pc.is_valid(batch.column(self.column_index))
|
|
319
|
+
|
|
320
|
+
def __repr__(self) -> str:
|
|
321
|
+
"""Return string representation for debugging."""
|
|
322
|
+
return f"IsNotNullFilter({self.column_name} IS NOT NULL)"
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
@dataclass(frozen=True, slots=True)
|
|
326
|
+
class InFilter(Filter):
|
|
327
|
+
"""IN (v1, v2, ...) set membership filter.
|
|
328
|
+
|
|
329
|
+
The values are stored as an Arrow array (the contents of the list column).
|
|
330
|
+
"""
|
|
331
|
+
|
|
332
|
+
values: pa.Array[Any]
|
|
333
|
+
|
|
334
|
+
def evaluate(self, batch: pa.RecordBatch) -> pa.BooleanArray:
|
|
335
|
+
"""Evaluate IN membership against batch column."""
|
|
336
|
+
col, vals = _normalize_for_compare(batch.column(self.column_index), self.values)
|
|
337
|
+
return cast("pa.BooleanArray", pc.is_in(col, vals))
|
|
338
|
+
|
|
339
|
+
def __repr__(self) -> str:
|
|
340
|
+
"""Return string representation for debugging."""
|
|
341
|
+
values = self.values.to_pylist()
|
|
342
|
+
preview = f"{values[:3]!r}...({len(values)} total)" if len(values) > 5 else repr(values)
|
|
343
|
+
return f"InFilter({self.column_name} IN {preview})"
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
@dataclass(frozen=True, slots=True)
|
|
347
|
+
class AndFilter(Filter):
|
|
348
|
+
"""Conjunction of child filters.
|
|
349
|
+
|
|
350
|
+
All child filters must pass for a row to pass.
|
|
351
|
+
"""
|
|
352
|
+
|
|
353
|
+
children: tuple[Filter, ...]
|
|
354
|
+
|
|
355
|
+
def evaluate(self, batch: pa.RecordBatch) -> pa.BooleanArray:
|
|
356
|
+
"""Evaluate AND of all child filters."""
|
|
357
|
+
if not self.children:
|
|
358
|
+
return _make_bool_array(True, batch.num_rows)
|
|
359
|
+
result = self.children[0].evaluate(batch)
|
|
360
|
+
for child in self.children[1:]:
|
|
361
|
+
result = pc.and_(result, child.evaluate(batch))
|
|
362
|
+
return result
|
|
363
|
+
|
|
364
|
+
def __repr__(self) -> str:
|
|
365
|
+
"""Return string representation for debugging."""
|
|
366
|
+
children_repr = " AND ".join(repr(c) for c in self.children)
|
|
367
|
+
return f"AndFilter({children_repr})"
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
@dataclass(frozen=True, slots=True)
|
|
371
|
+
class OrFilter(Filter):
|
|
372
|
+
"""Disjunction of child filters.
|
|
373
|
+
|
|
374
|
+
At least one child filter must pass for a row to pass.
|
|
375
|
+
"""
|
|
376
|
+
|
|
377
|
+
children: tuple[Filter, ...]
|
|
378
|
+
|
|
379
|
+
def evaluate(self, batch: pa.RecordBatch) -> pa.BooleanArray:
|
|
380
|
+
"""Evaluate OR of all child filters."""
|
|
381
|
+
if not self.children:
|
|
382
|
+
return _make_bool_array(False, batch.num_rows)
|
|
383
|
+
result = self.children[0].evaluate(batch)
|
|
384
|
+
for child in self.children[1:]:
|
|
385
|
+
result = pc.or_(result, child.evaluate(batch))
|
|
386
|
+
return result
|
|
387
|
+
|
|
388
|
+
def __repr__(self) -> str:
|
|
389
|
+
"""Return string representation for debugging."""
|
|
390
|
+
children_repr = " OR ".join(repr(c) for c in self.children)
|
|
391
|
+
return f"OrFilter({children_repr})"
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
class _SingleColumnBatch:
|
|
395
|
+
"""Lightweight wrapper providing batch-like interface for a single array.
|
|
396
|
+
|
|
397
|
+
Used by StructFilter to avoid creating a full RecordBatch when evaluating
|
|
398
|
+
child filters on nested struct fields.
|
|
399
|
+
"""
|
|
400
|
+
|
|
401
|
+
__slots__ = ("_array",)
|
|
402
|
+
|
|
403
|
+
def __init__(self, array: pa.Array[Any]) -> None:
|
|
404
|
+
self._array = array
|
|
405
|
+
|
|
406
|
+
def column(self, _index: int) -> pa.Array[Any]:
|
|
407
|
+
"""Return the wrapped array (index is ignored)."""
|
|
408
|
+
return self._array
|
|
409
|
+
|
|
410
|
+
@property
|
|
411
|
+
def num_rows(self) -> int:
|
|
412
|
+
"""Return the number of rows in the array."""
|
|
413
|
+
return len(self._array)
|
|
414
|
+
|
|
415
|
+
|
|
416
|
+
@dataclass(frozen=True, slots=True)
|
|
417
|
+
class StructFilter(Filter):
|
|
418
|
+
"""Nested struct field filter.
|
|
419
|
+
|
|
420
|
+
Filters on a nested field within a struct column.
|
|
421
|
+
Example: address.city = 'Seattle'
|
|
422
|
+
"""
|
|
423
|
+
|
|
424
|
+
child_index: int
|
|
425
|
+
child_name: str
|
|
426
|
+
child_filter: Filter
|
|
427
|
+
|
|
428
|
+
def evaluate(self, batch: pa.RecordBatch) -> pa.BooleanArray:
|
|
429
|
+
"""Evaluate filter on nested struct field."""
|
|
430
|
+
struct_col = batch.column(self.column_index)
|
|
431
|
+
nested = pc.struct_field(struct_col, self.child_name)
|
|
432
|
+
# Use lightweight wrapper instead of creating a full RecordBatch
|
|
433
|
+
wrapper = _SingleColumnBatch(nested)
|
|
434
|
+
# Adjust child filter to use column_index=0 for the wrapper
|
|
435
|
+
adjusted_child = dataclasses.replace(self.child_filter, column_index=0)
|
|
436
|
+
return adjusted_child.evaluate(wrapper) # type: ignore[arg-type]
|
|
437
|
+
|
|
438
|
+
def __repr__(self) -> str:
|
|
439
|
+
"""Return string representation for debugging."""
|
|
440
|
+
nested = f"{self.column_name}.{self.child_name}"
|
|
441
|
+
return f"StructFilter({nested}: {self.child_filter!r})"
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
# =============================================================================
|
|
445
|
+
# Expression Filter (recursive expression tree from DuckDB)
|
|
446
|
+
# =============================================================================
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
@dataclass(frozen=True, slots=True)
|
|
450
|
+
class ExpressionNode:
|
|
451
|
+
"""Base class for expression tree nodes.
|
|
452
|
+
|
|
453
|
+
Subclasses must set ``expr_type`` to match their class. This field
|
|
454
|
+
is used for serialization round-tripping (JSON ``expr_type`` key).
|
|
455
|
+
"""
|
|
456
|
+
|
|
457
|
+
expr_type: ExpressionNodeType
|
|
458
|
+
|
|
459
|
+
def to_sql(self, column_name: str) -> str:
|
|
460
|
+
"""Convert node to SQL string. Override in subclasses."""
|
|
461
|
+
raise NotImplementedError
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
@dataclass(frozen=True, slots=True)
|
|
465
|
+
class ColumnRefNode(ExpressionNode):
|
|
466
|
+
"""Column reference node.
|
|
467
|
+
|
|
468
|
+
Note: In v1, all column refs in an expression filter refer to the same
|
|
469
|
+
column (the filter column). The index is stored for future multi-column
|
|
470
|
+
support but to_sql() always uses the filter's column_name.
|
|
471
|
+
"""
|
|
472
|
+
|
|
473
|
+
index: int
|
|
474
|
+
|
|
475
|
+
def to_sql(self, column_name: str) -> str:
|
|
476
|
+
"""Return quoted column name with double-quote escaping."""
|
|
477
|
+
escaped = column_name.replace('"', '""')
|
|
478
|
+
return f'"{escaped}"'
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
@dataclass(frozen=True, slots=True)
|
|
482
|
+
class ConstantNode(ExpressionNode):
|
|
483
|
+
"""Constant value node."""
|
|
484
|
+
|
|
485
|
+
value: pa.Scalar[Any]
|
|
486
|
+
field: pa.Field[Any] | None = None # Arrow field with extension metadata (if available)
|
|
487
|
+
|
|
488
|
+
def to_sql(self, column_name: str) -> str:
|
|
489
|
+
"""Format Arrow scalar as SQL literal, using field metadata for extension types."""
|
|
490
|
+
return _arrow_scalar_to_sql(self.value, self.field)
|
|
491
|
+
|
|
492
|
+
|
|
493
|
+
def _is_operator_name(name: str) -> bool:
|
|
494
|
+
"""Check if a function name is an infix operator (all non-alphanumeric/underscore chars)."""
|
|
495
|
+
return len(name) > 0 and all(not (c.isalnum() or c == "_") for c in name)
|
|
496
|
+
|
|
497
|
+
|
|
498
|
+
@dataclass(frozen=True, slots=True)
|
|
499
|
+
class FunctionNode(ExpressionNode):
|
|
500
|
+
"""Function call node."""
|
|
501
|
+
|
|
502
|
+
function_name: str
|
|
503
|
+
children: tuple[ExpressionNode, ...]
|
|
504
|
+
|
|
505
|
+
def to_sql(self, column_name: str) -> str:
|
|
506
|
+
"""Format as function_name(args...) or infix for operators like &&."""
|
|
507
|
+
if _is_operator_name(self.function_name) and len(self.children) == 2:
|
|
508
|
+
left = self.children[0].to_sql(column_name)
|
|
509
|
+
right = self.children[1].to_sql(column_name)
|
|
510
|
+
return f"({left} {self.function_name} {right})"
|
|
511
|
+
args = ", ".join(c.to_sql(column_name) for c in self.children)
|
|
512
|
+
return f"{self.function_name}({args})"
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
@dataclass(frozen=True, slots=True)
|
|
516
|
+
class ComparisonNode(ExpressionNode):
|
|
517
|
+
"""Comparison node (left op right)."""
|
|
518
|
+
|
|
519
|
+
op: ComparisonOp
|
|
520
|
+
left: ExpressionNode
|
|
521
|
+
right: ExpressionNode
|
|
522
|
+
|
|
523
|
+
def to_sql(self, column_name: str) -> str:
|
|
524
|
+
"""Format as (left op right)."""
|
|
525
|
+
return f"({self.left.to_sql(column_name)} {self.op.symbol} {self.right.to_sql(column_name)})"
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
@dataclass(frozen=True, slots=True)
|
|
529
|
+
class ConjunctionNode(ExpressionNode):
|
|
530
|
+
"""AND/OR conjunction node."""
|
|
531
|
+
|
|
532
|
+
conjunction_type: str # "and" or "or"
|
|
533
|
+
children: tuple[ExpressionNode, ...]
|
|
534
|
+
|
|
535
|
+
def to_sql(self, column_name: str) -> str:
|
|
536
|
+
"""Format as (child1 AND/OR child2 AND/OR ...)."""
|
|
537
|
+
joiner = " AND " if self.conjunction_type == "and" else " OR "
|
|
538
|
+
parts = [c.to_sql(column_name) for c in self.children]
|
|
539
|
+
return f"({joiner.join(parts)})"
|
|
540
|
+
|
|
541
|
+
|
|
542
|
+
# Arrow extension type name → SQL wrapper for binary values.
|
|
543
|
+
# Maps ARROW:extension:name metadata to a function that converts hex to SQL.
|
|
544
|
+
_ARROW_EXTENSION_SQL: dict[str, Callable[[str], str]] = {
|
|
545
|
+
"geoarrow.wkb": lambda hex_str: f"ST_GeomFromHEXWKB('{hex_str}')",
|
|
546
|
+
}
|
|
547
|
+
|
|
548
|
+
|
|
549
|
+
def _arrow_scalar_to_sql(scalar: pa.Scalar[Any], field: pa.Field[Any] | None = None) -> str:
|
|
550
|
+
"""Convert an Arrow scalar to a SQL literal string.
|
|
551
|
+
|
|
552
|
+
Uses Arrow field metadata to handle extension types (e.g., geoarrow.wkb
|
|
553
|
+
for geometry). Falls back to generic representations for standard types.
|
|
554
|
+
|
|
555
|
+
Args:
|
|
556
|
+
scalar: Arrow scalar value.
|
|
557
|
+
field: Arrow field with extension metadata (optional). Used to detect
|
|
558
|
+
extension types like geoarrow.wkb for proper SQL rendering.
|
|
559
|
+
|
|
560
|
+
"""
|
|
561
|
+
if scalar.is_valid is False:
|
|
562
|
+
return "NULL"
|
|
563
|
+
|
|
564
|
+
val = scalar.as_py()
|
|
565
|
+
typ = scalar.type
|
|
566
|
+
|
|
567
|
+
if pa.types.is_boolean(typ):
|
|
568
|
+
return "TRUE" if val else "FALSE"
|
|
569
|
+
elif pa.types.is_integer(typ) or pa.types.is_floating(typ):
|
|
570
|
+
return str(val)
|
|
571
|
+
elif pa.types.is_string(typ) or pa.types.is_large_string(typ):
|
|
572
|
+
escaped = str(val).replace("'", "''")
|
|
573
|
+
return f"'{escaped}'"
|
|
574
|
+
elif pa.types.is_binary(typ) or pa.types.is_large_binary(typ):
|
|
575
|
+
hex_str = val.hex()
|
|
576
|
+
# Check Arrow extension metadata for type-specific rendering
|
|
577
|
+
if field is not None and field.metadata is not None:
|
|
578
|
+
ext_name = field.metadata.get(b"ARROW:extension:name", b"").decode()
|
|
579
|
+
if ext_name in _ARROW_EXTENSION_SQL:
|
|
580
|
+
return _ARROW_EXTENSION_SQL[ext_name](hex_str)
|
|
581
|
+
return f"'\\x{hex_str}'::BLOB"
|
|
582
|
+
else:
|
|
583
|
+
# Fallback: use Python repr as string literal
|
|
584
|
+
escaped = str(val).replace("'", "''")
|
|
585
|
+
return f"'{escaped}'"
|
|
586
|
+
|
|
587
|
+
|
|
588
|
+
_expression_eval_local = threading.local()
|
|
589
|
+
|
|
590
|
+
|
|
591
|
+
def _get_expression_eval_connection() -> Any:
|
|
592
|
+
"""Get or create a thread-local DuckDB connection for expression filter evaluation.
|
|
593
|
+
|
|
594
|
+
Each thread gets its own connection (DuckDB connections are not thread-safe).
|
|
595
|
+
Spatial extension is loaded if available (needed for geometry functions like &&).
|
|
596
|
+
"""
|
|
597
|
+
conn = getattr(_expression_eval_local, "conn", None)
|
|
598
|
+
if conn is not None:
|
|
599
|
+
return conn
|
|
600
|
+
|
|
601
|
+
from vgi._duckdb import connect as engine_connect
|
|
602
|
+
|
|
603
|
+
conn = engine_connect()
|
|
604
|
+
try:
|
|
605
|
+
conn.load_extension("spatial")
|
|
606
|
+
except Exception:
|
|
607
|
+
try:
|
|
608
|
+
conn.install_extension("spatial")
|
|
609
|
+
conn.load_extension("spatial")
|
|
610
|
+
except Exception:
|
|
611
|
+
pass # spatial not available — non-spatial expressions still work
|
|
612
|
+
_expression_eval_local.conn = conn
|
|
613
|
+
return conn
|
|
614
|
+
|
|
615
|
+
|
|
616
|
+
@dataclass(frozen=True, slots=True)
|
|
617
|
+
class ExpressionFilter(Filter):
|
|
618
|
+
"""Expression tree filter pushed from DuckDB.
|
|
619
|
+
|
|
620
|
+
Contains a recursive expression tree that the worker evaluates
|
|
621
|
+
using DuckDB. Typical use: spatial predicates like ``geom && box``.
|
|
622
|
+
"""
|
|
623
|
+
|
|
624
|
+
expr: ExpressionNode
|
|
625
|
+
|
|
626
|
+
def evaluate(self, batch: pa.RecordBatch) -> pa.BooleanArray:
|
|
627
|
+
"""Evaluate expression tree against batch using DuckDB.
|
|
628
|
+
|
|
629
|
+
Uses a cached per-process DuckDB connection with spatial extension
|
|
630
|
+
pre-loaded (if available). The engine is imported lazily via
|
|
631
|
+
:mod:`vgi._duckdb` (haybarn preferred, duckdb fallback) — workers
|
|
632
|
+
that don't use expression filters don't need either installed.
|
|
633
|
+
"""
|
|
634
|
+
conn = _get_expression_eval_connection()
|
|
635
|
+
tbl = conn.from_arrow(batch) # noqa: F841 (used in SQL below)
|
|
636
|
+
sql_expr = self.expr.to_sql(self.column_name)
|
|
637
|
+
query_result = conn.sql(f"SELECT ({sql_expr})::BOOLEAN AS _r FROM tbl")
|
|
638
|
+
# Use to_arrow_table if available (duckdb >= 1.5.1), fall back to fetch_arrow_table
|
|
639
|
+
fetch = getattr(query_result, "to_arrow_table", None) or query_result.fetch_arrow_table
|
|
640
|
+
result = fetch()
|
|
641
|
+
return result.column("_r").combine_chunks() # type: ignore[no-any-return]
|
|
642
|
+
|
|
643
|
+
def __repr__(self) -> str:
|
|
644
|
+
"""Return string representation for debugging."""
|
|
645
|
+
return f"ExpressionFilter({self.column_name}: {self.expr.to_sql(self.column_name)})"
|
|
646
|
+
|
|
647
|
+
|
|
648
|
+
# =============================================================================
|
|
649
|
+
# Column Bounds Helper
|
|
650
|
+
# =============================================================================
|
|
651
|
+
|
|
652
|
+
|
|
653
|
+
@dataclass(frozen=True, slots=True)
|
|
654
|
+
class ColumnBounds:
|
|
655
|
+
"""Numeric/comparable bounds for a column extracted from filters.
|
|
656
|
+
|
|
657
|
+
Use case: Partition pruning, index range scans, bounded data fetches.
|
|
658
|
+
|
|
659
|
+
Attributes:
|
|
660
|
+
min_value: Minimum bound value, or None if unbounded below.
|
|
661
|
+
min_inclusive: True if min_value is inclusive (>=), False if exclusive (>).
|
|
662
|
+
max_value: Maximum bound value, or None if unbounded above.
|
|
663
|
+
max_inclusive: True if max_value is inclusive (<=), False if exclusive (<).
|
|
664
|
+
|
|
665
|
+
"""
|
|
666
|
+
|
|
667
|
+
min_value: pa.Scalar[Any] | None = None
|
|
668
|
+
min_inclusive: bool = True
|
|
669
|
+
max_value: pa.Scalar[Any] | None = None
|
|
670
|
+
max_inclusive: bool = True
|
|
671
|
+
|
|
672
|
+
def contains(self, value: Any) -> bool:
|
|
673
|
+
"""Check if a value satisfies these bounds.
|
|
674
|
+
|
|
675
|
+
Args:
|
|
676
|
+
value: Value to check against bounds.
|
|
677
|
+
|
|
678
|
+
Returns:
|
|
679
|
+
True if value is within bounds, False otherwise.
|
|
680
|
+
|
|
681
|
+
"""
|
|
682
|
+
if self.min_value is not None:
|
|
683
|
+
min_val = self.min_value.as_py()
|
|
684
|
+
below_min = value < min_val if self.min_inclusive else value <= min_val
|
|
685
|
+
if below_min:
|
|
686
|
+
return False
|
|
687
|
+
|
|
688
|
+
if self.max_value is not None:
|
|
689
|
+
max_val = self.max_value.as_py()
|
|
690
|
+
above_max = value > max_val if self.max_inclusive else value >= max_val
|
|
691
|
+
if above_max:
|
|
692
|
+
return False
|
|
693
|
+
|
|
694
|
+
return True
|
|
695
|
+
|
|
696
|
+
|
|
697
|
+
# =============================================================================
|
|
698
|
+
# PushdownFilters Container
|
|
699
|
+
# =============================================================================
|
|
700
|
+
|
|
701
|
+
|
|
702
|
+
@dataclass(frozen=True, slots=True)
|
|
703
|
+
class PushdownFilters:
|
|
704
|
+
"""Container for pushdown filters with evaluation and query helpers.
|
|
705
|
+
|
|
706
|
+
The top-level filters array represents a conjunction (AND). Each filter in
|
|
707
|
+
the array must be satisfied for a row to pass. Individual filters may
|
|
708
|
+
themselves be AND/OR compound filters for more complex expressions.
|
|
709
|
+
|
|
710
|
+
Provides:
|
|
711
|
+
- evaluate(batch) / apply(batch) - Apply filters using PyArrow compute
|
|
712
|
+
- get_column_bounds(name) - Extract numeric bounds for partition pruning
|
|
713
|
+
- get_column_constant(name) - Get equality constant for a column
|
|
714
|
+
- get_column_in_values(name) - Get IN list values
|
|
715
|
+
- get_column_filters(name) - Get all filters for a column
|
|
716
|
+
- to_sql() - Generate SQL WHERE clause
|
|
717
|
+
|
|
718
|
+
"""
|
|
719
|
+
|
|
720
|
+
filters: tuple[Filter, ...]
|
|
721
|
+
version: str = _SUPPORTED_VERSION
|
|
722
|
+
join_keys_batches: list[pa.RecordBatch] | None = None
|
|
723
|
+
|
|
724
|
+
def get_join_keys_batch(self) -> pa.RecordBatch | None:
|
|
725
|
+
"""Return a merged join keys batch for temp table registration.
|
|
726
|
+
|
|
727
|
+
When all join key batches have the same row count (the semi-join
|
|
728
|
+
case), returns a single RecordBatch with all columns merged.
|
|
729
|
+
When batches have different row counts (independent IN filters),
|
|
730
|
+
they cannot be merged, so this returns ``None``.
|
|
731
|
+
|
|
732
|
+
For individual column access, use :meth:`get_join_keys_batches`
|
|
733
|
+
or :meth:`get_column_in_values`.
|
|
734
|
+
|
|
735
|
+
Example::
|
|
736
|
+
|
|
737
|
+
keys = params.current_pushdown_filters.get_join_keys_batch()
|
|
738
|
+
if keys is not None:
|
|
739
|
+
conn.register("join_keys", keys)
|
|
740
|
+
result = conn.sql(
|
|
741
|
+
"SELECT d.* FROM my_data d JOIN join_keys USING (id)"
|
|
742
|
+
)
|
|
743
|
+
|
|
744
|
+
Returns:
|
|
745
|
+
Merged RecordBatch when all batches have equal row counts,
|
|
746
|
+
or None.
|
|
747
|
+
|
|
748
|
+
"""
|
|
749
|
+
if not self.join_keys_batches:
|
|
750
|
+
return None
|
|
751
|
+
row_counts = {b.num_rows for b in self.join_keys_batches}
|
|
752
|
+
if len(row_counts) != 1:
|
|
753
|
+
return None
|
|
754
|
+
# All same cardinality — merge into one multi-column batch
|
|
755
|
+
columns: list[pa.Array[Any]] = []
|
|
756
|
+
fields: list[pa.Field[Any]] = []
|
|
757
|
+
for b in self.join_keys_batches:
|
|
758
|
+
for i in range(b.num_columns):
|
|
759
|
+
columns.append(b.column(i))
|
|
760
|
+
fields.append(b.schema.field(i))
|
|
761
|
+
return pa.RecordBatch.from_arrays(columns, schema=pa.schema(fields))
|
|
762
|
+
|
|
763
|
+
def get_join_keys_batches(self) -> list[pa.RecordBatch] | None:
|
|
764
|
+
"""Return all join key batches (one per IN filter column).
|
|
765
|
+
|
|
766
|
+
Each batch is a single-column RecordBatch. Different batches may have
|
|
767
|
+
different row counts. Returns ``None`` if no join keys were pushed.
|
|
768
|
+
|
|
769
|
+
"""
|
|
770
|
+
return self.join_keys_batches if self.join_keys_batches else None
|
|
771
|
+
|
|
772
|
+
def evaluate(self, batch: pa.RecordBatch) -> pa.BooleanArray:
|
|
773
|
+
"""Evaluate all filters, returning boolean mask.
|
|
774
|
+
|
|
775
|
+
Filters are combined with AND at the top level - a row passes only
|
|
776
|
+
if ALL filters evaluate to true for that row.
|
|
777
|
+
|
|
778
|
+
Args:
|
|
779
|
+
batch: RecordBatch to evaluate filters against.
|
|
780
|
+
|
|
781
|
+
Returns:
|
|
782
|
+
Boolean array with True for rows that pass all filters.
|
|
783
|
+
|
|
784
|
+
"""
|
|
785
|
+
_log_debug(
|
|
786
|
+
"evaluate_start",
|
|
787
|
+
num_filters=len(self.filters),
|
|
788
|
+
input_rows=batch.num_rows,
|
|
789
|
+
columns=[f.column_name for f in self.filters],
|
|
790
|
+
)
|
|
791
|
+
|
|
792
|
+
if not self.filters:
|
|
793
|
+
_log_debug("evaluate_no_filters", input_rows=batch.num_rows)
|
|
794
|
+
return _make_bool_array(True, batch.num_rows)
|
|
795
|
+
|
|
796
|
+
result = self.filters[0].evaluate(batch)
|
|
797
|
+
# pc.sum works on BooleanArray (counts True values) but stubs don't reflect this
|
|
798
|
+
true_count: int | None = pc.sum(result).as_py() # type: ignore[type-var]
|
|
799
|
+
_log_debug(
|
|
800
|
+
"evaluate_filter",
|
|
801
|
+
filter_index=0,
|
|
802
|
+
filter_type=type(self.filters[0]).__name__,
|
|
803
|
+
filter_repr=repr(self.filters[0]),
|
|
804
|
+
rows_passing=true_count,
|
|
805
|
+
)
|
|
806
|
+
|
|
807
|
+
for i, f in enumerate(self.filters[1:], start=1):
|
|
808
|
+
result = pc.and_(result, f.evaluate(batch))
|
|
809
|
+
true_count = pc.sum(result).as_py() # type: ignore[type-var]
|
|
810
|
+
_log_debug(
|
|
811
|
+
"evaluate_filter",
|
|
812
|
+
filter_index=i,
|
|
813
|
+
filter_type=type(f).__name__,
|
|
814
|
+
filter_repr=repr(f),
|
|
815
|
+
rows_passing=true_count,
|
|
816
|
+
)
|
|
817
|
+
|
|
818
|
+
final_count: int | None = pc.sum(result).as_py() # type: ignore[type-var]
|
|
819
|
+
_log_debug(
|
|
820
|
+
"evaluate_complete",
|
|
821
|
+
input_rows=batch.num_rows,
|
|
822
|
+
rows_passing=final_count,
|
|
823
|
+
rows_filtered=batch.num_rows - (final_count or 0),
|
|
824
|
+
)
|
|
825
|
+
return result
|
|
826
|
+
|
|
827
|
+
def apply(self, batch: pa.RecordBatch) -> pa.RecordBatch:
|
|
828
|
+
"""Apply all filters to batch, returning filtered batch.
|
|
829
|
+
|
|
830
|
+
Args:
|
|
831
|
+
batch: RecordBatch to filter.
|
|
832
|
+
|
|
833
|
+
Returns:
|
|
834
|
+
Filtered RecordBatch containing only rows that pass all filters.
|
|
835
|
+
|
|
836
|
+
"""
|
|
837
|
+
_log_debug("apply_start", input_rows=batch.num_rows)
|
|
838
|
+
mask = self.evaluate(batch)
|
|
839
|
+
# pc.filter supports RecordBatch but pyarrow-stubs don't have the overload
|
|
840
|
+
filtered: pa.RecordBatch = pc.filter(batch, mask) # type: ignore[call-overload]
|
|
841
|
+
_log_debug(
|
|
842
|
+
"apply_complete",
|
|
843
|
+
input_rows=batch.num_rows,
|
|
844
|
+
output_rows=filtered.num_rows,
|
|
845
|
+
rows_removed=batch.num_rows - filtered.num_rows,
|
|
846
|
+
)
|
|
847
|
+
return filtered
|
|
848
|
+
|
|
849
|
+
# =========================================================================
|
|
850
|
+
# Column Query Helpers
|
|
851
|
+
# =========================================================================
|
|
852
|
+
|
|
853
|
+
@property
|
|
854
|
+
def filtered_columns(self) -> frozenset[str]:
|
|
855
|
+
"""Set of column names that have filters applied.
|
|
856
|
+
|
|
857
|
+
Use case: Quick check of which columns are constrained.
|
|
858
|
+
"""
|
|
859
|
+
return frozenset(f.column_name for f in self.filters)
|
|
860
|
+
|
|
861
|
+
def get_column_filters(self, column_name: str) -> list[Filter]:
|
|
862
|
+
"""Get all top-level filters for a specific column.
|
|
863
|
+
|
|
864
|
+
Use case: Inspect what constraints apply to a column.
|
|
865
|
+
|
|
866
|
+
Args:
|
|
867
|
+
column_name: Name of the column to get filters for.
|
|
868
|
+
|
|
869
|
+
Returns:
|
|
870
|
+
List of filters that apply to the column.
|
|
871
|
+
|
|
872
|
+
"""
|
|
873
|
+
return [f for f in self.filters if f.column_name == column_name]
|
|
874
|
+
|
|
875
|
+
def has_filter_for_column(self, column_name: str) -> bool:
|
|
876
|
+
"""Check if any filter constrains the given column.
|
|
877
|
+
|
|
878
|
+
Args:
|
|
879
|
+
column_name: Name of the column to check.
|
|
880
|
+
|
|
881
|
+
Returns:
|
|
882
|
+
True if at least one filter applies to the column.
|
|
883
|
+
|
|
884
|
+
"""
|
|
885
|
+
return any(f.column_name == column_name for f in self.filters)
|
|
886
|
+
|
|
887
|
+
def get_column_constant(self, column_name: str) -> pa.Scalar[Any] | None:
|
|
888
|
+
"""Get constant value if column has an equality filter.
|
|
889
|
+
|
|
890
|
+
Use case: Partition key lookup, exact match optimization.
|
|
891
|
+
|
|
892
|
+
Args:
|
|
893
|
+
column_name: Name of the column to check.
|
|
894
|
+
|
|
895
|
+
Returns:
|
|
896
|
+
The constant value if an equality filter exists, None otherwise.
|
|
897
|
+
|
|
898
|
+
"""
|
|
899
|
+
# Descend one level into AndFilter children (consistent with
|
|
900
|
+
# get_column_bounds): DuckDB commonly pushes `col = v` / `col IN (...)`
|
|
901
|
+
# conjoined with derived range bounds as a single AndFilter.
|
|
902
|
+
for f in self._collect_column_filters(column_name):
|
|
903
|
+
if isinstance(f, ConstantFilter) and f.op == ComparisonOp.EQ:
|
|
904
|
+
return f.value
|
|
905
|
+
return None
|
|
906
|
+
|
|
907
|
+
def get_column_in_values(self, column_name: str) -> pa.Array[Any] | None:
|
|
908
|
+
"""Get IN list values if column has an IN filter.
|
|
909
|
+
|
|
910
|
+
Use case: Multi-key lookup, batch fetching.
|
|
911
|
+
|
|
912
|
+
Args:
|
|
913
|
+
column_name: Name of the column to check.
|
|
914
|
+
|
|
915
|
+
Returns:
|
|
916
|
+
Arrow array of IN values if an IN filter exists, None otherwise.
|
|
917
|
+
|
|
918
|
+
"""
|
|
919
|
+
# Descend one level into AndFilter children (consistent with
|
|
920
|
+
# get_column_bounds): an `IN (...)` predicate is frequently pushed as
|
|
921
|
+
# AndFilter(InFilter, >=min, <=max).
|
|
922
|
+
for f in self._collect_column_filters(column_name):
|
|
923
|
+
if isinstance(f, InFilter):
|
|
924
|
+
return f.values
|
|
925
|
+
return None
|
|
926
|
+
|
|
927
|
+
def get_column_values(self, column_name: str) -> pa.Array[Any] | None:
|
|
928
|
+
"""Get all distinct values a column could have based on filters.
|
|
929
|
+
|
|
930
|
+
Returns values from equality (=) or IN filters as an Arrow array.
|
|
931
|
+
Useful for partition pruning when partitions are keyed by specific values.
|
|
932
|
+
|
|
933
|
+
Use case: Partition key lookup, directory-based partitioning.
|
|
934
|
+
|
|
935
|
+
Args:
|
|
936
|
+
column_name: Name of the column to check.
|
|
937
|
+
|
|
938
|
+
Returns:
|
|
939
|
+
Arrow array of discrete values if available, None otherwise.
|
|
940
|
+
|
|
941
|
+
"""
|
|
942
|
+
# Descend one level into AndFilter children (consistent with
|
|
943
|
+
# get_column_bounds). DuckDB pushes `col = v` / `col IN (...)` conjoined
|
|
944
|
+
# with derived range bounds as a single AndFilter; without this descent
|
|
945
|
+
# the discrete-value fast path silently misses those and callers fall
|
|
946
|
+
# back to scanning every partition.
|
|
947
|
+
for f in self._collect_column_filters(column_name):
|
|
948
|
+
if isinstance(f, ConstantFilter) and f.op == ComparisonOp.EQ:
|
|
949
|
+
# Wrap single value in array for consistent return type
|
|
950
|
+
arr: pa.Array[Any] = pa.array([f.value.as_py()], type=f.value.type)
|
|
951
|
+
return arr
|
|
952
|
+
elif isinstance(f, InFilter):
|
|
953
|
+
return f.values
|
|
954
|
+
elif isinstance(f, OrFilter):
|
|
955
|
+
# `col = a OR col = b`, `col IN (...) OR col = c`, etc. The
|
|
956
|
+
# column's possible values are the UNION of the branches — but
|
|
957
|
+
# only when *every* branch pins this column to discrete values.
|
|
958
|
+
# If any branch is a range/IS NULL, or constrains a different
|
|
959
|
+
# column (so this column is unbounded in that branch), the set
|
|
960
|
+
# is not enumerable and we must fall through to None. Unlike the
|
|
961
|
+
# AND case, returning one branch's values would be an unsafe
|
|
962
|
+
# subset (a pruning caller would skip the other branches' rows).
|
|
963
|
+
union = self._or_discrete_values(f, column_name)
|
|
964
|
+
if union is not None:
|
|
965
|
+
return union
|
|
966
|
+
return None
|
|
967
|
+
|
|
968
|
+
def _or_discrete_values(self, or_filter: OrFilter, column_name: str) -> pa.Array[Any] | None:
|
|
969
|
+
"""Union of discrete values for ``column_name`` across all OR branches.
|
|
970
|
+
|
|
971
|
+
Returns the deduplicated union iff every child constrains ``column_name``
|
|
972
|
+
to discrete values (``=`` or ``IN``); otherwise None (the column could
|
|
973
|
+
take any value through a non-discrete branch, so it cannot be enumerated).
|
|
974
|
+
Descends one level into ``OrFilter`` children, consistent with the
|
|
975
|
+
single-level descent used elsewhere; deeper nesting yields None.
|
|
976
|
+
"""
|
|
977
|
+
values: list[Any] = []
|
|
978
|
+
for child in or_filter.children:
|
|
979
|
+
if child.column_name != column_name:
|
|
980
|
+
# A branch that constrains some other column (or none) leaves
|
|
981
|
+
# this column unbounded within that branch -> not enumerable.
|
|
982
|
+
return None
|
|
983
|
+
if isinstance(child, ConstantFilter) and child.op == ComparisonOp.EQ:
|
|
984
|
+
values.append(child.value.as_py())
|
|
985
|
+
elif isinstance(child, InFilter):
|
|
986
|
+
values.extend(child.values.to_pylist())
|
|
987
|
+
else:
|
|
988
|
+
return None
|
|
989
|
+
if not values:
|
|
990
|
+
return None
|
|
991
|
+
# Deduplicate, preserving first-seen order for stable output.
|
|
992
|
+
seen: set[Any] = set()
|
|
993
|
+
deduped: list[Any] = []
|
|
994
|
+
for v in values:
|
|
995
|
+
if v not in seen:
|
|
996
|
+
seen.add(v)
|
|
997
|
+
deduped.append(v)
|
|
998
|
+
result: pa.Array[Any] = pa.array(deduped)
|
|
999
|
+
return result
|
|
1000
|
+
|
|
1001
|
+
def get_column_bounds(self, column_name: str) -> ColumnBounds | None:
|
|
1002
|
+
"""Extract numeric bounds from comparison filters.
|
|
1003
|
+
|
|
1004
|
+
Analyzes gt/ge/lt/le filters to determine value range.
|
|
1005
|
+
|
|
1006
|
+
Use case: Range scans, partition pruning, bounded iteration.
|
|
1007
|
+
|
|
1008
|
+
Args:
|
|
1009
|
+
column_name: Name of the column to extract bounds for.
|
|
1010
|
+
|
|
1011
|
+
Returns:
|
|
1012
|
+
ColumnBounds with min/max values if bounds exist, None otherwise.
|
|
1013
|
+
|
|
1014
|
+
"""
|
|
1015
|
+
min_val: pa.Scalar[Any] | None = None
|
|
1016
|
+
min_inc = True
|
|
1017
|
+
max_val: pa.Scalar[Any] | None = None
|
|
1018
|
+
max_inc = True
|
|
1019
|
+
|
|
1020
|
+
for f in self._collect_column_filters(column_name):
|
|
1021
|
+
if isinstance(f, ConstantFilter):
|
|
1022
|
+
if f.op == ComparisonOp.GT:
|
|
1023
|
+
if min_val is None or f.value.as_py() > min_val.as_py():
|
|
1024
|
+
min_val, min_inc = f.value, False
|
|
1025
|
+
elif f.op == ComparisonOp.GE:
|
|
1026
|
+
if min_val is None or f.value.as_py() >= min_val.as_py():
|
|
1027
|
+
min_val, min_inc = f.value, True
|
|
1028
|
+
elif f.op == ComparisonOp.LT:
|
|
1029
|
+
if max_val is None or f.value.as_py() < max_val.as_py():
|
|
1030
|
+
max_val, max_inc = f.value, False
|
|
1031
|
+
elif f.op == ComparisonOp.LE:
|
|
1032
|
+
if max_val is None or f.value.as_py() <= max_val.as_py():
|
|
1033
|
+
max_val, max_inc = f.value, True
|
|
1034
|
+
elif f.op == ComparisonOp.EQ:
|
|
1035
|
+
# Equality implies exact bounds
|
|
1036
|
+
return ColumnBounds(f.value, True, f.value, True)
|
|
1037
|
+
|
|
1038
|
+
if min_val is None and max_val is None:
|
|
1039
|
+
return None
|
|
1040
|
+
return ColumnBounds(min_val, min_inc, max_val, max_inc)
|
|
1041
|
+
|
|
1042
|
+
def _collect_column_filters(self, column_name: str) -> list[Filter]:
|
|
1043
|
+
"""Collect filters for a column from top-level and direct AND children.
|
|
1044
|
+
|
|
1045
|
+
Note: Only descends one level into AndFilter children. Deeply nested
|
|
1046
|
+
AND filters (AND within AND) are not traversed. This is sufficient
|
|
1047
|
+
for most query patterns where bounds filters are either at top level
|
|
1048
|
+
or grouped in a single AND.
|
|
1049
|
+
"""
|
|
1050
|
+
result: list[Filter] = []
|
|
1051
|
+
for f in self.filters:
|
|
1052
|
+
if f.column_name == column_name:
|
|
1053
|
+
if isinstance(f, AndFilter):
|
|
1054
|
+
result.extend(c for c in f.children if c.column_name == column_name)
|
|
1055
|
+
else:
|
|
1056
|
+
result.append(f)
|
|
1057
|
+
return result
|
|
1058
|
+
|
|
1059
|
+
# =========================================================================
|
|
1060
|
+
# SQL Generation
|
|
1061
|
+
# =========================================================================
|
|
1062
|
+
|
|
1063
|
+
def to_sql(
|
|
1064
|
+
self,
|
|
1065
|
+
quote_identifier: Callable[[str], str] | None = None,
|
|
1066
|
+
placeholder: str = "?",
|
|
1067
|
+
) -> tuple[str, list[Any]]:
|
|
1068
|
+
"""Convert filters to SQL WHERE clause with parameters.
|
|
1069
|
+
|
|
1070
|
+
Args:
|
|
1071
|
+
quote_identifier: Function to quote column names (default: double quotes)
|
|
1072
|
+
placeholder: Parameter placeholder style ("?", "%s", ":name")
|
|
1073
|
+
|
|
1074
|
+
Returns:
|
|
1075
|
+
Tuple of (where_clause, params) - clause excludes "WHERE" keyword.
|
|
1076
|
+
|
|
1077
|
+
"""
|
|
1078
|
+
if not self.filters:
|
|
1079
|
+
return "", []
|
|
1080
|
+
|
|
1081
|
+
quote = quote_identifier or (lambda s: f'"{s}"')
|
|
1082
|
+
conditions: list[str] = []
|
|
1083
|
+
params: list[Any] = []
|
|
1084
|
+
|
|
1085
|
+
for f in self.filters:
|
|
1086
|
+
sql, ps = _filter_to_sql(f, quote, placeholder, len(params))
|
|
1087
|
+
conditions.append(sql)
|
|
1088
|
+
params.extend(ps)
|
|
1089
|
+
|
|
1090
|
+
return " AND ".join(conditions), params
|
|
1091
|
+
|
|
1092
|
+
# =========================================================================
|
|
1093
|
+
# Dunder Methods
|
|
1094
|
+
# =========================================================================
|
|
1095
|
+
|
|
1096
|
+
@classmethod
|
|
1097
|
+
def empty(cls) -> PushdownFilters:
|
|
1098
|
+
"""Create an empty PushdownFilters instance (no filters)."""
|
|
1099
|
+
return cls(filters=())
|
|
1100
|
+
|
|
1101
|
+
def __bool__(self) -> bool:
|
|
1102
|
+
"""Return True if there are any filters."""
|
|
1103
|
+
return len(self.filters) > 0
|
|
1104
|
+
|
|
1105
|
+
def __len__(self) -> int:
|
|
1106
|
+
"""Return the number of top-level filters."""
|
|
1107
|
+
return len(self.filters)
|
|
1108
|
+
|
|
1109
|
+
def __iter__(self) -> Iterator[Filter]:
|
|
1110
|
+
"""Iterate over top-level filters."""
|
|
1111
|
+
return iter(self.filters)
|
|
1112
|
+
|
|
1113
|
+
def __contains__(self, column_name: str) -> bool:
|
|
1114
|
+
"""Check if any filter constrains the given column.
|
|
1115
|
+
|
|
1116
|
+
Allows 'column_name in filters' syntax.
|
|
1117
|
+
"""
|
|
1118
|
+
return any(f.column_name == column_name for f in self.filters)
|
|
1119
|
+
|
|
1120
|
+
def __repr__(self) -> str:
|
|
1121
|
+
"""Return string representation for debugging."""
|
|
1122
|
+
if not self.filters:
|
|
1123
|
+
return "PushdownFilters([])"
|
|
1124
|
+
filters_repr = ", ".join(repr(f) for f in self.filters)
|
|
1125
|
+
return f"PushdownFilters([{filters_repr}])"
|
|
1126
|
+
|
|
1127
|
+
|
|
1128
|
+
# =============================================================================
|
|
1129
|
+
# SQL Generation Helper
|
|
1130
|
+
# =============================================================================
|
|
1131
|
+
|
|
1132
|
+
|
|
1133
|
+
def _filter_to_sql(
|
|
1134
|
+
f: Filter,
|
|
1135
|
+
quote: Callable[[str], str],
|
|
1136
|
+
placeholder: str,
|
|
1137
|
+
param_offset: int,
|
|
1138
|
+
) -> tuple[str, list[Any]]:
|
|
1139
|
+
"""Convert a single filter to SQL fragment.
|
|
1140
|
+
|
|
1141
|
+
Args:
|
|
1142
|
+
f: Filter to convert.
|
|
1143
|
+
quote: Function to quote identifiers.
|
|
1144
|
+
placeholder: Parameter placeholder style.
|
|
1145
|
+
param_offset: Current parameter offset (for recursive calls).
|
|
1146
|
+
|
|
1147
|
+
Returns:
|
|
1148
|
+
Tuple of (sql_fragment, params).
|
|
1149
|
+
|
|
1150
|
+
"""
|
|
1151
|
+
col = quote(f.column_name)
|
|
1152
|
+
|
|
1153
|
+
match f:
|
|
1154
|
+
case ConstantFilter(op=op, value=value):
|
|
1155
|
+
return f"{col} {op.symbol} {placeholder}", [value.as_py()]
|
|
1156
|
+
|
|
1157
|
+
case IsNullFilter():
|
|
1158
|
+
return f"{col} IS NULL", []
|
|
1159
|
+
|
|
1160
|
+
case IsNotNullFilter():
|
|
1161
|
+
return f"{col} IS NOT NULL", []
|
|
1162
|
+
|
|
1163
|
+
case InFilter(values=values):
|
|
1164
|
+
placeholders = ", ".join([placeholder] * len(values))
|
|
1165
|
+
return f"{col} IN ({placeholders})", values.to_pylist()
|
|
1166
|
+
|
|
1167
|
+
case AndFilter(children=children):
|
|
1168
|
+
parts: list[str] = []
|
|
1169
|
+
params: list[Any] = []
|
|
1170
|
+
for child in children:
|
|
1171
|
+
offset = param_offset + len(params)
|
|
1172
|
+
sql, ps = _filter_to_sql(child, quote, placeholder, offset)
|
|
1173
|
+
parts.append(sql)
|
|
1174
|
+
params.extend(ps)
|
|
1175
|
+
return f"({' AND '.join(parts)})", params
|
|
1176
|
+
|
|
1177
|
+
case OrFilter(children=children):
|
|
1178
|
+
parts = []
|
|
1179
|
+
params = []
|
|
1180
|
+
for child in children:
|
|
1181
|
+
offset = param_offset + len(params)
|
|
1182
|
+
sql, ps = _filter_to_sql(child, quote, placeholder, offset)
|
|
1183
|
+
parts.append(sql)
|
|
1184
|
+
params.extend(ps)
|
|
1185
|
+
return f"({' OR '.join(parts)})", params
|
|
1186
|
+
|
|
1187
|
+
case StructFilter(child_name=child_name, child_filter=child_filter):
|
|
1188
|
+
# Struct access varies by database - use dot notation as default
|
|
1189
|
+
nested_col = f"{f.column_name}.{child_name}"
|
|
1190
|
+
return _filter_to_sql(child_filter, lambda _: quote(nested_col), placeholder, param_offset)
|
|
1191
|
+
|
|
1192
|
+
case ExpressionFilter(expr=expr):
|
|
1193
|
+
# Expression filters use inline constants in to_sql() because
|
|
1194
|
+
# they may contain geometry literals that can't be parameterized.
|
|
1195
|
+
# Constants are embedded directly in the SQL string.
|
|
1196
|
+
return expr.to_sql(f.column_name), []
|
|
1197
|
+
|
|
1198
|
+
case _:
|
|
1199
|
+
raise ValueError(f"Unknown filter type: {type(f)}")
|
|
1200
|
+
|
|
1201
|
+
|
|
1202
|
+
# =============================================================================
|
|
1203
|
+
# Deserialization
|
|
1204
|
+
# =============================================================================
|
|
1205
|
+
|
|
1206
|
+
|
|
1207
|
+
def deserialize_filters(
|
|
1208
|
+
batch: pa.RecordBatch,
|
|
1209
|
+
join_keys: list[pa.RecordBatch] | None = None,
|
|
1210
|
+
) -> PushdownFilters:
|
|
1211
|
+
"""Deserialize Arrow IPC bytes to typed AST.
|
|
1212
|
+
|
|
1213
|
+
Args:
|
|
1214
|
+
batch: Arrow RecordBatch containing the serialized filters.
|
|
1215
|
+
join_keys: Optional list of single-column Arrow RecordBatches, one per
|
|
1216
|
+
IN filter column. Each batch may have a different row count.
|
|
1217
|
+
Referenced by ``join_keys`` filter type entries in the filter spec.
|
|
1218
|
+
|
|
1219
|
+
Returns:
|
|
1220
|
+
PushdownFilters container with parsed filter AST.
|
|
1221
|
+
|
|
1222
|
+
Raises:
|
|
1223
|
+
FilterDeserializationError: If parsing fails.
|
|
1224
|
+
FilterVersionError: If version is unsupported.
|
|
1225
|
+
|
|
1226
|
+
"""
|
|
1227
|
+
# Validate version
|
|
1228
|
+
metadata = batch.schema.field(0).metadata
|
|
1229
|
+
if metadata is None:
|
|
1230
|
+
raise FilterVersionError("Missing vgi_filter_version metadata")
|
|
1231
|
+
version = metadata.get(b"vgi_filter_version", b"").decode()
|
|
1232
|
+
if version != _SUPPORTED_VERSION:
|
|
1233
|
+
raise FilterVersionError(f"Unsupported filter version: {version!r}")
|
|
1234
|
+
|
|
1235
|
+
_log_debug("deserialize_version", version=version)
|
|
1236
|
+
|
|
1237
|
+
# Parse JSON spec
|
|
1238
|
+
try:
|
|
1239
|
+
filter_specs = json.loads(batch.column(0)[0].as_py())
|
|
1240
|
+
except Exception as e:
|
|
1241
|
+
_log_debug("deserialize_json_error", error=str(e))
|
|
1242
|
+
raise FilterDeserializationError(f"Failed to parse filter JSON: {e}") from e
|
|
1243
|
+
|
|
1244
|
+
_log_debug("deserialize_specs", num_filters=len(filter_specs), specs=filter_specs)
|
|
1245
|
+
|
|
1246
|
+
# Value resolver - returns scalar for value_ref N from column N+1
|
|
1247
|
+
def get_value(ref: int) -> pa.Scalar[Any]:
|
|
1248
|
+
value = batch.column(ref + 1)[0]
|
|
1249
|
+
_log_debug(
|
|
1250
|
+
"deserialize_value_ref",
|
|
1251
|
+
ref=ref,
|
|
1252
|
+
column_index=ref + 1,
|
|
1253
|
+
value_type=str(value.type),
|
|
1254
|
+
value=str(value),
|
|
1255
|
+
)
|
|
1256
|
+
return value # type: ignore[no-any-return]
|
|
1257
|
+
|
|
1258
|
+
def get_field(ref: int) -> pa.Field[Any]:
|
|
1259
|
+
"""Get the Arrow field for a value_ref (column ref+1 in the batch)."""
|
|
1260
|
+
return batch.schema.field(ref + 1)
|
|
1261
|
+
|
|
1262
|
+
def get_join_keys_column(column_name: str) -> pa.Array[Any] | None:
|
|
1263
|
+
"""Resolve a column from the join keys batches by name."""
|
|
1264
|
+
if not join_keys:
|
|
1265
|
+
return None
|
|
1266
|
+
for keys_batch in join_keys:
|
|
1267
|
+
try:
|
|
1268
|
+
return keys_batch.column(column_name)
|
|
1269
|
+
except KeyError:
|
|
1270
|
+
continue
|
|
1271
|
+
return None
|
|
1272
|
+
|
|
1273
|
+
# Parse filters
|
|
1274
|
+
try:
|
|
1275
|
+
parsed: list[Filter] = []
|
|
1276
|
+
for spec in filter_specs:
|
|
1277
|
+
f = _parse_filter(spec, get_value, get_field, get_join_keys_column)
|
|
1278
|
+
if f is not None:
|
|
1279
|
+
parsed.append(f)
|
|
1280
|
+
filters = tuple(parsed)
|
|
1281
|
+
except Exception as e:
|
|
1282
|
+
_log_debug("deserialize_parse_error", error=str(e))
|
|
1283
|
+
raise FilterDeserializationError(f"Failed to parse filters: {e}") from e
|
|
1284
|
+
|
|
1285
|
+
_log_debug(
|
|
1286
|
+
"deserialize_complete",
|
|
1287
|
+
num_filters=len(filters),
|
|
1288
|
+
filter_types=[type(f).__name__ for f in filters],
|
|
1289
|
+
columns=[f.column_name for f in filters],
|
|
1290
|
+
)
|
|
1291
|
+
|
|
1292
|
+
return PushdownFilters(filters=filters, version=version, join_keys_batches=join_keys)
|
|
1293
|
+
|
|
1294
|
+
|
|
1295
|
+
def _parse_filter(
|
|
1296
|
+
spec: dict[str, Any],
|
|
1297
|
+
get_value: Callable[[int], pa.Scalar[Any]],
|
|
1298
|
+
get_field: Callable[[int], pa.Field[Any]],
|
|
1299
|
+
get_join_keys_column: Callable[[str], pa.Array[Any] | None] | None = None,
|
|
1300
|
+
) -> Filter | None:
|
|
1301
|
+
"""Parse a single filter spec into a typed Filter object.
|
|
1302
|
+
|
|
1303
|
+
Args:
|
|
1304
|
+
spec: Filter specification dict from JSON.
|
|
1305
|
+
get_value: Function to get Arrow scalar by value_ref index.
|
|
1306
|
+
get_field: Function to get Arrow field by value_ref index (for extension metadata).
|
|
1307
|
+
get_join_keys_column: Function to resolve a column from the join keys batch by name.
|
|
1308
|
+
Returns None if no join keys batch or column not found.
|
|
1309
|
+
|
|
1310
|
+
Returns:
|
|
1311
|
+
Typed Filter object, or None if the filter references missing join keys.
|
|
1312
|
+
|
|
1313
|
+
Raises:
|
|
1314
|
+
FilterDeserializationError: If filter type is unknown.
|
|
1315
|
+
|
|
1316
|
+
"""
|
|
1317
|
+
column_name = spec["column_name"]
|
|
1318
|
+
column_index = spec["column_index"]
|
|
1319
|
+
filter_type = spec["type"]
|
|
1320
|
+
|
|
1321
|
+
_log_debug(
|
|
1322
|
+
"parse_filter_start",
|
|
1323
|
+
filter_type=filter_type,
|
|
1324
|
+
column_name=column_name,
|
|
1325
|
+
column_index=column_index,
|
|
1326
|
+
)
|
|
1327
|
+
|
|
1328
|
+
if filter_type == FilterType.CONSTANT.value:
|
|
1329
|
+
op = ComparisonOp(spec["op"])
|
|
1330
|
+
value = get_value(spec["value_ref"])
|
|
1331
|
+
result = ConstantFilter(
|
|
1332
|
+
column_name=column_name,
|
|
1333
|
+
column_index=column_index,
|
|
1334
|
+
op=op,
|
|
1335
|
+
value=value,
|
|
1336
|
+
)
|
|
1337
|
+
_log_debug(
|
|
1338
|
+
"parse_filter_constant",
|
|
1339
|
+
column=column_name,
|
|
1340
|
+
op=op.value,
|
|
1341
|
+
value=str(value),
|
|
1342
|
+
value_type=str(value.type),
|
|
1343
|
+
)
|
|
1344
|
+
return result
|
|
1345
|
+
|
|
1346
|
+
elif filter_type == FilterType.IS_NULL.value:
|
|
1347
|
+
_log_debug("parse_filter_is_null", column=column_name)
|
|
1348
|
+
return IsNullFilter(column_name=column_name, column_index=column_index)
|
|
1349
|
+
|
|
1350
|
+
elif filter_type == FilterType.IS_NOT_NULL.value:
|
|
1351
|
+
_log_debug("parse_filter_is_not_null", column=column_name)
|
|
1352
|
+
return IsNotNullFilter(column_name=column_name, column_index=column_index)
|
|
1353
|
+
|
|
1354
|
+
elif filter_type == FilterType.IN.value:
|
|
1355
|
+
# value_ref points to a list column; extract the list's values as an array
|
|
1356
|
+
list_scalar = get_value(spec["value_ref"])
|
|
1357
|
+
# ListScalar.values gives us the underlying array
|
|
1358
|
+
# pyarrow-stubs doesn't type ListScalar.values correctly
|
|
1359
|
+
values_array: pa.Array[Any] = list_scalar.values # type: ignore[attr-defined]
|
|
1360
|
+
_log_debug(
|
|
1361
|
+
"parse_filter_in",
|
|
1362
|
+
column=column_name,
|
|
1363
|
+
num_values=len(values_array),
|
|
1364
|
+
values=values_array.to_pylist(),
|
|
1365
|
+
value_type=str(values_array.type),
|
|
1366
|
+
)
|
|
1367
|
+
return InFilter(
|
|
1368
|
+
column_name=column_name,
|
|
1369
|
+
column_index=column_index,
|
|
1370
|
+
values=values_array,
|
|
1371
|
+
)
|
|
1372
|
+
|
|
1373
|
+
elif filter_type == FilterType.JOIN_KEYS.value:
|
|
1374
|
+
keys_column_name = spec["keys_column"]
|
|
1375
|
+
join_values: pa.Array[Any] | None = get_join_keys_column(keys_column_name) if get_join_keys_column else None
|
|
1376
|
+
if join_values is None:
|
|
1377
|
+
_log_debug(
|
|
1378
|
+
"parse_filter_join_keys_missing",
|
|
1379
|
+
column=column_name,
|
|
1380
|
+
keys_column=keys_column_name,
|
|
1381
|
+
)
|
|
1382
|
+
return None # graceful degradation — DuckDB filters client-side
|
|
1383
|
+
_log_debug(
|
|
1384
|
+
"parse_filter_join_keys",
|
|
1385
|
+
column=column_name,
|
|
1386
|
+
keys_column=keys_column_name,
|
|
1387
|
+
num_values=len(join_values),
|
|
1388
|
+
value_type=str(join_values.type),
|
|
1389
|
+
)
|
|
1390
|
+
return InFilter(
|
|
1391
|
+
column_name=column_name,
|
|
1392
|
+
column_index=column_index,
|
|
1393
|
+
values=join_values,
|
|
1394
|
+
)
|
|
1395
|
+
|
|
1396
|
+
elif filter_type == FilterType.AND.value:
|
|
1397
|
+
_log_debug(
|
|
1398
|
+
"parse_filter_and_start",
|
|
1399
|
+
column=column_name,
|
|
1400
|
+
num_children=len(spec["children"]),
|
|
1401
|
+
)
|
|
1402
|
+
children = tuple(
|
|
1403
|
+
f
|
|
1404
|
+
for c in spec["children"]
|
|
1405
|
+
if (f := _parse_filter(c, get_value, get_field, get_join_keys_column)) is not None
|
|
1406
|
+
)
|
|
1407
|
+
_log_debug("parse_filter_and_complete", column=column_name)
|
|
1408
|
+
return AndFilter(
|
|
1409
|
+
column_name=column_name,
|
|
1410
|
+
column_index=column_index,
|
|
1411
|
+
children=children,
|
|
1412
|
+
)
|
|
1413
|
+
|
|
1414
|
+
elif filter_type == FilterType.OR.value:
|
|
1415
|
+
_log_debug(
|
|
1416
|
+
"parse_filter_or_start",
|
|
1417
|
+
column=column_name,
|
|
1418
|
+
num_children=len(spec["children"]),
|
|
1419
|
+
)
|
|
1420
|
+
parsed_children: list[Filter] = []
|
|
1421
|
+
for c in spec["children"]:
|
|
1422
|
+
child = _parse_filter(c, get_value, get_field, get_join_keys_column)
|
|
1423
|
+
if child is None:
|
|
1424
|
+
# Dropping a child from OR would strengthen the filter (fewer rows pass),
|
|
1425
|
+
# which is wrong for graceful degradation. Drop the entire OR instead.
|
|
1426
|
+
_log_debug("parse_filter_or_child_missing", column=column_name)
|
|
1427
|
+
return None
|
|
1428
|
+
parsed_children.append(child)
|
|
1429
|
+
children = tuple(parsed_children)
|
|
1430
|
+
_log_debug("parse_filter_or_complete", column=column_name)
|
|
1431
|
+
return OrFilter(
|
|
1432
|
+
column_name=column_name,
|
|
1433
|
+
column_index=column_index,
|
|
1434
|
+
children=children,
|
|
1435
|
+
)
|
|
1436
|
+
|
|
1437
|
+
elif filter_type == FilterType.STRUCT.value:
|
|
1438
|
+
child_name = spec["child_name"]
|
|
1439
|
+
_log_debug(
|
|
1440
|
+
"parse_filter_struct_start",
|
|
1441
|
+
column=column_name,
|
|
1442
|
+
child_name=child_name,
|
|
1443
|
+
)
|
|
1444
|
+
child_filter = _parse_filter(spec["child_filter"], get_value, get_field, get_join_keys_column)
|
|
1445
|
+
if child_filter is None:
|
|
1446
|
+
return None
|
|
1447
|
+
_log_debug("parse_filter_struct_complete", column=column_name)
|
|
1448
|
+
return StructFilter(
|
|
1449
|
+
column_name=column_name,
|
|
1450
|
+
column_index=column_index,
|
|
1451
|
+
child_index=spec["child_index"],
|
|
1452
|
+
child_name=child_name,
|
|
1453
|
+
child_filter=child_filter,
|
|
1454
|
+
)
|
|
1455
|
+
|
|
1456
|
+
elif filter_type == FilterType.EXPRESSION.value:
|
|
1457
|
+
_log_debug("parse_filter_expression_start", column=column_name)
|
|
1458
|
+
expr = _parse_expression_node(spec["expr"], get_value, get_field)
|
|
1459
|
+
_log_debug("parse_filter_expression_complete", column=column_name)
|
|
1460
|
+
return ExpressionFilter(
|
|
1461
|
+
column_name=column_name,
|
|
1462
|
+
column_index=column_index,
|
|
1463
|
+
expr=expr,
|
|
1464
|
+
)
|
|
1465
|
+
|
|
1466
|
+
else:
|
|
1467
|
+
_log_debug("parse_filter_unknown", filter_type=filter_type)
|
|
1468
|
+
raise FilterDeserializationError(f"Unknown filter type: {filter_type}")
|
|
1469
|
+
|
|
1470
|
+
|
|
1471
|
+
def _parse_expression_node(
|
|
1472
|
+
spec: dict[str, Any],
|
|
1473
|
+
get_value: Callable[[int], pa.Scalar[Any]],
|
|
1474
|
+
get_field: Callable[[int], pa.Field[Any]],
|
|
1475
|
+
) -> ExpressionNode:
|
|
1476
|
+
"""Parse a single expression node from a JSON spec.
|
|
1477
|
+
|
|
1478
|
+
Args:
|
|
1479
|
+
spec: Expression node specification dict from JSON.
|
|
1480
|
+
get_value: Function to get Arrow scalar by value_ref index.
|
|
1481
|
+
get_field: Function to get Arrow field by value_ref index (for extension metadata).
|
|
1482
|
+
|
|
1483
|
+
Returns:
|
|
1484
|
+
Typed ExpressionNode.
|
|
1485
|
+
|
|
1486
|
+
Raises:
|
|
1487
|
+
FilterDeserializationError: If expression node type is unknown.
|
|
1488
|
+
|
|
1489
|
+
"""
|
|
1490
|
+
expr_type = spec["expr_type"]
|
|
1491
|
+
|
|
1492
|
+
if expr_type == ExpressionNodeType.COLUMN_REF.value:
|
|
1493
|
+
return ColumnRefNode(expr_type=ExpressionNodeType.COLUMN_REF, index=spec["index"])
|
|
1494
|
+
|
|
1495
|
+
elif expr_type == ExpressionNodeType.CONSTANT.value:
|
|
1496
|
+
ref = spec["value_ref"]
|
|
1497
|
+
return ConstantNode(
|
|
1498
|
+
expr_type=ExpressionNodeType.CONSTANT,
|
|
1499
|
+
value=get_value(ref),
|
|
1500
|
+
field=get_field(ref),
|
|
1501
|
+
)
|
|
1502
|
+
|
|
1503
|
+
elif expr_type == ExpressionNodeType.FUNCTION.value:
|
|
1504
|
+
children = tuple(_parse_expression_node(c, get_value, get_field) for c in spec["children"])
|
|
1505
|
+
return FunctionNode(
|
|
1506
|
+
expr_type=ExpressionNodeType.FUNCTION,
|
|
1507
|
+
function_name=spec["function_name"],
|
|
1508
|
+
children=children,
|
|
1509
|
+
)
|
|
1510
|
+
|
|
1511
|
+
elif expr_type == ExpressionNodeType.COMPARISON.value:
|
|
1512
|
+
return ComparisonNode(
|
|
1513
|
+
expr_type=ExpressionNodeType.COMPARISON,
|
|
1514
|
+
op=ComparisonOp(spec["op"]),
|
|
1515
|
+
left=_parse_expression_node(spec["left"], get_value, get_field),
|
|
1516
|
+
right=_parse_expression_node(spec["right"], get_value, get_field),
|
|
1517
|
+
)
|
|
1518
|
+
|
|
1519
|
+
elif expr_type == ExpressionNodeType.CONJUNCTION.value:
|
|
1520
|
+
children = tuple(_parse_expression_node(c, get_value, get_field) for c in spec["children"])
|
|
1521
|
+
return ConjunctionNode(
|
|
1522
|
+
expr_type=ExpressionNodeType.CONJUNCTION,
|
|
1523
|
+
conjunction_type=spec["conjunction_type"],
|
|
1524
|
+
children=children,
|
|
1525
|
+
)
|
|
1526
|
+
|
|
1527
|
+
else:
|
|
1528
|
+
raise FilterDeserializationError(f"Unknown expression node type: {expr_type}")
|