haystack-experimental 0.15.1__py3-none-any.whl → 0.16.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- haystack_experimental/chat_message_stores/in_memory.py +3 -3
- haystack_experimental/chat_message_stores/types.py +2 -2
- haystack_experimental/components/agents/agent.py +174 -119
- haystack_experimental/components/agents/human_in_the_loop/breakpoint.py +3 -1
- haystack_experimental/components/agents/human_in_the_loop/dataclasses.py +6 -6
- haystack_experimental/components/agents/human_in_the_loop/errors.py +1 -5
- haystack_experimental/components/agents/human_in_the_loop/strategies.py +10 -10
- haystack_experimental/components/agents/human_in_the_loop/types.py +5 -5
- haystack_experimental/components/agents/human_in_the_loop/user_interfaces.py +2 -2
- haystack_experimental/components/generators/chat/openai.py +11 -11
- haystack_experimental/components/preprocessors/__init__.py +1 -3
- haystack_experimental/components/retrievers/chat_message_retriever.py +4 -4
- haystack_experimental/components/retrievers/types/protocol.py +3 -3
- haystack_experimental/components/summarizers/llm_summarizer.py +7 -7
- haystack_experimental/core/pipeline/breakpoint.py +6 -6
- haystack_experimental/dataclasses/breakpoints.py +2 -2
- haystack_experimental/utils/hallucination_risk_calculator/dataclasses.py +9 -9
- haystack_experimental/utils/hallucination_risk_calculator/openai_planner.py +4 -4
- haystack_experimental/utils/hallucination_risk_calculator/skeletonization.py +5 -5
- {haystack_experimental-0.15.1.dist-info → haystack_experimental-0.16.0.dist-info}/METADATA +6 -10
- {haystack_experimental-0.15.1.dist-info → haystack_experimental-0.16.0.dist-info}/RECORD +24 -25
- haystack_experimental/components/preprocessors/embedding_based_document_splitter.py +0 -430
- {haystack_experimental-0.15.1.dist-info → haystack_experimental-0.16.0.dist-info}/WHEEL +0 -0
- {haystack_experimental-0.15.1.dist-info → haystack_experimental-0.16.0.dist-info}/licenses/LICENSE +0 -0
- {haystack_experimental-0.15.1.dist-info → haystack_experimental-0.16.0.dist-info}/licenses/LICENSE-MIT.txt +0 -0
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
# SPDX-License-Identifier: Apache-2.0
|
|
4
4
|
|
|
5
5
|
from dataclasses import replace
|
|
6
|
-
from typing import Any, Iterable
|
|
6
|
+
from typing import Any, Iterable
|
|
7
7
|
|
|
8
8
|
from haystack import default_from_dict, default_to_dict
|
|
9
9
|
from haystack.dataclasses import ChatMessage, ChatRole
|
|
@@ -42,7 +42,7 @@ class InMemoryChatMessageStore:
|
|
|
42
42
|
```
|
|
43
43
|
"""
|
|
44
44
|
|
|
45
|
-
def __init__(self, skip_system_messages: bool = True, last_k:
|
|
45
|
+
def __init__(self, skip_system_messages: bool = True, last_k: int | None = 10) -> None:
|
|
46
46
|
"""
|
|
47
47
|
Create an InMemoryChatMessageStore.
|
|
48
48
|
|
|
@@ -135,7 +135,7 @@ class InMemoryChatMessageStore:
|
|
|
135
135
|
|
|
136
136
|
return len(messages_to_write)
|
|
137
137
|
|
|
138
|
-
def retrieve_messages(self, chat_history_id: str, last_k:
|
|
138
|
+
def retrieve_messages(self, chat_history_id: str, last_k: int | None = None) -> list[ChatMessage]:
|
|
139
139
|
"""
|
|
140
140
|
Retrieves all stored chat messages.
|
|
141
141
|
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
#
|
|
3
3
|
# SPDX-License-Identifier: Apache-2.0
|
|
4
4
|
|
|
5
|
-
from typing import Any,
|
|
5
|
+
from typing import Any, Protocol
|
|
6
6
|
|
|
7
7
|
from haystack.dataclasses import ChatMessage
|
|
8
8
|
|
|
@@ -74,7 +74,7 @@ class ChatMessageStore(Protocol):
|
|
|
74
74
|
"""
|
|
75
75
|
...
|
|
76
76
|
|
|
77
|
-
def retrieve_messages(self, chat_history_id: str, last_k:
|
|
77
|
+
def retrieve_messages(self, chat_history_id: str, last_k: int | None = None) -> list[ChatMessage]:
|
|
78
78
|
"""
|
|
79
79
|
Retrieves chat messages from the ChatMessageStore.
|
|
80
80
|
|
|
@@ -7,7 +7,7 @@
|
|
|
7
7
|
|
|
8
8
|
import inspect
|
|
9
9
|
from dataclasses import dataclass
|
|
10
|
-
from typing import Any
|
|
10
|
+
from typing import Any
|
|
11
11
|
|
|
12
12
|
# Monkey patch Haystack's AgentSnapshot with our extended version
|
|
13
13
|
import haystack.dataclasses.breakpoints as hdb
|
|
@@ -28,11 +28,13 @@ from haystack.components.agents.agent import _ExecutionContext as Haystack_Execu
|
|
|
28
28
|
from haystack.components.agents.agent import _schema_from_dict
|
|
29
29
|
from haystack.components.agents.state import replace_values, State
|
|
30
30
|
from haystack.components.generators.chat.types import ChatGenerator
|
|
31
|
-
from haystack.core.errors import PipelineRuntimeError
|
|
31
|
+
from haystack.core.errors import BreakpointException, PipelineRuntimeError
|
|
32
32
|
from haystack.core.pipeline import AsyncPipeline, Pipeline
|
|
33
33
|
from haystack.core.pipeline.breakpoint import (
|
|
34
34
|
_create_pipeline_snapshot_from_chat_generator,
|
|
35
35
|
_create_pipeline_snapshot_from_tool_invoker,
|
|
36
|
+
_save_pipeline_snapshot,
|
|
37
|
+
_should_trigger_tool_invoker_breakpoint,
|
|
36
38
|
)
|
|
37
39
|
from haystack.core.pipeline.utils import _deepcopy_with_exceptions
|
|
38
40
|
from haystack.core.serialization import default_from_dict, import_class_by_name
|
|
@@ -75,8 +77,8 @@ class _ExecutionContext(Haystack_ExecutionContext):
|
|
|
75
77
|
parameter in their `run()` and `run_async()` methods.
|
|
76
78
|
"""
|
|
77
79
|
|
|
78
|
-
tool_execution_decisions:
|
|
79
|
-
confirmation_strategy_context:
|
|
80
|
+
tool_execution_decisions: list[ToolExecutionDecision] | None = None
|
|
81
|
+
confirmation_strategy_context: dict[str, Any] | None = None
|
|
80
82
|
|
|
81
83
|
|
|
82
84
|
class Agent(HaystackAgent):
|
|
@@ -134,16 +136,16 @@ class Agent(HaystackAgent):
|
|
|
134
136
|
self,
|
|
135
137
|
*,
|
|
136
138
|
chat_generator: ChatGenerator,
|
|
137
|
-
tools:
|
|
138
|
-
system_prompt:
|
|
139
|
-
exit_conditions:
|
|
140
|
-
state_schema:
|
|
139
|
+
tools: ToolsType | None = None,
|
|
140
|
+
system_prompt: str | None = None,
|
|
141
|
+
exit_conditions: list[str] | None = None,
|
|
142
|
+
state_schema: dict[str, Any] | None = None,
|
|
141
143
|
max_agent_steps: int = 100,
|
|
142
|
-
streaming_callback:
|
|
144
|
+
streaming_callback: StreamingCallbackT | None = None,
|
|
143
145
|
raise_on_tool_invocation_failure: bool = False,
|
|
144
|
-
confirmation_strategies:
|
|
145
|
-
tool_invoker_kwargs:
|
|
146
|
-
chat_message_store:
|
|
146
|
+
confirmation_strategies: dict[str, ConfirmationStrategy] | None = None,
|
|
147
|
+
tool_invoker_kwargs: dict[str, Any] | None = None,
|
|
148
|
+
chat_message_store: ChatMessageStore | None = None,
|
|
147
149
|
) -> None:
|
|
148
150
|
"""
|
|
149
151
|
Initialize the agent component.
|
|
@@ -188,14 +190,14 @@ class Agent(HaystackAgent):
|
|
|
188
190
|
def _initialize_fresh_execution(
|
|
189
191
|
self,
|
|
190
192
|
messages: list[ChatMessage],
|
|
191
|
-
streaming_callback:
|
|
193
|
+
streaming_callback: StreamingCallbackT | None,
|
|
192
194
|
requires_async: bool,
|
|
193
195
|
*,
|
|
194
|
-
system_prompt:
|
|
195
|
-
generation_kwargs:
|
|
196
|
-
tools:
|
|
197
|
-
confirmation_strategy_context:
|
|
198
|
-
chat_message_store_kwargs:
|
|
196
|
+
system_prompt: str | None = None,
|
|
197
|
+
generation_kwargs: dict[str, Any] | None = None,
|
|
198
|
+
tools: ToolsType | list[str] | None = None,
|
|
199
|
+
confirmation_strategy_context: dict[str, Any] | None = None,
|
|
200
|
+
chat_message_store_kwargs: dict[str, Any] | None = None,
|
|
199
201
|
**kwargs: dict[str, Any],
|
|
200
202
|
) -> _ExecutionContext:
|
|
201
203
|
"""
|
|
@@ -262,12 +264,12 @@ class Agent(HaystackAgent):
|
|
|
262
264
|
def _initialize_from_snapshot( # type: ignore[override]
|
|
263
265
|
self,
|
|
264
266
|
snapshot: AgentSnapshot,
|
|
265
|
-
streaming_callback:
|
|
267
|
+
streaming_callback: StreamingCallbackT | None,
|
|
266
268
|
requires_async: bool,
|
|
267
269
|
*,
|
|
268
|
-
generation_kwargs:
|
|
269
|
-
tools:
|
|
270
|
-
confirmation_strategy_context:
|
|
270
|
+
generation_kwargs: dict[str, Any] | None = None,
|
|
271
|
+
tools: ToolsType | list[str] | None = None,
|
|
272
|
+
confirmation_strategy_context: dict[str, Any] | None = None,
|
|
271
273
|
) -> _ExecutionContext:
|
|
272
274
|
"""
|
|
273
275
|
Initialize execution context from an AgentSnapshot.
|
|
@@ -315,18 +317,18 @@ class Agent(HaystackAgent):
|
|
|
315
317
|
confirmation_strategy_context=confirmation_strategy_context,
|
|
316
318
|
)
|
|
317
319
|
|
|
318
|
-
def run( # type: ignore[override] # noqa: PLR0915
|
|
320
|
+
def run( # type: ignore[override] # noqa: PLR0915 PLR0912
|
|
319
321
|
self,
|
|
320
322
|
messages: list[ChatMessage],
|
|
321
|
-
streaming_callback:
|
|
323
|
+
streaming_callback: StreamingCallbackT | None = None,
|
|
322
324
|
*,
|
|
323
|
-
generation_kwargs:
|
|
324
|
-
break_point:
|
|
325
|
-
snapshot:
|
|
326
|
-
system_prompt:
|
|
327
|
-
tools:
|
|
328
|
-
confirmation_strategy_context:
|
|
329
|
-
chat_message_store_kwargs:
|
|
325
|
+
generation_kwargs: dict[str, Any] | None = None,
|
|
326
|
+
break_point: AgentBreakpoint | None = None,
|
|
327
|
+
snapshot: AgentSnapshot | None = None,
|
|
328
|
+
system_prompt: str | None = None,
|
|
329
|
+
tools: ToolsType | list[str] | None = None,
|
|
330
|
+
confirmation_strategy_context: dict[str, Any] | None = None,
|
|
331
|
+
chat_message_store_kwargs: dict[str, Any] | None = None,
|
|
330
332
|
**kwargs: Any,
|
|
331
333
|
) -> dict[str, Any]:
|
|
332
334
|
"""
|
|
@@ -360,8 +362,6 @@ class Agent(HaystackAgent):
|
|
|
360
362
|
:raises RuntimeError: If the Agent component wasn't warmed up before calling `run()`.
|
|
361
363
|
:raises BreakpointException: If an agent breakpoint is triggered.
|
|
362
364
|
"""
|
|
363
|
-
# We pop parent_snapshot from kwargs to avoid passing it into State.
|
|
364
|
-
parent_snapshot = kwargs.pop("parent_snapshot", None)
|
|
365
365
|
agent_inputs = {
|
|
366
366
|
"messages": messages,
|
|
367
367
|
"streaming_callback": streaming_callback,
|
|
@@ -369,13 +369,9 @@ class Agent(HaystackAgent):
|
|
|
369
369
|
"snapshot": snapshot,
|
|
370
370
|
**kwargs,
|
|
371
371
|
}
|
|
372
|
-
#
|
|
373
|
-
#
|
|
374
|
-
|
|
375
|
-
if len(inspect.signature(self._runtime_checks).parameters) == 2:
|
|
376
|
-
self._runtime_checks(break_point, snapshot) # type: ignore[call-arg] # pylint: disable=too-many-function-args
|
|
377
|
-
else:
|
|
378
|
-
self._runtime_checks(break_point) # type: ignore[call-arg] # pylint: disable=no-value-for-parameter
|
|
372
|
+
# TODO Probably good to add a warning in runtime checks that BreakpointConfirmationStrategy will take
|
|
373
|
+
# precedence over passing a ToolBreakpoint
|
|
374
|
+
self._runtime_checks(break_point)
|
|
379
375
|
|
|
380
376
|
if snapshot:
|
|
381
377
|
exe_context = self._initialize_from_snapshot(
|
|
@@ -403,10 +399,6 @@ class Agent(HaystackAgent):
|
|
|
403
399
|
span.set_content_tag("haystack.agent.input", _deepcopy_with_exceptions(agent_inputs))
|
|
404
400
|
|
|
405
401
|
while exe_context.counter < self.max_agent_steps:
|
|
406
|
-
# Handle breakpoint and ChatGenerator call
|
|
407
|
-
Agent._check_chat_generator_breakpoint(
|
|
408
|
-
execution_context=exe_context, break_point=break_point, parent_snapshot=parent_snapshot
|
|
409
|
-
)
|
|
410
402
|
# We skip the chat generator when restarting from a snapshot from a ToolBreakpoint
|
|
411
403
|
if exe_context.skip_chat_generator:
|
|
412
404
|
llm_messages = exe_context.state.get("messages", [])[-1:]
|
|
@@ -423,14 +415,26 @@ class Agent(HaystackAgent):
|
|
|
423
415
|
},
|
|
424
416
|
component_visits=exe_context.component_visits,
|
|
425
417
|
parent_span=span,
|
|
418
|
+
break_point=break_point.break_point if isinstance(break_point, AgentBreakpoint) else None,
|
|
426
419
|
)
|
|
427
|
-
except PipelineRuntimeError as e:
|
|
428
|
-
|
|
429
|
-
agent_name=
|
|
430
|
-
|
|
431
|
-
|
|
420
|
+
except (BreakpointException, PipelineRuntimeError) as e:
|
|
421
|
+
if isinstance(e, BreakpointException):
|
|
422
|
+
agent_name = break_point.agent_name if break_point else None
|
|
423
|
+
saved_bp = break_point
|
|
424
|
+
else:
|
|
425
|
+
agent_name = getattr(self, "__component_name__", None)
|
|
426
|
+
saved_bp = None
|
|
427
|
+
|
|
428
|
+
e.pipeline_snapshot = _create_pipeline_snapshot_from_chat_generator(
|
|
429
|
+
agent_name=agent_name, execution_context=exe_context, break_point=saved_bp
|
|
432
430
|
)
|
|
433
|
-
e
|
|
431
|
+
if isinstance(e, BreakpointException):
|
|
432
|
+
e._break_point = e.pipeline_snapshot.break_point
|
|
433
|
+
# If Agent is not in a pipeline, we save the snapshot to a file.
|
|
434
|
+
# Checked by __component_name__ not being set.
|
|
435
|
+
if getattr(self, "__component_name__", None) is None:
|
|
436
|
+
full_file_path = _save_pipeline_snapshot(pipeline_snapshot=e.pipeline_snapshot)
|
|
437
|
+
e.pipeline_snapshot_file_path = full_file_path
|
|
434
438
|
raise e
|
|
435
439
|
|
|
436
440
|
llm_messages = result["replies"]
|
|
@@ -441,6 +445,19 @@ class Agent(HaystackAgent):
|
|
|
441
445
|
exe_context.counter += 1
|
|
442
446
|
break
|
|
443
447
|
|
|
448
|
+
# We only pass down the breakpoint if the tool name matches the tool call in the LLM messages
|
|
449
|
+
resolved_break_point = None
|
|
450
|
+
break_point_to_pass = None
|
|
451
|
+
if (
|
|
452
|
+
break_point
|
|
453
|
+
and isinstance(break_point.break_point, ToolBreakpoint)
|
|
454
|
+
and _should_trigger_tool_invoker_breakpoint(
|
|
455
|
+
break_point=break_point.break_point, llm_messages=llm_messages
|
|
456
|
+
)
|
|
457
|
+
):
|
|
458
|
+
resolved_break_point = break_point
|
|
459
|
+
break_point_to_pass = resolved_break_point.break_point
|
|
460
|
+
|
|
444
461
|
# Apply confirmation strategies and update State and messages sent to ToolInvoker
|
|
445
462
|
try:
|
|
446
463
|
# Run confirmation strategies to get updated tool call messages and modified chat history
|
|
@@ -452,8 +469,8 @@ class Agent(HaystackAgent):
|
|
|
452
469
|
# Replace the chat history in state with the modified one
|
|
453
470
|
exe_context.state.set(key="messages", value=new_chat_history, handler_override=replace_values)
|
|
454
471
|
except HITLBreakpointException as tbp_error:
|
|
455
|
-
# We create a break_point to pass
|
|
456
|
-
|
|
472
|
+
# We create a break_point to pass to Pipeline._run_component
|
|
473
|
+
resolved_break_point = AgentBreakpoint(
|
|
457
474
|
agent_name=getattr(self, "__component_name__", ""),
|
|
458
475
|
break_point=ToolBreakpoint(
|
|
459
476
|
component_name="tool_invoker",
|
|
@@ -462,11 +479,9 @@ class Agent(HaystackAgent):
|
|
|
462
479
|
snapshot_file_path=tbp_error.snapshot_file_path,
|
|
463
480
|
),
|
|
464
481
|
)
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
execution_context=exe_context, break_point=break_point, parent_snapshot=parent_snapshot
|
|
469
|
-
)
|
|
482
|
+
break_point_to_pass = resolved_break_point.break_point
|
|
483
|
+
# If we hit a HITL breakpoint, we skip passing modified messages to ToolInvoker
|
|
484
|
+
modified_tool_call_messages = llm_messages
|
|
470
485
|
|
|
471
486
|
# Run ToolInvoker
|
|
472
487
|
try:
|
|
@@ -481,19 +496,28 @@ class Agent(HaystackAgent):
|
|
|
481
496
|
},
|
|
482
497
|
component_visits=exe_context.component_visits,
|
|
483
498
|
parent_span=span,
|
|
499
|
+
break_point=break_point_to_pass,
|
|
484
500
|
)
|
|
485
|
-
except PipelineRuntimeError as e:
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
501
|
+
except (BreakpointException, PipelineRuntimeError) as e:
|
|
502
|
+
if isinstance(e, BreakpointException):
|
|
503
|
+
agent_name = resolved_break_point.agent_name if resolved_break_point else None
|
|
504
|
+
tool_name = e.break_point.tool_name if isinstance(e.break_point, ToolBreakpoint) else None
|
|
505
|
+
saved_bp = resolved_break_point
|
|
506
|
+
else:
|
|
507
|
+
agent_name = getattr(self, "__component_name__", None)
|
|
508
|
+
tool_name = getattr(e.__cause__, "tool_name", None)
|
|
509
|
+
saved_bp = None
|
|
510
|
+
|
|
511
|
+
e.pipeline_snapshot = _create_pipeline_snapshot_from_tool_invoker(
|
|
512
|
+
tool_name=tool_name, agent_name=agent_name, execution_context=exe_context, break_point=saved_bp
|
|
495
513
|
)
|
|
496
|
-
e
|
|
514
|
+
if isinstance(e, BreakpointException):
|
|
515
|
+
e._break_point = e.pipeline_snapshot.break_point
|
|
516
|
+
# If Agent is not in a pipeline, we save the snapshot to a file.
|
|
517
|
+
# Checked by __component_name__ not being set.
|
|
518
|
+
if getattr(self, "__component_name__", None) is None:
|
|
519
|
+
full_file_path = _save_pipeline_snapshot(pipeline_snapshot=e.pipeline_snapshot)
|
|
520
|
+
e.pipeline_snapshot_file_path = full_file_path
|
|
497
521
|
raise e
|
|
498
522
|
|
|
499
523
|
# Set execution context tool execution decisions to empty after applying them b/c they should only
|
|
@@ -531,18 +555,18 @@ class Agent(HaystackAgent):
|
|
|
531
555
|
|
|
532
556
|
return result
|
|
533
557
|
|
|
534
|
-
async def run_async( # type: ignore[override]
|
|
558
|
+
async def run_async( # type: ignore[override] # noqa: PLR0915
|
|
535
559
|
self,
|
|
536
560
|
messages: list[ChatMessage],
|
|
537
|
-
streaming_callback:
|
|
561
|
+
streaming_callback: StreamingCallbackT | None = None,
|
|
538
562
|
*,
|
|
539
|
-
generation_kwargs:
|
|
540
|
-
break_point:
|
|
541
|
-
snapshot:
|
|
542
|
-
system_prompt:
|
|
543
|
-
tools:
|
|
544
|
-
confirmation_strategy_context:
|
|
545
|
-
chat_message_store_kwargs:
|
|
563
|
+
generation_kwargs: dict[str, Any] | None = None,
|
|
564
|
+
break_point: AgentBreakpoint | None = None,
|
|
565
|
+
snapshot: AgentSnapshot | None = None,
|
|
566
|
+
system_prompt: str | None = None,
|
|
567
|
+
tools: ToolsType | list[str] | None = None,
|
|
568
|
+
confirmation_strategy_context: dict[str, Any] | None = None,
|
|
569
|
+
chat_message_store_kwargs: dict[str, Any] | None = None,
|
|
546
570
|
**kwargs: Any,
|
|
547
571
|
) -> dict[str, Any]:
|
|
548
572
|
"""
|
|
@@ -579,8 +603,6 @@ class Agent(HaystackAgent):
|
|
|
579
603
|
:raises RuntimeError: If the Agent component wasn't warmed up before calling `run_async()`.
|
|
580
604
|
:raises BreakpointException: If an agent breakpoint is triggered.
|
|
581
605
|
"""
|
|
582
|
-
# We pop parent_snapshot from kwargs to avoid passing it into State.
|
|
583
|
-
parent_snapshot = kwargs.pop("parent_snapshot", None)
|
|
584
606
|
agent_inputs = {
|
|
585
607
|
"messages": messages,
|
|
586
608
|
"streaming_callback": streaming_callback,
|
|
@@ -588,13 +610,7 @@ class Agent(HaystackAgent):
|
|
|
588
610
|
"snapshot": snapshot,
|
|
589
611
|
**kwargs,
|
|
590
612
|
}
|
|
591
|
-
|
|
592
|
-
# _runtime_checks. This change will be released in Haystack 2.20.0.
|
|
593
|
-
# To maintain compatibility with Haystack 2.19 we check the number of parameters and call accordingly.
|
|
594
|
-
if len(inspect.signature(self._runtime_checks).parameters) == 2:
|
|
595
|
-
self._runtime_checks(break_point, snapshot) # type: ignore[call-arg] # pylint: disable=too-many-function-args
|
|
596
|
-
else:
|
|
597
|
-
self._runtime_checks(break_point) # type: ignore[call-arg] # pylint: disable=no-value-for-parameter
|
|
613
|
+
self._runtime_checks(break_point)
|
|
598
614
|
|
|
599
615
|
if snapshot:
|
|
600
616
|
exe_context = self._initialize_from_snapshot(
|
|
@@ -622,26 +638,39 @@ class Agent(HaystackAgent):
|
|
|
622
638
|
span.set_content_tag("haystack.agent.input", _deepcopy_with_exceptions(agent_inputs))
|
|
623
639
|
|
|
624
640
|
while exe_context.counter < self.max_agent_steps:
|
|
625
|
-
# Handle breakpoint and ChatGenerator call
|
|
626
|
-
self._check_chat_generator_breakpoint(
|
|
627
|
-
execution_context=exe_context, break_point=break_point, parent_snapshot=parent_snapshot
|
|
628
|
-
)
|
|
629
641
|
# We skip the chat generator when restarting from a snapshot from a ToolBreakpoint
|
|
630
642
|
if exe_context.skip_chat_generator:
|
|
631
643
|
llm_messages = exe_context.state.get("messages", [])[-1:]
|
|
632
644
|
# Set to False so the next iteration will call the chat generator
|
|
633
645
|
exe_context.skip_chat_generator = False
|
|
634
646
|
else:
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
647
|
+
try:
|
|
648
|
+
result = await AsyncPipeline._run_component_async(
|
|
649
|
+
component_name="chat_generator",
|
|
650
|
+
component={"instance": self.chat_generator},
|
|
651
|
+
component_inputs={
|
|
652
|
+
"messages": exe_context.state.data["messages"],
|
|
653
|
+
**exe_context.chat_generator_inputs,
|
|
654
|
+
},
|
|
655
|
+
component_visits=exe_context.component_visits,
|
|
656
|
+
parent_span=span,
|
|
657
|
+
break_point=break_point.break_point if isinstance(break_point, AgentBreakpoint) else None,
|
|
658
|
+
)
|
|
659
|
+
except BreakpointException as e:
|
|
660
|
+
e.pipeline_snapshot = _create_pipeline_snapshot_from_chat_generator(
|
|
661
|
+
agent_name=break_point.agent_name if break_point else None,
|
|
662
|
+
execution_context=exe_context,
|
|
663
|
+
break_point=break_point,
|
|
664
|
+
)
|
|
665
|
+
e._break_point = e.pipeline_snapshot.break_point
|
|
666
|
+
# We check if the agent is part of a pipeline by checking for __component_name__
|
|
667
|
+
# If it is not in a pipeline, we save the snapshot to a file.
|
|
668
|
+
in_pipeline = getattr(self, "__component_name__", None) is not None
|
|
669
|
+
if not in_pipeline:
|
|
670
|
+
full_file_path = _save_pipeline_snapshot(pipeline_snapshot=e.pipeline_snapshot)
|
|
671
|
+
e.pipeline_snapshot_file_path = full_file_path
|
|
672
|
+
raise e
|
|
673
|
+
|
|
645
674
|
llm_messages = result["replies"]
|
|
646
675
|
exe_context.state.set("messages", llm_messages)
|
|
647
676
|
|
|
@@ -650,6 +679,19 @@ class Agent(HaystackAgent):
|
|
|
650
679
|
exe_context.counter += 1
|
|
651
680
|
break
|
|
652
681
|
|
|
682
|
+
# We only pass down the breakpoint if the tool name matches the tool call in the LLM messages
|
|
683
|
+
resolved_break_point = None
|
|
684
|
+
break_point_to_pass = None
|
|
685
|
+
if (
|
|
686
|
+
break_point
|
|
687
|
+
and isinstance(break_point.break_point, ToolBreakpoint)
|
|
688
|
+
and _should_trigger_tool_invoker_breakpoint(
|
|
689
|
+
break_point=break_point.break_point, llm_messages=llm_messages
|
|
690
|
+
)
|
|
691
|
+
):
|
|
692
|
+
resolved_break_point = break_point
|
|
693
|
+
break_point_to_pass = resolved_break_point.break_point
|
|
694
|
+
|
|
653
695
|
# Apply confirmation strategies and update State and messages sent to ToolInvoker
|
|
654
696
|
try:
|
|
655
697
|
# Run confirmation strategies to get updated tool call messages and modified chat history (async)
|
|
@@ -662,7 +704,7 @@ class Agent(HaystackAgent):
|
|
|
662
704
|
exe_context.state.set(key="messages", value=new_chat_history, handler_override=replace_values)
|
|
663
705
|
except HITLBreakpointException as tbp_error:
|
|
664
706
|
# We create a break_point to pass into _check_tool_invoker_breakpoint
|
|
665
|
-
|
|
707
|
+
resolved_break_point = AgentBreakpoint(
|
|
666
708
|
agent_name=getattr(self, "__component_name__", ""),
|
|
667
709
|
break_point=ToolBreakpoint(
|
|
668
710
|
component_name="tool_invoker",
|
|
@@ -671,25 +713,38 @@ class Agent(HaystackAgent):
|
|
|
671
713
|
snapshot_file_path=tbp_error.snapshot_file_path,
|
|
672
714
|
),
|
|
673
715
|
)
|
|
716
|
+
break_point_to_pass = resolved_break_point.break_point
|
|
717
|
+
# If we hit a HITL breakpoint, we skip passing modified messages to ToolInvoker
|
|
718
|
+
modified_tool_call_messages = llm_messages
|
|
674
719
|
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
720
|
+
try:
|
|
721
|
+
# We only send the messages from the LLM to the tool invoker
|
|
722
|
+
tool_invoker_result = await AsyncPipeline._run_component_async(
|
|
723
|
+
component_name="tool_invoker",
|
|
724
|
+
component={"instance": self._tool_invoker},
|
|
725
|
+
component_inputs={
|
|
726
|
+
"messages": modified_tool_call_messages,
|
|
727
|
+
"state": exe_context.state,
|
|
728
|
+
**exe_context.tool_invoker_inputs,
|
|
729
|
+
},
|
|
730
|
+
component_visits=exe_context.component_visits,
|
|
731
|
+
parent_span=span,
|
|
732
|
+
break_point=break_point_to_pass,
|
|
733
|
+
)
|
|
734
|
+
except BreakpointException as e:
|
|
735
|
+
e.pipeline_snapshot = _create_pipeline_snapshot_from_tool_invoker(
|
|
736
|
+
tool_name=e.break_point.tool_name if isinstance(e.break_point, ToolBreakpoint) else None,
|
|
737
|
+
agent_name=resolved_break_point.agent_name if resolved_break_point else None,
|
|
738
|
+
execution_context=exe_context,
|
|
739
|
+
break_point=resolved_break_point,
|
|
740
|
+
)
|
|
741
|
+
e._break_point = e.pipeline_snapshot.break_point
|
|
742
|
+
# If Agent is not in a pipeline, we save the snapshot to a file.
|
|
743
|
+
# Checked by __component_name__ not being set.
|
|
744
|
+
if getattr(self, "__component_name__", None) is None:
|
|
745
|
+
full_file_path = _save_pipeline_snapshot(pipeline_snapshot=e.pipeline_snapshot)
|
|
746
|
+
e.pipeline_snapshot_file_path = full_file_path
|
|
747
|
+
raise e
|
|
693
748
|
|
|
694
749
|
# Set execution context tool execution decisions to empty after applying them b/c they should only
|
|
695
750
|
# be used once for the current tool calls
|
|
@@ -2,6 +2,8 @@
|
|
|
2
2
|
#
|
|
3
3
|
# SPDX-License-Identifier: Apache-2.0
|
|
4
4
|
|
|
5
|
+
from copy import deepcopy
|
|
6
|
+
|
|
5
7
|
from haystack.dataclasses.breakpoints import AgentSnapshot, ToolBreakpoint
|
|
6
8
|
from haystack.utils import _deserialize_value_with_schema
|
|
7
9
|
|
|
@@ -31,7 +33,7 @@ def get_tool_calls_and_descriptions_from_snapshot(
|
|
|
31
33
|
tool_caused_break_point = break_point.tool_name
|
|
32
34
|
|
|
33
35
|
# Deserialize the tool invoker inputs from the snapshot
|
|
34
|
-
tool_invoker_inputs = _deserialize_value_with_schema(agent_snapshot.component_inputs["tool_invoker"])
|
|
36
|
+
tool_invoker_inputs = _deserialize_value_with_schema(deepcopy(agent_snapshot.component_inputs["tool_invoker"]))
|
|
35
37
|
tool_call_messages = tool_invoker_inputs["messages"]
|
|
36
38
|
state = tool_invoker_inputs["state"]
|
|
37
39
|
tool_name_to_tool = {t.name: t for t in tool_invoker_inputs["tools"]}
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
# SPDX-License-Identifier: Apache-2.0
|
|
4
4
|
|
|
5
5
|
from dataclasses import asdict, dataclass
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
@dataclass
|
|
@@ -23,8 +23,8 @@ class ConfirmationUIResult:
|
|
|
23
23
|
"""
|
|
24
24
|
|
|
25
25
|
action: str # "confirm", "reject", "modify"
|
|
26
|
-
feedback:
|
|
27
|
-
new_tool_params:
|
|
26
|
+
feedback: str | None = None
|
|
27
|
+
new_tool_params: dict[str, Any] | None = None
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
@dataclass
|
|
@@ -49,9 +49,9 @@ class ToolExecutionDecision:
|
|
|
49
49
|
|
|
50
50
|
tool_name: str
|
|
51
51
|
execute: bool
|
|
52
|
-
tool_call_id:
|
|
53
|
-
feedback:
|
|
54
|
-
final_tool_params:
|
|
52
|
+
tool_call_id: str | None = None
|
|
53
|
+
feedback: str | None = None
|
|
54
|
+
final_tool_params: dict[str, Any] | None = None
|
|
55
55
|
|
|
56
56
|
def to_dict(self) -> dict[str, Any]:
|
|
57
57
|
"""
|
|
@@ -2,17 +2,13 @@
|
|
|
2
2
|
#
|
|
3
3
|
# SPDX-License-Identifier: Apache-2.0
|
|
4
4
|
|
|
5
|
-
from typing import Optional
|
|
6
|
-
|
|
7
5
|
|
|
8
6
|
class HITLBreakpointException(Exception):
|
|
9
7
|
"""
|
|
10
8
|
Exception raised when a tool execution is paused by a ConfirmationStrategy (e.g. BreakpointConfirmationStrategy).
|
|
11
9
|
"""
|
|
12
10
|
|
|
13
|
-
def __init__(
|
|
14
|
-
self, message: str, tool_name: str, snapshot_file_path: str, tool_call_id: Optional[str] = None
|
|
15
|
-
) -> None:
|
|
11
|
+
def __init__(self, message: str, tool_name: str, snapshot_file_path: str, tool_call_id: str | None = None) -> None:
|
|
16
12
|
"""
|
|
17
13
|
Initialize the HITLBreakpointException.
|
|
18
14
|
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
# SPDX-License-Identifier: Apache-2.0
|
|
4
4
|
|
|
5
5
|
from dataclasses import replace
|
|
6
|
-
from typing import TYPE_CHECKING, Any
|
|
6
|
+
from typing import TYPE_CHECKING, Any
|
|
7
7
|
|
|
8
8
|
from haystack.components.agents.state import State
|
|
9
9
|
from haystack.components.tools.tool_invoker import ToolInvoker
|
|
@@ -52,8 +52,8 @@ class BlockingConfirmationStrategy:
|
|
|
52
52
|
tool_name: str,
|
|
53
53
|
tool_description: str,
|
|
54
54
|
tool_params: dict[str, Any],
|
|
55
|
-
tool_call_id:
|
|
56
|
-
confirmation_strategy_context:
|
|
55
|
+
tool_call_id: str | None = None,
|
|
56
|
+
confirmation_strategy_context: dict[str, Any] | None = None,
|
|
57
57
|
) -> ToolExecutionDecision:
|
|
58
58
|
"""
|
|
59
59
|
Run the human-in-the-loop strategy for a given tool and its parameters.
|
|
@@ -125,8 +125,8 @@ class BlockingConfirmationStrategy:
|
|
|
125
125
|
tool_name: str,
|
|
126
126
|
tool_description: str,
|
|
127
127
|
tool_params: dict[str, Any],
|
|
128
|
-
tool_call_id:
|
|
129
|
-
confirmation_strategy_context:
|
|
128
|
+
tool_call_id: str | None = None,
|
|
129
|
+
confirmation_strategy_context: dict[str, Any] | None = None,
|
|
130
130
|
) -> ToolExecutionDecision:
|
|
131
131
|
"""
|
|
132
132
|
Async version of run. Calls the sync run() method by default.
|
|
@@ -210,8 +210,8 @@ class BreakpointConfirmationStrategy:
|
|
|
210
210
|
tool_name: str,
|
|
211
211
|
tool_description: str,
|
|
212
212
|
tool_params: dict[str, Any],
|
|
213
|
-
tool_call_id:
|
|
214
|
-
confirmation_strategy_context:
|
|
213
|
+
tool_call_id: str | None = None,
|
|
214
|
+
confirmation_strategy_context: dict[str, Any] | None = None,
|
|
215
215
|
) -> ToolExecutionDecision:
|
|
216
216
|
"""
|
|
217
217
|
Run the breakpoint confirmation strategy for a given tool and its parameters.
|
|
@@ -248,8 +248,8 @@ class BreakpointConfirmationStrategy:
|
|
|
248
248
|
tool_name: str,
|
|
249
249
|
tool_description: str,
|
|
250
250
|
tool_params: dict[str, Any],
|
|
251
|
-
tool_call_id:
|
|
252
|
-
confirmation_strategy_context:
|
|
251
|
+
tool_call_id: str | None = None,
|
|
252
|
+
confirmation_strategy_context: dict[str, Any] | None = None,
|
|
253
253
|
) -> ToolExecutionDecision:
|
|
254
254
|
"""
|
|
255
255
|
Async version of run. Calls the sync run() method.
|
|
@@ -304,7 +304,7 @@ def _prepare_tool_args(
|
|
|
304
304
|
tool: Tool,
|
|
305
305
|
tool_call_arguments: dict[str, Any],
|
|
306
306
|
state: State,
|
|
307
|
-
streaming_callback:
|
|
307
|
+
streaming_callback: StreamingCallbackT | None = None,
|
|
308
308
|
enable_streaming_passthrough: bool = False,
|
|
309
309
|
) -> dict[str, Any]:
|
|
310
310
|
"""
|