dao-ai 0.1.19__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1204 @@
1
+ """
2
+ Abstract base class for context-aware Genie cache implementations.
3
+
4
+ This module provides the foundational abstract base class for all context-aware
5
+ cache implementations. It extracts common code for:
6
+ - Dual embedding generation (question + conversation context)
7
+ - Ask question flow with error handling and graceful fallback
8
+ - SQL execution with retry logic
9
+ - Common properties and initialization patterns
10
+
11
+ Subclasses must implement storage-specific methods:
12
+ - _find_similar(): Find semantically similar cached entry
13
+ - _store_entry(): Store new cache entry
14
+ - _setup(): Initialize resources (embeddings, storage)
15
+ - invalidate_expired(): Remove expired entries
16
+ - clear(): Clear all entries for space
17
+ - stats(): Return cache statistics
18
+ - get_entries(): Retrieve cache entries with filtering
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ from abc import abstractmethod
24
+ from datetime import datetime, timedelta
25
+ from typing import Any, Self, TypeVar
26
+
27
+ import mlflow
28
+ import pandas as pd
29
+ from databricks.sdk import WorkspaceClient
30
+ from databricks.sdk.service.dashboards import (
31
+ GenieFeedbackRating,
32
+ GenieListConversationMessagesResponse,
33
+ GenieMessage,
34
+ )
35
+ from databricks_ai_bridge.genie import GenieResponse
36
+ from loguru import logger
37
+
38
+ from dao_ai.config import LLMModel, WarehouseModel
39
+ from dao_ai.genie.cache.base import (
40
+ CacheResult,
41
+ GenieServiceBase,
42
+ SQLCacheEntry,
43
+ )
44
+ from dao_ai.genie.cache.core import execute_sql_via_warehouse
45
+
46
+ # Type variable for subclass return types
47
+ T = TypeVar("T", bound="ContextAwareGenieService")
48
+
49
+
50
+ def get_conversation_history(
51
+ workspace_client: WorkspaceClient,
52
+ space_id: str,
53
+ conversation_id: str,
54
+ max_messages: int = 10,
55
+ ) -> list[GenieMessage]:
56
+ """
57
+ Retrieve conversation history from Genie.
58
+
59
+ Args:
60
+ workspace_client: The Databricks workspace client
61
+ space_id: The Genie space ID
62
+ conversation_id: The conversation ID to retrieve
63
+ max_messages: Maximum number of messages to retrieve
64
+
65
+ Returns:
66
+ List of GenieMessage objects representing the conversation history
67
+ """
68
+ try:
69
+ # Use the Genie API to retrieve conversation messages
70
+ response: GenieListConversationMessagesResponse = (
71
+ workspace_client.genie.list_conversation_messages(
72
+ space_id=space_id,
73
+ conversation_id=conversation_id,
74
+ )
75
+ )
76
+
77
+ # Return the most recent messages up to max_messages
78
+ if response.messages is not None:
79
+ all_messages: list[GenieMessage] = list(response.messages)
80
+ return (
81
+ all_messages[-max_messages:]
82
+ if len(all_messages) > max_messages
83
+ else all_messages
84
+ )
85
+ return []
86
+ except Exception as e:
87
+ logger.warning(
88
+ "Failed to retrieve conversation history",
89
+ conversation_id=conversation_id,
90
+ error=str(e),
91
+ )
92
+ return []
93
+
94
+
95
+ def build_context_string(
96
+ question: str,
97
+ conversation_messages: list[GenieMessage],
98
+ window_size: int,
99
+ max_tokens: int = 2000,
100
+ ) -> str:
101
+ """
102
+ Build a context-aware question string using rolling window.
103
+
104
+ This function creates a concatenated string that includes recent conversation
105
+ turns to provide context for semantic similarity matching.
106
+
107
+ Args:
108
+ question: The current question
109
+ conversation_messages: List of previous conversation messages
110
+ window_size: Number of previous turns to include
111
+ max_tokens: Maximum estimated tokens (rough approximation: 4 chars = 1 token)
112
+
113
+ Returns:
114
+ Context-aware question string formatted for embedding
115
+ """
116
+ if window_size <= 0 or not conversation_messages:
117
+ return question
118
+
119
+ # Take the last window_size messages (most recent)
120
+ recent_messages = (
121
+ conversation_messages[-window_size:]
122
+ if len(conversation_messages) > window_size
123
+ else conversation_messages
124
+ )
125
+
126
+ # Build context parts
127
+ context_parts: list[str] = []
128
+
129
+ for msg in recent_messages:
130
+ # Only include messages with content from the history
131
+ if msg.content:
132
+ # Limit message length to prevent token overflow
133
+ content: str = msg.content
134
+ if len(content) > 500: # Truncate very long messages
135
+ content = content[:500] + "..."
136
+ context_parts.append(f"Previous: {content}")
137
+
138
+ # Add current question
139
+ context_parts.append(f"Current: {question}")
140
+
141
+ # Join with newlines
142
+ context_string = "\n".join(context_parts)
143
+
144
+ # Rough token limit check (4 chars ≈ 1 token)
145
+ estimated_tokens = len(context_string) / 4
146
+ if estimated_tokens > max_tokens:
147
+ # Truncate to fit max_tokens
148
+ target_chars = max_tokens * 4
149
+ original_length = len(context_string)
150
+ context_string = context_string[:target_chars] + "..."
151
+ logger.trace(
152
+ "Truncated context string",
153
+ original_chars=original_length,
154
+ target_chars=target_chars,
155
+ max_tokens=max_tokens,
156
+ )
157
+
158
+ return context_string
159
+
160
+
161
+ class ContextAwareGenieService(GenieServiceBase):
162
+ """
163
+ Abstract base class for context-aware Genie cache implementations.
164
+
165
+ This class provides shared implementation for:
166
+ - Dual embedding generation (question + conversation context)
167
+ - Main ask_question flow with error handling
168
+ - SQL execution with warehouse
169
+ - Common properties (time_to_live, similarity_threshold, etc.)
170
+
171
+ Subclasses must implement storage-specific methods for finding similar
172
+ entries, storing new entries, and managing cache lifecycle.
173
+
174
+ Error Handling:
175
+ All cache operations are wrapped in try/except to ensure graceful
176
+ degradation. If any cache operation fails, the request is delegated
177
+ to the underlying service without caching.
178
+
179
+ Thread Safety:
180
+ Subclasses are responsible for thread safety of storage operations.
181
+ This base class does not provide synchronization primitives.
182
+ """
183
+
184
+ # Common attributes - subclasses should define these
185
+ impl: GenieServiceBase
186
+ _workspace_client: WorkspaceClient | None
187
+ name: str
188
+ _embeddings: Any # DatabricksEmbeddings
189
+ _embedding_dims: int | None
190
+ _setup_complete: bool
191
+
192
+ # Abstract methods that subclasses must implement
193
+ @abstractmethod
194
+ def _setup(self) -> None:
195
+ """
196
+ Initialize resources required by the cache implementation.
197
+
198
+ This method is called lazily before first use. Implementations should:
199
+ - Initialize embedding model
200
+ - Set up storage (database connection, in-memory structures, etc.)
201
+ - Create necessary tables/indexes if applicable
202
+
203
+ This method should be idempotent (safe to call multiple times).
204
+ """
205
+ pass
206
+
207
+ @abstractmethod
208
+ def _find_similar(
209
+ self,
210
+ question: str,
211
+ conversation_context: str,
212
+ question_embedding: list[float],
213
+ context_embedding: list[float],
214
+ conversation_id: str | None = None,
215
+ ) -> tuple[SQLCacheEntry, float] | None:
216
+ """
217
+ Find a semantically similar cached entry using dual embedding matching.
218
+
219
+ Args:
220
+ question: The original question (for logging)
221
+ conversation_context: The conversation context string
222
+ question_embedding: The embedding vector of just the question
223
+ context_embedding: The embedding vector of the conversation context
224
+ conversation_id: Optional conversation ID (for logging)
225
+
226
+ Returns:
227
+ Tuple of (SQLCacheEntry, combined_similarity_score) if found, None otherwise
228
+ """
229
+ pass
230
+
231
+ @abstractmethod
232
+ def _store_entry(
233
+ self,
234
+ question: str,
235
+ conversation_context: str,
236
+ question_embedding: list[float],
237
+ context_embedding: list[float],
238
+ response: GenieResponse,
239
+ message_id: str | None = None,
240
+ ) -> None:
241
+ """
242
+ Store a new cache entry with dual embeddings.
243
+
244
+ Args:
245
+ question: The user's question
246
+ conversation_context: Previous conversation context string
247
+ question_embedding: Embedding of the question
248
+ context_embedding: Embedding of the conversation context
249
+ response: The GenieResponse containing query, description, etc.
250
+ message_id: The Genie message ID from the original API response.
251
+ Stored with the cache entry to enable feedback on cache hits.
252
+ """
253
+ pass
254
+
255
+ def invalidate_expired(self) -> int | dict[str, int]:
256
+ """
257
+ Template method for removing expired entries from the cache.
258
+
259
+ This method implements the TTL check and delegates to
260
+ _delete_expired_entries() for the actual deletion.
261
+
262
+ Returns:
263
+ Number of entries deleted, or dict with counts by category
264
+ """
265
+ self._setup()
266
+ ttl_seconds = self.time_to_live_seconds
267
+
268
+ if ttl_seconds is None or ttl_seconds < 0:
269
+ return self._get_empty_expiration_result()
270
+
271
+ return self._delete_expired_entries(ttl_seconds)
272
+
273
+ @abstractmethod
274
+ def _delete_expired_entries(self, ttl_seconds: int) -> int | dict[str, int]:
275
+ """
276
+ Delete expired entries from storage.
277
+
278
+ Args:
279
+ ttl_seconds: TTL in seconds for determining expiration
280
+
281
+ Returns:
282
+ Number of entries deleted, or dict with counts by category
283
+ """
284
+ pass
285
+
286
+ def _get_empty_expiration_result(self) -> int | dict[str, int]:
287
+ """
288
+ Return the empty result for invalidate_expired when TTL is disabled.
289
+
290
+ Override this in subclasses that return dict to return appropriate empty dict.
291
+
292
+ Returns:
293
+ 0 by default, or empty dict for subclasses that return dict
294
+ """
295
+ return 0
296
+
297
+ def clear(self) -> int:
298
+ """
299
+ Template method for clearing all entries from the cache.
300
+
301
+ This method calls _setup() and delegates to _delete_all_entries().
302
+
303
+ Returns:
304
+ Number of entries deleted
305
+ """
306
+ self._setup()
307
+ return self._delete_all_entries()
308
+
309
+ @abstractmethod
310
+ def _delete_all_entries(self) -> int:
311
+ """
312
+ Delete all entries for this Genie space from storage.
313
+
314
+ Returns:
315
+ Number of entries deleted
316
+ """
317
+ pass
318
+
319
+ @abstractmethod
320
+ def get_entries(
321
+ self,
322
+ limit: int | None = None,
323
+ offset: int | None = None,
324
+ include_embeddings: bool = False,
325
+ conversation_id: str | None = None,
326
+ created_after: datetime | None = None,
327
+ created_before: datetime | None = None,
328
+ question_contains: str | None = None,
329
+ ) -> list[dict[str, Any]]:
330
+ """
331
+ Get cache entries with optional filtering.
332
+
333
+ This method retrieves cache entries for inspection, debugging, or
334
+ generating evaluation datasets for threshold optimization.
335
+
336
+ Args:
337
+ limit: Maximum number of entries to return (None = no limit)
338
+ offset: Number of entries to skip for pagination (None = 0)
339
+ include_embeddings: Whether to include embedding vectors in results.
340
+ Embeddings are large, so set False for general inspection.
341
+ conversation_id: Filter by conversation ID (None = all conversations)
342
+ created_after: Only entries created after this time (None = no filter)
343
+ created_before: Only entries created before this time (None = no filter)
344
+ question_contains: Case-insensitive text search on question field
345
+
346
+ Returns:
347
+ List of cache entry dicts with keys:
348
+ - id: Cache entry ID (int for persistent caches, None for in-memory)
349
+ - question: The cached question text
350
+ - conversation_context: Prior conversation context string
351
+ - sql_query: The cached SQL query
352
+ - description: Query description
353
+ - conversation_id: The conversation ID
354
+ - created_at: Entry creation timestamp (datetime)
355
+ - question_embedding: (only if include_embeddings=True)
356
+ - context_embedding: (only if include_embeddings=True)
357
+
358
+ Example:
359
+ # Get recent entries for inspection
360
+ entries = cache.get_entries(limit=10)
361
+
362
+ # Get entries with embeddings for evaluation dataset
363
+ entries = cache.get_entries(include_embeddings=True, limit=100)
364
+ eval_dataset = generate_eval_dataset_from_cache(entries)
365
+
366
+ # Search for specific questions
367
+ entries = cache.get_entries(question_contains="sales")
368
+ """
369
+ pass
370
+
371
+ def stats(self) -> dict[str, Any]:
372
+ """
373
+ Template method for returning cache statistics.
374
+
375
+ This method uses the Template Method pattern to consolidate the common
376
+ stats calculation algorithm. Subclasses provide counting implementations
377
+ via abstract methods and can add additional stats via hook methods.
378
+
379
+ Returns:
380
+ Dict with cache statistics (size, ttl, thresholds, etc.)
381
+ """
382
+ self._setup()
383
+ ttl_seconds = self.time_to_live_seconds
384
+ ttl = self.time_to_live
385
+
386
+ # Calculate base stats using abstract counting methods
387
+ if ttl_seconds is None or ttl_seconds < 0:
388
+ total = self._count_all_entries()
389
+ base_stats: dict[str, Any] = {
390
+ "size": total,
391
+ "ttl_seconds": None,
392
+ "similarity_threshold": self.similarity_threshold,
393
+ "expired_entries": 0,
394
+ "valid_entries": total,
395
+ }
396
+ else:
397
+ total, expired = self._count_entries_with_ttl(ttl_seconds)
398
+ base_stats = {
399
+ "size": total,
400
+ "ttl_seconds": ttl.total_seconds() if ttl else None,
401
+ "similarity_threshold": self.similarity_threshold,
402
+ "expired_entries": expired,
403
+ "valid_entries": total - expired,
404
+ }
405
+
406
+ # Add any additional stats from subclasses
407
+ additional_stats = self._get_additional_stats()
408
+ base_stats.update(additional_stats)
409
+
410
+ return base_stats
411
+
412
+ @abstractmethod
413
+ def _count_all_entries(self) -> int:
414
+ """
415
+ Count all cache entries for this Genie space.
416
+
417
+ Returns:
418
+ Total number of cache entries
419
+ """
420
+ pass
421
+
422
+ @abstractmethod
423
+ def _count_entries_with_ttl(self, ttl_seconds: int) -> tuple[int, int]:
424
+ """
425
+ Count total and expired entries for this Genie space.
426
+
427
+ Args:
428
+ ttl_seconds: TTL in seconds for determining expiration
429
+
430
+ Returns:
431
+ Tuple of (total_entries, expired_entries)
432
+ """
433
+ pass
434
+
435
+ def _get_additional_stats(self) -> dict[str, Any]:
436
+ """
437
+ Hook method to add additional stats from subclasses.
438
+
439
+ Override this method to add subclass-specific statistics like
440
+ capacity (in-memory) or prompt history stats (postgres).
441
+
442
+ Returns:
443
+ Dict with additional stats to merge into base stats
444
+ """
445
+ return {}
446
+
447
+ # Properties that subclasses should implement or inherit
448
+ @property
449
+ @abstractmethod
450
+ def warehouse(self) -> WarehouseModel:
451
+ """The warehouse used for executing cached SQL queries."""
452
+ pass
453
+
454
+ @property
455
+ @abstractmethod
456
+ def time_to_live(self) -> timedelta | None:
457
+ """Time-to-live for cache entries. None means never expires."""
458
+ pass
459
+
460
+ @property
461
+ @abstractmethod
462
+ def similarity_threshold(self) -> float:
463
+ """Minimum similarity for cache hit (using L2 distance converted to similarity)."""
464
+ pass
465
+
466
+ @property
467
+ def embedding_dims(self) -> int:
468
+ """Dimension size for embeddings (auto-detected if not configured)."""
469
+ if self._embedding_dims is None:
470
+ raise RuntimeError(
471
+ "Embedding dimensions not yet initialized. Call _setup() first."
472
+ )
473
+ return self._embedding_dims
474
+
475
+ @property
476
+ def space_id(self) -> str:
477
+ """The Genie space ID from the underlying service."""
478
+ return self.impl.space_id
479
+
480
+ @property
481
+ def workspace_client(self) -> WorkspaceClient | None:
482
+ """Get workspace client, delegating to impl if not set."""
483
+ if self._workspace_client is not None:
484
+ return self._workspace_client
485
+ return self.impl.workspace_client
486
+
487
+ @property
488
+ def time_to_live_seconds(self) -> int | None:
489
+ """TTL in seconds (None or negative = never expires)."""
490
+ ttl = self.time_to_live
491
+ if ttl is None:
492
+ return None
493
+ return int(ttl.total_seconds())
494
+
495
+ # Abstract method for embedding - subclasses must implement
496
+ @abstractmethod
497
+ def _embed_question(
498
+ self, question: str, conversation_id: str | None = None
499
+ ) -> tuple[list[float], list[float], str]:
500
+ """
501
+ Generate dual embeddings for a question with conversation context.
502
+
503
+ Args:
504
+ question: The question to embed
505
+ conversation_id: Optional conversation ID for retrieving context
506
+
507
+ Returns:
508
+ Tuple of (question_embedding, context_embedding, conversation_context_string)
509
+ """
510
+ pass
511
+
512
+ # Shared implementation methods
513
+ def initialize(self) -> Self:
514
+ """
515
+ Eagerly initialize the cache service.
516
+
517
+ Call this during tool creation to:
518
+ - Validate configuration early (fail fast)
519
+ - Initialize resources before any requests
520
+ - Avoid first-request latency from lazy initialization
521
+
522
+ Returns:
523
+ self for method chaining
524
+ """
525
+ self._setup()
526
+ return self
527
+
528
+ def _initialize_embeddings(
529
+ self,
530
+ embedding_model: str | LLMModel,
531
+ embedding_dims: int | None = None,
532
+ ) -> None:
533
+ """
534
+ Initialize the embeddings model and detect dimensions.
535
+
536
+ This helper method handles embedding model initialization for subclasses.
537
+
538
+ Args:
539
+ embedding_model: The embedding model name or LLMModel instance
540
+ embedding_dims: Optional pre-configured embedding dimensions
541
+ """
542
+ # Convert embedding_model to LLMModel if it's a string
543
+ model: LLMModel = (
544
+ LLMModel(name=embedding_model)
545
+ if isinstance(embedding_model, str)
546
+ else embedding_model
547
+ )
548
+ self._embeddings = model.as_embeddings_model()
549
+
550
+ # Auto-detect embedding dimensions if not provided
551
+ if embedding_dims is None:
552
+ sample_embedding: list[float] = self._embeddings.embed_query("test")
553
+ self._embedding_dims = len(sample_embedding)
554
+ logger.debug(
555
+ "Auto-detected embedding dimensions",
556
+ layer=self.name,
557
+ dims=self._embedding_dims,
558
+ )
559
+ else:
560
+ self._embedding_dims = embedding_dims
561
+
562
+ def _embed_question_with_genie_history(
563
+ self,
564
+ question: str,
565
+ conversation_id: str | None,
566
+ context_window_size: int,
567
+ max_context_tokens: int,
568
+ ) -> tuple[list[float], list[float], str]:
569
+ """
570
+ Generate dual embeddings using Genie API for conversation history.
571
+
572
+ This method retrieves conversation history from the Genie API and
573
+ generates dual embeddings for semantic matching.
574
+
575
+ Args:
576
+ question: The question to embed
577
+ conversation_id: Optional conversation ID for retrieving context
578
+ context_window_size: Number of previous messages to include
579
+ max_context_tokens: Maximum tokens for context string
580
+
581
+ Returns:
582
+ Tuple of (question_embedding, context_embedding, conversation_context_string)
583
+ """
584
+ conversation_context = ""
585
+
586
+ # If conversation context is enabled and available
587
+ if (
588
+ self.workspace_client is not None
589
+ and conversation_id is not None
590
+ and context_window_size > 0
591
+ ):
592
+ try:
593
+ # Retrieve conversation history from Genie API
594
+ conversation_messages = get_conversation_history(
595
+ workspace_client=self.workspace_client,
596
+ space_id=self.space_id,
597
+ conversation_id=conversation_id,
598
+ max_messages=context_window_size * 2, # Get extra for safety
599
+ )
600
+
601
+ # Build context string
602
+ if conversation_messages:
603
+ recent_messages = (
604
+ conversation_messages[-context_window_size:]
605
+ if len(conversation_messages) > context_window_size
606
+ else conversation_messages
607
+ )
608
+
609
+ context_parts: list[str] = []
610
+ for msg in recent_messages:
611
+ if msg.content:
612
+ content: str = msg.content
613
+ if len(content) > 500:
614
+ content = content[:500] + "..."
615
+ context_parts.append(f"Previous: {content}")
616
+
617
+ conversation_context = "\n".join(context_parts)
618
+
619
+ # Truncate if too long
620
+ estimated_tokens = len(conversation_context) / 4
621
+ if estimated_tokens > max_context_tokens:
622
+ target_chars = max_context_tokens * 4
623
+ conversation_context = (
624
+ conversation_context[:target_chars] + "..."
625
+ )
626
+
627
+ logger.trace(
628
+ "Using conversation context from Genie API",
629
+ layer=self.name,
630
+ messages_count=len(conversation_messages),
631
+ window_size=context_window_size,
632
+ )
633
+ except Exception as e:
634
+ logger.warning(
635
+ "Failed to build conversation context, using question only",
636
+ layer=self.name,
637
+ error=str(e),
638
+ )
639
+ conversation_context = ""
640
+
641
+ return self._generate_dual_embeddings(question, conversation_context)
642
+
643
+ def _generate_dual_embeddings(
644
+ self, question: str, conversation_context: str
645
+ ) -> tuple[list[float], list[float], str]:
646
+ """
647
+ Generate dual embeddings for question and conversation context.
648
+
649
+ Args:
650
+ question: The question to embed
651
+ conversation_context: The conversation context string
652
+
653
+ Returns:
654
+ Tuple of (question_embedding, context_embedding, conversation_context)
655
+ """
656
+ if conversation_context:
657
+ # Embed both question and context
658
+ embeddings: list[list[float]] = self._embeddings.embed_documents(
659
+ [question, conversation_context]
660
+ )
661
+ question_embedding = embeddings[0]
662
+ context_embedding = embeddings[1]
663
+ else:
664
+ # Only embed question, use zero vector for context
665
+ embeddings = self._embeddings.embed_documents([question])
666
+ question_embedding = embeddings[0]
667
+ context_embedding = [0.0] * len(question_embedding) # Zero vector
668
+
669
+ return question_embedding, context_embedding, conversation_context
670
+
671
+ @mlflow.trace(name="execute_cached_sql")
672
+ def _execute_sql(self, sql: str) -> pd.DataFrame | str:
673
+ """
674
+ Execute SQL using the warehouse and return results.
675
+
676
+ Args:
677
+ sql: The SQL query to execute
678
+
679
+ Returns:
680
+ DataFrame with results, or error message string if execution failed
681
+ """
682
+ return execute_sql_via_warehouse(
683
+ warehouse=self.warehouse,
684
+ sql=sql,
685
+ layer_name=self.name,
686
+ )
687
+
688
+ def _build_cache_hit_response(
689
+ self,
690
+ cached: SQLCacheEntry,
691
+ result: pd.DataFrame,
692
+ conversation_id: str | None,
693
+ ) -> CacheResult:
694
+ """
695
+ Build a CacheResult for a cache hit.
696
+
697
+ Args:
698
+ cached: The cached SQL entry
699
+ result: The fresh DataFrame from SQL execution
700
+ conversation_id: The current conversation ID
701
+
702
+ Returns:
703
+ CacheResult with cache_hit=True, including message_id and cache_entry_id
704
+ from the original cached entry for traceability and feedback support.
705
+ """
706
+ # IMPORTANT: Use the current conversation_id (from the request), not the cached one
707
+ # This ensures the conversation continues properly
708
+ response = GenieResponse(
709
+ result=result,
710
+ query=cached.query,
711
+ description=cached.description,
712
+ conversation_id=conversation_id
713
+ if conversation_id
714
+ else cached.conversation_id,
715
+ )
716
+ # Cache hit - include message_id from original response for feedback support
717
+ # and cache_entry_id for traceability to genie_prompt_history
718
+ return CacheResult(
719
+ response=response,
720
+ cache_hit=True,
721
+ served_by=self.name,
722
+ message_id=cached.message_id,
723
+ cache_entry_id=cached.cache_entry_id,
724
+ )
725
+
726
+ def ask_question(
727
+ self, question: str, conversation_id: str | None = None
728
+ ) -> CacheResult:
729
+ """
730
+ Ask a question, using semantic cache if a similar query exists.
731
+
732
+ On cache hit, re-executes the cached SQL to get fresh data.
733
+ Returns CacheResult with cache metadata.
734
+
735
+ This method wraps ask_question_with_cache_info with error handling
736
+ to ensure graceful degradation on cache failures.
737
+
738
+ Args:
739
+ question: The question to ask
740
+ conversation_id: Optional conversation ID for context
741
+
742
+ Returns:
743
+ CacheResult with fresh response and cache metadata
744
+ """
745
+ try:
746
+ return self.ask_question_with_cache_info(question, conversation_id)
747
+ except Exception as e:
748
+ logger.warning(
749
+ "Context-aware cache operation failed, delegating to underlying service",
750
+ layer=self.name,
751
+ error=str(e),
752
+ exc_info=True,
753
+ )
754
+ # Graceful degradation: fall back to underlying service
755
+ return self.impl.ask_question(question, conversation_id)
756
+
757
+ def ask_question_with_cache_info(
758
+ self,
759
+ question: str,
760
+ conversation_id: str | None = None,
761
+ ) -> CacheResult:
762
+ """
763
+ Template method for asking a question with cache lookup.
764
+
765
+ This method implements the cache lookup algorithm using the Template Method
766
+ pattern. Subclasses can customize behavior by overriding hook methods:
767
+ - _before_cache_lookup(): Called before cache search (e.g., store prompt)
768
+ - _after_cache_hit(): Called after a cache hit (e.g., update prompt flags)
769
+ - _after_cache_miss(): Called after a cache miss (e.g., store prompt)
770
+
771
+ Args:
772
+ question: The question to ask
773
+ conversation_id: Optional conversation ID for context and continuation
774
+
775
+ Returns:
776
+ CacheResult with fresh response and cache metadata
777
+ """
778
+ self._setup()
779
+
780
+ # Step 1: Generate dual embeddings
781
+ question_embedding, context_embedding, conversation_context = (
782
+ self._embed_question(question, conversation_id)
783
+ )
784
+
785
+ # Step 2: Hook for pre-lookup actions (e.g., store prompt in history)
786
+ self._before_cache_lookup(question, conversation_id)
787
+
788
+ # Step 3: Search for similar cached entry
789
+ cache_result = self._find_similar(
790
+ question,
791
+ conversation_context,
792
+ question_embedding,
793
+ context_embedding,
794
+ conversation_id,
795
+ )
796
+
797
+ # Step 4: Handle cache hit or miss
798
+ if cache_result is not None:
799
+ cached, combined_similarity = cache_result
800
+
801
+ result = self._handle_cache_hit(
802
+ question,
803
+ conversation_id,
804
+ cached,
805
+ combined_similarity,
806
+ conversation_context,
807
+ question_embedding,
808
+ context_embedding,
809
+ )
810
+
811
+ # Hook for post-cache-hit actions (e.g., update prompt cache_hit flag)
812
+ self._after_cache_hit(question, conversation_id, result)
813
+
814
+ return result
815
+
816
+ # Handle cache miss
817
+ result = self._handle_cache_miss(
818
+ question,
819
+ conversation_id,
820
+ conversation_context,
821
+ question_embedding,
822
+ context_embedding,
823
+ )
824
+
825
+ # Hook for post-cache-miss actions (e.g., store prompt if not done earlier)
826
+ self._after_cache_miss(question, conversation_id, result)
827
+
828
+ return result
829
+
830
+ def _before_cache_lookup(self, question: str, conversation_id: str | None) -> None:
831
+ """
832
+ Hook method called before cache lookup.
833
+
834
+ Override this method to perform actions before searching the cache,
835
+ such as storing the prompt in history.
836
+
837
+ Args:
838
+ question: The question being asked
839
+ conversation_id: Optional conversation ID
840
+ """
841
+ pass
842
+
843
+ def _after_cache_hit(
844
+ self,
845
+ question: str,
846
+ conversation_id: str | None,
847
+ result: CacheResult,
848
+ ) -> None:
849
+ """
850
+ Hook method called after a cache hit.
851
+
852
+ Override this method to perform actions after a successful cache hit,
853
+ such as updating prompt history flags.
854
+
855
+ Args:
856
+ question: The question that was asked
857
+ conversation_id: Optional conversation ID
858
+ result: The cache result
859
+ """
860
+ pass
861
+
862
+ def _after_cache_miss(
863
+ self,
864
+ question: str,
865
+ conversation_id: str | None,
866
+ result: CacheResult,
867
+ ) -> None:
868
+ """
869
+ Hook method called after a cache miss.
870
+
871
+ Override this method to perform actions after a cache miss,
872
+ such as storing prompt history if not done earlier.
873
+
874
+ Args:
875
+ question: The question that was asked
876
+ conversation_id: Optional conversation ID
877
+ result: The cache result
878
+ """
879
+ pass
880
+
881
+ def _handle_cache_hit(
882
+ self,
883
+ question: str,
884
+ conversation_id: str | None,
885
+ cached: SQLCacheEntry,
886
+ combined_similarity: float,
887
+ conversation_context: str,
888
+ question_embedding: list[float],
889
+ context_embedding: list[float],
890
+ ) -> CacheResult:
891
+ """
892
+ Handle a cache hit - execute cached SQL and return response.
893
+
894
+ This method handles the common cache hit logic including SQL execution,
895
+ stale cache fallback, and response building.
896
+
897
+ Args:
898
+ question: The original question
899
+ conversation_id: The conversation ID
900
+ cached: The cached SQL entry
901
+ combined_similarity: The similarity score
902
+ conversation_context: The conversation context string
903
+ question_embedding: The question embedding
904
+ context_embedding: The context embedding
905
+
906
+ Returns:
907
+ CacheResult with the response
908
+ """
909
+ logger.debug(
910
+ "Cache hit",
911
+ layer=self.name,
912
+ combined_similarity=f"{combined_similarity:.3f}",
913
+ question=question[:50],
914
+ conversation_id=conversation_id,
915
+ )
916
+
917
+ # Re-execute the cached SQL to get fresh data
918
+ result: pd.DataFrame | str = self._execute_sql(cached.query)
919
+
920
+ # Check if SQL execution failed (returns error string instead of DataFrame)
921
+ if isinstance(result, str):
922
+ logger.warning(
923
+ "Cached SQL execution failed, falling back to Genie",
924
+ layer=self.name,
925
+ question=question[:80],
926
+ conversation_id=conversation_id,
927
+ cached_sql=cached.query[:80],
928
+ error=result[:200],
929
+ space_id=self.space_id,
930
+ )
931
+
932
+ # Subclass should handle stale entry cleanup
933
+ self._on_stale_cache_entry(question)
934
+
935
+ # Fall back to Genie to get fresh SQL
936
+ logger.info(
937
+ "Delegating to Genie for fresh SQL",
938
+ layer=self.name,
939
+ question=question[:80],
940
+ conversation_id=conversation_id,
941
+ space_id=self.space_id,
942
+ delegating_to=type(self.impl).__name__,
943
+ )
944
+ fallback_result: CacheResult = self.impl.ask_question(
945
+ question, conversation_id
946
+ )
947
+
948
+ # Store the fresh SQL in cache
949
+ if fallback_result.response.query:
950
+ self._store_entry(
951
+ question,
952
+ conversation_context,
953
+ question_embedding,
954
+ context_embedding,
955
+ fallback_result.response,
956
+ message_id=fallback_result.message_id,
957
+ )
958
+ logger.info(
959
+ "Stored fresh SQL from fallback",
960
+ layer=self.name,
961
+ fresh_sql=fallback_result.response.query[:80],
962
+ space_id=self.space_id,
963
+ message_id=fallback_result.message_id,
964
+ )
965
+ else:
966
+ logger.warning(
967
+ "Fallback response has no SQL query to cache",
968
+ layer=self.name,
969
+ question=question[:80],
970
+ space_id=self.space_id,
971
+ )
972
+
973
+ # Return as cache miss (fallback scenario)
974
+ # Propagate message_id from fallback result
975
+ return CacheResult(
976
+ response=fallback_result.response,
977
+ cache_hit=False,
978
+ served_by=None,
979
+ message_id=fallback_result.message_id,
980
+ )
981
+
982
+ # Build and return cache hit response
983
+ return self._build_cache_hit_response(cached, result, conversation_id)
984
+
985
+ def _on_stale_cache_entry(self, question: str) -> None:
986
+ """
987
+ Called when a stale cache entry is detected (SQL execution failed).
988
+
989
+ Subclasses can override this to clean up the stale entry from storage.
990
+
991
+ Args:
992
+ question: The question that had a stale cache entry
993
+ """
994
+ # Default implementation does nothing - subclasses should override
995
+ pass
996
+
997
+ def _handle_cache_miss(
998
+ self,
999
+ question: str,
1000
+ conversation_id: str | None,
1001
+ conversation_context: str,
1002
+ question_embedding: list[float],
1003
+ context_embedding: list[float],
1004
+ ) -> CacheResult:
1005
+ """
1006
+ Handle a cache miss - delegate to underlying service and store result.
1007
+
1008
+ Args:
1009
+ question: The original question
1010
+ conversation_id: The conversation ID
1011
+ conversation_context: The conversation context string
1012
+ question_embedding: The question embedding
1013
+ context_embedding: The context embedding
1014
+
1015
+ Returns:
1016
+ CacheResult from the underlying service
1017
+ """
1018
+ logger.info(
1019
+ "Cache MISS",
1020
+ layer=self.name,
1021
+ question=question[:80],
1022
+ conversation_id=conversation_id,
1023
+ space_id=self.space_id,
1024
+ similarity_threshold=self.similarity_threshold,
1025
+ delegating_to=type(self.impl).__name__,
1026
+ )
1027
+
1028
+ result: CacheResult = self.impl.ask_question(question, conversation_id)
1029
+
1030
+ # Store in cache if we got a SQL query
1031
+ if result.response.query:
1032
+ logger.debug(
1033
+ "Storing new cache entry",
1034
+ layer=self.name,
1035
+ question=question[:50],
1036
+ conversation_id=conversation_id,
1037
+ space=self.space_id,
1038
+ message_id=result.message_id,
1039
+ )
1040
+ self._store_entry(
1041
+ question,
1042
+ conversation_context,
1043
+ question_embedding,
1044
+ context_embedding,
1045
+ result.response,
1046
+ message_id=result.message_id,
1047
+ )
1048
+ else:
1049
+ logger.warning(
1050
+ "Not caching: response has no SQL query",
1051
+ layer=self.name,
1052
+ question=question[:50],
1053
+ )
1054
+
1055
+ # Propagate message_id from underlying service result
1056
+ return CacheResult(
1057
+ response=result.response,
1058
+ cache_hit=False,
1059
+ served_by=None,
1060
+ message_id=result.message_id,
1061
+ )
1062
+
1063
+ @abstractmethod
1064
+ def _invalidate_by_question(self, question: str) -> bool:
1065
+ """
1066
+ Invalidate cache entries matching a specific question.
1067
+
1068
+ This method is called when negative feedback is received to remove
1069
+ the corresponding cache entry.
1070
+
1071
+ Args:
1072
+ question: The question text to match and invalidate
1073
+
1074
+ Returns:
1075
+ True if an entry was found and invalidated, False otherwise
1076
+ """
1077
+ pass
1078
+
1079
+ @mlflow.trace(name="genie_context_aware_send_feedback")
1080
+ def send_feedback(
1081
+ self,
1082
+ conversation_id: str,
1083
+ rating: GenieFeedbackRating,
1084
+ message_id: str | None = None,
1085
+ was_cache_hit: bool = False,
1086
+ ) -> None:
1087
+ """
1088
+ Send feedback for a Genie message with cache invalidation.
1089
+
1090
+ For context-aware caches, this method:
1091
+ 1. If was_cache_hit is False: forwards feedback to the underlying service
1092
+ 2. If rating is NEGATIVE: invalidates any matching cache entries
1093
+
1094
+ Args:
1095
+ conversation_id: The conversation containing the message
1096
+ rating: The feedback rating (POSITIVE, NEGATIVE, or NONE)
1097
+ message_id: Optional message ID. If None, looks up the most recent message.
1098
+ was_cache_hit: Whether the response being rated was served from cache.
1099
+
1100
+ Note:
1101
+ For cached responses (was_cache_hit=True), only cache invalidation is
1102
+ performed. No feedback is sent to the Genie API because cached responses
1103
+ don't have a corresponding Genie message.
1104
+
1105
+ Future Enhancement: To enable full Genie feedback for cached responses,
1106
+ the cache would need to store the original message_id. This would require:
1107
+ 1. Adding message_id column to cache tables
1108
+ 2. Adding message_id field to SQLCacheEntry dataclass
1109
+ 3. Capturing message_id from the original Genie API response
1110
+ (databricks_ai_bridge.genie.GenieResponse doesn't expose this)
1111
+ 4. Using WorkspaceClient directly instead of databricks_ai_bridge
1112
+ """
1113
+ invalidated = False
1114
+
1115
+ # Handle cache invalidation on negative feedback
1116
+ if rating == GenieFeedbackRating.NEGATIVE:
1117
+ # Need to look up the message content to find matching cache entries
1118
+ if self.workspace_client is not None:
1119
+ from dao_ai.genie.cache.base import (
1120
+ get_latest_message_id,
1121
+ get_message_content,
1122
+ )
1123
+
1124
+ # Get message_id if not provided
1125
+ target_message_id = message_id
1126
+ if target_message_id is None:
1127
+ target_message_id = get_latest_message_id(
1128
+ workspace_client=self.workspace_client,
1129
+ space_id=self.space_id,
1130
+ conversation_id=conversation_id,
1131
+ )
1132
+
1133
+ # Get the message content (question) to find matching cache entries
1134
+ if target_message_id:
1135
+ question = get_message_content(
1136
+ workspace_client=self.workspace_client,
1137
+ space_id=self.space_id,
1138
+ conversation_id=conversation_id,
1139
+ message_id=target_message_id,
1140
+ )
1141
+
1142
+ if question:
1143
+ invalidated = self._invalidate_by_question(question)
1144
+ if invalidated:
1145
+ logger.info(
1146
+ "Invalidated cache entry due to negative feedback",
1147
+ layer=self.name,
1148
+ question=question[:80],
1149
+ conversation_id=conversation_id,
1150
+ message_id=target_message_id,
1151
+ )
1152
+ else:
1153
+ logger.debug(
1154
+ "No cache entry found to invalidate for negative feedback",
1155
+ layer=self.name,
1156
+ question=question[:80],
1157
+ conversation_id=conversation_id,
1158
+ )
1159
+ else:
1160
+ logger.warning(
1161
+ "Could not retrieve message content for cache invalidation",
1162
+ layer=self.name,
1163
+ conversation_id=conversation_id,
1164
+ message_id=target_message_id,
1165
+ )
1166
+ else:
1167
+ logger.warning(
1168
+ "Could not find message_id for cache invalidation",
1169
+ layer=self.name,
1170
+ conversation_id=conversation_id,
1171
+ )
1172
+ else:
1173
+ logger.warning(
1174
+ "No workspace_client available for cache invalidation",
1175
+ layer=self.name,
1176
+ conversation_id=conversation_id,
1177
+ )
1178
+
1179
+ # Forward feedback to underlying service if not a cache hit
1180
+ # For cache hits, there's no Genie message to provide feedback on
1181
+ if was_cache_hit:
1182
+ logger.info(
1183
+ "Skipping Genie API feedback - response was served from cache",
1184
+ layer=self.name,
1185
+ conversation_id=conversation_id,
1186
+ rating=rating.value if rating else None,
1187
+ cache_invalidated=invalidated,
1188
+ )
1189
+ return
1190
+
1191
+ # Forward to underlying service
1192
+ logger.debug(
1193
+ "Forwarding feedback to underlying service",
1194
+ layer=self.name,
1195
+ conversation_id=conversation_id,
1196
+ rating=rating.value if rating else None,
1197
+ delegating_to=type(self.impl).__name__,
1198
+ )
1199
+ self.impl.send_feedback(
1200
+ conversation_id=conversation_id,
1201
+ rating=rating,
1202
+ message_id=message_id,
1203
+ was_cache_hit=False, # Already handled, so pass False
1204
+ )