vgi-python 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vgi/__init__.py +152 -0
- vgi/_duckdb.py +62 -0
- vgi/_storage_profile.py +132 -0
- vgi/_test_fixtures/__init__.py +20 -0
- vgi/_test_fixtures/accumulate/__init__.py +19 -0
- vgi/_test_fixtures/accumulate/worker.py +762 -0
- vgi/_test_fixtures/aggregate/__init__.py +62 -0
- vgi/_test_fixtures/aggregate/_common.py +21 -0
- vgi/_test_fixtures/aggregate/basic.py +232 -0
- vgi/_test_fixtures/aggregate/dynamic.py +409 -0
- vgi/_test_fixtures/aggregate/generic.py +86 -0
- vgi/_test_fixtures/aggregate/listagg.py +71 -0
- vgi/_test_fixtures/aggregate/percentile.py +107 -0
- vgi/_test_fixtures/aggregate/streaming.py +192 -0
- vgi/_test_fixtures/aggregate/varargs.py +75 -0
- vgi/_test_fixtures/aggregate/window.py +380 -0
- vgi/_test_fixtures/attach_options.py +308 -0
- vgi/_test_fixtures/bad_protocol.py +62 -0
- vgi/_test_fixtures/cancellable.py +336 -0
- vgi/_test_fixtures/catalog.py +813 -0
- vgi/_test_fixtures/http_server.py +394 -0
- vgi/_test_fixtures/nest_tensor.py +614 -0
- vgi/_test_fixtures/orchard_catalog.py +47 -0
- vgi/_test_fixtures/projection_repro/__init__.py +6 -0
- vgi/_test_fixtures/projection_repro/worker.py +454 -0
- vgi/_test_fixtures/scalar/__init__.py +116 -0
- vgi/_test_fixtures/scalar/_common.py +69 -0
- vgi/_test_fixtures/scalar/arithmetic.py +321 -0
- vgi/_test_fixtures/scalar/binary.py +120 -0
- vgi/_test_fixtures/scalar/formatting.py +176 -0
- vgi/_test_fixtures/scalar/geo.py +300 -0
- vgi/_test_fixtures/scalar/null_handling.py +107 -0
- vgi/_test_fixtures/scalar/random_demo.py +171 -0
- vgi/_test_fixtures/scalar/settings_secrets.py +102 -0
- vgi/_test_fixtures/scalar/type_info.py +219 -0
- vgi/_test_fixtures/schema_reconcile/__init__.py +29 -0
- vgi/_test_fixtures/schema_reconcile/worker.py +653 -0
- vgi/_test_fixtures/simple_writable.py +793 -0
- vgi/_test_fixtures/table/__init__.py +221 -0
- vgi/_test_fixtures/table/_common.py +162 -0
- vgi/_test_fixtures/table/batch_index.py +283 -0
- vgi/_test_fixtures/table/batch_index_broken.py +200 -0
- vgi/_test_fixtures/table/catalog_scans.py +162 -0
- vgi/_test_fixtures/table/filters.py +1005 -0
- vgi/_test_fixtures/table/late_materialization.py +249 -0
- vgi/_test_fixtures/table/make_series.py +273 -0
- vgi/_test_fixtures/table/misc.py +499 -0
- vgi/_test_fixtures/table/order_modes.py +164 -0
- vgi/_test_fixtures/table/pairs.py +437 -0
- vgi/_test_fixtures/table/partition_columns.py +472 -0
- vgi/_test_fixtures/table/partition_columns_broken.py +304 -0
- vgi/_test_fixtures/table/profiling_example.py +195 -0
- vgi/_test_fixtures/table/required_filters.py +234 -0
- vgi/_test_fixtures/table/sequence.py +710 -0
- vgi/_test_fixtures/table/settings.py +426 -0
- vgi/_test_fixtures/table/transaction_storage.py +162 -0
- vgi/_test_fixtures/table/tt_pushdown.py +191 -0
- vgi/_test_fixtures/table/versioned.py +230 -0
- vgi/_test_fixtures/table_in_out.py +1392 -0
- vgi/_test_fixtures/versioned.py +155 -0
- vgi/_test_fixtures/versioned_tables.py +595 -0
- vgi/_test_fixtures/worker.py +1631 -0
- vgi/_test_fixtures/writable/__init__.py +8 -0
- vgi/_test_fixtures/writable/generic.py +236 -0
- vgi/_test_fixtures/writable/table.py +149 -0
- vgi/_test_fixtures/writable/worker.py +1148 -0
- vgi/aggregate_function.py +607 -0
- vgi/argument_spec.py +472 -0
- vgi/arguments.py +1747 -0
- vgi/auth.py +55 -0
- vgi/catalog/__init__.py +88 -0
- vgi/catalog/attach_option.py +206 -0
- vgi/catalog/catalog_interface.py +2767 -0
- vgi/catalog/descriptors.py +870 -0
- vgi/catalog/duckdb_statistics.py +377 -0
- vgi/catalog/secret_type.py +96 -0
- vgi/catalog/setting.py +253 -0
- vgi/catalog/storage.py +372 -0
- vgi/client/__init__.py +67 -0
- vgi/client/catalog_mixin.py +1251 -0
- vgi/client/cli.py +582 -0
- vgi/client/cli_catalog.py +182 -0
- vgi/client/cli_schema.py +270 -0
- vgi/client/cli_table.py +907 -0
- vgi/client/cli_transaction.py +97 -0
- vgi/client/cli_utils.py +441 -0
- vgi/client/cli_view.py +303 -0
- vgi/client/client.py +2183 -0
- vgi/exceptions.py +205 -0
- vgi/function.py +245 -0
- vgi/function_storage.py +1636 -0
- vgi/function_storage_azure_sql.py +922 -0
- vgi/function_storage_cf_do.py +740 -0
- vgi/http/__init__.py +25 -0
- vgi/http/demo_storage.py +212 -0
- vgi/http/worker_page.py +1252 -0
- vgi/invocation.py +154 -0
- vgi/logging_config.py +93 -0
- vgi/meta_worker.py +661 -0
- vgi/metadata.py +1403 -0
- vgi/otel.py +406 -0
- vgi/protocol.py +2418 -0
- vgi/protocol_version.txt +1 -0
- vgi/py.typed +0 -0
- vgi/scalar_function.py +1211 -0
- vgi/schema_utils.py +234 -0
- vgi/secret_protocol.py +124 -0
- vgi/secret_service.py +238 -0
- vgi/serve.py +769 -0
- vgi/table_buffering_function.py +443 -0
- vgi/table_filter_pushdown.py +1528 -0
- vgi/table_function.py +1130 -0
- vgi/table_in_out_function.py +383 -0
- vgi/transactor/__init__.py +24 -0
- vgi/transactor/_duckdb_compat.py +27 -0
- vgi/transactor/client.py +137 -0
- vgi/transactor/protocol.py +149 -0
- vgi/transactor/server.py +740 -0
- vgi/worker.py +4761 -0
- vgi_python-0.8.0.dist-info/METADATA +735 -0
- vgi_python-0.8.0.dist-info/RECORD +124 -0
- vgi_python-0.8.0.dist-info/WHEEL +4 -0
- vgi_python-0.8.0.dist-info/entry_points.txt +5 -0
- vgi_python-0.8.0.dist-info/licenses/LICENSE +134 -0
|
@@ -0,0 +1,1392 @@
|
|
|
1
|
+
# Copyright 2025, 2026 Query Farm LLC - https://query.farm
|
|
2
|
+
|
|
3
|
+
"""Example table-in/table-out function implementations for testing VGI.
|
|
4
|
+
|
|
5
|
+
WARNING: EXAMPLE/TEST FUNCTIONS ONLY
|
|
6
|
+
-------------------------------------
|
|
7
|
+
These functions are reference implementations for testing and validating the VGI
|
|
8
|
+
protocol. They are NOT intended for production use. The VGI protocol will have
|
|
9
|
+
multiple implementations (Python, Go, JavaScript), and these examples serve as:
|
|
10
|
+
|
|
11
|
+
1. Protocol conformance tests - Verify implementations correctly handle the
|
|
12
|
+
VGI streaming protocol.
|
|
13
|
+
2. Pattern demonstrations - Show how to implement common function patterns
|
|
14
|
+
3. Cross-implementation test cases - Ensure consistent behavior across languages
|
|
15
|
+
|
|
16
|
+
Production considerations like memory limits, error recovery, and performance
|
|
17
|
+
optimizations are intentionally omitted to keep the examples simple and focused
|
|
18
|
+
on protocol correctness.
|
|
19
|
+
|
|
20
|
+
AVAILABLE FUNCTIONS
|
|
21
|
+
-------------------
|
|
22
|
+
EchoFunction - Passthrough, no transformation
|
|
23
|
+
BufferInputFunction - Collects all input, emits on finalize
|
|
24
|
+
FilterBySettingFunction - Filters rows by threshold setting
|
|
25
|
+
RepeatInputsFunction - Duplicates each input batch N times
|
|
26
|
+
SumAllColumnsFunction - Aggregates numeric columns into sums
|
|
27
|
+
ExceptionProcessFunction - Raises exception during process (test)
|
|
28
|
+
ExceptionFinalizeFunction - Raises exception during finalize (test)
|
|
29
|
+
SumAllColumnsSimpleDistributed - Distributed aggregation via callback API
|
|
30
|
+
CrashOnProcessFunction - SIGKILLs the worker mid-process (test)
|
|
31
|
+
CrashOnCombineFunction - Raises during combine() (test)
|
|
32
|
+
CrashOnFinalizeFunction - Raises during finalize() (test)
|
|
33
|
+
HangOnProcessFunction - Sleeps forever in process() (manual cancel test)
|
|
34
|
+
LargeStateFunction - Buffers ~N MB per state_id (IPC chunking test)
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
from __future__ import annotations
|
|
38
|
+
|
|
39
|
+
import os
|
|
40
|
+
import signal
|
|
41
|
+
import sys
|
|
42
|
+
import time
|
|
43
|
+
from dataclasses import dataclass
|
|
44
|
+
from typing import Annotated, Any
|
|
45
|
+
|
|
46
|
+
import pyarrow as pa
|
|
47
|
+
import pyarrow.compute as pc
|
|
48
|
+
from vgi_rpc import ArrowSerializableDataclass, ArrowType
|
|
49
|
+
from vgi_rpc.log import Level
|
|
50
|
+
from vgi_rpc.rpc import OutputCollector
|
|
51
|
+
from vgi_rpc.utils import empty_batch
|
|
52
|
+
|
|
53
|
+
from vgi.arguments import Arg, Setting, TableInput
|
|
54
|
+
from vgi.invocation import BindResponse
|
|
55
|
+
from vgi.metadata import FunctionExample
|
|
56
|
+
from vgi.schema_utils import schema
|
|
57
|
+
from vgi.table_buffering_function import (
|
|
58
|
+
TableBufferingFunction,
|
|
59
|
+
TableBufferingParams,
|
|
60
|
+
)
|
|
61
|
+
from vgi.table_function import BindParams, ProcessParams, TableCardinality
|
|
62
|
+
from vgi.table_in_out_function import (
|
|
63
|
+
TableInOutFunction,
|
|
64
|
+
TableInOutGenerator,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
# Per-tick cursor state for finalize streams that drain a state_log via
|
|
69
|
+
# state_log_scan. Wire-serializable so the producer-mode stream survives
|
|
70
|
+
# HTTP tick boundaries.
|
|
71
|
+
@dataclass
|
|
72
|
+
class _LogDrainState(ArrowSerializableDataclass):
|
|
73
|
+
"""Cursor over a per-state state_log; after_id starts at -1 (before-first)."""
|
|
74
|
+
|
|
75
|
+
# Namespace under which finalize draws batches. Defaults to b"buf" —
|
|
76
|
+
# the conventional location process() writes via state_append.
|
|
77
|
+
ns: bytes = b"buf"
|
|
78
|
+
after_id: int = -1
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
__all__ = [
|
|
82
|
+
"EchoFunction",
|
|
83
|
+
"EchoWitnessFunction",
|
|
84
|
+
"BufferInputFunction",
|
|
85
|
+
"FilterBySettingFunction",
|
|
86
|
+
"RepeatInputsFunction",
|
|
87
|
+
"SumAllColumnsFunction",
|
|
88
|
+
"SumAllColumnsSimpleDistributed",
|
|
89
|
+
"ExceptionProcessFunction",
|
|
90
|
+
"ExceptionFinalizeFunction",
|
|
91
|
+
"CrashOnProcessFunction",
|
|
92
|
+
"CrashOnCombineFunction",
|
|
93
|
+
"CrashOnFinalizeFunction",
|
|
94
|
+
"HangOnProcessFunction",
|
|
95
|
+
"LargeStateFunction",
|
|
96
|
+
"OrderedBufferInputFunction",
|
|
97
|
+
"BatchIndexBufferInputFunction",
|
|
98
|
+
"OrderedSourceFunction",
|
|
99
|
+
"BufferEmitWideFunction",
|
|
100
|
+
]
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@dataclass(slots=True, frozen=True, kw_only=True)
|
|
104
|
+
class SingleTableArguments:
|
|
105
|
+
"""Arguments for a table in/out function that just takes a table."""
|
|
106
|
+
|
|
107
|
+
data: Annotated[TableInput, Arg(0, doc="Input table")]
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class EchoFunction(TableInOutGenerator[SingleTableArguments]):
|
|
111
|
+
"""Passthrough function that emits each input batch unchanged.
|
|
112
|
+
|
|
113
|
+
USE CASE
|
|
114
|
+
--------
|
|
115
|
+
Testing, debugging, or as a no-op placeholder in a pipeline.
|
|
116
|
+
|
|
117
|
+
SCHEMA TRANSFORMATION
|
|
118
|
+
---------------------
|
|
119
|
+
Input: any schema
|
|
120
|
+
Output: same schema (passthrough), with optional projection and filtering
|
|
121
|
+
|
|
122
|
+
PUSHDOWN SUPPORT
|
|
123
|
+
----------------
|
|
124
|
+
- projection_pushdown: Only returns requested columns
|
|
125
|
+
- filter_pushdown: Filters rows based on pushed-down predicates
|
|
126
|
+
- auto_apply_filters: Automatically applies filters to output batches
|
|
127
|
+
|
|
128
|
+
Example:
|
|
129
|
+
-------
|
|
130
|
+
Input: [{"a": 1, "b": 2}, {"a": 3, "b": 4}]
|
|
131
|
+
Output: [{"a": 1, "b": 2}, {"a": 3, "b": 4}]
|
|
132
|
+
|
|
133
|
+
"""
|
|
134
|
+
|
|
135
|
+
class Meta:
|
|
136
|
+
"""Metadata for EchoFunction."""
|
|
137
|
+
|
|
138
|
+
name = "echo"
|
|
139
|
+
description = "Passthrough function that emits each input batch unchanged"
|
|
140
|
+
categories = ["utility", "debug"]
|
|
141
|
+
tags = {"category": "debug", "type": "passthrough"}
|
|
142
|
+
projection_pushdown = True
|
|
143
|
+
filter_pushdown = True
|
|
144
|
+
auto_apply_filters = True
|
|
145
|
+
examples = [
|
|
146
|
+
FunctionExample(
|
|
147
|
+
sql="SELECT * FROM echo((SELECT * FROM input_table))",
|
|
148
|
+
description="Pass through all rows unchanged",
|
|
149
|
+
)
|
|
150
|
+
]
|
|
151
|
+
|
|
152
|
+
@classmethod
|
|
153
|
+
def on_bind(cls, params: BindParams[SingleTableArguments]) -> BindResponse:
|
|
154
|
+
"""Produce the output schema."""
|
|
155
|
+
assert params.bind_call.input_schema is not None
|
|
156
|
+
return BindResponse(output_schema=params.bind_call.input_schema)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class EchoWitnessFunction(TableInOutGenerator[SingleTableArguments]):
|
|
160
|
+
"""Integer-output fixture that encodes the post-projection column count.
|
|
161
|
+
|
|
162
|
+
Designed to verify that projection pushdown ACTUALLY narrows the
|
|
163
|
+
schema reaching the worker (rather than just relying on DuckDB to
|
|
164
|
+
narrow above the operator). Each emitted row has every column set
|
|
165
|
+
to ``len(params.output_schema)`` — i.e., the worker's observed
|
|
166
|
+
column count after framework projection narrowing.
|
|
167
|
+
|
|
168
|
+
With pushdown working:
|
|
169
|
+
``SELECT a FROM echo_witness((SELECT 1 AS a, 2 AS b, 3 AS c))`` → 1
|
|
170
|
+
|
|
171
|
+
Without pushdown (DuckDB requests all columns, narrows above):
|
|
172
|
+
``SELECT a FROM echo_witness((SELECT 1 AS a, 2 AS b, 3 AS c))`` → 3
|
|
173
|
+
|
|
174
|
+
Output schema mirrors input (must be all integer columns for the
|
|
175
|
+
encoding to work). Filter pushdown is intentionally off — this
|
|
176
|
+
fixture only probes projection.
|
|
177
|
+
"""
|
|
178
|
+
|
|
179
|
+
class Meta:
|
|
180
|
+
name = "echo_witness"
|
|
181
|
+
description = "Emits len(observed_output_schema) per column — projection probe"
|
|
182
|
+
categories = ["test", "pushdown"]
|
|
183
|
+
projection_pushdown = True
|
|
184
|
+
|
|
185
|
+
@classmethod
|
|
186
|
+
def on_bind(cls, params: BindParams[SingleTableArguments]) -> BindResponse:
|
|
187
|
+
assert params.bind_call.input_schema is not None
|
|
188
|
+
return BindResponse(output_schema=params.bind_call.input_schema)
|
|
189
|
+
|
|
190
|
+
@classmethod
|
|
191
|
+
def process(
|
|
192
|
+
cls,
|
|
193
|
+
params: ProcessParams[SingleTableArguments],
|
|
194
|
+
state: None,
|
|
195
|
+
batch: pa.RecordBatch,
|
|
196
|
+
out: OutputCollector,
|
|
197
|
+
) -> None:
|
|
198
|
+
observed = len(params.output_schema)
|
|
199
|
+
cols = {field.name: pa.array([observed] * batch.num_rows, type=field.type) for field in params.output_schema}
|
|
200
|
+
out.emit(pa.RecordBatch.from_pydict(cols, schema=params.output_schema))
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
class BufferInputFunction(TableBufferingFunction[SingleTableArguments, _LogDrainState]):
|
|
204
|
+
"""Buffering function — collects all input, emits during finalize.
|
|
205
|
+
|
|
206
|
+
One bucket per execution: ``process()`` returns ``params.execution_id``
|
|
207
|
+
for every call and appends to a single shared state_log under
|
|
208
|
+
``(b"buf", b"")``. ``combine()`` collapses every state_id (all are
|
|
209
|
+
identical) to a single finalize_state_id. ``finalize()`` cursor-drains
|
|
210
|
+
one batch per tick.
|
|
211
|
+
|
|
212
|
+
Schema:
|
|
213
|
+
Input: any schema
|
|
214
|
+
Output: same schema (passthrough)
|
|
215
|
+
"""
|
|
216
|
+
|
|
217
|
+
class Meta:
|
|
218
|
+
name = "buffer_input"
|
|
219
|
+
description = "Collects all input batches and emits during finalization"
|
|
220
|
+
categories = ["utility", "buffer"]
|
|
221
|
+
examples = [
|
|
222
|
+
FunctionExample(
|
|
223
|
+
sql="SELECT * FROM buffer_input((SELECT * FROM input_table))",
|
|
224
|
+
description="Buffer all input and emit on finalize",
|
|
225
|
+
)
|
|
226
|
+
]
|
|
227
|
+
|
|
228
|
+
@classmethod
|
|
229
|
+
def on_bind(cls, params: BindParams[SingleTableArguments]) -> BindResponse:
|
|
230
|
+
assert params.bind_call.input_schema is not None
|
|
231
|
+
return BindResponse(output_schema=params.bind_call.input_schema)
|
|
232
|
+
|
|
233
|
+
@classmethod
|
|
234
|
+
def process(
|
|
235
|
+
cls,
|
|
236
|
+
batch: pa.RecordBatch,
|
|
237
|
+
params: TableBufferingParams[SingleTableArguments],
|
|
238
|
+
) -> bytes:
|
|
239
|
+
"""Append the batch to the shared state_log; return execution_id."""
|
|
240
|
+
sink = pa.BufferOutputStream()
|
|
241
|
+
with pa.ipc.new_stream(sink, batch.schema) as writer:
|
|
242
|
+
writer.write_batch(batch)
|
|
243
|
+
params.storage.state_append(b"buf", b"", sink.getvalue().to_pybytes())
|
|
244
|
+
return params.execution_id
|
|
245
|
+
|
|
246
|
+
@classmethod
|
|
247
|
+
def combine(
|
|
248
|
+
cls,
|
|
249
|
+
state_ids: list[bytes],
|
|
250
|
+
params: TableBufferingParams[SingleTableArguments],
|
|
251
|
+
) -> list[bytes]:
|
|
252
|
+
# Every state_id is params.execution_id; collapse to one stream.
|
|
253
|
+
return [params.execution_id]
|
|
254
|
+
|
|
255
|
+
@classmethod
|
|
256
|
+
def initial_finalize_state(
|
|
257
|
+
cls,
|
|
258
|
+
finalize_state_id: bytes,
|
|
259
|
+
params: TableBufferingParams[SingleTableArguments],
|
|
260
|
+
) -> _LogDrainState:
|
|
261
|
+
return _LogDrainState(ns=b"buf", after_id=-1)
|
|
262
|
+
|
|
263
|
+
@classmethod
|
|
264
|
+
def finalize(
|
|
265
|
+
cls,
|
|
266
|
+
params: TableBufferingParams[SingleTableArguments],
|
|
267
|
+
finalize_state_id: bytes,
|
|
268
|
+
state: _LogDrainState,
|
|
269
|
+
out: OutputCollector,
|
|
270
|
+
) -> None:
|
|
271
|
+
"""Emit one buffered batch per tick; finish at end-of-log."""
|
|
272
|
+
rows = params.storage.state_log_scan(
|
|
273
|
+
state.ns,
|
|
274
|
+
b"",
|
|
275
|
+
after_id=state.after_id,
|
|
276
|
+
limit=1,
|
|
277
|
+
)
|
|
278
|
+
if not rows:
|
|
279
|
+
out.finish()
|
|
280
|
+
return
|
|
281
|
+
log_id, value = rows[0]
|
|
282
|
+
out.emit(pa.ipc.open_stream(value).read_next_batch())
|
|
283
|
+
state.after_id = log_id
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
class EchoBufferingFunction(TableBufferingFunction[SingleTableArguments, _LogDrainState]):
|
|
287
|
+
"""Buffered passthrough with projection + filter pushdown enabled.
|
|
288
|
+
|
|
289
|
+
Same shape as :class:`BufferInputFunction` (process buffers input,
|
|
290
|
+
finalize drains one batch per tick) but declares all three pushdown
|
|
291
|
+
flags so DuckDB sends ``projection_ids`` / ``pushdown_filters`` on the
|
|
292
|
+
InitRequest. The framework:
|
|
293
|
+
|
|
294
|
+
* Narrows ``params.output_schema`` to the projected columns; the
|
|
295
|
+
``OutputCollector.emit`` call's ``batch.select(target_names)`` then
|
|
296
|
+
drops non-projected columns from the buffered full-width batch.
|
|
297
|
+
* Wraps ``out`` in ``_FilteringOutputCollector`` (because
|
|
298
|
+
``auto_apply_filters=True``) so emitted batches are filter-applied
|
|
299
|
+
automatically.
|
|
300
|
+
|
|
301
|
+
User code stays the streaming-style passthrough — no awareness of
|
|
302
|
+
projection or filters needed. The fixture verifies that buffered
|
|
303
|
+
TableBufferingFunction pushdown actually plumbs through end-to-end.
|
|
304
|
+
"""
|
|
305
|
+
|
|
306
|
+
class Meta:
|
|
307
|
+
"""Metadata for EchoBufferingFunction."""
|
|
308
|
+
|
|
309
|
+
name = "echo_buffering"
|
|
310
|
+
description = "Buffered passthrough with projection + filter pushdown"
|
|
311
|
+
categories = ["test", "buffer", "pushdown"]
|
|
312
|
+
projection_pushdown = True
|
|
313
|
+
filter_pushdown = True
|
|
314
|
+
auto_apply_filters = True
|
|
315
|
+
|
|
316
|
+
@classmethod
|
|
317
|
+
def on_bind(cls, params: BindParams[SingleTableArguments]) -> BindResponse:
|
|
318
|
+
"""Output schema = input schema (passthrough)."""
|
|
319
|
+
assert params.bind_call.input_schema is not None
|
|
320
|
+
return BindResponse(output_schema=params.bind_call.input_schema)
|
|
321
|
+
|
|
322
|
+
@classmethod
|
|
323
|
+
def process(
|
|
324
|
+
cls,
|
|
325
|
+
batch: pa.RecordBatch,
|
|
326
|
+
params: TableBufferingParams[SingleTableArguments],
|
|
327
|
+
) -> bytes:
|
|
328
|
+
"""Buffer the full input batch (no projection at storage time)."""
|
|
329
|
+
sink = pa.BufferOutputStream()
|
|
330
|
+
with pa.ipc.new_stream(sink, batch.schema) as writer:
|
|
331
|
+
writer.write_batch(batch)
|
|
332
|
+
params.storage.state_append(b"buf", b"", sink.getvalue().to_pybytes())
|
|
333
|
+
return params.execution_id
|
|
334
|
+
|
|
335
|
+
@classmethod
|
|
336
|
+
def combine(
|
|
337
|
+
cls,
|
|
338
|
+
state_ids: list[bytes], # noqa: ARG003 - collapse to one finalize stream
|
|
339
|
+
params: TableBufferingParams[SingleTableArguments],
|
|
340
|
+
) -> list[bytes]:
|
|
341
|
+
return [params.execution_id]
|
|
342
|
+
|
|
343
|
+
@classmethod
|
|
344
|
+
def initial_finalize_state(
|
|
345
|
+
cls,
|
|
346
|
+
finalize_state_id: bytes, # noqa: ARG003 - one bucket per execution
|
|
347
|
+
params: TableBufferingParams[SingleTableArguments], # noqa: ARG003
|
|
348
|
+
) -> _LogDrainState:
|
|
349
|
+
return _LogDrainState(ns=b"buf", after_id=-1)
|
|
350
|
+
|
|
351
|
+
@classmethod
|
|
352
|
+
def finalize(
|
|
353
|
+
cls,
|
|
354
|
+
params: TableBufferingParams[SingleTableArguments],
|
|
355
|
+
finalize_state_id: bytes, # noqa: ARG003
|
|
356
|
+
state: _LogDrainState,
|
|
357
|
+
out: OutputCollector,
|
|
358
|
+
) -> None:
|
|
359
|
+
"""Emit one buffered batch per tick — framework narrows + filters."""
|
|
360
|
+
rows = params.storage.state_log_scan(
|
|
361
|
+
state.ns,
|
|
362
|
+
b"",
|
|
363
|
+
after_id=state.after_id,
|
|
364
|
+
limit=1,
|
|
365
|
+
)
|
|
366
|
+
if not rows:
|
|
367
|
+
out.finish()
|
|
368
|
+
return
|
|
369
|
+
log_id, value = rows[0]
|
|
370
|
+
out.emit(pa.ipc.open_stream(value).read_next_batch())
|
|
371
|
+
state.after_id = log_id
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
class FilterBySettingFunction(TableInOutGenerator[SingleTableArguments]):
|
|
375
|
+
"""Filters input rows where the value column meets a threshold setting.
|
|
376
|
+
|
|
377
|
+
USE CASE
|
|
378
|
+
--------
|
|
379
|
+
Demonstrates how table-in-out functions can use DuckDB settings to control
|
|
380
|
+
behavior. The threshold setting determines which rows pass through: only
|
|
381
|
+
rows where the "value" column >= threshold are emitted.
|
|
382
|
+
|
|
383
|
+
The Setting() on on_bind() serves solely to register ``threshold`` in
|
|
384
|
+
required_settings metadata. The actual filtering uses params.settings
|
|
385
|
+
in process().
|
|
386
|
+
|
|
387
|
+
SCHEMA TRANSFORMATION
|
|
388
|
+
---------------------
|
|
389
|
+
Input: any schema (must contain a "value" column)
|
|
390
|
+
Output: same schema (rows filtered by threshold)
|
|
391
|
+
|
|
392
|
+
Example:
|
|
393
|
+
-------
|
|
394
|
+
With threshold=5 and input [{"value": 3}, {"value": 7}]:
|
|
395
|
+
Output: [{"value": 7}]
|
|
396
|
+
|
|
397
|
+
"""
|
|
398
|
+
|
|
399
|
+
class Meta:
|
|
400
|
+
"""Metadata for FilterBySettingFunction."""
|
|
401
|
+
|
|
402
|
+
name = "filter_by_setting"
|
|
403
|
+
description = "Filter rows where value column >= threshold setting"
|
|
404
|
+
categories = ["transform", "settings"]
|
|
405
|
+
examples = [
|
|
406
|
+
FunctionExample(
|
|
407
|
+
sql="SELECT * FROM filter_by_setting((SELECT * FROM input_table))",
|
|
408
|
+
description="Filter rows using the threshold setting",
|
|
409
|
+
)
|
|
410
|
+
]
|
|
411
|
+
|
|
412
|
+
@classmethod
|
|
413
|
+
def on_bind(
|
|
414
|
+
cls,
|
|
415
|
+
params: BindParams[SingleTableArguments],
|
|
416
|
+
*,
|
|
417
|
+
threshold: Annotated[pa.Scalar[Any] | None, Setting()] = None,
|
|
418
|
+
) -> BindResponse:
|
|
419
|
+
"""Return input schema unchanged. Threshold declared for required_settings."""
|
|
420
|
+
assert params.bind_call.input_schema is not None
|
|
421
|
+
return BindResponse(output_schema=params.bind_call.input_schema)
|
|
422
|
+
|
|
423
|
+
@classmethod
|
|
424
|
+
def process(
|
|
425
|
+
cls,
|
|
426
|
+
params: ProcessParams[SingleTableArguments],
|
|
427
|
+
state: None,
|
|
428
|
+
batch: pa.RecordBatch,
|
|
429
|
+
out: OutputCollector,
|
|
430
|
+
) -> None:
|
|
431
|
+
"""Filter rows where value >= threshold."""
|
|
432
|
+
raw_threshold = params.settings["threshold"]
|
|
433
|
+
# Cast to column type for compatibility (C++ extension may send as string)
|
|
434
|
+
col = batch.column("value")
|
|
435
|
+
threshold = pa.scalar(int(raw_threshold.as_py()), type=col.type)
|
|
436
|
+
mask = pc.greater_equal(col, threshold)
|
|
437
|
+
out.emit(batch.filter(mask))
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
@dataclass(slots=True, frozen=True)
|
|
441
|
+
class RepeatsInputsFunctionArguments:
|
|
442
|
+
"""Arguments for RepeatInputsFunction."""
|
|
443
|
+
|
|
444
|
+
repeat_count: Annotated[int, Arg(0, doc="Number of times to repeat each input batch")]
|
|
445
|
+
data: Annotated[TableInput, Arg(1, doc="Input table to repeat")]
|
|
446
|
+
|
|
447
|
+
|
|
448
|
+
class RepeatInputsFunction(TableInOutGenerator[RepeatsInputsFunctionArguments]):
|
|
449
|
+
"""Explosion function that duplicates each input batch N times.
|
|
450
|
+
|
|
451
|
+
USE CASE
|
|
452
|
+
--------
|
|
453
|
+
Data augmentation, testing with larger datasets, or any scenario where
|
|
454
|
+
you need multiple copies of each input record.
|
|
455
|
+
|
|
456
|
+
Arguments:
|
|
457
|
+
---------
|
|
458
|
+
repeat_count: Annotated[int, Arg(0)] (required)
|
|
459
|
+
Number of times to repeat each input batch.
|
|
460
|
+
|
|
461
|
+
BEHAVIOR
|
|
462
|
+
--------
|
|
463
|
+
- output_schema: Returns input schema unchanged
|
|
464
|
+
- process(): For each input, concatenates it N times into one output
|
|
465
|
+
|
|
466
|
+
SCHEMA TRANSFORMATION
|
|
467
|
+
---------------------
|
|
468
|
+
Input: any schema
|
|
469
|
+
Output: same schema (passthrough)
|
|
470
|
+
|
|
471
|
+
Example:
|
|
472
|
+
-------
|
|
473
|
+
With repeat_count=3:
|
|
474
|
+
Input: [{"a": 1}]
|
|
475
|
+
Output: [{"a": 1}, {"a": 1}, {"a": 1}]
|
|
476
|
+
|
|
477
|
+
"""
|
|
478
|
+
|
|
479
|
+
class Meta:
|
|
480
|
+
"""Metadata for RepeatInputsFunction."""
|
|
481
|
+
|
|
482
|
+
name = "repeat_inputs"
|
|
483
|
+
description = "Duplicates each input batch N times"
|
|
484
|
+
categories = ["transform", "augmentation"]
|
|
485
|
+
examples = [
|
|
486
|
+
FunctionExample(
|
|
487
|
+
sql="SELECT * FROM repeat_inputs(3, (SELECT * FROM input_table))",
|
|
488
|
+
description="Repeat each row 3 times",
|
|
489
|
+
)
|
|
490
|
+
]
|
|
491
|
+
|
|
492
|
+
@classmethod
|
|
493
|
+
def on_bind(cls, params: BindParams[RepeatsInputsFunctionArguments]) -> BindResponse:
|
|
494
|
+
"""Validate repeat count argument."""
|
|
495
|
+
if params.args.repeat_count < 1:
|
|
496
|
+
raise ValueError("Repeat count must be at least 1")
|
|
497
|
+
if params.bind_call.input_schema is None:
|
|
498
|
+
raise ValueError("input_schema is required but was None")
|
|
499
|
+
return BindResponse(output_schema=params.bind_call.input_schema)
|
|
500
|
+
|
|
501
|
+
@classmethod
|
|
502
|
+
def process(
|
|
503
|
+
cls,
|
|
504
|
+
params: ProcessParams[RepeatsInputsFunctionArguments],
|
|
505
|
+
state: None,
|
|
506
|
+
batch: pa.RecordBatch,
|
|
507
|
+
out: OutputCollector,
|
|
508
|
+
) -> None:
|
|
509
|
+
"""Emit input batch concatenated repeat_count times."""
|
|
510
|
+
combined = pa.Table.from_batches([batch] * params.args.repeat_count).combine_chunks()
|
|
511
|
+
out.emit(combined.to_batches()[0])
|
|
512
|
+
|
|
513
|
+
|
|
514
|
+
@dataclass(slots=True, frozen=True, kw_only=True)
|
|
515
|
+
class SumAllColumnsFunctionArguments:
|
|
516
|
+
"""Arguments for SumAllColumnsFunction."""
|
|
517
|
+
|
|
518
|
+
data: Annotated[TableInput, Arg(0, doc="Input table")]
|
|
519
|
+
logging: Annotated[bool, Arg("logging", doc="Whether to log during processing", default=False)] = False
|
|
520
|
+
|
|
521
|
+
|
|
522
|
+
@dataclass(kw_only=True)
|
|
523
|
+
class SumAllColumnsState(ArrowSerializableDataclass):
|
|
524
|
+
"""Mutable state for SumAllColumnsFunction - tracks running sums."""
|
|
525
|
+
|
|
526
|
+
partial_sums: Annotated[pa.RecordBatch, ArrowType(pa.binary())]
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
class SumAllColumnsFunction(TableBufferingFunction[SumAllColumnsFunctionArguments, _LogDrainState]):
|
|
530
|
+
"""Aggregation function that computes column-wise sums across all batches.
|
|
531
|
+
|
|
532
|
+
USE CASE
|
|
533
|
+
--------
|
|
534
|
+
Computing totals, aggregating metrics, or any full-stream reduction
|
|
535
|
+
that produces a single summary row.
|
|
536
|
+
|
|
537
|
+
BEHAVIOR
|
|
538
|
+
--------
|
|
539
|
+
- process(): Accumulates sums and emits empty results
|
|
540
|
+
- finalize(): Returns single row with final sums
|
|
541
|
+
|
|
542
|
+
SCHEMA TRANSFORMATION
|
|
543
|
+
---------------------
|
|
544
|
+
Input: any schema with numeric columns
|
|
545
|
+
Output: only numeric columns, promoted to int64/float64
|
|
546
|
+
|
|
547
|
+
For each input column:
|
|
548
|
+
- Integer types -> int64
|
|
549
|
+
- Floating types -> float64
|
|
550
|
+
- Non-numeric types -> excluded from output
|
|
551
|
+
|
|
552
|
+
KEY PATTERN: ACCUMULATE IN process(), EMIT IN finalize()
|
|
553
|
+
--------------------------------------------------------
|
|
554
|
+
In process(), accumulate state but emit empty results.
|
|
555
|
+
In finalize(), return the final aggregated result.
|
|
556
|
+
|
|
557
|
+
Example:
|
|
558
|
+
-------
|
|
559
|
+
Input schema: {"a": int32, "b": float32, "name": string}
|
|
560
|
+
Output schema: {"a": int64, "b": float64} (string column excluded)
|
|
561
|
+
|
|
562
|
+
Input batches:
|
|
563
|
+
[{"a": 1, "b": 1.5, "name": "x"}, {"a": 2, "b": 2.5, "name": "y"}]
|
|
564
|
+
[{"a": 3, "b": 3.0, "name": "z"}]
|
|
565
|
+
|
|
566
|
+
Output (single row):
|
|
567
|
+
[{"a": 6, "b": 7.0}]
|
|
568
|
+
|
|
569
|
+
"""
|
|
570
|
+
|
|
571
|
+
class Meta:
|
|
572
|
+
name = "sum_all_columns"
|
|
573
|
+
description = "Computes column-wise sums across all batches"
|
|
574
|
+
categories = ["aggregation", "numeric"]
|
|
575
|
+
examples = [
|
|
576
|
+
FunctionExample(
|
|
577
|
+
sql="SELECT * FROM sum_all_columns((SELECT * FROM input_table))",
|
|
578
|
+
description="Sum all numeric columns",
|
|
579
|
+
)
|
|
580
|
+
]
|
|
581
|
+
|
|
582
|
+
@classmethod
|
|
583
|
+
def cardinality(cls, params: BindParams[SumAllColumnsFunctionArguments]) -> TableCardinality:
|
|
584
|
+
"""Return cardinality estimate of exactly 1 row."""
|
|
585
|
+
return TableCardinality(estimate=1, max=1)
|
|
586
|
+
|
|
587
|
+
@classmethod
|
|
588
|
+
def on_bind(cls, params: BindParams[SumAllColumnsFunctionArguments]) -> BindResponse:
|
|
589
|
+
"""Produce the output schema with only numeric columns.
|
|
590
|
+
|
|
591
|
+
Numeric here means integer, floating-point, or fixed-precision
|
|
592
|
+
decimal. DECIMAL inputs are promoted to float64 in the output
|
|
593
|
+
(matching the float path) — DuckDB users routinely expect a sum
|
|
594
|
+
of DECIMAL values to be summable, so silently dropping them would
|
|
595
|
+
be surprising. Non-numeric inputs (strings, lists, timestamps,
|
|
596
|
+
booleans) are filtered out; if NO numeric columns remain we raise
|
|
597
|
+
ValueError at bind time rather than producing an empty output
|
|
598
|
+
schema (which would crash downstream with an internal assertion).
|
|
599
|
+
"""
|
|
600
|
+
assert params.bind_call.input_schema is not None
|
|
601
|
+
output_fields: dict[str, pa.DataType] = {}
|
|
602
|
+
for field in params.bind_call.input_schema:
|
|
603
|
+
out_type: pa.DataType
|
|
604
|
+
if pa.types.is_integer(field.type):
|
|
605
|
+
out_type = pa.int64()
|
|
606
|
+
elif pa.types.is_floating(field.type):
|
|
607
|
+
out_type = pa.float64()
|
|
608
|
+
elif pa.types.is_decimal(field.type):
|
|
609
|
+
# Promote DECIMAL to float64 for the summed output.
|
|
610
|
+
# (A more precise implementation would widen the decimal
|
|
611
|
+
# type to absorb sum overflow, but for a test fixture
|
|
612
|
+
# this is sufficient.)
|
|
613
|
+
out_type = pa.float64()
|
|
614
|
+
else:
|
|
615
|
+
continue
|
|
616
|
+
output_fields[field.name] = out_type
|
|
617
|
+
|
|
618
|
+
if not output_fields:
|
|
619
|
+
input_summary = ", ".join(f"{f.name}: {f.type}" for f in params.bind_call.input_schema)
|
|
620
|
+
raise ValueError(
|
|
621
|
+
"sum_all_columns requires at least one numeric (integer, "
|
|
622
|
+
"floating-point, or decimal) input column, got [" + input_summary + "]"
|
|
623
|
+
)
|
|
624
|
+
|
|
625
|
+
return BindResponse(output_schema=schema(output_fields))
|
|
626
|
+
|
|
627
|
+
@staticmethod
|
|
628
|
+
def _scalars_to_single_row_batch(values: dict[str, pa.Scalar]) -> pa.RecordBatch: # type: ignore[type-arg]
|
|
629
|
+
arrays = [pa.array([scalar], type=scalar.type) for scalar in values.values()]
|
|
630
|
+
return pa.RecordBatch.from_arrays(arrays, names=list(values.keys()))
|
|
631
|
+
|
|
632
|
+
@classmethod
|
|
633
|
+
def process(
|
|
634
|
+
cls,
|
|
635
|
+
batch: pa.RecordBatch,
|
|
636
|
+
params: TableBufferingParams[SumAllColumnsFunctionArguments],
|
|
637
|
+
) -> bytes:
|
|
638
|
+
"""Append this batch's partial sums to the per-execution log.
|
|
639
|
+
|
|
640
|
+
Race-safe append (state_append is atomic). combine() reduces.
|
|
641
|
+
"""
|
|
642
|
+
if params.args.logging:
|
|
643
|
+
# Goes through the same wire mechanism as the streaming
|
|
644
|
+
# ``out.client_log()`` path — emits a 0-row log batch on the
|
|
645
|
+
# ``table_buffering_process`` response stream that DuckDB
|
|
646
|
+
# surfaces in ``duckdb_logs()`` with type='VGI'. The Python
|
|
647
|
+
# stdlib ``logging.getLogger(...).info(...)`` we used before
|
|
648
|
+
# didn't reach the wire and never showed up in duckdb_logs
|
|
649
|
+
# (the framework provides no stdlib-logging-to-wire bridge).
|
|
650
|
+
params.client_log(
|
|
651
|
+
Level.INFO,
|
|
652
|
+
f"Processing batch with {batch.num_rows} rows",
|
|
653
|
+
)
|
|
654
|
+
|
|
655
|
+
# Compute partial sums for this batch only.
|
|
656
|
+
sums: dict[str, pa.Scalar[Any]] = {}
|
|
657
|
+
for name in params.output_schema.names:
|
|
658
|
+
col_sum = pc.sum(batch.column(name))
|
|
659
|
+
if col_sum.is_valid:
|
|
660
|
+
sums[name] = col_sum
|
|
661
|
+
else:
|
|
662
|
+
sums[name] = pa.scalar(
|
|
663
|
+
0,
|
|
664
|
+
type=params.output_schema.field(name).type,
|
|
665
|
+
)
|
|
666
|
+
partial = cls._scalars_to_single_row_batch(sums)
|
|
667
|
+
params.storage.state_append(
|
|
668
|
+
b"partial",
|
|
669
|
+
b"",
|
|
670
|
+
SumAllColumnsState(partial_sums=partial).serialize_to_bytes(),
|
|
671
|
+
)
|
|
672
|
+
return params.execution_id
|
|
673
|
+
|
|
674
|
+
@classmethod
|
|
675
|
+
def combine(
|
|
676
|
+
cls,
|
|
677
|
+
state_ids: list[bytes],
|
|
678
|
+
params: TableBufferingParams[SumAllColumnsFunctionArguments],
|
|
679
|
+
) -> list[bytes]:
|
|
680
|
+
"""Reduce all per-batch partials into one merged batch.
|
|
681
|
+
|
|
682
|
+
combine() runs once on the coordinator after every process()
|
|
683
|
+
completes — no race here. Drains the append-only log, sums, and
|
|
684
|
+
writes the merged row to b"buf"/b"" for finalize to drain.
|
|
685
|
+
|
|
686
|
+
Empty-input guard: even with no state_ids, writes the zeros row
|
|
687
|
+
so ``SELECT ... FROM sum_all_columns((SELECT 1 WHERE 1=0))``
|
|
688
|
+
produces one row of the expected shape.
|
|
689
|
+
"""
|
|
690
|
+
if params.args.logging:
|
|
691
|
+
# Symmetric with process() — fires through ``params.client_log``
|
|
692
|
+
# (the unary-RPC analogue of ``out.client_log``) so the message
|
|
693
|
+
# lands in DuckDB's ``duckdb_logs()`` with type='VGI'. Used by
|
|
694
|
+
# ``logging.test`` to verify the in-band log path works from
|
|
695
|
+
# ``combine()`` too, not just from ``process()``.
|
|
696
|
+
params.client_log(
|
|
697
|
+
Level.INFO,
|
|
698
|
+
f"Combining {len(state_ids)} state_ids",
|
|
699
|
+
)
|
|
700
|
+
|
|
701
|
+
merged: dict[str, pa.Scalar[Any]] = {
|
|
702
|
+
name: pa.scalar(0, type=field.type)
|
|
703
|
+
for name, field in zip(
|
|
704
|
+
params.output_schema.names,
|
|
705
|
+
params.output_schema,
|
|
706
|
+
strict=True,
|
|
707
|
+
)
|
|
708
|
+
}
|
|
709
|
+
for _log_id, blob in params.storage.state_log_scan(b"partial", b""):
|
|
710
|
+
partial = SumAllColumnsState.deserialize_from_bytes(blob).partial_sums
|
|
711
|
+
for name in params.output_schema.names:
|
|
712
|
+
merged[name] = pc.add(merged[name], partial.column(name)[0])
|
|
713
|
+
merged_batch = cls._scalars_to_single_row_batch(merged)
|
|
714
|
+
sink = pa.BufferOutputStream()
|
|
715
|
+
with pa.ipc.new_stream(sink, merged_batch.schema) as w:
|
|
716
|
+
w.write_batch(merged_batch)
|
|
717
|
+
params.storage.state_append(b"buf", b"", sink.getvalue().to_pybytes())
|
|
718
|
+
return [params.execution_id]
|
|
719
|
+
|
|
720
|
+
@classmethod
|
|
721
|
+
def initial_finalize_state(
|
|
722
|
+
cls,
|
|
723
|
+
finalize_state_id: bytes,
|
|
724
|
+
params: TableBufferingParams[SumAllColumnsFunctionArguments],
|
|
725
|
+
) -> _LogDrainState:
|
|
726
|
+
return _LogDrainState(ns=b"buf", after_id=-1)
|
|
727
|
+
|
|
728
|
+
@classmethod
|
|
729
|
+
def finalize(
|
|
730
|
+
cls,
|
|
731
|
+
params: TableBufferingParams[SumAllColumnsFunctionArguments],
|
|
732
|
+
finalize_state_id: bytes,
|
|
733
|
+
state: _LogDrainState,
|
|
734
|
+
out: OutputCollector,
|
|
735
|
+
) -> None:
|
|
736
|
+
rows = params.storage.state_log_scan(
|
|
737
|
+
state.ns,
|
|
738
|
+
b"",
|
|
739
|
+
after_id=state.after_id,
|
|
740
|
+
limit=1,
|
|
741
|
+
)
|
|
742
|
+
if not rows:
|
|
743
|
+
out.finish()
|
|
744
|
+
return
|
|
745
|
+
log_id, value = rows[0]
|
|
746
|
+
out.emit(pa.ipc.open_stream(value).read_next_batch())
|
|
747
|
+
state.after_id = log_id
|
|
748
|
+
|
|
749
|
+
|
|
750
|
+
@dataclass(kw_only=True)
|
|
751
|
+
class ExceptionProcessState(ArrowSerializableDataclass):
|
|
752
|
+
"""Mutable state for ExceptionProcessFunction."""
|
|
753
|
+
|
|
754
|
+
batch_count: int = 0
|
|
755
|
+
|
|
756
|
+
|
|
757
|
+
class ExceptionProcessFunction(SumAllColumnsFunction):
|
|
758
|
+
"""Buffered table function that raises an exception on the second batch."""
|
|
759
|
+
|
|
760
|
+
class Meta(SumAllColumnsFunction.Meta):
|
|
761
|
+
name = "exception_process"
|
|
762
|
+
description = "Test function that raises exception during process"
|
|
763
|
+
categories = ["test", "error"]
|
|
764
|
+
|
|
765
|
+
@classmethod
|
|
766
|
+
def process(
|
|
767
|
+
cls,
|
|
768
|
+
batch: pa.RecordBatch,
|
|
769
|
+
params: TableBufferingParams[SumAllColumnsFunctionArguments],
|
|
770
|
+
) -> bytes:
|
|
771
|
+
"""Raise an exception on the second batch.
|
|
772
|
+
|
|
773
|
+
Race-safe counter: append-only log under b"count"/b"" — count is
|
|
774
|
+
the number of log entries seen so far. Concurrent process() calls
|
|
775
|
+
on HTTP serialize through state_append's atomic id minting.
|
|
776
|
+
"""
|
|
777
|
+
params.storage.state_append(b"count", b"", b"")
|
|
778
|
+
count = len(params.storage.state_log_scan(b"count", b""))
|
|
779
|
+
if count % 2 == 0:
|
|
780
|
+
raise ValueError(f"Intentional exception on batch {count}")
|
|
781
|
+
return params.execution_id
|
|
782
|
+
|
|
783
|
+
|
|
784
|
+
class ExceptionFinalizeFunction(SumAllColumnsFunction):
|
|
785
|
+
"""Buffered table function that raises an exception during finalize()."""
|
|
786
|
+
|
|
787
|
+
class Meta(SumAllColumnsFunction.Meta):
|
|
788
|
+
name = "exception_finalize"
|
|
789
|
+
description = "Test function that raises exception during finalize"
|
|
790
|
+
categories = ["test", "error"]
|
|
791
|
+
|
|
792
|
+
@classmethod
|
|
793
|
+
def finalize(
|
|
794
|
+
cls,
|
|
795
|
+
params: TableBufferingParams[SumAllColumnsFunctionArguments],
|
|
796
|
+
finalize_state_id: bytes,
|
|
797
|
+
state: _LogDrainState,
|
|
798
|
+
out: OutputCollector,
|
|
799
|
+
) -> None:
|
|
800
|
+
raise ValueError("Intentional exception during finalize()")
|
|
801
|
+
|
|
802
|
+
|
|
803
|
+
@dataclass(slots=True, kw_only=True)
|
|
804
|
+
class SumAllColumnsSimpleDistributedState(ArrowSerializableDataclass):
|
|
805
|
+
"""Partial sum state for distributed aggregation."""
|
|
806
|
+
|
|
807
|
+
partial_sum: Annotated[pa.RecordBatch, ArrowType(pa.binary())]
|
|
808
|
+
|
|
809
|
+
|
|
810
|
+
class SumAllColumnsSimpleDistributed(TableInOutFunction[SingleTableArguments, SumAllColumnsSimpleDistributedState]):
|
|
811
|
+
"""Distributed aggregation using the simple callback API.
|
|
812
|
+
|
|
813
|
+
This function demonstrates TableInOutFunction with distributed
|
|
814
|
+
state management.
|
|
815
|
+
|
|
816
|
+
It's equivalent to SumAllColumnsFunctionDistributed but uses
|
|
817
|
+
the simpler callback API.
|
|
818
|
+
|
|
819
|
+
Example:
|
|
820
|
+
-------
|
|
821
|
+
Input batches (split across workers):
|
|
822
|
+
Worker 1: [{a: 1, b: 1.0}, {a: 2, b: 2.0}]
|
|
823
|
+
Worker 2: [{a: 3, b: 3.0}]
|
|
824
|
+
|
|
825
|
+
Each worker computes partial sums:
|
|
826
|
+
Worker 1 state: {a: 3, b: 3.0}
|
|
827
|
+
Worker 2 state: {a: 3, b: 3.0}
|
|
828
|
+
|
|
829
|
+
Primary worker merges states in finish():
|
|
830
|
+
Combined: {a: 6, b: 6.0}
|
|
831
|
+
|
|
832
|
+
Output (single row):
|
|
833
|
+
[{a: 6, b: 6.0}]
|
|
834
|
+
|
|
835
|
+
"""
|
|
836
|
+
|
|
837
|
+
class Meta:
|
|
838
|
+
"""Metadata for SumAllColumnsSimpleDistributed."""
|
|
839
|
+
|
|
840
|
+
name = "sum_all_columns_simple_distributed"
|
|
841
|
+
description = "Distributed sum using simple callback API"
|
|
842
|
+
categories = ["aggregation", "numeric", "distributed"]
|
|
843
|
+
examples = [
|
|
844
|
+
FunctionExample(
|
|
845
|
+
sql=("SELECT * FROM sum_all_columns_simple_distributed((SELECT * FROM input_table))"),
|
|
846
|
+
description="Sum columns using distributed workers with callback API",
|
|
847
|
+
)
|
|
848
|
+
]
|
|
849
|
+
|
|
850
|
+
@classmethod
|
|
851
|
+
def cardinality(cls, params: BindParams[SingleTableArguments]) -> TableCardinality:
|
|
852
|
+
"""Return cardinality estimate of exactly 1 row."""
|
|
853
|
+
return TableCardinality(estimate=1, max=1)
|
|
854
|
+
|
|
855
|
+
@classmethod
|
|
856
|
+
def on_bind(cls, params: BindParams[SingleTableArguments]) -> BindResponse:
|
|
857
|
+
"""Produce the output schema with only numeric columns."""
|
|
858
|
+
assert params.bind_call.input_schema is not None
|
|
859
|
+
output_fields: dict[str, pa.DataType] = {}
|
|
860
|
+
for field in params.bind_call.input_schema:
|
|
861
|
+
out_type: pa.DataType
|
|
862
|
+
if pa.types.is_integer(field.type):
|
|
863
|
+
out_type = pa.int64()
|
|
864
|
+
elif pa.types.is_floating(field.type):
|
|
865
|
+
out_type = pa.float64()
|
|
866
|
+
else:
|
|
867
|
+
continue
|
|
868
|
+
output_fields[field.name] = out_type
|
|
869
|
+
|
|
870
|
+
return BindResponse(output_schema=schema(output_fields))
|
|
871
|
+
|
|
872
|
+
@classmethod
|
|
873
|
+
def initial_state(cls, params: ProcessParams[SingleTableArguments]) -> SumAllColumnsSimpleDistributedState | None:
|
|
874
|
+
"""Create the initial state."""
|
|
875
|
+
return SumAllColumnsSimpleDistributedState(
|
|
876
|
+
partial_sum=pa.RecordBatch.from_pylist(
|
|
877
|
+
[{name: 0 for name in params.output_schema.names}], schema=params.output_schema
|
|
878
|
+
)
|
|
879
|
+
)
|
|
880
|
+
|
|
881
|
+
@classmethod
|
|
882
|
+
def transform(
|
|
883
|
+
cls,
|
|
884
|
+
batch: pa.RecordBatch,
|
|
885
|
+
params: ProcessParams[SingleTableArguments],
|
|
886
|
+
state: SumAllColumnsSimpleDistributedState | None,
|
|
887
|
+
) -> pa.RecordBatch:
|
|
888
|
+
"""Accumulate column sums. Emit nothing during processing."""
|
|
889
|
+
if state is None:
|
|
890
|
+
raise ValueError("State must not be None in transform()")
|
|
891
|
+
# Add this batch's values to running sums
|
|
892
|
+
sums: dict[str, pa.Scalar[Any]] = {}
|
|
893
|
+
for name in params.output_schema.names:
|
|
894
|
+
col_sum = pc.sum(batch.column(name))
|
|
895
|
+
if col_sum.is_valid:
|
|
896
|
+
sums[name] = pc.add(state.partial_sum.column(name)[0], col_sum)
|
|
897
|
+
else:
|
|
898
|
+
sums[name] = state.partial_sum.column(name)[0]
|
|
899
|
+
|
|
900
|
+
state.partial_sum = pa.RecordBatch.from_pylist(
|
|
901
|
+
[{name: val for name, val in sums.items()}],
|
|
902
|
+
schema=params.output_schema,
|
|
903
|
+
)
|
|
904
|
+
|
|
905
|
+
return empty_batch(params.output_schema)
|
|
906
|
+
|
|
907
|
+
@classmethod
|
|
908
|
+
def finish(
|
|
909
|
+
cls,
|
|
910
|
+
params: ProcessParams[SingleTableArguments],
|
|
911
|
+
states: list[SumAllColumnsSimpleDistributedState],
|
|
912
|
+
) -> list[pa.RecordBatch]:
|
|
913
|
+
"""Emit single row containing the column sums."""
|
|
914
|
+
table = pa.Table.from_batches([state.partial_sum for state in states])
|
|
915
|
+
|
|
916
|
+
sums: dict[str, pa.Scalar[Any]] = {}
|
|
917
|
+
for field in params.output_schema:
|
|
918
|
+
sums[field.name] = pa.scalar(0, type=field.type)
|
|
919
|
+
|
|
920
|
+
for field in params.output_schema:
|
|
921
|
+
sums[field.name] = pc.sum(table.column(field.name))
|
|
922
|
+
|
|
923
|
+
return [pa.RecordBatch.from_pylist([{name: val for name, val in sums.items()}], schema=params.output_schema)]
|
|
924
|
+
|
|
925
|
+
|
|
926
|
+
# ============================================================================
|
|
927
|
+
# Failure-injection fixtures
|
|
928
|
+
# ============================================================================
|
|
929
|
+
# These are table_buffering functions designed to exercise crash/error paths in
|
|
930
|
+
# the C++ Sink+Source operator. Each one fails in a specific phase so we can
|
|
931
|
+
# test that the operator: throws cleanly, drains the worker pool, doesn't leak
|
|
932
|
+
# in-flight workers, and recovers on the next query.
|
|
933
|
+
|
|
934
|
+
|
|
935
|
+
class CrashOnProcessFunction(BufferInputFunction):
|
|
936
|
+
"""SIGKILLs its own worker process on the first process() call."""
|
|
937
|
+
|
|
938
|
+
class Meta(BufferInputFunction.Meta):
|
|
939
|
+
name = "crash_on_process"
|
|
940
|
+
description = "Worker SIGKILLs itself during process (test)"
|
|
941
|
+
categories = ["test", "crash"]
|
|
942
|
+
|
|
943
|
+
@classmethod
|
|
944
|
+
def process(
|
|
945
|
+
cls,
|
|
946
|
+
batch: pa.RecordBatch,
|
|
947
|
+
params: TableBufferingParams[SingleTableArguments],
|
|
948
|
+
) -> bytes:
|
|
949
|
+
if sys.platform == "win32": # pragma: no cover - hard-crash equivalent
|
|
950
|
+
os.kill(os.getpid(), signal.SIGABRT)
|
|
951
|
+
else:
|
|
952
|
+
os.kill(os.getpid(), signal.SIGKILL)
|
|
953
|
+
return params.execution_id # unreachable
|
|
954
|
+
|
|
955
|
+
|
|
956
|
+
class CrashOnCombineFunction(BufferInputFunction):
|
|
957
|
+
"""Buffers input normally; raises during combine()."""
|
|
958
|
+
|
|
959
|
+
class Meta(BufferInputFunction.Meta):
|
|
960
|
+
name = "crash_on_combine"
|
|
961
|
+
description = "Worker raises during combine (test)"
|
|
962
|
+
categories = ["test", "crash"]
|
|
963
|
+
|
|
964
|
+
@classmethod
|
|
965
|
+
def combine(cls, state_ids: list[bytes], params: TableBufferingParams[SingleTableArguments]) -> list[bytes]:
|
|
966
|
+
raise RuntimeError("Intentional exception during combine()")
|
|
967
|
+
|
|
968
|
+
|
|
969
|
+
class CrashOnFinalizeFunction(BufferInputFunction):
|
|
970
|
+
"""Combine returns normally, finalize raises on first tick."""
|
|
971
|
+
|
|
972
|
+
class Meta(BufferInputFunction.Meta):
|
|
973
|
+
name = "crash_on_finalize"
|
|
974
|
+
description = "Worker raises during finalize (test)"
|
|
975
|
+
categories = ["test", "crash"]
|
|
976
|
+
|
|
977
|
+
@classmethod
|
|
978
|
+
def finalize(
|
|
979
|
+
cls,
|
|
980
|
+
params: TableBufferingParams[SingleTableArguments],
|
|
981
|
+
finalize_state_id: bytes,
|
|
982
|
+
state: _LogDrainState,
|
|
983
|
+
out: OutputCollector,
|
|
984
|
+
) -> None:
|
|
985
|
+
raise ValueError("Intentional exception during finalize()")
|
|
986
|
+
|
|
987
|
+
|
|
988
|
+
class HangOnProcessFunction(BufferInputFunction):
|
|
989
|
+
"""Sleeps for an hour in process(); used by the manual cancellation smoke."""
|
|
990
|
+
|
|
991
|
+
class Meta(BufferInputFunction.Meta):
|
|
992
|
+
name = "hang_on_process"
|
|
993
|
+
description = "Worker sleeps in process (manual cancel test)"
|
|
994
|
+
categories = ["test", "hang"]
|
|
995
|
+
|
|
996
|
+
@classmethod
|
|
997
|
+
def process(
|
|
998
|
+
cls,
|
|
999
|
+
batch: pa.RecordBatch,
|
|
1000
|
+
params: TableBufferingParams[SingleTableArguments],
|
|
1001
|
+
) -> bytes:
|
|
1002
|
+
time.sleep(3600)
|
|
1003
|
+
return params.execution_id # unreachable
|
|
1004
|
+
|
|
1005
|
+
|
|
1006
|
+
@dataclass(kw_only=True)
|
|
1007
|
+
class LargeStateState(ArrowSerializableDataclass):
|
|
1008
|
+
"""Holds a single large bytes buffer.
|
|
1009
|
+
|
|
1010
|
+
The C++ side ships this through combine/finalize via the RPC's outer
|
|
1011
|
+
envelope, exercising IPC chunking on the response path.
|
|
1012
|
+
"""
|
|
1013
|
+
|
|
1014
|
+
payload: bytes = b""
|
|
1015
|
+
|
|
1016
|
+
|
|
1017
|
+
class LargeStateFunction(TableBufferingFunction[SingleTableArguments, _LogDrainState]):
|
|
1018
|
+
"""Accumulates a large payload per state_id and emits it during finalize.
|
|
1019
|
+
|
|
1020
|
+
Each process() call appends 1 MB to the per-worker payload via RMW on
|
|
1021
|
+
BoundStorage. combine() materializes one output row per worker into a
|
|
1022
|
+
state_log; finalize() drains it cursor-style.
|
|
1023
|
+
"""
|
|
1024
|
+
|
|
1025
|
+
class Meta:
|
|
1026
|
+
name = "large_state"
|
|
1027
|
+
description = "Buffers ~1 MB per input batch into state (IPC test)"
|
|
1028
|
+
categories = ["test", "memory"]
|
|
1029
|
+
|
|
1030
|
+
@classmethod
|
|
1031
|
+
def on_bind(cls, params: BindParams[SingleTableArguments]) -> BindResponse:
|
|
1032
|
+
assert params.bind_call.input_schema is not None
|
|
1033
|
+
return BindResponse(output_schema=params.bind_call.input_schema)
|
|
1034
|
+
|
|
1035
|
+
@classmethod
|
|
1036
|
+
def process(
|
|
1037
|
+
cls,
|
|
1038
|
+
batch: pa.RecordBatch,
|
|
1039
|
+
params: TableBufferingParams[SingleTableArguments],
|
|
1040
|
+
) -> bytes:
|
|
1041
|
+
"""Append 1 MB per call to the per-execution log.
|
|
1042
|
+
|
|
1043
|
+
Race-safe append; combine() sums the total payload size by
|
|
1044
|
+
scanning the log.
|
|
1045
|
+
"""
|
|
1046
|
+
params.storage.state_append(
|
|
1047
|
+
b"large",
|
|
1048
|
+
b"",
|
|
1049
|
+
b"\x00" * (1024 * 1024),
|
|
1050
|
+
)
|
|
1051
|
+
return params.execution_id
|
|
1052
|
+
|
|
1053
|
+
@classmethod
|
|
1054
|
+
def combine(
|
|
1055
|
+
cls,
|
|
1056
|
+
state_ids: list[bytes],
|
|
1057
|
+
params: TableBufferingParams[SingleTableArguments],
|
|
1058
|
+
) -> list[bytes]:
|
|
1059
|
+
"""Materialize one output row carrying the total payload size."""
|
|
1060
|
+
total = sum(len(blob) for _log_id, blob in params.storage.state_log_scan(b"large", b""))
|
|
1061
|
+
out_batch = pa.RecordBatch.from_pydict(
|
|
1062
|
+
{name: [total] for name in params.output_schema.names},
|
|
1063
|
+
schema=params.output_schema,
|
|
1064
|
+
)
|
|
1065
|
+
sink = pa.BufferOutputStream()
|
|
1066
|
+
with pa.ipc.new_stream(sink, out_batch.schema) as w:
|
|
1067
|
+
w.write_batch(out_batch)
|
|
1068
|
+
params.storage.state_append(b"buf", b"", sink.getvalue().to_pybytes())
|
|
1069
|
+
return [params.execution_id]
|
|
1070
|
+
|
|
1071
|
+
@classmethod
|
|
1072
|
+
def initial_finalize_state(
|
|
1073
|
+
cls,
|
|
1074
|
+
finalize_state_id: bytes,
|
|
1075
|
+
params: TableBufferingParams[SingleTableArguments],
|
|
1076
|
+
) -> _LogDrainState:
|
|
1077
|
+
return _LogDrainState(ns=b"buf", after_id=-1)
|
|
1078
|
+
|
|
1079
|
+
@classmethod
|
|
1080
|
+
def finalize(
|
|
1081
|
+
cls,
|
|
1082
|
+
params: TableBufferingParams[SingleTableArguments],
|
|
1083
|
+
finalize_state_id: bytes,
|
|
1084
|
+
state: _LogDrainState,
|
|
1085
|
+
out: OutputCollector,
|
|
1086
|
+
) -> None:
|
|
1087
|
+
rows = params.storage.state_log_scan(
|
|
1088
|
+
state.ns,
|
|
1089
|
+
b"",
|
|
1090
|
+
after_id=state.after_id,
|
|
1091
|
+
limit=1,
|
|
1092
|
+
)
|
|
1093
|
+
if not rows:
|
|
1094
|
+
out.finish()
|
|
1095
|
+
return
|
|
1096
|
+
log_id, value = rows[0]
|
|
1097
|
+
out.emit(pa.ipc.open_stream(value).read_next_batch())
|
|
1098
|
+
state.after_id = log_id
|
|
1099
|
+
|
|
1100
|
+
|
|
1101
|
+
# ============================================================================
|
|
1102
|
+
# Ordering knobs (sink_order_dependent, requires_input_batch_index)
|
|
1103
|
+
# ============================================================================
|
|
1104
|
+
|
|
1105
|
+
|
|
1106
|
+
class OrderedBufferInputFunction(BufferInputFunction):
|
|
1107
|
+
"""Buffered table function with single-threaded ingest.
|
|
1108
|
+
|
|
1109
|
+
Uses ``Meta.sink_order_dependent=True`` to force ``ParallelSink=false`` on
|
|
1110
|
+
the C++ operator. Every ``process()`` call arrives on the same worker in
|
|
1111
|
+
source order — verifying this works correctly is the integration test's
|
|
1112
|
+
job (assert distinct ``conn=`` count is exactly 1).
|
|
1113
|
+
|
|
1114
|
+
Output is identical to ``BufferInputFunction``: passthrough of all
|
|
1115
|
+
buffered rows. Because there's only one Sink thread there's only one
|
|
1116
|
+
state_id; combine returns ``[0]`` and finalize yields the buffer.
|
|
1117
|
+
"""
|
|
1118
|
+
|
|
1119
|
+
class Meta(BufferInputFunction.Meta):
|
|
1120
|
+
name = "ordered_buffer_input"
|
|
1121
|
+
description = "buffer_input variant with sink_order_dependent=True"
|
|
1122
|
+
categories = ["test", "ordering"]
|
|
1123
|
+
sink_order_dependent = True
|
|
1124
|
+
|
|
1125
|
+
|
|
1126
|
+
def _pack_indexed_batch(batch_index: int, batch_bytes: bytes) -> bytes:
|
|
1127
|
+
"""Pack (batch_index, batch_bytes) into a single appendable blob.
|
|
1128
|
+
|
|
1129
|
+
Layout: 8 bytes little-endian signed batch_index || raw IPC stream bytes.
|
|
1130
|
+
Used by BatchIndexBufferInputFunction to thread per-batch ordering keys
|
|
1131
|
+
through the append-only state_log without an extra ArrowSerializableDataclass
|
|
1132
|
+
round-trip.
|
|
1133
|
+
"""
|
|
1134
|
+
return batch_index.to_bytes(8, "little", signed=True) + batch_bytes
|
|
1135
|
+
|
|
1136
|
+
|
|
1137
|
+
def _unpack_indexed_batch(blob: bytes) -> tuple[int, bytes]:
|
|
1138
|
+
"""Inverse of _pack_indexed_batch."""
|
|
1139
|
+
return int.from_bytes(blob[:8], "little", signed=True), blob[8:]
|
|
1140
|
+
|
|
1141
|
+
|
|
1142
|
+
class BatchIndexBufferInputFunction(TableBufferingFunction[SingleTableArguments, _LogDrainState]):
|
|
1143
|
+
"""Buffered table function that demands ``batch_index`` per ``process()``.
|
|
1144
|
+
|
|
1145
|
+
Uses ``Meta.requires_input_batch_index=True`` so the C++ operator
|
|
1146
|
+
declares ``RequiredPartitionInfo()=BatchIndex()`` and threads DuckDB's
|
|
1147
|
+
per-chunk batch_index into every ``process()`` call. process() packs
|
|
1148
|
+
(batch_index, ipc_bytes) into the per-worker state_log; combine() sorts
|
|
1149
|
+
globally by batch_index and re-writes a sorted log; finalize() drains
|
|
1150
|
+
cursor-style.
|
|
1151
|
+
"""
|
|
1152
|
+
|
|
1153
|
+
class Meta:
|
|
1154
|
+
name = "batch_index_buffer_input"
|
|
1155
|
+
description = "buffer_input variant using batch_index to reconstruct order"
|
|
1156
|
+
categories = ["test", "ordering"]
|
|
1157
|
+
requires_input_batch_index = True
|
|
1158
|
+
|
|
1159
|
+
@classmethod
|
|
1160
|
+
def on_bind(cls, params: BindParams[SingleTableArguments]) -> BindResponse:
|
|
1161
|
+
assert params.bind_call.input_schema is not None
|
|
1162
|
+
return BindResponse(output_schema=params.bind_call.input_schema)
|
|
1163
|
+
|
|
1164
|
+
@classmethod
|
|
1165
|
+
def process(
|
|
1166
|
+
cls,
|
|
1167
|
+
batch: pa.RecordBatch,
|
|
1168
|
+
params: TableBufferingParams[SingleTableArguments],
|
|
1169
|
+
) -> bytes:
|
|
1170
|
+
if params.batch_index is None:
|
|
1171
|
+
raise RuntimeError(
|
|
1172
|
+
"batch_index_buffer_input.process() received batch_index=None "
|
|
1173
|
+
"— Meta.requires_input_batch_index plumbing is broken"
|
|
1174
|
+
)
|
|
1175
|
+
sink = pa.BufferOutputStream()
|
|
1176
|
+
with pa.ipc.new_stream(sink, batch.schema) as writer:
|
|
1177
|
+
writer.write_batch(batch)
|
|
1178
|
+
# Append-only — race-safe under concurrent Sink threads. combine()
|
|
1179
|
+
# collects and sorts globally by batch_index.
|
|
1180
|
+
params.storage.state_append(
|
|
1181
|
+
b"unsorted",
|
|
1182
|
+
b"",
|
|
1183
|
+
_pack_indexed_batch(params.batch_index, sink.getvalue().to_pybytes()),
|
|
1184
|
+
)
|
|
1185
|
+
return params.execution_id
|
|
1186
|
+
|
|
1187
|
+
@classmethod
|
|
1188
|
+
def combine(
|
|
1189
|
+
cls,
|
|
1190
|
+
state_ids: list[bytes],
|
|
1191
|
+
params: TableBufferingParams[SingleTableArguments],
|
|
1192
|
+
) -> list[bytes]:
|
|
1193
|
+
"""Sort globally by batch_index and re-emit as a single ordered log."""
|
|
1194
|
+
all_pairs: list[tuple[int, bytes]] = [
|
|
1195
|
+
_unpack_indexed_batch(v) for _, v in params.storage.state_log_scan(b"unsorted", b"")
|
|
1196
|
+
]
|
|
1197
|
+
all_pairs.sort(key=lambda p: p[0])
|
|
1198
|
+
for _idx, batch_bytes in all_pairs:
|
|
1199
|
+
params.storage.state_append(b"buf", b"", batch_bytes)
|
|
1200
|
+
return [params.execution_id]
|
|
1201
|
+
|
|
1202
|
+
@classmethod
|
|
1203
|
+
def initial_finalize_state(
|
|
1204
|
+
cls,
|
|
1205
|
+
finalize_state_id: bytes,
|
|
1206
|
+
params: TableBufferingParams[SingleTableArguments],
|
|
1207
|
+
) -> _LogDrainState:
|
|
1208
|
+
return _LogDrainState(ns=b"buf", after_id=-1)
|
|
1209
|
+
|
|
1210
|
+
@classmethod
|
|
1211
|
+
def finalize(
|
|
1212
|
+
cls,
|
|
1213
|
+
params: TableBufferingParams[SingleTableArguments],
|
|
1214
|
+
finalize_state_id: bytes,
|
|
1215
|
+
state: _LogDrainState,
|
|
1216
|
+
out: OutputCollector,
|
|
1217
|
+
) -> None:
|
|
1218
|
+
rows = params.storage.state_log_scan(
|
|
1219
|
+
state.ns,
|
|
1220
|
+
b"",
|
|
1221
|
+
after_id=state.after_id,
|
|
1222
|
+
limit=1,
|
|
1223
|
+
)
|
|
1224
|
+
if not rows:
|
|
1225
|
+
out.finish()
|
|
1226
|
+
return
|
|
1227
|
+
log_id, value = rows[0]
|
|
1228
|
+
out.emit(pa.ipc.open_stream(value).read_next_batch())
|
|
1229
|
+
state.after_id = log_id
|
|
1230
|
+
|
|
1231
|
+
|
|
1232
|
+
@dataclass
|
|
1233
|
+
class _OneShotState(ArrowSerializableDataclass):
|
|
1234
|
+
"""Single-emit cursor for ``OrderedSourceFunction.finalize``."""
|
|
1235
|
+
|
|
1236
|
+
value: int = 0
|
|
1237
|
+
emitted: bool = False
|
|
1238
|
+
|
|
1239
|
+
|
|
1240
|
+
class OrderedSourceFunction(TableBufferingFunction[SingleTableArguments, _OneShotState]):
|
|
1241
|
+
"""Buffered table function with ``source_order_dependent=True``.
|
|
1242
|
+
|
|
1243
|
+
Forces ``ParallelSource()=false`` and ``SourceOrder()=FIXED_ORDER`` on the
|
|
1244
|
+
C++ ``PhysicalVgiTableBufferingFunction``. The Source phase serial-drains
|
|
1245
|
+
``finalize_queue`` in whatever order ``combine()`` populated it; without
|
|
1246
|
+
``source_order_dependent`` the parallel Source drains would race and emit
|
|
1247
|
+
rows in arbitrary order.
|
|
1248
|
+
|
|
1249
|
+
The fixture deliberately ignores its input and emits a fixed 0..15
|
|
1250
|
+
integer sequence so the assertion is deterministic regardless of Sink
|
|
1251
|
+
parallelism or input partitioning: ``combine()`` returns sixteen
|
|
1252
|
+
finalize_state_ids encoded as 4-byte big-endian integers in ascending
|
|
1253
|
+
order; ``finalize()`` decodes its state_id and emits exactly one row
|
|
1254
|
+
containing that integer. With ``source_order_dependent`` the C++ Source
|
|
1255
|
+
must yield rows in the same 0..15 order.
|
|
1256
|
+
|
|
1257
|
+
Output schema: single ``v`` column (BIGINT).
|
|
1258
|
+
"""
|
|
1259
|
+
|
|
1260
|
+
class Meta:
|
|
1261
|
+
name = "ordered_source"
|
|
1262
|
+
description = "Emits a fixed 0..15 sequence via source_order_dependent=True; input is ignored"
|
|
1263
|
+
categories = ["test", "ordering"]
|
|
1264
|
+
source_order_dependent = True
|
|
1265
|
+
|
|
1266
|
+
_N_ROWS = 16
|
|
1267
|
+
|
|
1268
|
+
@classmethod
|
|
1269
|
+
def on_bind(cls, params: BindParams[SingleTableArguments]) -> BindResponse:
|
|
1270
|
+
return BindResponse(output_schema=schema(v=pa.int64()))
|
|
1271
|
+
|
|
1272
|
+
@classmethod
|
|
1273
|
+
def process(
|
|
1274
|
+
cls,
|
|
1275
|
+
batch: pa.RecordBatch,
|
|
1276
|
+
params: TableBufferingParams[SingleTableArguments],
|
|
1277
|
+
) -> bytes:
|
|
1278
|
+
# Input is irrelevant — the test asserts source ordering, not data.
|
|
1279
|
+
return params.execution_id
|
|
1280
|
+
|
|
1281
|
+
@classmethod
|
|
1282
|
+
def combine(
|
|
1283
|
+
cls,
|
|
1284
|
+
state_ids: list[bytes],
|
|
1285
|
+
params: TableBufferingParams[SingleTableArguments],
|
|
1286
|
+
) -> list[bytes]:
|
|
1287
|
+
# Fixed monotonically-ascending list of 4-byte big-endian integers.
|
|
1288
|
+
# FIXED_ORDER Source must drain finalize_queue in this exact order.
|
|
1289
|
+
return [i.to_bytes(4, "big") for i in range(cls._N_ROWS)]
|
|
1290
|
+
|
|
1291
|
+
@classmethod
|
|
1292
|
+
def initial_finalize_state(
|
|
1293
|
+
cls,
|
|
1294
|
+
finalize_state_id: bytes,
|
|
1295
|
+
params: TableBufferingParams[SingleTableArguments],
|
|
1296
|
+
) -> _OneShotState:
|
|
1297
|
+
return _OneShotState(value=int.from_bytes(finalize_state_id, "big"))
|
|
1298
|
+
|
|
1299
|
+
@classmethod
|
|
1300
|
+
def finalize(
|
|
1301
|
+
cls,
|
|
1302
|
+
params: TableBufferingParams[SingleTableArguments],
|
|
1303
|
+
finalize_state_id: bytes,
|
|
1304
|
+
state: _OneShotState,
|
|
1305
|
+
out: OutputCollector,
|
|
1306
|
+
) -> None:
|
|
1307
|
+
if state.emitted:
|
|
1308
|
+
out.finish()
|
|
1309
|
+
return
|
|
1310
|
+
out.emit(pa.RecordBatch.from_pylist([{"v": state.value}], schema=params.output_schema))
|
|
1311
|
+
state.emitted = True
|
|
1312
|
+
|
|
1313
|
+
|
|
1314
|
+
# ---------------------------------------------------------------------------
|
|
1315
|
+
# Repro fixture: emit a single large finalize batch from a buffering function.
|
|
1316
|
+
# ---------------------------------------------------------------------------
|
|
1317
|
+
|
|
1318
|
+
|
|
1319
|
+
@dataclass(slots=True, frozen=True, kw_only=True)
|
|
1320
|
+
class BufferEmitWideArguments:
|
|
1321
|
+
"""Arguments for BufferEmitWideFunction."""
|
|
1322
|
+
|
|
1323
|
+
rows: Annotated[int, Arg(0, doc="Number of rows to emit in one finalize batch", ge=0)]
|
|
1324
|
+
data: Annotated[TableInput, Arg(1, doc="Input table (content ignored)")]
|
|
1325
|
+
|
|
1326
|
+
|
|
1327
|
+
@dataclass
|
|
1328
|
+
class _EmitOnceState(ArrowSerializableDataclass):
|
|
1329
|
+
"""Whether the single finalize batch has been emitted."""
|
|
1330
|
+
|
|
1331
|
+
emitted: bool = False
|
|
1332
|
+
|
|
1333
|
+
|
|
1334
|
+
_BUFFER_EMIT_WIDE_SCHEMA = schema(n=pa.int64())
|
|
1335
|
+
|
|
1336
|
+
|
|
1337
|
+
class BufferEmitWideFunction(TableBufferingFunction[BufferEmitWideArguments, _EmitOnceState]):
|
|
1338
|
+
"""Buffering function whose Source phase emits ONE batch of ``rows`` rows.
|
|
1339
|
+
|
|
1340
|
+
Unlike BufferInputFunction (which echoes input batches, each already
|
|
1341
|
+
capped at DuckDB's standard vector size), this emits a single, arbitrarily
|
|
1342
|
+
large output batch from ``finalize``. It is a minimal repro for whether the
|
|
1343
|
+
buffering Source path supports output batches larger than the standard
|
|
1344
|
+
vector size (2048 rows) — a regular TableFunctionGenerator (e.g. sequence)
|
|
1345
|
+
does support this.
|
|
1346
|
+
"""
|
|
1347
|
+
|
|
1348
|
+
class Meta:
|
|
1349
|
+
"""Metadata for BufferEmitWideFunction."""
|
|
1350
|
+
|
|
1351
|
+
name = "buffer_emit_wide"
|
|
1352
|
+
description = "Emit a single finalize batch of N rows (vector-size repro)"
|
|
1353
|
+
categories = ["test", "buffer"]
|
|
1354
|
+
examples = [
|
|
1355
|
+
FunctionExample(
|
|
1356
|
+
sql="SELECT count(*) FROM buffer_emit_wide((SELECT 1), 10000)",
|
|
1357
|
+
description="Emit a single 10000-row batch from the Source phase",
|
|
1358
|
+
)
|
|
1359
|
+
]
|
|
1360
|
+
|
|
1361
|
+
@classmethod
|
|
1362
|
+
def on_bind(cls, params: BindParams[BufferEmitWideArguments]) -> BindResponse:
|
|
1363
|
+
return BindResponse(output_schema=_BUFFER_EMIT_WIDE_SCHEMA)
|
|
1364
|
+
|
|
1365
|
+
@classmethod
|
|
1366
|
+
def process(cls, batch: pa.RecordBatch, params: TableBufferingParams[BufferEmitWideArguments]) -> bytes:
|
|
1367
|
+
return params.execution_id
|
|
1368
|
+
|
|
1369
|
+
@classmethod
|
|
1370
|
+
def combine(cls, state_ids: list[bytes], params: TableBufferingParams[BufferEmitWideArguments]) -> list[bytes]:
|
|
1371
|
+
return [params.execution_id]
|
|
1372
|
+
|
|
1373
|
+
@classmethod
|
|
1374
|
+
def initial_finalize_state(
|
|
1375
|
+
cls, finalize_state_id: bytes, params: TableBufferingParams[BufferEmitWideArguments]
|
|
1376
|
+
) -> _EmitOnceState:
|
|
1377
|
+
return _EmitOnceState(emitted=False)
|
|
1378
|
+
|
|
1379
|
+
@classmethod
|
|
1380
|
+
def finalize(
|
|
1381
|
+
cls,
|
|
1382
|
+
params: TableBufferingParams[BufferEmitWideArguments],
|
|
1383
|
+
finalize_state_id: bytes,
|
|
1384
|
+
state: _EmitOnceState,
|
|
1385
|
+
out: OutputCollector,
|
|
1386
|
+
) -> None:
|
|
1387
|
+
if state.emitted:
|
|
1388
|
+
out.finish()
|
|
1389
|
+
return
|
|
1390
|
+
n = params.args.rows
|
|
1391
|
+
out.emit(pa.RecordBatch.from_pydict({"n": list(range(n))}, schema=_BUFFER_EMIT_WIDE_SCHEMA))
|
|
1392
|
+
state.emitted = True
|