dao-ai 0.0.28__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/__init__.py +29 -0
- dao_ai/agent_as_code.py +2 -5
- dao_ai/cli.py +245 -40
- dao_ai/config.py +1491 -370
- dao_ai/genie/__init__.py +38 -0
- dao_ai/genie/cache/__init__.py +43 -0
- dao_ai/genie/cache/base.py +72 -0
- dao_ai/genie/cache/core.py +79 -0
- dao_ai/genie/cache/lru.py +347 -0
- dao_ai/genie/cache/semantic.py +970 -0
- dao_ai/genie/core.py +35 -0
- dao_ai/graph.py +27 -253
- dao_ai/hooks/__init__.py +9 -6
- dao_ai/hooks/core.py +27 -195
- dao_ai/logging.py +56 -0
- dao_ai/memory/__init__.py +10 -0
- dao_ai/memory/core.py +65 -30
- dao_ai/memory/databricks.py +402 -0
- dao_ai/memory/postgres.py +79 -38
- dao_ai/messages.py +6 -4
- dao_ai/middleware/__init__.py +125 -0
- dao_ai/middleware/assertions.py +806 -0
- dao_ai/middleware/base.py +50 -0
- dao_ai/middleware/core.py +67 -0
- dao_ai/middleware/guardrails.py +420 -0
- dao_ai/middleware/human_in_the_loop.py +232 -0
- dao_ai/middleware/message_validation.py +586 -0
- dao_ai/middleware/summarization.py +197 -0
- dao_ai/models.py +1306 -114
- dao_ai/nodes.py +245 -159
- dao_ai/optimization.py +674 -0
- dao_ai/orchestration/__init__.py +52 -0
- dao_ai/orchestration/core.py +294 -0
- dao_ai/orchestration/supervisor.py +278 -0
- dao_ai/orchestration/swarm.py +271 -0
- dao_ai/prompts.py +128 -31
- dao_ai/providers/databricks.py +573 -601
- dao_ai/state.py +157 -21
- dao_ai/tools/__init__.py +13 -5
- dao_ai/tools/agent.py +1 -3
- dao_ai/tools/core.py +64 -11
- dao_ai/tools/email.py +232 -0
- dao_ai/tools/genie.py +144 -294
- dao_ai/tools/mcp.py +223 -155
- dao_ai/tools/memory.py +50 -0
- dao_ai/tools/python.py +9 -14
- dao_ai/tools/search.py +14 -0
- dao_ai/tools/slack.py +22 -10
- dao_ai/tools/sql.py +202 -0
- dao_ai/tools/time.py +30 -7
- dao_ai/tools/unity_catalog.py +165 -88
- dao_ai/tools/vector_search.py +331 -221
- dao_ai/utils.py +166 -20
- dao_ai-0.1.2.dist-info/METADATA +455 -0
- dao_ai-0.1.2.dist-info/RECORD +64 -0
- dao_ai/chat_models.py +0 -204
- dao_ai/guardrails.py +0 -112
- dao_ai/tools/human_in_the_loop.py +0 -100
- dao_ai-0.0.28.dist-info/METADATA +0 -1168
- dao_ai-0.0.28.dist-info/RECORD +0 -41
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/licenses/LICENSE +0 -0
dao_ai/tools/genie.py
CHANGED
|
@@ -1,284 +1,73 @@
|
|
|
1
|
-
|
|
1
|
+
"""
|
|
2
|
+
Genie tool for natural language queries to databases.
|
|
3
|
+
|
|
4
|
+
This module provides the tool factory for creating LangGraph tools that
|
|
5
|
+
interact with Databricks Genie.
|
|
6
|
+
|
|
7
|
+
For the core Genie service and cache implementations, see:
|
|
8
|
+
- dao_ai.genie: GenieService, GenieServiceBase
|
|
9
|
+
- dao_ai.genie.cache: LRUCacheService, SemanticCacheService
|
|
10
|
+
"""
|
|
11
|
+
|
|
2
12
|
import json
|
|
3
13
|
import os
|
|
4
|
-
import time
|
|
5
|
-
from dataclasses import asdict, dataclass
|
|
6
|
-
from datetime import datetime
|
|
7
14
|
from textwrap import dedent
|
|
8
|
-
from typing import Annotated, Any, Callable
|
|
15
|
+
from typing import Annotated, Any, Callable
|
|
9
16
|
|
|
10
|
-
import mlflow
|
|
11
17
|
import pandas as pd
|
|
12
|
-
from
|
|
18
|
+
from databricks_ai_bridge.genie import Genie, GenieResponse
|
|
19
|
+
from langchain.tools import ToolRuntime, tool
|
|
13
20
|
from langchain_core.messages import ToolMessage
|
|
14
|
-
from langchain_core.tools import InjectedToolCallId, tool
|
|
15
|
-
from langgraph.prebuilt import InjectedState
|
|
16
21
|
from langgraph.types import Command
|
|
17
22
|
from loguru import logger
|
|
18
|
-
from pydantic import BaseModel
|
|
19
|
-
|
|
20
|
-
from dao_ai.config import AnyVariable, CompositeVariableModel, GenieRoomModel, value_of
|
|
21
|
-
|
|
22
|
-
MAX_TOKENS_OF_DATA: int = 20000
|
|
23
|
-
MAX_ITERATIONS: int = 50
|
|
24
|
-
DEFAULT_POLLING_INTERVAL_SECS: int = 2
|
|
25
|
-
|
|
23
|
+
from pydantic import BaseModel
|
|
26
24
|
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
query: Optional[str] = ""
|
|
39
|
-
description: Optional[str] = ""
|
|
40
|
-
|
|
41
|
-
def to_json(self):
|
|
42
|
-
return json.dumps(asdict(self))
|
|
25
|
+
from dao_ai.config import (
|
|
26
|
+
AnyVariable,
|
|
27
|
+
CompositeVariableModel,
|
|
28
|
+
GenieLRUCacheParametersModel,
|
|
29
|
+
GenieRoomModel,
|
|
30
|
+
GenieSemanticCacheParametersModel,
|
|
31
|
+
value_of,
|
|
32
|
+
)
|
|
33
|
+
from dao_ai.genie import GenieService, GenieServiceBase
|
|
34
|
+
from dao_ai.genie.cache import CacheResult, LRUCacheService, SemanticCacheService
|
|
35
|
+
from dao_ai.state import AgentState, Context, SessionState
|
|
43
36
|
|
|
44
37
|
|
|
45
38
|
class GenieToolInput(BaseModel):
|
|
46
|
-
"""Input schema for
|
|
47
|
-
|
|
48
|
-
question: str = Field(
|
|
49
|
-
description="The question to ask Genie about your data. Ask simple, clear questions about your tabular data. For complex analysis, ask multiple simple questions rather than one complex question."
|
|
50
|
-
)
|
|
39
|
+
"""Input schema for Genie tool - only includes user-facing parameters."""
|
|
51
40
|
|
|
41
|
+
question: str
|
|
52
42
|
|
|
53
|
-
def _truncate_result(dataframe: pd.DataFrame) -> str:
|
|
54
|
-
query_result = dataframe.to_markdown()
|
|
55
|
-
tokens_used = _count_tokens(query_result)
|
|
56
|
-
|
|
57
|
-
# If the full result fits, return it
|
|
58
|
-
if tokens_used <= MAX_TOKENS_OF_DATA:
|
|
59
|
-
return query_result.strip()
|
|
60
|
-
|
|
61
|
-
def is_too_big(n):
|
|
62
|
-
return _count_tokens(dataframe.iloc[:n].to_markdown()) > MAX_TOKENS_OF_DATA
|
|
63
|
-
|
|
64
|
-
# Use bisect_left to find the cutoff point of rows within the max token data limit in a O(log n) complexity
|
|
65
|
-
# Passing True, as this is the target value we are looking for when _is_too_big returns
|
|
66
|
-
cutoff = bisect.bisect_left(range(len(dataframe) + 1), True, key=is_too_big)
|
|
67
|
-
|
|
68
|
-
# Slice to the found limit
|
|
69
|
-
truncated_df = dataframe.iloc[:cutoff]
|
|
70
|
-
|
|
71
|
-
# Edge case: Cannot return any rows because of tokens so return an empty string
|
|
72
|
-
if len(truncated_df) == 0:
|
|
73
|
-
return ""
|
|
74
|
-
|
|
75
|
-
truncated_result = truncated_df.to_markdown()
|
|
76
|
-
|
|
77
|
-
# Double-check edge case if we overshot by one
|
|
78
|
-
if _count_tokens(truncated_result) > MAX_TOKENS_OF_DATA:
|
|
79
|
-
truncated_result = truncated_df.iloc[:-1].to_markdown()
|
|
80
|
-
return truncated_result
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
@mlflow.trace(span_type="PARSER")
|
|
84
|
-
def _parse_query_result(resp, truncate_results) -> Union[str, pd.DataFrame]:
|
|
85
|
-
output = resp["result"]
|
|
86
|
-
if not output:
|
|
87
|
-
return "EMPTY"
|
|
88
|
-
|
|
89
|
-
columns = resp["manifest"]["schema"]["columns"]
|
|
90
|
-
header = [str(col["name"]) for col in columns]
|
|
91
|
-
rows = []
|
|
92
|
-
|
|
93
|
-
for item in output["data_array"]:
|
|
94
|
-
row = []
|
|
95
|
-
for column, value in zip(columns, item):
|
|
96
|
-
type_name = column["type_name"]
|
|
97
|
-
if value is None:
|
|
98
|
-
row.append(None)
|
|
99
|
-
continue
|
|
100
|
-
|
|
101
|
-
if type_name in ["INT", "LONG", "SHORT", "BYTE"]:
|
|
102
|
-
row.append(int(value))
|
|
103
|
-
elif type_name in ["FLOAT", "DOUBLE", "DECIMAL"]:
|
|
104
|
-
row.append(float(value))
|
|
105
|
-
elif type_name == "BOOLEAN":
|
|
106
|
-
row.append(value.lower() == "true")
|
|
107
|
-
elif type_name == "DATE" or type_name == "TIMESTAMP":
|
|
108
|
-
row.append(datetime.strptime(value[:10], "%Y-%m-%d").date())
|
|
109
|
-
elif type_name == "BINARY":
|
|
110
|
-
row.append(bytes(value, "utf-8"))
|
|
111
|
-
else:
|
|
112
|
-
row.append(value)
|
|
113
|
-
|
|
114
|
-
rows.append(row)
|
|
115
|
-
|
|
116
|
-
dataframe = pd.DataFrame(rows, columns=header)
|
|
117
|
-
|
|
118
|
-
if truncate_results:
|
|
119
|
-
query_result = _truncate_result(dataframe)
|
|
120
|
-
else:
|
|
121
|
-
query_result = dataframe.to_markdown()
|
|
122
|
-
|
|
123
|
-
return query_result.strip()
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
class Genie:
|
|
127
|
-
def __init__(
|
|
128
|
-
self,
|
|
129
|
-
space_id,
|
|
130
|
-
client: WorkspaceClient | None = None,
|
|
131
|
-
truncate_results: bool = False,
|
|
132
|
-
polling_interval: int = DEFAULT_POLLING_INTERVAL_SECS,
|
|
133
|
-
):
|
|
134
|
-
self.space_id = space_id
|
|
135
|
-
workspace_client = client or WorkspaceClient()
|
|
136
|
-
self.genie = workspace_client.genie
|
|
137
|
-
self.description = self.genie.get_space(space_id).description
|
|
138
|
-
self.headers = {
|
|
139
|
-
"Accept": "application/json",
|
|
140
|
-
"Content-Type": "application/json",
|
|
141
|
-
}
|
|
142
|
-
self.truncate_results = truncate_results
|
|
143
|
-
if polling_interval < 1 or polling_interval > 30:
|
|
144
|
-
raise ValueError("poll_interval must be between 1 and 30 seconds")
|
|
145
|
-
self.poll_interval = polling_interval
|
|
146
|
-
|
|
147
|
-
@mlflow.trace()
|
|
148
|
-
def start_conversation(self, content):
|
|
149
|
-
resp = self.genie._api.do(
|
|
150
|
-
"POST",
|
|
151
|
-
f"/api/2.0/genie/spaces/{self.space_id}/start-conversation",
|
|
152
|
-
body={"content": content},
|
|
153
|
-
headers=self.headers,
|
|
154
|
-
)
|
|
155
|
-
return resp
|
|
156
|
-
|
|
157
|
-
@mlflow.trace()
|
|
158
|
-
def create_message(self, conversation_id, content):
|
|
159
|
-
resp = self.genie._api.do(
|
|
160
|
-
"POST",
|
|
161
|
-
f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages",
|
|
162
|
-
body={"content": content},
|
|
163
|
-
headers=self.headers,
|
|
164
|
-
)
|
|
165
|
-
return resp
|
|
166
|
-
|
|
167
|
-
@mlflow.trace()
|
|
168
|
-
def poll_for_result(self, conversation_id, message_id):
|
|
169
|
-
@mlflow.trace()
|
|
170
|
-
def poll_query_results(attachment_id, query_str, description):
|
|
171
|
-
iteration_count = 0
|
|
172
|
-
while iteration_count < MAX_ITERATIONS:
|
|
173
|
-
iteration_count += 1
|
|
174
|
-
resp = self.genie._api.do(
|
|
175
|
-
"GET",
|
|
176
|
-
f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages/{message_id}/attachments/{attachment_id}/query-result",
|
|
177
|
-
headers=self.headers,
|
|
178
|
-
)["statement_response"]
|
|
179
|
-
state = resp["status"]["state"]
|
|
180
|
-
if state == "SUCCEEDED":
|
|
181
|
-
result = _parse_query_result(resp, self.truncate_results)
|
|
182
|
-
return GenieResponse(
|
|
183
|
-
conversation_id, result, query_str, description
|
|
184
|
-
)
|
|
185
|
-
elif state in ["RUNNING", "PENDING"]:
|
|
186
|
-
logger.debug("Waiting for query result...")
|
|
187
|
-
time.sleep(self.poll_interval)
|
|
188
|
-
else:
|
|
189
|
-
return GenieResponse(
|
|
190
|
-
conversation_id,
|
|
191
|
-
f"No query result: {resp['state']}",
|
|
192
|
-
query_str,
|
|
193
|
-
description,
|
|
194
|
-
)
|
|
195
|
-
return GenieResponse(
|
|
196
|
-
conversation_id,
|
|
197
|
-
f"Genie query for result timed out after {MAX_ITERATIONS} iterations of {self.poll_interval} seconds",
|
|
198
|
-
query_str,
|
|
199
|
-
description,
|
|
200
|
-
)
|
|
201
43
|
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
"GET",
|
|
209
|
-
f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages/{message_id}",
|
|
210
|
-
headers=self.headers,
|
|
211
|
-
)
|
|
212
|
-
if resp["status"] == "COMPLETED":
|
|
213
|
-
# Check if attachments key exists in response
|
|
214
|
-
attachments = resp.get("attachments", [])
|
|
215
|
-
if not attachments:
|
|
216
|
-
# Handle case where response has no attachments
|
|
217
|
-
return GenieResponse(
|
|
218
|
-
conversation_id,
|
|
219
|
-
result=f"Genie query completed but no attachments found. Response: {resp}",
|
|
220
|
-
)
|
|
221
|
-
|
|
222
|
-
attachment = next((r for r in attachments if "query" in r), None)
|
|
223
|
-
if attachment:
|
|
224
|
-
query_obj = attachment["query"]
|
|
225
|
-
description = query_obj.get("description", "")
|
|
226
|
-
query_str = query_obj.get("query", "")
|
|
227
|
-
attachment_id = attachment["attachment_id"]
|
|
228
|
-
return poll_query_results(attachment_id, query_str, description)
|
|
229
|
-
if resp["status"] == "COMPLETED":
|
|
230
|
-
text_content = next(
|
|
231
|
-
(r for r in attachments if "text" in r), None
|
|
232
|
-
)
|
|
233
|
-
if text_content:
|
|
234
|
-
return GenieResponse(
|
|
235
|
-
conversation_id, result=text_content["text"]["content"]
|
|
236
|
-
)
|
|
237
|
-
return GenieResponse(
|
|
238
|
-
conversation_id,
|
|
239
|
-
result="Genie query completed but no text content found in attachments.",
|
|
240
|
-
)
|
|
241
|
-
elif resp["status"] in {"CANCELLED", "QUERY_RESULT_EXPIRED"}:
|
|
242
|
-
return GenieResponse(
|
|
243
|
-
conversation_id, result=f"Genie query {resp['status'].lower()}."
|
|
244
|
-
)
|
|
245
|
-
elif resp["status"] == "FAILED":
|
|
246
|
-
return GenieResponse(
|
|
247
|
-
conversation_id,
|
|
248
|
-
result=f"Genie query failed with error: {resp.get('error', 'Unknown error')}",
|
|
249
|
-
)
|
|
250
|
-
# includes EXECUTING_QUERY, Genie can retry after this status
|
|
251
|
-
else:
|
|
252
|
-
logger.debug(f"Waiting...: {resp['status']}")
|
|
253
|
-
time.sleep(self.poll_interval)
|
|
254
|
-
return GenieResponse(
|
|
255
|
-
conversation_id,
|
|
256
|
-
f"Genie query timed out after {MAX_ITERATIONS} iterations of {self.poll_interval} seconds",
|
|
257
|
-
)
|
|
44
|
+
def _response_to_json(response: GenieResponse) -> str:
|
|
45
|
+
"""Convert GenieResponse to JSON string, handling DataFrame results."""
|
|
46
|
+
# Convert result to string if it's a DataFrame
|
|
47
|
+
result: str | pd.DataFrame = response.result
|
|
48
|
+
if isinstance(result, pd.DataFrame):
|
|
49
|
+
result = result.to_markdown()
|
|
258
50
|
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
if conversation_id:
|
|
267
|
-
resp = self.create_message(conversation_id, question)
|
|
268
|
-
else:
|
|
269
|
-
resp = self.start_conversation(question)
|
|
270
|
-
logger.debug(f"ask_question response: {resp}")
|
|
271
|
-
return self.poll_for_result(resp["conversation_id"], resp["message_id"])
|
|
51
|
+
data: dict[str, Any] = {
|
|
52
|
+
"result": result,
|
|
53
|
+
"query": response.query,
|
|
54
|
+
"description": response.description,
|
|
55
|
+
"conversation_id": response.conversation_id,
|
|
56
|
+
}
|
|
57
|
+
return json.dumps(data)
|
|
272
58
|
|
|
273
59
|
|
|
274
60
|
def create_genie_tool(
|
|
275
61
|
genie_room: GenieRoomModel | dict[str, Any],
|
|
276
|
-
name:
|
|
277
|
-
description:
|
|
278
|
-
persist_conversation: bool =
|
|
62
|
+
name: str | None = None,
|
|
63
|
+
description: str | None = None,
|
|
64
|
+
persist_conversation: bool = True,
|
|
279
65
|
truncate_results: bool = False,
|
|
280
|
-
|
|
281
|
-
|
|
66
|
+
lru_cache_parameters: GenieLRUCacheParametersModel | dict[str, Any] | None = None,
|
|
67
|
+
semantic_cache_parameters: GenieSemanticCacheParametersModel
|
|
68
|
+
| dict[str, Any]
|
|
69
|
+
| None = None,
|
|
70
|
+
) -> Callable[..., Command]:
|
|
282
71
|
"""
|
|
283
72
|
Create a tool for interacting with Databricks Genie for natural language queries to databases.
|
|
284
73
|
|
|
@@ -290,17 +79,37 @@ def create_genie_tool(
|
|
|
290
79
|
genie_room: GenieRoomModel or dict containing Genie configuration
|
|
291
80
|
name: Optional custom name for the tool. If None, uses default "genie_tool"
|
|
292
81
|
description: Optional custom description for the tool. If None, uses default description
|
|
82
|
+
persist_conversation: Whether to persist conversation IDs across tool calls for
|
|
83
|
+
multi-turn conversations within the same Genie space
|
|
84
|
+
truncate_results: Whether to truncate large query results to fit token limits
|
|
85
|
+
lru_cache_parameters: Optional LRU cache configuration for SQL query caching
|
|
86
|
+
semantic_cache_parameters: Optional semantic cache configuration using pg_vector
|
|
87
|
+
for similarity-based query matching
|
|
293
88
|
|
|
294
89
|
Returns:
|
|
295
90
|
A LangGraph tool that processes natural language queries through Genie
|
|
296
91
|
"""
|
|
92
|
+
logger.debug(
|
|
93
|
+
"Creating Genie tool",
|
|
94
|
+
genie_room_type=type(genie_room).__name__,
|
|
95
|
+
persist_conversation=persist_conversation,
|
|
96
|
+
truncate_results=truncate_results,
|
|
97
|
+
name=name,
|
|
98
|
+
has_lru_cache=lru_cache_parameters is not None,
|
|
99
|
+
has_semantic_cache=semantic_cache_parameters is not None,
|
|
100
|
+
)
|
|
297
101
|
|
|
298
102
|
if isinstance(genie_room, dict):
|
|
299
103
|
genie_room = GenieRoomModel(**genie_room)
|
|
300
104
|
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
105
|
+
if isinstance(lru_cache_parameters, dict):
|
|
106
|
+
lru_cache_parameters = GenieLRUCacheParametersModel(**lru_cache_parameters)
|
|
107
|
+
|
|
108
|
+
if isinstance(semantic_cache_parameters, dict):
|
|
109
|
+
semantic_cache_parameters = GenieSemanticCacheParametersModel(
|
|
110
|
+
**semantic_cache_parameters
|
|
111
|
+
)
|
|
112
|
+
|
|
304
113
|
space_id: AnyVariable = genie_room.space_id or os.environ.get(
|
|
305
114
|
"DATABRICKS_GENIE_SPACE_ID"
|
|
306
115
|
)
|
|
@@ -308,13 +117,6 @@ def create_genie_tool(
|
|
|
308
117
|
space_id = CompositeVariableModel(**space_id)
|
|
309
118
|
space_id = value_of(space_id)
|
|
310
119
|
|
|
311
|
-
# genie: Genie = Genie(
|
|
312
|
-
# space_id=space_id,
|
|
313
|
-
# client=genie_room.workspace_client,
|
|
314
|
-
# truncate_results=truncate_results,
|
|
315
|
-
# polling_interval=poll_interval,
|
|
316
|
-
# )
|
|
317
|
-
|
|
318
120
|
default_description: str = dedent("""
|
|
319
121
|
This tool lets you have a conversation and chat with tabular data about <topic>. You should ask
|
|
320
122
|
questions about the data and the tool will try to answer them.
|
|
@@ -337,51 +139,99 @@ Returns:
|
|
|
337
139
|
GenieResponse: A response object containing the conversation ID and result from Genie."""
|
|
338
140
|
tool_description = tool_description + function_docs
|
|
339
141
|
|
|
142
|
+
genie: Genie = Genie(
|
|
143
|
+
space_id=space_id,
|
|
144
|
+
client=genie_room.workspace_client,
|
|
145
|
+
truncate_results=truncate_results,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
genie_service: GenieServiceBase = GenieService(genie)
|
|
149
|
+
|
|
150
|
+
# Wrap with semantic cache first (checked second due to decorator pattern)
|
|
151
|
+
if semantic_cache_parameters is not None:
|
|
152
|
+
genie_service = SemanticCacheService(
|
|
153
|
+
impl=genie_service,
|
|
154
|
+
parameters=semantic_cache_parameters,
|
|
155
|
+
workspace_client=genie_room.workspace_client, # Pass workspace client for conversation history
|
|
156
|
+
).initialize() # Eagerly initialize to fail fast and create table
|
|
157
|
+
|
|
158
|
+
# Wrap with LRU cache last (checked first - fast O(1) exact match)
|
|
159
|
+
if lru_cache_parameters is not None:
|
|
160
|
+
genie_service = LRUCacheService(
|
|
161
|
+
impl=genie_service,
|
|
162
|
+
parameters=lru_cache_parameters,
|
|
163
|
+
)
|
|
164
|
+
|
|
340
165
|
@tool(
|
|
341
166
|
name_or_callable=tool_name,
|
|
342
167
|
description=tool_description,
|
|
343
168
|
)
|
|
344
169
|
def genie_tool(
|
|
345
170
|
question: Annotated[str, "The question to ask Genie about your data"],
|
|
346
|
-
|
|
347
|
-
tool_call_id: Annotated[str, InjectedToolCallId],
|
|
171
|
+
runtime: ToolRuntime[Context, AgentState],
|
|
348
172
|
) -> Command:
|
|
349
|
-
|
|
350
|
-
space_id=space_id,
|
|
351
|
-
client=genie_room.workspace_client,
|
|
352
|
-
truncate_results=truncate_results,
|
|
353
|
-
polling_interval=poll_interval,
|
|
354
|
-
)
|
|
173
|
+
"""Process a natural language question through Databricks Genie.
|
|
355
174
|
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
175
|
+
Uses ToolRuntime to access state and context in a type-safe way.
|
|
176
|
+
"""
|
|
177
|
+
# Access state through runtime
|
|
178
|
+
state: AgentState = runtime.state
|
|
179
|
+
tool_call_id: str = runtime.tool_call_id
|
|
180
|
+
|
|
181
|
+
# Ensure space_id is a string for state keys
|
|
182
|
+
space_id_str: str = str(space_id)
|
|
183
|
+
|
|
184
|
+
# Get session state (or create new one)
|
|
185
|
+
session: SessionState = state.get("session", SessionState())
|
|
186
|
+
|
|
187
|
+
# Get existing conversation ID from session
|
|
188
|
+
existing_conversation_id: str | None = session.genie.get_conversation_id(
|
|
189
|
+
space_id_str
|
|
190
|
+
)
|
|
191
|
+
logger.trace(
|
|
192
|
+
"Using existing conversation ID",
|
|
193
|
+
space_id=space_id_str,
|
|
194
|
+
conversation_id=existing_conversation_id,
|
|
362
195
|
)
|
|
363
196
|
|
|
364
|
-
|
|
197
|
+
# Call ask_question which always returns CacheResult with cache metadata
|
|
198
|
+
cache_result: CacheResult = genie_service.ask_question(
|
|
365
199
|
question, conversation_id=existing_conversation_id
|
|
366
200
|
)
|
|
201
|
+
genie_response: GenieResponse = cache_result.response
|
|
202
|
+
cache_hit: bool = cache_result.cache_hit
|
|
203
|
+
cache_key: str | None = cache_result.served_by
|
|
367
204
|
|
|
368
|
-
current_conversation_id: str =
|
|
205
|
+
current_conversation_id: str = genie_response.conversation_id
|
|
369
206
|
logger.debug(
|
|
370
|
-
|
|
207
|
+
"Genie question answered",
|
|
208
|
+
space_id=space_id_str,
|
|
209
|
+
conversation_id=current_conversation_id,
|
|
210
|
+
cache_hit=cache_hit,
|
|
211
|
+
cache_key=cache_key,
|
|
371
212
|
)
|
|
372
213
|
|
|
373
|
-
# Update
|
|
214
|
+
# Update session state with cache information
|
|
215
|
+
if persist_conversation:
|
|
216
|
+
session.genie.update_space(
|
|
217
|
+
space_id=space_id_str,
|
|
218
|
+
conversation_id=current_conversation_id,
|
|
219
|
+
cache_hit=cache_hit,
|
|
220
|
+
cache_key=cache_key,
|
|
221
|
+
last_query=question,
|
|
222
|
+
)
|
|
374
223
|
|
|
224
|
+
# Build update dict with response and session
|
|
375
225
|
update: dict[str, Any] = {
|
|
376
|
-
"messages": [
|
|
226
|
+
"messages": [
|
|
227
|
+
ToolMessage(
|
|
228
|
+
_response_to_json(genie_response), tool_call_id=tool_call_id
|
|
229
|
+
)
|
|
230
|
+
],
|
|
377
231
|
}
|
|
378
232
|
|
|
379
233
|
if persist_conversation:
|
|
380
|
-
|
|
381
|
-
updated_conversation_ids[space_id] = current_conversation_id
|
|
382
|
-
update["genie_conversation_ids"] = updated_conversation_ids
|
|
383
|
-
|
|
384
|
-
logger.debug(f"State update: {update}")
|
|
234
|
+
update["session"] = session
|
|
385
235
|
|
|
386
236
|
return Command(update=update)
|
|
387
237
|
|