gaard-core 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. gaard_core/__init__.py +0 -0
  2. gaard_core/audit/__init__.py +0 -0
  3. gaard_core/errors.py +79 -0
  4. gaard_core/evaluation/__init__.py +0 -0
  5. gaard_core/execution/__init__.py +0 -0
  6. gaard_core/execution/mock_executor.py +25 -0
  7. gaard_core/investigation/__init__.py +25 -0
  8. gaard_core/investigation/llm_readiness_agent.py +220 -0
  9. gaard_core/investigation/loop.py +83 -0
  10. gaard_core/investigation/mock_readiness_agent.py +20 -0
  11. gaard_core/investigation/models.py +62 -0
  12. gaard_core/json_utils.py +58 -0
  13. gaard_core/llm_output.py +12 -0
  14. gaard_core/policy_engine/__init__.py +0 -0
  15. gaard_core/prompt_compiler/__init__.py +0 -0
  16. gaard_core/prompt_compiler/intent_classification_prompt.py +58 -0
  17. gaard_core/prompt_compiler/investigation_readiness_prompt.py +84 -0
  18. gaard_core/prompt_compiler/models.py +19 -0
  19. gaard_core/prompt_compiler/result_classification_prompt.py +62 -0
  20. gaard_core/prompt_compiler/result_interpretation_prompt.py +73 -0
  21. gaard_core/prompt_compiler/schema_formatter.py +43 -0
  22. gaard_core/prompt_compiler/sql_generation_prompt.py +105 -0
  23. gaard_core/query_intent/__init__.py +1 -0
  24. gaard_core/query_intent/llm_classifier.py +112 -0
  25. gaard_core/query_intent/mock_classifier.py +14 -0
  26. gaard_core/query_pipeline/__init__.py +0 -0
  27. gaard_core/query_pipeline/llm_sql_generator.py +85 -0
  28. gaard_core/query_pipeline/mock_sql_generator.py +33 -0
  29. gaard_core/query_pipeline/models.py +57 -0
  30. gaard_core/query_pipeline/pipeline.py +124 -0
  31. gaard_core/result_classifier/__init__.py +1 -0
  32. gaard_core/result_classifier/llm_classifier.py +87 -0
  33. gaard_core/result_classifier/mock_classifier.py +10 -0
  34. gaard_core/result_interpreter/__init__.py +0 -0
  35. gaard_core/result_interpreter/llm_interpreter.py +66 -0
  36. gaard_core/result_interpreter/mock_interpreter.py +25 -0
  37. gaard_core/schema/__init__.py +0 -0
  38. gaard_core/schema/cache.py +59 -0
  39. gaard_core/schema/context.py +40 -0
  40. gaard_core/schema/models.py +27 -0
  41. gaard_core/security/__init__.py +0 -0
  42. gaard_core/semantic_layer/__init__.py +0 -0
  43. gaard_core/sql_validator/__init__.py +0 -0
  44. gaard_core/sql_validator/select_only.py +37 -0
  45. gaard_core-0.1.0.dist-info/METADATA +23 -0
  46. gaard_core-0.1.0.dist-info/RECORD +48 -0
  47. gaard_core-0.1.0.dist-info/WHEEL +5 -0
  48. gaard_core-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,84 @@
1
+ from gaard_core.investigation.models import InvestigationContext, InvestigationRoute
2
+ from gaard_core.json_utils import json_dumps
3
+ from gaard_core.prompt_compiler.models import CompiledPrompt
4
+
5
+
6
+ class InvestigationReadinessPromptCompiler:
7
+ def compile(self, context: InvestigationContext) -> CompiledPrompt:
8
+ payload = {
9
+ "question": context.question,
10
+ "datasource_id": context.datasource_id,
11
+ "user_id": context.user_id,
12
+ "schema": context.formatted_schema,
13
+ "business_logic": context.business_logic,
14
+ }
15
+
16
+ return CompiledPrompt(
17
+ system_prompt=self._build_system_prompt(),
18
+ user_prompt=self._build_user_prompt(payload),
19
+ metadata={
20
+ "allowed_routes": [item.value for item in InvestigationRoute],
21
+ },
22
+ )
23
+
24
+ def _build_system_prompt(self) -> str:
25
+ return """You are GAARD Investigation Readiness.
26
+
27
+ Your task is to decide whether GAARD already knows enough to create a correct SQL query for the user's question.
28
+
29
+ Assume nothing. Verify continuously.
30
+
31
+ Use only:
32
+ - the user's question,
33
+ - the active datasource schema,
34
+ - the approved or previously saved business logic supplied in the payload.
35
+
36
+ You do not generate SQL.
37
+ You do not answer the user.
38
+ You decide only whether normal SQL generation may start safely.
39
+
40
+ Return ready_for_sql=true only when all information needed for correct SQL is explicit in the question, schema, and business logic:
41
+ - requested business entity or metric,
42
+ - relevant tables, views and columns,
43
+ - required filters and dictionary/status values,
44
+ - required joins or relationships,
45
+ - requested output shape such as count, list, detail, or aggregation.
46
+
47
+ Return ready_for_sql=false when any material element is missing, ambiguous, inferred only from the model, or would require checking data values before SQL can be trusted. In that case route must be analysis.
48
+
49
+ Output rules:
50
+ - Return only a JSON object.
51
+ - Do not include markdown.
52
+ - Do not include reasoning outside the JSON.
53
+ - Do not include <think> blocks.
54
+ - Use exactly this JSON shape:
55
+ {"ready_for_sql":false,"route":"analysis","confidence":0.0,"reason":"short reason","missing_information":[],"required_analysis":[],"required_analysis_tasks":[],"assumptions":[]}
56
+
57
+ Required analysis task shape:
58
+ {"missing_information":"what is missing","required_analysis":"specific read-only data question for SQL analysis","category":"dictionary_value","expected_output":"what kind of result would resolve this"}
59
+
60
+ Allowed categories:
61
+ - dictionary_value
62
+ - relationship_logic
63
+ - filter_logic
64
+ - aggregation_logic
65
+ - entity_mapping
66
+ - unknown
67
+ """
68
+
69
+ def _build_user_prompt(self, payload: dict[str, str]) -> str:
70
+ return f"""Assess whether normal SQL generation can start.
71
+
72
+ Input JSON:
73
+ {json_dumps(payload, ensure_ascii=False, indent=2)}
74
+
75
+ Return one JSON object with:
76
+ - ready_for_sql: boolean
77
+ - route: sql or analysis
78
+ - confidence: number from 0 to 1
79
+ - reason: short explanation
80
+ - missing_information: list of missing or ambiguous items
81
+ - required_analysis: list of checks that Analysis mode should perform when ready_for_sql=false
82
+ - required_analysis_tasks: list of structured SQL-analysis tasks with missing_information, required_analysis, category, expected_output
83
+ - assumptions: list of any assumptions that would affect SQL correctness
84
+ """
@@ -0,0 +1,19 @@
1
+ from typing import Any
2
+
3
+ from pydantic import BaseModel, Field
4
+
5
+ from gaard_core.schema.models import DatabaseSchema
6
+
7
+
8
+ class SqlGenerationPromptRequest(BaseModel):
9
+ question: str = Field(min_length=1)
10
+ database_schema: DatabaseSchema | None = None
11
+ formatted_schema: str | None = None
12
+ dialect: str = "sqlite"
13
+ max_rows: int = 100
14
+
15
+
16
+ class CompiledPrompt(BaseModel):
17
+ system_prompt: str
18
+ user_prompt: str
19
+ metadata: dict[str, Any] = Field(default_factory=dict)
@@ -0,0 +1,62 @@
1
+ from gaard_core.json_utils import json_dumps
2
+ from gaard_core.prompt_compiler.models import CompiledPrompt
3
+ from gaard_core.query_pipeline.models import OutputClassification, QueryRequest
4
+
5
+
6
+ class ResultClassificationPromptCompiler:
7
+ def compile(
8
+ self,
9
+ request: QueryRequest,
10
+ answer: str,
11
+ ) -> CompiledPrompt:
12
+ payload = {
13
+ "question": request.question,
14
+ "answer": answer,
15
+ }
16
+
17
+ return CompiledPrompt(
18
+ system_prompt=self._build_system_prompt(),
19
+ user_prompt=self._build_user_prompt(payload),
20
+ metadata={
21
+ "allowed_classifications": [
22
+ item.value for item in OutputClassification
23
+ ],
24
+ },
25
+ )
26
+
27
+ def _build_system_prompt(self) -> str:
28
+ return """You are GAARD Output Classification.
29
+
30
+ Your task is to classify the user-facing interpreted answer into exactly one output data class.
31
+
32
+ Allowed classes:
33
+ - personal_data: the answer is about personal data, people, identities, audit events concerning personal data, or aggregates describing personal data access.
34
+ - sensitive_data: the answer is about sensitive or special-category data such as health, credentials, secrets, financial risk, legal status, biometric data, or similarly high-risk information.
35
+ - technical_data: the answer is about system configuration, schemas, logs, query mechanics, infrastructure, or operational technical metadata.
36
+ - neutral_data: the answer is about non-personal, non-sensitive business or aggregate information.
37
+ - unknown: the answer cannot be classified reliably.
38
+
39
+ Priority rules:
40
+ 1. Choose sensitive_data over personal_data if both apply.
41
+ 2. Choose personal_data over technical_data if the answer concerns audit or technical records about personal data.
42
+ 3. Choose unknown instead of guessing when the answer lacks enough context.
43
+
44
+ Rules:
45
+ - Classify only the interpreted answer and the user's question.
46
+ - Do not classify raw database rows.
47
+ - Return only one allowed class value.
48
+ - Do not include explanations.
49
+ - Do not include markdown.
50
+ - Do not include reasoning.
51
+ - Do not include <think> blocks.
52
+ """
53
+
54
+ def _build_user_prompt(self, payload: dict[str, str]) -> str:
55
+ return f"""Classify this interpreted result.
56
+
57
+ Input JSON:
58
+ {json_dumps(payload, ensure_ascii=False, indent=2)}
59
+
60
+ Return exactly one of:
61
+ {", ".join(item.value for item in OutputClassification)}
62
+ """
@@ -0,0 +1,73 @@
1
+ from typing import Any
2
+
3
+ from gaard_core.json_utils import json_dumps
4
+ from gaard_core.prompt_compiler.models import CompiledPrompt
5
+ from gaard_core.query_pipeline.models import QueryRequest, QueryResult
6
+
7
+
8
+ class ResultInterpretationPromptCompiler:
9
+ def compile(
10
+ self,
11
+ request: QueryRequest,
12
+ sql: str,
13
+ result: QueryResult,
14
+ ) -> CompiledPrompt:
15
+ system_prompt = self._build_system_prompt()
16
+ user_prompt = self._build_user_prompt(
17
+ question=request.question,
18
+ sql=sql,
19
+ rows=result.rows,
20
+ columns=result.columns,
21
+ )
22
+
23
+ return CompiledPrompt(
24
+ system_prompt=system_prompt,
25
+ user_prompt=user_prompt,
26
+ metadata={
27
+ "rows_count": len(result.rows),
28
+ "columns_count": len(result.columns),
29
+ },
30
+ )
31
+
32
+ def _build_system_prompt(self) -> str:
33
+ return """You are GAARD Data Result Interpreter.
34
+
35
+ Your task is to explain SQL query results to the user.
36
+
37
+ Rules:
38
+ - Answer in the same language as the user's question.
39
+ - Pay attention to correct user's language grammar and plural forms.
40
+ - Use only the data provided in the result.
41
+ - Do not invent facts.
42
+ - Be concise.
43
+ - Prefer one short paragraph.
44
+ - If the result is empty, say that the query returned no rows.
45
+ - If the result contains aggregated values, explain the value directly.
46
+ - Do not mention that you are an AI model.
47
+ - Do not include markdown tables unless explicitly needed.
48
+ - Do not include reasoning.
49
+ - Do not include <think> blocks.
50
+ - Return only the final answer.
51
+ """
52
+
53
+ def _build_user_prompt(
54
+ self,
55
+ question: str,
56
+ sql: str,
57
+ rows: list[dict[str, Any]],
58
+ columns: list[str],
59
+ ) -> str:
60
+ payload = {
61
+ "question": question,
62
+ "sql": sql,
63
+ "columns": columns,
64
+ "rows": rows,
65
+ }
66
+
67
+ return f"""Interpret the following SQL result for the user.
68
+
69
+ Input JSON:
70
+ {json_dumps(payload, ensure_ascii=False, indent=2)}
71
+
72
+ Return only the final user-facing answer.
73
+ """
@@ -0,0 +1,43 @@
1
+ from gaard_core.schema.models import DatabaseSchema, TableInfo
2
+
3
+
4
+ class SchemaPromptFormatter:
5
+ def format(self, schema: DatabaseSchema) -> str:
6
+ if not schema.tables:
7
+ return "No tables or views available."
8
+
9
+ sections: list[str] = []
10
+
11
+ for table in sorted(schema.tables, key=lambda item: item.name):
12
+ sections.append(self._format_table(table))
13
+
14
+ return "\n\n".join(sections)
15
+
16
+ def _format_table(self, table: TableInfo) -> str:
17
+ object_label = "View" if table.object_type == "view" else "Table"
18
+ lines: list[str] = [f"{object_label}: {table.name}", "Columns:"]
19
+
20
+ if not table.columns:
21
+ lines.append("- No columns available.")
22
+ else:
23
+ for column in table.columns:
24
+ modifiers: list[str] = []
25
+
26
+ if column.primary_key:
27
+ modifiers.append("primary key")
28
+
29
+ if not column.nullable:
30
+ modifiers.append("not null")
31
+
32
+ modifier_text = f" ({', '.join(modifiers)})" if modifiers else ""
33
+ lines.append(f"- {column.name}: {column.type}{modifier_text}")
34
+
35
+ if table.foreign_keys:
36
+ lines.append("Foreign keys:")
37
+
38
+ for foreign_key in table.foreign_keys:
39
+ constrained = ", ".join(foreign_key.constrained_columns)
40
+ referred = ", ".join(foreign_key.referred_columns)
41
+ lines.append(f"- {constrained} -> {foreign_key.referred_table}.{referred}")
42
+
43
+ return "\n".join(lines)
@@ -0,0 +1,105 @@
1
+ from gaard_core.errors import ConfigurationError
2
+ from gaard_core.prompt_compiler.models import CompiledPrompt, SqlGenerationPromptRequest
3
+ from gaard_core.prompt_compiler.schema_formatter import SchemaPromptFormatter
4
+
5
+
6
+ class SqlGenerationPromptCompiler:
7
+ def __init__(self, schema_formatter: SchemaPromptFormatter | None = None) -> None:
8
+ self.schema_formatter = schema_formatter or SchemaPromptFormatter()
9
+
10
+ def compile(self, request: SqlGenerationPromptRequest) -> CompiledPrompt:
11
+ formatted_schema = self._resolve_formatted_schema(request)
12
+
13
+ system_prompt = self._build_system_prompt(
14
+ dialect=request.dialect,
15
+ max_rows=request.max_rows,
16
+ )
17
+
18
+ user_prompt = self._build_user_prompt(
19
+ schema=formatted_schema,
20
+ question=request.question,
21
+ )
22
+
23
+ return CompiledPrompt(
24
+ system_prompt=system_prompt,
25
+ user_prompt=user_prompt,
26
+ metadata={
27
+ "dialect": request.dialect,
28
+ "max_rows": request.max_rows,
29
+ "schema_source": "formatted_schema"
30
+ if request.formatted_schema is not None
31
+ else "database_schema",
32
+ "tables_count": len(request.database_schema.tables)
33
+ if request.database_schema is not None
34
+ else None,
35
+ },
36
+ )
37
+
38
+ def _resolve_formatted_schema(self, request: SqlGenerationPromptRequest) -> str:
39
+ if request.formatted_schema is not None:
40
+ return request.formatted_schema
41
+
42
+ if request.database_schema is None:
43
+ raise ConfigurationError(
44
+ "Either database_schema or formatted_schema must be provided."
45
+ )
46
+
47
+ return self.schema_formatter.format(request.database_schema)
48
+
49
+ def _build_system_prompt(self, dialect: str, max_rows: int) -> str:
50
+ return f"""You are an expert data analyst and SQL specialist.
51
+
52
+ Your task is to generate exactly one valid SQL SELECT query based on:
53
+ - the user's question,
54
+ - the provided database schema,
55
+ - the provided data rules and descriptions.
56
+
57
+ You must generate SQL for the {dialect} dialect.
58
+
59
+ Core rules:
60
+ 1. Generate only one SQL statement.
61
+ 2. Generate only a SELECT statement.
62
+ 3. Do not generate multiple statements.
63
+ 4. Do not generate INSERT, UPDATE, DELETE, DROP, ALTER, CREATE, TRUNCATE, MERGE, REPLACE, GRANT or REVOKE.
64
+ 5. Use only tables, views and columns listed in the provided schema.
65
+ 6. Do not invent tables, views or columns.
66
+ 7. Return only raw SQL.
67
+ 8. Do not use markdown.
68
+ 9. Do not use code fences.
69
+ 10. Do not add comments.
70
+ 11. Do not add explanations.
71
+ 12. Do not include reasoning.
72
+ 13. Do not include <think> blocks.
73
+
74
+ Query construction rules:
75
+ 1. If the user asks for a count, use COUNT with a clear alias.
76
+ 2. If the user asks for a breakdown, distribution, comparison by category, or values "by" some dimension, use one SELECT statement with GROUP BY or conditional aggregation.
77
+ 3. If the user asks for both a total and a breakdown, prefer one SELECT statement that returns grouped rows or conditional aggregate columns.
78
+ 4. Do not solve one user question by generating multiple separate SELECT statements.
79
+ 5. Prefer explicit column names over SELECT *.
80
+ 6. Add LIMIT {max_rows} when the query may return many rows.
81
+ 7. Do not add LIMIT to pure aggregate queries that return a single row, unless it is already useful for the dialect or safety.
82
+ 8. Use clear aliases for computed expressions.
83
+ 9. When the query uses more than one table, every table must have a short, stable alias.
84
+ 10. When the query uses more than one table, every column reference must be qualified with the correct table alias in SELECT, JOIN, WHERE, GROUP BY, HAVING and ORDER BY.
85
+ 11. When the query uses table aliases, use those aliases consistently and do not mix aliased and unaliased table references.
86
+ 12. Do not use unqualified column names in joins or multi-table queries.
87
+ 13. If the question is ambiguous, choose the most likely interpretation based on the schema, column names, descriptions and data rules.
88
+
89
+ Output contract:
90
+ - Return exactly one SQL SELECT statement.
91
+ - The first non-whitespace token must be SELECT or WITH.
92
+ - The final output must be executable SQL only.
93
+ """
94
+
95
+ def _build_user_prompt(self, schema: str, question: str) -> str:
96
+ return f"""Database schema:
97
+ {schema}
98
+
99
+ User question:
100
+ {question}
101
+
102
+ Generate exactly one SQL SELECT statement for this question.
103
+ If the answer requires multiple values, categories or groups, return them using one query.
104
+ Return SQL only.
105
+ """
@@ -0,0 +1 @@
1
+
@@ -0,0 +1,112 @@
1
+ import json
2
+ from typing import Any, Protocol
3
+
4
+ from gaard_core.llm_output import remove_thinking_blocks
5
+ from gaard_core.prompt_compiler.intent_classification_prompt import (
6
+ IntentClassificationPromptCompiler,
7
+ )
8
+ from gaard_core.prompt_compiler.models import CompiledPrompt
9
+ from gaard_core.query_pipeline.models import (
10
+ QueryIntentClassification,
11
+ QueryIntentDecision,
12
+ QueryRequest,
13
+ )
14
+ from gaard_llm.openai_compatible.client import OpenAICompatibleClient
15
+ from gaard_llm.providers.models import ChatCompletionRequest, ChatMessage
16
+
17
+
18
+ class IntentPromptCompiler(Protocol):
19
+ def compile(self, request: QueryRequest) -> CompiledPrompt:
20
+ pass
21
+
22
+
23
+ class LlmQueryIntentClassifier:
24
+ def __init__(
25
+ self,
26
+ client: OpenAICompatibleClient,
27
+ model: str,
28
+ extra_body: dict[str, Any] | None = None,
29
+ prompt_compiler: IntentPromptCompiler | None = None,
30
+ ) -> None:
31
+ self.client = client
32
+ self.model = model
33
+ self.extra_body = extra_body or {}
34
+ self.prompt_compiler = prompt_compiler or IntentClassificationPromptCompiler()
35
+
36
+ def classify(self, request: QueryRequest) -> QueryIntentClassification:
37
+ compiled_prompt = self.prompt_compiler.compile(request=request)
38
+
39
+ response = self.client.create_chat_completion(
40
+ ChatCompletionRequest(
41
+ model=self.model,
42
+ temperature=0.0,
43
+ extra_body=self.extra_body,
44
+ messages=[
45
+ ChatMessage(
46
+ role="system",
47
+ content=compiled_prompt.system_prompt,
48
+ ),
49
+ ChatMessage(
50
+ role="user",
51
+ content=compiled_prompt.user_prompt,
52
+ ),
53
+ ],
54
+ )
55
+ )
56
+
57
+ return parse_query_intent_classification(response.content)
58
+
59
+
60
+ def parse_query_intent_classification(value: str) -> QueryIntentClassification:
61
+ cleaned = remove_thinking_blocks(value).strip()
62
+
63
+ try:
64
+ payload = json.loads(cleaned)
65
+ except json.JSONDecodeError:
66
+ payload = {"decision": cleaned}
67
+
68
+ if not isinstance(payload, dict):
69
+ return QueryIntentClassification()
70
+
71
+ return QueryIntentClassification(
72
+ decision=parse_query_intent_decision(payload.get("decision")),
73
+ confidence=parse_confidence(payload.get("confidence")),
74
+ reason=str(payload.get("reason") or ""),
75
+ model_response=payload,
76
+ )
77
+
78
+
79
+ def parse_query_intent_decision(value: object) -> QueryIntentDecision:
80
+ if not isinstance(value, str):
81
+ return QueryIntentDecision.AMBIGUOUS
82
+
83
+ normalized = value.strip().lower().replace(" ", "_").replace("-", "_")
84
+
85
+ aliases = {
86
+ "read_only": QueryIntentDecision.READ_ONLY_DATA_QUESTION,
87
+ "readonly": QueryIntentDecision.READ_ONLY_DATA_QUESTION,
88
+ "select": QueryIntentDecision.READ_ONLY_DATA_QUESTION,
89
+ "write": QueryIntentDecision.WRITE_OR_MUTATION_REQUEST,
90
+ "mutation": QueryIntentDecision.WRITE_OR_MUTATION_REQUEST,
91
+ "unsafe": QueryIntentDecision.WRITE_OR_MUTATION_REQUEST,
92
+ "non_data": QueryIntentDecision.NON_DATA_REQUEST,
93
+ "not_data": QueryIntentDecision.NON_DATA_REQUEST,
94
+ }
95
+
96
+ if normalized in aliases:
97
+ return aliases[normalized]
98
+
99
+ for item in QueryIntentDecision:
100
+ if normalized == item.value:
101
+ return item
102
+
103
+ return QueryIntentDecision.AMBIGUOUS
104
+
105
+
106
+ def parse_confidence(value: object) -> float:
107
+ try:
108
+ confidence = float(value)
109
+ except (TypeError, ValueError):
110
+ return 0.0
111
+
112
+ return max(0.0, min(1.0, confidence))
@@ -0,0 +1,14 @@
1
+ from gaard_core.query_pipeline.models import (
2
+ QueryIntentClassification,
3
+ QueryIntentDecision,
4
+ QueryRequest,
5
+ )
6
+
7
+
8
+ class MockQueryIntentClassifier:
9
+ def classify(self, request: QueryRequest) -> QueryIntentClassification:
10
+ return QueryIntentClassification(
11
+ decision=QueryIntentDecision.READ_ONLY_DATA_QUESTION,
12
+ confidence=1.0,
13
+ reason="Mock intent classifier allows the request.",
14
+ )
File without changes
@@ -0,0 +1,85 @@
1
+ from typing import Any, Protocol
2
+
3
+ from gaard_core.prompt_compiler.models import CompiledPrompt, SqlGenerationPromptRequest
4
+ from gaard_core.prompt_compiler.sql_generation_prompt import SqlGenerationPromptCompiler
5
+ from gaard_core.llm_output import remove_thinking_blocks
6
+ from gaard_core.query_pipeline.models import GeneratedSql, QueryRequest
7
+ from gaard_llm.openai_compatible.client import OpenAICompatibleClient
8
+ from gaard_llm.providers.models import ChatCompletionRequest, ChatMessage
9
+
10
+
11
+ class SqlPromptCompiler(Protocol):
12
+ def compile(self, request: SqlGenerationPromptRequest) -> CompiledPrompt:
13
+ pass
14
+
15
+
16
+ class LlmSqlGenerator:
17
+ def __init__(
18
+ self,
19
+ client: OpenAICompatibleClient,
20
+ model: str,
21
+ formatted_schema: str,
22
+ dialect: str = "sqlite",
23
+ max_rows: int = 100,
24
+ extra_body: dict[str, Any] | None = None,
25
+ prompt_compiler: SqlPromptCompiler | None = None,
26
+ ) -> None:
27
+ self.client = client
28
+ self.model = model
29
+ self.formatted_schema = formatted_schema
30
+ self.dialect = dialect
31
+ self.max_rows = max_rows
32
+ self.extra_body = extra_body or {}
33
+ self.prompt_compiler = prompt_compiler or SqlGenerationPromptCompiler()
34
+
35
+ def generate(self, request: QueryRequest) -> GeneratedSql:
36
+ compiled_prompt = self.prompt_compiler.compile(
37
+ SqlGenerationPromptRequest(
38
+ question=request.question,
39
+ formatted_schema=self.formatted_schema,
40
+ dialect=self.dialect,
41
+ max_rows=self.max_rows,
42
+ )
43
+ )
44
+
45
+ response = self.client.create_chat_completion(
46
+ ChatCompletionRequest(
47
+ model=self.model,
48
+ temperature=0.0,
49
+ extra_body=self.extra_body,
50
+ messages=[
51
+ ChatMessage(
52
+ role="system",
53
+ content=compiled_prompt.system_prompt,
54
+ ),
55
+ ChatMessage(
56
+ role="user",
57
+ content=compiled_prompt.user_prompt,
58
+ ),
59
+ ],
60
+ )
61
+ )
62
+
63
+ sql = self._clean_sql(response.content)
64
+
65
+ return GeneratedSql(
66
+ sql=sql,
67
+ confidence=0.0,
68
+ assumptions=[
69
+ "SQL generated by OpenAI-compatible LLM provider.",
70
+ ],
71
+ )
72
+
73
+ def _clean_sql(self, value: str) -> str:
74
+ cleaned = remove_thinking_blocks(value)
75
+
76
+ if cleaned.startswith("```sql"):
77
+ cleaned = cleaned.removeprefix("```sql").strip()
78
+
79
+ if cleaned.startswith("```"):
80
+ cleaned = cleaned.removeprefix("```").strip()
81
+
82
+ if cleaned.endswith("```"):
83
+ cleaned = cleaned.removesuffix("```").strip()
84
+
85
+ return cleaned
@@ -0,0 +1,33 @@
1
+ from gaard_core.query_pipeline.models import GeneratedSql, QueryRequest
2
+
3
+
4
+ class MockSqlGenerator:
5
+ def generate(self, request: QueryRequest) -> GeneratedSql:
6
+ question = request.question.lower()
7
+
8
+ if "aktywn" in question or "active" in question:
9
+ return GeneratedSql(
10
+ sql="SELECT COUNT(*) AS active_patients_count FROM patients WHERE status = 'active'",
11
+ confidence=0.95,
12
+ assumptions=["Using demo patients table and status = active."],
13
+ )
14
+
15
+ if "pacjent" in question or "patient" in question:
16
+ return GeneratedSql(
17
+ sql="SELECT COUNT(*) AS patients_count FROM patients",
18
+ confidence=0.95,
19
+ assumptions=["Using demo patients table."],
20
+ )
21
+
22
+ if "wizyt" in question or "appointment" in question:
23
+ return GeneratedSql(
24
+ sql="SELECT COUNT(*) AS appointments_count FROM appointments",
25
+ confidence=0.9,
26
+ assumptions=["Using demo appointments table."],
27
+ )
28
+
29
+ return GeneratedSql(
30
+ sql="SELECT 1 AS value",
31
+ confidence=0.5,
32
+ assumptions=["Fallback mock query."],
33
+ )