gaard-core 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gaard_core-0.1.0/PKG-INFO +23 -0
- gaard_core-0.1.0/README.md +10 -0
- gaard_core-0.1.0/pyproject.toml +28 -0
- gaard_core-0.1.0/setup.cfg +4 -0
- gaard_core-0.1.0/src/gaard_core/__init__.py +0 -0
- gaard_core-0.1.0/src/gaard_core/audit/__init__.py +0 -0
- gaard_core-0.1.0/src/gaard_core/errors.py +79 -0
- gaard_core-0.1.0/src/gaard_core/evaluation/__init__.py +0 -0
- gaard_core-0.1.0/src/gaard_core/execution/__init__.py +0 -0
- gaard_core-0.1.0/src/gaard_core/execution/mock_executor.py +25 -0
- gaard_core-0.1.0/src/gaard_core/investigation/__init__.py +25 -0
- gaard_core-0.1.0/src/gaard_core/investigation/llm_readiness_agent.py +220 -0
- gaard_core-0.1.0/src/gaard_core/investigation/loop.py +83 -0
- gaard_core-0.1.0/src/gaard_core/investigation/mock_readiness_agent.py +20 -0
- gaard_core-0.1.0/src/gaard_core/investigation/models.py +62 -0
- gaard_core-0.1.0/src/gaard_core/json_utils.py +58 -0
- gaard_core-0.1.0/src/gaard_core/llm_output.py +12 -0
- gaard_core-0.1.0/src/gaard_core/policy_engine/__init__.py +0 -0
- gaard_core-0.1.0/src/gaard_core/prompt_compiler/__init__.py +0 -0
- gaard_core-0.1.0/src/gaard_core/prompt_compiler/intent_classification_prompt.py +58 -0
- gaard_core-0.1.0/src/gaard_core/prompt_compiler/investigation_readiness_prompt.py +84 -0
- gaard_core-0.1.0/src/gaard_core/prompt_compiler/models.py +19 -0
- gaard_core-0.1.0/src/gaard_core/prompt_compiler/result_classification_prompt.py +62 -0
- gaard_core-0.1.0/src/gaard_core/prompt_compiler/result_interpretation_prompt.py +73 -0
- gaard_core-0.1.0/src/gaard_core/prompt_compiler/schema_formatter.py +43 -0
- gaard_core-0.1.0/src/gaard_core/prompt_compiler/sql_generation_prompt.py +105 -0
- gaard_core-0.1.0/src/gaard_core/query_intent/__init__.py +1 -0
- gaard_core-0.1.0/src/gaard_core/query_intent/llm_classifier.py +112 -0
- gaard_core-0.1.0/src/gaard_core/query_intent/mock_classifier.py +14 -0
- gaard_core-0.1.0/src/gaard_core/query_pipeline/__init__.py +0 -0
- gaard_core-0.1.0/src/gaard_core/query_pipeline/llm_sql_generator.py +85 -0
- gaard_core-0.1.0/src/gaard_core/query_pipeline/mock_sql_generator.py +33 -0
- gaard_core-0.1.0/src/gaard_core/query_pipeline/models.py +57 -0
- gaard_core-0.1.0/src/gaard_core/query_pipeline/pipeline.py +124 -0
- gaard_core-0.1.0/src/gaard_core/result_classifier/__init__.py +1 -0
- gaard_core-0.1.0/src/gaard_core/result_classifier/llm_classifier.py +87 -0
- gaard_core-0.1.0/src/gaard_core/result_classifier/mock_classifier.py +10 -0
- gaard_core-0.1.0/src/gaard_core/result_interpreter/__init__.py +0 -0
- gaard_core-0.1.0/src/gaard_core/result_interpreter/llm_interpreter.py +66 -0
- gaard_core-0.1.0/src/gaard_core/result_interpreter/mock_interpreter.py +25 -0
- gaard_core-0.1.0/src/gaard_core/schema/__init__.py +0 -0
- gaard_core-0.1.0/src/gaard_core/schema/cache.py +59 -0
- gaard_core-0.1.0/src/gaard_core/schema/context.py +40 -0
- gaard_core-0.1.0/src/gaard_core/schema/models.py +27 -0
- gaard_core-0.1.0/src/gaard_core/security/__init__.py +0 -0
- gaard_core-0.1.0/src/gaard_core/semantic_layer/__init__.py +0 -0
- gaard_core-0.1.0/src/gaard_core/sql_validator/__init__.py +0 -0
- gaard_core-0.1.0/src/gaard_core/sql_validator/select_only.py +37 -0
- gaard_core-0.1.0/src/gaard_core.egg-info/PKG-INFO +23 -0
- gaard_core-0.1.0/src/gaard_core.egg-info/SOURCES.txt +66 -0
- gaard_core-0.1.0/src/gaard_core.egg-info/dependency_links.txt +1 -0
- gaard_core-0.1.0/src/gaard_core.egg-info/requires.txt +7 -0
- gaard_core-0.1.0/src/gaard_core.egg-info/top_level.txt +1 -0
- gaard_core-0.1.0/tests/test_investigation_readiness.py +150 -0
- gaard_core-0.1.0/tests/test_json_utils.py +30 -0
- gaard_core-0.1.0/tests/test_llm_output.py +13 -0
- gaard_core-0.1.0/tests/test_llm_query_intent_classifier.py +85 -0
- gaard_core-0.1.0/tests/test_llm_result_classifier.py +75 -0
- gaard_core-0.1.0/tests/test_llm_result_interpreter.py +103 -0
- gaard_core-0.1.0/tests/test_llm_sql_generator.py +74 -0
- gaard_core-0.1.0/tests/test_query_pipeline.py +71 -0
- gaard_core-0.1.0/tests/test_result_classification_prompt_compiler.py +17 -0
- gaard_core-0.1.0/tests/test_result_interpretation_prompt_compiler.py +25 -0
- gaard_core-0.1.0/tests/test_schema_context_cache.py +47 -0
- gaard_core-0.1.0/tests/test_schema_context_service.py +38 -0
- gaard_core-0.1.0/tests/test_schema_prompt_formatter.py +52 -0
- gaard_core-0.1.0/tests/test_sql_generation_prompt_compiler.py +41 -0
- gaard_core-0.1.0/tests/test_sql_validator.py +36 -0
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: gaard-core
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Core GAARD query pipeline, prompt compiler, policies and SQL validation
|
|
5
|
+
Requires-Python: >=3.11
|
|
6
|
+
Description-Content-Type: text/markdown
|
|
7
|
+
Requires-Dist: pydantic>=2.7.0
|
|
8
|
+
Requires-Dist: sqlglot>=25.0.0
|
|
9
|
+
Provides-Extra: dev
|
|
10
|
+
Requires-Dist: pytest>=8.0.0; extra == "dev"
|
|
11
|
+
Requires-Dist: ruff>=0.5.0; extra == "dev"
|
|
12
|
+
Requires-Dist: mypy>=1.10.0; extra == "dev"
|
|
13
|
+
|
|
14
|
+
# GAARD - Governed AI Access to Relational Data
|
|
15
|
+
|
|
16
|
+
GAARD is a self-hosted AI SQL Gateway for governed natural-language access to relational data.
|
|
17
|
+
|
|
18
|
+
GAARD allows applications and users to ask questions about relational databases using natural language while keeping SQL generation, validation, execution, prompts, connectors, and auditability under control.
|
|
19
|
+
|
|
20
|
+
For more informacion see https://github.com/pkroliszewski/gaard
|
|
21
|
+
|
|
22
|
+
# This package
|
|
23
|
+
Package gaard-core provides internal logic required for other packages of gaard.
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
# GAARD - Governed AI Access to Relational Data
|
|
2
|
+
|
|
3
|
+
GAARD is a self-hosted AI SQL Gateway for governed natural-language access to relational data.
|
|
4
|
+
|
|
5
|
+
GAARD allows applications and users to ask questions about relational databases using natural language while keeping SQL generation, validation, execution, prompts, connectors, and auditability under control.
|
|
6
|
+
|
|
7
|
+
For more informacion see https://github.com/pkroliszewski/gaard
|
|
8
|
+
|
|
9
|
+
# This package
|
|
10
|
+
Package gaard-core provides internal logic required for other packages of gaard.
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=69", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "gaard-core"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
description = "Core GAARD query pipeline, prompt compiler, policies and SQL validation"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.11"
|
|
11
|
+
dependencies = [
|
|
12
|
+
"pydantic>=2.7.0",
|
|
13
|
+
"sqlglot>=25.0.0",
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
[project.optional-dependencies]
|
|
17
|
+
dev = [
|
|
18
|
+
"pytest>=8.0.0",
|
|
19
|
+
"ruff>=0.5.0",
|
|
20
|
+
"mypy>=1.10.0",
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
[tool.ruff]
|
|
24
|
+
line-length = 100
|
|
25
|
+
target-version = "py311"
|
|
26
|
+
|
|
27
|
+
[tool.setuptools.packages.find]
|
|
28
|
+
where = ["src"]
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class GaardError(Exception):
|
|
5
|
+
code = "GAARD_ERROR"
|
|
6
|
+
status_code = 500
|
|
7
|
+
|
|
8
|
+
def __init__(self, message: str | None = None) -> None:
|
|
9
|
+
self.message = message or "GAARD error."
|
|
10
|
+
super().__init__(self.message)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ConfigurationError(GaardError):
|
|
14
|
+
code = "CONFIGURATION_ERROR"
|
|
15
|
+
status_code = 500
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class SqlGenerationError(GaardError):
|
|
19
|
+
code = "SQL_GENERATION_ERROR"
|
|
20
|
+
status_code = 502
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class SqlValidationError(GaardError):
|
|
24
|
+
code = "SQL_VALIDATION_ERROR"
|
|
25
|
+
status_code = 400
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
message: str | None = None,
|
|
30
|
+
sql: str = "",
|
|
31
|
+
error_detail: str = "",
|
|
32
|
+
metadata: dict[str, Any] | None = None,
|
|
33
|
+
) -> None:
|
|
34
|
+
self.sql = sql
|
|
35
|
+
self.error_detail = error_detail
|
|
36
|
+
self.metadata = metadata or {}
|
|
37
|
+
super().__init__(message)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class QueryExecutionError(GaardError):
|
|
41
|
+
code = "QUERY_EXECUTION_ERROR"
|
|
42
|
+
status_code = 400
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
message: str | None = None,
|
|
47
|
+
sql: str = "",
|
|
48
|
+
error_detail: str = "",
|
|
49
|
+
metadata: dict[str, Any] | None = None,
|
|
50
|
+
) -> None:
|
|
51
|
+
self.sql = sql
|
|
52
|
+
self.error_detail = error_detail
|
|
53
|
+
self.metadata = metadata or {}
|
|
54
|
+
super().__init__(message)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class QueryPipelineStepError(GaardError):
|
|
58
|
+
status_code = 502
|
|
59
|
+
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
message: str | None = None,
|
|
63
|
+
phase: str = "",
|
|
64
|
+
sql: str = "",
|
|
65
|
+
error_code: str = "QUERY_PIPELINE_STEP_ERROR",
|
|
66
|
+
error_detail: str = "",
|
|
67
|
+
metadata: dict[str, Any] | None = None,
|
|
68
|
+
) -> None:
|
|
69
|
+
self.code = error_code
|
|
70
|
+
self.phase = phase
|
|
71
|
+
self.sql = sql
|
|
72
|
+
self.error_detail = error_detail
|
|
73
|
+
self.metadata = metadata or {}
|
|
74
|
+
super().__init__(message)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class LlmProviderError(GaardError):
|
|
78
|
+
code = "LLM_PROVIDER_ERROR"
|
|
79
|
+
status_code = 502
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from gaard_core.query_pipeline.models import QueryResult
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class MockQueryExecutor:
|
|
5
|
+
def execute(self, sql: str) -> QueryResult:
|
|
6
|
+
normalized = sql.lower()
|
|
7
|
+
|
|
8
|
+
if "patients" in normalized:
|
|
9
|
+
return QueryResult(
|
|
10
|
+
columns=["patients_count"],
|
|
11
|
+
rows=[
|
|
12
|
+
{
|
|
13
|
+
"patients_count": 124,
|
|
14
|
+
}
|
|
15
|
+
],
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
return QueryResult(
|
|
19
|
+
columns=["value"],
|
|
20
|
+
rows=[
|
|
21
|
+
{
|
|
22
|
+
"value": 1,
|
|
23
|
+
}
|
|
24
|
+
],
|
|
25
|
+
)
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from gaard_core.investigation.loop import InvestigationLoop
|
|
2
|
+
from gaard_core.investigation.llm_readiness_agent import LlmInvestigationReadinessAgent
|
|
3
|
+
from gaard_core.investigation.mock_readiness_agent import MockInvestigationReadinessAgent
|
|
4
|
+
from gaard_core.investigation.models import (
|
|
5
|
+
InvestigationContext,
|
|
6
|
+
InvestigationIteration,
|
|
7
|
+
InvestigationLoopConfig,
|
|
8
|
+
InvestigationLoopResult,
|
|
9
|
+
InvestigationReadinessDecision,
|
|
10
|
+
InvestigationRoute,
|
|
11
|
+
RequiredAnalysisTask,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"InvestigationContext",
|
|
16
|
+
"InvestigationIteration",
|
|
17
|
+
"InvestigationLoop",
|
|
18
|
+
"InvestigationLoopConfig",
|
|
19
|
+
"InvestigationLoopResult",
|
|
20
|
+
"InvestigationReadinessDecision",
|
|
21
|
+
"InvestigationRoute",
|
|
22
|
+
"LlmInvestigationReadinessAgent",
|
|
23
|
+
"MockInvestigationReadinessAgent",
|
|
24
|
+
"RequiredAnalysisTask",
|
|
25
|
+
]
|
|
@@ -0,0 +1,220 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from typing import Any, Protocol
|
|
3
|
+
|
|
4
|
+
from gaard_core.investigation.models import (
|
|
5
|
+
InvestigationContext,
|
|
6
|
+
InvestigationReadinessDecision,
|
|
7
|
+
InvestigationRoute,
|
|
8
|
+
RequiredAnalysisTask,
|
|
9
|
+
)
|
|
10
|
+
from gaard_core.llm_output import remove_thinking_blocks
|
|
11
|
+
from gaard_core.prompt_compiler.investigation_readiness_prompt import (
|
|
12
|
+
InvestigationReadinessPromptCompiler,
|
|
13
|
+
)
|
|
14
|
+
from gaard_core.prompt_compiler.models import CompiledPrompt
|
|
15
|
+
from gaard_llm.openai_compatible.client import OpenAICompatibleClient
|
|
16
|
+
from gaard_llm.providers.models import ChatCompletionRequest, ChatMessage
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class InvestigationReadinessPromptCompilerProtocol(Protocol):
|
|
20
|
+
def compile(self, context: InvestigationContext) -> CompiledPrompt:
|
|
21
|
+
pass
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class LlmInvestigationReadinessAgent:
|
|
25
|
+
name = "llm_investigation_readiness"
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
client: OpenAICompatibleClient,
|
|
30
|
+
model: str,
|
|
31
|
+
extra_body: dict[str, Any] | None = None,
|
|
32
|
+
prompt_compiler: InvestigationReadinessPromptCompilerProtocol | None = None,
|
|
33
|
+
) -> None:
|
|
34
|
+
self.client = client
|
|
35
|
+
self.model = model
|
|
36
|
+
self.extra_body = extra_body or {}
|
|
37
|
+
self.prompt_compiler = prompt_compiler or InvestigationReadinessPromptCompiler()
|
|
38
|
+
|
|
39
|
+
def assess(self, context: InvestigationContext) -> InvestigationReadinessDecision:
|
|
40
|
+
compiled_prompt = self.prompt_compiler.compile(context=context)
|
|
41
|
+
|
|
42
|
+
response = self.client.create_chat_completion(
|
|
43
|
+
ChatCompletionRequest(
|
|
44
|
+
model=self.model,
|
|
45
|
+
temperature=0.0,
|
|
46
|
+
extra_body=self.extra_body,
|
|
47
|
+
messages=[
|
|
48
|
+
ChatMessage(
|
|
49
|
+
role="system",
|
|
50
|
+
content=compiled_prompt.system_prompt,
|
|
51
|
+
),
|
|
52
|
+
ChatMessage(
|
|
53
|
+
role="user",
|
|
54
|
+
content=compiled_prompt.user_prompt,
|
|
55
|
+
),
|
|
56
|
+
],
|
|
57
|
+
)
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
return parse_investigation_readiness_decision(response.content)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def parse_investigation_readiness_decision(value: str) -> InvestigationReadinessDecision:
|
|
64
|
+
cleaned = remove_thinking_blocks(value).strip()
|
|
65
|
+
|
|
66
|
+
try:
|
|
67
|
+
payload = json.loads(cleaned)
|
|
68
|
+
except json.JSONDecodeError:
|
|
69
|
+
return InvestigationReadinessDecision(
|
|
70
|
+
ready_for_sql=False,
|
|
71
|
+
route=InvestigationRoute.ANALYSIS,
|
|
72
|
+
confidence=0.0,
|
|
73
|
+
reason="Investigation readiness agent returned invalid JSON.",
|
|
74
|
+
missing_information=["valid readiness JSON"],
|
|
75
|
+
required_analysis=["Retry readiness assessment with a valid JSON response."],
|
|
76
|
+
model_response={"raw": cleaned},
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
if not isinstance(payload, dict):
|
|
80
|
+
return InvestigationReadinessDecision(
|
|
81
|
+
ready_for_sql=False,
|
|
82
|
+
route=InvestigationRoute.ANALYSIS,
|
|
83
|
+
confidence=0.0,
|
|
84
|
+
reason="Investigation readiness agent returned a non-object JSON value.",
|
|
85
|
+
missing_information=["valid readiness JSON object"],
|
|
86
|
+
required_analysis=["Retry readiness assessment with a JSON object response."],
|
|
87
|
+
model_response={"raw": payload},
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
ready_for_sql = parse_bool(payload.get("ready_for_sql"))
|
|
91
|
+
route = parse_route(payload.get("route"), ready_for_sql)
|
|
92
|
+
|
|
93
|
+
if route == InvestigationRoute.SQL and not ready_for_sql:
|
|
94
|
+
route = InvestigationRoute.ANALYSIS
|
|
95
|
+
|
|
96
|
+
if route == InvestigationRoute.ANALYSIS:
|
|
97
|
+
ready_for_sql = False
|
|
98
|
+
|
|
99
|
+
missing_information = parse_string_list(payload.get("missing_information"))
|
|
100
|
+
required_analysis = parse_string_list(payload.get("required_analysis"))
|
|
101
|
+
|
|
102
|
+
return InvestigationReadinessDecision(
|
|
103
|
+
ready_for_sql=ready_for_sql,
|
|
104
|
+
route=route,
|
|
105
|
+
confidence=parse_confidence(payload.get("confidence")),
|
|
106
|
+
reason=str(payload.get("reason") or ""),
|
|
107
|
+
missing_information=missing_information,
|
|
108
|
+
required_analysis=required_analysis,
|
|
109
|
+
required_analysis_tasks=parse_required_analysis_tasks(
|
|
110
|
+
payload.get("required_analysis_tasks"),
|
|
111
|
+
missing_information,
|
|
112
|
+
required_analysis,
|
|
113
|
+
),
|
|
114
|
+
assumptions=parse_string_list(payload.get("assumptions")),
|
|
115
|
+
model_response=payload,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def parse_route(value: object, ready_for_sql: bool) -> InvestigationRoute:
|
|
120
|
+
if isinstance(value, str):
|
|
121
|
+
normalized = value.strip().lower().replace("-", "_").replace(" ", "_")
|
|
122
|
+
if normalized in {"sql", "ready", "ready_for_sql"}:
|
|
123
|
+
return InvestigationRoute.SQL
|
|
124
|
+
if normalized in {"analysis", "analyze", "requires_analysis"}:
|
|
125
|
+
return InvestigationRoute.ANALYSIS
|
|
126
|
+
|
|
127
|
+
return InvestigationRoute.SQL if ready_for_sql else InvestigationRoute.ANALYSIS
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def parse_bool(value: object) -> bool:
|
|
131
|
+
if isinstance(value, bool):
|
|
132
|
+
return value
|
|
133
|
+
|
|
134
|
+
if isinstance(value, str):
|
|
135
|
+
return value.strip().lower() in {"true", "yes", "tak", "1"}
|
|
136
|
+
|
|
137
|
+
return False
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def parse_confidence(value: object) -> float:
|
|
141
|
+
try:
|
|
142
|
+
confidence = float(value)
|
|
143
|
+
except (TypeError, ValueError):
|
|
144
|
+
return 0.0
|
|
145
|
+
|
|
146
|
+
return max(0.0, min(1.0, confidence))
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def parse_string_list(value: object) -> list[str]:
|
|
150
|
+
if not isinstance(value, list):
|
|
151
|
+
return []
|
|
152
|
+
|
|
153
|
+
items: list[str] = []
|
|
154
|
+
for item in value:
|
|
155
|
+
if item is None:
|
|
156
|
+
continue
|
|
157
|
+
text = str(item).strip()
|
|
158
|
+
if text:
|
|
159
|
+
items.append(text)
|
|
160
|
+
|
|
161
|
+
return items
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def parse_required_analysis_tasks(
|
|
165
|
+
value: object,
|
|
166
|
+
missing_information: list[str],
|
|
167
|
+
required_analysis: list[str],
|
|
168
|
+
) -> list[RequiredAnalysisTask]:
|
|
169
|
+
if isinstance(value, list):
|
|
170
|
+
tasks = [
|
|
171
|
+
parse_required_analysis_task(item)
|
|
172
|
+
for item in value
|
|
173
|
+
if isinstance(item, dict)
|
|
174
|
+
]
|
|
175
|
+
tasks = [task for task in tasks if task.required_analysis]
|
|
176
|
+
if tasks:
|
|
177
|
+
return tasks
|
|
178
|
+
|
|
179
|
+
return required_analysis_tasks_from_lists(missing_information, required_analysis)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def parse_required_analysis_task(value: dict[str, object]) -> RequiredAnalysisTask:
|
|
183
|
+
return RequiredAnalysisTask(
|
|
184
|
+
missing_information=str(value.get("missing_information") or "").strip(),
|
|
185
|
+
required_analysis=str(value.get("required_analysis") or "").strip(),
|
|
186
|
+
category=normalize_analysis_category(value.get("category")),
|
|
187
|
+
expected_output=str(value.get("expected_output") or "").strip(),
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def required_analysis_tasks_from_lists(
|
|
192
|
+
missing_information: list[str],
|
|
193
|
+
required_analysis: list[str],
|
|
194
|
+
) -> list[RequiredAnalysisTask]:
|
|
195
|
+
tasks: list[RequiredAnalysisTask] = []
|
|
196
|
+
for index, analysis_question in enumerate(required_analysis):
|
|
197
|
+
tasks.append(
|
|
198
|
+
RequiredAnalysisTask(
|
|
199
|
+
missing_information=missing_information[index]
|
|
200
|
+
if index < len(missing_information)
|
|
201
|
+
else "",
|
|
202
|
+
required_analysis=analysis_question,
|
|
203
|
+
)
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
return tasks
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def normalize_analysis_category(value: object) -> str:
|
|
210
|
+
normalized = str(value or "unknown").strip().lower().replace("-", "_").replace(" ", "_")
|
|
211
|
+
allowed_categories = {
|
|
212
|
+
"dictionary_value",
|
|
213
|
+
"relationship_logic",
|
|
214
|
+
"filter_logic",
|
|
215
|
+
"aggregation_logic",
|
|
216
|
+
"entity_mapping",
|
|
217
|
+
"unknown",
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
return normalized if normalized in allowed_categories else "unknown"
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
from typing import Protocol
|
|
2
|
+
|
|
3
|
+
from gaard_core.investigation.models import (
|
|
4
|
+
InvestigationContext,
|
|
5
|
+
InvestigationIteration,
|
|
6
|
+
InvestigationLoopConfig,
|
|
7
|
+
InvestigationLoopResult,
|
|
8
|
+
InvestigationReadinessDecision,
|
|
9
|
+
InvestigationRoute,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class InvestigationReadinessAgent(Protocol):
|
|
14
|
+
name: str
|
|
15
|
+
|
|
16
|
+
def assess(self, context: InvestigationContext) -> InvestigationReadinessDecision:
|
|
17
|
+
pass
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class InvestigationLoop:
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
readiness_agent: InvestigationReadinessAgent,
|
|
24
|
+
config: InvestigationLoopConfig | None = None,
|
|
25
|
+
) -> None:
|
|
26
|
+
self.readiness_agent = readiness_agent
|
|
27
|
+
self.config = config or InvestigationLoopConfig()
|
|
28
|
+
|
|
29
|
+
def run(self, context: InvestigationContext) -> InvestigationLoopResult:
|
|
30
|
+
iterations: list[InvestigationIteration] = []
|
|
31
|
+
|
|
32
|
+
for iteration_number in range(1, self.config.max_iterations + 1):
|
|
33
|
+
decision = self.readiness_agent.assess(context)
|
|
34
|
+
normalized_decision = self._normalize_decision(decision)
|
|
35
|
+
iterations.append(
|
|
36
|
+
InvestigationIteration(
|
|
37
|
+
iteration=iteration_number,
|
|
38
|
+
agent=self.readiness_agent.name,
|
|
39
|
+
decision=normalized_decision,
|
|
40
|
+
)
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
if normalized_decision.route == InvestigationRoute.SQL:
|
|
44
|
+
return InvestigationLoopResult(
|
|
45
|
+
route=InvestigationRoute.SQL,
|
|
46
|
+
ready_for_sql=True,
|
|
47
|
+
max_iterations=self.config.max_iterations,
|
|
48
|
+
confidence_threshold=self.config.readiness_confidence_threshold,
|
|
49
|
+
iterations=iterations,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
return InvestigationLoopResult(
|
|
53
|
+
route=InvestigationRoute.ANALYSIS,
|
|
54
|
+
ready_for_sql=False,
|
|
55
|
+
max_iterations=self.config.max_iterations,
|
|
56
|
+
confidence_threshold=self.config.readiness_confidence_threshold,
|
|
57
|
+
iterations=iterations,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
return InvestigationLoopResult(
|
|
61
|
+
route=InvestigationRoute.ANALYSIS,
|
|
62
|
+
ready_for_sql=False,
|
|
63
|
+
max_iterations=self.config.max_iterations,
|
|
64
|
+
confidence_threshold=self.config.readiness_confidence_threshold,
|
|
65
|
+
iterations=iterations,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
def _normalize_decision(
|
|
69
|
+
self,
|
|
70
|
+
decision: InvestigationReadinessDecision,
|
|
71
|
+
) -> InvestigationReadinessDecision:
|
|
72
|
+
ready = (
|
|
73
|
+
decision.ready_for_sql
|
|
74
|
+
and decision.confidence >= self.config.readiness_confidence_threshold
|
|
75
|
+
)
|
|
76
|
+
route = InvestigationRoute.SQL if ready else InvestigationRoute.ANALYSIS
|
|
77
|
+
|
|
78
|
+
return decision.model_copy(
|
|
79
|
+
update={
|
|
80
|
+
"ready_for_sql": ready,
|
|
81
|
+
"route": route,
|
|
82
|
+
}
|
|
83
|
+
)
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from gaard_core.investigation.models import (
|
|
2
|
+
InvestigationContext,
|
|
3
|
+
InvestigationReadinessDecision,
|
|
4
|
+
InvestigationRoute,
|
|
5
|
+
)
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class MockInvestigationReadinessAgent:
|
|
9
|
+
name = "mock_investigation_readiness"
|
|
10
|
+
|
|
11
|
+
def __init__(self, decision: InvestigationReadinessDecision | None = None) -> None:
|
|
12
|
+
self.decision = decision or InvestigationReadinessDecision(
|
|
13
|
+
ready_for_sql=True,
|
|
14
|
+
route=InvestigationRoute.SQL,
|
|
15
|
+
confidence=1.0,
|
|
16
|
+
reason="Mock readiness agent allows the normal SQL pipeline.",
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
def assess(self, context: InvestigationContext) -> InvestigationReadinessDecision:
|
|
20
|
+
return self.decision
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
from enum import StrEnum
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel, Field
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class InvestigationRoute(StrEnum):
|
|
8
|
+
SQL = "sql"
|
|
9
|
+
ANALYSIS = "analysis"
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class InvestigationContext(BaseModel):
|
|
13
|
+
question: str = Field(min_length=1)
|
|
14
|
+
datasource_id: str = "default"
|
|
15
|
+
user_id: str = "local-admin"
|
|
16
|
+
formatted_schema: str = ""
|
|
17
|
+
business_logic: str = ""
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class RequiredAnalysisTask(BaseModel):
|
|
21
|
+
missing_information: str = ""
|
|
22
|
+
required_analysis: str = ""
|
|
23
|
+
category: str = "unknown"
|
|
24
|
+
expected_output: str = ""
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class InvestigationReadinessDecision(BaseModel):
|
|
28
|
+
ready_for_sql: bool = False
|
|
29
|
+
route: InvestigationRoute = InvestigationRoute.ANALYSIS
|
|
30
|
+
confidence: float = 0.0
|
|
31
|
+
reason: str = ""
|
|
32
|
+
missing_information: list[str] = Field(default_factory=list)
|
|
33
|
+
required_analysis: list[str] = Field(default_factory=list)
|
|
34
|
+
required_analysis_tasks: list[RequiredAnalysisTask] = Field(default_factory=list)
|
|
35
|
+
assumptions: list[str] = Field(default_factory=list)
|
|
36
|
+
model_response: dict[str, Any] = Field(default_factory=dict)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class InvestigationIteration(BaseModel):
|
|
40
|
+
iteration: int
|
|
41
|
+
agent: str
|
|
42
|
+
decision: InvestigationReadinessDecision
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class InvestigationLoopConfig(BaseModel):
|
|
46
|
+
max_iterations: int = Field(default=1, ge=1)
|
|
47
|
+
readiness_confidence_threshold: float = Field(default=0.85, ge=0.0, le=1.0)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class InvestigationLoopResult(BaseModel):
|
|
51
|
+
route: InvestigationRoute
|
|
52
|
+
ready_for_sql: bool
|
|
53
|
+
max_iterations: int
|
|
54
|
+
confidence_threshold: float
|
|
55
|
+
iterations: list[InvestigationIteration] = Field(default_factory=list)
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def final_decision(self) -> InvestigationReadinessDecision | None:
|
|
59
|
+
if not self.iterations:
|
|
60
|
+
return None
|
|
61
|
+
|
|
62
|
+
return self.iterations[-1].decision
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import math
|
|
3
|
+
from collections.abc import Mapping
|
|
4
|
+
from datetime import date, datetime, time
|
|
5
|
+
from decimal import Decimal
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def to_jsonable(value: Any) -> Any:
|
|
10
|
+
if value is None or isinstance(value, str | int | float | bool):
|
|
11
|
+
return value
|
|
12
|
+
|
|
13
|
+
if isinstance(value, Decimal):
|
|
14
|
+
return _decimal_to_jsonable(value)
|
|
15
|
+
|
|
16
|
+
if isinstance(value, datetime | date | time):
|
|
17
|
+
return value.isoformat()
|
|
18
|
+
|
|
19
|
+
if isinstance(value, bytes | bytearray | memoryview):
|
|
20
|
+
return _bytes_to_jsonable(value)
|
|
21
|
+
|
|
22
|
+
if isinstance(value, Mapping):
|
|
23
|
+
return {str(key): to_jsonable(item) for key, item in value.items()}
|
|
24
|
+
|
|
25
|
+
if isinstance(value, list | tuple | set | frozenset):
|
|
26
|
+
return [to_jsonable(item) for item in value]
|
|
27
|
+
|
|
28
|
+
if hasattr(value, "model_dump"):
|
|
29
|
+
return to_jsonable(value.model_dump())
|
|
30
|
+
|
|
31
|
+
return str(value)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def json_dumps(value: Any, **kwargs: Any) -> str:
|
|
35
|
+
return json.dumps(to_jsonable(value), **kwargs)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _decimal_to_jsonable(value: Decimal) -> int | float | str:
|
|
39
|
+
if not value.is_finite():
|
|
40
|
+
return str(value)
|
|
41
|
+
|
|
42
|
+
if value == value.to_integral_value():
|
|
43
|
+
return int(value)
|
|
44
|
+
|
|
45
|
+
as_float = float(value)
|
|
46
|
+
if math.isfinite(as_float):
|
|
47
|
+
return as_float
|
|
48
|
+
|
|
49
|
+
return str(value)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _bytes_to_jsonable(value: bytes | bytearray | memoryview) -> str:
|
|
53
|
+
raw = bytes(value)
|
|
54
|
+
|
|
55
|
+
try:
|
|
56
|
+
return raw.decode("utf-8")
|
|
57
|
+
except UnicodeDecodeError:
|
|
58
|
+
return raw.hex()
|
|
File without changes
|
|
File without changes
|