dao-ai 0.1.5__py3-none-any.whl → 0.1.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. dao_ai/apps/__init__.py +24 -0
  2. dao_ai/apps/handlers.py +105 -0
  3. dao_ai/apps/model_serving.py +29 -0
  4. dao_ai/apps/resources.py +1122 -0
  5. dao_ai/apps/server.py +39 -0
  6. dao_ai/cli.py +446 -16
  7. dao_ai/config.py +1034 -103
  8. dao_ai/evaluation.py +543 -0
  9. dao_ai/genie/__init__.py +55 -7
  10. dao_ai/genie/cache/__init__.py +34 -7
  11. dao_ai/genie/cache/base.py +143 -2
  12. dao_ai/genie/cache/context_aware/__init__.py +31 -0
  13. dao_ai/genie/cache/context_aware/base.py +1151 -0
  14. dao_ai/genie/cache/context_aware/in_memory.py +609 -0
  15. dao_ai/genie/cache/context_aware/persistent.py +802 -0
  16. dao_ai/genie/cache/context_aware/postgres.py +1166 -0
  17. dao_ai/genie/cache/core.py +1 -1
  18. dao_ai/genie/cache/lru.py +257 -75
  19. dao_ai/genie/cache/optimization.py +890 -0
  20. dao_ai/genie/core.py +235 -11
  21. dao_ai/memory/postgres.py +175 -39
  22. dao_ai/middleware/__init__.py +5 -0
  23. dao_ai/middleware/tool_selector.py +129 -0
  24. dao_ai/models.py +327 -370
  25. dao_ai/nodes.py +4 -4
  26. dao_ai/orchestration/core.py +33 -9
  27. dao_ai/orchestration/supervisor.py +23 -8
  28. dao_ai/orchestration/swarm.py +6 -1
  29. dao_ai/{prompts.py → prompts/__init__.py} +12 -61
  30. dao_ai/prompts/instructed_retriever_decomposition.yaml +58 -0
  31. dao_ai/prompts/instruction_reranker.yaml +14 -0
  32. dao_ai/prompts/router.yaml +37 -0
  33. dao_ai/prompts/verifier.yaml +46 -0
  34. dao_ai/providers/base.py +28 -2
  35. dao_ai/providers/databricks.py +352 -33
  36. dao_ai/state.py +1 -0
  37. dao_ai/tools/__init__.py +5 -3
  38. dao_ai/tools/genie.py +103 -26
  39. dao_ai/tools/instructed_retriever.py +366 -0
  40. dao_ai/tools/instruction_reranker.py +202 -0
  41. dao_ai/tools/mcp.py +539 -97
  42. dao_ai/tools/router.py +89 -0
  43. dao_ai/tools/slack.py +13 -2
  44. dao_ai/tools/sql.py +7 -3
  45. dao_ai/tools/unity_catalog.py +32 -10
  46. dao_ai/tools/vector_search.py +493 -160
  47. dao_ai/tools/verifier.py +159 -0
  48. dao_ai/utils.py +182 -2
  49. dao_ai/vector_search.py +9 -1
  50. {dao_ai-0.1.5.dist-info → dao_ai-0.1.20.dist-info}/METADATA +10 -8
  51. dao_ai-0.1.20.dist-info/RECORD +89 -0
  52. dao_ai/agent_as_code.py +0 -22
  53. dao_ai/genie/cache/semantic.py +0 -970
  54. dao_ai-0.1.5.dist-info/RECORD +0 -70
  55. {dao_ai-0.1.5.dist-info → dao_ai-0.1.20.dist-info}/WHEEL +0 -0
  56. {dao_ai-0.1.5.dist-info → dao_ai-0.1.20.dist-info}/entry_points.txt +0 -0
  57. {dao_ai-0.1.5.dist-info → dao_ai-0.1.20.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,1166 @@
1
+ """
2
+ PostgreSQL pg_vector-based context-aware Genie cache implementation.
3
+
4
+ This module provides a context-aware cache that uses PostgreSQL with pg_vector
5
+ for semantic similarity search. It supports both standard PostgreSQL and
6
+ Databricks Lakebase connections via the DatabaseModel abstraction.
7
+
8
+ Features:
9
+ - Dual embedding matching (question + conversation context)
10
+ - pg_vector similarity search with L2 distance
11
+ - Prompt history tracking for conversation context
12
+ - TTL-based expiration with refresh-on-hit
13
+ - Space-partitioned cache entries
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ from datetime import datetime, timedelta
19
+ from typing import Any, Self
20
+
21
+ import mlflow
22
+ from databricks.sdk import WorkspaceClient
23
+ from databricks_ai_bridge.genie import GenieResponse
24
+ from loguru import logger
25
+
26
+ from dao_ai.config import (
27
+ DatabaseModel,
28
+ GenieContextAwareCacheParametersModel,
29
+ WarehouseModel,
30
+ )
31
+ from dao_ai.genie.cache.base import (
32
+ CacheResult,
33
+ GenieServiceBase,
34
+ SQLCacheEntry,
35
+ )
36
+ from dao_ai.genie.cache.context_aware.persistent import (
37
+ DbRow,
38
+ PersistentContextAwareGenieCacheService,
39
+ )
40
+
41
+
42
+ class PostgresContextAwareGenieService(PersistentContextAwareGenieCacheService):
43
+ """
44
+ PostgreSQL pg_vector-based context-aware caching decorator.
45
+
46
+ This service caches the SQL query generated by Genie along with dual embeddings
47
+ (question + conversation context) for high-precision semantic matching. On
48
+ subsequent queries, it performs similarity search using pg_vector to find
49
+ cached queries that match both the question intent AND conversation context.
50
+
51
+ Supports both standard PostgreSQL and Databricks Lakebase via DatabaseModel.
52
+
53
+ Cache entries are partitioned by genie_space_id to ensure queries from different
54
+ Genie spaces don't return incorrect cache hits.
55
+
56
+ On cache hit, it re-executes the cached SQL using the provided warehouse
57
+ to return fresh data while avoiding the Genie NL-to-SQL translation cost.
58
+
59
+ Example:
60
+ from dao_ai.config import GenieContextAwareCacheParametersModel, DatabaseModel
61
+ from dao_ai.genie.cache.context_aware import PostgresContextAwareGenieService
62
+
63
+ cache_params = GenieContextAwareCacheParametersModel(
64
+ database=database_model,
65
+ warehouse=warehouse_model,
66
+ embedding_model="databricks-gte-large-en",
67
+ time_to_live_seconds=86400, # 24 hours
68
+ similarity_threshold=0.85
69
+ )
70
+ genie = PostgresContextAwareGenieService(
71
+ impl=GenieService(Genie(space_id="my-space")),
72
+ parameters=cache_params
73
+ )
74
+
75
+ Thread-safe: Uses connection pooling from psycopg_pool.
76
+ """
77
+
78
+ impl: GenieServiceBase
79
+ parameters: GenieContextAwareCacheParametersModel
80
+ _workspace_client: WorkspaceClient | None
81
+ name: str
82
+ _embeddings: Any # DatabricksEmbeddings
83
+ _pool: Any # ConnectionPool
84
+ _embedding_dims: int | None
85
+ _setup_complete: bool
86
+
87
+ def __init__(
88
+ self,
89
+ impl: GenieServiceBase,
90
+ parameters: GenieContextAwareCacheParametersModel,
91
+ workspace_client: WorkspaceClient | None = None,
92
+ name: str | None = None,
93
+ ) -> None:
94
+ """
95
+ Initialize the PostgreSQL context-aware cache service.
96
+
97
+ Args:
98
+ impl: The underlying GenieServiceBase to delegate to on cache miss.
99
+ The space_id will be obtained from impl.space_id.
100
+ parameters: Cache configuration including database, warehouse, embedding model
101
+ workspace_client: Optional WorkspaceClient for retrieving conversation history.
102
+ If None, conversation context will not be used.
103
+ name: Name for this cache layer (for logging). Defaults to class name.
104
+ """
105
+ self.impl = impl
106
+ self.parameters = parameters
107
+ self._workspace_client = workspace_client
108
+ self.name = name if name is not None else self.__class__.__name__
109
+ self._embeddings = None
110
+ self._pool = None
111
+ self._embedding_dims = None
112
+ self._setup_complete = False
113
+ self._prompt_stored_for_current_request = False
114
+
115
+ def _setup(self) -> None:
116
+ """Initialize embeddings and database connection pool lazily."""
117
+ if self._setup_complete:
118
+ return
119
+
120
+ from dao_ai.memory.postgres import PostgresPoolManager
121
+
122
+ # Initialize embeddings using base class helper
123
+ self._initialize_embeddings(
124
+ self.parameters.embedding_model,
125
+ self.parameters.embedding_dims,
126
+ )
127
+
128
+ # Get connection pool
129
+ self._pool = PostgresPoolManager.get_pool(self.parameters.database)
130
+
131
+ # Ensure table exists
132
+ self._create_table_if_not_exists()
133
+
134
+ self._setup_complete = True
135
+ logger.debug(
136
+ "PostgreSQL context-aware cache initialized",
137
+ layer=self.name,
138
+ space_id=self.space_id,
139
+ table_name=self.table_name,
140
+ dims=self._embedding_dims,
141
+ )
142
+
143
+ # Property implementations
144
+ @property
145
+ def database(self) -> DatabaseModel:
146
+ """The database used for storing cache entries."""
147
+ return self.parameters.database
148
+
149
+ @property
150
+ def warehouse(self) -> WarehouseModel:
151
+ """The warehouse used for executing cached SQL queries."""
152
+ return self.parameters.warehouse
153
+
154
+ @property
155
+ def time_to_live(self) -> timedelta | None:
156
+ """Time-to-live for cache entries. None means never expires."""
157
+ ttl = self.parameters.time_to_live_seconds
158
+ if ttl is None or ttl < 0:
159
+ return None
160
+ return timedelta(seconds=ttl)
161
+
162
+ @property
163
+ def time_to_live_seconds(self) -> int | None:
164
+ """TTL in seconds (None or negative = never expires)."""
165
+ return self.parameters.time_to_live_seconds
166
+
167
+ @property
168
+ def similarity_threshold(self) -> float:
169
+ """Minimum similarity for cache hit (using L2 distance converted to similarity)."""
170
+ return self.parameters.similarity_threshold
171
+
172
+ @property
173
+ def context_similarity_threshold(self) -> float:
174
+ """Minimum similarity for context matching."""
175
+ return self.parameters.context_similarity_threshold
176
+
177
+ @property
178
+ def question_weight(self) -> float:
179
+ """Weight for question similarity in combined score."""
180
+ return self.parameters.question_weight
181
+
182
+ @property
183
+ def context_weight(self) -> float:
184
+ """Weight for context similarity in combined score."""
185
+ return self.parameters.context_weight
186
+
187
+ @property
188
+ def table_name(self) -> str:
189
+ """Name of the cache table."""
190
+ return self.parameters.table_name
191
+
192
+ @property
193
+ def prompt_history_table(self) -> str:
194
+ """Name of the prompt history table."""
195
+ return self.parameters.prompt_history_table
196
+
197
+ @property
198
+ def context_window_size(self) -> int:
199
+ """Number of previous prompts to include in context."""
200
+ return self.parameters.context_window_size
201
+
202
+ @property
203
+ def max_context_tokens(self) -> int:
204
+ """Maximum tokens for context string."""
205
+ return self.parameters.max_context_tokens
206
+
207
+ @property
208
+ def max_prompt_history_length(self) -> int:
209
+ """Maximum number of prompts to keep per conversation."""
210
+ return self.parameters.max_prompt_history_length
211
+
212
+ def _create_table_if_not_exists(self) -> None:
213
+ """Create the cache table and prompt history table with pg_vector extension."""
214
+ create_extension_sql: str = "CREATE EXTENSION IF NOT EXISTS vector"
215
+
216
+ # Check if table exists and get current embedding dimensions
217
+ check_dims_sql: str = """
218
+ SELECT atttypmod
219
+ FROM pg_attribute
220
+ WHERE attrelid = %s::regclass
221
+ AND attname = 'question_embedding'
222
+ """
223
+
224
+ create_table_sql: str = f"""
225
+ CREATE TABLE IF NOT EXISTS {self.table_name} (
226
+ id SERIAL PRIMARY KEY,
227
+ genie_space_id TEXT NOT NULL,
228
+ question TEXT NOT NULL,
229
+ conversation_context TEXT,
230
+ context_string TEXT,
231
+ question_embedding vector({self.embedding_dims}),
232
+ context_embedding vector({self.embedding_dims}),
233
+ sql_query TEXT NOT NULL,
234
+ description TEXT,
235
+ conversation_id TEXT,
236
+ message_id TEXT,
237
+ created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
238
+ )
239
+ """
240
+
241
+ # Migration: Add message_id column if it doesn't exist
242
+ add_message_id_sql: str = f"""
243
+ ALTER TABLE {self.table_name}
244
+ ADD COLUMN IF NOT EXISTS message_id TEXT
245
+ """
246
+ # Index for efficient similarity search partitioned by genie_space_id
247
+ create_question_embedding_index_sql: str = f"""
248
+ CREATE INDEX IF NOT EXISTS {self.table_name}_question_embedding_idx
249
+ ON {self.table_name}
250
+ USING ivfflat (question_embedding vector_l2_ops)
251
+ WITH (lists = 100)
252
+ """
253
+ create_context_embedding_index_sql: str = f"""
254
+ CREATE INDEX IF NOT EXISTS {self.table_name}_context_embedding_idx
255
+ ON {self.table_name}
256
+ USING ivfflat (context_embedding vector_l2_ops)
257
+ WITH (lists = 100)
258
+ """
259
+ create_space_index_sql: str = f"""
260
+ CREATE INDEX IF NOT EXISTS {self.table_name}_space_idx
261
+ ON {self.table_name} (genie_space_id)
262
+ """
263
+ create_unique_question_index_sql: str = f"""
264
+ CREATE UNIQUE INDEX IF NOT EXISTS {self.table_name}_unique_question_idx
265
+ ON {self.table_name} (genie_space_id, question)
266
+ """
267
+
268
+ with self._pool.connection() as conn:
269
+ with conn.cursor() as cur:
270
+ cur.execute(create_extension_sql)
271
+
272
+ # Check if table exists and verify embedding dimensions
273
+ try:
274
+ cur.execute(check_dims_sql, (self.table_name,))
275
+ row: DbRow | None = cur.fetchone()
276
+ if row is not None:
277
+ current_dims = row.get("atttypmod", 0)
278
+ if current_dims != self.embedding_dims:
279
+ logger.warning(
280
+ "Embedding dimension mismatch, dropping and recreating table",
281
+ layer=self.name,
282
+ table_dims=current_dims,
283
+ expected_dims=self.embedding_dims,
284
+ table_name=self.table_name,
285
+ )
286
+ cur.execute(f"DROP TABLE {self.table_name}")
287
+ except Exception:
288
+ pass
289
+
290
+ try:
291
+ cur.execute(create_table_sql)
292
+ except Exception as e:
293
+ logger.debug(
294
+ f"Table creation skipped (may already exist): {e}",
295
+ layer=self.name,
296
+ )
297
+
298
+ # Migration: Add message_id column if it doesn't exist (for existing tables)
299
+ try:
300
+ cur.execute(add_message_id_sql)
301
+ logger.debug(
302
+ "Added message_id column (or already exists)",
303
+ layer=self.name,
304
+ table_name=self.table_name,
305
+ )
306
+ except Exception as e:
307
+ # Column might already exist or other error
308
+ logger.debug(
309
+ f"message_id column migration skipped: {e}",
310
+ layer=self.name,
311
+ )
312
+
313
+ # Create indexes
314
+ for idx_name, idx_sql in [
315
+ (f"{self.table_name}_space_idx", create_space_index_sql),
316
+ (
317
+ f"{self.table_name}_question_embedding_idx",
318
+ create_question_embedding_index_sql,
319
+ ),
320
+ (
321
+ f"{self.table_name}_context_embedding_idx",
322
+ create_context_embedding_index_sql,
323
+ ),
324
+ (
325
+ f"{self.table_name}_unique_question_idx",
326
+ create_unique_question_index_sql,
327
+ ),
328
+ ]:
329
+ if self._index_exists(cur, idx_name):
330
+ logger.debug(
331
+ f"Index {idx_name} already exists", layer=self.name
332
+ )
333
+ continue
334
+ try:
335
+ cur.execute(idx_sql)
336
+ except Exception as e:
337
+ logger.warning(
338
+ f"Could not create {idx_name}: {e}", layer=self.name
339
+ )
340
+
341
+ # Create prompt history table
342
+ try:
343
+ self._create_prompt_history_table(cur)
344
+ except Exception as e:
345
+ logger.error(
346
+ f"Failed to create prompt history table: {e}",
347
+ layer=self.name,
348
+ exc_info=True,
349
+ )
350
+
351
+ def _create_prompt_history_table(self, cur: Any) -> None:
352
+ """Create the prompt history table for tracking user prompts."""
353
+ prompt_table_name = self.prompt_history_table
354
+
355
+ create_prompt_table_sql: str = f"""
356
+ CREATE TABLE IF NOT EXISTS {prompt_table_name} (
357
+ id SERIAL PRIMARY KEY,
358
+ genie_space_id TEXT NOT NULL,
359
+ conversation_id TEXT NOT NULL,
360
+ prompt TEXT NOT NULL,
361
+ cache_hit BOOLEAN DEFAULT FALSE,
362
+ cache_entry_id INTEGER,
363
+ created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
364
+ )
365
+ """
366
+
367
+ # Migration: Add cache_entry_id column if it doesn't exist
368
+ add_cache_entry_id_sql: str = f"""
369
+ ALTER TABLE {prompt_table_name}
370
+ ADD COLUMN IF NOT EXISTS cache_entry_id INTEGER
371
+ """
372
+
373
+ create_conversation_index_sql: str = f"""
374
+ CREATE INDEX IF NOT EXISTS {prompt_table_name}_conversation_idx
375
+ ON {prompt_table_name} (genie_space_id, conversation_id, created_at DESC)
376
+ """
377
+ create_space_index_sql: str = f"""
378
+ CREATE INDEX IF NOT EXISTS {prompt_table_name}_space_idx
379
+ ON {prompt_table_name} (genie_space_id, created_at DESC)
380
+ """
381
+ create_unique_prompt_index_sql: str = f"""
382
+ CREATE UNIQUE INDEX IF NOT EXISTS {prompt_table_name}_unique_prompt_idx
383
+ ON {prompt_table_name} (genie_space_id, conversation_id, prompt)
384
+ """
385
+ create_cache_entry_index_sql: str = f"""
386
+ CREATE INDEX IF NOT EXISTS {prompt_table_name}_cache_entry_idx
387
+ ON {prompt_table_name} (cache_entry_id)
388
+ WHERE cache_entry_id IS NOT NULL
389
+ """
390
+
391
+ try:
392
+ cur.execute(create_prompt_table_sql)
393
+ except Exception as e:
394
+ if "duplicate key" in str(e) or "already exists" in str(e):
395
+ logger.debug("Prompt history table already exists", layer=self.name)
396
+ else:
397
+ raise
398
+
399
+ # Migration: Add cache_entry_id column if it doesn't exist (for existing tables)
400
+ try:
401
+ cur.execute(add_cache_entry_id_sql)
402
+ logger.debug(
403
+ "Added cache_entry_id column (or already exists)",
404
+ layer=self.name,
405
+ table_name=prompt_table_name,
406
+ )
407
+ except Exception as e:
408
+ logger.debug(
409
+ f"cache_entry_id column migration skipped: {e}",
410
+ layer=self.name,
411
+ )
412
+
413
+ for idx_name, idx_sql in [
414
+ (f"{prompt_table_name}_conversation_idx", create_conversation_index_sql),
415
+ (f"{prompt_table_name}_space_idx", create_space_index_sql),
416
+ (f"{prompt_table_name}_unique_prompt_idx", create_unique_prompt_index_sql),
417
+ (f"{prompt_table_name}_cache_entry_idx", create_cache_entry_index_sql),
418
+ ]:
419
+ if self._index_exists(cur, idx_name):
420
+ continue
421
+ try:
422
+ cur.execute(idx_sql)
423
+ except Exception as e:
424
+ if "duplicate key" not in str(e) and "already exists" not in str(e):
425
+ logger.warning(
426
+ f"Could not create index {idx_name}: {e}", layer=self.name
427
+ )
428
+
429
+ logger.info(
430
+ "Prompt history table ready", layer=self.name, table=prompt_table_name
431
+ )
432
+
433
+ @mlflow.trace(name="semantic_search_postgres")
434
+ def _find_similar(
435
+ self,
436
+ question: str,
437
+ conversation_context: str,
438
+ question_embedding: list[float],
439
+ context_embedding: list[float],
440
+ conversation_id: str | None = None,
441
+ ) -> tuple[SQLCacheEntry, float] | None:
442
+ """Find a semantically similar cached entry using pg_vector."""
443
+ ttl_seconds = self.time_to_live_seconds
444
+ ttl_disabled = ttl_seconds is None or ttl_seconds < 0
445
+
446
+ if ttl_disabled:
447
+ is_valid_expr = "TRUE"
448
+ else:
449
+ is_valid_expr = f"created_at > NOW() - INTERVAL '{ttl_seconds} seconds'"
450
+
451
+ question_weight = self.question_weight
452
+ context_weight = self.context_weight
453
+
454
+ search_sql: str = f"""
455
+ SELECT
456
+ id,
457
+ question,
458
+ conversation_context,
459
+ sql_query,
460
+ description,
461
+ conversation_id,
462
+ message_id,
463
+ created_at,
464
+ 1.0 / (1.0 + (question_embedding <-> %s::vector)) as question_similarity,
465
+ 1.0 / (1.0 + (context_embedding <-> %s::vector)) as context_similarity,
466
+ ({question_weight} * (1.0 / (1.0 + (question_embedding <-> %s::vector)))) +
467
+ ({context_weight} * (1.0 / (1.0 + (context_embedding <-> %s::vector)))) as combined_similarity,
468
+ {is_valid_expr} as is_valid
469
+ FROM {self.table_name}
470
+ WHERE genie_space_id = %s
471
+ ORDER BY combined_similarity DESC
472
+ LIMIT 1
473
+ """
474
+
475
+ question_emb_str = f"[{','.join(str(x) for x in question_embedding)}]"
476
+ context_emb_str = f"[{','.join(str(x) for x in context_embedding)}]"
477
+
478
+ with self._pool.connection() as conn:
479
+ with conn.cursor() as cur:
480
+ cur.execute(
481
+ search_sql,
482
+ (
483
+ question_emb_str,
484
+ context_emb_str,
485
+ question_emb_str,
486
+ context_emb_str,
487
+ self.space_id,
488
+ ),
489
+ )
490
+ row: DbRow | None = cur.fetchone()
491
+
492
+ if row is None:
493
+ logger.info(
494
+ "Cache MISS (no entries)",
495
+ layer=self.name,
496
+ question=question[:50],
497
+ space=self.space_id,
498
+ )
499
+ return None
500
+
501
+ entry_id = row.get("id")
502
+ cached_question = row.get("question", "")
503
+ sql_query = row["sql_query"]
504
+ description = row.get("description", "")
505
+ conversation_id_cached = row.get("conversation_id", "")
506
+ created_at = row["created_at"]
507
+ question_similarity = row["question_similarity"]
508
+ context_similarity = row["context_similarity"]
509
+ combined_similarity = row["combined_similarity"]
510
+ is_valid = row.get("is_valid", False)
511
+
512
+ logger.debug(
513
+ "Best match found",
514
+ layer=self.name,
515
+ question_sim=f"{question_similarity:.4f}",
516
+ context_sim=f"{context_similarity:.4f}",
517
+ combined_sim=f"{combined_similarity:.4f}",
518
+ is_valid=is_valid,
519
+ )
520
+
521
+ if question_similarity < self.similarity_threshold:
522
+ logger.info(
523
+ "Cache MISS (question similarity too low)",
524
+ layer=self.name,
525
+ question_sim=f"{question_similarity:.4f}",
526
+ threshold=self.similarity_threshold,
527
+ )
528
+ return None
529
+
530
+ if context_similarity < self.context_similarity_threshold:
531
+ logger.info(
532
+ "Cache MISS (context similarity too low)",
533
+ layer=self.name,
534
+ context_sim=f"{context_similarity:.4f}",
535
+ threshold=self.context_similarity_threshold,
536
+ )
537
+ return None
538
+
539
+ if not is_valid:
540
+ cur.execute(
541
+ f"DELETE FROM {self.table_name} WHERE id = %s", (entry_id,)
542
+ )
543
+ logger.info("Cache MISS (expired, deleted)", layer=self.name)
544
+ return None
545
+
546
+ cache_age_seconds = None
547
+ if created_at:
548
+ cache_age_seconds = (
549
+ datetime.now(created_at.tzinfo) - created_at
550
+ ).total_seconds()
551
+
552
+ logger.info(
553
+ "Cache HIT",
554
+ layer=self.name,
555
+ question=question[:80],
556
+ matched_question=cached_question[:80],
557
+ cache_age_seconds=round(cache_age_seconds, 1)
558
+ if cache_age_seconds
559
+ else None,
560
+ question_similarity=f"{question_similarity:.4f}",
561
+ context_similarity=f"{context_similarity:.4f}",
562
+ combined_similarity=f"{combined_similarity:.4f}",
563
+ )
564
+
565
+ message_id_cached = row.get("message_id")
566
+
567
+ entry = SQLCacheEntry(
568
+ query=sql_query,
569
+ description=description,
570
+ conversation_id=conversation_id_cached,
571
+ created_at=created_at,
572
+ message_id=message_id_cached,
573
+ cache_entry_id=entry_id,
574
+ )
575
+ return entry, combined_similarity
576
+
577
+ def _store_entry(
578
+ self,
579
+ question: str,
580
+ conversation_context: str,
581
+ question_embedding: list[float],
582
+ context_embedding: list[float],
583
+ response: GenieResponse,
584
+ message_id: str | None = None,
585
+ ) -> None:
586
+ """Store a new cache entry with dual embeddings and message_id."""
587
+ insert_sql: str = f"""
588
+ INSERT INTO {self.table_name}
589
+ (genie_space_id, question, conversation_context, context_string,
590
+ question_embedding, context_embedding, sql_query, description,
591
+ conversation_id, message_id)
592
+ VALUES (%s, %s, %s, %s, %s::vector, %s::vector, %s, %s, %s, %s)
593
+ """
594
+ question_emb_str = f"[{','.join(str(x) for x in question_embedding)}]"
595
+ context_emb_str = f"[{','.join(str(x) for x in context_embedding)}]"
596
+
597
+ if conversation_context:
598
+ full_context_string = f"{conversation_context}\nCurrent: {question}"
599
+ else:
600
+ full_context_string = question
601
+
602
+ with self._pool.connection() as conn:
603
+ with conn.cursor() as cur:
604
+ cur.execute(
605
+ insert_sql,
606
+ (
607
+ self.space_id,
608
+ question,
609
+ conversation_context,
610
+ full_context_string,
611
+ question_emb_str,
612
+ context_emb_str,
613
+ response.query,
614
+ response.description,
615
+ response.conversation_id,
616
+ message_id,
617
+ ),
618
+ )
619
+ logger.debug(
620
+ "Stored cache entry",
621
+ layer=self.name,
622
+ question=question[:50],
623
+ space=self.space_id,
624
+ message_id=message_id,
625
+ )
626
+
627
+ def _on_stale_cache_entry(self, question: str) -> None:
628
+ """Delete stale cache entry from database."""
629
+ delete_sql = (
630
+ f"DELETE FROM {self.table_name} WHERE genie_space_id = %s AND question = %s"
631
+ )
632
+ with self._pool.connection() as conn:
633
+ with conn.cursor() as cur:
634
+ cur.execute(delete_sql, (self.space_id, question))
635
+ deleted_rows = cur.rowcount
636
+ logger.info(
637
+ "Deleted stale cache entry",
638
+ layer=self.name,
639
+ deleted_rows=deleted_rows,
640
+ space_id=self.space_id,
641
+ )
642
+
643
+ def _invalidate_by_question(self, question: str) -> bool:
644
+ """
645
+ Invalidate cache entries matching a specific question.
646
+
647
+ This method is called when negative feedback is received to remove
648
+ the corresponding cache entry from the PostgreSQL database.
649
+
650
+ Args:
651
+ question: The question text to match and invalidate
652
+
653
+ Returns:
654
+ True if an entry was found and invalidated, False otherwise
655
+ """
656
+ delete_sql = (
657
+ f"DELETE FROM {self.table_name} WHERE genie_space_id = %s AND question = %s"
658
+ )
659
+ with self._pool.connection() as conn:
660
+ with conn.cursor() as cur:
661
+ cur.execute(delete_sql, (self.space_id, question))
662
+ deleted_rows = cur.rowcount if isinstance(cur.rowcount, int) else 0
663
+ if deleted_rows > 0:
664
+ logger.info(
665
+ "Invalidated cache entry by question",
666
+ layer=self.name,
667
+ question=question[:50],
668
+ deleted_rows=deleted_rows,
669
+ space_id=self.space_id,
670
+ )
671
+ return True
672
+ return False
673
+
674
+ # Template Method hook implementations
675
+
676
+ def _before_cache_lookup(self, question: str, conversation_id: str | None) -> None:
677
+ """Store prompt before cache lookup."""
678
+ if conversation_id:
679
+ self._store_user_prompt(
680
+ prompt=question,
681
+ conversation_id=conversation_id,
682
+ cache_hit=False,
683
+ )
684
+ # Track that we stored the prompt
685
+ self._prompt_stored_for_current_request = True
686
+ else:
687
+ self._prompt_stored_for_current_request = False
688
+
689
+ def _after_cache_hit(
690
+ self,
691
+ question: str,
692
+ conversation_id: str | None,
693
+ result: CacheResult,
694
+ ) -> None:
695
+ """Update cache_hit flag and cache_entry_id after a cache hit."""
696
+ if result.cache_hit and self._prompt_stored_for_current_request:
697
+ actual_conv_id = result.response.conversation_id or conversation_id
698
+ if actual_conv_id:
699
+ self._update_prompt_cache_hit(
700
+ conversation_id=actual_conv_id,
701
+ prompt=question,
702
+ cache_hit=True,
703
+ cache_entry_id=result.cache_entry_id,
704
+ )
705
+
706
+ def _after_cache_miss(
707
+ self,
708
+ question: str,
709
+ conversation_id: str | None,
710
+ result: CacheResult,
711
+ ) -> None:
712
+ """Store prompt if not done earlier (when conversation_id comes from response)."""
713
+ if (
714
+ not self._prompt_stored_for_current_request
715
+ and result.response.conversation_id
716
+ ):
717
+ self._store_user_prompt(
718
+ prompt=question,
719
+ conversation_id=result.response.conversation_id,
720
+ cache_hit=False,
721
+ )
722
+
723
+ # Template Method implementations for invalidate_expired() and clear()
724
+
725
+ def _get_empty_expiration_result(self) -> dict[str, int]:
726
+ """Return empty dict for PostgresContextAwareGenieService."""
727
+ return {"cache": 0, "prompt_history": 0}
728
+
729
+ def _delete_expired_entries(self, ttl_seconds: int) -> dict[str, int]:
730
+ """Delete expired entries from cache and prompt history."""
731
+ prompt_ttl_seconds = self.parameters.prompt_history_ttl_seconds
732
+ if prompt_ttl_seconds is None:
733
+ prompt_ttl_seconds = ttl_seconds
734
+
735
+ result: dict[str, int] = {"cache": 0, "prompt_history": 0}
736
+
737
+ # Delete expired cache entries
738
+ delete_cache_sql = f"""
739
+ DELETE FROM {self.table_name}
740
+ WHERE genie_space_id = %s
741
+ AND created_at < NOW() - INTERVAL '%s seconds'
742
+ """
743
+
744
+ with self._pool.connection() as conn:
745
+ with conn.cursor() as cur:
746
+ cur.execute(delete_cache_sql, (self.space_id, ttl_seconds))
747
+ deleted = cur.rowcount if isinstance(cur.rowcount, int) else 0
748
+ result["cache"] = deleted
749
+ logger.debug(
750
+ "Deleted expired cache entries",
751
+ layer=self.name,
752
+ deleted_count=deleted,
753
+ )
754
+
755
+ # Delete expired prompt history
756
+ if prompt_ttl_seconds is not None and prompt_ttl_seconds >= 0:
757
+ try:
758
+ delete_prompt_sql = f"""
759
+ DELETE FROM {self.prompt_history_table}
760
+ WHERE genie_space_id = %s
761
+ AND created_at < NOW() - INTERVAL '%s seconds'
762
+ """
763
+ with self._pool.connection() as conn:
764
+ with conn.cursor() as cur:
765
+ cur.execute(
766
+ delete_prompt_sql, (self.space_id, prompt_ttl_seconds)
767
+ )
768
+ deleted = cur.rowcount if isinstance(cur.rowcount, int) else 0
769
+ result["prompt_history"] = deleted
770
+ except Exception as e:
771
+ logger.warning(
772
+ f"Failed to clean up prompt history: {e}", layer=self.name
773
+ )
774
+
775
+ return result
776
+
777
+ def _delete_all_entries(self) -> int:
778
+ """Delete all cache entries for this Genie space."""
779
+ delete_sql = f"DELETE FROM {self.table_name} WHERE genie_space_id = %s"
780
+
781
+ with self._pool.connection() as conn:
782
+ with conn.cursor() as cur:
783
+ cur.execute(delete_sql, (self.space_id,))
784
+ deleted: int = cur.rowcount
785
+ logger.debug(
786
+ "Cleared cache entries", layer=self.name, deleted_count=deleted
787
+ )
788
+ return deleted
789
+
790
+ # Template Method implementations for stats()
791
+
792
+ def _count_all_entries(self) -> int:
793
+ """Count all cache entries for this Genie space."""
794
+ count_sql = (
795
+ f"SELECT COUNT(*) as total FROM {self.table_name} WHERE genie_space_id = %s"
796
+ )
797
+ with self._pool.connection() as conn:
798
+ with conn.cursor() as cur:
799
+ cur.execute(count_sql, (self.space_id,))
800
+ row = cur.fetchone()
801
+ return row.get("total", 0) if row else 0
802
+
803
+ def _count_entries_with_ttl(self, ttl_seconds: int) -> tuple[int, int]:
804
+ """Count total and expired entries for this Genie space."""
805
+ stats_sql = f"""
806
+ SELECT
807
+ COUNT(*) as total,
808
+ COUNT(*) FILTER (WHERE created_at <= NOW() - INTERVAL '%s seconds') as expired
809
+ FROM {self.table_name}
810
+ WHERE genie_space_id = %s
811
+ """
812
+ with self._pool.connection() as conn:
813
+ with conn.cursor() as cur:
814
+ cur.execute(stats_sql, (ttl_seconds, self.space_id))
815
+ row = cur.fetchone()
816
+ if row:
817
+ return row.get("total", 0), row.get("expired", 0)
818
+ return 0, 0
819
+
820
+ def _get_additional_stats(self) -> dict[str, Any]:
821
+ """Add prompt history stats."""
822
+ prompt_stats_sql = f"""
823
+ SELECT
824
+ COUNT(*) as total_prompts,
825
+ COUNT(*) FILTER (WHERE cache_hit = true) as cache_hit_prompts,
826
+ COUNT(*) FILTER (WHERE cache_hit = false) as cache_miss_prompts,
827
+ COUNT(DISTINCT conversation_id) as total_conversations
828
+ FROM {self.prompt_history_table}
829
+ WHERE genie_space_id = %s
830
+ """
831
+ with self._pool.connection() as conn:
832
+ with conn.cursor() as cur:
833
+ cur.execute(prompt_stats_sql, (self.space_id,))
834
+ row = cur.fetchone()
835
+ if row:
836
+ total_prompts = row.get("total_prompts", 0)
837
+ return {
838
+ "prompt_history": {
839
+ "total_prompts": total_prompts,
840
+ "cache_hit_prompts": row.get("cache_hit_prompts", 0),
841
+ "cache_miss_prompts": row.get("cache_miss_prompts", 0),
842
+ "total_conversations": row.get("total_conversations", 0),
843
+ "cache_hit_rate": (
844
+ row.get("cache_hit_prompts", 0) / total_prompts
845
+ if total_prompts > 0
846
+ else 0.0
847
+ ),
848
+ }
849
+ }
850
+ return {}
851
+
852
+ def from_space(
853
+ self,
854
+ space_id: str | None = None,
855
+ *,
856
+ include_all_messages: bool = True,
857
+ from_datetime: datetime | None = None,
858
+ to_datetime: datetime | None = None,
859
+ max_messages: int | None = None,
860
+ ) -> Self:
861
+ """Populate cache from existing Genie space conversations.
862
+
863
+ Fetches all conversations from a Genie space and populates:
864
+ 1. Prompt history table - all user messages
865
+ 2. Cache embeddings table - messages with SQL query attachments
866
+
867
+ Uses ON CONFLICT DO NOTHING to avoid duplicate entries.
868
+
869
+ Args:
870
+ space_id: Genie space ID to import from (defaults to self.space_id)
871
+ include_all_messages: If True, fetch all users' conversations
872
+ from_datetime: Only include messages after this time
873
+ to_datetime: Only include messages before this time
874
+ max_messages: Limit to last N messages (most recent first)
875
+
876
+ Returns:
877
+ self for method chaining
878
+ """
879
+ if self.workspace_client is None:
880
+ raise ValueError("workspace_client is required for from_space()")
881
+
882
+ self._setup()
883
+ target_space_id = space_id or self.space_id
884
+
885
+ logger.info(
886
+ "Starting from_space import",
887
+ layer=self.name,
888
+ space_id=target_space_id,
889
+ include_all_messages=include_all_messages,
890
+ )
891
+
892
+ stats = {
893
+ "conversations_processed": 0,
894
+ "prompts_imported": 0,
895
+ "prompts_skipped": 0,
896
+ "cache_entries_imported": 0,
897
+ "cache_entries_skipped": 0,
898
+ "errors": 0,
899
+ }
900
+
901
+ from databricks.sdk.service.dashboards import GenieMessage
902
+
903
+ all_messages: list[tuple[str, GenieMessage]] = []
904
+ page_token: str | None = None
905
+
906
+ while True:
907
+ try:
908
+ response = self.workspace_client.genie.list_conversations(
909
+ space_id=target_space_id,
910
+ include_all=include_all_messages,
911
+ page_token=page_token,
912
+ )
913
+ except Exception as e:
914
+ logger.error(f"Failed to list conversations: {e}", layer=self.name)
915
+ stats["errors"] += 1
916
+ break
917
+
918
+ if response.conversations is None:
919
+ break
920
+
921
+ for conversation in response.conversations:
922
+ if conversation.conversation_id is None:
923
+ continue
924
+
925
+ stats["conversations_processed"] += 1
926
+
927
+ try:
928
+ messages_response = (
929
+ self.workspace_client.genie.list_conversation_messages(
930
+ space_id=target_space_id,
931
+ conversation_id=conversation.conversation_id,
932
+ )
933
+ )
934
+ except Exception as e:
935
+ logger.warning(f"Failed to fetch messages: {e}", layer=self.name)
936
+ stats["errors"] += 1
937
+ continue
938
+
939
+ if messages_response.messages is None:
940
+ continue
941
+
942
+ for message in messages_response.messages:
943
+ all_messages.append((conversation.conversation_id, message))
944
+
945
+ if max_messages and len(all_messages) >= max_messages:
946
+ break
947
+
948
+ if max_messages and len(all_messages) >= max_messages:
949
+ break
950
+
951
+ page_token = response.next_page_token
952
+ if page_token is None:
953
+ break
954
+
955
+ # Sort and limit
956
+ all_messages.sort(
957
+ key=lambda x: x[1].created_timestamp if x[1].created_timestamp else 0,
958
+ reverse=True,
959
+ )
960
+ if max_messages:
961
+ all_messages = all_messages[:max_messages]
962
+
963
+ # Group messages by conversation_id for context building
964
+ from collections import defaultdict
965
+
966
+ messages_by_conversation: dict[str, list[tuple[str, GenieMessage]]] = (
967
+ defaultdict(list)
968
+ )
969
+ for conv_id, msg in all_messages:
970
+ messages_by_conversation[conv_id].append((conv_id, msg))
971
+
972
+ # Sort each conversation's messages by timestamp (oldest first for context building)
973
+ for conv_id in messages_by_conversation:
974
+ messages_by_conversation[conv_id].sort(
975
+ key=lambda x: x[1].created_timestamp if x[1].created_timestamp else 0
976
+ )
977
+
978
+ # Process messages
979
+ for conversation_id, message in all_messages:
980
+ if message.content is None:
981
+ continue
982
+
983
+ message_created_at = None
984
+ if message.created_timestamp:
985
+ message_created_at = datetime.fromtimestamp(
986
+ message.created_timestamp / 1000.0,
987
+ tz=from_datetime.tzinfo if from_datetime else None,
988
+ )
989
+
990
+ if message_created_at:
991
+ if from_datetime and message_created_at < from_datetime:
992
+ continue
993
+ if to_datetime and message_created_at > to_datetime:
994
+ continue
995
+
996
+ # Store prompt
997
+ prompt_stored = self._store_prompt_if_not_exists(
998
+ prompt=message.content,
999
+ conversation_id=conversation_id,
1000
+ space_id=target_space_id,
1001
+ cache_hit=False,
1002
+ created_at=message_created_at,
1003
+ )
1004
+
1005
+ if prompt_stored:
1006
+ stats["prompts_imported"] += 1
1007
+ else:
1008
+ stats["prompts_skipped"] += 1
1009
+
1010
+ # Check for SQL attachments
1011
+ if message.attachments:
1012
+ for attachment in message.attachments:
1013
+ if attachment.query and attachment.query.query:
1014
+ try:
1015
+ # Build conversation context from prior messages
1016
+ # Uses same "Previous: {content}" format as normal operations
1017
+ prior_messages: list[str] = []
1018
+ conv_messages = messages_by_conversation.get(
1019
+ conversation_id, []
1020
+ )
1021
+ for _, prior_msg in conv_messages:
1022
+ if (
1023
+ prior_msg.created_timestamp
1024
+ and message.created_timestamp
1025
+ ):
1026
+ if (
1027
+ prior_msg.created_timestamp
1028
+ < message.created_timestamp
1029
+ ):
1030
+ if prior_msg.content:
1031
+ content = prior_msg.content
1032
+ if len(content) > 500:
1033
+ content = content[:500] + "..."
1034
+ prior_messages.append(
1035
+ f"Previous: {content}"
1036
+ )
1037
+
1038
+ # Limit to context_window_size (most recent N messages)
1039
+ context_window = self.context_window_size
1040
+ if len(prior_messages) > context_window:
1041
+ prior_messages = prior_messages[-context_window:]
1042
+
1043
+ conversation_context = "\n".join(prior_messages)
1044
+
1045
+ # Generate embeddings
1046
+ question_embedding = self._embeddings.embed_query(
1047
+ message.content
1048
+ )
1049
+ if conversation_context:
1050
+ context_embedding = self._embeddings.embed_query(
1051
+ conversation_context
1052
+ )
1053
+ else:
1054
+ # Zero vector when no prior context (first message)
1055
+ context_embedding = [0.0] * len(question_embedding)
1056
+
1057
+ cache_stored = self._store_cache_entry_if_not_exists(
1058
+ question=message.content,
1059
+ conversation_context=conversation_context,
1060
+ question_embedding=question_embedding,
1061
+ context_embedding=context_embedding,
1062
+ sql_query=attachment.query.query,
1063
+ description=attachment.query.description or "",
1064
+ conversation_id=conversation_id,
1065
+ space_id=target_space_id,
1066
+ )
1067
+
1068
+ if cache_stored:
1069
+ stats["cache_entries_imported"] += 1
1070
+ else:
1071
+ stats["cache_entries_skipped"] += 1
1072
+ except Exception as e:
1073
+ logger.warning(
1074
+ f"Failed to generate embeddings: {e}", layer=self.name
1075
+ )
1076
+ stats["errors"] += 1
1077
+
1078
+ logger.info("Completed from_space import", layer=self.name, **stats)
1079
+ return self
1080
+
1081
+ def _store_prompt_if_not_exists(
1082
+ self,
1083
+ prompt: str,
1084
+ conversation_id: str,
1085
+ space_id: str | None = None,
1086
+ cache_hit: bool = False,
1087
+ created_at: datetime | None = None,
1088
+ ) -> bool:
1089
+ """Store prompt with ON CONFLICT DO NOTHING."""
1090
+ target_space_id = space_id or self.space_id
1091
+ prompt_table_name = self.prompt_history_table
1092
+
1093
+ if created_at:
1094
+ insert_sql = f"""
1095
+ INSERT INTO {prompt_table_name}
1096
+ (genie_space_id, conversation_id, prompt, cache_hit, created_at)
1097
+ VALUES (%s, %s, %s, %s, %s)
1098
+ ON CONFLICT (genie_space_id, conversation_id, prompt) DO NOTHING
1099
+ """
1100
+ params = (target_space_id, conversation_id, prompt, cache_hit, created_at)
1101
+ else:
1102
+ insert_sql = f"""
1103
+ INSERT INTO {prompt_table_name}
1104
+ (genie_space_id, conversation_id, prompt, cache_hit)
1105
+ VALUES (%s, %s, %s, %s)
1106
+ ON CONFLICT (genie_space_id, conversation_id, prompt) DO NOTHING
1107
+ """
1108
+ params = (target_space_id, conversation_id, prompt, cache_hit)
1109
+
1110
+ try:
1111
+ with self._pool.connection() as conn:
1112
+ with conn.cursor() as cur:
1113
+ cur.execute(insert_sql, params)
1114
+ return cur.rowcount > 0 if isinstance(cur.rowcount, int) else False
1115
+ except Exception:
1116
+ return False
1117
+
1118
+ def _store_cache_entry_if_not_exists(
1119
+ self,
1120
+ question: str,
1121
+ conversation_context: str,
1122
+ question_embedding: list[float],
1123
+ context_embedding: list[float],
1124
+ sql_query: str,
1125
+ description: str | None = None,
1126
+ conversation_id: str | None = None,
1127
+ space_id: str | None = None,
1128
+ ) -> bool:
1129
+ """Store cache entry with ON CONFLICT DO NOTHING."""
1130
+ target_space_id = space_id or self.space_id
1131
+
1132
+ insert_sql = f"""
1133
+ INSERT INTO {self.table_name}
1134
+ (genie_space_id, question, conversation_context, context_string,
1135
+ question_embedding, context_embedding, sql_query, description, conversation_id)
1136
+ VALUES (%s, %s, %s, %s, %s::vector, %s::vector, %s, %s, %s)
1137
+ ON CONFLICT (genie_space_id, question) DO NOTHING
1138
+ """
1139
+ question_emb_str = f"[{','.join(str(x) for x in question_embedding)}]"
1140
+ context_emb_str = f"[{','.join(str(x) for x in context_embedding)}]"
1141
+ full_context = (
1142
+ f"{conversation_context}\nCurrent: {question}"
1143
+ if conversation_context
1144
+ else question
1145
+ )
1146
+
1147
+ try:
1148
+ with self._pool.connection() as conn:
1149
+ with conn.cursor() as cur:
1150
+ cur.execute(
1151
+ insert_sql,
1152
+ (
1153
+ target_space_id,
1154
+ question,
1155
+ conversation_context,
1156
+ full_context,
1157
+ question_emb_str,
1158
+ context_emb_str,
1159
+ sql_query,
1160
+ description or "",
1161
+ conversation_id or "",
1162
+ ),
1163
+ )
1164
+ return cur.rowcount > 0 if isinstance(cur.rowcount, int) else False
1165
+ except Exception:
1166
+ return False