dao-ai 0.0.35__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/__init__.py +29 -0
- dao_ai/cli.py +195 -30
- dao_ai/config.py +797 -242
- dao_ai/genie/__init__.py +38 -0
- dao_ai/genie/cache/__init__.py +43 -0
- dao_ai/genie/cache/base.py +72 -0
- dao_ai/genie/cache/core.py +75 -0
- dao_ai/genie/cache/lru.py +329 -0
- dao_ai/genie/cache/semantic.py +919 -0
- dao_ai/genie/core.py +35 -0
- dao_ai/graph.py +27 -253
- dao_ai/hooks/__init__.py +9 -6
- dao_ai/hooks/core.py +22 -190
- dao_ai/memory/__init__.py +10 -0
- dao_ai/memory/core.py +23 -5
- dao_ai/memory/databricks.py +389 -0
- dao_ai/memory/postgres.py +2 -2
- dao_ai/messages.py +6 -4
- dao_ai/middleware/__init__.py +125 -0
- dao_ai/middleware/assertions.py +778 -0
- dao_ai/middleware/base.py +50 -0
- dao_ai/middleware/core.py +61 -0
- dao_ai/middleware/guardrails.py +415 -0
- dao_ai/middleware/human_in_the_loop.py +228 -0
- dao_ai/middleware/message_validation.py +554 -0
- dao_ai/middleware/summarization.py +192 -0
- dao_ai/models.py +1177 -108
- dao_ai/nodes.py +118 -161
- dao_ai/optimization.py +664 -0
- dao_ai/orchestration/__init__.py +52 -0
- dao_ai/orchestration/core.py +287 -0
- dao_ai/orchestration/supervisor.py +264 -0
- dao_ai/orchestration/swarm.py +226 -0
- dao_ai/prompts.py +126 -29
- dao_ai/providers/databricks.py +126 -381
- dao_ai/state.py +139 -21
- dao_ai/tools/__init__.py +11 -5
- dao_ai/tools/core.py +57 -4
- dao_ai/tools/email.py +280 -0
- dao_ai/tools/genie.py +108 -35
- dao_ai/tools/mcp.py +4 -3
- dao_ai/tools/memory.py +50 -0
- dao_ai/tools/python.py +4 -12
- dao_ai/tools/search.py +14 -0
- dao_ai/tools/slack.py +1 -1
- dao_ai/tools/unity_catalog.py +8 -6
- dao_ai/tools/vector_search.py +16 -9
- dao_ai/utils.py +72 -8
- dao_ai-0.1.0.dist-info/METADATA +1878 -0
- dao_ai-0.1.0.dist-info/RECORD +62 -0
- dao_ai/chat_models.py +0 -204
- dao_ai/guardrails.py +0 -112
- dao_ai/tools/human_in_the_loop.py +0 -100
- dao_ai-0.0.35.dist-info/METADATA +0 -1169
- dao_ai-0.0.35.dist-info/RECORD +0 -41
- {dao_ai-0.0.35.dist-info → dao_ai-0.1.0.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.35.dist-info → dao_ai-0.1.0.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.35.dist-info → dao_ai-0.1.0.dist-info}/licenses/LICENSE +0 -0
dao_ai/genie/__init__.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Genie service implementations and caching layers.
|
|
3
|
+
|
|
4
|
+
This package provides core Genie functionality that can be used across
|
|
5
|
+
different contexts (tools, direct integration, etc.).
|
|
6
|
+
|
|
7
|
+
Main exports:
|
|
8
|
+
- GenieService: Core service implementation wrapping Databricks Genie SDK
|
|
9
|
+
- GenieServiceBase: Abstract base class for service implementations
|
|
10
|
+
|
|
11
|
+
Cache implementations are available in the cache subpackage:
|
|
12
|
+
- dao_ai.genie.cache.lru: LRU (Least Recently Used) cache
|
|
13
|
+
- dao_ai.genie.cache.semantic: Semantic similarity cache using pg_vector
|
|
14
|
+
|
|
15
|
+
Example usage:
|
|
16
|
+
from dao_ai.genie import GenieService
|
|
17
|
+
from dao_ai.genie.cache import LRUCacheService, SemanticCacheService
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from dao_ai.genie.cache import (
|
|
21
|
+
CacheResult,
|
|
22
|
+
GenieServiceBase,
|
|
23
|
+
LRUCacheService,
|
|
24
|
+
SemanticCacheService,
|
|
25
|
+
SQLCacheEntry,
|
|
26
|
+
)
|
|
27
|
+
from dao_ai.genie.core import GenieService
|
|
28
|
+
|
|
29
|
+
__all__ = [
|
|
30
|
+
# Service classes
|
|
31
|
+
"GenieService",
|
|
32
|
+
"GenieServiceBase",
|
|
33
|
+
# Cache types (from cache subpackage)
|
|
34
|
+
"CacheResult",
|
|
35
|
+
"LRUCacheService",
|
|
36
|
+
"SemanticCacheService",
|
|
37
|
+
"SQLCacheEntry",
|
|
38
|
+
]
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Genie cache implementations.
|
|
3
|
+
|
|
4
|
+
This package provides caching layers for Genie SQL queries that can be
|
|
5
|
+
chained together using the decorator pattern.
|
|
6
|
+
|
|
7
|
+
Available cache implementations:
|
|
8
|
+
- LRUCacheService: In-memory LRU cache with O(1) exact match lookup
|
|
9
|
+
- SemanticCacheService: PostgreSQL pg_vector-based semantic similarity cache
|
|
10
|
+
|
|
11
|
+
Example usage:
|
|
12
|
+
from dao_ai.genie.cache import LRUCacheService, SemanticCacheService
|
|
13
|
+
|
|
14
|
+
# Chain caches: LRU (checked first) -> Semantic (checked second) -> Genie
|
|
15
|
+
genie_service = SemanticCacheService(
|
|
16
|
+
impl=GenieService(genie),
|
|
17
|
+
parameters=semantic_params,
|
|
18
|
+
)
|
|
19
|
+
genie_service = LRUCacheService(
|
|
20
|
+
impl=genie_service,
|
|
21
|
+
parameters=lru_params,
|
|
22
|
+
)
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
from dao_ai.genie.cache.base import (
|
|
26
|
+
CacheResult,
|
|
27
|
+
GenieServiceBase,
|
|
28
|
+
SQLCacheEntry,
|
|
29
|
+
)
|
|
30
|
+
from dao_ai.genie.cache.core import execute_sql_via_warehouse
|
|
31
|
+
from dao_ai.genie.cache.lru import LRUCacheService
|
|
32
|
+
from dao_ai.genie.cache.semantic import SemanticCacheService
|
|
33
|
+
|
|
34
|
+
__all__ = [
|
|
35
|
+
# Base types
|
|
36
|
+
"CacheResult",
|
|
37
|
+
"GenieServiceBase",
|
|
38
|
+
"SQLCacheEntry",
|
|
39
|
+
"execute_sql_via_warehouse",
|
|
40
|
+
# Cache implementations
|
|
41
|
+
"LRUCacheService",
|
|
42
|
+
"SemanticCacheService",
|
|
43
|
+
]
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Base classes and types for Genie cache implementations.
|
|
3
|
+
|
|
4
|
+
This module provides the foundational types used across different cache
|
|
5
|
+
implementations (LRU, Semantic, etc.). It contains only abstract base classes
|
|
6
|
+
and data structures - no concrete implementations.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from abc import ABC, abstractmethod
|
|
12
|
+
from dataclasses import dataclass
|
|
13
|
+
from datetime import datetime
|
|
14
|
+
from typing import TYPE_CHECKING
|
|
15
|
+
|
|
16
|
+
from databricks_ai_bridge.genie import GenieResponse
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from dao_ai.genie.cache.base import CacheResult
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class GenieServiceBase(ABC):
|
|
23
|
+
"""Abstract base class for Genie service implementations."""
|
|
24
|
+
|
|
25
|
+
@abstractmethod
|
|
26
|
+
def ask_question(
|
|
27
|
+
self, question: str, conversation_id: str | None = None
|
|
28
|
+
) -> "CacheResult":
|
|
29
|
+
"""
|
|
30
|
+
Ask a question to Genie and return the response with cache metadata.
|
|
31
|
+
|
|
32
|
+
All implementations return CacheResult to provide consistent cache information,
|
|
33
|
+
even when caching is disabled (cache_hit=False, served_by=None).
|
|
34
|
+
"""
|
|
35
|
+
pass
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
@abstractmethod
|
|
39
|
+
def space_id(self) -> str:
|
|
40
|
+
"""The space ID for the Genie service."""
|
|
41
|
+
pass
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class SQLCacheEntry:
|
|
46
|
+
"""
|
|
47
|
+
A cache entry storing the SQL query metadata for re-execution.
|
|
48
|
+
|
|
49
|
+
Instead of caching the full result, we cache the SQL query so that
|
|
50
|
+
on cache hit we can re-execute it to get fresh data.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
query: str
|
|
54
|
+
description: str
|
|
55
|
+
conversation_id: str
|
|
56
|
+
created_at: datetime
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@dataclass
|
|
60
|
+
class CacheResult:
|
|
61
|
+
"""
|
|
62
|
+
Result of a cache-aware query with metadata about cache behavior.
|
|
63
|
+
|
|
64
|
+
Attributes:
|
|
65
|
+
response: The GenieResponse (fresh data, possibly from cached SQL)
|
|
66
|
+
cache_hit: Whether the SQL query came from cache
|
|
67
|
+
served_by: Name of the layer that served the cached SQL (None if from origin)
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
response: GenieResponse
|
|
71
|
+
cache_hit: bool
|
|
72
|
+
served_by: str | None = None
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Core utilities for Genie cache implementations.
|
|
3
|
+
|
|
4
|
+
This module provides shared utility functions used by different cache
|
|
5
|
+
implementations (LRU, Semantic, etc.). These are concrete implementations
|
|
6
|
+
of common operations needed across cache types.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
import pandas as pd
|
|
12
|
+
from databricks.sdk import WorkspaceClient
|
|
13
|
+
from databricks.sdk.service.sql import StatementResponse, StatementState
|
|
14
|
+
from loguru import logger
|
|
15
|
+
|
|
16
|
+
from dao_ai.config import WarehouseModel
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def execute_sql_via_warehouse(
|
|
20
|
+
warehouse: WarehouseModel,
|
|
21
|
+
sql: str,
|
|
22
|
+
layer_name: str = "cache",
|
|
23
|
+
) -> pd.DataFrame | str:
|
|
24
|
+
"""
|
|
25
|
+
Execute SQL using a Databricks warehouse and return results as DataFrame.
|
|
26
|
+
|
|
27
|
+
This is a shared utility for cache implementations that need to re-execute
|
|
28
|
+
cached SQL queries.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
warehouse: The warehouse configuration for SQL execution
|
|
32
|
+
sql: The SQL query to execute
|
|
33
|
+
layer_name: Name of the cache layer (for logging)
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
DataFrame with results, or error message string
|
|
37
|
+
"""
|
|
38
|
+
w: WorkspaceClient = warehouse.workspace_client
|
|
39
|
+
warehouse_id: str = str(warehouse.warehouse_id)
|
|
40
|
+
|
|
41
|
+
logger.debug(f"[{layer_name}] Executing cached SQL: {sql[:100]}...")
|
|
42
|
+
|
|
43
|
+
statement_response: StatementResponse = w.statement_execution.execute_statement(
|
|
44
|
+
statement=sql,
|
|
45
|
+
warehouse_id=warehouse_id,
|
|
46
|
+
wait_timeout="30s",
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
# Poll for completion if still running
|
|
50
|
+
while statement_response.status.state in [
|
|
51
|
+
StatementState.PENDING,
|
|
52
|
+
StatementState.RUNNING,
|
|
53
|
+
]:
|
|
54
|
+
statement_response = w.statement_execution.get_statement(
|
|
55
|
+
statement_response.statement_id
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
if statement_response.status.state != StatementState.SUCCEEDED:
|
|
59
|
+
error_msg: str = f"SQL execution failed: {statement_response.status}"
|
|
60
|
+
logger.error(f"[{layer_name}] {error_msg}")
|
|
61
|
+
return error_msg
|
|
62
|
+
|
|
63
|
+
# Convert to DataFrame
|
|
64
|
+
if statement_response.result and statement_response.result.data_array:
|
|
65
|
+
columns: list[str] = []
|
|
66
|
+
if statement_response.manifest and statement_response.manifest.schema:
|
|
67
|
+
columns = [col.name for col in statement_response.manifest.schema.columns]
|
|
68
|
+
|
|
69
|
+
data: list[list[Any]] = statement_response.result.data_array
|
|
70
|
+
if columns:
|
|
71
|
+
return pd.DataFrame(data, columns=columns)
|
|
72
|
+
else:
|
|
73
|
+
return pd.DataFrame(data)
|
|
74
|
+
|
|
75
|
+
return pd.DataFrame()
|
|
@@ -0,0 +1,329 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LRU (Least Recently Used) cache implementation for Genie SQL queries.
|
|
3
|
+
|
|
4
|
+
This module provides an in-memory LRU cache that stores SQL queries generated
|
|
5
|
+
by Genie. On cache hit, the cached SQL is re-executed against the warehouse
|
|
6
|
+
to return fresh data while avoiding the Genie NL-to-SQL translation cost.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from collections import OrderedDict
|
|
10
|
+
from datetime import datetime, timedelta
|
|
11
|
+
from threading import Lock
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
import mlflow
|
|
15
|
+
import pandas as pd
|
|
16
|
+
from databricks.sdk import WorkspaceClient
|
|
17
|
+
from databricks.sdk.service.sql import StatementResponse, StatementState
|
|
18
|
+
from databricks_ai_bridge.genie import GenieResponse
|
|
19
|
+
from loguru import logger
|
|
20
|
+
|
|
21
|
+
from dao_ai.config import GenieLRUCacheParametersModel, WarehouseModel
|
|
22
|
+
from dao_ai.genie.cache.base import (
|
|
23
|
+
CacheResult,
|
|
24
|
+
GenieServiceBase,
|
|
25
|
+
SQLCacheEntry,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class LRUCacheService(GenieServiceBase):
|
|
30
|
+
"""
|
|
31
|
+
LRU caching decorator that caches SQL queries and re-executes them.
|
|
32
|
+
|
|
33
|
+
This service caches the SQL query generated by Genie (not the result data).
|
|
34
|
+
On cache hit, it re-executes the cached SQL using the provided warehouse
|
|
35
|
+
to return fresh data while avoiding the Genie NL-to-SQL translation cost.
|
|
36
|
+
|
|
37
|
+
Example:
|
|
38
|
+
from dao_ai.config import GenieLRUCacheParametersModel, WarehouseModel
|
|
39
|
+
from dao_ai.genie.cache import LRUCacheService
|
|
40
|
+
|
|
41
|
+
cache_params = GenieLRUCacheParametersModel(
|
|
42
|
+
warehouse=warehouse_model,
|
|
43
|
+
capacity=100,
|
|
44
|
+
time_to_live_seconds=86400 # 24 hours
|
|
45
|
+
)
|
|
46
|
+
genie = LRUCacheService(
|
|
47
|
+
impl=GenieService(Genie(space_id="my-space")),
|
|
48
|
+
parameters=cache_params
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
Thread-safe: Uses a lock to protect cache operations.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
impl: GenieServiceBase
|
|
55
|
+
parameters: GenieLRUCacheParametersModel
|
|
56
|
+
name: str
|
|
57
|
+
_cache: OrderedDict[str, SQLCacheEntry]
|
|
58
|
+
_lock: Lock
|
|
59
|
+
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
impl: GenieServiceBase,
|
|
63
|
+
parameters: GenieLRUCacheParametersModel,
|
|
64
|
+
name: str | None = None,
|
|
65
|
+
) -> None:
|
|
66
|
+
"""
|
|
67
|
+
Initialize the SQL cache service.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
impl: The underlying GenieServiceBase to delegate to on cache miss
|
|
71
|
+
parameters: Cache configuration including warehouse, capacity, and TTL
|
|
72
|
+
name: Name for this cache layer (for logging). Defaults to class name.
|
|
73
|
+
"""
|
|
74
|
+
self.impl = impl
|
|
75
|
+
self.parameters = parameters
|
|
76
|
+
self.name = name if name is not None else self.__class__.__name__
|
|
77
|
+
self._cache = OrderedDict()
|
|
78
|
+
self._lock = Lock()
|
|
79
|
+
|
|
80
|
+
@property
|
|
81
|
+
def warehouse(self) -> WarehouseModel:
|
|
82
|
+
"""The warehouse used for executing cached SQL queries."""
|
|
83
|
+
return self.parameters.warehouse
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
def capacity(self) -> int:
|
|
87
|
+
"""Maximum number of SQL queries to cache."""
|
|
88
|
+
return self.parameters.capacity
|
|
89
|
+
|
|
90
|
+
@property
|
|
91
|
+
def time_to_live(self) -> timedelta | None:
|
|
92
|
+
"""Duration after which cached queries expire. None means never expires."""
|
|
93
|
+
ttl = self.parameters.time_to_live_seconds
|
|
94
|
+
if ttl is None or ttl < 0:
|
|
95
|
+
return None
|
|
96
|
+
return timedelta(seconds=ttl)
|
|
97
|
+
|
|
98
|
+
@staticmethod
|
|
99
|
+
def _normalize_key(question: str, conversation_id: str | None = None) -> str:
|
|
100
|
+
"""
|
|
101
|
+
Normalize the question and conversation_id to create a consistent cache key.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
question: The question text
|
|
105
|
+
conversation_id: Optional conversation ID to include in the key
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
A normalized cache key combining question and conversation_id
|
|
109
|
+
"""
|
|
110
|
+
normalized_question = question.strip().lower()
|
|
111
|
+
if conversation_id:
|
|
112
|
+
return f"{conversation_id}::{normalized_question}"
|
|
113
|
+
return normalized_question
|
|
114
|
+
|
|
115
|
+
def _is_expired(self, entry: SQLCacheEntry) -> bool:
|
|
116
|
+
"""Check if a cache entry has exceeded its TTL. Returns False if TTL is disabled."""
|
|
117
|
+
if self.time_to_live is None:
|
|
118
|
+
return False
|
|
119
|
+
age: timedelta = datetime.now() - entry.created_at
|
|
120
|
+
return age > self.time_to_live
|
|
121
|
+
|
|
122
|
+
def _evict_oldest(self) -> None:
|
|
123
|
+
"""Remove the oldest (least recently used) entry."""
|
|
124
|
+
if self._cache:
|
|
125
|
+
oldest_key: str = next(iter(self._cache))
|
|
126
|
+
del self._cache[oldest_key]
|
|
127
|
+
logger.debug(f"[{self.name}] Evicted: {oldest_key[:50]}...")
|
|
128
|
+
|
|
129
|
+
def _get(self, key: str) -> SQLCacheEntry | None:
|
|
130
|
+
"""Get from cache, returning None if not found or expired."""
|
|
131
|
+
if key not in self._cache:
|
|
132
|
+
return None
|
|
133
|
+
|
|
134
|
+
entry: SQLCacheEntry = self._cache[key]
|
|
135
|
+
|
|
136
|
+
if self._is_expired(entry):
|
|
137
|
+
del self._cache[key]
|
|
138
|
+
logger.debug(f"[{self.name}] Expired: {key[:50]}...")
|
|
139
|
+
return None
|
|
140
|
+
|
|
141
|
+
self._cache.move_to_end(key)
|
|
142
|
+
return entry
|
|
143
|
+
|
|
144
|
+
def _put(self, key: str, response: GenieResponse) -> None:
|
|
145
|
+
"""Store SQL query in cache, evicting if at capacity."""
|
|
146
|
+
if key in self._cache:
|
|
147
|
+
del self._cache[key]
|
|
148
|
+
|
|
149
|
+
while len(self._cache) >= self.capacity:
|
|
150
|
+
self._evict_oldest()
|
|
151
|
+
|
|
152
|
+
self._cache[key] = SQLCacheEntry(
|
|
153
|
+
query=response.query,
|
|
154
|
+
description=response.description,
|
|
155
|
+
conversation_id=response.conversation_id,
|
|
156
|
+
created_at=datetime.now(),
|
|
157
|
+
)
|
|
158
|
+
logger.info(
|
|
159
|
+
f"[{self.name}] Stored cache entry: key='{key[:50]}...' "
|
|
160
|
+
f"sql='{response.query[:50] if response.query else 'None'}...' "
|
|
161
|
+
f"(cache_size={len(self._cache)}/{self.capacity})"
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
@mlflow.trace(name="execute_cached_sql")
|
|
165
|
+
def _execute_sql(self, sql: str) -> pd.DataFrame | str:
|
|
166
|
+
"""
|
|
167
|
+
Execute SQL using the warehouse and return results as DataFrame.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
sql: The SQL query to execute
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
DataFrame with results, or error message string
|
|
174
|
+
"""
|
|
175
|
+
w: WorkspaceClient = self.warehouse.workspace_client
|
|
176
|
+
warehouse_id: str = str(self.warehouse.warehouse_id)
|
|
177
|
+
|
|
178
|
+
logger.debug(f"[{self.name}] Executing cached SQL: {sql[:100]}...")
|
|
179
|
+
|
|
180
|
+
statement_response: StatementResponse = w.statement_execution.execute_statement(
|
|
181
|
+
statement=sql,
|
|
182
|
+
warehouse_id=warehouse_id,
|
|
183
|
+
wait_timeout="30s",
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
# Poll for completion if still running
|
|
187
|
+
while statement_response.status.state in [
|
|
188
|
+
StatementState.PENDING,
|
|
189
|
+
StatementState.RUNNING,
|
|
190
|
+
]:
|
|
191
|
+
statement_response = w.statement_execution.get_statement(
|
|
192
|
+
statement_response.statement_id
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
if statement_response.status.state != StatementState.SUCCEEDED:
|
|
196
|
+
error_msg: str = f"SQL execution failed: {statement_response.status}"
|
|
197
|
+
logger.error(f"[{self.name}] {error_msg}")
|
|
198
|
+
return error_msg
|
|
199
|
+
|
|
200
|
+
# Convert to DataFrame
|
|
201
|
+
if statement_response.result and statement_response.result.data_array:
|
|
202
|
+
columns: list[str] = []
|
|
203
|
+
if statement_response.manifest and statement_response.manifest.schema:
|
|
204
|
+
columns = [
|
|
205
|
+
col.name for col in statement_response.manifest.schema.columns
|
|
206
|
+
]
|
|
207
|
+
|
|
208
|
+
data: list[list[Any]] = statement_response.result.data_array
|
|
209
|
+
if columns:
|
|
210
|
+
return pd.DataFrame(data, columns=columns)
|
|
211
|
+
else:
|
|
212
|
+
return pd.DataFrame(data)
|
|
213
|
+
|
|
214
|
+
return pd.DataFrame()
|
|
215
|
+
|
|
216
|
+
def ask_question(
|
|
217
|
+
self, question: str, conversation_id: str | None = None
|
|
218
|
+
) -> CacheResult:
|
|
219
|
+
"""
|
|
220
|
+
Ask a question, using cached SQL query if available.
|
|
221
|
+
|
|
222
|
+
On cache hit, re-executes the cached SQL to get fresh data.
|
|
223
|
+
Returns CacheResult with cache metadata.
|
|
224
|
+
"""
|
|
225
|
+
return self.ask_question_with_cache_info(question, conversation_id)
|
|
226
|
+
|
|
227
|
+
@mlflow.trace(name="genie_lru_cache_lookup")
|
|
228
|
+
def ask_question_with_cache_info(
|
|
229
|
+
self,
|
|
230
|
+
question: str,
|
|
231
|
+
conversation_id: str | None = None,
|
|
232
|
+
) -> CacheResult:
|
|
233
|
+
"""
|
|
234
|
+
Ask a question with detailed cache hit information.
|
|
235
|
+
|
|
236
|
+
On cache hit, the cached SQL is re-executed to return fresh data.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
question: The question to ask
|
|
240
|
+
conversation_id: Optional conversation ID
|
|
241
|
+
|
|
242
|
+
Returns:
|
|
243
|
+
CacheResult with fresh response and cache metadata
|
|
244
|
+
"""
|
|
245
|
+
key: str = self._normalize_key(question, conversation_id)
|
|
246
|
+
|
|
247
|
+
# Check cache
|
|
248
|
+
with self._lock:
|
|
249
|
+
cached: SQLCacheEntry | None = self._get(key)
|
|
250
|
+
|
|
251
|
+
if cached is not None:
|
|
252
|
+
logger.info(
|
|
253
|
+
f"[{self.name}] Cache HIT: '{question[:50]}...' "
|
|
254
|
+
f"(conversation_id={conversation_id}, cache_size={self.size}/{self.capacity})"
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
# Re-execute the cached SQL to get fresh data
|
|
258
|
+
result: pd.DataFrame | str = self._execute_sql(cached.query)
|
|
259
|
+
|
|
260
|
+
# Use current conversation_id, not the cached one
|
|
261
|
+
response: GenieResponse = GenieResponse(
|
|
262
|
+
result=result,
|
|
263
|
+
query=cached.query,
|
|
264
|
+
description=cached.description,
|
|
265
|
+
conversation_id=conversation_id
|
|
266
|
+
if conversation_id
|
|
267
|
+
else cached.conversation_id,
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
return CacheResult(response=response, cache_hit=True, served_by=self.name)
|
|
271
|
+
|
|
272
|
+
# Cache miss - delegate to wrapped service
|
|
273
|
+
logger.info(
|
|
274
|
+
f"[{self.name}] Cache MISS: '{question[:50]}...' "
|
|
275
|
+
f"(conversation_id={conversation_id}, cache_size={self.size}/{self.capacity}, delegating to {type(self.impl).__name__})"
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
result: CacheResult = self.impl.ask_question(question, conversation_id)
|
|
279
|
+
with self._lock:
|
|
280
|
+
self._put(key, result.response)
|
|
281
|
+
return CacheResult(response=result.response, cache_hit=False, served_by=None)
|
|
282
|
+
|
|
283
|
+
@property
|
|
284
|
+
def space_id(self) -> str:
|
|
285
|
+
return self.impl.space_id
|
|
286
|
+
|
|
287
|
+
def invalidate(self, question: str, conversation_id: str | None = None) -> bool:
|
|
288
|
+
"""
|
|
289
|
+
Remove a specific entry from the cache.
|
|
290
|
+
|
|
291
|
+
Args:
|
|
292
|
+
question: The question text
|
|
293
|
+
conversation_id: Optional conversation ID to match
|
|
294
|
+
|
|
295
|
+
Returns:
|
|
296
|
+
True if the entry was found and removed, False otherwise
|
|
297
|
+
"""
|
|
298
|
+
key: str = self._normalize_key(question, conversation_id)
|
|
299
|
+
with self._lock:
|
|
300
|
+
if key in self._cache:
|
|
301
|
+
del self._cache[key]
|
|
302
|
+
return True
|
|
303
|
+
return False
|
|
304
|
+
|
|
305
|
+
def clear(self) -> int:
|
|
306
|
+
"""Clear all entries from the cache."""
|
|
307
|
+
with self._lock:
|
|
308
|
+
count: int = len(self._cache)
|
|
309
|
+
self._cache.clear()
|
|
310
|
+
return count
|
|
311
|
+
|
|
312
|
+
@property
|
|
313
|
+
def size(self) -> int:
|
|
314
|
+
"""Current number of entries in the cache."""
|
|
315
|
+
with self._lock:
|
|
316
|
+
return len(self._cache)
|
|
317
|
+
|
|
318
|
+
def stats(self) -> dict[str, int | float | None]:
|
|
319
|
+
"""Return cache statistics."""
|
|
320
|
+
with self._lock:
|
|
321
|
+
expired: int = sum(1 for e in self._cache.values() if self._is_expired(e))
|
|
322
|
+
ttl = self.time_to_live
|
|
323
|
+
return {
|
|
324
|
+
"size": len(self._cache),
|
|
325
|
+
"capacity": self.capacity,
|
|
326
|
+
"ttl_seconds": ttl.total_seconds() if ttl else None,
|
|
327
|
+
"expired_entries": expired,
|
|
328
|
+
"valid_entries": len(self._cache) - expired,
|
|
329
|
+
}
|