vgi-python 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. vgi/__init__.py +152 -0
  2. vgi/_duckdb.py +62 -0
  3. vgi/_storage_profile.py +132 -0
  4. vgi/_test_fixtures/__init__.py +20 -0
  5. vgi/_test_fixtures/accumulate/__init__.py +19 -0
  6. vgi/_test_fixtures/accumulate/worker.py +762 -0
  7. vgi/_test_fixtures/aggregate/__init__.py +62 -0
  8. vgi/_test_fixtures/aggregate/_common.py +21 -0
  9. vgi/_test_fixtures/aggregate/basic.py +232 -0
  10. vgi/_test_fixtures/aggregate/dynamic.py +409 -0
  11. vgi/_test_fixtures/aggregate/generic.py +86 -0
  12. vgi/_test_fixtures/aggregate/listagg.py +71 -0
  13. vgi/_test_fixtures/aggregate/percentile.py +107 -0
  14. vgi/_test_fixtures/aggregate/streaming.py +192 -0
  15. vgi/_test_fixtures/aggregate/varargs.py +75 -0
  16. vgi/_test_fixtures/aggregate/window.py +380 -0
  17. vgi/_test_fixtures/attach_options.py +308 -0
  18. vgi/_test_fixtures/bad_protocol.py +62 -0
  19. vgi/_test_fixtures/cancellable.py +336 -0
  20. vgi/_test_fixtures/catalog.py +813 -0
  21. vgi/_test_fixtures/http_server.py +394 -0
  22. vgi/_test_fixtures/nest_tensor.py +614 -0
  23. vgi/_test_fixtures/orchard_catalog.py +47 -0
  24. vgi/_test_fixtures/projection_repro/__init__.py +6 -0
  25. vgi/_test_fixtures/projection_repro/worker.py +454 -0
  26. vgi/_test_fixtures/scalar/__init__.py +116 -0
  27. vgi/_test_fixtures/scalar/_common.py +69 -0
  28. vgi/_test_fixtures/scalar/arithmetic.py +321 -0
  29. vgi/_test_fixtures/scalar/binary.py +120 -0
  30. vgi/_test_fixtures/scalar/formatting.py +176 -0
  31. vgi/_test_fixtures/scalar/geo.py +300 -0
  32. vgi/_test_fixtures/scalar/null_handling.py +107 -0
  33. vgi/_test_fixtures/scalar/random_demo.py +171 -0
  34. vgi/_test_fixtures/scalar/settings_secrets.py +102 -0
  35. vgi/_test_fixtures/scalar/type_info.py +219 -0
  36. vgi/_test_fixtures/schema_reconcile/__init__.py +29 -0
  37. vgi/_test_fixtures/schema_reconcile/worker.py +653 -0
  38. vgi/_test_fixtures/simple_writable.py +793 -0
  39. vgi/_test_fixtures/table/__init__.py +221 -0
  40. vgi/_test_fixtures/table/_common.py +162 -0
  41. vgi/_test_fixtures/table/batch_index.py +283 -0
  42. vgi/_test_fixtures/table/batch_index_broken.py +200 -0
  43. vgi/_test_fixtures/table/catalog_scans.py +162 -0
  44. vgi/_test_fixtures/table/filters.py +1005 -0
  45. vgi/_test_fixtures/table/late_materialization.py +249 -0
  46. vgi/_test_fixtures/table/make_series.py +273 -0
  47. vgi/_test_fixtures/table/misc.py +499 -0
  48. vgi/_test_fixtures/table/order_modes.py +164 -0
  49. vgi/_test_fixtures/table/pairs.py +437 -0
  50. vgi/_test_fixtures/table/partition_columns.py +472 -0
  51. vgi/_test_fixtures/table/partition_columns_broken.py +304 -0
  52. vgi/_test_fixtures/table/profiling_example.py +195 -0
  53. vgi/_test_fixtures/table/required_filters.py +234 -0
  54. vgi/_test_fixtures/table/sequence.py +710 -0
  55. vgi/_test_fixtures/table/settings.py +426 -0
  56. vgi/_test_fixtures/table/transaction_storage.py +162 -0
  57. vgi/_test_fixtures/table/tt_pushdown.py +191 -0
  58. vgi/_test_fixtures/table/versioned.py +230 -0
  59. vgi/_test_fixtures/table_in_out.py +1392 -0
  60. vgi/_test_fixtures/versioned.py +155 -0
  61. vgi/_test_fixtures/versioned_tables.py +595 -0
  62. vgi/_test_fixtures/worker.py +1631 -0
  63. vgi/_test_fixtures/writable/__init__.py +8 -0
  64. vgi/_test_fixtures/writable/generic.py +236 -0
  65. vgi/_test_fixtures/writable/table.py +149 -0
  66. vgi/_test_fixtures/writable/worker.py +1148 -0
  67. vgi/aggregate_function.py +607 -0
  68. vgi/argument_spec.py +472 -0
  69. vgi/arguments.py +1747 -0
  70. vgi/auth.py +55 -0
  71. vgi/catalog/__init__.py +88 -0
  72. vgi/catalog/attach_option.py +206 -0
  73. vgi/catalog/catalog_interface.py +2767 -0
  74. vgi/catalog/descriptors.py +870 -0
  75. vgi/catalog/duckdb_statistics.py +377 -0
  76. vgi/catalog/secret_type.py +96 -0
  77. vgi/catalog/setting.py +253 -0
  78. vgi/catalog/storage.py +372 -0
  79. vgi/client/__init__.py +67 -0
  80. vgi/client/catalog_mixin.py +1251 -0
  81. vgi/client/cli.py +582 -0
  82. vgi/client/cli_catalog.py +182 -0
  83. vgi/client/cli_schema.py +270 -0
  84. vgi/client/cli_table.py +907 -0
  85. vgi/client/cli_transaction.py +97 -0
  86. vgi/client/cli_utils.py +441 -0
  87. vgi/client/cli_view.py +303 -0
  88. vgi/client/client.py +2183 -0
  89. vgi/exceptions.py +205 -0
  90. vgi/function.py +245 -0
  91. vgi/function_storage.py +1636 -0
  92. vgi/function_storage_azure_sql.py +922 -0
  93. vgi/function_storage_cf_do.py +740 -0
  94. vgi/http/__init__.py +25 -0
  95. vgi/http/demo_storage.py +212 -0
  96. vgi/http/worker_page.py +1252 -0
  97. vgi/invocation.py +154 -0
  98. vgi/logging_config.py +93 -0
  99. vgi/meta_worker.py +661 -0
  100. vgi/metadata.py +1403 -0
  101. vgi/otel.py +406 -0
  102. vgi/protocol.py +2418 -0
  103. vgi/protocol_version.txt +1 -0
  104. vgi/py.typed +0 -0
  105. vgi/scalar_function.py +1211 -0
  106. vgi/schema_utils.py +234 -0
  107. vgi/secret_protocol.py +124 -0
  108. vgi/secret_service.py +238 -0
  109. vgi/serve.py +769 -0
  110. vgi/table_buffering_function.py +443 -0
  111. vgi/table_filter_pushdown.py +1528 -0
  112. vgi/table_function.py +1130 -0
  113. vgi/table_in_out_function.py +383 -0
  114. vgi/transactor/__init__.py +24 -0
  115. vgi/transactor/_duckdb_compat.py +27 -0
  116. vgi/transactor/client.py +137 -0
  117. vgi/transactor/protocol.py +149 -0
  118. vgi/transactor/server.py +740 -0
  119. vgi/worker.py +4761 -0
  120. vgi_python-0.8.0.dist-info/METADATA +735 -0
  121. vgi_python-0.8.0.dist-info/RECORD +124 -0
  122. vgi_python-0.8.0.dist-info/WHEEL +4 -0
  123. vgi_python-0.8.0.dist-info/entry_points.txt +5 -0
  124. vgi_python-0.8.0.dist-info/licenses/LICENSE +134 -0
@@ -0,0 +1,1528 @@
1
+ # Copyright 2025, 2026 Query Farm LLC - https://query.farm
2
+
3
+ """Filter pushdown AST classes for table functions.
4
+
5
+ This module provides:
6
+ - Filter AST classes for representing pushdown filter predicates
7
+ - ColumnBounds for extracting numeric bounds from filters
8
+ - PushdownFilters container with evaluation and helper methods
9
+ - Deserialization from Arrow IPC format
10
+
11
+ Filter Types:
12
+ ConstantFilter: Comparison with a constant value (=, !=, >, >=, <, <=)
13
+ IsNullFilter: IS NULL check
14
+ IsNotNullFilter: IS NOT NULL check
15
+ InFilter: Set membership (IN clause)
16
+ AndFilter: Conjunction of child filters
17
+ OrFilter: Disjunction of child filters
18
+ StructFilter: Nested struct field filter
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import dataclasses
24
+ import json
25
+ import logging
26
+ import os
27
+ import threading
28
+ from dataclasses import dataclass
29
+ from enum import Enum
30
+ from typing import TYPE_CHECKING, Any, cast
31
+
32
+ import pyarrow as pa
33
+ import pyarrow.compute as pc
34
+
35
+ if TYPE_CHECKING:
36
+ from collections.abc import Callable, Iterator
37
+
38
+ # =============================================================================
39
+ # Debug Logging
40
+ # =============================================================================
41
+
42
+ # Enable with VGI_FILTER_DEBUG=1 for detailed filter pushdown diagnostics
43
+ _FILTER_DEBUG = os.environ.get("VGI_FILTER_DEBUG", "").lower() in ("1", "true", "yes")
44
+ _filter_logger = logging.getLogger("vgi.filter_pushdown")
45
+
46
+
47
+ def _log_debug(event: str, **kwargs: Any) -> None:
48
+ """Log a debug message if VGI_FILTER_DEBUG is enabled."""
49
+ if _FILTER_DEBUG:
50
+ extra = " ".join(f"{k}={v}" for k, v in kwargs.items())
51
+ _filter_logger.debug("%s %s", event, extra) if extra else _filter_logger.debug("%s", event)
52
+
53
+
54
+ # Supported filter protocol version
55
+ _SUPPORTED_VERSION = "1"
56
+
57
+
58
+ def _strip_extension(x: Any) -> Any:
59
+ """Unwrap a canonical Arrow extension wrapper, returning the storage form.
60
+
61
+ Filter literals serialised by the DuckDB VGI extension carry their
62
+ canonical Arrow extension type (e.g. ``arrow.bool8`` for BOOLEAN, see
63
+ ``vgi/duckdb/src/common/arrow/arrow_type_extension.cpp``). PyArrow's
64
+ binary compute kernels are type-pair-keyed and have no entry for
65
+ ``equal(bool, extension<arrow.bool8>)`` — so we must strip the wrapper
66
+ before comparing. Pass the result of this helper to ``pc.equal`` and
67
+ friends.
68
+
69
+ For Arrays / ChunkedArrays, returns the ``.storage`` array. For
70
+ Scalars, returns the ``.value`` storage scalar. Plain (non-extension)
71
+ inputs are returned unchanged.
72
+
73
+ Note: the type-check uses ``pa.BaseExtensionType`` rather than
74
+ ``pa.ExtensionType`` because canonical Arrow extension types like
75
+ ``Bool8Type`` and ``UuidType`` inherit from ``BaseExtensionType``
76
+ directly without going through the user-extension class.
77
+ """
78
+ if isinstance(x.type, pa.BaseExtensionType):
79
+ if isinstance(x, (pa.Array, pa.ChunkedArray)):
80
+ # Narrowed to Array/ChunkedArray, but an extension-typed one —
81
+ # ``.storage`` exists at runtime; re-widen so mypy allows it.
82
+ return cast(Any, x).storage
83
+ return x.value
84
+ return x
85
+
86
+
87
+ def _normalize_for_compare(col: Any, val: Any) -> tuple[Any, Any]:
88
+ """Bring a (column, scalar) pair into a shape pyarrow.compute can compare.
89
+
90
+ Three transforms in order:
91
+ 1. Strip canonical Arrow extension wrappers from both sides.
92
+ 2. Decode a dictionary-encoded column to its value type. The
93
+ compute kernels (``is_in``, ``equal``, …) accept a dictionary
94
+ column paired with a plain literal, but throw if the literal is
95
+ *also* dictionary-encoded — so decoding the column (rather than
96
+ casting the literal up to a dictionary) is the path that
97
+ resolves. See ``ArrowTypeError: Array type doesn't match type
98
+ of values set``.
99
+ 3. If the resulting types still differ (e.g. column is plain
100
+ ``bool_`` while the literal came over the wire as
101
+ ``arrow.bool8`` and stripped to ``int8``), cast the literal to
102
+ the column's type so the kernel resolves.
103
+ """
104
+ col = _strip_extension(col)
105
+ val = _strip_extension(val)
106
+ if pa.types.is_dictionary(col.type):
107
+ col = col.cast(col.type.value_type)
108
+ if val.type != col.type:
109
+ val = val.cast(col.type)
110
+ return col, val
111
+
112
+
113
+ def _make_bool_array(value: bool, length: int) -> pa.BooleanArray:
114
+ """Create a boolean array of constant value.
115
+
116
+ Used for empty AND (all True) and empty OR (all False) filter results.
117
+ """
118
+ return pa.repeat(pa.scalar(value), length)
119
+
120
+
121
+ __all__ = [
122
+ # Exceptions
123
+ "FilterError",
124
+ "FilterDeserializationError",
125
+ "FilterVersionError",
126
+ # Enums
127
+ "FilterType",
128
+ "ComparisonOp",
129
+ "ExpressionNodeType",
130
+ # Filter classes
131
+ "Filter",
132
+ "ConstantFilter",
133
+ "IsNullFilter",
134
+ "IsNotNullFilter",
135
+ "InFilter",
136
+ "AndFilter",
137
+ "OrFilter",
138
+ "StructFilter",
139
+ "ExpressionFilter",
140
+ # Expression node classes
141
+ "ExpressionNode",
142
+ "ColumnRefNode",
143
+ "ConstantNode",
144
+ "FunctionNode",
145
+ "ComparisonNode",
146
+ "ConjunctionNode",
147
+ # Helpers
148
+ "ColumnBounds",
149
+ "PushdownFilters",
150
+ # Functions
151
+ "deserialize_filters",
152
+ ]
153
+
154
+
155
+ # =============================================================================
156
+ # Exceptions
157
+ # =============================================================================
158
+
159
+
160
+ class FilterError(Exception):
161
+ """Base exception for filter pushdown errors."""
162
+
163
+
164
+ class FilterDeserializationError(FilterError):
165
+ """Failed to parse filter IPC bytes."""
166
+
167
+
168
+ class FilterVersionError(FilterError):
169
+ """Unsupported filter protocol version."""
170
+
171
+
172
+ # =============================================================================
173
+ # Enums
174
+ # =============================================================================
175
+
176
+
177
+ class FilterType(Enum):
178
+ """Filter type identifiers matching the JSON protocol."""
179
+
180
+ CONSTANT = "constant"
181
+ IS_NULL = "is_null"
182
+ IS_NOT_NULL = "is_not_null"
183
+ IN = "in"
184
+ JOIN_KEYS = "join_keys"
185
+ AND = "and"
186
+ OR = "or"
187
+ STRUCT = "struct"
188
+ EXPRESSION = "expression"
189
+
190
+
191
+ class ExpressionNodeType(Enum):
192
+ """Expression node type identifiers matching the JSON protocol."""
193
+
194
+ COLUMN_REF = "column_ref"
195
+ CONSTANT = "constant"
196
+ FUNCTION = "function"
197
+ COMPARISON = "comparison"
198
+ CONJUNCTION = "conjunction"
199
+
200
+
201
+ class ComparisonOp(Enum):
202
+ """Comparison operators for constant filters."""
203
+
204
+ EQ = "eq" # =
205
+ NE = "ne" # !=
206
+ GT = "gt" # >
207
+ GE = "ge" # >=
208
+ LT = "lt" # <
209
+ LE = "le" # <=
210
+
211
+ @property
212
+ def symbol(self) -> str:
213
+ """Return the SQL symbol for this operator."""
214
+ symbols = {"eq": "=", "ne": "!=", "gt": ">", "ge": ">=", "lt": "<", "le": "<="}
215
+ return symbols[self.value]
216
+
217
+
218
+ # =============================================================================
219
+ # Filter Base Class
220
+ # =============================================================================
221
+
222
+
223
+ @dataclass(frozen=True, slots=True)
224
+ class Filter:
225
+ """Base class for all filter types.
226
+
227
+ Attributes:
228
+ column_name: Name of the column this filter applies to.
229
+ column_index: Index of the column in the output schema.
230
+
231
+ """
232
+
233
+ column_name: str
234
+ column_index: int
235
+
236
+ def evaluate(self, batch: pa.RecordBatch) -> pa.BooleanArray:
237
+ """Evaluate filter against batch using PyArrow compute.
238
+
239
+ Args:
240
+ batch: RecordBatch to evaluate filter against.
241
+
242
+ Returns:
243
+ Boolean array with True for rows that pass the filter.
244
+
245
+ Raises:
246
+ NotImplementedError: Base class does not implement evaluation.
247
+
248
+ """
249
+ raise NotImplementedError
250
+
251
+
252
+ # =============================================================================
253
+ # Filter Type Classes
254
+ # =============================================================================
255
+
256
+
257
+ @dataclass(frozen=True, slots=True)
258
+ class ConstantFilter(Filter):
259
+ """Comparison filter: column <op> value.
260
+
261
+ Examples:
262
+ age >= 18
263
+ status = 'active'
264
+ price < 100.0
265
+
266
+ """
267
+
268
+ op: ComparisonOp
269
+ value: pa.Scalar[Any]
270
+
271
+ def evaluate(self, batch: pa.RecordBatch) -> pa.BooleanArray:
272
+ """Evaluate comparison against batch column."""
273
+ col, val = _normalize_for_compare(batch.column(self.column_index), self.value)
274
+ # _normalize_for_compare returns (Any, Any) — the compute kernels
275
+ # below resolve to BooleanArray at runtime; cast once on the way out.
276
+ result: Any
277
+ match self.op:
278
+ case ComparisonOp.EQ:
279
+ result = pc.equal(col, val)
280
+ case ComparisonOp.NE:
281
+ result = pc.not_equal(col, val)
282
+ case ComparisonOp.GT:
283
+ result = pc.greater(col, val)
284
+ case ComparisonOp.GE:
285
+ result = pc.greater_equal(col, val)
286
+ case ComparisonOp.LT:
287
+ result = pc.less(col, val)
288
+ case ComparisonOp.LE:
289
+ result = pc.less_equal(col, val)
290
+ case _:
291
+ raise ValueError(f"Unknown comparison operator: {self.op}")
292
+ return cast("pa.BooleanArray", result)
293
+
294
+ def __repr__(self) -> str:
295
+ """Return string representation for debugging."""
296
+ return f"ConstantFilter({self.column_name} {self.op.symbol} {self.value})"
297
+
298
+
299
+ @dataclass(frozen=True, slots=True)
300
+ class IsNullFilter(Filter):
301
+ """IS NULL check filter."""
302
+
303
+ def evaluate(self, batch: pa.RecordBatch) -> pa.BooleanArray:
304
+ """Evaluate IS NULL check against batch column."""
305
+ return pc.is_null(batch.column(self.column_index))
306
+
307
+ def __repr__(self) -> str:
308
+ """Return string representation for debugging."""
309
+ return f"IsNullFilter({self.column_name} IS NULL)"
310
+
311
+
312
+ @dataclass(frozen=True, slots=True)
313
+ class IsNotNullFilter(Filter):
314
+ """IS NOT NULL check filter."""
315
+
316
+ def evaluate(self, batch: pa.RecordBatch) -> pa.BooleanArray:
317
+ """Evaluate IS NOT NULL check against batch column."""
318
+ return pc.is_valid(batch.column(self.column_index))
319
+
320
+ def __repr__(self) -> str:
321
+ """Return string representation for debugging."""
322
+ return f"IsNotNullFilter({self.column_name} IS NOT NULL)"
323
+
324
+
325
+ @dataclass(frozen=True, slots=True)
326
+ class InFilter(Filter):
327
+ """IN (v1, v2, ...) set membership filter.
328
+
329
+ The values are stored as an Arrow array (the contents of the list column).
330
+ """
331
+
332
+ values: pa.Array[Any]
333
+
334
+ def evaluate(self, batch: pa.RecordBatch) -> pa.BooleanArray:
335
+ """Evaluate IN membership against batch column."""
336
+ col, vals = _normalize_for_compare(batch.column(self.column_index), self.values)
337
+ return cast("pa.BooleanArray", pc.is_in(col, vals))
338
+
339
+ def __repr__(self) -> str:
340
+ """Return string representation for debugging."""
341
+ values = self.values.to_pylist()
342
+ preview = f"{values[:3]!r}...({len(values)} total)" if len(values) > 5 else repr(values)
343
+ return f"InFilter({self.column_name} IN {preview})"
344
+
345
+
346
+ @dataclass(frozen=True, slots=True)
347
+ class AndFilter(Filter):
348
+ """Conjunction of child filters.
349
+
350
+ All child filters must pass for a row to pass.
351
+ """
352
+
353
+ children: tuple[Filter, ...]
354
+
355
+ def evaluate(self, batch: pa.RecordBatch) -> pa.BooleanArray:
356
+ """Evaluate AND of all child filters."""
357
+ if not self.children:
358
+ return _make_bool_array(True, batch.num_rows)
359
+ result = self.children[0].evaluate(batch)
360
+ for child in self.children[1:]:
361
+ result = pc.and_(result, child.evaluate(batch))
362
+ return result
363
+
364
+ def __repr__(self) -> str:
365
+ """Return string representation for debugging."""
366
+ children_repr = " AND ".join(repr(c) for c in self.children)
367
+ return f"AndFilter({children_repr})"
368
+
369
+
370
+ @dataclass(frozen=True, slots=True)
371
+ class OrFilter(Filter):
372
+ """Disjunction of child filters.
373
+
374
+ At least one child filter must pass for a row to pass.
375
+ """
376
+
377
+ children: tuple[Filter, ...]
378
+
379
+ def evaluate(self, batch: pa.RecordBatch) -> pa.BooleanArray:
380
+ """Evaluate OR of all child filters."""
381
+ if not self.children:
382
+ return _make_bool_array(False, batch.num_rows)
383
+ result = self.children[0].evaluate(batch)
384
+ for child in self.children[1:]:
385
+ result = pc.or_(result, child.evaluate(batch))
386
+ return result
387
+
388
+ def __repr__(self) -> str:
389
+ """Return string representation for debugging."""
390
+ children_repr = " OR ".join(repr(c) for c in self.children)
391
+ return f"OrFilter({children_repr})"
392
+
393
+
394
+ class _SingleColumnBatch:
395
+ """Lightweight wrapper providing batch-like interface for a single array.
396
+
397
+ Used by StructFilter to avoid creating a full RecordBatch when evaluating
398
+ child filters on nested struct fields.
399
+ """
400
+
401
+ __slots__ = ("_array",)
402
+
403
+ def __init__(self, array: pa.Array[Any]) -> None:
404
+ self._array = array
405
+
406
+ def column(self, _index: int) -> pa.Array[Any]:
407
+ """Return the wrapped array (index is ignored)."""
408
+ return self._array
409
+
410
+ @property
411
+ def num_rows(self) -> int:
412
+ """Return the number of rows in the array."""
413
+ return len(self._array)
414
+
415
+
416
+ @dataclass(frozen=True, slots=True)
417
+ class StructFilter(Filter):
418
+ """Nested struct field filter.
419
+
420
+ Filters on a nested field within a struct column.
421
+ Example: address.city = 'Seattle'
422
+ """
423
+
424
+ child_index: int
425
+ child_name: str
426
+ child_filter: Filter
427
+
428
+ def evaluate(self, batch: pa.RecordBatch) -> pa.BooleanArray:
429
+ """Evaluate filter on nested struct field."""
430
+ struct_col = batch.column(self.column_index)
431
+ nested = pc.struct_field(struct_col, self.child_name)
432
+ # Use lightweight wrapper instead of creating a full RecordBatch
433
+ wrapper = _SingleColumnBatch(nested)
434
+ # Adjust child filter to use column_index=0 for the wrapper
435
+ adjusted_child = dataclasses.replace(self.child_filter, column_index=0)
436
+ return adjusted_child.evaluate(wrapper) # type: ignore[arg-type]
437
+
438
+ def __repr__(self) -> str:
439
+ """Return string representation for debugging."""
440
+ nested = f"{self.column_name}.{self.child_name}"
441
+ return f"StructFilter({nested}: {self.child_filter!r})"
442
+
443
+
444
+ # =============================================================================
445
+ # Expression Filter (recursive expression tree from DuckDB)
446
+ # =============================================================================
447
+
448
+
449
+ @dataclass(frozen=True, slots=True)
450
+ class ExpressionNode:
451
+ """Base class for expression tree nodes.
452
+
453
+ Subclasses must set ``expr_type`` to match their class. This field
454
+ is used for serialization round-tripping (JSON ``expr_type`` key).
455
+ """
456
+
457
+ expr_type: ExpressionNodeType
458
+
459
+ def to_sql(self, column_name: str) -> str:
460
+ """Convert node to SQL string. Override in subclasses."""
461
+ raise NotImplementedError
462
+
463
+
464
+ @dataclass(frozen=True, slots=True)
465
+ class ColumnRefNode(ExpressionNode):
466
+ """Column reference node.
467
+
468
+ Note: In v1, all column refs in an expression filter refer to the same
469
+ column (the filter column). The index is stored for future multi-column
470
+ support but to_sql() always uses the filter's column_name.
471
+ """
472
+
473
+ index: int
474
+
475
+ def to_sql(self, column_name: str) -> str:
476
+ """Return quoted column name with double-quote escaping."""
477
+ escaped = column_name.replace('"', '""')
478
+ return f'"{escaped}"'
479
+
480
+
481
+ @dataclass(frozen=True, slots=True)
482
+ class ConstantNode(ExpressionNode):
483
+ """Constant value node."""
484
+
485
+ value: pa.Scalar[Any]
486
+ field: pa.Field[Any] | None = None # Arrow field with extension metadata (if available)
487
+
488
+ def to_sql(self, column_name: str) -> str:
489
+ """Format Arrow scalar as SQL literal, using field metadata for extension types."""
490
+ return _arrow_scalar_to_sql(self.value, self.field)
491
+
492
+
493
+ def _is_operator_name(name: str) -> bool:
494
+ """Check if a function name is an infix operator (all non-alphanumeric/underscore chars)."""
495
+ return len(name) > 0 and all(not (c.isalnum() or c == "_") for c in name)
496
+
497
+
498
+ @dataclass(frozen=True, slots=True)
499
+ class FunctionNode(ExpressionNode):
500
+ """Function call node."""
501
+
502
+ function_name: str
503
+ children: tuple[ExpressionNode, ...]
504
+
505
+ def to_sql(self, column_name: str) -> str:
506
+ """Format as function_name(args...) or infix for operators like &&."""
507
+ if _is_operator_name(self.function_name) and len(self.children) == 2:
508
+ left = self.children[0].to_sql(column_name)
509
+ right = self.children[1].to_sql(column_name)
510
+ return f"({left} {self.function_name} {right})"
511
+ args = ", ".join(c.to_sql(column_name) for c in self.children)
512
+ return f"{self.function_name}({args})"
513
+
514
+
515
+ @dataclass(frozen=True, slots=True)
516
+ class ComparisonNode(ExpressionNode):
517
+ """Comparison node (left op right)."""
518
+
519
+ op: ComparisonOp
520
+ left: ExpressionNode
521
+ right: ExpressionNode
522
+
523
+ def to_sql(self, column_name: str) -> str:
524
+ """Format as (left op right)."""
525
+ return f"({self.left.to_sql(column_name)} {self.op.symbol} {self.right.to_sql(column_name)})"
526
+
527
+
528
+ @dataclass(frozen=True, slots=True)
529
+ class ConjunctionNode(ExpressionNode):
530
+ """AND/OR conjunction node."""
531
+
532
+ conjunction_type: str # "and" or "or"
533
+ children: tuple[ExpressionNode, ...]
534
+
535
+ def to_sql(self, column_name: str) -> str:
536
+ """Format as (child1 AND/OR child2 AND/OR ...)."""
537
+ joiner = " AND " if self.conjunction_type == "and" else " OR "
538
+ parts = [c.to_sql(column_name) for c in self.children]
539
+ return f"({joiner.join(parts)})"
540
+
541
+
542
+ # Arrow extension type name → SQL wrapper for binary values.
543
+ # Maps ARROW:extension:name metadata to a function that converts hex to SQL.
544
+ _ARROW_EXTENSION_SQL: dict[str, Callable[[str], str]] = {
545
+ "geoarrow.wkb": lambda hex_str: f"ST_GeomFromHEXWKB('{hex_str}')",
546
+ }
547
+
548
+
549
+ def _arrow_scalar_to_sql(scalar: pa.Scalar[Any], field: pa.Field[Any] | None = None) -> str:
550
+ """Convert an Arrow scalar to a SQL literal string.
551
+
552
+ Uses Arrow field metadata to handle extension types (e.g., geoarrow.wkb
553
+ for geometry). Falls back to generic representations for standard types.
554
+
555
+ Args:
556
+ scalar: Arrow scalar value.
557
+ field: Arrow field with extension metadata (optional). Used to detect
558
+ extension types like geoarrow.wkb for proper SQL rendering.
559
+
560
+ """
561
+ if scalar.is_valid is False:
562
+ return "NULL"
563
+
564
+ val = scalar.as_py()
565
+ typ = scalar.type
566
+
567
+ if pa.types.is_boolean(typ):
568
+ return "TRUE" if val else "FALSE"
569
+ elif pa.types.is_integer(typ) or pa.types.is_floating(typ):
570
+ return str(val)
571
+ elif pa.types.is_string(typ) or pa.types.is_large_string(typ):
572
+ escaped = str(val).replace("'", "''")
573
+ return f"'{escaped}'"
574
+ elif pa.types.is_binary(typ) or pa.types.is_large_binary(typ):
575
+ hex_str = val.hex()
576
+ # Check Arrow extension metadata for type-specific rendering
577
+ if field is not None and field.metadata is not None:
578
+ ext_name = field.metadata.get(b"ARROW:extension:name", b"").decode()
579
+ if ext_name in _ARROW_EXTENSION_SQL:
580
+ return _ARROW_EXTENSION_SQL[ext_name](hex_str)
581
+ return f"'\\x{hex_str}'::BLOB"
582
+ else:
583
+ # Fallback: use Python repr as string literal
584
+ escaped = str(val).replace("'", "''")
585
+ return f"'{escaped}'"
586
+
587
+
588
+ _expression_eval_local = threading.local()
589
+
590
+
591
+ def _get_expression_eval_connection() -> Any:
592
+ """Get or create a thread-local DuckDB connection for expression filter evaluation.
593
+
594
+ Each thread gets its own connection (DuckDB connections are not thread-safe).
595
+ Spatial extension is loaded if available (needed for geometry functions like &&).
596
+ """
597
+ conn = getattr(_expression_eval_local, "conn", None)
598
+ if conn is not None:
599
+ return conn
600
+
601
+ from vgi._duckdb import connect as engine_connect
602
+
603
+ conn = engine_connect()
604
+ try:
605
+ conn.load_extension("spatial")
606
+ except Exception:
607
+ try:
608
+ conn.install_extension("spatial")
609
+ conn.load_extension("spatial")
610
+ except Exception:
611
+ pass # spatial not available — non-spatial expressions still work
612
+ _expression_eval_local.conn = conn
613
+ return conn
614
+
615
+
616
+ @dataclass(frozen=True, slots=True)
617
+ class ExpressionFilter(Filter):
618
+ """Expression tree filter pushed from DuckDB.
619
+
620
+ Contains a recursive expression tree that the worker evaluates
621
+ using DuckDB. Typical use: spatial predicates like ``geom && box``.
622
+ """
623
+
624
+ expr: ExpressionNode
625
+
626
+ def evaluate(self, batch: pa.RecordBatch) -> pa.BooleanArray:
627
+ """Evaluate expression tree against batch using DuckDB.
628
+
629
+ Uses a cached per-process DuckDB connection with spatial extension
630
+ pre-loaded (if available). The engine is imported lazily via
631
+ :mod:`vgi._duckdb` (haybarn preferred, duckdb fallback) — workers
632
+ that don't use expression filters don't need either installed.
633
+ """
634
+ conn = _get_expression_eval_connection()
635
+ tbl = conn.from_arrow(batch) # noqa: F841 (used in SQL below)
636
+ sql_expr = self.expr.to_sql(self.column_name)
637
+ query_result = conn.sql(f"SELECT ({sql_expr})::BOOLEAN AS _r FROM tbl")
638
+ # Use to_arrow_table if available (duckdb >= 1.5.1), fall back to fetch_arrow_table
639
+ fetch = getattr(query_result, "to_arrow_table", None) or query_result.fetch_arrow_table
640
+ result = fetch()
641
+ return result.column("_r").combine_chunks() # type: ignore[no-any-return]
642
+
643
+ def __repr__(self) -> str:
644
+ """Return string representation for debugging."""
645
+ return f"ExpressionFilter({self.column_name}: {self.expr.to_sql(self.column_name)})"
646
+
647
+
648
+ # =============================================================================
649
+ # Column Bounds Helper
650
+ # =============================================================================
651
+
652
+
653
+ @dataclass(frozen=True, slots=True)
654
+ class ColumnBounds:
655
+ """Numeric/comparable bounds for a column extracted from filters.
656
+
657
+ Use case: Partition pruning, index range scans, bounded data fetches.
658
+
659
+ Attributes:
660
+ min_value: Minimum bound value, or None if unbounded below.
661
+ min_inclusive: True if min_value is inclusive (>=), False if exclusive (>).
662
+ max_value: Maximum bound value, or None if unbounded above.
663
+ max_inclusive: True if max_value is inclusive (<=), False if exclusive (<).
664
+
665
+ """
666
+
667
+ min_value: pa.Scalar[Any] | None = None
668
+ min_inclusive: bool = True
669
+ max_value: pa.Scalar[Any] | None = None
670
+ max_inclusive: bool = True
671
+
672
+ def contains(self, value: Any) -> bool:
673
+ """Check if a value satisfies these bounds.
674
+
675
+ Args:
676
+ value: Value to check against bounds.
677
+
678
+ Returns:
679
+ True if value is within bounds, False otherwise.
680
+
681
+ """
682
+ if self.min_value is not None:
683
+ min_val = self.min_value.as_py()
684
+ below_min = value < min_val if self.min_inclusive else value <= min_val
685
+ if below_min:
686
+ return False
687
+
688
+ if self.max_value is not None:
689
+ max_val = self.max_value.as_py()
690
+ above_max = value > max_val if self.max_inclusive else value >= max_val
691
+ if above_max:
692
+ return False
693
+
694
+ return True
695
+
696
+
697
+ # =============================================================================
698
+ # PushdownFilters Container
699
+ # =============================================================================
700
+
701
+
702
+ @dataclass(frozen=True, slots=True)
703
+ class PushdownFilters:
704
+ """Container for pushdown filters with evaluation and query helpers.
705
+
706
+ The top-level filters array represents a conjunction (AND). Each filter in
707
+ the array must be satisfied for a row to pass. Individual filters may
708
+ themselves be AND/OR compound filters for more complex expressions.
709
+
710
+ Provides:
711
+ - evaluate(batch) / apply(batch) - Apply filters using PyArrow compute
712
+ - get_column_bounds(name) - Extract numeric bounds for partition pruning
713
+ - get_column_constant(name) - Get equality constant for a column
714
+ - get_column_in_values(name) - Get IN list values
715
+ - get_column_filters(name) - Get all filters for a column
716
+ - to_sql() - Generate SQL WHERE clause
717
+
718
+ """
719
+
720
+ filters: tuple[Filter, ...]
721
+ version: str = _SUPPORTED_VERSION
722
+ join_keys_batches: list[pa.RecordBatch] | None = None
723
+
724
+ def get_join_keys_batch(self) -> pa.RecordBatch | None:
725
+ """Return a merged join keys batch for temp table registration.
726
+
727
+ When all join key batches have the same row count (the semi-join
728
+ case), returns a single RecordBatch with all columns merged.
729
+ When batches have different row counts (independent IN filters),
730
+ they cannot be merged, so this returns ``None``.
731
+
732
+ For individual column access, use :meth:`get_join_keys_batches`
733
+ or :meth:`get_column_in_values`.
734
+
735
+ Example::
736
+
737
+ keys = params.current_pushdown_filters.get_join_keys_batch()
738
+ if keys is not None:
739
+ conn.register("join_keys", keys)
740
+ result = conn.sql(
741
+ "SELECT d.* FROM my_data d JOIN join_keys USING (id)"
742
+ )
743
+
744
+ Returns:
745
+ Merged RecordBatch when all batches have equal row counts,
746
+ or None.
747
+
748
+ """
749
+ if not self.join_keys_batches:
750
+ return None
751
+ row_counts = {b.num_rows for b in self.join_keys_batches}
752
+ if len(row_counts) != 1:
753
+ return None
754
+ # All same cardinality — merge into one multi-column batch
755
+ columns: list[pa.Array[Any]] = []
756
+ fields: list[pa.Field[Any]] = []
757
+ for b in self.join_keys_batches:
758
+ for i in range(b.num_columns):
759
+ columns.append(b.column(i))
760
+ fields.append(b.schema.field(i))
761
+ return pa.RecordBatch.from_arrays(columns, schema=pa.schema(fields))
762
+
763
+ def get_join_keys_batches(self) -> list[pa.RecordBatch] | None:
764
+ """Return all join key batches (one per IN filter column).
765
+
766
+ Each batch is a single-column RecordBatch. Different batches may have
767
+ different row counts. Returns ``None`` if no join keys were pushed.
768
+
769
+ """
770
+ return self.join_keys_batches if self.join_keys_batches else None
771
+
772
+ def evaluate(self, batch: pa.RecordBatch) -> pa.BooleanArray:
773
+ """Evaluate all filters, returning boolean mask.
774
+
775
+ Filters are combined with AND at the top level - a row passes only
776
+ if ALL filters evaluate to true for that row.
777
+
778
+ Args:
779
+ batch: RecordBatch to evaluate filters against.
780
+
781
+ Returns:
782
+ Boolean array with True for rows that pass all filters.
783
+
784
+ """
785
+ _log_debug(
786
+ "evaluate_start",
787
+ num_filters=len(self.filters),
788
+ input_rows=batch.num_rows,
789
+ columns=[f.column_name for f in self.filters],
790
+ )
791
+
792
+ if not self.filters:
793
+ _log_debug("evaluate_no_filters", input_rows=batch.num_rows)
794
+ return _make_bool_array(True, batch.num_rows)
795
+
796
+ result = self.filters[0].evaluate(batch)
797
+ # pc.sum works on BooleanArray (counts True values) but stubs don't reflect this
798
+ true_count: int | None = pc.sum(result).as_py() # type: ignore[type-var]
799
+ _log_debug(
800
+ "evaluate_filter",
801
+ filter_index=0,
802
+ filter_type=type(self.filters[0]).__name__,
803
+ filter_repr=repr(self.filters[0]),
804
+ rows_passing=true_count,
805
+ )
806
+
807
+ for i, f in enumerate(self.filters[1:], start=1):
808
+ result = pc.and_(result, f.evaluate(batch))
809
+ true_count = pc.sum(result).as_py() # type: ignore[type-var]
810
+ _log_debug(
811
+ "evaluate_filter",
812
+ filter_index=i,
813
+ filter_type=type(f).__name__,
814
+ filter_repr=repr(f),
815
+ rows_passing=true_count,
816
+ )
817
+
818
+ final_count: int | None = pc.sum(result).as_py() # type: ignore[type-var]
819
+ _log_debug(
820
+ "evaluate_complete",
821
+ input_rows=batch.num_rows,
822
+ rows_passing=final_count,
823
+ rows_filtered=batch.num_rows - (final_count or 0),
824
+ )
825
+ return result
826
+
827
+ def apply(self, batch: pa.RecordBatch) -> pa.RecordBatch:
828
+ """Apply all filters to batch, returning filtered batch.
829
+
830
+ Args:
831
+ batch: RecordBatch to filter.
832
+
833
+ Returns:
834
+ Filtered RecordBatch containing only rows that pass all filters.
835
+
836
+ """
837
+ _log_debug("apply_start", input_rows=batch.num_rows)
838
+ mask = self.evaluate(batch)
839
+ # pc.filter supports RecordBatch but pyarrow-stubs don't have the overload
840
+ filtered: pa.RecordBatch = pc.filter(batch, mask) # type: ignore[call-overload]
841
+ _log_debug(
842
+ "apply_complete",
843
+ input_rows=batch.num_rows,
844
+ output_rows=filtered.num_rows,
845
+ rows_removed=batch.num_rows - filtered.num_rows,
846
+ )
847
+ return filtered
848
+
849
+ # =========================================================================
850
+ # Column Query Helpers
851
+ # =========================================================================
852
+
853
+ @property
854
+ def filtered_columns(self) -> frozenset[str]:
855
+ """Set of column names that have filters applied.
856
+
857
+ Use case: Quick check of which columns are constrained.
858
+ """
859
+ return frozenset(f.column_name for f in self.filters)
860
+
861
+ def get_column_filters(self, column_name: str) -> list[Filter]:
862
+ """Get all top-level filters for a specific column.
863
+
864
+ Use case: Inspect what constraints apply to a column.
865
+
866
+ Args:
867
+ column_name: Name of the column to get filters for.
868
+
869
+ Returns:
870
+ List of filters that apply to the column.
871
+
872
+ """
873
+ return [f for f in self.filters if f.column_name == column_name]
874
+
875
+ def has_filter_for_column(self, column_name: str) -> bool:
876
+ """Check if any filter constrains the given column.
877
+
878
+ Args:
879
+ column_name: Name of the column to check.
880
+
881
+ Returns:
882
+ True if at least one filter applies to the column.
883
+
884
+ """
885
+ return any(f.column_name == column_name for f in self.filters)
886
+
887
+ def get_column_constant(self, column_name: str) -> pa.Scalar[Any] | None:
888
+ """Get constant value if column has an equality filter.
889
+
890
+ Use case: Partition key lookup, exact match optimization.
891
+
892
+ Args:
893
+ column_name: Name of the column to check.
894
+
895
+ Returns:
896
+ The constant value if an equality filter exists, None otherwise.
897
+
898
+ """
899
+ # Descend one level into AndFilter children (consistent with
900
+ # get_column_bounds): DuckDB commonly pushes `col = v` / `col IN (...)`
901
+ # conjoined with derived range bounds as a single AndFilter.
902
+ for f in self._collect_column_filters(column_name):
903
+ if isinstance(f, ConstantFilter) and f.op == ComparisonOp.EQ:
904
+ return f.value
905
+ return None
906
+
907
+ def get_column_in_values(self, column_name: str) -> pa.Array[Any] | None:
908
+ """Get IN list values if column has an IN filter.
909
+
910
+ Use case: Multi-key lookup, batch fetching.
911
+
912
+ Args:
913
+ column_name: Name of the column to check.
914
+
915
+ Returns:
916
+ Arrow array of IN values if an IN filter exists, None otherwise.
917
+
918
+ """
919
+ # Descend one level into AndFilter children (consistent with
920
+ # get_column_bounds): an `IN (...)` predicate is frequently pushed as
921
+ # AndFilter(InFilter, >=min, <=max).
922
+ for f in self._collect_column_filters(column_name):
923
+ if isinstance(f, InFilter):
924
+ return f.values
925
+ return None
926
+
927
+ def get_column_values(self, column_name: str) -> pa.Array[Any] | None:
928
+ """Get all distinct values a column could have based on filters.
929
+
930
+ Returns values from equality (=) or IN filters as an Arrow array.
931
+ Useful for partition pruning when partitions are keyed by specific values.
932
+
933
+ Use case: Partition key lookup, directory-based partitioning.
934
+
935
+ Args:
936
+ column_name: Name of the column to check.
937
+
938
+ Returns:
939
+ Arrow array of discrete values if available, None otherwise.
940
+
941
+ """
942
+ # Descend one level into AndFilter children (consistent with
943
+ # get_column_bounds). DuckDB pushes `col = v` / `col IN (...)` conjoined
944
+ # with derived range bounds as a single AndFilter; without this descent
945
+ # the discrete-value fast path silently misses those and callers fall
946
+ # back to scanning every partition.
947
+ for f in self._collect_column_filters(column_name):
948
+ if isinstance(f, ConstantFilter) and f.op == ComparisonOp.EQ:
949
+ # Wrap single value in array for consistent return type
950
+ arr: pa.Array[Any] = pa.array([f.value.as_py()], type=f.value.type)
951
+ return arr
952
+ elif isinstance(f, InFilter):
953
+ return f.values
954
+ elif isinstance(f, OrFilter):
955
+ # `col = a OR col = b`, `col IN (...) OR col = c`, etc. The
956
+ # column's possible values are the UNION of the branches — but
957
+ # only when *every* branch pins this column to discrete values.
958
+ # If any branch is a range/IS NULL, or constrains a different
959
+ # column (so this column is unbounded in that branch), the set
960
+ # is not enumerable and we must fall through to None. Unlike the
961
+ # AND case, returning one branch's values would be an unsafe
962
+ # subset (a pruning caller would skip the other branches' rows).
963
+ union = self._or_discrete_values(f, column_name)
964
+ if union is not None:
965
+ return union
966
+ return None
967
+
968
+ def _or_discrete_values(self, or_filter: OrFilter, column_name: str) -> pa.Array[Any] | None:
969
+ """Union of discrete values for ``column_name`` across all OR branches.
970
+
971
+ Returns the deduplicated union iff every child constrains ``column_name``
972
+ to discrete values (``=`` or ``IN``); otherwise None (the column could
973
+ take any value through a non-discrete branch, so it cannot be enumerated).
974
+ Descends one level into ``OrFilter`` children, consistent with the
975
+ single-level descent used elsewhere; deeper nesting yields None.
976
+ """
977
+ values: list[Any] = []
978
+ for child in or_filter.children:
979
+ if child.column_name != column_name:
980
+ # A branch that constrains some other column (or none) leaves
981
+ # this column unbounded within that branch -> not enumerable.
982
+ return None
983
+ if isinstance(child, ConstantFilter) and child.op == ComparisonOp.EQ:
984
+ values.append(child.value.as_py())
985
+ elif isinstance(child, InFilter):
986
+ values.extend(child.values.to_pylist())
987
+ else:
988
+ return None
989
+ if not values:
990
+ return None
991
+ # Deduplicate, preserving first-seen order for stable output.
992
+ seen: set[Any] = set()
993
+ deduped: list[Any] = []
994
+ for v in values:
995
+ if v not in seen:
996
+ seen.add(v)
997
+ deduped.append(v)
998
+ result: pa.Array[Any] = pa.array(deduped)
999
+ return result
1000
+
1001
+ def get_column_bounds(self, column_name: str) -> ColumnBounds | None:
1002
+ """Extract numeric bounds from comparison filters.
1003
+
1004
+ Analyzes gt/ge/lt/le filters to determine value range.
1005
+
1006
+ Use case: Range scans, partition pruning, bounded iteration.
1007
+
1008
+ Args:
1009
+ column_name: Name of the column to extract bounds for.
1010
+
1011
+ Returns:
1012
+ ColumnBounds with min/max values if bounds exist, None otherwise.
1013
+
1014
+ """
1015
+ min_val: pa.Scalar[Any] | None = None
1016
+ min_inc = True
1017
+ max_val: pa.Scalar[Any] | None = None
1018
+ max_inc = True
1019
+
1020
+ for f in self._collect_column_filters(column_name):
1021
+ if isinstance(f, ConstantFilter):
1022
+ if f.op == ComparisonOp.GT:
1023
+ if min_val is None or f.value.as_py() > min_val.as_py():
1024
+ min_val, min_inc = f.value, False
1025
+ elif f.op == ComparisonOp.GE:
1026
+ if min_val is None or f.value.as_py() >= min_val.as_py():
1027
+ min_val, min_inc = f.value, True
1028
+ elif f.op == ComparisonOp.LT:
1029
+ if max_val is None or f.value.as_py() < max_val.as_py():
1030
+ max_val, max_inc = f.value, False
1031
+ elif f.op == ComparisonOp.LE:
1032
+ if max_val is None or f.value.as_py() <= max_val.as_py():
1033
+ max_val, max_inc = f.value, True
1034
+ elif f.op == ComparisonOp.EQ:
1035
+ # Equality implies exact bounds
1036
+ return ColumnBounds(f.value, True, f.value, True)
1037
+
1038
+ if min_val is None and max_val is None:
1039
+ return None
1040
+ return ColumnBounds(min_val, min_inc, max_val, max_inc)
1041
+
1042
+ def _collect_column_filters(self, column_name: str) -> list[Filter]:
1043
+ """Collect filters for a column from top-level and direct AND children.
1044
+
1045
+ Note: Only descends one level into AndFilter children. Deeply nested
1046
+ AND filters (AND within AND) are not traversed. This is sufficient
1047
+ for most query patterns where bounds filters are either at top level
1048
+ or grouped in a single AND.
1049
+ """
1050
+ result: list[Filter] = []
1051
+ for f in self.filters:
1052
+ if f.column_name == column_name:
1053
+ if isinstance(f, AndFilter):
1054
+ result.extend(c for c in f.children if c.column_name == column_name)
1055
+ else:
1056
+ result.append(f)
1057
+ return result
1058
+
1059
+ # =========================================================================
1060
+ # SQL Generation
1061
+ # =========================================================================
1062
+
1063
+ def to_sql(
1064
+ self,
1065
+ quote_identifier: Callable[[str], str] | None = None,
1066
+ placeholder: str = "?",
1067
+ ) -> tuple[str, list[Any]]:
1068
+ """Convert filters to SQL WHERE clause with parameters.
1069
+
1070
+ Args:
1071
+ quote_identifier: Function to quote column names (default: double quotes)
1072
+ placeholder: Parameter placeholder style ("?", "%s", ":name")
1073
+
1074
+ Returns:
1075
+ Tuple of (where_clause, params) - clause excludes "WHERE" keyword.
1076
+
1077
+ """
1078
+ if not self.filters:
1079
+ return "", []
1080
+
1081
+ quote = quote_identifier or (lambda s: f'"{s}"')
1082
+ conditions: list[str] = []
1083
+ params: list[Any] = []
1084
+
1085
+ for f in self.filters:
1086
+ sql, ps = _filter_to_sql(f, quote, placeholder, len(params))
1087
+ conditions.append(sql)
1088
+ params.extend(ps)
1089
+
1090
+ return " AND ".join(conditions), params
1091
+
1092
+ # =========================================================================
1093
+ # Dunder Methods
1094
+ # =========================================================================
1095
+
1096
+ @classmethod
1097
+ def empty(cls) -> PushdownFilters:
1098
+ """Create an empty PushdownFilters instance (no filters)."""
1099
+ return cls(filters=())
1100
+
1101
+ def __bool__(self) -> bool:
1102
+ """Return True if there are any filters."""
1103
+ return len(self.filters) > 0
1104
+
1105
+ def __len__(self) -> int:
1106
+ """Return the number of top-level filters."""
1107
+ return len(self.filters)
1108
+
1109
+ def __iter__(self) -> Iterator[Filter]:
1110
+ """Iterate over top-level filters."""
1111
+ return iter(self.filters)
1112
+
1113
+ def __contains__(self, column_name: str) -> bool:
1114
+ """Check if any filter constrains the given column.
1115
+
1116
+ Allows 'column_name in filters' syntax.
1117
+ """
1118
+ return any(f.column_name == column_name for f in self.filters)
1119
+
1120
+ def __repr__(self) -> str:
1121
+ """Return string representation for debugging."""
1122
+ if not self.filters:
1123
+ return "PushdownFilters([])"
1124
+ filters_repr = ", ".join(repr(f) for f in self.filters)
1125
+ return f"PushdownFilters([{filters_repr}])"
1126
+
1127
+
1128
+ # =============================================================================
1129
+ # SQL Generation Helper
1130
+ # =============================================================================
1131
+
1132
+
1133
+ def _filter_to_sql(
1134
+ f: Filter,
1135
+ quote: Callable[[str], str],
1136
+ placeholder: str,
1137
+ param_offset: int,
1138
+ ) -> tuple[str, list[Any]]:
1139
+ """Convert a single filter to SQL fragment.
1140
+
1141
+ Args:
1142
+ f: Filter to convert.
1143
+ quote: Function to quote identifiers.
1144
+ placeholder: Parameter placeholder style.
1145
+ param_offset: Current parameter offset (for recursive calls).
1146
+
1147
+ Returns:
1148
+ Tuple of (sql_fragment, params).
1149
+
1150
+ """
1151
+ col = quote(f.column_name)
1152
+
1153
+ match f:
1154
+ case ConstantFilter(op=op, value=value):
1155
+ return f"{col} {op.symbol} {placeholder}", [value.as_py()]
1156
+
1157
+ case IsNullFilter():
1158
+ return f"{col} IS NULL", []
1159
+
1160
+ case IsNotNullFilter():
1161
+ return f"{col} IS NOT NULL", []
1162
+
1163
+ case InFilter(values=values):
1164
+ placeholders = ", ".join([placeholder] * len(values))
1165
+ return f"{col} IN ({placeholders})", values.to_pylist()
1166
+
1167
+ case AndFilter(children=children):
1168
+ parts: list[str] = []
1169
+ params: list[Any] = []
1170
+ for child in children:
1171
+ offset = param_offset + len(params)
1172
+ sql, ps = _filter_to_sql(child, quote, placeholder, offset)
1173
+ parts.append(sql)
1174
+ params.extend(ps)
1175
+ return f"({' AND '.join(parts)})", params
1176
+
1177
+ case OrFilter(children=children):
1178
+ parts = []
1179
+ params = []
1180
+ for child in children:
1181
+ offset = param_offset + len(params)
1182
+ sql, ps = _filter_to_sql(child, quote, placeholder, offset)
1183
+ parts.append(sql)
1184
+ params.extend(ps)
1185
+ return f"({' OR '.join(parts)})", params
1186
+
1187
+ case StructFilter(child_name=child_name, child_filter=child_filter):
1188
+ # Struct access varies by database - use dot notation as default
1189
+ nested_col = f"{f.column_name}.{child_name}"
1190
+ return _filter_to_sql(child_filter, lambda _: quote(nested_col), placeholder, param_offset)
1191
+
1192
+ case ExpressionFilter(expr=expr):
1193
+ # Expression filters use inline constants in to_sql() because
1194
+ # they may contain geometry literals that can't be parameterized.
1195
+ # Constants are embedded directly in the SQL string.
1196
+ return expr.to_sql(f.column_name), []
1197
+
1198
+ case _:
1199
+ raise ValueError(f"Unknown filter type: {type(f)}")
1200
+
1201
+
1202
+ # =============================================================================
1203
+ # Deserialization
1204
+ # =============================================================================
1205
+
1206
+
1207
+ def deserialize_filters(
1208
+ batch: pa.RecordBatch,
1209
+ join_keys: list[pa.RecordBatch] | None = None,
1210
+ ) -> PushdownFilters:
1211
+ """Deserialize Arrow IPC bytes to typed AST.
1212
+
1213
+ Args:
1214
+ batch: Arrow RecordBatch containing the serialized filters.
1215
+ join_keys: Optional list of single-column Arrow RecordBatches, one per
1216
+ IN filter column. Each batch may have a different row count.
1217
+ Referenced by ``join_keys`` filter type entries in the filter spec.
1218
+
1219
+ Returns:
1220
+ PushdownFilters container with parsed filter AST.
1221
+
1222
+ Raises:
1223
+ FilterDeserializationError: If parsing fails.
1224
+ FilterVersionError: If version is unsupported.
1225
+
1226
+ """
1227
+ # Validate version
1228
+ metadata = batch.schema.field(0).metadata
1229
+ if metadata is None:
1230
+ raise FilterVersionError("Missing vgi_filter_version metadata")
1231
+ version = metadata.get(b"vgi_filter_version", b"").decode()
1232
+ if version != _SUPPORTED_VERSION:
1233
+ raise FilterVersionError(f"Unsupported filter version: {version!r}")
1234
+
1235
+ _log_debug("deserialize_version", version=version)
1236
+
1237
+ # Parse JSON spec
1238
+ try:
1239
+ filter_specs = json.loads(batch.column(0)[0].as_py())
1240
+ except Exception as e:
1241
+ _log_debug("deserialize_json_error", error=str(e))
1242
+ raise FilterDeserializationError(f"Failed to parse filter JSON: {e}") from e
1243
+
1244
+ _log_debug("deserialize_specs", num_filters=len(filter_specs), specs=filter_specs)
1245
+
1246
+ # Value resolver - returns scalar for value_ref N from column N+1
1247
+ def get_value(ref: int) -> pa.Scalar[Any]:
1248
+ value = batch.column(ref + 1)[0]
1249
+ _log_debug(
1250
+ "deserialize_value_ref",
1251
+ ref=ref,
1252
+ column_index=ref + 1,
1253
+ value_type=str(value.type),
1254
+ value=str(value),
1255
+ )
1256
+ return value # type: ignore[no-any-return]
1257
+
1258
+ def get_field(ref: int) -> pa.Field[Any]:
1259
+ """Get the Arrow field for a value_ref (column ref+1 in the batch)."""
1260
+ return batch.schema.field(ref + 1)
1261
+
1262
+ def get_join_keys_column(column_name: str) -> pa.Array[Any] | None:
1263
+ """Resolve a column from the join keys batches by name."""
1264
+ if not join_keys:
1265
+ return None
1266
+ for keys_batch in join_keys:
1267
+ try:
1268
+ return keys_batch.column(column_name)
1269
+ except KeyError:
1270
+ continue
1271
+ return None
1272
+
1273
+ # Parse filters
1274
+ try:
1275
+ parsed: list[Filter] = []
1276
+ for spec in filter_specs:
1277
+ f = _parse_filter(spec, get_value, get_field, get_join_keys_column)
1278
+ if f is not None:
1279
+ parsed.append(f)
1280
+ filters = tuple(parsed)
1281
+ except Exception as e:
1282
+ _log_debug("deserialize_parse_error", error=str(e))
1283
+ raise FilterDeserializationError(f"Failed to parse filters: {e}") from e
1284
+
1285
+ _log_debug(
1286
+ "deserialize_complete",
1287
+ num_filters=len(filters),
1288
+ filter_types=[type(f).__name__ for f in filters],
1289
+ columns=[f.column_name for f in filters],
1290
+ )
1291
+
1292
+ return PushdownFilters(filters=filters, version=version, join_keys_batches=join_keys)
1293
+
1294
+
1295
+ def _parse_filter(
1296
+ spec: dict[str, Any],
1297
+ get_value: Callable[[int], pa.Scalar[Any]],
1298
+ get_field: Callable[[int], pa.Field[Any]],
1299
+ get_join_keys_column: Callable[[str], pa.Array[Any] | None] | None = None,
1300
+ ) -> Filter | None:
1301
+ """Parse a single filter spec into a typed Filter object.
1302
+
1303
+ Args:
1304
+ spec: Filter specification dict from JSON.
1305
+ get_value: Function to get Arrow scalar by value_ref index.
1306
+ get_field: Function to get Arrow field by value_ref index (for extension metadata).
1307
+ get_join_keys_column: Function to resolve a column from the join keys batch by name.
1308
+ Returns None if no join keys batch or column not found.
1309
+
1310
+ Returns:
1311
+ Typed Filter object, or None if the filter references missing join keys.
1312
+
1313
+ Raises:
1314
+ FilterDeserializationError: If filter type is unknown.
1315
+
1316
+ """
1317
+ column_name = spec["column_name"]
1318
+ column_index = spec["column_index"]
1319
+ filter_type = spec["type"]
1320
+
1321
+ _log_debug(
1322
+ "parse_filter_start",
1323
+ filter_type=filter_type,
1324
+ column_name=column_name,
1325
+ column_index=column_index,
1326
+ )
1327
+
1328
+ if filter_type == FilterType.CONSTANT.value:
1329
+ op = ComparisonOp(spec["op"])
1330
+ value = get_value(spec["value_ref"])
1331
+ result = ConstantFilter(
1332
+ column_name=column_name,
1333
+ column_index=column_index,
1334
+ op=op,
1335
+ value=value,
1336
+ )
1337
+ _log_debug(
1338
+ "parse_filter_constant",
1339
+ column=column_name,
1340
+ op=op.value,
1341
+ value=str(value),
1342
+ value_type=str(value.type),
1343
+ )
1344
+ return result
1345
+
1346
+ elif filter_type == FilterType.IS_NULL.value:
1347
+ _log_debug("parse_filter_is_null", column=column_name)
1348
+ return IsNullFilter(column_name=column_name, column_index=column_index)
1349
+
1350
+ elif filter_type == FilterType.IS_NOT_NULL.value:
1351
+ _log_debug("parse_filter_is_not_null", column=column_name)
1352
+ return IsNotNullFilter(column_name=column_name, column_index=column_index)
1353
+
1354
+ elif filter_type == FilterType.IN.value:
1355
+ # value_ref points to a list column; extract the list's values as an array
1356
+ list_scalar = get_value(spec["value_ref"])
1357
+ # ListScalar.values gives us the underlying array
1358
+ # pyarrow-stubs doesn't type ListScalar.values correctly
1359
+ values_array: pa.Array[Any] = list_scalar.values # type: ignore[attr-defined]
1360
+ _log_debug(
1361
+ "parse_filter_in",
1362
+ column=column_name,
1363
+ num_values=len(values_array),
1364
+ values=values_array.to_pylist(),
1365
+ value_type=str(values_array.type),
1366
+ )
1367
+ return InFilter(
1368
+ column_name=column_name,
1369
+ column_index=column_index,
1370
+ values=values_array,
1371
+ )
1372
+
1373
+ elif filter_type == FilterType.JOIN_KEYS.value:
1374
+ keys_column_name = spec["keys_column"]
1375
+ join_values: pa.Array[Any] | None = get_join_keys_column(keys_column_name) if get_join_keys_column else None
1376
+ if join_values is None:
1377
+ _log_debug(
1378
+ "parse_filter_join_keys_missing",
1379
+ column=column_name,
1380
+ keys_column=keys_column_name,
1381
+ )
1382
+ return None # graceful degradation — DuckDB filters client-side
1383
+ _log_debug(
1384
+ "parse_filter_join_keys",
1385
+ column=column_name,
1386
+ keys_column=keys_column_name,
1387
+ num_values=len(join_values),
1388
+ value_type=str(join_values.type),
1389
+ )
1390
+ return InFilter(
1391
+ column_name=column_name,
1392
+ column_index=column_index,
1393
+ values=join_values,
1394
+ )
1395
+
1396
+ elif filter_type == FilterType.AND.value:
1397
+ _log_debug(
1398
+ "parse_filter_and_start",
1399
+ column=column_name,
1400
+ num_children=len(spec["children"]),
1401
+ )
1402
+ children = tuple(
1403
+ f
1404
+ for c in spec["children"]
1405
+ if (f := _parse_filter(c, get_value, get_field, get_join_keys_column)) is not None
1406
+ )
1407
+ _log_debug("parse_filter_and_complete", column=column_name)
1408
+ return AndFilter(
1409
+ column_name=column_name,
1410
+ column_index=column_index,
1411
+ children=children,
1412
+ )
1413
+
1414
+ elif filter_type == FilterType.OR.value:
1415
+ _log_debug(
1416
+ "parse_filter_or_start",
1417
+ column=column_name,
1418
+ num_children=len(spec["children"]),
1419
+ )
1420
+ parsed_children: list[Filter] = []
1421
+ for c in spec["children"]:
1422
+ child = _parse_filter(c, get_value, get_field, get_join_keys_column)
1423
+ if child is None:
1424
+ # Dropping a child from OR would strengthen the filter (fewer rows pass),
1425
+ # which is wrong for graceful degradation. Drop the entire OR instead.
1426
+ _log_debug("parse_filter_or_child_missing", column=column_name)
1427
+ return None
1428
+ parsed_children.append(child)
1429
+ children = tuple(parsed_children)
1430
+ _log_debug("parse_filter_or_complete", column=column_name)
1431
+ return OrFilter(
1432
+ column_name=column_name,
1433
+ column_index=column_index,
1434
+ children=children,
1435
+ )
1436
+
1437
+ elif filter_type == FilterType.STRUCT.value:
1438
+ child_name = spec["child_name"]
1439
+ _log_debug(
1440
+ "parse_filter_struct_start",
1441
+ column=column_name,
1442
+ child_name=child_name,
1443
+ )
1444
+ child_filter = _parse_filter(spec["child_filter"], get_value, get_field, get_join_keys_column)
1445
+ if child_filter is None:
1446
+ return None
1447
+ _log_debug("parse_filter_struct_complete", column=column_name)
1448
+ return StructFilter(
1449
+ column_name=column_name,
1450
+ column_index=column_index,
1451
+ child_index=spec["child_index"],
1452
+ child_name=child_name,
1453
+ child_filter=child_filter,
1454
+ )
1455
+
1456
+ elif filter_type == FilterType.EXPRESSION.value:
1457
+ _log_debug("parse_filter_expression_start", column=column_name)
1458
+ expr = _parse_expression_node(spec["expr"], get_value, get_field)
1459
+ _log_debug("parse_filter_expression_complete", column=column_name)
1460
+ return ExpressionFilter(
1461
+ column_name=column_name,
1462
+ column_index=column_index,
1463
+ expr=expr,
1464
+ )
1465
+
1466
+ else:
1467
+ _log_debug("parse_filter_unknown", filter_type=filter_type)
1468
+ raise FilterDeserializationError(f"Unknown filter type: {filter_type}")
1469
+
1470
+
1471
+ def _parse_expression_node(
1472
+ spec: dict[str, Any],
1473
+ get_value: Callable[[int], pa.Scalar[Any]],
1474
+ get_field: Callable[[int], pa.Field[Any]],
1475
+ ) -> ExpressionNode:
1476
+ """Parse a single expression node from a JSON spec.
1477
+
1478
+ Args:
1479
+ spec: Expression node specification dict from JSON.
1480
+ get_value: Function to get Arrow scalar by value_ref index.
1481
+ get_field: Function to get Arrow field by value_ref index (for extension metadata).
1482
+
1483
+ Returns:
1484
+ Typed ExpressionNode.
1485
+
1486
+ Raises:
1487
+ FilterDeserializationError: If expression node type is unknown.
1488
+
1489
+ """
1490
+ expr_type = spec["expr_type"]
1491
+
1492
+ if expr_type == ExpressionNodeType.COLUMN_REF.value:
1493
+ return ColumnRefNode(expr_type=ExpressionNodeType.COLUMN_REF, index=spec["index"])
1494
+
1495
+ elif expr_type == ExpressionNodeType.CONSTANT.value:
1496
+ ref = spec["value_ref"]
1497
+ return ConstantNode(
1498
+ expr_type=ExpressionNodeType.CONSTANT,
1499
+ value=get_value(ref),
1500
+ field=get_field(ref),
1501
+ )
1502
+
1503
+ elif expr_type == ExpressionNodeType.FUNCTION.value:
1504
+ children = tuple(_parse_expression_node(c, get_value, get_field) for c in spec["children"])
1505
+ return FunctionNode(
1506
+ expr_type=ExpressionNodeType.FUNCTION,
1507
+ function_name=spec["function_name"],
1508
+ children=children,
1509
+ )
1510
+
1511
+ elif expr_type == ExpressionNodeType.COMPARISON.value:
1512
+ return ComparisonNode(
1513
+ expr_type=ExpressionNodeType.COMPARISON,
1514
+ op=ComparisonOp(spec["op"]),
1515
+ left=_parse_expression_node(spec["left"], get_value, get_field),
1516
+ right=_parse_expression_node(spec["right"], get_value, get_field),
1517
+ )
1518
+
1519
+ elif expr_type == ExpressionNodeType.CONJUNCTION.value:
1520
+ children = tuple(_parse_expression_node(c, get_value, get_field) for c in spec["children"])
1521
+ return ConjunctionNode(
1522
+ expr_type=ExpressionNodeType.CONJUNCTION,
1523
+ conjunction_type=spec["conjunction_type"],
1524
+ children=children,
1525
+ )
1526
+
1527
+ else:
1528
+ raise FilterDeserializationError(f"Unknown expression node type: {expr_type}")