pgnode 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- app/__init__.py +1 -0
- app/agent/__init__.py +1 -0
- app/agent/agent.py +216 -0
- app/agent/followup.py +127 -0
- app/agent/intent_agent.py +21 -0
- app/agent/memory.py +61 -0
- app/agent/sql_metadata.py +101 -0
- app/agent/sql_quality.py +117 -0
- app/agent/sql_quoter.py +92 -0
- app/agent/validator.py +72 -0
- app/cli/__init__.py +1 -0
- app/cli/diagnostics.py +314 -0
- app/cli/formatting.py +319 -0
- app/cli/history.py +76 -0
- app/cli/main.py +801 -0
- app/core/__init__.py +5 -0
- app/core/config.py +157 -0
- app/db/__init__.py +1 -0
- app/db/connection.py +63 -0
- app/db/query_executor.py +61 -0
- app/db/schema_loader.py +53 -0
- app/db/ssh_tunnel.py +82 -0
- app/llm/__init__.py +1 -0
- app/llm/intent_classifier.py +146 -0
- app/llm/ollama_client.py +33 -0
- app/llm/prompt_builder.py +250 -0
- app/llm/sql_generator.py +63 -0
- pgnode-0.1.0.dist-info/METADATA +223 -0
- pgnode-0.1.0.dist-info/RECORD +32 -0
- pgnode-0.1.0.dist-info/WHEEL +5 -0
- pgnode-0.1.0.dist-info/entry_points.txt +2 -0
- pgnode-0.1.0.dist-info/top_level.txt +1 -0
app/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Application package."""
|
app/agent/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Agent orchestration: NL to validated SQL execution."""
|
app/agent/agent.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
"""MVP agent: schema -> LLM SQL -> validate -> execute."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from app.agent.intent_agent import run_intent_agent
|
|
9
|
+
from app.agent.sql_metadata import extract_sql_metadata
|
|
10
|
+
from app.agent.sql_quality import detect_quality_issue
|
|
11
|
+
from app.agent.sql_quoter import quote_schema_identifiers
|
|
12
|
+
from app.agent.validator import validate_sql
|
|
13
|
+
from app.db.query_executor import run_query
|
|
14
|
+
from app.db.schema_loader import get_full_schema
|
|
15
|
+
from app.llm.prompt_builder import strip_sql_fences
|
|
16
|
+
from app.llm.sql_generator import generate_sql, repair_sql
|
|
17
|
+
|
|
18
|
+
_MAX_SQL_RETRIES = 1
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class AgentResult:
|
|
23
|
+
"""Outcome of a single agent turn."""
|
|
24
|
+
|
|
25
|
+
user_prompt: str
|
|
26
|
+
sql: str | None
|
|
27
|
+
validation_error: str | None
|
|
28
|
+
execution_result: list[dict[str, Any]] | str | None
|
|
29
|
+
llm_error: str | None = None
|
|
30
|
+
operation: str | None = None
|
|
31
|
+
table: str | None = None
|
|
32
|
+
changed_columns: list[str] | None = None
|
|
33
|
+
where_clause: str | None = None
|
|
34
|
+
out_of_scope: bool = False
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def run_agent(
|
|
38
|
+
user_prompt: str,
|
|
39
|
+
*,
|
|
40
|
+
dry_run: bool = False,
|
|
41
|
+
model: str | None = None,
|
|
42
|
+
conversation_context: str | None = None,
|
|
43
|
+
forced_sql: str | None = None,
|
|
44
|
+
) -> AgentResult:
|
|
45
|
+
"""
|
|
46
|
+
Load schema, generate SQL via local LLM, validate, optionally execute.
|
|
47
|
+
|
|
48
|
+
If dry_run is True, the query executor returns a dry-run message without executing.
|
|
49
|
+
"""
|
|
50
|
+
try:
|
|
51
|
+
schema = get_full_schema()
|
|
52
|
+
except RuntimeError as exc:
|
|
53
|
+
return AgentResult(
|
|
54
|
+
user_prompt=user_prompt,
|
|
55
|
+
sql=None,
|
|
56
|
+
validation_error=str(exc),
|
|
57
|
+
execution_result=None,
|
|
58
|
+
llm_error=None,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
try:
|
|
62
|
+
intent = run_intent_agent(
|
|
63
|
+
user_prompt,
|
|
64
|
+
schema,
|
|
65
|
+
model=model,
|
|
66
|
+
conversation_context=conversation_context,
|
|
67
|
+
)
|
|
68
|
+
except RuntimeError as exc:
|
|
69
|
+
return AgentResult(
|
|
70
|
+
user_prompt=user_prompt,
|
|
71
|
+
sql=None,
|
|
72
|
+
validation_error=None,
|
|
73
|
+
execution_result=None,
|
|
74
|
+
llm_error=str(exc),
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
if intent.intent == "out_of_scope":
|
|
78
|
+
return AgentResult(
|
|
79
|
+
user_prompt=user_prompt,
|
|
80
|
+
sql=None,
|
|
81
|
+
validation_error=None,
|
|
82
|
+
execution_result=intent.response,
|
|
83
|
+
out_of_scope=True,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
if forced_sql is not None:
|
|
87
|
+
sql = quote_schema_identifiers(strip_sql_fences(forced_sql).strip(), schema)
|
|
88
|
+
else:
|
|
89
|
+
try:
|
|
90
|
+
raw_sql = generate_sql(
|
|
91
|
+
user_prompt,
|
|
92
|
+
schema,
|
|
93
|
+
model=model,
|
|
94
|
+
conversation_context=conversation_context,
|
|
95
|
+
)
|
|
96
|
+
except RuntimeError as exc:
|
|
97
|
+
return AgentResult(
|
|
98
|
+
user_prompt=user_prompt,
|
|
99
|
+
sql=None,
|
|
100
|
+
validation_error=None,
|
|
101
|
+
execution_result=None,
|
|
102
|
+
llm_error=str(exc),
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
sql = quote_schema_identifiers(strip_sql_fences(raw_sql).strip(), schema)
|
|
106
|
+
sql = _maybe_repair_sql_quality(
|
|
107
|
+
user_prompt=user_prompt,
|
|
108
|
+
sql=sql,
|
|
109
|
+
schema=schema,
|
|
110
|
+
model=model,
|
|
111
|
+
conversation_context=conversation_context,
|
|
112
|
+
)
|
|
113
|
+
metadata = extract_sql_metadata(sql)
|
|
114
|
+
|
|
115
|
+
validation_error = validate_sql(sql)
|
|
116
|
+
if validation_error is not None:
|
|
117
|
+
return AgentResult(
|
|
118
|
+
user_prompt=user_prompt,
|
|
119
|
+
sql=sql,
|
|
120
|
+
validation_error=validation_error,
|
|
121
|
+
execution_result=None,
|
|
122
|
+
operation=metadata.operation,
|
|
123
|
+
table=metadata.table,
|
|
124
|
+
changed_columns=metadata.changed_columns,
|
|
125
|
+
where_clause=metadata.where_clause,
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
execution_result = run_query(sql, dry_run=dry_run)
|
|
129
|
+
attempts = 0
|
|
130
|
+
while (
|
|
131
|
+
not dry_run
|
|
132
|
+
and isinstance(execution_result, str)
|
|
133
|
+
and execution_result.startswith("Query execution failed:")
|
|
134
|
+
and attempts < _MAX_SQL_RETRIES
|
|
135
|
+
):
|
|
136
|
+
attempts += 1
|
|
137
|
+
try:
|
|
138
|
+
repaired_sql = repair_sql(
|
|
139
|
+
user_request=user_prompt,
|
|
140
|
+
schema=schema,
|
|
141
|
+
previous_sql=sql,
|
|
142
|
+
db_error=execution_result,
|
|
143
|
+
model=model,
|
|
144
|
+
conversation_context=conversation_context,
|
|
145
|
+
).strip()
|
|
146
|
+
except RuntimeError as exc:
|
|
147
|
+
return AgentResult(
|
|
148
|
+
user_prompt=user_prompt,
|
|
149
|
+
sql=sql,
|
|
150
|
+
validation_error=None,
|
|
151
|
+
execution_result=execution_result,
|
|
152
|
+
llm_error=f"Repair attempt failed: {exc}",
|
|
153
|
+
operation=metadata.operation,
|
|
154
|
+
table=metadata.table,
|
|
155
|
+
changed_columns=metadata.changed_columns,
|
|
156
|
+
where_clause=metadata.where_clause,
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
repair_validation_error = validate_sql(repaired_sql)
|
|
160
|
+
if repair_validation_error is not None:
|
|
161
|
+
return AgentResult(
|
|
162
|
+
user_prompt=user_prompt,
|
|
163
|
+
sql=repaired_sql,
|
|
164
|
+
validation_error=repair_validation_error,
|
|
165
|
+
execution_result=execution_result,
|
|
166
|
+
operation=metadata.operation,
|
|
167
|
+
table=metadata.table,
|
|
168
|
+
changed_columns=metadata.changed_columns,
|
|
169
|
+
where_clause=metadata.where_clause,
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
sql = quote_schema_identifiers(repaired_sql, schema)
|
|
173
|
+
metadata = extract_sql_metadata(sql)
|
|
174
|
+
execution_result = run_query(sql, dry_run=False)
|
|
175
|
+
|
|
176
|
+
return AgentResult(
|
|
177
|
+
user_prompt=user_prompt,
|
|
178
|
+
sql=sql,
|
|
179
|
+
validation_error=None,
|
|
180
|
+
execution_result=execution_result,
|
|
181
|
+
operation=metadata.operation,
|
|
182
|
+
table=metadata.table,
|
|
183
|
+
changed_columns=metadata.changed_columns,
|
|
184
|
+
where_clause=metadata.where_clause,
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def _maybe_repair_sql_quality(
|
|
189
|
+
*,
|
|
190
|
+
user_prompt: str,
|
|
191
|
+
sql: str,
|
|
192
|
+
schema: dict[str, list[dict[str, str]]],
|
|
193
|
+
model: str | None,
|
|
194
|
+
conversation_context: str | None,
|
|
195
|
+
) -> str:
|
|
196
|
+
"""Run one quality-focused repair when generated SQL violates browse/lookup heuristics."""
|
|
197
|
+
quality_issue = detect_quality_issue(user_prompt, sql)
|
|
198
|
+
if quality_issue is None:
|
|
199
|
+
return sql
|
|
200
|
+
|
|
201
|
+
try:
|
|
202
|
+
repaired_sql = repair_sql(
|
|
203
|
+
user_request=user_prompt,
|
|
204
|
+
schema=schema,
|
|
205
|
+
previous_sql=sql,
|
|
206
|
+
db_error=f"Quality issue: {quality_issue}",
|
|
207
|
+
model=model,
|
|
208
|
+
conversation_context=conversation_context,
|
|
209
|
+
).strip()
|
|
210
|
+
except RuntimeError:
|
|
211
|
+
return sql
|
|
212
|
+
|
|
213
|
+
if validate_sql(repaired_sql) is not None:
|
|
214
|
+
return sql
|
|
215
|
+
|
|
216
|
+
return quote_schema_identifiers(repaired_sql, schema)
|
app/agent/followup.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
"""Follow-up prompt grounding after writes and other contextual turns."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import re
|
|
6
|
+
|
|
7
|
+
from app.agent.memory import ConversationMemory, MemoryTurn
|
|
8
|
+
from app.agent.sql_metadata import extract_sql_metadata
|
|
9
|
+
|
|
10
|
+
_SHOW_VERBS = re.compile(
|
|
11
|
+
r"\b(show|display|list|see|view|fetch|get)\b",
|
|
12
|
+
re.IGNORECASE,
|
|
13
|
+
)
|
|
14
|
+
_AFTER_WRITE_HINTS = re.compile(
|
|
15
|
+
r"\b(updated|update|modified|changes?|changed|result|after)\b",
|
|
16
|
+
re.IGNORECASE,
|
|
17
|
+
)
|
|
18
|
+
_ANAPHORA = re.compile(
|
|
19
|
+
r"\b(that|it|same|previous|those|the change|the update)\b",
|
|
20
|
+
re.IGNORECASE,
|
|
21
|
+
)
|
|
22
|
+
_QUOTED_IDENT = r'"([^"]+)"'
|
|
23
|
+
_BARE_IDENT = r"([a-zA-Z_][\w]*)"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def is_show_after_write_prompt(text: str) -> bool:
|
|
27
|
+
"""True when the user likely wants to read back data after a write."""
|
|
28
|
+
lowered = text.lower()
|
|
29
|
+
if not _SHOW_VERBS.search(lowered):
|
|
30
|
+
return False
|
|
31
|
+
if _AFTER_WRITE_HINTS.search(lowered):
|
|
32
|
+
return True
|
|
33
|
+
return "table" in lowered or "rows" in lowered or "record" in lowered
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def try_build_show_after_write_sql(
|
|
37
|
+
user_prompt: str,
|
|
38
|
+
memory: ConversationMemory | None,
|
|
39
|
+
) -> str | None:
|
|
40
|
+
"""
|
|
41
|
+
Build a SELECT for the table/row affected by the most recent write.
|
|
42
|
+
|
|
43
|
+
Returns None when this follow-up pattern does not apply.
|
|
44
|
+
"""
|
|
45
|
+
if memory is None or not is_show_after_write_prompt(user_prompt):
|
|
46
|
+
return None
|
|
47
|
+
|
|
48
|
+
last_write = _find_last_write_turn(memory)
|
|
49
|
+
if last_write is None or not last_write.sql:
|
|
50
|
+
return None
|
|
51
|
+
|
|
52
|
+
metadata = extract_sql_metadata(last_write.sql)
|
|
53
|
+
table = metadata.table or _table_from_sql(last_write.sql)
|
|
54
|
+
if not table:
|
|
55
|
+
return None
|
|
56
|
+
|
|
57
|
+
table_ref = _quote_table_ref(table)
|
|
58
|
+
if metadata.where_clause:
|
|
59
|
+
return f"SELECT * FROM {table_ref} WHERE {metadata.where_clause} LIMIT 10;"
|
|
60
|
+
return f"SELECT * FROM {table_ref} LIMIT 10;"
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def ground_followup_prompt(text: str, memory: ConversationMemory | None) -> str:
|
|
64
|
+
"""Enrich vague follow-ups using recent conversation turns."""
|
|
65
|
+
if memory is None:
|
|
66
|
+
return text
|
|
67
|
+
|
|
68
|
+
if is_show_after_write_prompt(text):
|
|
69
|
+
last_write = _find_last_write_turn(memory)
|
|
70
|
+
if last_write is not None:
|
|
71
|
+
metadata = extract_sql_metadata(last_write.sql or "")
|
|
72
|
+
table = metadata.table or _table_from_sql(last_write.sql or "") or "unknown"
|
|
73
|
+
where = metadata.where_clause or ""
|
|
74
|
+
cols = ", ".join(last_write.changed_columns) if last_write.changed_columns else "*"
|
|
75
|
+
return (
|
|
76
|
+
f"{text}\n\n"
|
|
77
|
+
"Context from previous turn:\n"
|
|
78
|
+
f"- Previous operation: {last_write.operation}\n"
|
|
79
|
+
f"- Table: {table}\n"
|
|
80
|
+
f"- Changed columns: {cols}\n"
|
|
81
|
+
f"- Previous filter: {where or '<none>'}\n"
|
|
82
|
+
"The user wants to VIEW current rows after that write.\n"
|
|
83
|
+
"MANDATORY: Generate a SELECT query only (NOT UPDATE/INSERT/DELETE).\n"
|
|
84
|
+
f"Prefer: SELECT * FROM {_quote_table_ref(table)}"
|
|
85
|
+
+ (f" WHERE {where}" if where else "")
|
|
86
|
+
+ " LIMIT 10;"
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
if not _ANAPHORA.search(text):
|
|
90
|
+
return text
|
|
91
|
+
|
|
92
|
+
for turn in reversed(memory.recent(10)):
|
|
93
|
+
if turn.operation in {"UPDATE", "INSERT"} and (turn.table or turn.sql):
|
|
94
|
+
metadata = extract_sql_metadata(turn.sql or "")
|
|
95
|
+
table = metadata.table or turn.table or "unknown"
|
|
96
|
+
cols = ", ".join(turn.changed_columns) if turn.changed_columns else "*"
|
|
97
|
+
where = metadata.where_clause or turn.where_clause or ""
|
|
98
|
+
return (
|
|
99
|
+
f"{text}\n"
|
|
100
|
+
"Context from previous turn:\n"
|
|
101
|
+
f"- Last {turn.operation} table: {table}\n"
|
|
102
|
+
f"- Changed columns: {cols}\n"
|
|
103
|
+
f"- Filter: {where or '<none>'}\n"
|
|
104
|
+
"If user asks to show the change, produce SELECT for that table/row."
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
return text
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _find_last_write_turn(memory: ConversationMemory) -> MemoryTurn | None:
|
|
111
|
+
for turn in reversed(memory.recent(10)):
|
|
112
|
+
if turn.operation in {"UPDATE", "INSERT"}:
|
|
113
|
+
return turn
|
|
114
|
+
return None
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _table_from_sql(sql: str) -> str | None:
|
|
118
|
+
metadata = extract_sql_metadata(sql)
|
|
119
|
+
return metadata.table
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def _quote_table_ref(table: str) -> str:
|
|
123
|
+
if table.startswith('"') and table.endswith('"'):
|
|
124
|
+
return table
|
|
125
|
+
if table != table.lower():
|
|
126
|
+
return f'"{table}"'
|
|
127
|
+
return table
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""Intent Agent wrapper: classify user input before SQL generation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from app.llm.intent_classifier import IntentClassification, classify_intent
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def run_intent_agent(
|
|
9
|
+
user_prompt: str,
|
|
10
|
+
schema: dict[str, list[dict[str, str]]],
|
|
11
|
+
*,
|
|
12
|
+
model: str | None = None,
|
|
13
|
+
conversation_context: str | None = None,
|
|
14
|
+
) -> IntentClassification:
|
|
15
|
+
"""Classify whether the user message should reach the SQL generator agent."""
|
|
16
|
+
return classify_intent(
|
|
17
|
+
user_prompt,
|
|
18
|
+
schema,
|
|
19
|
+
model=model,
|
|
20
|
+
conversation_context=conversation_context,
|
|
21
|
+
)
|
app/agent/memory.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
"""Ephemeral in-process conversation memory for interactive CLI sessions."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class MemoryTurn:
|
|
10
|
+
"""Single turn retained in in-memory conversation context."""
|
|
11
|
+
|
|
12
|
+
user_prompt: str
|
|
13
|
+
sql: str | None
|
|
14
|
+
validation_error: str | None
|
|
15
|
+
execution_summary: str | None
|
|
16
|
+
operation: str | None
|
|
17
|
+
table: str | None
|
|
18
|
+
changed_columns: list[str]
|
|
19
|
+
where_clause: str | None
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ConversationMemory:
|
|
23
|
+
"""Simple rolling memory for one running conversation."""
|
|
24
|
+
|
|
25
|
+
def __init__(self) -> None:
|
|
26
|
+
self._turns: list[MemoryTurn] = []
|
|
27
|
+
|
|
28
|
+
def append(self, turn: MemoryTurn) -> None:
|
|
29
|
+
self._turns.append(turn)
|
|
30
|
+
|
|
31
|
+
def clear(self) -> None:
|
|
32
|
+
self._turns.clear()
|
|
33
|
+
|
|
34
|
+
def recent(self, limit: int) -> list[MemoryTurn]:
|
|
35
|
+
if limit <= 0:
|
|
36
|
+
return []
|
|
37
|
+
return self._turns[-limit:]
|
|
38
|
+
|
|
39
|
+
def format_for_prompt(self, limit: int) -> str:
|
|
40
|
+
"""
|
|
41
|
+
Format recent turns as compact plain text context for the SQL prompt.
|
|
42
|
+
"""
|
|
43
|
+
rows: list[str] = []
|
|
44
|
+
for idx, turn in enumerate(self.recent(limit), start=1):
|
|
45
|
+
rows.append(f"Turn {idx}:")
|
|
46
|
+
rows.append(f"- User: {turn.user_prompt}")
|
|
47
|
+
if turn.sql:
|
|
48
|
+
rows.append(f"- SQL: {turn.sql}")
|
|
49
|
+
if turn.operation:
|
|
50
|
+
rows.append(f"- Operation: {turn.operation}")
|
|
51
|
+
if turn.table:
|
|
52
|
+
rows.append(f"- Table: {turn.table}")
|
|
53
|
+
if turn.changed_columns:
|
|
54
|
+
rows.append(f"- Changed columns: {', '.join(turn.changed_columns)}")
|
|
55
|
+
if turn.where_clause:
|
|
56
|
+
rows.append(f"- Where: {turn.where_clause}")
|
|
57
|
+
if turn.validation_error:
|
|
58
|
+
rows.append(f"- Validation: {turn.validation_error}")
|
|
59
|
+
if turn.execution_summary:
|
|
60
|
+
rows.append(f"- Result: {turn.execution_summary}")
|
|
61
|
+
return "\n".join(rows)
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
"""Lightweight SQL metadata extraction for conversational follow-ups."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import re
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
|
|
8
|
+
_TABLE_PATTERN = r'(?:"([^"]+)"|([a-zA-Z_][\w]*))'
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class SqlMetadata:
|
|
13
|
+
"""Best-effort structured details from generated SQL."""
|
|
14
|
+
|
|
15
|
+
operation: str | None
|
|
16
|
+
table: str | None
|
|
17
|
+
changed_columns: list[str]
|
|
18
|
+
where_clause: str | None
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def first_keyword(sql: str) -> str:
|
|
22
|
+
"""Return the first SQL keyword (uppercased)."""
|
|
23
|
+
cleaned = sql.strip()
|
|
24
|
+
if not cleaned:
|
|
25
|
+
return ""
|
|
26
|
+
return cleaned.split(maxsplit=1)[0].upper()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def is_write_operation(sql: str) -> bool:
|
|
30
|
+
"""True when SQL starts with INSERT or UPDATE."""
|
|
31
|
+
return first_keyword(sql) in {"INSERT", "UPDATE"}
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def extract_sql_metadata(sql: str) -> SqlMetadata:
|
|
35
|
+
"""
|
|
36
|
+
Best-effort extraction for SELECT/INSERT/UPDATE.
|
|
37
|
+
|
|
38
|
+
This is intentionally simple regex parsing for MVP behavior.
|
|
39
|
+
"""
|
|
40
|
+
operation = first_keyword(sql) or None
|
|
41
|
+
table: str | None = None
|
|
42
|
+
changed_columns: list[str] = []
|
|
43
|
+
where_clause: str | None = None
|
|
44
|
+
cleaned = sql.strip().rstrip(";")
|
|
45
|
+
|
|
46
|
+
if operation == "UPDATE":
|
|
47
|
+
match = re.search(
|
|
48
|
+
rf"^\s*UPDATE\s+{_TABLE_PATTERN}\s+SET\s+(.*)$",
|
|
49
|
+
cleaned,
|
|
50
|
+
re.IGNORECASE,
|
|
51
|
+
)
|
|
52
|
+
if match:
|
|
53
|
+
table = match.group(1) or match.group(2)
|
|
54
|
+
remainder = match.group(3)
|
|
55
|
+
where_split = re.split(r"\bWHERE\b", remainder, maxsplit=1, flags=re.IGNORECASE)
|
|
56
|
+
set_expr = where_split[0]
|
|
57
|
+
if len(where_split) > 1:
|
|
58
|
+
where_clause = where_split[1].strip()
|
|
59
|
+
changed_columns = _extract_set_columns(set_expr)
|
|
60
|
+
elif operation == "INSERT":
|
|
61
|
+
match = re.search(
|
|
62
|
+
rf"^\s*INSERT\s+INTO\s+{_TABLE_PATTERN}",
|
|
63
|
+
cleaned,
|
|
64
|
+
re.IGNORECASE,
|
|
65
|
+
)
|
|
66
|
+
if match:
|
|
67
|
+
table = match.group(1) or match.group(2)
|
|
68
|
+
elif operation == "SELECT":
|
|
69
|
+
match = re.search(
|
|
70
|
+
rf"\bFROM\s+{_TABLE_PATTERN}",
|
|
71
|
+
cleaned,
|
|
72
|
+
re.IGNORECASE,
|
|
73
|
+
)
|
|
74
|
+
if match:
|
|
75
|
+
table = match.group(1) or match.group(2)
|
|
76
|
+
where_match = re.search(r"\bWHERE\b\s+(.*)$", cleaned, re.IGNORECASE)
|
|
77
|
+
if where_match:
|
|
78
|
+
where_clause = where_match.group(1).strip()
|
|
79
|
+
|
|
80
|
+
return SqlMetadata(
|
|
81
|
+
operation=operation,
|
|
82
|
+
table=table,
|
|
83
|
+
changed_columns=changed_columns,
|
|
84
|
+
where_clause=where_clause,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _extract_set_columns(set_expr: str) -> list[str]:
|
|
89
|
+
columns: list[str] = []
|
|
90
|
+
for part in set_expr.split(","):
|
|
91
|
+
left = part.split("=", maxsplit=1)[0].strip()
|
|
92
|
+
if not left:
|
|
93
|
+
continue
|
|
94
|
+
quoted = re.match(r'"([^"]+)"', left)
|
|
95
|
+
if quoted:
|
|
96
|
+
columns.append(quoted.group(1))
|
|
97
|
+
continue
|
|
98
|
+
bare = re.match(r"([a-zA-Z_][\w]*)", left)
|
|
99
|
+
if bare:
|
|
100
|
+
columns.append(bare.group(1))
|
|
101
|
+
return columns
|
app/agent/sql_quality.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
"""Heuristic SQL quality checks before execution."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import re
|
|
6
|
+
|
|
7
|
+
_BROAD_VERBS = re.compile(
|
|
8
|
+
r"\b(show|list|display|get|fetch|give|view|browse|see)\b",
|
|
9
|
+
re.IGNORECASE,
|
|
10
|
+
)
|
|
11
|
+
_SPECIFIC_FIELD_HINTS = re.compile(
|
|
12
|
+
r"\b(only|just|columns?|fields?|select)\s+[\w,\s]+|"
|
|
13
|
+
r"\b(name|price|email|count|sum|avg|min|max|total)\b",
|
|
14
|
+
re.IGNORECASE,
|
|
15
|
+
)
|
|
16
|
+
_EXPLICIT_COLUMN_THRESHOLD = 8
|
|
17
|
+
_EXACT_LABEL_MATCH = re.compile(
|
|
18
|
+
r"""=\s*'[^']+'""",
|
|
19
|
+
re.IGNORECASE,
|
|
20
|
+
)
|
|
21
|
+
_JOIN_PATTERN = re.compile(r"\bJOIN\b", re.IGNORECASE)
|
|
22
|
+
_LABEL_COLUMN_NAMES = frozenset({"name", "title", "label", "slug", "code"})
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def detect_quality_issue(user_prompt: str, sql: str) -> str | None:
|
|
26
|
+
"""
|
|
27
|
+
Return a human-readable quality issue, or None if SQL looks acceptable.
|
|
28
|
+
"""
|
|
29
|
+
cleaned = sql.strip()
|
|
30
|
+
if not cleaned.upper().startswith("SELECT"):
|
|
31
|
+
return None
|
|
32
|
+
|
|
33
|
+
if is_broad_list_request(user_prompt) and _over_selects_for_broad_request(cleaned):
|
|
34
|
+
return (
|
|
35
|
+
"Query lists too many explicit columns for a broad show/list request. "
|
|
36
|
+
"Use SELECT * FROM the target table with a reasonable LIMIT instead."
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
if _uses_exact_label_match_on_join(cleaned):
|
|
40
|
+
return (
|
|
41
|
+
"Query uses an exact string match on a related entity label. "
|
|
42
|
+
"Do not assume user text matches stored casing/spacing. "
|
|
43
|
+
"Resolve the related row with case-insensitive LOWER(...) LIKE LOWER(...) "
|
|
44
|
+
"on likely label columns, then filter the target table by foreign key "
|
|
45
|
+
"using a subquery or IN (...)."
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
return None
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def is_broad_list_request(user_prompt: str) -> bool:
|
|
52
|
+
"""True when the user asks to show/list rows without naming specific output fields."""
|
|
53
|
+
text = user_prompt.strip()
|
|
54
|
+
if not text:
|
|
55
|
+
return False
|
|
56
|
+
if not _BROAD_VERBS.search(text):
|
|
57
|
+
return False
|
|
58
|
+
return not _SPECIFIC_FIELD_HINTS.search(text)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def count_explicit_select_columns(sql: str) -> int:
|
|
62
|
+
"""Count columns in the SELECT list; returns 0 for SELECT *."""
|
|
63
|
+
match = re.search(
|
|
64
|
+
r"^\s*SELECT\s+(DISTINCT\s+)?(.*?)\s+FROM\b",
|
|
65
|
+
sql,
|
|
66
|
+
re.IGNORECASE | re.DOTALL,
|
|
67
|
+
)
|
|
68
|
+
if not match:
|
|
69
|
+
return 0
|
|
70
|
+
|
|
71
|
+
select_list = match.group(2).strip()
|
|
72
|
+
if not select_list or select_list == "*":
|
|
73
|
+
return 0
|
|
74
|
+
|
|
75
|
+
# Rough split on commas not inside parentheses
|
|
76
|
+
depth = 0
|
|
77
|
+
parts: list[str] = []
|
|
78
|
+
current: list[str] = []
|
|
79
|
+
for char in select_list:
|
|
80
|
+
if char == "(":
|
|
81
|
+
depth += 1
|
|
82
|
+
elif char == ")":
|
|
83
|
+
depth = max(0, depth - 1)
|
|
84
|
+
elif char == "," and depth == 0:
|
|
85
|
+
parts.append("".join(current).strip())
|
|
86
|
+
current = []
|
|
87
|
+
continue
|
|
88
|
+
current.append(char)
|
|
89
|
+
if current:
|
|
90
|
+
parts.append("".join(current).strip())
|
|
91
|
+
return len([part for part in parts if part])
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _over_selects_for_broad_request(sql: str) -> bool:
|
|
95
|
+
if re.search(r"^\s*SELECT\s+\*\s+FROM\b", sql, re.IGNORECASE):
|
|
96
|
+
return False
|
|
97
|
+
return count_explicit_select_columns(sql) >= _EXPLICIT_COLUMN_THRESHOLD
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _uses_exact_label_match_on_join(sql: str) -> bool:
|
|
101
|
+
if not _JOIN_PATTERN.search(sql):
|
|
102
|
+
return False
|
|
103
|
+
if not _EXACT_LABEL_MATCH.search(sql):
|
|
104
|
+
return False
|
|
105
|
+
|
|
106
|
+
where_match = re.search(
|
|
107
|
+
r"\bWHERE\b(.+?)(?:\bORDER\b|\bGROUP\b|\bLIMIT\b|$)", sql, re.IGNORECASE | re.DOTALL
|
|
108
|
+
)
|
|
109
|
+
if not where_match:
|
|
110
|
+
return False
|
|
111
|
+
|
|
112
|
+
where_clause = where_match.group(1)
|
|
113
|
+
for column in _LABEL_COLUMN_NAMES:
|
|
114
|
+
pattern = rf'["\w.]*{column}["\w.]*\s*=\s*\'[^\']+\''
|
|
115
|
+
if re.search(pattern, where_clause, re.IGNORECASE):
|
|
116
|
+
return True
|
|
117
|
+
return False
|