rasa-pro 3.9.18__py3-none-any.whl → 3.10.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (190) hide show
  1. README.md +26 -57
  2. rasa/__init__.py +1 -2
  3. rasa/__main__.py +5 -0
  4. rasa/anonymization/anonymization_rule_executor.py +2 -2
  5. rasa/api.py +26 -22
  6. rasa/cli/arguments/data.py +27 -2
  7. rasa/cli/arguments/default_arguments.py +25 -3
  8. rasa/cli/arguments/run.py +9 -9
  9. rasa/cli/arguments/train.py +2 -0
  10. rasa/cli/data.py +70 -8
  11. rasa/cli/e2e_test.py +108 -433
  12. rasa/cli/interactive.py +1 -0
  13. rasa/cli/llm_fine_tuning.py +395 -0
  14. rasa/cli/project_templates/calm/endpoints.yml +1 -1
  15. rasa/cli/project_templates/tutorial/endpoints.yml +1 -1
  16. rasa/cli/run.py +14 -13
  17. rasa/cli/scaffold.py +10 -8
  18. rasa/cli/train.py +8 -7
  19. rasa/cli/utils.py +15 -0
  20. rasa/constants.py +7 -1
  21. rasa/core/actions/action.py +98 -49
  22. rasa/core/actions/action_run_slot_rejections.py +4 -1
  23. rasa/core/actions/custom_action_executor.py +9 -6
  24. rasa/core/actions/direct_custom_actions_executor.py +80 -0
  25. rasa/core/actions/e2e_stub_custom_action_executor.py +68 -0
  26. rasa/core/actions/grpc_custom_action_executor.py +2 -2
  27. rasa/core/actions/http_custom_action_executor.py +6 -5
  28. rasa/core/agent.py +21 -17
  29. rasa/core/channels/__init__.py +2 -0
  30. rasa/core/channels/audiocodes.py +1 -16
  31. rasa/core/channels/inspector/dist/index.html +0 -2
  32. rasa/core/channels/inspector/index.html +0 -2
  33. rasa/core/channels/voice_aware/__init__.py +0 -0
  34. rasa/core/channels/voice_aware/jambonz.py +103 -0
  35. rasa/core/channels/voice_aware/jambonz_protocol.py +344 -0
  36. rasa/core/channels/voice_aware/utils.py +20 -0
  37. rasa/core/channels/voice_native/__init__.py +0 -0
  38. rasa/core/constants.py +6 -1
  39. rasa/core/featurizers/single_state_featurizer.py +1 -22
  40. rasa/core/featurizers/tracker_featurizers.py +18 -115
  41. rasa/core/information_retrieval/faiss.py +7 -4
  42. rasa/core/information_retrieval/information_retrieval.py +8 -0
  43. rasa/core/information_retrieval/milvus.py +9 -2
  44. rasa/core/information_retrieval/qdrant.py +1 -1
  45. rasa/core/nlg/contextual_response_rephraser.py +32 -10
  46. rasa/core/nlg/summarize.py +4 -3
  47. rasa/core/policies/enterprise_search_policy.py +100 -44
  48. rasa/core/policies/flows/flow_executor.py +130 -94
  49. rasa/core/policies/intentless_policy.py +52 -28
  50. rasa/core/policies/ted_policy.py +33 -58
  51. rasa/core/policies/unexpected_intent_policy.py +7 -15
  52. rasa/core/processor.py +20 -53
  53. rasa/core/run.py +5 -4
  54. rasa/core/tracker_store.py +8 -4
  55. rasa/core/utils.py +45 -56
  56. rasa/dialogue_understanding/coexistence/llm_based_router.py +45 -12
  57. rasa/dialogue_understanding/commands/__init__.py +4 -0
  58. rasa/dialogue_understanding/commands/change_flow_command.py +0 -6
  59. rasa/dialogue_understanding/commands/session_start_command.py +59 -0
  60. rasa/dialogue_understanding/commands/set_slot_command.py +1 -5
  61. rasa/dialogue_understanding/commands/utils.py +38 -0
  62. rasa/dialogue_understanding/generator/constants.py +10 -3
  63. rasa/dialogue_understanding/generator/flow_retrieval.py +14 -5
  64. rasa/dialogue_understanding/generator/llm_based_command_generator.py +12 -2
  65. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +106 -87
  66. rasa/dialogue_understanding/generator/nlu_command_adapter.py +28 -6
  67. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +90 -37
  68. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +15 -15
  69. rasa/dialogue_understanding/patterns/session_start.py +37 -0
  70. rasa/dialogue_understanding/processor/command_processor.py +13 -14
  71. rasa/e2e_test/aggregate_test_stats_calculator.py +124 -0
  72. rasa/e2e_test/assertions.py +1181 -0
  73. rasa/e2e_test/assertions_schema.yml +106 -0
  74. rasa/e2e_test/constants.py +20 -0
  75. rasa/e2e_test/e2e_config.py +220 -0
  76. rasa/e2e_test/e2e_config_schema.yml +26 -0
  77. rasa/e2e_test/e2e_test_case.py +131 -8
  78. rasa/e2e_test/e2e_test_converter.py +363 -0
  79. rasa/e2e_test/e2e_test_converter_prompt.jinja2 +70 -0
  80. rasa/e2e_test/e2e_test_coverage_report.py +364 -0
  81. rasa/e2e_test/e2e_test_result.py +26 -6
  82. rasa/e2e_test/e2e_test_runner.py +491 -72
  83. rasa/e2e_test/e2e_test_schema.yml +96 -0
  84. rasa/e2e_test/pykwalify_extensions.py +39 -0
  85. rasa/e2e_test/stub_custom_action.py +70 -0
  86. rasa/e2e_test/utils/__init__.py +0 -0
  87. rasa/e2e_test/utils/e2e_yaml_utils.py +55 -0
  88. rasa/e2e_test/utils/io.py +596 -0
  89. rasa/e2e_test/utils/validation.py +80 -0
  90. rasa/engine/recipes/default_components.py +0 -2
  91. rasa/engine/storage/local_model_storage.py +0 -1
  92. rasa/env.py +9 -0
  93. rasa/keys +1 -0
  94. rasa/llm_fine_tuning/__init__.py +0 -0
  95. rasa/llm_fine_tuning/annotation_module.py +241 -0
  96. rasa/llm_fine_tuning/conversations.py +144 -0
  97. rasa/llm_fine_tuning/llm_data_preparation_module.py +178 -0
  98. rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +407 -0
  99. rasa/llm_fine_tuning/paraphrasing/__init__.py +0 -0
  100. rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +281 -0
  101. rasa/llm_fine_tuning/paraphrasing/default_rephrase_prompt_template.jina2 +44 -0
  102. rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +121 -0
  103. rasa/llm_fine_tuning/paraphrasing/rephrased_user_message.py +10 -0
  104. rasa/llm_fine_tuning/paraphrasing_module.py +128 -0
  105. rasa/llm_fine_tuning/storage.py +174 -0
  106. rasa/llm_fine_tuning/train_test_split_module.py +441 -0
  107. rasa/model_training.py +48 -16
  108. rasa/nlu/classifiers/diet_classifier.py +25 -38
  109. rasa/nlu/classifiers/logistic_regression_classifier.py +9 -44
  110. rasa/nlu/classifiers/sklearn_intent_classifier.py +16 -37
  111. rasa/nlu/extractors/crf_entity_extractor.py +50 -93
  112. rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +45 -78
  113. rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +17 -52
  114. rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +3 -5
  115. rasa/nlu/persistor.py +129 -32
  116. rasa/server.py +45 -10
  117. rasa/shared/constants.py +63 -15
  118. rasa/shared/core/domain.py +15 -12
  119. rasa/shared/core/events.py +28 -2
  120. rasa/shared/core/flows/flow.py +208 -13
  121. rasa/shared/core/flows/flow_path.py +84 -0
  122. rasa/shared/core/flows/flows_list.py +28 -10
  123. rasa/shared/core/flows/flows_yaml_schema.json +269 -193
  124. rasa/shared/core/flows/validation.py +112 -25
  125. rasa/shared/core/flows/yaml_flows_io.py +149 -10
  126. rasa/shared/core/trackers.py +6 -0
  127. rasa/shared/core/training_data/visualization.html +2 -2
  128. rasa/shared/exceptions.py +4 -0
  129. rasa/shared/importers/importer.py +60 -11
  130. rasa/shared/importers/remote_importer.py +196 -0
  131. rasa/shared/nlu/constants.py +2 -0
  132. rasa/shared/nlu/training_data/features.py +2 -120
  133. rasa/shared/providers/_configs/__init__.py +0 -0
  134. rasa/shared/providers/_configs/azure_openai_client_config.py +181 -0
  135. rasa/shared/providers/_configs/client_config.py +57 -0
  136. rasa/shared/providers/_configs/default_litellm_client_config.py +130 -0
  137. rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +234 -0
  138. rasa/shared/providers/_configs/openai_client_config.py +175 -0
  139. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +171 -0
  140. rasa/shared/providers/_configs/utils.py +101 -0
  141. rasa/shared/providers/_ssl_verification_utils.py +124 -0
  142. rasa/shared/providers/embedding/__init__.py +0 -0
  143. rasa/shared/providers/embedding/_base_litellm_embedding_client.py +254 -0
  144. rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py +74 -0
  145. rasa/shared/providers/embedding/azure_openai_embedding_client.py +277 -0
  146. rasa/shared/providers/embedding/default_litellm_embedding_client.py +102 -0
  147. rasa/shared/providers/embedding/embedding_client.py +90 -0
  148. rasa/shared/providers/embedding/embedding_response.py +41 -0
  149. rasa/shared/providers/embedding/huggingface_local_embedding_client.py +191 -0
  150. rasa/shared/providers/embedding/openai_embedding_client.py +172 -0
  151. rasa/shared/providers/llm/__init__.py +0 -0
  152. rasa/shared/providers/llm/_base_litellm_client.py +227 -0
  153. rasa/shared/providers/llm/azure_openai_llm_client.py +338 -0
  154. rasa/shared/providers/llm/default_litellm_llm_client.py +84 -0
  155. rasa/shared/providers/llm/llm_client.py +76 -0
  156. rasa/shared/providers/llm/llm_response.py +50 -0
  157. rasa/shared/providers/llm/openai_llm_client.py +155 -0
  158. rasa/shared/providers/llm/self_hosted_llm_client.py +169 -0
  159. rasa/shared/providers/mappings.py +75 -0
  160. rasa/shared/utils/cli.py +30 -0
  161. rasa/shared/utils/io.py +65 -3
  162. rasa/shared/utils/llm.py +223 -200
  163. rasa/shared/utils/yaml.py +122 -7
  164. rasa/studio/download.py +19 -13
  165. rasa/studio/train.py +2 -3
  166. rasa/studio/upload.py +2 -3
  167. rasa/telemetry.py +113 -58
  168. rasa/tracing/config.py +2 -3
  169. rasa/tracing/instrumentation/attribute_extractors.py +29 -17
  170. rasa/tracing/instrumentation/instrumentation.py +4 -47
  171. rasa/utils/common.py +18 -19
  172. rasa/utils/endpoints.py +7 -4
  173. rasa/utils/io.py +66 -0
  174. rasa/utils/json_utils.py +60 -0
  175. rasa/utils/licensing.py +9 -1
  176. rasa/utils/ml_utils.py +4 -2
  177. rasa/utils/tensorflow/model_data.py +193 -2
  178. rasa/validator.py +196 -1
  179. rasa/version.py +1 -1
  180. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.4.dist-info}/METADATA +47 -72
  181. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.4.dist-info}/RECORD +186 -121
  182. rasa/nlu/classifiers/llm_intent_classifier.py +0 -519
  183. rasa/shared/providers/openai/clients.py +0 -43
  184. rasa/shared/providers/openai/session_handler.py +0 -110
  185. rasa/utils/tensorflow/feature_array.py +0 -366
  186. /rasa/{shared/providers/openai → cli/project_templates/tutorial/actions}/__init__.py +0 -0
  187. /rasa/cli/project_templates/tutorial/{actions.py → actions/actions.py} +0 -0
  188. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.4.dist-info}/NOTICE +0 -0
  189. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.4.dist-info}/WHEEL +0 -0
  190. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.4.dist-info}/entry_points.txt +0 -0
@@ -8,6 +8,11 @@ import structlog
8
8
  from jinja2 import Template
9
9
  from pydantic import ValidationError
10
10
 
11
+ from rasa.shared.providers.embedding._langchain_embedding_client_adapter import (
12
+ _LangchainEmbeddingClientAdapter,
13
+ )
14
+ from rasa.shared.providers.llm.llm_client import LLMClient
15
+
11
16
  import rasa.shared.utils.io
12
17
  from rasa.telemetry import (
13
18
  track_enterprise_search_policy_predict,
@@ -19,6 +24,7 @@ from rasa.core.constants import (
19
24
  POLICY_MAX_HISTORY,
20
25
  POLICY_PRIORITY,
21
26
  SEARCH_POLICY_PRIORITY,
27
+ UTTER_SOURCE_METADATA_KEY,
22
28
  )
23
29
  from rasa.core.policies.policy import Policy, PolicyPrediction
24
30
  from rasa.core.utils import AvailableEndpoints
@@ -39,13 +45,23 @@ from rasa.engine.storage.resource import Resource
39
45
  from rasa.engine.storage.storage import ModelStorage
40
46
  from rasa.graph_components.providers.forms_provider import Forms
41
47
  from rasa.graph_components.providers.responses_provider import Responses
48
+ from rasa.shared.constants import (
49
+ EMBEDDINGS_CONFIG_KEY,
50
+ LLM_CONFIG_KEY,
51
+ MODEL_CONFIG_KEY,
52
+ MODEL_NAME_CONFIG_KEY,
53
+ PROMPT_CONFIG_KEY,
54
+ PROVIDER_CONFIG_KEY,
55
+ OPENAI_PROVIDER,
56
+ TIMEOUT_CONFIG_KEY,
57
+ )
42
58
  from rasa.shared.core.constants import (
43
59
  ACTION_CANCEL_FLOW,
44
60
  ACTION_SEND_TEXT_NAME,
45
61
  DEFAULT_SLOT_NAMES,
46
62
  )
47
63
  from rasa.shared.core.domain import Domain
48
- from rasa.shared.core.events import Event
64
+ from rasa.shared.core.events import Event, UserUttered, BotUttered
49
65
  from rasa.shared.core.generator import TrackerWithCachedStates
50
66
  from rasa.shared.core.trackers import DialogueStateTracker, EventVerbosity
51
67
  from rasa.shared.nlu.training_data.training_data import TrainingData
@@ -59,6 +75,7 @@ from rasa.shared.utils.llm import (
59
75
  llm_factory,
60
76
  sanitize_message_for_prompt,
61
77
  tracker_as_readable_transcript,
78
+ try_instantiate_llm_client,
62
79
  )
63
80
  from rasa.core.information_retrieval.faiss import FAISS_Store
64
81
  from rasa.core.information_retrieval import (
@@ -70,7 +87,6 @@ from rasa.core.information_retrieval import (
70
87
 
71
88
  if TYPE_CHECKING:
72
89
  from langchain.schema.embeddings import Embeddings
73
- from langchain.llms.base import BaseLLM
74
90
  from rasa.core.featurizers.tracker_featurizers import TrackerFeaturizer
75
91
 
76
92
  from rasa.utils.log_utils import log_llm
@@ -86,6 +102,7 @@ VECTOR_STORE_THRESHOLD_PROPERTY = "threshold"
86
102
  TRACE_TOKENS_PROPERTY = "trace_prompt_tokens"
87
103
  CITATION_ENABLED_PROPERTY = "citation_enabled"
88
104
  USE_LLM_PROPERTY = "use_generative_llm"
105
+ MAX_MESSAGES_IN_QUERY_KEY = "max_messages_in_query"
89
106
 
90
107
  DEFAULT_VECTOR_STORE_TYPE = "faiss"
91
108
  DEFAULT_VECTOR_STORE_THRESHOLD = 0.0
@@ -96,23 +113,24 @@ DEFAULT_VECTOR_STORE = {
96
113
  }
97
114
 
98
115
  DEFAULT_LLM_CONFIG = {
99
- "_type": "openai",
100
- "request_timeout": 10,
116
+ PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
117
+ MODEL_CONFIG_KEY: DEFAULT_OPENAI_CHAT_MODEL_NAME,
118
+ TIMEOUT_CONFIG_KEY: 10,
101
119
  "temperature": 0.0,
102
120
  "max_tokens": 256,
103
- "model_name": DEFAULT_OPENAI_CHAT_MODEL_NAME,
104
121
  "max_retries": 1,
105
122
  }
106
123
 
107
124
  DEFAULT_EMBEDDINGS_CONFIG = {
108
- "_type": "openai",
125
+ PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
109
126
  "model": DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
110
127
  }
111
128
 
112
- EMBEDDINGS_CONFIG_KEY = "embeddings"
113
- LLM_CONFIG_KEY = "llm"
114
129
  ENTERPRISE_SEARCH_PROMPT_FILE_NAME = "enterprise_search_policy_prompt.jinja2"
115
130
 
131
+ SEARCH_RESULTS_METADATA_KEY = "search_results"
132
+ SEARCH_QUERY_METADATA_KEY = "search_query"
133
+
116
134
  DEFAULT_ENTERPRISE_SEARCH_PROMPT_TEMPLATE = importlib.resources.read_text(
117
135
  "rasa.core.policies", "enterprise_search_prompt_template.jinja2"
118
136
  )
@@ -179,26 +197,42 @@ class EnterpriseSearchPolicy(Policy):
179
197
  """Constructs a new Policy object."""
180
198
  super().__init__(config, model_storage, resource, execution_context, featurizer)
181
199
 
200
+ # Vector store object and configuration
182
201
  self.vector_store = vector_store
183
202
  self.vector_store_config = config.get(
184
203
  VECTOR_STORE_PROPERTY, DEFAULT_VECTOR_STORE
185
204
  )
186
- self.llm_config = self.config.get(LLM_CONFIG_KEY, DEFAULT_LLM_CONFIG)
205
+ # Embeddings configuration for encoding the search query
187
206
  self.embeddings_config = self.config.get(
188
207
  EMBEDDINGS_CONFIG_KEY, DEFAULT_EMBEDDINGS_CONFIG
189
208
  )
209
+ # Maximum number of turns to include in the prompt
190
210
  self.max_history = self.config.get(POLICY_MAX_HISTORY)
191
- self.prompt_template = prompt_template or get_prompt_template(
192
- self.config.get("prompt"),
193
- DEFAULT_ENTERPRISE_SEARCH_PROMPT_TEMPLATE,
194
- )
211
+
212
+ # Maximum number of messages to include in the search query
213
+ self.max_messages_in_query = self.config.get(MAX_MESSAGES_IN_QUERY_KEY, 2)
214
+
215
+ # LLM Configuration for response generation
216
+ self.llm_config = self.config.get(LLM_CONFIG_KEY, DEFAULT_LLM_CONFIG)
217
+
218
+ # boolean to enable/disable tracing of prompt tokens
195
219
  self.trace_prompt_tokens = self.config.get(TRACE_TOKENS_PROPERTY, False)
220
+
221
+ # boolean to enable/disable the use of LLM for response generation
196
222
  self.use_llm = self.config.get(USE_LLM_PROPERTY, True)
223
+
224
+ # boolean to enable/disable citation generation
197
225
  self.citation_enabled = self.config.get(CITATION_ENABLED_PROPERTY, False)
226
+
227
+ self.prompt_template = prompt_template or get_prompt_template(
228
+ self.config.get(PROMPT_CONFIG_KEY),
229
+ DEFAULT_ENTERPRISE_SEARCH_PROMPT_TEMPLATE,
230
+ )
198
231
  self.citation_prompt_template = get_prompt_template(
199
- self.config.get("prompt"),
232
+ self.config.get(PROMPT_CONFIG_KEY),
200
233
  DEFAULT_ENTERPRISE_SEARCH_PROMPT_WITH_CITATION_TEMPLATE,
201
234
  )
235
+ # If citation is enabled, use the citation prompt template
202
236
  if self.citation_enabled:
203
237
  self.prompt_template = self.citation_prompt_template
204
238
 
@@ -209,9 +243,10 @@ class EnterpriseSearchPolicy(Policy):
209
243
  Returns:
210
244
  The embedder.
211
245
  """
212
- return embedder_factory(
246
+ client = embedder_factory(
213
247
  config.get(EMBEDDINGS_CONFIG_KEY), DEFAULT_EMBEDDINGS_CONFIG
214
248
  )
249
+ return _LangchainEmbeddingClientAdapter(client)
215
250
 
216
251
  def train( # type: ignore[override]
217
252
  self,
@@ -245,20 +280,24 @@ class EnterpriseSearchPolicy(Policy):
245
280
  # validate embedding configuration
246
281
  try:
247
282
  embeddings = self._create_plain_embedder(self.config)
248
- except ValidationError as e:
283
+ except (ValidationError, Exception) as e:
284
+ logger.error(
285
+ "enterprise_search_policy.train.embedder_instantiation_failed",
286
+ message="Unable to instantiate the embedding client.",
287
+ error=e,
288
+ )
249
289
  print_error_and_exit(
250
290
  "Unable to create embedder. Please make sure you specified the "
251
291
  f"required environment variables. Error: {e}"
252
292
  )
253
293
 
254
294
  # validate llm configuration
255
- try:
256
- llm_factory(self.config.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG)
257
- except (ImportError, ValueError, ValidationError) as e:
258
- # ImportError: llm library is likely not installed
259
- # ValueError: llm config is likely invalid
260
- # ValidationError: environment variables are likely not set
261
- print_error_and_exit(f"Unable to create LLM. Error: {e}")
295
+ try_instantiate_llm_client(
296
+ self.config.get(LLM_CONFIG_KEY),
297
+ DEFAULT_LLM_CONFIG,
298
+ "enterprise_search_policy.train",
299
+ "EnterpriseSearchPolicy",
300
+ )
262
301
 
263
302
  if store_type == DEFAULT_VECTOR_STORE_TYPE:
264
303
  logger.info("enterprise_search_policy.train.faiss")
@@ -275,11 +314,12 @@ class EnterpriseSearchPolicy(Policy):
275
314
  # telemetry call to track training completion
276
315
  track_enterprise_search_policy_train_completed(
277
316
  vector_store_type=store_type,
278
- embeddings_type=self.embeddings_config.get("_type"),
279
- embeddings_model=self.embeddings_config.get("model")
280
- or self.embeddings_config.get("model_name"),
281
- llm_type=self.llm_config.get("_type"),
282
- llm_model=self.llm_config.get("model") or self.llm_config.get("model_name"),
317
+ embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
318
+ embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
319
+ or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
320
+ llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
321
+ llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
322
+ or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
283
323
  citation_enabled=self.citation_enabled,
284
324
  )
285
325
  self.persist()
@@ -343,24 +383,31 @@ class EnterpriseSearchPolicy(Policy):
343
383
  logger.error(
344
384
  "enterprise_search_policy._connect_vector_store_or_raise.connect_error",
345
385
  error=e,
386
+ config=config,
346
387
  )
347
388
  raise VectorStoreConnectionError(
348
389
  f"Unable to connect to the vector store. Error: {e}"
349
390
  )
350
391
 
351
- def _get_last_user_message(self, tracker: DialogueStateTracker) -> str:
352
- """Get the last user message from the tracker.
392
+ def _prepare_search_query(self, tracker: DialogueStateTracker, history: int) -> str:
393
+ """Prepares the search query.
394
+ The search query is the last N messages in the conversation history.
353
395
 
354
396
  Args:
355
397
  tracker: The tracker containing the conversation history up to now.
398
+ history: The number of messages to include in the search query.
356
399
 
357
400
  Returns:
358
- The last user message.
401
+ The search query.
359
402
  """
360
- for event in reversed(tracker.events):
361
- if isinstance(event, rasa.shared.core.events.UserUttered):
362
- return sanitize_message_for_prompt(event.text)
363
- return ""
403
+ transcript = []
404
+ for event in tracker.applied_events():
405
+ if isinstance(event, UserUttered) or isinstance(event, BotUttered):
406
+ transcript.append(sanitize_message_for_prompt(event.text))
407
+
408
+ search_query = " ".join(transcript[-history:][::-1])
409
+ logger.debug("search_query", search_query=search_query)
410
+ return search_query
364
411
 
365
412
  async def predict_action_probabilities( # type: ignore[override]
366
413
  self,
@@ -404,7 +451,9 @@ class EnterpriseSearchPolicy(Policy):
404
451
  logger.error(f"{logger_key}.connection_error", error=e)
405
452
  return self._create_prediction_internal_error(domain, tracker)
406
453
 
407
- search_query = self._get_last_user_message(tracker)
454
+ search_query = self._prepare_search_query(
455
+ tracker, int(self.max_messages_in_query)
456
+ )
408
457
  tracker_state = tracker.current_state(EventVerbosity.AFTER_RESTART)
409
458
 
410
459
  try:
@@ -448,17 +497,23 @@ class EnterpriseSearchPolicy(Policy):
448
497
  action_metadata = {
449
498
  "message": {
450
499
  "text": response,
500
+ SEARCH_RESULTS_METADATA_KEY: [
501
+ result.text for result in documents.results
502
+ ],
503
+ UTTER_SOURCE_METADATA_KEY: self.__class__.__name__,
504
+ SEARCH_QUERY_METADATA_KEY: search_query,
451
505
  }
452
506
  }
453
507
 
454
508
  # telemetry call to track policy prediction
455
509
  track_enterprise_search_policy_predict(
456
510
  vector_store_type=self.vector_store_config.get(VECTOR_STORE_TYPE_PROPERTY),
457
- embeddings_type=self.embeddings_config.get("_type"),
458
- embeddings_model=self.embeddings_config.get("model")
459
- or self.embeddings_config.get("model_name"),
460
- llm_type=self.llm_config.get("_type"),
461
- llm_model=self.llm_config.get("model") or self.llm_config.get("model_name"),
511
+ embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
512
+ embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
513
+ or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
514
+ llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
515
+ llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
516
+ or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
462
517
  citation_enabled=self.citation_enabled,
463
518
  )
464
519
  return self._create_prediction(
@@ -495,10 +550,11 @@ class EnterpriseSearchPolicy(Policy):
495
550
  return prompt
496
551
 
497
552
  async def _generate_llm_answer(
498
- self, llm: "BaseLLM", prompt: Text
553
+ self, llm: LLMClient, prompt: Text
499
554
  ) -> Optional[Text]:
500
555
  try:
501
- llm_answer = await llm.apredict(prompt)
556
+ llm_response = await llm.acompletion(prompt)
557
+ llm_answer = llm_response.choices[0]
502
558
  except Exception as e:
503
559
  # unfortunately, langchain does not wrap LLM exceptions which means
504
560
  # we have to catch all exceptions here
@@ -684,7 +740,7 @@ class EnterpriseSearchPolicy(Policy):
684
740
  local_knowledge_data = cls._get_local_knowledge_data(config)
685
741
 
686
742
  prompt_template = get_prompt_template(
687
- config.get("prompt"),
743
+ config.get(PROMPT_CONFIG_KEY),
688
744
  DEFAULT_ENTERPRISE_SEARCH_PROMPT_TEMPLATE,
689
745
  )
690
746
  return deep_container_fingerprint([prompt_template, local_knowledge_data])
@@ -2,12 +2,14 @@ from __future__ import annotations
2
2
 
3
3
  from typing import Any, Dict, Text, List, Optional
4
4
 
5
+ import structlog
5
6
  from jinja2 import Template
6
- from rasa.dialogue_understanding.commands import CancelFlowCommand
7
- from rasa.dialogue_understanding.patterns.cancel import CancelPatternFlowStackFrame
7
+ from pypred import Predicate
8
8
  from structlog.contextvars import (
9
9
  bound_contextvars,
10
10
  )
11
+
12
+ from rasa.core.constants import STEP_ID_METADATA_KEY, ACTIVE_FLOW_METADATA_KEY
11
13
  from rasa.core.policies.flows.flow_exceptions import (
12
14
  FlowCircuitBreakerTrippedException,
13
15
  FlowException,
@@ -19,6 +21,20 @@ from rasa.core.policies.flows.flow_step_result import (
19
21
  FlowStepResult,
20
22
  PauseFlowReturnPrediction,
21
23
  )
24
+ from rasa.dialogue_understanding.commands import CancelFlowCommand
25
+ from rasa.dialogue_understanding.patterns.cancel import CancelPatternFlowStackFrame
26
+ from rasa.dialogue_understanding.patterns.collect_information import (
27
+ CollectInformationPatternFlowStackFrame,
28
+ )
29
+ from rasa.dialogue_understanding.patterns.completed import (
30
+ CompletedPatternFlowStackFrame,
31
+ )
32
+ from rasa.dialogue_understanding.patterns.continue_interrupted import (
33
+ ContinueInterruptedPatternFlowStackFrame,
34
+ )
35
+ from rasa.dialogue_understanding.patterns.human_handoff import (
36
+ HumanHandoffPatternFlowStackFrame,
37
+ )
22
38
  from rasa.dialogue_understanding.patterns.internal_error import (
23
39
  InternalErrorPatternFlowStackFrame,
24
40
  )
@@ -29,24 +45,13 @@ from rasa.dialogue_understanding.stack.frames import (
29
45
  DialogueStackFrame,
30
46
  UserFlowStackFrame,
31
47
  )
32
- from rasa.dialogue_understanding.patterns.collect_information import (
33
- CollectInformationPatternFlowStackFrame,
34
- )
35
- from rasa.dialogue_understanding.patterns.completed import (
36
- CompletedPatternFlowStackFrame,
37
- )
38
- from rasa.dialogue_understanding.patterns.continue_interrupted import (
39
- ContinueInterruptedPatternFlowStackFrame,
40
- )
41
48
  from rasa.dialogue_understanding.stack.frames.flow_stack_frame import (
42
49
  FlowStackFrameType,
43
50
  )
44
51
  from rasa.dialogue_understanding.stack.utils import (
45
52
  top_user_flow_frame,
46
53
  )
47
-
48
- from pypred import Predicate
49
-
54
+ from rasa.shared.constants import RASA_PATTERN_HUMAN_HANDOFF
50
55
  from rasa.shared.core.constants import ACTION_LISTEN_NAME, SlotMappingType
51
56
  from rasa.shared.core.events import (
52
57
  Event,
@@ -56,6 +61,11 @@ from rasa.shared.core.events import (
56
61
  SlotSet,
57
62
  )
58
63
  from rasa.shared.core.flows import FlowsList
64
+ from rasa.shared.core.flows.flow import (
65
+ END_STEP,
66
+ Flow,
67
+ FlowStep,
68
+ )
59
69
  from rasa.shared.core.flows.flow_step_links import (
60
70
  StaticFlowStepLink,
61
71
  IfFlowStepLink,
@@ -71,17 +81,11 @@ from rasa.shared.core.flows.steps import (
71
81
  CollectInformationFlowStep,
72
82
  NoOperationFlowStep,
73
83
  )
74
- from rasa.shared.core.flows.flow import (
75
- END_STEP,
76
- Flow,
77
- FlowStep,
78
- )
79
84
  from rasa.shared.core.flows.steps.collect import SlotRejection
80
85
  from rasa.shared.core.slots import Slot
81
86
  from rasa.shared.core.trackers import (
82
87
  DialogueStateTracker,
83
88
  )
84
- import structlog
85
89
 
86
90
  structlogger = structlog.get_logger()
87
91
 
@@ -466,6 +470,10 @@ def advance_flows_until_next_action(
466
470
  # make sure we really return all events that got created during the
467
471
  # step execution of all steps (not only the last one)
468
472
  prediction.events = gathered_events
473
+ prediction.metadata = {
474
+ ACTIVE_FLOW_METADATA_KEY: tracker.active_flow,
475
+ STEP_ID_METADATA_KEY: tracker.current_step_id,
476
+ }
469
477
  return prediction
470
478
  else:
471
479
  structlogger.warning("flow.step.execution.no_action")
@@ -476,15 +484,15 @@ def validate_collect_step(
476
484
  step: CollectInformationFlowStep,
477
485
  stack: DialogueStack,
478
486
  available_actions: List[str],
479
- slots: Dict[str, Slot],
480
- flow_name: str,
487
+ slots: Dict[Text, Slot],
481
488
  ) -> bool:
482
489
  """Validate that a collect step can be executed.
483
490
 
484
491
  A collect step can be executed if either the `utter_ask` or the `action_ask` is
485
492
  defined in the domain. If neither is defined, the collect step can still be
486
493
  executed if the slot has an initial value defined in the domain, which would cause
487
- the step to be skipped."""
494
+ the step to be skipped.
495
+ """
488
496
  slot = slots.get(step.collect)
489
497
  slot_has_initial_value_defined = slot and slot.initial_value is not None
490
498
  if (
@@ -499,12 +507,12 @@ def validate_collect_step(
499
507
  slot_name=step.collect,
500
508
  )
501
509
 
502
- cancel_flow_and_push_internal_error(stack, flow_name)
510
+ cancel_flow_and_push_internal_error(stack)
503
511
 
504
512
  return False
505
513
 
506
514
 
507
- def cancel_flow_and_push_internal_error(stack: DialogueStack, flow_name: str) -> None:
515
+ def cancel_flow_and_push_internal_error(stack: DialogueStack) -> None:
508
516
  """Cancel the top user flow and push the internal error pattern."""
509
517
  top_frame = stack.top()
510
518
 
@@ -516,7 +524,7 @@ def cancel_flow_and_push_internal_error(stack: DialogueStack, flow_name: str) ->
516
524
  canceled_frames = CancelFlowCommand.select_canceled_frames(stack)
517
525
  stack.push(
518
526
  CancelPatternFlowStackFrame(
519
- canceled_name=flow_name,
527
+ canceled_name=top_frame.flow_id,
520
528
  canceled_frames=canceled_frames,
521
529
  )
522
530
  )
@@ -528,7 +536,6 @@ def validate_custom_slot_mappings(
528
536
  stack: DialogueStack,
529
537
  tracker: DialogueStateTracker,
530
538
  available_actions: List[str],
531
- flow_name: str,
532
539
  ) -> bool:
533
540
  """Validate a slot with custom mappings.
534
541
 
@@ -549,7 +556,7 @@ def validate_custom_slot_mappings(
549
556
  action=step.collect_action,
550
557
  collect=step.collect,
551
558
  )
552
- cancel_flow_and_push_internal_error(stack, flow_name)
559
+ cancel_flow_and_push_internal_error(stack)
553
560
  return False
554
561
 
555
562
  return True
@@ -585,110 +592,139 @@ def run_step(
585
592
  """
586
593
  initial_events: List[Event] = []
587
594
  if step == flow.first_step_in_flow():
588
- initial_events.append(FlowStarted(flow.id))
595
+ initial_events.append(FlowStarted(flow.id, metadata=stack.current_context()))
589
596
 
590
597
  if isinstance(step, CollectInformationFlowStep):
591
598
  return _run_collect_information_step(
592
- available_actions,
593
- initial_events,
594
- stack,
595
- step,
596
- tracker,
597
- flow.readable_name(),
599
+ available_actions, initial_events, stack, step, tracker
598
600
  )
599
601
 
600
602
  elif isinstance(step, ActionFlowStep):
601
603
  if not step.action:
602
604
  raise FlowException(f"Action not specified for step {step}")
603
-
604
- context = {"context": stack.current_context()}
605
- action_name = render_template_variables(step.action, context)
606
-
607
- if action_name in available_actions:
608
- structlogger.debug("flow.step.run.action", context=context)
609
- return PauseFlowReturnPrediction(
610
- FlowActionPrediction(action_name, 1.0, events=initial_events)
611
- )
612
- else:
613
- if step.action != "validate_{{context.collect}}":
614
- # do not log about non-existing validation actions of collect steps
615
- utter_action_name = render_template_variables(
616
- "{{context.utter}}", context
617
- )
618
- if utter_action_name not in available_actions:
619
- structlogger.warning(
620
- "flow.step.run.action.unknown", action=action_name
621
- )
622
- return ContinueFlowWithNextStep(events=initial_events)
605
+ return _run_action_step(available_actions, initial_events, stack, step)
623
606
 
624
607
  elif isinstance(step, LinkFlowStep):
625
- structlogger.debug("flow.step.run.link")
626
- stack.push(
627
- UserFlowStackFrame(
628
- flow_id=step.link,
629
- frame_type=FlowStackFrameType.LINK,
630
- ),
631
- # push this below the current stack frame so that we can
632
- # complete the current flow first and then continue with the
633
- # linked flow
634
- index=-1,
635
- )
636
- return ContinueFlowWithNextStep(events=initial_events)
608
+ return _run_link_step(initial_events, stack, step)
637
609
 
638
610
  elif isinstance(step, CallFlowStep):
639
- structlogger.debug("flow.step.run.call")
640
- stack.push(
641
- UserFlowStackFrame(
642
- flow_id=step.call,
643
- frame_type=FlowStackFrameType.CALL,
644
- ),
645
- )
646
- return ContinueFlowWithNextStep()
611
+ return _run_call_step(initial_events, stack, step)
647
612
 
648
613
  elif isinstance(step, SetSlotsFlowStep):
649
- structlogger.debug("flow.step.run.slot")
650
- slot_events: List[Event] = events_from_set_slots_step(step)
651
- return ContinueFlowWithNextStep(events=initial_events + slot_events)
614
+ return _run_set_slot_step(initial_events, step)
652
615
 
653
616
  elif isinstance(step, NoOperationFlowStep):
654
617
  structlogger.debug("flow.step.run.no_operation")
655
618
  return ContinueFlowWithNextStep(events=initial_events)
656
619
 
657
620
  elif isinstance(step, EndFlowStep):
658
- # this is the end of the flow, so we'll pop it from the stack
659
- structlogger.debug("flow.step.run.flow_end")
660
- current_frame = stack.pop()
661
- trigger_pattern_completed(current_frame, stack, flows)
662
- resumed_events = trigger_pattern_continue_interrupted(
663
- current_frame, stack, flows
664
- )
665
- reset_events: List[Event] = reset_scoped_slots(current_frame, flow, tracker)
666
- return ContinueFlowWithNextStep(
667
- events=initial_events + reset_events + resumed_events, has_flow_ended=True
668
- )
621
+ return _run_end_step(flow, flows, initial_events, stack, tracker)
669
622
 
670
623
  else:
671
624
  raise FlowException(f"Unknown flow step type {type(step)}")
672
625
 
673
626
 
627
+ def _run_end_step(
628
+ flow: Flow,
629
+ flows: FlowsList,
630
+ initial_events: List[Event],
631
+ stack: DialogueStack,
632
+ tracker: DialogueStateTracker,
633
+ ) -> FlowStepResult:
634
+ # this is the end of the flow, so we'll pop it from the stack
635
+ structlogger.debug("flow.step.run.flow_end")
636
+ current_frame = stack.pop()
637
+ trigger_pattern_completed(current_frame, stack, flows)
638
+ resumed_events = trigger_pattern_continue_interrupted(current_frame, stack, flows)
639
+ reset_events: List[Event] = reset_scoped_slots(current_frame, flow, tracker)
640
+ return ContinueFlowWithNextStep(
641
+ events=initial_events + reset_events + resumed_events, has_flow_ended=True
642
+ )
643
+
644
+
645
+ def _run_set_slot_step(
646
+ initial_events: List[Event], step: SetSlotsFlowStep
647
+ ) -> FlowStepResult:
648
+ structlogger.debug("flow.step.run.slot")
649
+ slot_events: List[Event] = events_from_set_slots_step(step)
650
+ return ContinueFlowWithNextStep(events=initial_events + slot_events)
651
+
652
+
653
+ def _run_call_step(
654
+ initial_events: List[Event], stack: DialogueStack, step: CallFlowStep
655
+ ) -> FlowStepResult:
656
+ structlogger.debug("flow.step.run.call")
657
+ stack.push(
658
+ UserFlowStackFrame(
659
+ flow_id=step.call,
660
+ frame_type=FlowStackFrameType.CALL,
661
+ ),
662
+ )
663
+ return ContinueFlowWithNextStep(events=initial_events)
664
+
665
+
666
+ def _run_link_step(
667
+ initial_events: List[Event], stack: DialogueStack, step: LinkFlowStep
668
+ ) -> FlowStepResult:
669
+ structlogger.debug("flow.step.run.link")
670
+
671
+ if step.link == RASA_PATTERN_HUMAN_HANDOFF:
672
+ linked_stack_frame: DialogueStackFrame = HumanHandoffPatternFlowStackFrame()
673
+ else:
674
+ linked_stack_frame = UserFlowStackFrame(
675
+ flow_id=step.link,
676
+ frame_type=FlowStackFrameType.LINK,
677
+ )
678
+
679
+ stack.push(
680
+ linked_stack_frame,
681
+ # push this below the current stack frame so that we can
682
+ # complete the current flow first and then continue with the
683
+ # linked flow
684
+ index=-1,
685
+ )
686
+
687
+ return ContinueFlowWithNextStep(events=initial_events)
688
+
689
+
690
+ def _run_action_step(
691
+ available_actions: List[str],
692
+ initial_events: List[Event],
693
+ stack: DialogueStack,
694
+ step: ActionFlowStep,
695
+ ) -> FlowStepResult:
696
+ context = {"context": stack.current_context()}
697
+ action_name = render_template_variables(step.action, context)
698
+
699
+ if action_name in available_actions:
700
+ structlogger.debug("flow.step.run.action", context=context)
701
+ return PauseFlowReturnPrediction(
702
+ FlowActionPrediction(action_name, 1.0, events=initial_events)
703
+ )
704
+ else:
705
+ if step.action != "validate_{{context.collect}}":
706
+ # do not log about non-existing validation actions of collect steps
707
+ utter_action_name = render_template_variables("{{context.utter}}", context)
708
+ if utter_action_name not in available_actions:
709
+ structlogger.warning("flow.step.run.action.unknown", action=action_name)
710
+ return ContinueFlowWithNextStep(events=initial_events)
711
+
712
+
674
713
  def _run_collect_information_step(
675
714
  available_actions: List[str],
676
715
  initial_events: List[Event],
677
716
  stack: DialogueStack,
678
717
  step: CollectInformationFlowStep,
679
718
  tracker: DialogueStateTracker,
680
- flow_name: str,
681
719
  ) -> FlowStepResult:
682
- is_step_valid = validate_collect_step(
683
- step, stack, available_actions, tracker.slots, flow_name
684
- )
720
+ is_step_valid = validate_collect_step(step, stack, available_actions, tracker.slots)
685
721
 
686
722
  if not is_step_valid:
687
723
  # if we return any other FlowStepResult, the assistant will stay silent
688
724
  # instead of triggering the internal error pattern
689
725
  return ContinueFlowWithNextStep(events=initial_events)
690
726
  is_mapping_valid = validate_custom_slot_mappings(
691
- step, stack, tracker, available_actions, flow_name
727
+ step, stack, tracker, available_actions
692
728
  )
693
729
 
694
730
  if not is_mapping_valid: