haystack-experimental 0.13.0__py3-none-any.whl → 0.14.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- haystack_experimental/components/agents/__init__.py +16 -0
- haystack_experimental/components/agents/agent.py +634 -0
- haystack_experimental/components/agents/human_in_the_loop/__init__.py +35 -0
- haystack_experimental/components/agents/human_in_the_loop/breakpoint.py +63 -0
- haystack_experimental/components/agents/human_in_the_loop/dataclasses.py +72 -0
- haystack_experimental/components/agents/human_in_the_loop/errors.py +28 -0
- haystack_experimental/components/agents/human_in_the_loop/policies.py +78 -0
- haystack_experimental/components/agents/human_in_the_loop/strategies.py +455 -0
- haystack_experimental/components/agents/human_in_the_loop/types.py +89 -0
- haystack_experimental/components/agents/human_in_the_loop/user_interfaces.py +209 -0
- haystack_experimental/components/preprocessors/embedding_based_document_splitter.py +18 -6
- haystack_experimental/components/preprocessors/md_header_level_inferrer.py +146 -0
- haystack_experimental/components/summarizers/__init__.py +7 -0
- haystack_experimental/components/summarizers/llm_summarizer.py +317 -0
- haystack_experimental/core/__init__.py +3 -0
- haystack_experimental/core/pipeline/__init__.py +3 -0
- haystack_experimental/core/pipeline/breakpoint.py +174 -0
- haystack_experimental/dataclasses/__init__.py +3 -0
- haystack_experimental/dataclasses/breakpoints.py +53 -0
- {haystack_experimental-0.13.0.dist-info → haystack_experimental-0.14.0.dist-info}/METADATA +29 -14
- {haystack_experimental-0.13.0.dist-info → haystack_experimental-0.14.0.dist-info}/RECORD +24 -6
- {haystack_experimental-0.13.0.dist-info → haystack_experimental-0.14.0.dist-info}/WHEEL +0 -0
- {haystack_experimental-0.13.0.dist-info → haystack_experimental-0.14.0.dist-info}/licenses/LICENSE +0 -0
- {haystack_experimental-0.13.0.dist-info → haystack_experimental-0.14.0.dist-info}/licenses/LICENSE-MIT.txt +0 -0
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
2
|
+
#
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
import sys
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
from lazy_imports import LazyImporter
|
|
9
|
+
|
|
10
|
+
_import_structure = {"agent": ["Agent"]}
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from .agent import Agent as Agent
|
|
14
|
+
|
|
15
|
+
else:
|
|
16
|
+
sys.modules[__name__] = LazyImporter(name=__name__, module_file=__file__, import_structure=_import_structure)
|
|
@@ -0,0 +1,634 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
2
|
+
#
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
# pylint: disable=wrong-import-order,wrong-import-position,ungrouped-imports
|
|
6
|
+
# ruff: noqa: I001
|
|
7
|
+
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from typing import Any, Optional, Union
|
|
10
|
+
|
|
11
|
+
# Monkey patch Haystack's AgentSnapshot with our extended version
|
|
12
|
+
import haystack.dataclasses.breakpoints as hdb
|
|
13
|
+
from haystack_experimental.dataclasses.breakpoints import AgentSnapshot
|
|
14
|
+
|
|
15
|
+
hdb.AgentSnapshot = AgentSnapshot # type: ignore[misc]
|
|
16
|
+
|
|
17
|
+
# Monkey patch Haystack's breakpoint functions with our extended versions
|
|
18
|
+
import haystack.core.pipeline.breakpoint as hs_breakpoint
|
|
19
|
+
import haystack_experimental.core.pipeline.breakpoint as exp_breakpoint
|
|
20
|
+
|
|
21
|
+
hs_breakpoint._create_agent_snapshot = exp_breakpoint._create_agent_snapshot
|
|
22
|
+
hs_breakpoint._create_pipeline_snapshot_from_tool_invoker = exp_breakpoint._create_pipeline_snapshot_from_tool_invoker # type: ignore[assignment]
|
|
23
|
+
hs_breakpoint._trigger_tool_invoker_breakpoint = exp_breakpoint._trigger_tool_invoker_breakpoint
|
|
24
|
+
|
|
25
|
+
from haystack import logging
|
|
26
|
+
from haystack.components.agents.agent import Agent as HaystackAgent
|
|
27
|
+
from haystack.components.agents.agent import _ExecutionContext as Haystack_ExecutionContext
|
|
28
|
+
from haystack.components.agents.agent import _schema_from_dict
|
|
29
|
+
from haystack.components.agents.state import replace_values
|
|
30
|
+
from haystack.components.generators.chat.types import ChatGenerator
|
|
31
|
+
from haystack.core.errors import PipelineRuntimeError
|
|
32
|
+
from haystack.core.pipeline import AsyncPipeline, Pipeline
|
|
33
|
+
from haystack.core.pipeline.breakpoint import (
|
|
34
|
+
_create_pipeline_snapshot_from_chat_generator,
|
|
35
|
+
_create_pipeline_snapshot_from_tool_invoker,
|
|
36
|
+
)
|
|
37
|
+
from haystack.core.pipeline.utils import _deepcopy_with_exceptions
|
|
38
|
+
from haystack.core.serialization import default_from_dict, import_class_by_name
|
|
39
|
+
from haystack.dataclasses import ChatMessage
|
|
40
|
+
from haystack.dataclasses.breakpoints import AgentBreakpoint, ToolBreakpoint
|
|
41
|
+
from haystack.dataclasses.streaming_chunk import StreamingCallbackT
|
|
42
|
+
from haystack.tools import Tool, Toolset, deserialize_tools_or_toolset_inplace
|
|
43
|
+
from haystack.utils.callable_serialization import deserialize_callable
|
|
44
|
+
from haystack.utils.deserialization import deserialize_chatgenerator_inplace
|
|
45
|
+
|
|
46
|
+
from haystack_experimental.components.agents.human_in_the_loop import (
|
|
47
|
+
ConfirmationStrategy,
|
|
48
|
+
ToolExecutionDecision,
|
|
49
|
+
HITLBreakpointException,
|
|
50
|
+
)
|
|
51
|
+
from haystack_experimental.components.agents.human_in_the_loop.strategies import _process_confirmation_strategies
|
|
52
|
+
|
|
53
|
+
logger = logging.getLogger(__name__)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@dataclass
|
|
57
|
+
class _ExecutionContext(Haystack_ExecutionContext):
|
|
58
|
+
"""
|
|
59
|
+
Execution context for the Agent component
|
|
60
|
+
|
|
61
|
+
Extends Haystack's _ExecutionContext to include tool execution decisions for human-in-the-loop strategies.
|
|
62
|
+
|
|
63
|
+
:param tool_execution_decisions: Optional list of ToolExecutionDecision objects to use instead of prompting
|
|
64
|
+
the user. This is useful when restarting from a snapshot where tool execution decisions were already made.
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
tool_execution_decisions: Optional[list[ToolExecutionDecision]] = None
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class Agent(HaystackAgent):
|
|
71
|
+
"""
|
|
72
|
+
A Haystack component that implements a tool-using agent with provider-agnostic chat model support.
|
|
73
|
+
|
|
74
|
+
NOTE: This class extends Haystack's Agent component to add support for human-in-the-loop confirmation strategies.
|
|
75
|
+
|
|
76
|
+
The component processes messages and executes tools until an exit condition is met.
|
|
77
|
+
The exit condition can be triggered either by a direct text response or by invoking a specific designated tool.
|
|
78
|
+
Multiple exit conditions can be specified.
|
|
79
|
+
|
|
80
|
+
When you call an Agent without tools, it acts as a ChatGenerator, produces one response, then exits.
|
|
81
|
+
|
|
82
|
+
### Usage example
|
|
83
|
+
```python
|
|
84
|
+
from haystack.components.generators.chat import OpenAIChatGenerator
|
|
85
|
+
from haystack.dataclasses import ChatMessage
|
|
86
|
+
from haystack.tools.tool import Tool
|
|
87
|
+
|
|
88
|
+
from haystack_experimental.components.agents import Agent
|
|
89
|
+
from haystack_experimental.components.agents.human_in_the_loop import (
|
|
90
|
+
HumanInTheLoopStrategy,
|
|
91
|
+
AlwaysAskPolicy,
|
|
92
|
+
NeverAskPolicy,
|
|
93
|
+
SimpleConsoleUI,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
calculator_tool = Tool(name="calculator", description="A tool for performing mathematical calculations.", ...)
|
|
97
|
+
search_tool = Tool(name="search", description="A tool for searching the web.", ...)
|
|
98
|
+
|
|
99
|
+
agent = Agent(
|
|
100
|
+
chat_generator=OpenAIChatGenerator(),
|
|
101
|
+
tools=[calculator_tool, search_tool],
|
|
102
|
+
confirmation_strategies={
|
|
103
|
+
calculator_tool.name: HumanInTheLoopStrategy(
|
|
104
|
+
confirmation_policy=NeverAskPolicy(), confirmation_ui=SimpleConsoleUI()
|
|
105
|
+
),
|
|
106
|
+
search_tool.name: HumanInTheLoopStrategy(
|
|
107
|
+
confirmation_policy=AlwaysAskPolicy(), confirmation_ui=SimpleConsoleUI()
|
|
108
|
+
),
|
|
109
|
+
},
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# Run the agent
|
|
113
|
+
result = agent.run(
|
|
114
|
+
messages=[ChatMessage.from_user("Find information about Haystack")]
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
assert "messages" in result # Contains conversation history
|
|
118
|
+
```
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
def __init__(
|
|
122
|
+
self,
|
|
123
|
+
*,
|
|
124
|
+
chat_generator: ChatGenerator,
|
|
125
|
+
tools: Optional[Union[list[Tool], Toolset]] = None,
|
|
126
|
+
system_prompt: Optional[str] = None,
|
|
127
|
+
exit_conditions: Optional[list[str]] = None,
|
|
128
|
+
state_schema: Optional[dict[str, Any]] = None,
|
|
129
|
+
max_agent_steps: int = 100,
|
|
130
|
+
streaming_callback: Optional[StreamingCallbackT] = None,
|
|
131
|
+
raise_on_tool_invocation_failure: bool = False,
|
|
132
|
+
confirmation_strategies: Optional[dict[str, ConfirmationStrategy]] = None,
|
|
133
|
+
tool_invoker_kwargs: Optional[dict[str, Any]] = None,
|
|
134
|
+
) -> None:
|
|
135
|
+
"""
|
|
136
|
+
Initialize the agent component.
|
|
137
|
+
|
|
138
|
+
:param chat_generator: An instance of the chat generator that your agent should use. It must support tools.
|
|
139
|
+
:param tools: List of Tool objects or a Toolset that the agent can use.
|
|
140
|
+
:param system_prompt: System prompt for the agent.
|
|
141
|
+
:param exit_conditions: List of conditions that will cause the agent to return.
|
|
142
|
+
Can include "text" if the agent should return when it generates a message without tool calls,
|
|
143
|
+
or tool names that will cause the agent to return once the tool was executed. Defaults to ["text"].
|
|
144
|
+
:param state_schema: The schema for the runtime state used by the tools.
|
|
145
|
+
:param max_agent_steps: Maximum number of steps the agent will run before stopping. Defaults to 100.
|
|
146
|
+
If the agent exceeds this number of steps, it will stop and return the current state.
|
|
147
|
+
:param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
|
|
148
|
+
The same callback can be configured to emit tool results when a tool is called.
|
|
149
|
+
:param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails?
|
|
150
|
+
If set to False, the exception will be turned into a chat message and passed to the LLM.
|
|
151
|
+
:param tool_invoker_kwargs: Additional keyword arguments to pass to the ToolInvoker.
|
|
152
|
+
:raises TypeError: If the chat_generator does not support tools parameter in its run method.
|
|
153
|
+
:raises ValueError: If the exit_conditions are not valid.
|
|
154
|
+
"""
|
|
155
|
+
super(Agent, self).__init__(
|
|
156
|
+
chat_generator=chat_generator,
|
|
157
|
+
tools=tools,
|
|
158
|
+
system_prompt=system_prompt,
|
|
159
|
+
exit_conditions=exit_conditions,
|
|
160
|
+
state_schema=state_schema,
|
|
161
|
+
max_agent_steps=max_agent_steps,
|
|
162
|
+
streaming_callback=streaming_callback,
|
|
163
|
+
raise_on_tool_invocation_failure=raise_on_tool_invocation_failure,
|
|
164
|
+
tool_invoker_kwargs=tool_invoker_kwargs,
|
|
165
|
+
)
|
|
166
|
+
self._confirmation_strategies = confirmation_strategies or {}
|
|
167
|
+
|
|
168
|
+
def _initialize_fresh_execution(
|
|
169
|
+
self,
|
|
170
|
+
messages: list[ChatMessage],
|
|
171
|
+
streaming_callback: Optional[StreamingCallbackT],
|
|
172
|
+
requires_async: bool,
|
|
173
|
+
*,
|
|
174
|
+
system_prompt: Optional[str] = None,
|
|
175
|
+
tools: Optional[Union[list[Tool], Toolset, list[str]]] = None,
|
|
176
|
+
**kwargs: dict[str, Any],
|
|
177
|
+
) -> _ExecutionContext:
|
|
178
|
+
"""
|
|
179
|
+
Initialize execution context for a fresh run of the agent.
|
|
180
|
+
|
|
181
|
+
:param messages: List of ChatMessage objects to start the agent with.
|
|
182
|
+
:param streaming_callback: Optional callback for streaming responses.
|
|
183
|
+
:param requires_async: Whether the agent run requires asynchronous execution.
|
|
184
|
+
:param system_prompt: System prompt for the agent. If provided, it overrides the default system prompt.
|
|
185
|
+
:param tools: Optional list of Tool objects, a Toolset, or list of tool names to use for this run.
|
|
186
|
+
When passing tool names, tools are selected from the Agent's originally configured tools.
|
|
187
|
+
:param kwargs: Additional data to pass to the State used by the Agent.
|
|
188
|
+
"""
|
|
189
|
+
exe_context = super(Agent, self)._initialize_fresh_execution(
|
|
190
|
+
messages=messages,
|
|
191
|
+
streaming_callback=streaming_callback,
|
|
192
|
+
requires_async=requires_async,
|
|
193
|
+
system_prompt=system_prompt,
|
|
194
|
+
tools=tools,
|
|
195
|
+
**kwargs,
|
|
196
|
+
)
|
|
197
|
+
# NOTE: 1st difference with parent method to add this to tool_invoker_inputs
|
|
198
|
+
if self._tool_invoker:
|
|
199
|
+
exe_context.tool_invoker_inputs["enable_streaming_callback_passthrough"] = (
|
|
200
|
+
self._tool_invoker.enable_streaming_callback_passthrough
|
|
201
|
+
)
|
|
202
|
+
# NOTE: 2nd difference is to use the extended _ExecutionContext
|
|
203
|
+
return _ExecutionContext(
|
|
204
|
+
state=exe_context.state,
|
|
205
|
+
component_visits=exe_context.component_visits,
|
|
206
|
+
chat_generator_inputs=exe_context.chat_generator_inputs,
|
|
207
|
+
tool_invoker_inputs=exe_context.tool_invoker_inputs,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
def _initialize_from_snapshot( # type: ignore[override]
|
|
211
|
+
self,
|
|
212
|
+
snapshot: AgentSnapshot,
|
|
213
|
+
streaming_callback: Optional[StreamingCallbackT],
|
|
214
|
+
requires_async: bool,
|
|
215
|
+
*,
|
|
216
|
+
tools: Optional[Union[list[Tool], Toolset, list[str]]] = None,
|
|
217
|
+
) -> _ExecutionContext:
|
|
218
|
+
"""
|
|
219
|
+
Initialize execution context from an AgentSnapshot.
|
|
220
|
+
|
|
221
|
+
:param snapshot: An AgentSnapshot containing the state of a previously saved agent execution.
|
|
222
|
+
:param streaming_callback: Optional callback for streaming responses.
|
|
223
|
+
:param requires_async: Whether the agent run requires asynchronous execution.
|
|
224
|
+
:param tools: Optional list of Tool objects, a Toolset, or list of tool names to use for this run.
|
|
225
|
+
When passing tool names, tools are selected from the Agent's originally configured tools.
|
|
226
|
+
"""
|
|
227
|
+
exe_context = super(Agent, self)._initialize_from_snapshot(
|
|
228
|
+
snapshot=snapshot, streaming_callback=streaming_callback, requires_async=requires_async, tools=tools
|
|
229
|
+
)
|
|
230
|
+
# NOTE: 1st difference with parent method to add this to tool_invoker_inputs
|
|
231
|
+
if self._tool_invoker:
|
|
232
|
+
exe_context.tool_invoker_inputs["enable_streaming_callback_passthrough"] = (
|
|
233
|
+
self._tool_invoker.enable_streaming_callback_passthrough
|
|
234
|
+
)
|
|
235
|
+
# NOTE: 2nd difference is to use the extended _ExecutionContext and add tool_execution_decisions
|
|
236
|
+
return _ExecutionContext(
|
|
237
|
+
state=exe_context.state,
|
|
238
|
+
component_visits=exe_context.component_visits,
|
|
239
|
+
chat_generator_inputs=exe_context.chat_generator_inputs,
|
|
240
|
+
tool_invoker_inputs=exe_context.tool_invoker_inputs,
|
|
241
|
+
counter=exe_context.counter,
|
|
242
|
+
skip_chat_generator=exe_context.skip_chat_generator,
|
|
243
|
+
tool_execution_decisions=snapshot.tool_execution_decisions,
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
def run( # noqa: PLR0915
|
|
247
|
+
self,
|
|
248
|
+
messages: list[ChatMessage],
|
|
249
|
+
streaming_callback: Optional[StreamingCallbackT] = None,
|
|
250
|
+
*,
|
|
251
|
+
break_point: Optional[AgentBreakpoint] = None,
|
|
252
|
+
snapshot: Optional[AgentSnapshot] = None, # type: ignore[override]
|
|
253
|
+
system_prompt: Optional[str] = None,
|
|
254
|
+
tools: Optional[Union[list[Tool], Toolset, list[str]]] = None,
|
|
255
|
+
**kwargs: Any,
|
|
256
|
+
) -> dict[str, Any]:
|
|
257
|
+
"""
|
|
258
|
+
Process messages and execute tools until an exit condition is met.
|
|
259
|
+
|
|
260
|
+
:param messages: List of Haystack ChatMessage objects to process.
|
|
261
|
+
:param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
|
|
262
|
+
The same callback can be configured to emit tool results when a tool is called.
|
|
263
|
+
:param break_point: An AgentBreakpoint, can be a Breakpoint for the "chat_generator" or a ToolBreakpoint
|
|
264
|
+
for "tool_invoker".
|
|
265
|
+
:param snapshot: A dictionary containing a snapshot of a previously saved agent execution. The snapshot contains
|
|
266
|
+
the relevant information to restart the Agent execution from where it left off.
|
|
267
|
+
:param system_prompt: System prompt for the agent. If provided, it overrides the default system prompt.
|
|
268
|
+
:param tools: Optional list of Tool objects, a Toolset, or list of tool names to use for this run.
|
|
269
|
+
When passing tool names, tools are selected from the Agent's originally configured tools.
|
|
270
|
+
:param kwargs: Additional data to pass to the State schema used by the Agent.
|
|
271
|
+
The keys must match the schema defined in the Agent's `state_schema`.
|
|
272
|
+
:returns:
|
|
273
|
+
A dictionary with the following keys:
|
|
274
|
+
- "messages": List of all messages exchanged during the agent's run.
|
|
275
|
+
- "last_message": The last message exchanged during the agent's run.
|
|
276
|
+
- Any additional keys defined in the `state_schema`.
|
|
277
|
+
:raises RuntimeError: If the Agent component wasn't warmed up before calling `run()`.
|
|
278
|
+
:raises BreakpointException: If an agent breakpoint is triggered.
|
|
279
|
+
"""
|
|
280
|
+
# We pop parent_snapshot from kwargs to avoid passing it into State.
|
|
281
|
+
parent_snapshot = kwargs.pop("parent_snapshot", None)
|
|
282
|
+
agent_inputs = {
|
|
283
|
+
"messages": messages,
|
|
284
|
+
"streaming_callback": streaming_callback,
|
|
285
|
+
"break_point": break_point,
|
|
286
|
+
"snapshot": snapshot,
|
|
287
|
+
**kwargs,
|
|
288
|
+
}
|
|
289
|
+
self._runtime_checks(break_point=break_point, snapshot=snapshot)
|
|
290
|
+
|
|
291
|
+
if snapshot:
|
|
292
|
+
exe_context = self._initialize_from_snapshot(
|
|
293
|
+
snapshot=snapshot, streaming_callback=streaming_callback, requires_async=False, tools=tools
|
|
294
|
+
)
|
|
295
|
+
else:
|
|
296
|
+
exe_context = self._initialize_fresh_execution(
|
|
297
|
+
messages=messages,
|
|
298
|
+
streaming_callback=streaming_callback,
|
|
299
|
+
requires_async=False,
|
|
300
|
+
system_prompt=system_prompt,
|
|
301
|
+
tools=tools,
|
|
302
|
+
**kwargs,
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
with self._create_agent_span() as span:
|
|
306
|
+
span.set_content_tag("haystack.agent.input", _deepcopy_with_exceptions(agent_inputs))
|
|
307
|
+
|
|
308
|
+
while exe_context.counter < self.max_agent_steps:
|
|
309
|
+
# Handle breakpoint and ChatGenerator call
|
|
310
|
+
Agent._check_chat_generator_breakpoint(
|
|
311
|
+
execution_context=exe_context, break_point=break_point, parent_snapshot=parent_snapshot
|
|
312
|
+
)
|
|
313
|
+
# We skip the chat generator when restarting from a snapshot from a ToolBreakpoint
|
|
314
|
+
if exe_context.skip_chat_generator:
|
|
315
|
+
llm_messages = exe_context.state.get("messages", [])[-1:]
|
|
316
|
+
# Set to False so the next iteration will call the chat generator
|
|
317
|
+
exe_context.skip_chat_generator = False
|
|
318
|
+
else:
|
|
319
|
+
try:
|
|
320
|
+
result = Pipeline._run_component(
|
|
321
|
+
component_name="chat_generator",
|
|
322
|
+
component={"instance": self.chat_generator},
|
|
323
|
+
inputs={
|
|
324
|
+
"messages": exe_context.state.data["messages"],
|
|
325
|
+
**exe_context.chat_generator_inputs,
|
|
326
|
+
},
|
|
327
|
+
component_visits=exe_context.component_visits,
|
|
328
|
+
parent_span=span,
|
|
329
|
+
)
|
|
330
|
+
except PipelineRuntimeError as e:
|
|
331
|
+
pipeline_snapshot = _create_pipeline_snapshot_from_chat_generator(
|
|
332
|
+
agent_name=getattr(self, "__component_name__", None),
|
|
333
|
+
execution_context=exe_context,
|
|
334
|
+
parent_snapshot=parent_snapshot,
|
|
335
|
+
)
|
|
336
|
+
e.pipeline_snapshot = pipeline_snapshot
|
|
337
|
+
raise e
|
|
338
|
+
|
|
339
|
+
llm_messages = result["replies"]
|
|
340
|
+
exe_context.state.set("messages", llm_messages)
|
|
341
|
+
|
|
342
|
+
# Check if any of the LLM responses contain a tool call or if the LLM is not using tools
|
|
343
|
+
if not any(msg.tool_call for msg in llm_messages) or self._tool_invoker is None:
|
|
344
|
+
exe_context.counter += 1
|
|
345
|
+
break
|
|
346
|
+
|
|
347
|
+
# Apply confirmation strategies and update State and messages sent to ToolInvoker
|
|
348
|
+
try:
|
|
349
|
+
# Run confirmation strategies to get updated tool call messages and modified chat history
|
|
350
|
+
modified_tool_call_messages, new_chat_history = _process_confirmation_strategies(
|
|
351
|
+
confirmation_strategies=self._confirmation_strategies,
|
|
352
|
+
messages_with_tool_calls=llm_messages,
|
|
353
|
+
execution_context=exe_context,
|
|
354
|
+
)
|
|
355
|
+
# Replace the chat history in state with the modified one
|
|
356
|
+
exe_context.state.set(key="messages", value=new_chat_history, handler_override=replace_values)
|
|
357
|
+
except HITLBreakpointException as tbp_error:
|
|
358
|
+
# We create a break_point to pass into _check_tool_invoker_breakpoint
|
|
359
|
+
break_point = AgentBreakpoint(
|
|
360
|
+
agent_name=getattr(self, "__component_name__", ""),
|
|
361
|
+
break_point=ToolBreakpoint(
|
|
362
|
+
component_name="tool_invoker",
|
|
363
|
+
tool_name=tbp_error.tool_name,
|
|
364
|
+
visit_count=exe_context.component_visits["tool_invoker"],
|
|
365
|
+
snapshot_file_path=tbp_error.snapshot_file_path,
|
|
366
|
+
),
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
# Handle breakpoint
|
|
370
|
+
Agent._check_tool_invoker_breakpoint(
|
|
371
|
+
execution_context=exe_context, break_point=break_point, parent_snapshot=parent_snapshot
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
# Run ToolInvoker
|
|
375
|
+
try:
|
|
376
|
+
# We only send the messages from the LLM to the tool invoker
|
|
377
|
+
tool_invoker_result = Pipeline._run_component(
|
|
378
|
+
component_name="tool_invoker",
|
|
379
|
+
component={"instance": self._tool_invoker},
|
|
380
|
+
inputs={
|
|
381
|
+
"messages": modified_tool_call_messages,
|
|
382
|
+
"state": exe_context.state,
|
|
383
|
+
**exe_context.tool_invoker_inputs,
|
|
384
|
+
},
|
|
385
|
+
component_visits=exe_context.component_visits,
|
|
386
|
+
parent_span=span,
|
|
387
|
+
)
|
|
388
|
+
except PipelineRuntimeError as e:
|
|
389
|
+
# Access the original Tool Invoker exception
|
|
390
|
+
original_error = e.__cause__
|
|
391
|
+
tool_name = getattr(original_error, "tool_name", None)
|
|
392
|
+
|
|
393
|
+
pipeline_snapshot = _create_pipeline_snapshot_from_tool_invoker(
|
|
394
|
+
tool_name=tool_name,
|
|
395
|
+
agent_name=getattr(self, "__component_name__", None),
|
|
396
|
+
execution_context=exe_context,
|
|
397
|
+
parent_snapshot=parent_snapshot,
|
|
398
|
+
)
|
|
399
|
+
e.pipeline_snapshot = pipeline_snapshot
|
|
400
|
+
raise e
|
|
401
|
+
|
|
402
|
+
# Set execution context tool execution decisions to empty after applying them b/c they should only
|
|
403
|
+
# be used once for the current tool calls
|
|
404
|
+
exe_context.tool_execution_decisions = None
|
|
405
|
+
tool_messages = tool_invoker_result["tool_messages"]
|
|
406
|
+
exe_context.state = tool_invoker_result["state"]
|
|
407
|
+
exe_context.state.set("messages", tool_messages)
|
|
408
|
+
|
|
409
|
+
# Check if any LLM message's tool call name matches an exit condition
|
|
410
|
+
if self.exit_conditions != ["text"] and self._check_exit_conditions(llm_messages, tool_messages):
|
|
411
|
+
exe_context.counter += 1
|
|
412
|
+
break
|
|
413
|
+
|
|
414
|
+
# Increment the step counter
|
|
415
|
+
exe_context.counter += 1
|
|
416
|
+
|
|
417
|
+
if exe_context.counter >= self.max_agent_steps:
|
|
418
|
+
logger.warning(
|
|
419
|
+
"Agent reached maximum agent steps of {max_agent_steps}, stopping.",
|
|
420
|
+
max_agent_steps=self.max_agent_steps,
|
|
421
|
+
)
|
|
422
|
+
span.set_content_tag("haystack.agent.output", exe_context.state.data)
|
|
423
|
+
span.set_tag("haystack.agent.steps_taken", exe_context.counter)
|
|
424
|
+
|
|
425
|
+
result = {**exe_context.state.data}
|
|
426
|
+
if msgs := result.get("messages"):
|
|
427
|
+
result["last_message"] = msgs[-1]
|
|
428
|
+
return result
|
|
429
|
+
|
|
430
|
+
async def run_async(
|
|
431
|
+
self,
|
|
432
|
+
messages: list[ChatMessage],
|
|
433
|
+
streaming_callback: Optional[StreamingCallbackT] = None,
|
|
434
|
+
*,
|
|
435
|
+
break_point: Optional[AgentBreakpoint] = None,
|
|
436
|
+
snapshot: Optional[AgentSnapshot] = None, # type: ignore[override]
|
|
437
|
+
system_prompt: Optional[str] = None,
|
|
438
|
+
tools: Optional[Union[list[Tool], Toolset, list[str]]] = None,
|
|
439
|
+
**kwargs: Any,
|
|
440
|
+
) -> dict[str, Any]:
|
|
441
|
+
"""
|
|
442
|
+
Asynchronously process messages and execute tools until the exit condition is met.
|
|
443
|
+
|
|
444
|
+
This is the asynchronous version of the `run` method. It follows the same logic but uses
|
|
445
|
+
asynchronous operations where possible, such as calling the `run_async` method of the ChatGenerator
|
|
446
|
+
if available.
|
|
447
|
+
|
|
448
|
+
:param messages: List of Haystack ChatMessage objects to process.
|
|
449
|
+
:param streaming_callback: An asynchronous callback that will be invoked when a response is streamed from the
|
|
450
|
+
LLM. The same callback can be configured to emit tool results when a tool is called.
|
|
451
|
+
:param break_point: An AgentBreakpoint, can be a Breakpoint for the "chat_generator" or a ToolBreakpoint
|
|
452
|
+
for "tool_invoker".
|
|
453
|
+
:param snapshot: A dictionary containing a snapshot of a previously saved agent execution. The snapshot contains
|
|
454
|
+
the relevant information to restart the Agent execution from where it left off.
|
|
455
|
+
:param system_prompt: System prompt for the agent. If provided, it overrides the default system prompt.
|
|
456
|
+
:param tools: Optional list of Tool objects, a Toolset, or list of tool names to use for this run.
|
|
457
|
+
:param kwargs: Additional data to pass to the State schema used by the Agent.
|
|
458
|
+
The keys must match the schema defined in the Agent's `state_schema`.
|
|
459
|
+
:returns:
|
|
460
|
+
A dictionary with the following keys:
|
|
461
|
+
- "messages": List of all messages exchanged during the agent's run.
|
|
462
|
+
- "last_message": The last message exchanged during the agent's run.
|
|
463
|
+
- Any additional keys defined in the `state_schema`.
|
|
464
|
+
:raises RuntimeError: If the Agent component wasn't warmed up before calling `run_async()`.
|
|
465
|
+
:raises BreakpointException: If an agent breakpoint is triggered.
|
|
466
|
+
"""
|
|
467
|
+
# We pop parent_snapshot from kwargs to avoid passing it into State.
|
|
468
|
+
parent_snapshot = kwargs.pop("parent_snapshot", None)
|
|
469
|
+
agent_inputs = {
|
|
470
|
+
"messages": messages,
|
|
471
|
+
"streaming_callback": streaming_callback,
|
|
472
|
+
"break_point": break_point,
|
|
473
|
+
"snapshot": snapshot,
|
|
474
|
+
**kwargs,
|
|
475
|
+
}
|
|
476
|
+
self._runtime_checks(break_point=break_point, snapshot=snapshot)
|
|
477
|
+
|
|
478
|
+
if snapshot:
|
|
479
|
+
exe_context = self._initialize_from_snapshot(
|
|
480
|
+
snapshot=snapshot, streaming_callback=streaming_callback, requires_async=True, tools=tools
|
|
481
|
+
)
|
|
482
|
+
else:
|
|
483
|
+
exe_context = self._initialize_fresh_execution(
|
|
484
|
+
messages=messages,
|
|
485
|
+
streaming_callback=streaming_callback,
|
|
486
|
+
requires_async=True,
|
|
487
|
+
system_prompt=system_prompt,
|
|
488
|
+
tools=tools,
|
|
489
|
+
**kwargs,
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
with self._create_agent_span() as span:
|
|
493
|
+
span.set_content_tag("haystack.agent.input", _deepcopy_with_exceptions(agent_inputs))
|
|
494
|
+
|
|
495
|
+
while exe_context.counter < self.max_agent_steps:
|
|
496
|
+
# Handle breakpoint and ChatGenerator call
|
|
497
|
+
self._check_chat_generator_breakpoint(
|
|
498
|
+
execution_context=exe_context, break_point=break_point, parent_snapshot=parent_snapshot
|
|
499
|
+
)
|
|
500
|
+
# We skip the chat generator when restarting from a snapshot from a ToolBreakpoint
|
|
501
|
+
if exe_context.skip_chat_generator:
|
|
502
|
+
llm_messages = exe_context.state.get("messages", [])[-1:]
|
|
503
|
+
# Set to False so the next iteration will call the chat generator
|
|
504
|
+
exe_context.skip_chat_generator = False
|
|
505
|
+
else:
|
|
506
|
+
result = await AsyncPipeline._run_component_async(
|
|
507
|
+
component_name="chat_generator",
|
|
508
|
+
component={"instance": self.chat_generator},
|
|
509
|
+
component_inputs={
|
|
510
|
+
"messages": exe_context.state.data["messages"],
|
|
511
|
+
**exe_context.chat_generator_inputs,
|
|
512
|
+
},
|
|
513
|
+
component_visits=exe_context.component_visits,
|
|
514
|
+
parent_span=span,
|
|
515
|
+
)
|
|
516
|
+
llm_messages = result["replies"]
|
|
517
|
+
exe_context.state.set("messages", llm_messages)
|
|
518
|
+
|
|
519
|
+
# Check if any of the LLM responses contain a tool call or if the LLM is not using tools
|
|
520
|
+
if not any(msg.tool_call for msg in llm_messages) or self._tool_invoker is None:
|
|
521
|
+
exe_context.counter += 1
|
|
522
|
+
break
|
|
523
|
+
|
|
524
|
+
# Apply confirmation strategies and update State and messages sent to ToolInvoker
|
|
525
|
+
try:
|
|
526
|
+
# Run confirmation strategies to get updated tool call messages and modified chat history
|
|
527
|
+
modified_tool_call_messages, new_chat_history = _process_confirmation_strategies(
|
|
528
|
+
confirmation_strategies=self._confirmation_strategies,
|
|
529
|
+
messages_with_tool_calls=llm_messages,
|
|
530
|
+
execution_context=exe_context,
|
|
531
|
+
)
|
|
532
|
+
# Replace the chat history in state with the modified one
|
|
533
|
+
exe_context.state.set(key="messages", value=new_chat_history, handler_override=replace_values)
|
|
534
|
+
except HITLBreakpointException as tbp_error:
|
|
535
|
+
# We create a break_point to pass into _check_tool_invoker_breakpoint
|
|
536
|
+
break_point = AgentBreakpoint(
|
|
537
|
+
agent_name=getattr(self, "__component_name__", ""),
|
|
538
|
+
break_point=ToolBreakpoint(
|
|
539
|
+
component_name="tool_invoker",
|
|
540
|
+
tool_name=tbp_error.tool_name,
|
|
541
|
+
visit_count=exe_context.component_visits["tool_invoker"],
|
|
542
|
+
snapshot_file_path=tbp_error.snapshot_file_path,
|
|
543
|
+
),
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
# Handle breakpoint
|
|
547
|
+
Agent._check_tool_invoker_breakpoint(
|
|
548
|
+
execution_context=exe_context, break_point=break_point, parent_snapshot=parent_snapshot
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
# Run ToolInvoker
|
|
552
|
+
# We only send the messages from the LLM to the tool invoker
|
|
553
|
+
tool_invoker_result = await AsyncPipeline._run_component_async(
|
|
554
|
+
component_name="tool_invoker",
|
|
555
|
+
component={"instance": self._tool_invoker},
|
|
556
|
+
component_inputs={
|
|
557
|
+
"messages": modified_tool_call_messages,
|
|
558
|
+
"state": exe_context.state,
|
|
559
|
+
**exe_context.tool_invoker_inputs,
|
|
560
|
+
},
|
|
561
|
+
component_visits=exe_context.component_visits,
|
|
562
|
+
parent_span=span,
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
# Set execution context tool execution decisions to empty after applying them b/c they should only
|
|
566
|
+
# be used once for the current tool calls
|
|
567
|
+
exe_context.tool_execution_decisions = None
|
|
568
|
+
tool_messages = tool_invoker_result["tool_messages"]
|
|
569
|
+
exe_context.state = tool_invoker_result["state"]
|
|
570
|
+
exe_context.state.set("messages", tool_messages)
|
|
571
|
+
|
|
572
|
+
# Check if any LLM message's tool call name matches an exit condition
|
|
573
|
+
if self.exit_conditions != ["text"] and self._check_exit_conditions(llm_messages, tool_messages):
|
|
574
|
+
exe_context.counter += 1
|
|
575
|
+
break
|
|
576
|
+
|
|
577
|
+
# Increment the step counter
|
|
578
|
+
exe_context.counter += 1
|
|
579
|
+
|
|
580
|
+
if exe_context.counter >= self.max_agent_steps:
|
|
581
|
+
logger.warning(
|
|
582
|
+
"Agent reached maximum agent steps of {max_agent_steps}, stopping.",
|
|
583
|
+
max_agent_steps=self.max_agent_steps,
|
|
584
|
+
)
|
|
585
|
+
span.set_content_tag("haystack.agent.output", exe_context.state.data)
|
|
586
|
+
span.set_tag("haystack.agent.steps_taken", exe_context.counter)
|
|
587
|
+
|
|
588
|
+
result = {**exe_context.state.data}
|
|
589
|
+
if msgs := result.get("messages"):
|
|
590
|
+
result["last_message"] = msgs[-1]
|
|
591
|
+
return result
|
|
592
|
+
|
|
593
|
+
def to_dict(self) -> dict[str, Any]:
|
|
594
|
+
"""
|
|
595
|
+
Serialize the component to a dictionary.
|
|
596
|
+
|
|
597
|
+
:return: Dictionary with serialized data
|
|
598
|
+
"""
|
|
599
|
+
data = super(Agent, self).to_dict()
|
|
600
|
+
data["init_parameters"]["confirmation_strategies"] = (
|
|
601
|
+
{name: strategy.to_dict() for name, strategy in self._confirmation_strategies.items()}
|
|
602
|
+
if self._confirmation_strategies
|
|
603
|
+
else None
|
|
604
|
+
)
|
|
605
|
+
return data
|
|
606
|
+
|
|
607
|
+
@classmethod
|
|
608
|
+
def from_dict(cls, data: dict[str, Any]) -> "Agent":
|
|
609
|
+
"""
|
|
610
|
+
Deserialize the agent from a dictionary.
|
|
611
|
+
|
|
612
|
+
:param data: Dictionary to deserialize from
|
|
613
|
+
:return: Deserialized agent
|
|
614
|
+
"""
|
|
615
|
+
init_params = data.get("init_parameters", {})
|
|
616
|
+
|
|
617
|
+
deserialize_chatgenerator_inplace(init_params, key="chat_generator")
|
|
618
|
+
|
|
619
|
+
if "state_schema" in init_params:
|
|
620
|
+
init_params["state_schema"] = _schema_from_dict(init_params["state_schema"])
|
|
621
|
+
|
|
622
|
+
if init_params.get("streaming_callback") is not None:
|
|
623
|
+
init_params["streaming_callback"] = deserialize_callable(init_params["streaming_callback"])
|
|
624
|
+
|
|
625
|
+
deserialize_tools_or_toolset_inplace(init_params, key="tools")
|
|
626
|
+
|
|
627
|
+
if "confirmation_strategies" in init_params and init_params["confirmation_strategies"] is not None:
|
|
628
|
+
for name, strategy_dict in init_params["confirmation_strategies"].items():
|
|
629
|
+
strategy_class = import_class_by_name(strategy_dict["type"])
|
|
630
|
+
if not hasattr(strategy_class, "from_dict"):
|
|
631
|
+
raise TypeError(f"{strategy_class} does not have from_dict method implemented.")
|
|
632
|
+
init_params["confirmation_strategies"][name] = strategy_class.from_dict(strategy_dict)
|
|
633
|
+
|
|
634
|
+
return default_from_dict(cls, data)
|