pgnode 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
app/__init__.py ADDED
@@ -0,0 +1 @@
1
+ """Application package."""
app/agent/__init__.py ADDED
@@ -0,0 +1 @@
1
+ """Agent orchestration: NL to validated SQL execution."""
app/agent/agent.py ADDED
@@ -0,0 +1,216 @@
1
+ """MVP agent: schema -> LLM SQL -> validate -> execute."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Any
7
+
8
+ from app.agent.intent_agent import run_intent_agent
9
+ from app.agent.sql_metadata import extract_sql_metadata
10
+ from app.agent.sql_quality import detect_quality_issue
11
+ from app.agent.sql_quoter import quote_schema_identifiers
12
+ from app.agent.validator import validate_sql
13
+ from app.db.query_executor import run_query
14
+ from app.db.schema_loader import get_full_schema
15
+ from app.llm.prompt_builder import strip_sql_fences
16
+ from app.llm.sql_generator import generate_sql, repair_sql
17
+
18
+ _MAX_SQL_RETRIES = 1
19
+
20
+
21
+ @dataclass
22
+ class AgentResult:
23
+ """Outcome of a single agent turn."""
24
+
25
+ user_prompt: str
26
+ sql: str | None
27
+ validation_error: str | None
28
+ execution_result: list[dict[str, Any]] | str | None
29
+ llm_error: str | None = None
30
+ operation: str | None = None
31
+ table: str | None = None
32
+ changed_columns: list[str] | None = None
33
+ where_clause: str | None = None
34
+ out_of_scope: bool = False
35
+
36
+
37
+ def run_agent(
38
+ user_prompt: str,
39
+ *,
40
+ dry_run: bool = False,
41
+ model: str | None = None,
42
+ conversation_context: str | None = None,
43
+ forced_sql: str | None = None,
44
+ ) -> AgentResult:
45
+ """
46
+ Load schema, generate SQL via local LLM, validate, optionally execute.
47
+
48
+ If dry_run is True, the query executor returns a dry-run message without executing.
49
+ """
50
+ try:
51
+ schema = get_full_schema()
52
+ except RuntimeError as exc:
53
+ return AgentResult(
54
+ user_prompt=user_prompt,
55
+ sql=None,
56
+ validation_error=str(exc),
57
+ execution_result=None,
58
+ llm_error=None,
59
+ )
60
+
61
+ try:
62
+ intent = run_intent_agent(
63
+ user_prompt,
64
+ schema,
65
+ model=model,
66
+ conversation_context=conversation_context,
67
+ )
68
+ except RuntimeError as exc:
69
+ return AgentResult(
70
+ user_prompt=user_prompt,
71
+ sql=None,
72
+ validation_error=None,
73
+ execution_result=None,
74
+ llm_error=str(exc),
75
+ )
76
+
77
+ if intent.intent == "out_of_scope":
78
+ return AgentResult(
79
+ user_prompt=user_prompt,
80
+ sql=None,
81
+ validation_error=None,
82
+ execution_result=intent.response,
83
+ out_of_scope=True,
84
+ )
85
+
86
+ if forced_sql is not None:
87
+ sql = quote_schema_identifiers(strip_sql_fences(forced_sql).strip(), schema)
88
+ else:
89
+ try:
90
+ raw_sql = generate_sql(
91
+ user_prompt,
92
+ schema,
93
+ model=model,
94
+ conversation_context=conversation_context,
95
+ )
96
+ except RuntimeError as exc:
97
+ return AgentResult(
98
+ user_prompt=user_prompt,
99
+ sql=None,
100
+ validation_error=None,
101
+ execution_result=None,
102
+ llm_error=str(exc),
103
+ )
104
+
105
+ sql = quote_schema_identifiers(strip_sql_fences(raw_sql).strip(), schema)
106
+ sql = _maybe_repair_sql_quality(
107
+ user_prompt=user_prompt,
108
+ sql=sql,
109
+ schema=schema,
110
+ model=model,
111
+ conversation_context=conversation_context,
112
+ )
113
+ metadata = extract_sql_metadata(sql)
114
+
115
+ validation_error = validate_sql(sql)
116
+ if validation_error is not None:
117
+ return AgentResult(
118
+ user_prompt=user_prompt,
119
+ sql=sql,
120
+ validation_error=validation_error,
121
+ execution_result=None,
122
+ operation=metadata.operation,
123
+ table=metadata.table,
124
+ changed_columns=metadata.changed_columns,
125
+ where_clause=metadata.where_clause,
126
+ )
127
+
128
+ execution_result = run_query(sql, dry_run=dry_run)
129
+ attempts = 0
130
+ while (
131
+ not dry_run
132
+ and isinstance(execution_result, str)
133
+ and execution_result.startswith("Query execution failed:")
134
+ and attempts < _MAX_SQL_RETRIES
135
+ ):
136
+ attempts += 1
137
+ try:
138
+ repaired_sql = repair_sql(
139
+ user_request=user_prompt,
140
+ schema=schema,
141
+ previous_sql=sql,
142
+ db_error=execution_result,
143
+ model=model,
144
+ conversation_context=conversation_context,
145
+ ).strip()
146
+ except RuntimeError as exc:
147
+ return AgentResult(
148
+ user_prompt=user_prompt,
149
+ sql=sql,
150
+ validation_error=None,
151
+ execution_result=execution_result,
152
+ llm_error=f"Repair attempt failed: {exc}",
153
+ operation=metadata.operation,
154
+ table=metadata.table,
155
+ changed_columns=metadata.changed_columns,
156
+ where_clause=metadata.where_clause,
157
+ )
158
+
159
+ repair_validation_error = validate_sql(repaired_sql)
160
+ if repair_validation_error is not None:
161
+ return AgentResult(
162
+ user_prompt=user_prompt,
163
+ sql=repaired_sql,
164
+ validation_error=repair_validation_error,
165
+ execution_result=execution_result,
166
+ operation=metadata.operation,
167
+ table=metadata.table,
168
+ changed_columns=metadata.changed_columns,
169
+ where_clause=metadata.where_clause,
170
+ )
171
+
172
+ sql = quote_schema_identifiers(repaired_sql, schema)
173
+ metadata = extract_sql_metadata(sql)
174
+ execution_result = run_query(sql, dry_run=False)
175
+
176
+ return AgentResult(
177
+ user_prompt=user_prompt,
178
+ sql=sql,
179
+ validation_error=None,
180
+ execution_result=execution_result,
181
+ operation=metadata.operation,
182
+ table=metadata.table,
183
+ changed_columns=metadata.changed_columns,
184
+ where_clause=metadata.where_clause,
185
+ )
186
+
187
+
188
+ def _maybe_repair_sql_quality(
189
+ *,
190
+ user_prompt: str,
191
+ sql: str,
192
+ schema: dict[str, list[dict[str, str]]],
193
+ model: str | None,
194
+ conversation_context: str | None,
195
+ ) -> str:
196
+ """Run one quality-focused repair when generated SQL violates browse/lookup heuristics."""
197
+ quality_issue = detect_quality_issue(user_prompt, sql)
198
+ if quality_issue is None:
199
+ return sql
200
+
201
+ try:
202
+ repaired_sql = repair_sql(
203
+ user_request=user_prompt,
204
+ schema=schema,
205
+ previous_sql=sql,
206
+ db_error=f"Quality issue: {quality_issue}",
207
+ model=model,
208
+ conversation_context=conversation_context,
209
+ ).strip()
210
+ except RuntimeError:
211
+ return sql
212
+
213
+ if validate_sql(repaired_sql) is not None:
214
+ return sql
215
+
216
+ return quote_schema_identifiers(repaired_sql, schema)
app/agent/followup.py ADDED
@@ -0,0 +1,127 @@
1
+ """Follow-up prompt grounding after writes and other contextual turns."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import re
6
+
7
+ from app.agent.memory import ConversationMemory, MemoryTurn
8
+ from app.agent.sql_metadata import extract_sql_metadata
9
+
10
+ _SHOW_VERBS = re.compile(
11
+ r"\b(show|display|list|see|view|fetch|get)\b",
12
+ re.IGNORECASE,
13
+ )
14
+ _AFTER_WRITE_HINTS = re.compile(
15
+ r"\b(updated|update|modified|changes?|changed|result|after)\b",
16
+ re.IGNORECASE,
17
+ )
18
+ _ANAPHORA = re.compile(
19
+ r"\b(that|it|same|previous|those|the change|the update)\b",
20
+ re.IGNORECASE,
21
+ )
22
+ _QUOTED_IDENT = r'"([^"]+)"'
23
+ _BARE_IDENT = r"([a-zA-Z_][\w]*)"
24
+
25
+
26
+ def is_show_after_write_prompt(text: str) -> bool:
27
+ """True when the user likely wants to read back data after a write."""
28
+ lowered = text.lower()
29
+ if not _SHOW_VERBS.search(lowered):
30
+ return False
31
+ if _AFTER_WRITE_HINTS.search(lowered):
32
+ return True
33
+ return "table" in lowered or "rows" in lowered or "record" in lowered
34
+
35
+
36
+ def try_build_show_after_write_sql(
37
+ user_prompt: str,
38
+ memory: ConversationMemory | None,
39
+ ) -> str | None:
40
+ """
41
+ Build a SELECT for the table/row affected by the most recent write.
42
+
43
+ Returns None when this follow-up pattern does not apply.
44
+ """
45
+ if memory is None or not is_show_after_write_prompt(user_prompt):
46
+ return None
47
+
48
+ last_write = _find_last_write_turn(memory)
49
+ if last_write is None or not last_write.sql:
50
+ return None
51
+
52
+ metadata = extract_sql_metadata(last_write.sql)
53
+ table = metadata.table or _table_from_sql(last_write.sql)
54
+ if not table:
55
+ return None
56
+
57
+ table_ref = _quote_table_ref(table)
58
+ if metadata.where_clause:
59
+ return f"SELECT * FROM {table_ref} WHERE {metadata.where_clause} LIMIT 10;"
60
+ return f"SELECT * FROM {table_ref} LIMIT 10;"
61
+
62
+
63
+ def ground_followup_prompt(text: str, memory: ConversationMemory | None) -> str:
64
+ """Enrich vague follow-ups using recent conversation turns."""
65
+ if memory is None:
66
+ return text
67
+
68
+ if is_show_after_write_prompt(text):
69
+ last_write = _find_last_write_turn(memory)
70
+ if last_write is not None:
71
+ metadata = extract_sql_metadata(last_write.sql or "")
72
+ table = metadata.table or _table_from_sql(last_write.sql or "") or "unknown"
73
+ where = metadata.where_clause or ""
74
+ cols = ", ".join(last_write.changed_columns) if last_write.changed_columns else "*"
75
+ return (
76
+ f"{text}\n\n"
77
+ "Context from previous turn:\n"
78
+ f"- Previous operation: {last_write.operation}\n"
79
+ f"- Table: {table}\n"
80
+ f"- Changed columns: {cols}\n"
81
+ f"- Previous filter: {where or '<none>'}\n"
82
+ "The user wants to VIEW current rows after that write.\n"
83
+ "MANDATORY: Generate a SELECT query only (NOT UPDATE/INSERT/DELETE).\n"
84
+ f"Prefer: SELECT * FROM {_quote_table_ref(table)}"
85
+ + (f" WHERE {where}" if where else "")
86
+ + " LIMIT 10;"
87
+ )
88
+
89
+ if not _ANAPHORA.search(text):
90
+ return text
91
+
92
+ for turn in reversed(memory.recent(10)):
93
+ if turn.operation in {"UPDATE", "INSERT"} and (turn.table or turn.sql):
94
+ metadata = extract_sql_metadata(turn.sql or "")
95
+ table = metadata.table or turn.table or "unknown"
96
+ cols = ", ".join(turn.changed_columns) if turn.changed_columns else "*"
97
+ where = metadata.where_clause or turn.where_clause or ""
98
+ return (
99
+ f"{text}\n"
100
+ "Context from previous turn:\n"
101
+ f"- Last {turn.operation} table: {table}\n"
102
+ f"- Changed columns: {cols}\n"
103
+ f"- Filter: {where or '<none>'}\n"
104
+ "If user asks to show the change, produce SELECT for that table/row."
105
+ )
106
+
107
+ return text
108
+
109
+
110
+ def _find_last_write_turn(memory: ConversationMemory) -> MemoryTurn | None:
111
+ for turn in reversed(memory.recent(10)):
112
+ if turn.operation in {"UPDATE", "INSERT"}:
113
+ return turn
114
+ return None
115
+
116
+
117
+ def _table_from_sql(sql: str) -> str | None:
118
+ metadata = extract_sql_metadata(sql)
119
+ return metadata.table
120
+
121
+
122
+ def _quote_table_ref(table: str) -> str:
123
+ if table.startswith('"') and table.endswith('"'):
124
+ return table
125
+ if table != table.lower():
126
+ return f'"{table}"'
127
+ return table
@@ -0,0 +1,21 @@
1
+ """Intent Agent wrapper: classify user input before SQL generation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from app.llm.intent_classifier import IntentClassification, classify_intent
6
+
7
+
8
+ def run_intent_agent(
9
+ user_prompt: str,
10
+ schema: dict[str, list[dict[str, str]]],
11
+ *,
12
+ model: str | None = None,
13
+ conversation_context: str | None = None,
14
+ ) -> IntentClassification:
15
+ """Classify whether the user message should reach the SQL generator agent."""
16
+ return classify_intent(
17
+ user_prompt,
18
+ schema,
19
+ model=model,
20
+ conversation_context=conversation_context,
21
+ )
app/agent/memory.py ADDED
@@ -0,0 +1,61 @@
1
+ """Ephemeral in-process conversation memory for interactive CLI sessions."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+
7
+
8
+ @dataclass
9
+ class MemoryTurn:
10
+ """Single turn retained in in-memory conversation context."""
11
+
12
+ user_prompt: str
13
+ sql: str | None
14
+ validation_error: str | None
15
+ execution_summary: str | None
16
+ operation: str | None
17
+ table: str | None
18
+ changed_columns: list[str]
19
+ where_clause: str | None
20
+
21
+
22
+ class ConversationMemory:
23
+ """Simple rolling memory for one running conversation."""
24
+
25
+ def __init__(self) -> None:
26
+ self._turns: list[MemoryTurn] = []
27
+
28
+ def append(self, turn: MemoryTurn) -> None:
29
+ self._turns.append(turn)
30
+
31
+ def clear(self) -> None:
32
+ self._turns.clear()
33
+
34
+ def recent(self, limit: int) -> list[MemoryTurn]:
35
+ if limit <= 0:
36
+ return []
37
+ return self._turns[-limit:]
38
+
39
+ def format_for_prompt(self, limit: int) -> str:
40
+ """
41
+ Format recent turns as compact plain text context for the SQL prompt.
42
+ """
43
+ rows: list[str] = []
44
+ for idx, turn in enumerate(self.recent(limit), start=1):
45
+ rows.append(f"Turn {idx}:")
46
+ rows.append(f"- User: {turn.user_prompt}")
47
+ if turn.sql:
48
+ rows.append(f"- SQL: {turn.sql}")
49
+ if turn.operation:
50
+ rows.append(f"- Operation: {turn.operation}")
51
+ if turn.table:
52
+ rows.append(f"- Table: {turn.table}")
53
+ if turn.changed_columns:
54
+ rows.append(f"- Changed columns: {', '.join(turn.changed_columns)}")
55
+ if turn.where_clause:
56
+ rows.append(f"- Where: {turn.where_clause}")
57
+ if turn.validation_error:
58
+ rows.append(f"- Validation: {turn.validation_error}")
59
+ if turn.execution_summary:
60
+ rows.append(f"- Result: {turn.execution_summary}")
61
+ return "\n".join(rows)
@@ -0,0 +1,101 @@
1
+ """Lightweight SQL metadata extraction for conversational follow-ups."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import re
6
+ from dataclasses import dataclass
7
+
8
+ _TABLE_PATTERN = r'(?:"([^"]+)"|([a-zA-Z_][\w]*))'
9
+
10
+
11
+ @dataclass
12
+ class SqlMetadata:
13
+ """Best-effort structured details from generated SQL."""
14
+
15
+ operation: str | None
16
+ table: str | None
17
+ changed_columns: list[str]
18
+ where_clause: str | None
19
+
20
+
21
+ def first_keyword(sql: str) -> str:
22
+ """Return the first SQL keyword (uppercased)."""
23
+ cleaned = sql.strip()
24
+ if not cleaned:
25
+ return ""
26
+ return cleaned.split(maxsplit=1)[0].upper()
27
+
28
+
29
+ def is_write_operation(sql: str) -> bool:
30
+ """True when SQL starts with INSERT or UPDATE."""
31
+ return first_keyword(sql) in {"INSERT", "UPDATE"}
32
+
33
+
34
+ def extract_sql_metadata(sql: str) -> SqlMetadata:
35
+ """
36
+ Best-effort extraction for SELECT/INSERT/UPDATE.
37
+
38
+ This is intentionally simple regex parsing for MVP behavior.
39
+ """
40
+ operation = first_keyword(sql) or None
41
+ table: str | None = None
42
+ changed_columns: list[str] = []
43
+ where_clause: str | None = None
44
+ cleaned = sql.strip().rstrip(";")
45
+
46
+ if operation == "UPDATE":
47
+ match = re.search(
48
+ rf"^\s*UPDATE\s+{_TABLE_PATTERN}\s+SET\s+(.*)$",
49
+ cleaned,
50
+ re.IGNORECASE,
51
+ )
52
+ if match:
53
+ table = match.group(1) or match.group(2)
54
+ remainder = match.group(3)
55
+ where_split = re.split(r"\bWHERE\b", remainder, maxsplit=1, flags=re.IGNORECASE)
56
+ set_expr = where_split[0]
57
+ if len(where_split) > 1:
58
+ where_clause = where_split[1].strip()
59
+ changed_columns = _extract_set_columns(set_expr)
60
+ elif operation == "INSERT":
61
+ match = re.search(
62
+ rf"^\s*INSERT\s+INTO\s+{_TABLE_PATTERN}",
63
+ cleaned,
64
+ re.IGNORECASE,
65
+ )
66
+ if match:
67
+ table = match.group(1) or match.group(2)
68
+ elif operation == "SELECT":
69
+ match = re.search(
70
+ rf"\bFROM\s+{_TABLE_PATTERN}",
71
+ cleaned,
72
+ re.IGNORECASE,
73
+ )
74
+ if match:
75
+ table = match.group(1) or match.group(2)
76
+ where_match = re.search(r"\bWHERE\b\s+(.*)$", cleaned, re.IGNORECASE)
77
+ if where_match:
78
+ where_clause = where_match.group(1).strip()
79
+
80
+ return SqlMetadata(
81
+ operation=operation,
82
+ table=table,
83
+ changed_columns=changed_columns,
84
+ where_clause=where_clause,
85
+ )
86
+
87
+
88
+ def _extract_set_columns(set_expr: str) -> list[str]:
89
+ columns: list[str] = []
90
+ for part in set_expr.split(","):
91
+ left = part.split("=", maxsplit=1)[0].strip()
92
+ if not left:
93
+ continue
94
+ quoted = re.match(r'"([^"]+)"', left)
95
+ if quoted:
96
+ columns.append(quoted.group(1))
97
+ continue
98
+ bare = re.match(r"([a-zA-Z_][\w]*)", left)
99
+ if bare:
100
+ columns.append(bare.group(1))
101
+ return columns
@@ -0,0 +1,117 @@
1
+ """Heuristic SQL quality checks before execution."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import re
6
+
7
+ _BROAD_VERBS = re.compile(
8
+ r"\b(show|list|display|get|fetch|give|view|browse|see)\b",
9
+ re.IGNORECASE,
10
+ )
11
+ _SPECIFIC_FIELD_HINTS = re.compile(
12
+ r"\b(only|just|columns?|fields?|select)\s+[\w,\s]+|"
13
+ r"\b(name|price|email|count|sum|avg|min|max|total)\b",
14
+ re.IGNORECASE,
15
+ )
16
+ _EXPLICIT_COLUMN_THRESHOLD = 8
17
+ _EXACT_LABEL_MATCH = re.compile(
18
+ r"""=\s*'[^']+'""",
19
+ re.IGNORECASE,
20
+ )
21
+ _JOIN_PATTERN = re.compile(r"\bJOIN\b", re.IGNORECASE)
22
+ _LABEL_COLUMN_NAMES = frozenset({"name", "title", "label", "slug", "code"})
23
+
24
+
25
+ def detect_quality_issue(user_prompt: str, sql: str) -> str | None:
26
+ """
27
+ Return a human-readable quality issue, or None if SQL looks acceptable.
28
+ """
29
+ cleaned = sql.strip()
30
+ if not cleaned.upper().startswith("SELECT"):
31
+ return None
32
+
33
+ if is_broad_list_request(user_prompt) and _over_selects_for_broad_request(cleaned):
34
+ return (
35
+ "Query lists too many explicit columns for a broad show/list request. "
36
+ "Use SELECT * FROM the target table with a reasonable LIMIT instead."
37
+ )
38
+
39
+ if _uses_exact_label_match_on_join(cleaned):
40
+ return (
41
+ "Query uses an exact string match on a related entity label. "
42
+ "Do not assume user text matches stored casing/spacing. "
43
+ "Resolve the related row with case-insensitive LOWER(...) LIKE LOWER(...) "
44
+ "on likely label columns, then filter the target table by foreign key "
45
+ "using a subquery or IN (...)."
46
+ )
47
+
48
+ return None
49
+
50
+
51
+ def is_broad_list_request(user_prompt: str) -> bool:
52
+ """True when the user asks to show/list rows without naming specific output fields."""
53
+ text = user_prompt.strip()
54
+ if not text:
55
+ return False
56
+ if not _BROAD_VERBS.search(text):
57
+ return False
58
+ return not _SPECIFIC_FIELD_HINTS.search(text)
59
+
60
+
61
+ def count_explicit_select_columns(sql: str) -> int:
62
+ """Count columns in the SELECT list; returns 0 for SELECT *."""
63
+ match = re.search(
64
+ r"^\s*SELECT\s+(DISTINCT\s+)?(.*?)\s+FROM\b",
65
+ sql,
66
+ re.IGNORECASE | re.DOTALL,
67
+ )
68
+ if not match:
69
+ return 0
70
+
71
+ select_list = match.group(2).strip()
72
+ if not select_list or select_list == "*":
73
+ return 0
74
+
75
+ # Rough split on commas not inside parentheses
76
+ depth = 0
77
+ parts: list[str] = []
78
+ current: list[str] = []
79
+ for char in select_list:
80
+ if char == "(":
81
+ depth += 1
82
+ elif char == ")":
83
+ depth = max(0, depth - 1)
84
+ elif char == "," and depth == 0:
85
+ parts.append("".join(current).strip())
86
+ current = []
87
+ continue
88
+ current.append(char)
89
+ if current:
90
+ parts.append("".join(current).strip())
91
+ return len([part for part in parts if part])
92
+
93
+
94
+ def _over_selects_for_broad_request(sql: str) -> bool:
95
+ if re.search(r"^\s*SELECT\s+\*\s+FROM\b", sql, re.IGNORECASE):
96
+ return False
97
+ return count_explicit_select_columns(sql) >= _EXPLICIT_COLUMN_THRESHOLD
98
+
99
+
100
+ def _uses_exact_label_match_on_join(sql: str) -> bool:
101
+ if not _JOIN_PATTERN.search(sql):
102
+ return False
103
+ if not _EXACT_LABEL_MATCH.search(sql):
104
+ return False
105
+
106
+ where_match = re.search(
107
+ r"\bWHERE\b(.+?)(?:\bORDER\b|\bGROUP\b|\bLIMIT\b|$)", sql, re.IGNORECASE | re.DOTALL
108
+ )
109
+ if not where_match:
110
+ return False
111
+
112
+ where_clause = where_match.group(1)
113
+ for column in _LABEL_COLUMN_NAMES:
114
+ pattern = rf'["\w.]*{column}["\w.]*\s*=\s*\'[^\']+\''
115
+ if re.search(pattern, where_clause, re.IGNORECASE):
116
+ return True
117
+ return False