dao-ai 0.1.20__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,7 @@
1
1
  """
2
- Semantic cache threshold optimization using Optuna Bayesian optimization.
2
+ Context-aware semantic cache threshold optimization using Optuna Bayesian optimization.
3
3
 
4
- This module provides optimization for Genie semantic cache thresholds using
4
+ This module provides optimization for context-aware Genie cache thresholds using
5
5
  Optuna's Tree-structured Parzen Estimator (TPE) algorithm with LLM-as-Judge
6
6
  evaluation for semantic match validation.
7
7
 
@@ -11,10 +11,23 @@ The optimizer tunes these thresholds:
11
11
  - question_weight: Weight for question similarity in combined score
12
12
 
13
13
  Usage:
14
- from dao_ai.genie.cache.optimization import optimize_semantic_cache_thresholds
14
+ from dao_ai.genie.cache.context_aware.optimization import (
15
+ optimize_context_aware_cache_thresholds,
16
+ generate_eval_dataset_from_cache,
17
+ )
18
+
19
+ # Get entries from your cache
20
+ entries = cache_service.get_entries(include_embeddings=True, limit=100)
21
+
22
+ # Generate evaluation dataset
23
+ eval_dataset = generate_eval_dataset_from_cache(
24
+ cache_entries=entries,
25
+ dataset_name="my_cache_eval",
26
+ )
15
27
 
16
- result = optimize_semantic_cache_thresholds(
17
- dataset=my_eval_dataset,
28
+ # Optimize thresholds
29
+ result = optimize_context_aware_cache_thresholds(
30
+ dataset=eval_dataset,
18
31
  judge_model="databricks-meta-llama-3-3-70b-instruct",
19
32
  n_trials=50,
20
33
  metric="f1",
@@ -29,37 +42,31 @@ import hashlib
29
42
  import math
30
43
  from dataclasses import dataclass, field
31
44
  from datetime import datetime, timezone
32
- from typing import Any, Callable, Literal, Sequence
45
+ from typing import TYPE_CHECKING, Any, Callable, Iterator, Literal, Sequence
33
46
 
34
47
  import mlflow
35
- import optuna
36
48
  from loguru import logger
37
- from optuna.samplers import TPESampler
38
-
39
- # Optional MLflow integration - requires optuna-integration[mlflow]
40
- try:
41
- from optuna.integration import MLflowCallback
42
-
43
- MLFLOW_CALLBACK_AVAILABLE = True
44
- except ModuleNotFoundError:
45
- MLFLOW_CALLBACK_AVAILABLE = False
46
- MLflowCallback = None # type: ignore
47
49
 
48
50
  from dao_ai.config import GenieContextAwareCacheParametersModel, LLMModel
49
51
  from dao_ai.utils import dao_ai_version
50
52
 
53
+ # Type-only import for optuna.Trial to support type hints without runtime dependency
54
+ if TYPE_CHECKING:
55
+ import optuna
56
+
51
57
  __all__ = [
52
- "SemanticCacheEvalEntry",
53
- "SemanticCacheEvalDataset",
58
+ "ContextAwareCacheEvalEntry",
59
+ "ContextAwareCacheEvalDataset",
54
60
  "ThresholdOptimizationResult",
55
- "optimize_semantic_cache_thresholds",
61
+ "optimize_context_aware_cache_thresholds",
56
62
  "generate_eval_dataset_from_cache",
57
63
  "semantic_match_judge",
64
+ "clear_judge_cache",
58
65
  ]
59
66
 
60
67
 
61
68
  @dataclass
62
- class SemanticCacheEvalEntry:
69
+ class ContextAwareCacheEvalEntry:
63
70
  """Single evaluation entry for threshold optimization.
64
71
 
65
72
  Represents a pair of question/context combinations to evaluate
@@ -90,7 +97,7 @@ class SemanticCacheEvalEntry:
90
97
 
91
98
 
92
99
  @dataclass
93
- class SemanticCacheEvalDataset:
100
+ class ContextAwareCacheEvalDataset:
94
101
  """Dataset for semantic cache threshold optimization.
95
102
 
96
103
  Attributes:
@@ -100,13 +107,13 @@ class SemanticCacheEvalDataset:
100
107
  """
101
108
 
102
109
  name: str
103
- entries: list[SemanticCacheEvalEntry]
110
+ entries: list[ContextAwareCacheEvalEntry]
104
111
  description: str = ""
105
112
 
106
113
  def __len__(self) -> int:
107
114
  return len(self.entries)
108
115
 
109
- def __iter__(self):
116
+ def __iter__(self) -> Iterator[ContextAwareCacheEvalEntry]:
110
117
  return iter(self.entries)
111
118
 
112
119
 
@@ -272,7 +279,7 @@ def _compute_l2_similarity(embedding1: list[float], embedding2: list[float]) ->
272
279
 
273
280
 
274
281
  def _evaluate_thresholds(
275
- dataset: SemanticCacheEvalDataset,
282
+ dataset: ContextAwareCacheEvalDataset,
276
283
  similarity_threshold: float,
277
284
  context_similarity_threshold: float,
278
285
  question_weight: float,
@@ -370,14 +377,14 @@ def _evaluate_thresholds(
370
377
 
371
378
 
372
379
  def _create_objective(
373
- dataset: SemanticCacheEvalDataset,
380
+ dataset: ContextAwareCacheEvalDataset,
374
381
  judge_model: LLMModel | str | None,
375
382
  metric: Literal["f1", "precision", "recall", "fbeta"],
376
383
  beta: float = 1.0,
377
- ) -> Callable[[optuna.Trial], float]:
384
+ ) -> Callable[["optuna.Trial"], float]:
378
385
  """Create the Optuna objective function."""
379
386
 
380
- def objective(trial: optuna.Trial) -> float:
387
+ def objective(trial: "optuna.Trial") -> float:
381
388
  # Sample parameters
382
389
  similarity_threshold = trial.suggest_float(
383
390
  "similarity_threshold", 0.5, 0.99, log=False
@@ -423,8 +430,8 @@ def _create_objective(
423
430
  return objective
424
431
 
425
432
 
426
- def optimize_semantic_cache_thresholds(
427
- dataset: SemanticCacheEvalDataset,
433
+ def optimize_context_aware_cache_thresholds(
434
+ dataset: ContextAwareCacheEvalDataset,
428
435
  original_thresholds: dict[str, float]
429
436
  | GenieContextAwareCacheParametersModel
430
437
  | None = None,
@@ -461,12 +468,12 @@ def optimize_semantic_cache_thresholds(
461
468
  ThresholdOptimizationResult with optimized thresholds and metrics
462
469
 
463
470
  Example:
464
- from dao_ai.genie.cache.optimization import (
465
- optimize_semantic_cache_thresholds,
466
- SemanticCacheEvalDataset,
471
+ from dao_ai.genie.cache.context_aware.optimization import (
472
+ optimize_context_aware_cache_thresholds,
473
+ ContextAwareCacheEvalDataset,
467
474
  )
468
475
 
469
- result = optimize_semantic_cache_thresholds(
476
+ result = optimize_context_aware_cache_thresholds(
470
477
  dataset=my_dataset,
471
478
  judge_model="databricks-meta-llama-3-3-70b-instruct",
472
479
  n_trials=50,
@@ -476,6 +483,26 @@ def optimize_semantic_cache_thresholds(
476
483
  if result.improved:
477
484
  print(f"New thresholds: {result.optimized_thresholds}")
478
485
  """
486
+ # Lazy import optuna - only loaded when optimization is actually called
487
+ # This allows the cache module to be imported without optuna installed
488
+ try:
489
+ import optuna
490
+ from optuna.samplers import TPESampler
491
+ except ImportError as e:
492
+ raise ImportError(
493
+ "optuna is required for cache threshold optimization. "
494
+ "Install it with: pip install optuna"
495
+ ) from e
496
+
497
+ # Optional MLflow integration - requires optuna-integration[mlflow]
498
+ try:
499
+ from optuna.integration import MLflowCallback
500
+
501
+ mlflow_callback_available = True
502
+ except ModuleNotFoundError:
503
+ mlflow_callback_available = False
504
+ MLflowCallback = None # type: ignore
505
+
479
506
  logger.info(
480
507
  "Starting semantic cache threshold optimization",
481
508
  dataset_name=dataset.name,
@@ -539,7 +566,7 @@ def optimize_semantic_cache_thresholds(
539
566
  # Create study name if not provided
540
567
  if study_name is None:
541
568
  timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
542
- study_name = f"semantic_cache_threshold_optimization_{timestamp}"
569
+ study_name = f"context_aware_cache_threshold_optimization_{timestamp}"
543
570
 
544
571
  # Create Optuna study
545
572
  sampler = TPESampler(seed=seed)
@@ -562,7 +589,7 @@ def optimize_semantic_cache_thresholds(
562
589
 
563
590
  # Set up MLflow callback if available
564
591
  callbacks = []
565
- if MLFLOW_CALLBACK_AVAILABLE and MLflowCallback is not None:
592
+ if mlflow_callback_available and MLflowCallback is not None:
566
593
  try:
567
594
  mlflow_callback = MLflowCallback(
568
595
  tracking_uri=mlflow.get_tracking_uri(),
@@ -743,7 +770,7 @@ def generate_eval_dataset_from_cache(
743
770
  num_negative_pairs: int = 50,
744
771
  paraphrase_model: LLMModel | str | None = None,
745
772
  dataset_name: str = "generated_eval_dataset",
746
- ) -> SemanticCacheEvalDataset:
773
+ ) -> ContextAwareCacheEvalDataset:
747
774
  """
748
775
  Generate an evaluation dataset from existing cache entries.
749
776
 
@@ -752,7 +779,8 @@ def generate_eval_dataset_from_cache(
752
779
 
753
780
  Args:
754
781
  cache_entries: List of cache entries with 'question', 'conversation_context',
755
- 'question_embedding', and 'context_embedding' keys
782
+ 'question_embedding', and 'context_embedding' keys. Use cache.get_entries()
783
+ with include_embeddings=True to retrieve these.
756
784
  embedding_model: Model for generating embeddings for paraphrased questions
757
785
  num_positive_pairs: Number of positive (matching) pairs to generate
758
786
  num_negative_pairs: Number of negative (non-matching) pairs to generate
@@ -760,7 +788,19 @@ def generate_eval_dataset_from_cache(
760
788
  dataset_name: Name for the generated dataset
761
789
 
762
790
  Returns:
763
- SemanticCacheEvalDataset with generated entries
791
+ ContextAwareCacheEvalDataset with generated entries
792
+
793
+ Example:
794
+ # Get entries from cache with embeddings
795
+ entries = cache_service.get_entries(include_embeddings=True, limit=100)
796
+
797
+ # Generate evaluation dataset
798
+ eval_dataset = generate_eval_dataset_from_cache(
799
+ cache_entries=entries,
800
+ num_positive_pairs=50,
801
+ num_negative_pairs=50,
802
+ dataset_name="my_cache_eval",
803
+ )
764
804
  """
765
805
  import random
766
806
 
@@ -787,7 +827,7 @@ def generate_eval_dataset_from_cache(
787
827
  )
788
828
  chat = para_model.as_chat_model()
789
829
 
790
- entries: list[SemanticCacheEvalEntry] = []
830
+ entries: list[ContextAwareCacheEvalEntry] = []
791
831
 
792
832
  # Generate positive pairs (paraphrases)
793
833
  logger.info(
@@ -824,7 +864,7 @@ Rephrased question:"""
824
864
  )
825
865
 
826
866
  entries.append(
827
- SemanticCacheEvalEntry(
867
+ ContextAwareCacheEvalEntry(
828
868
  question=paraphrased_question,
829
869
  question_embedding=para_q_emb,
830
870
  context=original_context,
@@ -855,7 +895,7 @@ Rephrased question:"""
855
895
 
856
896
  # Use entry1 as the "question" and entry2 as the "cached" entry
857
897
  entries.append(
858
- SemanticCacheEvalEntry(
898
+ ContextAwareCacheEvalEntry(
859
899
  question=entry1.get("question", ""),
860
900
  question_embedding=entry1.get("question_embedding", []),
861
901
  context=entry1.get("conversation_context", ""),
@@ -876,7 +916,7 @@ Rephrased question:"""
876
916
  negative_pairs=sum(1 for e in entries if e.expected_match is False),
877
917
  )
878
918
 
879
- return SemanticCacheEvalDataset(
919
+ return ContextAwareCacheEvalDataset(
880
920
  name=dataset_name,
881
921
  entries=entries,
882
922
  description=f"Generated from {len(cache_entries)} cache entries",
@@ -849,6 +849,144 @@ class PostgresContextAwareGenieService(PersistentContextAwareGenieCacheService):
849
849
  }
850
850
  return {}
851
851
 
852
+ def get_entries(
853
+ self,
854
+ limit: int | None = None,
855
+ offset: int | None = None,
856
+ include_embeddings: bool = False,
857
+ conversation_id: str | None = None,
858
+ created_after: datetime | None = None,
859
+ created_before: datetime | None = None,
860
+ question_contains: str | None = None,
861
+ ) -> list[dict[str, Any]]:
862
+ """
863
+ Get cache entries with optional filtering.
864
+
865
+ This method retrieves cache entries for inspection, debugging, or
866
+ generating evaluation datasets for threshold optimization.
867
+
868
+ Args:
869
+ limit: Maximum number of entries to return (None = no limit)
870
+ offset: Number of entries to skip for pagination (None = 0)
871
+ include_embeddings: Whether to include embedding vectors in results.
872
+ Embeddings are large, so set False for general inspection.
873
+ conversation_id: Filter by conversation ID (None = all conversations)
874
+ created_after: Only entries created after this time (None = no filter)
875
+ created_before: Only entries created before this time (None = no filter)
876
+ question_contains: Case-insensitive text search on question field
877
+
878
+ Returns:
879
+ List of cache entry dicts. See base class for full key documentation.
880
+
881
+ Example:
882
+ # Get entries with embeddings for evaluation dataset generation
883
+ entries = cache.get_entries(include_embeddings=True, limit=100)
884
+ eval_dataset = generate_eval_dataset_from_cache(entries)
885
+ """
886
+ self._setup()
887
+
888
+ # Build column list
889
+ base_columns = [
890
+ "id",
891
+ "question",
892
+ "conversation_context",
893
+ "sql_query",
894
+ "description",
895
+ "conversation_id",
896
+ "created_at",
897
+ ]
898
+
899
+ if include_embeddings:
900
+ columns = base_columns + ["question_embedding", "context_embedding"]
901
+ else:
902
+ columns = base_columns
903
+
904
+ columns_str = ", ".join(columns)
905
+
906
+ # Build WHERE clause with parameters
907
+ where_clauses = ["genie_space_id = %s"]
908
+ params: list[Any] = [self.space_id]
909
+
910
+ if conversation_id is not None:
911
+ where_clauses.append("conversation_id = %s")
912
+ params.append(conversation_id)
913
+
914
+ if created_after is not None:
915
+ where_clauses.append("created_at > %s")
916
+ params.append(created_after)
917
+
918
+ if created_before is not None:
919
+ where_clauses.append("created_at < %s")
920
+ params.append(created_before)
921
+
922
+ if question_contains is not None:
923
+ where_clauses.append("question ILIKE %s")
924
+ params.append(f"%{question_contains}%")
925
+
926
+ where_str = " AND ".join(where_clauses)
927
+
928
+ # Build full query
929
+ query = f"""
930
+ SELECT {columns_str}
931
+ FROM {self.table_name}
932
+ WHERE {where_str}
933
+ ORDER BY created_at DESC
934
+ """
935
+
936
+ if limit is not None:
937
+ query += f" LIMIT {int(limit)}"
938
+
939
+ if offset is not None:
940
+ query += f" OFFSET {int(offset)}"
941
+
942
+ # Execute query
943
+ with self._pool.connection() as conn:
944
+ with conn.cursor() as cur:
945
+ cur.execute(query, params)
946
+ rows = cur.fetchall()
947
+
948
+ entries: list[dict[str, Any]] = []
949
+ for row in rows:
950
+ entry: dict[str, Any] = {
951
+ "id": row.get("id"),
952
+ "question": row.get("question"),
953
+ "conversation_context": row.get("conversation_context"),
954
+ "sql_query": row.get("sql_query"),
955
+ "description": row.get("description"),
956
+ "conversation_id": row.get("conversation_id"),
957
+ "created_at": row.get("created_at"),
958
+ }
959
+
960
+ if include_embeddings:
961
+ # Convert pgvector to list
962
+ q_emb = row.get("question_embedding")
963
+ c_emb = row.get("context_embedding")
964
+ entry["question_embedding"] = (
965
+ list(q_emb) if q_emb is not None else []
966
+ )
967
+ entry["context_embedding"] = (
968
+ list(c_emb) if c_emb is not None else []
969
+ )
970
+
971
+ entries.append(entry)
972
+
973
+ logger.debug(
974
+ "Retrieved cache entries",
975
+ layer=self.name,
976
+ count=len(entries),
977
+ include_embeddings=include_embeddings,
978
+ filters={
979
+ "conversation_id": conversation_id,
980
+ "created_after": str(created_after) if created_after else None,
981
+ "created_before": (
982
+ str(created_before) if created_before else None
983
+ ),
984
+ "question_contains": question_contains,
985
+ },
986
+ )
987
+
988
+ return entries
989
+
852
990
  def from_space(
853
991
  self,
854
992
  space_id: str | None = None,
@@ -857,6 +995,7 @@ class PostgresContextAwareGenieService(PersistentContextAwareGenieCacheService):
857
995
  from_datetime: datetime | None = None,
858
996
  to_datetime: datetime | None = None,
859
997
  max_messages: int | None = None,
998
+ max_conversations: int | None = None,
860
999
  ) -> Self:
861
1000
  """Populate cache from existing Genie space conversations.
862
1001
 
@@ -872,6 +1011,7 @@ class PostgresContextAwareGenieService(PersistentContextAwareGenieCacheService):
872
1011
  from_datetime: Only include messages after this time
873
1012
  to_datetime: Only include messages before this time
874
1013
  max_messages: Limit to last N messages (most recent first)
1014
+ max_conversations: Limit to N conversations (stops pagination after reaching limit)
875
1015
 
876
1016
  Returns:
877
1017
  self for method chaining
@@ -916,8 +1056,21 @@ class PostgresContextAwareGenieService(PersistentContextAwareGenieCacheService):
916
1056
  break
917
1057
 
918
1058
  if response.conversations is None:
1059
+ logger.debug(
1060
+ "No conversations in response",
1061
+ layer=self.name,
1062
+ space_id=target_space_id,
1063
+ )
919
1064
  break
920
1065
 
1066
+ logger.debug(
1067
+ "Fetched conversations page",
1068
+ layer=self.name,
1069
+ conversations_in_page=len(response.conversations),
1070
+ total_conversations_so_far=stats["conversations_processed"],
1071
+ has_next_page=response.next_page_token is not None,
1072
+ )
1073
+
921
1074
  for conversation in response.conversations:
922
1075
  if conversation.conversation_id is None:
923
1076
  continue
@@ -945,11 +1098,35 @@ class PostgresContextAwareGenieService(PersistentContextAwareGenieCacheService):
945
1098
  if max_messages and len(all_messages) >= max_messages:
946
1099
  break
947
1100
 
1101
+ if (
1102
+ max_conversations
1103
+ and stats["conversations_processed"] >= max_conversations
1104
+ ):
1105
+ break
1106
+
948
1107
  if max_messages and len(all_messages) >= max_messages:
949
1108
  break
950
1109
 
1110
+ if (
1111
+ max_conversations
1112
+ and stats["conversations_processed"] >= max_conversations
1113
+ ):
1114
+ logger.debug(
1115
+ "Reached max_conversations limit",
1116
+ layer=self.name,
1117
+ max_conversations=max_conversations,
1118
+ total_conversations=stats["conversations_processed"],
1119
+ )
1120
+ break
1121
+
951
1122
  page_token = response.next_page_token
952
1123
  if page_token is None:
1124
+ logger.debug(
1125
+ "No more pages to fetch",
1126
+ layer=self.name,
1127
+ total_conversations=stats["conversations_processed"],
1128
+ total_messages=len(all_messages),
1129
+ )
953
1130
  break
954
1131
 
955
1132
  # Sort and limit
@@ -1,5 +1,5 @@
1
1
  # DAO AI Middleware Module
2
- # This module provides middleware implementations compatible with LangChain v1's create_agent
2
+ # Middleware implementations compatible with LangChain v1's create_agent
3
3
 
4
4
  # Re-export LangChain built-in middleware
5
5
  from langchain.agents.middleware import (
@@ -82,6 +82,10 @@ from dao_ai.middleware.summarization import (
82
82
  create_summarization_middleware,
83
83
  )
84
84
  from dao_ai.middleware.tool_call_limit import create_tool_call_limit_middleware
85
+ from dao_ai.middleware.tool_call_observability import (
86
+ ToolCallObservabilityMiddleware,
87
+ create_tool_call_observability_middleware,
88
+ )
85
89
  from dao_ai.middleware.tool_retry import create_tool_retry_middleware
86
90
  from dao_ai.middleware.tool_selector import create_llm_tool_selector_middleware
87
91
 
@@ -160,4 +164,7 @@ __all__ = [
160
164
  "create_clear_tool_uses_edit",
161
165
  # PII middleware factory functions
162
166
  "create_pii_middleware",
167
+ # Tool call observability middleware
168
+ "ToolCallObservabilityMiddleware",
169
+ "create_tool_call_observability_middleware",
163
170
  ]