vgi-python 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. vgi/__init__.py +152 -0
  2. vgi/_duckdb.py +62 -0
  3. vgi/_storage_profile.py +132 -0
  4. vgi/_test_fixtures/__init__.py +20 -0
  5. vgi/_test_fixtures/accumulate/__init__.py +19 -0
  6. vgi/_test_fixtures/accumulate/worker.py +762 -0
  7. vgi/_test_fixtures/aggregate/__init__.py +62 -0
  8. vgi/_test_fixtures/aggregate/_common.py +21 -0
  9. vgi/_test_fixtures/aggregate/basic.py +232 -0
  10. vgi/_test_fixtures/aggregate/dynamic.py +409 -0
  11. vgi/_test_fixtures/aggregate/generic.py +86 -0
  12. vgi/_test_fixtures/aggregate/listagg.py +71 -0
  13. vgi/_test_fixtures/aggregate/percentile.py +107 -0
  14. vgi/_test_fixtures/aggregate/streaming.py +192 -0
  15. vgi/_test_fixtures/aggregate/varargs.py +75 -0
  16. vgi/_test_fixtures/aggregate/window.py +380 -0
  17. vgi/_test_fixtures/attach_options.py +308 -0
  18. vgi/_test_fixtures/bad_protocol.py +62 -0
  19. vgi/_test_fixtures/cancellable.py +336 -0
  20. vgi/_test_fixtures/catalog.py +813 -0
  21. vgi/_test_fixtures/http_server.py +394 -0
  22. vgi/_test_fixtures/nest_tensor.py +614 -0
  23. vgi/_test_fixtures/orchard_catalog.py +47 -0
  24. vgi/_test_fixtures/projection_repro/__init__.py +6 -0
  25. vgi/_test_fixtures/projection_repro/worker.py +454 -0
  26. vgi/_test_fixtures/scalar/__init__.py +116 -0
  27. vgi/_test_fixtures/scalar/_common.py +69 -0
  28. vgi/_test_fixtures/scalar/arithmetic.py +321 -0
  29. vgi/_test_fixtures/scalar/binary.py +120 -0
  30. vgi/_test_fixtures/scalar/formatting.py +176 -0
  31. vgi/_test_fixtures/scalar/geo.py +300 -0
  32. vgi/_test_fixtures/scalar/null_handling.py +107 -0
  33. vgi/_test_fixtures/scalar/random_demo.py +171 -0
  34. vgi/_test_fixtures/scalar/settings_secrets.py +102 -0
  35. vgi/_test_fixtures/scalar/type_info.py +219 -0
  36. vgi/_test_fixtures/schema_reconcile/__init__.py +29 -0
  37. vgi/_test_fixtures/schema_reconcile/worker.py +653 -0
  38. vgi/_test_fixtures/simple_writable.py +793 -0
  39. vgi/_test_fixtures/table/__init__.py +221 -0
  40. vgi/_test_fixtures/table/_common.py +162 -0
  41. vgi/_test_fixtures/table/batch_index.py +283 -0
  42. vgi/_test_fixtures/table/batch_index_broken.py +200 -0
  43. vgi/_test_fixtures/table/catalog_scans.py +162 -0
  44. vgi/_test_fixtures/table/filters.py +1005 -0
  45. vgi/_test_fixtures/table/late_materialization.py +249 -0
  46. vgi/_test_fixtures/table/make_series.py +273 -0
  47. vgi/_test_fixtures/table/misc.py +499 -0
  48. vgi/_test_fixtures/table/order_modes.py +164 -0
  49. vgi/_test_fixtures/table/pairs.py +437 -0
  50. vgi/_test_fixtures/table/partition_columns.py +472 -0
  51. vgi/_test_fixtures/table/partition_columns_broken.py +304 -0
  52. vgi/_test_fixtures/table/profiling_example.py +195 -0
  53. vgi/_test_fixtures/table/required_filters.py +234 -0
  54. vgi/_test_fixtures/table/sequence.py +710 -0
  55. vgi/_test_fixtures/table/settings.py +426 -0
  56. vgi/_test_fixtures/table/transaction_storage.py +162 -0
  57. vgi/_test_fixtures/table/tt_pushdown.py +191 -0
  58. vgi/_test_fixtures/table/versioned.py +230 -0
  59. vgi/_test_fixtures/table_in_out.py +1392 -0
  60. vgi/_test_fixtures/versioned.py +155 -0
  61. vgi/_test_fixtures/versioned_tables.py +595 -0
  62. vgi/_test_fixtures/worker.py +1631 -0
  63. vgi/_test_fixtures/writable/__init__.py +8 -0
  64. vgi/_test_fixtures/writable/generic.py +236 -0
  65. vgi/_test_fixtures/writable/table.py +149 -0
  66. vgi/_test_fixtures/writable/worker.py +1148 -0
  67. vgi/aggregate_function.py +607 -0
  68. vgi/argument_spec.py +472 -0
  69. vgi/arguments.py +1747 -0
  70. vgi/auth.py +55 -0
  71. vgi/catalog/__init__.py +88 -0
  72. vgi/catalog/attach_option.py +206 -0
  73. vgi/catalog/catalog_interface.py +2767 -0
  74. vgi/catalog/descriptors.py +870 -0
  75. vgi/catalog/duckdb_statistics.py +377 -0
  76. vgi/catalog/secret_type.py +96 -0
  77. vgi/catalog/setting.py +253 -0
  78. vgi/catalog/storage.py +372 -0
  79. vgi/client/__init__.py +67 -0
  80. vgi/client/catalog_mixin.py +1251 -0
  81. vgi/client/cli.py +582 -0
  82. vgi/client/cli_catalog.py +182 -0
  83. vgi/client/cli_schema.py +270 -0
  84. vgi/client/cli_table.py +907 -0
  85. vgi/client/cli_transaction.py +97 -0
  86. vgi/client/cli_utils.py +441 -0
  87. vgi/client/cli_view.py +303 -0
  88. vgi/client/client.py +2183 -0
  89. vgi/exceptions.py +205 -0
  90. vgi/function.py +245 -0
  91. vgi/function_storage.py +1636 -0
  92. vgi/function_storage_azure_sql.py +922 -0
  93. vgi/function_storage_cf_do.py +740 -0
  94. vgi/http/__init__.py +25 -0
  95. vgi/http/demo_storage.py +212 -0
  96. vgi/http/worker_page.py +1252 -0
  97. vgi/invocation.py +154 -0
  98. vgi/logging_config.py +93 -0
  99. vgi/meta_worker.py +661 -0
  100. vgi/metadata.py +1403 -0
  101. vgi/otel.py +406 -0
  102. vgi/protocol.py +2418 -0
  103. vgi/protocol_version.txt +1 -0
  104. vgi/py.typed +0 -0
  105. vgi/scalar_function.py +1211 -0
  106. vgi/schema_utils.py +234 -0
  107. vgi/secret_protocol.py +124 -0
  108. vgi/secret_service.py +238 -0
  109. vgi/serve.py +769 -0
  110. vgi/table_buffering_function.py +443 -0
  111. vgi/table_filter_pushdown.py +1528 -0
  112. vgi/table_function.py +1130 -0
  113. vgi/table_in_out_function.py +383 -0
  114. vgi/transactor/__init__.py +24 -0
  115. vgi/transactor/_duckdb_compat.py +27 -0
  116. vgi/transactor/client.py +137 -0
  117. vgi/transactor/protocol.py +149 -0
  118. vgi/transactor/server.py +740 -0
  119. vgi/worker.py +4761 -0
  120. vgi_python-0.8.0.dist-info/METADATA +735 -0
  121. vgi_python-0.8.0.dist-info/RECORD +124 -0
  122. vgi_python-0.8.0.dist-info/WHEEL +4 -0
  123. vgi_python-0.8.0.dist-info/entry_points.txt +5 -0
  124. vgi_python-0.8.0.dist-info/licenses/LICENSE +134 -0
@@ -0,0 +1,1392 @@
1
+ # Copyright 2025, 2026 Query Farm LLC - https://query.farm
2
+
3
+ """Example table-in/table-out function implementations for testing VGI.
4
+
5
+ WARNING: EXAMPLE/TEST FUNCTIONS ONLY
6
+ -------------------------------------
7
+ These functions are reference implementations for testing and validating the VGI
8
+ protocol. They are NOT intended for production use. The VGI protocol will have
9
+ multiple implementations (Python, Go, JavaScript), and these examples serve as:
10
+
11
+ 1. Protocol conformance tests - Verify implementations correctly handle the
12
+ VGI streaming protocol.
13
+ 2. Pattern demonstrations - Show how to implement common function patterns
14
+ 3. Cross-implementation test cases - Ensure consistent behavior across languages
15
+
16
+ Production considerations like memory limits, error recovery, and performance
17
+ optimizations are intentionally omitted to keep the examples simple and focused
18
+ on protocol correctness.
19
+
20
+ AVAILABLE FUNCTIONS
21
+ -------------------
22
+ EchoFunction - Passthrough, no transformation
23
+ BufferInputFunction - Collects all input, emits on finalize
24
+ FilterBySettingFunction - Filters rows by threshold setting
25
+ RepeatInputsFunction - Duplicates each input batch N times
26
+ SumAllColumnsFunction - Aggregates numeric columns into sums
27
+ ExceptionProcessFunction - Raises exception during process (test)
28
+ ExceptionFinalizeFunction - Raises exception during finalize (test)
29
+ SumAllColumnsSimpleDistributed - Distributed aggregation via callback API
30
+ CrashOnProcessFunction - SIGKILLs the worker mid-process (test)
31
+ CrashOnCombineFunction - Raises during combine() (test)
32
+ CrashOnFinalizeFunction - Raises during finalize() (test)
33
+ HangOnProcessFunction - Sleeps forever in process() (manual cancel test)
34
+ LargeStateFunction - Buffers ~N MB per state_id (IPC chunking test)
35
+ """
36
+
37
+ from __future__ import annotations
38
+
39
+ import os
40
+ import signal
41
+ import sys
42
+ import time
43
+ from dataclasses import dataclass
44
+ from typing import Annotated, Any
45
+
46
+ import pyarrow as pa
47
+ import pyarrow.compute as pc
48
+ from vgi_rpc import ArrowSerializableDataclass, ArrowType
49
+ from vgi_rpc.log import Level
50
+ from vgi_rpc.rpc import OutputCollector
51
+ from vgi_rpc.utils import empty_batch
52
+
53
+ from vgi.arguments import Arg, Setting, TableInput
54
+ from vgi.invocation import BindResponse
55
+ from vgi.metadata import FunctionExample
56
+ from vgi.schema_utils import schema
57
+ from vgi.table_buffering_function import (
58
+ TableBufferingFunction,
59
+ TableBufferingParams,
60
+ )
61
+ from vgi.table_function import BindParams, ProcessParams, TableCardinality
62
+ from vgi.table_in_out_function import (
63
+ TableInOutFunction,
64
+ TableInOutGenerator,
65
+ )
66
+
67
+
68
+ # Per-tick cursor state for finalize streams that drain a state_log via
69
+ # state_log_scan. Wire-serializable so the producer-mode stream survives
70
+ # HTTP tick boundaries.
71
+ @dataclass
72
+ class _LogDrainState(ArrowSerializableDataclass):
73
+ """Cursor over a per-state state_log; after_id starts at -1 (before-first)."""
74
+
75
+ # Namespace under which finalize draws batches. Defaults to b"buf" —
76
+ # the conventional location process() writes via state_append.
77
+ ns: bytes = b"buf"
78
+ after_id: int = -1
79
+
80
+
81
+ __all__ = [
82
+ "EchoFunction",
83
+ "EchoWitnessFunction",
84
+ "BufferInputFunction",
85
+ "FilterBySettingFunction",
86
+ "RepeatInputsFunction",
87
+ "SumAllColumnsFunction",
88
+ "SumAllColumnsSimpleDistributed",
89
+ "ExceptionProcessFunction",
90
+ "ExceptionFinalizeFunction",
91
+ "CrashOnProcessFunction",
92
+ "CrashOnCombineFunction",
93
+ "CrashOnFinalizeFunction",
94
+ "HangOnProcessFunction",
95
+ "LargeStateFunction",
96
+ "OrderedBufferInputFunction",
97
+ "BatchIndexBufferInputFunction",
98
+ "OrderedSourceFunction",
99
+ "BufferEmitWideFunction",
100
+ ]
101
+
102
+
103
+ @dataclass(slots=True, frozen=True, kw_only=True)
104
+ class SingleTableArguments:
105
+ """Arguments for a table in/out function that just takes a table."""
106
+
107
+ data: Annotated[TableInput, Arg(0, doc="Input table")]
108
+
109
+
110
+ class EchoFunction(TableInOutGenerator[SingleTableArguments]):
111
+ """Passthrough function that emits each input batch unchanged.
112
+
113
+ USE CASE
114
+ --------
115
+ Testing, debugging, or as a no-op placeholder in a pipeline.
116
+
117
+ SCHEMA TRANSFORMATION
118
+ ---------------------
119
+ Input: any schema
120
+ Output: same schema (passthrough), with optional projection and filtering
121
+
122
+ PUSHDOWN SUPPORT
123
+ ----------------
124
+ - projection_pushdown: Only returns requested columns
125
+ - filter_pushdown: Filters rows based on pushed-down predicates
126
+ - auto_apply_filters: Automatically applies filters to output batches
127
+
128
+ Example:
129
+ -------
130
+ Input: [{"a": 1, "b": 2}, {"a": 3, "b": 4}]
131
+ Output: [{"a": 1, "b": 2}, {"a": 3, "b": 4}]
132
+
133
+ """
134
+
135
+ class Meta:
136
+ """Metadata for EchoFunction."""
137
+
138
+ name = "echo"
139
+ description = "Passthrough function that emits each input batch unchanged"
140
+ categories = ["utility", "debug"]
141
+ tags = {"category": "debug", "type": "passthrough"}
142
+ projection_pushdown = True
143
+ filter_pushdown = True
144
+ auto_apply_filters = True
145
+ examples = [
146
+ FunctionExample(
147
+ sql="SELECT * FROM echo((SELECT * FROM input_table))",
148
+ description="Pass through all rows unchanged",
149
+ )
150
+ ]
151
+
152
+ @classmethod
153
+ def on_bind(cls, params: BindParams[SingleTableArguments]) -> BindResponse:
154
+ """Produce the output schema."""
155
+ assert params.bind_call.input_schema is not None
156
+ return BindResponse(output_schema=params.bind_call.input_schema)
157
+
158
+
159
+ class EchoWitnessFunction(TableInOutGenerator[SingleTableArguments]):
160
+ """Integer-output fixture that encodes the post-projection column count.
161
+
162
+ Designed to verify that projection pushdown ACTUALLY narrows the
163
+ schema reaching the worker (rather than just relying on DuckDB to
164
+ narrow above the operator). Each emitted row has every column set
165
+ to ``len(params.output_schema)`` — i.e., the worker's observed
166
+ column count after framework projection narrowing.
167
+
168
+ With pushdown working:
169
+ ``SELECT a FROM echo_witness((SELECT 1 AS a, 2 AS b, 3 AS c))`` → 1
170
+
171
+ Without pushdown (DuckDB requests all columns, narrows above):
172
+ ``SELECT a FROM echo_witness((SELECT 1 AS a, 2 AS b, 3 AS c))`` → 3
173
+
174
+ Output schema mirrors input (must be all integer columns for the
175
+ encoding to work). Filter pushdown is intentionally off — this
176
+ fixture only probes projection.
177
+ """
178
+
179
+ class Meta:
180
+ name = "echo_witness"
181
+ description = "Emits len(observed_output_schema) per column — projection probe"
182
+ categories = ["test", "pushdown"]
183
+ projection_pushdown = True
184
+
185
+ @classmethod
186
+ def on_bind(cls, params: BindParams[SingleTableArguments]) -> BindResponse:
187
+ assert params.bind_call.input_schema is not None
188
+ return BindResponse(output_schema=params.bind_call.input_schema)
189
+
190
+ @classmethod
191
+ def process(
192
+ cls,
193
+ params: ProcessParams[SingleTableArguments],
194
+ state: None,
195
+ batch: pa.RecordBatch,
196
+ out: OutputCollector,
197
+ ) -> None:
198
+ observed = len(params.output_schema)
199
+ cols = {field.name: pa.array([observed] * batch.num_rows, type=field.type) for field in params.output_schema}
200
+ out.emit(pa.RecordBatch.from_pydict(cols, schema=params.output_schema))
201
+
202
+
203
+ class BufferInputFunction(TableBufferingFunction[SingleTableArguments, _LogDrainState]):
204
+ """Buffering function — collects all input, emits during finalize.
205
+
206
+ One bucket per execution: ``process()`` returns ``params.execution_id``
207
+ for every call and appends to a single shared state_log under
208
+ ``(b"buf", b"")``. ``combine()`` collapses every state_id (all are
209
+ identical) to a single finalize_state_id. ``finalize()`` cursor-drains
210
+ one batch per tick.
211
+
212
+ Schema:
213
+ Input: any schema
214
+ Output: same schema (passthrough)
215
+ """
216
+
217
+ class Meta:
218
+ name = "buffer_input"
219
+ description = "Collects all input batches and emits during finalization"
220
+ categories = ["utility", "buffer"]
221
+ examples = [
222
+ FunctionExample(
223
+ sql="SELECT * FROM buffer_input((SELECT * FROM input_table))",
224
+ description="Buffer all input and emit on finalize",
225
+ )
226
+ ]
227
+
228
+ @classmethod
229
+ def on_bind(cls, params: BindParams[SingleTableArguments]) -> BindResponse:
230
+ assert params.bind_call.input_schema is not None
231
+ return BindResponse(output_schema=params.bind_call.input_schema)
232
+
233
+ @classmethod
234
+ def process(
235
+ cls,
236
+ batch: pa.RecordBatch,
237
+ params: TableBufferingParams[SingleTableArguments],
238
+ ) -> bytes:
239
+ """Append the batch to the shared state_log; return execution_id."""
240
+ sink = pa.BufferOutputStream()
241
+ with pa.ipc.new_stream(sink, batch.schema) as writer:
242
+ writer.write_batch(batch)
243
+ params.storage.state_append(b"buf", b"", sink.getvalue().to_pybytes())
244
+ return params.execution_id
245
+
246
+ @classmethod
247
+ def combine(
248
+ cls,
249
+ state_ids: list[bytes],
250
+ params: TableBufferingParams[SingleTableArguments],
251
+ ) -> list[bytes]:
252
+ # Every state_id is params.execution_id; collapse to one stream.
253
+ return [params.execution_id]
254
+
255
+ @classmethod
256
+ def initial_finalize_state(
257
+ cls,
258
+ finalize_state_id: bytes,
259
+ params: TableBufferingParams[SingleTableArguments],
260
+ ) -> _LogDrainState:
261
+ return _LogDrainState(ns=b"buf", after_id=-1)
262
+
263
+ @classmethod
264
+ def finalize(
265
+ cls,
266
+ params: TableBufferingParams[SingleTableArguments],
267
+ finalize_state_id: bytes,
268
+ state: _LogDrainState,
269
+ out: OutputCollector,
270
+ ) -> None:
271
+ """Emit one buffered batch per tick; finish at end-of-log."""
272
+ rows = params.storage.state_log_scan(
273
+ state.ns,
274
+ b"",
275
+ after_id=state.after_id,
276
+ limit=1,
277
+ )
278
+ if not rows:
279
+ out.finish()
280
+ return
281
+ log_id, value = rows[0]
282
+ out.emit(pa.ipc.open_stream(value).read_next_batch())
283
+ state.after_id = log_id
284
+
285
+
286
+ class EchoBufferingFunction(TableBufferingFunction[SingleTableArguments, _LogDrainState]):
287
+ """Buffered passthrough with projection + filter pushdown enabled.
288
+
289
+ Same shape as :class:`BufferInputFunction` (process buffers input,
290
+ finalize drains one batch per tick) but declares all three pushdown
291
+ flags so DuckDB sends ``projection_ids`` / ``pushdown_filters`` on the
292
+ InitRequest. The framework:
293
+
294
+ * Narrows ``params.output_schema`` to the projected columns; the
295
+ ``OutputCollector.emit`` call's ``batch.select(target_names)`` then
296
+ drops non-projected columns from the buffered full-width batch.
297
+ * Wraps ``out`` in ``_FilteringOutputCollector`` (because
298
+ ``auto_apply_filters=True``) so emitted batches are filter-applied
299
+ automatically.
300
+
301
+ User code stays the streaming-style passthrough — no awareness of
302
+ projection or filters needed. The fixture verifies that buffered
303
+ TableBufferingFunction pushdown actually plumbs through end-to-end.
304
+ """
305
+
306
+ class Meta:
307
+ """Metadata for EchoBufferingFunction."""
308
+
309
+ name = "echo_buffering"
310
+ description = "Buffered passthrough with projection + filter pushdown"
311
+ categories = ["test", "buffer", "pushdown"]
312
+ projection_pushdown = True
313
+ filter_pushdown = True
314
+ auto_apply_filters = True
315
+
316
+ @classmethod
317
+ def on_bind(cls, params: BindParams[SingleTableArguments]) -> BindResponse:
318
+ """Output schema = input schema (passthrough)."""
319
+ assert params.bind_call.input_schema is not None
320
+ return BindResponse(output_schema=params.bind_call.input_schema)
321
+
322
+ @classmethod
323
+ def process(
324
+ cls,
325
+ batch: pa.RecordBatch,
326
+ params: TableBufferingParams[SingleTableArguments],
327
+ ) -> bytes:
328
+ """Buffer the full input batch (no projection at storage time)."""
329
+ sink = pa.BufferOutputStream()
330
+ with pa.ipc.new_stream(sink, batch.schema) as writer:
331
+ writer.write_batch(batch)
332
+ params.storage.state_append(b"buf", b"", sink.getvalue().to_pybytes())
333
+ return params.execution_id
334
+
335
+ @classmethod
336
+ def combine(
337
+ cls,
338
+ state_ids: list[bytes], # noqa: ARG003 - collapse to one finalize stream
339
+ params: TableBufferingParams[SingleTableArguments],
340
+ ) -> list[bytes]:
341
+ return [params.execution_id]
342
+
343
+ @classmethod
344
+ def initial_finalize_state(
345
+ cls,
346
+ finalize_state_id: bytes, # noqa: ARG003 - one bucket per execution
347
+ params: TableBufferingParams[SingleTableArguments], # noqa: ARG003
348
+ ) -> _LogDrainState:
349
+ return _LogDrainState(ns=b"buf", after_id=-1)
350
+
351
+ @classmethod
352
+ def finalize(
353
+ cls,
354
+ params: TableBufferingParams[SingleTableArguments],
355
+ finalize_state_id: bytes, # noqa: ARG003
356
+ state: _LogDrainState,
357
+ out: OutputCollector,
358
+ ) -> None:
359
+ """Emit one buffered batch per tick — framework narrows + filters."""
360
+ rows = params.storage.state_log_scan(
361
+ state.ns,
362
+ b"",
363
+ after_id=state.after_id,
364
+ limit=1,
365
+ )
366
+ if not rows:
367
+ out.finish()
368
+ return
369
+ log_id, value = rows[0]
370
+ out.emit(pa.ipc.open_stream(value).read_next_batch())
371
+ state.after_id = log_id
372
+
373
+
374
+ class FilterBySettingFunction(TableInOutGenerator[SingleTableArguments]):
375
+ """Filters input rows where the value column meets a threshold setting.
376
+
377
+ USE CASE
378
+ --------
379
+ Demonstrates how table-in-out functions can use DuckDB settings to control
380
+ behavior. The threshold setting determines which rows pass through: only
381
+ rows where the "value" column >= threshold are emitted.
382
+
383
+ The Setting() on on_bind() serves solely to register ``threshold`` in
384
+ required_settings metadata. The actual filtering uses params.settings
385
+ in process().
386
+
387
+ SCHEMA TRANSFORMATION
388
+ ---------------------
389
+ Input: any schema (must contain a "value" column)
390
+ Output: same schema (rows filtered by threshold)
391
+
392
+ Example:
393
+ -------
394
+ With threshold=5 and input [{"value": 3}, {"value": 7}]:
395
+ Output: [{"value": 7}]
396
+
397
+ """
398
+
399
+ class Meta:
400
+ """Metadata for FilterBySettingFunction."""
401
+
402
+ name = "filter_by_setting"
403
+ description = "Filter rows where value column >= threshold setting"
404
+ categories = ["transform", "settings"]
405
+ examples = [
406
+ FunctionExample(
407
+ sql="SELECT * FROM filter_by_setting((SELECT * FROM input_table))",
408
+ description="Filter rows using the threshold setting",
409
+ )
410
+ ]
411
+
412
+ @classmethod
413
+ def on_bind(
414
+ cls,
415
+ params: BindParams[SingleTableArguments],
416
+ *,
417
+ threshold: Annotated[pa.Scalar[Any] | None, Setting()] = None,
418
+ ) -> BindResponse:
419
+ """Return input schema unchanged. Threshold declared for required_settings."""
420
+ assert params.bind_call.input_schema is not None
421
+ return BindResponse(output_schema=params.bind_call.input_schema)
422
+
423
+ @classmethod
424
+ def process(
425
+ cls,
426
+ params: ProcessParams[SingleTableArguments],
427
+ state: None,
428
+ batch: pa.RecordBatch,
429
+ out: OutputCollector,
430
+ ) -> None:
431
+ """Filter rows where value >= threshold."""
432
+ raw_threshold = params.settings["threshold"]
433
+ # Cast to column type for compatibility (C++ extension may send as string)
434
+ col = batch.column("value")
435
+ threshold = pa.scalar(int(raw_threshold.as_py()), type=col.type)
436
+ mask = pc.greater_equal(col, threshold)
437
+ out.emit(batch.filter(mask))
438
+
439
+
440
+ @dataclass(slots=True, frozen=True)
441
+ class RepeatsInputsFunctionArguments:
442
+ """Arguments for RepeatInputsFunction."""
443
+
444
+ repeat_count: Annotated[int, Arg(0, doc="Number of times to repeat each input batch")]
445
+ data: Annotated[TableInput, Arg(1, doc="Input table to repeat")]
446
+
447
+
448
+ class RepeatInputsFunction(TableInOutGenerator[RepeatsInputsFunctionArguments]):
449
+ """Explosion function that duplicates each input batch N times.
450
+
451
+ USE CASE
452
+ --------
453
+ Data augmentation, testing with larger datasets, or any scenario where
454
+ you need multiple copies of each input record.
455
+
456
+ Arguments:
457
+ ---------
458
+ repeat_count: Annotated[int, Arg(0)] (required)
459
+ Number of times to repeat each input batch.
460
+
461
+ BEHAVIOR
462
+ --------
463
+ - output_schema: Returns input schema unchanged
464
+ - process(): For each input, concatenates it N times into one output
465
+
466
+ SCHEMA TRANSFORMATION
467
+ ---------------------
468
+ Input: any schema
469
+ Output: same schema (passthrough)
470
+
471
+ Example:
472
+ -------
473
+ With repeat_count=3:
474
+ Input: [{"a": 1}]
475
+ Output: [{"a": 1}, {"a": 1}, {"a": 1}]
476
+
477
+ """
478
+
479
+ class Meta:
480
+ """Metadata for RepeatInputsFunction."""
481
+
482
+ name = "repeat_inputs"
483
+ description = "Duplicates each input batch N times"
484
+ categories = ["transform", "augmentation"]
485
+ examples = [
486
+ FunctionExample(
487
+ sql="SELECT * FROM repeat_inputs(3, (SELECT * FROM input_table))",
488
+ description="Repeat each row 3 times",
489
+ )
490
+ ]
491
+
492
+ @classmethod
493
+ def on_bind(cls, params: BindParams[RepeatsInputsFunctionArguments]) -> BindResponse:
494
+ """Validate repeat count argument."""
495
+ if params.args.repeat_count < 1:
496
+ raise ValueError("Repeat count must be at least 1")
497
+ if params.bind_call.input_schema is None:
498
+ raise ValueError("input_schema is required but was None")
499
+ return BindResponse(output_schema=params.bind_call.input_schema)
500
+
501
+ @classmethod
502
+ def process(
503
+ cls,
504
+ params: ProcessParams[RepeatsInputsFunctionArguments],
505
+ state: None,
506
+ batch: pa.RecordBatch,
507
+ out: OutputCollector,
508
+ ) -> None:
509
+ """Emit input batch concatenated repeat_count times."""
510
+ combined = pa.Table.from_batches([batch] * params.args.repeat_count).combine_chunks()
511
+ out.emit(combined.to_batches()[0])
512
+
513
+
514
+ @dataclass(slots=True, frozen=True, kw_only=True)
515
+ class SumAllColumnsFunctionArguments:
516
+ """Arguments for SumAllColumnsFunction."""
517
+
518
+ data: Annotated[TableInput, Arg(0, doc="Input table")]
519
+ logging: Annotated[bool, Arg("logging", doc="Whether to log during processing", default=False)] = False
520
+
521
+
522
+ @dataclass(kw_only=True)
523
+ class SumAllColumnsState(ArrowSerializableDataclass):
524
+ """Mutable state for SumAllColumnsFunction - tracks running sums."""
525
+
526
+ partial_sums: Annotated[pa.RecordBatch, ArrowType(pa.binary())]
527
+
528
+
529
+ class SumAllColumnsFunction(TableBufferingFunction[SumAllColumnsFunctionArguments, _LogDrainState]):
530
+ """Aggregation function that computes column-wise sums across all batches.
531
+
532
+ USE CASE
533
+ --------
534
+ Computing totals, aggregating metrics, or any full-stream reduction
535
+ that produces a single summary row.
536
+
537
+ BEHAVIOR
538
+ --------
539
+ - process(): Accumulates sums and emits empty results
540
+ - finalize(): Returns single row with final sums
541
+
542
+ SCHEMA TRANSFORMATION
543
+ ---------------------
544
+ Input: any schema with numeric columns
545
+ Output: only numeric columns, promoted to int64/float64
546
+
547
+ For each input column:
548
+ - Integer types -> int64
549
+ - Floating types -> float64
550
+ - Non-numeric types -> excluded from output
551
+
552
+ KEY PATTERN: ACCUMULATE IN process(), EMIT IN finalize()
553
+ --------------------------------------------------------
554
+ In process(), accumulate state but emit empty results.
555
+ In finalize(), return the final aggregated result.
556
+
557
+ Example:
558
+ -------
559
+ Input schema: {"a": int32, "b": float32, "name": string}
560
+ Output schema: {"a": int64, "b": float64} (string column excluded)
561
+
562
+ Input batches:
563
+ [{"a": 1, "b": 1.5, "name": "x"}, {"a": 2, "b": 2.5, "name": "y"}]
564
+ [{"a": 3, "b": 3.0, "name": "z"}]
565
+
566
+ Output (single row):
567
+ [{"a": 6, "b": 7.0}]
568
+
569
+ """
570
+
571
+ class Meta:
572
+ name = "sum_all_columns"
573
+ description = "Computes column-wise sums across all batches"
574
+ categories = ["aggregation", "numeric"]
575
+ examples = [
576
+ FunctionExample(
577
+ sql="SELECT * FROM sum_all_columns((SELECT * FROM input_table))",
578
+ description="Sum all numeric columns",
579
+ )
580
+ ]
581
+
582
+ @classmethod
583
+ def cardinality(cls, params: BindParams[SumAllColumnsFunctionArguments]) -> TableCardinality:
584
+ """Return cardinality estimate of exactly 1 row."""
585
+ return TableCardinality(estimate=1, max=1)
586
+
587
+ @classmethod
588
+ def on_bind(cls, params: BindParams[SumAllColumnsFunctionArguments]) -> BindResponse:
589
+ """Produce the output schema with only numeric columns.
590
+
591
+ Numeric here means integer, floating-point, or fixed-precision
592
+ decimal. DECIMAL inputs are promoted to float64 in the output
593
+ (matching the float path) — DuckDB users routinely expect a sum
594
+ of DECIMAL values to be summable, so silently dropping them would
595
+ be surprising. Non-numeric inputs (strings, lists, timestamps,
596
+ booleans) are filtered out; if NO numeric columns remain we raise
597
+ ValueError at bind time rather than producing an empty output
598
+ schema (which would crash downstream with an internal assertion).
599
+ """
600
+ assert params.bind_call.input_schema is not None
601
+ output_fields: dict[str, pa.DataType] = {}
602
+ for field in params.bind_call.input_schema:
603
+ out_type: pa.DataType
604
+ if pa.types.is_integer(field.type):
605
+ out_type = pa.int64()
606
+ elif pa.types.is_floating(field.type):
607
+ out_type = pa.float64()
608
+ elif pa.types.is_decimal(field.type):
609
+ # Promote DECIMAL to float64 for the summed output.
610
+ # (A more precise implementation would widen the decimal
611
+ # type to absorb sum overflow, but for a test fixture
612
+ # this is sufficient.)
613
+ out_type = pa.float64()
614
+ else:
615
+ continue
616
+ output_fields[field.name] = out_type
617
+
618
+ if not output_fields:
619
+ input_summary = ", ".join(f"{f.name}: {f.type}" for f in params.bind_call.input_schema)
620
+ raise ValueError(
621
+ "sum_all_columns requires at least one numeric (integer, "
622
+ "floating-point, or decimal) input column, got [" + input_summary + "]"
623
+ )
624
+
625
+ return BindResponse(output_schema=schema(output_fields))
626
+
627
+ @staticmethod
628
+ def _scalars_to_single_row_batch(values: dict[str, pa.Scalar]) -> pa.RecordBatch: # type: ignore[type-arg]
629
+ arrays = [pa.array([scalar], type=scalar.type) for scalar in values.values()]
630
+ return pa.RecordBatch.from_arrays(arrays, names=list(values.keys()))
631
+
632
+ @classmethod
633
+ def process(
634
+ cls,
635
+ batch: pa.RecordBatch,
636
+ params: TableBufferingParams[SumAllColumnsFunctionArguments],
637
+ ) -> bytes:
638
+ """Append this batch's partial sums to the per-execution log.
639
+
640
+ Race-safe append (state_append is atomic). combine() reduces.
641
+ """
642
+ if params.args.logging:
643
+ # Goes through the same wire mechanism as the streaming
644
+ # ``out.client_log()`` path — emits a 0-row log batch on the
645
+ # ``table_buffering_process`` response stream that DuckDB
646
+ # surfaces in ``duckdb_logs()`` with type='VGI'. The Python
647
+ # stdlib ``logging.getLogger(...).info(...)`` we used before
648
+ # didn't reach the wire and never showed up in duckdb_logs
649
+ # (the framework provides no stdlib-logging-to-wire bridge).
650
+ params.client_log(
651
+ Level.INFO,
652
+ f"Processing batch with {batch.num_rows} rows",
653
+ )
654
+
655
+ # Compute partial sums for this batch only.
656
+ sums: dict[str, pa.Scalar[Any]] = {}
657
+ for name in params.output_schema.names:
658
+ col_sum = pc.sum(batch.column(name))
659
+ if col_sum.is_valid:
660
+ sums[name] = col_sum
661
+ else:
662
+ sums[name] = pa.scalar(
663
+ 0,
664
+ type=params.output_schema.field(name).type,
665
+ )
666
+ partial = cls._scalars_to_single_row_batch(sums)
667
+ params.storage.state_append(
668
+ b"partial",
669
+ b"",
670
+ SumAllColumnsState(partial_sums=partial).serialize_to_bytes(),
671
+ )
672
+ return params.execution_id
673
+
674
+ @classmethod
675
+ def combine(
676
+ cls,
677
+ state_ids: list[bytes],
678
+ params: TableBufferingParams[SumAllColumnsFunctionArguments],
679
+ ) -> list[bytes]:
680
+ """Reduce all per-batch partials into one merged batch.
681
+
682
+ combine() runs once on the coordinator after every process()
683
+ completes — no race here. Drains the append-only log, sums, and
684
+ writes the merged row to b"buf"/b"" for finalize to drain.
685
+
686
+ Empty-input guard: even with no state_ids, writes the zeros row
687
+ so ``SELECT ... FROM sum_all_columns((SELECT 1 WHERE 1=0))``
688
+ produces one row of the expected shape.
689
+ """
690
+ if params.args.logging:
691
+ # Symmetric with process() — fires through ``params.client_log``
692
+ # (the unary-RPC analogue of ``out.client_log``) so the message
693
+ # lands in DuckDB's ``duckdb_logs()`` with type='VGI'. Used by
694
+ # ``logging.test`` to verify the in-band log path works from
695
+ # ``combine()`` too, not just from ``process()``.
696
+ params.client_log(
697
+ Level.INFO,
698
+ f"Combining {len(state_ids)} state_ids",
699
+ )
700
+
701
+ merged: dict[str, pa.Scalar[Any]] = {
702
+ name: pa.scalar(0, type=field.type)
703
+ for name, field in zip(
704
+ params.output_schema.names,
705
+ params.output_schema,
706
+ strict=True,
707
+ )
708
+ }
709
+ for _log_id, blob in params.storage.state_log_scan(b"partial", b""):
710
+ partial = SumAllColumnsState.deserialize_from_bytes(blob).partial_sums
711
+ for name in params.output_schema.names:
712
+ merged[name] = pc.add(merged[name], partial.column(name)[0])
713
+ merged_batch = cls._scalars_to_single_row_batch(merged)
714
+ sink = pa.BufferOutputStream()
715
+ with pa.ipc.new_stream(sink, merged_batch.schema) as w:
716
+ w.write_batch(merged_batch)
717
+ params.storage.state_append(b"buf", b"", sink.getvalue().to_pybytes())
718
+ return [params.execution_id]
719
+
720
+ @classmethod
721
+ def initial_finalize_state(
722
+ cls,
723
+ finalize_state_id: bytes,
724
+ params: TableBufferingParams[SumAllColumnsFunctionArguments],
725
+ ) -> _LogDrainState:
726
+ return _LogDrainState(ns=b"buf", after_id=-1)
727
+
728
+ @classmethod
729
+ def finalize(
730
+ cls,
731
+ params: TableBufferingParams[SumAllColumnsFunctionArguments],
732
+ finalize_state_id: bytes,
733
+ state: _LogDrainState,
734
+ out: OutputCollector,
735
+ ) -> None:
736
+ rows = params.storage.state_log_scan(
737
+ state.ns,
738
+ b"",
739
+ after_id=state.after_id,
740
+ limit=1,
741
+ )
742
+ if not rows:
743
+ out.finish()
744
+ return
745
+ log_id, value = rows[0]
746
+ out.emit(pa.ipc.open_stream(value).read_next_batch())
747
+ state.after_id = log_id
748
+
749
+
750
+ @dataclass(kw_only=True)
751
+ class ExceptionProcessState(ArrowSerializableDataclass):
752
+ """Mutable state for ExceptionProcessFunction."""
753
+
754
+ batch_count: int = 0
755
+
756
+
757
+ class ExceptionProcessFunction(SumAllColumnsFunction):
758
+ """Buffered table function that raises an exception on the second batch."""
759
+
760
+ class Meta(SumAllColumnsFunction.Meta):
761
+ name = "exception_process"
762
+ description = "Test function that raises exception during process"
763
+ categories = ["test", "error"]
764
+
765
+ @classmethod
766
+ def process(
767
+ cls,
768
+ batch: pa.RecordBatch,
769
+ params: TableBufferingParams[SumAllColumnsFunctionArguments],
770
+ ) -> bytes:
771
+ """Raise an exception on the second batch.
772
+
773
+ Race-safe counter: append-only log under b"count"/b"" — count is
774
+ the number of log entries seen so far. Concurrent process() calls
775
+ on HTTP serialize through state_append's atomic id minting.
776
+ """
777
+ params.storage.state_append(b"count", b"", b"")
778
+ count = len(params.storage.state_log_scan(b"count", b""))
779
+ if count % 2 == 0:
780
+ raise ValueError(f"Intentional exception on batch {count}")
781
+ return params.execution_id
782
+
783
+
784
+ class ExceptionFinalizeFunction(SumAllColumnsFunction):
785
+ """Buffered table function that raises an exception during finalize()."""
786
+
787
+ class Meta(SumAllColumnsFunction.Meta):
788
+ name = "exception_finalize"
789
+ description = "Test function that raises exception during finalize"
790
+ categories = ["test", "error"]
791
+
792
+ @classmethod
793
+ def finalize(
794
+ cls,
795
+ params: TableBufferingParams[SumAllColumnsFunctionArguments],
796
+ finalize_state_id: bytes,
797
+ state: _LogDrainState,
798
+ out: OutputCollector,
799
+ ) -> None:
800
+ raise ValueError("Intentional exception during finalize()")
801
+
802
+
803
+ @dataclass(slots=True, kw_only=True)
804
+ class SumAllColumnsSimpleDistributedState(ArrowSerializableDataclass):
805
+ """Partial sum state for distributed aggregation."""
806
+
807
+ partial_sum: Annotated[pa.RecordBatch, ArrowType(pa.binary())]
808
+
809
+
810
+ class SumAllColumnsSimpleDistributed(TableInOutFunction[SingleTableArguments, SumAllColumnsSimpleDistributedState]):
811
+ """Distributed aggregation using the simple callback API.
812
+
813
+ This function demonstrates TableInOutFunction with distributed
814
+ state management.
815
+
816
+ It's equivalent to SumAllColumnsFunctionDistributed but uses
817
+ the simpler callback API.
818
+
819
+ Example:
820
+ -------
821
+ Input batches (split across workers):
822
+ Worker 1: [{a: 1, b: 1.0}, {a: 2, b: 2.0}]
823
+ Worker 2: [{a: 3, b: 3.0}]
824
+
825
+ Each worker computes partial sums:
826
+ Worker 1 state: {a: 3, b: 3.0}
827
+ Worker 2 state: {a: 3, b: 3.0}
828
+
829
+ Primary worker merges states in finish():
830
+ Combined: {a: 6, b: 6.0}
831
+
832
+ Output (single row):
833
+ [{a: 6, b: 6.0}]
834
+
835
+ """
836
+
837
+ class Meta:
838
+ """Metadata for SumAllColumnsSimpleDistributed."""
839
+
840
+ name = "sum_all_columns_simple_distributed"
841
+ description = "Distributed sum using simple callback API"
842
+ categories = ["aggregation", "numeric", "distributed"]
843
+ examples = [
844
+ FunctionExample(
845
+ sql=("SELECT * FROM sum_all_columns_simple_distributed((SELECT * FROM input_table))"),
846
+ description="Sum columns using distributed workers with callback API",
847
+ )
848
+ ]
849
+
850
+ @classmethod
851
+ def cardinality(cls, params: BindParams[SingleTableArguments]) -> TableCardinality:
852
+ """Return cardinality estimate of exactly 1 row."""
853
+ return TableCardinality(estimate=1, max=1)
854
+
855
+ @classmethod
856
+ def on_bind(cls, params: BindParams[SingleTableArguments]) -> BindResponse:
857
+ """Produce the output schema with only numeric columns."""
858
+ assert params.bind_call.input_schema is not None
859
+ output_fields: dict[str, pa.DataType] = {}
860
+ for field in params.bind_call.input_schema:
861
+ out_type: pa.DataType
862
+ if pa.types.is_integer(field.type):
863
+ out_type = pa.int64()
864
+ elif pa.types.is_floating(field.type):
865
+ out_type = pa.float64()
866
+ else:
867
+ continue
868
+ output_fields[field.name] = out_type
869
+
870
+ return BindResponse(output_schema=schema(output_fields))
871
+
872
+ @classmethod
873
+ def initial_state(cls, params: ProcessParams[SingleTableArguments]) -> SumAllColumnsSimpleDistributedState | None:
874
+ """Create the initial state."""
875
+ return SumAllColumnsSimpleDistributedState(
876
+ partial_sum=pa.RecordBatch.from_pylist(
877
+ [{name: 0 for name in params.output_schema.names}], schema=params.output_schema
878
+ )
879
+ )
880
+
881
+ @classmethod
882
+ def transform(
883
+ cls,
884
+ batch: pa.RecordBatch,
885
+ params: ProcessParams[SingleTableArguments],
886
+ state: SumAllColumnsSimpleDistributedState | None,
887
+ ) -> pa.RecordBatch:
888
+ """Accumulate column sums. Emit nothing during processing."""
889
+ if state is None:
890
+ raise ValueError("State must not be None in transform()")
891
+ # Add this batch's values to running sums
892
+ sums: dict[str, pa.Scalar[Any]] = {}
893
+ for name in params.output_schema.names:
894
+ col_sum = pc.sum(batch.column(name))
895
+ if col_sum.is_valid:
896
+ sums[name] = pc.add(state.partial_sum.column(name)[0], col_sum)
897
+ else:
898
+ sums[name] = state.partial_sum.column(name)[0]
899
+
900
+ state.partial_sum = pa.RecordBatch.from_pylist(
901
+ [{name: val for name, val in sums.items()}],
902
+ schema=params.output_schema,
903
+ )
904
+
905
+ return empty_batch(params.output_schema)
906
+
907
+ @classmethod
908
+ def finish(
909
+ cls,
910
+ params: ProcessParams[SingleTableArguments],
911
+ states: list[SumAllColumnsSimpleDistributedState],
912
+ ) -> list[pa.RecordBatch]:
913
+ """Emit single row containing the column sums."""
914
+ table = pa.Table.from_batches([state.partial_sum for state in states])
915
+
916
+ sums: dict[str, pa.Scalar[Any]] = {}
917
+ for field in params.output_schema:
918
+ sums[field.name] = pa.scalar(0, type=field.type)
919
+
920
+ for field in params.output_schema:
921
+ sums[field.name] = pc.sum(table.column(field.name))
922
+
923
+ return [pa.RecordBatch.from_pylist([{name: val for name, val in sums.items()}], schema=params.output_schema)]
924
+
925
+
926
+ # ============================================================================
927
+ # Failure-injection fixtures
928
+ # ============================================================================
929
+ # These are table_buffering functions designed to exercise crash/error paths in
930
+ # the C++ Sink+Source operator. Each one fails in a specific phase so we can
931
+ # test that the operator: throws cleanly, drains the worker pool, doesn't leak
932
+ # in-flight workers, and recovers on the next query.
933
+
934
+
935
+ class CrashOnProcessFunction(BufferInputFunction):
936
+ """SIGKILLs its own worker process on the first process() call."""
937
+
938
+ class Meta(BufferInputFunction.Meta):
939
+ name = "crash_on_process"
940
+ description = "Worker SIGKILLs itself during process (test)"
941
+ categories = ["test", "crash"]
942
+
943
+ @classmethod
944
+ def process(
945
+ cls,
946
+ batch: pa.RecordBatch,
947
+ params: TableBufferingParams[SingleTableArguments],
948
+ ) -> bytes:
949
+ if sys.platform == "win32": # pragma: no cover - hard-crash equivalent
950
+ os.kill(os.getpid(), signal.SIGABRT)
951
+ else:
952
+ os.kill(os.getpid(), signal.SIGKILL)
953
+ return params.execution_id # unreachable
954
+
955
+
956
+ class CrashOnCombineFunction(BufferInputFunction):
957
+ """Buffers input normally; raises during combine()."""
958
+
959
+ class Meta(BufferInputFunction.Meta):
960
+ name = "crash_on_combine"
961
+ description = "Worker raises during combine (test)"
962
+ categories = ["test", "crash"]
963
+
964
+ @classmethod
965
+ def combine(cls, state_ids: list[bytes], params: TableBufferingParams[SingleTableArguments]) -> list[bytes]:
966
+ raise RuntimeError("Intentional exception during combine()")
967
+
968
+
969
+ class CrashOnFinalizeFunction(BufferInputFunction):
970
+ """Combine returns normally, finalize raises on first tick."""
971
+
972
+ class Meta(BufferInputFunction.Meta):
973
+ name = "crash_on_finalize"
974
+ description = "Worker raises during finalize (test)"
975
+ categories = ["test", "crash"]
976
+
977
+ @classmethod
978
+ def finalize(
979
+ cls,
980
+ params: TableBufferingParams[SingleTableArguments],
981
+ finalize_state_id: bytes,
982
+ state: _LogDrainState,
983
+ out: OutputCollector,
984
+ ) -> None:
985
+ raise ValueError("Intentional exception during finalize()")
986
+
987
+
988
+ class HangOnProcessFunction(BufferInputFunction):
989
+ """Sleeps for an hour in process(); used by the manual cancellation smoke."""
990
+
991
+ class Meta(BufferInputFunction.Meta):
992
+ name = "hang_on_process"
993
+ description = "Worker sleeps in process (manual cancel test)"
994
+ categories = ["test", "hang"]
995
+
996
+ @classmethod
997
+ def process(
998
+ cls,
999
+ batch: pa.RecordBatch,
1000
+ params: TableBufferingParams[SingleTableArguments],
1001
+ ) -> bytes:
1002
+ time.sleep(3600)
1003
+ return params.execution_id # unreachable
1004
+
1005
+
1006
+ @dataclass(kw_only=True)
1007
+ class LargeStateState(ArrowSerializableDataclass):
1008
+ """Holds a single large bytes buffer.
1009
+
1010
+ The C++ side ships this through combine/finalize via the RPC's outer
1011
+ envelope, exercising IPC chunking on the response path.
1012
+ """
1013
+
1014
+ payload: bytes = b""
1015
+
1016
+
1017
+ class LargeStateFunction(TableBufferingFunction[SingleTableArguments, _LogDrainState]):
1018
+ """Accumulates a large payload per state_id and emits it during finalize.
1019
+
1020
+ Each process() call appends 1 MB to the per-worker payload via RMW on
1021
+ BoundStorage. combine() materializes one output row per worker into a
1022
+ state_log; finalize() drains it cursor-style.
1023
+ """
1024
+
1025
+ class Meta:
1026
+ name = "large_state"
1027
+ description = "Buffers ~1 MB per input batch into state (IPC test)"
1028
+ categories = ["test", "memory"]
1029
+
1030
+ @classmethod
1031
+ def on_bind(cls, params: BindParams[SingleTableArguments]) -> BindResponse:
1032
+ assert params.bind_call.input_schema is not None
1033
+ return BindResponse(output_schema=params.bind_call.input_schema)
1034
+
1035
+ @classmethod
1036
+ def process(
1037
+ cls,
1038
+ batch: pa.RecordBatch,
1039
+ params: TableBufferingParams[SingleTableArguments],
1040
+ ) -> bytes:
1041
+ """Append 1 MB per call to the per-execution log.
1042
+
1043
+ Race-safe append; combine() sums the total payload size by
1044
+ scanning the log.
1045
+ """
1046
+ params.storage.state_append(
1047
+ b"large",
1048
+ b"",
1049
+ b"\x00" * (1024 * 1024),
1050
+ )
1051
+ return params.execution_id
1052
+
1053
+ @classmethod
1054
+ def combine(
1055
+ cls,
1056
+ state_ids: list[bytes],
1057
+ params: TableBufferingParams[SingleTableArguments],
1058
+ ) -> list[bytes]:
1059
+ """Materialize one output row carrying the total payload size."""
1060
+ total = sum(len(blob) for _log_id, blob in params.storage.state_log_scan(b"large", b""))
1061
+ out_batch = pa.RecordBatch.from_pydict(
1062
+ {name: [total] for name in params.output_schema.names},
1063
+ schema=params.output_schema,
1064
+ )
1065
+ sink = pa.BufferOutputStream()
1066
+ with pa.ipc.new_stream(sink, out_batch.schema) as w:
1067
+ w.write_batch(out_batch)
1068
+ params.storage.state_append(b"buf", b"", sink.getvalue().to_pybytes())
1069
+ return [params.execution_id]
1070
+
1071
+ @classmethod
1072
+ def initial_finalize_state(
1073
+ cls,
1074
+ finalize_state_id: bytes,
1075
+ params: TableBufferingParams[SingleTableArguments],
1076
+ ) -> _LogDrainState:
1077
+ return _LogDrainState(ns=b"buf", after_id=-1)
1078
+
1079
+ @classmethod
1080
+ def finalize(
1081
+ cls,
1082
+ params: TableBufferingParams[SingleTableArguments],
1083
+ finalize_state_id: bytes,
1084
+ state: _LogDrainState,
1085
+ out: OutputCollector,
1086
+ ) -> None:
1087
+ rows = params.storage.state_log_scan(
1088
+ state.ns,
1089
+ b"",
1090
+ after_id=state.after_id,
1091
+ limit=1,
1092
+ )
1093
+ if not rows:
1094
+ out.finish()
1095
+ return
1096
+ log_id, value = rows[0]
1097
+ out.emit(pa.ipc.open_stream(value).read_next_batch())
1098
+ state.after_id = log_id
1099
+
1100
+
1101
+ # ============================================================================
1102
+ # Ordering knobs (sink_order_dependent, requires_input_batch_index)
1103
+ # ============================================================================
1104
+
1105
+
1106
+ class OrderedBufferInputFunction(BufferInputFunction):
1107
+ """Buffered table function with single-threaded ingest.
1108
+
1109
+ Uses ``Meta.sink_order_dependent=True`` to force ``ParallelSink=false`` on
1110
+ the C++ operator. Every ``process()`` call arrives on the same worker in
1111
+ source order — verifying this works correctly is the integration test's
1112
+ job (assert distinct ``conn=`` count is exactly 1).
1113
+
1114
+ Output is identical to ``BufferInputFunction``: passthrough of all
1115
+ buffered rows. Because there's only one Sink thread there's only one
1116
+ state_id; combine returns ``[0]`` and finalize yields the buffer.
1117
+ """
1118
+
1119
+ class Meta(BufferInputFunction.Meta):
1120
+ name = "ordered_buffer_input"
1121
+ description = "buffer_input variant with sink_order_dependent=True"
1122
+ categories = ["test", "ordering"]
1123
+ sink_order_dependent = True
1124
+
1125
+
1126
+ def _pack_indexed_batch(batch_index: int, batch_bytes: bytes) -> bytes:
1127
+ """Pack (batch_index, batch_bytes) into a single appendable blob.
1128
+
1129
+ Layout: 8 bytes little-endian signed batch_index || raw IPC stream bytes.
1130
+ Used by BatchIndexBufferInputFunction to thread per-batch ordering keys
1131
+ through the append-only state_log without an extra ArrowSerializableDataclass
1132
+ round-trip.
1133
+ """
1134
+ return batch_index.to_bytes(8, "little", signed=True) + batch_bytes
1135
+
1136
+
1137
+ def _unpack_indexed_batch(blob: bytes) -> tuple[int, bytes]:
1138
+ """Inverse of _pack_indexed_batch."""
1139
+ return int.from_bytes(blob[:8], "little", signed=True), blob[8:]
1140
+
1141
+
1142
+ class BatchIndexBufferInputFunction(TableBufferingFunction[SingleTableArguments, _LogDrainState]):
1143
+ """Buffered table function that demands ``batch_index`` per ``process()``.
1144
+
1145
+ Uses ``Meta.requires_input_batch_index=True`` so the C++ operator
1146
+ declares ``RequiredPartitionInfo()=BatchIndex()`` and threads DuckDB's
1147
+ per-chunk batch_index into every ``process()`` call. process() packs
1148
+ (batch_index, ipc_bytes) into the per-worker state_log; combine() sorts
1149
+ globally by batch_index and re-writes a sorted log; finalize() drains
1150
+ cursor-style.
1151
+ """
1152
+
1153
+ class Meta:
1154
+ name = "batch_index_buffer_input"
1155
+ description = "buffer_input variant using batch_index to reconstruct order"
1156
+ categories = ["test", "ordering"]
1157
+ requires_input_batch_index = True
1158
+
1159
+ @classmethod
1160
+ def on_bind(cls, params: BindParams[SingleTableArguments]) -> BindResponse:
1161
+ assert params.bind_call.input_schema is not None
1162
+ return BindResponse(output_schema=params.bind_call.input_schema)
1163
+
1164
+ @classmethod
1165
+ def process(
1166
+ cls,
1167
+ batch: pa.RecordBatch,
1168
+ params: TableBufferingParams[SingleTableArguments],
1169
+ ) -> bytes:
1170
+ if params.batch_index is None:
1171
+ raise RuntimeError(
1172
+ "batch_index_buffer_input.process() received batch_index=None "
1173
+ "— Meta.requires_input_batch_index plumbing is broken"
1174
+ )
1175
+ sink = pa.BufferOutputStream()
1176
+ with pa.ipc.new_stream(sink, batch.schema) as writer:
1177
+ writer.write_batch(batch)
1178
+ # Append-only — race-safe under concurrent Sink threads. combine()
1179
+ # collects and sorts globally by batch_index.
1180
+ params.storage.state_append(
1181
+ b"unsorted",
1182
+ b"",
1183
+ _pack_indexed_batch(params.batch_index, sink.getvalue().to_pybytes()),
1184
+ )
1185
+ return params.execution_id
1186
+
1187
+ @classmethod
1188
+ def combine(
1189
+ cls,
1190
+ state_ids: list[bytes],
1191
+ params: TableBufferingParams[SingleTableArguments],
1192
+ ) -> list[bytes]:
1193
+ """Sort globally by batch_index and re-emit as a single ordered log."""
1194
+ all_pairs: list[tuple[int, bytes]] = [
1195
+ _unpack_indexed_batch(v) for _, v in params.storage.state_log_scan(b"unsorted", b"")
1196
+ ]
1197
+ all_pairs.sort(key=lambda p: p[0])
1198
+ for _idx, batch_bytes in all_pairs:
1199
+ params.storage.state_append(b"buf", b"", batch_bytes)
1200
+ return [params.execution_id]
1201
+
1202
+ @classmethod
1203
+ def initial_finalize_state(
1204
+ cls,
1205
+ finalize_state_id: bytes,
1206
+ params: TableBufferingParams[SingleTableArguments],
1207
+ ) -> _LogDrainState:
1208
+ return _LogDrainState(ns=b"buf", after_id=-1)
1209
+
1210
+ @classmethod
1211
+ def finalize(
1212
+ cls,
1213
+ params: TableBufferingParams[SingleTableArguments],
1214
+ finalize_state_id: bytes,
1215
+ state: _LogDrainState,
1216
+ out: OutputCollector,
1217
+ ) -> None:
1218
+ rows = params.storage.state_log_scan(
1219
+ state.ns,
1220
+ b"",
1221
+ after_id=state.after_id,
1222
+ limit=1,
1223
+ )
1224
+ if not rows:
1225
+ out.finish()
1226
+ return
1227
+ log_id, value = rows[0]
1228
+ out.emit(pa.ipc.open_stream(value).read_next_batch())
1229
+ state.after_id = log_id
1230
+
1231
+
1232
+ @dataclass
1233
+ class _OneShotState(ArrowSerializableDataclass):
1234
+ """Single-emit cursor for ``OrderedSourceFunction.finalize``."""
1235
+
1236
+ value: int = 0
1237
+ emitted: bool = False
1238
+
1239
+
1240
+ class OrderedSourceFunction(TableBufferingFunction[SingleTableArguments, _OneShotState]):
1241
+ """Buffered table function with ``source_order_dependent=True``.
1242
+
1243
+ Forces ``ParallelSource()=false`` and ``SourceOrder()=FIXED_ORDER`` on the
1244
+ C++ ``PhysicalVgiTableBufferingFunction``. The Source phase serial-drains
1245
+ ``finalize_queue`` in whatever order ``combine()`` populated it; without
1246
+ ``source_order_dependent`` the parallel Source drains would race and emit
1247
+ rows in arbitrary order.
1248
+
1249
+ The fixture deliberately ignores its input and emits a fixed 0..15
1250
+ integer sequence so the assertion is deterministic regardless of Sink
1251
+ parallelism or input partitioning: ``combine()`` returns sixteen
1252
+ finalize_state_ids encoded as 4-byte big-endian integers in ascending
1253
+ order; ``finalize()`` decodes its state_id and emits exactly one row
1254
+ containing that integer. With ``source_order_dependent`` the C++ Source
1255
+ must yield rows in the same 0..15 order.
1256
+
1257
+ Output schema: single ``v`` column (BIGINT).
1258
+ """
1259
+
1260
+ class Meta:
1261
+ name = "ordered_source"
1262
+ description = "Emits a fixed 0..15 sequence via source_order_dependent=True; input is ignored"
1263
+ categories = ["test", "ordering"]
1264
+ source_order_dependent = True
1265
+
1266
+ _N_ROWS = 16
1267
+
1268
+ @classmethod
1269
+ def on_bind(cls, params: BindParams[SingleTableArguments]) -> BindResponse:
1270
+ return BindResponse(output_schema=schema(v=pa.int64()))
1271
+
1272
+ @classmethod
1273
+ def process(
1274
+ cls,
1275
+ batch: pa.RecordBatch,
1276
+ params: TableBufferingParams[SingleTableArguments],
1277
+ ) -> bytes:
1278
+ # Input is irrelevant — the test asserts source ordering, not data.
1279
+ return params.execution_id
1280
+
1281
+ @classmethod
1282
+ def combine(
1283
+ cls,
1284
+ state_ids: list[bytes],
1285
+ params: TableBufferingParams[SingleTableArguments],
1286
+ ) -> list[bytes]:
1287
+ # Fixed monotonically-ascending list of 4-byte big-endian integers.
1288
+ # FIXED_ORDER Source must drain finalize_queue in this exact order.
1289
+ return [i.to_bytes(4, "big") for i in range(cls._N_ROWS)]
1290
+
1291
+ @classmethod
1292
+ def initial_finalize_state(
1293
+ cls,
1294
+ finalize_state_id: bytes,
1295
+ params: TableBufferingParams[SingleTableArguments],
1296
+ ) -> _OneShotState:
1297
+ return _OneShotState(value=int.from_bytes(finalize_state_id, "big"))
1298
+
1299
+ @classmethod
1300
+ def finalize(
1301
+ cls,
1302
+ params: TableBufferingParams[SingleTableArguments],
1303
+ finalize_state_id: bytes,
1304
+ state: _OneShotState,
1305
+ out: OutputCollector,
1306
+ ) -> None:
1307
+ if state.emitted:
1308
+ out.finish()
1309
+ return
1310
+ out.emit(pa.RecordBatch.from_pylist([{"v": state.value}], schema=params.output_schema))
1311
+ state.emitted = True
1312
+
1313
+
1314
+ # ---------------------------------------------------------------------------
1315
+ # Repro fixture: emit a single large finalize batch from a buffering function.
1316
+ # ---------------------------------------------------------------------------
1317
+
1318
+
1319
+ @dataclass(slots=True, frozen=True, kw_only=True)
1320
+ class BufferEmitWideArguments:
1321
+ """Arguments for BufferEmitWideFunction."""
1322
+
1323
+ rows: Annotated[int, Arg(0, doc="Number of rows to emit in one finalize batch", ge=0)]
1324
+ data: Annotated[TableInput, Arg(1, doc="Input table (content ignored)")]
1325
+
1326
+
1327
+ @dataclass
1328
+ class _EmitOnceState(ArrowSerializableDataclass):
1329
+ """Whether the single finalize batch has been emitted."""
1330
+
1331
+ emitted: bool = False
1332
+
1333
+
1334
+ _BUFFER_EMIT_WIDE_SCHEMA = schema(n=pa.int64())
1335
+
1336
+
1337
+ class BufferEmitWideFunction(TableBufferingFunction[BufferEmitWideArguments, _EmitOnceState]):
1338
+ """Buffering function whose Source phase emits ONE batch of ``rows`` rows.
1339
+
1340
+ Unlike BufferInputFunction (which echoes input batches, each already
1341
+ capped at DuckDB's standard vector size), this emits a single, arbitrarily
1342
+ large output batch from ``finalize``. It is a minimal repro for whether the
1343
+ buffering Source path supports output batches larger than the standard
1344
+ vector size (2048 rows) — a regular TableFunctionGenerator (e.g. sequence)
1345
+ does support this.
1346
+ """
1347
+
1348
+ class Meta:
1349
+ """Metadata for BufferEmitWideFunction."""
1350
+
1351
+ name = "buffer_emit_wide"
1352
+ description = "Emit a single finalize batch of N rows (vector-size repro)"
1353
+ categories = ["test", "buffer"]
1354
+ examples = [
1355
+ FunctionExample(
1356
+ sql="SELECT count(*) FROM buffer_emit_wide((SELECT 1), 10000)",
1357
+ description="Emit a single 10000-row batch from the Source phase",
1358
+ )
1359
+ ]
1360
+
1361
+ @classmethod
1362
+ def on_bind(cls, params: BindParams[BufferEmitWideArguments]) -> BindResponse:
1363
+ return BindResponse(output_schema=_BUFFER_EMIT_WIDE_SCHEMA)
1364
+
1365
+ @classmethod
1366
+ def process(cls, batch: pa.RecordBatch, params: TableBufferingParams[BufferEmitWideArguments]) -> bytes:
1367
+ return params.execution_id
1368
+
1369
+ @classmethod
1370
+ def combine(cls, state_ids: list[bytes], params: TableBufferingParams[BufferEmitWideArguments]) -> list[bytes]:
1371
+ return [params.execution_id]
1372
+
1373
+ @classmethod
1374
+ def initial_finalize_state(
1375
+ cls, finalize_state_id: bytes, params: TableBufferingParams[BufferEmitWideArguments]
1376
+ ) -> _EmitOnceState:
1377
+ return _EmitOnceState(emitted=False)
1378
+
1379
+ @classmethod
1380
+ def finalize(
1381
+ cls,
1382
+ params: TableBufferingParams[BufferEmitWideArguments],
1383
+ finalize_state_id: bytes,
1384
+ state: _EmitOnceState,
1385
+ out: OutputCollector,
1386
+ ) -> None:
1387
+ if state.emitted:
1388
+ out.finish()
1389
+ return
1390
+ n = params.args.rows
1391
+ out.emit(pa.RecordBatch.from_pydict({"n": list(range(n))}, schema=_BUFFER_EMIT_WIDE_SCHEMA))
1392
+ state.emitted = True