vgi-python 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vgi/__init__.py +152 -0
- vgi/_duckdb.py +62 -0
- vgi/_storage_profile.py +132 -0
- vgi/_test_fixtures/__init__.py +20 -0
- vgi/_test_fixtures/accumulate/__init__.py +19 -0
- vgi/_test_fixtures/accumulate/worker.py +762 -0
- vgi/_test_fixtures/aggregate/__init__.py +62 -0
- vgi/_test_fixtures/aggregate/_common.py +21 -0
- vgi/_test_fixtures/aggregate/basic.py +232 -0
- vgi/_test_fixtures/aggregate/dynamic.py +409 -0
- vgi/_test_fixtures/aggregate/generic.py +86 -0
- vgi/_test_fixtures/aggregate/listagg.py +71 -0
- vgi/_test_fixtures/aggregate/percentile.py +107 -0
- vgi/_test_fixtures/aggregate/streaming.py +192 -0
- vgi/_test_fixtures/aggregate/varargs.py +75 -0
- vgi/_test_fixtures/aggregate/window.py +380 -0
- vgi/_test_fixtures/attach_options.py +308 -0
- vgi/_test_fixtures/bad_protocol.py +62 -0
- vgi/_test_fixtures/cancellable.py +336 -0
- vgi/_test_fixtures/catalog.py +813 -0
- vgi/_test_fixtures/http_server.py +394 -0
- vgi/_test_fixtures/nest_tensor.py +614 -0
- vgi/_test_fixtures/orchard_catalog.py +47 -0
- vgi/_test_fixtures/projection_repro/__init__.py +6 -0
- vgi/_test_fixtures/projection_repro/worker.py +454 -0
- vgi/_test_fixtures/scalar/__init__.py +116 -0
- vgi/_test_fixtures/scalar/_common.py +69 -0
- vgi/_test_fixtures/scalar/arithmetic.py +321 -0
- vgi/_test_fixtures/scalar/binary.py +120 -0
- vgi/_test_fixtures/scalar/formatting.py +176 -0
- vgi/_test_fixtures/scalar/geo.py +300 -0
- vgi/_test_fixtures/scalar/null_handling.py +107 -0
- vgi/_test_fixtures/scalar/random_demo.py +171 -0
- vgi/_test_fixtures/scalar/settings_secrets.py +102 -0
- vgi/_test_fixtures/scalar/type_info.py +219 -0
- vgi/_test_fixtures/schema_reconcile/__init__.py +29 -0
- vgi/_test_fixtures/schema_reconcile/worker.py +653 -0
- vgi/_test_fixtures/simple_writable.py +793 -0
- vgi/_test_fixtures/table/__init__.py +221 -0
- vgi/_test_fixtures/table/_common.py +162 -0
- vgi/_test_fixtures/table/batch_index.py +283 -0
- vgi/_test_fixtures/table/batch_index_broken.py +200 -0
- vgi/_test_fixtures/table/catalog_scans.py +162 -0
- vgi/_test_fixtures/table/filters.py +1005 -0
- vgi/_test_fixtures/table/late_materialization.py +249 -0
- vgi/_test_fixtures/table/make_series.py +273 -0
- vgi/_test_fixtures/table/misc.py +499 -0
- vgi/_test_fixtures/table/order_modes.py +164 -0
- vgi/_test_fixtures/table/pairs.py +437 -0
- vgi/_test_fixtures/table/partition_columns.py +472 -0
- vgi/_test_fixtures/table/partition_columns_broken.py +304 -0
- vgi/_test_fixtures/table/profiling_example.py +195 -0
- vgi/_test_fixtures/table/required_filters.py +234 -0
- vgi/_test_fixtures/table/sequence.py +710 -0
- vgi/_test_fixtures/table/settings.py +426 -0
- vgi/_test_fixtures/table/transaction_storage.py +162 -0
- vgi/_test_fixtures/table/tt_pushdown.py +191 -0
- vgi/_test_fixtures/table/versioned.py +230 -0
- vgi/_test_fixtures/table_in_out.py +1392 -0
- vgi/_test_fixtures/versioned.py +155 -0
- vgi/_test_fixtures/versioned_tables.py +595 -0
- vgi/_test_fixtures/worker.py +1631 -0
- vgi/_test_fixtures/writable/__init__.py +8 -0
- vgi/_test_fixtures/writable/generic.py +236 -0
- vgi/_test_fixtures/writable/table.py +149 -0
- vgi/_test_fixtures/writable/worker.py +1148 -0
- vgi/aggregate_function.py +607 -0
- vgi/argument_spec.py +472 -0
- vgi/arguments.py +1747 -0
- vgi/auth.py +55 -0
- vgi/catalog/__init__.py +88 -0
- vgi/catalog/attach_option.py +206 -0
- vgi/catalog/catalog_interface.py +2767 -0
- vgi/catalog/descriptors.py +870 -0
- vgi/catalog/duckdb_statistics.py +377 -0
- vgi/catalog/secret_type.py +96 -0
- vgi/catalog/setting.py +253 -0
- vgi/catalog/storage.py +372 -0
- vgi/client/__init__.py +67 -0
- vgi/client/catalog_mixin.py +1251 -0
- vgi/client/cli.py +582 -0
- vgi/client/cli_catalog.py +182 -0
- vgi/client/cli_schema.py +270 -0
- vgi/client/cli_table.py +907 -0
- vgi/client/cli_transaction.py +97 -0
- vgi/client/cli_utils.py +441 -0
- vgi/client/cli_view.py +303 -0
- vgi/client/client.py +2183 -0
- vgi/exceptions.py +205 -0
- vgi/function.py +245 -0
- vgi/function_storage.py +1636 -0
- vgi/function_storage_azure_sql.py +922 -0
- vgi/function_storage_cf_do.py +740 -0
- vgi/http/__init__.py +25 -0
- vgi/http/demo_storage.py +212 -0
- vgi/http/worker_page.py +1252 -0
- vgi/invocation.py +154 -0
- vgi/logging_config.py +93 -0
- vgi/meta_worker.py +661 -0
- vgi/metadata.py +1403 -0
- vgi/otel.py +406 -0
- vgi/protocol.py +2418 -0
- vgi/protocol_version.txt +1 -0
- vgi/py.typed +0 -0
- vgi/scalar_function.py +1211 -0
- vgi/schema_utils.py +234 -0
- vgi/secret_protocol.py +124 -0
- vgi/secret_service.py +238 -0
- vgi/serve.py +769 -0
- vgi/table_buffering_function.py +443 -0
- vgi/table_filter_pushdown.py +1528 -0
- vgi/table_function.py +1130 -0
- vgi/table_in_out_function.py +383 -0
- vgi/transactor/__init__.py +24 -0
- vgi/transactor/_duckdb_compat.py +27 -0
- vgi/transactor/client.py +137 -0
- vgi/transactor/protocol.py +149 -0
- vgi/transactor/server.py +740 -0
- vgi/worker.py +4761 -0
- vgi_python-0.8.0.dist-info/METADATA +735 -0
- vgi_python-0.8.0.dist-info/RECORD +124 -0
- vgi_python-0.8.0.dist-info/WHEEL +4 -0
- vgi_python-0.8.0.dist-info/entry_points.txt +5 -0
- vgi_python-0.8.0.dist-info/licenses/LICENSE +134 -0
vgi/scalar_function.py
ADDED
|
@@ -0,0 +1,1211 @@
|
|
|
1
|
+
# Copyright 2025, 2026 Query Farm LLC - https://query.farm
|
|
2
|
+
|
|
3
|
+
"""Scalar functions: per-row transforms with single-column output.
|
|
4
|
+
|
|
5
|
+
Scalar functions are the simplest function type in VGI. They transform each
|
|
6
|
+
input row into exactly one output value, producing a single column of results.
|
|
7
|
+
|
|
8
|
+
Key characteristics:
|
|
9
|
+
- **1:1 row mapping**: Output has exactly the same number of rows as input
|
|
10
|
+
- **Single column output**: Output schema has exactly one column named "result"
|
|
11
|
+
- **No finish()**: Processing ends when the caller closes the input stream.
|
|
12
|
+
|
|
13
|
+
Common use cases:
|
|
14
|
+
- Mathematical operations: multiply, add, abs
|
|
15
|
+
- String transforms: upper, lower, concat, trim
|
|
16
|
+
- Type conversions: cast, parse
|
|
17
|
+
- Field extraction: get nested values, parse JSON fields
|
|
18
|
+
|
|
19
|
+
This module provides two base classes:
|
|
20
|
+
|
|
21
|
+
ScalarFunction (recommended)
|
|
22
|
+
Declarative API using Param/ConstParam/Returns annotations on compute().
|
|
23
|
+
Also supports Setting, Secret, and OutputLength annotations.
|
|
24
|
+
Override output_type() only if the output type depends on input schema.
|
|
25
|
+
|
|
26
|
+
ScalarFunctionGenerator (advanced)
|
|
27
|
+
Per-batch callback API for fine-grained control.
|
|
28
|
+
Override output_type() and process().
|
|
29
|
+
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
from __future__ import annotations
|
|
33
|
+
|
|
34
|
+
import contextlib
|
|
35
|
+
import inspect
|
|
36
|
+
import logging
|
|
37
|
+
import uuid
|
|
38
|
+
from abc import abstractmethod
|
|
39
|
+
from dataclasses import dataclass
|
|
40
|
+
from typing import TYPE_CHECKING, Any, cast, final, get_args, get_origin, get_type_hints
|
|
41
|
+
|
|
42
|
+
import pyarrow as pa
|
|
43
|
+
from vgi_rpc import ArrowSerializableDataclass
|
|
44
|
+
from vgi_rpc.rpc import AuthContext, CallContext
|
|
45
|
+
|
|
46
|
+
import vgi.function
|
|
47
|
+
from vgi.arguments import (
|
|
48
|
+
_PYTHON_TO_ARROW,
|
|
49
|
+
ARRAY_CLASS_TO_DATATYPE,
|
|
50
|
+
COMPLEX_ARRAY_CLASSES,
|
|
51
|
+
Arg,
|
|
52
|
+
Arguments,
|
|
53
|
+
ArgumentValidationError,
|
|
54
|
+
Auth,
|
|
55
|
+
ConstParam,
|
|
56
|
+
OutputLength,
|
|
57
|
+
Param,
|
|
58
|
+
Returns,
|
|
59
|
+
Secret,
|
|
60
|
+
SecretLookupEntry,
|
|
61
|
+
_extract_setting_secret_params,
|
|
62
|
+
)
|
|
63
|
+
from vgi.function_storage import BoundStorage, attach_catalog_bytes
|
|
64
|
+
from vgi.invocation import (
|
|
65
|
+
BaseInitResponse,
|
|
66
|
+
BindResponse,
|
|
67
|
+
GlobalInitResponse,
|
|
68
|
+
)
|
|
69
|
+
from vgi.schema_utils import schema
|
|
70
|
+
from vgi.table_function import SecretsAccessor, _struct_scalar_to_dict
|
|
71
|
+
|
|
72
|
+
if TYPE_CHECKING:
|
|
73
|
+
from vgi.protocol import BindRequest, InitRequest
|
|
74
|
+
|
|
75
|
+
logger = logging.getLogger(__name__)
|
|
76
|
+
|
|
77
|
+
__all__ = [
|
|
78
|
+
"BindParameters",
|
|
79
|
+
"BindResult",
|
|
80
|
+
"RowCountMismatchError",
|
|
81
|
+
"ScalarFunction",
|
|
82
|
+
"ScalarFunctionGenerator",
|
|
83
|
+
"TypeMismatchError",
|
|
84
|
+
]
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@dataclass(slots=True, frozen=True)
|
|
88
|
+
class BindResult(ArrowSerializableDataclass):
|
|
89
|
+
"""Result of calling bind() on a scalar function.
|
|
90
|
+
|
|
91
|
+
Unlike table functions which return a full schema, scalar functions
|
|
92
|
+
return a single output type since they produce one value per row.
|
|
93
|
+
|
|
94
|
+
Attributes:
|
|
95
|
+
output_type: Arrow data type for the output value.
|
|
96
|
+
opaque_data: Optional serialized data, opaque to the caller,
|
|
97
|
+
that will be passed to global_init() and process().
|
|
98
|
+
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
output_type: pa.DataType
|
|
102
|
+
opaque_data: ArrowSerializableDataclass | None = None
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
@dataclass(slots=True, frozen=True)
|
|
106
|
+
class BindParameters:
|
|
107
|
+
"""Parameters passed to a scalar function's bind() method.
|
|
108
|
+
|
|
109
|
+
Attributes:
|
|
110
|
+
constant_arguments: Constant arguments provided at query planning time.
|
|
111
|
+
arguments_schema: Schema describing the input columns.
|
|
112
|
+
settings: DuckDB settings as a single-row RecordBatch, or None.
|
|
113
|
+
secrets: SecretsAccessor for accessing resolved and dynamic secrets.
|
|
114
|
+
auth_context: Authentication context for the current request.
|
|
115
|
+
attach_opaque_data: Catalog attach ID, if the function was invoked through an ATTACHed catalog.
|
|
116
|
+
transaction_opaque_data: Catalog transaction ID, if invoked inside a catalog transaction.
|
|
117
|
+
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
constant_arguments: Arguments
|
|
121
|
+
arguments_schema: pa.Schema
|
|
122
|
+
settings: pa.RecordBatch | None
|
|
123
|
+
secrets: SecretsAccessor
|
|
124
|
+
auth_context: AuthContext = AuthContext.anonymous()
|
|
125
|
+
attach_opaque_data: bytes | None = None
|
|
126
|
+
transaction_opaque_data: bytes | None = None
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def _resolve_explicit_arrow_type(arrow_type: pa.DataType | type) -> pa.DataType:
|
|
130
|
+
"""Resolve an explicit arrow_type value to a pa.DataType.
|
|
131
|
+
|
|
132
|
+
Handles pa.DataType instances and Python types (int/str/float/bool/bytes).
|
|
133
|
+
|
|
134
|
+
Raises:
|
|
135
|
+
TypeError: If the type cannot be converted to Arrow.
|
|
136
|
+
|
|
137
|
+
"""
|
|
138
|
+
if isinstance(arrow_type, pa.DataType):
|
|
139
|
+
return arrow_type
|
|
140
|
+
if arrow_type in _PYTHON_TO_ARROW:
|
|
141
|
+
return _PYTHON_TO_ARROW[arrow_type]
|
|
142
|
+
raise TypeError(
|
|
143
|
+
f"Cannot convert type '{arrow_type}' to Arrow type. "
|
|
144
|
+
f"Use pa.DataType, Python type (int/str/float/bool/bytes), "
|
|
145
|
+
f"or None for AnyArrow."
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def _param_to_arg(param: Param, base_type: type, position: int) -> Arg[Any]:
|
|
150
|
+
"""Convert Param dataclass to internal Arg object with type inference.
|
|
151
|
+
|
|
152
|
+
Supports hybrid type inference:
|
|
153
|
+
1. Explicit arrow_type in Param() takes priority
|
|
154
|
+
2. Simple array classes (pa.Int64Array, etc.) are inferred automatically
|
|
155
|
+
3. Complex/parameterized types (pa.StructArray, etc.) require explicit arrow_type
|
|
156
|
+
4. pa.Array or pa.Array[Any] indicates AnyArrow (dynamic type)
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
param: The Param metadata from an Annotated type hint.
|
|
160
|
+
base_type: The type from the Annotated first argument (e.g., pa.Int64Array
|
|
161
|
+
from Annotated[pa.Int64Array, Param(...)]).
|
|
162
|
+
position: The parameter's position in the compute() signature.
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
Arg instance configured for columnar input.
|
|
166
|
+
|
|
167
|
+
Raises:
|
|
168
|
+
TypeError: If type cannot be determined (complex type without explicit
|
|
169
|
+
arrow_type).
|
|
170
|
+
|
|
171
|
+
"""
|
|
172
|
+
is_any = False
|
|
173
|
+
arrow_type: pa.DataType
|
|
174
|
+
|
|
175
|
+
# For varargs params, unwrap list[X] to get the element type X
|
|
176
|
+
infer_type = base_type
|
|
177
|
+
if param.varargs and get_origin(base_type) is list:
|
|
178
|
+
type_args = get_args(base_type)
|
|
179
|
+
if type_args:
|
|
180
|
+
infer_type = type_args[0]
|
|
181
|
+
|
|
182
|
+
if param.arrow_type is not None:
|
|
183
|
+
arrow_type = _resolve_explicit_arrow_type(param.arrow_type)
|
|
184
|
+
# Infer from simple array class (pa.Int64Array -> pa.int64())
|
|
185
|
+
elif infer_type in ARRAY_CLASS_TO_DATATYPE:
|
|
186
|
+
arrow_type = ARRAY_CLASS_TO_DATATYPE[infer_type]
|
|
187
|
+
# Complex types require explicit arrow_type
|
|
188
|
+
elif infer_type in COMPLEX_ARRAY_CLASSES:
|
|
189
|
+
raise TypeError(
|
|
190
|
+
f"{base_type.__name__} requires explicit arrow_type in Param(). "
|
|
191
|
+
f"Example: Param(arrow_type=pa.list_(pa.int64()), doc='...')"
|
|
192
|
+
)
|
|
193
|
+
# pa.Array or generic -> AnyArrow
|
|
194
|
+
else:
|
|
195
|
+
# Covers pa.Array, pa.Array[Any], Any, and other generic types
|
|
196
|
+
is_any = True
|
|
197
|
+
arrow_type = pa.null() # Placeholder for AnyArrow
|
|
198
|
+
|
|
199
|
+
return Arg[Any](
|
|
200
|
+
position,
|
|
201
|
+
doc=param.doc,
|
|
202
|
+
arrow_type=arrow_type,
|
|
203
|
+
type_bound=param.type_bound,
|
|
204
|
+
varargs=param.varargs,
|
|
205
|
+
is_any=is_any,
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def _const_param_to_arg(const_param: ConstParam, base_type: type, position: int) -> Arg[Any]:
|
|
210
|
+
"""Convert ConstParam dataclass to internal Arg object.
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
const_param: The ConstParam metadata from an Annotated type hint.
|
|
214
|
+
base_type: The type from the Annotated first argument (e.g., int from
|
|
215
|
+
Annotated[int, ConstParam(...)]).
|
|
216
|
+
position: The parameter's position in the const arguments.
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
Arg instance configured for constant (scalar) input.
|
|
220
|
+
|
|
221
|
+
Raises:
|
|
222
|
+
TypeError: If the Arrow type cannot be determined.
|
|
223
|
+
|
|
224
|
+
"""
|
|
225
|
+
arrow_type: pa.DataType
|
|
226
|
+
|
|
227
|
+
if const_param.arrow_type is not None:
|
|
228
|
+
arrow_type = _resolve_explicit_arrow_type(const_param.arrow_type)
|
|
229
|
+
elif base_type in _PYTHON_TO_ARROW:
|
|
230
|
+
# Infer from Annotated first argument
|
|
231
|
+
arrow_type = _PYTHON_TO_ARROW[base_type]
|
|
232
|
+
else:
|
|
233
|
+
raise TypeError(
|
|
234
|
+
f"Cannot infer Arrow type from {base_type}. "
|
|
235
|
+
f"Use a supported type (int/str/float/bool/bytes) or specify arrow_type."
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
return Arg[Any](position, doc=const_param.doc, arrow_type=arrow_type, const=True)
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
# =============================================================================
|
|
242
|
+
# Descriptors for Param/ConstParam Arguments
|
|
243
|
+
# =============================================================================
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
class _ArgDescriptor:
|
|
247
|
+
"""Descriptor for Param arguments.
|
|
248
|
+
|
|
249
|
+
On class access, returns the Arg metadata (position, doc, type_bound, etc.).
|
|
250
|
+
On instance access, returns the resolved column value via Arg._resolve().
|
|
251
|
+
"""
|
|
252
|
+
|
|
253
|
+
__slots__ = ("arg", "name")
|
|
254
|
+
|
|
255
|
+
def __init__(self, arg: Arg[Any], name: str) -> None:
|
|
256
|
+
self.arg = arg
|
|
257
|
+
self.name = name
|
|
258
|
+
|
|
259
|
+
def __get__(self, obj: object | None, _objtype: type | None = None) -> Any:
|
|
260
|
+
if obj is None:
|
|
261
|
+
return self.arg
|
|
262
|
+
# For instance access, return the resolved value from arguments
|
|
263
|
+
# This allows accessing self.param_name to get the column name/value
|
|
264
|
+
return self.arg._resolve(obj)
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
class _ConstArgDescriptor:
|
|
268
|
+
"""Descriptor for constant-folded ConstParam arguments.
|
|
269
|
+
|
|
270
|
+
Provides access to the scalar value (not array) for const parameters.
|
|
271
|
+
The value is resolved from invocation.arguments and converted to Python.
|
|
272
|
+
|
|
273
|
+
These must be separate classes because their __get__ methods return different types.
|
|
274
|
+
- _ArgDescriptor returns the column value (array) for regular Param
|
|
275
|
+
- _ConstArgDescriptor returns the scalar value for ConstParam
|
|
276
|
+
"""
|
|
277
|
+
|
|
278
|
+
__slots__ = ("arg", "name")
|
|
279
|
+
|
|
280
|
+
def __init__(self, arg: Arg[Any], name: str) -> None:
|
|
281
|
+
self.arg = arg
|
|
282
|
+
self.name = name
|
|
283
|
+
|
|
284
|
+
def __get__(self, obj: object | None, _objtype: type | None = None) -> Any:
|
|
285
|
+
if obj is None:
|
|
286
|
+
return self.arg
|
|
287
|
+
# For instance access, return the resolved scalar value
|
|
288
|
+
return self.arg._resolve(obj)
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
class RowCountMismatchError(Exception):
|
|
292
|
+
"""Raised when scalar function output row count doesn't match input.
|
|
293
|
+
|
|
294
|
+
Scalar functions must produce exactly one output row for each input row.
|
|
295
|
+
This error indicates the compute() method returned an array with the
|
|
296
|
+
wrong number of elements.
|
|
297
|
+
|
|
298
|
+
Attributes:
|
|
299
|
+
input_rows: Number of rows in the input batch.
|
|
300
|
+
output_rows: Number of rows in the output batch.
|
|
301
|
+
function_name: Name of the function that produced the mismatch.
|
|
302
|
+
|
|
303
|
+
"""
|
|
304
|
+
|
|
305
|
+
def __init__(
|
|
306
|
+
self,
|
|
307
|
+
message: str,
|
|
308
|
+
*,
|
|
309
|
+
input_rows: int | None = None,
|
|
310
|
+
output_rows: int | None = None,
|
|
311
|
+
function_name: str = "",
|
|
312
|
+
) -> None:
|
|
313
|
+
"""Initialize with row count details.
|
|
314
|
+
|
|
315
|
+
Args:
|
|
316
|
+
message: Base error message.
|
|
317
|
+
input_rows: Number of input rows.
|
|
318
|
+
output_rows: Number of output rows.
|
|
319
|
+
function_name: Name of the function class.
|
|
320
|
+
|
|
321
|
+
"""
|
|
322
|
+
self.input_rows = input_rows
|
|
323
|
+
self.output_rows = output_rows
|
|
324
|
+
self.function_name = function_name
|
|
325
|
+
|
|
326
|
+
if input_rows is not None and output_rows is not None:
|
|
327
|
+
full_message = self._build_detailed_message(message, input_rows, output_rows)
|
|
328
|
+
else:
|
|
329
|
+
full_message = message
|
|
330
|
+
|
|
331
|
+
super().__init__(full_message)
|
|
332
|
+
|
|
333
|
+
def _build_detailed_message(self, base_message: str, input_rows: int, output_rows: int) -> str:
|
|
334
|
+
"""Build a detailed, helpful error message."""
|
|
335
|
+
lines = [base_message, ""]
|
|
336
|
+
|
|
337
|
+
if self.function_name:
|
|
338
|
+
lines.append(f" Function: {self.function_name}")
|
|
339
|
+
|
|
340
|
+
lines.append(f" Input rows: {input_rows}")
|
|
341
|
+
lines.append(f" Output rows: {output_rows}")
|
|
342
|
+
|
|
343
|
+
# Provide specific guidance based on the mismatch type
|
|
344
|
+
lines.append("")
|
|
345
|
+
if output_rows < input_rows:
|
|
346
|
+
lines.append(" Problem: Output has fewer rows than input.")
|
|
347
|
+
lines.append("")
|
|
348
|
+
lines.append(" Possible causes:")
|
|
349
|
+
lines.append(" - compute() is filtering rows (not allowed in scalar)")
|
|
350
|
+
lines.append(" - compute() is aggregating (not allowed in scalar)")
|
|
351
|
+
lines.append(" - Bug in array construction")
|
|
352
|
+
lines.append("")
|
|
353
|
+
lines.append(" Scalar functions require 1:1 row mapping.")
|
|
354
|
+
lines.append(" For filtering or aggregation, use a table function.")
|
|
355
|
+
else:
|
|
356
|
+
lines.append(" Problem: Output has more rows than input.")
|
|
357
|
+
lines.append("")
|
|
358
|
+
lines.append(" Possible causes:")
|
|
359
|
+
lines.append(" - compute() is expanding rows (not allowed in scalar)")
|
|
360
|
+
lines.append(" - compute() is unnesting arrays")
|
|
361
|
+
lines.append(" - Bug in array construction")
|
|
362
|
+
lines.append("")
|
|
363
|
+
lines.append(" Scalar functions require 1:1 row mapping.")
|
|
364
|
+
lines.append(" For row expansion (1→N), use a table function.")
|
|
365
|
+
|
|
366
|
+
return "\n".join(lines)
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
class TypeMismatchError(TypeError):
|
|
370
|
+
"""Raised when array type doesn't match declared parameter or return type.
|
|
371
|
+
|
|
372
|
+
This error indicates a mismatch between the declared type in Param() or Returns()
|
|
373
|
+
and the actual array type at runtime.
|
|
374
|
+
|
|
375
|
+
Attributes:
|
|
376
|
+
param_name: Name of the parameter with the type mismatch.
|
|
377
|
+
expected_type: The declared Arrow type.
|
|
378
|
+
actual_type: The actual Arrow type found.
|
|
379
|
+
function_name: Name of the function class.
|
|
380
|
+
|
|
381
|
+
"""
|
|
382
|
+
|
|
383
|
+
def __init__(
|
|
384
|
+
self,
|
|
385
|
+
message: str,
|
|
386
|
+
*,
|
|
387
|
+
param_name: str = "",
|
|
388
|
+
expected_type: pa.DataType | None = None,
|
|
389
|
+
actual_type: pa.DataType | None = None,
|
|
390
|
+
function_name: str = "",
|
|
391
|
+
) -> None:
|
|
392
|
+
"""Initialize with type mismatch details.
|
|
393
|
+
|
|
394
|
+
Args:
|
|
395
|
+
message: Base error message.
|
|
396
|
+
param_name: Name of the parameter.
|
|
397
|
+
expected_type: Expected Arrow type.
|
|
398
|
+
actual_type: Actual Arrow type found.
|
|
399
|
+
function_name: Name of the function class.
|
|
400
|
+
|
|
401
|
+
"""
|
|
402
|
+
self.param_name = param_name
|
|
403
|
+
self.expected_type = expected_type
|
|
404
|
+
self.actual_type = actual_type
|
|
405
|
+
self.function_name = function_name
|
|
406
|
+
|
|
407
|
+
if expected_type is not None and actual_type is not None:
|
|
408
|
+
full_message = self._build_detailed_message(message, param_name, expected_type, actual_type)
|
|
409
|
+
else:
|
|
410
|
+
full_message = message
|
|
411
|
+
|
|
412
|
+
super().__init__(full_message)
|
|
413
|
+
|
|
414
|
+
def _build_detailed_message(
|
|
415
|
+
self,
|
|
416
|
+
base_message: str,
|
|
417
|
+
param_name: str,
|
|
418
|
+
expected_type: pa.DataType,
|
|
419
|
+
actual_type: pa.DataType,
|
|
420
|
+
) -> str:
|
|
421
|
+
"""Build a detailed, helpful error message."""
|
|
422
|
+
lines = [base_message, ""]
|
|
423
|
+
|
|
424
|
+
if self.function_name:
|
|
425
|
+
lines.append(f" Function: {self.function_name}")
|
|
426
|
+
if param_name:
|
|
427
|
+
lines.append(f" Parameter: {param_name}")
|
|
428
|
+
|
|
429
|
+
lines.append(f" Expected type: {expected_type}")
|
|
430
|
+
lines.append(f" Actual type: {actual_type}")
|
|
431
|
+
|
|
432
|
+
return "\n".join(lines)
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
class ScalarFunctionGenerator(vgi.function.Function):
|
|
436
|
+
"""Per-batch callback base class for scalar functions.
|
|
437
|
+
|
|
438
|
+
This is the advanced API for scalar functions. For most use cases,
|
|
439
|
+
use ScalarFunction instead, which provides a simpler compute() callback.
|
|
440
|
+
|
|
441
|
+
Scalar functions have these constraints:
|
|
442
|
+
- **1:1 row mapping**: Output row count must equal input row count
|
|
443
|
+
- **Single value output**: Produces one value per input row
|
|
444
|
+
- **No finalization**: Processing ends when input is exhausted
|
|
445
|
+
|
|
446
|
+
Methods to Override
|
|
447
|
+
-------------------
|
|
448
|
+
output_type(params) -> pa.DataType
|
|
449
|
+
Return the Arrow type for the output value. Required.
|
|
450
|
+
|
|
451
|
+
process(...) -> pa.RecordBatch
|
|
452
|
+
Process one input batch. Must return output with same row count.
|
|
453
|
+
Required.
|
|
454
|
+
|
|
455
|
+
on_bind(params) -> BindResult
|
|
456
|
+
Optional. Override to perform custom bind-time logic.
|
|
457
|
+
|
|
458
|
+
on_init(...) -> GlobalInitResponse
|
|
459
|
+
Optional. Override to perform custom initialization.
|
|
460
|
+
|
|
461
|
+
Protocol Entry Points (called by worker, do not override)
|
|
462
|
+
---------------------------------------------------------
|
|
463
|
+
bind(input) -> BindResponse
|
|
464
|
+
Handles the bind API call
|
|
465
|
+
|
|
466
|
+
global_init(input) -> GlobalInitResponse
|
|
467
|
+
Handles the global_init API call
|
|
468
|
+
|
|
469
|
+
"""
|
|
470
|
+
|
|
471
|
+
@final
|
|
472
|
+
@classmethod
|
|
473
|
+
def _validate_row_count(cls, output_batch: pa.RecordBatch, input_batch: pa.RecordBatch) -> None:
|
|
474
|
+
"""Validate that output row count matches input row count."""
|
|
475
|
+
if output_batch.num_rows != input_batch.num_rows:
|
|
476
|
+
raise RowCountMismatchError(
|
|
477
|
+
"Scalar function output must have same row count as input.",
|
|
478
|
+
input_rows=input_batch.num_rows,
|
|
479
|
+
output_rows=output_batch.num_rows,
|
|
480
|
+
function_name=cls.__name__,
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
@classmethod
|
|
484
|
+
@abstractmethod
|
|
485
|
+
def output_type(cls, params: BindParameters) -> pa.DataType:
|
|
486
|
+
"""Return the Arrow type for the output value.
|
|
487
|
+
|
|
488
|
+
Args:
|
|
489
|
+
params: Bind parameters including arguments and input schema.
|
|
490
|
+
|
|
491
|
+
"""
|
|
492
|
+
...
|
|
493
|
+
|
|
494
|
+
@classmethod
|
|
495
|
+
def on_bind(
|
|
496
|
+
cls,
|
|
497
|
+
params: BindParameters,
|
|
498
|
+
) -> BindResult:
|
|
499
|
+
"""Produce the output type during the bind API call.
|
|
500
|
+
|
|
501
|
+
Override to perform custom bind-time logic such as validating
|
|
502
|
+
arguments or computing a dynamic output type.
|
|
503
|
+
|
|
504
|
+
Args:
|
|
505
|
+
params: Bind parameters including arguments and schema.
|
|
506
|
+
|
|
507
|
+
Returns:
|
|
508
|
+
BindResult with output_type and optional opaque_data.
|
|
509
|
+
|
|
510
|
+
"""
|
|
511
|
+
return BindResult(cls.output_type(params))
|
|
512
|
+
|
|
513
|
+
@classmethod
|
|
514
|
+
def catalog_output_schema(cls) -> pa.Schema:
|
|
515
|
+
"""Return output schema for catalog introspection.
|
|
516
|
+
|
|
517
|
+
A generator-style scalar function computes its output type at bind
|
|
518
|
+
time, so no static type is known here. Report a single dynamic
|
|
519
|
+
``result`` column (``null()`` tagged ``vgi:any``) so catalog
|
|
520
|
+
consumers treat the type as resolved-at-bind. ScalarFunction
|
|
521
|
+
overrides this when a static ``Returns()`` type is available.
|
|
522
|
+
"""
|
|
523
|
+
field = pa.field("result", pa.null(), metadata={b"vgi:any": b"true"})
|
|
524
|
+
return pa.schema([field])
|
|
525
|
+
|
|
526
|
+
@final
|
|
527
|
+
@classmethod
|
|
528
|
+
def _validate_param_type_bounds(cls, input_schema: pa.Schema) -> None:
|
|
529
|
+
"""Validate type bounds for AnyArrow Param parameters at bind time.
|
|
530
|
+
|
|
531
|
+
Checks each Param with type_bound against the input schema field types.
|
|
532
|
+
This provides early error detection before any data is processed.
|
|
533
|
+
|
|
534
|
+
Only applies to ScalarFunction subclasses that define _compute_params
|
|
535
|
+
(via the Param/ConstParam annotation API). For ScalarFunctionGenerator
|
|
536
|
+
subclasses that don't use annotations, this is a no-op.
|
|
537
|
+
|
|
538
|
+
Args:
|
|
539
|
+
input_schema: The input schema from the bind call.
|
|
540
|
+
|
|
541
|
+
Raises:
|
|
542
|
+
SchemaValidationError: If any column type fails type_bound.
|
|
543
|
+
|
|
544
|
+
"""
|
|
545
|
+
compute_params: dict[str, Arg[Any]] | None = getattr(cls, "_compute_params", None)
|
|
546
|
+
if not compute_params:
|
|
547
|
+
return
|
|
548
|
+
for _name, arg in compute_params.items():
|
|
549
|
+
if not arg.is_any or arg.type_bound is None:
|
|
550
|
+
continue
|
|
551
|
+
col_idx = cast(int, arg._resolution_index)
|
|
552
|
+
if arg.varargs:
|
|
553
|
+
for i in range(col_idx, len(input_schema)):
|
|
554
|
+
arg.validate_type_bound(input_schema.field(i).type)
|
|
555
|
+
else:
|
|
556
|
+
arg.validate_type_bound(input_schema.field(col_idx).type)
|
|
557
|
+
|
|
558
|
+
@final
|
|
559
|
+
@classmethod
|
|
560
|
+
def _validate_param_varargs_min(cls, input_schema: pa.Schema) -> None:
|
|
561
|
+
"""Ensure varargs Params receive at least one column.
|
|
562
|
+
|
|
563
|
+
Mirrors the ``Arg.varargs`` rule for table-function arguments: a
|
|
564
|
+
``Param(varargs=True)`` must bind to >= 1 input column. Without this
|
|
565
|
+
guard, ``on_bind`` implementations that index ``arguments_schema``
|
|
566
|
+
crash with an opaque ``IndexError`` when callers invoke a varargs
|
|
567
|
+
scalar with zero values.
|
|
568
|
+
"""
|
|
569
|
+
compute_params: dict[str, Arg[Any]] | None = getattr(cls, "_compute_params", None)
|
|
570
|
+
if not compute_params:
|
|
571
|
+
return
|
|
572
|
+
for name, arg in compute_params.items():
|
|
573
|
+
if not arg.varargs:
|
|
574
|
+
continue
|
|
575
|
+
col_idx = cast(int, arg._resolution_index)
|
|
576
|
+
if col_idx >= len(input_schema):
|
|
577
|
+
raise ArgumentValidationError(
|
|
578
|
+
f"Varargs parameter '{name}' requires at least 1 value.",
|
|
579
|
+
arg_name=name,
|
|
580
|
+
position=arg.position,
|
|
581
|
+
constraint="varargs requires at least 1 value",
|
|
582
|
+
doc=arg.doc if arg.doc else None,
|
|
583
|
+
)
|
|
584
|
+
|
|
585
|
+
@final
|
|
586
|
+
@classmethod
|
|
587
|
+
def bind(
|
|
588
|
+
cls,
|
|
589
|
+
input: BindRequest,
|
|
590
|
+
*,
|
|
591
|
+
ctx: CallContext | None = None,
|
|
592
|
+
attach_plaintext: bytes | None = None,
|
|
593
|
+
) -> BindResponse:
|
|
594
|
+
"""Bind protocol entry point. Do not override; use on_bind() instead.
|
|
595
|
+
|
|
596
|
+
Constructs BindParameters, validates type bounds, calls on_bind(),
|
|
597
|
+
and wraps the result for transmission to global_init. If on_bind()
|
|
598
|
+
triggers dynamic secret lookups or if compute() declares Secret()
|
|
599
|
+
annotations that haven't been resolved, returns a secret scope request.
|
|
600
|
+
|
|
601
|
+
"""
|
|
602
|
+
assert input.input_schema is not None
|
|
603
|
+
cls._validate_param_varargs_min(input.input_schema)
|
|
604
|
+
cls._validate_param_type_bounds(input.input_schema)
|
|
605
|
+
|
|
606
|
+
# Auto-request secrets declared via Secret() annotations on compute()
|
|
607
|
+
# when they haven't been resolved yet (first bind call).
|
|
608
|
+
# _secret_params is only defined on ScalarFunction, not ScalarFunctionGenerator.
|
|
609
|
+
secret_params: dict[str, Secret] = getattr(cls, "_secret_params", {})
|
|
610
|
+
if secret_params and not input.resolved_secrets_provided and input.secrets is None:
|
|
611
|
+
entries = [
|
|
612
|
+
SecretLookupEntry(
|
|
613
|
+
secret_type=secret.secret_type,
|
|
614
|
+
scope=secret.scope,
|
|
615
|
+
secret_name=secret.name,
|
|
616
|
+
)
|
|
617
|
+
for secret in secret_params.values()
|
|
618
|
+
]
|
|
619
|
+
return BindResponse.secret_scope_request(entries)
|
|
620
|
+
|
|
621
|
+
auth = ctx.auth if ctx is not None else AuthContext.anonymous()
|
|
622
|
+
secrets_accessor = SecretsAccessor(input.secrets, is_retry=input.resolved_secrets_provided)
|
|
623
|
+
bind_params = BindParameters(
|
|
624
|
+
input.arguments,
|
|
625
|
+
input.input_schema,
|
|
626
|
+
input.settings,
|
|
627
|
+
secrets_accessor,
|
|
628
|
+
auth,
|
|
629
|
+
# ``attach_plaintext`` is the full framework plaintext (uuid||catalog
|
|
630
|
+
# bytes); the body sees only the catalog bytes. transaction_opaque_data
|
|
631
|
+
# is already sealed.
|
|
632
|
+
attach_catalog_bytes(attach_plaintext),
|
|
633
|
+
input.transaction_opaque_data,
|
|
634
|
+
)
|
|
635
|
+
result = cls.on_bind(bind_params)
|
|
636
|
+
|
|
637
|
+
# Check if on_bind() registered pending secret lookups
|
|
638
|
+
if secrets_accessor.needs_resolution:
|
|
639
|
+
return BindResponse.secret_scope_request(secrets_accessor.pending_lookups)
|
|
640
|
+
|
|
641
|
+
# Serialize the typed BindResult.opaque_data to bytes before
|
|
642
|
+
# putting it on the wire. The user-facing API stays typed
|
|
643
|
+
# (BindResult.opaque_data: ArrowSerializableDataclass | None) but
|
|
644
|
+
# the wire field is bytes; the framework owns this single
|
|
645
|
+
# boundary shim so workers don't write per-callsite serialization
|
|
646
|
+
# boilerplate. See vgi/invocation.py:BindResponse.opaque_data for
|
|
647
|
+
# the full contract.
|
|
648
|
+
opaque_bytes: bytes | None = None
|
|
649
|
+
if result.opaque_data is not None:
|
|
650
|
+
opaque_bytes = result.opaque_data.serialize_to_bytes()
|
|
651
|
+
|
|
652
|
+
return BindResponse(
|
|
653
|
+
output_schema=schema(result=result.output_type),
|
|
654
|
+
opaque_data=opaque_bytes,
|
|
655
|
+
)
|
|
656
|
+
|
|
657
|
+
@classmethod
|
|
658
|
+
def on_init(
|
|
659
|
+
cls,
|
|
660
|
+
*,
|
|
661
|
+
bind_call: BindRequest,
|
|
662
|
+
opaque_data: bytes | None,
|
|
663
|
+
storage: BoundStorage,
|
|
664
|
+
) -> GlobalInitResponse:
|
|
665
|
+
"""Initialize the function during the init API call.
|
|
666
|
+
|
|
667
|
+
Override to perform one-time setup that should happen after bind
|
|
668
|
+
but before processing batches. The default returns max_processes=1.
|
|
669
|
+
|
|
670
|
+
Args:
|
|
671
|
+
bind_call: The original BindCall with arguments and schema.
|
|
672
|
+
opaque_data: Bytes from on_bind()'s ``BindResult.opaque_data``
|
|
673
|
+
(after the framework's serialize-to-bytes shim), or None
|
|
674
|
+
if on_bind didn't set it. Reconstruct via
|
|
675
|
+
``MyConcreteDataclass.deserialize_from_bytes(opaque_data)``
|
|
676
|
+
— the consumer always knows what concrete type to expect,
|
|
677
|
+
so explicit reconstruction is preferred over a framework-
|
|
678
|
+
level class-name registry.
|
|
679
|
+
storage: BoundStorage for storing data across calls.
|
|
680
|
+
|
|
681
|
+
Returns:
|
|
682
|
+
GlobalInitResponse with max_processes and optional opaque data.
|
|
683
|
+
|
|
684
|
+
"""
|
|
685
|
+
return GlobalInitResponse()
|
|
686
|
+
|
|
687
|
+
@final
|
|
688
|
+
@classmethod
|
|
689
|
+
def global_init(cls, input: InitRequest, *, attach_plaintext: bytes | None = None) -> GlobalInitResponse:
|
|
690
|
+
"""Global init protocol entry point. Do not override; use on_init() instead.
|
|
691
|
+
|
|
692
|
+
Deserializes the wrapped bind data, calls on_init(), and
|
|
693
|
+
wraps the result for transmission to process().
|
|
694
|
+
|
|
695
|
+
``attach_plaintext`` is the full framework plaintext (uuid||catalog
|
|
696
|
+
bytes) the worker unwrapped; storage shards on its UUID. Scalar
|
|
697
|
+
``on_init`` does not expose the attach to bodies, so only storage uses it.
|
|
698
|
+
"""
|
|
699
|
+
execution_id = uuid.uuid4().bytes
|
|
700
|
+
result = cls.on_init(
|
|
701
|
+
bind_call=input.bind_call,
|
|
702
|
+
opaque_data=input.bind_opaque_data,
|
|
703
|
+
storage=BoundStorage(cls.storage, execution_id, request=input, attach_plaintext=attach_plaintext),
|
|
704
|
+
)
|
|
705
|
+
|
|
706
|
+
return GlobalInitResponse(
|
|
707
|
+
max_workers=result.max_workers,
|
|
708
|
+
execution_id=execution_id,
|
|
709
|
+
opaque_data=result.opaque_data,
|
|
710
|
+
)
|
|
711
|
+
|
|
712
|
+
@classmethod
|
|
713
|
+
@abstractmethod
|
|
714
|
+
def process(
|
|
715
|
+
cls,
|
|
716
|
+
*,
|
|
717
|
+
batch: pa.RecordBatch,
|
|
718
|
+
init_call: InitRequest,
|
|
719
|
+
init_response: BaseInitResponse,
|
|
720
|
+
storage: BoundStorage,
|
|
721
|
+
auth_context: AuthContext,
|
|
722
|
+
) -> pa.RecordBatch:
|
|
723
|
+
"""Process one input batch.
|
|
724
|
+
|
|
725
|
+
Override this method to implement your scalar transformation.
|
|
726
|
+
Must return an output RecordBatch with exactly the same number
|
|
727
|
+
of rows as the input batch.
|
|
728
|
+
|
|
729
|
+
Args:
|
|
730
|
+
batch: The input RecordBatch to process.
|
|
731
|
+
init_call: The parameters from global_init.
|
|
732
|
+
init_response: The response from the init call.
|
|
733
|
+
storage: BoundStorage for storing data across calls.
|
|
734
|
+
auth_context: Authentication context for the current request.
|
|
735
|
+
|
|
736
|
+
Returns:
|
|
737
|
+
Output RecordBatch with same row count as input.
|
|
738
|
+
|
|
739
|
+
"""
|
|
740
|
+
...
|
|
741
|
+
|
|
742
|
+
|
|
743
|
+
class ScalarFunction(ScalarFunctionGenerator):
|
|
744
|
+
"""Base class for scalar functions (1:1 row mapping, single output column).
|
|
745
|
+
|
|
746
|
+
Scalar functions transform each input row to exactly one output value.
|
|
747
|
+
Use Param/ConstParam/Returns annotations on compute() to declare types.
|
|
748
|
+
|
|
749
|
+
Type Validation
|
|
750
|
+
---------------
|
|
751
|
+
Input and output types are validated at runtime:
|
|
752
|
+
- Param types are checked against actual array types
|
|
753
|
+
- Returns type is checked against compute() result
|
|
754
|
+
- AnyArrow parameters skip validation
|
|
755
|
+
- TypeMismatchError is raised on mismatch
|
|
756
|
+
|
|
757
|
+
Methods to Override
|
|
758
|
+
-------------------
|
|
759
|
+
compute(self, ...) -> pa.Array
|
|
760
|
+
Transform input arrays to output. Use Param/ConstParam annotations.
|
|
761
|
+
|
|
762
|
+
output_type(params) -> pa.DataType (classmethod)
|
|
763
|
+
Override when output type depends on input schema or arguments.
|
|
764
|
+
|
|
765
|
+
"""
|
|
766
|
+
|
|
767
|
+
# For TYPE_CHECKING, allow dynamic attribute access for Param/ConstParam
|
|
768
|
+
if TYPE_CHECKING:
|
|
769
|
+
|
|
770
|
+
def __getattr__(self, _name: str) -> Any:
|
|
771
|
+
"""Allow dynamic attribute access for Param/ConstParam descriptors."""
|
|
772
|
+
...
|
|
773
|
+
|
|
774
|
+
_compute_params: dict[str, Arg[Any]] # Regular Param() arguments (arrays)
|
|
775
|
+
_const_params: dict[str, Arg[Any]] # ConstParam() arguments (scalars)
|
|
776
|
+
_setting_params: dict[str, str] # Setting params: param_name -> setting_key
|
|
777
|
+
_secret_params: dict[str, Secret] # Secret params: param_name -> Secret instance
|
|
778
|
+
_output_length_param: str | None # OutputLength param name (batch row count)
|
|
779
|
+
_auth_param: str | None # Auth param name (receives AuthContext)
|
|
780
|
+
_returns_output_type: pa.DataType | None # Output type from Returns()
|
|
781
|
+
|
|
782
|
+
def __init_subclass__(cls, **kwargs: Any) -> None:
|
|
783
|
+
"""Extract annotations from compute() signature.
|
|
784
|
+
|
|
785
|
+
Extracts Param, ConstParam, Setting, Secret, OutputLength, and
|
|
786
|
+
Returns type information from compute() parameter annotations.
|
|
787
|
+
"""
|
|
788
|
+
super().__init_subclass__(**kwargs)
|
|
789
|
+
|
|
790
|
+
# Skip abstract classes
|
|
791
|
+
if inspect.isabstract(cls):
|
|
792
|
+
return
|
|
793
|
+
|
|
794
|
+
# Get compute method
|
|
795
|
+
compute_method = getattr(cls, "compute", None)
|
|
796
|
+
if compute_method is None:
|
|
797
|
+
raise TypeError(f"{cls.__name__} must define a compute() method.\n\n")
|
|
798
|
+
|
|
799
|
+
sig = inspect.signature(compute_method)
|
|
800
|
+
|
|
801
|
+
# Try to get type hints for the compute method
|
|
802
|
+
# This handles both regular annotations and PEP 563 string annotations
|
|
803
|
+
hints: dict[str, Any] = {}
|
|
804
|
+
with contextlib.suppress(Exception):
|
|
805
|
+
hints = get_type_hints(compute_method, include_extras=True)
|
|
806
|
+
|
|
807
|
+
# If get_type_hints failed or returned empty, try to evaluate annotations
|
|
808
|
+
# manually. This handles cases where Param/ConstParam are used with
|
|
809
|
+
# `from __future__ import annotations` which stores annotations as strings.
|
|
810
|
+
if not hints:
|
|
811
|
+
raw_annotations = getattr(compute_method, "__annotations__", {})
|
|
812
|
+
# Build namespace with imports from vgi.arguments for evaluation
|
|
813
|
+
from vgi import arguments as vgi_args
|
|
814
|
+
|
|
815
|
+
# Create a mock pa module with subscriptable Array for eval
|
|
816
|
+
# (pa.Array[Any] isn't subscriptable in PyArrow)
|
|
817
|
+
class _MockArray:
|
|
818
|
+
def __class_getitem__(cls, _item: Any) -> Any:
|
|
819
|
+
return Any
|
|
820
|
+
|
|
821
|
+
class _MockPa:
|
|
822
|
+
Array = _MockArray
|
|
823
|
+
|
|
824
|
+
def __getattr__(self, name: str) -> Any:
|
|
825
|
+
return getattr(pa, name)
|
|
826
|
+
|
|
827
|
+
from typing import Annotated
|
|
828
|
+
|
|
829
|
+
eval_namespace = {
|
|
830
|
+
**getattr(compute_method, "__globals__", {}),
|
|
831
|
+
"Annotated": Annotated,
|
|
832
|
+
"Param": vgi_args.Param,
|
|
833
|
+
"ConstParam": vgi_args.ConstParam,
|
|
834
|
+
"Setting": vgi_args.Setting,
|
|
835
|
+
"Secret": vgi_args.Secret,
|
|
836
|
+
"Auth": vgi_args.Auth,
|
|
837
|
+
"OutputLength": vgi_args.OutputLength,
|
|
838
|
+
"Returns": vgi_args.Returns,
|
|
839
|
+
"AnyArrow": vgi_args.AnyArrow,
|
|
840
|
+
"pa": _MockPa(),
|
|
841
|
+
}
|
|
842
|
+
for name, annotation in raw_annotations.items():
|
|
843
|
+
if isinstance(annotation, str):
|
|
844
|
+
with contextlib.suppress(Exception):
|
|
845
|
+
hints[name] = eval(annotation, eval_namespace) # noqa: S307
|
|
846
|
+
else:
|
|
847
|
+
hints[name] = annotation
|
|
848
|
+
|
|
849
|
+
compute_params: dict[str, Arg[Any]] = {}
|
|
850
|
+
const_params: dict[str, Arg[Any]] = {}
|
|
851
|
+
output_length_param: str | None = None # param that receives batch row count
|
|
852
|
+
auth_param: str | None = None # param that receives AuthContext
|
|
853
|
+
returns_output_type: pa.DataType | None = None
|
|
854
|
+
|
|
855
|
+
# Check return type for Returns() annotation
|
|
856
|
+
return_hint = hints.get("return")
|
|
857
|
+
if return_hint is not None and hasattr(return_hint, "__metadata__"):
|
|
858
|
+
# Extract Returns from Annotated[..., Returns(...)]
|
|
859
|
+
for meta in return_hint.__metadata__:
|
|
860
|
+
if isinstance(meta, Returns):
|
|
861
|
+
# Priority 1: Explicit arrow_type in Returns()
|
|
862
|
+
if meta.arrow_type is not None:
|
|
863
|
+
returns_output_type = meta.arrow_type
|
|
864
|
+
else:
|
|
865
|
+
# Priority 2: Infer from Annotated first argument
|
|
866
|
+
type_args = get_args(return_hint)
|
|
867
|
+
if type_args:
|
|
868
|
+
return_base_type = type_args[0]
|
|
869
|
+
if return_base_type in ARRAY_CLASS_TO_DATATYPE:
|
|
870
|
+
returns_output_type = ARRAY_CLASS_TO_DATATYPE[return_base_type]
|
|
871
|
+
elif return_base_type in COMPLEX_ARRAY_CLASSES:
|
|
872
|
+
raise TypeError(
|
|
873
|
+
f"{return_base_type.__name__} requires explicit "
|
|
874
|
+
f"arrow_type in Returns(). "
|
|
875
|
+
f"Example: Returns(arrow_type=pa.list_(pa.int64()))"
|
|
876
|
+
)
|
|
877
|
+
# Else: AnyArrow (returns_output_type remains None)
|
|
878
|
+
break
|
|
879
|
+
|
|
880
|
+
# Extract Param/ConstParam from parameter annotations
|
|
881
|
+
# Track overall position (call order) for metadata/client use.
|
|
882
|
+
# For const params, track resolution_index (index into Args.positional)
|
|
883
|
+
# For column params, track column_index (index into batch columns)
|
|
884
|
+
overall_position = 0 # Overall call order position for metadata
|
|
885
|
+
column_index = 0 # Index in batch columns (for column params)
|
|
886
|
+
const_index = 0 # Index in invocation.arguments.positional (for const params)
|
|
887
|
+
for name in sig.parameters:
|
|
888
|
+
if name == "self":
|
|
889
|
+
continue
|
|
890
|
+
|
|
891
|
+
hint = hints.get(name)
|
|
892
|
+
if hint is None:
|
|
893
|
+
continue
|
|
894
|
+
|
|
895
|
+
# Check for Annotated[..., Param/ConstParam/...] pattern
|
|
896
|
+
if hasattr(hint, "__metadata__"):
|
|
897
|
+
for meta in hint.__metadata__:
|
|
898
|
+
# Param: column input (array)
|
|
899
|
+
if isinstance(meta, Param):
|
|
900
|
+
# Get base type from Annotated first argument for inference
|
|
901
|
+
type_args = get_args(hint)
|
|
902
|
+
base_type = type_args[0] if type_args else pa.Array
|
|
903
|
+
# Use overall position for metadata, column_index for resolution
|
|
904
|
+
arg = _param_to_arg(meta, base_type, overall_position)
|
|
905
|
+
arg._name = name
|
|
906
|
+
# Store column_index in _resolution_index for batch lookup
|
|
907
|
+
arg._resolution_index = column_index
|
|
908
|
+
compute_params[name] = arg
|
|
909
|
+
setattr(cls, name, _ArgDescriptor(arg, name))
|
|
910
|
+
overall_position += 1
|
|
911
|
+
column_index += 1
|
|
912
|
+
break
|
|
913
|
+
|
|
914
|
+
# ConstParam: constant input (scalar)
|
|
915
|
+
if isinstance(meta, ConstParam):
|
|
916
|
+
# Get base type from Annotated first argument
|
|
917
|
+
type_args = get_args(hint)
|
|
918
|
+
base_type = cast(type, type_args[0] if type_args else Any)
|
|
919
|
+
# Use overall position for metadata
|
|
920
|
+
arg = _const_param_to_arg(meta, base_type, overall_position)
|
|
921
|
+
arg._name = name
|
|
922
|
+
# _resolution_index points to Arguments.positional index
|
|
923
|
+
arg._resolution_index = const_index
|
|
924
|
+
const_params[name] = arg
|
|
925
|
+
setattr(cls, name, _ConstArgDescriptor(arg, name))
|
|
926
|
+
overall_position += 1
|
|
927
|
+
const_index += 1
|
|
928
|
+
break
|
|
929
|
+
|
|
930
|
+
# OutputLength: receives batch row count
|
|
931
|
+
if isinstance(meta, OutputLength):
|
|
932
|
+
output_length_param = name
|
|
933
|
+
# Don't increment overall_position - not a call argument
|
|
934
|
+
break
|
|
935
|
+
|
|
936
|
+
# Auth: receives AuthContext
|
|
937
|
+
if isinstance(meta, Auth):
|
|
938
|
+
if auth_param is not None:
|
|
939
|
+
raise TypeError(
|
|
940
|
+
f"{cls.__name__}.compute() has multiple Auth parameters: {auth_param!r} and {name!r}"
|
|
941
|
+
)
|
|
942
|
+
auth_param = name
|
|
943
|
+
# Don't increment overall_position - not a call argument
|
|
944
|
+
break
|
|
945
|
+
|
|
946
|
+
# Extract Setting/Secret params using shared helper
|
|
947
|
+
setting_params, secret_params = _extract_setting_secret_params(compute_method)
|
|
948
|
+
|
|
949
|
+
cls._compute_params = compute_params
|
|
950
|
+
cls._const_params = const_params
|
|
951
|
+
cls._setting_params = setting_params
|
|
952
|
+
cls._secret_params = secret_params
|
|
953
|
+
cls._output_length_param = output_length_param
|
|
954
|
+
cls._auth_param = auth_param
|
|
955
|
+
cls._returns_output_type = returns_output_type
|
|
956
|
+
|
|
957
|
+
@final
|
|
958
|
+
@classmethod
|
|
959
|
+
def catalog_output_schema(cls) -> pa.Schema:
|
|
960
|
+
"""Return output schema for catalog introspection.
|
|
961
|
+
|
|
962
|
+
Returns the output schema with a single "result" field using the
|
|
963
|
+
type from the Returns() annotation. If no explicit type was declared
|
|
964
|
+
(dynamic type), returns null() with metadata indicating "any" type.
|
|
965
|
+
"""
|
|
966
|
+
returns_type = getattr(cls, "_returns_output_type", None)
|
|
967
|
+
if returns_type is None:
|
|
968
|
+
# Dynamic type (no explicit Returns type)
|
|
969
|
+
field = pa.field("result", pa.null(), metadata={b"vgi:any": b"true"})
|
|
970
|
+
return pa.schema([field])
|
|
971
|
+
return schema({"result": returns_type})
|
|
972
|
+
|
|
973
|
+
@classmethod
|
|
974
|
+
def output_type(cls, params: BindParameters) -> pa.DataType:
|
|
975
|
+
"""Return the Arrow type for the output column.
|
|
976
|
+
|
|
977
|
+
Default implementation uses _returns_output_type from Returns()
|
|
978
|
+
annotation. Override when the output type depends on input schema
|
|
979
|
+
or arguments (use params.arguments_schema, params.constant_arguments).
|
|
980
|
+
|
|
981
|
+
Args:
|
|
982
|
+
params: Bind parameters including arguments and input schema.
|
|
983
|
+
|
|
984
|
+
"""
|
|
985
|
+
if cls._returns_output_type is not None:
|
|
986
|
+
return cls._returns_output_type
|
|
987
|
+
|
|
988
|
+
raise NotImplementedError(
|
|
989
|
+
f"{cls.__name__}.output_type must be overridden when using Returns() "
|
|
990
|
+
f"without an explicit type (dynamic output type)."
|
|
991
|
+
)
|
|
992
|
+
|
|
993
|
+
# Note: compute() is NOT defined here. Subclasses define it with their own
|
|
994
|
+
# keyword-only signature. This avoids mypy override errors for users.
|
|
995
|
+
# See class docstring for compute() signature requirements.
|
|
996
|
+
# Validated at class definition time by __init_subclass__.
|
|
997
|
+
|
|
998
|
+
@final
|
|
999
|
+
@classmethod
|
|
1000
|
+
def _extract_compute_kwargs(
|
|
1001
|
+
cls,
|
|
1002
|
+
batch: pa.RecordBatch,
|
|
1003
|
+
bind_call: BindRequest,
|
|
1004
|
+
auth_context: AuthContext,
|
|
1005
|
+
) -> dict[str, Any]:
|
|
1006
|
+
"""Extract columns/values for compute() parameters.
|
|
1007
|
+
|
|
1008
|
+
Returns dict[str, Any] because values are a mix of arrays, lists of
|
|
1009
|
+
arrays, and scalar values, keyed by compute() parameter names.
|
|
1010
|
+
|
|
1011
|
+
Args:
|
|
1012
|
+
batch: Input RecordBatch.
|
|
1013
|
+
bind_call: The BindCall with arguments, settings, and secrets.
|
|
1014
|
+
auth_context: Authentication context for the current request.
|
|
1015
|
+
|
|
1016
|
+
Returns:
|
|
1017
|
+
Dict mapping parameter names to their resolved values.
|
|
1018
|
+
|
|
1019
|
+
"""
|
|
1020
|
+
kwargs: dict[str, Any] = {}
|
|
1021
|
+
|
|
1022
|
+
# Regular params: extract arrays by _resolution_index (batch column index)
|
|
1023
|
+
for name, arg in cls._compute_params.items():
|
|
1024
|
+
# Use _resolution_index for batch column lookup
|
|
1025
|
+
col_idx = cast(int, arg._resolution_index)
|
|
1026
|
+
if arg.varargs:
|
|
1027
|
+
# Varargs: collect all remaining columns from position
|
|
1028
|
+
kwargs[name] = [batch.column(i) for i in range(col_idx, batch.num_columns)]
|
|
1029
|
+
else:
|
|
1030
|
+
# Regular param: extract column by index
|
|
1031
|
+
kwargs[name] = batch.column(col_idx)
|
|
1032
|
+
|
|
1033
|
+
# Const params: extract scalar values from arguments
|
|
1034
|
+
for name, arg in cls._const_params.items():
|
|
1035
|
+
# Use _resolution_index for Arguments.positional lookup
|
|
1036
|
+
arg_idx = cast(int, arg._resolution_index)
|
|
1037
|
+
# Get the scalar value from arguments
|
|
1038
|
+
scalar = bind_call.arguments.positional[arg_idx]
|
|
1039
|
+
# Convert to Python value
|
|
1040
|
+
kwargs[name] = scalar.as_py() if scalar is not None else None
|
|
1041
|
+
|
|
1042
|
+
# Setting params: extract pa.Scalar from settings RecordBatch
|
|
1043
|
+
if bind_call.settings is not None and cls._setting_params:
|
|
1044
|
+
settings_schema = bind_call.settings.schema
|
|
1045
|
+
for name, setting_key in cls._setting_params.items():
|
|
1046
|
+
col_idx = settings_schema.get_field_index(setting_key)
|
|
1047
|
+
kwargs[name] = bind_call.settings.column(col_idx)[0] if col_idx >= 0 else None
|
|
1048
|
+
|
|
1049
|
+
# Secret params: extract dict[str, pa.Scalar] from secrets RecordBatch
|
|
1050
|
+
if bind_call.secrets is not None and cls._secret_params:
|
|
1051
|
+
secrets_schema = bind_call.secrets.schema
|
|
1052
|
+
for name, secret in cls._secret_params.items():
|
|
1053
|
+
col_idx = secrets_schema.get_field_index(secret.secret_type)
|
|
1054
|
+
kwargs[name] = _struct_scalar_to_dict(bind_call.secrets.column(col_idx)[0]) if col_idx >= 0 else None
|
|
1055
|
+
|
|
1056
|
+
# OutputLength param: pass the batch row count
|
|
1057
|
+
if cls._output_length_param is not None:
|
|
1058
|
+
kwargs[cls._output_length_param] = batch.num_rows
|
|
1059
|
+
|
|
1060
|
+
# Auth param: pass the AuthContext
|
|
1061
|
+
if cls._auth_param is not None:
|
|
1062
|
+
kwargs[cls._auth_param] = auth_context
|
|
1063
|
+
|
|
1064
|
+
return kwargs
|
|
1065
|
+
|
|
1066
|
+
@final
|
|
1067
|
+
@classmethod
|
|
1068
|
+
def _validate_single_param_type(cls, arg: Arg[Any], arr: pa.Array[Any], display_name: str) -> pa.Array[Any]:
|
|
1069
|
+
"""Validate a single parameter's array type against its declaration.
|
|
1070
|
+
|
|
1071
|
+
If the array type doesn't match exactly but is castable (e.g. int32→int64,
|
|
1072
|
+
decimal128→double), the array is cast to the expected type and returned.
|
|
1073
|
+
|
|
1074
|
+
Args:
|
|
1075
|
+
arg: The Arg metadata for the parameter.
|
|
1076
|
+
arr: The actual array to validate.
|
|
1077
|
+
display_name: Name used in error messages (e.g. "x" or "x[0]").
|
|
1078
|
+
|
|
1079
|
+
Returns:
|
|
1080
|
+
The (possibly cast) array.
|
|
1081
|
+
|
|
1082
|
+
Raises:
|
|
1083
|
+
TypeMismatchError: If array type doesn't match and cannot be cast.
|
|
1084
|
+
|
|
1085
|
+
"""
|
|
1086
|
+
if arg.is_any:
|
|
1087
|
+
if arg.type_bound is not None:
|
|
1088
|
+
arg.validate_type_bound(arr.type)
|
|
1089
|
+
return arr
|
|
1090
|
+
if arg.arrow_type is not None and arr.type != arg.arrow_type:
|
|
1091
|
+
try:
|
|
1092
|
+
casted = arr.cast(arg.arrow_type)
|
|
1093
|
+
logger.debug("Cast parameter '%s' from %s to %s", display_name, arr.type, arg.arrow_type)
|
|
1094
|
+
return casted
|
|
1095
|
+
except (pa.ArrowInvalid, pa.ArrowNotImplementedError):
|
|
1096
|
+
raise TypeMismatchError(
|
|
1097
|
+
f"Input type mismatch for parameter '{display_name}'.",
|
|
1098
|
+
param_name=display_name,
|
|
1099
|
+
expected_type=arg.arrow_type,
|
|
1100
|
+
actual_type=arr.type,
|
|
1101
|
+
function_name=cls.__name__,
|
|
1102
|
+
) from None
|
|
1103
|
+
return arr
|
|
1104
|
+
|
|
1105
|
+
@final
|
|
1106
|
+
@classmethod
|
|
1107
|
+
def _validate_param_types(cls, kwargs: dict[str, Any]) -> None:
|
|
1108
|
+
"""Validate that input array types match declared Param types.
|
|
1109
|
+
|
|
1110
|
+
For the Param/ConstParam API:
|
|
1111
|
+
- Validates exact type match for params with declared arrow_type
|
|
1112
|
+
- Validates type_bound predicates for AnyArrow params with type_bound
|
|
1113
|
+
|
|
1114
|
+
Args:
|
|
1115
|
+
kwargs: Dict of parameter names to arrays (from _extract_compute_kwargs).
|
|
1116
|
+
|
|
1117
|
+
Raises:
|
|
1118
|
+
TypeMismatchError: If any array type doesn't match its declared type.
|
|
1119
|
+
SchemaValidationError: If any array type fails type_bound validation.
|
|
1120
|
+
|
|
1121
|
+
"""
|
|
1122
|
+
for name, arg in cls._compute_params.items():
|
|
1123
|
+
if arg.varargs:
|
|
1124
|
+
kwargs[name] = [
|
|
1125
|
+
cls._validate_single_param_type(arg, arr, f"{name}[{i}]") for i, arr in enumerate(kwargs[name])
|
|
1126
|
+
]
|
|
1127
|
+
else:
|
|
1128
|
+
kwargs[name] = cls._validate_single_param_type(arg, kwargs[name], name)
|
|
1129
|
+
|
|
1130
|
+
@final
|
|
1131
|
+
@classmethod
|
|
1132
|
+
def _validate_output_type(cls, result: pa.Array[Any]) -> None:
|
|
1133
|
+
"""Validate that output array type matches declared Returns type.
|
|
1134
|
+
|
|
1135
|
+
Args:
|
|
1136
|
+
result: The output array from compute().
|
|
1137
|
+
|
|
1138
|
+
Raises:
|
|
1139
|
+
TypeMismatchError: If output type doesn't match declared type.
|
|
1140
|
+
|
|
1141
|
+
"""
|
|
1142
|
+
if cls._returns_output_type is None:
|
|
1143
|
+
return # AnyArrow or not specified
|
|
1144
|
+
|
|
1145
|
+
if result.type != cls._returns_output_type:
|
|
1146
|
+
raise TypeMismatchError(
|
|
1147
|
+
"Output type mismatch.",
|
|
1148
|
+
param_name="return",
|
|
1149
|
+
expected_type=cls._returns_output_type,
|
|
1150
|
+
actual_type=result.type,
|
|
1151
|
+
function_name=cls.__name__,
|
|
1152
|
+
)
|
|
1153
|
+
|
|
1154
|
+
@classmethod
|
|
1155
|
+
def on_bind(cls, params: BindParameters) -> BindResult:
|
|
1156
|
+
"""Produce the output type during the bind phase.
|
|
1157
|
+
|
|
1158
|
+
Override to perform custom bind-time logic such as validating
|
|
1159
|
+
arguments, examining input schema, or computing a dynamic output type.
|
|
1160
|
+
|
|
1161
|
+
Args:
|
|
1162
|
+
params: Bind parameters including arguments, input schema,
|
|
1163
|
+
settings, and secrets.
|
|
1164
|
+
|
|
1165
|
+
Returns:
|
|
1166
|
+
BindResult with output_type and optional opaque_data.
|
|
1167
|
+
|
|
1168
|
+
Note:
|
|
1169
|
+
Constant arguments needed during process() are automatically
|
|
1170
|
+
serialized by the protocol. The opaque_data field is for
|
|
1171
|
+
additional bind-time state you need to pass forward.
|
|
1172
|
+
|
|
1173
|
+
"""
|
|
1174
|
+
return BindResult(output_type=cls.output_type(params), opaque_data=None)
|
|
1175
|
+
|
|
1176
|
+
@final
|
|
1177
|
+
@classmethod
|
|
1178
|
+
def process(
|
|
1179
|
+
cls,
|
|
1180
|
+
*,
|
|
1181
|
+
batch: pa.RecordBatch,
|
|
1182
|
+
init_call: InitRequest,
|
|
1183
|
+
init_response: BaseInitResponse,
|
|
1184
|
+
storage: BoundStorage,
|
|
1185
|
+
auth_context: AuthContext,
|
|
1186
|
+
) -> pa.RecordBatch:
|
|
1187
|
+
"""Convert compute() to per-batch callback.
|
|
1188
|
+
|
|
1189
|
+
This method calls your compute() method for the input batch.
|
|
1190
|
+
Keyword-only parameters in compute() are automatically populated
|
|
1191
|
+
from the batch columns.
|
|
1192
|
+
|
|
1193
|
+
"""
|
|
1194
|
+
output_schema = init_call.output_schema
|
|
1195
|
+
|
|
1196
|
+
# Extract columns for keyword-only parameters
|
|
1197
|
+
kwargs = cls._extract_compute_kwargs(batch, init_call.bind_call, auth_context)
|
|
1198
|
+
|
|
1199
|
+
# Validate input types match declared Param types
|
|
1200
|
+
cls._validate_param_types(kwargs)
|
|
1201
|
+
|
|
1202
|
+
# Call compute() defined by subclass. Cast to Any to avoid
|
|
1203
|
+
# attr-defined error since compute() isn't on base class.
|
|
1204
|
+
# and the arguments of compute() vary by subclass.
|
|
1205
|
+
result = cast(Any, cls).compute(**kwargs)
|
|
1206
|
+
|
|
1207
|
+
# Validate output type matches declared Returns type
|
|
1208
|
+
cls._validate_output_type(result)
|
|
1209
|
+
|
|
1210
|
+
# Create output batch from result array
|
|
1211
|
+
return pa.RecordBatch.from_arrays([result], schema=output_schema)
|