dao-ai 0.0.6__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dao_ai/nodes.py CHANGED
@@ -10,6 +10,7 @@ from langchain_core.tools import BaseTool
10
10
  from langgraph.graph import StateGraph
11
11
  from langgraph.graph.state import CompiledStateGraph
12
12
  from langgraph.prebuilt import create_react_agent
13
+ from langgraph.runtime import Runtime
13
14
  from langmem import create_manage_memory_tool, create_search_memory_tool
14
15
  from langmem.short_term import SummarizationNode
15
16
  from langmem.short_term.summarization import TokenCounter
@@ -26,7 +27,7 @@ from dao_ai.config import (
26
27
  from dao_ai.guardrails import reflection_guardrail, with_guardrails
27
28
  from dao_ai.hooks.core import create_hooks
28
29
  from dao_ai.prompts import make_prompt
29
- from dao_ai.state import IncomingState, SharedState
30
+ from dao_ai.state import Context, IncomingState, SharedState
30
31
  from dao_ai.tools import create_tools
31
32
 
32
33
 
@@ -53,6 +54,7 @@ def summarization_node(app_model: AppModel) -> RunnableLike:
53
54
  )
54
55
 
55
56
  summarization_model: LanguageModelLike = chat_history.model.as_chat_model()
57
+
56
58
  node: RunnableLike = SummarizationNode(
57
59
  model=summarization_model,
58
60
  max_tokens=max_tokens,
@@ -67,7 +69,7 @@ def summarization_node(app_model: AppModel) -> RunnableLike:
67
69
 
68
70
 
69
71
  def call_agent_with_summarized_messages(agent: CompiledStateGraph) -> RunnableLike:
70
- def call_agent(state: SharedState, config: RunnableConfig) -> SharedState:
72
+ def call_agent(state: SharedState, runtime: Runtime[Context]) -> SharedState:
71
73
  logger.debug(f"Calling agent {agent.name} with summarized messages")
72
74
 
73
75
  # Get the summarized messages from the summarization node
@@ -79,7 +81,7 @@ def call_agent_with_summarized_messages(agent: CompiledStateGraph) -> RunnableLi
79
81
  "messages": messages,
80
82
  }
81
83
 
82
- response: dict[str, Any] = agent.invoke(input=input, config=config)
84
+ response: dict[str, Any] = agent.invoke(input=input, context=runtime.context)
83
85
  response_messages = response.get("messages", [])
84
86
  logger.debug(f"Agent returned {len(response_messages)} messages")
85
87
 
@@ -147,9 +149,9 @@ def create_agent_node(
147
149
  prompt=make_prompt(agent.prompt),
148
150
  tools=tools,
149
151
  store=True,
150
- state_schema=SharedState,
151
- config_schema=RunnableConfig,
152
152
  checkpointer=True,
153
+ state_schema=SharedState,
154
+ context_schema=Context,
153
155
  pre_model_hook=pre_agent_hook,
154
156
  post_model_hook=post_agent_hook,
155
157
  )
@@ -165,17 +167,17 @@ def create_agent_node(
165
167
  chat_history: ChatHistoryModel = app.chat_history
166
168
 
167
169
  if chat_history is None:
170
+ logger.debug("No chat history configured, using compiled agent directly")
168
171
  agent_node = compiled_agent
169
172
  else:
173
+ logger.debug("Creating agent node with chat history summarization")
170
174
  workflow: StateGraph = StateGraph(
171
175
  SharedState,
172
176
  config_schema=RunnableConfig,
173
177
  input=SharedState,
174
178
  output=SharedState,
175
179
  )
176
- workflow.add_node(
177
- "summarization", summarization_node(chat_history=chat_history)
178
- )
180
+ workflow.add_node("summarization", summarization_node(app))
179
181
  workflow.add_node(
180
182
  "agent",
181
183
  call_agent_with_summarized_messages(agent=compiled_agent),
@@ -191,7 +193,7 @@ def message_hook_node(config: AppConfig) -> RunnableLike:
191
193
  message_hooks: Sequence[Callable[..., Any]] = create_hooks(config.app.message_hooks)
192
194
 
193
195
  @mlflow.trace()
194
- def message_hook(state: IncomingState, config: RunnableConfig) -> SharedState:
196
+ def message_hook(state: IncomingState, runtime: Runtime[Context]) -> SharedState:
195
197
  logger.debug("Running message validation")
196
198
  response: dict[str, Any] = {"is_valid": True, "message_error": None}
197
199
 
@@ -201,7 +203,7 @@ def message_hook_node(config: AppConfig) -> RunnableLike:
201
203
  try:
202
204
  hook_response: dict[str, Any] = message_hook(
203
205
  state=state,
204
- config=config,
206
+ runtime=runtime,
205
207
  )
206
208
  response.update(hook_response)
207
209
  logger.debug(f"Hook response: {hook_response}")
@@ -355,6 +355,20 @@ class DatabricksProvider(ServiceProvider):
355
355
 
356
356
  latest_version: int = get_latest_model_version(registered_model_name)
357
357
 
358
+ # Check if endpoint exists to determine deployment strategy
359
+ endpoint_exists: bool = False
360
+ try:
361
+ agents.get_deployments(endpoint_name)
362
+ endpoint_exists = True
363
+ logger.debug(
364
+ f"Endpoint {endpoint_name} already exists, updating without tags to avoid conflicts..."
365
+ )
366
+ except Exception:
367
+ logger.debug(
368
+ f"Endpoint {endpoint_name} doesn't exist, creating new with tags..."
369
+ )
370
+
371
+ # Deploy - skip tags for existing endpoints to avoid conflicts
358
372
  agents.deploy(
359
373
  endpoint_name=endpoint_name,
360
374
  model_name=registered_model_name,
@@ -362,7 +376,7 @@ class DatabricksProvider(ServiceProvider):
362
376
  scale_to_zero=scale_to_zero,
363
377
  environment_vars=environment_vars,
364
378
  workload_size=workload_size,
365
- tags=tags,
379
+ tags=tags if not endpoint_exists else None,
366
380
  )
367
381
 
368
382
  registered_model_name: str = config.app.registered_model.full_name
@@ -526,9 +540,75 @@ class DatabricksProvider(ServiceProvider):
526
540
  columns_to_sync=vector_store.columns,
527
541
  )
528
542
  else:
529
- self.vsc.get_index(
543
+ logger.debug(
544
+ f"Index {vector_store.index.full_name} already exists, checking status and syncing..."
545
+ )
546
+ index = self.vsc.get_index(
530
547
  vector_store.endpoint.name, vector_store.index.full_name
531
- ).sync()
548
+ )
549
+
550
+ # Wait for index to be in a syncable state
551
+ import time
552
+
553
+ max_wait_time = 600 # 10 minutes
554
+ wait_interval = 10 # 10 seconds
555
+ elapsed = 0
556
+
557
+ while elapsed < max_wait_time:
558
+ try:
559
+ index_status = index.describe()
560
+ pipeline_status = index_status.get("status", {}).get(
561
+ "detailed_state", "UNKNOWN"
562
+ )
563
+ logger.debug(f"Index pipeline status: {pipeline_status}")
564
+
565
+ if pipeline_status in [
566
+ "COMPLETED",
567
+ "FAILED",
568
+ "CANCELED",
569
+ "ONLINE_PIPELINE_FAILED",
570
+ ]:
571
+ logger.debug(
572
+ f"Index is ready to sync (status: {pipeline_status})"
573
+ )
574
+ break
575
+ elif pipeline_status in [
576
+ "WAITING_FOR_RESOURCES",
577
+ "PROVISIONING",
578
+ "INITIALIZING",
579
+ "INDEXING",
580
+ "ONLINE",
581
+ ]:
582
+ logger.debug(
583
+ f"Index not ready yet (status: {pipeline_status}), waiting {wait_interval} seconds..."
584
+ )
585
+ time.sleep(wait_interval)
586
+ elapsed += wait_interval
587
+ else:
588
+ logger.warning(
589
+ f"Unknown pipeline status: {pipeline_status}, attempting sync anyway"
590
+ )
591
+ break
592
+ except Exception as status_error:
593
+ logger.warning(
594
+ f"Could not check index status: {status_error}, attempting sync anyway"
595
+ )
596
+ break
597
+
598
+ if elapsed >= max_wait_time:
599
+ logger.warning(
600
+ f"Timed out waiting for index to be ready after {max_wait_time} seconds"
601
+ )
602
+
603
+ # Now attempt to sync
604
+ try:
605
+ index.sync()
606
+ logger.debug("Index sync completed successfully")
607
+ except Exception as sync_error:
608
+ if "not ready to sync yet" in str(sync_error).lower():
609
+ logger.warning(f"Index still not ready to sync: {sync_error}")
610
+ else:
611
+ raise sync_error
532
612
 
533
613
  logger.debug(
534
614
  f"index {vector_store.index.full_name} on table {vector_store.source_table.full_name} is ready"
dao_ai/state.py CHANGED
@@ -1,6 +1,7 @@
1
1
  from langchain_core.messages import AnyMessage
2
2
  from langgraph.graph import MessagesState
3
3
  from langgraph.managed import RemainingSteps
4
+ from pydantic import BaseModel
4
5
 
5
6
 
6
7
  class IncomingState(MessagesState): ...
@@ -29,3 +30,9 @@ class SharedState(MessagesState):
29
30
 
30
31
  is_valid: bool # message validation node
31
32
  message_error: str
33
+
34
+
35
+ class Context(BaseModel):
36
+ user_id: str | None = None
37
+ thread_id: str | None = None
38
+ store_num: int | None = None
dao_ai/tools/__init__.py CHANGED
@@ -1,14 +1,12 @@
1
1
  from dao_ai.hooks.core import create_hooks
2
2
  from dao_ai.tools.agent import create_agent_endpoint_tool
3
3
  from dao_ai.tools.core import (
4
- create_factory_tool,
5
- create_mcp_tools,
6
- create_python_tool,
7
4
  create_tools,
8
- create_uc_tools,
9
5
  search_tool,
10
6
  )
11
7
  from dao_ai.tools.genie import create_genie_tool
8
+ from dao_ai.tools.mcp import create_mcp_tools
9
+ from dao_ai.tools.python import create_factory_tool, create_python_tool
12
10
  from dao_ai.tools.time import (
13
11
  add_time_tool,
14
12
  current_time_tool,
@@ -18,6 +16,7 @@ from dao_ai.tools.time import (
18
16
  time_in_timezone_tool,
19
17
  time_until_tool,
20
18
  )
19
+ from dao_ai.tools.unity_catalog import create_uc_tools
21
20
  from dao_ai.tools.vector_search import create_vector_search_tool
22
21
 
23
22
  __all__ = [
dao_ai/tools/core.py CHANGED
@@ -1,118 +1,15 @@
1
- import asyncio
2
1
  from collections import OrderedDict
3
- from typing import Any, Callable, Optional, Sequence
2
+ from typing import Sequence
4
3
 
5
- from databricks_langchain import (
6
- DatabricksFunctionClient,
7
- UCFunctionToolkit,
8
- )
9
4
  from langchain_community.tools import DuckDuckGoSearchRun
10
- from langchain_core.runnables import RunnableConfig
11
5
  from langchain_core.runnables.base import RunnableLike
12
- from langchain_core.tools import BaseTool
13
- from langchain_core.tools import tool as create_tool
14
- from langchain_mcp_adapters.client import MultiServerMCPClient
15
- from langgraph.prebuilt.interrupt import HumanInterrupt, HumanInterruptConfig
16
- from langgraph.types import interrupt
17
6
  from loguru import logger
18
- from mcp.types import ListToolsResult, Tool
19
7
 
20
8
  from dao_ai.config import (
21
9
  AnyTool,
22
- BaseFunctionModel,
23
- FactoryFunctionModel,
24
- HumanInTheLoopModel,
25
- McpFunctionModel,
26
- PythonFunctionModel,
27
10
  ToolModel,
28
- TransportType,
29
- UnityCatalogFunctionModel,
30
11
  )
31
12
  from dao_ai.hooks.core import create_hooks
32
- from dao_ai.utils import load_function
33
-
34
-
35
- def add_human_in_the_loop(
36
- tool: RunnableLike,
37
- *,
38
- interrupt_config: HumanInterruptConfig | None = None,
39
- review_prompt: Optional[str] = "Please review the tool call",
40
- ) -> BaseTool:
41
- """
42
- Wrap a tool with human-in-the-loop functionality.
43
- This function takes a tool (either a callable or a BaseTool instance) and wraps it
44
- with a human-in-the-loop mechanism. When the tool is invoked, it will first
45
- request human review before executing the tool's logic. The human can choose to
46
- accept, edit the input, or provide a custom response.
47
-
48
- Args:
49
- tool (Callable[..., Any] | BaseTool): _description_
50
- interrupt_config (HumanInterruptConfig | None, optional): _description_. Defaults to None.
51
-
52
- Raises:
53
- ValueError: _description_
54
-
55
- Returns:
56
- BaseTool: _description_
57
- """
58
- if not isinstance(tool, BaseTool):
59
- tool = create_tool(tool)
60
-
61
- if interrupt_config is None:
62
- interrupt_config = {
63
- "allow_accept": True,
64
- "allow_edit": True,
65
- "allow_respond": True,
66
- }
67
-
68
- logger.debug(f"Wrapping tool {tool} with human-in-the-loop functionality")
69
-
70
- @create_tool(tool.name, description=tool.description, args_schema=tool.args_schema)
71
- def call_tool_with_interrupt(config: RunnableConfig, **tool_input) -> Any:
72
- logger.debug(f"call_tool_with_interrupt: {tool.name} with input: {tool_input}")
73
- request: HumanInterrupt = {
74
- "action_request": {
75
- "action": tool.name,
76
- "args": tool_input,
77
- },
78
- "config": interrupt_config,
79
- "description": review_prompt,
80
- }
81
-
82
- logger.debug(f"Human interrupt request: {request}")
83
- response: dict[str, Any] = interrupt([request])[0]
84
- logger.debug(f"Human interrupt response: {response}")
85
-
86
- if response["type"] == "accept":
87
- tool_response = tool.invoke(tool_input, config=config)
88
- elif response["type"] == "edit":
89
- tool_input = response["args"]["args"]
90
- tool_response = tool.invoke(tool_input, config=config)
91
- elif response["type"] == "response":
92
- user_feedback = response["args"]
93
- tool_response = user_feedback
94
- else:
95
- raise ValueError(f"Unknown interrupt response type: {response['type']}")
96
-
97
- return tool_response
98
-
99
- return call_tool_with_interrupt
100
-
101
-
102
- def as_human_in_the_loop(
103
- tool: RunnableLike, function: BaseFunctionModel | str
104
- ) -> RunnableLike:
105
- if isinstance(function, BaseFunctionModel):
106
- human_in_the_loop: HumanInTheLoopModel | None = function.human_in_the_loop
107
- if human_in_the_loop:
108
- logger.debug(f"Adding human-in-the-loop to tool: {tool.name}")
109
- tool = add_human_in_the_loop(
110
- tool=tool,
111
- interrupt_config=human_in_the_loop.interupt_config,
112
- review_prompt=human_in_the_loop.review_prompt,
113
- )
114
- return tool
115
-
116
13
 
117
14
  tool_registry: dict[str, Sequence[RunnableLike]] = {}
118
15
 
@@ -157,196 +54,6 @@ def create_tools(tool_models: Sequence[ToolModel]) -> Sequence[RunnableLike]:
157
54
  return all_tools
158
55
 
159
56
 
160
- def create_mcp_tools(
161
- function: McpFunctionModel,
162
- ) -> Sequence[RunnableLike]:
163
- """
164
- Create tools for invoking Databricks MCP functions.
165
-
166
- Uses session-based approach to handle authentication token expiration properly.
167
- """
168
- logger.debug(f"create_mcp_tools: {function}")
169
-
170
- def _create_fresh_connection() -> dict[str, Any]:
171
- logger.debug("Creating fresh connection...")
172
- """Create connection config with fresh authentication headers."""
173
- if function.transport == TransportType.STDIO:
174
- return {
175
- "command": function.command,
176
- "args": function.args,
177
- "transport": function.transport,
178
- }
179
-
180
- # For HTTP transport, generate fresh headers
181
- headers = function.headers.copy() if function.headers else {}
182
-
183
- if "Authorization" not in headers:
184
- logger.debug("Generating fresh authentication token for MCP function")
185
-
186
- from dao_ai.config import value_of
187
- from dao_ai.providers.databricks import DatabricksProvider
188
-
189
- try:
190
- provider = DatabricksProvider(
191
- workspace_host=value_of(function.workspace_host),
192
- client_id=value_of(function.client_id),
193
- client_secret=value_of(function.client_secret),
194
- pat=value_of(function.pat),
195
- )
196
- headers["Authorization"] = f"Bearer {provider.create_token()}"
197
- logger.debug("Generated fresh authentication token")
198
- except Exception as e:
199
- logger.error(f"Failed to create fresh token: {e}")
200
- else:
201
- logger.debug("Using existing authentication token")
202
-
203
- response = {
204
- "url": function.url,
205
- "transport": function.transport,
206
- "headers": headers,
207
- }
208
-
209
- return response
210
-
211
- # Get available tools from MCP server
212
- async def _list_mcp_tools():
213
- connection = _create_fresh_connection()
214
- client = MultiServerMCPClient({function.name: connection})
215
-
216
- try:
217
- async with client.session(function.name) as session:
218
- return await session.list_tools()
219
- except Exception as e:
220
- logger.error(f"Failed to list MCP tools: {e}")
221
- return []
222
-
223
- try:
224
- mcp_tools: list | ListToolsResult = asyncio.run(_list_mcp_tools())
225
- if isinstance(mcp_tools, ListToolsResult):
226
- mcp_tools = mcp_tools.tools
227
-
228
- logger.debug(f"Retrieved {len(mcp_tools)} MCP tools")
229
- except Exception as e:
230
- logger.error(f"Failed to get tools from MCP server: {e}")
231
- raise RuntimeError(
232
- f"Failed to list MCP tools for function '{function.name}' with transport '{function.transport}' and URL '{function.url}': {e}"
233
- )
234
-
235
- # Create wrapper tools with fresh session per invocation
236
- def _create_tool_wrapper(mcp_tool: Tool) -> RunnableLike:
237
- @create_tool(
238
- mcp_tool.name,
239
- description=mcp_tool.description or f"MCP tool: {mcp_tool.name}",
240
- args_schema=mcp_tool.inputSchema,
241
- )
242
- def tool_wrapper(**kwargs):
243
- """Execute MCP tool with fresh session and authentication."""
244
- logger.debug(f"Invoking MCP tool {mcp_tool.name} with fresh session")
245
-
246
- async def _invoke():
247
- connection = _create_fresh_connection()
248
- client = MultiServerMCPClient({function.name: connection})
249
-
250
- try:
251
- async with client.session(function.name) as session:
252
- return await session.call_tool(mcp_tool.name, kwargs)
253
- except Exception as e:
254
- logger.error(f"MCP tool {mcp_tool.name} failed: {e}")
255
- raise
256
-
257
- return asyncio.run(_invoke())
258
-
259
- return as_human_in_the_loop(tool_wrapper, function)
260
-
261
- return [_create_tool_wrapper(tool) for tool in mcp_tools]
262
-
263
-
264
- def create_factory_tool(
265
- function: FactoryFunctionModel,
266
- ) -> RunnableLike:
267
- """
268
- Create a factory tool from a FactoryFunctionModel.
269
- This factory function dynamically loads a Python function and returns it as a callable tool.
270
- Args:
271
- function: FactoryFunctionModel instance containing the function details
272
- Returns:
273
- A callable tool function that wraps the specified factory function
274
- """
275
- logger.debug(f"create_factory_tool: {function}")
276
-
277
- factory: Callable[..., Any] = load_function(function_name=function.full_name)
278
- tool: Callable[..., Any] = factory(**function.args)
279
- tool = as_human_in_the_loop(
280
- tool=tool,
281
- function=function,
282
- )
283
- return tool
284
-
285
-
286
- def create_python_tool(
287
- function: PythonFunctionModel | str,
288
- ) -> RunnableLike:
289
- """
290
- Create a Python tool from a Python function model.
291
- This factory function wraps a Python function as a callable tool that can be
292
- invoked by agents during reasoning.
293
- Args:
294
- function: PythonFunctionModel instance containing the function details
295
- Returns:
296
- A callable tool function that wraps the specified Python function
297
- """
298
- logger.debug(f"create_python_tool: {function}")
299
-
300
- if isinstance(function, PythonFunctionModel):
301
- function = function.full_name
302
-
303
- # Load the Python function dynamically
304
- tool: Callable[..., Any] = load_function(function_name=function)
305
-
306
- tool = as_human_in_the_loop(
307
- tool=tool,
308
- function=function,
309
- )
310
- return tool
311
-
312
-
313
- def create_uc_tools(
314
- function: UnityCatalogFunctionModel | str,
315
- ) -> Sequence[RunnableLike]:
316
- """
317
- Create LangChain tools from Unity Catalog functions.
318
-
319
- This factory function wraps Unity Catalog functions as LangChain tools,
320
- making them available for use by agents. Each UC function becomes a callable
321
- tool that can be invoked by the agent during reasoning.
322
-
323
- Args:
324
- function: UnityCatalogFunctionModel instance containing the function details
325
-
326
- Returns:
327
- A sequence of BaseTool objects that wrap the specified UC functions
328
- """
329
-
330
- logger.debug(f"create_uc_tools: {function}")
331
-
332
- if isinstance(function, UnityCatalogFunctionModel):
333
- function = function.full_name
334
-
335
- client: DatabricksFunctionClient = DatabricksFunctionClient()
336
-
337
- toolkit: UCFunctionToolkit = UCFunctionToolkit(
338
- function_names=[function], client=client
339
- )
340
-
341
- tools = toolkit.tools or []
342
-
343
- logger.debug(f"Retrieved tools: {tools}")
344
-
345
- tools = [as_human_in_the_loop(tool=tool, function=function) for tool in tools]
346
-
347
- return tools
348
-
349
-
350
57
  def search_tool() -> RunnableLike:
351
58
  logger.debug("search_tool")
352
59
  return DuckDuckGoSearchRun(output_format="list")
@@ -0,0 +1,96 @@
1
+ from typing import Any, Optional
2
+
3
+ from langchain_core.runnables import RunnableConfig
4
+ from langchain_core.runnables.base import RunnableLike
5
+ from langchain_core.tools import BaseTool
6
+ from langchain_core.tools import tool as create_tool
7
+ from langgraph.prebuilt.interrupt import HumanInterrupt, HumanInterruptConfig
8
+ from langgraph.types import interrupt
9
+ from loguru import logger
10
+
11
+ from dao_ai.config import (
12
+ BaseFunctionModel,
13
+ HumanInTheLoopModel,
14
+ )
15
+
16
+
17
+ def add_human_in_the_loop(
18
+ tool: RunnableLike,
19
+ *,
20
+ interrupt_config: HumanInterruptConfig | None = None,
21
+ review_prompt: Optional[str] = "Please review the tool call",
22
+ ) -> BaseTool:
23
+ """
24
+ Wrap a tool with human-in-the-loop functionality.
25
+ This function takes a tool (either a callable or a BaseTool instance) and wraps it
26
+ with a human-in-the-loop mechanism. When the tool is invoked, it will first
27
+ request human review before executing the tool's logic. The human can choose to
28
+ accept, edit the input, or provide a custom response.
29
+
30
+ Args:
31
+ tool (Callable[..., Any] | BaseTool): _description_
32
+ interrupt_config (HumanInterruptConfig | None, optional): _description_. Defaults to None.
33
+
34
+ Raises:
35
+ ValueError: _description_
36
+
37
+ Returns:
38
+ BaseTool: _description_
39
+ """
40
+ if not isinstance(tool, BaseTool):
41
+ tool = create_tool(tool)
42
+
43
+ if interrupt_config is None:
44
+ interrupt_config = {
45
+ "allow_accept": True,
46
+ "allow_edit": True,
47
+ "allow_respond": True,
48
+ }
49
+
50
+ logger.debug(f"Wrapping tool {tool} with human-in-the-loop functionality")
51
+
52
+ @create_tool(tool.name, description=tool.description, args_schema=tool.args_schema)
53
+ def call_tool_with_interrupt(config: RunnableConfig, **tool_input) -> Any:
54
+ logger.debug(f"call_tool_with_interrupt: {tool.name} with input: {tool_input}")
55
+ request: HumanInterrupt = {
56
+ "action_request": {
57
+ "action": tool.name,
58
+ "args": tool_input,
59
+ },
60
+ "config": interrupt_config,
61
+ "description": review_prompt,
62
+ }
63
+
64
+ logger.debug(f"Human interrupt request: {request}")
65
+ response: dict[str, Any] = interrupt([request])[0]
66
+ logger.debug(f"Human interrupt response: {response}")
67
+
68
+ if response["type"] == "accept":
69
+ tool_response = tool.invoke(tool_input, config=config)
70
+ elif response["type"] == "edit":
71
+ tool_input = response["args"]["args"]
72
+ tool_response = tool.invoke(tool_input, config=config)
73
+ elif response["type"] == "response":
74
+ user_feedback = response["args"]
75
+ tool_response = user_feedback
76
+ else:
77
+ raise ValueError(f"Unknown interrupt response type: {response['type']}")
78
+
79
+ return tool_response
80
+
81
+ return call_tool_with_interrupt
82
+
83
+
84
+ def as_human_in_the_loop(
85
+ tool: RunnableLike, function: BaseFunctionModel | str
86
+ ) -> RunnableLike:
87
+ if isinstance(function, BaseFunctionModel):
88
+ human_in_the_loop: HumanInTheLoopModel | None = function.human_in_the_loop
89
+ if human_in_the_loop:
90
+ logger.debug(f"Adding human-in-the-loop to tool: {tool.name}")
91
+ tool = add_human_in_the_loop(
92
+ tool=tool,
93
+ interrupt_config=human_in_the_loop.interupt_config,
94
+ review_prompt=human_in_the_loop.review_prompt,
95
+ )
96
+ return tool