haystack-experimental 0.16.0__py3-none-any.whl → 0.18.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- haystack_experimental/components/agents/agent.py +135 -124
- haystack_experimental/components/agents/human_in_the_loop/__init__.py +2 -15
- haystack_experimental/components/agents/human_in_the_loop/breakpoint.py +1 -2
- haystack_experimental/components/agents/human_in_the_loop/strategies.py +4 -548
- haystack_experimental/dataclasses/breakpoints.py +3 -4
- haystack_experimental/memory_stores/__init__.py +7 -0
- haystack_experimental/memory_stores/mem0/__init__.py +16 -0
- haystack_experimental/memory_stores/mem0/memory_store.py +323 -0
- haystack_experimental/memory_stores/types/__init__.py +7 -0
- haystack_experimental/memory_stores/types/protocol.py +94 -0
- {haystack_experimental-0.16.0.dist-info → haystack_experimental-0.18.0.dist-info}/METADATA +25 -26
- {haystack_experimental-0.16.0.dist-info → haystack_experimental-0.18.0.dist-info}/RECORD +15 -14
- haystack_experimental/components/agents/human_in_the_loop/dataclasses.py +0 -72
- haystack_experimental/components/agents/human_in_the_loop/policies.py +0 -78
- haystack_experimental/components/agents/human_in_the_loop/types.py +0 -124
- haystack_experimental/components/agents/human_in_the_loop/user_interfaces.py +0 -209
- {haystack_experimental-0.16.0.dist-info → haystack_experimental-0.18.0.dist-info}/WHEEL +0 -0
- {haystack_experimental-0.16.0.dist-info → haystack_experimental-0.18.0.dist-info}/licenses/LICENSE +0 -0
- {haystack_experimental-0.16.0.dist-info → haystack_experimental-0.18.0.dist-info}/licenses/LICENSE-MIT.txt +0 -0
|
@@ -6,7 +6,6 @@
|
|
|
6
6
|
# ruff: noqa: I001
|
|
7
7
|
|
|
8
8
|
import inspect
|
|
9
|
-
from dataclasses import dataclass
|
|
10
9
|
from typing import Any
|
|
11
10
|
|
|
12
11
|
# Monkey patch Haystack's AgentSnapshot with our extended version
|
|
@@ -22,11 +21,14 @@ import haystack_experimental.core.pipeline.breakpoint as exp_breakpoint
|
|
|
22
21
|
hs_breakpoint._create_agent_snapshot = exp_breakpoint._create_agent_snapshot
|
|
23
22
|
hs_breakpoint._create_pipeline_snapshot_from_tool_invoker = exp_breakpoint._create_pipeline_snapshot_from_tool_invoker # type: ignore[assignment]
|
|
24
23
|
|
|
25
|
-
from haystack import
|
|
26
|
-
from haystack.components.agents.agent import Agent as HaystackAgent
|
|
27
|
-
from haystack.
|
|
28
|
-
|
|
29
|
-
|
|
24
|
+
from haystack import component, logging
|
|
25
|
+
from haystack.components.agents.agent import Agent as HaystackAgent, _ExecutionContext, _schema_from_dict
|
|
26
|
+
from haystack.human_in_the_loop.strategies import (
|
|
27
|
+
ConfirmationStrategy,
|
|
28
|
+
_process_confirmation_strategies,
|
|
29
|
+
_process_confirmation_strategies_async,
|
|
30
|
+
)
|
|
31
|
+
from haystack.components.agents.state import replace_values
|
|
30
32
|
from haystack.components.generators.chat.types import ChatGenerator
|
|
31
33
|
from haystack.core.errors import BreakpointException, PipelineRuntimeError
|
|
32
34
|
from haystack.core.pipeline import AsyncPipeline, Pipeline
|
|
@@ -37,50 +39,24 @@ from haystack.core.pipeline.breakpoint import (
|
|
|
37
39
|
_should_trigger_tool_invoker_breakpoint,
|
|
38
40
|
)
|
|
39
41
|
from haystack.core.pipeline.utils import _deepcopy_with_exceptions
|
|
40
|
-
from haystack.core.serialization import default_from_dict
|
|
41
|
-
from haystack.dataclasses import ChatMessage
|
|
42
|
+
from haystack.core.serialization import default_from_dict
|
|
43
|
+
from haystack.dataclasses import ChatMessage
|
|
42
44
|
from haystack.dataclasses.breakpoints import AgentBreakpoint, ToolBreakpoint
|
|
43
|
-
from haystack.dataclasses.streaming_chunk import StreamingCallbackT
|
|
45
|
+
from haystack.dataclasses.streaming_chunk import StreamingCallbackT
|
|
44
46
|
from haystack.tools import ToolsType, deserialize_tools_or_toolset_inplace
|
|
45
47
|
from haystack.utils.callable_serialization import deserialize_callable
|
|
46
|
-
from haystack.utils.deserialization import
|
|
48
|
+
from haystack.utils.deserialization import deserialize_component_inplace
|
|
47
49
|
|
|
48
50
|
from haystack_experimental.chat_message_stores.types import ChatMessageStore
|
|
49
|
-
from haystack_experimental.components.agents.human_in_the_loop import
|
|
50
|
-
ConfirmationStrategy,
|
|
51
|
-
ToolExecutionDecision,
|
|
52
|
-
HITLBreakpointException,
|
|
53
|
-
)
|
|
54
|
-
from haystack_experimental.components.agents.human_in_the_loop.strategies import (
|
|
55
|
-
_process_confirmation_strategies,
|
|
56
|
-
_process_confirmation_strategies_async,
|
|
57
|
-
)
|
|
51
|
+
from haystack_experimental.components.agents.human_in_the_loop import HITLBreakpointException
|
|
58
52
|
from haystack_experimental.components.retrievers import ChatMessageRetriever
|
|
59
53
|
from haystack_experimental.components.writers import ChatMessageWriter
|
|
54
|
+
from haystack_experimental.memory_stores.types import MemoryStore
|
|
60
55
|
|
|
61
56
|
logger = logging.getLogger(__name__)
|
|
62
57
|
|
|
63
58
|
|
|
64
|
-
@
|
|
65
|
-
class _ExecutionContext(Haystack_ExecutionContext):
|
|
66
|
-
"""
|
|
67
|
-
Execution context for the Agent component
|
|
68
|
-
|
|
69
|
-
Extends Haystack's _ExecutionContext to include tool execution decisions for human-in-the-loop strategies.
|
|
70
|
-
|
|
71
|
-
:param tool_execution_decisions: Optional list of ToolExecutionDecision objects to use instead of prompting
|
|
72
|
-
the user. This is useful when restarting from a snapshot where tool execution decisions were already made.
|
|
73
|
-
:param confirmation_strategy_context: Optional dictionary for passing request-scoped resources
|
|
74
|
-
to confirmation strategies. In web/server environments, this enables passing per-request
|
|
75
|
-
objects (e.g., WebSocket connections, async queues, or pub/sub clients) that strategies can use for
|
|
76
|
-
non-blocking user interaction. This is passed directly to strategies via the `confirmation_strategy_context`
|
|
77
|
-
parameter in their `run()` and `run_async()` methods.
|
|
78
|
-
"""
|
|
79
|
-
|
|
80
|
-
tool_execution_decisions: list[ToolExecutionDecision] | None = None
|
|
81
|
-
confirmation_strategy_context: dict[str, Any] | None = None
|
|
82
|
-
|
|
83
|
-
|
|
59
|
+
@component
|
|
84
60
|
class Agent(HaystackAgent):
|
|
85
61
|
"""
|
|
86
62
|
A Haystack component that implements a tool-using agent with provider-agnostic chat model support.
|
|
@@ -146,6 +122,7 @@ class Agent(HaystackAgent):
|
|
|
146
122
|
confirmation_strategies: dict[str, ConfirmationStrategy] | None = None,
|
|
147
123
|
tool_invoker_kwargs: dict[str, Any] | None = None,
|
|
148
124
|
chat_message_store: ChatMessageStore | None = None,
|
|
125
|
+
memory_store: MemoryStore | None = None,
|
|
149
126
|
) -> None:
|
|
150
127
|
"""
|
|
151
128
|
Initialize the agent component.
|
|
@@ -164,6 +141,9 @@ class Agent(HaystackAgent):
|
|
|
164
141
|
:param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails?
|
|
165
142
|
If set to False, the exception will be turned into a chat message and passed to the LLM.
|
|
166
143
|
:param tool_invoker_kwargs: Additional keyword arguments to pass to the ToolInvoker.
|
|
144
|
+
:param chat_message_store: The ChatMessageStore that the agent can use to store
|
|
145
|
+
and retrieve chat messages history.
|
|
146
|
+
:param memory_store: The memory store that the agent can use to store and retrieve memories.
|
|
167
147
|
:raises TypeError: If the chat_generator does not support tools parameter in its run method.
|
|
168
148
|
:raises ValueError: If the exit_conditions are not valid.
|
|
169
149
|
"""
|
|
@@ -177,8 +157,8 @@ class Agent(HaystackAgent):
|
|
|
177
157
|
streaming_callback=streaming_callback,
|
|
178
158
|
raise_on_tool_invocation_failure=raise_on_tool_invocation_failure,
|
|
179
159
|
tool_invoker_kwargs=tool_invoker_kwargs,
|
|
160
|
+
confirmation_strategies=confirmation_strategies,
|
|
180
161
|
)
|
|
181
|
-
self._confirmation_strategies = confirmation_strategies or {}
|
|
182
162
|
self._chat_message_store = chat_message_store
|
|
183
163
|
self._chat_message_retriever = (
|
|
184
164
|
ChatMessageRetriever(chat_message_store=chat_message_store) if chat_message_store else None
|
|
@@ -186,6 +166,7 @@ class Agent(HaystackAgent):
|
|
|
186
166
|
self._chat_message_writer = (
|
|
187
167
|
ChatMessageWriter(chat_message_store=chat_message_store) if chat_message_store else None
|
|
188
168
|
)
|
|
169
|
+
self._memory_store = memory_store
|
|
189
170
|
|
|
190
171
|
def _initialize_fresh_execution(
|
|
191
172
|
self,
|
|
@@ -198,6 +179,7 @@ class Agent(HaystackAgent):
|
|
|
198
179
|
tools: ToolsType | list[str] | None = None,
|
|
199
180
|
confirmation_strategy_context: dict[str, Any] | None = None,
|
|
200
181
|
chat_message_store_kwargs: dict[str, Any] | None = None,
|
|
182
|
+
memory_store_kwargs: dict[str, Any] | None = None,
|
|
201
183
|
**kwargs: dict[str, Any],
|
|
202
184
|
) -> _ExecutionContext:
|
|
203
185
|
"""
|
|
@@ -209,56 +191,67 @@ class Agent(HaystackAgent):
|
|
|
209
191
|
:param system_prompt: System prompt for the agent. If provided, it overrides the default system prompt.
|
|
210
192
|
:param tools: Optional list of Tool objects, a Toolset, or list of tool names to use for this run.
|
|
211
193
|
When passing tool names, tools are selected from the Agent's originally configured tools.
|
|
194
|
+
|
|
195
|
+
:param memory_store_kwargs: Optional dictionary of keyword arguments to pass to the MemoryStore.
|
|
196
|
+
For example, it can include the `user_id`, `run_id`, and `agent_id` parameters
|
|
197
|
+
for storing and retrieving memories.
|
|
212
198
|
:param confirmation_strategy_context: Optional dictionary for passing request-scoped resources
|
|
213
199
|
to confirmation strategies.
|
|
214
200
|
:param chat_message_store_kwargs: Optional dictionary of keyword arguments to pass to the ChatMessageStore.
|
|
201
|
+
For example, it can include the `chat_history_id` and `last_k` parameters for retrieving chat history.
|
|
215
202
|
:param kwargs: Additional data to pass to the State used by the Agent.
|
|
216
203
|
"""
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
204
|
+
exe_context = super(Agent, self)._initialize_fresh_execution(
|
|
205
|
+
messages=messages,
|
|
206
|
+
streaming_callback=streaming_callback,
|
|
207
|
+
requires_async=requires_async,
|
|
208
|
+
system_prompt=system_prompt,
|
|
209
|
+
generation_kwargs=generation_kwargs,
|
|
210
|
+
tools=tools,
|
|
211
|
+
confirmation_strategy_context=confirmation_strategy_context,
|
|
212
|
+
chat_message_store_kwargs=chat_message_store_kwargs,
|
|
213
|
+
**kwargs,
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
# NOTE: difference with parent method to add memory retrieval
|
|
217
|
+
if self._memory_store:
|
|
218
|
+
retrieved_memories = self._memory_store.search_memories(
|
|
219
|
+
query=messages[-1].text, **memory_store_kwargs if memory_store_kwargs else {}
|
|
220
|
+
)
|
|
221
|
+
# we combine the memories into a single string
|
|
222
|
+
combined_memory = "\n".join(
|
|
223
|
+
f"- MEMORY #{idx + 1}: {memory.text}" for idx, memory in enumerate(retrieved_memories)
|
|
224
|
+
)
|
|
225
|
+
retrieved_memory = ChatMessage.from_system(text=combined_memory)
|
|
226
|
+
memory_instruction = (
|
|
227
|
+
"\n\nWhen messages start with `[MEMORY]`, treat them as long-term context and use them to guide the "
|
|
228
|
+
"response if relevant."
|
|
229
|
+
)
|
|
230
|
+
new_system_message = ChatMessage.from_system(text=f"{system_prompt}{memory_instruction}")
|
|
231
|
+
memory_system_message = ChatMessage.from_system(
|
|
232
|
+
text=f"Here are the relevant memories for the user's query: {retrieved_memory.text}"
|
|
233
|
+
)
|
|
234
|
+
new_chat_history = [new_system_message] + messages + [memory_system_message]
|
|
235
|
+
# We replace the messages in state with the new chat history including memories
|
|
236
|
+
exe_context.state.set("messages", new_chat_history, handler_override=replace_values)
|
|
220
237
|
|
|
221
238
|
# NOTE: difference with parent method to add chat message retrieval
|
|
222
239
|
if self._chat_message_retriever:
|
|
223
240
|
retriever_kwargs = _select_kwargs(self._chat_message_retriever, chat_message_store_kwargs or {})
|
|
224
241
|
if "chat_history_id" in retriever_kwargs:
|
|
225
|
-
|
|
226
|
-
current_messages=messages,
|
|
242
|
+
updated_messages = self._chat_message_retriever.run(
|
|
243
|
+
current_messages=exe_context.state.get("messages", []),
|
|
227
244
|
**retriever_kwargs,
|
|
228
245
|
)["messages"]
|
|
246
|
+
# We replace the messages in state with the updated messages including chat history
|
|
247
|
+
exe_context.state.set("messages", updated_messages, handler_override=replace_values)
|
|
229
248
|
|
|
230
|
-
if all(m.is_from(ChatRole.SYSTEM) for m in messages):
|
|
231
|
-
logger.warning("All messages provided to the Agent component are system messages. This is not recommended.")
|
|
232
|
-
|
|
233
|
-
state = State(schema=self.state_schema, data=kwargs)
|
|
234
|
-
state.set("messages", messages)
|
|
235
|
-
|
|
236
|
-
streaming_callback = select_streaming_callback( # type: ignore[call-overload]
|
|
237
|
-
init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=requires_async
|
|
238
|
-
)
|
|
239
|
-
|
|
240
|
-
selected_tools = self._select_tools(tools)
|
|
241
|
-
tool_invoker_inputs: dict[str, Any] = {"tools": selected_tools}
|
|
242
|
-
generator_inputs: dict[str, Any] = {"tools": selected_tools}
|
|
243
|
-
if streaming_callback is not None:
|
|
244
|
-
tool_invoker_inputs["streaming_callback"] = streaming_callback
|
|
245
|
-
generator_inputs["streaming_callback"] = streaming_callback
|
|
246
|
-
if generation_kwargs is not None:
|
|
247
|
-
generator_inputs["generation_kwargs"] = generation_kwargs
|
|
248
|
-
|
|
249
|
-
# NOTE: difference with parent method to add this to tool_invoker_inputs
|
|
250
|
-
if self._tool_invoker:
|
|
251
|
-
tool_invoker_inputs["enable_streaming_callback_passthrough"] = (
|
|
252
|
-
self._tool_invoker.enable_streaming_callback_passthrough
|
|
253
|
-
)
|
|
254
|
-
|
|
255
|
-
# NOTE: difference is to use the extended _ExecutionContext with confirmation_strategy_context
|
|
256
249
|
return _ExecutionContext(
|
|
257
|
-
state=state,
|
|
258
|
-
component_visits=
|
|
259
|
-
chat_generator_inputs=
|
|
260
|
-
tool_invoker_inputs=tool_invoker_inputs,
|
|
261
|
-
confirmation_strategy_context=confirmation_strategy_context,
|
|
250
|
+
state=exe_context.state,
|
|
251
|
+
component_visits=exe_context.component_visits,
|
|
252
|
+
chat_generator_inputs=exe_context.chat_generator_inputs,
|
|
253
|
+
tool_invoker_inputs=exe_context.tool_invoker_inputs,
|
|
254
|
+
confirmation_strategy_context=exe_context.confirmation_strategy_context,
|
|
262
255
|
)
|
|
263
256
|
|
|
264
257
|
def _initialize_from_snapshot( # type: ignore[override]
|
|
@@ -284,28 +277,15 @@ class Agent(HaystackAgent):
|
|
|
284
277
|
:param confirmation_strategy_context: Optional dictionary for passing request-scoped resources
|
|
285
278
|
to confirmation strategies.
|
|
286
279
|
"""
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
tools=tools,
|
|
297
|
-
)
|
|
298
|
-
else:
|
|
299
|
-
exe_context = super(Agent, self)._initialize_from_snapshot(
|
|
300
|
-
snapshot=snapshot, streaming_callback=streaming_callback, requires_async=requires_async, tools=tools
|
|
301
|
-
)
|
|
302
|
-
# NOTE: 1st difference with parent method to add this to tool_invoker_inputs
|
|
303
|
-
if self._tool_invoker:
|
|
304
|
-
exe_context.tool_invoker_inputs["enable_streaming_callback_passthrough"] = (
|
|
305
|
-
self._tool_invoker.enable_streaming_callback_passthrough
|
|
306
|
-
)
|
|
307
|
-
# NOTE: 2nd difference is to use the extended _ExecutionContext
|
|
308
|
-
# and add tool_execution_decisions + confirmation_strategy_context
|
|
280
|
+
exe_context = super(Agent, self)._initialize_from_snapshot(
|
|
281
|
+
snapshot=snapshot,
|
|
282
|
+
streaming_callback=streaming_callback,
|
|
283
|
+
requires_async=requires_async,
|
|
284
|
+
generation_kwargs=generation_kwargs,
|
|
285
|
+
tools=tools,
|
|
286
|
+
confirmation_strategy_context=confirmation_strategy_context,
|
|
287
|
+
)
|
|
288
|
+
# NOTE: Only difference is to use pass tool_execution_decisions to _ExecutionContext
|
|
309
289
|
return _ExecutionContext(
|
|
310
290
|
state=exe_context.state,
|
|
311
291
|
component_visits=exe_context.component_visits,
|
|
@@ -313,8 +293,8 @@ class Agent(HaystackAgent):
|
|
|
313
293
|
tool_invoker_inputs=exe_context.tool_invoker_inputs,
|
|
314
294
|
counter=exe_context.counter,
|
|
315
295
|
skip_chat_generator=exe_context.skip_chat_generator,
|
|
296
|
+
confirmation_strategy_context=exe_context.confirmation_strategy_context,
|
|
316
297
|
tool_execution_decisions=snapshot.tool_execution_decisions,
|
|
317
|
-
confirmation_strategy_context=confirmation_strategy_context,
|
|
318
298
|
)
|
|
319
299
|
|
|
320
300
|
def run( # type: ignore[override] # noqa: PLR0915 PLR0912
|
|
@@ -329,6 +309,7 @@ class Agent(HaystackAgent):
|
|
|
329
309
|
tools: ToolsType | list[str] | None = None,
|
|
330
310
|
confirmation_strategy_context: dict[str, Any] | None = None,
|
|
331
311
|
chat_message_store_kwargs: dict[str, Any] | None = None,
|
|
312
|
+
memory_store_kwargs: dict[str, Any] | None = None,
|
|
332
313
|
**kwargs: Any,
|
|
333
314
|
) -> dict[str, Any]:
|
|
334
315
|
"""
|
|
@@ -352,6 +333,19 @@ class Agent(HaystackAgent):
|
|
|
352
333
|
can use for non-blocking user interaction.
|
|
353
334
|
:param chat_message_store_kwargs: Optional dictionary of keyword arguments to pass to the ChatMessageStore.
|
|
354
335
|
For example, it can include the `chat_history_id` and `last_k` parameters for retrieving chat history.
|
|
336
|
+
:param memory_store_kwargs: Optional dictionary of keyword arguments to pass to the MemoryStore.
|
|
337
|
+
It can include:
|
|
338
|
+
- `user_id`: The user ID to search and add memories from.
|
|
339
|
+
- `run_id`: The run ID to search and add memories from.
|
|
340
|
+
- `agent_id`: The agent ID to search and add memories from.
|
|
341
|
+
- `search_criteria`: A dictionary of containing kwargs for the `search_memories` method.
|
|
342
|
+
This can include:
|
|
343
|
+
- `filters`: A dictionary of filters to search for memories.
|
|
344
|
+
- `query`: The query to search for memories.
|
|
345
|
+
Note: If you pass this, the user query passed to the agent will be
|
|
346
|
+
ignored for memory retrieval.
|
|
347
|
+
- `top_k`: The number of memories to return.
|
|
348
|
+
- `include_memory_metadata`: Whether to include the memory metadata in the ChatMessage.
|
|
355
349
|
:param kwargs: Additional data to pass to the State schema used by the Agent.
|
|
356
350
|
The keys must match the schema defined in the Agent's `state_schema`.
|
|
357
351
|
:returns:
|
|
@@ -362,6 +356,8 @@ class Agent(HaystackAgent):
|
|
|
362
356
|
:raises RuntimeError: If the Agent component wasn't warmed up before calling `run()`.
|
|
363
357
|
:raises BreakpointException: If an agent breakpoint is triggered.
|
|
364
358
|
"""
|
|
359
|
+
memory_store_kwargs = memory_store_kwargs or {}
|
|
360
|
+
|
|
365
361
|
agent_inputs = {
|
|
366
362
|
"messages": messages,
|
|
367
363
|
"streaming_callback": streaming_callback,
|
|
@@ -392,6 +388,7 @@ class Agent(HaystackAgent):
|
|
|
392
388
|
tools=tools,
|
|
393
389
|
confirmation_strategy_context=confirmation_strategy_context,
|
|
394
390
|
chat_message_store_kwargs=chat_message_store_kwargs,
|
|
391
|
+
memory_store_kwargs=memory_store_kwargs,
|
|
395
392
|
**kwargs,
|
|
396
393
|
)
|
|
397
394
|
|
|
@@ -458,8 +455,9 @@ class Agent(HaystackAgent):
|
|
|
458
455
|
resolved_break_point = break_point
|
|
459
456
|
break_point_to_pass = resolved_break_point.break_point
|
|
460
457
|
|
|
461
|
-
#
|
|
458
|
+
# NOTE: difference with parent method to add support HITLBreakpointException
|
|
462
459
|
try:
|
|
460
|
+
# Apply confirmation strategies and update State and messages sent to ToolInvoker
|
|
463
461
|
# Run confirmation strategies to get updated tool call messages and modified chat history
|
|
464
462
|
modified_tool_call_messages, new_chat_history = _process_confirmation_strategies(
|
|
465
463
|
confirmation_strategies=self._confirmation_strategies,
|
|
@@ -547,6 +545,11 @@ class Agent(HaystackAgent):
|
|
|
547
545
|
if msgs := result.get("messages"):
|
|
548
546
|
result["last_message"] = msgs[-1]
|
|
549
547
|
|
|
548
|
+
# Add the new conversation as memories to the memory store
|
|
549
|
+
if self._memory_store:
|
|
550
|
+
new_memories = [message for message in msgs if message.role.value != "system"]
|
|
551
|
+
self._memory_store.add_memories(messages=new_memories, **memory_store_kwargs)
|
|
552
|
+
|
|
550
553
|
# Write messages to ChatMessageStore if configured
|
|
551
554
|
if self._chat_message_writer:
|
|
552
555
|
writer_kwargs = _select_kwargs(self._chat_message_writer, chat_message_store_kwargs or {})
|
|
@@ -567,6 +570,7 @@ class Agent(HaystackAgent):
|
|
|
567
570
|
tools: ToolsType | list[str] | None = None,
|
|
568
571
|
confirmation_strategy_context: dict[str, Any] | None = None,
|
|
569
572
|
chat_message_store_kwargs: dict[str, Any] | None = None,
|
|
573
|
+
memory_store_kwargs: dict[str, Any] | None = None,
|
|
570
574
|
**kwargs: Any,
|
|
571
575
|
) -> dict[str, Any]:
|
|
572
576
|
"""
|
|
@@ -593,6 +597,20 @@ class Agent(HaystackAgent):
|
|
|
593
597
|
can use for non-blocking user interaction.
|
|
594
598
|
:param chat_message_store_kwargs: Optional dictionary of keyword arguments to pass to the ChatMessageStore.
|
|
595
599
|
For example, it can include the `chat_history_id` and `last_k` parameters for retrieving chat history.
|
|
600
|
+
:param kwargs: Additional data to pass to the State schema used by the Agent.
|
|
601
|
+
:param memory_store_kwargs: Optional dictionary of keyword arguments to pass to the MemoryStore.
|
|
602
|
+
It can include:
|
|
603
|
+
- `user_id`: The user ID to search and add memories from.
|
|
604
|
+
- `run_id`: The run ID to search and add memories from.
|
|
605
|
+
- `agent_id`: The agent ID to search and add memories from.
|
|
606
|
+
- `search_criteria`: A dictionary of containing kwargs for the `search_memories` method.
|
|
607
|
+
This can include:
|
|
608
|
+
- `filters`: A dictionary of filters to search for memories.
|
|
609
|
+
- `query`: The query to search for memories.
|
|
610
|
+
Note: If you pass this, the user query passed to the agent will be
|
|
611
|
+
ignored for memory retrieval.
|
|
612
|
+
- `top_k`: The number of memories to return.
|
|
613
|
+
- `include_memory_metadata`: Whether to include the memory metadata in the ChatMessage.
|
|
596
614
|
:param kwargs: Additional data to pass to the State schema used by the Agent.
|
|
597
615
|
The keys must match the schema defined in the Agent's `state_schema`.
|
|
598
616
|
:returns:
|
|
@@ -603,6 +621,8 @@ class Agent(HaystackAgent):
|
|
|
603
621
|
:raises RuntimeError: If the Agent component wasn't warmed up before calling `run_async()`.
|
|
604
622
|
:raises BreakpointException: If an agent breakpoint is triggered.
|
|
605
623
|
"""
|
|
624
|
+
memory_store_kwargs = memory_store_kwargs or {}
|
|
625
|
+
|
|
606
626
|
agent_inputs = {
|
|
607
627
|
"messages": messages,
|
|
608
628
|
"streaming_callback": streaming_callback,
|
|
@@ -631,6 +651,7 @@ class Agent(HaystackAgent):
|
|
|
631
651
|
tools=tools,
|
|
632
652
|
confirmation_strategy_context=confirmation_strategy_context,
|
|
633
653
|
chat_message_store_kwargs=chat_message_store_kwargs,
|
|
654
|
+
memory_store_kwargs=memory_store_kwargs,
|
|
634
655
|
**kwargs,
|
|
635
656
|
)
|
|
636
657
|
|
|
@@ -692,8 +713,9 @@ class Agent(HaystackAgent):
|
|
|
692
713
|
resolved_break_point = break_point
|
|
693
714
|
break_point_to_pass = resolved_break_point.break_point
|
|
694
715
|
|
|
695
|
-
#
|
|
716
|
+
# NOTE: difference with parent method to add support HITLBreakpointException
|
|
696
717
|
try:
|
|
718
|
+
# Apply confirmation strategies and update State and messages sent to ToolInvoker
|
|
697
719
|
# Run confirmation strategies to get updated tool call messages and modified chat history (async)
|
|
698
720
|
modified_tool_call_messages, new_chat_history = await _process_confirmation_strategies_async(
|
|
699
721
|
confirmation_strategies=self._confirmation_strategies,
|
|
@@ -773,6 +795,11 @@ class Agent(HaystackAgent):
|
|
|
773
795
|
if msgs := result.get("messages"):
|
|
774
796
|
result["last_message"] = msgs[-1]
|
|
775
797
|
|
|
798
|
+
# Add the new conversation as memories to the memory store
|
|
799
|
+
if self._memory_store:
|
|
800
|
+
new_memories = [message for message in msgs if message.role.value != "system"]
|
|
801
|
+
self._memory_store.add_memories(messages=new_memories, **memory_store_kwargs)
|
|
802
|
+
|
|
776
803
|
# Write messages to ChatMessageStore if configured
|
|
777
804
|
if self._chat_message_writer:
|
|
778
805
|
writer_kwargs = _select_kwargs(self._chat_message_writer, chat_message_store_kwargs or {})
|
|
@@ -788,11 +815,7 @@ class Agent(HaystackAgent):
|
|
|
788
815
|
:return: Dictionary with serialized data
|
|
789
816
|
"""
|
|
790
817
|
data = super(Agent, self).to_dict()
|
|
791
|
-
|
|
792
|
-
{name: strategy.to_dict() for name, strategy in self._confirmation_strategies.items()}
|
|
793
|
-
if self._confirmation_strategies
|
|
794
|
-
else None
|
|
795
|
-
)
|
|
818
|
+
# NOTE: This is different from the base Agent class to handle ChatMessageStore serialization
|
|
796
819
|
data["init_parameters"]["chat_message_store"] = (
|
|
797
820
|
self._chat_message_store.to_dict() if self._chat_message_store is not None else None
|
|
798
821
|
)
|
|
@@ -808,9 +831,9 @@ class Agent(HaystackAgent):
|
|
|
808
831
|
"""
|
|
809
832
|
init_params = data.get("init_parameters", {})
|
|
810
833
|
|
|
811
|
-
|
|
834
|
+
deserialize_component_inplace(init_params, key="chat_generator")
|
|
812
835
|
|
|
813
|
-
if "state_schema"
|
|
836
|
+
if init_params.get("state_schema") is not None:
|
|
814
837
|
init_params["state_schema"] = _schema_from_dict(init_params["state_schema"])
|
|
815
838
|
|
|
816
839
|
if init_params.get("streaming_callback") is not None:
|
|
@@ -819,24 +842,12 @@ class Agent(HaystackAgent):
|
|
|
819
842
|
deserialize_tools_or_toolset_inplace(init_params, key="tools")
|
|
820
843
|
|
|
821
844
|
if "confirmation_strategies" in init_params and init_params["confirmation_strategies"] is not None:
|
|
822
|
-
for name
|
|
823
|
-
|
|
824
|
-
strategy_class = import_class_by_name(strategy_dict["type"])
|
|
825
|
-
except ImportError as e:
|
|
826
|
-
raise DeserializationError(f"Class '{strategy_dict['type']}' not correctly imported") from e
|
|
827
|
-
if not hasattr(strategy_class, "from_dict"):
|
|
828
|
-
raise DeserializationError(f"{strategy_class} does not have from_dict method implemented.")
|
|
829
|
-
init_params["confirmation_strategies"][name] = strategy_class.from_dict(strategy_dict)
|
|
845
|
+
for name in init_params["confirmation_strategies"]:
|
|
846
|
+
deserialize_component_inplace(init_params["confirmation_strategies"], key=name)
|
|
830
847
|
|
|
848
|
+
# NOTE: This is different from the base Agent class to handle ChatMessageStore deserialization
|
|
831
849
|
if "chat_message_store" in init_params and init_params["chat_message_store"] is not None:
|
|
832
|
-
|
|
833
|
-
try:
|
|
834
|
-
cms_class = import_class_by_name(cms_data["type"])
|
|
835
|
-
except ImportError as e:
|
|
836
|
-
raise DeserializationError(f"Class '{cms_data['type']}' not correctly imported") from e
|
|
837
|
-
if not hasattr(cms_class, "from_dict"):
|
|
838
|
-
raise DeserializationError(f"{cms_class} does not have from_dict method implemented.")
|
|
839
|
-
init_params["chat_message_store"] = cms_class.from_dict(cms_data)
|
|
850
|
+
deserialize_component_inplace(init_params, key="chat_message_store")
|
|
840
851
|
|
|
841
852
|
return default_from_dict(cls, data)
|
|
842
853
|
|
|
@@ -8,28 +8,15 @@ from typing import TYPE_CHECKING
|
|
|
8
8
|
from lazy_imports import LazyImporter
|
|
9
9
|
|
|
10
10
|
_import_structure = {
|
|
11
|
-
"dataclasses": ["
|
|
11
|
+
"dataclasses": ["ToolExecutionDecision"],
|
|
12
12
|
"errors": ["HITLBreakpointException"],
|
|
13
|
-
"
|
|
14
|
-
"strategies": ["BlockingConfirmationStrategy", "BreakpointConfirmationStrategy"],
|
|
15
|
-
"types": ["ConfirmationPolicy", "ConfirmationUI", "ConfirmationStrategy"],
|
|
16
|
-
"user_interfaces": ["RichConsoleUI", "SimpleConsoleUI"],
|
|
13
|
+
"strategies": ["BreakpointConfirmationStrategy"],
|
|
17
14
|
}
|
|
18
15
|
|
|
19
16
|
if TYPE_CHECKING:
|
|
20
|
-
from .dataclasses import ConfirmationUIResult as ConfirmationUIResult
|
|
21
17
|
from .dataclasses import ToolExecutionDecision as ToolExecutionDecision
|
|
22
18
|
from .errors import HITLBreakpointException as HITLBreakpointException
|
|
23
|
-
from .policies import AlwaysAskPolicy as AlwaysAskPolicy
|
|
24
|
-
from .policies import AskOncePolicy as AskOncePolicy
|
|
25
|
-
from .policies import NeverAskPolicy as NeverAskPolicy
|
|
26
|
-
from .strategies import BlockingConfirmationStrategy as BlockingConfirmationStrategy
|
|
27
19
|
from .strategies import BreakpointConfirmationStrategy as BreakpointConfirmationStrategy
|
|
28
|
-
from .types import ConfirmationPolicy as ConfirmationPolicy
|
|
29
|
-
from .types import ConfirmationStrategy as ConfirmationStrategy
|
|
30
|
-
from .types import ConfirmationUI as ConfirmationUI
|
|
31
|
-
from .user_interfaces import RichConsoleUI as RichConsoleUI
|
|
32
|
-
from .user_interfaces import SimpleConsoleUI as SimpleConsoleUI
|
|
33
20
|
|
|
34
21
|
else:
|
|
35
22
|
sys.modules[__name__] = LazyImporter(name=__name__, module_file=__file__, import_structure=_import_structure)
|
|
@@ -5,10 +5,9 @@
|
|
|
5
5
|
from copy import deepcopy
|
|
6
6
|
|
|
7
7
|
from haystack.dataclasses.breakpoints import AgentSnapshot, ToolBreakpoint
|
|
8
|
+
from haystack.human_in_the_loop.strategies import _prepare_tool_args
|
|
8
9
|
from haystack.utils import _deserialize_value_with_schema
|
|
9
10
|
|
|
10
|
-
from haystack_experimental.components.agents.human_in_the_loop.strategies import _prepare_tool_args
|
|
11
|
-
|
|
12
11
|
|
|
13
12
|
def get_tool_calls_and_descriptions_from_snapshot(
|
|
14
13
|
agent_snapshot: AgentSnapshot, breakpoint_tool_only: bool = True
|