haystack-experimental 0.14.3__py3-none-any.whl → 0.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- haystack_experimental/chat_message_stores/__init__.py +1 -1
- haystack_experimental/chat_message_stores/in_memory.py +176 -31
- haystack_experimental/chat_message_stores/types.py +33 -21
- haystack_experimental/components/agents/agent.py +147 -44
- haystack_experimental/components/agents/human_in_the_loop/strategies.py +220 -3
- haystack_experimental/components/agents/human_in_the_loop/types.py +36 -1
- haystack_experimental/components/embedders/types/protocol.py +2 -2
- haystack_experimental/components/preprocessors/embedding_based_document_splitter.py +16 -16
- haystack_experimental/components/retrievers/__init__.py +1 -3
- haystack_experimental/components/retrievers/chat_message_retriever.py +57 -26
- haystack_experimental/components/writers/__init__.py +1 -1
- haystack_experimental/components/writers/chat_message_writer.py +25 -22
- {haystack_experimental-0.14.3.dist-info → haystack_experimental-0.15.0.dist-info}/METADATA +24 -31
- {haystack_experimental-0.14.3.dist-info → haystack_experimental-0.15.0.dist-info}/RECORD +17 -24
- {haystack_experimental-0.14.3.dist-info → haystack_experimental-0.15.0.dist-info}/WHEEL +1 -1
- haystack_experimental/components/query/__init__.py +0 -18
- haystack_experimental/components/query/query_expander.py +0 -294
- haystack_experimental/components/retrievers/multi_query_embedding_retriever.py +0 -173
- haystack_experimental/components/retrievers/multi_query_text_retriever.py +0 -150
- haystack_experimental/super_components/__init__.py +0 -3
- haystack_experimental/super_components/indexers/__init__.py +0 -11
- haystack_experimental/super_components/indexers/sentence_transformers_document_indexer.py +0 -199
- {haystack_experimental-0.14.3.dist-info → haystack_experimental-0.15.0.dist-info}/licenses/LICENSE +0 -0
- {haystack_experimental-0.14.3.dist-info → haystack_experimental-0.15.0.dist-info}/licenses/LICENSE-MIT.txt +0 -0
|
@@ -22,11 +22,11 @@ import haystack_experimental.core.pipeline.breakpoint as exp_breakpoint
|
|
|
22
22
|
hs_breakpoint._create_agent_snapshot = exp_breakpoint._create_agent_snapshot
|
|
23
23
|
hs_breakpoint._create_pipeline_snapshot_from_tool_invoker = exp_breakpoint._create_pipeline_snapshot_from_tool_invoker # type: ignore[assignment]
|
|
24
24
|
|
|
25
|
-
from haystack import logging
|
|
25
|
+
from haystack import DeserializationError, logging
|
|
26
26
|
from haystack.components.agents.agent import Agent as HaystackAgent
|
|
27
27
|
from haystack.components.agents.agent import _ExecutionContext as Haystack_ExecutionContext
|
|
28
28
|
from haystack.components.agents.agent import _schema_from_dict
|
|
29
|
-
from haystack.components.agents.state import replace_values
|
|
29
|
+
from haystack.components.agents.state import replace_values, State
|
|
30
30
|
from haystack.components.generators.chat.types import ChatGenerator
|
|
31
31
|
from haystack.core.errors import PipelineRuntimeError
|
|
32
32
|
from haystack.core.pipeline import AsyncPipeline, Pipeline
|
|
@@ -36,19 +36,25 @@ from haystack.core.pipeline.breakpoint import (
|
|
|
36
36
|
)
|
|
37
37
|
from haystack.core.pipeline.utils import _deepcopy_with_exceptions
|
|
38
38
|
from haystack.core.serialization import default_from_dict, import_class_by_name
|
|
39
|
-
from haystack.dataclasses import ChatMessage
|
|
39
|
+
from haystack.dataclasses import ChatMessage, ChatRole
|
|
40
40
|
from haystack.dataclasses.breakpoints import AgentBreakpoint, ToolBreakpoint
|
|
41
|
-
from haystack.dataclasses.streaming_chunk import StreamingCallbackT
|
|
41
|
+
from haystack.dataclasses.streaming_chunk import StreamingCallbackT, select_streaming_callback
|
|
42
42
|
from haystack.tools import ToolsType, deserialize_tools_or_toolset_inplace
|
|
43
43
|
from haystack.utils.callable_serialization import deserialize_callable
|
|
44
44
|
from haystack.utils.deserialization import deserialize_chatgenerator_inplace
|
|
45
45
|
|
|
46
|
+
from haystack_experimental.chat_message_stores.types import ChatMessageStore
|
|
46
47
|
from haystack_experimental.components.agents.human_in_the_loop import (
|
|
47
48
|
ConfirmationStrategy,
|
|
48
49
|
ToolExecutionDecision,
|
|
49
50
|
HITLBreakpointException,
|
|
50
51
|
)
|
|
51
|
-
from haystack_experimental.components.agents.human_in_the_loop.strategies import
|
|
52
|
+
from haystack_experimental.components.agents.human_in_the_loop.strategies import (
|
|
53
|
+
_process_confirmation_strategies,
|
|
54
|
+
_process_confirmation_strategies_async,
|
|
55
|
+
)
|
|
56
|
+
from haystack_experimental.components.retrievers import ChatMessageRetriever
|
|
57
|
+
from haystack_experimental.components.writers import ChatMessageWriter
|
|
52
58
|
|
|
53
59
|
logger = logging.getLogger(__name__)
|
|
54
60
|
|
|
@@ -62,9 +68,15 @@ class _ExecutionContext(Haystack_ExecutionContext):
|
|
|
62
68
|
|
|
63
69
|
:param tool_execution_decisions: Optional list of ToolExecutionDecision objects to use instead of prompting
|
|
64
70
|
the user. This is useful when restarting from a snapshot where tool execution decisions were already made.
|
|
71
|
+
:param confirmation_strategy_context: Optional dictionary for passing request-scoped resources
|
|
72
|
+
to confirmation strategies. In web/server environments, this enables passing per-request
|
|
73
|
+
objects (e.g., WebSocket connections, async queues, or pub/sub clients) that strategies can use for
|
|
74
|
+
non-blocking user interaction. This is passed directly to strategies via the `confirmation_strategy_context`
|
|
75
|
+
parameter in their `run()` and `run_async()` methods.
|
|
65
76
|
"""
|
|
66
77
|
|
|
67
78
|
tool_execution_decisions: Optional[list[ToolExecutionDecision]] = None
|
|
79
|
+
confirmation_strategy_context: Optional[dict[str, Any]] = None
|
|
68
80
|
|
|
69
81
|
|
|
70
82
|
class Agent(HaystackAgent):
|
|
@@ -131,6 +143,7 @@ class Agent(HaystackAgent):
|
|
|
131
143
|
raise_on_tool_invocation_failure: bool = False,
|
|
132
144
|
confirmation_strategies: Optional[dict[str, ConfirmationStrategy]] = None,
|
|
133
145
|
tool_invoker_kwargs: Optional[dict[str, Any]] = None,
|
|
146
|
+
chat_message_store: Optional[ChatMessageStore] = None,
|
|
134
147
|
) -> None:
|
|
135
148
|
"""
|
|
136
149
|
Initialize the agent component.
|
|
@@ -164,6 +177,13 @@ class Agent(HaystackAgent):
|
|
|
164
177
|
tool_invoker_kwargs=tool_invoker_kwargs,
|
|
165
178
|
)
|
|
166
179
|
self._confirmation_strategies = confirmation_strategies or {}
|
|
180
|
+
self._chat_message_store = chat_message_store
|
|
181
|
+
self._chat_message_retriever = (
|
|
182
|
+
ChatMessageRetriever(chat_message_store=chat_message_store) if chat_message_store else None
|
|
183
|
+
)
|
|
184
|
+
self._chat_message_writer = (
|
|
185
|
+
ChatMessageWriter(chat_message_store=chat_message_store) if chat_message_store else None
|
|
186
|
+
)
|
|
167
187
|
|
|
168
188
|
def _initialize_fresh_execution(
|
|
169
189
|
self,
|
|
@@ -174,6 +194,8 @@ class Agent(HaystackAgent):
|
|
|
174
194
|
system_prompt: Optional[str] = None,
|
|
175
195
|
generation_kwargs: Optional[dict[str, Any]] = None,
|
|
176
196
|
tools: Optional[Union[ToolsType, list[str]]] = None,
|
|
197
|
+
confirmation_strategy_context: Optional[dict[str, Any]] = None,
|
|
198
|
+
chat_message_store_kwargs: Optional[dict[str, Any]] = None,
|
|
177
199
|
**kwargs: dict[str, Any],
|
|
178
200
|
) -> _ExecutionContext:
|
|
179
201
|
"""
|
|
@@ -185,41 +207,56 @@ class Agent(HaystackAgent):
|
|
|
185
207
|
:param system_prompt: System prompt for the agent. If provided, it overrides the default system prompt.
|
|
186
208
|
:param tools: Optional list of Tool objects, a Toolset, or list of tool names to use for this run.
|
|
187
209
|
When passing tool names, tools are selected from the Agent's originally configured tools.
|
|
210
|
+
:param confirmation_strategy_context: Optional dictionary for passing request-scoped resources
|
|
211
|
+
to confirmation strategies.
|
|
212
|
+
:param chat_message_store_kwargs: Optional dictionary of keyword arguments to pass to the ChatMessageStore.
|
|
188
213
|
:param kwargs: Additional data to pass to the State used by the Agent.
|
|
189
214
|
"""
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
215
|
+
system_prompt = system_prompt or self.system_prompt
|
|
216
|
+
if system_prompt is not None:
|
|
217
|
+
messages = [ChatMessage.from_system(system_prompt)] + messages
|
|
218
|
+
|
|
219
|
+
# NOTE: difference with parent method to add chat message retrieval
|
|
220
|
+
if self._chat_message_retriever:
|
|
221
|
+
retriever_kwargs = _select_kwargs(self._chat_message_retriever, chat_message_store_kwargs or {})
|
|
222
|
+
if "chat_history_id" in retriever_kwargs:
|
|
223
|
+
messages = self._chat_message_retriever.run(
|
|
224
|
+
current_messages=messages,
|
|
225
|
+
**retriever_kwargs,
|
|
226
|
+
)["messages"]
|
|
227
|
+
|
|
228
|
+
if all(m.is_from(ChatRole.SYSTEM) for m in messages):
|
|
229
|
+
logger.warning("All messages provided to the Agent component are system messages. This is not recommended.")
|
|
230
|
+
|
|
231
|
+
state = State(schema=self.state_schema, data=kwargs)
|
|
232
|
+
state.set("messages", messages)
|
|
233
|
+
|
|
234
|
+
streaming_callback = select_streaming_callback( # type: ignore[call-overload]
|
|
235
|
+
init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=requires_async
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
selected_tools = self._select_tools(tools)
|
|
239
|
+
tool_invoker_inputs: dict[str, Any] = {"tools": selected_tools}
|
|
240
|
+
generator_inputs: dict[str, Any] = {"tools": selected_tools}
|
|
241
|
+
if streaming_callback is not None:
|
|
242
|
+
tool_invoker_inputs["streaming_callback"] = streaming_callback
|
|
243
|
+
generator_inputs["streaming_callback"] = streaming_callback
|
|
244
|
+
if generation_kwargs is not None:
|
|
245
|
+
generator_inputs["generation_kwargs"] = generation_kwargs
|
|
246
|
+
|
|
247
|
+
# NOTE: difference with parent method to add this to tool_invoker_inputs
|
|
213
248
|
if self._tool_invoker:
|
|
214
|
-
|
|
249
|
+
tool_invoker_inputs["enable_streaming_callback_passthrough"] = (
|
|
215
250
|
self._tool_invoker.enable_streaming_callback_passthrough
|
|
216
251
|
)
|
|
217
|
-
|
|
252
|
+
|
|
253
|
+
# NOTE: difference is to use the extended _ExecutionContext with confirmation_strategy_context
|
|
218
254
|
return _ExecutionContext(
|
|
219
|
-
state=
|
|
220
|
-
component_visits=
|
|
221
|
-
chat_generator_inputs=
|
|
222
|
-
tool_invoker_inputs=
|
|
255
|
+
state=state,
|
|
256
|
+
component_visits=dict.fromkeys(["chat_generator", "tool_invoker"], 0),
|
|
257
|
+
chat_generator_inputs=generator_inputs,
|
|
258
|
+
tool_invoker_inputs=tool_invoker_inputs,
|
|
259
|
+
confirmation_strategy_context=confirmation_strategy_context,
|
|
223
260
|
)
|
|
224
261
|
|
|
225
262
|
def _initialize_from_snapshot( # type: ignore[override]
|
|
@@ -230,6 +267,7 @@ class Agent(HaystackAgent):
|
|
|
230
267
|
*,
|
|
231
268
|
generation_kwargs: Optional[dict[str, Any]] = None,
|
|
232
269
|
tools: Optional[Union[ToolsType, list[str]]] = None,
|
|
270
|
+
confirmation_strategy_context: Optional[dict[str, Any]] = None,
|
|
233
271
|
) -> _ExecutionContext:
|
|
234
272
|
"""
|
|
235
273
|
Initialize execution context from an AgentSnapshot.
|
|
@@ -241,12 +279,14 @@ class Agent(HaystackAgent):
|
|
|
241
279
|
override the parameters passed during component initialization.
|
|
242
280
|
:param tools: Optional list of Tool objects, a Toolset, or list of tool names to use for this run.
|
|
243
281
|
When passing tool names, tools are selected from the Agent's originally configured tools.
|
|
282
|
+
:param confirmation_strategy_context: Optional dictionary for passing request-scoped resources
|
|
283
|
+
to confirmation strategies.
|
|
244
284
|
"""
|
|
245
285
|
# The PR https://github.com/deepset-ai/haystack/pull/9616 added the generation_kwargs parameter to
|
|
246
286
|
# _initialize_from_snapshot. This change has been released in Haystack 2.20.0.
|
|
247
287
|
# To maintain compatibility with Haystack 2.19 we check the number of parameters and call accordingly.
|
|
248
288
|
if inspect.signature(super(Agent, self)._initialize_from_snapshot).parameters.get("generation_kwargs"):
|
|
249
|
-
exe_context = super(Agent, self)._initialize_from_snapshot(
|
|
289
|
+
exe_context = super(Agent, self)._initialize_from_snapshot( # type: ignore[call-arg]
|
|
250
290
|
snapshot=snapshot,
|
|
251
291
|
streaming_callback=streaming_callback,
|
|
252
292
|
requires_async=requires_async,
|
|
@@ -262,7 +302,8 @@ class Agent(HaystackAgent):
|
|
|
262
302
|
exe_context.tool_invoker_inputs["enable_streaming_callback_passthrough"] = (
|
|
263
303
|
self._tool_invoker.enable_streaming_callback_passthrough
|
|
264
304
|
)
|
|
265
|
-
# NOTE: 2nd difference is to use the extended _ExecutionContext
|
|
305
|
+
# NOTE: 2nd difference is to use the extended _ExecutionContext
|
|
306
|
+
# and add tool_execution_decisions + confirmation_strategy_context
|
|
266
307
|
return _ExecutionContext(
|
|
267
308
|
state=exe_context.state,
|
|
268
309
|
component_visits=exe_context.component_visits,
|
|
@@ -271,18 +312,21 @@ class Agent(HaystackAgent):
|
|
|
271
312
|
counter=exe_context.counter,
|
|
272
313
|
skip_chat_generator=exe_context.skip_chat_generator,
|
|
273
314
|
tool_execution_decisions=snapshot.tool_execution_decisions,
|
|
315
|
+
confirmation_strategy_context=confirmation_strategy_context,
|
|
274
316
|
)
|
|
275
317
|
|
|
276
|
-
def run( # noqa: PLR0915
|
|
318
|
+
def run( # type: ignore[override] # noqa: PLR0915
|
|
277
319
|
self,
|
|
278
320
|
messages: list[ChatMessage],
|
|
279
321
|
streaming_callback: Optional[StreamingCallbackT] = None,
|
|
280
322
|
*,
|
|
281
323
|
generation_kwargs: Optional[dict[str, Any]] = None,
|
|
282
324
|
break_point: Optional[AgentBreakpoint] = None,
|
|
283
|
-
snapshot: Optional[AgentSnapshot] = None,
|
|
325
|
+
snapshot: Optional[AgentSnapshot] = None,
|
|
284
326
|
system_prompt: Optional[str] = None,
|
|
285
327
|
tools: Optional[Union[ToolsType, list[str]]] = None,
|
|
328
|
+
confirmation_strategy_context: Optional[dict[str, Any]] = None,
|
|
329
|
+
chat_message_store_kwargs: Optional[dict[str, Any]] = None,
|
|
286
330
|
**kwargs: Any,
|
|
287
331
|
) -> dict[str, Any]:
|
|
288
332
|
"""
|
|
@@ -300,6 +344,12 @@ class Agent(HaystackAgent):
|
|
|
300
344
|
:param system_prompt: System prompt for the agent. If provided, it overrides the default system prompt.
|
|
301
345
|
:param tools: Optional list of Tool objects, a Toolset, or list of tool names to use for this run.
|
|
302
346
|
When passing tool names, tools are selected from the Agent's originally configured tools.
|
|
347
|
+
:param confirmation_strategy_context: Optional dictionary for passing request-scoped resources
|
|
348
|
+
to confirmation strategies. Useful in web/server environments to provide per-request
|
|
349
|
+
objects (e.g., WebSocket connections, async queues, Redis pub/sub clients) that strategies
|
|
350
|
+
can use for non-blocking user interaction.
|
|
351
|
+
:param chat_message_store_kwargs: Optional dictionary of keyword arguments to pass to the ChatMessageStore.
|
|
352
|
+
For example, it can include the `chat_history_id` and `last_k` parameters for retrieving chat history.
|
|
303
353
|
:param kwargs: Additional data to pass to the State schema used by the Agent.
|
|
304
354
|
The keys must match the schema defined in the Agent's `state_schema`.
|
|
305
355
|
:returns:
|
|
@@ -334,6 +384,7 @@ class Agent(HaystackAgent):
|
|
|
334
384
|
requires_async=False,
|
|
335
385
|
generation_kwargs=generation_kwargs,
|
|
336
386
|
tools=tools,
|
|
387
|
+
confirmation_strategy_context=confirmation_strategy_context,
|
|
337
388
|
)
|
|
338
389
|
else:
|
|
339
390
|
exe_context = self._initialize_fresh_execution(
|
|
@@ -343,6 +394,8 @@ class Agent(HaystackAgent):
|
|
|
343
394
|
system_prompt=system_prompt,
|
|
344
395
|
generation_kwargs=generation_kwargs,
|
|
345
396
|
tools=tools,
|
|
397
|
+
confirmation_strategy_context=confirmation_strategy_context,
|
|
398
|
+
chat_message_store_kwargs=chat_message_store_kwargs,
|
|
346
399
|
**kwargs,
|
|
347
400
|
)
|
|
348
401
|
|
|
@@ -469,18 +522,27 @@ class Agent(HaystackAgent):
|
|
|
469
522
|
result = {**exe_context.state.data}
|
|
470
523
|
if msgs := result.get("messages"):
|
|
471
524
|
result["last_message"] = msgs[-1]
|
|
525
|
+
|
|
526
|
+
# Write messages to ChatMessageStore if configured
|
|
527
|
+
if self._chat_message_writer:
|
|
528
|
+
writer_kwargs = _select_kwargs(self._chat_message_writer, chat_message_store_kwargs or {})
|
|
529
|
+
if "chat_history_id" in writer_kwargs:
|
|
530
|
+
self._chat_message_writer.run(messages=result["messages"], **writer_kwargs)
|
|
531
|
+
|
|
472
532
|
return result
|
|
473
533
|
|
|
474
|
-
async def run_async(
|
|
534
|
+
async def run_async( # type: ignore[override]
|
|
475
535
|
self,
|
|
476
536
|
messages: list[ChatMessage],
|
|
477
537
|
streaming_callback: Optional[StreamingCallbackT] = None,
|
|
478
538
|
*,
|
|
479
539
|
generation_kwargs: Optional[dict[str, Any]] = None,
|
|
480
540
|
break_point: Optional[AgentBreakpoint] = None,
|
|
481
|
-
snapshot: Optional[AgentSnapshot] = None,
|
|
541
|
+
snapshot: Optional[AgentSnapshot] = None,
|
|
482
542
|
system_prompt: Optional[str] = None,
|
|
483
543
|
tools: Optional[Union[ToolsType, list[str]]] = None,
|
|
544
|
+
confirmation_strategy_context: Optional[dict[str, Any]] = None,
|
|
545
|
+
chat_message_store_kwargs: Optional[dict[str, Any]] = None,
|
|
484
546
|
**kwargs: Any,
|
|
485
547
|
) -> dict[str, Any]:
|
|
486
548
|
"""
|
|
@@ -501,6 +563,12 @@ class Agent(HaystackAgent):
|
|
|
501
563
|
the relevant information to restart the Agent execution from where it left off.
|
|
502
564
|
:param system_prompt: System prompt for the agent. If provided, it overrides the default system prompt.
|
|
503
565
|
:param tools: Optional list of Tool objects, a Toolset, or list of tool names to use for this run.
|
|
566
|
+
:param confirmation_strategy_context: Optional dictionary for passing request-scoped resources
|
|
567
|
+
to confirmation strategies. Useful in web/server environments to provide per-request
|
|
568
|
+
objects (e.g., WebSocket connections, async queues, Redis pub/sub clients) that strategies
|
|
569
|
+
can use for non-blocking user interaction.
|
|
570
|
+
:param chat_message_store_kwargs: Optional dictionary of keyword arguments to pass to the ChatMessageStore.
|
|
571
|
+
For example, it can include the `chat_history_id` and `last_k` parameters for retrieving chat history.
|
|
504
572
|
:param kwargs: Additional data to pass to the State schema used by the Agent.
|
|
505
573
|
The keys must match the schema defined in the Agent's `state_schema`.
|
|
506
574
|
:returns:
|
|
@@ -535,6 +603,7 @@ class Agent(HaystackAgent):
|
|
|
535
603
|
requires_async=True,
|
|
536
604
|
generation_kwargs=generation_kwargs,
|
|
537
605
|
tools=tools,
|
|
606
|
+
confirmation_strategy_context=confirmation_strategy_context,
|
|
538
607
|
)
|
|
539
608
|
else:
|
|
540
609
|
exe_context = self._initialize_fresh_execution(
|
|
@@ -544,6 +613,8 @@ class Agent(HaystackAgent):
|
|
|
544
613
|
system_prompt=system_prompt,
|
|
545
614
|
generation_kwargs=generation_kwargs,
|
|
546
615
|
tools=tools,
|
|
616
|
+
confirmation_strategy_context=confirmation_strategy_context,
|
|
617
|
+
chat_message_store_kwargs=chat_message_store_kwargs,
|
|
547
618
|
**kwargs,
|
|
548
619
|
)
|
|
549
620
|
|
|
@@ -581,8 +652,8 @@ class Agent(HaystackAgent):
|
|
|
581
652
|
|
|
582
653
|
# Apply confirmation strategies and update State and messages sent to ToolInvoker
|
|
583
654
|
try:
|
|
584
|
-
# Run confirmation strategies to get updated tool call messages and modified chat history
|
|
585
|
-
modified_tool_call_messages, new_chat_history =
|
|
655
|
+
# Run confirmation strategies to get updated tool call messages and modified chat history (async)
|
|
656
|
+
modified_tool_call_messages, new_chat_history = await _process_confirmation_strategies_async(
|
|
586
657
|
confirmation_strategies=self._confirmation_strategies,
|
|
587
658
|
messages_with_tool_calls=llm_messages,
|
|
588
659
|
execution_context=exe_context,
|
|
@@ -646,6 +717,13 @@ class Agent(HaystackAgent):
|
|
|
646
717
|
result = {**exe_context.state.data}
|
|
647
718
|
if msgs := result.get("messages"):
|
|
648
719
|
result["last_message"] = msgs[-1]
|
|
720
|
+
|
|
721
|
+
# Write messages to ChatMessageStore if configured
|
|
722
|
+
if self._chat_message_writer:
|
|
723
|
+
writer_kwargs = _select_kwargs(self._chat_message_writer, chat_message_store_kwargs or {})
|
|
724
|
+
if "chat_history_id" in writer_kwargs:
|
|
725
|
+
self._chat_message_writer.run(messages=result["messages"], **writer_kwargs)
|
|
726
|
+
|
|
649
727
|
return result
|
|
650
728
|
|
|
651
729
|
def to_dict(self) -> dict[str, Any]:
|
|
@@ -660,6 +738,9 @@ class Agent(HaystackAgent):
|
|
|
660
738
|
if self._confirmation_strategies
|
|
661
739
|
else None
|
|
662
740
|
)
|
|
741
|
+
data["init_parameters"]["chat_message_store"] = (
|
|
742
|
+
self._chat_message_store.to_dict() if self._chat_message_store is not None else None
|
|
743
|
+
)
|
|
663
744
|
return data
|
|
664
745
|
|
|
665
746
|
@classmethod
|
|
@@ -684,9 +765,31 @@ class Agent(HaystackAgent):
|
|
|
684
765
|
|
|
685
766
|
if "confirmation_strategies" in init_params and init_params["confirmation_strategies"] is not None:
|
|
686
767
|
for name, strategy_dict in init_params["confirmation_strategies"].items():
|
|
687
|
-
|
|
768
|
+
try:
|
|
769
|
+
strategy_class = import_class_by_name(strategy_dict["type"])
|
|
770
|
+
except ImportError as e:
|
|
771
|
+
raise DeserializationError(f"Class '{strategy_dict['type']}' not correctly imported") from e
|
|
688
772
|
if not hasattr(strategy_class, "from_dict"):
|
|
689
|
-
raise
|
|
773
|
+
raise DeserializationError(f"{strategy_class} does not have from_dict method implemented.")
|
|
690
774
|
init_params["confirmation_strategies"][name] = strategy_class.from_dict(strategy_dict)
|
|
691
775
|
|
|
776
|
+
if "chat_message_store" in init_params and init_params["chat_message_store"] is not None:
|
|
777
|
+
cms_data = init_params["chat_message_store"]
|
|
778
|
+
try:
|
|
779
|
+
cms_class = import_class_by_name(cms_data["type"])
|
|
780
|
+
except ImportError as e:
|
|
781
|
+
raise DeserializationError(f"Class '{cms_data['type']}' not correctly imported") from e
|
|
782
|
+
if not hasattr(cms_class, "from_dict"):
|
|
783
|
+
raise DeserializationError(f"{cms_class} does not have from_dict method implemented.")
|
|
784
|
+
init_params["chat_message_store"] = cms_class.from_dict(cms_data)
|
|
785
|
+
|
|
692
786
|
return default_from_dict(cls, data)
|
|
787
|
+
|
|
788
|
+
|
|
789
|
+
def _select_kwargs(obj: Any, source: dict) -> dict[str, Any]:
|
|
790
|
+
"""
|
|
791
|
+
Select only those key-value pairs from source dict that are valid parameters for obj.run() method.
|
|
792
|
+
"""
|
|
793
|
+
sig = inspect.signature(obj.run)
|
|
794
|
+
allowed = set(sig.parameters.keys())
|
|
795
|
+
return {k: v for k, v in source.items() if k in allowed}
|
|
@@ -47,7 +47,13 @@ class BlockingConfirmationStrategy:
|
|
|
47
47
|
self.confirmation_ui = confirmation_ui
|
|
48
48
|
|
|
49
49
|
def run(
|
|
50
|
-
self,
|
|
50
|
+
self,
|
|
51
|
+
*,
|
|
52
|
+
tool_name: str,
|
|
53
|
+
tool_description: str,
|
|
54
|
+
tool_params: dict[str, Any],
|
|
55
|
+
tool_call_id: Optional[str] = None,
|
|
56
|
+
confirmation_strategy_context: Optional[dict[str, Any]] = None,
|
|
51
57
|
) -> ToolExecutionDecision:
|
|
52
58
|
"""
|
|
53
59
|
Run the human-in-the-loop strategy for a given tool and its parameters.
|
|
@@ -61,6 +67,10 @@ class BlockingConfirmationStrategy:
|
|
|
61
67
|
:param tool_call_id:
|
|
62
68
|
Optional unique identifier for the tool call. This can be used to track and correlate the decision with a
|
|
63
69
|
specific tool invocation.
|
|
70
|
+
:param confirmation_strategy_context:
|
|
71
|
+
Optional dictionary for passing request-scoped resources. Useful in web/server environments
|
|
72
|
+
to provide per-request objects (e.g., WebSocket connections, async queues, Redis pub/sub clients)
|
|
73
|
+
that strategies can use for non-blocking user interaction.
|
|
64
74
|
|
|
65
75
|
:returns:
|
|
66
76
|
A ToolExecutionDecision indicating whether to execute the tool with the given parameters, or a
|
|
@@ -109,6 +119,40 @@ class BlockingConfirmationStrategy:
|
|
|
109
119
|
tool_name=tool_name, execute=True, tool_call_id=tool_call_id, final_tool_params=tool_params
|
|
110
120
|
)
|
|
111
121
|
|
|
122
|
+
async def run_async(
|
|
123
|
+
self,
|
|
124
|
+
*,
|
|
125
|
+
tool_name: str,
|
|
126
|
+
tool_description: str,
|
|
127
|
+
tool_params: dict[str, Any],
|
|
128
|
+
tool_call_id: Optional[str] = None,
|
|
129
|
+
confirmation_strategy_context: Optional[dict[str, Any]] = None,
|
|
130
|
+
) -> ToolExecutionDecision:
|
|
131
|
+
"""
|
|
132
|
+
Async version of run. Calls the sync run() method by default.
|
|
133
|
+
|
|
134
|
+
:param tool_name:
|
|
135
|
+
The name of the tool to be executed.
|
|
136
|
+
:param tool_description:
|
|
137
|
+
The description of the tool.
|
|
138
|
+
:param tool_params:
|
|
139
|
+
The parameters to be passed to the tool.
|
|
140
|
+
:param tool_call_id:
|
|
141
|
+
Optional unique identifier for the tool call.
|
|
142
|
+
:param confirmation_strategy_context:
|
|
143
|
+
Optional dictionary for passing request-scoped resources.
|
|
144
|
+
|
|
145
|
+
:returns:
|
|
146
|
+
A ToolExecutionDecision indicating whether to execute the tool with the given parameters.
|
|
147
|
+
"""
|
|
148
|
+
return self.run(
|
|
149
|
+
tool_name=tool_name,
|
|
150
|
+
tool_description=tool_description,
|
|
151
|
+
tool_params=tool_params,
|
|
152
|
+
tool_call_id=tool_call_id,
|
|
153
|
+
confirmation_strategy_context=confirmation_strategy_context,
|
|
154
|
+
)
|
|
155
|
+
|
|
112
156
|
def to_dict(self) -> dict[str, Any]:
|
|
113
157
|
"""
|
|
114
158
|
Serializes the BlockingConfirmationStrategy to a dictionary.
|
|
@@ -161,7 +205,13 @@ class BreakpointConfirmationStrategy:
|
|
|
161
205
|
self.snapshot_file_path = snapshot_file_path
|
|
162
206
|
|
|
163
207
|
def run(
|
|
164
|
-
self,
|
|
208
|
+
self,
|
|
209
|
+
*,
|
|
210
|
+
tool_name: str,
|
|
211
|
+
tool_description: str,
|
|
212
|
+
tool_params: dict[str, Any],
|
|
213
|
+
tool_call_id: Optional[str] = None,
|
|
214
|
+
confirmation_strategy_context: Optional[dict[str, Any]] = None,
|
|
165
215
|
) -> ToolExecutionDecision:
|
|
166
216
|
"""
|
|
167
217
|
Run the breakpoint confirmation strategy for a given tool and its parameters.
|
|
@@ -175,6 +225,9 @@ class BreakpointConfirmationStrategy:
|
|
|
175
225
|
:param tool_call_id:
|
|
176
226
|
Optional unique identifier for the tool call. This can be used to track and correlate the decision with a
|
|
177
227
|
specific tool invocation.
|
|
228
|
+
:param confirmation_strategy_context:
|
|
229
|
+
Optional dictionary for passing request-scoped resources. Not used by this strategy but included for
|
|
230
|
+
interface compatibility.
|
|
178
231
|
|
|
179
232
|
:raises HITLBreakpointException:
|
|
180
233
|
Always raises an `HITLBreakpointException` exception to signal that user confirmation is required.
|
|
@@ -189,6 +242,43 @@ class BreakpointConfirmationStrategy:
|
|
|
189
242
|
snapshot_file_path=self.snapshot_file_path,
|
|
190
243
|
)
|
|
191
244
|
|
|
245
|
+
async def run_async(
|
|
246
|
+
self,
|
|
247
|
+
*,
|
|
248
|
+
tool_name: str,
|
|
249
|
+
tool_description: str,
|
|
250
|
+
tool_params: dict[str, Any],
|
|
251
|
+
tool_call_id: Optional[str] = None,
|
|
252
|
+
confirmation_strategy_context: Optional[dict[str, Any]] = None,
|
|
253
|
+
) -> ToolExecutionDecision:
|
|
254
|
+
"""
|
|
255
|
+
Async version of run. Calls the sync run() method.
|
|
256
|
+
|
|
257
|
+
:param tool_name:
|
|
258
|
+
The name of the tool to be executed.
|
|
259
|
+
:param tool_description:
|
|
260
|
+
The description of the tool.
|
|
261
|
+
:param tool_params:
|
|
262
|
+
The parameters to be passed to the tool.
|
|
263
|
+
:param tool_call_id:
|
|
264
|
+
Optional unique identifier for the tool call.
|
|
265
|
+
:param confirmation_strategy_context:
|
|
266
|
+
Optional dictionary for passing request-scoped resources.
|
|
267
|
+
|
|
268
|
+
:raises HITLBreakpointException:
|
|
269
|
+
Always raises an `HITLBreakpointException` exception to signal that user confirmation is required.
|
|
270
|
+
|
|
271
|
+
:returns:
|
|
272
|
+
This method does not return; it always raises an exception.
|
|
273
|
+
"""
|
|
274
|
+
return self.run(
|
|
275
|
+
tool_name=tool_name,
|
|
276
|
+
tool_description=tool_description,
|
|
277
|
+
tool_params=tool_params,
|
|
278
|
+
tool_call_id=tool_call_id,
|
|
279
|
+
confirmation_strategy_context=confirmation_strategy_context,
|
|
280
|
+
)
|
|
281
|
+
|
|
192
282
|
def to_dict(self) -> dict[str, Any]:
|
|
193
283
|
"""
|
|
194
284
|
Serializes the BreakpointConfirmationStrategy to a dictionary.
|
|
@@ -285,6 +375,46 @@ def _process_confirmation_strategies(
|
|
|
285
375
|
return modified_tool_call_messages, new_chat_history
|
|
286
376
|
|
|
287
377
|
|
|
378
|
+
async def _process_confirmation_strategies_async(
|
|
379
|
+
*,
|
|
380
|
+
confirmation_strategies: dict[str, ConfirmationStrategy],
|
|
381
|
+
messages_with_tool_calls: list[ChatMessage],
|
|
382
|
+
execution_context: "_ExecutionContext",
|
|
383
|
+
) -> tuple[list[ChatMessage], list[ChatMessage]]:
|
|
384
|
+
"""
|
|
385
|
+
Async version of _process_confirmation_strategies.
|
|
386
|
+
|
|
387
|
+
Run the confirmation strategies and return modified tool call messages and updated chat history.
|
|
388
|
+
|
|
389
|
+
:param confirmation_strategies: Mapping of tool names to their corresponding confirmation strategies
|
|
390
|
+
:param messages_with_tool_calls: Chat messages containing tool calls
|
|
391
|
+
:param execution_context: The current execution context of the agent
|
|
392
|
+
:returns:
|
|
393
|
+
Tuple of modified messages with confirmed tool calls and updated chat history
|
|
394
|
+
"""
|
|
395
|
+
# Run confirmation strategies and get tool execution decisions (async version)
|
|
396
|
+
teds = await _run_confirmation_strategies_async(
|
|
397
|
+
confirmation_strategies=confirmation_strategies,
|
|
398
|
+
messages_with_tool_calls=messages_with_tool_calls,
|
|
399
|
+
execution_context=execution_context,
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
# Apply tool execution decisions to messages_with_tool_calls
|
|
403
|
+
rejection_messages, modified_tool_call_messages = _apply_tool_execution_decisions(
|
|
404
|
+
tool_call_messages=messages_with_tool_calls,
|
|
405
|
+
tool_execution_decisions=teds,
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
# Update the chat history with rejection messages and new tool call messages
|
|
409
|
+
new_chat_history = _update_chat_history(
|
|
410
|
+
chat_history=execution_context.state.get("messages"),
|
|
411
|
+
rejection_messages=rejection_messages,
|
|
412
|
+
tool_call_and_explanation_messages=modified_tool_call_messages,
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
return modified_tool_call_messages, new_chat_history
|
|
416
|
+
|
|
417
|
+
|
|
288
418
|
def _run_confirmation_strategies(
|
|
289
419
|
confirmation_strategies: dict[str, ConfirmationStrategy],
|
|
290
420
|
messages_with_tool_calls: list[ChatMessage],
|
|
@@ -344,13 +474,100 @@ def _run_confirmation_strategies(
|
|
|
344
474
|
# If not, run the confirmation strategy
|
|
345
475
|
if not ted:
|
|
346
476
|
ted = confirmation_strategies[tool_name].run(
|
|
347
|
-
tool_name=tool_name,
|
|
477
|
+
tool_name=tool_name,
|
|
478
|
+
tool_description=tool_to_invoke.description,
|
|
479
|
+
tool_params=final_args,
|
|
480
|
+
tool_call_id=tool_call.id,
|
|
481
|
+
confirmation_strategy_context=execution_context.confirmation_strategy_context,
|
|
348
482
|
)
|
|
349
483
|
teds.append(ted)
|
|
350
484
|
|
|
351
485
|
return teds
|
|
352
486
|
|
|
353
487
|
|
|
488
|
+
async def _run_confirmation_strategies_async(
|
|
489
|
+
confirmation_strategies: dict[str, ConfirmationStrategy],
|
|
490
|
+
messages_with_tool_calls: list[ChatMessage],
|
|
491
|
+
execution_context: "_ExecutionContext",
|
|
492
|
+
) -> list[ToolExecutionDecision]:
|
|
493
|
+
"""
|
|
494
|
+
Async version of _run_confirmation_strategies.
|
|
495
|
+
|
|
496
|
+
Run confirmation strategies for tool calls in the provided chat messages.
|
|
497
|
+
|
|
498
|
+
:param confirmation_strategies: Mapping of tool names to their corresponding confirmation strategies
|
|
499
|
+
:param messages_with_tool_calls: Messages containing tool calls to process
|
|
500
|
+
:param execution_context: The current execution context containing state and inputs
|
|
501
|
+
:returns:
|
|
502
|
+
A list of ToolExecutionDecision objects representing the decisions made for each tool call.
|
|
503
|
+
"""
|
|
504
|
+
state = execution_context.state
|
|
505
|
+
tools_with_names = {tool.name: tool for tool in execution_context.tool_invoker_inputs["tools"]}
|
|
506
|
+
existing_teds = execution_context.tool_execution_decisions if execution_context.tool_execution_decisions else []
|
|
507
|
+
existing_teds_by_name = {ted.tool_name: ted for ted in existing_teds if ted.tool_name}
|
|
508
|
+
existing_teds_by_id = {ted.tool_call_id: ted for ted in existing_teds if ted.tool_call_id}
|
|
509
|
+
|
|
510
|
+
teds = []
|
|
511
|
+
for message in messages_with_tool_calls:
|
|
512
|
+
if not message.tool_calls:
|
|
513
|
+
continue
|
|
514
|
+
|
|
515
|
+
for tool_call in message.tool_calls:
|
|
516
|
+
tool_name = tool_call.tool_name
|
|
517
|
+
tool_to_invoke = tools_with_names[tool_name]
|
|
518
|
+
|
|
519
|
+
# Prepare final tool args
|
|
520
|
+
final_args = _prepare_tool_args(
|
|
521
|
+
tool=tool_to_invoke,
|
|
522
|
+
tool_call_arguments=tool_call.arguments,
|
|
523
|
+
state=state,
|
|
524
|
+
streaming_callback=execution_context.tool_invoker_inputs.get("streaming_callback"),
|
|
525
|
+
enable_streaming_passthrough=execution_context.tool_invoker_inputs.get(
|
|
526
|
+
"enable_streaming_passthrough", False
|
|
527
|
+
),
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
# Get tool execution decisions from confirmation strategies
|
|
531
|
+
# If no confirmation strategy is defined for this tool, proceed with execution
|
|
532
|
+
if tool_name not in confirmation_strategies:
|
|
533
|
+
teds.append(
|
|
534
|
+
ToolExecutionDecision(
|
|
535
|
+
tool_call_id=tool_call.id,
|
|
536
|
+
tool_name=tool_name,
|
|
537
|
+
execute=True,
|
|
538
|
+
final_tool_params=final_args,
|
|
539
|
+
)
|
|
540
|
+
)
|
|
541
|
+
continue
|
|
542
|
+
|
|
543
|
+
# Check if there's already a decision for this tool call in the execution context
|
|
544
|
+
ted = existing_teds_by_id.get(tool_call.id or "") or existing_teds_by_name.get(tool_name)
|
|
545
|
+
|
|
546
|
+
# If not, run the confirmation strategy (async version)
|
|
547
|
+
if not ted:
|
|
548
|
+
strategy = confirmation_strategies[tool_name]
|
|
549
|
+
# Use run_async if available, otherwise fall back to sync run
|
|
550
|
+
if hasattr(strategy, "run_async"):
|
|
551
|
+
ted = await strategy.run_async(
|
|
552
|
+
tool_name=tool_name,
|
|
553
|
+
tool_description=tool_to_invoke.description,
|
|
554
|
+
tool_params=final_args,
|
|
555
|
+
tool_call_id=tool_call.id,
|
|
556
|
+
confirmation_strategy_context=execution_context.confirmation_strategy_context,
|
|
557
|
+
)
|
|
558
|
+
else:
|
|
559
|
+
ted = strategy.run(
|
|
560
|
+
tool_name=tool_name,
|
|
561
|
+
tool_description=tool_to_invoke.description,
|
|
562
|
+
tool_params=final_args,
|
|
563
|
+
tool_call_id=tool_call.id,
|
|
564
|
+
confirmation_strategy_context=execution_context.confirmation_strategy_context,
|
|
565
|
+
)
|
|
566
|
+
teds.append(ted)
|
|
567
|
+
|
|
568
|
+
return teds
|
|
569
|
+
|
|
570
|
+
|
|
354
571
|
def _apply_tool_execution_decisions(
|
|
355
572
|
tool_call_messages: list[ChatMessage], tool_execution_decisions: list[ToolExecutionDecision]
|
|
356
573
|
) -> tuple[list[ChatMessage], list[ChatMessage]]:
|