gaard-api 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1109 @@
1
+ import json
2
+ import time
3
+ from collections.abc import Iterator
4
+ from queue import Queue
5
+ from threading import Thread
6
+ from typing import Any, Callable
7
+
8
+ from fastapi import APIRouter
9
+ from fastapi.responses import StreamingResponse
10
+
11
+ from gaard_connectors.sqlalchemy.executor import SQLAlchemyQueryExecutor
12
+ from gaard_connectors.sqlalchemy.introspector import SQLAlchemySchemaIntrospector
13
+ from gaard_core.errors import (
14
+ ConfigurationError,
15
+ LlmProviderError,
16
+ QueryExecutionError,
17
+ QueryPipelineStepError,
18
+ SqlValidationError,
19
+ )
20
+ from gaard_core.investigation import (
21
+ InvestigationContext,
22
+ InvestigationLoop,
23
+ InvestigationLoopConfig,
24
+ InvestigationLoopResult,
25
+ InvestigationRoute,
26
+ LlmInvestigationReadinessAgent,
27
+ MockInvestigationReadinessAgent,
28
+ RequiredAnalysisTask,
29
+ )
30
+ from gaard_core.query_intent.llm_classifier import LlmQueryIntentClassifier
31
+ from gaard_core.query_intent.mock_classifier import MockQueryIntentClassifier
32
+ from gaard_core.query_pipeline.llm_sql_generator import LlmSqlGenerator
33
+ from gaard_core.query_pipeline.mock_sql_generator import MockSqlGenerator
34
+ from gaard_core.query_pipeline.models import (
35
+ OutputClassification,
36
+ QueryIntentClassification,
37
+ QueryIntentDecision,
38
+ QueryMode,
39
+ QueryRequest,
40
+ QueryResponse,
41
+ )
42
+ from gaard_core.query_pipeline.pipeline import QueryPipeline
43
+ from gaard_core.result_classifier.llm_classifier import LlmResultClassifier
44
+ from gaard_core.result_classifier.mock_classifier import MockResultClassifier
45
+ from gaard_core.result_interpreter.llm_interpreter import LlmResultInterpreter
46
+ from gaard_core.result_interpreter.mock_interpreter import MockResultInterpreter
47
+ from gaard_core.schema.context import SchemaContextService
48
+ from gaard_core.sql_validator.select_only import SelectOnlySqlValidator
49
+ from gaard_llm.openai_compatible.client import OpenAICompatibleClient
50
+
51
+ from gaard_api.admin.models import DatasourceConnector, DatasourceSchemaCache
52
+ from gaard_api.admin.prompt_runtime import (
53
+ get_investigation_readiness_prompt_compiler,
54
+ get_intent_classification_prompt_compiler,
55
+ get_result_classification_prompt_compiler,
56
+ get_result_interpretation_prompt_compiler,
57
+ get_sql_generation_prompt_compiler,
58
+ )
59
+ from gaard_api.admin.services import (
60
+ ACCESS_ERROR_INTENT_CLASSIFICATION,
61
+ ACCESS_ERROR_SQL_VALIDATION,
62
+ get_active_business_logic_prompt_safe,
63
+ get_datasource_schema_context_safe,
64
+ get_llm_runtime_config_safe,
65
+ get_query_runtime_config_safe,
66
+ learn_business_logic_from_sql_error,
67
+ LlmRuntimeConfig,
68
+ QueryRuntimeConfig,
69
+ record_data_query_access_error_audit,
70
+ record_data_query_audit,
71
+ record_data_query_pipeline_error_audit,
72
+ record_data_query_sql_error_audit,
73
+ upsert_investigation_analysis_business_logic_suggestion,
74
+ )
75
+ from gaard_api.api.v1.schema import get_schema_cache_key
76
+ from gaard_api.core.schema_cache import schema_context_cache
77
+ from gaard_api.core.settings import settings
78
+
79
+ router = APIRouter()
80
+
81
+ DatasourceContext = tuple[DatasourceConnector, DatasourceSchemaCache]
82
+ ProgressCallback = Callable[[dict[str, Any]], None]
83
+
84
+ READ_ONLY_REFUSAL_ANSWER = (
85
+ "Nie mogę tego zrobić. GAARD obsługuje tylko odczyt danych i nie wykonuje "
86
+ "operacji modyfikujących, usuwających ani tworzących dane."
87
+ )
88
+
89
+ READ_ONLY_SCOPE_REFUSAL_ANSWER = (
90
+ "Nie mogę obsłużyć tego zapytania. GAARD odpowiada tylko na pytania, "
91
+ "które można zrealizować jako bezpieczny odczyt danych SQL SELECT."
92
+ )
93
+
94
+ ALLOWLIST_REFUSAL_ANSWER = (
95
+ "Nie mogę wykonać tego zapytania, ponieważ wymagane tabele, kolumny, "
96
+ "relacje albo dowody są poza zakresem dozwolonym dla tego zadania."
97
+ )
98
+
99
+ SQL_SYNTAX_REFUSAL_ANSWER = (
100
+ "Nie mogę wykonać tego zapytania, ponieważ wygenerowany SQL nie przeszedł "
101
+ "walidacji składni lub zasad pojedynczego zapytania."
102
+ )
103
+
104
+ CLARIFICATION_REFUSAL_ANSWER = (
105
+ "Potrzebuję doprecyzowania, zanim bezpiecznie rozpocznę tę analizę."
106
+ )
107
+
108
+ ANALYSIS_MODE_PENDING_ANSWER = (
109
+ "Tryb Investigation wymaga dodatkowej analizy przed wygenerowaniem SQL. "
110
+ "Ścieżka Analysis nie jest jeszcze zaimplementowana."
111
+ )
112
+
113
+ VALIDATION_SQL_PREFIXES = (
114
+ "Only SELECT queries are allowed. ",
115
+ "DDL and DML statements are not allowed. ",
116
+ "Only single-statement SQL queries are allowed. SQL: ",
117
+ "Invalid SQL syntax. ",
118
+ )
119
+
120
+
121
+ def create_llm_client(
122
+ llm_config: LlmRuntimeConfig | None = None,
123
+ ) -> OpenAICompatibleClient:
124
+ llm_config = llm_config or get_llm_runtime_config_safe()
125
+
126
+ if llm_config.provider != "openai-compatible":
127
+ raise ConfigurationError(f"Unsupported GAARD_LLM_PROVIDER: {llm_config.provider}")
128
+
129
+ if llm_config.api_key == "change-me":
130
+ raise ConfigurationError("GAARD_LLM_API_KEY must be configured when using LLM mode.")
131
+
132
+ return OpenAICompatibleClient(
133
+ base_url=llm_config.base_url,
134
+ api_key=llm_config.api_key,
135
+ timeout_seconds=llm_config.timeout_seconds,
136
+ )
137
+
138
+
139
+ def create_sql_generator(
140
+ datasource_context: DatasourceContext | None = None,
141
+ llm_config: LlmRuntimeConfig | None = None,
142
+ runtime_config: QueryRuntimeConfig | None = None,
143
+ ) -> MockSqlGenerator | LlmSqlGenerator:
144
+ runtime_config = runtime_config or get_query_runtime_config_safe()
145
+
146
+ if runtime_config.sql_generation_mode == "mock":
147
+ return MockSqlGenerator()
148
+
149
+ if runtime_config.sql_generation_mode == "llm":
150
+ llm_config = llm_config or get_llm_runtime_config_safe()
151
+
152
+ if datasource_context is None:
153
+ datasource_context = get_datasource_schema_context_safe()
154
+
155
+ if datasource_context is not None:
156
+ connector, schema_cache = datasource_context
157
+ formatted_schema = append_business_logic_to_schema(
158
+ schema_cache.formatted_schema,
159
+ connector.id,
160
+ )
161
+
162
+ return LlmSqlGenerator(
163
+ client=create_llm_client(llm_config),
164
+ model=llm_config.model,
165
+ formatted_schema=formatted_schema,
166
+ dialect=connector.sql_dialect,
167
+ max_rows=runtime_config.query_max_rows,
168
+ extra_body=llm_config.extra_body,
169
+ prompt_compiler=get_sql_generation_prompt_compiler(),
170
+ )
171
+
172
+ introspector = SQLAlchemySchemaIntrospector(
173
+ database_url=settings.gaard_datasource_url,
174
+ )
175
+
176
+ schema_context_service = SchemaContextService(
177
+ introspector=introspector,
178
+ cache=schema_context_cache,
179
+ )
180
+
181
+ schema_context = schema_context_service.get_schema_context(
182
+ get_schema_cache_key()
183
+ )
184
+
185
+ return LlmSqlGenerator(
186
+ client=create_llm_client(llm_config),
187
+ model=llm_config.model,
188
+ formatted_schema=schema_context.formatted_schema,
189
+ dialect=settings.gaard_sql_dialect,
190
+ max_rows=runtime_config.query_max_rows,
191
+ extra_body=llm_config.extra_body,
192
+ prompt_compiler=get_sql_generation_prompt_compiler(),
193
+ )
194
+
195
+ raise ConfigurationError(
196
+ f"Unsupported GAARD_SQL_GENERATION_MODE: {runtime_config.sql_generation_mode}"
197
+ )
198
+
199
+
200
+ def resolve_intent_classification_mode() -> str:
201
+ runtime_config = get_query_runtime_config_safe()
202
+
203
+ if runtime_config.intent_classification_mode == "auto":
204
+ return "llm" if runtime_config.sql_generation_mode == "llm" else "mock"
205
+
206
+ return runtime_config.intent_classification_mode
207
+
208
+
209
+ def create_intent_classifier(
210
+ llm_config: LlmRuntimeConfig | None = None,
211
+ ) -> MockQueryIntentClassifier | LlmQueryIntentClassifier:
212
+ intent_classification_mode = resolve_intent_classification_mode()
213
+
214
+ if intent_classification_mode == "mock":
215
+ return MockQueryIntentClassifier()
216
+
217
+ if intent_classification_mode == "llm":
218
+ llm_config = llm_config or get_llm_runtime_config_safe()
219
+
220
+ return LlmQueryIntentClassifier(
221
+ client=create_llm_client(llm_config),
222
+ model=llm_config.model,
223
+ extra_body=llm_config.extra_body,
224
+ prompt_compiler=get_intent_classification_prompt_compiler(),
225
+ )
226
+
227
+ raise ConfigurationError(
228
+ "Unsupported GAARD_INTENT_CLASSIFICATION_MODE: "
229
+ f"{get_query_runtime_config_safe().intent_classification_mode}"
230
+ )
231
+
232
+
233
+ def create_result_interpreter(
234
+ llm_config: LlmRuntimeConfig | None = None,
235
+ runtime_config: QueryRuntimeConfig | None = None,
236
+ ) -> MockResultInterpreter | LlmResultInterpreter:
237
+ runtime_config = runtime_config or get_query_runtime_config_safe()
238
+
239
+ if runtime_config.result_interpretation_mode == "mock":
240
+ return MockResultInterpreter()
241
+
242
+ if runtime_config.result_interpretation_mode == "llm":
243
+ llm_config = llm_config or get_llm_runtime_config_safe()
244
+
245
+ return LlmResultInterpreter(
246
+ client=create_llm_client(llm_config),
247
+ model=llm_config.model,
248
+ extra_body=llm_config.extra_body,
249
+ prompt_compiler=get_result_interpretation_prompt_compiler(),
250
+ )
251
+
252
+ raise ConfigurationError(
253
+ "Unsupported GAARD_RESULT_INTERPRETATION_MODE: "
254
+ f"{runtime_config.result_interpretation_mode}"
255
+ )
256
+
257
+
258
+ def resolve_output_classification_mode() -> str:
259
+ runtime_config = get_query_runtime_config_safe()
260
+
261
+ if runtime_config.output_classification_mode == "auto":
262
+ return "llm" if runtime_config.result_interpretation_mode == "llm" else "mock"
263
+
264
+ return runtime_config.output_classification_mode
265
+
266
+
267
+ def create_result_classifier(
268
+ llm_config: LlmRuntimeConfig | None = None,
269
+ runtime_config: QueryRuntimeConfig | None = None,
270
+ ) -> MockResultClassifier | LlmResultClassifier:
271
+ runtime_config = runtime_config or get_query_runtime_config_safe()
272
+ output_classification_mode = (
273
+ "llm"
274
+ if runtime_config.output_classification_mode == "auto"
275
+ and runtime_config.result_interpretation_mode == "llm"
276
+ else (
277
+ "mock"
278
+ if runtime_config.output_classification_mode == "auto"
279
+ else runtime_config.output_classification_mode
280
+ )
281
+ )
282
+
283
+ if output_classification_mode == "mock":
284
+ return MockResultClassifier()
285
+
286
+ if output_classification_mode == "llm":
287
+ llm_config = llm_config or get_llm_runtime_config_safe()
288
+
289
+ return LlmResultClassifier(
290
+ client=create_llm_client(llm_config),
291
+ model=llm_config.model,
292
+ extra_body=llm_config.extra_body,
293
+ prompt_compiler=get_result_classification_prompt_compiler(),
294
+ )
295
+
296
+ raise ConfigurationError(
297
+ "Unsupported GAARD_OUTPUT_CLASSIFICATION_MODE: "
298
+ f"{runtime_config.output_classification_mode}"
299
+ )
300
+
301
+
302
+ def create_pipeline(datasource_context: DatasourceContext | None = None) -> QueryPipeline:
303
+ if datasource_context is None:
304
+ datasource_context = get_datasource_schema_context_safe()
305
+
306
+ runtime_config = get_query_runtime_config_safe()
307
+ database_url = (
308
+ datasource_context[0].database_url
309
+ if datasource_context is not None
310
+ else settings.gaard_datasource_url
311
+ )
312
+ sql_dialect = (
313
+ datasource_context[0].sql_dialect
314
+ if datasource_context is not None
315
+ else settings.gaard_sql_dialect
316
+ )
317
+
318
+ executor = SQLAlchemyQueryExecutor(
319
+ database_url=database_url,
320
+ max_rows=runtime_config.query_max_rows,
321
+ )
322
+ llm_config = (
323
+ get_llm_runtime_config_safe()
324
+ if "llm"
325
+ in {
326
+ runtime_config.sql_generation_mode,
327
+ runtime_config.result_interpretation_mode,
328
+ (
329
+ "llm"
330
+ if runtime_config.output_classification_mode == "auto"
331
+ and runtime_config.result_interpretation_mode == "llm"
332
+ else runtime_config.output_classification_mode
333
+ ),
334
+ }
335
+ else None
336
+ )
337
+ output_classification_mode = (
338
+ "llm"
339
+ if runtime_config.output_classification_mode == "auto"
340
+ and runtime_config.result_interpretation_mode == "llm"
341
+ else (
342
+ "mock"
343
+ if runtime_config.output_classification_mode == "auto"
344
+ else runtime_config.output_classification_mode
345
+ )
346
+ )
347
+
348
+ return QueryPipeline(
349
+ sql_generator=create_sql_generator(datasource_context, llm_config, runtime_config),
350
+ sql_validator=SelectOnlySqlValidator(dialect=sql_dialect),
351
+ executor=executor,
352
+ interpreter=create_result_interpreter(llm_config, runtime_config),
353
+ classifier=create_result_classifier(llm_config, runtime_config),
354
+ sql_generation_mode=runtime_config.sql_generation_mode,
355
+ result_interpretation_mode=runtime_config.result_interpretation_mode,
356
+ output_classification_mode=output_classification_mode,
357
+ )
358
+
359
+
360
+ def append_business_logic_to_schema(formatted_schema: str, connector_id: int) -> str:
361
+ business_logic = get_active_business_logic_prompt_safe(connector_id)
362
+
363
+ if not business_logic:
364
+ return formatted_schema
365
+
366
+ return f"{formatted_schema}\n\n{business_logic}"
367
+
368
+
369
+ def schema_and_business_logic_for_investigation(
370
+ datasource_context: DatasourceContext | None,
371
+ ) -> tuple[str, str]:
372
+ if datasource_context is not None:
373
+ connector, schema_cache = datasource_context
374
+ return (
375
+ schema_cache.formatted_schema,
376
+ get_active_business_logic_prompt_safe(connector.id),
377
+ )
378
+
379
+ introspector = SQLAlchemySchemaIntrospector(
380
+ database_url=settings.gaard_datasource_url,
381
+ )
382
+ schema_context_service = SchemaContextService(
383
+ introspector=introspector,
384
+ cache=schema_context_cache,
385
+ )
386
+ schema_context = schema_context_service.get_schema_context(get_schema_cache_key())
387
+
388
+ return schema_context.formatted_schema, ""
389
+
390
+
391
+ def create_investigation_context(
392
+ request: QueryRequest,
393
+ datasource_context: DatasourceContext | None,
394
+ ) -> InvestigationContext:
395
+ formatted_schema, business_logic = schema_and_business_logic_for_investigation(
396
+ datasource_context
397
+ )
398
+
399
+ return InvestigationContext(
400
+ question=request.question,
401
+ datasource_id=request.datasource_id,
402
+ user_id=request.user_id,
403
+ formatted_schema=formatted_schema,
404
+ business_logic=business_logic,
405
+ )
406
+
407
+
408
+ def resolve_investigation_mode(runtime_config: QueryRuntimeConfig) -> str:
409
+ if runtime_config.investigation_mode == "auto":
410
+ return "llm" if runtime_config.sql_generation_mode == "llm" else "mock"
411
+
412
+ return runtime_config.investigation_mode
413
+
414
+
415
+ def create_investigation_readiness_agent(
416
+ runtime_config: QueryRuntimeConfig | None = None,
417
+ llm_config: LlmRuntimeConfig | None = None,
418
+ ) -> MockInvestigationReadinessAgent | LlmInvestigationReadinessAgent:
419
+ runtime_config = runtime_config or get_query_runtime_config_safe()
420
+ investigation_mode = resolve_investigation_mode(runtime_config)
421
+
422
+ if investigation_mode == "mock":
423
+ return MockInvestigationReadinessAgent()
424
+
425
+ if investigation_mode == "llm":
426
+ llm_config = llm_config or get_llm_runtime_config_safe()
427
+ return LlmInvestigationReadinessAgent(
428
+ client=create_llm_client(llm_config),
429
+ model=llm_config.model,
430
+ extra_body=llm_config.extra_body,
431
+ prompt_compiler=get_investigation_readiness_prompt_compiler(),
432
+ )
433
+
434
+ raise ConfigurationError(
435
+ f"Unsupported GAARD_INVESTIGATION_MODE: {runtime_config.investigation_mode}"
436
+ )
437
+
438
+
439
+ def investigation_iteration_metadata(result: InvestigationLoopResult) -> list[dict[str, Any]]:
440
+ return [
441
+ {
442
+ "iteration": item.iteration,
443
+ "agent": item.agent,
444
+ **item.decision.model_dump(mode="json"),
445
+ }
446
+ for item in result.iterations
447
+ ]
448
+
449
+
450
+ def investigation_metadata(
451
+ result: InvestigationLoopResult,
452
+ investigation_mode: str,
453
+ ) -> dict[str, Any]:
454
+ steps = investigation_iteration_metadata(result)
455
+ metadata = {
456
+ "query_mode": QueryMode.INVESTIGATION.value,
457
+ "investigation_backend_status": "readiness_gate_active",
458
+ "investigation_mode": investigation_mode,
459
+ "investigation_route": result.route.value,
460
+ "investigation_loop": {
461
+ "max_iterations": result.max_iterations,
462
+ "iterations_run": len(result.iterations),
463
+ "confidence_threshold": result.confidence_threshold,
464
+ },
465
+ "investigation_steps": steps,
466
+ "investigation_audit_trail": steps,
467
+ }
468
+
469
+ if result.route == InvestigationRoute.ANALYSIS:
470
+ metadata["analysis_mode_status"] = "not_implemented"
471
+
472
+ return metadata
473
+
474
+
475
+ def required_analysis_tasks_from_result(
476
+ result: InvestigationLoopResult,
477
+ ) -> list[RequiredAnalysisTask]:
478
+ decision = result.final_decision
479
+ if decision is None:
480
+ return []
481
+
482
+ if decision.required_analysis_tasks:
483
+ return [
484
+ task
485
+ for task in decision.required_analysis_tasks
486
+ if task.required_analysis.strip()
487
+ ]
488
+
489
+ tasks: list[RequiredAnalysisTask] = []
490
+ for index, required_analysis in enumerate(decision.required_analysis):
491
+ tasks.append(
492
+ RequiredAnalysisTask(
493
+ missing_information=decision.missing_information[index]
494
+ if index < len(decision.missing_information)
495
+ else "",
496
+ required_analysis=required_analysis,
497
+ )
498
+ )
499
+
500
+ return tasks
501
+
502
+
503
+ def emit_investigation_progress(
504
+ progress_callback: ProgressCallback | None,
505
+ payload: dict[str, Any],
506
+ ) -> None:
507
+ if progress_callback is not None:
508
+ progress_callback(payload)
509
+
510
+
511
+ def record_investigation_readiness_audit(
512
+ effective_request: QueryRequest,
513
+ result: InvestigationLoopResult,
514
+ metadata: dict[str, Any],
515
+ ) -> int | None:
516
+ decision = result.final_decision
517
+ response = QueryResponse(
518
+ question=effective_request.question,
519
+ answer=decision.reason if decision is not None else "No readiness decision.",
520
+ sql="",
521
+ rows=[],
522
+ metadata={
523
+ "duration_ms": 0,
524
+ "datasource_id": effective_request.datasource_id,
525
+ "user_id": effective_request.user_id,
526
+ "output_classification": OutputClassification.UNKNOWN.value,
527
+ **metadata,
528
+ "investigation_step": "readiness",
529
+ },
530
+ )
531
+ audit_log = record_data_query_audit(effective_request, response)
532
+ return audit_log.id if audit_log is not None else None
533
+
534
+
535
+ def datasource_connector_id(datasource_context: DatasourceContext | None) -> int | None:
536
+ return datasource_context[0].id if datasource_context is not None else None
537
+
538
+
539
+ def run_investigation_analysis_tasks(
540
+ effective_request: QueryRequest,
541
+ datasource_context: DatasourceContext | None,
542
+ result: InvestigationLoopResult,
543
+ progress_callback: ProgressCallback | None = None,
544
+ ) -> list[dict[str, Any]]:
545
+ analysis_results: list[dict[str, Any]] = []
546
+ tasks = required_analysis_tasks_from_result(result)
547
+
548
+ for index, task in enumerate(tasks):
549
+ emit_investigation_progress(
550
+ progress_callback,
551
+ {
552
+ "step": "analysis_sql",
553
+ "analysis_task_index": index,
554
+ "data_question": task.required_analysis,
555
+ "decisions": [
556
+ f"Running analysis SQL task {index + 1} of {len(tasks)}."
557
+ ],
558
+ },
559
+ )
560
+ analysis_request = effective_request.model_copy(
561
+ update={
562
+ "question": task.required_analysis,
563
+ "mode": QueryMode.SQL,
564
+ }
565
+ )
566
+ analysis_metadata = {
567
+ "query_mode": QueryMode.INVESTIGATION.value,
568
+ "investigation_backend_status": "readiness_gate_active",
569
+ "investigation_route": InvestigationRoute.ANALYSIS.value,
570
+ "investigation_step": "analysis_sql",
571
+ "analysis_task_index": index,
572
+ "analysis_missing_information": task.missing_information,
573
+ "analysis_required_analysis": task.required_analysis,
574
+ "analysis_category": task.category,
575
+ "analysis_expected_output": task.expected_output,
576
+ "original_question": effective_request.question,
577
+ }
578
+ analysis_response = run_sql_request(
579
+ analysis_request,
580
+ datasource_context,
581
+ analysis_metadata,
582
+ )
583
+ learning_result = record_analysis_business_logic_if_possible(
584
+ effective_request=effective_request,
585
+ datasource_context=datasource_context,
586
+ task=task,
587
+ task_index=index,
588
+ analysis_response=analysis_response,
589
+ )
590
+ task_result = {
591
+ "analysis_task_index": index,
592
+ "missing_information": task.missing_information,
593
+ "required_analysis": task.required_analysis,
594
+ "category": task.category,
595
+ "expected_output": task.expected_output,
596
+ "sql": analysis_response.sql,
597
+ "rows": analysis_response.rows,
598
+ "answer": analysis_response.answer,
599
+ "audit_log_id": analysis_response.metadata.get("data_query_audit_id"),
600
+ "business_logic_learning": learning_result,
601
+ }
602
+ analysis_results.append(task_result)
603
+ emit_investigation_progress(
604
+ progress_callback,
605
+ {
606
+ "step": "analysis_sql_complete",
607
+ "analysis_task_index": index,
608
+ "data_question": task.required_analysis,
609
+ "decisions": [
610
+ f"Analysis SQL task {index + 1} completed.",
611
+ business_logic_progress_message(learning_result),
612
+ ],
613
+ },
614
+ )
615
+
616
+ return analysis_results
617
+
618
+
619
+ def record_analysis_business_logic_if_possible(
620
+ effective_request: QueryRequest,
621
+ datasource_context: DatasourceContext | None,
622
+ task: RequiredAnalysisTask,
623
+ task_index: int,
624
+ analysis_response: QueryResponse,
625
+ ) -> dict[str, Any]:
626
+ if analysis_response.metadata.get("blocked"):
627
+ learning_result = {
628
+ "status": "skipped",
629
+ "reason": "Analysis SQL task was blocked and did not produce evidence.",
630
+ }
631
+ else:
632
+ source_audit_id = analysis_response.metadata.get("data_query_audit_id")
633
+ learning_result = upsert_investigation_analysis_business_logic_suggestion(
634
+ connector_id=datasource_connector_id(datasource_context),
635
+ source_audit_id=source_audit_id if isinstance(source_audit_id, int) else None,
636
+ missing_information=task.missing_information,
637
+ required_analysis=task.required_analysis,
638
+ category=task.category,
639
+ analysis_response=analysis_response,
640
+ )
641
+
642
+ record_investigation_business_logic_audit(
643
+ effective_request=effective_request,
644
+ task=task,
645
+ task_index=task_index,
646
+ analysis_response=analysis_response,
647
+ learning_result=learning_result,
648
+ )
649
+ return learning_result
650
+
651
+
652
+ def record_investigation_business_logic_audit(
653
+ effective_request: QueryRequest,
654
+ task: RequiredAnalysisTask,
655
+ task_index: int,
656
+ analysis_response: QueryResponse,
657
+ learning_result: dict[str, Any],
658
+ ) -> None:
659
+ response = QueryResponse(
660
+ question=effective_request.question,
661
+ answer=business_logic_progress_message(learning_result),
662
+ sql=analysis_response.sql,
663
+ rows=[],
664
+ metadata={
665
+ "duration_ms": 0,
666
+ "datasource_id": effective_request.datasource_id,
667
+ "user_id": effective_request.user_id,
668
+ "output_classification": OutputClassification.UNKNOWN.value,
669
+ "query_mode": QueryMode.INVESTIGATION.value,
670
+ "investigation_backend_status": "readiness_gate_active",
671
+ "investigation_route": InvestigationRoute.ANALYSIS.value,
672
+ "investigation_step": "analysis_business_logic",
673
+ "analysis_task_index": task_index,
674
+ "analysis_missing_information": task.missing_information,
675
+ "analysis_required_analysis": task.required_analysis,
676
+ "analysis_category": task.category,
677
+ "analysis_source_audit_log_id": analysis_response.metadata.get(
678
+ "data_query_audit_id"
679
+ ),
680
+ "business_logic_learning": learning_result,
681
+ },
682
+ )
683
+ record_data_query_audit(effective_request, response)
684
+
685
+
686
+ def business_logic_progress_message(learning_result: dict[str, Any]) -> str:
687
+ status = str(learning_result.get("status") or "")
688
+ if status == "created":
689
+ return "Business logic suggestion was created and is pending approval."
690
+ if status == "existing":
691
+ return "Business logic suggestion already exists; no duplicate was created."
692
+ if status == "skipped":
693
+ return f"Business logic suggestion was skipped: {learning_result.get('reason')}"
694
+ return "Business logic suggestion step completed."
695
+
696
+
697
+ def extract_sql_from_validation_error(error_message: str) -> str:
698
+ for prefix in VALIDATION_SQL_PREFIXES:
699
+ if error_message.startswith(prefix):
700
+ return error_message.removeprefix(prefix).strip()
701
+
702
+ return ""
703
+
704
+
705
+ def validation_error_metadata(exc: SqlValidationError) -> dict[str, Any]:
706
+ if exc.metadata.get("primary_error_category"):
707
+ return exc.metadata
708
+
709
+ category = "sql.validation.write_operation" if any(
710
+ text in exc.message
711
+ for text in (
712
+ "Only SELECT queries are allowed",
713
+ "DDL and DML statements are not allowed",
714
+ )
715
+ ) else (
716
+ "sql.validation.syntax"
717
+ if any(
718
+ text in exc.message
719
+ for text in (
720
+ "Only single-statement SQL queries are allowed",
721
+ "Invalid SQL syntax",
722
+ )
723
+ )
724
+ else "sql.validation.disallowed_column"
725
+ )
726
+ return {
727
+ **exc.metadata,
728
+ "primary_error_category": category,
729
+ "error_categories": [category],
730
+ }
731
+
732
+
733
+ def is_read_only_intent(intent: QueryIntentClassification) -> bool:
734
+ return intent.decision == QueryIntentDecision.READ_ONLY_DATA_QUESTION
735
+
736
+
737
+ def intent_metadata(
738
+ intent: QueryIntentClassification,
739
+ intent_classification_mode: str,
740
+ ) -> dict[str, Any]:
741
+ model_response = intent.model_response or {
742
+ "decision": intent.decision.value,
743
+ "confidence": intent.confidence,
744
+ "reason": intent.reason,
745
+ }
746
+
747
+ return {
748
+ "intent_classification_mode": intent_classification_mode,
749
+ "intent_decision": intent.decision.value,
750
+ "intent_confidence": intent.confidence,
751
+ "intent_reason": intent.reason,
752
+ "intent_model_response": model_response,
753
+ }
754
+
755
+
756
+ def build_access_refusal_response(
757
+ request: QueryRequest,
758
+ reason: str,
759
+ sql: str = "",
760
+ metadata: dict[str, Any] | None = None,
761
+ ) -> QueryResponse:
762
+ metadata = metadata or {}
763
+ return QueryResponse(
764
+ question=request.question,
765
+ answer=access_refusal_answer(reason, metadata),
766
+ sql=sql,
767
+ rows=[],
768
+ metadata={
769
+ "duration_ms": 0,
770
+ "datasource_id": request.datasource_id,
771
+ "user_id": request.user_id,
772
+ "output_classification": OutputClassification.UNKNOWN.value,
773
+ "blocked": True,
774
+ "blocked_reason": reason,
775
+ **metadata,
776
+ },
777
+ )
778
+
779
+
780
+ def access_refusal_answer(reason: str, metadata: dict[str, Any]) -> str:
781
+ if reason != ACCESS_ERROR_SQL_VALIDATION:
782
+ return READ_ONLY_SCOPE_REFUSAL_ANSWER
783
+
784
+ categories = set(metadata.get("error_categories") or [])
785
+ primary = metadata.get("primary_error_category")
786
+ if primary:
787
+ categories.add(str(primary))
788
+
789
+ if "sql.validation.write_operation" in categories:
790
+ return READ_ONLY_REFUSAL_ANSWER
791
+ if "intent.ambiguous_requires_clarification" in categories:
792
+ return CLARIFICATION_REFUSAL_ANSWER
793
+ if "sql.validation.syntax" in categories:
794
+ return SQL_SYNTAX_REFUSAL_ANSWER
795
+ if categories & {
796
+ "sql.validation.disallowed_column",
797
+ "sql.validation.disallowed_table",
798
+ "sql.validation.disallowed_relationship",
799
+ "sql.validation.select_star",
800
+ "task.inconsistent_allowlist",
801
+ "task.insufficient_evidence",
802
+ }:
803
+ return ALLOWLIST_REFUSAL_ANSWER
804
+ return READ_ONLY_SCOPE_REFUSAL_ANSWER
805
+
806
+
807
+ def run_sql_request(
808
+ effective_request: QueryRequest,
809
+ datasource_context: DatasourceContext | None,
810
+ extra_metadata: dict[str, Any] | None = None,
811
+ ) -> QueryResponse:
812
+ extra_metadata = extra_metadata or {}
813
+ intent_mode = resolve_intent_classification_mode()
814
+ intent_llm_config = (
815
+ get_llm_runtime_config_safe() if intent_mode == "llm" else None
816
+ )
817
+ try:
818
+ intent = create_intent_classifier(intent_llm_config).classify(effective_request)
819
+ except LlmProviderError as exc:
820
+ audit_log = record_data_query_pipeline_error_audit(
821
+ request=effective_request,
822
+ sql="",
823
+ error_code=exc.code,
824
+ error_message=exc.message,
825
+ error_detail=exc.message,
826
+ pipeline_phase="intent_classification",
827
+ metadata={**extra_metadata, "intent_classification_mode": intent_mode},
828
+ )
829
+ learn_business_logic_from_sql_error(
830
+ connector_id=datasource_context[0].id if datasource_context is not None else None,
831
+ audit_id=audit_log.id if audit_log is not None else None,
832
+ )
833
+ raise
834
+
835
+ current_intent_metadata = intent_metadata(intent, intent_mode)
836
+ audit_metadata = {**current_intent_metadata, **extra_metadata}
837
+
838
+ if not is_read_only_intent(intent):
839
+ response = build_access_refusal_response(
840
+ effective_request,
841
+ ACCESS_ERROR_INTENT_CLASSIFICATION,
842
+ metadata=audit_metadata,
843
+ )
844
+ audit_log = record_data_query_access_error_audit(
845
+ request=effective_request,
846
+ answer=response.answer,
847
+ reason=ACCESS_ERROR_INTENT_CLASSIFICATION,
848
+ metadata=audit_metadata,
849
+ )
850
+ if audit_log is not None:
851
+ response.metadata["data_query_audit_id"] = audit_log.id
852
+ return response
853
+
854
+ pipeline = create_pipeline(datasource_context)
855
+ try:
856
+ response = pipeline.handle(effective_request)
857
+ except QueryExecutionError as exc:
858
+ audit_log = record_data_query_sql_error_audit(
859
+ request=effective_request,
860
+ sql=exc.sql,
861
+ error_code=exc.code,
862
+ error_message=exc.message,
863
+ error_detail=exc.error_detail,
864
+ metadata=extra_metadata,
865
+ )
866
+ learn_business_logic_from_sql_error(
867
+ connector_id=datasource_context[0].id if datasource_context is not None else None,
868
+ audit_id=audit_log.id if audit_log is not None else None,
869
+ )
870
+ raise
871
+ except QueryPipelineStepError as exc:
872
+ audit_log = record_data_query_pipeline_error_audit(
873
+ request=effective_request,
874
+ sql=exc.sql,
875
+ error_code=exc.code,
876
+ error_message=exc.message,
877
+ error_detail=exc.error_detail,
878
+ pipeline_phase=exc.phase,
879
+ metadata=audit_metadata,
880
+ )
881
+ learn_business_logic_from_sql_error(
882
+ connector_id=datasource_context[0].id if datasource_context is not None else None,
883
+ audit_id=audit_log.id if audit_log is not None else None,
884
+ )
885
+ raise
886
+ except SqlValidationError as exc:
887
+ validation_metadata = validation_error_metadata(exc)
888
+ response = build_access_refusal_response(
889
+ effective_request,
890
+ ACCESS_ERROR_SQL_VALIDATION,
891
+ sql=extract_sql_from_validation_error(exc.message),
892
+ metadata={**audit_metadata, **validation_metadata},
893
+ )
894
+ audit_log = record_data_query_access_error_audit(
895
+ request=effective_request,
896
+ answer=response.answer,
897
+ reason=ACCESS_ERROR_SQL_VALIDATION,
898
+ sql=response.sql,
899
+ error_code=exc.code,
900
+ error_detail=exc.message,
901
+ metadata={**audit_metadata, **validation_metadata},
902
+ )
903
+ if audit_log is not None:
904
+ response.metadata["data_query_audit_id"] = audit_log.id
905
+ learn_business_logic_from_sql_error(
906
+ connector_id=datasource_context[0].id if datasource_context is not None else None,
907
+ audit_id=audit_log.id if audit_log is not None else None,
908
+ )
909
+ return response
910
+
911
+ response.metadata.update(current_intent_metadata)
912
+ response.metadata.update(extra_metadata)
913
+ audit_log = record_data_query_audit(effective_request, response)
914
+ if audit_log is not None:
915
+ response.metadata["data_query_audit_id"] = audit_log.id
916
+
917
+ return response
918
+
919
+
920
+ def run_investigation_request(
921
+ effective_request: QueryRequest,
922
+ datasource_context: DatasourceContext | None,
923
+ progress_callback: ProgressCallback | None = None,
924
+ ) -> QueryResponse:
925
+ started_at = time.perf_counter()
926
+ runtime_config = get_query_runtime_config_safe()
927
+ investigation_mode = resolve_investigation_mode(runtime_config)
928
+ context = create_investigation_context(effective_request, datasource_context)
929
+ loop_config = InvestigationLoopConfig(max_iterations=1)
930
+
931
+ emit_investigation_progress(
932
+ progress_callback,
933
+ {
934
+ "step": "readiness",
935
+ "data_question": effective_request.question,
936
+ "decisions": ["Running Investigation readiness check."],
937
+ },
938
+ )
939
+ try:
940
+ result = InvestigationLoop(
941
+ readiness_agent=create_investigation_readiness_agent(runtime_config),
942
+ config=loop_config,
943
+ ).run(context)
944
+ except LlmProviderError as exc:
945
+ record_data_query_pipeline_error_audit(
946
+ request=effective_request,
947
+ sql="",
948
+ error_code=exc.code,
949
+ error_message=exc.message,
950
+ error_detail=exc.message,
951
+ pipeline_phase="investigation_readiness",
952
+ metadata={
953
+ "query_mode": QueryMode.INVESTIGATION.value,
954
+ "investigation_mode": investigation_mode,
955
+ "investigation_backend_status": "readiness_gate_active",
956
+ },
957
+ )
958
+ raise
959
+
960
+ metadata = investigation_metadata(result, investigation_mode)
961
+ metadata["investigation_readiness_duration_ms"] = round(
962
+ (time.perf_counter() - started_at) * 1000,
963
+ 2,
964
+ )
965
+ readiness_audit_log_id = record_investigation_readiness_audit(
966
+ effective_request,
967
+ result,
968
+ metadata,
969
+ )
970
+ if readiness_audit_log_id is not None:
971
+ metadata["readiness_audit_log_id"] = readiness_audit_log_id
972
+ emit_investigation_progress(
973
+ progress_callback,
974
+ {
975
+ "step": "readiness_complete",
976
+ "data_question": effective_request.question,
977
+ "decisions": [
978
+ f"Investigation readiness route: {result.route.value}."
979
+ ],
980
+ },
981
+ )
982
+
983
+ if result.route == InvestigationRoute.SQL:
984
+ emit_investigation_progress(
985
+ progress_callback,
986
+ {
987
+ "step": "sql",
988
+ "data_question": effective_request.question,
989
+ "decisions": ["Readiness passed; running the normal SQL pipeline."],
990
+ },
991
+ )
992
+ return run_sql_request(
993
+ effective_request,
994
+ datasource_context,
995
+ {**metadata, "investigation_step": "sql"},
996
+ )
997
+
998
+ analysis_results = run_investigation_analysis_tasks(
999
+ effective_request=effective_request,
1000
+ datasource_context=datasource_context,
1001
+ result=result,
1002
+ progress_callback=progress_callback,
1003
+ )
1004
+ metadata["analysis_results"] = analysis_results
1005
+ metadata["analysis_tasks_count"] = len(analysis_results)
1006
+ metadata["investigation_step"] = "final"
1007
+
1008
+ response = QueryResponse(
1009
+ question=effective_request.question,
1010
+ answer=ANALYSIS_MODE_PENDING_ANSWER,
1011
+ sql="",
1012
+ rows=[],
1013
+ metadata={
1014
+ "duration_ms": round((time.perf_counter() - started_at) * 1000, 2),
1015
+ "datasource_id": effective_request.datasource_id,
1016
+ "user_id": effective_request.user_id,
1017
+ "output_classification": OutputClassification.UNKNOWN.value,
1018
+ **metadata,
1019
+ },
1020
+ )
1021
+ record_data_query_audit(effective_request, response)
1022
+ return response
1023
+
1024
+
1025
+ def effective_query_request(request: QueryRequest) -> tuple[QueryRequest, DatasourceContext | None]:
1026
+ datasource_context = get_datasource_schema_context_safe()
1027
+ effective_request = request
1028
+
1029
+ if datasource_context is not None:
1030
+ effective_request = request.model_copy(
1031
+ update={"datasource_id": datasource_context[0].connector_key}
1032
+ )
1033
+
1034
+ return effective_request, datasource_context
1035
+
1036
+
1037
+ def ndjson_line(payload: dict[str, Any]) -> str:
1038
+ return f"{json.dumps(payload, ensure_ascii=False)}\n"
1039
+
1040
+
1041
+ def stream_investigation_response(
1042
+ effective_request: QueryRequest,
1043
+ datasource_context: DatasourceContext | None,
1044
+ ) -> Iterator[str]:
1045
+ queue: Queue[dict[str, Any] | None] = Queue()
1046
+
1047
+ def progress_callback(payload: dict[str, Any]) -> None:
1048
+ queue.put({"progress": payload})
1049
+
1050
+ def worker() -> None:
1051
+ try:
1052
+ response = run_investigation_request(
1053
+ effective_request,
1054
+ datasource_context,
1055
+ progress_callback=progress_callback,
1056
+ )
1057
+ queue.put({"final": response.model_dump(mode="json")})
1058
+ except Exception as exc:
1059
+ queue.put(
1060
+ {
1061
+ "error": {
1062
+ "message": str(exc),
1063
+ "type": exc.__class__.__name__,
1064
+ }
1065
+ }
1066
+ )
1067
+ finally:
1068
+ queue.put(None)
1069
+
1070
+ thread = Thread(target=worker, daemon=True)
1071
+ thread.start()
1072
+
1073
+ while True:
1074
+ item = queue.get()
1075
+ if item is None:
1076
+ break
1077
+
1078
+ yield ndjson_line(item)
1079
+
1080
+ thread.join()
1081
+
1082
+
1083
+ @router.post("/query", response_model=QueryResponse)
1084
+ def query(request: QueryRequest) -> QueryResponse:
1085
+ effective_request, datasource_context = effective_query_request(request)
1086
+
1087
+ if effective_request.mode == QueryMode.INVESTIGATION:
1088
+ return run_investigation_request(effective_request, datasource_context)
1089
+
1090
+ return run_sql_request(effective_request, datasource_context)
1091
+
1092
+
1093
+ @router.post("/query/stream")
1094
+ def query_stream(request: QueryRequest) -> StreamingResponse:
1095
+ effective_request, datasource_context = effective_query_request(request)
1096
+
1097
+ if effective_request.mode != QueryMode.INVESTIGATION:
1098
+ def single_response() -> Iterator[str]:
1099
+ yield ndjson_line({"final": query(effective_request).model_dump(mode="json")})
1100
+
1101
+ return StreamingResponse(
1102
+ single_response(),
1103
+ media_type="application/x-ndjson",
1104
+ )
1105
+
1106
+ return StreamingResponse(
1107
+ stream_investigation_response(effective_request, datasource_context),
1108
+ media_type="application/x-ndjson",
1109
+ )