haystack-experimental 0.13.0__py3-none-any.whl → 0.14.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- haystack_experimental/components/agents/__init__.py +16 -0
- haystack_experimental/components/agents/agent.py +634 -0
- haystack_experimental/components/agents/human_in_the_loop/__init__.py +35 -0
- haystack_experimental/components/agents/human_in_the_loop/breakpoint.py +63 -0
- haystack_experimental/components/agents/human_in_the_loop/dataclasses.py +72 -0
- haystack_experimental/components/agents/human_in_the_loop/errors.py +28 -0
- haystack_experimental/components/agents/human_in_the_loop/policies.py +78 -0
- haystack_experimental/components/agents/human_in_the_loop/strategies.py +455 -0
- haystack_experimental/components/agents/human_in_the_loop/types.py +89 -0
- haystack_experimental/components/agents/human_in_the_loop/user_interfaces.py +209 -0
- haystack_experimental/components/preprocessors/embedding_based_document_splitter.py +18 -6
- haystack_experimental/components/preprocessors/md_header_level_inferrer.py +146 -0
- haystack_experimental/components/summarizers/__init__.py +7 -0
- haystack_experimental/components/summarizers/llm_summarizer.py +317 -0
- haystack_experimental/core/__init__.py +3 -0
- haystack_experimental/core/pipeline/__init__.py +3 -0
- haystack_experimental/core/pipeline/breakpoint.py +174 -0
- haystack_experimental/dataclasses/__init__.py +3 -0
- haystack_experimental/dataclasses/breakpoints.py +53 -0
- {haystack_experimental-0.13.0.dist-info → haystack_experimental-0.14.0.dist-info}/METADATA +29 -14
- {haystack_experimental-0.13.0.dist-info → haystack_experimental-0.14.0.dist-info}/RECORD +24 -6
- {haystack_experimental-0.13.0.dist-info → haystack_experimental-0.14.0.dist-info}/WHEEL +0 -0
- {haystack_experimental-0.13.0.dist-info → haystack_experimental-0.14.0.dist-info}/licenses/LICENSE +0 -0
- {haystack_experimental-0.13.0.dist-info → haystack_experimental-0.14.0.dist-info}/licenses/LICENSE-MIT.txt +0 -0
|
@@ -0,0 +1,455 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
2
|
+
#
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
from dataclasses import replace
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Optional
|
|
7
|
+
|
|
8
|
+
from haystack.components.agents.state import State
|
|
9
|
+
from haystack.components.tools.tool_invoker import ToolInvoker
|
|
10
|
+
from haystack.core.serialization import default_from_dict, default_to_dict, import_class_by_name
|
|
11
|
+
from haystack.dataclasses import ChatMessage, StreamingCallbackT
|
|
12
|
+
from haystack.tools import Tool
|
|
13
|
+
|
|
14
|
+
from haystack_experimental.components.agents.human_in_the_loop import (
|
|
15
|
+
ConfirmationPolicy,
|
|
16
|
+
ConfirmationStrategy,
|
|
17
|
+
ConfirmationUI,
|
|
18
|
+
HITLBreakpointException,
|
|
19
|
+
ToolExecutionDecision,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from haystack_experimental.components.agents.agent import _ExecutionContext
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
_REJECTION_FEEDBACK_TEMPLATE = "Tool execution for '{tool_name}' was rejected by the user."
|
|
27
|
+
_MODIFICATION_FEEDBACK_TEMPLATE = (
|
|
28
|
+
"The parameters for tool '{tool_name}' were updated by the user to:\n{final_tool_params}"
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class BlockingConfirmationStrategy:
|
|
33
|
+
"""
|
|
34
|
+
Confirmation strategy that blocks execution to gather user feedback.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(self, confirmation_policy: ConfirmationPolicy, confirmation_ui: ConfirmationUI) -> None:
|
|
38
|
+
"""
|
|
39
|
+
Initialize the BlockingConfirmationStrategy with a confirmation policy and UI.
|
|
40
|
+
|
|
41
|
+
:param confirmation_policy:
|
|
42
|
+
The confirmation policy to determine when to ask for user confirmation.
|
|
43
|
+
:param confirmation_ui:
|
|
44
|
+
The user interface to interact with the user for confirmation.
|
|
45
|
+
"""
|
|
46
|
+
self.confirmation_policy = confirmation_policy
|
|
47
|
+
self.confirmation_ui = confirmation_ui
|
|
48
|
+
|
|
49
|
+
def run(
|
|
50
|
+
self, tool_name: str, tool_description: str, tool_params: dict[str, Any], tool_call_id: Optional[str] = None
|
|
51
|
+
) -> ToolExecutionDecision:
|
|
52
|
+
"""
|
|
53
|
+
Run the human-in-the-loop strategy for a given tool and its parameters.
|
|
54
|
+
|
|
55
|
+
:param tool_name:
|
|
56
|
+
The name of the tool to be executed.
|
|
57
|
+
:param tool_description:
|
|
58
|
+
The description of the tool.
|
|
59
|
+
:param tool_params:
|
|
60
|
+
The parameters to be passed to the tool.
|
|
61
|
+
:param tool_call_id:
|
|
62
|
+
Optional unique identifier for the tool call. This can be used to track and correlate the decision with a
|
|
63
|
+
specific tool invocation.
|
|
64
|
+
|
|
65
|
+
:returns:
|
|
66
|
+
A ToolExecutionDecision indicating whether to execute the tool with the given parameters, or a
|
|
67
|
+
feedback message if rejected.
|
|
68
|
+
"""
|
|
69
|
+
# Check if we should ask based on policy
|
|
70
|
+
if not self.confirmation_policy.should_ask(
|
|
71
|
+
tool_name=tool_name, tool_description=tool_description, tool_params=tool_params
|
|
72
|
+
):
|
|
73
|
+
return ToolExecutionDecision(
|
|
74
|
+
tool_name=tool_name, execute=True, tool_call_id=tool_call_id, final_tool_params=tool_params
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# Get user confirmation through UI
|
|
78
|
+
confirmation_ui_result = self.confirmation_ui.get_user_confirmation(tool_name, tool_description, tool_params)
|
|
79
|
+
|
|
80
|
+
# Pass back the result to the policy for any learning/updating
|
|
81
|
+
self.confirmation_policy.update_after_confirmation(
|
|
82
|
+
tool_name, tool_description, tool_params, confirmation_ui_result
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
# Process the confirmation result
|
|
86
|
+
final_args = {}
|
|
87
|
+
if confirmation_ui_result.action == "reject":
|
|
88
|
+
explanation_text = _REJECTION_FEEDBACK_TEMPLATE.format(tool_name=tool_name)
|
|
89
|
+
if confirmation_ui_result.feedback:
|
|
90
|
+
explanation_text += f" With feedback: {confirmation_ui_result.feedback}"
|
|
91
|
+
return ToolExecutionDecision(
|
|
92
|
+
tool_name=tool_name, execute=False, tool_call_id=tool_call_id, feedback=explanation_text
|
|
93
|
+
)
|
|
94
|
+
elif confirmation_ui_result.action == "modify" and confirmation_ui_result.new_tool_params:
|
|
95
|
+
# Update the tool call params with the new params
|
|
96
|
+
final_args.update(confirmation_ui_result.new_tool_params)
|
|
97
|
+
explanation_text = _MODIFICATION_FEEDBACK_TEMPLATE.format(tool_name=tool_name, final_tool_params=final_args)
|
|
98
|
+
if confirmation_ui_result.feedback:
|
|
99
|
+
explanation_text += f" With feedback: {confirmation_ui_result.feedback}"
|
|
100
|
+
return ToolExecutionDecision(
|
|
101
|
+
tool_name=tool_name,
|
|
102
|
+
tool_call_id=tool_call_id,
|
|
103
|
+
execute=True,
|
|
104
|
+
feedback=explanation_text,
|
|
105
|
+
final_tool_params=final_args,
|
|
106
|
+
)
|
|
107
|
+
else: # action == "confirm"
|
|
108
|
+
return ToolExecutionDecision(
|
|
109
|
+
tool_name=tool_name, execute=True, tool_call_id=tool_call_id, final_tool_params=tool_params
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
def to_dict(self) -> dict[str, Any]:
|
|
113
|
+
"""
|
|
114
|
+
Serializes the BlockingConfirmationStrategy to a dictionary.
|
|
115
|
+
|
|
116
|
+
:returns:
|
|
117
|
+
Dictionary with serialized data.
|
|
118
|
+
"""
|
|
119
|
+
return default_to_dict(
|
|
120
|
+
self, confirmation_policy=self.confirmation_policy.to_dict(), confirmation_ui=self.confirmation_ui.to_dict()
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
@classmethod
|
|
124
|
+
def from_dict(cls, data: dict[str, Any]) -> "BlockingConfirmationStrategy":
|
|
125
|
+
"""
|
|
126
|
+
Deserializes the BlockingConfirmationStrategy from a dictionary.
|
|
127
|
+
|
|
128
|
+
:param data:
|
|
129
|
+
Dictionary to deserialize from.
|
|
130
|
+
|
|
131
|
+
:returns:
|
|
132
|
+
Deserialized BlockingConfirmationStrategy.
|
|
133
|
+
"""
|
|
134
|
+
policy_data = data["init_parameters"]["confirmation_policy"]
|
|
135
|
+
policy_class = import_class_by_name(policy_data["type"])
|
|
136
|
+
if not hasattr(policy_class, "from_dict"):
|
|
137
|
+
raise ValueError(f"Class {policy_class} does not implement from_dict method.")
|
|
138
|
+
ui_data = data["init_parameters"]["confirmation_ui"]
|
|
139
|
+
ui_class = import_class_by_name(ui_data["type"])
|
|
140
|
+
if not hasattr(ui_class, "from_dict"):
|
|
141
|
+
raise ValueError(f"Class {ui_class} does not implement from_dict method.")
|
|
142
|
+
return cls(confirmation_policy=policy_class.from_dict(policy_data), confirmation_ui=ui_class.from_dict(ui_data))
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class BreakpointConfirmationStrategy:
|
|
146
|
+
"""
|
|
147
|
+
Confirmation strategy that raises a tool breakpoint exception to pause execution and gather user feedback.
|
|
148
|
+
|
|
149
|
+
This strategy is designed for scenarios where immediate user interaction is not possible.
|
|
150
|
+
When a tool execution requires confirmation, it raises an `HITLBreakpointException`, which is caught by the Agent.
|
|
151
|
+
The Agent then serialize its current state, including the tool call details. This information can then be used to
|
|
152
|
+
notify a user to review and confirm the tool execution.
|
|
153
|
+
"""
|
|
154
|
+
|
|
155
|
+
def __init__(self, snapshot_file_path: str) -> None:
|
|
156
|
+
"""
|
|
157
|
+
Initialize the BreakpointConfirmationStrategy.
|
|
158
|
+
|
|
159
|
+
:param snapshot_file_path: The path to the directory that the snapshot should be saved.
|
|
160
|
+
"""
|
|
161
|
+
self.snapshot_file_path = snapshot_file_path
|
|
162
|
+
|
|
163
|
+
def run(
|
|
164
|
+
self, tool_name: str, tool_description: str, tool_params: dict[str, Any], tool_call_id: Optional[str] = None
|
|
165
|
+
) -> ToolExecutionDecision:
|
|
166
|
+
"""
|
|
167
|
+
Run the breakpoint confirmation strategy for a given tool and its parameters.
|
|
168
|
+
|
|
169
|
+
:param tool_name:
|
|
170
|
+
The name of the tool to be executed.
|
|
171
|
+
:param tool_description:
|
|
172
|
+
The description of the tool.
|
|
173
|
+
:param tool_params:
|
|
174
|
+
The parameters to be passed to the tool.
|
|
175
|
+
:param tool_call_id:
|
|
176
|
+
Optional unique identifier for the tool call. This can be used to track and correlate the decision with a
|
|
177
|
+
specific tool invocation.
|
|
178
|
+
|
|
179
|
+
:raises HITLBreakpointException:
|
|
180
|
+
Always raises an `HITLBreakpointException` exception to signal that user confirmation is required.
|
|
181
|
+
|
|
182
|
+
:returns:
|
|
183
|
+
This method does not return; it always raises an exception.
|
|
184
|
+
"""
|
|
185
|
+
raise HITLBreakpointException(
|
|
186
|
+
message=f"Tool execution for '{tool_name}' requires user confirmation.",
|
|
187
|
+
tool_name=tool_name,
|
|
188
|
+
tool_call_id=tool_call_id,
|
|
189
|
+
snapshot_file_path=self.snapshot_file_path,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
def to_dict(self) -> dict[str, Any]:
|
|
193
|
+
"""
|
|
194
|
+
Serializes the BreakpointConfirmationStrategy to a dictionary.
|
|
195
|
+
"""
|
|
196
|
+
return default_to_dict(self, snapshot_file_path=self.snapshot_file_path)
|
|
197
|
+
|
|
198
|
+
@classmethod
|
|
199
|
+
def from_dict(cls, data: dict[str, Any]) -> "BreakpointConfirmationStrategy":
|
|
200
|
+
"""
|
|
201
|
+
Deserializes the BreakpointConfirmationStrategy from a dictionary.
|
|
202
|
+
|
|
203
|
+
:param data:
|
|
204
|
+
Dictionary to deserialize from.
|
|
205
|
+
|
|
206
|
+
:returns:
|
|
207
|
+
Deserialized BreakpointConfirmationStrategy.
|
|
208
|
+
"""
|
|
209
|
+
return default_from_dict(cls, data)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def _prepare_tool_args(
|
|
213
|
+
*,
|
|
214
|
+
tool: Tool,
|
|
215
|
+
tool_call_arguments: dict[str, Any],
|
|
216
|
+
state: State,
|
|
217
|
+
streaming_callback: Optional[StreamingCallbackT] = None,
|
|
218
|
+
enable_streaming_passthrough: bool = False,
|
|
219
|
+
) -> dict[str, Any]:
|
|
220
|
+
"""
|
|
221
|
+
Prepare the final arguments for a tool by injecting state inputs and optionally a streaming callback.
|
|
222
|
+
|
|
223
|
+
:param tool:
|
|
224
|
+
The tool instance to prepare arguments for.
|
|
225
|
+
:param tool_call_arguments:
|
|
226
|
+
The initial arguments provided for the tool call.
|
|
227
|
+
:param state:
|
|
228
|
+
The current state containing inputs to be injected into the tool arguments.
|
|
229
|
+
:param streaming_callback:
|
|
230
|
+
Optional streaming callback to be injected if enabled and applicable.
|
|
231
|
+
:param enable_streaming_passthrough:
|
|
232
|
+
Flag indicating whether to inject the streaming callback into the tool arguments.
|
|
233
|
+
|
|
234
|
+
:returns:
|
|
235
|
+
A dictionary of final arguments ready for tool invocation.
|
|
236
|
+
"""
|
|
237
|
+
# Combine user + state inputs
|
|
238
|
+
final_args = ToolInvoker._inject_state_args(tool, tool_call_arguments.copy(), state)
|
|
239
|
+
# Check whether to inject streaming_callback
|
|
240
|
+
if (
|
|
241
|
+
enable_streaming_passthrough
|
|
242
|
+
and streaming_callback is not None
|
|
243
|
+
and "streaming_callback" not in final_args
|
|
244
|
+
and "streaming_callback" in ToolInvoker._get_func_params(tool)
|
|
245
|
+
):
|
|
246
|
+
final_args["streaming_callback"] = streaming_callback
|
|
247
|
+
return final_args
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def _process_confirmation_strategies(
|
|
251
|
+
*,
|
|
252
|
+
confirmation_strategies: dict[str, ConfirmationStrategy],
|
|
253
|
+
messages_with_tool_calls: list[ChatMessage],
|
|
254
|
+
execution_context: "_ExecutionContext",
|
|
255
|
+
) -> tuple[list[ChatMessage], list[ChatMessage]]:
|
|
256
|
+
"""
|
|
257
|
+
Run the confirmation strategies and return modified tool call messages and updated chat history.
|
|
258
|
+
|
|
259
|
+
:param confirmation_strategies: Mapping of tool names to their corresponding confirmation strategies
|
|
260
|
+
:param messages_with_tool_calls: Chat messages containing tool calls
|
|
261
|
+
:param execution_context: The current execution context of the agent
|
|
262
|
+
:returns:
|
|
263
|
+
Tuple of modified messages with confirmed tool calls and updated chat history
|
|
264
|
+
"""
|
|
265
|
+
# Run confirmation strategies and get tool execution decisions
|
|
266
|
+
teds = _run_confirmation_strategies(
|
|
267
|
+
confirmation_strategies=confirmation_strategies,
|
|
268
|
+
messages_with_tool_calls=messages_with_tool_calls,
|
|
269
|
+
execution_context=execution_context,
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
# Apply tool execution decisions to messages_with_tool_calls
|
|
273
|
+
rejection_messages, modified_tool_call_messages = _apply_tool_execution_decisions(
|
|
274
|
+
tool_call_messages=messages_with_tool_calls,
|
|
275
|
+
tool_execution_decisions=teds,
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
# Update the chat history with rejection messages and new tool call messages
|
|
279
|
+
new_chat_history = _update_chat_history(
|
|
280
|
+
chat_history=execution_context.state.get("messages"),
|
|
281
|
+
rejection_messages=rejection_messages,
|
|
282
|
+
tool_call_and_explanation_messages=modified_tool_call_messages,
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
return modified_tool_call_messages, new_chat_history
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def _run_confirmation_strategies(
|
|
289
|
+
confirmation_strategies: dict[str, ConfirmationStrategy],
|
|
290
|
+
messages_with_tool_calls: list[ChatMessage],
|
|
291
|
+
execution_context: "_ExecutionContext",
|
|
292
|
+
) -> list[ToolExecutionDecision]:
|
|
293
|
+
"""
|
|
294
|
+
Run confirmation strategies for tool calls in the provided chat messages.
|
|
295
|
+
|
|
296
|
+
:param confirmation_strategies: Mapping of tool names to their corresponding confirmation strategies
|
|
297
|
+
:param messages_with_tool_calls: Messages containing tool calls to process
|
|
298
|
+
:param execution_context: The current execution context containing state and inputs
|
|
299
|
+
:returns:
|
|
300
|
+
A list of ToolExecutionDecision objects representing the decisions made for each tool call.
|
|
301
|
+
"""
|
|
302
|
+
state = execution_context.state
|
|
303
|
+
tools_with_names = {tool.name: tool for tool in execution_context.tool_invoker_inputs["tools"]}
|
|
304
|
+
existing_teds = execution_context.tool_execution_decisions if execution_context.tool_execution_decisions else []
|
|
305
|
+
existing_teds_by_name = {ted.tool_name: ted for ted in existing_teds if ted.tool_name}
|
|
306
|
+
existing_teds_by_id = {ted.tool_call_id: ted for ted in existing_teds if ted.tool_call_id}
|
|
307
|
+
|
|
308
|
+
teds = []
|
|
309
|
+
for message in messages_with_tool_calls:
|
|
310
|
+
if not message.tool_calls:
|
|
311
|
+
continue
|
|
312
|
+
|
|
313
|
+
for tool_call in message.tool_calls:
|
|
314
|
+
tool_name = tool_call.tool_name
|
|
315
|
+
tool_to_invoke = tools_with_names[tool_name]
|
|
316
|
+
|
|
317
|
+
# Prepare final tool args
|
|
318
|
+
final_args = _prepare_tool_args(
|
|
319
|
+
tool=tool_to_invoke,
|
|
320
|
+
tool_call_arguments=tool_call.arguments,
|
|
321
|
+
state=state,
|
|
322
|
+
streaming_callback=execution_context.tool_invoker_inputs.get("streaming_callback"),
|
|
323
|
+
enable_streaming_passthrough=execution_context.tool_invoker_inputs.get(
|
|
324
|
+
"enable_streaming_passthrough", False
|
|
325
|
+
),
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
# Get tool execution decisions from confirmation strategies
|
|
329
|
+
# If no confirmation strategy is defined for this tool, proceed with execution
|
|
330
|
+
if tool_name not in confirmation_strategies:
|
|
331
|
+
teds.append(
|
|
332
|
+
ToolExecutionDecision(
|
|
333
|
+
tool_call_id=tool_call.id,
|
|
334
|
+
tool_name=tool_name,
|
|
335
|
+
execute=True,
|
|
336
|
+
final_tool_params=final_args,
|
|
337
|
+
)
|
|
338
|
+
)
|
|
339
|
+
continue
|
|
340
|
+
|
|
341
|
+
# Check if there's already a decision for this tool call in the execution context
|
|
342
|
+
ted = existing_teds_by_id.get(tool_call.id or "") or existing_teds_by_name.get(tool_name)
|
|
343
|
+
|
|
344
|
+
# If not, run the confirmation strategy
|
|
345
|
+
if not ted:
|
|
346
|
+
ted = confirmation_strategies[tool_name].run(
|
|
347
|
+
tool_name=tool_name, tool_description=tool_to_invoke.description, tool_params=final_args
|
|
348
|
+
)
|
|
349
|
+
teds.append(ted)
|
|
350
|
+
|
|
351
|
+
return teds
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
def _apply_tool_execution_decisions(
|
|
355
|
+
tool_call_messages: list[ChatMessage], tool_execution_decisions: list[ToolExecutionDecision]
|
|
356
|
+
) -> tuple[list[ChatMessage], list[ChatMessage]]:
|
|
357
|
+
"""
|
|
358
|
+
Apply the tool execution decisions to the tool call messages.
|
|
359
|
+
|
|
360
|
+
:param tool_call_messages: The tool call messages to apply the decisions to.
|
|
361
|
+
:param tool_execution_decisions: The tool execution decisions to apply.
|
|
362
|
+
:returns:
|
|
363
|
+
A tuple containing:
|
|
364
|
+
- A list of rejection messages for rejected tool calls. These are pairs of tool call and tool call result
|
|
365
|
+
messages.
|
|
366
|
+
- A list of tool call messages for confirmed or modified tool calls. If tool parameters were modified,
|
|
367
|
+
a user message explaining the modification is included before the tool call message.
|
|
368
|
+
"""
|
|
369
|
+
decision_by_id = {d.tool_call_id: d for d in tool_execution_decisions if d.tool_call_id}
|
|
370
|
+
decision_by_name = {d.tool_name: d for d in tool_execution_decisions if d.tool_name}
|
|
371
|
+
|
|
372
|
+
def make_assistant_message(chat_message, tool_calls):
|
|
373
|
+
return ChatMessage.from_assistant(
|
|
374
|
+
text=chat_message.text,
|
|
375
|
+
meta=chat_message.meta,
|
|
376
|
+
name=chat_message.name,
|
|
377
|
+
tool_calls=tool_calls,
|
|
378
|
+
reasoning=chat_message.reasoning,
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
new_tool_call_messages = []
|
|
382
|
+
rejection_messages = []
|
|
383
|
+
|
|
384
|
+
for chat_msg in tool_call_messages:
|
|
385
|
+
new_tool_calls = []
|
|
386
|
+
for tc in chat_msg.tool_calls or []:
|
|
387
|
+
ted = decision_by_id.get(tc.id or "") or decision_by_name.get(tc.tool_name)
|
|
388
|
+
if not ted:
|
|
389
|
+
# This shouldn't happen, if so something went wrong in _run_confirmation_strategies
|
|
390
|
+
continue
|
|
391
|
+
|
|
392
|
+
if not ted.execute:
|
|
393
|
+
# rejected tool call
|
|
394
|
+
tool_result_text = ted.feedback or _REJECTION_FEEDBACK_TEMPLATE.format(tool_name=tc.tool_name)
|
|
395
|
+
rejection_messages.extend(
|
|
396
|
+
[
|
|
397
|
+
make_assistant_message(chat_msg, [tc]),
|
|
398
|
+
ChatMessage.from_tool(tool_result=tool_result_text, origin=tc, error=True),
|
|
399
|
+
]
|
|
400
|
+
)
|
|
401
|
+
continue
|
|
402
|
+
|
|
403
|
+
# Covers confirm and modify cases
|
|
404
|
+
final_args = ted.final_tool_params or {}
|
|
405
|
+
if tc.arguments != final_args:
|
|
406
|
+
# In the modify case we add a user message explaining the modification otherwise the LLM won't know
|
|
407
|
+
# why the tool parameters changed and will likely just try and call the tool again with the
|
|
408
|
+
# original parameters.
|
|
409
|
+
user_text = ted.feedback or _MODIFICATION_FEEDBACK_TEMPLATE.format(
|
|
410
|
+
tool_name=tc.tool_name, final_tool_params=final_args
|
|
411
|
+
)
|
|
412
|
+
new_tool_call_messages.append(ChatMessage.from_user(text=user_text))
|
|
413
|
+
new_tool_calls.append(replace(tc, arguments=final_args))
|
|
414
|
+
|
|
415
|
+
# Only add the tool call message if there are any tool calls left (i.e. not all were rejected)
|
|
416
|
+
if new_tool_calls:
|
|
417
|
+
new_tool_call_messages.append(make_assistant_message(chat_msg, new_tool_calls))
|
|
418
|
+
|
|
419
|
+
return rejection_messages, new_tool_call_messages
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
def _update_chat_history(
|
|
423
|
+
chat_history: list[ChatMessage],
|
|
424
|
+
rejection_messages: list[ChatMessage],
|
|
425
|
+
tool_call_and_explanation_messages: list[ChatMessage],
|
|
426
|
+
) -> list[ChatMessage]:
|
|
427
|
+
"""
|
|
428
|
+
Update the chat history to include rejection messages and tool call messages at the appropriate positions.
|
|
429
|
+
|
|
430
|
+
Steps:
|
|
431
|
+
1. Identify the last user message and the last tool message in the current chat history.
|
|
432
|
+
2. Determine the insertion point as the maximum index of these two messages.
|
|
433
|
+
3. Create a new chat history that includes:
|
|
434
|
+
- All messages up to the insertion point.
|
|
435
|
+
- Any rejection messages (pairs of tool call and tool call result messages).
|
|
436
|
+
- Any tool call messages for confirmed or modified tool calls, including user messages explaining modifications.
|
|
437
|
+
|
|
438
|
+
:param chat_history: The current chat history.
|
|
439
|
+
:param rejection_messages: Chat messages to add for rejected tool calls (pairs of tool call and tool call result
|
|
440
|
+
messages).
|
|
441
|
+
:param tool_call_and_explanation_messages: Tool call messages for confirmed or modified tool calls, which may
|
|
442
|
+
include user messages explaining modifications.
|
|
443
|
+
:returns:
|
|
444
|
+
The updated chat history.
|
|
445
|
+
"""
|
|
446
|
+
user_indices = [i for i, message in enumerate(chat_history) if message.is_from("user")]
|
|
447
|
+
tool_indices = [i for i, message in enumerate(chat_history) if message.is_from("tool")]
|
|
448
|
+
|
|
449
|
+
last_user_idx = max(user_indices) if user_indices else -1
|
|
450
|
+
last_tool_idx = max(tool_indices) if tool_indices else -1
|
|
451
|
+
|
|
452
|
+
insertion_point = max(last_user_idx, last_tool_idx)
|
|
453
|
+
|
|
454
|
+
new_chat_history = chat_history[: insertion_point + 1] + rejection_messages + tool_call_and_explanation_messages
|
|
455
|
+
return new_chat_history
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
2
|
+
#
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
from typing import Any, Optional, Protocol
|
|
6
|
+
|
|
7
|
+
from haystack.core.serialization import default_from_dict, default_to_dict
|
|
8
|
+
|
|
9
|
+
from haystack_experimental.components.agents.human_in_the_loop.dataclasses import (
|
|
10
|
+
ConfirmationUIResult,
|
|
11
|
+
ToolExecutionDecision,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
# Ellipsis are needed to define the Protocol but pylint complains. See https://github.com/pylint-dev/pylint/issues/9319.
|
|
15
|
+
# pylint: disable=unnecessary-ellipsis
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ConfirmationUI(Protocol):
|
|
19
|
+
"""Base class for confirmation UIs."""
|
|
20
|
+
|
|
21
|
+
def get_user_confirmation(
|
|
22
|
+
self, tool_name: str, tool_description: str, tool_params: dict[str, Any]
|
|
23
|
+
) -> ConfirmationUIResult:
|
|
24
|
+
"""Get user confirmation for tool execution."""
|
|
25
|
+
...
|
|
26
|
+
|
|
27
|
+
def to_dict(self) -> dict[str, Any]:
|
|
28
|
+
"""Serialize the UI to a dictionary."""
|
|
29
|
+
return default_to_dict(self)
|
|
30
|
+
|
|
31
|
+
@classmethod
|
|
32
|
+
def from_dict(cls, data: dict[str, Any]) -> "ConfirmationUI":
|
|
33
|
+
"""Deserialize the ConfirmationUI from a dictionary."""
|
|
34
|
+
return default_from_dict(cls, data)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class ConfirmationPolicy(Protocol):
|
|
38
|
+
"""Base class for confirmation policies."""
|
|
39
|
+
|
|
40
|
+
def should_ask(self, tool_name: str, tool_description: str, tool_params: dict[str, Any]) -> bool:
|
|
41
|
+
"""Determine whether to ask for confirmation."""
|
|
42
|
+
...
|
|
43
|
+
|
|
44
|
+
def update_after_confirmation(
|
|
45
|
+
self,
|
|
46
|
+
tool_name: str,
|
|
47
|
+
tool_description: str,
|
|
48
|
+
tool_params: dict[str, Any],
|
|
49
|
+
confirmation_result: ConfirmationUIResult,
|
|
50
|
+
) -> None:
|
|
51
|
+
"""Update the policy based on the confirmation UI result."""
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
def to_dict(self) -> dict[str, Any]:
|
|
55
|
+
"""Serialize the policy to a dictionary."""
|
|
56
|
+
return default_to_dict(self)
|
|
57
|
+
|
|
58
|
+
@classmethod
|
|
59
|
+
def from_dict(cls, data: dict[str, Any]) -> "ConfirmationPolicy":
|
|
60
|
+
"""Deserialize the policy from a dictionary."""
|
|
61
|
+
return default_from_dict(cls, data)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class ConfirmationStrategy(Protocol):
|
|
65
|
+
def run(
|
|
66
|
+
self, tool_name: str, tool_description: str, tool_params: dict[str, Any], tool_call_id: Optional[str] = None
|
|
67
|
+
) -> ToolExecutionDecision:
|
|
68
|
+
"""
|
|
69
|
+
Run the confirmation strategy for a given tool and its parameters.
|
|
70
|
+
|
|
71
|
+
:param tool_name: The name of the tool to be executed.
|
|
72
|
+
:param tool_description: The description of the tool.
|
|
73
|
+
:param tool_params: The parameters to be passed to the tool.
|
|
74
|
+
:param tool_call_id: Optional unique identifier for the tool call. This can be used to track and correlate
|
|
75
|
+
the decision with a specific tool invocation.
|
|
76
|
+
|
|
77
|
+
:returns:
|
|
78
|
+
The result of the confirmation strategy (e.g., tool output, rejection message, etc.).
|
|
79
|
+
"""
|
|
80
|
+
...
|
|
81
|
+
|
|
82
|
+
def to_dict(self) -> dict[str, Any]:
|
|
83
|
+
"""Serialize the strategy to a dictionary."""
|
|
84
|
+
...
|
|
85
|
+
|
|
86
|
+
@classmethod
|
|
87
|
+
def from_dict(cls, data: dict[str, Any]) -> "ConfirmationStrategy":
|
|
88
|
+
"""Deserialize the strategy from a dictionary."""
|
|
89
|
+
...
|