gaard-api 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gaard_api/__init__.py +0 -0
- gaard_api/admin/__init__.py +1 -0
- gaard_api/admin/database.py +513 -0
- gaard_api/admin/defaults.py +271 -0
- gaard_api/admin/models.py +253 -0
- gaard_api/admin/prompt_runtime.py +237 -0
- gaard_api/admin/security.py +45 -0
- gaard_api/admin/services.py +2142 -0
- gaard_api/admin-web/assets/main.js +2056 -0
- gaard_api/admin-web/assets/styles.css +1041 -0
- gaard_api/admin-web/index.html +13 -0
- gaard_api/admin-web/src/main.ts +2343 -0
- gaard_api/api/__init__.py +0 -0
- gaard_api/api/v1/__init__.py +0 -0
- gaard_api/api/v1/admin.py +2174 -0
- gaard_api/api/v1/prompts.py +55 -0
- gaard_api/api/v1/query.py +1109 -0
- gaard_api/api/v1/schema.py +60 -0
- gaard_api/cli.py +19 -0
- gaard_api/cli_commands.py +18 -0
- gaard_api/core/__init__.py +0 -0
- gaard_api/core/error_handlers.py +18 -0
- gaard_api/core/schema_cache.py +7 -0
- gaard_api/core/settings.py +103 -0
- gaard_api/dependencies/__init__.py +0 -0
- gaard_api/main.py +53 -0
- gaard_api/schemas/__init__.py +0 -0
- gaard_api/server_cli.py +25 -0
- gaard_api/services/__init__.py +0 -0
- gaard_api-0.1.0.dist-info/METADATA +44 -0
- gaard_api-0.1.0.dist-info/RECORD +34 -0
- gaard_api-0.1.0.dist-info/WHEEL +5 -0
- gaard_api-0.1.0.dist-info/entry_points.txt +7 -0
- gaard_api-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1109 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import time
|
|
3
|
+
from collections.abc import Iterator
|
|
4
|
+
from queue import Queue
|
|
5
|
+
from threading import Thread
|
|
6
|
+
from typing import Any, Callable
|
|
7
|
+
|
|
8
|
+
from fastapi import APIRouter
|
|
9
|
+
from fastapi.responses import StreamingResponse
|
|
10
|
+
|
|
11
|
+
from gaard_connectors.sqlalchemy.executor import SQLAlchemyQueryExecutor
|
|
12
|
+
from gaard_connectors.sqlalchemy.introspector import SQLAlchemySchemaIntrospector
|
|
13
|
+
from gaard_core.errors import (
|
|
14
|
+
ConfigurationError,
|
|
15
|
+
LlmProviderError,
|
|
16
|
+
QueryExecutionError,
|
|
17
|
+
QueryPipelineStepError,
|
|
18
|
+
SqlValidationError,
|
|
19
|
+
)
|
|
20
|
+
from gaard_core.investigation import (
|
|
21
|
+
InvestigationContext,
|
|
22
|
+
InvestigationLoop,
|
|
23
|
+
InvestigationLoopConfig,
|
|
24
|
+
InvestigationLoopResult,
|
|
25
|
+
InvestigationRoute,
|
|
26
|
+
LlmInvestigationReadinessAgent,
|
|
27
|
+
MockInvestigationReadinessAgent,
|
|
28
|
+
RequiredAnalysisTask,
|
|
29
|
+
)
|
|
30
|
+
from gaard_core.query_intent.llm_classifier import LlmQueryIntentClassifier
|
|
31
|
+
from gaard_core.query_intent.mock_classifier import MockQueryIntentClassifier
|
|
32
|
+
from gaard_core.query_pipeline.llm_sql_generator import LlmSqlGenerator
|
|
33
|
+
from gaard_core.query_pipeline.mock_sql_generator import MockSqlGenerator
|
|
34
|
+
from gaard_core.query_pipeline.models import (
|
|
35
|
+
OutputClassification,
|
|
36
|
+
QueryIntentClassification,
|
|
37
|
+
QueryIntentDecision,
|
|
38
|
+
QueryMode,
|
|
39
|
+
QueryRequest,
|
|
40
|
+
QueryResponse,
|
|
41
|
+
)
|
|
42
|
+
from gaard_core.query_pipeline.pipeline import QueryPipeline
|
|
43
|
+
from gaard_core.result_classifier.llm_classifier import LlmResultClassifier
|
|
44
|
+
from gaard_core.result_classifier.mock_classifier import MockResultClassifier
|
|
45
|
+
from gaard_core.result_interpreter.llm_interpreter import LlmResultInterpreter
|
|
46
|
+
from gaard_core.result_interpreter.mock_interpreter import MockResultInterpreter
|
|
47
|
+
from gaard_core.schema.context import SchemaContextService
|
|
48
|
+
from gaard_core.sql_validator.select_only import SelectOnlySqlValidator
|
|
49
|
+
from gaard_llm.openai_compatible.client import OpenAICompatibleClient
|
|
50
|
+
|
|
51
|
+
from gaard_api.admin.models import DatasourceConnector, DatasourceSchemaCache
|
|
52
|
+
from gaard_api.admin.prompt_runtime import (
|
|
53
|
+
get_investigation_readiness_prompt_compiler,
|
|
54
|
+
get_intent_classification_prompt_compiler,
|
|
55
|
+
get_result_classification_prompt_compiler,
|
|
56
|
+
get_result_interpretation_prompt_compiler,
|
|
57
|
+
get_sql_generation_prompt_compiler,
|
|
58
|
+
)
|
|
59
|
+
from gaard_api.admin.services import (
|
|
60
|
+
ACCESS_ERROR_INTENT_CLASSIFICATION,
|
|
61
|
+
ACCESS_ERROR_SQL_VALIDATION,
|
|
62
|
+
get_active_business_logic_prompt_safe,
|
|
63
|
+
get_datasource_schema_context_safe,
|
|
64
|
+
get_llm_runtime_config_safe,
|
|
65
|
+
get_query_runtime_config_safe,
|
|
66
|
+
learn_business_logic_from_sql_error,
|
|
67
|
+
LlmRuntimeConfig,
|
|
68
|
+
QueryRuntimeConfig,
|
|
69
|
+
record_data_query_access_error_audit,
|
|
70
|
+
record_data_query_audit,
|
|
71
|
+
record_data_query_pipeline_error_audit,
|
|
72
|
+
record_data_query_sql_error_audit,
|
|
73
|
+
upsert_investigation_analysis_business_logic_suggestion,
|
|
74
|
+
)
|
|
75
|
+
from gaard_api.api.v1.schema import get_schema_cache_key
|
|
76
|
+
from gaard_api.core.schema_cache import schema_context_cache
|
|
77
|
+
from gaard_api.core.settings import settings
|
|
78
|
+
|
|
79
|
+
router = APIRouter()
|
|
80
|
+
|
|
81
|
+
DatasourceContext = tuple[DatasourceConnector, DatasourceSchemaCache]
|
|
82
|
+
ProgressCallback = Callable[[dict[str, Any]], None]
|
|
83
|
+
|
|
84
|
+
READ_ONLY_REFUSAL_ANSWER = (
|
|
85
|
+
"Nie mogę tego zrobić. GAARD obsługuje tylko odczyt danych i nie wykonuje "
|
|
86
|
+
"operacji modyfikujących, usuwających ani tworzących dane."
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
READ_ONLY_SCOPE_REFUSAL_ANSWER = (
|
|
90
|
+
"Nie mogę obsłużyć tego zapytania. GAARD odpowiada tylko na pytania, "
|
|
91
|
+
"które można zrealizować jako bezpieczny odczyt danych SQL SELECT."
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
ALLOWLIST_REFUSAL_ANSWER = (
|
|
95
|
+
"Nie mogę wykonać tego zapytania, ponieważ wymagane tabele, kolumny, "
|
|
96
|
+
"relacje albo dowody są poza zakresem dozwolonym dla tego zadania."
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
SQL_SYNTAX_REFUSAL_ANSWER = (
|
|
100
|
+
"Nie mogę wykonać tego zapytania, ponieważ wygenerowany SQL nie przeszedł "
|
|
101
|
+
"walidacji składni lub zasad pojedynczego zapytania."
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
CLARIFICATION_REFUSAL_ANSWER = (
|
|
105
|
+
"Potrzebuję doprecyzowania, zanim bezpiecznie rozpocznę tę analizę."
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
ANALYSIS_MODE_PENDING_ANSWER = (
|
|
109
|
+
"Tryb Investigation wymaga dodatkowej analizy przed wygenerowaniem SQL. "
|
|
110
|
+
"Ścieżka Analysis nie jest jeszcze zaimplementowana."
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
VALIDATION_SQL_PREFIXES = (
|
|
114
|
+
"Only SELECT queries are allowed. ",
|
|
115
|
+
"DDL and DML statements are not allowed. ",
|
|
116
|
+
"Only single-statement SQL queries are allowed. SQL: ",
|
|
117
|
+
"Invalid SQL syntax. ",
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def create_llm_client(
|
|
122
|
+
llm_config: LlmRuntimeConfig | None = None,
|
|
123
|
+
) -> OpenAICompatibleClient:
|
|
124
|
+
llm_config = llm_config or get_llm_runtime_config_safe()
|
|
125
|
+
|
|
126
|
+
if llm_config.provider != "openai-compatible":
|
|
127
|
+
raise ConfigurationError(f"Unsupported GAARD_LLM_PROVIDER: {llm_config.provider}")
|
|
128
|
+
|
|
129
|
+
if llm_config.api_key == "change-me":
|
|
130
|
+
raise ConfigurationError("GAARD_LLM_API_KEY must be configured when using LLM mode.")
|
|
131
|
+
|
|
132
|
+
return OpenAICompatibleClient(
|
|
133
|
+
base_url=llm_config.base_url,
|
|
134
|
+
api_key=llm_config.api_key,
|
|
135
|
+
timeout_seconds=llm_config.timeout_seconds,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def create_sql_generator(
|
|
140
|
+
datasource_context: DatasourceContext | None = None,
|
|
141
|
+
llm_config: LlmRuntimeConfig | None = None,
|
|
142
|
+
runtime_config: QueryRuntimeConfig | None = None,
|
|
143
|
+
) -> MockSqlGenerator | LlmSqlGenerator:
|
|
144
|
+
runtime_config = runtime_config or get_query_runtime_config_safe()
|
|
145
|
+
|
|
146
|
+
if runtime_config.sql_generation_mode == "mock":
|
|
147
|
+
return MockSqlGenerator()
|
|
148
|
+
|
|
149
|
+
if runtime_config.sql_generation_mode == "llm":
|
|
150
|
+
llm_config = llm_config or get_llm_runtime_config_safe()
|
|
151
|
+
|
|
152
|
+
if datasource_context is None:
|
|
153
|
+
datasource_context = get_datasource_schema_context_safe()
|
|
154
|
+
|
|
155
|
+
if datasource_context is not None:
|
|
156
|
+
connector, schema_cache = datasource_context
|
|
157
|
+
formatted_schema = append_business_logic_to_schema(
|
|
158
|
+
schema_cache.formatted_schema,
|
|
159
|
+
connector.id,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
return LlmSqlGenerator(
|
|
163
|
+
client=create_llm_client(llm_config),
|
|
164
|
+
model=llm_config.model,
|
|
165
|
+
formatted_schema=formatted_schema,
|
|
166
|
+
dialect=connector.sql_dialect,
|
|
167
|
+
max_rows=runtime_config.query_max_rows,
|
|
168
|
+
extra_body=llm_config.extra_body,
|
|
169
|
+
prompt_compiler=get_sql_generation_prompt_compiler(),
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
introspector = SQLAlchemySchemaIntrospector(
|
|
173
|
+
database_url=settings.gaard_datasource_url,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
schema_context_service = SchemaContextService(
|
|
177
|
+
introspector=introspector,
|
|
178
|
+
cache=schema_context_cache,
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
schema_context = schema_context_service.get_schema_context(
|
|
182
|
+
get_schema_cache_key()
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
return LlmSqlGenerator(
|
|
186
|
+
client=create_llm_client(llm_config),
|
|
187
|
+
model=llm_config.model,
|
|
188
|
+
formatted_schema=schema_context.formatted_schema,
|
|
189
|
+
dialect=settings.gaard_sql_dialect,
|
|
190
|
+
max_rows=runtime_config.query_max_rows,
|
|
191
|
+
extra_body=llm_config.extra_body,
|
|
192
|
+
prompt_compiler=get_sql_generation_prompt_compiler(),
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
raise ConfigurationError(
|
|
196
|
+
f"Unsupported GAARD_SQL_GENERATION_MODE: {runtime_config.sql_generation_mode}"
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def resolve_intent_classification_mode() -> str:
|
|
201
|
+
runtime_config = get_query_runtime_config_safe()
|
|
202
|
+
|
|
203
|
+
if runtime_config.intent_classification_mode == "auto":
|
|
204
|
+
return "llm" if runtime_config.sql_generation_mode == "llm" else "mock"
|
|
205
|
+
|
|
206
|
+
return runtime_config.intent_classification_mode
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def create_intent_classifier(
|
|
210
|
+
llm_config: LlmRuntimeConfig | None = None,
|
|
211
|
+
) -> MockQueryIntentClassifier | LlmQueryIntentClassifier:
|
|
212
|
+
intent_classification_mode = resolve_intent_classification_mode()
|
|
213
|
+
|
|
214
|
+
if intent_classification_mode == "mock":
|
|
215
|
+
return MockQueryIntentClassifier()
|
|
216
|
+
|
|
217
|
+
if intent_classification_mode == "llm":
|
|
218
|
+
llm_config = llm_config or get_llm_runtime_config_safe()
|
|
219
|
+
|
|
220
|
+
return LlmQueryIntentClassifier(
|
|
221
|
+
client=create_llm_client(llm_config),
|
|
222
|
+
model=llm_config.model,
|
|
223
|
+
extra_body=llm_config.extra_body,
|
|
224
|
+
prompt_compiler=get_intent_classification_prompt_compiler(),
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
raise ConfigurationError(
|
|
228
|
+
"Unsupported GAARD_INTENT_CLASSIFICATION_MODE: "
|
|
229
|
+
f"{get_query_runtime_config_safe().intent_classification_mode}"
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def create_result_interpreter(
|
|
234
|
+
llm_config: LlmRuntimeConfig | None = None,
|
|
235
|
+
runtime_config: QueryRuntimeConfig | None = None,
|
|
236
|
+
) -> MockResultInterpreter | LlmResultInterpreter:
|
|
237
|
+
runtime_config = runtime_config or get_query_runtime_config_safe()
|
|
238
|
+
|
|
239
|
+
if runtime_config.result_interpretation_mode == "mock":
|
|
240
|
+
return MockResultInterpreter()
|
|
241
|
+
|
|
242
|
+
if runtime_config.result_interpretation_mode == "llm":
|
|
243
|
+
llm_config = llm_config or get_llm_runtime_config_safe()
|
|
244
|
+
|
|
245
|
+
return LlmResultInterpreter(
|
|
246
|
+
client=create_llm_client(llm_config),
|
|
247
|
+
model=llm_config.model,
|
|
248
|
+
extra_body=llm_config.extra_body,
|
|
249
|
+
prompt_compiler=get_result_interpretation_prompt_compiler(),
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
raise ConfigurationError(
|
|
253
|
+
"Unsupported GAARD_RESULT_INTERPRETATION_MODE: "
|
|
254
|
+
f"{runtime_config.result_interpretation_mode}"
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def resolve_output_classification_mode() -> str:
|
|
259
|
+
runtime_config = get_query_runtime_config_safe()
|
|
260
|
+
|
|
261
|
+
if runtime_config.output_classification_mode == "auto":
|
|
262
|
+
return "llm" if runtime_config.result_interpretation_mode == "llm" else "mock"
|
|
263
|
+
|
|
264
|
+
return runtime_config.output_classification_mode
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def create_result_classifier(
|
|
268
|
+
llm_config: LlmRuntimeConfig | None = None,
|
|
269
|
+
runtime_config: QueryRuntimeConfig | None = None,
|
|
270
|
+
) -> MockResultClassifier | LlmResultClassifier:
|
|
271
|
+
runtime_config = runtime_config or get_query_runtime_config_safe()
|
|
272
|
+
output_classification_mode = (
|
|
273
|
+
"llm"
|
|
274
|
+
if runtime_config.output_classification_mode == "auto"
|
|
275
|
+
and runtime_config.result_interpretation_mode == "llm"
|
|
276
|
+
else (
|
|
277
|
+
"mock"
|
|
278
|
+
if runtime_config.output_classification_mode == "auto"
|
|
279
|
+
else runtime_config.output_classification_mode
|
|
280
|
+
)
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
if output_classification_mode == "mock":
|
|
284
|
+
return MockResultClassifier()
|
|
285
|
+
|
|
286
|
+
if output_classification_mode == "llm":
|
|
287
|
+
llm_config = llm_config or get_llm_runtime_config_safe()
|
|
288
|
+
|
|
289
|
+
return LlmResultClassifier(
|
|
290
|
+
client=create_llm_client(llm_config),
|
|
291
|
+
model=llm_config.model,
|
|
292
|
+
extra_body=llm_config.extra_body,
|
|
293
|
+
prompt_compiler=get_result_classification_prompt_compiler(),
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
raise ConfigurationError(
|
|
297
|
+
"Unsupported GAARD_OUTPUT_CLASSIFICATION_MODE: "
|
|
298
|
+
f"{runtime_config.output_classification_mode}"
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def create_pipeline(datasource_context: DatasourceContext | None = None) -> QueryPipeline:
|
|
303
|
+
if datasource_context is None:
|
|
304
|
+
datasource_context = get_datasource_schema_context_safe()
|
|
305
|
+
|
|
306
|
+
runtime_config = get_query_runtime_config_safe()
|
|
307
|
+
database_url = (
|
|
308
|
+
datasource_context[0].database_url
|
|
309
|
+
if datasource_context is not None
|
|
310
|
+
else settings.gaard_datasource_url
|
|
311
|
+
)
|
|
312
|
+
sql_dialect = (
|
|
313
|
+
datasource_context[0].sql_dialect
|
|
314
|
+
if datasource_context is not None
|
|
315
|
+
else settings.gaard_sql_dialect
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
executor = SQLAlchemyQueryExecutor(
|
|
319
|
+
database_url=database_url,
|
|
320
|
+
max_rows=runtime_config.query_max_rows,
|
|
321
|
+
)
|
|
322
|
+
llm_config = (
|
|
323
|
+
get_llm_runtime_config_safe()
|
|
324
|
+
if "llm"
|
|
325
|
+
in {
|
|
326
|
+
runtime_config.sql_generation_mode,
|
|
327
|
+
runtime_config.result_interpretation_mode,
|
|
328
|
+
(
|
|
329
|
+
"llm"
|
|
330
|
+
if runtime_config.output_classification_mode == "auto"
|
|
331
|
+
and runtime_config.result_interpretation_mode == "llm"
|
|
332
|
+
else runtime_config.output_classification_mode
|
|
333
|
+
),
|
|
334
|
+
}
|
|
335
|
+
else None
|
|
336
|
+
)
|
|
337
|
+
output_classification_mode = (
|
|
338
|
+
"llm"
|
|
339
|
+
if runtime_config.output_classification_mode == "auto"
|
|
340
|
+
and runtime_config.result_interpretation_mode == "llm"
|
|
341
|
+
else (
|
|
342
|
+
"mock"
|
|
343
|
+
if runtime_config.output_classification_mode == "auto"
|
|
344
|
+
else runtime_config.output_classification_mode
|
|
345
|
+
)
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
return QueryPipeline(
|
|
349
|
+
sql_generator=create_sql_generator(datasource_context, llm_config, runtime_config),
|
|
350
|
+
sql_validator=SelectOnlySqlValidator(dialect=sql_dialect),
|
|
351
|
+
executor=executor,
|
|
352
|
+
interpreter=create_result_interpreter(llm_config, runtime_config),
|
|
353
|
+
classifier=create_result_classifier(llm_config, runtime_config),
|
|
354
|
+
sql_generation_mode=runtime_config.sql_generation_mode,
|
|
355
|
+
result_interpretation_mode=runtime_config.result_interpretation_mode,
|
|
356
|
+
output_classification_mode=output_classification_mode,
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
def append_business_logic_to_schema(formatted_schema: str, connector_id: int) -> str:
|
|
361
|
+
business_logic = get_active_business_logic_prompt_safe(connector_id)
|
|
362
|
+
|
|
363
|
+
if not business_logic:
|
|
364
|
+
return formatted_schema
|
|
365
|
+
|
|
366
|
+
return f"{formatted_schema}\n\n{business_logic}"
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
def schema_and_business_logic_for_investigation(
|
|
370
|
+
datasource_context: DatasourceContext | None,
|
|
371
|
+
) -> tuple[str, str]:
|
|
372
|
+
if datasource_context is not None:
|
|
373
|
+
connector, schema_cache = datasource_context
|
|
374
|
+
return (
|
|
375
|
+
schema_cache.formatted_schema,
|
|
376
|
+
get_active_business_logic_prompt_safe(connector.id),
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
introspector = SQLAlchemySchemaIntrospector(
|
|
380
|
+
database_url=settings.gaard_datasource_url,
|
|
381
|
+
)
|
|
382
|
+
schema_context_service = SchemaContextService(
|
|
383
|
+
introspector=introspector,
|
|
384
|
+
cache=schema_context_cache,
|
|
385
|
+
)
|
|
386
|
+
schema_context = schema_context_service.get_schema_context(get_schema_cache_key())
|
|
387
|
+
|
|
388
|
+
return schema_context.formatted_schema, ""
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
def create_investigation_context(
|
|
392
|
+
request: QueryRequest,
|
|
393
|
+
datasource_context: DatasourceContext | None,
|
|
394
|
+
) -> InvestigationContext:
|
|
395
|
+
formatted_schema, business_logic = schema_and_business_logic_for_investigation(
|
|
396
|
+
datasource_context
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
return InvestigationContext(
|
|
400
|
+
question=request.question,
|
|
401
|
+
datasource_id=request.datasource_id,
|
|
402
|
+
user_id=request.user_id,
|
|
403
|
+
formatted_schema=formatted_schema,
|
|
404
|
+
business_logic=business_logic,
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
def resolve_investigation_mode(runtime_config: QueryRuntimeConfig) -> str:
|
|
409
|
+
if runtime_config.investigation_mode == "auto":
|
|
410
|
+
return "llm" if runtime_config.sql_generation_mode == "llm" else "mock"
|
|
411
|
+
|
|
412
|
+
return runtime_config.investigation_mode
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
def create_investigation_readiness_agent(
|
|
416
|
+
runtime_config: QueryRuntimeConfig | None = None,
|
|
417
|
+
llm_config: LlmRuntimeConfig | None = None,
|
|
418
|
+
) -> MockInvestigationReadinessAgent | LlmInvestigationReadinessAgent:
|
|
419
|
+
runtime_config = runtime_config or get_query_runtime_config_safe()
|
|
420
|
+
investigation_mode = resolve_investigation_mode(runtime_config)
|
|
421
|
+
|
|
422
|
+
if investigation_mode == "mock":
|
|
423
|
+
return MockInvestigationReadinessAgent()
|
|
424
|
+
|
|
425
|
+
if investigation_mode == "llm":
|
|
426
|
+
llm_config = llm_config or get_llm_runtime_config_safe()
|
|
427
|
+
return LlmInvestigationReadinessAgent(
|
|
428
|
+
client=create_llm_client(llm_config),
|
|
429
|
+
model=llm_config.model,
|
|
430
|
+
extra_body=llm_config.extra_body,
|
|
431
|
+
prompt_compiler=get_investigation_readiness_prompt_compiler(),
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
raise ConfigurationError(
|
|
435
|
+
f"Unsupported GAARD_INVESTIGATION_MODE: {runtime_config.investigation_mode}"
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
def investigation_iteration_metadata(result: InvestigationLoopResult) -> list[dict[str, Any]]:
|
|
440
|
+
return [
|
|
441
|
+
{
|
|
442
|
+
"iteration": item.iteration,
|
|
443
|
+
"agent": item.agent,
|
|
444
|
+
**item.decision.model_dump(mode="json"),
|
|
445
|
+
}
|
|
446
|
+
for item in result.iterations
|
|
447
|
+
]
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
def investigation_metadata(
|
|
451
|
+
result: InvestigationLoopResult,
|
|
452
|
+
investigation_mode: str,
|
|
453
|
+
) -> dict[str, Any]:
|
|
454
|
+
steps = investigation_iteration_metadata(result)
|
|
455
|
+
metadata = {
|
|
456
|
+
"query_mode": QueryMode.INVESTIGATION.value,
|
|
457
|
+
"investigation_backend_status": "readiness_gate_active",
|
|
458
|
+
"investigation_mode": investigation_mode,
|
|
459
|
+
"investigation_route": result.route.value,
|
|
460
|
+
"investigation_loop": {
|
|
461
|
+
"max_iterations": result.max_iterations,
|
|
462
|
+
"iterations_run": len(result.iterations),
|
|
463
|
+
"confidence_threshold": result.confidence_threshold,
|
|
464
|
+
},
|
|
465
|
+
"investigation_steps": steps,
|
|
466
|
+
"investigation_audit_trail": steps,
|
|
467
|
+
}
|
|
468
|
+
|
|
469
|
+
if result.route == InvestigationRoute.ANALYSIS:
|
|
470
|
+
metadata["analysis_mode_status"] = "not_implemented"
|
|
471
|
+
|
|
472
|
+
return metadata
|
|
473
|
+
|
|
474
|
+
|
|
475
|
+
def required_analysis_tasks_from_result(
|
|
476
|
+
result: InvestigationLoopResult,
|
|
477
|
+
) -> list[RequiredAnalysisTask]:
|
|
478
|
+
decision = result.final_decision
|
|
479
|
+
if decision is None:
|
|
480
|
+
return []
|
|
481
|
+
|
|
482
|
+
if decision.required_analysis_tasks:
|
|
483
|
+
return [
|
|
484
|
+
task
|
|
485
|
+
for task in decision.required_analysis_tasks
|
|
486
|
+
if task.required_analysis.strip()
|
|
487
|
+
]
|
|
488
|
+
|
|
489
|
+
tasks: list[RequiredAnalysisTask] = []
|
|
490
|
+
for index, required_analysis in enumerate(decision.required_analysis):
|
|
491
|
+
tasks.append(
|
|
492
|
+
RequiredAnalysisTask(
|
|
493
|
+
missing_information=decision.missing_information[index]
|
|
494
|
+
if index < len(decision.missing_information)
|
|
495
|
+
else "",
|
|
496
|
+
required_analysis=required_analysis,
|
|
497
|
+
)
|
|
498
|
+
)
|
|
499
|
+
|
|
500
|
+
return tasks
|
|
501
|
+
|
|
502
|
+
|
|
503
|
+
def emit_investigation_progress(
|
|
504
|
+
progress_callback: ProgressCallback | None,
|
|
505
|
+
payload: dict[str, Any],
|
|
506
|
+
) -> None:
|
|
507
|
+
if progress_callback is not None:
|
|
508
|
+
progress_callback(payload)
|
|
509
|
+
|
|
510
|
+
|
|
511
|
+
def record_investigation_readiness_audit(
|
|
512
|
+
effective_request: QueryRequest,
|
|
513
|
+
result: InvestigationLoopResult,
|
|
514
|
+
metadata: dict[str, Any],
|
|
515
|
+
) -> int | None:
|
|
516
|
+
decision = result.final_decision
|
|
517
|
+
response = QueryResponse(
|
|
518
|
+
question=effective_request.question,
|
|
519
|
+
answer=decision.reason if decision is not None else "No readiness decision.",
|
|
520
|
+
sql="",
|
|
521
|
+
rows=[],
|
|
522
|
+
metadata={
|
|
523
|
+
"duration_ms": 0,
|
|
524
|
+
"datasource_id": effective_request.datasource_id,
|
|
525
|
+
"user_id": effective_request.user_id,
|
|
526
|
+
"output_classification": OutputClassification.UNKNOWN.value,
|
|
527
|
+
**metadata,
|
|
528
|
+
"investigation_step": "readiness",
|
|
529
|
+
},
|
|
530
|
+
)
|
|
531
|
+
audit_log = record_data_query_audit(effective_request, response)
|
|
532
|
+
return audit_log.id if audit_log is not None else None
|
|
533
|
+
|
|
534
|
+
|
|
535
|
+
def datasource_connector_id(datasource_context: DatasourceContext | None) -> int | None:
|
|
536
|
+
return datasource_context[0].id if datasource_context is not None else None
|
|
537
|
+
|
|
538
|
+
|
|
539
|
+
def run_investigation_analysis_tasks(
|
|
540
|
+
effective_request: QueryRequest,
|
|
541
|
+
datasource_context: DatasourceContext | None,
|
|
542
|
+
result: InvestigationLoopResult,
|
|
543
|
+
progress_callback: ProgressCallback | None = None,
|
|
544
|
+
) -> list[dict[str, Any]]:
|
|
545
|
+
analysis_results: list[dict[str, Any]] = []
|
|
546
|
+
tasks = required_analysis_tasks_from_result(result)
|
|
547
|
+
|
|
548
|
+
for index, task in enumerate(tasks):
|
|
549
|
+
emit_investigation_progress(
|
|
550
|
+
progress_callback,
|
|
551
|
+
{
|
|
552
|
+
"step": "analysis_sql",
|
|
553
|
+
"analysis_task_index": index,
|
|
554
|
+
"data_question": task.required_analysis,
|
|
555
|
+
"decisions": [
|
|
556
|
+
f"Running analysis SQL task {index + 1} of {len(tasks)}."
|
|
557
|
+
],
|
|
558
|
+
},
|
|
559
|
+
)
|
|
560
|
+
analysis_request = effective_request.model_copy(
|
|
561
|
+
update={
|
|
562
|
+
"question": task.required_analysis,
|
|
563
|
+
"mode": QueryMode.SQL,
|
|
564
|
+
}
|
|
565
|
+
)
|
|
566
|
+
analysis_metadata = {
|
|
567
|
+
"query_mode": QueryMode.INVESTIGATION.value,
|
|
568
|
+
"investigation_backend_status": "readiness_gate_active",
|
|
569
|
+
"investigation_route": InvestigationRoute.ANALYSIS.value,
|
|
570
|
+
"investigation_step": "analysis_sql",
|
|
571
|
+
"analysis_task_index": index,
|
|
572
|
+
"analysis_missing_information": task.missing_information,
|
|
573
|
+
"analysis_required_analysis": task.required_analysis,
|
|
574
|
+
"analysis_category": task.category,
|
|
575
|
+
"analysis_expected_output": task.expected_output,
|
|
576
|
+
"original_question": effective_request.question,
|
|
577
|
+
}
|
|
578
|
+
analysis_response = run_sql_request(
|
|
579
|
+
analysis_request,
|
|
580
|
+
datasource_context,
|
|
581
|
+
analysis_metadata,
|
|
582
|
+
)
|
|
583
|
+
learning_result = record_analysis_business_logic_if_possible(
|
|
584
|
+
effective_request=effective_request,
|
|
585
|
+
datasource_context=datasource_context,
|
|
586
|
+
task=task,
|
|
587
|
+
task_index=index,
|
|
588
|
+
analysis_response=analysis_response,
|
|
589
|
+
)
|
|
590
|
+
task_result = {
|
|
591
|
+
"analysis_task_index": index,
|
|
592
|
+
"missing_information": task.missing_information,
|
|
593
|
+
"required_analysis": task.required_analysis,
|
|
594
|
+
"category": task.category,
|
|
595
|
+
"expected_output": task.expected_output,
|
|
596
|
+
"sql": analysis_response.sql,
|
|
597
|
+
"rows": analysis_response.rows,
|
|
598
|
+
"answer": analysis_response.answer,
|
|
599
|
+
"audit_log_id": analysis_response.metadata.get("data_query_audit_id"),
|
|
600
|
+
"business_logic_learning": learning_result,
|
|
601
|
+
}
|
|
602
|
+
analysis_results.append(task_result)
|
|
603
|
+
emit_investigation_progress(
|
|
604
|
+
progress_callback,
|
|
605
|
+
{
|
|
606
|
+
"step": "analysis_sql_complete",
|
|
607
|
+
"analysis_task_index": index,
|
|
608
|
+
"data_question": task.required_analysis,
|
|
609
|
+
"decisions": [
|
|
610
|
+
f"Analysis SQL task {index + 1} completed.",
|
|
611
|
+
business_logic_progress_message(learning_result),
|
|
612
|
+
],
|
|
613
|
+
},
|
|
614
|
+
)
|
|
615
|
+
|
|
616
|
+
return analysis_results
|
|
617
|
+
|
|
618
|
+
|
|
619
|
+
def record_analysis_business_logic_if_possible(
|
|
620
|
+
effective_request: QueryRequest,
|
|
621
|
+
datasource_context: DatasourceContext | None,
|
|
622
|
+
task: RequiredAnalysisTask,
|
|
623
|
+
task_index: int,
|
|
624
|
+
analysis_response: QueryResponse,
|
|
625
|
+
) -> dict[str, Any]:
|
|
626
|
+
if analysis_response.metadata.get("blocked"):
|
|
627
|
+
learning_result = {
|
|
628
|
+
"status": "skipped",
|
|
629
|
+
"reason": "Analysis SQL task was blocked and did not produce evidence.",
|
|
630
|
+
}
|
|
631
|
+
else:
|
|
632
|
+
source_audit_id = analysis_response.metadata.get("data_query_audit_id")
|
|
633
|
+
learning_result = upsert_investigation_analysis_business_logic_suggestion(
|
|
634
|
+
connector_id=datasource_connector_id(datasource_context),
|
|
635
|
+
source_audit_id=source_audit_id if isinstance(source_audit_id, int) else None,
|
|
636
|
+
missing_information=task.missing_information,
|
|
637
|
+
required_analysis=task.required_analysis,
|
|
638
|
+
category=task.category,
|
|
639
|
+
analysis_response=analysis_response,
|
|
640
|
+
)
|
|
641
|
+
|
|
642
|
+
record_investigation_business_logic_audit(
|
|
643
|
+
effective_request=effective_request,
|
|
644
|
+
task=task,
|
|
645
|
+
task_index=task_index,
|
|
646
|
+
analysis_response=analysis_response,
|
|
647
|
+
learning_result=learning_result,
|
|
648
|
+
)
|
|
649
|
+
return learning_result
|
|
650
|
+
|
|
651
|
+
|
|
652
|
+
def record_investigation_business_logic_audit(
|
|
653
|
+
effective_request: QueryRequest,
|
|
654
|
+
task: RequiredAnalysisTask,
|
|
655
|
+
task_index: int,
|
|
656
|
+
analysis_response: QueryResponse,
|
|
657
|
+
learning_result: dict[str, Any],
|
|
658
|
+
) -> None:
|
|
659
|
+
response = QueryResponse(
|
|
660
|
+
question=effective_request.question,
|
|
661
|
+
answer=business_logic_progress_message(learning_result),
|
|
662
|
+
sql=analysis_response.sql,
|
|
663
|
+
rows=[],
|
|
664
|
+
metadata={
|
|
665
|
+
"duration_ms": 0,
|
|
666
|
+
"datasource_id": effective_request.datasource_id,
|
|
667
|
+
"user_id": effective_request.user_id,
|
|
668
|
+
"output_classification": OutputClassification.UNKNOWN.value,
|
|
669
|
+
"query_mode": QueryMode.INVESTIGATION.value,
|
|
670
|
+
"investigation_backend_status": "readiness_gate_active",
|
|
671
|
+
"investigation_route": InvestigationRoute.ANALYSIS.value,
|
|
672
|
+
"investigation_step": "analysis_business_logic",
|
|
673
|
+
"analysis_task_index": task_index,
|
|
674
|
+
"analysis_missing_information": task.missing_information,
|
|
675
|
+
"analysis_required_analysis": task.required_analysis,
|
|
676
|
+
"analysis_category": task.category,
|
|
677
|
+
"analysis_source_audit_log_id": analysis_response.metadata.get(
|
|
678
|
+
"data_query_audit_id"
|
|
679
|
+
),
|
|
680
|
+
"business_logic_learning": learning_result,
|
|
681
|
+
},
|
|
682
|
+
)
|
|
683
|
+
record_data_query_audit(effective_request, response)
|
|
684
|
+
|
|
685
|
+
|
|
686
|
+
def business_logic_progress_message(learning_result: dict[str, Any]) -> str:
|
|
687
|
+
status = str(learning_result.get("status") or "")
|
|
688
|
+
if status == "created":
|
|
689
|
+
return "Business logic suggestion was created and is pending approval."
|
|
690
|
+
if status == "existing":
|
|
691
|
+
return "Business logic suggestion already exists; no duplicate was created."
|
|
692
|
+
if status == "skipped":
|
|
693
|
+
return f"Business logic suggestion was skipped: {learning_result.get('reason')}"
|
|
694
|
+
return "Business logic suggestion step completed."
|
|
695
|
+
|
|
696
|
+
|
|
697
|
+
def extract_sql_from_validation_error(error_message: str) -> str:
|
|
698
|
+
for prefix in VALIDATION_SQL_PREFIXES:
|
|
699
|
+
if error_message.startswith(prefix):
|
|
700
|
+
return error_message.removeprefix(prefix).strip()
|
|
701
|
+
|
|
702
|
+
return ""
|
|
703
|
+
|
|
704
|
+
|
|
705
|
+
def validation_error_metadata(exc: SqlValidationError) -> dict[str, Any]:
|
|
706
|
+
if exc.metadata.get("primary_error_category"):
|
|
707
|
+
return exc.metadata
|
|
708
|
+
|
|
709
|
+
category = "sql.validation.write_operation" if any(
|
|
710
|
+
text in exc.message
|
|
711
|
+
for text in (
|
|
712
|
+
"Only SELECT queries are allowed",
|
|
713
|
+
"DDL and DML statements are not allowed",
|
|
714
|
+
)
|
|
715
|
+
) else (
|
|
716
|
+
"sql.validation.syntax"
|
|
717
|
+
if any(
|
|
718
|
+
text in exc.message
|
|
719
|
+
for text in (
|
|
720
|
+
"Only single-statement SQL queries are allowed",
|
|
721
|
+
"Invalid SQL syntax",
|
|
722
|
+
)
|
|
723
|
+
)
|
|
724
|
+
else "sql.validation.disallowed_column"
|
|
725
|
+
)
|
|
726
|
+
return {
|
|
727
|
+
**exc.metadata,
|
|
728
|
+
"primary_error_category": category,
|
|
729
|
+
"error_categories": [category],
|
|
730
|
+
}
|
|
731
|
+
|
|
732
|
+
|
|
733
|
+
def is_read_only_intent(intent: QueryIntentClassification) -> bool:
|
|
734
|
+
return intent.decision == QueryIntentDecision.READ_ONLY_DATA_QUESTION
|
|
735
|
+
|
|
736
|
+
|
|
737
|
+
def intent_metadata(
|
|
738
|
+
intent: QueryIntentClassification,
|
|
739
|
+
intent_classification_mode: str,
|
|
740
|
+
) -> dict[str, Any]:
|
|
741
|
+
model_response = intent.model_response or {
|
|
742
|
+
"decision": intent.decision.value,
|
|
743
|
+
"confidence": intent.confidence,
|
|
744
|
+
"reason": intent.reason,
|
|
745
|
+
}
|
|
746
|
+
|
|
747
|
+
return {
|
|
748
|
+
"intent_classification_mode": intent_classification_mode,
|
|
749
|
+
"intent_decision": intent.decision.value,
|
|
750
|
+
"intent_confidence": intent.confidence,
|
|
751
|
+
"intent_reason": intent.reason,
|
|
752
|
+
"intent_model_response": model_response,
|
|
753
|
+
}
|
|
754
|
+
|
|
755
|
+
|
|
756
|
+
def build_access_refusal_response(
|
|
757
|
+
request: QueryRequest,
|
|
758
|
+
reason: str,
|
|
759
|
+
sql: str = "",
|
|
760
|
+
metadata: dict[str, Any] | None = None,
|
|
761
|
+
) -> QueryResponse:
|
|
762
|
+
metadata = metadata or {}
|
|
763
|
+
return QueryResponse(
|
|
764
|
+
question=request.question,
|
|
765
|
+
answer=access_refusal_answer(reason, metadata),
|
|
766
|
+
sql=sql,
|
|
767
|
+
rows=[],
|
|
768
|
+
metadata={
|
|
769
|
+
"duration_ms": 0,
|
|
770
|
+
"datasource_id": request.datasource_id,
|
|
771
|
+
"user_id": request.user_id,
|
|
772
|
+
"output_classification": OutputClassification.UNKNOWN.value,
|
|
773
|
+
"blocked": True,
|
|
774
|
+
"blocked_reason": reason,
|
|
775
|
+
**metadata,
|
|
776
|
+
},
|
|
777
|
+
)
|
|
778
|
+
|
|
779
|
+
|
|
780
|
+
def access_refusal_answer(reason: str, metadata: dict[str, Any]) -> str:
|
|
781
|
+
if reason != ACCESS_ERROR_SQL_VALIDATION:
|
|
782
|
+
return READ_ONLY_SCOPE_REFUSAL_ANSWER
|
|
783
|
+
|
|
784
|
+
categories = set(metadata.get("error_categories") or [])
|
|
785
|
+
primary = metadata.get("primary_error_category")
|
|
786
|
+
if primary:
|
|
787
|
+
categories.add(str(primary))
|
|
788
|
+
|
|
789
|
+
if "sql.validation.write_operation" in categories:
|
|
790
|
+
return READ_ONLY_REFUSAL_ANSWER
|
|
791
|
+
if "intent.ambiguous_requires_clarification" in categories:
|
|
792
|
+
return CLARIFICATION_REFUSAL_ANSWER
|
|
793
|
+
if "sql.validation.syntax" in categories:
|
|
794
|
+
return SQL_SYNTAX_REFUSAL_ANSWER
|
|
795
|
+
if categories & {
|
|
796
|
+
"sql.validation.disallowed_column",
|
|
797
|
+
"sql.validation.disallowed_table",
|
|
798
|
+
"sql.validation.disallowed_relationship",
|
|
799
|
+
"sql.validation.select_star",
|
|
800
|
+
"task.inconsistent_allowlist",
|
|
801
|
+
"task.insufficient_evidence",
|
|
802
|
+
}:
|
|
803
|
+
return ALLOWLIST_REFUSAL_ANSWER
|
|
804
|
+
return READ_ONLY_SCOPE_REFUSAL_ANSWER
|
|
805
|
+
|
|
806
|
+
|
|
807
|
+
def run_sql_request(
|
|
808
|
+
effective_request: QueryRequest,
|
|
809
|
+
datasource_context: DatasourceContext | None,
|
|
810
|
+
extra_metadata: dict[str, Any] | None = None,
|
|
811
|
+
) -> QueryResponse:
|
|
812
|
+
extra_metadata = extra_metadata or {}
|
|
813
|
+
intent_mode = resolve_intent_classification_mode()
|
|
814
|
+
intent_llm_config = (
|
|
815
|
+
get_llm_runtime_config_safe() if intent_mode == "llm" else None
|
|
816
|
+
)
|
|
817
|
+
try:
|
|
818
|
+
intent = create_intent_classifier(intent_llm_config).classify(effective_request)
|
|
819
|
+
except LlmProviderError as exc:
|
|
820
|
+
audit_log = record_data_query_pipeline_error_audit(
|
|
821
|
+
request=effective_request,
|
|
822
|
+
sql="",
|
|
823
|
+
error_code=exc.code,
|
|
824
|
+
error_message=exc.message,
|
|
825
|
+
error_detail=exc.message,
|
|
826
|
+
pipeline_phase="intent_classification",
|
|
827
|
+
metadata={**extra_metadata, "intent_classification_mode": intent_mode},
|
|
828
|
+
)
|
|
829
|
+
learn_business_logic_from_sql_error(
|
|
830
|
+
connector_id=datasource_context[0].id if datasource_context is not None else None,
|
|
831
|
+
audit_id=audit_log.id if audit_log is not None else None,
|
|
832
|
+
)
|
|
833
|
+
raise
|
|
834
|
+
|
|
835
|
+
current_intent_metadata = intent_metadata(intent, intent_mode)
|
|
836
|
+
audit_metadata = {**current_intent_metadata, **extra_metadata}
|
|
837
|
+
|
|
838
|
+
if not is_read_only_intent(intent):
|
|
839
|
+
response = build_access_refusal_response(
|
|
840
|
+
effective_request,
|
|
841
|
+
ACCESS_ERROR_INTENT_CLASSIFICATION,
|
|
842
|
+
metadata=audit_metadata,
|
|
843
|
+
)
|
|
844
|
+
audit_log = record_data_query_access_error_audit(
|
|
845
|
+
request=effective_request,
|
|
846
|
+
answer=response.answer,
|
|
847
|
+
reason=ACCESS_ERROR_INTENT_CLASSIFICATION,
|
|
848
|
+
metadata=audit_metadata,
|
|
849
|
+
)
|
|
850
|
+
if audit_log is not None:
|
|
851
|
+
response.metadata["data_query_audit_id"] = audit_log.id
|
|
852
|
+
return response
|
|
853
|
+
|
|
854
|
+
pipeline = create_pipeline(datasource_context)
|
|
855
|
+
try:
|
|
856
|
+
response = pipeline.handle(effective_request)
|
|
857
|
+
except QueryExecutionError as exc:
|
|
858
|
+
audit_log = record_data_query_sql_error_audit(
|
|
859
|
+
request=effective_request,
|
|
860
|
+
sql=exc.sql,
|
|
861
|
+
error_code=exc.code,
|
|
862
|
+
error_message=exc.message,
|
|
863
|
+
error_detail=exc.error_detail,
|
|
864
|
+
metadata=extra_metadata,
|
|
865
|
+
)
|
|
866
|
+
learn_business_logic_from_sql_error(
|
|
867
|
+
connector_id=datasource_context[0].id if datasource_context is not None else None,
|
|
868
|
+
audit_id=audit_log.id if audit_log is not None else None,
|
|
869
|
+
)
|
|
870
|
+
raise
|
|
871
|
+
except QueryPipelineStepError as exc:
|
|
872
|
+
audit_log = record_data_query_pipeline_error_audit(
|
|
873
|
+
request=effective_request,
|
|
874
|
+
sql=exc.sql,
|
|
875
|
+
error_code=exc.code,
|
|
876
|
+
error_message=exc.message,
|
|
877
|
+
error_detail=exc.error_detail,
|
|
878
|
+
pipeline_phase=exc.phase,
|
|
879
|
+
metadata=audit_metadata,
|
|
880
|
+
)
|
|
881
|
+
learn_business_logic_from_sql_error(
|
|
882
|
+
connector_id=datasource_context[0].id if datasource_context is not None else None,
|
|
883
|
+
audit_id=audit_log.id if audit_log is not None else None,
|
|
884
|
+
)
|
|
885
|
+
raise
|
|
886
|
+
except SqlValidationError as exc:
|
|
887
|
+
validation_metadata = validation_error_metadata(exc)
|
|
888
|
+
response = build_access_refusal_response(
|
|
889
|
+
effective_request,
|
|
890
|
+
ACCESS_ERROR_SQL_VALIDATION,
|
|
891
|
+
sql=extract_sql_from_validation_error(exc.message),
|
|
892
|
+
metadata={**audit_metadata, **validation_metadata},
|
|
893
|
+
)
|
|
894
|
+
audit_log = record_data_query_access_error_audit(
|
|
895
|
+
request=effective_request,
|
|
896
|
+
answer=response.answer,
|
|
897
|
+
reason=ACCESS_ERROR_SQL_VALIDATION,
|
|
898
|
+
sql=response.sql,
|
|
899
|
+
error_code=exc.code,
|
|
900
|
+
error_detail=exc.message,
|
|
901
|
+
metadata={**audit_metadata, **validation_metadata},
|
|
902
|
+
)
|
|
903
|
+
if audit_log is not None:
|
|
904
|
+
response.metadata["data_query_audit_id"] = audit_log.id
|
|
905
|
+
learn_business_logic_from_sql_error(
|
|
906
|
+
connector_id=datasource_context[0].id if datasource_context is not None else None,
|
|
907
|
+
audit_id=audit_log.id if audit_log is not None else None,
|
|
908
|
+
)
|
|
909
|
+
return response
|
|
910
|
+
|
|
911
|
+
response.metadata.update(current_intent_metadata)
|
|
912
|
+
response.metadata.update(extra_metadata)
|
|
913
|
+
audit_log = record_data_query_audit(effective_request, response)
|
|
914
|
+
if audit_log is not None:
|
|
915
|
+
response.metadata["data_query_audit_id"] = audit_log.id
|
|
916
|
+
|
|
917
|
+
return response
|
|
918
|
+
|
|
919
|
+
|
|
920
|
+
def run_investigation_request(
|
|
921
|
+
effective_request: QueryRequest,
|
|
922
|
+
datasource_context: DatasourceContext | None,
|
|
923
|
+
progress_callback: ProgressCallback | None = None,
|
|
924
|
+
) -> QueryResponse:
|
|
925
|
+
started_at = time.perf_counter()
|
|
926
|
+
runtime_config = get_query_runtime_config_safe()
|
|
927
|
+
investigation_mode = resolve_investigation_mode(runtime_config)
|
|
928
|
+
context = create_investigation_context(effective_request, datasource_context)
|
|
929
|
+
loop_config = InvestigationLoopConfig(max_iterations=1)
|
|
930
|
+
|
|
931
|
+
emit_investigation_progress(
|
|
932
|
+
progress_callback,
|
|
933
|
+
{
|
|
934
|
+
"step": "readiness",
|
|
935
|
+
"data_question": effective_request.question,
|
|
936
|
+
"decisions": ["Running Investigation readiness check."],
|
|
937
|
+
},
|
|
938
|
+
)
|
|
939
|
+
try:
|
|
940
|
+
result = InvestigationLoop(
|
|
941
|
+
readiness_agent=create_investigation_readiness_agent(runtime_config),
|
|
942
|
+
config=loop_config,
|
|
943
|
+
).run(context)
|
|
944
|
+
except LlmProviderError as exc:
|
|
945
|
+
record_data_query_pipeline_error_audit(
|
|
946
|
+
request=effective_request,
|
|
947
|
+
sql="",
|
|
948
|
+
error_code=exc.code,
|
|
949
|
+
error_message=exc.message,
|
|
950
|
+
error_detail=exc.message,
|
|
951
|
+
pipeline_phase="investigation_readiness",
|
|
952
|
+
metadata={
|
|
953
|
+
"query_mode": QueryMode.INVESTIGATION.value,
|
|
954
|
+
"investigation_mode": investigation_mode,
|
|
955
|
+
"investigation_backend_status": "readiness_gate_active",
|
|
956
|
+
},
|
|
957
|
+
)
|
|
958
|
+
raise
|
|
959
|
+
|
|
960
|
+
metadata = investigation_metadata(result, investigation_mode)
|
|
961
|
+
metadata["investigation_readiness_duration_ms"] = round(
|
|
962
|
+
(time.perf_counter() - started_at) * 1000,
|
|
963
|
+
2,
|
|
964
|
+
)
|
|
965
|
+
readiness_audit_log_id = record_investigation_readiness_audit(
|
|
966
|
+
effective_request,
|
|
967
|
+
result,
|
|
968
|
+
metadata,
|
|
969
|
+
)
|
|
970
|
+
if readiness_audit_log_id is not None:
|
|
971
|
+
metadata["readiness_audit_log_id"] = readiness_audit_log_id
|
|
972
|
+
emit_investigation_progress(
|
|
973
|
+
progress_callback,
|
|
974
|
+
{
|
|
975
|
+
"step": "readiness_complete",
|
|
976
|
+
"data_question": effective_request.question,
|
|
977
|
+
"decisions": [
|
|
978
|
+
f"Investigation readiness route: {result.route.value}."
|
|
979
|
+
],
|
|
980
|
+
},
|
|
981
|
+
)
|
|
982
|
+
|
|
983
|
+
if result.route == InvestigationRoute.SQL:
|
|
984
|
+
emit_investigation_progress(
|
|
985
|
+
progress_callback,
|
|
986
|
+
{
|
|
987
|
+
"step": "sql",
|
|
988
|
+
"data_question": effective_request.question,
|
|
989
|
+
"decisions": ["Readiness passed; running the normal SQL pipeline."],
|
|
990
|
+
},
|
|
991
|
+
)
|
|
992
|
+
return run_sql_request(
|
|
993
|
+
effective_request,
|
|
994
|
+
datasource_context,
|
|
995
|
+
{**metadata, "investigation_step": "sql"},
|
|
996
|
+
)
|
|
997
|
+
|
|
998
|
+
analysis_results = run_investigation_analysis_tasks(
|
|
999
|
+
effective_request=effective_request,
|
|
1000
|
+
datasource_context=datasource_context,
|
|
1001
|
+
result=result,
|
|
1002
|
+
progress_callback=progress_callback,
|
|
1003
|
+
)
|
|
1004
|
+
metadata["analysis_results"] = analysis_results
|
|
1005
|
+
metadata["analysis_tasks_count"] = len(analysis_results)
|
|
1006
|
+
metadata["investigation_step"] = "final"
|
|
1007
|
+
|
|
1008
|
+
response = QueryResponse(
|
|
1009
|
+
question=effective_request.question,
|
|
1010
|
+
answer=ANALYSIS_MODE_PENDING_ANSWER,
|
|
1011
|
+
sql="",
|
|
1012
|
+
rows=[],
|
|
1013
|
+
metadata={
|
|
1014
|
+
"duration_ms": round((time.perf_counter() - started_at) * 1000, 2),
|
|
1015
|
+
"datasource_id": effective_request.datasource_id,
|
|
1016
|
+
"user_id": effective_request.user_id,
|
|
1017
|
+
"output_classification": OutputClassification.UNKNOWN.value,
|
|
1018
|
+
**metadata,
|
|
1019
|
+
},
|
|
1020
|
+
)
|
|
1021
|
+
record_data_query_audit(effective_request, response)
|
|
1022
|
+
return response
|
|
1023
|
+
|
|
1024
|
+
|
|
1025
|
+
def effective_query_request(request: QueryRequest) -> tuple[QueryRequest, DatasourceContext | None]:
|
|
1026
|
+
datasource_context = get_datasource_schema_context_safe()
|
|
1027
|
+
effective_request = request
|
|
1028
|
+
|
|
1029
|
+
if datasource_context is not None:
|
|
1030
|
+
effective_request = request.model_copy(
|
|
1031
|
+
update={"datasource_id": datasource_context[0].connector_key}
|
|
1032
|
+
)
|
|
1033
|
+
|
|
1034
|
+
return effective_request, datasource_context
|
|
1035
|
+
|
|
1036
|
+
|
|
1037
|
+
def ndjson_line(payload: dict[str, Any]) -> str:
|
|
1038
|
+
return f"{json.dumps(payload, ensure_ascii=False)}\n"
|
|
1039
|
+
|
|
1040
|
+
|
|
1041
|
+
def stream_investigation_response(
|
|
1042
|
+
effective_request: QueryRequest,
|
|
1043
|
+
datasource_context: DatasourceContext | None,
|
|
1044
|
+
) -> Iterator[str]:
|
|
1045
|
+
queue: Queue[dict[str, Any] | None] = Queue()
|
|
1046
|
+
|
|
1047
|
+
def progress_callback(payload: dict[str, Any]) -> None:
|
|
1048
|
+
queue.put({"progress": payload})
|
|
1049
|
+
|
|
1050
|
+
def worker() -> None:
|
|
1051
|
+
try:
|
|
1052
|
+
response = run_investigation_request(
|
|
1053
|
+
effective_request,
|
|
1054
|
+
datasource_context,
|
|
1055
|
+
progress_callback=progress_callback,
|
|
1056
|
+
)
|
|
1057
|
+
queue.put({"final": response.model_dump(mode="json")})
|
|
1058
|
+
except Exception as exc:
|
|
1059
|
+
queue.put(
|
|
1060
|
+
{
|
|
1061
|
+
"error": {
|
|
1062
|
+
"message": str(exc),
|
|
1063
|
+
"type": exc.__class__.__name__,
|
|
1064
|
+
}
|
|
1065
|
+
}
|
|
1066
|
+
)
|
|
1067
|
+
finally:
|
|
1068
|
+
queue.put(None)
|
|
1069
|
+
|
|
1070
|
+
thread = Thread(target=worker, daemon=True)
|
|
1071
|
+
thread.start()
|
|
1072
|
+
|
|
1073
|
+
while True:
|
|
1074
|
+
item = queue.get()
|
|
1075
|
+
if item is None:
|
|
1076
|
+
break
|
|
1077
|
+
|
|
1078
|
+
yield ndjson_line(item)
|
|
1079
|
+
|
|
1080
|
+
thread.join()
|
|
1081
|
+
|
|
1082
|
+
|
|
1083
|
+
@router.post("/query", response_model=QueryResponse)
|
|
1084
|
+
def query(request: QueryRequest) -> QueryResponse:
|
|
1085
|
+
effective_request, datasource_context = effective_query_request(request)
|
|
1086
|
+
|
|
1087
|
+
if effective_request.mode == QueryMode.INVESTIGATION:
|
|
1088
|
+
return run_investigation_request(effective_request, datasource_context)
|
|
1089
|
+
|
|
1090
|
+
return run_sql_request(effective_request, datasource_context)
|
|
1091
|
+
|
|
1092
|
+
|
|
1093
|
+
@router.post("/query/stream")
|
|
1094
|
+
def query_stream(request: QueryRequest) -> StreamingResponse:
|
|
1095
|
+
effective_request, datasource_context = effective_query_request(request)
|
|
1096
|
+
|
|
1097
|
+
if effective_request.mode != QueryMode.INVESTIGATION:
|
|
1098
|
+
def single_response() -> Iterator[str]:
|
|
1099
|
+
yield ndjson_line({"final": query(effective_request).model_dump(mode="json")})
|
|
1100
|
+
|
|
1101
|
+
return StreamingResponse(
|
|
1102
|
+
single_response(),
|
|
1103
|
+
media_type="application/x-ndjson",
|
|
1104
|
+
)
|
|
1105
|
+
|
|
1106
|
+
return StreamingResponse(
|
|
1107
|
+
stream_investigation_response(effective_request, datasource_context),
|
|
1108
|
+
media_type="application/x-ndjson",
|
|
1109
|
+
)
|