gaard-api 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gaard_api/__init__.py ADDED
File without changes
@@ -0,0 +1 @@
1
+
@@ -0,0 +1,513 @@
1
+ from collections.abc import Iterator
2
+ import json
3
+
4
+ from sqlalchemy import create_engine, select
5
+ from sqlalchemy.engine import Engine
6
+ from sqlalchemy import inspect, text
7
+ from sqlalchemy.orm import Session, sessionmaker
8
+
9
+ from gaard_api.admin.defaults import DEFAULT_GOVERNANCE_POLICY_CONFIG, DEFAULT_PROMPTS
10
+ from gaard_api.admin.models import (
11
+ AdminSetting,
12
+ AdminUser,
13
+ Base,
14
+ DataQueryAuditLog,
15
+ DataQueryAuditType,
16
+ DatasourceConnector,
17
+ OverviewWidget,
18
+ PromptTemplate,
19
+ )
20
+ from gaard_api.admin.security import hash_password
21
+ from gaard_api.core.settings import settings
22
+
23
+
24
+ _engine: Engine | None = None
25
+ _session_factory: sessionmaker[Session] | None = None
26
+ _engine_url: str | None = None
27
+
28
+
29
+ def get_engine() -> Engine:
30
+ global _engine, _engine_url, _session_factory
31
+
32
+ database_url = settings.gaard_metadata_database_url
33
+
34
+ if _engine is not None and _engine_url == database_url:
35
+ return _engine
36
+
37
+ connect_args = {"check_same_thread": False} if database_url.startswith("sqlite") else {}
38
+ _engine = create_engine(database_url, connect_args=connect_args)
39
+ _session_factory = sessionmaker(bind=_engine, autoflush=False, expire_on_commit=False)
40
+ _engine_url = database_url
41
+
42
+ return _engine
43
+
44
+
45
+ def get_session() -> Iterator[Session]:
46
+ init_metadata_store()
47
+
48
+ if _session_factory is None:
49
+ raise RuntimeError("Admin metadata session factory is not initialized.")
50
+
51
+ session = _session_factory()
52
+ try:
53
+ yield session
54
+ finally:
55
+ session.close()
56
+
57
+
58
+ def create_session() -> Session:
59
+ init_metadata_store()
60
+
61
+ if _session_factory is None:
62
+ raise RuntimeError("Admin metadata session factory is not initialized.")
63
+
64
+ return _session_factory()
65
+
66
+
67
+ def init_metadata_store() -> None:
68
+ engine = get_engine()
69
+ Base.metadata.create_all(engine)
70
+ ensure_data_query_audit_schema(engine)
71
+ ensure_overview_widget_schema(engine)
72
+
73
+ if _session_factory is None:
74
+ raise RuntimeError("Admin metadata session factory is not initialized.")
75
+
76
+ with _session_factory() as session:
77
+ seed_admin_user(session)
78
+ seed_settings(session)
79
+ apply_runtime_settings(session)
80
+ seed_prompts(session)
81
+ seed_datasource_connectors(session)
82
+ seed_overview_widgets(session)
83
+ backfill_data_query_audit_types(session)
84
+ session.commit()
85
+
86
+
87
+ def seed_admin_user(session: Session) -> None:
88
+ user = session.scalar(select(AdminUser).where(AdminUser.username == "admin"))
89
+
90
+ if user is not None:
91
+ return
92
+
93
+ session.add(
94
+ AdminUser(
95
+ username="admin",
96
+ password_hash=hash_password("admin"),
97
+ must_change_password=True,
98
+ )
99
+ )
100
+
101
+
102
+ def seed_settings(session: Session) -> None:
103
+ defaults = {
104
+ "gaard_intent_classification_mode": settings.gaard_intent_classification_mode,
105
+ "gaard_sql_generation_mode": settings.gaard_sql_generation_mode,
106
+ "gaard_result_interpretation_mode": settings.gaard_result_interpretation_mode,
107
+ "gaard_output_classification_mode": settings.gaard_output_classification_mode,
108
+ "gaard_investigation_mode": settings.gaard_investigation_mode,
109
+ "gaard_investigation_ambiguity_mode": settings.gaard_investigation_ambiguity_mode,
110
+ "gaard_query_max_rows": str(settings.gaard_query_max_rows),
111
+ "gaard_query_timeout_seconds": str(settings.gaard_query_timeout_seconds),
112
+ "gaard_llm_provider": settings.gaard_llm_provider,
113
+ "gaard_llm_base_url": settings.gaard_llm_base_url,
114
+ "gaard_llm_api_key": settings.gaard_llm_api_key,
115
+ "gaard_llm_model": settings.gaard_llm_model,
116
+ "gaard_llm_timeout_seconds": str(settings.gaard_llm_timeout_seconds),
117
+ "gaard_llm_extra_body": json.dumps(
118
+ settings.gaard_llm_extra_body,
119
+ ensure_ascii=False,
120
+ sort_keys=True,
121
+ ),
122
+ "gaard_governance_policy": json.dumps(
123
+ DEFAULT_GOVERNANCE_POLICY_CONFIG,
124
+ ensure_ascii=False,
125
+ sort_keys=True,
126
+ ),
127
+ "data_query_audit_retention_days": str(settings.gaard_audit_retention_days),
128
+ "schema_cache_ttl_seconds": str(settings.gaard_schema_cache_ttl_seconds),
129
+ "license_edition": "community",
130
+ }
131
+
132
+ for key, value in defaults.items():
133
+ setting = session.get(AdminSetting, key)
134
+ if setting is None:
135
+ session.add(AdminSetting(key=key, value=value))
136
+ elif setting.updated_by == "system" and setting.value != value:
137
+ setting.value = value
138
+
139
+
140
+ def apply_runtime_settings(session: Session) -> None:
141
+ schema_cache_ttl = session.get(AdminSetting, "schema_cache_ttl_seconds")
142
+
143
+ if schema_cache_ttl is None:
144
+ return
145
+
146
+ try:
147
+ ttl_seconds = max(1, int(schema_cache_ttl.value))
148
+ except (TypeError, ValueError):
149
+ return
150
+
151
+ from gaard_api.core.schema_cache import schema_context_cache
152
+
153
+ schema_context_cache.ttl_seconds = ttl_seconds
154
+
155
+
156
+ def seed_prompts(session: Session) -> None:
157
+ for prompt in DEFAULT_PROMPTS:
158
+ existing = session.scalar(
159
+ select(PromptTemplate).where(PromptTemplate.prompt_key == prompt["prompt_key"])
160
+ )
161
+
162
+ if existing is None:
163
+ session.add(PromptTemplate(**prompt))
164
+ continue
165
+
166
+ if existing.updated_by != "system":
167
+ continue
168
+
169
+ changed = any(
170
+ getattr(existing, field) != prompt[field]
171
+ for field in (
172
+ "name",
173
+ "description",
174
+ "system_prompt",
175
+ "user_prompt_template",
176
+ )
177
+ )
178
+ if not changed:
179
+ continue
180
+
181
+ existing.name = prompt["name"]
182
+ existing.description = prompt["description"]
183
+ existing.system_prompt = prompt["system_prompt"]
184
+ existing.user_prompt_template = prompt["user_prompt_template"]
185
+ existing.active = True
186
+ existing.version += 1
187
+ existing.updated_by = "system"
188
+
189
+
190
+ def seed_datasource_connectors(session: Session) -> None:
191
+ migrate_postgres_sql_dialect(session)
192
+
193
+ default_connector = session.scalar(
194
+ select(DatasourceConnector).where(DatasourceConnector.connector_key == "default")
195
+ )
196
+
197
+ if default_connector is None:
198
+ session.add(
199
+ DatasourceConnector(
200
+ connector_key="default",
201
+ name="Medical POC SQLite",
202
+ database_type="sqlite",
203
+ database_url=settings.gaard_datasource_url,
204
+ sql_dialect=settings.gaard_sql_dialect,
205
+ active=True,
206
+ )
207
+ )
208
+
209
+ metadata_connector = session.scalar(
210
+ select(DatasourceConnector).where(DatasourceConnector.connector_key == "metadata-db")
211
+ )
212
+ database_type, sql_dialect = infer_datasource_type(settings.gaard_metadata_database_url)
213
+
214
+ if metadata_connector is None:
215
+ session.add(
216
+ DatasourceConnector(
217
+ connector_key="metadata-db",
218
+ name="GAARD Metadata DB",
219
+ database_type=database_type,
220
+ database_url=settings.gaard_metadata_database_url,
221
+ sql_dialect=sql_dialect,
222
+ active=False,
223
+ )
224
+ )
225
+ else:
226
+ metadata_connector.name = "GAARD Metadata DB"
227
+ metadata_connector.database_type = database_type
228
+ metadata_connector.database_url = settings.gaard_metadata_database_url
229
+ metadata_connector.sql_dialect = sql_dialect
230
+ metadata_connector.active = False
231
+ metadata_connector.updated_by = "system"
232
+
233
+ active_user_connector = session.scalar(
234
+ select(DatasourceConnector).where(
235
+ DatasourceConnector.connector_key != "metadata-db",
236
+ DatasourceConnector.active.is_(True),
237
+ )
238
+ )
239
+ default_connector = session.scalar(
240
+ select(DatasourceConnector).where(DatasourceConnector.connector_key == "default")
241
+ )
242
+
243
+ if active_user_connector is None and default_connector is not None:
244
+ default_connector.active = True
245
+
246
+
247
+ def infer_datasource_type(database_url: str) -> tuple[str, str]:
248
+ if database_url.startswith("sqlite"):
249
+ return "sqlite", "sqlite"
250
+
251
+ if database_url.startswith("postgresql"):
252
+ return "postgresql", "postgres"
253
+
254
+ if database_url.startswith("mysql"):
255
+ return "mysql", "mysql"
256
+
257
+ return "postgresql", "postgres"
258
+
259
+
260
+ def migrate_postgres_sql_dialect(session: Session) -> None:
261
+ for connector in session.scalars(
262
+ select(DatasourceConnector).where(DatasourceConnector.sql_dialect == "postgresql")
263
+ ):
264
+ connector.sql_dialect = "postgres"
265
+
266
+
267
+ def seed_overview_widgets(session: Session) -> None:
268
+ _database_type, metadata_sql_dialect = infer_datasource_type(settings.gaard_metadata_database_url)
269
+ runtime_sql = (
270
+ "SELECT DATE(occurred_at) AS day, datasource_id, COUNT(*) AS query_count "
271
+ "FROM data_query_audit_logs "
272
+ "GROUP BY DATE(occurred_at), datasource_id "
273
+ "ORDER BY day, datasource_id"
274
+ if metadata_sql_dialect == "sqlite"
275
+ else "SELECT occurred_at::date AS day, datasource_id, COUNT(*) AS query_count "
276
+ "FROM data_query_audit_logs "
277
+ "GROUP BY occurred_at::date, datasource_id "
278
+ "ORDER BY day, datasource_id"
279
+ )
280
+ defaults = [
281
+ {
282
+ "widget_key": "prompts_count",
283
+ "label": "Prompts",
284
+ "widget_type": "scalar",
285
+ "datasource_key": "metadata-db",
286
+ "question": (
287
+ "How many prompt templates are configured in GAARD metadata? "
288
+ "Return exactly one numeric value."
289
+ ),
290
+ "sql": "SELECT COUNT(*) AS value FROM prompt_templates",
291
+ "result_mode": "data",
292
+ "position": 10,
293
+ "grid_width": 1,
294
+ },
295
+ {
296
+ "widget_key": "audit_retention",
297
+ "label": "Audit retention",
298
+ "widget_type": "scalar",
299
+ "datasource_key": "metadata-db",
300
+ "question": (
301
+ "What is the value of the admin setting named "
302
+ "data_query_audit_retention_days? Return exactly one numeric value."
303
+ ),
304
+ "sql": (
305
+ "SELECT CAST(value AS INTEGER) AS value "
306
+ "FROM admin_settings "
307
+ "WHERE key = 'data_query_audit_retention_days'"
308
+ ),
309
+ "result_mode": "data",
310
+ "position": 20,
311
+ "grid_width": 1,
312
+ },
313
+ {
314
+ "widget_key": "schema_cache_ttl",
315
+ "label": "Schema cache TTL",
316
+ "widget_type": "scalar",
317
+ "datasource_key": "metadata-db",
318
+ "question": (
319
+ "What is the value of the admin setting named schema_cache_ttl_seconds? "
320
+ "Return exactly one numeric value."
321
+ ),
322
+ "sql": (
323
+ "SELECT CAST(value AS INTEGER) AS value "
324
+ "FROM admin_settings "
325
+ "WHERE key = 'schema_cache_ttl_seconds'"
326
+ ),
327
+ "result_mode": "data",
328
+ "position": 30,
329
+ "grid_width": 1,
330
+ },
331
+ {
332
+ "widget_key": "license_edition",
333
+ "label": "License",
334
+ "widget_type": "scalar",
335
+ "datasource_key": "metadata-db",
336
+ "question": (
337
+ "What is the value of the admin setting named license_edition? "
338
+ "Return exactly one text value."
339
+ ),
340
+ "sql": "SELECT value AS value FROM admin_settings WHERE key = 'license_edition'",
341
+ "result_mode": "data",
342
+ "position": 40,
343
+ "grid_width": 1,
344
+ },
345
+ {
346
+ "widget_key": "runtime_daily_queries",
347
+ "label": "Runtime",
348
+ "widget_type": "timeseries",
349
+ "datasource_key": "metadata-db",
350
+ "question": (
351
+ "For each day and datasource_id in data_query_audit_logs, how many "
352
+ "query records exist? Return columns day, datasource_id, query_count "
353
+ "ordered by day and datasource_id."
354
+ ),
355
+ "sql": runtime_sql,
356
+ "result_mode": "data",
357
+ "position": 100,
358
+ "grid_width": 4,
359
+ "active": False,
360
+ },
361
+ {
362
+ "widget_key": "prompt_templates_table",
363
+ "label": "Prompt templates",
364
+ "widget_type": "table",
365
+ "datasource_key": "metadata-db",
366
+ "question": (
367
+ "List configured prompt templates with prompt_key, name, version and "
368
+ "active status, ordered by prompt_key."
369
+ ),
370
+ "sql": (
371
+ "SELECT prompt_key, name, version, active "
372
+ "FROM prompt_templates "
373
+ "ORDER BY prompt_key"
374
+ ),
375
+ "result_mode": "data",
376
+ "position": 50,
377
+ "grid_width": 4,
378
+ },
379
+ ]
380
+
381
+ for item in defaults:
382
+ existing = session.scalar(
383
+ select(OverviewWidget).where(OverviewWidget.widget_key == item["widget_key"])
384
+ )
385
+
386
+ if existing is None:
387
+ session.add(OverviewWidget(**item))
388
+ elif not existing.sql and existing.updated_by == "system":
389
+ existing.sql = str(item["sql"])
390
+
391
+ if (
392
+ existing is not None
393
+ and item["widget_key"] == "runtime_daily_queries"
394
+ and existing.updated_by == "system"
395
+ ):
396
+ existing.active = False
397
+
398
+ if existing is not None and existing.updated_by == "system":
399
+ existing.position = int(item["position"])
400
+ existing.grid_width = int(item["grid_width"])
401
+ existing.result_mode = str(item["result_mode"])
402
+
403
+
404
+ def ensure_data_query_audit_schema(engine: Engine) -> None:
405
+ inspector = inspect(engine)
406
+
407
+ if "data_query_audit_logs" not in inspector.get_table_names():
408
+ return
409
+
410
+ columns = {column["name"] for column in inspector.get_columns("data_query_audit_logs")}
411
+
412
+ if "type" not in columns:
413
+ with engine.begin() as connection:
414
+ connection.execute(
415
+ text(
416
+ "ALTER TABLE data_query_audit_logs "
417
+ "ADD COLUMN type VARCHAR(50) NOT NULL DEFAULT 'info'"
418
+ )
419
+ )
420
+
421
+ if "output_classification" not in columns:
422
+ with engine.begin() as connection:
423
+ connection.execute(
424
+ text(
425
+ "ALTER TABLE data_query_audit_logs "
426
+ "ADD COLUMN output_classification VARCHAR(50) "
427
+ "NOT NULL DEFAULT 'unknown'"
428
+ )
429
+ )
430
+
431
+ with engine.begin() as connection:
432
+ for index in DataQueryAuditLog.__table__.indexes:
433
+ index.create(bind=connection, checkfirst=True)
434
+
435
+
436
+ def backfill_data_query_audit_types(session: Session) -> None:
437
+ logs = session.scalars(
438
+ select(DataQueryAuditLog).where(
439
+ DataQueryAuditLog.metadata_json.like('%"audit_type"%')
440
+ )
441
+ )
442
+
443
+ for log in logs:
444
+ try:
445
+ metadata = json.loads(log.metadata_json or "{}")
446
+ except json.JSONDecodeError:
447
+ continue
448
+
449
+ if not isinstance(metadata, dict):
450
+ continue
451
+
452
+ audit_type = coerce_legacy_data_query_audit_type(metadata.get("audit_type"))
453
+
454
+ if audit_type is None:
455
+ continue
456
+
457
+ metadata.pop("audit_type", None)
458
+ log.type = audit_type
459
+ log.metadata_json = json.dumps(metadata, ensure_ascii=False, sort_keys=True)
460
+
461
+
462
+ def coerce_legacy_data_query_audit_type(value: object) -> DataQueryAuditType | None:
463
+ if isinstance(value, DataQueryAuditType):
464
+ return value
465
+
466
+ if not isinstance(value, str):
467
+ return None
468
+
469
+ normalized = value.strip().lower().replace(" ", "_").replace("-", "_")
470
+ aliases = {item.value: item for item in DataQueryAuditType}
471
+
472
+ return aliases.get(normalized)
473
+
474
+
475
+ def ensure_overview_widget_schema(engine: Engine) -> None:
476
+ inspector = inspect(engine)
477
+
478
+ if "overview_widgets" not in inspector.get_table_names():
479
+ return
480
+
481
+ columns = {column["name"] for column in inspector.get_columns("overview_widgets")}
482
+
483
+ if "sql" not in columns:
484
+ with engine.begin() as connection:
485
+ connection.execute(
486
+ text("ALTER TABLE overview_widgets ADD COLUMN sql TEXT DEFAULT ''")
487
+ )
488
+
489
+ if "grid_width" not in columns:
490
+ with engine.begin() as connection:
491
+ connection.execute(
492
+ text("ALTER TABLE overview_widgets ADD COLUMN grid_width INTEGER DEFAULT 1")
493
+ )
494
+
495
+ if "result_mode" not in columns:
496
+ with engine.begin() as connection:
497
+ connection.execute(
498
+ text(
499
+ "ALTER TABLE overview_widgets "
500
+ "ADD COLUMN result_mode VARCHAR(50) DEFAULT 'data'"
501
+ )
502
+ )
503
+
504
+
505
+ def reset_metadata_store_for_tests() -> None:
506
+ global _engine, _engine_url, _session_factory
507
+
508
+ if _engine is not None:
509
+ _engine.dispose()
510
+
511
+ _engine = None
512
+ _engine_url = None
513
+ _session_factory = None