rasa-pro 3.11.0rc1__py3-none-any.whl → 3.11.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (66) hide show
  1. rasa/cli/inspect.py +2 -0
  2. rasa/cli/studio/studio.py +18 -8
  3. rasa/core/actions/action_repeat_bot_messages.py +17 -0
  4. rasa/core/channels/channel.py +17 -0
  5. rasa/core/channels/development_inspector.py +4 -1
  6. rasa/core/channels/voice_ready/audiocodes.py +15 -4
  7. rasa/core/channels/voice_ready/jambonz.py +13 -2
  8. rasa/core/channels/voice_ready/twilio_voice.py +6 -21
  9. rasa/core/channels/voice_stream/asr/asr_event.py +1 -1
  10. rasa/core/channels/voice_stream/asr/azure.py +5 -7
  11. rasa/core/channels/voice_stream/asr/deepgram.py +13 -11
  12. rasa/core/channels/voice_stream/voice_channel.py +61 -19
  13. rasa/core/nlg/contextual_response_rephraser.py +20 -12
  14. rasa/core/policies/enterprise_search_policy.py +32 -72
  15. rasa/core/policies/intentless_policy.py +34 -72
  16. rasa/dialogue_understanding/coexistence/llm_based_router.py +18 -33
  17. rasa/dialogue_understanding/generator/constants.py +0 -2
  18. rasa/dialogue_understanding/generator/flow_retrieval.py +33 -50
  19. rasa/dialogue_understanding/generator/llm_based_command_generator.py +12 -40
  20. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +18 -20
  21. rasa/dialogue_understanding/generator/nlu_command_adapter.py +19 -1
  22. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +26 -22
  23. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +9 -0
  24. rasa/dialogue_understanding/processor/command_processor.py +21 -1
  25. rasa/e2e_test/e2e_test_case.py +85 -6
  26. rasa/engine/validation.py +88 -60
  27. rasa/model_service.py +3 -0
  28. rasa/nlu/tokenizers/whitespace_tokenizer.py +3 -14
  29. rasa/server.py +3 -1
  30. rasa/shared/constants.py +5 -5
  31. rasa/shared/core/constants.py +1 -1
  32. rasa/shared/core/domain.py +0 -26
  33. rasa/shared/core/flows/flows_list.py +5 -1
  34. rasa/shared/providers/_configs/litellm_router_client_config.py +29 -9
  35. rasa/shared/providers/embedding/_base_litellm_embedding_client.py +6 -14
  36. rasa/shared/providers/embedding/litellm_router_embedding_client.py +1 -1
  37. rasa/shared/providers/llm/_base_litellm_client.py +32 -1
  38. rasa/shared/providers/llm/litellm_router_llm_client.py +56 -1
  39. rasa/shared/providers/llm/self_hosted_llm_client.py +4 -28
  40. rasa/shared/providers/router/_base_litellm_router_client.py +35 -1
  41. rasa/shared/utils/common.py +1 -1
  42. rasa/shared/utils/health_check/__init__.py +0 -0
  43. rasa/shared/utils/health_check/embeddings_health_check_mixin.py +31 -0
  44. rasa/shared/utils/health_check/health_check.py +256 -0
  45. rasa/shared/utils/health_check/llm_health_check_mixin.py +31 -0
  46. rasa/shared/utils/llm.py +5 -2
  47. rasa/shared/utils/yaml.py +102 -62
  48. rasa/studio/auth.py +3 -5
  49. rasa/studio/config.py +13 -4
  50. rasa/studio/constants.py +1 -0
  51. rasa/studio/data_handler.py +10 -3
  52. rasa/studio/upload.py +21 -10
  53. rasa/telemetry.py +15 -1
  54. rasa/tracing/config.py +3 -1
  55. rasa/tracing/instrumentation/attribute_extractors.py +20 -0
  56. rasa/tracing/instrumentation/instrumentation.py +121 -0
  57. rasa/utils/common.py +5 -0
  58. rasa/utils/io.py +8 -16
  59. rasa/utils/sanic_error_handler.py +32 -0
  60. rasa/version.py +1 -1
  61. {rasa_pro-3.11.0rc1.dist-info → rasa_pro-3.11.0rc3.dist-info}/METADATA +3 -2
  62. {rasa_pro-3.11.0rc1.dist-info → rasa_pro-3.11.0rc3.dist-info}/RECORD +65 -61
  63. rasa/shared/utils/health_check.py +0 -533
  64. {rasa_pro-3.11.0rc1.dist-info → rasa_pro-3.11.0rc3.dist-info}/NOTICE +0 -0
  65. {rasa_pro-3.11.0rc1.dist-info → rasa_pro-3.11.0rc3.dist-info}/WHEEL +0 -0
  66. {rasa_pro-3.11.0rc1.dist-info → rasa_pro-3.11.0rc3.dist-info}/entry_points.txt +0 -0
@@ -2,7 +2,6 @@ from typing import Any, Dict, Optional, Text
2
2
 
3
3
  import structlog
4
4
  from jinja2 import Template
5
-
6
5
  from rasa import telemetry
7
6
  from rasa.core.nlg.response import TemplatedNaturalLanguageGenerator
8
7
  from rasa.core.nlg.summarize import summarize_conversation
@@ -14,11 +13,12 @@ from rasa.shared.constants import (
14
13
  PROVIDER_CONFIG_KEY,
15
14
  OPENAI_PROVIDER,
16
15
  TIMEOUT_CONFIG_KEY,
17
- MODEL_GROUP_CONFIG_KEY,
16
+ MODEL_GROUP_ID_CONFIG_KEY,
18
17
  )
19
18
  from rasa.shared.core.domain import KEY_RESPONSES_TEXT, Domain
20
19
  from rasa.shared.core.events import BotUttered, UserUttered
21
20
  from rasa.shared.core.trackers import DialogueStateTracker
21
+ from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheckMixin
22
22
  from rasa.shared.utils.llm import (
23
23
  DEFAULT_OPENAI_GENERATE_MODEL_NAME,
24
24
  DEFAULT_OPENAI_MAX_GENERATED_TOKENS,
@@ -28,7 +28,6 @@ from rasa.shared.utils.llm import (
28
28
  llm_factory,
29
29
  resolve_model_client_config,
30
30
  )
31
- from rasa.shared.utils.health_check import perform_training_time_llm_health_check
32
31
  from rasa.shared.utils.llm import (
33
32
  tracker_as_readable_transcript,
34
33
  )
@@ -44,6 +43,8 @@ RESPONSE_REPHRASING_TEMPLATE_KEY = "rephrase_prompt"
44
43
  RESPONSE_SUMMARISE_CONVERSATION_KEY = "summarize_conversation"
45
44
 
46
45
  DEFAULT_REPHRASE_ALL = False
46
+ DEFAULT_SUMMARIZE_HISTORY = True
47
+ DEFAULT_MAX_HISTORICAL_TURNS = 5
47
48
 
48
49
  DEFAULT_LLM_CONFIG = {
49
50
  PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
@@ -68,7 +69,9 @@ Suggested AI Response: {{suggested_response}}
68
69
  Rephrased AI Response:"""
69
70
 
70
71
 
71
- class ContextualResponseRephraser(TemplatedNaturalLanguageGenerator):
72
+ class ContextualResponseRephraser(
73
+ LLMHealthCheckMixin, TemplatedNaturalLanguageGenerator
74
+ ):
72
75
  """Generates responses based on modified templates.
73
76
 
74
77
  The templates are filled with the entities and slots that are available in the
@@ -102,13 +105,19 @@ class ContextualResponseRephraser(TemplatedNaturalLanguageGenerator):
102
105
  self.trace_prompt_tokens = self.nlg_endpoint.kwargs.get(
103
106
  "trace_prompt_tokens", False
104
107
  )
108
+ self.summarize_history = self.nlg_endpoint.kwargs.get(
109
+ "summarize_history", DEFAULT_SUMMARIZE_HISTORY
110
+ )
111
+ self.max_historical_turns = self.nlg_endpoint.kwargs.get(
112
+ "max_historical_turns", DEFAULT_MAX_HISTORICAL_TURNS
113
+ )
105
114
 
106
115
  self.llm_config = resolve_model_client_config(
107
116
  self.nlg_endpoint.kwargs.get(LLM_CONFIG_KEY),
108
117
  ContextualResponseRephraser.__name__,
109
118
  )
110
119
 
111
- perform_training_time_llm_health_check(
120
+ self.perform_llm_health_check(
112
121
  self.llm_config,
113
122
  DEFAULT_LLM_CONFIG,
114
123
  "contextual_response_rephraser.init",
@@ -213,18 +222,17 @@ class ContextualResponseRephraser(TemplatedNaturalLanguageGenerator):
213
222
  prompt_template_text = self._template_for_response_rephrasing(response)
214
223
 
215
224
  # Retrieve inputs for the dynamic prompt
216
- transcript = tracker_as_readable_transcript(tracker, max_turns=5)
217
225
  latest_message = self._last_message_if_human(tracker)
218
226
  current_input = f"{USER}: {latest_message}" if latest_message else ""
219
227
 
220
228
  # Only summarise conversation history if flagged
221
- summarize_conversation_flag = response.get("metadata", {}).get(
222
- RESPONSE_SUMMARISE_CONVERSATION_KEY, False
223
- )
224
- if summarize_conversation_flag:
229
+ if self.summarize_history:
225
230
  history = await self._create_history(tracker)
226
231
  else:
227
- history = transcript
232
+ # make sure the transcript/history contains the last user utterance
233
+ max_turns = max(self.max_historical_turns, 1)
234
+ history = tracker_as_readable_transcript(tracker, max_turns=max_turns)
235
+ # the history already contains the current input
228
236
  current_input = ""
229
237
 
230
238
  prompt = Template(prompt_template_text).render(
@@ -245,7 +253,7 @@ class ContextualResponseRephraser(TemplatedNaturalLanguageGenerator):
245
253
  llm_type=self.llm_property(PROVIDER_CONFIG_KEY),
246
254
  llm_model=self.llm_property(MODEL_CONFIG_KEY)
247
255
  or self.llm_property(MODEL_NAME_CONFIG_KEY),
248
- llm_model_group_id=self.llm_property(MODEL_GROUP_CONFIG_KEY),
256
+ llm_model_group_id=self.llm_property(MODEL_GROUP_ID_CONFIG_KEY),
249
257
  )
250
258
  if not (updated_text := await self._generate_llm_response(prompt)):
251
259
  # If the LLM fails to generate a response, we
@@ -1,7 +1,7 @@
1
1
  import importlib.resources
2
2
  import json
3
3
  import re
4
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Text, Tuple
4
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Text
5
5
  import dotenv
6
6
  import structlog
7
7
  from jinja2 import Template
@@ -25,8 +25,6 @@ from rasa.core.policies.policy import Policy, PolicyPrediction
25
25
  from rasa.core.utils import AvailableEndpoints
26
26
  from rasa.dialogue_understanding.generator.constants import (
27
27
  LLM_CONFIG_KEY,
28
- TRAINED_MODEL_NAME_CONFIG_KEY,
29
- TRAINED_EMBEDDINGS_CONFIG_KEY,
30
28
  )
31
29
  from rasa.dialogue_understanding.patterns.cannot_handle import (
32
30
  CannotHandlePatternFlowStackFrame,
@@ -53,7 +51,7 @@ from rasa.shared.constants import (
53
51
  OPENAI_PROVIDER,
54
52
  TIMEOUT_CONFIG_KEY,
55
53
  MODEL_NAME_CONFIG_KEY,
56
- MODEL_GROUP_CONFIG_KEY,
54
+ MODEL_GROUP_ID_CONFIG_KEY,
57
55
  )
58
56
  from rasa.shared.core.constants import (
59
57
  ACTION_CANCEL_FLOW,
@@ -71,6 +69,10 @@ from rasa.shared.providers.embedding._langchain_embedding_client_adapter import
71
69
  )
72
70
  from rasa.shared.providers.llm.llm_client import LLMClient
73
71
  from rasa.shared.utils.cli import print_error_and_exit
72
+ from rasa.shared.utils.health_check.embeddings_health_check_mixin import (
73
+ EmbeddingsHealthCheckMixin,
74
+ )
75
+ from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheckMixin
74
76
  from rasa.shared.utils.io import deep_container_fingerprint
75
77
  from rasa.shared.utils.llm import (
76
78
  DEFAULT_OPENAI_CHAT_MODEL_NAME,
@@ -82,12 +84,6 @@ from rasa.shared.utils.llm import (
82
84
  tracker_as_readable_transcript,
83
85
  resolve_model_client_config,
84
86
  )
85
- from rasa.shared.utils.health_check import (
86
- perform_training_time_llm_health_check,
87
- perform_training_time_embeddings_health_check,
88
- perform_inference_time_llm_health_check,
89
- perform_inference_time_embeddings_health_check,
90
- )
91
87
  from rasa.telemetry import (
92
88
  track_enterprise_search_policy_predict,
93
89
  track_enterprise_search_policy_train_completed,
@@ -161,7 +157,7 @@ class VectorStoreConfigurationError(RasaException):
161
157
  @DefaultV1Recipe.register(
162
158
  DefaultV1Recipe.ComponentType.POLICY_WITH_END_TO_END_SUPPORT, is_trainable=True
163
159
  )
164
- class EnterpriseSearchPolicy(Policy):
160
+ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Policy):
165
161
  """Policy which uses a vector store and LLMs to respond to user messages.
166
162
 
167
163
  The policy uses a vector store and LLMs to respond to user messages. The
@@ -300,6 +296,9 @@ class EnterpriseSearchPolicy(Policy):
300
296
  A policy must return its resource locator so that potential children nodes
301
297
  can load the policy from the resource.
302
298
  """
299
+ # Perform health checks for both LLM and embeddings client configs
300
+ self._perform_health_checks(self.config, "enterprise_search_policy.train")
301
+
303
302
  store_type = self.vector_store_config.get(VECTOR_STORE_TYPE_PROPERTY)
304
303
 
305
304
  # telemetry call to track training start
@@ -319,11 +318,6 @@ class EnterpriseSearchPolicy(Policy):
319
318
  f"required environment variables. Error: {e}"
320
319
  )
321
320
 
322
- (
323
- self.config[TRAINED_MODEL_NAME_CONFIG_KEY],
324
- self.config[TRAINED_EMBEDDINGS_CONFIG_KEY],
325
- ) = self._perform_training_time_health_checks()
326
-
327
321
  if store_type == DEFAULT_VECTOR_STORE_TYPE:
328
322
  logger.info("enterprise_search_policy.train.faiss")
329
323
  with self._model_storage.write_to(self._resource) as path:
@@ -343,12 +337,12 @@ class EnterpriseSearchPolicy(Policy):
343
337
  embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
344
338
  or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
345
339
  embeddings_model_group_id=self.embeddings_config.get(
346
- MODEL_GROUP_CONFIG_KEY
340
+ MODEL_GROUP_ID_CONFIG_KEY
347
341
  ),
348
342
  llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
349
343
  llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
350
344
  or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
351
- llm_model_group_id=self.llm_config.get(MODEL_GROUP_CONFIG_KEY),
345
+ llm_model_group_id=self.llm_config.get(MODEL_GROUP_ID_CONFIG_KEY),
352
346
  citation_enabled=self.citation_enabled,
353
347
  )
354
348
  self.persist()
@@ -544,12 +538,12 @@ class EnterpriseSearchPolicy(Policy):
544
538
  embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
545
539
  or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
546
540
  embeddings_model_group_id=self.embeddings_config.get(
547
- MODEL_GROUP_CONFIG_KEY
541
+ MODEL_GROUP_ID_CONFIG_KEY
548
542
  ),
549
543
  llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
550
544
  llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
551
545
  or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
552
- llm_model_group_id=self.llm_config.get(MODEL_GROUP_CONFIG_KEY),
546
+ llm_model_group_id=self.llm_config.get(MODEL_GROUP_ID_CONFIG_KEY),
553
547
  citation_enabled=self.citation_enabled,
554
548
  )
555
549
  return self._create_prediction(
@@ -698,16 +692,16 @@ class EnterpriseSearchPolicy(Policy):
698
692
  **kwargs: Any,
699
693
  ) -> "EnterpriseSearchPolicy":
700
694
  """Loads a trained policy (see parent class for full docstring)."""
695
+
696
+ # Perform health checks for both LLM and embeddings client configs
697
+ cls._perform_health_checks(config, "enterprise_search_policy.load")
698
+
701
699
  prompt_template = None
702
- persisted_config = None
703
700
  try:
704
701
  with model_storage.read_from(resource) as path:
705
702
  prompt_template = rasa.shared.utils.io.read_file(
706
703
  path / ENTERPRISE_SEARCH_PROMPT_FILE_NAME
707
704
  )
708
- persisted_config = rasa.shared.utils.io.read_json_file(
709
- path / ENTERPRISE_SEARCH_CONFIG_FILE_NAME
710
- )
711
705
  except (FileNotFoundError, FileIOException) as e:
712
706
  logger.warning(
713
707
  "enterprise_search_policy.load.failed", error=e, resource=resource.name
@@ -737,7 +731,7 @@ class EnterpriseSearchPolicy(Policy):
737
731
  embeddings=embeddings,
738
732
  ) # type: ignore
739
733
 
740
- policy = cls(
734
+ return cls(
741
735
  config,
742
736
  model_storage,
743
737
  resource,
@@ -746,14 +740,6 @@ class EnterpriseSearchPolicy(Policy):
746
740
  prompt_template=prompt_template,
747
741
  )
748
742
 
749
- cls._perform_inference_time_health_checks(
750
- persisted_config,
751
- policy.config.get(LLM_CONFIG_KEY),
752
- policy.config.get(EMBEDDINGS_CONFIG_KEY),
753
- )
754
-
755
- return policy
756
-
757
743
  @classmethod
758
744
  def _get_local_knowledge_data(cls, config: Dict[str, Any]) -> Optional[List[str]]:
759
745
  """This is required only for local knowledge base types.
@@ -894,52 +880,26 @@ class EnterpriseSearchPolicy(Policy):
894
880
 
895
881
  return joined_answer + joined_sources
896
882
 
897
- def _perform_training_time_health_checks(
898
- self,
899
- ) -> Tuple[Optional[str], Optional[str]]:
900
- train_model_name = perform_training_time_llm_health_check(
901
- self.config.get(LLM_CONFIG_KEY),
902
- DEFAULT_LLM_CONFIG,
903
- "enterprise_search_policy.train",
904
- EnterpriseSearchPolicy.__name__,
905
- )
906
- train_embedding_name = perform_training_time_embeddings_health_check(
907
- self.config.get(EMBEDDINGS_CONFIG_KEY),
908
- DEFAULT_EMBEDDINGS_CONFIG,
909
- "enterprise_search_policy.train",
910
- EnterpriseSearchPolicy.__name__,
911
- )
912
- return train_model_name, train_embedding_name
913
-
914
883
  @classmethod
915
- def _perform_inference_time_health_checks(
916
- cls,
917
- persisted_config: Optional[Dict[str, Any]],
918
- resolved_llm_config: Optional[Dict[str, Any]],
919
- resolved_embeddings_config: Optional[Dict[str, Any]],
884
+ def _perform_health_checks(
885
+ cls, config: Dict[Text, Any], log_source_method: str
920
886
  ) -> None:
921
- train_model_name = (
922
- persisted_config.get(TRAINED_MODEL_NAME_CONFIG_KEY, None)
923
- if persisted_config
924
- else None
925
- )
926
- perform_inference_time_llm_health_check(
927
- resolved_llm_config,
887
+ # Perform health check of the LLM client config
888
+ llm_config = resolve_model_client_config(config.get(LLM_CONFIG_KEY, {}))
889
+ cls.perform_llm_health_check(
890
+ llm_config,
928
891
  DEFAULT_LLM_CONFIG,
929
- train_model_name,
930
- "enterprise_search_policy.load",
892
+ log_source_method,
931
893
  EnterpriseSearchPolicy.__name__,
932
894
  )
933
895
 
934
- train_embeddings_name = (
935
- persisted_config.get(TRAINED_EMBEDDINGS_CONFIG_KEY, None)
936
- if persisted_config
937
- else None
896
+ # Perform health check of the embeddings client config
897
+ embeddings_config = resolve_model_client_config(
898
+ config.get(EMBEDDINGS_CONFIG_KEY, {})
938
899
  )
939
- perform_inference_time_embeddings_health_check(
940
- resolved_embeddings_config,
900
+ cls.perform_embeddings_health_check(
901
+ embeddings_config,
941
902
  DEFAULT_EMBEDDINGS_CONFIG,
942
- train_embeddings_name,
943
- "enterprise_search_policy.load",
903
+ log_source_method,
944
904
  EnterpriseSearchPolicy.__name__,
945
905
  )
@@ -18,10 +18,6 @@ from rasa.core.constants import (
18
18
  UTTER_SOURCE_METADATA_KEY,
19
19
  )
20
20
  from rasa.core.policies.policy import Policy, PolicyPrediction, SupportedData
21
- from rasa.dialogue_understanding.generator.constants import (
22
- TRAINED_MODEL_NAME_CONFIG_KEY,
23
- TRAINED_EMBEDDINGS_CONFIG_KEY,
24
- )
25
21
  from rasa.dialogue_understanding.patterns.chitchat import FLOW_PATTERN_CHITCHAT
26
22
  from rasa.dialogue_understanding.stack.frames import (
27
23
  ChitChatStackFrame,
@@ -43,7 +39,7 @@ from rasa.shared.constants import (
43
39
  PROVIDER_CONFIG_KEY,
44
40
  OPENAI_PROVIDER,
45
41
  TIMEOUT_CONFIG_KEY,
46
- MODEL_GROUP_CONFIG_KEY,
42
+ MODEL_GROUP_ID_CONFIG_KEY,
47
43
  )
48
44
  from rasa.shared.core.constants import ACTION_LISTEN_NAME
49
45
  from rasa.shared.core.constants import ACTION_TRIGGER_CHITCHAT
@@ -64,6 +60,10 @@ from rasa.shared.providers.embedding._langchain_embedding_client_adapter import
64
60
  _LangchainEmbeddingClientAdapter,
65
61
  )
66
62
  from rasa.shared.providers.llm.llm_client import LLMClient
63
+ from rasa.shared.utils.health_check.embeddings_health_check_mixin import (
64
+ EmbeddingsHealthCheckMixin,
65
+ )
66
+ from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheckMixin
67
67
  from rasa.shared.utils.io import deep_container_fingerprint
68
68
  from rasa.shared.utils.llm import (
69
69
  AI,
@@ -79,12 +79,6 @@ from rasa.shared.utils.llm import (
79
79
  tracker_as_readable_transcript,
80
80
  resolve_model_client_config,
81
81
  )
82
- from rasa.shared.utils.health_check import (
83
- perform_training_time_llm_health_check,
84
- perform_training_time_embeddings_health_check,
85
- perform_inference_time_llm_health_check,
86
- perform_inference_time_embeddings_health_check,
87
- )
88
82
  from rasa.utils.log_utils import log_llm
89
83
  from rasa.utils.ml_utils import (
90
84
  extract_ai_response_examples,
@@ -383,7 +377,7 @@ def conversation_as_prompt(conversation: Conversation) -> str:
383
377
  @DefaultV1Recipe.register(
384
378
  DefaultV1Recipe.ComponentType.POLICY_WITH_END_TO_END_SUPPORT, is_trainable=True
385
379
  )
386
- class IntentlessPolicy(Policy):
380
+ class IntentlessPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Policy):
387
381
  """Policy which uses a language model to generate the next action.
388
382
 
389
383
  The policy uses the OpenAI API to generate the next action based on the
@@ -516,10 +510,8 @@ class IntentlessPolicy(Policy):
516
510
  A policy must return its resource locator so that potential children nodes
517
511
  can load the policy from the resource.
518
512
  """
519
- (
520
- self.config[TRAINED_MODEL_NAME_CONFIG_KEY],
521
- self.config[TRAINED_EMBEDDINGS_CONFIG_KEY],
522
- ) = self._perform_training_time_health_checks()
513
+ # Perform health checks of both LLM and embeddings client configs
514
+ self._perform_health_checks(self.config, "intentless_policy.train")
523
515
 
524
516
  responses = filter_responses(responses, forms, flows or FlowsList([]))
525
517
  telemetry.track_intentless_policy_train()
@@ -566,11 +558,13 @@ class IntentlessPolicy(Policy):
566
558
  embeddings_type=self.embeddings_property(PROVIDER_CONFIG_KEY),
567
559
  embeddings_model=self.embeddings_property(MODEL_CONFIG_KEY)
568
560
  or self.embeddings_property(MODEL_NAME_CONFIG_KEY),
569
- embeddings_model_group_id=self.embeddings_property(MODEL_GROUP_CONFIG_KEY),
561
+ embeddings_model_group_id=self.embeddings_property(
562
+ MODEL_GROUP_ID_CONFIG_KEY
563
+ ),
570
564
  llm_type=self.llm_property(PROVIDER_CONFIG_KEY),
571
565
  llm_model=self.llm_property(MODEL_CONFIG_KEY)
572
566
  or self.llm_property(MODEL_NAME_CONFIG_KEY),
573
- llm_model_group_id=self.llm_property(MODEL_GROUP_CONFIG_KEY),
567
+ llm_model_group_id=self.llm_property(MODEL_GROUP_ID_CONFIG_KEY),
574
568
  )
575
569
 
576
570
  self.persist()
@@ -650,11 +644,13 @@ class IntentlessPolicy(Policy):
650
644
  embeddings_type=self.embeddings_property(PROVIDER_CONFIG_KEY),
651
645
  embeddings_model=self.embeddings_property(MODEL_CONFIG_KEY)
652
646
  or self.embeddings_property(MODEL_NAME_CONFIG_KEY),
653
- embeddings_model_group_id=self.embeddings_property(MODEL_GROUP_CONFIG_KEY),
647
+ embeddings_model_group_id=self.embeddings_property(
648
+ MODEL_GROUP_ID_CONFIG_KEY
649
+ ),
654
650
  llm_type=self.llm_property(PROVIDER_CONFIG_KEY),
655
651
  llm_model=self.llm_property(MODEL_CONFIG_KEY)
656
652
  or self.llm_property(MODEL_NAME_CONFIG_KEY),
657
- llm_model_group_id=self.llm_property(MODEL_GROUP_CONFIG_KEY),
653
+ llm_model_group_id=self.llm_property(MODEL_GROUP_ID_CONFIG_KEY),
658
654
  score=score,
659
655
  )
660
656
 
@@ -952,10 +948,13 @@ class IntentlessPolicy(Policy):
952
948
  **kwargs: Any,
953
949
  ) -> "IntentlessPolicy":
954
950
  """Loads a trained policy (see parent class for full docstring)."""
951
+
952
+ # Perform health checks of both LLM and embeddings client configs
953
+ cls._perform_health_checks(config, "intentless_policy.load")
954
+
955
955
  responses_docsearch = None
956
956
  samples_docsearch = None
957
957
  prompt_template = None
958
- persisted_config = None
959
958
  try:
960
959
  with model_storage.read_from(resource) as path:
961
960
  responses_docsearch = load_faiss_vector_store(
@@ -973,15 +972,12 @@ class IntentlessPolicy(Policy):
973
972
  prompt_template = rasa.shared.utils.io.read_file(
974
973
  path / INTENTLESS_PROMPT_TEMPLATE_FILE_NAME
975
974
  )
976
- persisted_config = rasa.shared.utils.io.read_json_file(
977
- path / INTENTLESS_CONFIG_FILE_NAME
978
- )
979
975
  except (ValueError, FileNotFoundError, FileIOException) as e:
980
976
  structlogger.warning(
981
977
  "intentless_policy.load.failed", error=e, resource_name=resource.name
982
978
  )
983
979
 
984
- policy = cls(
980
+ return cls(
985
981
  config,
986
982
  model_storage,
987
983
  resource,
@@ -991,14 +987,6 @@ class IntentlessPolicy(Policy):
991
987
  prompt_template=prompt_template,
992
988
  )
993
989
 
994
- cls._perform_inference_time_health_checks(
995
- persisted_config,
996
- policy.config.get(LLM_CONFIG_KEY),
997
- policy.config.get(EMBEDDINGS_CONFIG_KEY),
998
- )
999
-
1000
- return policy
1001
-
1002
990
  @classmethod
1003
991
  def fingerprint_addon(cls, config: Dict[str, Any]) -> Optional[str]:
1004
992
  """Add a fingerprint of intentless policy for the graph."""
@@ -1018,52 +1006,26 @@ class IntentlessPolicy(Policy):
1018
1006
  [prompt_template, llm_config, embedding_config]
1019
1007
  )
1020
1008
 
1021
- def _perform_training_time_health_checks(
1022
- self,
1023
- ) -> Tuple[Optional[str], Optional[str]]:
1024
- train_model_name = perform_training_time_llm_health_check(
1025
- self.config.get(LLM_CONFIG_KEY),
1026
- DEFAULT_LLM_CONFIG,
1027
- "intentless_policy.train",
1028
- IntentlessPolicy.__name__,
1029
- )
1030
- train_embedding_name = perform_training_time_embeddings_health_check(
1031
- self.config.get(EMBEDDINGS_CONFIG_KEY),
1032
- DEFAULT_EMBEDDINGS_CONFIG,
1033
- "intentless_policy.train",
1034
- IntentlessPolicy.__name__,
1035
- )
1036
- return train_model_name, train_embedding_name
1037
-
1038
1009
  @classmethod
1039
- def _perform_inference_time_health_checks(
1040
- cls,
1041
- persisted_config: Optional[Dict[str, Any]],
1042
- resolved_llm_config: Optional[Dict[str, Any]],
1043
- resolved_embeddings_config: Optional[Dict[str, Any]],
1010
+ def _perform_health_checks(
1011
+ cls, config: Dict[Text, Any], log_source_method: str
1044
1012
  ) -> None:
1045
- train_model_name = (
1046
- persisted_config.get(TRAINED_MODEL_NAME_CONFIG_KEY, None)
1047
- if persisted_config
1048
- else None
1049
- )
1050
- perform_inference_time_llm_health_check(
1051
- resolved_llm_config,
1013
+ # Perform health check of the LLM client config
1014
+ llm_config = resolve_model_client_config(config.get(LLM_CONFIG_KEY, {}))
1015
+ cls.perform_llm_health_check(
1016
+ llm_config,
1052
1017
  DEFAULT_LLM_CONFIG,
1053
- train_model_name,
1054
- "intentless_policy.load",
1018
+ log_source_method,
1055
1019
  IntentlessPolicy.__name__,
1056
1020
  )
1057
1021
 
1058
- train_embeddings_name = (
1059
- persisted_config.get(TRAINED_EMBEDDINGS_CONFIG_KEY, None)
1060
- if persisted_config
1061
- else None
1022
+ # Perform health check of the embeddings client config
1023
+ embeddings_config = resolve_model_client_config(
1024
+ config.get(EMBEDDINGS_CONFIG_KEY, {})
1062
1025
  )
1063
- perform_inference_time_embeddings_health_check(
1064
- resolved_embeddings_config,
1026
+ cls.perform_embeddings_health_check(
1027
+ embeddings_config,
1065
1028
  DEFAULT_EMBEDDINGS_CONFIG,
1066
- train_embeddings_name,
1067
- "intentless_policy.load",
1029
+ log_source_method,
1068
1030
  IntentlessPolicy.__name__,
1069
1031
  )
@@ -17,7 +17,6 @@ from rasa.dialogue_understanding.commands import Command, SetSlotCommand
17
17
  from rasa.dialogue_understanding.commands.noop_command import NoopCommand
18
18
  from rasa.dialogue_understanding.generator.constants import (
19
19
  LLM_CONFIG_KEY,
20
- TRAINED_MODEL_NAME_CONFIG_KEY,
21
20
  )
22
21
  from rasa.engine.graph import ExecutionContext, GraphComponent
23
22
  from rasa.engine.recipes.default_recipe import DefaultV1Recipe
@@ -36,6 +35,7 @@ from rasa.shared.exceptions import InvalidConfigException, FileIOException
36
35
  from rasa.shared.nlu.constants import COMMANDS, TEXT
37
36
  from rasa.shared.nlu.training_data.message import Message
38
37
  from rasa.shared.nlu.training_data.training_data import TrainingData
38
+ from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheckMixin
39
39
  from rasa.shared.utils.io import deep_container_fingerprint
40
40
  from rasa.shared.utils.llm import (
41
41
  DEFAULT_OPENAI_CHAT_MODEL_NAME,
@@ -43,10 +43,6 @@ from rasa.shared.utils.llm import (
43
43
  llm_factory,
44
44
  resolve_model_client_config,
45
45
  )
46
- from rasa.shared.utils.health_check import (
47
- perform_training_time_llm_health_check,
48
- perform_inference_time_llm_health_check,
49
- )
50
46
  from rasa.utils.log_utils import log_llm
51
47
 
52
48
  LLM_BASED_ROUTER_PROMPT_FILE_NAME = "llm_based_router_prompt.jinja2"
@@ -80,7 +76,7 @@ structlogger = structlog.get_logger()
80
76
  ],
81
77
  is_trainable=True,
82
78
  )
83
- class LLMBasedRouter(GraphComponent):
79
+ class LLMBasedRouter(LLMHealthCheckMixin, GraphComponent):
84
80
  @staticmethod
85
81
  def get_default_config() -> Dict[str, Any]:
86
82
  """The component's default config (see parent class for full docstring)."""
@@ -144,13 +140,11 @@ class LLMBasedRouter(GraphComponent):
144
140
 
145
141
  def train(self, training_data: TrainingData) -> Resource:
146
142
  """Train the intent classifier on a data set."""
147
- self.config[TRAINED_MODEL_NAME_CONFIG_KEY] = (
148
- perform_training_time_llm_health_check(
149
- self.config.get(LLM_CONFIG_KEY),
150
- DEFAULT_LLM_CONFIG,
151
- "llm_based_router.train",
152
- LLMBasedRouter.__name__,
153
- )
143
+ self.perform_llm_health_check(
144
+ self.config.get(LLM_CONFIG_KEY),
145
+ DEFAULT_LLM_CONFIG,
146
+ "llm_based_router.train",
147
+ LLMBasedRouter.__name__,
154
148
  )
155
149
 
156
150
  self.persist()
@@ -166,37 +160,28 @@ class LLMBasedRouter(GraphComponent):
166
160
  **kwargs: Any,
167
161
  ) -> "LLMBasedRouter":
168
162
  """Loads trained component (see parent class for full docstring)."""
163
+
164
+ # Perform health check on the resolved LLM client config
165
+ llm_config = resolve_model_client_config(config.get(LLM_CONFIG_KEY, {}))
166
+ cls.perform_llm_health_check(
167
+ llm_config,
168
+ DEFAULT_LLM_CONFIG,
169
+ "llm_based_router.load",
170
+ LLMBasedRouter.__name__,
171
+ )
172
+
169
173
  prompt_template = None
170
- persisted_config = None
171
174
  try:
172
175
  with model_storage.read_from(resource) as path:
173
176
  prompt_template = rasa.shared.utils.io.read_file(
174
177
  path / LLM_BASED_ROUTER_PROMPT_FILE_NAME
175
178
  )
176
- persisted_config = rasa.shared.utils.io.read_json_file(
177
- path / LLM_BASED_ROUTER_CONFIG_FILE_NAME
178
- )
179
179
  except (FileNotFoundError, FileIOException) as e:
180
180
  structlogger.warning(
181
181
  "llm_based_router.load.failed", error=e, resource=resource.name
182
182
  )
183
183
 
184
- router = cls(config, model_storage, resource, prompt_template=prompt_template)
185
-
186
- train_model_name = (
187
- persisted_config.get(TRAINED_MODEL_NAME_CONFIG_KEY, None)
188
- if persisted_config
189
- else None
190
- )
191
- perform_inference_time_llm_health_check(
192
- router.config.get(LLM_CONFIG_KEY),
193
- DEFAULT_LLM_CONFIG,
194
- train_model_name,
195
- "llm_based_router.load",
196
- LLMBasedRouter.__name__,
197
- )
198
-
199
- return router
184
+ return cls(config, model_storage, resource, prompt_template=prompt_template)
200
185
 
201
186
  @classmethod
202
187
  def create(
@@ -18,8 +18,6 @@ DEFAULT_LLM_CONFIG = {
18
18
  }
19
19
 
20
20
  LLM_CONFIG_KEY = "llm"
21
- TRAINED_MODEL_NAME_CONFIG_KEY = "trained_llm_model_name"
22
- TRAINED_EMBEDDINGS_CONFIG_KEY = "trained_embeddings_model_name"
23
21
  USER_INPUT_CONFIG_KEY = "user_input"
24
22
 
25
23
  FLOW_RETRIEVAL_KEY = "flow_retrieval"