rasa-pro 3.9.18__py3-none-any.whl → 3.10.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (183) hide show
  1. README.md +0 -374
  2. rasa/__init__.py +1 -2
  3. rasa/__main__.py +5 -0
  4. rasa/anonymization/anonymization_rule_executor.py +2 -2
  5. rasa/api.py +27 -23
  6. rasa/cli/arguments/data.py +27 -2
  7. rasa/cli/arguments/default_arguments.py +25 -3
  8. rasa/cli/arguments/run.py +9 -9
  9. rasa/cli/arguments/train.py +11 -3
  10. rasa/cli/data.py +70 -8
  11. rasa/cli/e2e_test.py +104 -431
  12. rasa/cli/evaluate.py +1 -1
  13. rasa/cli/interactive.py +1 -0
  14. rasa/cli/llm_fine_tuning.py +398 -0
  15. rasa/cli/project_templates/calm/endpoints.yml +1 -1
  16. rasa/cli/project_templates/tutorial/endpoints.yml +1 -1
  17. rasa/cli/run.py +15 -14
  18. rasa/cli/scaffold.py +10 -8
  19. rasa/cli/studio/studio.py +35 -5
  20. rasa/cli/train.py +56 -8
  21. rasa/cli/utils.py +22 -5
  22. rasa/cli/x.py +1 -1
  23. rasa/constants.py +7 -1
  24. rasa/core/actions/action.py +98 -49
  25. rasa/core/actions/action_run_slot_rejections.py +4 -1
  26. rasa/core/actions/custom_action_executor.py +9 -6
  27. rasa/core/actions/direct_custom_actions_executor.py +80 -0
  28. rasa/core/actions/e2e_stub_custom_action_executor.py +68 -0
  29. rasa/core/actions/grpc_custom_action_executor.py +2 -2
  30. rasa/core/actions/http_custom_action_executor.py +6 -5
  31. rasa/core/agent.py +21 -17
  32. rasa/core/channels/__init__.py +2 -0
  33. rasa/core/channels/audiocodes.py +1 -16
  34. rasa/core/channels/voice_aware/__init__.py +0 -0
  35. rasa/core/channels/voice_aware/jambonz.py +103 -0
  36. rasa/core/channels/voice_aware/jambonz_protocol.py +344 -0
  37. rasa/core/channels/voice_aware/utils.py +20 -0
  38. rasa/core/channels/voice_native/__init__.py +0 -0
  39. rasa/core/constants.py +6 -1
  40. rasa/core/information_retrieval/faiss.py +7 -4
  41. rasa/core/information_retrieval/information_retrieval.py +8 -0
  42. rasa/core/information_retrieval/milvus.py +9 -2
  43. rasa/core/information_retrieval/qdrant.py +1 -1
  44. rasa/core/nlg/contextual_response_rephraser.py +32 -10
  45. rasa/core/nlg/summarize.py +4 -3
  46. rasa/core/policies/enterprise_search_policy.py +113 -45
  47. rasa/core/policies/flows/flow_executor.py +122 -76
  48. rasa/core/policies/intentless_policy.py +83 -29
  49. rasa/core/processor.py +72 -54
  50. rasa/core/run.py +5 -4
  51. rasa/core/tracker_store.py +8 -4
  52. rasa/core/training/interactive.py +1 -1
  53. rasa/core/utils.py +56 -57
  54. rasa/dialogue_understanding/coexistence/llm_based_router.py +53 -13
  55. rasa/dialogue_understanding/commands/__init__.py +6 -0
  56. rasa/dialogue_understanding/commands/restart_command.py +58 -0
  57. rasa/dialogue_understanding/commands/session_start_command.py +59 -0
  58. rasa/dialogue_understanding/commands/utils.py +40 -0
  59. rasa/dialogue_understanding/generator/constants.py +10 -3
  60. rasa/dialogue_understanding/generator/flow_retrieval.py +21 -5
  61. rasa/dialogue_understanding/generator/llm_based_command_generator.py +13 -3
  62. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +134 -90
  63. rasa/dialogue_understanding/generator/nlu_command_adapter.py +47 -7
  64. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +127 -41
  65. rasa/dialogue_understanding/patterns/restart.py +37 -0
  66. rasa/dialogue_understanding/patterns/session_start.py +37 -0
  67. rasa/dialogue_understanding/processor/command_processor.py +16 -3
  68. rasa/dialogue_understanding/processor/command_processor_component.py +6 -2
  69. rasa/e2e_test/aggregate_test_stats_calculator.py +134 -0
  70. rasa/e2e_test/assertions.py +1223 -0
  71. rasa/e2e_test/assertions_schema.yml +106 -0
  72. rasa/e2e_test/constants.py +20 -0
  73. rasa/e2e_test/e2e_config.py +220 -0
  74. rasa/e2e_test/e2e_config_schema.yml +26 -0
  75. rasa/e2e_test/e2e_test_case.py +131 -8
  76. rasa/e2e_test/e2e_test_converter.py +363 -0
  77. rasa/e2e_test/e2e_test_converter_prompt.jinja2 +70 -0
  78. rasa/e2e_test/e2e_test_coverage_report.py +364 -0
  79. rasa/e2e_test/e2e_test_result.py +26 -6
  80. rasa/e2e_test/e2e_test_runner.py +493 -71
  81. rasa/e2e_test/e2e_test_schema.yml +96 -0
  82. rasa/e2e_test/pykwalify_extensions.py +39 -0
  83. rasa/e2e_test/stub_custom_action.py +70 -0
  84. rasa/e2e_test/utils/__init__.py +0 -0
  85. rasa/e2e_test/utils/e2e_yaml_utils.py +55 -0
  86. rasa/e2e_test/utils/io.py +598 -0
  87. rasa/e2e_test/utils/validation.py +80 -0
  88. rasa/engine/graph.py +9 -3
  89. rasa/engine/recipes/default_components.py +0 -2
  90. rasa/engine/recipes/default_recipe.py +10 -2
  91. rasa/engine/storage/local_model_storage.py +40 -12
  92. rasa/engine/validation.py +78 -1
  93. rasa/env.py +9 -0
  94. rasa/graph_components/providers/story_graph_provider.py +59 -6
  95. rasa/llm_fine_tuning/__init__.py +0 -0
  96. rasa/llm_fine_tuning/annotation_module.py +241 -0
  97. rasa/llm_fine_tuning/conversations.py +144 -0
  98. rasa/llm_fine_tuning/llm_data_preparation_module.py +178 -0
  99. rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +407 -0
  100. rasa/llm_fine_tuning/paraphrasing/__init__.py +0 -0
  101. rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +281 -0
  102. rasa/llm_fine_tuning/paraphrasing/default_rephrase_prompt_template.jina2 +44 -0
  103. rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +121 -0
  104. rasa/llm_fine_tuning/paraphrasing/rephrased_user_message.py +10 -0
  105. rasa/llm_fine_tuning/paraphrasing_module.py +128 -0
  106. rasa/llm_fine_tuning/storage.py +174 -0
  107. rasa/llm_fine_tuning/train_test_split_module.py +441 -0
  108. rasa/model_training.py +56 -16
  109. rasa/nlu/persistor.py +157 -36
  110. rasa/server.py +45 -10
  111. rasa/shared/constants.py +76 -16
  112. rasa/shared/core/domain.py +27 -19
  113. rasa/shared/core/events.py +28 -2
  114. rasa/shared/core/flows/flow.py +208 -13
  115. rasa/shared/core/flows/flow_path.py +84 -0
  116. rasa/shared/core/flows/flows_list.py +33 -11
  117. rasa/shared/core/flows/flows_yaml_schema.json +269 -193
  118. rasa/shared/core/flows/validation.py +112 -25
  119. rasa/shared/core/flows/yaml_flows_io.py +149 -10
  120. rasa/shared/core/trackers.py +6 -0
  121. rasa/shared/core/training_data/structures.py +20 -0
  122. rasa/shared/core/training_data/visualization.html +2 -2
  123. rasa/shared/exceptions.py +4 -0
  124. rasa/shared/importers/importer.py +64 -16
  125. rasa/shared/nlu/constants.py +2 -0
  126. rasa/shared/providers/_configs/__init__.py +0 -0
  127. rasa/shared/providers/_configs/azure_openai_client_config.py +183 -0
  128. rasa/shared/providers/_configs/client_config.py +57 -0
  129. rasa/shared/providers/_configs/default_litellm_client_config.py +130 -0
  130. rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +234 -0
  131. rasa/shared/providers/_configs/openai_client_config.py +175 -0
  132. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +176 -0
  133. rasa/shared/providers/_configs/utils.py +101 -0
  134. rasa/shared/providers/_ssl_verification_utils.py +124 -0
  135. rasa/shared/providers/embedding/__init__.py +0 -0
  136. rasa/shared/providers/embedding/_base_litellm_embedding_client.py +259 -0
  137. rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py +74 -0
  138. rasa/shared/providers/embedding/azure_openai_embedding_client.py +277 -0
  139. rasa/shared/providers/embedding/default_litellm_embedding_client.py +102 -0
  140. rasa/shared/providers/embedding/embedding_client.py +90 -0
  141. rasa/shared/providers/embedding/embedding_response.py +41 -0
  142. rasa/shared/providers/embedding/huggingface_local_embedding_client.py +191 -0
  143. rasa/shared/providers/embedding/openai_embedding_client.py +172 -0
  144. rasa/shared/providers/llm/__init__.py +0 -0
  145. rasa/shared/providers/llm/_base_litellm_client.py +251 -0
  146. rasa/shared/providers/llm/azure_openai_llm_client.py +338 -0
  147. rasa/shared/providers/llm/default_litellm_llm_client.py +84 -0
  148. rasa/shared/providers/llm/llm_client.py +76 -0
  149. rasa/shared/providers/llm/llm_response.py +50 -0
  150. rasa/shared/providers/llm/openai_llm_client.py +155 -0
  151. rasa/shared/providers/llm/self_hosted_llm_client.py +293 -0
  152. rasa/shared/providers/mappings.py +75 -0
  153. rasa/shared/utils/cli.py +30 -0
  154. rasa/shared/utils/io.py +65 -2
  155. rasa/shared/utils/llm.py +246 -200
  156. rasa/shared/utils/yaml.py +121 -15
  157. rasa/studio/auth.py +6 -4
  158. rasa/studio/config.py +13 -4
  159. rasa/studio/constants.py +1 -0
  160. rasa/studio/data_handler.py +10 -3
  161. rasa/studio/download.py +19 -13
  162. rasa/studio/train.py +2 -3
  163. rasa/studio/upload.py +19 -11
  164. rasa/telemetry.py +113 -58
  165. rasa/tracing/instrumentation/attribute_extractors.py +32 -17
  166. rasa/utils/common.py +18 -19
  167. rasa/utils/endpoints.py +7 -4
  168. rasa/utils/json_utils.py +60 -0
  169. rasa/utils/licensing.py +9 -1
  170. rasa/utils/ml_utils.py +4 -2
  171. rasa/validator.py +213 -3
  172. rasa/version.py +1 -1
  173. rasa_pro-3.10.16.dist-info/METADATA +196 -0
  174. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/RECORD +179 -113
  175. rasa/nlu/classifiers/llm_intent_classifier.py +0 -519
  176. rasa/shared/providers/openai/clients.py +0 -43
  177. rasa/shared/providers/openai/session_handler.py +0 -110
  178. rasa_pro-3.9.18.dist-info/METADATA +0 -563
  179. /rasa/{shared/providers/openai → cli/project_templates/tutorial/actions}/__init__.py +0 -0
  180. /rasa/cli/project_templates/tutorial/{actions.py → actions/actions.py} +0 -0
  181. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/NOTICE +0 -0
  182. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/WHEEL +0 -0
  183. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/entry_points.txt +0 -0
@@ -8,6 +8,11 @@ import structlog
8
8
  from jinja2 import Template
9
9
  from pydantic import ValidationError
10
10
 
11
+ from rasa.shared.providers.embedding._langchain_embedding_client_adapter import (
12
+ _LangchainEmbeddingClientAdapter,
13
+ )
14
+ from rasa.shared.providers.llm.llm_client import LLMClient
15
+
11
16
  import rasa.shared.utils.io
12
17
  from rasa.telemetry import (
13
18
  track_enterprise_search_policy_predict,
@@ -19,6 +24,7 @@ from rasa.core.constants import (
19
24
  POLICY_MAX_HISTORY,
20
25
  POLICY_PRIORITY,
21
26
  SEARCH_POLICY_PRIORITY,
27
+ UTTER_SOURCE_METADATA_KEY,
22
28
  )
23
29
  from rasa.core.policies.policy import Policy, PolicyPrediction
24
30
  from rasa.core.utils import AvailableEndpoints
@@ -39,13 +45,23 @@ from rasa.engine.storage.resource import Resource
39
45
  from rasa.engine.storage.storage import ModelStorage
40
46
  from rasa.graph_components.providers.forms_provider import Forms
41
47
  from rasa.graph_components.providers.responses_provider import Responses
48
+ from rasa.shared.constants import (
49
+ EMBEDDINGS_CONFIG_KEY,
50
+ LLM_CONFIG_KEY,
51
+ MODEL_CONFIG_KEY,
52
+ MODEL_NAME_CONFIG_KEY,
53
+ PROMPT_CONFIG_KEY,
54
+ PROVIDER_CONFIG_KEY,
55
+ OPENAI_PROVIDER,
56
+ TIMEOUT_CONFIG_KEY,
57
+ )
42
58
  from rasa.shared.core.constants import (
43
59
  ACTION_CANCEL_FLOW,
44
60
  ACTION_SEND_TEXT_NAME,
45
61
  DEFAULT_SLOT_NAMES,
46
62
  )
47
63
  from rasa.shared.core.domain import Domain
48
- from rasa.shared.core.events import Event
64
+ from rasa.shared.core.events import Event, UserUttered, BotUttered
49
65
  from rasa.shared.core.generator import TrackerWithCachedStates
50
66
  from rasa.shared.core.trackers import DialogueStateTracker, EventVerbosity
51
67
  from rasa.shared.nlu.training_data.training_data import TrainingData
@@ -59,6 +75,8 @@ from rasa.shared.utils.llm import (
59
75
  llm_factory,
60
76
  sanitize_message_for_prompt,
61
77
  tracker_as_readable_transcript,
78
+ try_instantiate_llm_client,
79
+ try_instantiate_embedder,
62
80
  )
63
81
  from rasa.core.information_retrieval.faiss import FAISS_Store
64
82
  from rasa.core.information_retrieval import (
@@ -70,7 +88,6 @@ from rasa.core.information_retrieval import (
70
88
 
71
89
  if TYPE_CHECKING:
72
90
  from langchain.schema.embeddings import Embeddings
73
- from langchain.llms.base import BaseLLM
74
91
  from rasa.core.featurizers.tracker_featurizers import TrackerFeaturizer
75
92
 
76
93
  from rasa.utils.log_utils import log_llm
@@ -86,6 +103,7 @@ VECTOR_STORE_THRESHOLD_PROPERTY = "threshold"
86
103
  TRACE_TOKENS_PROPERTY = "trace_prompt_tokens"
87
104
  CITATION_ENABLED_PROPERTY = "citation_enabled"
88
105
  USE_LLM_PROPERTY = "use_generative_llm"
106
+ MAX_MESSAGES_IN_QUERY_KEY = "max_messages_in_query"
89
107
 
90
108
  DEFAULT_VECTOR_STORE_TYPE = "faiss"
91
109
  DEFAULT_VECTOR_STORE_THRESHOLD = 0.0
@@ -96,23 +114,24 @@ DEFAULT_VECTOR_STORE = {
96
114
  }
97
115
 
98
116
  DEFAULT_LLM_CONFIG = {
99
- "_type": "openai",
100
- "request_timeout": 10,
117
+ PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
118
+ MODEL_CONFIG_KEY: DEFAULT_OPENAI_CHAT_MODEL_NAME,
119
+ TIMEOUT_CONFIG_KEY: 10,
101
120
  "temperature": 0.0,
102
121
  "max_tokens": 256,
103
- "model_name": DEFAULT_OPENAI_CHAT_MODEL_NAME,
104
122
  "max_retries": 1,
105
123
  }
106
124
 
107
125
  DEFAULT_EMBEDDINGS_CONFIG = {
108
- "_type": "openai",
126
+ PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
109
127
  "model": DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
110
128
  }
111
129
 
112
- EMBEDDINGS_CONFIG_KEY = "embeddings"
113
- LLM_CONFIG_KEY = "llm"
114
130
  ENTERPRISE_SEARCH_PROMPT_FILE_NAME = "enterprise_search_policy_prompt.jinja2"
115
131
 
132
+ SEARCH_RESULTS_METADATA_KEY = "search_results"
133
+ SEARCH_QUERY_METADATA_KEY = "search_query"
134
+
116
135
  DEFAULT_ENTERPRISE_SEARCH_PROMPT_TEMPLATE = importlib.resources.read_text(
117
136
  "rasa.core.policies", "enterprise_search_prompt_template.jinja2"
118
137
  )
@@ -179,26 +198,42 @@ class EnterpriseSearchPolicy(Policy):
179
198
  """Constructs a new Policy object."""
180
199
  super().__init__(config, model_storage, resource, execution_context, featurizer)
181
200
 
201
+ # Vector store object and configuration
182
202
  self.vector_store = vector_store
183
203
  self.vector_store_config = config.get(
184
204
  VECTOR_STORE_PROPERTY, DEFAULT_VECTOR_STORE
185
205
  )
186
- self.llm_config = self.config.get(LLM_CONFIG_KEY, DEFAULT_LLM_CONFIG)
206
+ # Embeddings configuration for encoding the search query
187
207
  self.embeddings_config = self.config.get(
188
208
  EMBEDDINGS_CONFIG_KEY, DEFAULT_EMBEDDINGS_CONFIG
189
209
  )
210
+ # Maximum number of turns to include in the prompt
190
211
  self.max_history = self.config.get(POLICY_MAX_HISTORY)
191
- self.prompt_template = prompt_template or get_prompt_template(
192
- self.config.get("prompt"),
193
- DEFAULT_ENTERPRISE_SEARCH_PROMPT_TEMPLATE,
194
- )
212
+
213
+ # Maximum number of messages to include in the search query
214
+ self.max_messages_in_query = self.config.get(MAX_MESSAGES_IN_QUERY_KEY, 2)
215
+
216
+ # LLM Configuration for response generation
217
+ self.llm_config = self.config.get(LLM_CONFIG_KEY, DEFAULT_LLM_CONFIG)
218
+
219
+ # boolean to enable/disable tracing of prompt tokens
195
220
  self.trace_prompt_tokens = self.config.get(TRACE_TOKENS_PROPERTY, False)
221
+
222
+ # boolean to enable/disable the use of LLM for response generation
196
223
  self.use_llm = self.config.get(USE_LLM_PROPERTY, True)
224
+
225
+ # boolean to enable/disable citation generation
197
226
  self.citation_enabled = self.config.get(CITATION_ENABLED_PROPERTY, False)
227
+
228
+ self.prompt_template = prompt_template or get_prompt_template(
229
+ self.config.get(PROMPT_CONFIG_KEY),
230
+ DEFAULT_ENTERPRISE_SEARCH_PROMPT_TEMPLATE,
231
+ )
198
232
  self.citation_prompt_template = get_prompt_template(
199
- self.config.get("prompt"),
233
+ self.config.get(PROMPT_CONFIG_KEY),
200
234
  DEFAULT_ENTERPRISE_SEARCH_PROMPT_WITH_CITATION_TEMPLATE,
201
235
  )
236
+ # If citation is enabled, use the citation prompt template
202
237
  if self.citation_enabled:
203
238
  self.prompt_template = self.citation_prompt_template
204
239
 
@@ -209,9 +244,10 @@ class EnterpriseSearchPolicy(Policy):
209
244
  Returns:
210
245
  The embedder.
211
246
  """
212
- return embedder_factory(
247
+ client = embedder_factory(
213
248
  config.get(EMBEDDINGS_CONFIG_KEY), DEFAULT_EMBEDDINGS_CONFIG
214
249
  )
250
+ return _LangchainEmbeddingClientAdapter(client)
215
251
 
216
252
  def train( # type: ignore[override]
217
253
  self,
@@ -245,20 +281,24 @@ class EnterpriseSearchPolicy(Policy):
245
281
  # validate embedding configuration
246
282
  try:
247
283
  embeddings = self._create_plain_embedder(self.config)
248
- except ValidationError as e:
284
+ except (ValidationError, Exception) as e:
285
+ logger.error(
286
+ "enterprise_search_policy.train.embedder_instantiation_failed",
287
+ message="Unable to instantiate the embedding client.",
288
+ error=e,
289
+ )
249
290
  print_error_and_exit(
250
291
  "Unable to create embedder. Please make sure you specified the "
251
292
  f"required environment variables. Error: {e}"
252
293
  )
253
294
 
254
295
  # validate llm configuration
255
- try:
256
- llm_factory(self.config.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG)
257
- except (ImportError, ValueError, ValidationError) as e:
258
- # ImportError: llm library is likely not installed
259
- # ValueError: llm config is likely invalid
260
- # ValidationError: environment variables are likely not set
261
- print_error_and_exit(f"Unable to create LLM. Error: {e}")
296
+ try_instantiate_llm_client(
297
+ self.config.get(LLM_CONFIG_KEY),
298
+ DEFAULT_LLM_CONFIG,
299
+ "enterprise_search_policy.train",
300
+ "EnterpriseSearchPolicy",
301
+ )
262
302
 
263
303
  if store_type == DEFAULT_VECTOR_STORE_TYPE:
264
304
  logger.info("enterprise_search_policy.train.faiss")
@@ -275,11 +315,12 @@ class EnterpriseSearchPolicy(Policy):
275
315
  # telemetry call to track training completion
276
316
  track_enterprise_search_policy_train_completed(
277
317
  vector_store_type=store_type,
278
- embeddings_type=self.embeddings_config.get("_type"),
279
- embeddings_model=self.embeddings_config.get("model")
280
- or self.embeddings_config.get("model_name"),
281
- llm_type=self.llm_config.get("_type"),
282
- llm_model=self.llm_config.get("model") or self.llm_config.get("model_name"),
318
+ embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
319
+ embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
320
+ or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
321
+ llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
322
+ llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
323
+ or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
283
324
  citation_enabled=self.citation_enabled,
284
325
  )
285
326
  self.persist()
@@ -343,24 +384,31 @@ class EnterpriseSearchPolicy(Policy):
343
384
  logger.error(
344
385
  "enterprise_search_policy._connect_vector_store_or_raise.connect_error",
345
386
  error=e,
387
+ config=config,
346
388
  )
347
389
  raise VectorStoreConnectionError(
348
390
  f"Unable to connect to the vector store. Error: {e}"
349
391
  )
350
392
 
351
- def _get_last_user_message(self, tracker: DialogueStateTracker) -> str:
352
- """Get the last user message from the tracker.
393
+ def _prepare_search_query(self, tracker: DialogueStateTracker, history: int) -> str:
394
+ """Prepares the search query.
395
+ The search query is the last N messages in the conversation history.
353
396
 
354
397
  Args:
355
398
  tracker: The tracker containing the conversation history up to now.
399
+ history: The number of messages to include in the search query.
356
400
 
357
401
  Returns:
358
- The last user message.
402
+ The search query.
359
403
  """
360
- for event in reversed(tracker.events):
361
- if isinstance(event, rasa.shared.core.events.UserUttered):
362
- return sanitize_message_for_prompt(event.text)
363
- return ""
404
+ transcript = []
405
+ for event in tracker.applied_events():
406
+ if isinstance(event, UserUttered) or isinstance(event, BotUttered):
407
+ transcript.append(sanitize_message_for_prompt(event.text))
408
+
409
+ search_query = " ".join(transcript[-history:][::-1])
410
+ logger.debug("search_query", search_query=search_query)
411
+ return search_query
364
412
 
365
413
  async def predict_action_probabilities( # type: ignore[override]
366
414
  self,
@@ -404,7 +452,9 @@ class EnterpriseSearchPolicy(Policy):
404
452
  logger.error(f"{logger_key}.connection_error", error=e)
405
453
  return self._create_prediction_internal_error(domain, tracker)
406
454
 
407
- search_query = self._get_last_user_message(tracker)
455
+ search_query = self._prepare_search_query(
456
+ tracker, int(self.max_messages_in_query)
457
+ )
408
458
  tracker_state = tracker.current_state(EventVerbosity.AFTER_RESTART)
409
459
 
410
460
  try:
@@ -448,17 +498,23 @@ class EnterpriseSearchPolicy(Policy):
448
498
  action_metadata = {
449
499
  "message": {
450
500
  "text": response,
501
+ SEARCH_RESULTS_METADATA_KEY: [
502
+ result.text for result in documents.results
503
+ ],
504
+ UTTER_SOURCE_METADATA_KEY: self.__class__.__name__,
505
+ SEARCH_QUERY_METADATA_KEY: search_query,
451
506
  }
452
507
  }
453
508
 
454
509
  # telemetry call to track policy prediction
455
510
  track_enterprise_search_policy_predict(
456
511
  vector_store_type=self.vector_store_config.get(VECTOR_STORE_TYPE_PROPERTY),
457
- embeddings_type=self.embeddings_config.get("_type"),
458
- embeddings_model=self.embeddings_config.get("model")
459
- or self.embeddings_config.get("model_name"),
460
- llm_type=self.llm_config.get("_type"),
461
- llm_model=self.llm_config.get("model") or self.llm_config.get("model_name"),
512
+ embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
513
+ embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
514
+ or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
515
+ llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
516
+ llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
517
+ or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
462
518
  citation_enabled=self.citation_enabled,
463
519
  )
464
520
  return self._create_prediction(
@@ -495,10 +551,11 @@ class EnterpriseSearchPolicy(Policy):
495
551
  return prompt
496
552
 
497
553
  async def _generate_llm_answer(
498
- self, llm: "BaseLLM", prompt: Text
554
+ self, llm: LLMClient, prompt: Text
499
555
  ) -> Optional[Text]:
500
556
  try:
501
- llm_answer = await llm.apredict(prompt)
557
+ llm_response = await llm.acompletion(prompt)
558
+ llm_answer = llm_response.choices[0]
502
559
  except Exception as e:
503
560
  # unfortunately, langchain does not wrap LLM exceptions which means
504
561
  # we have to catch all exceptions here
@@ -605,6 +662,18 @@ class EnterpriseSearchPolicy(Policy):
605
662
  execution_context: ExecutionContext,
606
663
  **kwargs: Any,
607
664
  ) -> "EnterpriseSearchPolicy":
665
+ try_instantiate_llm_client(
666
+ config.get(LLM_CONFIG_KEY),
667
+ DEFAULT_LLM_CONFIG,
668
+ "enterprise_search_policy.load",
669
+ EnterpriseSearchPolicy.__name__,
670
+ )
671
+ try_instantiate_embedder(
672
+ config.get(EMBEDDINGS_CONFIG_KEY),
673
+ DEFAULT_EMBEDDINGS_CONFIG,
674
+ "enterprise_search_policy.load",
675
+ EnterpriseSearchPolicy.__name__,
676
+ )
608
677
  """Loads a trained policy (see parent class for full docstring)."""
609
678
  prompt_template = None
610
679
  store_type = config.get(VECTOR_STORE_PROPERTY, {}).get(
@@ -639,7 +708,6 @@ class EnterpriseSearchPolicy(Policy):
639
708
  logger.warning(
640
709
  "enterprise_search_policy.load.failed", error=e, resource=resource.name
641
710
  )
642
-
643
711
  return cls(
644
712
  config,
645
713
  model_storage,
@@ -684,7 +752,7 @@ class EnterpriseSearchPolicy(Policy):
684
752
  local_knowledge_data = cls._get_local_knowledge_data(config)
685
753
 
686
754
  prompt_template = get_prompt_template(
687
- config.get("prompt"),
755
+ config.get(PROMPT_CONFIG_KEY),
688
756
  DEFAULT_ENTERPRISE_SEARCH_PROMPT_TEMPLATE,
689
757
  )
690
758
  return deep_container_fingerprint([prompt_template, local_knowledge_data])
@@ -2,12 +2,14 @@ from __future__ import annotations
2
2
 
3
3
  from typing import Any, Dict, Text, List, Optional
4
4
 
5
+ import structlog
5
6
  from jinja2 import Template
6
- from rasa.dialogue_understanding.commands import CancelFlowCommand
7
- from rasa.dialogue_understanding.patterns.cancel import CancelPatternFlowStackFrame
7
+ from pypred import Predicate
8
8
  from structlog.contextvars import (
9
9
  bound_contextvars,
10
10
  )
11
+
12
+ from rasa.core.constants import STEP_ID_METADATA_KEY, ACTIVE_FLOW_METADATA_KEY
11
13
  from rasa.core.policies.flows.flow_exceptions import (
12
14
  FlowCircuitBreakerTrippedException,
13
15
  FlowException,
@@ -19,6 +21,20 @@ from rasa.core.policies.flows.flow_step_result import (
19
21
  FlowStepResult,
20
22
  PauseFlowReturnPrediction,
21
23
  )
24
+ from rasa.dialogue_understanding.commands import CancelFlowCommand
25
+ from rasa.dialogue_understanding.patterns.cancel import CancelPatternFlowStackFrame
26
+ from rasa.dialogue_understanding.patterns.collect_information import (
27
+ CollectInformationPatternFlowStackFrame,
28
+ )
29
+ from rasa.dialogue_understanding.patterns.completed import (
30
+ CompletedPatternFlowStackFrame,
31
+ )
32
+ from rasa.dialogue_understanding.patterns.continue_interrupted import (
33
+ ContinueInterruptedPatternFlowStackFrame,
34
+ )
35
+ from rasa.dialogue_understanding.patterns.human_handoff import (
36
+ HumanHandoffPatternFlowStackFrame,
37
+ )
22
38
  from rasa.dialogue_understanding.patterns.internal_error import (
23
39
  InternalErrorPatternFlowStackFrame,
24
40
  )
@@ -29,24 +45,13 @@ from rasa.dialogue_understanding.stack.frames import (
29
45
  DialogueStackFrame,
30
46
  UserFlowStackFrame,
31
47
  )
32
- from rasa.dialogue_understanding.patterns.collect_information import (
33
- CollectInformationPatternFlowStackFrame,
34
- )
35
- from rasa.dialogue_understanding.patterns.completed import (
36
- CompletedPatternFlowStackFrame,
37
- )
38
- from rasa.dialogue_understanding.patterns.continue_interrupted import (
39
- ContinueInterruptedPatternFlowStackFrame,
40
- )
41
48
  from rasa.dialogue_understanding.stack.frames.flow_stack_frame import (
42
49
  FlowStackFrameType,
43
50
  )
44
51
  from rasa.dialogue_understanding.stack.utils import (
45
52
  top_user_flow_frame,
46
53
  )
47
-
48
- from pypred import Predicate
49
-
54
+ from rasa.shared.constants import RASA_PATTERN_HUMAN_HANDOFF
50
55
  from rasa.shared.core.constants import ACTION_LISTEN_NAME, SlotMappingType
51
56
  from rasa.shared.core.events import (
52
57
  Event,
@@ -56,6 +61,11 @@ from rasa.shared.core.events import (
56
61
  SlotSet,
57
62
  )
58
63
  from rasa.shared.core.flows import FlowsList
64
+ from rasa.shared.core.flows.flow import (
65
+ END_STEP,
66
+ Flow,
67
+ FlowStep,
68
+ )
59
69
  from rasa.shared.core.flows.flow_step_links import (
60
70
  StaticFlowStepLink,
61
71
  IfFlowStepLink,
@@ -71,17 +81,11 @@ from rasa.shared.core.flows.steps import (
71
81
  CollectInformationFlowStep,
72
82
  NoOperationFlowStep,
73
83
  )
74
- from rasa.shared.core.flows.flow import (
75
- END_STEP,
76
- Flow,
77
- FlowStep,
78
- )
79
84
  from rasa.shared.core.flows.steps.collect import SlotRejection
80
85
  from rasa.shared.core.slots import Slot
81
86
  from rasa.shared.core.trackers import (
82
87
  DialogueStateTracker,
83
88
  )
84
- import structlog
85
89
 
86
90
  structlogger = structlog.get_logger()
87
91
 
@@ -466,6 +470,10 @@ def advance_flows_until_next_action(
466
470
  # make sure we really return all events that got created during the
467
471
  # step execution of all steps (not only the last one)
468
472
  prediction.events = gathered_events
473
+ prediction.metadata = {
474
+ ACTIVE_FLOW_METADATA_KEY: tracker.active_flow,
475
+ STEP_ID_METADATA_KEY: tracker.current_step_id,
476
+ }
469
477
  return prediction
470
478
  else:
471
479
  structlogger.warning("flow.step.execution.no_action")
@@ -484,7 +492,8 @@ def validate_collect_step(
484
492
  A collect step can be executed if either the `utter_ask` or the `action_ask` is
485
493
  defined in the domain. If neither is defined, the collect step can still be
486
494
  executed if the slot has an initial value defined in the domain, which would cause
487
- the step to be skipped."""
495
+ the step to be skipped.
496
+ """
488
497
  slot = slots.get(step.collect)
489
498
  slot_has_initial_value_defined = slot and slot.initial_value is not None
490
499
  if (
@@ -585,7 +594,7 @@ def run_step(
585
594
  """
586
595
  initial_events: List[Event] = []
587
596
  if step == flow.first_step_in_flow():
588
- initial_events.append(FlowStarted(flow.id))
597
+ initial_events.append(FlowStarted(flow.id, metadata=stack.current_context()))
589
598
 
590
599
  if isinstance(step, CollectInformationFlowStep):
591
600
  return _run_collect_information_step(
@@ -600,77 +609,114 @@ def run_step(
600
609
  elif isinstance(step, ActionFlowStep):
601
610
  if not step.action:
602
611
  raise FlowException(f"Action not specified for step {step}")
603
-
604
- context = {"context": stack.current_context()}
605
- action_name = render_template_variables(step.action, context)
606
-
607
- if action_name in available_actions:
608
- structlogger.debug("flow.step.run.action", context=context)
609
- return PauseFlowReturnPrediction(
610
- FlowActionPrediction(action_name, 1.0, events=initial_events)
611
- )
612
- else:
613
- if step.action != "validate_{{context.collect}}":
614
- # do not log about non-existing validation actions of collect steps
615
- utter_action_name = render_template_variables(
616
- "{{context.utter}}", context
617
- )
618
- if utter_action_name not in available_actions:
619
- structlogger.warning(
620
- "flow.step.run.action.unknown", action=action_name
621
- )
622
- return ContinueFlowWithNextStep(events=initial_events)
612
+ return _run_action_step(available_actions, initial_events, stack, step)
623
613
 
624
614
  elif isinstance(step, LinkFlowStep):
625
- structlogger.debug("flow.step.run.link")
626
- stack.push(
627
- UserFlowStackFrame(
628
- flow_id=step.link,
629
- frame_type=FlowStackFrameType.LINK,
630
- ),
631
- # push this below the current stack frame so that we can
632
- # complete the current flow first and then continue with the
633
- # linked flow
634
- index=-1,
635
- )
636
- return ContinueFlowWithNextStep(events=initial_events)
615
+ return _run_link_step(initial_events, stack, step)
637
616
 
638
617
  elif isinstance(step, CallFlowStep):
639
- structlogger.debug("flow.step.run.call")
640
- stack.push(
641
- UserFlowStackFrame(
642
- flow_id=step.call,
643
- frame_type=FlowStackFrameType.CALL,
644
- ),
645
- )
646
- return ContinueFlowWithNextStep()
618
+ return _run_call_step(initial_events, stack, step)
647
619
 
648
620
  elif isinstance(step, SetSlotsFlowStep):
649
- structlogger.debug("flow.step.run.slot")
650
- slot_events: List[Event] = events_from_set_slots_step(step)
651
- return ContinueFlowWithNextStep(events=initial_events + slot_events)
621
+ return _run_set_slot_step(initial_events, step)
652
622
 
653
623
  elif isinstance(step, NoOperationFlowStep):
654
624
  structlogger.debug("flow.step.run.no_operation")
655
625
  return ContinueFlowWithNextStep(events=initial_events)
656
626
 
657
627
  elif isinstance(step, EndFlowStep):
658
- # this is the end of the flow, so we'll pop it from the stack
659
- structlogger.debug("flow.step.run.flow_end")
660
- current_frame = stack.pop()
661
- trigger_pattern_completed(current_frame, stack, flows)
662
- resumed_events = trigger_pattern_continue_interrupted(
663
- current_frame, stack, flows
664
- )
665
- reset_events: List[Event] = reset_scoped_slots(current_frame, flow, tracker)
666
- return ContinueFlowWithNextStep(
667
- events=initial_events + reset_events + resumed_events, has_flow_ended=True
668
- )
628
+ return _run_end_step(flow, flows, initial_events, stack, tracker)
669
629
 
670
630
  else:
671
631
  raise FlowException(f"Unknown flow step type {type(step)}")
672
632
 
673
633
 
634
+ def _run_end_step(
635
+ flow: Flow,
636
+ flows: FlowsList,
637
+ initial_events: List[Event],
638
+ stack: DialogueStack,
639
+ tracker: DialogueStateTracker,
640
+ ) -> FlowStepResult:
641
+ # this is the end of the flow, so we'll pop it from the stack
642
+ structlogger.debug("flow.step.run.flow_end")
643
+ current_frame = stack.pop()
644
+ trigger_pattern_completed(current_frame, stack, flows)
645
+ resumed_events = trigger_pattern_continue_interrupted(current_frame, stack, flows)
646
+ reset_events: List[Event] = reset_scoped_slots(current_frame, flow, tracker)
647
+ return ContinueFlowWithNextStep(
648
+ events=initial_events + reset_events + resumed_events, has_flow_ended=True
649
+ )
650
+
651
+
652
+ def _run_set_slot_step(
653
+ initial_events: List[Event], step: SetSlotsFlowStep
654
+ ) -> FlowStepResult:
655
+ structlogger.debug("flow.step.run.slot")
656
+ slot_events: List[Event] = events_from_set_slots_step(step)
657
+ return ContinueFlowWithNextStep(events=initial_events + slot_events)
658
+
659
+
660
+ def _run_call_step(
661
+ initial_events: List[Event], stack: DialogueStack, step: CallFlowStep
662
+ ) -> FlowStepResult:
663
+ structlogger.debug("flow.step.run.call")
664
+ stack.push(
665
+ UserFlowStackFrame(
666
+ flow_id=step.call,
667
+ frame_type=FlowStackFrameType.CALL,
668
+ ),
669
+ )
670
+ return ContinueFlowWithNextStep(events=initial_events)
671
+
672
+
673
+ def _run_link_step(
674
+ initial_events: List[Event], stack: DialogueStack, step: LinkFlowStep
675
+ ) -> FlowStepResult:
676
+ structlogger.debug("flow.step.run.link")
677
+
678
+ if step.link == RASA_PATTERN_HUMAN_HANDOFF:
679
+ linked_stack_frame: DialogueStackFrame = HumanHandoffPatternFlowStackFrame()
680
+ else:
681
+ linked_stack_frame = UserFlowStackFrame(
682
+ flow_id=step.link,
683
+ frame_type=FlowStackFrameType.LINK,
684
+ )
685
+
686
+ stack.push(
687
+ linked_stack_frame,
688
+ # push this below the current stack frame so that we can
689
+ # complete the current flow first and then continue with the
690
+ # linked flow
691
+ index=-1,
692
+ )
693
+
694
+ return ContinueFlowWithNextStep(events=initial_events)
695
+
696
+
697
+ def _run_action_step(
698
+ available_actions: List[str],
699
+ initial_events: List[Event],
700
+ stack: DialogueStack,
701
+ step: ActionFlowStep,
702
+ ) -> FlowStepResult:
703
+ context = {"context": stack.current_context()}
704
+ action_name = render_template_variables(step.action, context)
705
+
706
+ if action_name in available_actions:
707
+ structlogger.debug("flow.step.run.action", context=context)
708
+ return PauseFlowReturnPrediction(
709
+ FlowActionPrediction(action_name, 1.0, events=initial_events)
710
+ )
711
+ else:
712
+ if step.action != "validate_{{context.collect}}":
713
+ # do not log about non-existing validation actions of collect steps
714
+ utter_action_name = render_template_variables("{{context.utter}}", context)
715
+ if utter_action_name not in available_actions:
716
+ structlogger.warning("flow.step.run.action.unknown", action=action_name)
717
+ return ContinueFlowWithNextStep(events=initial_events)
718
+
719
+
674
720
  def _run_collect_information_step(
675
721
  available_actions: List[str],
676
722
  initial_events: List[Event],