nl2sql-agents 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. nl2sql_agents/__init__.py +9 -0
  2. nl2sql_agents/agents/__init__.py +0 -0
  3. nl2sql_agents/agents/base_agent.py +74 -0
  4. nl2sql_agents/agents/discovery/__init__.py +0 -0
  5. nl2sql_agents/agents/discovery/discovery_agent.py +117 -0
  6. nl2sql_agents/agents/discovery/fk_graph_agent.py +75 -0
  7. nl2sql_agents/agents/discovery/keyword_agent.py +61 -0
  8. nl2sql_agents/agents/discovery/semantic_agent.py +61 -0
  9. nl2sql_agents/agents/explainer/__init__.py +0 -0
  10. nl2sql_agents/agents/explainer/explainer_agent.py +45 -0
  11. nl2sql_agents/agents/explainer/explanation_agent.py +32 -0
  12. nl2sql_agents/agents/explainer/optimization_agent.py +31 -0
  13. nl2sql_agents/agents/explainer/safety_report_agent.py +42 -0
  14. nl2sql_agents/agents/query_generator.py +133 -0
  15. nl2sql_agents/agents/schema_formatter.py +69 -0
  16. nl2sql_agents/agents/validator/__init__.py +0 -0
  17. nl2sql_agents/agents/validator/logic_validator.py +59 -0
  18. nl2sql_agents/agents/validator/performance_validator.py +74 -0
  19. nl2sql_agents/agents/validator/security_validator.py +51 -0
  20. nl2sql_agents/agents/validator/syntax_validator.py +74 -0
  21. nl2sql_agents/agents/validator/validator_agent.py +104 -0
  22. nl2sql_agents/cli.py +291 -0
  23. nl2sql_agents/config/__init__.py +0 -0
  24. nl2sql_agents/config/settings.py +66 -0
  25. nl2sql_agents/db/__init__.py +0 -0
  26. nl2sql_agents/db/connector.py +107 -0
  27. nl2sql_agents/filters/__init__.py +0 -0
  28. nl2sql_agents/filters/gate.py +62 -0
  29. nl2sql_agents/filters/security_filter.py +36 -0
  30. nl2sql_agents/models/__init__.py +0 -0
  31. nl2sql_agents/models/schemas.py +120 -0
  32. nl2sql_agents/orchestrator/__init__.py +0 -0
  33. nl2sql_agents/orchestrator/nodes.py +142 -0
  34. nl2sql_agents/orchestrator/pipeline.py +70 -0
  35. nl2sql_agents/py.typed +0 -0
  36. nl2sql_agents-0.1.0.dist-info/METADATA +540 -0
  37. nl2sql_agents-0.1.0.dist-info/RECORD +40 -0
  38. nl2sql_agents-0.1.0.dist-info/WHEEL +4 -0
  39. nl2sql_agents-0.1.0.dist-info/entry_points.txt +2 -0
  40. nl2sql_agents-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,9 @@
1
+ """
2
+ nl2sql-agents — Multi-Agent Natural Language to SQL System.
3
+
4
+ A sophisticated multi-agent orchestration system that converts natural
5
+ language queries into safe, optimized SQL using LangGraph and OpenRouter LLMs.
6
+ """
7
+
8
+ __version__ = "0.1.0"
9
+ __all__ = ["__version__"]
File without changes
@@ -0,0 +1,74 @@
1
+ """
2
+ BASE AGNET - Abstract Interface for all LLM agents.
3
+
4
+ hanldes:
5
+ - Async LLM calls via ChatOpenAI (langhcain-openai)
6
+ - provider-aware: each agent recieves an LLM provider
7
+ - token usage logging
8
+
9
+ Each agent implements:
10
+ - build_prompt()
11
+ - parse_response()
12
+ """
13
+
14
+ import logging
15
+ from abc import ABC, abstractmethod
16
+ from typing import Any
17
+ from langchain_openai import ChatOpenAI
18
+ from nl2sql_agents.config.settings import LLMProvider, PRIMARY_PROVIDER
19
+ from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ def _to_langchain_messages(messages: list[dict[str, str]]) -> list[BaseMessage]:
24
+ mapping = {"system": SystemMessage, "user": HumanMessage}
25
+ return [mapping.get(m["role"], HumanMessage)(content=m["content"]) for m in messages]
26
+
27
+ class BaseAgent(ABC):
28
+ def __init__(self, provider: LLMProvider = PRIMARY_PROVIDER) -> None:
29
+ self.provider = provider
30
+ self.model_name = provider.default_model
31
+
32
+ @abstractmethod
33
+ def build_prompt(self, *args, **kwargs) -> list[dict[str, str]]:
34
+ raise NotImplementedError("Abstract Method build_prompt not implemented")
35
+
36
+ @abstractmethod
37
+ def parse_response(self, raw: str) -> Any:
38
+ raise NotImplementedError("Abstract Method ParseResponse not implemented")
39
+
40
+ def _get_llm(self, temperature: float=0.3, max_tokens: int = 2048) -> ChatOpenAI:
41
+ return self.provider.chat_model(temperature=temperature, max_tokens=max_tokens)
42
+
43
+ async def call_llm(
44
+ self,
45
+ messages: list[dict[str, str]],
46
+ temperature: float = 0.3,
47
+ max_tokens: int = 2048
48
+ ) -> str:
49
+ logger.debug(
50
+ "%s -> LLM (model=%s, temp=%.1f, msgs=%d)", self.__class__.__name__, self.model_name, temperature, len(messages)
51
+ )
52
+
53
+ llm = self._get_llm(temperature=temperature, max_tokens=max_tokens)
54
+ lc_messages = _to_langchain_messages(messages)
55
+
56
+ response = await llm.ainvoke(lc_messages)
57
+
58
+ if response.usage_metadata:
59
+ logger.debug("%s ← LLM (prompt=%d, completion=%d tokens)",
60
+ self.__class__.__name__,
61
+ response.usage_metadata.get("input_tokens", 0),
62
+ response.usage_metadata.get("output_tokens", 0),
63
+ )
64
+
65
+ return response.content.strip()
66
+
67
+ async def execute(self, *args, **kwargs) -> Any:
68
+ messages = self.build_prompt(*args, **kwargs)
69
+ raw = await self.call_llm(
70
+ messages,
71
+ temperature=kwargs.get("temperature", 0.3),
72
+ max_tokens=kwargs.get("max_tokens", 2048)
73
+ )
74
+ return self.parse_response(raw)
File without changes
@@ -0,0 +1,117 @@
1
+ """
2
+ DISCOVERY AGENT
3
+
4
+ 1. Keyword pre-filter (no llm)
5
+ - runs KeywordAgent on all tables
6
+ - Takes top KEYWORD_PRE_FILTER_TOP_N by keyword score
7
+ 2. PARALLEL - Semantic + FK (on filtered tables only.)
8
+ - runs SemanticAgent + FKGraphAgent in parallel on filtered tables
9
+ - Merges all 3 scores with configurable weights
10
+ - returns full ranked list
11
+ """
12
+
13
+ import logging
14
+ import asyncio
15
+ from collections import defaultdict
16
+
17
+ from .keyword_agent import KeywordAgent
18
+ from .fk_graph_agent import FKGraphAgent
19
+ from .semantic_agent import SemanticAgent
20
+ from nl2sql_agents.config.settings import KEYWORD_PRE_FILTER_TOP_N
21
+ from nl2sql_agents.models.schemas import TableMetaData, DiscoveryResult, ScoredTable
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+ WEIGHTS = {
26
+ "keywords" : 0.35,
27
+ "semantic" : 0.45,
28
+ "fk_graph" : 0.20
29
+ }
30
+
31
+ class DiscoveryAgent:
32
+ def __init__(self) -> None:
33
+ self.keyword_agent = KeywordAgent()
34
+ self.semantic_agent = SemanticAgent()
35
+ self.fk_graph_agent = FKGraphAgent()
36
+
37
+ async def run(
38
+ self,
39
+ tables: list[TableMetaData],
40
+ user_query: str,
41
+ pre_filter_n: int = KEYWORD_PRE_FILTER_TOP_N
42
+ ) -> DiscoveryResult:
43
+
44
+ # 1. keyword pre-filter
45
+
46
+ logger.info('DiscoveryAgent: Keyword prefilter on %d tables', len(tables))
47
+
48
+ kw_scores = await self.keyword_agent.score(tables, user_query)
49
+ sorted_by_kw = sorted(kw_scores.items(), key=lambda x:x[1], reverse=True)
50
+ top_n_names = {name for name, _ in sorted_by_kw[:pre_filter_n]}
51
+ pre_filtered = [t for t in tables if t.table_name in top_n_names]
52
+
53
+ logger.info('DiscoveryAgent phase 1: %d -> %d tables', len(tables), len(pre_filtered))
54
+
55
+ #2. semantic + FK in parallel
56
+
57
+ logger.info('DiscoveryAgent phase2: Semantic + FK in parallel on %d tables', len(pre_filtered))
58
+
59
+ sem_scores, fk_scores = await asyncio.gather(
60
+ self.semantic_agent.score(pre_filtered, user_query),
61
+ self.fk_graph_agent.score(pre_filtered, user_query)
62
+ )
63
+
64
+ merged = self._merge_and_rank(pre_filtered, kw_scores, sem_scores, fk_scores)
65
+
66
+ logger.info('Discover Agent: ranked %d tables, top-5 = %s', len(merged), [s.table.table_name for s in merged[:5]])
67
+
68
+ return DiscoveryResult(
69
+ top_tables=[s.table for s in merged[:5]],
70
+ scored_tables=merged
71
+ )
72
+
73
+ def _merge_and_rank(
74
+ self,
75
+ tables: list[TableMetaData],
76
+ kw: dict[str, float],
77
+ sem: dict[str, float],
78
+ fk: dict[str, float]
79
+ ) -> list[ScoredTable]:
80
+ agg: dict[str, dict] = defaultdict(
81
+ lambda: {"score": 0.0, "found_by": []}
82
+ )
83
+
84
+ for name, score in kw.items():
85
+ if name not in {t.table_name for t in tables}:
86
+ continue
87
+
88
+ agg[name]["score"] += score * WEIGHTS["keywords"]
89
+ if score > 0:
90
+ agg[name]["found_by"].append("keyword")
91
+
92
+ for name, score in sem.items():
93
+ agg[name]["score"] += score * WEIGHTS["semantic"]
94
+ if score > 0:
95
+ agg[name]["found_by"].append("semantic")
96
+
97
+ for name, score in fk.items():
98
+ agg[name]["score"] += score * WEIGHTS["fk_graph"]
99
+ if score > 0:
100
+ agg[name]["found_by"].append("fk_graph")
101
+
102
+ table_map = {t.table_name: t for t in tables}
103
+
104
+ ranked = sorted(
105
+ [
106
+ ScoredTable(
107
+ table=table_map[name],
108
+ score=round(data['score'], 4),
109
+ found_by=list(set(data['found_by'])),
110
+ )
111
+ for name, data in agg.items() if name in table_map
112
+ ],
113
+ key=lambda a: a.score,
114
+ reverse=True
115
+ )
116
+
117
+ return ranked
@@ -0,0 +1,75 @@
1
+ """
2
+ Sub-Agent 1c - Foreign Key Graph Agent
3
+
4
+ Builds Bi-directional foriegn key graph from table metadata.
5
+ BFS-walk from seed tables, and scores by graph distance.
6
+
7
+ | Distance from Seed | Score | Interpretation |
8
+ | -------------------------- | ----- | --------------------------- |
9
+ | 0 (seed itself) | 1.0 | Directly mentioned in query |
10
+ | 1 (neighbors) | 0.5 | Directly linked via FK |
11
+ | 2 (neighbors of neighbors) | 0.25 | Two hops away |
12
+
13
+ """
14
+
15
+ import logging
16
+ from collections import deque
17
+ from nl2sql_agents.models.schemas import TableMetaData
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ MAX_DEPTH = 2
22
+
23
+ class FKGraphAgent:
24
+ async def score(self, tables: list[TableMetaData], user_query: str) -> dict[str, float]:
25
+ graph = self._build_fk_graph(tables)
26
+ seeds = self._find_seeds(tables, user_query)
27
+
28
+ logger.debug('FKGraphAgent: seeds=%s', seeds)
29
+ return self._bfs_score(seeds, graph, MAX_DEPTH)
30
+
31
+ def _build_fk_graph(self, tables: list[TableMetaData])-> dict[str, set[str]]:
32
+
33
+ graph: dict[str, set[str]] = {t.table_name: set() for t in tables}
34
+
35
+ for table in tables:
36
+ for col in table.columns:
37
+ if col.is_foreign_key and col.reference_column:
38
+ graph[table.table_name].add(col.reference_table)
39
+ if col.reference_table in graph:
40
+ graph[col.reference_table].add(table.table_name)
41
+ return graph
42
+
43
+ def _find_seeds(self, tables: list[TableMetaData], user_query: str) -> list[str]:
44
+ """seed = table whose name-parts appear in the user query"""
45
+
46
+ query_lower = user_query.lower()
47
+ seeds = []
48
+ for table in tables:
49
+ parts = table.table_name.lower().replace('_', ' ').split()
50
+ if any(part in query_lower for part in parts if len(part) > 2):
51
+ seeds.append(table.table_name)
52
+ return seeds
53
+
54
+ def _bfs_score(self, seeds: list[str], graph: dict[str, set[str]], max_depth: int) -> dict[str, float]:
55
+ scores: dict[str, float] = {}
56
+ visited: set[str] = set()
57
+ queue: deque[tuple[str, int]] = deque()
58
+
59
+ for seed in seeds:
60
+ queue.append((seed, 0))
61
+ visited.add(seed)
62
+
63
+ while queue:
64
+ node, depth = queue.popleft()
65
+ if depth > max_depth:
66
+ continue
67
+ score = 1.0 / (2 ** depth)
68
+ scores[node] = max(scores.get(node, 0.0), score)
69
+
70
+ for neighbor in graph.get(node, []):
71
+ if neighbor not in visited:
72
+ visited.add(neighbor)
73
+ queue.append((neighbor, depth+1))
74
+
75
+ return scores
@@ -0,0 +1,61 @@
1
+ """
2
+ Sub-Agent 1a - Keyword Agent (No LLM)
3
+
4
+ Extract Meaningful tokens from user query and fuzzy-matches aginst table names and column names. Returns {tablename: score}
5
+
6
+ """
7
+
8
+ import re
9
+ import logging
10
+ from difflib import SequenceMatcher
11
+ from nl2sql_agents.models.schemas import TableMetaData
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ STOP_WORDS = {
16
+ "show", "me", "get", "find", "list", "give", "the", "a", "an",
17
+ "of", "for", "in", "on", "by", "with", "from", "where", "top",
18
+ "all", "my", "this", "that", "and", "or", "is", "are", "was",
19
+ "how", "many", "much", "what", "which", "who",
20
+ }
21
+
22
+ class KeywordAgent:
23
+ async def score(
24
+ self,
25
+ tables: list[TableMetaData],
26
+ user_query: str
27
+ ) -> dict[str, float]:
28
+ keywords = self._extract_keywords(user_query)
29
+ logger.debug('KeywordAgent: Keywords=%s', keywords)
30
+
31
+ return {
32
+ t.table_name: self._score_table(t, keywords)
33
+ for t in tables
34
+ }
35
+
36
+ def _extract_keywords(self, query: str) -> list[str]:
37
+ tokens = re.findall(r"[a-zA-Z]+", query.lower())
38
+ return [t for t in tokens if t not in STOP_WORDS and len(t) > 2]
39
+
40
+ def _score_table(self, table: TableMetaData, keywords: list[str]) -> float:
41
+ if not keywords:
42
+ return 0.0
43
+
44
+ candidates = [table.table_name.lower()] + [
45
+ c.column_name.lower() for c in table.columns
46
+ ]
47
+
48
+ best_scores = []
49
+ for kw in keywords:
50
+ kw_best = max(
51
+ self._fuzzy_score(kw, candidate)
52
+ for candidate in candidates
53
+ )
54
+ best_scores.append(kw_best)
55
+
56
+ return round(sum(best_scores) / len(best_scores), 4)
57
+
58
+ def _fuzzy_score(self, keyword: str, target: str) -> float:
59
+ if keyword in target:
60
+ return 1.0
61
+ return SequenceMatcher(None, keyword, target).ratio()
@@ -0,0 +1,61 @@
1
+ """
2
+ Sub-Agent 1b - Semantic Agent
3
+
4
+ Embeds the user query and table descriptions using OpenAIEmbeddings.
5
+ Return Cosine similarity scores: {table_name: similarity}
6
+ """
7
+
8
+ import logging
9
+ import numpy as np
10
+ from nl2sql_agents.models.schemas import TableMetaData
11
+ from nl2sql_agents.config.settings import EMBEDDING_PROVIDER
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ def _cosine_similarity(a: list[float], b: list[float]) -> float:
16
+ va, vb = np.array(a), np.array(b)
17
+ # Calculate product of magnitudes (L2 norms)
18
+ # ||a|| * ||b||
19
+ norm = np.linalg.norm(va) * np.linalg.norm(vb)
20
+
21
+ """
22
+ # Two similar vectors (both about "sales")
23
+ a = [0.5, 0.8, 0.2, 0.1] # embedding for "revenue"
24
+ b = [0.6, 0.7, 0.3, 0.2] # embedding for "sales"
25
+
26
+ # Dot product: 0.5*0.6 + 0.8*0.7 + 0.2*0.3 + 0.1*0.2 = 0.88
27
+ # ||a|| = sqrt(0.25 + 0.64 + 0.04 + 0.01) = 0.949
28
+ # ||b|| = sqrt(0.36 + 0.49 + 0.09 + 0.04) = 0.990
29
+ # Result: 0.88 / (0.949 * 0.990) ≈ 0.94 ← very similar!
30
+
31
+ # Two unrelated vectors
32
+ c = [0.9, 0.1, 0.0, 0.0] # embedding for "apple" (fruit)
33
+ d = [0.1, 0.9, 0.8, 0.5] # embedding for "car" (vehicle)
34
+ # Result: ≈ 0.15 ← not similar
35
+ """
36
+
37
+ return float(np.dot(va, vb)/norm) if norm > 0 else 0.0
38
+
39
+ class SemanticAgent:
40
+ def __init__(self) -> None:
41
+ self.embeddings = EMBEDDING_PROVIDER.embeddings_model()
42
+
43
+ async def score(
44
+ self, tables: list[TableMetaData], user_query: str
45
+ ) -> dict[str, float]:
46
+ logger.debug("SemanticAgent: embedding query + %d tables", len(tables))
47
+
48
+ texts = [user_query] + [self._table_to_text(t) for t in tables]
49
+
50
+ all_embeddings = await self.embeddings.aembed_documents(texts)
51
+
52
+ query_emb = all_embeddings[0]
53
+ table_embs = all_embeddings[1:]
54
+
55
+ return {
56
+ t.table_name: round(_cosine_similarity(query_emb, emb), 4) for t, emb in zip(tables, table_embs)
57
+ }
58
+
59
+ def _table_to_text(self, table: TableMetaData) -> str:
60
+ col_names = ','.join(c.column_name for c in table.columns[:20])
61
+ return f"Table {table.table_name}: columns {col_names}"
File without changes
@@ -0,0 +1,45 @@
1
+ """
2
+ AGENT 5 - EXPLAINER AGENT
3
+
4
+ runs PARALLEL - 3 output tasks concurrently:
5
+ - Explanation (Plain English)
6
+ - Safety Report (audit from validation report)
7
+ - Optimization Hints (LLM)
8
+ """
9
+
10
+ import asyncio
11
+ import logging
12
+
13
+ from nl2sql_agents.agents.explainer.explanation_agent import ExplanationAgent
14
+ from nl2sql_agents.agents.explainer.optimization_agent import OptimizationAgent
15
+ from nl2sql_agents.agents.explainer.safety_report_agent import SafetyReportAgent
16
+ from nl2sql_agents.models.schemas import SQLCandidate, CandidateValidationResult, ExplainerOutput
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ class ExplainerAgent:
21
+ def __init__(self) -> None:
22
+ self.explanation = ExplanationAgent()
23
+ self.safety_report = SafetyReportAgent()
24
+ self.optimization = OptimizationAgent()
25
+
26
+ async def explain(
27
+ self,
28
+ candidate: SQLCandidate,
29
+ validation_results: list[CandidateValidationResult],
30
+ user_query: str = "",
31
+ ) -> ExplainerOutput:
32
+ """PARALLEL"""
33
+ logger.info("ExplainerAgent: 3 output taaks in PARALLEL")
34
+
35
+ explanation, safety, hints = await asyncio.gather(
36
+ self.explanation.run(candidate.sql, user_query),
37
+ self.safety_report.run(candidate, validation_results),
38
+ self.optimization.run(candidate.sql)
39
+ )
40
+
41
+ return ExplainerOutput(
42
+ explanation=explanation,
43
+ safety_report=safety,
44
+ optimization_hints = hints
45
+ )
@@ -0,0 +1,32 @@
1
+ """
2
+ Explanation Agent - Translates SQL into plain English
3
+ """
4
+
5
+ import logging
6
+ from nl2sql_agents.agents.base_agent import BaseAgent
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ SYSTEM_PROMPT = """You are a helpful data analyst explaining SQL to a business user.
11
+ Given a SQL query and the original question, explain in 2-4 plain English sentences what the query does. Do not include SQL syntax in your explanation."""
12
+
13
+
14
+ class ExplanationAgent(BaseAgent):
15
+ def build_prompt(self, sql: str = "", user_query: str = "", **_) -> list[dict[str, str]]:
16
+ USER_PROMPT = (
17
+ f"Original question: {user_query}\n\n"
18
+ f"SQL: \n{sql}\n\n"
19
+ "Explain in plain English."
20
+ )
21
+
22
+ return [
23
+ {"role": "system", "content": SYSTEM_PROMPT},
24
+ {"role": "user", "content": USER_PROMPT}
25
+ ]
26
+
27
+ def parse_response(self, raw: str) -> str:
28
+ return raw.strip()
29
+
30
+ async def run(self, sql: str, user_query: str) -> str:
31
+ messages = self.build_prompt(sql=sql, user_query=user_query)
32
+ return await self.call_llm(messages, temperature=0.3, max_tokens=300)
@@ -0,0 +1,31 @@
1
+ """
2
+ OPTIMIZATION AGENT - LLM SUGGESTS INDEX/PERFORMANCE IMPROVMENTS
3
+ """
4
+
5
+ import logging
6
+ from nl2sql_agents.agents.base_agent import BaseAgent
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ SYSTEM_PROMPT = """You are a database performance tuning expert.
11
+ Given a SQL query, suggest 1-3 concrete optimization hints.
12
+ Be specific and actionable. If well-optimized, say so.
13
+ Keep each hint to one sentence."""
14
+
15
+ class OptimizationAgent(BaseAgent):
16
+ async def run(self, sql: str) -> str:
17
+ messages = self.build_prompt(sql=sql)
18
+ return await self.call_llm(
19
+ messages=messages,
20
+ temperature=0.2,
21
+ max_tokens=200
22
+ )
23
+
24
+ def build_prompt(self, sql: str = "", **_) -> list[dict[str, str]]:
25
+ return [
26
+ {"role": "system", "content": SYSTEM_PROMPT},
27
+ {"role": "user", "content": f"Optimize this SQl: \n\n{sql}"}
28
+ ]
29
+
30
+ def parse_response(self, raw: str) -> str:
31
+ return raw.strip()
@@ -0,0 +1,42 @@
1
+ """
2
+ SAFETY REPORT AGENT - Generates structured audit report from validation results.
3
+
4
+ No LLM calls - purely derived from check date.
5
+ """
6
+
7
+ import logging
8
+ from nl2sql_agents.models.schemas import SQLCandidate, CandidateValidationResult
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ ICONS={"passed": "", "warned": " ", "failed":""}
13
+ CHECK_ORDER = ["security", "syntax", "logic", "performance"]
14
+
15
+ class SafetyReportAgent:
16
+ async def run(self, candidate: SQLCandidate, all_results: list[CandidateValidationResult]) -> str:
17
+ winning = next(
18
+ (r for r in all_results if r.candidate.sql == candidate.sql), None
19
+ )
20
+
21
+ if not winning:
22
+ logger.warning("SafetyReportAgent: Safety Report Not Available")
23
+ return "Safety Report Not Available"
24
+
25
+ lines = ["Security & Quality Report", " "*36]
26
+
27
+ for check in sorted(winning.checks, key=lambda c: CHECK_ORDER.index(c.check_name.lower())):
28
+ if check.passed and check.score==1.0:
29
+ icon = ICONS['passed']
30
+ elif check.passed and check.score<1.0:
31
+ icon = ICONS['warned']
32
+ else:
33
+ icon = ICONS["failed"]
34
+
35
+ lines.append(
36
+ f"{icon} {check.check_name.upper():12} {check.details or ''}"
37
+ )
38
+
39
+ lines.append(" "*36)
40
+
41
+ lines.append(f"Total score: {winning.total_score:.1f} / 4.0")
42
+ return "\n".join(lines)