query-farm-airport-test-server 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1799 @@
1
+ import hashlib
2
+ import json
3
+ import re
4
+ from collections.abc import Callable, Generator, Iterator
5
+ from dataclasses import dataclass, field
6
+ from typing import Any, Literal, TypeVar, overload
7
+
8
+ import click
9
+ import duckdb
10
+ import msgpack
11
+ import pyarrow as pa
12
+ import pyarrow.compute as pc
13
+ import pyarrow.flight as flight
14
+ import query_farm_duckdb_json_serialization.expression
15
+ import query_farm_flight_server.auth as auth
16
+ import query_farm_flight_server.auth_manager as auth_manager
17
+ import query_farm_flight_server.auth_manager_naive as auth_manager_naive
18
+ import query_farm_flight_server.flight_handling as flight_handling
19
+ import query_farm_flight_server.flight_inventory as flight_inventory
20
+ import query_farm_flight_server.middleware as base_middleware
21
+ import query_farm_flight_server.parameter_types as parameter_types
22
+ import query_farm_flight_server.schema_uploader as schema_uploader
23
+ import query_farm_flight_server.server as base_server
24
+ import structlog
25
+ from pydantic import BaseModel, ConfigDict, field_serializer, field_validator
26
+
27
+ from .utils import CaseInsensitiveDict
28
+
29
+ log = structlog.get_logger()
30
+
31
+
32
+ def read_recordbatch(source: bytes) -> pa.RecordBatch:
33
+ """
34
+ Read a record batch from a byte string.
35
+ """
36
+ buffer = pa.BufferReader(source)
37
+ ipc_stream = pa.ipc.open_stream(buffer)
38
+ return next(ipc_stream)
39
+
40
+
41
+ def conform_nullable(schema: pa.Schema, table: pa.Table) -> pa.Table:
42
+ """
43
+ Conform the table to the nullable flags as defined in the schema.
44
+
45
+ There shouldn't be null values in the columns.
46
+
47
+ This is needed because DuckDB doesn't send the nullable flag in the schema
48
+ it sends via the DoExchange call.
49
+ """
50
+ for idx, table_field in enumerate(schema):
51
+ if not table_field.nullable:
52
+ # Only update the column if the new schema allows nulls where the original did not
53
+ new_field = table_field.with_nullable(False)
54
+
55
+ # Check for null values.
56
+ if table.column(idx).null_count > 0:
57
+ raise flight.FlightServerError(
58
+ f"Column {table_field.name} has null values, but the schema does not allow nulls."
59
+ )
60
+
61
+ table = table.set_column(idx, new_field, table.column(idx))
62
+ return table
63
+
64
+
65
+ def check_schema_is_subset_of_schema(existing_schema: pa.Schema, new_schema: pa.Schema) -> None:
66
+ """
67
+ Check that the new schema is a subset of the existing schema.
68
+ """
69
+ existing_contents = set([(field.name, field.type) for field in existing_schema])
70
+ new_contents = set([(field.name, field.type) for field in new_schema])
71
+
72
+ unknown_fields = new_contents - existing_contents
73
+ if unknown_fields:
74
+ raise flight.FlightServerError(f"Unknown fields in insert: {unknown_fields}")
75
+ return
76
+
77
+
78
+ # class FlightTicketDataTableFunction(flight_handling.FlightTicketData):
79
+ # action_name: str
80
+ # schema_name: str
81
+ # parameters: bytes
82
+
83
+
84
+ # def model_selector(flight_name: str, src: bytes) -> flight_handling.FlightTicketData | FlightTicketDataTableFunction:
85
+ # return flight_handling.FlightTicketData.unpack(src)
86
+
87
+
88
+ @dataclass
89
+ class TableFunctionDynamicOutput:
90
+ # The method that will determine the output schema from the input parameters
91
+ schema_creator: Callable[[pa.RecordBatch, pa.Schema | None], pa.Schema]
92
+
93
+ # The default parameters for the function, if not called with any.
94
+ default_values: tuple[pa.RecordBatch, pa.Schema | None]
95
+
96
+
97
+ @dataclass
98
+ class TableFunction:
99
+ # The input schema for the function.
100
+ input_schema: pa.Schema
101
+
102
+ output_schema_source: pa.Schema | TableFunctionDynamicOutput
103
+
104
+ # The function to call to process a chunk of rows.
105
+ handler: Callable[
106
+ [parameter_types.TableFunctionParameters, pa.Schema],
107
+ Generator[pa.RecordBatch, pa.RecordBatch, pa.RecordBatch],
108
+ ]
109
+
110
+ estimated_rows: int | Callable[[parameter_types.TableFunctionFlightInfo], int] = -1
111
+
112
+ def output_schema(
113
+ self,
114
+ parameters: pa.RecordBatch | None = None,
115
+ input_schema: pa.Schema | None = None,
116
+ ) -> pa.Schema:
117
+ if isinstance(self.output_schema_source, pa.Schema):
118
+ return self.output_schema_source
119
+ if parameters is None:
120
+ return self.output_schema_source.schema_creator(*self.output_schema_source.default_values)
121
+ assert isinstance(parameters, pa.RecordBatch)
122
+ result = self.output_schema_source.schema_creator(parameters, input_schema)
123
+ return result
124
+
125
+ def flight_info(
126
+ self,
127
+ *,
128
+ name: str,
129
+ catalog_name: str,
130
+ schema_name: str,
131
+ parameters: parameter_types.TableFunctionFlightInfo | None = None,
132
+ ) -> tuple[flight.FlightInfo, flight_inventory.FlightSchemaMetadata]:
133
+ """
134
+ Often its necessary to create a FlightInfo object
135
+ standardize doing that here.
136
+ """
137
+ assert name != ""
138
+ assert catalog_name != ""
139
+ assert schema_name != ""
140
+
141
+ if isinstance(self.estimated_rows, int):
142
+ estimated_rows = self.estimated_rows
143
+ else:
144
+ assert parameters is not None
145
+ estimated_rows = self.estimated_rows(parameters)
146
+
147
+ metadata = flight_inventory.FlightSchemaMetadata(
148
+ type="table_function",
149
+ catalog=catalog_name,
150
+ schema=schema_name,
151
+ name=name,
152
+ comment=None,
153
+ input_schema=self.input_schema,
154
+ )
155
+ flight_info = flight.FlightInfo(
156
+ self.output_schema(parameters.parameters, parameters.table_input_schema)
157
+ if parameters
158
+ else self.output_schema(),
159
+ # This will always be the same descriptor, so that we can use the action
160
+ # name to determine which which table function to execute.
161
+ descriptor_pack_(catalog_name, schema_name, "table_function", name),
162
+ [],
163
+ estimated_rows,
164
+ -1,
165
+ app_metadata=metadata.serialize(),
166
+ )
167
+ return (flight_info, metadata)
168
+
169
+
170
+ @dataclass
171
+ class ScalarFunction:
172
+ # The input schema for the function.
173
+ input_schema: pa.Schema
174
+ # The output schema for the function, should only have a single column.
175
+ output_schema: pa.Schema
176
+
177
+ # The function to call to process a chunk of rows.
178
+ handler: Callable[[pa.Table], pa.Array]
179
+
180
+ def flight_info(
181
+ self, *, name: str, catalog_name: str, schema_name: str
182
+ ) -> tuple[flight.FlightInfo, flight_inventory.FlightSchemaMetadata]:
183
+ """
184
+ Often its necessary to create a FlightInfo object
185
+ standardize doing that here.
186
+ """
187
+ metadata = flight_inventory.FlightSchemaMetadata(
188
+ type="scalar_function",
189
+ catalog=catalog_name,
190
+ schema=schema_name,
191
+ name=name,
192
+ comment=None,
193
+ input_schema=self.input_schema,
194
+ )
195
+ flight_info = flight.FlightInfo(
196
+ self.output_schema,
197
+ descriptor_pack_(catalog_name, schema_name, "scalar_function", name),
198
+ [],
199
+ -1,
200
+ -1,
201
+ app_metadata=metadata.serialize(),
202
+ )
203
+ return (flight_info, metadata)
204
+
205
+
206
+ @dataclass
207
+ class TableInfo:
208
+ # To enable version history keep track of tables.
209
+ table_versions: list[pa.Table] = field(default_factory=list)
210
+
211
+ # the next row id to assign.
212
+ row_id_counter: int = 0
213
+
214
+ def update_table(self, table: pa.Table) -> None:
215
+ assert table is not None
216
+ assert isinstance(table, pa.Table)
217
+ self.table_versions.append(table)
218
+
219
+ def version(self, version: int | None = None) -> pa.Table:
220
+ """
221
+ Get the version of the table.
222
+ """
223
+ assert len(self.table_versions) > 0
224
+ if version is None:
225
+ return self.table_versions[-1]
226
+
227
+ assert version < len(self.table_versions)
228
+ return self.table_versions[version]
229
+
230
+ def flight_info(
231
+ self,
232
+ *,
233
+ name: str,
234
+ catalog_name: str,
235
+ schema_name: str,
236
+ version: int | None = None,
237
+ ) -> tuple[flight.FlightInfo, flight_inventory.FlightSchemaMetadata]:
238
+ """
239
+ Often its necessary to create a FlightInfo object for the table,
240
+ standardize doing that here.
241
+ """
242
+ metadata = flight_inventory.FlightSchemaMetadata(
243
+ type="table",
244
+ catalog=catalog_name,
245
+ schema=schema_name,
246
+ name=name,
247
+ comment=None,
248
+ )
249
+ flight_info = flight.FlightInfo(
250
+ self.version(version).schema,
251
+ descriptor_pack_(catalog_name, schema_name, "table", name),
252
+ [],
253
+ -1,
254
+ -1,
255
+ app_metadata=metadata.serialize(),
256
+ )
257
+ return (flight_info, metadata)
258
+
259
+
260
+ ObjectTypeName = Literal["table", "scalar_function", "table_function"]
261
+
262
+
263
+ @dataclass
264
+ class SchemaCollection:
265
+ tables_by_name: CaseInsensitiveDict[TableInfo] = field(default_factory=CaseInsensitiveDict[TableInfo])
266
+ scalar_functions_by_name: CaseInsensitiveDict[ScalarFunction] = field(
267
+ default_factory=CaseInsensitiveDict[ScalarFunction]
268
+ )
269
+ table_functions_by_name: CaseInsensitiveDict[TableFunction] = field(
270
+ default_factory=CaseInsensitiveDict[TableFunction]
271
+ )
272
+
273
+ def containers(
274
+ self,
275
+ ) -> list[
276
+ CaseInsensitiveDict[TableInfo] | CaseInsensitiveDict[ScalarFunction] | CaseInsensitiveDict[TableFunction]
277
+ ]:
278
+ return [
279
+ self.tables_by_name,
280
+ self.scalar_functions_by_name,
281
+ self.table_functions_by_name,
282
+ ]
283
+
284
+ @overload
285
+ def by_name(self, type: Literal["table"], name: str) -> TableInfo: ...
286
+
287
+ @overload
288
+ def by_name(self, type: Literal["scalar_function"], name: str) -> ScalarFunction: ...
289
+
290
+ @overload
291
+ def by_name(self, type: Literal["table_function"], name: str) -> TableFunction: ...
292
+
293
+ def by_name(self, type: ObjectTypeName, name: str) -> TableInfo | ScalarFunction | TableFunction:
294
+ assert name is not None
295
+ assert name != ""
296
+ if type == "table":
297
+ table = self.tables_by_name.get(name)
298
+ if not table:
299
+ raise flight.FlightServerError(f"Table {name} does not exist.")
300
+ return table
301
+ elif type == "scalar_function":
302
+ scalar_function = self.scalar_functions_by_name.get(name)
303
+ if not scalar_function:
304
+ raise flight.FlightServerError(f"Scalar function {name} does not exist.")
305
+ return scalar_function
306
+ elif type == "table_function":
307
+ table_function = self.table_functions_by_name.get(name)
308
+ if not table_function:
309
+ raise flight.FlightServerError(f"Table function {name} does not exist.")
310
+ return table_function
311
+
312
+
313
+ @dataclass
314
+ class DatabaseContents:
315
+ # Collection of schemas by name.
316
+ schemas_by_name: CaseInsensitiveDict[SchemaCollection] = field(
317
+ default_factory=CaseInsensitiveDict[SchemaCollection]
318
+ )
319
+
320
+ # The version of the database, updated on each schema change.
321
+ version: int = 1
322
+
323
+ def by_name(self, name: str) -> SchemaCollection:
324
+ if name not in self.schemas_by_name:
325
+ raise flight.FlightServerError(f"Schema {name} does not exist.")
326
+ return self.schemas_by_name[name]
327
+
328
+
329
+ @dataclass
330
+ class DescriptorParts:
331
+ """
332
+ The fields that are encoded in the flight descriptor.
333
+ """
334
+
335
+ catalog_name: str
336
+ schema_name: str
337
+ type: ObjectTypeName
338
+ name: str
339
+
340
+
341
+ @dataclass
342
+ class DatabaseLibrary:
343
+ """
344
+ The database library, which contains all of the databases, organized by token.
345
+ """
346
+
347
+ # Collection of databases by token.
348
+ databases_by_name: CaseInsensitiveDict[DatabaseContents] = field(
349
+ default_factory=CaseInsensitiveDict[DatabaseContents]
350
+ )
351
+
352
+ def by_name(self, name: str) -> DatabaseContents:
353
+ if name not in self.databases_by_name:
354
+ raise flight.FlightServerError(f"Database {name} does not exist.")
355
+ return self.databases_by_name[name]
356
+
357
+
358
+ def descriptor_pack_(
359
+ catalog_name: str,
360
+ schema_name: str,
361
+ type: ObjectTypeName,
362
+ name: str,
363
+ ) -> flight.FlightDescriptor:
364
+ """
365
+ Pack the descriptor into a FlightDescriptor.
366
+ """
367
+ return flight.FlightDescriptor.for_path(f"{catalog_name}/{schema_name}/{type}/{name}")
368
+
369
+
370
+ def descriptor_unpack_(descriptor: flight.FlightDescriptor) -> DescriptorParts:
371
+ """
372
+ Split the descriptor into its components.
373
+ """
374
+ assert descriptor.descriptor_type == flight.DescriptorType.PATH
375
+ assert len(descriptor.path) == 1
376
+ path = descriptor.path[0].decode("utf-8")
377
+ parts = path.split("/")
378
+ if len(parts) != 4:
379
+ raise flight.FlightServerError(f"Invalid descriptor path: {path}")
380
+
381
+ descriptor_type: ObjectTypeName
382
+ if parts[2] == "table":
383
+ descriptor_type = "table"
384
+ elif parts[2] == "scalar_function":
385
+ descriptor_type = "scalar_function"
386
+ elif parts[2] == "table_function":
387
+ descriptor_type = "table_function"
388
+ else:
389
+ raise flight.FlightServerError(f"Invalid descriptor type: {parts[2]}")
390
+
391
+ return DescriptorParts(
392
+ catalog_name=parts[0],
393
+ schema_name=parts[1],
394
+ type=descriptor_type,
395
+ name=parts[3],
396
+ )
397
+
398
+
399
+ class FlightTicketData(BaseModel):
400
+ model_config = ConfigDict(arbitrary_types_allowed=True) # for Pydantic v2
401
+ descriptor: flight.FlightDescriptor
402
+
403
+ where_clause: str | None = None
404
+
405
+ # These are the parameters for the table returning function.
406
+ table_function_parameters: pa.RecordBatch | None = None
407
+ table_function_input_schema: pa.Schema | None = None
408
+
409
+ at_unit: str | None = None
410
+ at_value: str | None = None
411
+
412
+ _validate_table_function_parameters = field_validator("table_function_parameters", mode="before")(
413
+ parameter_types.deserialize_record_batch_or_none
414
+ )
415
+
416
+ _validate_table_function_input_schema = field_validator("table_function_input_schema", mode="before")(
417
+ parameter_types.deserialize_schema_or_none
418
+ )
419
+
420
+ @field_serializer("table_function_parameters")
421
+ def serialize_table_function_parameters(self, value: pa.RecordBatch, info: Any) -> bytes | None:
422
+ return parameter_types.serialize_record_batch(value, info)
423
+
424
+ @field_serializer("table_function_input_schema")
425
+ def serialize_table_function_input_Schema(self, value: pa.RecordBatch, info: Any) -> bytes | None:
426
+ return parameter_types.serialize_schema(value, info)
427
+
428
+ _validate_flight_descriptor = field_validator("descriptor", mode="before")(
429
+ parameter_types.deserialize_flight_descriptor
430
+ )
431
+
432
+ @field_serializer("descriptor")
433
+ def serialize_flight_descriptor(self, value: flight.FlightDescriptor, info: Any) -> bytes:
434
+ return parameter_types.serialize_flight_descriptor(value, info)
435
+
436
+
437
+ T = TypeVar("T", bound=BaseModel)
438
+
439
+
440
+ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth.AccountToken]):
441
+ def __init__(
442
+ self,
443
+ *,
444
+ location: str | None,
445
+ auth_manager: auth_manager.AuthManager[auth.Account, auth.AccountToken],
446
+ **kwargs: dict[str, Any],
447
+ ) -> None:
448
+ self.service_name = "test_server"
449
+ self._auth_manager = auth_manager
450
+ super().__init__(location=location, **kwargs)
451
+
452
+ # token, database name, schema, table_name
453
+ self.contents: dict[str, DatabaseLibrary] = {}
454
+
455
+ self.ROWID_FIELD_NAME = "rowid"
456
+ self.rowid_field = pa.field(self.ROWID_FIELD_NAME, pa.int64(), metadata={"is_rowid": "1"})
457
+
458
+ def action_endpoints(
459
+ self,
460
+ *,
461
+ context: base_server.CallContext[auth.Account, auth.AccountToken],
462
+ parameters: parameter_types.Endpoints,
463
+ ) -> list[flight.FlightEndpoint]:
464
+ assert context.caller is not None
465
+
466
+ descriptor_parts = descriptor_unpack_(parameters.descriptor)
467
+ library = self.contents[context.caller.token.token]
468
+ database = library.by_name(descriptor_parts.catalog_name)
469
+ schema = database.by_name(descriptor_parts.schema_name)
470
+
471
+ filter_sql_where_clause: str | None = None
472
+ if parameters.parameters.json_filters is not None:
473
+ context.logger.debug("duckdb_input", input=json.dumps(parameters.parameters.json_filters.filters))
474
+ filter_sql_where_clause, filter_sql_field_type_info = (
475
+ query_farm_duckdb_json_serialization.expression.convert_to_sql(
476
+ source=parameters.parameters.json_filters.filters,
477
+ bound_column_names=parameters.parameters.json_filters.column_binding_names_by_index,
478
+ )
479
+ )
480
+ if filter_sql_where_clause == "":
481
+ filter_sql_where_clause = None
482
+
483
+ if descriptor_parts.type == "table":
484
+ schema.by_name("table", descriptor_parts.name)
485
+
486
+ ticket_data = FlightTicketData(
487
+ descriptor=parameters.descriptor,
488
+ where_clause=filter_sql_where_clause,
489
+ at_unit=parameters.parameters.at_unit,
490
+ at_value=parameters.parameters.at_value,
491
+ )
492
+
493
+ return [flight_handling.endpoint(ticket_data=ticket_data, locations=None)]
494
+ elif descriptor_parts.type == "table_function":
495
+ # So the table function may not exist, because its a dynamic descriptor.
496
+
497
+ schema.by_name("table_function", descriptor_parts.name)
498
+
499
+ ticket_data = FlightTicketData(
500
+ descriptor=parameters.descriptor,
501
+ where_clause=filter_sql_where_clause,
502
+ table_function_parameters=parameters.parameters.table_function_parameters,
503
+ table_function_input_schema=parameters.parameters.table_function_input_schema,
504
+ at_unit=parameters.parameters.at_unit,
505
+ at_value=parameters.parameters.at_value,
506
+ )
507
+ return [flight_handling.endpoint(ticket_data=ticket_data, locations=None)]
508
+ else:
509
+ raise flight.FlightServerError(f"Unsupported descriptor type: {descriptor_parts.type}")
510
+
511
+ def action_list_schemas(
512
+ self,
513
+ *,
514
+ context: base_server.CallContext[auth.Account, auth.AccountToken],
515
+ parameters: parameter_types.ListSchemas,
516
+ ) -> base_server.AirportSerializedCatalogRoot:
517
+ assert context.caller is not None
518
+
519
+ self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
520
+ library = self.contents[context.caller.token.token]
521
+ database = library.by_name(parameters.catalog_name)
522
+
523
+ dynamic_inventory: dict[str, dict[str, list[flight_inventory.FlightInventoryWithMetadata]]] = {}
524
+
525
+ catalog_contents = dynamic_inventory.setdefault(parameters.catalog_name, {})
526
+
527
+ for schema_name, schema in database.schemas_by_name.items():
528
+ schema_contents = catalog_contents.setdefault(schema_name, [])
529
+ for coll in schema.containers():
530
+ for name, obj in coll.items():
531
+ schema_contents.append(
532
+ obj.flight_info(
533
+ name=name,
534
+ catalog_name=parameters.catalog_name,
535
+ schema_name=schema_name,
536
+ )
537
+ )
538
+
539
+ return flight_inventory.upload_and_generate_schema_list(
540
+ upload_parameters=flight_inventory.UploadParameters(
541
+ s3_client=None,
542
+ base_url="http://localhost",
543
+ bucket_name="test_bucket",
544
+ bucket_prefix="test_prefix",
545
+ ),
546
+ flight_service_name=self.service_name,
547
+ flight_inventory=dynamic_inventory,
548
+ schema_details={},
549
+ skip_upload=True,
550
+ serialize_inline=True,
551
+ catalog_version=1,
552
+ catalog_version_fixed=False,
553
+ )
554
+
555
+ def impl_list_flights(
556
+ self,
557
+ *,
558
+ context: base_server.CallContext[auth.Account, auth.AccountToken],
559
+ criteria: bytes,
560
+ ) -> Iterator[flight.FlightInfo]:
561
+ assert context.caller is not None
562
+ self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
563
+ library = self.contents[context.caller.token.token]
564
+
565
+ def yield_flight_infos() -> Generator[flight.FlightInfo, None, None]:
566
+ for db_name, db in library.databases_by_name.items():
567
+ for schema_name, schema in db.schemas_by_name.items():
568
+ for coll in schema.containers():
569
+ for name, obj in coll.items():
570
+ yield obj.flight_info(name=name, catalog_name=db_name, schema_name=schema_name)[0]
571
+
572
+ return yield_flight_infos()
573
+
574
+ def impl_get_flight_info(
575
+ self,
576
+ *,
577
+ context: base_server.CallContext[auth.Account, auth.AccountToken],
578
+ descriptor: flight.FlightDescriptor,
579
+ ) -> flight.FlightInfo:
580
+ assert context.caller is not None
581
+ self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
582
+
583
+ descriptor_parts = descriptor_unpack_(descriptor)
584
+ library = self.contents[context.caller.token.token]
585
+ database = library.by_name(descriptor_parts.catalog_name)
586
+ schema = database.by_name(descriptor_parts.schema_name)
587
+
588
+ obj = schema.by_name(descriptor_parts.type, descriptor_parts.name)
589
+ return obj.flight_info(
590
+ name=descriptor_parts.name,
591
+ catalog_name=descriptor_parts.catalog_name,
592
+ schema_name=descriptor_parts.schema_name,
593
+ )[0]
594
+
595
+ def action_catalog_version(
596
+ self,
597
+ *,
598
+ context: base_server.CallContext[auth.Account, auth.AccountToken],
599
+ parameters: parameter_types.CatalogVersion,
600
+ ) -> base_server.GetCatalogVersionResult:
601
+ assert context.caller is not None
602
+
603
+ self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
604
+ library = self.contents[context.caller.token.token]
605
+ database = library.by_name(parameters.catalog_name)
606
+
607
+ context.logger.debug(
608
+ "catalog_version_result",
609
+ catalog_name=parameters.catalog_name,
610
+ version=database.version,
611
+ )
612
+ return base_server.GetCatalogVersionResult(catalog_version=database.version, is_fixed=False)
613
+
614
+ def action_create_transaction(
615
+ self,
616
+ *,
617
+ context: base_server.CallContext[auth.Account, auth.AccountToken],
618
+ parameters: parameter_types.CreateTransaction,
619
+ ) -> base_server.CreateTransactionResult:
620
+ return base_server.CreateTransactionResult(identifier=None)
621
+
622
+ def action_create_schema(
623
+ self,
624
+ *,
625
+ context: base_server.CallContext[auth.Account, auth.AccountToken],
626
+ parameters: parameter_types.CreateSchema,
627
+ ) -> base_server.AirportSerializedContentsWithSHA256Hash:
628
+ assert context.caller is not None
629
+
630
+ self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
631
+ library = self.contents[context.caller.token.token]
632
+ database = library.by_name(parameters.catalog_name)
633
+
634
+ if database.schemas_by_name.get(parameters.schema_name) is not None:
635
+ raise flight.FlightServerError(f"Schema {parameters.schema_name} already exists")
636
+
637
+ database.schemas_by_name[parameters.schema_name] = SchemaCollection()
638
+ database.version += 1
639
+
640
+ # FIXME: this needs to be handled better on the server side...
641
+ # rather than calling into internal methods.
642
+ packed_data = msgpack.packb([])
643
+ assert packed_data
644
+ compressed_data = schema_uploader._compress_and_prefix_with_length(packed_data, compression_level=3)
645
+
646
+ empty_hash = hashlib.sha256(compressed_data).hexdigest()
647
+ return base_server.AirportSerializedContentsWithSHA256Hash(
648
+ url=None, sha256=empty_hash, serialized=compressed_data
649
+ )
650
+
651
+ def action_drop_table(
652
+ self,
653
+ *,
654
+ context: base_server.CallContext[auth.Account, auth.AccountToken],
655
+ parameters: parameter_types.DropObject,
656
+ ) -> None:
657
+ assert context.caller is not None
658
+
659
+ self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
660
+ library = self.contents[context.caller.token.token]
661
+ database = library.by_name(parameters.catalog_name)
662
+ schema = database.by_name(parameters.schema_name)
663
+
664
+ schema.by_name("table", parameters.name)
665
+
666
+ del schema.tables_by_name[parameters.name]
667
+ database.version += 1
668
+
669
+ def action_drop_schema(
670
+ self,
671
+ *,
672
+ context: base_server.CallContext[auth.Account, auth.AccountToken],
673
+ parameters: parameter_types.DropObject,
674
+ ) -> None:
675
+ assert context.caller is not None
676
+
677
+ self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
678
+ library = self.contents[context.caller.token.token]
679
+ database = library.by_name(parameters.catalog_name)
680
+
681
+ if database.schemas_by_name.get(parameters.name) is None:
682
+ raise flight.FlightServerError(f"Schema '{parameters.name}' does not exist")
683
+
684
+ del database.schemas_by_name[parameters.name]
685
+ database.version += 1
686
+
687
+ def action_create_table(
688
+ self,
689
+ *,
690
+ context: base_server.CallContext[auth.Account, auth.AccountToken],
691
+ parameters: parameter_types.CreateTable,
692
+ ) -> flight.FlightInfo:
693
+ assert context.caller is not None
694
+
695
+ self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
696
+ library = self.contents[context.caller.token.token]
697
+ database = library.by_name(parameters.catalog_name)
698
+ schema = database.by_name(parameters.schema_name)
699
+
700
+ if parameters.table_name in schema.tables_by_name:
701
+ raise flight.FlightServerError(
702
+ f"Table {parameters.table_name} already exists for token {context.caller.token}"
703
+ )
704
+
705
+ # FIXME: may want to add a row_id column that is not visable to the user, so that inserts and
706
+ # deletes can be tested.
707
+
708
+ assert "_rowid" not in parameters.arrow_schema.names
709
+
710
+ schema_with_row_id = parameters.arrow_schema.append(self.rowid_field)
711
+
712
+ table_info = TableInfo([schema_with_row_id.empty_table()], 0)
713
+
714
+ schema.tables_by_name[parameters.table_name] = table_info
715
+
716
+ database.version += 1
717
+
718
+ return table_info.flight_info(
719
+ name=parameters.table_name,
720
+ catalog_name=parameters.catalog_name,
721
+ schema_name=parameters.schema_name,
722
+ )[0]
723
+
724
+ def impl_do_action(
725
+ self,
726
+ *,
727
+ context: base_server.CallContext[auth.Account, auth.AccountToken],
728
+ action: flight.Action,
729
+ ) -> Iterator[bytes]:
730
+ assert context.caller is not None
731
+
732
+ if action.type == "reset":
733
+ context.logger.debug("Resetting server state")
734
+ if context.caller.token.token in self.contents:
735
+ del self.contents[context.caller.token.token]
736
+ return iter([])
737
+ elif action.type == "create_database":
738
+ database_name = action.body.to_pybytes().decode("utf-8")
739
+ context.logger.debug("Creating database", database_name=database_name)
740
+ self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
741
+ library = self.contents[context.caller.token.token]
742
+ library.databases_by_name[database_name] = DatabaseContents()
743
+
744
+ # Since we are creating a new database, lets load it with a few example
745
+ # scalar functions.
746
+
747
+ def add_handler(table: pa.Table) -> pa.Array:
748
+ assert table.num_columns == 2
749
+ return pc.add(table.column(0), table.column(1))
750
+
751
+ def uppercase_handler(table: pa.Table) -> pa.Array:
752
+ assert table.num_columns == 1
753
+ return pc.utf8_upper(table.column(0))
754
+
755
+ def any_type_handler(table: pa.Table) -> pa.Array:
756
+ return table.column(0)
757
+
758
+ def echo_handler(
759
+ parameters: parameter_types.TableFunctionParameters,
760
+ output_schema: pa.Schema,
761
+ ) -> Generator[pa.RecordBatch, pa.RecordBatch, None]:
762
+ # Just echo the parameters back as a single row.
763
+ assert parameters.parameters
764
+ yield pa.RecordBatch.from_arrays(
765
+ [parameters.parameters.column(0)],
766
+ schema=pa.schema([pa.field("result", pa.string())]),
767
+ )
768
+
769
+ def long_handler(
770
+ parameters: parameter_types.TableFunctionParameters,
771
+ output_schema: pa.Schema,
772
+ ) -> Generator[pa.RecordBatch, pa.RecordBatch, None]:
773
+ assert parameters.parameters
774
+ for i in range(100):
775
+ yield pa.RecordBatch.from_arrays([[f"{i}"] * 3000] * len(output_schema), schema=output_schema)
776
+
777
+ def repeat_handler(
778
+ parameters: parameter_types.TableFunctionParameters,
779
+ output_schema: pa.Schema,
780
+ ) -> Generator[pa.RecordBatch, pa.RecordBatch, None]:
781
+ # Just echo the parameters back as a single row.
782
+ assert parameters.parameters
783
+ for _i in range(parameters.parameters.column(1).to_pylist()[0]):
784
+ yield pa.RecordBatch.from_arrays(
785
+ [parameters.parameters.column(0)],
786
+ schema=output_schema,
787
+ )
788
+
789
+ def wide_handler(
790
+ parameters: parameter_types.TableFunctionParameters,
791
+ output_schema: pa.Schema,
792
+ ) -> Generator[pa.RecordBatch, pa.RecordBatch, None]:
793
+ # Just echo the parameters back as a single row.
794
+ assert parameters.parameters
795
+ rows = []
796
+ for i in range(parameters.parameters.column(0).to_pylist()[0]):
797
+ rows.append({f"result_{idx}": idx for idx in range(20)})
798
+
799
+ yield pa.RecordBatch.from_pylist(rows, schema=output_schema)
800
+
801
+ def dynamic_schema_handler(
802
+ parameters: parameter_types.TableFunctionParameters,
803
+ output_schema: pa.Schema,
804
+ ) -> Generator[pa.RecordBatch, pa.RecordBatch, None]:
805
+ yield parameters.parameters
806
+
807
+ def dynamic_schema_handler_output_schema(
808
+ parameters: pa.RecordBatch, input_schema: pa.Schema | None = None
809
+ ) -> pa.Schema:
810
+ # This is the schema that will be returned to the client.
811
+ # It will be used to create the table function.
812
+ assert isinstance(parameters, pa.RecordBatch)
813
+ return parameters.schema
814
+
815
+ def in_out_schema_handler(parameters: pa.RecordBatch, input_schema: pa.Schema | None = None) -> pa.Schema:
816
+ assert input_schema is not None
817
+ return pa.schema([parameters.schema.field(0), input_schema.field(0)])
818
+
819
+ def in_out_wide_schema_handler(
820
+ parameters: pa.RecordBatch, input_schema: pa.Schema | None = None
821
+ ) -> pa.Schema:
822
+ assert input_schema is not None
823
+ return pa.schema([pa.field(f"result_{i}", pa.int32()) for i in range(20)])
824
+
825
+ def in_out_echo_schema_handler(
826
+ parameters: pa.RecordBatch, input_schema: pa.Schema | None = None
827
+ ) -> pa.Schema:
828
+ assert input_schema is not None
829
+ return input_schema
830
+
831
+ def in_out_echo_handler(
832
+ parameters: parameter_types.TableFunctionParameters,
833
+ output_schema: pa.Schema,
834
+ ) -> Generator[pa.RecordBatch, pa.RecordBatch, None]:
835
+ result = output_schema.empty_table()
836
+
837
+ while True:
838
+ input_chunk = yield result
839
+
840
+ if input_chunk is None:
841
+ break
842
+
843
+ result = input_chunk
844
+
845
+ return
846
+
847
+ def in_out_wide_handler(
848
+ parameters: parameter_types.TableFunctionParameters,
849
+ output_schema: pa.Schema,
850
+ ) -> Generator[pa.RecordBatch, pa.RecordBatch, None]:
851
+ result = output_schema.empty_table()
852
+
853
+ while True:
854
+ input_chunk = yield result
855
+
856
+ if input_chunk is None:
857
+ break
858
+
859
+ result = pa.RecordBatch.from_arrays(
860
+ [[i] * len(input_chunk) for i in range(20)],
861
+ schema=output_schema,
862
+ )
863
+
864
+ return
865
+
866
+ def in_out_handler(
867
+ parameters: parameter_types.TableFunctionParameters,
868
+ output_schema: pa.Schema,
869
+ ) -> Generator[pa.RecordBatch, pa.RecordBatch, None]:
870
+ result = output_schema.empty_table()
871
+
872
+ while True:
873
+ input_chunk = yield result
874
+
875
+ if input_chunk is None:
876
+ break
877
+
878
+ assert parameters.parameters
879
+ result = pa.RecordBatch.from_arrays(
880
+ [
881
+ parameters.parameters.column(0),
882
+ input_chunk.column(0),
883
+ ],
884
+ schema=output_schema,
885
+ )
886
+
887
+ return pa.RecordBatch.from_arrays([["last"], ["row"]], schema=output_schema)
888
+
889
+ util_schema = SchemaCollection(
890
+ scalar_functions_by_name=CaseInsensitiveDict(
891
+ {
892
+ "test_uppercase": ScalarFunction(
893
+ input_schema=pa.schema([pa.field("a", pa.string())]),
894
+ output_schema=pa.schema([pa.field("result", pa.string())]),
895
+ handler=uppercase_handler,
896
+ ),
897
+ "test_any_type": ScalarFunction(
898
+ input_schema=pa.schema([pa.field("a", pa.string(), metadata={"is_any_type": "1"})]),
899
+ output_schema=pa.schema([pa.field("result", pa.string())]),
900
+ handler=any_type_handler,
901
+ ),
902
+ "test_add": ScalarFunction(
903
+ input_schema=pa.schema([pa.field("a", pa.int64()), pa.field("b", pa.int64())]),
904
+ output_schema=pa.schema([pa.field("result", pa.int64())]),
905
+ handler=add_handler,
906
+ ),
907
+ }
908
+ ),
909
+ table_functions_by_name=CaseInsensitiveDict(
910
+ {
911
+ "test_echo": TableFunction(
912
+ input_schema=pa.schema([pa.field("input", pa.string())]),
913
+ output_schema_source=pa.schema([pa.field("result", pa.string())]),
914
+ handler=echo_handler,
915
+ ),
916
+ "test_wide": TableFunction(
917
+ input_schema=pa.schema([pa.field("count", pa.int32())]),
918
+ output_schema_source=pa.schema([pa.field(f"result_{i}", pa.int32()) for i in range(20)]),
919
+ handler=wide_handler,
920
+ ),
921
+ "test_long": TableFunction(
922
+ input_schema=pa.schema([pa.field("input", pa.string())]),
923
+ output_schema_source=pa.schema(
924
+ [
925
+ pa.field("result", pa.string()),
926
+ pa.field("result2", pa.string()),
927
+ ]
928
+ ),
929
+ handler=long_handler,
930
+ ),
931
+ "test_repeat": TableFunction(
932
+ input_schema=pa.schema(
933
+ [
934
+ pa.field("input", pa.string()),
935
+ pa.field("count", pa.int32()),
936
+ ]
937
+ ),
938
+ output_schema_source=pa.schema([pa.field("result", pa.string())]),
939
+ handler=repeat_handler,
940
+ ),
941
+ "test_dynamic_schema": TableFunction(
942
+ input_schema=pa.schema(
943
+ [
944
+ pa.field(
945
+ "input",
946
+ pa.string(),
947
+ metadata={"is_any_type": "1"},
948
+ )
949
+ ]
950
+ ),
951
+ output_schema_source=TableFunctionDynamicOutput(
952
+ schema_creator=dynamic_schema_handler_output_schema,
953
+ default_values=(
954
+ pa.RecordBatch.from_arrays(
955
+ [pa.array([1], type=pa.int32())],
956
+ schema=pa.schema([pa.field("input", pa.int32())]),
957
+ ),
958
+ None,
959
+ ),
960
+ ),
961
+ handler=dynamic_schema_handler,
962
+ ),
963
+ "test_dynamic_schema_named_parameters": TableFunction(
964
+ input_schema=pa.schema(
965
+ [
966
+ pa.field("name", pa.string()),
967
+ pa.field(
968
+ "location",
969
+ pa.string(),
970
+ metadata={"is_named_parameter": "1"},
971
+ ),
972
+ pa.field(
973
+ "input",
974
+ pa.string(),
975
+ metadata={"is_any_type": "1"},
976
+ ),
977
+ pa.field("city", pa.string()),
978
+ ]
979
+ ),
980
+ output_schema_source=TableFunctionDynamicOutput(
981
+ schema_creator=dynamic_schema_handler_output_schema,
982
+ default_values=(
983
+ pa.RecordBatch.from_arrays(
984
+ [pa.array([1], type=pa.int32())],
985
+ schema=pa.schema([pa.field("input", pa.int32())]),
986
+ ),
987
+ None,
988
+ ),
989
+ ),
990
+ handler=dynamic_schema_handler,
991
+ ),
992
+ "test_table_in_out": TableFunction(
993
+ input_schema=pa.schema(
994
+ [
995
+ pa.field("input", pa.string()),
996
+ pa.field(
997
+ "table_input",
998
+ pa.string(),
999
+ metadata={"is_table_type": "1"},
1000
+ ),
1001
+ ]
1002
+ ),
1003
+ output_schema_source=TableFunctionDynamicOutput(
1004
+ schema_creator=in_out_schema_handler,
1005
+ default_values=(
1006
+ pa.RecordBatch.from_arrays(
1007
+ [pa.array([1], type=pa.int32())],
1008
+ schema=pa.schema([pa.field("input", pa.int32())]),
1009
+ ),
1010
+ pa.schema([pa.field("input", pa.int32())]),
1011
+ ),
1012
+ ),
1013
+ handler=in_out_handler,
1014
+ ),
1015
+ "test_table_in_out_wide": TableFunction(
1016
+ input_schema=pa.schema(
1017
+ [
1018
+ pa.field("input", pa.string()),
1019
+ pa.field(
1020
+ "table_input",
1021
+ pa.string(),
1022
+ metadata={"is_table_type": "1"},
1023
+ ),
1024
+ ]
1025
+ ),
1026
+ output_schema_source=TableFunctionDynamicOutput(
1027
+ schema_creator=in_out_wide_schema_handler,
1028
+ default_values=(
1029
+ pa.RecordBatch.from_arrays(
1030
+ [pa.array([1], type=pa.int32())],
1031
+ schema=pa.schema([pa.field("input", pa.int32())]),
1032
+ ),
1033
+ pa.schema([pa.field("input", pa.int32())]),
1034
+ ),
1035
+ ),
1036
+ handler=in_out_wide_handler,
1037
+ ),
1038
+ "test_table_in_out_echo": TableFunction(
1039
+ input_schema=pa.schema(
1040
+ [
1041
+ pa.field(
1042
+ "table_input",
1043
+ pa.string(),
1044
+ metadata={"is_table_type": "1"},
1045
+ ),
1046
+ ]
1047
+ ),
1048
+ output_schema_source=TableFunctionDynamicOutput(
1049
+ schema_creator=in_out_echo_schema_handler,
1050
+ default_values=(
1051
+ pa.RecordBatch.from_arrays(
1052
+ [pa.array([1], type=pa.int32())],
1053
+ schema=pa.schema([pa.field("input", pa.int32())]),
1054
+ ),
1055
+ pa.schema([pa.field("input", pa.int32())]),
1056
+ ),
1057
+ ),
1058
+ handler=in_out_echo_handler,
1059
+ ),
1060
+ }
1061
+ ),
1062
+ )
1063
+
1064
+ library.databases_by_name[database_name].schemas_by_name["utils"] = util_schema
1065
+
1066
+ return iter([])
1067
+ elif action.type == "drop_database":
1068
+ database_name = action.body.to_pybytes().decode("utf-8")
1069
+ context.logger.debug("Dropping database", database_name=database_name)
1070
+
1071
+ self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
1072
+ library = self.contents[context.caller.token.token]
1073
+ if action.body.decode("utf-8") not in library.databases_by_name:
1074
+ raise flight.FlightServerError(f"Database {action.body.decode('utf-8')} does not exist")
1075
+ del library.databases_by_name[action.body.decode("utf-8")]
1076
+ return iter([])
1077
+
1078
+ raise flight.FlightServerError(f"Unsupported action type: {action.type}")
1079
+
1080
+ def exchange_update(
1081
+ self,
1082
+ *,
1083
+ context: base_server.CallContext[auth.Account, auth.AccountToken],
1084
+ descriptor: flight.FlightDescriptor,
1085
+ reader: flight.MetadataRecordBatchReader,
1086
+ writer: flight.MetadataRecordBatchWriter,
1087
+ return_chunks: bool,
1088
+ ) -> int:
1089
+ assert context.caller is not None
1090
+
1091
+ descriptor_parts = descriptor_unpack_(descriptor)
1092
+
1093
+ if descriptor_parts.type != "table":
1094
+ raise flight.FlightServerError(f"Unsupported descriptor type: {descriptor_parts.type}")
1095
+ library = self.contents[context.caller.token.token]
1096
+ database = library.by_name(descriptor_parts.catalog_name)
1097
+ schema = database.by_name(descriptor_parts.schema_name)
1098
+ table_info = schema.by_name("table", descriptor_parts.name)
1099
+
1100
+ existing_table = table_info.version()
1101
+
1102
+ writer.begin(existing_table.schema)
1103
+
1104
+ rowid_index = existing_table.schema.get_field_index(self.ROWID_FIELD_NAME)
1105
+ assert rowid_index != -1
1106
+
1107
+ change_count = 0
1108
+
1109
+ for chunk in reader:
1110
+ if chunk.data is not None:
1111
+ chunk_table = pa.Table.from_batches([chunk.data])
1112
+ assert chunk_table.num_rows > 0
1113
+
1114
+ # So this chunk will contain any updated columns and the row id.
1115
+
1116
+ input_rowid_index = chunk_table.schema.get_field_index(self.ROWID_FIELD_NAME)
1117
+
1118
+ # To perform an update, first remove the rows from the table,
1119
+ # then assign new row ids to the incoming rows, and append them to the table.
1120
+ table_mask = pc.is_in(
1121
+ existing_table.column(rowid_index),
1122
+ value_set=chunk_table.column(input_rowid_index),
1123
+ )
1124
+
1125
+ # This is the table with updated rows removed, we'll be adding rows to it later on.
1126
+ table_without_updated_rows = pc.filter(existing_table, pc.invert(table_mask))
1127
+
1128
+ # These are the rows that are being updated, since the update may not send all
1129
+ # the columns, we need to filter the table to get the updated rows to persist the
1130
+ # values that aren't being updated.
1131
+ changed_rows = pc.filter(existing_table, table_mask)
1132
+
1133
+ # Get a list of all of the columns that are not being updated, so that a join
1134
+ # can be performed.
1135
+ unchanged_column_names = set(existing_table.schema.names) - set(chunk_table.schema.names)
1136
+
1137
+ joined_table = pa.Table.join(
1138
+ changed_rows.select(list(unchanged_column_names) + [self.ROWID_FIELD_NAME]),
1139
+ chunk_table,
1140
+ keys=[self.ROWID_FIELD_NAME],
1141
+ join_type="inner",
1142
+ )
1143
+
1144
+ # Add the new row id column.
1145
+ chunk_length = len(joined_table)
1146
+ rowid_values = [
1147
+ x
1148
+ for x in range(
1149
+ table_info.row_id_counter,
1150
+ table_info.row_id_counter + chunk_length,
1151
+ )
1152
+ ]
1153
+ updated_rows = joined_table.set_column(
1154
+ joined_table.schema.get_field_index(self.ROWID_FIELD_NAME),
1155
+ self.rowid_field,
1156
+ [rowid_values],
1157
+ )
1158
+ table_info.row_id_counter += chunk_length
1159
+
1160
+ # Now the columns may be in a different order, so we need to realign them.
1161
+ updated_rows = updated_rows.select(existing_table.schema.names)
1162
+
1163
+ check_schema_is_subset_of_schema(existing_table.schema, updated_rows.schema)
1164
+
1165
+ updated_rows = conform_nullable(existing_table.schema, updated_rows)
1166
+
1167
+ updated_table = pa.concat_tables(
1168
+ [
1169
+ table_without_updated_rows,
1170
+ updated_rows.select(table_without_updated_rows.schema.names),
1171
+ ]
1172
+ )
1173
+
1174
+ if return_chunks:
1175
+ writer.write_table(updated_rows)
1176
+
1177
+ existing_table = updated_table
1178
+
1179
+ table_info.update_table(existing_table)
1180
+
1181
+ return change_count
1182
+
1183
+ def exchange_delete(
1184
+ self,
1185
+ *,
1186
+ context: base_server.CallContext[auth.Account, auth.AccountToken],
1187
+ descriptor: flight.FlightDescriptor,
1188
+ reader: flight.MetadataRecordBatchReader,
1189
+ writer: flight.MetadataRecordBatchWriter,
1190
+ return_chunks: bool,
1191
+ ) -> int:
1192
+ assert context.caller is not None
1193
+
1194
+ descriptor_parts = descriptor_unpack_(descriptor)
1195
+
1196
+ if descriptor_parts.type != "table":
1197
+ raise flight.FlightServerError(f"Unsupported descriptor type: {descriptor_parts.type}")
1198
+ library = self.contents[context.caller.token.token]
1199
+ database = library.by_name(descriptor_parts.catalog_name)
1200
+ schema = database.by_name(descriptor_parts.schema_name)
1201
+ table_info = schema.by_name("table", descriptor_parts.name)
1202
+
1203
+ existing_table = table_info.version()
1204
+ writer.begin(existing_table.schema)
1205
+
1206
+ rowid_index = existing_table.schema.get_field_index(self.ROWID_FIELD_NAME)
1207
+ assert rowid_index != -1
1208
+
1209
+ change_count = 0
1210
+
1211
+ for chunk in reader:
1212
+ if chunk.data is not None:
1213
+ chunk_table = pa.Table.from_batches([chunk.data])
1214
+ assert chunk_table.num_rows > 0
1215
+
1216
+ # Should only be getting the row id.
1217
+ assert chunk_table.num_columns == 1
1218
+ # the rowid field doesn't come back the same since it missing the
1219
+ # not null flag and the metadata, so can't compare the schemas
1220
+ input_rowid_index = chunk_table.schema.get_field_index(self.ROWID_FIELD_NAME)
1221
+
1222
+ # Now do an antijoin to get the rows that are not in the delete_rows.
1223
+ target_rowids = chunk_table.column(input_rowid_index)
1224
+ existing_row_ids = existing_table.column(rowid_index)
1225
+
1226
+ mask = pc.is_in(existing_row_ids, value_set=target_rowids)
1227
+
1228
+ target_rows = pc.filter(existing_table, mask)
1229
+ changed_table = pc.filter(existing_table, pc.invert(mask))
1230
+
1231
+ change_count += target_rows.num_rows
1232
+
1233
+ if return_chunks:
1234
+ writer.write_table(target_rows)
1235
+
1236
+ existing_table = changed_table
1237
+
1238
+ table_info.update_table(existing_table)
1239
+
1240
+ return change_count
1241
+
1242
+ def exchange_insert(
1243
+ self,
1244
+ *,
1245
+ context: base_server.CallContext[auth.Account, auth.AccountToken],
1246
+ descriptor: flight.FlightDescriptor,
1247
+ reader: flight.MetadataRecordBatchReader,
1248
+ writer: flight.MetadataRecordBatchWriter,
1249
+ return_chunks: bool,
1250
+ ) -> int:
1251
+ assert context.caller is not None
1252
+
1253
+ descriptor_parts = descriptor_unpack_(descriptor)
1254
+
1255
+ if descriptor_parts.type != "table":
1256
+ raise flight.FlightServerError(f"Unsupported descriptor type: {descriptor_parts.type}")
1257
+ library = self.contents[context.caller.token.token]
1258
+ database = library.by_name(descriptor_parts.catalog_name)
1259
+ schema = database.by_name(descriptor_parts.schema_name)
1260
+ table_info = schema.by_name("table", descriptor_parts.name)
1261
+
1262
+ existing_table = table_info.version()
1263
+ writer.begin(existing_table.schema)
1264
+ change_count = 0
1265
+
1266
+ rowid_index = existing_table.schema.get_field_index(self.ROWID_FIELD_NAME)
1267
+ assert rowid_index != -1
1268
+
1269
+ # Check that the data being read matches the table without the rowid column.
1270
+
1271
+ # DuckDB won't send field metadata when it sends us the schema that it uses
1272
+ # to perform an insert, so we need some way to adapt the schema we
1273
+ check_schema_is_subset_of_schema(existing_table.schema, reader.schema)
1274
+
1275
+ # FIXME: need to handle the case of rowids.
1276
+
1277
+ for chunk in reader:
1278
+ if chunk.data is not None:
1279
+ new_rows = pa.Table.from_batches([chunk.data])
1280
+ assert new_rows.num_rows > 0
1281
+
1282
+ # append the row id column to the new rows.
1283
+ chunk_length = new_rows.num_rows
1284
+ rowid_values = [
1285
+ x
1286
+ for x in range(
1287
+ table_info.row_id_counter,
1288
+ table_info.row_id_counter + chunk_length,
1289
+ )
1290
+ ]
1291
+ new_rows = new_rows.append_column(self.rowid_field, [rowid_values])
1292
+ table_info.row_id_counter += chunk_length
1293
+ change_count += chunk_length
1294
+
1295
+ if return_chunks:
1296
+ writer.write_table(new_rows)
1297
+
1298
+ # Since the table could have columns removed and deleted, use .select
1299
+ # the ensure that the columns are aligned in the same order as the original table.
1300
+
1301
+ # So it turns out that DuckDB doesn't send the "not null" flag in the arrow schema.
1302
+ #
1303
+ # This means we can't concat the tables, without those flags matching.
1304
+ # for field_name in existing_table.schema.names:
1305
+ # field = existing_table.schema.field(field_name)
1306
+ # if not field.nullable:
1307
+ # field_index = new_rows.schema.get_field_index(field_name)
1308
+ # new_rows = new_rows.set_column(
1309
+ # field_index, field.with_nullable(False), new_rows.column(field_index)
1310
+ # )
1311
+ new_rows = conform_nullable(existing_table.schema, new_rows)
1312
+
1313
+ table_info.update_table(
1314
+ pa.concat_tables([existing_table, new_rows.select(existing_table.schema.names)])
1315
+ )
1316
+
1317
+ return change_count
1318
+
1319
+ def exchange_table_function_in_out(
1320
+ self,
1321
+ *,
1322
+ context: base_server.CallContext[auth.Account, auth.AccountToken],
1323
+ descriptor: flight.FlightDescriptor,
1324
+ parameters: parameter_types.TableFunctionParameters,
1325
+ input_schema: pa.Schema,
1326
+ ) -> tuple[pa.Schema, Generator[pa.RecordBatch, pa.RecordBatch, pa.RecordBatch]]:
1327
+ assert context.caller is not None
1328
+
1329
+ descriptor_parts = descriptor_unpack_(descriptor)
1330
+
1331
+ if descriptor_parts.type != "table_function":
1332
+ raise flight.FlightServerError(f"Unsupported descriptor type: {descriptor_parts.type}")
1333
+ library = self.contents[context.caller.token.token]
1334
+ database = library.by_name(descriptor_parts.catalog_name)
1335
+ schema = database.by_name(descriptor_parts.schema_name)
1336
+ table_info = schema.by_name("table_function", descriptor_parts.name)
1337
+
1338
+ output_schema = table_info.output_schema(parameters=parameters.parameters, input_schema=input_schema)
1339
+ gen = table_info.handler(parameters, output_schema)
1340
+ return (output_schema, gen)
1341
+
1342
+ def action_add_column(
1343
+ self,
1344
+ *,
1345
+ context: base_server.CallContext[auth.Account, auth.AccountToken],
1346
+ parameters: parameter_types.AddColumn,
1347
+ ) -> flight.FlightInfo:
1348
+ assert context.caller is not None
1349
+
1350
+ self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
1351
+ library = self.contents[context.caller.token.token]
1352
+ database = library.by_name(parameters.catalog)
1353
+ schema = database.by_name(parameters.schema_name)
1354
+
1355
+ table_info = schema.by_name("table", parameters.name)
1356
+
1357
+ assert len(parameters.column_schema.names) == 1
1358
+
1359
+ existing_table = table_info.version()
1360
+ # Don't allow duplicate colum names.
1361
+ assert parameters.column_schema.field(0).name not in existing_table.schema.names
1362
+
1363
+ table_info.update_table(
1364
+ existing_table.append_column(
1365
+ parameters.column_schema.field(0).name,
1366
+ [
1367
+ pa.nulls(
1368
+ existing_table.num_rows,
1369
+ type=parameters.column_schema.field(0).type,
1370
+ )
1371
+ ],
1372
+ )
1373
+ )
1374
+ database.version += 1
1375
+
1376
+ return table_info.flight_info(
1377
+ name=parameters.name,
1378
+ catalog_name=parameters.catalog,
1379
+ schema_name=parameters.schema_name,
1380
+ )[0]
1381
+
1382
+ def action_remove_column(
1383
+ self,
1384
+ *,
1385
+ context: base_server.CallContext[auth.Account, auth.AccountToken],
1386
+ parameters: parameter_types.RemoveColumn,
1387
+ ) -> flight.FlightInfo:
1388
+ assert context.caller is not None
1389
+
1390
+ self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
1391
+ library = self.contents[context.caller.token.token]
1392
+ database = library.by_name(parameters.catalog)
1393
+ schema = database.by_name(parameters.schema_name)
1394
+
1395
+ table_info = schema.by_name("table", parameters.name)
1396
+
1397
+ table_info.update_table(table_info.version().drop(parameters.removed_column))
1398
+ database.version += 1
1399
+
1400
+ return table_info.flight_info(
1401
+ name=parameters.name,
1402
+ catalog_name=parameters.catalog,
1403
+ schema_name=parameters.schema_name,
1404
+ )[0]
1405
+
1406
+ def action_rename_column(
1407
+ self,
1408
+ *,
1409
+ context: base_server.CallContext[auth.Account, auth.AccountToken],
1410
+ parameters: parameter_types.RenameColumn,
1411
+ ) -> flight.FlightInfo:
1412
+ assert context.caller is not None
1413
+
1414
+ self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
1415
+ library = self.contents[context.caller.token.token]
1416
+ database = library.by_name(parameters.catalog)
1417
+ schema = database.by_name(parameters.schema_name)
1418
+
1419
+ table_info = schema.by_name("table", parameters.name)
1420
+
1421
+ table_info.update_table(table_info.version().rename_columns({parameters.old_name: parameters.new_name}))
1422
+ database.version += 1
1423
+
1424
+ return table_info.flight_info(
1425
+ name=parameters.name,
1426
+ catalog_name=parameters.catalog,
1427
+ schema_name=parameters.schema_name,
1428
+ )[0]
1429
+
1430
+ def action_rename_table(
1431
+ self,
1432
+ *,
1433
+ context: base_server.CallContext[auth.Account, auth.AccountToken],
1434
+ parameters: parameter_types.RenameTable,
1435
+ ) -> flight.FlightInfo:
1436
+ assert context.caller is not None
1437
+
1438
+ self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
1439
+ library = self.contents[context.caller.token.token]
1440
+ database = library.by_name(parameters.catalog)
1441
+ schema = database.by_name(parameters.schema_name)
1442
+
1443
+ table_info = schema.by_name("table", parameters.name)
1444
+
1445
+ schema.tables_by_name[parameters.new_table_name] = schema.tables_by_name.pop(parameters.name)
1446
+
1447
+ database.version += 1
1448
+
1449
+ return table_info.flight_info(
1450
+ name=parameters.new_table_name,
1451
+ catalog_name=parameters.catalog,
1452
+ schema_name=parameters.schema_name,
1453
+ )[0]
1454
+
1455
+ def action_set_default(
1456
+ self,
1457
+ *,
1458
+ context: base_server.CallContext[auth.Account, auth.AccountToken],
1459
+ parameters: parameter_types.SetDefault,
1460
+ ) -> flight.FlightInfo:
1461
+ assert context.caller is not None
1462
+
1463
+ self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
1464
+ library = self.contents[context.caller.token.token]
1465
+ database = library.by_name(parameters.catalog)
1466
+ schema = database.by_name(parameters.schema_name)
1467
+ table_info = schema.by_name("table", parameters.name)
1468
+
1469
+ # Defaults are set as metadata on a field.
1470
+
1471
+ t = table_info.version()
1472
+ field_index = t.schema.get_field_index(parameters.column_name)
1473
+ field = t.schema.field(parameters.column_name)
1474
+ new_metadata: dict[str, Any] = {}
1475
+ if field.metadata:
1476
+ new_metadata = {**field.metadata}
1477
+
1478
+ new_metadata["default"] = parameters.expression
1479
+
1480
+ table_info.update_table(
1481
+ t.set_column(
1482
+ field_index,
1483
+ field.with_metadata(new_metadata),
1484
+ t.column(field_index),
1485
+ )
1486
+ )
1487
+
1488
+ database.version += 1
1489
+
1490
+ return table_info.flight_info(
1491
+ name=parameters.name,
1492
+ catalog_name=parameters.catalog,
1493
+ schema_name=parameters.schema_name,
1494
+ )[0]
1495
+
1496
+ def action_set_not_null(
1497
+ self,
1498
+ *,
1499
+ context: base_server.CallContext[auth.Account, auth.AccountToken],
1500
+ parameters: parameter_types.SetNotNull,
1501
+ ) -> flight.FlightInfo:
1502
+ assert context.caller is not None
1503
+
1504
+ self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
1505
+ library = self.contents[context.caller.token.token]
1506
+ database = library.by_name(parameters.catalog)
1507
+ schema = database.by_name(parameters.schema_name)
1508
+
1509
+ table_info = schema.by_name("table", parameters.name)
1510
+
1511
+ t = table_info.version()
1512
+ field_index = t.schema.get_field_index(parameters.column_name)
1513
+ field = t.schema.field(parameters.column_name)
1514
+
1515
+ if t.column(field_index).null_count > 0:
1516
+ raise flight.FlightServerError(f"Cannot set column {parameters.column_name} contains null values")
1517
+
1518
+ table_info.update_table(
1519
+ t.set_column(
1520
+ field_index,
1521
+ field.with_nullable(False),
1522
+ t.column(field_index),
1523
+ )
1524
+ )
1525
+
1526
+ database.version += 1
1527
+
1528
+ return table_info.flight_info(
1529
+ name=parameters.name,
1530
+ catalog_name=parameters.catalog,
1531
+ schema_name=parameters.schema_name,
1532
+ )[0]
1533
+
1534
+ def action_drop_not_null(
1535
+ self,
1536
+ *,
1537
+ context: base_server.CallContext[auth.Account, auth.AccountToken],
1538
+ parameters: parameter_types.DropNotNull,
1539
+ ) -> flight.FlightInfo:
1540
+ assert context.caller is not None
1541
+
1542
+ self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
1543
+ library = self.contents[context.caller.token.token]
1544
+ database = library.by_name(parameters.catalog)
1545
+ schema = database.by_name(parameters.schema_name)
1546
+
1547
+ table_info = schema.by_name("table", parameters.name)
1548
+
1549
+ t = table_info.version()
1550
+ field_index = t.schema.get_field_index(parameters.column_name)
1551
+ field = t.schema.field(parameters.column_name)
1552
+
1553
+ table_info.update_table(
1554
+ t.set_column(
1555
+ field_index,
1556
+ field.with_nullable(True),
1557
+ t.column(field_index),
1558
+ )
1559
+ )
1560
+
1561
+ database.version += 1
1562
+
1563
+ return table_info.flight_info(
1564
+ name=parameters.name,
1565
+ catalog_name=parameters.catalog,
1566
+ schema_name=parameters.schema_name,
1567
+ )[0]
1568
+
1569
+ def action_change_column_type(
1570
+ self,
1571
+ *,
1572
+ context: base_server.CallContext[auth.Account, auth.AccountToken],
1573
+ parameters: parameter_types.ChangeColumnType,
1574
+ ) -> flight.FlightInfo:
1575
+ assert context.caller is not None
1576
+
1577
+ self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
1578
+ library = self.contents[context.caller.token.token]
1579
+ database = library.by_name(parameters.catalog)
1580
+ schema = database.by_name(parameters.schema_name)
1581
+
1582
+ table_info = schema.by_name("table", parameters.name)
1583
+
1584
+ # Defaults are set as metadata on a field.
1585
+
1586
+ t = table_info.version()
1587
+ column_name = parameters.column_schema.field(0).name
1588
+ field_index = t.schema.get_field_index(column_name)
1589
+ field = t.schema.field(column_name)
1590
+
1591
+ new_type = parameters.column_schema.field(0).type
1592
+ new_field = pa.field(field.name, new_type, metadata=field.metadata)
1593
+ new_data = pc.cast(t.column(field_index), new_type)
1594
+
1595
+ table_info.update_table(
1596
+ t.set_column(
1597
+ field_index,
1598
+ new_field,
1599
+ new_data,
1600
+ )
1601
+ )
1602
+
1603
+ database.version += 1
1604
+
1605
+ return table_info.flight_info(
1606
+ name=parameters.name,
1607
+ catalog_name=parameters.catalog,
1608
+ schema_name=parameters.schema_name,
1609
+ )[0]
1610
+
1611
+ def impl_do_get(
1612
+ self,
1613
+ *,
1614
+ context: base_server.CallContext[auth.Account, auth.AccountToken],
1615
+ ticket: flight.Ticket,
1616
+ ) -> flight.RecordBatchStream:
1617
+ assert context.caller is not None
1618
+
1619
+ ticket_data = flight_handling.decode_ticket_model(ticket, FlightTicketData)
1620
+
1621
+ descriptor_parts = descriptor_unpack_(ticket_data.descriptor)
1622
+ library = self.contents[context.caller.token.token]
1623
+ database = library.by_name(descriptor_parts.catalog_name)
1624
+ schema = database.by_name(descriptor_parts.schema_name)
1625
+
1626
+ if descriptor_parts.type == "table":
1627
+ table = schema.by_name("table", descriptor_parts.name)
1628
+
1629
+ if ticket_data.at_unit == "VERSION":
1630
+ assert ticket_data.at_value is not None
1631
+
1632
+ # Check if at_value is an integer but currently is a string.
1633
+ if not re.match(r"^\d+$", ticket_data.at_value):
1634
+ raise flight.FlightServerError(f"Invalid version: {ticket_data.at_value}")
1635
+
1636
+ table_version = table.version(int(ticket_data.at_value))
1637
+ elif ticket_data.at_unit == "TIMESTAMP":
1638
+ raise flight.FlightServerError("Timestamp not supported for table versioning")
1639
+ else:
1640
+ table_version = table.version()
1641
+
1642
+ if descriptor_parts.schema_name == "test_predicate_pushdown" and ticket_data.where_clause is not None:
1643
+ # We are going to do the predicate pushdown for filtering the data we have in memory.
1644
+ # At this point if we have JSON filters we should test that we can decode them.
1645
+ with duckdb.connect(":memory:") as connection:
1646
+ connection.execute("SET TimeZone = 'UTC'")
1647
+ sql = f"select * from table_version where {ticket_data.where_clause}"
1648
+ try:
1649
+ results = connection.execute(sql).fetch_arrow_table()
1650
+ except Exception as e:
1651
+ raise flight.FlightServerError(f"Failed to execute predicate pushdown: {e} sql: {sql}") from e
1652
+ table_version = results
1653
+
1654
+ return flight.RecordBatchStream(table_version)
1655
+ elif descriptor_parts.type == "table_function":
1656
+ table_function = schema.by_name("table_function", descriptor_parts.name)
1657
+
1658
+ output_schema = table_function.output_schema(
1659
+ ticket_data.table_function_parameters,
1660
+ ticket_data.table_function_input_schema,
1661
+ )
1662
+
1663
+ return flight.GeneratorStream(
1664
+ output_schema,
1665
+ table_function.handler(
1666
+ parameter_types.TableFunctionParameters(
1667
+ parameters=ticket_data.table_function_parameters, where_clause=ticket_data.where_clause
1668
+ ),
1669
+ output_schema,
1670
+ ),
1671
+ )
1672
+ else:
1673
+ raise flight.FlightServerError(f"Unsupported descriptor type: {descriptor_parts.type}")
1674
+
1675
+ def action_table_function_flight_info(
1676
+ self,
1677
+ *,
1678
+ context: base_server.CallContext[auth.Account, auth.AccountToken],
1679
+ parameters: parameter_types.TableFunctionFlightInfo,
1680
+ ) -> flight.FlightInfo:
1681
+ assert context.caller is not None
1682
+
1683
+ library = self.contents[context.caller.token.token]
1684
+ descriptor_parts = descriptor_unpack_(parameters.descriptor)
1685
+
1686
+ database = library.by_name(descriptor_parts.catalog_name)
1687
+ schema = database.by_name(descriptor_parts.schema_name)
1688
+
1689
+ # Now to get the table function, its a bit harder since they are named by action.
1690
+ table_function = schema.by_name("table_function", descriptor_parts.name)
1691
+
1692
+ return table_function.flight_info(
1693
+ name=descriptor_parts.name,
1694
+ catalog_name=descriptor_parts.catalog_name,
1695
+ schema_name=descriptor_parts.schema_name,
1696
+ # Pass the real parameters here.
1697
+ parameters=parameters,
1698
+ )[0]
1699
+
1700
+ def action_flight_info(
1701
+ self,
1702
+ *,
1703
+ context: base_server.CallContext[auth.Account, auth.AccountToken],
1704
+ parameters: parameter_types.FlightInfo,
1705
+ ) -> flight.FlightInfo:
1706
+ assert context.caller is not None
1707
+
1708
+ library = self.contents[context.caller.token.token]
1709
+ descriptor_parts = descriptor_unpack_(parameters.descriptor)
1710
+
1711
+ database = library.by_name(descriptor_parts.catalog_name)
1712
+ schema = database.by_name(descriptor_parts.schema_name)
1713
+
1714
+ if descriptor_parts.type == "table_function":
1715
+ raise flight.FlightServerError("Table function flight info not supported")
1716
+ elif descriptor_parts.type == "table":
1717
+ table = schema.by_name("table", descriptor_parts.name)
1718
+
1719
+ version_id = None
1720
+ if parameters.at_unit is not None:
1721
+ if parameters.at_unit == "VERSION":
1722
+ assert parameters.at_value is not None
1723
+
1724
+ # Check if at_value is an integer but currently is a string.
1725
+ if not re.match(r"^\d+$", parameters.at_value):
1726
+ raise flight.FlightServerError(f"Invalid version: {parameters.at_value}")
1727
+
1728
+ version_id = int(parameters.at_value)
1729
+ elif parameters.at_unit == "TIMESTAMP":
1730
+ raise flight.FlightServerError("Timestamp not supported for table versioning")
1731
+ else:
1732
+ raise flight.FlightServerError(f"Unsupported at_unit: {parameters.at_unit}")
1733
+
1734
+ return table.flight_info(
1735
+ name=descriptor_parts.name,
1736
+ catalog_name=descriptor_parts.catalog_name,
1737
+ schema_name=descriptor_parts.schema_name,
1738
+ version=version_id,
1739
+ )[0]
1740
+ else:
1741
+ raise flight.FlightServerError(f"Unsupported descriptor type: {descriptor_parts.type}")
1742
+
1743
+ def exchange_scalar_function(
1744
+ self,
1745
+ *,
1746
+ context: base_server.CallContext[auth.Account, auth.AccountToken],
1747
+ descriptor: flight.FlightDescriptor,
1748
+ reader: flight.MetadataRecordBatchReader,
1749
+ writer: flight.MetadataRecordBatchWriter,
1750
+ ) -> None:
1751
+ assert context.caller is not None
1752
+
1753
+ descriptor_parts = descriptor_unpack_(descriptor)
1754
+
1755
+ if descriptor_parts.type != "scalar_function":
1756
+ raise flight.FlightServerError(f"Unsupported descriptor type: {descriptor_parts.type}")
1757
+
1758
+ library = self.contents[context.caller.token.token]
1759
+ database = library.by_name(descriptor_parts.catalog_name)
1760
+ schema = database.by_name(descriptor_parts.schema_name)
1761
+ scalar_function_info = schema.by_name("scalar_function", descriptor_parts.name)
1762
+
1763
+ writer.begin(scalar_function_info.output_schema)
1764
+
1765
+ for chunk in reader:
1766
+ if chunk.data is not None:
1767
+ new_rows = pa.Table.from_batches([chunk.data])
1768
+ assert new_rows.num_rows > 0
1769
+
1770
+ result = scalar_function_info.handler(new_rows)
1771
+
1772
+ writer.write_table(pa.Table.from_arrays([result], schema=scalar_function_info.output_schema))
1773
+
1774
+
1775
+ @click.command()
1776
+ @click.option(
1777
+ "--location",
1778
+ type=str,
1779
+ default="grpc://127.0.0.1:50312",
1780
+ help="The location where the server should listen.",
1781
+ )
1782
+ def run(location: str) -> None:
1783
+ log.info("Starting server", location=location)
1784
+
1785
+ auth_manager = auth_manager_naive.AuthManagerNaive[auth.Account, auth.AccountToken](
1786
+ account_type=auth.Account,
1787
+ token_type=auth.AccountToken,
1788
+ allow_anonymous_access=False,
1789
+ )
1790
+
1791
+ server = InMemoryArrowFlightServer(
1792
+ middleware={
1793
+ "headers": base_middleware.SaveHeadersMiddlewareFactory(),
1794
+ "auth": base_middleware.AuthManagerMiddlewareFactory(auth_manager=auth_manager),
1795
+ },
1796
+ location=location,
1797
+ auth_manager=auth_manager,
1798
+ )
1799
+ server.serve()