query-farm-airport-test-server 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,9 +1,8 @@
1
1
  import hashlib
2
2
  import json
3
3
  import re
4
- from collections.abc import Callable, Generator, Iterator
5
- from dataclasses import dataclass, field
6
- from typing import Any, Literal, TypeVar, overload
4
+ from collections.abc import Generator, Iterator
5
+ from typing import Any, TypeVar
7
6
 
8
7
  import click
9
8
  import duckdb
@@ -24,7 +23,14 @@ import query_farm_flight_server.server as base_server
24
23
  import structlog
25
24
  from pydantic import BaseModel, ConfigDict, field_serializer, field_validator
26
25
 
27
- from .utils import CaseInsensitiveDict
26
+ from .database_impl import (
27
+ DatabaseContents,
28
+ DatabaseLibrary,
29
+ DatabaseLibraryContext,
30
+ SchemaCollection,
31
+ TableInfo,
32
+ descriptor_unpack_,
33
+ )
28
34
 
29
35
  log = structlog.get_logger()
30
36
 
@@ -85,317 +91,6 @@ def check_schema_is_subset_of_schema(existing_schema: pa.Schema, new_schema: pa.
85
91
  # return flight_handling.FlightTicketData.unpack(src)
86
92
 
87
93
 
88
- @dataclass
89
- class TableFunctionDynamicOutput:
90
- # The method that will determine the output schema from the input parameters
91
- schema_creator: Callable[[pa.RecordBatch, pa.Schema | None], pa.Schema]
92
-
93
- # The default parameters for the function, if not called with any.
94
- default_values: tuple[pa.RecordBatch, pa.Schema | None]
95
-
96
-
97
- @dataclass
98
- class TableFunction:
99
- # The input schema for the function.
100
- input_schema: pa.Schema
101
-
102
- output_schema_source: pa.Schema | TableFunctionDynamicOutput
103
-
104
- # The function to call to process a chunk of rows.
105
- handler: Callable[
106
- [parameter_types.TableFunctionParameters, pa.Schema],
107
- Generator[pa.RecordBatch, pa.RecordBatch, pa.RecordBatch],
108
- ]
109
-
110
- estimated_rows: int | Callable[[parameter_types.TableFunctionFlightInfo], int] = -1
111
-
112
- def output_schema(
113
- self,
114
- parameters: pa.RecordBatch | None = None,
115
- input_schema: pa.Schema | None = None,
116
- ) -> pa.Schema:
117
- if isinstance(self.output_schema_source, pa.Schema):
118
- return self.output_schema_source
119
- if parameters is None:
120
- return self.output_schema_source.schema_creator(*self.output_schema_source.default_values)
121
- assert isinstance(parameters, pa.RecordBatch)
122
- result = self.output_schema_source.schema_creator(parameters, input_schema)
123
- return result
124
-
125
- def flight_info(
126
- self,
127
- *,
128
- name: str,
129
- catalog_name: str,
130
- schema_name: str,
131
- parameters: parameter_types.TableFunctionFlightInfo | None = None,
132
- ) -> tuple[flight.FlightInfo, flight_inventory.FlightSchemaMetadata]:
133
- """
134
- Often its necessary to create a FlightInfo object
135
- standardize doing that here.
136
- """
137
- assert name != ""
138
- assert catalog_name != ""
139
- assert schema_name != ""
140
-
141
- if isinstance(self.estimated_rows, int):
142
- estimated_rows = self.estimated_rows
143
- else:
144
- assert parameters is not None
145
- estimated_rows = self.estimated_rows(parameters)
146
-
147
- metadata = flight_inventory.FlightSchemaMetadata(
148
- type="table_function",
149
- catalog=catalog_name,
150
- schema=schema_name,
151
- name=name,
152
- comment=None,
153
- input_schema=self.input_schema,
154
- )
155
- flight_info = flight.FlightInfo(
156
- self.output_schema(parameters.parameters, parameters.table_input_schema)
157
- if parameters
158
- else self.output_schema(),
159
- # This will always be the same descriptor, so that we can use the action
160
- # name to determine which which table function to execute.
161
- descriptor_pack_(catalog_name, schema_name, "table_function", name),
162
- [],
163
- estimated_rows,
164
- -1,
165
- app_metadata=metadata.serialize(),
166
- )
167
- return (flight_info, metadata)
168
-
169
-
170
- @dataclass
171
- class ScalarFunction:
172
- # The input schema for the function.
173
- input_schema: pa.Schema
174
- # The output schema for the function, should only have a single column.
175
- output_schema: pa.Schema
176
-
177
- # The function to call to process a chunk of rows.
178
- handler: Callable[[pa.Table], pa.Array]
179
-
180
- def flight_info(
181
- self, *, name: str, catalog_name: str, schema_name: str
182
- ) -> tuple[flight.FlightInfo, flight_inventory.FlightSchemaMetadata]:
183
- """
184
- Often its necessary to create a FlightInfo object
185
- standardize doing that here.
186
- """
187
- metadata = flight_inventory.FlightSchemaMetadata(
188
- type="scalar_function",
189
- catalog=catalog_name,
190
- schema=schema_name,
191
- name=name,
192
- comment=None,
193
- input_schema=self.input_schema,
194
- )
195
- flight_info = flight.FlightInfo(
196
- self.output_schema,
197
- descriptor_pack_(catalog_name, schema_name, "scalar_function", name),
198
- [],
199
- -1,
200
- -1,
201
- app_metadata=metadata.serialize(),
202
- )
203
- return (flight_info, metadata)
204
-
205
-
206
- @dataclass
207
- class TableInfo:
208
- # To enable version history keep track of tables.
209
- table_versions: list[pa.Table] = field(default_factory=list)
210
-
211
- # the next row id to assign.
212
- row_id_counter: int = 0
213
-
214
- def update_table(self, table: pa.Table) -> None:
215
- assert table is not None
216
- assert isinstance(table, pa.Table)
217
- self.table_versions.append(table)
218
-
219
- def version(self, version: int | None = None) -> pa.Table:
220
- """
221
- Get the version of the table.
222
- """
223
- assert len(self.table_versions) > 0
224
- if version is None:
225
- return self.table_versions[-1]
226
-
227
- assert version < len(self.table_versions)
228
- return self.table_versions[version]
229
-
230
- def flight_info(
231
- self,
232
- *,
233
- name: str,
234
- catalog_name: str,
235
- schema_name: str,
236
- version: int | None = None,
237
- ) -> tuple[flight.FlightInfo, flight_inventory.FlightSchemaMetadata]:
238
- """
239
- Often its necessary to create a FlightInfo object for the table,
240
- standardize doing that here.
241
- """
242
- metadata = flight_inventory.FlightSchemaMetadata(
243
- type="table",
244
- catalog=catalog_name,
245
- schema=schema_name,
246
- name=name,
247
- comment=None,
248
- )
249
- flight_info = flight.FlightInfo(
250
- self.version(version).schema,
251
- descriptor_pack_(catalog_name, schema_name, "table", name),
252
- [],
253
- -1,
254
- -1,
255
- app_metadata=metadata.serialize(),
256
- )
257
- return (flight_info, metadata)
258
-
259
-
260
- ObjectTypeName = Literal["table", "scalar_function", "table_function"]
261
-
262
-
263
- @dataclass
264
- class SchemaCollection:
265
- tables_by_name: CaseInsensitiveDict[TableInfo] = field(default_factory=CaseInsensitiveDict[TableInfo])
266
- scalar_functions_by_name: CaseInsensitiveDict[ScalarFunction] = field(
267
- default_factory=CaseInsensitiveDict[ScalarFunction]
268
- )
269
- table_functions_by_name: CaseInsensitiveDict[TableFunction] = field(
270
- default_factory=CaseInsensitiveDict[TableFunction]
271
- )
272
-
273
- def containers(
274
- self,
275
- ) -> list[
276
- CaseInsensitiveDict[TableInfo] | CaseInsensitiveDict[ScalarFunction] | CaseInsensitiveDict[TableFunction]
277
- ]:
278
- return [
279
- self.tables_by_name,
280
- self.scalar_functions_by_name,
281
- self.table_functions_by_name,
282
- ]
283
-
284
- @overload
285
- def by_name(self, type: Literal["table"], name: str) -> TableInfo: ...
286
-
287
- @overload
288
- def by_name(self, type: Literal["scalar_function"], name: str) -> ScalarFunction: ...
289
-
290
- @overload
291
- def by_name(self, type: Literal["table_function"], name: str) -> TableFunction: ...
292
-
293
- def by_name(self, type: ObjectTypeName, name: str) -> TableInfo | ScalarFunction | TableFunction:
294
- assert name is not None
295
- assert name != ""
296
- if type == "table":
297
- table = self.tables_by_name.get(name)
298
- if not table:
299
- raise flight.FlightServerError(f"Table {name} does not exist.")
300
- return table
301
- elif type == "scalar_function":
302
- scalar_function = self.scalar_functions_by_name.get(name)
303
- if not scalar_function:
304
- raise flight.FlightServerError(f"Scalar function {name} does not exist.")
305
- return scalar_function
306
- elif type == "table_function":
307
- table_function = self.table_functions_by_name.get(name)
308
- if not table_function:
309
- raise flight.FlightServerError(f"Table function {name} does not exist.")
310
- return table_function
311
-
312
-
313
- @dataclass
314
- class DatabaseContents:
315
- # Collection of schemas by name.
316
- schemas_by_name: CaseInsensitiveDict[SchemaCollection] = field(
317
- default_factory=CaseInsensitiveDict[SchemaCollection]
318
- )
319
-
320
- # The version of the database, updated on each schema change.
321
- version: int = 1
322
-
323
- def by_name(self, name: str) -> SchemaCollection:
324
- if name not in self.schemas_by_name:
325
- raise flight.FlightServerError(f"Schema {name} does not exist.")
326
- return self.schemas_by_name[name]
327
-
328
-
329
- @dataclass
330
- class DescriptorParts:
331
- """
332
- The fields that are encoded in the flight descriptor.
333
- """
334
-
335
- catalog_name: str
336
- schema_name: str
337
- type: ObjectTypeName
338
- name: str
339
-
340
-
341
- @dataclass
342
- class DatabaseLibrary:
343
- """
344
- The database library, which contains all of the databases, organized by token.
345
- """
346
-
347
- # Collection of databases by token.
348
- databases_by_name: CaseInsensitiveDict[DatabaseContents] = field(
349
- default_factory=CaseInsensitiveDict[DatabaseContents]
350
- )
351
-
352
- def by_name(self, name: str) -> DatabaseContents:
353
- if name not in self.databases_by_name:
354
- raise flight.FlightServerError(f"Database {name} does not exist.")
355
- return self.databases_by_name[name]
356
-
357
-
358
- def descriptor_pack_(
359
- catalog_name: str,
360
- schema_name: str,
361
- type: ObjectTypeName,
362
- name: str,
363
- ) -> flight.FlightDescriptor:
364
- """
365
- Pack the descriptor into a FlightDescriptor.
366
- """
367
- return flight.FlightDescriptor.for_path(f"{catalog_name}/{schema_name}/{type}/{name}")
368
-
369
-
370
- def descriptor_unpack_(descriptor: flight.FlightDescriptor) -> DescriptorParts:
371
- """
372
- Split the descriptor into its components.
373
- """
374
- assert descriptor.descriptor_type == flight.DescriptorType.PATH
375
- assert len(descriptor.path) == 1
376
- path = descriptor.path[0].decode("utf-8")
377
- parts = path.split("/")
378
- if len(parts) != 4:
379
- raise flight.FlightServerError(f"Invalid descriptor path: {path}")
380
-
381
- descriptor_type: ObjectTypeName
382
- if parts[2] == "table":
383
- descriptor_type = "table"
384
- elif parts[2] == "scalar_function":
385
- descriptor_type = "scalar_function"
386
- elif parts[2] == "table_function":
387
- descriptor_type = "table_function"
388
- else:
389
- raise flight.FlightServerError(f"Invalid descriptor type: {parts[2]}")
390
-
391
- return DescriptorParts(
392
- catalog_name=parts[0],
393
- schema_name=parts[1],
394
- type=descriptor_type,
395
- name=parts[3],
396
- )
397
-
398
-
399
94
  class FlightTicketData(BaseModel):
400
95
  model_config = ConfigDict(arbitrary_types_allowed=True) # for Pydantic v2
401
96
  descriptor: flight.FlightDescriptor
@@ -449,9 +144,6 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
449
144
  self._auth_manager = auth_manager
450
145
  super().__init__(location=location, **kwargs)
451
146
 
452
- # token, database name, schema, table_name
453
- self.contents: dict[str, DatabaseLibrary] = {}
454
-
455
147
  self.ROWID_FIELD_NAME = "rowid"
456
148
  self.rowid_field = pa.field(self.ROWID_FIELD_NAME, pa.int64(), metadata={"is_rowid": "1"})
457
149
 
@@ -464,49 +156,51 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
464
156
  assert context.caller is not None
465
157
 
466
158
  descriptor_parts = descriptor_unpack_(parameters.descriptor)
467
- library = self.contents[context.caller.token.token]
468
- database = library.by_name(descriptor_parts.catalog_name)
469
- schema = database.by_name(descriptor_parts.schema_name)
470
-
471
- filter_sql_where_clause: str | None = None
472
- if parameters.parameters.json_filters is not None:
473
- context.logger.debug("duckdb_input", input=json.dumps(parameters.parameters.json_filters.filters))
474
- filter_sql_where_clause, filter_sql_field_type_info = (
475
- query_farm_duckdb_json_serialization.expression.convert_to_sql(
476
- source=parameters.parameters.json_filters.filters,
477
- bound_column_names=parameters.parameters.json_filters.column_binding_names_by_index,
159
+ with DatabaseLibraryContext(context.caller.token.token, readonly=True) as library:
160
+ database = library.by_name(descriptor_parts.catalog_name)
161
+ schema = database.by_name(descriptor_parts.schema_name)
162
+
163
+ filter_sql_where_clause: str | None = None
164
+ if parameters.parameters.json_filters is not None:
165
+ context.logger.debug("duckdb_input", input=json.dumps(parameters.parameters.json_filters.filters))
166
+ filter_sql_where_clause, filter_sql_field_type_info = (
167
+ query_farm_duckdb_json_serialization.expression.convert_to_sql(
168
+ source=parameters.parameters.json_filters.filters,
169
+ bound_column_names=parameters.parameters.json_filters.column_binding_names_by_index,
170
+ )
478
171
  )
479
- )
480
- if filter_sql_where_clause == "":
481
- filter_sql_where_clause = None
172
+ if filter_sql_where_clause == "":
173
+ filter_sql_where_clause = None
482
174
 
483
- if descriptor_parts.type == "table":
484
- schema.by_name("table", descriptor_parts.name)
175
+ if descriptor_parts.type == "table":
176
+ table_info = schema.by_name("table", descriptor_parts.name)
485
177
 
486
- ticket_data = FlightTicketData(
487
- descriptor=parameters.descriptor,
488
- where_clause=filter_sql_where_clause,
489
- at_unit=parameters.parameters.at_unit,
490
- at_value=parameters.parameters.at_value,
491
- )
492
-
493
- return [flight_handling.endpoint(ticket_data=ticket_data, locations=None)]
494
- elif descriptor_parts.type == "table_function":
495
- # So the table function may not exist, because its a dynamic descriptor.
496
-
497
- schema.by_name("table_function", descriptor_parts.name)
178
+ ticket_data = FlightTicketData(
179
+ descriptor=parameters.descriptor,
180
+ where_clause=filter_sql_where_clause,
181
+ at_unit=parameters.parameters.at_unit,
182
+ at_value=parameters.parameters.at_value,
183
+ )
498
184
 
499
- ticket_data = FlightTicketData(
500
- descriptor=parameters.descriptor,
501
- where_clause=filter_sql_where_clause,
502
- table_function_parameters=parameters.parameters.table_function_parameters,
503
- table_function_input_schema=parameters.parameters.table_function_input_schema,
504
- at_unit=parameters.parameters.at_unit,
505
- at_value=parameters.parameters.at_value,
506
- )
507
- return [flight_handling.endpoint(ticket_data=ticket_data, locations=None)]
508
- else:
509
- raise flight.FlightServerError(f"Unsupported descriptor type: {descriptor_parts.type}")
185
+ if table_info.endpoint_generator is not None:
186
+ return table_info.endpoint_generator(ticket_data)
187
+ return [flight_handling.endpoint(ticket_data=ticket_data, locations=None)]
188
+ elif descriptor_parts.type == "table_function":
189
+ # So the table function may not exist, because its a dynamic descriptor.
190
+
191
+ schema.by_name("table_function", descriptor_parts.name)
192
+
193
+ ticket_data = FlightTicketData(
194
+ descriptor=parameters.descriptor,
195
+ where_clause=filter_sql_where_clause,
196
+ table_function_parameters=parameters.parameters.table_function_parameters,
197
+ table_function_input_schema=parameters.parameters.table_function_input_schema,
198
+ at_unit=parameters.parameters.at_unit,
199
+ at_value=parameters.parameters.at_value,
200
+ )
201
+ return [flight_handling.endpoint(ticket_data=ticket_data, locations=None)]
202
+ else:
203
+ raise flight.FlightServerError(f"Unsupported descriptor type: {descriptor_parts.type}")
510
204
 
511
205
  def action_list_schemas(
512
206
  self,
@@ -516,41 +210,40 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
516
210
  ) -> base_server.AirportSerializedCatalogRoot:
517
211
  assert context.caller is not None
518
212
 
519
- self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
520
- library = self.contents[context.caller.token.token]
521
- database = library.by_name(parameters.catalog_name)
213
+ with DatabaseLibraryContext(context.caller.token.token, readonly=True) as library:
214
+ database = library.by_name(parameters.catalog_name)
522
215
 
523
- dynamic_inventory: dict[str, dict[str, list[flight_inventory.FlightInventoryWithMetadata]]] = {}
216
+ dynamic_inventory: dict[str, dict[str, list[flight_inventory.FlightInventoryWithMetadata]]] = {}
524
217
 
525
- catalog_contents = dynamic_inventory.setdefault(parameters.catalog_name, {})
218
+ catalog_contents = dynamic_inventory.setdefault(parameters.catalog_name, {})
526
219
 
527
- for schema_name, schema in database.schemas_by_name.items():
528
- schema_contents = catalog_contents.setdefault(schema_name, [])
529
- for coll in schema.containers():
530
- for name, obj in coll.items():
531
- schema_contents.append(
532
- obj.flight_info(
533
- name=name,
534
- catalog_name=parameters.catalog_name,
535
- schema_name=schema_name,
220
+ for schema_name, schema in database.schemas_by_name.items():
221
+ schema_contents = catalog_contents.setdefault(schema_name, [])
222
+ for coll in schema.containers():
223
+ for name, obj in coll.items():
224
+ schema_contents.append(
225
+ obj.flight_info(
226
+ name=name,
227
+ catalog_name=parameters.catalog_name,
228
+ schema_name=schema_name,
229
+ )
536
230
  )
537
- )
538
231
 
539
- return flight_inventory.upload_and_generate_schema_list(
540
- upload_parameters=flight_inventory.UploadParameters(
541
- s3_client=None,
542
- base_url="http://localhost",
543
- bucket_name="test_bucket",
544
- bucket_prefix="test_prefix",
545
- ),
546
- flight_service_name=self.service_name,
547
- flight_inventory=dynamic_inventory,
548
- schema_details={},
549
- skip_upload=True,
550
- serialize_inline=True,
551
- catalog_version=1,
552
- catalog_version_fixed=False,
553
- )
232
+ return flight_inventory.upload_and_generate_schema_list(
233
+ upload_parameters=flight_inventory.UploadParameters(
234
+ s3_client=None,
235
+ base_url="http://localhost",
236
+ bucket_name="test_bucket",
237
+ bucket_prefix="test_prefix",
238
+ ),
239
+ flight_service_name=self.service_name,
240
+ flight_inventory=dynamic_inventory,
241
+ schema_details={},
242
+ skip_upload=True,
243
+ serialize_inline=True,
244
+ catalog_version=1,
245
+ catalog_version_fixed=False,
246
+ )
554
247
 
555
248
  def impl_list_flights(
556
249
  self,
@@ -559,17 +252,16 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
559
252
  criteria: bytes,
560
253
  ) -> Iterator[flight.FlightInfo]:
561
254
  assert context.caller is not None
562
- self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
563
- library = self.contents[context.caller.token.token]
255
+ with DatabaseLibraryContext(context.caller.token.token, readonly=True) as library:
564
256
 
565
- def yield_flight_infos() -> Generator[flight.FlightInfo, None, None]:
566
- for db_name, db in library.databases_by_name.items():
567
- for schema_name, schema in db.schemas_by_name.items():
568
- for coll in schema.containers():
569
- for name, obj in coll.items():
570
- yield obj.flight_info(name=name, catalog_name=db_name, schema_name=schema_name)[0]
257
+ def yield_flight_infos() -> Generator[flight.FlightInfo, None, None]:
258
+ for db_name, db in library.databases_by_name.items():
259
+ for schema_name, schema in db.schemas_by_name.items():
260
+ for coll in schema.containers():
261
+ for name, obj in coll.items():
262
+ yield obj.flight_info(name=name, catalog_name=db_name, schema_name=schema_name)[0]
571
263
 
572
- return yield_flight_infos()
264
+ return yield_flight_infos()
573
265
 
574
266
  def impl_get_flight_info(
575
267
  self,
@@ -578,19 +270,18 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
578
270
  descriptor: flight.FlightDescriptor,
579
271
  ) -> flight.FlightInfo:
580
272
  assert context.caller is not None
581
- self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
582
273
 
583
274
  descriptor_parts = descriptor_unpack_(descriptor)
584
- library = self.contents[context.caller.token.token]
585
- database = library.by_name(descriptor_parts.catalog_name)
586
- schema = database.by_name(descriptor_parts.schema_name)
275
+ with DatabaseLibraryContext(context.caller.token.token, readonly=True) as library:
276
+ database = library.by_name(descriptor_parts.catalog_name)
277
+ schema = database.by_name(descriptor_parts.schema_name)
587
278
 
588
- obj = schema.by_name(descriptor_parts.type, descriptor_parts.name)
589
- return obj.flight_info(
590
- name=descriptor_parts.name,
591
- catalog_name=descriptor_parts.catalog_name,
592
- schema_name=descriptor_parts.schema_name,
593
- )[0]
279
+ obj = schema.by_name(descriptor_parts.type, descriptor_parts.name)
280
+ return obj.flight_info(
281
+ name=descriptor_parts.name,
282
+ catalog_name=descriptor_parts.catalog_name,
283
+ schema_name=descriptor_parts.schema_name,
284
+ )[0]
594
285
 
595
286
  def action_catalog_version(
596
287
  self,
@@ -600,16 +291,15 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
600
291
  ) -> base_server.GetCatalogVersionResult:
601
292
  assert context.caller is not None
602
293
 
603
- self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
604
- library = self.contents[context.caller.token.token]
605
- database = library.by_name(parameters.catalog_name)
294
+ with DatabaseLibraryContext(context.caller.token.token, readonly=True) as library:
295
+ database = library.by_name(parameters.catalog_name)
606
296
 
607
- context.logger.debug(
608
- "catalog_version_result",
609
- catalog_name=parameters.catalog_name,
610
- version=database.version,
611
- )
612
- return base_server.GetCatalogVersionResult(catalog_version=database.version, is_fixed=False)
297
+ context.logger.debug(
298
+ "catalog_version_result",
299
+ catalog_name=parameters.catalog_name,
300
+ version=database.version,
301
+ )
302
+ return base_server.GetCatalogVersionResult(catalog_version=database.version, is_fixed=False)
613
303
 
614
304
  def action_create_transaction(
615
305
  self,
@@ -627,26 +317,25 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
627
317
  ) -> base_server.AirportSerializedContentsWithSHA256Hash:
628
318
  assert context.caller is not None
629
319
 
630
- self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
631
- library = self.contents[context.caller.token.token]
632
- database = library.by_name(parameters.catalog_name)
320
+ with DatabaseLibraryContext(context.caller.token.token) as library:
321
+ database = library.by_name(parameters.catalog_name)
633
322
 
634
- if database.schemas_by_name.get(parameters.schema_name) is not None:
635
- raise flight.FlightServerError(f"Schema {parameters.schema_name} already exists")
323
+ if database.schemas_by_name.get(parameters.schema_name) is not None:
324
+ raise flight.FlightServerError(f"Schema {parameters.schema_name} already exists")
636
325
 
637
- database.schemas_by_name[parameters.schema_name] = SchemaCollection()
638
- database.version += 1
326
+ database.schemas_by_name[parameters.schema_name] = SchemaCollection()
327
+ database.version += 1
639
328
 
640
- # FIXME: this needs to be handled better on the server side...
641
- # rather than calling into internal methods.
642
- packed_data = msgpack.packb([])
643
- assert packed_data
644
- compressed_data = schema_uploader._compress_and_prefix_with_length(packed_data, compression_level=3)
329
+ # FIXME: this needs to be handled better on the server side...
330
+ # rather than calling into internal methods.
331
+ packed_data = msgpack.packb([])
332
+ assert packed_data
333
+ compressed_data = schema_uploader._compress_and_prefix_with_length(packed_data, compression_level=3)
645
334
 
646
- empty_hash = hashlib.sha256(compressed_data).hexdigest()
647
- return base_server.AirportSerializedContentsWithSHA256Hash(
648
- url=None, sha256=empty_hash, serialized=compressed_data
649
- )
335
+ empty_hash = hashlib.sha256(compressed_data).hexdigest()
336
+ return base_server.AirportSerializedContentsWithSHA256Hash(
337
+ url=None, sha256=empty_hash, serialized=compressed_data
338
+ )
650
339
 
651
340
  def action_drop_table(
652
341
  self,
@@ -656,15 +345,14 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
656
345
  ) -> None:
657
346
  assert context.caller is not None
658
347
 
659
- self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
660
- library = self.contents[context.caller.token.token]
661
- database = library.by_name(parameters.catalog_name)
662
- schema = database.by_name(parameters.schema_name)
348
+ with DatabaseLibraryContext(context.caller.token.token) as library:
349
+ database = library.by_name(parameters.catalog_name)
350
+ schema = database.by_name(parameters.schema_name)
663
351
 
664
- schema.by_name("table", parameters.name)
352
+ schema.by_name("table", parameters.name)
665
353
 
666
- del schema.tables_by_name[parameters.name]
667
- database.version += 1
354
+ del schema.tables_by_name[parameters.name]
355
+ database.version += 1
668
356
 
669
357
  def action_drop_schema(
670
358
  self,
@@ -674,15 +362,14 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
674
362
  ) -> None:
675
363
  assert context.caller is not None
676
364
 
677
- self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
678
- library = self.contents[context.caller.token.token]
679
- database = library.by_name(parameters.catalog_name)
365
+ with DatabaseLibraryContext(context.caller.token.token) as library:
366
+ database = library.by_name(parameters.catalog_name)
680
367
 
681
- if database.schemas_by_name.get(parameters.name) is None:
682
- raise flight.FlightServerError(f"Schema '{parameters.name}' does not exist")
368
+ if database.schemas_by_name.get(parameters.name) is None:
369
+ raise flight.FlightServerError(f"Schema '{parameters.name}' does not exist")
683
370
 
684
- del database.schemas_by_name[parameters.name]
685
- database.version += 1
371
+ del database.schemas_by_name[parameters.name]
372
+ database.version += 1
686
373
 
687
374
  def action_create_table(
688
375
  self,
@@ -692,28 +379,27 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
692
379
  ) -> flight.FlightInfo:
693
380
  assert context.caller is not None
694
381
 
695
- self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
696
- library = self.contents[context.caller.token.token]
697
- database = library.by_name(parameters.catalog_name)
698
- schema = database.by_name(parameters.schema_name)
382
+ with DatabaseLibraryContext(context.caller.token.token) as library:
383
+ database = library.by_name(parameters.catalog_name)
384
+ schema = database.by_name(parameters.schema_name)
699
385
 
700
- if parameters.table_name in schema.tables_by_name:
701
- raise flight.FlightServerError(
702
- f"Table {parameters.table_name} already exists for token {context.caller.token}"
703
- )
386
+ if parameters.table_name in schema.tables_by_name:
387
+ raise flight.FlightServerError(
388
+ f"Table {parameters.table_name} already exists for token {context.caller.token}"
389
+ )
704
390
 
705
- # FIXME: may want to add a row_id column that is not visable to the user, so that inserts and
706
- # deletes can be tested.
391
+ # FIXME: may want to add a row_id column that is not visable to the user, so that inserts and
392
+ # deletes can be tested.
707
393
 
708
- assert "_rowid" not in parameters.arrow_schema.names
394
+ assert "_rowid" not in parameters.arrow_schema.names
709
395
 
710
- schema_with_row_id = parameters.arrow_schema.append(self.rowid_field)
396
+ schema_with_row_id = parameters.arrow_schema.append(self.rowid_field)
711
397
 
712
- table_info = TableInfo([schema_with_row_id.empty_table()], 0)
398
+ table_info = TableInfo([schema_with_row_id.empty_table()], 0)
713
399
 
714
- schema.tables_by_name[parameters.table_name] = table_info
400
+ schema.tables_by_name[parameters.table_name] = table_info
715
401
 
716
- database.version += 1
402
+ database.version += 1
717
403
 
718
404
  return table_info.flight_info(
719
405
  name=parameters.table_name,
@@ -731,348 +417,36 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
731
417
 
732
418
  if action.type == "reset":
733
419
  context.logger.debug("Resetting server state")
734
- if context.caller.token.token in self.contents:
735
- del self.contents[context.caller.token.token]
420
+ DatabaseLibrary.reset(context.caller.token.token)
736
421
  return iter([])
422
+ elif action.type == "generate_error":
423
+ error_name = action.body.to_pybytes().decode("utf-8")
424
+ known_errors = {
425
+ "flight_unavailable": flight.FlightUnavailableError,
426
+ "flight_server_error": flight.FlightServerError,
427
+ "flight_unauthenticated": flight.FlightUnauthenticatedError,
428
+ }
429
+ if error_name in known_errors:
430
+ raise known_errors[error_name](f"Testing error: {error_name}")
431
+ else:
432
+ context.logger.error("Unknown error type", error_name=error_name)
433
+ raise flight.FlightServerError(f"Unknown error type: {error_name}")
737
434
  elif action.type == "create_database":
738
435
  database_name = action.body.to_pybytes().decode("utf-8")
739
436
  context.logger.debug("Creating database", database_name=database_name)
740
- self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
741
- library = self.contents[context.caller.token.token]
742
- library.databases_by_name[database_name] = DatabaseContents()
743
-
744
- # Since we are creating a new database, lets load it with a few example
745
- # scalar functions.
746
-
747
- def add_handler(table: pa.Table) -> pa.Array:
748
- assert table.num_columns == 2
749
- return pc.add(table.column(0), table.column(1))
750
-
751
- def uppercase_handler(table: pa.Table) -> pa.Array:
752
- assert table.num_columns == 1
753
- return pc.utf8_upper(table.column(0))
754
-
755
- def any_type_handler(table: pa.Table) -> pa.Array:
756
- return table.column(0)
757
-
758
- def echo_handler(
759
- parameters: parameter_types.TableFunctionParameters,
760
- output_schema: pa.Schema,
761
- ) -> Generator[pa.RecordBatch, pa.RecordBatch, None]:
762
- # Just echo the parameters back as a single row.
763
- assert parameters.parameters
764
- yield pa.RecordBatch.from_arrays(
765
- [parameters.parameters.column(0)],
766
- schema=pa.schema([pa.field("result", pa.string())]),
767
- )
768
-
769
- def long_handler(
770
- parameters: parameter_types.TableFunctionParameters,
771
- output_schema: pa.Schema,
772
- ) -> Generator[pa.RecordBatch, pa.RecordBatch, None]:
773
- assert parameters.parameters
774
- for i in range(100):
775
- yield pa.RecordBatch.from_arrays([[f"{i}"] * 3000] * len(output_schema), schema=output_schema)
776
-
777
- def repeat_handler(
778
- parameters: parameter_types.TableFunctionParameters,
779
- output_schema: pa.Schema,
780
- ) -> Generator[pa.RecordBatch, pa.RecordBatch, None]:
781
- # Just echo the parameters back as a single row.
782
- assert parameters.parameters
783
- for _i in range(parameters.parameters.column(1).to_pylist()[0]):
784
- yield pa.RecordBatch.from_arrays(
785
- [parameters.parameters.column(0)],
786
- schema=output_schema,
787
- )
788
-
789
- def wide_handler(
790
- parameters: parameter_types.TableFunctionParameters,
791
- output_schema: pa.Schema,
792
- ) -> Generator[pa.RecordBatch, pa.RecordBatch, None]:
793
- # Just echo the parameters back as a single row.
794
- assert parameters.parameters
795
- rows = []
796
- for i in range(parameters.parameters.column(0).to_pylist()[0]):
797
- rows.append({f"result_{idx}": idx for idx in range(20)})
798
-
799
- yield pa.RecordBatch.from_pylist(rows, schema=output_schema)
800
-
801
- def dynamic_schema_handler(
802
- parameters: parameter_types.TableFunctionParameters,
803
- output_schema: pa.Schema,
804
- ) -> Generator[pa.RecordBatch, pa.RecordBatch, None]:
805
- yield parameters.parameters
806
-
807
- def dynamic_schema_handler_output_schema(
808
- parameters: pa.RecordBatch, input_schema: pa.Schema | None = None
809
- ) -> pa.Schema:
810
- # This is the schema that will be returned to the client.
811
- # It will be used to create the table function.
812
- assert isinstance(parameters, pa.RecordBatch)
813
- return parameters.schema
814
-
815
- def in_out_schema_handler(parameters: pa.RecordBatch, input_schema: pa.Schema | None = None) -> pa.Schema:
816
- assert input_schema is not None
817
- return pa.schema([parameters.schema.field(0), input_schema.field(0)])
818
-
819
- def in_out_wide_schema_handler(
820
- parameters: pa.RecordBatch, input_schema: pa.Schema | None = None
821
- ) -> pa.Schema:
822
- assert input_schema is not None
823
- return pa.schema([pa.field(f"result_{i}", pa.int32()) for i in range(20)])
824
-
825
- def in_out_echo_schema_handler(
826
- parameters: pa.RecordBatch, input_schema: pa.Schema | None = None
827
- ) -> pa.Schema:
828
- assert input_schema is not None
829
- return input_schema
830
-
831
- def in_out_echo_handler(
832
- parameters: parameter_types.TableFunctionParameters,
833
- output_schema: pa.Schema,
834
- ) -> Generator[pa.RecordBatch, pa.RecordBatch, None]:
835
- result = output_schema.empty_table()
836
-
837
- while True:
838
- input_chunk = yield result
839
-
840
- if input_chunk is None:
841
- break
842
-
843
- result = input_chunk
844
-
845
- return
846
-
847
- def in_out_wide_handler(
848
- parameters: parameter_types.TableFunctionParameters,
849
- output_schema: pa.Schema,
850
- ) -> Generator[pa.RecordBatch, pa.RecordBatch, None]:
851
- result = output_schema.empty_table()
852
-
853
- while True:
854
- input_chunk = yield result
855
-
856
- if input_chunk is None:
857
- break
858
-
859
- result = pa.RecordBatch.from_arrays(
860
- [[i] * len(input_chunk) for i in range(20)],
861
- schema=output_schema,
862
- )
863
-
864
- return
865
-
866
- def in_out_handler(
867
- parameters: parameter_types.TableFunctionParameters,
868
- output_schema: pa.Schema,
869
- ) -> Generator[pa.RecordBatch, pa.RecordBatch, None]:
870
- result = output_schema.empty_table()
871
-
872
- while True:
873
- input_chunk = yield result
874
-
875
- if input_chunk is None:
876
- break
877
-
878
- assert parameters.parameters
879
- result = pa.RecordBatch.from_arrays(
880
- [
881
- parameters.parameters.column(0),
882
- input_chunk.column(0),
883
- ],
884
- schema=output_schema,
885
- )
886
-
887
- return pa.RecordBatch.from_arrays([["last"], ["row"]], schema=output_schema)
888
-
889
- util_schema = SchemaCollection(
890
- scalar_functions_by_name=CaseInsensitiveDict(
891
- {
892
- "test_uppercase": ScalarFunction(
893
- input_schema=pa.schema([pa.field("a", pa.string())]),
894
- output_schema=pa.schema([pa.field("result", pa.string())]),
895
- handler=uppercase_handler,
896
- ),
897
- "test_any_type": ScalarFunction(
898
- input_schema=pa.schema([pa.field("a", pa.string(), metadata={"is_any_type": "1"})]),
899
- output_schema=pa.schema([pa.field("result", pa.string())]),
900
- handler=any_type_handler,
901
- ),
902
- "test_add": ScalarFunction(
903
- input_schema=pa.schema([pa.field("a", pa.int64()), pa.field("b", pa.int64())]),
904
- output_schema=pa.schema([pa.field("result", pa.int64())]),
905
- handler=add_handler,
906
- ),
907
- }
908
- ),
909
- table_functions_by_name=CaseInsensitiveDict(
910
- {
911
- "test_echo": TableFunction(
912
- input_schema=pa.schema([pa.field("input", pa.string())]),
913
- output_schema_source=pa.schema([pa.field("result", pa.string())]),
914
- handler=echo_handler,
915
- ),
916
- "test_wide": TableFunction(
917
- input_schema=pa.schema([pa.field("count", pa.int32())]),
918
- output_schema_source=pa.schema([pa.field(f"result_{i}", pa.int32()) for i in range(20)]),
919
- handler=wide_handler,
920
- ),
921
- "test_long": TableFunction(
922
- input_schema=pa.schema([pa.field("input", pa.string())]),
923
- output_schema_source=pa.schema(
924
- [
925
- pa.field("result", pa.string()),
926
- pa.field("result2", pa.string()),
927
- ]
928
- ),
929
- handler=long_handler,
930
- ),
931
- "test_repeat": TableFunction(
932
- input_schema=pa.schema(
933
- [
934
- pa.field("input", pa.string()),
935
- pa.field("count", pa.int32()),
936
- ]
937
- ),
938
- output_schema_source=pa.schema([pa.field("result", pa.string())]),
939
- handler=repeat_handler,
940
- ),
941
- "test_dynamic_schema": TableFunction(
942
- input_schema=pa.schema(
943
- [
944
- pa.field(
945
- "input",
946
- pa.string(),
947
- metadata={"is_any_type": "1"},
948
- )
949
- ]
950
- ),
951
- output_schema_source=TableFunctionDynamicOutput(
952
- schema_creator=dynamic_schema_handler_output_schema,
953
- default_values=(
954
- pa.RecordBatch.from_arrays(
955
- [pa.array([1], type=pa.int32())],
956
- schema=pa.schema([pa.field("input", pa.int32())]),
957
- ),
958
- None,
959
- ),
960
- ),
961
- handler=dynamic_schema_handler,
962
- ),
963
- "test_dynamic_schema_named_parameters": TableFunction(
964
- input_schema=pa.schema(
965
- [
966
- pa.field("name", pa.string()),
967
- pa.field(
968
- "location",
969
- pa.string(),
970
- metadata={"is_named_parameter": "1"},
971
- ),
972
- pa.field(
973
- "input",
974
- pa.string(),
975
- metadata={"is_any_type": "1"},
976
- ),
977
- pa.field("city", pa.string()),
978
- ]
979
- ),
980
- output_schema_source=TableFunctionDynamicOutput(
981
- schema_creator=dynamic_schema_handler_output_schema,
982
- default_values=(
983
- pa.RecordBatch.from_arrays(
984
- [pa.array([1], type=pa.int32())],
985
- schema=pa.schema([pa.field("input", pa.int32())]),
986
- ),
987
- None,
988
- ),
989
- ),
990
- handler=dynamic_schema_handler,
991
- ),
992
- "test_table_in_out": TableFunction(
993
- input_schema=pa.schema(
994
- [
995
- pa.field("input", pa.string()),
996
- pa.field(
997
- "table_input",
998
- pa.string(),
999
- metadata={"is_table_type": "1"},
1000
- ),
1001
- ]
1002
- ),
1003
- output_schema_source=TableFunctionDynamicOutput(
1004
- schema_creator=in_out_schema_handler,
1005
- default_values=(
1006
- pa.RecordBatch.from_arrays(
1007
- [pa.array([1], type=pa.int32())],
1008
- schema=pa.schema([pa.field("input", pa.int32())]),
1009
- ),
1010
- pa.schema([pa.field("input", pa.int32())]),
1011
- ),
1012
- ),
1013
- handler=in_out_handler,
1014
- ),
1015
- "test_table_in_out_wide": TableFunction(
1016
- input_schema=pa.schema(
1017
- [
1018
- pa.field("input", pa.string()),
1019
- pa.field(
1020
- "table_input",
1021
- pa.string(),
1022
- metadata={"is_table_type": "1"},
1023
- ),
1024
- ]
1025
- ),
1026
- output_schema_source=TableFunctionDynamicOutput(
1027
- schema_creator=in_out_wide_schema_handler,
1028
- default_values=(
1029
- pa.RecordBatch.from_arrays(
1030
- [pa.array([1], type=pa.int32())],
1031
- schema=pa.schema([pa.field("input", pa.int32())]),
1032
- ),
1033
- pa.schema([pa.field("input", pa.int32())]),
1034
- ),
1035
- ),
1036
- handler=in_out_wide_handler,
1037
- ),
1038
- "test_table_in_out_echo": TableFunction(
1039
- input_schema=pa.schema(
1040
- [
1041
- pa.field(
1042
- "table_input",
1043
- pa.string(),
1044
- metadata={"is_table_type": "1"},
1045
- ),
1046
- ]
1047
- ),
1048
- output_schema_source=TableFunctionDynamicOutput(
1049
- schema_creator=in_out_echo_schema_handler,
1050
- default_values=(
1051
- pa.RecordBatch.from_arrays(
1052
- [pa.array([1], type=pa.int32())],
1053
- schema=pa.schema([pa.field("input", pa.int32())]),
1054
- ),
1055
- pa.schema([pa.field("input", pa.int32())]),
1056
- ),
1057
- ),
1058
- handler=in_out_echo_handler,
1059
- ),
1060
- }
1061
- ),
1062
- )
1063
-
1064
- library.databases_by_name[database_name].schemas_by_name["utils"] = util_schema
1065
-
437
+ with DatabaseLibraryContext(context.caller.token.token) as library:
438
+ if database_name in library.databases_by_name:
439
+ raise flight.FlightServerError(f"Database {database_name} already exists")
440
+ library.databases_by_name[database_name] = DatabaseContents()
1066
441
  return iter([])
1067
442
  elif action.type == "drop_database":
1068
443
  database_name = action.body.to_pybytes().decode("utf-8")
1069
444
  context.logger.debug("Dropping database", database_name=database_name)
1070
445
 
1071
- self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
1072
- library = self.contents[context.caller.token.token]
1073
- if action.body.decode("utf-8") not in library.databases_by_name:
1074
- raise flight.FlightServerError(f"Database {action.body.decode('utf-8')} does not exist")
1075
- del library.databases_by_name[action.body.decode("utf-8")]
446
+ with DatabaseLibraryContext(context.caller.token.token) as library:
447
+ if action.body.decode("utf-8") not in library.databases_by_name:
448
+ raise flight.FlightServerError(f"Database {action.body.decode('utf-8')} does not exist")
449
+ del library.databases_by_name[action.body.decode("utf-8")]
1076
450
  return iter([])
1077
451
 
1078
452
  raise flight.FlightServerError(f"Unsupported action type: {action.type}")
@@ -1092,91 +466,92 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
1092
466
 
1093
467
  if descriptor_parts.type != "table":
1094
468
  raise flight.FlightServerError(f"Unsupported descriptor type: {descriptor_parts.type}")
1095
- library = self.contents[context.caller.token.token]
1096
- database = library.by_name(descriptor_parts.catalog_name)
1097
- schema = database.by_name(descriptor_parts.schema_name)
1098
- table_info = schema.by_name("table", descriptor_parts.name)
1099
469
 
1100
- existing_table = table_info.version()
470
+ with DatabaseLibraryContext(context.caller.token.token) as library:
471
+ database = library.by_name(descriptor_parts.catalog_name)
472
+ schema = database.by_name(descriptor_parts.schema_name)
473
+ table_info = schema.by_name("table", descriptor_parts.name)
1101
474
 
1102
- writer.begin(existing_table.schema)
475
+ existing_table = table_info.version()
1103
476
 
1104
- rowid_index = existing_table.schema.get_field_index(self.ROWID_FIELD_NAME)
1105
- assert rowid_index != -1
477
+ writer.begin(existing_table.schema)
1106
478
 
1107
- change_count = 0
479
+ rowid_index = existing_table.schema.get_field_index(self.ROWID_FIELD_NAME)
480
+ assert rowid_index != -1
1108
481
 
1109
- for chunk in reader:
1110
- if chunk.data is not None:
1111
- chunk_table = pa.Table.from_batches([chunk.data])
1112
- assert chunk_table.num_rows > 0
482
+ change_count = 0
1113
483
 
1114
- # So this chunk will contain any updated columns and the row id.
484
+ for chunk in reader:
485
+ if chunk.data is not None:
486
+ chunk_table = pa.Table.from_batches([chunk.data])
487
+ assert chunk_table.num_rows > 0
1115
488
 
1116
- input_rowid_index = chunk_table.schema.get_field_index(self.ROWID_FIELD_NAME)
489
+ # So this chunk will contain any updated columns and the row id.
1117
490
 
1118
- # To perform an update, first remove the rows from the table,
1119
- # then assign new row ids to the incoming rows, and append them to the table.
1120
- table_mask = pc.is_in(
1121
- existing_table.column(rowid_index),
1122
- value_set=chunk_table.column(input_rowid_index),
1123
- )
491
+ input_rowid_index = chunk_table.schema.get_field_index(self.ROWID_FIELD_NAME)
1124
492
 
1125
- # This is the table with updated rows removed, we'll be adding rows to it later on.
1126
- table_without_updated_rows = pc.filter(existing_table, pc.invert(table_mask))
493
+ # To perform an update, first remove the rows from the table,
494
+ # then assign new row ids to the incoming rows, and append them to the table.
495
+ table_mask = pc.is_in(
496
+ existing_table.column(rowid_index),
497
+ value_set=chunk_table.column(input_rowid_index),
498
+ )
1127
499
 
1128
- # These are the rows that are being updated, since the update may not send all
1129
- # the columns, we need to filter the table to get the updated rows to persist the
1130
- # values that aren't being updated.
1131
- changed_rows = pc.filter(existing_table, table_mask)
500
+ # This is the table with updated rows removed, we'll be adding rows to it later on.
501
+ table_without_updated_rows = pc.filter(existing_table, pc.invert(table_mask))
1132
502
 
1133
- # Get a list of all of the columns that are not being updated, so that a join
1134
- # can be performed.
1135
- unchanged_column_names = set(existing_table.schema.names) - set(chunk_table.schema.names)
503
+ # These are the rows that are being updated, since the update may not send all
504
+ # the columns, we need to filter the table to get the updated rows to persist the
505
+ # values that aren't being updated.
506
+ changed_rows = pc.filter(existing_table, table_mask)
1136
507
 
1137
- joined_table = pa.Table.join(
1138
- changed_rows.select(list(unchanged_column_names) + [self.ROWID_FIELD_NAME]),
1139
- chunk_table,
1140
- keys=[self.ROWID_FIELD_NAME],
1141
- join_type="inner",
1142
- )
508
+ # Get a list of all of the columns that are not being updated, so that a join
509
+ # can be performed.
510
+ unchanged_column_names = set(existing_table.schema.names) - set(chunk_table.schema.names)
1143
511
 
1144
- # Add the new row id column.
1145
- chunk_length = len(joined_table)
1146
- rowid_values = [
1147
- x
1148
- for x in range(
1149
- table_info.row_id_counter,
1150
- table_info.row_id_counter + chunk_length,
512
+ joined_table = pa.Table.join(
513
+ changed_rows.select(list(unchanged_column_names) + [self.ROWID_FIELD_NAME]),
514
+ chunk_table,
515
+ keys=[self.ROWID_FIELD_NAME],
516
+ join_type="inner",
1151
517
  )
1152
- ]
1153
- updated_rows = joined_table.set_column(
1154
- joined_table.schema.get_field_index(self.ROWID_FIELD_NAME),
1155
- self.rowid_field,
1156
- [rowid_values],
1157
- )
1158
- table_info.row_id_counter += chunk_length
1159
518
 
1160
- # Now the columns may be in a different order, so we need to realign them.
1161
- updated_rows = updated_rows.select(existing_table.schema.names)
519
+ # Add the new row id column.
520
+ chunk_length = len(joined_table)
521
+ rowid_values = [
522
+ x
523
+ for x in range(
524
+ table_info.row_id_counter,
525
+ table_info.row_id_counter + chunk_length,
526
+ )
527
+ ]
528
+ updated_rows = joined_table.set_column(
529
+ joined_table.schema.get_field_index(self.ROWID_FIELD_NAME),
530
+ self.rowid_field,
531
+ [rowid_values],
532
+ )
533
+ table_info.row_id_counter += chunk_length
534
+
535
+ # Now the columns may be in a different order, so we need to realign them.
536
+ updated_rows = updated_rows.select(existing_table.schema.names)
1162
537
 
1163
- check_schema_is_subset_of_schema(existing_table.schema, updated_rows.schema)
538
+ check_schema_is_subset_of_schema(existing_table.schema, updated_rows.schema)
1164
539
 
1165
- updated_rows = conform_nullable(existing_table.schema, updated_rows)
540
+ updated_rows = conform_nullable(existing_table.schema, updated_rows)
1166
541
 
1167
- updated_table = pa.concat_tables(
1168
- [
1169
- table_without_updated_rows,
1170
- updated_rows.select(table_without_updated_rows.schema.names),
1171
- ]
1172
- )
542
+ updated_table = pa.concat_tables(
543
+ [
544
+ table_without_updated_rows,
545
+ updated_rows.select(table_without_updated_rows.schema.names),
546
+ ]
547
+ )
1173
548
 
1174
- if return_chunks:
1175
- writer.write_table(updated_rows)
549
+ if return_chunks:
550
+ writer.write_table(updated_rows)
1176
551
 
1177
- existing_table = updated_table
552
+ existing_table = updated_table
1178
553
 
1179
- table_info.update_table(existing_table)
554
+ table_info.update_table(existing_table)
1180
555
 
1181
556
  return change_count
1182
557
 
@@ -1195,47 +570,47 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
1195
570
 
1196
571
  if descriptor_parts.type != "table":
1197
572
  raise flight.FlightServerError(f"Unsupported descriptor type: {descriptor_parts.type}")
1198
- library = self.contents[context.caller.token.token]
1199
- database = library.by_name(descriptor_parts.catalog_name)
1200
- schema = database.by_name(descriptor_parts.schema_name)
1201
- table_info = schema.by_name("table", descriptor_parts.name)
573
+ with DatabaseLibraryContext(context.caller.token.token) as library:
574
+ database = library.by_name(descriptor_parts.catalog_name)
575
+ schema = database.by_name(descriptor_parts.schema_name)
576
+ table_info = schema.by_name("table", descriptor_parts.name)
1202
577
 
1203
- existing_table = table_info.version()
1204
- writer.begin(existing_table.schema)
578
+ existing_table = table_info.version()
579
+ writer.begin(existing_table.schema)
1205
580
 
1206
- rowid_index = existing_table.schema.get_field_index(self.ROWID_FIELD_NAME)
1207
- assert rowid_index != -1
581
+ rowid_index = existing_table.schema.get_field_index(self.ROWID_FIELD_NAME)
582
+ assert rowid_index != -1
1208
583
 
1209
- change_count = 0
584
+ change_count = 0
1210
585
 
1211
- for chunk in reader:
1212
- if chunk.data is not None:
1213
- chunk_table = pa.Table.from_batches([chunk.data])
1214
- assert chunk_table.num_rows > 0
586
+ for chunk in reader:
587
+ if chunk.data is not None:
588
+ chunk_table = pa.Table.from_batches([chunk.data])
589
+ assert chunk_table.num_rows > 0
1215
590
 
1216
- # Should only be getting the row id.
1217
- assert chunk_table.num_columns == 1
1218
- # the rowid field doesn't come back the same since it missing the
1219
- # not null flag and the metadata, so can't compare the schemas
1220
- input_rowid_index = chunk_table.schema.get_field_index(self.ROWID_FIELD_NAME)
591
+ # Should only be getting the row id.
592
+ assert chunk_table.num_columns == 1
593
+ # the rowid field doesn't come back the same since it missing the
594
+ # not null flag and the metadata, so can't compare the schemas
595
+ input_rowid_index = chunk_table.schema.get_field_index(self.ROWID_FIELD_NAME)
1221
596
 
1222
- # Now do an antijoin to get the rows that are not in the delete_rows.
1223
- target_rowids = chunk_table.column(input_rowid_index)
1224
- existing_row_ids = existing_table.column(rowid_index)
597
+ # Now do an antijoin to get the rows that are not in the delete_rows.
598
+ target_rowids = chunk_table.column(input_rowid_index)
599
+ existing_row_ids = existing_table.column(rowid_index)
1225
600
 
1226
- mask = pc.is_in(existing_row_ids, value_set=target_rowids)
601
+ mask = pc.is_in(existing_row_ids, value_set=target_rowids)
1227
602
 
1228
- target_rows = pc.filter(existing_table, mask)
1229
- changed_table = pc.filter(existing_table, pc.invert(mask))
603
+ target_rows = pc.filter(existing_table, mask)
604
+ changed_table = pc.filter(existing_table, pc.invert(mask))
1230
605
 
1231
- change_count += target_rows.num_rows
606
+ change_count += target_rows.num_rows
1232
607
 
1233
- if return_chunks:
1234
- writer.write_table(target_rows)
608
+ if return_chunks:
609
+ writer.write_table(target_rows)
1235
610
 
1236
- existing_table = changed_table
611
+ existing_table = changed_table
1237
612
 
1238
- table_info.update_table(existing_table)
613
+ table_info.update_table(existing_table)
1239
614
 
1240
615
  return change_count
1241
616
 
@@ -1254,65 +629,66 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
1254
629
 
1255
630
  if descriptor_parts.type != "table":
1256
631
  raise flight.FlightServerError(f"Unsupported descriptor type: {descriptor_parts.type}")
1257
- library = self.contents[context.caller.token.token]
1258
- database = library.by_name(descriptor_parts.catalog_name)
1259
- schema = database.by_name(descriptor_parts.schema_name)
1260
- table_info = schema.by_name("table", descriptor_parts.name)
1261
632
 
1262
- existing_table = table_info.version()
1263
- writer.begin(existing_table.schema)
1264
- change_count = 0
633
+ with DatabaseLibraryContext(context.caller.token.token) as library:
634
+ database = library.by_name(descriptor_parts.catalog_name)
635
+ schema = database.by_name(descriptor_parts.schema_name)
636
+ table_info = schema.by_name("table", descriptor_parts.name)
1265
637
 
1266
- rowid_index = existing_table.schema.get_field_index(self.ROWID_FIELD_NAME)
1267
- assert rowid_index != -1
638
+ existing_table = table_info.version()
639
+ writer.begin(existing_table.schema)
640
+ change_count = 0
1268
641
 
1269
- # Check that the data being read matches the table without the rowid column.
642
+ rowid_index = existing_table.schema.get_field_index(self.ROWID_FIELD_NAME)
643
+ assert rowid_index != -1
1270
644
 
1271
- # DuckDB won't send field metadata when it sends us the schema that it uses
1272
- # to perform an insert, so we need some way to adapt the schema we
1273
- check_schema_is_subset_of_schema(existing_table.schema, reader.schema)
645
+ # Check that the data being read matches the table without the rowid column.
1274
646
 
1275
- # FIXME: need to handle the case of rowids.
647
+ # DuckDB won't send field metadata when it sends us the schema that it uses
648
+ # to perform an insert, so we need some way to adapt the schema we
649
+ check_schema_is_subset_of_schema(existing_table.schema, reader.schema)
1276
650
 
1277
- for chunk in reader:
1278
- if chunk.data is not None:
1279
- new_rows = pa.Table.from_batches([chunk.data])
1280
- assert new_rows.num_rows > 0
651
+ # FIXME: need to handle the case of rowids.
1281
652
 
1282
- # append the row id column to the new rows.
1283
- chunk_length = new_rows.num_rows
1284
- rowid_values = [
1285
- x
1286
- for x in range(
1287
- table_info.row_id_counter,
1288
- table_info.row_id_counter + chunk_length,
1289
- )
1290
- ]
1291
- new_rows = new_rows.append_column(self.rowid_field, [rowid_values])
1292
- table_info.row_id_counter += chunk_length
1293
- change_count += chunk_length
1294
-
1295
- if return_chunks:
1296
- writer.write_table(new_rows)
1297
-
1298
- # Since the table could have columns removed and deleted, use .select
1299
- # the ensure that the columns are aligned in the same order as the original table.
1300
-
1301
- # So it turns out that DuckDB doesn't send the "not null" flag in the arrow schema.
1302
- #
1303
- # This means we can't concat the tables, without those flags matching.
1304
- # for field_name in existing_table.schema.names:
1305
- # field = existing_table.schema.field(field_name)
1306
- # if not field.nullable:
1307
- # field_index = new_rows.schema.get_field_index(field_name)
1308
- # new_rows = new_rows.set_column(
1309
- # field_index, field.with_nullable(False), new_rows.column(field_index)
1310
- # )
1311
- new_rows = conform_nullable(existing_table.schema, new_rows)
1312
-
1313
- table_info.update_table(
1314
- pa.concat_tables([existing_table, new_rows.select(existing_table.schema.names)])
1315
- )
653
+ for chunk in reader:
654
+ if chunk.data is not None:
655
+ new_rows = pa.Table.from_batches([chunk.data])
656
+ assert new_rows.num_rows > 0
657
+
658
+ # append the row id column to the new rows.
659
+ chunk_length = new_rows.num_rows
660
+ rowid_values = [
661
+ x
662
+ for x in range(
663
+ table_info.row_id_counter,
664
+ table_info.row_id_counter + chunk_length,
665
+ )
666
+ ]
667
+ new_rows = new_rows.append_column(self.rowid_field, [rowid_values])
668
+ table_info.row_id_counter += chunk_length
669
+ change_count += chunk_length
670
+
671
+ if return_chunks:
672
+ writer.write_table(new_rows)
673
+
674
+ # Since the table could have columns removed and deleted, use .select
675
+ # the ensure that the columns are aligned in the same order as the original table.
676
+
677
+ # So it turns out that DuckDB doesn't send the "not null" flag in the arrow schema.
678
+ #
679
+ # This means we can't concat the tables, without those flags matching.
680
+ # for field_name in existing_table.schema.names:
681
+ # field = existing_table.schema.field(field_name)
682
+ # if not field.nullable:
683
+ # field_index = new_rows.schema.get_field_index(field_name)
684
+ # new_rows = new_rows.set_column(
685
+ # field_index, field.with_nullable(False), new_rows.column(field_index)
686
+ # )
687
+ new_rows = conform_nullable(existing_table.schema, new_rows)
688
+
689
+ existing_table = pa.concat_tables([existing_table, new_rows.select(existing_table.schema.names)])
690
+
691
+ table_info.update_table(existing_table)
1316
692
 
1317
693
  return change_count
1318
694
 
@@ -1330,13 +706,13 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
1330
706
 
1331
707
  if descriptor_parts.type != "table_function":
1332
708
  raise flight.FlightServerError(f"Unsupported descriptor type: {descriptor_parts.type}")
1333
- library = self.contents[context.caller.token.token]
1334
- database = library.by_name(descriptor_parts.catalog_name)
1335
- schema = database.by_name(descriptor_parts.schema_name)
1336
- table_info = schema.by_name("table_function", descriptor_parts.name)
709
+ with DatabaseLibraryContext(context.caller.token.token) as library:
710
+ database = library.by_name(descriptor_parts.catalog_name)
711
+ schema = database.by_name(descriptor_parts.schema_name)
712
+ table_info = schema.by_name("table_function", descriptor_parts.name)
1337
713
 
1338
- output_schema = table_info.output_schema(parameters=parameters.parameters, input_schema=input_schema)
1339
- gen = table_info.handler(parameters, output_schema)
714
+ output_schema = table_info.output_schema(parameters=parameters.parameters, input_schema=input_schema)
715
+ gen = table_info.handler(parameters, output_schema)
1340
716
  return (output_schema, gen)
1341
717
 
1342
718
  def action_add_column(
@@ -1347,31 +723,30 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
1347
723
  ) -> flight.FlightInfo:
1348
724
  assert context.caller is not None
1349
725
 
1350
- self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
1351
- library = self.contents[context.caller.token.token]
1352
- database = library.by_name(parameters.catalog)
1353
- schema = database.by_name(parameters.schema_name)
726
+ with DatabaseLibraryContext(context.caller.token.token) as library:
727
+ database = library.by_name(parameters.catalog)
728
+ schema = database.by_name(parameters.schema_name)
1354
729
 
1355
- table_info = schema.by_name("table", parameters.name)
730
+ table_info = schema.by_name("table", parameters.name)
1356
731
 
1357
- assert len(parameters.column_schema.names) == 1
732
+ assert len(parameters.column_schema.names) == 1
1358
733
 
1359
- existing_table = table_info.version()
1360
- # Don't allow duplicate colum names.
1361
- assert parameters.column_schema.field(0).name not in existing_table.schema.names
734
+ existing_table = table_info.version()
735
+ # Don't allow duplicate colum names.
736
+ assert parameters.column_schema.field(0).name not in existing_table.schema.names
1362
737
 
1363
- table_info.update_table(
1364
- existing_table.append_column(
1365
- parameters.column_schema.field(0).name,
1366
- [
1367
- pa.nulls(
1368
- existing_table.num_rows,
1369
- type=parameters.column_schema.field(0).type,
1370
- )
1371
- ],
738
+ table_info.update_table(
739
+ existing_table.append_column(
740
+ parameters.column_schema.field(0).name,
741
+ [
742
+ pa.nulls(
743
+ existing_table.num_rows,
744
+ type=parameters.column_schema.field(0).type,
745
+ )
746
+ ],
747
+ )
1372
748
  )
1373
- )
1374
- database.version += 1
749
+ database.version += 1
1375
750
 
1376
751
  return table_info.flight_info(
1377
752
  name=parameters.name,
@@ -1387,15 +762,14 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
1387
762
  ) -> flight.FlightInfo:
1388
763
  assert context.caller is not None
1389
764
 
1390
- self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
1391
- library = self.contents[context.caller.token.token]
1392
- database = library.by_name(parameters.catalog)
1393
- schema = database.by_name(parameters.schema_name)
765
+ with DatabaseLibraryContext(context.caller.token.token) as library:
766
+ database = library.by_name(parameters.catalog)
767
+ schema = database.by_name(parameters.schema_name)
1394
768
 
1395
- table_info = schema.by_name("table", parameters.name)
769
+ table_info = schema.by_name("table", parameters.name)
1396
770
 
1397
- table_info.update_table(table_info.version().drop(parameters.removed_column))
1398
- database.version += 1
771
+ table_info.update_table(table_info.version().drop(parameters.removed_column))
772
+ database.version += 1
1399
773
 
1400
774
  return table_info.flight_info(
1401
775
  name=parameters.name,
@@ -1411,15 +785,14 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
1411
785
  ) -> flight.FlightInfo:
1412
786
  assert context.caller is not None
1413
787
 
1414
- self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
1415
- library = self.contents[context.caller.token.token]
1416
- database = library.by_name(parameters.catalog)
1417
- schema = database.by_name(parameters.schema_name)
788
+ with DatabaseLibraryContext(context.caller.token.token) as library:
789
+ database = library.by_name(parameters.catalog)
790
+ schema = database.by_name(parameters.schema_name)
1418
791
 
1419
- table_info = schema.by_name("table", parameters.name)
792
+ table_info = schema.by_name("table", parameters.name)
1420
793
 
1421
- table_info.update_table(table_info.version().rename_columns({parameters.old_name: parameters.new_name}))
1422
- database.version += 1
794
+ table_info.update_table(table_info.version().rename_columns({parameters.old_name: parameters.new_name}))
795
+ database.version += 1
1423
796
 
1424
797
  return table_info.flight_info(
1425
798
  name=parameters.name,
@@ -1435,16 +808,15 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
1435
808
  ) -> flight.FlightInfo:
1436
809
  assert context.caller is not None
1437
810
 
1438
- self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
1439
- library = self.contents[context.caller.token.token]
1440
- database = library.by_name(parameters.catalog)
1441
- schema = database.by_name(parameters.schema_name)
811
+ with DatabaseLibraryContext(context.caller.token.token) as library:
812
+ database = library.by_name(parameters.catalog)
813
+ schema = database.by_name(parameters.schema_name)
1442
814
 
1443
- table_info = schema.by_name("table", parameters.name)
815
+ table_info = schema.by_name("table", parameters.name)
1444
816
 
1445
- schema.tables_by_name[parameters.new_table_name] = schema.tables_by_name.pop(parameters.name)
817
+ schema.tables_by_name[parameters.new_table_name] = schema.tables_by_name.pop(parameters.name)
1446
818
 
1447
- database.version += 1
819
+ database.version += 1
1448
820
 
1449
821
  return table_info.flight_info(
1450
822
  name=parameters.new_table_name,
@@ -1460,32 +832,31 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
1460
832
  ) -> flight.FlightInfo:
1461
833
  assert context.caller is not None
1462
834
 
1463
- self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
1464
- library = self.contents[context.caller.token.token]
1465
- database = library.by_name(parameters.catalog)
1466
- schema = database.by_name(parameters.schema_name)
1467
- table_info = schema.by_name("table", parameters.name)
835
+ with DatabaseLibraryContext(context.caller.token.token) as library:
836
+ database = library.by_name(parameters.catalog)
837
+ schema = database.by_name(parameters.schema_name)
838
+ table_info = schema.by_name("table", parameters.name)
1468
839
 
1469
- # Defaults are set as metadata on a field.
840
+ # Defaults are set as metadata on a field.
1470
841
 
1471
- t = table_info.version()
1472
- field_index = t.schema.get_field_index(parameters.column_name)
1473
- field = t.schema.field(parameters.column_name)
1474
- new_metadata: dict[str, Any] = {}
1475
- if field.metadata:
1476
- new_metadata = {**field.metadata}
842
+ t = table_info.version()
843
+ field_index = t.schema.get_field_index(parameters.column_name)
844
+ field = t.schema.field(parameters.column_name)
845
+ new_metadata: dict[str, Any] = {}
846
+ if field.metadata:
847
+ new_metadata = {**field.metadata}
1477
848
 
1478
- new_metadata["default"] = parameters.expression
849
+ new_metadata["default"] = parameters.expression
1479
850
 
1480
- table_info.update_table(
1481
- t.set_column(
1482
- field_index,
1483
- field.with_metadata(new_metadata),
1484
- t.column(field_index),
851
+ table_info.update_table(
852
+ t.set_column(
853
+ field_index,
854
+ field.with_metadata(new_metadata),
855
+ t.column(field_index),
856
+ )
1485
857
  )
1486
- )
1487
858
 
1488
- database.version += 1
859
+ database.version += 1
1489
860
 
1490
861
  return table_info.flight_info(
1491
862
  name=parameters.name,
@@ -1501,29 +872,28 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
1501
872
  ) -> flight.FlightInfo:
1502
873
  assert context.caller is not None
1503
874
 
1504
- self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
1505
- library = self.contents[context.caller.token.token]
1506
- database = library.by_name(parameters.catalog)
1507
- schema = database.by_name(parameters.schema_name)
875
+ with DatabaseLibraryContext(context.caller.token.token) as library:
876
+ database = library.by_name(parameters.catalog)
877
+ schema = database.by_name(parameters.schema_name)
1508
878
 
1509
- table_info = schema.by_name("table", parameters.name)
879
+ table_info = schema.by_name("table", parameters.name)
1510
880
 
1511
- t = table_info.version()
1512
- field_index = t.schema.get_field_index(parameters.column_name)
1513
- field = t.schema.field(parameters.column_name)
881
+ t = table_info.version()
882
+ field_index = t.schema.get_field_index(parameters.column_name)
883
+ field = t.schema.field(parameters.column_name)
1514
884
 
1515
- if t.column(field_index).null_count > 0:
1516
- raise flight.FlightServerError(f"Cannot set column {parameters.column_name} contains null values")
885
+ if t.column(field_index).null_count > 0:
886
+ raise flight.FlightServerError(f"Cannot set column {parameters.column_name} contains null values")
1517
887
 
1518
- table_info.update_table(
1519
- t.set_column(
1520
- field_index,
1521
- field.with_nullable(False),
1522
- t.column(field_index),
888
+ table_info.update_table(
889
+ t.set_column(
890
+ field_index,
891
+ field.with_nullable(False),
892
+ t.column(field_index),
893
+ )
1523
894
  )
1524
- )
1525
895
 
1526
- database.version += 1
896
+ database.version += 1
1527
897
 
1528
898
  return table_info.flight_info(
1529
899
  name=parameters.name,
@@ -1539,26 +909,25 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
1539
909
  ) -> flight.FlightInfo:
1540
910
  assert context.caller is not None
1541
911
 
1542
- self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
1543
- library = self.contents[context.caller.token.token]
1544
- database = library.by_name(parameters.catalog)
1545
- schema = database.by_name(parameters.schema_name)
912
+ with DatabaseLibraryContext(context.caller.token.token) as library:
913
+ database = library.by_name(parameters.catalog)
914
+ schema = database.by_name(parameters.schema_name)
1546
915
 
1547
- table_info = schema.by_name("table", parameters.name)
916
+ table_info = schema.by_name("table", parameters.name)
1548
917
 
1549
- t = table_info.version()
1550
- field_index = t.schema.get_field_index(parameters.column_name)
1551
- field = t.schema.field(parameters.column_name)
918
+ t = table_info.version()
919
+ field_index = t.schema.get_field_index(parameters.column_name)
920
+ field = t.schema.field(parameters.column_name)
1552
921
 
1553
- table_info.update_table(
1554
- t.set_column(
1555
- field_index,
1556
- field.with_nullable(True),
1557
- t.column(field_index),
922
+ table_info.update_table(
923
+ t.set_column(
924
+ field_index,
925
+ field.with_nullable(True),
926
+ t.column(field_index),
927
+ )
1558
928
  )
1559
- )
1560
929
 
1561
- database.version += 1
930
+ database.version += 1
1562
931
 
1563
932
  return table_info.flight_info(
1564
933
  name=parameters.name,
@@ -1574,33 +943,32 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
1574
943
  ) -> flight.FlightInfo:
1575
944
  assert context.caller is not None
1576
945
 
1577
- self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
1578
- library = self.contents[context.caller.token.token]
1579
- database = library.by_name(parameters.catalog)
1580
- schema = database.by_name(parameters.schema_name)
946
+ with DatabaseLibraryContext(context.caller.token.token) as library:
947
+ database = library.by_name(parameters.catalog)
948
+ schema = database.by_name(parameters.schema_name)
1581
949
 
1582
- table_info = schema.by_name("table", parameters.name)
950
+ table_info = schema.by_name("table", parameters.name)
1583
951
 
1584
- # Defaults are set as metadata on a field.
952
+ # Defaults are set as metadata on a field.
1585
953
 
1586
- t = table_info.version()
1587
- column_name = parameters.column_schema.field(0).name
1588
- field_index = t.schema.get_field_index(column_name)
1589
- field = t.schema.field(column_name)
954
+ t = table_info.version()
955
+ column_name = parameters.column_schema.field(0).name
956
+ field_index = t.schema.get_field_index(column_name)
957
+ field = t.schema.field(column_name)
1590
958
 
1591
- new_type = parameters.column_schema.field(0).type
1592
- new_field = pa.field(field.name, new_type, metadata=field.metadata)
1593
- new_data = pc.cast(t.column(field_index), new_type)
959
+ new_type = parameters.column_schema.field(0).type
960
+ new_field = pa.field(field.name, new_type, metadata=field.metadata)
961
+ new_data = pc.cast(t.column(field_index), new_type)
1594
962
 
1595
- table_info.update_table(
1596
- t.set_column(
1597
- field_index,
1598
- new_field,
1599
- new_data,
963
+ table_info.update_table(
964
+ t.set_column(
965
+ field_index,
966
+ new_field,
967
+ new_data,
968
+ )
1600
969
  )
1601
- )
1602
970
 
1603
- database.version += 1
971
+ database.version += 1
1604
972
 
1605
973
  return table_info.flight_info(
1606
974
  name=parameters.name,
@@ -1608,6 +976,71 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
1608
976
  schema_name=parameters.schema_name,
1609
977
  )[0]
1610
978
 
979
+ def action_column_statistics(
980
+ self,
981
+ *,
982
+ context: base_server.CallContext[auth.Account, auth.AccountToken],
983
+ parameters: parameter_types.ColumnStatistics,
984
+ ) -> pa.Table:
985
+ assert context.caller is not None
986
+
987
+ descriptor_parts = descriptor_unpack_(parameters.flight_descriptor)
988
+ with DatabaseLibraryContext(context.caller.token.token, readonly=True) as library:
989
+ database = library.by_name(descriptor_parts.catalog_name)
990
+ schema = database.by_name(descriptor_parts.schema_name)
991
+
992
+ assert descriptor_parts.type == "table"
993
+ table = schema.by_name("table", descriptor_parts.name)
994
+
995
+ contents = table.version().column(parameters.column_name)
996
+ # Since the table is a Pyarrow table we need to produce some values.
997
+ not_null_count = pc.count(contents, "only_valid").as_py()
998
+ null_count = pc.count(contents, "only_null").as_py()
999
+ distinct_count = len(set(contents.to_pylist()))
1000
+ sorted_contents = sorted(filter(lambda x: x is not None, contents.to_pylist()))
1001
+ min_value = sorted_contents[0]
1002
+ max_value = sorted_contents[-1]
1003
+
1004
+ additional_values = {}
1005
+ additional_schema_fields = []
1006
+ if contents.type in (pa.string(), pa.utf8(), pa.binary()):
1007
+ max_length = pc.max(pc.binary_length(contents)).as_py()
1008
+
1009
+ additional_values = {"max_string_length": max_length, "contains_unicode": contents.type == pa.utf8()}
1010
+ additional_schema_fields = [
1011
+ pa.field("max_string_length", pa.uint64()),
1012
+ pa.field("contains_unicode", pa.bool_()),
1013
+ ]
1014
+
1015
+ if contents.type == pa.uuid():
1016
+ # For UUIDs, we need to convert them to strings for the output.
1017
+ min_value = min_value.bytes
1018
+ max_value = max_value.bytes
1019
+
1020
+ result_table = pa.Table.from_pylist(
1021
+ [
1022
+ {
1023
+ "has_not_null": not_null_count > 0,
1024
+ "has_null": null_count > 0,
1025
+ "distinct_count": distinct_count,
1026
+ "min": min_value,
1027
+ "max": max_value,
1028
+ **additional_values,
1029
+ }
1030
+ ],
1031
+ schema=pa.schema(
1032
+ [
1033
+ pa.field("has_not_null", pa.bool_()),
1034
+ pa.field("has_null", pa.bool_()),
1035
+ pa.field("distinct_count", pa.uint64()),
1036
+ pa.field("min", contents.type),
1037
+ pa.field("max", contents.type),
1038
+ *additional_schema_fields,
1039
+ ]
1040
+ ),
1041
+ )
1042
+ return result_table
1043
+
1611
1044
  def impl_do_get(
1612
1045
  self,
1613
1046
  *,
@@ -1619,58 +1052,60 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
1619
1052
  ticket_data = flight_handling.decode_ticket_model(ticket, FlightTicketData)
1620
1053
 
1621
1054
  descriptor_parts = descriptor_unpack_(ticket_data.descriptor)
1622
- library = self.contents[context.caller.token.token]
1623
- database = library.by_name(descriptor_parts.catalog_name)
1624
- schema = database.by_name(descriptor_parts.schema_name)
1625
-
1626
- if descriptor_parts.type == "table":
1627
- table = schema.by_name("table", descriptor_parts.name)
1055
+ with DatabaseLibraryContext(context.caller.token.token, readonly=True) as library:
1056
+ database = library.by_name(descriptor_parts.catalog_name)
1057
+ schema = database.by_name(descriptor_parts.schema_name)
1628
1058
 
1629
- if ticket_data.at_unit == "VERSION":
1630
- assert ticket_data.at_value is not None
1059
+ if descriptor_parts.type == "table":
1060
+ table = schema.by_name("table", descriptor_parts.name)
1631
1061
 
1632
- # Check if at_value is an integer but currently is a string.
1633
- if not re.match(r"^\d+$", ticket_data.at_value):
1634
- raise flight.FlightServerError(f"Invalid version: {ticket_data.at_value}")
1062
+ if ticket_data.at_unit == "VERSION":
1063
+ assert ticket_data.at_value is not None
1635
1064
 
1636
- table_version = table.version(int(ticket_data.at_value))
1637
- elif ticket_data.at_unit == "TIMESTAMP":
1638
- raise flight.FlightServerError("Timestamp not supported for table versioning")
1639
- else:
1640
- table_version = table.version()
1641
-
1642
- if descriptor_parts.schema_name == "test_predicate_pushdown" and ticket_data.where_clause is not None:
1643
- # We are going to do the predicate pushdown for filtering the data we have in memory.
1644
- # At this point if we have JSON filters we should test that we can decode them.
1645
- with duckdb.connect(":memory:") as connection:
1646
- connection.execute("SET TimeZone = 'UTC'")
1647
- sql = f"select * from table_version where {ticket_data.where_clause}"
1648
- try:
1649
- results = connection.execute(sql).fetch_arrow_table()
1650
- except Exception as e:
1651
- raise flight.FlightServerError(f"Failed to execute predicate pushdown: {e} sql: {sql}") from e
1652
- table_version = results
1653
-
1654
- return flight.RecordBatchStream(table_version)
1655
- elif descriptor_parts.type == "table_function":
1656
- table_function = schema.by_name("table_function", descriptor_parts.name)
1065
+ # Check if at_value is an integer but currently is a string.
1066
+ if not re.match(r"^\d+$", ticket_data.at_value):
1067
+ raise flight.FlightServerError(f"Invalid version: {ticket_data.at_value}")
1657
1068
 
1658
- output_schema = table_function.output_schema(
1659
- ticket_data.table_function_parameters,
1660
- ticket_data.table_function_input_schema,
1661
- )
1069
+ table_version = table.version(int(ticket_data.at_value))
1070
+ elif ticket_data.at_unit == "TIMESTAMP":
1071
+ raise flight.FlightServerError("Timestamp not supported for table versioning")
1072
+ else:
1073
+ table_version = table.version()
1074
+
1075
+ if descriptor_parts.schema_name == "test_predicate_pushdown" and ticket_data.where_clause is not None:
1076
+ # We are going to do the predicate pushdown for filtering the data we have in memory.
1077
+ # At this point if we have JSON filters we should test that we can decode them.
1078
+ with duckdb.connect(":memory:") as connection:
1079
+ connection.execute("SET TimeZone = 'UTC'")
1080
+ sql = f"select * from table_version where {ticket_data.where_clause}"
1081
+ try:
1082
+ results = connection.execute(sql).fetch_arrow_table()
1083
+ except Exception as e:
1084
+ raise flight.FlightServerError(
1085
+ f"Failed to execute predicate pushdown: {e} sql: {sql}"
1086
+ ) from e
1087
+ table_version = results
1088
+
1089
+ return flight.RecordBatchStream(table_version)
1090
+ elif descriptor_parts.type == "table_function":
1091
+ table_function = schema.by_name("table_function", descriptor_parts.name)
1092
+
1093
+ output_schema = table_function.output_schema(
1094
+ ticket_data.table_function_parameters,
1095
+ ticket_data.table_function_input_schema,
1096
+ )
1662
1097
 
1663
- return flight.GeneratorStream(
1664
- output_schema,
1665
- table_function.handler(
1666
- parameter_types.TableFunctionParameters(
1667
- parameters=ticket_data.table_function_parameters, where_clause=ticket_data.where_clause
1668
- ),
1098
+ return flight.GeneratorStream(
1669
1099
  output_schema,
1670
- ),
1671
- )
1672
- else:
1673
- raise flight.FlightServerError(f"Unsupported descriptor type: {descriptor_parts.type}")
1100
+ table_function.handler(
1101
+ parameter_types.TableFunctionParameters(
1102
+ parameters=ticket_data.table_function_parameters, where_clause=ticket_data.where_clause
1103
+ ),
1104
+ output_schema,
1105
+ ),
1106
+ )
1107
+ else:
1108
+ raise flight.FlightServerError(f"Unsupported descriptor type: {descriptor_parts.type}")
1674
1109
 
1675
1110
  def action_table_function_flight_info(
1676
1111
  self,
@@ -1680,14 +1115,14 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
1680
1115
  ) -> flight.FlightInfo:
1681
1116
  assert context.caller is not None
1682
1117
 
1683
- library = self.contents[context.caller.token.token]
1684
- descriptor_parts = descriptor_unpack_(parameters.descriptor)
1118
+ with DatabaseLibraryContext(context.caller.token.token) as library:
1119
+ descriptor_parts = descriptor_unpack_(parameters.descriptor)
1685
1120
 
1686
- database = library.by_name(descriptor_parts.catalog_name)
1687
- schema = database.by_name(descriptor_parts.schema_name)
1121
+ database = library.by_name(descriptor_parts.catalog_name)
1122
+ schema = database.by_name(descriptor_parts.schema_name)
1688
1123
 
1689
- # Now to get the table function, its a bit harder since they are named by action.
1690
- table_function = schema.by_name("table_function", descriptor_parts.name)
1124
+ # Now to get the table function, its a bit harder since they are named by action.
1125
+ table_function = schema.by_name("table_function", descriptor_parts.name)
1691
1126
 
1692
1127
  return table_function.flight_info(
1693
1128
  name=descriptor_parts.name,
@@ -1705,40 +1140,40 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
1705
1140
  ) -> flight.FlightInfo:
1706
1141
  assert context.caller is not None
1707
1142
 
1708
- library = self.contents[context.caller.token.token]
1709
- descriptor_parts = descriptor_unpack_(parameters.descriptor)
1710
-
1711
- database = library.by_name(descriptor_parts.catalog_name)
1712
- schema = database.by_name(descriptor_parts.schema_name)
1713
-
1714
- if descriptor_parts.type == "table_function":
1715
- raise flight.FlightServerError("Table function flight info not supported")
1716
- elif descriptor_parts.type == "table":
1717
- table = schema.by_name("table", descriptor_parts.name)
1718
-
1719
- version_id = None
1720
- if parameters.at_unit is not None:
1721
- if parameters.at_unit == "VERSION":
1722
- assert parameters.at_value is not None
1723
-
1724
- # Check if at_value is an integer but currently is a string.
1725
- if not re.match(r"^\d+$", parameters.at_value):
1726
- raise flight.FlightServerError(f"Invalid version: {parameters.at_value}")
1727
-
1728
- version_id = int(parameters.at_value)
1729
- elif parameters.at_unit == "TIMESTAMP":
1730
- raise flight.FlightServerError("Timestamp not supported for table versioning")
1731
- else:
1732
- raise flight.FlightServerError(f"Unsupported at_unit: {parameters.at_unit}")
1733
-
1734
- return table.flight_info(
1735
- name=descriptor_parts.name,
1736
- catalog_name=descriptor_parts.catalog_name,
1737
- schema_name=descriptor_parts.schema_name,
1738
- version=version_id,
1739
- )[0]
1740
- else:
1741
- raise flight.FlightServerError(f"Unsupported descriptor type: {descriptor_parts.type}")
1143
+ with DatabaseLibraryContext(context.caller.token.token) as library:
1144
+ descriptor_parts = descriptor_unpack_(parameters.descriptor)
1145
+
1146
+ database = library.by_name(descriptor_parts.catalog_name)
1147
+ schema = database.by_name(descriptor_parts.schema_name)
1148
+
1149
+ if descriptor_parts.type == "table_function":
1150
+ raise flight.FlightServerError("Table function flight info not supported")
1151
+ elif descriptor_parts.type == "table":
1152
+ table = schema.by_name("table", descriptor_parts.name)
1153
+
1154
+ version_id = None
1155
+ if parameters.at_unit is not None:
1156
+ if parameters.at_unit == "VERSION":
1157
+ assert parameters.at_value is not None
1158
+
1159
+ # Check if at_value is an integer but currently is a string.
1160
+ if not re.match(r"^\d+$", parameters.at_value):
1161
+ raise flight.FlightServerError(f"Invalid version: {parameters.at_value}")
1162
+
1163
+ version_id = int(parameters.at_value)
1164
+ elif parameters.at_unit == "TIMESTAMP":
1165
+ raise flight.FlightServerError("Timestamp not supported for table versioning")
1166
+ else:
1167
+ raise flight.FlightServerError(f"Unsupported at_unit: {parameters.at_unit}")
1168
+
1169
+ return table.flight_info(
1170
+ name=descriptor_parts.name,
1171
+ catalog_name=descriptor_parts.catalog_name,
1172
+ schema_name=descriptor_parts.schema_name,
1173
+ version=version_id,
1174
+ )[0]
1175
+ else:
1176
+ raise flight.FlightServerError(f"Unsupported descriptor type: {descriptor_parts.type}")
1742
1177
 
1743
1178
  def exchange_scalar_function(
1744
1179
  self,
@@ -1755,10 +1190,10 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
1755
1190
  if descriptor_parts.type != "scalar_function":
1756
1191
  raise flight.FlightServerError(f"Unsupported descriptor type: {descriptor_parts.type}")
1757
1192
 
1758
- library = self.contents[context.caller.token.token]
1759
- database = library.by_name(descriptor_parts.catalog_name)
1760
- schema = database.by_name(descriptor_parts.schema_name)
1761
- scalar_function_info = schema.by_name("scalar_function", descriptor_parts.name)
1193
+ with DatabaseLibraryContext(context.caller.token.token) as library:
1194
+ database = library.by_name(descriptor_parts.catalog_name)
1195
+ schema = database.by_name(descriptor_parts.schema_name)
1196
+ scalar_function_info = schema.by_name("scalar_function", descriptor_parts.name)
1762
1197
 
1763
1198
  writer.begin(scalar_function_info.output_schema)
1764
1199