rasa-pro 3.11.0rc1__py3-none-any.whl → 3.11.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (52) hide show
  1. rasa/cli/inspect.py +2 -0
  2. rasa/cli/studio/studio.py +18 -8
  3. rasa/core/actions/action_repeat_bot_messages.py +17 -0
  4. rasa/core/channels/channel.py +17 -0
  5. rasa/core/channels/voice_ready/audiocodes.py +12 -0
  6. rasa/core/channels/voice_ready/jambonz.py +13 -2
  7. rasa/core/channels/voice_ready/twilio_voice.py +6 -21
  8. rasa/core/channels/voice_stream/voice_channel.py +13 -1
  9. rasa/core/nlg/contextual_response_rephraser.py +18 -10
  10. rasa/core/policies/enterprise_search_policy.py +27 -67
  11. rasa/core/policies/intentless_policy.py +25 -67
  12. rasa/dialogue_understanding/coexistence/llm_based_router.py +18 -33
  13. rasa/dialogue_understanding/generator/constants.py +0 -2
  14. rasa/dialogue_understanding/generator/flow_retrieval.py +33 -50
  15. rasa/dialogue_understanding/generator/llm_based_command_generator.py +12 -40
  16. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +18 -20
  17. rasa/dialogue_understanding/generator/nlu_command_adapter.py +19 -1
  18. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +24 -21
  19. rasa/dialogue_understanding/processor/command_processor.py +21 -1
  20. rasa/e2e_test/e2e_test_case.py +85 -6
  21. rasa/engine/validation.py +57 -41
  22. rasa/model_service.py +3 -0
  23. rasa/nlu/tokenizers/whitespace_tokenizer.py +3 -14
  24. rasa/server.py +3 -1
  25. rasa/shared/core/flows/flows_list.py +5 -1
  26. rasa/shared/providers/embedding/_base_litellm_embedding_client.py +6 -14
  27. rasa/shared/providers/llm/_base_litellm_client.py +6 -1
  28. rasa/shared/utils/health_check/__init__.py +0 -0
  29. rasa/shared/utils/health_check/embeddings_health_check_mixin.py +31 -0
  30. rasa/shared/utils/health_check/health_check.py +256 -0
  31. rasa/shared/utils/health_check/llm_health_check_mixin.py +31 -0
  32. rasa/shared/utils/llm.py +5 -2
  33. rasa/shared/utils/yaml.py +102 -62
  34. rasa/studio/auth.py +3 -5
  35. rasa/studio/config.py +13 -4
  36. rasa/studio/constants.py +1 -0
  37. rasa/studio/data_handler.py +10 -3
  38. rasa/studio/upload.py +21 -10
  39. rasa/telemetry.py +12 -0
  40. rasa/tracing/config.py +2 -0
  41. rasa/tracing/instrumentation/attribute_extractors.py +20 -0
  42. rasa/tracing/instrumentation/instrumentation.py +121 -0
  43. rasa/utils/common.py +5 -0
  44. rasa/utils/io.py +8 -16
  45. rasa/utils/sanic_error_handler.py +32 -0
  46. rasa/version.py +1 -1
  47. {rasa_pro-3.11.0rc1.dist-info → rasa_pro-3.11.0rc2.dist-info}/METADATA +3 -2
  48. {rasa_pro-3.11.0rc1.dist-info → rasa_pro-3.11.0rc2.dist-info}/RECORD +51 -47
  49. rasa/shared/utils/health_check.py +0 -533
  50. {rasa_pro-3.11.0rc1.dist-info → rasa_pro-3.11.0rc2.dist-info}/NOTICE +0 -0
  51. {rasa_pro-3.11.0rc1.dist-info → rasa_pro-3.11.0rc2.dist-info}/WHEEL +0 -0
  52. {rasa_pro-3.11.0rc1.dist-info → rasa_pro-3.11.0rc2.dist-info}/entry_points.txt +0 -0
@@ -18,10 +18,6 @@ from rasa.core.constants import (
18
18
  UTTER_SOURCE_METADATA_KEY,
19
19
  )
20
20
  from rasa.core.policies.policy import Policy, PolicyPrediction, SupportedData
21
- from rasa.dialogue_understanding.generator.constants import (
22
- TRAINED_MODEL_NAME_CONFIG_KEY,
23
- TRAINED_EMBEDDINGS_CONFIG_KEY,
24
- )
25
21
  from rasa.dialogue_understanding.patterns.chitchat import FLOW_PATTERN_CHITCHAT
26
22
  from rasa.dialogue_understanding.stack.frames import (
27
23
  ChitChatStackFrame,
@@ -64,6 +60,10 @@ from rasa.shared.providers.embedding._langchain_embedding_client_adapter import
64
60
  _LangchainEmbeddingClientAdapter,
65
61
  )
66
62
  from rasa.shared.providers.llm.llm_client import LLMClient
63
+ from rasa.shared.utils.health_check.embeddings_health_check_mixin import (
64
+ EmbeddingsHealthCheckMixin,
65
+ )
66
+ from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheckMixin
67
67
  from rasa.shared.utils.io import deep_container_fingerprint
68
68
  from rasa.shared.utils.llm import (
69
69
  AI,
@@ -79,12 +79,6 @@ from rasa.shared.utils.llm import (
79
79
  tracker_as_readable_transcript,
80
80
  resolve_model_client_config,
81
81
  )
82
- from rasa.shared.utils.health_check import (
83
- perform_training_time_llm_health_check,
84
- perform_training_time_embeddings_health_check,
85
- perform_inference_time_llm_health_check,
86
- perform_inference_time_embeddings_health_check,
87
- )
88
82
  from rasa.utils.log_utils import log_llm
89
83
  from rasa.utils.ml_utils import (
90
84
  extract_ai_response_examples,
@@ -383,7 +377,7 @@ def conversation_as_prompt(conversation: Conversation) -> str:
383
377
  @DefaultV1Recipe.register(
384
378
  DefaultV1Recipe.ComponentType.POLICY_WITH_END_TO_END_SUPPORT, is_trainable=True
385
379
  )
386
- class IntentlessPolicy(Policy):
380
+ class IntentlessPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Policy):
387
381
  """Policy which uses a language model to generate the next action.
388
382
 
389
383
  The policy uses the OpenAI API to generate the next action based on the
@@ -516,10 +510,8 @@ class IntentlessPolicy(Policy):
516
510
  A policy must return its resource locator so that potential children nodes
517
511
  can load the policy from the resource.
518
512
  """
519
- (
520
- self.config[TRAINED_MODEL_NAME_CONFIG_KEY],
521
- self.config[TRAINED_EMBEDDINGS_CONFIG_KEY],
522
- ) = self._perform_training_time_health_checks()
513
+ # Perform health checks of both LLM and embeddings client configs
514
+ self._perform_health_checks(self.config, "intentless_policy.train")
523
515
 
524
516
  responses = filter_responses(responses, forms, flows or FlowsList([]))
525
517
  telemetry.track_intentless_policy_train()
@@ -952,10 +944,13 @@ class IntentlessPolicy(Policy):
952
944
  **kwargs: Any,
953
945
  ) -> "IntentlessPolicy":
954
946
  """Loads a trained policy (see parent class for full docstring)."""
947
+
948
+ # Perform health checks of both LLM and embeddings client configs
949
+ cls._perform_health_checks(config, "intentless_policy.load")
950
+
955
951
  responses_docsearch = None
956
952
  samples_docsearch = None
957
953
  prompt_template = None
958
- persisted_config = None
959
954
  try:
960
955
  with model_storage.read_from(resource) as path:
961
956
  responses_docsearch = load_faiss_vector_store(
@@ -973,15 +968,12 @@ class IntentlessPolicy(Policy):
973
968
  prompt_template = rasa.shared.utils.io.read_file(
974
969
  path / INTENTLESS_PROMPT_TEMPLATE_FILE_NAME
975
970
  )
976
- persisted_config = rasa.shared.utils.io.read_json_file(
977
- path / INTENTLESS_CONFIG_FILE_NAME
978
- )
979
971
  except (ValueError, FileNotFoundError, FileIOException) as e:
980
972
  structlogger.warning(
981
973
  "intentless_policy.load.failed", error=e, resource_name=resource.name
982
974
  )
983
975
 
984
- policy = cls(
976
+ return cls(
985
977
  config,
986
978
  model_storage,
987
979
  resource,
@@ -991,14 +983,6 @@ class IntentlessPolicy(Policy):
991
983
  prompt_template=prompt_template,
992
984
  )
993
985
 
994
- cls._perform_inference_time_health_checks(
995
- persisted_config,
996
- policy.config.get(LLM_CONFIG_KEY),
997
- policy.config.get(EMBEDDINGS_CONFIG_KEY),
998
- )
999
-
1000
- return policy
1001
-
1002
986
  @classmethod
1003
987
  def fingerprint_addon(cls, config: Dict[str, Any]) -> Optional[str]:
1004
988
  """Add a fingerprint of intentless policy for the graph."""
@@ -1018,52 +1002,26 @@ class IntentlessPolicy(Policy):
1018
1002
  [prompt_template, llm_config, embedding_config]
1019
1003
  )
1020
1004
 
1021
- def _perform_training_time_health_checks(
1022
- self,
1023
- ) -> Tuple[Optional[str], Optional[str]]:
1024
- train_model_name = perform_training_time_llm_health_check(
1025
- self.config.get(LLM_CONFIG_KEY),
1026
- DEFAULT_LLM_CONFIG,
1027
- "intentless_policy.train",
1028
- IntentlessPolicy.__name__,
1029
- )
1030
- train_embedding_name = perform_training_time_embeddings_health_check(
1031
- self.config.get(EMBEDDINGS_CONFIG_KEY),
1032
- DEFAULT_EMBEDDINGS_CONFIG,
1033
- "intentless_policy.train",
1034
- IntentlessPolicy.__name__,
1035
- )
1036
- return train_model_name, train_embedding_name
1037
-
1038
1005
  @classmethod
1039
- def _perform_inference_time_health_checks(
1040
- cls,
1041
- persisted_config: Optional[Dict[str, Any]],
1042
- resolved_llm_config: Optional[Dict[str, Any]],
1043
- resolved_embeddings_config: Optional[Dict[str, Any]],
1006
+ def _perform_health_checks(
1007
+ cls, config: Dict[Text, Any], log_source_method: str
1044
1008
  ) -> None:
1045
- train_model_name = (
1046
- persisted_config.get(TRAINED_MODEL_NAME_CONFIG_KEY, None)
1047
- if persisted_config
1048
- else None
1049
- )
1050
- perform_inference_time_llm_health_check(
1051
- resolved_llm_config,
1009
+ # Perform health check of the LLM client config
1010
+ llm_config = resolve_model_client_config(config.get(LLM_CONFIG_KEY, {}))
1011
+ cls.perform_llm_health_check(
1012
+ llm_config,
1052
1013
  DEFAULT_LLM_CONFIG,
1053
- train_model_name,
1054
- "intentless_policy.load",
1014
+ log_source_method,
1055
1015
  IntentlessPolicy.__name__,
1056
1016
  )
1057
1017
 
1058
- train_embeddings_name = (
1059
- persisted_config.get(TRAINED_EMBEDDINGS_CONFIG_KEY, None)
1060
- if persisted_config
1061
- else None
1018
+ # Perform health check of the embeddings client config
1019
+ embeddings_config = resolve_model_client_config(
1020
+ config.get(EMBEDDINGS_CONFIG_KEY, {})
1062
1021
  )
1063
- perform_inference_time_embeddings_health_check(
1064
- resolved_embeddings_config,
1022
+ cls.perform_embeddings_health_check(
1023
+ embeddings_config,
1065
1024
  DEFAULT_EMBEDDINGS_CONFIG,
1066
- train_embeddings_name,
1067
- "intentless_policy.load",
1025
+ log_source_method,
1068
1026
  IntentlessPolicy.__name__,
1069
1027
  )
@@ -17,7 +17,6 @@ from rasa.dialogue_understanding.commands import Command, SetSlotCommand
17
17
  from rasa.dialogue_understanding.commands.noop_command import NoopCommand
18
18
  from rasa.dialogue_understanding.generator.constants import (
19
19
  LLM_CONFIG_KEY,
20
- TRAINED_MODEL_NAME_CONFIG_KEY,
21
20
  )
22
21
  from rasa.engine.graph import ExecutionContext, GraphComponent
23
22
  from rasa.engine.recipes.default_recipe import DefaultV1Recipe
@@ -36,6 +35,7 @@ from rasa.shared.exceptions import InvalidConfigException, FileIOException
36
35
  from rasa.shared.nlu.constants import COMMANDS, TEXT
37
36
  from rasa.shared.nlu.training_data.message import Message
38
37
  from rasa.shared.nlu.training_data.training_data import TrainingData
38
+ from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheckMixin
39
39
  from rasa.shared.utils.io import deep_container_fingerprint
40
40
  from rasa.shared.utils.llm import (
41
41
  DEFAULT_OPENAI_CHAT_MODEL_NAME,
@@ -43,10 +43,6 @@ from rasa.shared.utils.llm import (
43
43
  llm_factory,
44
44
  resolve_model_client_config,
45
45
  )
46
- from rasa.shared.utils.health_check import (
47
- perform_training_time_llm_health_check,
48
- perform_inference_time_llm_health_check,
49
- )
50
46
  from rasa.utils.log_utils import log_llm
51
47
 
52
48
  LLM_BASED_ROUTER_PROMPT_FILE_NAME = "llm_based_router_prompt.jinja2"
@@ -80,7 +76,7 @@ structlogger = structlog.get_logger()
80
76
  ],
81
77
  is_trainable=True,
82
78
  )
83
- class LLMBasedRouter(GraphComponent):
79
+ class LLMBasedRouter(LLMHealthCheckMixin, GraphComponent):
84
80
  @staticmethod
85
81
  def get_default_config() -> Dict[str, Any]:
86
82
  """The component's default config (see parent class for full docstring)."""
@@ -144,13 +140,11 @@ class LLMBasedRouter(GraphComponent):
144
140
 
145
141
  def train(self, training_data: TrainingData) -> Resource:
146
142
  """Train the intent classifier on a data set."""
147
- self.config[TRAINED_MODEL_NAME_CONFIG_KEY] = (
148
- perform_training_time_llm_health_check(
149
- self.config.get(LLM_CONFIG_KEY),
150
- DEFAULT_LLM_CONFIG,
151
- "llm_based_router.train",
152
- LLMBasedRouter.__name__,
153
- )
143
+ self.perform_llm_health_check(
144
+ self.config.get(LLM_CONFIG_KEY),
145
+ DEFAULT_LLM_CONFIG,
146
+ "llm_based_router.train",
147
+ LLMBasedRouter.__name__,
154
148
  )
155
149
 
156
150
  self.persist()
@@ -166,37 +160,28 @@ class LLMBasedRouter(GraphComponent):
166
160
  **kwargs: Any,
167
161
  ) -> "LLMBasedRouter":
168
162
  """Loads trained component (see parent class for full docstring)."""
163
+
164
+ # Perform health check on the resolved LLM client config
165
+ llm_config = resolve_model_client_config(config.get(LLM_CONFIG_KEY, {}))
166
+ cls.perform_llm_health_check(
167
+ llm_config,
168
+ DEFAULT_LLM_CONFIG,
169
+ "llm_based_router.load",
170
+ LLMBasedRouter.__name__,
171
+ )
172
+
169
173
  prompt_template = None
170
- persisted_config = None
171
174
  try:
172
175
  with model_storage.read_from(resource) as path:
173
176
  prompt_template = rasa.shared.utils.io.read_file(
174
177
  path / LLM_BASED_ROUTER_PROMPT_FILE_NAME
175
178
  )
176
- persisted_config = rasa.shared.utils.io.read_json_file(
177
- path / LLM_BASED_ROUTER_CONFIG_FILE_NAME
178
- )
179
179
  except (FileNotFoundError, FileIOException) as e:
180
180
  structlogger.warning(
181
181
  "llm_based_router.load.failed", error=e, resource=resource.name
182
182
  )
183
183
 
184
- router = cls(config, model_storage, resource, prompt_template=prompt_template)
185
-
186
- train_model_name = (
187
- persisted_config.get(TRAINED_MODEL_NAME_CONFIG_KEY, None)
188
- if persisted_config
189
- else None
190
- )
191
- perform_inference_time_llm_health_check(
192
- router.config.get(LLM_CONFIG_KEY),
193
- DEFAULT_LLM_CONFIG,
194
- train_model_name,
195
- "llm_based_router.load",
196
- LLMBasedRouter.__name__,
197
- )
198
-
199
- return router
184
+ return cls(config, model_storage, resource, prompt_template=prompt_template)
200
185
 
201
186
  @classmethod
202
187
  def create(
@@ -18,8 +18,6 @@ DEFAULT_LLM_CONFIG = {
18
18
  }
19
19
 
20
20
  LLM_CONFIG_KEY = "llm"
21
- TRAINED_MODEL_NAME_CONFIG_KEY = "trained_llm_model_name"
22
- TRAINED_EMBEDDINGS_CONFIG_KEY = "trained_embeddings_model_name"
23
21
  USER_INPUT_CONFIG_KEY = "user_input"
24
22
 
25
23
  FLOW_RETRIEVAL_KEY = "flow_retrieval"
@@ -27,12 +27,9 @@ from langchain.schema.embeddings import Embeddings
27
27
  from langchain_community.vectorstores.faiss import FAISS
28
28
  from langchain_community.vectorstores.utils import DistanceStrategy
29
29
 
30
- from rasa.dialogue_understanding.generator.constants import (
31
- TRAINED_EMBEDDINGS_CONFIG_KEY,
32
- )
33
30
  from rasa.engine.storage.resource import Resource
34
31
  from rasa.engine.storage.storage import ModelStorage
35
-
32
+ import rasa.shared.utils.io
36
33
  from rasa.shared.constants import (
37
34
  EMBEDDINGS_CONFIG_KEY,
38
35
  PROVIDER_CONFIG_KEY,
@@ -41,12 +38,15 @@ from rasa.shared.constants import (
41
38
  from rasa.shared.core.domain import Domain
42
39
  from rasa.shared.core.flows import FlowsList
43
40
  from rasa.shared.core.trackers import DialogueStateTracker
44
- from rasa.shared.exceptions import ProviderClientAPIException, FileIOException
41
+ from rasa.shared.exceptions import ProviderClientAPIException
45
42
  from rasa.shared.nlu.constants import TEXT, FLOWS_FROM_SEMANTIC_SEARCH
46
43
  from rasa.shared.nlu.training_data.message import Message
47
44
  from rasa.shared.providers.embedding._langchain_embedding_client_adapter import (
48
45
  _LangchainEmbeddingClientAdapter,
49
46
  )
47
+ from rasa.shared.utils.health_check.embeddings_health_check_mixin import (
48
+ EmbeddingsHealthCheckMixin,
49
+ )
50
50
  from rasa.shared.utils.llm import (
51
51
  tracker_as_readable_transcript,
52
52
  embedder_factory,
@@ -56,11 +56,6 @@ from rasa.shared.utils.llm import (
56
56
  allowed_values_for_slot,
57
57
  resolve_model_client_config,
58
58
  )
59
- from rasa.shared.utils.health_check import (
60
- perform_training_time_embeddings_health_check,
61
- perform_inference_time_embeddings_health_check,
62
- )
63
- from rasa.shared.utils.io import dump_obj_as_json_to_file, read_json_file
64
59
 
65
60
  DEFAULT_FLOW_DOCUMENT_TEMPLATE = importlib.resources.read_text(
66
61
  "rasa.dialogue_understanding.generator", "flow_document_template.jinja2"
@@ -85,7 +80,7 @@ DEFAULT_SHOULD_EMBED_SLOTS = True
85
80
  structlogger = structlog.get_logger()
86
81
 
87
82
 
88
- class FlowRetrieval:
83
+ class FlowRetrieval(EmbeddingsHealthCheckMixin):
89
84
  @classmethod
90
85
  def get_default_config(cls) -> Dict[str, Any]:
91
86
  """The default config for the flow retrieval."""
@@ -94,7 +89,6 @@ class FlowRetrieval:
94
89
  MAX_FLOWS_FROM_SEMANTIC_SEARCH_KEY: DEFAULT_MAX_FLOWS_FROM_SEMANTIC_SEARCH,
95
90
  TURNS_TO_EMBED_KEY: DEFAULT_TURNS_TO_EMBED,
96
91
  SHOULD_EMBED_SLOTS_KEY: DEFAULT_SHOULD_EMBED_SLOTS,
97
- TRAINED_EMBEDDINGS_CONFIG_KEY: None,
98
92
  }
99
93
 
100
94
  def __init__(
@@ -147,16 +141,6 @@ class FlowRetrieval:
147
141
 
148
142
  return config
149
143
 
150
- def train(self) -> None:
151
- self.config[TRAINED_EMBEDDINGS_CONFIG_KEY] = (
152
- perform_training_time_embeddings_health_check(
153
- self.config.get(EMBEDDINGS_CONFIG_KEY),
154
- DEFAULT_EMBEDDINGS_CONFIG,
155
- "flow_retrieval.train",
156
- FlowRetrieval.__name__,
157
- )
158
- )
159
-
160
144
  @classmethod
161
145
  def load(
162
146
  cls,
@@ -166,6 +150,18 @@ class FlowRetrieval:
166
150
  **kwargs: Any,
167
151
  ) -> "FlowRetrieval":
168
152
  """Load flow retrieval with previously populated FAISS vector store."""
153
+
154
+ # Perform health check on resolved embedding client config
155
+ embeddings_config = resolve_model_client_config(
156
+ config.get(EMBEDDINGS_CONFIG_KEY, {})
157
+ )
158
+ cls.perform_embeddings_health_check(
159
+ embeddings_config,
160
+ DEFAULT_EMBEDDINGS_CONFIG,
161
+ "flow_retrieval.load",
162
+ FlowRetrieval.__name__,
163
+ )
164
+
169
165
  # initialize base flow retrieval
170
166
  flow_retrieval = FlowRetrieval(config, model_storage, resource)
171
167
  # load vector store
@@ -174,30 +170,6 @@ class FlowRetrieval:
174
170
  )
175
171
  flow_retrieval.vector_store = vector_store
176
172
 
177
- persisted_config = None
178
- try:
179
- with model_storage.read_from(resource) as path:
180
- persisted_config = read_json_file(
181
- path / FLOW_RETRIEVAL_CONFIG_FILE_NAME
182
- )
183
- except (FileNotFoundError, FileIOException) as e:
184
- structlogger.warning(
185
- "flow_retrieval.load.failed", error=e, resource=resource.name
186
- )
187
-
188
- train_embeddings_name = (
189
- persisted_config.get(TRAINED_EMBEDDINGS_CONFIG_KEY, None)
190
- if persisted_config
191
- else None
192
- )
193
- perform_inference_time_embeddings_health_check(
194
- flow_retrieval.config.get(EMBEDDINGS_CONFIG_KEY),
195
- DEFAULT_EMBEDDINGS_CONFIG,
196
- train_embeddings_name,
197
- "flow_retrieval.load",
198
- FlowRetrieval.__name__,
199
- )
200
-
201
173
  return flow_retrieval
202
174
 
203
175
  @classmethod
@@ -243,10 +215,7 @@ class FlowRetrieval:
243
215
 
244
216
  def persist(self) -> None:
245
217
  self._persist_vector_store()
246
- with self._model_storage.write_to(self._resource) as path:
247
- dump_obj_as_json_to_file(
248
- path / FLOW_RETRIEVAL_CONFIG_FILE_NAME, self.config
249
- )
218
+ self._persist_config()
250
219
 
251
220
  def _persist_vector_store(self) -> None:
252
221
  """Persists the FAISS vector store."""
@@ -259,6 +228,12 @@ class FlowRetrieval:
259
228
  event_info="Vector store is None, not persisted.",
260
229
  )
261
230
 
231
+ def _persist_config(self) -> None:
232
+ with self._model_storage.write_to(self._resource) as path:
233
+ rasa.shared.utils.io.dump_obj_as_json_to_file(
234
+ path / FLOW_RETRIEVAL_CONFIG_FILE_NAME, self.config
235
+ )
236
+
262
237
  def populate(self, flows: FlowsList, domain: Domain) -> None:
263
238
  """Populates the vector store with embeddings generated from
264
239
  documents based on the flow descriptions, and flow slots
@@ -268,6 +243,14 @@ class FlowRetrieval:
268
243
  flows: List of flows to populate the vector store with.
269
244
  domain: The domain containing relevant slot information.
270
245
  """
246
+ # Perform health check before populating the vector store with flows
247
+ self.perform_embeddings_health_check(
248
+ self.config.get(EMBEDDINGS_CONFIG_KEY),
249
+ DEFAULT_EMBEDDINGS_CONFIG,
250
+ "flow_retrieval.train",
251
+ FlowRetrieval.__name__,
252
+ )
253
+
271
254
  flows_to_embedd = flows.exclude_link_only_flows()
272
255
  embeddings = self._create_embedder(self.config)
273
256
  documents = self._generate_flow_documents(flows_to_embedd, domain)
@@ -17,7 +17,6 @@ from rasa.dialogue_understanding.generator.constants import (
17
17
  FLOW_RETRIEVAL_KEY,
18
18
  FLOW_RETRIEVAL_ACTIVE_KEY,
19
19
  FLOW_RETRIEVAL_FLOW_THRESHOLD,
20
- TRAINED_MODEL_NAME_CONFIG_KEY,
21
20
  )
22
21
  from rasa.dialogue_understanding.generator.flow_retrieval import FlowRetrieval
23
22
  from rasa.engine.graph import GraphComponent, ExecutionContext
@@ -33,27 +32,26 @@ from rasa.shared.exceptions import ProviderClientAPIException
33
32
  from rasa.shared.nlu.constants import FLOWS_IN_PROMPT
34
33
  from rasa.shared.nlu.training_data.message import Message
35
34
  from rasa.shared.nlu.training_data.training_data import TrainingData
35
+ from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheckMixin
36
36
  from rasa.shared.utils.llm import (
37
37
  allowed_values_for_slot,
38
38
  llm_factory,
39
39
  resolve_model_client_config,
40
40
  )
41
- from rasa.shared.utils.health_check import perform_training_time_llm_health_check
42
41
  from rasa.utils.log_utils import log_llm
43
42
 
44
43
  structlogger = structlog.get_logger()
45
44
 
46
45
 
47
- LLM_BASED_COMMAND_GENERATOR_CONFIG_FILE = "config.json"
48
-
49
-
50
46
  @DefaultV1Recipe.register(
51
47
  [
52
48
  DefaultV1Recipe.ComponentType.COMMAND_GENERATOR,
53
49
  ],
54
50
  is_trainable=True,
55
51
  )
56
- class LLMBasedCommandGenerator(GraphComponent, CommandGenerator, ABC):
52
+ class LLMBasedCommandGenerator(
53
+ LLMHealthCheckMixin, GraphComponent, CommandGenerator, ABC
54
+ ):
57
55
  """An abstract class defining interface and common functionality
58
56
  of an LLM-based command generators.
59
57
  """
@@ -106,11 +104,7 @@ class LLMBasedCommandGenerator(GraphComponent, CommandGenerator, ABC):
106
104
  @abstractmethod
107
105
  def persist(self) -> None:
108
106
  """Persist the component to disk for future loading."""
109
- # persist the config to store the resolved llm and embedding config
110
- with self._model_storage.write_to(self._resource) as path:
111
- rasa.shared.utils.io.dump_obj_as_json_to_file(
112
- path / LLM_BASED_COMMAND_GENERATOR_CONFIG_FILE, self.config
113
- )
107
+ pass
114
108
 
115
109
  @abstractmethod
116
110
  async def predict_commands(
@@ -173,13 +167,11 @@ class LLMBasedCommandGenerator(GraphComponent, CommandGenerator, ABC):
173
167
  """Train the llm based command generator. Stores all flows into a vector
174
168
  store.
175
169
  """
176
- self.config[TRAINED_MODEL_NAME_CONFIG_KEY] = (
177
- perform_training_time_llm_health_check(
178
- self.config.get(LLM_CONFIG_KEY),
179
- DEFAULT_LLM_CONFIG,
180
- "llm_based_command_generator.train",
181
- LLMBasedCommandGenerator.__name__,
182
- )
170
+ self.perform_llm_health_check(
171
+ self.config.get(LLM_CONFIG_KEY),
172
+ DEFAULT_LLM_CONFIG,
173
+ "llm_based_command_generator.train",
174
+ LLMBasedCommandGenerator.__name__,
183
175
  )
184
176
 
185
177
  if (
@@ -210,12 +202,11 @@ class LLMBasedCommandGenerator(GraphComponent, CommandGenerator, ABC):
210
202
  except Exception as e:
211
203
  structlogger.error(
212
204
  "llm_based_command_generator.train.failed",
213
- event_info="Flow retrieval store isinaccessible.",
205
+ event_info="Flow retrieval store is inaccessible.",
214
206
  error=e,
215
207
  )
216
208
  raise
217
- if self.flow_retrieval is not None:
218
- self.flow_retrieval.train()
209
+
219
210
  self.persist()
220
211
  return self._resource
221
212
 
@@ -251,25 +242,6 @@ class LLMBasedCommandGenerator(GraphComponent, CommandGenerator, ABC):
251
242
  )
252
243
  return None
253
244
 
254
- @classmethod
255
- def load_config_from_model_storage(
256
- cls,
257
- model_storage: ModelStorage,
258
- resource: Resource,
259
- ) -> Optional[Text]:
260
- try:
261
- with model_storage.read_from(resource) as path:
262
- return rasa.shared.utils.io.read_json_file(
263
- path / LLM_BASED_COMMAND_GENERATOR_CONFIG_FILE
264
- )
265
- except (FileNotFoundError, FileIOException) as e:
266
- structlogger.warning(
267
- "llm_based_command_generator.load_config.failed",
268
- error=e,
269
- resource=resource.name,
270
- )
271
- return None
272
-
273
245
  @classmethod
274
246
  def load_flow_retrival(
275
247
  cls,
@@ -24,7 +24,6 @@ from rasa.dialogue_understanding.generator.constants import (
24
24
  LLM_CONFIG_KEY,
25
25
  USER_INPUT_CONFIG_KEY,
26
26
  FLOW_RETRIEVAL_KEY,
27
- TRAINED_MODEL_NAME_CONFIG_KEY,
28
27
  DEFAULT_LLM_CONFIG,
29
28
  )
30
29
  from rasa.dialogue_understanding.generator.flow_retrieval import FlowRetrieval
@@ -60,7 +59,6 @@ from rasa.shared.utils.llm import (
60
59
  allowed_values_for_slot,
61
60
  resolve_model_client_config,
62
61
  )
63
- from rasa.shared.utils.health_check import perform_inference_time_llm_health_check
64
62
 
65
63
  # multistep template keys
66
64
  HANDLE_FLOWS_KEY = "handle_flows"
@@ -77,6 +75,7 @@ DEFAULT_HANDLE_FLOWS_TEMPLATE = importlib.resources.read_text(
77
75
  DEFAULT_FILL_SLOTS_TEMPLATE = importlib.resources.read_text(
78
76
  "rasa.dialogue_understanding.generator.multi_step", "fill_slots_prompt.jinja2"
79
77
  ).strip()
78
+ MULTI_STEP_LLM_COMMAND_GENERATOR_CONFIG_FILE = "config.json"
80
79
 
81
80
  # dictionary of template names and associated file names and default values
82
81
  PROMPT_TEMPLATES = {
@@ -145,15 +144,18 @@ class MultiStepLLMCommandGenerator(LLMBasedCommandGenerator):
145
144
  **kwargs: Any,
146
145
  ) -> "MultiStepLLMCommandGenerator":
147
146
  """Loads trained component (see parent class for full docstring)."""
148
- prompts = cls._load_prompt_templates(model_storage, resource)
149
147
 
150
- persisted_config = cls.load_config_from_model_storage(model_storage, resource)
151
- train_model_name = (
152
- persisted_config.get(TRAINED_MODEL_NAME_CONFIG_KEY, None)
153
- if persisted_config
154
- else None
148
+ # Perform health check of the LLM client config
149
+ llm_config = resolve_model_client_config(config.get(LLM_CONFIG_KEY, {}))
150
+ cls.perform_llm_health_check(
151
+ llm_config,
152
+ DEFAULT_LLM_CONFIG,
153
+ "multi_step_llm_command_generator.load",
154
+ MultiStepLLMCommandGenerator.__name__,
155
155
  )
156
156
 
157
+ prompts = cls._load_prompt_templates(model_storage, resource)
158
+
157
159
  # init base command generator
158
160
  command_generator = cls(config, model_storage, resource, prompts)
159
161
  # load flow retrieval if enabled
@@ -162,23 +164,12 @@ class MultiStepLLMCommandGenerator(LLMBasedCommandGenerator):
162
164
  command_generator.config, model_storage, resource
163
165
  )
164
166
 
165
- perform_inference_time_llm_health_check(
166
- command_generator.config.get(LLM_CONFIG_KEY),
167
- DEFAULT_LLM_CONFIG,
168
- train_model_name,
169
- "multi_step_llm_command_generator.load",
170
- MultiStepLLMCommandGenerator.__name__,
171
- )
172
-
173
167
  return command_generator
174
168
 
175
169
  def persist(self) -> None:
176
170
  """Persist this component to disk for future loading."""
177
- super().persist()
178
-
179
- # persist prompt template
180
171
  self._persist_prompt_templates()
181
- # persist flow retrieval
172
+ self._persist_config()
182
173
  if self.flow_retrieval is not None:
183
174
  self.flow_retrieval.persist()
184
175
 
@@ -411,6 +402,13 @@ class MultiStepLLMCommandGenerator(LLMBasedCommandGenerator):
411
402
  file_path = path / file_name
412
403
  rasa.shared.utils.io.write_text_file(template, file_path)
413
404
 
405
+ def _persist_config(self) -> None:
406
+ """Persist config as a source of truth for resolved clients."""
407
+ with self._model_storage.write_to(self._resource) as path:
408
+ rasa.shared.utils.io.dump_obj_as_json_to_file(
409
+ path / MULTI_STEP_LLM_COMMAND_GENERATOR_CONFIG_FILE, self.config
410
+ )
411
+
414
412
  async def _predict_commands_with_multi_step(
415
413
  self,
416
414
  message: Message,
@@ -19,6 +19,7 @@ from rasa.engine.storage.storage import ModelStorage
19
19
  from rasa.shared.constants import ROUTE_TO_CALM_SLOT
20
20
  from rasa.shared.core.domain import Domain
21
21
  from rasa.shared.core.flows.flows_list import FlowsList
22
+ from rasa.shared.core.flows.steps import CollectInformationFlowStep
22
23
  from rasa.shared.core.slot_mappings import (
23
24
  SlotFillingManager,
24
25
  extract_slot_value,
@@ -217,7 +218,24 @@ def _issue_set_slot_commands(
217
218
  commands: List[Command] = []
218
219
  domain = domain if domain else Domain.empty()
219
220
  slot_filling_manager = SlotFillingManager(domain, tracker, message)
220
- available_slot_names = flows.available_slot_names()
221
+
222
+ # only use slots that don't have ask_before_filling set to True
223
+ available_slot_names = flows.available_slot_names(ask_before_filling=False)
224
+
225
+ # check if the current step is a CollectInformationFlowStep
226
+ # in case it has ask_before_filling set to True, we need to add the
227
+ # slot to the available_slot_names
228
+ if tracker.active_flow:
229
+ flow = flows.flow_by_id(tracker.active_flow)
230
+ step_id = tracker.current_step_id
231
+ if flow is not None:
232
+ current_step = flow.step_by_id(step_id)
233
+ if (
234
+ current_step
235
+ and isinstance(current_step, CollectInformationFlowStep)
236
+ and current_step.ask_before_filling
237
+ ):
238
+ available_slot_names.add(current_step.collect)
221
239
 
222
240
  for _, slot in tracker.slots.items():
223
241
  # if a slot is not collected in available flows,