dao-ai 0.1.2__py3-none-any.whl → 0.1.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. dao_ai/apps/__init__.py +24 -0
  2. dao_ai/apps/handlers.py +105 -0
  3. dao_ai/apps/model_serving.py +29 -0
  4. dao_ai/apps/resources.py +1122 -0
  5. dao_ai/apps/server.py +39 -0
  6. dao_ai/cli.py +546 -37
  7. dao_ai/config.py +1179 -139
  8. dao_ai/evaluation.py +543 -0
  9. dao_ai/genie/__init__.py +55 -7
  10. dao_ai/genie/cache/__init__.py +34 -7
  11. dao_ai/genie/cache/base.py +143 -2
  12. dao_ai/genie/cache/context_aware/__init__.py +31 -0
  13. dao_ai/genie/cache/context_aware/base.py +1151 -0
  14. dao_ai/genie/cache/context_aware/in_memory.py +609 -0
  15. dao_ai/genie/cache/context_aware/persistent.py +802 -0
  16. dao_ai/genie/cache/context_aware/postgres.py +1166 -0
  17. dao_ai/genie/cache/core.py +1 -1
  18. dao_ai/genie/cache/lru.py +257 -75
  19. dao_ai/genie/cache/optimization.py +890 -0
  20. dao_ai/genie/core.py +235 -11
  21. dao_ai/memory/postgres.py +175 -39
  22. dao_ai/middleware/__init__.py +38 -0
  23. dao_ai/middleware/assertions.py +3 -3
  24. dao_ai/middleware/context_editing.py +230 -0
  25. dao_ai/middleware/core.py +4 -4
  26. dao_ai/middleware/guardrails.py +3 -3
  27. dao_ai/middleware/human_in_the_loop.py +3 -2
  28. dao_ai/middleware/message_validation.py +4 -4
  29. dao_ai/middleware/model_call_limit.py +77 -0
  30. dao_ai/middleware/model_retry.py +121 -0
  31. dao_ai/middleware/pii.py +157 -0
  32. dao_ai/middleware/summarization.py +1 -1
  33. dao_ai/middleware/tool_call_limit.py +210 -0
  34. dao_ai/middleware/tool_retry.py +174 -0
  35. dao_ai/middleware/tool_selector.py +129 -0
  36. dao_ai/models.py +327 -370
  37. dao_ai/nodes.py +9 -16
  38. dao_ai/orchestration/core.py +33 -9
  39. dao_ai/orchestration/supervisor.py +29 -13
  40. dao_ai/orchestration/swarm.py +6 -1
  41. dao_ai/{prompts.py → prompts/__init__.py} +12 -61
  42. dao_ai/prompts/instructed_retriever_decomposition.yaml +58 -0
  43. dao_ai/prompts/instruction_reranker.yaml +14 -0
  44. dao_ai/prompts/router.yaml +37 -0
  45. dao_ai/prompts/verifier.yaml +46 -0
  46. dao_ai/providers/base.py +28 -2
  47. dao_ai/providers/databricks.py +363 -33
  48. dao_ai/state.py +1 -0
  49. dao_ai/tools/__init__.py +5 -3
  50. dao_ai/tools/genie.py +103 -26
  51. dao_ai/tools/instructed_retriever.py +366 -0
  52. dao_ai/tools/instruction_reranker.py +202 -0
  53. dao_ai/tools/mcp.py +539 -97
  54. dao_ai/tools/router.py +89 -0
  55. dao_ai/tools/slack.py +13 -2
  56. dao_ai/tools/sql.py +7 -3
  57. dao_ai/tools/unity_catalog.py +32 -10
  58. dao_ai/tools/vector_search.py +493 -160
  59. dao_ai/tools/verifier.py +159 -0
  60. dao_ai/utils.py +182 -2
  61. dao_ai/vector_search.py +46 -1
  62. {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/METADATA +45 -9
  63. dao_ai-0.1.20.dist-info/RECORD +89 -0
  64. dao_ai/agent_as_code.py +0 -22
  65. dao_ai/genie/cache/semantic.py +0 -970
  66. dao_ai-0.1.2.dist-info/RECORD +0 -64
  67. {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/WHEEL +0 -0
  68. {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/entry_points.txt +0 -0
  69. {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,1151 @@
1
+ """
2
+ Abstract base class for context-aware Genie cache implementations.
3
+
4
+ This module provides the foundational abstract base class for all context-aware
5
+ cache implementations. It extracts common code for:
6
+ - Dual embedding generation (question + conversation context)
7
+ - Ask question flow with error handling and graceful fallback
8
+ - SQL execution with retry logic
9
+ - Common properties and initialization patterns
10
+
11
+ Subclasses must implement storage-specific methods:
12
+ - _find_similar(): Find semantically similar cached entry
13
+ - _store_entry(): Store new cache entry
14
+ - _setup(): Initialize resources (embeddings, storage)
15
+ - invalidate_expired(): Remove expired entries
16
+ - clear(): Clear all entries for space
17
+ - stats(): Return cache statistics
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ from abc import abstractmethod
23
+ from datetime import timedelta
24
+ from typing import Any, Self, TypeVar
25
+
26
+ import mlflow
27
+ import pandas as pd
28
+ from databricks.sdk import WorkspaceClient
29
+ from databricks.sdk.service.dashboards import (
30
+ GenieFeedbackRating,
31
+ GenieListConversationMessagesResponse,
32
+ GenieMessage,
33
+ )
34
+ from databricks_ai_bridge.genie import GenieResponse
35
+ from loguru import logger
36
+
37
+ from dao_ai.config import LLMModel, WarehouseModel
38
+ from dao_ai.genie.cache.base import (
39
+ CacheResult,
40
+ GenieServiceBase,
41
+ SQLCacheEntry,
42
+ )
43
+ from dao_ai.genie.cache.core import execute_sql_via_warehouse
44
+
45
+ # Type variable for subclass return types
46
+ T = TypeVar("T", bound="ContextAwareGenieService")
47
+
48
+
49
+ def get_conversation_history(
50
+ workspace_client: WorkspaceClient,
51
+ space_id: str,
52
+ conversation_id: str,
53
+ max_messages: int = 10,
54
+ ) -> list[GenieMessage]:
55
+ """
56
+ Retrieve conversation history from Genie.
57
+
58
+ Args:
59
+ workspace_client: The Databricks workspace client
60
+ space_id: The Genie space ID
61
+ conversation_id: The conversation ID to retrieve
62
+ max_messages: Maximum number of messages to retrieve
63
+
64
+ Returns:
65
+ List of GenieMessage objects representing the conversation history
66
+ """
67
+ try:
68
+ # Use the Genie API to retrieve conversation messages
69
+ response: GenieListConversationMessagesResponse = (
70
+ workspace_client.genie.list_conversation_messages(
71
+ space_id=space_id,
72
+ conversation_id=conversation_id,
73
+ )
74
+ )
75
+
76
+ # Return the most recent messages up to max_messages
77
+ if response.messages is not None:
78
+ all_messages: list[GenieMessage] = list(response.messages)
79
+ return (
80
+ all_messages[-max_messages:]
81
+ if len(all_messages) > max_messages
82
+ else all_messages
83
+ )
84
+ return []
85
+ except Exception as e:
86
+ logger.warning(
87
+ "Failed to retrieve conversation history",
88
+ conversation_id=conversation_id,
89
+ error=str(e),
90
+ )
91
+ return []
92
+
93
+
94
+ def build_context_string(
95
+ question: str,
96
+ conversation_messages: list[GenieMessage],
97
+ window_size: int,
98
+ max_tokens: int = 2000,
99
+ ) -> str:
100
+ """
101
+ Build a context-aware question string using rolling window.
102
+
103
+ This function creates a concatenated string that includes recent conversation
104
+ turns to provide context for semantic similarity matching.
105
+
106
+ Args:
107
+ question: The current question
108
+ conversation_messages: List of previous conversation messages
109
+ window_size: Number of previous turns to include
110
+ max_tokens: Maximum estimated tokens (rough approximation: 4 chars = 1 token)
111
+
112
+ Returns:
113
+ Context-aware question string formatted for embedding
114
+ """
115
+ if window_size <= 0 or not conversation_messages:
116
+ return question
117
+
118
+ # Take the last window_size messages (most recent)
119
+ recent_messages = (
120
+ conversation_messages[-window_size:]
121
+ if len(conversation_messages) > window_size
122
+ else conversation_messages
123
+ )
124
+
125
+ # Build context parts
126
+ context_parts: list[str] = []
127
+
128
+ for msg in recent_messages:
129
+ # Only include messages with content from the history
130
+ if msg.content:
131
+ # Limit message length to prevent token overflow
132
+ content: str = msg.content
133
+ if len(content) > 500: # Truncate very long messages
134
+ content = content[:500] + "..."
135
+ context_parts.append(f"Previous: {content}")
136
+
137
+ # Add current question
138
+ context_parts.append(f"Current: {question}")
139
+
140
+ # Join with newlines
141
+ context_string = "\n".join(context_parts)
142
+
143
+ # Rough token limit check (4 chars ≈ 1 token)
144
+ estimated_tokens = len(context_string) / 4
145
+ if estimated_tokens > max_tokens:
146
+ # Truncate to fit max_tokens
147
+ target_chars = max_tokens * 4
148
+ original_length = len(context_string)
149
+ context_string = context_string[:target_chars] + "..."
150
+ logger.trace(
151
+ "Truncated context string",
152
+ original_chars=original_length,
153
+ target_chars=target_chars,
154
+ max_tokens=max_tokens,
155
+ )
156
+
157
+ return context_string
158
+
159
+
160
+ class ContextAwareGenieService(GenieServiceBase):
161
+ """
162
+ Abstract base class for context-aware Genie cache implementations.
163
+
164
+ This class provides shared implementation for:
165
+ - Dual embedding generation (question + conversation context)
166
+ - Main ask_question flow with error handling
167
+ - SQL execution with warehouse
168
+ - Common properties (time_to_live, similarity_threshold, etc.)
169
+
170
+ Subclasses must implement storage-specific methods for finding similar
171
+ entries, storing new entries, and managing cache lifecycle.
172
+
173
+ Error Handling:
174
+ All cache operations are wrapped in try/except to ensure graceful
175
+ degradation. If any cache operation fails, the request is delegated
176
+ to the underlying service without caching.
177
+
178
+ Thread Safety:
179
+ Subclasses are responsible for thread safety of storage operations.
180
+ This base class does not provide synchronization primitives.
181
+ """
182
+
183
+ # Common attributes - subclasses should define these
184
+ impl: GenieServiceBase
185
+ _workspace_client: WorkspaceClient | None
186
+ name: str
187
+ _embeddings: Any # DatabricksEmbeddings
188
+ _embedding_dims: int | None
189
+ _setup_complete: bool
190
+
191
+ # Abstract methods that subclasses must implement
192
+ @abstractmethod
193
+ def _setup(self) -> None:
194
+ """
195
+ Initialize resources required by the cache implementation.
196
+
197
+ This method is called lazily before first use. Implementations should:
198
+ - Initialize embedding model
199
+ - Set up storage (database connection, in-memory structures, etc.)
200
+ - Create necessary tables/indexes if applicable
201
+
202
+ This method should be idempotent (safe to call multiple times).
203
+ """
204
+ pass
205
+
206
+ @abstractmethod
207
+ def _find_similar(
208
+ self,
209
+ question: str,
210
+ conversation_context: str,
211
+ question_embedding: list[float],
212
+ context_embedding: list[float],
213
+ conversation_id: str | None = None,
214
+ ) -> tuple[SQLCacheEntry, float] | None:
215
+ """
216
+ Find a semantically similar cached entry using dual embedding matching.
217
+
218
+ Args:
219
+ question: The original question (for logging)
220
+ conversation_context: The conversation context string
221
+ question_embedding: The embedding vector of just the question
222
+ context_embedding: The embedding vector of the conversation context
223
+ conversation_id: Optional conversation ID (for logging)
224
+
225
+ Returns:
226
+ Tuple of (SQLCacheEntry, combined_similarity_score) if found, None otherwise
227
+ """
228
+ pass
229
+
230
+ @abstractmethod
231
+ def _store_entry(
232
+ self,
233
+ question: str,
234
+ conversation_context: str,
235
+ question_embedding: list[float],
236
+ context_embedding: list[float],
237
+ response: GenieResponse,
238
+ message_id: str | None = None,
239
+ ) -> None:
240
+ """
241
+ Store a new cache entry with dual embeddings.
242
+
243
+ Args:
244
+ question: The user's question
245
+ conversation_context: Previous conversation context string
246
+ question_embedding: Embedding of the question
247
+ context_embedding: Embedding of the conversation context
248
+ response: The GenieResponse containing query, description, etc.
249
+ message_id: The Genie message ID from the original API response.
250
+ Stored with the cache entry to enable feedback on cache hits.
251
+ """
252
+ pass
253
+
254
+ def invalidate_expired(self) -> int | dict[str, int]:
255
+ """
256
+ Template method for removing expired entries from the cache.
257
+
258
+ This method implements the TTL check and delegates to
259
+ _delete_expired_entries() for the actual deletion.
260
+
261
+ Returns:
262
+ Number of entries deleted, or dict with counts by category
263
+ """
264
+ self._setup()
265
+ ttl_seconds = self.time_to_live_seconds
266
+
267
+ if ttl_seconds is None or ttl_seconds < 0:
268
+ return self._get_empty_expiration_result()
269
+
270
+ return self._delete_expired_entries(ttl_seconds)
271
+
272
+ @abstractmethod
273
+ def _delete_expired_entries(self, ttl_seconds: int) -> int | dict[str, int]:
274
+ """
275
+ Delete expired entries from storage.
276
+
277
+ Args:
278
+ ttl_seconds: TTL in seconds for determining expiration
279
+
280
+ Returns:
281
+ Number of entries deleted, or dict with counts by category
282
+ """
283
+ pass
284
+
285
+ def _get_empty_expiration_result(self) -> int | dict[str, int]:
286
+ """
287
+ Return the empty result for invalidate_expired when TTL is disabled.
288
+
289
+ Override this in subclasses that return dict to return appropriate empty dict.
290
+
291
+ Returns:
292
+ 0 by default, or empty dict for subclasses that return dict
293
+ """
294
+ return 0
295
+
296
+ def clear(self) -> int:
297
+ """
298
+ Template method for clearing all entries from the cache.
299
+
300
+ This method calls _setup() and delegates to _delete_all_entries().
301
+
302
+ Returns:
303
+ Number of entries deleted
304
+ """
305
+ self._setup()
306
+ return self._delete_all_entries()
307
+
308
+ @abstractmethod
309
+ def _delete_all_entries(self) -> int:
310
+ """
311
+ Delete all entries for this Genie space from storage.
312
+
313
+ Returns:
314
+ Number of entries deleted
315
+ """
316
+ pass
317
+
318
+ def stats(self) -> dict[str, Any]:
319
+ """
320
+ Template method for returning cache statistics.
321
+
322
+ This method uses the Template Method pattern to consolidate the common
323
+ stats calculation algorithm. Subclasses provide counting implementations
324
+ via abstract methods and can add additional stats via hook methods.
325
+
326
+ Returns:
327
+ Dict with cache statistics (size, ttl, thresholds, etc.)
328
+ """
329
+ self._setup()
330
+ ttl_seconds = self.time_to_live_seconds
331
+ ttl = self.time_to_live
332
+
333
+ # Calculate base stats using abstract counting methods
334
+ if ttl_seconds is None or ttl_seconds < 0:
335
+ total = self._count_all_entries()
336
+ base_stats: dict[str, Any] = {
337
+ "size": total,
338
+ "ttl_seconds": None,
339
+ "similarity_threshold": self.similarity_threshold,
340
+ "expired_entries": 0,
341
+ "valid_entries": total,
342
+ }
343
+ else:
344
+ total, expired = self._count_entries_with_ttl(ttl_seconds)
345
+ base_stats = {
346
+ "size": total,
347
+ "ttl_seconds": ttl.total_seconds() if ttl else None,
348
+ "similarity_threshold": self.similarity_threshold,
349
+ "expired_entries": expired,
350
+ "valid_entries": total - expired,
351
+ }
352
+
353
+ # Add any additional stats from subclasses
354
+ additional_stats = self._get_additional_stats()
355
+ base_stats.update(additional_stats)
356
+
357
+ return base_stats
358
+
359
+ @abstractmethod
360
+ def _count_all_entries(self) -> int:
361
+ """
362
+ Count all cache entries for this Genie space.
363
+
364
+ Returns:
365
+ Total number of cache entries
366
+ """
367
+ pass
368
+
369
+ @abstractmethod
370
+ def _count_entries_with_ttl(self, ttl_seconds: int) -> tuple[int, int]:
371
+ """
372
+ Count total and expired entries for this Genie space.
373
+
374
+ Args:
375
+ ttl_seconds: TTL in seconds for determining expiration
376
+
377
+ Returns:
378
+ Tuple of (total_entries, expired_entries)
379
+ """
380
+ pass
381
+
382
+ def _get_additional_stats(self) -> dict[str, Any]:
383
+ """
384
+ Hook method to add additional stats from subclasses.
385
+
386
+ Override this method to add subclass-specific statistics like
387
+ capacity (in-memory) or prompt history stats (postgres).
388
+
389
+ Returns:
390
+ Dict with additional stats to merge into base stats
391
+ """
392
+ return {}
393
+
394
+ # Properties that subclasses should implement or inherit
395
+ @property
396
+ @abstractmethod
397
+ def warehouse(self) -> WarehouseModel:
398
+ """The warehouse used for executing cached SQL queries."""
399
+ pass
400
+
401
+ @property
402
+ @abstractmethod
403
+ def time_to_live(self) -> timedelta | None:
404
+ """Time-to-live for cache entries. None means never expires."""
405
+ pass
406
+
407
+ @property
408
+ @abstractmethod
409
+ def similarity_threshold(self) -> float:
410
+ """Minimum similarity for cache hit (using L2 distance converted to similarity)."""
411
+ pass
412
+
413
+ @property
414
+ def embedding_dims(self) -> int:
415
+ """Dimension size for embeddings (auto-detected if not configured)."""
416
+ if self._embedding_dims is None:
417
+ raise RuntimeError(
418
+ "Embedding dimensions not yet initialized. Call _setup() first."
419
+ )
420
+ return self._embedding_dims
421
+
422
+ @property
423
+ def space_id(self) -> str:
424
+ """The Genie space ID from the underlying service."""
425
+ return self.impl.space_id
426
+
427
+ @property
428
+ def workspace_client(self) -> WorkspaceClient | None:
429
+ """Get workspace client, delegating to impl if not set."""
430
+ if self._workspace_client is not None:
431
+ return self._workspace_client
432
+ return self.impl.workspace_client
433
+
434
+ @property
435
+ def time_to_live_seconds(self) -> int | None:
436
+ """TTL in seconds (None or negative = never expires)."""
437
+ ttl = self.time_to_live
438
+ if ttl is None:
439
+ return None
440
+ return int(ttl.total_seconds())
441
+
442
+ # Abstract method for embedding - subclasses must implement
443
+ @abstractmethod
444
+ def _embed_question(
445
+ self, question: str, conversation_id: str | None = None
446
+ ) -> tuple[list[float], list[float], str]:
447
+ """
448
+ Generate dual embeddings for a question with conversation context.
449
+
450
+ Args:
451
+ question: The question to embed
452
+ conversation_id: Optional conversation ID for retrieving context
453
+
454
+ Returns:
455
+ Tuple of (question_embedding, context_embedding, conversation_context_string)
456
+ """
457
+ pass
458
+
459
+ # Shared implementation methods
460
+ def initialize(self) -> Self:
461
+ """
462
+ Eagerly initialize the cache service.
463
+
464
+ Call this during tool creation to:
465
+ - Validate configuration early (fail fast)
466
+ - Initialize resources before any requests
467
+ - Avoid first-request latency from lazy initialization
468
+
469
+ Returns:
470
+ self for method chaining
471
+ """
472
+ self._setup()
473
+ return self
474
+
475
+ def _initialize_embeddings(
476
+ self,
477
+ embedding_model: str | LLMModel,
478
+ embedding_dims: int | None = None,
479
+ ) -> None:
480
+ """
481
+ Initialize the embeddings model and detect dimensions.
482
+
483
+ This helper method handles embedding model initialization for subclasses.
484
+
485
+ Args:
486
+ embedding_model: The embedding model name or LLMModel instance
487
+ embedding_dims: Optional pre-configured embedding dimensions
488
+ """
489
+ # Convert embedding_model to LLMModel if it's a string
490
+ model: LLMModel = (
491
+ LLMModel(name=embedding_model)
492
+ if isinstance(embedding_model, str)
493
+ else embedding_model
494
+ )
495
+ self._embeddings = model.as_embeddings_model()
496
+
497
+ # Auto-detect embedding dimensions if not provided
498
+ if embedding_dims is None:
499
+ sample_embedding: list[float] = self._embeddings.embed_query("test")
500
+ self._embedding_dims = len(sample_embedding)
501
+ logger.debug(
502
+ "Auto-detected embedding dimensions",
503
+ layer=self.name,
504
+ dims=self._embedding_dims,
505
+ )
506
+ else:
507
+ self._embedding_dims = embedding_dims
508
+
509
+ def _embed_question_with_genie_history(
510
+ self,
511
+ question: str,
512
+ conversation_id: str | None,
513
+ context_window_size: int,
514
+ max_context_tokens: int,
515
+ ) -> tuple[list[float], list[float], str]:
516
+ """
517
+ Generate dual embeddings using Genie API for conversation history.
518
+
519
+ This method retrieves conversation history from the Genie API and
520
+ generates dual embeddings for semantic matching.
521
+
522
+ Args:
523
+ question: The question to embed
524
+ conversation_id: Optional conversation ID for retrieving context
525
+ context_window_size: Number of previous messages to include
526
+ max_context_tokens: Maximum tokens for context string
527
+
528
+ Returns:
529
+ Tuple of (question_embedding, context_embedding, conversation_context_string)
530
+ """
531
+ conversation_context = ""
532
+
533
+ # If conversation context is enabled and available
534
+ if (
535
+ self.workspace_client is not None
536
+ and conversation_id is not None
537
+ and context_window_size > 0
538
+ ):
539
+ try:
540
+ # Retrieve conversation history from Genie API
541
+ conversation_messages = get_conversation_history(
542
+ workspace_client=self.workspace_client,
543
+ space_id=self.space_id,
544
+ conversation_id=conversation_id,
545
+ max_messages=context_window_size * 2, # Get extra for safety
546
+ )
547
+
548
+ # Build context string
549
+ if conversation_messages:
550
+ recent_messages = (
551
+ conversation_messages[-context_window_size:]
552
+ if len(conversation_messages) > context_window_size
553
+ else conversation_messages
554
+ )
555
+
556
+ context_parts: list[str] = []
557
+ for msg in recent_messages:
558
+ if msg.content:
559
+ content: str = msg.content
560
+ if len(content) > 500:
561
+ content = content[:500] + "..."
562
+ context_parts.append(f"Previous: {content}")
563
+
564
+ conversation_context = "\n".join(context_parts)
565
+
566
+ # Truncate if too long
567
+ estimated_tokens = len(conversation_context) / 4
568
+ if estimated_tokens > max_context_tokens:
569
+ target_chars = max_context_tokens * 4
570
+ conversation_context = (
571
+ conversation_context[:target_chars] + "..."
572
+ )
573
+
574
+ logger.trace(
575
+ "Using conversation context from Genie API",
576
+ layer=self.name,
577
+ messages_count=len(conversation_messages),
578
+ window_size=context_window_size,
579
+ )
580
+ except Exception as e:
581
+ logger.warning(
582
+ "Failed to build conversation context, using question only",
583
+ layer=self.name,
584
+ error=str(e),
585
+ )
586
+ conversation_context = ""
587
+
588
+ return self._generate_dual_embeddings(question, conversation_context)
589
+
590
+ def _generate_dual_embeddings(
591
+ self, question: str, conversation_context: str
592
+ ) -> tuple[list[float], list[float], str]:
593
+ """
594
+ Generate dual embeddings for question and conversation context.
595
+
596
+ Args:
597
+ question: The question to embed
598
+ conversation_context: The conversation context string
599
+
600
+ Returns:
601
+ Tuple of (question_embedding, context_embedding, conversation_context)
602
+ """
603
+ if conversation_context:
604
+ # Embed both question and context
605
+ embeddings: list[list[float]] = self._embeddings.embed_documents(
606
+ [question, conversation_context]
607
+ )
608
+ question_embedding = embeddings[0]
609
+ context_embedding = embeddings[1]
610
+ else:
611
+ # Only embed question, use zero vector for context
612
+ embeddings = self._embeddings.embed_documents([question])
613
+ question_embedding = embeddings[0]
614
+ context_embedding = [0.0] * len(question_embedding) # Zero vector
615
+
616
+ return question_embedding, context_embedding, conversation_context
617
+
618
+ @mlflow.trace(name="execute_cached_sql")
619
+ def _execute_sql(self, sql: str) -> pd.DataFrame | str:
620
+ """
621
+ Execute SQL using the warehouse and return results.
622
+
623
+ Args:
624
+ sql: The SQL query to execute
625
+
626
+ Returns:
627
+ DataFrame with results, or error message string if execution failed
628
+ """
629
+ return execute_sql_via_warehouse(
630
+ warehouse=self.warehouse,
631
+ sql=sql,
632
+ layer_name=self.name,
633
+ )
634
+
635
+ def _build_cache_hit_response(
636
+ self,
637
+ cached: SQLCacheEntry,
638
+ result: pd.DataFrame,
639
+ conversation_id: str | None,
640
+ ) -> CacheResult:
641
+ """
642
+ Build a CacheResult for a cache hit.
643
+
644
+ Args:
645
+ cached: The cached SQL entry
646
+ result: The fresh DataFrame from SQL execution
647
+ conversation_id: The current conversation ID
648
+
649
+ Returns:
650
+ CacheResult with cache_hit=True, including message_id and cache_entry_id
651
+ from the original cached entry for traceability and feedback support.
652
+ """
653
+ # IMPORTANT: Use the current conversation_id (from the request), not the cached one
654
+ # This ensures the conversation continues properly
655
+ response = GenieResponse(
656
+ result=result,
657
+ query=cached.query,
658
+ description=cached.description,
659
+ conversation_id=conversation_id
660
+ if conversation_id
661
+ else cached.conversation_id,
662
+ )
663
+ # Cache hit - include message_id from original response for feedback support
664
+ # and cache_entry_id for traceability to genie_prompt_history
665
+ return CacheResult(
666
+ response=response,
667
+ cache_hit=True,
668
+ served_by=self.name,
669
+ message_id=cached.message_id,
670
+ cache_entry_id=cached.cache_entry_id,
671
+ )
672
+
673
+ def ask_question(
674
+ self, question: str, conversation_id: str | None = None
675
+ ) -> CacheResult:
676
+ """
677
+ Ask a question, using semantic cache if a similar query exists.
678
+
679
+ On cache hit, re-executes the cached SQL to get fresh data.
680
+ Returns CacheResult with cache metadata.
681
+
682
+ This method wraps ask_question_with_cache_info with error handling
683
+ to ensure graceful degradation on cache failures.
684
+
685
+ Args:
686
+ question: The question to ask
687
+ conversation_id: Optional conversation ID for context
688
+
689
+ Returns:
690
+ CacheResult with fresh response and cache metadata
691
+ """
692
+ try:
693
+ return self.ask_question_with_cache_info(question, conversation_id)
694
+ except Exception as e:
695
+ logger.warning(
696
+ "Context-aware cache operation failed, delegating to underlying service",
697
+ layer=self.name,
698
+ error=str(e),
699
+ exc_info=True,
700
+ )
701
+ # Graceful degradation: fall back to underlying service
702
+ return self.impl.ask_question(question, conversation_id)
703
+
704
+ def ask_question_with_cache_info(
705
+ self,
706
+ question: str,
707
+ conversation_id: str | None = None,
708
+ ) -> CacheResult:
709
+ """
710
+ Template method for asking a question with cache lookup.
711
+
712
+ This method implements the cache lookup algorithm using the Template Method
713
+ pattern. Subclasses can customize behavior by overriding hook methods:
714
+ - _before_cache_lookup(): Called before cache search (e.g., store prompt)
715
+ - _after_cache_hit(): Called after a cache hit (e.g., update prompt flags)
716
+ - _after_cache_miss(): Called after a cache miss (e.g., store prompt)
717
+
718
+ Args:
719
+ question: The question to ask
720
+ conversation_id: Optional conversation ID for context and continuation
721
+
722
+ Returns:
723
+ CacheResult with fresh response and cache metadata
724
+ """
725
+ self._setup()
726
+
727
+ # Step 1: Generate dual embeddings
728
+ question_embedding, context_embedding, conversation_context = (
729
+ self._embed_question(question, conversation_id)
730
+ )
731
+
732
+ # Step 2: Hook for pre-lookup actions (e.g., store prompt in history)
733
+ self._before_cache_lookup(question, conversation_id)
734
+
735
+ # Step 3: Search for similar cached entry
736
+ cache_result = self._find_similar(
737
+ question,
738
+ conversation_context,
739
+ question_embedding,
740
+ context_embedding,
741
+ conversation_id,
742
+ )
743
+
744
+ # Step 4: Handle cache hit or miss
745
+ if cache_result is not None:
746
+ cached, combined_similarity = cache_result
747
+
748
+ result = self._handle_cache_hit(
749
+ question,
750
+ conversation_id,
751
+ cached,
752
+ combined_similarity,
753
+ conversation_context,
754
+ question_embedding,
755
+ context_embedding,
756
+ )
757
+
758
+ # Hook for post-cache-hit actions (e.g., update prompt cache_hit flag)
759
+ self._after_cache_hit(question, conversation_id, result)
760
+
761
+ return result
762
+
763
+ # Handle cache miss
764
+ result = self._handle_cache_miss(
765
+ question,
766
+ conversation_id,
767
+ conversation_context,
768
+ question_embedding,
769
+ context_embedding,
770
+ )
771
+
772
+ # Hook for post-cache-miss actions (e.g., store prompt if not done earlier)
773
+ self._after_cache_miss(question, conversation_id, result)
774
+
775
+ return result
776
+
777
+ def _before_cache_lookup(self, question: str, conversation_id: str | None) -> None:
778
+ """
779
+ Hook method called before cache lookup.
780
+
781
+ Override this method to perform actions before searching the cache,
782
+ such as storing the prompt in history.
783
+
784
+ Args:
785
+ question: The question being asked
786
+ conversation_id: Optional conversation ID
787
+ """
788
+ pass
789
+
790
+ def _after_cache_hit(
791
+ self,
792
+ question: str,
793
+ conversation_id: str | None,
794
+ result: CacheResult,
795
+ ) -> None:
796
+ """
797
+ Hook method called after a cache hit.
798
+
799
+ Override this method to perform actions after a successful cache hit,
800
+ such as updating prompt history flags.
801
+
802
+ Args:
803
+ question: The question that was asked
804
+ conversation_id: Optional conversation ID
805
+ result: The cache result
806
+ """
807
+ pass
808
+
809
+ def _after_cache_miss(
810
+ self,
811
+ question: str,
812
+ conversation_id: str | None,
813
+ result: CacheResult,
814
+ ) -> None:
815
+ """
816
+ Hook method called after a cache miss.
817
+
818
+ Override this method to perform actions after a cache miss,
819
+ such as storing prompt history if not done earlier.
820
+
821
+ Args:
822
+ question: The question that was asked
823
+ conversation_id: Optional conversation ID
824
+ result: The cache result
825
+ """
826
+ pass
827
+
828
+ def _handle_cache_hit(
829
+ self,
830
+ question: str,
831
+ conversation_id: str | None,
832
+ cached: SQLCacheEntry,
833
+ combined_similarity: float,
834
+ conversation_context: str,
835
+ question_embedding: list[float],
836
+ context_embedding: list[float],
837
+ ) -> CacheResult:
838
+ """
839
+ Handle a cache hit - execute cached SQL and return response.
840
+
841
+ This method handles the common cache hit logic including SQL execution,
842
+ stale cache fallback, and response building.
843
+
844
+ Args:
845
+ question: The original question
846
+ conversation_id: The conversation ID
847
+ cached: The cached SQL entry
848
+ combined_similarity: The similarity score
849
+ conversation_context: The conversation context string
850
+ question_embedding: The question embedding
851
+ context_embedding: The context embedding
852
+
853
+ Returns:
854
+ CacheResult with the response
855
+ """
856
+ logger.debug(
857
+ "Cache hit",
858
+ layer=self.name,
859
+ combined_similarity=f"{combined_similarity:.3f}",
860
+ question=question[:50],
861
+ conversation_id=conversation_id,
862
+ )
863
+
864
+ # Re-execute the cached SQL to get fresh data
865
+ result: pd.DataFrame | str = self._execute_sql(cached.query)
866
+
867
+ # Check if SQL execution failed (returns error string instead of DataFrame)
868
+ if isinstance(result, str):
869
+ logger.warning(
870
+ "Cached SQL execution failed, falling back to Genie",
871
+ layer=self.name,
872
+ question=question[:80],
873
+ conversation_id=conversation_id,
874
+ cached_sql=cached.query[:80],
875
+ error=result[:200],
876
+ space_id=self.space_id,
877
+ )
878
+
879
+ # Subclass should handle stale entry cleanup
880
+ self._on_stale_cache_entry(question)
881
+
882
+ # Fall back to Genie to get fresh SQL
883
+ logger.info(
884
+ "Delegating to Genie for fresh SQL",
885
+ layer=self.name,
886
+ question=question[:80],
887
+ conversation_id=conversation_id,
888
+ space_id=self.space_id,
889
+ delegating_to=type(self.impl).__name__,
890
+ )
891
+ fallback_result: CacheResult = self.impl.ask_question(
892
+ question, conversation_id
893
+ )
894
+
895
+ # Store the fresh SQL in cache
896
+ if fallback_result.response.query:
897
+ self._store_entry(
898
+ question,
899
+ conversation_context,
900
+ question_embedding,
901
+ context_embedding,
902
+ fallback_result.response,
903
+ message_id=fallback_result.message_id,
904
+ )
905
+ logger.info(
906
+ "Stored fresh SQL from fallback",
907
+ layer=self.name,
908
+ fresh_sql=fallback_result.response.query[:80],
909
+ space_id=self.space_id,
910
+ message_id=fallback_result.message_id,
911
+ )
912
+ else:
913
+ logger.warning(
914
+ "Fallback response has no SQL query to cache",
915
+ layer=self.name,
916
+ question=question[:80],
917
+ space_id=self.space_id,
918
+ )
919
+
920
+ # Return as cache miss (fallback scenario)
921
+ # Propagate message_id from fallback result
922
+ return CacheResult(
923
+ response=fallback_result.response,
924
+ cache_hit=False,
925
+ served_by=None,
926
+ message_id=fallback_result.message_id,
927
+ )
928
+
929
+ # Build and return cache hit response
930
+ return self._build_cache_hit_response(cached, result, conversation_id)
931
+
932
+ def _on_stale_cache_entry(self, question: str) -> None:
933
+ """
934
+ Called when a stale cache entry is detected (SQL execution failed).
935
+
936
+ Subclasses can override this to clean up the stale entry from storage.
937
+
938
+ Args:
939
+ question: The question that had a stale cache entry
940
+ """
941
+ # Default implementation does nothing - subclasses should override
942
+ pass
943
+
944
+ def _handle_cache_miss(
945
+ self,
946
+ question: str,
947
+ conversation_id: str | None,
948
+ conversation_context: str,
949
+ question_embedding: list[float],
950
+ context_embedding: list[float],
951
+ ) -> CacheResult:
952
+ """
953
+ Handle a cache miss - delegate to underlying service and store result.
954
+
955
+ Args:
956
+ question: The original question
957
+ conversation_id: The conversation ID
958
+ conversation_context: The conversation context string
959
+ question_embedding: The question embedding
960
+ context_embedding: The context embedding
961
+
962
+ Returns:
963
+ CacheResult from the underlying service
964
+ """
965
+ logger.info(
966
+ "Cache MISS",
967
+ layer=self.name,
968
+ question=question[:80],
969
+ conversation_id=conversation_id,
970
+ space_id=self.space_id,
971
+ similarity_threshold=self.similarity_threshold,
972
+ delegating_to=type(self.impl).__name__,
973
+ )
974
+
975
+ result: CacheResult = self.impl.ask_question(question, conversation_id)
976
+
977
+ # Store in cache if we got a SQL query
978
+ if result.response.query:
979
+ logger.debug(
980
+ "Storing new cache entry",
981
+ layer=self.name,
982
+ question=question[:50],
983
+ conversation_id=conversation_id,
984
+ space=self.space_id,
985
+ message_id=result.message_id,
986
+ )
987
+ self._store_entry(
988
+ question,
989
+ conversation_context,
990
+ question_embedding,
991
+ context_embedding,
992
+ result.response,
993
+ message_id=result.message_id,
994
+ )
995
+ else:
996
+ logger.warning(
997
+ "Not caching: response has no SQL query",
998
+ layer=self.name,
999
+ question=question[:50],
1000
+ )
1001
+
1002
+ # Propagate message_id from underlying service result
1003
+ return CacheResult(
1004
+ response=result.response,
1005
+ cache_hit=False,
1006
+ served_by=None,
1007
+ message_id=result.message_id,
1008
+ )
1009
+
1010
+ @abstractmethod
1011
+ def _invalidate_by_question(self, question: str) -> bool:
1012
+ """
1013
+ Invalidate cache entries matching a specific question.
1014
+
1015
+ This method is called when negative feedback is received to remove
1016
+ the corresponding cache entry.
1017
+
1018
+ Args:
1019
+ question: The question text to match and invalidate
1020
+
1021
+ Returns:
1022
+ True if an entry was found and invalidated, False otherwise
1023
+ """
1024
+ pass
1025
+
1026
+ @mlflow.trace(name="genie_context_aware_send_feedback")
1027
+ def send_feedback(
1028
+ self,
1029
+ conversation_id: str,
1030
+ rating: GenieFeedbackRating,
1031
+ message_id: str | None = None,
1032
+ was_cache_hit: bool = False,
1033
+ ) -> None:
1034
+ """
1035
+ Send feedback for a Genie message with cache invalidation.
1036
+
1037
+ For context-aware caches, this method:
1038
+ 1. If was_cache_hit is False: forwards feedback to the underlying service
1039
+ 2. If rating is NEGATIVE: invalidates any matching cache entries
1040
+
1041
+ Args:
1042
+ conversation_id: The conversation containing the message
1043
+ rating: The feedback rating (POSITIVE, NEGATIVE, or NONE)
1044
+ message_id: Optional message ID. If None, looks up the most recent message.
1045
+ was_cache_hit: Whether the response being rated was served from cache.
1046
+
1047
+ Note:
1048
+ For cached responses (was_cache_hit=True), only cache invalidation is
1049
+ performed. No feedback is sent to the Genie API because cached responses
1050
+ don't have a corresponding Genie message.
1051
+
1052
+ Future Enhancement: To enable full Genie feedback for cached responses,
1053
+ the cache would need to store the original message_id. This would require:
1054
+ 1. Adding message_id column to cache tables
1055
+ 2. Adding message_id field to SQLCacheEntry dataclass
1056
+ 3. Capturing message_id from the original Genie API response
1057
+ (databricks_ai_bridge.genie.GenieResponse doesn't expose this)
1058
+ 4. Using WorkspaceClient directly instead of databricks_ai_bridge
1059
+ """
1060
+ invalidated = False
1061
+
1062
+ # Handle cache invalidation on negative feedback
1063
+ if rating == GenieFeedbackRating.NEGATIVE:
1064
+ # Need to look up the message content to find matching cache entries
1065
+ if self.workspace_client is not None:
1066
+ from dao_ai.genie.cache.base import (
1067
+ get_latest_message_id,
1068
+ get_message_content,
1069
+ )
1070
+
1071
+ # Get message_id if not provided
1072
+ target_message_id = message_id
1073
+ if target_message_id is None:
1074
+ target_message_id = get_latest_message_id(
1075
+ workspace_client=self.workspace_client,
1076
+ space_id=self.space_id,
1077
+ conversation_id=conversation_id,
1078
+ )
1079
+
1080
+ # Get the message content (question) to find matching cache entries
1081
+ if target_message_id:
1082
+ question = get_message_content(
1083
+ workspace_client=self.workspace_client,
1084
+ space_id=self.space_id,
1085
+ conversation_id=conversation_id,
1086
+ message_id=target_message_id,
1087
+ )
1088
+
1089
+ if question:
1090
+ invalidated = self._invalidate_by_question(question)
1091
+ if invalidated:
1092
+ logger.info(
1093
+ "Invalidated cache entry due to negative feedback",
1094
+ layer=self.name,
1095
+ question=question[:80],
1096
+ conversation_id=conversation_id,
1097
+ message_id=target_message_id,
1098
+ )
1099
+ else:
1100
+ logger.debug(
1101
+ "No cache entry found to invalidate for negative feedback",
1102
+ layer=self.name,
1103
+ question=question[:80],
1104
+ conversation_id=conversation_id,
1105
+ )
1106
+ else:
1107
+ logger.warning(
1108
+ "Could not retrieve message content for cache invalidation",
1109
+ layer=self.name,
1110
+ conversation_id=conversation_id,
1111
+ message_id=target_message_id,
1112
+ )
1113
+ else:
1114
+ logger.warning(
1115
+ "Could not find message_id for cache invalidation",
1116
+ layer=self.name,
1117
+ conversation_id=conversation_id,
1118
+ )
1119
+ else:
1120
+ logger.warning(
1121
+ "No workspace_client available for cache invalidation",
1122
+ layer=self.name,
1123
+ conversation_id=conversation_id,
1124
+ )
1125
+
1126
+ # Forward feedback to underlying service if not a cache hit
1127
+ # For cache hits, there's no Genie message to provide feedback on
1128
+ if was_cache_hit:
1129
+ logger.info(
1130
+ "Skipping Genie API feedback - response was served from cache",
1131
+ layer=self.name,
1132
+ conversation_id=conversation_id,
1133
+ rating=rating.value if rating else None,
1134
+ cache_invalidated=invalidated,
1135
+ )
1136
+ return
1137
+
1138
+ # Forward to underlying service
1139
+ logger.debug(
1140
+ "Forwarding feedback to underlying service",
1141
+ layer=self.name,
1142
+ conversation_id=conversation_id,
1143
+ rating=rating.value if rating else None,
1144
+ delegating_to=type(self.impl).__name__,
1145
+ )
1146
+ self.impl.send_feedback(
1147
+ conversation_id=conversation_id,
1148
+ rating=rating,
1149
+ message_id=message_id,
1150
+ was_cache_hit=False, # Already handled, so pass False
1151
+ )