vgi-python 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. vgi/__init__.py +152 -0
  2. vgi/_duckdb.py +62 -0
  3. vgi/_storage_profile.py +132 -0
  4. vgi/_test_fixtures/__init__.py +20 -0
  5. vgi/_test_fixtures/accumulate/__init__.py +19 -0
  6. vgi/_test_fixtures/accumulate/worker.py +762 -0
  7. vgi/_test_fixtures/aggregate/__init__.py +62 -0
  8. vgi/_test_fixtures/aggregate/_common.py +21 -0
  9. vgi/_test_fixtures/aggregate/basic.py +232 -0
  10. vgi/_test_fixtures/aggregate/dynamic.py +409 -0
  11. vgi/_test_fixtures/aggregate/generic.py +86 -0
  12. vgi/_test_fixtures/aggregate/listagg.py +71 -0
  13. vgi/_test_fixtures/aggregate/percentile.py +107 -0
  14. vgi/_test_fixtures/aggregate/streaming.py +192 -0
  15. vgi/_test_fixtures/aggregate/varargs.py +75 -0
  16. vgi/_test_fixtures/aggregate/window.py +380 -0
  17. vgi/_test_fixtures/attach_options.py +308 -0
  18. vgi/_test_fixtures/bad_protocol.py +62 -0
  19. vgi/_test_fixtures/cancellable.py +336 -0
  20. vgi/_test_fixtures/catalog.py +813 -0
  21. vgi/_test_fixtures/http_server.py +394 -0
  22. vgi/_test_fixtures/nest_tensor.py +614 -0
  23. vgi/_test_fixtures/orchard_catalog.py +47 -0
  24. vgi/_test_fixtures/projection_repro/__init__.py +6 -0
  25. vgi/_test_fixtures/projection_repro/worker.py +454 -0
  26. vgi/_test_fixtures/scalar/__init__.py +116 -0
  27. vgi/_test_fixtures/scalar/_common.py +69 -0
  28. vgi/_test_fixtures/scalar/arithmetic.py +321 -0
  29. vgi/_test_fixtures/scalar/binary.py +120 -0
  30. vgi/_test_fixtures/scalar/formatting.py +176 -0
  31. vgi/_test_fixtures/scalar/geo.py +300 -0
  32. vgi/_test_fixtures/scalar/null_handling.py +107 -0
  33. vgi/_test_fixtures/scalar/random_demo.py +171 -0
  34. vgi/_test_fixtures/scalar/settings_secrets.py +102 -0
  35. vgi/_test_fixtures/scalar/type_info.py +219 -0
  36. vgi/_test_fixtures/schema_reconcile/__init__.py +29 -0
  37. vgi/_test_fixtures/schema_reconcile/worker.py +653 -0
  38. vgi/_test_fixtures/simple_writable.py +793 -0
  39. vgi/_test_fixtures/table/__init__.py +221 -0
  40. vgi/_test_fixtures/table/_common.py +162 -0
  41. vgi/_test_fixtures/table/batch_index.py +283 -0
  42. vgi/_test_fixtures/table/batch_index_broken.py +200 -0
  43. vgi/_test_fixtures/table/catalog_scans.py +162 -0
  44. vgi/_test_fixtures/table/filters.py +1005 -0
  45. vgi/_test_fixtures/table/late_materialization.py +249 -0
  46. vgi/_test_fixtures/table/make_series.py +273 -0
  47. vgi/_test_fixtures/table/misc.py +499 -0
  48. vgi/_test_fixtures/table/order_modes.py +164 -0
  49. vgi/_test_fixtures/table/pairs.py +437 -0
  50. vgi/_test_fixtures/table/partition_columns.py +472 -0
  51. vgi/_test_fixtures/table/partition_columns_broken.py +304 -0
  52. vgi/_test_fixtures/table/profiling_example.py +195 -0
  53. vgi/_test_fixtures/table/required_filters.py +234 -0
  54. vgi/_test_fixtures/table/sequence.py +710 -0
  55. vgi/_test_fixtures/table/settings.py +426 -0
  56. vgi/_test_fixtures/table/transaction_storage.py +162 -0
  57. vgi/_test_fixtures/table/tt_pushdown.py +191 -0
  58. vgi/_test_fixtures/table/versioned.py +230 -0
  59. vgi/_test_fixtures/table_in_out.py +1392 -0
  60. vgi/_test_fixtures/versioned.py +155 -0
  61. vgi/_test_fixtures/versioned_tables.py +595 -0
  62. vgi/_test_fixtures/worker.py +1631 -0
  63. vgi/_test_fixtures/writable/__init__.py +8 -0
  64. vgi/_test_fixtures/writable/generic.py +236 -0
  65. vgi/_test_fixtures/writable/table.py +149 -0
  66. vgi/_test_fixtures/writable/worker.py +1148 -0
  67. vgi/aggregate_function.py +607 -0
  68. vgi/argument_spec.py +472 -0
  69. vgi/arguments.py +1747 -0
  70. vgi/auth.py +55 -0
  71. vgi/catalog/__init__.py +88 -0
  72. vgi/catalog/attach_option.py +206 -0
  73. vgi/catalog/catalog_interface.py +2767 -0
  74. vgi/catalog/descriptors.py +870 -0
  75. vgi/catalog/duckdb_statistics.py +377 -0
  76. vgi/catalog/secret_type.py +96 -0
  77. vgi/catalog/setting.py +253 -0
  78. vgi/catalog/storage.py +372 -0
  79. vgi/client/__init__.py +67 -0
  80. vgi/client/catalog_mixin.py +1251 -0
  81. vgi/client/cli.py +582 -0
  82. vgi/client/cli_catalog.py +182 -0
  83. vgi/client/cli_schema.py +270 -0
  84. vgi/client/cli_table.py +907 -0
  85. vgi/client/cli_transaction.py +97 -0
  86. vgi/client/cli_utils.py +441 -0
  87. vgi/client/cli_view.py +303 -0
  88. vgi/client/client.py +2183 -0
  89. vgi/exceptions.py +205 -0
  90. vgi/function.py +245 -0
  91. vgi/function_storage.py +1636 -0
  92. vgi/function_storage_azure_sql.py +922 -0
  93. vgi/function_storage_cf_do.py +740 -0
  94. vgi/http/__init__.py +25 -0
  95. vgi/http/demo_storage.py +212 -0
  96. vgi/http/worker_page.py +1252 -0
  97. vgi/invocation.py +154 -0
  98. vgi/logging_config.py +93 -0
  99. vgi/meta_worker.py +661 -0
  100. vgi/metadata.py +1403 -0
  101. vgi/otel.py +406 -0
  102. vgi/protocol.py +2418 -0
  103. vgi/protocol_version.txt +1 -0
  104. vgi/py.typed +0 -0
  105. vgi/scalar_function.py +1211 -0
  106. vgi/schema_utils.py +234 -0
  107. vgi/secret_protocol.py +124 -0
  108. vgi/secret_service.py +238 -0
  109. vgi/serve.py +769 -0
  110. vgi/table_buffering_function.py +443 -0
  111. vgi/table_filter_pushdown.py +1528 -0
  112. vgi/table_function.py +1130 -0
  113. vgi/table_in_out_function.py +383 -0
  114. vgi/transactor/__init__.py +24 -0
  115. vgi/transactor/_duckdb_compat.py +27 -0
  116. vgi/transactor/client.py +137 -0
  117. vgi/transactor/protocol.py +149 -0
  118. vgi/transactor/server.py +740 -0
  119. vgi/worker.py +4761 -0
  120. vgi_python-0.8.0.dist-info/METADATA +735 -0
  121. vgi_python-0.8.0.dist-info/RECORD +124 -0
  122. vgi_python-0.8.0.dist-info/WHEEL +4 -0
  123. vgi_python-0.8.0.dist-info/entry_points.txt +5 -0
  124. vgi_python-0.8.0.dist-info/licenses/LICENSE +134 -0
@@ -0,0 +1,740 @@
1
+ # Copyright 2025, 2026 Query Farm LLC - https://query.farm
2
+
3
+ """db-transactor server — multi-database DuckDB transaction manager.
4
+
5
+ Runs as a long-lived subprocess, accepting ``vgi_rpc`` connections over a
6
+ Unix domain socket. Manages multiple DuckDB databases, one per catalog
7
+ attachment (identified by ``attach_opaque_data``).
8
+
9
+ Usage::
10
+
11
+ vgi-transactor --db-dir /path/to/databases --socket /tmp/vgi-transactor.sock
12
+
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import argparse
18
+ import logging
19
+ import os
20
+ import sys
21
+ import threading
22
+ import uuid
23
+ from typing import TYPE_CHECKING, cast
24
+
25
+ import pyarrow as pa
26
+ from vgi_rpc import AnnotatedBatch, OutputCollector, RpcServer
27
+ from vgi_rpc.rpc import CallContext, ExchangeState, ProducerState, Stream, StreamState, serve_unix
28
+
29
+ from vgi._duckdb import connect as engine_connect
30
+ from vgi.schema_utils import schema
31
+ from vgi.transactor._duckdb_compat import subcursor
32
+ from vgi.transactor.protocol import TransactorProtocol
33
+
34
+ if TYPE_CHECKING:
35
+ import duckdb
36
+
37
+ logger = logging.getLogger("vgi.transactor")
38
+
39
+ _COUNT_SCHEMA = schema(count=pa.int64())
40
+
41
+
42
+ class TransactorImpl:
43
+ """Implementation of the TransactorProtocol backed by DuckDB.
44
+
45
+ Manages multiple databases (one per attach_opaque_data). Each transaction gets
46
+ its own DuckDB cursor, allowing multiple concurrent transactions per
47
+ database.
48
+ """
49
+
50
+ def __init__(self, db_dir: str) -> None:
51
+ """Initialize with the directory for database files."""
52
+ self._db_dir = db_dir
53
+ os.makedirs(db_dir, exist_ok=True)
54
+ self._lock = threading.Lock()
55
+ self._databases: dict[bytes, duckdb.DuckDBPyConnection] = {}
56
+ self._catalog_names: dict[bytes, str] = {} # attach_opaque_data → catalog name (for view SQL stripping)
57
+ self._catalog_versions: dict[bytes, int] = {} # attach_opaque_data → version (incremented on DDL)
58
+ # Transactions nested by attach_opaque_data: {attach_opaque_data: {tx_id: cursor}}
59
+ self._transactions: dict[bytes, dict[bytes, duckdb.DuckDBPyConnection]] = {}
60
+ self._tx_locks: dict[bytes, dict[bytes, threading.Lock]] = {}
61
+ logger.info("Transactor started: db_dir=%s", db_dir)
62
+
63
+ # ========== Helpers ==========
64
+
65
+ def _get_db_conn(self, attach_opaque_data: bytes) -> duckdb.DuckDBPyConnection:
66
+ """Get the main connection for a database, raising if not registered."""
67
+ with self._lock:
68
+ conn = self._databases.get(attach_opaque_data)
69
+ if conn is None:
70
+ msg = f"No registered database: {attach_opaque_data.hex()}"
71
+ raise ValueError(msg)
72
+ return conn
73
+
74
+ def _get_tx_conn(self, attach_opaque_data: bytes, tx_id: bytes) -> duckdb.DuckDBPyConnection:
75
+ """Get the cursor for a transaction within a database."""
76
+ with self._lock:
77
+ db_txns = self._transactions.get(attach_opaque_data, {})
78
+ conn = db_txns.get(tx_id)
79
+ if conn is None:
80
+ msg = f"No active transaction: {tx_id.hex()} in db {attach_opaque_data.hex()}"
81
+ raise ValueError(msg)
82
+ return conn
83
+
84
+ def _get_tx_lock(self, attach_opaque_data: bytes, tx_id: bytes) -> threading.Lock:
85
+ """Get the per-transaction lock."""
86
+ with self._lock:
87
+ db_locks = self._tx_locks.setdefault(attach_opaque_data, {})
88
+ if tx_id not in db_locks:
89
+ db_locks[tx_id] = threading.Lock()
90
+ return db_locks[tx_id]
91
+
92
+ def _table_schema(self, qualified_name: str, attach_opaque_data: bytes, tx_id: bytes) -> pa.Schema:
93
+ """Get the Arrow schema for a table using a subcursor."""
94
+ conn = self._get_tx_conn(attach_opaque_data, tx_id)
95
+ tx_lock = self._get_tx_lock(attach_opaque_data, tx_id)
96
+ with tx_lock:
97
+ sub = subcursor(conn)
98
+ sql = f"SELECT * FROM {qualified_name} LIMIT 0" # noqa: S608
99
+ result = sub.execute(sql)
100
+ schema: pa.Schema = result.to_arrow_table().schema
101
+ sub.close()
102
+ return schema
103
+
104
+ # ========== Database lifecycle ==========
105
+
106
+ def register(
107
+ self, attach_opaque_data: bytes, catalog_name: str = "", ddl_statements: list[str] | None = None
108
+ ) -> None:
109
+ """Create a new database for this attach_opaque_data and run initial DDL."""
110
+ db_path = os.path.join(self._db_dir, f"{attach_opaque_data.hex()}.duckdb")
111
+ conn = engine_connect(db_path)
112
+ with self._lock:
113
+ self._databases[attach_opaque_data] = conn
114
+ self._catalog_names[attach_opaque_data] = catalog_name
115
+ self._catalog_versions[attach_opaque_data] = 1
116
+ if ddl_statements:
117
+ for sql in ddl_statements:
118
+ conn.execute(sql)
119
+ logger.info("Database registered: %s (catalog=%s) -> %s", attach_opaque_data.hex()[:8], catalog_name, db_path)
120
+
121
+ def catalog_version(self, attach_opaque_data: bytes) -> int:
122
+ """Return the catalog version for the database."""
123
+ with self._lock:
124
+ return self._catalog_versions.get(attach_opaque_data, 1)
125
+
126
+ # ========== Transaction lifecycle ==========
127
+
128
+ def begin(self, attach_opaque_data: bytes) -> bytes:
129
+ """Begin a transaction on the database. Returns the tx_id."""
130
+ db_conn = self._get_db_conn(attach_opaque_data)
131
+ tx_id = uuid.uuid4().bytes
132
+ cursor = db_conn.cursor()
133
+ cursor.execute("SET enable_suspended_queries = true")
134
+ cursor.begin()
135
+ with self._lock:
136
+ self._transactions.setdefault(attach_opaque_data, {})[tx_id] = cursor
137
+ self._tx_locks.setdefault(attach_opaque_data, {})[tx_id] = threading.Lock()
138
+ logger.info("Transaction begun: %s (db %s)", tx_id.hex()[:8], attach_opaque_data.hex()[:8])
139
+ return tx_id
140
+
141
+ def commit(self, attach_opaque_data: bytes, tx_id: bytes) -> None:
142
+ """Commit a transaction."""
143
+ conn = self._get_tx_conn(attach_opaque_data, tx_id)
144
+ conn.commit()
145
+ conn.close()
146
+ with self._lock:
147
+ self._transactions.get(attach_opaque_data, {}).pop(tx_id, None)
148
+ self._tx_locks.get(attach_opaque_data, {}).pop(tx_id, None)
149
+ logger.info("Transaction committed: %s", tx_id.hex()[:8])
150
+
151
+ def rollback(self, attach_opaque_data: bytes, tx_id: bytes) -> None:
152
+ """Rollback a transaction."""
153
+ conn = self._get_tx_conn(attach_opaque_data, tx_id)
154
+ conn.rollback()
155
+ conn.close()
156
+ with self._lock:
157
+ self._transactions.get(attach_opaque_data, {}).pop(tx_id, None)
158
+ self._tx_locks.get(attach_opaque_data, {}).pop(tx_id, None)
159
+ logger.info("Transaction rolled back: %s", tx_id.hex()[:8])
160
+
161
+ # ========== Write operations (streaming exchange) ==========
162
+
163
+ def insert(
164
+ self,
165
+ attach_opaque_data: bytes,
166
+ tx_id: bytes,
167
+ table_name: str,
168
+ schema_name: str = "",
169
+ returning: bool = False,
170
+ ) -> Stream[StreamState]:
171
+ """Create an insert exchange stream."""
172
+ conn = self._get_tx_conn(attach_opaque_data, tx_id)
173
+ tx_lock = self._get_tx_lock(attach_opaque_data, tx_id)
174
+ qualified = f"{schema_name}.{table_name}" if schema_name else table_name
175
+ table_schema = self._table_schema(qualified, attach_opaque_data, tx_id)
176
+
177
+ input_fields = [f for f in table_schema if f.name != "rowid"]
178
+ input_schema = pa.schema(input_fields)
179
+ output_schema = input_schema if returning else _COUNT_SCHEMA
180
+
181
+ sub = subcursor(conn)
182
+ state = _InsertState(
183
+ conn=sub,
184
+ qualified_name=qualified,
185
+ returning=returning,
186
+ table_schema=input_schema,
187
+ tx_lock=tx_lock,
188
+ )
189
+ return Stream(output_schema=output_schema, state=state, input_schema=input_schema)
190
+
191
+ def delete(
192
+ self, attach_opaque_data: bytes, tx_id: bytes, table_name: str, schema_name: str = "", returning: bool = False
193
+ ) -> Stream[StreamState]:
194
+ """Create a delete exchange stream."""
195
+ conn = self._get_tx_conn(attach_opaque_data, tx_id)
196
+ tx_lock = self._get_tx_lock(attach_opaque_data, tx_id)
197
+ qualified = f"{schema_name}.{table_name}" if schema_name else table_name
198
+ table_schema = self._table_schema(qualified, attach_opaque_data, tx_id)
199
+
200
+ input_schema = schema(rowid=pa.int64())
201
+ ret_fields = [f for f in table_schema if f.name != "rowid"]
202
+ ret_schema = pa.schema(ret_fields)
203
+ output_schema = ret_schema if returning else _COUNT_SCHEMA
204
+
205
+ sub = subcursor(conn)
206
+ state = _DeleteState(
207
+ conn=sub, qualified_name=qualified, returning=returning, table_schema=ret_schema, tx_lock=tx_lock
208
+ )
209
+ return Stream(output_schema=output_schema, state=state, input_schema=input_schema)
210
+
211
+ def update(
212
+ self,
213
+ attach_opaque_data: bytes,
214
+ tx_id: bytes,
215
+ table_name: str,
216
+ schema_name: str = "",
217
+ columns: list[str] | None = None,
218
+ returning: bool = False,
219
+ ) -> Stream[StreamState]:
220
+ """Create an update exchange stream."""
221
+ conn = self._get_tx_conn(attach_opaque_data, tx_id)
222
+ tx_lock = self._get_tx_lock(attach_opaque_data, tx_id)
223
+ qualified = f"{schema_name}.{table_name}" if schema_name else table_name
224
+ table_schema = self._table_schema(qualified, attach_opaque_data, tx_id)
225
+
226
+ if columns:
227
+ fields = [table_schema.field(c) for c in columns if table_schema.get_field_index(c) >= 0]
228
+ fields.append(pa.field("rowid", pa.int64()))
229
+ input_schema = pa.schema(fields)
230
+ else:
231
+ input_schema = table_schema
232
+
233
+ ret_fields = [f for f in table_schema if f.name != "rowid"]
234
+ ret_schema = pa.schema(ret_fields)
235
+ output_schema = ret_schema if returning else _COUNT_SCHEMA
236
+
237
+ sub = subcursor(conn)
238
+ state = _UpdateState(
239
+ conn=sub, qualified_name=qualified, returning=returning, table_schema=ret_schema, tx_lock=tx_lock
240
+ )
241
+ return Stream(output_schema=output_schema, state=state, input_schema=input_schema)
242
+
243
+ # ========== Read (streaming producer) ==========
244
+
245
+ def scan(
246
+ self,
247
+ attach_opaque_data: bytes,
248
+ tx_id: bytes,
249
+ table_name: str,
250
+ columns: list[str],
251
+ schema_name: str = "",
252
+ pushdown_filters: bytes | None = None,
253
+ ) -> Stream[StreamState]:
254
+ """Create a scan producer stream within the transaction."""
255
+ conn = self._get_tx_conn(attach_opaque_data, tx_id)
256
+ tx_lock = self._get_tx_lock(attach_opaque_data, tx_id)
257
+ qualified = f"{schema_name}.{table_name}" if schema_name else table_name
258
+ col_list = ", ".join(columns) if columns else "*"
259
+
260
+ sql = f"SELECT {col_list} FROM {qualified}" # noqa: S608
261
+ bind_params: list[object] = []
262
+ if pushdown_filters is not None:
263
+ from vgi.table_filter_pushdown import deserialize_filters
264
+
265
+ pf_reader = pa.ipc.open_stream(pushdown_filters)
266
+ pf_batch = pf_reader.read_next_batch()
267
+ pf = deserialize_filters(pf_batch)
268
+ if pf and pf.filters:
269
+ where_clause, bind_params = pf.to_sql()
270
+ sql += f" WHERE {where_clause}"
271
+
272
+ with tx_lock:
273
+ schema_sub = subcursor(conn)
274
+ schema_sql = f"SELECT {col_list} FROM {qualified} LIMIT 0" # noqa: S608
275
+ output_schema = schema_sub.execute(schema_sql).to_arrow_table().schema
276
+ schema_sub.close()
277
+
278
+ scan_cursor = subcursor(conn)
279
+ result = scan_cursor.execute(sql, bind_params) if bind_params else scan_cursor.execute(sql)
280
+ reader = result.to_arrow_reader(batch_size=50_000)
281
+
282
+ state = _ScanState(reader=reader, tx_lock=tx_lock)
283
+ return Stream(output_schema=output_schema, state=state)
284
+
285
+ # ========== DDL ==========
286
+
287
+ def execute_ddl(self, attach_opaque_data: bytes, sql: str) -> None:
288
+ """Execute DDL statement on the database (non-transactional)."""
289
+ conn = self._get_db_conn(attach_opaque_data)
290
+ with self._lock:
291
+ conn.execute(sql)
292
+ self._catalog_versions[attach_opaque_data] = self._catalog_versions.get(attach_opaque_data, 1) + 1
293
+ logger.debug("DDL executed: %s", sql[:100])
294
+
295
+ def execute_ddl_tx(
296
+ self, attach_opaque_data: bytes, tx_id: bytes, sql: str, strip_catalog: str | None = None
297
+ ) -> None:
298
+ """Execute DDL within a transaction.
299
+
300
+ If strip_catalog is provided it overrides the registered catalog name.
301
+ Otherwise the catalog name from register() is used automatically.
302
+ """
303
+ catalog_name = strip_catalog
304
+ if catalog_name is None:
305
+ with self._lock:
306
+ catalog_name = self._catalog_names.get(attach_opaque_data, "")
307
+ if catalog_name:
308
+ sql = self.strip_catalog_refs(sql, catalog_name)
309
+ conn = self._get_tx_conn(attach_opaque_data, tx_id)
310
+ tx_lock = self._get_tx_lock(attach_opaque_data, tx_id)
311
+ with tx_lock:
312
+ conn.execute(sql)
313
+ with self._lock:
314
+ self._catalog_versions[attach_opaque_data] = self._catalog_versions.get(attach_opaque_data, 1) + 1
315
+ logger.debug("DDL (tx) executed: %s", sql[:100])
316
+
317
+ def strip_catalog_refs(self, sql: str, catalog_name: str) -> str:
318
+ """Strip external catalog references from SQL using AST transformation."""
319
+ import sqlglot
320
+ from sqlglot import exp
321
+
322
+ try:
323
+ parsed = sqlglot.parse_one(sql, dialect="duckdb")
324
+ except sqlglot.errors.ParseError:
325
+ logger.warning("strip_catalog_refs: failed to parse SQL, returning as-is: %s", sql[:100])
326
+ return sql
327
+
328
+ for table in parsed.find_all(exp.Table):
329
+ if table.catalog and table.catalog.lower() == catalog_name.lower():
330
+ table.set("catalog", None)
331
+ if table.args.get("db") and table.args["db"].name.lower() == "main":
332
+ table.set("db", None)
333
+
334
+ return parsed.sql(dialect="duckdb")
335
+
336
+ # ========== Metadata ==========
337
+
338
+ def _query_list(
339
+ self, attach_opaque_data: bytes, tx_id: bytes, sql: str, params: list[object] | None = None
340
+ ) -> list[str]:
341
+ """Execute a query within a transaction and return the first column as a list."""
342
+ conn = self._get_tx_conn(attach_opaque_data, tx_id)
343
+ tx_lock = self._get_tx_lock(attach_opaque_data, tx_id)
344
+ with tx_lock:
345
+ result = conn.execute(sql, params or [])
346
+ return [row[0] for row in result.fetchall()]
347
+
348
+ def list_schemas(self, attach_opaque_data: bytes, tx_id: bytes) -> list[str]:
349
+ """List schema names within a transaction."""
350
+ return self._query_list(
351
+ attach_opaque_data, tx_id, "SELECT schema_name FROM duckdb_schemas() WHERE NOT internal"
352
+ )
353
+
354
+ def list_user_tables(self, attach_opaque_data: bytes, tx_id: bytes, schema_name: str = "main") -> list[str]:
355
+ """List user tables in the given schema within a transaction."""
356
+ return self._query_list(
357
+ attach_opaque_data,
358
+ tx_id,
359
+ "SELECT table_name FROM information_schema.tables WHERE table_schema=? AND table_type='BASE TABLE'",
360
+ [schema_name],
361
+ )
362
+
363
+ def table_schema(self, attach_opaque_data: bytes, table_name: str, tx_id: bytes) -> bytes:
364
+ """Get Arrow schema for a table as serialized IPC bytes within a transaction."""
365
+ conn = self._get_tx_conn(attach_opaque_data, tx_id)
366
+ tx_lock = self._get_tx_lock(attach_opaque_data, tx_id)
367
+ bare_name = table_name.rsplit(".", 1)[-1] if "." in table_name else table_name
368
+ with tx_lock:
369
+ sub = subcursor(conn)
370
+ if "." in table_name:
371
+ schema_part = table_name.rsplit(".", 1)[0]
372
+ row = sub.execute(
373
+ "SELECT COUNT(*) FROM duckdb_tables() WHERE schema_name = ? AND table_name = ?",
374
+ [schema_part, bare_name],
375
+ ).fetchone()
376
+ is_table = row is not None and row[0] > 0
377
+ else:
378
+ row = sub.execute(
379
+ "SELECT COUNT(*) FROM duckdb_tables() WHERE table_name = ?",
380
+ [bare_name],
381
+ ).fetchone()
382
+ is_table = row is not None and row[0] > 0
383
+ if not is_table:
384
+ sub.close()
385
+ raise ValueError(f"'{table_name}' is not a table")
386
+ schema = sub.execute(f"SELECT * FROM {table_name} LIMIT 0").to_arrow_table().schema # noqa: S608
387
+
388
+ col_meta_result = sub.execute(
389
+ "SELECT column_name, comment, column_default FROM duckdb_columns() WHERE table_name = ?",
390
+ [bare_name],
391
+ ).fetchall()
392
+ sub.close()
393
+ meta_updates: dict[str, dict[bytes | str, bytes | str]] = {}
394
+ for row in col_meta_result:
395
+ col_name, comment, default = row[0], row[1], row[2]
396
+ updates: dict[bytes | str, bytes | str] = {}
397
+ if comment is not None:
398
+ updates[b"comment"] = comment.encode("utf-8")
399
+ if default is not None:
400
+ updates[b"default"] = default.encode("utf-8")
401
+ if updates:
402
+ meta_updates[col_name] = updates
403
+ if meta_updates:
404
+ fields = list(schema)
405
+ for i, f in enumerate(fields):
406
+ if f.name in meta_updates:
407
+ metadata: dict[bytes | str, bytes | str] = (
408
+ dict(cast("dict[bytes | str, bytes | str]", f.metadata)) if f.metadata else {}
409
+ )
410
+ metadata.update(meta_updates[f.name])
411
+ fields[i] = f.with_metadata(metadata)
412
+ schema = pa.schema(fields)
413
+
414
+ with tx_lock:
415
+ sub2 = subcursor(conn)
416
+ constraint_rows = sub2.execute(
417
+ "SELECT constraint_type, constraint_column_names, constraint_text, "
418
+ "referenced_table, referenced_column_names "
419
+ "FROM duckdb_constraints() WHERE table_name = ?",
420
+ [bare_name],
421
+ ).fetchall()
422
+ sub2.close()
423
+
424
+ import json
425
+
426
+ constraints_json = json.dumps(
427
+ [
428
+ {
429
+ "type": row[0],
430
+ "columns": row[1],
431
+ "text": row[2],
432
+ "referenced_table": row[3],
433
+ "referenced_columns": row[4],
434
+ }
435
+ for row in constraint_rows
436
+ ]
437
+ )
438
+ schema_meta: dict[bytes | str, bytes | str] = {b"vgi.constraints": constraints_json.encode("utf-8")}
439
+
440
+ rowid_field = pa.field("rowid", pa.int64(), metadata={b"is_row_id": b""})
441
+ result_schema = pa.schema([rowid_field, *schema], metadata=schema_meta)
442
+ return result_schema.serialize().to_pybytes()
443
+
444
+ def table_comment(self, attach_opaque_data: bytes, table_name: str, tx_id: bytes) -> str | None:
445
+ """Get the comment on a table within a transaction."""
446
+ conn = self._get_tx_conn(attach_opaque_data, tx_id)
447
+ tx_lock = self._get_tx_lock(attach_opaque_data, tx_id)
448
+ bare_name = table_name.rsplit(".", 1)[-1] if "." in table_name else table_name
449
+ with tx_lock:
450
+ result = conn.execute(
451
+ "SELECT comment FROM duckdb_tables() WHERE table_name = ?",
452
+ [bare_name],
453
+ ).fetchone()
454
+ if result and result[0]:
455
+ return str(result[0])
456
+ return None
457
+
458
+ def list_user_views(self, attach_opaque_data: bytes, tx_id: bytes, schema_name: str = "main") -> list[str]:
459
+ """List user-created view names in the given schema within a transaction."""
460
+ return self._query_list(
461
+ attach_opaque_data,
462
+ tx_id,
463
+ "SELECT view_name FROM duckdb_views() WHERE schema_name = ? AND NOT internal",
464
+ [schema_name],
465
+ )
466
+
467
+ def view_info(self, attach_opaque_data: bytes, view_name: str, tx_id: bytes) -> str:
468
+ """Return view info as JSON (definition, comment)."""
469
+ conn = self._get_tx_conn(attach_opaque_data, tx_id)
470
+ tx_lock = self._get_tx_lock(attach_opaque_data, tx_id)
471
+ with tx_lock:
472
+ sub = subcursor(conn)
473
+ result = sub.execute(
474
+ "SELECT sql, comment FROM duckdb_views() WHERE view_name = ?",
475
+ [view_name],
476
+ ).fetchone()
477
+ sub.close()
478
+ if result is None:
479
+ raise ValueError(f"View '{view_name}' not found")
480
+ import json
481
+ import re
482
+
483
+ definition = result[0] or ""
484
+ match = re.match(r"CREATE\s+VIEW\s+\S+\s+AS\s+", definition, re.IGNORECASE)
485
+ if match:
486
+ definition = definition[match.end() :]
487
+ definition = definition.rstrip().rstrip(";")
488
+ return json.dumps({"definition": definition, "comment": result[1]})
489
+
490
+ # ========== Lifecycle ==========
491
+
492
+ def ping(self) -> None:
493
+ """Health check."""
494
+
495
+ def shutdown(self) -> None:
496
+ """Graceful shutdown — rollback active transactions and close all databases."""
497
+ logger.info("Shutdown requested")
498
+ with self._lock:
499
+ for attach_opaque_data, db_txns in list(self._transactions.items()):
500
+ for tx_id, conn in list(db_txns.items()):
501
+ try:
502
+ conn.rollback()
503
+ conn.close()
504
+ logger.info("Rolled back orphan tx: %s (db %s)", tx_id.hex()[:8], attach_opaque_data.hex()[:8])
505
+ except Exception:
506
+ logger.exception("Failed to rollback tx: %s", tx_id.hex()[:8])
507
+ self._transactions.clear()
508
+ self._tx_locks.clear()
509
+ for attach_opaque_data, conn in list(self._databases.items()):
510
+ try:
511
+ conn.close()
512
+ except Exception:
513
+ logger.exception("Failed to close database: %s", attach_opaque_data.hex()[:8])
514
+ self._databases.clear()
515
+ sys.exit(0)
516
+
517
+
518
+ # ============================================================================
519
+ # Stream state implementations
520
+ # ============================================================================
521
+
522
+
523
+ def _read_result_batch(result: duckdb.DuckDBPyConnection) -> pa.RecordBatch:
524
+ """Read DML result as a single Arrow batch."""
525
+ table = result.to_arrow_table()
526
+ batches = table.to_batches()
527
+ if not batches:
528
+ return pa.record_batch({f.name: [] for f in table.schema}, schema=table.schema)
529
+ if len(batches) == 1:
530
+ return batches[0]
531
+ return pa.Table.from_batches(batches, schema=table.schema).combine_chunks().to_batches()[0]
532
+
533
+
534
+ _batch_counter = 0
535
+ _batch_counter_lock = threading.Lock()
536
+
537
+
538
+ def _unique_batch_name(prefix: str) -> str:
539
+ """Generate a unique registered batch name to avoid collisions across transactions."""
540
+ global _batch_counter # noqa: PLW0603
541
+ with _batch_counter_lock:
542
+ _batch_counter += 1
543
+ return f"__{prefix}_{_batch_counter}__"
544
+
545
+
546
+ class _InsertState(ExchangeState):
547
+ """Insert exchange: receives row batches, inserts into table, returns count or RETURNING rows."""
548
+
549
+ def __init__(
550
+ self,
551
+ conn: duckdb.DuckDBPyConnection,
552
+ qualified_name: str,
553
+ returning: bool,
554
+ table_schema: pa.Schema,
555
+ tx_lock: threading.Lock,
556
+ ) -> None:
557
+ self.conn = conn
558
+ self.qualified_name = qualified_name
559
+ self.returning = returning
560
+ self.table_schema = table_schema
561
+ self.tx_lock = tx_lock
562
+
563
+ def exchange(self, input: AnnotatedBatch, out: OutputCollector, ctx: CallContext) -> None:
564
+ with self.tx_lock:
565
+ batch = input.batch
566
+ col_names = ", ".join(batch.schema.names)
567
+ view_name = _unique_batch_name("insert")
568
+ sql = f"INSERT INTO {self.qualified_name} ({col_names}) SELECT * FROM {view_name}" # noqa: S608
569
+ if self.returning:
570
+ ret_cols = ", ".join(self.table_schema.names)
571
+ sql += f" RETURNING {ret_cols}"
572
+ self.conn.register(view_name, batch)
573
+ result = self.conn.execute(sql)
574
+ result_batch = _read_result_batch(result)
575
+ self.conn.unregister(view_name)
576
+ self._emit_result(result_batch, out)
577
+
578
+ def _emit_result(self, result_batch: pa.RecordBatch, out: OutputCollector) -> None:
579
+ if self.returning:
580
+ out.emit(
581
+ result_batch
582
+ if result_batch.num_rows > 0
583
+ else pa.record_batch({c: [] for c in self.table_schema.names}, schema=self.table_schema)
584
+ )
585
+ else:
586
+ count = result_batch.column("Count")[0].as_py() if result_batch.num_rows > 0 else 0
587
+ out.emit(pa.record_batch({"count": [count]}, schema=_COUNT_SCHEMA))
588
+
589
+
590
+ class _DeleteState(ExchangeState):
591
+ """Delete exchange: receives rowid batches, deletes matching rows."""
592
+
593
+ def __init__(
594
+ self,
595
+ conn: duckdb.DuckDBPyConnection,
596
+ qualified_name: str,
597
+ returning: bool,
598
+ table_schema: pa.Schema,
599
+ tx_lock: threading.Lock,
600
+ ) -> None:
601
+ self.conn = conn
602
+ self.qualified_name = qualified_name
603
+ self.returning = returning
604
+ self.table_schema = table_schema
605
+ self.tx_lock = tx_lock
606
+
607
+ def exchange(self, input: AnnotatedBatch, out: OutputCollector, ctx: CallContext) -> None:
608
+ with self.tx_lock:
609
+ batch = input.batch
610
+ view_name = _unique_batch_name("delete")
611
+ self.conn.register(view_name, batch)
612
+ if self.returning:
613
+ ret_cols = ", ".join(f"{self.qualified_name}.{c}" for c in self.table_schema.names)
614
+ select_sql = (
615
+ f"SELECT {ret_cols} FROM {self.qualified_name} " # noqa: S608
616
+ f"JOIN {view_name} ON {self.qualified_name}.rowid = {view_name}.rowid"
617
+ )
618
+ result_batch = _read_result_batch(self.conn.execute(select_sql))
619
+ self.conn.execute(
620
+ f"DELETE FROM {self.qualified_name} " # noqa: S608
621
+ f"USING {view_name} WHERE {self.qualified_name}.rowid = {view_name}.rowid",
622
+ )
623
+ else:
624
+ result = self.conn.execute(
625
+ f"DELETE FROM {self.qualified_name} " # noqa: S608
626
+ f"USING {view_name} WHERE {self.qualified_name}.rowid = {view_name}.rowid",
627
+ )
628
+ result_batch = _read_result_batch(result)
629
+ self.conn.unregister(view_name)
630
+
631
+ if self.returning:
632
+ out.emit(
633
+ result_batch
634
+ if result_batch.num_rows > 0
635
+ else pa.record_batch({c: [] for c in self.table_schema.names}, schema=self.table_schema)
636
+ )
637
+ else:
638
+ count = result_batch.column("Count")[0].as_py() if result_batch.num_rows > 0 else 0
639
+ out.emit(pa.record_batch({"count": [count]}, schema=_COUNT_SCHEMA))
640
+
641
+
642
+ class _UpdateState(ExchangeState):
643
+ """Update exchange: receives rowid + updated columns, updates matching rows."""
644
+
645
+ def __init__(
646
+ self,
647
+ conn: duckdb.DuckDBPyConnection,
648
+ qualified_name: str,
649
+ returning: bool,
650
+ table_schema: pa.Schema,
651
+ tx_lock: threading.Lock,
652
+ ) -> None:
653
+ self.conn = conn
654
+ self.qualified_name = qualified_name
655
+ self.returning = returning
656
+ self.table_schema = table_schema
657
+ self.tx_lock = tx_lock
658
+
659
+ def exchange(self, input: AnnotatedBatch, out: OutputCollector, ctx: CallContext) -> None:
660
+ with self.tx_lock:
661
+ batch = input.batch
662
+ view_name = _unique_batch_name("update")
663
+ update_cols = [name for name in batch.schema.names if name != "rowid"]
664
+ set_clause = ", ".join(f"{col} = {view_name}.{col}" for col in update_cols)
665
+ sql = (
666
+ f"UPDATE {self.qualified_name} SET {set_clause} " # noqa: S608
667
+ f"FROM {view_name} WHERE {self.qualified_name}.rowid = {view_name}.rowid"
668
+ )
669
+ if self.returning:
670
+ ret_cols = ", ".join(self.table_schema.names)
671
+ sql += f" RETURNING {ret_cols}"
672
+ self.conn.register(view_name, batch)
673
+ result = self.conn.execute(sql)
674
+ result_batch = _read_result_batch(result)
675
+ self.conn.unregister(view_name)
676
+
677
+ if self.returning:
678
+ out.emit(
679
+ result_batch
680
+ if result_batch.num_rows > 0
681
+ else pa.record_batch({c: [] for c in self.table_schema.names}, schema=self.table_schema)
682
+ )
683
+ else:
684
+ count = result_batch.column("Count")[0].as_py() if result_batch.num_rows > 0 else 0
685
+ out.emit(pa.record_batch({"count": [count]}, schema=_COUNT_SCHEMA))
686
+
687
+
688
+ class _ScanState(ProducerState):
689
+ """Scan producer: streams Arrow batches from a query result."""
690
+
691
+ def __init__(self, reader: pa.RecordBatchReader, tx_lock: threading.Lock) -> None:
692
+ self._reader = reader
693
+ self._tx_lock = tx_lock
694
+
695
+ def produce(self, out: OutputCollector, ctx: CallContext) -> None:
696
+ with self._tx_lock:
697
+ while True:
698
+ try:
699
+ batch = self._reader.read_next_batch()
700
+ except StopIteration:
701
+ out.finish()
702
+ return
703
+ if batch.num_rows > 0:
704
+ out.emit(batch)
705
+ return
706
+
707
+
708
+ def main() -> None:
709
+ """Entry point for the vgi-transactor command."""
710
+ # sqlglot is imported lazily inside ``strip_catalog_refs``, but DDL paths
711
+ # require it. Surface a clear install message at startup instead of
712
+ # blowing up mid-transaction.
713
+ try:
714
+ import sqlglot # noqa: F401
715
+ except ImportError:
716
+ import sys as _sys
717
+
718
+ _sys.exit("vgi-transactor requires the transactor extra. Install with: pip install 'vgi-python[transactor]'")
719
+
720
+ parser = argparse.ArgumentParser(description="VGI db-transactor server")
721
+ parser.add_argument("--db-dir", required=True, help="Directory for DuckDB database files")
722
+ parser.add_argument("--socket", required=True, help="Unix domain socket path to listen on")
723
+ parser.add_argument("--log-file", default=None, help="Log file path (default: derived from socket path)")
724
+ args = parser.parse_args()
725
+
726
+ log_path = args.log_file or args.socket.replace(".sock", ".log")
727
+ logging.basicConfig(
728
+ level=logging.INFO,
729
+ format="%(asctime)s %(name)s %(levelname)s %(message)s",
730
+ filename=log_path,
731
+ )
732
+
733
+ impl = TransactorImpl(args.db_dir)
734
+ server = RpcServer(TransactorProtocol, impl)
735
+ logger.info("Serving on %s", args.socket)
736
+ serve_unix(server, args.socket, threaded=True)
737
+
738
+
739
+ if __name__ == "__main__":
740
+ main()