haystack-experimental 0.14.3__py3-none-any.whl → 0.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. haystack_experimental/chat_message_stores/__init__.py +1 -1
  2. haystack_experimental/chat_message_stores/in_memory.py +176 -31
  3. haystack_experimental/chat_message_stores/types.py +33 -21
  4. haystack_experimental/components/agents/agent.py +147 -44
  5. haystack_experimental/components/agents/human_in_the_loop/strategies.py +220 -3
  6. haystack_experimental/components/agents/human_in_the_loop/types.py +36 -1
  7. haystack_experimental/components/embedders/types/protocol.py +2 -2
  8. haystack_experimental/components/preprocessors/embedding_based_document_splitter.py +16 -16
  9. haystack_experimental/components/retrievers/__init__.py +1 -3
  10. haystack_experimental/components/retrievers/chat_message_retriever.py +57 -26
  11. haystack_experimental/components/writers/__init__.py +1 -1
  12. haystack_experimental/components/writers/chat_message_writer.py +25 -22
  13. {haystack_experimental-0.14.3.dist-info → haystack_experimental-0.15.1.dist-info}/METADATA +24 -31
  14. {haystack_experimental-0.14.3.dist-info → haystack_experimental-0.15.1.dist-info}/RECORD +17 -24
  15. {haystack_experimental-0.14.3.dist-info → haystack_experimental-0.15.1.dist-info}/WHEEL +1 -1
  16. haystack_experimental/components/query/__init__.py +0 -18
  17. haystack_experimental/components/query/query_expander.py +0 -294
  18. haystack_experimental/components/retrievers/multi_query_embedding_retriever.py +0 -173
  19. haystack_experimental/components/retrievers/multi_query_text_retriever.py +0 -150
  20. haystack_experimental/super_components/__init__.py +0 -3
  21. haystack_experimental/super_components/indexers/__init__.py +0 -11
  22. haystack_experimental/super_components/indexers/sentence_transformers_document_indexer.py +0 -199
  23. {haystack_experimental-0.14.3.dist-info → haystack_experimental-0.15.1.dist-info}/licenses/LICENSE +0 -0
  24. {haystack_experimental-0.14.3.dist-info → haystack_experimental-0.15.1.dist-info}/licenses/LICENSE-MIT.txt +0 -0
@@ -22,11 +22,11 @@ import haystack_experimental.core.pipeline.breakpoint as exp_breakpoint
22
22
  hs_breakpoint._create_agent_snapshot = exp_breakpoint._create_agent_snapshot
23
23
  hs_breakpoint._create_pipeline_snapshot_from_tool_invoker = exp_breakpoint._create_pipeline_snapshot_from_tool_invoker # type: ignore[assignment]
24
24
 
25
- from haystack import logging
25
+ from haystack import DeserializationError, logging
26
26
  from haystack.components.agents.agent import Agent as HaystackAgent
27
27
  from haystack.components.agents.agent import _ExecutionContext as Haystack_ExecutionContext
28
28
  from haystack.components.agents.agent import _schema_from_dict
29
- from haystack.components.agents.state import replace_values
29
+ from haystack.components.agents.state import replace_values, State
30
30
  from haystack.components.generators.chat.types import ChatGenerator
31
31
  from haystack.core.errors import PipelineRuntimeError
32
32
  from haystack.core.pipeline import AsyncPipeline, Pipeline
@@ -36,19 +36,25 @@ from haystack.core.pipeline.breakpoint import (
36
36
  )
37
37
  from haystack.core.pipeline.utils import _deepcopy_with_exceptions
38
38
  from haystack.core.serialization import default_from_dict, import_class_by_name
39
- from haystack.dataclasses import ChatMessage
39
+ from haystack.dataclasses import ChatMessage, ChatRole
40
40
  from haystack.dataclasses.breakpoints import AgentBreakpoint, ToolBreakpoint
41
- from haystack.dataclasses.streaming_chunk import StreamingCallbackT
41
+ from haystack.dataclasses.streaming_chunk import StreamingCallbackT, select_streaming_callback
42
42
  from haystack.tools import ToolsType, deserialize_tools_or_toolset_inplace
43
43
  from haystack.utils.callable_serialization import deserialize_callable
44
44
  from haystack.utils.deserialization import deserialize_chatgenerator_inplace
45
45
 
46
+ from haystack_experimental.chat_message_stores.types import ChatMessageStore
46
47
  from haystack_experimental.components.agents.human_in_the_loop import (
47
48
  ConfirmationStrategy,
48
49
  ToolExecutionDecision,
49
50
  HITLBreakpointException,
50
51
  )
51
- from haystack_experimental.components.agents.human_in_the_loop.strategies import _process_confirmation_strategies
52
+ from haystack_experimental.components.agents.human_in_the_loop.strategies import (
53
+ _process_confirmation_strategies,
54
+ _process_confirmation_strategies_async,
55
+ )
56
+ from haystack_experimental.components.retrievers import ChatMessageRetriever
57
+ from haystack_experimental.components.writers import ChatMessageWriter
52
58
 
53
59
  logger = logging.getLogger(__name__)
54
60
 
@@ -62,9 +68,15 @@ class _ExecutionContext(Haystack_ExecutionContext):
62
68
 
63
69
  :param tool_execution_decisions: Optional list of ToolExecutionDecision objects to use instead of prompting
64
70
  the user. This is useful when restarting from a snapshot where tool execution decisions were already made.
71
+ :param confirmation_strategy_context: Optional dictionary for passing request-scoped resources
72
+ to confirmation strategies. In web/server environments, this enables passing per-request
73
+ objects (e.g., WebSocket connections, async queues, or pub/sub clients) that strategies can use for
74
+ non-blocking user interaction. This is passed directly to strategies via the `confirmation_strategy_context`
75
+ parameter in their `run()` and `run_async()` methods.
65
76
  """
66
77
 
67
78
  tool_execution_decisions: Optional[list[ToolExecutionDecision]] = None
79
+ confirmation_strategy_context: Optional[dict[str, Any]] = None
68
80
 
69
81
 
70
82
  class Agent(HaystackAgent):
@@ -131,6 +143,7 @@ class Agent(HaystackAgent):
131
143
  raise_on_tool_invocation_failure: bool = False,
132
144
  confirmation_strategies: Optional[dict[str, ConfirmationStrategy]] = None,
133
145
  tool_invoker_kwargs: Optional[dict[str, Any]] = None,
146
+ chat_message_store: Optional[ChatMessageStore] = None,
134
147
  ) -> None:
135
148
  """
136
149
  Initialize the agent component.
@@ -164,6 +177,13 @@ class Agent(HaystackAgent):
164
177
  tool_invoker_kwargs=tool_invoker_kwargs,
165
178
  )
166
179
  self._confirmation_strategies = confirmation_strategies or {}
180
+ self._chat_message_store = chat_message_store
181
+ self._chat_message_retriever = (
182
+ ChatMessageRetriever(chat_message_store=chat_message_store) if chat_message_store else None
183
+ )
184
+ self._chat_message_writer = (
185
+ ChatMessageWriter(chat_message_store=chat_message_store) if chat_message_store else None
186
+ )
167
187
 
168
188
  def _initialize_fresh_execution(
169
189
  self,
@@ -174,6 +194,8 @@ class Agent(HaystackAgent):
174
194
  system_prompt: Optional[str] = None,
175
195
  generation_kwargs: Optional[dict[str, Any]] = None,
176
196
  tools: Optional[Union[ToolsType, list[str]]] = None,
197
+ confirmation_strategy_context: Optional[dict[str, Any]] = None,
198
+ chat_message_store_kwargs: Optional[dict[str, Any]] = None,
177
199
  **kwargs: dict[str, Any],
178
200
  ) -> _ExecutionContext:
179
201
  """
@@ -185,41 +207,56 @@ class Agent(HaystackAgent):
185
207
  :param system_prompt: System prompt for the agent. If provided, it overrides the default system prompt.
186
208
  :param tools: Optional list of Tool objects, a Toolset, or list of tool names to use for this run.
187
209
  When passing tool names, tools are selected from the Agent's originally configured tools.
210
+ :param confirmation_strategy_context: Optional dictionary for passing request-scoped resources
211
+ to confirmation strategies.
212
+ :param chat_message_store_kwargs: Optional dictionary of keyword arguments to pass to the ChatMessageStore.
188
213
  :param kwargs: Additional data to pass to the State used by the Agent.
189
214
  """
190
- # The PR https://github.com/deepset-ai/haystack/pull/9616 added the generation_kwargs parameter to
191
- # _initialize_fresh_execution. This change has been released in Haystack 2.20.0.
192
- # To maintain compatibility with Haystack 2.19 we check the number of parameters and call accordingly.
193
- if inspect.signature(super(Agent, self)._initialize_fresh_execution).parameters.get("generation_kwargs"):
194
- exe_context = super(Agent, self)._initialize_fresh_execution(
195
- messages=messages,
196
- streaming_callback=streaming_callback,
197
- requires_async=requires_async,
198
- system_prompt=system_prompt,
199
- generation_kwargs=generation_kwargs,
200
- tools=tools,
201
- **kwargs,
202
- )
203
- else:
204
- exe_context = super(Agent, self)._initialize_fresh_execution(
205
- messages=messages,
206
- streaming_callback=streaming_callback,
207
- requires_async=requires_async,
208
- system_prompt=system_prompt,
209
- tools=tools,
210
- **kwargs,
211
- )
212
- # NOTE: 1st difference with parent method to add this to tool_invoker_inputs
215
+ system_prompt = system_prompt or self.system_prompt
216
+ if system_prompt is not None:
217
+ messages = [ChatMessage.from_system(system_prompt)] + messages
218
+
219
+ # NOTE: difference with parent method to add chat message retrieval
220
+ if self._chat_message_retriever:
221
+ retriever_kwargs = _select_kwargs(self._chat_message_retriever, chat_message_store_kwargs or {})
222
+ if "chat_history_id" in retriever_kwargs:
223
+ messages = self._chat_message_retriever.run(
224
+ current_messages=messages,
225
+ **retriever_kwargs,
226
+ )["messages"]
227
+
228
+ if all(m.is_from(ChatRole.SYSTEM) for m in messages):
229
+ logger.warning("All messages provided to the Agent component are system messages. This is not recommended.")
230
+
231
+ state = State(schema=self.state_schema, data=kwargs)
232
+ state.set("messages", messages)
233
+
234
+ streaming_callback = select_streaming_callback( # type: ignore[call-overload]
235
+ init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=requires_async
236
+ )
237
+
238
+ selected_tools = self._select_tools(tools)
239
+ tool_invoker_inputs: dict[str, Any] = {"tools": selected_tools}
240
+ generator_inputs: dict[str, Any] = {"tools": selected_tools}
241
+ if streaming_callback is not None:
242
+ tool_invoker_inputs["streaming_callback"] = streaming_callback
243
+ generator_inputs["streaming_callback"] = streaming_callback
244
+ if generation_kwargs is not None:
245
+ generator_inputs["generation_kwargs"] = generation_kwargs
246
+
247
+ # NOTE: difference with parent method to add this to tool_invoker_inputs
213
248
  if self._tool_invoker:
214
- exe_context.tool_invoker_inputs["enable_streaming_callback_passthrough"] = (
249
+ tool_invoker_inputs["enable_streaming_callback_passthrough"] = (
215
250
  self._tool_invoker.enable_streaming_callback_passthrough
216
251
  )
217
- # NOTE: 2nd difference is to use the extended _ExecutionContext
252
+
253
+ # NOTE: difference is to use the extended _ExecutionContext with confirmation_strategy_context
218
254
  return _ExecutionContext(
219
- state=exe_context.state,
220
- component_visits=exe_context.component_visits,
221
- chat_generator_inputs=exe_context.chat_generator_inputs,
222
- tool_invoker_inputs=exe_context.tool_invoker_inputs,
255
+ state=state,
256
+ component_visits=dict.fromkeys(["chat_generator", "tool_invoker"], 0),
257
+ chat_generator_inputs=generator_inputs,
258
+ tool_invoker_inputs=tool_invoker_inputs,
259
+ confirmation_strategy_context=confirmation_strategy_context,
223
260
  )
224
261
 
225
262
  def _initialize_from_snapshot( # type: ignore[override]
@@ -230,6 +267,7 @@ class Agent(HaystackAgent):
230
267
  *,
231
268
  generation_kwargs: Optional[dict[str, Any]] = None,
232
269
  tools: Optional[Union[ToolsType, list[str]]] = None,
270
+ confirmation_strategy_context: Optional[dict[str, Any]] = None,
233
271
  ) -> _ExecutionContext:
234
272
  """
235
273
  Initialize execution context from an AgentSnapshot.
@@ -241,12 +279,14 @@ class Agent(HaystackAgent):
241
279
  override the parameters passed during component initialization.
242
280
  :param tools: Optional list of Tool objects, a Toolset, or list of tool names to use for this run.
243
281
  When passing tool names, tools are selected from the Agent's originally configured tools.
282
+ :param confirmation_strategy_context: Optional dictionary for passing request-scoped resources
283
+ to confirmation strategies.
244
284
  """
245
285
  # The PR https://github.com/deepset-ai/haystack/pull/9616 added the generation_kwargs parameter to
246
286
  # _initialize_from_snapshot. This change has been released in Haystack 2.20.0.
247
287
  # To maintain compatibility with Haystack 2.19 we check the number of parameters and call accordingly.
248
288
  if inspect.signature(super(Agent, self)._initialize_from_snapshot).parameters.get("generation_kwargs"):
249
- exe_context = super(Agent, self)._initialize_from_snapshot(
289
+ exe_context = super(Agent, self)._initialize_from_snapshot( # type: ignore[call-arg]
250
290
  snapshot=snapshot,
251
291
  streaming_callback=streaming_callback,
252
292
  requires_async=requires_async,
@@ -262,7 +302,8 @@ class Agent(HaystackAgent):
262
302
  exe_context.tool_invoker_inputs["enable_streaming_callback_passthrough"] = (
263
303
  self._tool_invoker.enable_streaming_callback_passthrough
264
304
  )
265
- # NOTE: 2nd difference is to use the extended _ExecutionContext and add tool_execution_decisions
305
+ # NOTE: 2nd difference is to use the extended _ExecutionContext
306
+ # and add tool_execution_decisions + confirmation_strategy_context
266
307
  return _ExecutionContext(
267
308
  state=exe_context.state,
268
309
  component_visits=exe_context.component_visits,
@@ -271,18 +312,21 @@ class Agent(HaystackAgent):
271
312
  counter=exe_context.counter,
272
313
  skip_chat_generator=exe_context.skip_chat_generator,
273
314
  tool_execution_decisions=snapshot.tool_execution_decisions,
315
+ confirmation_strategy_context=confirmation_strategy_context,
274
316
  )
275
317
 
276
- def run( # noqa: PLR0915
318
+ def run( # type: ignore[override] # noqa: PLR0915
277
319
  self,
278
320
  messages: list[ChatMessage],
279
321
  streaming_callback: Optional[StreamingCallbackT] = None,
280
322
  *,
281
323
  generation_kwargs: Optional[dict[str, Any]] = None,
282
324
  break_point: Optional[AgentBreakpoint] = None,
283
- snapshot: Optional[AgentSnapshot] = None, # type: ignore[override]
325
+ snapshot: Optional[AgentSnapshot] = None,
284
326
  system_prompt: Optional[str] = None,
285
327
  tools: Optional[Union[ToolsType, list[str]]] = None,
328
+ confirmation_strategy_context: Optional[dict[str, Any]] = None,
329
+ chat_message_store_kwargs: Optional[dict[str, Any]] = None,
286
330
  **kwargs: Any,
287
331
  ) -> dict[str, Any]:
288
332
  """
@@ -300,6 +344,12 @@ class Agent(HaystackAgent):
300
344
  :param system_prompt: System prompt for the agent. If provided, it overrides the default system prompt.
301
345
  :param tools: Optional list of Tool objects, a Toolset, or list of tool names to use for this run.
302
346
  When passing tool names, tools are selected from the Agent's originally configured tools.
347
+ :param confirmation_strategy_context: Optional dictionary for passing request-scoped resources
348
+ to confirmation strategies. Useful in web/server environments to provide per-request
349
+ objects (e.g., WebSocket connections, async queues, Redis pub/sub clients) that strategies
350
+ can use for non-blocking user interaction.
351
+ :param chat_message_store_kwargs: Optional dictionary of keyword arguments to pass to the ChatMessageStore.
352
+ For example, it can include the `chat_history_id` and `last_k` parameters for retrieving chat history.
303
353
  :param kwargs: Additional data to pass to the State schema used by the Agent.
304
354
  The keys must match the schema defined in the Agent's `state_schema`.
305
355
  :returns:
@@ -334,6 +384,7 @@ class Agent(HaystackAgent):
334
384
  requires_async=False,
335
385
  generation_kwargs=generation_kwargs,
336
386
  tools=tools,
387
+ confirmation_strategy_context=confirmation_strategy_context,
337
388
  )
338
389
  else:
339
390
  exe_context = self._initialize_fresh_execution(
@@ -343,6 +394,8 @@ class Agent(HaystackAgent):
343
394
  system_prompt=system_prompt,
344
395
  generation_kwargs=generation_kwargs,
345
396
  tools=tools,
397
+ confirmation_strategy_context=confirmation_strategy_context,
398
+ chat_message_store_kwargs=chat_message_store_kwargs,
346
399
  **kwargs,
347
400
  )
348
401
 
@@ -469,18 +522,27 @@ class Agent(HaystackAgent):
469
522
  result = {**exe_context.state.data}
470
523
  if msgs := result.get("messages"):
471
524
  result["last_message"] = msgs[-1]
525
+
526
+ # Write messages to ChatMessageStore if configured
527
+ if self._chat_message_writer:
528
+ writer_kwargs = _select_kwargs(self._chat_message_writer, chat_message_store_kwargs or {})
529
+ if "chat_history_id" in writer_kwargs:
530
+ self._chat_message_writer.run(messages=result["messages"], **writer_kwargs)
531
+
472
532
  return result
473
533
 
474
- async def run_async(
534
+ async def run_async( # type: ignore[override]
475
535
  self,
476
536
  messages: list[ChatMessage],
477
537
  streaming_callback: Optional[StreamingCallbackT] = None,
478
538
  *,
479
539
  generation_kwargs: Optional[dict[str, Any]] = None,
480
540
  break_point: Optional[AgentBreakpoint] = None,
481
- snapshot: Optional[AgentSnapshot] = None, # type: ignore[override]
541
+ snapshot: Optional[AgentSnapshot] = None,
482
542
  system_prompt: Optional[str] = None,
483
543
  tools: Optional[Union[ToolsType, list[str]]] = None,
544
+ confirmation_strategy_context: Optional[dict[str, Any]] = None,
545
+ chat_message_store_kwargs: Optional[dict[str, Any]] = None,
484
546
  **kwargs: Any,
485
547
  ) -> dict[str, Any]:
486
548
  """
@@ -501,6 +563,12 @@ class Agent(HaystackAgent):
501
563
  the relevant information to restart the Agent execution from where it left off.
502
564
  :param system_prompt: System prompt for the agent. If provided, it overrides the default system prompt.
503
565
  :param tools: Optional list of Tool objects, a Toolset, or list of tool names to use for this run.
566
+ :param confirmation_strategy_context: Optional dictionary for passing request-scoped resources
567
+ to confirmation strategies. Useful in web/server environments to provide per-request
568
+ objects (e.g., WebSocket connections, async queues, Redis pub/sub clients) that strategies
569
+ can use for non-blocking user interaction.
570
+ :param chat_message_store_kwargs: Optional dictionary of keyword arguments to pass to the ChatMessageStore.
571
+ For example, it can include the `chat_history_id` and `last_k` parameters for retrieving chat history.
504
572
  :param kwargs: Additional data to pass to the State schema used by the Agent.
505
573
  The keys must match the schema defined in the Agent's `state_schema`.
506
574
  :returns:
@@ -535,6 +603,7 @@ class Agent(HaystackAgent):
535
603
  requires_async=True,
536
604
  generation_kwargs=generation_kwargs,
537
605
  tools=tools,
606
+ confirmation_strategy_context=confirmation_strategy_context,
538
607
  )
539
608
  else:
540
609
  exe_context = self._initialize_fresh_execution(
@@ -544,6 +613,8 @@ class Agent(HaystackAgent):
544
613
  system_prompt=system_prompt,
545
614
  generation_kwargs=generation_kwargs,
546
615
  tools=tools,
616
+ confirmation_strategy_context=confirmation_strategy_context,
617
+ chat_message_store_kwargs=chat_message_store_kwargs,
547
618
  **kwargs,
548
619
  )
549
620
 
@@ -581,8 +652,8 @@ class Agent(HaystackAgent):
581
652
 
582
653
  # Apply confirmation strategies and update State and messages sent to ToolInvoker
583
654
  try:
584
- # Run confirmation strategies to get updated tool call messages and modified chat history
585
- modified_tool_call_messages, new_chat_history = _process_confirmation_strategies(
655
+ # Run confirmation strategies to get updated tool call messages and modified chat history (async)
656
+ modified_tool_call_messages, new_chat_history = await _process_confirmation_strategies_async(
586
657
  confirmation_strategies=self._confirmation_strategies,
587
658
  messages_with_tool_calls=llm_messages,
588
659
  execution_context=exe_context,
@@ -646,6 +717,13 @@ class Agent(HaystackAgent):
646
717
  result = {**exe_context.state.data}
647
718
  if msgs := result.get("messages"):
648
719
  result["last_message"] = msgs[-1]
720
+
721
+ # Write messages to ChatMessageStore if configured
722
+ if self._chat_message_writer:
723
+ writer_kwargs = _select_kwargs(self._chat_message_writer, chat_message_store_kwargs or {})
724
+ if "chat_history_id" in writer_kwargs:
725
+ self._chat_message_writer.run(messages=result["messages"], **writer_kwargs)
726
+
649
727
  return result
650
728
 
651
729
  def to_dict(self) -> dict[str, Any]:
@@ -660,6 +738,9 @@ class Agent(HaystackAgent):
660
738
  if self._confirmation_strategies
661
739
  else None
662
740
  )
741
+ data["init_parameters"]["chat_message_store"] = (
742
+ self._chat_message_store.to_dict() if self._chat_message_store is not None else None
743
+ )
663
744
  return data
664
745
 
665
746
  @classmethod
@@ -684,9 +765,31 @@ class Agent(HaystackAgent):
684
765
 
685
766
  if "confirmation_strategies" in init_params and init_params["confirmation_strategies"] is not None:
686
767
  for name, strategy_dict in init_params["confirmation_strategies"].items():
687
- strategy_class = import_class_by_name(strategy_dict["type"])
768
+ try:
769
+ strategy_class = import_class_by_name(strategy_dict["type"])
770
+ except ImportError as e:
771
+ raise DeserializationError(f"Class '{strategy_dict['type']}' not correctly imported") from e
688
772
  if not hasattr(strategy_class, "from_dict"):
689
- raise TypeError(f"{strategy_class} does not have from_dict method implemented.")
773
+ raise DeserializationError(f"{strategy_class} does not have from_dict method implemented.")
690
774
  init_params["confirmation_strategies"][name] = strategy_class.from_dict(strategy_dict)
691
775
 
776
+ if "chat_message_store" in init_params and init_params["chat_message_store"] is not None:
777
+ cms_data = init_params["chat_message_store"]
778
+ try:
779
+ cms_class = import_class_by_name(cms_data["type"])
780
+ except ImportError as e:
781
+ raise DeserializationError(f"Class '{cms_data['type']}' not correctly imported") from e
782
+ if not hasattr(cms_class, "from_dict"):
783
+ raise DeserializationError(f"{cms_class} does not have from_dict method implemented.")
784
+ init_params["chat_message_store"] = cms_class.from_dict(cms_data)
785
+
692
786
  return default_from_dict(cls, data)
787
+
788
+
789
+ def _select_kwargs(obj: Any, source: dict) -> dict[str, Any]:
790
+ """
791
+ Select only those key-value pairs from source dict that are valid parameters for obj.run() method.
792
+ """
793
+ sig = inspect.signature(obj.run)
794
+ allowed = set(sig.parameters.keys())
795
+ return {k: v for k, v in source.items() if k in allowed}
@@ -47,7 +47,13 @@ class BlockingConfirmationStrategy:
47
47
  self.confirmation_ui = confirmation_ui
48
48
 
49
49
  def run(
50
- self, tool_name: str, tool_description: str, tool_params: dict[str, Any], tool_call_id: Optional[str] = None
50
+ self,
51
+ *,
52
+ tool_name: str,
53
+ tool_description: str,
54
+ tool_params: dict[str, Any],
55
+ tool_call_id: Optional[str] = None,
56
+ confirmation_strategy_context: Optional[dict[str, Any]] = None,
51
57
  ) -> ToolExecutionDecision:
52
58
  """
53
59
  Run the human-in-the-loop strategy for a given tool and its parameters.
@@ -61,6 +67,10 @@ class BlockingConfirmationStrategy:
61
67
  :param tool_call_id:
62
68
  Optional unique identifier for the tool call. This can be used to track and correlate the decision with a
63
69
  specific tool invocation.
70
+ :param confirmation_strategy_context:
71
+ Optional dictionary for passing request-scoped resources. Useful in web/server environments
72
+ to provide per-request objects (e.g., WebSocket connections, async queues, Redis pub/sub clients)
73
+ that strategies can use for non-blocking user interaction.
64
74
 
65
75
  :returns:
66
76
  A ToolExecutionDecision indicating whether to execute the tool with the given parameters, or a
@@ -109,6 +119,40 @@ class BlockingConfirmationStrategy:
109
119
  tool_name=tool_name, execute=True, tool_call_id=tool_call_id, final_tool_params=tool_params
110
120
  )
111
121
 
122
+ async def run_async(
123
+ self,
124
+ *,
125
+ tool_name: str,
126
+ tool_description: str,
127
+ tool_params: dict[str, Any],
128
+ tool_call_id: Optional[str] = None,
129
+ confirmation_strategy_context: Optional[dict[str, Any]] = None,
130
+ ) -> ToolExecutionDecision:
131
+ """
132
+ Async version of run. Calls the sync run() method by default.
133
+
134
+ :param tool_name:
135
+ The name of the tool to be executed.
136
+ :param tool_description:
137
+ The description of the tool.
138
+ :param tool_params:
139
+ The parameters to be passed to the tool.
140
+ :param tool_call_id:
141
+ Optional unique identifier for the tool call.
142
+ :param confirmation_strategy_context:
143
+ Optional dictionary for passing request-scoped resources.
144
+
145
+ :returns:
146
+ A ToolExecutionDecision indicating whether to execute the tool with the given parameters.
147
+ """
148
+ return self.run(
149
+ tool_name=tool_name,
150
+ tool_description=tool_description,
151
+ tool_params=tool_params,
152
+ tool_call_id=tool_call_id,
153
+ confirmation_strategy_context=confirmation_strategy_context,
154
+ )
155
+
112
156
  def to_dict(self) -> dict[str, Any]:
113
157
  """
114
158
  Serializes the BlockingConfirmationStrategy to a dictionary.
@@ -161,7 +205,13 @@ class BreakpointConfirmationStrategy:
161
205
  self.snapshot_file_path = snapshot_file_path
162
206
 
163
207
  def run(
164
- self, tool_name: str, tool_description: str, tool_params: dict[str, Any], tool_call_id: Optional[str] = None
208
+ self,
209
+ *,
210
+ tool_name: str,
211
+ tool_description: str,
212
+ tool_params: dict[str, Any],
213
+ tool_call_id: Optional[str] = None,
214
+ confirmation_strategy_context: Optional[dict[str, Any]] = None,
165
215
  ) -> ToolExecutionDecision:
166
216
  """
167
217
  Run the breakpoint confirmation strategy for a given tool and its parameters.
@@ -175,6 +225,9 @@ class BreakpointConfirmationStrategy:
175
225
  :param tool_call_id:
176
226
  Optional unique identifier for the tool call. This can be used to track and correlate the decision with a
177
227
  specific tool invocation.
228
+ :param confirmation_strategy_context:
229
+ Optional dictionary for passing request-scoped resources. Not used by this strategy but included for
230
+ interface compatibility.
178
231
 
179
232
  :raises HITLBreakpointException:
180
233
  Always raises an `HITLBreakpointException` exception to signal that user confirmation is required.
@@ -189,6 +242,43 @@ class BreakpointConfirmationStrategy:
189
242
  snapshot_file_path=self.snapshot_file_path,
190
243
  )
191
244
 
245
+ async def run_async(
246
+ self,
247
+ *,
248
+ tool_name: str,
249
+ tool_description: str,
250
+ tool_params: dict[str, Any],
251
+ tool_call_id: Optional[str] = None,
252
+ confirmation_strategy_context: Optional[dict[str, Any]] = None,
253
+ ) -> ToolExecutionDecision:
254
+ """
255
+ Async version of run. Calls the sync run() method.
256
+
257
+ :param tool_name:
258
+ The name of the tool to be executed.
259
+ :param tool_description:
260
+ The description of the tool.
261
+ :param tool_params:
262
+ The parameters to be passed to the tool.
263
+ :param tool_call_id:
264
+ Optional unique identifier for the tool call.
265
+ :param confirmation_strategy_context:
266
+ Optional dictionary for passing request-scoped resources.
267
+
268
+ :raises HITLBreakpointException:
269
+ Always raises an `HITLBreakpointException` exception to signal that user confirmation is required.
270
+
271
+ :returns:
272
+ This method does not return; it always raises an exception.
273
+ """
274
+ return self.run(
275
+ tool_name=tool_name,
276
+ tool_description=tool_description,
277
+ tool_params=tool_params,
278
+ tool_call_id=tool_call_id,
279
+ confirmation_strategy_context=confirmation_strategy_context,
280
+ )
281
+
192
282
  def to_dict(self) -> dict[str, Any]:
193
283
  """
194
284
  Serializes the BreakpointConfirmationStrategy to a dictionary.
@@ -285,6 +375,46 @@ def _process_confirmation_strategies(
285
375
  return modified_tool_call_messages, new_chat_history
286
376
 
287
377
 
378
+ async def _process_confirmation_strategies_async(
379
+ *,
380
+ confirmation_strategies: dict[str, ConfirmationStrategy],
381
+ messages_with_tool_calls: list[ChatMessage],
382
+ execution_context: "_ExecutionContext",
383
+ ) -> tuple[list[ChatMessage], list[ChatMessage]]:
384
+ """
385
+ Async version of _process_confirmation_strategies.
386
+
387
+ Run the confirmation strategies and return modified tool call messages and updated chat history.
388
+
389
+ :param confirmation_strategies: Mapping of tool names to their corresponding confirmation strategies
390
+ :param messages_with_tool_calls: Chat messages containing tool calls
391
+ :param execution_context: The current execution context of the agent
392
+ :returns:
393
+ Tuple of modified messages with confirmed tool calls and updated chat history
394
+ """
395
+ # Run confirmation strategies and get tool execution decisions (async version)
396
+ teds = await _run_confirmation_strategies_async(
397
+ confirmation_strategies=confirmation_strategies,
398
+ messages_with_tool_calls=messages_with_tool_calls,
399
+ execution_context=execution_context,
400
+ )
401
+
402
+ # Apply tool execution decisions to messages_with_tool_calls
403
+ rejection_messages, modified_tool_call_messages = _apply_tool_execution_decisions(
404
+ tool_call_messages=messages_with_tool_calls,
405
+ tool_execution_decisions=teds,
406
+ )
407
+
408
+ # Update the chat history with rejection messages and new tool call messages
409
+ new_chat_history = _update_chat_history(
410
+ chat_history=execution_context.state.get("messages"),
411
+ rejection_messages=rejection_messages,
412
+ tool_call_and_explanation_messages=modified_tool_call_messages,
413
+ )
414
+
415
+ return modified_tool_call_messages, new_chat_history
416
+
417
+
288
418
  def _run_confirmation_strategies(
289
419
  confirmation_strategies: dict[str, ConfirmationStrategy],
290
420
  messages_with_tool_calls: list[ChatMessage],
@@ -344,13 +474,100 @@ def _run_confirmation_strategies(
344
474
  # If not, run the confirmation strategy
345
475
  if not ted:
346
476
  ted = confirmation_strategies[tool_name].run(
347
- tool_name=tool_name, tool_description=tool_to_invoke.description, tool_params=final_args
477
+ tool_name=tool_name,
478
+ tool_description=tool_to_invoke.description,
479
+ tool_params=final_args,
480
+ tool_call_id=tool_call.id,
481
+ confirmation_strategy_context=execution_context.confirmation_strategy_context,
348
482
  )
349
483
  teds.append(ted)
350
484
 
351
485
  return teds
352
486
 
353
487
 
488
+ async def _run_confirmation_strategies_async(
489
+ confirmation_strategies: dict[str, ConfirmationStrategy],
490
+ messages_with_tool_calls: list[ChatMessage],
491
+ execution_context: "_ExecutionContext",
492
+ ) -> list[ToolExecutionDecision]:
493
+ """
494
+ Async version of _run_confirmation_strategies.
495
+
496
+ Run confirmation strategies for tool calls in the provided chat messages.
497
+
498
+ :param confirmation_strategies: Mapping of tool names to their corresponding confirmation strategies
499
+ :param messages_with_tool_calls: Messages containing tool calls to process
500
+ :param execution_context: The current execution context containing state and inputs
501
+ :returns:
502
+ A list of ToolExecutionDecision objects representing the decisions made for each tool call.
503
+ """
504
+ state = execution_context.state
505
+ tools_with_names = {tool.name: tool for tool in execution_context.tool_invoker_inputs["tools"]}
506
+ existing_teds = execution_context.tool_execution_decisions if execution_context.tool_execution_decisions else []
507
+ existing_teds_by_name = {ted.tool_name: ted for ted in existing_teds if ted.tool_name}
508
+ existing_teds_by_id = {ted.tool_call_id: ted for ted in existing_teds if ted.tool_call_id}
509
+
510
+ teds = []
511
+ for message in messages_with_tool_calls:
512
+ if not message.tool_calls:
513
+ continue
514
+
515
+ for tool_call in message.tool_calls:
516
+ tool_name = tool_call.tool_name
517
+ tool_to_invoke = tools_with_names[tool_name]
518
+
519
+ # Prepare final tool args
520
+ final_args = _prepare_tool_args(
521
+ tool=tool_to_invoke,
522
+ tool_call_arguments=tool_call.arguments,
523
+ state=state,
524
+ streaming_callback=execution_context.tool_invoker_inputs.get("streaming_callback"),
525
+ enable_streaming_passthrough=execution_context.tool_invoker_inputs.get(
526
+ "enable_streaming_passthrough", False
527
+ ),
528
+ )
529
+
530
+ # Get tool execution decisions from confirmation strategies
531
+ # If no confirmation strategy is defined for this tool, proceed with execution
532
+ if tool_name not in confirmation_strategies:
533
+ teds.append(
534
+ ToolExecutionDecision(
535
+ tool_call_id=tool_call.id,
536
+ tool_name=tool_name,
537
+ execute=True,
538
+ final_tool_params=final_args,
539
+ )
540
+ )
541
+ continue
542
+
543
+ # Check if there's already a decision for this tool call in the execution context
544
+ ted = existing_teds_by_id.get(tool_call.id or "") or existing_teds_by_name.get(tool_name)
545
+
546
+ # If not, run the confirmation strategy (async version)
547
+ if not ted:
548
+ strategy = confirmation_strategies[tool_name]
549
+ # Use run_async if available, otherwise fall back to sync run
550
+ if hasattr(strategy, "run_async"):
551
+ ted = await strategy.run_async(
552
+ tool_name=tool_name,
553
+ tool_description=tool_to_invoke.description,
554
+ tool_params=final_args,
555
+ tool_call_id=tool_call.id,
556
+ confirmation_strategy_context=execution_context.confirmation_strategy_context,
557
+ )
558
+ else:
559
+ ted = strategy.run(
560
+ tool_name=tool_name,
561
+ tool_description=tool_to_invoke.description,
562
+ tool_params=final_args,
563
+ tool_call_id=tool_call.id,
564
+ confirmation_strategy_context=execution_context.confirmation_strategy_context,
565
+ )
566
+ teds.append(ted)
567
+
568
+ return teds
569
+
570
+
354
571
  def _apply_tool_execution_decisions(
355
572
  tool_call_messages: list[ChatMessage], tool_execution_decisions: list[ToolExecutionDecision]
356
573
  ) -> tuple[list[ChatMessage], list[ChatMessage]]: