dao-ai 0.0.28__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/__init__.py +29 -0
- dao_ai/agent_as_code.py +2 -5
- dao_ai/cli.py +245 -40
- dao_ai/config.py +1491 -370
- dao_ai/genie/__init__.py +38 -0
- dao_ai/genie/cache/__init__.py +43 -0
- dao_ai/genie/cache/base.py +72 -0
- dao_ai/genie/cache/core.py +79 -0
- dao_ai/genie/cache/lru.py +347 -0
- dao_ai/genie/cache/semantic.py +970 -0
- dao_ai/genie/core.py +35 -0
- dao_ai/graph.py +27 -253
- dao_ai/hooks/__init__.py +9 -6
- dao_ai/hooks/core.py +27 -195
- dao_ai/logging.py +56 -0
- dao_ai/memory/__init__.py +10 -0
- dao_ai/memory/core.py +65 -30
- dao_ai/memory/databricks.py +402 -0
- dao_ai/memory/postgres.py +79 -38
- dao_ai/messages.py +6 -4
- dao_ai/middleware/__init__.py +125 -0
- dao_ai/middleware/assertions.py +806 -0
- dao_ai/middleware/base.py +50 -0
- dao_ai/middleware/core.py +67 -0
- dao_ai/middleware/guardrails.py +420 -0
- dao_ai/middleware/human_in_the_loop.py +232 -0
- dao_ai/middleware/message_validation.py +586 -0
- dao_ai/middleware/summarization.py +197 -0
- dao_ai/models.py +1306 -114
- dao_ai/nodes.py +245 -159
- dao_ai/optimization.py +674 -0
- dao_ai/orchestration/__init__.py +52 -0
- dao_ai/orchestration/core.py +294 -0
- dao_ai/orchestration/supervisor.py +278 -0
- dao_ai/orchestration/swarm.py +271 -0
- dao_ai/prompts.py +128 -31
- dao_ai/providers/databricks.py +573 -601
- dao_ai/state.py +157 -21
- dao_ai/tools/__init__.py +13 -5
- dao_ai/tools/agent.py +1 -3
- dao_ai/tools/core.py +64 -11
- dao_ai/tools/email.py +232 -0
- dao_ai/tools/genie.py +144 -294
- dao_ai/tools/mcp.py +223 -155
- dao_ai/tools/memory.py +50 -0
- dao_ai/tools/python.py +9 -14
- dao_ai/tools/search.py +14 -0
- dao_ai/tools/slack.py +22 -10
- dao_ai/tools/sql.py +202 -0
- dao_ai/tools/time.py +30 -7
- dao_ai/tools/unity_catalog.py +165 -88
- dao_ai/tools/vector_search.py +331 -221
- dao_ai/utils.py +166 -20
- dao_ai-0.1.2.dist-info/METADATA +455 -0
- dao_ai-0.1.2.dist-info/RECORD +64 -0
- dao_ai/chat_models.py +0 -204
- dao_ai/guardrails.py +0 -112
- dao_ai/tools/human_in_the_loop.py +0 -100
- dao_ai-0.0.28.dist-info/METADATA +0 -1168
- dao_ai-0.0.28.dist-info/RECORD +0 -41
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/licenses/LICENSE +0 -0
dao_ai/chat_models.py
DELETED
|
@@ -1,204 +0,0 @@
|
|
|
1
|
-
import json
|
|
2
|
-
from typing import Any, Iterator, Optional, Sequence
|
|
3
|
-
|
|
4
|
-
from databricks_langchain import ChatDatabricks
|
|
5
|
-
from langchain_core.callbacks import CallbackManagerForLLMRun
|
|
6
|
-
from langchain_core.messages import AIMessage, BaseMessage, ToolMessage
|
|
7
|
-
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
|
8
|
-
from loguru import logger
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
class ChatDatabricksFiltered(ChatDatabricks):
|
|
12
|
-
def __init__(self, **kwargs):
|
|
13
|
-
super().__init__(**kwargs)
|
|
14
|
-
|
|
15
|
-
def _preprocess_messages(
|
|
16
|
-
self, messages: Sequence[BaseMessage]
|
|
17
|
-
) -> Sequence[BaseMessage]:
|
|
18
|
-
logger.debug(f"Preprocessing {len(messages)} messages for filtering")
|
|
19
|
-
|
|
20
|
-
logger.trace(
|
|
21
|
-
f"Original messages:\n{json.dumps([msg.model_dump() for msg in messages], indent=2)}"
|
|
22
|
-
)
|
|
23
|
-
|
|
24
|
-
# Diagnostic logging to understand what types of messages we're getting
|
|
25
|
-
message_types = {}
|
|
26
|
-
remove_message_count = 0
|
|
27
|
-
empty_content_count = 0
|
|
28
|
-
|
|
29
|
-
for msg in messages:
|
|
30
|
-
msg_type = msg.__class__.__name__
|
|
31
|
-
message_types[msg_type] = message_types.get(msg_type, 0) + 1
|
|
32
|
-
|
|
33
|
-
if msg_type == "RemoveMessage":
|
|
34
|
-
remove_message_count += 1
|
|
35
|
-
elif hasattr(msg, "content") and (msg.content == "" or msg.content is None):
|
|
36
|
-
empty_content_count += 1
|
|
37
|
-
|
|
38
|
-
logger.debug(f"Message type breakdown: {message_types}")
|
|
39
|
-
logger.debug(
|
|
40
|
-
f"RemoveMessage count: {remove_message_count}, Empty content count: {empty_content_count}"
|
|
41
|
-
)
|
|
42
|
-
|
|
43
|
-
filtered_messages = []
|
|
44
|
-
for i, msg in enumerate(messages):
|
|
45
|
-
# First, filter out RemoveMessage objects completely - they're LangGraph-specific
|
|
46
|
-
# and should never be sent to an LLM
|
|
47
|
-
if hasattr(msg, "__class__") and msg.__class__.__name__ == "RemoveMessage":
|
|
48
|
-
logger.debug(f"Filtering out RemoveMessage at index {i}")
|
|
49
|
-
continue
|
|
50
|
-
|
|
51
|
-
# Be very conservative with filtering - only filter out messages that are:
|
|
52
|
-
# 1. Have empty or None content AND
|
|
53
|
-
# 2. Are not tool-related messages AND
|
|
54
|
-
# 3. Don't break tool_use/tool_result pairing
|
|
55
|
-
# 4. Are not the only remaining message (to avoid filtering everything)
|
|
56
|
-
has_empty_content = hasattr(msg, "content") and (
|
|
57
|
-
msg.content == "" or msg.content is None
|
|
58
|
-
)
|
|
59
|
-
|
|
60
|
-
# Check if this message has tool calls (non-empty list)
|
|
61
|
-
has_tool_calls = (
|
|
62
|
-
hasattr(msg, "tool_calls")
|
|
63
|
-
and msg.tool_calls
|
|
64
|
-
and len(msg.tool_calls) > 0
|
|
65
|
-
)
|
|
66
|
-
|
|
67
|
-
# Check if this is a tool result message
|
|
68
|
-
is_tool_result = hasattr(msg, "tool_call_id") or isinstance(
|
|
69
|
-
msg, ToolMessage
|
|
70
|
-
)
|
|
71
|
-
|
|
72
|
-
# Check if the previous message had tool calls (this message might be a tool result)
|
|
73
|
-
prev_had_tool_calls = False
|
|
74
|
-
if i > 0:
|
|
75
|
-
prev_msg = messages[i - 1]
|
|
76
|
-
prev_had_tool_calls = (
|
|
77
|
-
hasattr(prev_msg, "tool_calls")
|
|
78
|
-
and prev_msg.tool_calls
|
|
79
|
-
and len(prev_msg.tool_calls) > 0
|
|
80
|
-
)
|
|
81
|
-
|
|
82
|
-
# Check if the next message is a tool result (this message might be a tool use)
|
|
83
|
-
next_is_tool_result = False
|
|
84
|
-
if i < len(messages) - 1:
|
|
85
|
-
next_msg = messages[i + 1]
|
|
86
|
-
next_is_tool_result = hasattr(next_msg, "tool_call_id") or isinstance(
|
|
87
|
-
next_msg, ToolMessage
|
|
88
|
-
)
|
|
89
|
-
|
|
90
|
-
# Special handling for empty AIMessages - they might be placeholders or incomplete responses
|
|
91
|
-
# Don't filter them if they're the only AI response or seem important to the conversation flow
|
|
92
|
-
is_empty_ai_message = has_empty_content and isinstance(msg, AIMessage)
|
|
93
|
-
|
|
94
|
-
# Only filter out messages with empty content that are definitely not needed
|
|
95
|
-
should_filter = (
|
|
96
|
-
has_empty_content
|
|
97
|
-
and not has_tool_calls
|
|
98
|
-
and not is_tool_result
|
|
99
|
-
and not prev_had_tool_calls # Don't filter if previous message had tool calls
|
|
100
|
-
and not next_is_tool_result # Don't filter if next message is a tool result
|
|
101
|
-
and not (
|
|
102
|
-
is_empty_ai_message and len(messages) <= 2
|
|
103
|
-
) # Don't filter empty AI messages in short conversations
|
|
104
|
-
)
|
|
105
|
-
|
|
106
|
-
if should_filter:
|
|
107
|
-
logger.debug(f"Filtering out message at index {i}: {msg.model_dump()}")
|
|
108
|
-
continue
|
|
109
|
-
else:
|
|
110
|
-
filtered_messages.append(msg)
|
|
111
|
-
|
|
112
|
-
logger.debug(
|
|
113
|
-
f"Filtered {len(messages)} messages down to {len(filtered_messages)} messages"
|
|
114
|
-
)
|
|
115
|
-
|
|
116
|
-
# Log diagnostic information if all messages were filtered out
|
|
117
|
-
if len(filtered_messages) == 0:
|
|
118
|
-
logger.warning(
|
|
119
|
-
f"All {len(messages)} messages were filtered out! This indicates a problem with the conversation state."
|
|
120
|
-
)
|
|
121
|
-
logger.debug(f"Original message types: {message_types}")
|
|
122
|
-
|
|
123
|
-
if remove_message_count == len(messages):
|
|
124
|
-
logger.warning(
|
|
125
|
-
"All messages were RemoveMessage objects - this suggests a bug in summarization logic"
|
|
126
|
-
)
|
|
127
|
-
elif empty_content_count > 0:
|
|
128
|
-
logger.debug(f"{empty_content_count} messages had empty content")
|
|
129
|
-
|
|
130
|
-
return filtered_messages
|
|
131
|
-
|
|
132
|
-
def _postprocess_message(self, message: BaseMessage) -> BaseMessage:
|
|
133
|
-
return message
|
|
134
|
-
|
|
135
|
-
def _generate(
|
|
136
|
-
self,
|
|
137
|
-
messages: Sequence[BaseMessage],
|
|
138
|
-
stop: Optional[Sequence[str]] = None,
|
|
139
|
-
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
140
|
-
**kwargs: Any,
|
|
141
|
-
) -> ChatResult:
|
|
142
|
-
"""Override _generate to apply message preprocessing and postprocessing."""
|
|
143
|
-
# Apply message preprocessing
|
|
144
|
-
processed_messages: Sequence[BaseMessage] = self._preprocess_messages(messages)
|
|
145
|
-
|
|
146
|
-
if len(processed_messages) == 0:
|
|
147
|
-
logger.error(
|
|
148
|
-
"All messages were filtered out during preprocessing. This indicates a serious issue with the conversation state."
|
|
149
|
-
)
|
|
150
|
-
empty_generation = ChatGeneration(
|
|
151
|
-
message=AIMessage(content="", id="empty-response")
|
|
152
|
-
)
|
|
153
|
-
return ChatResult(generations=[empty_generation])
|
|
154
|
-
|
|
155
|
-
logger.trace(
|
|
156
|
-
f"Processed messages:\n{json.dumps([msg.model_dump() for msg in processed_messages], indent=2)}"
|
|
157
|
-
)
|
|
158
|
-
|
|
159
|
-
result: ChatResult = super()._generate(
|
|
160
|
-
processed_messages, stop, run_manager, **kwargs
|
|
161
|
-
)
|
|
162
|
-
|
|
163
|
-
if result.generations:
|
|
164
|
-
for generation in result.generations:
|
|
165
|
-
if isinstance(generation, ChatGeneration) and generation.message:
|
|
166
|
-
generation.message = self._postprocess_message(generation.message)
|
|
167
|
-
|
|
168
|
-
return result
|
|
169
|
-
|
|
170
|
-
def _stream(
|
|
171
|
-
self,
|
|
172
|
-
messages: Sequence[BaseMessage],
|
|
173
|
-
stop: Optional[Sequence[str]] = None,
|
|
174
|
-
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
175
|
-
**kwargs: Any,
|
|
176
|
-
) -> Iterator[ChatGeneration]:
|
|
177
|
-
"""Override _stream to apply message preprocessing and postprocessing."""
|
|
178
|
-
# Apply message preprocessing
|
|
179
|
-
processed_messages: Sequence[BaseMessage] = self._preprocess_messages(messages)
|
|
180
|
-
|
|
181
|
-
# Handle the edge case where all messages were filtered out
|
|
182
|
-
if len(processed_messages) == 0:
|
|
183
|
-
logger.error(
|
|
184
|
-
"All messages were filtered out during preprocessing. This indicates a serious issue with the conversation state."
|
|
185
|
-
)
|
|
186
|
-
# Return an empty streaming result without calling the underlying API
|
|
187
|
-
# This prevents API errors while making the issue visible through an empty response
|
|
188
|
-
empty_chunk = ChatGenerationChunk(
|
|
189
|
-
message=AIMessage(content="", id="empty-response")
|
|
190
|
-
)
|
|
191
|
-
yield empty_chunk
|
|
192
|
-
return
|
|
193
|
-
|
|
194
|
-
logger.trace(
|
|
195
|
-
f"Processed messages:\n{json.dumps([msg.model_dump() for msg in processed_messages], indent=2)}"
|
|
196
|
-
)
|
|
197
|
-
|
|
198
|
-
# Call the parent ChatDatabricks implementation
|
|
199
|
-
for chunk in super()._stream(processed_messages, stop, run_manager, **kwargs):
|
|
200
|
-
chunk: ChatGenerationChunk
|
|
201
|
-
# Apply message postprocessing to each chunk
|
|
202
|
-
if isinstance(chunk, ChatGeneration) and chunk.message:
|
|
203
|
-
chunk.message = self._postprocess_message(chunk.message)
|
|
204
|
-
yield chunk
|
dao_ai/guardrails.py
DELETED
|
@@ -1,112 +0,0 @@
|
|
|
1
|
-
from typing import Any, Literal, Optional, Type
|
|
2
|
-
|
|
3
|
-
from langchain_core.language_models import LanguageModelLike
|
|
4
|
-
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
|
5
|
-
from langchain_core.runnables import RunnableConfig
|
|
6
|
-
from langchain_core.runnables.base import RunnableLike
|
|
7
|
-
from langgraph.graph import END, START, MessagesState, StateGraph
|
|
8
|
-
from langgraph.graph.state import CompiledStateGraph
|
|
9
|
-
from langgraph.managed import RemainingSteps
|
|
10
|
-
from loguru import logger
|
|
11
|
-
from openevals.llm import create_llm_as_judge
|
|
12
|
-
|
|
13
|
-
from dao_ai.config import GuardrailModel
|
|
14
|
-
from dao_ai.messages import last_ai_message, last_human_message
|
|
15
|
-
from dao_ai.state import SharedState
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
class MessagesWithSteps(MessagesState):
|
|
19
|
-
guardrails_remaining_steps: RemainingSteps
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
def end_or_reflect(state: MessagesWithSteps) -> Literal[END, "graph"]:
|
|
23
|
-
if state["guardrails_remaining_steps"] < 2:
|
|
24
|
-
return END
|
|
25
|
-
if len(state["messages"]) == 0:
|
|
26
|
-
return END
|
|
27
|
-
last_message = state["messages"][-1]
|
|
28
|
-
if isinstance(last_message, HumanMessage):
|
|
29
|
-
return "graph"
|
|
30
|
-
else:
|
|
31
|
-
return END
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
def create_reflection_graph(
|
|
35
|
-
graph: CompiledStateGraph,
|
|
36
|
-
reflection: CompiledStateGraph,
|
|
37
|
-
state_schema: Optional[Type[Any]] = None,
|
|
38
|
-
config_schema: Optional[Type[Any]] = None,
|
|
39
|
-
) -> StateGraph:
|
|
40
|
-
logger.debug("Creating reflection graph")
|
|
41
|
-
_state_schema = state_schema or graph.builder.schema
|
|
42
|
-
|
|
43
|
-
if "guardrails_remaining_steps" in _state_schema.__annotations__:
|
|
44
|
-
raise ValueError(
|
|
45
|
-
"Has key 'guardrails_remaining_steps' in state_schema, this shadows a built in key"
|
|
46
|
-
)
|
|
47
|
-
|
|
48
|
-
if "messages" not in _state_schema.__annotations__:
|
|
49
|
-
raise ValueError("Missing required key 'messages' in state_schema")
|
|
50
|
-
|
|
51
|
-
class StateSchema(_state_schema):
|
|
52
|
-
guardrails_remaining_steps: RemainingSteps
|
|
53
|
-
|
|
54
|
-
rgraph = StateGraph(StateSchema, config_schema=config_schema)
|
|
55
|
-
rgraph.add_node("graph", graph)
|
|
56
|
-
rgraph.add_node("reflection", reflection)
|
|
57
|
-
rgraph.add_edge(START, "graph")
|
|
58
|
-
rgraph.add_edge("graph", "reflection")
|
|
59
|
-
rgraph.add_conditional_edges("reflection", end_or_reflect)
|
|
60
|
-
return rgraph
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
def with_guardrails(
|
|
64
|
-
graph: CompiledStateGraph, guardrail: CompiledStateGraph
|
|
65
|
-
) -> CompiledStateGraph:
|
|
66
|
-
logger.debug("Creating graph with guardrails")
|
|
67
|
-
return create_reflection_graph(
|
|
68
|
-
graph, guardrail, state_schema=SharedState, config_schema=RunnableConfig
|
|
69
|
-
).compile()
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
def judge_node(guardrails: GuardrailModel) -> RunnableLike:
|
|
73
|
-
def judge(state: SharedState, config: RunnableConfig) -> dict[str, BaseMessage]:
|
|
74
|
-
llm: LanguageModelLike = guardrails.model.as_chat_model()
|
|
75
|
-
|
|
76
|
-
evaluator = create_llm_as_judge(
|
|
77
|
-
prompt=guardrails.prompt,
|
|
78
|
-
judge=llm,
|
|
79
|
-
)
|
|
80
|
-
|
|
81
|
-
ai_message: AIMessage = last_ai_message(state["messages"])
|
|
82
|
-
human_message: HumanMessage = last_human_message(state["messages"])
|
|
83
|
-
|
|
84
|
-
logger.debug(f"Evaluating response: {ai_message.content}")
|
|
85
|
-
eval_result = evaluator(
|
|
86
|
-
inputs=human_message.content, outputs=ai_message.content
|
|
87
|
-
)
|
|
88
|
-
|
|
89
|
-
if eval_result["score"]:
|
|
90
|
-
logger.debug("Response approved by judge")
|
|
91
|
-
logger.debug(f"Judge's comment: {eval_result['comment']}")
|
|
92
|
-
return
|
|
93
|
-
else:
|
|
94
|
-
# Otherwise, return the judge's critique as a new user message
|
|
95
|
-
logger.warning("Judge requested improvements")
|
|
96
|
-
comment: str = eval_result["comment"]
|
|
97
|
-
logger.warning(f"Judge's critique: {comment}")
|
|
98
|
-
content: str = "\n".join([human_message.content, comment])
|
|
99
|
-
return {"messages": [HumanMessage(content=content)]}
|
|
100
|
-
|
|
101
|
-
return judge
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
def reflection_guardrail(guardrails: GuardrailModel) -> CompiledStateGraph:
|
|
105
|
-
judge: CompiledStateGraph = (
|
|
106
|
-
StateGraph(SharedState, config_schema=RunnableConfig)
|
|
107
|
-
.add_node("judge", judge_node(guardrails=guardrails))
|
|
108
|
-
.add_edge(START, "judge")
|
|
109
|
-
.add_edge("judge", END)
|
|
110
|
-
.compile()
|
|
111
|
-
)
|
|
112
|
-
return judge
|
|
@@ -1,100 +0,0 @@
|
|
|
1
|
-
from typing import Any, Optional
|
|
2
|
-
|
|
3
|
-
from langchain_core.runnables import RunnableConfig
|
|
4
|
-
from langchain_core.runnables.base import RunnableLike
|
|
5
|
-
from langchain_core.tools import BaseTool
|
|
6
|
-
from langchain_core.tools import tool as create_tool
|
|
7
|
-
from langgraph.prebuilt.interrupt import HumanInterrupt, HumanInterruptConfig
|
|
8
|
-
from langgraph.types import interrupt
|
|
9
|
-
from loguru import logger
|
|
10
|
-
|
|
11
|
-
from dao_ai.config import (
|
|
12
|
-
BaseFunctionModel,
|
|
13
|
-
HumanInTheLoopModel,
|
|
14
|
-
)
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
def add_human_in_the_loop(
|
|
18
|
-
tool: RunnableLike,
|
|
19
|
-
*,
|
|
20
|
-
interrupt_config: HumanInterruptConfig | None = None,
|
|
21
|
-
review_prompt: Optional[str] = "Please review the tool call",
|
|
22
|
-
) -> BaseTool:
|
|
23
|
-
"""
|
|
24
|
-
Wrap a tool with human-in-the-loop functionality.
|
|
25
|
-
This function takes a tool (either a callable or a BaseTool instance) and wraps it
|
|
26
|
-
with a human-in-the-loop mechanism. When the tool is invoked, it will first
|
|
27
|
-
request human review before executing the tool's logic. The human can choose to
|
|
28
|
-
accept, edit the input, or provide a custom response.
|
|
29
|
-
|
|
30
|
-
Args:
|
|
31
|
-
tool (Callable[..., Any] | BaseTool): _description_
|
|
32
|
-
interrupt_config (HumanInterruptConfig | None, optional): _description_. Defaults to None.
|
|
33
|
-
|
|
34
|
-
Raises:
|
|
35
|
-
ValueError: _description_
|
|
36
|
-
|
|
37
|
-
Returns:
|
|
38
|
-
BaseTool: _description_
|
|
39
|
-
"""
|
|
40
|
-
if not isinstance(tool, BaseTool):
|
|
41
|
-
tool = create_tool(tool)
|
|
42
|
-
|
|
43
|
-
if interrupt_config is None:
|
|
44
|
-
interrupt_config = {
|
|
45
|
-
"allow_accept": True,
|
|
46
|
-
"allow_edit": True,
|
|
47
|
-
"allow_respond": True,
|
|
48
|
-
}
|
|
49
|
-
|
|
50
|
-
logger.debug(f"Wrapping tool {tool} with human-in-the-loop functionality")
|
|
51
|
-
|
|
52
|
-
@create_tool(tool.name, description=tool.description, args_schema=tool.args_schema)
|
|
53
|
-
async def call_tool_with_interrupt(config: RunnableConfig, **tool_input) -> Any:
|
|
54
|
-
logger.debug(f"call_tool_with_interrupt: {tool.name} with input: {tool_input}")
|
|
55
|
-
request: HumanInterrupt = {
|
|
56
|
-
"action_request": {
|
|
57
|
-
"action": tool.name,
|
|
58
|
-
"args": tool_input,
|
|
59
|
-
},
|
|
60
|
-
"config": interrupt_config,
|
|
61
|
-
"description": review_prompt,
|
|
62
|
-
}
|
|
63
|
-
|
|
64
|
-
logger.debug(f"Human interrupt request: {request}")
|
|
65
|
-
response: dict[str, Any] = interrupt([request])[0]
|
|
66
|
-
logger.debug(f"Human interrupt response: {response}")
|
|
67
|
-
|
|
68
|
-
if response["type"] == "accept":
|
|
69
|
-
tool_response = await tool.ainvoke(tool_input, config=config)
|
|
70
|
-
elif response["type"] == "edit":
|
|
71
|
-
tool_input = response["args"]["args"]
|
|
72
|
-
tool_response = await tool.ainvoke(tool_input, config=config)
|
|
73
|
-
elif response["type"] == "response":
|
|
74
|
-
user_feedback = response["args"]
|
|
75
|
-
tool_response = user_feedback
|
|
76
|
-
else:
|
|
77
|
-
raise ValueError(f"Unknown interrupt response type: {response['type']}")
|
|
78
|
-
|
|
79
|
-
return tool_response
|
|
80
|
-
|
|
81
|
-
return call_tool_with_interrupt
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
def as_human_in_the_loop(
|
|
85
|
-
tool: RunnableLike, function: BaseFunctionModel | str
|
|
86
|
-
) -> RunnableLike:
|
|
87
|
-
if isinstance(function, BaseFunctionModel):
|
|
88
|
-
human_in_the_loop: HumanInTheLoopModel | None = function.human_in_the_loop
|
|
89
|
-
if human_in_the_loop:
|
|
90
|
-
# Get tool name safely - handle RunnableBinding objects
|
|
91
|
-
tool_name = getattr(tool, "name", None) or getattr(
|
|
92
|
-
getattr(tool, "bound", None), "name", "unknown_tool"
|
|
93
|
-
)
|
|
94
|
-
logger.debug(f"Adding human-in-the-loop to tool: {tool_name}")
|
|
95
|
-
tool = add_human_in_the_loop(
|
|
96
|
-
tool=tool,
|
|
97
|
-
interrupt_config=human_in_the_loop.interupt_config,
|
|
98
|
-
review_prompt=human_in_the_loop.review_prompt,
|
|
99
|
-
)
|
|
100
|
-
return tool
|