dao-ai 0.1.20__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dao_ai/config.py CHANGED
@@ -20,8 +20,8 @@ from typing import (
20
20
  )
21
21
 
22
22
  if TYPE_CHECKING:
23
- from dao_ai.genie.cache.optimization import (
24
- SemanticCacheEvalDataset,
23
+ from dao_ai.genie.cache.context_aware.optimization import (
24
+ ContextAwareCacheEvalDataset,
25
25
  ThresholdOptimizationResult,
26
26
  )
27
27
  from dao_ai.state import Context
@@ -2316,6 +2316,7 @@ class FunctionType(str, Enum):
2316
2316
  FACTORY = "factory"
2317
2317
  UNITY_CATALOG = "unity_catalog"
2318
2318
  MCP = "mcp"
2319
+ INLINE = "inline"
2319
2320
 
2320
2321
 
2321
2322
  class HumanInTheLoopModel(BaseModel):
@@ -2417,6 +2418,72 @@ class FactoryFunctionModel(BaseFunctionModel, HasFullName):
2417
2418
  return self
2418
2419
 
2419
2420
 
2421
+ class InlineFunctionModel(BaseFunctionModel):
2422
+ """
2423
+ Inline function model for defining tool code directly in YAML configuration.
2424
+
2425
+ This allows you to define simple tools without creating separate Python files.
2426
+ The code should define a function decorated with @tool from langchain.tools.
2427
+
2428
+ Example YAML:
2429
+ tools:
2430
+ calculator:
2431
+ name: calculator
2432
+ function:
2433
+ type: inline
2434
+ code: |
2435
+ from langchain.tools import tool
2436
+
2437
+ @tool
2438
+ def calculator(expression: str) -> str:
2439
+ '''Evaluate a mathematical expression.'''
2440
+ return str(eval(expression))
2441
+
2442
+ The code block must:
2443
+ - Import @tool from langchain.tools
2444
+ - Define exactly one function decorated with @tool
2445
+ - The function name becomes the tool name
2446
+ """
2447
+
2448
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
2449
+ type: Literal[FunctionType.INLINE] = FunctionType.INLINE
2450
+ code: str = Field(
2451
+ ...,
2452
+ description="Python code defining a tool function decorated with @tool",
2453
+ )
2454
+
2455
+ def as_tools(self, **kwargs: Any) -> Sequence[RunnableLike]:
2456
+ """Execute the inline code and return the tool(s) defined in it."""
2457
+ from langchain_core.tools import BaseTool
2458
+
2459
+ # Create a namespace for executing the code
2460
+ namespace: dict[str, Any] = {}
2461
+
2462
+ # Execute the code in the namespace
2463
+ try:
2464
+ exec(self.code, namespace)
2465
+ except Exception as e:
2466
+ raise ValueError(f"Failed to execute inline tool code: {e}") from e
2467
+
2468
+ # Find all tools (functions decorated with @tool) in the namespace
2469
+ tools: list[RunnableLike] = []
2470
+ for name, obj in namespace.items():
2471
+ if isinstance(obj, BaseTool):
2472
+ tools.append(obj)
2473
+
2474
+ if not tools:
2475
+ raise ValueError(
2476
+ "Inline code must define at least one function decorated with @tool. "
2477
+ "Make sure to import and use: from langchain.tools import tool"
2478
+ )
2479
+
2480
+ logger.debug(
2481
+ "Created inline tools",
2482
+ tool_names=[t.name for t in tools if hasattr(t, "name")],
2483
+ )
2484
+ return tools
2485
+
2486
+
2420
2487
  class TransportType(str, Enum):
2421
2488
  STREAMABLE_HTTP = "streamable_http"
2422
2489
  STDIO = "stdio"
@@ -2737,6 +2804,7 @@ AnyTool: TypeAlias = (
2737
2804
  Union[
2738
2805
  PythonFunctionModel,
2739
2806
  FactoryFunctionModel,
2807
+ InlineFunctionModel,
2740
2808
  UnityCatalogFunctionModel,
2741
2809
  McpFunctionModel,
2742
2810
  ]
@@ -3659,20 +3727,25 @@ class OptimizationsModel(BaseModel):
3659
3727
  prompt_optimizations: dict[str, PromptOptimizationModel] = Field(
3660
3728
  default_factory=dict
3661
3729
  )
3730
+ cache_threshold_optimizations: dict[str, "ContextAwareCacheOptimizationModel"] = (
3731
+ Field(default_factory=dict)
3732
+ )
3662
3733
 
3663
- def optimize(self, w: WorkspaceClient | None = None) -> dict[str, PromptModel]:
3734
+ def optimize(self, w: WorkspaceClient | None = None) -> dict[str, Any]:
3664
3735
  """
3665
- Optimize all prompts in this configuration.
3736
+ Optimize all prompts and cache thresholds in this configuration.
3666
3737
 
3667
3738
  This method:
3668
3739
  1. Ensures all training datasets are created/registered in MLflow
3669
3740
  2. Runs each prompt optimization
3741
+ 3. Runs each cache threshold optimization
3670
3742
 
3671
3743
  Args:
3672
3744
  w: Optional WorkspaceClient for Databricks operations
3673
3745
 
3674
3746
  Returns:
3675
- dict[str, PromptModel]: Dictionary mapping optimization names to optimized prompts
3747
+ dict[str, Any]: Dictionary with 'prompts' and 'cache_thresholds' keys
3748
+ containing the respective optimization results
3676
3749
  """
3677
3750
  # First, ensure all training datasets are created/registered in MLflow
3678
3751
  logger.info(f"Ensuring {len(self.training_datasets)} training datasets exist")
@@ -3680,15 +3753,21 @@ class OptimizationsModel(BaseModel):
3680
3753
  logger.debug(f"Creating/updating dataset: {dataset_name}")
3681
3754
  dataset_model.as_dataset()
3682
3755
 
3683
- # Run optimizations
3684
- results: dict[str, PromptModel] = {}
3756
+ # Run prompt optimizations
3757
+ prompt_results: dict[str, PromptModel] = {}
3685
3758
  for name, optimization in self.prompt_optimizations.items():
3686
- results[name] = optimization.optimize(w)
3687
- return results
3759
+ prompt_results[name] = optimization.optimize(w)
3760
+
3761
+ # Run cache threshold optimizations
3762
+ cache_results: dict[str, Any] = {}
3763
+ for name, optimization in self.cache_threshold_optimizations.items():
3764
+ cache_results[name] = optimization.optimize(w)
3765
+
3766
+ return {"prompts": prompt_results, "cache_thresholds": cache_results}
3688
3767
 
3689
3768
 
3690
- class SemanticCacheEvalEntryModel(BaseModel):
3691
- """Single evaluation entry for semantic cache threshold optimization.
3769
+ class ContextAwareCacheEvalEntryModel(BaseModel):
3770
+ """Single evaluation entry for context-aware cache threshold optimization.
3692
3771
 
3693
3772
  Represents a pair of question/context combinations to evaluate
3694
3773
  whether the cache should return a hit or miss.
@@ -3718,8 +3797,8 @@ class SemanticCacheEvalEntryModel(BaseModel):
3718
3797
  expected_match: Optional[bool] = None # None = use LLM judge
3719
3798
 
3720
3799
 
3721
- class SemanticCacheEvalDatasetModel(BaseModel):
3722
- """Dataset for semantic cache threshold optimization.
3800
+ class ContextAwareCacheEvalDatasetModel(BaseModel):
3801
+ """Dataset for context-aware cache threshold optimization.
3723
3802
 
3724
3803
  Contains pairs of questions/contexts to evaluate whether thresholds
3725
3804
  correctly identify semantic matches.
@@ -3736,17 +3815,17 @@ class SemanticCacheEvalDatasetModel(BaseModel):
3736
3815
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
3737
3816
  name: str
3738
3817
  description: str = ""
3739
- entries: list[SemanticCacheEvalEntryModel] = Field(default_factory=list)
3818
+ entries: list[ContextAwareCacheEvalEntryModel] = Field(default_factory=list)
3740
3819
 
3741
- def as_eval_dataset(self) -> "SemanticCacheEvalDataset":
3820
+ def as_eval_dataset(self) -> "ContextAwareCacheEvalDataset":
3742
3821
  """Convert to internal evaluation dataset format."""
3743
- from dao_ai.genie.cache.optimization import (
3744
- SemanticCacheEvalDataset,
3745
- SemanticCacheEvalEntry,
3822
+ from dao_ai.genie.cache.context_aware.optimization import (
3823
+ ContextAwareCacheEvalDataset,
3824
+ ContextAwareCacheEvalEntry,
3746
3825
  )
3747
3826
 
3748
3827
  entries = [
3749
- SemanticCacheEvalEntry(
3828
+ ContextAwareCacheEvalEntry(
3750
3829
  question=e.question,
3751
3830
  question_embedding=e.question_embedding,
3752
3831
  context=e.context,
@@ -3760,33 +3839,35 @@ class SemanticCacheEvalDatasetModel(BaseModel):
3760
3839
  for e in self.entries
3761
3840
  ]
3762
3841
 
3763
- return SemanticCacheEvalDataset(
3842
+ return ContextAwareCacheEvalDataset(
3764
3843
  name=self.name,
3765
3844
  entries=entries,
3766
3845
  description=self.description,
3767
3846
  )
3768
3847
 
3769
3848
 
3770
- class SemanticCacheThresholdOptimizationModel(BaseModel):
3771
- """Configuration for semantic cache threshold optimization.
3849
+ class ContextAwareCacheOptimizationModel(BaseModel):
3850
+ """Configuration for context-aware cache threshold optimization.
3772
3851
 
3773
3852
  Uses Optuna Bayesian optimization to find optimal threshold values
3774
3853
  that maximize cache hit accuracy (F1 score by default).
3775
3854
 
3776
3855
  Example:
3777
- threshold_optimization:
3778
- name: optimize_cache_thresholds
3779
- cache_parameters: *my_cache_params
3780
- dataset: *my_eval_dataset
3781
- judge_model: databricks-meta-llama-3-3-70b-instruct
3782
- n_trials: 50
3783
- metric: f1
3856
+ optimizations:
3857
+ cache_threshold_optimizations:
3858
+ my_optimization:
3859
+ name: optimize_cache_thresholds
3860
+ cache_parameters: *my_cache_params
3861
+ dataset: *my_eval_dataset
3862
+ judge_model: databricks-meta-llama-3-3-70b-instruct
3863
+ n_trials: 50
3864
+ metric: f1
3784
3865
  """
3785
3866
 
3786
3867
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
3787
3868
  name: str
3788
3869
  cache_parameters: Optional[GenieContextAwareCacheParametersModel] = None
3789
- dataset: SemanticCacheEvalDatasetModel
3870
+ dataset: ContextAwareCacheEvalDatasetModel
3790
3871
  judge_model: Optional[LLMModel | str] = "databricks-meta-llama-3-3-70b-instruct"
3791
3872
  n_trials: int = 50
3792
3873
  metric: Literal["f1", "precision", "recall", "fbeta"] = "f1"
@@ -3805,9 +3886,9 @@ class SemanticCacheThresholdOptimizationModel(BaseModel):
3805
3886
  Returns:
3806
3887
  ThresholdOptimizationResult with optimized thresholds
3807
3888
  """
3808
- from dao_ai.genie.cache.optimization import (
3889
+ from dao_ai.genie.cache.context_aware.optimization import (
3809
3890
  ThresholdOptimizationResult,
3810
- optimize_semantic_cache_thresholds,
3891
+ optimize_context_aware_cache_thresholds,
3811
3892
  )
3812
3893
 
3813
3894
  # Convert dataset
@@ -3831,7 +3912,7 @@ class SemanticCacheThresholdOptimizationModel(BaseModel):
3831
3912
  else:
3832
3913
  judge_model_name = "databricks-meta-llama-3-3-70b-instruct"
3833
3914
 
3834
- result: ThresholdOptimizationResult = optimize_semantic_cache_thresholds(
3915
+ result: ThresholdOptimizationResult = optimize_context_aware_cache_thresholds(
3835
3916
  dataset=eval_dataset,
3836
3917
  original_thresholds=original_thresholds,
3837
3918
  judge_model=judge_model_name,
@@ -34,17 +34,19 @@ from dao_ai.genie.cache.context_aware import (
34
34
  PersistentContextAwareGenieCacheService,
35
35
  PostgresContextAwareGenieService,
36
36
  )
37
- from dao_ai.genie.cache.core import execute_sql_via_warehouse
38
- from dao_ai.genie.cache.lru import LRUCacheService
39
- from dao_ai.genie.cache.optimization import (
40
- SemanticCacheEvalDataset,
41
- SemanticCacheEvalEntry,
37
+
38
+ # Re-export optimization from context_aware for backwards compatibility
39
+ from dao_ai.genie.cache.context_aware.optimization import (
40
+ ContextAwareCacheEvalDataset,
41
+ ContextAwareCacheEvalEntry,
42
42
  ThresholdOptimizationResult,
43
43
  clear_judge_cache,
44
44
  generate_eval_dataset_from_cache,
45
- optimize_semantic_cache_thresholds,
45
+ optimize_context_aware_cache_thresholds,
46
46
  semantic_match_judge,
47
47
  )
48
+ from dao_ai.genie.cache.core import execute_sql_via_warehouse
49
+ from dao_ai.genie.cache.lru import LRUCacheService
48
50
 
49
51
  __all__ = [
50
52
  # Base types
@@ -60,11 +62,11 @@ __all__ = [
60
62
  "LRUCacheService",
61
63
  "PostgresContextAwareGenieService",
62
64
  # Optimization
63
- "SemanticCacheEvalDataset",
64
- "SemanticCacheEvalEntry",
65
+ "ContextAwareCacheEvalDataset",
66
+ "ContextAwareCacheEvalEntry",
65
67
  "ThresholdOptimizationResult",
66
68
  "clear_judge_cache",
67
69
  "generate_eval_dataset_from_cache",
68
- "optimize_semantic_cache_thresholds",
70
+ "optimize_context_aware_cache_thresholds",
69
71
  "semantic_match_judge",
70
72
  ]
@@ -12,10 +12,23 @@ Available implementations:
12
12
  Base classes:
13
13
  - ContextAwareGenieService: Abstract base for all context-aware cache implementations
14
14
  - PersistentContextAwareGenieCacheService: Abstract base for database-backed implementations
15
+
16
+ Optimization:
17
+ - optimize_context_aware_cache_thresholds: Tune cache thresholds using Bayesian optimization
18
+ - generate_eval_dataset_from_cache: Generate evaluation datasets from cache entries
15
19
  """
16
20
 
17
21
  from dao_ai.genie.cache.context_aware.base import ContextAwareGenieService
18
22
  from dao_ai.genie.cache.context_aware.in_memory import InMemoryContextAwareGenieService
23
+ from dao_ai.genie.cache.context_aware.optimization import (
24
+ ContextAwareCacheEvalDataset,
25
+ ContextAwareCacheEvalEntry,
26
+ ThresholdOptimizationResult,
27
+ clear_judge_cache,
28
+ generate_eval_dataset_from_cache,
29
+ optimize_context_aware_cache_thresholds,
30
+ semantic_match_judge,
31
+ )
19
32
  from dao_ai.genie.cache.context_aware.persistent import (
20
33
  PersistentContextAwareGenieCacheService,
21
34
  )
@@ -28,4 +41,12 @@ __all__ = [
28
41
  # Implementations
29
42
  "InMemoryContextAwareGenieService",
30
43
  "PostgresContextAwareGenieService",
44
+ # Optimization
45
+ "ContextAwareCacheEvalDataset",
46
+ "ContextAwareCacheEvalEntry",
47
+ "ThresholdOptimizationResult",
48
+ "clear_judge_cache",
49
+ "generate_eval_dataset_from_cache",
50
+ "optimize_context_aware_cache_thresholds",
51
+ "semantic_match_judge",
31
52
  ]
@@ -15,12 +15,13 @@ Subclasses must implement storage-specific methods:
15
15
  - invalidate_expired(): Remove expired entries
16
16
  - clear(): Clear all entries for space
17
17
  - stats(): Return cache statistics
18
+ - get_entries(): Retrieve cache entries with filtering
18
19
  """
19
20
 
20
21
  from __future__ import annotations
21
22
 
22
23
  from abc import abstractmethod
23
- from datetime import timedelta
24
+ from datetime import datetime, timedelta
24
25
  from typing import Any, Self, TypeVar
25
26
 
26
27
  import mlflow
@@ -315,6 +316,58 @@ class ContextAwareGenieService(GenieServiceBase):
315
316
  """
316
317
  pass
317
318
 
319
+ @abstractmethod
320
+ def get_entries(
321
+ self,
322
+ limit: int | None = None,
323
+ offset: int | None = None,
324
+ include_embeddings: bool = False,
325
+ conversation_id: str | None = None,
326
+ created_after: datetime | None = None,
327
+ created_before: datetime | None = None,
328
+ question_contains: str | None = None,
329
+ ) -> list[dict[str, Any]]:
330
+ """
331
+ Get cache entries with optional filtering.
332
+
333
+ This method retrieves cache entries for inspection, debugging, or
334
+ generating evaluation datasets for threshold optimization.
335
+
336
+ Args:
337
+ limit: Maximum number of entries to return (None = no limit)
338
+ offset: Number of entries to skip for pagination (None = 0)
339
+ include_embeddings: Whether to include embedding vectors in results.
340
+ Embeddings are large, so set False for general inspection.
341
+ conversation_id: Filter by conversation ID (None = all conversations)
342
+ created_after: Only entries created after this time (None = no filter)
343
+ created_before: Only entries created before this time (None = no filter)
344
+ question_contains: Case-insensitive text search on question field
345
+
346
+ Returns:
347
+ List of cache entry dicts with keys:
348
+ - id: Cache entry ID (int for persistent caches, None for in-memory)
349
+ - question: The cached question text
350
+ - conversation_context: Prior conversation context string
351
+ - sql_query: The cached SQL query
352
+ - description: Query description
353
+ - conversation_id: The conversation ID
354
+ - created_at: Entry creation timestamp (datetime)
355
+ - question_embedding: (only if include_embeddings=True)
356
+ - context_embedding: (only if include_embeddings=True)
357
+
358
+ Example:
359
+ # Get recent entries for inspection
360
+ entries = cache.get_entries(limit=10)
361
+
362
+ # Get entries with embeddings for evaluation dataset
363
+ entries = cache.get_entries(include_embeddings=True, limit=100)
364
+ eval_dataset = generate_eval_dataset_from_cache(entries)
365
+
366
+ # Search for specific questions
367
+ entries = cache.get_entries(question_contains="sales")
368
+ """
369
+ pass
370
+
318
371
  def stats(self) -> dict[str, Any]:
319
372
  """
320
373
  Template method for returning cache statistics.
@@ -607,3 +607,115 @@ class InMemoryContextAwareGenieService(ContextAwareGenieService):
607
607
  def _get_additional_stats(self) -> dict[str, Any]:
608
608
  """Add capacity info to stats."""
609
609
  return {"capacity": self.parameters.capacity}
610
+
611
+ def get_entries(
612
+ self,
613
+ limit: int | None = None,
614
+ offset: int | None = None,
615
+ include_embeddings: bool = False,
616
+ conversation_id: str | None = None,
617
+ created_after: datetime | None = None,
618
+ created_before: datetime | None = None,
619
+ question_contains: str | None = None,
620
+ ) -> list[dict[str, Any]]:
621
+ """
622
+ Get cache entries with optional filtering.
623
+
624
+ This method retrieves cache entries for inspection, debugging, or
625
+ generating evaluation datasets for threshold optimization.
626
+
627
+ Args:
628
+ limit: Maximum number of entries to return (None = no limit)
629
+ offset: Number of entries to skip for pagination (None = 0)
630
+ include_embeddings: Whether to include embedding vectors in results.
631
+ Embeddings are large, so set False for general inspection.
632
+ conversation_id: Filter by conversation ID (None = all conversations)
633
+ created_after: Only entries created after this time (None = no filter)
634
+ created_before: Only entries created before this time (None = no filter)
635
+ question_contains: Case-insensitive text search on question field
636
+
637
+ Returns:
638
+ List of cache entry dicts. See base class for full key documentation.
639
+
640
+ Example:
641
+ # Get entries with embeddings for evaluation dataset generation
642
+ entries = cache.get_entries(include_embeddings=True, limit=100)
643
+ eval_dataset = generate_eval_dataset_from_cache(entries)
644
+ """
645
+ self._setup()
646
+
647
+ with self._lock:
648
+ # Filter entries for this space
649
+ filtered_entries: list[InMemoryCacheEntry] = []
650
+
651
+ for entry in self._cache:
652
+ # Filter by space_id
653
+ if entry.genie_space_id != self.space_id:
654
+ continue
655
+
656
+ # Filter by conversation_id
657
+ if (
658
+ conversation_id is not None
659
+ and entry.conversation_id != conversation_id
660
+ ):
661
+ continue
662
+
663
+ # Filter by created_after
664
+ if created_after is not None and entry.created_at <= created_after:
665
+ continue
666
+
667
+ # Filter by created_before
668
+ if created_before is not None and entry.created_at >= created_before:
669
+ continue
670
+
671
+ # Filter by question_contains (case-insensitive)
672
+ if question_contains is not None:
673
+ if question_contains.lower() not in entry.question.lower():
674
+ continue
675
+
676
+ filtered_entries.append(entry)
677
+
678
+ # Sort by created_at descending (most recent first)
679
+ filtered_entries.sort(key=lambda e: e.created_at, reverse=True)
680
+
681
+ # Apply offset
682
+ if offset is not None and offset > 0:
683
+ filtered_entries = filtered_entries[offset:]
684
+
685
+ # Apply limit
686
+ if limit is not None:
687
+ filtered_entries = filtered_entries[:limit]
688
+
689
+ # Convert to dicts
690
+ entries: list[dict[str, Any]] = []
691
+ for entry in filtered_entries:
692
+ result: dict[str, Any] = {
693
+ "id": None, # In-memory caches don't have database IDs
694
+ "question": entry.question,
695
+ "conversation_context": entry.conversation_context,
696
+ "sql_query": entry.sql_query,
697
+ "description": entry.description,
698
+ "conversation_id": entry.conversation_id,
699
+ "created_at": entry.created_at,
700
+ }
701
+
702
+ if include_embeddings:
703
+ result["question_embedding"] = entry.question_embedding
704
+ result["context_embedding"] = entry.context_embedding
705
+
706
+ entries.append(result)
707
+
708
+ logger.debug(
709
+ "Retrieved cache entries",
710
+ layer=self.name,
711
+ count=len(entries),
712
+ include_embeddings=include_embeddings,
713
+ filters={
714
+ "conversation_id": conversation_id,
715
+ "created_after": str(created_after) if created_after else None,
716
+ "created_before": str(created_before) if created_before else None,
717
+ "question_contains": question_contains,
718
+ },
719
+ )
720
+
721
+ return entries