rasa-pro 3.13.0.dev1__py3-none-any.whl → 3.13.0.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (58) hide show
  1. rasa/core/actions/action.py +0 -6
  2. rasa/core/channels/voice_ready/audiocodes.py +52 -17
  3. rasa/core/channels/voice_stream/audiocodes.py +53 -9
  4. rasa/core/channels/voice_stream/genesys.py +146 -16
  5. rasa/core/information_retrieval/faiss.py +6 -1
  6. rasa/core/information_retrieval/information_retrieval.py +40 -2
  7. rasa/core/information_retrieval/milvus.py +7 -2
  8. rasa/core/information_retrieval/qdrant.py +7 -2
  9. rasa/core/policies/enterprise_search_policy.py +61 -301
  10. rasa/core/policies/flows/flow_executor.py +3 -38
  11. rasa/core/processor.py +27 -6
  12. rasa/core/utils.py +53 -0
  13. rasa/dialogue_understanding/commands/cancel_flow_command.py +4 -59
  14. rasa/dialogue_understanding/commands/start_flow_command.py +0 -41
  15. rasa/dialogue_understanding/generator/command_generator.py +67 -0
  16. rasa/dialogue_understanding/generator/command_parser.py +1 -1
  17. rasa/dialogue_understanding/generator/llm_based_command_generator.py +4 -13
  18. rasa/dialogue_understanding/generator/prompt_templates/command_prompt_template.jinja2 +1 -1
  19. rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_gpt_4o_2024_11_20_template.jinja2 +20 -1
  20. rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +7 -0
  21. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +0 -61
  22. rasa/dialogue_understanding/processor/command_processor.py +7 -65
  23. rasa/dialogue_understanding/stack/utils.py +0 -38
  24. rasa/dialogue_understanding_test/io.py +13 -8
  25. rasa/document_retrieval/__init__.py +0 -0
  26. rasa/document_retrieval/constants.py +32 -0
  27. rasa/document_retrieval/document_post_processor.py +351 -0
  28. rasa/document_retrieval/document_post_processor_prompt_template.jinja2 +0 -0
  29. rasa/document_retrieval/document_retriever.py +333 -0
  30. rasa/document_retrieval/knowledge_base_connectors/__init__.py +0 -0
  31. rasa/document_retrieval/knowledge_base_connectors/api_connector.py +39 -0
  32. rasa/document_retrieval/knowledge_base_connectors/knowledge_base_connector.py +34 -0
  33. rasa/document_retrieval/knowledge_base_connectors/vector_store_connector.py +226 -0
  34. rasa/document_retrieval/query_rewriter.py +234 -0
  35. rasa/document_retrieval/query_rewriter_prompt_template.jinja2 +8 -0
  36. rasa/engine/recipes/default_components.py +2 -0
  37. rasa/shared/core/constants.py +0 -8
  38. rasa/shared/core/domain.py +12 -3
  39. rasa/shared/core/flows/flow.py +0 -17
  40. rasa/shared/core/flows/flows_yaml_schema.json +3 -38
  41. rasa/shared/core/flows/steps/collect.py +5 -18
  42. rasa/shared/core/flows/utils.py +1 -16
  43. rasa/shared/core/slot_mappings.py +11 -5
  44. rasa/shared/nlu/constants.py +0 -1
  45. rasa/shared/utils/common.py +11 -1
  46. rasa/shared/utils/llm.py +1 -1
  47. rasa/tracing/instrumentation/attribute_extractors.py +10 -7
  48. rasa/tracing/instrumentation/instrumentation.py +12 -12
  49. rasa/validator.py +1 -123
  50. rasa/version.py +1 -1
  51. {rasa_pro-3.13.0.dev1.dist-info → rasa_pro-3.13.0.dev2.dist-info}/METADATA +1 -1
  52. {rasa_pro-3.13.0.dev1.dist-info → rasa_pro-3.13.0.dev2.dist-info}/RECORD +55 -47
  53. rasa/core/actions/action_handle_digressions.py +0 -164
  54. rasa/dialogue_understanding/commands/handle_digressions_command.py +0 -144
  55. rasa/dialogue_understanding/patterns/handle_digressions.py +0 -81
  56. {rasa_pro-3.13.0.dev1.dist-info → rasa_pro-3.13.0.dev2.dist-info}/NOTICE +0 -0
  57. {rasa_pro-3.13.0.dev1.dist-info → rasa_pro-3.13.0.dev2.dist-info}/WHEEL +0 -0
  58. {rasa_pro-3.13.0.dev1.dist-info → rasa_pro-3.13.0.dev2.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,333 @@
1
+ from __future__ import annotations
2
+
3
+ import datetime
4
+ import time
5
+ import uuid
6
+ from enum import Enum
7
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Text
8
+
9
+ import structlog
10
+
11
+ from rasa.core.information_retrieval.faiss import FAISS_Store
12
+ from rasa.dialogue_understanding.utils import add_prompt_to_message_parse_data
13
+ from rasa.document_retrieval.constants import (
14
+ CONNECTOR_CONFIG_KEY,
15
+ DEFAULT_K,
16
+ DEFAULT_THRESHOLD,
17
+ K_CONFIG_KEY,
18
+ POST_PROCESSED_DOCUMENTS_KEY,
19
+ POST_PROCESSING_CONFIG_KEY,
20
+ QUERY_REWRITING_CONFIG_KEY,
21
+ RETRIEVED_DOCUMENTS_KEY,
22
+ SEARCH_QUERY_KEY,
23
+ SOURCE_PROPERTY,
24
+ THRESHOLD_CONFIG_KEY,
25
+ USE_LLM_PROPERTY,
26
+ VECTOR_STORE_CONFIG_KEY,
27
+ VECTOR_STORE_TYPE_CONFIG_KEY,
28
+ )
29
+ from rasa.document_retrieval.document_post_processor import DocumentPostProcessor
30
+ from rasa.document_retrieval.knowledge_base_connectors.api_connector import APIConnector
31
+ from rasa.document_retrieval.knowledge_base_connectors.knowledge_base_connector import (
32
+ KnowledgeBaseConnector,
33
+ )
34
+ from rasa.document_retrieval.knowledge_base_connectors.vector_store_connector import (
35
+ DEFAULT_EMBEDDINGS_CONFIG,
36
+ VectorStoreConnector,
37
+ VectorStoreType,
38
+ )
39
+ from rasa.document_retrieval.query_rewriter import QueryRewriter
40
+ from rasa.engine.graph import ExecutionContext, GraphComponent
41
+ from rasa.engine.recipes.default_recipe import DefaultV1Recipe
42
+ from rasa.engine.storage.resource import Resource
43
+ from rasa.engine.storage.storage import ModelStorage
44
+ from rasa.shared.constants import (
45
+ EMBEDDINGS_CONFIG_KEY,
46
+ )
47
+ from rasa.shared.core.trackers import DialogueStateTracker
48
+ from rasa.shared.nlu.training_data.message import Message
49
+ from rasa.shared.nlu.training_data.training_data import TrainingData
50
+ from rasa.shared.providers.embedding._langchain_embedding_client_adapter import (
51
+ _LangchainEmbeddingClientAdapter,
52
+ )
53
+ from rasa.shared.providers.llm.llm_response import LLMResponse
54
+ from rasa.shared.utils.llm import (
55
+ embedder_factory,
56
+ resolve_model_client_config,
57
+ )
58
+
59
+ if TYPE_CHECKING:
60
+ from langchain.schema.embeddings import Embeddings
61
+
62
+ structlogger = structlog.get_logger()
63
+
64
+
65
+ class ConnectorType(Enum):
66
+ API = "API"
67
+ VECTOR_STORE = "VECTOR_STORE"
68
+
69
+ def __str__(self) -> str:
70
+ return self.value
71
+
72
+
73
+ @DefaultV1Recipe.register(
74
+ [
75
+ DefaultV1Recipe.ComponentType.COEXISTENCE_ROUTER,
76
+ ],
77
+ is_trainable=True,
78
+ )
79
+ class DocumentRetriever(GraphComponent):
80
+ @staticmethod
81
+ def get_default_config() -> Dict[str, Any]:
82
+ """The component's default config (see parent class for full docstring)."""
83
+ return {
84
+ THRESHOLD_CONFIG_KEY: DEFAULT_THRESHOLD,
85
+ K_CONFIG_KEY: DEFAULT_K,
86
+ CONNECTOR_CONFIG_KEY: ConnectorType.VECTOR_STORE.value,
87
+ EMBEDDINGS_CONFIG_KEY: DEFAULT_EMBEDDINGS_CONFIG,
88
+ VECTOR_STORE_CONFIG_KEY: {
89
+ VECTOR_STORE_TYPE_CONFIG_KEY: VectorStoreType.FAISS.value,
90
+ },
91
+ }
92
+
93
+ def __init__(
94
+ self,
95
+ config: Dict[str, Any],
96
+ model_storage: ModelStorage,
97
+ resource: Resource,
98
+ query_rewriter: Optional[QueryRewriter] = None,
99
+ document_post_processor: Optional[DocumentPostProcessor] = None,
100
+ knowledge_base_connector: Optional[KnowledgeBaseConnector] = None,
101
+ ) -> None:
102
+ self.config = {**self.get_default_config(), **config}
103
+ self.config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
104
+ self.config.get(EMBEDDINGS_CONFIG_KEY), DocumentRetriever.__name__
105
+ )
106
+
107
+ self._model_storage = model_storage
108
+ self._resource = resource
109
+
110
+ # Disable query rewriting and post processing if they are not set
111
+ query_rewriting_config = config.get(
112
+ QUERY_REWRITING_CONFIG_KEY, {"type": "PLAIN"}
113
+ )
114
+ post_processing_config = config.get(
115
+ POST_PROCESSING_CONFIG_KEY, {"type": "PLAIN"}
116
+ )
117
+
118
+ self.query_rewriter = query_rewriter or QueryRewriter(
119
+ query_rewriting_config, model_storage, resource
120
+ )
121
+ self.document_post_processor = document_post_processor or DocumentPostProcessor(
122
+ post_processing_config, model_storage, resource
123
+ )
124
+ self.knowledge_base_connector = (
125
+ knowledge_base_connector or self.initialize_knowledge_base_connector()
126
+ )
127
+
128
+ self.use_llm = self.config.get(USE_LLM_PROPERTY, False)
129
+
130
+ def persist(self) -> None:
131
+ """Persist this component to disk for future loading."""
132
+ self.query_rewriter.persist()
133
+ self.document_post_processor.persist()
134
+
135
+ @classmethod
136
+ def _create_plain_embedder(cls, config: Dict[Text, Any]) -> "Embeddings":
137
+ """Creates an embedder based on the given configuration.
138
+
139
+ Returns:
140
+ The embedder.
141
+ """
142
+ # Copy the config so original config is not modified
143
+ config = config.copy()
144
+ # Resolve config and instantiate the embedding client
145
+ config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
146
+ config.get(EMBEDDINGS_CONFIG_KEY), DocumentRetriever.__name__
147
+ )
148
+ client = embedder_factory(
149
+ config.get(EMBEDDINGS_CONFIG_KEY), DEFAULT_EMBEDDINGS_CONFIG
150
+ )
151
+ # Wrap the embedding client in the adapter
152
+ return _LangchainEmbeddingClientAdapter(client)
153
+
154
+ def train(self, training_data: TrainingData) -> Resource:
155
+ """Train the document retriever on a data set."""
156
+ store_type = self.config.get(VECTOR_STORE_CONFIG_KEY, {}).get(
157
+ VECTOR_STORE_TYPE_CONFIG_KEY
158
+ )
159
+ if store_type == VectorStoreType.FAISS.value:
160
+ structlogger.info("document_retriever.train.faiss")
161
+ embeddings = self._create_plain_embedder(self.config)
162
+ with self._model_storage.write_to(self._resource) as path:
163
+ self.vector_store = FAISS_Store(
164
+ docs_folder=self.config.get(VECTOR_STORE_CONFIG_KEY, {}).get(
165
+ SOURCE_PROPERTY
166
+ ),
167
+ embeddings=embeddings,
168
+ index_path=path,
169
+ create_index=True,
170
+ use_llm=self.use_llm,
171
+ )
172
+ self.persist()
173
+ return self._resource
174
+
175
+ @classmethod
176
+ def load(
177
+ cls,
178
+ config: Dict[str, Any],
179
+ model_storage: ModelStorage,
180
+ resource: Resource,
181
+ execution_context: ExecutionContext,
182
+ **kwargs: Any,
183
+ ) -> "DocumentRetriever":
184
+ """Loads trained component (see parent class for full docstring)."""
185
+ # Load query rewriter and document post processor
186
+
187
+ # Disable query rewriting and post processing if they are not set
188
+ query_rewriting_config = config.get(
189
+ QUERY_REWRITING_CONFIG_KEY, {"type": "PLAIN"}
190
+ )
191
+ post_processing_config = config.get(
192
+ POST_PROCESSING_CONFIG_KEY, {"type": "PLAIN"}
193
+ )
194
+
195
+ query_rewriter = QueryRewriter.load(
196
+ query_rewriting_config, model_storage, resource
197
+ )
198
+ document_post_processor = DocumentPostProcessor.load(
199
+ post_processing_config, model_storage, resource
200
+ )
201
+
202
+ connector_type = config.get(CONNECTOR_CONFIG_KEY)
203
+ knowledge_base_connector: KnowledgeBaseConnector
204
+
205
+ if connector_type == ConnectorType.VECTOR_STORE.value:
206
+ knowledge_base_connector = VectorStoreConnector.load(
207
+ config, model_storage, resource
208
+ )
209
+ elif connector_type == ConnectorType.API.value:
210
+ knowledge_base_connector = APIConnector.load(
211
+ config, model_storage, resource
212
+ )
213
+ else:
214
+ raise ValueError(f"Invalid knowledge base connector: {connector_type}")
215
+
216
+ return cls(
217
+ config,
218
+ model_storage,
219
+ resource,
220
+ query_rewriter,
221
+ document_post_processor,
222
+ knowledge_base_connector,
223
+ )
224
+
225
+ @classmethod
226
+ def create(
227
+ cls,
228
+ config: Dict[str, Any],
229
+ model_storage: ModelStorage,
230
+ resource: Resource,
231
+ execution_context: ExecutionContext,
232
+ ) -> DocumentRetriever:
233
+ """Creates component (see parent class for full docstring)."""
234
+ return cls(config, model_storage, resource)
235
+
236
+ def initialize_knowledge_base_connector(self) -> KnowledgeBaseConnector:
237
+ connector_type = self.config.get(CONNECTOR_CONFIG_KEY)
238
+
239
+ if connector_type == ConnectorType.VECTOR_STORE.value:
240
+ return VectorStoreConnector(
241
+ self.config,
242
+ self._model_storage,
243
+ self._resource,
244
+ )
245
+ elif connector_type == ConnectorType.API.value:
246
+ return APIConnector(self.config)
247
+ else:
248
+ raise ValueError(f"Invalid knowledge base connector: {type}")
249
+
250
+ async def process(
251
+ self,
252
+ messages: List[Message],
253
+ tracker: Optional[DialogueStateTracker] = None,
254
+ ) -> List[Message]:
255
+ """Process a list of messages."""
256
+ self.knowledge_base_connector.connect_or_raise()
257
+
258
+ for message in messages:
259
+ start = time.time()
260
+
261
+ # Prepare search query
262
+ search_query = await self.query_rewriter.prepare_search_query(
263
+ message, tracker
264
+ )
265
+ message.set(
266
+ SEARCH_QUERY_KEY,
267
+ search_query,
268
+ add_to_output=True,
269
+ )
270
+
271
+ # Retrieve documents
272
+ search_result = await self.knowledge_base_connector.retrieve_documents(
273
+ search_query,
274
+ self.config[K_CONFIG_KEY] or DEFAULT_K,
275
+ self.config[THRESHOLD_CONFIG_KEY] or DEFAULT_THRESHOLD,
276
+ tracker,
277
+ )
278
+
279
+ if search_result is None:
280
+ message.set(
281
+ RETRIEVED_DOCUMENTS_KEY,
282
+ [],
283
+ add_to_output=True,
284
+ )
285
+ message.set(
286
+ POST_PROCESSED_DOCUMENTS_KEY,
287
+ [],
288
+ add_to_output=True,
289
+ )
290
+ continue
291
+
292
+ message.set(
293
+ RETRIEVED_DOCUMENTS_KEY,
294
+ search_result.to_dict(),
295
+ add_to_output=True,
296
+ )
297
+
298
+ # Post process documents
299
+ final_search_result = await self.document_post_processor.process_documents(
300
+ message, search_query, search_result, tracker
301
+ )
302
+ message.set(
303
+ POST_PROCESSED_DOCUMENTS_KEY,
304
+ final_search_result.to_dict(),
305
+ add_to_output=True,
306
+ )
307
+
308
+ structlogger.debug(
309
+ "document_retriever.process",
310
+ search_query=search_query,
311
+ search_result=search_result.to_dict(),
312
+ final_search_result=final_search_result.to_dict(),
313
+ )
314
+
315
+ end = time.time()
316
+ add_prompt_to_message_parse_data(
317
+ message,
318
+ DocumentRetriever.__name__,
319
+ "document_retriever_process",
320
+ user_prompt="Dummy prompt for document retriever process.",
321
+ llm_response=LLMResponse(
322
+ id=str(uuid.uuid4()),
323
+ choices=[
324
+ f"search_query: {search_query}\n"
325
+ f"retrieved_documents: {search_result.to_dict()}\n"
326
+ f"post_processed_documents: {final_search_result.to_dict()}",
327
+ ],
328
+ created=int(datetime.datetime.now().timestamp()),
329
+ latency=end - start,
330
+ ),
331
+ )
332
+
333
+ return messages
@@ -0,0 +1,39 @@
1
+ from typing import Any, Dict, Optional
2
+
3
+ from rasa.core.information_retrieval import SearchResultList
4
+ from rasa.document_retrieval.knowledge_base_connectors.knowledge_base_connector import (
5
+ KnowledgeBaseConnector,
6
+ )
7
+ from rasa.engine.storage.resource import Resource
8
+ from rasa.engine.storage.storage import ModelStorage
9
+ from rasa.shared.core.trackers import DialogueStateTracker
10
+
11
+
12
+ class APIConnector(KnowledgeBaseConnector):
13
+ def __init__(self, config: Dict[str, Any]) -> None:
14
+ self.config = config
15
+
16
+ @classmethod
17
+ def load(
18
+ cls,
19
+ config: Dict[str, Any],
20
+ model_storage: ModelStorage,
21
+ resource: Resource,
22
+ **kwargs: Any,
23
+ ) -> "APIConnector":
24
+ # TODO implement
25
+ return APIConnector(config)
26
+
27
+ async def retrieve_documents(
28
+ self,
29
+ search_query: str,
30
+ k: int,
31
+ threshold: float,
32
+ tracker: Optional[DialogueStateTracker],
33
+ ) -> Optional[SearchResultList]:
34
+ # TODO implement
35
+ return SearchResultList(results=[], metadata={})
36
+
37
+ def connect_or_raise(self) -> None:
38
+ # TODO implement
39
+ return None
@@ -0,0 +1,34 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any, Dict, Optional
3
+
4
+ from rasa.core.information_retrieval import SearchResultList
5
+ from rasa.engine.storage.resource import Resource
6
+ from rasa.engine.storage.storage import ModelStorage
7
+ from rasa.shared.core.trackers import DialogueStateTracker
8
+
9
+
10
+ class KnowledgeBaseConnector(ABC):
11
+ @abstractmethod
12
+ def connect_or_raise(self) -> None:
13
+ pass
14
+
15
+ @abstractmethod
16
+ async def retrieve_documents(
17
+ self,
18
+ search_query: str,
19
+ k: int,
20
+ threshold: float,
21
+ tracker: Optional[DialogueStateTracker],
22
+ ) -> Optional[SearchResultList]:
23
+ pass
24
+
25
+ @classmethod
26
+ @abstractmethod
27
+ def load(
28
+ cls,
29
+ config: Dict[str, Any],
30
+ model_storage: ModelStorage,
31
+ resource: Resource,
32
+ **kwargs: Any,
33
+ ) -> "KnowledgeBaseConnector":
34
+ pass
@@ -0,0 +1,226 @@
1
+ import copy
2
+ from enum import Enum
3
+ from typing import TYPE_CHECKING, Any, Dict, Optional
4
+
5
+ import structlog
6
+
7
+ from rasa.core.information_retrieval import (
8
+ InformationRetrieval,
9
+ InformationRetrievalException,
10
+ SearchResultList,
11
+ create_from_endpoint_config,
12
+ )
13
+ from rasa.core.information_retrieval.faiss import FAISS_Store
14
+ from rasa.document_retrieval.constants import (
15
+ DEFAULT_EMBEDDINGS_CONFIG,
16
+ DEFAULT_VECTOR_STORE,
17
+ DEFAULT_VECTOR_STORE_TYPE,
18
+ VECTOR_STORE_CONFIG_KEY,
19
+ VECTOR_STORE_TYPE_CONFIG_KEY,
20
+ )
21
+ from rasa.document_retrieval.knowledge_base_connectors.knowledge_base_connector import (
22
+ KnowledgeBaseConnector,
23
+ )
24
+ from rasa.engine.storage.resource import Resource
25
+ from rasa.engine.storage.storage import ModelStorage
26
+ from rasa.shared.constants import EMBEDDINGS_CONFIG_KEY
27
+ from rasa.shared.core.trackers import DialogueStateTracker, EventVerbosity
28
+ from rasa.shared.exceptions import RasaException
29
+ from rasa.shared.providers.embedding._langchain_embedding_client_adapter import (
30
+ _LangchainEmbeddingClientAdapter,
31
+ )
32
+ from rasa.shared.utils.health_check.embeddings_health_check_mixin import (
33
+ EmbeddingsHealthCheckMixin,
34
+ )
35
+ from rasa.shared.utils.health_check.health_check import perform_embeddings_health_check
36
+ from rasa.shared.utils.llm import embedder_factory, resolve_model_client_config
37
+
38
+ if TYPE_CHECKING:
39
+ from langchain.schema.embeddings import Embeddings
40
+
41
+
42
+ structlogger = structlog.get_logger()
43
+
44
+
45
+ class VectorStoreConnectionError(RasaException):
46
+ """Exception raised for errors in connecting to the vector store."""
47
+
48
+
49
+ class VectorStoreConfigurationError(RasaException):
50
+ """Exception raised for errors in vector store configuration."""
51
+
52
+
53
+ class VectorStoreType(Enum):
54
+ FAISS = "FAISS"
55
+ QDRANT = "QDRANT"
56
+ MILVUS = "MILVUS"
57
+
58
+ def __str__(self) -> str:
59
+ return self.value
60
+
61
+
62
+ class VectorStoreConnector(KnowledgeBaseConnector, EmbeddingsHealthCheckMixin):
63
+ def __init__(
64
+ self,
65
+ config: Dict[str, Any],
66
+ model_storage: ModelStorage,
67
+ resource: Resource,
68
+ vector_store: Optional[InformationRetrieval] = None,
69
+ ) -> None:
70
+ self.config = config
71
+ self.vector_store_type = config.get(VECTOR_STORE_CONFIG_KEY, {}).get(
72
+ VECTOR_STORE_TYPE_CONFIG_KEY
73
+ )
74
+
75
+ # Vector store object and configuration
76
+ self.vector_store = vector_store
77
+ self.vector_store_config = self.config.get(
78
+ VECTOR_STORE_CONFIG_KEY, DEFAULT_VECTOR_STORE
79
+ )
80
+
81
+ # Embeddings configuration for encoding the search query
82
+ self.embeddings_config = (
83
+ self.config[EMBEDDINGS_CONFIG_KEY] or DEFAULT_EMBEDDINGS_CONFIG
84
+ )
85
+
86
+ self._model_storage = model_storage
87
+ self._resource = resource
88
+
89
+ @classmethod
90
+ def _create_plain_embedder(cls, config: Dict[str, Any]) -> "Embeddings":
91
+ """Creates an embedder based on the given configuration.
92
+
93
+ Returns:
94
+ The embedder.
95
+ """
96
+ # Copy the config so original config is not modified
97
+ config = copy.deepcopy(config)
98
+ # Resolve config and instantiate the embedding client
99
+ config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
100
+ config.get(EMBEDDINGS_CONFIG_KEY), VectorStoreConnector.__name__
101
+ )
102
+ client = embedder_factory(
103
+ config.get(EMBEDDINGS_CONFIG_KEY), DEFAULT_EMBEDDINGS_CONFIG
104
+ )
105
+ # Wrap the embedding client in the adapter
106
+ return _LangchainEmbeddingClientAdapter(client)
107
+
108
+ @classmethod
109
+ def load(
110
+ cls,
111
+ config: Dict[str, Any],
112
+ model_storage: ModelStorage,
113
+ resource: Resource,
114
+ **kwargs: Any,
115
+ ) -> "VectorStoreConnector":
116
+ # Perform health check on the resolved embeddings client config
117
+ embedding_config = resolve_model_client_config(
118
+ config.get(EMBEDDINGS_CONFIG_KEY, {})
119
+ )
120
+ perform_embeddings_health_check(
121
+ embedding_config,
122
+ DEFAULT_EMBEDDINGS_CONFIG,
123
+ "vector_store_connector.load",
124
+ VectorStoreConnector.__name__,
125
+ )
126
+
127
+ store_type = config.get(VECTOR_STORE_CONFIG_KEY, {}).get(
128
+ VECTOR_STORE_TYPE_CONFIG_KEY
129
+ )
130
+ embeddings = cls._create_plain_embedder(config)
131
+
132
+ structlogger.info("vector_store_connector.load", config=config)
133
+ if store_type == VectorStoreType.FAISS.value:
134
+ # if a vector store is not specified,
135
+ # default to using FAISS with the index stored in the model
136
+ # TODO figure out a way to get path without context manager
137
+ with model_storage.read_from(resource) as path:
138
+ vector_store = FAISS_Store(
139
+ embeddings=embeddings,
140
+ index_path=path,
141
+ docs_folder=None,
142
+ create_index=False,
143
+ )
144
+ else:
145
+ vector_store = create_from_endpoint_config(
146
+ config_type=store_type,
147
+ embeddings=embeddings,
148
+ ) # type: ignore
149
+
150
+ return cls(
151
+ config=config,
152
+ model_storage=model_storage,
153
+ resource=resource,
154
+ vector_store=vector_store,
155
+ )
156
+
157
+ def connect_or_raise(self) -> None:
158
+ """Connects to the vector store or raises an exception.
159
+
160
+ Raise exceptions for the following cases:
161
+ - The configuration is not specified
162
+ - Unable to connect to the vector store
163
+
164
+ Args:
165
+ endpoints: Endpoints configuration.
166
+ """
167
+ if self.vector_store_type == VectorStoreType.FAISS.value:
168
+ return
169
+ from rasa.core.utils import AvailableEndpoints
170
+
171
+ endpoints = AvailableEndpoints.get_instance()
172
+
173
+ config = endpoints.vector_store if endpoints else None
174
+ store_type = self.config.get(VECTOR_STORE_CONFIG_KEY, {}).get(
175
+ VECTOR_STORE_TYPE_CONFIG_KEY
176
+ )
177
+
178
+ if config is None and store_type != DEFAULT_VECTOR_STORE_TYPE:
179
+ structlogger.error("vector_store_connector._connect_or_raise.no_config")
180
+ raise VectorStoreConfigurationError(
181
+ """No vector store specified. Please specify a vector
182
+ store in the endpoints configuration."""
183
+ )
184
+ try:
185
+ self.vector_store.connect(config) # type: ignore
186
+ except Exception as e:
187
+ structlogger.error(
188
+ "vector_store_connector._connect_or_raise.connect_error",
189
+ error=e,
190
+ config=config,
191
+ )
192
+ raise VectorStoreConnectionError(
193
+ f"Unable to connect to the vector store. Error: {e}"
194
+ )
195
+
196
+ async def retrieve_documents(
197
+ self,
198
+ search_query: str,
199
+ k: int,
200
+ threshold: float,
201
+ tracker: Optional[DialogueStateTracker],
202
+ ) -> Optional[SearchResultList]:
203
+ if self.vector_store is None:
204
+ return None
205
+
206
+ try:
207
+ self.connect_or_raise()
208
+ except (VectorStoreConfigurationError, VectorStoreConnectionError) as e:
209
+ structlogger.error("vector_store_connector.connection_error", error=e)
210
+ return None
211
+
212
+ if tracker is not None:
213
+ tracker_state = tracker.current_state(EventVerbosity.AFTER_RESTART)
214
+ else:
215
+ tracker_state = {}
216
+
217
+ try:
218
+ return await self.vector_store.search(
219
+ query=search_query,
220
+ threshold=threshold,
221
+ tracker_state=tracker_state,
222
+ k=k,
223
+ )
224
+ except InformationRetrievalException as e:
225
+ structlogger.error("vector_store.search_error", error=e)
226
+ return None