nl2sql-agents 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nl2sql_agents/__init__.py +9 -0
- nl2sql_agents/agents/__init__.py +0 -0
- nl2sql_agents/agents/base_agent.py +74 -0
- nl2sql_agents/agents/discovery/__init__.py +0 -0
- nl2sql_agents/agents/discovery/discovery_agent.py +117 -0
- nl2sql_agents/agents/discovery/fk_graph_agent.py +75 -0
- nl2sql_agents/agents/discovery/keyword_agent.py +61 -0
- nl2sql_agents/agents/discovery/semantic_agent.py +61 -0
- nl2sql_agents/agents/explainer/__init__.py +0 -0
- nl2sql_agents/agents/explainer/explainer_agent.py +45 -0
- nl2sql_agents/agents/explainer/explanation_agent.py +32 -0
- nl2sql_agents/agents/explainer/optimization_agent.py +31 -0
- nl2sql_agents/agents/explainer/safety_report_agent.py +42 -0
- nl2sql_agents/agents/query_generator.py +133 -0
- nl2sql_agents/agents/schema_formatter.py +69 -0
- nl2sql_agents/agents/validator/__init__.py +0 -0
- nl2sql_agents/agents/validator/logic_validator.py +59 -0
- nl2sql_agents/agents/validator/performance_validator.py +74 -0
- nl2sql_agents/agents/validator/security_validator.py +51 -0
- nl2sql_agents/agents/validator/syntax_validator.py +74 -0
- nl2sql_agents/agents/validator/validator_agent.py +104 -0
- nl2sql_agents/cli.py +291 -0
- nl2sql_agents/config/__init__.py +0 -0
- nl2sql_agents/config/settings.py +66 -0
- nl2sql_agents/db/__init__.py +0 -0
- nl2sql_agents/db/connector.py +107 -0
- nl2sql_agents/filters/__init__.py +0 -0
- nl2sql_agents/filters/gate.py +62 -0
- nl2sql_agents/filters/security_filter.py +36 -0
- nl2sql_agents/models/__init__.py +0 -0
- nl2sql_agents/models/schemas.py +120 -0
- nl2sql_agents/orchestrator/__init__.py +0 -0
- nl2sql_agents/orchestrator/nodes.py +142 -0
- nl2sql_agents/orchestrator/pipeline.py +70 -0
- nl2sql_agents/py.typed +0 -0
- nl2sql_agents-0.1.0.dist-info/METADATA +540 -0
- nl2sql_agents-0.1.0.dist-info/RECORD +40 -0
- nl2sql_agents-0.1.0.dist-info/WHEEL +4 -0
- nl2sql_agents-0.1.0.dist-info/entry_points.txt +2 -0
- nl2sql_agents-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
"""
|
|
2
|
+
nl2sql-agents — Multi-Agent Natural Language to SQL System.
|
|
3
|
+
|
|
4
|
+
A sophisticated multi-agent orchestration system that converts natural
|
|
5
|
+
language queries into safe, optimized SQL using LangGraph and OpenRouter LLMs.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
__version__ = "0.1.0"
|
|
9
|
+
__all__ = ["__version__"]
|
|
File without changes
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
"""
|
|
2
|
+
BASE AGNET - Abstract Interface for all LLM agents.
|
|
3
|
+
|
|
4
|
+
hanldes:
|
|
5
|
+
- Async LLM calls via ChatOpenAI (langhcain-openai)
|
|
6
|
+
- provider-aware: each agent recieves an LLM provider
|
|
7
|
+
- token usage logging
|
|
8
|
+
|
|
9
|
+
Each agent implements:
|
|
10
|
+
- build_prompt()
|
|
11
|
+
- parse_response()
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import logging
|
|
15
|
+
from abc import ABC, abstractmethod
|
|
16
|
+
from typing import Any
|
|
17
|
+
from langchain_openai import ChatOpenAI
|
|
18
|
+
from nl2sql_agents.config.settings import LLMProvider, PRIMARY_PROVIDER
|
|
19
|
+
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
def _to_langchain_messages(messages: list[dict[str, str]]) -> list[BaseMessage]:
|
|
24
|
+
mapping = {"system": SystemMessage, "user": HumanMessage}
|
|
25
|
+
return [mapping.get(m["role"], HumanMessage)(content=m["content"]) for m in messages]
|
|
26
|
+
|
|
27
|
+
class BaseAgent(ABC):
|
|
28
|
+
def __init__(self, provider: LLMProvider = PRIMARY_PROVIDER) -> None:
|
|
29
|
+
self.provider = provider
|
|
30
|
+
self.model_name = provider.default_model
|
|
31
|
+
|
|
32
|
+
@abstractmethod
|
|
33
|
+
def build_prompt(self, *args, **kwargs) -> list[dict[str, str]]:
|
|
34
|
+
raise NotImplementedError("Abstract Method build_prompt not implemented")
|
|
35
|
+
|
|
36
|
+
@abstractmethod
|
|
37
|
+
def parse_response(self, raw: str) -> Any:
|
|
38
|
+
raise NotImplementedError("Abstract Method ParseResponse not implemented")
|
|
39
|
+
|
|
40
|
+
def _get_llm(self, temperature: float=0.3, max_tokens: int = 2048) -> ChatOpenAI:
|
|
41
|
+
return self.provider.chat_model(temperature=temperature, max_tokens=max_tokens)
|
|
42
|
+
|
|
43
|
+
async def call_llm(
|
|
44
|
+
self,
|
|
45
|
+
messages: list[dict[str, str]],
|
|
46
|
+
temperature: float = 0.3,
|
|
47
|
+
max_tokens: int = 2048
|
|
48
|
+
) -> str:
|
|
49
|
+
logger.debug(
|
|
50
|
+
"%s -> LLM (model=%s, temp=%.1f, msgs=%d)", self.__class__.__name__, self.model_name, temperature, len(messages)
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
llm = self._get_llm(temperature=temperature, max_tokens=max_tokens)
|
|
54
|
+
lc_messages = _to_langchain_messages(messages)
|
|
55
|
+
|
|
56
|
+
response = await llm.ainvoke(lc_messages)
|
|
57
|
+
|
|
58
|
+
if response.usage_metadata:
|
|
59
|
+
logger.debug("%s ← LLM (prompt=%d, completion=%d tokens)",
|
|
60
|
+
self.__class__.__name__,
|
|
61
|
+
response.usage_metadata.get("input_tokens", 0),
|
|
62
|
+
response.usage_metadata.get("output_tokens", 0),
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
return response.content.strip()
|
|
66
|
+
|
|
67
|
+
async def execute(self, *args, **kwargs) -> Any:
|
|
68
|
+
messages = self.build_prompt(*args, **kwargs)
|
|
69
|
+
raw = await self.call_llm(
|
|
70
|
+
messages,
|
|
71
|
+
temperature=kwargs.get("temperature", 0.3),
|
|
72
|
+
max_tokens=kwargs.get("max_tokens", 2048)
|
|
73
|
+
)
|
|
74
|
+
return self.parse_response(raw)
|
|
File without changes
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
"""
|
|
2
|
+
DISCOVERY AGENT
|
|
3
|
+
|
|
4
|
+
1. Keyword pre-filter (no llm)
|
|
5
|
+
- runs KeywordAgent on all tables
|
|
6
|
+
- Takes top KEYWORD_PRE_FILTER_TOP_N by keyword score
|
|
7
|
+
2. PARALLEL - Semantic + FK (on filtered tables only.)
|
|
8
|
+
- runs SemanticAgent + FKGraphAgent in parallel on filtered tables
|
|
9
|
+
- Merges all 3 scores with configurable weights
|
|
10
|
+
- returns full ranked list
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import logging
|
|
14
|
+
import asyncio
|
|
15
|
+
from collections import defaultdict
|
|
16
|
+
|
|
17
|
+
from .keyword_agent import KeywordAgent
|
|
18
|
+
from .fk_graph_agent import FKGraphAgent
|
|
19
|
+
from .semantic_agent import SemanticAgent
|
|
20
|
+
from nl2sql_agents.config.settings import KEYWORD_PRE_FILTER_TOP_N
|
|
21
|
+
from nl2sql_agents.models.schemas import TableMetaData, DiscoveryResult, ScoredTable
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
WEIGHTS = {
|
|
26
|
+
"keywords" : 0.35,
|
|
27
|
+
"semantic" : 0.45,
|
|
28
|
+
"fk_graph" : 0.20
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
class DiscoveryAgent:
|
|
32
|
+
def __init__(self) -> None:
|
|
33
|
+
self.keyword_agent = KeywordAgent()
|
|
34
|
+
self.semantic_agent = SemanticAgent()
|
|
35
|
+
self.fk_graph_agent = FKGraphAgent()
|
|
36
|
+
|
|
37
|
+
async def run(
|
|
38
|
+
self,
|
|
39
|
+
tables: list[TableMetaData],
|
|
40
|
+
user_query: str,
|
|
41
|
+
pre_filter_n: int = KEYWORD_PRE_FILTER_TOP_N
|
|
42
|
+
) -> DiscoveryResult:
|
|
43
|
+
|
|
44
|
+
# 1. keyword pre-filter
|
|
45
|
+
|
|
46
|
+
logger.info('DiscoveryAgent: Keyword prefilter on %d tables', len(tables))
|
|
47
|
+
|
|
48
|
+
kw_scores = await self.keyword_agent.score(tables, user_query)
|
|
49
|
+
sorted_by_kw = sorted(kw_scores.items(), key=lambda x:x[1], reverse=True)
|
|
50
|
+
top_n_names = {name for name, _ in sorted_by_kw[:pre_filter_n]}
|
|
51
|
+
pre_filtered = [t for t in tables if t.table_name in top_n_names]
|
|
52
|
+
|
|
53
|
+
logger.info('DiscoveryAgent phase 1: %d -> %d tables', len(tables), len(pre_filtered))
|
|
54
|
+
|
|
55
|
+
#2. semantic + FK in parallel
|
|
56
|
+
|
|
57
|
+
logger.info('DiscoveryAgent phase2: Semantic + FK in parallel on %d tables', len(pre_filtered))
|
|
58
|
+
|
|
59
|
+
sem_scores, fk_scores = await asyncio.gather(
|
|
60
|
+
self.semantic_agent.score(pre_filtered, user_query),
|
|
61
|
+
self.fk_graph_agent.score(pre_filtered, user_query)
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
merged = self._merge_and_rank(pre_filtered, kw_scores, sem_scores, fk_scores)
|
|
65
|
+
|
|
66
|
+
logger.info('Discover Agent: ranked %d tables, top-5 = %s', len(merged), [s.table.table_name for s in merged[:5]])
|
|
67
|
+
|
|
68
|
+
return DiscoveryResult(
|
|
69
|
+
top_tables=[s.table for s in merged[:5]],
|
|
70
|
+
scored_tables=merged
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
def _merge_and_rank(
|
|
74
|
+
self,
|
|
75
|
+
tables: list[TableMetaData],
|
|
76
|
+
kw: dict[str, float],
|
|
77
|
+
sem: dict[str, float],
|
|
78
|
+
fk: dict[str, float]
|
|
79
|
+
) -> list[ScoredTable]:
|
|
80
|
+
agg: dict[str, dict] = defaultdict(
|
|
81
|
+
lambda: {"score": 0.0, "found_by": []}
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
for name, score in kw.items():
|
|
85
|
+
if name not in {t.table_name for t in tables}:
|
|
86
|
+
continue
|
|
87
|
+
|
|
88
|
+
agg[name]["score"] += score * WEIGHTS["keywords"]
|
|
89
|
+
if score > 0:
|
|
90
|
+
agg[name]["found_by"].append("keyword")
|
|
91
|
+
|
|
92
|
+
for name, score in sem.items():
|
|
93
|
+
agg[name]["score"] += score * WEIGHTS["semantic"]
|
|
94
|
+
if score > 0:
|
|
95
|
+
agg[name]["found_by"].append("semantic")
|
|
96
|
+
|
|
97
|
+
for name, score in fk.items():
|
|
98
|
+
agg[name]["score"] += score * WEIGHTS["fk_graph"]
|
|
99
|
+
if score > 0:
|
|
100
|
+
agg[name]["found_by"].append("fk_graph")
|
|
101
|
+
|
|
102
|
+
table_map = {t.table_name: t for t in tables}
|
|
103
|
+
|
|
104
|
+
ranked = sorted(
|
|
105
|
+
[
|
|
106
|
+
ScoredTable(
|
|
107
|
+
table=table_map[name],
|
|
108
|
+
score=round(data['score'], 4),
|
|
109
|
+
found_by=list(set(data['found_by'])),
|
|
110
|
+
)
|
|
111
|
+
for name, data in agg.items() if name in table_map
|
|
112
|
+
],
|
|
113
|
+
key=lambda a: a.score,
|
|
114
|
+
reverse=True
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
return ranked
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Sub-Agent 1c - Foreign Key Graph Agent
|
|
3
|
+
|
|
4
|
+
Builds Bi-directional foriegn key graph from table metadata.
|
|
5
|
+
BFS-walk from seed tables, and scores by graph distance.
|
|
6
|
+
|
|
7
|
+
| Distance from Seed | Score | Interpretation |
|
|
8
|
+
| -------------------------- | ----- | --------------------------- |
|
|
9
|
+
| 0 (seed itself) | 1.0 | Directly mentioned in query |
|
|
10
|
+
| 1 (neighbors) | 0.5 | Directly linked via FK |
|
|
11
|
+
| 2 (neighbors of neighbors) | 0.25 | Two hops away |
|
|
12
|
+
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import logging
|
|
16
|
+
from collections import deque
|
|
17
|
+
from nl2sql_agents.models.schemas import TableMetaData
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
MAX_DEPTH = 2
|
|
22
|
+
|
|
23
|
+
class FKGraphAgent:
|
|
24
|
+
async def score(self, tables: list[TableMetaData], user_query: str) -> dict[str, float]:
|
|
25
|
+
graph = self._build_fk_graph(tables)
|
|
26
|
+
seeds = self._find_seeds(tables, user_query)
|
|
27
|
+
|
|
28
|
+
logger.debug('FKGraphAgent: seeds=%s', seeds)
|
|
29
|
+
return self._bfs_score(seeds, graph, MAX_DEPTH)
|
|
30
|
+
|
|
31
|
+
def _build_fk_graph(self, tables: list[TableMetaData])-> dict[str, set[str]]:
|
|
32
|
+
|
|
33
|
+
graph: dict[str, set[str]] = {t.table_name: set() for t in tables}
|
|
34
|
+
|
|
35
|
+
for table in tables:
|
|
36
|
+
for col in table.columns:
|
|
37
|
+
if col.is_foreign_key and col.reference_column:
|
|
38
|
+
graph[table.table_name].add(col.reference_table)
|
|
39
|
+
if col.reference_table in graph:
|
|
40
|
+
graph[col.reference_table].add(table.table_name)
|
|
41
|
+
return graph
|
|
42
|
+
|
|
43
|
+
def _find_seeds(self, tables: list[TableMetaData], user_query: str) -> list[str]:
|
|
44
|
+
"""seed = table whose name-parts appear in the user query"""
|
|
45
|
+
|
|
46
|
+
query_lower = user_query.lower()
|
|
47
|
+
seeds = []
|
|
48
|
+
for table in tables:
|
|
49
|
+
parts = table.table_name.lower().replace('_', ' ').split()
|
|
50
|
+
if any(part in query_lower for part in parts if len(part) > 2):
|
|
51
|
+
seeds.append(table.table_name)
|
|
52
|
+
return seeds
|
|
53
|
+
|
|
54
|
+
def _bfs_score(self, seeds: list[str], graph: dict[str, set[str]], max_depth: int) -> dict[str, float]:
|
|
55
|
+
scores: dict[str, float] = {}
|
|
56
|
+
visited: set[str] = set()
|
|
57
|
+
queue: deque[tuple[str, int]] = deque()
|
|
58
|
+
|
|
59
|
+
for seed in seeds:
|
|
60
|
+
queue.append((seed, 0))
|
|
61
|
+
visited.add(seed)
|
|
62
|
+
|
|
63
|
+
while queue:
|
|
64
|
+
node, depth = queue.popleft()
|
|
65
|
+
if depth > max_depth:
|
|
66
|
+
continue
|
|
67
|
+
score = 1.0 / (2 ** depth)
|
|
68
|
+
scores[node] = max(scores.get(node, 0.0), score)
|
|
69
|
+
|
|
70
|
+
for neighbor in graph.get(node, []):
|
|
71
|
+
if neighbor not in visited:
|
|
72
|
+
visited.add(neighbor)
|
|
73
|
+
queue.append((neighbor, depth+1))
|
|
74
|
+
|
|
75
|
+
return scores
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Sub-Agent 1a - Keyword Agent (No LLM)
|
|
3
|
+
|
|
4
|
+
Extract Meaningful tokens from user query and fuzzy-matches aginst table names and column names. Returns {tablename: score}
|
|
5
|
+
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import re
|
|
9
|
+
import logging
|
|
10
|
+
from difflib import SequenceMatcher
|
|
11
|
+
from nl2sql_agents.models.schemas import TableMetaData
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
STOP_WORDS = {
|
|
16
|
+
"show", "me", "get", "find", "list", "give", "the", "a", "an",
|
|
17
|
+
"of", "for", "in", "on", "by", "with", "from", "where", "top",
|
|
18
|
+
"all", "my", "this", "that", "and", "or", "is", "are", "was",
|
|
19
|
+
"how", "many", "much", "what", "which", "who",
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
class KeywordAgent:
|
|
23
|
+
async def score(
|
|
24
|
+
self,
|
|
25
|
+
tables: list[TableMetaData],
|
|
26
|
+
user_query: str
|
|
27
|
+
) -> dict[str, float]:
|
|
28
|
+
keywords = self._extract_keywords(user_query)
|
|
29
|
+
logger.debug('KeywordAgent: Keywords=%s', keywords)
|
|
30
|
+
|
|
31
|
+
return {
|
|
32
|
+
t.table_name: self._score_table(t, keywords)
|
|
33
|
+
for t in tables
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
def _extract_keywords(self, query: str) -> list[str]:
|
|
37
|
+
tokens = re.findall(r"[a-zA-Z]+", query.lower())
|
|
38
|
+
return [t for t in tokens if t not in STOP_WORDS and len(t) > 2]
|
|
39
|
+
|
|
40
|
+
def _score_table(self, table: TableMetaData, keywords: list[str]) -> float:
|
|
41
|
+
if not keywords:
|
|
42
|
+
return 0.0
|
|
43
|
+
|
|
44
|
+
candidates = [table.table_name.lower()] + [
|
|
45
|
+
c.column_name.lower() for c in table.columns
|
|
46
|
+
]
|
|
47
|
+
|
|
48
|
+
best_scores = []
|
|
49
|
+
for kw in keywords:
|
|
50
|
+
kw_best = max(
|
|
51
|
+
self._fuzzy_score(kw, candidate)
|
|
52
|
+
for candidate in candidates
|
|
53
|
+
)
|
|
54
|
+
best_scores.append(kw_best)
|
|
55
|
+
|
|
56
|
+
return round(sum(best_scores) / len(best_scores), 4)
|
|
57
|
+
|
|
58
|
+
def _fuzzy_score(self, keyword: str, target: str) -> float:
|
|
59
|
+
if keyword in target:
|
|
60
|
+
return 1.0
|
|
61
|
+
return SequenceMatcher(None, keyword, target).ratio()
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Sub-Agent 1b - Semantic Agent
|
|
3
|
+
|
|
4
|
+
Embeds the user query and table descriptions using OpenAIEmbeddings.
|
|
5
|
+
Return Cosine similarity scores: {table_name: similarity}
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
import numpy as np
|
|
10
|
+
from nl2sql_agents.models.schemas import TableMetaData
|
|
11
|
+
from nl2sql_agents.config.settings import EMBEDDING_PROVIDER
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
def _cosine_similarity(a: list[float], b: list[float]) -> float:
|
|
16
|
+
va, vb = np.array(a), np.array(b)
|
|
17
|
+
# Calculate product of magnitudes (L2 norms)
|
|
18
|
+
# ||a|| * ||b||
|
|
19
|
+
norm = np.linalg.norm(va) * np.linalg.norm(vb)
|
|
20
|
+
|
|
21
|
+
"""
|
|
22
|
+
# Two similar vectors (both about "sales")
|
|
23
|
+
a = [0.5, 0.8, 0.2, 0.1] # embedding for "revenue"
|
|
24
|
+
b = [0.6, 0.7, 0.3, 0.2] # embedding for "sales"
|
|
25
|
+
|
|
26
|
+
# Dot product: 0.5*0.6 + 0.8*0.7 + 0.2*0.3 + 0.1*0.2 = 0.88
|
|
27
|
+
# ||a|| = sqrt(0.25 + 0.64 + 0.04 + 0.01) = 0.949
|
|
28
|
+
# ||b|| = sqrt(0.36 + 0.49 + 0.09 + 0.04) = 0.990
|
|
29
|
+
# Result: 0.88 / (0.949 * 0.990) ≈ 0.94 ← very similar!
|
|
30
|
+
|
|
31
|
+
# Two unrelated vectors
|
|
32
|
+
c = [0.9, 0.1, 0.0, 0.0] # embedding for "apple" (fruit)
|
|
33
|
+
d = [0.1, 0.9, 0.8, 0.5] # embedding for "car" (vehicle)
|
|
34
|
+
# Result: ≈ 0.15 ← not similar
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
return float(np.dot(va, vb)/norm) if norm > 0 else 0.0
|
|
38
|
+
|
|
39
|
+
class SemanticAgent:
|
|
40
|
+
def __init__(self) -> None:
|
|
41
|
+
self.embeddings = EMBEDDING_PROVIDER.embeddings_model()
|
|
42
|
+
|
|
43
|
+
async def score(
|
|
44
|
+
self, tables: list[TableMetaData], user_query: str
|
|
45
|
+
) -> dict[str, float]:
|
|
46
|
+
logger.debug("SemanticAgent: embedding query + %d tables", len(tables))
|
|
47
|
+
|
|
48
|
+
texts = [user_query] + [self._table_to_text(t) for t in tables]
|
|
49
|
+
|
|
50
|
+
all_embeddings = await self.embeddings.aembed_documents(texts)
|
|
51
|
+
|
|
52
|
+
query_emb = all_embeddings[0]
|
|
53
|
+
table_embs = all_embeddings[1:]
|
|
54
|
+
|
|
55
|
+
return {
|
|
56
|
+
t.table_name: round(_cosine_similarity(query_emb, emb), 4) for t, emb in zip(tables, table_embs)
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
def _table_to_text(self, table: TableMetaData) -> str:
|
|
60
|
+
col_names = ','.join(c.column_name for c in table.columns[:20])
|
|
61
|
+
return f"Table {table.table_name}: columns {col_names}"
|
|
File without changes
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""
|
|
2
|
+
AGENT 5 - EXPLAINER AGENT
|
|
3
|
+
|
|
4
|
+
runs PARALLEL - 3 output tasks concurrently:
|
|
5
|
+
- Explanation (Plain English)
|
|
6
|
+
- Safety Report (audit from validation report)
|
|
7
|
+
- Optimization Hints (LLM)
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import asyncio
|
|
11
|
+
import logging
|
|
12
|
+
|
|
13
|
+
from nl2sql_agents.agents.explainer.explanation_agent import ExplanationAgent
|
|
14
|
+
from nl2sql_agents.agents.explainer.optimization_agent import OptimizationAgent
|
|
15
|
+
from nl2sql_agents.agents.explainer.safety_report_agent import SafetyReportAgent
|
|
16
|
+
from nl2sql_agents.models.schemas import SQLCandidate, CandidateValidationResult, ExplainerOutput
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
class ExplainerAgent:
|
|
21
|
+
def __init__(self) -> None:
|
|
22
|
+
self.explanation = ExplanationAgent()
|
|
23
|
+
self.safety_report = SafetyReportAgent()
|
|
24
|
+
self.optimization = OptimizationAgent()
|
|
25
|
+
|
|
26
|
+
async def explain(
|
|
27
|
+
self,
|
|
28
|
+
candidate: SQLCandidate,
|
|
29
|
+
validation_results: list[CandidateValidationResult],
|
|
30
|
+
user_query: str = "",
|
|
31
|
+
) -> ExplainerOutput:
|
|
32
|
+
"""PARALLEL"""
|
|
33
|
+
logger.info("ExplainerAgent: 3 output taaks in PARALLEL")
|
|
34
|
+
|
|
35
|
+
explanation, safety, hints = await asyncio.gather(
|
|
36
|
+
self.explanation.run(candidate.sql, user_query),
|
|
37
|
+
self.safety_report.run(candidate, validation_results),
|
|
38
|
+
self.optimization.run(candidate.sql)
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
return ExplainerOutput(
|
|
42
|
+
explanation=explanation,
|
|
43
|
+
safety_report=safety,
|
|
44
|
+
optimization_hints = hints
|
|
45
|
+
)
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Explanation Agent - Translates SQL into plain English
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from nl2sql_agents.agents.base_agent import BaseAgent
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
SYSTEM_PROMPT = """You are a helpful data analyst explaining SQL to a business user.
|
|
11
|
+
Given a SQL query and the original question, explain in 2-4 plain English sentences what the query does. Do not include SQL syntax in your explanation."""
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ExplanationAgent(BaseAgent):
|
|
15
|
+
def build_prompt(self, sql: str = "", user_query: str = "", **_) -> list[dict[str, str]]:
|
|
16
|
+
USER_PROMPT = (
|
|
17
|
+
f"Original question: {user_query}\n\n"
|
|
18
|
+
f"SQL: \n{sql}\n\n"
|
|
19
|
+
"Explain in plain English."
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
return [
|
|
23
|
+
{"role": "system", "content": SYSTEM_PROMPT},
|
|
24
|
+
{"role": "user", "content": USER_PROMPT}
|
|
25
|
+
]
|
|
26
|
+
|
|
27
|
+
def parse_response(self, raw: str) -> str:
|
|
28
|
+
return raw.strip()
|
|
29
|
+
|
|
30
|
+
async def run(self, sql: str, user_query: str) -> str:
|
|
31
|
+
messages = self.build_prompt(sql=sql, user_query=user_query)
|
|
32
|
+
return await self.call_llm(messages, temperature=0.3, max_tokens=300)
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""
|
|
2
|
+
OPTIMIZATION AGENT - LLM SUGGESTS INDEX/PERFORMANCE IMPROVMENTS
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from nl2sql_agents.agents.base_agent import BaseAgent
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
SYSTEM_PROMPT = """You are a database performance tuning expert.
|
|
11
|
+
Given a SQL query, suggest 1-3 concrete optimization hints.
|
|
12
|
+
Be specific and actionable. If well-optimized, say so.
|
|
13
|
+
Keep each hint to one sentence."""
|
|
14
|
+
|
|
15
|
+
class OptimizationAgent(BaseAgent):
|
|
16
|
+
async def run(self, sql: str) -> str:
|
|
17
|
+
messages = self.build_prompt(sql=sql)
|
|
18
|
+
return await self.call_llm(
|
|
19
|
+
messages=messages,
|
|
20
|
+
temperature=0.2,
|
|
21
|
+
max_tokens=200
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
def build_prompt(self, sql: str = "", **_) -> list[dict[str, str]]:
|
|
25
|
+
return [
|
|
26
|
+
{"role": "system", "content": SYSTEM_PROMPT},
|
|
27
|
+
{"role": "user", "content": f"Optimize this SQl: \n\n{sql}"}
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
def parse_response(self, raw: str) -> str:
|
|
31
|
+
return raw.strip()
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
"""
|
|
2
|
+
SAFETY REPORT AGENT - Generates structured audit report from validation results.
|
|
3
|
+
|
|
4
|
+
No LLM calls - purely derived from check date.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
from nl2sql_agents.models.schemas import SQLCandidate, CandidateValidationResult
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
ICONS={"passed": "", "warned": " ", "failed":""}
|
|
13
|
+
CHECK_ORDER = ["security", "syntax", "logic", "performance"]
|
|
14
|
+
|
|
15
|
+
class SafetyReportAgent:
|
|
16
|
+
async def run(self, candidate: SQLCandidate, all_results: list[CandidateValidationResult]) -> str:
|
|
17
|
+
winning = next(
|
|
18
|
+
(r for r in all_results if r.candidate.sql == candidate.sql), None
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
if not winning:
|
|
22
|
+
logger.warning("SafetyReportAgent: Safety Report Not Available")
|
|
23
|
+
return "Safety Report Not Available"
|
|
24
|
+
|
|
25
|
+
lines = ["Security & Quality Report", " "*36]
|
|
26
|
+
|
|
27
|
+
for check in sorted(winning.checks, key=lambda c: CHECK_ORDER.index(c.check_name.lower())):
|
|
28
|
+
if check.passed and check.score==1.0:
|
|
29
|
+
icon = ICONS['passed']
|
|
30
|
+
elif check.passed and check.score<1.0:
|
|
31
|
+
icon = ICONS['warned']
|
|
32
|
+
else:
|
|
33
|
+
icon = ICONS["failed"]
|
|
34
|
+
|
|
35
|
+
lines.append(
|
|
36
|
+
f"{icon} {check.check_name.upper():12} {check.details or ''}"
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
lines.append(" "*36)
|
|
40
|
+
|
|
41
|
+
lines.append(f"Total score: {winning.total_score:.1f} / 4.0")
|
|
42
|
+
return "\n".join(lines)
|