haystack-experimental 0.14.2__py3-none-any.whl → 0.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. haystack_experimental/chat_message_stores/__init__.py +1 -1
  2. haystack_experimental/chat_message_stores/in_memory.py +176 -31
  3. haystack_experimental/chat_message_stores/types.py +33 -21
  4. haystack_experimental/components/agents/agent.py +184 -35
  5. haystack_experimental/components/agents/human_in_the_loop/strategies.py +220 -3
  6. haystack_experimental/components/agents/human_in_the_loop/types.py +36 -1
  7. haystack_experimental/components/embedders/types/protocol.py +2 -2
  8. haystack_experimental/components/preprocessors/__init__.py +2 -0
  9. haystack_experimental/components/preprocessors/embedding_based_document_splitter.py +16 -16
  10. haystack_experimental/components/preprocessors/md_header_level_inferrer.py +2 -2
  11. haystack_experimental/components/retrievers/__init__.py +1 -3
  12. haystack_experimental/components/retrievers/chat_message_retriever.py +57 -26
  13. haystack_experimental/components/writers/__init__.py +1 -1
  14. haystack_experimental/components/writers/chat_message_writer.py +25 -22
  15. haystack_experimental/core/pipeline/breakpoint.py +5 -3
  16. {haystack_experimental-0.14.2.dist-info → haystack_experimental-0.15.0.dist-info}/METADATA +24 -31
  17. {haystack_experimental-0.14.2.dist-info → haystack_experimental-0.15.0.dist-info}/RECORD +20 -27
  18. {haystack_experimental-0.14.2.dist-info → haystack_experimental-0.15.0.dist-info}/WHEEL +1 -1
  19. haystack_experimental/components/query/__init__.py +0 -18
  20. haystack_experimental/components/query/query_expander.py +0 -299
  21. haystack_experimental/components/retrievers/multi_query_embedding_retriever.py +0 -180
  22. haystack_experimental/components/retrievers/multi_query_text_retriever.py +0 -158
  23. haystack_experimental/super_components/__init__.py +0 -3
  24. haystack_experimental/super_components/indexers/__init__.py +0 -11
  25. haystack_experimental/super_components/indexers/sentence_transformers_document_indexer.py +0 -199
  26. {haystack_experimental-0.14.2.dist-info → haystack_experimental-0.15.0.dist-info}/licenses/LICENSE +0 -0
  27. {haystack_experimental-0.14.2.dist-info → haystack_experimental-0.15.0.dist-info}/licenses/LICENSE-MIT.txt +0 -0
@@ -47,7 +47,13 @@ class BlockingConfirmationStrategy:
47
47
  self.confirmation_ui = confirmation_ui
48
48
 
49
49
  def run(
50
- self, tool_name: str, tool_description: str, tool_params: dict[str, Any], tool_call_id: Optional[str] = None
50
+ self,
51
+ *,
52
+ tool_name: str,
53
+ tool_description: str,
54
+ tool_params: dict[str, Any],
55
+ tool_call_id: Optional[str] = None,
56
+ confirmation_strategy_context: Optional[dict[str, Any]] = None,
51
57
  ) -> ToolExecutionDecision:
52
58
  """
53
59
  Run the human-in-the-loop strategy for a given tool and its parameters.
@@ -61,6 +67,10 @@ class BlockingConfirmationStrategy:
61
67
  :param tool_call_id:
62
68
  Optional unique identifier for the tool call. This can be used to track and correlate the decision with a
63
69
  specific tool invocation.
70
+ :param confirmation_strategy_context:
71
+ Optional dictionary for passing request-scoped resources. Useful in web/server environments
72
+ to provide per-request objects (e.g., WebSocket connections, async queues, Redis pub/sub clients)
73
+ that strategies can use for non-blocking user interaction.
64
74
 
65
75
  :returns:
66
76
  A ToolExecutionDecision indicating whether to execute the tool with the given parameters, or a
@@ -109,6 +119,40 @@ class BlockingConfirmationStrategy:
109
119
  tool_name=tool_name, execute=True, tool_call_id=tool_call_id, final_tool_params=tool_params
110
120
  )
111
121
 
122
+ async def run_async(
123
+ self,
124
+ *,
125
+ tool_name: str,
126
+ tool_description: str,
127
+ tool_params: dict[str, Any],
128
+ tool_call_id: Optional[str] = None,
129
+ confirmation_strategy_context: Optional[dict[str, Any]] = None,
130
+ ) -> ToolExecutionDecision:
131
+ """
132
+ Async version of run. Calls the sync run() method by default.
133
+
134
+ :param tool_name:
135
+ The name of the tool to be executed.
136
+ :param tool_description:
137
+ The description of the tool.
138
+ :param tool_params:
139
+ The parameters to be passed to the tool.
140
+ :param tool_call_id:
141
+ Optional unique identifier for the tool call.
142
+ :param confirmation_strategy_context:
143
+ Optional dictionary for passing request-scoped resources.
144
+
145
+ :returns:
146
+ A ToolExecutionDecision indicating whether to execute the tool with the given parameters.
147
+ """
148
+ return self.run(
149
+ tool_name=tool_name,
150
+ tool_description=tool_description,
151
+ tool_params=tool_params,
152
+ tool_call_id=tool_call_id,
153
+ confirmation_strategy_context=confirmation_strategy_context,
154
+ )
155
+
112
156
  def to_dict(self) -> dict[str, Any]:
113
157
  """
114
158
  Serializes the BlockingConfirmationStrategy to a dictionary.
@@ -161,7 +205,13 @@ class BreakpointConfirmationStrategy:
161
205
  self.snapshot_file_path = snapshot_file_path
162
206
 
163
207
  def run(
164
- self, tool_name: str, tool_description: str, tool_params: dict[str, Any], tool_call_id: Optional[str] = None
208
+ self,
209
+ *,
210
+ tool_name: str,
211
+ tool_description: str,
212
+ tool_params: dict[str, Any],
213
+ tool_call_id: Optional[str] = None,
214
+ confirmation_strategy_context: Optional[dict[str, Any]] = None,
165
215
  ) -> ToolExecutionDecision:
166
216
  """
167
217
  Run the breakpoint confirmation strategy for a given tool and its parameters.
@@ -175,6 +225,9 @@ class BreakpointConfirmationStrategy:
175
225
  :param tool_call_id:
176
226
  Optional unique identifier for the tool call. This can be used to track and correlate the decision with a
177
227
  specific tool invocation.
228
+ :param confirmation_strategy_context:
229
+ Optional dictionary for passing request-scoped resources. Not used by this strategy but included for
230
+ interface compatibility.
178
231
 
179
232
  :raises HITLBreakpointException:
180
233
  Always raises an `HITLBreakpointException` exception to signal that user confirmation is required.
@@ -189,6 +242,43 @@ class BreakpointConfirmationStrategy:
189
242
  snapshot_file_path=self.snapshot_file_path,
190
243
  )
191
244
 
245
+ async def run_async(
246
+ self,
247
+ *,
248
+ tool_name: str,
249
+ tool_description: str,
250
+ tool_params: dict[str, Any],
251
+ tool_call_id: Optional[str] = None,
252
+ confirmation_strategy_context: Optional[dict[str, Any]] = None,
253
+ ) -> ToolExecutionDecision:
254
+ """
255
+ Async version of run. Calls the sync run() method.
256
+
257
+ :param tool_name:
258
+ The name of the tool to be executed.
259
+ :param tool_description:
260
+ The description of the tool.
261
+ :param tool_params:
262
+ The parameters to be passed to the tool.
263
+ :param tool_call_id:
264
+ Optional unique identifier for the tool call.
265
+ :param confirmation_strategy_context:
266
+ Optional dictionary for passing request-scoped resources.
267
+
268
+ :raises HITLBreakpointException:
269
+ Always raises an `HITLBreakpointException` exception to signal that user confirmation is required.
270
+
271
+ :returns:
272
+ This method does not return; it always raises an exception.
273
+ """
274
+ return self.run(
275
+ tool_name=tool_name,
276
+ tool_description=tool_description,
277
+ tool_params=tool_params,
278
+ tool_call_id=tool_call_id,
279
+ confirmation_strategy_context=confirmation_strategy_context,
280
+ )
281
+
192
282
  def to_dict(self) -> dict[str, Any]:
193
283
  """
194
284
  Serializes the BreakpointConfirmationStrategy to a dictionary.
@@ -285,6 +375,46 @@ def _process_confirmation_strategies(
285
375
  return modified_tool_call_messages, new_chat_history
286
376
 
287
377
 
378
+ async def _process_confirmation_strategies_async(
379
+ *,
380
+ confirmation_strategies: dict[str, ConfirmationStrategy],
381
+ messages_with_tool_calls: list[ChatMessage],
382
+ execution_context: "_ExecutionContext",
383
+ ) -> tuple[list[ChatMessage], list[ChatMessage]]:
384
+ """
385
+ Async version of _process_confirmation_strategies.
386
+
387
+ Run the confirmation strategies and return modified tool call messages and updated chat history.
388
+
389
+ :param confirmation_strategies: Mapping of tool names to their corresponding confirmation strategies
390
+ :param messages_with_tool_calls: Chat messages containing tool calls
391
+ :param execution_context: The current execution context of the agent
392
+ :returns:
393
+ Tuple of modified messages with confirmed tool calls and updated chat history
394
+ """
395
+ # Run confirmation strategies and get tool execution decisions (async version)
396
+ teds = await _run_confirmation_strategies_async(
397
+ confirmation_strategies=confirmation_strategies,
398
+ messages_with_tool_calls=messages_with_tool_calls,
399
+ execution_context=execution_context,
400
+ )
401
+
402
+ # Apply tool execution decisions to messages_with_tool_calls
403
+ rejection_messages, modified_tool_call_messages = _apply_tool_execution_decisions(
404
+ tool_call_messages=messages_with_tool_calls,
405
+ tool_execution_decisions=teds,
406
+ )
407
+
408
+ # Update the chat history with rejection messages and new tool call messages
409
+ new_chat_history = _update_chat_history(
410
+ chat_history=execution_context.state.get("messages"),
411
+ rejection_messages=rejection_messages,
412
+ tool_call_and_explanation_messages=modified_tool_call_messages,
413
+ )
414
+
415
+ return modified_tool_call_messages, new_chat_history
416
+
417
+
288
418
  def _run_confirmation_strategies(
289
419
  confirmation_strategies: dict[str, ConfirmationStrategy],
290
420
  messages_with_tool_calls: list[ChatMessage],
@@ -344,13 +474,100 @@ def _run_confirmation_strategies(
344
474
  # If not, run the confirmation strategy
345
475
  if not ted:
346
476
  ted = confirmation_strategies[tool_name].run(
347
- tool_name=tool_name, tool_description=tool_to_invoke.description, tool_params=final_args
477
+ tool_name=tool_name,
478
+ tool_description=tool_to_invoke.description,
479
+ tool_params=final_args,
480
+ tool_call_id=tool_call.id,
481
+ confirmation_strategy_context=execution_context.confirmation_strategy_context,
348
482
  )
349
483
  teds.append(ted)
350
484
 
351
485
  return teds
352
486
 
353
487
 
488
+ async def _run_confirmation_strategies_async(
489
+ confirmation_strategies: dict[str, ConfirmationStrategy],
490
+ messages_with_tool_calls: list[ChatMessage],
491
+ execution_context: "_ExecutionContext",
492
+ ) -> list[ToolExecutionDecision]:
493
+ """
494
+ Async version of _run_confirmation_strategies.
495
+
496
+ Run confirmation strategies for tool calls in the provided chat messages.
497
+
498
+ :param confirmation_strategies: Mapping of tool names to their corresponding confirmation strategies
499
+ :param messages_with_tool_calls: Messages containing tool calls to process
500
+ :param execution_context: The current execution context containing state and inputs
501
+ :returns:
502
+ A list of ToolExecutionDecision objects representing the decisions made for each tool call.
503
+ """
504
+ state = execution_context.state
505
+ tools_with_names = {tool.name: tool for tool in execution_context.tool_invoker_inputs["tools"]}
506
+ existing_teds = execution_context.tool_execution_decisions if execution_context.tool_execution_decisions else []
507
+ existing_teds_by_name = {ted.tool_name: ted for ted in existing_teds if ted.tool_name}
508
+ existing_teds_by_id = {ted.tool_call_id: ted for ted in existing_teds if ted.tool_call_id}
509
+
510
+ teds = []
511
+ for message in messages_with_tool_calls:
512
+ if not message.tool_calls:
513
+ continue
514
+
515
+ for tool_call in message.tool_calls:
516
+ tool_name = tool_call.tool_name
517
+ tool_to_invoke = tools_with_names[tool_name]
518
+
519
+ # Prepare final tool args
520
+ final_args = _prepare_tool_args(
521
+ tool=tool_to_invoke,
522
+ tool_call_arguments=tool_call.arguments,
523
+ state=state,
524
+ streaming_callback=execution_context.tool_invoker_inputs.get("streaming_callback"),
525
+ enable_streaming_passthrough=execution_context.tool_invoker_inputs.get(
526
+ "enable_streaming_passthrough", False
527
+ ),
528
+ )
529
+
530
+ # Get tool execution decisions from confirmation strategies
531
+ # If no confirmation strategy is defined for this tool, proceed with execution
532
+ if tool_name not in confirmation_strategies:
533
+ teds.append(
534
+ ToolExecutionDecision(
535
+ tool_call_id=tool_call.id,
536
+ tool_name=tool_name,
537
+ execute=True,
538
+ final_tool_params=final_args,
539
+ )
540
+ )
541
+ continue
542
+
543
+ # Check if there's already a decision for this tool call in the execution context
544
+ ted = existing_teds_by_id.get(tool_call.id or "") or existing_teds_by_name.get(tool_name)
545
+
546
+ # If not, run the confirmation strategy (async version)
547
+ if not ted:
548
+ strategy = confirmation_strategies[tool_name]
549
+ # Use run_async if available, otherwise fall back to sync run
550
+ if hasattr(strategy, "run_async"):
551
+ ted = await strategy.run_async(
552
+ tool_name=tool_name,
553
+ tool_description=tool_to_invoke.description,
554
+ tool_params=final_args,
555
+ tool_call_id=tool_call.id,
556
+ confirmation_strategy_context=execution_context.confirmation_strategy_context,
557
+ )
558
+ else:
559
+ ted = strategy.run(
560
+ tool_name=tool_name,
561
+ tool_description=tool_to_invoke.description,
562
+ tool_params=final_args,
563
+ tool_call_id=tool_call.id,
564
+ confirmation_strategy_context=execution_context.confirmation_strategy_context,
565
+ )
566
+ teds.append(ted)
567
+
568
+ return teds
569
+
570
+
354
571
  def _apply_tool_execution_decisions(
355
572
  tool_call_messages: list[ChatMessage], tool_execution_decisions: list[ToolExecutionDecision]
356
573
  ) -> tuple[list[ChatMessage], list[ChatMessage]]:
@@ -63,7 +63,12 @@ class ConfirmationPolicy(Protocol):
63
63
 
64
64
  class ConfirmationStrategy(Protocol):
65
65
  def run(
66
- self, tool_name: str, tool_description: str, tool_params: dict[str, Any], tool_call_id: Optional[str] = None
66
+ self,
67
+ tool_name: str,
68
+ tool_description: str,
69
+ tool_params: dict[str, Any],
70
+ tool_call_id: Optional[str] = None,
71
+ **kwargs: Optional[dict[str, Any]],
67
72
  ) -> ToolExecutionDecision:
68
73
  """
69
74
  Run the confirmation strategy for a given tool and its parameters.
@@ -73,6 +78,36 @@ class ConfirmationStrategy(Protocol):
73
78
  :param tool_params: The parameters to be passed to the tool.
74
79
  :param tool_call_id: Optional unique identifier for the tool call. This can be used to track and correlate
75
80
  the decision with a specific tool invocation.
81
+ :param kwargs: Additional keyword arguments. Implementations may accept `confirmation_strategy_context`
82
+ for passing request-scoped resources (e.g., WebSocket connections, async queues) in web/server
83
+ environments.
84
+
85
+ :returns:
86
+ The result of the confirmation strategy (e.g., tool output, rejection message, etc.).
87
+ """
88
+ ...
89
+
90
+ async def run_async(
91
+ self,
92
+ tool_name: str,
93
+ tool_description: str,
94
+ tool_params: dict[str, Any],
95
+ tool_call_id: Optional[str] = None,
96
+ **kwargs: Optional[dict[str, Any]],
97
+ ) -> ToolExecutionDecision:
98
+ """
99
+ Async version of run. Run the confirmation strategy for a given tool and its parameters.
100
+
101
+ Default implementation calls the sync run() method. Override for true async behavior.
102
+
103
+ :param tool_name: The name of the tool to be executed.
104
+ :param tool_description: The description of the tool.
105
+ :param tool_params: The parameters to be passed to the tool.
106
+ :param tool_call_id: Optional unique identifier for the tool call. This can be used to track and correlate
107
+ the decision with a specific tool invocation.
108
+ :param kwargs: Additional keyword arguments. Implementations may accept `confirmation_strategy_context`
109
+ for passing request-scoped resources (e.g., WebSocket connections, async queues) in web/server
110
+ environments.
76
111
 
77
112
  :returns:
78
113
  The result of the confirmation strategy (e.g., tool output, rejection message, etc.).
@@ -2,7 +2,7 @@
2
2
  #
3
3
  # SPDX-License-Identifier: Apache-2.0
4
4
 
5
- from typing import Any, Dict, List, Protocol
5
+ from typing import Any, Protocol
6
6
 
7
7
  from haystack import Document
8
8
 
@@ -15,7 +15,7 @@ class DocumentEmbedder(Protocol):
15
15
  Protocol for Document Embedders.
16
16
  """
17
17
 
18
- def run(self, documents: List[Document]) -> Dict[str, Any]:
18
+ def run(self, documents: list[Document]) -> dict[str, Any]:
19
19
  """
20
20
  Generate embeddings for the input documents.
21
21
 
@@ -9,10 +9,12 @@ from lazy_imports import LazyImporter
9
9
 
10
10
  _import_structure = {
11
11
  "embedding_based_document_splitter": ["EmbeddingBasedDocumentSplitter"],
12
+ "md_header_level_inferrer": ["MarkdownHeaderLevelInferrer"],
12
13
  }
13
14
 
14
15
  if TYPE_CHECKING:
15
16
  from .embedding_based_document_splitter import EmbeddingBasedDocumentSplitter
17
+ from .md_header_level_inferrer import MarkdownHeaderLevelInferrer
16
18
 
17
19
  else:
18
20
  sys.modules[__name__] = LazyImporter(name=__name__, module_file=__file__, import_structure=_import_structure)
@@ -3,7 +3,7 @@
3
3
  # SPDX-License-Identifier: Apache-2.0
4
4
 
5
5
  from copy import deepcopy
6
- from typing import Any, Dict, List, Optional
6
+ from typing import Any, Optional
7
7
 
8
8
  import numpy as np
9
9
  from haystack import Document, component, logging
@@ -136,8 +136,8 @@ class EmbeddingBasedDocumentSplitter:
136
136
  self.document_embedder.warm_up()
137
137
  self._is_warmed_up = True
138
138
 
139
- @component.output_types(documents=List[Document])
140
- def run(self, documents: List[Document]) -> Dict[str, List[Document]]:
139
+ @component.output_types(documents=list[Document])
140
+ def run(self, documents: list[Document]) -> dict[str, list[Document]]:
141
141
  """
142
142
  Split documents based on embedding similarity.
143
143
 
@@ -162,7 +162,7 @@ class EmbeddingBasedDocumentSplitter:
162
162
  if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)):
163
163
  raise TypeError("EmbeddingBasedDocumentSplitter expects a List of Documents as input.")
164
164
 
165
- split_docs: List[Document] = []
165
+ split_docs: list[Document] = []
166
166
  for doc in documents:
167
167
  if doc.content is None:
168
168
  raise ValueError(
@@ -178,7 +178,7 @@ class EmbeddingBasedDocumentSplitter:
178
178
 
179
179
  return {"documents": split_docs}
180
180
 
181
- def _split_document(self, doc: Document) -> List[Document]:
181
+ def _split_document(self, doc: Document) -> list[Document]:
182
182
  """
183
183
  Split a single document based on embedding similarity.
184
184
  """
@@ -194,7 +194,7 @@ class EmbeddingBasedDocumentSplitter:
194
194
  # Create Document objects from the final splits
195
195
  return EmbeddingBasedDocumentSplitter._create_documents_from_splits(splits=final_splits, original_doc=doc)
196
196
 
197
- def _split_text(self, text: str) -> List[str]:
197
+ def _split_text(self, text: str) -> list[str]:
198
198
  """
199
199
  Split a text into smaller chunks based on embedding similarity.
200
200
  """
@@ -221,7 +221,7 @@ class EmbeddingBasedDocumentSplitter:
221
221
 
222
222
  return sub_splits
223
223
 
224
- def _group_sentences(self, sentences: List[str]) -> List[str]:
224
+ def _group_sentences(self, sentences: list[str]) -> list[str]:
225
225
  """
226
226
  Group sentences into groups of sentences_per_group.
227
227
  """
@@ -235,7 +235,7 @@ class EmbeddingBasedDocumentSplitter:
235
235
 
236
236
  return groups
237
237
 
238
- def _calculate_embeddings(self, sentence_groups: List[str]) -> List[List[float]]:
238
+ def _calculate_embeddings(self, sentence_groups: list[str]) -> list[list[float]]:
239
239
  """
240
240
  Calculate embeddings for each sentence group using the DocumentEmbedder.
241
241
  """
@@ -246,7 +246,7 @@ class EmbeddingBasedDocumentSplitter:
246
246
  embeddings = [doc.embedding for doc in embedded_docs]
247
247
  return embeddings
248
248
 
249
- def _find_split_points(self, embeddings: List[List[float]]) -> List[int]:
249
+ def _find_split_points(self, embeddings: list[list[float]]) -> list[int]:
250
250
  """
251
251
  Find split points based on cosine distances between sequential embeddings.
252
252
  """
@@ -273,7 +273,7 @@ class EmbeddingBasedDocumentSplitter:
273
273
  return split_points
274
274
 
275
275
  @staticmethod
276
- def _cosine_distance(embedding1: List[float], embedding2: List[float]) -> float:
276
+ def _cosine_distance(embedding1: list[float], embedding2: list[float]) -> float:
277
277
  """
278
278
  Calculate cosine distance between two embeddings.
279
279
  """
@@ -291,7 +291,7 @@ class EmbeddingBasedDocumentSplitter:
291
291
  return 1.0 - cosine_sim
292
292
 
293
293
  @staticmethod
294
- def _create_splits_from_points(sentence_groups: List[str], split_points: List[int]) -> List[str]:
294
+ def _create_splits_from_points(sentence_groups: list[str], split_points: list[int]) -> list[str]:
295
295
  """
296
296
  Create splits based on split points.
297
297
  """
@@ -315,7 +315,7 @@ class EmbeddingBasedDocumentSplitter:
315
315
 
316
316
  return splits
317
317
 
318
- def _merge_small_splits(self, splits: List[str]) -> List[str]:
318
+ def _merge_small_splits(self, splits: list[str]) -> list[str]:
319
319
  """
320
320
  Merge splits that are below min_length.
321
321
  """
@@ -341,7 +341,7 @@ class EmbeddingBasedDocumentSplitter:
341
341
 
342
342
  return merged
343
343
 
344
- def _split_large_splits(self, splits: List[str]) -> List[str]:
344
+ def _split_large_splits(self, splits: list[str]) -> list[str]:
345
345
  """
346
346
  Recursively split splits that are above max_length.
347
347
 
@@ -375,7 +375,7 @@ class EmbeddingBasedDocumentSplitter:
375
375
  return final_splits
376
376
 
377
377
  @staticmethod
378
- def _create_documents_from_splits(splits: List[str], original_doc: Document) -> List[Document]:
378
+ def _create_documents_from_splits(splits: list[str], original_doc: Document) -> list[Document]:
379
379
  """
380
380
  Create Document objects from splits.
381
381
  """
@@ -405,7 +405,7 @@ class EmbeddingBasedDocumentSplitter:
405
405
 
406
406
  return documents
407
407
 
408
- def to_dict(self) -> Dict[str, Any]:
408
+ def to_dict(self) -> dict[str, Any]:
409
409
  """
410
410
  Serializes the component to a dictionary.
411
411
  """
@@ -422,7 +422,7 @@ class EmbeddingBasedDocumentSplitter:
422
422
  )
423
423
 
424
424
  @classmethod
425
- def from_dict(cls, data: Dict[str, Any]) -> "EmbeddingBasedDocumentSplitter":
425
+ def from_dict(cls, data: dict[str, Any]) -> "EmbeddingBasedDocumentSplitter":
426
426
  """
427
427
  Deserializes the component from a dictionary.
428
428
  """
@@ -24,7 +24,7 @@ class MarkdownHeaderLevelInferrer:
24
24
  from haystack_experimental.components.preprocessors import MarkdownHeaderLevelInferrer
25
25
 
26
26
  # Create a document with uniform header levels
27
- text = "## Title\nSome content\n## Section\nMore content\n## Subsection\nFinal content"
27
+ text = "## Title\n## Subheader\nSection\n## Subheader\nMore Content"
28
28
  doc = Document(content=text)
29
29
 
30
30
  # Initialize the inferrer and process the document
@@ -33,7 +33,7 @@ class MarkdownHeaderLevelInferrer:
33
33
 
34
34
  # The headers are now normalized with proper hierarchy
35
35
  print(result["documents"][0].content)
36
- > # Title\nSome content\n## Section\nMore content\n### Subsection\nFinal content
36
+ > # Title\n## Subheader\nSection\n## Subheader\nMore Content
37
37
  ```
38
38
  """
39
39
 
@@ -3,7 +3,5 @@
3
3
  # SPDX-License-Identifier: Apache-2.0
4
4
 
5
5
  from haystack_experimental.components.retrievers.chat_message_retriever import ChatMessageRetriever
6
- from haystack_experimental.components.retrievers.multi_query_embedding_retriever import MultiQueryEmbeddingRetriever
7
- from haystack_experimental.components.retrievers.multi_query_text_retriever import MultiQueryTextRetriever
8
6
 
9
- _all_ = ["ChatMessageRetriever", "MultiQueryTextRetriever", "MultiQueryEmbeddingRetriever"]
7
+ _all_ = ["ChatMessageRetriever"]
@@ -2,11 +2,11 @@
2
2
  #
3
3
  # SPDX-License-Identifier: Apache-2.0
4
4
 
5
- from typing import Any, Dict, List, Optional
5
+ from typing import Any, Optional
6
6
 
7
7
  from haystack import DeserializationError, component, default_from_dict, default_to_dict, logging
8
8
  from haystack.core.serialization import import_class_by_name
9
- from haystack.dataclasses import ChatMessage
9
+ from haystack.dataclasses import ChatMessage, ChatRole
10
10
 
11
11
  from haystack_experimental.chat_message_stores.types import ChatMessageStore
12
12
 
@@ -30,41 +30,40 @@ class ChatMessageRetriever:
30
30
  ]
31
31
 
32
32
  message_store = InMemoryChatMessageStore()
33
- message_store.write_messages(messages)
33
+ message_store.write_messages(chat_history_id="user_456_session_123", messages=messages)
34
34
  retriever = ChatMessageRetriever(message_store)
35
35
 
36
- result = retriever.run()
36
+ result = retriever.run(chat_history_id="user_456_session_123")
37
37
 
38
38
  print(result["messages"])
39
39
  ```
40
40
  """
41
41
 
42
- def __init__(self, message_store: ChatMessageStore, last_k: int = 10):
42
+ def __init__(self, chat_message_store: ChatMessageStore, last_k: Optional[int] = 10):
43
43
  """
44
44
  Create the ChatMessageRetriever component.
45
45
 
46
- :param message_store:
46
+ :param chat_message_store:
47
47
  An instance of a ChatMessageStore.
48
48
  :param last_k:
49
49
  The number of last messages to retrieve. Defaults to 10 messages if not specified.
50
50
  """
51
- self.message_store = message_store
52
- if last_k <= 0:
53
- raise ValueError(f"last_k must be greater than 0. Currently, the last_k is {last_k}")
51
+ self.chat_message_store = chat_message_store
52
+ if last_k and last_k <= 0:
53
+ raise ValueError(f"last_k must be greater than 0. Currently, last_k is {last_k}")
54
54
  self.last_k = last_k
55
55
 
56
- def to_dict(self) -> Dict[str, Any]:
56
+ def to_dict(self) -> dict[str, Any]:
57
57
  """
58
58
  Serializes the component to a dictionary.
59
59
 
60
60
  :returns:
61
61
  Dictionary with serialized data.
62
62
  """
63
- message_store = self.message_store.to_dict()
64
- return default_to_dict(self, message_store=message_store, last_k=self.last_k)
63
+ return default_to_dict(self, chat_message_store=self.chat_message_store.to_dict(), last_k=self.last_k)
65
64
 
66
65
  @classmethod
67
- def from_dict(cls, data: Dict[str, Any]) -> "ChatMessageRetriever":
66
+ def from_dict(cls, data: dict[str, Any]) -> "ChatMessageRetriever":
68
67
  """
69
68
  Deserializes the component from a dictionary.
70
69
 
@@ -74,35 +73,67 @@ class ChatMessageRetriever:
74
73
  The deserialized component.
75
74
  """
76
75
  init_params = data.get("init_parameters", {})
77
- if "message_store" not in init_params:
78
- raise DeserializationError("Missing 'message_store' in serialization data")
79
- if "type" not in init_params["message_store"]:
76
+ if "chat_message_store" not in init_params:
77
+ raise DeserializationError("Missing 'chat_message_store' in serialization data")
78
+ if "type" not in init_params["chat_message_store"]:
80
79
  raise DeserializationError("Missing 'type' in message store's serialization data")
81
80
 
82
- message_store_data = init_params["message_store"]
81
+ message_store_data = init_params["chat_message_store"]
83
82
  try:
84
83
  message_store_class = import_class_by_name(message_store_data["type"])
85
84
  except ImportError as e:
86
85
  raise DeserializationError(f"Class '{message_store_data['type']}' not correctly imported") from e
86
+ if not hasattr(message_store_class, "from_dict"):
87
+ raise DeserializationError(f"{message_store_class} does not have from_dict method implemented.")
88
+ init_params["chat_message_store"] = message_store_class.from_dict(message_store_data)
87
89
 
88
- data["init_parameters"]["message_store"] = default_from_dict(message_store_class, message_store_data)
89
90
  return default_from_dict(cls, data)
90
91
 
91
- @component.output_types(messages=List[ChatMessage])
92
- def run(self, last_k: Optional[int] = None) -> Dict[str, List[ChatMessage]]:
92
+ @component.output_types(messages=list[ChatMessage])
93
+ def run(
94
+ self,
95
+ chat_history_id: str,
96
+ *,
97
+ last_k: Optional[int] = None,
98
+ current_messages: Optional[list[ChatMessage]] = None,
99
+ ) -> dict[str, list[ChatMessage]]:
93
100
  """
94
101
  Run the ChatMessageRetriever
95
102
 
103
+ :param chat_history_id:
104
+ A unique identifier for the chat session or conversation whose messages should be retrieved.
105
+ Each `chat_history_id` corresponds to a distinct chat history stored in the underlying ChatMessageStore.
106
+ For example, use a session ID or conversation ID to isolate messages from different chat sessions.
96
107
  :param last_k: The number of last messages to retrieve. This parameter takes precedence over the last_k
97
108
  parameter passed to the ChatMessageRetriever constructor. If unspecified, the last_k parameter passed
98
109
  to the constructor will be used.
110
+ :param current_messages:
111
+ A list of incoming chat messages to combine with the retrieved messages. System messages from this list
112
+ are prepended before the retrieved history, while all other messages (e.g., user messages) are appended
113
+ after. This is useful for including new conversational context alongside stored history so the output
114
+ can be directly used as input to a ChatGenerator or an Agent. If not provided, only the stored messages
115
+ will be returned.
116
+
99
117
  :returns:
100
- - `messages` - The retrieved chat messages.
101
- :raises ValueError: If last_k is not None and is less than 1
118
+ A dictionary with the following key:
119
+ - `messages` - The retrieved chat messages combined with any provided current messages.
120
+ :raises ValueError: If last_k is not None and is less than 0.
102
121
  """
103
- if last_k is not None and last_k <= 0:
104
- raise ValueError("last_k must be greater than 0")
122
+ if last_k is not None and last_k < 0:
123
+ raise ValueError("last_k must be 0 or greater")
124
+
125
+ resolved_last_k = last_k or self.last_k
126
+ if resolved_last_k == 0:
127
+ return {"messages": current_messages or []}
128
+
129
+ retrieved_messages = self.chat_message_store.retrieve_messages(
130
+ chat_history_id=chat_history_id, last_k=last_k or self.last_k
131
+ )
105
132
 
106
- last_k = last_k or self.last_k
133
+ if not current_messages:
134
+ return {"messages": retrieved_messages}
107
135
 
108
- return {"messages": self.message_store.retrieve()[-last_k:]}
136
+ # We maintain the order: system messages first, then stored messages, then new user messages
137
+ system_messages = [msg for msg in current_messages if msg.is_from(ChatRole.SYSTEM)]
138
+ other_messages = [msg for msg in current_messages if not msg.is_from(ChatRole.SYSTEM)]
139
+ return {"messages": system_messages + retrieved_messages + other_messages}