vgi-python 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. vgi/__init__.py +152 -0
  2. vgi/_duckdb.py +62 -0
  3. vgi/_storage_profile.py +132 -0
  4. vgi/_test_fixtures/__init__.py +20 -0
  5. vgi/_test_fixtures/accumulate/__init__.py +19 -0
  6. vgi/_test_fixtures/accumulate/worker.py +762 -0
  7. vgi/_test_fixtures/aggregate/__init__.py +62 -0
  8. vgi/_test_fixtures/aggregate/_common.py +21 -0
  9. vgi/_test_fixtures/aggregate/basic.py +232 -0
  10. vgi/_test_fixtures/aggregate/dynamic.py +409 -0
  11. vgi/_test_fixtures/aggregate/generic.py +86 -0
  12. vgi/_test_fixtures/aggregate/listagg.py +71 -0
  13. vgi/_test_fixtures/aggregate/percentile.py +107 -0
  14. vgi/_test_fixtures/aggregate/streaming.py +192 -0
  15. vgi/_test_fixtures/aggregate/varargs.py +75 -0
  16. vgi/_test_fixtures/aggregate/window.py +380 -0
  17. vgi/_test_fixtures/attach_options.py +308 -0
  18. vgi/_test_fixtures/bad_protocol.py +62 -0
  19. vgi/_test_fixtures/cancellable.py +336 -0
  20. vgi/_test_fixtures/catalog.py +813 -0
  21. vgi/_test_fixtures/http_server.py +394 -0
  22. vgi/_test_fixtures/nest_tensor.py +614 -0
  23. vgi/_test_fixtures/orchard_catalog.py +47 -0
  24. vgi/_test_fixtures/projection_repro/__init__.py +6 -0
  25. vgi/_test_fixtures/projection_repro/worker.py +454 -0
  26. vgi/_test_fixtures/scalar/__init__.py +116 -0
  27. vgi/_test_fixtures/scalar/_common.py +69 -0
  28. vgi/_test_fixtures/scalar/arithmetic.py +321 -0
  29. vgi/_test_fixtures/scalar/binary.py +120 -0
  30. vgi/_test_fixtures/scalar/formatting.py +176 -0
  31. vgi/_test_fixtures/scalar/geo.py +300 -0
  32. vgi/_test_fixtures/scalar/null_handling.py +107 -0
  33. vgi/_test_fixtures/scalar/random_demo.py +171 -0
  34. vgi/_test_fixtures/scalar/settings_secrets.py +102 -0
  35. vgi/_test_fixtures/scalar/type_info.py +219 -0
  36. vgi/_test_fixtures/schema_reconcile/__init__.py +29 -0
  37. vgi/_test_fixtures/schema_reconcile/worker.py +653 -0
  38. vgi/_test_fixtures/simple_writable.py +793 -0
  39. vgi/_test_fixtures/table/__init__.py +221 -0
  40. vgi/_test_fixtures/table/_common.py +162 -0
  41. vgi/_test_fixtures/table/batch_index.py +283 -0
  42. vgi/_test_fixtures/table/batch_index_broken.py +200 -0
  43. vgi/_test_fixtures/table/catalog_scans.py +162 -0
  44. vgi/_test_fixtures/table/filters.py +1005 -0
  45. vgi/_test_fixtures/table/late_materialization.py +249 -0
  46. vgi/_test_fixtures/table/make_series.py +273 -0
  47. vgi/_test_fixtures/table/misc.py +499 -0
  48. vgi/_test_fixtures/table/order_modes.py +164 -0
  49. vgi/_test_fixtures/table/pairs.py +437 -0
  50. vgi/_test_fixtures/table/partition_columns.py +472 -0
  51. vgi/_test_fixtures/table/partition_columns_broken.py +304 -0
  52. vgi/_test_fixtures/table/profiling_example.py +195 -0
  53. vgi/_test_fixtures/table/required_filters.py +234 -0
  54. vgi/_test_fixtures/table/sequence.py +710 -0
  55. vgi/_test_fixtures/table/settings.py +426 -0
  56. vgi/_test_fixtures/table/transaction_storage.py +162 -0
  57. vgi/_test_fixtures/table/tt_pushdown.py +191 -0
  58. vgi/_test_fixtures/table/versioned.py +230 -0
  59. vgi/_test_fixtures/table_in_out.py +1392 -0
  60. vgi/_test_fixtures/versioned.py +155 -0
  61. vgi/_test_fixtures/versioned_tables.py +595 -0
  62. vgi/_test_fixtures/worker.py +1631 -0
  63. vgi/_test_fixtures/writable/__init__.py +8 -0
  64. vgi/_test_fixtures/writable/generic.py +236 -0
  65. vgi/_test_fixtures/writable/table.py +149 -0
  66. vgi/_test_fixtures/writable/worker.py +1148 -0
  67. vgi/aggregate_function.py +607 -0
  68. vgi/argument_spec.py +472 -0
  69. vgi/arguments.py +1747 -0
  70. vgi/auth.py +55 -0
  71. vgi/catalog/__init__.py +88 -0
  72. vgi/catalog/attach_option.py +206 -0
  73. vgi/catalog/catalog_interface.py +2767 -0
  74. vgi/catalog/descriptors.py +870 -0
  75. vgi/catalog/duckdb_statistics.py +377 -0
  76. vgi/catalog/secret_type.py +96 -0
  77. vgi/catalog/setting.py +253 -0
  78. vgi/catalog/storage.py +372 -0
  79. vgi/client/__init__.py +67 -0
  80. vgi/client/catalog_mixin.py +1251 -0
  81. vgi/client/cli.py +582 -0
  82. vgi/client/cli_catalog.py +182 -0
  83. vgi/client/cli_schema.py +270 -0
  84. vgi/client/cli_table.py +907 -0
  85. vgi/client/cli_transaction.py +97 -0
  86. vgi/client/cli_utils.py +441 -0
  87. vgi/client/cli_view.py +303 -0
  88. vgi/client/client.py +2183 -0
  89. vgi/exceptions.py +205 -0
  90. vgi/function.py +245 -0
  91. vgi/function_storage.py +1636 -0
  92. vgi/function_storage_azure_sql.py +922 -0
  93. vgi/function_storage_cf_do.py +740 -0
  94. vgi/http/__init__.py +25 -0
  95. vgi/http/demo_storage.py +212 -0
  96. vgi/http/worker_page.py +1252 -0
  97. vgi/invocation.py +154 -0
  98. vgi/logging_config.py +93 -0
  99. vgi/meta_worker.py +661 -0
  100. vgi/metadata.py +1403 -0
  101. vgi/otel.py +406 -0
  102. vgi/protocol.py +2418 -0
  103. vgi/protocol_version.txt +1 -0
  104. vgi/py.typed +0 -0
  105. vgi/scalar_function.py +1211 -0
  106. vgi/schema_utils.py +234 -0
  107. vgi/secret_protocol.py +124 -0
  108. vgi/secret_service.py +238 -0
  109. vgi/serve.py +769 -0
  110. vgi/table_buffering_function.py +443 -0
  111. vgi/table_filter_pushdown.py +1528 -0
  112. vgi/table_function.py +1130 -0
  113. vgi/table_in_out_function.py +383 -0
  114. vgi/transactor/__init__.py +24 -0
  115. vgi/transactor/_duckdb_compat.py +27 -0
  116. vgi/transactor/client.py +137 -0
  117. vgi/transactor/protocol.py +149 -0
  118. vgi/transactor/server.py +740 -0
  119. vgi/worker.py +4761 -0
  120. vgi_python-0.8.0.dist-info/METADATA +735 -0
  121. vgi_python-0.8.0.dist-info/RECORD +124 -0
  122. vgi_python-0.8.0.dist-info/WHEEL +4 -0
  123. vgi_python-0.8.0.dist-info/entry_points.txt +5 -0
  124. vgi_python-0.8.0.dist-info/licenses/LICENSE +134 -0
vgi/scalar_function.py ADDED
@@ -0,0 +1,1211 @@
1
+ # Copyright 2025, 2026 Query Farm LLC - https://query.farm
2
+
3
+ """Scalar functions: per-row transforms with single-column output.
4
+
5
+ Scalar functions are the simplest function type in VGI. They transform each
6
+ input row into exactly one output value, producing a single column of results.
7
+
8
+ Key characteristics:
9
+ - **1:1 row mapping**: Output has exactly the same number of rows as input
10
+ - **Single column output**: Output schema has exactly one column named "result"
11
+ - **No finish()**: Processing ends when the caller closes the input stream.
12
+
13
+ Common use cases:
14
+ - Mathematical operations: multiply, add, abs
15
+ - String transforms: upper, lower, concat, trim
16
+ - Type conversions: cast, parse
17
+ - Field extraction: get nested values, parse JSON fields
18
+
19
+ This module provides two base classes:
20
+
21
+ ScalarFunction (recommended)
22
+ Declarative API using Param/ConstParam/Returns annotations on compute().
23
+ Also supports Setting, Secret, and OutputLength annotations.
24
+ Override output_type() only if the output type depends on input schema.
25
+
26
+ ScalarFunctionGenerator (advanced)
27
+ Per-batch callback API for fine-grained control.
28
+ Override output_type() and process().
29
+
30
+ """
31
+
32
+ from __future__ import annotations
33
+
34
+ import contextlib
35
+ import inspect
36
+ import logging
37
+ import uuid
38
+ from abc import abstractmethod
39
+ from dataclasses import dataclass
40
+ from typing import TYPE_CHECKING, Any, cast, final, get_args, get_origin, get_type_hints
41
+
42
+ import pyarrow as pa
43
+ from vgi_rpc import ArrowSerializableDataclass
44
+ from vgi_rpc.rpc import AuthContext, CallContext
45
+
46
+ import vgi.function
47
+ from vgi.arguments import (
48
+ _PYTHON_TO_ARROW,
49
+ ARRAY_CLASS_TO_DATATYPE,
50
+ COMPLEX_ARRAY_CLASSES,
51
+ Arg,
52
+ Arguments,
53
+ ArgumentValidationError,
54
+ Auth,
55
+ ConstParam,
56
+ OutputLength,
57
+ Param,
58
+ Returns,
59
+ Secret,
60
+ SecretLookupEntry,
61
+ _extract_setting_secret_params,
62
+ )
63
+ from vgi.function_storage import BoundStorage, attach_catalog_bytes
64
+ from vgi.invocation import (
65
+ BaseInitResponse,
66
+ BindResponse,
67
+ GlobalInitResponse,
68
+ )
69
+ from vgi.schema_utils import schema
70
+ from vgi.table_function import SecretsAccessor, _struct_scalar_to_dict
71
+
72
+ if TYPE_CHECKING:
73
+ from vgi.protocol import BindRequest, InitRequest
74
+
75
+ logger = logging.getLogger(__name__)
76
+
77
+ __all__ = [
78
+ "BindParameters",
79
+ "BindResult",
80
+ "RowCountMismatchError",
81
+ "ScalarFunction",
82
+ "ScalarFunctionGenerator",
83
+ "TypeMismatchError",
84
+ ]
85
+
86
+
87
+ @dataclass(slots=True, frozen=True)
88
+ class BindResult(ArrowSerializableDataclass):
89
+ """Result of calling bind() on a scalar function.
90
+
91
+ Unlike table functions which return a full schema, scalar functions
92
+ return a single output type since they produce one value per row.
93
+
94
+ Attributes:
95
+ output_type: Arrow data type for the output value.
96
+ opaque_data: Optional serialized data, opaque to the caller,
97
+ that will be passed to global_init() and process().
98
+
99
+ """
100
+
101
+ output_type: pa.DataType
102
+ opaque_data: ArrowSerializableDataclass | None = None
103
+
104
+
105
+ @dataclass(slots=True, frozen=True)
106
+ class BindParameters:
107
+ """Parameters passed to a scalar function's bind() method.
108
+
109
+ Attributes:
110
+ constant_arguments: Constant arguments provided at query planning time.
111
+ arguments_schema: Schema describing the input columns.
112
+ settings: DuckDB settings as a single-row RecordBatch, or None.
113
+ secrets: SecretsAccessor for accessing resolved and dynamic secrets.
114
+ auth_context: Authentication context for the current request.
115
+ attach_opaque_data: Catalog attach ID, if the function was invoked through an ATTACHed catalog.
116
+ transaction_opaque_data: Catalog transaction ID, if invoked inside a catalog transaction.
117
+
118
+ """
119
+
120
+ constant_arguments: Arguments
121
+ arguments_schema: pa.Schema
122
+ settings: pa.RecordBatch | None
123
+ secrets: SecretsAccessor
124
+ auth_context: AuthContext = AuthContext.anonymous()
125
+ attach_opaque_data: bytes | None = None
126
+ transaction_opaque_data: bytes | None = None
127
+
128
+
129
+ def _resolve_explicit_arrow_type(arrow_type: pa.DataType | type) -> pa.DataType:
130
+ """Resolve an explicit arrow_type value to a pa.DataType.
131
+
132
+ Handles pa.DataType instances and Python types (int/str/float/bool/bytes).
133
+
134
+ Raises:
135
+ TypeError: If the type cannot be converted to Arrow.
136
+
137
+ """
138
+ if isinstance(arrow_type, pa.DataType):
139
+ return arrow_type
140
+ if arrow_type in _PYTHON_TO_ARROW:
141
+ return _PYTHON_TO_ARROW[arrow_type]
142
+ raise TypeError(
143
+ f"Cannot convert type '{arrow_type}' to Arrow type. "
144
+ f"Use pa.DataType, Python type (int/str/float/bool/bytes), "
145
+ f"or None for AnyArrow."
146
+ )
147
+
148
+
149
+ def _param_to_arg(param: Param, base_type: type, position: int) -> Arg[Any]:
150
+ """Convert Param dataclass to internal Arg object with type inference.
151
+
152
+ Supports hybrid type inference:
153
+ 1. Explicit arrow_type in Param() takes priority
154
+ 2. Simple array classes (pa.Int64Array, etc.) are inferred automatically
155
+ 3. Complex/parameterized types (pa.StructArray, etc.) require explicit arrow_type
156
+ 4. pa.Array or pa.Array[Any] indicates AnyArrow (dynamic type)
157
+
158
+ Args:
159
+ param: The Param metadata from an Annotated type hint.
160
+ base_type: The type from the Annotated first argument (e.g., pa.Int64Array
161
+ from Annotated[pa.Int64Array, Param(...)]).
162
+ position: The parameter's position in the compute() signature.
163
+
164
+ Returns:
165
+ Arg instance configured for columnar input.
166
+
167
+ Raises:
168
+ TypeError: If type cannot be determined (complex type without explicit
169
+ arrow_type).
170
+
171
+ """
172
+ is_any = False
173
+ arrow_type: pa.DataType
174
+
175
+ # For varargs params, unwrap list[X] to get the element type X
176
+ infer_type = base_type
177
+ if param.varargs and get_origin(base_type) is list:
178
+ type_args = get_args(base_type)
179
+ if type_args:
180
+ infer_type = type_args[0]
181
+
182
+ if param.arrow_type is not None:
183
+ arrow_type = _resolve_explicit_arrow_type(param.arrow_type)
184
+ # Infer from simple array class (pa.Int64Array -> pa.int64())
185
+ elif infer_type in ARRAY_CLASS_TO_DATATYPE:
186
+ arrow_type = ARRAY_CLASS_TO_DATATYPE[infer_type]
187
+ # Complex types require explicit arrow_type
188
+ elif infer_type in COMPLEX_ARRAY_CLASSES:
189
+ raise TypeError(
190
+ f"{base_type.__name__} requires explicit arrow_type in Param(). "
191
+ f"Example: Param(arrow_type=pa.list_(pa.int64()), doc='...')"
192
+ )
193
+ # pa.Array or generic -> AnyArrow
194
+ else:
195
+ # Covers pa.Array, pa.Array[Any], Any, and other generic types
196
+ is_any = True
197
+ arrow_type = pa.null() # Placeholder for AnyArrow
198
+
199
+ return Arg[Any](
200
+ position,
201
+ doc=param.doc,
202
+ arrow_type=arrow_type,
203
+ type_bound=param.type_bound,
204
+ varargs=param.varargs,
205
+ is_any=is_any,
206
+ )
207
+
208
+
209
+ def _const_param_to_arg(const_param: ConstParam, base_type: type, position: int) -> Arg[Any]:
210
+ """Convert ConstParam dataclass to internal Arg object.
211
+
212
+ Args:
213
+ const_param: The ConstParam metadata from an Annotated type hint.
214
+ base_type: The type from the Annotated first argument (e.g., int from
215
+ Annotated[int, ConstParam(...)]).
216
+ position: The parameter's position in the const arguments.
217
+
218
+ Returns:
219
+ Arg instance configured for constant (scalar) input.
220
+
221
+ Raises:
222
+ TypeError: If the Arrow type cannot be determined.
223
+
224
+ """
225
+ arrow_type: pa.DataType
226
+
227
+ if const_param.arrow_type is not None:
228
+ arrow_type = _resolve_explicit_arrow_type(const_param.arrow_type)
229
+ elif base_type in _PYTHON_TO_ARROW:
230
+ # Infer from Annotated first argument
231
+ arrow_type = _PYTHON_TO_ARROW[base_type]
232
+ else:
233
+ raise TypeError(
234
+ f"Cannot infer Arrow type from {base_type}. "
235
+ f"Use a supported type (int/str/float/bool/bytes) or specify arrow_type."
236
+ )
237
+
238
+ return Arg[Any](position, doc=const_param.doc, arrow_type=arrow_type, const=True)
239
+
240
+
241
+ # =============================================================================
242
+ # Descriptors for Param/ConstParam Arguments
243
+ # =============================================================================
244
+
245
+
246
+ class _ArgDescriptor:
247
+ """Descriptor for Param arguments.
248
+
249
+ On class access, returns the Arg metadata (position, doc, type_bound, etc.).
250
+ On instance access, returns the resolved column value via Arg._resolve().
251
+ """
252
+
253
+ __slots__ = ("arg", "name")
254
+
255
+ def __init__(self, arg: Arg[Any], name: str) -> None:
256
+ self.arg = arg
257
+ self.name = name
258
+
259
+ def __get__(self, obj: object | None, _objtype: type | None = None) -> Any:
260
+ if obj is None:
261
+ return self.arg
262
+ # For instance access, return the resolved value from arguments
263
+ # This allows accessing self.param_name to get the column name/value
264
+ return self.arg._resolve(obj)
265
+
266
+
267
+ class _ConstArgDescriptor:
268
+ """Descriptor for constant-folded ConstParam arguments.
269
+
270
+ Provides access to the scalar value (not array) for const parameters.
271
+ The value is resolved from invocation.arguments and converted to Python.
272
+
273
+ These must be separate classes because their __get__ methods return different types.
274
+ - _ArgDescriptor returns the column value (array) for regular Param
275
+ - _ConstArgDescriptor returns the scalar value for ConstParam
276
+ """
277
+
278
+ __slots__ = ("arg", "name")
279
+
280
+ def __init__(self, arg: Arg[Any], name: str) -> None:
281
+ self.arg = arg
282
+ self.name = name
283
+
284
+ def __get__(self, obj: object | None, _objtype: type | None = None) -> Any:
285
+ if obj is None:
286
+ return self.arg
287
+ # For instance access, return the resolved scalar value
288
+ return self.arg._resolve(obj)
289
+
290
+
291
+ class RowCountMismatchError(Exception):
292
+ """Raised when scalar function output row count doesn't match input.
293
+
294
+ Scalar functions must produce exactly one output row for each input row.
295
+ This error indicates the compute() method returned an array with the
296
+ wrong number of elements.
297
+
298
+ Attributes:
299
+ input_rows: Number of rows in the input batch.
300
+ output_rows: Number of rows in the output batch.
301
+ function_name: Name of the function that produced the mismatch.
302
+
303
+ """
304
+
305
+ def __init__(
306
+ self,
307
+ message: str,
308
+ *,
309
+ input_rows: int | None = None,
310
+ output_rows: int | None = None,
311
+ function_name: str = "",
312
+ ) -> None:
313
+ """Initialize with row count details.
314
+
315
+ Args:
316
+ message: Base error message.
317
+ input_rows: Number of input rows.
318
+ output_rows: Number of output rows.
319
+ function_name: Name of the function class.
320
+
321
+ """
322
+ self.input_rows = input_rows
323
+ self.output_rows = output_rows
324
+ self.function_name = function_name
325
+
326
+ if input_rows is not None and output_rows is not None:
327
+ full_message = self._build_detailed_message(message, input_rows, output_rows)
328
+ else:
329
+ full_message = message
330
+
331
+ super().__init__(full_message)
332
+
333
+ def _build_detailed_message(self, base_message: str, input_rows: int, output_rows: int) -> str:
334
+ """Build a detailed, helpful error message."""
335
+ lines = [base_message, ""]
336
+
337
+ if self.function_name:
338
+ lines.append(f" Function: {self.function_name}")
339
+
340
+ lines.append(f" Input rows: {input_rows}")
341
+ lines.append(f" Output rows: {output_rows}")
342
+
343
+ # Provide specific guidance based on the mismatch type
344
+ lines.append("")
345
+ if output_rows < input_rows:
346
+ lines.append(" Problem: Output has fewer rows than input.")
347
+ lines.append("")
348
+ lines.append(" Possible causes:")
349
+ lines.append(" - compute() is filtering rows (not allowed in scalar)")
350
+ lines.append(" - compute() is aggregating (not allowed in scalar)")
351
+ lines.append(" - Bug in array construction")
352
+ lines.append("")
353
+ lines.append(" Scalar functions require 1:1 row mapping.")
354
+ lines.append(" For filtering or aggregation, use a table function.")
355
+ else:
356
+ lines.append(" Problem: Output has more rows than input.")
357
+ lines.append("")
358
+ lines.append(" Possible causes:")
359
+ lines.append(" - compute() is expanding rows (not allowed in scalar)")
360
+ lines.append(" - compute() is unnesting arrays")
361
+ lines.append(" - Bug in array construction")
362
+ lines.append("")
363
+ lines.append(" Scalar functions require 1:1 row mapping.")
364
+ lines.append(" For row expansion (1→N), use a table function.")
365
+
366
+ return "\n".join(lines)
367
+
368
+
369
+ class TypeMismatchError(TypeError):
370
+ """Raised when array type doesn't match declared parameter or return type.
371
+
372
+ This error indicates a mismatch between the declared type in Param() or Returns()
373
+ and the actual array type at runtime.
374
+
375
+ Attributes:
376
+ param_name: Name of the parameter with the type mismatch.
377
+ expected_type: The declared Arrow type.
378
+ actual_type: The actual Arrow type found.
379
+ function_name: Name of the function class.
380
+
381
+ """
382
+
383
+ def __init__(
384
+ self,
385
+ message: str,
386
+ *,
387
+ param_name: str = "",
388
+ expected_type: pa.DataType | None = None,
389
+ actual_type: pa.DataType | None = None,
390
+ function_name: str = "",
391
+ ) -> None:
392
+ """Initialize with type mismatch details.
393
+
394
+ Args:
395
+ message: Base error message.
396
+ param_name: Name of the parameter.
397
+ expected_type: Expected Arrow type.
398
+ actual_type: Actual Arrow type found.
399
+ function_name: Name of the function class.
400
+
401
+ """
402
+ self.param_name = param_name
403
+ self.expected_type = expected_type
404
+ self.actual_type = actual_type
405
+ self.function_name = function_name
406
+
407
+ if expected_type is not None and actual_type is not None:
408
+ full_message = self._build_detailed_message(message, param_name, expected_type, actual_type)
409
+ else:
410
+ full_message = message
411
+
412
+ super().__init__(full_message)
413
+
414
+ def _build_detailed_message(
415
+ self,
416
+ base_message: str,
417
+ param_name: str,
418
+ expected_type: pa.DataType,
419
+ actual_type: pa.DataType,
420
+ ) -> str:
421
+ """Build a detailed, helpful error message."""
422
+ lines = [base_message, ""]
423
+
424
+ if self.function_name:
425
+ lines.append(f" Function: {self.function_name}")
426
+ if param_name:
427
+ lines.append(f" Parameter: {param_name}")
428
+
429
+ lines.append(f" Expected type: {expected_type}")
430
+ lines.append(f" Actual type: {actual_type}")
431
+
432
+ return "\n".join(lines)
433
+
434
+
435
+ class ScalarFunctionGenerator(vgi.function.Function):
436
+ """Per-batch callback base class for scalar functions.
437
+
438
+ This is the advanced API for scalar functions. For most use cases,
439
+ use ScalarFunction instead, which provides a simpler compute() callback.
440
+
441
+ Scalar functions have these constraints:
442
+ - **1:1 row mapping**: Output row count must equal input row count
443
+ - **Single value output**: Produces one value per input row
444
+ - **No finalization**: Processing ends when input is exhausted
445
+
446
+ Methods to Override
447
+ -------------------
448
+ output_type(params) -> pa.DataType
449
+ Return the Arrow type for the output value. Required.
450
+
451
+ process(...) -> pa.RecordBatch
452
+ Process one input batch. Must return output with same row count.
453
+ Required.
454
+
455
+ on_bind(params) -> BindResult
456
+ Optional. Override to perform custom bind-time logic.
457
+
458
+ on_init(...) -> GlobalInitResponse
459
+ Optional. Override to perform custom initialization.
460
+
461
+ Protocol Entry Points (called by worker, do not override)
462
+ ---------------------------------------------------------
463
+ bind(input) -> BindResponse
464
+ Handles the bind API call
465
+
466
+ global_init(input) -> GlobalInitResponse
467
+ Handles the global_init API call
468
+
469
+ """
470
+
471
+ @final
472
+ @classmethod
473
+ def _validate_row_count(cls, output_batch: pa.RecordBatch, input_batch: pa.RecordBatch) -> None:
474
+ """Validate that output row count matches input row count."""
475
+ if output_batch.num_rows != input_batch.num_rows:
476
+ raise RowCountMismatchError(
477
+ "Scalar function output must have same row count as input.",
478
+ input_rows=input_batch.num_rows,
479
+ output_rows=output_batch.num_rows,
480
+ function_name=cls.__name__,
481
+ )
482
+
483
+ @classmethod
484
+ @abstractmethod
485
+ def output_type(cls, params: BindParameters) -> pa.DataType:
486
+ """Return the Arrow type for the output value.
487
+
488
+ Args:
489
+ params: Bind parameters including arguments and input schema.
490
+
491
+ """
492
+ ...
493
+
494
+ @classmethod
495
+ def on_bind(
496
+ cls,
497
+ params: BindParameters,
498
+ ) -> BindResult:
499
+ """Produce the output type during the bind API call.
500
+
501
+ Override to perform custom bind-time logic such as validating
502
+ arguments or computing a dynamic output type.
503
+
504
+ Args:
505
+ params: Bind parameters including arguments and schema.
506
+
507
+ Returns:
508
+ BindResult with output_type and optional opaque_data.
509
+
510
+ """
511
+ return BindResult(cls.output_type(params))
512
+
513
+ @classmethod
514
+ def catalog_output_schema(cls) -> pa.Schema:
515
+ """Return output schema for catalog introspection.
516
+
517
+ A generator-style scalar function computes its output type at bind
518
+ time, so no static type is known here. Report a single dynamic
519
+ ``result`` column (``null()`` tagged ``vgi:any``) so catalog
520
+ consumers treat the type as resolved-at-bind. ScalarFunction
521
+ overrides this when a static ``Returns()`` type is available.
522
+ """
523
+ field = pa.field("result", pa.null(), metadata={b"vgi:any": b"true"})
524
+ return pa.schema([field])
525
+
526
+ @final
527
+ @classmethod
528
+ def _validate_param_type_bounds(cls, input_schema: pa.Schema) -> None:
529
+ """Validate type bounds for AnyArrow Param parameters at bind time.
530
+
531
+ Checks each Param with type_bound against the input schema field types.
532
+ This provides early error detection before any data is processed.
533
+
534
+ Only applies to ScalarFunction subclasses that define _compute_params
535
+ (via the Param/ConstParam annotation API). For ScalarFunctionGenerator
536
+ subclasses that don't use annotations, this is a no-op.
537
+
538
+ Args:
539
+ input_schema: The input schema from the bind call.
540
+
541
+ Raises:
542
+ SchemaValidationError: If any column type fails type_bound.
543
+
544
+ """
545
+ compute_params: dict[str, Arg[Any]] | None = getattr(cls, "_compute_params", None)
546
+ if not compute_params:
547
+ return
548
+ for _name, arg in compute_params.items():
549
+ if not arg.is_any or arg.type_bound is None:
550
+ continue
551
+ col_idx = cast(int, arg._resolution_index)
552
+ if arg.varargs:
553
+ for i in range(col_idx, len(input_schema)):
554
+ arg.validate_type_bound(input_schema.field(i).type)
555
+ else:
556
+ arg.validate_type_bound(input_schema.field(col_idx).type)
557
+
558
+ @final
559
+ @classmethod
560
+ def _validate_param_varargs_min(cls, input_schema: pa.Schema) -> None:
561
+ """Ensure varargs Params receive at least one column.
562
+
563
+ Mirrors the ``Arg.varargs`` rule for table-function arguments: a
564
+ ``Param(varargs=True)`` must bind to >= 1 input column. Without this
565
+ guard, ``on_bind`` implementations that index ``arguments_schema``
566
+ crash with an opaque ``IndexError`` when callers invoke a varargs
567
+ scalar with zero values.
568
+ """
569
+ compute_params: dict[str, Arg[Any]] | None = getattr(cls, "_compute_params", None)
570
+ if not compute_params:
571
+ return
572
+ for name, arg in compute_params.items():
573
+ if not arg.varargs:
574
+ continue
575
+ col_idx = cast(int, arg._resolution_index)
576
+ if col_idx >= len(input_schema):
577
+ raise ArgumentValidationError(
578
+ f"Varargs parameter '{name}' requires at least 1 value.",
579
+ arg_name=name,
580
+ position=arg.position,
581
+ constraint="varargs requires at least 1 value",
582
+ doc=arg.doc if arg.doc else None,
583
+ )
584
+
585
+ @final
586
+ @classmethod
587
+ def bind(
588
+ cls,
589
+ input: BindRequest,
590
+ *,
591
+ ctx: CallContext | None = None,
592
+ attach_plaintext: bytes | None = None,
593
+ ) -> BindResponse:
594
+ """Bind protocol entry point. Do not override; use on_bind() instead.
595
+
596
+ Constructs BindParameters, validates type bounds, calls on_bind(),
597
+ and wraps the result for transmission to global_init. If on_bind()
598
+ triggers dynamic secret lookups or if compute() declares Secret()
599
+ annotations that haven't been resolved, returns a secret scope request.
600
+
601
+ """
602
+ assert input.input_schema is not None
603
+ cls._validate_param_varargs_min(input.input_schema)
604
+ cls._validate_param_type_bounds(input.input_schema)
605
+
606
+ # Auto-request secrets declared via Secret() annotations on compute()
607
+ # when they haven't been resolved yet (first bind call).
608
+ # _secret_params is only defined on ScalarFunction, not ScalarFunctionGenerator.
609
+ secret_params: dict[str, Secret] = getattr(cls, "_secret_params", {})
610
+ if secret_params and not input.resolved_secrets_provided and input.secrets is None:
611
+ entries = [
612
+ SecretLookupEntry(
613
+ secret_type=secret.secret_type,
614
+ scope=secret.scope,
615
+ secret_name=secret.name,
616
+ )
617
+ for secret in secret_params.values()
618
+ ]
619
+ return BindResponse.secret_scope_request(entries)
620
+
621
+ auth = ctx.auth if ctx is not None else AuthContext.anonymous()
622
+ secrets_accessor = SecretsAccessor(input.secrets, is_retry=input.resolved_secrets_provided)
623
+ bind_params = BindParameters(
624
+ input.arguments,
625
+ input.input_schema,
626
+ input.settings,
627
+ secrets_accessor,
628
+ auth,
629
+ # ``attach_plaintext`` is the full framework plaintext (uuid||catalog
630
+ # bytes); the body sees only the catalog bytes. transaction_opaque_data
631
+ # is already sealed.
632
+ attach_catalog_bytes(attach_plaintext),
633
+ input.transaction_opaque_data,
634
+ )
635
+ result = cls.on_bind(bind_params)
636
+
637
+ # Check if on_bind() registered pending secret lookups
638
+ if secrets_accessor.needs_resolution:
639
+ return BindResponse.secret_scope_request(secrets_accessor.pending_lookups)
640
+
641
+ # Serialize the typed BindResult.opaque_data to bytes before
642
+ # putting it on the wire. The user-facing API stays typed
643
+ # (BindResult.opaque_data: ArrowSerializableDataclass | None) but
644
+ # the wire field is bytes; the framework owns this single
645
+ # boundary shim so workers don't write per-callsite serialization
646
+ # boilerplate. See vgi/invocation.py:BindResponse.opaque_data for
647
+ # the full contract.
648
+ opaque_bytes: bytes | None = None
649
+ if result.opaque_data is not None:
650
+ opaque_bytes = result.opaque_data.serialize_to_bytes()
651
+
652
+ return BindResponse(
653
+ output_schema=schema(result=result.output_type),
654
+ opaque_data=opaque_bytes,
655
+ )
656
+
657
+ @classmethod
658
+ def on_init(
659
+ cls,
660
+ *,
661
+ bind_call: BindRequest,
662
+ opaque_data: bytes | None,
663
+ storage: BoundStorage,
664
+ ) -> GlobalInitResponse:
665
+ """Initialize the function during the init API call.
666
+
667
+ Override to perform one-time setup that should happen after bind
668
+ but before processing batches. The default returns max_processes=1.
669
+
670
+ Args:
671
+ bind_call: The original BindCall with arguments and schema.
672
+ opaque_data: Bytes from on_bind()'s ``BindResult.opaque_data``
673
+ (after the framework's serialize-to-bytes shim), or None
674
+ if on_bind didn't set it. Reconstruct via
675
+ ``MyConcreteDataclass.deserialize_from_bytes(opaque_data)``
676
+ — the consumer always knows what concrete type to expect,
677
+ so explicit reconstruction is preferred over a framework-
678
+ level class-name registry.
679
+ storage: BoundStorage for storing data across calls.
680
+
681
+ Returns:
682
+ GlobalInitResponse with max_processes and optional opaque data.
683
+
684
+ """
685
+ return GlobalInitResponse()
686
+
687
+ @final
688
+ @classmethod
689
+ def global_init(cls, input: InitRequest, *, attach_plaintext: bytes | None = None) -> GlobalInitResponse:
690
+ """Global init protocol entry point. Do not override; use on_init() instead.
691
+
692
+ Deserializes the wrapped bind data, calls on_init(), and
693
+ wraps the result for transmission to process().
694
+
695
+ ``attach_plaintext`` is the full framework plaintext (uuid||catalog
696
+ bytes) the worker unwrapped; storage shards on its UUID. Scalar
697
+ ``on_init`` does not expose the attach to bodies, so only storage uses it.
698
+ """
699
+ execution_id = uuid.uuid4().bytes
700
+ result = cls.on_init(
701
+ bind_call=input.bind_call,
702
+ opaque_data=input.bind_opaque_data,
703
+ storage=BoundStorage(cls.storage, execution_id, request=input, attach_plaintext=attach_plaintext),
704
+ )
705
+
706
+ return GlobalInitResponse(
707
+ max_workers=result.max_workers,
708
+ execution_id=execution_id,
709
+ opaque_data=result.opaque_data,
710
+ )
711
+
712
+ @classmethod
713
+ @abstractmethod
714
+ def process(
715
+ cls,
716
+ *,
717
+ batch: pa.RecordBatch,
718
+ init_call: InitRequest,
719
+ init_response: BaseInitResponse,
720
+ storage: BoundStorage,
721
+ auth_context: AuthContext,
722
+ ) -> pa.RecordBatch:
723
+ """Process one input batch.
724
+
725
+ Override this method to implement your scalar transformation.
726
+ Must return an output RecordBatch with exactly the same number
727
+ of rows as the input batch.
728
+
729
+ Args:
730
+ batch: The input RecordBatch to process.
731
+ init_call: The parameters from global_init.
732
+ init_response: The response from the init call.
733
+ storage: BoundStorage for storing data across calls.
734
+ auth_context: Authentication context for the current request.
735
+
736
+ Returns:
737
+ Output RecordBatch with same row count as input.
738
+
739
+ """
740
+ ...
741
+
742
+
743
+ class ScalarFunction(ScalarFunctionGenerator):
744
+ """Base class for scalar functions (1:1 row mapping, single output column).
745
+
746
+ Scalar functions transform each input row to exactly one output value.
747
+ Use Param/ConstParam/Returns annotations on compute() to declare types.
748
+
749
+ Type Validation
750
+ ---------------
751
+ Input and output types are validated at runtime:
752
+ - Param types are checked against actual array types
753
+ - Returns type is checked against compute() result
754
+ - AnyArrow parameters skip validation
755
+ - TypeMismatchError is raised on mismatch
756
+
757
+ Methods to Override
758
+ -------------------
759
+ compute(self, ...) -> pa.Array
760
+ Transform input arrays to output. Use Param/ConstParam annotations.
761
+
762
+ output_type(params) -> pa.DataType (classmethod)
763
+ Override when output type depends on input schema or arguments.
764
+
765
+ """
766
+
767
+ # For TYPE_CHECKING, allow dynamic attribute access for Param/ConstParam
768
+ if TYPE_CHECKING:
769
+
770
+ def __getattr__(self, _name: str) -> Any:
771
+ """Allow dynamic attribute access for Param/ConstParam descriptors."""
772
+ ...
773
+
774
+ _compute_params: dict[str, Arg[Any]] # Regular Param() arguments (arrays)
775
+ _const_params: dict[str, Arg[Any]] # ConstParam() arguments (scalars)
776
+ _setting_params: dict[str, str] # Setting params: param_name -> setting_key
777
+ _secret_params: dict[str, Secret] # Secret params: param_name -> Secret instance
778
+ _output_length_param: str | None # OutputLength param name (batch row count)
779
+ _auth_param: str | None # Auth param name (receives AuthContext)
780
+ _returns_output_type: pa.DataType | None # Output type from Returns()
781
+
782
+ def __init_subclass__(cls, **kwargs: Any) -> None:
783
+ """Extract annotations from compute() signature.
784
+
785
+ Extracts Param, ConstParam, Setting, Secret, OutputLength, and
786
+ Returns type information from compute() parameter annotations.
787
+ """
788
+ super().__init_subclass__(**kwargs)
789
+
790
+ # Skip abstract classes
791
+ if inspect.isabstract(cls):
792
+ return
793
+
794
+ # Get compute method
795
+ compute_method = getattr(cls, "compute", None)
796
+ if compute_method is None:
797
+ raise TypeError(f"{cls.__name__} must define a compute() method.\n\n")
798
+
799
+ sig = inspect.signature(compute_method)
800
+
801
+ # Try to get type hints for the compute method
802
+ # This handles both regular annotations and PEP 563 string annotations
803
+ hints: dict[str, Any] = {}
804
+ with contextlib.suppress(Exception):
805
+ hints = get_type_hints(compute_method, include_extras=True)
806
+
807
+ # If get_type_hints failed or returned empty, try to evaluate annotations
808
+ # manually. This handles cases where Param/ConstParam are used with
809
+ # `from __future__ import annotations` which stores annotations as strings.
810
+ if not hints:
811
+ raw_annotations = getattr(compute_method, "__annotations__", {})
812
+ # Build namespace with imports from vgi.arguments for evaluation
813
+ from vgi import arguments as vgi_args
814
+
815
+ # Create a mock pa module with subscriptable Array for eval
816
+ # (pa.Array[Any] isn't subscriptable in PyArrow)
817
+ class _MockArray:
818
+ def __class_getitem__(cls, _item: Any) -> Any:
819
+ return Any
820
+
821
+ class _MockPa:
822
+ Array = _MockArray
823
+
824
+ def __getattr__(self, name: str) -> Any:
825
+ return getattr(pa, name)
826
+
827
+ from typing import Annotated
828
+
829
+ eval_namespace = {
830
+ **getattr(compute_method, "__globals__", {}),
831
+ "Annotated": Annotated,
832
+ "Param": vgi_args.Param,
833
+ "ConstParam": vgi_args.ConstParam,
834
+ "Setting": vgi_args.Setting,
835
+ "Secret": vgi_args.Secret,
836
+ "Auth": vgi_args.Auth,
837
+ "OutputLength": vgi_args.OutputLength,
838
+ "Returns": vgi_args.Returns,
839
+ "AnyArrow": vgi_args.AnyArrow,
840
+ "pa": _MockPa(),
841
+ }
842
+ for name, annotation in raw_annotations.items():
843
+ if isinstance(annotation, str):
844
+ with contextlib.suppress(Exception):
845
+ hints[name] = eval(annotation, eval_namespace) # noqa: S307
846
+ else:
847
+ hints[name] = annotation
848
+
849
+ compute_params: dict[str, Arg[Any]] = {}
850
+ const_params: dict[str, Arg[Any]] = {}
851
+ output_length_param: str | None = None # param that receives batch row count
852
+ auth_param: str | None = None # param that receives AuthContext
853
+ returns_output_type: pa.DataType | None = None
854
+
855
+ # Check return type for Returns() annotation
856
+ return_hint = hints.get("return")
857
+ if return_hint is not None and hasattr(return_hint, "__metadata__"):
858
+ # Extract Returns from Annotated[..., Returns(...)]
859
+ for meta in return_hint.__metadata__:
860
+ if isinstance(meta, Returns):
861
+ # Priority 1: Explicit arrow_type in Returns()
862
+ if meta.arrow_type is not None:
863
+ returns_output_type = meta.arrow_type
864
+ else:
865
+ # Priority 2: Infer from Annotated first argument
866
+ type_args = get_args(return_hint)
867
+ if type_args:
868
+ return_base_type = type_args[0]
869
+ if return_base_type in ARRAY_CLASS_TO_DATATYPE:
870
+ returns_output_type = ARRAY_CLASS_TO_DATATYPE[return_base_type]
871
+ elif return_base_type in COMPLEX_ARRAY_CLASSES:
872
+ raise TypeError(
873
+ f"{return_base_type.__name__} requires explicit "
874
+ f"arrow_type in Returns(). "
875
+ f"Example: Returns(arrow_type=pa.list_(pa.int64()))"
876
+ )
877
+ # Else: AnyArrow (returns_output_type remains None)
878
+ break
879
+
880
+ # Extract Param/ConstParam from parameter annotations
881
+ # Track overall position (call order) for metadata/client use.
882
+ # For const params, track resolution_index (index into Args.positional)
883
+ # For column params, track column_index (index into batch columns)
884
+ overall_position = 0 # Overall call order position for metadata
885
+ column_index = 0 # Index in batch columns (for column params)
886
+ const_index = 0 # Index in invocation.arguments.positional (for const params)
887
+ for name in sig.parameters:
888
+ if name == "self":
889
+ continue
890
+
891
+ hint = hints.get(name)
892
+ if hint is None:
893
+ continue
894
+
895
+ # Check for Annotated[..., Param/ConstParam/...] pattern
896
+ if hasattr(hint, "__metadata__"):
897
+ for meta in hint.__metadata__:
898
+ # Param: column input (array)
899
+ if isinstance(meta, Param):
900
+ # Get base type from Annotated first argument for inference
901
+ type_args = get_args(hint)
902
+ base_type = type_args[0] if type_args else pa.Array
903
+ # Use overall position for metadata, column_index for resolution
904
+ arg = _param_to_arg(meta, base_type, overall_position)
905
+ arg._name = name
906
+ # Store column_index in _resolution_index for batch lookup
907
+ arg._resolution_index = column_index
908
+ compute_params[name] = arg
909
+ setattr(cls, name, _ArgDescriptor(arg, name))
910
+ overall_position += 1
911
+ column_index += 1
912
+ break
913
+
914
+ # ConstParam: constant input (scalar)
915
+ if isinstance(meta, ConstParam):
916
+ # Get base type from Annotated first argument
917
+ type_args = get_args(hint)
918
+ base_type = cast(type, type_args[0] if type_args else Any)
919
+ # Use overall position for metadata
920
+ arg = _const_param_to_arg(meta, base_type, overall_position)
921
+ arg._name = name
922
+ # _resolution_index points to Arguments.positional index
923
+ arg._resolution_index = const_index
924
+ const_params[name] = arg
925
+ setattr(cls, name, _ConstArgDescriptor(arg, name))
926
+ overall_position += 1
927
+ const_index += 1
928
+ break
929
+
930
+ # OutputLength: receives batch row count
931
+ if isinstance(meta, OutputLength):
932
+ output_length_param = name
933
+ # Don't increment overall_position - not a call argument
934
+ break
935
+
936
+ # Auth: receives AuthContext
937
+ if isinstance(meta, Auth):
938
+ if auth_param is not None:
939
+ raise TypeError(
940
+ f"{cls.__name__}.compute() has multiple Auth parameters: {auth_param!r} and {name!r}"
941
+ )
942
+ auth_param = name
943
+ # Don't increment overall_position - not a call argument
944
+ break
945
+
946
+ # Extract Setting/Secret params using shared helper
947
+ setting_params, secret_params = _extract_setting_secret_params(compute_method)
948
+
949
+ cls._compute_params = compute_params
950
+ cls._const_params = const_params
951
+ cls._setting_params = setting_params
952
+ cls._secret_params = secret_params
953
+ cls._output_length_param = output_length_param
954
+ cls._auth_param = auth_param
955
+ cls._returns_output_type = returns_output_type
956
+
957
+ @final
958
+ @classmethod
959
+ def catalog_output_schema(cls) -> pa.Schema:
960
+ """Return output schema for catalog introspection.
961
+
962
+ Returns the output schema with a single "result" field using the
963
+ type from the Returns() annotation. If no explicit type was declared
964
+ (dynamic type), returns null() with metadata indicating "any" type.
965
+ """
966
+ returns_type = getattr(cls, "_returns_output_type", None)
967
+ if returns_type is None:
968
+ # Dynamic type (no explicit Returns type)
969
+ field = pa.field("result", pa.null(), metadata={b"vgi:any": b"true"})
970
+ return pa.schema([field])
971
+ return schema({"result": returns_type})
972
+
973
+ @classmethod
974
+ def output_type(cls, params: BindParameters) -> pa.DataType:
975
+ """Return the Arrow type for the output column.
976
+
977
+ Default implementation uses _returns_output_type from Returns()
978
+ annotation. Override when the output type depends on input schema
979
+ or arguments (use params.arguments_schema, params.constant_arguments).
980
+
981
+ Args:
982
+ params: Bind parameters including arguments and input schema.
983
+
984
+ """
985
+ if cls._returns_output_type is not None:
986
+ return cls._returns_output_type
987
+
988
+ raise NotImplementedError(
989
+ f"{cls.__name__}.output_type must be overridden when using Returns() "
990
+ f"without an explicit type (dynamic output type)."
991
+ )
992
+
993
+ # Note: compute() is NOT defined here. Subclasses define it with their own
994
+ # keyword-only signature. This avoids mypy override errors for users.
995
+ # See class docstring for compute() signature requirements.
996
+ # Validated at class definition time by __init_subclass__.
997
+
998
+ @final
999
+ @classmethod
1000
+ def _extract_compute_kwargs(
1001
+ cls,
1002
+ batch: pa.RecordBatch,
1003
+ bind_call: BindRequest,
1004
+ auth_context: AuthContext,
1005
+ ) -> dict[str, Any]:
1006
+ """Extract columns/values for compute() parameters.
1007
+
1008
+ Returns dict[str, Any] because values are a mix of arrays, lists of
1009
+ arrays, and scalar values, keyed by compute() parameter names.
1010
+
1011
+ Args:
1012
+ batch: Input RecordBatch.
1013
+ bind_call: The BindCall with arguments, settings, and secrets.
1014
+ auth_context: Authentication context for the current request.
1015
+
1016
+ Returns:
1017
+ Dict mapping parameter names to their resolved values.
1018
+
1019
+ """
1020
+ kwargs: dict[str, Any] = {}
1021
+
1022
+ # Regular params: extract arrays by _resolution_index (batch column index)
1023
+ for name, arg in cls._compute_params.items():
1024
+ # Use _resolution_index for batch column lookup
1025
+ col_idx = cast(int, arg._resolution_index)
1026
+ if arg.varargs:
1027
+ # Varargs: collect all remaining columns from position
1028
+ kwargs[name] = [batch.column(i) for i in range(col_idx, batch.num_columns)]
1029
+ else:
1030
+ # Regular param: extract column by index
1031
+ kwargs[name] = batch.column(col_idx)
1032
+
1033
+ # Const params: extract scalar values from arguments
1034
+ for name, arg in cls._const_params.items():
1035
+ # Use _resolution_index for Arguments.positional lookup
1036
+ arg_idx = cast(int, arg._resolution_index)
1037
+ # Get the scalar value from arguments
1038
+ scalar = bind_call.arguments.positional[arg_idx]
1039
+ # Convert to Python value
1040
+ kwargs[name] = scalar.as_py() if scalar is not None else None
1041
+
1042
+ # Setting params: extract pa.Scalar from settings RecordBatch
1043
+ if bind_call.settings is not None and cls._setting_params:
1044
+ settings_schema = bind_call.settings.schema
1045
+ for name, setting_key in cls._setting_params.items():
1046
+ col_idx = settings_schema.get_field_index(setting_key)
1047
+ kwargs[name] = bind_call.settings.column(col_idx)[0] if col_idx >= 0 else None
1048
+
1049
+ # Secret params: extract dict[str, pa.Scalar] from secrets RecordBatch
1050
+ if bind_call.secrets is not None and cls._secret_params:
1051
+ secrets_schema = bind_call.secrets.schema
1052
+ for name, secret in cls._secret_params.items():
1053
+ col_idx = secrets_schema.get_field_index(secret.secret_type)
1054
+ kwargs[name] = _struct_scalar_to_dict(bind_call.secrets.column(col_idx)[0]) if col_idx >= 0 else None
1055
+
1056
+ # OutputLength param: pass the batch row count
1057
+ if cls._output_length_param is not None:
1058
+ kwargs[cls._output_length_param] = batch.num_rows
1059
+
1060
+ # Auth param: pass the AuthContext
1061
+ if cls._auth_param is not None:
1062
+ kwargs[cls._auth_param] = auth_context
1063
+
1064
+ return kwargs
1065
+
1066
+ @final
1067
+ @classmethod
1068
+ def _validate_single_param_type(cls, arg: Arg[Any], arr: pa.Array[Any], display_name: str) -> pa.Array[Any]:
1069
+ """Validate a single parameter's array type against its declaration.
1070
+
1071
+ If the array type doesn't match exactly but is castable (e.g. int32→int64,
1072
+ decimal128→double), the array is cast to the expected type and returned.
1073
+
1074
+ Args:
1075
+ arg: The Arg metadata for the parameter.
1076
+ arr: The actual array to validate.
1077
+ display_name: Name used in error messages (e.g. "x" or "x[0]").
1078
+
1079
+ Returns:
1080
+ The (possibly cast) array.
1081
+
1082
+ Raises:
1083
+ TypeMismatchError: If array type doesn't match and cannot be cast.
1084
+
1085
+ """
1086
+ if arg.is_any:
1087
+ if arg.type_bound is not None:
1088
+ arg.validate_type_bound(arr.type)
1089
+ return arr
1090
+ if arg.arrow_type is not None and arr.type != arg.arrow_type:
1091
+ try:
1092
+ casted = arr.cast(arg.arrow_type)
1093
+ logger.debug("Cast parameter '%s' from %s to %s", display_name, arr.type, arg.arrow_type)
1094
+ return casted
1095
+ except (pa.ArrowInvalid, pa.ArrowNotImplementedError):
1096
+ raise TypeMismatchError(
1097
+ f"Input type mismatch for parameter '{display_name}'.",
1098
+ param_name=display_name,
1099
+ expected_type=arg.arrow_type,
1100
+ actual_type=arr.type,
1101
+ function_name=cls.__name__,
1102
+ ) from None
1103
+ return arr
1104
+
1105
+ @final
1106
+ @classmethod
1107
+ def _validate_param_types(cls, kwargs: dict[str, Any]) -> None:
1108
+ """Validate that input array types match declared Param types.
1109
+
1110
+ For the Param/ConstParam API:
1111
+ - Validates exact type match for params with declared arrow_type
1112
+ - Validates type_bound predicates for AnyArrow params with type_bound
1113
+
1114
+ Args:
1115
+ kwargs: Dict of parameter names to arrays (from _extract_compute_kwargs).
1116
+
1117
+ Raises:
1118
+ TypeMismatchError: If any array type doesn't match its declared type.
1119
+ SchemaValidationError: If any array type fails type_bound validation.
1120
+
1121
+ """
1122
+ for name, arg in cls._compute_params.items():
1123
+ if arg.varargs:
1124
+ kwargs[name] = [
1125
+ cls._validate_single_param_type(arg, arr, f"{name}[{i}]") for i, arr in enumerate(kwargs[name])
1126
+ ]
1127
+ else:
1128
+ kwargs[name] = cls._validate_single_param_type(arg, kwargs[name], name)
1129
+
1130
+ @final
1131
+ @classmethod
1132
+ def _validate_output_type(cls, result: pa.Array[Any]) -> None:
1133
+ """Validate that output array type matches declared Returns type.
1134
+
1135
+ Args:
1136
+ result: The output array from compute().
1137
+
1138
+ Raises:
1139
+ TypeMismatchError: If output type doesn't match declared type.
1140
+
1141
+ """
1142
+ if cls._returns_output_type is None:
1143
+ return # AnyArrow or not specified
1144
+
1145
+ if result.type != cls._returns_output_type:
1146
+ raise TypeMismatchError(
1147
+ "Output type mismatch.",
1148
+ param_name="return",
1149
+ expected_type=cls._returns_output_type,
1150
+ actual_type=result.type,
1151
+ function_name=cls.__name__,
1152
+ )
1153
+
1154
+ @classmethod
1155
+ def on_bind(cls, params: BindParameters) -> BindResult:
1156
+ """Produce the output type during the bind phase.
1157
+
1158
+ Override to perform custom bind-time logic such as validating
1159
+ arguments, examining input schema, or computing a dynamic output type.
1160
+
1161
+ Args:
1162
+ params: Bind parameters including arguments, input schema,
1163
+ settings, and secrets.
1164
+
1165
+ Returns:
1166
+ BindResult with output_type and optional opaque_data.
1167
+
1168
+ Note:
1169
+ Constant arguments needed during process() are automatically
1170
+ serialized by the protocol. The opaque_data field is for
1171
+ additional bind-time state you need to pass forward.
1172
+
1173
+ """
1174
+ return BindResult(output_type=cls.output_type(params), opaque_data=None)
1175
+
1176
+ @final
1177
+ @classmethod
1178
+ def process(
1179
+ cls,
1180
+ *,
1181
+ batch: pa.RecordBatch,
1182
+ init_call: InitRequest,
1183
+ init_response: BaseInitResponse,
1184
+ storage: BoundStorage,
1185
+ auth_context: AuthContext,
1186
+ ) -> pa.RecordBatch:
1187
+ """Convert compute() to per-batch callback.
1188
+
1189
+ This method calls your compute() method for the input batch.
1190
+ Keyword-only parameters in compute() are automatically populated
1191
+ from the batch columns.
1192
+
1193
+ """
1194
+ output_schema = init_call.output_schema
1195
+
1196
+ # Extract columns for keyword-only parameters
1197
+ kwargs = cls._extract_compute_kwargs(batch, init_call.bind_call, auth_context)
1198
+
1199
+ # Validate input types match declared Param types
1200
+ cls._validate_param_types(kwargs)
1201
+
1202
+ # Call compute() defined by subclass. Cast to Any to avoid
1203
+ # attr-defined error since compute() isn't on base class.
1204
+ # and the arguments of compute() vary by subclass.
1205
+ result = cast(Any, cls).compute(**kwargs)
1206
+
1207
+ # Validate output type matches declared Returns type
1208
+ cls._validate_output_type(result)
1209
+
1210
+ # Create output batch from result array
1211
+ return pa.RecordBatch.from_arrays([result], schema=output_schema)