query-farm-airport-test-server 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- query_farm_airport_test_server/collatz.sql +30 -0
- query_farm_airport_test_server/database_impl.py +1122 -0
- query_farm_airport_test_server/repro.py +17 -0
- query_farm_airport_test_server/repro2.py +15 -0
- query_farm_airport_test_server/server.py +582 -1147
- {query_farm_airport_test_server-0.1.0.dist-info → query_farm_airport_test_server-0.1.1.dist-info}/METADATA +1 -1
- query_farm_airport_test_server-0.1.1.dist-info/RECORD +13 -0
- query_farm_airport_test_server-0.1.0.dist-info/RECORD +0 -9
- {query_farm_airport_test_server-0.1.0.dist-info → query_farm_airport_test_server-0.1.1.dist-info}/WHEEL +0 -0
- {query_farm_airport_test_server-0.1.0.dist-info → query_farm_airport_test_server-0.1.1.dist-info}/entry_points.txt +0 -0
@@ -1,9 +1,8 @@
|
|
1
1
|
import hashlib
|
2
2
|
import json
|
3
3
|
import re
|
4
|
-
from collections.abc import
|
5
|
-
from
|
6
|
-
from typing import Any, Literal, TypeVar, overload
|
4
|
+
from collections.abc import Generator, Iterator
|
5
|
+
from typing import Any, TypeVar
|
7
6
|
|
8
7
|
import click
|
9
8
|
import duckdb
|
@@ -24,7 +23,14 @@ import query_farm_flight_server.server as base_server
|
|
24
23
|
import structlog
|
25
24
|
from pydantic import BaseModel, ConfigDict, field_serializer, field_validator
|
26
25
|
|
27
|
-
from .
|
26
|
+
from .database_impl import (
|
27
|
+
DatabaseContents,
|
28
|
+
DatabaseLibrary,
|
29
|
+
DatabaseLibraryContext,
|
30
|
+
SchemaCollection,
|
31
|
+
TableInfo,
|
32
|
+
descriptor_unpack_,
|
33
|
+
)
|
28
34
|
|
29
35
|
log = structlog.get_logger()
|
30
36
|
|
@@ -85,317 +91,6 @@ def check_schema_is_subset_of_schema(existing_schema: pa.Schema, new_schema: pa.
|
|
85
91
|
# return flight_handling.FlightTicketData.unpack(src)
|
86
92
|
|
87
93
|
|
88
|
-
@dataclass
|
89
|
-
class TableFunctionDynamicOutput:
|
90
|
-
# The method that will determine the output schema from the input parameters
|
91
|
-
schema_creator: Callable[[pa.RecordBatch, pa.Schema | None], pa.Schema]
|
92
|
-
|
93
|
-
# The default parameters for the function, if not called with any.
|
94
|
-
default_values: tuple[pa.RecordBatch, pa.Schema | None]
|
95
|
-
|
96
|
-
|
97
|
-
@dataclass
|
98
|
-
class TableFunction:
|
99
|
-
# The input schema for the function.
|
100
|
-
input_schema: pa.Schema
|
101
|
-
|
102
|
-
output_schema_source: pa.Schema | TableFunctionDynamicOutput
|
103
|
-
|
104
|
-
# The function to call to process a chunk of rows.
|
105
|
-
handler: Callable[
|
106
|
-
[parameter_types.TableFunctionParameters, pa.Schema],
|
107
|
-
Generator[pa.RecordBatch, pa.RecordBatch, pa.RecordBatch],
|
108
|
-
]
|
109
|
-
|
110
|
-
estimated_rows: int | Callable[[parameter_types.TableFunctionFlightInfo], int] = -1
|
111
|
-
|
112
|
-
def output_schema(
|
113
|
-
self,
|
114
|
-
parameters: pa.RecordBatch | None = None,
|
115
|
-
input_schema: pa.Schema | None = None,
|
116
|
-
) -> pa.Schema:
|
117
|
-
if isinstance(self.output_schema_source, pa.Schema):
|
118
|
-
return self.output_schema_source
|
119
|
-
if parameters is None:
|
120
|
-
return self.output_schema_source.schema_creator(*self.output_schema_source.default_values)
|
121
|
-
assert isinstance(parameters, pa.RecordBatch)
|
122
|
-
result = self.output_schema_source.schema_creator(parameters, input_schema)
|
123
|
-
return result
|
124
|
-
|
125
|
-
def flight_info(
|
126
|
-
self,
|
127
|
-
*,
|
128
|
-
name: str,
|
129
|
-
catalog_name: str,
|
130
|
-
schema_name: str,
|
131
|
-
parameters: parameter_types.TableFunctionFlightInfo | None = None,
|
132
|
-
) -> tuple[flight.FlightInfo, flight_inventory.FlightSchemaMetadata]:
|
133
|
-
"""
|
134
|
-
Often its necessary to create a FlightInfo object
|
135
|
-
standardize doing that here.
|
136
|
-
"""
|
137
|
-
assert name != ""
|
138
|
-
assert catalog_name != ""
|
139
|
-
assert schema_name != ""
|
140
|
-
|
141
|
-
if isinstance(self.estimated_rows, int):
|
142
|
-
estimated_rows = self.estimated_rows
|
143
|
-
else:
|
144
|
-
assert parameters is not None
|
145
|
-
estimated_rows = self.estimated_rows(parameters)
|
146
|
-
|
147
|
-
metadata = flight_inventory.FlightSchemaMetadata(
|
148
|
-
type="table_function",
|
149
|
-
catalog=catalog_name,
|
150
|
-
schema=schema_name,
|
151
|
-
name=name,
|
152
|
-
comment=None,
|
153
|
-
input_schema=self.input_schema,
|
154
|
-
)
|
155
|
-
flight_info = flight.FlightInfo(
|
156
|
-
self.output_schema(parameters.parameters, parameters.table_input_schema)
|
157
|
-
if parameters
|
158
|
-
else self.output_schema(),
|
159
|
-
# This will always be the same descriptor, so that we can use the action
|
160
|
-
# name to determine which which table function to execute.
|
161
|
-
descriptor_pack_(catalog_name, schema_name, "table_function", name),
|
162
|
-
[],
|
163
|
-
estimated_rows,
|
164
|
-
-1,
|
165
|
-
app_metadata=metadata.serialize(),
|
166
|
-
)
|
167
|
-
return (flight_info, metadata)
|
168
|
-
|
169
|
-
|
170
|
-
@dataclass
|
171
|
-
class ScalarFunction:
|
172
|
-
# The input schema for the function.
|
173
|
-
input_schema: pa.Schema
|
174
|
-
# The output schema for the function, should only have a single column.
|
175
|
-
output_schema: pa.Schema
|
176
|
-
|
177
|
-
# The function to call to process a chunk of rows.
|
178
|
-
handler: Callable[[pa.Table], pa.Array]
|
179
|
-
|
180
|
-
def flight_info(
|
181
|
-
self, *, name: str, catalog_name: str, schema_name: str
|
182
|
-
) -> tuple[flight.FlightInfo, flight_inventory.FlightSchemaMetadata]:
|
183
|
-
"""
|
184
|
-
Often its necessary to create a FlightInfo object
|
185
|
-
standardize doing that here.
|
186
|
-
"""
|
187
|
-
metadata = flight_inventory.FlightSchemaMetadata(
|
188
|
-
type="scalar_function",
|
189
|
-
catalog=catalog_name,
|
190
|
-
schema=schema_name,
|
191
|
-
name=name,
|
192
|
-
comment=None,
|
193
|
-
input_schema=self.input_schema,
|
194
|
-
)
|
195
|
-
flight_info = flight.FlightInfo(
|
196
|
-
self.output_schema,
|
197
|
-
descriptor_pack_(catalog_name, schema_name, "scalar_function", name),
|
198
|
-
[],
|
199
|
-
-1,
|
200
|
-
-1,
|
201
|
-
app_metadata=metadata.serialize(),
|
202
|
-
)
|
203
|
-
return (flight_info, metadata)
|
204
|
-
|
205
|
-
|
206
|
-
@dataclass
|
207
|
-
class TableInfo:
|
208
|
-
# To enable version history keep track of tables.
|
209
|
-
table_versions: list[pa.Table] = field(default_factory=list)
|
210
|
-
|
211
|
-
# the next row id to assign.
|
212
|
-
row_id_counter: int = 0
|
213
|
-
|
214
|
-
def update_table(self, table: pa.Table) -> None:
|
215
|
-
assert table is not None
|
216
|
-
assert isinstance(table, pa.Table)
|
217
|
-
self.table_versions.append(table)
|
218
|
-
|
219
|
-
def version(self, version: int | None = None) -> pa.Table:
|
220
|
-
"""
|
221
|
-
Get the version of the table.
|
222
|
-
"""
|
223
|
-
assert len(self.table_versions) > 0
|
224
|
-
if version is None:
|
225
|
-
return self.table_versions[-1]
|
226
|
-
|
227
|
-
assert version < len(self.table_versions)
|
228
|
-
return self.table_versions[version]
|
229
|
-
|
230
|
-
def flight_info(
|
231
|
-
self,
|
232
|
-
*,
|
233
|
-
name: str,
|
234
|
-
catalog_name: str,
|
235
|
-
schema_name: str,
|
236
|
-
version: int | None = None,
|
237
|
-
) -> tuple[flight.FlightInfo, flight_inventory.FlightSchemaMetadata]:
|
238
|
-
"""
|
239
|
-
Often its necessary to create a FlightInfo object for the table,
|
240
|
-
standardize doing that here.
|
241
|
-
"""
|
242
|
-
metadata = flight_inventory.FlightSchemaMetadata(
|
243
|
-
type="table",
|
244
|
-
catalog=catalog_name,
|
245
|
-
schema=schema_name,
|
246
|
-
name=name,
|
247
|
-
comment=None,
|
248
|
-
)
|
249
|
-
flight_info = flight.FlightInfo(
|
250
|
-
self.version(version).schema,
|
251
|
-
descriptor_pack_(catalog_name, schema_name, "table", name),
|
252
|
-
[],
|
253
|
-
-1,
|
254
|
-
-1,
|
255
|
-
app_metadata=metadata.serialize(),
|
256
|
-
)
|
257
|
-
return (flight_info, metadata)
|
258
|
-
|
259
|
-
|
260
|
-
ObjectTypeName = Literal["table", "scalar_function", "table_function"]
|
261
|
-
|
262
|
-
|
263
|
-
@dataclass
|
264
|
-
class SchemaCollection:
|
265
|
-
tables_by_name: CaseInsensitiveDict[TableInfo] = field(default_factory=CaseInsensitiveDict[TableInfo])
|
266
|
-
scalar_functions_by_name: CaseInsensitiveDict[ScalarFunction] = field(
|
267
|
-
default_factory=CaseInsensitiveDict[ScalarFunction]
|
268
|
-
)
|
269
|
-
table_functions_by_name: CaseInsensitiveDict[TableFunction] = field(
|
270
|
-
default_factory=CaseInsensitiveDict[TableFunction]
|
271
|
-
)
|
272
|
-
|
273
|
-
def containers(
|
274
|
-
self,
|
275
|
-
) -> list[
|
276
|
-
CaseInsensitiveDict[TableInfo] | CaseInsensitiveDict[ScalarFunction] | CaseInsensitiveDict[TableFunction]
|
277
|
-
]:
|
278
|
-
return [
|
279
|
-
self.tables_by_name,
|
280
|
-
self.scalar_functions_by_name,
|
281
|
-
self.table_functions_by_name,
|
282
|
-
]
|
283
|
-
|
284
|
-
@overload
|
285
|
-
def by_name(self, type: Literal["table"], name: str) -> TableInfo: ...
|
286
|
-
|
287
|
-
@overload
|
288
|
-
def by_name(self, type: Literal["scalar_function"], name: str) -> ScalarFunction: ...
|
289
|
-
|
290
|
-
@overload
|
291
|
-
def by_name(self, type: Literal["table_function"], name: str) -> TableFunction: ...
|
292
|
-
|
293
|
-
def by_name(self, type: ObjectTypeName, name: str) -> TableInfo | ScalarFunction | TableFunction:
|
294
|
-
assert name is not None
|
295
|
-
assert name != ""
|
296
|
-
if type == "table":
|
297
|
-
table = self.tables_by_name.get(name)
|
298
|
-
if not table:
|
299
|
-
raise flight.FlightServerError(f"Table {name} does not exist.")
|
300
|
-
return table
|
301
|
-
elif type == "scalar_function":
|
302
|
-
scalar_function = self.scalar_functions_by_name.get(name)
|
303
|
-
if not scalar_function:
|
304
|
-
raise flight.FlightServerError(f"Scalar function {name} does not exist.")
|
305
|
-
return scalar_function
|
306
|
-
elif type == "table_function":
|
307
|
-
table_function = self.table_functions_by_name.get(name)
|
308
|
-
if not table_function:
|
309
|
-
raise flight.FlightServerError(f"Table function {name} does not exist.")
|
310
|
-
return table_function
|
311
|
-
|
312
|
-
|
313
|
-
@dataclass
|
314
|
-
class DatabaseContents:
|
315
|
-
# Collection of schemas by name.
|
316
|
-
schemas_by_name: CaseInsensitiveDict[SchemaCollection] = field(
|
317
|
-
default_factory=CaseInsensitiveDict[SchemaCollection]
|
318
|
-
)
|
319
|
-
|
320
|
-
# The version of the database, updated on each schema change.
|
321
|
-
version: int = 1
|
322
|
-
|
323
|
-
def by_name(self, name: str) -> SchemaCollection:
|
324
|
-
if name not in self.schemas_by_name:
|
325
|
-
raise flight.FlightServerError(f"Schema {name} does not exist.")
|
326
|
-
return self.schemas_by_name[name]
|
327
|
-
|
328
|
-
|
329
|
-
@dataclass
|
330
|
-
class DescriptorParts:
|
331
|
-
"""
|
332
|
-
The fields that are encoded in the flight descriptor.
|
333
|
-
"""
|
334
|
-
|
335
|
-
catalog_name: str
|
336
|
-
schema_name: str
|
337
|
-
type: ObjectTypeName
|
338
|
-
name: str
|
339
|
-
|
340
|
-
|
341
|
-
@dataclass
|
342
|
-
class DatabaseLibrary:
|
343
|
-
"""
|
344
|
-
The database library, which contains all of the databases, organized by token.
|
345
|
-
"""
|
346
|
-
|
347
|
-
# Collection of databases by token.
|
348
|
-
databases_by_name: CaseInsensitiveDict[DatabaseContents] = field(
|
349
|
-
default_factory=CaseInsensitiveDict[DatabaseContents]
|
350
|
-
)
|
351
|
-
|
352
|
-
def by_name(self, name: str) -> DatabaseContents:
|
353
|
-
if name not in self.databases_by_name:
|
354
|
-
raise flight.FlightServerError(f"Database {name} does not exist.")
|
355
|
-
return self.databases_by_name[name]
|
356
|
-
|
357
|
-
|
358
|
-
def descriptor_pack_(
|
359
|
-
catalog_name: str,
|
360
|
-
schema_name: str,
|
361
|
-
type: ObjectTypeName,
|
362
|
-
name: str,
|
363
|
-
) -> flight.FlightDescriptor:
|
364
|
-
"""
|
365
|
-
Pack the descriptor into a FlightDescriptor.
|
366
|
-
"""
|
367
|
-
return flight.FlightDescriptor.for_path(f"{catalog_name}/{schema_name}/{type}/{name}")
|
368
|
-
|
369
|
-
|
370
|
-
def descriptor_unpack_(descriptor: flight.FlightDescriptor) -> DescriptorParts:
|
371
|
-
"""
|
372
|
-
Split the descriptor into its components.
|
373
|
-
"""
|
374
|
-
assert descriptor.descriptor_type == flight.DescriptorType.PATH
|
375
|
-
assert len(descriptor.path) == 1
|
376
|
-
path = descriptor.path[0].decode("utf-8")
|
377
|
-
parts = path.split("/")
|
378
|
-
if len(parts) != 4:
|
379
|
-
raise flight.FlightServerError(f"Invalid descriptor path: {path}")
|
380
|
-
|
381
|
-
descriptor_type: ObjectTypeName
|
382
|
-
if parts[2] == "table":
|
383
|
-
descriptor_type = "table"
|
384
|
-
elif parts[2] == "scalar_function":
|
385
|
-
descriptor_type = "scalar_function"
|
386
|
-
elif parts[2] == "table_function":
|
387
|
-
descriptor_type = "table_function"
|
388
|
-
else:
|
389
|
-
raise flight.FlightServerError(f"Invalid descriptor type: {parts[2]}")
|
390
|
-
|
391
|
-
return DescriptorParts(
|
392
|
-
catalog_name=parts[0],
|
393
|
-
schema_name=parts[1],
|
394
|
-
type=descriptor_type,
|
395
|
-
name=parts[3],
|
396
|
-
)
|
397
|
-
|
398
|
-
|
399
94
|
class FlightTicketData(BaseModel):
|
400
95
|
model_config = ConfigDict(arbitrary_types_allowed=True) # for Pydantic v2
|
401
96
|
descriptor: flight.FlightDescriptor
|
@@ -449,9 +144,6 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
|
|
449
144
|
self._auth_manager = auth_manager
|
450
145
|
super().__init__(location=location, **kwargs)
|
451
146
|
|
452
|
-
# token, database name, schema, table_name
|
453
|
-
self.contents: dict[str, DatabaseLibrary] = {}
|
454
|
-
|
455
147
|
self.ROWID_FIELD_NAME = "rowid"
|
456
148
|
self.rowid_field = pa.field(self.ROWID_FIELD_NAME, pa.int64(), metadata={"is_rowid": "1"})
|
457
149
|
|
@@ -464,49 +156,51 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
|
|
464
156
|
assert context.caller is not None
|
465
157
|
|
466
158
|
descriptor_parts = descriptor_unpack_(parameters.descriptor)
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
159
|
+
with DatabaseLibraryContext(context.caller.token.token, readonly=True) as library:
|
160
|
+
database = library.by_name(descriptor_parts.catalog_name)
|
161
|
+
schema = database.by_name(descriptor_parts.schema_name)
|
162
|
+
|
163
|
+
filter_sql_where_clause: str | None = None
|
164
|
+
if parameters.parameters.json_filters is not None:
|
165
|
+
context.logger.debug("duckdb_input", input=json.dumps(parameters.parameters.json_filters.filters))
|
166
|
+
filter_sql_where_clause, filter_sql_field_type_info = (
|
167
|
+
query_farm_duckdb_json_serialization.expression.convert_to_sql(
|
168
|
+
source=parameters.parameters.json_filters.filters,
|
169
|
+
bound_column_names=parameters.parameters.json_filters.column_binding_names_by_index,
|
170
|
+
)
|
478
171
|
)
|
479
|
-
|
480
|
-
|
481
|
-
filter_sql_where_clause = None
|
172
|
+
if filter_sql_where_clause == "":
|
173
|
+
filter_sql_where_clause = None
|
482
174
|
|
483
|
-
|
484
|
-
|
175
|
+
if descriptor_parts.type == "table":
|
176
|
+
table_info = schema.by_name("table", descriptor_parts.name)
|
485
177
|
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
return [flight_handling.endpoint(ticket_data=ticket_data, locations=None)]
|
494
|
-
elif descriptor_parts.type == "table_function":
|
495
|
-
# So the table function may not exist, because its a dynamic descriptor.
|
496
|
-
|
497
|
-
schema.by_name("table_function", descriptor_parts.name)
|
178
|
+
ticket_data = FlightTicketData(
|
179
|
+
descriptor=parameters.descriptor,
|
180
|
+
where_clause=filter_sql_where_clause,
|
181
|
+
at_unit=parameters.parameters.at_unit,
|
182
|
+
at_value=parameters.parameters.at_value,
|
183
|
+
)
|
498
184
|
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
185
|
+
if table_info.endpoint_generator is not None:
|
186
|
+
return table_info.endpoint_generator(ticket_data)
|
187
|
+
return [flight_handling.endpoint(ticket_data=ticket_data, locations=None)]
|
188
|
+
elif descriptor_parts.type == "table_function":
|
189
|
+
# So the table function may not exist, because its a dynamic descriptor.
|
190
|
+
|
191
|
+
schema.by_name("table_function", descriptor_parts.name)
|
192
|
+
|
193
|
+
ticket_data = FlightTicketData(
|
194
|
+
descriptor=parameters.descriptor,
|
195
|
+
where_clause=filter_sql_where_clause,
|
196
|
+
table_function_parameters=parameters.parameters.table_function_parameters,
|
197
|
+
table_function_input_schema=parameters.parameters.table_function_input_schema,
|
198
|
+
at_unit=parameters.parameters.at_unit,
|
199
|
+
at_value=parameters.parameters.at_value,
|
200
|
+
)
|
201
|
+
return [flight_handling.endpoint(ticket_data=ticket_data, locations=None)]
|
202
|
+
else:
|
203
|
+
raise flight.FlightServerError(f"Unsupported descriptor type: {descriptor_parts.type}")
|
510
204
|
|
511
205
|
def action_list_schemas(
|
512
206
|
self,
|
@@ -516,41 +210,40 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
|
|
516
210
|
) -> base_server.AirportSerializedCatalogRoot:
|
517
211
|
assert context.caller is not None
|
518
212
|
|
519
|
-
|
520
|
-
|
521
|
-
database = library.by_name(parameters.catalog_name)
|
213
|
+
with DatabaseLibraryContext(context.caller.token.token, readonly=True) as library:
|
214
|
+
database = library.by_name(parameters.catalog_name)
|
522
215
|
|
523
|
-
|
216
|
+
dynamic_inventory: dict[str, dict[str, list[flight_inventory.FlightInventoryWithMetadata]]] = {}
|
524
217
|
|
525
|
-
|
218
|
+
catalog_contents = dynamic_inventory.setdefault(parameters.catalog_name, {})
|
526
219
|
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
220
|
+
for schema_name, schema in database.schemas_by_name.items():
|
221
|
+
schema_contents = catalog_contents.setdefault(schema_name, [])
|
222
|
+
for coll in schema.containers():
|
223
|
+
for name, obj in coll.items():
|
224
|
+
schema_contents.append(
|
225
|
+
obj.flight_info(
|
226
|
+
name=name,
|
227
|
+
catalog_name=parameters.catalog_name,
|
228
|
+
schema_name=schema_name,
|
229
|
+
)
|
536
230
|
)
|
537
|
-
)
|
538
231
|
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
232
|
+
return flight_inventory.upload_and_generate_schema_list(
|
233
|
+
upload_parameters=flight_inventory.UploadParameters(
|
234
|
+
s3_client=None,
|
235
|
+
base_url="http://localhost",
|
236
|
+
bucket_name="test_bucket",
|
237
|
+
bucket_prefix="test_prefix",
|
238
|
+
),
|
239
|
+
flight_service_name=self.service_name,
|
240
|
+
flight_inventory=dynamic_inventory,
|
241
|
+
schema_details={},
|
242
|
+
skip_upload=True,
|
243
|
+
serialize_inline=True,
|
244
|
+
catalog_version=1,
|
245
|
+
catalog_version_fixed=False,
|
246
|
+
)
|
554
247
|
|
555
248
|
def impl_list_flights(
|
556
249
|
self,
|
@@ -559,17 +252,16 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
|
|
559
252
|
criteria: bytes,
|
560
253
|
) -> Iterator[flight.FlightInfo]:
|
561
254
|
assert context.caller is not None
|
562
|
-
|
563
|
-
library = self.contents[context.caller.token.token]
|
255
|
+
with DatabaseLibraryContext(context.caller.token.token, readonly=True) as library:
|
564
256
|
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
257
|
+
def yield_flight_infos() -> Generator[flight.FlightInfo, None, None]:
|
258
|
+
for db_name, db in library.databases_by_name.items():
|
259
|
+
for schema_name, schema in db.schemas_by_name.items():
|
260
|
+
for coll in schema.containers():
|
261
|
+
for name, obj in coll.items():
|
262
|
+
yield obj.flight_info(name=name, catalog_name=db_name, schema_name=schema_name)[0]
|
571
263
|
|
572
|
-
|
264
|
+
return yield_flight_infos()
|
573
265
|
|
574
266
|
def impl_get_flight_info(
|
575
267
|
self,
|
@@ -578,19 +270,18 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
|
|
578
270
|
descriptor: flight.FlightDescriptor,
|
579
271
|
) -> flight.FlightInfo:
|
580
272
|
assert context.caller is not None
|
581
|
-
self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
|
582
273
|
|
583
274
|
descriptor_parts = descriptor_unpack_(descriptor)
|
584
|
-
|
585
|
-
|
586
|
-
|
275
|
+
with DatabaseLibraryContext(context.caller.token.token, readonly=True) as library:
|
276
|
+
database = library.by_name(descriptor_parts.catalog_name)
|
277
|
+
schema = database.by_name(descriptor_parts.schema_name)
|
587
278
|
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
279
|
+
obj = schema.by_name(descriptor_parts.type, descriptor_parts.name)
|
280
|
+
return obj.flight_info(
|
281
|
+
name=descriptor_parts.name,
|
282
|
+
catalog_name=descriptor_parts.catalog_name,
|
283
|
+
schema_name=descriptor_parts.schema_name,
|
284
|
+
)[0]
|
594
285
|
|
595
286
|
def action_catalog_version(
|
596
287
|
self,
|
@@ -600,16 +291,15 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
|
|
600
291
|
) -> base_server.GetCatalogVersionResult:
|
601
292
|
assert context.caller is not None
|
602
293
|
|
603
|
-
|
604
|
-
|
605
|
-
database = library.by_name(parameters.catalog_name)
|
294
|
+
with DatabaseLibraryContext(context.caller.token.token, readonly=True) as library:
|
295
|
+
database = library.by_name(parameters.catalog_name)
|
606
296
|
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
297
|
+
context.logger.debug(
|
298
|
+
"catalog_version_result",
|
299
|
+
catalog_name=parameters.catalog_name,
|
300
|
+
version=database.version,
|
301
|
+
)
|
302
|
+
return base_server.GetCatalogVersionResult(catalog_version=database.version, is_fixed=False)
|
613
303
|
|
614
304
|
def action_create_transaction(
|
615
305
|
self,
|
@@ -627,26 +317,25 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
|
|
627
317
|
) -> base_server.AirportSerializedContentsWithSHA256Hash:
|
628
318
|
assert context.caller is not None
|
629
319
|
|
630
|
-
|
631
|
-
|
632
|
-
database = library.by_name(parameters.catalog_name)
|
320
|
+
with DatabaseLibraryContext(context.caller.token.token) as library:
|
321
|
+
database = library.by_name(parameters.catalog_name)
|
633
322
|
|
634
|
-
|
635
|
-
|
323
|
+
if database.schemas_by_name.get(parameters.schema_name) is not None:
|
324
|
+
raise flight.FlightServerError(f"Schema {parameters.schema_name} already exists")
|
636
325
|
|
637
|
-
|
638
|
-
|
326
|
+
database.schemas_by_name[parameters.schema_name] = SchemaCollection()
|
327
|
+
database.version += 1
|
639
328
|
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
329
|
+
# FIXME: this needs to be handled better on the server side...
|
330
|
+
# rather than calling into internal methods.
|
331
|
+
packed_data = msgpack.packb([])
|
332
|
+
assert packed_data
|
333
|
+
compressed_data = schema_uploader._compress_and_prefix_with_length(packed_data, compression_level=3)
|
645
334
|
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
335
|
+
empty_hash = hashlib.sha256(compressed_data).hexdigest()
|
336
|
+
return base_server.AirportSerializedContentsWithSHA256Hash(
|
337
|
+
url=None, sha256=empty_hash, serialized=compressed_data
|
338
|
+
)
|
650
339
|
|
651
340
|
def action_drop_table(
|
652
341
|
self,
|
@@ -656,15 +345,14 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
|
|
656
345
|
) -> None:
|
657
346
|
assert context.caller is not None
|
658
347
|
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
schema = database.by_name(parameters.schema_name)
|
348
|
+
with DatabaseLibraryContext(context.caller.token.token) as library:
|
349
|
+
database = library.by_name(parameters.catalog_name)
|
350
|
+
schema = database.by_name(parameters.schema_name)
|
663
351
|
|
664
|
-
|
352
|
+
schema.by_name("table", parameters.name)
|
665
353
|
|
666
|
-
|
667
|
-
|
354
|
+
del schema.tables_by_name[parameters.name]
|
355
|
+
database.version += 1
|
668
356
|
|
669
357
|
def action_drop_schema(
|
670
358
|
self,
|
@@ -674,15 +362,14 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
|
|
674
362
|
) -> None:
|
675
363
|
assert context.caller is not None
|
676
364
|
|
677
|
-
|
678
|
-
|
679
|
-
database = library.by_name(parameters.catalog_name)
|
365
|
+
with DatabaseLibraryContext(context.caller.token.token) as library:
|
366
|
+
database = library.by_name(parameters.catalog_name)
|
680
367
|
|
681
|
-
|
682
|
-
|
368
|
+
if database.schemas_by_name.get(parameters.name) is None:
|
369
|
+
raise flight.FlightServerError(f"Schema '{parameters.name}' does not exist")
|
683
370
|
|
684
|
-
|
685
|
-
|
371
|
+
del database.schemas_by_name[parameters.name]
|
372
|
+
database.version += 1
|
686
373
|
|
687
374
|
def action_create_table(
|
688
375
|
self,
|
@@ -692,28 +379,27 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
|
|
692
379
|
) -> flight.FlightInfo:
|
693
380
|
assert context.caller is not None
|
694
381
|
|
695
|
-
|
696
|
-
|
697
|
-
|
698
|
-
schema = database.by_name(parameters.schema_name)
|
382
|
+
with DatabaseLibraryContext(context.caller.token.token) as library:
|
383
|
+
database = library.by_name(parameters.catalog_name)
|
384
|
+
schema = database.by_name(parameters.schema_name)
|
699
385
|
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
|
386
|
+
if parameters.table_name in schema.tables_by_name:
|
387
|
+
raise flight.FlightServerError(
|
388
|
+
f"Table {parameters.table_name} already exists for token {context.caller.token}"
|
389
|
+
)
|
704
390
|
|
705
|
-
|
706
|
-
|
391
|
+
# FIXME: may want to add a row_id column that is not visable to the user, so that inserts and
|
392
|
+
# deletes can be tested.
|
707
393
|
|
708
|
-
|
394
|
+
assert "_rowid" not in parameters.arrow_schema.names
|
709
395
|
|
710
|
-
|
396
|
+
schema_with_row_id = parameters.arrow_schema.append(self.rowid_field)
|
711
397
|
|
712
|
-
|
398
|
+
table_info = TableInfo([schema_with_row_id.empty_table()], 0)
|
713
399
|
|
714
|
-
|
400
|
+
schema.tables_by_name[parameters.table_name] = table_info
|
715
401
|
|
716
|
-
|
402
|
+
database.version += 1
|
717
403
|
|
718
404
|
return table_info.flight_info(
|
719
405
|
name=parameters.table_name,
|
@@ -731,348 +417,36 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
|
|
731
417
|
|
732
418
|
if action.type == "reset":
|
733
419
|
context.logger.debug("Resetting server state")
|
734
|
-
|
735
|
-
del self.contents[context.caller.token.token]
|
420
|
+
DatabaseLibrary.reset(context.caller.token.token)
|
736
421
|
return iter([])
|
422
|
+
elif action.type == "generate_error":
|
423
|
+
error_name = action.body.to_pybytes().decode("utf-8")
|
424
|
+
known_errors = {
|
425
|
+
"flight_unavailable": flight.FlightUnavailableError,
|
426
|
+
"flight_server_error": flight.FlightServerError,
|
427
|
+
"flight_unauthenticated": flight.FlightUnauthenticatedError,
|
428
|
+
}
|
429
|
+
if error_name in known_errors:
|
430
|
+
raise known_errors[error_name](f"Testing error: {error_name}")
|
431
|
+
else:
|
432
|
+
context.logger.error("Unknown error type", error_name=error_name)
|
433
|
+
raise flight.FlightServerError(f"Unknown error type: {error_name}")
|
737
434
|
elif action.type == "create_database":
|
738
435
|
database_name = action.body.to_pybytes().decode("utf-8")
|
739
436
|
context.logger.debug("Creating database", database_name=database_name)
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
|
744
|
-
# Since we are creating a new database, lets load it with a few example
|
745
|
-
# scalar functions.
|
746
|
-
|
747
|
-
def add_handler(table: pa.Table) -> pa.Array:
|
748
|
-
assert table.num_columns == 2
|
749
|
-
return pc.add(table.column(0), table.column(1))
|
750
|
-
|
751
|
-
def uppercase_handler(table: pa.Table) -> pa.Array:
|
752
|
-
assert table.num_columns == 1
|
753
|
-
return pc.utf8_upper(table.column(0))
|
754
|
-
|
755
|
-
def any_type_handler(table: pa.Table) -> pa.Array:
|
756
|
-
return table.column(0)
|
757
|
-
|
758
|
-
def echo_handler(
|
759
|
-
parameters: parameter_types.TableFunctionParameters,
|
760
|
-
output_schema: pa.Schema,
|
761
|
-
) -> Generator[pa.RecordBatch, pa.RecordBatch, None]:
|
762
|
-
# Just echo the parameters back as a single row.
|
763
|
-
assert parameters.parameters
|
764
|
-
yield pa.RecordBatch.from_arrays(
|
765
|
-
[parameters.parameters.column(0)],
|
766
|
-
schema=pa.schema([pa.field("result", pa.string())]),
|
767
|
-
)
|
768
|
-
|
769
|
-
def long_handler(
|
770
|
-
parameters: parameter_types.TableFunctionParameters,
|
771
|
-
output_schema: pa.Schema,
|
772
|
-
) -> Generator[pa.RecordBatch, pa.RecordBatch, None]:
|
773
|
-
assert parameters.parameters
|
774
|
-
for i in range(100):
|
775
|
-
yield pa.RecordBatch.from_arrays([[f"{i}"] * 3000] * len(output_schema), schema=output_schema)
|
776
|
-
|
777
|
-
def repeat_handler(
|
778
|
-
parameters: parameter_types.TableFunctionParameters,
|
779
|
-
output_schema: pa.Schema,
|
780
|
-
) -> Generator[pa.RecordBatch, pa.RecordBatch, None]:
|
781
|
-
# Just echo the parameters back as a single row.
|
782
|
-
assert parameters.parameters
|
783
|
-
for _i in range(parameters.parameters.column(1).to_pylist()[0]):
|
784
|
-
yield pa.RecordBatch.from_arrays(
|
785
|
-
[parameters.parameters.column(0)],
|
786
|
-
schema=output_schema,
|
787
|
-
)
|
788
|
-
|
789
|
-
def wide_handler(
|
790
|
-
parameters: parameter_types.TableFunctionParameters,
|
791
|
-
output_schema: pa.Schema,
|
792
|
-
) -> Generator[pa.RecordBatch, pa.RecordBatch, None]:
|
793
|
-
# Just echo the parameters back as a single row.
|
794
|
-
assert parameters.parameters
|
795
|
-
rows = []
|
796
|
-
for i in range(parameters.parameters.column(0).to_pylist()[0]):
|
797
|
-
rows.append({f"result_{idx}": idx for idx in range(20)})
|
798
|
-
|
799
|
-
yield pa.RecordBatch.from_pylist(rows, schema=output_schema)
|
800
|
-
|
801
|
-
def dynamic_schema_handler(
|
802
|
-
parameters: parameter_types.TableFunctionParameters,
|
803
|
-
output_schema: pa.Schema,
|
804
|
-
) -> Generator[pa.RecordBatch, pa.RecordBatch, None]:
|
805
|
-
yield parameters.parameters
|
806
|
-
|
807
|
-
def dynamic_schema_handler_output_schema(
|
808
|
-
parameters: pa.RecordBatch, input_schema: pa.Schema | None = None
|
809
|
-
) -> pa.Schema:
|
810
|
-
# This is the schema that will be returned to the client.
|
811
|
-
# It will be used to create the table function.
|
812
|
-
assert isinstance(parameters, pa.RecordBatch)
|
813
|
-
return parameters.schema
|
814
|
-
|
815
|
-
def in_out_schema_handler(parameters: pa.RecordBatch, input_schema: pa.Schema | None = None) -> pa.Schema:
|
816
|
-
assert input_schema is not None
|
817
|
-
return pa.schema([parameters.schema.field(0), input_schema.field(0)])
|
818
|
-
|
819
|
-
def in_out_wide_schema_handler(
|
820
|
-
parameters: pa.RecordBatch, input_schema: pa.Schema | None = None
|
821
|
-
) -> pa.Schema:
|
822
|
-
assert input_schema is not None
|
823
|
-
return pa.schema([pa.field(f"result_{i}", pa.int32()) for i in range(20)])
|
824
|
-
|
825
|
-
def in_out_echo_schema_handler(
|
826
|
-
parameters: pa.RecordBatch, input_schema: pa.Schema | None = None
|
827
|
-
) -> pa.Schema:
|
828
|
-
assert input_schema is not None
|
829
|
-
return input_schema
|
830
|
-
|
831
|
-
def in_out_echo_handler(
|
832
|
-
parameters: parameter_types.TableFunctionParameters,
|
833
|
-
output_schema: pa.Schema,
|
834
|
-
) -> Generator[pa.RecordBatch, pa.RecordBatch, None]:
|
835
|
-
result = output_schema.empty_table()
|
836
|
-
|
837
|
-
while True:
|
838
|
-
input_chunk = yield result
|
839
|
-
|
840
|
-
if input_chunk is None:
|
841
|
-
break
|
842
|
-
|
843
|
-
result = input_chunk
|
844
|
-
|
845
|
-
return
|
846
|
-
|
847
|
-
def in_out_wide_handler(
|
848
|
-
parameters: parameter_types.TableFunctionParameters,
|
849
|
-
output_schema: pa.Schema,
|
850
|
-
) -> Generator[pa.RecordBatch, pa.RecordBatch, None]:
|
851
|
-
result = output_schema.empty_table()
|
852
|
-
|
853
|
-
while True:
|
854
|
-
input_chunk = yield result
|
855
|
-
|
856
|
-
if input_chunk is None:
|
857
|
-
break
|
858
|
-
|
859
|
-
result = pa.RecordBatch.from_arrays(
|
860
|
-
[[i] * len(input_chunk) for i in range(20)],
|
861
|
-
schema=output_schema,
|
862
|
-
)
|
863
|
-
|
864
|
-
return
|
865
|
-
|
866
|
-
def in_out_handler(
|
867
|
-
parameters: parameter_types.TableFunctionParameters,
|
868
|
-
output_schema: pa.Schema,
|
869
|
-
) -> Generator[pa.RecordBatch, pa.RecordBatch, None]:
|
870
|
-
result = output_schema.empty_table()
|
871
|
-
|
872
|
-
while True:
|
873
|
-
input_chunk = yield result
|
874
|
-
|
875
|
-
if input_chunk is None:
|
876
|
-
break
|
877
|
-
|
878
|
-
assert parameters.parameters
|
879
|
-
result = pa.RecordBatch.from_arrays(
|
880
|
-
[
|
881
|
-
parameters.parameters.column(0),
|
882
|
-
input_chunk.column(0),
|
883
|
-
],
|
884
|
-
schema=output_schema,
|
885
|
-
)
|
886
|
-
|
887
|
-
return pa.RecordBatch.from_arrays([["last"], ["row"]], schema=output_schema)
|
888
|
-
|
889
|
-
util_schema = SchemaCollection(
|
890
|
-
scalar_functions_by_name=CaseInsensitiveDict(
|
891
|
-
{
|
892
|
-
"test_uppercase": ScalarFunction(
|
893
|
-
input_schema=pa.schema([pa.field("a", pa.string())]),
|
894
|
-
output_schema=pa.schema([pa.field("result", pa.string())]),
|
895
|
-
handler=uppercase_handler,
|
896
|
-
),
|
897
|
-
"test_any_type": ScalarFunction(
|
898
|
-
input_schema=pa.schema([pa.field("a", pa.string(), metadata={"is_any_type": "1"})]),
|
899
|
-
output_schema=pa.schema([pa.field("result", pa.string())]),
|
900
|
-
handler=any_type_handler,
|
901
|
-
),
|
902
|
-
"test_add": ScalarFunction(
|
903
|
-
input_schema=pa.schema([pa.field("a", pa.int64()), pa.field("b", pa.int64())]),
|
904
|
-
output_schema=pa.schema([pa.field("result", pa.int64())]),
|
905
|
-
handler=add_handler,
|
906
|
-
),
|
907
|
-
}
|
908
|
-
),
|
909
|
-
table_functions_by_name=CaseInsensitiveDict(
|
910
|
-
{
|
911
|
-
"test_echo": TableFunction(
|
912
|
-
input_schema=pa.schema([pa.field("input", pa.string())]),
|
913
|
-
output_schema_source=pa.schema([pa.field("result", pa.string())]),
|
914
|
-
handler=echo_handler,
|
915
|
-
),
|
916
|
-
"test_wide": TableFunction(
|
917
|
-
input_schema=pa.schema([pa.field("count", pa.int32())]),
|
918
|
-
output_schema_source=pa.schema([pa.field(f"result_{i}", pa.int32()) for i in range(20)]),
|
919
|
-
handler=wide_handler,
|
920
|
-
),
|
921
|
-
"test_long": TableFunction(
|
922
|
-
input_schema=pa.schema([pa.field("input", pa.string())]),
|
923
|
-
output_schema_source=pa.schema(
|
924
|
-
[
|
925
|
-
pa.field("result", pa.string()),
|
926
|
-
pa.field("result2", pa.string()),
|
927
|
-
]
|
928
|
-
),
|
929
|
-
handler=long_handler,
|
930
|
-
),
|
931
|
-
"test_repeat": TableFunction(
|
932
|
-
input_schema=pa.schema(
|
933
|
-
[
|
934
|
-
pa.field("input", pa.string()),
|
935
|
-
pa.field("count", pa.int32()),
|
936
|
-
]
|
937
|
-
),
|
938
|
-
output_schema_source=pa.schema([pa.field("result", pa.string())]),
|
939
|
-
handler=repeat_handler,
|
940
|
-
),
|
941
|
-
"test_dynamic_schema": TableFunction(
|
942
|
-
input_schema=pa.schema(
|
943
|
-
[
|
944
|
-
pa.field(
|
945
|
-
"input",
|
946
|
-
pa.string(),
|
947
|
-
metadata={"is_any_type": "1"},
|
948
|
-
)
|
949
|
-
]
|
950
|
-
),
|
951
|
-
output_schema_source=TableFunctionDynamicOutput(
|
952
|
-
schema_creator=dynamic_schema_handler_output_schema,
|
953
|
-
default_values=(
|
954
|
-
pa.RecordBatch.from_arrays(
|
955
|
-
[pa.array([1], type=pa.int32())],
|
956
|
-
schema=pa.schema([pa.field("input", pa.int32())]),
|
957
|
-
),
|
958
|
-
None,
|
959
|
-
),
|
960
|
-
),
|
961
|
-
handler=dynamic_schema_handler,
|
962
|
-
),
|
963
|
-
"test_dynamic_schema_named_parameters": TableFunction(
|
964
|
-
input_schema=pa.schema(
|
965
|
-
[
|
966
|
-
pa.field("name", pa.string()),
|
967
|
-
pa.field(
|
968
|
-
"location",
|
969
|
-
pa.string(),
|
970
|
-
metadata={"is_named_parameter": "1"},
|
971
|
-
),
|
972
|
-
pa.field(
|
973
|
-
"input",
|
974
|
-
pa.string(),
|
975
|
-
metadata={"is_any_type": "1"},
|
976
|
-
),
|
977
|
-
pa.field("city", pa.string()),
|
978
|
-
]
|
979
|
-
),
|
980
|
-
output_schema_source=TableFunctionDynamicOutput(
|
981
|
-
schema_creator=dynamic_schema_handler_output_schema,
|
982
|
-
default_values=(
|
983
|
-
pa.RecordBatch.from_arrays(
|
984
|
-
[pa.array([1], type=pa.int32())],
|
985
|
-
schema=pa.schema([pa.field("input", pa.int32())]),
|
986
|
-
),
|
987
|
-
None,
|
988
|
-
),
|
989
|
-
),
|
990
|
-
handler=dynamic_schema_handler,
|
991
|
-
),
|
992
|
-
"test_table_in_out": TableFunction(
|
993
|
-
input_schema=pa.schema(
|
994
|
-
[
|
995
|
-
pa.field("input", pa.string()),
|
996
|
-
pa.field(
|
997
|
-
"table_input",
|
998
|
-
pa.string(),
|
999
|
-
metadata={"is_table_type": "1"},
|
1000
|
-
),
|
1001
|
-
]
|
1002
|
-
),
|
1003
|
-
output_schema_source=TableFunctionDynamicOutput(
|
1004
|
-
schema_creator=in_out_schema_handler,
|
1005
|
-
default_values=(
|
1006
|
-
pa.RecordBatch.from_arrays(
|
1007
|
-
[pa.array([1], type=pa.int32())],
|
1008
|
-
schema=pa.schema([pa.field("input", pa.int32())]),
|
1009
|
-
),
|
1010
|
-
pa.schema([pa.field("input", pa.int32())]),
|
1011
|
-
),
|
1012
|
-
),
|
1013
|
-
handler=in_out_handler,
|
1014
|
-
),
|
1015
|
-
"test_table_in_out_wide": TableFunction(
|
1016
|
-
input_schema=pa.schema(
|
1017
|
-
[
|
1018
|
-
pa.field("input", pa.string()),
|
1019
|
-
pa.field(
|
1020
|
-
"table_input",
|
1021
|
-
pa.string(),
|
1022
|
-
metadata={"is_table_type": "1"},
|
1023
|
-
),
|
1024
|
-
]
|
1025
|
-
),
|
1026
|
-
output_schema_source=TableFunctionDynamicOutput(
|
1027
|
-
schema_creator=in_out_wide_schema_handler,
|
1028
|
-
default_values=(
|
1029
|
-
pa.RecordBatch.from_arrays(
|
1030
|
-
[pa.array([1], type=pa.int32())],
|
1031
|
-
schema=pa.schema([pa.field("input", pa.int32())]),
|
1032
|
-
),
|
1033
|
-
pa.schema([pa.field("input", pa.int32())]),
|
1034
|
-
),
|
1035
|
-
),
|
1036
|
-
handler=in_out_wide_handler,
|
1037
|
-
),
|
1038
|
-
"test_table_in_out_echo": TableFunction(
|
1039
|
-
input_schema=pa.schema(
|
1040
|
-
[
|
1041
|
-
pa.field(
|
1042
|
-
"table_input",
|
1043
|
-
pa.string(),
|
1044
|
-
metadata={"is_table_type": "1"},
|
1045
|
-
),
|
1046
|
-
]
|
1047
|
-
),
|
1048
|
-
output_schema_source=TableFunctionDynamicOutput(
|
1049
|
-
schema_creator=in_out_echo_schema_handler,
|
1050
|
-
default_values=(
|
1051
|
-
pa.RecordBatch.from_arrays(
|
1052
|
-
[pa.array([1], type=pa.int32())],
|
1053
|
-
schema=pa.schema([pa.field("input", pa.int32())]),
|
1054
|
-
),
|
1055
|
-
pa.schema([pa.field("input", pa.int32())]),
|
1056
|
-
),
|
1057
|
-
),
|
1058
|
-
handler=in_out_echo_handler,
|
1059
|
-
),
|
1060
|
-
}
|
1061
|
-
),
|
1062
|
-
)
|
1063
|
-
|
1064
|
-
library.databases_by_name[database_name].schemas_by_name["utils"] = util_schema
|
1065
|
-
|
437
|
+
with DatabaseLibraryContext(context.caller.token.token) as library:
|
438
|
+
if database_name in library.databases_by_name:
|
439
|
+
raise flight.FlightServerError(f"Database {database_name} already exists")
|
440
|
+
library.databases_by_name[database_name] = DatabaseContents()
|
1066
441
|
return iter([])
|
1067
442
|
elif action.type == "drop_database":
|
1068
443
|
database_name = action.body.to_pybytes().decode("utf-8")
|
1069
444
|
context.logger.debug("Dropping database", database_name=database_name)
|
1070
445
|
|
1071
|
-
|
1072
|
-
|
1073
|
-
|
1074
|
-
|
1075
|
-
del library.databases_by_name[action.body.decode("utf-8")]
|
446
|
+
with DatabaseLibraryContext(context.caller.token.token) as library:
|
447
|
+
if action.body.decode("utf-8") not in library.databases_by_name:
|
448
|
+
raise flight.FlightServerError(f"Database {action.body.decode('utf-8')} does not exist")
|
449
|
+
del library.databases_by_name[action.body.decode("utf-8")]
|
1076
450
|
return iter([])
|
1077
451
|
|
1078
452
|
raise flight.FlightServerError(f"Unsupported action type: {action.type}")
|
@@ -1092,91 +466,92 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
|
|
1092
466
|
|
1093
467
|
if descriptor_parts.type != "table":
|
1094
468
|
raise flight.FlightServerError(f"Unsupported descriptor type: {descriptor_parts.type}")
|
1095
|
-
library = self.contents[context.caller.token.token]
|
1096
|
-
database = library.by_name(descriptor_parts.catalog_name)
|
1097
|
-
schema = database.by_name(descriptor_parts.schema_name)
|
1098
|
-
table_info = schema.by_name("table", descriptor_parts.name)
|
1099
469
|
|
1100
|
-
|
470
|
+
with DatabaseLibraryContext(context.caller.token.token) as library:
|
471
|
+
database = library.by_name(descriptor_parts.catalog_name)
|
472
|
+
schema = database.by_name(descriptor_parts.schema_name)
|
473
|
+
table_info = schema.by_name("table", descriptor_parts.name)
|
1101
474
|
|
1102
|
-
|
475
|
+
existing_table = table_info.version()
|
1103
476
|
|
1104
|
-
|
1105
|
-
assert rowid_index != -1
|
477
|
+
writer.begin(existing_table.schema)
|
1106
478
|
|
1107
|
-
|
479
|
+
rowid_index = existing_table.schema.get_field_index(self.ROWID_FIELD_NAME)
|
480
|
+
assert rowid_index != -1
|
1108
481
|
|
1109
|
-
|
1110
|
-
if chunk.data is not None:
|
1111
|
-
chunk_table = pa.Table.from_batches([chunk.data])
|
1112
|
-
assert chunk_table.num_rows > 0
|
482
|
+
change_count = 0
|
1113
483
|
|
1114
|
-
|
484
|
+
for chunk in reader:
|
485
|
+
if chunk.data is not None:
|
486
|
+
chunk_table = pa.Table.from_batches([chunk.data])
|
487
|
+
assert chunk_table.num_rows > 0
|
1115
488
|
|
1116
|
-
|
489
|
+
# So this chunk will contain any updated columns and the row id.
|
1117
490
|
|
1118
|
-
|
1119
|
-
# then assign new row ids to the incoming rows, and append them to the table.
|
1120
|
-
table_mask = pc.is_in(
|
1121
|
-
existing_table.column(rowid_index),
|
1122
|
-
value_set=chunk_table.column(input_rowid_index),
|
1123
|
-
)
|
491
|
+
input_rowid_index = chunk_table.schema.get_field_index(self.ROWID_FIELD_NAME)
|
1124
492
|
|
1125
|
-
|
1126
|
-
|
493
|
+
# To perform an update, first remove the rows from the table,
|
494
|
+
# then assign new row ids to the incoming rows, and append them to the table.
|
495
|
+
table_mask = pc.is_in(
|
496
|
+
existing_table.column(rowid_index),
|
497
|
+
value_set=chunk_table.column(input_rowid_index),
|
498
|
+
)
|
1127
499
|
|
1128
|
-
|
1129
|
-
|
1130
|
-
# values that aren't being updated.
|
1131
|
-
changed_rows = pc.filter(existing_table, table_mask)
|
500
|
+
# This is the table with updated rows removed, we'll be adding rows to it later on.
|
501
|
+
table_without_updated_rows = pc.filter(existing_table, pc.invert(table_mask))
|
1132
502
|
|
1133
|
-
|
1134
|
-
|
1135
|
-
|
503
|
+
# These are the rows that are being updated, since the update may not send all
|
504
|
+
# the columns, we need to filter the table to get the updated rows to persist the
|
505
|
+
# values that aren't being updated.
|
506
|
+
changed_rows = pc.filter(existing_table, table_mask)
|
1136
507
|
|
1137
|
-
|
1138
|
-
|
1139
|
-
chunk_table
|
1140
|
-
keys=[self.ROWID_FIELD_NAME],
|
1141
|
-
join_type="inner",
|
1142
|
-
)
|
508
|
+
# Get a list of all of the columns that are not being updated, so that a join
|
509
|
+
# can be performed.
|
510
|
+
unchanged_column_names = set(existing_table.schema.names) - set(chunk_table.schema.names)
|
1143
511
|
|
1144
|
-
|
1145
|
-
|
1146
|
-
|
1147
|
-
|
1148
|
-
|
1149
|
-
table_info.row_id_counter,
|
1150
|
-
table_info.row_id_counter + chunk_length,
|
512
|
+
joined_table = pa.Table.join(
|
513
|
+
changed_rows.select(list(unchanged_column_names) + [self.ROWID_FIELD_NAME]),
|
514
|
+
chunk_table,
|
515
|
+
keys=[self.ROWID_FIELD_NAME],
|
516
|
+
join_type="inner",
|
1151
517
|
)
|
1152
|
-
]
|
1153
|
-
updated_rows = joined_table.set_column(
|
1154
|
-
joined_table.schema.get_field_index(self.ROWID_FIELD_NAME),
|
1155
|
-
self.rowid_field,
|
1156
|
-
[rowid_values],
|
1157
|
-
)
|
1158
|
-
table_info.row_id_counter += chunk_length
|
1159
518
|
|
1160
|
-
|
1161
|
-
|
519
|
+
# Add the new row id column.
|
520
|
+
chunk_length = len(joined_table)
|
521
|
+
rowid_values = [
|
522
|
+
x
|
523
|
+
for x in range(
|
524
|
+
table_info.row_id_counter,
|
525
|
+
table_info.row_id_counter + chunk_length,
|
526
|
+
)
|
527
|
+
]
|
528
|
+
updated_rows = joined_table.set_column(
|
529
|
+
joined_table.schema.get_field_index(self.ROWID_FIELD_NAME),
|
530
|
+
self.rowid_field,
|
531
|
+
[rowid_values],
|
532
|
+
)
|
533
|
+
table_info.row_id_counter += chunk_length
|
534
|
+
|
535
|
+
# Now the columns may be in a different order, so we need to realign them.
|
536
|
+
updated_rows = updated_rows.select(existing_table.schema.names)
|
1162
537
|
|
1163
|
-
|
538
|
+
check_schema_is_subset_of_schema(existing_table.schema, updated_rows.schema)
|
1164
539
|
|
1165
|
-
|
540
|
+
updated_rows = conform_nullable(existing_table.schema, updated_rows)
|
1166
541
|
|
1167
|
-
|
1168
|
-
|
1169
|
-
|
1170
|
-
|
1171
|
-
|
1172
|
-
|
542
|
+
updated_table = pa.concat_tables(
|
543
|
+
[
|
544
|
+
table_without_updated_rows,
|
545
|
+
updated_rows.select(table_without_updated_rows.schema.names),
|
546
|
+
]
|
547
|
+
)
|
1173
548
|
|
1174
|
-
|
1175
|
-
|
549
|
+
if return_chunks:
|
550
|
+
writer.write_table(updated_rows)
|
1176
551
|
|
1177
|
-
|
552
|
+
existing_table = updated_table
|
1178
553
|
|
1179
|
-
|
554
|
+
table_info.update_table(existing_table)
|
1180
555
|
|
1181
556
|
return change_count
|
1182
557
|
|
@@ -1195,47 +570,47 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
|
|
1195
570
|
|
1196
571
|
if descriptor_parts.type != "table":
|
1197
572
|
raise flight.FlightServerError(f"Unsupported descriptor type: {descriptor_parts.type}")
|
1198
|
-
|
1199
|
-
|
1200
|
-
|
1201
|
-
|
573
|
+
with DatabaseLibraryContext(context.caller.token.token) as library:
|
574
|
+
database = library.by_name(descriptor_parts.catalog_name)
|
575
|
+
schema = database.by_name(descriptor_parts.schema_name)
|
576
|
+
table_info = schema.by_name("table", descriptor_parts.name)
|
1202
577
|
|
1203
|
-
|
1204
|
-
|
578
|
+
existing_table = table_info.version()
|
579
|
+
writer.begin(existing_table.schema)
|
1205
580
|
|
1206
|
-
|
1207
|
-
|
581
|
+
rowid_index = existing_table.schema.get_field_index(self.ROWID_FIELD_NAME)
|
582
|
+
assert rowid_index != -1
|
1208
583
|
|
1209
|
-
|
584
|
+
change_count = 0
|
1210
585
|
|
1211
|
-
|
1212
|
-
|
1213
|
-
|
1214
|
-
|
586
|
+
for chunk in reader:
|
587
|
+
if chunk.data is not None:
|
588
|
+
chunk_table = pa.Table.from_batches([chunk.data])
|
589
|
+
assert chunk_table.num_rows > 0
|
1215
590
|
|
1216
|
-
|
1217
|
-
|
1218
|
-
|
1219
|
-
|
1220
|
-
|
591
|
+
# Should only be getting the row id.
|
592
|
+
assert chunk_table.num_columns == 1
|
593
|
+
# the rowid field doesn't come back the same since it missing the
|
594
|
+
# not null flag and the metadata, so can't compare the schemas
|
595
|
+
input_rowid_index = chunk_table.schema.get_field_index(self.ROWID_FIELD_NAME)
|
1221
596
|
|
1222
|
-
|
1223
|
-
|
1224
|
-
|
597
|
+
# Now do an antijoin to get the rows that are not in the delete_rows.
|
598
|
+
target_rowids = chunk_table.column(input_rowid_index)
|
599
|
+
existing_row_ids = existing_table.column(rowid_index)
|
1225
600
|
|
1226
|
-
|
601
|
+
mask = pc.is_in(existing_row_ids, value_set=target_rowids)
|
1227
602
|
|
1228
|
-
|
1229
|
-
|
603
|
+
target_rows = pc.filter(existing_table, mask)
|
604
|
+
changed_table = pc.filter(existing_table, pc.invert(mask))
|
1230
605
|
|
1231
|
-
|
606
|
+
change_count += target_rows.num_rows
|
1232
607
|
|
1233
|
-
|
1234
|
-
|
608
|
+
if return_chunks:
|
609
|
+
writer.write_table(target_rows)
|
1235
610
|
|
1236
|
-
|
611
|
+
existing_table = changed_table
|
1237
612
|
|
1238
|
-
|
613
|
+
table_info.update_table(existing_table)
|
1239
614
|
|
1240
615
|
return change_count
|
1241
616
|
|
@@ -1254,65 +629,66 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
|
|
1254
629
|
|
1255
630
|
if descriptor_parts.type != "table":
|
1256
631
|
raise flight.FlightServerError(f"Unsupported descriptor type: {descriptor_parts.type}")
|
1257
|
-
library = self.contents[context.caller.token.token]
|
1258
|
-
database = library.by_name(descriptor_parts.catalog_name)
|
1259
|
-
schema = database.by_name(descriptor_parts.schema_name)
|
1260
|
-
table_info = schema.by_name("table", descriptor_parts.name)
|
1261
632
|
|
1262
|
-
|
1263
|
-
|
1264
|
-
|
633
|
+
with DatabaseLibraryContext(context.caller.token.token) as library:
|
634
|
+
database = library.by_name(descriptor_parts.catalog_name)
|
635
|
+
schema = database.by_name(descriptor_parts.schema_name)
|
636
|
+
table_info = schema.by_name("table", descriptor_parts.name)
|
1265
637
|
|
1266
|
-
|
1267
|
-
|
638
|
+
existing_table = table_info.version()
|
639
|
+
writer.begin(existing_table.schema)
|
640
|
+
change_count = 0
|
1268
641
|
|
1269
|
-
|
642
|
+
rowid_index = existing_table.schema.get_field_index(self.ROWID_FIELD_NAME)
|
643
|
+
assert rowid_index != -1
|
1270
644
|
|
1271
|
-
|
1272
|
-
# to perform an insert, so we need some way to adapt the schema we
|
1273
|
-
check_schema_is_subset_of_schema(existing_table.schema, reader.schema)
|
645
|
+
# Check that the data being read matches the table without the rowid column.
|
1274
646
|
|
1275
|
-
|
647
|
+
# DuckDB won't send field metadata when it sends us the schema that it uses
|
648
|
+
# to perform an insert, so we need some way to adapt the schema we
|
649
|
+
check_schema_is_subset_of_schema(existing_table.schema, reader.schema)
|
1276
650
|
|
1277
|
-
|
1278
|
-
if chunk.data is not None:
|
1279
|
-
new_rows = pa.Table.from_batches([chunk.data])
|
1280
|
-
assert new_rows.num_rows > 0
|
651
|
+
# FIXME: need to handle the case of rowids.
|
1281
652
|
|
1282
|
-
|
1283
|
-
|
1284
|
-
|
1285
|
-
|
1286
|
-
|
1287
|
-
|
1288
|
-
|
1289
|
-
|
1290
|
-
|
1291
|
-
|
1292
|
-
|
1293
|
-
|
1294
|
-
|
1295
|
-
|
1296
|
-
|
1297
|
-
|
1298
|
-
|
1299
|
-
|
1300
|
-
|
1301
|
-
|
1302
|
-
|
1303
|
-
|
1304
|
-
|
1305
|
-
|
1306
|
-
|
1307
|
-
|
1308
|
-
|
1309
|
-
|
1310
|
-
|
1311
|
-
|
1312
|
-
|
1313
|
-
|
1314
|
-
|
1315
|
-
|
653
|
+
for chunk in reader:
|
654
|
+
if chunk.data is not None:
|
655
|
+
new_rows = pa.Table.from_batches([chunk.data])
|
656
|
+
assert new_rows.num_rows > 0
|
657
|
+
|
658
|
+
# append the row id column to the new rows.
|
659
|
+
chunk_length = new_rows.num_rows
|
660
|
+
rowid_values = [
|
661
|
+
x
|
662
|
+
for x in range(
|
663
|
+
table_info.row_id_counter,
|
664
|
+
table_info.row_id_counter + chunk_length,
|
665
|
+
)
|
666
|
+
]
|
667
|
+
new_rows = new_rows.append_column(self.rowid_field, [rowid_values])
|
668
|
+
table_info.row_id_counter += chunk_length
|
669
|
+
change_count += chunk_length
|
670
|
+
|
671
|
+
if return_chunks:
|
672
|
+
writer.write_table(new_rows)
|
673
|
+
|
674
|
+
# Since the table could have columns removed and deleted, use .select
|
675
|
+
# the ensure that the columns are aligned in the same order as the original table.
|
676
|
+
|
677
|
+
# So it turns out that DuckDB doesn't send the "not null" flag in the arrow schema.
|
678
|
+
#
|
679
|
+
# This means we can't concat the tables, without those flags matching.
|
680
|
+
# for field_name in existing_table.schema.names:
|
681
|
+
# field = existing_table.schema.field(field_name)
|
682
|
+
# if not field.nullable:
|
683
|
+
# field_index = new_rows.schema.get_field_index(field_name)
|
684
|
+
# new_rows = new_rows.set_column(
|
685
|
+
# field_index, field.with_nullable(False), new_rows.column(field_index)
|
686
|
+
# )
|
687
|
+
new_rows = conform_nullable(existing_table.schema, new_rows)
|
688
|
+
|
689
|
+
existing_table = pa.concat_tables([existing_table, new_rows.select(existing_table.schema.names)])
|
690
|
+
|
691
|
+
table_info.update_table(existing_table)
|
1316
692
|
|
1317
693
|
return change_count
|
1318
694
|
|
@@ -1330,13 +706,13 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
|
|
1330
706
|
|
1331
707
|
if descriptor_parts.type != "table_function":
|
1332
708
|
raise flight.FlightServerError(f"Unsupported descriptor type: {descriptor_parts.type}")
|
1333
|
-
|
1334
|
-
|
1335
|
-
|
1336
|
-
|
709
|
+
with DatabaseLibraryContext(context.caller.token.token) as library:
|
710
|
+
database = library.by_name(descriptor_parts.catalog_name)
|
711
|
+
schema = database.by_name(descriptor_parts.schema_name)
|
712
|
+
table_info = schema.by_name("table_function", descriptor_parts.name)
|
1337
713
|
|
1338
|
-
|
1339
|
-
|
714
|
+
output_schema = table_info.output_schema(parameters=parameters.parameters, input_schema=input_schema)
|
715
|
+
gen = table_info.handler(parameters, output_schema)
|
1340
716
|
return (output_schema, gen)
|
1341
717
|
|
1342
718
|
def action_add_column(
|
@@ -1347,31 +723,30 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
|
|
1347
723
|
) -> flight.FlightInfo:
|
1348
724
|
assert context.caller is not None
|
1349
725
|
|
1350
|
-
|
1351
|
-
|
1352
|
-
|
1353
|
-
schema = database.by_name(parameters.schema_name)
|
726
|
+
with DatabaseLibraryContext(context.caller.token.token) as library:
|
727
|
+
database = library.by_name(parameters.catalog)
|
728
|
+
schema = database.by_name(parameters.schema_name)
|
1354
729
|
|
1355
|
-
|
730
|
+
table_info = schema.by_name("table", parameters.name)
|
1356
731
|
|
1357
|
-
|
732
|
+
assert len(parameters.column_schema.names) == 1
|
1358
733
|
|
1359
|
-
|
1360
|
-
|
1361
|
-
|
734
|
+
existing_table = table_info.version()
|
735
|
+
# Don't allow duplicate colum names.
|
736
|
+
assert parameters.column_schema.field(0).name not in existing_table.schema.names
|
1362
737
|
|
1363
|
-
|
1364
|
-
|
1365
|
-
|
1366
|
-
|
1367
|
-
|
1368
|
-
|
1369
|
-
|
1370
|
-
|
1371
|
-
|
738
|
+
table_info.update_table(
|
739
|
+
existing_table.append_column(
|
740
|
+
parameters.column_schema.field(0).name,
|
741
|
+
[
|
742
|
+
pa.nulls(
|
743
|
+
existing_table.num_rows,
|
744
|
+
type=parameters.column_schema.field(0).type,
|
745
|
+
)
|
746
|
+
],
|
747
|
+
)
|
1372
748
|
)
|
1373
|
-
|
1374
|
-
database.version += 1
|
749
|
+
database.version += 1
|
1375
750
|
|
1376
751
|
return table_info.flight_info(
|
1377
752
|
name=parameters.name,
|
@@ -1387,15 +762,14 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
|
|
1387
762
|
) -> flight.FlightInfo:
|
1388
763
|
assert context.caller is not None
|
1389
764
|
|
1390
|
-
|
1391
|
-
|
1392
|
-
|
1393
|
-
schema = database.by_name(parameters.schema_name)
|
765
|
+
with DatabaseLibraryContext(context.caller.token.token) as library:
|
766
|
+
database = library.by_name(parameters.catalog)
|
767
|
+
schema = database.by_name(parameters.schema_name)
|
1394
768
|
|
1395
|
-
|
769
|
+
table_info = schema.by_name("table", parameters.name)
|
1396
770
|
|
1397
|
-
|
1398
|
-
|
771
|
+
table_info.update_table(table_info.version().drop(parameters.removed_column))
|
772
|
+
database.version += 1
|
1399
773
|
|
1400
774
|
return table_info.flight_info(
|
1401
775
|
name=parameters.name,
|
@@ -1411,15 +785,14 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
|
|
1411
785
|
) -> flight.FlightInfo:
|
1412
786
|
assert context.caller is not None
|
1413
787
|
|
1414
|
-
|
1415
|
-
|
1416
|
-
|
1417
|
-
schema = database.by_name(parameters.schema_name)
|
788
|
+
with DatabaseLibraryContext(context.caller.token.token) as library:
|
789
|
+
database = library.by_name(parameters.catalog)
|
790
|
+
schema = database.by_name(parameters.schema_name)
|
1418
791
|
|
1419
|
-
|
792
|
+
table_info = schema.by_name("table", parameters.name)
|
1420
793
|
|
1421
|
-
|
1422
|
-
|
794
|
+
table_info.update_table(table_info.version().rename_columns({parameters.old_name: parameters.new_name}))
|
795
|
+
database.version += 1
|
1423
796
|
|
1424
797
|
return table_info.flight_info(
|
1425
798
|
name=parameters.name,
|
@@ -1435,16 +808,15 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
|
|
1435
808
|
) -> flight.FlightInfo:
|
1436
809
|
assert context.caller is not None
|
1437
810
|
|
1438
|
-
|
1439
|
-
|
1440
|
-
|
1441
|
-
schema = database.by_name(parameters.schema_name)
|
811
|
+
with DatabaseLibraryContext(context.caller.token.token) as library:
|
812
|
+
database = library.by_name(parameters.catalog)
|
813
|
+
schema = database.by_name(parameters.schema_name)
|
1442
814
|
|
1443
|
-
|
815
|
+
table_info = schema.by_name("table", parameters.name)
|
1444
816
|
|
1445
|
-
|
817
|
+
schema.tables_by_name[parameters.new_table_name] = schema.tables_by_name.pop(parameters.name)
|
1446
818
|
|
1447
|
-
|
819
|
+
database.version += 1
|
1448
820
|
|
1449
821
|
return table_info.flight_info(
|
1450
822
|
name=parameters.new_table_name,
|
@@ -1460,32 +832,31 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
|
|
1460
832
|
) -> flight.FlightInfo:
|
1461
833
|
assert context.caller is not None
|
1462
834
|
|
1463
|
-
|
1464
|
-
|
1465
|
-
|
1466
|
-
|
1467
|
-
table_info = schema.by_name("table", parameters.name)
|
835
|
+
with DatabaseLibraryContext(context.caller.token.token) as library:
|
836
|
+
database = library.by_name(parameters.catalog)
|
837
|
+
schema = database.by_name(parameters.schema_name)
|
838
|
+
table_info = schema.by_name("table", parameters.name)
|
1468
839
|
|
1469
|
-
|
840
|
+
# Defaults are set as metadata on a field.
|
1470
841
|
|
1471
|
-
|
1472
|
-
|
1473
|
-
|
1474
|
-
|
1475
|
-
|
1476
|
-
|
842
|
+
t = table_info.version()
|
843
|
+
field_index = t.schema.get_field_index(parameters.column_name)
|
844
|
+
field = t.schema.field(parameters.column_name)
|
845
|
+
new_metadata: dict[str, Any] = {}
|
846
|
+
if field.metadata:
|
847
|
+
new_metadata = {**field.metadata}
|
1477
848
|
|
1478
|
-
|
849
|
+
new_metadata["default"] = parameters.expression
|
1479
850
|
|
1480
|
-
|
1481
|
-
|
1482
|
-
|
1483
|
-
|
1484
|
-
|
851
|
+
table_info.update_table(
|
852
|
+
t.set_column(
|
853
|
+
field_index,
|
854
|
+
field.with_metadata(new_metadata),
|
855
|
+
t.column(field_index),
|
856
|
+
)
|
1485
857
|
)
|
1486
|
-
)
|
1487
858
|
|
1488
|
-
|
859
|
+
database.version += 1
|
1489
860
|
|
1490
861
|
return table_info.flight_info(
|
1491
862
|
name=parameters.name,
|
@@ -1501,29 +872,28 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
|
|
1501
872
|
) -> flight.FlightInfo:
|
1502
873
|
assert context.caller is not None
|
1503
874
|
|
1504
|
-
|
1505
|
-
|
1506
|
-
|
1507
|
-
schema = database.by_name(parameters.schema_name)
|
875
|
+
with DatabaseLibraryContext(context.caller.token.token) as library:
|
876
|
+
database = library.by_name(parameters.catalog)
|
877
|
+
schema = database.by_name(parameters.schema_name)
|
1508
878
|
|
1509
|
-
|
879
|
+
table_info = schema.by_name("table", parameters.name)
|
1510
880
|
|
1511
|
-
|
1512
|
-
|
1513
|
-
|
881
|
+
t = table_info.version()
|
882
|
+
field_index = t.schema.get_field_index(parameters.column_name)
|
883
|
+
field = t.schema.field(parameters.column_name)
|
1514
884
|
|
1515
|
-
|
1516
|
-
|
885
|
+
if t.column(field_index).null_count > 0:
|
886
|
+
raise flight.FlightServerError(f"Cannot set column {parameters.column_name} contains null values")
|
1517
887
|
|
1518
|
-
|
1519
|
-
|
1520
|
-
|
1521
|
-
|
1522
|
-
|
888
|
+
table_info.update_table(
|
889
|
+
t.set_column(
|
890
|
+
field_index,
|
891
|
+
field.with_nullable(False),
|
892
|
+
t.column(field_index),
|
893
|
+
)
|
1523
894
|
)
|
1524
|
-
)
|
1525
895
|
|
1526
|
-
|
896
|
+
database.version += 1
|
1527
897
|
|
1528
898
|
return table_info.flight_info(
|
1529
899
|
name=parameters.name,
|
@@ -1539,26 +909,25 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
|
|
1539
909
|
) -> flight.FlightInfo:
|
1540
910
|
assert context.caller is not None
|
1541
911
|
|
1542
|
-
|
1543
|
-
|
1544
|
-
|
1545
|
-
schema = database.by_name(parameters.schema_name)
|
912
|
+
with DatabaseLibraryContext(context.caller.token.token) as library:
|
913
|
+
database = library.by_name(parameters.catalog)
|
914
|
+
schema = database.by_name(parameters.schema_name)
|
1546
915
|
|
1547
|
-
|
916
|
+
table_info = schema.by_name("table", parameters.name)
|
1548
917
|
|
1549
|
-
|
1550
|
-
|
1551
|
-
|
918
|
+
t = table_info.version()
|
919
|
+
field_index = t.schema.get_field_index(parameters.column_name)
|
920
|
+
field = t.schema.field(parameters.column_name)
|
1552
921
|
|
1553
|
-
|
1554
|
-
|
1555
|
-
|
1556
|
-
|
1557
|
-
|
922
|
+
table_info.update_table(
|
923
|
+
t.set_column(
|
924
|
+
field_index,
|
925
|
+
field.with_nullable(True),
|
926
|
+
t.column(field_index),
|
927
|
+
)
|
1558
928
|
)
|
1559
|
-
)
|
1560
929
|
|
1561
|
-
|
930
|
+
database.version += 1
|
1562
931
|
|
1563
932
|
return table_info.flight_info(
|
1564
933
|
name=parameters.name,
|
@@ -1574,33 +943,32 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
|
|
1574
943
|
) -> flight.FlightInfo:
|
1575
944
|
assert context.caller is not None
|
1576
945
|
|
1577
|
-
|
1578
|
-
|
1579
|
-
|
1580
|
-
schema = database.by_name(parameters.schema_name)
|
946
|
+
with DatabaseLibraryContext(context.caller.token.token) as library:
|
947
|
+
database = library.by_name(parameters.catalog)
|
948
|
+
schema = database.by_name(parameters.schema_name)
|
1581
949
|
|
1582
|
-
|
950
|
+
table_info = schema.by_name("table", parameters.name)
|
1583
951
|
|
1584
|
-
|
952
|
+
# Defaults are set as metadata on a field.
|
1585
953
|
|
1586
|
-
|
1587
|
-
|
1588
|
-
|
1589
|
-
|
954
|
+
t = table_info.version()
|
955
|
+
column_name = parameters.column_schema.field(0).name
|
956
|
+
field_index = t.schema.get_field_index(column_name)
|
957
|
+
field = t.schema.field(column_name)
|
1590
958
|
|
1591
|
-
|
1592
|
-
|
1593
|
-
|
959
|
+
new_type = parameters.column_schema.field(0).type
|
960
|
+
new_field = pa.field(field.name, new_type, metadata=field.metadata)
|
961
|
+
new_data = pc.cast(t.column(field_index), new_type)
|
1594
962
|
|
1595
|
-
|
1596
|
-
|
1597
|
-
|
1598
|
-
|
1599
|
-
|
963
|
+
table_info.update_table(
|
964
|
+
t.set_column(
|
965
|
+
field_index,
|
966
|
+
new_field,
|
967
|
+
new_data,
|
968
|
+
)
|
1600
969
|
)
|
1601
|
-
)
|
1602
970
|
|
1603
|
-
|
971
|
+
database.version += 1
|
1604
972
|
|
1605
973
|
return table_info.flight_info(
|
1606
974
|
name=parameters.name,
|
@@ -1608,6 +976,71 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
|
|
1608
976
|
schema_name=parameters.schema_name,
|
1609
977
|
)[0]
|
1610
978
|
|
979
|
+
def action_column_statistics(
|
980
|
+
self,
|
981
|
+
*,
|
982
|
+
context: base_server.CallContext[auth.Account, auth.AccountToken],
|
983
|
+
parameters: parameter_types.ColumnStatistics,
|
984
|
+
) -> pa.Table:
|
985
|
+
assert context.caller is not None
|
986
|
+
|
987
|
+
descriptor_parts = descriptor_unpack_(parameters.flight_descriptor)
|
988
|
+
with DatabaseLibraryContext(context.caller.token.token, readonly=True) as library:
|
989
|
+
database = library.by_name(descriptor_parts.catalog_name)
|
990
|
+
schema = database.by_name(descriptor_parts.schema_name)
|
991
|
+
|
992
|
+
assert descriptor_parts.type == "table"
|
993
|
+
table = schema.by_name("table", descriptor_parts.name)
|
994
|
+
|
995
|
+
contents = table.version().column(parameters.column_name)
|
996
|
+
# Since the table is a Pyarrow table we need to produce some values.
|
997
|
+
not_null_count = pc.count(contents, "only_valid").as_py()
|
998
|
+
null_count = pc.count(contents, "only_null").as_py()
|
999
|
+
distinct_count = len(set(contents.to_pylist()))
|
1000
|
+
sorted_contents = sorted(filter(lambda x: x is not None, contents.to_pylist()))
|
1001
|
+
min_value = sorted_contents[0]
|
1002
|
+
max_value = sorted_contents[-1]
|
1003
|
+
|
1004
|
+
additional_values = {}
|
1005
|
+
additional_schema_fields = []
|
1006
|
+
if contents.type in (pa.string(), pa.utf8(), pa.binary()):
|
1007
|
+
max_length = pc.max(pc.binary_length(contents)).as_py()
|
1008
|
+
|
1009
|
+
additional_values = {"max_string_length": max_length, "contains_unicode": contents.type == pa.utf8()}
|
1010
|
+
additional_schema_fields = [
|
1011
|
+
pa.field("max_string_length", pa.uint64()),
|
1012
|
+
pa.field("contains_unicode", pa.bool_()),
|
1013
|
+
]
|
1014
|
+
|
1015
|
+
if contents.type == pa.uuid():
|
1016
|
+
# For UUIDs, we need to convert them to strings for the output.
|
1017
|
+
min_value = min_value.bytes
|
1018
|
+
max_value = max_value.bytes
|
1019
|
+
|
1020
|
+
result_table = pa.Table.from_pylist(
|
1021
|
+
[
|
1022
|
+
{
|
1023
|
+
"has_not_null": not_null_count > 0,
|
1024
|
+
"has_null": null_count > 0,
|
1025
|
+
"distinct_count": distinct_count,
|
1026
|
+
"min": min_value,
|
1027
|
+
"max": max_value,
|
1028
|
+
**additional_values,
|
1029
|
+
}
|
1030
|
+
],
|
1031
|
+
schema=pa.schema(
|
1032
|
+
[
|
1033
|
+
pa.field("has_not_null", pa.bool_()),
|
1034
|
+
pa.field("has_null", pa.bool_()),
|
1035
|
+
pa.field("distinct_count", pa.uint64()),
|
1036
|
+
pa.field("min", contents.type),
|
1037
|
+
pa.field("max", contents.type),
|
1038
|
+
*additional_schema_fields,
|
1039
|
+
]
|
1040
|
+
),
|
1041
|
+
)
|
1042
|
+
return result_table
|
1043
|
+
|
1611
1044
|
def impl_do_get(
|
1612
1045
|
self,
|
1613
1046
|
*,
|
@@ -1619,58 +1052,60 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
|
|
1619
1052
|
ticket_data = flight_handling.decode_ticket_model(ticket, FlightTicketData)
|
1620
1053
|
|
1621
1054
|
descriptor_parts = descriptor_unpack_(ticket_data.descriptor)
|
1622
|
-
|
1623
|
-
|
1624
|
-
|
1625
|
-
|
1626
|
-
if descriptor_parts.type == "table":
|
1627
|
-
table = schema.by_name("table", descriptor_parts.name)
|
1055
|
+
with DatabaseLibraryContext(context.caller.token.token, readonly=True) as library:
|
1056
|
+
database = library.by_name(descriptor_parts.catalog_name)
|
1057
|
+
schema = database.by_name(descriptor_parts.schema_name)
|
1628
1058
|
|
1629
|
-
if
|
1630
|
-
|
1059
|
+
if descriptor_parts.type == "table":
|
1060
|
+
table = schema.by_name("table", descriptor_parts.name)
|
1631
1061
|
|
1632
|
-
|
1633
|
-
|
1634
|
-
raise flight.FlightServerError(f"Invalid version: {ticket_data.at_value}")
|
1062
|
+
if ticket_data.at_unit == "VERSION":
|
1063
|
+
assert ticket_data.at_value is not None
|
1635
1064
|
|
1636
|
-
|
1637
|
-
|
1638
|
-
|
1639
|
-
else:
|
1640
|
-
table_version = table.version()
|
1641
|
-
|
1642
|
-
if descriptor_parts.schema_name == "test_predicate_pushdown" and ticket_data.where_clause is not None:
|
1643
|
-
# We are going to do the predicate pushdown for filtering the data we have in memory.
|
1644
|
-
# At this point if we have JSON filters we should test that we can decode them.
|
1645
|
-
with duckdb.connect(":memory:") as connection:
|
1646
|
-
connection.execute("SET TimeZone = 'UTC'")
|
1647
|
-
sql = f"select * from table_version where {ticket_data.where_clause}"
|
1648
|
-
try:
|
1649
|
-
results = connection.execute(sql).fetch_arrow_table()
|
1650
|
-
except Exception as e:
|
1651
|
-
raise flight.FlightServerError(f"Failed to execute predicate pushdown: {e} sql: {sql}") from e
|
1652
|
-
table_version = results
|
1653
|
-
|
1654
|
-
return flight.RecordBatchStream(table_version)
|
1655
|
-
elif descriptor_parts.type == "table_function":
|
1656
|
-
table_function = schema.by_name("table_function", descriptor_parts.name)
|
1065
|
+
# Check if at_value is an integer but currently is a string.
|
1066
|
+
if not re.match(r"^\d+$", ticket_data.at_value):
|
1067
|
+
raise flight.FlightServerError(f"Invalid version: {ticket_data.at_value}")
|
1657
1068
|
|
1658
|
-
|
1659
|
-
ticket_data.
|
1660
|
-
|
1661
|
-
|
1069
|
+
table_version = table.version(int(ticket_data.at_value))
|
1070
|
+
elif ticket_data.at_unit == "TIMESTAMP":
|
1071
|
+
raise flight.FlightServerError("Timestamp not supported for table versioning")
|
1072
|
+
else:
|
1073
|
+
table_version = table.version()
|
1074
|
+
|
1075
|
+
if descriptor_parts.schema_name == "test_predicate_pushdown" and ticket_data.where_clause is not None:
|
1076
|
+
# We are going to do the predicate pushdown for filtering the data we have in memory.
|
1077
|
+
# At this point if we have JSON filters we should test that we can decode them.
|
1078
|
+
with duckdb.connect(":memory:") as connection:
|
1079
|
+
connection.execute("SET TimeZone = 'UTC'")
|
1080
|
+
sql = f"select * from table_version where {ticket_data.where_clause}"
|
1081
|
+
try:
|
1082
|
+
results = connection.execute(sql).fetch_arrow_table()
|
1083
|
+
except Exception as e:
|
1084
|
+
raise flight.FlightServerError(
|
1085
|
+
f"Failed to execute predicate pushdown: {e} sql: {sql}"
|
1086
|
+
) from e
|
1087
|
+
table_version = results
|
1088
|
+
|
1089
|
+
return flight.RecordBatchStream(table_version)
|
1090
|
+
elif descriptor_parts.type == "table_function":
|
1091
|
+
table_function = schema.by_name("table_function", descriptor_parts.name)
|
1092
|
+
|
1093
|
+
output_schema = table_function.output_schema(
|
1094
|
+
ticket_data.table_function_parameters,
|
1095
|
+
ticket_data.table_function_input_schema,
|
1096
|
+
)
|
1662
1097
|
|
1663
|
-
|
1664
|
-
output_schema,
|
1665
|
-
table_function.handler(
|
1666
|
-
parameter_types.TableFunctionParameters(
|
1667
|
-
parameters=ticket_data.table_function_parameters, where_clause=ticket_data.where_clause
|
1668
|
-
),
|
1098
|
+
return flight.GeneratorStream(
|
1669
1099
|
output_schema,
|
1670
|
-
|
1671
|
-
|
1672
|
-
|
1673
|
-
|
1100
|
+
table_function.handler(
|
1101
|
+
parameter_types.TableFunctionParameters(
|
1102
|
+
parameters=ticket_data.table_function_parameters, where_clause=ticket_data.where_clause
|
1103
|
+
),
|
1104
|
+
output_schema,
|
1105
|
+
),
|
1106
|
+
)
|
1107
|
+
else:
|
1108
|
+
raise flight.FlightServerError(f"Unsupported descriptor type: {descriptor_parts.type}")
|
1674
1109
|
|
1675
1110
|
def action_table_function_flight_info(
|
1676
1111
|
self,
|
@@ -1680,14 +1115,14 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
|
|
1680
1115
|
) -> flight.FlightInfo:
|
1681
1116
|
assert context.caller is not None
|
1682
1117
|
|
1683
|
-
|
1684
|
-
|
1118
|
+
with DatabaseLibraryContext(context.caller.token.token) as library:
|
1119
|
+
descriptor_parts = descriptor_unpack_(parameters.descriptor)
|
1685
1120
|
|
1686
|
-
|
1687
|
-
|
1121
|
+
database = library.by_name(descriptor_parts.catalog_name)
|
1122
|
+
schema = database.by_name(descriptor_parts.schema_name)
|
1688
1123
|
|
1689
|
-
|
1690
|
-
|
1124
|
+
# Now to get the table function, its a bit harder since they are named by action.
|
1125
|
+
table_function = schema.by_name("table_function", descriptor_parts.name)
|
1691
1126
|
|
1692
1127
|
return table_function.flight_info(
|
1693
1128
|
name=descriptor_parts.name,
|
@@ -1705,40 +1140,40 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
|
|
1705
1140
|
) -> flight.FlightInfo:
|
1706
1141
|
assert context.caller is not None
|
1707
1142
|
|
1708
|
-
|
1709
|
-
|
1710
|
-
|
1711
|
-
|
1712
|
-
|
1713
|
-
|
1714
|
-
|
1715
|
-
|
1716
|
-
|
1717
|
-
|
1718
|
-
|
1719
|
-
|
1720
|
-
|
1721
|
-
|
1722
|
-
|
1723
|
-
|
1724
|
-
|
1725
|
-
|
1726
|
-
|
1727
|
-
|
1728
|
-
|
1729
|
-
|
1730
|
-
|
1731
|
-
|
1732
|
-
|
1733
|
-
|
1734
|
-
|
1735
|
-
|
1736
|
-
|
1737
|
-
|
1738
|
-
|
1739
|
-
|
1740
|
-
|
1741
|
-
|
1143
|
+
with DatabaseLibraryContext(context.caller.token.token) as library:
|
1144
|
+
descriptor_parts = descriptor_unpack_(parameters.descriptor)
|
1145
|
+
|
1146
|
+
database = library.by_name(descriptor_parts.catalog_name)
|
1147
|
+
schema = database.by_name(descriptor_parts.schema_name)
|
1148
|
+
|
1149
|
+
if descriptor_parts.type == "table_function":
|
1150
|
+
raise flight.FlightServerError("Table function flight info not supported")
|
1151
|
+
elif descriptor_parts.type == "table":
|
1152
|
+
table = schema.by_name("table", descriptor_parts.name)
|
1153
|
+
|
1154
|
+
version_id = None
|
1155
|
+
if parameters.at_unit is not None:
|
1156
|
+
if parameters.at_unit == "VERSION":
|
1157
|
+
assert parameters.at_value is not None
|
1158
|
+
|
1159
|
+
# Check if at_value is an integer but currently is a string.
|
1160
|
+
if not re.match(r"^\d+$", parameters.at_value):
|
1161
|
+
raise flight.FlightServerError(f"Invalid version: {parameters.at_value}")
|
1162
|
+
|
1163
|
+
version_id = int(parameters.at_value)
|
1164
|
+
elif parameters.at_unit == "TIMESTAMP":
|
1165
|
+
raise flight.FlightServerError("Timestamp not supported for table versioning")
|
1166
|
+
else:
|
1167
|
+
raise flight.FlightServerError(f"Unsupported at_unit: {parameters.at_unit}")
|
1168
|
+
|
1169
|
+
return table.flight_info(
|
1170
|
+
name=descriptor_parts.name,
|
1171
|
+
catalog_name=descriptor_parts.catalog_name,
|
1172
|
+
schema_name=descriptor_parts.schema_name,
|
1173
|
+
version=version_id,
|
1174
|
+
)[0]
|
1175
|
+
else:
|
1176
|
+
raise flight.FlightServerError(f"Unsupported descriptor type: {descriptor_parts.type}")
|
1742
1177
|
|
1743
1178
|
def exchange_scalar_function(
|
1744
1179
|
self,
|
@@ -1755,10 +1190,10 @@ class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth
|
|
1755
1190
|
if descriptor_parts.type != "scalar_function":
|
1756
1191
|
raise flight.FlightServerError(f"Unsupported descriptor type: {descriptor_parts.type}")
|
1757
1192
|
|
1758
|
-
|
1759
|
-
|
1760
|
-
|
1761
|
-
|
1193
|
+
with DatabaseLibraryContext(context.caller.token.token) as library:
|
1194
|
+
database = library.by_name(descriptor_parts.catalog_name)
|
1195
|
+
schema = database.by_name(descriptor_parts.schema_name)
|
1196
|
+
scalar_function_info = schema.by_name("scalar_function", descriptor_parts.name)
|
1762
1197
|
|
1763
1198
|
writer.begin(scalar_function_info.output_schema)
|
1764
1199
|
|