dao-ai 0.0.36__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. dao_ai/__init__.py +29 -0
  2. dao_ai/cli.py +195 -30
  3. dao_ai/config.py +770 -244
  4. dao_ai/genie/__init__.py +1 -22
  5. dao_ai/genie/cache/__init__.py +1 -2
  6. dao_ai/genie/cache/base.py +20 -70
  7. dao_ai/genie/cache/core.py +75 -0
  8. dao_ai/genie/cache/lru.py +44 -21
  9. dao_ai/genie/cache/semantic.py +390 -109
  10. dao_ai/genie/core.py +35 -0
  11. dao_ai/graph.py +27 -253
  12. dao_ai/hooks/__init__.py +9 -6
  13. dao_ai/hooks/core.py +22 -190
  14. dao_ai/memory/__init__.py +10 -0
  15. dao_ai/memory/core.py +23 -5
  16. dao_ai/memory/databricks.py +389 -0
  17. dao_ai/memory/postgres.py +2 -2
  18. dao_ai/messages.py +6 -4
  19. dao_ai/middleware/__init__.py +125 -0
  20. dao_ai/middleware/assertions.py +778 -0
  21. dao_ai/middleware/base.py +50 -0
  22. dao_ai/middleware/core.py +61 -0
  23. dao_ai/middleware/guardrails.py +415 -0
  24. dao_ai/middleware/human_in_the_loop.py +228 -0
  25. dao_ai/middleware/message_validation.py +554 -0
  26. dao_ai/middleware/summarization.py +192 -0
  27. dao_ai/models.py +1177 -108
  28. dao_ai/nodes.py +118 -161
  29. dao_ai/optimization.py +664 -0
  30. dao_ai/orchestration/__init__.py +52 -0
  31. dao_ai/orchestration/core.py +287 -0
  32. dao_ai/orchestration/supervisor.py +264 -0
  33. dao_ai/orchestration/swarm.py +226 -0
  34. dao_ai/prompts.py +126 -29
  35. dao_ai/providers/databricks.py +126 -381
  36. dao_ai/state.py +139 -21
  37. dao_ai/tools/__init__.py +8 -5
  38. dao_ai/tools/core.py +57 -4
  39. dao_ai/tools/email.py +280 -0
  40. dao_ai/tools/genie.py +47 -24
  41. dao_ai/tools/mcp.py +4 -3
  42. dao_ai/tools/memory.py +50 -0
  43. dao_ai/tools/python.py +4 -12
  44. dao_ai/tools/search.py +14 -0
  45. dao_ai/tools/slack.py +1 -1
  46. dao_ai/tools/unity_catalog.py +8 -6
  47. dao_ai/tools/vector_search.py +16 -9
  48. dao_ai/utils.py +72 -8
  49. dao_ai-0.1.0.dist-info/METADATA +1878 -0
  50. dao_ai-0.1.0.dist-info/RECORD +62 -0
  51. dao_ai/chat_models.py +0 -204
  52. dao_ai/guardrails.py +0 -112
  53. dao_ai/tools/genie/__init__.py +0 -236
  54. dao_ai/tools/human_in_the_loop.py +0 -100
  55. dao_ai-0.0.36.dist-info/METADATA +0 -951
  56. dao_ai-0.0.36.dist-info/RECORD +0 -47
  57. {dao_ai-0.0.36.dist-info → dao_ai-0.1.0.dist-info}/WHEEL +0 -0
  58. {dao_ai-0.0.36.dist-info → dao_ai-0.1.0.dist-info}/entry_points.txt +0 -0
  59. {dao_ai-0.0.36.dist-info → dao_ai-0.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -4,6 +4,10 @@ Semantic cache implementation for Genie SQL queries using PostgreSQL pg_vector.
4
4
  This module provides a semantic cache that uses embeddings and similarity search
5
5
  to find cached queries that match the intent of new questions. Cache entries are
6
6
  partitioned by genie_space_id to ensure proper isolation between Genie spaces.
7
+
8
+ The cache supports conversation-aware embedding using a rolling window approach
9
+ to capture context from recent conversation turns, improving accuracy for
10
+ multi-turn conversations with anaphoric references.
7
11
  """
8
12
 
9
13
  from datetime import timedelta
@@ -12,14 +16,18 @@ from typing import Any
12
16
  import mlflow
13
17
  import pandas as pd
14
18
  from databricks.sdk import WorkspaceClient
19
+ from databricks.sdk.service.dashboards import (
20
+ GenieListConversationMessagesResponse,
21
+ GenieMessage,
22
+ )
15
23
  from databricks.sdk.service.sql import StatementResponse, StatementState
16
24
  from databricks_ai_bridge.genie import GenieResponse
17
25
  from loguru import logger
18
- from mlflow.entities import SpanType
19
26
 
20
27
  from dao_ai.config import (
21
28
  DatabaseModel,
22
29
  GenieSemanticCacheParametersModel,
30
+ LLMModel,
23
31
  WarehouseModel,
24
32
  )
25
33
  from dao_ai.genie.cache.base import (
@@ -32,6 +40,112 @@ from dao_ai.genie.cache.base import (
32
40
  DbRow = dict[str, Any]
33
41
 
34
42
 
43
+ def get_conversation_history(
44
+ workspace_client: WorkspaceClient,
45
+ space_id: str,
46
+ conversation_id: str,
47
+ max_messages: int = 10,
48
+ ) -> list[GenieMessage]:
49
+ """
50
+ Retrieve conversation history from Genie.
51
+
52
+ Args:
53
+ workspace_client: The Databricks workspace client
54
+ space_id: The Genie space ID
55
+ conversation_id: The conversation ID to retrieve
56
+ max_messages: Maximum number of messages to retrieve
57
+
58
+ Returns:
59
+ List of GenieMessage objects representing the conversation history
60
+ """
61
+ try:
62
+ # Use the Genie API to retrieve conversation messages
63
+ response: GenieListConversationMessagesResponse = (
64
+ workspace_client.genie.list_conversation_messages(
65
+ space_id=space_id,
66
+ conversation_id=conversation_id,
67
+ )
68
+ )
69
+
70
+ # Return the most recent messages up to max_messages
71
+ if response.messages is not None:
72
+ all_messages: list[GenieMessage] = list(response.messages)
73
+ return (
74
+ all_messages[-max_messages:]
75
+ if len(all_messages) > max_messages
76
+ else all_messages
77
+ )
78
+ return []
79
+ except Exception as e:
80
+ logger.warning(
81
+ f"Failed to retrieve conversation history for conversation_id={conversation_id}: {e}"
82
+ )
83
+ return []
84
+
85
+
86
+ def build_context_string(
87
+ question: str,
88
+ conversation_messages: list[GenieMessage],
89
+ window_size: int,
90
+ max_tokens: int = 2000,
91
+ ) -> str:
92
+ """
93
+ Build a context-aware question string using rolling window.
94
+
95
+ This function creates a concatenated string that includes recent conversation
96
+ turns to provide context for semantic similarity matching.
97
+
98
+ Args:
99
+ question: The current question
100
+ conversation_messages: List of previous conversation messages
101
+ window_size: Number of previous turns to include
102
+ max_tokens: Maximum estimated tokens (rough approximation: 4 chars = 1 token)
103
+
104
+ Returns:
105
+ Context-aware question string formatted for embedding
106
+ """
107
+ if window_size <= 0 or not conversation_messages:
108
+ return question
109
+
110
+ # Take the last window_size messages (most recent)
111
+ recent_messages = (
112
+ conversation_messages[-window_size:]
113
+ if len(conversation_messages) > window_size
114
+ else conversation_messages
115
+ )
116
+
117
+ # Build context parts
118
+ context_parts: list[str] = []
119
+
120
+ for msg in recent_messages:
121
+ # Only include messages with content from the history
122
+ if msg.content:
123
+ # Limit message length to prevent token overflow
124
+ content: str = msg.content
125
+ if len(content) > 500: # Truncate very long messages
126
+ content = content[:500] + "..."
127
+ context_parts.append(f"Previous: {content}")
128
+
129
+ # Add current question
130
+ context_parts.append(f"Current: {question}")
131
+
132
+ # Join with newlines
133
+ context_string = "\n".join(context_parts)
134
+
135
+ # Rough token limit check (4 chars ≈ 1 token)
136
+ estimated_tokens = len(context_string) / 4
137
+ if estimated_tokens > max_tokens:
138
+ # Truncate to fit max_tokens
139
+ target_chars = max_tokens * 4
140
+ context_string = context_string[:target_chars] + "..."
141
+ logger.debug(
142
+ f"Truncated context string from {len(context_string)} to {target_chars} chars "
143
+ f"(estimated {max_tokens} tokens)"
144
+ )
145
+
146
+ return context_string
147
+
148
+
35
149
  class SemanticCacheService(GenieServiceBase):
36
150
  """
37
151
  Semantic caching decorator that uses PostgreSQL pg_vector for similarity lookup.
@@ -59,8 +173,7 @@ class SemanticCacheService(GenieServiceBase):
59
173
  )
60
174
  genie = SemanticCacheService(
61
175
  impl=GenieService(Genie(space_id="my-space")),
62
- parameters=cache_params,
63
- genie_space_id="my-space"
176
+ parameters=cache_params
64
177
  )
65
178
 
66
179
  Thread-safe: Uses connection pooling from psycopg_pool.
@@ -68,7 +181,7 @@ class SemanticCacheService(GenieServiceBase):
68
181
 
69
182
  impl: GenieServiceBase
70
183
  parameters: GenieSemanticCacheParametersModel
71
- genie_space_id: str
184
+ workspace_client: WorkspaceClient | None
72
185
  name: str
73
186
  _embeddings: Any # DatabricksEmbeddings
74
187
  _pool: Any # ConnectionPool
@@ -79,21 +192,23 @@ class SemanticCacheService(GenieServiceBase):
79
192
  self,
80
193
  impl: GenieServiceBase,
81
194
  parameters: GenieSemanticCacheParametersModel,
82
- genie_space_id: str,
195
+ workspace_client: WorkspaceClient | None = None,
83
196
  name: str | None = None,
84
197
  ) -> None:
85
198
  """
86
199
  Initialize the semantic cache service.
87
200
 
88
201
  Args:
89
- impl: The underlying GenieServiceBase to delegate to on cache miss
202
+ impl: The underlying GenieServiceBase to delegate to on cache miss.
203
+ The space_id will be obtained from impl.space_id.
90
204
  parameters: Cache configuration including database, warehouse, embedding model
91
- genie_space_id: The Genie space ID for partitioning cache entries
205
+ workspace_client: Optional WorkspaceClient for retrieving conversation history.
206
+ If None, conversation context will not be used.
92
207
  name: Name for this cache layer (for logging). Defaults to class name.
93
208
  """
94
209
  self.impl = impl
95
210
  self.parameters = parameters
96
- self.genie_space_id = genie_space_id
211
+ self.workspace_client = workspace_client
97
212
  self.name = name if name is not None else self.__class__.__name__
98
213
  self._embeddings = None
99
214
  self._pool = None
@@ -120,17 +235,16 @@ class SemanticCacheService(GenieServiceBase):
120
235
  if self._setup_complete:
121
236
  return
122
237
 
123
- from databricks_langchain import DatabricksEmbeddings
124
-
125
238
  from dao_ai.memory.postgres import PostgresPoolManager
126
239
 
127
240
  # Initialize embeddings
128
- embedding_model: str = (
129
- self.parameters.embedding_model
241
+ # Convert embedding_model to LLMModel if it's a string
242
+ embedding_model: LLMModel = (
243
+ LLMModel(name=self.parameters.embedding_model)
130
244
  if isinstance(self.parameters.embedding_model, str)
131
- else self.parameters.embedding_model.name
245
+ else self.parameters.embedding_model
132
246
  )
133
- self._embeddings = DatabricksEmbeddings(endpoint=embedding_model)
247
+ self._embeddings = embedding_model.as_embeddings_model()
134
248
 
135
249
  # Auto-detect embedding dimensions if not provided
136
250
  if self.parameters.embedding_dims is None:
@@ -150,7 +264,7 @@ class SemanticCacheService(GenieServiceBase):
150
264
 
151
265
  self._setup_complete = True
152
266
  logger.debug(
153
- f"[{self.name}] Semantic cache initialized for space '{self.genie_space_id}' "
267
+ f"[{self.name}] Semantic cache initialized for space '{self.space_id}' "
154
268
  f"with table '{self.table_name}' (dims={self._embedding_dims})"
155
269
  )
156
270
 
@@ -212,7 +326,10 @@ class SemanticCacheService(GenieServiceBase):
212
326
  id SERIAL PRIMARY KEY,
213
327
  genie_space_id TEXT NOT NULL,
214
328
  question TEXT NOT NULL,
329
+ conversation_context TEXT,
330
+ context_string TEXT,
215
331
  question_embedding vector({self.embedding_dims}),
332
+ context_embedding vector({self.embedding_dims}),
216
333
  sql_query TEXT NOT NULL,
217
334
  description TEXT,
218
335
  conversation_id TEXT,
@@ -221,12 +338,18 @@ class SemanticCacheService(GenieServiceBase):
221
338
  """
222
339
  # Index for efficient similarity search partitioned by genie_space_id
223
340
  # Use L2 (Euclidean) distance - optimal for Databricks GTE embeddings
224
- create_embedding_index_sql: str = f"""
225
- CREATE INDEX IF NOT EXISTS {self.table_name}_embedding_idx
341
+ create_question_embedding_index_sql: str = f"""
342
+ CREATE INDEX IF NOT EXISTS {self.table_name}_question_embedding_idx
226
343
  ON {self.table_name}
227
344
  USING ivfflat (question_embedding vector_l2_ops)
228
345
  WITH (lists = 100)
229
346
  """
347
+ create_context_embedding_index_sql: str = f"""
348
+ CREATE INDEX IF NOT EXISTS {self.table_name}_context_embedding_idx
349
+ ON {self.table_name}
350
+ USING ivfflat (context_embedding vector_l2_ops)
351
+ WITH (lists = 100)
352
+ """
230
353
  # Index for filtering by genie_space_id
231
354
  create_space_index_sql: str = f"""
232
355
  CREATE INDEX IF NOT EXISTS {self.table_name}_space_idx
@@ -257,35 +380,132 @@ class SemanticCacheService(GenieServiceBase):
257
380
 
258
381
  cur.execute(create_table_sql)
259
382
  cur.execute(create_space_index_sql)
260
- cur.execute(create_embedding_index_sql)
383
+ cur.execute(create_question_embedding_index_sql)
384
+ cur.execute(create_context_embedding_index_sql)
385
+
386
+ def _embed_question(
387
+ self, question: str, conversation_id: str | None = None
388
+ ) -> tuple[list[float], list[float], str]:
389
+ """
390
+ Generate dual embeddings: one for the question, one for the conversation context.
391
+
392
+ This enables separate matching of question similarity vs context similarity,
393
+ improving precision by ensuring both the question AND the conversation context
394
+ are semantically similar before returning a cached result.
395
+
396
+ Args:
397
+ question: The question to embed
398
+ conversation_id: Optional conversation ID for retrieving context
399
+
400
+ Returns:
401
+ Tuple of (question_embedding, context_embedding, conversation_context_string)
402
+ - question_embedding: Vector embedding of just the question
403
+ - context_embedding: Vector embedding of the conversation context (or zero vector if no context)
404
+ - conversation_context_string: The conversation context string (empty if no context)
405
+ """
406
+ conversation_context = ""
407
+
408
+ # If conversation context is enabled and available
409
+ if (
410
+ self.workspace_client is not None
411
+ and conversation_id is not None
412
+ and self.parameters.context_window_size > 0
413
+ ):
414
+ try:
415
+ # Retrieve conversation history
416
+ conversation_messages = get_conversation_history(
417
+ workspace_client=self.workspace_client,
418
+ space_id=self.space_id,
419
+ conversation_id=conversation_id,
420
+ max_messages=self.parameters.context_window_size
421
+ * 2, # Get extra for safety
422
+ )
423
+
424
+ # Build context string (just the "Previous:" messages, not the current question)
425
+ if conversation_messages:
426
+ recent_messages = (
427
+ conversation_messages[-self.parameters.context_window_size :]
428
+ if len(conversation_messages)
429
+ > self.parameters.context_window_size
430
+ else conversation_messages
431
+ )
432
+
433
+ context_parts: list[str] = []
434
+ for msg in recent_messages:
435
+ if msg.content:
436
+ content: str = msg.content
437
+ if len(content) > 500:
438
+ content = content[:500] + "..."
439
+ context_parts.append(f"Previous: {content}")
440
+
441
+ conversation_context = "\n".join(context_parts)
261
442
 
262
- def _embed_question(self, question: str) -> list[float]:
263
- """Generate embedding for a question."""
264
- embeddings: list[list[float]] = self._embeddings.embed_documents([question])
265
- return embeddings[0]
443
+ # Truncate if too long
444
+ estimated_tokens = len(conversation_context) / 4
445
+ if estimated_tokens > self.parameters.max_context_tokens:
446
+ target_chars = self.parameters.max_context_tokens * 4
447
+ conversation_context = (
448
+ conversation_context[:target_chars] + "..."
449
+ )
450
+
451
+ logger.debug(
452
+ f"[{self.name}] Using conversation context: {len(conversation_messages)} messages "
453
+ f"(window_size={self.parameters.context_window_size})"
454
+ )
455
+ except Exception as e:
456
+ logger.warning(
457
+ f"[{self.name}] Failed to build conversation context, using question only: {e}"
458
+ )
459
+ conversation_context = ""
460
+
461
+ # Generate dual embeddings
462
+ if conversation_context:
463
+ # Embed both question and context
464
+ embeddings: list[list[float]] = self._embeddings.embed_documents(
465
+ [question, conversation_context]
466
+ )
467
+ question_embedding = embeddings[0]
468
+ context_embedding = embeddings[1]
469
+ else:
470
+ # Only embed question, use zero vector for context
471
+ embeddings = self._embeddings.embed_documents([question])
472
+ question_embedding = embeddings[0]
473
+ context_embedding = [0.0] * len(question_embedding) # Zero vector
474
+
475
+ return question_embedding, context_embedding, conversation_context
266
476
 
267
477
  @mlflow.trace(name="semantic_search")
268
478
  def _find_similar(
269
- self, question: str, embedding: list[float]
479
+ self,
480
+ question: str,
481
+ conversation_context: str,
482
+ question_embedding: list[float],
483
+ context_embedding: list[float],
270
484
  ) -> tuple[SQLCacheEntry, float] | None:
271
485
  """
272
- Find a semantically similar cached entry for this Genie space.
486
+ Find a semantically similar cached entry using dual embedding matching.
487
+
488
+ This method matches BOTH the question AND the conversation context separately,
489
+ ensuring high precision by requiring both to be semantically similar.
273
490
 
274
491
  Args:
275
- question: The question to search for
276
- embedding: The embedding vector of the question
492
+ question: The original question (for logging)
493
+ conversation_context: The conversation context string
494
+ question_embedding: The embedding vector of just the question
495
+ context_embedding: The embedding vector of the conversation context
277
496
 
278
497
  Returns:
279
- Tuple of (SQLCacheEntry, similarity_score) if found, None otherwise
498
+ Tuple of (SQLCacheEntry, combined_similarity_score) if found, None otherwise
280
499
  """
281
500
  # Use L2 (Euclidean) distance - optimal for Databricks GTE embeddings
282
501
  # pg_vector's <-> operator returns L2 distance (0 = identical)
283
502
  # Convert to similarity: 1 / (1 + distance) gives range [0, 1]
284
503
  #
285
- # Refresh-on-hit strategy:
286
- # 1. Search without TTL filter to find best semantic match
287
- # 2. If match is within TTL (or TTL disabled) → cache hit
288
- # 3. If match is expired → delete it, return miss (triggers refresh)
504
+ # Dual embedding strategy:
505
+ # 1. Calculate separate similarities for question and context
506
+ # 2. BOTH must exceed their respective thresholds
507
+ # 3. Combined score is weighted average
508
+ # 4. Refresh-on-hit: check TTL after similarity check
289
509
  ttl_seconds = self.parameters.time_to_live_seconds
290
510
  ttl_disabled = ttl_seconds is None or ttl_seconds < 0
291
511
 
@@ -295,63 +515,87 @@ class SemanticCacheService(GenieServiceBase):
295
515
  else:
296
516
  is_valid_expr = f"created_at > NOW() - INTERVAL '{ttl_seconds} seconds'"
297
517
 
518
+ # Weighted combined similarity for ordering
519
+ question_weight: float = self.parameters.question_weight
520
+ context_weight: float = self.parameters.context_weight
521
+
298
522
  search_sql: str = f"""
299
523
  SELECT
300
524
  id,
301
525
  question,
526
+ conversation_context,
302
527
  sql_query,
303
528
  description,
304
529
  conversation_id,
305
530
  created_at,
306
- 1.0 / (1.0 + (question_embedding <-> %s::vector)) as similarity,
531
+ 1.0 / (1.0 + (question_embedding <-> %s::vector)) as question_similarity,
532
+ 1.0 / (1.0 + (context_embedding <-> %s::vector)) as context_similarity,
533
+ ({question_weight} * (1.0 / (1.0 + (question_embedding <-> %s::vector)))) +
534
+ ({context_weight} * (1.0 / (1.0 + (context_embedding <-> %s::vector)))) as combined_similarity,
307
535
  {is_valid_expr} as is_valid
308
536
  FROM {self.table_name}
309
537
  WHERE genie_space_id = %s
310
- ORDER BY question_embedding <-> %s::vector
538
+ ORDER BY combined_similarity DESC
311
539
  LIMIT 1
312
540
  """
313
541
 
314
- embedding_str: str = f"[{','.join(str(x) for x in embedding)}]"
542
+ question_emb_str: str = f"[{','.join(str(x) for x in question_embedding)}]"
543
+ context_emb_str: str = f"[{','.join(str(x) for x in context_embedding)}]"
315
544
 
316
545
  with self._pool.connection() as conn:
317
546
  with conn.cursor() as cur:
318
547
  cur.execute(
319
548
  search_sql,
320
- (embedding_str, self.genie_space_id, embedding_str),
549
+ (
550
+ question_emb_str,
551
+ context_emb_str,
552
+ question_emb_str,
553
+ context_emb_str,
554
+ self.space_id,
555
+ ),
321
556
  )
322
557
  row: DbRow | None = cur.fetchone()
323
558
 
324
559
  if row is None:
325
560
  logger.info(
326
561
  f"[{self.name}] MISS (no entries): "
327
- f"question='{question[:50]}...' space='{self.genie_space_id}'"
562
+ f"question='{question[:50]}...' space='{self.space_id}'"
328
563
  )
329
564
  return None
330
565
 
331
566
  # Extract values from dict row
332
- entry_id = row.get("id")
333
- cached_question = row.get("question", "")
334
- sql_query = row["sql_query"]
335
- description = row.get("description", "")
336
- conversation_id = row.get("conversation_id", "")
337
- created_at = row["created_at"]
338
- similarity = row["similarity"]
339
- is_valid = row.get("is_valid", False)
340
-
341
- # Log best match info (L2 distance can be computed from similarity: d = 1/s - 1)
342
- l2_distance = (
343
- (1.0 / similarity) - 1.0 if similarity > 0 else float("inf")
344
- )
567
+ entry_id: Any = row.get("id")
568
+ cached_question: str = row.get("question", "")
569
+ cached_context: str = row.get("conversation_context", "")
570
+ sql_query: str = row["sql_query"]
571
+ description: str = row.get("description", "")
572
+ conversation_id_cached: str = row.get("conversation_id", "")
573
+ created_at: Any = row["created_at"]
574
+ question_similarity: float = row["question_similarity"]
575
+ context_similarity: float = row["context_similarity"]
576
+ combined_similarity: float = row["combined_similarity"]
577
+ is_valid: bool = row.get("is_valid", False)
578
+
579
+ # Log best match info
345
580
  logger.info(
346
- f"[{self.name}] Best match: l2_distance={l2_distance:.4f}, similarity={similarity:.4f}, "
347
- f"is_valid={is_valid}, question='{cached_question[:50]}...'"
581
+ f"[{self.name}] Best match: "
582
+ f"question_sim={question_similarity:.4f}, context_sim={context_similarity:.4f}, "
583
+ f"combined_sim={combined_similarity:.4f}, is_valid={is_valid}, "
584
+ f"question='{cached_question[:50]}...', context='{cached_context[:80]}...'"
348
585
  )
349
586
 
350
- # Check similarity threshold
351
- if similarity < self.similarity_threshold:
587
+ # Check BOTH similarity thresholds (dual embedding precision check)
588
+ if question_similarity < self.parameters.similarity_threshold:
589
+ logger.info(
590
+ f"[{self.name}] MISS (question similarity too low): "
591
+ f"question_sim={question_similarity:.4f} < threshold={self.parameters.similarity_threshold}"
592
+ )
593
+ return None
594
+
595
+ if context_similarity < self.parameters.context_similarity_threshold:
352
596
  logger.info(
353
- f"[{self.name}] MISS (below threshold): similarity={similarity:.4f} < threshold={self.similarity_threshold} "
354
- f"(cached_question='{cached_question[:50]}...')"
597
+ f"[{self.name}] MISS (context similarity too low): "
598
+ f"context_sim={context_similarity:.4f} < threshold={self.parameters.context_similarity_threshold}"
355
599
  )
356
600
  return None
357
601
 
@@ -361,43 +605,59 @@ class SemanticCacheService(GenieServiceBase):
361
605
  delete_sql = f"DELETE FROM {self.table_name} WHERE id = %s"
362
606
  cur.execute(delete_sql, (entry_id,))
363
607
  logger.info(
364
- f"[{self.name}] MISS (expired, deleted for refresh): similarity={similarity:.4f}, "
365
- f"ttl={ttl_seconds}s, question='{cached_question[:50]}...'"
608
+ f"[{self.name}] MISS (expired, deleted for refresh): "
609
+ f"combined_sim={combined_similarity:.4f}, ttl={ttl_seconds}s, question='{cached_question[:50]}...'"
366
610
  )
367
611
  return None
368
612
 
369
613
  logger.info(
370
- f"[{self.name}] HIT: similarity={similarity:.4f} >= threshold={self.similarity_threshold} "
371
- f"(cached_question='{cached_question[:50]}...')"
614
+ f"[{self.name}] HIT: question_sim={question_similarity:.4f}, context_sim={context_similarity:.4f}, "
615
+ f"combined_sim={combined_similarity:.4f} (cached_question='{cached_question[:50]}...')"
372
616
  )
373
617
 
374
618
  entry = SQLCacheEntry(
375
619
  query=sql_query,
376
620
  description=description,
377
- conversation_id=conversation_id,
621
+ conversation_id=conversation_id_cached,
378
622
  created_at=created_at,
379
623
  )
380
- return entry, similarity
624
+ return entry, combined_similarity
381
625
 
382
626
  def _store_entry(
383
- self, question: str, embedding: list[float], response: GenieResponse
627
+ self,
628
+ question: str,
629
+ conversation_context: str,
630
+ question_embedding: list[float],
631
+ context_embedding: list[float],
632
+ response: GenieResponse,
384
633
  ) -> None:
385
- """Store a new cache entry for this Genie space."""
634
+ """Store a new cache entry with dual embeddings for this Genie space."""
386
635
  insert_sql: str = f"""
387
636
  INSERT INTO {self.table_name}
388
- (genie_space_id, question, question_embedding, sql_query, description, conversation_id)
389
- VALUES (%s, %s, %s::vector, %s, %s, %s)
637
+ (genie_space_id, question, conversation_context, context_string,
638
+ question_embedding, context_embedding, sql_query, description, conversation_id)
639
+ VALUES (%s, %s, %s, %s, %s::vector, %s::vector, %s, %s, %s)
390
640
  """
391
- embedding_str: str = f"[{','.join(str(x) for x in embedding)}]"
641
+ question_emb_str: str = f"[{','.join(str(x) for x in question_embedding)}]"
642
+ context_emb_str: str = f"[{','.join(str(x) for x in context_embedding)}]"
643
+
644
+ # Build full context string for backward compatibility (used in logging)
645
+ if conversation_context:
646
+ full_context_string = f"{conversation_context}\nCurrent: {question}"
647
+ else:
648
+ full_context_string = question
392
649
 
393
650
  with self._pool.connection() as conn:
394
651
  with conn.cursor() as cur:
395
652
  cur.execute(
396
653
  insert_sql,
397
654
  (
398
- self.genie_space_id,
655
+ self.space_id,
399
656
  question,
400
- embedding_str,
657
+ conversation_context,
658
+ full_context_string,
659
+ question_emb_str,
660
+ context_emb_str,
401
661
  response.query,
402
662
  response.description,
403
663
  response.conversation_id,
@@ -405,14 +665,15 @@ class SemanticCacheService(GenieServiceBase):
405
665
  )
406
666
  logger.info(
407
667
  f"[{self.name}] Stored cache entry: question='{question[:50]}...' "
408
- f"sql='{response.query[:50]}...' (space={self.genie_space_id}, table={self.table_name})"
668
+ f"context='{conversation_context[:80]}...' "
669
+ f"sql='{response.query[:50]}...' (space={self.space_id}, table={self.table_name})"
409
670
  )
410
671
 
411
672
  @mlflow.trace(name="execute_cached_sql_semantic")
412
673
  def _execute_sql(self, sql: str) -> pd.DataFrame | str:
413
674
  """Execute SQL using the warehouse and return results."""
414
675
  client: WorkspaceClient = self.warehouse.workspace_client
415
- warehouse_id: str = self.warehouse.warehouse_id
676
+ warehouse_id: str = str(self.warehouse.warehouse_id)
416
677
 
417
678
  statement_response: StatementResponse = (
418
679
  client.statement_execution.execute_statement(
@@ -422,10 +683,13 @@ class SemanticCacheService(GenieServiceBase):
422
683
  )
423
684
  )
424
685
 
425
- if statement_response.status.state != StatementState.SUCCEEDED:
686
+ if (
687
+ statement_response.status is not None
688
+ and statement_response.status.state != StatementState.SUCCEEDED
689
+ ):
426
690
  error_msg: str = (
427
691
  f"SQL execution failed: {statement_response.status.error.message}"
428
- if statement_response.status.error
692
+ if statement_response.status.error is not None
429
693
  else f"SQL execution failed with state: {statement_response.status.state}"
430
694
  )
431
695
  logger.error(f"[{self.name}] {error_msg}")
@@ -439,10 +703,10 @@ class SemanticCacheService(GenieServiceBase):
439
703
  and statement_response.manifest.schema.columns
440
704
  ):
441
705
  columns = [
442
- col.name for col in statement_response.manifest.schema.columns
706
+ col.name
707
+ for col in statement_response.manifest.schema.columns
708
+ if col.name is not None
443
709
  ]
444
- elif hasattr(statement_response.result, "schema"):
445
- columns = [col.name for col in statement_response.result.schema.columns]
446
710
 
447
711
  data: list[list[Any]] = statement_response.result.data_array
448
712
  if columns:
@@ -454,19 +718,16 @@ class SemanticCacheService(GenieServiceBase):
454
718
 
455
719
  def ask_question(
456
720
  self, question: str, conversation_id: str | None = None
457
- ) -> GenieResponse:
721
+ ) -> CacheResult:
458
722
  """
459
723
  Ask a question, using semantic cache if a similar query exists.
460
724
 
461
725
  On cache hit, re-executes the cached SQL to get fresh data.
462
- Implements GenieServiceBase for seamless chaining.
726
+ Returns CacheResult with cache metadata.
463
727
  """
464
- result: CacheResult = self.ask_question_with_cache_info(
465
- question, conversation_id
466
- )
467
- return result.response
728
+ return self.ask_question_with_cache_info(question, conversation_id)
468
729
 
469
- @mlflow.trace(name="genie_semantic_cache_lookup", span_type=SpanType.TOOL)
730
+ @mlflow.trace(name="genie_semantic_cache_lookup")
470
731
  def ask_question_with_cache_info(
471
732
  self,
472
733
  question: str,
@@ -475,11 +736,12 @@ class SemanticCacheService(GenieServiceBase):
475
736
  """
476
737
  Ask a question with detailed cache hit information.
477
738
 
478
- On cache hit, the cached SQL is re-executed to return fresh data.
739
+ On cache hit, the cached SQL is re-executed to return fresh data, but the
740
+ conversation_id returned is the current conversation_id (not the cached one).
479
741
 
480
742
  Args:
481
743
  question: The question to ask
482
- conversation_id: Optional conversation ID
744
+ conversation_id: Optional conversation ID for context and continuation
483
745
 
484
746
  Returns:
485
747
  CacheResult with fresh response and cache metadata
@@ -487,28 +749,37 @@ class SemanticCacheService(GenieServiceBase):
487
749
  # Ensure initialization (lazy init if initialize() wasn't called)
488
750
  self._setup()
489
751
 
490
- # Generate embedding for the question
491
- embedding: list[float] = self._embed_question(question)
752
+ # Generate dual embeddings for the question and conversation context
753
+ question_embedding: list[float]
754
+ context_embedding: list[float]
755
+ conversation_context: str
756
+ question_embedding, context_embedding, conversation_context = (
757
+ self._embed_question(question, conversation_id)
758
+ )
492
759
 
493
- # Check cache
760
+ # Check cache using dual embedding similarity
494
761
  cache_result: tuple[SQLCacheEntry, float] | None = self._find_similar(
495
- question, embedding
762
+ question, conversation_context, question_embedding, context_embedding
496
763
  )
497
764
 
498
765
  if cache_result is not None:
499
- cached, similarity = cache_result
766
+ cached, combined_similarity = cache_result
500
767
  logger.debug(
501
- f"[{self.name}] Semantic cache hit (similarity={similarity:.3f}): {question[:50]}..."
768
+ f"[{self.name}] Semantic cache hit (combined_similarity={combined_similarity:.3f}): {question[:50]}..."
502
769
  )
503
770
 
504
771
  # Re-execute the cached SQL to get fresh data
505
772
  result: pd.DataFrame | str = self._execute_sql(cached.query)
506
773
 
774
+ # IMPORTANT: Use the current conversation_id (from the request), not the cached one
775
+ # This ensures the conversation continues properly
507
776
  response: GenieResponse = GenieResponse(
508
777
  result=result,
509
778
  query=cached.query,
510
779
  description=cached.description,
511
- conversation_id=cached.conversation_id,
780
+ conversation_id=conversation_id
781
+ if conversation_id
782
+ else cached.conversation_id,
512
783
  )
513
784
 
514
785
  return CacheResult(response=response, cache_hit=True, served_by=self.name)
@@ -516,22 +787,32 @@ class SemanticCacheService(GenieServiceBase):
516
787
  # Cache miss - delegate to wrapped service
517
788
  logger.debug(f"[{self.name}] Miss: {question[:50]}...")
518
789
 
519
- response = self.impl.ask_question(question, conversation_id)
790
+ result: CacheResult = self.impl.ask_question(question, conversation_id)
520
791
 
521
792
  # Store in cache if we got a SQL query
522
- if response.query:
793
+ if result.response.query:
523
794
  logger.info(
524
795
  f"[{self.name}] Storing new cache entry for question: '{question[:50]}...' "
525
- f"(space={self.genie_space_id})"
796
+ f"(space={self.space_id})"
526
797
  )
527
- self._store_entry(question, embedding, response)
528
- elif not response.query:
798
+ self._store_entry(
799
+ question,
800
+ conversation_context,
801
+ question_embedding,
802
+ context_embedding,
803
+ result.response,
804
+ )
805
+ elif not result.response.query:
529
806
  logger.warning(
530
807
  f"[{self.name}] Not caching: response has no SQL query "
531
808
  f"(question='{question[:50]}...')"
532
809
  )
533
810
 
534
- return CacheResult(response=response, cache_hit=False, served_by=None)
811
+ return CacheResult(response=result.response, cache_hit=False, served_by=None)
812
+
813
+ @property
814
+ def space_id(self) -> str:
815
+ return self.impl.space_id
535
816
 
536
817
  def invalidate_expired(self) -> int:
537
818
  """Remove expired entries from the cache for this Genie space.
@@ -544,7 +825,7 @@ class SemanticCacheService(GenieServiceBase):
544
825
  # If TTL is disabled, nothing can expire
545
826
  if ttl_seconds is None or ttl_seconds < 0:
546
827
  logger.debug(
547
- f"[{self.name}] TTL disabled, no entries to expire for space {self.genie_space_id}"
828
+ f"[{self.name}] TTL disabled, no entries to expire for space {self.space_id}"
548
829
  )
549
830
  return 0
550
831
 
@@ -556,10 +837,10 @@ class SemanticCacheService(GenieServiceBase):
556
837
 
557
838
  with self._pool.connection() as conn:
558
839
  with conn.cursor() as cur:
559
- cur.execute(delete_sql, (self.genie_space_id, ttl_seconds))
840
+ cur.execute(delete_sql, (self.space_id, ttl_seconds))
560
841
  deleted: int = cur.rowcount
561
842
  logger.debug(
562
- f"[{self.name}] Deleted {deleted} expired entries for space {self.genie_space_id}"
843
+ f"[{self.name}] Deleted {deleted} expired entries for space {self.space_id}"
563
844
  )
564
845
  return deleted
565
846
 
@@ -570,10 +851,10 @@ class SemanticCacheService(GenieServiceBase):
570
851
 
571
852
  with self._pool.connection() as conn:
572
853
  with conn.cursor() as cur:
573
- cur.execute(delete_sql, (self.genie_space_id,))
854
+ cur.execute(delete_sql, (self.space_id,))
574
855
  deleted: int = cur.rowcount
575
856
  logger.debug(
576
- f"[{self.name}] Cleared {deleted} entries for space {self.genie_space_id}"
857
+ f"[{self.name}] Cleared {deleted} entries for space {self.space_id}"
577
858
  )
578
859
  return deleted
579
860
 
@@ -587,7 +868,7 @@ class SemanticCacheService(GenieServiceBase):
587
868
 
588
869
  with self._pool.connection() as conn:
589
870
  with conn.cursor() as cur:
590
- cur.execute(count_sql, (self.genie_space_id,))
871
+ cur.execute(count_sql, (self.space_id,))
591
872
  row: DbRow | None = cur.fetchone()
592
873
  return row.get("count", 0) if row else 0
593
874
 
@@ -605,7 +886,7 @@ class SemanticCacheService(GenieServiceBase):
605
886
  """
606
887
  with self._pool.connection() as conn:
607
888
  with conn.cursor() as cur:
608
- cur.execute(count_sql, (self.genie_space_id,))
889
+ cur.execute(count_sql, (self.space_id,))
609
890
  row: DbRow | None = cur.fetchone()
610
891
  total = row.get("total", 0) if row else 0
611
892
  return {
@@ -627,12 +908,12 @@ class SemanticCacheService(GenieServiceBase):
627
908
 
628
909
  with self._pool.connection() as conn:
629
910
  with conn.cursor() as cur:
630
- cur.execute(stats_sql, (ttl_seconds, ttl_seconds, self.genie_space_id))
631
- row: DbRow | None = cur.fetchone()
911
+ cur.execute(stats_sql, (ttl_seconds, ttl_seconds, self.space_id))
912
+ stats_row: DbRow | None = cur.fetchone()
632
913
  return {
633
- "size": row.get("total", 0) if row else 0,
914
+ "size": stats_row.get("total", 0) if stats_row else 0,
634
915
  "ttl_seconds": ttl.total_seconds() if ttl else None,
635
916
  "similarity_threshold": self.similarity_threshold,
636
- "expired_entries": row.get("expired", 0) if row else 0,
637
- "valid_entries": row.get("valid", 0) if row else 0,
917
+ "expired_entries": stats_row.get("expired", 0) if stats_row else 0,
918
+ "valid_entries": stats_row.get("valid", 0) if stats_row else 0,
638
919
  }