rasa-pro 3.13.0.dev2__py3-none-any.whl → 3.13.0.dev3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (48) hide show
  1. rasa/cli/run.py +10 -6
  2. rasa/cli/utils.py +7 -0
  3. rasa/core/channels/channel.py +30 -0
  4. rasa/core/channels/voice_ready/jambonz.py +25 -5
  5. rasa/core/channels/voice_ready/jambonz_protocol.py +4 -0
  6. rasa/core/information_retrieval/faiss.py +7 -68
  7. rasa/core/information_retrieval/information_retrieval.py +2 -40
  8. rasa/core/information_retrieval/milvus.py +2 -7
  9. rasa/core/information_retrieval/qdrant.py +2 -7
  10. rasa/core/nlg/contextual_response_rephraser.py +3 -0
  11. rasa/core/policies/enterprise_search_policy.py +310 -61
  12. rasa/core/policies/intentless_policy.py +3 -0
  13. rasa/dialogue_understanding/coexistence/llm_based_router.py +8 -0
  14. rasa/dialogue_understanding/commands/knowledge_answer_command.py +2 -2
  15. rasa/dialogue_understanding/generator/command_parser.py +1 -1
  16. rasa/dialogue_understanding/generator/flow_retrieval.py +1 -4
  17. rasa/dialogue_understanding/generator/llm_based_command_generator.py +1 -2
  18. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +13 -0
  19. rasa/dialogue_understanding/generator/prompt_templates/command_prompt_template.jinja2 +1 -1
  20. rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_claude_3_5_sonnet_20240620_template.jinja2 +1 -1
  21. rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_gpt_4o_2024_11_20_template.jinja2 +2 -24
  22. rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +22 -17
  23. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +27 -12
  24. rasa/dialogue_understanding_test/io.py +8 -13
  25. rasa/e2e_test/utils/validation.py +3 -3
  26. rasa/engine/recipes/default_components.py +0 -2
  27. rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +3 -0
  28. rasa/shared/utils/constants.py +3 -0
  29. rasa/shared/utils/llm.py +70 -24
  30. rasa/tracing/instrumentation/attribute_extractors.py +7 -10
  31. rasa/tracing/instrumentation/instrumentation.py +12 -12
  32. rasa/version.py +1 -1
  33. {rasa_pro-3.13.0.dev2.dist-info → rasa_pro-3.13.0.dev3.dist-info}/METADATA +2 -2
  34. {rasa_pro-3.13.0.dev2.dist-info → rasa_pro-3.13.0.dev3.dist-info}/RECORD +37 -48
  35. rasa/document_retrieval/__init__.py +0 -0
  36. rasa/document_retrieval/constants.py +0 -32
  37. rasa/document_retrieval/document_post_processor.py +0 -351
  38. rasa/document_retrieval/document_post_processor_prompt_template.jinja2 +0 -0
  39. rasa/document_retrieval/document_retriever.py +0 -333
  40. rasa/document_retrieval/knowledge_base_connectors/__init__.py +0 -0
  41. rasa/document_retrieval/knowledge_base_connectors/api_connector.py +0 -39
  42. rasa/document_retrieval/knowledge_base_connectors/knowledge_base_connector.py +0 -34
  43. rasa/document_retrieval/knowledge_base_connectors/vector_store_connector.py +0 -226
  44. rasa/document_retrieval/query_rewriter.py +0 -234
  45. rasa/document_retrieval/query_rewriter_prompt_template.jinja2 +0 -8
  46. {rasa_pro-3.13.0.dev2.dist-info → rasa_pro-3.13.0.dev3.dist-info}/NOTICE +0 -0
  47. {rasa_pro-3.13.0.dev2.dist-info → rasa_pro-3.13.0.dev3.dist-info}/WHEEL +0 -0
  48. {rasa_pro-3.13.0.dev2.dist-info → rasa_pro-3.13.0.dev3.dist-info}/entry_points.txt +0 -0
@@ -1,10 +1,12 @@
1
1
  import importlib.resources
2
+ import json
2
3
  import re
3
4
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, Text
4
5
 
5
6
  import dotenv
6
7
  import structlog
7
8
  from jinja2 import Template
9
+ from pydantic import ValidationError
8
10
 
9
11
  import rasa.shared.utils.io
10
12
  from rasa.core.constants import (
@@ -14,9 +16,12 @@ from rasa.core.constants import (
14
16
  UTTER_SOURCE_METADATA_KEY,
15
17
  )
16
18
  from rasa.core.information_retrieval import (
19
+ InformationRetrieval,
20
+ InformationRetrievalException,
17
21
  SearchResult,
18
- SearchResultList,
22
+ create_from_endpoint_config,
19
23
  )
24
+ from rasa.core.information_retrieval.faiss import FAISS_Store
20
25
  from rasa.core.policies.policy import Policy, PolicyPrediction
21
26
  from rasa.core.utils import AvailableEndpoints
22
27
  from rasa.dialogue_understanding.generator.constants import (
@@ -33,10 +38,6 @@ from rasa.dialogue_understanding.stack.frames import (
33
38
  PatternFlowStackFrame,
34
39
  SearchStackFrame,
35
40
  )
36
- from rasa.document_retrieval.constants import (
37
- POST_PROCESSED_DOCUMENTS_KEY,
38
- SEARCH_QUERY_KEY,
39
- )
40
41
  from rasa.engine.graph import ExecutionContext
41
42
  from rasa.engine.recipes.default_recipe import DefaultV1Recipe
42
43
  from rasa.engine.storage.resource import Resource
@@ -44,7 +45,10 @@ from rasa.engine.storage.storage import ModelStorage
44
45
  from rasa.graph_components.providers.forms_provider import Forms
45
46
  from rasa.graph_components.providers.responses_provider import Responses
46
47
  from rasa.shared.constants import (
48
+ EMBEDDINGS_CONFIG_KEY,
47
49
  MODEL_CONFIG_KEY,
50
+ MODEL_GROUP_ID_CONFIG_KEY,
51
+ MODEL_NAME_CONFIG_KEY,
48
52
  OPENAI_PROVIDER,
49
53
  PROMPT_CONFIG_KEY,
50
54
  PROVIDER_CONFIG_KEY,
@@ -56,10 +60,10 @@ from rasa.shared.core.constants import (
56
60
  DEFAULT_SLOT_NAMES,
57
61
  )
58
62
  from rasa.shared.core.domain import Domain
59
- from rasa.shared.core.events import Event
63
+ from rasa.shared.core.events import BotUttered, Event, UserUttered
60
64
  from rasa.shared.core.generator import TrackerWithCachedStates
61
- from rasa.shared.core.trackers import DialogueStateTracker
62
- from rasa.shared.exceptions import FileIOException
65
+ from rasa.shared.core.trackers import DialogueStateTracker, EventVerbosity
66
+ from rasa.shared.exceptions import FileIOException, RasaException
63
67
  from rasa.shared.nlu.constants import (
64
68
  KEY_COMPONENT_NAME,
65
69
  KEY_LLM_RESPONSE_METADATA,
@@ -68,8 +72,16 @@ from rasa.shared.nlu.constants import (
68
72
  PROMPTS,
69
73
  )
70
74
  from rasa.shared.nlu.training_data.training_data import TrainingData
75
+ from rasa.shared.providers.embedding._langchain_embedding_client_adapter import (
76
+ _LangchainEmbeddingClientAdapter,
77
+ )
71
78
  from rasa.shared.providers.llm.llm_client import LLMClient
72
79
  from rasa.shared.providers.llm.llm_response import LLMResponse, measure_llm_latency
80
+ from rasa.shared.utils.cli import print_error_and_exit
81
+ from rasa.shared.utils.constants import (
82
+ LOG_COMPONENT_SOURCE_METHOD_FINGERPRINT_ADDON,
83
+ LOG_COMPONENT_SOURCE_METHOD_INIT,
84
+ )
73
85
  from rasa.shared.utils.health_check.embeddings_health_check_mixin import (
74
86
  EmbeddingsHealthCheckMixin,
75
87
  )
@@ -77,13 +89,23 @@ from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheck
77
89
  from rasa.shared.utils.io import deep_container_fingerprint
78
90
  from rasa.shared.utils.llm import (
79
91
  DEFAULT_OPENAI_CHAT_MODEL_NAME,
92
+ DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
93
+ embedder_factory,
80
94
  get_prompt_template,
81
95
  llm_factory,
82
96
  resolve_model_client_config,
97
+ sanitize_message_for_prompt,
83
98
  tracker_as_readable_transcript,
84
99
  )
100
+ from rasa.telemetry import (
101
+ track_enterprise_search_policy_predict,
102
+ track_enterprise_search_policy_train_completed,
103
+ track_enterprise_search_policy_train_started,
104
+ )
85
105
 
86
106
  if TYPE_CHECKING:
107
+ from langchain.schema.embeddings import Embeddings
108
+
87
109
  from rasa.core.featurizers.tracker_featurizers import TrackerFeaturizer
88
110
 
89
111
  from rasa.utils.log_utils import log_llm
@@ -92,11 +114,22 @@ logger = structlog.get_logger()
92
114
 
93
115
  dotenv.load_dotenv("./.env")
94
116
 
117
+ SOURCE_PROPERTY = "source"
118
+ VECTOR_STORE_TYPE_PROPERTY = "type"
119
+ VECTOR_STORE_PROPERTY = "vector_store"
120
+ VECTOR_STORE_THRESHOLD_PROPERTY = "threshold"
95
121
  TRACE_TOKENS_PROPERTY = "trace_prompt_tokens"
96
122
  CITATION_ENABLED_PROPERTY = "citation_enabled"
97
123
  USE_LLM_PROPERTY = "use_generative_llm"
98
124
  MAX_MESSAGES_IN_QUERY_KEY = "max_messages_in_query"
99
125
 
126
+ DEFAULT_VECTOR_STORE_TYPE = "faiss"
127
+ DEFAULT_VECTOR_STORE_THRESHOLD = 0.0
128
+ DEFAULT_VECTOR_STORE = {
129
+ VECTOR_STORE_TYPE_PROPERTY: DEFAULT_VECTOR_STORE_TYPE,
130
+ SOURCE_PROPERTY: "./docs",
131
+ VECTOR_STORE_THRESHOLD_PROPERTY: DEFAULT_VECTOR_STORE_THRESHOLD,
132
+ }
100
133
 
101
134
  DEFAULT_LLM_CONFIG = {
102
135
  PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
@@ -107,6 +140,11 @@ DEFAULT_LLM_CONFIG = {
107
140
  "max_retries": 1,
108
141
  }
109
142
 
143
+ DEFAULT_EMBEDDINGS_CONFIG = {
144
+ PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
145
+ "model": DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
146
+ }
147
+
110
148
  ENTERPRISE_SEARCH_PROMPT_FILE_NAME = "enterprise_search_policy_prompt.jinja2"
111
149
  ENTERPRISE_SEARCH_CONFIG_FILE_NAME = "config.json"
112
150
 
@@ -122,6 +160,14 @@ DEFAULT_ENTERPRISE_SEARCH_PROMPT_WITH_CITATION_TEMPLATE = importlib.resources.re
122
160
  )
123
161
 
124
162
 
163
+ class VectorStoreConnectionError(RasaException):
164
+ """Exception raised for errors in connecting to the vector store."""
165
+
166
+
167
+ class VectorStoreConfigurationError(RasaException):
168
+ """Exception raised for errors in vector store configuration."""
169
+
170
+
125
171
  @DefaultV1Recipe.register(
126
172
  DefaultV1Recipe.ComponentType.POLICY_WITH_END_TO_END_SUPPORT, is_trainable=True
127
173
  )
@@ -155,6 +201,7 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
155
201
  """Returns the default config of the policy."""
156
202
  return {
157
203
  POLICY_PRIORITY: SEARCH_POLICY_PRIORITY,
204
+ VECTOR_STORE_PROPERTY: DEFAULT_VECTOR_STORE,
158
205
  }
159
206
 
160
207
  def __init__(
@@ -163,6 +210,7 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
163
210
  model_storage: ModelStorage,
164
211
  resource: Resource,
165
212
  execution_context: ExecutionContext,
213
+ vector_store: Optional[InformationRetrieval] = None,
166
214
  featurizer: Optional["TrackerFeaturizer"] = None,
167
215
  prompt_template: Optional[Text] = None,
168
216
  ) -> None:
@@ -173,6 +221,21 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
173
221
  self.config[LLM_CONFIG_KEY] = resolve_model_client_config(
174
222
  self.config.get(LLM_CONFIG_KEY), EnterpriseSearchPolicy.__name__
175
223
  )
224
+ # Resolve embeddings config
225
+ self.config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
226
+ self.config.get(EMBEDDINGS_CONFIG_KEY), EnterpriseSearchPolicy.__name__
227
+ )
228
+
229
+ # Vector store object and configuration
230
+ self.vector_store = vector_store
231
+ self.vector_store_config = self.config.get(
232
+ VECTOR_STORE_PROPERTY, DEFAULT_VECTOR_STORE
233
+ )
234
+
235
+ # Embeddings configuration for encoding the search query
236
+ self.embeddings_config = (
237
+ self.config[EMBEDDINGS_CONFIG_KEY] or DEFAULT_EMBEDDINGS_CONFIG
238
+ )
176
239
 
177
240
  # LLM Configuration for response generation
178
241
  self.llm_config = self.config[LLM_CONFIG_KEY] or DEFAULT_LLM_CONFIG
@@ -180,6 +243,9 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
180
243
  # Maximum number of turns to include in the prompt
181
244
  self.max_history = self.config.get(POLICY_MAX_HISTORY)
182
245
 
246
+ # Maximum number of messages to include in the search query
247
+ self.max_messages_in_query = self.config.get(MAX_MESSAGES_IN_QUERY_KEY, 2)
248
+
183
249
  # boolean to enable/disable tracing of prompt tokens
184
250
  self.trace_prompt_tokens = self.config.get(TRACE_TOKENS_PROPERTY, False)
185
251
 
@@ -192,15 +258,38 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
192
258
  self.prompt_template = prompt_template or get_prompt_template(
193
259
  self.config.get(PROMPT_CONFIG_KEY),
194
260
  DEFAULT_ENTERPRISE_SEARCH_PROMPT_TEMPLATE,
261
+ log_source_component=EnterpriseSearchPolicy.__name__,
262
+ log_source_method=LOG_COMPONENT_SOURCE_METHOD_INIT,
195
263
  )
196
264
  self.citation_prompt_template = get_prompt_template(
197
265
  self.config.get(PROMPT_CONFIG_KEY),
198
266
  DEFAULT_ENTERPRISE_SEARCH_PROMPT_WITH_CITATION_TEMPLATE,
267
+ log_source_component=EnterpriseSearchPolicy.__name__,
268
+ log_source_method=LOG_COMPONENT_SOURCE_METHOD_INIT,
199
269
  )
200
270
  # If citation is enabled, use the citation prompt template
201
271
  if self.citation_enabled:
202
272
  self.prompt_template = self.citation_prompt_template
203
273
 
274
+ @classmethod
275
+ def _create_plain_embedder(cls, config: Dict[Text, Any]) -> "Embeddings":
276
+ """Creates an embedder based on the given configuration.
277
+
278
+ Returns:
279
+ The embedder.
280
+ """
281
+ # Copy the config so original config is not modified
282
+ config = config.copy()
283
+ # Resolve config and instantiate the embedding client
284
+ config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
285
+ config.get(EMBEDDINGS_CONFIG_KEY), EnterpriseSearchPolicy.__name__
286
+ )
287
+ client = embedder_factory(
288
+ config.get(EMBEDDINGS_CONFIG_KEY), DEFAULT_EMBEDDINGS_CONFIG
289
+ )
290
+ # Wrap the embedding client in the adapter
291
+ return _LangchainEmbeddingClientAdapter(client)
292
+
204
293
  @classmethod
205
294
  def _add_prompt_and_llm_response_to_latest_message(
206
295
  cls,
@@ -265,24 +354,52 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
265
354
  # Perform health checks for both LLM and embeddings client configs
266
355
  self._perform_health_checks(self.config, "enterprise_search_policy.train")
267
356
 
268
- # # telemetry call to track training start
269
- # track_enterprise_search_policy_train_started()
270
- # # telemetry call to track training completion
271
- # track_enterprise_search_policy_train_completed(
272
- # vector_store_type=store_type,
273
- # embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
274
- # embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
275
- # or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
276
- # embeddings_model_group_id=self.embeddings_config.get(
277
- # MODEL_GROUP_ID_CONFIG_KEY
278
- # ),
279
- # llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
280
- # llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
281
- # or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
282
- # llm_model_group_id=self.llm_config.get(MODEL_GROUP_ID_CONFIG_KEY),
283
- # citation_enabled=self.citation_enabled,
284
- # )
357
+ store_type = self.vector_store_config.get(VECTOR_STORE_TYPE_PROPERTY)
358
+
359
+ # telemetry call to track training start
360
+ track_enterprise_search_policy_train_started()
285
361
 
362
+ # validate embedding configuration
363
+ try:
364
+ embeddings = self._create_plain_embedder(self.config)
365
+ except (ValidationError, Exception) as e:
366
+ logger.error(
367
+ "enterprise_search_policy.train.embedder_instantiation_failed",
368
+ message="Unable to instantiate the embedding client.",
369
+ error=e,
370
+ )
371
+ print_error_and_exit(
372
+ "Unable to create embedder. Please make sure you specified the "
373
+ f"required environment variables. Error: {e}"
374
+ )
375
+
376
+ if store_type == DEFAULT_VECTOR_STORE_TYPE:
377
+ logger.info("enterprise_search_policy.train.faiss")
378
+ with self._model_storage.write_to(self._resource) as path:
379
+ self.vector_store = FAISS_Store(
380
+ docs_folder=self.vector_store_config.get(SOURCE_PROPERTY),
381
+ embeddings=embeddings,
382
+ index_path=path,
383
+ create_index=True,
384
+ )
385
+ else:
386
+ logger.info("enterprise_search_policy.train.custom", store_type=store_type)
387
+
388
+ # telemetry call to track training completion
389
+ track_enterprise_search_policy_train_completed(
390
+ vector_store_type=store_type,
391
+ embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
392
+ embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
393
+ or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
394
+ embeddings_model_group_id=self.embeddings_config.get(
395
+ MODEL_GROUP_ID_CONFIG_KEY
396
+ ),
397
+ llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
398
+ llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
399
+ or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
400
+ llm_model_group_id=self.llm_config.get(MODEL_GROUP_ID_CONFIG_KEY),
401
+ citation_enabled=self.citation_enabled,
402
+ )
286
403
  self.persist()
287
404
  return self._resource
288
405
 
@@ -319,6 +436,60 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
319
436
  )
320
437
  return template_slots
321
438
 
439
+ def _connect_vector_store_or_raise(
440
+ self, endpoints: Optional[AvailableEndpoints]
441
+ ) -> None:
442
+ """Connects to the vector store or raises an exception.
443
+
444
+ Raise exceptions for the following cases:
445
+ - The configuration is not specified
446
+ - Unable to connect to the vector store
447
+
448
+ Args:
449
+ endpoints: Endpoints configuration.
450
+ """
451
+ config = endpoints.vector_store if endpoints else None
452
+ store_type = self.vector_store_config.get(VECTOR_STORE_TYPE_PROPERTY)
453
+ if config is None and store_type != DEFAULT_VECTOR_STORE_TYPE:
454
+ logger.error(
455
+ "enterprise_search_policy._connect_vector_store_or_raise.no_config"
456
+ )
457
+ raise VectorStoreConfigurationError(
458
+ """No vector store specified. Please specify a vector
459
+ store in the endpoints configuration"""
460
+ )
461
+ try:
462
+ self.vector_store.connect(config) # type: ignore
463
+ except Exception as e:
464
+ logger.error(
465
+ "enterprise_search_policy._connect_vector_store_or_raise.connect_error",
466
+ error=e,
467
+ config=config,
468
+ )
469
+ raise VectorStoreConnectionError(
470
+ f"Unable to connect to the vector store. Error: {e}"
471
+ )
472
+
473
+ def _prepare_search_query(self, tracker: DialogueStateTracker, history: int) -> str:
474
+ """Prepares the search query.
475
+ The search query is the last N messages in the conversation history.
476
+
477
+ Args:
478
+ tracker: The tracker containing the conversation history up to now.
479
+ history: The number of messages to include in the search query.
480
+
481
+ Returns:
482
+ The search query.
483
+ """
484
+ transcript = []
485
+ for event in tracker.applied_events():
486
+ if isinstance(event, UserUttered) or isinstance(event, BotUttered):
487
+ transcript.append(sanitize_message_for_prompt(event.text))
488
+
489
+ search_query = " ".join(transcript[-history:][::-1])
490
+ logger.debug("search_query", search_query=search_query)
491
+ return search_query
492
+
322
493
  async def predict_action_probabilities( # type: ignore[override]
323
494
  self,
324
495
  tracker: DialogueStateTracker,
@@ -342,34 +513,44 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
342
513
  The prediction.
343
514
  """
344
515
  logger_key = "enterprise_search_policy.predict_action_probabilities"
345
-
516
+ vector_search_threshold = self.vector_store_config.get(
517
+ VECTOR_STORE_THRESHOLD_PROPERTY, DEFAULT_VECTOR_STORE_THRESHOLD
518
+ )
519
+ llm = llm_factory(self.config.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG)
346
520
  if not self.supports_current_stack_frame(
347
521
  tracker, False, False
348
522
  ) or self.should_abstain_in_coexistence(tracker, True):
349
523
  return self._prediction(self._default_predictions(domain))
350
524
 
351
- # retrieve documents from the latest message
352
- # document retrieval happened earlier in the pipeline
353
- if tracker.latest_message is None or tracker.latest_message.parse_data is None:
354
- logger.info(f"{logger_key}.no_documents")
355
- return self._create_prediction_cannot_handle(domain, tracker)
525
+ if not self.vector_store:
526
+ logger.error(f"{logger_key}.no_vector_store")
527
+ return self._create_prediction_internal_error(domain, tracker)
356
528
 
357
- documents_data = tracker.latest_message.parse_data.get(
358
- POST_PROCESSED_DOCUMENTS_KEY
359
- )
529
+ try:
530
+ self._connect_vector_store_or_raise(endpoints)
531
+ except (VectorStoreConfigurationError, VectorStoreConnectionError) as e:
532
+ logger.error(f"{logger_key}.connection_error", error=e)
533
+ return self._create_prediction_internal_error(domain, tracker)
360
534
 
361
- if not documents_data:
362
- logger.info(f"{logger_key}.no_documents")
363
- return self._create_prediction_cannot_handle(domain, tracker)
535
+ search_query = self._prepare_search_query(
536
+ tracker, int(self.max_messages_in_query)
537
+ )
538
+ tracker_state = tracker.current_state(EventVerbosity.AFTER_RESTART)
364
539
 
365
- documents = SearchResultList.from_dict(documents_data)
540
+ try:
541
+ documents = await self.vector_store.search(
542
+ query=search_query,
543
+ tracker_state=tracker_state,
544
+ threshold=vector_search_threshold,
545
+ )
546
+ except InformationRetrievalException as e:
547
+ logger.error(f"{logger_key}.search_error", error=e)
548
+ return self._create_prediction_internal_error(domain, tracker)
366
549
 
367
550
  if not documents.results:
368
551
  logger.info(f"{logger_key}.no_documents")
369
552
  return self._create_prediction_cannot_handle(domain, tracker)
370
553
 
371
- llm = llm_factory(self.config.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG)
372
-
373
554
  if self.use_llm:
374
555
  prompt = self._render_prompt(tracker, documents.results)
375
556
  llm_response = await self._generate_llm_answer(llm, prompt)
@@ -414,29 +595,25 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
414
595
  result.text for result in documents.results
415
596
  ],
416
597
  UTTER_SOURCE_METADATA_KEY: self.__class__.__name__,
417
- SEARCH_QUERY_METADATA_KEY: tracker.latest_message.parse_data.get(
418
- SEARCH_QUERY_KEY
419
- ),
598
+ SEARCH_QUERY_METADATA_KEY: search_query,
420
599
  }
421
600
  }
422
601
 
423
- # # telemetry call to track policy prediction
424
- # track_enterprise_search_policy_predict(
425
- # vector_store_type=self.vector_store_config.get(
426
- # VECTOR_STORE_TYPE_PROPERTY),
427
- # embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
428
- # embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
429
- # or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
430
- # embeddings_model_group_id=self.embeddings_config.get(
431
- # MODEL_GROUP_ID_CONFIG_KEY
432
- # ),
433
- # llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
434
- # llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
435
- # or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
436
- # llm_model_group_id=self.llm_config.get(MODEL_GROUP_ID_CONFIG_KEY),
437
- # citation_enabled=self.citation_enabled,
438
- # )
439
-
602
+ # telemetry call to track policy prediction
603
+ track_enterprise_search_policy_predict(
604
+ vector_store_type=self.vector_store_config.get(VECTOR_STORE_TYPE_PROPERTY),
605
+ embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
606
+ embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
607
+ or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
608
+ embeddings_model_group_id=self.embeddings_config.get(
609
+ MODEL_GROUP_ID_CONFIG_KEY
610
+ ),
611
+ llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
612
+ llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
613
+ or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
614
+ llm_model_group_id=self.llm_config.get(MODEL_GROUP_ID_CONFIG_KEY),
615
+ citation_enabled=self.citation_enabled,
616
+ )
440
617
  return self._create_prediction(
441
618
  domain=domain, tracker=tracker, action_metadata=action_metadata
442
619
  )
@@ -604,28 +781,89 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
604
781
  "enterprise_search_policy.load.failed", error=e, resource=resource.name
605
782
  )
606
783
 
784
+ store_type = config.get(VECTOR_STORE_PROPERTY, {}).get(
785
+ VECTOR_STORE_TYPE_PROPERTY
786
+ )
787
+
788
+ embeddings = cls._create_plain_embedder(config)
789
+
607
790
  logger.info("enterprise_search_policy.load", config=config)
791
+ if store_type == DEFAULT_VECTOR_STORE_TYPE:
792
+ # if a vector store is not specified,
793
+ # default to using FAISS with the index stored in the model
794
+ # TODO figure out a way to get path without context manager
795
+ with model_storage.read_from(resource) as path:
796
+ vector_store = FAISS_Store(
797
+ embeddings=embeddings,
798
+ index_path=path,
799
+ docs_folder=None,
800
+ create_index=False,
801
+ )
802
+ else:
803
+ vector_store = create_from_endpoint_config(
804
+ config_type=store_type,
805
+ embeddings=embeddings,
806
+ ) # type: ignore
608
807
 
609
808
  return cls(
610
809
  config,
611
810
  model_storage,
612
811
  resource,
613
812
  execution_context,
813
+ vector_store=vector_store,
614
814
  prompt_template=prompt_template,
615
815
  )
616
816
 
817
+ @classmethod
818
+ def _get_local_knowledge_data(cls, config: Dict[str, Any]) -> Optional[List[str]]:
819
+ """This is required only for local knowledge base types.
820
+
821
+ e.g. FAISS, to ensure that the graph component is retrained when the knowledge
822
+ base is updated.
823
+ """
824
+ merged_config = {**cls.get_default_config(), **config}
825
+
826
+ store_type = merged_config.get(VECTOR_STORE_PROPERTY, {}).get(
827
+ VECTOR_STORE_TYPE_PROPERTY
828
+ )
829
+ if store_type != DEFAULT_VECTOR_STORE_TYPE:
830
+ return None
831
+
832
+ source = merged_config.get(VECTOR_STORE_PROPERTY, {}).get(SOURCE_PROPERTY)
833
+ if not source:
834
+ return None
835
+
836
+ docs = FAISS_Store.load_documents(source)
837
+
838
+ if len(docs) == 0:
839
+ return None
840
+
841
+ docs_as_strings = [
842
+ json.dumps(doc.dict(), ensure_ascii=False, sort_keys=True) for doc in docs
843
+ ]
844
+ return sorted(docs_as_strings)
845
+
617
846
  @classmethod
618
847
  def fingerprint_addon(cls, config: Dict[str, Any]) -> Optional[str]:
619
848
  """Add a fingerprint of enterprise search policy for the graph."""
849
+ local_knowledge_data = cls._get_local_knowledge_data(config)
850
+
620
851
  prompt_template = get_prompt_template(
621
852
  config.get(PROMPT_CONFIG_KEY),
622
853
  DEFAULT_ENTERPRISE_SEARCH_PROMPT_TEMPLATE,
854
+ log_source_component=EnterpriseSearchPolicy.__name__,
855
+ log_source_method=LOG_COMPONENT_SOURCE_METHOD_FINGERPRINT_ADDON,
623
856
  )
624
857
 
625
858
  llm_config = resolve_model_client_config(
626
859
  config.get(LLM_CONFIG_KEY), EnterpriseSearchPolicy.__name__
627
860
  )
628
- return deep_container_fingerprint([prompt_template, llm_config])
861
+ embedding_config = resolve_model_client_config(
862
+ config.get(EMBEDDINGS_CONFIG_KEY), EnterpriseSearchPolicy.__name__
863
+ )
864
+ return deep_container_fingerprint(
865
+ [prompt_template, local_knowledge_data, llm_config, embedding_config]
866
+ )
629
867
 
630
868
  @staticmethod
631
869
  def post_process_citations(llm_answer: str) -> str:
@@ -730,3 +968,14 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
730
968
  log_source_method,
731
969
  EnterpriseSearchPolicy.__name__,
732
970
  )
971
+
972
+ # Perform health check of the embeddings client config
973
+ embeddings_config = resolve_model_client_config(
974
+ config.get(EMBEDDINGS_CONFIG_KEY, {})
975
+ )
976
+ cls.perform_embeddings_health_check(
977
+ embeddings_config,
978
+ DEFAULT_EMBEDDINGS_CONFIG,
979
+ log_source_method,
980
+ EnterpriseSearchPolicy.__name__,
981
+ )
@@ -58,6 +58,7 @@ from rasa.shared.providers.embedding._langchain_embedding_client_adapter import
58
58
  _LangchainEmbeddingClientAdapter,
59
59
  )
60
60
  from rasa.shared.providers.llm.llm_client import LLMClient
61
+ from rasa.shared.utils.constants import LOG_COMPONENT_SOURCE_METHOD_FINGERPRINT_ADDON
61
62
  from rasa.shared.utils.health_check.embeddings_health_check_mixin import (
62
63
  EmbeddingsHealthCheckMixin,
63
64
  )
@@ -939,6 +940,8 @@ class IntentlessPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Policy):
939
940
  prompt_template = get_prompt_template(
940
941
  config.get(PROMPT_CONFIG_KEY),
941
942
  DEFAULT_INTENTLESS_PROMPT_TEMPLATE,
943
+ log_source_component=IntentlessPolicy.__name__,
944
+ log_source_method=LOG_COMPONENT_SOURCE_METHOD_FINGERPRINT_ADDON,
942
945
  )
943
946
 
944
947
  llm_config = resolve_model_client_config(
@@ -35,6 +35,10 @@ from rasa.shared.exceptions import FileIOException, InvalidConfigException
35
35
  from rasa.shared.nlu.constants import COMMANDS, TEXT
36
36
  from rasa.shared.nlu.training_data.message import Message
37
37
  from rasa.shared.nlu.training_data.training_data import TrainingData
38
+ from rasa.shared.utils.constants import (
39
+ LOG_COMPONENT_SOURCE_METHOD_FINGERPRINT_ADDON,
40
+ LOG_COMPONENT_SOURCE_METHOD_INIT,
41
+ )
38
42
  from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheckMixin
39
43
  from rasa.shared.utils.io import deep_container_fingerprint
40
44
  from rasa.shared.utils.llm import (
@@ -107,6 +111,8 @@ class LLMBasedRouter(LLMHealthCheckMixin, GraphComponent):
107
111
  or get_prompt_template(
108
112
  config.get(PROMPT_CONFIG_KEY),
109
113
  DEFAULT_COMMAND_PROMPT_TEMPLATE,
114
+ log_source_component=LLMBasedRouter.__name__,
115
+ log_source_method=LOG_COMPONENT_SOURCE_METHOD_INIT,
110
116
  ).strip()
111
117
  )
112
118
 
@@ -318,6 +324,8 @@ class LLMBasedRouter(LLMHealthCheckMixin, GraphComponent):
318
324
  prompt_template = get_prompt_template(
319
325
  config.get(PROMPT_CONFIG_KEY),
320
326
  DEFAULT_COMMAND_PROMPT_TEMPLATE,
327
+ log_source_component=LLMBasedRouter.__name__,
328
+ log_source_method=LOG_COMPONENT_SOURCE_METHOD_FINGERPRINT_ADDON,
321
329
  )
322
330
 
323
331
  llm_config = resolve_model_client_config(
@@ -65,7 +65,7 @@ class KnowledgeAnswerCommand(FreeFormAnswerCommand):
65
65
  """Converts the command to a DSL string."""
66
66
  mapper = {
67
67
  CommandSyntaxVersion.v1: "SearchAndReply()",
68
- CommandSyntaxVersion.v2: "search and reply",
68
+ CommandSyntaxVersion.v2: "provide info",
69
69
  }
70
70
  return mapper.get(
71
71
  CommandSyntaxManager.get_syntax_version(),
@@ -81,7 +81,7 @@ class KnowledgeAnswerCommand(FreeFormAnswerCommand):
81
81
  def regex_pattern() -> str:
82
82
  mapper = {
83
83
  CommandSyntaxVersion.v1: r"SearchAndReply\(\)",
84
- CommandSyntaxVersion.v2: r"""^[\s\W\d]*search and reply['"`]*$""",
84
+ CommandSyntaxVersion.v2: r"""^[\s\W\d]*provide info['"`]*$""",
85
85
  }
86
86
  return mapper.get(
87
87
  CommandSyntaxManager.get_syntax_version(),
@@ -169,7 +169,7 @@ def _parse_standard_commands(
169
169
  commands: List[Command] = []
170
170
  for command_clz in standard_commands:
171
171
  pattern = _get_compiled_pattern(command_clz.regex_pattern())
172
- if match := pattern.search(action.strip()):
172
+ if match := pattern.search(action):
173
173
  parsed_command = command_clz.from_dsl(match, **kwargs)
174
174
  if _additional_parsing_fn := _get_additional_parsing_logic(command_clz):
175
175
  parsed_command = _additional_parsing_fn(parsed_command, flows, **kwargs)
@@ -52,7 +52,6 @@ from rasa.shared.utils.llm import (
52
52
  USER,
53
53
  allowed_values_for_slot,
54
54
  embedder_factory,
55
- get_prompt_template,
56
55
  resolve_model_client_config,
57
56
  tracker_as_readable_transcript,
58
57
  )
@@ -103,9 +102,7 @@ class FlowRetrieval(EmbeddingsHealthCheckMixin):
103
102
  self.config.get(EMBEDDINGS_CONFIG_KEY), FlowRetrieval.__name__
104
103
  )
105
104
  self.vector_store: Optional[FAISS] = None
106
- self.flow_document_template = get_prompt_template(
107
- None, DEFAULT_FLOW_DOCUMENT_TEMPLATE
108
- )
105
+ self.flow_document_template = DEFAULT_FLOW_DOCUMENT_TEMPLATE
109
106
  self._model_storage = model_storage
110
107
  self._resource = resource
111
108
 
@@ -390,8 +390,7 @@ class LLMBasedCommandGenerator(
390
390
  "slots": slots_with_info,
391
391
  }
392
392
  )
393
-
394
- return sorted(result, key=lambda x: x["name"])
393
+ return result
395
394
 
396
395
  @staticmethod
397
396
  def is_extractable(