dao-ai 0.0.35__py3-none-any.whl → 0.0.36__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dao_ai/config.py CHANGED
@@ -28,8 +28,10 @@ from databricks.sdk.service.database import DatabaseInstance
28
28
  from databricks.vector_search.client import VectorSearchClient
29
29
  from databricks.vector_search.index import VectorSearchIndex
30
30
  from databricks_langchain import (
31
+ DatabricksEmbeddings,
31
32
  DatabricksFunctionClient,
32
33
  )
34
+ from langchain_core.embeddings import Embeddings
33
35
  from langchain_core.language_models import LanguageModelLike
34
36
  from langchain_core.messages import BaseMessage, messages_from_dict
35
37
  from langchain_core.runnables.base import RunnableLike
@@ -408,6 +410,9 @@ class LLMModel(BaseModel, IsDatabricksResource):
408
410
 
409
411
  return chat_client
410
412
 
413
+ def as_embeddings_model(self) -> Embeddings:
414
+ return DatabricksEmbeddings(endpoint=self.name)
415
+
411
416
 
412
417
  class VectorSearchEndpointType(str, Enum):
413
418
  STANDARD = "STANDARD"
@@ -977,6 +982,30 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
977
982
  provider.create_lakebase_instance_role(self)
978
983
 
979
984
 
985
+ class GenieLRUCacheParametersModel(BaseModel):
986
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
987
+ capacity: int = 1000
988
+ time_to_live_seconds: int | None = (
989
+ 60 * 60 * 24
990
+ ) # 1 day default, None or negative = never expires
991
+ warehouse: WarehouseModel
992
+
993
+
994
+ class GenieSemanticCacheParametersModel(BaseModel):
995
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
996
+ time_to_live_seconds: int | None = (
997
+ 60 * 60 * 24
998
+ ) # 1 day default, None or negative = never expires
999
+ similarity_threshold: float = (
1000
+ 0.85 # Minimum similarity for cache hit (L2 distance converted to 0-1 scale)
1001
+ )
1002
+ embedding_model: str | LLMModel = "databricks-gte-large-en"
1003
+ embedding_dims: int | None = None # Auto-detected if None
1004
+ database: DatabaseModel
1005
+ warehouse: WarehouseModel
1006
+ table_name: str = "genie_semantic_cache"
1007
+
1008
+
980
1009
  class SearchParametersModel(BaseModel):
981
1010
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
982
1011
  num_results: Optional[int] = 10
@@ -0,0 +1,59 @@
1
+ """
2
+ Genie service implementations and caching layers.
3
+
4
+ This package provides core Genie functionality that can be used across
5
+ different contexts (tools, direct integration, etc.).
6
+
7
+ Main exports:
8
+ - GenieService: Core service implementation wrapping Databricks Genie SDK
9
+ - GenieServiceBase: Abstract base class for service implementations
10
+
11
+ Cache implementations are available in the cache subpackage:
12
+ - dao_ai.genie.cache.lru: LRU (Least Recently Used) cache
13
+ - dao_ai.genie.cache.semantic: Semantic similarity cache using pg_vector
14
+
15
+ Example usage:
16
+ from dao_ai.genie import GenieService
17
+ from dao_ai.genie.cache import LRUCacheService, SemanticCacheService
18
+ """
19
+
20
+ import mlflow
21
+ from databricks_ai_bridge.genie import Genie, GenieResponse
22
+
23
+ from dao_ai.genie.cache import (
24
+ CacheResult,
25
+ GenieServiceBase,
26
+ LRUCacheService,
27
+ SemanticCacheService,
28
+ SQLCacheEntry,
29
+ )
30
+
31
+
32
+ class GenieService(GenieServiceBase):
33
+ """Concrete implementation of GenieServiceBase using the Genie SDK."""
34
+
35
+ genie: Genie
36
+
37
+ def __init__(self, genie: Genie) -> None:
38
+ self.genie = genie
39
+
40
+ @mlflow.trace(name="genie_ask_question")
41
+ def ask_question(
42
+ self, question: str, conversation_id: str | None = None
43
+ ) -> GenieResponse:
44
+ response: GenieResponse = self.genie.ask_question(
45
+ question, conversation_id=conversation_id
46
+ )
47
+ return response
48
+
49
+
50
+ __all__ = [
51
+ # Service classes
52
+ "GenieService",
53
+ "GenieServiceBase",
54
+ # Cache types (from cache subpackage)
55
+ "CacheResult",
56
+ "LRUCacheService",
57
+ "SemanticCacheService",
58
+ "SQLCacheEntry",
59
+ ]
@@ -0,0 +1,44 @@
1
+ """
2
+ Genie cache implementations.
3
+
4
+ This package provides caching layers for Genie SQL queries that can be
5
+ chained together using the decorator pattern.
6
+
7
+ Available cache implementations:
8
+ - LRUCacheService: In-memory LRU cache with O(1) exact match lookup
9
+ - SemanticCacheService: PostgreSQL pg_vector-based semantic similarity cache
10
+
11
+ Example usage:
12
+ from dao_ai.genie.cache import LRUCacheService, SemanticCacheService
13
+
14
+ # Chain caches: LRU (checked first) -> Semantic (checked second) -> Genie
15
+ genie_service = SemanticCacheService(
16
+ impl=GenieService(genie),
17
+ parameters=semantic_params,
18
+ genie_space_id=space_id,
19
+ )
20
+ genie_service = LRUCacheService(
21
+ impl=genie_service,
22
+ parameters=lru_params,
23
+ )
24
+ """
25
+
26
+ from dao_ai.genie.cache.base import (
27
+ CacheResult,
28
+ GenieServiceBase,
29
+ SQLCacheEntry,
30
+ execute_sql_via_warehouse,
31
+ )
32
+ from dao_ai.genie.cache.lru import LRUCacheService
33
+ from dao_ai.genie.cache.semantic import SemanticCacheService
34
+
35
+ __all__ = [
36
+ # Base types
37
+ "CacheResult",
38
+ "GenieServiceBase",
39
+ "SQLCacheEntry",
40
+ "execute_sql_via_warehouse",
41
+ # Cache implementations
42
+ "LRUCacheService",
43
+ "SemanticCacheService",
44
+ ]
@@ -0,0 +1,122 @@
1
+ """
2
+ Base classes and types for Genie cache implementations.
3
+
4
+ This module provides the foundational types used across different cache
5
+ implementations (LRU, Semantic, etc.).
6
+ """
7
+
8
+ from abc import ABC, abstractmethod
9
+ from dataclasses import dataclass
10
+ from datetime import datetime
11
+ from typing import Any
12
+
13
+ import pandas as pd
14
+ from databricks.sdk import WorkspaceClient
15
+ from databricks.sdk.service.sql import StatementResponse, StatementState
16
+ from databricks_ai_bridge.genie import GenieResponse
17
+ from loguru import logger
18
+
19
+ from dao_ai.config import WarehouseModel
20
+
21
+
22
+ class GenieServiceBase(ABC):
23
+ """Abstract base class for Genie service implementations."""
24
+
25
+ @abstractmethod
26
+ def ask_question(
27
+ self, question: str, conversation_id: str | None = None
28
+ ) -> GenieResponse:
29
+ """Ask a question to Genie and return the response."""
30
+ pass
31
+
32
+
33
+ @dataclass
34
+ class SQLCacheEntry:
35
+ """
36
+ A cache entry storing the SQL query metadata for re-execution.
37
+
38
+ Instead of caching the full result, we cache the SQL query so that
39
+ on cache hit we can re-execute it to get fresh data.
40
+ """
41
+
42
+ query: str
43
+ description: str
44
+ conversation_id: str
45
+ created_at: datetime
46
+
47
+
48
+ @dataclass
49
+ class CacheResult:
50
+ """
51
+ Result of a cache-aware query with metadata about cache behavior.
52
+
53
+ Attributes:
54
+ response: The GenieResponse (fresh data, possibly from cached SQL)
55
+ cache_hit: Whether the SQL query came from cache
56
+ served_by: Name of the layer that served the cached SQL (None if from origin)
57
+ """
58
+
59
+ response: GenieResponse
60
+ cache_hit: bool
61
+ served_by: str | None = None
62
+
63
+
64
+ def execute_sql_via_warehouse(
65
+ warehouse: WarehouseModel,
66
+ sql: str,
67
+ layer_name: str = "cache",
68
+ ) -> pd.DataFrame | str:
69
+ """
70
+ Execute SQL using a Databricks warehouse and return results as DataFrame.
71
+
72
+ This is a shared utility for cache implementations that need to re-execute
73
+ cached SQL queries.
74
+
75
+ Args:
76
+ warehouse: The warehouse configuration for SQL execution
77
+ sql: The SQL query to execute
78
+ layer_name: Name of the cache layer (for logging)
79
+
80
+ Returns:
81
+ DataFrame with results, or error message string
82
+ """
83
+ w: WorkspaceClient = warehouse.workspace_client
84
+ warehouse_id: str = str(warehouse.warehouse_id)
85
+
86
+ logger.debug(f"[{layer_name}] Executing cached SQL: {sql[:100]}...")
87
+
88
+ statement_response: StatementResponse = w.statement_execution.execute_statement(
89
+ statement=sql,
90
+ warehouse_id=warehouse_id,
91
+ wait_timeout="30s",
92
+ )
93
+
94
+ # Poll for completion if still running
95
+ while statement_response.status.state in [
96
+ StatementState.PENDING,
97
+ StatementState.RUNNING,
98
+ ]:
99
+ statement_response = w.statement_execution.get_statement(
100
+ statement_response.statement_id
101
+ )
102
+
103
+ if statement_response.status.state != StatementState.SUCCEEDED:
104
+ error_msg: str = f"SQL execution failed: {statement_response.status}"
105
+ logger.error(f"[{layer_name}] {error_msg}")
106
+ return error_msg
107
+
108
+ # Convert to DataFrame
109
+ if statement_response.result and statement_response.result.data_array:
110
+ columns: list[str] = []
111
+ if statement_response.manifest and statement_response.manifest.schema:
112
+ columns = [col.name for col in statement_response.manifest.schema.columns]
113
+ elif hasattr(statement_response.result, "schema"):
114
+ columns = [col.name for col in statement_response.result.schema.columns]
115
+
116
+ data: list[list[Any]] = statement_response.result.data_array
117
+ if columns:
118
+ return pd.DataFrame(data, columns=columns)
119
+ else:
120
+ return pd.DataFrame(data)
121
+
122
+ return pd.DataFrame()
@@ -0,0 +1,306 @@
1
+ """
2
+ LRU (Least Recently Used) cache implementation for Genie SQL queries.
3
+
4
+ This module provides an in-memory LRU cache that stores SQL queries generated
5
+ by Genie. On cache hit, the cached SQL is re-executed against the warehouse
6
+ to return fresh data while avoiding the Genie NL-to-SQL translation cost.
7
+ """
8
+
9
+ from collections import OrderedDict
10
+ from datetime import datetime, timedelta
11
+ from threading import Lock
12
+ from typing import Any
13
+
14
+ import mlflow
15
+ import pandas as pd
16
+ from databricks.sdk import WorkspaceClient
17
+ from databricks.sdk.service.sql import StatementResponse, StatementState
18
+ from databricks_ai_bridge.genie import GenieResponse
19
+ from loguru import logger
20
+
21
+ from dao_ai.config import GenieLRUCacheParametersModel, WarehouseModel
22
+ from dao_ai.genie.cache.base import (
23
+ CacheResult,
24
+ GenieServiceBase,
25
+ SQLCacheEntry,
26
+ )
27
+
28
+
29
+ class LRUCacheService(GenieServiceBase):
30
+ """
31
+ LRU caching decorator that caches SQL queries and re-executes them.
32
+
33
+ This service caches the SQL query generated by Genie (not the result data).
34
+ On cache hit, it re-executes the cached SQL using the provided warehouse
35
+ to return fresh data while avoiding the Genie NL-to-SQL translation cost.
36
+
37
+ Example:
38
+ from dao_ai.config import GenieLRUCacheParametersModel, WarehouseModel
39
+ from dao_ai.genie.cache import LRUCacheService
40
+
41
+ cache_params = GenieLRUCacheParametersModel(
42
+ warehouse=warehouse_model,
43
+ capacity=100,
44
+ time_to_live_seconds=86400 # 24 hours
45
+ )
46
+ genie = LRUCacheService(
47
+ impl=GenieService(Genie(space_id="my-space")),
48
+ parameters=cache_params
49
+ )
50
+
51
+ Thread-safe: Uses a lock to protect cache operations.
52
+ """
53
+
54
+ impl: GenieServiceBase
55
+ parameters: GenieLRUCacheParametersModel
56
+ name: str
57
+ _cache: OrderedDict[str, SQLCacheEntry]
58
+ _lock: Lock
59
+
60
+ def __init__(
61
+ self,
62
+ impl: GenieServiceBase,
63
+ parameters: GenieLRUCacheParametersModel,
64
+ name: str | None = None,
65
+ ) -> None:
66
+ """
67
+ Initialize the SQL cache service.
68
+
69
+ Args:
70
+ impl: The underlying GenieServiceBase to delegate to on cache miss
71
+ parameters: Cache configuration including warehouse, capacity, and TTL
72
+ name: Name for this cache layer (for logging). Defaults to class name.
73
+ """
74
+ self.impl = impl
75
+ self.parameters = parameters
76
+ self.name = name if name is not None else self.__class__.__name__
77
+ self._cache = OrderedDict()
78
+ self._lock = Lock()
79
+
80
+ @property
81
+ def warehouse(self) -> WarehouseModel:
82
+ """The warehouse used for executing cached SQL queries."""
83
+ return self.parameters.warehouse
84
+
85
+ @property
86
+ def capacity(self) -> int:
87
+ """Maximum number of SQL queries to cache."""
88
+ return self.parameters.capacity
89
+
90
+ @property
91
+ def time_to_live(self) -> timedelta | None:
92
+ """Duration after which cached queries expire. None means never expires."""
93
+ ttl = self.parameters.time_to_live_seconds
94
+ if ttl is None or ttl < 0:
95
+ return None
96
+ return timedelta(seconds=ttl)
97
+
98
+ @staticmethod
99
+ def _normalize_key(question: str) -> str:
100
+ """Normalize the question to create a consistent cache key."""
101
+ return question.strip().lower()
102
+
103
+ def _is_expired(self, entry: SQLCacheEntry) -> bool:
104
+ """Check if a cache entry has exceeded its TTL. Returns False if TTL is disabled."""
105
+ if self.time_to_live is None:
106
+ return False
107
+ age: timedelta = datetime.now() - entry.created_at
108
+ return age > self.time_to_live
109
+
110
+ def _evict_oldest(self) -> None:
111
+ """Remove the oldest (least recently used) entry."""
112
+ if self._cache:
113
+ oldest_key: str = next(iter(self._cache))
114
+ del self._cache[oldest_key]
115
+ logger.debug(f"[{self.name}] Evicted: {oldest_key[:50]}...")
116
+
117
+ def _get(self, key: str) -> SQLCacheEntry | None:
118
+ """Get from cache, returning None if not found or expired."""
119
+ if key not in self._cache:
120
+ return None
121
+
122
+ entry: SQLCacheEntry = self._cache[key]
123
+
124
+ if self._is_expired(entry):
125
+ del self._cache[key]
126
+ logger.debug(f"[{self.name}] Expired: {key[:50]}...")
127
+ return None
128
+
129
+ self._cache.move_to_end(key)
130
+ return entry
131
+
132
+ def _put(self, key: str, response: GenieResponse) -> None:
133
+ """Store SQL query in cache, evicting if at capacity."""
134
+ if key in self._cache:
135
+ del self._cache[key]
136
+
137
+ while len(self._cache) >= self.capacity:
138
+ self._evict_oldest()
139
+
140
+ self._cache[key] = SQLCacheEntry(
141
+ query=response.query,
142
+ description=response.description,
143
+ conversation_id=response.conversation_id,
144
+ created_at=datetime.now(),
145
+ )
146
+ logger.info(
147
+ f"[{self.name}] Stored cache entry: key='{key[:50]}...' "
148
+ f"sql='{response.query[:50] if response.query else 'None'}...' "
149
+ f"(cache_size={len(self._cache)}/{self.capacity})"
150
+ )
151
+
152
+ @mlflow.trace(name="execute_cached_sql")
153
+ def _execute_sql(self, sql: str) -> pd.DataFrame | str:
154
+ """
155
+ Execute SQL using the warehouse and return results as DataFrame.
156
+
157
+ Args:
158
+ sql: The SQL query to execute
159
+
160
+ Returns:
161
+ DataFrame with results, or error message string
162
+ """
163
+ w: WorkspaceClient = self.warehouse.workspace_client
164
+ warehouse_id: str = str(self.warehouse.warehouse_id)
165
+
166
+ logger.debug(f"[{self.name}] Executing cached SQL: {sql[:100]}...")
167
+
168
+ statement_response: StatementResponse = w.statement_execution.execute_statement(
169
+ statement=sql,
170
+ warehouse_id=warehouse_id,
171
+ wait_timeout="30s",
172
+ )
173
+
174
+ # Poll for completion if still running
175
+ while statement_response.status.state in [
176
+ StatementState.PENDING,
177
+ StatementState.RUNNING,
178
+ ]:
179
+ statement_response = w.statement_execution.get_statement(
180
+ statement_response.statement_id
181
+ )
182
+
183
+ if statement_response.status.state != StatementState.SUCCEEDED:
184
+ error_msg: str = f"SQL execution failed: {statement_response.status}"
185
+ logger.error(f"[{self.name}] {error_msg}")
186
+ return error_msg
187
+
188
+ # Convert to DataFrame
189
+ if statement_response.result and statement_response.result.data_array:
190
+ columns: list[str] = []
191
+ if statement_response.manifest and statement_response.manifest.schema:
192
+ columns = [
193
+ col.name for col in statement_response.manifest.schema.columns
194
+ ]
195
+ elif hasattr(statement_response.result, "schema"):
196
+ columns = [col.name for col in statement_response.result.schema.columns]
197
+
198
+ data: list[list[Any]] = statement_response.result.data_array
199
+ if columns:
200
+ return pd.DataFrame(data, columns=columns)
201
+ else:
202
+ return pd.DataFrame(data)
203
+
204
+ return pd.DataFrame()
205
+
206
+ def ask_question(
207
+ self, question: str, conversation_id: str | None = None
208
+ ) -> GenieResponse:
209
+ """
210
+ Ask a question, using cached SQL query if available.
211
+
212
+ On cache hit, re-executes the cached SQL to get fresh data.
213
+ Implements GenieServiceBase for seamless chaining.
214
+ """
215
+ result: CacheResult = self.ask_question_with_cache_info(
216
+ question, conversation_id
217
+ )
218
+ return result.response
219
+
220
+ @mlflow.trace(name="genie_lru_cache_lookup")
221
+ def ask_question_with_cache_info(
222
+ self,
223
+ question: str,
224
+ conversation_id: str | None = None,
225
+ ) -> CacheResult:
226
+ """
227
+ Ask a question with detailed cache hit information.
228
+
229
+ On cache hit, the cached SQL is re-executed to return fresh data.
230
+
231
+ Args:
232
+ question: The question to ask
233
+ conversation_id: Optional conversation ID
234
+
235
+ Returns:
236
+ CacheResult with fresh response and cache metadata
237
+ """
238
+ key: str = self._normalize_key(question)
239
+
240
+ # Check cache
241
+ with self._lock:
242
+ cached: SQLCacheEntry | None = self._get(key)
243
+
244
+ if cached is not None:
245
+ logger.info(
246
+ f"[{self.name}] Cache HIT: '{question[:50]}...' "
247
+ f"(cache_size={self.size}/{self.capacity})"
248
+ )
249
+
250
+ # Re-execute the cached SQL to get fresh data
251
+ result: pd.DataFrame | str = self._execute_sql(cached.query)
252
+
253
+ response: GenieResponse = GenieResponse(
254
+ result=result,
255
+ query=cached.query,
256
+ description=cached.description,
257
+ conversation_id=cached.conversation_id,
258
+ )
259
+
260
+ return CacheResult(response=response, cache_hit=True, served_by=self.name)
261
+
262
+ # Cache miss - delegate to wrapped service
263
+ logger.info(
264
+ f"[{self.name}] Cache MISS: '{question[:50]}...' "
265
+ f"(cache_size={self.size}/{self.capacity}, delegating to {type(self.impl).__name__})"
266
+ )
267
+
268
+ response = self.impl.ask_question(question, conversation_id)
269
+ with self._lock:
270
+ self._put(key, response)
271
+ return CacheResult(response=response, cache_hit=False, served_by=None)
272
+
273
+ def invalidate(self, question: str) -> bool:
274
+ """Remove a specific entry from the cache."""
275
+ key: str = self._normalize_key(question)
276
+ with self._lock:
277
+ if key in self._cache:
278
+ del self._cache[key]
279
+ return True
280
+ return False
281
+
282
+ def clear(self) -> int:
283
+ """Clear all entries from the cache."""
284
+ with self._lock:
285
+ count: int = len(self._cache)
286
+ self._cache.clear()
287
+ return count
288
+
289
+ @property
290
+ def size(self) -> int:
291
+ """Current number of entries in the cache."""
292
+ with self._lock:
293
+ return len(self._cache)
294
+
295
+ def stats(self) -> dict[str, int | float | None]:
296
+ """Return cache statistics."""
297
+ with self._lock:
298
+ expired: int = sum(1 for e in self._cache.values() if self._is_expired(e))
299
+ ttl = self.time_to_live
300
+ return {
301
+ "size": len(self._cache),
302
+ "capacity": self.capacity,
303
+ "ttl_seconds": ttl.total_seconds() if ttl else None,
304
+ "expired_entries": expired,
305
+ "valid_entries": len(self._cache) - expired,
306
+ }