haystack-experimental 0.14.2__py3-none-any.whl → 0.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- haystack_experimental/chat_message_stores/__init__.py +1 -1
- haystack_experimental/chat_message_stores/in_memory.py +176 -31
- haystack_experimental/chat_message_stores/types.py +33 -21
- haystack_experimental/components/agents/agent.py +184 -35
- haystack_experimental/components/agents/human_in_the_loop/strategies.py +220 -3
- haystack_experimental/components/agents/human_in_the_loop/types.py +36 -1
- haystack_experimental/components/embedders/types/protocol.py +2 -2
- haystack_experimental/components/preprocessors/__init__.py +2 -0
- haystack_experimental/components/preprocessors/embedding_based_document_splitter.py +16 -16
- haystack_experimental/components/preprocessors/md_header_level_inferrer.py +2 -2
- haystack_experimental/components/retrievers/__init__.py +1 -3
- haystack_experimental/components/retrievers/chat_message_retriever.py +57 -26
- haystack_experimental/components/writers/__init__.py +1 -1
- haystack_experimental/components/writers/chat_message_writer.py +25 -22
- haystack_experimental/core/pipeline/breakpoint.py +5 -3
- {haystack_experimental-0.14.2.dist-info → haystack_experimental-0.15.0.dist-info}/METADATA +24 -31
- {haystack_experimental-0.14.2.dist-info → haystack_experimental-0.15.0.dist-info}/RECORD +20 -27
- {haystack_experimental-0.14.2.dist-info → haystack_experimental-0.15.0.dist-info}/WHEEL +1 -1
- haystack_experimental/components/query/__init__.py +0 -18
- haystack_experimental/components/query/query_expander.py +0 -299
- haystack_experimental/components/retrievers/multi_query_embedding_retriever.py +0 -180
- haystack_experimental/components/retrievers/multi_query_text_retriever.py +0 -158
- haystack_experimental/super_components/__init__.py +0 -3
- haystack_experimental/super_components/indexers/__init__.py +0 -11
- haystack_experimental/super_components/indexers/sentence_transformers_document_indexer.py +0 -199
- {haystack_experimental-0.14.2.dist-info → haystack_experimental-0.15.0.dist-info}/licenses/LICENSE +0 -0
- {haystack_experimental-0.14.2.dist-info → haystack_experimental-0.15.0.dist-info}/licenses/LICENSE-MIT.txt +0 -0
|
@@ -47,7 +47,13 @@ class BlockingConfirmationStrategy:
|
|
|
47
47
|
self.confirmation_ui = confirmation_ui
|
|
48
48
|
|
|
49
49
|
def run(
|
|
50
|
-
self,
|
|
50
|
+
self,
|
|
51
|
+
*,
|
|
52
|
+
tool_name: str,
|
|
53
|
+
tool_description: str,
|
|
54
|
+
tool_params: dict[str, Any],
|
|
55
|
+
tool_call_id: Optional[str] = None,
|
|
56
|
+
confirmation_strategy_context: Optional[dict[str, Any]] = None,
|
|
51
57
|
) -> ToolExecutionDecision:
|
|
52
58
|
"""
|
|
53
59
|
Run the human-in-the-loop strategy for a given tool and its parameters.
|
|
@@ -61,6 +67,10 @@ class BlockingConfirmationStrategy:
|
|
|
61
67
|
:param tool_call_id:
|
|
62
68
|
Optional unique identifier for the tool call. This can be used to track and correlate the decision with a
|
|
63
69
|
specific tool invocation.
|
|
70
|
+
:param confirmation_strategy_context:
|
|
71
|
+
Optional dictionary for passing request-scoped resources. Useful in web/server environments
|
|
72
|
+
to provide per-request objects (e.g., WebSocket connections, async queues, Redis pub/sub clients)
|
|
73
|
+
that strategies can use for non-blocking user interaction.
|
|
64
74
|
|
|
65
75
|
:returns:
|
|
66
76
|
A ToolExecutionDecision indicating whether to execute the tool with the given parameters, or a
|
|
@@ -109,6 +119,40 @@ class BlockingConfirmationStrategy:
|
|
|
109
119
|
tool_name=tool_name, execute=True, tool_call_id=tool_call_id, final_tool_params=tool_params
|
|
110
120
|
)
|
|
111
121
|
|
|
122
|
+
async def run_async(
|
|
123
|
+
self,
|
|
124
|
+
*,
|
|
125
|
+
tool_name: str,
|
|
126
|
+
tool_description: str,
|
|
127
|
+
tool_params: dict[str, Any],
|
|
128
|
+
tool_call_id: Optional[str] = None,
|
|
129
|
+
confirmation_strategy_context: Optional[dict[str, Any]] = None,
|
|
130
|
+
) -> ToolExecutionDecision:
|
|
131
|
+
"""
|
|
132
|
+
Async version of run. Calls the sync run() method by default.
|
|
133
|
+
|
|
134
|
+
:param tool_name:
|
|
135
|
+
The name of the tool to be executed.
|
|
136
|
+
:param tool_description:
|
|
137
|
+
The description of the tool.
|
|
138
|
+
:param tool_params:
|
|
139
|
+
The parameters to be passed to the tool.
|
|
140
|
+
:param tool_call_id:
|
|
141
|
+
Optional unique identifier for the tool call.
|
|
142
|
+
:param confirmation_strategy_context:
|
|
143
|
+
Optional dictionary for passing request-scoped resources.
|
|
144
|
+
|
|
145
|
+
:returns:
|
|
146
|
+
A ToolExecutionDecision indicating whether to execute the tool with the given parameters.
|
|
147
|
+
"""
|
|
148
|
+
return self.run(
|
|
149
|
+
tool_name=tool_name,
|
|
150
|
+
tool_description=tool_description,
|
|
151
|
+
tool_params=tool_params,
|
|
152
|
+
tool_call_id=tool_call_id,
|
|
153
|
+
confirmation_strategy_context=confirmation_strategy_context,
|
|
154
|
+
)
|
|
155
|
+
|
|
112
156
|
def to_dict(self) -> dict[str, Any]:
|
|
113
157
|
"""
|
|
114
158
|
Serializes the BlockingConfirmationStrategy to a dictionary.
|
|
@@ -161,7 +205,13 @@ class BreakpointConfirmationStrategy:
|
|
|
161
205
|
self.snapshot_file_path = snapshot_file_path
|
|
162
206
|
|
|
163
207
|
def run(
|
|
164
|
-
self,
|
|
208
|
+
self,
|
|
209
|
+
*,
|
|
210
|
+
tool_name: str,
|
|
211
|
+
tool_description: str,
|
|
212
|
+
tool_params: dict[str, Any],
|
|
213
|
+
tool_call_id: Optional[str] = None,
|
|
214
|
+
confirmation_strategy_context: Optional[dict[str, Any]] = None,
|
|
165
215
|
) -> ToolExecutionDecision:
|
|
166
216
|
"""
|
|
167
217
|
Run the breakpoint confirmation strategy for a given tool and its parameters.
|
|
@@ -175,6 +225,9 @@ class BreakpointConfirmationStrategy:
|
|
|
175
225
|
:param tool_call_id:
|
|
176
226
|
Optional unique identifier for the tool call. This can be used to track and correlate the decision with a
|
|
177
227
|
specific tool invocation.
|
|
228
|
+
:param confirmation_strategy_context:
|
|
229
|
+
Optional dictionary for passing request-scoped resources. Not used by this strategy but included for
|
|
230
|
+
interface compatibility.
|
|
178
231
|
|
|
179
232
|
:raises HITLBreakpointException:
|
|
180
233
|
Always raises an `HITLBreakpointException` exception to signal that user confirmation is required.
|
|
@@ -189,6 +242,43 @@ class BreakpointConfirmationStrategy:
|
|
|
189
242
|
snapshot_file_path=self.snapshot_file_path,
|
|
190
243
|
)
|
|
191
244
|
|
|
245
|
+
async def run_async(
|
|
246
|
+
self,
|
|
247
|
+
*,
|
|
248
|
+
tool_name: str,
|
|
249
|
+
tool_description: str,
|
|
250
|
+
tool_params: dict[str, Any],
|
|
251
|
+
tool_call_id: Optional[str] = None,
|
|
252
|
+
confirmation_strategy_context: Optional[dict[str, Any]] = None,
|
|
253
|
+
) -> ToolExecutionDecision:
|
|
254
|
+
"""
|
|
255
|
+
Async version of run. Calls the sync run() method.
|
|
256
|
+
|
|
257
|
+
:param tool_name:
|
|
258
|
+
The name of the tool to be executed.
|
|
259
|
+
:param tool_description:
|
|
260
|
+
The description of the tool.
|
|
261
|
+
:param tool_params:
|
|
262
|
+
The parameters to be passed to the tool.
|
|
263
|
+
:param tool_call_id:
|
|
264
|
+
Optional unique identifier for the tool call.
|
|
265
|
+
:param confirmation_strategy_context:
|
|
266
|
+
Optional dictionary for passing request-scoped resources.
|
|
267
|
+
|
|
268
|
+
:raises HITLBreakpointException:
|
|
269
|
+
Always raises an `HITLBreakpointException` exception to signal that user confirmation is required.
|
|
270
|
+
|
|
271
|
+
:returns:
|
|
272
|
+
This method does not return; it always raises an exception.
|
|
273
|
+
"""
|
|
274
|
+
return self.run(
|
|
275
|
+
tool_name=tool_name,
|
|
276
|
+
tool_description=tool_description,
|
|
277
|
+
tool_params=tool_params,
|
|
278
|
+
tool_call_id=tool_call_id,
|
|
279
|
+
confirmation_strategy_context=confirmation_strategy_context,
|
|
280
|
+
)
|
|
281
|
+
|
|
192
282
|
def to_dict(self) -> dict[str, Any]:
|
|
193
283
|
"""
|
|
194
284
|
Serializes the BreakpointConfirmationStrategy to a dictionary.
|
|
@@ -285,6 +375,46 @@ def _process_confirmation_strategies(
|
|
|
285
375
|
return modified_tool_call_messages, new_chat_history
|
|
286
376
|
|
|
287
377
|
|
|
378
|
+
async def _process_confirmation_strategies_async(
|
|
379
|
+
*,
|
|
380
|
+
confirmation_strategies: dict[str, ConfirmationStrategy],
|
|
381
|
+
messages_with_tool_calls: list[ChatMessage],
|
|
382
|
+
execution_context: "_ExecutionContext",
|
|
383
|
+
) -> tuple[list[ChatMessage], list[ChatMessage]]:
|
|
384
|
+
"""
|
|
385
|
+
Async version of _process_confirmation_strategies.
|
|
386
|
+
|
|
387
|
+
Run the confirmation strategies and return modified tool call messages and updated chat history.
|
|
388
|
+
|
|
389
|
+
:param confirmation_strategies: Mapping of tool names to their corresponding confirmation strategies
|
|
390
|
+
:param messages_with_tool_calls: Chat messages containing tool calls
|
|
391
|
+
:param execution_context: The current execution context of the agent
|
|
392
|
+
:returns:
|
|
393
|
+
Tuple of modified messages with confirmed tool calls and updated chat history
|
|
394
|
+
"""
|
|
395
|
+
# Run confirmation strategies and get tool execution decisions (async version)
|
|
396
|
+
teds = await _run_confirmation_strategies_async(
|
|
397
|
+
confirmation_strategies=confirmation_strategies,
|
|
398
|
+
messages_with_tool_calls=messages_with_tool_calls,
|
|
399
|
+
execution_context=execution_context,
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
# Apply tool execution decisions to messages_with_tool_calls
|
|
403
|
+
rejection_messages, modified_tool_call_messages = _apply_tool_execution_decisions(
|
|
404
|
+
tool_call_messages=messages_with_tool_calls,
|
|
405
|
+
tool_execution_decisions=teds,
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
# Update the chat history with rejection messages and new tool call messages
|
|
409
|
+
new_chat_history = _update_chat_history(
|
|
410
|
+
chat_history=execution_context.state.get("messages"),
|
|
411
|
+
rejection_messages=rejection_messages,
|
|
412
|
+
tool_call_and_explanation_messages=modified_tool_call_messages,
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
return modified_tool_call_messages, new_chat_history
|
|
416
|
+
|
|
417
|
+
|
|
288
418
|
def _run_confirmation_strategies(
|
|
289
419
|
confirmation_strategies: dict[str, ConfirmationStrategy],
|
|
290
420
|
messages_with_tool_calls: list[ChatMessage],
|
|
@@ -344,13 +474,100 @@ def _run_confirmation_strategies(
|
|
|
344
474
|
# If not, run the confirmation strategy
|
|
345
475
|
if not ted:
|
|
346
476
|
ted = confirmation_strategies[tool_name].run(
|
|
347
|
-
tool_name=tool_name,
|
|
477
|
+
tool_name=tool_name,
|
|
478
|
+
tool_description=tool_to_invoke.description,
|
|
479
|
+
tool_params=final_args,
|
|
480
|
+
tool_call_id=tool_call.id,
|
|
481
|
+
confirmation_strategy_context=execution_context.confirmation_strategy_context,
|
|
348
482
|
)
|
|
349
483
|
teds.append(ted)
|
|
350
484
|
|
|
351
485
|
return teds
|
|
352
486
|
|
|
353
487
|
|
|
488
|
+
async def _run_confirmation_strategies_async(
|
|
489
|
+
confirmation_strategies: dict[str, ConfirmationStrategy],
|
|
490
|
+
messages_with_tool_calls: list[ChatMessage],
|
|
491
|
+
execution_context: "_ExecutionContext",
|
|
492
|
+
) -> list[ToolExecutionDecision]:
|
|
493
|
+
"""
|
|
494
|
+
Async version of _run_confirmation_strategies.
|
|
495
|
+
|
|
496
|
+
Run confirmation strategies for tool calls in the provided chat messages.
|
|
497
|
+
|
|
498
|
+
:param confirmation_strategies: Mapping of tool names to their corresponding confirmation strategies
|
|
499
|
+
:param messages_with_tool_calls: Messages containing tool calls to process
|
|
500
|
+
:param execution_context: The current execution context containing state and inputs
|
|
501
|
+
:returns:
|
|
502
|
+
A list of ToolExecutionDecision objects representing the decisions made for each tool call.
|
|
503
|
+
"""
|
|
504
|
+
state = execution_context.state
|
|
505
|
+
tools_with_names = {tool.name: tool for tool in execution_context.tool_invoker_inputs["tools"]}
|
|
506
|
+
existing_teds = execution_context.tool_execution_decisions if execution_context.tool_execution_decisions else []
|
|
507
|
+
existing_teds_by_name = {ted.tool_name: ted for ted in existing_teds if ted.tool_name}
|
|
508
|
+
existing_teds_by_id = {ted.tool_call_id: ted for ted in existing_teds if ted.tool_call_id}
|
|
509
|
+
|
|
510
|
+
teds = []
|
|
511
|
+
for message in messages_with_tool_calls:
|
|
512
|
+
if not message.tool_calls:
|
|
513
|
+
continue
|
|
514
|
+
|
|
515
|
+
for tool_call in message.tool_calls:
|
|
516
|
+
tool_name = tool_call.tool_name
|
|
517
|
+
tool_to_invoke = tools_with_names[tool_name]
|
|
518
|
+
|
|
519
|
+
# Prepare final tool args
|
|
520
|
+
final_args = _prepare_tool_args(
|
|
521
|
+
tool=tool_to_invoke,
|
|
522
|
+
tool_call_arguments=tool_call.arguments,
|
|
523
|
+
state=state,
|
|
524
|
+
streaming_callback=execution_context.tool_invoker_inputs.get("streaming_callback"),
|
|
525
|
+
enable_streaming_passthrough=execution_context.tool_invoker_inputs.get(
|
|
526
|
+
"enable_streaming_passthrough", False
|
|
527
|
+
),
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
# Get tool execution decisions from confirmation strategies
|
|
531
|
+
# If no confirmation strategy is defined for this tool, proceed with execution
|
|
532
|
+
if tool_name not in confirmation_strategies:
|
|
533
|
+
teds.append(
|
|
534
|
+
ToolExecutionDecision(
|
|
535
|
+
tool_call_id=tool_call.id,
|
|
536
|
+
tool_name=tool_name,
|
|
537
|
+
execute=True,
|
|
538
|
+
final_tool_params=final_args,
|
|
539
|
+
)
|
|
540
|
+
)
|
|
541
|
+
continue
|
|
542
|
+
|
|
543
|
+
# Check if there's already a decision for this tool call in the execution context
|
|
544
|
+
ted = existing_teds_by_id.get(tool_call.id or "") or existing_teds_by_name.get(tool_name)
|
|
545
|
+
|
|
546
|
+
# If not, run the confirmation strategy (async version)
|
|
547
|
+
if not ted:
|
|
548
|
+
strategy = confirmation_strategies[tool_name]
|
|
549
|
+
# Use run_async if available, otherwise fall back to sync run
|
|
550
|
+
if hasattr(strategy, "run_async"):
|
|
551
|
+
ted = await strategy.run_async(
|
|
552
|
+
tool_name=tool_name,
|
|
553
|
+
tool_description=tool_to_invoke.description,
|
|
554
|
+
tool_params=final_args,
|
|
555
|
+
tool_call_id=tool_call.id,
|
|
556
|
+
confirmation_strategy_context=execution_context.confirmation_strategy_context,
|
|
557
|
+
)
|
|
558
|
+
else:
|
|
559
|
+
ted = strategy.run(
|
|
560
|
+
tool_name=tool_name,
|
|
561
|
+
tool_description=tool_to_invoke.description,
|
|
562
|
+
tool_params=final_args,
|
|
563
|
+
tool_call_id=tool_call.id,
|
|
564
|
+
confirmation_strategy_context=execution_context.confirmation_strategy_context,
|
|
565
|
+
)
|
|
566
|
+
teds.append(ted)
|
|
567
|
+
|
|
568
|
+
return teds
|
|
569
|
+
|
|
570
|
+
|
|
354
571
|
def _apply_tool_execution_decisions(
|
|
355
572
|
tool_call_messages: list[ChatMessage], tool_execution_decisions: list[ToolExecutionDecision]
|
|
356
573
|
) -> tuple[list[ChatMessage], list[ChatMessage]]:
|
|
@@ -63,7 +63,12 @@ class ConfirmationPolicy(Protocol):
|
|
|
63
63
|
|
|
64
64
|
class ConfirmationStrategy(Protocol):
|
|
65
65
|
def run(
|
|
66
|
-
self,
|
|
66
|
+
self,
|
|
67
|
+
tool_name: str,
|
|
68
|
+
tool_description: str,
|
|
69
|
+
tool_params: dict[str, Any],
|
|
70
|
+
tool_call_id: Optional[str] = None,
|
|
71
|
+
**kwargs: Optional[dict[str, Any]],
|
|
67
72
|
) -> ToolExecutionDecision:
|
|
68
73
|
"""
|
|
69
74
|
Run the confirmation strategy for a given tool and its parameters.
|
|
@@ -73,6 +78,36 @@ class ConfirmationStrategy(Protocol):
|
|
|
73
78
|
:param tool_params: The parameters to be passed to the tool.
|
|
74
79
|
:param tool_call_id: Optional unique identifier for the tool call. This can be used to track and correlate
|
|
75
80
|
the decision with a specific tool invocation.
|
|
81
|
+
:param kwargs: Additional keyword arguments. Implementations may accept `confirmation_strategy_context`
|
|
82
|
+
for passing request-scoped resources (e.g., WebSocket connections, async queues) in web/server
|
|
83
|
+
environments.
|
|
84
|
+
|
|
85
|
+
:returns:
|
|
86
|
+
The result of the confirmation strategy (e.g., tool output, rejection message, etc.).
|
|
87
|
+
"""
|
|
88
|
+
...
|
|
89
|
+
|
|
90
|
+
async def run_async(
|
|
91
|
+
self,
|
|
92
|
+
tool_name: str,
|
|
93
|
+
tool_description: str,
|
|
94
|
+
tool_params: dict[str, Any],
|
|
95
|
+
tool_call_id: Optional[str] = None,
|
|
96
|
+
**kwargs: Optional[dict[str, Any]],
|
|
97
|
+
) -> ToolExecutionDecision:
|
|
98
|
+
"""
|
|
99
|
+
Async version of run. Run the confirmation strategy for a given tool and its parameters.
|
|
100
|
+
|
|
101
|
+
Default implementation calls the sync run() method. Override for true async behavior.
|
|
102
|
+
|
|
103
|
+
:param tool_name: The name of the tool to be executed.
|
|
104
|
+
:param tool_description: The description of the tool.
|
|
105
|
+
:param tool_params: The parameters to be passed to the tool.
|
|
106
|
+
:param tool_call_id: Optional unique identifier for the tool call. This can be used to track and correlate
|
|
107
|
+
the decision with a specific tool invocation.
|
|
108
|
+
:param kwargs: Additional keyword arguments. Implementations may accept `confirmation_strategy_context`
|
|
109
|
+
for passing request-scoped resources (e.g., WebSocket connections, async queues) in web/server
|
|
110
|
+
environments.
|
|
76
111
|
|
|
77
112
|
:returns:
|
|
78
113
|
The result of the confirmation strategy (e.g., tool output, rejection message, etc.).
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
#
|
|
3
3
|
# SPDX-License-Identifier: Apache-2.0
|
|
4
4
|
|
|
5
|
-
from typing import Any,
|
|
5
|
+
from typing import Any, Protocol
|
|
6
6
|
|
|
7
7
|
from haystack import Document
|
|
8
8
|
|
|
@@ -15,7 +15,7 @@ class DocumentEmbedder(Protocol):
|
|
|
15
15
|
Protocol for Document Embedders.
|
|
16
16
|
"""
|
|
17
17
|
|
|
18
|
-
def run(self, documents:
|
|
18
|
+
def run(self, documents: list[Document]) -> dict[str, Any]:
|
|
19
19
|
"""
|
|
20
20
|
Generate embeddings for the input documents.
|
|
21
21
|
|
|
@@ -9,10 +9,12 @@ from lazy_imports import LazyImporter
|
|
|
9
9
|
|
|
10
10
|
_import_structure = {
|
|
11
11
|
"embedding_based_document_splitter": ["EmbeddingBasedDocumentSplitter"],
|
|
12
|
+
"md_header_level_inferrer": ["MarkdownHeaderLevelInferrer"],
|
|
12
13
|
}
|
|
13
14
|
|
|
14
15
|
if TYPE_CHECKING:
|
|
15
16
|
from .embedding_based_document_splitter import EmbeddingBasedDocumentSplitter
|
|
17
|
+
from .md_header_level_inferrer import MarkdownHeaderLevelInferrer
|
|
16
18
|
|
|
17
19
|
else:
|
|
18
20
|
sys.modules[__name__] = LazyImporter(name=__name__, module_file=__file__, import_structure=_import_structure)
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
# SPDX-License-Identifier: Apache-2.0
|
|
4
4
|
|
|
5
5
|
from copy import deepcopy
|
|
6
|
-
from typing import Any,
|
|
6
|
+
from typing import Any, Optional
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
from haystack import Document, component, logging
|
|
@@ -136,8 +136,8 @@ class EmbeddingBasedDocumentSplitter:
|
|
|
136
136
|
self.document_embedder.warm_up()
|
|
137
137
|
self._is_warmed_up = True
|
|
138
138
|
|
|
139
|
-
@component.output_types(documents=
|
|
140
|
-
def run(self, documents:
|
|
139
|
+
@component.output_types(documents=list[Document])
|
|
140
|
+
def run(self, documents: list[Document]) -> dict[str, list[Document]]:
|
|
141
141
|
"""
|
|
142
142
|
Split documents based on embedding similarity.
|
|
143
143
|
|
|
@@ -162,7 +162,7 @@ class EmbeddingBasedDocumentSplitter:
|
|
|
162
162
|
if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)):
|
|
163
163
|
raise TypeError("EmbeddingBasedDocumentSplitter expects a List of Documents as input.")
|
|
164
164
|
|
|
165
|
-
split_docs:
|
|
165
|
+
split_docs: list[Document] = []
|
|
166
166
|
for doc in documents:
|
|
167
167
|
if doc.content is None:
|
|
168
168
|
raise ValueError(
|
|
@@ -178,7 +178,7 @@ class EmbeddingBasedDocumentSplitter:
|
|
|
178
178
|
|
|
179
179
|
return {"documents": split_docs}
|
|
180
180
|
|
|
181
|
-
def _split_document(self, doc: Document) ->
|
|
181
|
+
def _split_document(self, doc: Document) -> list[Document]:
|
|
182
182
|
"""
|
|
183
183
|
Split a single document based on embedding similarity.
|
|
184
184
|
"""
|
|
@@ -194,7 +194,7 @@ class EmbeddingBasedDocumentSplitter:
|
|
|
194
194
|
# Create Document objects from the final splits
|
|
195
195
|
return EmbeddingBasedDocumentSplitter._create_documents_from_splits(splits=final_splits, original_doc=doc)
|
|
196
196
|
|
|
197
|
-
def _split_text(self, text: str) ->
|
|
197
|
+
def _split_text(self, text: str) -> list[str]:
|
|
198
198
|
"""
|
|
199
199
|
Split a text into smaller chunks based on embedding similarity.
|
|
200
200
|
"""
|
|
@@ -221,7 +221,7 @@ class EmbeddingBasedDocumentSplitter:
|
|
|
221
221
|
|
|
222
222
|
return sub_splits
|
|
223
223
|
|
|
224
|
-
def _group_sentences(self, sentences:
|
|
224
|
+
def _group_sentences(self, sentences: list[str]) -> list[str]:
|
|
225
225
|
"""
|
|
226
226
|
Group sentences into groups of sentences_per_group.
|
|
227
227
|
"""
|
|
@@ -235,7 +235,7 @@ class EmbeddingBasedDocumentSplitter:
|
|
|
235
235
|
|
|
236
236
|
return groups
|
|
237
237
|
|
|
238
|
-
def _calculate_embeddings(self, sentence_groups:
|
|
238
|
+
def _calculate_embeddings(self, sentence_groups: list[str]) -> list[list[float]]:
|
|
239
239
|
"""
|
|
240
240
|
Calculate embeddings for each sentence group using the DocumentEmbedder.
|
|
241
241
|
"""
|
|
@@ -246,7 +246,7 @@ class EmbeddingBasedDocumentSplitter:
|
|
|
246
246
|
embeddings = [doc.embedding for doc in embedded_docs]
|
|
247
247
|
return embeddings
|
|
248
248
|
|
|
249
|
-
def _find_split_points(self, embeddings:
|
|
249
|
+
def _find_split_points(self, embeddings: list[list[float]]) -> list[int]:
|
|
250
250
|
"""
|
|
251
251
|
Find split points based on cosine distances between sequential embeddings.
|
|
252
252
|
"""
|
|
@@ -273,7 +273,7 @@ class EmbeddingBasedDocumentSplitter:
|
|
|
273
273
|
return split_points
|
|
274
274
|
|
|
275
275
|
@staticmethod
|
|
276
|
-
def _cosine_distance(embedding1:
|
|
276
|
+
def _cosine_distance(embedding1: list[float], embedding2: list[float]) -> float:
|
|
277
277
|
"""
|
|
278
278
|
Calculate cosine distance between two embeddings.
|
|
279
279
|
"""
|
|
@@ -291,7 +291,7 @@ class EmbeddingBasedDocumentSplitter:
|
|
|
291
291
|
return 1.0 - cosine_sim
|
|
292
292
|
|
|
293
293
|
@staticmethod
|
|
294
|
-
def _create_splits_from_points(sentence_groups:
|
|
294
|
+
def _create_splits_from_points(sentence_groups: list[str], split_points: list[int]) -> list[str]:
|
|
295
295
|
"""
|
|
296
296
|
Create splits based on split points.
|
|
297
297
|
"""
|
|
@@ -315,7 +315,7 @@ class EmbeddingBasedDocumentSplitter:
|
|
|
315
315
|
|
|
316
316
|
return splits
|
|
317
317
|
|
|
318
|
-
def _merge_small_splits(self, splits:
|
|
318
|
+
def _merge_small_splits(self, splits: list[str]) -> list[str]:
|
|
319
319
|
"""
|
|
320
320
|
Merge splits that are below min_length.
|
|
321
321
|
"""
|
|
@@ -341,7 +341,7 @@ class EmbeddingBasedDocumentSplitter:
|
|
|
341
341
|
|
|
342
342
|
return merged
|
|
343
343
|
|
|
344
|
-
def _split_large_splits(self, splits:
|
|
344
|
+
def _split_large_splits(self, splits: list[str]) -> list[str]:
|
|
345
345
|
"""
|
|
346
346
|
Recursively split splits that are above max_length.
|
|
347
347
|
|
|
@@ -375,7 +375,7 @@ class EmbeddingBasedDocumentSplitter:
|
|
|
375
375
|
return final_splits
|
|
376
376
|
|
|
377
377
|
@staticmethod
|
|
378
|
-
def _create_documents_from_splits(splits:
|
|
378
|
+
def _create_documents_from_splits(splits: list[str], original_doc: Document) -> list[Document]:
|
|
379
379
|
"""
|
|
380
380
|
Create Document objects from splits.
|
|
381
381
|
"""
|
|
@@ -405,7 +405,7 @@ class EmbeddingBasedDocumentSplitter:
|
|
|
405
405
|
|
|
406
406
|
return documents
|
|
407
407
|
|
|
408
|
-
def to_dict(self) ->
|
|
408
|
+
def to_dict(self) -> dict[str, Any]:
|
|
409
409
|
"""
|
|
410
410
|
Serializes the component to a dictionary.
|
|
411
411
|
"""
|
|
@@ -422,7 +422,7 @@ class EmbeddingBasedDocumentSplitter:
|
|
|
422
422
|
)
|
|
423
423
|
|
|
424
424
|
@classmethod
|
|
425
|
-
def from_dict(cls, data:
|
|
425
|
+
def from_dict(cls, data: dict[str, Any]) -> "EmbeddingBasedDocumentSplitter":
|
|
426
426
|
"""
|
|
427
427
|
Deserializes the component from a dictionary.
|
|
428
428
|
"""
|
|
@@ -24,7 +24,7 @@ class MarkdownHeaderLevelInferrer:
|
|
|
24
24
|
from haystack_experimental.components.preprocessors import MarkdownHeaderLevelInferrer
|
|
25
25
|
|
|
26
26
|
# Create a document with uniform header levels
|
|
27
|
-
text = "## Title\
|
|
27
|
+
text = "## Title\n## Subheader\nSection\n## Subheader\nMore Content"
|
|
28
28
|
doc = Document(content=text)
|
|
29
29
|
|
|
30
30
|
# Initialize the inferrer and process the document
|
|
@@ -33,7 +33,7 @@ class MarkdownHeaderLevelInferrer:
|
|
|
33
33
|
|
|
34
34
|
# The headers are now normalized with proper hierarchy
|
|
35
35
|
print(result["documents"][0].content)
|
|
36
|
-
> # Title\
|
|
36
|
+
> # Title\n## Subheader\nSection\n## Subheader\nMore Content
|
|
37
37
|
```
|
|
38
38
|
"""
|
|
39
39
|
|
|
@@ -3,7 +3,5 @@
|
|
|
3
3
|
# SPDX-License-Identifier: Apache-2.0
|
|
4
4
|
|
|
5
5
|
from haystack_experimental.components.retrievers.chat_message_retriever import ChatMessageRetriever
|
|
6
|
-
from haystack_experimental.components.retrievers.multi_query_embedding_retriever import MultiQueryEmbeddingRetriever
|
|
7
|
-
from haystack_experimental.components.retrievers.multi_query_text_retriever import MultiQueryTextRetriever
|
|
8
6
|
|
|
9
|
-
_all_ = ["ChatMessageRetriever"
|
|
7
|
+
_all_ = ["ChatMessageRetriever"]
|
|
@@ -2,11 +2,11 @@
|
|
|
2
2
|
#
|
|
3
3
|
# SPDX-License-Identifier: Apache-2.0
|
|
4
4
|
|
|
5
|
-
from typing import Any,
|
|
5
|
+
from typing import Any, Optional
|
|
6
6
|
|
|
7
7
|
from haystack import DeserializationError, component, default_from_dict, default_to_dict, logging
|
|
8
8
|
from haystack.core.serialization import import_class_by_name
|
|
9
|
-
from haystack.dataclasses import ChatMessage
|
|
9
|
+
from haystack.dataclasses import ChatMessage, ChatRole
|
|
10
10
|
|
|
11
11
|
from haystack_experimental.chat_message_stores.types import ChatMessageStore
|
|
12
12
|
|
|
@@ -30,41 +30,40 @@ class ChatMessageRetriever:
|
|
|
30
30
|
]
|
|
31
31
|
|
|
32
32
|
message_store = InMemoryChatMessageStore()
|
|
33
|
-
message_store.write_messages(messages)
|
|
33
|
+
message_store.write_messages(chat_history_id="user_456_session_123", messages=messages)
|
|
34
34
|
retriever = ChatMessageRetriever(message_store)
|
|
35
35
|
|
|
36
|
-
result = retriever.run()
|
|
36
|
+
result = retriever.run(chat_history_id="user_456_session_123")
|
|
37
37
|
|
|
38
38
|
print(result["messages"])
|
|
39
39
|
```
|
|
40
40
|
"""
|
|
41
41
|
|
|
42
|
-
def __init__(self,
|
|
42
|
+
def __init__(self, chat_message_store: ChatMessageStore, last_k: Optional[int] = 10):
|
|
43
43
|
"""
|
|
44
44
|
Create the ChatMessageRetriever component.
|
|
45
45
|
|
|
46
|
-
:param
|
|
46
|
+
:param chat_message_store:
|
|
47
47
|
An instance of a ChatMessageStore.
|
|
48
48
|
:param last_k:
|
|
49
49
|
The number of last messages to retrieve. Defaults to 10 messages if not specified.
|
|
50
50
|
"""
|
|
51
|
-
self.
|
|
52
|
-
if last_k <= 0:
|
|
53
|
-
raise ValueError(f"last_k must be greater than 0. Currently,
|
|
51
|
+
self.chat_message_store = chat_message_store
|
|
52
|
+
if last_k and last_k <= 0:
|
|
53
|
+
raise ValueError(f"last_k must be greater than 0. Currently, last_k is {last_k}")
|
|
54
54
|
self.last_k = last_k
|
|
55
55
|
|
|
56
|
-
def to_dict(self) ->
|
|
56
|
+
def to_dict(self) -> dict[str, Any]:
|
|
57
57
|
"""
|
|
58
58
|
Serializes the component to a dictionary.
|
|
59
59
|
|
|
60
60
|
:returns:
|
|
61
61
|
Dictionary with serialized data.
|
|
62
62
|
"""
|
|
63
|
-
|
|
64
|
-
return default_to_dict(self, message_store=message_store, last_k=self.last_k)
|
|
63
|
+
return default_to_dict(self, chat_message_store=self.chat_message_store.to_dict(), last_k=self.last_k)
|
|
65
64
|
|
|
66
65
|
@classmethod
|
|
67
|
-
def from_dict(cls, data:
|
|
66
|
+
def from_dict(cls, data: dict[str, Any]) -> "ChatMessageRetriever":
|
|
68
67
|
"""
|
|
69
68
|
Deserializes the component from a dictionary.
|
|
70
69
|
|
|
@@ -74,35 +73,67 @@ class ChatMessageRetriever:
|
|
|
74
73
|
The deserialized component.
|
|
75
74
|
"""
|
|
76
75
|
init_params = data.get("init_parameters", {})
|
|
77
|
-
if "
|
|
78
|
-
raise DeserializationError("Missing '
|
|
79
|
-
if "type" not in init_params["
|
|
76
|
+
if "chat_message_store" not in init_params:
|
|
77
|
+
raise DeserializationError("Missing 'chat_message_store' in serialization data")
|
|
78
|
+
if "type" not in init_params["chat_message_store"]:
|
|
80
79
|
raise DeserializationError("Missing 'type' in message store's serialization data")
|
|
81
80
|
|
|
82
|
-
message_store_data = init_params["
|
|
81
|
+
message_store_data = init_params["chat_message_store"]
|
|
83
82
|
try:
|
|
84
83
|
message_store_class = import_class_by_name(message_store_data["type"])
|
|
85
84
|
except ImportError as e:
|
|
86
85
|
raise DeserializationError(f"Class '{message_store_data['type']}' not correctly imported") from e
|
|
86
|
+
if not hasattr(message_store_class, "from_dict"):
|
|
87
|
+
raise DeserializationError(f"{message_store_class} does not have from_dict method implemented.")
|
|
88
|
+
init_params["chat_message_store"] = message_store_class.from_dict(message_store_data)
|
|
87
89
|
|
|
88
|
-
data["init_parameters"]["message_store"] = default_from_dict(message_store_class, message_store_data)
|
|
89
90
|
return default_from_dict(cls, data)
|
|
90
91
|
|
|
91
|
-
@component.output_types(messages=
|
|
92
|
-
def run(
|
|
92
|
+
@component.output_types(messages=list[ChatMessage])
|
|
93
|
+
def run(
|
|
94
|
+
self,
|
|
95
|
+
chat_history_id: str,
|
|
96
|
+
*,
|
|
97
|
+
last_k: Optional[int] = None,
|
|
98
|
+
current_messages: Optional[list[ChatMessage]] = None,
|
|
99
|
+
) -> dict[str, list[ChatMessage]]:
|
|
93
100
|
"""
|
|
94
101
|
Run the ChatMessageRetriever
|
|
95
102
|
|
|
103
|
+
:param chat_history_id:
|
|
104
|
+
A unique identifier for the chat session or conversation whose messages should be retrieved.
|
|
105
|
+
Each `chat_history_id` corresponds to a distinct chat history stored in the underlying ChatMessageStore.
|
|
106
|
+
For example, use a session ID or conversation ID to isolate messages from different chat sessions.
|
|
96
107
|
:param last_k: The number of last messages to retrieve. This parameter takes precedence over the last_k
|
|
97
108
|
parameter passed to the ChatMessageRetriever constructor. If unspecified, the last_k parameter passed
|
|
98
109
|
to the constructor will be used.
|
|
110
|
+
:param current_messages:
|
|
111
|
+
A list of incoming chat messages to combine with the retrieved messages. System messages from this list
|
|
112
|
+
are prepended before the retrieved history, while all other messages (e.g., user messages) are appended
|
|
113
|
+
after. This is useful for including new conversational context alongside stored history so the output
|
|
114
|
+
can be directly used as input to a ChatGenerator or an Agent. If not provided, only the stored messages
|
|
115
|
+
will be returned.
|
|
116
|
+
|
|
99
117
|
:returns:
|
|
100
|
-
|
|
101
|
-
|
|
118
|
+
A dictionary with the following key:
|
|
119
|
+
- `messages` - The retrieved chat messages combined with any provided current messages.
|
|
120
|
+
:raises ValueError: If last_k is not None and is less than 0.
|
|
102
121
|
"""
|
|
103
|
-
if last_k is not None and last_k
|
|
104
|
-
raise ValueError("last_k must be
|
|
122
|
+
if last_k is not None and last_k < 0:
|
|
123
|
+
raise ValueError("last_k must be 0 or greater")
|
|
124
|
+
|
|
125
|
+
resolved_last_k = last_k or self.last_k
|
|
126
|
+
if resolved_last_k == 0:
|
|
127
|
+
return {"messages": current_messages or []}
|
|
128
|
+
|
|
129
|
+
retrieved_messages = self.chat_message_store.retrieve_messages(
|
|
130
|
+
chat_history_id=chat_history_id, last_k=last_k or self.last_k
|
|
131
|
+
)
|
|
105
132
|
|
|
106
|
-
|
|
133
|
+
if not current_messages:
|
|
134
|
+
return {"messages": retrieved_messages}
|
|
107
135
|
|
|
108
|
-
|
|
136
|
+
# We maintain the order: system messages first, then stored messages, then new user messages
|
|
137
|
+
system_messages = [msg for msg in current_messages if msg.is_from(ChatRole.SYSTEM)]
|
|
138
|
+
other_messages = [msg for msg in current_messages if not msg.is_from(ChatRole.SYSTEM)]
|
|
139
|
+
return {"messages": system_messages + retrieved_messages + other_messages}
|