dao-ai 0.0.35__py3-none-any.whl → 0.0.36__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/config.py +29 -0
- dao_ai/genie/__init__.py +59 -0
- dao_ai/genie/cache/__init__.py +44 -0
- dao_ai/genie/cache/base.py +122 -0
- dao_ai/genie/cache/lru.py +306 -0
- dao_ai/genie/cache/semantic.py +638 -0
- dao_ai/tools/__init__.py +3 -0
- dao_ai/tools/genie/__init__.py +236 -0
- dao_ai/tools/genie.py +65 -15
- dao_ai-0.0.36.dist-info/METADATA +951 -0
- {dao_ai-0.0.35.dist-info → dao_ai-0.0.36.dist-info}/RECORD +14 -8
- dao_ai-0.0.35.dist-info/METADATA +0 -1169
- {dao_ai-0.0.35.dist-info → dao_ai-0.0.36.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.35.dist-info → dao_ai-0.0.36.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.35.dist-info → dao_ai-0.0.36.dist-info}/licenses/LICENSE +0 -0
dao_ai/config.py
CHANGED
|
@@ -28,8 +28,10 @@ from databricks.sdk.service.database import DatabaseInstance
|
|
|
28
28
|
from databricks.vector_search.client import VectorSearchClient
|
|
29
29
|
from databricks.vector_search.index import VectorSearchIndex
|
|
30
30
|
from databricks_langchain import (
|
|
31
|
+
DatabricksEmbeddings,
|
|
31
32
|
DatabricksFunctionClient,
|
|
32
33
|
)
|
|
34
|
+
from langchain_core.embeddings import Embeddings
|
|
33
35
|
from langchain_core.language_models import LanguageModelLike
|
|
34
36
|
from langchain_core.messages import BaseMessage, messages_from_dict
|
|
35
37
|
from langchain_core.runnables.base import RunnableLike
|
|
@@ -408,6 +410,9 @@ class LLMModel(BaseModel, IsDatabricksResource):
|
|
|
408
410
|
|
|
409
411
|
return chat_client
|
|
410
412
|
|
|
413
|
+
def as_embeddings_model(self) -> Embeddings:
|
|
414
|
+
return DatabricksEmbeddings(endpoint=self.name)
|
|
415
|
+
|
|
411
416
|
|
|
412
417
|
class VectorSearchEndpointType(str, Enum):
|
|
413
418
|
STANDARD = "STANDARD"
|
|
@@ -977,6 +982,30 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
977
982
|
provider.create_lakebase_instance_role(self)
|
|
978
983
|
|
|
979
984
|
|
|
985
|
+
class GenieLRUCacheParametersModel(BaseModel):
|
|
986
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
987
|
+
capacity: int = 1000
|
|
988
|
+
time_to_live_seconds: int | None = (
|
|
989
|
+
60 * 60 * 24
|
|
990
|
+
) # 1 day default, None or negative = never expires
|
|
991
|
+
warehouse: WarehouseModel
|
|
992
|
+
|
|
993
|
+
|
|
994
|
+
class GenieSemanticCacheParametersModel(BaseModel):
|
|
995
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
996
|
+
time_to_live_seconds: int | None = (
|
|
997
|
+
60 * 60 * 24
|
|
998
|
+
) # 1 day default, None or negative = never expires
|
|
999
|
+
similarity_threshold: float = (
|
|
1000
|
+
0.85 # Minimum similarity for cache hit (L2 distance converted to 0-1 scale)
|
|
1001
|
+
)
|
|
1002
|
+
embedding_model: str | LLMModel = "databricks-gte-large-en"
|
|
1003
|
+
embedding_dims: int | None = None # Auto-detected if None
|
|
1004
|
+
database: DatabaseModel
|
|
1005
|
+
warehouse: WarehouseModel
|
|
1006
|
+
table_name: str = "genie_semantic_cache"
|
|
1007
|
+
|
|
1008
|
+
|
|
980
1009
|
class SearchParametersModel(BaseModel):
|
|
981
1010
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
982
1011
|
num_results: Optional[int] = 10
|
dao_ai/genie/__init__.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Genie service implementations and caching layers.
|
|
3
|
+
|
|
4
|
+
This package provides core Genie functionality that can be used across
|
|
5
|
+
different contexts (tools, direct integration, etc.).
|
|
6
|
+
|
|
7
|
+
Main exports:
|
|
8
|
+
- GenieService: Core service implementation wrapping Databricks Genie SDK
|
|
9
|
+
- GenieServiceBase: Abstract base class for service implementations
|
|
10
|
+
|
|
11
|
+
Cache implementations are available in the cache subpackage:
|
|
12
|
+
- dao_ai.genie.cache.lru: LRU (Least Recently Used) cache
|
|
13
|
+
- dao_ai.genie.cache.semantic: Semantic similarity cache using pg_vector
|
|
14
|
+
|
|
15
|
+
Example usage:
|
|
16
|
+
from dao_ai.genie import GenieService
|
|
17
|
+
from dao_ai.genie.cache import LRUCacheService, SemanticCacheService
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
import mlflow
|
|
21
|
+
from databricks_ai_bridge.genie import Genie, GenieResponse
|
|
22
|
+
|
|
23
|
+
from dao_ai.genie.cache import (
|
|
24
|
+
CacheResult,
|
|
25
|
+
GenieServiceBase,
|
|
26
|
+
LRUCacheService,
|
|
27
|
+
SemanticCacheService,
|
|
28
|
+
SQLCacheEntry,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class GenieService(GenieServiceBase):
|
|
33
|
+
"""Concrete implementation of GenieServiceBase using the Genie SDK."""
|
|
34
|
+
|
|
35
|
+
genie: Genie
|
|
36
|
+
|
|
37
|
+
def __init__(self, genie: Genie) -> None:
|
|
38
|
+
self.genie = genie
|
|
39
|
+
|
|
40
|
+
@mlflow.trace(name="genie_ask_question")
|
|
41
|
+
def ask_question(
|
|
42
|
+
self, question: str, conversation_id: str | None = None
|
|
43
|
+
) -> GenieResponse:
|
|
44
|
+
response: GenieResponse = self.genie.ask_question(
|
|
45
|
+
question, conversation_id=conversation_id
|
|
46
|
+
)
|
|
47
|
+
return response
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
__all__ = [
|
|
51
|
+
# Service classes
|
|
52
|
+
"GenieService",
|
|
53
|
+
"GenieServiceBase",
|
|
54
|
+
# Cache types (from cache subpackage)
|
|
55
|
+
"CacheResult",
|
|
56
|
+
"LRUCacheService",
|
|
57
|
+
"SemanticCacheService",
|
|
58
|
+
"SQLCacheEntry",
|
|
59
|
+
]
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Genie cache implementations.
|
|
3
|
+
|
|
4
|
+
This package provides caching layers for Genie SQL queries that can be
|
|
5
|
+
chained together using the decorator pattern.
|
|
6
|
+
|
|
7
|
+
Available cache implementations:
|
|
8
|
+
- LRUCacheService: In-memory LRU cache with O(1) exact match lookup
|
|
9
|
+
- SemanticCacheService: PostgreSQL pg_vector-based semantic similarity cache
|
|
10
|
+
|
|
11
|
+
Example usage:
|
|
12
|
+
from dao_ai.genie.cache import LRUCacheService, SemanticCacheService
|
|
13
|
+
|
|
14
|
+
# Chain caches: LRU (checked first) -> Semantic (checked second) -> Genie
|
|
15
|
+
genie_service = SemanticCacheService(
|
|
16
|
+
impl=GenieService(genie),
|
|
17
|
+
parameters=semantic_params,
|
|
18
|
+
genie_space_id=space_id,
|
|
19
|
+
)
|
|
20
|
+
genie_service = LRUCacheService(
|
|
21
|
+
impl=genie_service,
|
|
22
|
+
parameters=lru_params,
|
|
23
|
+
)
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
from dao_ai.genie.cache.base import (
|
|
27
|
+
CacheResult,
|
|
28
|
+
GenieServiceBase,
|
|
29
|
+
SQLCacheEntry,
|
|
30
|
+
execute_sql_via_warehouse,
|
|
31
|
+
)
|
|
32
|
+
from dao_ai.genie.cache.lru import LRUCacheService
|
|
33
|
+
from dao_ai.genie.cache.semantic import SemanticCacheService
|
|
34
|
+
|
|
35
|
+
__all__ = [
|
|
36
|
+
# Base types
|
|
37
|
+
"CacheResult",
|
|
38
|
+
"GenieServiceBase",
|
|
39
|
+
"SQLCacheEntry",
|
|
40
|
+
"execute_sql_via_warehouse",
|
|
41
|
+
# Cache implementations
|
|
42
|
+
"LRUCacheService",
|
|
43
|
+
"SemanticCacheService",
|
|
44
|
+
]
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Base classes and types for Genie cache implementations.
|
|
3
|
+
|
|
4
|
+
This module provides the foundational types used across different cache
|
|
5
|
+
implementations (LRU, Semantic, etc.).
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from datetime import datetime
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
import pandas as pd
|
|
14
|
+
from databricks.sdk import WorkspaceClient
|
|
15
|
+
from databricks.sdk.service.sql import StatementResponse, StatementState
|
|
16
|
+
from databricks_ai_bridge.genie import GenieResponse
|
|
17
|
+
from loguru import logger
|
|
18
|
+
|
|
19
|
+
from dao_ai.config import WarehouseModel
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class GenieServiceBase(ABC):
|
|
23
|
+
"""Abstract base class for Genie service implementations."""
|
|
24
|
+
|
|
25
|
+
@abstractmethod
|
|
26
|
+
def ask_question(
|
|
27
|
+
self, question: str, conversation_id: str | None = None
|
|
28
|
+
) -> GenieResponse:
|
|
29
|
+
"""Ask a question to Genie and return the response."""
|
|
30
|
+
pass
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class SQLCacheEntry:
|
|
35
|
+
"""
|
|
36
|
+
A cache entry storing the SQL query metadata for re-execution.
|
|
37
|
+
|
|
38
|
+
Instead of caching the full result, we cache the SQL query so that
|
|
39
|
+
on cache hit we can re-execute it to get fresh data.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
query: str
|
|
43
|
+
description: str
|
|
44
|
+
conversation_id: str
|
|
45
|
+
created_at: datetime
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass
|
|
49
|
+
class CacheResult:
|
|
50
|
+
"""
|
|
51
|
+
Result of a cache-aware query with metadata about cache behavior.
|
|
52
|
+
|
|
53
|
+
Attributes:
|
|
54
|
+
response: The GenieResponse (fresh data, possibly from cached SQL)
|
|
55
|
+
cache_hit: Whether the SQL query came from cache
|
|
56
|
+
served_by: Name of the layer that served the cached SQL (None if from origin)
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
response: GenieResponse
|
|
60
|
+
cache_hit: bool
|
|
61
|
+
served_by: str | None = None
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def execute_sql_via_warehouse(
|
|
65
|
+
warehouse: WarehouseModel,
|
|
66
|
+
sql: str,
|
|
67
|
+
layer_name: str = "cache",
|
|
68
|
+
) -> pd.DataFrame | str:
|
|
69
|
+
"""
|
|
70
|
+
Execute SQL using a Databricks warehouse and return results as DataFrame.
|
|
71
|
+
|
|
72
|
+
This is a shared utility for cache implementations that need to re-execute
|
|
73
|
+
cached SQL queries.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
warehouse: The warehouse configuration for SQL execution
|
|
77
|
+
sql: The SQL query to execute
|
|
78
|
+
layer_name: Name of the cache layer (for logging)
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
DataFrame with results, or error message string
|
|
82
|
+
"""
|
|
83
|
+
w: WorkspaceClient = warehouse.workspace_client
|
|
84
|
+
warehouse_id: str = str(warehouse.warehouse_id)
|
|
85
|
+
|
|
86
|
+
logger.debug(f"[{layer_name}] Executing cached SQL: {sql[:100]}...")
|
|
87
|
+
|
|
88
|
+
statement_response: StatementResponse = w.statement_execution.execute_statement(
|
|
89
|
+
statement=sql,
|
|
90
|
+
warehouse_id=warehouse_id,
|
|
91
|
+
wait_timeout="30s",
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
# Poll for completion if still running
|
|
95
|
+
while statement_response.status.state in [
|
|
96
|
+
StatementState.PENDING,
|
|
97
|
+
StatementState.RUNNING,
|
|
98
|
+
]:
|
|
99
|
+
statement_response = w.statement_execution.get_statement(
|
|
100
|
+
statement_response.statement_id
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
if statement_response.status.state != StatementState.SUCCEEDED:
|
|
104
|
+
error_msg: str = f"SQL execution failed: {statement_response.status}"
|
|
105
|
+
logger.error(f"[{layer_name}] {error_msg}")
|
|
106
|
+
return error_msg
|
|
107
|
+
|
|
108
|
+
# Convert to DataFrame
|
|
109
|
+
if statement_response.result and statement_response.result.data_array:
|
|
110
|
+
columns: list[str] = []
|
|
111
|
+
if statement_response.manifest and statement_response.manifest.schema:
|
|
112
|
+
columns = [col.name for col in statement_response.manifest.schema.columns]
|
|
113
|
+
elif hasattr(statement_response.result, "schema"):
|
|
114
|
+
columns = [col.name for col in statement_response.result.schema.columns]
|
|
115
|
+
|
|
116
|
+
data: list[list[Any]] = statement_response.result.data_array
|
|
117
|
+
if columns:
|
|
118
|
+
return pd.DataFrame(data, columns=columns)
|
|
119
|
+
else:
|
|
120
|
+
return pd.DataFrame(data)
|
|
121
|
+
|
|
122
|
+
return pd.DataFrame()
|
|
@@ -0,0 +1,306 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LRU (Least Recently Used) cache implementation for Genie SQL queries.
|
|
3
|
+
|
|
4
|
+
This module provides an in-memory LRU cache that stores SQL queries generated
|
|
5
|
+
by Genie. On cache hit, the cached SQL is re-executed against the warehouse
|
|
6
|
+
to return fresh data while avoiding the Genie NL-to-SQL translation cost.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from collections import OrderedDict
|
|
10
|
+
from datetime import datetime, timedelta
|
|
11
|
+
from threading import Lock
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
import mlflow
|
|
15
|
+
import pandas as pd
|
|
16
|
+
from databricks.sdk import WorkspaceClient
|
|
17
|
+
from databricks.sdk.service.sql import StatementResponse, StatementState
|
|
18
|
+
from databricks_ai_bridge.genie import GenieResponse
|
|
19
|
+
from loguru import logger
|
|
20
|
+
|
|
21
|
+
from dao_ai.config import GenieLRUCacheParametersModel, WarehouseModel
|
|
22
|
+
from dao_ai.genie.cache.base import (
|
|
23
|
+
CacheResult,
|
|
24
|
+
GenieServiceBase,
|
|
25
|
+
SQLCacheEntry,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class LRUCacheService(GenieServiceBase):
|
|
30
|
+
"""
|
|
31
|
+
LRU caching decorator that caches SQL queries and re-executes them.
|
|
32
|
+
|
|
33
|
+
This service caches the SQL query generated by Genie (not the result data).
|
|
34
|
+
On cache hit, it re-executes the cached SQL using the provided warehouse
|
|
35
|
+
to return fresh data while avoiding the Genie NL-to-SQL translation cost.
|
|
36
|
+
|
|
37
|
+
Example:
|
|
38
|
+
from dao_ai.config import GenieLRUCacheParametersModel, WarehouseModel
|
|
39
|
+
from dao_ai.genie.cache import LRUCacheService
|
|
40
|
+
|
|
41
|
+
cache_params = GenieLRUCacheParametersModel(
|
|
42
|
+
warehouse=warehouse_model,
|
|
43
|
+
capacity=100,
|
|
44
|
+
time_to_live_seconds=86400 # 24 hours
|
|
45
|
+
)
|
|
46
|
+
genie = LRUCacheService(
|
|
47
|
+
impl=GenieService(Genie(space_id="my-space")),
|
|
48
|
+
parameters=cache_params
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
Thread-safe: Uses a lock to protect cache operations.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
impl: GenieServiceBase
|
|
55
|
+
parameters: GenieLRUCacheParametersModel
|
|
56
|
+
name: str
|
|
57
|
+
_cache: OrderedDict[str, SQLCacheEntry]
|
|
58
|
+
_lock: Lock
|
|
59
|
+
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
impl: GenieServiceBase,
|
|
63
|
+
parameters: GenieLRUCacheParametersModel,
|
|
64
|
+
name: str | None = None,
|
|
65
|
+
) -> None:
|
|
66
|
+
"""
|
|
67
|
+
Initialize the SQL cache service.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
impl: The underlying GenieServiceBase to delegate to on cache miss
|
|
71
|
+
parameters: Cache configuration including warehouse, capacity, and TTL
|
|
72
|
+
name: Name for this cache layer (for logging). Defaults to class name.
|
|
73
|
+
"""
|
|
74
|
+
self.impl = impl
|
|
75
|
+
self.parameters = parameters
|
|
76
|
+
self.name = name if name is not None else self.__class__.__name__
|
|
77
|
+
self._cache = OrderedDict()
|
|
78
|
+
self._lock = Lock()
|
|
79
|
+
|
|
80
|
+
@property
|
|
81
|
+
def warehouse(self) -> WarehouseModel:
|
|
82
|
+
"""The warehouse used for executing cached SQL queries."""
|
|
83
|
+
return self.parameters.warehouse
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
def capacity(self) -> int:
|
|
87
|
+
"""Maximum number of SQL queries to cache."""
|
|
88
|
+
return self.parameters.capacity
|
|
89
|
+
|
|
90
|
+
@property
|
|
91
|
+
def time_to_live(self) -> timedelta | None:
|
|
92
|
+
"""Duration after which cached queries expire. None means never expires."""
|
|
93
|
+
ttl = self.parameters.time_to_live_seconds
|
|
94
|
+
if ttl is None or ttl < 0:
|
|
95
|
+
return None
|
|
96
|
+
return timedelta(seconds=ttl)
|
|
97
|
+
|
|
98
|
+
@staticmethod
|
|
99
|
+
def _normalize_key(question: str) -> str:
|
|
100
|
+
"""Normalize the question to create a consistent cache key."""
|
|
101
|
+
return question.strip().lower()
|
|
102
|
+
|
|
103
|
+
def _is_expired(self, entry: SQLCacheEntry) -> bool:
|
|
104
|
+
"""Check if a cache entry has exceeded its TTL. Returns False if TTL is disabled."""
|
|
105
|
+
if self.time_to_live is None:
|
|
106
|
+
return False
|
|
107
|
+
age: timedelta = datetime.now() - entry.created_at
|
|
108
|
+
return age > self.time_to_live
|
|
109
|
+
|
|
110
|
+
def _evict_oldest(self) -> None:
|
|
111
|
+
"""Remove the oldest (least recently used) entry."""
|
|
112
|
+
if self._cache:
|
|
113
|
+
oldest_key: str = next(iter(self._cache))
|
|
114
|
+
del self._cache[oldest_key]
|
|
115
|
+
logger.debug(f"[{self.name}] Evicted: {oldest_key[:50]}...")
|
|
116
|
+
|
|
117
|
+
def _get(self, key: str) -> SQLCacheEntry | None:
|
|
118
|
+
"""Get from cache, returning None if not found or expired."""
|
|
119
|
+
if key not in self._cache:
|
|
120
|
+
return None
|
|
121
|
+
|
|
122
|
+
entry: SQLCacheEntry = self._cache[key]
|
|
123
|
+
|
|
124
|
+
if self._is_expired(entry):
|
|
125
|
+
del self._cache[key]
|
|
126
|
+
logger.debug(f"[{self.name}] Expired: {key[:50]}...")
|
|
127
|
+
return None
|
|
128
|
+
|
|
129
|
+
self._cache.move_to_end(key)
|
|
130
|
+
return entry
|
|
131
|
+
|
|
132
|
+
def _put(self, key: str, response: GenieResponse) -> None:
|
|
133
|
+
"""Store SQL query in cache, evicting if at capacity."""
|
|
134
|
+
if key in self._cache:
|
|
135
|
+
del self._cache[key]
|
|
136
|
+
|
|
137
|
+
while len(self._cache) >= self.capacity:
|
|
138
|
+
self._evict_oldest()
|
|
139
|
+
|
|
140
|
+
self._cache[key] = SQLCacheEntry(
|
|
141
|
+
query=response.query,
|
|
142
|
+
description=response.description,
|
|
143
|
+
conversation_id=response.conversation_id,
|
|
144
|
+
created_at=datetime.now(),
|
|
145
|
+
)
|
|
146
|
+
logger.info(
|
|
147
|
+
f"[{self.name}] Stored cache entry: key='{key[:50]}...' "
|
|
148
|
+
f"sql='{response.query[:50] if response.query else 'None'}...' "
|
|
149
|
+
f"(cache_size={len(self._cache)}/{self.capacity})"
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
@mlflow.trace(name="execute_cached_sql")
|
|
153
|
+
def _execute_sql(self, sql: str) -> pd.DataFrame | str:
|
|
154
|
+
"""
|
|
155
|
+
Execute SQL using the warehouse and return results as DataFrame.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
sql: The SQL query to execute
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
DataFrame with results, or error message string
|
|
162
|
+
"""
|
|
163
|
+
w: WorkspaceClient = self.warehouse.workspace_client
|
|
164
|
+
warehouse_id: str = str(self.warehouse.warehouse_id)
|
|
165
|
+
|
|
166
|
+
logger.debug(f"[{self.name}] Executing cached SQL: {sql[:100]}...")
|
|
167
|
+
|
|
168
|
+
statement_response: StatementResponse = w.statement_execution.execute_statement(
|
|
169
|
+
statement=sql,
|
|
170
|
+
warehouse_id=warehouse_id,
|
|
171
|
+
wait_timeout="30s",
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
# Poll for completion if still running
|
|
175
|
+
while statement_response.status.state in [
|
|
176
|
+
StatementState.PENDING,
|
|
177
|
+
StatementState.RUNNING,
|
|
178
|
+
]:
|
|
179
|
+
statement_response = w.statement_execution.get_statement(
|
|
180
|
+
statement_response.statement_id
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
if statement_response.status.state != StatementState.SUCCEEDED:
|
|
184
|
+
error_msg: str = f"SQL execution failed: {statement_response.status}"
|
|
185
|
+
logger.error(f"[{self.name}] {error_msg}")
|
|
186
|
+
return error_msg
|
|
187
|
+
|
|
188
|
+
# Convert to DataFrame
|
|
189
|
+
if statement_response.result and statement_response.result.data_array:
|
|
190
|
+
columns: list[str] = []
|
|
191
|
+
if statement_response.manifest and statement_response.manifest.schema:
|
|
192
|
+
columns = [
|
|
193
|
+
col.name for col in statement_response.manifest.schema.columns
|
|
194
|
+
]
|
|
195
|
+
elif hasattr(statement_response.result, "schema"):
|
|
196
|
+
columns = [col.name for col in statement_response.result.schema.columns]
|
|
197
|
+
|
|
198
|
+
data: list[list[Any]] = statement_response.result.data_array
|
|
199
|
+
if columns:
|
|
200
|
+
return pd.DataFrame(data, columns=columns)
|
|
201
|
+
else:
|
|
202
|
+
return pd.DataFrame(data)
|
|
203
|
+
|
|
204
|
+
return pd.DataFrame()
|
|
205
|
+
|
|
206
|
+
def ask_question(
|
|
207
|
+
self, question: str, conversation_id: str | None = None
|
|
208
|
+
) -> GenieResponse:
|
|
209
|
+
"""
|
|
210
|
+
Ask a question, using cached SQL query if available.
|
|
211
|
+
|
|
212
|
+
On cache hit, re-executes the cached SQL to get fresh data.
|
|
213
|
+
Implements GenieServiceBase for seamless chaining.
|
|
214
|
+
"""
|
|
215
|
+
result: CacheResult = self.ask_question_with_cache_info(
|
|
216
|
+
question, conversation_id
|
|
217
|
+
)
|
|
218
|
+
return result.response
|
|
219
|
+
|
|
220
|
+
@mlflow.trace(name="genie_lru_cache_lookup")
|
|
221
|
+
def ask_question_with_cache_info(
|
|
222
|
+
self,
|
|
223
|
+
question: str,
|
|
224
|
+
conversation_id: str | None = None,
|
|
225
|
+
) -> CacheResult:
|
|
226
|
+
"""
|
|
227
|
+
Ask a question with detailed cache hit information.
|
|
228
|
+
|
|
229
|
+
On cache hit, the cached SQL is re-executed to return fresh data.
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
question: The question to ask
|
|
233
|
+
conversation_id: Optional conversation ID
|
|
234
|
+
|
|
235
|
+
Returns:
|
|
236
|
+
CacheResult with fresh response and cache metadata
|
|
237
|
+
"""
|
|
238
|
+
key: str = self._normalize_key(question)
|
|
239
|
+
|
|
240
|
+
# Check cache
|
|
241
|
+
with self._lock:
|
|
242
|
+
cached: SQLCacheEntry | None = self._get(key)
|
|
243
|
+
|
|
244
|
+
if cached is not None:
|
|
245
|
+
logger.info(
|
|
246
|
+
f"[{self.name}] Cache HIT: '{question[:50]}...' "
|
|
247
|
+
f"(cache_size={self.size}/{self.capacity})"
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
# Re-execute the cached SQL to get fresh data
|
|
251
|
+
result: pd.DataFrame | str = self._execute_sql(cached.query)
|
|
252
|
+
|
|
253
|
+
response: GenieResponse = GenieResponse(
|
|
254
|
+
result=result,
|
|
255
|
+
query=cached.query,
|
|
256
|
+
description=cached.description,
|
|
257
|
+
conversation_id=cached.conversation_id,
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
return CacheResult(response=response, cache_hit=True, served_by=self.name)
|
|
261
|
+
|
|
262
|
+
# Cache miss - delegate to wrapped service
|
|
263
|
+
logger.info(
|
|
264
|
+
f"[{self.name}] Cache MISS: '{question[:50]}...' "
|
|
265
|
+
f"(cache_size={self.size}/{self.capacity}, delegating to {type(self.impl).__name__})"
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
response = self.impl.ask_question(question, conversation_id)
|
|
269
|
+
with self._lock:
|
|
270
|
+
self._put(key, response)
|
|
271
|
+
return CacheResult(response=response, cache_hit=False, served_by=None)
|
|
272
|
+
|
|
273
|
+
def invalidate(self, question: str) -> bool:
|
|
274
|
+
"""Remove a specific entry from the cache."""
|
|
275
|
+
key: str = self._normalize_key(question)
|
|
276
|
+
with self._lock:
|
|
277
|
+
if key in self._cache:
|
|
278
|
+
del self._cache[key]
|
|
279
|
+
return True
|
|
280
|
+
return False
|
|
281
|
+
|
|
282
|
+
def clear(self) -> int:
|
|
283
|
+
"""Clear all entries from the cache."""
|
|
284
|
+
with self._lock:
|
|
285
|
+
count: int = len(self._cache)
|
|
286
|
+
self._cache.clear()
|
|
287
|
+
return count
|
|
288
|
+
|
|
289
|
+
@property
|
|
290
|
+
def size(self) -> int:
|
|
291
|
+
"""Current number of entries in the cache."""
|
|
292
|
+
with self._lock:
|
|
293
|
+
return len(self._cache)
|
|
294
|
+
|
|
295
|
+
def stats(self) -> dict[str, int | float | None]:
|
|
296
|
+
"""Return cache statistics."""
|
|
297
|
+
with self._lock:
|
|
298
|
+
expired: int = sum(1 for e in self._cache.values() if self._is_expired(e))
|
|
299
|
+
ttl = self.time_to_live
|
|
300
|
+
return {
|
|
301
|
+
"size": len(self._cache),
|
|
302
|
+
"capacity": self.capacity,
|
|
303
|
+
"ttl_seconds": ttl.total_seconds() if ttl else None,
|
|
304
|
+
"expired_entries": expired,
|
|
305
|
+
"valid_entries": len(self._cache) - expired,
|
|
306
|
+
}
|