query-farm-airport-test-server 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1122 @@
1
+ import datetime
2
+ import os
3
+ import pickle
4
+ import tempfile
5
+ from collections.abc import Callable, Generator
6
+ from dataclasses import dataclass, field
7
+ from decimal import Decimal
8
+ from typing import Any, Literal, overload
9
+
10
+ import pyarrow as pa
11
+ import pyarrow.compute as pc
12
+ import pyarrow.flight as flight
13
+ import query_farm_flight_server.flight_handling as flight_handling
14
+ import query_farm_flight_server.flight_inventory as flight_inventory
15
+ import query_farm_flight_server.parameter_types as parameter_types
16
+
17
+ from .utils import CaseInsensitiveDict
18
+
19
+ # Since we are creating a new database, lets load it with a few example
20
+ # scalar functions.
21
+
22
+
23
+ @dataclass
24
+ class TableFunctionDynamicOutput:
25
+ # The method that will determine the output schema from the input parameters
26
+ schema_creator: Callable[[pa.RecordBatch, pa.Schema | None], pa.Schema]
27
+
28
+ # The default parameters for the function, if not called with any.
29
+ default_values: tuple[pa.RecordBatch, pa.Schema | None]
30
+
31
+
32
+ @dataclass
33
+ class TableFunction:
34
+ # The input schema for the function.
35
+ input_schema: pa.Schema
36
+
37
+ output_schema_source: pa.Schema | TableFunctionDynamicOutput
38
+
39
+ # The function to call to process a chunk of rows.
40
+ handler: Callable[[parameter_types.TableFunctionParameters, pa.Schema], parameter_types.TableFunctionInOutGenerator]
41
+
42
+ estimated_rows: int | Callable[[parameter_types.TableFunctionFlightInfo], int] = -1
43
+
44
+ def output_schema(
45
+ self,
46
+ parameters: pa.RecordBatch | None = None,
47
+ input_schema: pa.Schema | None = None,
48
+ ) -> pa.Schema:
49
+ if isinstance(self.output_schema_source, pa.Schema):
50
+ return self.output_schema_source
51
+ if parameters is None:
52
+ return self.output_schema_source.schema_creator(*self.output_schema_source.default_values)
53
+ assert isinstance(parameters, pa.RecordBatch)
54
+ result = self.output_schema_source.schema_creator(parameters, input_schema)
55
+ return result
56
+
57
+ def flight_info(
58
+ self,
59
+ *,
60
+ name: str,
61
+ catalog_name: str,
62
+ schema_name: str,
63
+ parameters: parameter_types.TableFunctionFlightInfo | None = None,
64
+ ) -> tuple[flight.FlightInfo, flight_inventory.FlightSchemaMetadata]:
65
+ """
66
+ Often its necessary to create a FlightInfo object
67
+ standardize doing that here.
68
+ """
69
+ assert name != ""
70
+ assert catalog_name != ""
71
+ assert schema_name != ""
72
+
73
+ if isinstance(self.estimated_rows, int):
74
+ estimated_rows = self.estimated_rows
75
+ else:
76
+ assert parameters is not None
77
+ estimated_rows = self.estimated_rows(parameters)
78
+
79
+ metadata = flight_inventory.FlightSchemaMetadata(
80
+ type="table_function",
81
+ catalog=catalog_name,
82
+ schema=schema_name,
83
+ name=name,
84
+ comment=None,
85
+ input_schema=self.input_schema,
86
+ )
87
+ flight_info = flight.FlightInfo(
88
+ self.output_schema(parameters.parameters, parameters.table_input_schema)
89
+ if parameters
90
+ else self.output_schema(),
91
+ # This will always be the same descriptor, so that we can use the action
92
+ # name to determine which which table function to execute.
93
+ descriptor_pack_(catalog_name, schema_name, "table_function", name),
94
+ [],
95
+ estimated_rows,
96
+ -1,
97
+ app_metadata=metadata.serialize(),
98
+ )
99
+ return (flight_info, metadata)
100
+
101
+
102
+ @dataclass
103
+ class ScalarFunction:
104
+ # The input schema for the function.
105
+ input_schema: pa.Schema
106
+ # The output schema for the function, should only have a single column.
107
+ output_schema: pa.Schema
108
+
109
+ # The function to call to process a chunk of rows.
110
+ handler: Callable[[pa.Table], pa.Array]
111
+
112
+ def flight_info(
113
+ self, *, name: str, catalog_name: str, schema_name: str
114
+ ) -> tuple[flight.FlightInfo, flight_inventory.FlightSchemaMetadata]:
115
+ """
116
+ Often its necessary to create a FlightInfo object
117
+ standardize doing that here.
118
+ """
119
+ metadata = flight_inventory.FlightSchemaMetadata(
120
+ type="scalar_function",
121
+ catalog=catalog_name,
122
+ schema=schema_name,
123
+ name=name,
124
+ comment=None,
125
+ input_schema=self.input_schema,
126
+ )
127
+ flight_info = flight.FlightInfo(
128
+ self.output_schema,
129
+ descriptor_pack_(catalog_name, schema_name, "scalar_function", name),
130
+ [],
131
+ -1,
132
+ -1,
133
+ app_metadata=metadata.serialize(),
134
+ )
135
+ return (flight_info, metadata)
136
+
137
+
138
+ def serialize_table_data(table: pa.Table) -> bytes:
139
+ """
140
+ Serialize the table data to a byte string.
141
+ """
142
+ assert isinstance(table, pa.Table)
143
+ sink = pa.BufferOutputStream()
144
+ with pa.ipc.new_stream(sink, table.schema) as writer:
145
+ writer.write_table(table)
146
+ return sink.getvalue().to_pybytes()
147
+
148
+
149
+ def deserialize_table_data(data: bytes) -> pa.Table:
150
+ """
151
+ Deserialize the table data from a byte string.
152
+ """
153
+ assert isinstance(data, bytes)
154
+ buffer = pa.BufferReader(data)
155
+ ipc_stream = pa.ipc.open_stream(buffer)
156
+ return ipc_stream.read_all()
157
+
158
+
159
+ @dataclass
160
+ class TableInfo:
161
+ # To enable version history keep track of tables.
162
+ table_versions: list[pa.Table] = field(default_factory=list)
163
+
164
+ # the next row id to assign.
165
+ row_id_counter: int = 0
166
+
167
+ # This cannot be serailized but it convenient for testing.
168
+ endpoint_generator: Callable[[Any], list[flight.FlightEndpoint]] | None = None
169
+
170
+ def update_table(self, table: pa.Table) -> None:
171
+ assert table is not None
172
+ assert isinstance(table, pa.Table)
173
+ self.table_versions.append(table)
174
+
175
+ def version(self, version: int | None = None) -> pa.Table:
176
+ """
177
+ Get the version of the table.
178
+ """
179
+ assert len(self.table_versions) > 0
180
+ if version is None:
181
+ return self.table_versions[-1]
182
+
183
+ assert version < len(self.table_versions)
184
+ return self.table_versions[version]
185
+
186
+ def flight_info(
187
+ self,
188
+ *,
189
+ name: str,
190
+ catalog_name: str,
191
+ schema_name: str,
192
+ version: int | None = None,
193
+ ) -> tuple[flight.FlightInfo, flight_inventory.FlightSchemaMetadata]:
194
+ """
195
+ Often its necessary to create a FlightInfo object for the table,
196
+ standardize doing that here.
197
+ """
198
+ metadata = flight_inventory.FlightSchemaMetadata(
199
+ type="table",
200
+ catalog=catalog_name,
201
+ schema=schema_name,
202
+ name=name,
203
+ comment=None,
204
+ )
205
+ flight_info = flight.FlightInfo(
206
+ self.version(version).schema,
207
+ descriptor_pack_(catalog_name, schema_name, "table", name),
208
+ [],
209
+ -1,
210
+ -1,
211
+ app_metadata=metadata.serialize(),
212
+ )
213
+ return (flight_info, metadata)
214
+
215
+ def serialize(self) -> dict[str, Any]:
216
+ """
217
+ Serialize the TableInfo to a dictionary.
218
+ """
219
+ return {
220
+ "table_versions": [serialize_table_data(table) for table in self.table_versions],
221
+ "row_id_counter": self.row_id_counter,
222
+ }
223
+
224
+ def deserialize(self, data: dict[str, Any]) -> "TableInfo":
225
+ """
226
+ Deserialize the TableInfo from a dictionary.
227
+ """
228
+ self.table_versions = [deserialize_table_data(table) for table in data["table_versions"]]
229
+ self.row_id_counter = data["row_id_counter"]
230
+ self.endpoint_generator = None
231
+ return self
232
+
233
+
234
+ ObjectTypeName = Literal["table", "scalar_function", "table_function"]
235
+
236
+
237
+ @dataclass
238
+ class DescriptorParts:
239
+ """
240
+ The fields that are encoded in the flight descriptor.
241
+ """
242
+
243
+ catalog_name: str
244
+ schema_name: str
245
+ type: ObjectTypeName
246
+ name: str
247
+
248
+
249
+ def descriptor_pack_(
250
+ catalog_name: str,
251
+ schema_name: str,
252
+ type: ObjectTypeName,
253
+ name: str,
254
+ ) -> flight.FlightDescriptor:
255
+ """
256
+ Pack the descriptor into a FlightDescriptor.
257
+ """
258
+ return flight.FlightDescriptor.for_path(f"{catalog_name}/{schema_name}/{type}/{name}")
259
+
260
+
261
+ def descriptor_unpack_(descriptor: flight.FlightDescriptor) -> DescriptorParts:
262
+ """
263
+ Split the descriptor into its components.
264
+ """
265
+ assert descriptor.descriptor_type == flight.DescriptorType.PATH
266
+ assert len(descriptor.path) == 1
267
+ path = descriptor.path[0].decode("utf-8")
268
+ parts = path.split("/")
269
+ if len(parts) != 4:
270
+ raise flight.FlightServerError(f"Invalid descriptor path: {path}")
271
+
272
+ descriptor_type: ObjectTypeName
273
+ if parts[2] == "table":
274
+ descriptor_type = "table"
275
+ elif parts[2] == "scalar_function":
276
+ descriptor_type = "scalar_function"
277
+ elif parts[2] == "table_function":
278
+ descriptor_type = "table_function"
279
+ else:
280
+ raise flight.FlightServerError(f"Invalid descriptor type: {parts[2]}")
281
+
282
+ return DescriptorParts(
283
+ catalog_name=parts[0],
284
+ schema_name=parts[1],
285
+ type=descriptor_type,
286
+ name=parts[3],
287
+ )
288
+
289
+
290
+ @dataclass
291
+ class SchemaCollection:
292
+ tables_by_name: CaseInsensitiveDict[TableInfo] = field(default_factory=CaseInsensitiveDict[TableInfo])
293
+ scalar_functions_by_name: CaseInsensitiveDict[ScalarFunction] = field(
294
+ default_factory=CaseInsensitiveDict[ScalarFunction]
295
+ )
296
+ table_functions_by_name: CaseInsensitiveDict[TableFunction] = field(
297
+ default_factory=CaseInsensitiveDict[TableFunction]
298
+ )
299
+
300
+ def serialize(self) -> dict[str, Any]:
301
+ return {
302
+ "tables": {name: table.serialize() for name, table in self.tables_by_name.items()},
303
+ }
304
+
305
+ def deserialize(self, data: dict[str, Any]) -> "SchemaCollection":
306
+ """
307
+ Deserialize the schema collection from a dictionary.
308
+ """
309
+ self.tables_by_name = CaseInsensitiveDict[TableInfo](
310
+ {name: TableInfo().deserialize(table) for name, table in data["tables"].items()}
311
+ )
312
+ return self
313
+
314
+ def containers(
315
+ self,
316
+ ) -> list[
317
+ CaseInsensitiveDict[TableInfo] | CaseInsensitiveDict[ScalarFunction] | CaseInsensitiveDict[TableFunction]
318
+ ]:
319
+ return [
320
+ self.tables_by_name,
321
+ self.scalar_functions_by_name,
322
+ self.table_functions_by_name,
323
+ ]
324
+
325
+ @overload
326
+ def by_name(self, type: Literal["table"], name: str) -> TableInfo: ...
327
+
328
+ @overload
329
+ def by_name(self, type: Literal["scalar_function"], name: str) -> ScalarFunction: ...
330
+
331
+ @overload
332
+ def by_name(self, type: Literal["table_function"], name: str) -> TableFunction: ...
333
+
334
+ def by_name(self, type: ObjectTypeName, name: str) -> TableInfo | ScalarFunction | TableFunction:
335
+ assert name is not None
336
+ assert name != ""
337
+ if type == "table":
338
+ table = self.tables_by_name.get(name)
339
+ if not table:
340
+ raise flight.FlightServerError(f"Table {name} does not exist.")
341
+ return table
342
+ elif type == "scalar_function":
343
+ scalar_function = self.scalar_functions_by_name.get(name)
344
+ if not scalar_function:
345
+ raise flight.FlightServerError(f"Scalar function {name} does not exist.")
346
+ return scalar_function
347
+ elif type == "table_function":
348
+ table_function = self.table_functions_by_name.get(name)
349
+ if not table_function:
350
+ raise flight.FlightServerError(f"Table function {name} does not exist.")
351
+ return table_function
352
+
353
+
354
+ @dataclass
355
+ class DatabaseContents:
356
+ # Collection of schemas by name.
357
+ schemas_by_name: CaseInsensitiveDict[SchemaCollection] = field(
358
+ default_factory=CaseInsensitiveDict[SchemaCollection]
359
+ )
360
+
361
+ # The version of the database, updated on each schema change.
362
+ version: int = 1
363
+
364
+ def __post_init__(self) -> None:
365
+ self.schemas_by_name["remote_data"] = remote_data_schema
366
+ self.schemas_by_name["static_data"] = static_data_schema
367
+ self.schemas_by_name["utils"] = util_schema
368
+ return
369
+
370
+ def by_name(self, name: str) -> SchemaCollection:
371
+ if name not in self.schemas_by_name:
372
+ raise flight.FlightServerError(f"Schema {name} does not exist.")
373
+ return self.schemas_by_name[name]
374
+
375
+ def serialize(self) -> dict[str, Any]:
376
+ return {
377
+ "schemas": {name: schema.serialize() for name, schema in self.schemas_by_name.items()},
378
+ "version": self.version,
379
+ }
380
+
381
+ def deserialize(self, data: dict[str, Any]) -> "DatabaseContents":
382
+ """
383
+ Deserialize the database contents from a dictionary.
384
+ """
385
+ self.schemas_by_name = CaseInsensitiveDict[SchemaCollection](
386
+ {name: SchemaCollection().deserialize(schema) for name, schema in data["schemas"].items()}
387
+ )
388
+ self.schemas_by_name["static_data"] = static_data_schema
389
+ self.schemas_by_name["remote_data"] = remote_data_schema
390
+ self.schemas_by_name["utils"] = util_schema
391
+
392
+ self.version = data["version"]
393
+ return self
394
+
395
+
396
+ @dataclass
397
+ class DatabaseLibrary:
398
+ """
399
+ The database library, which contains all of the databases, organized by token.
400
+ """
401
+
402
+ # Collection of databases by token.
403
+ databases_by_name: CaseInsensitiveDict[DatabaseContents] = field(
404
+ default_factory=CaseInsensitiveDict[DatabaseContents]
405
+ )
406
+
407
+ def by_name(self, name: str) -> DatabaseContents:
408
+ if name not in self.databases_by_name:
409
+ raise flight.FlightServerError(f"Database {name} does not exist.")
410
+ return self.databases_by_name[name]
411
+
412
+ def serialize(self) -> dict[str, Any]:
413
+ return {
414
+ "databases": {name: db.serialize() for name, db in self.databases_by_name.items()},
415
+ }
416
+
417
+ def deserialize(self, data: dict[str, Any]) -> None:
418
+ """
419
+ Deserialize the database library from a dictionary.
420
+ """
421
+ self.databases_by_name = CaseInsensitiveDict[DatabaseContents](
422
+ {name: DatabaseContents().deserialize(db) for name, db in data["databases"].items()}
423
+ )
424
+
425
+ @staticmethod
426
+ def filename_for_token(token: str) -> str:
427
+ """
428
+ Get the filename for the database library for a given token.
429
+ """
430
+ assert token is not None
431
+ assert token != ""
432
+ return f"database_library_{token}.pkl"
433
+
434
+ @staticmethod
435
+ def reset(token: str) -> None:
436
+ """
437
+ Reset the database library for a given token.
438
+ This will delete the file associated with the token.
439
+ """
440
+ file_path = DatabaseLibrary.filename_for_token(token)
441
+ if os.path.isfile(file_path):
442
+ os.remove(file_path)
443
+
444
+ @staticmethod
445
+ def read_from_file(token: str) -> "DatabaseLibrary":
446
+ """
447
+ Read the database library from a file.
448
+ If the file does not exist, return an empty database library.
449
+ """
450
+ library = DatabaseLibrary()
451
+
452
+ file_path = DatabaseLibrary.filename_for_token(token)
453
+
454
+ if not os.path.isfile(file_path):
455
+ # File doesn't exist — return empty instance
456
+ return library
457
+
458
+ try:
459
+ with open(file_path, "rb") as f:
460
+ # use pickle
461
+ data = pickle.load(f)
462
+ library.deserialize(data)
463
+ except Exception as e:
464
+ raise RuntimeError(f"Failed to read database library from {file_path}: {e}") from e
465
+
466
+ return library
467
+
468
+ def write_to_file(self, token: str) -> None:
469
+ """
470
+ Write the database library to a temp file, then atomically rename to the destination.
471
+ """
472
+ file_path = DatabaseLibrary.filename_for_token(token)
473
+
474
+ data = self.serialize()
475
+ # use pickle
476
+ dir_name = os.path.dirname(file_path) or "."
477
+
478
+ with tempfile.NamedTemporaryFile("wb", dir=dir_name, delete=False) as tmp_file:
479
+ pickle.dump(data, tmp_file)
480
+ os.replace(tmp_file.name, file_path)
481
+
482
+
483
+ class DatabaseLibraryContext:
484
+ def __init__(self, token: str, readonly: bool = False) -> None:
485
+ self.token = token
486
+ self.readonly = readonly
487
+
488
+ def __enter__(self) -> DatabaseLibrary:
489
+ self.db = DatabaseLibrary.read_from_file(self.token)
490
+ return self.db
491
+
492
+ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
493
+ if exc_type:
494
+ print(f"An error occurred: {exc_val}")
495
+ # Optionally return True to suppress the exception
496
+ return
497
+ if not self.readonly:
498
+ self.db.write_to_file(self.token)
499
+
500
+
501
+ def add_handler(table: pa.Table) -> pa.Array:
502
+ assert table.num_columns == 2
503
+ return pc.add(table.column(0), table.column(1))
504
+
505
+
506
+ def uppercase_handler(table: pa.Table) -> pa.Array:
507
+ assert table.num_columns == 1
508
+ return pc.utf8_upper(table.column(0))
509
+
510
+
511
+ def any_type_handler(table: pa.Table) -> pa.Array:
512
+ return table.column(0)
513
+
514
+
515
+ def echo_handler(
516
+ parameters: parameter_types.TableFunctionParameters,
517
+ output_schema: pa.Schema,
518
+ ) -> Generator[pa.RecordBatch, pa.RecordBatch, None]:
519
+ # Just echo the parameters back as a single row.
520
+ assert parameters.parameters
521
+ yield pa.RecordBatch.from_arrays(
522
+ [parameters.parameters.column(0)],
523
+ schema=pa.schema([pa.field("result", pa.string())]),
524
+ )
525
+
526
+
527
+ def long_handler(
528
+ parameters: parameter_types.TableFunctionParameters,
529
+ output_schema: pa.Schema,
530
+ ) -> Generator[pa.RecordBatch, pa.RecordBatch, None]:
531
+ assert parameters.parameters
532
+ for i in range(100):
533
+ yield pa.RecordBatch.from_arrays([[f"{i}"] * 3000] * len(output_schema), schema=output_schema)
534
+
535
+
536
+ def repeat_handler(
537
+ parameters: parameter_types.TableFunctionParameters,
538
+ output_schema: pa.Schema,
539
+ ) -> Generator[pa.RecordBatch, pa.RecordBatch, None]:
540
+ # Just echo the parameters back as a single row.
541
+ assert parameters.parameters
542
+ for _i in range(parameters.parameters.column(1).to_pylist()[0]):
543
+ yield pa.RecordBatch.from_arrays(
544
+ [parameters.parameters.column(0)],
545
+ schema=output_schema,
546
+ )
547
+
548
+
549
+ def wide_handler(
550
+ parameters: parameter_types.TableFunctionParameters,
551
+ output_schema: pa.Schema,
552
+ ) -> Generator[pa.RecordBatch, pa.RecordBatch, None]:
553
+ # Just echo the parameters back as a single row.
554
+ assert parameters.parameters
555
+ rows = []
556
+ for _i in range(parameters.parameters.column(0).to_pylist()[0]):
557
+ rows.append({f"result_{idx}": idx for idx in range(20)})
558
+
559
+ yield pa.RecordBatch.from_pylist(rows, schema=output_schema)
560
+
561
+
562
+ def dynamic_schema_handler(
563
+ parameters: parameter_types.TableFunctionParameters,
564
+ output_schema: pa.Schema,
565
+ ) -> Generator[pa.RecordBatch, pa.RecordBatch, None]:
566
+ yield parameters.parameters
567
+
568
+
569
+ def dynamic_schema_handler_output_schema(
570
+ parameters: pa.RecordBatch, input_schema: pa.Schema | None = None
571
+ ) -> pa.Schema:
572
+ # This is the schema that will be returned to the client.
573
+ # It will be used to create the table function.
574
+ assert isinstance(parameters, pa.RecordBatch)
575
+ return parameters.schema
576
+
577
+
578
+ def in_out_long_schema_handler(parameters: pa.RecordBatch, input_schema: pa.Schema | None = None) -> pa.Schema:
579
+ assert input_schema is not None
580
+ return pa.schema([input_schema.field(0)])
581
+
582
+
583
+ def in_out_schema_handler(parameters: pa.RecordBatch, input_schema: pa.Schema | None = None) -> pa.Schema:
584
+ assert input_schema is not None
585
+ return pa.schema([parameters.schema.field(0), input_schema.field(0)])
586
+
587
+
588
+ def in_out_wide_schema_handler(parameters: pa.RecordBatch, input_schema: pa.Schema | None = None) -> pa.Schema:
589
+ assert input_schema is not None
590
+ return pa.schema([pa.field(f"result_{i}", pa.int32()) for i in range(20)])
591
+
592
+
593
+ def in_out_echo_schema_handler(parameters: pa.RecordBatch, input_schema: pa.Schema | None = None) -> pa.Schema:
594
+ assert input_schema is not None
595
+ return input_schema
596
+
597
+
598
+ def in_out_echo_handler(
599
+ parameters: parameter_types.TableFunctionParameters,
600
+ output_schema: pa.Schema,
601
+ ) -> Generator[pa.RecordBatch, pa.RecordBatch, None]:
602
+ result = output_schema.empty_table()
603
+
604
+ while True:
605
+ input_chunk = yield result
606
+
607
+ if input_chunk is None:
608
+ break
609
+
610
+ result = input_chunk
611
+
612
+ return
613
+
614
+
615
+ def in_out_wide_handler(
616
+ parameters: parameter_types.TableFunctionParameters,
617
+ output_schema: pa.Schema,
618
+ ) -> parameter_types.TableFunctionInOutGenerator:
619
+ result = output_schema.empty_table()
620
+
621
+ while True:
622
+ input_chunk = yield (result, True)
623
+
624
+ if input_chunk is None:
625
+ break
626
+
627
+ if isinstance(input_chunk, bool):
628
+ raise NotImplementedError("Not expecting continuing output for input chunk.")
629
+
630
+ chunk_length = len(input_chunk)
631
+
632
+ result = pa.RecordBatch.from_arrays(
633
+ [[i] * chunk_length for i in range(20)],
634
+ schema=output_schema,
635
+ )
636
+
637
+ return None
638
+
639
+
640
+ def in_out_handler(
641
+ parameters: parameter_types.TableFunctionParameters,
642
+ output_schema: pa.Schema,
643
+ ) -> parameter_types.TableFunctionInOutGenerator:
644
+ result = output_schema.empty_table()
645
+
646
+ while True:
647
+ input_chunk = yield (result, True)
648
+
649
+ if input_chunk is None:
650
+ break
651
+
652
+ if isinstance(input_chunk, bool):
653
+ raise NotImplementedError("Not expecting continuing output for input chunk.")
654
+
655
+ assert parameters.parameters is not None
656
+ parameter_value = parameters.parameters.column(0).to_pylist()[0]
657
+
658
+ # Since input chunks could be different sizes, standardize it.
659
+ result = pa.RecordBatch.from_arrays(
660
+ [
661
+ [parameter_value] * len(input_chunk),
662
+ input_chunk.column(0),
663
+ ],
664
+ schema=output_schema,
665
+ )
666
+
667
+ return [pa.RecordBatch.from_arrays([["last"], ["row"]], schema=output_schema)]
668
+
669
+
670
+ def in_out_long_handler(
671
+ parameters: parameter_types.TableFunctionParameters,
672
+ output_schema: pa.Schema,
673
+ ) -> parameter_types.TableFunctionInOutGenerator:
674
+ result = output_schema.empty_table()
675
+
676
+ while True:
677
+ input_chunk = yield (result, True)
678
+
679
+ if input_chunk is None:
680
+ break
681
+
682
+ if isinstance(input_chunk, bool):
683
+ raise NotImplementedError("Not expecting continuing output for input chunk.")
684
+
685
+ # Return the input chunk ten times.
686
+ multiplier = 10
687
+ copied_results = [
688
+ pa.RecordBatch.from_arrays(
689
+ [
690
+ input_chunk.column(0),
691
+ ],
692
+ schema=output_schema,
693
+ )
694
+ for index in range(multiplier)
695
+ ]
696
+
697
+ for item in copied_results[0:-1]:
698
+ yield (item, False)
699
+ result = copied_results[-1]
700
+
701
+ return None
702
+
703
+
704
+ def in_out_huge_chunk_handler(
705
+ parameters: parameter_types.TableFunctionParameters,
706
+ output_schema: pa.Schema,
707
+ ) -> parameter_types.TableFunctionInOutGenerator:
708
+ result = output_schema.empty_table()
709
+ multiplier = 10
710
+ chunk_length = 5000
711
+
712
+ while True:
713
+ input_chunk = yield (result, True)
714
+
715
+ if input_chunk is None:
716
+ break
717
+
718
+ if isinstance(input_chunk, bool):
719
+ raise NotImplementedError("Not expecting continuing output for input chunk.")
720
+
721
+ for index, _i in enumerate(range(multiplier)):
722
+ output = pa.RecordBatch.from_arrays(
723
+ [list(range(chunk_length)), list([index] * chunk_length)],
724
+ schema=output_schema,
725
+ )
726
+ if index < multiplier - 1:
727
+ yield (output, False)
728
+ else:
729
+ result = output
730
+
731
+ # test big chunks returned as the last results.
732
+ return [
733
+ pa.RecordBatch.from_arrays([list(range(chunk_length)), list([footer_id] * chunk_length)], schema=output_schema)
734
+ for footer_id in (-1, -2, -3)
735
+ ]
736
+
737
+
738
+ def yellow_taxi_endpoint_generator(ticket_data: Any) -> list[flight.FlightEndpoint]:
739
+ """
740
+ Generate a list of FlightEndpoint objects for the NYC Yellow Taxi dataset.
741
+ """
742
+ files = [
743
+ "https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2025-01.parquet",
744
+ "https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2025-02.parquet",
745
+ # "https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2025-03.parquet",
746
+ # "https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2025-04.parquet",
747
+ # "https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2025-05.parquet",
748
+ ]
749
+ return [
750
+ flight_handling.endpoint(
751
+ ticket_data=ticket_data,
752
+ locations=[
753
+ flight_handling.dict_to_msgpack_duckdb_call_data_uri(
754
+ {
755
+ "function_name": "read_parquet",
756
+ # So arguments could be a record batch.
757
+ "data": flight_handling.serialize_arrow_ipc_table(
758
+ pa.Table.from_pylist(
759
+ [
760
+ {
761
+ "arg_0": files,
762
+ "hive_partitioning": False,
763
+ "union_by_name": True,
764
+ }
765
+ ],
766
+ )
767
+ ),
768
+ }
769
+ )
770
+ ],
771
+ )
772
+ ]
773
+
774
+
775
+ remote_data_schema = SchemaCollection(
776
+ scalar_functions_by_name=CaseInsensitiveDict(),
777
+ table_functions_by_name=CaseInsensitiveDict(),
778
+ tables_by_name=CaseInsensitiveDict(
779
+ {
780
+ "nyc_yellow_taxi": TableInfo(
781
+ table_versions=[
782
+ pa.schema(
783
+ [
784
+ pa.field("VendorID", pa.int32()),
785
+ pa.field("tpep_pickup_datetime", pa.timestamp("us")),
786
+ pa.field("tpep_dropoff_datetime", pa.timestamp("us")),
787
+ pa.field("passenger_count", pa.int64()),
788
+ pa.field("trip_distance", pa.float64()),
789
+ pa.field("RatecodeID", pa.int64()),
790
+ pa.field("store_and_fwd_flag", pa.string()),
791
+ pa.field("PULocationID", pa.int32()),
792
+ pa.field("DOLocationID", pa.int32()),
793
+ pa.field("payment_type", pa.int64()),
794
+ pa.field("fare_amount", pa.float64()),
795
+ pa.field("extra", pa.float64()),
796
+ pa.field("mta_tax", pa.float64()),
797
+ pa.field("tip_amount", pa.float64()),
798
+ pa.field("tolls_amount", pa.float64()),
799
+ pa.field("improvement_surcharge", pa.float64()),
800
+ pa.field("total_amount", pa.float64()),
801
+ pa.field("congestion_surcharge", pa.float64()),
802
+ pa.field("Airport_fee", pa.float64()),
803
+ pa.field("cbd_congestion_fee", pa.float64()),
804
+ ]
805
+ ).empty_table()
806
+ ],
807
+ row_id_counter=0,
808
+ endpoint_generator=yellow_taxi_endpoint_generator,
809
+ ),
810
+ }
811
+ ),
812
+ )
813
+
814
+
815
+ static_data_schema = SchemaCollection(
816
+ scalar_functions_by_name=CaseInsensitiveDict(),
817
+ table_functions_by_name=CaseInsensitiveDict(),
818
+ tables_by_name=CaseInsensitiveDict(
819
+ {
820
+ "big_chunk": TableInfo(
821
+ table_versions=[
822
+ pa.Table.from_arrays(
823
+ [
824
+ list(range(100000)),
825
+ ],
826
+ schema=pa.schema([pa.field("id", pa.int64())]),
827
+ )
828
+ ],
829
+ row_id_counter=0,
830
+ ),
831
+ "employees": TableInfo(
832
+ table_versions=[
833
+ pa.Table.from_arrays(
834
+ [
835
+ ["Emily", "Amy"],
836
+ [30, 32],
837
+ [datetime.datetime(2023, 10, 1), datetime.datetime(2024, 10, 2)],
838
+ ["{}", "[1,2,3]"],
839
+ [
840
+ bytes.fromhex("b975e4187a6d4afdb1a41f7174ce1805"),
841
+ bytes.fromhex("7ef19ab7c7af4f0188c386fae862fd60"),
842
+ ],
843
+ [datetime.date(2023, 10, 1), datetime.date(2024, 10, 2)],
844
+ [True, False],
845
+ ["Ann", None],
846
+ [1234.123, 5678.123],
847
+ [Decimal("12345.678790"), Decimal("67890.123456")],
848
+ ],
849
+ schema=pa.schema(
850
+ [
851
+ pa.field("name", pa.string()),
852
+ pa.field("age", pa.int32()),
853
+ pa.field("start_date", pa.timestamp("ms")),
854
+ pa.field("json_data", pa.json_(pa.string())),
855
+ pa.field("id", pa.uuid()),
856
+ pa.field("birthdate", pa.date32()),
857
+ pa.field("is_active", pa.bool_()),
858
+ pa.field("nickname", pa.string()),
859
+ pa.field("salary", pa.float64()),
860
+ pa.field("balance", pa.decimal128(12, 6)),
861
+ ],
862
+ metadata={"can_produce_statistics": "1"},
863
+ ),
864
+ )
865
+ ],
866
+ row_id_counter=2,
867
+ ),
868
+ }
869
+ ),
870
+ )
871
+
872
+
873
+ def collatz_step_count(n: int) -> int:
874
+ steps = 0
875
+ while n != 1:
876
+ if n % 2 == 0:
877
+ n //= 2
878
+ else:
879
+ n = 3 * n + 1
880
+ steps += 1
881
+ return steps
882
+
883
+
884
+ def collatz(inputs: pa.Array) -> pa.Array:
885
+ results = [collatz_step_count(n) for n in inputs.to_pylist()]
886
+ return pa.array(results, type=pa.int64())
887
+
888
+
889
+ def collatz_steps(n: int) -> list[int]:
890
+ steps = 0
891
+ results = []
892
+ while n != 1:
893
+ if n % 2 == 0:
894
+ n //= 2
895
+ else:
896
+ n = 3 * n + 1
897
+ results.append(n)
898
+ steps += 1
899
+ return results
900
+
901
+
902
+ util_schema = SchemaCollection(
903
+ scalar_functions_by_name=CaseInsensitiveDict(
904
+ {
905
+ "test_uppercase": ScalarFunction(
906
+ input_schema=pa.schema([pa.field("a", pa.string())]),
907
+ output_schema=pa.schema([pa.field("result", pa.string())]),
908
+ handler=uppercase_handler,
909
+ ),
910
+ "test_any_type": ScalarFunction(
911
+ input_schema=pa.schema([pa.field("a", pa.string(), metadata={"is_any_type": "1"})]),
912
+ output_schema=pa.schema([pa.field("result", pa.string())]),
913
+ handler=any_type_handler,
914
+ ),
915
+ "test_add": ScalarFunction(
916
+ input_schema=pa.schema([pa.field("a", pa.int64()), pa.field("b", pa.int64())]),
917
+ output_schema=pa.schema([pa.field("result", pa.int64())]),
918
+ handler=add_handler,
919
+ ),
920
+ "collatz": ScalarFunction(
921
+ input_schema=pa.schema([pa.field("n", pa.int64())]),
922
+ output_schema=pa.schema([pa.field("result", pa.int64())]),
923
+ handler=lambda table: collatz(table.column(0)),
924
+ ),
925
+ "collatz_sequence": ScalarFunction(
926
+ input_schema=pa.schema([pa.field("n", pa.int64())]),
927
+ output_schema=pa.schema([pa.field("result", pa.list_(pa.int64()))]),
928
+ handler=lambda table: pa.array(
929
+ [collatz_steps(n) for n in table.column(0).to_pylist()], type=pa.list_(pa.int64())
930
+ ),
931
+ ),
932
+ }
933
+ ),
934
+ table_functions_by_name=CaseInsensitiveDict(
935
+ {
936
+ "test_echo": TableFunction(
937
+ input_schema=pa.schema([pa.field("input", pa.string())]),
938
+ output_schema_source=pa.schema([pa.field("result", pa.string())]),
939
+ handler=echo_handler,
940
+ ),
941
+ "test_wide": TableFunction(
942
+ input_schema=pa.schema([pa.field("count", pa.int32())]),
943
+ output_schema_source=pa.schema([pa.field(f"result_{i}", pa.int32()) for i in range(20)]),
944
+ handler=wide_handler,
945
+ ),
946
+ "test_long": TableFunction(
947
+ input_schema=pa.schema([pa.field("input", pa.string())]),
948
+ output_schema_source=pa.schema(
949
+ [
950
+ pa.field("result", pa.string()),
951
+ pa.field("result2", pa.string()),
952
+ ]
953
+ ),
954
+ handler=long_handler,
955
+ ),
956
+ "test_repeat": TableFunction(
957
+ input_schema=pa.schema(
958
+ [
959
+ pa.field("input", pa.string()),
960
+ pa.field("count", pa.int32()),
961
+ ]
962
+ ),
963
+ output_schema_source=pa.schema([pa.field("result", pa.string())]),
964
+ handler=repeat_handler,
965
+ ),
966
+ "test_dynamic_schema": TableFunction(
967
+ input_schema=pa.schema(
968
+ [
969
+ pa.field(
970
+ "input",
971
+ pa.string(),
972
+ metadata={"is_any_type": "1"},
973
+ )
974
+ ]
975
+ ),
976
+ output_schema_source=TableFunctionDynamicOutput(
977
+ schema_creator=dynamic_schema_handler_output_schema,
978
+ default_values=(
979
+ pa.RecordBatch.from_arrays(
980
+ [pa.array([1], type=pa.int32())],
981
+ schema=pa.schema([pa.field("input", pa.int32())]),
982
+ ),
983
+ None,
984
+ ),
985
+ ),
986
+ handler=dynamic_schema_handler,
987
+ ),
988
+ "test_dynamic_schema_named_parameters": TableFunction(
989
+ input_schema=pa.schema(
990
+ [
991
+ pa.field("name", pa.string()),
992
+ pa.field(
993
+ "location",
994
+ pa.string(),
995
+ metadata={"is_named_parameter": "1"},
996
+ ),
997
+ pa.field(
998
+ "input",
999
+ pa.string(),
1000
+ metadata={"is_any_type": "1"},
1001
+ ),
1002
+ pa.field("city", pa.string()),
1003
+ ]
1004
+ ),
1005
+ output_schema_source=TableFunctionDynamicOutput(
1006
+ schema_creator=dynamic_schema_handler_output_schema,
1007
+ default_values=(
1008
+ pa.RecordBatch.from_arrays(
1009
+ [pa.array([1], type=pa.int32())],
1010
+ schema=pa.schema([pa.field("input", pa.int32())]),
1011
+ ),
1012
+ None,
1013
+ ),
1014
+ ),
1015
+ handler=dynamic_schema_handler,
1016
+ ),
1017
+ "test_table_in_out": TableFunction(
1018
+ input_schema=pa.schema(
1019
+ [
1020
+ pa.field("input", pa.string()),
1021
+ pa.field(
1022
+ "table_input",
1023
+ pa.string(),
1024
+ metadata={"is_table_type": "1"},
1025
+ ),
1026
+ ]
1027
+ ),
1028
+ output_schema_source=TableFunctionDynamicOutput(
1029
+ schema_creator=in_out_schema_handler,
1030
+ default_values=(
1031
+ pa.RecordBatch.from_arrays(
1032
+ [pa.array([1], type=pa.int32())],
1033
+ schema=pa.schema([pa.field("input", pa.int32())]),
1034
+ ),
1035
+ pa.schema([pa.field("input", pa.int32())]),
1036
+ ),
1037
+ ),
1038
+ handler=in_out_handler,
1039
+ ),
1040
+ "test_table_in_out_long": TableFunction(
1041
+ input_schema=pa.schema(
1042
+ [
1043
+ pa.field(
1044
+ "table_input",
1045
+ pa.string(),
1046
+ metadata={"is_table_type": "1"},
1047
+ ),
1048
+ ]
1049
+ ),
1050
+ output_schema_source=TableFunctionDynamicOutput(
1051
+ schema_creator=in_out_long_schema_handler,
1052
+ default_values=(
1053
+ pa.RecordBatch.from_arrays(
1054
+ [pa.array([1], type=pa.int32())],
1055
+ schema=pa.schema([pa.field("input", pa.int32())]),
1056
+ ),
1057
+ pa.schema([pa.field("input", pa.int32())]),
1058
+ ),
1059
+ ),
1060
+ handler=in_out_long_handler,
1061
+ ),
1062
+ "test_table_in_out_huge": TableFunction(
1063
+ input_schema=pa.schema(
1064
+ [
1065
+ pa.field(
1066
+ "table_input",
1067
+ pa.string(),
1068
+ metadata={"is_table_type": "1"},
1069
+ ),
1070
+ ]
1071
+ ),
1072
+ output_schema_source=pa.schema([("multiplier", pa.int64()), ("value", pa.int64())]),
1073
+ handler=in_out_huge_chunk_handler,
1074
+ ),
1075
+ "test_table_in_out_wide": TableFunction(
1076
+ input_schema=pa.schema(
1077
+ [
1078
+ pa.field("input", pa.string()),
1079
+ pa.field(
1080
+ "table_input",
1081
+ pa.string(),
1082
+ metadata={"is_table_type": "1"},
1083
+ ),
1084
+ ]
1085
+ ),
1086
+ output_schema_source=TableFunctionDynamicOutput(
1087
+ schema_creator=in_out_wide_schema_handler,
1088
+ default_values=(
1089
+ pa.RecordBatch.from_arrays(
1090
+ [pa.array([1], type=pa.int32())],
1091
+ schema=pa.schema([pa.field("input", pa.int32())]),
1092
+ ),
1093
+ pa.schema([pa.field("input", pa.int32())]),
1094
+ ),
1095
+ ),
1096
+ handler=in_out_wide_handler,
1097
+ ),
1098
+ "test_table_in_out_echo": TableFunction(
1099
+ input_schema=pa.schema(
1100
+ [
1101
+ pa.field(
1102
+ "table_input",
1103
+ pa.string(),
1104
+ metadata={"is_table_type": "1"},
1105
+ ),
1106
+ ]
1107
+ ),
1108
+ output_schema_source=TableFunctionDynamicOutput(
1109
+ schema_creator=in_out_echo_schema_handler,
1110
+ default_values=(
1111
+ pa.RecordBatch.from_arrays(
1112
+ [pa.array([1], type=pa.int32())],
1113
+ schema=pa.schema([pa.field("input", pa.int32())]),
1114
+ ),
1115
+ pa.schema([pa.field("input", pa.int32())]),
1116
+ ),
1117
+ ),
1118
+ handler=in_out_echo_handler,
1119
+ ),
1120
+ }
1121
+ ),
1122
+ )