dao-ai 0.0.6__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/cli.py +3 -0
- dao_ai/config.py +21 -3
- dao_ai/graph.py +31 -31
- dao_ai/hooks/__init__.py +2 -0
- dao_ai/hooks/core.py +96 -30
- dao_ai/memory/postgres.py +6 -6
- dao_ai/messages.py +6 -0
- dao_ai/models.py +66 -32
- dao_ai/nodes.py +12 -10
- dao_ai/providers/databricks.py +83 -3
- dao_ai/state.py +7 -0
- dao_ai/tools/__init__.py +3 -4
- dao_ai/tools/core.py +1 -294
- dao_ai/tools/human_in_the_loop.py +96 -0
- dao_ai/tools/mcp.py +118 -0
- dao_ai/tools/python.py +60 -0
- dao_ai/tools/unity_catalog.py +50 -0
- {dao_ai-0.0.6.dist-info → dao_ai-0.0.7.dist-info}/METADATA +10 -11
- dao_ai-0.0.7.dist-info/RECORD +40 -0
- dao_ai-0.0.6.dist-info/RECORD +0 -36
- {dao_ai-0.0.6.dist-info → dao_ai-0.0.7.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.6.dist-info → dao_ai-0.0.7.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.6.dist-info → dao_ai-0.0.7.dist-info}/licenses/LICENSE +0 -0
dao_ai/nodes.py
CHANGED
|
@@ -10,6 +10,7 @@ from langchain_core.tools import BaseTool
|
|
|
10
10
|
from langgraph.graph import StateGraph
|
|
11
11
|
from langgraph.graph.state import CompiledStateGraph
|
|
12
12
|
from langgraph.prebuilt import create_react_agent
|
|
13
|
+
from langgraph.runtime import Runtime
|
|
13
14
|
from langmem import create_manage_memory_tool, create_search_memory_tool
|
|
14
15
|
from langmem.short_term import SummarizationNode
|
|
15
16
|
from langmem.short_term.summarization import TokenCounter
|
|
@@ -26,7 +27,7 @@ from dao_ai.config import (
|
|
|
26
27
|
from dao_ai.guardrails import reflection_guardrail, with_guardrails
|
|
27
28
|
from dao_ai.hooks.core import create_hooks
|
|
28
29
|
from dao_ai.prompts import make_prompt
|
|
29
|
-
from dao_ai.state import IncomingState, SharedState
|
|
30
|
+
from dao_ai.state import Context, IncomingState, SharedState
|
|
30
31
|
from dao_ai.tools import create_tools
|
|
31
32
|
|
|
32
33
|
|
|
@@ -53,6 +54,7 @@ def summarization_node(app_model: AppModel) -> RunnableLike:
|
|
|
53
54
|
)
|
|
54
55
|
|
|
55
56
|
summarization_model: LanguageModelLike = chat_history.model.as_chat_model()
|
|
57
|
+
|
|
56
58
|
node: RunnableLike = SummarizationNode(
|
|
57
59
|
model=summarization_model,
|
|
58
60
|
max_tokens=max_tokens,
|
|
@@ -67,7 +69,7 @@ def summarization_node(app_model: AppModel) -> RunnableLike:
|
|
|
67
69
|
|
|
68
70
|
|
|
69
71
|
def call_agent_with_summarized_messages(agent: CompiledStateGraph) -> RunnableLike:
|
|
70
|
-
def call_agent(state: SharedState,
|
|
72
|
+
def call_agent(state: SharedState, runtime: Runtime[Context]) -> SharedState:
|
|
71
73
|
logger.debug(f"Calling agent {agent.name} with summarized messages")
|
|
72
74
|
|
|
73
75
|
# Get the summarized messages from the summarization node
|
|
@@ -79,7 +81,7 @@ def call_agent_with_summarized_messages(agent: CompiledStateGraph) -> RunnableLi
|
|
|
79
81
|
"messages": messages,
|
|
80
82
|
}
|
|
81
83
|
|
|
82
|
-
response: dict[str, Any] = agent.invoke(input=input,
|
|
84
|
+
response: dict[str, Any] = agent.invoke(input=input, context=runtime.context)
|
|
83
85
|
response_messages = response.get("messages", [])
|
|
84
86
|
logger.debug(f"Agent returned {len(response_messages)} messages")
|
|
85
87
|
|
|
@@ -147,9 +149,9 @@ def create_agent_node(
|
|
|
147
149
|
prompt=make_prompt(agent.prompt),
|
|
148
150
|
tools=tools,
|
|
149
151
|
store=True,
|
|
150
|
-
state_schema=SharedState,
|
|
151
|
-
config_schema=RunnableConfig,
|
|
152
152
|
checkpointer=True,
|
|
153
|
+
state_schema=SharedState,
|
|
154
|
+
context_schema=Context,
|
|
153
155
|
pre_model_hook=pre_agent_hook,
|
|
154
156
|
post_model_hook=post_agent_hook,
|
|
155
157
|
)
|
|
@@ -165,17 +167,17 @@ def create_agent_node(
|
|
|
165
167
|
chat_history: ChatHistoryModel = app.chat_history
|
|
166
168
|
|
|
167
169
|
if chat_history is None:
|
|
170
|
+
logger.debug("No chat history configured, using compiled agent directly")
|
|
168
171
|
agent_node = compiled_agent
|
|
169
172
|
else:
|
|
173
|
+
logger.debug("Creating agent node with chat history summarization")
|
|
170
174
|
workflow: StateGraph = StateGraph(
|
|
171
175
|
SharedState,
|
|
172
176
|
config_schema=RunnableConfig,
|
|
173
177
|
input=SharedState,
|
|
174
178
|
output=SharedState,
|
|
175
179
|
)
|
|
176
|
-
workflow.add_node(
|
|
177
|
-
"summarization", summarization_node(chat_history=chat_history)
|
|
178
|
-
)
|
|
180
|
+
workflow.add_node("summarization", summarization_node(app))
|
|
179
181
|
workflow.add_node(
|
|
180
182
|
"agent",
|
|
181
183
|
call_agent_with_summarized_messages(agent=compiled_agent),
|
|
@@ -191,7 +193,7 @@ def message_hook_node(config: AppConfig) -> RunnableLike:
|
|
|
191
193
|
message_hooks: Sequence[Callable[..., Any]] = create_hooks(config.app.message_hooks)
|
|
192
194
|
|
|
193
195
|
@mlflow.trace()
|
|
194
|
-
def message_hook(state: IncomingState,
|
|
196
|
+
def message_hook(state: IncomingState, runtime: Runtime[Context]) -> SharedState:
|
|
195
197
|
logger.debug("Running message validation")
|
|
196
198
|
response: dict[str, Any] = {"is_valid": True, "message_error": None}
|
|
197
199
|
|
|
@@ -201,7 +203,7 @@ def message_hook_node(config: AppConfig) -> RunnableLike:
|
|
|
201
203
|
try:
|
|
202
204
|
hook_response: dict[str, Any] = message_hook(
|
|
203
205
|
state=state,
|
|
204
|
-
|
|
206
|
+
runtime=runtime,
|
|
205
207
|
)
|
|
206
208
|
response.update(hook_response)
|
|
207
209
|
logger.debug(f"Hook response: {hook_response}")
|
dao_ai/providers/databricks.py
CHANGED
|
@@ -355,6 +355,20 @@ class DatabricksProvider(ServiceProvider):
|
|
|
355
355
|
|
|
356
356
|
latest_version: int = get_latest_model_version(registered_model_name)
|
|
357
357
|
|
|
358
|
+
# Check if endpoint exists to determine deployment strategy
|
|
359
|
+
endpoint_exists: bool = False
|
|
360
|
+
try:
|
|
361
|
+
agents.get_deployments(endpoint_name)
|
|
362
|
+
endpoint_exists = True
|
|
363
|
+
logger.debug(
|
|
364
|
+
f"Endpoint {endpoint_name} already exists, updating without tags to avoid conflicts..."
|
|
365
|
+
)
|
|
366
|
+
except Exception:
|
|
367
|
+
logger.debug(
|
|
368
|
+
f"Endpoint {endpoint_name} doesn't exist, creating new with tags..."
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
# Deploy - skip tags for existing endpoints to avoid conflicts
|
|
358
372
|
agents.deploy(
|
|
359
373
|
endpoint_name=endpoint_name,
|
|
360
374
|
model_name=registered_model_name,
|
|
@@ -362,7 +376,7 @@ class DatabricksProvider(ServiceProvider):
|
|
|
362
376
|
scale_to_zero=scale_to_zero,
|
|
363
377
|
environment_vars=environment_vars,
|
|
364
378
|
workload_size=workload_size,
|
|
365
|
-
tags=tags,
|
|
379
|
+
tags=tags if not endpoint_exists else None,
|
|
366
380
|
)
|
|
367
381
|
|
|
368
382
|
registered_model_name: str = config.app.registered_model.full_name
|
|
@@ -526,9 +540,75 @@ class DatabricksProvider(ServiceProvider):
|
|
|
526
540
|
columns_to_sync=vector_store.columns,
|
|
527
541
|
)
|
|
528
542
|
else:
|
|
529
|
-
|
|
543
|
+
logger.debug(
|
|
544
|
+
f"Index {vector_store.index.full_name} already exists, checking status and syncing..."
|
|
545
|
+
)
|
|
546
|
+
index = self.vsc.get_index(
|
|
530
547
|
vector_store.endpoint.name, vector_store.index.full_name
|
|
531
|
-
)
|
|
548
|
+
)
|
|
549
|
+
|
|
550
|
+
# Wait for index to be in a syncable state
|
|
551
|
+
import time
|
|
552
|
+
|
|
553
|
+
max_wait_time = 600 # 10 minutes
|
|
554
|
+
wait_interval = 10 # 10 seconds
|
|
555
|
+
elapsed = 0
|
|
556
|
+
|
|
557
|
+
while elapsed < max_wait_time:
|
|
558
|
+
try:
|
|
559
|
+
index_status = index.describe()
|
|
560
|
+
pipeline_status = index_status.get("status", {}).get(
|
|
561
|
+
"detailed_state", "UNKNOWN"
|
|
562
|
+
)
|
|
563
|
+
logger.debug(f"Index pipeline status: {pipeline_status}")
|
|
564
|
+
|
|
565
|
+
if pipeline_status in [
|
|
566
|
+
"COMPLETED",
|
|
567
|
+
"FAILED",
|
|
568
|
+
"CANCELED",
|
|
569
|
+
"ONLINE_PIPELINE_FAILED",
|
|
570
|
+
]:
|
|
571
|
+
logger.debug(
|
|
572
|
+
f"Index is ready to sync (status: {pipeline_status})"
|
|
573
|
+
)
|
|
574
|
+
break
|
|
575
|
+
elif pipeline_status in [
|
|
576
|
+
"WAITING_FOR_RESOURCES",
|
|
577
|
+
"PROVISIONING",
|
|
578
|
+
"INITIALIZING",
|
|
579
|
+
"INDEXING",
|
|
580
|
+
"ONLINE",
|
|
581
|
+
]:
|
|
582
|
+
logger.debug(
|
|
583
|
+
f"Index not ready yet (status: {pipeline_status}), waiting {wait_interval} seconds..."
|
|
584
|
+
)
|
|
585
|
+
time.sleep(wait_interval)
|
|
586
|
+
elapsed += wait_interval
|
|
587
|
+
else:
|
|
588
|
+
logger.warning(
|
|
589
|
+
f"Unknown pipeline status: {pipeline_status}, attempting sync anyway"
|
|
590
|
+
)
|
|
591
|
+
break
|
|
592
|
+
except Exception as status_error:
|
|
593
|
+
logger.warning(
|
|
594
|
+
f"Could not check index status: {status_error}, attempting sync anyway"
|
|
595
|
+
)
|
|
596
|
+
break
|
|
597
|
+
|
|
598
|
+
if elapsed >= max_wait_time:
|
|
599
|
+
logger.warning(
|
|
600
|
+
f"Timed out waiting for index to be ready after {max_wait_time} seconds"
|
|
601
|
+
)
|
|
602
|
+
|
|
603
|
+
# Now attempt to sync
|
|
604
|
+
try:
|
|
605
|
+
index.sync()
|
|
606
|
+
logger.debug("Index sync completed successfully")
|
|
607
|
+
except Exception as sync_error:
|
|
608
|
+
if "not ready to sync yet" in str(sync_error).lower():
|
|
609
|
+
logger.warning(f"Index still not ready to sync: {sync_error}")
|
|
610
|
+
else:
|
|
611
|
+
raise sync_error
|
|
532
612
|
|
|
533
613
|
logger.debug(
|
|
534
614
|
f"index {vector_store.index.full_name} on table {vector_store.source_table.full_name} is ready"
|
dao_ai/state.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from langchain_core.messages import AnyMessage
|
|
2
2
|
from langgraph.graph import MessagesState
|
|
3
3
|
from langgraph.managed import RemainingSteps
|
|
4
|
+
from pydantic import BaseModel
|
|
4
5
|
|
|
5
6
|
|
|
6
7
|
class IncomingState(MessagesState): ...
|
|
@@ -29,3 +30,9 @@ class SharedState(MessagesState):
|
|
|
29
30
|
|
|
30
31
|
is_valid: bool # message validation node
|
|
31
32
|
message_error: str
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class Context(BaseModel):
|
|
36
|
+
user_id: str | None = None
|
|
37
|
+
thread_id: str | None = None
|
|
38
|
+
store_num: int | None = None
|
dao_ai/tools/__init__.py
CHANGED
|
@@ -1,14 +1,12 @@
|
|
|
1
1
|
from dao_ai.hooks.core import create_hooks
|
|
2
2
|
from dao_ai.tools.agent import create_agent_endpoint_tool
|
|
3
3
|
from dao_ai.tools.core import (
|
|
4
|
-
create_factory_tool,
|
|
5
|
-
create_mcp_tools,
|
|
6
|
-
create_python_tool,
|
|
7
4
|
create_tools,
|
|
8
|
-
create_uc_tools,
|
|
9
5
|
search_tool,
|
|
10
6
|
)
|
|
11
7
|
from dao_ai.tools.genie import create_genie_tool
|
|
8
|
+
from dao_ai.tools.mcp import create_mcp_tools
|
|
9
|
+
from dao_ai.tools.python import create_factory_tool, create_python_tool
|
|
12
10
|
from dao_ai.tools.time import (
|
|
13
11
|
add_time_tool,
|
|
14
12
|
current_time_tool,
|
|
@@ -18,6 +16,7 @@ from dao_ai.tools.time import (
|
|
|
18
16
|
time_in_timezone_tool,
|
|
19
17
|
time_until_tool,
|
|
20
18
|
)
|
|
19
|
+
from dao_ai.tools.unity_catalog import create_uc_tools
|
|
21
20
|
from dao_ai.tools.vector_search import create_vector_search_tool
|
|
22
21
|
|
|
23
22
|
__all__ = [
|
dao_ai/tools/core.py
CHANGED
|
@@ -1,118 +1,15 @@
|
|
|
1
|
-
import asyncio
|
|
2
1
|
from collections import OrderedDict
|
|
3
|
-
from typing import
|
|
2
|
+
from typing import Sequence
|
|
4
3
|
|
|
5
|
-
from databricks_langchain import (
|
|
6
|
-
DatabricksFunctionClient,
|
|
7
|
-
UCFunctionToolkit,
|
|
8
|
-
)
|
|
9
4
|
from langchain_community.tools import DuckDuckGoSearchRun
|
|
10
|
-
from langchain_core.runnables import RunnableConfig
|
|
11
5
|
from langchain_core.runnables.base import RunnableLike
|
|
12
|
-
from langchain_core.tools import BaseTool
|
|
13
|
-
from langchain_core.tools import tool as create_tool
|
|
14
|
-
from langchain_mcp_adapters.client import MultiServerMCPClient
|
|
15
|
-
from langgraph.prebuilt.interrupt import HumanInterrupt, HumanInterruptConfig
|
|
16
|
-
from langgraph.types import interrupt
|
|
17
6
|
from loguru import logger
|
|
18
|
-
from mcp.types import ListToolsResult, Tool
|
|
19
7
|
|
|
20
8
|
from dao_ai.config import (
|
|
21
9
|
AnyTool,
|
|
22
|
-
BaseFunctionModel,
|
|
23
|
-
FactoryFunctionModel,
|
|
24
|
-
HumanInTheLoopModel,
|
|
25
|
-
McpFunctionModel,
|
|
26
|
-
PythonFunctionModel,
|
|
27
10
|
ToolModel,
|
|
28
|
-
TransportType,
|
|
29
|
-
UnityCatalogFunctionModel,
|
|
30
11
|
)
|
|
31
12
|
from dao_ai.hooks.core import create_hooks
|
|
32
|
-
from dao_ai.utils import load_function
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
def add_human_in_the_loop(
|
|
36
|
-
tool: RunnableLike,
|
|
37
|
-
*,
|
|
38
|
-
interrupt_config: HumanInterruptConfig | None = None,
|
|
39
|
-
review_prompt: Optional[str] = "Please review the tool call",
|
|
40
|
-
) -> BaseTool:
|
|
41
|
-
"""
|
|
42
|
-
Wrap a tool with human-in-the-loop functionality.
|
|
43
|
-
This function takes a tool (either a callable or a BaseTool instance) and wraps it
|
|
44
|
-
with a human-in-the-loop mechanism. When the tool is invoked, it will first
|
|
45
|
-
request human review before executing the tool's logic. The human can choose to
|
|
46
|
-
accept, edit the input, or provide a custom response.
|
|
47
|
-
|
|
48
|
-
Args:
|
|
49
|
-
tool (Callable[..., Any] | BaseTool): _description_
|
|
50
|
-
interrupt_config (HumanInterruptConfig | None, optional): _description_. Defaults to None.
|
|
51
|
-
|
|
52
|
-
Raises:
|
|
53
|
-
ValueError: _description_
|
|
54
|
-
|
|
55
|
-
Returns:
|
|
56
|
-
BaseTool: _description_
|
|
57
|
-
"""
|
|
58
|
-
if not isinstance(tool, BaseTool):
|
|
59
|
-
tool = create_tool(tool)
|
|
60
|
-
|
|
61
|
-
if interrupt_config is None:
|
|
62
|
-
interrupt_config = {
|
|
63
|
-
"allow_accept": True,
|
|
64
|
-
"allow_edit": True,
|
|
65
|
-
"allow_respond": True,
|
|
66
|
-
}
|
|
67
|
-
|
|
68
|
-
logger.debug(f"Wrapping tool {tool} with human-in-the-loop functionality")
|
|
69
|
-
|
|
70
|
-
@create_tool(tool.name, description=tool.description, args_schema=tool.args_schema)
|
|
71
|
-
def call_tool_with_interrupt(config: RunnableConfig, **tool_input) -> Any:
|
|
72
|
-
logger.debug(f"call_tool_with_interrupt: {tool.name} with input: {tool_input}")
|
|
73
|
-
request: HumanInterrupt = {
|
|
74
|
-
"action_request": {
|
|
75
|
-
"action": tool.name,
|
|
76
|
-
"args": tool_input,
|
|
77
|
-
},
|
|
78
|
-
"config": interrupt_config,
|
|
79
|
-
"description": review_prompt,
|
|
80
|
-
}
|
|
81
|
-
|
|
82
|
-
logger.debug(f"Human interrupt request: {request}")
|
|
83
|
-
response: dict[str, Any] = interrupt([request])[0]
|
|
84
|
-
logger.debug(f"Human interrupt response: {response}")
|
|
85
|
-
|
|
86
|
-
if response["type"] == "accept":
|
|
87
|
-
tool_response = tool.invoke(tool_input, config=config)
|
|
88
|
-
elif response["type"] == "edit":
|
|
89
|
-
tool_input = response["args"]["args"]
|
|
90
|
-
tool_response = tool.invoke(tool_input, config=config)
|
|
91
|
-
elif response["type"] == "response":
|
|
92
|
-
user_feedback = response["args"]
|
|
93
|
-
tool_response = user_feedback
|
|
94
|
-
else:
|
|
95
|
-
raise ValueError(f"Unknown interrupt response type: {response['type']}")
|
|
96
|
-
|
|
97
|
-
return tool_response
|
|
98
|
-
|
|
99
|
-
return call_tool_with_interrupt
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
def as_human_in_the_loop(
|
|
103
|
-
tool: RunnableLike, function: BaseFunctionModel | str
|
|
104
|
-
) -> RunnableLike:
|
|
105
|
-
if isinstance(function, BaseFunctionModel):
|
|
106
|
-
human_in_the_loop: HumanInTheLoopModel | None = function.human_in_the_loop
|
|
107
|
-
if human_in_the_loop:
|
|
108
|
-
logger.debug(f"Adding human-in-the-loop to tool: {tool.name}")
|
|
109
|
-
tool = add_human_in_the_loop(
|
|
110
|
-
tool=tool,
|
|
111
|
-
interrupt_config=human_in_the_loop.interupt_config,
|
|
112
|
-
review_prompt=human_in_the_loop.review_prompt,
|
|
113
|
-
)
|
|
114
|
-
return tool
|
|
115
|
-
|
|
116
13
|
|
|
117
14
|
tool_registry: dict[str, Sequence[RunnableLike]] = {}
|
|
118
15
|
|
|
@@ -157,196 +54,6 @@ def create_tools(tool_models: Sequence[ToolModel]) -> Sequence[RunnableLike]:
|
|
|
157
54
|
return all_tools
|
|
158
55
|
|
|
159
56
|
|
|
160
|
-
def create_mcp_tools(
|
|
161
|
-
function: McpFunctionModel,
|
|
162
|
-
) -> Sequence[RunnableLike]:
|
|
163
|
-
"""
|
|
164
|
-
Create tools for invoking Databricks MCP functions.
|
|
165
|
-
|
|
166
|
-
Uses session-based approach to handle authentication token expiration properly.
|
|
167
|
-
"""
|
|
168
|
-
logger.debug(f"create_mcp_tools: {function}")
|
|
169
|
-
|
|
170
|
-
def _create_fresh_connection() -> dict[str, Any]:
|
|
171
|
-
logger.debug("Creating fresh connection...")
|
|
172
|
-
"""Create connection config with fresh authentication headers."""
|
|
173
|
-
if function.transport == TransportType.STDIO:
|
|
174
|
-
return {
|
|
175
|
-
"command": function.command,
|
|
176
|
-
"args": function.args,
|
|
177
|
-
"transport": function.transport,
|
|
178
|
-
}
|
|
179
|
-
|
|
180
|
-
# For HTTP transport, generate fresh headers
|
|
181
|
-
headers = function.headers.copy() if function.headers else {}
|
|
182
|
-
|
|
183
|
-
if "Authorization" not in headers:
|
|
184
|
-
logger.debug("Generating fresh authentication token for MCP function")
|
|
185
|
-
|
|
186
|
-
from dao_ai.config import value_of
|
|
187
|
-
from dao_ai.providers.databricks import DatabricksProvider
|
|
188
|
-
|
|
189
|
-
try:
|
|
190
|
-
provider = DatabricksProvider(
|
|
191
|
-
workspace_host=value_of(function.workspace_host),
|
|
192
|
-
client_id=value_of(function.client_id),
|
|
193
|
-
client_secret=value_of(function.client_secret),
|
|
194
|
-
pat=value_of(function.pat),
|
|
195
|
-
)
|
|
196
|
-
headers["Authorization"] = f"Bearer {provider.create_token()}"
|
|
197
|
-
logger.debug("Generated fresh authentication token")
|
|
198
|
-
except Exception as e:
|
|
199
|
-
logger.error(f"Failed to create fresh token: {e}")
|
|
200
|
-
else:
|
|
201
|
-
logger.debug("Using existing authentication token")
|
|
202
|
-
|
|
203
|
-
response = {
|
|
204
|
-
"url": function.url,
|
|
205
|
-
"transport": function.transport,
|
|
206
|
-
"headers": headers,
|
|
207
|
-
}
|
|
208
|
-
|
|
209
|
-
return response
|
|
210
|
-
|
|
211
|
-
# Get available tools from MCP server
|
|
212
|
-
async def _list_mcp_tools():
|
|
213
|
-
connection = _create_fresh_connection()
|
|
214
|
-
client = MultiServerMCPClient({function.name: connection})
|
|
215
|
-
|
|
216
|
-
try:
|
|
217
|
-
async with client.session(function.name) as session:
|
|
218
|
-
return await session.list_tools()
|
|
219
|
-
except Exception as e:
|
|
220
|
-
logger.error(f"Failed to list MCP tools: {e}")
|
|
221
|
-
return []
|
|
222
|
-
|
|
223
|
-
try:
|
|
224
|
-
mcp_tools: list | ListToolsResult = asyncio.run(_list_mcp_tools())
|
|
225
|
-
if isinstance(mcp_tools, ListToolsResult):
|
|
226
|
-
mcp_tools = mcp_tools.tools
|
|
227
|
-
|
|
228
|
-
logger.debug(f"Retrieved {len(mcp_tools)} MCP tools")
|
|
229
|
-
except Exception as e:
|
|
230
|
-
logger.error(f"Failed to get tools from MCP server: {e}")
|
|
231
|
-
raise RuntimeError(
|
|
232
|
-
f"Failed to list MCP tools for function '{function.name}' with transport '{function.transport}' and URL '{function.url}': {e}"
|
|
233
|
-
)
|
|
234
|
-
|
|
235
|
-
# Create wrapper tools with fresh session per invocation
|
|
236
|
-
def _create_tool_wrapper(mcp_tool: Tool) -> RunnableLike:
|
|
237
|
-
@create_tool(
|
|
238
|
-
mcp_tool.name,
|
|
239
|
-
description=mcp_tool.description or f"MCP tool: {mcp_tool.name}",
|
|
240
|
-
args_schema=mcp_tool.inputSchema,
|
|
241
|
-
)
|
|
242
|
-
def tool_wrapper(**kwargs):
|
|
243
|
-
"""Execute MCP tool with fresh session and authentication."""
|
|
244
|
-
logger.debug(f"Invoking MCP tool {mcp_tool.name} with fresh session")
|
|
245
|
-
|
|
246
|
-
async def _invoke():
|
|
247
|
-
connection = _create_fresh_connection()
|
|
248
|
-
client = MultiServerMCPClient({function.name: connection})
|
|
249
|
-
|
|
250
|
-
try:
|
|
251
|
-
async with client.session(function.name) as session:
|
|
252
|
-
return await session.call_tool(mcp_tool.name, kwargs)
|
|
253
|
-
except Exception as e:
|
|
254
|
-
logger.error(f"MCP tool {mcp_tool.name} failed: {e}")
|
|
255
|
-
raise
|
|
256
|
-
|
|
257
|
-
return asyncio.run(_invoke())
|
|
258
|
-
|
|
259
|
-
return as_human_in_the_loop(tool_wrapper, function)
|
|
260
|
-
|
|
261
|
-
return [_create_tool_wrapper(tool) for tool in mcp_tools]
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
def create_factory_tool(
|
|
265
|
-
function: FactoryFunctionModel,
|
|
266
|
-
) -> RunnableLike:
|
|
267
|
-
"""
|
|
268
|
-
Create a factory tool from a FactoryFunctionModel.
|
|
269
|
-
This factory function dynamically loads a Python function and returns it as a callable tool.
|
|
270
|
-
Args:
|
|
271
|
-
function: FactoryFunctionModel instance containing the function details
|
|
272
|
-
Returns:
|
|
273
|
-
A callable tool function that wraps the specified factory function
|
|
274
|
-
"""
|
|
275
|
-
logger.debug(f"create_factory_tool: {function}")
|
|
276
|
-
|
|
277
|
-
factory: Callable[..., Any] = load_function(function_name=function.full_name)
|
|
278
|
-
tool: Callable[..., Any] = factory(**function.args)
|
|
279
|
-
tool = as_human_in_the_loop(
|
|
280
|
-
tool=tool,
|
|
281
|
-
function=function,
|
|
282
|
-
)
|
|
283
|
-
return tool
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
def create_python_tool(
|
|
287
|
-
function: PythonFunctionModel | str,
|
|
288
|
-
) -> RunnableLike:
|
|
289
|
-
"""
|
|
290
|
-
Create a Python tool from a Python function model.
|
|
291
|
-
This factory function wraps a Python function as a callable tool that can be
|
|
292
|
-
invoked by agents during reasoning.
|
|
293
|
-
Args:
|
|
294
|
-
function: PythonFunctionModel instance containing the function details
|
|
295
|
-
Returns:
|
|
296
|
-
A callable tool function that wraps the specified Python function
|
|
297
|
-
"""
|
|
298
|
-
logger.debug(f"create_python_tool: {function}")
|
|
299
|
-
|
|
300
|
-
if isinstance(function, PythonFunctionModel):
|
|
301
|
-
function = function.full_name
|
|
302
|
-
|
|
303
|
-
# Load the Python function dynamically
|
|
304
|
-
tool: Callable[..., Any] = load_function(function_name=function)
|
|
305
|
-
|
|
306
|
-
tool = as_human_in_the_loop(
|
|
307
|
-
tool=tool,
|
|
308
|
-
function=function,
|
|
309
|
-
)
|
|
310
|
-
return tool
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
def create_uc_tools(
|
|
314
|
-
function: UnityCatalogFunctionModel | str,
|
|
315
|
-
) -> Sequence[RunnableLike]:
|
|
316
|
-
"""
|
|
317
|
-
Create LangChain tools from Unity Catalog functions.
|
|
318
|
-
|
|
319
|
-
This factory function wraps Unity Catalog functions as LangChain tools,
|
|
320
|
-
making them available for use by agents. Each UC function becomes a callable
|
|
321
|
-
tool that can be invoked by the agent during reasoning.
|
|
322
|
-
|
|
323
|
-
Args:
|
|
324
|
-
function: UnityCatalogFunctionModel instance containing the function details
|
|
325
|
-
|
|
326
|
-
Returns:
|
|
327
|
-
A sequence of BaseTool objects that wrap the specified UC functions
|
|
328
|
-
"""
|
|
329
|
-
|
|
330
|
-
logger.debug(f"create_uc_tools: {function}")
|
|
331
|
-
|
|
332
|
-
if isinstance(function, UnityCatalogFunctionModel):
|
|
333
|
-
function = function.full_name
|
|
334
|
-
|
|
335
|
-
client: DatabricksFunctionClient = DatabricksFunctionClient()
|
|
336
|
-
|
|
337
|
-
toolkit: UCFunctionToolkit = UCFunctionToolkit(
|
|
338
|
-
function_names=[function], client=client
|
|
339
|
-
)
|
|
340
|
-
|
|
341
|
-
tools = toolkit.tools or []
|
|
342
|
-
|
|
343
|
-
logger.debug(f"Retrieved tools: {tools}")
|
|
344
|
-
|
|
345
|
-
tools = [as_human_in_the_loop(tool=tool, function=function) for tool in tools]
|
|
346
|
-
|
|
347
|
-
return tools
|
|
348
|
-
|
|
349
|
-
|
|
350
57
|
def search_tool() -> RunnableLike:
|
|
351
58
|
logger.debug("search_tool")
|
|
352
59
|
return DuckDuckGoSearchRun(output_format="list")
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
from typing import Any, Optional
|
|
2
|
+
|
|
3
|
+
from langchain_core.runnables import RunnableConfig
|
|
4
|
+
from langchain_core.runnables.base import RunnableLike
|
|
5
|
+
from langchain_core.tools import BaseTool
|
|
6
|
+
from langchain_core.tools import tool as create_tool
|
|
7
|
+
from langgraph.prebuilt.interrupt import HumanInterrupt, HumanInterruptConfig
|
|
8
|
+
from langgraph.types import interrupt
|
|
9
|
+
from loguru import logger
|
|
10
|
+
|
|
11
|
+
from dao_ai.config import (
|
|
12
|
+
BaseFunctionModel,
|
|
13
|
+
HumanInTheLoopModel,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def add_human_in_the_loop(
|
|
18
|
+
tool: RunnableLike,
|
|
19
|
+
*,
|
|
20
|
+
interrupt_config: HumanInterruptConfig | None = None,
|
|
21
|
+
review_prompt: Optional[str] = "Please review the tool call",
|
|
22
|
+
) -> BaseTool:
|
|
23
|
+
"""
|
|
24
|
+
Wrap a tool with human-in-the-loop functionality.
|
|
25
|
+
This function takes a tool (either a callable or a BaseTool instance) and wraps it
|
|
26
|
+
with a human-in-the-loop mechanism. When the tool is invoked, it will first
|
|
27
|
+
request human review before executing the tool's logic. The human can choose to
|
|
28
|
+
accept, edit the input, or provide a custom response.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
tool (Callable[..., Any] | BaseTool): _description_
|
|
32
|
+
interrupt_config (HumanInterruptConfig | None, optional): _description_. Defaults to None.
|
|
33
|
+
|
|
34
|
+
Raises:
|
|
35
|
+
ValueError: _description_
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
BaseTool: _description_
|
|
39
|
+
"""
|
|
40
|
+
if not isinstance(tool, BaseTool):
|
|
41
|
+
tool = create_tool(tool)
|
|
42
|
+
|
|
43
|
+
if interrupt_config is None:
|
|
44
|
+
interrupt_config = {
|
|
45
|
+
"allow_accept": True,
|
|
46
|
+
"allow_edit": True,
|
|
47
|
+
"allow_respond": True,
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
logger.debug(f"Wrapping tool {tool} with human-in-the-loop functionality")
|
|
51
|
+
|
|
52
|
+
@create_tool(tool.name, description=tool.description, args_schema=tool.args_schema)
|
|
53
|
+
def call_tool_with_interrupt(config: RunnableConfig, **tool_input) -> Any:
|
|
54
|
+
logger.debug(f"call_tool_with_interrupt: {tool.name} with input: {tool_input}")
|
|
55
|
+
request: HumanInterrupt = {
|
|
56
|
+
"action_request": {
|
|
57
|
+
"action": tool.name,
|
|
58
|
+
"args": tool_input,
|
|
59
|
+
},
|
|
60
|
+
"config": interrupt_config,
|
|
61
|
+
"description": review_prompt,
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
logger.debug(f"Human interrupt request: {request}")
|
|
65
|
+
response: dict[str, Any] = interrupt([request])[0]
|
|
66
|
+
logger.debug(f"Human interrupt response: {response}")
|
|
67
|
+
|
|
68
|
+
if response["type"] == "accept":
|
|
69
|
+
tool_response = tool.invoke(tool_input, config=config)
|
|
70
|
+
elif response["type"] == "edit":
|
|
71
|
+
tool_input = response["args"]["args"]
|
|
72
|
+
tool_response = tool.invoke(tool_input, config=config)
|
|
73
|
+
elif response["type"] == "response":
|
|
74
|
+
user_feedback = response["args"]
|
|
75
|
+
tool_response = user_feedback
|
|
76
|
+
else:
|
|
77
|
+
raise ValueError(f"Unknown interrupt response type: {response['type']}")
|
|
78
|
+
|
|
79
|
+
return tool_response
|
|
80
|
+
|
|
81
|
+
return call_tool_with_interrupt
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def as_human_in_the_loop(
|
|
85
|
+
tool: RunnableLike, function: BaseFunctionModel | str
|
|
86
|
+
) -> RunnableLike:
|
|
87
|
+
if isinstance(function, BaseFunctionModel):
|
|
88
|
+
human_in_the_loop: HumanInTheLoopModel | None = function.human_in_the_loop
|
|
89
|
+
if human_in_the_loop:
|
|
90
|
+
logger.debug(f"Adding human-in-the-loop to tool: {tool.name}")
|
|
91
|
+
tool = add_human_in_the_loop(
|
|
92
|
+
tool=tool,
|
|
93
|
+
interrupt_config=human_in_the_loop.interupt_config,
|
|
94
|
+
review_prompt=human_in_the_loop.review_prompt,
|
|
95
|
+
)
|
|
96
|
+
return tool
|