sqlrooms 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlrooms/cli.py +576 -0
- sqlrooms/web/__init__.py +0 -0
- sqlrooms/web/db_bridge/__init__.py +28 -0
- sqlrooms/web/db_bridge/connectors/__init__.py +9 -0
- sqlrooms/web/db_bridge/connectors/base.py +59 -0
- sqlrooms/web/db_bridge/connectors/postgres.py +118 -0
- sqlrooms/web/db_bridge/connectors/snowflake.py +161 -0
- sqlrooms/web/db_bridge/factory.py +91 -0
- sqlrooms/web/db_bridge/registry.py +113 -0
- sqlrooms/web/db_bridge/types.py +29 -0
- sqlrooms/web/db_bridge/utils.py +45 -0
- sqlrooms/web/launcher.py +1215 -0
- sqlrooms/web/static/assets/AiSlice-x2gVCmwI.js +137 -0
- sqlrooms/web/static/assets/CommandSlice-DPSuuiIV.js +23 -0
- sqlrooms/web/static/assets/DockLayout-DhgcIQET.js +1 -0
- sqlrooms/web/static/assets/GridLayout-CBVgs-6H.css +1 -0
- sqlrooms/web/static/assets/GridLayout-fXJZYHbE.js +253 -0
- sqlrooms/web/static/assets/LayoutRendererContext-BKO2wB-W.js +1 -0
- sqlrooms/web/static/assets/LeafLayout-DPFHUP6B.js +1 -0
- sqlrooms/web/static/assets/LeafLayout-ekhNDEEg.js +1 -0
- sqlrooms/web/static/assets/RenderNodeContext-BdrX8FaE.js +1 -0
- sqlrooms/web/static/assets/RendererSwitcher-DnVbhqg4.js +1 -0
- sqlrooms/web/static/assets/SplitLayout-fPLAPJN-.js +1 -0
- sqlrooms/web/static/assets/TabsLayout-C0N-7wmx.js +1 -0
- sqlrooms/web/static/assets/TabsLayout-T3iApyr5.js +41 -0
- sqlrooms/web/static/assets/chunk-jRWAZmH_.js +1 -0
- sqlrooms/web/static/assets/codicon-ngg6Pgfi.ttf +0 -0
- sqlrooms/web/static/assets/core.esm-DdCldPzV.js +5 -0
- sqlrooms/web/static/assets/css.worker-Wv5dxAWO.js +89 -0
- sqlrooms/web/static/assets/devtools-BNUn8Jb2.js +2 -0
- sqlrooms/web/static/assets/dist-dwKeDPoe.js +1 -0
- sqlrooms/web/static/assets/html.worker-CQP8QQsS.js +502 -0
- sqlrooms/web/static/assets/index-D9UP9D4f.js +316286 -0
- sqlrooms/web/static/assets/index-DioDnqnf.css +1 -0
- sqlrooms/web/static/assets/json.worker-DzV-CpCQ.js +58 -0
- sqlrooms/web/static/assets/loro_wasm_bg-DP4dC0x3.wasm +0 -0
- sqlrooms/web/static/assets/loro_wasm_bg-VQ4j4Qa9.js +9 -0
- sqlrooms/web/static/assets/loro_wasm_bg-oL0xMWtE.js +3630 -0
- sqlrooms/web/static/assets/maplibre-gl-C-a91wbz.js +748 -0
- sqlrooms/web/static/assets/node-sql-parser-ChfKIXD7.js +68 -0
- sqlrooms/web/static/assets/prop-types-DybOnnvg.js +1 -0
- sqlrooms/web/static/assets/react-dom-liMHu8hH.js +1 -0
- sqlrooms/web/static/assets/resizable-DYr7VLR3.js +1 -0
- sqlrooms/web/static/assets/scroll-area-ZmzNHGEm.js +1 -0
- sqlrooms/web/static/assets/tooltip-mgpsA9tW.js +1 -0
- sqlrooms/web/static/assets/ts.worker-Dth06zuC.js +67734 -0
- sqlrooms/web/static/assets/utils-yJ4l7ARz.js +1 -0
- sqlrooms/web/static/assets/webgl-device-CgQl7NRd.js +1 -0
- sqlrooms/web/static/assets/webgl-device-CtgDFnYR.js +13 -0
- sqlrooms/web/static/index.html +32 -0
- sqlrooms/web/static/logo.png +0 -0
- sqlrooms/web/ui.py +37 -0
- sqlrooms-0.1.0.dist-info/METADATA +274 -0
- sqlrooms-0.1.0.dist-info/RECORD +57 -0
- sqlrooms-0.1.0.dist-info/WHEEL +4 -0
- sqlrooms-0.1.0.dist-info/entry_points.txt +2 -0
- sqlrooms-0.1.0.dist-info/licenses/LICENSE +9 -0
sqlrooms/web/launcher.py
ADDED
|
@@ -0,0 +1,1215 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import errno
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
import os
|
|
8
|
+
import re
|
|
9
|
+
import secrets
|
|
10
|
+
import socket
|
|
11
|
+
import tempfile
|
|
12
|
+
import threading
|
|
13
|
+
import webbrowser
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import Any, Dict
|
|
16
|
+
from urllib.parse import urlsplit, urlunsplit
|
|
17
|
+
|
|
18
|
+
import uvicorn
|
|
19
|
+
from fastapi import FastAPI, File, UploadFile
|
|
20
|
+
from fastapi import Request, WebSocket, WebSocketDisconnect
|
|
21
|
+
from fastapi.middleware.cors import CORSMiddleware
|
|
22
|
+
from fastapi.responses import (
|
|
23
|
+
FileResponse,
|
|
24
|
+
JSONResponse,
|
|
25
|
+
RedirectResponse,
|
|
26
|
+
Response,
|
|
27
|
+
StreamingResponse,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
from sqlrooms.server import db_async
|
|
31
|
+
from sqlrooms.server.cache import QueryCache
|
|
32
|
+
from sqlrooms.server.server import server as duckdb_ws_server
|
|
33
|
+
|
|
34
|
+
from .db_bridge import (
|
|
35
|
+
ENGINE_CONFIG_FIELDS,
|
|
36
|
+
SUPPORTED_ENGINES,
|
|
37
|
+
PostgresConnectorSettings,
|
|
38
|
+
SnowflakeConnectorSettings,
|
|
39
|
+
UnknownBridgeConnectionError,
|
|
40
|
+
build_cli_db_bridge_registry,
|
|
41
|
+
build_ephemeral_connector,
|
|
42
|
+
)
|
|
43
|
+
from .ui import BuiltinUiProvider, DirectoryUiProvider, UiProvider
|
|
44
|
+
|
|
45
|
+
logger = logging.getLogger(__name__)
|
|
46
|
+
DB_BRIDGE_ID = "sqlrooms-cli-http-bridge"
|
|
47
|
+
UPLOAD_COPY_CHUNK_SIZE = 1024 * 1024
|
|
48
|
+
NO_STORE_HEADERS = {"Cache-Control": "no-store"}
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
async def _write_upload_to_path(file: UploadFile, target: Path) -> int:
|
|
52
|
+
bytes_written = 0
|
|
53
|
+
with open(target, "wb") as f:
|
|
54
|
+
while chunk := await file.read(UPLOAD_COPY_CHUNK_SIZE):
|
|
55
|
+
bytes_written += len(chunk)
|
|
56
|
+
f.write(chunk)
|
|
57
|
+
return bytes_written
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _normalize_config_string(value: Any) -> str | None:
|
|
61
|
+
if isinstance(value, (int, float)):
|
|
62
|
+
return str(value)
|
|
63
|
+
if not isinstance(value, str):
|
|
64
|
+
return None
|
|
65
|
+
normalized = value.strip()
|
|
66
|
+
return normalized or None
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _write_ai_settings_to_toml(config_path: Path, payload: dict[str, Any]) -> None:
|
|
70
|
+
"""Write ``[ai]`` settings into a TOML config file.
|
|
71
|
+
|
|
72
|
+
Preserves unrelated sections and merges provider entries so existing
|
|
73
|
+
``api_key_env`` references are not replaced with resolved secret values.
|
|
74
|
+
"""
|
|
75
|
+
import tomlkit
|
|
76
|
+
|
|
77
|
+
if config_path.exists():
|
|
78
|
+
doc = tomlkit.parse(config_path.read_text(encoding="utf-8"))
|
|
79
|
+
else:
|
|
80
|
+
doc = tomlkit.document()
|
|
81
|
+
config_path.parent.mkdir(parents=True, exist_ok=True)
|
|
82
|
+
|
|
83
|
+
settings = payload.get("settings")
|
|
84
|
+
if not isinstance(settings, dict):
|
|
85
|
+
settings = payload
|
|
86
|
+
|
|
87
|
+
providers = settings.get("providers") or {}
|
|
88
|
+
if not isinstance(providers, dict):
|
|
89
|
+
raise ValueError("'settings.providers' must be an object.")
|
|
90
|
+
|
|
91
|
+
default_provider = _normalize_config_string(payload.get("defaultProvider"))
|
|
92
|
+
default_model = _normalize_config_string(payload.get("defaultModel"))
|
|
93
|
+
|
|
94
|
+
existing_by_id: dict[str, dict[str, Any]] = {}
|
|
95
|
+
existing_ai = doc.get("ai")
|
|
96
|
+
if isinstance(existing_ai, dict):
|
|
97
|
+
for entry in existing_ai.get("providers") or []:
|
|
98
|
+
if isinstance(entry, dict):
|
|
99
|
+
provider_id = _normalize_config_string(entry.get("id"))
|
|
100
|
+
if provider_id:
|
|
101
|
+
existing_by_id[provider_id] = dict(entry)
|
|
102
|
+
|
|
103
|
+
ai_table = tomlkit.table()
|
|
104
|
+
if default_provider:
|
|
105
|
+
ai_table.add("default_provider", default_provider)
|
|
106
|
+
if default_model:
|
|
107
|
+
ai_table.add("default_model", default_model)
|
|
108
|
+
|
|
109
|
+
providers_aot = tomlkit.aot()
|
|
110
|
+
for provider_id, provider in providers.items():
|
|
111
|
+
if not isinstance(provider, dict):
|
|
112
|
+
continue
|
|
113
|
+
provider_name = _normalize_config_string(provider_id)
|
|
114
|
+
if not provider_name:
|
|
115
|
+
continue
|
|
116
|
+
|
|
117
|
+
existing = existing_by_id.get(provider_name, {})
|
|
118
|
+
item = tomlkit.table()
|
|
119
|
+
item.add("id", provider_name)
|
|
120
|
+
item.add("base_url", _normalize_config_string(provider.get("baseUrl")) or "")
|
|
121
|
+
|
|
122
|
+
api_key = _normalize_config_string(provider.get("apiKey")) or ""
|
|
123
|
+
api_key_env = _normalize_config_string(existing.get("api_key_env"))
|
|
124
|
+
env_value = os.environ.get(api_key_env, "") if api_key_env else ""
|
|
125
|
+
if api_key_env and (not api_key or api_key == env_value):
|
|
126
|
+
item.add("api_key_env", api_key_env)
|
|
127
|
+
elif api_key:
|
|
128
|
+
item.add("api_key", api_key)
|
|
129
|
+
|
|
130
|
+
models = tomlkit.array()
|
|
131
|
+
models.multiline(False)
|
|
132
|
+
for model in provider.get("models") or []:
|
|
133
|
+
if not isinstance(model, dict):
|
|
134
|
+
continue
|
|
135
|
+
model_name = _normalize_config_string(model.get("modelName"))
|
|
136
|
+
if model_name:
|
|
137
|
+
models.append(model_name)
|
|
138
|
+
item.add("models", models)
|
|
139
|
+
providers_aot.append(item)
|
|
140
|
+
ai_table.add("providers", providers_aot)
|
|
141
|
+
|
|
142
|
+
custom_models_raw = settings.get("customModels") or []
|
|
143
|
+
if not isinstance(custom_models_raw, list):
|
|
144
|
+
raise ValueError("'settings.customModels' must be an array.")
|
|
145
|
+
custom_models_aot = tomlkit.aot()
|
|
146
|
+
for custom_model in custom_models_raw:
|
|
147
|
+
if not isinstance(custom_model, dict):
|
|
148
|
+
continue
|
|
149
|
+
model_name = _normalize_config_string(custom_model.get("modelName"))
|
|
150
|
+
base_url = _normalize_config_string(custom_model.get("baseUrl"))
|
|
151
|
+
if not model_name or not base_url:
|
|
152
|
+
continue
|
|
153
|
+
item = tomlkit.table()
|
|
154
|
+
item.add("model_name", model_name)
|
|
155
|
+
item.add("base_url", base_url)
|
|
156
|
+
api_key = _normalize_config_string(custom_model.get("apiKey"))
|
|
157
|
+
if api_key:
|
|
158
|
+
item.add("api_key", api_key)
|
|
159
|
+
custom_models_aot.append(item)
|
|
160
|
+
if custom_models_aot:
|
|
161
|
+
ai_table.add("custom_models", custom_models_aot)
|
|
162
|
+
|
|
163
|
+
model_parameters = settings.get("modelParameters") or {}
|
|
164
|
+
if isinstance(model_parameters, dict):
|
|
165
|
+
params_table = tomlkit.table()
|
|
166
|
+
if "maxSteps" in model_parameters:
|
|
167
|
+
max_steps = model_parameters.get("maxSteps")
|
|
168
|
+
if not isinstance(max_steps, int) or isinstance(max_steps, bool):
|
|
169
|
+
raise ValueError(
|
|
170
|
+
"'settings.modelParameters.maxSteps' must be an integer."
|
|
171
|
+
)
|
|
172
|
+
params_table.add("max_steps", max_steps)
|
|
173
|
+
additional_instruction = model_parameters.get("additionalInstruction")
|
|
174
|
+
if isinstance(additional_instruction, str):
|
|
175
|
+
params_table.add("additional_instruction", additional_instruction)
|
|
176
|
+
if params_table:
|
|
177
|
+
ai_table.add("model_parameters", params_table)
|
|
178
|
+
|
|
179
|
+
if "ai" in doc:
|
|
180
|
+
del doc["ai"]
|
|
181
|
+
doc.add("ai", ai_table)
|
|
182
|
+
|
|
183
|
+
raw = tomlkit.dumps(doc)
|
|
184
|
+
raw = re.sub(r"\n{3,}", "\n\n", raw)
|
|
185
|
+
config_path.write_text(raw, encoding="utf-8")
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def _write_db_connectors_to_toml(
|
|
189
|
+
config_path: Path, connections: list[dict[str, Any]]
|
|
190
|
+
) -> None:
|
|
191
|
+
"""Write ``[[db.connectors]]`` entries into a TOML config file.
|
|
192
|
+
|
|
193
|
+
Merges incoming connection metadata with existing TOML entries so that
|
|
194
|
+
engine-specific fields the frontend doesn't know about (e.g.
|
|
195
|
+
``account``, ``password``) are preserved. Connections whose ``id`` is no
|
|
196
|
+
longer present in *connections* are removed.
|
|
197
|
+
|
|
198
|
+
Preserves all non-``[db]`` sections. If the file does not exist yet it is
|
|
199
|
+
created.
|
|
200
|
+
"""
|
|
201
|
+
import tomlkit
|
|
202
|
+
|
|
203
|
+
if config_path.exists():
|
|
204
|
+
doc = tomlkit.parse(config_path.read_text(encoding="utf-8"))
|
|
205
|
+
else:
|
|
206
|
+
doc = tomlkit.document()
|
|
207
|
+
config_path.parent.mkdir(parents=True, exist_ok=True)
|
|
208
|
+
|
|
209
|
+
existing_by_id: dict[str, dict[str, Any]] = {}
|
|
210
|
+
if "db" in doc and "connectors" in doc["db"]:
|
|
211
|
+
for entry in doc["db"]["connectors"]:
|
|
212
|
+
eid = entry.get("id")
|
|
213
|
+
if eid:
|
|
214
|
+
existing_by_id[eid] = dict(entry)
|
|
215
|
+
|
|
216
|
+
frontend_to_toml = {
|
|
217
|
+
"engineId": "engine",
|
|
218
|
+
"runtimeSupport": None,
|
|
219
|
+
"requiresBridge": None,
|
|
220
|
+
"bridgeId": None,
|
|
221
|
+
"isCore": None,
|
|
222
|
+
"config": None,
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
from .db_bridge import ENGINE_CONFIG_FIELDS
|
|
226
|
+
|
|
227
|
+
all_engine_keys: set[str] = set()
|
|
228
|
+
engine_to_keys: dict[str, set[str]] = {}
|
|
229
|
+
for eng, fields in ENGINE_CONFIG_FIELDS.items():
|
|
230
|
+
keys = {f["key"] for f in fields}
|
|
231
|
+
engine_to_keys[eng] = keys
|
|
232
|
+
all_engine_keys |= keys
|
|
233
|
+
|
|
234
|
+
connectors_aot = tomlkit.aot()
|
|
235
|
+
for conn in connections:
|
|
236
|
+
conn_id = conn.get("id")
|
|
237
|
+
base = dict(existing_by_id.get(conn_id, {})) if conn_id else {}
|
|
238
|
+
|
|
239
|
+
old_engine = base.get("engine")
|
|
240
|
+
new_engine = conn.get("engineId") or old_engine
|
|
241
|
+
|
|
242
|
+
for k, v in conn.items():
|
|
243
|
+
if v is None:
|
|
244
|
+
continue
|
|
245
|
+
toml_key = frontend_to_toml.get(k, k)
|
|
246
|
+
if toml_key is None:
|
|
247
|
+
continue
|
|
248
|
+
base[toml_key] = v
|
|
249
|
+
|
|
250
|
+
if old_engine and new_engine and old_engine != new_engine:
|
|
251
|
+
stale_keys = engine_to_keys.get(old_engine, set()) - engine_to_keys.get(
|
|
252
|
+
new_engine, set()
|
|
253
|
+
)
|
|
254
|
+
for sk in stale_keys:
|
|
255
|
+
base.pop(sk, None)
|
|
256
|
+
|
|
257
|
+
engine_config = conn.get("config")
|
|
258
|
+
if isinstance(engine_config, dict):
|
|
259
|
+
for ck, cv in engine_config.items():
|
|
260
|
+
if cv is not None and cv != "":
|
|
261
|
+
base[ck] = cv
|
|
262
|
+
else:
|
|
263
|
+
base.pop(ck, None)
|
|
264
|
+
|
|
265
|
+
item = tomlkit.table()
|
|
266
|
+
for k, v in base.items():
|
|
267
|
+
item.add(k, v)
|
|
268
|
+
connectors_aot.append(item)
|
|
269
|
+
|
|
270
|
+
if "db" not in doc:
|
|
271
|
+
doc.add("db", tomlkit.table())
|
|
272
|
+
db_section = doc["db"]
|
|
273
|
+
if "connectors" in db_section:
|
|
274
|
+
del db_section["connectors"] # type: ignore[arg-type]
|
|
275
|
+
db_section["connectors"] = connectors_aot # type: ignore[index]
|
|
276
|
+
|
|
277
|
+
raw = tomlkit.dumps(doc)
|
|
278
|
+
# Collapse runs of blank lines that tomlkit may accumulate on repeated writes
|
|
279
|
+
import re
|
|
280
|
+
|
|
281
|
+
raw = re.sub(r"\n{3,}", "\n\n", raw)
|
|
282
|
+
config_path.write_text(raw, encoding="utf-8")
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
def _sanitize_filename(name: str) -> str:
|
|
286
|
+
safe = os.path.basename(name.strip().replace("\\", "/"))
|
|
287
|
+
return safe or "upload.dat"
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def _localhost_probe_hosts(host: str) -> tuple[str, ...]:
|
|
291
|
+
return ("127.0.0.1", "::1") if host == "localhost" else (host,)
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def _can_bind_single_host_port(
|
|
295
|
+
host: str,
|
|
296
|
+
port: int,
|
|
297
|
+
*,
|
|
298
|
+
ignore_unavailable: bool = False,
|
|
299
|
+
) -> bool:
|
|
300
|
+
is_ipv6 = ":" in host and host != "0.0.0.0"
|
|
301
|
+
family = socket.AF_INET6 if is_ipv6 else socket.AF_INET
|
|
302
|
+
sock = socket.socket(family, socket.SOCK_STREAM)
|
|
303
|
+
try:
|
|
304
|
+
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
305
|
+
if family == socket.AF_INET6:
|
|
306
|
+
sock.bind((host, port, 0, 0))
|
|
307
|
+
else:
|
|
308
|
+
sock.bind((host, port))
|
|
309
|
+
return True
|
|
310
|
+
except OSError as exc:
|
|
311
|
+
if exc.errno in {errno.EADDRINUSE, errno.EACCES}:
|
|
312
|
+
return False
|
|
313
|
+
if ignore_unavailable and exc.errno in {
|
|
314
|
+
errno.EADDRNOTAVAIL,
|
|
315
|
+
errno.EAFNOSUPPORT,
|
|
316
|
+
}:
|
|
317
|
+
return True
|
|
318
|
+
raise
|
|
319
|
+
finally:
|
|
320
|
+
sock.close()
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
def _can_bind_port(host: str, port: int) -> bool:
|
|
324
|
+
return all(
|
|
325
|
+
_can_bind_single_host_port(
|
|
326
|
+
probe_host,
|
|
327
|
+
port,
|
|
328
|
+
ignore_unavailable=host == "localhost",
|
|
329
|
+
)
|
|
330
|
+
for probe_host in _localhost_probe_hosts(host)
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
def _pick_free_port(
|
|
335
|
+
host: str,
|
|
336
|
+
start_port: int | None = None,
|
|
337
|
+
*,
|
|
338
|
+
reserved_ports: set[int] | None = None,
|
|
339
|
+
) -> int:
|
|
340
|
+
"""
|
|
341
|
+
Pick an available TCP port for a local background server.
|
|
342
|
+
|
|
343
|
+
If ``start_port`` is provided, scan upward from that port. Otherwise bind to
|
|
344
|
+
port 0 and let the OS select a free port.
|
|
345
|
+
"""
|
|
346
|
+
is_ipv6 = ":" in host and host != "0.0.0.0"
|
|
347
|
+
family = socket.AF_INET6 if is_ipv6 else socket.AF_INET
|
|
348
|
+
reserved = reserved_ports or set()
|
|
349
|
+
if start_port is not None:
|
|
350
|
+
for port in range(start_port, 65536):
|
|
351
|
+
if port in reserved:
|
|
352
|
+
continue
|
|
353
|
+
if _can_bind_port(host, port):
|
|
354
|
+
return port
|
|
355
|
+
raise RuntimeError(f"No available port found starting from {start_port}.")
|
|
356
|
+
|
|
357
|
+
sock = socket.socket(family, socket.SOCK_STREAM)
|
|
358
|
+
try:
|
|
359
|
+
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
360
|
+
if family == socket.AF_INET6:
|
|
361
|
+
sock.bind((host, 0, 0, 0))
|
|
362
|
+
else:
|
|
363
|
+
sock.bind((host, 0))
|
|
364
|
+
return int(sock.getsockname()[1])
|
|
365
|
+
finally:
|
|
366
|
+
sock.close()
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
def _normalize_sql_for_policy(sql: str) -> str:
|
|
370
|
+
normalized = sql.strip()
|
|
371
|
+
if normalized.endswith(";"):
|
|
372
|
+
normalized = normalized[:-1].strip()
|
|
373
|
+
return normalized
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
def _redact_sql_literals(sql: str) -> str:
|
|
377
|
+
# Redact quoted strings and numeric literals before logging user SQL.
|
|
378
|
+
redacted = re.sub(r"'(?:''|[^'])*'", "'***'", sql)
|
|
379
|
+
redacted = re.sub(r'"(?:""|[^"])*"', '"***"', redacted)
|
|
380
|
+
redacted = re.sub(r"\b\d+(?:\.\d+)?\b", "?", redacted)
|
|
381
|
+
return redacted
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
def _is_select_only_sql(sql: str) -> bool:
|
|
385
|
+
normalized = _normalize_sql_for_policy(sql)
|
|
386
|
+
if not normalized:
|
|
387
|
+
return False
|
|
388
|
+
# One statement only.
|
|
389
|
+
if ";" in normalized:
|
|
390
|
+
return False
|
|
391
|
+
return bool(re.match(r"^(select|with)\s", normalized, re.IGNORECASE))
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
def _references_internal_namespace(sql: str, namespace: str) -> bool:
|
|
395
|
+
escaped = re.escape(namespace)
|
|
396
|
+
pattern = re.compile(
|
|
397
|
+
rf"(^|[^A-Za-z0-9_])(?:{escaped}|\"{escaped}\"|`{escaped}`|\[{escaped}\])\s*\.",
|
|
398
|
+
re.IGNORECASE,
|
|
399
|
+
)
|
|
400
|
+
return bool(pattern.search(sql))
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
def _encode_stream_frame(
|
|
404
|
+
frame_type: str,
|
|
405
|
+
*,
|
|
406
|
+
query_id: str,
|
|
407
|
+
payload: bytes = b"",
|
|
408
|
+
error: str | None = None,
|
|
409
|
+
) -> bytes:
|
|
410
|
+
header = {
|
|
411
|
+
"type": frame_type,
|
|
412
|
+
"queryId": query_id,
|
|
413
|
+
"payloadLength": len(payload),
|
|
414
|
+
}
|
|
415
|
+
if error:
|
|
416
|
+
header["error"] = error
|
|
417
|
+
header_bytes = json.dumps(header).encode("utf-8")
|
|
418
|
+
return len(header_bytes).to_bytes(4, byteorder="big") + header_bytes + payload
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
def _derive_ws_proxy_url(external_url: str) -> str:
|
|
422
|
+
parsed = urlsplit(external_url.rstrip("/"))
|
|
423
|
+
scheme = {"http": "ws", "https": "wss"}.get(parsed.scheme, parsed.scheme)
|
|
424
|
+
base_path = parsed.path.rstrip("/")
|
|
425
|
+
return urlunsplit((scheme, parsed.netloc, f"{base_path}/ws/duckdb", "", ""))
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
class SqlroomsHttpServer:
|
|
429
|
+
def __init__(
|
|
430
|
+
self,
|
|
431
|
+
db_path: str | Path,
|
|
432
|
+
host: str,
|
|
433
|
+
port: int,
|
|
434
|
+
ws_port: int | None,
|
|
435
|
+
*,
|
|
436
|
+
sync_enabled: bool = False,
|
|
437
|
+
meta_db: str | None = None,
|
|
438
|
+
meta_namespace: str = "__sqlrooms",
|
|
439
|
+
llm_provider: str | None = None,
|
|
440
|
+
llm_model: str | None = None,
|
|
441
|
+
api_key: str | None = None,
|
|
442
|
+
ai_providers: dict[str, dict[str, Any]] | None = None,
|
|
443
|
+
ai_custom_models: list[dict[str, Any]] | None = None,
|
|
444
|
+
ai_model_parameters: dict[str, Any] | None = None,
|
|
445
|
+
connector_settings: list[PostgresConnectorSettings | SnowflakeConnectorSettings]
|
|
446
|
+
| None = None,
|
|
447
|
+
open_browser: bool = True,
|
|
448
|
+
ui_dir: str | None = None,
|
|
449
|
+
serve_ui: bool = True,
|
|
450
|
+
experimental_enabled: bool = False,
|
|
451
|
+
config_path: Path | None = None,
|
|
452
|
+
external_url: str | None = None,
|
|
453
|
+
external_ws_url: str | None = None,
|
|
454
|
+
ai_devtools: bool = False,
|
|
455
|
+
):
|
|
456
|
+
db_path_str = str(db_path)
|
|
457
|
+
self.is_in_memory = db_path_str == ":memory:"
|
|
458
|
+
if self.is_in_memory:
|
|
459
|
+
self.db_path: Path | None = None
|
|
460
|
+
self.duckdb_database = ":memory:"
|
|
461
|
+
base_dir = Path(tempfile.gettempdir()) / "sqlrooms"
|
|
462
|
+
else:
|
|
463
|
+
self.db_path = Path(db_path).expanduser().resolve()
|
|
464
|
+
self.duckdb_database = str(self.db_path)
|
|
465
|
+
base_dir = self.db_path.parent
|
|
466
|
+
|
|
467
|
+
self.host = host
|
|
468
|
+
self.port = port
|
|
469
|
+
if ws_port is None:
|
|
470
|
+
# socketify listens on all interfaces; we pick a free local port for convenience
|
|
471
|
+
# to avoid collisions when multiple dev servers are running.
|
|
472
|
+
self.ws_port = _pick_free_port(self._public_host())
|
|
473
|
+
else:
|
|
474
|
+
self.ws_port = ws_port
|
|
475
|
+
self.llm_provider = llm_provider
|
|
476
|
+
self.llm_model = llm_model
|
|
477
|
+
self.api_key = api_key
|
|
478
|
+
self.ai_providers = ai_providers or {}
|
|
479
|
+
self.ai_custom_models = ai_custom_models or []
|
|
480
|
+
self.ai_model_parameters = ai_model_parameters or {}
|
|
481
|
+
self.open_browser = open_browser
|
|
482
|
+
self.serve_ui = serve_ui
|
|
483
|
+
self.experimental_enabled = bool(experimental_enabled)
|
|
484
|
+
self.ai_devtools = bool(ai_devtools)
|
|
485
|
+
self.sync_enabled = bool(sync_enabled)
|
|
486
|
+
self.meta_db = meta_db
|
|
487
|
+
self.meta_namespace = meta_namespace
|
|
488
|
+
self.session_token = secrets.token_urlsafe(24)
|
|
489
|
+
self.db_bridge_registry = build_cli_db_bridge_registry(
|
|
490
|
+
bridge_id=DB_BRIDGE_ID,
|
|
491
|
+
connector_settings=connector_settings,
|
|
492
|
+
)
|
|
493
|
+
self.config_path = config_path
|
|
494
|
+
self.connector_settings = connector_settings or []
|
|
495
|
+
self.external_url = external_url.rstrip("/") if external_url else None
|
|
496
|
+
self.external_ws_url = external_ws_url if external_ws_url else None
|
|
497
|
+
|
|
498
|
+
self.ui_provider: UiProvider = (
|
|
499
|
+
DirectoryUiProvider(ui_dir) if ui_dir else BuiltinUiProvider()
|
|
500
|
+
)
|
|
501
|
+
self.static_dir = self.ui_provider.static_dir()
|
|
502
|
+
self.index_html = self.ui_provider.index_html()
|
|
503
|
+
self.upload_dir = base_dir / "sqlrooms_uploads"
|
|
504
|
+
self.upload_dir.mkdir(parents=True, exist_ok=True)
|
|
505
|
+
self._duckdb_thread: threading.Thread | None = None
|
|
506
|
+
self._duckdb_ready = threading.Event()
|
|
507
|
+
self._duckdb_start_error: BaseException | None = None
|
|
508
|
+
|
|
509
|
+
async def start(self) -> None:
|
|
510
|
+
logger.info("Starting sqlrooms CLI server")
|
|
511
|
+
self._assert_ui_available()
|
|
512
|
+
if self.meta_db:
|
|
513
|
+
logger.info(
|
|
514
|
+
"Meta DB is ENABLED (db=%s, namespace=%s)",
|
|
515
|
+
self.meta_db,
|
|
516
|
+
self.meta_namespace,
|
|
517
|
+
)
|
|
518
|
+
else:
|
|
519
|
+
logger.info(
|
|
520
|
+
"Meta DB is DISABLED (using schema=%s within main DB)",
|
|
521
|
+
self.meta_namespace,
|
|
522
|
+
)
|
|
523
|
+
if self.sync_enabled:
|
|
524
|
+
logger.info("CRDT sync is ENABLED")
|
|
525
|
+
self._start_duckdb_backend()
|
|
526
|
+
app = self._build_app()
|
|
527
|
+
|
|
528
|
+
if self.open_browser and self.serve_ui:
|
|
529
|
+
threading.Timer(1.0, self._open_browser).start()
|
|
530
|
+
|
|
531
|
+
logger.info("SQLRooms UI URL: %s", self._ui_url())
|
|
532
|
+
logger.info("DuckDB websocket URL: %s", self._ws_url())
|
|
533
|
+
|
|
534
|
+
config = uvicorn.Config(
|
|
535
|
+
app, host=self.host, port=self.port, log_level="info", loop="asyncio"
|
|
536
|
+
)
|
|
537
|
+
server = uvicorn.Server(config)
|
|
538
|
+
await server.serve()
|
|
539
|
+
|
|
540
|
+
def _open_browser(self) -> None:
|
|
541
|
+
url = self._ui_url()
|
|
542
|
+
try:
|
|
543
|
+
webbrowser.open_new_tab(url)
|
|
544
|
+
except Exception as exc:
|
|
545
|
+
logger.debug("Failed to open browser: %s", exc)
|
|
546
|
+
else:
|
|
547
|
+
logger.info("Opened browser at %s", url)
|
|
548
|
+
|
|
549
|
+
def _public_host(self) -> str:
|
|
550
|
+
return (
|
|
551
|
+
"localhost"
|
|
552
|
+
if self.host in ("0.0.0.0", "::", "127.0.0.1", "::1")
|
|
553
|
+
else self.host
|
|
554
|
+
)
|
|
555
|
+
|
|
556
|
+
def _ui_host(self) -> str:
|
|
557
|
+
return "localhost" if self.host in ("0.0.0.0", "::") else self.host
|
|
558
|
+
|
|
559
|
+
@staticmethod
|
|
560
|
+
def _host_for_url(host: str) -> str:
|
|
561
|
+
if ":" in host and not host.startswith("["):
|
|
562
|
+
return f"[{host}]"
|
|
563
|
+
return host
|
|
564
|
+
|
|
565
|
+
def _ui_url(self) -> str:
|
|
566
|
+
return f"http://{self._host_for_url(self._ui_host())}:{self.port}"
|
|
567
|
+
|
|
568
|
+
def _ws_url(self) -> str:
|
|
569
|
+
return (
|
|
570
|
+
self.external_ws_url
|
|
571
|
+
or f"ws://{self._host_for_url(self._public_host())}:{self.ws_port}"
|
|
572
|
+
)
|
|
573
|
+
|
|
574
|
+
def _assert_ui_available(self) -> None:
|
|
575
|
+
if not self.serve_ui:
|
|
576
|
+
return
|
|
577
|
+
if self.index_html.exists():
|
|
578
|
+
return
|
|
579
|
+
if isinstance(self.ui_provider, DirectoryUiProvider):
|
|
580
|
+
raise RuntimeError(
|
|
581
|
+
f"SQLRooms UI bundle is missing: {self.index_html}. "
|
|
582
|
+
"Build the UI bundle or pass --no-ui to start only the API server."
|
|
583
|
+
)
|
|
584
|
+
raise RuntimeError(
|
|
585
|
+
f"Bundled SQLRooms UI is missing from this installation: {self.index_html}. "
|
|
586
|
+
"Reinstall sqlrooms or report a packaging issue. "
|
|
587
|
+
"Developers can rebuild it with `pnpm --filter sqlrooms-python build:ui`."
|
|
588
|
+
)
|
|
589
|
+
|
|
590
|
+
def _index_response(self) -> FileResponse:
|
|
591
|
+
return FileResponse(self.index_html, headers=NO_STORE_HEADERS)
|
|
592
|
+
|
|
593
|
+
def _resolve_static_file(self, full_path: str) -> Path | None:
|
|
594
|
+
static_root = self.static_dir.resolve()
|
|
595
|
+
candidate = (static_root / full_path).resolve()
|
|
596
|
+
try:
|
|
597
|
+
candidate.relative_to(static_root)
|
|
598
|
+
except ValueError:
|
|
599
|
+
return None
|
|
600
|
+
if candidate.is_file() and candidate.name != "index.html":
|
|
601
|
+
return candidate
|
|
602
|
+
return None
|
|
603
|
+
|
|
604
|
+
def _stale_entry_asset_redirect(self, full_path: str) -> RedirectResponse | None:
|
|
605
|
+
requested = Path(full_path)
|
|
606
|
+
if (
|
|
607
|
+
len(requested.parts) != 2
|
|
608
|
+
or requested.parts[0] != "assets"
|
|
609
|
+
or not requested.name.startswith("index-")
|
|
610
|
+
or requested.suffix not in {".css", ".js"}
|
|
611
|
+
):
|
|
612
|
+
return None
|
|
613
|
+
|
|
614
|
+
assets_dir = self.static_dir / "assets"
|
|
615
|
+
matches = sorted(assets_dir.glob(f"index-*{requested.suffix}"))
|
|
616
|
+
if len(matches) != 1:
|
|
617
|
+
return None
|
|
618
|
+
|
|
619
|
+
return RedirectResponse(
|
|
620
|
+
url=f"/assets/{matches[0].name}",
|
|
621
|
+
status_code=302,
|
|
622
|
+
headers=NO_STORE_HEADERS,
|
|
623
|
+
)
|
|
624
|
+
|
|
625
|
+
def _start_duckdb_backend(self) -> None:
|
|
626
|
+
self._duckdb_ready.clear()
|
|
627
|
+
self._duckdb_start_error = None
|
|
628
|
+
thread = threading.Thread(
|
|
629
|
+
target=self._run_duckdb_server,
|
|
630
|
+
daemon=True,
|
|
631
|
+
name="duckdb-ws-server",
|
|
632
|
+
)
|
|
633
|
+
thread.start()
|
|
634
|
+
self._duckdb_thread = thread
|
|
635
|
+
self._duckdb_ready.wait(timeout=10)
|
|
636
|
+
if self._duckdb_start_error is not None:
|
|
637
|
+
logger.error("Failed to start DuckDB websocket backend")
|
|
638
|
+
return
|
|
639
|
+
if not self._duckdb_ready.is_set():
|
|
640
|
+
logger.warning(
|
|
641
|
+
"DuckDB websocket backend is still starting after 10 seconds"
|
|
642
|
+
)
|
|
643
|
+
return
|
|
644
|
+
logger.info(
|
|
645
|
+
"Started DuckDB websocket backend at ws://%s:%s",
|
|
646
|
+
self._public_host(),
|
|
647
|
+
self.ws_port,
|
|
648
|
+
)
|
|
649
|
+
|
|
650
|
+
def _format_startup_error(self, exc: BaseException) -> Dict[str, str]:
|
|
651
|
+
details: list[str] = []
|
|
652
|
+
current: BaseException | None = exc
|
|
653
|
+
while current is not None:
|
|
654
|
+
error_type = type(current).__name__
|
|
655
|
+
message = str(current) or error_type
|
|
656
|
+
details.append(f"{error_type}: {message}")
|
|
657
|
+
current = current.__cause__ or current.__context__
|
|
658
|
+
|
|
659
|
+
return {
|
|
660
|
+
"message": str(exc) or type(exc).__name__,
|
|
661
|
+
"details": "\nCaused by: ".join(details),
|
|
662
|
+
}
|
|
663
|
+
|
|
664
|
+
def _runtime_status(self) -> Dict[str, Any]:
|
|
665
|
+
duckdb_status: Dict[str, Any]
|
|
666
|
+
if self._duckdb_start_error is not None:
|
|
667
|
+
error = self._format_startup_error(self._duckdb_start_error)
|
|
668
|
+
duckdb_status = {
|
|
669
|
+
"status": "error",
|
|
670
|
+
"message": "DuckDB websocket backend failed to start",
|
|
671
|
+
"error": error["message"],
|
|
672
|
+
"details": error["details"],
|
|
673
|
+
}
|
|
674
|
+
elif self._duckdb_ready.is_set():
|
|
675
|
+
duckdb_status = {"status": "ready"}
|
|
676
|
+
else:
|
|
677
|
+
duckdb_status = {"status": "starting"}
|
|
678
|
+
|
|
679
|
+
status = "ready" if duckdb_status["status"] == "ready" else "degraded"
|
|
680
|
+
return {
|
|
681
|
+
"status": status,
|
|
682
|
+
"components": {
|
|
683
|
+
"duckdbWebSocket": duckdb_status,
|
|
684
|
+
},
|
|
685
|
+
}
|
|
686
|
+
|
|
687
|
+
def _runtime_config(self) -> Dict[str, Any]:
|
|
688
|
+
derived_ws_url = (
|
|
689
|
+
_derive_ws_proxy_url(self.external_url)
|
|
690
|
+
if self.external_url and not self.external_ws_url
|
|
691
|
+
else None
|
|
692
|
+
)
|
|
693
|
+
ws_url = self.external_ws_url or derived_ws_url or self._ws_url()
|
|
694
|
+
return {
|
|
695
|
+
"wsUrl": ws_url,
|
|
696
|
+
"wsAuthToken": self.session_token,
|
|
697
|
+
"apiBaseUrl": self.external_url or "",
|
|
698
|
+
"llmProvider": self.llm_provider,
|
|
699
|
+
"llmModel": self.llm_model,
|
|
700
|
+
"configWritable": self.config_path is not None,
|
|
701
|
+
"experimentalEnabled": self.experimental_enabled,
|
|
702
|
+
"aiDevtools": self.ai_devtools,
|
|
703
|
+
"syncEnabled": self.sync_enabled,
|
|
704
|
+
"crdtWsUrl": ws_url,
|
|
705
|
+
"crdtRoomId": (
|
|
706
|
+
f"sqlrooms:{self.meta_namespace}:{self.duckdb_database or 'memory'}"
|
|
707
|
+
),
|
|
708
|
+
"aiProviders": self.ai_providers,
|
|
709
|
+
"aiSettings": {
|
|
710
|
+
"providers": self.ai_providers,
|
|
711
|
+
"customModels": self.ai_custom_models,
|
|
712
|
+
"modelParameters": self.ai_model_parameters,
|
|
713
|
+
},
|
|
714
|
+
"dbPath": self.duckdb_database,
|
|
715
|
+
"metaNamespace": self.meta_namespace,
|
|
716
|
+
"startupStatus": self._runtime_status(),
|
|
717
|
+
"dbBridge": {
|
|
718
|
+
"id": self.db_bridge_registry.bridge_id,
|
|
719
|
+
"connections": self.db_bridge_registry.runtime_connections(),
|
|
720
|
+
"diagnostics": self.db_bridge_registry.runtime_diagnostics(),
|
|
721
|
+
"supportedEngines": SUPPORTED_ENGINES,
|
|
722
|
+
"engineConfigFields": ENGINE_CONFIG_FIELDS,
|
|
723
|
+
},
|
|
724
|
+
}
|
|
725
|
+
|
|
726
|
+
def _is_authorized_request(self, request: Request) -> bool:
|
|
727
|
+
client_host = (request.client.host if request.client else "") or ""
|
|
728
|
+
if client_host in {"", "127.0.0.1", "::1", "localhost", "testclient"}:
|
|
729
|
+
return True
|
|
730
|
+
auth_header = (request.headers.get("authorization") or "").strip()
|
|
731
|
+
if auth_header.lower().startswith("bearer "):
|
|
732
|
+
token = auth_header[7:].strip()
|
|
733
|
+
if token == self.session_token:
|
|
734
|
+
return True
|
|
735
|
+
token_header = (request.headers.get("x-sqlrooms-token") or "").strip()
|
|
736
|
+
return token_header == self.session_token
|
|
737
|
+
|
|
738
|
+
def _require_api_auth(self, request: Request):
|
|
739
|
+
if self._is_authorized_request(request):
|
|
740
|
+
return None
|
|
741
|
+
return JSONResponse({"error": "unauthorized"}, status_code=401)
|
|
742
|
+
|
|
743
|
+
def _build_app(self) -> FastAPI:
|
|
744
|
+
app = FastAPI(title="sqlrooms", version="0.1.0")
|
|
745
|
+
app.add_middleware(
|
|
746
|
+
CORSMiddleware,
|
|
747
|
+
allow_origins=[
|
|
748
|
+
f"http://localhost:{self.port}",
|
|
749
|
+
f"http://127.0.0.1:{self.port}",
|
|
750
|
+
f"http://{self._public_host()}:{self.port}",
|
|
751
|
+
],
|
|
752
|
+
allow_credentials=False,
|
|
753
|
+
allow_methods=["GET", "POST", "PUT", "OPTIONS"],
|
|
754
|
+
allow_headers=["Authorization", "Content-Type", "X-SQLRooms-Token"],
|
|
755
|
+
)
|
|
756
|
+
|
|
757
|
+
@app.middleware("http")
|
|
758
|
+
async def add_cross_origin_isolation_headers(request: Request, call_next):
|
|
759
|
+
response = await call_next(request)
|
|
760
|
+
# WebContainer requires cross-origin isolation to transfer SharedArrayBuffer.
|
|
761
|
+
response.headers["Cross-Origin-Opener-Policy"] = "same-origin"
|
|
762
|
+
response.headers["Cross-Origin-Embedder-Policy"] = "require-corp"
|
|
763
|
+
response.headers["Cross-Origin-Resource-Policy"] = "cross-origin"
|
|
764
|
+
return response
|
|
765
|
+
|
|
766
|
+
@app.get("/api/config")
|
|
767
|
+
async def get_config():
|
|
768
|
+
return self._runtime_config()
|
|
769
|
+
|
|
770
|
+
@app.get("/config.json")
|
|
771
|
+
async def get_config_json():
|
|
772
|
+
return self._runtime_config()
|
|
773
|
+
|
|
774
|
+
@app.websocket("/ws/duckdb")
|
|
775
|
+
async def duckdb_websocket_proxy(client_ws: WebSocket):
|
|
776
|
+
await client_ws.accept()
|
|
777
|
+
try:
|
|
778
|
+
import websockets
|
|
779
|
+
except ImportError:
|
|
780
|
+
await client_ws.close(code=1011, reason="websockets package missing")
|
|
781
|
+
return
|
|
782
|
+
|
|
783
|
+
upstream_url = f"ws://127.0.0.1:{self.ws_port}"
|
|
784
|
+
try:
|
|
785
|
+
async with websockets.connect(
|
|
786
|
+
upstream_url,
|
|
787
|
+
max_size=None,
|
|
788
|
+
) as upstream_ws:
|
|
789
|
+
|
|
790
|
+
async def client_to_upstream() -> None:
|
|
791
|
+
while True:
|
|
792
|
+
message = await client_ws.receive()
|
|
793
|
+
if message["type"] == "websocket.disconnect":
|
|
794
|
+
await upstream_ws.close()
|
|
795
|
+
return
|
|
796
|
+
if message.get("bytes") is not None:
|
|
797
|
+
await upstream_ws.send(message["bytes"])
|
|
798
|
+
elif message.get("text") is not None:
|
|
799
|
+
await upstream_ws.send(message["text"])
|
|
800
|
+
|
|
801
|
+
async def upstream_to_client() -> None:
|
|
802
|
+
async for message in upstream_ws:
|
|
803
|
+
if isinstance(message, bytes):
|
|
804
|
+
await client_ws.send_bytes(message)
|
|
805
|
+
else:
|
|
806
|
+
await client_ws.send_text(message)
|
|
807
|
+
|
|
808
|
+
done, pending = await asyncio.wait(
|
|
809
|
+
{
|
|
810
|
+
asyncio.create_task(client_to_upstream()),
|
|
811
|
+
asyncio.create_task(upstream_to_client()),
|
|
812
|
+
},
|
|
813
|
+
return_when=asyncio.FIRST_COMPLETED,
|
|
814
|
+
)
|
|
815
|
+
for task in pending:
|
|
816
|
+
task.cancel()
|
|
817
|
+
for task in done:
|
|
818
|
+
task.result()
|
|
819
|
+
except WebSocketDisconnect:
|
|
820
|
+
return
|
|
821
|
+
except Exception:
|
|
822
|
+
logger.exception("DuckDB websocket proxy failed")
|
|
823
|
+
try:
|
|
824
|
+
await client_ws.close(code=1011)
|
|
825
|
+
except Exception:
|
|
826
|
+
pass
|
|
827
|
+
|
|
828
|
+
@app.get("/api/status")
|
|
829
|
+
async def get_status():
|
|
830
|
+
return self._runtime_status()
|
|
831
|
+
|
|
832
|
+
@app.get("/status.json")
|
|
833
|
+
async def get_status_json():
|
|
834
|
+
return self._runtime_status()
|
|
835
|
+
|
|
836
|
+
@app.get("/api/db/settings")
|
|
837
|
+
async def get_db_settings():
|
|
838
|
+
connections = self.db_bridge_registry.runtime_connections()
|
|
839
|
+
diagnostics = self.db_bridge_registry.runtime_diagnostics()
|
|
840
|
+
return {
|
|
841
|
+
"connections": connections,
|
|
842
|
+
"diagnostics": diagnostics,
|
|
843
|
+
"supportedEngines": SUPPORTED_ENGINES,
|
|
844
|
+
"engineConfigFields": ENGINE_CONFIG_FIELDS,
|
|
845
|
+
}
|
|
846
|
+
|
|
847
|
+
@app.put("/api/db/settings")
|
|
848
|
+
async def put_db_settings(payload: Dict[str, Any], request: Request):
|
|
849
|
+
unauthorized = self._require_api_auth(request)
|
|
850
|
+
if unauthorized is not None:
|
|
851
|
+
return unauthorized
|
|
852
|
+
if self.config_path is None:
|
|
853
|
+
return JSONResponse(
|
|
854
|
+
{
|
|
855
|
+
"error": "No config file available (started with --no-config or no config found)."
|
|
856
|
+
},
|
|
857
|
+
status_code=400,
|
|
858
|
+
)
|
|
859
|
+
connections = payload.get("connections")
|
|
860
|
+
if not isinstance(connections, list):
|
|
861
|
+
return JSONResponse(
|
|
862
|
+
{"error": "'connections' must be an array."},
|
|
863
|
+
status_code=400,
|
|
864
|
+
)
|
|
865
|
+
try:
|
|
866
|
+
_write_db_connectors_to_toml(self.config_path, connections)
|
|
867
|
+
except Exception as exc:
|
|
868
|
+
logger.error(
|
|
869
|
+
"Failed to write db settings to %s: %s", self.config_path, exc
|
|
870
|
+
)
|
|
871
|
+
return JSONResponse({"error": str(exc)}, status_code=500)
|
|
872
|
+
return {"ok": True, "configPath": str(self.config_path)}
|
|
873
|
+
|
|
874
|
+
@app.put("/api/ai/settings")
|
|
875
|
+
async def put_ai_settings(payload: Dict[str, Any], request: Request):
|
|
876
|
+
unauthorized = self._require_api_auth(request)
|
|
877
|
+
if unauthorized is not None:
|
|
878
|
+
return unauthorized
|
|
879
|
+
if self.config_path is None:
|
|
880
|
+
return JSONResponse(
|
|
881
|
+
{
|
|
882
|
+
"error": "No config file available (started with --no-config or no config found)."
|
|
883
|
+
},
|
|
884
|
+
status_code=400,
|
|
885
|
+
)
|
|
886
|
+
try:
|
|
887
|
+
_write_ai_settings_to_toml(self.config_path, payload)
|
|
888
|
+
except ValueError as exc:
|
|
889
|
+
return JSONResponse({"error": str(exc)}, status_code=400)
|
|
890
|
+
except Exception as exc:
|
|
891
|
+
logger.error(
|
|
892
|
+
"Failed to write AI settings to %s: %s", self.config_path, exc
|
|
893
|
+
)
|
|
894
|
+
return JSONResponse({"error": str(exc)}, status_code=500)
|
|
895
|
+
|
|
896
|
+
settings = payload.get("settings")
|
|
897
|
+
if isinstance(settings, dict):
|
|
898
|
+
providers = settings.get("providers")
|
|
899
|
+
custom_models = settings.get("customModels")
|
|
900
|
+
model_parameters = settings.get("modelParameters")
|
|
901
|
+
if isinstance(providers, dict):
|
|
902
|
+
self.ai_providers = providers
|
|
903
|
+
if isinstance(custom_models, list):
|
|
904
|
+
self.ai_custom_models = custom_models
|
|
905
|
+
if isinstance(model_parameters, dict):
|
|
906
|
+
self.ai_model_parameters = model_parameters
|
|
907
|
+
default_provider = _normalize_config_string(payload.get("defaultProvider"))
|
|
908
|
+
default_model = _normalize_config_string(payload.get("defaultModel"))
|
|
909
|
+
if default_provider:
|
|
910
|
+
self.llm_provider = default_provider
|
|
911
|
+
if default_model:
|
|
912
|
+
self.llm_model = default_model
|
|
913
|
+
return {"ok": True, "configPath": str(self.config_path)}
|
|
914
|
+
|
|
915
|
+
@app.post("/api/upload")
|
|
916
|
+
async def upload_file(request: Request, file: UploadFile = File(...)):
|
|
917
|
+
unauthorized = self._require_api_auth(request)
|
|
918
|
+
if unauthorized is not None:
|
|
919
|
+
return unauthorized
|
|
920
|
+
safe_name = _sanitize_filename(file.filename)
|
|
921
|
+
target = self.upload_dir / safe_name
|
|
922
|
+
await _write_upload_to_path(file, target)
|
|
923
|
+
return {"path": str(target)}
|
|
924
|
+
|
|
925
|
+
@app.post("/api/db/test-connection")
|
|
926
|
+
async def test_connection(payload: Dict[str, Any], request: Request):
|
|
927
|
+
unauthorized = self._require_api_auth(request)
|
|
928
|
+
if unauthorized is not None:
|
|
929
|
+
return unauthorized
|
|
930
|
+
connection_id = payload.get("connectionId")
|
|
931
|
+
engine = payload.get("engine")
|
|
932
|
+
config = payload.get("config")
|
|
933
|
+
|
|
934
|
+
try:
|
|
935
|
+
if isinstance(engine, str) and isinstance(config, dict):
|
|
936
|
+
connector = build_ephemeral_connector(engine, config)
|
|
937
|
+
ok = connector.test_connection()
|
|
938
|
+
return {"ok": bool(ok)}
|
|
939
|
+
|
|
940
|
+
if isinstance(connection_id, str) and connection_id.strip():
|
|
941
|
+
ok = self.db_bridge_registry.test_connection(connection_id)
|
|
942
|
+
return {"ok": bool(ok)}
|
|
943
|
+
|
|
944
|
+
return {
|
|
945
|
+
"ok": False,
|
|
946
|
+
"error": "Provide either engine+config or connectionId",
|
|
947
|
+
}
|
|
948
|
+
except UnknownBridgeConnectionError as exc:
|
|
949
|
+
return {"ok": False, "error": str(exc)}
|
|
950
|
+
except Exception as exc:
|
|
951
|
+
return {"ok": False, "error": str(exc)}
|
|
952
|
+
|
|
953
|
+
@app.post("/api/db/list-catalog")
|
|
954
|
+
async def list_catalog(payload: Dict[str, Any], request: Request):
|
|
955
|
+
unauthorized = self._require_api_auth(request)
|
|
956
|
+
if unauthorized is not None:
|
|
957
|
+
return unauthorized
|
|
958
|
+
connection_id = payload.get("connectionId")
|
|
959
|
+
if not isinstance(connection_id, str) or not connection_id.strip():
|
|
960
|
+
return {
|
|
961
|
+
"databases": [],
|
|
962
|
+
"schemas": [],
|
|
963
|
+
"tables": [],
|
|
964
|
+
"error": "connectionId is required",
|
|
965
|
+
}
|
|
966
|
+
try:
|
|
967
|
+
return self.db_bridge_registry.list_catalog(connection_id)
|
|
968
|
+
except UnknownBridgeConnectionError as exc:
|
|
969
|
+
return {"databases": [], "schemas": [], "tables": [], "error": str(exc)}
|
|
970
|
+
except Exception as exc:
|
|
971
|
+
return {"databases": [], "schemas": [], "tables": [], "error": str(exc)}
|
|
972
|
+
|
|
973
|
+
@app.post("/api/db/execute-query")
|
|
974
|
+
async def execute_query(payload: Dict[str, Any], request: Request):
|
|
975
|
+
unauthorized = self._require_api_auth(request)
|
|
976
|
+
if unauthorized is not None:
|
|
977
|
+
return unauthorized
|
|
978
|
+
connection_id = payload.get("connectionId")
|
|
979
|
+
if not isinstance(connection_id, str) or not connection_id.strip():
|
|
980
|
+
return JSONResponse(
|
|
981
|
+
{"error": "connectionId is required"}, status_code=400
|
|
982
|
+
)
|
|
983
|
+
sql = payload.get("sql", "")
|
|
984
|
+
query_type = payload.get("queryType", "json")
|
|
985
|
+
if query_type not in {"json", "exec"}:
|
|
986
|
+
return JSONResponse(
|
|
987
|
+
{"error": "queryType must be either 'json' or 'exec'"},
|
|
988
|
+
status_code=400,
|
|
989
|
+
)
|
|
990
|
+
if not isinstance(sql, str) or not sql.strip():
|
|
991
|
+
return JSONResponse({"error": "sql is required"}, status_code=400)
|
|
992
|
+
try:
|
|
993
|
+
return self.db_bridge_registry.execute_query(
|
|
994
|
+
connection_id=connection_id,
|
|
995
|
+
sql=sql,
|
|
996
|
+
query_type=query_type,
|
|
997
|
+
)
|
|
998
|
+
except UnknownBridgeConnectionError as exc:
|
|
999
|
+
return JSONResponse({"error": str(exc)}, status_code=404)
|
|
1000
|
+
except Exception as exc:
|
|
1001
|
+
return JSONResponse({"error": str(exc)}, status_code=500)
|
|
1002
|
+
|
|
1003
|
+
@app.post("/api/db/fetch-arrow")
|
|
1004
|
+
async def fetch_arrow(payload: Dict[str, Any], request: Request):
|
|
1005
|
+
unauthorized = self._require_api_auth(request)
|
|
1006
|
+
if unauthorized is not None:
|
|
1007
|
+
return unauthorized
|
|
1008
|
+
connection_id = payload.get("connectionId")
|
|
1009
|
+
if not isinstance(connection_id, str) or not connection_id.strip():
|
|
1010
|
+
return JSONResponse(
|
|
1011
|
+
{"error": "connectionId is required"}, status_code=400
|
|
1012
|
+
)
|
|
1013
|
+
sql = payload.get("sql", "")
|
|
1014
|
+
if not isinstance(sql, str) or not sql.strip():
|
|
1015
|
+
return JSONResponse({"error": "sql is required"}, status_code=400)
|
|
1016
|
+
try:
|
|
1017
|
+
arrow_bytes = self.db_bridge_registry.fetch_arrow_bytes(
|
|
1018
|
+
connection_id=connection_id,
|
|
1019
|
+
sql=sql,
|
|
1020
|
+
)
|
|
1021
|
+
return Response(
|
|
1022
|
+
content=arrow_bytes,
|
|
1023
|
+
media_type="application/vnd.apache.arrow.stream",
|
|
1024
|
+
)
|
|
1025
|
+
except UnknownBridgeConnectionError as exc:
|
|
1026
|
+
return JSONResponse({"error": str(exc)}, status_code=404)
|
|
1027
|
+
except Exception as exc:
|
|
1028
|
+
return JSONResponse({"error": str(exc)}, status_code=500)
|
|
1029
|
+
|
|
1030
|
+
@app.post("/api/db/fetch-arrow-stream")
|
|
1031
|
+
async def fetch_arrow_stream(payload: Dict[str, Any], request: Request):
|
|
1032
|
+
unauthorized = self._require_api_auth(request)
|
|
1033
|
+
if unauthorized is not None:
|
|
1034
|
+
return unauthorized
|
|
1035
|
+
connection_id = payload.get("connectionId")
|
|
1036
|
+
if not isinstance(connection_id, str) or not connection_id.strip():
|
|
1037
|
+
return JSONResponse(
|
|
1038
|
+
{"error": "connectionId is required"}, status_code=400
|
|
1039
|
+
)
|
|
1040
|
+
sql = payload.get("sql", "")
|
|
1041
|
+
if not isinstance(sql, str) or not sql.strip():
|
|
1042
|
+
return JSONResponse({"error": "sql is required"}, status_code=400)
|
|
1043
|
+
query_id = payload.get("queryId")
|
|
1044
|
+
if not isinstance(query_id, str) or not query_id.strip():
|
|
1045
|
+
query_id = f"bridge_{os.urandom(8).hex()}"
|
|
1046
|
+
chunk_rows = payload.get("chunkRows")
|
|
1047
|
+
if not isinstance(chunk_rows, int) or chunk_rows <= 0:
|
|
1048
|
+
chunk_rows = 5000
|
|
1049
|
+
|
|
1050
|
+
async def _stream():
|
|
1051
|
+
try:
|
|
1052
|
+
for batch in self.db_bridge_registry.stream_arrow_batches(
|
|
1053
|
+
connection_id=connection_id,
|
|
1054
|
+
sql=sql,
|
|
1055
|
+
chunk_rows=chunk_rows,
|
|
1056
|
+
query_id=query_id,
|
|
1057
|
+
):
|
|
1058
|
+
if await request.is_disconnected():
|
|
1059
|
+
break
|
|
1060
|
+
yield _encode_stream_frame(
|
|
1061
|
+
"batch", query_id=query_id, payload=batch
|
|
1062
|
+
)
|
|
1063
|
+
yield _encode_stream_frame("end", query_id=query_id)
|
|
1064
|
+
except UnknownBridgeConnectionError as exc:
|
|
1065
|
+
yield _encode_stream_frame(
|
|
1066
|
+
"error", query_id=query_id, error=str(exc)
|
|
1067
|
+
)
|
|
1068
|
+
except Exception as exc:
|
|
1069
|
+
yield _encode_stream_frame(
|
|
1070
|
+
"error", query_id=query_id, error=str(exc)
|
|
1071
|
+
)
|
|
1072
|
+
|
|
1073
|
+
return StreamingResponse(_stream(), media_type="application/octet-stream")
|
|
1074
|
+
|
|
1075
|
+
@app.post("/api/db/cancel-query")
|
|
1076
|
+
async def cancel_query(payload: Dict[str, Any], request: Request):
|
|
1077
|
+
unauthorized = self._require_api_auth(request)
|
|
1078
|
+
if unauthorized is not None:
|
|
1079
|
+
return unauthorized
|
|
1080
|
+
query_id = payload.get("queryId")
|
|
1081
|
+
connection_id = payload.get("connectionId")
|
|
1082
|
+
if not isinstance(query_id, str) or not query_id.strip():
|
|
1083
|
+
return {"cancelled": False, "error": "queryId is required"}
|
|
1084
|
+
if not isinstance(connection_id, str) or not connection_id.strip():
|
|
1085
|
+
return {"cancelled": False}
|
|
1086
|
+
try:
|
|
1087
|
+
cancelled = self.db_bridge_registry.cancel_query(
|
|
1088
|
+
connection_id=connection_id,
|
|
1089
|
+
query_id=query_id,
|
|
1090
|
+
)
|
|
1091
|
+
return {"cancelled": bool(cancelled)}
|
|
1092
|
+
except UnknownBridgeConnectionError:
|
|
1093
|
+
return {"cancelled": False}
|
|
1094
|
+
except Exception:
|
|
1095
|
+
return {"cancelled": False}
|
|
1096
|
+
|
|
1097
|
+
@app.post("/api/project/query")
|
|
1098
|
+
async def project_query(payload: Dict[str, Any], request: Request):
|
|
1099
|
+
unauthorized = self._require_api_auth(request)
|
|
1100
|
+
if unauthorized is not None:
|
|
1101
|
+
return unauthorized
|
|
1102
|
+
sql = str(payload.get("sql") or "")
|
|
1103
|
+
if not _is_select_only_sql(sql):
|
|
1104
|
+
return JSONResponse(
|
|
1105
|
+
{"error": "Only SELECT statements are allowed"},
|
|
1106
|
+
status_code=400,
|
|
1107
|
+
)
|
|
1108
|
+
if _references_internal_namespace(sql, self.meta_namespace):
|
|
1109
|
+
return JSONResponse(
|
|
1110
|
+
{
|
|
1111
|
+
"error": f"Access to internal schema {self.meta_namespace} is denied"
|
|
1112
|
+
},
|
|
1113
|
+
status_code=403,
|
|
1114
|
+
)
|
|
1115
|
+
|
|
1116
|
+
logger.debug(
|
|
1117
|
+
"project_query sql=%s",
|
|
1118
|
+
_redact_sql_literals(_normalize_sql_for_policy(sql))[:2000],
|
|
1119
|
+
)
|
|
1120
|
+
|
|
1121
|
+
def _run(cur):
|
|
1122
|
+
cur.execute(sql)
|
|
1123
|
+
columns = [d[0] for d in (cur.description or [])]
|
|
1124
|
+
fetched = cur.fetchmany(5001)
|
|
1125
|
+
truncated = len(fetched) > 5000
|
|
1126
|
+
limited = fetched[:5000]
|
|
1127
|
+
return {
|
|
1128
|
+
"columns": columns,
|
|
1129
|
+
"rows": [dict(zip(columns, row)) for row in limited],
|
|
1130
|
+
"rowCount": len(limited),
|
|
1131
|
+
"truncated": truncated,
|
|
1132
|
+
}
|
|
1133
|
+
|
|
1134
|
+
try:
|
|
1135
|
+
data = await db_async.run_db_task(_run)
|
|
1136
|
+
except Exception as exc:
|
|
1137
|
+
return JSONResponse({"error": str(exc)}, status_code=400)
|
|
1138
|
+
return data
|
|
1139
|
+
|
|
1140
|
+
if self.serve_ui and self.static_dir.exists():
|
|
1141
|
+
|
|
1142
|
+
@app.api_route("/", methods=["GET", "HEAD"])
|
|
1143
|
+
async def spa_index():
|
|
1144
|
+
return self._index_response()
|
|
1145
|
+
|
|
1146
|
+
@app.api_route("/{full_path:path}", methods=["GET", "HEAD"])
|
|
1147
|
+
async def spa_fallback(full_path: str):
|
|
1148
|
+
static_file = self._resolve_static_file(full_path)
|
|
1149
|
+
if static_file is not None:
|
|
1150
|
+
return FileResponse(static_file)
|
|
1151
|
+
|
|
1152
|
+
stale_entry_redirect = self._stale_entry_asset_redirect(full_path)
|
|
1153
|
+
if stale_entry_redirect is not None:
|
|
1154
|
+
return stale_entry_redirect
|
|
1155
|
+
|
|
1156
|
+
if full_path == "api" or full_path.startswith("api/"):
|
|
1157
|
+
return JSONResponse(
|
|
1158
|
+
{"error": "Endpoint not found"}, status_code=404
|
|
1159
|
+
)
|
|
1160
|
+
|
|
1161
|
+
if full_path.startswith("assets/"):
|
|
1162
|
+
return JSONResponse(
|
|
1163
|
+
{"error": "Static asset not found"}, status_code=404
|
|
1164
|
+
)
|
|
1165
|
+
|
|
1166
|
+
if self.index_html.exists():
|
|
1167
|
+
return self._index_response()
|
|
1168
|
+
return JSONResponse({"error": "UI bundle not found"}, status_code=404)
|
|
1169
|
+
elif self.serve_ui:
|
|
1170
|
+
logger.warning(
|
|
1171
|
+
"Static bundle missing at %s. UI will not load until built.",
|
|
1172
|
+
self.index_html,
|
|
1173
|
+
)
|
|
1174
|
+
else:
|
|
1175
|
+
logger.info(
|
|
1176
|
+
"Static UI serving is disabled; API endpoints remain available."
|
|
1177
|
+
)
|
|
1178
|
+
|
|
1179
|
+
return app
|
|
1180
|
+
|
|
1181
|
+
def _run_duckdb_server(self) -> None:
|
|
1182
|
+
import signal
|
|
1183
|
+
|
|
1184
|
+
# In some environments (notably when embedding), signal handlers can only be
|
|
1185
|
+
# registered from the main thread. The websocket server itself does not rely
|
|
1186
|
+
# on signals, so we no-op signal registration in this background thread.
|
|
1187
|
+
original_signal = signal.signal
|
|
1188
|
+
|
|
1189
|
+
def _noop_signal(*_args, **_kwargs):
|
|
1190
|
+
return None
|
|
1191
|
+
|
|
1192
|
+
signal.signal = _noop_signal # type: ignore
|
|
1193
|
+
try:
|
|
1194
|
+
db_async.init_global_connection(self.duckdb_database, extensions=["httpfs"])
|
|
1195
|
+
self._duckdb_start_error = None
|
|
1196
|
+
self._duckdb_ready.set()
|
|
1197
|
+
cache = QueryCache()
|
|
1198
|
+
duckdb_ws_server(
|
|
1199
|
+
cache,
|
|
1200
|
+
self.ws_port,
|
|
1201
|
+
auth_token=None,
|
|
1202
|
+
sync_enabled=self.sync_enabled,
|
|
1203
|
+
meta_db_path=self.meta_db,
|
|
1204
|
+
meta_namespace=self.meta_namespace,
|
|
1205
|
+
allow_client_snapshots=bool(
|
|
1206
|
+
self.sync_enabled and self.duckdb_database == ":memory:"
|
|
1207
|
+
),
|
|
1208
|
+
local_only=True,
|
|
1209
|
+
)
|
|
1210
|
+
except Exception as exc:
|
|
1211
|
+
self._duckdb_start_error = exc
|
|
1212
|
+
self._duckdb_ready.set()
|
|
1213
|
+
logger.exception("DuckDB websocket backend failed to start")
|
|
1214
|
+
finally:
|
|
1215
|
+
signal.signal = original_signal # type: ignore
|