rasa-pro 3.13.0.dev1__py3-none-any.whl → 3.13.0.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (58) hide show
  1. rasa/core/actions/action.py +0 -6
  2. rasa/core/channels/voice_ready/audiocodes.py +52 -17
  3. rasa/core/channels/voice_stream/audiocodes.py +53 -9
  4. rasa/core/channels/voice_stream/genesys.py +146 -16
  5. rasa/core/information_retrieval/faiss.py +6 -1
  6. rasa/core/information_retrieval/information_retrieval.py +40 -2
  7. rasa/core/information_retrieval/milvus.py +7 -2
  8. rasa/core/information_retrieval/qdrant.py +7 -2
  9. rasa/core/policies/enterprise_search_policy.py +61 -301
  10. rasa/core/policies/flows/flow_executor.py +3 -38
  11. rasa/core/processor.py +27 -6
  12. rasa/core/utils.py +53 -0
  13. rasa/dialogue_understanding/commands/cancel_flow_command.py +4 -59
  14. rasa/dialogue_understanding/commands/start_flow_command.py +0 -41
  15. rasa/dialogue_understanding/generator/command_generator.py +67 -0
  16. rasa/dialogue_understanding/generator/command_parser.py +1 -1
  17. rasa/dialogue_understanding/generator/llm_based_command_generator.py +4 -13
  18. rasa/dialogue_understanding/generator/prompt_templates/command_prompt_template.jinja2 +1 -1
  19. rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_gpt_4o_2024_11_20_template.jinja2 +20 -1
  20. rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +7 -0
  21. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +0 -61
  22. rasa/dialogue_understanding/processor/command_processor.py +7 -65
  23. rasa/dialogue_understanding/stack/utils.py +0 -38
  24. rasa/dialogue_understanding_test/io.py +13 -8
  25. rasa/document_retrieval/__init__.py +0 -0
  26. rasa/document_retrieval/constants.py +32 -0
  27. rasa/document_retrieval/document_post_processor.py +351 -0
  28. rasa/document_retrieval/document_post_processor_prompt_template.jinja2 +0 -0
  29. rasa/document_retrieval/document_retriever.py +333 -0
  30. rasa/document_retrieval/knowledge_base_connectors/__init__.py +0 -0
  31. rasa/document_retrieval/knowledge_base_connectors/api_connector.py +39 -0
  32. rasa/document_retrieval/knowledge_base_connectors/knowledge_base_connector.py +34 -0
  33. rasa/document_retrieval/knowledge_base_connectors/vector_store_connector.py +226 -0
  34. rasa/document_retrieval/query_rewriter.py +234 -0
  35. rasa/document_retrieval/query_rewriter_prompt_template.jinja2 +8 -0
  36. rasa/engine/recipes/default_components.py +2 -0
  37. rasa/shared/core/constants.py +0 -8
  38. rasa/shared/core/domain.py +12 -3
  39. rasa/shared/core/flows/flow.py +0 -17
  40. rasa/shared/core/flows/flows_yaml_schema.json +3 -38
  41. rasa/shared/core/flows/steps/collect.py +5 -18
  42. rasa/shared/core/flows/utils.py +1 -16
  43. rasa/shared/core/slot_mappings.py +11 -5
  44. rasa/shared/nlu/constants.py +0 -1
  45. rasa/shared/utils/common.py +11 -1
  46. rasa/shared/utils/llm.py +1 -1
  47. rasa/tracing/instrumentation/attribute_extractors.py +10 -7
  48. rasa/tracing/instrumentation/instrumentation.py +12 -12
  49. rasa/validator.py +1 -123
  50. rasa/version.py +1 -1
  51. {rasa_pro-3.13.0.dev1.dist-info → rasa_pro-3.13.0.dev2.dist-info}/METADATA +1 -1
  52. {rasa_pro-3.13.0.dev1.dist-info → rasa_pro-3.13.0.dev2.dist-info}/RECORD +55 -47
  53. rasa/core/actions/action_handle_digressions.py +0 -164
  54. rasa/dialogue_understanding/commands/handle_digressions_command.py +0 -144
  55. rasa/dialogue_understanding/patterns/handle_digressions.py +0 -81
  56. {rasa_pro-3.13.0.dev1.dist-info → rasa_pro-3.13.0.dev2.dist-info}/NOTICE +0 -0
  57. {rasa_pro-3.13.0.dev1.dist-info → rasa_pro-3.13.0.dev2.dist-info}/WHEEL +0 -0
  58. {rasa_pro-3.13.0.dev1.dist-info → rasa_pro-3.13.0.dev2.dist-info}/entry_points.txt +0 -0
@@ -1,12 +1,10 @@
1
1
  import importlib.resources
2
- import json
3
2
  import re
4
3
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, Text
5
4
 
6
5
  import dotenv
7
6
  import structlog
8
7
  from jinja2 import Template
9
- from pydantic import ValidationError
10
8
 
11
9
  import rasa.shared.utils.io
12
10
  from rasa.core.constants import (
@@ -16,12 +14,9 @@ from rasa.core.constants import (
16
14
  UTTER_SOURCE_METADATA_KEY,
17
15
  )
18
16
  from rasa.core.information_retrieval import (
19
- InformationRetrieval,
20
- InformationRetrievalException,
21
17
  SearchResult,
22
- create_from_endpoint_config,
18
+ SearchResultList,
23
19
  )
24
- from rasa.core.information_retrieval.faiss import FAISS_Store
25
20
  from rasa.core.policies.policy import Policy, PolicyPrediction
26
21
  from rasa.core.utils import AvailableEndpoints
27
22
  from rasa.dialogue_understanding.generator.constants import (
@@ -38,6 +33,10 @@ from rasa.dialogue_understanding.stack.frames import (
38
33
  PatternFlowStackFrame,
39
34
  SearchStackFrame,
40
35
  )
36
+ from rasa.document_retrieval.constants import (
37
+ POST_PROCESSED_DOCUMENTS_KEY,
38
+ SEARCH_QUERY_KEY,
39
+ )
41
40
  from rasa.engine.graph import ExecutionContext
42
41
  from rasa.engine.recipes.default_recipe import DefaultV1Recipe
43
42
  from rasa.engine.storage.resource import Resource
@@ -45,10 +44,7 @@ from rasa.engine.storage.storage import ModelStorage
45
44
  from rasa.graph_components.providers.forms_provider import Forms
46
45
  from rasa.graph_components.providers.responses_provider import Responses
47
46
  from rasa.shared.constants import (
48
- EMBEDDINGS_CONFIG_KEY,
49
47
  MODEL_CONFIG_KEY,
50
- MODEL_GROUP_ID_CONFIG_KEY,
51
- MODEL_NAME_CONFIG_KEY,
52
48
  OPENAI_PROVIDER,
53
49
  PROMPT_CONFIG_KEY,
54
50
  PROVIDER_CONFIG_KEY,
@@ -60,10 +56,10 @@ from rasa.shared.core.constants import (
60
56
  DEFAULT_SLOT_NAMES,
61
57
  )
62
58
  from rasa.shared.core.domain import Domain
63
- from rasa.shared.core.events import BotUttered, Event, UserUttered
59
+ from rasa.shared.core.events import Event
64
60
  from rasa.shared.core.generator import TrackerWithCachedStates
65
- from rasa.shared.core.trackers import DialogueStateTracker, EventVerbosity
66
- from rasa.shared.exceptions import FileIOException, RasaException
61
+ from rasa.shared.core.trackers import DialogueStateTracker
62
+ from rasa.shared.exceptions import FileIOException
67
63
  from rasa.shared.nlu.constants import (
68
64
  KEY_COMPONENT_NAME,
69
65
  KEY_LLM_RESPONSE_METADATA,
@@ -72,12 +68,8 @@ from rasa.shared.nlu.constants import (
72
68
  PROMPTS,
73
69
  )
74
70
  from rasa.shared.nlu.training_data.training_data import TrainingData
75
- from rasa.shared.providers.embedding._langchain_embedding_client_adapter import (
76
- _LangchainEmbeddingClientAdapter,
77
- )
78
71
  from rasa.shared.providers.llm.llm_client import LLMClient
79
72
  from rasa.shared.providers.llm.llm_response import LLMResponse, measure_llm_latency
80
- from rasa.shared.utils.cli import print_error_and_exit
81
73
  from rasa.shared.utils.health_check.embeddings_health_check_mixin import (
82
74
  EmbeddingsHealthCheckMixin,
83
75
  )
@@ -85,23 +77,13 @@ from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheck
85
77
  from rasa.shared.utils.io import deep_container_fingerprint
86
78
  from rasa.shared.utils.llm import (
87
79
  DEFAULT_OPENAI_CHAT_MODEL_NAME,
88
- DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
89
- embedder_factory,
90
80
  get_prompt_template,
91
81
  llm_factory,
92
82
  resolve_model_client_config,
93
- sanitize_message_for_prompt,
94
83
  tracker_as_readable_transcript,
95
84
  )
96
- from rasa.telemetry import (
97
- track_enterprise_search_policy_predict,
98
- track_enterprise_search_policy_train_completed,
99
- track_enterprise_search_policy_train_started,
100
- )
101
85
 
102
86
  if TYPE_CHECKING:
103
- from langchain.schema.embeddings import Embeddings
104
-
105
87
  from rasa.core.featurizers.tracker_featurizers import TrackerFeaturizer
106
88
 
107
89
  from rasa.utils.log_utils import log_llm
@@ -110,22 +92,11 @@ logger = structlog.get_logger()
110
92
 
111
93
  dotenv.load_dotenv("./.env")
112
94
 
113
- SOURCE_PROPERTY = "source"
114
- VECTOR_STORE_TYPE_PROPERTY = "type"
115
- VECTOR_STORE_PROPERTY = "vector_store"
116
- VECTOR_STORE_THRESHOLD_PROPERTY = "threshold"
117
95
  TRACE_TOKENS_PROPERTY = "trace_prompt_tokens"
118
96
  CITATION_ENABLED_PROPERTY = "citation_enabled"
119
97
  USE_LLM_PROPERTY = "use_generative_llm"
120
98
  MAX_MESSAGES_IN_QUERY_KEY = "max_messages_in_query"
121
99
 
122
- DEFAULT_VECTOR_STORE_TYPE = "faiss"
123
- DEFAULT_VECTOR_STORE_THRESHOLD = 0.0
124
- DEFAULT_VECTOR_STORE = {
125
- VECTOR_STORE_TYPE_PROPERTY: DEFAULT_VECTOR_STORE_TYPE,
126
- SOURCE_PROPERTY: "./docs",
127
- VECTOR_STORE_THRESHOLD_PROPERTY: DEFAULT_VECTOR_STORE_THRESHOLD,
128
- }
129
100
 
130
101
  DEFAULT_LLM_CONFIG = {
131
102
  PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
@@ -136,11 +107,6 @@ DEFAULT_LLM_CONFIG = {
136
107
  "max_retries": 1,
137
108
  }
138
109
 
139
- DEFAULT_EMBEDDINGS_CONFIG = {
140
- PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
141
- "model": DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
142
- }
143
-
144
110
  ENTERPRISE_SEARCH_PROMPT_FILE_NAME = "enterprise_search_policy_prompt.jinja2"
145
111
  ENTERPRISE_SEARCH_CONFIG_FILE_NAME = "config.json"
146
112
 
@@ -156,14 +122,6 @@ DEFAULT_ENTERPRISE_SEARCH_PROMPT_WITH_CITATION_TEMPLATE = importlib.resources.re
156
122
  )
157
123
 
158
124
 
159
- class VectorStoreConnectionError(RasaException):
160
- """Exception raised for errors in connecting to the vector store."""
161
-
162
-
163
- class VectorStoreConfigurationError(RasaException):
164
- """Exception raised for errors in vector store configuration."""
165
-
166
-
167
125
  @DefaultV1Recipe.register(
168
126
  DefaultV1Recipe.ComponentType.POLICY_WITH_END_TO_END_SUPPORT, is_trainable=True
169
127
  )
@@ -197,7 +155,6 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
197
155
  """Returns the default config of the policy."""
198
156
  return {
199
157
  POLICY_PRIORITY: SEARCH_POLICY_PRIORITY,
200
- VECTOR_STORE_PROPERTY: DEFAULT_VECTOR_STORE,
201
158
  }
202
159
 
203
160
  def __init__(
@@ -206,7 +163,6 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
206
163
  model_storage: ModelStorage,
207
164
  resource: Resource,
208
165
  execution_context: ExecutionContext,
209
- vector_store: Optional[InformationRetrieval] = None,
210
166
  featurizer: Optional["TrackerFeaturizer"] = None,
211
167
  prompt_template: Optional[Text] = None,
212
168
  ) -> None:
@@ -217,21 +173,6 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
217
173
  self.config[LLM_CONFIG_KEY] = resolve_model_client_config(
218
174
  self.config.get(LLM_CONFIG_KEY), EnterpriseSearchPolicy.__name__
219
175
  )
220
- # Resolve embeddings config
221
- self.config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
222
- self.config.get(EMBEDDINGS_CONFIG_KEY), EnterpriseSearchPolicy.__name__
223
- )
224
-
225
- # Vector store object and configuration
226
- self.vector_store = vector_store
227
- self.vector_store_config = self.config.get(
228
- VECTOR_STORE_PROPERTY, DEFAULT_VECTOR_STORE
229
- )
230
-
231
- # Embeddings configuration for encoding the search query
232
- self.embeddings_config = (
233
- self.config[EMBEDDINGS_CONFIG_KEY] or DEFAULT_EMBEDDINGS_CONFIG
234
- )
235
176
 
236
177
  # LLM Configuration for response generation
237
178
  self.llm_config = self.config[LLM_CONFIG_KEY] or DEFAULT_LLM_CONFIG
@@ -239,9 +180,6 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
239
180
  # Maximum number of turns to include in the prompt
240
181
  self.max_history = self.config.get(POLICY_MAX_HISTORY)
241
182
 
242
- # Maximum number of messages to include in the search query
243
- self.max_messages_in_query = self.config.get(MAX_MESSAGES_IN_QUERY_KEY, 2)
244
-
245
183
  # boolean to enable/disable tracing of prompt tokens
246
184
  self.trace_prompt_tokens = self.config.get(TRACE_TOKENS_PROPERTY, False)
247
185
 
@@ -263,25 +201,6 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
263
201
  if self.citation_enabled:
264
202
  self.prompt_template = self.citation_prompt_template
265
203
 
266
- @classmethod
267
- def _create_plain_embedder(cls, config: Dict[Text, Any]) -> "Embeddings":
268
- """Creates an embedder based on the given configuration.
269
-
270
- Returns:
271
- The embedder.
272
- """
273
- # Copy the config so original config is not modified
274
- config = config.copy()
275
- # Resolve config and instantiate the embedding client
276
- config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
277
- config.get(EMBEDDINGS_CONFIG_KEY), EnterpriseSearchPolicy.__name__
278
- )
279
- client = embedder_factory(
280
- config.get(EMBEDDINGS_CONFIG_KEY), DEFAULT_EMBEDDINGS_CONFIG
281
- )
282
- # Wrap the embedding client in the adapter
283
- return _LangchainEmbeddingClientAdapter(client)
284
-
285
204
  @classmethod
286
205
  def _add_prompt_and_llm_response_to_latest_message(
287
206
  cls,
@@ -346,53 +265,24 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
346
265
  # Perform health checks for both LLM and embeddings client configs
347
266
  self._perform_health_checks(self.config, "enterprise_search_policy.train")
348
267
 
349
- store_type = self.vector_store_config.get(VECTOR_STORE_TYPE_PROPERTY)
350
-
351
- # telemetry call to track training start
352
- track_enterprise_search_policy_train_started()
353
-
354
- # validate embedding configuration
355
- try:
356
- embeddings = self._create_plain_embedder(self.config)
357
- except (ValidationError, Exception) as e:
358
- logger.error(
359
- "enterprise_search_policy.train.embedder_instantiation_failed",
360
- message="Unable to instantiate the embedding client.",
361
- error=e,
362
- )
363
- print_error_and_exit(
364
- "Unable to create embedder. Please make sure you specified the "
365
- f"required environment variables. Error: {e}"
366
- )
268
+ # # telemetry call to track training start
269
+ # track_enterprise_search_policy_train_started()
270
+ # # telemetry call to track training completion
271
+ # track_enterprise_search_policy_train_completed(
272
+ # vector_store_type=store_type,
273
+ # embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
274
+ # embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
275
+ # or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
276
+ # embeddings_model_group_id=self.embeddings_config.get(
277
+ # MODEL_GROUP_ID_CONFIG_KEY
278
+ # ),
279
+ # llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
280
+ # llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
281
+ # or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
282
+ # llm_model_group_id=self.llm_config.get(MODEL_GROUP_ID_CONFIG_KEY),
283
+ # citation_enabled=self.citation_enabled,
284
+ # )
367
285
 
368
- if store_type == DEFAULT_VECTOR_STORE_TYPE:
369
- logger.info("enterprise_search_policy.train.faiss")
370
- with self._model_storage.write_to(self._resource) as path:
371
- self.vector_store = FAISS_Store(
372
- docs_folder=self.vector_store_config.get(SOURCE_PROPERTY),
373
- embeddings=embeddings,
374
- index_path=path,
375
- create_index=True,
376
- use_llm=self.use_llm,
377
- )
378
- else:
379
- logger.info("enterprise_search_policy.train.custom", store_type=store_type)
380
-
381
- # telemetry call to track training completion
382
- track_enterprise_search_policy_train_completed(
383
- vector_store_type=store_type,
384
- embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
385
- embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
386
- or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
387
- embeddings_model_group_id=self.embeddings_config.get(
388
- MODEL_GROUP_ID_CONFIG_KEY
389
- ),
390
- llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
391
- llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
392
- or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
393
- llm_model_group_id=self.llm_config.get(MODEL_GROUP_ID_CONFIG_KEY),
394
- citation_enabled=self.citation_enabled,
395
- )
396
286
  self.persist()
397
287
  return self._resource
398
288
 
@@ -429,60 +319,6 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
429
319
  )
430
320
  return template_slots
431
321
 
432
- def _connect_vector_store_or_raise(
433
- self, endpoints: Optional[AvailableEndpoints]
434
- ) -> None:
435
- """Connects to the vector store or raises an exception.
436
-
437
- Raise exceptions for the following cases:
438
- - The configuration is not specified
439
- - Unable to connect to the vector store
440
-
441
- Args:
442
- endpoints: Endpoints configuration.
443
- """
444
- config = endpoints.vector_store if endpoints else None
445
- store_type = self.vector_store_config.get(VECTOR_STORE_TYPE_PROPERTY)
446
- if config is None and store_type != DEFAULT_VECTOR_STORE_TYPE:
447
- logger.error(
448
- "enterprise_search_policy._connect_vector_store_or_raise.no_config"
449
- )
450
- raise VectorStoreConfigurationError(
451
- """No vector store specified. Please specify a vector
452
- store in the endpoints configuration"""
453
- )
454
- try:
455
- self.vector_store.connect(config) # type: ignore
456
- except Exception as e:
457
- logger.error(
458
- "enterprise_search_policy._connect_vector_store_or_raise.connect_error",
459
- error=e,
460
- config=config,
461
- )
462
- raise VectorStoreConnectionError(
463
- f"Unable to connect to the vector store. Error: {e}"
464
- )
465
-
466
- def _prepare_search_query(self, tracker: DialogueStateTracker, history: int) -> str:
467
- """Prepares the search query.
468
- The search query is the last N messages in the conversation history.
469
-
470
- Args:
471
- tracker: The tracker containing the conversation history up to now.
472
- history: The number of messages to include in the search query.
473
-
474
- Returns:
475
- The search query.
476
- """
477
- transcript = []
478
- for event in tracker.applied_events():
479
- if isinstance(event, UserUttered) or isinstance(event, BotUttered):
480
- transcript.append(sanitize_message_for_prompt(event.text))
481
-
482
- search_query = " ".join(transcript[-history:][::-1])
483
- logger.debug("search_query", search_query=search_query)
484
- return search_query
485
-
486
322
  async def predict_action_probabilities( # type: ignore[override]
487
323
  self,
488
324
  tracker: DialogueStateTracker,
@@ -506,44 +342,34 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
506
342
  The prediction.
507
343
  """
508
344
  logger_key = "enterprise_search_policy.predict_action_probabilities"
509
- vector_search_threshold = self.vector_store_config.get(
510
- VECTOR_STORE_THRESHOLD_PROPERTY, DEFAULT_VECTOR_STORE_THRESHOLD
511
- )
512
- llm = llm_factory(self.config.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG)
345
+
513
346
  if not self.supports_current_stack_frame(
514
347
  tracker, False, False
515
348
  ) or self.should_abstain_in_coexistence(tracker, True):
516
349
  return self._prediction(self._default_predictions(domain))
517
350
 
518
- if not self.vector_store:
519
- logger.error(f"{logger_key}.no_vector_store")
520
- return self._create_prediction_internal_error(domain, tracker)
521
-
522
- try:
523
- self._connect_vector_store_or_raise(endpoints)
524
- except (VectorStoreConfigurationError, VectorStoreConnectionError) as e:
525
- logger.error(f"{logger_key}.connection_error", error=e)
526
- return self._create_prediction_internal_error(domain, tracker)
351
+ # retrieve documents from the latest message
352
+ # document retrieval happened earlier in the pipeline
353
+ if tracker.latest_message is None or tracker.latest_message.parse_data is None:
354
+ logger.info(f"{logger_key}.no_documents")
355
+ return self._create_prediction_cannot_handle(domain, tracker)
527
356
 
528
- search_query = self._prepare_search_query(
529
- tracker, int(self.max_messages_in_query)
357
+ documents_data = tracker.latest_message.parse_data.get(
358
+ POST_PROCESSED_DOCUMENTS_KEY
530
359
  )
531
- tracker_state = tracker.current_state(EventVerbosity.AFTER_RESTART)
532
360
 
533
- try:
534
- documents = await self.vector_store.search(
535
- query=search_query,
536
- tracker_state=tracker_state,
537
- threshold=vector_search_threshold,
538
- )
539
- except InformationRetrievalException as e:
540
- logger.error(f"{logger_key}.search_error", error=e)
541
- return self._create_prediction_internal_error(domain, tracker)
361
+ if not documents_data:
362
+ logger.info(f"{logger_key}.no_documents")
363
+ return self._create_prediction_cannot_handle(domain, tracker)
364
+
365
+ documents = SearchResultList.from_dict(documents_data)
542
366
 
543
367
  if not documents.results:
544
368
  logger.info(f"{logger_key}.no_documents")
545
369
  return self._create_prediction_cannot_handle(domain, tracker)
546
370
 
371
+ llm = llm_factory(self.config.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG)
372
+
547
373
  if self.use_llm:
548
374
  prompt = self._render_prompt(tracker, documents.results)
549
375
  llm_response = await self._generate_llm_answer(llm, prompt)
@@ -588,25 +414,29 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
588
414
  result.text for result in documents.results
589
415
  ],
590
416
  UTTER_SOURCE_METADATA_KEY: self.__class__.__name__,
591
- SEARCH_QUERY_METADATA_KEY: search_query,
417
+ SEARCH_QUERY_METADATA_KEY: tracker.latest_message.parse_data.get(
418
+ SEARCH_QUERY_KEY
419
+ ),
592
420
  }
593
421
  }
594
422
 
595
- # telemetry call to track policy prediction
596
- track_enterprise_search_policy_predict(
597
- vector_store_type=self.vector_store_config.get(VECTOR_STORE_TYPE_PROPERTY),
598
- embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
599
- embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
600
- or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
601
- embeddings_model_group_id=self.embeddings_config.get(
602
- MODEL_GROUP_ID_CONFIG_KEY
603
- ),
604
- llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
605
- llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
606
- or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
607
- llm_model_group_id=self.llm_config.get(MODEL_GROUP_ID_CONFIG_KEY),
608
- citation_enabled=self.citation_enabled,
609
- )
423
+ # # telemetry call to track policy prediction
424
+ # track_enterprise_search_policy_predict(
425
+ # vector_store_type=self.vector_store_config.get(
426
+ # VECTOR_STORE_TYPE_PROPERTY),
427
+ # embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
428
+ # embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
429
+ # or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
430
+ # embeddings_model_group_id=self.embeddings_config.get(
431
+ # MODEL_GROUP_ID_CONFIG_KEY
432
+ # ),
433
+ # llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
434
+ # llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
435
+ # or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
436
+ # llm_model_group_id=self.llm_config.get(MODEL_GROUP_ID_CONFIG_KEY),
437
+ # citation_enabled=self.citation_enabled,
438
+ # )
439
+
610
440
  return self._create_prediction(
611
441
  domain=domain, tracker=tracker, action_metadata=action_metadata
612
442
  )
@@ -774,73 +604,19 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
774
604
  "enterprise_search_policy.load.failed", error=e, resource=resource.name
775
605
  )
776
606
 
777
- store_type = config.get(VECTOR_STORE_PROPERTY, {}).get(
778
- VECTOR_STORE_TYPE_PROPERTY
779
- )
780
-
781
- embeddings = cls._create_plain_embedder(config)
782
-
783
607
  logger.info("enterprise_search_policy.load", config=config)
784
- if store_type == DEFAULT_VECTOR_STORE_TYPE:
785
- # if a vector store is not specified,
786
- # default to using FAISS with the index stored in the model
787
- # TODO figure out a way to get path without context manager
788
- with model_storage.read_from(resource) as path:
789
- vector_store = FAISS_Store(
790
- embeddings=embeddings,
791
- index_path=path,
792
- docs_folder=None,
793
- create_index=False,
794
- )
795
- else:
796
- vector_store = create_from_endpoint_config(
797
- config_type=store_type,
798
- embeddings=embeddings,
799
- ) # type: ignore
800
608
 
801
609
  return cls(
802
610
  config,
803
611
  model_storage,
804
612
  resource,
805
613
  execution_context,
806
- vector_store=vector_store,
807
614
  prompt_template=prompt_template,
808
615
  )
809
616
 
810
- @classmethod
811
- def _get_local_knowledge_data(cls, config: Dict[str, Any]) -> Optional[List[str]]:
812
- """This is required only for local knowledge base types.
813
-
814
- e.g. FAISS, to ensure that the graph component is retrained when the knowledge
815
- base is updated.
816
- """
817
- merged_config = {**cls.get_default_config(), **config}
818
-
819
- store_type = merged_config.get(VECTOR_STORE_PROPERTY, {}).get(
820
- VECTOR_STORE_TYPE_PROPERTY
821
- )
822
- if store_type != DEFAULT_VECTOR_STORE_TYPE:
823
- return None
824
-
825
- source = merged_config.get(VECTOR_STORE_PROPERTY, {}).get(SOURCE_PROPERTY)
826
- if not source:
827
- return None
828
-
829
- docs = FAISS_Store.load_documents(source)
830
-
831
- if len(docs) == 0:
832
- return None
833
-
834
- docs_as_strings = [
835
- json.dumps(doc.dict(), ensure_ascii=False, sort_keys=True) for doc in docs
836
- ]
837
- return sorted(docs_as_strings)
838
-
839
617
  @classmethod
840
618
  def fingerprint_addon(cls, config: Dict[str, Any]) -> Optional[str]:
841
619
  """Add a fingerprint of enterprise search policy for the graph."""
842
- local_knowledge_data = cls._get_local_knowledge_data(config)
843
-
844
620
  prompt_template = get_prompt_template(
845
621
  config.get(PROMPT_CONFIG_KEY),
846
622
  DEFAULT_ENTERPRISE_SEARCH_PROMPT_TEMPLATE,
@@ -849,12 +625,7 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
849
625
  llm_config = resolve_model_client_config(
850
626
  config.get(LLM_CONFIG_KEY), EnterpriseSearchPolicy.__name__
851
627
  )
852
- embedding_config = resolve_model_client_config(
853
- config.get(EMBEDDINGS_CONFIG_KEY), EnterpriseSearchPolicy.__name__
854
- )
855
- return deep_container_fingerprint(
856
- [prompt_template, local_knowledge_data, llm_config, embedding_config]
857
- )
628
+ return deep_container_fingerprint([prompt_template, llm_config])
858
629
 
859
630
  @staticmethod
860
631
  def post_process_citations(llm_answer: str) -> str:
@@ -959,14 +730,3 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
959
730
  log_source_method,
960
731
  EnterpriseSearchPolicy.__name__,
961
732
  )
962
-
963
- # Perform health check of the embeddings client config
964
- embeddings_config = resolve_model_client_config(
965
- config.get(EMBEDDINGS_CONFIG_KEY, {})
966
- )
967
- cls.perform_embeddings_health_check(
968
- embeddings_config,
969
- DEFAULT_EMBEDDINGS_CONFIG,
970
- log_source_method,
971
- EnterpriseSearchPolicy.__name__,
972
- )
@@ -23,7 +23,6 @@ from rasa.core.policies.flows.flow_step_result import (
23
23
  )
24
24
  from rasa.dialogue_understanding.commands import CancelFlowCommand
25
25
  from rasa.dialogue_understanding.patterns.cancel import CancelPatternFlowStackFrame
26
- from rasa.dialogue_understanding.patterns.clarify import ClarifyPatternFlowStackFrame
27
26
  from rasa.dialogue_understanding.patterns.collect_information import (
28
27
  CollectInformationPatternFlowStackFrame,
29
28
  )
@@ -51,7 +50,6 @@ from rasa.dialogue_understanding.stack.frames.flow_stack_frame import (
51
50
  )
52
51
  from rasa.dialogue_understanding.stack.utils import (
53
52
  top_user_flow_frame,
54
- user_flows_on_the_stack,
55
53
  )
56
54
  from rasa.shared.constants import RASA_PATTERN_HUMAN_HANDOFF
57
55
  from rasa.shared.core.constants import (
@@ -280,33 +278,6 @@ def trigger_pattern_continue_interrupted(
280
278
  return events
281
279
 
282
280
 
283
- def trigger_pattern_clarification(
284
- current_frame: DialogueStackFrame, stack: DialogueStack, flows: FlowsList
285
- ) -> None:
286
- """Trigger the pattern to clarify which topic to continue if needed."""
287
- if not isinstance(current_frame, UserFlowStackFrame):
288
- return None
289
-
290
- if current_frame.frame_type in [
291
- FlowStackFrameType.CALL,
292
- FlowStackFrameType.INTERRUPT,
293
- ]:
294
- # we want to return to the flow that called
295
- # the current flow or the flow that was interrupted
296
- # by the current flow
297
- return None
298
-
299
- pending_flows = [
300
- flows.flow_by_id(frame.flow_id)
301
- for frame in stack.frames
302
- if isinstance(frame, UserFlowStackFrame)
303
- and frame.flow_id != current_frame.flow_id
304
- ]
305
-
306
- flow_names = [flow.readable_name() for flow in pending_flows if flow is not None]
307
- stack.push(ClarifyPatternFlowStackFrame(names=flow_names))
308
-
309
-
310
281
  def trigger_pattern_completed(
311
282
  current_frame: DialogueStackFrame, stack: DialogueStack, flows: FlowsList
312
283
  ) -> None:
@@ -675,15 +646,9 @@ def _run_end_step(
675
646
  structlogger.debug("flow.step.run.flow_end")
676
647
  current_frame = stack.pop()
677
648
  trigger_pattern_completed(current_frame, stack, flows)
678
- resumed_events = []
679
- if len(user_flows_on_the_stack(stack)) > 1:
680
- # if there are more user flows on the stack,
681
- # we need to trigger the pattern clarify
682
- trigger_pattern_clarification(current_frame, stack, flows)
683
- else:
684
- resumed_events = trigger_pattern_continue_interrupted(
685
- current_frame, stack, flows, tracker
686
- )
649
+ resumed_events = trigger_pattern_continue_interrupted(
650
+ current_frame, stack, flows, tracker
651
+ )
687
652
  reset_events: List[Event] = reset_scoped_slots(current_frame, flow, tracker)
688
653
  return ContinueFlowWithNextStep(
689
654
  events=initial_events + reset_events + resumed_events, has_flow_ended=True
rasa/core/processor.py CHANGED
@@ -76,6 +76,7 @@ from rasa.shared.core.constants import (
76
76
  SLOT_SILENCE_TIMEOUT,
77
77
  USER_INTENT_RESTART,
78
78
  USER_INTENT_SILENCE_TIMEOUT,
79
+ SetSlotExtractor,
79
80
  )
80
81
  from rasa.shared.core.events import (
81
82
  ActionExecuted,
@@ -766,13 +767,26 @@ class MessageProcessor:
766
767
  if self.http_interpreter:
767
768
  parse_data = await self.http_interpreter.parse(message)
768
769
  else:
769
- regex_reader = create_regex_pattern_reader(message, self.domain)
770
-
771
770
  processed_message = Message({TEXT: message.text})
772
- if regex_reader:
773
- processed_message = regex_reader.unpack_regex_message(
774
- message=processed_message, domain=self.domain
771
+
772
+ all_flows = await self.get_flows()
773
+ should_force_slot_command, slot_name = (
774
+ rasa.core.utils.should_force_slot_filling(tracker, all_flows)
775
+ )
776
+
777
+ if should_force_slot_command:
778
+ command = SetSlotCommand(
779
+ name=slot_name,
780
+ value=message.text,
781
+ extractor=SetSlotExtractor.COMMAND_PAYLOAD_READER.value,
775
782
  )
783
+ processed_message.set(COMMANDS, [command.as_dict()], add_to_output=True)
784
+ else:
785
+ regex_reader = create_regex_pattern_reader(message, self.domain)
786
+ if regex_reader:
787
+ processed_message = regex_reader.unpack_regex_message(
788
+ message=processed_message, domain=self.domain
789
+ )
776
790
 
777
791
  # Invalid use of slash syntax, sanitize the message before passing
778
792
  # it to the graph
@@ -1009,7 +1023,14 @@ class MessageProcessor:
1009
1023
 
1010
1024
  @staticmethod
1011
1025
  def _should_handle_message(tracker: DialogueStateTracker) -> bool:
1012
- return not tracker.is_paused() or (
1026
+ return not tracker.is_paused() or MessageProcessor._last_user_intent_is_restart(
1027
+ tracker
1028
+ )
1029
+
1030
+ @staticmethod
1031
+ def _last_user_intent_is_restart(tracker: DialogueStateTracker) -> bool:
1032
+ """Check if the last user intent is a restart intent."""
1033
+ return (
1013
1034
  tracker.latest_message is not None
1014
1035
  and tracker.latest_message.intent.get(INTENT_NAME_KEY)
1015
1036
  == USER_INTENT_RESTART