gaard-api 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,2142 @@
1
+ import hashlib
2
+ import json
3
+ import re
4
+ from dataclasses import dataclass
5
+ from datetime import UTC, datetime, timedelta
6
+ from typing import Any
7
+
8
+ from sqlalchemy import create_engine, text
9
+ from sqlalchemy import delete, desc, select
10
+ from sqlalchemy.exc import SQLAlchemyError
11
+ from sqlalchemy.orm import Session
12
+
13
+ from gaard_connectors.sqlalchemy.introspector import SQLAlchemySchemaIntrospector
14
+ from gaard_core.errors import LlmProviderError
15
+ from gaard_core.llm_output import remove_thinking_blocks
16
+ from gaard_core.query_pipeline.models import OutputClassification, QueryRequest, QueryResponse
17
+ from gaard_core.schema.models import ColumnInfo, DatabaseSchema, TableInfo
18
+ from gaard_llm.openai_compatible.client import OpenAICompatibleClient
19
+ from gaard_llm.providers.models import ChatCompletionRequest, ChatMessage
20
+
21
+ from gaard_api.admin.defaults import DEFAULT_GOVERNANCE_POLICY_CONFIG
22
+ from gaard_api.admin.database import create_session
23
+ from gaard_api.admin.models import (
24
+ AdminAuditLog,
25
+ AdminSetting,
26
+ BusinessLogicSuggestion,
27
+ BusinessKnowledgeClaim,
28
+ DataQueryAuditLog,
29
+ DataQueryAuditType,
30
+ DatasourceConnector,
31
+ DatasourceSchemaCache,
32
+ OverviewWidget,
33
+ PromptTemplate,
34
+ )
35
+ from gaard_api.core.settings import settings
36
+
37
+
38
+ def json_dumps(value: Any) -> str:
39
+ return json.dumps(value, ensure_ascii=False, sort_keys=True)
40
+
41
+
42
+ def json_loads(value: str) -> Any:
43
+ return json.loads(value or "{}")
44
+
45
+
46
+ SUPPORTED_DATABASE_TYPES = {
47
+ "sqlite": ("sqlite://", "sqlite"),
48
+ "postgresql": ("postgresql://", "postgresql+psycopg://"),
49
+ "mysql": ("mysql://", "mysql+pymysql://"),
50
+ }
51
+
52
+
53
+ def validate_datasource_url(database_type: str, database_url: str) -> None:
54
+ prefixes = SUPPORTED_DATABASE_TYPES.get(database_type)
55
+
56
+ if prefixes is None:
57
+ raise ValueError("Unsupported datasource type.")
58
+
59
+ if not database_url.startswith(prefixes):
60
+ raise ValueError(
61
+ f"Datasource URL for {database_type} must start with one of: "
62
+ f"{', '.join(prefixes)}"
63
+ )
64
+
65
+
66
+ def mask_database_url(database_url: str) -> str:
67
+ if "://" not in database_url or "@" not in database_url:
68
+ return database_url
69
+
70
+ scheme, rest = database_url.split("://", 1)
71
+ credentials, host = rest.split("@", 1)
72
+
73
+ if ":" not in credentials:
74
+ return f"{scheme}://***@{host}"
75
+
76
+ username, _password = credentials.split(":", 1)
77
+ return f"{scheme}://{username}:***@{host}"
78
+
79
+
80
+ def record_admin_audit(
81
+ session: Session,
82
+ actor: str,
83
+ action: str,
84
+ resource_type: str,
85
+ resource_id: str = "",
86
+ details: dict[str, Any] | None = None,
87
+ ) -> None:
88
+ session.add(
89
+ AdminAuditLog(
90
+ actor=actor,
91
+ action=action,
92
+ resource_type=resource_type,
93
+ resource_id=resource_id,
94
+ details_json=json_dumps(details or {}),
95
+ )
96
+ )
97
+
98
+
99
+ DATA_QUERY_AUDIT_INFO = DataQueryAuditType.INFO.value
100
+ DATA_QUERY_AUDIT_SQL_ERROR = DataQueryAuditType.SQL_ERROR.value
101
+ DATA_QUERY_AUDIT_ACCESS_ERROR = DataQueryAuditType.ACCESS_ERROR.value
102
+
103
+ ACCESS_ERROR_INTENT_CLASSIFICATION = "access.intent_classification"
104
+ ACCESS_ERROR_SQL_VALIDATION = "access.sql_validation"
105
+
106
+ SQL_ERROR_SCHEMA_MISSING_TABLE = "schema.missing_table"
107
+ SQL_ERROR_SCHEMA_MISSING_COLUMN = "schema.missing_column"
108
+ SQL_ERROR_DIALECT_SYNTAX = "dialect.syntax"
109
+ SQL_ERROR_PERMISSION_ACCESS_DENIED = "permission.access_denied"
110
+ SQL_ERROR_RUNTIME_DATA_TYPE = "runtime.data_type"
111
+ SQL_ERROR_LLM_PROVIDER = "llm.provider_error"
112
+ SQL_ERROR_UNKNOWN = "unknown"
113
+
114
+ BUSINESS_LOGIC_STATUS_PENDING = "pending"
115
+ BUSINESS_LOGIC_STATUS_ACTIVE = "active"
116
+ BUSINESS_LOGIC_SAFETY_SAFE = "safe"
117
+ BUSINESS_LOGIC_SAFETY_REVIEW = "review"
118
+ BUSINESS_LOGIC_LEARNING_STATUS_SKIPPED = "skipped"
119
+
120
+ OVERVIEW_WIDGET_SCALAR = "scalar"
121
+ OVERVIEW_WIDGET_TIMESERIES = "timeseries"
122
+ OVERVIEW_WIDGET_TABLE = "table"
123
+ OVERVIEW_WIDGET_RESULT_DATA = "data"
124
+ OVERVIEW_WIDGET_RESULT_INTERPRETATION = "interpretation"
125
+
126
+ LLM_SETTING_PROVIDER = "gaard_llm_provider"
127
+ LLM_SETTING_BASE_URL = "gaard_llm_base_url"
128
+ LLM_SETTING_API_KEY = "gaard_llm_api_key"
129
+ LLM_SETTING_MODEL = "gaard_llm_model"
130
+ LLM_SETTING_TIMEOUT_SECONDS = "gaard_llm_timeout_seconds"
131
+ LLM_SETTING_EXTRA_BODY = "gaard_llm_extra_body"
132
+
133
+ INTENT_CLASSIFICATION_MODE_SETTING = "gaard_intent_classification_mode"
134
+ SQL_GENERATION_MODE_SETTING = "gaard_sql_generation_mode"
135
+ RESULT_INTERPRETATION_MODE_SETTING = "gaard_result_interpretation_mode"
136
+ OUTPUT_CLASSIFICATION_MODE_SETTING = "gaard_output_classification_mode"
137
+ INVESTIGATION_MODE_SETTING = "gaard_investigation_mode"
138
+ INVESTIGATION_AMBIGUITY_MODE_SETTING = "gaard_investigation_ambiguity_mode"
139
+ QUERY_MAX_ROWS_SETTING = "gaard_query_max_rows"
140
+ QUERY_TIMEOUT_SECONDS_SETTING = "gaard_query_timeout_seconds"
141
+ GOVERNANCE_POLICY_SETTING = "gaard_governance_policy"
142
+
143
+ SYSTEM_DATASOURCE_CONNECTOR_KEYS = {"metadata-db"}
144
+
145
+
146
+ @dataclass(frozen=True)
147
+ class LlmRuntimeConfig:
148
+ provider: str
149
+ base_url: str
150
+ api_key: str
151
+ model: str
152
+ extra_body: dict[str, Any]
153
+ timeout_seconds: int
154
+
155
+
156
+ @dataclass(frozen=True)
157
+ class QueryRuntimeConfig:
158
+ intent_classification_mode: str
159
+ sql_generation_mode: str
160
+ result_interpretation_mode: str
161
+ output_classification_mode: str
162
+ investigation_mode: str
163
+ investigation_ambiguity_mode: str
164
+ query_max_rows: int
165
+ query_timeout_seconds: int
166
+
167
+
168
+ def record_data_query_audit(
169
+ request: QueryRequest,
170
+ response: QueryResponse,
171
+ ) -> DataQueryAuditLog | None:
172
+ output_classification = coerce_output_classification(
173
+ response.metadata.get("output_classification")
174
+ )
175
+
176
+ return _record_data_query_audit(
177
+ request=request,
178
+ answer=response.answer,
179
+ sql=response.sql,
180
+ audit_type=DataQueryAuditType.INFO,
181
+ output_classification=output_classification,
182
+ metadata=response.metadata,
183
+ )
184
+
185
+
186
+ def record_data_query_sql_error_audit(
187
+ request: QueryRequest,
188
+ sql: str,
189
+ error_code: str,
190
+ error_message: str,
191
+ error_detail: str = "",
192
+ metadata: dict[str, Any] | None = None,
193
+ ) -> DataQueryAuditLog | None:
194
+ audit_metadata = {
195
+ "error_category": SQL_ERROR_UNKNOWN,
196
+ "error_code": error_code,
197
+ "error_message": error_message,
198
+ "error_detail": error_detail,
199
+ "datasource_id": request.datasource_id,
200
+ "user_id": request.user_id,
201
+ }
202
+ audit_metadata.update(metadata or {})
203
+
204
+ return _record_data_query_audit(
205
+ request=request,
206
+ answer=error_message,
207
+ sql=sql,
208
+ audit_type=DataQueryAuditType.SQL_ERROR,
209
+ output_classification=OutputClassification.UNKNOWN,
210
+ metadata=audit_metadata,
211
+ )
212
+
213
+
214
+ def record_data_query_pipeline_error_audit(
215
+ request: QueryRequest,
216
+ sql: str,
217
+ error_code: str,
218
+ error_message: str,
219
+ pipeline_phase: str,
220
+ error_detail: str = "",
221
+ metadata: dict[str, Any] | None = None,
222
+ ) -> DataQueryAuditLog | None:
223
+ audit_metadata = {
224
+ "error_category": SQL_ERROR_LLM_PROVIDER
225
+ if error_code == "LLM_PROVIDER_ERROR"
226
+ else SQL_ERROR_UNKNOWN,
227
+ "error_code": error_code,
228
+ "error_message": error_message,
229
+ "error_detail": error_detail,
230
+ "pipeline_phase": pipeline_phase,
231
+ "datasource_id": request.datasource_id,
232
+ "user_id": request.user_id,
233
+ }
234
+ audit_metadata.update(metadata or {})
235
+
236
+ return _record_data_query_audit(
237
+ request=request,
238
+ answer=error_message,
239
+ sql=sql,
240
+ audit_type=DataQueryAuditType.SQL_ERROR,
241
+ output_classification=OutputClassification.UNKNOWN,
242
+ metadata=audit_metadata,
243
+ )
244
+
245
+
246
+ def record_data_query_access_error_audit(
247
+ request: QueryRequest,
248
+ answer: str,
249
+ reason: str,
250
+ sql: str = "",
251
+ error_code: str = "ACCESS_ERROR",
252
+ error_detail: str = "",
253
+ metadata: dict[str, Any] | None = None,
254
+ ) -> DataQueryAuditLog | None:
255
+ audit_metadata = {
256
+ "error_category": reason,
257
+ "error_code": error_code,
258
+ "error_message": answer,
259
+ "error_detail": error_detail,
260
+ "datasource_id": request.datasource_id,
261
+ "user_id": request.user_id,
262
+ }
263
+ audit_metadata.update(metadata or {})
264
+
265
+ return _record_data_query_audit(
266
+ request=request,
267
+ answer=answer,
268
+ sql=sql,
269
+ audit_type=DataQueryAuditType.ACCESS_ERROR,
270
+ output_classification=OutputClassification.UNKNOWN,
271
+ metadata=audit_metadata,
272
+ )
273
+
274
+
275
+ def _record_data_query_audit(
276
+ request: QueryRequest,
277
+ answer: str,
278
+ sql: str,
279
+ audit_type: DataQueryAuditType | str,
280
+ output_classification: OutputClassification | str,
281
+ metadata: dict[str, Any],
282
+ ) -> DataQueryAuditLog | None:
283
+ try:
284
+ session = create_session()
285
+ except SQLAlchemyError:
286
+ return None
287
+
288
+ try:
289
+ apply_data_query_audit_retention(session)
290
+ audit_metadata = dict(metadata)
291
+ audit_metadata.pop("audit_type", None)
292
+ audit_metadata.pop("output_classification", None)
293
+ log = DataQueryAuditLog(
294
+ type=coerce_data_query_audit_type(audit_type),
295
+ output_classification=coerce_output_classification(output_classification),
296
+ user_id=request.user_id,
297
+ datasource_id=request.datasource_id,
298
+ question=request.question,
299
+ answer=answer,
300
+ sql=sql,
301
+ metadata_json=json_dumps(audit_metadata),
302
+ )
303
+ session.add(log)
304
+ session.commit()
305
+ return log
306
+ except SQLAlchemyError:
307
+ session.rollback()
308
+ return None
309
+ finally:
310
+ session.close()
311
+
312
+
313
+ def get_setting(session: Session, key: str, default: str) -> str:
314
+ setting = session.get(AdminSetting, key)
315
+
316
+ if setting is None:
317
+ return default
318
+
319
+ return setting.value
320
+
321
+
322
+ def get_int_setting(
323
+ session: Session,
324
+ key: str,
325
+ default: int,
326
+ minimum: int = 1,
327
+ ) -> int:
328
+ value = get_setting(session, key, str(default))
329
+
330
+ try:
331
+ parsed = int(value)
332
+ except (TypeError, ValueError):
333
+ return default
334
+
335
+ return max(minimum, parsed)
336
+
337
+
338
+ def set_setting(session: Session, key: str, value: str, actor: str) -> AdminSetting:
339
+ setting = session.get(AdminSetting, key)
340
+
341
+ if setting is None:
342
+ setting = AdminSetting(key=key, value=value, updated_by=actor)
343
+ session.add(setting)
344
+ else:
345
+ setting.value = value
346
+ setting.updated_by = actor
347
+
348
+ return setting
349
+
350
+
351
+ def default_governance_policy_config() -> dict[str, Any]:
352
+ return json.loads(json_dumps(DEFAULT_GOVERNANCE_POLICY_CONFIG))
353
+
354
+
355
+ def normalize_bool_setting(value: Any, field_name: str) -> bool:
356
+ if isinstance(value, bool):
357
+ return value
358
+
359
+ raise ValueError(f"{field_name} must be a boolean.")
360
+
361
+
362
+ def normalize_string_list(value: Any, field_name: str, *, lower: bool = False) -> list[str]:
363
+ if not isinstance(value, list):
364
+ raise ValueError(f"{field_name} must be a list of strings.")
365
+
366
+ normalized: list[str] = []
367
+ seen: set[str] = set()
368
+ for item in value:
369
+ if not isinstance(item, str):
370
+ raise ValueError(f"{field_name} must contain only strings.")
371
+ text = item.strip()
372
+ if not text:
373
+ continue
374
+ text = text.lower() if lower else text
375
+ if text in seen:
376
+ continue
377
+ seen.add(text)
378
+ normalized.append(text)
379
+
380
+ return normalized
381
+
382
+
383
+ def normalize_forbidden_columns(value: Any) -> dict[str, list[str]]:
384
+ if not isinstance(value, dict):
385
+ raise ValueError("privacy.forbidden_columns must be an object.")
386
+
387
+ normalized: dict[str, list[str]] = {}
388
+ for table_name, columns in value.items():
389
+ if not isinstance(table_name, str) or not table_name.strip():
390
+ raise ValueError("privacy.forbidden_columns table names must be non-empty strings.")
391
+ normalized_columns = normalize_string_list(
392
+ columns,
393
+ f"privacy.forbidden_columns.{table_name}",
394
+ )
395
+ if normalized_columns:
396
+ normalized[table_name.strip()] = normalized_columns
397
+
398
+ return normalized
399
+
400
+
401
+ def normalize_pii_column_names(value: Any) -> dict[str, list[str]]:
402
+ if isinstance(value, list):
403
+ return {"default": normalize_string_list(value, "pii_column_names", lower=True)}
404
+
405
+ if not isinstance(value, dict):
406
+ raise ValueError("pii_column_names must be an object of string lists.")
407
+
408
+ normalized: dict[str, list[str]] = {}
409
+ for category, column_names in value.items():
410
+ if not isinstance(category, str) or not category.strip():
411
+ raise ValueError("pii_column_names categories must be non-empty strings.")
412
+ normalized_columns = normalize_string_list(
413
+ column_names,
414
+ f"pii_column_names.{category}",
415
+ lower=True,
416
+ )
417
+ if normalized_columns:
418
+ normalized[category.strip()] = normalized_columns
419
+
420
+ return normalized
421
+
422
+
423
+ def normalize_governance_policy_config(value: Any) -> dict[str, Any]:
424
+ if not isinstance(value, dict):
425
+ raise ValueError("Governance policy must be a JSON object.")
426
+
427
+ defaults = default_governance_policy_config()
428
+ final_answer_input = value.get("final_answer", {})
429
+ sql_input = value.get("sql", {})
430
+ privacy_input = value.get("privacy", {})
431
+
432
+ if not isinstance(final_answer_input, dict):
433
+ raise ValueError("final_answer must be an object.")
434
+ if not isinstance(sql_input, dict):
435
+ raise ValueError("sql must be an object.")
436
+ if not isinstance(privacy_input, dict):
437
+ raise ValueError("privacy must be an object.")
438
+
439
+ tenant_column = sql_input.get("tenant_column", defaults["sql"]["tenant_column"])
440
+ if tenant_column is None:
441
+ normalized_tenant_column = None
442
+ elif isinstance(tenant_column, str):
443
+ normalized_tenant_column = tenant_column.strip() or None
444
+ else:
445
+ raise ValueError("sql.tenant_column must be null or a string.")
446
+
447
+ return {
448
+ "final_answer": {
449
+ "record_level_pii_allowed": normalize_bool_setting(
450
+ final_answer_input.get(
451
+ "record_level_pii_allowed",
452
+ defaults["final_answer"]["record_level_pii_allowed"],
453
+ ),
454
+ "final_answer.record_level_pii_allowed",
455
+ ),
456
+ "prefer_aggregates_for_sensitive_domains": normalize_bool_setting(
457
+ final_answer_input.get(
458
+ "prefer_aggregates_for_sensitive_domains",
459
+ defaults["final_answer"]["prefer_aggregates_for_sensitive_domains"],
460
+ ),
461
+ "final_answer.prefer_aggregates_for_sensitive_domains",
462
+ ),
463
+ },
464
+ "sql": {
465
+ "read_only": normalize_bool_setting(
466
+ sql_input.get("read_only", defaults["sql"]["read_only"]),
467
+ "sql.read_only",
468
+ ),
469
+ "select_star_allowed": normalize_bool_setting(
470
+ sql_input.get(
471
+ "select_star_allowed",
472
+ defaults["sql"]["select_star_allowed"],
473
+ ),
474
+ "sql.select_star_allowed",
475
+ ),
476
+ "tenant_filter_required": normalize_bool_setting(
477
+ sql_input.get(
478
+ "tenant_filter_required",
479
+ defaults["sql"]["tenant_filter_required"],
480
+ ),
481
+ "sql.tenant_filter_required",
482
+ ),
483
+ "tenant_column": normalized_tenant_column,
484
+ },
485
+ "privacy": {
486
+ "forbidden_columns": normalize_forbidden_columns(
487
+ privacy_input.get(
488
+ "forbidden_columns",
489
+ defaults["privacy"]["forbidden_columns"],
490
+ )
491
+ ),
492
+ "record_level_forbidden": normalize_bool_setting(
493
+ privacy_input.get(
494
+ "record_level_forbidden",
495
+ defaults["privacy"]["record_level_forbidden"],
496
+ ),
497
+ "privacy.record_level_forbidden",
498
+ ),
499
+ },
500
+ "pii_column_names": normalize_pii_column_names(
501
+ value.get("pii_column_names", defaults["pii_column_names"])
502
+ ),
503
+ }
504
+
505
+
506
+ def get_governance_policy_config(session: Session) -> dict[str, Any]:
507
+ value = get_setting(
508
+ session,
509
+ GOVERNANCE_POLICY_SETTING,
510
+ json_dumps(default_governance_policy_config()),
511
+ )
512
+ return normalize_governance_policy_config(json_loads(value))
513
+
514
+
515
+ def set_governance_policy_config(
516
+ session: Session,
517
+ config: dict[str, Any],
518
+ actor: str,
519
+ ) -> dict[str, Any]:
520
+ normalized = normalize_governance_policy_config(config)
521
+ set_setting(session, GOVERNANCE_POLICY_SETTING, json_dumps(normalized), actor)
522
+ return normalized
523
+
524
+
525
+ def get_governance_policy_sources(session: Session) -> dict[str, str]:
526
+ return {
527
+ "governance_policy": (
528
+ "metadata"
529
+ if session.get(AdminSetting, GOVERNANCE_POLICY_SETTING) is not None
530
+ else "default"
531
+ )
532
+ }
533
+
534
+
535
+ def flatten_pii_column_names(config: dict[str, Any]) -> set[str]:
536
+ column_names: set[str] = set()
537
+ for names in config.get("pii_column_names", {}).values():
538
+ column_names.update(str(name).lower() for name in names)
539
+ return column_names
540
+
541
+
542
+ def infer_configured_forbidden_columns(
543
+ schema_summary: dict[str, Any],
544
+ config: dict[str, Any],
545
+ ) -> dict[str, list[str]]:
546
+ pii_column_names = flatten_pii_column_names(config)
547
+ if not pii_column_names:
548
+ return {}
549
+
550
+ forbidden: dict[str, list[str]] = {}
551
+ for table_name, table in schema_summary.get("tables", {}).items():
552
+ columns = [
553
+ column_name
554
+ for column_name in table.get("columns", {})
555
+ if column_name.lower() in pii_column_names
556
+ ]
557
+ if columns:
558
+ forbidden[table_name] = columns
559
+
560
+ return forbidden
561
+
562
+
563
+ def merge_forbidden_columns(
564
+ configured: dict[str, list[str]],
565
+ inferred: dict[str, list[str]],
566
+ ) -> dict[str, list[str]]:
567
+ merged = {table: [*columns] for table, columns in configured.items()}
568
+
569
+ for table_name, columns in inferred.items():
570
+ existing = merged.setdefault(table_name, [])
571
+ seen = set(existing)
572
+ for column_name in columns:
573
+ if column_name in seen:
574
+ continue
575
+ existing.append(column_name)
576
+ seen.add(column_name)
577
+
578
+ return {table: columns for table, columns in merged.items() if columns}
579
+
580
+
581
+ def build_governance_policy_from_config(
582
+ schema_summary: dict[str, Any],
583
+ config: dict[str, Any],
584
+ ) -> dict[str, Any]:
585
+ normalized = normalize_governance_policy_config(config)
586
+ privacy = {
587
+ **normalized["privacy"],
588
+ "forbidden_columns": merge_forbidden_columns(
589
+ normalized["privacy"]["forbidden_columns"],
590
+ infer_configured_forbidden_columns(schema_summary, normalized),
591
+ ),
592
+ }
593
+
594
+ return {
595
+ "final_answer": normalized["final_answer"],
596
+ "sql": normalized["sql"],
597
+ "privacy": privacy,
598
+ }
599
+
600
+
601
+ def get_governance_policy_for_schema(
602
+ session: Session,
603
+ schema_summary: dict[str, Any],
604
+ ) -> dict[str, Any]:
605
+ return build_governance_policy_from_config(
606
+ schema_summary,
607
+ get_governance_policy_config(session),
608
+ )
609
+
610
+
611
+ def get_governance_policy_for_schema_safe(
612
+ schema_summary: dict[str, Any],
613
+ ) -> dict[str, Any]:
614
+ fallback = build_governance_policy_from_config(
615
+ schema_summary,
616
+ default_governance_policy_config(),
617
+ )
618
+
619
+ try:
620
+ session = create_session()
621
+ except SQLAlchemyError:
622
+ return fallback
623
+
624
+ try:
625
+ return get_governance_policy_for_schema(session, schema_summary)
626
+ except (SQLAlchemyError, ValueError, TypeError, json.JSONDecodeError):
627
+ return fallback
628
+ finally:
629
+ session.close()
630
+
631
+
632
+ def get_query_runtime_config(session: Session) -> QueryRuntimeConfig:
633
+ return QueryRuntimeConfig(
634
+ intent_classification_mode=get_setting(
635
+ session,
636
+ INTENT_CLASSIFICATION_MODE_SETTING,
637
+ settings.gaard_intent_classification_mode,
638
+ ),
639
+ sql_generation_mode=get_setting(
640
+ session,
641
+ SQL_GENERATION_MODE_SETTING,
642
+ settings.gaard_sql_generation_mode,
643
+ ),
644
+ result_interpretation_mode=get_setting(
645
+ session,
646
+ RESULT_INTERPRETATION_MODE_SETTING,
647
+ settings.gaard_result_interpretation_mode,
648
+ ),
649
+ output_classification_mode=get_setting(
650
+ session,
651
+ OUTPUT_CLASSIFICATION_MODE_SETTING,
652
+ settings.gaard_output_classification_mode,
653
+ ),
654
+ investigation_mode=get_setting(
655
+ session,
656
+ INVESTIGATION_MODE_SETTING,
657
+ settings.gaard_investigation_mode,
658
+ ),
659
+ investigation_ambiguity_mode=get_setting(
660
+ session,
661
+ INVESTIGATION_AMBIGUITY_MODE_SETTING,
662
+ settings.gaard_investigation_ambiguity_mode,
663
+ ),
664
+ query_max_rows=get_int_setting(
665
+ session,
666
+ QUERY_MAX_ROWS_SETTING,
667
+ settings.gaard_query_max_rows,
668
+ ),
669
+ query_timeout_seconds=get_int_setting(
670
+ session,
671
+ QUERY_TIMEOUT_SECONDS_SETTING,
672
+ settings.gaard_query_timeout_seconds,
673
+ ),
674
+ )
675
+
676
+
677
+ def get_query_runtime_config_safe() -> QueryRuntimeConfig:
678
+ fallback = QueryRuntimeConfig(
679
+ intent_classification_mode=settings.gaard_intent_classification_mode,
680
+ sql_generation_mode=settings.gaard_sql_generation_mode,
681
+ result_interpretation_mode=settings.gaard_result_interpretation_mode,
682
+ output_classification_mode=settings.gaard_output_classification_mode,
683
+ investigation_mode=settings.gaard_investigation_mode,
684
+ investigation_ambiguity_mode=settings.gaard_investigation_ambiguity_mode,
685
+ query_max_rows=settings.gaard_query_max_rows,
686
+ query_timeout_seconds=settings.gaard_query_timeout_seconds,
687
+ )
688
+
689
+ try:
690
+ session = create_session()
691
+ except SQLAlchemyError:
692
+ return fallback
693
+
694
+ try:
695
+ return get_query_runtime_config(session)
696
+ except (SQLAlchemyError, ValueError, TypeError):
697
+ return fallback
698
+ finally:
699
+ session.close()
700
+
701
+
702
+ def set_query_runtime_config(
703
+ session: Session,
704
+ intent_classification_mode: str,
705
+ sql_generation_mode: str,
706
+ result_interpretation_mode: str,
707
+ output_classification_mode: str,
708
+ investigation_mode: str,
709
+ investigation_ambiguity_mode: str,
710
+ query_max_rows: int,
711
+ query_timeout_seconds: int,
712
+ actor: str,
713
+ ) -> QueryRuntimeConfig:
714
+ set_setting(session, INTENT_CLASSIFICATION_MODE_SETTING, intent_classification_mode, actor)
715
+ set_setting(session, SQL_GENERATION_MODE_SETTING, sql_generation_mode, actor)
716
+ set_setting(session, RESULT_INTERPRETATION_MODE_SETTING, result_interpretation_mode, actor)
717
+ set_setting(session, OUTPUT_CLASSIFICATION_MODE_SETTING, output_classification_mode, actor)
718
+ set_setting(session, INVESTIGATION_MODE_SETTING, investigation_mode, actor)
719
+ set_setting(
720
+ session,
721
+ INVESTIGATION_AMBIGUITY_MODE_SETTING,
722
+ investigation_ambiguity_mode,
723
+ actor,
724
+ )
725
+ set_setting(session, QUERY_MAX_ROWS_SETTING, str(query_max_rows), actor)
726
+ set_setting(session, QUERY_TIMEOUT_SECONDS_SETTING, str(query_timeout_seconds), actor)
727
+
728
+ return get_query_runtime_config(session)
729
+
730
+
731
+ def get_llm_runtime_config(session: Session) -> LlmRuntimeConfig:
732
+ extra_body = json_loads(
733
+ get_setting(
734
+ session,
735
+ LLM_SETTING_EXTRA_BODY,
736
+ json_dumps(settings.gaard_llm_extra_body),
737
+ )
738
+ )
739
+
740
+ if not isinstance(extra_body, dict):
741
+ extra_body = {}
742
+
743
+ return LlmRuntimeConfig(
744
+ provider=get_setting(session, LLM_SETTING_PROVIDER, settings.gaard_llm_provider),
745
+ base_url=get_setting(session, LLM_SETTING_BASE_URL, settings.gaard_llm_base_url),
746
+ api_key=get_setting(session, LLM_SETTING_API_KEY, settings.gaard_llm_api_key),
747
+ model=get_setting(session, LLM_SETTING_MODEL, settings.gaard_llm_model),
748
+ extra_body=extra_body,
749
+ timeout_seconds=get_int_setting(
750
+ session,
751
+ LLM_SETTING_TIMEOUT_SECONDS,
752
+ settings.gaard_llm_timeout_seconds,
753
+ ),
754
+ )
755
+
756
+
757
+ def get_llm_runtime_config_safe() -> LlmRuntimeConfig:
758
+ try:
759
+ session = create_session()
760
+ except SQLAlchemyError:
761
+ return LlmRuntimeConfig(
762
+ provider=settings.gaard_llm_provider,
763
+ base_url=settings.gaard_llm_base_url,
764
+ api_key=settings.gaard_llm_api_key,
765
+ model=settings.gaard_llm_model,
766
+ extra_body=settings.gaard_llm_extra_body,
767
+ timeout_seconds=settings.gaard_llm_timeout_seconds,
768
+ )
769
+
770
+ try:
771
+ return get_llm_runtime_config(session)
772
+ except (SQLAlchemyError, ValueError, TypeError):
773
+ return LlmRuntimeConfig(
774
+ provider=settings.gaard_llm_provider,
775
+ base_url=settings.gaard_llm_base_url,
776
+ api_key=settings.gaard_llm_api_key,
777
+ model=settings.gaard_llm_model,
778
+ extra_body=settings.gaard_llm_extra_body,
779
+ timeout_seconds=settings.gaard_llm_timeout_seconds,
780
+ )
781
+ finally:
782
+ session.close()
783
+
784
+
785
+ def get_llm_config_sources(session: Session) -> dict[str, str]:
786
+ keys = {
787
+ "provider": LLM_SETTING_PROVIDER,
788
+ "base_url": LLM_SETTING_BASE_URL,
789
+ "api_key": LLM_SETTING_API_KEY,
790
+ "model": LLM_SETTING_MODEL,
791
+ "timeout_seconds": LLM_SETTING_TIMEOUT_SECONDS,
792
+ "extra_body": LLM_SETTING_EXTRA_BODY,
793
+ "intent_classification_mode": INTENT_CLASSIFICATION_MODE_SETTING,
794
+ "sql_generation_mode": SQL_GENERATION_MODE_SETTING,
795
+ "result_interpretation_mode": RESULT_INTERPRETATION_MODE_SETTING,
796
+ "output_classification_mode": OUTPUT_CLASSIFICATION_MODE_SETTING,
797
+ "investigation_mode": INVESTIGATION_MODE_SETTING,
798
+ "investigation_ambiguity_mode": INVESTIGATION_AMBIGUITY_MODE_SETTING,
799
+ "query_max_rows": QUERY_MAX_ROWS_SETTING,
800
+ "query_timeout_seconds": QUERY_TIMEOUT_SECONDS_SETTING,
801
+ }
802
+
803
+ return {
804
+ field: "metadata" if session.get(AdminSetting, key) is not None else "default"
805
+ for field, key in keys.items()
806
+ }
807
+
808
+
809
+ def set_llm_runtime_config(
810
+ session: Session,
811
+ provider: str,
812
+ base_url: str,
813
+ api_key: str | None,
814
+ model: str,
815
+ timeout_seconds: int,
816
+ extra_body: dict[str, Any],
817
+ actor: str,
818
+ ) -> LlmRuntimeConfig:
819
+ set_setting(session, LLM_SETTING_PROVIDER, provider, actor)
820
+ set_setting(session, LLM_SETTING_BASE_URL, base_url, actor)
821
+ if api_key is not None:
822
+ set_setting(session, LLM_SETTING_API_KEY, api_key, actor)
823
+ set_setting(session, LLM_SETTING_MODEL, model, actor)
824
+ set_setting(session, LLM_SETTING_TIMEOUT_SECONDS, str(timeout_seconds), actor)
825
+ set_setting(session, LLM_SETTING_EXTRA_BODY, json_dumps(extra_body), actor)
826
+
827
+ return get_llm_runtime_config(session)
828
+
829
+
830
+ def get_data_query_audit_retention_days(session: Session) -> int:
831
+ value = get_setting(
832
+ session,
833
+ "data_query_audit_retention_days",
834
+ str(settings.gaard_audit_retention_days),
835
+ )
836
+
837
+ return max(1, int(value))
838
+
839
+
840
+ def apply_data_query_audit_retention(session: Session) -> None:
841
+ retention_days = get_data_query_audit_retention_days(session)
842
+ cutoff = datetime.now(UTC) - timedelta(days=retention_days)
843
+
844
+ session.execute(
845
+ delete(DataQueryAuditLog).where(DataQueryAuditLog.occurred_at < cutoff)
846
+ )
847
+
848
+
849
+ def list_data_query_audit_logs(
850
+ session: Session,
851
+ limit: int = 100,
852
+ audit_type: DataQueryAuditType | str | None = None,
853
+ output_classification: OutputClassification | str | None = None,
854
+ sql_contains: str | None = None,
855
+ ) -> list[DataQueryAuditLog]:
856
+ apply_data_query_audit_retention(session)
857
+
858
+ query = select(DataQueryAuditLog).order_by(desc(DataQueryAuditLog.occurred_at))
859
+
860
+ if audit_type is not None:
861
+ query = query.where(DataQueryAuditLog.type == coerce_data_query_audit_type(audit_type))
862
+
863
+ if output_classification is not None:
864
+ query = query.where(
865
+ DataQueryAuditLog.output_classification
866
+ == coerce_output_classification(output_classification)
867
+ )
868
+
869
+ if sql_contains is not None and sql_contains.strip():
870
+ query = query.where(DataQueryAuditLog.sql.contains(sql_contains.strip()))
871
+
872
+ return list(session.scalars(query.limit(limit)))
873
+
874
+
875
+ def get_data_query_audit_type(log: DataQueryAuditLog) -> str:
876
+ return data_query_audit_type_value(log.type)
877
+
878
+
879
+ def data_query_audit_type_value(value: DataQueryAuditType | str | None) -> str:
880
+ if isinstance(value, DataQueryAuditType):
881
+ return value.value
882
+
883
+ if isinstance(value, str) and value:
884
+ try:
885
+ return coerce_data_query_audit_type(value).value
886
+ except ValueError:
887
+ return value
888
+
889
+ return DATA_QUERY_AUDIT_INFO
890
+
891
+
892
+ def coerce_data_query_audit_type(value: DataQueryAuditType | str) -> DataQueryAuditType:
893
+ if isinstance(value, DataQueryAuditType):
894
+ return value
895
+
896
+ normalized = value.strip().lower().replace(" ", "_").replace("-", "_")
897
+ aliases = {item.value: item for item in DataQueryAuditType}
898
+
899
+ if normalized in aliases:
900
+ return aliases[normalized]
901
+
902
+ raise ValueError("Unsupported data query audit type.")
903
+
904
+
905
+ def coerce_output_classification(value: object) -> OutputClassification:
906
+ if isinstance(value, OutputClassification):
907
+ return value
908
+
909
+ if not isinstance(value, str) or not value.strip():
910
+ return OutputClassification.UNKNOWN
911
+
912
+ normalized = value.strip().lower().replace(" ", "_").replace("-", "_")
913
+ aliases = {item.value: item for item in OutputClassification}
914
+
915
+ return aliases.get(normalized, OutputClassification.UNKNOWN)
916
+
917
+
918
+ def list_business_logic_suggestions(
919
+ session: Session,
920
+ connector_id: int,
921
+ ) -> list[BusinessLogicSuggestion]:
922
+ return list(
923
+ session.scalars(
924
+ select(BusinessLogicSuggestion)
925
+ .where(BusinessLogicSuggestion.connector_id == connector_id)
926
+ .order_by(
927
+ BusinessLogicSuggestion.enabled.desc(),
928
+ desc(BusinessLogicSuggestion.updated_at),
929
+ )
930
+ )
931
+ )
932
+
933
+
934
+ def get_business_logic_suggestion(
935
+ session: Session,
936
+ suggestion_id: int,
937
+ ) -> BusinessLogicSuggestion | None:
938
+ return session.get(BusinessLogicSuggestion, suggestion_id)
939
+
940
+
941
+ def set_business_logic_suggestion_enabled(
942
+ session: Session,
943
+ suggestion: BusinessLogicSuggestion,
944
+ enabled: bool,
945
+ actor: str,
946
+ ) -> BusinessLogicSuggestion:
947
+ suggestion.enabled = enabled
948
+ suggestion.status = (
949
+ BUSINESS_LOGIC_STATUS_ACTIVE
950
+ if enabled
951
+ else BUSINESS_LOGIC_STATUS_PENDING
952
+ )
953
+ suggestion.updated_by = actor
954
+
955
+ return suggestion
956
+
957
+
958
+ def update_business_logic_suggestion_content(
959
+ suggestion: BusinessLogicSuggestion,
960
+ title: str | None,
961
+ rule_text: str | None,
962
+ actor: str,
963
+ ) -> BusinessLogicSuggestion:
964
+ if title is not None:
965
+ suggestion.title = truncate_text(title.strip(), 255)
966
+
967
+ if rule_text is not None:
968
+ suggestion.rule_text = rule_text.strip()
969
+
970
+ suggestion.updated_by = actor
971
+
972
+ return suggestion
973
+
974
+
975
+ def delete_business_logic_suggestion(
976
+ session: Session,
977
+ suggestion: BusinessLogicSuggestion,
978
+ ) -> None:
979
+ session.delete(suggestion)
980
+
981
+
982
+ def learn_business_logic_from_sql_error(
983
+ connector_id: int | None,
984
+ audit_id: int | None,
985
+ actor: str = "system",
986
+ ) -> BusinessLogicSuggestion | None:
987
+ if audit_id is None:
988
+ return None
989
+
990
+ try:
991
+ session = create_session()
992
+ except SQLAlchemyError:
993
+ return None
994
+
995
+ try:
996
+ audit_log = session.get(DataQueryAuditLog, audit_id)
997
+ if audit_log is None:
998
+ return None
999
+
1000
+ metadata = safe_json_object(audit_log.metadata_json)
1001
+ exclusion_reason = business_logic_learning_exclusion_reason(metadata)
1002
+ if exclusion_reason:
1003
+ mark_business_logic_learning_skipped(
1004
+ session=session,
1005
+ audit_log=audit_log,
1006
+ reason=exclusion_reason,
1007
+ )
1008
+ return None
1009
+
1010
+ if connector_id is None:
1011
+ mark_business_logic_learning_skipped(
1012
+ session=session,
1013
+ audit_log=audit_log,
1014
+ reason="No active datasource connector was available for this SQL error.",
1015
+ )
1016
+ return None
1017
+
1018
+ connector = session.get(DatasourceConnector, connector_id)
1019
+ if connector is None:
1020
+ mark_business_logic_learning_skipped(
1021
+ session=session,
1022
+ audit_log=audit_log,
1023
+ reason="Datasource connector was not found for this SQL error.",
1024
+ )
1025
+ return None
1026
+
1027
+ cache = get_datasource_schema_cache(session, connector_id)
1028
+ if cache is None:
1029
+ mark_business_logic_learning_skipped(
1030
+ session=session,
1031
+ audit_log=audit_log,
1032
+ reason="Datasource schema cache was not available for LLM learning.",
1033
+ )
1034
+ return None
1035
+
1036
+ try:
1037
+ llm_config = get_llm_runtime_config(session)
1038
+ except (SQLAlchemyError, ValueError, TypeError) as exc:
1039
+ mark_business_logic_learning_skipped(
1040
+ session=session,
1041
+ audit_log=audit_log,
1042
+ reason=f"LLM configuration could not be loaded: {exc}",
1043
+ )
1044
+ return None
1045
+
1046
+ skip_reason = validate_business_logic_learning_llm_config(llm_config)
1047
+ if skip_reason:
1048
+ mark_business_logic_learning_skipped(
1049
+ session=session,
1050
+ audit_log=audit_log,
1051
+ reason=skip_reason,
1052
+ )
1053
+ return None
1054
+
1055
+ try:
1056
+ lesson = request_business_logic_lesson(
1057
+ llm_config=llm_config,
1058
+ connector=connector,
1059
+ audit_log=audit_log,
1060
+ schema_cache=cache,
1061
+ )
1062
+ except (LlmProviderError, ValueError, TypeError) as exc:
1063
+ mark_business_logic_learning_skipped(
1064
+ session=session,
1065
+ audit_log=audit_log,
1066
+ reason=f"LLM learning failed: {exc}",
1067
+ )
1068
+ return None
1069
+
1070
+ if not lesson.create_suggestion:
1071
+ mark_business_logic_learning_skipped(
1072
+ session=session,
1073
+ audit_log=audit_log,
1074
+ reason=lesson.skip_reason
1075
+ or "LLM did not find a durable SQL-generation lesson.",
1076
+ llm_response=lesson.raw,
1077
+ )
1078
+ return None
1079
+
1080
+ if not lesson.rule_text:
1081
+ mark_business_logic_learning_skipped(
1082
+ session=session,
1083
+ audit_log=audit_log,
1084
+ reason="LLM did not return a rule_text for business logic learning.",
1085
+ llm_response=lesson.raw,
1086
+ )
1087
+ return None
1088
+
1089
+ suggestion = upsert_llm_business_logic_suggestion(
1090
+ session=session,
1091
+ connector=connector,
1092
+ audit_log=audit_log,
1093
+ lesson=lesson,
1094
+ actor=actor,
1095
+ )
1096
+ record_business_logic_learning_suggestion(
1097
+ audit_log=audit_log,
1098
+ suggestion=suggestion,
1099
+ lesson=lesson,
1100
+ )
1101
+ session.commit()
1102
+
1103
+ return suggestion
1104
+ except Exception:
1105
+ session.rollback()
1106
+ return None
1107
+ finally:
1108
+ session.close()
1109
+
1110
+
1111
+ def business_logic_learning_exclusion_reason(metadata: dict[str, Any]) -> str:
1112
+ pipeline_phase = str(metadata.get("pipeline_phase") or "")
1113
+ route = str(metadata.get("route") or metadata.get("execution_route") or "")
1114
+ required_evidence_type = str(metadata.get("required_evidence_type") or "")
1115
+ primary_error_category = str(metadata.get("primary_error_category") or "")
1116
+ failed_identifier = str(metadata.get("failed_identifier") or "")
1117
+ error_categories = {
1118
+ str(category)
1119
+ for category in metadata.get("error_categories", [])
1120
+ if str(category).strip()
1121
+ }
1122
+ if primary_error_category:
1123
+ error_categories.add(primary_error_category)
1124
+
1125
+ non_sql_routes = {
1126
+ "answer_from_schema_summary",
1127
+ "answer_from_policy_or_governance",
1128
+ "ask_clarification",
1129
+ "answer_from_reasoning",
1130
+ "cannot_answer_safely",
1131
+ }
1132
+ non_record_evidence = {
1133
+ "schema_metadata",
1134
+ "governance_policy",
1135
+ "clarification",
1136
+ "reasoning_only",
1137
+ }
1138
+ pipeline_design_phases = {
1139
+ "intent_classification",
1140
+ }
1141
+ non_business_categories = {
1142
+ "schema_metadata.unavailable",
1143
+ "governance_policy.unavailable",
1144
+ "clarification.unavailable",
1145
+ "reasoning.unavailable",
1146
+ "intent.ambiguous_requires_clarification",
1147
+ }
1148
+
1149
+ if route in non_sql_routes:
1150
+ return "Business logic learning is skipped for non-SQL routes."
1151
+ if required_evidence_type in non_record_evidence:
1152
+ return "Business logic learning is skipped for non-record evidence outcomes."
1153
+ if pipeline_phase in pipeline_design_phases:
1154
+ return (
1155
+ "Business logic learning is skipped for query routing, modeling, or "
1156
+ "preflight failures."
1157
+ )
1158
+ if error_categories & non_business_categories:
1159
+ return (
1160
+ "Business logic learning is skipped because the failure is not a durable "
1161
+ "business-logic gap."
1162
+ )
1163
+ return ""
1164
+
1165
+
1166
+ @dataclass(frozen=True)
1167
+ class BusinessLogicLesson:
1168
+ create_suggestion: bool
1169
+ title: str
1170
+ rule_text: str
1171
+ error_category: str
1172
+ failed_identifier: str
1173
+ repaired_identifier: str
1174
+ confidence: float
1175
+ terms: list[str]
1176
+ join_hints: list[str]
1177
+ skip_reason: str
1178
+ raw: dict[str, Any]
1179
+
1180
+
1181
+ def validate_business_logic_learning_llm_config(config: LlmRuntimeConfig) -> str:
1182
+ if config.provider != "openai-compatible":
1183
+ return f"Unsupported LLM provider for business logic learning: {config.provider}."
1184
+
1185
+ if not config.base_url:
1186
+ return "LLM base URL is not configured for business logic learning."
1187
+
1188
+ if not config.model:
1189
+ return "LLM model is not configured for business logic learning."
1190
+
1191
+ if not config.api_key or config.api_key == "change-me":
1192
+ return "LLM API key is not configured for business logic learning."
1193
+
1194
+ return ""
1195
+
1196
+
1197
+ def request_business_logic_lesson(
1198
+ llm_config: LlmRuntimeConfig,
1199
+ connector: DatasourceConnector,
1200
+ audit_log: DataQueryAuditLog,
1201
+ schema_cache: DatasourceSchemaCache,
1202
+ ) -> BusinessLogicLesson:
1203
+ client = OpenAICompatibleClient(
1204
+ base_url=llm_config.base_url,
1205
+ api_key=llm_config.api_key,
1206
+ timeout_seconds=llm_config.timeout_seconds,
1207
+ )
1208
+ system_prompt, user_prompt = build_business_logic_learning_prompt(
1209
+ connector=connector,
1210
+ audit_log=audit_log,
1211
+ schema_cache=schema_cache,
1212
+ )
1213
+ response = client.create_chat_completion(
1214
+ ChatCompletionRequest(
1215
+ model=llm_config.model,
1216
+ temperature=0.0,
1217
+ extra_body=llm_config.extra_body,
1218
+ messages=[
1219
+ ChatMessage(role="system", content=system_prompt),
1220
+ ChatMessage(role="user", content=user_prompt),
1221
+ ],
1222
+ )
1223
+ )
1224
+
1225
+ return parse_business_logic_lesson_response(response.content)
1226
+
1227
+
1228
+ def build_business_logic_learning_prompt(
1229
+ connector: DatasourceConnector,
1230
+ audit_log: DataQueryAuditLog,
1231
+ schema_cache: DatasourceSchemaCache,
1232
+ ) -> tuple[str, str]:
1233
+ metadata = safe_json_object(audit_log.metadata_json)
1234
+ error_message = str(metadata.get("error_message") or audit_log.answer)
1235
+ error_detail = str(metadata.get("error_detail") or "")
1236
+ formatted_schema = schema_cache.formatted_schema.strip() or schema_cache.schema_json
1237
+
1238
+ system_prompt = """You diagnose failed generated SQL and turn the diagnosis into durable business logic for future SQL generation.
1239
+
1240
+ Analyze every SQL execution error type yourself. Do not rely on pre-classified metadata.
1241
+ Create a business logic suggestion when a future SQL generator can avoid this error by following a durable rule about schema usage, joins, aliases, dialect, functions, grouping, filtering, or business terminology.
1242
+ Do not create a suggestion for transient infrastructure failures, permissions that require administrator action, missing privileges, unavailable databases, timeouts, or cases where there is no durable SQL-generation lesson.
1243
+ The lesson must be a direct instruction for future SQL generation, not a postmortem.
1244
+
1245
+ Return JSON only. Use this exact shape:
1246
+ {
1247
+ "create_suggestion": true,
1248
+ "error_category": "short.category",
1249
+ "title": "short title for an admin",
1250
+ "rule_text": "durable instruction for future SQL generation",
1251
+ "failed_identifier": "optional table, column, function or concept that caused the error",
1252
+ "repaired_identifier": "optional preferred table, column, function or concept",
1253
+ "confidence": 0.0,
1254
+ "terms": ["optional", "search", "terms"],
1255
+ "join_hints": ["optional join or alias hints"],
1256
+ "skip_reason": ""
1257
+ }
1258
+
1259
+ If there is no durable lesson, return the same JSON shape with "create_suggestion": false, an empty rule_text, and a clear skip_reason."""
1260
+
1261
+ user_prompt = f"""Datasource:
1262
+ - key: {connector.connector_key}
1263
+ - dialect: {connector.sql_dialect}
1264
+
1265
+ Database schema and approved business logic:
1266
+ {formatted_schema}
1267
+
1268
+ User question:
1269
+ {audit_log.question}
1270
+
1271
+ Generated SQL:
1272
+ {audit_log.sql}
1273
+
1274
+ SQL execution error:
1275
+ {error_message}
1276
+
1277
+ SQL error detail:
1278
+ {error_detail}
1279
+
1280
+ Audit metadata:
1281
+ {json.dumps(metadata, ensure_ascii=False, indent=2, sort_keys=True)}
1282
+
1283
+ Return the JSON lesson only."""
1284
+
1285
+ return system_prompt, user_prompt
1286
+
1287
+
1288
+ def parse_business_logic_lesson_response(value: str) -> BusinessLogicLesson:
1289
+ cleaned = remove_thinking_blocks(value).strip()
1290
+
1291
+ if cleaned.startswith("```json"):
1292
+ cleaned = cleaned.removeprefix("```json").strip()
1293
+
1294
+ if cleaned.startswith("```"):
1295
+ cleaned = cleaned.removeprefix("```").strip()
1296
+
1297
+ if cleaned.endswith("```"):
1298
+ cleaned = cleaned.removesuffix("```").strip()
1299
+
1300
+ try:
1301
+ payload = json.loads(cleaned)
1302
+ except json.JSONDecodeError as exc:
1303
+ raise ValueError("LLM returned non-JSON business logic learning output.") from exc
1304
+
1305
+ if not isinstance(payload, dict):
1306
+ raise ValueError("LLM returned invalid business logic learning output.")
1307
+
1308
+ rule_text = str(payload.get("rule_text") or payload.get("lesson") or "").strip()
1309
+ create_suggestion = bool(payload.get("create_suggestion"))
1310
+
1311
+ return BusinessLogicLesson(
1312
+ create_suggestion=create_suggestion,
1313
+ title=str(payload.get("title") or "").strip(),
1314
+ rule_text=rule_text,
1315
+ error_category=coerce_short_text(payload.get("error_category"), SQL_ERROR_UNKNOWN, 100),
1316
+ failed_identifier=coerce_short_text(payload.get("failed_identifier"), "", 255),
1317
+ repaired_identifier=coerce_short_text(payload.get("repaired_identifier"), "", 255),
1318
+ confidence=coerce_confidence(payload.get("confidence")),
1319
+ terms=coerce_text_list(payload.get("terms")),
1320
+ join_hints=coerce_text_list(payload.get("join_hints")),
1321
+ skip_reason=str(payload.get("skip_reason") or "").strip(),
1322
+ raw=payload,
1323
+ )
1324
+
1325
+
1326
+ def upsert_llm_business_logic_suggestion(
1327
+ session: Session,
1328
+ connector: DatasourceConnector,
1329
+ audit_log: DataQueryAuditLog,
1330
+ lesson: BusinessLogicLesson,
1331
+ actor: str,
1332
+ ) -> BusinessLogicSuggestion:
1333
+ if not lesson.rule_text:
1334
+ raise ValueError("LLM did not return a rule_text for business logic learning.")
1335
+
1336
+ metadata = safe_json_object(audit_log.metadata_json)
1337
+ title = truncate_text(
1338
+ lesson.title
1339
+ or f"SQL lesson for {lesson.error_category.replace('_', ' ')}",
1340
+ 255,
1341
+ )
1342
+ failed_identifier = lesson.failed_identifier or str(metadata.get("failed_identifier") or "")
1343
+ repaired_identifier = lesson.repaired_identifier
1344
+ terms = lesson.terms or build_business_logic_terms(
1345
+ audit_log.question,
1346
+ failed_identifier,
1347
+ repaired_identifier or title,
1348
+ )
1349
+
1350
+ existing = session.scalar(
1351
+ select(BusinessLogicSuggestion).where(
1352
+ BusinessLogicSuggestion.connector_id == connector.id,
1353
+ BusinessLogicSuggestion.source_audit_id == audit_log.id,
1354
+ )
1355
+ )
1356
+
1357
+ if existing is None and (failed_identifier or repaired_identifier):
1358
+ existing = session.scalar(
1359
+ select(BusinessLogicSuggestion).where(
1360
+ BusinessLogicSuggestion.connector_id == connector.id,
1361
+ BusinessLogicSuggestion.error_category == lesson.error_category,
1362
+ BusinessLogicSuggestion.failed_identifier == failed_identifier,
1363
+ BusinessLogicSuggestion.repaired_identifier == repaired_identifier,
1364
+ )
1365
+ )
1366
+
1367
+ if existing is None:
1368
+ existing = BusinessLogicSuggestion(
1369
+ connector_id=connector.id,
1370
+ source_audit_id=audit_log.id,
1371
+ status=BUSINESS_LOGIC_STATUS_PENDING,
1372
+ safety=BUSINESS_LOGIC_SAFETY_REVIEW,
1373
+ enabled=False,
1374
+ error_category=lesson.error_category,
1375
+ title=title,
1376
+ rule_text=lesson.rule_text,
1377
+ terms_json=json_dumps(terms),
1378
+ join_hints_json=json_dumps(lesson.join_hints),
1379
+ failed_identifier=failed_identifier,
1380
+ repaired_identifier=repaired_identifier,
1381
+ confidence=lesson.confidence,
1382
+ updated_by=actor,
1383
+ )
1384
+ session.add(existing)
1385
+ session.flush()
1386
+ return existing
1387
+
1388
+ existing.source_audit_id = audit_log.id
1389
+ existing.status = BUSINESS_LOGIC_STATUS_PENDING
1390
+ existing.safety = BUSINESS_LOGIC_SAFETY_REVIEW
1391
+ existing.enabled = False
1392
+ existing.error_category = lesson.error_category
1393
+ existing.title = title
1394
+ existing.rule_text = lesson.rule_text
1395
+ existing.terms_json = json_dumps(terms)
1396
+ existing.join_hints_json = json_dumps(lesson.join_hints)
1397
+ existing.failed_identifier = failed_identifier
1398
+ existing.repaired_identifier = repaired_identifier
1399
+ existing.confidence = lesson.confidence
1400
+ existing.updated_by = actor
1401
+
1402
+ return existing
1403
+
1404
+
1405
+ def record_business_logic_learning_suggestion(
1406
+ audit_log: DataQueryAuditLog,
1407
+ suggestion: BusinessLogicSuggestion,
1408
+ lesson: BusinessLogicLesson,
1409
+ ) -> None:
1410
+ metadata = safe_json_object(audit_log.metadata_json)
1411
+ metadata["error_category"] = lesson.error_category
1412
+ metadata["failed_identifier"] = lesson.failed_identifier
1413
+ metadata["repaired_identifier"] = lesson.repaired_identifier
1414
+ metadata["business_logic_learning"] = {
1415
+ "status": "pending_approval",
1416
+ "suggestion_id": suggestion.id,
1417
+ "message": (
1418
+ "Nauczyłem się propozycji rozwiązania tego błędu, ale musisz ją "
1419
+ "zatwierdzić w Sugestiach logiki biznesowej."
1420
+ ),
1421
+ "admin_section": "business-logic",
1422
+ "error_category": lesson.error_category,
1423
+ "confidence": lesson.confidence,
1424
+ }
1425
+ audit_log.metadata_json = json_dumps(metadata)
1426
+
1427
+
1428
+ def mark_business_logic_learning_skipped(
1429
+ session: Session,
1430
+ audit_log: DataQueryAuditLog,
1431
+ reason: str,
1432
+ llm_response: dict[str, Any] | None = None,
1433
+ ) -> None:
1434
+ metadata = safe_json_object(audit_log.metadata_json)
1435
+ learning: dict[str, Any] = {
1436
+ "status": BUSINESS_LOGIC_LEARNING_STATUS_SKIPPED,
1437
+ "reason": reason,
1438
+ "message": f"Nauka logiki biznesowej została pominięta: {reason}",
1439
+ "admin_section": "business-logic",
1440
+ }
1441
+
1442
+ if llm_response:
1443
+ learning["llm_response"] = llm_response
1444
+
1445
+ metadata["business_logic_learning"] = learning
1446
+ audit_log.metadata_json = json_dumps(metadata)
1447
+ session.commit()
1448
+
1449
+
1450
+ def safe_json_object(value: str) -> dict[str, Any]:
1451
+ try:
1452
+ payload = json_loads(value)
1453
+ except (json.JSONDecodeError, TypeError):
1454
+ return {}
1455
+
1456
+ return payload if isinstance(payload, dict) else {}
1457
+
1458
+
1459
+ def coerce_short_text(value: object, default: str, max_length: int) -> str:
1460
+ text_value = str(value or default).strip()
1461
+ return truncate_text(text_value, max_length)
1462
+
1463
+
1464
+ def truncate_text(value: str, max_length: int) -> str:
1465
+ if len(value) <= max_length:
1466
+ return value
1467
+
1468
+ return value[: max_length - 3].rstrip() + "..."
1469
+
1470
+
1471
+ def coerce_confidence(value: object) -> float:
1472
+ if not isinstance(value, (int, float, str)):
1473
+ return 0.0
1474
+
1475
+ try:
1476
+ confidence = float(value)
1477
+ except (TypeError, ValueError):
1478
+ return 0.0
1479
+
1480
+ return max(0.0, min(1.0, confidence))
1481
+
1482
+
1483
+ def coerce_text_list(value: object) -> list[str]:
1484
+ if not isinstance(value, list):
1485
+ return []
1486
+
1487
+ items: list[str] = []
1488
+ for item in value:
1489
+ text_value = str(item or "").strip()
1490
+ if text_value and text_value not in items:
1491
+ items.append(text_value)
1492
+
1493
+ return items
1494
+
1495
+
1496
+ def build_business_logic_terms(
1497
+ question: str,
1498
+ failed_identifier: str,
1499
+ repaired_identifier: str,
1500
+ ) -> list[str]:
1501
+ terms = {
1502
+ failed_identifier,
1503
+ *failed_identifier.split("_"),
1504
+ *repaired_identifier.split("_"),
1505
+ }
1506
+
1507
+ for token in re.findall(r"[A-Za-zÀ-ž0-9_]{4,}", question.lower()):
1508
+ terms.add(token)
1509
+
1510
+ return sorted(term for term in terms if term)
1511
+
1512
+
1513
+ def get_active_business_logic_prompt_safe(connector_id: int) -> str:
1514
+ try:
1515
+ session = create_session()
1516
+ except SQLAlchemyError:
1517
+ return ""
1518
+
1519
+ try:
1520
+ return format_business_logic_prompt(
1521
+ [
1522
+ suggestion
1523
+ for suggestion in list_business_logic_suggestions(session, connector_id)
1524
+ if suggestion.enabled
1525
+ ]
1526
+ )
1527
+ except SQLAlchemyError:
1528
+ return ""
1529
+ finally:
1530
+ session.close()
1531
+
1532
+
1533
+ def format_business_logic_prompt(
1534
+ suggestions: list[BusinessLogicSuggestion],
1535
+ ) -> str:
1536
+ active = [suggestion for suggestion in suggestions if suggestion.enabled]
1537
+
1538
+ if not active:
1539
+ return ""
1540
+
1541
+ lines = ["Business logic:"]
1542
+
1543
+ for suggestion in active:
1544
+ lines.append(f"- {suggestion.rule_text}")
1545
+
1546
+ return "\n".join(lines)
1547
+
1548
+
1549
+ INVESTIGATION_ANALYSIS_CATEGORIES = {
1550
+ "dictionary_value",
1551
+ "relationship_logic",
1552
+ "filter_logic",
1553
+ "aggregation_logic",
1554
+ "entity_mapping",
1555
+ "unknown",
1556
+ }
1557
+
1558
+
1559
+ def upsert_investigation_analysis_business_logic_suggestion(
1560
+ connector_id: int | None,
1561
+ source_audit_id: int | None,
1562
+ missing_information: str,
1563
+ required_analysis: str,
1564
+ category: str,
1565
+ analysis_response: QueryResponse,
1566
+ actor: str = "system",
1567
+ ) -> dict[str, Any]:
1568
+ if connector_id is None:
1569
+ return {
1570
+ "status": "skipped",
1571
+ "reason": "Datasource connector is unavailable.",
1572
+ }
1573
+
1574
+ normalized_category = normalize_investigation_analysis_category(category)
1575
+ normalized_missing = normalize_fingerprint_text(missing_information)
1576
+ normalized_analysis = normalize_fingerprint_text(required_analysis)
1577
+ result_signature = investigation_analysis_result_signature(analysis_response)
1578
+ fingerprint = investigation_analysis_fingerprint(
1579
+ category=normalized_category,
1580
+ missing_information=normalized_missing,
1581
+ required_analysis=normalized_analysis,
1582
+ result_signature=result_signature,
1583
+ )
1584
+
1585
+ try:
1586
+ session = create_session()
1587
+ except SQLAlchemyError:
1588
+ return {
1589
+ "status": "skipped",
1590
+ "reason": "Metadata store is unavailable.",
1591
+ "fingerprint": fingerprint,
1592
+ }
1593
+
1594
+ error_category = f"investigation.analysis.{normalized_category}"
1595
+ failed_identifier = truncate_text(normalized_missing, 255)
1596
+
1597
+ try:
1598
+ existing = session.scalar(
1599
+ select(BusinessLogicSuggestion).where(
1600
+ BusinessLogicSuggestion.connector_id == connector_id,
1601
+ BusinessLogicSuggestion.error_category == error_category,
1602
+ BusinessLogicSuggestion.failed_identifier == failed_identifier,
1603
+ BusinessLogicSuggestion.repaired_identifier == fingerprint,
1604
+ )
1605
+ )
1606
+ similar_existing = session.scalars(
1607
+ select(BusinessLogicSuggestion).where(
1608
+ BusinessLogicSuggestion.connector_id == connector_id,
1609
+ BusinessLogicSuggestion.error_category == error_category,
1610
+ BusinessLogicSuggestion.failed_identifier == failed_identifier,
1611
+ BusinessLogicSuggestion.repaired_identifier != fingerprint,
1612
+ )
1613
+ ).all()
1614
+
1615
+ if existing is not None:
1616
+ return {
1617
+ "status": "existing",
1618
+ "suggestion_id": existing.id,
1619
+ "fingerprint": fingerprint,
1620
+ "similar_existing_suggestion_ids": [
1621
+ item.id for item in similar_existing
1622
+ ],
1623
+ }
1624
+
1625
+ compact_result = compact_investigation_analysis_result(analysis_response)
1626
+ rule_text = (
1627
+ f"[{normalized_category}] {missing_information.strip()} => {compact_result}"
1628
+ )
1629
+ suggestion = BusinessLogicSuggestion(
1630
+ connector_id=connector_id,
1631
+ source_audit_id=source_audit_id,
1632
+ status=BUSINESS_LOGIC_STATUS_PENDING,
1633
+ safety=BUSINESS_LOGIC_SAFETY_REVIEW,
1634
+ enabled=False,
1635
+ error_category=error_category,
1636
+ title=truncate_text(
1637
+ f"Investigation analysis: {missing_information.strip()}",
1638
+ 255,
1639
+ ),
1640
+ rule_text=rule_text,
1641
+ terms_json=json_dumps(
1642
+ build_investigation_analysis_terms(
1643
+ missing_information,
1644
+ required_analysis,
1645
+ compact_result,
1646
+ )
1647
+ ),
1648
+ join_hints_json=json_dumps([]),
1649
+ failed_identifier=failed_identifier,
1650
+ repaired_identifier=fingerprint,
1651
+ confidence=coerce_confidence(
1652
+ analysis_response.metadata.get("confidence")
1653
+ ),
1654
+ updated_by=actor,
1655
+ )
1656
+ session.add(suggestion)
1657
+ session.commit()
1658
+ return {
1659
+ "status": "created",
1660
+ "suggestion_id": suggestion.id,
1661
+ "fingerprint": fingerprint,
1662
+ "similar_existing_suggestion_ids": [
1663
+ item.id for item in similar_existing
1664
+ ],
1665
+ }
1666
+ except SQLAlchemyError:
1667
+ session.rollback()
1668
+ return {
1669
+ "status": "skipped",
1670
+ "reason": "Could not store investigation analysis business logic.",
1671
+ "fingerprint": fingerprint,
1672
+ }
1673
+ finally:
1674
+ session.close()
1675
+
1676
+
1677
+ def normalize_investigation_analysis_category(value: str) -> str:
1678
+ normalized = value.strip().lower().replace("-", "_").replace(" ", "_")
1679
+ return normalized if normalized in INVESTIGATION_ANALYSIS_CATEGORIES else "unknown"
1680
+
1681
+
1682
+ def normalize_fingerprint_text(value: str) -> str:
1683
+ return re.sub(r"\s+", " ", value.strip().lower())
1684
+
1685
+
1686
+ def investigation_analysis_result_signature(response: QueryResponse) -> dict[str, Any]:
1687
+ sorted_rows = sorted(
1688
+ response.rows,
1689
+ key=lambda item: json_dumps(item),
1690
+ )
1691
+ return {
1692
+ "row_count": len(response.rows),
1693
+ "rows": sorted_rows[:50],
1694
+ "answer": response.answer if not response.rows else "",
1695
+ }
1696
+
1697
+
1698
+ def investigation_analysis_fingerprint(
1699
+ category: str,
1700
+ missing_information: str,
1701
+ required_analysis: str,
1702
+ result_signature: dict[str, Any],
1703
+ ) -> str:
1704
+ payload = {
1705
+ "category": category,
1706
+ "missing_information": missing_information,
1707
+ "required_analysis": required_analysis,
1708
+ "result_signature": result_signature,
1709
+ }
1710
+ return hashlib.sha256(json_dumps(payload).encode("utf-8")).hexdigest()
1711
+
1712
+
1713
+ def compact_investigation_analysis_result(response: QueryResponse) -> str:
1714
+ if response.rows:
1715
+ return truncate_text(json_dumps(response.rows[:20]), 2_000)
1716
+
1717
+ if response.answer:
1718
+ return truncate_text(response.answer, 2_000)
1719
+
1720
+ return "no rows"
1721
+
1722
+
1723
+ def build_investigation_analysis_terms(
1724
+ missing_information: str,
1725
+ required_analysis: str,
1726
+ compact_result: str,
1727
+ ) -> list[str]:
1728
+ terms: set[str] = set()
1729
+ for value in (missing_information, required_analysis, compact_result):
1730
+ for token in re.findall(r"[A-Za-zÀ-ž0-9_]{4,}", value.lower()):
1731
+ terms.add(token)
1732
+
1733
+ return sorted(terms)
1734
+
1735
+
1736
+ def record_candidate_business_knowledge(
1737
+ connector_id: int | None,
1738
+ knowledge_items: list[dict[str, Any]],
1739
+ actor: str = "system",
1740
+ ) -> list[int]:
1741
+ if connector_id is None or not knowledge_items:
1742
+ return []
1743
+
1744
+ try:
1745
+ session = create_session()
1746
+ except SQLAlchemyError:
1747
+ return []
1748
+
1749
+ created_ids: list[int] = []
1750
+ try:
1751
+ for item in knowledge_items:
1752
+ if not isinstance(item, dict):
1753
+ continue
1754
+
1755
+ claim = str(item.get("claim") or "").strip()
1756
+ if not claim:
1757
+ continue
1758
+
1759
+ subject = {
1760
+ "datasource_id": item.get("datasource_id"),
1761
+ "tables": item.get("tables") or [],
1762
+ "columns": item.get("columns") or [],
1763
+ "values": item.get("values") or [],
1764
+ }
1765
+ row = BusinessKnowledgeClaim(
1766
+ connector_id=connector_id,
1767
+ knowledge_type=str(item.get("knowledge_type") or "business_semantic"),
1768
+ status=str(item.get("status") or "candidate"),
1769
+ claim_text=claim,
1770
+ subject_json=json_dumps(subject),
1771
+ evidence_json=json_dumps(item.get("evidence") or []),
1772
+ confidence=coerce_confidence(item.get("confidence")),
1773
+ source=str(item.get("source") or "query_pipeline"),
1774
+ request_id=str(item.get("request_id") or ""),
1775
+ audit_reference=str(item.get("audit_event_id") or ""),
1776
+ requires_approval=bool(item.get("requires_approval", True)),
1777
+ updated_by=actor,
1778
+ )
1779
+ session.add(row)
1780
+ session.flush()
1781
+ created_ids.append(row.id)
1782
+
1783
+ session.commit()
1784
+ return created_ids
1785
+ except SQLAlchemyError:
1786
+ session.rollback()
1787
+ return []
1788
+ finally:
1789
+ session.close()
1790
+
1791
+
1792
+ def list_admin_audit_logs(session: Session, limit: int = 100) -> list[AdminAuditLog]:
1793
+ return list(
1794
+ session.scalars(
1795
+ select(AdminAuditLog).order_by(desc(AdminAuditLog.occurred_at)).limit(limit)
1796
+ )
1797
+ )
1798
+
1799
+
1800
+ def list_prompt_templates(session: Session) -> list[PromptTemplate]:
1801
+ return list(
1802
+ session.scalars(
1803
+ select(PromptTemplate).order_by(PromptTemplate.prompt_key)
1804
+ )
1805
+ )
1806
+
1807
+
1808
+ def get_prompt_template(session: Session, prompt_key: str) -> PromptTemplate | None:
1809
+ return session.scalar(
1810
+ select(PromptTemplate).where(PromptTemplate.prompt_key == prompt_key)
1811
+ )
1812
+
1813
+
1814
+ def get_active_prompt_template_safe(prompt_key: str) -> PromptTemplate | None:
1815
+ try:
1816
+ session = create_session()
1817
+ except SQLAlchemyError:
1818
+ return None
1819
+
1820
+ try:
1821
+ return session.scalar(
1822
+ select(PromptTemplate).where(
1823
+ PromptTemplate.prompt_key == prompt_key,
1824
+ PromptTemplate.active.is_(True),
1825
+ )
1826
+ )
1827
+ except SQLAlchemyError:
1828
+ return None
1829
+ finally:
1830
+ session.close()
1831
+
1832
+
1833
+ def list_datasource_connectors(session: Session) -> list[DatasourceConnector]:
1834
+ return list(
1835
+ session.scalars(
1836
+ select(DatasourceConnector).order_by(
1837
+ DatasourceConnector.active.desc(),
1838
+ DatasourceConnector.name,
1839
+ )
1840
+ )
1841
+ )
1842
+
1843
+
1844
+ def is_system_datasource_connector(connector: DatasourceConnector) -> bool:
1845
+ return connector.connector_key in SYSTEM_DATASOURCE_CONNECTOR_KEYS
1846
+
1847
+
1848
+ def get_datasource_connector(
1849
+ session: Session,
1850
+ connector_id: int,
1851
+ ) -> DatasourceConnector | None:
1852
+ return session.get(DatasourceConnector, connector_id)
1853
+
1854
+
1855
+ def get_datasource_connector_by_key(
1856
+ session: Session,
1857
+ connector_key: str,
1858
+ ) -> DatasourceConnector | None:
1859
+ return session.scalar(
1860
+ select(DatasourceConnector).where(DatasourceConnector.connector_key == connector_key)
1861
+ )
1862
+
1863
+
1864
+ def get_active_datasource_connector(session: Session) -> DatasourceConnector | None:
1865
+ return session.scalar(
1866
+ select(DatasourceConnector).where(DatasourceConnector.active.is_(True))
1867
+ )
1868
+
1869
+
1870
+ def get_active_datasource_connector_safe() -> DatasourceConnector | None:
1871
+ try:
1872
+ session = create_session()
1873
+ except SQLAlchemyError:
1874
+ return None
1875
+
1876
+ try:
1877
+ return get_active_datasource_connector(session)
1878
+ except SQLAlchemyError:
1879
+ return None
1880
+ finally:
1881
+ session.close()
1882
+
1883
+
1884
+ def set_active_datasource_connector(
1885
+ session: Session,
1886
+ connector: DatasourceConnector,
1887
+ actor: str,
1888
+ ) -> None:
1889
+ for item in list_datasource_connectors(session):
1890
+ if is_system_datasource_connector(item):
1891
+ item.active = False
1892
+ continue
1893
+
1894
+ item.active = item.id == connector.id
1895
+ item.updated_by = actor if item.id == connector.id else item.updated_by
1896
+
1897
+
1898
+ def test_datasource_connection(connector: DatasourceConnector) -> None:
1899
+ connect_args = (
1900
+ {"check_same_thread": False}
1901
+ if connector.database_url.startswith("sqlite")
1902
+ else {}
1903
+ )
1904
+ engine = create_engine(connector.database_url, connect_args=connect_args)
1905
+
1906
+ try:
1907
+ with engine.connect() as connection:
1908
+ connection.execute(text("SELECT 1"))
1909
+ finally:
1910
+ engine.dispose()
1911
+
1912
+
1913
+ def get_datasource_schema_cache(
1914
+ session: Session,
1915
+ connector_id: int,
1916
+ ) -> DatasourceSchemaCache | None:
1917
+ return session.get(DatasourceSchemaCache, connector_id)
1918
+
1919
+
1920
+ def get_or_create_datasource_schema_cache(
1921
+ session: Session,
1922
+ connector: DatasourceConnector,
1923
+ actor: str,
1924
+ ) -> DatasourceSchemaCache:
1925
+ cache = get_datasource_schema_cache(session, connector.id)
1926
+
1927
+ if cache is not None:
1928
+ return cache
1929
+
1930
+ return introspect_datasource_connector(session, connector, actor)
1931
+
1932
+
1933
+ def introspect_datasource_connector(
1934
+ session: Session,
1935
+ connector: DatasourceConnector,
1936
+ actor: str,
1937
+ ) -> DatasourceSchemaCache:
1938
+ schema = SQLAlchemySchemaIntrospector(connector.database_url).introspect()
1939
+ existing = get_datasource_schema_cache(session, connector.id)
1940
+ existing_settings = (
1941
+ json_loads(existing.table_settings_json) if existing is not None else {}
1942
+ )
1943
+ table_settings = build_table_settings(schema, existing_settings)
1944
+ formatted_schema = format_schema_for_prompt(schema, table_settings)
1945
+ schema_json = json_dumps(schema.model_dump())
1946
+
1947
+ if existing is None:
1948
+ existing = DatasourceSchemaCache(
1949
+ connector_id=connector.id,
1950
+ schema_json=schema_json,
1951
+ table_settings_json=json_dumps(table_settings),
1952
+ formatted_schema=formatted_schema,
1953
+ updated_by=actor,
1954
+ )
1955
+ session.add(existing)
1956
+ else:
1957
+ existing.schema_json = schema_json
1958
+ existing.table_settings_json = json_dumps(table_settings)
1959
+ existing.formatted_schema = formatted_schema
1960
+ existing.introspected_at = datetime.now(UTC)
1961
+ existing.updated_by = actor
1962
+
1963
+ return existing
1964
+
1965
+
1966
+ def build_table_settings(
1967
+ schema: DatabaseSchema,
1968
+ existing_settings: dict[str, Any] | None = None,
1969
+ ) -> dict[str, Any]:
1970
+ existing_tables = (existing_settings or {}).get("tables", {})
1971
+ tables: dict[str, Any] = {}
1972
+
1973
+ for table in schema.tables:
1974
+ existing = existing_tables.get(table.name, {})
1975
+ tables[table.name] = {
1976
+ "selected": bool(existing.get("selected", True)),
1977
+ "description": str(existing.get("description", "")),
1978
+ "primary_key_prompt": str(existing.get("primary_key_prompt", "")),
1979
+ "foreign_key_prompt": str(existing.get("foreign_key_prompt", "")),
1980
+ "join_logic": str(existing.get("join_logic", "")),
1981
+ }
1982
+
1983
+ return {"tables": tables}
1984
+
1985
+
1986
+ def parse_schema_cache(cache: DatasourceSchemaCache) -> DatabaseSchema:
1987
+ return DatabaseSchema.model_validate(json_loads(cache.schema_json))
1988
+
1989
+
1990
+ def selected_schema_from_cache(cache: DatasourceSchemaCache) -> DatabaseSchema:
1991
+ schema = parse_schema_cache(cache)
1992
+ table_settings = json_loads(cache.table_settings_json)
1993
+ selected_tables = table_settings.get("tables", {})
1994
+
1995
+ tables = [
1996
+ table
1997
+ for table in schema.tables
1998
+ if selected_tables.get(table.name, {}).get("selected", True)
1999
+ ]
2000
+
2001
+ return DatabaseSchema(tables=tables)
2002
+
2003
+
2004
+ def format_schema_for_prompt(
2005
+ schema: DatabaseSchema,
2006
+ table_settings: dict[str, Any],
2007
+ ) -> str:
2008
+ selected_tables = table_settings.get("tables", {})
2009
+ sections: list[str] = []
2010
+
2011
+ for table in sorted(schema.tables, key=lambda item: item.name):
2012
+ table_config = selected_tables.get(table.name, {})
2013
+
2014
+ if not table_config.get("selected", True):
2015
+ continue
2016
+
2017
+ sections.append(format_table_for_prompt(table, table_config))
2018
+
2019
+ if not sections:
2020
+ return "No tables or views available."
2021
+
2022
+ return "\n\n".join(sections)
2023
+
2024
+
2025
+ def format_table_for_prompt(table: TableInfo, table_config: dict[str, Any]) -> str:
2026
+ object_label = "View" if table.object_type == "view" else "Table"
2027
+ lines: list[str] = [f"{object_label}: {table.name}"]
2028
+
2029
+ description = str(table_config.get("description", "")).strip()
2030
+ if description:
2031
+ lines.append(f"Description: {description}")
2032
+
2033
+ lines.append("Columns:")
2034
+
2035
+ if not table.columns:
2036
+ lines.append("- No columns available.")
2037
+ else:
2038
+ for column in table.columns:
2039
+ lines.append(format_column_for_prompt(column))
2040
+
2041
+ primary_key_prompt = str(table_config.get("primary_key_prompt", "")).strip()
2042
+ if primary_key_prompt:
2043
+ lines.append("Primary key guidance:")
2044
+ lines.append(primary_key_prompt)
2045
+
2046
+ if table.foreign_keys:
2047
+ lines.append("Foreign keys:")
2048
+ for foreign_key in table.foreign_keys:
2049
+ constrained = ", ".join(foreign_key.constrained_columns)
2050
+ referred = ", ".join(foreign_key.referred_columns)
2051
+ lines.append(f"- {constrained} -> {foreign_key.referred_table}.{referred}")
2052
+
2053
+ foreign_key_prompt = str(table_config.get("foreign_key_prompt", "")).strip()
2054
+ if foreign_key_prompt:
2055
+ lines.append("Foreign key guidance:")
2056
+ lines.append(foreign_key_prompt)
2057
+
2058
+ join_logic = str(table_config.get("join_logic", "")).strip()
2059
+ if join_logic:
2060
+ lines.append("Join logic:")
2061
+ lines.append(join_logic)
2062
+
2063
+ return "\n".join(lines)
2064
+
2065
+
2066
+ def format_column_for_prompt(column: ColumnInfo) -> str:
2067
+ modifiers: list[str] = []
2068
+
2069
+ if column.primary_key:
2070
+ modifiers.append("primary key")
2071
+
2072
+ if not column.nullable:
2073
+ modifiers.append("not null")
2074
+
2075
+ modifier_text = f" ({', '.join(modifiers)})" if modifiers else ""
2076
+ return f"- {column.name}: {column.type}{modifier_text}"
2077
+
2078
+
2079
+ def update_schema_table_settings(
2080
+ session: Session,
2081
+ cache: DatasourceSchemaCache,
2082
+ table_settings: dict[str, Any],
2083
+ actor: str,
2084
+ ) -> DatasourceSchemaCache:
2085
+ schema = parse_schema_cache(cache)
2086
+ merged = build_table_settings(schema, table_settings)
2087
+ cache.table_settings_json = json_dumps(merged)
2088
+ cache.formatted_schema = format_schema_for_prompt(schema, merged)
2089
+ cache.updated_by = actor
2090
+
2091
+ return cache
2092
+
2093
+
2094
+ def list_overview_widgets(session: Session) -> list[OverviewWidget]:
2095
+ return list(
2096
+ session.scalars(
2097
+ select(OverviewWidget)
2098
+ .where(OverviewWidget.active.is_(True))
2099
+ .order_by(OverviewWidget.position.asc(), OverviewWidget.id.asc())
2100
+ )
2101
+ )
2102
+
2103
+
2104
+ def list_all_overview_widgets(session: Session) -> list[OverviewWidget]:
2105
+ return list(
2106
+ session.scalars(
2107
+ select(OverviewWidget)
2108
+ .order_by(OverviewWidget.position.asc(), OverviewWidget.id.asc())
2109
+ )
2110
+ )
2111
+
2112
+
2113
+ def get_overview_widget(session: Session, widget_key: str) -> OverviewWidget | None:
2114
+ return session.scalar(
2115
+ select(OverviewWidget).where(OverviewWidget.widget_key == widget_key)
2116
+ )
2117
+
2118
+
2119
+ def get_datasource_schema_context_safe() -> tuple[DatasourceConnector, DatasourceSchemaCache] | None:
2120
+ try:
2121
+ session = create_session()
2122
+ except SQLAlchemyError:
2123
+ return None
2124
+
2125
+ try:
2126
+ connector = get_active_datasource_connector(session)
2127
+
2128
+ if connector is None:
2129
+ return None
2130
+
2131
+ cache = get_datasource_schema_cache(session, connector.id)
2132
+
2133
+ if cache is None:
2134
+ cache = introspect_datasource_connector(session, connector, "system")
2135
+ session.commit()
2136
+
2137
+ return connector, cache
2138
+ except SQLAlchemyError:
2139
+ session.rollback()
2140
+ return None
2141
+ finally:
2142
+ session.close()