vgi-python 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vgi/__init__.py +152 -0
- vgi/_duckdb.py +62 -0
- vgi/_storage_profile.py +132 -0
- vgi/_test_fixtures/__init__.py +20 -0
- vgi/_test_fixtures/accumulate/__init__.py +19 -0
- vgi/_test_fixtures/accumulate/worker.py +762 -0
- vgi/_test_fixtures/aggregate/__init__.py +62 -0
- vgi/_test_fixtures/aggregate/_common.py +21 -0
- vgi/_test_fixtures/aggregate/basic.py +232 -0
- vgi/_test_fixtures/aggregate/dynamic.py +409 -0
- vgi/_test_fixtures/aggregate/generic.py +86 -0
- vgi/_test_fixtures/aggregate/listagg.py +71 -0
- vgi/_test_fixtures/aggregate/percentile.py +107 -0
- vgi/_test_fixtures/aggregate/streaming.py +192 -0
- vgi/_test_fixtures/aggregate/varargs.py +75 -0
- vgi/_test_fixtures/aggregate/window.py +380 -0
- vgi/_test_fixtures/attach_options.py +308 -0
- vgi/_test_fixtures/bad_protocol.py +62 -0
- vgi/_test_fixtures/cancellable.py +336 -0
- vgi/_test_fixtures/catalog.py +813 -0
- vgi/_test_fixtures/http_server.py +394 -0
- vgi/_test_fixtures/nest_tensor.py +614 -0
- vgi/_test_fixtures/orchard_catalog.py +47 -0
- vgi/_test_fixtures/projection_repro/__init__.py +6 -0
- vgi/_test_fixtures/projection_repro/worker.py +454 -0
- vgi/_test_fixtures/scalar/__init__.py +116 -0
- vgi/_test_fixtures/scalar/_common.py +69 -0
- vgi/_test_fixtures/scalar/arithmetic.py +321 -0
- vgi/_test_fixtures/scalar/binary.py +120 -0
- vgi/_test_fixtures/scalar/formatting.py +176 -0
- vgi/_test_fixtures/scalar/geo.py +300 -0
- vgi/_test_fixtures/scalar/null_handling.py +107 -0
- vgi/_test_fixtures/scalar/random_demo.py +171 -0
- vgi/_test_fixtures/scalar/settings_secrets.py +102 -0
- vgi/_test_fixtures/scalar/type_info.py +219 -0
- vgi/_test_fixtures/schema_reconcile/__init__.py +29 -0
- vgi/_test_fixtures/schema_reconcile/worker.py +653 -0
- vgi/_test_fixtures/simple_writable.py +793 -0
- vgi/_test_fixtures/table/__init__.py +221 -0
- vgi/_test_fixtures/table/_common.py +162 -0
- vgi/_test_fixtures/table/batch_index.py +283 -0
- vgi/_test_fixtures/table/batch_index_broken.py +200 -0
- vgi/_test_fixtures/table/catalog_scans.py +162 -0
- vgi/_test_fixtures/table/filters.py +1005 -0
- vgi/_test_fixtures/table/late_materialization.py +249 -0
- vgi/_test_fixtures/table/make_series.py +273 -0
- vgi/_test_fixtures/table/misc.py +499 -0
- vgi/_test_fixtures/table/order_modes.py +164 -0
- vgi/_test_fixtures/table/pairs.py +437 -0
- vgi/_test_fixtures/table/partition_columns.py +472 -0
- vgi/_test_fixtures/table/partition_columns_broken.py +304 -0
- vgi/_test_fixtures/table/profiling_example.py +195 -0
- vgi/_test_fixtures/table/required_filters.py +234 -0
- vgi/_test_fixtures/table/sequence.py +710 -0
- vgi/_test_fixtures/table/settings.py +426 -0
- vgi/_test_fixtures/table/transaction_storage.py +162 -0
- vgi/_test_fixtures/table/tt_pushdown.py +191 -0
- vgi/_test_fixtures/table/versioned.py +230 -0
- vgi/_test_fixtures/table_in_out.py +1392 -0
- vgi/_test_fixtures/versioned.py +155 -0
- vgi/_test_fixtures/versioned_tables.py +595 -0
- vgi/_test_fixtures/worker.py +1631 -0
- vgi/_test_fixtures/writable/__init__.py +8 -0
- vgi/_test_fixtures/writable/generic.py +236 -0
- vgi/_test_fixtures/writable/table.py +149 -0
- vgi/_test_fixtures/writable/worker.py +1148 -0
- vgi/aggregate_function.py +607 -0
- vgi/argument_spec.py +472 -0
- vgi/arguments.py +1747 -0
- vgi/auth.py +55 -0
- vgi/catalog/__init__.py +88 -0
- vgi/catalog/attach_option.py +206 -0
- vgi/catalog/catalog_interface.py +2767 -0
- vgi/catalog/descriptors.py +870 -0
- vgi/catalog/duckdb_statistics.py +377 -0
- vgi/catalog/secret_type.py +96 -0
- vgi/catalog/setting.py +253 -0
- vgi/catalog/storage.py +372 -0
- vgi/client/__init__.py +67 -0
- vgi/client/catalog_mixin.py +1251 -0
- vgi/client/cli.py +582 -0
- vgi/client/cli_catalog.py +182 -0
- vgi/client/cli_schema.py +270 -0
- vgi/client/cli_table.py +907 -0
- vgi/client/cli_transaction.py +97 -0
- vgi/client/cli_utils.py +441 -0
- vgi/client/cli_view.py +303 -0
- vgi/client/client.py +2183 -0
- vgi/exceptions.py +205 -0
- vgi/function.py +245 -0
- vgi/function_storage.py +1636 -0
- vgi/function_storage_azure_sql.py +922 -0
- vgi/function_storage_cf_do.py +740 -0
- vgi/http/__init__.py +25 -0
- vgi/http/demo_storage.py +212 -0
- vgi/http/worker_page.py +1252 -0
- vgi/invocation.py +154 -0
- vgi/logging_config.py +93 -0
- vgi/meta_worker.py +661 -0
- vgi/metadata.py +1403 -0
- vgi/otel.py +406 -0
- vgi/protocol.py +2418 -0
- vgi/protocol_version.txt +1 -0
- vgi/py.typed +0 -0
- vgi/scalar_function.py +1211 -0
- vgi/schema_utils.py +234 -0
- vgi/secret_protocol.py +124 -0
- vgi/secret_service.py +238 -0
- vgi/serve.py +769 -0
- vgi/table_buffering_function.py +443 -0
- vgi/table_filter_pushdown.py +1528 -0
- vgi/table_function.py +1130 -0
- vgi/table_in_out_function.py +383 -0
- vgi/transactor/__init__.py +24 -0
- vgi/transactor/_duckdb_compat.py +27 -0
- vgi/transactor/client.py +137 -0
- vgi/transactor/protocol.py +149 -0
- vgi/transactor/server.py +740 -0
- vgi/worker.py +4761 -0
- vgi_python-0.8.0.dist-info/METADATA +735 -0
- vgi_python-0.8.0.dist-info/RECORD +124 -0
- vgi_python-0.8.0.dist-info/WHEEL +4 -0
- vgi_python-0.8.0.dist-info/entry_points.txt +5 -0
- vgi_python-0.8.0.dist-info/licenses/LICENSE +134 -0
vgi/transactor/server.py
ADDED
|
@@ -0,0 +1,740 @@
|
|
|
1
|
+
# Copyright 2025, 2026 Query Farm LLC - https://query.farm
|
|
2
|
+
|
|
3
|
+
"""db-transactor server — multi-database DuckDB transaction manager.
|
|
4
|
+
|
|
5
|
+
Runs as a long-lived subprocess, accepting ``vgi_rpc`` connections over a
|
|
6
|
+
Unix domain socket. Manages multiple DuckDB databases, one per catalog
|
|
7
|
+
attachment (identified by ``attach_opaque_data``).
|
|
8
|
+
|
|
9
|
+
Usage::
|
|
10
|
+
|
|
11
|
+
vgi-transactor --db-dir /path/to/databases --socket /tmp/vgi-transactor.sock
|
|
12
|
+
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import argparse
|
|
18
|
+
import logging
|
|
19
|
+
import os
|
|
20
|
+
import sys
|
|
21
|
+
import threading
|
|
22
|
+
import uuid
|
|
23
|
+
from typing import TYPE_CHECKING, cast
|
|
24
|
+
|
|
25
|
+
import pyarrow as pa
|
|
26
|
+
from vgi_rpc import AnnotatedBatch, OutputCollector, RpcServer
|
|
27
|
+
from vgi_rpc.rpc import CallContext, ExchangeState, ProducerState, Stream, StreamState, serve_unix
|
|
28
|
+
|
|
29
|
+
from vgi._duckdb import connect as engine_connect
|
|
30
|
+
from vgi.schema_utils import schema
|
|
31
|
+
from vgi.transactor._duckdb_compat import subcursor
|
|
32
|
+
from vgi.transactor.protocol import TransactorProtocol
|
|
33
|
+
|
|
34
|
+
if TYPE_CHECKING:
|
|
35
|
+
import duckdb
|
|
36
|
+
|
|
37
|
+
logger = logging.getLogger("vgi.transactor")
|
|
38
|
+
|
|
39
|
+
_COUNT_SCHEMA = schema(count=pa.int64())
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class TransactorImpl:
|
|
43
|
+
"""Implementation of the TransactorProtocol backed by DuckDB.
|
|
44
|
+
|
|
45
|
+
Manages multiple databases (one per attach_opaque_data). Each transaction gets
|
|
46
|
+
its own DuckDB cursor, allowing multiple concurrent transactions per
|
|
47
|
+
database.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init__(self, db_dir: str) -> None:
|
|
51
|
+
"""Initialize with the directory for database files."""
|
|
52
|
+
self._db_dir = db_dir
|
|
53
|
+
os.makedirs(db_dir, exist_ok=True)
|
|
54
|
+
self._lock = threading.Lock()
|
|
55
|
+
self._databases: dict[bytes, duckdb.DuckDBPyConnection] = {}
|
|
56
|
+
self._catalog_names: dict[bytes, str] = {} # attach_opaque_data → catalog name (for view SQL stripping)
|
|
57
|
+
self._catalog_versions: dict[bytes, int] = {} # attach_opaque_data → version (incremented on DDL)
|
|
58
|
+
# Transactions nested by attach_opaque_data: {attach_opaque_data: {tx_id: cursor}}
|
|
59
|
+
self._transactions: dict[bytes, dict[bytes, duckdb.DuckDBPyConnection]] = {}
|
|
60
|
+
self._tx_locks: dict[bytes, dict[bytes, threading.Lock]] = {}
|
|
61
|
+
logger.info("Transactor started: db_dir=%s", db_dir)
|
|
62
|
+
|
|
63
|
+
# ========== Helpers ==========
|
|
64
|
+
|
|
65
|
+
def _get_db_conn(self, attach_opaque_data: bytes) -> duckdb.DuckDBPyConnection:
|
|
66
|
+
"""Get the main connection for a database, raising if not registered."""
|
|
67
|
+
with self._lock:
|
|
68
|
+
conn = self._databases.get(attach_opaque_data)
|
|
69
|
+
if conn is None:
|
|
70
|
+
msg = f"No registered database: {attach_opaque_data.hex()}"
|
|
71
|
+
raise ValueError(msg)
|
|
72
|
+
return conn
|
|
73
|
+
|
|
74
|
+
def _get_tx_conn(self, attach_opaque_data: bytes, tx_id: bytes) -> duckdb.DuckDBPyConnection:
|
|
75
|
+
"""Get the cursor for a transaction within a database."""
|
|
76
|
+
with self._lock:
|
|
77
|
+
db_txns = self._transactions.get(attach_opaque_data, {})
|
|
78
|
+
conn = db_txns.get(tx_id)
|
|
79
|
+
if conn is None:
|
|
80
|
+
msg = f"No active transaction: {tx_id.hex()} in db {attach_opaque_data.hex()}"
|
|
81
|
+
raise ValueError(msg)
|
|
82
|
+
return conn
|
|
83
|
+
|
|
84
|
+
def _get_tx_lock(self, attach_opaque_data: bytes, tx_id: bytes) -> threading.Lock:
|
|
85
|
+
"""Get the per-transaction lock."""
|
|
86
|
+
with self._lock:
|
|
87
|
+
db_locks = self._tx_locks.setdefault(attach_opaque_data, {})
|
|
88
|
+
if tx_id not in db_locks:
|
|
89
|
+
db_locks[tx_id] = threading.Lock()
|
|
90
|
+
return db_locks[tx_id]
|
|
91
|
+
|
|
92
|
+
def _table_schema(self, qualified_name: str, attach_opaque_data: bytes, tx_id: bytes) -> pa.Schema:
|
|
93
|
+
"""Get the Arrow schema for a table using a subcursor."""
|
|
94
|
+
conn = self._get_tx_conn(attach_opaque_data, tx_id)
|
|
95
|
+
tx_lock = self._get_tx_lock(attach_opaque_data, tx_id)
|
|
96
|
+
with tx_lock:
|
|
97
|
+
sub = subcursor(conn)
|
|
98
|
+
sql = f"SELECT * FROM {qualified_name} LIMIT 0" # noqa: S608
|
|
99
|
+
result = sub.execute(sql)
|
|
100
|
+
schema: pa.Schema = result.to_arrow_table().schema
|
|
101
|
+
sub.close()
|
|
102
|
+
return schema
|
|
103
|
+
|
|
104
|
+
# ========== Database lifecycle ==========
|
|
105
|
+
|
|
106
|
+
def register(
|
|
107
|
+
self, attach_opaque_data: bytes, catalog_name: str = "", ddl_statements: list[str] | None = None
|
|
108
|
+
) -> None:
|
|
109
|
+
"""Create a new database for this attach_opaque_data and run initial DDL."""
|
|
110
|
+
db_path = os.path.join(self._db_dir, f"{attach_opaque_data.hex()}.duckdb")
|
|
111
|
+
conn = engine_connect(db_path)
|
|
112
|
+
with self._lock:
|
|
113
|
+
self._databases[attach_opaque_data] = conn
|
|
114
|
+
self._catalog_names[attach_opaque_data] = catalog_name
|
|
115
|
+
self._catalog_versions[attach_opaque_data] = 1
|
|
116
|
+
if ddl_statements:
|
|
117
|
+
for sql in ddl_statements:
|
|
118
|
+
conn.execute(sql)
|
|
119
|
+
logger.info("Database registered: %s (catalog=%s) -> %s", attach_opaque_data.hex()[:8], catalog_name, db_path)
|
|
120
|
+
|
|
121
|
+
def catalog_version(self, attach_opaque_data: bytes) -> int:
|
|
122
|
+
"""Return the catalog version for the database."""
|
|
123
|
+
with self._lock:
|
|
124
|
+
return self._catalog_versions.get(attach_opaque_data, 1)
|
|
125
|
+
|
|
126
|
+
# ========== Transaction lifecycle ==========
|
|
127
|
+
|
|
128
|
+
def begin(self, attach_opaque_data: bytes) -> bytes:
|
|
129
|
+
"""Begin a transaction on the database. Returns the tx_id."""
|
|
130
|
+
db_conn = self._get_db_conn(attach_opaque_data)
|
|
131
|
+
tx_id = uuid.uuid4().bytes
|
|
132
|
+
cursor = db_conn.cursor()
|
|
133
|
+
cursor.execute("SET enable_suspended_queries = true")
|
|
134
|
+
cursor.begin()
|
|
135
|
+
with self._lock:
|
|
136
|
+
self._transactions.setdefault(attach_opaque_data, {})[tx_id] = cursor
|
|
137
|
+
self._tx_locks.setdefault(attach_opaque_data, {})[tx_id] = threading.Lock()
|
|
138
|
+
logger.info("Transaction begun: %s (db %s)", tx_id.hex()[:8], attach_opaque_data.hex()[:8])
|
|
139
|
+
return tx_id
|
|
140
|
+
|
|
141
|
+
def commit(self, attach_opaque_data: bytes, tx_id: bytes) -> None:
|
|
142
|
+
"""Commit a transaction."""
|
|
143
|
+
conn = self._get_tx_conn(attach_opaque_data, tx_id)
|
|
144
|
+
conn.commit()
|
|
145
|
+
conn.close()
|
|
146
|
+
with self._lock:
|
|
147
|
+
self._transactions.get(attach_opaque_data, {}).pop(tx_id, None)
|
|
148
|
+
self._tx_locks.get(attach_opaque_data, {}).pop(tx_id, None)
|
|
149
|
+
logger.info("Transaction committed: %s", tx_id.hex()[:8])
|
|
150
|
+
|
|
151
|
+
def rollback(self, attach_opaque_data: bytes, tx_id: bytes) -> None:
|
|
152
|
+
"""Rollback a transaction."""
|
|
153
|
+
conn = self._get_tx_conn(attach_opaque_data, tx_id)
|
|
154
|
+
conn.rollback()
|
|
155
|
+
conn.close()
|
|
156
|
+
with self._lock:
|
|
157
|
+
self._transactions.get(attach_opaque_data, {}).pop(tx_id, None)
|
|
158
|
+
self._tx_locks.get(attach_opaque_data, {}).pop(tx_id, None)
|
|
159
|
+
logger.info("Transaction rolled back: %s", tx_id.hex()[:8])
|
|
160
|
+
|
|
161
|
+
# ========== Write operations (streaming exchange) ==========
|
|
162
|
+
|
|
163
|
+
def insert(
|
|
164
|
+
self,
|
|
165
|
+
attach_opaque_data: bytes,
|
|
166
|
+
tx_id: bytes,
|
|
167
|
+
table_name: str,
|
|
168
|
+
schema_name: str = "",
|
|
169
|
+
returning: bool = False,
|
|
170
|
+
) -> Stream[StreamState]:
|
|
171
|
+
"""Create an insert exchange stream."""
|
|
172
|
+
conn = self._get_tx_conn(attach_opaque_data, tx_id)
|
|
173
|
+
tx_lock = self._get_tx_lock(attach_opaque_data, tx_id)
|
|
174
|
+
qualified = f"{schema_name}.{table_name}" if schema_name else table_name
|
|
175
|
+
table_schema = self._table_schema(qualified, attach_opaque_data, tx_id)
|
|
176
|
+
|
|
177
|
+
input_fields = [f for f in table_schema if f.name != "rowid"]
|
|
178
|
+
input_schema = pa.schema(input_fields)
|
|
179
|
+
output_schema = input_schema if returning else _COUNT_SCHEMA
|
|
180
|
+
|
|
181
|
+
sub = subcursor(conn)
|
|
182
|
+
state = _InsertState(
|
|
183
|
+
conn=sub,
|
|
184
|
+
qualified_name=qualified,
|
|
185
|
+
returning=returning,
|
|
186
|
+
table_schema=input_schema,
|
|
187
|
+
tx_lock=tx_lock,
|
|
188
|
+
)
|
|
189
|
+
return Stream(output_schema=output_schema, state=state, input_schema=input_schema)
|
|
190
|
+
|
|
191
|
+
def delete(
|
|
192
|
+
self, attach_opaque_data: bytes, tx_id: bytes, table_name: str, schema_name: str = "", returning: bool = False
|
|
193
|
+
) -> Stream[StreamState]:
|
|
194
|
+
"""Create a delete exchange stream."""
|
|
195
|
+
conn = self._get_tx_conn(attach_opaque_data, tx_id)
|
|
196
|
+
tx_lock = self._get_tx_lock(attach_opaque_data, tx_id)
|
|
197
|
+
qualified = f"{schema_name}.{table_name}" if schema_name else table_name
|
|
198
|
+
table_schema = self._table_schema(qualified, attach_opaque_data, tx_id)
|
|
199
|
+
|
|
200
|
+
input_schema = schema(rowid=pa.int64())
|
|
201
|
+
ret_fields = [f for f in table_schema if f.name != "rowid"]
|
|
202
|
+
ret_schema = pa.schema(ret_fields)
|
|
203
|
+
output_schema = ret_schema if returning else _COUNT_SCHEMA
|
|
204
|
+
|
|
205
|
+
sub = subcursor(conn)
|
|
206
|
+
state = _DeleteState(
|
|
207
|
+
conn=sub, qualified_name=qualified, returning=returning, table_schema=ret_schema, tx_lock=tx_lock
|
|
208
|
+
)
|
|
209
|
+
return Stream(output_schema=output_schema, state=state, input_schema=input_schema)
|
|
210
|
+
|
|
211
|
+
def update(
|
|
212
|
+
self,
|
|
213
|
+
attach_opaque_data: bytes,
|
|
214
|
+
tx_id: bytes,
|
|
215
|
+
table_name: str,
|
|
216
|
+
schema_name: str = "",
|
|
217
|
+
columns: list[str] | None = None,
|
|
218
|
+
returning: bool = False,
|
|
219
|
+
) -> Stream[StreamState]:
|
|
220
|
+
"""Create an update exchange stream."""
|
|
221
|
+
conn = self._get_tx_conn(attach_opaque_data, tx_id)
|
|
222
|
+
tx_lock = self._get_tx_lock(attach_opaque_data, tx_id)
|
|
223
|
+
qualified = f"{schema_name}.{table_name}" if schema_name else table_name
|
|
224
|
+
table_schema = self._table_schema(qualified, attach_opaque_data, tx_id)
|
|
225
|
+
|
|
226
|
+
if columns:
|
|
227
|
+
fields = [table_schema.field(c) for c in columns if table_schema.get_field_index(c) >= 0]
|
|
228
|
+
fields.append(pa.field("rowid", pa.int64()))
|
|
229
|
+
input_schema = pa.schema(fields)
|
|
230
|
+
else:
|
|
231
|
+
input_schema = table_schema
|
|
232
|
+
|
|
233
|
+
ret_fields = [f for f in table_schema if f.name != "rowid"]
|
|
234
|
+
ret_schema = pa.schema(ret_fields)
|
|
235
|
+
output_schema = ret_schema if returning else _COUNT_SCHEMA
|
|
236
|
+
|
|
237
|
+
sub = subcursor(conn)
|
|
238
|
+
state = _UpdateState(
|
|
239
|
+
conn=sub, qualified_name=qualified, returning=returning, table_schema=ret_schema, tx_lock=tx_lock
|
|
240
|
+
)
|
|
241
|
+
return Stream(output_schema=output_schema, state=state, input_schema=input_schema)
|
|
242
|
+
|
|
243
|
+
# ========== Read (streaming producer) ==========
|
|
244
|
+
|
|
245
|
+
def scan(
|
|
246
|
+
self,
|
|
247
|
+
attach_opaque_data: bytes,
|
|
248
|
+
tx_id: bytes,
|
|
249
|
+
table_name: str,
|
|
250
|
+
columns: list[str],
|
|
251
|
+
schema_name: str = "",
|
|
252
|
+
pushdown_filters: bytes | None = None,
|
|
253
|
+
) -> Stream[StreamState]:
|
|
254
|
+
"""Create a scan producer stream within the transaction."""
|
|
255
|
+
conn = self._get_tx_conn(attach_opaque_data, tx_id)
|
|
256
|
+
tx_lock = self._get_tx_lock(attach_opaque_data, tx_id)
|
|
257
|
+
qualified = f"{schema_name}.{table_name}" if schema_name else table_name
|
|
258
|
+
col_list = ", ".join(columns) if columns else "*"
|
|
259
|
+
|
|
260
|
+
sql = f"SELECT {col_list} FROM {qualified}" # noqa: S608
|
|
261
|
+
bind_params: list[object] = []
|
|
262
|
+
if pushdown_filters is not None:
|
|
263
|
+
from vgi.table_filter_pushdown import deserialize_filters
|
|
264
|
+
|
|
265
|
+
pf_reader = pa.ipc.open_stream(pushdown_filters)
|
|
266
|
+
pf_batch = pf_reader.read_next_batch()
|
|
267
|
+
pf = deserialize_filters(pf_batch)
|
|
268
|
+
if pf and pf.filters:
|
|
269
|
+
where_clause, bind_params = pf.to_sql()
|
|
270
|
+
sql += f" WHERE {where_clause}"
|
|
271
|
+
|
|
272
|
+
with tx_lock:
|
|
273
|
+
schema_sub = subcursor(conn)
|
|
274
|
+
schema_sql = f"SELECT {col_list} FROM {qualified} LIMIT 0" # noqa: S608
|
|
275
|
+
output_schema = schema_sub.execute(schema_sql).to_arrow_table().schema
|
|
276
|
+
schema_sub.close()
|
|
277
|
+
|
|
278
|
+
scan_cursor = subcursor(conn)
|
|
279
|
+
result = scan_cursor.execute(sql, bind_params) if bind_params else scan_cursor.execute(sql)
|
|
280
|
+
reader = result.to_arrow_reader(batch_size=50_000)
|
|
281
|
+
|
|
282
|
+
state = _ScanState(reader=reader, tx_lock=tx_lock)
|
|
283
|
+
return Stream(output_schema=output_schema, state=state)
|
|
284
|
+
|
|
285
|
+
# ========== DDL ==========
|
|
286
|
+
|
|
287
|
+
def execute_ddl(self, attach_opaque_data: bytes, sql: str) -> None:
|
|
288
|
+
"""Execute DDL statement on the database (non-transactional)."""
|
|
289
|
+
conn = self._get_db_conn(attach_opaque_data)
|
|
290
|
+
with self._lock:
|
|
291
|
+
conn.execute(sql)
|
|
292
|
+
self._catalog_versions[attach_opaque_data] = self._catalog_versions.get(attach_opaque_data, 1) + 1
|
|
293
|
+
logger.debug("DDL executed: %s", sql[:100])
|
|
294
|
+
|
|
295
|
+
def execute_ddl_tx(
|
|
296
|
+
self, attach_opaque_data: bytes, tx_id: bytes, sql: str, strip_catalog: str | None = None
|
|
297
|
+
) -> None:
|
|
298
|
+
"""Execute DDL within a transaction.
|
|
299
|
+
|
|
300
|
+
If strip_catalog is provided it overrides the registered catalog name.
|
|
301
|
+
Otherwise the catalog name from register() is used automatically.
|
|
302
|
+
"""
|
|
303
|
+
catalog_name = strip_catalog
|
|
304
|
+
if catalog_name is None:
|
|
305
|
+
with self._lock:
|
|
306
|
+
catalog_name = self._catalog_names.get(attach_opaque_data, "")
|
|
307
|
+
if catalog_name:
|
|
308
|
+
sql = self.strip_catalog_refs(sql, catalog_name)
|
|
309
|
+
conn = self._get_tx_conn(attach_opaque_data, tx_id)
|
|
310
|
+
tx_lock = self._get_tx_lock(attach_opaque_data, tx_id)
|
|
311
|
+
with tx_lock:
|
|
312
|
+
conn.execute(sql)
|
|
313
|
+
with self._lock:
|
|
314
|
+
self._catalog_versions[attach_opaque_data] = self._catalog_versions.get(attach_opaque_data, 1) + 1
|
|
315
|
+
logger.debug("DDL (tx) executed: %s", sql[:100])
|
|
316
|
+
|
|
317
|
+
def strip_catalog_refs(self, sql: str, catalog_name: str) -> str:
|
|
318
|
+
"""Strip external catalog references from SQL using AST transformation."""
|
|
319
|
+
import sqlglot
|
|
320
|
+
from sqlglot import exp
|
|
321
|
+
|
|
322
|
+
try:
|
|
323
|
+
parsed = sqlglot.parse_one(sql, dialect="duckdb")
|
|
324
|
+
except sqlglot.errors.ParseError:
|
|
325
|
+
logger.warning("strip_catalog_refs: failed to parse SQL, returning as-is: %s", sql[:100])
|
|
326
|
+
return sql
|
|
327
|
+
|
|
328
|
+
for table in parsed.find_all(exp.Table):
|
|
329
|
+
if table.catalog and table.catalog.lower() == catalog_name.lower():
|
|
330
|
+
table.set("catalog", None)
|
|
331
|
+
if table.args.get("db") and table.args["db"].name.lower() == "main":
|
|
332
|
+
table.set("db", None)
|
|
333
|
+
|
|
334
|
+
return parsed.sql(dialect="duckdb")
|
|
335
|
+
|
|
336
|
+
# ========== Metadata ==========
|
|
337
|
+
|
|
338
|
+
def _query_list(
|
|
339
|
+
self, attach_opaque_data: bytes, tx_id: bytes, sql: str, params: list[object] | None = None
|
|
340
|
+
) -> list[str]:
|
|
341
|
+
"""Execute a query within a transaction and return the first column as a list."""
|
|
342
|
+
conn = self._get_tx_conn(attach_opaque_data, tx_id)
|
|
343
|
+
tx_lock = self._get_tx_lock(attach_opaque_data, tx_id)
|
|
344
|
+
with tx_lock:
|
|
345
|
+
result = conn.execute(sql, params or [])
|
|
346
|
+
return [row[0] for row in result.fetchall()]
|
|
347
|
+
|
|
348
|
+
def list_schemas(self, attach_opaque_data: bytes, tx_id: bytes) -> list[str]:
|
|
349
|
+
"""List schema names within a transaction."""
|
|
350
|
+
return self._query_list(
|
|
351
|
+
attach_opaque_data, tx_id, "SELECT schema_name FROM duckdb_schemas() WHERE NOT internal"
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
def list_user_tables(self, attach_opaque_data: bytes, tx_id: bytes, schema_name: str = "main") -> list[str]:
|
|
355
|
+
"""List user tables in the given schema within a transaction."""
|
|
356
|
+
return self._query_list(
|
|
357
|
+
attach_opaque_data,
|
|
358
|
+
tx_id,
|
|
359
|
+
"SELECT table_name FROM information_schema.tables WHERE table_schema=? AND table_type='BASE TABLE'",
|
|
360
|
+
[schema_name],
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
def table_schema(self, attach_opaque_data: bytes, table_name: str, tx_id: bytes) -> bytes:
|
|
364
|
+
"""Get Arrow schema for a table as serialized IPC bytes within a transaction."""
|
|
365
|
+
conn = self._get_tx_conn(attach_opaque_data, tx_id)
|
|
366
|
+
tx_lock = self._get_tx_lock(attach_opaque_data, tx_id)
|
|
367
|
+
bare_name = table_name.rsplit(".", 1)[-1] if "." in table_name else table_name
|
|
368
|
+
with tx_lock:
|
|
369
|
+
sub = subcursor(conn)
|
|
370
|
+
if "." in table_name:
|
|
371
|
+
schema_part = table_name.rsplit(".", 1)[0]
|
|
372
|
+
row = sub.execute(
|
|
373
|
+
"SELECT COUNT(*) FROM duckdb_tables() WHERE schema_name = ? AND table_name = ?",
|
|
374
|
+
[schema_part, bare_name],
|
|
375
|
+
).fetchone()
|
|
376
|
+
is_table = row is not None and row[0] > 0
|
|
377
|
+
else:
|
|
378
|
+
row = sub.execute(
|
|
379
|
+
"SELECT COUNT(*) FROM duckdb_tables() WHERE table_name = ?",
|
|
380
|
+
[bare_name],
|
|
381
|
+
).fetchone()
|
|
382
|
+
is_table = row is not None and row[0] > 0
|
|
383
|
+
if not is_table:
|
|
384
|
+
sub.close()
|
|
385
|
+
raise ValueError(f"'{table_name}' is not a table")
|
|
386
|
+
schema = sub.execute(f"SELECT * FROM {table_name} LIMIT 0").to_arrow_table().schema # noqa: S608
|
|
387
|
+
|
|
388
|
+
col_meta_result = sub.execute(
|
|
389
|
+
"SELECT column_name, comment, column_default FROM duckdb_columns() WHERE table_name = ?",
|
|
390
|
+
[bare_name],
|
|
391
|
+
).fetchall()
|
|
392
|
+
sub.close()
|
|
393
|
+
meta_updates: dict[str, dict[bytes | str, bytes | str]] = {}
|
|
394
|
+
for row in col_meta_result:
|
|
395
|
+
col_name, comment, default = row[0], row[1], row[2]
|
|
396
|
+
updates: dict[bytes | str, bytes | str] = {}
|
|
397
|
+
if comment is not None:
|
|
398
|
+
updates[b"comment"] = comment.encode("utf-8")
|
|
399
|
+
if default is not None:
|
|
400
|
+
updates[b"default"] = default.encode("utf-8")
|
|
401
|
+
if updates:
|
|
402
|
+
meta_updates[col_name] = updates
|
|
403
|
+
if meta_updates:
|
|
404
|
+
fields = list(schema)
|
|
405
|
+
for i, f in enumerate(fields):
|
|
406
|
+
if f.name in meta_updates:
|
|
407
|
+
metadata: dict[bytes | str, bytes | str] = (
|
|
408
|
+
dict(cast("dict[bytes | str, bytes | str]", f.metadata)) if f.metadata else {}
|
|
409
|
+
)
|
|
410
|
+
metadata.update(meta_updates[f.name])
|
|
411
|
+
fields[i] = f.with_metadata(metadata)
|
|
412
|
+
schema = pa.schema(fields)
|
|
413
|
+
|
|
414
|
+
with tx_lock:
|
|
415
|
+
sub2 = subcursor(conn)
|
|
416
|
+
constraint_rows = sub2.execute(
|
|
417
|
+
"SELECT constraint_type, constraint_column_names, constraint_text, "
|
|
418
|
+
"referenced_table, referenced_column_names "
|
|
419
|
+
"FROM duckdb_constraints() WHERE table_name = ?",
|
|
420
|
+
[bare_name],
|
|
421
|
+
).fetchall()
|
|
422
|
+
sub2.close()
|
|
423
|
+
|
|
424
|
+
import json
|
|
425
|
+
|
|
426
|
+
constraints_json = json.dumps(
|
|
427
|
+
[
|
|
428
|
+
{
|
|
429
|
+
"type": row[0],
|
|
430
|
+
"columns": row[1],
|
|
431
|
+
"text": row[2],
|
|
432
|
+
"referenced_table": row[3],
|
|
433
|
+
"referenced_columns": row[4],
|
|
434
|
+
}
|
|
435
|
+
for row in constraint_rows
|
|
436
|
+
]
|
|
437
|
+
)
|
|
438
|
+
schema_meta: dict[bytes | str, bytes | str] = {b"vgi.constraints": constraints_json.encode("utf-8")}
|
|
439
|
+
|
|
440
|
+
rowid_field = pa.field("rowid", pa.int64(), metadata={b"is_row_id": b""})
|
|
441
|
+
result_schema = pa.schema([rowid_field, *schema], metadata=schema_meta)
|
|
442
|
+
return result_schema.serialize().to_pybytes()
|
|
443
|
+
|
|
444
|
+
def table_comment(self, attach_opaque_data: bytes, table_name: str, tx_id: bytes) -> str | None:
|
|
445
|
+
"""Get the comment on a table within a transaction."""
|
|
446
|
+
conn = self._get_tx_conn(attach_opaque_data, tx_id)
|
|
447
|
+
tx_lock = self._get_tx_lock(attach_opaque_data, tx_id)
|
|
448
|
+
bare_name = table_name.rsplit(".", 1)[-1] if "." in table_name else table_name
|
|
449
|
+
with tx_lock:
|
|
450
|
+
result = conn.execute(
|
|
451
|
+
"SELECT comment FROM duckdb_tables() WHERE table_name = ?",
|
|
452
|
+
[bare_name],
|
|
453
|
+
).fetchone()
|
|
454
|
+
if result and result[0]:
|
|
455
|
+
return str(result[0])
|
|
456
|
+
return None
|
|
457
|
+
|
|
458
|
+
def list_user_views(self, attach_opaque_data: bytes, tx_id: bytes, schema_name: str = "main") -> list[str]:
|
|
459
|
+
"""List user-created view names in the given schema within a transaction."""
|
|
460
|
+
return self._query_list(
|
|
461
|
+
attach_opaque_data,
|
|
462
|
+
tx_id,
|
|
463
|
+
"SELECT view_name FROM duckdb_views() WHERE schema_name = ? AND NOT internal",
|
|
464
|
+
[schema_name],
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
def view_info(self, attach_opaque_data: bytes, view_name: str, tx_id: bytes) -> str:
|
|
468
|
+
"""Return view info as JSON (definition, comment)."""
|
|
469
|
+
conn = self._get_tx_conn(attach_opaque_data, tx_id)
|
|
470
|
+
tx_lock = self._get_tx_lock(attach_opaque_data, tx_id)
|
|
471
|
+
with tx_lock:
|
|
472
|
+
sub = subcursor(conn)
|
|
473
|
+
result = sub.execute(
|
|
474
|
+
"SELECT sql, comment FROM duckdb_views() WHERE view_name = ?",
|
|
475
|
+
[view_name],
|
|
476
|
+
).fetchone()
|
|
477
|
+
sub.close()
|
|
478
|
+
if result is None:
|
|
479
|
+
raise ValueError(f"View '{view_name}' not found")
|
|
480
|
+
import json
|
|
481
|
+
import re
|
|
482
|
+
|
|
483
|
+
definition = result[0] or ""
|
|
484
|
+
match = re.match(r"CREATE\s+VIEW\s+\S+\s+AS\s+", definition, re.IGNORECASE)
|
|
485
|
+
if match:
|
|
486
|
+
definition = definition[match.end() :]
|
|
487
|
+
definition = definition.rstrip().rstrip(";")
|
|
488
|
+
return json.dumps({"definition": definition, "comment": result[1]})
|
|
489
|
+
|
|
490
|
+
# ========== Lifecycle ==========
|
|
491
|
+
|
|
492
|
+
def ping(self) -> None:
|
|
493
|
+
"""Health check."""
|
|
494
|
+
|
|
495
|
+
def shutdown(self) -> None:
|
|
496
|
+
"""Graceful shutdown — rollback active transactions and close all databases."""
|
|
497
|
+
logger.info("Shutdown requested")
|
|
498
|
+
with self._lock:
|
|
499
|
+
for attach_opaque_data, db_txns in list(self._transactions.items()):
|
|
500
|
+
for tx_id, conn in list(db_txns.items()):
|
|
501
|
+
try:
|
|
502
|
+
conn.rollback()
|
|
503
|
+
conn.close()
|
|
504
|
+
logger.info("Rolled back orphan tx: %s (db %s)", tx_id.hex()[:8], attach_opaque_data.hex()[:8])
|
|
505
|
+
except Exception:
|
|
506
|
+
logger.exception("Failed to rollback tx: %s", tx_id.hex()[:8])
|
|
507
|
+
self._transactions.clear()
|
|
508
|
+
self._tx_locks.clear()
|
|
509
|
+
for attach_opaque_data, conn in list(self._databases.items()):
|
|
510
|
+
try:
|
|
511
|
+
conn.close()
|
|
512
|
+
except Exception:
|
|
513
|
+
logger.exception("Failed to close database: %s", attach_opaque_data.hex()[:8])
|
|
514
|
+
self._databases.clear()
|
|
515
|
+
sys.exit(0)
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
# ============================================================================
|
|
519
|
+
# Stream state implementations
|
|
520
|
+
# ============================================================================
|
|
521
|
+
|
|
522
|
+
|
|
523
|
+
def _read_result_batch(result: duckdb.DuckDBPyConnection) -> pa.RecordBatch:
|
|
524
|
+
"""Read DML result as a single Arrow batch."""
|
|
525
|
+
table = result.to_arrow_table()
|
|
526
|
+
batches = table.to_batches()
|
|
527
|
+
if not batches:
|
|
528
|
+
return pa.record_batch({f.name: [] for f in table.schema}, schema=table.schema)
|
|
529
|
+
if len(batches) == 1:
|
|
530
|
+
return batches[0]
|
|
531
|
+
return pa.Table.from_batches(batches, schema=table.schema).combine_chunks().to_batches()[0]
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
_batch_counter = 0
|
|
535
|
+
_batch_counter_lock = threading.Lock()
|
|
536
|
+
|
|
537
|
+
|
|
538
|
+
def _unique_batch_name(prefix: str) -> str:
|
|
539
|
+
"""Generate a unique registered batch name to avoid collisions across transactions."""
|
|
540
|
+
global _batch_counter # noqa: PLW0603
|
|
541
|
+
with _batch_counter_lock:
|
|
542
|
+
_batch_counter += 1
|
|
543
|
+
return f"__{prefix}_{_batch_counter}__"
|
|
544
|
+
|
|
545
|
+
|
|
546
|
+
class _InsertState(ExchangeState):
|
|
547
|
+
"""Insert exchange: receives row batches, inserts into table, returns count or RETURNING rows."""
|
|
548
|
+
|
|
549
|
+
def __init__(
|
|
550
|
+
self,
|
|
551
|
+
conn: duckdb.DuckDBPyConnection,
|
|
552
|
+
qualified_name: str,
|
|
553
|
+
returning: bool,
|
|
554
|
+
table_schema: pa.Schema,
|
|
555
|
+
tx_lock: threading.Lock,
|
|
556
|
+
) -> None:
|
|
557
|
+
self.conn = conn
|
|
558
|
+
self.qualified_name = qualified_name
|
|
559
|
+
self.returning = returning
|
|
560
|
+
self.table_schema = table_schema
|
|
561
|
+
self.tx_lock = tx_lock
|
|
562
|
+
|
|
563
|
+
def exchange(self, input: AnnotatedBatch, out: OutputCollector, ctx: CallContext) -> None:
|
|
564
|
+
with self.tx_lock:
|
|
565
|
+
batch = input.batch
|
|
566
|
+
col_names = ", ".join(batch.schema.names)
|
|
567
|
+
view_name = _unique_batch_name("insert")
|
|
568
|
+
sql = f"INSERT INTO {self.qualified_name} ({col_names}) SELECT * FROM {view_name}" # noqa: S608
|
|
569
|
+
if self.returning:
|
|
570
|
+
ret_cols = ", ".join(self.table_schema.names)
|
|
571
|
+
sql += f" RETURNING {ret_cols}"
|
|
572
|
+
self.conn.register(view_name, batch)
|
|
573
|
+
result = self.conn.execute(sql)
|
|
574
|
+
result_batch = _read_result_batch(result)
|
|
575
|
+
self.conn.unregister(view_name)
|
|
576
|
+
self._emit_result(result_batch, out)
|
|
577
|
+
|
|
578
|
+
def _emit_result(self, result_batch: pa.RecordBatch, out: OutputCollector) -> None:
|
|
579
|
+
if self.returning:
|
|
580
|
+
out.emit(
|
|
581
|
+
result_batch
|
|
582
|
+
if result_batch.num_rows > 0
|
|
583
|
+
else pa.record_batch({c: [] for c in self.table_schema.names}, schema=self.table_schema)
|
|
584
|
+
)
|
|
585
|
+
else:
|
|
586
|
+
count = result_batch.column("Count")[0].as_py() if result_batch.num_rows > 0 else 0
|
|
587
|
+
out.emit(pa.record_batch({"count": [count]}, schema=_COUNT_SCHEMA))
|
|
588
|
+
|
|
589
|
+
|
|
590
|
+
class _DeleteState(ExchangeState):
|
|
591
|
+
"""Delete exchange: receives rowid batches, deletes matching rows."""
|
|
592
|
+
|
|
593
|
+
def __init__(
|
|
594
|
+
self,
|
|
595
|
+
conn: duckdb.DuckDBPyConnection,
|
|
596
|
+
qualified_name: str,
|
|
597
|
+
returning: bool,
|
|
598
|
+
table_schema: pa.Schema,
|
|
599
|
+
tx_lock: threading.Lock,
|
|
600
|
+
) -> None:
|
|
601
|
+
self.conn = conn
|
|
602
|
+
self.qualified_name = qualified_name
|
|
603
|
+
self.returning = returning
|
|
604
|
+
self.table_schema = table_schema
|
|
605
|
+
self.tx_lock = tx_lock
|
|
606
|
+
|
|
607
|
+
def exchange(self, input: AnnotatedBatch, out: OutputCollector, ctx: CallContext) -> None:
|
|
608
|
+
with self.tx_lock:
|
|
609
|
+
batch = input.batch
|
|
610
|
+
view_name = _unique_batch_name("delete")
|
|
611
|
+
self.conn.register(view_name, batch)
|
|
612
|
+
if self.returning:
|
|
613
|
+
ret_cols = ", ".join(f"{self.qualified_name}.{c}" for c in self.table_schema.names)
|
|
614
|
+
select_sql = (
|
|
615
|
+
f"SELECT {ret_cols} FROM {self.qualified_name} " # noqa: S608
|
|
616
|
+
f"JOIN {view_name} ON {self.qualified_name}.rowid = {view_name}.rowid"
|
|
617
|
+
)
|
|
618
|
+
result_batch = _read_result_batch(self.conn.execute(select_sql))
|
|
619
|
+
self.conn.execute(
|
|
620
|
+
f"DELETE FROM {self.qualified_name} " # noqa: S608
|
|
621
|
+
f"USING {view_name} WHERE {self.qualified_name}.rowid = {view_name}.rowid",
|
|
622
|
+
)
|
|
623
|
+
else:
|
|
624
|
+
result = self.conn.execute(
|
|
625
|
+
f"DELETE FROM {self.qualified_name} " # noqa: S608
|
|
626
|
+
f"USING {view_name} WHERE {self.qualified_name}.rowid = {view_name}.rowid",
|
|
627
|
+
)
|
|
628
|
+
result_batch = _read_result_batch(result)
|
|
629
|
+
self.conn.unregister(view_name)
|
|
630
|
+
|
|
631
|
+
if self.returning:
|
|
632
|
+
out.emit(
|
|
633
|
+
result_batch
|
|
634
|
+
if result_batch.num_rows > 0
|
|
635
|
+
else pa.record_batch({c: [] for c in self.table_schema.names}, schema=self.table_schema)
|
|
636
|
+
)
|
|
637
|
+
else:
|
|
638
|
+
count = result_batch.column("Count")[0].as_py() if result_batch.num_rows > 0 else 0
|
|
639
|
+
out.emit(pa.record_batch({"count": [count]}, schema=_COUNT_SCHEMA))
|
|
640
|
+
|
|
641
|
+
|
|
642
|
+
class _UpdateState(ExchangeState):
|
|
643
|
+
"""Update exchange: receives rowid + updated columns, updates matching rows."""
|
|
644
|
+
|
|
645
|
+
def __init__(
|
|
646
|
+
self,
|
|
647
|
+
conn: duckdb.DuckDBPyConnection,
|
|
648
|
+
qualified_name: str,
|
|
649
|
+
returning: bool,
|
|
650
|
+
table_schema: pa.Schema,
|
|
651
|
+
tx_lock: threading.Lock,
|
|
652
|
+
) -> None:
|
|
653
|
+
self.conn = conn
|
|
654
|
+
self.qualified_name = qualified_name
|
|
655
|
+
self.returning = returning
|
|
656
|
+
self.table_schema = table_schema
|
|
657
|
+
self.tx_lock = tx_lock
|
|
658
|
+
|
|
659
|
+
def exchange(self, input: AnnotatedBatch, out: OutputCollector, ctx: CallContext) -> None:
|
|
660
|
+
with self.tx_lock:
|
|
661
|
+
batch = input.batch
|
|
662
|
+
view_name = _unique_batch_name("update")
|
|
663
|
+
update_cols = [name for name in batch.schema.names if name != "rowid"]
|
|
664
|
+
set_clause = ", ".join(f"{col} = {view_name}.{col}" for col in update_cols)
|
|
665
|
+
sql = (
|
|
666
|
+
f"UPDATE {self.qualified_name} SET {set_clause} " # noqa: S608
|
|
667
|
+
f"FROM {view_name} WHERE {self.qualified_name}.rowid = {view_name}.rowid"
|
|
668
|
+
)
|
|
669
|
+
if self.returning:
|
|
670
|
+
ret_cols = ", ".join(self.table_schema.names)
|
|
671
|
+
sql += f" RETURNING {ret_cols}"
|
|
672
|
+
self.conn.register(view_name, batch)
|
|
673
|
+
result = self.conn.execute(sql)
|
|
674
|
+
result_batch = _read_result_batch(result)
|
|
675
|
+
self.conn.unregister(view_name)
|
|
676
|
+
|
|
677
|
+
if self.returning:
|
|
678
|
+
out.emit(
|
|
679
|
+
result_batch
|
|
680
|
+
if result_batch.num_rows > 0
|
|
681
|
+
else pa.record_batch({c: [] for c in self.table_schema.names}, schema=self.table_schema)
|
|
682
|
+
)
|
|
683
|
+
else:
|
|
684
|
+
count = result_batch.column("Count")[0].as_py() if result_batch.num_rows > 0 else 0
|
|
685
|
+
out.emit(pa.record_batch({"count": [count]}, schema=_COUNT_SCHEMA))
|
|
686
|
+
|
|
687
|
+
|
|
688
|
+
class _ScanState(ProducerState):
|
|
689
|
+
"""Scan producer: streams Arrow batches from a query result."""
|
|
690
|
+
|
|
691
|
+
def __init__(self, reader: pa.RecordBatchReader, tx_lock: threading.Lock) -> None:
|
|
692
|
+
self._reader = reader
|
|
693
|
+
self._tx_lock = tx_lock
|
|
694
|
+
|
|
695
|
+
def produce(self, out: OutputCollector, ctx: CallContext) -> None:
|
|
696
|
+
with self._tx_lock:
|
|
697
|
+
while True:
|
|
698
|
+
try:
|
|
699
|
+
batch = self._reader.read_next_batch()
|
|
700
|
+
except StopIteration:
|
|
701
|
+
out.finish()
|
|
702
|
+
return
|
|
703
|
+
if batch.num_rows > 0:
|
|
704
|
+
out.emit(batch)
|
|
705
|
+
return
|
|
706
|
+
|
|
707
|
+
|
|
708
|
+
def main() -> None:
|
|
709
|
+
"""Entry point for the vgi-transactor command."""
|
|
710
|
+
# sqlglot is imported lazily inside ``strip_catalog_refs``, but DDL paths
|
|
711
|
+
# require it. Surface a clear install message at startup instead of
|
|
712
|
+
# blowing up mid-transaction.
|
|
713
|
+
try:
|
|
714
|
+
import sqlglot # noqa: F401
|
|
715
|
+
except ImportError:
|
|
716
|
+
import sys as _sys
|
|
717
|
+
|
|
718
|
+
_sys.exit("vgi-transactor requires the transactor extra. Install with: pip install 'vgi-python[transactor]'")
|
|
719
|
+
|
|
720
|
+
parser = argparse.ArgumentParser(description="VGI db-transactor server")
|
|
721
|
+
parser.add_argument("--db-dir", required=True, help="Directory for DuckDB database files")
|
|
722
|
+
parser.add_argument("--socket", required=True, help="Unix domain socket path to listen on")
|
|
723
|
+
parser.add_argument("--log-file", default=None, help="Log file path (default: derived from socket path)")
|
|
724
|
+
args = parser.parse_args()
|
|
725
|
+
|
|
726
|
+
log_path = args.log_file or args.socket.replace(".sock", ".log")
|
|
727
|
+
logging.basicConfig(
|
|
728
|
+
level=logging.INFO,
|
|
729
|
+
format="%(asctime)s %(name)s %(levelname)s %(message)s",
|
|
730
|
+
filename=log_path,
|
|
731
|
+
)
|
|
732
|
+
|
|
733
|
+
impl = TransactorImpl(args.db_dir)
|
|
734
|
+
server = RpcServer(TransactorProtocol, impl)
|
|
735
|
+
logger.info("Serving on %s", args.socket)
|
|
736
|
+
serve_unix(server, args.socket, threaded=True)
|
|
737
|
+
|
|
738
|
+
|
|
739
|
+
if __name__ == "__main__":
|
|
740
|
+
main()
|