rasa-pro 3.9.17__py3-none-any.whl → 3.10.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (187) hide show
  1. README.md +5 -37
  2. rasa/__init__.py +1 -2
  3. rasa/__main__.py +5 -0
  4. rasa/anonymization/anonymization_rule_executor.py +2 -2
  5. rasa/api.py +26 -22
  6. rasa/cli/arguments/data.py +27 -2
  7. rasa/cli/arguments/default_arguments.py +25 -3
  8. rasa/cli/arguments/run.py +9 -9
  9. rasa/cli/arguments/train.py +2 -0
  10. rasa/cli/data.py +70 -8
  11. rasa/cli/e2e_test.py +108 -433
  12. rasa/cli/interactive.py +1 -0
  13. rasa/cli/llm_fine_tuning.py +395 -0
  14. rasa/cli/project_templates/calm/endpoints.yml +1 -1
  15. rasa/cli/project_templates/tutorial/endpoints.yml +1 -1
  16. rasa/cli/run.py +14 -13
  17. rasa/cli/scaffold.py +10 -8
  18. rasa/cli/train.py +8 -7
  19. rasa/cli/utils.py +15 -0
  20. rasa/constants.py +7 -1
  21. rasa/core/actions/action.py +98 -49
  22. rasa/core/actions/action_run_slot_rejections.py +4 -1
  23. rasa/core/actions/custom_action_executor.py +9 -6
  24. rasa/core/actions/direct_custom_actions_executor.py +80 -0
  25. rasa/core/actions/e2e_stub_custom_action_executor.py +68 -0
  26. rasa/core/actions/grpc_custom_action_executor.py +2 -2
  27. rasa/core/actions/http_custom_action_executor.py +6 -5
  28. rasa/core/agent.py +21 -17
  29. rasa/core/channels/__init__.py +2 -0
  30. rasa/core/channels/audiocodes.py +1 -16
  31. rasa/core/channels/voice_aware/__init__.py +0 -0
  32. rasa/core/channels/voice_aware/jambonz.py +103 -0
  33. rasa/core/channels/voice_aware/jambonz_protocol.py +344 -0
  34. rasa/core/channels/voice_aware/utils.py +20 -0
  35. rasa/core/channels/voice_native/__init__.py +0 -0
  36. rasa/core/constants.py +6 -1
  37. rasa/core/featurizers/single_state_featurizer.py +1 -22
  38. rasa/core/featurizers/tracker_featurizers.py +18 -115
  39. rasa/core/information_retrieval/faiss.py +7 -4
  40. rasa/core/information_retrieval/information_retrieval.py +8 -0
  41. rasa/core/information_retrieval/milvus.py +9 -2
  42. rasa/core/information_retrieval/qdrant.py +1 -1
  43. rasa/core/nlg/contextual_response_rephraser.py +32 -10
  44. rasa/core/nlg/summarize.py +4 -3
  45. rasa/core/policies/enterprise_search_policy.py +100 -44
  46. rasa/core/policies/flows/flow_executor.py +155 -98
  47. rasa/core/policies/intentless_policy.py +52 -28
  48. rasa/core/policies/ted_policy.py +33 -58
  49. rasa/core/policies/unexpected_intent_policy.py +7 -15
  50. rasa/core/processor.py +15 -46
  51. rasa/core/run.py +5 -4
  52. rasa/core/tracker_store.py +8 -4
  53. rasa/core/utils.py +45 -56
  54. rasa/dialogue_understanding/coexistence/llm_based_router.py +45 -12
  55. rasa/dialogue_understanding/commands/__init__.py +4 -0
  56. rasa/dialogue_understanding/commands/change_flow_command.py +0 -6
  57. rasa/dialogue_understanding/commands/session_start_command.py +59 -0
  58. rasa/dialogue_understanding/commands/set_slot_command.py +1 -5
  59. rasa/dialogue_understanding/commands/utils.py +38 -0
  60. rasa/dialogue_understanding/generator/constants.py +10 -3
  61. rasa/dialogue_understanding/generator/flow_retrieval.py +14 -5
  62. rasa/dialogue_understanding/generator/llm_based_command_generator.py +12 -2
  63. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +106 -87
  64. rasa/dialogue_understanding/generator/nlu_command_adapter.py +28 -6
  65. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +90 -37
  66. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +15 -15
  67. rasa/dialogue_understanding/patterns/session_start.py +37 -0
  68. rasa/dialogue_understanding/processor/command_processor.py +13 -14
  69. rasa/e2e_test/aggregate_test_stats_calculator.py +124 -0
  70. rasa/e2e_test/assertions.py +1181 -0
  71. rasa/e2e_test/assertions_schema.yml +106 -0
  72. rasa/e2e_test/constants.py +20 -0
  73. rasa/e2e_test/e2e_config.py +220 -0
  74. rasa/e2e_test/e2e_config_schema.yml +26 -0
  75. rasa/e2e_test/e2e_test_case.py +131 -8
  76. rasa/e2e_test/e2e_test_converter.py +363 -0
  77. rasa/e2e_test/e2e_test_converter_prompt.jinja2 +70 -0
  78. rasa/e2e_test/e2e_test_coverage_report.py +364 -0
  79. rasa/e2e_test/e2e_test_result.py +26 -6
  80. rasa/e2e_test/e2e_test_runner.py +498 -73
  81. rasa/e2e_test/e2e_test_schema.yml +96 -0
  82. rasa/e2e_test/pykwalify_extensions.py +39 -0
  83. rasa/e2e_test/stub_custom_action.py +70 -0
  84. rasa/e2e_test/utils/__init__.py +0 -0
  85. rasa/e2e_test/utils/e2e_yaml_utils.py +55 -0
  86. rasa/e2e_test/utils/io.py +596 -0
  87. rasa/e2e_test/utils/validation.py +80 -0
  88. rasa/engine/recipes/default_components.py +0 -2
  89. rasa/engine/storage/local_model_storage.py +0 -1
  90. rasa/env.py +9 -0
  91. rasa/llm_fine_tuning/__init__.py +0 -0
  92. rasa/llm_fine_tuning/annotation_module.py +241 -0
  93. rasa/llm_fine_tuning/conversations.py +144 -0
  94. rasa/llm_fine_tuning/llm_data_preparation_module.py +178 -0
  95. rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +407 -0
  96. rasa/llm_fine_tuning/paraphrasing/__init__.py +0 -0
  97. rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +281 -0
  98. rasa/llm_fine_tuning/paraphrasing/default_rephrase_prompt_template.jina2 +44 -0
  99. rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +121 -0
  100. rasa/llm_fine_tuning/paraphrasing/rephrased_user_message.py +10 -0
  101. rasa/llm_fine_tuning/paraphrasing_module.py +128 -0
  102. rasa/llm_fine_tuning/storage.py +174 -0
  103. rasa/llm_fine_tuning/train_test_split_module.py +441 -0
  104. rasa/model_training.py +48 -16
  105. rasa/nlu/classifiers/diet_classifier.py +25 -38
  106. rasa/nlu/classifiers/logistic_regression_classifier.py +9 -44
  107. rasa/nlu/classifiers/sklearn_intent_classifier.py +16 -37
  108. rasa/nlu/extractors/crf_entity_extractor.py +50 -93
  109. rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +45 -78
  110. rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +17 -52
  111. rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +3 -5
  112. rasa/nlu/persistor.py +129 -32
  113. rasa/server.py +45 -10
  114. rasa/shared/constants.py +63 -15
  115. rasa/shared/core/domain.py +15 -12
  116. rasa/shared/core/events.py +28 -2
  117. rasa/shared/core/flows/flow.py +208 -13
  118. rasa/shared/core/flows/flow_path.py +84 -0
  119. rasa/shared/core/flows/flows_list.py +28 -10
  120. rasa/shared/core/flows/flows_yaml_schema.json +269 -193
  121. rasa/shared/core/flows/validation.py +112 -25
  122. rasa/shared/core/flows/yaml_flows_io.py +149 -10
  123. rasa/shared/core/trackers.py +6 -0
  124. rasa/shared/core/training_data/visualization.html +2 -2
  125. rasa/shared/exceptions.py +4 -0
  126. rasa/shared/importers/importer.py +60 -11
  127. rasa/shared/importers/remote_importer.py +196 -0
  128. rasa/shared/nlu/constants.py +2 -0
  129. rasa/shared/nlu/training_data/features.py +2 -120
  130. rasa/shared/providers/_configs/__init__.py +0 -0
  131. rasa/shared/providers/_configs/azure_openai_client_config.py +181 -0
  132. rasa/shared/providers/_configs/client_config.py +57 -0
  133. rasa/shared/providers/_configs/default_litellm_client_config.py +130 -0
  134. rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +234 -0
  135. rasa/shared/providers/_configs/openai_client_config.py +175 -0
  136. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +171 -0
  137. rasa/shared/providers/_configs/utils.py +101 -0
  138. rasa/shared/providers/_ssl_verification_utils.py +124 -0
  139. rasa/shared/providers/embedding/__init__.py +0 -0
  140. rasa/shared/providers/embedding/_base_litellm_embedding_client.py +254 -0
  141. rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py +74 -0
  142. rasa/shared/providers/embedding/azure_openai_embedding_client.py +277 -0
  143. rasa/shared/providers/embedding/default_litellm_embedding_client.py +102 -0
  144. rasa/shared/providers/embedding/embedding_client.py +90 -0
  145. rasa/shared/providers/embedding/embedding_response.py +41 -0
  146. rasa/shared/providers/embedding/huggingface_local_embedding_client.py +191 -0
  147. rasa/shared/providers/embedding/openai_embedding_client.py +172 -0
  148. rasa/shared/providers/llm/__init__.py +0 -0
  149. rasa/shared/providers/llm/_base_litellm_client.py +227 -0
  150. rasa/shared/providers/llm/azure_openai_llm_client.py +338 -0
  151. rasa/shared/providers/llm/default_litellm_llm_client.py +84 -0
  152. rasa/shared/providers/llm/llm_client.py +76 -0
  153. rasa/shared/providers/llm/llm_response.py +50 -0
  154. rasa/shared/providers/llm/openai_llm_client.py +155 -0
  155. rasa/shared/providers/llm/self_hosted_llm_client.py +169 -0
  156. rasa/shared/providers/mappings.py +75 -0
  157. rasa/shared/utils/cli.py +30 -0
  158. rasa/shared/utils/io.py +65 -3
  159. rasa/shared/utils/llm.py +223 -200
  160. rasa/shared/utils/yaml.py +122 -7
  161. rasa/studio/download.py +19 -13
  162. rasa/studio/train.py +2 -3
  163. rasa/studio/upload.py +2 -3
  164. rasa/telemetry.py +113 -58
  165. rasa/tracing/config.py +2 -3
  166. rasa/tracing/instrumentation/attribute_extractors.py +29 -17
  167. rasa/tracing/instrumentation/instrumentation.py +4 -47
  168. rasa/utils/common.py +18 -19
  169. rasa/utils/endpoints.py +7 -4
  170. rasa/utils/io.py +66 -0
  171. rasa/utils/json_utils.py +60 -0
  172. rasa/utils/licensing.py +9 -1
  173. rasa/utils/ml_utils.py +4 -2
  174. rasa/utils/tensorflow/model_data.py +193 -2
  175. rasa/validator.py +195 -1
  176. rasa/version.py +1 -1
  177. {rasa_pro-3.9.17.dist-info → rasa_pro-3.10.3.dist-info}/METADATA +25 -51
  178. {rasa_pro-3.9.17.dist-info → rasa_pro-3.10.3.dist-info}/RECORD +183 -119
  179. rasa/nlu/classifiers/llm_intent_classifier.py +0 -519
  180. rasa/shared/providers/openai/clients.py +0 -43
  181. rasa/shared/providers/openai/session_handler.py +0 -110
  182. rasa/utils/tensorflow/feature_array.py +0 -366
  183. /rasa/{shared/providers/openai → cli/project_templates/tutorial/actions}/__init__.py +0 -0
  184. /rasa/cli/project_templates/tutorial/{actions.py → actions/actions.py} +0 -0
  185. {rasa_pro-3.9.17.dist-info → rasa_pro-3.10.3.dist-info}/NOTICE +0 -0
  186. {rasa_pro-3.9.17.dist-info → rasa_pro-3.10.3.dist-info}/WHEEL +0 -0
  187. {rasa_pro-3.9.17.dist-info → rasa_pro-3.10.3.dist-info}/entry_points.txt +0 -0
@@ -8,6 +8,11 @@ import structlog
8
8
  from jinja2 import Template
9
9
  from pydantic import ValidationError
10
10
 
11
+ from rasa.shared.providers.embedding._langchain_embedding_client_adapter import (
12
+ _LangchainEmbeddingClientAdapter,
13
+ )
14
+ from rasa.shared.providers.llm.llm_client import LLMClient
15
+
11
16
  import rasa.shared.utils.io
12
17
  from rasa.telemetry import (
13
18
  track_enterprise_search_policy_predict,
@@ -19,6 +24,7 @@ from rasa.core.constants import (
19
24
  POLICY_MAX_HISTORY,
20
25
  POLICY_PRIORITY,
21
26
  SEARCH_POLICY_PRIORITY,
27
+ UTTER_SOURCE_METADATA_KEY,
22
28
  )
23
29
  from rasa.core.policies.policy import Policy, PolicyPrediction
24
30
  from rasa.core.utils import AvailableEndpoints
@@ -39,13 +45,23 @@ from rasa.engine.storage.resource import Resource
39
45
  from rasa.engine.storage.storage import ModelStorage
40
46
  from rasa.graph_components.providers.forms_provider import Forms
41
47
  from rasa.graph_components.providers.responses_provider import Responses
48
+ from rasa.shared.constants import (
49
+ EMBEDDINGS_CONFIG_KEY,
50
+ LLM_CONFIG_KEY,
51
+ MODEL_CONFIG_KEY,
52
+ MODEL_NAME_CONFIG_KEY,
53
+ PROMPT_CONFIG_KEY,
54
+ PROVIDER_CONFIG_KEY,
55
+ OPENAI_PROVIDER,
56
+ TIMEOUT_CONFIG_KEY,
57
+ )
42
58
  from rasa.shared.core.constants import (
43
59
  ACTION_CANCEL_FLOW,
44
60
  ACTION_SEND_TEXT_NAME,
45
61
  DEFAULT_SLOT_NAMES,
46
62
  )
47
63
  from rasa.shared.core.domain import Domain
48
- from rasa.shared.core.events import Event
64
+ from rasa.shared.core.events import Event, UserUttered, BotUttered
49
65
  from rasa.shared.core.generator import TrackerWithCachedStates
50
66
  from rasa.shared.core.trackers import DialogueStateTracker, EventVerbosity
51
67
  from rasa.shared.nlu.training_data.training_data import TrainingData
@@ -59,6 +75,7 @@ from rasa.shared.utils.llm import (
59
75
  llm_factory,
60
76
  sanitize_message_for_prompt,
61
77
  tracker_as_readable_transcript,
78
+ try_instantiate_llm_client,
62
79
  )
63
80
  from rasa.core.information_retrieval.faiss import FAISS_Store
64
81
  from rasa.core.information_retrieval import (
@@ -70,7 +87,6 @@ from rasa.core.information_retrieval import (
70
87
 
71
88
  if TYPE_CHECKING:
72
89
  from langchain.schema.embeddings import Embeddings
73
- from langchain.llms.base import BaseLLM
74
90
  from rasa.core.featurizers.tracker_featurizers import TrackerFeaturizer
75
91
 
76
92
  from rasa.utils.log_utils import log_llm
@@ -86,6 +102,7 @@ VECTOR_STORE_THRESHOLD_PROPERTY = "threshold"
86
102
  TRACE_TOKENS_PROPERTY = "trace_prompt_tokens"
87
103
  CITATION_ENABLED_PROPERTY = "citation_enabled"
88
104
  USE_LLM_PROPERTY = "use_generative_llm"
105
+ MAX_MESSAGES_IN_QUERY_KEY = "max_messages_in_query"
89
106
 
90
107
  DEFAULT_VECTOR_STORE_TYPE = "faiss"
91
108
  DEFAULT_VECTOR_STORE_THRESHOLD = 0.0
@@ -96,23 +113,24 @@ DEFAULT_VECTOR_STORE = {
96
113
  }
97
114
 
98
115
  DEFAULT_LLM_CONFIG = {
99
- "_type": "openai",
100
- "request_timeout": 10,
116
+ PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
117
+ MODEL_CONFIG_KEY: DEFAULT_OPENAI_CHAT_MODEL_NAME,
118
+ TIMEOUT_CONFIG_KEY: 10,
101
119
  "temperature": 0.0,
102
120
  "max_tokens": 256,
103
- "model_name": DEFAULT_OPENAI_CHAT_MODEL_NAME,
104
121
  "max_retries": 1,
105
122
  }
106
123
 
107
124
  DEFAULT_EMBEDDINGS_CONFIG = {
108
- "_type": "openai",
125
+ PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
109
126
  "model": DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
110
127
  }
111
128
 
112
- EMBEDDINGS_CONFIG_KEY = "embeddings"
113
- LLM_CONFIG_KEY = "llm"
114
129
  ENTERPRISE_SEARCH_PROMPT_FILE_NAME = "enterprise_search_policy_prompt.jinja2"
115
130
 
131
+ SEARCH_RESULTS_METADATA_KEY = "search_results"
132
+ SEARCH_QUERY_METADATA_KEY = "search_query"
133
+
116
134
  DEFAULT_ENTERPRISE_SEARCH_PROMPT_TEMPLATE = importlib.resources.read_text(
117
135
  "rasa.core.policies", "enterprise_search_prompt_template.jinja2"
118
136
  )
@@ -179,26 +197,42 @@ class EnterpriseSearchPolicy(Policy):
179
197
  """Constructs a new Policy object."""
180
198
  super().__init__(config, model_storage, resource, execution_context, featurizer)
181
199
 
200
+ # Vector store object and configuration
182
201
  self.vector_store = vector_store
183
202
  self.vector_store_config = config.get(
184
203
  VECTOR_STORE_PROPERTY, DEFAULT_VECTOR_STORE
185
204
  )
186
- self.llm_config = self.config.get(LLM_CONFIG_KEY, DEFAULT_LLM_CONFIG)
205
+ # Embeddings configuration for encoding the search query
187
206
  self.embeddings_config = self.config.get(
188
207
  EMBEDDINGS_CONFIG_KEY, DEFAULT_EMBEDDINGS_CONFIG
189
208
  )
209
+ # Maximum number of turns to include in the prompt
190
210
  self.max_history = self.config.get(POLICY_MAX_HISTORY)
191
- self.prompt_template = prompt_template or get_prompt_template(
192
- self.config.get("prompt"),
193
- DEFAULT_ENTERPRISE_SEARCH_PROMPT_TEMPLATE,
194
- )
211
+
212
+ # Maximum number of messages to include in the search query
213
+ self.max_messages_in_query = self.config.get(MAX_MESSAGES_IN_QUERY_KEY, 2)
214
+
215
+ # LLM Configuration for response generation
216
+ self.llm_config = self.config.get(LLM_CONFIG_KEY, DEFAULT_LLM_CONFIG)
217
+
218
+ # boolean to enable/disable tracing of prompt tokens
195
219
  self.trace_prompt_tokens = self.config.get(TRACE_TOKENS_PROPERTY, False)
220
+
221
+ # boolean to enable/disable the use of LLM for response generation
196
222
  self.use_llm = self.config.get(USE_LLM_PROPERTY, True)
223
+
224
+ # boolean to enable/disable citation generation
197
225
  self.citation_enabled = self.config.get(CITATION_ENABLED_PROPERTY, False)
226
+
227
+ self.prompt_template = prompt_template or get_prompt_template(
228
+ self.config.get(PROMPT_CONFIG_KEY),
229
+ DEFAULT_ENTERPRISE_SEARCH_PROMPT_TEMPLATE,
230
+ )
198
231
  self.citation_prompt_template = get_prompt_template(
199
- self.config.get("prompt"),
232
+ self.config.get(PROMPT_CONFIG_KEY),
200
233
  DEFAULT_ENTERPRISE_SEARCH_PROMPT_WITH_CITATION_TEMPLATE,
201
234
  )
235
+ # If citation is enabled, use the citation prompt template
202
236
  if self.citation_enabled:
203
237
  self.prompt_template = self.citation_prompt_template
204
238
 
@@ -209,9 +243,10 @@ class EnterpriseSearchPolicy(Policy):
209
243
  Returns:
210
244
  The embedder.
211
245
  """
212
- return embedder_factory(
246
+ client = embedder_factory(
213
247
  config.get(EMBEDDINGS_CONFIG_KEY), DEFAULT_EMBEDDINGS_CONFIG
214
248
  )
249
+ return _LangchainEmbeddingClientAdapter(client)
215
250
 
216
251
  def train( # type: ignore[override]
217
252
  self,
@@ -245,20 +280,24 @@ class EnterpriseSearchPolicy(Policy):
245
280
  # validate embedding configuration
246
281
  try:
247
282
  embeddings = self._create_plain_embedder(self.config)
248
- except ValidationError as e:
283
+ except (ValidationError, Exception) as e:
284
+ logger.error(
285
+ "enterprise_search_policy.train.embedder_instantiation_failed",
286
+ message="Unable to instantiate the embedding client.",
287
+ error=e,
288
+ )
249
289
  print_error_and_exit(
250
290
  "Unable to create embedder. Please make sure you specified the "
251
291
  f"required environment variables. Error: {e}"
252
292
  )
253
293
 
254
294
  # validate llm configuration
255
- try:
256
- llm_factory(self.config.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG)
257
- except (ImportError, ValueError, ValidationError) as e:
258
- # ImportError: llm library is likely not installed
259
- # ValueError: llm config is likely invalid
260
- # ValidationError: environment variables are likely not set
261
- print_error_and_exit(f"Unable to create LLM. Error: {e}")
295
+ try_instantiate_llm_client(
296
+ self.config.get(LLM_CONFIG_KEY),
297
+ DEFAULT_LLM_CONFIG,
298
+ "enterprise_search_policy.train",
299
+ "EnterpriseSearchPolicy",
300
+ )
262
301
 
263
302
  if store_type == DEFAULT_VECTOR_STORE_TYPE:
264
303
  logger.info("enterprise_search_policy.train.faiss")
@@ -275,11 +314,12 @@ class EnterpriseSearchPolicy(Policy):
275
314
  # telemetry call to track training completion
276
315
  track_enterprise_search_policy_train_completed(
277
316
  vector_store_type=store_type,
278
- embeddings_type=self.embeddings_config.get("_type"),
279
- embeddings_model=self.embeddings_config.get("model")
280
- or self.embeddings_config.get("model_name"),
281
- llm_type=self.llm_config.get("_type"),
282
- llm_model=self.llm_config.get("model") or self.llm_config.get("model_name"),
317
+ embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
318
+ embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
319
+ or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
320
+ llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
321
+ llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
322
+ or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
283
323
  citation_enabled=self.citation_enabled,
284
324
  )
285
325
  self.persist()
@@ -343,24 +383,31 @@ class EnterpriseSearchPolicy(Policy):
343
383
  logger.error(
344
384
  "enterprise_search_policy._connect_vector_store_or_raise.connect_error",
345
385
  error=e,
386
+ config=config,
346
387
  )
347
388
  raise VectorStoreConnectionError(
348
389
  f"Unable to connect to the vector store. Error: {e}"
349
390
  )
350
391
 
351
- def _get_last_user_message(self, tracker: DialogueStateTracker) -> str:
352
- """Get the last user message from the tracker.
392
+ def _prepare_search_query(self, tracker: DialogueStateTracker, history: int) -> str:
393
+ """Prepares the search query.
394
+ The search query is the last N messages in the conversation history.
353
395
 
354
396
  Args:
355
397
  tracker: The tracker containing the conversation history up to now.
398
+ history: The number of messages to include in the search query.
356
399
 
357
400
  Returns:
358
- The last user message.
401
+ The search query.
359
402
  """
360
- for event in reversed(tracker.events):
361
- if isinstance(event, rasa.shared.core.events.UserUttered):
362
- return sanitize_message_for_prompt(event.text)
363
- return ""
403
+ transcript = []
404
+ for event in tracker.applied_events():
405
+ if isinstance(event, UserUttered) or isinstance(event, BotUttered):
406
+ transcript.append(sanitize_message_for_prompt(event.text))
407
+
408
+ search_query = " ".join(transcript[-history:][::-1])
409
+ logger.debug("search_query", search_query=search_query)
410
+ return search_query
364
411
 
365
412
  async def predict_action_probabilities( # type: ignore[override]
366
413
  self,
@@ -404,7 +451,9 @@ class EnterpriseSearchPolicy(Policy):
404
451
  logger.error(f"{logger_key}.connection_error", error=e)
405
452
  return self._create_prediction_internal_error(domain, tracker)
406
453
 
407
- search_query = self._get_last_user_message(tracker)
454
+ search_query = self._prepare_search_query(
455
+ tracker, int(self.max_messages_in_query)
456
+ )
408
457
  tracker_state = tracker.current_state(EventVerbosity.AFTER_RESTART)
409
458
 
410
459
  try:
@@ -448,17 +497,23 @@ class EnterpriseSearchPolicy(Policy):
448
497
  action_metadata = {
449
498
  "message": {
450
499
  "text": response,
500
+ SEARCH_RESULTS_METADATA_KEY: [
501
+ result.text for result in documents.results
502
+ ],
503
+ UTTER_SOURCE_METADATA_KEY: self.__class__.__name__,
504
+ SEARCH_QUERY_METADATA_KEY: search_query,
451
505
  }
452
506
  }
453
507
 
454
508
  # telemetry call to track policy prediction
455
509
  track_enterprise_search_policy_predict(
456
510
  vector_store_type=self.vector_store_config.get(VECTOR_STORE_TYPE_PROPERTY),
457
- embeddings_type=self.embeddings_config.get("_type"),
458
- embeddings_model=self.embeddings_config.get("model")
459
- or self.embeddings_config.get("model_name"),
460
- llm_type=self.llm_config.get("_type"),
461
- llm_model=self.llm_config.get("model") or self.llm_config.get("model_name"),
511
+ embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
512
+ embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
513
+ or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
514
+ llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
515
+ llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
516
+ or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
462
517
  citation_enabled=self.citation_enabled,
463
518
  )
464
519
  return self._create_prediction(
@@ -495,10 +550,11 @@ class EnterpriseSearchPolicy(Policy):
495
550
  return prompt
496
551
 
497
552
  async def _generate_llm_answer(
498
- self, llm: "BaseLLM", prompt: Text
553
+ self, llm: LLMClient, prompt: Text
499
554
  ) -> Optional[Text]:
500
555
  try:
501
- llm_answer = await llm.apredict(prompt)
556
+ llm_response = await llm.acompletion(prompt)
557
+ llm_answer = llm_response.choices[0]
502
558
  except Exception as e:
503
559
  # unfortunately, langchain does not wrap LLM exceptions which means
504
560
  # we have to catch all exceptions here
@@ -684,7 +740,7 @@ class EnterpriseSearchPolicy(Policy):
684
740
  local_knowledge_data = cls._get_local_knowledge_data(config)
685
741
 
686
742
  prompt_template = get_prompt_template(
687
- config.get("prompt"),
743
+ config.get(PROMPT_CONFIG_KEY),
688
744
  DEFAULT_ENTERPRISE_SEARCH_PROMPT_TEMPLATE,
689
745
  )
690
746
  return deep_container_fingerprint([prompt_template, local_knowledge_data])
@@ -2,12 +2,14 @@ from __future__ import annotations
2
2
 
3
3
  from typing import Any, Dict, Text, List, Optional
4
4
 
5
+ import structlog
5
6
  from jinja2 import Template
6
- from rasa.dialogue_understanding.commands import CancelFlowCommand
7
- from rasa.dialogue_understanding.patterns.cancel import CancelPatternFlowStackFrame
7
+ from pypred import Predicate
8
8
  from structlog.contextvars import (
9
9
  bound_contextvars,
10
10
  )
11
+
12
+ from rasa.core.constants import STEP_ID_METADATA_KEY, ACTIVE_FLOW_METADATA_KEY
11
13
  from rasa.core.policies.flows.flow_exceptions import (
12
14
  FlowCircuitBreakerTrippedException,
13
15
  FlowException,
@@ -19,6 +21,20 @@ from rasa.core.policies.flows.flow_step_result import (
19
21
  FlowStepResult,
20
22
  PauseFlowReturnPrediction,
21
23
  )
24
+ from rasa.dialogue_understanding.commands import CancelFlowCommand
25
+ from rasa.dialogue_understanding.patterns.cancel import CancelPatternFlowStackFrame
26
+ from rasa.dialogue_understanding.patterns.collect_information import (
27
+ CollectInformationPatternFlowStackFrame,
28
+ )
29
+ from rasa.dialogue_understanding.patterns.completed import (
30
+ CompletedPatternFlowStackFrame,
31
+ )
32
+ from rasa.dialogue_understanding.patterns.continue_interrupted import (
33
+ ContinueInterruptedPatternFlowStackFrame,
34
+ )
35
+ from rasa.dialogue_understanding.patterns.human_handoff import (
36
+ HumanHandoffPatternFlowStackFrame,
37
+ )
22
38
  from rasa.dialogue_understanding.patterns.internal_error import (
23
39
  InternalErrorPatternFlowStackFrame,
24
40
  )
@@ -29,24 +45,13 @@ from rasa.dialogue_understanding.stack.frames import (
29
45
  DialogueStackFrame,
30
46
  UserFlowStackFrame,
31
47
  )
32
- from rasa.dialogue_understanding.patterns.collect_information import (
33
- CollectInformationPatternFlowStackFrame,
34
- )
35
- from rasa.dialogue_understanding.patterns.completed import (
36
- CompletedPatternFlowStackFrame,
37
- )
38
- from rasa.dialogue_understanding.patterns.continue_interrupted import (
39
- ContinueInterruptedPatternFlowStackFrame,
40
- )
41
48
  from rasa.dialogue_understanding.stack.frames.flow_stack_frame import (
42
49
  FlowStackFrameType,
43
50
  )
44
51
  from rasa.dialogue_understanding.stack.utils import (
45
52
  top_user_flow_frame,
46
53
  )
47
-
48
- from pypred import Predicate
49
-
54
+ from rasa.shared.constants import RASA_PATTERN_HUMAN_HANDOFF
50
55
  from rasa.shared.core.constants import ACTION_LISTEN_NAME, SlotMappingType
51
56
  from rasa.shared.core.events import (
52
57
  Event,
@@ -56,6 +61,11 @@ from rasa.shared.core.events import (
56
61
  SlotSet,
57
62
  )
58
63
  from rasa.shared.core.flows import FlowsList
64
+ from rasa.shared.core.flows.flow import (
65
+ END_STEP,
66
+ Flow,
67
+ FlowStep,
68
+ )
59
69
  from rasa.shared.core.flows.flow_step_links import (
60
70
  StaticFlowStepLink,
61
71
  IfFlowStepLink,
@@ -71,17 +81,11 @@ from rasa.shared.core.flows.steps import (
71
81
  CollectInformationFlowStep,
72
82
  NoOperationFlowStep,
73
83
  )
74
- from rasa.shared.core.flows.flow import (
75
- END_STEP,
76
- Flow,
77
- FlowStep,
78
- )
79
84
  from rasa.shared.core.flows.steps.collect import SlotRejection
80
85
  from rasa.shared.core.slots import Slot
81
86
  from rasa.shared.core.trackers import (
82
87
  DialogueStateTracker,
83
88
  )
84
- import structlog
85
89
 
86
90
  structlogger = structlog.get_logger()
87
91
 
@@ -466,6 +470,10 @@ def advance_flows_until_next_action(
466
470
  # make sure we really return all events that got created during the
467
471
  # step execution of all steps (not only the last one)
468
472
  prediction.events = gathered_events
473
+ prediction.metadata = {
474
+ ACTIVE_FLOW_METADATA_KEY: tracker.active_flow,
475
+ STEP_ID_METADATA_KEY: tracker.current_step_id,
476
+ }
469
477
  return prediction
470
478
  else:
471
479
  structlogger.warning("flow.step.execution.no_action")
@@ -483,7 +491,8 @@ def validate_collect_step(
483
491
  A collect step can be executed if either the `utter_ask` or the `action_ask` is
484
492
  defined in the domain. If neither is defined, the collect step can still be
485
493
  executed if the slot has an initial value defined in the domain, which would cause
486
- the step to be skipped."""
494
+ the step to be skipped.
495
+ """
487
496
  slot = slots.get(step.collect)
488
497
  slot_has_initial_value_defined = slot and slot.initial_value is not None
489
498
  if (
@@ -583,102 +592,150 @@ def run_step(
583
592
  """
584
593
  initial_events: List[Event] = []
585
594
  if step == flow.first_step_in_flow():
586
- initial_events.append(FlowStarted(flow.id))
595
+ initial_events.append(FlowStarted(flow.id, metadata=stack.current_context()))
587
596
 
588
597
  if isinstance(step, CollectInformationFlowStep):
589
- is_step_valid = validate_collect_step(
590
- step, stack, available_actions, tracker.slots
598
+ return _run_collect_information_step(
599
+ available_actions, initial_events, stack, step, tracker
591
600
  )
592
- if not is_step_valid:
593
- # if we return any other FlowStepResult, the assistant will stay silent
594
- # instead of triggering the internal error pattern
595
- return ContinueFlowWithNextStep(events=initial_events)
596
-
597
- is_mapping_valid = validate_custom_slot_mappings(
598
- step, stack, tracker, available_actions
599
- )
600
- if not is_mapping_valid:
601
- # if we return any other FlowStepResult, the assistant will stay silent
602
- # instead of triggering the internal error pattern
603
- return ContinueFlowWithNextStep(events=initial_events)
604
-
605
- structlogger.debug("flow.step.run.collect")
606
- trigger_pattern_ask_collect_information(
607
- step.collect, stack, step.rejections, step.utter, step.collect_action
608
- )
609
-
610
- events: List[Event] = events_for_collect_step_execution(step, tracker)
611
- return ContinueFlowWithNextStep(events=initial_events + events)
612
601
 
613
602
  elif isinstance(step, ActionFlowStep):
614
603
  if not step.action:
615
604
  raise FlowException(f"Action not specified for step {step}")
616
-
617
- context = {"context": stack.current_context()}
618
- action_name = render_template_variables(step.action, context)
619
-
620
- if action_name in available_actions:
621
- structlogger.debug("flow.step.run.action", context=context)
622
- return PauseFlowReturnPrediction(
623
- FlowActionPrediction(action_name, 1.0, events=initial_events)
624
- )
625
- else:
626
- if step.action != "validate_{{context.collect}}":
627
- # do not log about non-existing validation actions of collect steps
628
- utter_action_name = render_template_variables(
629
- "{{context.utter}}", context
630
- )
631
- if utter_action_name not in available_actions:
632
- structlogger.warning(
633
- "flow.step.run.action.unknown", action=action_name
634
- )
635
- return ContinueFlowWithNextStep(events=initial_events)
605
+ return _run_action_step(available_actions, initial_events, stack, step)
636
606
 
637
607
  elif isinstance(step, LinkFlowStep):
638
- structlogger.debug("flow.step.run.link")
639
- stack.push(
640
- UserFlowStackFrame(
641
- flow_id=step.link,
642
- frame_type=FlowStackFrameType.LINK,
643
- ),
644
- # push this below the current stack frame so that we can
645
- # complete the current flow first and then continue with the
646
- # linked flow
647
- index=-1,
648
- )
649
- return ContinueFlowWithNextStep(events=initial_events)
608
+ return _run_link_step(initial_events, stack, step)
650
609
 
651
610
  elif isinstance(step, CallFlowStep):
652
- structlogger.debug("flow.step.run.call")
653
- stack.push(
654
- UserFlowStackFrame(
655
- flow_id=step.call,
656
- frame_type=FlowStackFrameType.CALL,
657
- ),
658
- )
659
- return ContinueFlowWithNextStep()
611
+ return _run_call_step(initial_events, stack, step)
660
612
 
661
613
  elif isinstance(step, SetSlotsFlowStep):
662
- structlogger.debug("flow.step.run.slot")
663
- slot_events: List[Event] = events_from_set_slots_step(step)
664
- return ContinueFlowWithNextStep(events=initial_events + slot_events)
614
+ return _run_set_slot_step(initial_events, step)
665
615
 
666
616
  elif isinstance(step, NoOperationFlowStep):
667
617
  structlogger.debug("flow.step.run.no_operation")
668
618
  return ContinueFlowWithNextStep(events=initial_events)
669
619
 
670
620
  elif isinstance(step, EndFlowStep):
671
- # this is the end of the flow, so we'll pop it from the stack
672
- structlogger.debug("flow.step.run.flow_end")
673
- current_frame = stack.pop()
674
- trigger_pattern_completed(current_frame, stack, flows)
675
- resumed_events = trigger_pattern_continue_interrupted(
676
- current_frame, stack, flows
677
- )
678
- reset_events: List[Event] = reset_scoped_slots(current_frame, flow, tracker)
679
- return ContinueFlowWithNextStep(
680
- events=initial_events + reset_events + resumed_events, has_flow_ended=True
681
- )
621
+ return _run_end_step(flow, flows, initial_events, stack, tracker)
682
622
 
683
623
  else:
684
624
  raise FlowException(f"Unknown flow step type {type(step)}")
625
+
626
+
627
+ def _run_end_step(
628
+ flow: Flow,
629
+ flows: FlowsList,
630
+ initial_events: List[Event],
631
+ stack: DialogueStack,
632
+ tracker: DialogueStateTracker,
633
+ ) -> FlowStepResult:
634
+ # this is the end of the flow, so we'll pop it from the stack
635
+ structlogger.debug("flow.step.run.flow_end")
636
+ current_frame = stack.pop()
637
+ trigger_pattern_completed(current_frame, stack, flows)
638
+ resumed_events = trigger_pattern_continue_interrupted(current_frame, stack, flows)
639
+ reset_events: List[Event] = reset_scoped_slots(current_frame, flow, tracker)
640
+ return ContinueFlowWithNextStep(
641
+ events=initial_events + reset_events + resumed_events, has_flow_ended=True
642
+ )
643
+
644
+
645
+ def _run_set_slot_step(
646
+ initial_events: List[Event], step: SetSlotsFlowStep
647
+ ) -> FlowStepResult:
648
+ structlogger.debug("flow.step.run.slot")
649
+ slot_events: List[Event] = events_from_set_slots_step(step)
650
+ return ContinueFlowWithNextStep(events=initial_events + slot_events)
651
+
652
+
653
+ def _run_call_step(
654
+ initial_events: List[Event], stack: DialogueStack, step: CallFlowStep
655
+ ) -> FlowStepResult:
656
+ structlogger.debug("flow.step.run.call")
657
+ stack.push(
658
+ UserFlowStackFrame(
659
+ flow_id=step.call,
660
+ frame_type=FlowStackFrameType.CALL,
661
+ ),
662
+ )
663
+ return ContinueFlowWithNextStep(events=initial_events)
664
+
665
+
666
+ def _run_link_step(
667
+ initial_events: List[Event], stack: DialogueStack, step: LinkFlowStep
668
+ ) -> FlowStepResult:
669
+ structlogger.debug("flow.step.run.link")
670
+
671
+ if step.link == RASA_PATTERN_HUMAN_HANDOFF:
672
+ linked_stack_frame: DialogueStackFrame = HumanHandoffPatternFlowStackFrame()
673
+ else:
674
+ linked_stack_frame = UserFlowStackFrame(
675
+ flow_id=step.link,
676
+ frame_type=FlowStackFrameType.LINK,
677
+ )
678
+
679
+ stack.push(
680
+ linked_stack_frame,
681
+ # push this below the current stack frame so that we can
682
+ # complete the current flow first and then continue with the
683
+ # linked flow
684
+ index=-1,
685
+ )
686
+
687
+ return ContinueFlowWithNextStep(events=initial_events)
688
+
689
+
690
+ def _run_action_step(
691
+ available_actions: List[str],
692
+ initial_events: List[Event],
693
+ stack: DialogueStack,
694
+ step: ActionFlowStep,
695
+ ) -> FlowStepResult:
696
+ context = {"context": stack.current_context()}
697
+ action_name = render_template_variables(step.action, context)
698
+
699
+ if action_name in available_actions:
700
+ structlogger.debug("flow.step.run.action", context=context)
701
+ return PauseFlowReturnPrediction(
702
+ FlowActionPrediction(action_name, 1.0, events=initial_events)
703
+ )
704
+ else:
705
+ if step.action != "validate_{{context.collect}}":
706
+ # do not log about non-existing validation actions of collect steps
707
+ utter_action_name = render_template_variables("{{context.utter}}", context)
708
+ if utter_action_name not in available_actions:
709
+ structlogger.warning("flow.step.run.action.unknown", action=action_name)
710
+ return ContinueFlowWithNextStep(events=initial_events)
711
+
712
+
713
+ def _run_collect_information_step(
714
+ available_actions: List[str],
715
+ initial_events: List[Event],
716
+ stack: DialogueStack,
717
+ step: CollectInformationFlowStep,
718
+ tracker: DialogueStateTracker,
719
+ ) -> FlowStepResult:
720
+ is_step_valid = validate_collect_step(step, stack, available_actions, tracker.slots)
721
+
722
+ if not is_step_valid:
723
+ # if we return any other FlowStepResult, the assistant will stay silent
724
+ # instead of triggering the internal error pattern
725
+ return ContinueFlowWithNextStep(events=initial_events)
726
+ is_mapping_valid = validate_custom_slot_mappings(
727
+ step, stack, tracker, available_actions
728
+ )
729
+
730
+ if not is_mapping_valid:
731
+ # if we return any other FlowStepResult, the assistant will stay silent
732
+ # instead of triggering the internal error pattern
733
+ return ContinueFlowWithNextStep(events=initial_events)
734
+
735
+ structlogger.debug("flow.step.run.collect")
736
+ trigger_pattern_ask_collect_information(
737
+ step.collect, stack, step.rejections, step.utter, step.collect_action
738
+ )
739
+
740
+ events: List[Event] = events_for_collect_step_execution(step, tracker)
741
+ return ContinueFlowWithNextStep(events=initial_events + events)