askql 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- askql/__init__.py +103 -0
- askql/api.py +215 -0
- askql/audit.py +121 -0
- askql/capability.py +120 -0
- askql/cli.py +243 -0
- askql/compressor.py +106 -0
- askql/config.py +227 -0
- askql/dialects/__init__.py +26 -0
- askql/dialects/mssql.py +58 -0
- askql/dialects/oracle.py +50 -0
- askql/dialects/postgres.py +59 -0
- askql/doctor.py +262 -0
- askql/drivers/__init__.py +97 -0
- askql/drivers/base.py +22 -0
- askql/drivers/jdbc.py +223 -0
- askql/drivers/mssql.py +37 -0
- askql/drivers/oracle.py +49 -0
- askql/drivers/postgres.py +37 -0
- askql/embeddings.py +86 -0
- askql/executor.py +337 -0
- askql/generate.py +292 -0
- askql/orchestrator.py +171 -0
- askql/policy.py +92 -0
- askql/py.typed +0 -0
- askql/ratelimit.py +53 -0
- askql/retriever.py +168 -0
- askql/schema_graph.py +163 -0
- askql/scraper.py +152 -0
- askql/tokenize.py +116 -0
- askql/validator.py +361 -0
- askql-0.2.0.dist-info/METADATA +251 -0
- askql-0.2.0.dist-info/RECORD +36 -0
- askql-0.2.0.dist-info/WHEEL +5 -0
- askql-0.2.0.dist-info/entry_points.txt +3 -0
- askql-0.2.0.dist-info/licenses/LICENSE +21 -0
- askql-0.2.0.dist-info/top_level.txt +1 -0
askql/__init__.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
"""askql — safe natural-language-to-SQL with defense-in-depth guardrails.
|
|
2
|
+
|
|
3
|
+
A library teams can build on: turn questions into validated, read-only SELECTs and run them
|
|
4
|
+
safely against many SQL engines. Read-only by design; the validator is the single source of
|
|
5
|
+
truth for what is safe to execute.
|
|
6
|
+
|
|
7
|
+
Public API (stable surface — internals may change between minor versions):
|
|
8
|
+
|
|
9
|
+
from askql import validate, compress, ask, execute_sql_text, Settings
|
|
10
|
+
|
|
11
|
+
Quick start (library use)::
|
|
12
|
+
|
|
13
|
+
from askql import validate, Settings
|
|
14
|
+
r = validate("SELECT id FROM s.t LIMIT 10", settings=Settings(dialect="postgres"))
|
|
15
|
+
assert r.ok
|
|
16
|
+
|
|
17
|
+
Optional features install via extras: `askql[postgres]`, `[jdbc]`, `[llm]`, `[llm-openai]`,
|
|
18
|
+
`[api]`. Heavy/optional deps (DB drivers, LLM SDKs, FastAPI, JPype) are imported lazily, so
|
|
19
|
+
`import askql` is light and never fails on a missing optional dependency.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
from __future__ import annotations
|
|
23
|
+
|
|
24
|
+
__version__ = "0.2.0"
|
|
25
|
+
|
|
26
|
+
# Config / models
|
|
27
|
+
# Schema graph + per-question compression
|
|
28
|
+
from .compressor import compress
|
|
29
|
+
from .config import (
|
|
30
|
+
DatabaseEntry,
|
|
31
|
+
Settings,
|
|
32
|
+
data_dir,
|
|
33
|
+
load_database,
|
|
34
|
+
load_env_file,
|
|
35
|
+
load_sensitive_patterns,
|
|
36
|
+
load_settings,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
# Connectivity (transport-agnostic)
|
|
40
|
+
from .drivers import DatabaseDriver, get_driver
|
|
41
|
+
|
|
42
|
+
# Execution (read-only)
|
|
43
|
+
from .executor import (
|
|
44
|
+
ExecutionResult,
|
|
45
|
+
execute_sql,
|
|
46
|
+
execute_sql_text,
|
|
47
|
+
format_csv,
|
|
48
|
+
format_markdown,
|
|
49
|
+
sanitize_error,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
# NL->SQL generation (provider-agnostic / BYOM)
|
|
53
|
+
from .generate import GeneratorUnavailable, SqlGenerator, get_generator
|
|
54
|
+
|
|
55
|
+
# Orchestration
|
|
56
|
+
from .orchestrator import ask
|
|
57
|
+
|
|
58
|
+
# RBAC (optional, opt-in by identity)
|
|
59
|
+
from .policy import AccessDenied, Policy, get_policy, resolve_current_user
|
|
60
|
+
from .schema_graph import build_graph, load_graph, write_graph
|
|
61
|
+
|
|
62
|
+
# Validation (the safety core)
|
|
63
|
+
from .validator import ValidationResult, validate
|
|
64
|
+
|
|
65
|
+
__all__ = [
|
|
66
|
+
"__version__",
|
|
67
|
+
# config
|
|
68
|
+
"Settings",
|
|
69
|
+
"DatabaseEntry",
|
|
70
|
+
"load_settings",
|
|
71
|
+
"load_database",
|
|
72
|
+
"load_env_file",
|
|
73
|
+
"load_sensitive_patterns",
|
|
74
|
+
"data_dir",
|
|
75
|
+
# validation
|
|
76
|
+
"validate",
|
|
77
|
+
"ValidationResult",
|
|
78
|
+
# schema
|
|
79
|
+
"compress",
|
|
80
|
+
"build_graph",
|
|
81
|
+
"load_graph",
|
|
82
|
+
"write_graph",
|
|
83
|
+
# execution
|
|
84
|
+
"execute_sql",
|
|
85
|
+
"execute_sql_text",
|
|
86
|
+
"ExecutionResult",
|
|
87
|
+
"format_markdown",
|
|
88
|
+
"format_csv",
|
|
89
|
+
"sanitize_error",
|
|
90
|
+
# connectivity
|
|
91
|
+
"get_driver",
|
|
92
|
+
"DatabaseDriver",
|
|
93
|
+
# rbac
|
|
94
|
+
"get_policy",
|
|
95
|
+
"Policy",
|
|
96
|
+
"AccessDenied",
|
|
97
|
+
"resolve_current_user",
|
|
98
|
+
# generation
|
|
99
|
+
"ask",
|
|
100
|
+
"get_generator",
|
|
101
|
+
"SqlGenerator",
|
|
102
|
+
"GeneratorUnavailable",
|
|
103
|
+
]
|
askql/api.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
"""REST API — connectivity-as-a-service.
|
|
2
|
+
|
|
3
|
+
The "usable by all" rung: the service holds the drivers + credentials; clients (QA, support,
|
|
4
|
+
devs, a web UI, a Slack bot) just call HTTP and hold nothing. Every request still flows through
|
|
5
|
+
the same validator / RBAC / read-only execution / audit, so the safety guarantees are identical
|
|
6
|
+
to the CLI.
|
|
7
|
+
|
|
8
|
+
Auth: set T2S_API_KEYS="key1:alice@corp,key2:bob@corp". The matched identity drives RBAC
|
|
9
|
+
(config/access-control.yaml). For local dev only, T2S_API_ALLOW_OPEN=true permits unauthenticated
|
|
10
|
+
calls (identity from an X-T2S-User header). Writing endpoints refuse to run if neither is set.
|
|
11
|
+
|
|
12
|
+
Run: uvicorn askql.api:app --host 0.0.0.0 --port 8000
|
|
13
|
+
or python -m askql.api
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import os
|
|
19
|
+
from dataclasses import replace
|
|
20
|
+
|
|
21
|
+
from fastapi import Depends, FastAPI, Header, HTTPException
|
|
22
|
+
from fastapi.responses import PlainTextResponse
|
|
23
|
+
from pydantic import BaseModel
|
|
24
|
+
|
|
25
|
+
from . import __version__
|
|
26
|
+
from .compressor import compress
|
|
27
|
+
from .config import (
|
|
28
|
+
CONFIG_DIR,
|
|
29
|
+
ROOT,
|
|
30
|
+
_read_yaml,
|
|
31
|
+
load_database,
|
|
32
|
+
load_sensitive_patterns,
|
|
33
|
+
load_settings,
|
|
34
|
+
)
|
|
35
|
+
from .executor import execute_sql_text, format_csv, format_markdown
|
|
36
|
+
from .schema_graph import load_graph
|
|
37
|
+
from .validator import validate
|
|
38
|
+
|
|
39
|
+
app = FastAPI(
|
|
40
|
+
title="askql API",
|
|
41
|
+
version=__version__,
|
|
42
|
+
description="Safe, read-only natural-language-to-SQL as a service.",
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
# ── Auth ──────────────────────────────────────────────────────
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _api_keys() -> dict[str, str]:
|
|
50
|
+
keys: dict[str, str] = {}
|
|
51
|
+
for pair in os.environ.get("T2S_API_KEYS", "").split(","):
|
|
52
|
+
if ":" in pair:
|
|
53
|
+
k, ident = pair.split(":", 1)
|
|
54
|
+
keys[k.strip()] = ident.strip()
|
|
55
|
+
return keys
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _open_allowed() -> bool:
|
|
59
|
+
return os.environ.get("T2S_API_ALLOW_OPEN", "").lower() == "true"
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def auth(x_api_key: str | None = Header(None), x_t2s_user: str | None = Header(None)) -> str | None:
|
|
63
|
+
"""Resolve the caller identity (drives RBAC). Required for non-public endpoints."""
|
|
64
|
+
keys = _api_keys()
|
|
65
|
+
if keys:
|
|
66
|
+
if not x_api_key or x_api_key not in keys:
|
|
67
|
+
raise HTTPException(status_code=401, detail="invalid or missing X-API-Key")
|
|
68
|
+
return keys[x_api_key]
|
|
69
|
+
if _open_allowed():
|
|
70
|
+
return x_t2s_user # trusted-gateway / dev mode; may be None -> pilot
|
|
71
|
+
raise HTTPException(
|
|
72
|
+
status_code=503,
|
|
73
|
+
detail="API auth not configured: set T2S_API_KEYS (or T2S_API_ALLOW_OPEN=true for dev)",
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
# ── Models ────────────────────────────────────────────────────
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class ValidateReq(BaseModel):
|
|
81
|
+
sql: str
|
|
82
|
+
database: str | None = None
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class CompressReq(BaseModel):
|
|
86
|
+
question: str
|
|
87
|
+
max_tables: int | None = None
|
|
88
|
+
max_columns: int | None = None
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class QueryReq(BaseModel):
|
|
92
|
+
sql: str
|
|
93
|
+
database: str | None = None
|
|
94
|
+
max_rows: int | None = None
|
|
95
|
+
format: str = "json" # json | csv | markdown
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class AskReq(BaseModel):
|
|
99
|
+
question: str
|
|
100
|
+
database: str | None = None
|
|
101
|
+
max_rows: int | None = None
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
# ── Public endpoints ──────────────────────────────────────────
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
@app.get("/health")
|
|
108
|
+
def health() -> dict:
|
|
109
|
+
return {"status": "ok", "version": __version__}
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@app.get("/api/v1/databases")
|
|
113
|
+
def databases() -> dict:
|
|
114
|
+
"""Registry names only — never connection strings or credentials."""
|
|
115
|
+
data = _read_yaml(CONFIG_DIR / "databases.yaml")
|
|
116
|
+
reg = data.get("databases") or {}
|
|
117
|
+
return {
|
|
118
|
+
"default": data.get("default"),
|
|
119
|
+
"databases": [
|
|
120
|
+
{"name": n, "dialect": e.get("dialect"), "environment": e.get("environment", "dev")}
|
|
121
|
+
for n, e in reg.items()
|
|
122
|
+
],
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
# ── Authenticated endpoints ───────────────────────────────────
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
@app.post("/api/v1/validate")
|
|
130
|
+
def api_validate(req: ValidateReq, identity: str | None = Depends(auth)) -> dict:
|
|
131
|
+
settings = load_settings()
|
|
132
|
+
if req.database:
|
|
133
|
+
db = load_database(req.database)
|
|
134
|
+
if db:
|
|
135
|
+
settings = replace(settings, dialect=db.dialect)
|
|
136
|
+
return validate(req.sql, settings=settings).to_dict()
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
@app.post("/api/v1/compress")
|
|
140
|
+
def api_compress(req: CompressReq, identity: str | None = Depends(auth)) -> dict:
|
|
141
|
+
settings = load_settings()
|
|
142
|
+
graph_path = ROOT / "docs" / "schema-graph.json"
|
|
143
|
+
if not graph_path.exists():
|
|
144
|
+
raise HTTPException(
|
|
145
|
+
status_code=409, detail="schema graph not built (run scrape_schema --build-graph)"
|
|
146
|
+
)
|
|
147
|
+
pats = load_sensitive_patterns() if settings.strip_sensitive_in_compressor else None
|
|
148
|
+
return compress(
|
|
149
|
+
load_graph(graph_path),
|
|
150
|
+
req.question,
|
|
151
|
+
max_tables=req.max_tables or settings.max_tables,
|
|
152
|
+
max_columns=req.max_columns or settings.max_columns,
|
|
153
|
+
seed_count=settings.seed_count,
|
|
154
|
+
sensitive_patterns=pats,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
@app.post("/api/v1/query")
|
|
159
|
+
def api_query(req: QueryReq, identity: str | None = Depends(auth)):
|
|
160
|
+
db = load_database(req.database)
|
|
161
|
+
if db is None:
|
|
162
|
+
raise HTTPException(status_code=400, detail="no database configured")
|
|
163
|
+
res = execute_sql_text(req.sql, sql_label="api", database=db, limit=req.max_rows, user=identity)
|
|
164
|
+
payload = {
|
|
165
|
+
"ok": res.ok,
|
|
166
|
+
"error": res.error,
|
|
167
|
+
"warning": res.warning,
|
|
168
|
+
"row_count": len(res.rows),
|
|
169
|
+
"truncated": res.truncated,
|
|
170
|
+
"latency_ms": res.latency_ms,
|
|
171
|
+
"columns": res.columns,
|
|
172
|
+
}
|
|
173
|
+
if not res.ok:
|
|
174
|
+
raise HTTPException(status_code=400, detail=res.error)
|
|
175
|
+
if req.format == "csv":
|
|
176
|
+
return PlainTextResponse(format_csv(res), media_type="text/csv")
|
|
177
|
+
if req.format == "markdown":
|
|
178
|
+
return {**payload, "markdown": format_markdown(res)}
|
|
179
|
+
payload["rows"] = res.rows
|
|
180
|
+
return payload
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
@app.post("/api/v1/ask")
|
|
184
|
+
def api_ask(req: AskReq, identity: str | None = Depends(auth)):
|
|
185
|
+
"""Full NL->SQL: compress -> generate -> validate -> retry -> execute. Needs an LLM."""
|
|
186
|
+
from .generate import GeneratorUnavailable
|
|
187
|
+
from .orchestrator import ask
|
|
188
|
+
|
|
189
|
+
db = load_database(req.database)
|
|
190
|
+
if db is None:
|
|
191
|
+
raise HTTPException(status_code=400, detail="no database configured")
|
|
192
|
+
try:
|
|
193
|
+
result = ask(req.question, database=db, user=identity, max_rows=req.max_rows)
|
|
194
|
+
except GeneratorUnavailable as exc:
|
|
195
|
+
raise HTTPException(status_code=503, detail=str(exc)) from exc
|
|
196
|
+
if not result["ok"]:
|
|
197
|
+
raise HTTPException(status_code=400, detail=result.get("error", "ask failed"))
|
|
198
|
+
return result
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def main() -> None:
|
|
202
|
+
import uvicorn
|
|
203
|
+
|
|
204
|
+
from .config import load_env_file
|
|
205
|
+
|
|
206
|
+
load_env_file() # pick up .env credentials for the server process
|
|
207
|
+
uvicorn.run(
|
|
208
|
+
"askql.api:app",
|
|
209
|
+
host=os.environ.get("T2S_API_HOST", "127.0.0.1"),
|
|
210
|
+
port=int(os.environ.get("T2S_API_PORT", "8000")),
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
if __name__ == "__main__":
|
|
215
|
+
main()
|
askql/audit.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
"""Audit sink — every execution (success or failure) is recorded.
|
|
2
|
+
|
|
3
|
+
`AuditSink` is an interface so the JSONL pilot implementation can be swapped for a
|
|
4
|
+
DB-backed sink later without changing callers (ARCHITECTURE.md §4). Audit writes must
|
|
5
|
+
NEVER block execution — failures here are swallowed.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import getpass
|
|
11
|
+
import json
|
|
12
|
+
import socket
|
|
13
|
+
from datetime import UTC, datetime
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import Any, Protocol
|
|
16
|
+
|
|
17
|
+
from .config import data_dir
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _now_iso() -> str:
|
|
21
|
+
return datetime.now(UTC).isoformat()
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class AuditSink(Protocol):
|
|
25
|
+
def record(self, entry: dict[str, Any]) -> None: ...
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class JsonlAuditSink:
|
|
29
|
+
"""Append one JSON object per line to build/query-audit.jsonl."""
|
|
30
|
+
|
|
31
|
+
def __init__(self, path: Path | None = None) -> None:
|
|
32
|
+
self.path = path or data_dir() / "query-audit.jsonl"
|
|
33
|
+
|
|
34
|
+
def record(self, entry: dict[str, Any]) -> None:
|
|
35
|
+
try:
|
|
36
|
+
self.path.parent.mkdir(parents=True, exist_ok=True)
|
|
37
|
+
with self.path.open("a", encoding="utf-8") as fh:
|
|
38
|
+
fh.write(json.dumps(entry) + "\n")
|
|
39
|
+
except Exception: # noqa: BLE001 - audit must never break execution
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def audit_path() -> Path:
|
|
44
|
+
return data_dir() / "query-audit.jsonl"
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def is_failure(entry: dict) -> bool:
|
|
48
|
+
"""A failure = blocked by the validator, or an execution error."""
|
|
49
|
+
return (not entry.get("validationOk")) or bool(entry.get("error"))
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def read_audit(
|
|
53
|
+
limit: int = 20,
|
|
54
|
+
*,
|
|
55
|
+
failures_only: bool = False,
|
|
56
|
+
database: str | None = None,
|
|
57
|
+
path: Path | None = None,
|
|
58
|
+
) -> list[dict]:
|
|
59
|
+
"""Return the most recent audit entries (newest last). Tolerates partial/corrupt lines."""
|
|
60
|
+
path = path or audit_path()
|
|
61
|
+
if not path.exists():
|
|
62
|
+
return []
|
|
63
|
+
entries: list[dict] = []
|
|
64
|
+
for line in path.read_text(encoding="utf-8").splitlines():
|
|
65
|
+
line = line.strip()
|
|
66
|
+
if not line:
|
|
67
|
+
continue
|
|
68
|
+
try:
|
|
69
|
+
entries.append(json.loads(line))
|
|
70
|
+
except json.JSONDecodeError:
|
|
71
|
+
continue
|
|
72
|
+
if database:
|
|
73
|
+
entries = [e for e in entries if e.get("database") == database]
|
|
74
|
+
if failures_only:
|
|
75
|
+
entries = [e for e in entries if is_failure(e)]
|
|
76
|
+
return entries[-limit:]
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def build_entry(
|
|
80
|
+
*,
|
|
81
|
+
sql: str,
|
|
82
|
+
sql_file: str,
|
|
83
|
+
validation_ok: bool,
|
|
84
|
+
row_count: int | None = None,
|
|
85
|
+
truncated: bool = False,
|
|
86
|
+
error: str | None = None,
|
|
87
|
+
latency_ms: int | None = None,
|
|
88
|
+
environment: str = "qa",
|
|
89
|
+
database: str | None = None,
|
|
90
|
+
identity: str | None = None,
|
|
91
|
+
role: str | None = None,
|
|
92
|
+
pii_mode: str | None = None,
|
|
93
|
+
) -> dict[str, Any]:
|
|
94
|
+
"""Assemble a sanitized audit record. SQL is truncated; credentials are never included.
|
|
95
|
+
|
|
96
|
+
`identity`/`role` capture the authenticated user and resolved role (RBAC); falls back to
|
|
97
|
+
the OS user when no identity is supplied (pilot mode).
|
|
98
|
+
"""
|
|
99
|
+
return {
|
|
100
|
+
"timestamp": _now_iso(),
|
|
101
|
+
"user": identity or _safe(getpass.getuser),
|
|
102
|
+
"role": role,
|
|
103
|
+
"host": _safe(socket.gethostname),
|
|
104
|
+
"environment": environment,
|
|
105
|
+
"database": database,
|
|
106
|
+
"piiMode": pii_mode,
|
|
107
|
+
"sqlFile": sql_file,
|
|
108
|
+
"sql": sql[:2000],
|
|
109
|
+
"validationOk": validation_ok,
|
|
110
|
+
"rowCount": row_count,
|
|
111
|
+
"truncated": truncated,
|
|
112
|
+
"error": error,
|
|
113
|
+
"latencyMs": latency_ms,
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _safe(fn) -> str:
|
|
118
|
+
try:
|
|
119
|
+
return fn()
|
|
120
|
+
except Exception: # noqa: BLE001
|
|
121
|
+
return "unknown"
|
askql/capability.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
"""Non-blocking write-capability advisory.
|
|
2
|
+
|
|
3
|
+
Primary mode is "use the credentials you're given." A read-only DB user is a recommended
|
|
4
|
+
nice-to-have, not a requirement. This module probes (best-effort, read-only) whether the
|
|
5
|
+
connected user can write and, if so, returns a one-line warning — it NEVER blocks execution.
|
|
6
|
+
|
|
7
|
+
The probe runs at most once/day per (database, user) via a small cache so it adds no per-query
|
|
8
|
+
overhead or noise. Any probe error -> unknown -> silent (advisory must not be fragile).
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import json
|
|
14
|
+
from datetime import UTC, datetime
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
|
|
17
|
+
from .config import data_dir
|
|
18
|
+
|
|
19
|
+
# Best-effort per-dialect scalar probe: returns a count > 0 when the user holds write/DDL privs.
|
|
20
|
+
# Run directly through the driver (not the validator), so referencing catalog views is fine.
|
|
21
|
+
WRITE_PROBES: dict[str, str] = {
|
|
22
|
+
"postgres": (
|
|
23
|
+
"SELECT "
|
|
24
|
+
"(SELECT count(*) FROM information_schema.role_table_grants "
|
|
25
|
+
" WHERE grantee = current_user "
|
|
26
|
+
" AND privilege_type IN ('INSERT','UPDATE','DELETE','TRUNCATE')) "
|
|
27
|
+
"+ (SELECT count(*) FROM pg_roles "
|
|
28
|
+
" WHERE rolname = current_user AND (rolsuper OR rolcreatedb)) AS w"
|
|
29
|
+
),
|
|
30
|
+
"redshift": (
|
|
31
|
+
"SELECT count(*) AS w FROM information_schema.role_table_grants "
|
|
32
|
+
"WHERE grantee = current_user AND privilege_type IN ('INSERT','UPDATE','DELETE')"
|
|
33
|
+
),
|
|
34
|
+
"oracle": (
|
|
35
|
+
"SELECT ("
|
|
36
|
+
"(SELECT COUNT(*) FROM session_privs WHERE privilege IN "
|
|
37
|
+
"('INSERT ANY TABLE','UPDATE ANY TABLE','DELETE ANY TABLE','CREATE TABLE',"
|
|
38
|
+
"'CREATE ANY TABLE','DROP ANY TABLE','ALTER ANY TABLE'))"
|
|
39
|
+
"+ (SELECT COUNT(*) FROM user_tab_privs WHERE grantee = USER "
|
|
40
|
+
" AND privilege IN ('INSERT','UPDATE','DELETE'))) AS w FROM dual"
|
|
41
|
+
),
|
|
42
|
+
"tsql": (
|
|
43
|
+
"SELECT COUNT(*) AS w FROM fn_my_permissions(NULL,'DATABASE') "
|
|
44
|
+
"WHERE permission_name IN ('INSERT','UPDATE','DELETE','ALTER','CONTROL','CREATE TABLE')"
|
|
45
|
+
),
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def probe_write_capability(driver, dialect: str) -> bool | None:
|
|
50
|
+
"""True=can write, False=read-only, None=unknown. Best-effort; never raises."""
|
|
51
|
+
sql = WRITE_PROBES.get(dialect)
|
|
52
|
+
if not sql:
|
|
53
|
+
return None
|
|
54
|
+
try:
|
|
55
|
+
_cols, rows = driver.execute(sql, 1, 10)
|
|
56
|
+
if rows and rows[0] and rows[0][0] is not None:
|
|
57
|
+
return int(rows[0][0]) > 0
|
|
58
|
+
except Exception: # noqa: BLE001 - advisory must not be fragile
|
|
59
|
+
return None
|
|
60
|
+
return None
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _today() -> int:
|
|
64
|
+
return int(datetime.now(UTC).timestamp() // 86400)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def probe_and_advise(
|
|
68
|
+
probe,
|
|
69
|
+
db_name: str,
|
|
70
|
+
environment: str,
|
|
71
|
+
identity: str | None,
|
|
72
|
+
*,
|
|
73
|
+
cache_dir: Path | None = None,
|
|
74
|
+
today: int | None = None,
|
|
75
|
+
) -> str | None:
|
|
76
|
+
"""Return a one-line advisory if the user can write (else None). `probe` is a zero-arg
|
|
77
|
+
callable returning bool|None (run on its own connection). Caches per (db,user)/day so the
|
|
78
|
+
probe runs at most once daily and we warn at most once daily."""
|
|
79
|
+
cache_dir = cache_dir or (data_dir() / ".write-check")
|
|
80
|
+
today = today if today is not None else _today()
|
|
81
|
+
who = identity or "default"
|
|
82
|
+
key = "".join(c if c.isalnum() else "_" for c in f"{db_name}_{who}")
|
|
83
|
+
path = cache_dir / f"{key}.json"
|
|
84
|
+
|
|
85
|
+
try:
|
|
86
|
+
state = json.loads(path.read_text(encoding="utf-8")) if path.exists() else {}
|
|
87
|
+
except (json.JSONDecodeError, OSError):
|
|
88
|
+
state = {}
|
|
89
|
+
|
|
90
|
+
if state.get("day") == today:
|
|
91
|
+
capable = state.get("capable")
|
|
92
|
+
else:
|
|
93
|
+
try:
|
|
94
|
+
capable = probe()
|
|
95
|
+
except Exception: # noqa: BLE001 - advisory must not be fragile
|
|
96
|
+
capable = None
|
|
97
|
+
state = {"day": today, "capable": capable, "warned_day": state.get("warned_day")}
|
|
98
|
+
_save(path, state)
|
|
99
|
+
|
|
100
|
+
if not capable:
|
|
101
|
+
return None
|
|
102
|
+
if state.get("warned_day") == today:
|
|
103
|
+
return None # already warned today
|
|
104
|
+
|
|
105
|
+
state["warned_day"] = today
|
|
106
|
+
_save(path, state)
|
|
107
|
+
return (
|
|
108
|
+
f"advisory: connected user '{who}' appears to have WRITE privileges on "
|
|
109
|
+
f"'{db_name}' ({environment}). askql blocks writes in software, but a read-only DB "
|
|
110
|
+
f"user is recommended for prod/sensitive data (see PB11). "
|
|
111
|
+
f"Set warn_if_writable: false to silence."
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _save(path: Path, state: dict) -> None:
|
|
116
|
+
try:
|
|
117
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
118
|
+
path.write_text(json.dumps(state), encoding="utf-8")
|
|
119
|
+
except OSError:
|
|
120
|
+
pass
|