dao-ai 0.1.19__py3-none-any.whl → 0.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/cli.py +37 -7
- dao_ai/config.py +265 -10
- dao_ai/genie/__init__.py +55 -7
- dao_ai/genie/cache/__init__.py +36 -9
- dao_ai/genie/cache/base.py +143 -2
- dao_ai/genie/cache/context_aware/__init__.py +52 -0
- dao_ai/genie/cache/context_aware/base.py +1204 -0
- dao_ai/genie/cache/{in_memory_semantic.py → context_aware/in_memory.py} +233 -383
- dao_ai/genie/cache/context_aware/optimization.py +930 -0
- dao_ai/genie/cache/context_aware/persistent.py +802 -0
- dao_ai/genie/cache/context_aware/postgres.py +1343 -0
- dao_ai/genie/cache/lru.py +248 -70
- dao_ai/genie/core.py +235 -11
- dao_ai/middleware/__init__.py +8 -1
- dao_ai/middleware/tool_call_observability.py +227 -0
- dao_ai/nodes.py +4 -4
- dao_ai/tools/__init__.py +2 -2
- dao_ai/tools/genie.py +10 -10
- dao_ai/utils.py +7 -3
- {dao_ai-0.1.19.dist-info → dao_ai-0.1.21.dist-info}/METADATA +1 -1
- {dao_ai-0.1.19.dist-info → dao_ai-0.1.21.dist-info}/RECORD +24 -19
- dao_ai/genie/cache/semantic.py +0 -1004
- {dao_ai-0.1.19.dist-info → dao_ai-0.1.21.dist-info}/WHEEL +0 -0
- {dao_ai-0.1.19.dist-info → dao_ai-0.1.21.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.1.19.dist-info → dao_ai-0.1.21.dist-info}/licenses/LICENSE +0 -0
dao_ai/cli.py
CHANGED
|
@@ -2,6 +2,7 @@ import argparse
|
|
|
2
2
|
import getpass
|
|
3
3
|
import json
|
|
4
4
|
import os
|
|
5
|
+
import signal
|
|
5
6
|
import subprocess
|
|
6
7
|
import sys
|
|
7
8
|
import traceback
|
|
@@ -454,6 +455,18 @@ def handle_chat_command(options: Namespace) -> None:
|
|
|
454
455
|
"""Interactive chat REPL with the DAO AI system with Human-in-the-Loop support."""
|
|
455
456
|
logger.debug("Starting chat session with DAO AI system...")
|
|
456
457
|
|
|
458
|
+
# Set up signal handler for clean Ctrl+C handling
|
|
459
|
+
def signal_handler(sig: int, frame: Any) -> None:
|
|
460
|
+
try:
|
|
461
|
+
print("\n\n👋 Chat session interrupted. Goodbye!")
|
|
462
|
+
sys.stdout.flush()
|
|
463
|
+
except Exception:
|
|
464
|
+
pass
|
|
465
|
+
sys.exit(0)
|
|
466
|
+
|
|
467
|
+
# Store original handler and set our handler
|
|
468
|
+
original_handler = signal.signal(signal.SIGINT, signal_handler)
|
|
469
|
+
|
|
457
470
|
try:
|
|
458
471
|
# Set default user_id if not provided
|
|
459
472
|
if options.user_id is None:
|
|
@@ -667,6 +680,12 @@ def handle_chat_command(options: Namespace) -> None:
|
|
|
667
680
|
|
|
668
681
|
try:
|
|
669
682
|
result = loop.run_until_complete(_invoke_with_hitl())
|
|
683
|
+
except KeyboardInterrupt:
|
|
684
|
+
# Re-raise to be caught by outer handler
|
|
685
|
+
raise
|
|
686
|
+
except asyncio.CancelledError:
|
|
687
|
+
# Treat cancellation like KeyboardInterrupt
|
|
688
|
+
raise KeyboardInterrupt
|
|
670
689
|
except Exception as e:
|
|
671
690
|
logger.error(f"Error invoking graph: {e}")
|
|
672
691
|
print(f"\n❌ Error: {e}")
|
|
@@ -732,23 +751,34 @@ def handle_chat_command(options: Namespace) -> None:
|
|
|
732
751
|
logger.error(f"Response processing error: {e}")
|
|
733
752
|
logger.error(f"Stack trace: {traceback.format_exc()}")
|
|
734
753
|
|
|
735
|
-
except EOFError:
|
|
736
|
-
# Handle Ctrl-D
|
|
737
|
-
print
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
754
|
+
except (EOFError, KeyboardInterrupt):
|
|
755
|
+
# Handle Ctrl-D (EOF) or Ctrl-C (interrupt)
|
|
756
|
+
# Use try/except for print in case stdout is closed
|
|
757
|
+
try:
|
|
758
|
+
print("\n\n👋 Goodbye! Chat session ended.")
|
|
759
|
+
sys.stdout.flush()
|
|
760
|
+
except Exception:
|
|
761
|
+
pass
|
|
742
762
|
break
|
|
743
763
|
except Exception as e:
|
|
744
764
|
print(f"\n❌ Error: {e}")
|
|
745
765
|
logger.error(f"Chat error: {e}")
|
|
746
766
|
traceback.print_exc()
|
|
747
767
|
|
|
768
|
+
except (EOFError, KeyboardInterrupt):
|
|
769
|
+
# Handle interrupts during initialization
|
|
770
|
+
try:
|
|
771
|
+
print("\n\n👋 Chat session interrupted. Goodbye!")
|
|
772
|
+
sys.stdout.flush()
|
|
773
|
+
except Exception:
|
|
774
|
+
pass
|
|
748
775
|
except Exception as e:
|
|
749
776
|
logger.error(f"Failed to initialize chat session: {e}")
|
|
750
777
|
print(f"❌ Failed to start chat session: {e}")
|
|
751
778
|
sys.exit(1)
|
|
779
|
+
finally:
|
|
780
|
+
# Restore original signal handler
|
|
781
|
+
signal.signal(signal.SIGINT, original_handler)
|
|
752
782
|
|
|
753
783
|
|
|
754
784
|
def handle_schema_command(options: Namespace) -> None:
|
dao_ai/config.py
CHANGED
|
@@ -20,6 +20,10 @@ from typing import (
|
|
|
20
20
|
)
|
|
21
21
|
|
|
22
22
|
if TYPE_CHECKING:
|
|
23
|
+
from dao_ai.genie.cache.context_aware.optimization import (
|
|
24
|
+
ContextAwareCacheEvalDataset,
|
|
25
|
+
ThresholdOptimizationResult,
|
|
26
|
+
)
|
|
23
27
|
from dao_ai.state import Context
|
|
24
28
|
|
|
25
29
|
from databricks.sdk import WorkspaceClient
|
|
@@ -1710,7 +1714,7 @@ class GenieLRUCacheParametersModel(BaseModel):
|
|
|
1710
1714
|
warehouse: WarehouseModel
|
|
1711
1715
|
|
|
1712
1716
|
|
|
1713
|
-
class
|
|
1717
|
+
class GenieContextAwareCacheParametersModel(BaseModel):
|
|
1714
1718
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1715
1719
|
time_to_live_seconds: int | None = (
|
|
1716
1720
|
60 * 60 * 24
|
|
@@ -1728,10 +1732,21 @@ class GenieSemanticCacheParametersModel(BaseModel):
|
|
|
1728
1732
|
database: DatabaseModel
|
|
1729
1733
|
warehouse: WarehouseModel
|
|
1730
1734
|
table_name: str = "genie_semantic_cache"
|
|
1731
|
-
context_window_size: int =
|
|
1735
|
+
context_window_size: int = 2 # Number of previous turns to include for context
|
|
1732
1736
|
max_context_tokens: int = (
|
|
1733
1737
|
2000 # Maximum context length to prevent extremely long embeddings
|
|
1734
1738
|
)
|
|
1739
|
+
# Prompt history configuration
|
|
1740
|
+
# Prompt history is always enabled - it stores all user prompts to maintain
|
|
1741
|
+
# conversation context for accurate semantic matching even when cache hits occur
|
|
1742
|
+
prompt_history_table: str = "genie_prompt_history" # Table name for prompt history
|
|
1743
|
+
max_prompt_history_length: int = 50 # Maximum prompts to keep per conversation
|
|
1744
|
+
use_genie_api_for_history: bool = (
|
|
1745
|
+
False # Fallback to Genie API if local history empty
|
|
1746
|
+
)
|
|
1747
|
+
prompt_history_ttl_seconds: int | None = (
|
|
1748
|
+
None # TTL for prompts (None = use cache TTL)
|
|
1749
|
+
)
|
|
1735
1750
|
|
|
1736
1751
|
@model_validator(mode="after")
|
|
1737
1752
|
def compute_and_validate_weights(self) -> Self:
|
|
@@ -1805,7 +1820,7 @@ class GenieInMemorySemanticCacheParametersModel(BaseModel):
|
|
|
1805
1820
|
- Cache persistence across restarts is not required
|
|
1806
1821
|
- Cache sizes are moderate (hundreds to low thousands of entries)
|
|
1807
1822
|
|
|
1808
|
-
For multi-instance deployments or large cache sizes, use
|
|
1823
|
+
For multi-instance deployments or large cache sizes, use GenieContextAwareCacheParametersModel
|
|
1809
1824
|
with PostgreSQL backend instead.
|
|
1810
1825
|
"""
|
|
1811
1826
|
|
|
@@ -2301,6 +2316,7 @@ class FunctionType(str, Enum):
|
|
|
2301
2316
|
FACTORY = "factory"
|
|
2302
2317
|
UNITY_CATALOG = "unity_catalog"
|
|
2303
2318
|
MCP = "mcp"
|
|
2319
|
+
INLINE = "inline"
|
|
2304
2320
|
|
|
2305
2321
|
|
|
2306
2322
|
class HumanInTheLoopModel(BaseModel):
|
|
@@ -2402,6 +2418,72 @@ class FactoryFunctionModel(BaseFunctionModel, HasFullName):
|
|
|
2402
2418
|
return self
|
|
2403
2419
|
|
|
2404
2420
|
|
|
2421
|
+
class InlineFunctionModel(BaseFunctionModel):
|
|
2422
|
+
"""
|
|
2423
|
+
Inline function model for defining tool code directly in YAML configuration.
|
|
2424
|
+
|
|
2425
|
+
This allows you to define simple tools without creating separate Python files.
|
|
2426
|
+
The code should define a function decorated with @tool from langchain.tools.
|
|
2427
|
+
|
|
2428
|
+
Example YAML:
|
|
2429
|
+
tools:
|
|
2430
|
+
calculator:
|
|
2431
|
+
name: calculator
|
|
2432
|
+
function:
|
|
2433
|
+
type: inline
|
|
2434
|
+
code: |
|
|
2435
|
+
from langchain.tools import tool
|
|
2436
|
+
|
|
2437
|
+
@tool
|
|
2438
|
+
def calculator(expression: str) -> str:
|
|
2439
|
+
'''Evaluate a mathematical expression.'''
|
|
2440
|
+
return str(eval(expression))
|
|
2441
|
+
|
|
2442
|
+
The code block must:
|
|
2443
|
+
- Import @tool from langchain.tools
|
|
2444
|
+
- Define exactly one function decorated with @tool
|
|
2445
|
+
- The function name becomes the tool name
|
|
2446
|
+
"""
|
|
2447
|
+
|
|
2448
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
2449
|
+
type: Literal[FunctionType.INLINE] = FunctionType.INLINE
|
|
2450
|
+
code: str = Field(
|
|
2451
|
+
...,
|
|
2452
|
+
description="Python code defining a tool function decorated with @tool",
|
|
2453
|
+
)
|
|
2454
|
+
|
|
2455
|
+
def as_tools(self, **kwargs: Any) -> Sequence[RunnableLike]:
|
|
2456
|
+
"""Execute the inline code and return the tool(s) defined in it."""
|
|
2457
|
+
from langchain_core.tools import BaseTool
|
|
2458
|
+
|
|
2459
|
+
# Create a namespace for executing the code
|
|
2460
|
+
namespace: dict[str, Any] = {}
|
|
2461
|
+
|
|
2462
|
+
# Execute the code in the namespace
|
|
2463
|
+
try:
|
|
2464
|
+
exec(self.code, namespace)
|
|
2465
|
+
except Exception as e:
|
|
2466
|
+
raise ValueError(f"Failed to execute inline tool code: {e}") from e
|
|
2467
|
+
|
|
2468
|
+
# Find all tools (functions decorated with @tool) in the namespace
|
|
2469
|
+
tools: list[RunnableLike] = []
|
|
2470
|
+
for name, obj in namespace.items():
|
|
2471
|
+
if isinstance(obj, BaseTool):
|
|
2472
|
+
tools.append(obj)
|
|
2473
|
+
|
|
2474
|
+
if not tools:
|
|
2475
|
+
raise ValueError(
|
|
2476
|
+
"Inline code must define at least one function decorated with @tool. "
|
|
2477
|
+
"Make sure to import and use: from langchain.tools import tool"
|
|
2478
|
+
)
|
|
2479
|
+
|
|
2480
|
+
logger.debug(
|
|
2481
|
+
"Created inline tools",
|
|
2482
|
+
tool_names=[t.name for t in tools if hasattr(t, "name")],
|
|
2483
|
+
)
|
|
2484
|
+
return tools
|
|
2485
|
+
|
|
2486
|
+
|
|
2405
2487
|
class TransportType(str, Enum):
|
|
2406
2488
|
STREAMABLE_HTTP = "streamable_http"
|
|
2407
2489
|
STDIO = "stdio"
|
|
@@ -2722,6 +2804,7 @@ AnyTool: TypeAlias = (
|
|
|
2722
2804
|
Union[
|
|
2723
2805
|
PythonFunctionModel,
|
|
2724
2806
|
FactoryFunctionModel,
|
|
2807
|
+
InlineFunctionModel,
|
|
2725
2808
|
UnityCatalogFunctionModel,
|
|
2726
2809
|
McpFunctionModel,
|
|
2727
2810
|
]
|
|
@@ -3644,20 +3727,25 @@ class OptimizationsModel(BaseModel):
|
|
|
3644
3727
|
prompt_optimizations: dict[str, PromptOptimizationModel] = Field(
|
|
3645
3728
|
default_factory=dict
|
|
3646
3729
|
)
|
|
3730
|
+
cache_threshold_optimizations: dict[str, "ContextAwareCacheOptimizationModel"] = (
|
|
3731
|
+
Field(default_factory=dict)
|
|
3732
|
+
)
|
|
3647
3733
|
|
|
3648
|
-
def optimize(self, w: WorkspaceClient | None = None) -> dict[str,
|
|
3734
|
+
def optimize(self, w: WorkspaceClient | None = None) -> dict[str, Any]:
|
|
3649
3735
|
"""
|
|
3650
|
-
Optimize all prompts in this configuration.
|
|
3736
|
+
Optimize all prompts and cache thresholds in this configuration.
|
|
3651
3737
|
|
|
3652
3738
|
This method:
|
|
3653
3739
|
1. Ensures all training datasets are created/registered in MLflow
|
|
3654
3740
|
2. Runs each prompt optimization
|
|
3741
|
+
3. Runs each cache threshold optimization
|
|
3655
3742
|
|
|
3656
3743
|
Args:
|
|
3657
3744
|
w: Optional WorkspaceClient for Databricks operations
|
|
3658
3745
|
|
|
3659
3746
|
Returns:
|
|
3660
|
-
dict[str,
|
|
3747
|
+
dict[str, Any]: Dictionary with 'prompts' and 'cache_thresholds' keys
|
|
3748
|
+
containing the respective optimization results
|
|
3661
3749
|
"""
|
|
3662
3750
|
# First, ensure all training datasets are created/registered in MLflow
|
|
3663
3751
|
logger.info(f"Ensuring {len(self.training_datasets)} training datasets exist")
|
|
@@ -3665,11 +3753,178 @@ class OptimizationsModel(BaseModel):
|
|
|
3665
3753
|
logger.debug(f"Creating/updating dataset: {dataset_name}")
|
|
3666
3754
|
dataset_model.as_dataset()
|
|
3667
3755
|
|
|
3668
|
-
# Run optimizations
|
|
3669
|
-
|
|
3756
|
+
# Run prompt optimizations
|
|
3757
|
+
prompt_results: dict[str, PromptModel] = {}
|
|
3670
3758
|
for name, optimization in self.prompt_optimizations.items():
|
|
3671
|
-
|
|
3672
|
-
|
|
3759
|
+
prompt_results[name] = optimization.optimize(w)
|
|
3760
|
+
|
|
3761
|
+
# Run cache threshold optimizations
|
|
3762
|
+
cache_results: dict[str, Any] = {}
|
|
3763
|
+
for name, optimization in self.cache_threshold_optimizations.items():
|
|
3764
|
+
cache_results[name] = optimization.optimize(w)
|
|
3765
|
+
|
|
3766
|
+
return {"prompts": prompt_results, "cache_thresholds": cache_results}
|
|
3767
|
+
|
|
3768
|
+
|
|
3769
|
+
class ContextAwareCacheEvalEntryModel(BaseModel):
|
|
3770
|
+
"""Single evaluation entry for context-aware cache threshold optimization.
|
|
3771
|
+
|
|
3772
|
+
Represents a pair of question/context combinations to evaluate
|
|
3773
|
+
whether the cache should return a hit or miss.
|
|
3774
|
+
|
|
3775
|
+
Example:
|
|
3776
|
+
entry:
|
|
3777
|
+
question: "What are total sales?"
|
|
3778
|
+
question_embedding: [0.1, 0.2, ...] # Pre-computed
|
|
3779
|
+
context: "Previous: Show me revenue"
|
|
3780
|
+
context_embedding: [0.1, 0.2, ...]
|
|
3781
|
+
cached_question: "Show total sales"
|
|
3782
|
+
cached_question_embedding: [0.1, 0.2, ...]
|
|
3783
|
+
cached_context: "Previous: Show me revenue"
|
|
3784
|
+
cached_context_embedding: [0.1, 0.2, ...]
|
|
3785
|
+
expected_match: true
|
|
3786
|
+
"""
|
|
3787
|
+
|
|
3788
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
3789
|
+
question: str
|
|
3790
|
+
question_embedding: list[float]
|
|
3791
|
+
context: str = ""
|
|
3792
|
+
context_embedding: list[float] = Field(default_factory=list)
|
|
3793
|
+
cached_question: str
|
|
3794
|
+
cached_question_embedding: list[float]
|
|
3795
|
+
cached_context: str = ""
|
|
3796
|
+
cached_context_embedding: list[float] = Field(default_factory=list)
|
|
3797
|
+
expected_match: Optional[bool] = None # None = use LLM judge
|
|
3798
|
+
|
|
3799
|
+
|
|
3800
|
+
class ContextAwareCacheEvalDatasetModel(BaseModel):
|
|
3801
|
+
"""Dataset for context-aware cache threshold optimization.
|
|
3802
|
+
|
|
3803
|
+
Contains pairs of questions/contexts to evaluate whether thresholds
|
|
3804
|
+
correctly identify semantic matches.
|
|
3805
|
+
|
|
3806
|
+
Example:
|
|
3807
|
+
dataset:
|
|
3808
|
+
name: my_cache_eval_dataset
|
|
3809
|
+
description: "Evaluation data for cache tuning"
|
|
3810
|
+
entries:
|
|
3811
|
+
- question: "What are total sales?"
|
|
3812
|
+
# ... entry fields
|
|
3813
|
+
"""
|
|
3814
|
+
|
|
3815
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
3816
|
+
name: str
|
|
3817
|
+
description: str = ""
|
|
3818
|
+
entries: list[ContextAwareCacheEvalEntryModel] = Field(default_factory=list)
|
|
3819
|
+
|
|
3820
|
+
def as_eval_dataset(self) -> "ContextAwareCacheEvalDataset":
|
|
3821
|
+
"""Convert to internal evaluation dataset format."""
|
|
3822
|
+
from dao_ai.genie.cache.context_aware.optimization import (
|
|
3823
|
+
ContextAwareCacheEvalDataset,
|
|
3824
|
+
ContextAwareCacheEvalEntry,
|
|
3825
|
+
)
|
|
3826
|
+
|
|
3827
|
+
entries = [
|
|
3828
|
+
ContextAwareCacheEvalEntry(
|
|
3829
|
+
question=e.question,
|
|
3830
|
+
question_embedding=e.question_embedding,
|
|
3831
|
+
context=e.context,
|
|
3832
|
+
context_embedding=e.context_embedding,
|
|
3833
|
+
cached_question=e.cached_question,
|
|
3834
|
+
cached_question_embedding=e.cached_question_embedding,
|
|
3835
|
+
cached_context=e.cached_context,
|
|
3836
|
+
cached_context_embedding=e.cached_context_embedding,
|
|
3837
|
+
expected_match=e.expected_match,
|
|
3838
|
+
)
|
|
3839
|
+
for e in self.entries
|
|
3840
|
+
]
|
|
3841
|
+
|
|
3842
|
+
return ContextAwareCacheEvalDataset(
|
|
3843
|
+
name=self.name,
|
|
3844
|
+
entries=entries,
|
|
3845
|
+
description=self.description,
|
|
3846
|
+
)
|
|
3847
|
+
|
|
3848
|
+
|
|
3849
|
+
class ContextAwareCacheOptimizationModel(BaseModel):
|
|
3850
|
+
"""Configuration for context-aware cache threshold optimization.
|
|
3851
|
+
|
|
3852
|
+
Uses Optuna Bayesian optimization to find optimal threshold values
|
|
3853
|
+
that maximize cache hit accuracy (F1 score by default).
|
|
3854
|
+
|
|
3855
|
+
Example:
|
|
3856
|
+
optimizations:
|
|
3857
|
+
cache_threshold_optimizations:
|
|
3858
|
+
my_optimization:
|
|
3859
|
+
name: optimize_cache_thresholds
|
|
3860
|
+
cache_parameters: *my_cache_params
|
|
3861
|
+
dataset: *my_eval_dataset
|
|
3862
|
+
judge_model: databricks-meta-llama-3-3-70b-instruct
|
|
3863
|
+
n_trials: 50
|
|
3864
|
+
metric: f1
|
|
3865
|
+
"""
|
|
3866
|
+
|
|
3867
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
3868
|
+
name: str
|
|
3869
|
+
cache_parameters: Optional[GenieContextAwareCacheParametersModel] = None
|
|
3870
|
+
dataset: ContextAwareCacheEvalDatasetModel
|
|
3871
|
+
judge_model: Optional[LLMModel | str] = "databricks-meta-llama-3-3-70b-instruct"
|
|
3872
|
+
n_trials: int = 50
|
|
3873
|
+
metric: Literal["f1", "precision", "recall", "fbeta"] = "f1"
|
|
3874
|
+
beta: float = 1.0 # For fbeta metric
|
|
3875
|
+
seed: Optional[int] = None
|
|
3876
|
+
|
|
3877
|
+
def optimize(
|
|
3878
|
+
self, w: WorkspaceClient | None = None
|
|
3879
|
+
) -> "ThresholdOptimizationResult":
|
|
3880
|
+
"""
|
|
3881
|
+
Optimize semantic cache thresholds.
|
|
3882
|
+
|
|
3883
|
+
Args:
|
|
3884
|
+
w: Optional WorkspaceClient (not used, kept for API compatibility)
|
|
3885
|
+
|
|
3886
|
+
Returns:
|
|
3887
|
+
ThresholdOptimizationResult with optimized thresholds
|
|
3888
|
+
"""
|
|
3889
|
+
from dao_ai.genie.cache.context_aware.optimization import (
|
|
3890
|
+
ThresholdOptimizationResult,
|
|
3891
|
+
optimize_context_aware_cache_thresholds,
|
|
3892
|
+
)
|
|
3893
|
+
|
|
3894
|
+
# Convert dataset
|
|
3895
|
+
eval_dataset = self.dataset.as_eval_dataset()
|
|
3896
|
+
|
|
3897
|
+
# Get original thresholds from cache_parameters
|
|
3898
|
+
original_thresholds: dict[str, float] | None = None
|
|
3899
|
+
if self.cache_parameters:
|
|
3900
|
+
original_thresholds = {
|
|
3901
|
+
"similarity_threshold": self.cache_parameters.similarity_threshold,
|
|
3902
|
+
"context_similarity_threshold": self.cache_parameters.context_similarity_threshold,
|
|
3903
|
+
"question_weight": self.cache_parameters.question_weight or 0.6,
|
|
3904
|
+
}
|
|
3905
|
+
|
|
3906
|
+
# Get judge model
|
|
3907
|
+
judge_model_name: str
|
|
3908
|
+
if isinstance(self.judge_model, str):
|
|
3909
|
+
judge_model_name = self.judge_model
|
|
3910
|
+
elif self.judge_model:
|
|
3911
|
+
judge_model_name = self.judge_model.uri
|
|
3912
|
+
else:
|
|
3913
|
+
judge_model_name = "databricks-meta-llama-3-3-70b-instruct"
|
|
3914
|
+
|
|
3915
|
+
result: ThresholdOptimizationResult = optimize_context_aware_cache_thresholds(
|
|
3916
|
+
dataset=eval_dataset,
|
|
3917
|
+
original_thresholds=original_thresholds,
|
|
3918
|
+
judge_model=judge_model_name,
|
|
3919
|
+
n_trials=self.n_trials,
|
|
3920
|
+
metric=self.metric,
|
|
3921
|
+
beta=self.beta,
|
|
3922
|
+
register_if_improved=True,
|
|
3923
|
+
study_name=self.name,
|
|
3924
|
+
seed=self.seed,
|
|
3925
|
+
)
|
|
3926
|
+
|
|
3927
|
+
return result
|
|
3673
3928
|
|
|
3674
3929
|
|
|
3675
3930
|
class DatasetFormat(str, Enum):
|
dao_ai/genie/__init__.py
CHANGED
|
@@ -5,34 +5,82 @@ This package provides core Genie functionality that can be used across
|
|
|
5
5
|
different contexts (tools, direct integration, etc.).
|
|
6
6
|
|
|
7
7
|
Main exports:
|
|
8
|
-
-
|
|
8
|
+
- Genie: Extended Genie class that captures message_id in responses
|
|
9
|
+
- GenieResponse: Extended response class with message_id field
|
|
10
|
+
- GenieService: Service implementation wrapping Genie
|
|
9
11
|
- GenieServiceBase: Abstract base class for service implementations
|
|
12
|
+
- GenieFeedbackRating: Enum for feedback ratings (POSITIVE, NEGATIVE, NONE)
|
|
13
|
+
|
|
14
|
+
Original databricks_ai_bridge classes (aliased):
|
|
15
|
+
- DatabricksGenie: Original Genie from databricks_ai_bridge
|
|
16
|
+
- DatabricksGenieResponse: Original GenieResponse from databricks_ai_bridge
|
|
10
17
|
|
|
11
18
|
Cache implementations are available in the cache subpackage:
|
|
12
19
|
- dao_ai.genie.cache.lru: LRU (Least Recently Used) cache
|
|
13
|
-
- dao_ai.genie.cache.
|
|
20
|
+
- dao_ai.genie.cache.context_aware.postgres: PostgreSQL context-aware cache
|
|
21
|
+
- dao_ai.genie.cache.context_aware.in_memory: In-memory context-aware cache
|
|
14
22
|
|
|
15
23
|
Example usage:
|
|
16
|
-
from dao_ai.genie import GenieService
|
|
17
|
-
|
|
24
|
+
from dao_ai.genie import Genie, GenieService, GenieFeedbackRating
|
|
25
|
+
|
|
26
|
+
# Create Genie with message_id support
|
|
27
|
+
genie = Genie(space_id="my-space")
|
|
28
|
+
response = genie.ask_question("What are total sales?")
|
|
29
|
+
print(response.message_id) # Now available!
|
|
30
|
+
|
|
31
|
+
# Use with GenieService
|
|
32
|
+
service = GenieService(genie)
|
|
33
|
+
result = service.ask_question("What are total sales?")
|
|
34
|
+
|
|
35
|
+
# Send feedback using captured message_id
|
|
36
|
+
service.send_feedback(
|
|
37
|
+
conversation_id=result.response.conversation_id,
|
|
38
|
+
rating=GenieFeedbackRating.POSITIVE,
|
|
39
|
+
message_id=result.message_id, # Available from CacheResult
|
|
40
|
+
was_cache_hit=result.cache_hit,
|
|
41
|
+
)
|
|
18
42
|
"""
|
|
19
43
|
|
|
44
|
+
from databricks.sdk.service.dashboards import GenieFeedbackRating
|
|
45
|
+
|
|
20
46
|
from dao_ai.genie.cache import (
|
|
21
47
|
CacheResult,
|
|
48
|
+
ContextAwareGenieService,
|
|
22
49
|
GenieServiceBase,
|
|
50
|
+
InMemoryContextAwareGenieService,
|
|
23
51
|
LRUCacheService,
|
|
24
|
-
|
|
52
|
+
PostgresContextAwareGenieService,
|
|
25
53
|
SQLCacheEntry,
|
|
26
54
|
)
|
|
27
|
-
from dao_ai.genie.
|
|
55
|
+
from dao_ai.genie.cache.base import get_latest_message_id, get_message_content
|
|
56
|
+
from dao_ai.genie.core import (
|
|
57
|
+
DatabricksGenie,
|
|
58
|
+
DatabricksGenieResponse,
|
|
59
|
+
Genie,
|
|
60
|
+
GenieResponse,
|
|
61
|
+
GenieService,
|
|
62
|
+
)
|
|
28
63
|
|
|
29
64
|
__all__ = [
|
|
65
|
+
# Extended Genie classes (primary - use these)
|
|
66
|
+
"Genie",
|
|
67
|
+
"GenieResponse",
|
|
68
|
+
# Original databricks_ai_bridge classes (aliased)
|
|
69
|
+
"DatabricksGenie",
|
|
70
|
+
"DatabricksGenieResponse",
|
|
30
71
|
# Service classes
|
|
31
72
|
"GenieService",
|
|
32
73
|
"GenieServiceBase",
|
|
74
|
+
# Feedback
|
|
75
|
+
"GenieFeedbackRating",
|
|
76
|
+
# Helper functions
|
|
77
|
+
"get_latest_message_id",
|
|
78
|
+
"get_message_content",
|
|
33
79
|
# Cache types (from cache subpackage)
|
|
34
80
|
"CacheResult",
|
|
81
|
+
"ContextAwareGenieService",
|
|
82
|
+
"InMemoryContextAwareGenieService",
|
|
35
83
|
"LRUCacheService",
|
|
36
|
-
"
|
|
84
|
+
"PostgresContextAwareGenieService",
|
|
37
85
|
"SQLCacheEntry",
|
|
38
86
|
]
|
dao_ai/genie/cache/__init__.py
CHANGED
|
@@ -6,15 +6,16 @@ chained together using the decorator pattern.
|
|
|
6
6
|
|
|
7
7
|
Available cache implementations:
|
|
8
8
|
- LRUCacheService: In-memory LRU cache with O(1) exact match lookup
|
|
9
|
-
-
|
|
9
|
+
- PostgresContextAwareGenieService: PostgreSQL pg_vector-based context-aware cache
|
|
10
|
+
- InMemoryContextAwareGenieService: In-memory context-aware cache
|
|
10
11
|
|
|
11
12
|
Example usage:
|
|
12
|
-
from dao_ai.genie.cache import LRUCacheService,
|
|
13
|
+
from dao_ai.genie.cache import LRUCacheService, PostgresContextAwareGenieService
|
|
13
14
|
|
|
14
|
-
# Chain caches: LRU (checked first) ->
|
|
15
|
-
genie_service =
|
|
15
|
+
# Chain caches: LRU (checked first) -> Context-aware (checked second) -> Genie
|
|
16
|
+
genie_service = PostgresContextAwareGenieService(
|
|
16
17
|
impl=GenieService(genie),
|
|
17
|
-
parameters=
|
|
18
|
+
parameters=context_aware_params,
|
|
18
19
|
)
|
|
19
20
|
genie_service = LRUCacheService(
|
|
20
21
|
impl=genie_service,
|
|
@@ -27,10 +28,25 @@ from dao_ai.genie.cache.base import (
|
|
|
27
28
|
GenieServiceBase,
|
|
28
29
|
SQLCacheEntry,
|
|
29
30
|
)
|
|
31
|
+
from dao_ai.genie.cache.context_aware import (
|
|
32
|
+
ContextAwareGenieService,
|
|
33
|
+
InMemoryContextAwareGenieService,
|
|
34
|
+
PersistentContextAwareGenieCacheService,
|
|
35
|
+
PostgresContextAwareGenieService,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
# Re-export optimization from context_aware for backwards compatibility
|
|
39
|
+
from dao_ai.genie.cache.context_aware.optimization import (
|
|
40
|
+
ContextAwareCacheEvalDataset,
|
|
41
|
+
ContextAwareCacheEvalEntry,
|
|
42
|
+
ThresholdOptimizationResult,
|
|
43
|
+
clear_judge_cache,
|
|
44
|
+
generate_eval_dataset_from_cache,
|
|
45
|
+
optimize_context_aware_cache_thresholds,
|
|
46
|
+
semantic_match_judge,
|
|
47
|
+
)
|
|
30
48
|
from dao_ai.genie.cache.core import execute_sql_via_warehouse
|
|
31
|
-
from dao_ai.genie.cache.in_memory_semantic import InMemorySemanticCacheService
|
|
32
49
|
from dao_ai.genie.cache.lru import LRUCacheService
|
|
33
|
-
from dao_ai.genie.cache.semantic import SemanticCacheService
|
|
34
50
|
|
|
35
51
|
__all__ = [
|
|
36
52
|
# Base types
|
|
@@ -38,8 +54,19 @@ __all__ = [
|
|
|
38
54
|
"GenieServiceBase",
|
|
39
55
|
"SQLCacheEntry",
|
|
40
56
|
"execute_sql_via_warehouse",
|
|
57
|
+
# Context-aware base classes
|
|
58
|
+
"ContextAwareGenieService",
|
|
59
|
+
"PersistentContextAwareGenieCacheService",
|
|
41
60
|
# Cache implementations
|
|
42
|
-
"
|
|
61
|
+
"InMemoryContextAwareGenieService",
|
|
43
62
|
"LRUCacheService",
|
|
44
|
-
"
|
|
63
|
+
"PostgresContextAwareGenieService",
|
|
64
|
+
# Optimization
|
|
65
|
+
"ContextAwareCacheEvalDataset",
|
|
66
|
+
"ContextAwareCacheEvalEntry",
|
|
67
|
+
"ThresholdOptimizationResult",
|
|
68
|
+
"clear_judge_cache",
|
|
69
|
+
"generate_eval_dataset_from_cache",
|
|
70
|
+
"optimize_context_aware_cache_thresholds",
|
|
71
|
+
"semantic_match_judge",
|
|
45
72
|
]
|