dao-ai 0.0.25__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. dao_ai/__init__.py +29 -0
  2. dao_ai/agent_as_code.py +5 -5
  3. dao_ai/cli.py +245 -40
  4. dao_ai/config.py +1863 -338
  5. dao_ai/genie/__init__.py +38 -0
  6. dao_ai/genie/cache/__init__.py +43 -0
  7. dao_ai/genie/cache/base.py +72 -0
  8. dao_ai/genie/cache/core.py +79 -0
  9. dao_ai/genie/cache/lru.py +347 -0
  10. dao_ai/genie/cache/semantic.py +970 -0
  11. dao_ai/genie/core.py +35 -0
  12. dao_ai/graph.py +27 -228
  13. dao_ai/hooks/__init__.py +9 -6
  14. dao_ai/hooks/core.py +27 -195
  15. dao_ai/logging.py +56 -0
  16. dao_ai/memory/__init__.py +10 -0
  17. dao_ai/memory/core.py +65 -30
  18. dao_ai/memory/databricks.py +402 -0
  19. dao_ai/memory/postgres.py +79 -38
  20. dao_ai/messages.py +6 -4
  21. dao_ai/middleware/__init__.py +125 -0
  22. dao_ai/middleware/assertions.py +806 -0
  23. dao_ai/middleware/base.py +50 -0
  24. dao_ai/middleware/core.py +67 -0
  25. dao_ai/middleware/guardrails.py +420 -0
  26. dao_ai/middleware/human_in_the_loop.py +232 -0
  27. dao_ai/middleware/message_validation.py +586 -0
  28. dao_ai/middleware/summarization.py +197 -0
  29. dao_ai/models.py +1306 -114
  30. dao_ai/nodes.py +261 -166
  31. dao_ai/optimization.py +674 -0
  32. dao_ai/orchestration/__init__.py +52 -0
  33. dao_ai/orchestration/core.py +294 -0
  34. dao_ai/orchestration/supervisor.py +278 -0
  35. dao_ai/orchestration/swarm.py +271 -0
  36. dao_ai/prompts.py +128 -31
  37. dao_ai/providers/databricks.py +645 -172
  38. dao_ai/state.py +157 -21
  39. dao_ai/tools/__init__.py +13 -5
  40. dao_ai/tools/agent.py +1 -3
  41. dao_ai/tools/core.py +64 -11
  42. dao_ai/tools/email.py +232 -0
  43. dao_ai/tools/genie.py +144 -295
  44. dao_ai/tools/mcp.py +220 -133
  45. dao_ai/tools/memory.py +50 -0
  46. dao_ai/tools/python.py +9 -14
  47. dao_ai/tools/search.py +14 -0
  48. dao_ai/tools/slack.py +22 -10
  49. dao_ai/tools/sql.py +202 -0
  50. dao_ai/tools/time.py +30 -7
  51. dao_ai/tools/unity_catalog.py +165 -88
  52. dao_ai/tools/vector_search.py +360 -40
  53. dao_ai/utils.py +218 -16
  54. dao_ai-0.1.2.dist-info/METADATA +455 -0
  55. dao_ai-0.1.2.dist-info/RECORD +64 -0
  56. {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/WHEEL +1 -1
  57. dao_ai/chat_models.py +0 -204
  58. dao_ai/guardrails.py +0 -112
  59. dao_ai/tools/human_in_the_loop.py +0 -100
  60. dao_ai-0.0.25.dist-info/METADATA +0 -1165
  61. dao_ai-0.0.25.dist-info/RECORD +0 -41
  62. {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/entry_points.txt +0 -0
  63. {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,970 @@
1
+ """
2
+ Semantic cache implementation for Genie SQL queries using PostgreSQL pg_vector.
3
+
4
+ This module provides a semantic cache that uses embeddings and similarity search
5
+ to find cached queries that match the intent of new questions. Cache entries are
6
+ partitioned by genie_space_id to ensure proper isolation between Genie spaces.
7
+
8
+ The cache supports conversation-aware embedding using a rolling window approach
9
+ to capture context from recent conversation turns, improving accuracy for
10
+ multi-turn conversations with anaphoric references.
11
+ """
12
+
13
+ from datetime import timedelta
14
+ from typing import Any
15
+
16
+ import mlflow
17
+ import pandas as pd
18
+ from databricks.sdk import WorkspaceClient
19
+ from databricks.sdk.service.dashboards import (
20
+ GenieListConversationMessagesResponse,
21
+ GenieMessage,
22
+ )
23
+ from databricks.sdk.service.sql import StatementResponse, StatementState
24
+ from databricks_ai_bridge.genie import GenieResponse
25
+ from loguru import logger
26
+
27
+ from dao_ai.config import (
28
+ DatabaseModel,
29
+ GenieSemanticCacheParametersModel,
30
+ LLMModel,
31
+ WarehouseModel,
32
+ )
33
+ from dao_ai.genie.cache.base import (
34
+ CacheResult,
35
+ GenieServiceBase,
36
+ SQLCacheEntry,
37
+ )
38
+
39
+ # Type alias for database row (dict due to row_factory=dict_row)
40
+ DbRow = dict[str, Any]
41
+
42
+
43
+ def get_conversation_history(
44
+ workspace_client: WorkspaceClient,
45
+ space_id: str,
46
+ conversation_id: str,
47
+ max_messages: int = 10,
48
+ ) -> list[GenieMessage]:
49
+ """
50
+ Retrieve conversation history from Genie.
51
+
52
+ Args:
53
+ workspace_client: The Databricks workspace client
54
+ space_id: The Genie space ID
55
+ conversation_id: The conversation ID to retrieve
56
+ max_messages: Maximum number of messages to retrieve
57
+
58
+ Returns:
59
+ List of GenieMessage objects representing the conversation history
60
+ """
61
+ try:
62
+ # Use the Genie API to retrieve conversation messages
63
+ response: GenieListConversationMessagesResponse = (
64
+ workspace_client.genie.list_conversation_messages(
65
+ space_id=space_id,
66
+ conversation_id=conversation_id,
67
+ )
68
+ )
69
+
70
+ # Return the most recent messages up to max_messages
71
+ if response.messages is not None:
72
+ all_messages: list[GenieMessage] = list(response.messages)
73
+ return (
74
+ all_messages[-max_messages:]
75
+ if len(all_messages) > max_messages
76
+ else all_messages
77
+ )
78
+ return []
79
+ except Exception as e:
80
+ logger.warning(
81
+ "Failed to retrieve conversation history",
82
+ conversation_id=conversation_id,
83
+ error=str(e),
84
+ )
85
+ return []
86
+
87
+
88
+ def build_context_string(
89
+ question: str,
90
+ conversation_messages: list[GenieMessage],
91
+ window_size: int,
92
+ max_tokens: int = 2000,
93
+ ) -> str:
94
+ """
95
+ Build a context-aware question string using rolling window.
96
+
97
+ This function creates a concatenated string that includes recent conversation
98
+ turns to provide context for semantic similarity matching.
99
+
100
+ Args:
101
+ question: The current question
102
+ conversation_messages: List of previous conversation messages
103
+ window_size: Number of previous turns to include
104
+ max_tokens: Maximum estimated tokens (rough approximation: 4 chars = 1 token)
105
+
106
+ Returns:
107
+ Context-aware question string formatted for embedding
108
+ """
109
+ if window_size <= 0 or not conversation_messages:
110
+ return question
111
+
112
+ # Take the last window_size messages (most recent)
113
+ recent_messages = (
114
+ conversation_messages[-window_size:]
115
+ if len(conversation_messages) > window_size
116
+ else conversation_messages
117
+ )
118
+
119
+ # Build context parts
120
+ context_parts: list[str] = []
121
+
122
+ for msg in recent_messages:
123
+ # Only include messages with content from the history
124
+ if msg.content:
125
+ # Limit message length to prevent token overflow
126
+ content: str = msg.content
127
+ if len(content) > 500: # Truncate very long messages
128
+ content = content[:500] + "..."
129
+ context_parts.append(f"Previous: {content}")
130
+
131
+ # Add current question
132
+ context_parts.append(f"Current: {question}")
133
+
134
+ # Join with newlines
135
+ context_string = "\n".join(context_parts)
136
+
137
+ # Rough token limit check (4 chars ≈ 1 token)
138
+ estimated_tokens = len(context_string) / 4
139
+ if estimated_tokens > max_tokens:
140
+ # Truncate to fit max_tokens
141
+ target_chars = max_tokens * 4
142
+ original_length = len(context_string)
143
+ context_string = context_string[:target_chars] + "..."
144
+ logger.trace(
145
+ "Truncated context string",
146
+ original_chars=original_length,
147
+ target_chars=target_chars,
148
+ max_tokens=max_tokens,
149
+ )
150
+
151
+ return context_string
152
+
153
+
154
+ class SemanticCacheService(GenieServiceBase):
155
+ """
156
+ Semantic caching decorator that uses PostgreSQL pg_vector for similarity lookup.
157
+
158
+ This service caches the SQL query generated by Genie along with an embedding
159
+ of the original question. On subsequent queries, it performs a semantic similarity
160
+ search to find cached queries that match the intent of the new question.
161
+
162
+ Cache entries are partitioned by genie_space_id to ensure queries from different
163
+ Genie spaces don't return incorrect cache hits.
164
+
165
+ On cache hit, it re-executes the cached SQL using the provided warehouse
166
+ to return fresh data while avoiding the Genie NL-to-SQL translation cost.
167
+
168
+ Example:
169
+ from dao_ai.config import GenieSemanticCacheParametersModel, DatabaseModel
170
+ from dao_ai.genie.cache import SemanticCacheService
171
+
172
+ cache_params = GenieSemanticCacheParametersModel(
173
+ database=database_model,
174
+ warehouse=warehouse_model,
175
+ embedding_model="databricks-gte-large-en",
176
+ time_to_live_seconds=86400, # 24 hours
177
+ similarity_threshold=0.85
178
+ )
179
+ genie = SemanticCacheService(
180
+ impl=GenieService(Genie(space_id="my-space")),
181
+ parameters=cache_params
182
+ )
183
+
184
+ Thread-safe: Uses connection pooling from psycopg_pool.
185
+ """
186
+
187
+ impl: GenieServiceBase
188
+ parameters: GenieSemanticCacheParametersModel
189
+ workspace_client: WorkspaceClient | None
190
+ name: str
191
+ _embeddings: Any # DatabricksEmbeddings
192
+ _pool: Any # ConnectionPool
193
+ _embedding_dims: int | None
194
+ _setup_complete: bool
195
+
196
+ def __init__(
197
+ self,
198
+ impl: GenieServiceBase,
199
+ parameters: GenieSemanticCacheParametersModel,
200
+ workspace_client: WorkspaceClient | None = None,
201
+ name: str | None = None,
202
+ ) -> None:
203
+ """
204
+ Initialize the semantic cache service.
205
+
206
+ Args:
207
+ impl: The underlying GenieServiceBase to delegate to on cache miss.
208
+ The space_id will be obtained from impl.space_id.
209
+ parameters: Cache configuration including database, warehouse, embedding model
210
+ workspace_client: Optional WorkspaceClient for retrieving conversation history.
211
+ If None, conversation context will not be used.
212
+ name: Name for this cache layer (for logging). Defaults to class name.
213
+ """
214
+ self.impl = impl
215
+ self.parameters = parameters
216
+ self.workspace_client = workspace_client
217
+ self.name = name if name is not None else self.__class__.__name__
218
+ self._embeddings = None
219
+ self._pool = None
220
+ self._embedding_dims = None
221
+ self._setup_complete = False
222
+
223
+ def initialize(self) -> "SemanticCacheService":
224
+ """
225
+ Eagerly initialize the cache service.
226
+
227
+ Call this during tool creation to:
228
+ - Validate configuration early (fail fast)
229
+ - Create the database table before any requests
230
+ - Avoid first-request latency from lazy initialization
231
+
232
+ Returns:
233
+ self for method chaining
234
+ """
235
+ self._setup()
236
+ return self
237
+
238
+ def _setup(self) -> None:
239
+ """Initialize embeddings and database connection pool lazily."""
240
+ if self._setup_complete:
241
+ return
242
+
243
+ from dao_ai.memory.postgres import PostgresPoolManager
244
+
245
+ # Initialize embeddings
246
+ # Convert embedding_model to LLMModel if it's a string
247
+ embedding_model: LLMModel = (
248
+ LLMModel(name=self.parameters.embedding_model)
249
+ if isinstance(self.parameters.embedding_model, str)
250
+ else self.parameters.embedding_model
251
+ )
252
+ self._embeddings = embedding_model.as_embeddings_model()
253
+
254
+ # Auto-detect embedding dimensions if not provided
255
+ if self.parameters.embedding_dims is None:
256
+ sample_embedding: list[float] = self._embeddings.embed_query("test")
257
+ self._embedding_dims = len(sample_embedding)
258
+ logger.debug(
259
+ "Auto-detected embedding dimensions",
260
+ layer=self.name,
261
+ dims=self._embedding_dims,
262
+ )
263
+ else:
264
+ self._embedding_dims = self.parameters.embedding_dims
265
+
266
+ # Get connection pool
267
+ self._pool = PostgresPoolManager.get_pool(self.parameters.database)
268
+
269
+ # Ensure table exists
270
+ self._create_table_if_not_exists()
271
+
272
+ self._setup_complete = True
273
+ logger.debug(
274
+ "Semantic cache initialized",
275
+ layer=self.name,
276
+ space_id=self.space_id,
277
+ table_name=self.table_name,
278
+ dims=self._embedding_dims,
279
+ )
280
+
281
+ @property
282
+ def database(self) -> DatabaseModel:
283
+ """The database used for storing cache entries."""
284
+ return self.parameters.database
285
+
286
+ @property
287
+ def warehouse(self) -> WarehouseModel:
288
+ """The warehouse used for executing cached SQL queries."""
289
+ return self.parameters.warehouse
290
+
291
+ @property
292
+ def time_to_live(self) -> timedelta | None:
293
+ """Time-to-live for cache entries. None means never expires."""
294
+ ttl = self.parameters.time_to_live_seconds
295
+ if ttl is None or ttl < 0:
296
+ return None
297
+ return timedelta(seconds=ttl)
298
+
299
+ @property
300
+ def similarity_threshold(self) -> float:
301
+ """Minimum similarity for cache hit (using L2 distance converted to similarity)."""
302
+ return self.parameters.similarity_threshold
303
+
304
+ @property
305
+ def embedding_dims(self) -> int:
306
+ """Dimension size for embeddings (auto-detected if not configured)."""
307
+ if self._embedding_dims is None:
308
+ raise RuntimeError(
309
+ "Embedding dimensions not yet initialized. Call _setup() first."
310
+ )
311
+ return self._embedding_dims
312
+
313
+ @property
314
+ def table_name(self) -> str:
315
+ """Name of the cache table."""
316
+ return self.parameters.table_name
317
+
318
+ def _create_table_if_not_exists(self) -> None:
319
+ """Create the cache table with pg_vector extension if it doesn't exist.
320
+
321
+ If the table exists but has a different embedding dimension, it will be
322
+ dropped and recreated with the new dimension size.
323
+ """
324
+ create_extension_sql: str = "CREATE EXTENSION IF NOT EXISTS vector"
325
+
326
+ # Check if table exists and get current embedding dimensions
327
+ check_dims_sql: str = """
328
+ SELECT atttypmod
329
+ FROM pg_attribute
330
+ WHERE attrelid = %s::regclass
331
+ AND attname = 'question_embedding'
332
+ """
333
+
334
+ create_table_sql: str = f"""
335
+ CREATE TABLE IF NOT EXISTS {self.table_name} (
336
+ id SERIAL PRIMARY KEY,
337
+ genie_space_id TEXT NOT NULL,
338
+ question TEXT NOT NULL,
339
+ conversation_context TEXT,
340
+ context_string TEXT,
341
+ question_embedding vector({self.embedding_dims}),
342
+ context_embedding vector({self.embedding_dims}),
343
+ sql_query TEXT NOT NULL,
344
+ description TEXT,
345
+ conversation_id TEXT,
346
+ created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
347
+ )
348
+ """
349
+ # Index for efficient similarity search partitioned by genie_space_id
350
+ # Use L2 (Euclidean) distance - optimal for Databricks GTE embeddings
351
+ create_question_embedding_index_sql: str = f"""
352
+ CREATE INDEX IF NOT EXISTS {self.table_name}_question_embedding_idx
353
+ ON {self.table_name}
354
+ USING ivfflat (question_embedding vector_l2_ops)
355
+ WITH (lists = 100)
356
+ """
357
+ create_context_embedding_index_sql: str = f"""
358
+ CREATE INDEX IF NOT EXISTS {self.table_name}_context_embedding_idx
359
+ ON {self.table_name}
360
+ USING ivfflat (context_embedding vector_l2_ops)
361
+ WITH (lists = 100)
362
+ """
363
+ # Index for filtering by genie_space_id
364
+ create_space_index_sql: str = f"""
365
+ CREATE INDEX IF NOT EXISTS {self.table_name}_space_idx
366
+ ON {self.table_name} (genie_space_id)
367
+ """
368
+
369
+ with self._pool.connection() as conn:
370
+ with conn.cursor() as cur:
371
+ cur.execute(create_extension_sql)
372
+
373
+ # Check if table exists and verify embedding dimensions
374
+ try:
375
+ cur.execute(check_dims_sql, (self.table_name,))
376
+ row: DbRow | None = cur.fetchone()
377
+ if row is not None:
378
+ # atttypmod for vector type contains the dimension
379
+ current_dims = row.get("atttypmod", 0)
380
+ if current_dims != self.embedding_dims:
381
+ logger.warning(
382
+ "Embedding dimension mismatch, dropping and recreating table",
383
+ layer=self.name,
384
+ table_dims=current_dims,
385
+ expected_dims=self.embedding_dims,
386
+ table_name=self.table_name,
387
+ )
388
+ cur.execute(f"DROP TABLE {self.table_name}")
389
+ except Exception:
390
+ # Table doesn't exist, which is fine
391
+ pass
392
+
393
+ cur.execute(create_table_sql)
394
+ cur.execute(create_space_index_sql)
395
+ cur.execute(create_question_embedding_index_sql)
396
+ cur.execute(create_context_embedding_index_sql)
397
+
398
+ def _embed_question(
399
+ self, question: str, conversation_id: str | None = None
400
+ ) -> tuple[list[float], list[float], str]:
401
+ """
402
+ Generate dual embeddings: one for the question, one for the conversation context.
403
+
404
+ This enables separate matching of question similarity vs context similarity,
405
+ improving precision by ensuring both the question AND the conversation context
406
+ are semantically similar before returning a cached result.
407
+
408
+ Args:
409
+ question: The question to embed
410
+ conversation_id: Optional conversation ID for retrieving context
411
+
412
+ Returns:
413
+ Tuple of (question_embedding, context_embedding, conversation_context_string)
414
+ - question_embedding: Vector embedding of just the question
415
+ - context_embedding: Vector embedding of the conversation context (or zero vector if no context)
416
+ - conversation_context_string: The conversation context string (empty if no context)
417
+ """
418
+ conversation_context = ""
419
+
420
+ # If conversation context is enabled and available
421
+ if (
422
+ self.workspace_client is not None
423
+ and conversation_id is not None
424
+ and self.parameters.context_window_size > 0
425
+ ):
426
+ try:
427
+ # Retrieve conversation history
428
+ conversation_messages = get_conversation_history(
429
+ workspace_client=self.workspace_client,
430
+ space_id=self.space_id,
431
+ conversation_id=conversation_id,
432
+ max_messages=self.parameters.context_window_size
433
+ * 2, # Get extra for safety
434
+ )
435
+
436
+ # Build context string (just the "Previous:" messages, not the current question)
437
+ if conversation_messages:
438
+ recent_messages = (
439
+ conversation_messages[-self.parameters.context_window_size :]
440
+ if len(conversation_messages)
441
+ > self.parameters.context_window_size
442
+ else conversation_messages
443
+ )
444
+
445
+ context_parts: list[str] = []
446
+ for msg in recent_messages:
447
+ if msg.content:
448
+ content: str = msg.content
449
+ if len(content) > 500:
450
+ content = content[:500] + "..."
451
+ context_parts.append(f"Previous: {content}")
452
+
453
+ conversation_context = "\n".join(context_parts)
454
+
455
+ # Truncate if too long
456
+ estimated_tokens = len(conversation_context) / 4
457
+ if estimated_tokens > self.parameters.max_context_tokens:
458
+ target_chars = self.parameters.max_context_tokens * 4
459
+ conversation_context = (
460
+ conversation_context[:target_chars] + "..."
461
+ )
462
+
463
+ logger.trace(
464
+ "Using conversation context",
465
+ layer=self.name,
466
+ messages_count=len(conversation_messages),
467
+ window_size=self.parameters.context_window_size,
468
+ )
469
+ except Exception as e:
470
+ logger.warning(
471
+ "Failed to build conversation context, using question only",
472
+ layer=self.name,
473
+ error=str(e),
474
+ )
475
+ conversation_context = ""
476
+
477
+ # Generate dual embeddings
478
+ if conversation_context:
479
+ # Embed both question and context
480
+ embeddings: list[list[float]] = self._embeddings.embed_documents(
481
+ [question, conversation_context]
482
+ )
483
+ question_embedding = embeddings[0]
484
+ context_embedding = embeddings[1]
485
+ else:
486
+ # Only embed question, use zero vector for context
487
+ embeddings = self._embeddings.embed_documents([question])
488
+ question_embedding = embeddings[0]
489
+ context_embedding = [0.0] * len(question_embedding) # Zero vector
490
+
491
+ return question_embedding, context_embedding, conversation_context
492
+
493
+ @mlflow.trace(name="semantic_search")
494
+ def _find_similar(
495
+ self,
496
+ question: str,
497
+ conversation_context: str,
498
+ question_embedding: list[float],
499
+ context_embedding: list[float],
500
+ ) -> tuple[SQLCacheEntry, float] | None:
501
+ """
502
+ Find a semantically similar cached entry using dual embedding matching.
503
+
504
+ This method matches BOTH the question AND the conversation context separately,
505
+ ensuring high precision by requiring both to be semantically similar.
506
+
507
+ Args:
508
+ question: The original question (for logging)
509
+ conversation_context: The conversation context string
510
+ question_embedding: The embedding vector of just the question
511
+ context_embedding: The embedding vector of the conversation context
512
+
513
+ Returns:
514
+ Tuple of (SQLCacheEntry, combined_similarity_score) if found, None otherwise
515
+ """
516
+ # Use L2 (Euclidean) distance - optimal for Databricks GTE embeddings
517
+ # pg_vector's <-> operator returns L2 distance (0 = identical)
518
+ # Convert to similarity: 1 / (1 + distance) gives range [0, 1]
519
+ #
520
+ # Dual embedding strategy:
521
+ # 1. Calculate separate similarities for question and context
522
+ # 2. BOTH must exceed their respective thresholds
523
+ # 3. Combined score is weighted average
524
+ # 4. Refresh-on-hit: check TTL after similarity check
525
+ ttl_seconds = self.parameters.time_to_live_seconds
526
+ ttl_disabled = ttl_seconds is None or ttl_seconds < 0
527
+
528
+ # When TTL is disabled, all entries are always valid
529
+ if ttl_disabled:
530
+ is_valid_expr = "TRUE"
531
+ else:
532
+ is_valid_expr = f"created_at > NOW() - INTERVAL '{ttl_seconds} seconds'"
533
+
534
+ # Weighted combined similarity for ordering
535
+ question_weight: float = self.parameters.question_weight
536
+ context_weight: float = self.parameters.context_weight
537
+
538
+ search_sql: str = f"""
539
+ SELECT
540
+ id,
541
+ question,
542
+ conversation_context,
543
+ sql_query,
544
+ description,
545
+ conversation_id,
546
+ created_at,
547
+ 1.0 / (1.0 + (question_embedding <-> %s::vector)) as question_similarity,
548
+ 1.0 / (1.0 + (context_embedding <-> %s::vector)) as context_similarity,
549
+ ({question_weight} * (1.0 / (1.0 + (question_embedding <-> %s::vector)))) +
550
+ ({context_weight} * (1.0 / (1.0 + (context_embedding <-> %s::vector)))) as combined_similarity,
551
+ {is_valid_expr} as is_valid
552
+ FROM {self.table_name}
553
+ WHERE genie_space_id = %s
554
+ ORDER BY combined_similarity DESC
555
+ LIMIT 1
556
+ """
557
+
558
+ question_emb_str: str = f"[{','.join(str(x) for x in question_embedding)}]"
559
+ context_emb_str: str = f"[{','.join(str(x) for x in context_embedding)}]"
560
+
561
+ with self._pool.connection() as conn:
562
+ with conn.cursor() as cur:
563
+ cur.execute(
564
+ search_sql,
565
+ (
566
+ question_emb_str,
567
+ context_emb_str,
568
+ question_emb_str,
569
+ context_emb_str,
570
+ self.space_id,
571
+ ),
572
+ )
573
+ row: DbRow | None = cur.fetchone()
574
+
575
+ if row is None:
576
+ logger.info(
577
+ "Cache MISS (no entries)",
578
+ layer=self.name,
579
+ question_prefix=question[:50],
580
+ space=self.space_id,
581
+ )
582
+ return None
583
+
584
+ # Extract values from dict row
585
+ entry_id: Any = row.get("id")
586
+ cached_question: str = row.get("question", "")
587
+ cached_context: str = row.get("conversation_context", "")
588
+ sql_query: str = row["sql_query"]
589
+ description: str = row.get("description", "")
590
+ conversation_id_cached: str = row.get("conversation_id", "")
591
+ created_at: Any = row["created_at"]
592
+ question_similarity: float = row["question_similarity"]
593
+ context_similarity: float = row["context_similarity"]
594
+ combined_similarity: float = row["combined_similarity"]
595
+ is_valid: bool = row.get("is_valid", False)
596
+
597
+ # Log best match info
598
+ logger.debug(
599
+ "Best match found",
600
+ layer=self.name,
601
+ question_sim=f"{question_similarity:.4f}",
602
+ context_sim=f"{context_similarity:.4f}",
603
+ combined_sim=f"{combined_similarity:.4f}",
604
+ is_valid=is_valid,
605
+ cached_question_prefix=cached_question[:50],
606
+ cached_context_prefix=cached_context[:80],
607
+ )
608
+
609
+ # Check BOTH similarity thresholds (dual embedding precision check)
610
+ if question_similarity < self.parameters.similarity_threshold:
611
+ logger.info(
612
+ "Cache MISS (question similarity too low)",
613
+ layer=self.name,
614
+ question_sim=f"{question_similarity:.4f}",
615
+ threshold=self.parameters.similarity_threshold,
616
+ )
617
+ return None
618
+
619
+ if context_similarity < self.parameters.context_similarity_threshold:
620
+ logger.info(
621
+ "Cache MISS (context similarity too low)",
622
+ layer=self.name,
623
+ context_sim=f"{context_similarity:.4f}",
624
+ threshold=self.parameters.context_similarity_threshold,
625
+ )
626
+ return None
627
+
628
+ # Check TTL - refresh on hit strategy
629
+ if not is_valid:
630
+ # Entry is expired - delete it and return miss to trigger refresh
631
+ delete_sql = f"DELETE FROM {self.table_name} WHERE id = %s"
632
+ cur.execute(delete_sql, (entry_id,))
633
+ logger.info(
634
+ "Cache MISS (expired, deleted for refresh)",
635
+ layer=self.name,
636
+ combined_sim=f"{combined_similarity:.4f}",
637
+ ttl_seconds=ttl_seconds,
638
+ cached_question_prefix=cached_question[:50],
639
+ )
640
+ return None
641
+
642
+ logger.info(
643
+ "Cache HIT",
644
+ layer=self.name,
645
+ question_sim=f"{question_similarity:.4f}",
646
+ context_sim=f"{context_similarity:.4f}",
647
+ combined_sim=f"{combined_similarity:.4f}",
648
+ cached_question_prefix=cached_question[:50],
649
+ )
650
+
651
+ entry = SQLCacheEntry(
652
+ query=sql_query,
653
+ description=description,
654
+ conversation_id=conversation_id_cached,
655
+ created_at=created_at,
656
+ )
657
+ return entry, combined_similarity
658
+
659
+ def _store_entry(
660
+ self,
661
+ question: str,
662
+ conversation_context: str,
663
+ question_embedding: list[float],
664
+ context_embedding: list[float],
665
+ response: GenieResponse,
666
+ ) -> None:
667
+ """Store a new cache entry with dual embeddings for this Genie space."""
668
+ insert_sql: str = f"""
669
+ INSERT INTO {self.table_name}
670
+ (genie_space_id, question, conversation_context, context_string,
671
+ question_embedding, context_embedding, sql_query, description, conversation_id)
672
+ VALUES (%s, %s, %s, %s, %s::vector, %s::vector, %s, %s, %s)
673
+ """
674
+ question_emb_str: str = f"[{','.join(str(x) for x in question_embedding)}]"
675
+ context_emb_str: str = f"[{','.join(str(x) for x in context_embedding)}]"
676
+
677
+ # Build full context string for backward compatibility (used in logging)
678
+ if conversation_context:
679
+ full_context_string = f"{conversation_context}\nCurrent: {question}"
680
+ else:
681
+ full_context_string = question
682
+
683
+ with self._pool.connection() as conn:
684
+ with conn.cursor() as cur:
685
+ cur.execute(
686
+ insert_sql,
687
+ (
688
+ self.space_id,
689
+ question,
690
+ conversation_context,
691
+ full_context_string,
692
+ question_emb_str,
693
+ context_emb_str,
694
+ response.query,
695
+ response.description,
696
+ response.conversation_id,
697
+ ),
698
+ )
699
+ logger.info(
700
+ "Stored cache entry",
701
+ layer=self.name,
702
+ question_prefix=question[:50],
703
+ context_prefix=conversation_context[:80],
704
+ sql_prefix=response.query[:50] if response.query else None,
705
+ space=self.space_id,
706
+ table=self.table_name,
707
+ )
708
+
709
+ @mlflow.trace(name="execute_cached_sql_semantic")
710
+ def _execute_sql(self, sql: str) -> pd.DataFrame | str:
711
+ """Execute SQL using the warehouse and return results."""
712
+ client: WorkspaceClient = self.warehouse.workspace_client
713
+ warehouse_id: str = str(self.warehouse.warehouse_id)
714
+
715
+ statement_response: StatementResponse = (
716
+ client.statement_execution.execute_statement(
717
+ warehouse_id=warehouse_id,
718
+ statement=sql,
719
+ wait_timeout="30s",
720
+ )
721
+ )
722
+
723
+ if (
724
+ statement_response.status is not None
725
+ and statement_response.status.state != StatementState.SUCCEEDED
726
+ ):
727
+ error_msg: str = (
728
+ f"SQL execution failed: {statement_response.status.error.message}"
729
+ if statement_response.status.error is not None
730
+ else f"SQL execution failed with state: {statement_response.status.state}"
731
+ )
732
+ logger.error("SQL execution failed", layer=self.name, error=error_msg)
733
+ return error_msg
734
+
735
+ if statement_response.result and statement_response.result.data_array:
736
+ columns: list[str] = []
737
+ if (
738
+ statement_response.manifest
739
+ and statement_response.manifest.schema
740
+ and statement_response.manifest.schema.columns
741
+ ):
742
+ columns = [
743
+ col.name
744
+ for col in statement_response.manifest.schema.columns
745
+ if col.name is not None
746
+ ]
747
+
748
+ data: list[list[Any]] = statement_response.result.data_array
749
+ if columns:
750
+ return pd.DataFrame(data, columns=columns)
751
+ else:
752
+ return pd.DataFrame(data)
753
+
754
+ return pd.DataFrame()
755
+
756
+ def ask_question(
757
+ self, question: str, conversation_id: str | None = None
758
+ ) -> CacheResult:
759
+ """
760
+ Ask a question, using semantic cache if a similar query exists.
761
+
762
+ On cache hit, re-executes the cached SQL to get fresh data.
763
+ Returns CacheResult with cache metadata.
764
+ """
765
+ return self.ask_question_with_cache_info(question, conversation_id)
766
+
767
+ @mlflow.trace(name="genie_semantic_cache_lookup")
768
+ def ask_question_with_cache_info(
769
+ self,
770
+ question: str,
771
+ conversation_id: str | None = None,
772
+ ) -> CacheResult:
773
+ """
774
+ Ask a question with detailed cache hit information.
775
+
776
+ On cache hit, the cached SQL is re-executed to return fresh data, but the
777
+ conversation_id returned is the current conversation_id (not the cached one).
778
+
779
+ Args:
780
+ question: The question to ask
781
+ conversation_id: Optional conversation ID for context and continuation
782
+
783
+ Returns:
784
+ CacheResult with fresh response and cache metadata
785
+ """
786
+ # Ensure initialization (lazy init if initialize() wasn't called)
787
+ self._setup()
788
+
789
+ # Generate dual embeddings for the question and conversation context
790
+ question_embedding: list[float]
791
+ context_embedding: list[float]
792
+ conversation_context: str
793
+ question_embedding, context_embedding, conversation_context = (
794
+ self._embed_question(question, conversation_id)
795
+ )
796
+
797
+ # Check cache using dual embedding similarity
798
+ cache_result: tuple[SQLCacheEntry, float] | None = self._find_similar(
799
+ question, conversation_context, question_embedding, context_embedding
800
+ )
801
+
802
+ if cache_result is not None:
803
+ cached, combined_similarity = cache_result
804
+ logger.debug(
805
+ "Semantic cache hit",
806
+ layer=self.name,
807
+ combined_similarity=f"{combined_similarity:.3f}",
808
+ question_prefix=question[:50],
809
+ )
810
+
811
+ # Re-execute the cached SQL to get fresh data
812
+ result: pd.DataFrame | str = self._execute_sql(cached.query)
813
+
814
+ # IMPORTANT: Use the current conversation_id (from the request), not the cached one
815
+ # This ensures the conversation continues properly
816
+ response: GenieResponse = GenieResponse(
817
+ result=result,
818
+ query=cached.query,
819
+ description=cached.description,
820
+ conversation_id=conversation_id
821
+ if conversation_id
822
+ else cached.conversation_id,
823
+ )
824
+
825
+ return CacheResult(response=response, cache_hit=True, served_by=self.name)
826
+
827
+ # Cache miss - delegate to wrapped service
828
+ logger.trace("Cache miss", layer=self.name, question_prefix=question[:50])
829
+
830
+ result: CacheResult = self.impl.ask_question(question, conversation_id)
831
+
832
+ # Store in cache if we got a SQL query
833
+ if result.response.query:
834
+ logger.info(
835
+ "Storing new cache entry",
836
+ layer=self.name,
837
+ question_prefix=question[:50],
838
+ space=self.space_id,
839
+ )
840
+ self._store_entry(
841
+ question,
842
+ conversation_context,
843
+ question_embedding,
844
+ context_embedding,
845
+ result.response,
846
+ )
847
+ elif not result.response.query:
848
+ logger.warning(
849
+ "Not caching: response has no SQL query",
850
+ layer=self.name,
851
+ question_prefix=question[:50],
852
+ )
853
+
854
+ return CacheResult(response=result.response, cache_hit=False, served_by=None)
855
+
856
+ @property
857
+ def space_id(self) -> str:
858
+ return self.impl.space_id
859
+
860
+ def invalidate_expired(self) -> int:
861
+ """Remove expired entries from the cache for this Genie space.
862
+
863
+ Returns 0 if TTL is disabled (entries never expire).
864
+ """
865
+ self._setup()
866
+ ttl_seconds = self.parameters.time_to_live_seconds
867
+
868
+ # If TTL is disabled, nothing can expire
869
+ if ttl_seconds is None or ttl_seconds < 0:
870
+ logger.trace(
871
+ "TTL disabled, no entries to expire",
872
+ layer=self.name,
873
+ space=self.space_id,
874
+ )
875
+ return 0
876
+
877
+ delete_sql: str = f"""
878
+ DELETE FROM {self.table_name}
879
+ WHERE genie_space_id = %s
880
+ AND created_at < NOW() - INTERVAL '%s seconds'
881
+ """
882
+
883
+ with self._pool.connection() as conn:
884
+ with conn.cursor() as cur:
885
+ cur.execute(delete_sql, (self.space_id, ttl_seconds))
886
+ deleted: int = cur.rowcount
887
+ logger.trace(
888
+ "Deleted expired entries",
889
+ layer=self.name,
890
+ deleted_count=deleted,
891
+ space=self.space_id,
892
+ )
893
+ return deleted
894
+
895
+ def clear(self) -> int:
896
+ """Clear all entries from the cache for this Genie space."""
897
+ self._setup()
898
+ delete_sql: str = f"DELETE FROM {self.table_name} WHERE genie_space_id = %s"
899
+
900
+ with self._pool.connection() as conn:
901
+ with conn.cursor() as cur:
902
+ cur.execute(delete_sql, (self.space_id,))
903
+ deleted: int = cur.rowcount
904
+ logger.debug(
905
+ "Cleared cache entries",
906
+ layer=self.name,
907
+ deleted_count=deleted,
908
+ space=self.space_id,
909
+ )
910
+ return deleted
911
+
912
+ @property
913
+ def size(self) -> int:
914
+ """Current number of entries in the cache for this Genie space."""
915
+ self._setup()
916
+ count_sql: str = (
917
+ f"SELECT COUNT(*) as count FROM {self.table_name} WHERE genie_space_id = %s"
918
+ )
919
+
920
+ with self._pool.connection() as conn:
921
+ with conn.cursor() as cur:
922
+ cur.execute(count_sql, (self.space_id,))
923
+ row: DbRow | None = cur.fetchone()
924
+ return row.get("count", 0) if row else 0
925
+
926
+ def stats(self) -> dict[str, int | float | None]:
927
+ """Return cache statistics for this Genie space."""
928
+ self._setup()
929
+ ttl_seconds = self.parameters.time_to_live_seconds
930
+ ttl = self.time_to_live
931
+
932
+ # If TTL is disabled, all entries are valid
933
+ if ttl_seconds is None or ttl_seconds < 0:
934
+ count_sql: str = f"""
935
+ SELECT COUNT(*) as total FROM {self.table_name}
936
+ WHERE genie_space_id = %s
937
+ """
938
+ with self._pool.connection() as conn:
939
+ with conn.cursor() as cur:
940
+ cur.execute(count_sql, (self.space_id,))
941
+ row: DbRow | None = cur.fetchone()
942
+ total = row.get("total", 0) if row else 0
943
+ return {
944
+ "size": total,
945
+ "ttl_seconds": None,
946
+ "similarity_threshold": self.similarity_threshold,
947
+ "expired_entries": 0,
948
+ "valid_entries": total,
949
+ }
950
+
951
+ stats_sql: str = f"""
952
+ SELECT
953
+ COUNT(*) as total,
954
+ COUNT(*) FILTER (WHERE created_at > NOW() - INTERVAL '%s seconds') as valid,
955
+ COUNT(*) FILTER (WHERE created_at <= NOW() - INTERVAL '%s seconds') as expired
956
+ FROM {self.table_name}
957
+ WHERE genie_space_id = %s
958
+ """
959
+
960
+ with self._pool.connection() as conn:
961
+ with conn.cursor() as cur:
962
+ cur.execute(stats_sql, (ttl_seconds, ttl_seconds, self.space_id))
963
+ stats_row: DbRow | None = cur.fetchone()
964
+ return {
965
+ "size": stats_row.get("total", 0) if stats_row else 0,
966
+ "ttl_seconds": ttl.total_seconds() if ttl else None,
967
+ "similarity_threshold": self.similarity_threshold,
968
+ "expired_entries": stats_row.get("expired", 0) if stats_row else 0,
969
+ "valid_entries": stats_row.get("valid", 0) if stats_row else 0,
970
+ }