query-farm-airport-test-server 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- query_farm_airport_test_server/__init__.py +6 -0
- query_farm_airport_test_server/auth.py +9 -0
- query_farm_airport_test_server/py.typed +0 -0
- query_farm_airport_test_server/server.py +1799 -0
- query_farm_airport_test_server/utils.py +182 -0
- query_farm_airport_test_server-0.1.0.dist-info/METADATA +40 -0
- query_farm_airport_test_server-0.1.0.dist-info/RECORD +9 -0
- query_farm_airport_test_server-0.1.0.dist-info/WHEEL +4 -0
- query_farm_airport_test_server-0.1.0.dist-info/entry_points.txt +2 -0
@@ -0,0 +1,1799 @@
|
|
1
|
+
import hashlib
|
2
|
+
import json
|
3
|
+
import re
|
4
|
+
from collections.abc import Callable, Generator, Iterator
|
5
|
+
from dataclasses import dataclass, field
|
6
|
+
from typing import Any, Literal, TypeVar, overload
|
7
|
+
|
8
|
+
import click
|
9
|
+
import duckdb
|
10
|
+
import msgpack
|
11
|
+
import pyarrow as pa
|
12
|
+
import pyarrow.compute as pc
|
13
|
+
import pyarrow.flight as flight
|
14
|
+
import query_farm_duckdb_json_serialization.expression
|
15
|
+
import query_farm_flight_server.auth as auth
|
16
|
+
import query_farm_flight_server.auth_manager as auth_manager
|
17
|
+
import query_farm_flight_server.auth_manager_naive as auth_manager_naive
|
18
|
+
import query_farm_flight_server.flight_handling as flight_handling
|
19
|
+
import query_farm_flight_server.flight_inventory as flight_inventory
|
20
|
+
import query_farm_flight_server.middleware as base_middleware
|
21
|
+
import query_farm_flight_server.parameter_types as parameter_types
|
22
|
+
import query_farm_flight_server.schema_uploader as schema_uploader
|
23
|
+
import query_farm_flight_server.server as base_server
|
24
|
+
import structlog
|
25
|
+
from pydantic import BaseModel, ConfigDict, field_serializer, field_validator
|
26
|
+
|
27
|
+
from .utils import CaseInsensitiveDict
|
28
|
+
|
29
|
+
log = structlog.get_logger()
|
30
|
+
|
31
|
+
|
32
|
+
def read_recordbatch(source: bytes) -> pa.RecordBatch:
|
33
|
+
"""
|
34
|
+
Read a record batch from a byte string.
|
35
|
+
"""
|
36
|
+
buffer = pa.BufferReader(source)
|
37
|
+
ipc_stream = pa.ipc.open_stream(buffer)
|
38
|
+
return next(ipc_stream)
|
39
|
+
|
40
|
+
|
41
|
+
def conform_nullable(schema: pa.Schema, table: pa.Table) -> pa.Table:
|
42
|
+
"""
|
43
|
+
Conform the table to the nullable flags as defined in the schema.
|
44
|
+
|
45
|
+
There shouldn't be null values in the columns.
|
46
|
+
|
47
|
+
This is needed because DuckDB doesn't send the nullable flag in the schema
|
48
|
+
it sends via the DoExchange call.
|
49
|
+
"""
|
50
|
+
for idx, table_field in enumerate(schema):
|
51
|
+
if not table_field.nullable:
|
52
|
+
# Only update the column if the new schema allows nulls where the original did not
|
53
|
+
new_field = table_field.with_nullable(False)
|
54
|
+
|
55
|
+
# Check for null values.
|
56
|
+
if table.column(idx).null_count > 0:
|
57
|
+
raise flight.FlightServerError(
|
58
|
+
f"Column {table_field.name} has null values, but the schema does not allow nulls."
|
59
|
+
)
|
60
|
+
|
61
|
+
table = table.set_column(idx, new_field, table.column(idx))
|
62
|
+
return table
|
63
|
+
|
64
|
+
|
65
|
+
def check_schema_is_subset_of_schema(existing_schema: pa.Schema, new_schema: pa.Schema) -> None:
|
66
|
+
"""
|
67
|
+
Check that the new schema is a subset of the existing schema.
|
68
|
+
"""
|
69
|
+
existing_contents = set([(field.name, field.type) for field in existing_schema])
|
70
|
+
new_contents = set([(field.name, field.type) for field in new_schema])
|
71
|
+
|
72
|
+
unknown_fields = new_contents - existing_contents
|
73
|
+
if unknown_fields:
|
74
|
+
raise flight.FlightServerError(f"Unknown fields in insert: {unknown_fields}")
|
75
|
+
return
|
76
|
+
|
77
|
+
|
78
|
+
# class FlightTicketDataTableFunction(flight_handling.FlightTicketData):
|
79
|
+
# action_name: str
|
80
|
+
# schema_name: str
|
81
|
+
# parameters: bytes
|
82
|
+
|
83
|
+
|
84
|
+
# def model_selector(flight_name: str, src: bytes) -> flight_handling.FlightTicketData | FlightTicketDataTableFunction:
|
85
|
+
# return flight_handling.FlightTicketData.unpack(src)
|
86
|
+
|
87
|
+
|
88
|
+
@dataclass
|
89
|
+
class TableFunctionDynamicOutput:
|
90
|
+
# The method that will determine the output schema from the input parameters
|
91
|
+
schema_creator: Callable[[pa.RecordBatch, pa.Schema | None], pa.Schema]
|
92
|
+
|
93
|
+
# The default parameters for the function, if not called with any.
|
94
|
+
default_values: tuple[pa.RecordBatch, pa.Schema | None]
|
95
|
+
|
96
|
+
|
97
|
+
@dataclass
|
98
|
+
class TableFunction:
|
99
|
+
# The input schema for the function.
|
100
|
+
input_schema: pa.Schema
|
101
|
+
|
102
|
+
output_schema_source: pa.Schema | TableFunctionDynamicOutput
|
103
|
+
|
104
|
+
# The function to call to process a chunk of rows.
|
105
|
+
handler: Callable[
|
106
|
+
[parameter_types.TableFunctionParameters, pa.Schema],
|
107
|
+
Generator[pa.RecordBatch, pa.RecordBatch, pa.RecordBatch],
|
108
|
+
]
|
109
|
+
|
110
|
+
estimated_rows: int | Callable[[parameter_types.TableFunctionFlightInfo], int] = -1
|
111
|
+
|
112
|
+
def output_schema(
|
113
|
+
self,
|
114
|
+
parameters: pa.RecordBatch | None = None,
|
115
|
+
input_schema: pa.Schema | None = None,
|
116
|
+
) -> pa.Schema:
|
117
|
+
if isinstance(self.output_schema_source, pa.Schema):
|
118
|
+
return self.output_schema_source
|
119
|
+
if parameters is None:
|
120
|
+
return self.output_schema_source.schema_creator(*self.output_schema_source.default_values)
|
121
|
+
assert isinstance(parameters, pa.RecordBatch)
|
122
|
+
result = self.output_schema_source.schema_creator(parameters, input_schema)
|
123
|
+
return result
|
124
|
+
|
125
|
+
def flight_info(
|
126
|
+
self,
|
127
|
+
*,
|
128
|
+
name: str,
|
129
|
+
catalog_name: str,
|
130
|
+
schema_name: str,
|
131
|
+
parameters: parameter_types.TableFunctionFlightInfo | None = None,
|
132
|
+
) -> tuple[flight.FlightInfo, flight_inventory.FlightSchemaMetadata]:
|
133
|
+
"""
|
134
|
+
Often its necessary to create a FlightInfo object
|
135
|
+
standardize doing that here.
|
136
|
+
"""
|
137
|
+
assert name != ""
|
138
|
+
assert catalog_name != ""
|
139
|
+
assert schema_name != ""
|
140
|
+
|
141
|
+
if isinstance(self.estimated_rows, int):
|
142
|
+
estimated_rows = self.estimated_rows
|
143
|
+
else:
|
144
|
+
assert parameters is not None
|
145
|
+
estimated_rows = self.estimated_rows(parameters)
|
146
|
+
|
147
|
+
metadata = flight_inventory.FlightSchemaMetadata(
|
148
|
+
type="table_function",
|
149
|
+
catalog=catalog_name,
|
150
|
+
schema=schema_name,
|
151
|
+
name=name,
|
152
|
+
comment=None,
|
153
|
+
input_schema=self.input_schema,
|
154
|
+
)
|
155
|
+
flight_info = flight.FlightInfo(
|
156
|
+
self.output_schema(parameters.parameters, parameters.table_input_schema)
|
157
|
+
if parameters
|
158
|
+
else self.output_schema(),
|
159
|
+
# This will always be the same descriptor, so that we can use the action
|
160
|
+
# name to determine which which table function to execute.
|
161
|
+
descriptor_pack_(catalog_name, schema_name, "table_function", name),
|
162
|
+
[],
|
163
|
+
estimated_rows,
|
164
|
+
-1,
|
165
|
+
app_metadata=metadata.serialize(),
|
166
|
+
)
|
167
|
+
return (flight_info, metadata)
|
168
|
+
|
169
|
+
|
170
|
+
@dataclass
|
171
|
+
class ScalarFunction:
|
172
|
+
# The input schema for the function.
|
173
|
+
input_schema: pa.Schema
|
174
|
+
# The output schema for the function, should only have a single column.
|
175
|
+
output_schema: pa.Schema
|
176
|
+
|
177
|
+
# The function to call to process a chunk of rows.
|
178
|
+
handler: Callable[[pa.Table], pa.Array]
|
179
|
+
|
180
|
+
def flight_info(
|
181
|
+
self, *, name: str, catalog_name: str, schema_name: str
|
182
|
+
) -> tuple[flight.FlightInfo, flight_inventory.FlightSchemaMetadata]:
|
183
|
+
"""
|
184
|
+
Often its necessary to create a FlightInfo object
|
185
|
+
standardize doing that here.
|
186
|
+
"""
|
187
|
+
metadata = flight_inventory.FlightSchemaMetadata(
|
188
|
+
type="scalar_function",
|
189
|
+
catalog=catalog_name,
|
190
|
+
schema=schema_name,
|
191
|
+
name=name,
|
192
|
+
comment=None,
|
193
|
+
input_schema=self.input_schema,
|
194
|
+
)
|
195
|
+
flight_info = flight.FlightInfo(
|
196
|
+
self.output_schema,
|
197
|
+
descriptor_pack_(catalog_name, schema_name, "scalar_function", name),
|
198
|
+
[],
|
199
|
+
-1,
|
200
|
+
-1,
|
201
|
+
app_metadata=metadata.serialize(),
|
202
|
+
)
|
203
|
+
return (flight_info, metadata)
|
204
|
+
|
205
|
+
|
206
|
+
@dataclass
|
207
|
+
class TableInfo:
|
208
|
+
# To enable version history keep track of tables.
|
209
|
+
table_versions: list[pa.Table] = field(default_factory=list)
|
210
|
+
|
211
|
+
# the next row id to assign.
|
212
|
+
row_id_counter: int = 0
|
213
|
+
|
214
|
+
def update_table(self, table: pa.Table) -> None:
|
215
|
+
assert table is not None
|
216
|
+
assert isinstance(table, pa.Table)
|
217
|
+
self.table_versions.append(table)
|
218
|
+
|
219
|
+
def version(self, version: int | None = None) -> pa.Table:
|
220
|
+
"""
|
221
|
+
Get the version of the table.
|
222
|
+
"""
|
223
|
+
assert len(self.table_versions) > 0
|
224
|
+
if version is None:
|
225
|
+
return self.table_versions[-1]
|
226
|
+
|
227
|
+
assert version < len(self.table_versions)
|
228
|
+
return self.table_versions[version]
|
229
|
+
|
230
|
+
def flight_info(
|
231
|
+
self,
|
232
|
+
*,
|
233
|
+
name: str,
|
234
|
+
catalog_name: str,
|
235
|
+
schema_name: str,
|
236
|
+
version: int | None = None,
|
237
|
+
) -> tuple[flight.FlightInfo, flight_inventory.FlightSchemaMetadata]:
|
238
|
+
"""
|
239
|
+
Often its necessary to create a FlightInfo object for the table,
|
240
|
+
standardize doing that here.
|
241
|
+
"""
|
242
|
+
metadata = flight_inventory.FlightSchemaMetadata(
|
243
|
+
type="table",
|
244
|
+
catalog=catalog_name,
|
245
|
+
schema=schema_name,
|
246
|
+
name=name,
|
247
|
+
comment=None,
|
248
|
+
)
|
249
|
+
flight_info = flight.FlightInfo(
|
250
|
+
self.version(version).schema,
|
251
|
+
descriptor_pack_(catalog_name, schema_name, "table", name),
|
252
|
+
[],
|
253
|
+
-1,
|
254
|
+
-1,
|
255
|
+
app_metadata=metadata.serialize(),
|
256
|
+
)
|
257
|
+
return (flight_info, metadata)
|
258
|
+
|
259
|
+
|
260
|
+
ObjectTypeName = Literal["table", "scalar_function", "table_function"]
|
261
|
+
|
262
|
+
|
263
|
+
@dataclass
|
264
|
+
class SchemaCollection:
|
265
|
+
tables_by_name: CaseInsensitiveDict[TableInfo] = field(default_factory=CaseInsensitiveDict[TableInfo])
|
266
|
+
scalar_functions_by_name: CaseInsensitiveDict[ScalarFunction] = field(
|
267
|
+
default_factory=CaseInsensitiveDict[ScalarFunction]
|
268
|
+
)
|
269
|
+
table_functions_by_name: CaseInsensitiveDict[TableFunction] = field(
|
270
|
+
default_factory=CaseInsensitiveDict[TableFunction]
|
271
|
+
)
|
272
|
+
|
273
|
+
def containers(
|
274
|
+
self,
|
275
|
+
) -> list[
|
276
|
+
CaseInsensitiveDict[TableInfo] | CaseInsensitiveDict[ScalarFunction] | CaseInsensitiveDict[TableFunction]
|
277
|
+
]:
|
278
|
+
return [
|
279
|
+
self.tables_by_name,
|
280
|
+
self.scalar_functions_by_name,
|
281
|
+
self.table_functions_by_name,
|
282
|
+
]
|
283
|
+
|
284
|
+
@overload
|
285
|
+
def by_name(self, type: Literal["table"], name: str) -> TableInfo: ...
|
286
|
+
|
287
|
+
@overload
|
288
|
+
def by_name(self, type: Literal["scalar_function"], name: str) -> ScalarFunction: ...
|
289
|
+
|
290
|
+
@overload
|
291
|
+
def by_name(self, type: Literal["table_function"], name: str) -> TableFunction: ...
|
292
|
+
|
293
|
+
def by_name(self, type: ObjectTypeName, name: str) -> TableInfo | ScalarFunction | TableFunction:
|
294
|
+
assert name is not None
|
295
|
+
assert name != ""
|
296
|
+
if type == "table":
|
297
|
+
table = self.tables_by_name.get(name)
|
298
|
+
if not table:
|
299
|
+
raise flight.FlightServerError(f"Table {name} does not exist.")
|
300
|
+
return table
|
301
|
+
elif type == "scalar_function":
|
302
|
+
scalar_function = self.scalar_functions_by_name.get(name)
|
303
|
+
if not scalar_function:
|
304
|
+
raise flight.FlightServerError(f"Scalar function {name} does not exist.")
|
305
|
+
return scalar_function
|
306
|
+
elif type == "table_function":
|
307
|
+
table_function = self.table_functions_by_name.get(name)
|
308
|
+
if not table_function:
|
309
|
+
raise flight.FlightServerError(f"Table function {name} does not exist.")
|
310
|
+
return table_function
|
311
|
+
|
312
|
+
|
313
|
+
@dataclass
|
314
|
+
class DatabaseContents:
|
315
|
+
# Collection of schemas by name.
|
316
|
+
schemas_by_name: CaseInsensitiveDict[SchemaCollection] = field(
|
317
|
+
default_factory=CaseInsensitiveDict[SchemaCollection]
|
318
|
+
)
|
319
|
+
|
320
|
+
# The version of the database, updated on each schema change.
|
321
|
+
version: int = 1
|
322
|
+
|
323
|
+
def by_name(self, name: str) -> SchemaCollection:
|
324
|
+
if name not in self.schemas_by_name:
|
325
|
+
raise flight.FlightServerError(f"Schema {name} does not exist.")
|
326
|
+
return self.schemas_by_name[name]
|
327
|
+
|
328
|
+
|
329
|
+
@dataclass
|
330
|
+
class DescriptorParts:
|
331
|
+
"""
|
332
|
+
The fields that are encoded in the flight descriptor.
|
333
|
+
"""
|
334
|
+
|
335
|
+
catalog_name: str
|
336
|
+
schema_name: str
|
337
|
+
type: ObjectTypeName
|
338
|
+
name: str
|
339
|
+
|
340
|
+
|
341
|
+
@dataclass
|
342
|
+
class DatabaseLibrary:
|
343
|
+
"""
|
344
|
+
The database library, which contains all of the databases, organized by token.
|
345
|
+
"""
|
346
|
+
|
347
|
+
# Collection of databases by token.
|
348
|
+
databases_by_name: CaseInsensitiveDict[DatabaseContents] = field(
|
349
|
+
default_factory=CaseInsensitiveDict[DatabaseContents]
|
350
|
+
)
|
351
|
+
|
352
|
+
def by_name(self, name: str) -> DatabaseContents:
|
353
|
+
if name not in self.databases_by_name:
|
354
|
+
raise flight.FlightServerError(f"Database {name} does not exist.")
|
355
|
+
return self.databases_by_name[name]
|
356
|
+
|
357
|
+
|
358
|
+
def descriptor_pack_(
|
359
|
+
catalog_name: str,
|
360
|
+
schema_name: str,
|
361
|
+
type: ObjectTypeName,
|
362
|
+
name: str,
|
363
|
+
) -> flight.FlightDescriptor:
|
364
|
+
"""
|
365
|
+
Pack the descriptor into a FlightDescriptor.
|
366
|
+
"""
|
367
|
+
return flight.FlightDescriptor.for_path(f"{catalog_name}/{schema_name}/{type}/{name}")
|
368
|
+
|
369
|
+
|
370
|
+
def descriptor_unpack_(descriptor: flight.FlightDescriptor) -> DescriptorParts:
|
371
|
+
"""
|
372
|
+
Split the descriptor into its components.
|
373
|
+
"""
|
374
|
+
assert descriptor.descriptor_type == flight.DescriptorType.PATH
|
375
|
+
assert len(descriptor.path) == 1
|
376
|
+
path = descriptor.path[0].decode("utf-8")
|
377
|
+
parts = path.split("/")
|
378
|
+
if len(parts) != 4:
|
379
|
+
raise flight.FlightServerError(f"Invalid descriptor path: {path}")
|
380
|
+
|
381
|
+
descriptor_type: ObjectTypeName
|
382
|
+
if parts[2] == "table":
|
383
|
+
descriptor_type = "table"
|
384
|
+
elif parts[2] == "scalar_function":
|
385
|
+
descriptor_type = "scalar_function"
|
386
|
+
elif parts[2] == "table_function":
|
387
|
+
descriptor_type = "table_function"
|
388
|
+
else:
|
389
|
+
raise flight.FlightServerError(f"Invalid descriptor type: {parts[2]}")
|
390
|
+
|
391
|
+
return DescriptorParts(
|
392
|
+
catalog_name=parts[0],
|
393
|
+
schema_name=parts[1],
|
394
|
+
type=descriptor_type,
|
395
|
+
name=parts[3],
|
396
|
+
)
|
397
|
+
|
398
|
+
|
399
|
+
class FlightTicketData(BaseModel):
|
400
|
+
model_config = ConfigDict(arbitrary_types_allowed=True) # for Pydantic v2
|
401
|
+
descriptor: flight.FlightDescriptor
|
402
|
+
|
403
|
+
where_clause: str | None = None
|
404
|
+
|
405
|
+
# These are the parameters for the table returning function.
|
406
|
+
table_function_parameters: pa.RecordBatch | None = None
|
407
|
+
table_function_input_schema: pa.Schema | None = None
|
408
|
+
|
409
|
+
at_unit: str | None = None
|
410
|
+
at_value: str | None = None
|
411
|
+
|
412
|
+
_validate_table_function_parameters = field_validator("table_function_parameters", mode="before")(
|
413
|
+
parameter_types.deserialize_record_batch_or_none
|
414
|
+
)
|
415
|
+
|
416
|
+
_validate_table_function_input_schema = field_validator("table_function_input_schema", mode="before")(
|
417
|
+
parameter_types.deserialize_schema_or_none
|
418
|
+
)
|
419
|
+
|
420
|
+
@field_serializer("table_function_parameters")
|
421
|
+
def serialize_table_function_parameters(self, value: pa.RecordBatch, info: Any) -> bytes | None:
|
422
|
+
return parameter_types.serialize_record_batch(value, info)
|
423
|
+
|
424
|
+
@field_serializer("table_function_input_schema")
|
425
|
+
def serialize_table_function_input_Schema(self, value: pa.RecordBatch, info: Any) -> bytes | None:
|
426
|
+
return parameter_types.serialize_schema(value, info)
|
427
|
+
|
428
|
+
_validate_flight_descriptor = field_validator("descriptor", mode="before")(
|
429
|
+
parameter_types.deserialize_flight_descriptor
|
430
|
+
)
|
431
|
+
|
432
|
+
@field_serializer("descriptor")
|
433
|
+
def serialize_flight_descriptor(self, value: flight.FlightDescriptor, info: Any) -> bytes:
|
434
|
+
return parameter_types.serialize_flight_descriptor(value, info)
|
435
|
+
|
436
|
+
|
437
|
+
T = TypeVar("T", bound=BaseModel)
|
438
|
+
|
439
|
+
|
440
|
+
class InMemoryArrowFlightServer(base_server.BasicFlightServer[auth.Account, auth.AccountToken]):
|
441
|
+
def __init__(
|
442
|
+
self,
|
443
|
+
*,
|
444
|
+
location: str | None,
|
445
|
+
auth_manager: auth_manager.AuthManager[auth.Account, auth.AccountToken],
|
446
|
+
**kwargs: dict[str, Any],
|
447
|
+
) -> None:
|
448
|
+
self.service_name = "test_server"
|
449
|
+
self._auth_manager = auth_manager
|
450
|
+
super().__init__(location=location, **kwargs)
|
451
|
+
|
452
|
+
# token, database name, schema, table_name
|
453
|
+
self.contents: dict[str, DatabaseLibrary] = {}
|
454
|
+
|
455
|
+
self.ROWID_FIELD_NAME = "rowid"
|
456
|
+
self.rowid_field = pa.field(self.ROWID_FIELD_NAME, pa.int64(), metadata={"is_rowid": "1"})
|
457
|
+
|
458
|
+
def action_endpoints(
|
459
|
+
self,
|
460
|
+
*,
|
461
|
+
context: base_server.CallContext[auth.Account, auth.AccountToken],
|
462
|
+
parameters: parameter_types.Endpoints,
|
463
|
+
) -> list[flight.FlightEndpoint]:
|
464
|
+
assert context.caller is not None
|
465
|
+
|
466
|
+
descriptor_parts = descriptor_unpack_(parameters.descriptor)
|
467
|
+
library = self.contents[context.caller.token.token]
|
468
|
+
database = library.by_name(descriptor_parts.catalog_name)
|
469
|
+
schema = database.by_name(descriptor_parts.schema_name)
|
470
|
+
|
471
|
+
filter_sql_where_clause: str | None = None
|
472
|
+
if parameters.parameters.json_filters is not None:
|
473
|
+
context.logger.debug("duckdb_input", input=json.dumps(parameters.parameters.json_filters.filters))
|
474
|
+
filter_sql_where_clause, filter_sql_field_type_info = (
|
475
|
+
query_farm_duckdb_json_serialization.expression.convert_to_sql(
|
476
|
+
source=parameters.parameters.json_filters.filters,
|
477
|
+
bound_column_names=parameters.parameters.json_filters.column_binding_names_by_index,
|
478
|
+
)
|
479
|
+
)
|
480
|
+
if filter_sql_where_clause == "":
|
481
|
+
filter_sql_where_clause = None
|
482
|
+
|
483
|
+
if descriptor_parts.type == "table":
|
484
|
+
schema.by_name("table", descriptor_parts.name)
|
485
|
+
|
486
|
+
ticket_data = FlightTicketData(
|
487
|
+
descriptor=parameters.descriptor,
|
488
|
+
where_clause=filter_sql_where_clause,
|
489
|
+
at_unit=parameters.parameters.at_unit,
|
490
|
+
at_value=parameters.parameters.at_value,
|
491
|
+
)
|
492
|
+
|
493
|
+
return [flight_handling.endpoint(ticket_data=ticket_data, locations=None)]
|
494
|
+
elif descriptor_parts.type == "table_function":
|
495
|
+
# So the table function may not exist, because its a dynamic descriptor.
|
496
|
+
|
497
|
+
schema.by_name("table_function", descriptor_parts.name)
|
498
|
+
|
499
|
+
ticket_data = FlightTicketData(
|
500
|
+
descriptor=parameters.descriptor,
|
501
|
+
where_clause=filter_sql_where_clause,
|
502
|
+
table_function_parameters=parameters.parameters.table_function_parameters,
|
503
|
+
table_function_input_schema=parameters.parameters.table_function_input_schema,
|
504
|
+
at_unit=parameters.parameters.at_unit,
|
505
|
+
at_value=parameters.parameters.at_value,
|
506
|
+
)
|
507
|
+
return [flight_handling.endpoint(ticket_data=ticket_data, locations=None)]
|
508
|
+
else:
|
509
|
+
raise flight.FlightServerError(f"Unsupported descriptor type: {descriptor_parts.type}")
|
510
|
+
|
511
|
+
def action_list_schemas(
|
512
|
+
self,
|
513
|
+
*,
|
514
|
+
context: base_server.CallContext[auth.Account, auth.AccountToken],
|
515
|
+
parameters: parameter_types.ListSchemas,
|
516
|
+
) -> base_server.AirportSerializedCatalogRoot:
|
517
|
+
assert context.caller is not None
|
518
|
+
|
519
|
+
self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
|
520
|
+
library = self.contents[context.caller.token.token]
|
521
|
+
database = library.by_name(parameters.catalog_name)
|
522
|
+
|
523
|
+
dynamic_inventory: dict[str, dict[str, list[flight_inventory.FlightInventoryWithMetadata]]] = {}
|
524
|
+
|
525
|
+
catalog_contents = dynamic_inventory.setdefault(parameters.catalog_name, {})
|
526
|
+
|
527
|
+
for schema_name, schema in database.schemas_by_name.items():
|
528
|
+
schema_contents = catalog_contents.setdefault(schema_name, [])
|
529
|
+
for coll in schema.containers():
|
530
|
+
for name, obj in coll.items():
|
531
|
+
schema_contents.append(
|
532
|
+
obj.flight_info(
|
533
|
+
name=name,
|
534
|
+
catalog_name=parameters.catalog_name,
|
535
|
+
schema_name=schema_name,
|
536
|
+
)
|
537
|
+
)
|
538
|
+
|
539
|
+
return flight_inventory.upload_and_generate_schema_list(
|
540
|
+
upload_parameters=flight_inventory.UploadParameters(
|
541
|
+
s3_client=None,
|
542
|
+
base_url="http://localhost",
|
543
|
+
bucket_name="test_bucket",
|
544
|
+
bucket_prefix="test_prefix",
|
545
|
+
),
|
546
|
+
flight_service_name=self.service_name,
|
547
|
+
flight_inventory=dynamic_inventory,
|
548
|
+
schema_details={},
|
549
|
+
skip_upload=True,
|
550
|
+
serialize_inline=True,
|
551
|
+
catalog_version=1,
|
552
|
+
catalog_version_fixed=False,
|
553
|
+
)
|
554
|
+
|
555
|
+
def impl_list_flights(
|
556
|
+
self,
|
557
|
+
*,
|
558
|
+
context: base_server.CallContext[auth.Account, auth.AccountToken],
|
559
|
+
criteria: bytes,
|
560
|
+
) -> Iterator[flight.FlightInfo]:
|
561
|
+
assert context.caller is not None
|
562
|
+
self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
|
563
|
+
library = self.contents[context.caller.token.token]
|
564
|
+
|
565
|
+
def yield_flight_infos() -> Generator[flight.FlightInfo, None, None]:
|
566
|
+
for db_name, db in library.databases_by_name.items():
|
567
|
+
for schema_name, schema in db.schemas_by_name.items():
|
568
|
+
for coll in schema.containers():
|
569
|
+
for name, obj in coll.items():
|
570
|
+
yield obj.flight_info(name=name, catalog_name=db_name, schema_name=schema_name)[0]
|
571
|
+
|
572
|
+
return yield_flight_infos()
|
573
|
+
|
574
|
+
def impl_get_flight_info(
|
575
|
+
self,
|
576
|
+
*,
|
577
|
+
context: base_server.CallContext[auth.Account, auth.AccountToken],
|
578
|
+
descriptor: flight.FlightDescriptor,
|
579
|
+
) -> flight.FlightInfo:
|
580
|
+
assert context.caller is not None
|
581
|
+
self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
|
582
|
+
|
583
|
+
descriptor_parts = descriptor_unpack_(descriptor)
|
584
|
+
library = self.contents[context.caller.token.token]
|
585
|
+
database = library.by_name(descriptor_parts.catalog_name)
|
586
|
+
schema = database.by_name(descriptor_parts.schema_name)
|
587
|
+
|
588
|
+
obj = schema.by_name(descriptor_parts.type, descriptor_parts.name)
|
589
|
+
return obj.flight_info(
|
590
|
+
name=descriptor_parts.name,
|
591
|
+
catalog_name=descriptor_parts.catalog_name,
|
592
|
+
schema_name=descriptor_parts.schema_name,
|
593
|
+
)[0]
|
594
|
+
|
595
|
+
def action_catalog_version(
|
596
|
+
self,
|
597
|
+
*,
|
598
|
+
context: base_server.CallContext[auth.Account, auth.AccountToken],
|
599
|
+
parameters: parameter_types.CatalogVersion,
|
600
|
+
) -> base_server.GetCatalogVersionResult:
|
601
|
+
assert context.caller is not None
|
602
|
+
|
603
|
+
self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
|
604
|
+
library = self.contents[context.caller.token.token]
|
605
|
+
database = library.by_name(parameters.catalog_name)
|
606
|
+
|
607
|
+
context.logger.debug(
|
608
|
+
"catalog_version_result",
|
609
|
+
catalog_name=parameters.catalog_name,
|
610
|
+
version=database.version,
|
611
|
+
)
|
612
|
+
return base_server.GetCatalogVersionResult(catalog_version=database.version, is_fixed=False)
|
613
|
+
|
614
|
+
def action_create_transaction(
|
615
|
+
self,
|
616
|
+
*,
|
617
|
+
context: base_server.CallContext[auth.Account, auth.AccountToken],
|
618
|
+
parameters: parameter_types.CreateTransaction,
|
619
|
+
) -> base_server.CreateTransactionResult:
|
620
|
+
return base_server.CreateTransactionResult(identifier=None)
|
621
|
+
|
622
|
+
def action_create_schema(
|
623
|
+
self,
|
624
|
+
*,
|
625
|
+
context: base_server.CallContext[auth.Account, auth.AccountToken],
|
626
|
+
parameters: parameter_types.CreateSchema,
|
627
|
+
) -> base_server.AirportSerializedContentsWithSHA256Hash:
|
628
|
+
assert context.caller is not None
|
629
|
+
|
630
|
+
self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
|
631
|
+
library = self.contents[context.caller.token.token]
|
632
|
+
database = library.by_name(parameters.catalog_name)
|
633
|
+
|
634
|
+
if database.schemas_by_name.get(parameters.schema_name) is not None:
|
635
|
+
raise flight.FlightServerError(f"Schema {parameters.schema_name} already exists")
|
636
|
+
|
637
|
+
database.schemas_by_name[parameters.schema_name] = SchemaCollection()
|
638
|
+
database.version += 1
|
639
|
+
|
640
|
+
# FIXME: this needs to be handled better on the server side...
|
641
|
+
# rather than calling into internal methods.
|
642
|
+
packed_data = msgpack.packb([])
|
643
|
+
assert packed_data
|
644
|
+
compressed_data = schema_uploader._compress_and_prefix_with_length(packed_data, compression_level=3)
|
645
|
+
|
646
|
+
empty_hash = hashlib.sha256(compressed_data).hexdigest()
|
647
|
+
return base_server.AirportSerializedContentsWithSHA256Hash(
|
648
|
+
url=None, sha256=empty_hash, serialized=compressed_data
|
649
|
+
)
|
650
|
+
|
651
|
+
def action_drop_table(
|
652
|
+
self,
|
653
|
+
*,
|
654
|
+
context: base_server.CallContext[auth.Account, auth.AccountToken],
|
655
|
+
parameters: parameter_types.DropObject,
|
656
|
+
) -> None:
|
657
|
+
assert context.caller is not None
|
658
|
+
|
659
|
+
self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
|
660
|
+
library = self.contents[context.caller.token.token]
|
661
|
+
database = library.by_name(parameters.catalog_name)
|
662
|
+
schema = database.by_name(parameters.schema_name)
|
663
|
+
|
664
|
+
schema.by_name("table", parameters.name)
|
665
|
+
|
666
|
+
del schema.tables_by_name[parameters.name]
|
667
|
+
database.version += 1
|
668
|
+
|
669
|
+
def action_drop_schema(
|
670
|
+
self,
|
671
|
+
*,
|
672
|
+
context: base_server.CallContext[auth.Account, auth.AccountToken],
|
673
|
+
parameters: parameter_types.DropObject,
|
674
|
+
) -> None:
|
675
|
+
assert context.caller is not None
|
676
|
+
|
677
|
+
self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
|
678
|
+
library = self.contents[context.caller.token.token]
|
679
|
+
database = library.by_name(parameters.catalog_name)
|
680
|
+
|
681
|
+
if database.schemas_by_name.get(parameters.name) is None:
|
682
|
+
raise flight.FlightServerError(f"Schema '{parameters.name}' does not exist")
|
683
|
+
|
684
|
+
del database.schemas_by_name[parameters.name]
|
685
|
+
database.version += 1
|
686
|
+
|
687
|
+
def action_create_table(
|
688
|
+
self,
|
689
|
+
*,
|
690
|
+
context: base_server.CallContext[auth.Account, auth.AccountToken],
|
691
|
+
parameters: parameter_types.CreateTable,
|
692
|
+
) -> flight.FlightInfo:
|
693
|
+
assert context.caller is not None
|
694
|
+
|
695
|
+
self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
|
696
|
+
library = self.contents[context.caller.token.token]
|
697
|
+
database = library.by_name(parameters.catalog_name)
|
698
|
+
schema = database.by_name(parameters.schema_name)
|
699
|
+
|
700
|
+
if parameters.table_name in schema.tables_by_name:
|
701
|
+
raise flight.FlightServerError(
|
702
|
+
f"Table {parameters.table_name} already exists for token {context.caller.token}"
|
703
|
+
)
|
704
|
+
|
705
|
+
# FIXME: may want to add a row_id column that is not visable to the user, so that inserts and
|
706
|
+
# deletes can be tested.
|
707
|
+
|
708
|
+
assert "_rowid" not in parameters.arrow_schema.names
|
709
|
+
|
710
|
+
schema_with_row_id = parameters.arrow_schema.append(self.rowid_field)
|
711
|
+
|
712
|
+
table_info = TableInfo([schema_with_row_id.empty_table()], 0)
|
713
|
+
|
714
|
+
schema.tables_by_name[parameters.table_name] = table_info
|
715
|
+
|
716
|
+
database.version += 1
|
717
|
+
|
718
|
+
return table_info.flight_info(
|
719
|
+
name=parameters.table_name,
|
720
|
+
catalog_name=parameters.catalog_name,
|
721
|
+
schema_name=parameters.schema_name,
|
722
|
+
)[0]
|
723
|
+
|
724
|
+
def impl_do_action(
|
725
|
+
self,
|
726
|
+
*,
|
727
|
+
context: base_server.CallContext[auth.Account, auth.AccountToken],
|
728
|
+
action: flight.Action,
|
729
|
+
) -> Iterator[bytes]:
|
730
|
+
assert context.caller is not None
|
731
|
+
|
732
|
+
if action.type == "reset":
|
733
|
+
context.logger.debug("Resetting server state")
|
734
|
+
if context.caller.token.token in self.contents:
|
735
|
+
del self.contents[context.caller.token.token]
|
736
|
+
return iter([])
|
737
|
+
elif action.type == "create_database":
|
738
|
+
database_name = action.body.to_pybytes().decode("utf-8")
|
739
|
+
context.logger.debug("Creating database", database_name=database_name)
|
740
|
+
self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
|
741
|
+
library = self.contents[context.caller.token.token]
|
742
|
+
library.databases_by_name[database_name] = DatabaseContents()
|
743
|
+
|
744
|
+
# Since we are creating a new database, lets load it with a few example
|
745
|
+
# scalar functions.
|
746
|
+
|
747
|
+
def add_handler(table: pa.Table) -> pa.Array:
|
748
|
+
assert table.num_columns == 2
|
749
|
+
return pc.add(table.column(0), table.column(1))
|
750
|
+
|
751
|
+
def uppercase_handler(table: pa.Table) -> pa.Array:
|
752
|
+
assert table.num_columns == 1
|
753
|
+
return pc.utf8_upper(table.column(0))
|
754
|
+
|
755
|
+
def any_type_handler(table: pa.Table) -> pa.Array:
|
756
|
+
return table.column(0)
|
757
|
+
|
758
|
+
def echo_handler(
|
759
|
+
parameters: parameter_types.TableFunctionParameters,
|
760
|
+
output_schema: pa.Schema,
|
761
|
+
) -> Generator[pa.RecordBatch, pa.RecordBatch, None]:
|
762
|
+
# Just echo the parameters back as a single row.
|
763
|
+
assert parameters.parameters
|
764
|
+
yield pa.RecordBatch.from_arrays(
|
765
|
+
[parameters.parameters.column(0)],
|
766
|
+
schema=pa.schema([pa.field("result", pa.string())]),
|
767
|
+
)
|
768
|
+
|
769
|
+
def long_handler(
|
770
|
+
parameters: parameter_types.TableFunctionParameters,
|
771
|
+
output_schema: pa.Schema,
|
772
|
+
) -> Generator[pa.RecordBatch, pa.RecordBatch, None]:
|
773
|
+
assert parameters.parameters
|
774
|
+
for i in range(100):
|
775
|
+
yield pa.RecordBatch.from_arrays([[f"{i}"] * 3000] * len(output_schema), schema=output_schema)
|
776
|
+
|
777
|
+
def repeat_handler(
|
778
|
+
parameters: parameter_types.TableFunctionParameters,
|
779
|
+
output_schema: pa.Schema,
|
780
|
+
) -> Generator[pa.RecordBatch, pa.RecordBatch, None]:
|
781
|
+
# Just echo the parameters back as a single row.
|
782
|
+
assert parameters.parameters
|
783
|
+
for _i in range(parameters.parameters.column(1).to_pylist()[0]):
|
784
|
+
yield pa.RecordBatch.from_arrays(
|
785
|
+
[parameters.parameters.column(0)],
|
786
|
+
schema=output_schema,
|
787
|
+
)
|
788
|
+
|
789
|
+
def wide_handler(
|
790
|
+
parameters: parameter_types.TableFunctionParameters,
|
791
|
+
output_schema: pa.Schema,
|
792
|
+
) -> Generator[pa.RecordBatch, pa.RecordBatch, None]:
|
793
|
+
# Just echo the parameters back as a single row.
|
794
|
+
assert parameters.parameters
|
795
|
+
rows = []
|
796
|
+
for i in range(parameters.parameters.column(0).to_pylist()[0]):
|
797
|
+
rows.append({f"result_{idx}": idx for idx in range(20)})
|
798
|
+
|
799
|
+
yield pa.RecordBatch.from_pylist(rows, schema=output_schema)
|
800
|
+
|
801
|
+
def dynamic_schema_handler(
|
802
|
+
parameters: parameter_types.TableFunctionParameters,
|
803
|
+
output_schema: pa.Schema,
|
804
|
+
) -> Generator[pa.RecordBatch, pa.RecordBatch, None]:
|
805
|
+
yield parameters.parameters
|
806
|
+
|
807
|
+
def dynamic_schema_handler_output_schema(
|
808
|
+
parameters: pa.RecordBatch, input_schema: pa.Schema | None = None
|
809
|
+
) -> pa.Schema:
|
810
|
+
# This is the schema that will be returned to the client.
|
811
|
+
# It will be used to create the table function.
|
812
|
+
assert isinstance(parameters, pa.RecordBatch)
|
813
|
+
return parameters.schema
|
814
|
+
|
815
|
+
def in_out_schema_handler(parameters: pa.RecordBatch, input_schema: pa.Schema | None = None) -> pa.Schema:
|
816
|
+
assert input_schema is not None
|
817
|
+
return pa.schema([parameters.schema.field(0), input_schema.field(0)])
|
818
|
+
|
819
|
+
def in_out_wide_schema_handler(
|
820
|
+
parameters: pa.RecordBatch, input_schema: pa.Schema | None = None
|
821
|
+
) -> pa.Schema:
|
822
|
+
assert input_schema is not None
|
823
|
+
return pa.schema([pa.field(f"result_{i}", pa.int32()) for i in range(20)])
|
824
|
+
|
825
|
+
def in_out_echo_schema_handler(
|
826
|
+
parameters: pa.RecordBatch, input_schema: pa.Schema | None = None
|
827
|
+
) -> pa.Schema:
|
828
|
+
assert input_schema is not None
|
829
|
+
return input_schema
|
830
|
+
|
831
|
+
def in_out_echo_handler(
|
832
|
+
parameters: parameter_types.TableFunctionParameters,
|
833
|
+
output_schema: pa.Schema,
|
834
|
+
) -> Generator[pa.RecordBatch, pa.RecordBatch, None]:
|
835
|
+
result = output_schema.empty_table()
|
836
|
+
|
837
|
+
while True:
|
838
|
+
input_chunk = yield result
|
839
|
+
|
840
|
+
if input_chunk is None:
|
841
|
+
break
|
842
|
+
|
843
|
+
result = input_chunk
|
844
|
+
|
845
|
+
return
|
846
|
+
|
847
|
+
def in_out_wide_handler(
|
848
|
+
parameters: parameter_types.TableFunctionParameters,
|
849
|
+
output_schema: pa.Schema,
|
850
|
+
) -> Generator[pa.RecordBatch, pa.RecordBatch, None]:
|
851
|
+
result = output_schema.empty_table()
|
852
|
+
|
853
|
+
while True:
|
854
|
+
input_chunk = yield result
|
855
|
+
|
856
|
+
if input_chunk is None:
|
857
|
+
break
|
858
|
+
|
859
|
+
result = pa.RecordBatch.from_arrays(
|
860
|
+
[[i] * len(input_chunk) for i in range(20)],
|
861
|
+
schema=output_schema,
|
862
|
+
)
|
863
|
+
|
864
|
+
return
|
865
|
+
|
866
|
+
def in_out_handler(
|
867
|
+
parameters: parameter_types.TableFunctionParameters,
|
868
|
+
output_schema: pa.Schema,
|
869
|
+
) -> Generator[pa.RecordBatch, pa.RecordBatch, None]:
|
870
|
+
result = output_schema.empty_table()
|
871
|
+
|
872
|
+
while True:
|
873
|
+
input_chunk = yield result
|
874
|
+
|
875
|
+
if input_chunk is None:
|
876
|
+
break
|
877
|
+
|
878
|
+
assert parameters.parameters
|
879
|
+
result = pa.RecordBatch.from_arrays(
|
880
|
+
[
|
881
|
+
parameters.parameters.column(0),
|
882
|
+
input_chunk.column(0),
|
883
|
+
],
|
884
|
+
schema=output_schema,
|
885
|
+
)
|
886
|
+
|
887
|
+
return pa.RecordBatch.from_arrays([["last"], ["row"]], schema=output_schema)
|
888
|
+
|
889
|
+
util_schema = SchemaCollection(
|
890
|
+
scalar_functions_by_name=CaseInsensitiveDict(
|
891
|
+
{
|
892
|
+
"test_uppercase": ScalarFunction(
|
893
|
+
input_schema=pa.schema([pa.field("a", pa.string())]),
|
894
|
+
output_schema=pa.schema([pa.field("result", pa.string())]),
|
895
|
+
handler=uppercase_handler,
|
896
|
+
),
|
897
|
+
"test_any_type": ScalarFunction(
|
898
|
+
input_schema=pa.schema([pa.field("a", pa.string(), metadata={"is_any_type": "1"})]),
|
899
|
+
output_schema=pa.schema([pa.field("result", pa.string())]),
|
900
|
+
handler=any_type_handler,
|
901
|
+
),
|
902
|
+
"test_add": ScalarFunction(
|
903
|
+
input_schema=pa.schema([pa.field("a", pa.int64()), pa.field("b", pa.int64())]),
|
904
|
+
output_schema=pa.schema([pa.field("result", pa.int64())]),
|
905
|
+
handler=add_handler,
|
906
|
+
),
|
907
|
+
}
|
908
|
+
),
|
909
|
+
table_functions_by_name=CaseInsensitiveDict(
|
910
|
+
{
|
911
|
+
"test_echo": TableFunction(
|
912
|
+
input_schema=pa.schema([pa.field("input", pa.string())]),
|
913
|
+
output_schema_source=pa.schema([pa.field("result", pa.string())]),
|
914
|
+
handler=echo_handler,
|
915
|
+
),
|
916
|
+
"test_wide": TableFunction(
|
917
|
+
input_schema=pa.schema([pa.field("count", pa.int32())]),
|
918
|
+
output_schema_source=pa.schema([pa.field(f"result_{i}", pa.int32()) for i in range(20)]),
|
919
|
+
handler=wide_handler,
|
920
|
+
),
|
921
|
+
"test_long": TableFunction(
|
922
|
+
input_schema=pa.schema([pa.field("input", pa.string())]),
|
923
|
+
output_schema_source=pa.schema(
|
924
|
+
[
|
925
|
+
pa.field("result", pa.string()),
|
926
|
+
pa.field("result2", pa.string()),
|
927
|
+
]
|
928
|
+
),
|
929
|
+
handler=long_handler,
|
930
|
+
),
|
931
|
+
"test_repeat": TableFunction(
|
932
|
+
input_schema=pa.schema(
|
933
|
+
[
|
934
|
+
pa.field("input", pa.string()),
|
935
|
+
pa.field("count", pa.int32()),
|
936
|
+
]
|
937
|
+
),
|
938
|
+
output_schema_source=pa.schema([pa.field("result", pa.string())]),
|
939
|
+
handler=repeat_handler,
|
940
|
+
),
|
941
|
+
"test_dynamic_schema": TableFunction(
|
942
|
+
input_schema=pa.schema(
|
943
|
+
[
|
944
|
+
pa.field(
|
945
|
+
"input",
|
946
|
+
pa.string(),
|
947
|
+
metadata={"is_any_type": "1"},
|
948
|
+
)
|
949
|
+
]
|
950
|
+
),
|
951
|
+
output_schema_source=TableFunctionDynamicOutput(
|
952
|
+
schema_creator=dynamic_schema_handler_output_schema,
|
953
|
+
default_values=(
|
954
|
+
pa.RecordBatch.from_arrays(
|
955
|
+
[pa.array([1], type=pa.int32())],
|
956
|
+
schema=pa.schema([pa.field("input", pa.int32())]),
|
957
|
+
),
|
958
|
+
None,
|
959
|
+
),
|
960
|
+
),
|
961
|
+
handler=dynamic_schema_handler,
|
962
|
+
),
|
963
|
+
"test_dynamic_schema_named_parameters": TableFunction(
|
964
|
+
input_schema=pa.schema(
|
965
|
+
[
|
966
|
+
pa.field("name", pa.string()),
|
967
|
+
pa.field(
|
968
|
+
"location",
|
969
|
+
pa.string(),
|
970
|
+
metadata={"is_named_parameter": "1"},
|
971
|
+
),
|
972
|
+
pa.field(
|
973
|
+
"input",
|
974
|
+
pa.string(),
|
975
|
+
metadata={"is_any_type": "1"},
|
976
|
+
),
|
977
|
+
pa.field("city", pa.string()),
|
978
|
+
]
|
979
|
+
),
|
980
|
+
output_schema_source=TableFunctionDynamicOutput(
|
981
|
+
schema_creator=dynamic_schema_handler_output_schema,
|
982
|
+
default_values=(
|
983
|
+
pa.RecordBatch.from_arrays(
|
984
|
+
[pa.array([1], type=pa.int32())],
|
985
|
+
schema=pa.schema([pa.field("input", pa.int32())]),
|
986
|
+
),
|
987
|
+
None,
|
988
|
+
),
|
989
|
+
),
|
990
|
+
handler=dynamic_schema_handler,
|
991
|
+
),
|
992
|
+
"test_table_in_out": TableFunction(
|
993
|
+
input_schema=pa.schema(
|
994
|
+
[
|
995
|
+
pa.field("input", pa.string()),
|
996
|
+
pa.field(
|
997
|
+
"table_input",
|
998
|
+
pa.string(),
|
999
|
+
metadata={"is_table_type": "1"},
|
1000
|
+
),
|
1001
|
+
]
|
1002
|
+
),
|
1003
|
+
output_schema_source=TableFunctionDynamicOutput(
|
1004
|
+
schema_creator=in_out_schema_handler,
|
1005
|
+
default_values=(
|
1006
|
+
pa.RecordBatch.from_arrays(
|
1007
|
+
[pa.array([1], type=pa.int32())],
|
1008
|
+
schema=pa.schema([pa.field("input", pa.int32())]),
|
1009
|
+
),
|
1010
|
+
pa.schema([pa.field("input", pa.int32())]),
|
1011
|
+
),
|
1012
|
+
),
|
1013
|
+
handler=in_out_handler,
|
1014
|
+
),
|
1015
|
+
"test_table_in_out_wide": TableFunction(
|
1016
|
+
input_schema=pa.schema(
|
1017
|
+
[
|
1018
|
+
pa.field("input", pa.string()),
|
1019
|
+
pa.field(
|
1020
|
+
"table_input",
|
1021
|
+
pa.string(),
|
1022
|
+
metadata={"is_table_type": "1"},
|
1023
|
+
),
|
1024
|
+
]
|
1025
|
+
),
|
1026
|
+
output_schema_source=TableFunctionDynamicOutput(
|
1027
|
+
schema_creator=in_out_wide_schema_handler,
|
1028
|
+
default_values=(
|
1029
|
+
pa.RecordBatch.from_arrays(
|
1030
|
+
[pa.array([1], type=pa.int32())],
|
1031
|
+
schema=pa.schema([pa.field("input", pa.int32())]),
|
1032
|
+
),
|
1033
|
+
pa.schema([pa.field("input", pa.int32())]),
|
1034
|
+
),
|
1035
|
+
),
|
1036
|
+
handler=in_out_wide_handler,
|
1037
|
+
),
|
1038
|
+
"test_table_in_out_echo": TableFunction(
|
1039
|
+
input_schema=pa.schema(
|
1040
|
+
[
|
1041
|
+
pa.field(
|
1042
|
+
"table_input",
|
1043
|
+
pa.string(),
|
1044
|
+
metadata={"is_table_type": "1"},
|
1045
|
+
),
|
1046
|
+
]
|
1047
|
+
),
|
1048
|
+
output_schema_source=TableFunctionDynamicOutput(
|
1049
|
+
schema_creator=in_out_echo_schema_handler,
|
1050
|
+
default_values=(
|
1051
|
+
pa.RecordBatch.from_arrays(
|
1052
|
+
[pa.array([1], type=pa.int32())],
|
1053
|
+
schema=pa.schema([pa.field("input", pa.int32())]),
|
1054
|
+
),
|
1055
|
+
pa.schema([pa.field("input", pa.int32())]),
|
1056
|
+
),
|
1057
|
+
),
|
1058
|
+
handler=in_out_echo_handler,
|
1059
|
+
),
|
1060
|
+
}
|
1061
|
+
),
|
1062
|
+
)
|
1063
|
+
|
1064
|
+
library.databases_by_name[database_name].schemas_by_name["utils"] = util_schema
|
1065
|
+
|
1066
|
+
return iter([])
|
1067
|
+
elif action.type == "drop_database":
|
1068
|
+
database_name = action.body.to_pybytes().decode("utf-8")
|
1069
|
+
context.logger.debug("Dropping database", database_name=database_name)
|
1070
|
+
|
1071
|
+
self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
|
1072
|
+
library = self.contents[context.caller.token.token]
|
1073
|
+
if action.body.decode("utf-8") not in library.databases_by_name:
|
1074
|
+
raise flight.FlightServerError(f"Database {action.body.decode('utf-8')} does not exist")
|
1075
|
+
del library.databases_by_name[action.body.decode("utf-8")]
|
1076
|
+
return iter([])
|
1077
|
+
|
1078
|
+
raise flight.FlightServerError(f"Unsupported action type: {action.type}")
|
1079
|
+
|
1080
|
+
def exchange_update(
|
1081
|
+
self,
|
1082
|
+
*,
|
1083
|
+
context: base_server.CallContext[auth.Account, auth.AccountToken],
|
1084
|
+
descriptor: flight.FlightDescriptor,
|
1085
|
+
reader: flight.MetadataRecordBatchReader,
|
1086
|
+
writer: flight.MetadataRecordBatchWriter,
|
1087
|
+
return_chunks: bool,
|
1088
|
+
) -> int:
|
1089
|
+
assert context.caller is not None
|
1090
|
+
|
1091
|
+
descriptor_parts = descriptor_unpack_(descriptor)
|
1092
|
+
|
1093
|
+
if descriptor_parts.type != "table":
|
1094
|
+
raise flight.FlightServerError(f"Unsupported descriptor type: {descriptor_parts.type}")
|
1095
|
+
library = self.contents[context.caller.token.token]
|
1096
|
+
database = library.by_name(descriptor_parts.catalog_name)
|
1097
|
+
schema = database.by_name(descriptor_parts.schema_name)
|
1098
|
+
table_info = schema.by_name("table", descriptor_parts.name)
|
1099
|
+
|
1100
|
+
existing_table = table_info.version()
|
1101
|
+
|
1102
|
+
writer.begin(existing_table.schema)
|
1103
|
+
|
1104
|
+
rowid_index = existing_table.schema.get_field_index(self.ROWID_FIELD_NAME)
|
1105
|
+
assert rowid_index != -1
|
1106
|
+
|
1107
|
+
change_count = 0
|
1108
|
+
|
1109
|
+
for chunk in reader:
|
1110
|
+
if chunk.data is not None:
|
1111
|
+
chunk_table = pa.Table.from_batches([chunk.data])
|
1112
|
+
assert chunk_table.num_rows > 0
|
1113
|
+
|
1114
|
+
# So this chunk will contain any updated columns and the row id.
|
1115
|
+
|
1116
|
+
input_rowid_index = chunk_table.schema.get_field_index(self.ROWID_FIELD_NAME)
|
1117
|
+
|
1118
|
+
# To perform an update, first remove the rows from the table,
|
1119
|
+
# then assign new row ids to the incoming rows, and append them to the table.
|
1120
|
+
table_mask = pc.is_in(
|
1121
|
+
existing_table.column(rowid_index),
|
1122
|
+
value_set=chunk_table.column(input_rowid_index),
|
1123
|
+
)
|
1124
|
+
|
1125
|
+
# This is the table with updated rows removed, we'll be adding rows to it later on.
|
1126
|
+
table_without_updated_rows = pc.filter(existing_table, pc.invert(table_mask))
|
1127
|
+
|
1128
|
+
# These are the rows that are being updated, since the update may not send all
|
1129
|
+
# the columns, we need to filter the table to get the updated rows to persist the
|
1130
|
+
# values that aren't being updated.
|
1131
|
+
changed_rows = pc.filter(existing_table, table_mask)
|
1132
|
+
|
1133
|
+
# Get a list of all of the columns that are not being updated, so that a join
|
1134
|
+
# can be performed.
|
1135
|
+
unchanged_column_names = set(existing_table.schema.names) - set(chunk_table.schema.names)
|
1136
|
+
|
1137
|
+
joined_table = pa.Table.join(
|
1138
|
+
changed_rows.select(list(unchanged_column_names) + [self.ROWID_FIELD_NAME]),
|
1139
|
+
chunk_table,
|
1140
|
+
keys=[self.ROWID_FIELD_NAME],
|
1141
|
+
join_type="inner",
|
1142
|
+
)
|
1143
|
+
|
1144
|
+
# Add the new row id column.
|
1145
|
+
chunk_length = len(joined_table)
|
1146
|
+
rowid_values = [
|
1147
|
+
x
|
1148
|
+
for x in range(
|
1149
|
+
table_info.row_id_counter,
|
1150
|
+
table_info.row_id_counter + chunk_length,
|
1151
|
+
)
|
1152
|
+
]
|
1153
|
+
updated_rows = joined_table.set_column(
|
1154
|
+
joined_table.schema.get_field_index(self.ROWID_FIELD_NAME),
|
1155
|
+
self.rowid_field,
|
1156
|
+
[rowid_values],
|
1157
|
+
)
|
1158
|
+
table_info.row_id_counter += chunk_length
|
1159
|
+
|
1160
|
+
# Now the columns may be in a different order, so we need to realign them.
|
1161
|
+
updated_rows = updated_rows.select(existing_table.schema.names)
|
1162
|
+
|
1163
|
+
check_schema_is_subset_of_schema(existing_table.schema, updated_rows.schema)
|
1164
|
+
|
1165
|
+
updated_rows = conform_nullable(existing_table.schema, updated_rows)
|
1166
|
+
|
1167
|
+
updated_table = pa.concat_tables(
|
1168
|
+
[
|
1169
|
+
table_without_updated_rows,
|
1170
|
+
updated_rows.select(table_without_updated_rows.schema.names),
|
1171
|
+
]
|
1172
|
+
)
|
1173
|
+
|
1174
|
+
if return_chunks:
|
1175
|
+
writer.write_table(updated_rows)
|
1176
|
+
|
1177
|
+
existing_table = updated_table
|
1178
|
+
|
1179
|
+
table_info.update_table(existing_table)
|
1180
|
+
|
1181
|
+
return change_count
|
1182
|
+
|
1183
|
+
def exchange_delete(
|
1184
|
+
self,
|
1185
|
+
*,
|
1186
|
+
context: base_server.CallContext[auth.Account, auth.AccountToken],
|
1187
|
+
descriptor: flight.FlightDescriptor,
|
1188
|
+
reader: flight.MetadataRecordBatchReader,
|
1189
|
+
writer: flight.MetadataRecordBatchWriter,
|
1190
|
+
return_chunks: bool,
|
1191
|
+
) -> int:
|
1192
|
+
assert context.caller is not None
|
1193
|
+
|
1194
|
+
descriptor_parts = descriptor_unpack_(descriptor)
|
1195
|
+
|
1196
|
+
if descriptor_parts.type != "table":
|
1197
|
+
raise flight.FlightServerError(f"Unsupported descriptor type: {descriptor_parts.type}")
|
1198
|
+
library = self.contents[context.caller.token.token]
|
1199
|
+
database = library.by_name(descriptor_parts.catalog_name)
|
1200
|
+
schema = database.by_name(descriptor_parts.schema_name)
|
1201
|
+
table_info = schema.by_name("table", descriptor_parts.name)
|
1202
|
+
|
1203
|
+
existing_table = table_info.version()
|
1204
|
+
writer.begin(existing_table.schema)
|
1205
|
+
|
1206
|
+
rowid_index = existing_table.schema.get_field_index(self.ROWID_FIELD_NAME)
|
1207
|
+
assert rowid_index != -1
|
1208
|
+
|
1209
|
+
change_count = 0
|
1210
|
+
|
1211
|
+
for chunk in reader:
|
1212
|
+
if chunk.data is not None:
|
1213
|
+
chunk_table = pa.Table.from_batches([chunk.data])
|
1214
|
+
assert chunk_table.num_rows > 0
|
1215
|
+
|
1216
|
+
# Should only be getting the row id.
|
1217
|
+
assert chunk_table.num_columns == 1
|
1218
|
+
# the rowid field doesn't come back the same since it missing the
|
1219
|
+
# not null flag and the metadata, so can't compare the schemas
|
1220
|
+
input_rowid_index = chunk_table.schema.get_field_index(self.ROWID_FIELD_NAME)
|
1221
|
+
|
1222
|
+
# Now do an antijoin to get the rows that are not in the delete_rows.
|
1223
|
+
target_rowids = chunk_table.column(input_rowid_index)
|
1224
|
+
existing_row_ids = existing_table.column(rowid_index)
|
1225
|
+
|
1226
|
+
mask = pc.is_in(existing_row_ids, value_set=target_rowids)
|
1227
|
+
|
1228
|
+
target_rows = pc.filter(existing_table, mask)
|
1229
|
+
changed_table = pc.filter(existing_table, pc.invert(mask))
|
1230
|
+
|
1231
|
+
change_count += target_rows.num_rows
|
1232
|
+
|
1233
|
+
if return_chunks:
|
1234
|
+
writer.write_table(target_rows)
|
1235
|
+
|
1236
|
+
existing_table = changed_table
|
1237
|
+
|
1238
|
+
table_info.update_table(existing_table)
|
1239
|
+
|
1240
|
+
return change_count
|
1241
|
+
|
1242
|
+
def exchange_insert(
|
1243
|
+
self,
|
1244
|
+
*,
|
1245
|
+
context: base_server.CallContext[auth.Account, auth.AccountToken],
|
1246
|
+
descriptor: flight.FlightDescriptor,
|
1247
|
+
reader: flight.MetadataRecordBatchReader,
|
1248
|
+
writer: flight.MetadataRecordBatchWriter,
|
1249
|
+
return_chunks: bool,
|
1250
|
+
) -> int:
|
1251
|
+
assert context.caller is not None
|
1252
|
+
|
1253
|
+
descriptor_parts = descriptor_unpack_(descriptor)
|
1254
|
+
|
1255
|
+
if descriptor_parts.type != "table":
|
1256
|
+
raise flight.FlightServerError(f"Unsupported descriptor type: {descriptor_parts.type}")
|
1257
|
+
library = self.contents[context.caller.token.token]
|
1258
|
+
database = library.by_name(descriptor_parts.catalog_name)
|
1259
|
+
schema = database.by_name(descriptor_parts.schema_name)
|
1260
|
+
table_info = schema.by_name("table", descriptor_parts.name)
|
1261
|
+
|
1262
|
+
existing_table = table_info.version()
|
1263
|
+
writer.begin(existing_table.schema)
|
1264
|
+
change_count = 0
|
1265
|
+
|
1266
|
+
rowid_index = existing_table.schema.get_field_index(self.ROWID_FIELD_NAME)
|
1267
|
+
assert rowid_index != -1
|
1268
|
+
|
1269
|
+
# Check that the data being read matches the table without the rowid column.
|
1270
|
+
|
1271
|
+
# DuckDB won't send field metadata when it sends us the schema that it uses
|
1272
|
+
# to perform an insert, so we need some way to adapt the schema we
|
1273
|
+
check_schema_is_subset_of_schema(existing_table.schema, reader.schema)
|
1274
|
+
|
1275
|
+
# FIXME: need to handle the case of rowids.
|
1276
|
+
|
1277
|
+
for chunk in reader:
|
1278
|
+
if chunk.data is not None:
|
1279
|
+
new_rows = pa.Table.from_batches([chunk.data])
|
1280
|
+
assert new_rows.num_rows > 0
|
1281
|
+
|
1282
|
+
# append the row id column to the new rows.
|
1283
|
+
chunk_length = new_rows.num_rows
|
1284
|
+
rowid_values = [
|
1285
|
+
x
|
1286
|
+
for x in range(
|
1287
|
+
table_info.row_id_counter,
|
1288
|
+
table_info.row_id_counter + chunk_length,
|
1289
|
+
)
|
1290
|
+
]
|
1291
|
+
new_rows = new_rows.append_column(self.rowid_field, [rowid_values])
|
1292
|
+
table_info.row_id_counter += chunk_length
|
1293
|
+
change_count += chunk_length
|
1294
|
+
|
1295
|
+
if return_chunks:
|
1296
|
+
writer.write_table(new_rows)
|
1297
|
+
|
1298
|
+
# Since the table could have columns removed and deleted, use .select
|
1299
|
+
# the ensure that the columns are aligned in the same order as the original table.
|
1300
|
+
|
1301
|
+
# So it turns out that DuckDB doesn't send the "not null" flag in the arrow schema.
|
1302
|
+
#
|
1303
|
+
# This means we can't concat the tables, without those flags matching.
|
1304
|
+
# for field_name in existing_table.schema.names:
|
1305
|
+
# field = existing_table.schema.field(field_name)
|
1306
|
+
# if not field.nullable:
|
1307
|
+
# field_index = new_rows.schema.get_field_index(field_name)
|
1308
|
+
# new_rows = new_rows.set_column(
|
1309
|
+
# field_index, field.with_nullable(False), new_rows.column(field_index)
|
1310
|
+
# )
|
1311
|
+
new_rows = conform_nullable(existing_table.schema, new_rows)
|
1312
|
+
|
1313
|
+
table_info.update_table(
|
1314
|
+
pa.concat_tables([existing_table, new_rows.select(existing_table.schema.names)])
|
1315
|
+
)
|
1316
|
+
|
1317
|
+
return change_count
|
1318
|
+
|
1319
|
+
def exchange_table_function_in_out(
|
1320
|
+
self,
|
1321
|
+
*,
|
1322
|
+
context: base_server.CallContext[auth.Account, auth.AccountToken],
|
1323
|
+
descriptor: flight.FlightDescriptor,
|
1324
|
+
parameters: parameter_types.TableFunctionParameters,
|
1325
|
+
input_schema: pa.Schema,
|
1326
|
+
) -> tuple[pa.Schema, Generator[pa.RecordBatch, pa.RecordBatch, pa.RecordBatch]]:
|
1327
|
+
assert context.caller is not None
|
1328
|
+
|
1329
|
+
descriptor_parts = descriptor_unpack_(descriptor)
|
1330
|
+
|
1331
|
+
if descriptor_parts.type != "table_function":
|
1332
|
+
raise flight.FlightServerError(f"Unsupported descriptor type: {descriptor_parts.type}")
|
1333
|
+
library = self.contents[context.caller.token.token]
|
1334
|
+
database = library.by_name(descriptor_parts.catalog_name)
|
1335
|
+
schema = database.by_name(descriptor_parts.schema_name)
|
1336
|
+
table_info = schema.by_name("table_function", descriptor_parts.name)
|
1337
|
+
|
1338
|
+
output_schema = table_info.output_schema(parameters=parameters.parameters, input_schema=input_schema)
|
1339
|
+
gen = table_info.handler(parameters, output_schema)
|
1340
|
+
return (output_schema, gen)
|
1341
|
+
|
1342
|
+
def action_add_column(
|
1343
|
+
self,
|
1344
|
+
*,
|
1345
|
+
context: base_server.CallContext[auth.Account, auth.AccountToken],
|
1346
|
+
parameters: parameter_types.AddColumn,
|
1347
|
+
) -> flight.FlightInfo:
|
1348
|
+
assert context.caller is not None
|
1349
|
+
|
1350
|
+
self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
|
1351
|
+
library = self.contents[context.caller.token.token]
|
1352
|
+
database = library.by_name(parameters.catalog)
|
1353
|
+
schema = database.by_name(parameters.schema_name)
|
1354
|
+
|
1355
|
+
table_info = schema.by_name("table", parameters.name)
|
1356
|
+
|
1357
|
+
assert len(parameters.column_schema.names) == 1
|
1358
|
+
|
1359
|
+
existing_table = table_info.version()
|
1360
|
+
# Don't allow duplicate colum names.
|
1361
|
+
assert parameters.column_schema.field(0).name not in existing_table.schema.names
|
1362
|
+
|
1363
|
+
table_info.update_table(
|
1364
|
+
existing_table.append_column(
|
1365
|
+
parameters.column_schema.field(0).name,
|
1366
|
+
[
|
1367
|
+
pa.nulls(
|
1368
|
+
existing_table.num_rows,
|
1369
|
+
type=parameters.column_schema.field(0).type,
|
1370
|
+
)
|
1371
|
+
],
|
1372
|
+
)
|
1373
|
+
)
|
1374
|
+
database.version += 1
|
1375
|
+
|
1376
|
+
return table_info.flight_info(
|
1377
|
+
name=parameters.name,
|
1378
|
+
catalog_name=parameters.catalog,
|
1379
|
+
schema_name=parameters.schema_name,
|
1380
|
+
)[0]
|
1381
|
+
|
1382
|
+
def action_remove_column(
|
1383
|
+
self,
|
1384
|
+
*,
|
1385
|
+
context: base_server.CallContext[auth.Account, auth.AccountToken],
|
1386
|
+
parameters: parameter_types.RemoveColumn,
|
1387
|
+
) -> flight.FlightInfo:
|
1388
|
+
assert context.caller is not None
|
1389
|
+
|
1390
|
+
self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
|
1391
|
+
library = self.contents[context.caller.token.token]
|
1392
|
+
database = library.by_name(parameters.catalog)
|
1393
|
+
schema = database.by_name(parameters.schema_name)
|
1394
|
+
|
1395
|
+
table_info = schema.by_name("table", parameters.name)
|
1396
|
+
|
1397
|
+
table_info.update_table(table_info.version().drop(parameters.removed_column))
|
1398
|
+
database.version += 1
|
1399
|
+
|
1400
|
+
return table_info.flight_info(
|
1401
|
+
name=parameters.name,
|
1402
|
+
catalog_name=parameters.catalog,
|
1403
|
+
schema_name=parameters.schema_name,
|
1404
|
+
)[0]
|
1405
|
+
|
1406
|
+
def action_rename_column(
|
1407
|
+
self,
|
1408
|
+
*,
|
1409
|
+
context: base_server.CallContext[auth.Account, auth.AccountToken],
|
1410
|
+
parameters: parameter_types.RenameColumn,
|
1411
|
+
) -> flight.FlightInfo:
|
1412
|
+
assert context.caller is not None
|
1413
|
+
|
1414
|
+
self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
|
1415
|
+
library = self.contents[context.caller.token.token]
|
1416
|
+
database = library.by_name(parameters.catalog)
|
1417
|
+
schema = database.by_name(parameters.schema_name)
|
1418
|
+
|
1419
|
+
table_info = schema.by_name("table", parameters.name)
|
1420
|
+
|
1421
|
+
table_info.update_table(table_info.version().rename_columns({parameters.old_name: parameters.new_name}))
|
1422
|
+
database.version += 1
|
1423
|
+
|
1424
|
+
return table_info.flight_info(
|
1425
|
+
name=parameters.name,
|
1426
|
+
catalog_name=parameters.catalog,
|
1427
|
+
schema_name=parameters.schema_name,
|
1428
|
+
)[0]
|
1429
|
+
|
1430
|
+
def action_rename_table(
|
1431
|
+
self,
|
1432
|
+
*,
|
1433
|
+
context: base_server.CallContext[auth.Account, auth.AccountToken],
|
1434
|
+
parameters: parameter_types.RenameTable,
|
1435
|
+
) -> flight.FlightInfo:
|
1436
|
+
assert context.caller is not None
|
1437
|
+
|
1438
|
+
self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
|
1439
|
+
library = self.contents[context.caller.token.token]
|
1440
|
+
database = library.by_name(parameters.catalog)
|
1441
|
+
schema = database.by_name(parameters.schema_name)
|
1442
|
+
|
1443
|
+
table_info = schema.by_name("table", parameters.name)
|
1444
|
+
|
1445
|
+
schema.tables_by_name[parameters.new_table_name] = schema.tables_by_name.pop(parameters.name)
|
1446
|
+
|
1447
|
+
database.version += 1
|
1448
|
+
|
1449
|
+
return table_info.flight_info(
|
1450
|
+
name=parameters.new_table_name,
|
1451
|
+
catalog_name=parameters.catalog,
|
1452
|
+
schema_name=parameters.schema_name,
|
1453
|
+
)[0]
|
1454
|
+
|
1455
|
+
def action_set_default(
|
1456
|
+
self,
|
1457
|
+
*,
|
1458
|
+
context: base_server.CallContext[auth.Account, auth.AccountToken],
|
1459
|
+
parameters: parameter_types.SetDefault,
|
1460
|
+
) -> flight.FlightInfo:
|
1461
|
+
assert context.caller is not None
|
1462
|
+
|
1463
|
+
self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
|
1464
|
+
library = self.contents[context.caller.token.token]
|
1465
|
+
database = library.by_name(parameters.catalog)
|
1466
|
+
schema = database.by_name(parameters.schema_name)
|
1467
|
+
table_info = schema.by_name("table", parameters.name)
|
1468
|
+
|
1469
|
+
# Defaults are set as metadata on a field.
|
1470
|
+
|
1471
|
+
t = table_info.version()
|
1472
|
+
field_index = t.schema.get_field_index(parameters.column_name)
|
1473
|
+
field = t.schema.field(parameters.column_name)
|
1474
|
+
new_metadata: dict[str, Any] = {}
|
1475
|
+
if field.metadata:
|
1476
|
+
new_metadata = {**field.metadata}
|
1477
|
+
|
1478
|
+
new_metadata["default"] = parameters.expression
|
1479
|
+
|
1480
|
+
table_info.update_table(
|
1481
|
+
t.set_column(
|
1482
|
+
field_index,
|
1483
|
+
field.with_metadata(new_metadata),
|
1484
|
+
t.column(field_index),
|
1485
|
+
)
|
1486
|
+
)
|
1487
|
+
|
1488
|
+
database.version += 1
|
1489
|
+
|
1490
|
+
return table_info.flight_info(
|
1491
|
+
name=parameters.name,
|
1492
|
+
catalog_name=parameters.catalog,
|
1493
|
+
schema_name=parameters.schema_name,
|
1494
|
+
)[0]
|
1495
|
+
|
1496
|
+
def action_set_not_null(
|
1497
|
+
self,
|
1498
|
+
*,
|
1499
|
+
context: base_server.CallContext[auth.Account, auth.AccountToken],
|
1500
|
+
parameters: parameter_types.SetNotNull,
|
1501
|
+
) -> flight.FlightInfo:
|
1502
|
+
assert context.caller is not None
|
1503
|
+
|
1504
|
+
self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
|
1505
|
+
library = self.contents[context.caller.token.token]
|
1506
|
+
database = library.by_name(parameters.catalog)
|
1507
|
+
schema = database.by_name(parameters.schema_name)
|
1508
|
+
|
1509
|
+
table_info = schema.by_name("table", parameters.name)
|
1510
|
+
|
1511
|
+
t = table_info.version()
|
1512
|
+
field_index = t.schema.get_field_index(parameters.column_name)
|
1513
|
+
field = t.schema.field(parameters.column_name)
|
1514
|
+
|
1515
|
+
if t.column(field_index).null_count > 0:
|
1516
|
+
raise flight.FlightServerError(f"Cannot set column {parameters.column_name} contains null values")
|
1517
|
+
|
1518
|
+
table_info.update_table(
|
1519
|
+
t.set_column(
|
1520
|
+
field_index,
|
1521
|
+
field.with_nullable(False),
|
1522
|
+
t.column(field_index),
|
1523
|
+
)
|
1524
|
+
)
|
1525
|
+
|
1526
|
+
database.version += 1
|
1527
|
+
|
1528
|
+
return table_info.flight_info(
|
1529
|
+
name=parameters.name,
|
1530
|
+
catalog_name=parameters.catalog,
|
1531
|
+
schema_name=parameters.schema_name,
|
1532
|
+
)[0]
|
1533
|
+
|
1534
|
+
def action_drop_not_null(
|
1535
|
+
self,
|
1536
|
+
*,
|
1537
|
+
context: base_server.CallContext[auth.Account, auth.AccountToken],
|
1538
|
+
parameters: parameter_types.DropNotNull,
|
1539
|
+
) -> flight.FlightInfo:
|
1540
|
+
assert context.caller is not None
|
1541
|
+
|
1542
|
+
self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
|
1543
|
+
library = self.contents[context.caller.token.token]
|
1544
|
+
database = library.by_name(parameters.catalog)
|
1545
|
+
schema = database.by_name(parameters.schema_name)
|
1546
|
+
|
1547
|
+
table_info = schema.by_name("table", parameters.name)
|
1548
|
+
|
1549
|
+
t = table_info.version()
|
1550
|
+
field_index = t.schema.get_field_index(parameters.column_name)
|
1551
|
+
field = t.schema.field(parameters.column_name)
|
1552
|
+
|
1553
|
+
table_info.update_table(
|
1554
|
+
t.set_column(
|
1555
|
+
field_index,
|
1556
|
+
field.with_nullable(True),
|
1557
|
+
t.column(field_index),
|
1558
|
+
)
|
1559
|
+
)
|
1560
|
+
|
1561
|
+
database.version += 1
|
1562
|
+
|
1563
|
+
return table_info.flight_info(
|
1564
|
+
name=parameters.name,
|
1565
|
+
catalog_name=parameters.catalog,
|
1566
|
+
schema_name=parameters.schema_name,
|
1567
|
+
)[0]
|
1568
|
+
|
1569
|
+
def action_change_column_type(
|
1570
|
+
self,
|
1571
|
+
*,
|
1572
|
+
context: base_server.CallContext[auth.Account, auth.AccountToken],
|
1573
|
+
parameters: parameter_types.ChangeColumnType,
|
1574
|
+
) -> flight.FlightInfo:
|
1575
|
+
assert context.caller is not None
|
1576
|
+
|
1577
|
+
self.contents.setdefault(context.caller.token.token, DatabaseLibrary())
|
1578
|
+
library = self.contents[context.caller.token.token]
|
1579
|
+
database = library.by_name(parameters.catalog)
|
1580
|
+
schema = database.by_name(parameters.schema_name)
|
1581
|
+
|
1582
|
+
table_info = schema.by_name("table", parameters.name)
|
1583
|
+
|
1584
|
+
# Defaults are set as metadata on a field.
|
1585
|
+
|
1586
|
+
t = table_info.version()
|
1587
|
+
column_name = parameters.column_schema.field(0).name
|
1588
|
+
field_index = t.schema.get_field_index(column_name)
|
1589
|
+
field = t.schema.field(column_name)
|
1590
|
+
|
1591
|
+
new_type = parameters.column_schema.field(0).type
|
1592
|
+
new_field = pa.field(field.name, new_type, metadata=field.metadata)
|
1593
|
+
new_data = pc.cast(t.column(field_index), new_type)
|
1594
|
+
|
1595
|
+
table_info.update_table(
|
1596
|
+
t.set_column(
|
1597
|
+
field_index,
|
1598
|
+
new_field,
|
1599
|
+
new_data,
|
1600
|
+
)
|
1601
|
+
)
|
1602
|
+
|
1603
|
+
database.version += 1
|
1604
|
+
|
1605
|
+
return table_info.flight_info(
|
1606
|
+
name=parameters.name,
|
1607
|
+
catalog_name=parameters.catalog,
|
1608
|
+
schema_name=parameters.schema_name,
|
1609
|
+
)[0]
|
1610
|
+
|
1611
|
+
def impl_do_get(
|
1612
|
+
self,
|
1613
|
+
*,
|
1614
|
+
context: base_server.CallContext[auth.Account, auth.AccountToken],
|
1615
|
+
ticket: flight.Ticket,
|
1616
|
+
) -> flight.RecordBatchStream:
|
1617
|
+
assert context.caller is not None
|
1618
|
+
|
1619
|
+
ticket_data = flight_handling.decode_ticket_model(ticket, FlightTicketData)
|
1620
|
+
|
1621
|
+
descriptor_parts = descriptor_unpack_(ticket_data.descriptor)
|
1622
|
+
library = self.contents[context.caller.token.token]
|
1623
|
+
database = library.by_name(descriptor_parts.catalog_name)
|
1624
|
+
schema = database.by_name(descriptor_parts.schema_name)
|
1625
|
+
|
1626
|
+
if descriptor_parts.type == "table":
|
1627
|
+
table = schema.by_name("table", descriptor_parts.name)
|
1628
|
+
|
1629
|
+
if ticket_data.at_unit == "VERSION":
|
1630
|
+
assert ticket_data.at_value is not None
|
1631
|
+
|
1632
|
+
# Check if at_value is an integer but currently is a string.
|
1633
|
+
if not re.match(r"^\d+$", ticket_data.at_value):
|
1634
|
+
raise flight.FlightServerError(f"Invalid version: {ticket_data.at_value}")
|
1635
|
+
|
1636
|
+
table_version = table.version(int(ticket_data.at_value))
|
1637
|
+
elif ticket_data.at_unit == "TIMESTAMP":
|
1638
|
+
raise flight.FlightServerError("Timestamp not supported for table versioning")
|
1639
|
+
else:
|
1640
|
+
table_version = table.version()
|
1641
|
+
|
1642
|
+
if descriptor_parts.schema_name == "test_predicate_pushdown" and ticket_data.where_clause is not None:
|
1643
|
+
# We are going to do the predicate pushdown for filtering the data we have in memory.
|
1644
|
+
# At this point if we have JSON filters we should test that we can decode them.
|
1645
|
+
with duckdb.connect(":memory:") as connection:
|
1646
|
+
connection.execute("SET TimeZone = 'UTC'")
|
1647
|
+
sql = f"select * from table_version where {ticket_data.where_clause}"
|
1648
|
+
try:
|
1649
|
+
results = connection.execute(sql).fetch_arrow_table()
|
1650
|
+
except Exception as e:
|
1651
|
+
raise flight.FlightServerError(f"Failed to execute predicate pushdown: {e} sql: {sql}") from e
|
1652
|
+
table_version = results
|
1653
|
+
|
1654
|
+
return flight.RecordBatchStream(table_version)
|
1655
|
+
elif descriptor_parts.type == "table_function":
|
1656
|
+
table_function = schema.by_name("table_function", descriptor_parts.name)
|
1657
|
+
|
1658
|
+
output_schema = table_function.output_schema(
|
1659
|
+
ticket_data.table_function_parameters,
|
1660
|
+
ticket_data.table_function_input_schema,
|
1661
|
+
)
|
1662
|
+
|
1663
|
+
return flight.GeneratorStream(
|
1664
|
+
output_schema,
|
1665
|
+
table_function.handler(
|
1666
|
+
parameter_types.TableFunctionParameters(
|
1667
|
+
parameters=ticket_data.table_function_parameters, where_clause=ticket_data.where_clause
|
1668
|
+
),
|
1669
|
+
output_schema,
|
1670
|
+
),
|
1671
|
+
)
|
1672
|
+
else:
|
1673
|
+
raise flight.FlightServerError(f"Unsupported descriptor type: {descriptor_parts.type}")
|
1674
|
+
|
1675
|
+
def action_table_function_flight_info(
|
1676
|
+
self,
|
1677
|
+
*,
|
1678
|
+
context: base_server.CallContext[auth.Account, auth.AccountToken],
|
1679
|
+
parameters: parameter_types.TableFunctionFlightInfo,
|
1680
|
+
) -> flight.FlightInfo:
|
1681
|
+
assert context.caller is not None
|
1682
|
+
|
1683
|
+
library = self.contents[context.caller.token.token]
|
1684
|
+
descriptor_parts = descriptor_unpack_(parameters.descriptor)
|
1685
|
+
|
1686
|
+
database = library.by_name(descriptor_parts.catalog_name)
|
1687
|
+
schema = database.by_name(descriptor_parts.schema_name)
|
1688
|
+
|
1689
|
+
# Now to get the table function, its a bit harder since they are named by action.
|
1690
|
+
table_function = schema.by_name("table_function", descriptor_parts.name)
|
1691
|
+
|
1692
|
+
return table_function.flight_info(
|
1693
|
+
name=descriptor_parts.name,
|
1694
|
+
catalog_name=descriptor_parts.catalog_name,
|
1695
|
+
schema_name=descriptor_parts.schema_name,
|
1696
|
+
# Pass the real parameters here.
|
1697
|
+
parameters=parameters,
|
1698
|
+
)[0]
|
1699
|
+
|
1700
|
+
def action_flight_info(
|
1701
|
+
self,
|
1702
|
+
*,
|
1703
|
+
context: base_server.CallContext[auth.Account, auth.AccountToken],
|
1704
|
+
parameters: parameter_types.FlightInfo,
|
1705
|
+
) -> flight.FlightInfo:
|
1706
|
+
assert context.caller is not None
|
1707
|
+
|
1708
|
+
library = self.contents[context.caller.token.token]
|
1709
|
+
descriptor_parts = descriptor_unpack_(parameters.descriptor)
|
1710
|
+
|
1711
|
+
database = library.by_name(descriptor_parts.catalog_name)
|
1712
|
+
schema = database.by_name(descriptor_parts.schema_name)
|
1713
|
+
|
1714
|
+
if descriptor_parts.type == "table_function":
|
1715
|
+
raise flight.FlightServerError("Table function flight info not supported")
|
1716
|
+
elif descriptor_parts.type == "table":
|
1717
|
+
table = schema.by_name("table", descriptor_parts.name)
|
1718
|
+
|
1719
|
+
version_id = None
|
1720
|
+
if parameters.at_unit is not None:
|
1721
|
+
if parameters.at_unit == "VERSION":
|
1722
|
+
assert parameters.at_value is not None
|
1723
|
+
|
1724
|
+
# Check if at_value is an integer but currently is a string.
|
1725
|
+
if not re.match(r"^\d+$", parameters.at_value):
|
1726
|
+
raise flight.FlightServerError(f"Invalid version: {parameters.at_value}")
|
1727
|
+
|
1728
|
+
version_id = int(parameters.at_value)
|
1729
|
+
elif parameters.at_unit == "TIMESTAMP":
|
1730
|
+
raise flight.FlightServerError("Timestamp not supported for table versioning")
|
1731
|
+
else:
|
1732
|
+
raise flight.FlightServerError(f"Unsupported at_unit: {parameters.at_unit}")
|
1733
|
+
|
1734
|
+
return table.flight_info(
|
1735
|
+
name=descriptor_parts.name,
|
1736
|
+
catalog_name=descriptor_parts.catalog_name,
|
1737
|
+
schema_name=descriptor_parts.schema_name,
|
1738
|
+
version=version_id,
|
1739
|
+
)[0]
|
1740
|
+
else:
|
1741
|
+
raise flight.FlightServerError(f"Unsupported descriptor type: {descriptor_parts.type}")
|
1742
|
+
|
1743
|
+
def exchange_scalar_function(
|
1744
|
+
self,
|
1745
|
+
*,
|
1746
|
+
context: base_server.CallContext[auth.Account, auth.AccountToken],
|
1747
|
+
descriptor: flight.FlightDescriptor,
|
1748
|
+
reader: flight.MetadataRecordBatchReader,
|
1749
|
+
writer: flight.MetadataRecordBatchWriter,
|
1750
|
+
) -> None:
|
1751
|
+
assert context.caller is not None
|
1752
|
+
|
1753
|
+
descriptor_parts = descriptor_unpack_(descriptor)
|
1754
|
+
|
1755
|
+
if descriptor_parts.type != "scalar_function":
|
1756
|
+
raise flight.FlightServerError(f"Unsupported descriptor type: {descriptor_parts.type}")
|
1757
|
+
|
1758
|
+
library = self.contents[context.caller.token.token]
|
1759
|
+
database = library.by_name(descriptor_parts.catalog_name)
|
1760
|
+
schema = database.by_name(descriptor_parts.schema_name)
|
1761
|
+
scalar_function_info = schema.by_name("scalar_function", descriptor_parts.name)
|
1762
|
+
|
1763
|
+
writer.begin(scalar_function_info.output_schema)
|
1764
|
+
|
1765
|
+
for chunk in reader:
|
1766
|
+
if chunk.data is not None:
|
1767
|
+
new_rows = pa.Table.from_batches([chunk.data])
|
1768
|
+
assert new_rows.num_rows > 0
|
1769
|
+
|
1770
|
+
result = scalar_function_info.handler(new_rows)
|
1771
|
+
|
1772
|
+
writer.write_table(pa.Table.from_arrays([result], schema=scalar_function_info.output_schema))
|
1773
|
+
|
1774
|
+
|
1775
|
+
@click.command()
|
1776
|
+
@click.option(
|
1777
|
+
"--location",
|
1778
|
+
type=str,
|
1779
|
+
default="grpc://127.0.0.1:50312",
|
1780
|
+
help="The location where the server should listen.",
|
1781
|
+
)
|
1782
|
+
def run(location: str) -> None:
|
1783
|
+
log.info("Starting server", location=location)
|
1784
|
+
|
1785
|
+
auth_manager = auth_manager_naive.AuthManagerNaive[auth.Account, auth.AccountToken](
|
1786
|
+
account_type=auth.Account,
|
1787
|
+
token_type=auth.AccountToken,
|
1788
|
+
allow_anonymous_access=False,
|
1789
|
+
)
|
1790
|
+
|
1791
|
+
server = InMemoryArrowFlightServer(
|
1792
|
+
middleware={
|
1793
|
+
"headers": base_middleware.SaveHeadersMiddlewareFactory(),
|
1794
|
+
"auth": base_middleware.AuthManagerMiddlewareFactory(auth_manager=auth_manager),
|
1795
|
+
},
|
1796
|
+
location=location,
|
1797
|
+
auth_manager=auth_manager,
|
1798
|
+
)
|
1799
|
+
server.serve()
|