dao-ai 0.0.25__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/__init__.py +29 -0
- dao_ai/agent_as_code.py +5 -5
- dao_ai/cli.py +245 -40
- dao_ai/config.py +1863 -338
- dao_ai/genie/__init__.py +38 -0
- dao_ai/genie/cache/__init__.py +43 -0
- dao_ai/genie/cache/base.py +72 -0
- dao_ai/genie/cache/core.py +79 -0
- dao_ai/genie/cache/lru.py +347 -0
- dao_ai/genie/cache/semantic.py +970 -0
- dao_ai/genie/core.py +35 -0
- dao_ai/graph.py +27 -228
- dao_ai/hooks/__init__.py +9 -6
- dao_ai/hooks/core.py +27 -195
- dao_ai/logging.py +56 -0
- dao_ai/memory/__init__.py +10 -0
- dao_ai/memory/core.py +65 -30
- dao_ai/memory/databricks.py +402 -0
- dao_ai/memory/postgres.py +79 -38
- dao_ai/messages.py +6 -4
- dao_ai/middleware/__init__.py +125 -0
- dao_ai/middleware/assertions.py +806 -0
- dao_ai/middleware/base.py +50 -0
- dao_ai/middleware/core.py +67 -0
- dao_ai/middleware/guardrails.py +420 -0
- dao_ai/middleware/human_in_the_loop.py +232 -0
- dao_ai/middleware/message_validation.py +586 -0
- dao_ai/middleware/summarization.py +197 -0
- dao_ai/models.py +1306 -114
- dao_ai/nodes.py +261 -166
- dao_ai/optimization.py +674 -0
- dao_ai/orchestration/__init__.py +52 -0
- dao_ai/orchestration/core.py +294 -0
- dao_ai/orchestration/supervisor.py +278 -0
- dao_ai/orchestration/swarm.py +271 -0
- dao_ai/prompts.py +128 -31
- dao_ai/providers/databricks.py +645 -172
- dao_ai/state.py +157 -21
- dao_ai/tools/__init__.py +13 -5
- dao_ai/tools/agent.py +1 -3
- dao_ai/tools/core.py +64 -11
- dao_ai/tools/email.py +232 -0
- dao_ai/tools/genie.py +144 -295
- dao_ai/tools/mcp.py +220 -133
- dao_ai/tools/memory.py +50 -0
- dao_ai/tools/python.py +9 -14
- dao_ai/tools/search.py +14 -0
- dao_ai/tools/slack.py +22 -10
- dao_ai/tools/sql.py +202 -0
- dao_ai/tools/time.py +30 -7
- dao_ai/tools/unity_catalog.py +165 -88
- dao_ai/tools/vector_search.py +360 -40
- dao_ai/utils.py +218 -16
- dao_ai-0.1.2.dist-info/METADATA +455 -0
- dao_ai-0.1.2.dist-info/RECORD +64 -0
- {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/WHEEL +1 -1
- dao_ai/chat_models.py +0 -204
- dao_ai/guardrails.py +0 -112
- dao_ai/tools/human_in_the_loop.py +0 -100
- dao_ai-0.0.25.dist-info/METADATA +0 -1165
- dao_ai-0.0.25.dist-info/RECORD +0 -41
- {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/licenses/LICENSE +0 -0
dao_ai/tools/genie.py
CHANGED
|
@@ -1,285 +1,73 @@
|
|
|
1
|
-
|
|
1
|
+
"""
|
|
2
|
+
Genie tool for natural language queries to databases.
|
|
3
|
+
|
|
4
|
+
This module provides the tool factory for creating LangGraph tools that
|
|
5
|
+
interact with Databricks Genie.
|
|
6
|
+
|
|
7
|
+
For the core Genie service and cache implementations, see:
|
|
8
|
+
- dao_ai.genie: GenieService, GenieServiceBase
|
|
9
|
+
- dao_ai.genie.cache: LRUCacheService, SemanticCacheService
|
|
10
|
+
"""
|
|
11
|
+
|
|
2
12
|
import json
|
|
3
|
-
import logging
|
|
4
13
|
import os
|
|
5
|
-
import time
|
|
6
|
-
from dataclasses import asdict, dataclass
|
|
7
|
-
from datetime import datetime
|
|
8
14
|
from textwrap import dedent
|
|
9
|
-
from typing import Annotated, Any, Callable
|
|
15
|
+
from typing import Annotated, Any, Callable
|
|
10
16
|
|
|
11
|
-
import mlflow
|
|
12
17
|
import pandas as pd
|
|
13
|
-
from
|
|
18
|
+
from databricks_ai_bridge.genie import Genie, GenieResponse
|
|
19
|
+
from langchain.tools import ToolRuntime, tool
|
|
14
20
|
from langchain_core.messages import ToolMessage
|
|
15
|
-
from langchain_core.tools import InjectedToolCallId, tool
|
|
16
|
-
from langgraph.prebuilt import InjectedState
|
|
17
21
|
from langgraph.types import Command
|
|
18
22
|
from loguru import logger
|
|
19
|
-
from pydantic import BaseModel
|
|
20
|
-
|
|
21
|
-
from dao_ai.config import AnyVariable, CompositeVariableModel, GenieRoomModel, value_of
|
|
22
|
-
|
|
23
|
-
MAX_TOKENS_OF_DATA: int = 20000
|
|
24
|
-
MAX_ITERATIONS: int = 50
|
|
25
|
-
DEFAULT_POLLING_INTERVAL_SECS: int = 2
|
|
26
|
-
|
|
23
|
+
from pydantic import BaseModel
|
|
27
24
|
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
query: Optional[str] = ""
|
|
40
|
-
description: Optional[str] = ""
|
|
41
|
-
|
|
42
|
-
def to_json(self):
|
|
43
|
-
return json.dumps(asdict(self))
|
|
25
|
+
from dao_ai.config import (
|
|
26
|
+
AnyVariable,
|
|
27
|
+
CompositeVariableModel,
|
|
28
|
+
GenieLRUCacheParametersModel,
|
|
29
|
+
GenieRoomModel,
|
|
30
|
+
GenieSemanticCacheParametersModel,
|
|
31
|
+
value_of,
|
|
32
|
+
)
|
|
33
|
+
from dao_ai.genie import GenieService, GenieServiceBase
|
|
34
|
+
from dao_ai.genie.cache import CacheResult, LRUCacheService, SemanticCacheService
|
|
35
|
+
from dao_ai.state import AgentState, Context, SessionState
|
|
44
36
|
|
|
45
37
|
|
|
46
38
|
class GenieToolInput(BaseModel):
|
|
47
|
-
"""Input schema for
|
|
48
|
-
|
|
49
|
-
question: str = Field(
|
|
50
|
-
description="The question to ask Genie about your data. Ask simple, clear questions about your tabular data. For complex analysis, ask multiple simple questions rather than one complex question."
|
|
51
|
-
)
|
|
39
|
+
"""Input schema for Genie tool - only includes user-facing parameters."""
|
|
52
40
|
|
|
41
|
+
question: str
|
|
53
42
|
|
|
54
|
-
def _truncate_result(dataframe: pd.DataFrame) -> str:
|
|
55
|
-
query_result = dataframe.to_markdown()
|
|
56
|
-
tokens_used = _count_tokens(query_result)
|
|
57
|
-
|
|
58
|
-
# If the full result fits, return it
|
|
59
|
-
if tokens_used <= MAX_TOKENS_OF_DATA:
|
|
60
|
-
return query_result.strip()
|
|
61
|
-
|
|
62
|
-
def is_too_big(n):
|
|
63
|
-
return _count_tokens(dataframe.iloc[:n].to_markdown()) > MAX_TOKENS_OF_DATA
|
|
64
|
-
|
|
65
|
-
# Use bisect_left to find the cutoff point of rows within the max token data limit in a O(log n) complexity
|
|
66
|
-
# Passing True, as this is the target value we are looking for when _is_too_big returns
|
|
67
|
-
cutoff = bisect.bisect_left(range(len(dataframe) + 1), True, key=is_too_big)
|
|
68
|
-
|
|
69
|
-
# Slice to the found limit
|
|
70
|
-
truncated_df = dataframe.iloc[:cutoff]
|
|
71
|
-
|
|
72
|
-
# Edge case: Cannot return any rows because of tokens so return an empty string
|
|
73
|
-
if len(truncated_df) == 0:
|
|
74
|
-
return ""
|
|
75
|
-
|
|
76
|
-
truncated_result = truncated_df.to_markdown()
|
|
77
|
-
|
|
78
|
-
# Double-check edge case if we overshot by one
|
|
79
|
-
if _count_tokens(truncated_result) > MAX_TOKENS_OF_DATA:
|
|
80
|
-
truncated_result = truncated_df.iloc[:-1].to_markdown()
|
|
81
|
-
return truncated_result
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
@mlflow.trace(span_type="PARSER")
|
|
85
|
-
def _parse_query_result(resp, truncate_results) -> Union[str, pd.DataFrame]:
|
|
86
|
-
output = resp["result"]
|
|
87
|
-
if not output:
|
|
88
|
-
return "EMPTY"
|
|
89
|
-
|
|
90
|
-
columns = resp["manifest"]["schema"]["columns"]
|
|
91
|
-
header = [str(col["name"]) for col in columns]
|
|
92
|
-
rows = []
|
|
93
|
-
|
|
94
|
-
for item in output["data_array"]:
|
|
95
|
-
row = []
|
|
96
|
-
for column, value in zip(columns, item):
|
|
97
|
-
type_name = column["type_name"]
|
|
98
|
-
if value is None:
|
|
99
|
-
row.append(None)
|
|
100
|
-
continue
|
|
101
|
-
|
|
102
|
-
if type_name in ["INT", "LONG", "SHORT", "BYTE"]:
|
|
103
|
-
row.append(int(value))
|
|
104
|
-
elif type_name in ["FLOAT", "DOUBLE", "DECIMAL"]:
|
|
105
|
-
row.append(float(value))
|
|
106
|
-
elif type_name == "BOOLEAN":
|
|
107
|
-
row.append(value.lower() == "true")
|
|
108
|
-
elif type_name == "DATE" or type_name == "TIMESTAMP":
|
|
109
|
-
row.append(datetime.strptime(value[:10], "%Y-%m-%d").date())
|
|
110
|
-
elif type_name == "BINARY":
|
|
111
|
-
row.append(bytes(value, "utf-8"))
|
|
112
|
-
else:
|
|
113
|
-
row.append(value)
|
|
114
|
-
|
|
115
|
-
rows.append(row)
|
|
116
|
-
|
|
117
|
-
dataframe = pd.DataFrame(rows, columns=header)
|
|
118
|
-
|
|
119
|
-
if truncate_results:
|
|
120
|
-
query_result = _truncate_result(dataframe)
|
|
121
|
-
else:
|
|
122
|
-
query_result = dataframe.to_markdown()
|
|
123
|
-
|
|
124
|
-
return query_result.strip()
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
class Genie:
|
|
128
|
-
def __init__(
|
|
129
|
-
self,
|
|
130
|
-
space_id,
|
|
131
|
-
client: WorkspaceClient | None = None,
|
|
132
|
-
truncate_results: bool = False,
|
|
133
|
-
polling_interval: int = DEFAULT_POLLING_INTERVAL_SECS,
|
|
134
|
-
):
|
|
135
|
-
self.space_id = space_id
|
|
136
|
-
workspace_client = client or WorkspaceClient()
|
|
137
|
-
self.genie = workspace_client.genie
|
|
138
|
-
self.description = self.genie.get_space(space_id).description
|
|
139
|
-
self.headers = {
|
|
140
|
-
"Accept": "application/json",
|
|
141
|
-
"Content-Type": "application/json",
|
|
142
|
-
}
|
|
143
|
-
self.truncate_results = truncate_results
|
|
144
|
-
if polling_interval < 1 or polling_interval > 30:
|
|
145
|
-
raise ValueError("poll_interval must be between 1 and 30 seconds")
|
|
146
|
-
self.poll_interval = polling_interval
|
|
147
|
-
|
|
148
|
-
@mlflow.trace()
|
|
149
|
-
def start_conversation(self, content):
|
|
150
|
-
resp = self.genie._api.do(
|
|
151
|
-
"POST",
|
|
152
|
-
f"/api/2.0/genie/spaces/{self.space_id}/start-conversation",
|
|
153
|
-
body={"content": content},
|
|
154
|
-
headers=self.headers,
|
|
155
|
-
)
|
|
156
|
-
return resp
|
|
157
|
-
|
|
158
|
-
@mlflow.trace()
|
|
159
|
-
def create_message(self, conversation_id, content):
|
|
160
|
-
resp = self.genie._api.do(
|
|
161
|
-
"POST",
|
|
162
|
-
f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages",
|
|
163
|
-
body={"content": content},
|
|
164
|
-
headers=self.headers,
|
|
165
|
-
)
|
|
166
|
-
return resp
|
|
167
|
-
|
|
168
|
-
@mlflow.trace()
|
|
169
|
-
def poll_for_result(self, conversation_id, message_id):
|
|
170
|
-
@mlflow.trace()
|
|
171
|
-
def poll_query_results(attachment_id, query_str, description):
|
|
172
|
-
iteration_count = 0
|
|
173
|
-
while iteration_count < MAX_ITERATIONS:
|
|
174
|
-
iteration_count += 1
|
|
175
|
-
resp = self.genie._api.do(
|
|
176
|
-
"GET",
|
|
177
|
-
f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages/{message_id}/attachments/{attachment_id}/query-result",
|
|
178
|
-
headers=self.headers,
|
|
179
|
-
)["statement_response"]
|
|
180
|
-
state = resp["status"]["state"]
|
|
181
|
-
if state == "SUCCEEDED":
|
|
182
|
-
result = _parse_query_result(resp, self.truncate_results)
|
|
183
|
-
return GenieResponse(
|
|
184
|
-
conversation_id, result, query_str, description
|
|
185
|
-
)
|
|
186
|
-
elif state in ["RUNNING", "PENDING"]:
|
|
187
|
-
logging.debug("Waiting for query result...")
|
|
188
|
-
time.sleep(self.poll_interval)
|
|
189
|
-
else:
|
|
190
|
-
return GenieResponse(
|
|
191
|
-
conversation_id,
|
|
192
|
-
f"No query result: {resp['state']}",
|
|
193
|
-
query_str,
|
|
194
|
-
description,
|
|
195
|
-
)
|
|
196
|
-
return GenieResponse(
|
|
197
|
-
conversation_id,
|
|
198
|
-
f"Genie query for result timed out after {MAX_ITERATIONS} iterations of {self.poll_interval} seconds",
|
|
199
|
-
query_str,
|
|
200
|
-
description,
|
|
201
|
-
)
|
|
202
43
|
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
"GET",
|
|
210
|
-
f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages/{message_id}",
|
|
211
|
-
headers=self.headers,
|
|
212
|
-
)
|
|
213
|
-
if resp["status"] == "COMPLETED":
|
|
214
|
-
# Check if attachments key exists in response
|
|
215
|
-
attachments = resp.get("attachments", [])
|
|
216
|
-
if not attachments:
|
|
217
|
-
# Handle case where response has no attachments
|
|
218
|
-
return GenieResponse(
|
|
219
|
-
conversation_id,
|
|
220
|
-
result=f"Genie query completed but no attachments found. Response: {resp}",
|
|
221
|
-
)
|
|
222
|
-
|
|
223
|
-
attachment = next((r for r in attachments if "query" in r), None)
|
|
224
|
-
if attachment:
|
|
225
|
-
query_obj = attachment["query"]
|
|
226
|
-
description = query_obj.get("description", "")
|
|
227
|
-
query_str = query_obj.get("query", "")
|
|
228
|
-
attachment_id = attachment["attachment_id"]
|
|
229
|
-
return poll_query_results(attachment_id, query_str, description)
|
|
230
|
-
if resp["status"] == "COMPLETED":
|
|
231
|
-
text_content = next(
|
|
232
|
-
(r for r in attachments if "text" in r), None
|
|
233
|
-
)
|
|
234
|
-
if text_content:
|
|
235
|
-
return GenieResponse(
|
|
236
|
-
conversation_id, result=text_content["text"]["content"]
|
|
237
|
-
)
|
|
238
|
-
return GenieResponse(
|
|
239
|
-
conversation_id,
|
|
240
|
-
result="Genie query completed but no text content found in attachments.",
|
|
241
|
-
)
|
|
242
|
-
elif resp["status"] in {"CANCELLED", "QUERY_RESULT_EXPIRED"}:
|
|
243
|
-
return GenieResponse(
|
|
244
|
-
conversation_id, result=f"Genie query {resp['status'].lower()}."
|
|
245
|
-
)
|
|
246
|
-
elif resp["status"] == "FAILED":
|
|
247
|
-
return GenieResponse(
|
|
248
|
-
conversation_id,
|
|
249
|
-
result=f"Genie query failed with error: {resp.get('error', 'Unknown error')}",
|
|
250
|
-
)
|
|
251
|
-
# includes EXECUTING_QUERY, Genie can retry after this status
|
|
252
|
-
else:
|
|
253
|
-
logging.debug(f"Waiting...: {resp['status']}")
|
|
254
|
-
time.sleep(self.poll_interval)
|
|
255
|
-
return GenieResponse(
|
|
256
|
-
conversation_id,
|
|
257
|
-
f"Genie query timed out after {MAX_ITERATIONS} iterations of {self.poll_interval} seconds",
|
|
258
|
-
)
|
|
44
|
+
def _response_to_json(response: GenieResponse) -> str:
|
|
45
|
+
"""Convert GenieResponse to JSON string, handling DataFrame results."""
|
|
46
|
+
# Convert result to string if it's a DataFrame
|
|
47
|
+
result: str | pd.DataFrame = response.result
|
|
48
|
+
if isinstance(result, pd.DataFrame):
|
|
49
|
+
result = result.to_markdown()
|
|
259
50
|
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
if conversation_id:
|
|
268
|
-
resp = self.create_message(conversation_id, question)
|
|
269
|
-
else:
|
|
270
|
-
resp = self.start_conversation(question)
|
|
271
|
-
logger.debug(f"ask_question response: {resp}")
|
|
272
|
-
return self.poll_for_result(resp["conversation_id"], resp["message_id"])
|
|
51
|
+
data: dict[str, Any] = {
|
|
52
|
+
"result": result,
|
|
53
|
+
"query": response.query,
|
|
54
|
+
"description": response.description,
|
|
55
|
+
"conversation_id": response.conversation_id,
|
|
56
|
+
}
|
|
57
|
+
return json.dumps(data)
|
|
273
58
|
|
|
274
59
|
|
|
275
60
|
def create_genie_tool(
|
|
276
61
|
genie_room: GenieRoomModel | dict[str, Any],
|
|
277
|
-
name:
|
|
278
|
-
description:
|
|
279
|
-
persist_conversation: bool =
|
|
62
|
+
name: str | None = None,
|
|
63
|
+
description: str | None = None,
|
|
64
|
+
persist_conversation: bool = True,
|
|
280
65
|
truncate_results: bool = False,
|
|
281
|
-
|
|
282
|
-
|
|
66
|
+
lru_cache_parameters: GenieLRUCacheParametersModel | dict[str, Any] | None = None,
|
|
67
|
+
semantic_cache_parameters: GenieSemanticCacheParametersModel
|
|
68
|
+
| dict[str, Any]
|
|
69
|
+
| None = None,
|
|
70
|
+
) -> Callable[..., Command]:
|
|
283
71
|
"""
|
|
284
72
|
Create a tool for interacting with Databricks Genie for natural language queries to databases.
|
|
285
73
|
|
|
@@ -291,17 +79,37 @@ def create_genie_tool(
|
|
|
291
79
|
genie_room: GenieRoomModel or dict containing Genie configuration
|
|
292
80
|
name: Optional custom name for the tool. If None, uses default "genie_tool"
|
|
293
81
|
description: Optional custom description for the tool. If None, uses default description
|
|
82
|
+
persist_conversation: Whether to persist conversation IDs across tool calls for
|
|
83
|
+
multi-turn conversations within the same Genie space
|
|
84
|
+
truncate_results: Whether to truncate large query results to fit token limits
|
|
85
|
+
lru_cache_parameters: Optional LRU cache configuration for SQL query caching
|
|
86
|
+
semantic_cache_parameters: Optional semantic cache configuration using pg_vector
|
|
87
|
+
for similarity-based query matching
|
|
294
88
|
|
|
295
89
|
Returns:
|
|
296
90
|
A LangGraph tool that processes natural language queries through Genie
|
|
297
91
|
"""
|
|
92
|
+
logger.debug(
|
|
93
|
+
"Creating Genie tool",
|
|
94
|
+
genie_room_type=type(genie_room).__name__,
|
|
95
|
+
persist_conversation=persist_conversation,
|
|
96
|
+
truncate_results=truncate_results,
|
|
97
|
+
name=name,
|
|
98
|
+
has_lru_cache=lru_cache_parameters is not None,
|
|
99
|
+
has_semantic_cache=semantic_cache_parameters is not None,
|
|
100
|
+
)
|
|
298
101
|
|
|
299
102
|
if isinstance(genie_room, dict):
|
|
300
103
|
genie_room = GenieRoomModel(**genie_room)
|
|
301
104
|
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
105
|
+
if isinstance(lru_cache_parameters, dict):
|
|
106
|
+
lru_cache_parameters = GenieLRUCacheParametersModel(**lru_cache_parameters)
|
|
107
|
+
|
|
108
|
+
if isinstance(semantic_cache_parameters, dict):
|
|
109
|
+
semantic_cache_parameters = GenieSemanticCacheParametersModel(
|
|
110
|
+
**semantic_cache_parameters
|
|
111
|
+
)
|
|
112
|
+
|
|
305
113
|
space_id: AnyVariable = genie_room.space_id or os.environ.get(
|
|
306
114
|
"DATABRICKS_GENIE_SPACE_ID"
|
|
307
115
|
)
|
|
@@ -309,13 +117,6 @@ def create_genie_tool(
|
|
|
309
117
|
space_id = CompositeVariableModel(**space_id)
|
|
310
118
|
space_id = value_of(space_id)
|
|
311
119
|
|
|
312
|
-
# genie: Genie = Genie(
|
|
313
|
-
# space_id=space_id,
|
|
314
|
-
# client=genie_room.workspace_client,
|
|
315
|
-
# truncate_results=truncate_results,
|
|
316
|
-
# polling_interval=poll_interval,
|
|
317
|
-
# )
|
|
318
|
-
|
|
319
120
|
default_description: str = dedent("""
|
|
320
121
|
This tool lets you have a conversation and chat with tabular data about <topic>. You should ask
|
|
321
122
|
questions about the data and the tool will try to answer them.
|
|
@@ -338,51 +139,99 @@ Returns:
|
|
|
338
139
|
GenieResponse: A response object containing the conversation ID and result from Genie."""
|
|
339
140
|
tool_description = tool_description + function_docs
|
|
340
141
|
|
|
142
|
+
genie: Genie = Genie(
|
|
143
|
+
space_id=space_id,
|
|
144
|
+
client=genie_room.workspace_client,
|
|
145
|
+
truncate_results=truncate_results,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
genie_service: GenieServiceBase = GenieService(genie)
|
|
149
|
+
|
|
150
|
+
# Wrap with semantic cache first (checked second due to decorator pattern)
|
|
151
|
+
if semantic_cache_parameters is not None:
|
|
152
|
+
genie_service = SemanticCacheService(
|
|
153
|
+
impl=genie_service,
|
|
154
|
+
parameters=semantic_cache_parameters,
|
|
155
|
+
workspace_client=genie_room.workspace_client, # Pass workspace client for conversation history
|
|
156
|
+
).initialize() # Eagerly initialize to fail fast and create table
|
|
157
|
+
|
|
158
|
+
# Wrap with LRU cache last (checked first - fast O(1) exact match)
|
|
159
|
+
if lru_cache_parameters is not None:
|
|
160
|
+
genie_service = LRUCacheService(
|
|
161
|
+
impl=genie_service,
|
|
162
|
+
parameters=lru_cache_parameters,
|
|
163
|
+
)
|
|
164
|
+
|
|
341
165
|
@tool(
|
|
342
166
|
name_or_callable=tool_name,
|
|
343
167
|
description=tool_description,
|
|
344
168
|
)
|
|
345
169
|
def genie_tool(
|
|
346
170
|
question: Annotated[str, "The question to ask Genie about your data"],
|
|
347
|
-
|
|
348
|
-
tool_call_id: Annotated[str, InjectedToolCallId],
|
|
171
|
+
runtime: ToolRuntime[Context, AgentState],
|
|
349
172
|
) -> Command:
|
|
350
|
-
|
|
351
|
-
space_id=space_id,
|
|
352
|
-
client=genie_room.workspace_client,
|
|
353
|
-
truncate_results=truncate_results,
|
|
354
|
-
polling_interval=poll_interval,
|
|
355
|
-
)
|
|
173
|
+
"""Process a natural language question through Databricks Genie.
|
|
356
174
|
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
175
|
+
Uses ToolRuntime to access state and context in a type-safe way.
|
|
176
|
+
"""
|
|
177
|
+
# Access state through runtime
|
|
178
|
+
state: AgentState = runtime.state
|
|
179
|
+
tool_call_id: str = runtime.tool_call_id
|
|
180
|
+
|
|
181
|
+
# Ensure space_id is a string for state keys
|
|
182
|
+
space_id_str: str = str(space_id)
|
|
183
|
+
|
|
184
|
+
# Get session state (or create new one)
|
|
185
|
+
session: SessionState = state.get("session", SessionState())
|
|
186
|
+
|
|
187
|
+
# Get existing conversation ID from session
|
|
188
|
+
existing_conversation_id: str | None = session.genie.get_conversation_id(
|
|
189
|
+
space_id_str
|
|
190
|
+
)
|
|
191
|
+
logger.trace(
|
|
192
|
+
"Using existing conversation ID",
|
|
193
|
+
space_id=space_id_str,
|
|
194
|
+
conversation_id=existing_conversation_id,
|
|
363
195
|
)
|
|
364
196
|
|
|
365
|
-
|
|
197
|
+
# Call ask_question which always returns CacheResult with cache metadata
|
|
198
|
+
cache_result: CacheResult = genie_service.ask_question(
|
|
366
199
|
question, conversation_id=existing_conversation_id
|
|
367
200
|
)
|
|
201
|
+
genie_response: GenieResponse = cache_result.response
|
|
202
|
+
cache_hit: bool = cache_result.cache_hit
|
|
203
|
+
cache_key: str | None = cache_result.served_by
|
|
368
204
|
|
|
369
|
-
current_conversation_id: str =
|
|
205
|
+
current_conversation_id: str = genie_response.conversation_id
|
|
370
206
|
logger.debug(
|
|
371
|
-
|
|
207
|
+
"Genie question answered",
|
|
208
|
+
space_id=space_id_str,
|
|
209
|
+
conversation_id=current_conversation_id,
|
|
210
|
+
cache_hit=cache_hit,
|
|
211
|
+
cache_key=cache_key,
|
|
372
212
|
)
|
|
373
213
|
|
|
374
|
-
# Update
|
|
214
|
+
# Update session state with cache information
|
|
215
|
+
if persist_conversation:
|
|
216
|
+
session.genie.update_space(
|
|
217
|
+
space_id=space_id_str,
|
|
218
|
+
conversation_id=current_conversation_id,
|
|
219
|
+
cache_hit=cache_hit,
|
|
220
|
+
cache_key=cache_key,
|
|
221
|
+
last_query=question,
|
|
222
|
+
)
|
|
375
223
|
|
|
224
|
+
# Build update dict with response and session
|
|
376
225
|
update: dict[str, Any] = {
|
|
377
|
-
"messages": [
|
|
226
|
+
"messages": [
|
|
227
|
+
ToolMessage(
|
|
228
|
+
_response_to_json(genie_response), tool_call_id=tool_call_id
|
|
229
|
+
)
|
|
230
|
+
],
|
|
378
231
|
}
|
|
379
232
|
|
|
380
233
|
if persist_conversation:
|
|
381
|
-
|
|
382
|
-
updated_conversation_ids[space_id] = current_conversation_id
|
|
383
|
-
update["genie_conversation_ids"] = updated_conversation_ids
|
|
384
|
-
|
|
385
|
-
logger.debug(f"State update: {update}")
|
|
234
|
+
update["session"] = session
|
|
386
235
|
|
|
387
236
|
return Command(update=update)
|
|
388
237
|
|