dao-ai 0.0.31__py3-none-any.whl → 0.0.33__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/config.py +162 -34
- dao_ai/prompts.py +1 -1
- dao_ai/providers/databricks.py +204 -146
- dao_ai/tools/core.py +1 -1
- dao_ai/tools/genie.py +26 -262
- dao_ai/tools/unity_catalog.py +31 -2
- dao_ai/tools/vector_search.py +4 -2
- dao_ai/utils.py +60 -7
- {dao_ai-0.0.31.dist-info → dao_ai-0.0.33.dist-info}/METADATA +15 -15
- {dao_ai-0.0.31.dist-info → dao_ai-0.0.33.dist-info}/RECORD +13 -13
- {dao_ai-0.0.31.dist-info → dao_ai-0.0.33.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.31.dist-info → dao_ai-0.0.33.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.31.dist-info → dao_ai-0.0.33.dist-info}/licenses/LICENSE +0 -0
dao_ai/tools/genie.py
CHANGED
|
@@ -1,15 +1,10 @@
|
|
|
1
|
-
import bisect
|
|
2
1
|
import json
|
|
3
2
|
import os
|
|
4
|
-
import time
|
|
5
|
-
from dataclasses import asdict, dataclass
|
|
6
|
-
from datetime import datetime
|
|
7
3
|
from textwrap import dedent
|
|
8
|
-
from typing import Annotated, Any, Callable
|
|
4
|
+
from typing import Annotated, Any, Callable
|
|
9
5
|
|
|
10
|
-
import mlflow
|
|
11
6
|
import pandas as pd
|
|
12
|
-
from
|
|
7
|
+
from databricks_ai_bridge.genie import Genie, GenieResponse
|
|
13
8
|
from langchain_core.messages import ToolMessage
|
|
14
9
|
from langchain_core.tools import InjectedToolCallId, tool
|
|
15
10
|
from langgraph.prebuilt import InjectedState
|
|
@@ -19,28 +14,6 @@ from pydantic import BaseModel, Field
|
|
|
19
14
|
|
|
20
15
|
from dao_ai.config import AnyVariable, CompositeVariableModel, GenieRoomModel, value_of
|
|
21
16
|
|
|
22
|
-
MAX_TOKENS_OF_DATA: int = 20000
|
|
23
|
-
MAX_ITERATIONS: int = 50
|
|
24
|
-
DEFAULT_POLLING_INTERVAL_SECS: int = 2
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
def _count_tokens(text):
|
|
28
|
-
import tiktoken
|
|
29
|
-
|
|
30
|
-
encoding = tiktoken.encoding_for_model("gpt-4o")
|
|
31
|
-
return len(encoding.encode(text))
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
@dataclass
|
|
35
|
-
class GenieResponse:
|
|
36
|
-
conversation_id: str
|
|
37
|
-
result: Union[str, pd.DataFrame]
|
|
38
|
-
query: Optional[str] = ""
|
|
39
|
-
description: Optional[str] = ""
|
|
40
|
-
|
|
41
|
-
def to_json(self):
|
|
42
|
-
return json.dumps(asdict(self))
|
|
43
|
-
|
|
44
17
|
|
|
45
18
|
class GenieToolInput(BaseModel):
|
|
46
19
|
"""Input schema for the Genie tool."""
|
|
@@ -50,235 +23,29 @@ class GenieToolInput(BaseModel):
|
|
|
50
23
|
)
|
|
51
24
|
|
|
52
25
|
|
|
53
|
-
def
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
return query_result.strip()
|
|
60
|
-
|
|
61
|
-
def is_too_big(n):
|
|
62
|
-
return _count_tokens(dataframe.iloc[:n].to_markdown()) > MAX_TOKENS_OF_DATA
|
|
63
|
-
|
|
64
|
-
# Use bisect_left to find the cutoff point of rows within the max token data limit in a O(log n) complexity
|
|
65
|
-
# Passing True, as this is the target value we are looking for when _is_too_big returns
|
|
66
|
-
cutoff = bisect.bisect_left(range(len(dataframe) + 1), True, key=is_too_big)
|
|
67
|
-
|
|
68
|
-
# Slice to the found limit
|
|
69
|
-
truncated_df = dataframe.iloc[:cutoff]
|
|
70
|
-
|
|
71
|
-
# Edge case: Cannot return any rows because of tokens so return an empty string
|
|
72
|
-
if len(truncated_df) == 0:
|
|
73
|
-
return ""
|
|
74
|
-
|
|
75
|
-
truncated_result = truncated_df.to_markdown()
|
|
76
|
-
|
|
77
|
-
# Double-check edge case if we overshot by one
|
|
78
|
-
if _count_tokens(truncated_result) > MAX_TOKENS_OF_DATA:
|
|
79
|
-
truncated_result = truncated_df.iloc[:-1].to_markdown()
|
|
80
|
-
return truncated_result
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
@mlflow.trace(span_type="PARSER")
|
|
84
|
-
def _parse_query_result(resp, truncate_results) -> Union[str, pd.DataFrame]:
|
|
85
|
-
output = resp["result"]
|
|
86
|
-
if not output:
|
|
87
|
-
return "EMPTY"
|
|
88
|
-
|
|
89
|
-
columns = resp["manifest"]["schema"]["columns"]
|
|
90
|
-
header = [str(col["name"]) for col in columns]
|
|
91
|
-
rows = []
|
|
92
|
-
|
|
93
|
-
for item in output["data_array"]:
|
|
94
|
-
row = []
|
|
95
|
-
for column, value in zip(columns, item):
|
|
96
|
-
type_name = column["type_name"]
|
|
97
|
-
if value is None:
|
|
98
|
-
row.append(None)
|
|
99
|
-
continue
|
|
100
|
-
|
|
101
|
-
if type_name in ["INT", "LONG", "SHORT", "BYTE"]:
|
|
102
|
-
row.append(int(value))
|
|
103
|
-
elif type_name in ["FLOAT", "DOUBLE", "DECIMAL"]:
|
|
104
|
-
row.append(float(value))
|
|
105
|
-
elif type_name == "BOOLEAN":
|
|
106
|
-
row.append(value.lower() == "true")
|
|
107
|
-
elif type_name == "DATE" or type_name == "TIMESTAMP":
|
|
108
|
-
row.append(datetime.strptime(value[:10], "%Y-%m-%d").date())
|
|
109
|
-
elif type_name == "BINARY":
|
|
110
|
-
row.append(bytes(value, "utf-8"))
|
|
111
|
-
else:
|
|
112
|
-
row.append(value)
|
|
113
|
-
|
|
114
|
-
rows.append(row)
|
|
26
|
+
def _response_to_json(response: GenieResponse) -> str:
|
|
27
|
+
"""Convert GenieResponse to JSON string, handling DataFrame results."""
|
|
28
|
+
# Convert result to string if it's a DataFrame
|
|
29
|
+
result: str | pd.DataFrame = response.result
|
|
30
|
+
if isinstance(result, pd.DataFrame):
|
|
31
|
+
result = result.to_markdown()
|
|
115
32
|
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
return query_result.strip()
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
class Genie:
|
|
127
|
-
def __init__(
|
|
128
|
-
self,
|
|
129
|
-
space_id,
|
|
130
|
-
client: WorkspaceClient | None = None,
|
|
131
|
-
truncate_results: bool = False,
|
|
132
|
-
polling_interval: int = DEFAULT_POLLING_INTERVAL_SECS,
|
|
133
|
-
):
|
|
134
|
-
self.space_id = space_id
|
|
135
|
-
workspace_client = client or WorkspaceClient()
|
|
136
|
-
self.genie = workspace_client.genie
|
|
137
|
-
self.description = self.genie.get_space(space_id).description
|
|
138
|
-
self.headers = {
|
|
139
|
-
"Accept": "application/json",
|
|
140
|
-
"Content-Type": "application/json",
|
|
141
|
-
}
|
|
142
|
-
self.truncate_results = truncate_results
|
|
143
|
-
if polling_interval < 1 or polling_interval > 30:
|
|
144
|
-
raise ValueError("poll_interval must be between 1 and 30 seconds")
|
|
145
|
-
self.poll_interval = polling_interval
|
|
146
|
-
|
|
147
|
-
@mlflow.trace()
|
|
148
|
-
def start_conversation(self, content):
|
|
149
|
-
resp = self.genie._api.do(
|
|
150
|
-
"POST",
|
|
151
|
-
f"/api/2.0/genie/spaces/{self.space_id}/start-conversation",
|
|
152
|
-
body={"content": content},
|
|
153
|
-
headers=self.headers,
|
|
154
|
-
)
|
|
155
|
-
return resp
|
|
156
|
-
|
|
157
|
-
@mlflow.trace()
|
|
158
|
-
def create_message(self, conversation_id, content):
|
|
159
|
-
resp = self.genie._api.do(
|
|
160
|
-
"POST",
|
|
161
|
-
f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages",
|
|
162
|
-
body={"content": content},
|
|
163
|
-
headers=self.headers,
|
|
164
|
-
)
|
|
165
|
-
return resp
|
|
166
|
-
|
|
167
|
-
@mlflow.trace()
|
|
168
|
-
def poll_for_result(self, conversation_id, message_id):
|
|
169
|
-
@mlflow.trace()
|
|
170
|
-
def poll_query_results(attachment_id, query_str, description):
|
|
171
|
-
iteration_count = 0
|
|
172
|
-
while iteration_count < MAX_ITERATIONS:
|
|
173
|
-
iteration_count += 1
|
|
174
|
-
resp = self.genie._api.do(
|
|
175
|
-
"GET",
|
|
176
|
-
f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages/{message_id}/attachments/{attachment_id}/query-result",
|
|
177
|
-
headers=self.headers,
|
|
178
|
-
)["statement_response"]
|
|
179
|
-
state = resp["status"]["state"]
|
|
180
|
-
if state == "SUCCEEDED":
|
|
181
|
-
result = _parse_query_result(resp, self.truncate_results)
|
|
182
|
-
return GenieResponse(
|
|
183
|
-
conversation_id, result, query_str, description
|
|
184
|
-
)
|
|
185
|
-
elif state in ["RUNNING", "PENDING"]:
|
|
186
|
-
logger.debug("Waiting for query result...")
|
|
187
|
-
time.sleep(self.poll_interval)
|
|
188
|
-
else:
|
|
189
|
-
return GenieResponse(
|
|
190
|
-
conversation_id,
|
|
191
|
-
f"No query result: {resp['state']}",
|
|
192
|
-
query_str,
|
|
193
|
-
description,
|
|
194
|
-
)
|
|
195
|
-
return GenieResponse(
|
|
196
|
-
conversation_id,
|
|
197
|
-
f"Genie query for result timed out after {MAX_ITERATIONS} iterations of {self.poll_interval} seconds",
|
|
198
|
-
query_str,
|
|
199
|
-
description,
|
|
200
|
-
)
|
|
201
|
-
|
|
202
|
-
@mlflow.trace()
|
|
203
|
-
def poll_result():
|
|
204
|
-
iteration_count = 0
|
|
205
|
-
while iteration_count < MAX_ITERATIONS:
|
|
206
|
-
iteration_count += 1
|
|
207
|
-
resp = self.genie._api.do(
|
|
208
|
-
"GET",
|
|
209
|
-
f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages/{message_id}",
|
|
210
|
-
headers=self.headers,
|
|
211
|
-
)
|
|
212
|
-
if resp["status"] == "COMPLETED":
|
|
213
|
-
# Check if attachments key exists in response
|
|
214
|
-
attachments = resp.get("attachments", [])
|
|
215
|
-
if not attachments:
|
|
216
|
-
# Handle case where response has no attachments
|
|
217
|
-
return GenieResponse(
|
|
218
|
-
conversation_id,
|
|
219
|
-
result=f"Genie query completed but no attachments found. Response: {resp}",
|
|
220
|
-
)
|
|
221
|
-
|
|
222
|
-
attachment = next((r for r in attachments if "query" in r), None)
|
|
223
|
-
if attachment:
|
|
224
|
-
query_obj = attachment["query"]
|
|
225
|
-
description = query_obj.get("description", "")
|
|
226
|
-
query_str = query_obj.get("query", "")
|
|
227
|
-
attachment_id = attachment["attachment_id"]
|
|
228
|
-
return poll_query_results(attachment_id, query_str, description)
|
|
229
|
-
if resp["status"] == "COMPLETED":
|
|
230
|
-
text_content = next(
|
|
231
|
-
(r for r in attachments if "text" in r), None
|
|
232
|
-
)
|
|
233
|
-
if text_content:
|
|
234
|
-
return GenieResponse(
|
|
235
|
-
conversation_id, result=text_content["text"]["content"]
|
|
236
|
-
)
|
|
237
|
-
return GenieResponse(
|
|
238
|
-
conversation_id,
|
|
239
|
-
result="Genie query completed but no text content found in attachments.",
|
|
240
|
-
)
|
|
241
|
-
elif resp["status"] in {"CANCELLED", "QUERY_RESULT_EXPIRED"}:
|
|
242
|
-
return GenieResponse(
|
|
243
|
-
conversation_id, result=f"Genie query {resp['status'].lower()}."
|
|
244
|
-
)
|
|
245
|
-
elif resp["status"] == "FAILED":
|
|
246
|
-
return GenieResponse(
|
|
247
|
-
conversation_id,
|
|
248
|
-
result=f"Genie query failed with error: {resp.get('error', 'Unknown error')}",
|
|
249
|
-
)
|
|
250
|
-
# includes EXECUTING_QUERY, Genie can retry after this status
|
|
251
|
-
else:
|
|
252
|
-
logger.debug(f"Waiting...: {resp['status']}")
|
|
253
|
-
time.sleep(self.poll_interval)
|
|
254
|
-
return GenieResponse(
|
|
255
|
-
conversation_id,
|
|
256
|
-
f"Genie query timed out after {MAX_ITERATIONS} iterations of {self.poll_interval} seconds",
|
|
257
|
-
)
|
|
258
|
-
|
|
259
|
-
return poll_result()
|
|
260
|
-
|
|
261
|
-
@mlflow.trace()
|
|
262
|
-
def ask_question(self, question: str, conversation_id: str | None = None):
|
|
263
|
-
logger.debug(
|
|
264
|
-
f"ask_question called with question: {question}, conversation_id: {conversation_id}"
|
|
265
|
-
)
|
|
266
|
-
if conversation_id:
|
|
267
|
-
resp = self.create_message(conversation_id, question)
|
|
268
|
-
else:
|
|
269
|
-
resp = self.start_conversation(question)
|
|
270
|
-
logger.debug(f"ask_question response: {resp}")
|
|
271
|
-
return self.poll_for_result(resp["conversation_id"], resp["message_id"])
|
|
33
|
+
data: dict[str, Any] = {
|
|
34
|
+
"result": result,
|
|
35
|
+
"query": response.query,
|
|
36
|
+
"description": response.description,
|
|
37
|
+
"conversation_id": response.conversation_id,
|
|
38
|
+
}
|
|
39
|
+
return json.dumps(data)
|
|
272
40
|
|
|
273
41
|
|
|
274
42
|
def create_genie_tool(
|
|
275
43
|
genie_room: GenieRoomModel | dict[str, Any],
|
|
276
|
-
name:
|
|
277
|
-
description:
|
|
44
|
+
name: str | None = None,
|
|
45
|
+
description: str | None = None,
|
|
278
46
|
persist_conversation: bool = False,
|
|
279
47
|
truncate_results: bool = False,
|
|
280
|
-
|
|
281
|
-
) -> Callable[[str], GenieResponse]:
|
|
48
|
+
) -> Callable[..., Command]:
|
|
282
49
|
"""
|
|
283
50
|
Create a tool for interacting with Databricks Genie for natural language queries to databases.
|
|
284
51
|
|
|
@@ -290,6 +57,9 @@ def create_genie_tool(
|
|
|
290
57
|
genie_room: GenieRoomModel or dict containing Genie configuration
|
|
291
58
|
name: Optional custom name for the tool. If None, uses default "genie_tool"
|
|
292
59
|
description: Optional custom description for the tool. If None, uses default description
|
|
60
|
+
persist_conversation: Whether to persist conversation IDs across tool calls for
|
|
61
|
+
multi-turn conversations within the same Genie space
|
|
62
|
+
truncate_results: Whether to truncate large query results to fit token limits
|
|
293
63
|
|
|
294
64
|
Returns:
|
|
295
65
|
A LangGraph tool that processes natural language queries through Genie
|
|
@@ -305,13 +75,6 @@ def create_genie_tool(
|
|
|
305
75
|
space_id = CompositeVariableModel(**space_id)
|
|
306
76
|
space_id = value_of(space_id)
|
|
307
77
|
|
|
308
|
-
# genie: Genie = Genie(
|
|
309
|
-
# space_id=space_id,
|
|
310
|
-
# client=genie_room.workspace_client,
|
|
311
|
-
# truncate_results=truncate_results,
|
|
312
|
-
# polling_interval=poll_interval,
|
|
313
|
-
# )
|
|
314
|
-
|
|
315
78
|
default_description: str = dedent("""
|
|
316
79
|
This tool lets you have a conversation and chat with tabular data about <topic>. You should ask
|
|
317
80
|
questions about the data and the tool will try to answer them.
|
|
@@ -343,14 +106,14 @@ GenieResponse: A response object containing the conversation ID and result from
|
|
|
343
106
|
state: Annotated[dict, InjectedState],
|
|
344
107
|
tool_call_id: Annotated[str, InjectedToolCallId],
|
|
345
108
|
) -> Command:
|
|
109
|
+
"""Process a natural language question through Databricks Genie."""
|
|
110
|
+
# Create Genie instance using databricks_langchain implementation
|
|
346
111
|
genie: Genie = Genie(
|
|
347
112
|
space_id=space_id,
|
|
348
113
|
client=genie_room.workspace_client,
|
|
349
114
|
truncate_results=truncate_results,
|
|
350
|
-
polling_interval=poll_interval,
|
|
351
115
|
)
|
|
352
116
|
|
|
353
|
-
"""Process a natural language question through Databricks Genie."""
|
|
354
117
|
# Get existing conversation mapping and retrieve conversation ID for this space
|
|
355
118
|
conversation_ids: dict[str, str] = state.get("genie_conversation_ids", {})
|
|
356
119
|
existing_conversation_id: str | None = conversation_ids.get(space_id)
|
|
@@ -368,9 +131,10 @@ GenieResponse: A response object containing the conversation ID and result from
|
|
|
368
131
|
)
|
|
369
132
|
|
|
370
133
|
# Update the conversation mapping with the new conversation ID for this space
|
|
371
|
-
|
|
372
134
|
update: dict[str, Any] = {
|
|
373
|
-
"messages": [
|
|
135
|
+
"messages": [
|
|
136
|
+
ToolMessage(_response_to_json(response), tool_call_id=tool_call_id)
|
|
137
|
+
],
|
|
374
138
|
}
|
|
375
139
|
|
|
376
140
|
if persist_conversation:
|
dao_ai/tools/unity_catalog.py
CHANGED
|
@@ -265,23 +265,52 @@ def with_partial_args(
|
|
|
265
265
|
|
|
266
266
|
Args:
|
|
267
267
|
tool: ToolModel containing the Unity Catalog function configuration
|
|
268
|
-
partial_args: Dictionary of arguments to pre-fill in the tool
|
|
268
|
+
partial_args: Dictionary of arguments to pre-fill in the tool.
|
|
269
|
+
Supports:
|
|
270
|
+
- client_id, client_secret: OAuth credentials directly
|
|
271
|
+
- service_principal: ServicePrincipalModel with client_id and client_secret
|
|
272
|
+
- host or workspace_host: Databricks workspace host
|
|
269
273
|
|
|
270
274
|
Returns:
|
|
271
275
|
StructuredTool: A LangChain tool with partial arguments pre-filled
|
|
272
276
|
"""
|
|
273
277
|
from unitycatalog.ai.langchain.toolkit import generate_function_input_params_schema
|
|
274
278
|
|
|
279
|
+
from dao_ai.config import ServicePrincipalModel
|
|
280
|
+
|
|
275
281
|
logger.debug(f"with_partial_args: {tool}")
|
|
276
282
|
|
|
277
283
|
# Convert dict-based variables to CompositeVariableModel and resolve their values
|
|
278
|
-
resolved_args = {}
|
|
284
|
+
resolved_args: dict[str, Any] = {}
|
|
279
285
|
for k, v in partial_args.items():
|
|
280
286
|
if isinstance(v, dict):
|
|
281
287
|
resolved_args[k] = value_of(CompositeVariableModel(**v))
|
|
282
288
|
else:
|
|
283
289
|
resolved_args[k] = value_of(v)
|
|
284
290
|
|
|
291
|
+
# Handle service_principal - expand into client_id and client_secret
|
|
292
|
+
if "service_principal" in resolved_args:
|
|
293
|
+
sp = resolved_args.pop("service_principal")
|
|
294
|
+
if isinstance(sp, dict):
|
|
295
|
+
sp = ServicePrincipalModel(**sp)
|
|
296
|
+
if isinstance(sp, ServicePrincipalModel):
|
|
297
|
+
if "client_id" not in resolved_args:
|
|
298
|
+
resolved_args["client_id"] = value_of(sp.client_id)
|
|
299
|
+
if "client_secret" not in resolved_args:
|
|
300
|
+
resolved_args["client_secret"] = value_of(sp.client_secret)
|
|
301
|
+
|
|
302
|
+
# Normalize host/workspace_host - accept either key
|
|
303
|
+
if "workspace_host" in resolved_args and "host" not in resolved_args:
|
|
304
|
+
resolved_args["host"] = resolved_args.pop("workspace_host")
|
|
305
|
+
|
|
306
|
+
# Default host from WorkspaceClient if not provided
|
|
307
|
+
if "host" not in resolved_args:
|
|
308
|
+
from dao_ai.utils import get_default_databricks_host
|
|
309
|
+
|
|
310
|
+
host: str | None = get_default_databricks_host()
|
|
311
|
+
if host:
|
|
312
|
+
resolved_args["host"] = host
|
|
313
|
+
|
|
285
314
|
logger.debug(f"Resolved partial args: {resolved_args.keys()}")
|
|
286
315
|
|
|
287
316
|
if isinstance(tool, dict):
|
dao_ai/tools/vector_search.py
CHANGED
|
@@ -101,7 +101,7 @@ def create_vector_search_tool(
|
|
|
101
101
|
# Initialize the vector store
|
|
102
102
|
# Note: text_column is only required for self-managed embeddings
|
|
103
103
|
# For Databricks-managed embeddings, it's automatically determined from the index
|
|
104
|
-
|
|
104
|
+
|
|
105
105
|
# Build client_args for VectorSearchClient from environment variables
|
|
106
106
|
# This is needed because during MLflow model validation, credentials must be
|
|
107
107
|
# explicitly passed to VectorSearchClient via client_args.
|
|
@@ -121,7 +121,9 @@ def create_vector_search_tool(
|
|
|
121
121
|
"DATABRICKS_CLIENT_SECRET"
|
|
122
122
|
)
|
|
123
123
|
|
|
124
|
-
logger.debug(
|
|
124
|
+
logger.debug(
|
|
125
|
+
f"Creating DatabricksVectorSearch with client_args keys: {list(client_args.keys())}"
|
|
126
|
+
)
|
|
125
127
|
|
|
126
128
|
# Pass both workspace_client (for model serving detection) and client_args (for credentials)
|
|
127
129
|
vector_store: DatabricksVectorSearch = DatabricksVectorSearch(
|
dao_ai/utils.py
CHANGED
|
@@ -38,6 +38,32 @@ def normalize_name(name: str) -> str:
|
|
|
38
38
|
return normalized.strip("_")
|
|
39
39
|
|
|
40
40
|
|
|
41
|
+
def get_default_databricks_host() -> str | None:
|
|
42
|
+
"""Get the default Databricks workspace host.
|
|
43
|
+
|
|
44
|
+
Attempts to get the host from:
|
|
45
|
+
1. DATABRICKS_HOST environment variable
|
|
46
|
+
2. WorkspaceClient ambient authentication (e.g., from ~/.databrickscfg)
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
The Databricks workspace host URL, or None if not available.
|
|
50
|
+
"""
|
|
51
|
+
# Try environment variable first
|
|
52
|
+
host: str | None = os.environ.get("DATABRICKS_HOST")
|
|
53
|
+
if host:
|
|
54
|
+
return host
|
|
55
|
+
|
|
56
|
+
# Fall back to WorkspaceClient
|
|
57
|
+
try:
|
|
58
|
+
from databricks.sdk import WorkspaceClient
|
|
59
|
+
|
|
60
|
+
w: WorkspaceClient = WorkspaceClient()
|
|
61
|
+
return w.config.host
|
|
62
|
+
except Exception:
|
|
63
|
+
logger.debug("Could not get default Databricks host from WorkspaceClient")
|
|
64
|
+
return None
|
|
65
|
+
|
|
66
|
+
|
|
41
67
|
def dao_ai_version() -> str:
|
|
42
68
|
"""
|
|
43
69
|
Get the dao-ai package version, with fallback for source installations.
|
|
@@ -99,7 +125,7 @@ def get_installed_packages() -> dict[str, str]:
|
|
|
99
125
|
f"databricks-langchain=={version('databricks-langchain')}",
|
|
100
126
|
f"databricks-mcp=={version('databricks-mcp')}",
|
|
101
127
|
f"databricks-sdk[openai]=={version('databricks-sdk')}",
|
|
102
|
-
f"
|
|
128
|
+
f"ddgs=={version('ddgs')}",
|
|
103
129
|
f"flashrank=={version('flashrank')}",
|
|
104
130
|
f"langchain=={version('langchain')}",
|
|
105
131
|
f"langchain-mcp-adapters=={version('langchain-mcp-adapters')}",
|
|
@@ -141,12 +167,12 @@ def load_function(function_name: str) -> Callable[..., Any]:
|
|
|
141
167
|
"module.submodule.function_name"
|
|
142
168
|
|
|
143
169
|
Returns:
|
|
144
|
-
The imported callable function
|
|
170
|
+
The imported callable function or langchain tool
|
|
145
171
|
|
|
146
172
|
Raises:
|
|
147
173
|
ImportError: If the module cannot be imported
|
|
148
174
|
AttributeError: If the function doesn't exist in the module
|
|
149
|
-
TypeError: If the resolved object is not callable
|
|
175
|
+
TypeError: If the resolved object is not callable or invocable
|
|
150
176
|
|
|
151
177
|
Example:
|
|
152
178
|
>>> func = callable_from_fqn("dao_ai.models.get_latest_model_version")
|
|
@@ -164,9 +190,14 @@ def load_function(function_name: str) -> Callable[..., Any]:
|
|
|
164
190
|
# Get the function from the module
|
|
165
191
|
func = getattr(module, func_name)
|
|
166
192
|
|
|
167
|
-
# Verify that the resolved object is callable
|
|
168
|
-
|
|
169
|
-
|
|
193
|
+
# Verify that the resolved object is callable or is a langchain tool
|
|
194
|
+
# In langchain 1.x, StructuredTool objects are not directly callable
|
|
195
|
+
# but have an invoke() method
|
|
196
|
+
is_callable = callable(func)
|
|
197
|
+
is_langchain_tool = hasattr(func, "invoke") and hasattr(func, "name")
|
|
198
|
+
|
|
199
|
+
if not is_callable and not is_langchain_tool:
|
|
200
|
+
raise TypeError(f"Function {func_name} is not callable or invocable.")
|
|
170
201
|
|
|
171
202
|
return func
|
|
172
203
|
except (ImportError, AttributeError, TypeError) as e:
|
|
@@ -175,4 +206,26 @@ def load_function(function_name: str) -> Callable[..., Any]:
|
|
|
175
206
|
|
|
176
207
|
|
|
177
208
|
def is_in_model_serving() -> bool:
|
|
178
|
-
|
|
209
|
+
"""Check if running in Databricks Model Serving environment.
|
|
210
|
+
|
|
211
|
+
Detects Model Serving by checking for environment variables that are
|
|
212
|
+
typically set in that environment.
|
|
213
|
+
"""
|
|
214
|
+
# Primary check - explicit Databricks Model Serving env var
|
|
215
|
+
if os.environ.get("IS_IN_DB_MODEL_SERVING_ENV", "false").lower() == "true":
|
|
216
|
+
return True
|
|
217
|
+
|
|
218
|
+
# Secondary check - Model Serving sets these environment variables
|
|
219
|
+
if os.environ.get("DATABRICKS_MODEL_SERVING_ENV"):
|
|
220
|
+
return True
|
|
221
|
+
|
|
222
|
+
# Check for cluster type indicator
|
|
223
|
+
cluster_type = os.environ.get("DATABRICKS_CLUSTER_TYPE", "")
|
|
224
|
+
if "model-serving" in cluster_type.lower():
|
|
225
|
+
return True
|
|
226
|
+
|
|
227
|
+
# Check for model serving specific paths
|
|
228
|
+
if os.path.exists("/opt/conda/envs/mlflow-env"):
|
|
229
|
+
return True
|
|
230
|
+
|
|
231
|
+
return False
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dao-ai
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.33
|
|
4
4
|
Summary: DAO AI: A modular, multi-agent orchestration framework for complex AI workflows. Supports agent handoff, tool integration, and dynamic configuration via YAML.
|
|
5
5
|
Project-URL: Homepage, https://github.com/natefleming/dao-ai
|
|
6
6
|
Project-URL: Documentation, https://natefleming.github.io/dao-ai
|
|
@@ -25,29 +25,29 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
|
25
25
|
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
26
26
|
Classifier: Topic :: System :: Distributed Computing
|
|
27
27
|
Requires-Python: >=3.11
|
|
28
|
-
Requires-Dist: databricks-agents>=1.
|
|
29
|
-
Requires-Dist: databricks-langchain>=0.
|
|
28
|
+
Requires-Dist: databricks-agents>=1.8.2
|
|
29
|
+
Requires-Dist: databricks-langchain>=0.11.0
|
|
30
30
|
Requires-Dist: databricks-mcp>=0.3.0
|
|
31
31
|
Requires-Dist: databricks-sdk[openai]>=0.67.0
|
|
32
|
-
Requires-Dist:
|
|
32
|
+
Requires-Dist: ddgs>=9.9.3
|
|
33
33
|
Requires-Dist: flashrank>=0.2.8
|
|
34
34
|
Requires-Dist: gepa>=0.0.17
|
|
35
35
|
Requires-Dist: grandalf>=0.8
|
|
36
|
-
Requires-Dist: langchain-mcp-adapters>=0.1
|
|
36
|
+
Requires-Dist: langchain-mcp-adapters>=0.2.1
|
|
37
37
|
Requires-Dist: langchain-tavily>=0.2.11
|
|
38
|
-
Requires-Dist: langchain>=
|
|
39
|
-
Requires-Dist: langgraph-checkpoint-postgres>=
|
|
40
|
-
Requires-Dist: langgraph-supervisor>=0.0.
|
|
41
|
-
Requires-Dist: langgraph-swarm>=0.0
|
|
42
|
-
Requires-Dist: langgraph>=0.
|
|
43
|
-
Requires-Dist: langmem>=0.0.
|
|
38
|
+
Requires-Dist: langchain>=1.1.3
|
|
39
|
+
Requires-Dist: langgraph-checkpoint-postgres>=3.0.2
|
|
40
|
+
Requires-Dist: langgraph-supervisor>=0.0.31
|
|
41
|
+
Requires-Dist: langgraph-swarm>=0.1.0
|
|
42
|
+
Requires-Dist: langgraph>=1.0.4
|
|
43
|
+
Requires-Dist: langmem>=0.0.30
|
|
44
44
|
Requires-Dist: loguru>=0.7.3
|
|
45
|
-
Requires-Dist: mcp>=1.
|
|
46
|
-
Requires-Dist: mlflow>=3.
|
|
45
|
+
Requires-Dist: mcp>=1.23.3
|
|
46
|
+
Requires-Dist: mlflow>=3.7.0
|
|
47
47
|
Requires-Dist: nest-asyncio>=1.6.0
|
|
48
48
|
Requires-Dist: openevals>=0.0.19
|
|
49
49
|
Requires-Dist: openpyxl>=3.1.5
|
|
50
|
-
Requires-Dist: psycopg[binary,pool]>=3.2
|
|
50
|
+
Requires-Dist: psycopg[binary,pool]>=3.3.2
|
|
51
51
|
Requires-Dist: pydantic>=2.12.0
|
|
52
52
|
Requires-Dist: python-dotenv>=1.1.0
|
|
53
53
|
Requires-Dist: pyyaml>=6.0.2
|
|
@@ -55,7 +55,7 @@ Requires-Dist: rich>=14.0.0
|
|
|
55
55
|
Requires-Dist: scipy<=1.15
|
|
56
56
|
Requires-Dist: sqlparse>=0.5.3
|
|
57
57
|
Requires-Dist: tomli>=2.3.0
|
|
58
|
-
Requires-Dist: unitycatalog-ai[databricks]>=0.3.
|
|
58
|
+
Requires-Dist: unitycatalog-ai[databricks]>=0.3.2
|
|
59
59
|
Provides-Extra: databricks
|
|
60
60
|
Requires-Dist: databricks-connect>=15.0.0; extra == 'databricks'
|
|
61
61
|
Requires-Dist: databricks-vectorsearch>=0.63; extra == 'databricks'
|
|
@@ -3,16 +3,16 @@ dao_ai/agent_as_code.py,sha256=sviZQV7ZPxE5zkZ9jAbfegI681nra5i8yYxw05e3X7U,552
|
|
|
3
3
|
dao_ai/catalog.py,sha256=sPZpHTD3lPx4EZUtIWeQV7VQM89WJ6YH__wluk1v2lE,4947
|
|
4
4
|
dao_ai/chat_models.py,sha256=uhwwOTeLyHWqoTTgHrs4n5iSyTwe4EQcLKnh3jRxPWI,8626
|
|
5
5
|
dao_ai/cli.py,sha256=gq-nsapWxDA1M6Jua3vajBvIwf0Oa6YLcB58lEtMKUo,22503
|
|
6
|
-
dao_ai/config.py,sha256=
|
|
6
|
+
dao_ai/config.py,sha256=Uj0FgOhjnYp0qEmY44mCnp3Ijafg-381FNXt8R_QuWw,78513
|
|
7
7
|
dao_ai/graph.py,sha256=9kjJx0oFZKq5J9-Kpri4-0VCJILHYdYyhqQnj0_noxQ,8913
|
|
8
8
|
dao_ai/guardrails.py,sha256=4TKArDONRy8RwHzOT1plZ1rhy3x9GF_aeGpPCRl6wYA,4016
|
|
9
9
|
dao_ai/messages.py,sha256=xl_3-WcFqZKCFCiov8sZOPljTdM3gX3fCHhxq-xFg2U,7005
|
|
10
10
|
dao_ai/models.py,sha256=8r8GIG3EGxtVyWsRNI56lVaBjiNrPkzh4HdwMZRq8iw,31689
|
|
11
11
|
dao_ai/nodes.py,sha256=iQ_5vL6mt1UcRnhwgz-l1D8Ww4CMQrSMVnP_Lu7fFjU,8781
|
|
12
|
-
dao_ai/prompts.py,sha256=
|
|
12
|
+
dao_ai/prompts.py,sha256=iA2Iaky7yzjwWT5cxg0cUIgwo1z1UVQua__8WPnvV6g,1633
|
|
13
13
|
dao_ai/state.py,sha256=_lF9krAYYjvFDMUwZzVKOn0ZnXKcOrbjWKdre0C5B54,1137
|
|
14
14
|
dao_ai/types.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
15
|
-
dao_ai/utils.py,sha256=
|
|
15
|
+
dao_ai/utils.py,sha256=oIPmz02kZ3LMntbqxUajFXh4nswOhbvEjOTi4e5_cvI,8500
|
|
16
16
|
dao_ai/vector_search.py,sha256=jlaFS_iizJ55wblgzZmswMM3UOL-qOp2BGJc0JqXYSg,2839
|
|
17
17
|
dao_ai/hooks/__init__.py,sha256=LlHGIuiZt6vGW8K5AQo1XJEkBP5vDVtMhq0IdjcLrD4,417
|
|
18
18
|
dao_ai/hooks/core.py,sha256=ZShHctUSoauhBgdf1cecy9-D7J6-sGn-pKjuRMumW5U,6663
|
|
@@ -22,20 +22,20 @@ dao_ai/memory/core.py,sha256=DnEjQO3S7hXr3CDDd7C2eE7fQUmcCS_8q9BXEgjPH3U,4271
|
|
|
22
22
|
dao_ai/memory/postgres.py,sha256=vvI3osjx1EoU5GBA6SCUstTBKillcmLl12hVgDMjfJY,15346
|
|
23
23
|
dao_ai/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
24
24
|
dao_ai/providers/base.py,sha256=-fjKypCOk28h6vioPfMj9YZSw_3Kcbi2nMuAyY7vX9k,1383
|
|
25
|
-
dao_ai/providers/databricks.py,sha256=
|
|
25
|
+
dao_ai/providers/databricks.py,sha256=rPBMdGcJvdGBRK9FZeBxkLfcTpXyxU1cs14YllyZKbY,67857
|
|
26
26
|
dao_ai/tools/__init__.py,sha256=G5-5Yi6zpQOH53b5IzLdtsC6g0Ep6leI5GxgxOmgw7Q,1203
|
|
27
27
|
dao_ai/tools/agent.py,sha256=WbQnyziiT12TLMrA7xK0VuOU029tdmUBXbUl-R1VZ0Q,1886
|
|
28
|
-
dao_ai/tools/core.py,sha256=
|
|
29
|
-
dao_ai/tools/genie.py,sha256=
|
|
28
|
+
dao_ai/tools/core.py,sha256=kN77fWOzVY7qOs4NiW72cUxCsSTC0DnPp73s6VJEZOQ,1991
|
|
29
|
+
dao_ai/tools/genie.py,sha256=BPM_1Sk5bf7QSCFPPboWWkZKYwBwDwbGhMVp5-QDd10,5956
|
|
30
30
|
dao_ai/tools/human_in_the_loop.py,sha256=yk35MO9eNETnYFH-sqlgR-G24TrEgXpJlnZUustsLkI,3681
|
|
31
31
|
dao_ai/tools/mcp.py,sha256=5aQoRtx2z4xm6zgRslc78rSfEQe-mfhqov2NsiybYfc,8416
|
|
32
32
|
dao_ai/tools/python.py,sha256=XcQiTMshZyLUTVR5peB3vqsoUoAAy8gol9_pcrhddfI,1831
|
|
33
33
|
dao_ai/tools/slack.py,sha256=SCvyVcD9Pv_XXPXePE_fSU1Pd8VLTEkKDLvoGTZWy2Y,4775
|
|
34
34
|
dao_ai/tools/time.py,sha256=Y-23qdnNHzwjvnfkWvYsE7PoWS1hfeKy44tA7sCnNac,8759
|
|
35
|
-
dao_ai/tools/unity_catalog.py,sha256=
|
|
36
|
-
dao_ai/tools/vector_search.py,sha256=
|
|
37
|
-
dao_ai-0.0.
|
|
38
|
-
dao_ai-0.0.
|
|
39
|
-
dao_ai-0.0.
|
|
40
|
-
dao_ai-0.0.
|
|
41
|
-
dao_ai-0.0.
|
|
35
|
+
dao_ai/tools/unity_catalog.py,sha256=K9t8M4spsbxbecWmV5yEZy16s_AG7AfaoxT-7IDW43I,14438
|
|
36
|
+
dao_ai/tools/vector_search.py,sha256=3cdiUaFpox25GSRNec7FKceY3DuLp7dLVH8FRA0BgeY,12624
|
|
37
|
+
dao_ai-0.0.33.dist-info/METADATA,sha256=aa4BvkiG1dEvLorpgADosf1LCKRVBg-n8LtReVYJNxc,42761
|
|
38
|
+
dao_ai-0.0.33.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
39
|
+
dao_ai-0.0.33.dist-info/entry_points.txt,sha256=Xa-UFyc6gWGwMqMJOt06ZOog2vAfygV_DSwg1AiP46g,43
|
|
40
|
+
dao_ai-0.0.33.dist-info/licenses/LICENSE,sha256=YZt3W32LtPYruuvHE9lGk2bw6ZPMMJD8yLrjgHybyz4,1069
|
|
41
|
+
dao_ai-0.0.33.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|