vgi-python 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vgi/__init__.py +152 -0
- vgi/_duckdb.py +62 -0
- vgi/_storage_profile.py +132 -0
- vgi/_test_fixtures/__init__.py +20 -0
- vgi/_test_fixtures/accumulate/__init__.py +19 -0
- vgi/_test_fixtures/accumulate/worker.py +762 -0
- vgi/_test_fixtures/aggregate/__init__.py +62 -0
- vgi/_test_fixtures/aggregate/_common.py +21 -0
- vgi/_test_fixtures/aggregate/basic.py +232 -0
- vgi/_test_fixtures/aggregate/dynamic.py +409 -0
- vgi/_test_fixtures/aggregate/generic.py +86 -0
- vgi/_test_fixtures/aggregate/listagg.py +71 -0
- vgi/_test_fixtures/aggregate/percentile.py +107 -0
- vgi/_test_fixtures/aggregate/streaming.py +192 -0
- vgi/_test_fixtures/aggregate/varargs.py +75 -0
- vgi/_test_fixtures/aggregate/window.py +380 -0
- vgi/_test_fixtures/attach_options.py +308 -0
- vgi/_test_fixtures/bad_protocol.py +62 -0
- vgi/_test_fixtures/cancellable.py +336 -0
- vgi/_test_fixtures/catalog.py +813 -0
- vgi/_test_fixtures/http_server.py +394 -0
- vgi/_test_fixtures/nest_tensor.py +614 -0
- vgi/_test_fixtures/orchard_catalog.py +47 -0
- vgi/_test_fixtures/projection_repro/__init__.py +6 -0
- vgi/_test_fixtures/projection_repro/worker.py +454 -0
- vgi/_test_fixtures/scalar/__init__.py +116 -0
- vgi/_test_fixtures/scalar/_common.py +69 -0
- vgi/_test_fixtures/scalar/arithmetic.py +321 -0
- vgi/_test_fixtures/scalar/binary.py +120 -0
- vgi/_test_fixtures/scalar/formatting.py +176 -0
- vgi/_test_fixtures/scalar/geo.py +300 -0
- vgi/_test_fixtures/scalar/null_handling.py +107 -0
- vgi/_test_fixtures/scalar/random_demo.py +171 -0
- vgi/_test_fixtures/scalar/settings_secrets.py +102 -0
- vgi/_test_fixtures/scalar/type_info.py +219 -0
- vgi/_test_fixtures/schema_reconcile/__init__.py +29 -0
- vgi/_test_fixtures/schema_reconcile/worker.py +653 -0
- vgi/_test_fixtures/simple_writable.py +793 -0
- vgi/_test_fixtures/table/__init__.py +221 -0
- vgi/_test_fixtures/table/_common.py +162 -0
- vgi/_test_fixtures/table/batch_index.py +283 -0
- vgi/_test_fixtures/table/batch_index_broken.py +200 -0
- vgi/_test_fixtures/table/catalog_scans.py +162 -0
- vgi/_test_fixtures/table/filters.py +1005 -0
- vgi/_test_fixtures/table/late_materialization.py +249 -0
- vgi/_test_fixtures/table/make_series.py +273 -0
- vgi/_test_fixtures/table/misc.py +499 -0
- vgi/_test_fixtures/table/order_modes.py +164 -0
- vgi/_test_fixtures/table/pairs.py +437 -0
- vgi/_test_fixtures/table/partition_columns.py +472 -0
- vgi/_test_fixtures/table/partition_columns_broken.py +304 -0
- vgi/_test_fixtures/table/profiling_example.py +195 -0
- vgi/_test_fixtures/table/required_filters.py +234 -0
- vgi/_test_fixtures/table/sequence.py +710 -0
- vgi/_test_fixtures/table/settings.py +426 -0
- vgi/_test_fixtures/table/transaction_storage.py +162 -0
- vgi/_test_fixtures/table/tt_pushdown.py +191 -0
- vgi/_test_fixtures/table/versioned.py +230 -0
- vgi/_test_fixtures/table_in_out.py +1392 -0
- vgi/_test_fixtures/versioned.py +155 -0
- vgi/_test_fixtures/versioned_tables.py +595 -0
- vgi/_test_fixtures/worker.py +1631 -0
- vgi/_test_fixtures/writable/__init__.py +8 -0
- vgi/_test_fixtures/writable/generic.py +236 -0
- vgi/_test_fixtures/writable/table.py +149 -0
- vgi/_test_fixtures/writable/worker.py +1148 -0
- vgi/aggregate_function.py +607 -0
- vgi/argument_spec.py +472 -0
- vgi/arguments.py +1747 -0
- vgi/auth.py +55 -0
- vgi/catalog/__init__.py +88 -0
- vgi/catalog/attach_option.py +206 -0
- vgi/catalog/catalog_interface.py +2767 -0
- vgi/catalog/descriptors.py +870 -0
- vgi/catalog/duckdb_statistics.py +377 -0
- vgi/catalog/secret_type.py +96 -0
- vgi/catalog/setting.py +253 -0
- vgi/catalog/storage.py +372 -0
- vgi/client/__init__.py +67 -0
- vgi/client/catalog_mixin.py +1251 -0
- vgi/client/cli.py +582 -0
- vgi/client/cli_catalog.py +182 -0
- vgi/client/cli_schema.py +270 -0
- vgi/client/cli_table.py +907 -0
- vgi/client/cli_transaction.py +97 -0
- vgi/client/cli_utils.py +441 -0
- vgi/client/cli_view.py +303 -0
- vgi/client/client.py +2183 -0
- vgi/exceptions.py +205 -0
- vgi/function.py +245 -0
- vgi/function_storage.py +1636 -0
- vgi/function_storage_azure_sql.py +922 -0
- vgi/function_storage_cf_do.py +740 -0
- vgi/http/__init__.py +25 -0
- vgi/http/demo_storage.py +212 -0
- vgi/http/worker_page.py +1252 -0
- vgi/invocation.py +154 -0
- vgi/logging_config.py +93 -0
- vgi/meta_worker.py +661 -0
- vgi/metadata.py +1403 -0
- vgi/otel.py +406 -0
- vgi/protocol.py +2418 -0
- vgi/protocol_version.txt +1 -0
- vgi/py.typed +0 -0
- vgi/scalar_function.py +1211 -0
- vgi/schema_utils.py +234 -0
- vgi/secret_protocol.py +124 -0
- vgi/secret_service.py +238 -0
- vgi/serve.py +769 -0
- vgi/table_buffering_function.py +443 -0
- vgi/table_filter_pushdown.py +1528 -0
- vgi/table_function.py +1130 -0
- vgi/table_in_out_function.py +383 -0
- vgi/transactor/__init__.py +24 -0
- vgi/transactor/_duckdb_compat.py +27 -0
- vgi/transactor/client.py +137 -0
- vgi/transactor/protocol.py +149 -0
- vgi/transactor/server.py +740 -0
- vgi/worker.py +4761 -0
- vgi_python-0.8.0.dist-info/METADATA +735 -0
- vgi_python-0.8.0.dist-info/RECORD +124 -0
- vgi_python-0.8.0.dist-info/WHEEL +4 -0
- vgi_python-0.8.0.dist-info/entry_points.txt +5 -0
- vgi_python-0.8.0.dist-info/licenses/LICENSE +134 -0
vgi/worker.py
ADDED
|
@@ -0,0 +1,4761 @@
|
|
|
1
|
+
# Copyright 2025, 2026 Query Farm LLC - https://query.farm
|
|
2
|
+
|
|
3
|
+
"""VGI Worker base class for hosting user-defined functions and catalogs.
|
|
4
|
+
|
|
5
|
+
A worker is a subprocess that communicates via stdin/stdout using Arrow IPC.
|
|
6
|
+
Workers are spawned by a client as needed and terminate once they detect their
|
|
7
|
+
input stream has been closed.
|
|
8
|
+
|
|
9
|
+
SUPPORTED FUNCTION TYPES
|
|
10
|
+
------------------------
|
|
11
|
+
The worker supports three function types, dispatched based on class inheritance:
|
|
12
|
+
|
|
13
|
+
1. ScalarFunction / ScalarFunctionGenerator: Transforms input batches to
|
|
14
|
+
single-column output with 1:1 row mapping. Use for per-row computations.
|
|
15
|
+
|
|
16
|
+
2. TableInOutFunction / TableInOutGenerator: Reads input batches, produces
|
|
17
|
+
output batches. Use for transforming, filtering, or aggregating input.
|
|
18
|
+
|
|
19
|
+
3. TableFunctionGenerator: Generates output batches without reading input.
|
|
20
|
+
Use for data generation functions like sequence(), range(), etc.
|
|
21
|
+
|
|
22
|
+
QUICK START
|
|
23
|
+
-----------
|
|
24
|
+
Create a worker by subclassing Worker and listing your functions:
|
|
25
|
+
|
|
26
|
+
from vgi.worker import Worker
|
|
27
|
+
from vgi.scalar_function import ScalarFunction
|
|
28
|
+
from vgi.table_in_out_function import TableInOutGenerator
|
|
29
|
+
from vgi.table_function import TableFunctionGenerator
|
|
30
|
+
|
|
31
|
+
class DoubleColumn(ScalarFunction):
|
|
32
|
+
# Single-column output with 1:1 row mapping
|
|
33
|
+
...
|
|
34
|
+
|
|
35
|
+
class EchoFunction(TableInOutGenerator):
|
|
36
|
+
# Transforms input batches
|
|
37
|
+
...
|
|
38
|
+
|
|
39
|
+
class SequenceFunction(TableFunctionGenerator):
|
|
40
|
+
# Generates output without input
|
|
41
|
+
...
|
|
42
|
+
|
|
43
|
+
class MyWorker(Worker):
|
|
44
|
+
functions = [DoubleColumn, EchoFunction, SequenceFunction]
|
|
45
|
+
|
|
46
|
+
if __name__ == "__main__":
|
|
47
|
+
MyWorker().run()
|
|
48
|
+
|
|
49
|
+
Function names are derived from metadata (Meta.name or class name converted to
|
|
50
|
+
snake_case). No manual name mapping required.
|
|
51
|
+
|
|
52
|
+
KEY CLASSES
|
|
53
|
+
-----------
|
|
54
|
+
Worker - Base class to subclass (set functions attribute)
|
|
55
|
+
|
|
56
|
+
See Also
|
|
57
|
+
--------
|
|
58
|
+
vgi.client.Client : Spawns workers and sends data to them
|
|
59
|
+
vgi.function.Function : Base class for all functions
|
|
60
|
+
vgi._test_fixtures.worker : Example worker with built-in functions
|
|
61
|
+
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
from __future__ import annotations
|
|
65
|
+
|
|
66
|
+
import contextlib
|
|
67
|
+
import hashlib
|
|
68
|
+
import importlib.metadata
|
|
69
|
+
import logging
|
|
70
|
+
import os
|
|
71
|
+
import pickle
|
|
72
|
+
import sys
|
|
73
|
+
import uuid
|
|
74
|
+
from collections import OrderedDict
|
|
75
|
+
from collections.abc import Sequence
|
|
76
|
+
from dataclasses import dataclass
|
|
77
|
+
from dataclasses import replace as _dataclass_replace
|
|
78
|
+
from threading import Lock
|
|
79
|
+
from typing import TYPE_CHECKING, Any, ClassVar, cast, final, overload
|
|
80
|
+
|
|
81
|
+
import pyarrow as pa
|
|
82
|
+
from vgi_rpc.rpc import AuthContext, CallContext, RpcServer, Stream, current_auth, serve_stdio
|
|
83
|
+
|
|
84
|
+
from vgi.aggregate_function import AggregateBindParams, AggregateFunction
|
|
85
|
+
from vgi.argument_spec import ArgumentSpec, extract_argument_specs
|
|
86
|
+
from vgi.arguments import Arguments
|
|
87
|
+
from vgi.catalog import CatalogInterface
|
|
88
|
+
from vgi.catalog.attach_option import AttachOptionSpec, extract_attach_option_specs
|
|
89
|
+
from vgi.catalog.catalog_interface import (
|
|
90
|
+
AttachOpaqueData,
|
|
91
|
+
CatalogAttachResult,
|
|
92
|
+
OnConflict,
|
|
93
|
+
SchemaObjectType,
|
|
94
|
+
SerializedSchema,
|
|
95
|
+
SqlExpression,
|
|
96
|
+
TransactionOpaqueData,
|
|
97
|
+
_validate_at_params,
|
|
98
|
+
serialize_column_statistics,
|
|
99
|
+
)
|
|
100
|
+
from vgi.catalog.secret_type import SecretTypeSpec
|
|
101
|
+
from vgi.catalog.setting import SettingSpec, extract_setting_specs
|
|
102
|
+
from vgi.function import (
|
|
103
|
+
Function,
|
|
104
|
+
)
|
|
105
|
+
from vgi.function_storage import BoundStorage, FrameworkNS, attach_catalog_bytes
|
|
106
|
+
from vgi.invocation import (
|
|
107
|
+
BindResponse,
|
|
108
|
+
GlobalInitResponse,
|
|
109
|
+
)
|
|
110
|
+
from vgi.logging_config import LogFormat, LogLevel
|
|
111
|
+
from vgi.otel import VgiTracer, get_noop_tracer
|
|
112
|
+
from vgi.protocol import (
|
|
113
|
+
BindRequest,
|
|
114
|
+
BufferedFinalizeState,
|
|
115
|
+
CatalogAttachRequest,
|
|
116
|
+
CatalogCreateRequest,
|
|
117
|
+
CatalogsResponse,
|
|
118
|
+
CatalogVersionResponse,
|
|
119
|
+
FunctionsResponse,
|
|
120
|
+
IndexCreateRequest,
|
|
121
|
+
IndexesResponse,
|
|
122
|
+
InitRequest,
|
|
123
|
+
MacroCreateRequest,
|
|
124
|
+
MacrosResponse,
|
|
125
|
+
ProcessState,
|
|
126
|
+
ScalarExchangeState,
|
|
127
|
+
SchemasResponse,
|
|
128
|
+
TableBufferingFinalizeState,
|
|
129
|
+
TableCreateRequest,
|
|
130
|
+
TableFunctionCardinalityRequest,
|
|
131
|
+
TableFunctionDynamicToStringRequest,
|
|
132
|
+
TableFunctionDynamicToStringResponse,
|
|
133
|
+
TableFunctionStatisticsRequest,
|
|
134
|
+
TableInOutExchangeState,
|
|
135
|
+
TableProducerState,
|
|
136
|
+
TablesResponse,
|
|
137
|
+
TransactionBeginResponse,
|
|
138
|
+
VgiProtocol,
|
|
139
|
+
ViewsResponse,
|
|
140
|
+
)
|
|
141
|
+
from vgi.scalar_function import ScalarFunctionGenerator
|
|
142
|
+
from vgi.table_buffering_function import (
|
|
143
|
+
TableBufferingFunction,
|
|
144
|
+
TableBufferingParams,
|
|
145
|
+
)
|
|
146
|
+
from vgi.table_function import (
|
|
147
|
+
ProcessParams,
|
|
148
|
+
SecretsAccessor,
|
|
149
|
+
TableCardinality,
|
|
150
|
+
TableFunctionBase,
|
|
151
|
+
TableFunctionGenerator,
|
|
152
|
+
TableInOutFunctionInitPhase,
|
|
153
|
+
_batch_to_scalar_dict,
|
|
154
|
+
_effective_projection_ids,
|
|
155
|
+
project_schema,
|
|
156
|
+
)
|
|
157
|
+
from vgi.table_in_out_function import (
|
|
158
|
+
TableInOutGenerator,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
if TYPE_CHECKING:
|
|
162
|
+
from vgi.catalog.descriptors import Catalog
|
|
163
|
+
from vgi.protocol import (
|
|
164
|
+
AggregateBindRequest,
|
|
165
|
+
AggregateBindResponse,
|
|
166
|
+
AggregateCombineRequest,
|
|
167
|
+
AggregateCombineResponse,
|
|
168
|
+
AggregateDestructorRequest,
|
|
169
|
+
AggregateDestructorResponse,
|
|
170
|
+
AggregateFinalizeRequest,
|
|
171
|
+
AggregateFinalizeResponse,
|
|
172
|
+
AggregateStreamingChunkRequest,
|
|
173
|
+
AggregateStreamingChunkResponse,
|
|
174
|
+
AggregateStreamingCloseRequest,
|
|
175
|
+
AggregateStreamingCloseResponse,
|
|
176
|
+
AggregateStreamingOpenRequest,
|
|
177
|
+
AggregateStreamingOpenResponse,
|
|
178
|
+
AggregateUpdateRequest,
|
|
179
|
+
AggregateUpdateResponse,
|
|
180
|
+
AggregateWindowBatchRequest,
|
|
181
|
+
AggregateWindowBatchResponse,
|
|
182
|
+
AggregateWindowDestructorRequest,
|
|
183
|
+
AggregateWindowDestructorResponse,
|
|
184
|
+
AggregateWindowInitRequest,
|
|
185
|
+
AggregateWindowInitResponse,
|
|
186
|
+
AggregateWindowRequest,
|
|
187
|
+
AggregateWindowResponse,
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
_logger = logging.getLogger("vgi.worker")
|
|
191
|
+
|
|
192
|
+
_vgi_version_cache: str | None = None
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def _write_port_file(path: str, port: int) -> None:
|
|
196
|
+
"""Write `port` to `path` atomically (tmp + rename).
|
|
197
|
+
|
|
198
|
+
Callers (typically test harnesses) watch for the target path to appear,
|
|
199
|
+
so a partially-written file would race the reader. Using rename means
|
|
200
|
+
the file either doesn't exist or has the full port number.
|
|
201
|
+
"""
|
|
202
|
+
import os
|
|
203
|
+
import tempfile
|
|
204
|
+
|
|
205
|
+
parent = os.path.dirname(os.path.abspath(path)) or "."
|
|
206
|
+
fd, tmp = tempfile.mkstemp(prefix=".port.", dir=parent)
|
|
207
|
+
try:
|
|
208
|
+
with os.fdopen(fd, "w") as fh:
|
|
209
|
+
fh.write(f"{port}\n")
|
|
210
|
+
os.replace(tmp, path)
|
|
211
|
+
except BaseException:
|
|
212
|
+
# Best-effort cleanup of the tmp on any failure before the rename.
|
|
213
|
+
with contextlib.suppress(OSError):
|
|
214
|
+
os.unlink(tmp)
|
|
215
|
+
raise
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def _get_vgi_version() -> str:
|
|
219
|
+
"""Return the installed vgi package version (cached)."""
|
|
220
|
+
global _vgi_version_cache # noqa: PLW0603
|
|
221
|
+
if _vgi_version_cache is None:
|
|
222
|
+
try:
|
|
223
|
+
_vgi_version_cache = importlib.metadata.version("vgi-python")
|
|
224
|
+
except importlib.metadata.PackageNotFoundError:
|
|
225
|
+
_vgi_version_cache = "unknown"
|
|
226
|
+
return _vgi_version_cache
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def _format_arguments_for_error(args: Arguments) -> str:
|
|
230
|
+
"""Format Arguments for error messages, showing values and types.
|
|
231
|
+
|
|
232
|
+
Produces output like:
|
|
233
|
+
const_args=[3 (int64), "hello" (string)], named_args={sep: "," (string)}
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
args: The Arguments instance to format.
|
|
237
|
+
|
|
238
|
+
Returns:
|
|
239
|
+
Human-readable string showing argument values and types.
|
|
240
|
+
|
|
241
|
+
"""
|
|
242
|
+
|
|
243
|
+
def format_scalar(scalar: Any) -> str:
|
|
244
|
+
"""Format a single scalar value with its type."""
|
|
245
|
+
if scalar is None:
|
|
246
|
+
return "null"
|
|
247
|
+
elif not scalar.is_valid:
|
|
248
|
+
return f"null ({scalar.type})"
|
|
249
|
+
else:
|
|
250
|
+
value = scalar.as_py()
|
|
251
|
+
type_name = str(scalar.type)
|
|
252
|
+
if isinstance(value, str):
|
|
253
|
+
return f"{value!r} ({type_name})"
|
|
254
|
+
elif isinstance(value, bytes):
|
|
255
|
+
if len(value) > 20:
|
|
256
|
+
return f"<{len(value)} bytes> ({type_name})"
|
|
257
|
+
else:
|
|
258
|
+
return f"{value!r} ({type_name})"
|
|
259
|
+
else:
|
|
260
|
+
return f"{value} ({type_name})"
|
|
261
|
+
|
|
262
|
+
parts = []
|
|
263
|
+
|
|
264
|
+
# Format positional constant arguments
|
|
265
|
+
if args.positional:
|
|
266
|
+
pos_strs = [format_scalar(s) for s in args.positional]
|
|
267
|
+
parts.append(f"const_args=[{', '.join(pos_strs)}]")
|
|
268
|
+
else:
|
|
269
|
+
parts.append("const_args=[]")
|
|
270
|
+
|
|
271
|
+
# Format named constant arguments
|
|
272
|
+
if args.named:
|
|
273
|
+
named_strs = [f"{name}: {format_scalar(scalar)}" for name, scalar in sorted(args.named.items())]
|
|
274
|
+
parts.append(f"named_args={{{', '.join(named_strs)}}}")
|
|
275
|
+
else:
|
|
276
|
+
parts.append("named_args={}")
|
|
277
|
+
|
|
278
|
+
return ", ".join(parts)
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
# ---------------------------------------------------------------------------
|
|
282
|
+
# Window partition in-process cache
|
|
283
|
+
# ---------------------------------------------------------------------------
|
|
284
|
+
# The storage layer (SQLite / Azure SQL / Cloudflare DO) is authoritative and
|
|
285
|
+
# makes the window path correct across multi-process deployments. But a
|
|
286
|
+
# single `aggregate_window` call does a BLOB read + Arrow IPC deserialize,
|
|
287
|
+
# and we make that call once per output row. For a 1000-row partition that's
|
|
288
|
+
# ~200ms of pure storage+deserialize overhead on top of the actual aggregate
|
|
289
|
+
# work — enough to make the window path slower than DuckDB's segment-tree
|
|
290
|
+
# fallback for many aggregates.
|
|
291
|
+
#
|
|
292
|
+
# Layer an in-memory cache on top of storage: populated on ``window_init``,
|
|
293
|
+
# read first on ``window``, invalidated on ``window_destructor`` and on the
|
|
294
|
+
# top-level ``aggregate_destructor`` safety sweep. Storage remains the
|
|
295
|
+
# authoritative source — if the cache misses (different worker process, LRU
|
|
296
|
+
# eviction, or a crashed-and-restarted worker) we fall through to storage.
|
|
297
|
+
|
|
298
|
+
# Cap the cache so a missed destructor in a long-running worker can't grow
|
|
299
|
+
# memory without bound. Eviction is correctness-safe because storage is
|
|
300
|
+
# authoritative.
|
|
301
|
+
_WINDOW_PARTITION_CACHE_MAX = 256
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
class _UpdateTrackingDict[K, V](dict[K, V]):
|
|
305
|
+
"""A dict that records explicit ``__setitem__`` writes.
|
|
306
|
+
|
|
307
|
+
Used by ``aggregate_update`` to tell apart "user's update() reassigned
|
|
308
|
+
state for this group" from "we pre-populated the entry with initial
|
|
309
|
+
state and the user's update() chose not to touch it (e.g. saw only
|
|
310
|
+
NULL inputs)". The framework then persists only entries the user
|
|
311
|
+
actually wrote, so a no-op update on a freshly-seeded group does not
|
|
312
|
+
overwrite the absence of stored state — preserving SQL-standard
|
|
313
|
+
NULL semantics for ``SUM`` of all NULLs.
|
|
314
|
+
"""
|
|
315
|
+
|
|
316
|
+
def __init__(self) -> None:
|
|
317
|
+
super().__init__()
|
|
318
|
+
self.written: set[K] = set()
|
|
319
|
+
|
|
320
|
+
def __setitem__(self, key: K, value: V) -> None:
|
|
321
|
+
super().__setitem__(key, value)
|
|
322
|
+
self.written.add(key)
|
|
323
|
+
|
|
324
|
+
def clear_writes(self) -> None:
|
|
325
|
+
self.written.clear()
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
@dataclass(slots=True)
|
|
329
|
+
class _CachedWindowPartition:
|
|
330
|
+
"""Fully-decoded partition ready to hand to the user's ``window()``.
|
|
331
|
+
|
|
332
|
+
``prepared_state`` holds the result of ``AggregateFunction.window_prepare``
|
|
333
|
+
if the user defines that hook — typically the deserialized
|
|
334
|
+
``window_state`` plus any per-partition derived structures (NumPy views,
|
|
335
|
+
dictionary lookups, etc.). It lives only in this in-memory cache and is
|
|
336
|
+
regenerated whenever the partition is reloaded from FunctionStorage.
|
|
337
|
+
Defaults to ``window_state`` when no hook is defined, so existing
|
|
338
|
+
aggregates see the placeholder unchanged.
|
|
339
|
+
"""
|
|
340
|
+
|
|
341
|
+
partition: Any # vgi.aggregate_function.WindowPartition (avoid import cycle)
|
|
342
|
+
output_schema: pa.Schema
|
|
343
|
+
window_state: Any # _WindowStatePlaceholder | None
|
|
344
|
+
prepared_state: Any = None
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
class _WindowPartitionCache:
|
|
348
|
+
"""Process-local, thread-safe LRU of decoded window partitions.
|
|
349
|
+
|
|
350
|
+
Keyed by ``(execution_id, partition_id)``. Kept small on purpose — a
|
|
351
|
+
missed destructor bounds at ``_WINDOW_PARTITION_CACHE_MAX`` entries.
|
|
352
|
+
"""
|
|
353
|
+
|
|
354
|
+
def __init__(self, max_size: int = _WINDOW_PARTITION_CACHE_MAX) -> None:
|
|
355
|
+
self._entries: OrderedDict[tuple[bytes, int], _CachedWindowPartition] = OrderedDict()
|
|
356
|
+
self._lock = Lock()
|
|
357
|
+
self._max_size = max_size
|
|
358
|
+
|
|
359
|
+
def get(self, execution_id: bytes, partition_id: int) -> _CachedWindowPartition | None:
|
|
360
|
+
key = (execution_id, partition_id)
|
|
361
|
+
with self._lock:
|
|
362
|
+
entry = self._entries.get(key)
|
|
363
|
+
if entry is not None:
|
|
364
|
+
self._entries.move_to_end(key)
|
|
365
|
+
return entry
|
|
366
|
+
|
|
367
|
+
def put(self, execution_id: bytes, partition_id: int, entry: _CachedWindowPartition) -> None:
|
|
368
|
+
key = (execution_id, partition_id)
|
|
369
|
+
with self._lock:
|
|
370
|
+
self._entries[key] = entry
|
|
371
|
+
self._entries.move_to_end(key)
|
|
372
|
+
while len(self._entries) > self._max_size:
|
|
373
|
+
self._entries.popitem(last=False)
|
|
374
|
+
|
|
375
|
+
def delete(self, execution_id: bytes, partition_id: int) -> None:
|
|
376
|
+
key = (execution_id, partition_id)
|
|
377
|
+
with self._lock:
|
|
378
|
+
self._entries.pop(key, None)
|
|
379
|
+
|
|
380
|
+
def clear_execution(self, execution_id: bytes) -> None:
|
|
381
|
+
with self._lock:
|
|
382
|
+
to_drop = [k for k in self._entries if k[0] == execution_id]
|
|
383
|
+
for k in to_drop:
|
|
384
|
+
del self._entries[k]
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
_window_partition_cache = _WindowPartitionCache()
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
# Sentinel for "we already looked this execution_id up and the worker has no
|
|
391
|
+
# const args" — distinguishable from None (= "not yet looked up").
|
|
392
|
+
_ABSENT_SENTINEL: Any = object()
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
class _AggregateConstArgsCache:
|
|
396
|
+
"""Process-local LRU of aggregate const_args, keyed by (shard_key, execution_id).
|
|
397
|
+
|
|
398
|
+
Const args are written once at aggregate_bind (under group_id=-2) and
|
|
399
|
+
never change. Every aggregate_update / aggregate_window / aggregate_finalize
|
|
400
|
+
needs them, and ``_load_aggregate_const_args`` would otherwise issue one
|
|
401
|
+
storage read per call — pathological under remote storage backends.
|
|
402
|
+
|
|
403
|
+
Scoping by ``(shard_key, execution_id)`` mirrors the storage row key:
|
|
404
|
+
const args live under ``(shard_key) -> aggregate_state(execution_id,
|
|
405
|
+
group_id=-2)``. Including shard_key in the cache key prevents a stale
|
|
406
|
+
hit if two different attaches in the same worker process ever produce
|
|
407
|
+
colliding execution_id bytes (UUIDs make this vanishingly rare, but the
|
|
408
|
+
bound is free and matches storage semantics).
|
|
409
|
+
|
|
410
|
+
Cache holds either ``Arguments`` (positive hit), ``_ABSENT_SENTINEL``
|
|
411
|
+
(the worker has no const params; remember the negative result to skip
|
|
412
|
+
future reads), or absent (not yet looked up). Bounded LRU eviction caps
|
|
413
|
+
memory; aggregate_destructor proactively evicts.
|
|
414
|
+
"""
|
|
415
|
+
|
|
416
|
+
def __init__(self, max_size: int = 256) -> None:
|
|
417
|
+
self._entries: OrderedDict[tuple[str, bytes], Any] = OrderedDict()
|
|
418
|
+
self._lock = Lock()
|
|
419
|
+
self._max_size = max_size
|
|
420
|
+
|
|
421
|
+
def get(self, shard_key: str, execution_id: bytes) -> Any:
|
|
422
|
+
key = (shard_key, execution_id)
|
|
423
|
+
with self._lock:
|
|
424
|
+
entry = self._entries.get(key)
|
|
425
|
+
if entry is not None:
|
|
426
|
+
self._entries.move_to_end(key)
|
|
427
|
+
return entry
|
|
428
|
+
|
|
429
|
+
def put(self, shard_key: str, execution_id: bytes, value: Any) -> None:
|
|
430
|
+
key = (shard_key, execution_id)
|
|
431
|
+
with self._lock:
|
|
432
|
+
self._entries[key] = value
|
|
433
|
+
self._entries.move_to_end(key)
|
|
434
|
+
while len(self._entries) > self._max_size:
|
|
435
|
+
self._entries.popitem(last=False)
|
|
436
|
+
|
|
437
|
+
def clear_execution(self, shard_key: str, execution_id: bytes) -> None:
|
|
438
|
+
with self._lock:
|
|
439
|
+
self._entries.pop((shard_key, execution_id), None)
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
_aggregate_const_args_cache = _AggregateConstArgsCache()
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
# TableBufferingFunction init metadata is persisted to FunctionStorage at
|
|
446
|
+
# init time (namespace FrameworkNS.BUFFERING_INIT, key = pack_int_key(_TABLE_BUFFERING_INIT_KEY))
|
|
447
|
+
# and cold-loaded by every subsequent table_buffering_{process,combine}
|
|
448
|
+
# RPC. There is intentionally NO in-process cache here — workers are
|
|
449
|
+
# stateless w.r.t. table_buffering executions, so any pool worker can
|
|
450
|
+
# serve any RPC for any execution_id.
|
|
451
|
+
#
|
|
452
|
+
# Cost: one extra storage round-trip per RPC. Negligible on subprocess
|
|
453
|
+
# (SQLite is microseconds). One HTTP RTT on CfDo deployments — if that
|
|
454
|
+
# surfaces as a hot-path bottleneck, the mitigation is to thread the
|
|
455
|
+
# init metadata bytes through the request envelope from C++ instead.
|
|
456
|
+
_TABLE_BUFFERING_INIT_KEY = -1
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
# Streaming-shape FINALIZE phase materializes the user's
|
|
460
|
+
# ``finalize() -> list[batch]`` return into a single state_log keyed
|
|
461
|
+
# by this constant. One streaming finalize per execution; one key.
|
|
462
|
+
_STREAMING_FINALIZE_KEY = b""
|
|
463
|
+
|
|
464
|
+
|
|
465
|
+
def _build_bound_storage_from_fields(
|
|
466
|
+
execution_id: bytes,
|
|
467
|
+
attach_plaintext: bytes | None,
|
|
468
|
+
ctx: CallContext,
|
|
469
|
+
) -> BoundStorage:
|
|
470
|
+
"""Cold-build a BoundStorage from (execution_id, attach_plaintext, ctx).
|
|
471
|
+
|
|
472
|
+
Called by ``BufferedFinalizeState.produce()`` per tick. The storage
|
|
473
|
+
backend is process-singleton: we read it through ``Function.storage``
|
|
474
|
+
so we hit the cached ``_DefaultStorageDescriptor`` instance instead
|
|
475
|
+
of calling ``_resolve_storage()`` (which constructs a fresh backend
|
|
476
|
+
every call — for the ``:memory:`` sqlite backend each construction
|
|
477
|
+
is a *different* in-memory database, so the appended finalize batch
|
|
478
|
+
never reaches the produce()-side read). ``attach_plaintext`` is the
|
|
479
|
+
**full** unwrapped attach (``uuid(16) || catalog_bytes``) the
|
|
480
|
+
rehydrated state persisted (the auth-scoped seal can't be reopened
|
|
481
|
+
here); ``BoundStorage`` shards on its leading UUID so CfDo routes
|
|
482
|
+
to the right Durable Object.
|
|
483
|
+
"""
|
|
484
|
+
from vgi.function import Function
|
|
485
|
+
|
|
486
|
+
return BoundStorage(
|
|
487
|
+
Function.storage,
|
|
488
|
+
execution_id,
|
|
489
|
+
attach_plaintext=attach_plaintext,
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
|
|
493
|
+
def _decode_ipc_batch(value: bytes) -> pa.RecordBatch:
|
|
494
|
+
"""Read a single record batch from an Arrow IPC stream byte payload."""
|
|
495
|
+
return pa.ipc.open_stream(value).read_next_batch()
|
|
496
|
+
|
|
497
|
+
|
|
498
|
+
def _deserialize_finalize_state(func_cls: type, blob: bytes) -> Any:
|
|
499
|
+
"""Deserialize a TFinalizeState blob using the class's resolved type.
|
|
500
|
+
|
|
501
|
+
Reads the ``_finalize_state_class`` attribute that
|
|
502
|
+
``TableBufferingFunction.__init_subclass__`` records at class-definition
|
|
503
|
+
time. ``None`` means "no per-tick state declared" (user passed
|
|
504
|
+
``TFinalizeState=None`` or didn't parameterize) — we pass the raw blob
|
|
505
|
+
through, which user code in that path also won't read.
|
|
506
|
+
"""
|
|
507
|
+
state_type = getattr(func_cls, "_finalize_state_class", None)
|
|
508
|
+
if state_type is None:
|
|
509
|
+
return None
|
|
510
|
+
return state_type.deserialize_from_bytes(blob)
|
|
511
|
+
|
|
512
|
+
|
|
513
|
+
def run_table_buffering_finalize_tick(state: Any, out: Any, ctx: Any) -> None:
|
|
514
|
+
"""One tick of ``cls.finalize(params, fid, state, out)``.
|
|
515
|
+
|
|
516
|
+
Lazy-imported by ``TableBufferingFinalizeState.produce()`` to break
|
|
517
|
+
the protocol→worker import cycle. Cold-resolves func_cls + params on
|
|
518
|
+
every call (no in-process cache — different worker processes may
|
|
519
|
+
handle different ticks under HTTP).
|
|
520
|
+
|
|
521
|
+
Applies the pushdown contract symmetric with the streaming
|
|
522
|
+
``TableInOutExchangeState`` (``protocol.py:1106-1186``): narrow
|
|
523
|
+
``params.output_schema`` to the projected slots and, when
|
|
524
|
+
``Meta.auto_apply_filters`` is True, wrap ``out`` in a filtering
|
|
525
|
+
collector so the user's ``finalize()`` doesn't need to know.
|
|
526
|
+
"""
|
|
527
|
+
import dataclasses as _dc_mod
|
|
528
|
+
from dataclasses import dataclass as _dc
|
|
529
|
+
|
|
530
|
+
from vgi.protocol import (
|
|
531
|
+
TableBufferingFinalizeState as _TBFS,
|
|
532
|
+
)
|
|
533
|
+
from vgi.protocol import (
|
|
534
|
+
_FilteringOutputCollector,
|
|
535
|
+
_TrackingOutputCollector,
|
|
536
|
+
)
|
|
537
|
+
from vgi.table_function import _effective_projection_ids, project_schema
|
|
538
|
+
|
|
539
|
+
assert isinstance(state, _TBFS), type(state)
|
|
540
|
+
|
|
541
|
+
@_dc
|
|
542
|
+
class _FinalizeStubRequest:
|
|
543
|
+
function_name: str
|
|
544
|
+
execution_id: bytes
|
|
545
|
+
attach_opaque_data: bytes | None
|
|
546
|
+
transaction_id: bytes | None
|
|
547
|
+
|
|
548
|
+
stub = _FinalizeStubRequest(
|
|
549
|
+
function_name=state.function_name,
|
|
550
|
+
execution_id=state.execution_id,
|
|
551
|
+
attach_opaque_data=state.attach_opaque_data,
|
|
552
|
+
transaction_id=state.transaction_id,
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
worker = ctx.implementation
|
|
556
|
+
if worker is None:
|
|
557
|
+
raise RuntimeError(
|
|
558
|
+
"table_buffering_finalize: ctx.implementation is None — "
|
|
559
|
+
"vgi-rpc must be ≥0.16.2 (the release that surfaces the "
|
|
560
|
+
"protocol implementation on CallContext)."
|
|
561
|
+
)
|
|
562
|
+
# ``state.attach_opaque_data`` carries the already-unwrapped attach
|
|
563
|
+
# bytes that init() stashed on the producer state. The seal is bound
|
|
564
|
+
# to the original sealer's auth — under HTTP, the rehydrated tick
|
|
565
|
+
# runs in a different auth context and re-unwrap fails. Skip it.
|
|
566
|
+
func_cls, params = worker._load_table_buffering_params(
|
|
567
|
+
stub,
|
|
568
|
+
ctx,
|
|
569
|
+
attach_already_unwrapped=True,
|
|
570
|
+
)
|
|
571
|
+
|
|
572
|
+
# Narrow output_schema based on projection_ids carried on the state.
|
|
573
|
+
# ``_effective_projection_ids`` returns None when the function didn't
|
|
574
|
+
# opt into projection_pushdown — a safety net so a stale projection_ids
|
|
575
|
+
# field on the wire (from a misbehaving C++ client) can't corrupt the
|
|
576
|
+
# output schema for a function that never asked for projection.
|
|
577
|
+
proj_ids = _effective_projection_ids(func_cls, state.projection_ids)
|
|
578
|
+
if proj_ids is not None:
|
|
579
|
+
params = _dc_mod.replace(
|
|
580
|
+
params,
|
|
581
|
+
output_schema=project_schema(proj_ids, params.output_schema),
|
|
582
|
+
)
|
|
583
|
+
|
|
584
|
+
# Apply auto-apply-filters: wrap `out` in
|
|
585
|
+
# ``_TrackingOutputCollector`` (the inner-validation wrapper that the
|
|
586
|
+
# streaming path uses at protocol.py:1177) and then in
|
|
587
|
+
# ``_FilteringOutputCollector`` so each batch emitted by user's
|
|
588
|
+
# finalize() is filtered through the PushdownFilters object before
|
|
589
|
+
# reaching the caller. _FilteringOutputCollector.emit threads
|
|
590
|
+
# ``batch_index=`` through to the inner; only _TrackingOutputCollector
|
|
591
|
+
# accepts that kwarg, so a raw OutputCollector wrap would TypeError.
|
|
592
|
+
# Mirrors protocol.py:1176-1186.
|
|
593
|
+
auto_apply = state.pushdown_filters is not None and func_cls._should_auto_apply_filters()
|
|
594
|
+
effective_out = out
|
|
595
|
+
if auto_apply:
|
|
596
|
+
pushdown_obj = func_cls.pushdown_filters(state.pushdown_filters)
|
|
597
|
+
if pushdown_obj is not None:
|
|
598
|
+
tracking_out = _TrackingOutputCollector(out)
|
|
599
|
+
effective_out = _FilteringOutputCollector(tracking_out, func_cls, pushdown_obj)
|
|
600
|
+
|
|
601
|
+
if not state.state_initialized:
|
|
602
|
+
user_state: Any = func_cls.initial_finalize_state(
|
|
603
|
+
state.finalize_state_id,
|
|
604
|
+
params,
|
|
605
|
+
)
|
|
606
|
+
state.state_initialized = True
|
|
607
|
+
else:
|
|
608
|
+
user_state = _deserialize_finalize_state(func_cls, state.state_blob) if state.state_blob else None
|
|
609
|
+
|
|
610
|
+
func_cls.finalize(params, state.finalize_state_id, user_state, effective_out)
|
|
611
|
+
|
|
612
|
+
# If we wrapped `out`, drain any buffered output before the tick returns.
|
|
613
|
+
# `_FilteringOutputCollector.propagate()` is the streaming-path
|
|
614
|
+
# convention; absent on the no-filter path so guard the call.
|
|
615
|
+
if auto_apply and hasattr(effective_out, "propagate"):
|
|
616
|
+
effective_out.propagate()
|
|
617
|
+
|
|
618
|
+
if user_state is None:
|
|
619
|
+
state.state_blob = b""
|
|
620
|
+
else:
|
|
621
|
+
state.state_blob = user_state.serialize_to_bytes()
|
|
622
|
+
|
|
623
|
+
|
|
624
|
+
# Streaming-partitioned aggregate sessions: session state is held in an
|
|
625
|
+
# in-process LRU cache for the fast path, and persisted to FunctionStorage
|
|
626
|
+
# (under partition_id=0) so chunk RPCs that land on a different pool worker
|
|
627
|
+
# can rehydrate. Same pattern as aggregate_window_partition_put/_get.
|
|
628
|
+
#
|
|
629
|
+
# State persisted: pickled dict containing streaming_state plus the schema
|
|
630
|
+
# fields that streaming_chunk needs to reconstruct ProcessParams without
|
|
631
|
+
# the open request.
|
|
632
|
+
_STREAMING_SESSION_STORAGE_KEY = 0
|
|
633
|
+
|
|
634
|
+
|
|
635
|
+
@dataclass(slots=True)
|
|
636
|
+
class _StreamingSession:
|
|
637
|
+
"""Per-execution_id state for a streaming-partitioned aggregate session."""
|
|
638
|
+
|
|
639
|
+
func_cls: type[AggregateFunction[Any]]
|
|
640
|
+
streaming_state: Any
|
|
641
|
+
output_schema: pa.Schema
|
|
642
|
+
partition_key_count: int
|
|
643
|
+
order_key_count: int
|
|
644
|
+
|
|
645
|
+
|
|
646
|
+
def _encode_streaming_session(session: _StreamingSession) -> bytes:
|
|
647
|
+
"""Pickle a session for FunctionStorage.
|
|
648
|
+
|
|
649
|
+
``func_cls`` is *not* pickled — it's resolved on the fly from the
|
|
650
|
+
function name on cold reload.
|
|
651
|
+
"""
|
|
652
|
+
sink = pa.BufferOutputStream()
|
|
653
|
+
with pa.ipc.new_stream(sink, session.output_schema):
|
|
654
|
+
pass # schema-only stream is enough to round-trip the schema
|
|
655
|
+
output_schema_bytes = sink.getvalue().to_pybytes()
|
|
656
|
+
return pickle.dumps(
|
|
657
|
+
{
|
|
658
|
+
"streaming_state": session.streaming_state,
|
|
659
|
+
"output_schema_bytes": output_schema_bytes,
|
|
660
|
+
"partition_key_count": session.partition_key_count,
|
|
661
|
+
"order_key_count": session.order_key_count,
|
|
662
|
+
}
|
|
663
|
+
)
|
|
664
|
+
|
|
665
|
+
|
|
666
|
+
def _decode_streaming_session(payload: bytes, func_cls: type[AggregateFunction[Any]]) -> _StreamingSession:
|
|
667
|
+
d = pickle.loads(payload)
|
|
668
|
+
return _StreamingSession(
|
|
669
|
+
func_cls=func_cls,
|
|
670
|
+
streaming_state=d["streaming_state"],
|
|
671
|
+
output_schema=pa.ipc.open_stream(d["output_schema_bytes"]).schema,
|
|
672
|
+
partition_key_count=d["partition_key_count"],
|
|
673
|
+
order_key_count=d["order_key_count"],
|
|
674
|
+
)
|
|
675
|
+
|
|
676
|
+
|
|
677
|
+
# --- Buffered-table init metadata persistence ----------------------------
|
|
678
|
+
#
|
|
679
|
+
# Persists just the two ArrowSerializableDataclass envelopes (the init
|
|
680
|
+
# request that the C++ side originally shipped, and the GlobalInitResponse
|
|
681
|
+
# the worker built from it). On cold-load, deserializing these is enough
|
|
682
|
+
# to rebuild the full ProcessParams via the same code path the init
|
|
683
|
+
# handler runs. func_cls is NOT persisted — re-resolved by name from the
|
|
684
|
+
# function registry on every load, both to avoid pickle-loading arbitrary
|
|
685
|
+
# classes and to handle version skew if the function definition changes.
|
|
686
|
+
|
|
687
|
+
|
|
688
|
+
def _encode_table_buffering_init(init_call: InitRequest, init_response: GlobalInitResponse) -> bytes:
|
|
689
|
+
return pickle.dumps(
|
|
690
|
+
{
|
|
691
|
+
"init_call_bytes": init_call.serialize_to_bytes(),
|
|
692
|
+
"init_response_bytes": init_response.serialize_to_bytes(),
|
|
693
|
+
}
|
|
694
|
+
)
|
|
695
|
+
|
|
696
|
+
|
|
697
|
+
def _decode_table_buffering_init(payload: bytes) -> tuple[InitRequest, GlobalInitResponse]:
|
|
698
|
+
d = pickle.loads(payload)
|
|
699
|
+
return (
|
|
700
|
+
InitRequest.deserialize_from_bytes(d["init_call_bytes"]),
|
|
701
|
+
GlobalInitResponse.deserialize_from_bytes(d["init_response_bytes"]),
|
|
702
|
+
)
|
|
703
|
+
|
|
704
|
+
|
|
705
|
+
class _StreamingSessionCache:
|
|
706
|
+
"""Process-local map ``execution_id -> _StreamingSession``."""
|
|
707
|
+
|
|
708
|
+
def __init__(self) -> None:
|
|
709
|
+
self._entries: dict[bytes, _StreamingSession] = {}
|
|
710
|
+
self._lock = Lock()
|
|
711
|
+
|
|
712
|
+
def put(self, execution_id: bytes, session: _StreamingSession) -> None:
|
|
713
|
+
with self._lock:
|
|
714
|
+
self._entries[execution_id] = session
|
|
715
|
+
|
|
716
|
+
def get(self, execution_id: bytes) -> _StreamingSession | None:
|
|
717
|
+
with self._lock:
|
|
718
|
+
return self._entries.get(execution_id)
|
|
719
|
+
|
|
720
|
+
def pop(self, execution_id: bytes) -> _StreamingSession | None:
|
|
721
|
+
with self._lock:
|
|
722
|
+
return self._entries.pop(execution_id, None)
|
|
723
|
+
|
|
724
|
+
|
|
725
|
+
_streaming_session_cache = _StreamingSessionCache()
|
|
726
|
+
|
|
727
|
+
|
|
728
|
+
# Process-local instrumentation for the streaming-aggregate path. Phase
|
|
729
|
+
# timers accumulate across all chunks of a session; dumped on close.
|
|
730
|
+
_streaming_persist_lock = Lock()
|
|
731
|
+
_streaming_persist_stats = {
|
|
732
|
+
"encode_session_seconds": 0.0,
|
|
733
|
+
"storage_put_seconds": 0.0,
|
|
734
|
+
"storage_get_seconds": 0.0,
|
|
735
|
+
"rpc_chunk_total_seconds": 0.0,
|
|
736
|
+
"n_chunks": 0,
|
|
737
|
+
"n_persists": 0,
|
|
738
|
+
"n_cold_loads": 0,
|
|
739
|
+
"bytes_persisted": 0,
|
|
740
|
+
}
|
|
741
|
+
|
|
742
|
+
|
|
743
|
+
def _record_persist_timing(
|
|
744
|
+
encode_seconds: float,
|
|
745
|
+
put_seconds: float,
|
|
746
|
+
payload_bytes: int,
|
|
747
|
+
) -> None:
|
|
748
|
+
with _streaming_persist_lock:
|
|
749
|
+
_streaming_persist_stats["encode_session_seconds"] += encode_seconds
|
|
750
|
+
_streaming_persist_stats["storage_put_seconds"] += put_seconds
|
|
751
|
+
_streaming_persist_stats["n_persists"] += 1
|
|
752
|
+
_streaming_persist_stats["bytes_persisted"] += payload_bytes
|
|
753
|
+
|
|
754
|
+
|
|
755
|
+
def _unpack_bool_mask(data: bytes, length: int) -> pa.BooleanArray:
|
|
756
|
+
"""Decode a packed-bit filter mask into a BooleanArray of the given length."""
|
|
757
|
+
if not data:
|
|
758
|
+
return pa.array([True] * length, type=pa.bool_())
|
|
759
|
+
buf = pa.py_buffer(data)
|
|
760
|
+
return cast(pa.BooleanArray, pa.Array.from_buffers(pa.bool_(), length, [None, buf])) # type: ignore[list-item]
|
|
761
|
+
|
|
762
|
+
|
|
763
|
+
def _unpack_frame_stats(data: bytes) -> tuple[tuple[int, int], tuple[int, int]]:
|
|
764
|
+
"""Decode 4× little-endian int64 into FrameStats tuple-of-tuples."""
|
|
765
|
+
if not data or len(data) < 32:
|
|
766
|
+
return ((0, 0), (0, 0))
|
|
767
|
+
import struct
|
|
768
|
+
|
|
769
|
+
b0, e0, b1, e1 = struct.unpack("<qqqq", data[:32])
|
|
770
|
+
return ((b0, e0), (b1, e1))
|
|
771
|
+
|
|
772
|
+
|
|
773
|
+
def _unpack_all_valid(data: bytes, column_count: int) -> list[bool]:
|
|
774
|
+
"""Decode 1-byte-per-column validity bools."""
|
|
775
|
+
if not data:
|
|
776
|
+
return [True] * column_count
|
|
777
|
+
return [bool(b) for b in data[:column_count]]
|
|
778
|
+
|
|
779
|
+
|
|
780
|
+
def _serialize_schema_bytes(schema: pa.Schema) -> bytes:
|
|
781
|
+
"""Serialize an Arrow Schema to IPC bytes (stream format, schema only)."""
|
|
782
|
+
sink = pa.BufferOutputStream()
|
|
783
|
+
with pa.ipc.new_stream(sink, schema):
|
|
784
|
+
pass
|
|
785
|
+
return sink.getvalue().to_pybytes()
|
|
786
|
+
|
|
787
|
+
|
|
788
|
+
# Arrow schema for the serialized window-partition cache payload stored in
|
|
789
|
+
# FunctionStorage. One row per partition, all fields binary/int64.
|
|
790
|
+
_WINDOW_PARTITION_CACHE_FIELDS: list[pa.Field[Any]] = [
|
|
791
|
+
pa.field("partition_batch", pa.binary(), nullable=False),
|
|
792
|
+
pa.field("output_schema", pa.binary(), nullable=False),
|
|
793
|
+
pa.field("filter_mask", pa.binary(), nullable=False),
|
|
794
|
+
pa.field("frame_stats", pa.binary(), nullable=False),
|
|
795
|
+
pa.field("all_valid", pa.binary(), nullable=False),
|
|
796
|
+
pa.field("row_count", pa.int64(), nullable=False),
|
|
797
|
+
pa.field("window_state", pa.binary(), nullable=True),
|
|
798
|
+
pa.field("window_state_class_name", pa.string(), nullable=False),
|
|
799
|
+
]
|
|
800
|
+
_WINDOW_PARTITION_CACHE_SCHEMA = pa.schema(_WINDOW_PARTITION_CACHE_FIELDS)
|
|
801
|
+
|
|
802
|
+
|
|
803
|
+
def _encode_window_partition_cache(
|
|
804
|
+
*,
|
|
805
|
+
partition_batch_bytes: bytes,
|
|
806
|
+
output_schema_bytes: bytes,
|
|
807
|
+
filter_mask_bytes: bytes,
|
|
808
|
+
frame_stats_bytes: bytes,
|
|
809
|
+
all_valid_bytes: bytes,
|
|
810
|
+
row_count: int,
|
|
811
|
+
window_state_bytes: bytes | None,
|
|
812
|
+
window_state_class_name: str,
|
|
813
|
+
) -> bytes:
|
|
814
|
+
batch = pa.record_batch(
|
|
815
|
+
{
|
|
816
|
+
"partition_batch": [partition_batch_bytes],
|
|
817
|
+
"output_schema": [output_schema_bytes],
|
|
818
|
+
"filter_mask": [filter_mask_bytes],
|
|
819
|
+
"frame_stats": [frame_stats_bytes],
|
|
820
|
+
"all_valid": [all_valid_bytes],
|
|
821
|
+
"row_count": [row_count],
|
|
822
|
+
"window_state": [window_state_bytes],
|
|
823
|
+
"window_state_class_name": [window_state_class_name],
|
|
824
|
+
},
|
|
825
|
+
schema=_WINDOW_PARTITION_CACHE_SCHEMA,
|
|
826
|
+
)
|
|
827
|
+
sink = pa.BufferOutputStream()
|
|
828
|
+
with pa.ipc.new_stream(sink, batch.schema) as writer:
|
|
829
|
+
writer.write_batch(batch)
|
|
830
|
+
return sink.getvalue().to_pybytes()
|
|
831
|
+
|
|
832
|
+
|
|
833
|
+
def _decode_window_partition_cache(data: bytes) -> dict[str, Any]:
|
|
834
|
+
batch = pa.ipc.open_stream(data).read_next_batch()
|
|
835
|
+
if batch.num_rows != 1:
|
|
836
|
+
raise ValueError(f"Expected 1 cache row, got {batch.num_rows}")
|
|
837
|
+
row = batch.to_pylist()[0]
|
|
838
|
+
return row
|
|
839
|
+
|
|
840
|
+
|
|
841
|
+
class _WindowStatePlaceholder:
|
|
842
|
+
"""Lazy window-state holder passed to user's ``window()``.
|
|
843
|
+
|
|
844
|
+
Carries the raw bytes and class name from ``window_init``'s return value.
|
|
845
|
+
The user's ``window()`` implementation typically calls ``.deserialize(cls)``
|
|
846
|
+
to rebuild a real dataclass instance, or inspects ``.raw_bytes`` directly.
|
|
847
|
+
"""
|
|
848
|
+
|
|
849
|
+
__slots__ = ("raw_bytes", "class_name")
|
|
850
|
+
|
|
851
|
+
def __init__(self, raw_bytes: bytes, class_name: str) -> None:
|
|
852
|
+
self.raw_bytes = raw_bytes
|
|
853
|
+
self.class_name = class_name
|
|
854
|
+
|
|
855
|
+
def deserialize(self, cls: type[Any]) -> Any:
|
|
856
|
+
"""Deserialize the stored bytes via ``cls.deserialize_from_bytes``."""
|
|
857
|
+
return cls.deserialize_from_bytes(self.raw_bytes)
|
|
858
|
+
|
|
859
|
+
|
|
860
|
+
def _build_scalar_result_batch(result_value: Any, output_schema: pa.Schema) -> pa.RecordBatch:
|
|
861
|
+
"""Build a one-row RecordBatch containing the scalar window result.
|
|
862
|
+
|
|
863
|
+
If ``result_value`` is already a RecordBatch/Array with the right shape,
|
|
864
|
+
convert it; otherwise wrap the scalar in a one-element array of the
|
|
865
|
+
output column's type.
|
|
866
|
+
"""
|
|
867
|
+
if isinstance(result_value, pa.RecordBatch):
|
|
868
|
+
if result_value.num_rows != 1:
|
|
869
|
+
raise ValueError(f"window() must return a scalar or a 1-row RecordBatch, got {result_value.num_rows} rows")
|
|
870
|
+
return result_value
|
|
871
|
+
|
|
872
|
+
if len(output_schema) != 1:
|
|
873
|
+
raise ValueError(f"Window aggregate output_schema must have 1 field, got {len(output_schema)}")
|
|
874
|
+
output_type = output_schema.field(0).type
|
|
875
|
+
col_name = output_schema.field(0).name
|
|
876
|
+
if isinstance(result_value, pa.Array):
|
|
877
|
+
if len(result_value) != 1:
|
|
878
|
+
raise ValueError(f"window() array result must have length 1, got {len(result_value)}")
|
|
879
|
+
arr = result_value
|
|
880
|
+
else:
|
|
881
|
+
arr = pa.array([result_value], type=output_type)
|
|
882
|
+
return pa.record_batch({col_name: arr}, schema=output_schema)
|
|
883
|
+
|
|
884
|
+
|
|
885
|
+
def _build_batch_result(
|
|
886
|
+
results: list[Any] | pa.Array[Any],
|
|
887
|
+
output_schema: pa.Schema,
|
|
888
|
+
expected_count: int | None = None,
|
|
889
|
+
) -> pa.RecordBatch:
|
|
890
|
+
"""Build a count-row RecordBatch containing the batched window results.
|
|
891
|
+
|
|
892
|
+
``results`` may be either a Python list (fed through ``pa.array(...)``,
|
|
893
|
+
the default) or a pre-built ``pa.Array`` matching the output type
|
|
894
|
+
(shipped directly — used by ``window_batch`` overrides that build the
|
|
895
|
+
output via Arrow primitives to avoid per-row Python overhead).
|
|
896
|
+
"""
|
|
897
|
+
if len(output_schema) != 1:
|
|
898
|
+
raise ValueError(f"Window aggregate output_schema must have 1 field, got {len(output_schema)}")
|
|
899
|
+
output_type = output_schema.field(0).type
|
|
900
|
+
col_name = output_schema.field(0).name
|
|
901
|
+
if isinstance(results, pa.Array):
|
|
902
|
+
if not results.type.equals(output_type):
|
|
903
|
+
raise TypeError(f"window_batch returned pa.Array of type {results.type}, expected {output_type}")
|
|
904
|
+
arr: pa.Array[Any] = results
|
|
905
|
+
else:
|
|
906
|
+
arr = pa.array(results, type=output_type)
|
|
907
|
+
if expected_count is not None and len(arr) != expected_count:
|
|
908
|
+
raise ValueError(f"window_batch returned {len(arr)} rows, expected {expected_count}")
|
|
909
|
+
return pa.record_batch({col_name: arr}, schema=output_schema)
|
|
910
|
+
|
|
911
|
+
|
|
912
|
+
# ---------------------------------------------------------------------------
|
|
913
|
+
# Catalog opaque-data AEAD envelopes
|
|
914
|
+
# ---------------------------------------------------------------------------
|
|
915
|
+
#
|
|
916
|
+
# ``attach_opaque_data`` and ``transaction_opaque_data`` are implementation-
|
|
917
|
+
# chosen byte strings the catalog round-trips through the client. They may
|
|
918
|
+
# carry credentials, and nothing stops principal A from presenting principal
|
|
919
|
+
# B's value. On HTTP transport (where one worker authenticates many
|
|
920
|
+
# principals) the worker therefore seals each value in an authenticated-
|
|
921
|
+
# encrypted envelope whose AAD binds the caller's ``(domain, principal)``;
|
|
922
|
+
# a request from a different principal produces different AAD and fails the
|
|
923
|
+
# tag check. The transaction envelope additionally binds its parent attach
|
|
924
|
+
# envelope, so a transaction value cannot be lifted onto a different attach.
|
|
925
|
+
#
|
|
926
|
+
# Subprocess / unix-socket transports have no signing key (``_signing_key``
|
|
927
|
+
# is ``None``): the helpers pass values through unchanged, since OS process
|
|
928
|
+
# ownership already enforces identity there.
|
|
929
|
+
|
|
930
|
+
_ATTACH_AAD_PREFIX = b"vgi.attach_opaque_data.v1\x00"
|
|
931
|
+
_TRANSACTION_AAD_PREFIX = b"vgi.transaction_opaque_data.v1\x00"
|
|
932
|
+
# v2: the inner attach plaintext is ``uuid(16) || catalog_bytes`` — catalog_attach
|
|
933
|
+
# prepends a framework-minted 16-byte UUID that storage shards on (see
|
|
934
|
+
# ``_AttachUnwrapper``). Bumped from 1 so a stale v1 token (no uuid prefix) is
|
|
935
|
+
# cleanly rejected at open rather than mis-parsed as ``uuid=catalog_bytes[:16]``.
|
|
936
|
+
_ATTACH_ENVELOPE_VERSION = 2
|
|
937
|
+
_TRANSACTION_ENVELOPE_VERSION = 2
|
|
938
|
+
# Width of the framework UUID prepended to every attach plaintext.
|
|
939
|
+
_ATTACH_UUID_LEN = 16
|
|
940
|
+
|
|
941
|
+
|
|
942
|
+
def _identity_tail(auth: AuthContext | None) -> bytes:
|
|
943
|
+
"""Build the identity portion of an opaque-data AAD from an auth context.
|
|
944
|
+
|
|
945
|
+
Mirrors the ``(domain, principal)`` convention vgi-rpc uses for HTTP
|
|
946
|
+
state tokens: unauthenticated requests get a fixed anonymous tail, so an
|
|
947
|
+
anonymous caller cannot open an envelope sealed for a real principal.
|
|
948
|
+
"""
|
|
949
|
+
if auth is None or not auth.authenticated:
|
|
950
|
+
return b"\x00anonymous"
|
|
951
|
+
domain = (auth.domain or "").encode()
|
|
952
|
+
principal = (auth.principal or "").encode()
|
|
953
|
+
return b"\x01" + domain + b"\x00" + principal
|
|
954
|
+
|
|
955
|
+
|
|
956
|
+
def _short_hash(value: bytes | str | None, *, length: int = 12) -> str | None:
|
|
957
|
+
"""Return a stable hex prefix of ``sha256(value)`` — never the value itself.
|
|
958
|
+
|
|
959
|
+
Matches ``vgi_rpc.sentry.short_hash`` so a value redacted here hashes to
|
|
960
|
+
the same token vgi-rpc's dispatch hook uses for Sentry tags. Defined
|
|
961
|
+
locally (rather than imported) because ``vgi_rpc.sentry`` pulls in
|
|
962
|
+
``sentry_sdk``, which is an optional extra; opaque-data redaction must
|
|
963
|
+
work whether or not Sentry is installed.
|
|
964
|
+
"""
|
|
965
|
+
if value is None:
|
|
966
|
+
return None
|
|
967
|
+
if isinstance(value, bytes):
|
|
968
|
+
value = value.hex()
|
|
969
|
+
return hashlib.sha256(value.encode("utf-8")).hexdigest()[:length]
|
|
970
|
+
|
|
971
|
+
|
|
972
|
+
def _attach_aad(auth: AuthContext | None) -> bytes:
|
|
973
|
+
"""AAD for an ``attach_opaque_data`` envelope: prefix + caller identity."""
|
|
974
|
+
return _ATTACH_AAD_PREFIX + _identity_tail(auth)
|
|
975
|
+
|
|
976
|
+
|
|
977
|
+
def _transaction_aad(auth: AuthContext | None, attach_envelope: bytes) -> bytes:
|
|
978
|
+
"""AAD for a ``transaction_opaque_data`` envelope.
|
|
979
|
+
|
|
980
|
+
Binds both the caller identity *and* the parent attach envelope, so a
|
|
981
|
+
transaction value minted under one attach cannot be replayed against a
|
|
982
|
+
different attach even by the same principal.
|
|
983
|
+
"""
|
|
984
|
+
return _TRANSACTION_AAD_PREFIX + _identity_tail(auth) + b"\x00" + attach_envelope
|
|
985
|
+
|
|
986
|
+
|
|
987
|
+
class Worker:
|
|
988
|
+
"""Base class for VGI workers that host user-defined functions.
|
|
989
|
+
|
|
990
|
+
Subclass this and define a `functions` class attribute listing your function
|
|
991
|
+
classes. Function names are derived from metadata (Meta.name or snake_case
|
|
992
|
+
of class name). The worker handles the VGI protocol via vgi_rpc.RpcServer.
|
|
993
|
+
|
|
994
|
+
Multiple functions can share the same name if they have different argument
|
|
995
|
+
signatures (function overloading). The worker will select the appropriate
|
|
996
|
+
function based on the invocation's arguments.
|
|
997
|
+
|
|
998
|
+
Catalog Interface:
|
|
999
|
+
If `catalog_interface` is not set but `functions` is non-empty, a default
|
|
1000
|
+
read-only catalog interface is created automatically. This exposes the
|
|
1001
|
+
worker's functions via the catalog protocol, allowing clients to discover
|
|
1002
|
+
available functions.
|
|
1003
|
+
|
|
1004
|
+
To customize the catalog, set `catalog_interface` to a CatalogInterface
|
|
1005
|
+
subclass. To disable the catalog entirely, set `catalog_interface = None`
|
|
1006
|
+
and `catalog_name = None`.
|
|
1007
|
+
|
|
1008
|
+
"""
|
|
1009
|
+
|
|
1010
|
+
functions: Sequence[type[Function]] = []
|
|
1011
|
+
# Protocol class handed to RpcServer. Defaults to the real VgiProtocol;
|
|
1012
|
+
# test fixtures override this with a VgiProtocol subclass that redeclares
|
|
1013
|
+
# ``protocol_version`` to exercise the framework's version-mismatch
|
|
1014
|
+
# enforcement end-to-end. vgi-rpc reads the version via ``vars(protocol)``,
|
|
1015
|
+
# so the override must redeclare ``protocol_version`` on its own class body.
|
|
1016
|
+
protocol_class: ClassVar[type[VgiProtocol]] = VgiProtocol # type: ignore[type-abstract]
|
|
1017
|
+
catalog_interface: type[CatalogInterface] | None = None
|
|
1018
|
+
catalog_name: str | None = "functions" # Set to None to disable default catalog
|
|
1019
|
+
catalog: Catalog | None = None
|
|
1020
|
+
_registry: dict[str, list[type[Function]]] | None = None
|
|
1021
|
+
_default_catalog_interface: type[CatalogInterface] | None = None
|
|
1022
|
+
_setting_specs: list[SettingSpec] = [] # Extracted from Settings inner class
|
|
1023
|
+
_secret_type_specs: list[SecretTypeSpec] = [] # Secret types to register
|
|
1024
|
+
_attach_option_specs: list[AttachOptionSpec] = [] # Extracted from AttachOptions inner class
|
|
1025
|
+
|
|
1026
|
+
# AEAD key for sealing catalog opaque-data envelopes. Set per-instance by
|
|
1027
|
+
# the HTTP serving path (``vgi.serve.create_app`` / the test HTTP server);
|
|
1028
|
+
# stays ``None`` for subprocess / unix-socket workers, where the helpers
|
|
1029
|
+
# pass opaque values through unchanged. See the module-level
|
|
1030
|
+
# "Catalog opaque-data AEAD envelopes" section.
|
|
1031
|
+
_signing_key: bytes | None = None
|
|
1032
|
+
|
|
1033
|
+
@final
|
|
1034
|
+
@staticmethod
|
|
1035
|
+
def _validate_required_settings(func_cls: type[Function], request: BindRequest) -> None:
|
|
1036
|
+
"""Validate required settings for a bind request."""
|
|
1037
|
+
meta = func_cls.get_metadata()
|
|
1038
|
+
if not meta.required_settings:
|
|
1039
|
+
return
|
|
1040
|
+
|
|
1041
|
+
settings: set[str] = set()
|
|
1042
|
+
if request.settings is not None and request.settings.schema is not None:
|
|
1043
|
+
settings = set(list(request.settings.schema.names))
|
|
1044
|
+
|
|
1045
|
+
missing = [s for s in meta.required_settings if s not in settings]
|
|
1046
|
+
if missing:
|
|
1047
|
+
raise ValueError(f"Function '{request.function_name}' requires settings: {missing}")
|
|
1048
|
+
|
|
1049
|
+
def __init_subclass__(cls, **kwargs: Any) -> None:
|
|
1050
|
+
"""Process Settings inner class when subclassing Worker."""
|
|
1051
|
+
super().__init_subclass__(**kwargs)
|
|
1052
|
+
|
|
1053
|
+
# Process Settings inner class if present
|
|
1054
|
+
if hasattr(cls, "Settings") and isinstance(cls.Settings, type):
|
|
1055
|
+
cls._setting_specs = extract_setting_specs(cls.Settings)
|
|
1056
|
+
else:
|
|
1057
|
+
cls._setting_specs = []
|
|
1058
|
+
|
|
1059
|
+
# Process AttachOptions inner class if present
|
|
1060
|
+
if hasattr(cls, "AttachOptions") and isinstance(cls.AttachOptions, type):
|
|
1061
|
+
cls._attach_option_specs = extract_attach_option_specs(cls.AttachOptions)
|
|
1062
|
+
else:
|
|
1063
|
+
cls._attach_option_specs = []
|
|
1064
|
+
|
|
1065
|
+
# Process secret_types class attribute if present
|
|
1066
|
+
if hasattr(cls, "secret_types") and isinstance(cls.secret_types, list):
|
|
1067
|
+
cls._secret_type_specs = list(cls.secret_types)
|
|
1068
|
+
else:
|
|
1069
|
+
cls._secret_type_specs = []
|
|
1070
|
+
|
|
1071
|
+
# Inject settings/secret_types/attach_option_specs into explicit
|
|
1072
|
+
# catalog_interface if set, so catalogs()/catalog_attach() can
|
|
1073
|
+
# serialize them. Done once at class definition.
|
|
1074
|
+
if cls.catalog_interface is not None:
|
|
1075
|
+
if cls._setting_specs and hasattr(cls.catalog_interface, "settings"):
|
|
1076
|
+
cls.catalog_interface.settings = list(cls._setting_specs)
|
|
1077
|
+
if cls._secret_type_specs and hasattr(cls.catalog_interface, "secret_types"):
|
|
1078
|
+
cls.catalog_interface.secret_types = list(cls._secret_type_specs)
|
|
1079
|
+
if cls._attach_option_specs and hasattr(cls.catalog_interface, "attach_option_specs"):
|
|
1080
|
+
cls.catalog_interface.attach_option_specs = list(cls._attach_option_specs)
|
|
1081
|
+
|
|
1082
|
+
@classmethod
|
|
1083
|
+
def _build_registry(cls) -> dict[str, list[type[Function]]]:
|
|
1084
|
+
"""Build function name -> list of classes mapping from functions list.
|
|
1085
|
+
|
|
1086
|
+
Multiple functions can share the same name if they have different
|
|
1087
|
+
argument signatures (overloading).
|
|
1088
|
+
|
|
1089
|
+
Supports both patterns:
|
|
1090
|
+
- Legacy: cls.functions list
|
|
1091
|
+
- Declarative: cls.catalog.schemas[*].functions
|
|
1092
|
+
"""
|
|
1093
|
+
if cls._registry is not None:
|
|
1094
|
+
return cls._registry
|
|
1095
|
+
|
|
1096
|
+
registry: dict[str, list[type[Function]]] = {}
|
|
1097
|
+
|
|
1098
|
+
seen: set[type[Function]] = set()
|
|
1099
|
+
|
|
1100
|
+
def add_function(func_cls: type[Function]) -> None:
|
|
1101
|
+
if func_cls in seen:
|
|
1102
|
+
return
|
|
1103
|
+
seen.add(func_cls)
|
|
1104
|
+
meta = func_cls.get_metadata()
|
|
1105
|
+
if meta.name not in registry:
|
|
1106
|
+
registry[meta.name] = []
|
|
1107
|
+
registry[meta.name].append(func_cls)
|
|
1108
|
+
|
|
1109
|
+
# Legacy pattern: functions list
|
|
1110
|
+
for func_cls in cls.functions:
|
|
1111
|
+
add_function(func_cls)
|
|
1112
|
+
|
|
1113
|
+
# Declarative pattern: functions in catalog schemas
|
|
1114
|
+
if cls.catalog is not None:
|
|
1115
|
+
for schema in cls.catalog.schemas:
|
|
1116
|
+
for func_cls in schema.functions:
|
|
1117
|
+
add_function(func_cls)
|
|
1118
|
+
|
|
1119
|
+
# Auto-register functions referenced by table descriptors
|
|
1120
|
+
for table in schema.tables:
|
|
1121
|
+
# Scan function (Table.function)
|
|
1122
|
+
if table.function is not None:
|
|
1123
|
+
add_function(table.function)
|
|
1124
|
+
# Write functions
|
|
1125
|
+
for attr in ("insert_function", "update_function", "delete_function"):
|
|
1126
|
+
write_func = getattr(table, attr, None)
|
|
1127
|
+
if write_func is not None:
|
|
1128
|
+
add_function(write_func)
|
|
1129
|
+
|
|
1130
|
+
cls._registry = registry
|
|
1131
|
+
return registry
|
|
1132
|
+
|
|
1133
|
+
@classmethod
|
|
1134
|
+
def _get_catalog_interface(cls) -> type[CatalogInterface] | None:
|
|
1135
|
+
"""Get the catalog interface to use for this worker.
|
|
1136
|
+
|
|
1137
|
+
Returns the explicitly set catalog_interface if present. Otherwise:
|
|
1138
|
+
- If `catalog` attribute is set (new pattern), creates a default
|
|
1139
|
+
ReadOnlyCatalogInterface using the Catalog object.
|
|
1140
|
+
- If `catalog_name` and `functions` are set (legacy pattern), creates
|
|
1141
|
+
a default ReadOnlyCatalogInterface exposing the functions.
|
|
1142
|
+
|
|
1143
|
+
Returns:
|
|
1144
|
+
CatalogInterface class to instantiate, or None if no catalog.
|
|
1145
|
+
|
|
1146
|
+
"""
|
|
1147
|
+
# Use explicit catalog_interface if set (settings injected in __init_subclass__)
|
|
1148
|
+
if cls.catalog_interface is not None:
|
|
1149
|
+
return cls.catalog_interface
|
|
1150
|
+
|
|
1151
|
+
# Check for new Catalog object or legacy patterns
|
|
1152
|
+
catalog_obj = cls.catalog
|
|
1153
|
+
has_catalog = catalog_obj is not None
|
|
1154
|
+
has_legacy = cls.catalog_name is not None and cls.functions
|
|
1155
|
+
|
|
1156
|
+
if not has_catalog and not has_legacy:
|
|
1157
|
+
return None
|
|
1158
|
+
|
|
1159
|
+
# Create default catalog interface if not already created
|
|
1160
|
+
if cls._default_catalog_interface is None:
|
|
1161
|
+
from vgi.catalog import ReadOnlyCatalogInterface
|
|
1162
|
+
|
|
1163
|
+
attrs: dict[str, Any] = {
|
|
1164
|
+
"settings": list(cls._setting_specs),
|
|
1165
|
+
"secret_types": list(cls._secret_type_specs),
|
|
1166
|
+
"attach_option_specs": list(cls._attach_option_specs),
|
|
1167
|
+
}
|
|
1168
|
+
|
|
1169
|
+
if has_catalog:
|
|
1170
|
+
# New pattern: use Catalog object
|
|
1171
|
+
assert catalog_obj is not None
|
|
1172
|
+
attrs["catalog"] = catalog_obj
|
|
1173
|
+
attrs["catalog_name"] = catalog_obj.name
|
|
1174
|
+
else:
|
|
1175
|
+
# Legacy pattern: use class attributes
|
|
1176
|
+
attrs["catalog_name"] = cls.catalog_name
|
|
1177
|
+
attrs["functions"] = list(cls.functions)
|
|
1178
|
+
|
|
1179
|
+
cls._default_catalog_interface = cast(
|
|
1180
|
+
type[CatalogInterface],
|
|
1181
|
+
type(
|
|
1182
|
+
f"{cls.__name__}Catalog",
|
|
1183
|
+
(ReadOnlyCatalogInterface,),
|
|
1184
|
+
attrs,
|
|
1185
|
+
),
|
|
1186
|
+
)
|
|
1187
|
+
|
|
1188
|
+
return cls._default_catalog_interface
|
|
1189
|
+
|
|
1190
|
+
@final
|
|
1191
|
+
@classmethod
|
|
1192
|
+
def main(cls) -> None:
|
|
1193
|
+
"""Run this worker as a CLI application with logging options.
|
|
1194
|
+
|
|
1195
|
+
By default, serves over stdin/stdout (pipe transport).
|
|
1196
|
+
Pass ``--http`` to serve over HTTP instead.
|
|
1197
|
+
|
|
1198
|
+
Supports ``--quiet``, ``--debug``, ``--log-level``,
|
|
1199
|
+
``--log-logger``, and ``--log-format`` for logging control.
|
|
1200
|
+
|
|
1201
|
+
HTTP-specific options (only used with ``--http``):
|
|
1202
|
+
``--host``, ``--port``, ``--prefix``, ``--cors-origins``,
|
|
1203
|
+
``--describe/--no-describe``.
|
|
1204
|
+
|
|
1205
|
+
Requires the ``http`` extra for HTTP mode: ``pip install vgi[http]``
|
|
1206
|
+
"""
|
|
1207
|
+
import typer
|
|
1208
|
+
|
|
1209
|
+
from vgi.logging_config import configure_worker_logging
|
|
1210
|
+
|
|
1211
|
+
app = typer.Typer(add_completion=False)
|
|
1212
|
+
|
|
1213
|
+
@app.command()
|
|
1214
|
+
def _run(
|
|
1215
|
+
quiet: bool = typer.Option(False, "--quiet", "-q", help="Suppress startup warning"),
|
|
1216
|
+
debug: bool = typer.Option(False, "--debug", help="Enable DEBUG on all vgi + vgi_rpc loggers"),
|
|
1217
|
+
log_level: LogLevel = typer.Option(LogLevel.INFO, "--log-level", help="Set log level"), # noqa: B008
|
|
1218
|
+
log_logger: list[str] | None = typer.Option( # noqa: B008
|
|
1219
|
+
None, "--log-logger", help="Target specific logger(s)"
|
|
1220
|
+
),
|
|
1221
|
+
log_format: LogFormat = typer.Option( # noqa: B008
|
|
1222
|
+
LogFormat.text, "--log-format", help="Stderr log format"
|
|
1223
|
+
),
|
|
1224
|
+
# HTTP transport options
|
|
1225
|
+
http: bool = typer.Option(False, "--http", help="Serve over HTTP instead of stdin/stdout"),
|
|
1226
|
+
host: str = typer.Option("127.0.0.1", "--host", help="HTTP bind address"),
|
|
1227
|
+
port: int = typer.Option(0, "--port", "-p", help="HTTP port (0 = auto-select)"),
|
|
1228
|
+
prefix: str = typer.Option("", "--prefix", help="URL prefix for RPC endpoints"),
|
|
1229
|
+
cors_origins: str = typer.Option("*", "--cors-origins", help="Allowed CORS origins"),
|
|
1230
|
+
describe: bool = typer.Option( # noqa: B008
|
|
1231
|
+
True, "--describe/--no-describe", help="Enable description pages (worker + RPC API)"
|
|
1232
|
+
),
|
|
1233
|
+
port_file: str | None = typer.Option(
|
|
1234
|
+
None,
|
|
1235
|
+
"--port-file",
|
|
1236
|
+
help=(
|
|
1237
|
+
"Write the bound port number (one line, no prefix) to this file before starting "
|
|
1238
|
+
"to serve. For test harnesses / process managers that need the port side-channel "
|
|
1239
|
+
"without parsing stdout."
|
|
1240
|
+
),
|
|
1241
|
+
),
|
|
1242
|
+
# AF_UNIX launcher contract — mutually exclusive with --http.
|
|
1243
|
+
# When --unix PATH is passed, the worker binds an AF_UNIX
|
|
1244
|
+
# socket, prints UNIX:<abs-path> to stdout, and self-shuts-down
|
|
1245
|
+
# after --idle-timeout seconds with zero connected clients.
|
|
1246
|
+
unix: str | None = typer.Option(
|
|
1247
|
+
None,
|
|
1248
|
+
"--unix",
|
|
1249
|
+
help="Bind to this AF_UNIX socket path instead of stdin/stdout (mutex with --http).",
|
|
1250
|
+
),
|
|
1251
|
+
idle_timeout: float = typer.Option(
|
|
1252
|
+
300.0,
|
|
1253
|
+
"--idle-timeout",
|
|
1254
|
+
help="Self-shutdown after N seconds idle when serving --unix.",
|
|
1255
|
+
),
|
|
1256
|
+
http_threads: int | None = typer.Option( # noqa: B008
|
|
1257
|
+
None,
|
|
1258
|
+
"--http-threads",
|
|
1259
|
+
help=(
|
|
1260
|
+
"waitress worker thread count (only when --http). Default None "
|
|
1261
|
+
"uses waitress's default (4). Raise this when many concurrent "
|
|
1262
|
+
"process() ticks would otherwise queue on the WSGI threadpool."
|
|
1263
|
+
),
|
|
1264
|
+
),
|
|
1265
|
+
) -> None:
|
|
1266
|
+
env_debug = os.environ.get("VGI_WORKER_DEBUG", "").lower() in ("1", "true", "yes")
|
|
1267
|
+
effective_debug = debug or env_debug
|
|
1268
|
+
effective_level = configure_worker_logging(
|
|
1269
|
+
debug=effective_debug,
|
|
1270
|
+
log_level=log_level,
|
|
1271
|
+
log_loggers=log_logger,
|
|
1272
|
+
log_format=log_format,
|
|
1273
|
+
)
|
|
1274
|
+
|
|
1275
|
+
if http and unix is not None:
|
|
1276
|
+
raise typer.BadParameter("--http and --unix are mutually exclusive")
|
|
1277
|
+
|
|
1278
|
+
if http:
|
|
1279
|
+
from vgi.serve import (
|
|
1280
|
+
_maybe_init_sentry,
|
|
1281
|
+
_resolve_authenticate,
|
|
1282
|
+
_resolve_oauth_resource_metadata,
|
|
1283
|
+
_resolve_otel_config,
|
|
1284
|
+
)
|
|
1285
|
+
|
|
1286
|
+
_maybe_init_sentry()
|
|
1287
|
+
authenticate = _resolve_authenticate()
|
|
1288
|
+
oauth_metadata = _resolve_oauth_resource_metadata()
|
|
1289
|
+
otel_config = _resolve_otel_config()
|
|
1290
|
+
cls._run_http(
|
|
1291
|
+
effective_level=effective_level,
|
|
1292
|
+
host=host,
|
|
1293
|
+
port=port,
|
|
1294
|
+
prefix=prefix,
|
|
1295
|
+
cors_origins=cors_origins,
|
|
1296
|
+
describe=describe,
|
|
1297
|
+
authenticate=authenticate,
|
|
1298
|
+
oauth_resource_metadata=oauth_metadata,
|
|
1299
|
+
otel_config=otel_config,
|
|
1300
|
+
port_file=port_file,
|
|
1301
|
+
http_threads=http_threads,
|
|
1302
|
+
)
|
|
1303
|
+
elif unix is not None:
|
|
1304
|
+
# AF_UNIX launcher path. Bind to the requested socket,
|
|
1305
|
+
# print UNIX:<abs_path> on stdout, idle-shutdown after
|
|
1306
|
+
# idle_timeout seconds.
|
|
1307
|
+
from vgi_rpc.rpc import serve_unix
|
|
1308
|
+
|
|
1309
|
+
from vgi.serve import _maybe_init_sentry, _resolve_otel_config
|
|
1310
|
+
|
|
1311
|
+
_maybe_init_sentry()
|
|
1312
|
+
otel_config = _resolve_otel_config()
|
|
1313
|
+
worker = cls(quiet=quiet, log_level=effective_level)
|
|
1314
|
+
server = RpcServer(cls.protocol_class, worker, server_version=_get_vgi_version())
|
|
1315
|
+
if otel_config is not None:
|
|
1316
|
+
from vgi_rpc.otel import instrument_server
|
|
1317
|
+
|
|
1318
|
+
instrument_server(server, otel_config)
|
|
1319
|
+
worker._vgi_tracer = VgiTracer.create(otel_config)
|
|
1320
|
+
abs_path = os.path.abspath(unix)
|
|
1321
|
+
effective_idle = idle_timeout if idle_timeout > 0 else None
|
|
1322
|
+
|
|
1323
|
+
def _emit(bound: str) -> None:
|
|
1324
|
+
print(f"UNIX:{bound}", flush=True)
|
|
1325
|
+
|
|
1326
|
+
serve_unix(
|
|
1327
|
+
server,
|
|
1328
|
+
abs_path,
|
|
1329
|
+
threaded=True,
|
|
1330
|
+
idle_timeout=effective_idle,
|
|
1331
|
+
on_bound=_emit,
|
|
1332
|
+
)
|
|
1333
|
+
else:
|
|
1334
|
+
from vgi.serve import _maybe_init_sentry, _resolve_otel_config
|
|
1335
|
+
|
|
1336
|
+
_maybe_init_sentry()
|
|
1337
|
+
otel_config = _resolve_otel_config()
|
|
1338
|
+
cls(quiet=quiet, log_level=effective_level).run(otel_config=otel_config)
|
|
1339
|
+
|
|
1340
|
+
app()
|
|
1341
|
+
|
|
1342
|
+
@final
|
|
1343
|
+
@classmethod
|
|
1344
|
+
def main_http(cls) -> None:
|
|
1345
|
+
"""Run this worker as a dedicated HTTP server with logging options.
|
|
1346
|
+
|
|
1347
|
+
Prefer using ``main()`` with ``--http`` instead — it provides the
|
|
1348
|
+
same HTTP capabilities while also supporting pipe transport as the
|
|
1349
|
+
default. This method is kept for backward compatibility and for
|
|
1350
|
+
entry points that are always HTTP-only.
|
|
1351
|
+
|
|
1352
|
+
Requires the ``http`` extra: ``pip install vgi[http]``
|
|
1353
|
+
"""
|
|
1354
|
+
import typer
|
|
1355
|
+
|
|
1356
|
+
from vgi.logging_config import configure_worker_logging
|
|
1357
|
+
|
|
1358
|
+
app = typer.Typer(add_completion=False)
|
|
1359
|
+
|
|
1360
|
+
@app.command()
|
|
1361
|
+
def _run(
|
|
1362
|
+
host: str = typer.Option("127.0.0.1", "--host", "-h", help="Bind address"),
|
|
1363
|
+
port: int = typer.Option(0, "--port", "-p", help="Bind port (0 = auto-select)"),
|
|
1364
|
+
prefix: str = typer.Option("", "--prefix", help="URL prefix for RPC endpoints"),
|
|
1365
|
+
cors_origins: str = typer.Option("*", "--cors-origins", help="Allowed CORS origins"),
|
|
1366
|
+
describe: bool = typer.Option( # noqa: B008
|
|
1367
|
+
True, "--describe/--no-describe", help="Enable description pages (worker + RPC API)"
|
|
1368
|
+
),
|
|
1369
|
+
debug: bool = typer.Option(False, "--debug", help="Enable DEBUG on all vgi + vgi_rpc loggers"),
|
|
1370
|
+
log_level: LogLevel = typer.Option(LogLevel.INFO, "--log-level", help="Set log level"), # noqa: B008
|
|
1371
|
+
log_logger: list[str] | None = typer.Option( # noqa: B008
|
|
1372
|
+
None, "--log-logger", help="Target specific logger(s)"
|
|
1373
|
+
),
|
|
1374
|
+
log_format: LogFormat = typer.Option( # noqa: B008
|
|
1375
|
+
LogFormat.text, "--log-format", help="Stderr log format"
|
|
1376
|
+
),
|
|
1377
|
+
http_threads: int | None = typer.Option( # noqa: B008
|
|
1378
|
+
None,
|
|
1379
|
+
"--http-threads",
|
|
1380
|
+
help=(
|
|
1381
|
+
"waitress worker thread count. Default None uses waitress's "
|
|
1382
|
+
"default (4). Raise this when many concurrent process() ticks "
|
|
1383
|
+
"would otherwise queue on the WSGI threadpool — typical sign "
|
|
1384
|
+
"is 'Task queue depth is N' messages from waitress."
|
|
1385
|
+
),
|
|
1386
|
+
),
|
|
1387
|
+
) -> None:
|
|
1388
|
+
env_debug = os.environ.get("VGI_WORKER_DEBUG", "").lower() in ("1", "true", "yes")
|
|
1389
|
+
effective_debug = debug or env_debug
|
|
1390
|
+
effective_level = configure_worker_logging(
|
|
1391
|
+
debug=effective_debug,
|
|
1392
|
+
log_level=log_level,
|
|
1393
|
+
log_loggers=log_logger,
|
|
1394
|
+
log_format=log_format,
|
|
1395
|
+
)
|
|
1396
|
+
|
|
1397
|
+
from vgi.serve import (
|
|
1398
|
+
_maybe_init_sentry,
|
|
1399
|
+
_resolve_authenticate,
|
|
1400
|
+
_resolve_oauth_resource_metadata,
|
|
1401
|
+
_resolve_otel_config,
|
|
1402
|
+
)
|
|
1403
|
+
|
|
1404
|
+
_maybe_init_sentry()
|
|
1405
|
+
authenticate = _resolve_authenticate()
|
|
1406
|
+
oauth_metadata = _resolve_oauth_resource_metadata()
|
|
1407
|
+
otel_config = _resolve_otel_config()
|
|
1408
|
+
cls._run_http(
|
|
1409
|
+
effective_level=effective_level,
|
|
1410
|
+
host=host,
|
|
1411
|
+
port=port,
|
|
1412
|
+
prefix=prefix,
|
|
1413
|
+
cors_origins=cors_origins,
|
|
1414
|
+
describe=describe,
|
|
1415
|
+
authenticate=authenticate,
|
|
1416
|
+
oauth_resource_metadata=oauth_metadata,
|
|
1417
|
+
otel_config=otel_config,
|
|
1418
|
+
http_threads=http_threads,
|
|
1419
|
+
)
|
|
1420
|
+
|
|
1421
|
+
app()
|
|
1422
|
+
|
|
1423
|
+
@classmethod
|
|
1424
|
+
def _run_http(
|
|
1425
|
+
cls,
|
|
1426
|
+
*,
|
|
1427
|
+
effective_level: int,
|
|
1428
|
+
host: str,
|
|
1429
|
+
port: int,
|
|
1430
|
+
prefix: str,
|
|
1431
|
+
cors_origins: str,
|
|
1432
|
+
describe: bool,
|
|
1433
|
+
authenticate: Any = None,
|
|
1434
|
+
oauth_resource_metadata: Any = None,
|
|
1435
|
+
otel_config: Any = None,
|
|
1436
|
+
port_file: str | None = None,
|
|
1437
|
+
http_threads: int | None = None,
|
|
1438
|
+
) -> None:
|
|
1439
|
+
"""Start the worker as an HTTP server (shared by ``main`` and ``main_http``)."""
|
|
1440
|
+
import socket
|
|
1441
|
+
|
|
1442
|
+
try:
|
|
1443
|
+
import waitress # type: ignore[import-untyped]
|
|
1444
|
+
except ImportError:
|
|
1445
|
+
sys.stderr.write(
|
|
1446
|
+
"Error: waitress not installed.\nInstall with: pip install vgi[http] (or: uv sync --extra http)\n"
|
|
1447
|
+
)
|
|
1448
|
+
sys.exit(1)
|
|
1449
|
+
|
|
1450
|
+
if port == 0:
|
|
1451
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
1452
|
+
s.bind((host, 0))
|
|
1453
|
+
port = int(s.getsockname()[1])
|
|
1454
|
+
|
|
1455
|
+
from vgi.serve import _resolve_describe, _resolve_signing_key, create_app
|
|
1456
|
+
|
|
1457
|
+
describe = _resolve_describe(describe)
|
|
1458
|
+
signing_key = _resolve_signing_key()
|
|
1459
|
+
|
|
1460
|
+
wsgi_app = create_app(
|
|
1461
|
+
cls,
|
|
1462
|
+
prefix=prefix,
|
|
1463
|
+
cors_origins=cors_origins,
|
|
1464
|
+
describe=describe,
|
|
1465
|
+
signing_key=signing_key,
|
|
1466
|
+
log_level=effective_level,
|
|
1467
|
+
authenticate=authenticate,
|
|
1468
|
+
oauth_resource_metadata=oauth_resource_metadata,
|
|
1469
|
+
otel_config=otel_config,
|
|
1470
|
+
)
|
|
1471
|
+
|
|
1472
|
+
# Side-channel port publication for test harnesses: write the port
|
|
1473
|
+
# atomically (tmp + rename) so readers can watch for the file
|
|
1474
|
+
# appearing without racing a partial write.
|
|
1475
|
+
if port_file is not None:
|
|
1476
|
+
_write_port_file(port_file, port)
|
|
1477
|
+
|
|
1478
|
+
# Machine-readable port for process managers and test harnesses
|
|
1479
|
+
print(f"PORT:{port}", flush=True)
|
|
1480
|
+
_logger.info("http_server_starting host=%s port=%d prefix=%s", host, port, prefix)
|
|
1481
|
+
sys.stderr.write(f"Serving {cls.__name__} on http://{host}:{port}{prefix}\n")
|
|
1482
|
+
sys.stderr.flush()
|
|
1483
|
+
|
|
1484
|
+
# ``asyncore_use_poll=True`` switches waitress's asyncore loop from
|
|
1485
|
+
# ``select.select()`` to ``select.poll()``. Plain select() carries the
|
|
1486
|
+
# POSIX FD_SETSIZE cap (1024 on Darwin/Linux) — when a long-running
|
|
1487
|
+
# worker accumulates enough fds (broker connections, TLS sockets,
|
|
1488
|
+
# http keep-alives, etc.) past that limit the asyncore loop dies
|
|
1489
|
+
# with ``ValueError: filedescriptor out of range in select()`` and
|
|
1490
|
+
# the HTTP server stops accepting connections. ``poll()`` has no
|
|
1491
|
+
# such limit; on Linux/macOS this is the safe default for any
|
|
1492
|
+
# worker that holds many sockets open. The tradeoff is one less
|
|
1493
|
+
# syscall path that's been around since the 1970s, but waitress
|
|
1494
|
+
# has supported the poll backend since its initial release and
|
|
1495
|
+
# it's how every other production-grade WSGI server runs.
|
|
1496
|
+
# ``threads=N`` controls waitress's worker pool. Default is 4, which
|
|
1497
|
+
# under-serves any caller that issues more concurrent HTTP requests
|
|
1498
|
+
# than that — for kafka_consume with ``SET threads=8`` plus setup/
|
|
1499
|
+
# teardown overlap, half the parallelism queues on the WSGI pool.
|
|
1500
|
+
# Symptom: waitress logs ``Task queue depth is N`` at INFO when the
|
|
1501
|
+
# accept queue grows past 0. Pass ``--http-threads`` to size for
|
|
1502
|
+
# the workload's expected concurrency.
|
|
1503
|
+
serve_kwargs: dict[str, Any] = {
|
|
1504
|
+
"host": host,
|
|
1505
|
+
"port": port,
|
|
1506
|
+
"_quiet": True,
|
|
1507
|
+
"asyncore_use_poll": True,
|
|
1508
|
+
}
|
|
1509
|
+
if http_threads is not None:
|
|
1510
|
+
serve_kwargs["threads"] = http_threads
|
|
1511
|
+
waitress.serve(wsgi_app, **serve_kwargs)
|
|
1512
|
+
|
|
1513
|
+
@staticmethod
|
|
1514
|
+
def _match_function_arguments(
|
|
1515
|
+
*,
|
|
1516
|
+
function_name: str,
|
|
1517
|
+
arguments: Arguments,
|
|
1518
|
+
input_schema: pa.Schema | None,
|
|
1519
|
+
candidates: Sequence[type[Function]],
|
|
1520
|
+
) -> type[Function]:
|
|
1521
|
+
"""Find the function that matches the invocation's arguments.
|
|
1522
|
+
|
|
1523
|
+
Compares the positional and named arguments against each
|
|
1524
|
+
the candidate functions' arguments to find a match. This is
|
|
1525
|
+
useful if a function can take different list of arguments or
|
|
1526
|
+
argument types.
|
|
1527
|
+
|
|
1528
|
+
Args:
|
|
1529
|
+
function_name: The name of the candidate function
|
|
1530
|
+
arguments: The arguments that were used to call the function
|
|
1531
|
+
input_schema: The input_schema that is passed to the function,
|
|
1532
|
+
candidates: Sequence of function classes with the same name.
|
|
1533
|
+
|
|
1534
|
+
Returns:
|
|
1535
|
+
The matching function class.
|
|
1536
|
+
|
|
1537
|
+
Raises:
|
|
1538
|
+
ValueError: If no function matches or multiple functions match.
|
|
1539
|
+
|
|
1540
|
+
"""
|
|
1541
|
+
args = arguments
|
|
1542
|
+
num_positional = len(args.positional)
|
|
1543
|
+
named_keys = set(args.named.keys()) if args.named else set()
|
|
1544
|
+
|
|
1545
|
+
matches: list[type[Function]] = []
|
|
1546
|
+
|
|
1547
|
+
for func_cls in candidates:
|
|
1548
|
+
meta = func_cls.get_metadata()
|
|
1549
|
+
|
|
1550
|
+
# Scalar functions vs Table functions have different argument passing:
|
|
1551
|
+
# - Scalar functions: column params come from input batches, only
|
|
1552
|
+
# ConstParams (is_const=True) come from invocation.arguments
|
|
1553
|
+
# - Table functions: all params come from invocation.arguments
|
|
1554
|
+
is_scalar = issubclass(func_cls, ScalarFunctionGenerator)
|
|
1555
|
+
|
|
1556
|
+
# Split parameters into positional and named (excluding TableInput)
|
|
1557
|
+
positional_params = [p for p in meta.parameters if isinstance(p.position, int) and not p.is_table_input]
|
|
1558
|
+
named_params = [p for p in meta.parameters if isinstance(p.position, str)]
|
|
1559
|
+
|
|
1560
|
+
# Check positional arguments
|
|
1561
|
+
if is_scalar:
|
|
1562
|
+
# Scalar functions have two calling conventions:
|
|
1563
|
+
#
|
|
1564
|
+
# 1. New API (Param/ConstParam on compute()):
|
|
1565
|
+
# - Column Params: bound from input batch columns by position
|
|
1566
|
+
# - ConstParams: passed via invocation.arguments
|
|
1567
|
+
# - Only count ConstParams for argument matching
|
|
1568
|
+
#
|
|
1569
|
+
# 2. Legacy API (no Param/ConstParam):
|
|
1570
|
+
# - Column NAMES passed as positional args to specify bindings
|
|
1571
|
+
# - All params come from invocation.arguments
|
|
1572
|
+
#
|
|
1573
|
+
# All scalar params are always required (no defaults).
|
|
1574
|
+
# Scalar functions don't support named arguments.
|
|
1575
|
+
|
|
1576
|
+
# Only ConstParams come from arguments
|
|
1577
|
+
# Column params come from input batch
|
|
1578
|
+
const_params = [p for p in positional_params if p.is_const]
|
|
1579
|
+
expected_positional = len(const_params)
|
|
1580
|
+
has_varargs = any(p.is_varargs for p in const_params)
|
|
1581
|
+
|
|
1582
|
+
if has_varargs:
|
|
1583
|
+
# With varargs, need at least expected params
|
|
1584
|
+
if num_positional < expected_positional:
|
|
1585
|
+
continue
|
|
1586
|
+
else:
|
|
1587
|
+
if num_positional != expected_positional:
|
|
1588
|
+
continue # Must match exactly
|
|
1589
|
+
|
|
1590
|
+
# Scalar functions don't support named arguments
|
|
1591
|
+
if named_keys:
|
|
1592
|
+
continue
|
|
1593
|
+
else:
|
|
1594
|
+
# Table functions: all params come from invocation.arguments
|
|
1595
|
+
required_positional = [p for p in positional_params if p.required]
|
|
1596
|
+
min_positional = len(required_positional)
|
|
1597
|
+
max_positional = len(positional_params)
|
|
1598
|
+
has_varargs = any(p.is_varargs for p in positional_params)
|
|
1599
|
+
|
|
1600
|
+
if has_varargs:
|
|
1601
|
+
if num_positional < min_positional:
|
|
1602
|
+
continue # Too few positional arguments
|
|
1603
|
+
else:
|
|
1604
|
+
if not (min_positional <= num_positional <= max_positional):
|
|
1605
|
+
continue # Wrong number of positional arguments
|
|
1606
|
+
|
|
1607
|
+
# Check named arguments
|
|
1608
|
+
valid_named_keys = {p.position for p in named_params}
|
|
1609
|
+
required_named_keys = {p.position for p in named_params if p.required}
|
|
1610
|
+
|
|
1611
|
+
# All provided named args must be valid
|
|
1612
|
+
if not named_keys.issubset(valid_named_keys):
|
|
1613
|
+
continue # Unknown named argument
|
|
1614
|
+
|
|
1615
|
+
# All required named args must be provided
|
|
1616
|
+
if not required_named_keys.issubset(named_keys):
|
|
1617
|
+
continue # Missing required named argument
|
|
1618
|
+
|
|
1619
|
+
matches.append(func_cls)
|
|
1620
|
+
|
|
1621
|
+
# Secondary type-based filtering when multiple overloads match by count
|
|
1622
|
+
if len(matches) > 1:
|
|
1623
|
+
matches = Worker._filter_by_argument_types(
|
|
1624
|
+
matches, args, input_schema, is_scalar=issubclass(matches[0], ScalarFunctionGenerator)
|
|
1625
|
+
)
|
|
1626
|
+
|
|
1627
|
+
if len(matches) == 0:
|
|
1628
|
+
# Build helpful error message
|
|
1629
|
+
param_summaries = []
|
|
1630
|
+
for func_cls in candidates:
|
|
1631
|
+
meta = func_cls.get_metadata()
|
|
1632
|
+
params = [p for p in meta.parameters if not p.is_table_input]
|
|
1633
|
+
param_str = ", ".join(
|
|
1634
|
+
f"{p.name}: {p.type_name or '?'}" + ("" if p.required else f" = {p.default}") for p in params
|
|
1635
|
+
)
|
|
1636
|
+
param_summaries.append(f" {func_cls.__name__}({param_str})")
|
|
1637
|
+
|
|
1638
|
+
# Format input schema for scalar functions
|
|
1639
|
+
input_schema_str = ""
|
|
1640
|
+
if input_schema is not None:
|
|
1641
|
+
cols = [f"{f.name}: {f.type}" for f in input_schema]
|
|
1642
|
+
input_schema_str = f"input_columns=[{', '.join(cols)}], "
|
|
1643
|
+
|
|
1644
|
+
raise ValueError(
|
|
1645
|
+
f"No matching function '{function_name}' for arguments: "
|
|
1646
|
+
f"{input_schema_str}{_format_arguments_for_error(args)}. "
|
|
1647
|
+
f"Available overloads:\n" + "\n".join(param_summaries)
|
|
1648
|
+
)
|
|
1649
|
+
|
|
1650
|
+
if len(matches) > 1:
|
|
1651
|
+
match_names = [m.__name__ for m in matches]
|
|
1652
|
+
raise ValueError(f"Ambiguous function call '{function_name}': multiple overloads match: {match_names}")
|
|
1653
|
+
|
|
1654
|
+
return matches[0]
|
|
1655
|
+
|
|
1656
|
+
@staticmethod
|
|
1657
|
+
def _types_compatible(actual: pa.DataType, declared: pa.DataType) -> bool:
|
|
1658
|
+
"""Check if an actual argument type is compatible with a declared type.
|
|
1659
|
+
|
|
1660
|
+
Uses type-family matching: integers match integers, strings match strings,
|
|
1661
|
+
etc. This handles DuckDB sending narrower types (e.g., int32 for a literal
|
|
1662
|
+
that fits, decimal for numeric literals) when the function declares a wider
|
|
1663
|
+
type.
|
|
1664
|
+
|
|
1665
|
+
"""
|
|
1666
|
+
if actual == declared:
|
|
1667
|
+
return True
|
|
1668
|
+
# Integer family: int8/16/32/64/uint8/16/32/64
|
|
1669
|
+
if pa.types.is_integer(actual) and pa.types.is_integer(declared):
|
|
1670
|
+
return True
|
|
1671
|
+
# Float/decimal family: float16/32/64, decimal
|
|
1672
|
+
if (pa.types.is_floating(actual) or pa.types.is_decimal(actual)) and (
|
|
1673
|
+
pa.types.is_floating(declared) or pa.types.is_decimal(declared)
|
|
1674
|
+
):
|
|
1675
|
+
return True
|
|
1676
|
+
# String family: string, large_string, utf8
|
|
1677
|
+
if (pa.types.is_string(actual) or pa.types.is_large_string(actual)) and (
|
|
1678
|
+
pa.types.is_string(declared) or pa.types.is_large_string(declared)
|
|
1679
|
+
):
|
|
1680
|
+
return True
|
|
1681
|
+
# Binary family: binary, large_binary
|
|
1682
|
+
if (pa.types.is_binary(actual) or pa.types.is_large_binary(actual)) and (
|
|
1683
|
+
pa.types.is_binary(declared) or pa.types.is_large_binary(declared)
|
|
1684
|
+
):
|
|
1685
|
+
return True
|
|
1686
|
+
# Boolean
|
|
1687
|
+
return pa.types.is_boolean(actual) and pa.types.is_boolean(declared)
|
|
1688
|
+
|
|
1689
|
+
_EXACT_MATCH_SCORE = 2
|
|
1690
|
+
_FAMILY_MATCH_SCORE = 1
|
|
1691
|
+
|
|
1692
|
+
@staticmethod
|
|
1693
|
+
def _score_types(
|
|
1694
|
+
specs: list[ArgumentSpec],
|
|
1695
|
+
actual_types: Sequence[pa.DataType | None],
|
|
1696
|
+
) -> tuple[int, bool]:
|
|
1697
|
+
"""Score how well actual argument types match declared specs.
|
|
1698
|
+
|
|
1699
|
+
Compares each spec's declared arrow_type against the corresponding
|
|
1700
|
+
actual type. Elements beyond ``len(specs)`` are scored against the
|
|
1701
|
+
varargs spec (if any).
|
|
1702
|
+
|
|
1703
|
+
Args:
|
|
1704
|
+
specs: Declared argument specs (ordered by position).
|
|
1705
|
+
actual_types: Actual types aligned 1-to-1 with *specs*, with any
|
|
1706
|
+
additional varargs tail elements appended.
|
|
1707
|
+
|
|
1708
|
+
Returns:
|
|
1709
|
+
``(score, matched)`` — cumulative score and whether all types
|
|
1710
|
+
were compatible.
|
|
1711
|
+
|
|
1712
|
+
"""
|
|
1713
|
+
score = 0
|
|
1714
|
+
varargs_spec: ArgumentSpec | None = None
|
|
1715
|
+
|
|
1716
|
+
for i, spec in enumerate(specs):
|
|
1717
|
+
if spec.is_varargs:
|
|
1718
|
+
varargs_spec = spec
|
|
1719
|
+
if i >= len(actual_types):
|
|
1720
|
+
break
|
|
1721
|
+
if spec.is_any_type or spec.arrow_type == pa.null():
|
|
1722
|
+
continue
|
|
1723
|
+
actual = actual_types[i]
|
|
1724
|
+
if actual is None:
|
|
1725
|
+
continue
|
|
1726
|
+
if actual == spec.arrow_type:
|
|
1727
|
+
score += Worker._EXACT_MATCH_SCORE
|
|
1728
|
+
elif Worker._types_compatible(actual, spec.arrow_type):
|
|
1729
|
+
score += Worker._FAMILY_MATCH_SCORE
|
|
1730
|
+
else:
|
|
1731
|
+
return score, False
|
|
1732
|
+
|
|
1733
|
+
# Score remaining varargs tail elements beyond declared specs
|
|
1734
|
+
if varargs_spec is not None and not varargs_spec.is_any_type and varargs_spec.arrow_type != pa.null():
|
|
1735
|
+
for i in range(len(specs), len(actual_types)):
|
|
1736
|
+
actual = actual_types[i]
|
|
1737
|
+
if actual is None:
|
|
1738
|
+
continue
|
|
1739
|
+
if actual == varargs_spec.arrow_type:
|
|
1740
|
+
score += Worker._EXACT_MATCH_SCORE
|
|
1741
|
+
elif Worker._types_compatible(actual, varargs_spec.arrow_type):
|
|
1742
|
+
score += Worker._FAMILY_MATCH_SCORE
|
|
1743
|
+
else:
|
|
1744
|
+
return score, False
|
|
1745
|
+
|
|
1746
|
+
return score, True
|
|
1747
|
+
|
|
1748
|
+
@staticmethod
|
|
1749
|
+
def _filter_by_argument_types(
|
|
1750
|
+
matches: list[type[Function]],
|
|
1751
|
+
arguments: Arguments,
|
|
1752
|
+
input_schema: pa.Schema | None,
|
|
1753
|
+
*,
|
|
1754
|
+
is_scalar: bool,
|
|
1755
|
+
) -> list[type[Function]]:
|
|
1756
|
+
"""Narrow overload candidates by comparing argument types.
|
|
1757
|
+
|
|
1758
|
+
Called when count-based filtering leaves multiple matches.
|
|
1759
|
+
Uses extract_argument_specs to get declared arrow_type for each
|
|
1760
|
+
parameter and compares against actual argument types.
|
|
1761
|
+
|
|
1762
|
+
Args:
|
|
1763
|
+
matches: Candidate function classes (same arg count).
|
|
1764
|
+
arguments: The invocation arguments.
|
|
1765
|
+
input_schema: Input schema for scalar functions (column types).
|
|
1766
|
+
is_scalar: Whether the candidates are scalar functions.
|
|
1767
|
+
|
|
1768
|
+
Returns:
|
|
1769
|
+
Filtered list of matching candidates.
|
|
1770
|
+
|
|
1771
|
+
"""
|
|
1772
|
+
scored: list[tuple[int, type[Function]]] = []
|
|
1773
|
+
|
|
1774
|
+
for func_cls in matches:
|
|
1775
|
+
specs = extract_argument_specs(func_cls)
|
|
1776
|
+
score = 0
|
|
1777
|
+
matched = True
|
|
1778
|
+
|
|
1779
|
+
if is_scalar:
|
|
1780
|
+
# For scalar functions:
|
|
1781
|
+
# - ConstParam specs: compare against arguments.positional types
|
|
1782
|
+
# - Column Param specs: compare against input_schema field types
|
|
1783
|
+
const_specs = [s for s in specs if s.is_const]
|
|
1784
|
+
col_specs = [s for s in specs if not s.is_const and isinstance(s.position, int)]
|
|
1785
|
+
|
|
1786
|
+
# Score ConstParam types against positional arguments
|
|
1787
|
+
const_types: list[pa.DataType | None] = [
|
|
1788
|
+
arg.type if arg is not None else None for arg in arguments.positional
|
|
1789
|
+
]
|
|
1790
|
+
delta, matched = Worker._score_types(const_specs, const_types)
|
|
1791
|
+
score += delta
|
|
1792
|
+
|
|
1793
|
+
# Score column Param types against input_schema
|
|
1794
|
+
if matched and input_schema is not None:
|
|
1795
|
+
col_types: list[pa.DataType | None] = []
|
|
1796
|
+
varargs_col_spec: ArgumentSpec | None = None
|
|
1797
|
+
for spec in col_specs:
|
|
1798
|
+
if spec.is_varargs:
|
|
1799
|
+
varargs_col_spec = spec
|
|
1800
|
+
pos = spec.position
|
|
1801
|
+
assert isinstance(pos, int)
|
|
1802
|
+
if pos < len(input_schema):
|
|
1803
|
+
col_types.append(input_schema.field(pos).type)
|
|
1804
|
+
else:
|
|
1805
|
+
col_types.append(None)
|
|
1806
|
+
# Append varargs tail from input_schema
|
|
1807
|
+
if varargs_col_spec is not None:
|
|
1808
|
+
assert isinstance(varargs_col_spec.position, int)
|
|
1809
|
+
varargs_start = varargs_col_spec.position + 1
|
|
1810
|
+
for i in range(varargs_start, len(input_schema)):
|
|
1811
|
+
col_types.append(input_schema.field(i).type)
|
|
1812
|
+
delta, matched = Worker._score_types(col_specs, col_types)
|
|
1813
|
+
score += delta
|
|
1814
|
+
else:
|
|
1815
|
+
# For table functions: compare arguments.positional types
|
|
1816
|
+
pos_specs = sorted(
|
|
1817
|
+
[s for s in specs if isinstance(s.position, int) and not s.is_table_input],
|
|
1818
|
+
key=lambda s: s.position,
|
|
1819
|
+
)
|
|
1820
|
+
pos_types: list[pa.DataType | None] = [
|
|
1821
|
+
arg.type if arg is not None else None for arg in arguments.positional
|
|
1822
|
+
]
|
|
1823
|
+
delta, matched = Worker._score_types(pos_specs, pos_types)
|
|
1824
|
+
score += delta
|
|
1825
|
+
|
|
1826
|
+
if matched:
|
|
1827
|
+
scored.append((score, func_cls))
|
|
1828
|
+
|
|
1829
|
+
if not scored:
|
|
1830
|
+
return []
|
|
1831
|
+
|
|
1832
|
+
# Prefer candidates with highest score (most exact type matches)
|
|
1833
|
+
max_score = max(s for s, _ in scored)
|
|
1834
|
+
return [func_cls for s, func_cls in scored if s == max_score]
|
|
1835
|
+
|
|
1836
|
+
@staticmethod
|
|
1837
|
+
def _suggest_similar_names(name: str, candidates: list[str]) -> list[str]:
|
|
1838
|
+
"""Find function names similar to the given name.
|
|
1839
|
+
|
|
1840
|
+
Uses prefix matching, substring matching, and character overlap to
|
|
1841
|
+
suggest likely alternatives for typos.
|
|
1842
|
+
|
|
1843
|
+
Args:
|
|
1844
|
+
name: The unknown function name.
|
|
1845
|
+
candidates: List of valid function names.
|
|
1846
|
+
|
|
1847
|
+
Returns:
|
|
1848
|
+
List of similar names, sorted by relevance.
|
|
1849
|
+
|
|
1850
|
+
"""
|
|
1851
|
+
if not candidates:
|
|
1852
|
+
return []
|
|
1853
|
+
|
|
1854
|
+
name_lower = name.lower()
|
|
1855
|
+
scored: list[tuple[int, str]] = []
|
|
1856
|
+
|
|
1857
|
+
for candidate in candidates:
|
|
1858
|
+
candidate_lower = candidate.lower()
|
|
1859
|
+
|
|
1860
|
+
# Exact prefix match (highest priority)
|
|
1861
|
+
if candidate_lower.startswith(name_lower):
|
|
1862
|
+
scored.append((0, candidate))
|
|
1863
|
+
elif name_lower.startswith(candidate_lower):
|
|
1864
|
+
scored.append((1, candidate))
|
|
1865
|
+
# Substring matches
|
|
1866
|
+
elif name_lower in candidate_lower or candidate_lower in name_lower:
|
|
1867
|
+
scored.append((2, candidate))
|
|
1868
|
+
else:
|
|
1869
|
+
# Character overlap score (for typos)
|
|
1870
|
+
name_chars = set(name_lower)
|
|
1871
|
+
candidate_chars = set(candidate_lower)
|
|
1872
|
+
overlap = len(name_chars & candidate_chars)
|
|
1873
|
+
# Require at least half the characters to match
|
|
1874
|
+
if overlap > len(name_lower) // 2:
|
|
1875
|
+
scored.append((10 - overlap, candidate))
|
|
1876
|
+
|
|
1877
|
+
scored.sort(key=lambda x: (x[0], x[1]))
|
|
1878
|
+
return [candidate for _, candidate in scored]
|
|
1879
|
+
|
|
1880
|
+
def _resolve_function(self, request: BindRequest) -> type[Function]:
|
|
1881
|
+
"""Look up and disambiguate function class from registry.
|
|
1882
|
+
|
|
1883
|
+
Args:
|
|
1884
|
+
request: The BindRequest containing function_name and arguments.
|
|
1885
|
+
|
|
1886
|
+
Returns:
|
|
1887
|
+
The matching function class.
|
|
1888
|
+
|
|
1889
|
+
Raises:
|
|
1890
|
+
ValueError: If function not found or ambiguous.
|
|
1891
|
+
|
|
1892
|
+
"""
|
|
1893
|
+
registry = self._build_registry()
|
|
1894
|
+
if request.function_name not in registry:
|
|
1895
|
+
available = sorted(registry.keys())
|
|
1896
|
+
suggestions = self._suggest_similar_names(request.function_name, available)
|
|
1897
|
+
msg_lines = [f"Unknown function: '{request.function_name}'"]
|
|
1898
|
+
if suggestions:
|
|
1899
|
+
msg_lines.append(" Did you mean:")
|
|
1900
|
+
for suggestion in suggestions[:3]:
|
|
1901
|
+
msg_lines.append(f" - {suggestion}")
|
|
1902
|
+
msg_lines.append(f" Available functions: {available}")
|
|
1903
|
+
raise ValueError("\n".join(msg_lines))
|
|
1904
|
+
|
|
1905
|
+
candidates = registry[request.function_name]
|
|
1906
|
+
if len(candidates) == 1:
|
|
1907
|
+
return candidates[0]
|
|
1908
|
+
|
|
1909
|
+
return self._match_function_arguments(
|
|
1910
|
+
function_name=request.function_name,
|
|
1911
|
+
arguments=request.arguments,
|
|
1912
|
+
input_schema=request.input_schema,
|
|
1913
|
+
candidates=candidates,
|
|
1914
|
+
)
|
|
1915
|
+
|
|
1916
|
+
def _resolve_function_by_name(
|
|
1917
|
+
self,
|
|
1918
|
+
function_name: str,
|
|
1919
|
+
attach_opaque_data: bytes | None = None,
|
|
1920
|
+
function_type: type[Function] | None = None,
|
|
1921
|
+
) -> type[Function]:
|
|
1922
|
+
"""Look up a function by name only (no argument disambiguation).
|
|
1923
|
+
|
|
1924
|
+
Args:
|
|
1925
|
+
function_name: The name of the function to look up.
|
|
1926
|
+
attach_opaque_data: Optional attach ID (reserved for future catalog use).
|
|
1927
|
+
function_type: Optional base class to filter candidates by type.
|
|
1928
|
+
|
|
1929
|
+
"""
|
|
1930
|
+
registry = self._build_registry()
|
|
1931
|
+
if function_name not in registry:
|
|
1932
|
+
available = sorted(registry.keys())
|
|
1933
|
+
raise ValueError(f"Unknown function: '{function_name}'. Available: {available}")
|
|
1934
|
+
candidates = registry[function_name]
|
|
1935
|
+
if function_type is not None:
|
|
1936
|
+
candidates = [c for c in candidates if issubclass(c, function_type)]
|
|
1937
|
+
if not candidates:
|
|
1938
|
+
raise ValueError(
|
|
1939
|
+
f"No {function_type.__name__} named '{function_name}' found. "
|
|
1940
|
+
f"Candidates exist but are not {function_type.__name__}."
|
|
1941
|
+
)
|
|
1942
|
+
if len(candidates) == 1:
|
|
1943
|
+
return candidates[0]
|
|
1944
|
+
# For aggregates with overloads, return the first match
|
|
1945
|
+
# (overload disambiguation happens at bind time on the C++ side)
|
|
1946
|
+
return candidates[0]
|
|
1947
|
+
|
|
1948
|
+
# ---------------------------------------------------------------------------
|
|
1949
|
+
# Catalog helpers
|
|
1950
|
+
# ---------------------------------------------------------------------------
|
|
1951
|
+
|
|
1952
|
+
_catalog_instance: CatalogInterface | None = None
|
|
1953
|
+
|
|
1954
|
+
# --- Catalog opaque-data AEAD envelopes -----------------------------------
|
|
1955
|
+
#
|
|
1956
|
+
# The catalog implementation always sees plaintext; the client (and C++)
|
|
1957
|
+
# only ever see sealed envelopes. ``seal_*`` runs on the way out of
|
|
1958
|
+
# ``catalog_attach`` / ``catalog_transaction_begin``; ``unwrap_*`` runs at
|
|
1959
|
+
# the top of every other catalog / function-dispatch path that carries an
|
|
1960
|
+
# opaque value. When ``_signing_key`` is ``None`` (subprocess / unix
|
|
1961
|
+
# transports) every helper is a transparent pass-through.
|
|
1962
|
+
|
|
1963
|
+
@staticmethod
|
|
1964
|
+
def _opaque_data_rejected(field: str) -> ValueError:
|
|
1965
|
+
"""Build the uniform error for an opaque value that fails to open.
|
|
1966
|
+
|
|
1967
|
+
Every failure mode — wrong principal, wrong parent attach, tampered,
|
|
1968
|
+
malformed, or simply unknown — maps to this single message so a
|
|
1969
|
+
probing caller cannot distinguish them.
|
|
1970
|
+
"""
|
|
1971
|
+
return ValueError(f"{field} not recognized")
|
|
1972
|
+
|
|
1973
|
+
def _seal_attach(self, plaintext: bytes) -> AttachOpaqueData:
|
|
1974
|
+
"""Seal a plaintext ``attach_opaque_data`` value into an AEAD envelope."""
|
|
1975
|
+
key = self._signing_key
|
|
1976
|
+
if key is None:
|
|
1977
|
+
return AttachOpaqueData(plaintext)
|
|
1978
|
+
from vgi_rpc import crypto
|
|
1979
|
+
|
|
1980
|
+
return AttachOpaqueData(
|
|
1981
|
+
crypto.seal_bytes(plaintext, key, aad=_attach_aad(current_auth()), version=_ATTACH_ENVELOPE_VERSION)
|
|
1982
|
+
)
|
|
1983
|
+
|
|
1984
|
+
def _unwrap_attach_full_with_auth(self, envelope: bytes | None, auth: AuthContext | None) -> bytes | None:
|
|
1985
|
+
"""Open an attach envelope under an explicit ``auth``, returning the full plaintext.
|
|
1986
|
+
|
|
1987
|
+
The full framework plaintext is ``uuid(16) || catalog_bytes``. Pass-through
|
|
1988
|
+
when there is no signing key. Raises the uniform
|
|
1989
|
+
:meth:`_opaque_data_rejected` error on any open failure.
|
|
1990
|
+
"""
|
|
1991
|
+
if envelope is None:
|
|
1992
|
+
return None
|
|
1993
|
+
key = self._signing_key
|
|
1994
|
+
if key is None:
|
|
1995
|
+
return bytes(envelope)
|
|
1996
|
+
from vgi_rpc import crypto
|
|
1997
|
+
|
|
1998
|
+
try:
|
|
1999
|
+
return crypto.open_bytes(envelope, key, aad=_attach_aad(auth), version=_ATTACH_ENVELOPE_VERSION)
|
|
2000
|
+
except crypto.SealError as exc:
|
|
2001
|
+
raise self._opaque_data_rejected("attach_opaque_data") from exc
|
|
2002
|
+
|
|
2003
|
+
@overload
|
|
2004
|
+
def _unwrap_attach(self, envelope: bytes) -> AttachOpaqueData: ...
|
|
2005
|
+
@overload
|
|
2006
|
+
def _unwrap_attach(self, envelope: None) -> None: ...
|
|
2007
|
+
def _unwrap_attach(self, envelope: bytes | None) -> AttachOpaqueData | None:
|
|
2008
|
+
"""Open an ``attach_opaque_data`` envelope, returning the catalog's plaintext.
|
|
2009
|
+
|
|
2010
|
+
The framework UUID prefix is stripped off — this is what every catalog_*
|
|
2011
|
+
method and function resolver wants: the bytes the catalog returned at
|
|
2012
|
+
``catalog_attach``, never the shard UUID. Storage routing uses
|
|
2013
|
+
:meth:`_unwrap_attach_full` to reach the UUID instead. Pass-through (minus
|
|
2014
|
+
the strip) when there is no signing key.
|
|
2015
|
+
"""
|
|
2016
|
+
if envelope is None:
|
|
2017
|
+
return None
|
|
2018
|
+
full = self._unwrap_attach_full_with_auth(envelope, current_auth())
|
|
2019
|
+
assert full is not None # a non-None envelope always yields plaintext
|
|
2020
|
+
return AttachOpaqueData(bytes(full[_ATTACH_UUID_LEN:]))
|
|
2021
|
+
|
|
2022
|
+
def _unwrap_attach_full(self, envelope: bytes | None) -> bytes | None:
|
|
2023
|
+
"""Open an attach envelope returning the full plaintext (not stripped).
|
|
2024
|
+
|
|
2025
|
+
The full framework plaintext is ``uuid(16) || catalog_bytes``. Storage
|
|
2026
|
+
shards on the leading UUID, so the function-execution paths thread this
|
|
2027
|
+
full form to ``BoundStorage`` / params (which strips the UUID for the
|
|
2028
|
+
user-facing ``attach_opaque_data``).
|
|
2029
|
+
"""
|
|
2030
|
+
return self._unwrap_attach_full_with_auth(envelope, current_auth())
|
|
2031
|
+
|
|
2032
|
+
def _unwrap_attach_full_for(self, request: Any) -> bytes | None:
|
|
2033
|
+
"""Like :meth:`_unwrap_attach_full`, reading the sealed attach off a request.
|
|
2034
|
+
|
|
2035
|
+
Accepts the attach on either ``attach_opaque_data`` or
|
|
2036
|
+
``bind_call.attach_opaque_data``.
|
|
2037
|
+
"""
|
|
2038
|
+
sealed = getattr(request, "attach_opaque_data", None)
|
|
2039
|
+
if sealed is None:
|
|
2040
|
+
bind_call = getattr(request, "bind_call", None)
|
|
2041
|
+
if bind_call is not None:
|
|
2042
|
+
sealed = getattr(bind_call, "attach_opaque_data", None)
|
|
2043
|
+
return self._unwrap_attach_full(sealed)
|
|
2044
|
+
|
|
2045
|
+
def _bound(self, storage: Any, execution_id: bytes, request: Any) -> BoundStorage:
|
|
2046
|
+
"""Build a ``BoundStorage`` for ``request``, sharding on its unwrapped attach UUID.
|
|
2047
|
+
|
|
2048
|
+
The worker unwraps the request's sealed attach to ``uuid(16) || catalog_bytes``
|
|
2049
|
+
and threads it in; storage shards on the leading UUID. Centralizes the
|
|
2050
|
+
per-handler construction so each call site stays a one-liner.
|
|
2051
|
+
"""
|
|
2052
|
+
return BoundStorage(
|
|
2053
|
+
storage,
|
|
2054
|
+
execution_id,
|
|
2055
|
+
request=request,
|
|
2056
|
+
attach_plaintext=self._unwrap_attach_full_for(request),
|
|
2057
|
+
)
|
|
2058
|
+
|
|
2059
|
+
def _seal_transaction(self, plaintext: bytes, attach_envelope: bytes) -> bytes:
|
|
2060
|
+
"""Seal a plaintext ``transaction_opaque_data`` value into an AEAD envelope.
|
|
2061
|
+
|
|
2062
|
+
``attach_envelope`` is the (sealed) ``attach_opaque_data`` the call
|
|
2063
|
+
carried; binding it into the AAD ties the transaction to its parent
|
|
2064
|
+
attach.
|
|
2065
|
+
"""
|
|
2066
|
+
key = self._signing_key
|
|
2067
|
+
if key is None:
|
|
2068
|
+
return plaintext
|
|
2069
|
+
from vgi_rpc import crypto
|
|
2070
|
+
|
|
2071
|
+
return crypto.seal_bytes(
|
|
2072
|
+
plaintext,
|
|
2073
|
+
key,
|
|
2074
|
+
aad=_transaction_aad(current_auth(), attach_envelope),
|
|
2075
|
+
version=_TRANSACTION_ENVELOPE_VERSION,
|
|
2076
|
+
)
|
|
2077
|
+
|
|
2078
|
+
@overload
|
|
2079
|
+
def _unwrap_transaction(self, envelope: bytes, attach_envelope: bytes) -> TransactionOpaqueData: ...
|
|
2080
|
+
@overload
|
|
2081
|
+
def _unwrap_transaction(self, envelope: None, attach_envelope: bytes) -> None: ...
|
|
2082
|
+
def _unwrap_transaction(self, envelope: bytes | None, attach_envelope: bytes) -> TransactionOpaqueData | None:
|
|
2083
|
+
"""Open a ``transaction_opaque_data`` envelope, returning the plaintext.
|
|
2084
|
+
|
|
2085
|
+
``attach_envelope`` is the (sealed) ``attach_opaque_data`` the same
|
|
2086
|
+
call carries — it must match the attach the transaction was minted
|
|
2087
|
+
under, or the open fails. Pass-through when there is no signing key.
|
|
2088
|
+
"""
|
|
2089
|
+
if envelope is None:
|
|
2090
|
+
return None
|
|
2091
|
+
key = self._signing_key
|
|
2092
|
+
if key is None:
|
|
2093
|
+
return TransactionOpaqueData(envelope)
|
|
2094
|
+
from vgi_rpc import crypto
|
|
2095
|
+
|
|
2096
|
+
try:
|
|
2097
|
+
plaintext = crypto.open_bytes(
|
|
2098
|
+
envelope,
|
|
2099
|
+
key,
|
|
2100
|
+
aad=_transaction_aad(current_auth(), attach_envelope),
|
|
2101
|
+
version=_TRANSACTION_ENVELOPE_VERSION,
|
|
2102
|
+
)
|
|
2103
|
+
except crypto.SealError as exc:
|
|
2104
|
+
raise self._opaque_data_rejected("transaction_opaque_data") from exc
|
|
2105
|
+
return TransactionOpaqueData(plaintext)
|
|
2106
|
+
|
|
2107
|
+
def _get_catalog(self) -> CatalogInterface:
|
|
2108
|
+
"""Get the CatalogInterface instance for this worker.
|
|
2109
|
+
|
|
2110
|
+
The instance is created on first access and cached for the lifetime
|
|
2111
|
+
of the worker, so that state (attach IDs, created schemas, etc.)
|
|
2112
|
+
persists across RPC calls.
|
|
2113
|
+
|
|
2114
|
+
Returns:
|
|
2115
|
+
CatalogInterface instance.
|
|
2116
|
+
|
|
2117
|
+
Raises:
|
|
2118
|
+
ValueError: If no catalog interface is available.
|
|
2119
|
+
|
|
2120
|
+
"""
|
|
2121
|
+
if self._catalog_instance is not None:
|
|
2122
|
+
return self._catalog_instance
|
|
2123
|
+
catalog_class = self._get_catalog_interface()
|
|
2124
|
+
if catalog_class is None:
|
|
2125
|
+
raise ValueError(
|
|
2126
|
+
"CatalogInterface invocation received but no catalog is available. "
|
|
2127
|
+
"Either set catalog_interface class attribute to a CatalogInterface "
|
|
2128
|
+
"subclass, or ensure functions are defined and catalog_name is set."
|
|
2129
|
+
)
|
|
2130
|
+
self._catalog_instance = catalog_class()
|
|
2131
|
+
return self._catalog_instance
|
|
2132
|
+
|
|
2133
|
+
@staticmethod
|
|
2134
|
+
def _options_batch_to_dict(batch: pa.RecordBatch | None) -> dict[str, Any]:
|
|
2135
|
+
"""Convert an options RecordBatch (1 row, mixed types) to a dict."""
|
|
2136
|
+
if batch is None or batch.num_rows == 0:
|
|
2137
|
+
return {}
|
|
2138
|
+
return batch.to_pylist()[0]
|
|
2139
|
+
|
|
2140
|
+
# ---------------------------------------------------------------------------
|
|
2141
|
+
# VgiProtocol implementation - bind/init
|
|
2142
|
+
# ---------------------------------------------------------------------------
|
|
2143
|
+
|
|
2144
|
+
def bind(self, request: BindRequest, ctx: CallContext) -> BindResponse:
|
|
2145
|
+
"""Resolve output schema and validate arguments.
|
|
2146
|
+
|
|
2147
|
+
Implements VgiProtocol.bind().
|
|
2148
|
+
"""
|
|
2149
|
+
self._vgi_tracer.set_current_span_attributes(
|
|
2150
|
+
{
|
|
2151
|
+
"vgi.function.name": request.function_name,
|
|
2152
|
+
"vgi.function.type": request.function_type.value,
|
|
2153
|
+
"vgi.principal": ctx.auth.principal,
|
|
2154
|
+
"vgi.auth_domain": ctx.auth.domain,
|
|
2155
|
+
"vgi.authenticated": ctx.auth.authenticated,
|
|
2156
|
+
}
|
|
2157
|
+
)
|
|
2158
|
+
# vgi.attach_opaque_data / vgi.transaction_opaque_data are auto-tagged by vgi-rpc's
|
|
2159
|
+
# Sentry dispatch hook (short-hash form) on every method that
|
|
2160
|
+
# carries them.
|
|
2161
|
+
# The request carries the SEALED attach; the worker holds the signing key,
|
|
2162
|
+
# so unwrap once here to the full framework plaintext (uuid||catalog_bytes)
|
|
2163
|
+
# and thread it down. Storage shards on the leading UUID; function bodies
|
|
2164
|
+
# get the catalog bytes (uuid stripped) via params.attach_opaque_data.
|
|
2165
|
+
attach_plaintext = self._unwrap_attach_full(getattr(request, "attach_opaque_data", None))
|
|
2166
|
+
func_cls = self._resolve_function(request)
|
|
2167
|
+
self._validate_required_settings(func_cls, request)
|
|
2168
|
+
instance = func_cls(logger=_logger)
|
|
2169
|
+
return instance.bind(request, ctx=ctx, attach_plaintext=attach_plaintext) # type: ignore[attr-defined, no-any-return]
|
|
2170
|
+
|
|
2171
|
+
def table_function_cardinality(
|
|
2172
|
+
self, request: TableFunctionCardinalityRequest, ctx: CallContext
|
|
2173
|
+
) -> TableCardinality:
|
|
2174
|
+
"""Estimate the cardinality of a table function's output.
|
|
2175
|
+
|
|
2176
|
+
Implements VgiProtocol.table_function_cardinality().
|
|
2177
|
+
"""
|
|
2178
|
+
attach_plaintext = self._unwrap_attach_full(getattr(request.bind_call, "attach_opaque_data", None))
|
|
2179
|
+
func_cls = self._resolve_function(request.bind_call)
|
|
2180
|
+
if not issubclass(func_cls, TableFunctionGenerator):
|
|
2181
|
+
raise ValueError(
|
|
2182
|
+
"Cardinality estimation is only supported for table"
|
|
2183
|
+
f" functions, but '{func_cls.__name__}' is not a TableFunctionGenerator."
|
|
2184
|
+
)
|
|
2185
|
+
return func_cls.cardinality(
|
|
2186
|
+
func_cls._make_bind_params(request.bind_call, auth_context=ctx.auth, attach_plaintext=attach_plaintext)
|
|
2187
|
+
)
|
|
2188
|
+
|
|
2189
|
+
def table_function_statistics(self, request: TableFunctionStatisticsRequest, ctx: CallContext) -> bytes | None:
|
|
2190
|
+
"""Return per-column statistics for a table function's output.
|
|
2191
|
+
|
|
2192
|
+
Implements VgiProtocol.table_function_statistics(). Returns IPC bytes
|
|
2193
|
+
of the serialized ColumnStatistics batch (same wire shape as
|
|
2194
|
+
catalog_table_column_statistics_get), or None when stats are unknown.
|
|
2195
|
+
"""
|
|
2196
|
+
attach_plaintext = self._unwrap_attach_full(getattr(request.bind_call, "attach_opaque_data", None))
|
|
2197
|
+
func_cls = self._resolve_function(request.bind_call)
|
|
2198
|
+
if not issubclass(func_cls, TableFunctionGenerator):
|
|
2199
|
+
return None
|
|
2200
|
+
stats = func_cls.statistics(
|
|
2201
|
+
func_cls._make_bind_params(request.bind_call, auth_context=ctx.auth, attach_plaintext=attach_plaintext)
|
|
2202
|
+
)
|
|
2203
|
+
if not stats:
|
|
2204
|
+
return None
|
|
2205
|
+
return serialize_column_statistics(stats, cache_max_age_seconds=None)
|
|
2206
|
+
|
|
2207
|
+
def table_function_dynamic_to_string(
|
|
2208
|
+
self, request: TableFunctionDynamicToStringRequest, ctx: CallContext
|
|
2209
|
+
) -> TableFunctionDynamicToStringResponse:
|
|
2210
|
+
"""Return user diagnostics for EXPLAIN ANALYZE Extra Info.
|
|
2211
|
+
|
|
2212
|
+
Implements VgiProtocol.table_function_dynamic_to_string(). Fired
|
|
2213
|
+
once per parallel scan thread post-execution. Best-effort: any
|
|
2214
|
+
exception (including a misbehaving user override) is logged and
|
|
2215
|
+
an empty response is returned so the EA query never aborts.
|
|
2216
|
+
"""
|
|
2217
|
+
empty = TableFunctionDynamicToStringResponse(keys=[], values=[])
|
|
2218
|
+
try:
|
|
2219
|
+
attach_plaintext = self._unwrap_attach_full(getattr(request.bind_call, "attach_opaque_data", None))
|
|
2220
|
+
func_cls = self._resolve_function(request.bind_call)
|
|
2221
|
+
except Exception:
|
|
2222
|
+
_logger.exception("dynamic_to_string: failed to resolve function class")
|
|
2223
|
+
return empty
|
|
2224
|
+
if not issubclass(func_cls, TableFunctionGenerator):
|
|
2225
|
+
return empty
|
|
2226
|
+
try:
|
|
2227
|
+
params = func_cls._make_bind_params(
|
|
2228
|
+
request.bind_call,
|
|
2229
|
+
auth_context=ctx.auth,
|
|
2230
|
+
execution_id=request.global_execution_id,
|
|
2231
|
+
attach_plaintext=attach_plaintext,
|
|
2232
|
+
)
|
|
2233
|
+
mapping = func_cls.dynamic_to_string(params, request.global_execution_id)
|
|
2234
|
+
except Exception:
|
|
2235
|
+
_logger.exception("dynamic_to_string: user hook raised on %s", func_cls.__name__)
|
|
2236
|
+
return empty
|
|
2237
|
+
if not mapping:
|
|
2238
|
+
return empty
|
|
2239
|
+
keys: list[str] = []
|
|
2240
|
+
values: list[str] = []
|
|
2241
|
+
for k, v in mapping.items():
|
|
2242
|
+
keys.append(str(k))
|
|
2243
|
+
values.append(str(v))
|
|
2244
|
+
return TableFunctionDynamicToStringResponse(keys=keys, values=values)
|
|
2245
|
+
|
|
2246
|
+
# ========== Aggregate Function Methods ==========
|
|
2247
|
+
|
|
2248
|
+
def _load_aggregate_const_args(
|
|
2249
|
+
self,
|
|
2250
|
+
func_cls: type[AggregateFunction], # type: ignore[type-arg]
|
|
2251
|
+
storage: BoundStorage,
|
|
2252
|
+
) -> Arguments | None:
|
|
2253
|
+
"""Load const arguments stored during aggregate_bind (group_id=-2).
|
|
2254
|
+
|
|
2255
|
+
Cached in-process per execution_id. The const args are written once
|
|
2256
|
+
at aggregate_bind time and never change, so a single ``aggregate_get``
|
|
2257
|
+
is enough for the whole execution. Without the cache, windowed
|
|
2258
|
+
aggregates would issue one storage read per output row — devastating
|
|
2259
|
+
under remote storage backends like the CF Durable Object (~60 ms RTT
|
|
2260
|
+
× N output rows).
|
|
2261
|
+
"""
|
|
2262
|
+
from vgi.arguments import Arguments
|
|
2263
|
+
|
|
2264
|
+
shard_key = storage._shard_key
|
|
2265
|
+
execution_id = storage._execution_id
|
|
2266
|
+
cached = _aggregate_const_args_cache.get(shard_key, execution_id)
|
|
2267
|
+
if cached is _ABSENT_SENTINEL:
|
|
2268
|
+
return None
|
|
2269
|
+
if cached is not None:
|
|
2270
|
+
return cast(Arguments, cached)
|
|
2271
|
+
|
|
2272
|
+
# Const args live at the synthetic group_id -2 in namespace FrameworkNS.AGGREGATE_STATE —
|
|
2273
|
+
# same row-keyspace as the rest of aggregate_state, no separate table
|
|
2274
|
+
# needed. The negative key avoids collision with caller-supplied
|
|
2275
|
+
# group_ids (which are non-negative).
|
|
2276
|
+
result = storage.state_get(FrameworkNS.AGGREGATE_STATE, BoundStorage.pack_int_key(-2))
|
|
2277
|
+
if result is None:
|
|
2278
|
+
# Most aggregates have no const params; cache the negative result
|
|
2279
|
+
# so we don't re-hit storage on every aggregate_window /
|
|
2280
|
+
# aggregate_update / etc.
|
|
2281
|
+
_aggregate_const_args_cache.put(shard_key, execution_id, _ABSENT_SENTINEL)
|
|
2282
|
+
return None
|
|
2283
|
+
parsed = Arguments.deserialize_from_bytes(result)
|
|
2284
|
+
_aggregate_const_args_cache.put(shard_key, execution_id, parsed)
|
|
2285
|
+
return parsed
|
|
2286
|
+
|
|
2287
|
+
def aggregate_bind(
|
|
2288
|
+
self,
|
|
2289
|
+
request: AggregateBindRequest,
|
|
2290
|
+
ctx: CallContext,
|
|
2291
|
+
) -> AggregateBindResponse:
|
|
2292
|
+
"""Bind an aggregate function, return output schema and execution_id."""
|
|
2293
|
+
from vgi.protocol import AggregateBindResponse
|
|
2294
|
+
|
|
2295
|
+
func_cls = self._resolve_function_by_name(
|
|
2296
|
+
request.function_name, self._unwrap_attach(request.attach_opaque_data), function_type=AggregateFunction
|
|
2297
|
+
)
|
|
2298
|
+
if not issubclass(func_cls, AggregateFunction):
|
|
2299
|
+
raise TypeError(f"Function '{request.function_name}' is not an AggregateFunction (got {func_cls.__name__})")
|
|
2300
|
+
|
|
2301
|
+
# Mirror the scalar varargs guard: a Param(varargs=True) on update()
|
|
2302
|
+
# must bind to at least one input column. Without this check, calling
|
|
2303
|
+
# a varargs aggregate with zero columns crashes much later inside
|
|
2304
|
+
# update() with an opaque "missing 1 required positional argument".
|
|
2305
|
+
compute_params = getattr(func_cls, "_compute_params", {}) or {}
|
|
2306
|
+
for name, arg in compute_params.items():
|
|
2307
|
+
if not getattr(arg, "varargs", False):
|
|
2308
|
+
continue
|
|
2309
|
+
col_idx = arg._resolution_index if arg._resolution_index is not None else 0
|
|
2310
|
+
n_cols = len(request.input_schema) if request.input_schema is not None else 0
|
|
2311
|
+
if col_idx >= n_cols:
|
|
2312
|
+
from vgi.arguments import ArgumentValidationError
|
|
2313
|
+
|
|
2314
|
+
raise ArgumentValidationError(
|
|
2315
|
+
f"Varargs parameter '{name}' requires at least 1 value.",
|
|
2316
|
+
arg_name=name,
|
|
2317
|
+
position=arg.position,
|
|
2318
|
+
constraint="varargs requires at least 1 value",
|
|
2319
|
+
doc=arg.doc if arg.doc else None,
|
|
2320
|
+
)
|
|
2321
|
+
|
|
2322
|
+
execution_id = uuid.uuid4().bytes
|
|
2323
|
+
bind_params = AggregateBindParams(
|
|
2324
|
+
args=request.arguments,
|
|
2325
|
+
input_schema=request.input_schema,
|
|
2326
|
+
settings=_batch_to_scalar_dict(request.settings),
|
|
2327
|
+
secrets=SecretsAccessor(request.secrets),
|
|
2328
|
+
auth_context=ctx.auth,
|
|
2329
|
+
)
|
|
2330
|
+
result = func_cls.on_bind(bind_params)
|
|
2331
|
+
|
|
2332
|
+
if bind_params.secrets.needs_resolution:
|
|
2333
|
+
raise NotImplementedError(
|
|
2334
|
+
f"Aggregate function '{request.function_name}' requires secret resolution, "
|
|
2335
|
+
"which is not yet supported for aggregate functions."
|
|
2336
|
+
)
|
|
2337
|
+
|
|
2338
|
+
# Store const arguments in FunctionStorage for later callbacks
|
|
2339
|
+
# (synthetic group_id=-2 in namespace FrameworkNS.AGGREGATE_STATE).
|
|
2340
|
+
if request.arguments and request.arguments.positional:
|
|
2341
|
+
storage = self._bound(func_cls.storage, execution_id, request)
|
|
2342
|
+
storage.state_put(
|
|
2343
|
+
FrameworkNS.AGGREGATE_STATE, BoundStorage.pack_int_key(-2), request.arguments.serialize_to_bytes()
|
|
2344
|
+
)
|
|
2345
|
+
|
|
2346
|
+
return AggregateBindResponse(
|
|
2347
|
+
output_schema=result.output_schema,
|
|
2348
|
+
execution_id=execution_id,
|
|
2349
|
+
)
|
|
2350
|
+
|
|
2351
|
+
def aggregate_update(
|
|
2352
|
+
self,
|
|
2353
|
+
request: AggregateUpdateRequest,
|
|
2354
|
+
ctx: CallContext,
|
|
2355
|
+
) -> AggregateUpdateResponse:
|
|
2356
|
+
"""Accumulate rows from a DataChunk into per-group state."""
|
|
2357
|
+
from vgi.aggregate_function import GROUP_COLUMN_NAME
|
|
2358
|
+
from vgi.protocol import AggregateUpdateResponse
|
|
2359
|
+
|
|
2360
|
+
func_cls = self._resolve_function_by_name(
|
|
2361
|
+
request.function_name, self._unwrap_attach(request.attach_opaque_data), function_type=AggregateFunction
|
|
2362
|
+
)
|
|
2363
|
+
if not issubclass(func_cls, AggregateFunction):
|
|
2364
|
+
raise TypeError(f"Function '{request.function_name}' is not an AggregateFunction (got {func_cls.__name__})")
|
|
2365
|
+
|
|
2366
|
+
batch = pa.ipc.open_stream(request.input_batch).read_next_batch()
|
|
2367
|
+
storage = self._bound(func_cls.storage, request.execution_id, request)
|
|
2368
|
+
|
|
2369
|
+
# Strip __vgi_group_id and extract group_ids
|
|
2370
|
+
gid_col_idx = batch.schema.get_field_index(GROUP_COLUMN_NAME)
|
|
2371
|
+
group_ids: pa.Int64Array = batch.column(gid_col_idx).cast(pa.int64()) # type: ignore[assignment]
|
|
2372
|
+
clean_batch = batch.remove_column(gid_col_idx)
|
|
2373
|
+
|
|
2374
|
+
# Load existing states, create initial_state for new groups
|
|
2375
|
+
unique_gids: list[int] = [v.as_py() for v in group_ids.unique()]
|
|
2376
|
+
|
|
2377
|
+
if func_cls.state_class is None:
|
|
2378
|
+
raise ValueError(f"Aggregate function '{request.function_name}' has no state_class defined")
|
|
2379
|
+
const_args = self._load_aggregate_const_args(func_cls, storage)
|
|
2380
|
+
params = ProcessParams(
|
|
2381
|
+
args=const_args,
|
|
2382
|
+
init_call=None,
|
|
2383
|
+
init_response=None,
|
|
2384
|
+
output_schema=pa.schema([]),
|
|
2385
|
+
settings={},
|
|
2386
|
+
secrets={},
|
|
2387
|
+
storage=storage,
|
|
2388
|
+
auth_context=ctx.auth,
|
|
2389
|
+
)
|
|
2390
|
+
# ``states`` is a tracking dict that records every gid the user's
|
|
2391
|
+
# ``update()`` reassigns. Earlier this method used a plain dict and
|
|
2392
|
+
# then heuristically skipped persisting "new groups whose serialized
|
|
2393
|
+
# state didn't change". That heuristic conflated two cases:
|
|
2394
|
+
#
|
|
2395
|
+
# 1. ``update()`` saw rows but chose not to mutate state (e.g.
|
|
2396
|
+
# SumFunction skipping all-NULL value_sum) → finalize should
|
|
2397
|
+
# return NULL because the group effectively had no rows.
|
|
2398
|
+
# 2. ``update()`` saw rows and assigned a state that happens to be
|
|
2399
|
+
# byte-equal to the initial state (e.g. ``SumState(total=0)``
|
|
2400
|
+
# after summing zeros) → finalize should return that state.
|
|
2401
|
+
#
|
|
2402
|
+
# The fix: persist a state iff the user explicitly wrote it during
|
|
2403
|
+
# this batch's ``update()``. Pre-existing entries from prior batches
|
|
2404
|
+
# are also persisted so multi-batch state survives.
|
|
2405
|
+
existing_gids: set[int] = set()
|
|
2406
|
+
states: _UpdateTrackingDict[int, Any] = _UpdateTrackingDict()
|
|
2407
|
+
gid_keys = [BoundStorage.pack_int_key(g) for g in unique_gids]
|
|
2408
|
+
stored = storage.state_get_many(FrameworkNS.AGGREGATE_STATE, gid_keys)
|
|
2409
|
+
for i, gid in enumerate(unique_gids):
|
|
2410
|
+
value = stored[i]
|
|
2411
|
+
if value is not None:
|
|
2412
|
+
states[gid] = func_cls.state_class.deserialize_from_bytes(value)
|
|
2413
|
+
existing_gids.add(gid)
|
|
2414
|
+
else:
|
|
2415
|
+
states[gid] = func_cls.initial_state(params)
|
|
2416
|
+
# Snapshot the writes made during seeding so we don't count them as
|
|
2417
|
+
# user-initiated mutations.
|
|
2418
|
+
states.clear_writes()
|
|
2419
|
+
|
|
2420
|
+
# Call user's update() with column arrays and const scalars as kwargs
|
|
2421
|
+
kwargs: dict[str, Any] = {"states": states, "group_ids": group_ids}
|
|
2422
|
+
compute_params = getattr(func_cls, "_compute_params", {})
|
|
2423
|
+
for name, arg in compute_params.items():
|
|
2424
|
+
col_idx = getattr(arg, "_resolution_index", None)
|
|
2425
|
+
if col_idx is not None and col_idx < clean_batch.num_columns:
|
|
2426
|
+
if getattr(arg, "varargs", False):
|
|
2427
|
+
# Varargs: collect all columns from this index onward as a list
|
|
2428
|
+
kwargs[name] = [clean_batch.column(i) for i in range(col_idx, clean_batch.num_columns)]
|
|
2429
|
+
else:
|
|
2430
|
+
kwargs[name] = clean_batch.column(col_idx)
|
|
2431
|
+
# Extract const values from stored arguments
|
|
2432
|
+
const_params = getattr(func_cls, "_const_params", {})
|
|
2433
|
+
const_phases = getattr(func_cls, "_const_param_phases", {})
|
|
2434
|
+
if const_args and const_args.positional and const_params:
|
|
2435
|
+
for name, arg in const_params.items():
|
|
2436
|
+
phase = const_phases.get(name, "all")
|
|
2437
|
+
if phase not in ("all", "update"):
|
|
2438
|
+
continue # Skip finalize-only params during update
|
|
2439
|
+
arg_idx = getattr(arg, "_resolution_index", None)
|
|
2440
|
+
if arg_idx is not None and arg_idx < len(const_args.positional):
|
|
2441
|
+
scalar = const_args.positional[arg_idx]
|
|
2442
|
+
kwargs[name] = scalar.as_py() if scalar is not None else None
|
|
2443
|
+
# Inject params for functions that declare it
|
|
2444
|
+
import inspect
|
|
2445
|
+
|
|
2446
|
+
update_sig = inspect.signature(func_cls.update)
|
|
2447
|
+
if "params" in update_sig.parameters:
|
|
2448
|
+
kwargs["params"] = params
|
|
2449
|
+
func_cls.update(**kwargs)
|
|
2450
|
+
|
|
2451
|
+
# Persist (a) every gid that already had storage from a prior batch
|
|
2452
|
+
# (its state may have been mutated by the user) and (b) every gid the
|
|
2453
|
+
# user's ``update()`` explicitly wrote during this batch.
|
|
2454
|
+
gids_to_persist = existing_gids | states.written
|
|
2455
|
+
items: list[tuple[bytes, bytes]] = [
|
|
2456
|
+
(BoundStorage.pack_int_key(gid), states[gid].serialize_to_bytes()) for gid in gids_to_persist
|
|
2457
|
+
]
|
|
2458
|
+
if items:
|
|
2459
|
+
storage.state_put_many(FrameworkNS.AGGREGATE_STATE, items)
|
|
2460
|
+
|
|
2461
|
+
return AggregateUpdateResponse()
|
|
2462
|
+
|
|
2463
|
+
def aggregate_combine(
|
|
2464
|
+
self,
|
|
2465
|
+
request: AggregateCombineRequest,
|
|
2466
|
+
ctx: CallContext,
|
|
2467
|
+
) -> AggregateCombineResponse:
|
|
2468
|
+
"""Merge source states into target states."""
|
|
2469
|
+
from vgi.protocol import AggregateCombineResponse
|
|
2470
|
+
|
|
2471
|
+
func_cls = self._resolve_function_by_name(
|
|
2472
|
+
request.function_name, self._unwrap_attach(request.attach_opaque_data), function_type=AggregateFunction
|
|
2473
|
+
)
|
|
2474
|
+
if not issubclass(func_cls, AggregateFunction):
|
|
2475
|
+
raise TypeError(f"Function '{request.function_name}' is not an AggregateFunction (got {func_cls.__name__})")
|
|
2476
|
+
merge_batch = pa.ipc.open_stream(request.merge_batch).read_next_batch()
|
|
2477
|
+
storage = self._bound(func_cls.storage, request.execution_id, request)
|
|
2478
|
+
|
|
2479
|
+
if merge_batch.num_rows == 0:
|
|
2480
|
+
return AggregateCombineResponse()
|
|
2481
|
+
|
|
2482
|
+
source_ids: list[int] = merge_batch.column("source_group_id").to_pylist() # type: ignore[assignment]
|
|
2483
|
+
target_ids: list[int] = merge_batch.column("target_group_id").to_pylist() # type: ignore[assignment]
|
|
2484
|
+
|
|
2485
|
+
all_gids: list[int] = list(set(source_ids) | set(target_ids))
|
|
2486
|
+
|
|
2487
|
+
if func_cls.state_class is None:
|
|
2488
|
+
raise ValueError(f"Aggregate function '{request.function_name}' has no state_class defined")
|
|
2489
|
+
const_args = self._load_aggregate_const_args(func_cls, storage)
|
|
2490
|
+
params = ProcessParams(
|
|
2491
|
+
args=const_args,
|
|
2492
|
+
init_call=None,
|
|
2493
|
+
init_response=None,
|
|
2494
|
+
output_schema=pa.schema([]),
|
|
2495
|
+
settings={},
|
|
2496
|
+
secrets={},
|
|
2497
|
+
storage=storage,
|
|
2498
|
+
auth_context=ctx.auth,
|
|
2499
|
+
)
|
|
2500
|
+
states: dict[int, Any] = {}
|
|
2501
|
+
all_keys = [BoundStorage.pack_int_key(g) for g in all_gids]
|
|
2502
|
+
stored = storage.state_get_many(FrameworkNS.AGGREGATE_STATE, all_keys)
|
|
2503
|
+
for i, gid in enumerate(all_gids):
|
|
2504
|
+
value = stored[i]
|
|
2505
|
+
if value is not None:
|
|
2506
|
+
states[gid] = func_cls.state_class.deserialize_from_bytes(value)
|
|
2507
|
+
# else: group was never updated — leave absent from states dict
|
|
2508
|
+
|
|
2509
|
+
# Apply merges. Skip pairs where both source and target were never
|
|
2510
|
+
# updated (not in storage). If only one side exists, use
|
|
2511
|
+
# initial_state() for the missing side so combine() has two states.
|
|
2512
|
+
for src_gid, tgt_gid in zip(source_ids, target_ids, strict=True):
|
|
2513
|
+
src = states.get(src_gid)
|
|
2514
|
+
tgt = states.get(tgt_gid)
|
|
2515
|
+
if src is None and tgt is None:
|
|
2516
|
+
continue # Neither side was ever updated — nothing to merge
|
|
2517
|
+
if src is None:
|
|
2518
|
+
src = func_cls.initial_state(params)
|
|
2519
|
+
if tgt is None:
|
|
2520
|
+
tgt = func_cls.initial_state(params)
|
|
2521
|
+
states[tgt_gid] = func_cls.combine(src, tgt, params)
|
|
2522
|
+
|
|
2523
|
+
# Save updated targets back to storage.
|
|
2524
|
+
updated_targets = set(target_ids)
|
|
2525
|
+
items = [
|
|
2526
|
+
(BoundStorage.pack_int_key(gid), states[gid].serialize_to_bytes())
|
|
2527
|
+
for gid in updated_targets
|
|
2528
|
+
if gid in states
|
|
2529
|
+
]
|
|
2530
|
+
if items:
|
|
2531
|
+
storage.state_put_many(FrameworkNS.AGGREGATE_STATE, items)
|
|
2532
|
+
|
|
2533
|
+
return AggregateCombineResponse()
|
|
2534
|
+
|
|
2535
|
+
def aggregate_finalize(
|
|
2536
|
+
self,
|
|
2537
|
+
request: AggregateFinalizeRequest,
|
|
2538
|
+
ctx: CallContext,
|
|
2539
|
+
) -> AggregateFinalizeResponse:
|
|
2540
|
+
"""Produce results for a chunk of group_ids."""
|
|
2541
|
+
from vgi.protocol import AggregateFinalizeResponse
|
|
2542
|
+
|
|
2543
|
+
func_cls = self._resolve_function_by_name(
|
|
2544
|
+
request.function_name, self._unwrap_attach(request.attach_opaque_data), function_type=AggregateFunction
|
|
2545
|
+
)
|
|
2546
|
+
if not issubclass(func_cls, AggregateFunction):
|
|
2547
|
+
raise TypeError(f"Function '{request.function_name}' is not an AggregateFunction (got {func_cls.__name__})")
|
|
2548
|
+
group_ids_batch = pa.ipc.open_stream(request.group_ids_batch).read_next_batch()
|
|
2549
|
+
group_ids: pa.Int64Array = group_ids_batch.column("group_id").cast(pa.int64()) # type: ignore[assignment]
|
|
2550
|
+
gid_list: list[int] = group_ids.to_pylist() # type: ignore[assignment]
|
|
2551
|
+
|
|
2552
|
+
storage = self._bound(func_cls.storage, request.execution_id, request)
|
|
2553
|
+
|
|
2554
|
+
if func_cls.state_class is None:
|
|
2555
|
+
raise ValueError(f"Aggregate function '{request.function_name}' has no state_class defined")
|
|
2556
|
+
const_args = self._load_aggregate_const_args(func_cls, storage)
|
|
2557
|
+
params = ProcessParams(
|
|
2558
|
+
args=const_args,
|
|
2559
|
+
init_call=None,
|
|
2560
|
+
init_response=None,
|
|
2561
|
+
output_schema=request.output_schema,
|
|
2562
|
+
settings={},
|
|
2563
|
+
secrets={},
|
|
2564
|
+
storage=storage,
|
|
2565
|
+
auth_context=ctx.auth,
|
|
2566
|
+
)
|
|
2567
|
+
states: dict[int, Any] = {}
|
|
2568
|
+
gid_keys = [BoundStorage.pack_int_key(g) for g in gid_list]
|
|
2569
|
+
stored = storage.state_get_many(FrameworkNS.AGGREGATE_STATE, gid_keys)
|
|
2570
|
+
for i, gid in enumerate(gid_list):
|
|
2571
|
+
value = stored[i]
|
|
2572
|
+
if value is not None:
|
|
2573
|
+
states[gid] = func_cls.state_class.deserialize_from_bytes(value)
|
|
2574
|
+
else:
|
|
2575
|
+
# Group was never updated — no entry in FunctionStorage.
|
|
2576
|
+
# Pass None so finalize() can return NULL (SQL standard for
|
|
2577
|
+
# SUM/AVG/MIN/MAX over zero rows). COUNT handles None → 0.
|
|
2578
|
+
states[gid] = None
|
|
2579
|
+
|
|
2580
|
+
# Call user's finalize()
|
|
2581
|
+
result_batch = func_cls.finalize(group_ids, states, params)
|
|
2582
|
+
|
|
2583
|
+
# Validate
|
|
2584
|
+
if result_batch.num_rows != len(gid_list):
|
|
2585
|
+
raise ValueError(
|
|
2586
|
+
f"finalize() returned {result_batch.num_rows} rows but expected {len(gid_list)} (one per group_id)"
|
|
2587
|
+
)
|
|
2588
|
+
|
|
2589
|
+
# Serialize result batch to IPC stream bytes
|
|
2590
|
+
sink = pa.BufferOutputStream()
|
|
2591
|
+
with pa.ipc.new_stream(sink, result_batch.schema) as writer:
|
|
2592
|
+
writer.write_batch(result_batch)
|
|
2593
|
+
return AggregateFinalizeResponse(result_batch=sink.getvalue().to_pybytes())
|
|
2594
|
+
|
|
2595
|
+
def aggregate_destructor(
|
|
2596
|
+
self,
|
|
2597
|
+
request: AggregateDestructorRequest,
|
|
2598
|
+
ctx: CallContext,
|
|
2599
|
+
) -> AggregateDestructorResponse:
|
|
2600
|
+
"""Best-effort cleanup of aggregate states."""
|
|
2601
|
+
from vgi.protocol import AggregateDestructorResponse
|
|
2602
|
+
|
|
2603
|
+
func_cls = self._resolve_function_by_name(
|
|
2604
|
+
request.function_name, self._unwrap_attach(request.attach_opaque_data), function_type=AggregateFunction
|
|
2605
|
+
)
|
|
2606
|
+
if not issubclass(func_cls, AggregateFunction):
|
|
2607
|
+
raise TypeError(f"Function '{request.function_name}' is not an AggregateFunction (got {func_cls.__name__})")
|
|
2608
|
+
|
|
2609
|
+
# Called once when all states have been destroyed (C++ tracks with
|
|
2610
|
+
# destroy_counter == group_id_counter). Clear all FunctionStorage
|
|
2611
|
+
# state. execution_clear() sweeps every namespace (every FrameworkNS
|
|
2612
|
+
# member plus any user-chosen namespaces under b"buf"/etc.) for this
|
|
2613
|
+
# execution_id in one call — subsumes the per-family clears we used to
|
|
2614
|
+
# need.
|
|
2615
|
+
storage = self._bound(func_cls.storage, request.execution_id, request)
|
|
2616
|
+
storage.execution_clear()
|
|
2617
|
+
_window_partition_cache.clear_execution(request.execution_id)
|
|
2618
|
+
_aggregate_const_args_cache.clear_execution(storage._shard_key, request.execution_id)
|
|
2619
|
+
|
|
2620
|
+
return AggregateDestructorResponse()
|
|
2621
|
+
|
|
2622
|
+
# ========== Table Sink+Source Function Methods (new buffered API) ==========
|
|
2623
|
+
|
|
2624
|
+
def _load_table_buffering_params(
|
|
2625
|
+
self,
|
|
2626
|
+
request: Any,
|
|
2627
|
+
ctx: CallContext,
|
|
2628
|
+
*,
|
|
2629
|
+
attach_already_unwrapped: bool = False,
|
|
2630
|
+
) -> tuple[type[TableBufferingFunction[Any, Any]], TableBufferingParams[Any]]:
|
|
2631
|
+
"""Cold-load buffering-table init metadata; build ``TableBufferingParams``.
|
|
2632
|
+
|
|
2633
|
+
Accepts either a unary table_buffering_* request (function_name,
|
|
2634
|
+
execution_id, attach_opaque_data, transaction_id at top level)
|
|
2635
|
+
or an ``InitRequest`` (fields under ``bind_call``).
|
|
2636
|
+
|
|
2637
|
+
``attach_already_unwrapped`` — pass True from callers that hold
|
|
2638
|
+
an already-unwrapped attach (e.g. ``run_table_buffering_finalize_tick``,
|
|
2639
|
+
which reads attach off a producer-stream state where init() stashed
|
|
2640
|
+
the unwrapped form). The seal is per-caller-auth; re-unwrapping on
|
|
2641
|
+
a different HTTP turn fails because ``current_auth()`` no longer
|
|
2642
|
+
matches the original sealer's identity.
|
|
2643
|
+
"""
|
|
2644
|
+
function_name = getattr(request, "function_name", None)
|
|
2645
|
+
attach = getattr(request, "attach_opaque_data", None)
|
|
2646
|
+
transaction_id = getattr(request, "transaction_id", None)
|
|
2647
|
+
if function_name is None:
|
|
2648
|
+
function_name = request.bind_call.function_name
|
|
2649
|
+
if attach is None:
|
|
2650
|
+
attach = request.bind_call.attach_opaque_data
|
|
2651
|
+
if transaction_id is None:
|
|
2652
|
+
transaction_id = getattr(
|
|
2653
|
+
request.bind_call,
|
|
2654
|
+
"transaction_opaque_data",
|
|
2655
|
+
None,
|
|
2656
|
+
)
|
|
2657
|
+
# ``attach_full`` is the framework plaintext ``uuid(16) || catalog_bytes``.
|
|
2658
|
+
# Already-unwrapped callers (buffering finalize/cancel) pass it directly —
|
|
2659
|
+
# the auth-scoped seal can't be reopened on their turn; everyone else
|
|
2660
|
+
# unwraps the sealed attach here. Storage shards on the UUID; the function
|
|
2661
|
+
# resolver and the user field see only the catalog bytes.
|
|
2662
|
+
attach_full = attach if attach_already_unwrapped else self._unwrap_attach_full(attach)
|
|
2663
|
+
catalog_bytes = attach_catalog_bytes(attach_full)
|
|
2664
|
+
func_cls = self._resolve_function_by_name(
|
|
2665
|
+
function_name,
|
|
2666
|
+
catalog_bytes,
|
|
2667
|
+
function_type=TableBufferingFunction,
|
|
2668
|
+
)
|
|
2669
|
+
if not issubclass(func_cls, TableBufferingFunction):
|
|
2670
|
+
raise TypeError(f"Function '{function_name}' is not a TableBufferingFunction (got {func_cls.__name__})")
|
|
2671
|
+
cold_storage = BoundStorage(
|
|
2672
|
+
func_cls.storage,
|
|
2673
|
+
request.execution_id,
|
|
2674
|
+
request=request,
|
|
2675
|
+
attach_plaintext=attach_full,
|
|
2676
|
+
)
|
|
2677
|
+
payload = cold_storage.state_get(
|
|
2678
|
+
FrameworkNS.BUFFERING_INIT,
|
|
2679
|
+
BoundStorage.pack_int_key(_TABLE_BUFFERING_INIT_KEY),
|
|
2680
|
+
)
|
|
2681
|
+
if payload is None:
|
|
2682
|
+
raise OSError(
|
|
2683
|
+
f"table_buffering: unknown execution_id "
|
|
2684
|
+
f"{request.execution_id.hex()} "
|
|
2685
|
+
f"(init never ran or destructor already fired)"
|
|
2686
|
+
)
|
|
2687
|
+
init_call, init_response = _decode_table_buffering_init(payload)
|
|
2688
|
+
# AttachOpaqueData is a NewType over bytes; just pass the bytes
|
|
2689
|
+
# through. Empty when the catalog doesn't seal attaches.
|
|
2690
|
+
attach_id = bytes(catalog_bytes or b"")
|
|
2691
|
+
params = TableBufferingParams(
|
|
2692
|
+
args=func_cls._parse_arguments(
|
|
2693
|
+
func_cls.FunctionArguments,
|
|
2694
|
+
init_call.bind_call.arguments,
|
|
2695
|
+
),
|
|
2696
|
+
init_call=init_call,
|
|
2697
|
+
init_response=init_response,
|
|
2698
|
+
output_schema=init_call.output_schema,
|
|
2699
|
+
settings=_batch_to_scalar_dict(init_call.bind_call.settings),
|
|
2700
|
+
secrets=SecretsAccessor(init_call.bind_call.secrets).to_dict(),
|
|
2701
|
+
storage=cold_storage,
|
|
2702
|
+
auth_context=ctx.auth,
|
|
2703
|
+
attach_opaque_data=catalog_bytes,
|
|
2704
|
+
execution_id=request.execution_id,
|
|
2705
|
+
attach_id=attach_id,
|
|
2706
|
+
transaction_id=transaction_id,
|
|
2707
|
+
function_name=function_name,
|
|
2708
|
+
worker_path=None,
|
|
2709
|
+
# process() and combine() are unary RPCs and have no
|
|
2710
|
+
# OutputCollector; expose ctx.client_log on params so they can
|
|
2711
|
+
# emit log batches that surface in DuckDB's duckdb_logs() with
|
|
2712
|
+
# type='VGI'. Mirrors out.client_log() on the streaming
|
|
2713
|
+
# finalize() callback.
|
|
2714
|
+
client_log=ctx.client_log,
|
|
2715
|
+
)
|
|
2716
|
+
return func_cls, params
|
|
2717
|
+
|
|
2718
|
+
def table_buffering_process(
|
|
2719
|
+
self,
|
|
2720
|
+
request: Any,
|
|
2721
|
+
ctx: CallContext,
|
|
2722
|
+
) -> Any:
|
|
2723
|
+
"""Sink one input batch; return worker-chosen state_id (unary)."""
|
|
2724
|
+
from vgi.protocol import TableBufferingProcessResponse
|
|
2725
|
+
|
|
2726
|
+
func_cls, params = self._load_table_buffering_params(request, ctx)
|
|
2727
|
+
if request.batch_index is not None:
|
|
2728
|
+
params = _dataclass_replace(params, batch_index=request.batch_index)
|
|
2729
|
+
batch = pa.ipc.open_stream(request.input_batch).read_next_batch()
|
|
2730
|
+
state_id = func_cls.process(batch, params)
|
|
2731
|
+
if not isinstance(state_id, (bytes, bytearray)):
|
|
2732
|
+
raise TypeError(
|
|
2733
|
+
f"{func_cls.__name__}.process() returned "
|
|
2734
|
+
f"{type(state_id).__name__}; must return bytes "
|
|
2735
|
+
f"(the opaque state_id)"
|
|
2736
|
+
)
|
|
2737
|
+
return TableBufferingProcessResponse(state_id=bytes(state_id))
|
|
2738
|
+
|
|
2739
|
+
def table_buffering_combine(
|
|
2740
|
+
self,
|
|
2741
|
+
request: Any,
|
|
2742
|
+
ctx: CallContext,
|
|
2743
|
+
) -> Any:
|
|
2744
|
+
"""End-of-input bridge: hand all state_ids to user combine()."""
|
|
2745
|
+
from vgi.protocol import TableBufferingCombineResponse
|
|
2746
|
+
|
|
2747
|
+
func_cls, params = self._load_table_buffering_params(request, ctx)
|
|
2748
|
+
finalize_state_ids = list(
|
|
2749
|
+
func_cls.combine(list(request.state_ids), params),
|
|
2750
|
+
)
|
|
2751
|
+
for i, fid in enumerate(finalize_state_ids):
|
|
2752
|
+
if not isinstance(fid, (bytes, bytearray)):
|
|
2753
|
+
raise TypeError(
|
|
2754
|
+
f"{func_cls.__name__}.combine() returned non-bytes "
|
|
2755
|
+
f"finalize_state_id at index {i}: "
|
|
2756
|
+
f"{type(fid).__name__}"
|
|
2757
|
+
)
|
|
2758
|
+
return TableBufferingCombineResponse(
|
|
2759
|
+
finalize_state_ids=[bytes(fid) for fid in finalize_state_ids],
|
|
2760
|
+
)
|
|
2761
|
+
|
|
2762
|
+
def table_buffering_destructor(
|
|
2763
|
+
self,
|
|
2764
|
+
request: Any,
|
|
2765
|
+
ctx: CallContext,
|
|
2766
|
+
) -> Any:
|
|
2767
|
+
"""Best-effort end-of-query cleanup."""
|
|
2768
|
+
from vgi.protocol import TableBufferingDestructorResponse
|
|
2769
|
+
|
|
2770
|
+
try:
|
|
2771
|
+
func_cls = self._resolve_function_by_name(
|
|
2772
|
+
request.function_name,
|
|
2773
|
+
self._unwrap_attach(request.attach_opaque_data),
|
|
2774
|
+
function_type=TableBufferingFunction,
|
|
2775
|
+
)
|
|
2776
|
+
storage = self._bound(func_cls.storage, request.execution_id, request)
|
|
2777
|
+
storage.execution_clear()
|
|
2778
|
+
except Exception:
|
|
2779
|
+
_logger.exception(
|
|
2780
|
+
"table_buffering_destructor: storage cleanup failed (execution_id=%s)",
|
|
2781
|
+
request.execution_id.hex(),
|
|
2782
|
+
)
|
|
2783
|
+
return TableBufferingDestructorResponse()
|
|
2784
|
+
|
|
2785
|
+
# ========== Windowed Aggregate Methods ==========
|
|
2786
|
+
|
|
2787
|
+
def aggregate_window_init(
|
|
2788
|
+
self,
|
|
2789
|
+
request: AggregateWindowInitRequest,
|
|
2790
|
+
ctx: CallContext,
|
|
2791
|
+
) -> AggregateWindowInitResponse:
|
|
2792
|
+
"""Cache a partition on the worker for windowed aggregation."""
|
|
2793
|
+
from vgi.aggregate_function import WindowPartition
|
|
2794
|
+
from vgi.protocol import AggregateWindowInitResponse
|
|
2795
|
+
|
|
2796
|
+
func_cls = self._resolve_function_by_name(
|
|
2797
|
+
request.function_name, self._unwrap_attach(request.attach_opaque_data), function_type=AggregateFunction
|
|
2798
|
+
)
|
|
2799
|
+
if not issubclass(func_cls, AggregateFunction):
|
|
2800
|
+
raise TypeError(f"Function '{request.function_name}' is not an AggregateFunction (got {func_cls.__name__})")
|
|
2801
|
+
|
|
2802
|
+
storage = self._bound(func_cls.storage, request.execution_id, request)
|
|
2803
|
+
const_args = self._load_aggregate_const_args(func_cls, storage)
|
|
2804
|
+
params = ProcessParams(
|
|
2805
|
+
args=const_args,
|
|
2806
|
+
init_call=None,
|
|
2807
|
+
init_response=None,
|
|
2808
|
+
output_schema=request.output_schema,
|
|
2809
|
+
settings={},
|
|
2810
|
+
secrets={},
|
|
2811
|
+
storage=storage,
|
|
2812
|
+
auth_context=ctx.auth,
|
|
2813
|
+
)
|
|
2814
|
+
|
|
2815
|
+
partition_batch = pa.ipc.open_stream(request.partition_batch).read_next_batch()
|
|
2816
|
+
filter_mask = _unpack_bool_mask(request.filter_mask, request.row_count)
|
|
2817
|
+
frame_stats = _unpack_frame_stats(request.frame_stats)
|
|
2818
|
+
all_valid = _unpack_all_valid(request.all_valid, partition_batch.num_columns)
|
|
2819
|
+
|
|
2820
|
+
partition = WindowPartition(
|
|
2821
|
+
inputs=partition_batch,
|
|
2822
|
+
row_count=request.row_count,
|
|
2823
|
+
filter_mask=filter_mask,
|
|
2824
|
+
frame_stats=frame_stats,
|
|
2825
|
+
all_valid=all_valid,
|
|
2826
|
+
)
|
|
2827
|
+
|
|
2828
|
+
window_state = func_cls.window_init(partition, params)
|
|
2829
|
+
window_state_bytes: bytes | None = None
|
|
2830
|
+
if window_state is not None:
|
|
2831
|
+
if not hasattr(window_state, "serialize_to_bytes"):
|
|
2832
|
+
raise TypeError(
|
|
2833
|
+
f"{func_cls.__name__}.window_init() must return an ArrowSerializableDataclass "
|
|
2834
|
+
f"or None, got {type(window_state).__name__}"
|
|
2835
|
+
)
|
|
2836
|
+
window_state_bytes = window_state.serialize_to_bytes()
|
|
2837
|
+
|
|
2838
|
+
payload = _encode_window_partition_cache(
|
|
2839
|
+
partition_batch_bytes=request.partition_batch,
|
|
2840
|
+
output_schema_bytes=_serialize_schema_bytes(request.output_schema),
|
|
2841
|
+
filter_mask_bytes=request.filter_mask,
|
|
2842
|
+
frame_stats_bytes=request.frame_stats,
|
|
2843
|
+
all_valid_bytes=request.all_valid,
|
|
2844
|
+
row_count=request.row_count,
|
|
2845
|
+
window_state_bytes=window_state_bytes,
|
|
2846
|
+
window_state_class_name=type(window_state).__name__ if window_state is not None else "",
|
|
2847
|
+
)
|
|
2848
|
+
storage.state_put(
|
|
2849
|
+
FrameworkNS.AGGREGATE_WINDOW_PARTITION,
|
|
2850
|
+
BoundStorage.pack_int_key(request.partition_id),
|
|
2851
|
+
payload,
|
|
2852
|
+
)
|
|
2853
|
+
|
|
2854
|
+
# Populate the in-process cache with the already-decoded partition
|
|
2855
|
+
# so aggregate_window() can skip the storage read + deserialize.
|
|
2856
|
+
cache_window_state: Any = None
|
|
2857
|
+
if window_state is not None and window_state_bytes is not None:
|
|
2858
|
+
cache_window_state = _WindowStatePlaceholder(
|
|
2859
|
+
raw_bytes=window_state_bytes,
|
|
2860
|
+
class_name=type(window_state).__name__,
|
|
2861
|
+
)
|
|
2862
|
+
# Hand the just-built window_state (the typed dataclass, not the
|
|
2863
|
+
# placeholder) to the optional window_prepare hook. Result lives
|
|
2864
|
+
# alongside the placeholder in the in-memory cache.
|
|
2865
|
+
prepared_state = func_cls.window_prepare(
|
|
2866
|
+
partition,
|
|
2867
|
+
window_state,
|
|
2868
|
+
params,
|
|
2869
|
+
)
|
|
2870
|
+
_window_partition_cache.put(
|
|
2871
|
+
request.execution_id,
|
|
2872
|
+
request.partition_id,
|
|
2873
|
+
_CachedWindowPartition(
|
|
2874
|
+
partition=partition,
|
|
2875
|
+
output_schema=request.output_schema,
|
|
2876
|
+
window_state=cache_window_state,
|
|
2877
|
+
prepared_state=prepared_state,
|
|
2878
|
+
),
|
|
2879
|
+
)
|
|
2880
|
+
|
|
2881
|
+
return AggregateWindowInitResponse()
|
|
2882
|
+
|
|
2883
|
+
def aggregate_window(
|
|
2884
|
+
self,
|
|
2885
|
+
request: AggregateWindowRequest,
|
|
2886
|
+
ctx: CallContext,
|
|
2887
|
+
) -> AggregateWindowResponse:
|
|
2888
|
+
"""Compute one output row for a windowed aggregate."""
|
|
2889
|
+
from vgi.protocol import AggregateWindowResponse
|
|
2890
|
+
|
|
2891
|
+
func_cls = self._resolve_function_by_name(
|
|
2892
|
+
request.function_name, self._unwrap_attach(request.attach_opaque_data), function_type=AggregateFunction
|
|
2893
|
+
)
|
|
2894
|
+
if not issubclass(func_cls, AggregateFunction):
|
|
2895
|
+
raise TypeError(f"Function '{request.function_name}' is not an AggregateFunction (got {func_cls.__name__})")
|
|
2896
|
+
|
|
2897
|
+
storage = self._bound(func_cls.storage, request.execution_id, request)
|
|
2898
|
+
cached = self._load_cached_window_partition(
|
|
2899
|
+
func_cls, request.execution_id, request.partition_id, storage, request.function_name
|
|
2900
|
+
)
|
|
2901
|
+
partition = cached.partition
|
|
2902
|
+
output_schema = cached.output_schema
|
|
2903
|
+
|
|
2904
|
+
const_args = self._load_aggregate_const_args(func_cls, storage)
|
|
2905
|
+
params = ProcessParams(
|
|
2906
|
+
args=const_args,
|
|
2907
|
+
init_call=None,
|
|
2908
|
+
init_response=None,
|
|
2909
|
+
output_schema=output_schema,
|
|
2910
|
+
settings={},
|
|
2911
|
+
secrets={},
|
|
2912
|
+
storage=storage,
|
|
2913
|
+
auth_context=ctx.auth,
|
|
2914
|
+
)
|
|
2915
|
+
|
|
2916
|
+
# Lazily populate prepared_state on first access — covers cold reloads
|
|
2917
|
+
# from FunctionStorage where _load_cached_window_partition can't yet
|
|
2918
|
+
# call window_prepare (no params available there).
|
|
2919
|
+
if cached.prepared_state is None:
|
|
2920
|
+
cached.prepared_state = func_cls.window_prepare(
|
|
2921
|
+
partition,
|
|
2922
|
+
cached.window_state,
|
|
2923
|
+
params,
|
|
2924
|
+
)
|
|
2925
|
+
|
|
2926
|
+
subframes = list(zip(request.frame_starts, request.frame_ends, strict=True))
|
|
2927
|
+
result_value = func_cls.window(
|
|
2928
|
+
request.rid,
|
|
2929
|
+
subframes,
|
|
2930
|
+
partition,
|
|
2931
|
+
cached.prepared_state,
|
|
2932
|
+
params,
|
|
2933
|
+
)
|
|
2934
|
+
|
|
2935
|
+
# Build a one-row result batch matching output_schema
|
|
2936
|
+
result_batch = _build_scalar_result_batch(result_value, output_schema)
|
|
2937
|
+
sink = pa.BufferOutputStream()
|
|
2938
|
+
with pa.ipc.new_stream(sink, result_batch.schema) as writer:
|
|
2939
|
+
writer.write_batch(result_batch)
|
|
2940
|
+
return AggregateWindowResponse(result_batch=sink.getvalue().to_pybytes())
|
|
2941
|
+
|
|
2942
|
+
def _load_cached_window_partition(
|
|
2943
|
+
self,
|
|
2944
|
+
func_cls: type,
|
|
2945
|
+
execution_id: bytes,
|
|
2946
|
+
partition_id: int,
|
|
2947
|
+
storage: BoundStorage,
|
|
2948
|
+
function_name: str,
|
|
2949
|
+
) -> _CachedWindowPartition:
|
|
2950
|
+
"""Fetch the decoded partition from the in-process cache.
|
|
2951
|
+
|
|
2952
|
+
Falls back to storage on a cache miss (multi-process HTTP, LRU
|
|
2953
|
+
eviction, or worker restart). Raises IOError if the partition is
|
|
2954
|
+
unknown — window_init never ran, or the destructor already fired.
|
|
2955
|
+
|
|
2956
|
+
``prepared_state`` is left as ``None`` on cold reload; the dispatcher
|
|
2957
|
+
(aggregate_window or aggregate_window_batch) will lazily populate it
|
|
2958
|
+
on first access via ``func_cls.window_prepare``. That keeps this
|
|
2959
|
+
function params-free, matching its previous signature.
|
|
2960
|
+
"""
|
|
2961
|
+
from vgi.aggregate_function import WindowPartition
|
|
2962
|
+
|
|
2963
|
+
cached = _window_partition_cache.get(execution_id, partition_id)
|
|
2964
|
+
if cached is not None:
|
|
2965
|
+
return cached
|
|
2966
|
+
|
|
2967
|
+
payload = storage.state_get(FrameworkNS.AGGREGATE_WINDOW_PARTITION, BoundStorage.pack_int_key(partition_id))
|
|
2968
|
+
if payload is None:
|
|
2969
|
+
raise OSError(
|
|
2970
|
+
f"aggregate_window called for unknown partition_id={partition_id} "
|
|
2971
|
+
f"(function {function_name}); window_init never ran or destructor already fired"
|
|
2972
|
+
)
|
|
2973
|
+
decoded = _decode_window_partition_cache(payload)
|
|
2974
|
+
partition_batch = pa.ipc.open_stream(decoded["partition_batch"]).read_next_batch()
|
|
2975
|
+
output_schema = pa.ipc.open_stream(decoded["output_schema"]).schema
|
|
2976
|
+
filter_mask = _unpack_bool_mask(decoded["filter_mask"], decoded["row_count"])
|
|
2977
|
+
frame_stats = _unpack_frame_stats(decoded["frame_stats"])
|
|
2978
|
+
all_valid = _unpack_all_valid(decoded["all_valid"], partition_batch.num_columns)
|
|
2979
|
+
|
|
2980
|
+
partition = WindowPartition(
|
|
2981
|
+
inputs=partition_batch,
|
|
2982
|
+
row_count=decoded["row_count"],
|
|
2983
|
+
filter_mask=filter_mask,
|
|
2984
|
+
frame_stats=frame_stats,
|
|
2985
|
+
all_valid=all_valid,
|
|
2986
|
+
)
|
|
2987
|
+
window_state: Any = None
|
|
2988
|
+
if decoded["window_state"] is not None:
|
|
2989
|
+
window_state = _WindowStatePlaceholder(
|
|
2990
|
+
raw_bytes=decoded["window_state"],
|
|
2991
|
+
class_name=decoded["window_state_class_name"],
|
|
2992
|
+
)
|
|
2993
|
+
cached = _CachedWindowPartition(
|
|
2994
|
+
partition=partition,
|
|
2995
|
+
output_schema=output_schema,
|
|
2996
|
+
window_state=window_state,
|
|
2997
|
+
prepared_state=None, # populated lazily by the dispatcher
|
|
2998
|
+
)
|
|
2999
|
+
_window_partition_cache.put(execution_id, partition_id, cached)
|
|
3000
|
+
return cached
|
|
3001
|
+
|
|
3002
|
+
def aggregate_window_batch(
|
|
3003
|
+
self,
|
|
3004
|
+
request: AggregateWindowBatchRequest,
|
|
3005
|
+
ctx: CallContext,
|
|
3006
|
+
) -> AggregateWindowBatchResponse:
|
|
3007
|
+
"""Compute ``count`` window output rows in a single batched RPC."""
|
|
3008
|
+
from vgi.protocol import AggregateWindowBatchResponse
|
|
3009
|
+
|
|
3010
|
+
func_cls = self._resolve_function_by_name(
|
|
3011
|
+
request.function_name, self._unwrap_attach(request.attach_opaque_data), function_type=AggregateFunction
|
|
3012
|
+
)
|
|
3013
|
+
if not issubclass(func_cls, AggregateFunction):
|
|
3014
|
+
raise TypeError(f"Function '{request.function_name}' is not an AggregateFunction (got {func_cls.__name__})")
|
|
3015
|
+
|
|
3016
|
+
storage = self._bound(func_cls.storage, request.execution_id, request)
|
|
3017
|
+
cached = self._load_cached_window_partition(
|
|
3018
|
+
func_cls, request.execution_id, request.partition_id, storage, request.function_name
|
|
3019
|
+
)
|
|
3020
|
+
partition = cached.partition
|
|
3021
|
+
output_schema = cached.output_schema
|
|
3022
|
+
|
|
3023
|
+
const_args = self._load_aggregate_const_args(func_cls, storage)
|
|
3024
|
+
params = ProcessParams(
|
|
3025
|
+
args=const_args,
|
|
3026
|
+
init_call=None,
|
|
3027
|
+
init_response=None,
|
|
3028
|
+
output_schema=output_schema,
|
|
3029
|
+
settings={},
|
|
3030
|
+
secrets={},
|
|
3031
|
+
storage=storage,
|
|
3032
|
+
auth_context=ctx.auth,
|
|
3033
|
+
)
|
|
3034
|
+
|
|
3035
|
+
# Lazily populate prepared_state on first batch — covers cold reloads
|
|
3036
|
+
# from FunctionStorage where _load_cached_window_partition can't yet
|
|
3037
|
+
# call window_prepare (no params available there).
|
|
3038
|
+
if cached.prepared_state is None:
|
|
3039
|
+
cached.prepared_state = func_cls.window_prepare(
|
|
3040
|
+
partition,
|
|
3041
|
+
cached.window_state,
|
|
3042
|
+
params,
|
|
3043
|
+
)
|
|
3044
|
+
|
|
3045
|
+
# Unflatten subframes: frame_starts/frame_ends are concatenated across
|
|
3046
|
+
# all rows, frames_per_row[i] gives the slice length for row i.
|
|
3047
|
+
starts = request.frame_starts
|
|
3048
|
+
ends = request.frame_ends
|
|
3049
|
+
frames_per_row = request.frames_per_row
|
|
3050
|
+
if request.count != len(frames_per_row):
|
|
3051
|
+
raise ValueError(
|
|
3052
|
+
f"aggregate_window_batch: count={request.count} but frames_per_row has {len(frames_per_row)} entries"
|
|
3053
|
+
)
|
|
3054
|
+
|
|
3055
|
+
row_ids: list[int] = [request.row_idx + i for i in range(request.count)]
|
|
3056
|
+
subframes_per_row: list[list[tuple[int, int]]] = []
|
|
3057
|
+
offset = 0
|
|
3058
|
+
for n in frames_per_row:
|
|
3059
|
+
subframes_per_row.append([(starts[offset + k], ends[offset + k]) for k in range(n)])
|
|
3060
|
+
offset += n
|
|
3061
|
+
|
|
3062
|
+
# User code may override window_batch to build the output as a single
|
|
3063
|
+
# pa.Array — bypassing per-row Python object construction and the
|
|
3064
|
+
# subsequent pa.array(...) conversion. The default falls back to
|
|
3065
|
+
# window() per row, preserving prior behaviour.
|
|
3066
|
+
results = func_cls.window_batch(
|
|
3067
|
+
row_ids,
|
|
3068
|
+
subframes_per_row,
|
|
3069
|
+
partition,
|
|
3070
|
+
cached.prepared_state,
|
|
3071
|
+
params,
|
|
3072
|
+
)
|
|
3073
|
+
|
|
3074
|
+
result_batch = _build_batch_result(results, output_schema, expected_count=request.count)
|
|
3075
|
+
sink = pa.BufferOutputStream()
|
|
3076
|
+
with pa.ipc.new_stream(sink, result_batch.schema) as writer:
|
|
3077
|
+
writer.write_batch(result_batch)
|
|
3078
|
+
return AggregateWindowBatchResponse(result_batch=sink.getvalue().to_pybytes())
|
|
3079
|
+
|
|
3080
|
+
def aggregate_window_destructor(
|
|
3081
|
+
self,
|
|
3082
|
+
request: AggregateWindowDestructorRequest,
|
|
3083
|
+
ctx: CallContext,
|
|
3084
|
+
) -> AggregateWindowDestructorResponse:
|
|
3085
|
+
"""Evict a cached partition from storage."""
|
|
3086
|
+
from vgi.protocol import AggregateWindowDestructorResponse
|
|
3087
|
+
|
|
3088
|
+
func_cls = self._resolve_function_by_name(
|
|
3089
|
+
request.function_name, self._unwrap_attach(request.attach_opaque_data), function_type=AggregateFunction
|
|
3090
|
+
)
|
|
3091
|
+
if not issubclass(func_cls, AggregateFunction):
|
|
3092
|
+
raise TypeError(f"Function '{request.function_name}' is not an AggregateFunction (got {func_cls.__name__})")
|
|
3093
|
+
|
|
3094
|
+
storage = self._bound(func_cls.storage, request.execution_id, request)
|
|
3095
|
+
storage.state_delete(FrameworkNS.AGGREGATE_WINDOW_PARTITION, [BoundStorage.pack_int_key(request.partition_id)])
|
|
3096
|
+
_window_partition_cache.delete(request.execution_id, request.partition_id)
|
|
3097
|
+
return AggregateWindowDestructorResponse()
|
|
3098
|
+
|
|
3099
|
+
def aggregate_streaming_open(
|
|
3100
|
+
self,
|
|
3101
|
+
request: AggregateStreamingOpenRequest,
|
|
3102
|
+
ctx: CallContext,
|
|
3103
|
+
) -> AggregateStreamingOpenResponse:
|
|
3104
|
+
"""Open a streaming-partitioned aggregate session."""
|
|
3105
|
+
from vgi.protocol import AggregateStreamingOpenResponse
|
|
3106
|
+
|
|
3107
|
+
func_cls = self._resolve_function_by_name(
|
|
3108
|
+
request.function_name, self._unwrap_attach(request.attach_opaque_data), function_type=AggregateFunction
|
|
3109
|
+
)
|
|
3110
|
+
if not issubclass(func_cls, AggregateFunction):
|
|
3111
|
+
raise TypeError(f"Function '{request.function_name}' is not an AggregateFunction (got {func_cls.__name__})")
|
|
3112
|
+
|
|
3113
|
+
execution_id = uuid.uuid4().bytes
|
|
3114
|
+
|
|
3115
|
+
# Stash const args (mirrors aggregate_bind behavior) so streaming_chunk
|
|
3116
|
+
# can rehydrate them via _load_aggregate_const_args if the function
|
|
3117
|
+
# declares const params.
|
|
3118
|
+
if request.arguments and request.arguments.positional:
|
|
3119
|
+
storage = self._bound(func_cls.storage, execution_id, request)
|
|
3120
|
+
storage.state_put(
|
|
3121
|
+
FrameworkNS.AGGREGATE_STATE, BoundStorage.pack_int_key(-2), request.arguments.serialize_to_bytes()
|
|
3122
|
+
)
|
|
3123
|
+
|
|
3124
|
+
storage = self._bound(func_cls.storage, execution_id, request)
|
|
3125
|
+
const_args = self._load_aggregate_const_args(func_cls, storage)
|
|
3126
|
+
params = ProcessParams(
|
|
3127
|
+
args=const_args,
|
|
3128
|
+
init_call=None,
|
|
3129
|
+
init_response=None,
|
|
3130
|
+
output_schema=request.output_schema,
|
|
3131
|
+
settings=_batch_to_scalar_dict(request.settings),
|
|
3132
|
+
secrets={},
|
|
3133
|
+
storage=storage,
|
|
3134
|
+
auth_context=ctx.auth,
|
|
3135
|
+
)
|
|
3136
|
+
|
|
3137
|
+
streaming_state = func_cls.streaming_open(params)
|
|
3138
|
+
|
|
3139
|
+
session = _StreamingSession(
|
|
3140
|
+
func_cls=func_cls,
|
|
3141
|
+
streaming_state=streaming_state,
|
|
3142
|
+
output_schema=request.output_schema,
|
|
3143
|
+
partition_key_count=request.partition_key_count,
|
|
3144
|
+
order_key_count=request.order_key_count,
|
|
3145
|
+
)
|
|
3146
|
+
_streaming_session_cache.put(execution_id, session)
|
|
3147
|
+
# Also persist to FunctionStorage so a chunk RPC landing on a
|
|
3148
|
+
# different pool worker can reload the session.
|
|
3149
|
+
storage.state_put(
|
|
3150
|
+
FrameworkNS.STREAMING_SESSION,
|
|
3151
|
+
BoundStorage.pack_int_key(_STREAMING_SESSION_STORAGE_KEY),
|
|
3152
|
+
_encode_streaming_session(session),
|
|
3153
|
+
)
|
|
3154
|
+
return AggregateStreamingOpenResponse(execution_id=execution_id)
|
|
3155
|
+
|
|
3156
|
+
def aggregate_streaming_chunk(
|
|
3157
|
+
self,
|
|
3158
|
+
request: AggregateStreamingChunkRequest,
|
|
3159
|
+
ctx: CallContext,
|
|
3160
|
+
) -> AggregateStreamingChunkResponse:
|
|
3161
|
+
"""Process one chunk of streaming input."""
|
|
3162
|
+
from vgi.protocol import AggregateStreamingChunkResponse
|
|
3163
|
+
|
|
3164
|
+
session = _streaming_session_cache.get(request.execution_id)
|
|
3165
|
+
if session is None:
|
|
3166
|
+
# Cold reload — the session may have been opened on a different
|
|
3167
|
+
# pool worker. Look it up in FunctionStorage.
|
|
3168
|
+
func_cls = self._resolve_function_by_name(
|
|
3169
|
+
request.function_name, self._unwrap_attach(request.attach_opaque_data), function_type=AggregateFunction
|
|
3170
|
+
)
|
|
3171
|
+
if not issubclass(func_cls, AggregateFunction):
|
|
3172
|
+
raise TypeError(
|
|
3173
|
+
f"Function '{request.function_name}' is not an AggregateFunction (got {func_cls.__name__})"
|
|
3174
|
+
)
|
|
3175
|
+
cold_storage = self._bound(func_cls.storage, request.execution_id, request)
|
|
3176
|
+
import time as _t
|
|
3177
|
+
|
|
3178
|
+
t_get_start = _t.perf_counter()
|
|
3179
|
+
payload = cold_storage.state_get(
|
|
3180
|
+
FrameworkNS.STREAMING_SESSION,
|
|
3181
|
+
BoundStorage.pack_int_key(_STREAMING_SESSION_STORAGE_KEY),
|
|
3182
|
+
)
|
|
3183
|
+
t_get = _t.perf_counter() - t_get_start
|
|
3184
|
+
with _streaming_persist_lock:
|
|
3185
|
+
_streaming_persist_stats["storage_get_seconds"] += t_get
|
|
3186
|
+
_streaming_persist_stats["n_cold_loads"] += 1
|
|
3187
|
+
if payload is None:
|
|
3188
|
+
raise OSError(
|
|
3189
|
+
"aggregate_streaming_chunk: unknown execution_id (streaming_open never ran or close already fired)"
|
|
3190
|
+
)
|
|
3191
|
+
session = _decode_streaming_session(payload, func_cls)
|
|
3192
|
+
_streaming_session_cache.put(request.execution_id, session)
|
|
3193
|
+
|
|
3194
|
+
chunk = pa.ipc.open_stream(request.input_batch).read_next_batch()
|
|
3195
|
+
|
|
3196
|
+
storage = self._bound(session.func_cls.storage, request.execution_id, request)
|
|
3197
|
+
const_args = self._load_aggregate_const_args(session.func_cls, storage)
|
|
3198
|
+
params = ProcessParams(
|
|
3199
|
+
args=const_args,
|
|
3200
|
+
init_call=None,
|
|
3201
|
+
init_response=None,
|
|
3202
|
+
output_schema=session.output_schema,
|
|
3203
|
+
settings={},
|
|
3204
|
+
secrets={},
|
|
3205
|
+
storage=storage,
|
|
3206
|
+
auth_context=ctx.auth,
|
|
3207
|
+
)
|
|
3208
|
+
|
|
3209
|
+
result = session.func_cls.streaming_chunk(
|
|
3210
|
+
chunk,
|
|
3211
|
+
session.streaming_state,
|
|
3212
|
+
session.partition_key_count,
|
|
3213
|
+
session.order_key_count,
|
|
3214
|
+
params,
|
|
3215
|
+
)
|
|
3216
|
+
|
|
3217
|
+
# Accept either a pa.Array (preferred) or a Python list. Coerce to
|
|
3218
|
+
# an Arrow array of the function's output type, then wrap in a
|
|
3219
|
+
# one-column RecordBatch for IPC transport.
|
|
3220
|
+
if isinstance(result, pa.Array):
|
|
3221
|
+
result_array = result
|
|
3222
|
+
else:
|
|
3223
|
+
result_array = pa.array(result, type=session.output_schema.field(0).type)
|
|
3224
|
+
|
|
3225
|
+
if len(result_array) != chunk.num_rows:
|
|
3226
|
+
raise ValueError(f"streaming_chunk returned {len(result_array)} values for {chunk.num_rows} input rows")
|
|
3227
|
+
|
|
3228
|
+
result_batch = pa.RecordBatch.from_arrays([result_array], schema=session.output_schema)
|
|
3229
|
+
sink = pa.BufferOutputStream()
|
|
3230
|
+
with pa.ipc.new_stream(sink, result_batch.schema) as writer:
|
|
3231
|
+
writer.write_batch(result_batch)
|
|
3232
|
+
# Persist updated session so the next chunk (possibly on a different
|
|
3233
|
+
# pool worker) sees the same state.
|
|
3234
|
+
import time as _t
|
|
3235
|
+
|
|
3236
|
+
t_enc_start = _t.perf_counter()
|
|
3237
|
+
payload = _encode_streaming_session(session)
|
|
3238
|
+
t_enc = _t.perf_counter() - t_enc_start
|
|
3239
|
+
t_put_start = _t.perf_counter()
|
|
3240
|
+
storage.state_put(
|
|
3241
|
+
FrameworkNS.STREAMING_SESSION,
|
|
3242
|
+
BoundStorage.pack_int_key(_STREAMING_SESSION_STORAGE_KEY),
|
|
3243
|
+
payload,
|
|
3244
|
+
)
|
|
3245
|
+
t_put = _t.perf_counter() - t_put_start
|
|
3246
|
+
_record_persist_timing(t_enc, t_put, len(payload))
|
|
3247
|
+
with _streaming_persist_lock:
|
|
3248
|
+
_streaming_persist_stats["n_chunks"] += 1
|
|
3249
|
+
return AggregateStreamingChunkResponse(result_batch=sink.getvalue().to_pybytes())
|
|
3250
|
+
|
|
3251
|
+
def aggregate_streaming_close(
|
|
3252
|
+
self,
|
|
3253
|
+
request: AggregateStreamingCloseRequest,
|
|
3254
|
+
ctx: CallContext,
|
|
3255
|
+
) -> AggregateStreamingCloseResponse:
|
|
3256
|
+
"""End a streaming-partitioned aggregate session."""
|
|
3257
|
+
from vgi.protocol import AggregateStreamingCloseResponse
|
|
3258
|
+
|
|
3259
|
+
session = _streaming_session_cache.pop(request.execution_id)
|
|
3260
|
+
if session is None:
|
|
3261
|
+
# Cold close — session may have been opened on a different pool
|
|
3262
|
+
# worker. Best-effort load to fire streaming_close on the user's
|
|
3263
|
+
# state (so they get the cleanup callback they expect), then
|
|
3264
|
+
# delete from storage. If load fails too, just drop.
|
|
3265
|
+
try:
|
|
3266
|
+
func_cls = self._resolve_function_by_name(
|
|
3267
|
+
request.function_name,
|
|
3268
|
+
self._unwrap_attach(request.attach_opaque_data),
|
|
3269
|
+
function_type=AggregateFunction,
|
|
3270
|
+
)
|
|
3271
|
+
except Exception: # noqa: BLE001
|
|
3272
|
+
func_cls = None
|
|
3273
|
+
if func_cls is not None and issubclass(func_cls, AggregateFunction):
|
|
3274
|
+
cold_storage = self._bound(func_cls.storage, request.execution_id, request)
|
|
3275
|
+
payload = cold_storage.state_get(
|
|
3276
|
+
FrameworkNS.STREAMING_SESSION,
|
|
3277
|
+
BoundStorage.pack_int_key(_STREAMING_SESSION_STORAGE_KEY),
|
|
3278
|
+
)
|
|
3279
|
+
if payload is not None:
|
|
3280
|
+
session = _decode_streaming_session(payload, func_cls)
|
|
3281
|
+
cold_storage.state_delete(
|
|
3282
|
+
FrameworkNS.STREAMING_SESSION,
|
|
3283
|
+
[BoundStorage.pack_int_key(_STREAMING_SESSION_STORAGE_KEY)],
|
|
3284
|
+
)
|
|
3285
|
+
if session is None:
|
|
3286
|
+
# Idempotent: nothing to clean up.
|
|
3287
|
+
return AggregateStreamingCloseResponse()
|
|
3288
|
+
|
|
3289
|
+
storage = self._bound(session.func_cls.storage, request.execution_id, request)
|
|
3290
|
+
const_args = self._load_aggregate_const_args(session.func_cls, storage)
|
|
3291
|
+
params = ProcessParams(
|
|
3292
|
+
args=const_args,
|
|
3293
|
+
init_call=None,
|
|
3294
|
+
init_response=None,
|
|
3295
|
+
output_schema=session.output_schema,
|
|
3296
|
+
settings={},
|
|
3297
|
+
secrets={},
|
|
3298
|
+
storage=storage,
|
|
3299
|
+
auth_context=ctx.auth,
|
|
3300
|
+
)
|
|
3301
|
+
try:
|
|
3302
|
+
session.func_cls.streaming_close(session.streaming_state, params)
|
|
3303
|
+
except Exception: # noqa: BLE001
|
|
3304
|
+
_logger.exception("streaming_close raised; session dropped anyway")
|
|
3305
|
+
|
|
3306
|
+
# Drop the persisted state.
|
|
3307
|
+
storage.state_delete(
|
|
3308
|
+
FrameworkNS.STREAMING_SESSION,
|
|
3309
|
+
[BoundStorage.pack_int_key(_STREAMING_SESSION_STORAGE_KEY)],
|
|
3310
|
+
)
|
|
3311
|
+
|
|
3312
|
+
# Dump worker-side persist stats accumulated for this session.
|
|
3313
|
+
with _streaming_persist_lock:
|
|
3314
|
+
stats = dict(_streaming_persist_stats)
|
|
3315
|
+
_streaming_persist_stats["encode_session_seconds"] = 0.0
|
|
3316
|
+
_streaming_persist_stats["storage_put_seconds"] = 0.0
|
|
3317
|
+
_streaming_persist_stats["storage_get_seconds"] = 0.0
|
|
3318
|
+
_streaming_persist_stats["rpc_chunk_total_seconds"] = 0.0
|
|
3319
|
+
_streaming_persist_stats["n_chunks"] = 0
|
|
3320
|
+
_streaming_persist_stats["n_persists"] = 0
|
|
3321
|
+
_streaming_persist_stats["n_cold_loads"] = 0
|
|
3322
|
+
_streaming_persist_stats["bytes_persisted"] = 0
|
|
3323
|
+
|
|
3324
|
+
if stats["n_chunks"] > 0:
|
|
3325
|
+
n = stats["n_chunks"]
|
|
3326
|
+
mb = stats["bytes_persisted"] / (1024 * 1024)
|
|
3327
|
+
_logger.info(
|
|
3328
|
+
"streaming_persist_summary chunks=%d encode=%.3fs put=%.3fs bytes=%.1fMB cold_loads=%d",
|
|
3329
|
+
n,
|
|
3330
|
+
stats["encode_session_seconds"],
|
|
3331
|
+
stats["storage_put_seconds"],
|
|
3332
|
+
mb,
|
|
3333
|
+
stats["n_cold_loads"],
|
|
3334
|
+
)
|
|
3335
|
+
# Also stderr so it's visible in the SQL bench output.
|
|
3336
|
+
import sys as _sys
|
|
3337
|
+
|
|
3338
|
+
print(
|
|
3339
|
+
f"[streaming_persist_summary] chunks={n} "
|
|
3340
|
+
f"encode={stats['encode_session_seconds']:.3f}s "
|
|
3341
|
+
f"put={stats['storage_put_seconds']:.3f}s "
|
|
3342
|
+
f"bytes={mb:.1f}MB "
|
|
3343
|
+
f"cold_loads={stats['n_cold_loads']}",
|
|
3344
|
+
file=_sys.stderr,
|
|
3345
|
+
flush=True,
|
|
3346
|
+
)
|
|
3347
|
+
return AggregateStreamingCloseResponse()
|
|
3348
|
+
|
|
3349
|
+
# ========== Function Invocation ==========
|
|
3350
|
+
|
|
3351
|
+
def init(self, request: InitRequest, ctx: CallContext) -> Stream[ProcessState, GlobalInitResponse]:
|
|
3352
|
+
"""Initialize a function execution and return a processing stream.
|
|
3353
|
+
|
|
3354
|
+
Implements VgiProtocol.init(). Creates the appropriate state object
|
|
3355
|
+
based on function type and creates the appropriate state object.
|
|
3356
|
+
"""
|
|
3357
|
+
self._vgi_tracer.set_current_span_attributes(
|
|
3358
|
+
{
|
|
3359
|
+
"vgi.function.name": request.bind_call.function_name,
|
|
3360
|
+
"vgi.function.type": request.bind_call.function_type.value,
|
|
3361
|
+
"vgi.init.is_secondary": request.is_secondary,
|
|
3362
|
+
"vgi.principal": ctx.auth.principal,
|
|
3363
|
+
"vgi.auth_domain": ctx.auth.domain,
|
|
3364
|
+
"vgi.authenticated": ctx.auth.authenticated,
|
|
3365
|
+
}
|
|
3366
|
+
)
|
|
3367
|
+
# vgi.attach_opaque_data / vgi.transaction_opaque_data are auto-tagged by vgi-rpc's
|
|
3368
|
+
# Sentry dispatch hook (short-hash form) on every method that
|
|
3369
|
+
# carries them — including this one (descends bind_call).
|
|
3370
|
+
#
|
|
3371
|
+
# Unwrap once to the full framework plaintext ``uuid(16) || catalog_bytes``.
|
|
3372
|
+
# Storage shards on the leading UUID; function bodies get the catalog
|
|
3373
|
+
# bytes (uuid stripped) via params.attach_opaque_data. The streaming /
|
|
3374
|
+
# buffering states below persist this **full** plaintext into their
|
|
3375
|
+
# serialized tokens (the auth-scoped seal can't be reopened on a later,
|
|
3376
|
+
# possibly different-auth, produce/finalize turn — so we can't re-unwrap
|
|
3377
|
+
# then), and shard their cold-built storage on its UUID.
|
|
3378
|
+
attach_plaintext = self._unwrap_attach_full(getattr(request.bind_call, "attach_opaque_data", None))
|
|
3379
|
+
func_cls = self._resolve_function(request.bind_call)
|
|
3380
|
+
instance = func_cls(logger=_logger)
|
|
3381
|
+
|
|
3382
|
+
# Determine if this is a secondary init
|
|
3383
|
+
if request.is_secondary:
|
|
3384
|
+
assert request.execution_id is not None
|
|
3385
|
+
init_response = GlobalInitResponse(
|
|
3386
|
+
execution_id=request.execution_id,
|
|
3387
|
+
opaque_data=request.init_opaque_data,
|
|
3388
|
+
)
|
|
3389
|
+
else:
|
|
3390
|
+
if isinstance(instance, TableFunctionBase):
|
|
3391
|
+
init_response = instance.global_init(request, ctx=ctx, attach_plaintext=attach_plaintext)
|
|
3392
|
+
elif isinstance(instance, ScalarFunctionGenerator):
|
|
3393
|
+
init_response = instance.global_init(request, attach_plaintext=attach_plaintext)
|
|
3394
|
+
else:
|
|
3395
|
+
init_response = instance.global_init(request) # type: ignore[attr-defined]
|
|
3396
|
+
|
|
3397
|
+
self._vgi_tracer.set_current_span_attributes(
|
|
3398
|
+
{
|
|
3399
|
+
"vgi.init.execution_id": init_response.execution_id.hex(),
|
|
3400
|
+
}
|
|
3401
|
+
)
|
|
3402
|
+
if request.phase is not None:
|
|
3403
|
+
self._vgi_tracer.set_current_span_attributes(
|
|
3404
|
+
{
|
|
3405
|
+
"vgi.init.phase": request.phase.value,
|
|
3406
|
+
}
|
|
3407
|
+
)
|
|
3408
|
+
|
|
3409
|
+
# Build common ProcessParams for table/table-in-out functions
|
|
3410
|
+
proj_ids = _effective_projection_ids(func_cls, request.projection_ids)
|
|
3411
|
+
output_schema = project_schema(proj_ids, request.output_schema)
|
|
3412
|
+
|
|
3413
|
+
# Determine state and input_schema based on function type
|
|
3414
|
+
state: ProcessState
|
|
3415
|
+
input_schema: pa.Schema | None
|
|
3416
|
+
|
|
3417
|
+
if isinstance(instance, ScalarFunctionGenerator) and not isinstance(instance, TableInOutGenerator):
|
|
3418
|
+
# Scalar function: exchange state with per-batch process()
|
|
3419
|
+
state = ScalarExchangeState(
|
|
3420
|
+
_func_cls=type(instance),
|
|
3421
|
+
_init_call=request,
|
|
3422
|
+
_init_response=init_response,
|
|
3423
|
+
_plaintext_attach=attach_plaintext,
|
|
3424
|
+
_vgi_tracer=self._vgi_tracer,
|
|
3425
|
+
)
|
|
3426
|
+
input_schema = request.bind_call.input_schema
|
|
3427
|
+
|
|
3428
|
+
elif isinstance(instance, TableBufferingFunction):
|
|
3429
|
+
# Table sink+source function. Two init phases reuse the
|
|
3430
|
+
# TABLE_BUFFERING (sink init) / TABLE_BUFFERING_FINALIZE (stream init) enum
|
|
3431
|
+
# values (renaming the phase strings is in task #6).
|
|
3432
|
+
cold_storage = BoundStorage(
|
|
3433
|
+
type(instance).storage,
|
|
3434
|
+
init_response.execution_id,
|
|
3435
|
+
request=request,
|
|
3436
|
+
attach_plaintext=self._unwrap_attach_full_for(request),
|
|
3437
|
+
)
|
|
3438
|
+
# ``attach_id`` is the user-facing plaintext attach identity (used to
|
|
3439
|
+
# pin attach-time config lookups) — the catalog's bytes with the
|
|
3440
|
+
# framework UUID prefix stripped.
|
|
3441
|
+
attach_id = bytes(attach_catalog_bytes(attach_plaintext) or b"")
|
|
3442
|
+
# Widen to ProcessParams so the later TableInOutGenerator /
|
|
3443
|
+
# TableFunctionGenerator branches can rebind `params` without
|
|
3444
|
+
# tripping mypy's local-type narrowing.
|
|
3445
|
+
params: ProcessParams[Any]
|
|
3446
|
+
params = TableBufferingParams(
|
|
3447
|
+
args=type(instance)._parse_arguments(
|
|
3448
|
+
type(instance).FunctionArguments,
|
|
3449
|
+
request.bind_call.arguments,
|
|
3450
|
+
),
|
|
3451
|
+
init_call=request,
|
|
3452
|
+
init_response=init_response,
|
|
3453
|
+
output_schema=output_schema,
|
|
3454
|
+
settings=_batch_to_scalar_dict(request.bind_call.settings),
|
|
3455
|
+
secrets=SecretsAccessor(request.bind_call.secrets).to_dict(),
|
|
3456
|
+
storage=cold_storage,
|
|
3457
|
+
auth_context=ctx.auth,
|
|
3458
|
+
attach_opaque_data=attach_catalog_bytes(attach_plaintext),
|
|
3459
|
+
execution_id=init_response.execution_id,
|
|
3460
|
+
attach_id=attach_id,
|
|
3461
|
+
transaction_id=getattr(
|
|
3462
|
+
request.bind_call,
|
|
3463
|
+
"transaction_opaque_data",
|
|
3464
|
+
None,
|
|
3465
|
+
),
|
|
3466
|
+
function_name=request.bind_call.function_name,
|
|
3467
|
+
worker_path=None,
|
|
3468
|
+
)
|
|
3469
|
+
if request.phase == TableInOutFunctionInitPhase.TABLE_BUFFERING:
|
|
3470
|
+
# Sink init: persist init metadata so any pool worker
|
|
3471
|
+
# can cold-load + serve subsequent process/combine RPCs.
|
|
3472
|
+
cold_storage.state_put(
|
|
3473
|
+
FrameworkNS.BUFFERING_INIT,
|
|
3474
|
+
BoundStorage.pack_int_key(_TABLE_BUFFERING_INIT_KEY),
|
|
3475
|
+
_encode_table_buffering_init(request, init_response),
|
|
3476
|
+
)
|
|
3477
|
+
state = BufferedFinalizeState(
|
|
3478
|
+
execution_id=init_response.execution_id,
|
|
3479
|
+
ns=FrameworkNS.BUFFERING_INIT,
|
|
3480
|
+
key=b"__never_written__",
|
|
3481
|
+
attach_opaque_data=attach_plaintext,
|
|
3482
|
+
)
|
|
3483
|
+
input_schema = None
|
|
3484
|
+
elif request.phase == TableInOutFunctionInitPhase.TABLE_BUFFERING_FINALIZE:
|
|
3485
|
+
if request.finalize_state_id is None:
|
|
3486
|
+
raise ValueError("TABLE_BUFFERING_FINALIZE phase requires finalize_state_id")
|
|
3487
|
+
state = TableBufferingFinalizeState(
|
|
3488
|
+
function_name=request.bind_call.function_name,
|
|
3489
|
+
execution_id=init_response.execution_id,
|
|
3490
|
+
transaction_id=getattr(
|
|
3491
|
+
request.bind_call,
|
|
3492
|
+
"transaction_opaque_data",
|
|
3493
|
+
None,
|
|
3494
|
+
),
|
|
3495
|
+
finalize_state_id=bytes(request.finalize_state_id),
|
|
3496
|
+
attach_opaque_data=attach_plaintext,
|
|
3497
|
+
# Thread pushdown info from the InitRequest onto the producer
|
|
3498
|
+
# state so it survives HTTP rehydration. The state is wire-
|
|
3499
|
+
# serialized between produce() ticks; carrying projection_ids
|
|
3500
|
+
# + pushdown_filters here lets every tick (potentially on a
|
|
3501
|
+
# different worker process) narrow output_schema and apply
|
|
3502
|
+
# filters consistently with the C++ Sink+Source contract.
|
|
3503
|
+
projection_ids=request.projection_ids,
|
|
3504
|
+
pushdown_filters=request.pushdown_filters,
|
|
3505
|
+
)
|
|
3506
|
+
input_schema = None
|
|
3507
|
+
else:
|
|
3508
|
+
raise ValueError(f"Unsupported init phase for TableBufferingFunction: {request.phase}")
|
|
3509
|
+
|
|
3510
|
+
elif isinstance(instance, TableInOutGenerator):
|
|
3511
|
+
# Table-in-out function: separate INPUT and FINALIZE phases
|
|
3512
|
+
params = ProcessParams(
|
|
3513
|
+
args=type(instance)._parse_arguments(type(instance).FunctionArguments, request.bind_call.arguments),
|
|
3514
|
+
init_call=request,
|
|
3515
|
+
init_response=init_response,
|
|
3516
|
+
output_schema=output_schema,
|
|
3517
|
+
settings=_batch_to_scalar_dict(request.bind_call.settings),
|
|
3518
|
+
secrets=SecretsAccessor(request.bind_call.secrets).to_dict(),
|
|
3519
|
+
storage=BoundStorage(
|
|
3520
|
+
type(instance).storage,
|
|
3521
|
+
init_response.execution_id,
|
|
3522
|
+
request=request,
|
|
3523
|
+
attach_plaintext=self._unwrap_attach_full_for(request),
|
|
3524
|
+
),
|
|
3525
|
+
auth_context=ctx.auth,
|
|
3526
|
+
attach_opaque_data=attach_catalog_bytes(attach_plaintext),
|
|
3527
|
+
)
|
|
3528
|
+
|
|
3529
|
+
if request.phase == TableInOutFunctionInitPhase.INPUT:
|
|
3530
|
+
user_state = type(instance).initial_state(params)
|
|
3531
|
+
state = TableInOutExchangeState(
|
|
3532
|
+
_init_call=request,
|
|
3533
|
+
_init_response=init_response,
|
|
3534
|
+
_plaintext_attach=attach_plaintext,
|
|
3535
|
+
_func_cls=type(instance),
|
|
3536
|
+
_params=params,
|
|
3537
|
+
_user_state=user_state,
|
|
3538
|
+
_vgi_tracer=self._vgi_tracer,
|
|
3539
|
+
)
|
|
3540
|
+
input_schema = request.bind_call.input_schema
|
|
3541
|
+
elif request.phase == TableInOutFunctionInitPhase.FINALIZE:
|
|
3542
|
+
# Streaming-shape FINALIZE: materialize the user's
|
|
3543
|
+
# finalize() return into BoundedStorage so the framework
|
|
3544
|
+
# can stream it via cursor — same shape as buffered.
|
|
3545
|
+
# User-facing API unchanged: finalize(params) -> list[batch].
|
|
3546
|
+
# Removes the prior _batches: Transient anti-pattern that
|
|
3547
|
+
# silently truncated streams over HTTP.
|
|
3548
|
+
finalize_batches = type(instance).finalize(params)
|
|
3549
|
+
for batch in finalize_batches:
|
|
3550
|
+
sink = pa.BufferOutputStream()
|
|
3551
|
+
with pa.ipc.new_stream(sink, batch.schema) as w:
|
|
3552
|
+
w.write_batch(batch)
|
|
3553
|
+
params.storage.state_append(
|
|
3554
|
+
FrameworkNS.STREAMING_FINALIZE,
|
|
3555
|
+
_STREAMING_FINALIZE_KEY,
|
|
3556
|
+
sink.getvalue().to_pybytes(),
|
|
3557
|
+
)
|
|
3558
|
+
state = BufferedFinalizeState(
|
|
3559
|
+
execution_id=init_response.execution_id,
|
|
3560
|
+
ns=FrameworkNS.STREAMING_FINALIZE,
|
|
3561
|
+
key=_STREAMING_FINALIZE_KEY,
|
|
3562
|
+
attach_opaque_data=attach_plaintext,
|
|
3563
|
+
)
|
|
3564
|
+
input_schema = None # Producer — no input
|
|
3565
|
+
else:
|
|
3566
|
+
raise ValueError(f"Unknown init phase for table-in-out function: {request.phase}")
|
|
3567
|
+
|
|
3568
|
+
elif isinstance(instance, TableFunctionGenerator):
|
|
3569
|
+
# Table function: producer state with per-tick process()
|
|
3570
|
+
params = ProcessParams(
|
|
3571
|
+
args=type(instance)._parse_arguments(type(instance).FunctionArguments, request.bind_call.arguments),
|
|
3572
|
+
init_call=request,
|
|
3573
|
+
init_response=init_response,
|
|
3574
|
+
output_schema=output_schema,
|
|
3575
|
+
settings=_batch_to_scalar_dict(request.bind_call.settings),
|
|
3576
|
+
secrets=SecretsAccessor(request.bind_call.secrets).to_dict(),
|
|
3577
|
+
storage=BoundStorage(
|
|
3578
|
+
type(instance).storage,
|
|
3579
|
+
init_response.execution_id,
|
|
3580
|
+
request=request,
|
|
3581
|
+
attach_plaintext=self._unwrap_attach_full_for(request),
|
|
3582
|
+
),
|
|
3583
|
+
auth_context=ctx.auth,
|
|
3584
|
+
attach_opaque_data=attach_catalog_bytes(attach_plaintext),
|
|
3585
|
+
)
|
|
3586
|
+
user_state = type(instance).initial_state(params)
|
|
3587
|
+
state = TableProducerState(
|
|
3588
|
+
_init_call=request,
|
|
3589
|
+
_init_response=init_response,
|
|
3590
|
+
_plaintext_attach=attach_plaintext,
|
|
3591
|
+
_func_cls=type(instance),
|
|
3592
|
+
_params=params,
|
|
3593
|
+
_user_state=user_state,
|
|
3594
|
+
_vgi_tracer=self._vgi_tracer,
|
|
3595
|
+
)
|
|
3596
|
+
input_schema = None # Producer — no input
|
|
3597
|
+
|
|
3598
|
+
else:
|
|
3599
|
+
raise ValueError(f"Unknown function type: {type(instance).__name__}")
|
|
3600
|
+
|
|
3601
|
+
return Stream(
|
|
3602
|
+
output_schema=output_schema,
|
|
3603
|
+
state=state,
|
|
3604
|
+
input_schema=input_schema or pa.schema([]),
|
|
3605
|
+
header=init_response,
|
|
3606
|
+
)
|
|
3607
|
+
|
|
3608
|
+
# ---------------------------------------------------------------------------
|
|
3609
|
+
# VgiProtocol implementation - Catalog Discovery
|
|
3610
|
+
# ---------------------------------------------------------------------------
|
|
3611
|
+
|
|
3612
|
+
def _enrich_catalog_span(self, **attrs: Any) -> None:
|
|
3613
|
+
"""Add catalog-specific attributes to the current vgi_rpc span."""
|
|
3614
|
+
self._vgi_tracer.set_current_span_attributes(attrs)
|
|
3615
|
+
|
|
3616
|
+
def _log_catalog_lifecycle(self, event: str, **fields: Any) -> None:
|
|
3617
|
+
"""Emit a structured log line and Sentry breadcrumb for a catalog event.
|
|
3618
|
+
|
|
3619
|
+
``event`` is a dotted name such as ``"catalog.attach"`` and is used
|
|
3620
|
+
both as the log message and the breadcrumb category. ``fields``
|
|
3621
|
+
are merged into the log record's ``extra`` and the breadcrumb data;
|
|
3622
|
+
callers must omit credentials. See
|
|
3623
|
+
:meth:`CatalogInterface.loggable_attach_options` for the
|
|
3624
|
+
opt-in option-redaction hook.
|
|
3625
|
+
|
|
3626
|
+
``attach_opaque_data`` / ``transaction_opaque_data`` are
|
|
3627
|
+
implementation-chosen byte strings that may carry credentials. This
|
|
3628
|
+
method is the single chokepoint that short-hashes them before they
|
|
3629
|
+
reach *any* sink — the log record, the Sentry breadcrumb data, and
|
|
3630
|
+
the Sentry scope tags — so no caller can leak a raw value.
|
|
3631
|
+
"""
|
|
3632
|
+
# Drop None values so logs and breadcrumbs stay tidy.
|
|
3633
|
+
clean = {k: v for k, v in fields.items() if v is not None}
|
|
3634
|
+
# Redact the opaque-data fields once, here, before they reach the
|
|
3635
|
+
# logger or any Sentry sink.
|
|
3636
|
+
redacted = dict(clean)
|
|
3637
|
+
for fld in ("attach_opaque_data", "transaction_opaque_data"):
|
|
3638
|
+
raw = redacted.get(fld)
|
|
3639
|
+
if raw:
|
|
3640
|
+
redacted[fld] = _short_hash(raw)
|
|
3641
|
+
_logger.info(event, extra=redacted)
|
|
3642
|
+
if "sentry_sdk" in sys.modules:
|
|
3643
|
+
import sentry_sdk
|
|
3644
|
+
|
|
3645
|
+
if sentry_sdk.is_initialized():
|
|
3646
|
+
scope = sentry_sdk.get_current_scope()
|
|
3647
|
+
for fld in ("attach_opaque_data", "transaction_opaque_data"):
|
|
3648
|
+
hashed = redacted.get(fld)
|
|
3649
|
+
if hashed:
|
|
3650
|
+
scope.set_tag(f"vgi.{fld}", hashed)
|
|
3651
|
+
sentry_sdk.add_breadcrumb(
|
|
3652
|
+
category=event,
|
|
3653
|
+
message=event,
|
|
3654
|
+
level="info",
|
|
3655
|
+
data=redacted,
|
|
3656
|
+
)
|
|
3657
|
+
|
|
3658
|
+
def catalog_catalogs(self) -> CatalogsResponse:
|
|
3659
|
+
"""List available catalog discovery records."""
|
|
3660
|
+
cat = self._get_catalog()
|
|
3661
|
+
return CatalogsResponse.from_infos(list(cat.catalogs()))
|
|
3662
|
+
|
|
3663
|
+
# ---------------------------------------------------------------------------
|
|
3664
|
+
# VgiProtocol implementation - Catalog Lifecycle
|
|
3665
|
+
# ---------------------------------------------------------------------------
|
|
3666
|
+
|
|
3667
|
+
def catalog_attach(
|
|
3668
|
+
self,
|
|
3669
|
+
request: CatalogAttachRequest,
|
|
3670
|
+
*,
|
|
3671
|
+
ctx: CallContext | None = None,
|
|
3672
|
+
) -> CatalogAttachResult:
|
|
3673
|
+
"""Attach to a catalog with options."""
|
|
3674
|
+
self._enrich_catalog_span(vgi_catalog_name=request.name)
|
|
3675
|
+
self._vgi_tracer.set_current_span_attributes(
|
|
3676
|
+
{
|
|
3677
|
+
"vgi.catalog.name": request.name,
|
|
3678
|
+
"vgi.data_version_spec": request.data_version_spec,
|
|
3679
|
+
"vgi.implementation_version": request.implementation_version,
|
|
3680
|
+
}
|
|
3681
|
+
)
|
|
3682
|
+
cat = self._get_catalog()
|
|
3683
|
+
options = self._options_batch_to_dict(request.options)
|
|
3684
|
+
result = cat.catalog_attach(
|
|
3685
|
+
name=request.name,
|
|
3686
|
+
options=options,
|
|
3687
|
+
data_version_spec=request.data_version_spec,
|
|
3688
|
+
implementation_version=request.implementation_version,
|
|
3689
|
+
ctx=ctx,
|
|
3690
|
+
)
|
|
3691
|
+
# Mint the shard identity: prepend a fresh framework UUID to the
|
|
3692
|
+
# catalog's plaintext (``uuid(16) || catalog_bytes``), then seal. Storage
|
|
3693
|
+
# shards on this UUID — stable across re-seals and globally unique, unlike
|
|
3694
|
+
# the random-nonce ciphertext or the (possibly non-unique) catalog bytes.
|
|
3695
|
+
# ``_unwrap_attach`` strips the UUID back off, so the catalog only ever
|
|
3696
|
+
# sees its own bytes. See ``_AttachUnwrapper``.
|
|
3697
|
+
if result.attach_opaque_data is not None:
|
|
3698
|
+
minted = uuid.uuid4().bytes + bytes(result.attach_opaque_data)
|
|
3699
|
+
result = _dataclass_replace(result, attach_opaque_data=self._seal_attach(minted))
|
|
3700
|
+
loggable = dict(cat.loggable_attach_options(options))
|
|
3701
|
+
self._log_catalog_lifecycle(
|
|
3702
|
+
"catalog.attach",
|
|
3703
|
+
catalog_name=request.name,
|
|
3704
|
+
attach_opaque_data=result.attach_opaque_data.hex() if result.attach_opaque_data else None,
|
|
3705
|
+
data_version_spec=request.data_version_spec,
|
|
3706
|
+
implementation_version=request.implementation_version,
|
|
3707
|
+
options=loggable or None,
|
|
3708
|
+
)
|
|
3709
|
+
return result
|
|
3710
|
+
|
|
3711
|
+
def catalog_detach(self, attach_opaque_data: bytes) -> None:
|
|
3712
|
+
"""Detach from a catalog."""
|
|
3713
|
+
cat = self._get_catalog()
|
|
3714
|
+
cat.catalog_detach(attach_opaque_data=self._unwrap_attach(attach_opaque_data))
|
|
3715
|
+
self._log_catalog_lifecycle("catalog.detach", attach_opaque_data=attach_opaque_data.hex())
|
|
3716
|
+
|
|
3717
|
+
def catalog_create(self, request: CatalogCreateRequest) -> None:
|
|
3718
|
+
"""Create a new catalog."""
|
|
3719
|
+
self._enrich_catalog_span(vgi_catalog_name=request.name)
|
|
3720
|
+
cat = self._get_catalog()
|
|
3721
|
+
options = self._options_batch_to_dict(request.options)
|
|
3722
|
+
cat.catalog_create(name=request.name, on_conflict=request.on_conflict, options=options)
|
|
3723
|
+
loggable = dict(cat.loggable_attach_options(options))
|
|
3724
|
+
self._log_catalog_lifecycle(
|
|
3725
|
+
"catalog.create",
|
|
3726
|
+
catalog_name=request.name,
|
|
3727
|
+
on_conflict=request.on_conflict.value,
|
|
3728
|
+
options=loggable or None,
|
|
3729
|
+
)
|
|
3730
|
+
|
|
3731
|
+
def catalog_drop(self, name: str) -> None:
|
|
3732
|
+
"""Drop a catalog."""
|
|
3733
|
+
self._enrich_catalog_span(vgi_catalog_name=name)
|
|
3734
|
+
cat = self._get_catalog()
|
|
3735
|
+
cat.catalog_drop(name=name)
|
|
3736
|
+
|
|
3737
|
+
def catalog_version(
|
|
3738
|
+
self,
|
|
3739
|
+
attach_opaque_data: bytes,
|
|
3740
|
+
transaction_opaque_data: bytes | None = None,
|
|
3741
|
+
*,
|
|
3742
|
+
ctx: CallContext | None = None,
|
|
3743
|
+
) -> CatalogVersionResponse:
|
|
3744
|
+
"""Get the current catalog version."""
|
|
3745
|
+
cat = self._get_catalog()
|
|
3746
|
+
version = cat.catalog_version(
|
|
3747
|
+
attach_opaque_data=self._unwrap_attach(attach_opaque_data),
|
|
3748
|
+
transaction_opaque_data=self._unwrap_transaction(transaction_opaque_data, attach_opaque_data)
|
|
3749
|
+
if transaction_opaque_data
|
|
3750
|
+
else None,
|
|
3751
|
+
ctx=ctx,
|
|
3752
|
+
)
|
|
3753
|
+
return CatalogVersionResponse(version=version)
|
|
3754
|
+
|
|
3755
|
+
# ---------------------------------------------------------------------------
|
|
3756
|
+
# VgiProtocol implementation - Catalog Transactions
|
|
3757
|
+
# ---------------------------------------------------------------------------
|
|
3758
|
+
|
|
3759
|
+
def catalog_transaction_begin(self, attach_opaque_data: bytes) -> TransactionBeginResponse:
|
|
3760
|
+
"""Begin a new transaction."""
|
|
3761
|
+
cat = self._get_catalog()
|
|
3762
|
+
tx_id = cat.catalog_transaction_begin(attach_opaque_data=self._unwrap_attach(attach_opaque_data))
|
|
3763
|
+
# Seal the implementation's plaintext transaction value, binding it to
|
|
3764
|
+
# the caller's identity *and* the parent attach envelope it was minted
|
|
3765
|
+
# under, before it leaves the worker.
|
|
3766
|
+
sealed_tx = self._seal_transaction(bytes(tx_id), attach_opaque_data) if tx_id else None
|
|
3767
|
+
self._log_catalog_lifecycle(
|
|
3768
|
+
"catalog.transaction.begin",
|
|
3769
|
+
attach_opaque_data=attach_opaque_data.hex(),
|
|
3770
|
+
transaction_opaque_data=sealed_tx.hex() if sealed_tx else None,
|
|
3771
|
+
)
|
|
3772
|
+
return TransactionBeginResponse(transaction_opaque_data=sealed_tx)
|
|
3773
|
+
|
|
3774
|
+
def catalog_transaction_commit(self, attach_opaque_data: bytes, transaction_opaque_data: bytes) -> None:
|
|
3775
|
+
"""Commit a transaction."""
|
|
3776
|
+
cat = self._get_catalog()
|
|
3777
|
+
cat.catalog_transaction_commit(
|
|
3778
|
+
attach_opaque_data=self._unwrap_attach(attach_opaque_data),
|
|
3779
|
+
transaction_opaque_data=self._unwrap_transaction(transaction_opaque_data, attach_opaque_data),
|
|
3780
|
+
)
|
|
3781
|
+
self._log_catalog_lifecycle(
|
|
3782
|
+
"catalog.transaction.commit",
|
|
3783
|
+
attach_opaque_data=attach_opaque_data.hex(),
|
|
3784
|
+
transaction_opaque_data=transaction_opaque_data.hex(),
|
|
3785
|
+
)
|
|
3786
|
+
|
|
3787
|
+
def catalog_transaction_rollback(self, attach_opaque_data: bytes, transaction_opaque_data: bytes) -> None:
|
|
3788
|
+
"""Rollback a transaction."""
|
|
3789
|
+
cat = self._get_catalog()
|
|
3790
|
+
cat.catalog_transaction_rollback(
|
|
3791
|
+
attach_opaque_data=self._unwrap_attach(attach_opaque_data),
|
|
3792
|
+
transaction_opaque_data=self._unwrap_transaction(transaction_opaque_data, attach_opaque_data),
|
|
3793
|
+
)
|
|
3794
|
+
self._log_catalog_lifecycle(
|
|
3795
|
+
"catalog.transaction.rollback",
|
|
3796
|
+
attach_opaque_data=attach_opaque_data.hex(),
|
|
3797
|
+
transaction_opaque_data=transaction_opaque_data.hex(),
|
|
3798
|
+
)
|
|
3799
|
+
|
|
3800
|
+
# ---------------------------------------------------------------------------
|
|
3801
|
+
# VgiProtocol implementation - Catalog Schemas
|
|
3802
|
+
# ---------------------------------------------------------------------------
|
|
3803
|
+
|
|
3804
|
+
def catalog_schemas(
|
|
3805
|
+
self, attach_opaque_data: bytes, transaction_opaque_data: bytes | None = None
|
|
3806
|
+
) -> SchemasResponse:
|
|
3807
|
+
"""List schemas in the catalog."""
|
|
3808
|
+
cat = self._get_catalog()
|
|
3809
|
+
infos = cat.schemas(
|
|
3810
|
+
attach_opaque_data=self._unwrap_attach(attach_opaque_data),
|
|
3811
|
+
transaction_opaque_data=self._unwrap_transaction(transaction_opaque_data, attach_opaque_data)
|
|
3812
|
+
if transaction_opaque_data
|
|
3813
|
+
else None,
|
|
3814
|
+
)
|
|
3815
|
+
return SchemasResponse.from_infos(list(infos))
|
|
3816
|
+
|
|
3817
|
+
def catalog_schema_get(
|
|
3818
|
+
self, attach_opaque_data: bytes, name: str, transaction_opaque_data: bytes | None = None
|
|
3819
|
+
) -> SchemasResponse:
|
|
3820
|
+
"""Get information about a schema. Returns 0 or 1 items."""
|
|
3821
|
+
self._enrich_catalog_span(vgi_schema_name=name)
|
|
3822
|
+
cat = self._get_catalog()
|
|
3823
|
+
info = cat.schema_get(
|
|
3824
|
+
attach_opaque_data=self._unwrap_attach(attach_opaque_data),
|
|
3825
|
+
transaction_opaque_data=self._unwrap_transaction(transaction_opaque_data, attach_opaque_data)
|
|
3826
|
+
if transaction_opaque_data
|
|
3827
|
+
else None,
|
|
3828
|
+
name=name,
|
|
3829
|
+
)
|
|
3830
|
+
return SchemasResponse.from_optional(info)
|
|
3831
|
+
|
|
3832
|
+
def catalog_schema_create(
|
|
3833
|
+
self,
|
|
3834
|
+
attach_opaque_data: bytes,
|
|
3835
|
+
name: str,
|
|
3836
|
+
on_conflict: OnConflict = OnConflict.ERROR,
|
|
3837
|
+
comment: str | None = None,
|
|
3838
|
+
tags: dict[str, str] | None = None,
|
|
3839
|
+
transaction_opaque_data: bytes | None = None,
|
|
3840
|
+
) -> None:
|
|
3841
|
+
"""Create a new schema."""
|
|
3842
|
+
self._enrich_catalog_span(vgi_schema_name=name)
|
|
3843
|
+
cat = self._get_catalog()
|
|
3844
|
+
cat.schema_create(
|
|
3845
|
+
attach_opaque_data=self._unwrap_attach(attach_opaque_data),
|
|
3846
|
+
transaction_opaque_data=self._unwrap_transaction(transaction_opaque_data, attach_opaque_data)
|
|
3847
|
+
if transaction_opaque_data
|
|
3848
|
+
else None,
|
|
3849
|
+
name=name,
|
|
3850
|
+
on_conflict=on_conflict,
|
|
3851
|
+
comment=comment,
|
|
3852
|
+
tags=tags or {},
|
|
3853
|
+
)
|
|
3854
|
+
|
|
3855
|
+
def catalog_schema_drop(
|
|
3856
|
+
self,
|
|
3857
|
+
attach_opaque_data: bytes,
|
|
3858
|
+
name: str,
|
|
3859
|
+
ignore_not_found: bool = False,
|
|
3860
|
+
cascade: bool = False,
|
|
3861
|
+
transaction_opaque_data: bytes | None = None,
|
|
3862
|
+
) -> None:
|
|
3863
|
+
"""Drop a schema."""
|
|
3864
|
+
self._enrich_catalog_span(vgi_schema_name=name)
|
|
3865
|
+
cat = self._get_catalog()
|
|
3866
|
+
cat.schema_drop(
|
|
3867
|
+
attach_opaque_data=self._unwrap_attach(attach_opaque_data),
|
|
3868
|
+
transaction_opaque_data=self._unwrap_transaction(transaction_opaque_data, attach_opaque_data)
|
|
3869
|
+
if transaction_opaque_data
|
|
3870
|
+
else None,
|
|
3871
|
+
name=name,
|
|
3872
|
+
ignore_not_found=ignore_not_found,
|
|
3873
|
+
cascade=cascade,
|
|
3874
|
+
)
|
|
3875
|
+
|
|
3876
|
+
def catalog_schema_contents_tables(
|
|
3877
|
+
self,
|
|
3878
|
+
attach_opaque_data: bytes,
|
|
3879
|
+
name: str,
|
|
3880
|
+
transaction_opaque_data: bytes | None = None,
|
|
3881
|
+
) -> TablesResponse:
|
|
3882
|
+
"""List tables in a schema."""
|
|
3883
|
+
self._enrich_catalog_span(vgi_schema_name=name)
|
|
3884
|
+
cat = self._get_catalog()
|
|
3885
|
+
infos = cat.schema_contents(
|
|
3886
|
+
attach_opaque_data=self._unwrap_attach(attach_opaque_data),
|
|
3887
|
+
transaction_opaque_data=self._unwrap_transaction(transaction_opaque_data, attach_opaque_data)
|
|
3888
|
+
if transaction_opaque_data
|
|
3889
|
+
else None,
|
|
3890
|
+
name=name,
|
|
3891
|
+
type=SchemaObjectType.TABLE,
|
|
3892
|
+
)
|
|
3893
|
+
return TablesResponse.from_infos(list(infos))
|
|
3894
|
+
|
|
3895
|
+
def catalog_schema_contents_views(
|
|
3896
|
+
self,
|
|
3897
|
+
attach_opaque_data: bytes,
|
|
3898
|
+
name: str,
|
|
3899
|
+
transaction_opaque_data: bytes | None = None,
|
|
3900
|
+
) -> ViewsResponse:
|
|
3901
|
+
"""List views in a schema."""
|
|
3902
|
+
self._enrich_catalog_span(vgi_schema_name=name)
|
|
3903
|
+
cat = self._get_catalog()
|
|
3904
|
+
infos = cat.schema_contents(
|
|
3905
|
+
attach_opaque_data=self._unwrap_attach(attach_opaque_data),
|
|
3906
|
+
transaction_opaque_data=self._unwrap_transaction(transaction_opaque_data, attach_opaque_data)
|
|
3907
|
+
if transaction_opaque_data
|
|
3908
|
+
else None,
|
|
3909
|
+
name=name,
|
|
3910
|
+
type=SchemaObjectType.VIEW,
|
|
3911
|
+
)
|
|
3912
|
+
return ViewsResponse.from_infos(list(infos))
|
|
3913
|
+
|
|
3914
|
+
def catalog_schema_contents_functions(
|
|
3915
|
+
self,
|
|
3916
|
+
attach_opaque_data: bytes,
|
|
3917
|
+
name: str,
|
|
3918
|
+
type: SchemaObjectType,
|
|
3919
|
+
transaction_opaque_data: bytes | None = None,
|
|
3920
|
+
) -> FunctionsResponse:
|
|
3921
|
+
"""List functions in a schema (scalar or table)."""
|
|
3922
|
+
self._enrich_catalog_span(vgi_schema_name=name)
|
|
3923
|
+
cat = self._get_catalog()
|
|
3924
|
+
infos = cat.schema_contents( # type: ignore[call-overload]
|
|
3925
|
+
attach_opaque_data=self._unwrap_attach(attach_opaque_data),
|
|
3926
|
+
transaction_opaque_data=self._unwrap_transaction(transaction_opaque_data, attach_opaque_data)
|
|
3927
|
+
if transaction_opaque_data
|
|
3928
|
+
else None,
|
|
3929
|
+
name=name,
|
|
3930
|
+
type=type,
|
|
3931
|
+
)
|
|
3932
|
+
return FunctionsResponse.from_infos(list(infos))
|
|
3933
|
+
|
|
3934
|
+
# ---------------------------------------------------------------------------
|
|
3935
|
+
# VgiProtocol implementation - Catalog Tables
|
|
3936
|
+
# ---------------------------------------------------------------------------
|
|
3937
|
+
|
|
3938
|
+
def catalog_table_get(
|
|
3939
|
+
self,
|
|
3940
|
+
attach_opaque_data: bytes,
|
|
3941
|
+
schema_name: str,
|
|
3942
|
+
name: str,
|
|
3943
|
+
at_unit: str | None = None,
|
|
3944
|
+
at_value: str | None = None,
|
|
3945
|
+
transaction_opaque_data: bytes | None = None,
|
|
3946
|
+
) -> TablesResponse:
|
|
3947
|
+
"""Get information about a table. Returns 0 or 1 items."""
|
|
3948
|
+
_validate_at_params(at_unit, at_value)
|
|
3949
|
+
self._enrich_catalog_span(vgi_schema_name=schema_name, vgi_table_name=name)
|
|
3950
|
+
cat = self._get_catalog()
|
|
3951
|
+
info = cat.table_get(
|
|
3952
|
+
attach_opaque_data=self._unwrap_attach(attach_opaque_data),
|
|
3953
|
+
transaction_opaque_data=self._unwrap_transaction(transaction_opaque_data, attach_opaque_data)
|
|
3954
|
+
if transaction_opaque_data
|
|
3955
|
+
else None,
|
|
3956
|
+
schema_name=schema_name,
|
|
3957
|
+
name=name,
|
|
3958
|
+
at_unit=at_unit,
|
|
3959
|
+
at_value=at_value,
|
|
3960
|
+
)
|
|
3961
|
+
return TablesResponse.from_optional(info)
|
|
3962
|
+
|
|
3963
|
+
def catalog_table_create(self, request: TableCreateRequest) -> None:
|
|
3964
|
+
"""Create a new table."""
|
|
3965
|
+
self._enrich_catalog_span(vgi_schema_name=request.schema_name, vgi_table_name=request.name)
|
|
3966
|
+
cat = self._get_catalog()
|
|
3967
|
+
cat.table_create(
|
|
3968
|
+
attach_opaque_data=self._unwrap_attach(request.attach_opaque_data),
|
|
3969
|
+
transaction_opaque_data=self._unwrap_transaction(
|
|
3970
|
+
request.transaction_opaque_data, request.attach_opaque_data
|
|
3971
|
+
)
|
|
3972
|
+
if request.transaction_opaque_data
|
|
3973
|
+
else None,
|
|
3974
|
+
schema_name=request.schema_name,
|
|
3975
|
+
name=request.name,
|
|
3976
|
+
columns=SerializedSchema(request.columns),
|
|
3977
|
+
on_conflict=request.on_conflict,
|
|
3978
|
+
not_null_constraints=list(request.not_null_constraints),
|
|
3979
|
+
unique_constraints=[list(c) for c in request.unique_constraints],
|
|
3980
|
+
check_constraints=list(request.check_constraints),
|
|
3981
|
+
primary_key_constraints=(
|
|
3982
|
+
[list(c) for c in request.primary_key_constraints] if request.primary_key_constraints else None
|
|
3983
|
+
),
|
|
3984
|
+
foreign_key_constraints=(
|
|
3985
|
+
list(request.foreign_key_constraints) if request.foreign_key_constraints else None
|
|
3986
|
+
),
|
|
3987
|
+
)
|
|
3988
|
+
|
|
3989
|
+
def catalog_table_drop(
|
|
3990
|
+
self,
|
|
3991
|
+
attach_opaque_data: bytes,
|
|
3992
|
+
schema_name: str,
|
|
3993
|
+
name: str,
|
|
3994
|
+
ignore_not_found: bool = False,
|
|
3995
|
+
cascade: bool = False,
|
|
3996
|
+
transaction_opaque_data: bytes | None = None,
|
|
3997
|
+
) -> None:
|
|
3998
|
+
"""Drop a table."""
|
|
3999
|
+
self._enrich_catalog_span(vgi_schema_name=schema_name, vgi_table_name=name)
|
|
4000
|
+
cat = self._get_catalog()
|
|
4001
|
+
cat.table_drop(
|
|
4002
|
+
attach_opaque_data=self._unwrap_attach(attach_opaque_data),
|
|
4003
|
+
transaction_opaque_data=self._unwrap_transaction(transaction_opaque_data, attach_opaque_data)
|
|
4004
|
+
if transaction_opaque_data
|
|
4005
|
+
else None,
|
|
4006
|
+
schema_name=schema_name,
|
|
4007
|
+
name=name,
|
|
4008
|
+
ignore_not_found=ignore_not_found,
|
|
4009
|
+
cascade=cascade,
|
|
4010
|
+
)
|
|
4011
|
+
|
|
4012
|
+
def catalog_table_scan_function_get(
|
|
4013
|
+
self,
|
|
4014
|
+
attach_opaque_data: bytes,
|
|
4015
|
+
schema_name: str,
|
|
4016
|
+
name: str,
|
|
4017
|
+
at_unit: str | None = None,
|
|
4018
|
+
at_value: str | None = None,
|
|
4019
|
+
transaction_opaque_data: bytes | None = None,
|
|
4020
|
+
) -> bytes:
|
|
4021
|
+
"""Get the scan function for a table. Returns ScanFunctionResult as IPC bytes."""
|
|
4022
|
+
_validate_at_params(at_unit, at_value)
|
|
4023
|
+
self._enrich_catalog_span(vgi_schema_name=schema_name, vgi_table_name=name)
|
|
4024
|
+
cat = self._get_catalog()
|
|
4025
|
+
result = cat.table_scan_function_get(
|
|
4026
|
+
attach_opaque_data=self._unwrap_attach(attach_opaque_data),
|
|
4027
|
+
transaction_opaque_data=self._unwrap_transaction(transaction_opaque_data, attach_opaque_data)
|
|
4028
|
+
if transaction_opaque_data
|
|
4029
|
+
else None,
|
|
4030
|
+
schema_name=schema_name,
|
|
4031
|
+
name=name,
|
|
4032
|
+
at_unit=at_unit,
|
|
4033
|
+
at_value=at_value,
|
|
4034
|
+
)
|
|
4035
|
+
return result.serialize()
|
|
4036
|
+
|
|
4037
|
+
def catalog_table_scan_branches_get(
|
|
4038
|
+
self,
|
|
4039
|
+
attach_opaque_data: bytes,
|
|
4040
|
+
schema_name: str,
|
|
4041
|
+
name: str,
|
|
4042
|
+
at_unit: str | None = None,
|
|
4043
|
+
at_value: str | None = None,
|
|
4044
|
+
transaction_opaque_data: bytes | None = None,
|
|
4045
|
+
) -> bytes:
|
|
4046
|
+
"""Get the list of scan branches for a multi-branch table.
|
|
4047
|
+
|
|
4048
|
+
Returns ScanBranchesResult as IPC bytes. The CatalogInterface base
|
|
4049
|
+
provides a default-impl shim that wraps the legacy
|
|
4050
|
+
``table_scan_function_get`` as a one-branch result, so every existing
|
|
4051
|
+
single-source worker automatically responds correctly here without
|
|
4052
|
+
further code changes.
|
|
4053
|
+
"""
|
|
4054
|
+
_validate_at_params(at_unit, at_value)
|
|
4055
|
+
self._enrich_catalog_span(vgi_schema_name=schema_name, vgi_table_name=name)
|
|
4056
|
+
cat = self._get_catalog()
|
|
4057
|
+
result = cat.table_scan_branches_get(
|
|
4058
|
+
attach_opaque_data=self._unwrap_attach(attach_opaque_data),
|
|
4059
|
+
transaction_opaque_data=self._unwrap_transaction(transaction_opaque_data, attach_opaque_data)
|
|
4060
|
+
if transaction_opaque_data
|
|
4061
|
+
else None,
|
|
4062
|
+
schema_name=schema_name,
|
|
4063
|
+
name=name,
|
|
4064
|
+
at_unit=at_unit,
|
|
4065
|
+
at_value=at_value,
|
|
4066
|
+
)
|
|
4067
|
+
return result.serialize()
|
|
4068
|
+
|
|
4069
|
+
def catalog_table_column_statistics_get(
|
|
4070
|
+
self,
|
|
4071
|
+
attach_opaque_data: bytes,
|
|
4072
|
+
schema_name: str,
|
|
4073
|
+
name: str,
|
|
4074
|
+
transaction_opaque_data: bytes | None = None,
|
|
4075
|
+
) -> bytes | None:
|
|
4076
|
+
"""Get column statistics for a table. Returns IPC bytes or None."""
|
|
4077
|
+
self._enrich_catalog_span(vgi_schema_name=schema_name, vgi_table_name=name)
|
|
4078
|
+
cat = self._get_catalog()
|
|
4079
|
+
result = cat.table_column_statistics_get(
|
|
4080
|
+
attach_opaque_data=self._unwrap_attach(attach_opaque_data),
|
|
4081
|
+
transaction_opaque_data=self._unwrap_transaction(transaction_opaque_data, attach_opaque_data)
|
|
4082
|
+
if transaction_opaque_data
|
|
4083
|
+
else None,
|
|
4084
|
+
schema_name=schema_name,
|
|
4085
|
+
name=name,
|
|
4086
|
+
)
|
|
4087
|
+
if result is None:
|
|
4088
|
+
return None
|
|
4089
|
+
return serialize_column_statistics(result.statistics, result.cache_max_age_seconds)
|
|
4090
|
+
|
|
4091
|
+
def catalog_table_insert_function_get(
|
|
4092
|
+
self,
|
|
4093
|
+
attach_opaque_data: bytes,
|
|
4094
|
+
schema_name: str,
|
|
4095
|
+
name: str,
|
|
4096
|
+
transaction_opaque_data: bytes | None = None,
|
|
4097
|
+
writable_branch_function_name: str | None = None,
|
|
4098
|
+
) -> bytes:
|
|
4099
|
+
"""Get the insert function for a table. Returns WriteFunctionResult as IPC bytes."""
|
|
4100
|
+
self._enrich_catalog_span(vgi_schema_name=schema_name, vgi_table_name=name)
|
|
4101
|
+
cat = self._get_catalog()
|
|
4102
|
+
result = cat.table_insert_function_get(
|
|
4103
|
+
attach_opaque_data=self._unwrap_attach(attach_opaque_data),
|
|
4104
|
+
transaction_opaque_data=self._unwrap_transaction(transaction_opaque_data, attach_opaque_data)
|
|
4105
|
+
if transaction_opaque_data
|
|
4106
|
+
else None,
|
|
4107
|
+
schema_name=schema_name,
|
|
4108
|
+
name=name,
|
|
4109
|
+
writable_branch_function_name=writable_branch_function_name,
|
|
4110
|
+
)
|
|
4111
|
+
return result.serialize()
|
|
4112
|
+
|
|
4113
|
+
def catalog_table_update_function_get(
|
|
4114
|
+
self,
|
|
4115
|
+
attach_opaque_data: bytes,
|
|
4116
|
+
schema_name: str,
|
|
4117
|
+
name: str,
|
|
4118
|
+
transaction_opaque_data: bytes | None = None,
|
|
4119
|
+
) -> bytes:
|
|
4120
|
+
"""Get the update function for a table. Returns WriteFunctionResult as IPC bytes."""
|
|
4121
|
+
self._enrich_catalog_span(vgi_schema_name=schema_name, vgi_table_name=name)
|
|
4122
|
+
cat = self._get_catalog()
|
|
4123
|
+
result = cat.table_update_function_get(
|
|
4124
|
+
attach_opaque_data=self._unwrap_attach(attach_opaque_data),
|
|
4125
|
+
transaction_opaque_data=self._unwrap_transaction(transaction_opaque_data, attach_opaque_data)
|
|
4126
|
+
if transaction_opaque_data
|
|
4127
|
+
else None,
|
|
4128
|
+
schema_name=schema_name,
|
|
4129
|
+
name=name,
|
|
4130
|
+
)
|
|
4131
|
+
return result.serialize()
|
|
4132
|
+
|
|
4133
|
+
def catalog_table_delete_function_get(
|
|
4134
|
+
self,
|
|
4135
|
+
attach_opaque_data: bytes,
|
|
4136
|
+
schema_name: str,
|
|
4137
|
+
name: str,
|
|
4138
|
+
transaction_opaque_data: bytes | None = None,
|
|
4139
|
+
) -> bytes:
|
|
4140
|
+
"""Get the delete function for a table. Returns WriteFunctionResult as IPC bytes."""
|
|
4141
|
+
self._enrich_catalog_span(vgi_schema_name=schema_name, vgi_table_name=name)
|
|
4142
|
+
cat = self._get_catalog()
|
|
4143
|
+
result = cat.table_delete_function_get(
|
|
4144
|
+
attach_opaque_data=self._unwrap_attach(attach_opaque_data),
|
|
4145
|
+
transaction_opaque_data=self._unwrap_transaction(transaction_opaque_data, attach_opaque_data)
|
|
4146
|
+
if transaction_opaque_data
|
|
4147
|
+
else None,
|
|
4148
|
+
schema_name=schema_name,
|
|
4149
|
+
name=name,
|
|
4150
|
+
)
|
|
4151
|
+
return result.serialize()
|
|
4152
|
+
|
|
4153
|
+
def catalog_table_comment_set(
|
|
4154
|
+
self,
|
|
4155
|
+
attach_opaque_data: bytes,
|
|
4156
|
+
schema_name: str,
|
|
4157
|
+
name: str,
|
|
4158
|
+
comment: str | None = None,
|
|
4159
|
+
ignore_not_found: bool = False,
|
|
4160
|
+
transaction_opaque_data: bytes | None = None,
|
|
4161
|
+
) -> None:
|
|
4162
|
+
"""Set or clear the comment on a table."""
|
|
4163
|
+
self._enrich_catalog_span(vgi_schema_name=schema_name, vgi_table_name=name)
|
|
4164
|
+
cat = self._get_catalog()
|
|
4165
|
+
cat.table_comment_set(
|
|
4166
|
+
attach_opaque_data=self._unwrap_attach(attach_opaque_data),
|
|
4167
|
+
transaction_opaque_data=self._unwrap_transaction(transaction_opaque_data, attach_opaque_data)
|
|
4168
|
+
if transaction_opaque_data
|
|
4169
|
+
else None,
|
|
4170
|
+
schema_name=schema_name,
|
|
4171
|
+
name=name,
|
|
4172
|
+
comment=comment,
|
|
4173
|
+
ignore_not_found=ignore_not_found,
|
|
4174
|
+
)
|
|
4175
|
+
|
|
4176
|
+
def catalog_table_column_comment_set(
|
|
4177
|
+
self,
|
|
4178
|
+
attach_opaque_data: bytes,
|
|
4179
|
+
schema_name: str,
|
|
4180
|
+
name: str,
|
|
4181
|
+
column_name: str,
|
|
4182
|
+
comment: str | None = None,
|
|
4183
|
+
ignore_not_found: bool = False,
|
|
4184
|
+
transaction_opaque_data: bytes | None = None,
|
|
4185
|
+
) -> None:
|
|
4186
|
+
"""Set or clear the comment on a table column."""
|
|
4187
|
+
self._enrich_catalog_span(vgi_schema_name=schema_name, vgi_table_name=name)
|
|
4188
|
+
cat = self._get_catalog()
|
|
4189
|
+
cat.table_column_comment_set(
|
|
4190
|
+
attach_opaque_data=self._unwrap_attach(attach_opaque_data),
|
|
4191
|
+
transaction_opaque_data=self._unwrap_transaction(transaction_opaque_data, attach_opaque_data)
|
|
4192
|
+
if transaction_opaque_data
|
|
4193
|
+
else None,
|
|
4194
|
+
schema_name=schema_name,
|
|
4195
|
+
name=name,
|
|
4196
|
+
column_name=column_name,
|
|
4197
|
+
comment=comment,
|
|
4198
|
+
ignore_not_found=ignore_not_found,
|
|
4199
|
+
)
|
|
4200
|
+
|
|
4201
|
+
def catalog_table_rename(
|
|
4202
|
+
self,
|
|
4203
|
+
attach_opaque_data: bytes,
|
|
4204
|
+
schema_name: str,
|
|
4205
|
+
name: str,
|
|
4206
|
+
new_name: str,
|
|
4207
|
+
ignore_not_found: bool = False,
|
|
4208
|
+
transaction_opaque_data: bytes | None = None,
|
|
4209
|
+
) -> None:
|
|
4210
|
+
"""Rename a table."""
|
|
4211
|
+
self._enrich_catalog_span(vgi_schema_name=schema_name, vgi_table_name=name)
|
|
4212
|
+
cat = self._get_catalog()
|
|
4213
|
+
cat.table_rename(
|
|
4214
|
+
attach_opaque_data=self._unwrap_attach(attach_opaque_data),
|
|
4215
|
+
transaction_opaque_data=self._unwrap_transaction(transaction_opaque_data, attach_opaque_data)
|
|
4216
|
+
if transaction_opaque_data
|
|
4217
|
+
else None,
|
|
4218
|
+
schema_name=schema_name,
|
|
4219
|
+
name=name,
|
|
4220
|
+
new_name=new_name,
|
|
4221
|
+
ignore_not_found=ignore_not_found,
|
|
4222
|
+
)
|
|
4223
|
+
|
|
4224
|
+
def catalog_table_column_add(
|
|
4225
|
+
self,
|
|
4226
|
+
attach_opaque_data: bytes,
|
|
4227
|
+
schema_name: str,
|
|
4228
|
+
name: str,
|
|
4229
|
+
column_definition: bytes,
|
|
4230
|
+
ignore_not_found: bool = False,
|
|
4231
|
+
if_column_not_exists: bool = False,
|
|
4232
|
+
transaction_opaque_data: bytes | None = None,
|
|
4233
|
+
) -> None:
|
|
4234
|
+
"""Add a new column to a table."""
|
|
4235
|
+
self._enrich_catalog_span(vgi_schema_name=schema_name, vgi_table_name=name)
|
|
4236
|
+
cat = self._get_catalog()
|
|
4237
|
+
cat.table_column_add(
|
|
4238
|
+
attach_opaque_data=self._unwrap_attach(attach_opaque_data),
|
|
4239
|
+
transaction_opaque_data=self._unwrap_transaction(transaction_opaque_data, attach_opaque_data)
|
|
4240
|
+
if transaction_opaque_data
|
|
4241
|
+
else None,
|
|
4242
|
+
schema_name=schema_name,
|
|
4243
|
+
name=name,
|
|
4244
|
+
column_definition=SerializedSchema(column_definition),
|
|
4245
|
+
ignore_not_found=ignore_not_found,
|
|
4246
|
+
if_column_not_exists=if_column_not_exists,
|
|
4247
|
+
)
|
|
4248
|
+
|
|
4249
|
+
def catalog_table_column_drop(
|
|
4250
|
+
self,
|
|
4251
|
+
attach_opaque_data: bytes,
|
|
4252
|
+
schema_name: str,
|
|
4253
|
+
name: str,
|
|
4254
|
+
column_name: str,
|
|
4255
|
+
ignore_not_found: bool = False,
|
|
4256
|
+
if_column_exists: bool = False,
|
|
4257
|
+
cascade: bool = False,
|
|
4258
|
+
transaction_opaque_data: bytes | None = None,
|
|
4259
|
+
) -> None:
|
|
4260
|
+
"""Drop a column from a table."""
|
|
4261
|
+
self._enrich_catalog_span(vgi_schema_name=schema_name, vgi_table_name=name)
|
|
4262
|
+
cat = self._get_catalog()
|
|
4263
|
+
cat.table_column_drop(
|
|
4264
|
+
attach_opaque_data=self._unwrap_attach(attach_opaque_data),
|
|
4265
|
+
transaction_opaque_data=self._unwrap_transaction(transaction_opaque_data, attach_opaque_data)
|
|
4266
|
+
if transaction_opaque_data
|
|
4267
|
+
else None,
|
|
4268
|
+
schema_name=schema_name,
|
|
4269
|
+
name=name,
|
|
4270
|
+
column_name=column_name,
|
|
4271
|
+
ignore_not_found=ignore_not_found,
|
|
4272
|
+
if_column_exists=if_column_exists,
|
|
4273
|
+
cascade=cascade,
|
|
4274
|
+
)
|
|
4275
|
+
|
|
4276
|
+
def catalog_table_column_rename(
|
|
4277
|
+
self,
|
|
4278
|
+
attach_opaque_data: bytes,
|
|
4279
|
+
schema_name: str,
|
|
4280
|
+
name: str,
|
|
4281
|
+
column_name: str,
|
|
4282
|
+
new_column_name: str,
|
|
4283
|
+
ignore_not_found: bool = False,
|
|
4284
|
+
transaction_opaque_data: bytes | None = None,
|
|
4285
|
+
) -> None:
|
|
4286
|
+
"""Rename a column."""
|
|
4287
|
+
self._enrich_catalog_span(vgi_schema_name=schema_name, vgi_table_name=name)
|
|
4288
|
+
cat = self._get_catalog()
|
|
4289
|
+
cat.table_column_rename(
|
|
4290
|
+
attach_opaque_data=self._unwrap_attach(attach_opaque_data),
|
|
4291
|
+
transaction_opaque_data=self._unwrap_transaction(transaction_opaque_data, attach_opaque_data)
|
|
4292
|
+
if transaction_opaque_data
|
|
4293
|
+
else None,
|
|
4294
|
+
schema_name=schema_name,
|
|
4295
|
+
name=name,
|
|
4296
|
+
column_name=column_name,
|
|
4297
|
+
new_column_name=new_column_name,
|
|
4298
|
+
ignore_not_found=ignore_not_found,
|
|
4299
|
+
)
|
|
4300
|
+
|
|
4301
|
+
def catalog_table_column_default_set(
|
|
4302
|
+
self,
|
|
4303
|
+
attach_opaque_data: bytes,
|
|
4304
|
+
schema_name: str,
|
|
4305
|
+
name: str,
|
|
4306
|
+
column_name: str,
|
|
4307
|
+
expression: str,
|
|
4308
|
+
ignore_not_found: bool = False,
|
|
4309
|
+
transaction_opaque_data: bytes | None = None,
|
|
4310
|
+
) -> None:
|
|
4311
|
+
"""Set the default value expression for a column."""
|
|
4312
|
+
self._enrich_catalog_span(vgi_schema_name=schema_name, vgi_table_name=name)
|
|
4313
|
+
cat = self._get_catalog()
|
|
4314
|
+
cat.table_column_default_set(
|
|
4315
|
+
attach_opaque_data=self._unwrap_attach(attach_opaque_data),
|
|
4316
|
+
transaction_opaque_data=self._unwrap_transaction(transaction_opaque_data, attach_opaque_data)
|
|
4317
|
+
if transaction_opaque_data
|
|
4318
|
+
else None,
|
|
4319
|
+
schema_name=schema_name,
|
|
4320
|
+
name=name,
|
|
4321
|
+
column_name=column_name,
|
|
4322
|
+
expression=SqlExpression(expression),
|
|
4323
|
+
ignore_not_found=ignore_not_found,
|
|
4324
|
+
)
|
|
4325
|
+
|
|
4326
|
+
def catalog_table_column_default_drop(
|
|
4327
|
+
self,
|
|
4328
|
+
attach_opaque_data: bytes,
|
|
4329
|
+
schema_name: str,
|
|
4330
|
+
name: str,
|
|
4331
|
+
column_name: str,
|
|
4332
|
+
ignore_not_found: bool = False,
|
|
4333
|
+
transaction_opaque_data: bytes | None = None,
|
|
4334
|
+
) -> None:
|
|
4335
|
+
"""Remove the default value from a column."""
|
|
4336
|
+
self._enrich_catalog_span(vgi_schema_name=schema_name, vgi_table_name=name)
|
|
4337
|
+
cat = self._get_catalog()
|
|
4338
|
+
cat.table_column_default_drop(
|
|
4339
|
+
attach_opaque_data=self._unwrap_attach(attach_opaque_data),
|
|
4340
|
+
transaction_opaque_data=self._unwrap_transaction(transaction_opaque_data, attach_opaque_data)
|
|
4341
|
+
if transaction_opaque_data
|
|
4342
|
+
else None,
|
|
4343
|
+
schema_name=schema_name,
|
|
4344
|
+
name=name,
|
|
4345
|
+
column_name=column_name,
|
|
4346
|
+
ignore_not_found=ignore_not_found,
|
|
4347
|
+
)
|
|
4348
|
+
|
|
4349
|
+
def catalog_table_column_type_change(
|
|
4350
|
+
self,
|
|
4351
|
+
attach_opaque_data: bytes,
|
|
4352
|
+
schema_name: str,
|
|
4353
|
+
name: str,
|
|
4354
|
+
column_definition: bytes,
|
|
4355
|
+
expression: str | None = None,
|
|
4356
|
+
ignore_not_found: bool = False,
|
|
4357
|
+
transaction_opaque_data: bytes | None = None,
|
|
4358
|
+
) -> None:
|
|
4359
|
+
"""Change the type of a column."""
|
|
4360
|
+
self._enrich_catalog_span(vgi_schema_name=schema_name, vgi_table_name=name)
|
|
4361
|
+
cat = self._get_catalog()
|
|
4362
|
+
cat.table_column_type_change(
|
|
4363
|
+
attach_opaque_data=self._unwrap_attach(attach_opaque_data),
|
|
4364
|
+
transaction_opaque_data=self._unwrap_transaction(transaction_opaque_data, attach_opaque_data)
|
|
4365
|
+
if transaction_opaque_data
|
|
4366
|
+
else None,
|
|
4367
|
+
schema_name=schema_name,
|
|
4368
|
+
name=name,
|
|
4369
|
+
column_definition=SerializedSchema(column_definition),
|
|
4370
|
+
expression=SqlExpression(expression) if expression else None,
|
|
4371
|
+
ignore_not_found=ignore_not_found,
|
|
4372
|
+
)
|
|
4373
|
+
|
|
4374
|
+
def catalog_table_not_null_drop(
|
|
4375
|
+
self,
|
|
4376
|
+
attach_opaque_data: bytes,
|
|
4377
|
+
schema_name: str,
|
|
4378
|
+
name: str,
|
|
4379
|
+
column_name: str,
|
|
4380
|
+
ignore_not_found: bool = False,
|
|
4381
|
+
transaction_opaque_data: bytes | None = None,
|
|
4382
|
+
) -> None:
|
|
4383
|
+
"""Remove NOT NULL constraint from a column."""
|
|
4384
|
+
self._enrich_catalog_span(vgi_schema_name=schema_name, vgi_table_name=name)
|
|
4385
|
+
cat = self._get_catalog()
|
|
4386
|
+
cat.table_not_null_drop(
|
|
4387
|
+
attach_opaque_data=self._unwrap_attach(attach_opaque_data),
|
|
4388
|
+
transaction_opaque_data=self._unwrap_transaction(transaction_opaque_data, attach_opaque_data)
|
|
4389
|
+
if transaction_opaque_data
|
|
4390
|
+
else None,
|
|
4391
|
+
schema_name=schema_name,
|
|
4392
|
+
name=name,
|
|
4393
|
+
column_name=column_name,
|
|
4394
|
+
ignore_not_found=ignore_not_found,
|
|
4395
|
+
)
|
|
4396
|
+
|
|
4397
|
+
def catalog_table_not_null_set(
|
|
4398
|
+
self,
|
|
4399
|
+
attach_opaque_data: bytes,
|
|
4400
|
+
schema_name: str,
|
|
4401
|
+
name: str,
|
|
4402
|
+
column_name: str,
|
|
4403
|
+
ignore_not_found: bool = False,
|
|
4404
|
+
transaction_opaque_data: bytes | None = None,
|
|
4405
|
+
) -> None:
|
|
4406
|
+
"""Add NOT NULL constraint to a column."""
|
|
4407
|
+
self._enrich_catalog_span(vgi_schema_name=schema_name, vgi_table_name=name)
|
|
4408
|
+
cat = self._get_catalog()
|
|
4409
|
+
cat.table_not_null_set(
|
|
4410
|
+
attach_opaque_data=self._unwrap_attach(attach_opaque_data),
|
|
4411
|
+
transaction_opaque_data=self._unwrap_transaction(transaction_opaque_data, attach_opaque_data)
|
|
4412
|
+
if transaction_opaque_data
|
|
4413
|
+
else None,
|
|
4414
|
+
schema_name=schema_name,
|
|
4415
|
+
name=name,
|
|
4416
|
+
column_name=column_name,
|
|
4417
|
+
ignore_not_found=ignore_not_found,
|
|
4418
|
+
)
|
|
4419
|
+
|
|
4420
|
+
# ---------------------------------------------------------------------------
|
|
4421
|
+
# VgiProtocol implementation - Catalog Views
|
|
4422
|
+
# ---------------------------------------------------------------------------
|
|
4423
|
+
|
|
4424
|
+
def catalog_view_get(
|
|
4425
|
+
self,
|
|
4426
|
+
attach_opaque_data: bytes,
|
|
4427
|
+
schema_name: str,
|
|
4428
|
+
name: str,
|
|
4429
|
+
transaction_opaque_data: bytes | None = None,
|
|
4430
|
+
) -> ViewsResponse:
|
|
4431
|
+
"""Get information about a view. Returns 0 or 1 items."""
|
|
4432
|
+
self._enrich_catalog_span(vgi_schema_name=schema_name, vgi_view_name=name)
|
|
4433
|
+
cat = self._get_catalog()
|
|
4434
|
+
info = cat.view_get(
|
|
4435
|
+
attach_opaque_data=self._unwrap_attach(attach_opaque_data),
|
|
4436
|
+
transaction_opaque_data=self._unwrap_transaction(transaction_opaque_data, attach_opaque_data)
|
|
4437
|
+
if transaction_opaque_data
|
|
4438
|
+
else None,
|
|
4439
|
+
schema_name=schema_name,
|
|
4440
|
+
name=name,
|
|
4441
|
+
)
|
|
4442
|
+
return ViewsResponse.from_optional(info)
|
|
4443
|
+
|
|
4444
|
+
def catalog_view_create(
|
|
4445
|
+
self,
|
|
4446
|
+
attach_opaque_data: bytes,
|
|
4447
|
+
schema_name: str,
|
|
4448
|
+
name: str,
|
|
4449
|
+
definition: str,
|
|
4450
|
+
on_conflict: OnConflict,
|
|
4451
|
+
transaction_opaque_data: bytes | None = None,
|
|
4452
|
+
) -> None:
|
|
4453
|
+
"""Create a new view."""
|
|
4454
|
+
self._enrich_catalog_span(vgi_schema_name=schema_name, vgi_view_name=name)
|
|
4455
|
+
cat = self._get_catalog()
|
|
4456
|
+
cat.view_create(
|
|
4457
|
+
attach_opaque_data=self._unwrap_attach(attach_opaque_data),
|
|
4458
|
+
transaction_opaque_data=self._unwrap_transaction(transaction_opaque_data, attach_opaque_data)
|
|
4459
|
+
if transaction_opaque_data
|
|
4460
|
+
else None,
|
|
4461
|
+
schema_name=schema_name,
|
|
4462
|
+
name=name,
|
|
4463
|
+
definition=definition,
|
|
4464
|
+
on_conflict=on_conflict,
|
|
4465
|
+
)
|
|
4466
|
+
|
|
4467
|
+
def catalog_view_drop(
|
|
4468
|
+
self,
|
|
4469
|
+
attach_opaque_data: bytes,
|
|
4470
|
+
schema_name: str,
|
|
4471
|
+
name: str,
|
|
4472
|
+
ignore_not_found: bool = False,
|
|
4473
|
+
cascade: bool = False,
|
|
4474
|
+
transaction_opaque_data: bytes | None = None,
|
|
4475
|
+
) -> None:
|
|
4476
|
+
"""Drop a view."""
|
|
4477
|
+
self._enrich_catalog_span(vgi_schema_name=schema_name, vgi_view_name=name)
|
|
4478
|
+
cat = self._get_catalog()
|
|
4479
|
+
cat.view_drop(
|
|
4480
|
+
attach_opaque_data=self._unwrap_attach(attach_opaque_data),
|
|
4481
|
+
transaction_opaque_data=self._unwrap_transaction(transaction_opaque_data, attach_opaque_data)
|
|
4482
|
+
if transaction_opaque_data
|
|
4483
|
+
else None,
|
|
4484
|
+
schema_name=schema_name,
|
|
4485
|
+
name=name,
|
|
4486
|
+
ignore_not_found=ignore_not_found,
|
|
4487
|
+
cascade=cascade,
|
|
4488
|
+
)
|
|
4489
|
+
|
|
4490
|
+
def catalog_view_rename(
|
|
4491
|
+
self,
|
|
4492
|
+
attach_opaque_data: bytes,
|
|
4493
|
+
schema_name: str,
|
|
4494
|
+
name: str,
|
|
4495
|
+
new_name: str,
|
|
4496
|
+
ignore_not_found: bool = False,
|
|
4497
|
+
transaction_opaque_data: bytes | None = None,
|
|
4498
|
+
) -> None:
|
|
4499
|
+
"""Rename a view."""
|
|
4500
|
+
self._enrich_catalog_span(vgi_schema_name=schema_name, vgi_view_name=name)
|
|
4501
|
+
cat = self._get_catalog()
|
|
4502
|
+
cat.view_rename(
|
|
4503
|
+
attach_opaque_data=self._unwrap_attach(attach_opaque_data),
|
|
4504
|
+
transaction_opaque_data=self._unwrap_transaction(transaction_opaque_data, attach_opaque_data)
|
|
4505
|
+
if transaction_opaque_data
|
|
4506
|
+
else None,
|
|
4507
|
+
schema_name=schema_name,
|
|
4508
|
+
name=name,
|
|
4509
|
+
new_name=new_name,
|
|
4510
|
+
ignore_not_found=ignore_not_found,
|
|
4511
|
+
)
|
|
4512
|
+
|
|
4513
|
+
def catalog_view_comment_set(
|
|
4514
|
+
self,
|
|
4515
|
+
attach_opaque_data: bytes,
|
|
4516
|
+
schema_name: str,
|
|
4517
|
+
name: str,
|
|
4518
|
+
comment: str | None = None,
|
|
4519
|
+
ignore_not_found: bool = False,
|
|
4520
|
+
transaction_opaque_data: bytes | None = None,
|
|
4521
|
+
) -> None:
|
|
4522
|
+
"""Set or clear the comment on a view."""
|
|
4523
|
+
self._enrich_catalog_span(vgi_schema_name=schema_name, vgi_view_name=name)
|
|
4524
|
+
cat = self._get_catalog()
|
|
4525
|
+
cat.view_comment_set(
|
|
4526
|
+
attach_opaque_data=self._unwrap_attach(attach_opaque_data),
|
|
4527
|
+
transaction_opaque_data=self._unwrap_transaction(transaction_opaque_data, attach_opaque_data)
|
|
4528
|
+
if transaction_opaque_data
|
|
4529
|
+
else None,
|
|
4530
|
+
schema_name=schema_name,
|
|
4531
|
+
name=name,
|
|
4532
|
+
comment=comment,
|
|
4533
|
+
ignore_not_found=ignore_not_found,
|
|
4534
|
+
)
|
|
4535
|
+
|
|
4536
|
+
# ---------------------------------------------------------------------------
|
|
4537
|
+
# VgiProtocol implementation - Catalog Macros
|
|
4538
|
+
# ---------------------------------------------------------------------------
|
|
4539
|
+
|
|
4540
|
+
def catalog_macro_get(
|
|
4541
|
+
self,
|
|
4542
|
+
attach_opaque_data: bytes,
|
|
4543
|
+
schema_name: str,
|
|
4544
|
+
name: str,
|
|
4545
|
+
transaction_opaque_data: bytes | None = None,
|
|
4546
|
+
) -> MacrosResponse:
|
|
4547
|
+
"""Get information about a macro. Returns 0 or 1 items."""
|
|
4548
|
+
self._enrich_catalog_span(vgi_schema_name=schema_name, vgi_macro_name=name)
|
|
4549
|
+
cat = self._get_catalog()
|
|
4550
|
+
info = cat.macro_get(
|
|
4551
|
+
attach_opaque_data=self._unwrap_attach(attach_opaque_data),
|
|
4552
|
+
transaction_opaque_data=self._unwrap_transaction(transaction_opaque_data, attach_opaque_data)
|
|
4553
|
+
if transaction_opaque_data
|
|
4554
|
+
else None,
|
|
4555
|
+
schema_name=schema_name,
|
|
4556
|
+
name=name,
|
|
4557
|
+
)
|
|
4558
|
+
return MacrosResponse.from_optional(info)
|
|
4559
|
+
|
|
4560
|
+
def catalog_macro_create(self, request: MacroCreateRequest) -> None:
|
|
4561
|
+
"""Create a new macro."""
|
|
4562
|
+
self._enrich_catalog_span(vgi_schema_name=request.schema_name, vgi_macro_name=request.name)
|
|
4563
|
+
cat = self._get_catalog()
|
|
4564
|
+
cat.macro_create(
|
|
4565
|
+
attach_opaque_data=self._unwrap_attach(request.attach_opaque_data),
|
|
4566
|
+
transaction_opaque_data=self._unwrap_transaction(
|
|
4567
|
+
request.transaction_opaque_data, request.attach_opaque_data
|
|
4568
|
+
)
|
|
4569
|
+
if request.transaction_opaque_data
|
|
4570
|
+
else None,
|
|
4571
|
+
schema_name=request.schema_name,
|
|
4572
|
+
name=request.name,
|
|
4573
|
+
macro_type=request.macro_type,
|
|
4574
|
+
parameters=request.parameters,
|
|
4575
|
+
definition=request.definition,
|
|
4576
|
+
on_conflict=request.on_conflict,
|
|
4577
|
+
parameter_default_values=request.parameter_default_values,
|
|
4578
|
+
)
|
|
4579
|
+
|
|
4580
|
+
def catalog_macro_drop(
|
|
4581
|
+
self,
|
|
4582
|
+
attach_opaque_data: bytes,
|
|
4583
|
+
schema_name: str,
|
|
4584
|
+
name: str,
|
|
4585
|
+
ignore_not_found: bool = False,
|
|
4586
|
+
transaction_opaque_data: bytes | None = None,
|
|
4587
|
+
) -> None:
|
|
4588
|
+
"""Drop a macro."""
|
|
4589
|
+
self._enrich_catalog_span(vgi_schema_name=schema_name, vgi_macro_name=name)
|
|
4590
|
+
cat = self._get_catalog()
|
|
4591
|
+
cat.macro_drop(
|
|
4592
|
+
attach_opaque_data=self._unwrap_attach(attach_opaque_data),
|
|
4593
|
+
transaction_opaque_data=self._unwrap_transaction(transaction_opaque_data, attach_opaque_data)
|
|
4594
|
+
if transaction_opaque_data
|
|
4595
|
+
else None,
|
|
4596
|
+
schema_name=schema_name,
|
|
4597
|
+
name=name,
|
|
4598
|
+
ignore_not_found=ignore_not_found,
|
|
4599
|
+
)
|
|
4600
|
+
|
|
4601
|
+
def catalog_schema_contents_macros(
|
|
4602
|
+
self,
|
|
4603
|
+
attach_opaque_data: bytes,
|
|
4604
|
+
name: str,
|
|
4605
|
+
type: SchemaObjectType,
|
|
4606
|
+
transaction_opaque_data: bytes | None = None,
|
|
4607
|
+
) -> MacrosResponse:
|
|
4608
|
+
"""List macros in a schema (scalar or table)."""
|
|
4609
|
+
self._enrich_catalog_span(vgi_schema_name=name)
|
|
4610
|
+
cat = self._get_catalog()
|
|
4611
|
+
infos = cat.schema_contents( # type: ignore[call-overload]
|
|
4612
|
+
attach_opaque_data=self._unwrap_attach(attach_opaque_data),
|
|
4613
|
+
transaction_opaque_data=self._unwrap_transaction(transaction_opaque_data, attach_opaque_data)
|
|
4614
|
+
if transaction_opaque_data
|
|
4615
|
+
else None,
|
|
4616
|
+
name=name,
|
|
4617
|
+
type=type,
|
|
4618
|
+
)
|
|
4619
|
+
return MacrosResponse.from_infos(list(infos))
|
|
4620
|
+
|
|
4621
|
+
# ---------------------------------------------------------------------------
|
|
4622
|
+
# VgiProtocol implementation - Catalog Indexes
|
|
4623
|
+
# ---------------------------------------------------------------------------
|
|
4624
|
+
|
|
4625
|
+
def catalog_index_get(
|
|
4626
|
+
self,
|
|
4627
|
+
attach_opaque_data: bytes,
|
|
4628
|
+
schema_name: str,
|
|
4629
|
+
name: str,
|
|
4630
|
+
transaction_opaque_data: bytes | None = None,
|
|
4631
|
+
) -> IndexesResponse:
|
|
4632
|
+
"""Get information about an index. Returns 0 or 1 items."""
|
|
4633
|
+
self._enrich_catalog_span(vgi_schema_name=schema_name, vgi_index_name=name)
|
|
4634
|
+
cat = self._get_catalog()
|
|
4635
|
+
info = cat.index_get(
|
|
4636
|
+
attach_opaque_data=self._unwrap_attach(attach_opaque_data),
|
|
4637
|
+
transaction_opaque_data=self._unwrap_transaction(transaction_opaque_data, attach_opaque_data)
|
|
4638
|
+
if transaction_opaque_data
|
|
4639
|
+
else None,
|
|
4640
|
+
schema_name=schema_name,
|
|
4641
|
+
name=name,
|
|
4642
|
+
)
|
|
4643
|
+
return IndexesResponse.from_optional(info)
|
|
4644
|
+
|
|
4645
|
+
def catalog_index_create(self, request: IndexCreateRequest) -> None:
|
|
4646
|
+
"""Create a new index."""
|
|
4647
|
+
self._enrich_catalog_span(vgi_schema_name=request.schema_name, vgi_index_name=request.name)
|
|
4648
|
+
cat = self._get_catalog()
|
|
4649
|
+
cat.index_create(
|
|
4650
|
+
attach_opaque_data=self._unwrap_attach(request.attach_opaque_data),
|
|
4651
|
+
transaction_opaque_data=self._unwrap_transaction(
|
|
4652
|
+
request.transaction_opaque_data, request.attach_opaque_data
|
|
4653
|
+
)
|
|
4654
|
+
if request.transaction_opaque_data
|
|
4655
|
+
else None,
|
|
4656
|
+
schema_name=request.schema_name,
|
|
4657
|
+
name=request.name,
|
|
4658
|
+
table_name=request.table_name,
|
|
4659
|
+
index_type=request.index_type,
|
|
4660
|
+
constraint_type=request.constraint_type,
|
|
4661
|
+
expressions=list(request.expressions),
|
|
4662
|
+
on_conflict=request.on_conflict,
|
|
4663
|
+
options=dict(request.options) if request.options else None,
|
|
4664
|
+
)
|
|
4665
|
+
|
|
4666
|
+
def catalog_index_drop(
|
|
4667
|
+
self,
|
|
4668
|
+
attach_opaque_data: bytes,
|
|
4669
|
+
schema_name: str,
|
|
4670
|
+
name: str,
|
|
4671
|
+
ignore_not_found: bool = False,
|
|
4672
|
+
cascade: bool = False,
|
|
4673
|
+
transaction_opaque_data: bytes | None = None,
|
|
4674
|
+
) -> None:
|
|
4675
|
+
"""Drop an index."""
|
|
4676
|
+
self._enrich_catalog_span(vgi_schema_name=schema_name, vgi_index_name=name)
|
|
4677
|
+
cat = self._get_catalog()
|
|
4678
|
+
cat.index_drop(
|
|
4679
|
+
attach_opaque_data=self._unwrap_attach(attach_opaque_data),
|
|
4680
|
+
transaction_opaque_data=self._unwrap_transaction(transaction_opaque_data, attach_opaque_data)
|
|
4681
|
+
if transaction_opaque_data
|
|
4682
|
+
else None,
|
|
4683
|
+
schema_name=schema_name,
|
|
4684
|
+
name=name,
|
|
4685
|
+
ignore_not_found=ignore_not_found,
|
|
4686
|
+
cascade=cascade,
|
|
4687
|
+
)
|
|
4688
|
+
|
|
4689
|
+
def catalog_schema_contents_indexes(
|
|
4690
|
+
self,
|
|
4691
|
+
attach_opaque_data: bytes,
|
|
4692
|
+
name: str,
|
|
4693
|
+
transaction_opaque_data: bytes | None = None,
|
|
4694
|
+
) -> IndexesResponse:
|
|
4695
|
+
"""List indexes in a schema."""
|
|
4696
|
+
self._enrich_catalog_span(vgi_schema_name=name)
|
|
4697
|
+
cat = self._get_catalog()
|
|
4698
|
+
infos = cat.schema_contents(
|
|
4699
|
+
attach_opaque_data=self._unwrap_attach(attach_opaque_data),
|
|
4700
|
+
transaction_opaque_data=self._unwrap_transaction(transaction_opaque_data, attach_opaque_data)
|
|
4701
|
+
if transaction_opaque_data
|
|
4702
|
+
else None,
|
|
4703
|
+
name=name,
|
|
4704
|
+
type=SchemaObjectType.INDEX,
|
|
4705
|
+
)
|
|
4706
|
+
return IndexesResponse.from_infos(list(infos))
|
|
4707
|
+
|
|
4708
|
+
# ---------------------------------------------------------------------------
|
|
4709
|
+
# Lifecycle
|
|
4710
|
+
# ---------------------------------------------------------------------------
|
|
4711
|
+
|
|
4712
|
+
def __init__(self, *, quiet: bool = False, log_level: int = logging.INFO) -> None:
|
|
4713
|
+
"""Initialize the worker with logging.
|
|
4714
|
+
|
|
4715
|
+
Args:
|
|
4716
|
+
quiet: If True, suppress startup logging output. Can also be enabled
|
|
4717
|
+
by setting the VGI_QUIET=1 environment variable.
|
|
4718
|
+
log_level: Numeric logging level for the ``vgi`` logger hierarchy.
|
|
4719
|
+
|
|
4720
|
+
"""
|
|
4721
|
+
self._quiet = quiet or os.environ.get("VGI_QUIET") == "1"
|
|
4722
|
+
self._vgi_tracer: VgiTracer = get_noop_tracer()
|
|
4723
|
+
logging.getLogger("vgi").setLevel(log_level)
|
|
4724
|
+
|
|
4725
|
+
def run(self, otel_config: Any = None) -> None:
|
|
4726
|
+
"""Run the worker, reading from stdin and writing to stdout.
|
|
4727
|
+
|
|
4728
|
+
Args:
|
|
4729
|
+
otel_config: Optional ``OtelConfig`` for OpenTelemetry instrumentation.
|
|
4730
|
+
When provided, instruments the RPC server and creates a VGI tracer.
|
|
4731
|
+
|
|
4732
|
+
"""
|
|
4733
|
+
# Warn if stdin is a terminal - user likely ran worker directly
|
|
4734
|
+
if sys.stdin.isatty() and not self._quiet:
|
|
4735
|
+
sys.stderr.write(
|
|
4736
|
+
"\n"
|
|
4737
|
+
"Warning: This worker expects Arrow IPC binary data on stdin.\n"
|
|
4738
|
+
"It is not meant to be run interactively in a terminal.\n"
|
|
4739
|
+
"\n"
|
|
4740
|
+
"Usage:\n"
|
|
4741
|
+
" - Use vgi-client to invoke functions\n"
|
|
4742
|
+
" - Use DuckDB with VGI extension\n"
|
|
4743
|
+
"\n"
|
|
4744
|
+
"To suppress this warning: --quiet or VGI_QUIET=1\n"
|
|
4745
|
+
"\n"
|
|
4746
|
+
)
|
|
4747
|
+
sys.stderr.flush()
|
|
4748
|
+
|
|
4749
|
+
_logger.info("worker_starting")
|
|
4750
|
+
|
|
4751
|
+
try:
|
|
4752
|
+
server = RpcServer(self.protocol_class, self, server_version=_get_vgi_version())
|
|
4753
|
+
if otel_config is not None:
|
|
4754
|
+
from vgi_rpc.otel import instrument_server
|
|
4755
|
+
|
|
4756
|
+
instrument_server(server, otel_config)
|
|
4757
|
+
self._vgi_tracer = VgiTracer.create(otel_config)
|
|
4758
|
+
serve_stdio(server)
|
|
4759
|
+
except KeyboardInterrupt:
|
|
4760
|
+
_logger.debug("worker_interrupted")
|
|
4761
|
+
sys.exit(130)
|