rasa-pro 3.11.0a4.dev3__py3-none-any.whl → 3.11.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (163) hide show
  1. rasa/__main__.py +22 -12
  2. rasa/api.py +1 -1
  3. rasa/cli/arguments/default_arguments.py +1 -2
  4. rasa/cli/arguments/shell.py +5 -1
  5. rasa/cli/e2e_test.py +1 -1
  6. rasa/cli/evaluate.py +8 -8
  7. rasa/cli/inspect.py +4 -4
  8. rasa/cli/llm_fine_tuning.py +1 -1
  9. rasa/cli/project_templates/calm/config.yml +5 -7
  10. rasa/cli/project_templates/calm/endpoints.yml +8 -0
  11. rasa/cli/project_templates/tutorial/config.yml +8 -5
  12. rasa/cli/project_templates/tutorial/data/flows.yml +1 -1
  13. rasa/cli/project_templates/tutorial/data/patterns.yml +5 -0
  14. rasa/cli/project_templates/tutorial/domain.yml +14 -0
  15. rasa/cli/project_templates/tutorial/endpoints.yml +7 -7
  16. rasa/cli/run.py +1 -1
  17. rasa/cli/scaffold.py +4 -2
  18. rasa/cli/utils.py +5 -0
  19. rasa/cli/x.py +8 -8
  20. rasa/constants.py +1 -1
  21. rasa/core/channels/channel.py +3 -0
  22. rasa/core/channels/inspector/dist/assets/{arc-6852c607.js → arc-bc141fb2.js} +1 -1
  23. rasa/core/channels/inspector/dist/assets/{c4Diagram-d0fbc5ce-acc952b2.js → c4Diagram-d0fbc5ce-be2db283.js} +1 -1
  24. rasa/core/channels/inspector/dist/assets/{classDiagram-936ed81e-848a7597.js → classDiagram-936ed81e-55366915.js} +1 -1
  25. rasa/core/channels/inspector/dist/assets/{classDiagram-v2-c3cb15f1-a73d3e68.js → classDiagram-v2-c3cb15f1-bb529518.js} +1 -1
  26. rasa/core/channels/inspector/dist/assets/{createText-62fc7601-e5ee049d.js → createText-62fc7601-b0ec81d6.js} +1 -1
  27. rasa/core/channels/inspector/dist/assets/{edges-f2ad444c-771e517e.js → edges-f2ad444c-6166330c.js} +1 -1
  28. rasa/core/channels/inspector/dist/assets/{erDiagram-9d236eb7-aa347178.js → erDiagram-9d236eb7-5ccc6a8e.js} +1 -1
  29. rasa/core/channels/inspector/dist/assets/{flowDb-1972c806-651fc57d.js → flowDb-1972c806-fca3bfe4.js} +1 -1
  30. rasa/core/channels/inspector/dist/assets/{flowDiagram-7ea5b25a-ca67804f.js → flowDiagram-7ea5b25a-4739080f.js} +1 -1
  31. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-736177bf.js +1 -0
  32. rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-abe16c3d-2dbc568d.js → flowchart-elk-definition-abe16c3d-7c1b0e0f.js} +1 -1
  33. rasa/core/channels/inspector/dist/assets/{ganttDiagram-9b5ea136-25a65bd8.js → ganttDiagram-9b5ea136-772fd050.js} +1 -1
  34. rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-99d0ae7c-fdc7378d.js → gitGraphDiagram-99d0ae7c-8eae1dc9.js} +1 -1
  35. rasa/core/channels/inspector/dist/assets/{index-2c4b9a3b-6f1fd606.js → index-2c4b9a3b-f55afcdf.js} +1 -1
  36. rasa/core/channels/inspector/dist/assets/{index-efdd30c1.js → index-e7cef9de.js} +68 -68
  37. rasa/core/channels/inspector/dist/assets/{infoDiagram-736b4530-cb1a041a.js → infoDiagram-736b4530-124d4a14.js} +1 -1
  38. rasa/core/channels/inspector/dist/assets/{journeyDiagram-df861f2b-14609879.js → journeyDiagram-df861f2b-7c4fae44.js} +1 -1
  39. rasa/core/channels/inspector/dist/assets/{layout-2490f52b.js → layout-b9885fb6.js} +1 -1
  40. rasa/core/channels/inspector/dist/assets/{line-40186f1f.js → line-7c59abb6.js} +1 -1
  41. rasa/core/channels/inspector/dist/assets/{linear-08814e93.js → linear-4776f780.js} +1 -1
  42. rasa/core/channels/inspector/dist/assets/{mindmap-definition-beec6740-1a534584.js → mindmap-definition-beec6740-2332c46c.js} +1 -1
  43. rasa/core/channels/inspector/dist/assets/{pieDiagram-dbbf0591-72397b61.js → pieDiagram-dbbf0591-8fb39303.js} +1 -1
  44. rasa/core/channels/inspector/dist/assets/{quadrantDiagram-4d7f4fd6-3bb0b6a3.js → quadrantDiagram-4d7f4fd6-3c7180a2.js} +1 -1
  45. rasa/core/channels/inspector/dist/assets/{requirementDiagram-6fc4c22a-57334f61.js → requirementDiagram-6fc4c22a-e910bcb8.js} +1 -1
  46. rasa/core/channels/inspector/dist/assets/{sankeyDiagram-8f13d901-111e1297.js → sankeyDiagram-8f13d901-ead16c89.js} +1 -1
  47. rasa/core/channels/inspector/dist/assets/{sequenceDiagram-b655622a-10bcfe62.js → sequenceDiagram-b655622a-29a02a19.js} +1 -1
  48. rasa/core/channels/inspector/dist/assets/{stateDiagram-59f0c015-acaf7513.js → stateDiagram-59f0c015-042b3137.js} +1 -1
  49. rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-2b26beab-3ec2a235.js → stateDiagram-v2-2b26beab-2178c0f3.js} +1 -1
  50. rasa/core/channels/inspector/dist/assets/{styles-080da4f6-62730289.js → styles-080da4f6-23ffa4fc.js} +1 -1
  51. rasa/core/channels/inspector/dist/assets/{styles-3dcbcfbf-5284ee76.js → styles-3dcbcfbf-94f59763.js} +1 -1
  52. rasa/core/channels/inspector/dist/assets/{styles-9c745c82-642435e3.js → styles-9c745c82-78a6bebc.js} +1 -1
  53. rasa/core/channels/inspector/dist/assets/{svgDrawCommon-4835440b-b250a350.js → svgDrawCommon-4835440b-eae2a6f6.js} +1 -1
  54. rasa/core/channels/inspector/dist/assets/{timeline-definition-5b62e21b-c2b147ed.js → timeline-definition-5b62e21b-5c968d92.js} +1 -1
  55. rasa/core/channels/inspector/dist/assets/{xychartDiagram-2b33534f-f92cfea9.js → xychartDiagram-2b33534f-fd3db0d5.js} +1 -1
  56. rasa/core/channels/inspector/dist/index.html +1 -1
  57. rasa/core/channels/inspector/src/App.tsx +1 -1
  58. rasa/core/channels/inspector/src/helpers/audiostream.ts +77 -16
  59. rasa/core/channels/socketio.py +2 -1
  60. rasa/core/channels/telegram.py +1 -1
  61. rasa/core/channels/twilio.py +1 -1
  62. rasa/core/channels/voice_ready/jambonz.py +2 -2
  63. rasa/core/channels/voice_stream/asr/asr_event.py +5 -0
  64. rasa/core/channels/voice_stream/asr/azure.py +122 -0
  65. rasa/core/channels/voice_stream/asr/deepgram.py +16 -6
  66. rasa/core/channels/voice_stream/audio_bytes.py +1 -0
  67. rasa/core/channels/voice_stream/browser_audio.py +31 -8
  68. rasa/core/channels/voice_stream/call_state.py +23 -0
  69. rasa/core/channels/voice_stream/tts/azure.py +6 -2
  70. rasa/core/channels/voice_stream/tts/cartesia.py +10 -6
  71. rasa/core/channels/voice_stream/tts/tts_engine.py +1 -0
  72. rasa/core/channels/voice_stream/twilio_media_streams.py +27 -18
  73. rasa/core/channels/voice_stream/util.py +4 -4
  74. rasa/core/channels/voice_stream/voice_channel.py +177 -39
  75. rasa/core/featurizers/single_state_featurizer.py +22 -1
  76. rasa/core/featurizers/tracker_featurizers.py +115 -18
  77. rasa/core/nlg/contextual_response_rephraser.py +16 -22
  78. rasa/core/persistor.py +86 -39
  79. rasa/core/policies/enterprise_search_policy.py +159 -60
  80. rasa/core/policies/flows/flow_executor.py +7 -4
  81. rasa/core/policies/intentless_policy.py +120 -22
  82. rasa/core/policies/ted_policy.py +58 -33
  83. rasa/core/policies/unexpected_intent_policy.py +15 -7
  84. rasa/core/processor.py +25 -0
  85. rasa/core/training/interactive.py +34 -35
  86. rasa/core/utils.py +8 -3
  87. rasa/dialogue_understanding/coexistence/llm_based_router.py +58 -16
  88. rasa/dialogue_understanding/commands/change_flow_command.py +6 -0
  89. rasa/dialogue_understanding/commands/user_silence_command.py +59 -0
  90. rasa/dialogue_understanding/commands/utils.py +5 -0
  91. rasa/dialogue_understanding/generator/constants.py +4 -0
  92. rasa/dialogue_understanding/generator/flow_retrieval.py +65 -3
  93. rasa/dialogue_understanding/generator/llm_based_command_generator.py +68 -26
  94. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +57 -8
  95. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +64 -7
  96. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +39 -0
  97. rasa/dialogue_understanding/patterns/user_silence.py +37 -0
  98. rasa/e2e_test/e2e_test_runner.py +4 -2
  99. rasa/e2e_test/utils/io.py +1 -1
  100. rasa/engine/validation.py +297 -7
  101. rasa/model_manager/config.py +15 -3
  102. rasa/model_manager/model_api.py +15 -7
  103. rasa/model_manager/runner_service.py +8 -6
  104. rasa/model_manager/socket_bridge.py +6 -3
  105. rasa/model_manager/trainer_service.py +7 -5
  106. rasa/model_manager/utils.py +28 -7
  107. rasa/model_service.py +6 -2
  108. rasa/model_training.py +2 -0
  109. rasa/nlu/classifiers/diet_classifier.py +38 -25
  110. rasa/nlu/classifiers/logistic_regression_classifier.py +22 -9
  111. rasa/nlu/classifiers/sklearn_intent_classifier.py +37 -16
  112. rasa/nlu/extractors/crf_entity_extractor.py +93 -50
  113. rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +45 -16
  114. rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +52 -17
  115. rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +5 -3
  116. rasa/shared/constants.py +36 -3
  117. rasa/shared/core/constants.py +7 -0
  118. rasa/shared/core/domain.py +26 -0
  119. rasa/shared/core/flows/flow.py +5 -0
  120. rasa/shared/core/flows/flows_yaml_schema.json +10 -0
  121. rasa/shared/core/flows/utils.py +39 -0
  122. rasa/shared/core/flows/validation.py +96 -0
  123. rasa/shared/core/slots.py +5 -0
  124. rasa/shared/nlu/training_data/features.py +120 -2
  125. rasa/shared/providers/_configs/azure_openai_client_config.py +5 -3
  126. rasa/shared/providers/_configs/litellm_router_client_config.py +200 -0
  127. rasa/shared/providers/_configs/model_group_config.py +167 -0
  128. rasa/shared/providers/_configs/openai_client_config.py +1 -1
  129. rasa/shared/providers/_configs/rasa_llm_client_config.py +73 -0
  130. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +1 -0
  131. rasa/shared/providers/_configs/utils.py +16 -0
  132. rasa/shared/providers/embedding/_base_litellm_embedding_client.py +12 -15
  133. rasa/shared/providers/embedding/azure_openai_embedding_client.py +54 -21
  134. rasa/shared/providers/embedding/litellm_router_embedding_client.py +135 -0
  135. rasa/shared/providers/llm/_base_litellm_client.py +31 -30
  136. rasa/shared/providers/llm/azure_openai_llm_client.py +50 -29
  137. rasa/shared/providers/llm/litellm_router_llm_client.py +127 -0
  138. rasa/shared/providers/llm/rasa_llm_client.py +112 -0
  139. rasa/shared/providers/llm/self_hosted_llm_client.py +1 -1
  140. rasa/shared/providers/mappings.py +19 -0
  141. rasa/shared/providers/router/__init__.py +0 -0
  142. rasa/shared/providers/router/_base_litellm_router_client.py +149 -0
  143. rasa/shared/providers/router/router_client.py +73 -0
  144. rasa/shared/utils/common.py +8 -0
  145. rasa/shared/utils/health_check.py +533 -0
  146. rasa/shared/utils/io.py +28 -6
  147. rasa/shared/utils/llm.py +350 -46
  148. rasa/shared/utils/yaml.py +11 -13
  149. rasa/studio/upload.py +64 -20
  150. rasa/telemetry.py +80 -17
  151. rasa/tracing/instrumentation/attribute_extractors.py +74 -17
  152. rasa/utils/io.py +0 -66
  153. rasa/utils/log_utils.py +9 -2
  154. rasa/utils/tensorflow/feature_array.py +366 -0
  155. rasa/utils/tensorflow/model_data.py +2 -193
  156. rasa/validator.py +70 -0
  157. rasa/version.py +1 -1
  158. {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc1.dist-info}/METADATA +10 -10
  159. {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc1.dist-info}/RECORD +162 -146
  160. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-587d82d8.js +0 -1
  161. {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc1.dist-info}/NOTICE +0 -0
  162. {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc1.dist-info}/WHEEL +0 -0
  163. {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc1.dist-info}/entry_points.txt +0 -0
@@ -1,45 +1,44 @@
1
1
  import importlib.resources
2
2
  import json
3
- import os
4
3
  import re
5
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Text
6
-
4
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Text, Tuple
7
5
  import dotenv
8
6
  import structlog
9
7
  from jinja2 import Template
10
8
  from pydantic import ValidationError
11
9
 
12
- from rasa.shared.providers.embedding._langchain_embedding_client_adapter import (
13
- _LangchainEmbeddingClientAdapter,
14
- )
15
- from rasa.shared.providers.llm.llm_client import LLMClient
16
-
17
10
  import rasa.shared.utils.io
18
- from rasa.telemetry import (
19
- track_enterprise_search_policy_predict,
20
- track_enterprise_search_policy_train_completed,
21
- track_enterprise_search_policy_train_started,
22
- )
23
- from rasa.shared.exceptions import RasaException
24
11
  from rasa.core.constants import (
25
12
  POLICY_MAX_HISTORY,
26
13
  POLICY_PRIORITY,
27
14
  SEARCH_POLICY_PRIORITY,
28
15
  UTTER_SOURCE_METADATA_KEY,
29
16
  )
17
+ from rasa.core.information_retrieval import (
18
+ InformationRetrieval,
19
+ SearchResult,
20
+ InformationRetrievalException,
21
+ create_from_endpoint_config,
22
+ )
23
+ from rasa.core.information_retrieval.faiss import FAISS_Store
30
24
  from rasa.core.policies.policy import Policy, PolicyPrediction
31
25
  from rasa.core.utils import AvailableEndpoints
32
- from rasa.dialogue_understanding.patterns.internal_error import (
33
- InternalErrorPatternFlowStackFrame,
26
+ from rasa.dialogue_understanding.generator.constants import (
27
+ LLM_CONFIG_KEY,
28
+ TRAINED_MODEL_NAME_CONFIG_KEY,
29
+ TRAINED_EMBEDDINGS_CONFIG_KEY,
34
30
  )
35
31
  from rasa.dialogue_understanding.patterns.cannot_handle import (
36
32
  CannotHandlePatternFlowStackFrame,
37
33
  )
38
- from rasa.dialogue_understanding.stack.frames import PatternFlowStackFrame
34
+ from rasa.dialogue_understanding.patterns.internal_error import (
35
+ InternalErrorPatternFlowStackFrame,
36
+ )
39
37
  from rasa.dialogue_understanding.stack.frames import (
40
38
  DialogueStackFrame,
41
39
  SearchStackFrame,
42
40
  )
41
+ from rasa.dialogue_understanding.stack.frames import PatternFlowStackFrame
43
42
  from rasa.engine.graph import ExecutionContext
44
43
  from rasa.engine.recipes.default_recipe import DefaultV1Recipe
45
44
  from rasa.engine.storage.resource import Resource
@@ -48,14 +47,13 @@ from rasa.graph_components.providers.forms_provider import Forms
48
47
  from rasa.graph_components.providers.responses_provider import Responses
49
48
  from rasa.shared.constants import (
50
49
  EMBEDDINGS_CONFIG_KEY,
51
- LLM_API_HEALTH_CHECK_ENV_VAR,
52
- LLM_CONFIG_KEY,
53
50
  MODEL_CONFIG_KEY,
54
- MODEL_NAME_CONFIG_KEY,
55
51
  PROMPT_CONFIG_KEY,
56
52
  PROVIDER_CONFIG_KEY,
57
53
  OPENAI_PROVIDER,
58
54
  TIMEOUT_CONFIG_KEY,
55
+ MODEL_NAME_CONFIG_KEY,
56
+ MODEL_GROUP_CONFIG_KEY,
59
57
  )
60
58
  from rasa.shared.core.constants import (
61
59
  ACTION_CANCEL_FLOW,
@@ -66,7 +64,12 @@ from rasa.shared.core.domain import Domain
66
64
  from rasa.shared.core.events import Event, UserUttered, BotUttered
67
65
  from rasa.shared.core.generator import TrackerWithCachedStates
68
66
  from rasa.shared.core.trackers import DialogueStateTracker, EventVerbosity
67
+ from rasa.shared.exceptions import RasaException, FileIOException
69
68
  from rasa.shared.nlu.training_data.training_data import TrainingData
69
+ from rasa.shared.providers.embedding._langchain_embedding_client_adapter import (
70
+ _LangchainEmbeddingClientAdapter,
71
+ )
72
+ from rasa.shared.providers.llm.llm_client import LLMClient
70
73
  from rasa.shared.utils.cli import print_error_and_exit
71
74
  from rasa.shared.utils.io import deep_container_fingerprint
72
75
  from rasa.shared.utils.llm import (
@@ -74,18 +77,21 @@ from rasa.shared.utils.llm import (
74
77
  DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
75
78
  embedder_factory,
76
79
  get_prompt_template,
77
- llm_api_health_check,
78
80
  llm_factory,
79
81
  sanitize_message_for_prompt,
80
82
  tracker_as_readable_transcript,
81
- try_instantiate_llm_client,
83
+ resolve_model_client_config,
82
84
  )
83
- from rasa.core.information_retrieval.faiss import FAISS_Store
84
- from rasa.core.information_retrieval import (
85
- InformationRetrieval,
86
- SearchResult,
87
- InformationRetrievalException,
88
- create_from_endpoint_config,
85
+ from rasa.shared.utils.health_check import (
86
+ perform_training_time_llm_health_check,
87
+ perform_training_time_embeddings_health_check,
88
+ perform_inference_time_llm_health_check,
89
+ perform_inference_time_embeddings_health_check,
90
+ )
91
+ from rasa.telemetry import (
92
+ track_enterprise_search_policy_predict,
93
+ track_enterprise_search_policy_train_completed,
94
+ track_enterprise_search_policy_train_started,
89
95
  )
90
96
 
91
97
  if TYPE_CHECKING:
@@ -130,6 +136,7 @@ DEFAULT_EMBEDDINGS_CONFIG = {
130
136
  }
131
137
 
132
138
  ENTERPRISE_SEARCH_PROMPT_FILE_NAME = "enterprise_search_policy_prompt.jinja2"
139
+ ENTERPRISE_SEARCH_CONFIG_FILE_NAME = "config.json"
133
140
 
134
141
  SEARCH_RESULTS_METADATA_KEY = "search_results"
135
142
  SEARCH_QUERY_METADATA_KEY = "search_query"
@@ -200,24 +207,35 @@ class EnterpriseSearchPolicy(Policy):
200
207
  """Constructs a new Policy object."""
201
208
  super().__init__(config, model_storage, resource, execution_context, featurizer)
202
209
 
210
+ # Resolve LLM config
211
+ self.config[LLM_CONFIG_KEY] = resolve_model_client_config(
212
+ self.config.get(LLM_CONFIG_KEY), EnterpriseSearchPolicy.__name__
213
+ )
214
+ # Resolve embeddings config
215
+ self.config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
216
+ self.config.get(EMBEDDINGS_CONFIG_KEY), EnterpriseSearchPolicy.__name__
217
+ )
218
+
203
219
  # Vector store object and configuration
204
220
  self.vector_store = vector_store
205
- self.vector_store_config = config.get(
221
+ self.vector_store_config = self.config.get(
206
222
  VECTOR_STORE_PROPERTY, DEFAULT_VECTOR_STORE
207
223
  )
224
+
208
225
  # Embeddings configuration for encoding the search query
209
- self.embeddings_config = self.config.get(
210
- EMBEDDINGS_CONFIG_KEY, DEFAULT_EMBEDDINGS_CONFIG
226
+ self.embeddings_config = (
227
+ self.config[EMBEDDINGS_CONFIG_KEY] or DEFAULT_EMBEDDINGS_CONFIG
211
228
  )
229
+
230
+ # LLM Configuration for response generation
231
+ self.llm_config = self.config[LLM_CONFIG_KEY] or DEFAULT_LLM_CONFIG
232
+
212
233
  # Maximum number of turns to include in the prompt
213
234
  self.max_history = self.config.get(POLICY_MAX_HISTORY)
214
235
 
215
236
  # Maximum number of messages to include in the search query
216
237
  self.max_messages_in_query = self.config.get(MAX_MESSAGES_IN_QUERY_KEY, 2)
217
238
 
218
- # LLM Configuration for response generation
219
- self.llm_config = self.config.get(LLM_CONFIG_KEY, DEFAULT_LLM_CONFIG)
220
-
221
239
  # boolean to enable/disable tracing of prompt tokens
222
240
  self.trace_prompt_tokens = self.config.get(TRACE_TOKENS_PROPERTY, False)
223
241
 
@@ -246,9 +264,16 @@ class EnterpriseSearchPolicy(Policy):
246
264
  Returns:
247
265
  The embedder.
248
266
  """
267
+ # Copy the config so original config is not modified
268
+ config = config.copy()
269
+ # Resolve config and instantiate the embedding client
270
+ config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
271
+ config.get(EMBEDDINGS_CONFIG_KEY), EnterpriseSearchPolicy.__name__
272
+ )
249
273
  client = embedder_factory(
250
274
  config.get(EMBEDDINGS_CONFIG_KEY), DEFAULT_EMBEDDINGS_CONFIG
251
275
  )
276
+ # Wrap the embedding client in the adapter
252
277
  return _LangchainEmbeddingClientAdapter(client)
253
278
 
254
279
  def train( # type: ignore[override]
@@ -294,19 +319,10 @@ class EnterpriseSearchPolicy(Policy):
294
319
  f"required environment variables. Error: {e}"
295
320
  )
296
321
 
297
- # validate llm configuration
298
- llm_client = try_instantiate_llm_client(
299
- self.config.get(LLM_CONFIG_KEY),
300
- DEFAULT_LLM_CONFIG,
301
- "enterprise_search_policy.train",
302
- EnterpriseSearchPolicy.__name__,
303
- )
304
- if os.getenv(LLM_API_HEALTH_CHECK_ENV_VAR, "true").lower() == "true":
305
- llm_api_health_check(
306
- llm_client,
307
- "enterprise_search_policy.train",
308
- EnterpriseSearchPolicy.__name__,
309
- )
322
+ (
323
+ self.config[TRAINED_MODEL_NAME_CONFIG_KEY],
324
+ self.config[TRAINED_EMBEDDINGS_CONFIG_KEY],
325
+ ) = self._perform_training_time_health_checks()
310
326
 
311
327
  if store_type == DEFAULT_VECTOR_STORE_TYPE:
312
328
  logger.info("enterprise_search_policy.train.faiss")
@@ -326,9 +342,13 @@ class EnterpriseSearchPolicy(Policy):
326
342
  embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
327
343
  embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
328
344
  or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
345
+ embeddings_model_group_id=self.embeddings_config.get(
346
+ MODEL_GROUP_CONFIG_KEY
347
+ ),
329
348
  llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
330
349
  llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
331
350
  or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
351
+ llm_model_group_id=self.llm_config.get(MODEL_GROUP_CONFIG_KEY),
332
352
  citation_enabled=self.citation_enabled,
333
353
  )
334
354
  self.persist()
@@ -340,6 +360,9 @@ class EnterpriseSearchPolicy(Policy):
340
360
  rasa.shared.utils.io.write_text_file(
341
361
  self.prompt_template, path / ENTERPRISE_SEARCH_PROMPT_FILE_NAME
342
362
  )
363
+ rasa.shared.utils.io.dump_obj_as_json_to_file(
364
+ path / ENTERPRISE_SEARCH_CONFIG_FILE_NAME, self.config
365
+ )
343
366
 
344
367
  def _prepare_slots_for_template(
345
368
  self, tracker: DialogueStateTracker
@@ -520,9 +543,13 @@ class EnterpriseSearchPolicy(Policy):
520
543
  embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
521
544
  embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
522
545
  or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
546
+ embeddings_model_group_id=self.embeddings_config.get(
547
+ MODEL_GROUP_CONFIG_KEY
548
+ ),
523
549
  llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
524
550
  llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
525
551
  or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
552
+ llm_model_group_id=self.llm_config.get(MODEL_GROUP_CONFIG_KEY),
526
553
  citation_enabled=self.citation_enabled,
527
554
  )
528
555
  return self._create_prediction(
@@ -672,11 +699,26 @@ class EnterpriseSearchPolicy(Policy):
672
699
  ) -> "EnterpriseSearchPolicy":
673
700
  """Loads a trained policy (see parent class for full docstring)."""
674
701
  prompt_template = None
702
+ persisted_config = None
703
+ try:
704
+ with model_storage.read_from(resource) as path:
705
+ prompt_template = rasa.shared.utils.io.read_file(
706
+ path / ENTERPRISE_SEARCH_PROMPT_FILE_NAME
707
+ )
708
+ persisted_config = rasa.shared.utils.io.read_json_file(
709
+ path / ENTERPRISE_SEARCH_CONFIG_FILE_NAME
710
+ )
711
+ except (FileNotFoundError, FileIOException) as e:
712
+ logger.warning(
713
+ "enterprise_search_policy.load.failed", error=e, resource=resource.name
714
+ )
715
+
675
716
  store_type = config.get(VECTOR_STORE_PROPERTY, {}).get(
676
717
  VECTOR_STORE_TYPE_PROPERTY
677
718
  )
678
719
 
679
720
  embeddings = cls._create_plain_embedder(config)
721
+
680
722
  logger.info("enterprise_search_policy.load", config=config)
681
723
  if store_type == DEFAULT_VECTOR_STORE_TYPE:
682
724
  # if a vector store is not specified,
@@ -694,18 +736,8 @@ class EnterpriseSearchPolicy(Policy):
694
736
  config_type=store_type,
695
737
  embeddings=embeddings,
696
738
  ) # type: ignore
697
- try:
698
- with model_storage.read_from(resource) as path:
699
- prompt_template = rasa.shared.utils.io.read_file(
700
- path / ENTERPRISE_SEARCH_PROMPT_FILE_NAME
701
- )
702
739
 
703
- except (FileNotFoundError, FileNotFoundError) as e:
704
- logger.warning(
705
- "enterprise_search_policy.load.failed", error=e, resource=resource.name
706
- )
707
-
708
- return cls(
740
+ policy = cls(
709
741
  config,
710
742
  model_storage,
711
743
  resource,
@@ -714,6 +746,14 @@ class EnterpriseSearchPolicy(Policy):
714
746
  prompt_template=prompt_template,
715
747
  )
716
748
 
749
+ cls._perform_inference_time_health_checks(
750
+ persisted_config,
751
+ policy.config.get(LLM_CONFIG_KEY),
752
+ policy.config.get(EMBEDDINGS_CONFIG_KEY),
753
+ )
754
+
755
+ return policy
756
+
717
757
  @classmethod
718
758
  def _get_local_knowledge_data(cls, config: Dict[str, Any]) -> Optional[List[str]]:
719
759
  """This is required only for local knowledge base types.
@@ -745,14 +785,23 @@ class EnterpriseSearchPolicy(Policy):
745
785
 
746
786
  @classmethod
747
787
  def fingerprint_addon(cls, config: Dict[str, Any]) -> Optional[str]:
748
- """Add a fingerprint of the knowledge base and prompt template for the graph."""
788
+ """Add a fingerprint of enterprise search policy for the graph."""
749
789
  local_knowledge_data = cls._get_local_knowledge_data(config)
750
790
 
751
791
  prompt_template = get_prompt_template(
752
792
  config.get(PROMPT_CONFIG_KEY),
753
793
  DEFAULT_ENTERPRISE_SEARCH_PROMPT_TEMPLATE,
754
794
  )
755
- return deep_container_fingerprint([prompt_template, local_knowledge_data])
795
+
796
+ llm_config = resolve_model_client_config(
797
+ config.get(LLM_CONFIG_KEY), EnterpriseSearchPolicy.__name__
798
+ )
799
+ embedding_config = resolve_model_client_config(
800
+ config.get(EMBEDDINGS_CONFIG_KEY), EnterpriseSearchPolicy.__name__
801
+ )
802
+ return deep_container_fingerprint(
803
+ [prompt_template, local_knowledge_data, llm_config, embedding_config]
804
+ )
756
805
 
757
806
  @staticmethod
758
807
  def post_process_citations(llm_answer: str) -> str:
@@ -844,3 +893,53 @@ class EnterpriseSearchPolicy(Policy):
844
893
  joined_sources = "\n".join(new_sources)
845
894
 
846
895
  return joined_answer + joined_sources
896
+
897
+ def _perform_training_time_health_checks(
898
+ self,
899
+ ) -> Tuple[Optional[str], Optional[str]]:
900
+ train_model_name = perform_training_time_llm_health_check(
901
+ self.config.get(LLM_CONFIG_KEY),
902
+ DEFAULT_LLM_CONFIG,
903
+ "enterprise_search_policy.train",
904
+ EnterpriseSearchPolicy.__name__,
905
+ )
906
+ train_embedding_name = perform_training_time_embeddings_health_check(
907
+ self.config.get(EMBEDDINGS_CONFIG_KEY),
908
+ DEFAULT_EMBEDDINGS_CONFIG,
909
+ "enterprise_search_policy.train",
910
+ EnterpriseSearchPolicy.__name__,
911
+ )
912
+ return train_model_name, train_embedding_name
913
+
914
+ @classmethod
915
+ def _perform_inference_time_health_checks(
916
+ cls,
917
+ persisted_config: Optional[Dict[str, Any]],
918
+ resolved_llm_config: Optional[Dict[str, Any]],
919
+ resolved_embeddings_config: Optional[Dict[str, Any]],
920
+ ) -> None:
921
+ train_model_name = (
922
+ persisted_config.get(TRAINED_MODEL_NAME_CONFIG_KEY, None)
923
+ if persisted_config
924
+ else None
925
+ )
926
+ perform_inference_time_llm_health_check(
927
+ resolved_llm_config,
928
+ DEFAULT_LLM_CONFIG,
929
+ train_model_name,
930
+ "enterprise_search_policy.load",
931
+ EnterpriseSearchPolicy.__name__,
932
+ )
933
+
934
+ train_embeddings_name = (
935
+ persisted_config.get(TRAINED_EMBEDDINGS_CONFIG_KEY, None)
936
+ if persisted_config
937
+ else None
938
+ )
939
+ perform_inference_time_embeddings_health_check(
940
+ resolved_embeddings_config,
941
+ DEFAULT_EMBEDDINGS_CONFIG,
942
+ train_embeddings_name,
943
+ "enterprise_search_policy.load",
944
+ EnterpriseSearchPolicy.__name__,
945
+ )
@@ -330,24 +330,27 @@ def reset_scoped_slots(
330
330
  events: List[Event] = []
331
331
 
332
332
  not_resettable_slot_names = set()
333
+ flow_persistable_slots = current_flow.persisted_slots
333
334
 
334
335
  for step in current_flow.steps_with_calls_resolved:
335
336
  if isinstance(step, CollectInformationFlowStep):
336
337
  # reset all slots scoped to the flow
337
- if step.reset_after_flow_ends:
338
- _reset_slot(step.collect, tracker)
338
+ slot_name = step.collect
339
+ if step.reset_after_flow_ends and slot_name not in flow_persistable_slots:
340
+ _reset_slot(slot_name, tracker)
339
341
  else:
340
- not_resettable_slot_names.add(step.collect)
342
+ not_resettable_slot_names.add(slot_name)
341
343
 
342
344
  # slots set by the set slots step should be reset after the flow ends
343
345
  # unless they are also used in a collect step where `reset_after_flow_ends`
344
- # is set to `False`
346
+ # is set to `False` or set in the `persisted_slots` list.
345
347
  resettable_set_slots = [
346
348
  slot["key"]
347
349
  for step in current_flow.steps_with_calls_resolved
348
350
  if isinstance(step, SetSlotsFlowStep)
349
351
  for slot in step.slots
350
352
  if slot["key"] not in not_resettable_slot_names
353
+ and slot["key"] not in flow_persistable_slots
351
354
  ]
352
355
 
353
356
  for name in resettable_set_slots: