haystack-experimental 0.14.3__py3-none-any.whl → 0.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. haystack_experimental/chat_message_stores/__init__.py +1 -1
  2. haystack_experimental/chat_message_stores/in_memory.py +176 -31
  3. haystack_experimental/chat_message_stores/types.py +33 -21
  4. haystack_experimental/components/agents/agent.py +147 -44
  5. haystack_experimental/components/agents/human_in_the_loop/strategies.py +220 -3
  6. haystack_experimental/components/agents/human_in_the_loop/types.py +36 -1
  7. haystack_experimental/components/embedders/types/protocol.py +2 -2
  8. haystack_experimental/components/preprocessors/embedding_based_document_splitter.py +16 -16
  9. haystack_experimental/components/retrievers/__init__.py +1 -3
  10. haystack_experimental/components/retrievers/chat_message_retriever.py +57 -26
  11. haystack_experimental/components/writers/__init__.py +1 -1
  12. haystack_experimental/components/writers/chat_message_writer.py +25 -22
  13. {haystack_experimental-0.14.3.dist-info → haystack_experimental-0.15.1.dist-info}/METADATA +24 -31
  14. {haystack_experimental-0.14.3.dist-info → haystack_experimental-0.15.1.dist-info}/RECORD +17 -24
  15. {haystack_experimental-0.14.3.dist-info → haystack_experimental-0.15.1.dist-info}/WHEEL +1 -1
  16. haystack_experimental/components/query/__init__.py +0 -18
  17. haystack_experimental/components/query/query_expander.py +0 -294
  18. haystack_experimental/components/retrievers/multi_query_embedding_retriever.py +0 -173
  19. haystack_experimental/components/retrievers/multi_query_text_retriever.py +0 -150
  20. haystack_experimental/super_components/__init__.py +0 -3
  21. haystack_experimental/super_components/indexers/__init__.py +0 -11
  22. haystack_experimental/super_components/indexers/sentence_transformers_document_indexer.py +0 -199
  23. {haystack_experimental-0.14.3.dist-info → haystack_experimental-0.15.1.dist-info}/licenses/LICENSE +0 -0
  24. {haystack_experimental-0.14.3.dist-info → haystack_experimental-0.15.1.dist-info}/licenses/LICENSE-MIT.txt +0 -0
@@ -4,4 +4,4 @@
4
4
 
5
5
  from haystack_experimental.chat_message_stores.in_memory import InMemoryChatMessageStore
6
6
 
7
- _all_ = ["InMemoryChatMessageStore"]
7
+ __all__ = ["InMemoryChatMessageStore"]
@@ -2,42 +2,69 @@
2
2
  #
3
3
  # SPDX-License-Identifier: Apache-2.0
4
4
 
5
- from typing import Any, Dict, Iterable, List
5
+ from dataclasses import replace
6
+ from typing import Any, Iterable, Optional
6
7
 
7
- from haystack import default_from_dict, default_to_dict, logging
8
- from haystack.dataclasses import ChatMessage
8
+ from haystack import default_from_dict, default_to_dict
9
+ from haystack.dataclasses import ChatMessage, ChatRole
9
10
 
10
- from haystack_experimental.chat_message_stores.types import ChatMessageStore
11
+ # Global storage for all InMemoryDocumentStore instances, indexed by the chat history id.
12
+ _STORAGES: dict[str, list[ChatMessage]] = {}
11
13
 
12
- logger = logging.getLogger(__name__)
13
14
 
14
-
15
- class InMemoryChatMessageStore(ChatMessageStore):
15
+ class InMemoryChatMessageStore:
16
16
  """
17
17
  Stores chat messages in-memory.
18
+
19
+ The `chat_history_id` parameter is used as a unique identifier for each conversation or chat session.
20
+ It acts as a namespace that isolates messages from different sessions. Each `chat_history_id` value corresponds to a
21
+ separate list of `ChatMessage` objects stored in memory.
22
+
23
+ Typical usage involves providing a unique `chat_history_id` (for example, a session ID or conversation ID)
24
+ whenever you write, read, or delete messages. This ensures that chat messages from different
25
+ conversations do not overlap.
26
+
27
+ Usage example:
28
+ ```python
29
+ from haystack.dataclasses import ChatMessage
30
+ from haystack_experimental.chat_message_stores.in_memory import InMemoryChatMessageStore
31
+
32
+ message_store = InMemoryChatMessageStore()
33
+
34
+ messages = [
35
+ ChatMessage.from_assistant("Hello, how can I help you?"),
36
+ ChatMessage.from_user("Hi, I have a question about Python. What is a Protocol?"),
37
+ ]
38
+ message_store.write_messages(chat_history_id="user_456_session_123", messages=messages)
39
+ retrieved_messages = message_store.retrieve_messages(chat_history_id="user_456_session_123")
40
+
41
+ print(retrieved_messages)
42
+ ```
18
43
  """
19
44
 
20
- def __init__(
21
- self,
22
- ):
45
+ def __init__(self, skip_system_messages: bool = True, last_k: Optional[int] = 10) -> None:
23
46
  """
24
- Initializes the InMemoryChatMessageStore.
47
+ Create an InMemoryChatMessageStore.
48
+
49
+ :param skip_system_messages:
50
+ Whether to skip storing system messages. Defaults to True.
51
+ :param last_k:
52
+ The number of last messages to retrieve. Defaults to 10 messages if not specified.
25
53
  """
26
- self.messages = []
54
+ self.skip_system_messages = skip_system_messages
55
+ self.last_k = last_k
27
56
 
28
- def to_dict(self) -> Dict[str, Any]:
57
+ def to_dict(self) -> dict[str, Any]:
29
58
  """
30
59
  Serializes the component to a dictionary.
31
60
 
32
61
  :returns:
33
62
  Dictionary with serialized data.
34
63
  """
35
- return default_to_dict(
36
- self,
37
- )
64
+ return default_to_dict(self, skip_system_messages=self.skip_system_messages, last_k=self.last_k)
38
65
 
39
66
  @classmethod
40
- def from_dict(cls, data: Dict[str, Any]) -> "InMemoryChatMessageStore":
67
+ def from_dict(cls, data: dict[str, Any]) -> "InMemoryChatMessageStore":
41
68
  """
42
69
  Deserializes the component from a dictionary.
43
70
 
@@ -48,19 +75,26 @@ class InMemoryChatMessageStore(ChatMessageStore):
48
75
  """
49
76
  return default_from_dict(cls, data)
50
77
 
51
- def count_messages(self) -> int:
78
+ def count_messages(self, chat_history_id: str) -> int:
52
79
  """
53
- Returns the number of chat messages stored.
80
+ Returns the number of chat messages stored in this store.
81
+
82
+ :param chat_history_id:
83
+ The chat history id for which to count messages.
54
84
 
55
85
  :returns: The number of messages.
56
86
  """
57
- return len(self.messages)
87
+ return len(_STORAGES.get(chat_history_id, []))
58
88
 
59
- def write_messages(self, messages: List[ChatMessage]) -> int:
89
+ def write_messages(self, chat_history_id: str, messages: list[ChatMessage]) -> int:
60
90
  """
61
91
  Writes chat messages to the ChatMessageStore.
62
92
 
63
- :param messages: A list of ChatMessages to write.
93
+ :param chat_history_id:
94
+ The chat history id under which to store the messages.
95
+ :param messages:
96
+ A list of ChatMessages to write.
97
+
64
98
  :returns: The number of messages written.
65
99
 
66
100
  :raises ValueError: If messages is not a list of ChatMessages.
@@ -68,19 +102,130 @@ class InMemoryChatMessageStore(ChatMessageStore):
68
102
  if not isinstance(messages, Iterable) or any(not isinstance(message, ChatMessage) for message in messages):
69
103
  raise ValueError("Please provide a list of ChatMessages.")
70
104
 
71
- self.messages.extend(messages)
72
- return len(messages)
105
+ # We assign an ID to messages that don't have one yet. The ID simply corresponds to the chat_history_id in the
106
+ # array.
107
+ counter = self.count_messages(chat_history_id)
108
+ messages_with_id = []
109
+ for msg in messages:
110
+ # Skip system messages if configured to do so
111
+ if self.skip_system_messages and msg.is_from(ChatRole.SYSTEM):
112
+ continue
113
+
114
+ chat_message_id = msg.meta.get("chat_message_id")
115
+ if chat_message_id is None:
116
+ # We use replace to avoid mutating the original message
117
+ msg = replace(msg, _meta={"chat_message_id": str(counter), **msg.meta})
118
+ counter += 1
119
+
120
+ messages_with_id.append(msg)
121
+
122
+ # For now, we always skip messages that are already stored based on their ID.
123
+ existing_messages = _STORAGES.get(chat_history_id, [])
124
+ existing_ids = {
125
+ msg.meta.get("chat_message_id") for msg in existing_messages if msg.meta.get("chat_message_id") is not None
126
+ }
127
+ messages_to_write = [
128
+ message for message in messages_with_id if message.meta["chat_message_id"] not in existing_ids
129
+ ]
130
+
131
+ for message in messages_to_write:
132
+ if chat_history_id not in _STORAGES:
133
+ _STORAGES[chat_history_id] = []
134
+ _STORAGES[chat_history_id].append(message)
135
+
136
+ return len(messages_to_write)
137
+
138
+ def retrieve_messages(self, chat_history_id: str, last_k: Optional[int] = None) -> list[ChatMessage]:
139
+ """
140
+ Retrieves all stored chat messages.
73
141
 
74
- def delete_messages(self) -> None:
142
+ :param chat_history_id:
143
+ The chat history id from which to retrieve messages.
144
+ :param last_k:
145
+ The number of last messages to retrieve. If unspecified, the last_k parameter passed
146
+ to the constructor will be used.
147
+
148
+ :returns: A list of chat messages.
149
+ :raises ValueError:
150
+ If last_k is not None and is less than 0.
75
151
  """
76
- Deletes all stored chat messages.
152
+
153
+ if last_k is not None and last_k < 0:
154
+ raise ValueError("last_k must be 0 or greater")
155
+
156
+ resolved_last_k = last_k if last_k is not None else self.last_k
157
+ if resolved_last_k == 0:
158
+ return []
159
+
160
+ messages = _STORAGES.get(chat_history_id, [])
161
+ if resolved_last_k is not None:
162
+ messages = self._get_last_k_messages(messages=messages, last_k=resolved_last_k)
163
+
164
+ return messages
165
+
166
+ @staticmethod
167
+ def _get_last_k_messages(messages: list[ChatMessage], last_k: int) -> list[ChatMessage]:
77
168
  """
78
- self.messages = []
169
+ Get the last_k rounds of messages from the incoming list of messages.
170
+
171
+ This is done in such a way such the returned list of messages is always valid. By valid we mean it will
172
+ be submittable to an LLM without causing issues. For example, we want to avoid slicing the chat history in a
173
+ way that a ToolCall is present without its corresponding ToolOutput.
79
174
 
80
- def retrieve(self) -> List[ChatMessage]:
175
+ This is handled by treating ToolCalls and its corresponding ToolOutput(s) as a single unit when slicing the chat
176
+ history.
177
+
178
+ :param messages:
179
+ List of chat messages.
180
+ :param last_k:
181
+ The number of last rounds of messages to retrieve. By rounds of messages we mean pairs of
182
+ User -> Assistant messages. ToolCalls and ToolOutputs are considered part of the Assistant message.
183
+ :returns:
184
+ The sliced list of chat messages.
81
185
  """
82
- Retrieves all stored chat messages.
186
+ rounds = []
187
+ current = []
188
+
189
+ for msg in messages:
190
+ # Treat system messages as separate rounds
191
+ if msg.role == "system":
192
+ rounds.append([msg])
193
+ continue
194
+
195
+ # User messages always start a new round
196
+ if msg.role == "user":
197
+ current.append(msg)
198
+ continue
199
+
200
+ # Assistant messages can either end a round or continue it (in case of tool calls)
201
+ if msg.role == "assistant":
202
+ current.append(msg)
203
+ if msg.text and not msg.tool_calls:
204
+ rounds.append(current)
205
+ current = []
206
+ continue
207
+
208
+ # Append all other messages (e.g., tool outputs) to the current round
209
+ current.append(msg)
210
+
211
+ # Catch any remaining messages in the current round
212
+ if current:
213
+ rounds.append(current)
214
+
215
+ selected = rounds[-last_k:]
216
+ return [m for r in selected for m in r]
217
+
218
+ def delete_messages(self, chat_history_id: str) -> None:
219
+ """
220
+ Deletes all stored chat messages.
83
221
 
84
- :returns: A list of chat messages.
222
+ :param chat_history_id:
223
+ The chat history id from which to delete messages.
224
+ """
225
+ _STORAGES.pop(chat_history_id, None)
226
+
227
+ def delete_all_messages(self) -> None:
228
+ """
229
+ Deletes all stored chat messages from all chat history ids.
85
230
  """
86
- return self.messages
231
+ _STORAGES.clear()
@@ -2,16 +2,15 @@
2
2
  #
3
3
  # SPDX-License-Identifier: Apache-2.0
4
4
 
5
- from abc import ABC, abstractmethod
6
- from typing import Any, Dict, List
5
+ from typing import Any, Optional, Protocol
7
6
 
8
- from haystack import logging
9
7
  from haystack.dataclasses import ChatMessage
10
8
 
11
- logger = logging.getLogger(__name__)
9
+ # Ellipsis are needed for the type checker, it's safe to disable module-wide
10
+ # pylint: disable=unnecessary-ellipsis
12
11
 
13
12
 
14
- class ChatMessageStore(ABC):
13
+ class ChatMessageStore(Protocol):
15
14
  """
16
15
  Stores ChatMessages to be used by the components of a Pipeline.
17
16
 
@@ -22,53 +21,66 @@ class ChatMessageStore(ABC):
22
21
  In order to write or retrieve chat messages, consider using a ChatMessageWriter or ChatMessageRetriever.
23
22
  """
24
23
 
25
- @abstractmethod
26
- def to_dict(self) -> Dict[str, Any]:
24
+ def to_dict(self) -> dict[str, Any]:
27
25
  """
28
26
  Serializes this store to a dictionary.
29
27
 
30
28
  :returns: The serialized store as a dictionary.
31
29
  """
30
+ ...
32
31
 
33
32
  @classmethod
34
- @abstractmethod
35
- def from_dict(cls, data: Dict[str, Any]) -> "ChatMessageStore":
33
+ def from_dict(cls, data: dict[str, Any]) -> "ChatMessageStore":
36
34
  """
37
35
  Deserializes the store from a dictionary.
38
36
 
39
37
  :param data: The dictionary to deserialize from.
40
38
  :returns: The deserialized store.
41
39
  """
40
+ ...
42
41
 
43
- @abstractmethod
44
- def count_messages(self) -> int:
42
+ def count_messages(self, chat_history_id: str) -> int:
45
43
  """
46
44
  Returns the number of chat messages stored.
47
45
 
46
+ :param chat_history_id: The chat history id for which to count messages.
47
+
48
48
  :returns: The number of messages.
49
49
  """
50
+ ...
50
51
 
51
- @abstractmethod
52
- def write_messages(self, messages: List[ChatMessage]) -> int:
52
+ def write_messages(self, chat_history_id: str, messages: list[ChatMessage]) -> int:
53
53
  """
54
54
  Writes chat messages to the ChatMessageStore.
55
55
 
56
+ :param chat_history_id: The chat history id under which to store the messages.
56
57
  :param messages: A list of ChatMessages to write.
57
- :returns: The number of messages written.
58
58
 
59
- :raises ValueError: If messages is not a list of ChatMessages.
59
+ :returns: The number of messages written.
60
60
  """
61
+ ...
61
62
 
62
- @abstractmethod
63
- def delete_messages(self) -> None:
63
+ def delete_messages(self, chat_history_id: str) -> None:
64
64
  """
65
65
  Deletes all stored chat messages.
66
+
67
+ :param chat_history_id: The chat history id from which to delete all messages.
66
68
  """
69
+ ...
67
70
 
68
- @abstractmethod
69
- def retrieve(self) -> List[ChatMessage]:
71
+ def delete_all_messages(self) -> None:
70
72
  """
71
- Retrieves all stored chat messages.
73
+ Deletes all stored chat messages from all indices.
74
+ """
75
+ ...
76
+
77
+ def retrieve_messages(self, chat_history_id: str, last_k: Optional[int] = None) -> list[ChatMessage]:
78
+ """
79
+ Retrieves chat messages from the ChatMessageStore.
80
+
81
+ :param chat_history_id: The chat history id from which to retrieve messages.
82
+ :param last_k: The number of last messages to retrieve. If None, retrieves all messages.
72
83
 
73
- :returns: A list of chat messages.
84
+ :returns: A list of retrieved ChatMessages.
74
85
  """
86
+ ...