dao-ai 0.0.35__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. dao_ai/__init__.py +29 -0
  2. dao_ai/cli.py +195 -30
  3. dao_ai/config.py +797 -242
  4. dao_ai/genie/__init__.py +38 -0
  5. dao_ai/genie/cache/__init__.py +43 -0
  6. dao_ai/genie/cache/base.py +72 -0
  7. dao_ai/genie/cache/core.py +75 -0
  8. dao_ai/genie/cache/lru.py +329 -0
  9. dao_ai/genie/cache/semantic.py +919 -0
  10. dao_ai/genie/core.py +35 -0
  11. dao_ai/graph.py +27 -253
  12. dao_ai/hooks/__init__.py +9 -6
  13. dao_ai/hooks/core.py +22 -190
  14. dao_ai/memory/__init__.py +10 -0
  15. dao_ai/memory/core.py +23 -5
  16. dao_ai/memory/databricks.py +389 -0
  17. dao_ai/memory/postgres.py +2 -2
  18. dao_ai/messages.py +6 -4
  19. dao_ai/middleware/__init__.py +125 -0
  20. dao_ai/middleware/assertions.py +778 -0
  21. dao_ai/middleware/base.py +50 -0
  22. dao_ai/middleware/core.py +61 -0
  23. dao_ai/middleware/guardrails.py +415 -0
  24. dao_ai/middleware/human_in_the_loop.py +228 -0
  25. dao_ai/middleware/message_validation.py +554 -0
  26. dao_ai/middleware/summarization.py +192 -0
  27. dao_ai/models.py +1177 -108
  28. dao_ai/nodes.py +118 -161
  29. dao_ai/optimization.py +664 -0
  30. dao_ai/orchestration/__init__.py +52 -0
  31. dao_ai/orchestration/core.py +287 -0
  32. dao_ai/orchestration/supervisor.py +264 -0
  33. dao_ai/orchestration/swarm.py +226 -0
  34. dao_ai/prompts.py +126 -29
  35. dao_ai/providers/databricks.py +126 -381
  36. dao_ai/state.py +139 -21
  37. dao_ai/tools/__init__.py +11 -5
  38. dao_ai/tools/core.py +57 -4
  39. dao_ai/tools/email.py +280 -0
  40. dao_ai/tools/genie.py +108 -35
  41. dao_ai/tools/mcp.py +4 -3
  42. dao_ai/tools/memory.py +50 -0
  43. dao_ai/tools/python.py +4 -12
  44. dao_ai/tools/search.py +14 -0
  45. dao_ai/tools/slack.py +1 -1
  46. dao_ai/tools/unity_catalog.py +8 -6
  47. dao_ai/tools/vector_search.py +16 -9
  48. dao_ai/utils.py +72 -8
  49. dao_ai-0.1.0.dist-info/METADATA +1878 -0
  50. dao_ai-0.1.0.dist-info/RECORD +62 -0
  51. dao_ai/chat_models.py +0 -204
  52. dao_ai/guardrails.py +0 -112
  53. dao_ai/tools/human_in_the_loop.py +0 -100
  54. dao_ai-0.0.35.dist-info/METADATA +0 -1169
  55. dao_ai-0.0.35.dist-info/RECORD +0 -41
  56. {dao_ai-0.0.35.dist-info → dao_ai-0.1.0.dist-info}/WHEEL +0 -0
  57. {dao_ai-0.0.35.dist-info → dao_ai-0.1.0.dist-info}/entry_points.txt +0 -0
  58. {dao_ai-0.0.35.dist-info → dao_ai-0.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,919 @@
1
+ """
2
+ Semantic cache implementation for Genie SQL queries using PostgreSQL pg_vector.
3
+
4
+ This module provides a semantic cache that uses embeddings and similarity search
5
+ to find cached queries that match the intent of new questions. Cache entries are
6
+ partitioned by genie_space_id to ensure proper isolation between Genie spaces.
7
+
8
+ The cache supports conversation-aware embedding using a rolling window approach
9
+ to capture context from recent conversation turns, improving accuracy for
10
+ multi-turn conversations with anaphoric references.
11
+ """
12
+
13
+ from datetime import timedelta
14
+ from typing import Any
15
+
16
+ import mlflow
17
+ import pandas as pd
18
+ from databricks.sdk import WorkspaceClient
19
+ from databricks.sdk.service.dashboards import (
20
+ GenieListConversationMessagesResponse,
21
+ GenieMessage,
22
+ )
23
+ from databricks.sdk.service.sql import StatementResponse, StatementState
24
+ from databricks_ai_bridge.genie import GenieResponse
25
+ from loguru import logger
26
+
27
+ from dao_ai.config import (
28
+ DatabaseModel,
29
+ GenieSemanticCacheParametersModel,
30
+ LLMModel,
31
+ WarehouseModel,
32
+ )
33
+ from dao_ai.genie.cache.base import (
34
+ CacheResult,
35
+ GenieServiceBase,
36
+ SQLCacheEntry,
37
+ )
38
+
39
+ # Type alias for database row (dict due to row_factory=dict_row)
40
+ DbRow = dict[str, Any]
41
+
42
+
43
+ def get_conversation_history(
44
+ workspace_client: WorkspaceClient,
45
+ space_id: str,
46
+ conversation_id: str,
47
+ max_messages: int = 10,
48
+ ) -> list[GenieMessage]:
49
+ """
50
+ Retrieve conversation history from Genie.
51
+
52
+ Args:
53
+ workspace_client: The Databricks workspace client
54
+ space_id: The Genie space ID
55
+ conversation_id: The conversation ID to retrieve
56
+ max_messages: Maximum number of messages to retrieve
57
+
58
+ Returns:
59
+ List of GenieMessage objects representing the conversation history
60
+ """
61
+ try:
62
+ # Use the Genie API to retrieve conversation messages
63
+ response: GenieListConversationMessagesResponse = (
64
+ workspace_client.genie.list_conversation_messages(
65
+ space_id=space_id,
66
+ conversation_id=conversation_id,
67
+ )
68
+ )
69
+
70
+ # Return the most recent messages up to max_messages
71
+ if response.messages is not None:
72
+ all_messages: list[GenieMessage] = list(response.messages)
73
+ return (
74
+ all_messages[-max_messages:]
75
+ if len(all_messages) > max_messages
76
+ else all_messages
77
+ )
78
+ return []
79
+ except Exception as e:
80
+ logger.warning(
81
+ f"Failed to retrieve conversation history for conversation_id={conversation_id}: {e}"
82
+ )
83
+ return []
84
+
85
+
86
+ def build_context_string(
87
+ question: str,
88
+ conversation_messages: list[GenieMessage],
89
+ window_size: int,
90
+ max_tokens: int = 2000,
91
+ ) -> str:
92
+ """
93
+ Build a context-aware question string using rolling window.
94
+
95
+ This function creates a concatenated string that includes recent conversation
96
+ turns to provide context for semantic similarity matching.
97
+
98
+ Args:
99
+ question: The current question
100
+ conversation_messages: List of previous conversation messages
101
+ window_size: Number of previous turns to include
102
+ max_tokens: Maximum estimated tokens (rough approximation: 4 chars = 1 token)
103
+
104
+ Returns:
105
+ Context-aware question string formatted for embedding
106
+ """
107
+ if window_size <= 0 or not conversation_messages:
108
+ return question
109
+
110
+ # Take the last window_size messages (most recent)
111
+ recent_messages = (
112
+ conversation_messages[-window_size:]
113
+ if len(conversation_messages) > window_size
114
+ else conversation_messages
115
+ )
116
+
117
+ # Build context parts
118
+ context_parts: list[str] = []
119
+
120
+ for msg in recent_messages:
121
+ # Only include messages with content from the history
122
+ if msg.content:
123
+ # Limit message length to prevent token overflow
124
+ content: str = msg.content
125
+ if len(content) > 500: # Truncate very long messages
126
+ content = content[:500] + "..."
127
+ context_parts.append(f"Previous: {content}")
128
+
129
+ # Add current question
130
+ context_parts.append(f"Current: {question}")
131
+
132
+ # Join with newlines
133
+ context_string = "\n".join(context_parts)
134
+
135
+ # Rough token limit check (4 chars ≈ 1 token)
136
+ estimated_tokens = len(context_string) / 4
137
+ if estimated_tokens > max_tokens:
138
+ # Truncate to fit max_tokens
139
+ target_chars = max_tokens * 4
140
+ context_string = context_string[:target_chars] + "..."
141
+ logger.debug(
142
+ f"Truncated context string from {len(context_string)} to {target_chars} chars "
143
+ f"(estimated {max_tokens} tokens)"
144
+ )
145
+
146
+ return context_string
147
+
148
+
149
+ class SemanticCacheService(GenieServiceBase):
150
+ """
151
+ Semantic caching decorator that uses PostgreSQL pg_vector for similarity lookup.
152
+
153
+ This service caches the SQL query generated by Genie along with an embedding
154
+ of the original question. On subsequent queries, it performs a semantic similarity
155
+ search to find cached queries that match the intent of the new question.
156
+
157
+ Cache entries are partitioned by genie_space_id to ensure queries from different
158
+ Genie spaces don't return incorrect cache hits.
159
+
160
+ On cache hit, it re-executes the cached SQL using the provided warehouse
161
+ to return fresh data while avoiding the Genie NL-to-SQL translation cost.
162
+
163
+ Example:
164
+ from dao_ai.config import GenieSemanticCacheParametersModel, DatabaseModel
165
+ from dao_ai.genie.cache import SemanticCacheService
166
+
167
+ cache_params = GenieSemanticCacheParametersModel(
168
+ database=database_model,
169
+ warehouse=warehouse_model,
170
+ embedding_model="databricks-gte-large-en",
171
+ time_to_live_seconds=86400, # 24 hours
172
+ similarity_threshold=0.85
173
+ )
174
+ genie = SemanticCacheService(
175
+ impl=GenieService(Genie(space_id="my-space")),
176
+ parameters=cache_params
177
+ )
178
+
179
+ Thread-safe: Uses connection pooling from psycopg_pool.
180
+ """
181
+
182
+ impl: GenieServiceBase
183
+ parameters: GenieSemanticCacheParametersModel
184
+ workspace_client: WorkspaceClient | None
185
+ name: str
186
+ _embeddings: Any # DatabricksEmbeddings
187
+ _pool: Any # ConnectionPool
188
+ _embedding_dims: int | None
189
+ _setup_complete: bool
190
+
191
+ def __init__(
192
+ self,
193
+ impl: GenieServiceBase,
194
+ parameters: GenieSemanticCacheParametersModel,
195
+ workspace_client: WorkspaceClient | None = None,
196
+ name: str | None = None,
197
+ ) -> None:
198
+ """
199
+ Initialize the semantic cache service.
200
+
201
+ Args:
202
+ impl: The underlying GenieServiceBase to delegate to on cache miss.
203
+ The space_id will be obtained from impl.space_id.
204
+ parameters: Cache configuration including database, warehouse, embedding model
205
+ workspace_client: Optional WorkspaceClient for retrieving conversation history.
206
+ If None, conversation context will not be used.
207
+ name: Name for this cache layer (for logging). Defaults to class name.
208
+ """
209
+ self.impl = impl
210
+ self.parameters = parameters
211
+ self.workspace_client = workspace_client
212
+ self.name = name if name is not None else self.__class__.__name__
213
+ self._embeddings = None
214
+ self._pool = None
215
+ self._embedding_dims = None
216
+ self._setup_complete = False
217
+
218
+ def initialize(self) -> "SemanticCacheService":
219
+ """
220
+ Eagerly initialize the cache service.
221
+
222
+ Call this during tool creation to:
223
+ - Validate configuration early (fail fast)
224
+ - Create the database table before any requests
225
+ - Avoid first-request latency from lazy initialization
226
+
227
+ Returns:
228
+ self for method chaining
229
+ """
230
+ self._setup()
231
+ return self
232
+
233
+ def _setup(self) -> None:
234
+ """Initialize embeddings and database connection pool lazily."""
235
+ if self._setup_complete:
236
+ return
237
+
238
+ from dao_ai.memory.postgres import PostgresPoolManager
239
+
240
+ # Initialize embeddings
241
+ # Convert embedding_model to LLMModel if it's a string
242
+ embedding_model: LLMModel = (
243
+ LLMModel(name=self.parameters.embedding_model)
244
+ if isinstance(self.parameters.embedding_model, str)
245
+ else self.parameters.embedding_model
246
+ )
247
+ self._embeddings = embedding_model.as_embeddings_model()
248
+
249
+ # Auto-detect embedding dimensions if not provided
250
+ if self.parameters.embedding_dims is None:
251
+ sample_embedding: list[float] = self._embeddings.embed_query("test")
252
+ self._embedding_dims = len(sample_embedding)
253
+ logger.debug(
254
+ f"[{self.name}] Auto-detected embedding dimensions: {self._embedding_dims}"
255
+ )
256
+ else:
257
+ self._embedding_dims = self.parameters.embedding_dims
258
+
259
+ # Get connection pool
260
+ self._pool = PostgresPoolManager.get_pool(self.parameters.database)
261
+
262
+ # Ensure table exists
263
+ self._create_table_if_not_exists()
264
+
265
+ self._setup_complete = True
266
+ logger.debug(
267
+ f"[{self.name}] Semantic cache initialized for space '{self.space_id}' "
268
+ f"with table '{self.table_name}' (dims={self._embedding_dims})"
269
+ )
270
+
271
+ @property
272
+ def database(self) -> DatabaseModel:
273
+ """The database used for storing cache entries."""
274
+ return self.parameters.database
275
+
276
+ @property
277
+ def warehouse(self) -> WarehouseModel:
278
+ """The warehouse used for executing cached SQL queries."""
279
+ return self.parameters.warehouse
280
+
281
+ @property
282
+ def time_to_live(self) -> timedelta | None:
283
+ """Time-to-live for cache entries. None means never expires."""
284
+ ttl = self.parameters.time_to_live_seconds
285
+ if ttl is None or ttl < 0:
286
+ return None
287
+ return timedelta(seconds=ttl)
288
+
289
+ @property
290
+ def similarity_threshold(self) -> float:
291
+ """Minimum similarity for cache hit (using L2 distance converted to similarity)."""
292
+ return self.parameters.similarity_threshold
293
+
294
+ @property
295
+ def embedding_dims(self) -> int:
296
+ """Dimension size for embeddings (auto-detected if not configured)."""
297
+ if self._embedding_dims is None:
298
+ raise RuntimeError(
299
+ "Embedding dimensions not yet initialized. Call _setup() first."
300
+ )
301
+ return self._embedding_dims
302
+
303
+ @property
304
+ def table_name(self) -> str:
305
+ """Name of the cache table."""
306
+ return self.parameters.table_name
307
+
308
+ def _create_table_if_not_exists(self) -> None:
309
+ """Create the cache table with pg_vector extension if it doesn't exist.
310
+
311
+ If the table exists but has a different embedding dimension, it will be
312
+ dropped and recreated with the new dimension size.
313
+ """
314
+ create_extension_sql: str = "CREATE EXTENSION IF NOT EXISTS vector"
315
+
316
+ # Check if table exists and get current embedding dimensions
317
+ check_dims_sql: str = """
318
+ SELECT atttypmod
319
+ FROM pg_attribute
320
+ WHERE attrelid = %s::regclass
321
+ AND attname = 'question_embedding'
322
+ """
323
+
324
+ create_table_sql: str = f"""
325
+ CREATE TABLE IF NOT EXISTS {self.table_name} (
326
+ id SERIAL PRIMARY KEY,
327
+ genie_space_id TEXT NOT NULL,
328
+ question TEXT NOT NULL,
329
+ conversation_context TEXT,
330
+ context_string TEXT,
331
+ question_embedding vector({self.embedding_dims}),
332
+ context_embedding vector({self.embedding_dims}),
333
+ sql_query TEXT NOT NULL,
334
+ description TEXT,
335
+ conversation_id TEXT,
336
+ created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
337
+ )
338
+ """
339
+ # Index for efficient similarity search partitioned by genie_space_id
340
+ # Use L2 (Euclidean) distance - optimal for Databricks GTE embeddings
341
+ create_question_embedding_index_sql: str = f"""
342
+ CREATE INDEX IF NOT EXISTS {self.table_name}_question_embedding_idx
343
+ ON {self.table_name}
344
+ USING ivfflat (question_embedding vector_l2_ops)
345
+ WITH (lists = 100)
346
+ """
347
+ create_context_embedding_index_sql: str = f"""
348
+ CREATE INDEX IF NOT EXISTS {self.table_name}_context_embedding_idx
349
+ ON {self.table_name}
350
+ USING ivfflat (context_embedding vector_l2_ops)
351
+ WITH (lists = 100)
352
+ """
353
+ # Index for filtering by genie_space_id
354
+ create_space_index_sql: str = f"""
355
+ CREATE INDEX IF NOT EXISTS {self.table_name}_space_idx
356
+ ON {self.table_name} (genie_space_id)
357
+ """
358
+
359
+ with self._pool.connection() as conn:
360
+ with conn.cursor() as cur:
361
+ cur.execute(create_extension_sql)
362
+
363
+ # Check if table exists and verify embedding dimensions
364
+ try:
365
+ cur.execute(check_dims_sql, (self.table_name,))
366
+ row: DbRow | None = cur.fetchone()
367
+ if row is not None:
368
+ # atttypmod for vector type contains the dimension
369
+ current_dims = row.get("atttypmod", 0)
370
+ if current_dims != self.embedding_dims:
371
+ logger.warning(
372
+ f"[{self.name}] Embedding dimension mismatch: "
373
+ f"table has {current_dims}, expected {self.embedding_dims}. "
374
+ f"Dropping and recreating table '{self.table_name}'."
375
+ )
376
+ cur.execute(f"DROP TABLE {self.table_name}")
377
+ except Exception:
378
+ # Table doesn't exist, which is fine
379
+ pass
380
+
381
+ cur.execute(create_table_sql)
382
+ cur.execute(create_space_index_sql)
383
+ cur.execute(create_question_embedding_index_sql)
384
+ cur.execute(create_context_embedding_index_sql)
385
+
386
+ def _embed_question(
387
+ self, question: str, conversation_id: str | None = None
388
+ ) -> tuple[list[float], list[float], str]:
389
+ """
390
+ Generate dual embeddings: one for the question, one for the conversation context.
391
+
392
+ This enables separate matching of question similarity vs context similarity,
393
+ improving precision by ensuring both the question AND the conversation context
394
+ are semantically similar before returning a cached result.
395
+
396
+ Args:
397
+ question: The question to embed
398
+ conversation_id: Optional conversation ID for retrieving context
399
+
400
+ Returns:
401
+ Tuple of (question_embedding, context_embedding, conversation_context_string)
402
+ - question_embedding: Vector embedding of just the question
403
+ - context_embedding: Vector embedding of the conversation context (or zero vector if no context)
404
+ - conversation_context_string: The conversation context string (empty if no context)
405
+ """
406
+ conversation_context = ""
407
+
408
+ # If conversation context is enabled and available
409
+ if (
410
+ self.workspace_client is not None
411
+ and conversation_id is not None
412
+ and self.parameters.context_window_size > 0
413
+ ):
414
+ try:
415
+ # Retrieve conversation history
416
+ conversation_messages = get_conversation_history(
417
+ workspace_client=self.workspace_client,
418
+ space_id=self.space_id,
419
+ conversation_id=conversation_id,
420
+ max_messages=self.parameters.context_window_size
421
+ * 2, # Get extra for safety
422
+ )
423
+
424
+ # Build context string (just the "Previous:" messages, not the current question)
425
+ if conversation_messages:
426
+ recent_messages = (
427
+ conversation_messages[-self.parameters.context_window_size :]
428
+ if len(conversation_messages)
429
+ > self.parameters.context_window_size
430
+ else conversation_messages
431
+ )
432
+
433
+ context_parts: list[str] = []
434
+ for msg in recent_messages:
435
+ if msg.content:
436
+ content: str = msg.content
437
+ if len(content) > 500:
438
+ content = content[:500] + "..."
439
+ context_parts.append(f"Previous: {content}")
440
+
441
+ conversation_context = "\n".join(context_parts)
442
+
443
+ # Truncate if too long
444
+ estimated_tokens = len(conversation_context) / 4
445
+ if estimated_tokens > self.parameters.max_context_tokens:
446
+ target_chars = self.parameters.max_context_tokens * 4
447
+ conversation_context = (
448
+ conversation_context[:target_chars] + "..."
449
+ )
450
+
451
+ logger.debug(
452
+ f"[{self.name}] Using conversation context: {len(conversation_messages)} messages "
453
+ f"(window_size={self.parameters.context_window_size})"
454
+ )
455
+ except Exception as e:
456
+ logger.warning(
457
+ f"[{self.name}] Failed to build conversation context, using question only: {e}"
458
+ )
459
+ conversation_context = ""
460
+
461
+ # Generate dual embeddings
462
+ if conversation_context:
463
+ # Embed both question and context
464
+ embeddings: list[list[float]] = self._embeddings.embed_documents(
465
+ [question, conversation_context]
466
+ )
467
+ question_embedding = embeddings[0]
468
+ context_embedding = embeddings[1]
469
+ else:
470
+ # Only embed question, use zero vector for context
471
+ embeddings = self._embeddings.embed_documents([question])
472
+ question_embedding = embeddings[0]
473
+ context_embedding = [0.0] * len(question_embedding) # Zero vector
474
+
475
+ return question_embedding, context_embedding, conversation_context
476
+
477
+ @mlflow.trace(name="semantic_search")
478
+ def _find_similar(
479
+ self,
480
+ question: str,
481
+ conversation_context: str,
482
+ question_embedding: list[float],
483
+ context_embedding: list[float],
484
+ ) -> tuple[SQLCacheEntry, float] | None:
485
+ """
486
+ Find a semantically similar cached entry using dual embedding matching.
487
+
488
+ This method matches BOTH the question AND the conversation context separately,
489
+ ensuring high precision by requiring both to be semantically similar.
490
+
491
+ Args:
492
+ question: The original question (for logging)
493
+ conversation_context: The conversation context string
494
+ question_embedding: The embedding vector of just the question
495
+ context_embedding: The embedding vector of the conversation context
496
+
497
+ Returns:
498
+ Tuple of (SQLCacheEntry, combined_similarity_score) if found, None otherwise
499
+ """
500
+ # Use L2 (Euclidean) distance - optimal for Databricks GTE embeddings
501
+ # pg_vector's <-> operator returns L2 distance (0 = identical)
502
+ # Convert to similarity: 1 / (1 + distance) gives range [0, 1]
503
+ #
504
+ # Dual embedding strategy:
505
+ # 1. Calculate separate similarities for question and context
506
+ # 2. BOTH must exceed their respective thresholds
507
+ # 3. Combined score is weighted average
508
+ # 4. Refresh-on-hit: check TTL after similarity check
509
+ ttl_seconds = self.parameters.time_to_live_seconds
510
+ ttl_disabled = ttl_seconds is None or ttl_seconds < 0
511
+
512
+ # When TTL is disabled, all entries are always valid
513
+ if ttl_disabled:
514
+ is_valid_expr = "TRUE"
515
+ else:
516
+ is_valid_expr = f"created_at > NOW() - INTERVAL '{ttl_seconds} seconds'"
517
+
518
+ # Weighted combined similarity for ordering
519
+ question_weight: float = self.parameters.question_weight
520
+ context_weight: float = self.parameters.context_weight
521
+
522
+ search_sql: str = f"""
523
+ SELECT
524
+ id,
525
+ question,
526
+ conversation_context,
527
+ sql_query,
528
+ description,
529
+ conversation_id,
530
+ created_at,
531
+ 1.0 / (1.0 + (question_embedding <-> %s::vector)) as question_similarity,
532
+ 1.0 / (1.0 + (context_embedding <-> %s::vector)) as context_similarity,
533
+ ({question_weight} * (1.0 / (1.0 + (question_embedding <-> %s::vector)))) +
534
+ ({context_weight} * (1.0 / (1.0 + (context_embedding <-> %s::vector)))) as combined_similarity,
535
+ {is_valid_expr} as is_valid
536
+ FROM {self.table_name}
537
+ WHERE genie_space_id = %s
538
+ ORDER BY combined_similarity DESC
539
+ LIMIT 1
540
+ """
541
+
542
+ question_emb_str: str = f"[{','.join(str(x) for x in question_embedding)}]"
543
+ context_emb_str: str = f"[{','.join(str(x) for x in context_embedding)}]"
544
+
545
+ with self._pool.connection() as conn:
546
+ with conn.cursor() as cur:
547
+ cur.execute(
548
+ search_sql,
549
+ (
550
+ question_emb_str,
551
+ context_emb_str,
552
+ question_emb_str,
553
+ context_emb_str,
554
+ self.space_id,
555
+ ),
556
+ )
557
+ row: DbRow | None = cur.fetchone()
558
+
559
+ if row is None:
560
+ logger.info(
561
+ f"[{self.name}] MISS (no entries): "
562
+ f"question='{question[:50]}...' space='{self.space_id}'"
563
+ )
564
+ return None
565
+
566
+ # Extract values from dict row
567
+ entry_id: Any = row.get("id")
568
+ cached_question: str = row.get("question", "")
569
+ cached_context: str = row.get("conversation_context", "")
570
+ sql_query: str = row["sql_query"]
571
+ description: str = row.get("description", "")
572
+ conversation_id_cached: str = row.get("conversation_id", "")
573
+ created_at: Any = row["created_at"]
574
+ question_similarity: float = row["question_similarity"]
575
+ context_similarity: float = row["context_similarity"]
576
+ combined_similarity: float = row["combined_similarity"]
577
+ is_valid: bool = row.get("is_valid", False)
578
+
579
+ # Log best match info
580
+ logger.info(
581
+ f"[{self.name}] Best match: "
582
+ f"question_sim={question_similarity:.4f}, context_sim={context_similarity:.4f}, "
583
+ f"combined_sim={combined_similarity:.4f}, is_valid={is_valid}, "
584
+ f"question='{cached_question[:50]}...', context='{cached_context[:80]}...'"
585
+ )
586
+
587
+ # Check BOTH similarity thresholds (dual embedding precision check)
588
+ if question_similarity < self.parameters.similarity_threshold:
589
+ logger.info(
590
+ f"[{self.name}] MISS (question similarity too low): "
591
+ f"question_sim={question_similarity:.4f} < threshold={self.parameters.similarity_threshold}"
592
+ )
593
+ return None
594
+
595
+ if context_similarity < self.parameters.context_similarity_threshold:
596
+ logger.info(
597
+ f"[{self.name}] MISS (context similarity too low): "
598
+ f"context_sim={context_similarity:.4f} < threshold={self.parameters.context_similarity_threshold}"
599
+ )
600
+ return None
601
+
602
+ # Check TTL - refresh on hit strategy
603
+ if not is_valid:
604
+ # Entry is expired - delete it and return miss to trigger refresh
605
+ delete_sql = f"DELETE FROM {self.table_name} WHERE id = %s"
606
+ cur.execute(delete_sql, (entry_id,))
607
+ logger.info(
608
+ f"[{self.name}] MISS (expired, deleted for refresh): "
609
+ f"combined_sim={combined_similarity:.4f}, ttl={ttl_seconds}s, question='{cached_question[:50]}...'"
610
+ )
611
+ return None
612
+
613
+ logger.info(
614
+ f"[{self.name}] HIT: question_sim={question_similarity:.4f}, context_sim={context_similarity:.4f}, "
615
+ f"combined_sim={combined_similarity:.4f} (cached_question='{cached_question[:50]}...')"
616
+ )
617
+
618
+ entry = SQLCacheEntry(
619
+ query=sql_query,
620
+ description=description,
621
+ conversation_id=conversation_id_cached,
622
+ created_at=created_at,
623
+ )
624
+ return entry, combined_similarity
625
+
626
+ def _store_entry(
627
+ self,
628
+ question: str,
629
+ conversation_context: str,
630
+ question_embedding: list[float],
631
+ context_embedding: list[float],
632
+ response: GenieResponse,
633
+ ) -> None:
634
+ """Store a new cache entry with dual embeddings for this Genie space."""
635
+ insert_sql: str = f"""
636
+ INSERT INTO {self.table_name}
637
+ (genie_space_id, question, conversation_context, context_string,
638
+ question_embedding, context_embedding, sql_query, description, conversation_id)
639
+ VALUES (%s, %s, %s, %s, %s::vector, %s::vector, %s, %s, %s)
640
+ """
641
+ question_emb_str: str = f"[{','.join(str(x) for x in question_embedding)}]"
642
+ context_emb_str: str = f"[{','.join(str(x) for x in context_embedding)}]"
643
+
644
+ # Build full context string for backward compatibility (used in logging)
645
+ if conversation_context:
646
+ full_context_string = f"{conversation_context}\nCurrent: {question}"
647
+ else:
648
+ full_context_string = question
649
+
650
+ with self._pool.connection() as conn:
651
+ with conn.cursor() as cur:
652
+ cur.execute(
653
+ insert_sql,
654
+ (
655
+ self.space_id,
656
+ question,
657
+ conversation_context,
658
+ full_context_string,
659
+ question_emb_str,
660
+ context_emb_str,
661
+ response.query,
662
+ response.description,
663
+ response.conversation_id,
664
+ ),
665
+ )
666
+ logger.info(
667
+ f"[{self.name}] Stored cache entry: question='{question[:50]}...' "
668
+ f"context='{conversation_context[:80]}...' "
669
+ f"sql='{response.query[:50]}...' (space={self.space_id}, table={self.table_name})"
670
+ )
671
+
672
+ @mlflow.trace(name="execute_cached_sql_semantic")
673
+ def _execute_sql(self, sql: str) -> pd.DataFrame | str:
674
+ """Execute SQL using the warehouse and return results."""
675
+ client: WorkspaceClient = self.warehouse.workspace_client
676
+ warehouse_id: str = str(self.warehouse.warehouse_id)
677
+
678
+ statement_response: StatementResponse = (
679
+ client.statement_execution.execute_statement(
680
+ warehouse_id=warehouse_id,
681
+ statement=sql,
682
+ wait_timeout="30s",
683
+ )
684
+ )
685
+
686
+ if (
687
+ statement_response.status is not None
688
+ and statement_response.status.state != StatementState.SUCCEEDED
689
+ ):
690
+ error_msg: str = (
691
+ f"SQL execution failed: {statement_response.status.error.message}"
692
+ if statement_response.status.error is not None
693
+ else f"SQL execution failed with state: {statement_response.status.state}"
694
+ )
695
+ logger.error(f"[{self.name}] {error_msg}")
696
+ return error_msg
697
+
698
+ if statement_response.result and statement_response.result.data_array:
699
+ columns: list[str] = []
700
+ if (
701
+ statement_response.manifest
702
+ and statement_response.manifest.schema
703
+ and statement_response.manifest.schema.columns
704
+ ):
705
+ columns = [
706
+ col.name
707
+ for col in statement_response.manifest.schema.columns
708
+ if col.name is not None
709
+ ]
710
+
711
+ data: list[list[Any]] = statement_response.result.data_array
712
+ if columns:
713
+ return pd.DataFrame(data, columns=columns)
714
+ else:
715
+ return pd.DataFrame(data)
716
+
717
+ return pd.DataFrame()
718
+
719
+ def ask_question(
720
+ self, question: str, conversation_id: str | None = None
721
+ ) -> CacheResult:
722
+ """
723
+ Ask a question, using semantic cache if a similar query exists.
724
+
725
+ On cache hit, re-executes the cached SQL to get fresh data.
726
+ Returns CacheResult with cache metadata.
727
+ """
728
+ return self.ask_question_with_cache_info(question, conversation_id)
729
+
730
+ @mlflow.trace(name="genie_semantic_cache_lookup")
731
+ def ask_question_with_cache_info(
732
+ self,
733
+ question: str,
734
+ conversation_id: str | None = None,
735
+ ) -> CacheResult:
736
+ """
737
+ Ask a question with detailed cache hit information.
738
+
739
+ On cache hit, the cached SQL is re-executed to return fresh data, but the
740
+ conversation_id returned is the current conversation_id (not the cached one).
741
+
742
+ Args:
743
+ question: The question to ask
744
+ conversation_id: Optional conversation ID for context and continuation
745
+
746
+ Returns:
747
+ CacheResult with fresh response and cache metadata
748
+ """
749
+ # Ensure initialization (lazy init if initialize() wasn't called)
750
+ self._setup()
751
+
752
+ # Generate dual embeddings for the question and conversation context
753
+ question_embedding: list[float]
754
+ context_embedding: list[float]
755
+ conversation_context: str
756
+ question_embedding, context_embedding, conversation_context = (
757
+ self._embed_question(question, conversation_id)
758
+ )
759
+
760
+ # Check cache using dual embedding similarity
761
+ cache_result: tuple[SQLCacheEntry, float] | None = self._find_similar(
762
+ question, conversation_context, question_embedding, context_embedding
763
+ )
764
+
765
+ if cache_result is not None:
766
+ cached, combined_similarity = cache_result
767
+ logger.debug(
768
+ f"[{self.name}] Semantic cache hit (combined_similarity={combined_similarity:.3f}): {question[:50]}..."
769
+ )
770
+
771
+ # Re-execute the cached SQL to get fresh data
772
+ result: pd.DataFrame | str = self._execute_sql(cached.query)
773
+
774
+ # IMPORTANT: Use the current conversation_id (from the request), not the cached one
775
+ # This ensures the conversation continues properly
776
+ response: GenieResponse = GenieResponse(
777
+ result=result,
778
+ query=cached.query,
779
+ description=cached.description,
780
+ conversation_id=conversation_id
781
+ if conversation_id
782
+ else cached.conversation_id,
783
+ )
784
+
785
+ return CacheResult(response=response, cache_hit=True, served_by=self.name)
786
+
787
+ # Cache miss - delegate to wrapped service
788
+ logger.debug(f"[{self.name}] Miss: {question[:50]}...")
789
+
790
+ result: CacheResult = self.impl.ask_question(question, conversation_id)
791
+
792
+ # Store in cache if we got a SQL query
793
+ if result.response.query:
794
+ logger.info(
795
+ f"[{self.name}] Storing new cache entry for question: '{question[:50]}...' "
796
+ f"(space={self.space_id})"
797
+ )
798
+ self._store_entry(
799
+ question,
800
+ conversation_context,
801
+ question_embedding,
802
+ context_embedding,
803
+ result.response,
804
+ )
805
+ elif not result.response.query:
806
+ logger.warning(
807
+ f"[{self.name}] Not caching: response has no SQL query "
808
+ f"(question='{question[:50]}...')"
809
+ )
810
+
811
+ return CacheResult(response=result.response, cache_hit=False, served_by=None)
812
+
813
+ @property
814
+ def space_id(self) -> str:
815
+ return self.impl.space_id
816
+
817
+ def invalidate_expired(self) -> int:
818
+ """Remove expired entries from the cache for this Genie space.
819
+
820
+ Returns 0 if TTL is disabled (entries never expire).
821
+ """
822
+ self._setup()
823
+ ttl_seconds = self.parameters.time_to_live_seconds
824
+
825
+ # If TTL is disabled, nothing can expire
826
+ if ttl_seconds is None or ttl_seconds < 0:
827
+ logger.debug(
828
+ f"[{self.name}] TTL disabled, no entries to expire for space {self.space_id}"
829
+ )
830
+ return 0
831
+
832
+ delete_sql: str = f"""
833
+ DELETE FROM {self.table_name}
834
+ WHERE genie_space_id = %s
835
+ AND created_at < NOW() - INTERVAL '%s seconds'
836
+ """
837
+
838
+ with self._pool.connection() as conn:
839
+ with conn.cursor() as cur:
840
+ cur.execute(delete_sql, (self.space_id, ttl_seconds))
841
+ deleted: int = cur.rowcount
842
+ logger.debug(
843
+ f"[{self.name}] Deleted {deleted} expired entries for space {self.space_id}"
844
+ )
845
+ return deleted
846
+
847
+ def clear(self) -> int:
848
+ """Clear all entries from the cache for this Genie space."""
849
+ self._setup()
850
+ delete_sql: str = f"DELETE FROM {self.table_name} WHERE genie_space_id = %s"
851
+
852
+ with self._pool.connection() as conn:
853
+ with conn.cursor() as cur:
854
+ cur.execute(delete_sql, (self.space_id,))
855
+ deleted: int = cur.rowcount
856
+ logger.debug(
857
+ f"[{self.name}] Cleared {deleted} entries for space {self.space_id}"
858
+ )
859
+ return deleted
860
+
861
+ @property
862
+ def size(self) -> int:
863
+ """Current number of entries in the cache for this Genie space."""
864
+ self._setup()
865
+ count_sql: str = (
866
+ f"SELECT COUNT(*) as count FROM {self.table_name} WHERE genie_space_id = %s"
867
+ )
868
+
869
+ with self._pool.connection() as conn:
870
+ with conn.cursor() as cur:
871
+ cur.execute(count_sql, (self.space_id,))
872
+ row: DbRow | None = cur.fetchone()
873
+ return row.get("count", 0) if row else 0
874
+
875
+ def stats(self) -> dict[str, int | float | None]:
876
+ """Return cache statistics for this Genie space."""
877
+ self._setup()
878
+ ttl_seconds = self.parameters.time_to_live_seconds
879
+ ttl = self.time_to_live
880
+
881
+ # If TTL is disabled, all entries are valid
882
+ if ttl_seconds is None or ttl_seconds < 0:
883
+ count_sql: str = f"""
884
+ SELECT COUNT(*) as total FROM {self.table_name}
885
+ WHERE genie_space_id = %s
886
+ """
887
+ with self._pool.connection() as conn:
888
+ with conn.cursor() as cur:
889
+ cur.execute(count_sql, (self.space_id,))
890
+ row: DbRow | None = cur.fetchone()
891
+ total = row.get("total", 0) if row else 0
892
+ return {
893
+ "size": total,
894
+ "ttl_seconds": None,
895
+ "similarity_threshold": self.similarity_threshold,
896
+ "expired_entries": 0,
897
+ "valid_entries": total,
898
+ }
899
+
900
+ stats_sql: str = f"""
901
+ SELECT
902
+ COUNT(*) as total,
903
+ COUNT(*) FILTER (WHERE created_at > NOW() - INTERVAL '%s seconds') as valid,
904
+ COUNT(*) FILTER (WHERE created_at <= NOW() - INTERVAL '%s seconds') as expired
905
+ FROM {self.table_name}
906
+ WHERE genie_space_id = %s
907
+ """
908
+
909
+ with self._pool.connection() as conn:
910
+ with conn.cursor() as cur:
911
+ cur.execute(stats_sql, (ttl_seconds, ttl_seconds, self.space_id))
912
+ stats_row: DbRow | None = cur.fetchone()
913
+ return {
914
+ "size": stats_row.get("total", 0) if stats_row else 0,
915
+ "ttl_seconds": ttl.total_seconds() if ttl else None,
916
+ "similarity_threshold": self.similarity_threshold,
917
+ "expired_entries": stats_row.get("expired", 0) if stats_row else 0,
918
+ "valid_entries": stats_row.get("valid", 0) if stats_row else 0,
919
+ }