haystack-experimental 0.15.1__py3-none-any.whl → 0.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. haystack_experimental/chat_message_stores/in_memory.py +3 -3
  2. haystack_experimental/chat_message_stores/types.py +2 -2
  3. haystack_experimental/components/agents/agent.py +174 -119
  4. haystack_experimental/components/agents/human_in_the_loop/breakpoint.py +3 -1
  5. haystack_experimental/components/agents/human_in_the_loop/dataclasses.py +6 -6
  6. haystack_experimental/components/agents/human_in_the_loop/errors.py +1 -5
  7. haystack_experimental/components/agents/human_in_the_loop/strategies.py +10 -10
  8. haystack_experimental/components/agents/human_in_the_loop/types.py +5 -5
  9. haystack_experimental/components/agents/human_in_the_loop/user_interfaces.py +2 -2
  10. haystack_experimental/components/generators/chat/openai.py +11 -11
  11. haystack_experimental/components/preprocessors/__init__.py +1 -3
  12. haystack_experimental/components/retrievers/chat_message_retriever.py +4 -4
  13. haystack_experimental/components/retrievers/types/protocol.py +3 -3
  14. haystack_experimental/components/summarizers/llm_summarizer.py +7 -7
  15. haystack_experimental/core/pipeline/breakpoint.py +6 -6
  16. haystack_experimental/dataclasses/breakpoints.py +2 -2
  17. haystack_experimental/utils/hallucination_risk_calculator/dataclasses.py +9 -9
  18. haystack_experimental/utils/hallucination_risk_calculator/openai_planner.py +4 -4
  19. haystack_experimental/utils/hallucination_risk_calculator/skeletonization.py +5 -5
  20. {haystack_experimental-0.15.1.dist-info → haystack_experimental-0.16.0.dist-info}/METADATA +6 -10
  21. {haystack_experimental-0.15.1.dist-info → haystack_experimental-0.16.0.dist-info}/RECORD +24 -25
  22. haystack_experimental/components/preprocessors/embedding_based_document_splitter.py +0 -430
  23. {haystack_experimental-0.15.1.dist-info → haystack_experimental-0.16.0.dist-info}/WHEEL +0 -0
  24. {haystack_experimental-0.15.1.dist-info → haystack_experimental-0.16.0.dist-info}/licenses/LICENSE +0 -0
  25. {haystack_experimental-0.15.1.dist-info → haystack_experimental-0.16.0.dist-info}/licenses/LICENSE-MIT.txt +0 -0
@@ -3,7 +3,7 @@
3
3
  # SPDX-License-Identifier: Apache-2.0
4
4
 
5
5
  from dataclasses import replace
6
- from typing import Any, Iterable, Optional
6
+ from typing import Any, Iterable
7
7
 
8
8
  from haystack import default_from_dict, default_to_dict
9
9
  from haystack.dataclasses import ChatMessage, ChatRole
@@ -42,7 +42,7 @@ class InMemoryChatMessageStore:
42
42
  ```
43
43
  """
44
44
 
45
- def __init__(self, skip_system_messages: bool = True, last_k: Optional[int] = 10) -> None:
45
+ def __init__(self, skip_system_messages: bool = True, last_k: int | None = 10) -> None:
46
46
  """
47
47
  Create an InMemoryChatMessageStore.
48
48
 
@@ -135,7 +135,7 @@ class InMemoryChatMessageStore:
135
135
 
136
136
  return len(messages_to_write)
137
137
 
138
- def retrieve_messages(self, chat_history_id: str, last_k: Optional[int] = None) -> list[ChatMessage]:
138
+ def retrieve_messages(self, chat_history_id: str, last_k: int | None = None) -> list[ChatMessage]:
139
139
  """
140
140
  Retrieves all stored chat messages.
141
141
 
@@ -2,7 +2,7 @@
2
2
  #
3
3
  # SPDX-License-Identifier: Apache-2.0
4
4
 
5
- from typing import Any, Optional, Protocol
5
+ from typing import Any, Protocol
6
6
 
7
7
  from haystack.dataclasses import ChatMessage
8
8
 
@@ -74,7 +74,7 @@ class ChatMessageStore(Protocol):
74
74
  """
75
75
  ...
76
76
 
77
- def retrieve_messages(self, chat_history_id: str, last_k: Optional[int] = None) -> list[ChatMessage]:
77
+ def retrieve_messages(self, chat_history_id: str, last_k: int | None = None) -> list[ChatMessage]:
78
78
  """
79
79
  Retrieves chat messages from the ChatMessageStore.
80
80
 
@@ -7,7 +7,7 @@
7
7
 
8
8
  import inspect
9
9
  from dataclasses import dataclass
10
- from typing import Any, Optional, Union
10
+ from typing import Any
11
11
 
12
12
  # Monkey patch Haystack's AgentSnapshot with our extended version
13
13
  import haystack.dataclasses.breakpoints as hdb
@@ -28,11 +28,13 @@ from haystack.components.agents.agent import _ExecutionContext as Haystack_Execu
28
28
  from haystack.components.agents.agent import _schema_from_dict
29
29
  from haystack.components.agents.state import replace_values, State
30
30
  from haystack.components.generators.chat.types import ChatGenerator
31
- from haystack.core.errors import PipelineRuntimeError
31
+ from haystack.core.errors import BreakpointException, PipelineRuntimeError
32
32
  from haystack.core.pipeline import AsyncPipeline, Pipeline
33
33
  from haystack.core.pipeline.breakpoint import (
34
34
  _create_pipeline_snapshot_from_chat_generator,
35
35
  _create_pipeline_snapshot_from_tool_invoker,
36
+ _save_pipeline_snapshot,
37
+ _should_trigger_tool_invoker_breakpoint,
36
38
  )
37
39
  from haystack.core.pipeline.utils import _deepcopy_with_exceptions
38
40
  from haystack.core.serialization import default_from_dict, import_class_by_name
@@ -75,8 +77,8 @@ class _ExecutionContext(Haystack_ExecutionContext):
75
77
  parameter in their `run()` and `run_async()` methods.
76
78
  """
77
79
 
78
- tool_execution_decisions: Optional[list[ToolExecutionDecision]] = None
79
- confirmation_strategy_context: Optional[dict[str, Any]] = None
80
+ tool_execution_decisions: list[ToolExecutionDecision] | None = None
81
+ confirmation_strategy_context: dict[str, Any] | None = None
80
82
 
81
83
 
82
84
  class Agent(HaystackAgent):
@@ -134,16 +136,16 @@ class Agent(HaystackAgent):
134
136
  self,
135
137
  *,
136
138
  chat_generator: ChatGenerator,
137
- tools: Optional[ToolsType] = None,
138
- system_prompt: Optional[str] = None,
139
- exit_conditions: Optional[list[str]] = None,
140
- state_schema: Optional[dict[str, Any]] = None,
139
+ tools: ToolsType | None = None,
140
+ system_prompt: str | None = None,
141
+ exit_conditions: list[str] | None = None,
142
+ state_schema: dict[str, Any] | None = None,
141
143
  max_agent_steps: int = 100,
142
- streaming_callback: Optional[StreamingCallbackT] = None,
144
+ streaming_callback: StreamingCallbackT | None = None,
143
145
  raise_on_tool_invocation_failure: bool = False,
144
- confirmation_strategies: Optional[dict[str, ConfirmationStrategy]] = None,
145
- tool_invoker_kwargs: Optional[dict[str, Any]] = None,
146
- chat_message_store: Optional[ChatMessageStore] = None,
146
+ confirmation_strategies: dict[str, ConfirmationStrategy] | None = None,
147
+ tool_invoker_kwargs: dict[str, Any] | None = None,
148
+ chat_message_store: ChatMessageStore | None = None,
147
149
  ) -> None:
148
150
  """
149
151
  Initialize the agent component.
@@ -188,14 +190,14 @@ class Agent(HaystackAgent):
188
190
  def _initialize_fresh_execution(
189
191
  self,
190
192
  messages: list[ChatMessage],
191
- streaming_callback: Optional[StreamingCallbackT],
193
+ streaming_callback: StreamingCallbackT | None,
192
194
  requires_async: bool,
193
195
  *,
194
- system_prompt: Optional[str] = None,
195
- generation_kwargs: Optional[dict[str, Any]] = None,
196
- tools: Optional[Union[ToolsType, list[str]]] = None,
197
- confirmation_strategy_context: Optional[dict[str, Any]] = None,
198
- chat_message_store_kwargs: Optional[dict[str, Any]] = None,
196
+ system_prompt: str | None = None,
197
+ generation_kwargs: dict[str, Any] | None = None,
198
+ tools: ToolsType | list[str] | None = None,
199
+ confirmation_strategy_context: dict[str, Any] | None = None,
200
+ chat_message_store_kwargs: dict[str, Any] | None = None,
199
201
  **kwargs: dict[str, Any],
200
202
  ) -> _ExecutionContext:
201
203
  """
@@ -262,12 +264,12 @@ class Agent(HaystackAgent):
262
264
  def _initialize_from_snapshot( # type: ignore[override]
263
265
  self,
264
266
  snapshot: AgentSnapshot,
265
- streaming_callback: Optional[StreamingCallbackT],
267
+ streaming_callback: StreamingCallbackT | None,
266
268
  requires_async: bool,
267
269
  *,
268
- generation_kwargs: Optional[dict[str, Any]] = None,
269
- tools: Optional[Union[ToolsType, list[str]]] = None,
270
- confirmation_strategy_context: Optional[dict[str, Any]] = None,
270
+ generation_kwargs: dict[str, Any] | None = None,
271
+ tools: ToolsType | list[str] | None = None,
272
+ confirmation_strategy_context: dict[str, Any] | None = None,
271
273
  ) -> _ExecutionContext:
272
274
  """
273
275
  Initialize execution context from an AgentSnapshot.
@@ -315,18 +317,18 @@ class Agent(HaystackAgent):
315
317
  confirmation_strategy_context=confirmation_strategy_context,
316
318
  )
317
319
 
318
- def run( # type: ignore[override] # noqa: PLR0915
320
+ def run( # type: ignore[override] # noqa: PLR0915 PLR0912
319
321
  self,
320
322
  messages: list[ChatMessage],
321
- streaming_callback: Optional[StreamingCallbackT] = None,
323
+ streaming_callback: StreamingCallbackT | None = None,
322
324
  *,
323
- generation_kwargs: Optional[dict[str, Any]] = None,
324
- break_point: Optional[AgentBreakpoint] = None,
325
- snapshot: Optional[AgentSnapshot] = None,
326
- system_prompt: Optional[str] = None,
327
- tools: Optional[Union[ToolsType, list[str]]] = None,
328
- confirmation_strategy_context: Optional[dict[str, Any]] = None,
329
- chat_message_store_kwargs: Optional[dict[str, Any]] = None,
325
+ generation_kwargs: dict[str, Any] | None = None,
326
+ break_point: AgentBreakpoint | None = None,
327
+ snapshot: AgentSnapshot | None = None,
328
+ system_prompt: str | None = None,
329
+ tools: ToolsType | list[str] | None = None,
330
+ confirmation_strategy_context: dict[str, Any] | None = None,
331
+ chat_message_store_kwargs: dict[str, Any] | None = None,
330
332
  **kwargs: Any,
331
333
  ) -> dict[str, Any]:
332
334
  """
@@ -360,8 +362,6 @@ class Agent(HaystackAgent):
360
362
  :raises RuntimeError: If the Agent component wasn't warmed up before calling `run()`.
361
363
  :raises BreakpointException: If an agent breakpoint is triggered.
362
364
  """
363
- # We pop parent_snapshot from kwargs to avoid passing it into State.
364
- parent_snapshot = kwargs.pop("parent_snapshot", None)
365
365
  agent_inputs = {
366
366
  "messages": messages,
367
367
  "streaming_callback": streaming_callback,
@@ -369,13 +369,9 @@ class Agent(HaystackAgent):
369
369
  "snapshot": snapshot,
370
370
  **kwargs,
371
371
  }
372
- # The PR https://github.com/deepset-ai/haystack/pull/9987 removed the unused snapshot parameter from
373
- # _runtime_checks. This change will be released in Haystack 2.20.0.
374
- # To maintain compatibility with Haystack 2.19 we check the number of parameters and call accordingly.
375
- if len(inspect.signature(self._runtime_checks).parameters) == 2:
376
- self._runtime_checks(break_point, snapshot) # type: ignore[call-arg] # pylint: disable=too-many-function-args
377
- else:
378
- self._runtime_checks(break_point) # type: ignore[call-arg] # pylint: disable=no-value-for-parameter
372
+ # TODO Probably good to add a warning in runtime checks that BreakpointConfirmationStrategy will take
373
+ # precedence over passing a ToolBreakpoint
374
+ self._runtime_checks(break_point)
379
375
 
380
376
  if snapshot:
381
377
  exe_context = self._initialize_from_snapshot(
@@ -403,10 +399,6 @@ class Agent(HaystackAgent):
403
399
  span.set_content_tag("haystack.agent.input", _deepcopy_with_exceptions(agent_inputs))
404
400
 
405
401
  while exe_context.counter < self.max_agent_steps:
406
- # Handle breakpoint and ChatGenerator call
407
- Agent._check_chat_generator_breakpoint(
408
- execution_context=exe_context, break_point=break_point, parent_snapshot=parent_snapshot
409
- )
410
402
  # We skip the chat generator when restarting from a snapshot from a ToolBreakpoint
411
403
  if exe_context.skip_chat_generator:
412
404
  llm_messages = exe_context.state.get("messages", [])[-1:]
@@ -423,14 +415,26 @@ class Agent(HaystackAgent):
423
415
  },
424
416
  component_visits=exe_context.component_visits,
425
417
  parent_span=span,
418
+ break_point=break_point.break_point if isinstance(break_point, AgentBreakpoint) else None,
426
419
  )
427
- except PipelineRuntimeError as e:
428
- pipeline_snapshot = _create_pipeline_snapshot_from_chat_generator(
429
- agent_name=getattr(self, "__component_name__", None),
430
- execution_context=exe_context,
431
- parent_snapshot=parent_snapshot,
420
+ except (BreakpointException, PipelineRuntimeError) as e:
421
+ if isinstance(e, BreakpointException):
422
+ agent_name = break_point.agent_name if break_point else None
423
+ saved_bp = break_point
424
+ else:
425
+ agent_name = getattr(self, "__component_name__", None)
426
+ saved_bp = None
427
+
428
+ e.pipeline_snapshot = _create_pipeline_snapshot_from_chat_generator(
429
+ agent_name=agent_name, execution_context=exe_context, break_point=saved_bp
432
430
  )
433
- e.pipeline_snapshot = pipeline_snapshot
431
+ if isinstance(e, BreakpointException):
432
+ e._break_point = e.pipeline_snapshot.break_point
433
+ # If Agent is not in a pipeline, we save the snapshot to a file.
434
+ # Checked by __component_name__ not being set.
435
+ if getattr(self, "__component_name__", None) is None:
436
+ full_file_path = _save_pipeline_snapshot(pipeline_snapshot=e.pipeline_snapshot)
437
+ e.pipeline_snapshot_file_path = full_file_path
434
438
  raise e
435
439
 
436
440
  llm_messages = result["replies"]
@@ -441,6 +445,19 @@ class Agent(HaystackAgent):
441
445
  exe_context.counter += 1
442
446
  break
443
447
 
448
+ # We only pass down the breakpoint if the tool name matches the tool call in the LLM messages
449
+ resolved_break_point = None
450
+ break_point_to_pass = None
451
+ if (
452
+ break_point
453
+ and isinstance(break_point.break_point, ToolBreakpoint)
454
+ and _should_trigger_tool_invoker_breakpoint(
455
+ break_point=break_point.break_point, llm_messages=llm_messages
456
+ )
457
+ ):
458
+ resolved_break_point = break_point
459
+ break_point_to_pass = resolved_break_point.break_point
460
+
444
461
  # Apply confirmation strategies and update State and messages sent to ToolInvoker
445
462
  try:
446
463
  # Run confirmation strategies to get updated tool call messages and modified chat history
@@ -452,8 +469,8 @@ class Agent(HaystackAgent):
452
469
  # Replace the chat history in state with the modified one
453
470
  exe_context.state.set(key="messages", value=new_chat_history, handler_override=replace_values)
454
471
  except HITLBreakpointException as tbp_error:
455
- # We create a break_point to pass into _check_tool_invoker_breakpoint
456
- break_point = AgentBreakpoint(
472
+ # We create a break_point to pass to Pipeline._run_component
473
+ resolved_break_point = AgentBreakpoint(
457
474
  agent_name=getattr(self, "__component_name__", ""),
458
475
  break_point=ToolBreakpoint(
459
476
  component_name="tool_invoker",
@@ -462,11 +479,9 @@ class Agent(HaystackAgent):
462
479
  snapshot_file_path=tbp_error.snapshot_file_path,
463
480
  ),
464
481
  )
465
-
466
- # Handle breakpoint
467
- Agent._check_tool_invoker_breakpoint(
468
- execution_context=exe_context, break_point=break_point, parent_snapshot=parent_snapshot
469
- )
482
+ break_point_to_pass = resolved_break_point.break_point
483
+ # If we hit a HITL breakpoint, we skip passing modified messages to ToolInvoker
484
+ modified_tool_call_messages = llm_messages
470
485
 
471
486
  # Run ToolInvoker
472
487
  try:
@@ -481,19 +496,28 @@ class Agent(HaystackAgent):
481
496
  },
482
497
  component_visits=exe_context.component_visits,
483
498
  parent_span=span,
499
+ break_point=break_point_to_pass,
484
500
  )
485
- except PipelineRuntimeError as e:
486
- # Access the original Tool Invoker exception
487
- original_error = e.__cause__
488
- tool_name = getattr(original_error, "tool_name", None)
489
-
490
- pipeline_snapshot = _create_pipeline_snapshot_from_tool_invoker(
491
- tool_name=tool_name,
492
- agent_name=getattr(self, "__component_name__", None),
493
- execution_context=exe_context,
494
- parent_snapshot=parent_snapshot,
501
+ except (BreakpointException, PipelineRuntimeError) as e:
502
+ if isinstance(e, BreakpointException):
503
+ agent_name = resolved_break_point.agent_name if resolved_break_point else None
504
+ tool_name = e.break_point.tool_name if isinstance(e.break_point, ToolBreakpoint) else None
505
+ saved_bp = resolved_break_point
506
+ else:
507
+ agent_name = getattr(self, "__component_name__", None)
508
+ tool_name = getattr(e.__cause__, "tool_name", None)
509
+ saved_bp = None
510
+
511
+ e.pipeline_snapshot = _create_pipeline_snapshot_from_tool_invoker(
512
+ tool_name=tool_name, agent_name=agent_name, execution_context=exe_context, break_point=saved_bp
495
513
  )
496
- e.pipeline_snapshot = pipeline_snapshot
514
+ if isinstance(e, BreakpointException):
515
+ e._break_point = e.pipeline_snapshot.break_point
516
+ # If Agent is not in a pipeline, we save the snapshot to a file.
517
+ # Checked by __component_name__ not being set.
518
+ if getattr(self, "__component_name__", None) is None:
519
+ full_file_path = _save_pipeline_snapshot(pipeline_snapshot=e.pipeline_snapshot)
520
+ e.pipeline_snapshot_file_path = full_file_path
497
521
  raise e
498
522
 
499
523
  # Set execution context tool execution decisions to empty after applying them b/c they should only
@@ -531,18 +555,18 @@ class Agent(HaystackAgent):
531
555
 
532
556
  return result
533
557
 
534
- async def run_async( # type: ignore[override]
558
+ async def run_async( # type: ignore[override] # noqa: PLR0915
535
559
  self,
536
560
  messages: list[ChatMessage],
537
- streaming_callback: Optional[StreamingCallbackT] = None,
561
+ streaming_callback: StreamingCallbackT | None = None,
538
562
  *,
539
- generation_kwargs: Optional[dict[str, Any]] = None,
540
- break_point: Optional[AgentBreakpoint] = None,
541
- snapshot: Optional[AgentSnapshot] = None,
542
- system_prompt: Optional[str] = None,
543
- tools: Optional[Union[ToolsType, list[str]]] = None,
544
- confirmation_strategy_context: Optional[dict[str, Any]] = None,
545
- chat_message_store_kwargs: Optional[dict[str, Any]] = None,
563
+ generation_kwargs: dict[str, Any] | None = None,
564
+ break_point: AgentBreakpoint | None = None,
565
+ snapshot: AgentSnapshot | None = None,
566
+ system_prompt: str | None = None,
567
+ tools: ToolsType | list[str] | None = None,
568
+ confirmation_strategy_context: dict[str, Any] | None = None,
569
+ chat_message_store_kwargs: dict[str, Any] | None = None,
546
570
  **kwargs: Any,
547
571
  ) -> dict[str, Any]:
548
572
  """
@@ -579,8 +603,6 @@ class Agent(HaystackAgent):
579
603
  :raises RuntimeError: If the Agent component wasn't warmed up before calling `run_async()`.
580
604
  :raises BreakpointException: If an agent breakpoint is triggered.
581
605
  """
582
- # We pop parent_snapshot from kwargs to avoid passing it into State.
583
- parent_snapshot = kwargs.pop("parent_snapshot", None)
584
606
  agent_inputs = {
585
607
  "messages": messages,
586
608
  "streaming_callback": streaming_callback,
@@ -588,13 +610,7 @@ class Agent(HaystackAgent):
588
610
  "snapshot": snapshot,
589
611
  **kwargs,
590
612
  }
591
- # The PR https://github.com/deepset-ai/haystack/pull/9987 removed the unused snapshot parameter from
592
- # _runtime_checks. This change will be released in Haystack 2.20.0.
593
- # To maintain compatibility with Haystack 2.19 we check the number of parameters and call accordingly.
594
- if len(inspect.signature(self._runtime_checks).parameters) == 2:
595
- self._runtime_checks(break_point, snapshot) # type: ignore[call-arg] # pylint: disable=too-many-function-args
596
- else:
597
- self._runtime_checks(break_point) # type: ignore[call-arg] # pylint: disable=no-value-for-parameter
613
+ self._runtime_checks(break_point)
598
614
 
599
615
  if snapshot:
600
616
  exe_context = self._initialize_from_snapshot(
@@ -622,26 +638,39 @@ class Agent(HaystackAgent):
622
638
  span.set_content_tag("haystack.agent.input", _deepcopy_with_exceptions(agent_inputs))
623
639
 
624
640
  while exe_context.counter < self.max_agent_steps:
625
- # Handle breakpoint and ChatGenerator call
626
- self._check_chat_generator_breakpoint(
627
- execution_context=exe_context, break_point=break_point, parent_snapshot=parent_snapshot
628
- )
629
641
  # We skip the chat generator when restarting from a snapshot from a ToolBreakpoint
630
642
  if exe_context.skip_chat_generator:
631
643
  llm_messages = exe_context.state.get("messages", [])[-1:]
632
644
  # Set to False so the next iteration will call the chat generator
633
645
  exe_context.skip_chat_generator = False
634
646
  else:
635
- result = await AsyncPipeline._run_component_async(
636
- component_name="chat_generator",
637
- component={"instance": self.chat_generator},
638
- component_inputs={
639
- "messages": exe_context.state.data["messages"],
640
- **exe_context.chat_generator_inputs,
641
- },
642
- component_visits=exe_context.component_visits,
643
- parent_span=span,
644
- )
647
+ try:
648
+ result = await AsyncPipeline._run_component_async(
649
+ component_name="chat_generator",
650
+ component={"instance": self.chat_generator},
651
+ component_inputs={
652
+ "messages": exe_context.state.data["messages"],
653
+ **exe_context.chat_generator_inputs,
654
+ },
655
+ component_visits=exe_context.component_visits,
656
+ parent_span=span,
657
+ break_point=break_point.break_point if isinstance(break_point, AgentBreakpoint) else None,
658
+ )
659
+ except BreakpointException as e:
660
+ e.pipeline_snapshot = _create_pipeline_snapshot_from_chat_generator(
661
+ agent_name=break_point.agent_name if break_point else None,
662
+ execution_context=exe_context,
663
+ break_point=break_point,
664
+ )
665
+ e._break_point = e.pipeline_snapshot.break_point
666
+ # We check if the agent is part of a pipeline by checking for __component_name__
667
+ # If it is not in a pipeline, we save the snapshot to a file.
668
+ in_pipeline = getattr(self, "__component_name__", None) is not None
669
+ if not in_pipeline:
670
+ full_file_path = _save_pipeline_snapshot(pipeline_snapshot=e.pipeline_snapshot)
671
+ e.pipeline_snapshot_file_path = full_file_path
672
+ raise e
673
+
645
674
  llm_messages = result["replies"]
646
675
  exe_context.state.set("messages", llm_messages)
647
676
 
@@ -650,6 +679,19 @@ class Agent(HaystackAgent):
650
679
  exe_context.counter += 1
651
680
  break
652
681
 
682
+ # We only pass down the breakpoint if the tool name matches the tool call in the LLM messages
683
+ resolved_break_point = None
684
+ break_point_to_pass = None
685
+ if (
686
+ break_point
687
+ and isinstance(break_point.break_point, ToolBreakpoint)
688
+ and _should_trigger_tool_invoker_breakpoint(
689
+ break_point=break_point.break_point, llm_messages=llm_messages
690
+ )
691
+ ):
692
+ resolved_break_point = break_point
693
+ break_point_to_pass = resolved_break_point.break_point
694
+
653
695
  # Apply confirmation strategies and update State and messages sent to ToolInvoker
654
696
  try:
655
697
  # Run confirmation strategies to get updated tool call messages and modified chat history (async)
@@ -662,7 +704,7 @@ class Agent(HaystackAgent):
662
704
  exe_context.state.set(key="messages", value=new_chat_history, handler_override=replace_values)
663
705
  except HITLBreakpointException as tbp_error:
664
706
  # We create a break_point to pass into _check_tool_invoker_breakpoint
665
- break_point = AgentBreakpoint(
707
+ resolved_break_point = AgentBreakpoint(
666
708
  agent_name=getattr(self, "__component_name__", ""),
667
709
  break_point=ToolBreakpoint(
668
710
  component_name="tool_invoker",
@@ -671,25 +713,38 @@ class Agent(HaystackAgent):
671
713
  snapshot_file_path=tbp_error.snapshot_file_path,
672
714
  ),
673
715
  )
716
+ break_point_to_pass = resolved_break_point.break_point
717
+ # If we hit a HITL breakpoint, we skip passing modified messages to ToolInvoker
718
+ modified_tool_call_messages = llm_messages
674
719
 
675
- # Handle breakpoint
676
- Agent._check_tool_invoker_breakpoint(
677
- execution_context=exe_context, break_point=break_point, parent_snapshot=parent_snapshot
678
- )
679
-
680
- # Run ToolInvoker
681
- # We only send the messages from the LLM to the tool invoker
682
- tool_invoker_result = await AsyncPipeline._run_component_async(
683
- component_name="tool_invoker",
684
- component={"instance": self._tool_invoker},
685
- component_inputs={
686
- "messages": modified_tool_call_messages,
687
- "state": exe_context.state,
688
- **exe_context.tool_invoker_inputs,
689
- },
690
- component_visits=exe_context.component_visits,
691
- parent_span=span,
692
- )
720
+ try:
721
+ # We only send the messages from the LLM to the tool invoker
722
+ tool_invoker_result = await AsyncPipeline._run_component_async(
723
+ component_name="tool_invoker",
724
+ component={"instance": self._tool_invoker},
725
+ component_inputs={
726
+ "messages": modified_tool_call_messages,
727
+ "state": exe_context.state,
728
+ **exe_context.tool_invoker_inputs,
729
+ },
730
+ component_visits=exe_context.component_visits,
731
+ parent_span=span,
732
+ break_point=break_point_to_pass,
733
+ )
734
+ except BreakpointException as e:
735
+ e.pipeline_snapshot = _create_pipeline_snapshot_from_tool_invoker(
736
+ tool_name=e.break_point.tool_name if isinstance(e.break_point, ToolBreakpoint) else None,
737
+ agent_name=resolved_break_point.agent_name if resolved_break_point else None,
738
+ execution_context=exe_context,
739
+ break_point=resolved_break_point,
740
+ )
741
+ e._break_point = e.pipeline_snapshot.break_point
742
+ # If Agent is not in a pipeline, we save the snapshot to a file.
743
+ # Checked by __component_name__ not being set.
744
+ if getattr(self, "__component_name__", None) is None:
745
+ full_file_path = _save_pipeline_snapshot(pipeline_snapshot=e.pipeline_snapshot)
746
+ e.pipeline_snapshot_file_path = full_file_path
747
+ raise e
693
748
 
694
749
  # Set execution context tool execution decisions to empty after applying them b/c they should only
695
750
  # be used once for the current tool calls
@@ -2,6 +2,8 @@
2
2
  #
3
3
  # SPDX-License-Identifier: Apache-2.0
4
4
 
5
+ from copy import deepcopy
6
+
5
7
  from haystack.dataclasses.breakpoints import AgentSnapshot, ToolBreakpoint
6
8
  from haystack.utils import _deserialize_value_with_schema
7
9
 
@@ -31,7 +33,7 @@ def get_tool_calls_and_descriptions_from_snapshot(
31
33
  tool_caused_break_point = break_point.tool_name
32
34
 
33
35
  # Deserialize the tool invoker inputs from the snapshot
34
- tool_invoker_inputs = _deserialize_value_with_schema(agent_snapshot.component_inputs["tool_invoker"])
36
+ tool_invoker_inputs = _deserialize_value_with_schema(deepcopy(agent_snapshot.component_inputs["tool_invoker"]))
35
37
  tool_call_messages = tool_invoker_inputs["messages"]
36
38
  state = tool_invoker_inputs["state"]
37
39
  tool_name_to_tool = {t.name: t for t in tool_invoker_inputs["tools"]}
@@ -3,7 +3,7 @@
3
3
  # SPDX-License-Identifier: Apache-2.0
4
4
 
5
5
  from dataclasses import asdict, dataclass
6
- from typing import Any, Optional
6
+ from typing import Any
7
7
 
8
8
 
9
9
  @dataclass
@@ -23,8 +23,8 @@ class ConfirmationUIResult:
23
23
  """
24
24
 
25
25
  action: str # "confirm", "reject", "modify"
26
- feedback: Optional[str] = None
27
- new_tool_params: Optional[dict[str, Any]] = None
26
+ feedback: str | None = None
27
+ new_tool_params: dict[str, Any] | None = None
28
28
 
29
29
 
30
30
  @dataclass
@@ -49,9 +49,9 @@ class ToolExecutionDecision:
49
49
 
50
50
  tool_name: str
51
51
  execute: bool
52
- tool_call_id: Optional[str] = None
53
- feedback: Optional[str] = None
54
- final_tool_params: Optional[dict[str, Any]] = None
52
+ tool_call_id: str | None = None
53
+ feedback: str | None = None
54
+ final_tool_params: dict[str, Any] | None = None
55
55
 
56
56
  def to_dict(self) -> dict[str, Any]:
57
57
  """
@@ -2,17 +2,13 @@
2
2
  #
3
3
  # SPDX-License-Identifier: Apache-2.0
4
4
 
5
- from typing import Optional
6
-
7
5
 
8
6
  class HITLBreakpointException(Exception):
9
7
  """
10
8
  Exception raised when a tool execution is paused by a ConfirmationStrategy (e.g. BreakpointConfirmationStrategy).
11
9
  """
12
10
 
13
- def __init__(
14
- self, message: str, tool_name: str, snapshot_file_path: str, tool_call_id: Optional[str] = None
15
- ) -> None:
11
+ def __init__(self, message: str, tool_name: str, snapshot_file_path: str, tool_call_id: str | None = None) -> None:
16
12
  """
17
13
  Initialize the HITLBreakpointException.
18
14
 
@@ -3,7 +3,7 @@
3
3
  # SPDX-License-Identifier: Apache-2.0
4
4
 
5
5
  from dataclasses import replace
6
- from typing import TYPE_CHECKING, Any, Optional
6
+ from typing import TYPE_CHECKING, Any
7
7
 
8
8
  from haystack.components.agents.state import State
9
9
  from haystack.components.tools.tool_invoker import ToolInvoker
@@ -52,8 +52,8 @@ class BlockingConfirmationStrategy:
52
52
  tool_name: str,
53
53
  tool_description: str,
54
54
  tool_params: dict[str, Any],
55
- tool_call_id: Optional[str] = None,
56
- confirmation_strategy_context: Optional[dict[str, Any]] = None,
55
+ tool_call_id: str | None = None,
56
+ confirmation_strategy_context: dict[str, Any] | None = None,
57
57
  ) -> ToolExecutionDecision:
58
58
  """
59
59
  Run the human-in-the-loop strategy for a given tool and its parameters.
@@ -125,8 +125,8 @@ class BlockingConfirmationStrategy:
125
125
  tool_name: str,
126
126
  tool_description: str,
127
127
  tool_params: dict[str, Any],
128
- tool_call_id: Optional[str] = None,
129
- confirmation_strategy_context: Optional[dict[str, Any]] = None,
128
+ tool_call_id: str | None = None,
129
+ confirmation_strategy_context: dict[str, Any] | None = None,
130
130
  ) -> ToolExecutionDecision:
131
131
  """
132
132
  Async version of run. Calls the sync run() method by default.
@@ -210,8 +210,8 @@ class BreakpointConfirmationStrategy:
210
210
  tool_name: str,
211
211
  tool_description: str,
212
212
  tool_params: dict[str, Any],
213
- tool_call_id: Optional[str] = None,
214
- confirmation_strategy_context: Optional[dict[str, Any]] = None,
213
+ tool_call_id: str | None = None,
214
+ confirmation_strategy_context: dict[str, Any] | None = None,
215
215
  ) -> ToolExecutionDecision:
216
216
  """
217
217
  Run the breakpoint confirmation strategy for a given tool and its parameters.
@@ -248,8 +248,8 @@ class BreakpointConfirmationStrategy:
248
248
  tool_name: str,
249
249
  tool_description: str,
250
250
  tool_params: dict[str, Any],
251
- tool_call_id: Optional[str] = None,
252
- confirmation_strategy_context: Optional[dict[str, Any]] = None,
251
+ tool_call_id: str | None = None,
252
+ confirmation_strategy_context: dict[str, Any] | None = None,
253
253
  ) -> ToolExecutionDecision:
254
254
  """
255
255
  Async version of run. Calls the sync run() method.
@@ -304,7 +304,7 @@ def _prepare_tool_args(
304
304
  tool: Tool,
305
305
  tool_call_arguments: dict[str, Any],
306
306
  state: State,
307
- streaming_callback: Optional[StreamingCallbackT] = None,
307
+ streaming_callback: StreamingCallbackT | None = None,
308
308
  enable_streaming_passthrough: bool = False,
309
309
  ) -> dict[str, Any]:
310
310
  """