rasa-pro 3.11.0a4.dev2__py3-none-any.whl → 3.11.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (163) hide show
  1. rasa/__main__.py +22 -12
  2. rasa/api.py +1 -1
  3. rasa/cli/arguments/default_arguments.py +1 -2
  4. rasa/cli/arguments/shell.py +5 -1
  5. rasa/cli/e2e_test.py +1 -1
  6. rasa/cli/evaluate.py +8 -8
  7. rasa/cli/inspect.py +4 -4
  8. rasa/cli/llm_fine_tuning.py +1 -1
  9. rasa/cli/project_templates/calm/config.yml +5 -7
  10. rasa/cli/project_templates/calm/endpoints.yml +8 -0
  11. rasa/cli/project_templates/tutorial/config.yml +8 -5
  12. rasa/cli/project_templates/tutorial/data/flows.yml +1 -1
  13. rasa/cli/project_templates/tutorial/data/patterns.yml +5 -0
  14. rasa/cli/project_templates/tutorial/domain.yml +14 -0
  15. rasa/cli/project_templates/tutorial/endpoints.yml +7 -7
  16. rasa/cli/run.py +1 -1
  17. rasa/cli/scaffold.py +4 -2
  18. rasa/cli/utils.py +5 -0
  19. rasa/cli/x.py +8 -8
  20. rasa/constants.py +1 -1
  21. rasa/core/channels/channel.py +3 -0
  22. rasa/core/channels/inspector/dist/assets/{arc-6852c607.js → arc-bc141fb2.js} +1 -1
  23. rasa/core/channels/inspector/dist/assets/{c4Diagram-d0fbc5ce-acc952b2.js → c4Diagram-d0fbc5ce-be2db283.js} +1 -1
  24. rasa/core/channels/inspector/dist/assets/{classDiagram-936ed81e-848a7597.js → classDiagram-936ed81e-55366915.js} +1 -1
  25. rasa/core/channels/inspector/dist/assets/{classDiagram-v2-c3cb15f1-a73d3e68.js → classDiagram-v2-c3cb15f1-bb529518.js} +1 -1
  26. rasa/core/channels/inspector/dist/assets/{createText-62fc7601-e5ee049d.js → createText-62fc7601-b0ec81d6.js} +1 -1
  27. rasa/core/channels/inspector/dist/assets/{edges-f2ad444c-771e517e.js → edges-f2ad444c-6166330c.js} +1 -1
  28. rasa/core/channels/inspector/dist/assets/{erDiagram-9d236eb7-aa347178.js → erDiagram-9d236eb7-5ccc6a8e.js} +1 -1
  29. rasa/core/channels/inspector/dist/assets/{flowDb-1972c806-651fc57d.js → flowDb-1972c806-fca3bfe4.js} +1 -1
  30. rasa/core/channels/inspector/dist/assets/{flowDiagram-7ea5b25a-ca67804f.js → flowDiagram-7ea5b25a-4739080f.js} +1 -1
  31. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-736177bf.js +1 -0
  32. rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-abe16c3d-2dbc568d.js → flowchart-elk-definition-abe16c3d-7c1b0e0f.js} +1 -1
  33. rasa/core/channels/inspector/dist/assets/{ganttDiagram-9b5ea136-25a65bd8.js → ganttDiagram-9b5ea136-772fd050.js} +1 -1
  34. rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-99d0ae7c-fdc7378d.js → gitGraphDiagram-99d0ae7c-8eae1dc9.js} +1 -1
  35. rasa/core/channels/inspector/dist/assets/{index-2c4b9a3b-6f1fd606.js → index-2c4b9a3b-f55afcdf.js} +1 -1
  36. rasa/core/channels/inspector/dist/assets/{index-efdd30c1.js → index-e7cef9de.js} +68 -68
  37. rasa/core/channels/inspector/dist/assets/{infoDiagram-736b4530-cb1a041a.js → infoDiagram-736b4530-124d4a14.js} +1 -1
  38. rasa/core/channels/inspector/dist/assets/{journeyDiagram-df861f2b-14609879.js → journeyDiagram-df861f2b-7c4fae44.js} +1 -1
  39. rasa/core/channels/inspector/dist/assets/{layout-2490f52b.js → layout-b9885fb6.js} +1 -1
  40. rasa/core/channels/inspector/dist/assets/{line-40186f1f.js → line-7c59abb6.js} +1 -1
  41. rasa/core/channels/inspector/dist/assets/{linear-08814e93.js → linear-4776f780.js} +1 -1
  42. rasa/core/channels/inspector/dist/assets/{mindmap-definition-beec6740-1a534584.js → mindmap-definition-beec6740-2332c46c.js} +1 -1
  43. rasa/core/channels/inspector/dist/assets/{pieDiagram-dbbf0591-72397b61.js → pieDiagram-dbbf0591-8fb39303.js} +1 -1
  44. rasa/core/channels/inspector/dist/assets/{quadrantDiagram-4d7f4fd6-3bb0b6a3.js → quadrantDiagram-4d7f4fd6-3c7180a2.js} +1 -1
  45. rasa/core/channels/inspector/dist/assets/{requirementDiagram-6fc4c22a-57334f61.js → requirementDiagram-6fc4c22a-e910bcb8.js} +1 -1
  46. rasa/core/channels/inspector/dist/assets/{sankeyDiagram-8f13d901-111e1297.js → sankeyDiagram-8f13d901-ead16c89.js} +1 -1
  47. rasa/core/channels/inspector/dist/assets/{sequenceDiagram-b655622a-10bcfe62.js → sequenceDiagram-b655622a-29a02a19.js} +1 -1
  48. rasa/core/channels/inspector/dist/assets/{stateDiagram-59f0c015-acaf7513.js → stateDiagram-59f0c015-042b3137.js} +1 -1
  49. rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-2b26beab-3ec2a235.js → stateDiagram-v2-2b26beab-2178c0f3.js} +1 -1
  50. rasa/core/channels/inspector/dist/assets/{styles-080da4f6-62730289.js → styles-080da4f6-23ffa4fc.js} +1 -1
  51. rasa/core/channels/inspector/dist/assets/{styles-3dcbcfbf-5284ee76.js → styles-3dcbcfbf-94f59763.js} +1 -1
  52. rasa/core/channels/inspector/dist/assets/{styles-9c745c82-642435e3.js → styles-9c745c82-78a6bebc.js} +1 -1
  53. rasa/core/channels/inspector/dist/assets/{svgDrawCommon-4835440b-b250a350.js → svgDrawCommon-4835440b-eae2a6f6.js} +1 -1
  54. rasa/core/channels/inspector/dist/assets/{timeline-definition-5b62e21b-c2b147ed.js → timeline-definition-5b62e21b-5c968d92.js} +1 -1
  55. rasa/core/channels/inspector/dist/assets/{xychartDiagram-2b33534f-f92cfea9.js → xychartDiagram-2b33534f-fd3db0d5.js} +1 -1
  56. rasa/core/channels/inspector/dist/index.html +1 -1
  57. rasa/core/channels/inspector/src/App.tsx +1 -1
  58. rasa/core/channels/inspector/src/helpers/audiostream.ts +77 -16
  59. rasa/core/channels/socketio.py +2 -1
  60. rasa/core/channels/telegram.py +1 -1
  61. rasa/core/channels/twilio.py +1 -1
  62. rasa/core/channels/voice_ready/jambonz.py +2 -2
  63. rasa/core/channels/voice_stream/asr/asr_event.py +5 -0
  64. rasa/core/channels/voice_stream/asr/azure.py +122 -0
  65. rasa/core/channels/voice_stream/asr/deepgram.py +16 -6
  66. rasa/core/channels/voice_stream/audio_bytes.py +1 -0
  67. rasa/core/channels/voice_stream/browser_audio.py +31 -8
  68. rasa/core/channels/voice_stream/call_state.py +23 -0
  69. rasa/core/channels/voice_stream/tts/azure.py +6 -2
  70. rasa/core/channels/voice_stream/tts/cartesia.py +10 -6
  71. rasa/core/channels/voice_stream/tts/tts_engine.py +1 -0
  72. rasa/core/channels/voice_stream/twilio_media_streams.py +27 -18
  73. rasa/core/channels/voice_stream/util.py +4 -4
  74. rasa/core/channels/voice_stream/voice_channel.py +177 -39
  75. rasa/core/featurizers/single_state_featurizer.py +22 -1
  76. rasa/core/featurizers/tracker_featurizers.py +115 -18
  77. rasa/core/nlg/contextual_response_rephraser.py +16 -22
  78. rasa/core/persistor.py +86 -39
  79. rasa/core/policies/enterprise_search_policy.py +159 -60
  80. rasa/core/policies/flows/flow_executor.py +7 -4
  81. rasa/core/policies/intentless_policy.py +120 -22
  82. rasa/core/policies/ted_policy.py +58 -33
  83. rasa/core/policies/unexpected_intent_policy.py +15 -7
  84. rasa/core/processor.py +25 -0
  85. rasa/core/training/interactive.py +34 -35
  86. rasa/core/utils.py +8 -3
  87. rasa/dialogue_understanding/coexistence/llm_based_router.py +58 -16
  88. rasa/dialogue_understanding/commands/change_flow_command.py +6 -0
  89. rasa/dialogue_understanding/commands/user_silence_command.py +59 -0
  90. rasa/dialogue_understanding/commands/utils.py +5 -0
  91. rasa/dialogue_understanding/generator/constants.py +4 -0
  92. rasa/dialogue_understanding/generator/flow_retrieval.py +65 -3
  93. rasa/dialogue_understanding/generator/llm_based_command_generator.py +68 -26
  94. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +57 -8
  95. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +64 -7
  96. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +39 -0
  97. rasa/dialogue_understanding/patterns/user_silence.py +37 -0
  98. rasa/e2e_test/e2e_test_runner.py +4 -2
  99. rasa/e2e_test/utils/io.py +1 -1
  100. rasa/engine/validation.py +297 -7
  101. rasa/model_manager/config.py +17 -3
  102. rasa/model_manager/model_api.py +16 -8
  103. rasa/model_manager/runner_service.py +8 -6
  104. rasa/model_manager/socket_bridge.py +6 -3
  105. rasa/model_manager/trainer_service.py +7 -5
  106. rasa/model_manager/utils.py +28 -7
  107. rasa/model_service.py +7 -5
  108. rasa/model_training.py +2 -0
  109. rasa/nlu/classifiers/diet_classifier.py +38 -25
  110. rasa/nlu/classifiers/logistic_regression_classifier.py +22 -9
  111. rasa/nlu/classifiers/sklearn_intent_classifier.py +37 -16
  112. rasa/nlu/extractors/crf_entity_extractor.py +93 -50
  113. rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +45 -16
  114. rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +52 -17
  115. rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +5 -3
  116. rasa/shared/constants.py +36 -3
  117. rasa/shared/core/constants.py +7 -0
  118. rasa/shared/core/domain.py +26 -0
  119. rasa/shared/core/flows/flow.py +5 -0
  120. rasa/shared/core/flows/flows_yaml_schema.json +10 -0
  121. rasa/shared/core/flows/utils.py +39 -0
  122. rasa/shared/core/flows/validation.py +96 -0
  123. rasa/shared/core/slots.py +5 -0
  124. rasa/shared/nlu/training_data/features.py +120 -2
  125. rasa/shared/providers/_configs/azure_openai_client_config.py +5 -3
  126. rasa/shared/providers/_configs/litellm_router_client_config.py +200 -0
  127. rasa/shared/providers/_configs/model_group_config.py +167 -0
  128. rasa/shared/providers/_configs/openai_client_config.py +1 -1
  129. rasa/shared/providers/_configs/rasa_llm_client_config.py +73 -0
  130. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +1 -0
  131. rasa/shared/providers/_configs/utils.py +16 -0
  132. rasa/shared/providers/embedding/_base_litellm_embedding_client.py +12 -15
  133. rasa/shared/providers/embedding/azure_openai_embedding_client.py +54 -21
  134. rasa/shared/providers/embedding/litellm_router_embedding_client.py +135 -0
  135. rasa/shared/providers/llm/_base_litellm_client.py +31 -30
  136. rasa/shared/providers/llm/azure_openai_llm_client.py +50 -29
  137. rasa/shared/providers/llm/litellm_router_llm_client.py +127 -0
  138. rasa/shared/providers/llm/rasa_llm_client.py +112 -0
  139. rasa/shared/providers/llm/self_hosted_llm_client.py +1 -1
  140. rasa/shared/providers/mappings.py +19 -0
  141. rasa/shared/providers/router/__init__.py +0 -0
  142. rasa/shared/providers/router/_base_litellm_router_client.py +149 -0
  143. rasa/shared/providers/router/router_client.py +73 -0
  144. rasa/shared/utils/common.py +8 -0
  145. rasa/shared/utils/health_check.py +533 -0
  146. rasa/shared/utils/io.py +28 -6
  147. rasa/shared/utils/llm.py +350 -46
  148. rasa/shared/utils/yaml.py +11 -13
  149. rasa/studio/upload.py +64 -20
  150. rasa/telemetry.py +80 -17
  151. rasa/tracing/instrumentation/attribute_extractors.py +74 -17
  152. rasa/utils/io.py +0 -66
  153. rasa/utils/log_utils.py +9 -2
  154. rasa/utils/tensorflow/feature_array.py +366 -0
  155. rasa/utils/tensorflow/model_data.py +2 -193
  156. rasa/validator.py +70 -0
  157. rasa/version.py +1 -1
  158. {rasa_pro-3.11.0a4.dev2.dist-info → rasa_pro-3.11.0rc1.dist-info}/METADATA +10 -10
  159. {rasa_pro-3.11.0a4.dev2.dist-info → rasa_pro-3.11.0rc1.dist-info}/RECORD +162 -146
  160. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-587d82d8.js +0 -1
  161. {rasa_pro-3.11.0a4.dev2.dist-info → rasa_pro-3.11.0rc1.dist-info}/NOTICE +0 -0
  162. {rasa_pro-3.11.0a4.dev2.dist-info → rasa_pro-3.11.0rc1.dist-info}/WHEEL +0 -0
  163. {rasa_pro-3.11.0a4.dev2.dist-info → rasa_pro-3.11.0rc1.dist-info}/entry_points.txt +0 -0
@@ -1,11 +1,9 @@
1
1
  from __future__ import annotations
2
- from pathlib import Path
3
- from collections import defaultdict
4
- from abc import abstractmethod
5
- import jsonpickle
6
- import logging
7
2
 
8
- from tqdm import tqdm
3
+ import logging
4
+ from abc import abstractmethod
5
+ from collections import defaultdict
6
+ from pathlib import Path
9
7
  from typing import (
10
8
  Tuple,
11
9
  List,
@@ -18,25 +16,30 @@ from typing import (
18
16
  Set,
19
17
  DefaultDict,
20
18
  cast,
19
+ Type,
20
+ Callable,
21
+ ClassVar,
21
22
  )
23
+
22
24
  import numpy as np
25
+ from tqdm import tqdm
23
26
 
24
- from rasa.core.featurizers.single_state_featurizer import SingleStateFeaturizer
25
- from rasa.core.featurizers.precomputation import MessageContainerForCoreFeaturization
26
- from rasa.core.exceptions import InvalidTrackerFeaturizerUsageError
27
27
  import rasa.shared.core.trackers
28
28
  import rasa.shared.utils.io
29
- from rasa.shared.nlu.constants import TEXT, INTENT, ENTITIES, ACTION_NAME
30
- from rasa.shared.nlu.training_data.features import Features
31
- from rasa.shared.core.trackers import DialogueStateTracker
32
- from rasa.shared.core.domain import State, Domain
33
- from rasa.shared.core.events import Event, ActionExecuted, UserUttered
29
+ from rasa.core.exceptions import InvalidTrackerFeaturizerUsageError
30
+ from rasa.core.featurizers.precomputation import MessageContainerForCoreFeaturization
31
+ from rasa.core.featurizers.single_state_featurizer import SingleStateFeaturizer
34
32
  from rasa.shared.core.constants import (
35
33
  USER,
36
34
  ACTION_UNLIKELY_INTENT_NAME,
37
35
  PREVIOUS_ACTION,
38
36
  )
37
+ from rasa.shared.core.domain import State, Domain
38
+ from rasa.shared.core.events import Event, ActionExecuted, UserUttered
39
+ from rasa.shared.core.trackers import DialogueStateTracker
39
40
  from rasa.shared.exceptions import RasaException
41
+ from rasa.shared.nlu.constants import TEXT, INTENT, ENTITIES, ACTION_NAME
42
+ from rasa.shared.nlu.training_data.features import Features
40
43
  from rasa.utils.tensorflow.constants import LABEL_PAD_ID
41
44
  from rasa.utils.tensorflow.model_data import ragged_array_to_ndarray
42
45
 
@@ -64,6 +67,10 @@ class InvalidStory(RasaException):
64
67
  class TrackerFeaturizer:
65
68
  """Base class for actual tracker featurizers."""
66
69
 
70
+ # Class registry to store all subclasses
71
+ _registry: ClassVar[Dict[str, Type["TrackerFeaturizer"]]] = {}
72
+ _featurizer_type: str = "TrackerFeaturizer"
73
+
67
74
  def __init__(
68
75
  self, state_featurizer: Optional[SingleStateFeaturizer] = None
69
76
  ) -> None:
@@ -74,6 +81,36 @@ class TrackerFeaturizer:
74
81
  """
75
82
  self.state_featurizer = state_featurizer
76
83
 
84
+ @classmethod
85
+ def register(cls, featurizer_type: str) -> Callable:
86
+ """Decorator to register featurizer subclasses."""
87
+
88
+ def wrapper(subclass: Type["TrackerFeaturizer"]) -> Type["TrackerFeaturizer"]:
89
+ cls._registry[featurizer_type] = subclass
90
+ # Store the type identifier in the class for serialization
91
+ subclass._featurizer_type = featurizer_type
92
+ return subclass
93
+
94
+ return wrapper
95
+
96
+ @classmethod
97
+ def from_dict(cls, data: Dict[str, Any]) -> "TrackerFeaturizer":
98
+ """Create featurizer instance from dictionary."""
99
+ featurizer_type = data.pop("type")
100
+
101
+ if featurizer_type not in cls._registry:
102
+ raise ValueError(f"Unknown featurizer type: {featurizer_type}")
103
+
104
+ # Get the correct subclass and instantiate it
105
+ subclass = cls._registry[featurizer_type]
106
+ return subclass.create_from_dict(data)
107
+
108
+ @classmethod
109
+ @abstractmethod
110
+ def create_from_dict(cls, data: Dict[str, Any]) -> "TrackerFeaturizer":
111
+ """Each subclass must implement its own creation from dict method."""
112
+ pass
113
+
77
114
  @staticmethod
78
115
  def _create_states(
79
116
  tracker: DialogueStateTracker,
@@ -465,9 +502,7 @@ class TrackerFeaturizer:
465
502
  self.state_featurizer.entity_tag_specs = []
466
503
 
467
504
  # noinspection PyTypeChecker
468
- rasa.shared.utils.io.write_text_file(
469
- str(jsonpickle.encode(self)), featurizer_file
470
- )
505
+ rasa.shared.utils.io.dump_obj_as_json_to_file(featurizer_file, self.to_dict())
471
506
 
472
507
  @staticmethod
473
508
  def load(path: Union[Text, Path]) -> Optional[TrackerFeaturizer]:
@@ -481,7 +516,17 @@ class TrackerFeaturizer:
481
516
  """
482
517
  featurizer_file = Path(path) / FEATURIZER_FILE
483
518
  if featurizer_file.is_file():
484
- return jsonpickle.decode(rasa.shared.utils.io.read_file(featurizer_file))
519
+ data = rasa.shared.utils.io.read_json_file(featurizer_file)
520
+
521
+ if "type" not in data:
522
+ logger.error(
523
+ f"Couldn't load featurizer for policy. "
524
+ f"File '{featurizer_file}' does not contain all "
525
+ f"necessary information. 'type' is missing."
526
+ )
527
+ return None
528
+
529
+ return TrackerFeaturizer.from_dict(data)
485
530
 
486
531
  logger.error(
487
532
  f"Couldn't load featurizer for policy. "
@@ -508,7 +553,16 @@ class TrackerFeaturizer:
508
553
  )
509
554
  ]
510
555
 
556
+ def to_dict(self) -> Dict[str, Any]:
557
+ return {
558
+ "type": self.__class__._featurizer_type,
559
+ "state_featurizer": (
560
+ self.state_featurizer.to_dict() if self.state_featurizer else None
561
+ ),
562
+ }
563
+
511
564
 
565
+ @TrackerFeaturizer.register("FullDialogueTrackerFeaturizer")
512
566
  class FullDialogueTrackerFeaturizer(TrackerFeaturizer):
513
567
  """Creates full dialogue training data for time distributed architectures.
514
568
 
@@ -646,7 +700,20 @@ class FullDialogueTrackerFeaturizer(TrackerFeaturizer):
646
700
 
647
701
  return trackers_as_states
648
702
 
703
+ def to_dict(self) -> Dict[str, Any]:
704
+ return super().to_dict()
649
705
 
706
+ @classmethod
707
+ def create_from_dict(cls, data: Dict[str, Any]) -> "FullDialogueTrackerFeaturizer":
708
+ state_featurizer = SingleStateFeaturizer.create_from_dict(
709
+ data["state_featurizer"]
710
+ )
711
+ return cls(
712
+ state_featurizer,
713
+ )
714
+
715
+
716
+ @TrackerFeaturizer.register("MaxHistoryTrackerFeaturizer")
650
717
  class MaxHistoryTrackerFeaturizer(TrackerFeaturizer):
651
718
  """Truncates the tracker history into `max_history` long sequences.
652
719
 
@@ -884,7 +951,25 @@ class MaxHistoryTrackerFeaturizer(TrackerFeaturizer):
884
951
 
885
952
  return trackers_as_states
886
953
 
954
+ def to_dict(self) -> Dict[str, Any]:
955
+ data = super().to_dict()
956
+ data.update(
957
+ {
958
+ "remove_duplicates": self.remove_duplicates,
959
+ "max_history": self.max_history,
960
+ }
961
+ )
962
+ return data
963
+
964
+ @classmethod
965
+ def create_from_dict(cls, data: Dict[str, Any]) -> "MaxHistoryTrackerFeaturizer":
966
+ state_featurizer = SingleStateFeaturizer.create_from_dict(
967
+ data["state_featurizer"]
968
+ )
969
+ return cls(state_featurizer, data["max_history"], data["remove_duplicates"])
887
970
 
971
+
972
+ @TrackerFeaturizer.register("IntentMaxHistoryTrackerFeaturizer")
888
973
  class IntentMaxHistoryTrackerFeaturizer(MaxHistoryTrackerFeaturizer):
889
974
  """Truncates the tracker history into `max_history` long sequences.
890
975
 
@@ -1159,6 +1244,18 @@ class IntentMaxHistoryTrackerFeaturizer(MaxHistoryTrackerFeaturizer):
1159
1244
 
1160
1245
  return trackers_as_states
1161
1246
 
1247
+ def to_dict(self) -> Dict[str, Any]:
1248
+ return super().to_dict()
1249
+
1250
+ @classmethod
1251
+ def create_from_dict(
1252
+ cls, data: Dict[str, Any]
1253
+ ) -> "IntentMaxHistoryTrackerFeaturizer":
1254
+ state_featurizer = SingleStateFeaturizer.create_from_dict(
1255
+ data["state_featurizer"]
1256
+ )
1257
+ return cls(state_featurizer, data["max_history"], data["remove_duplicates"])
1258
+
1162
1259
 
1163
1260
  def _is_prev_action_unlikely_intent_in_state(state: State) -> bool:
1164
1261
  prev_action_name = state.get(PREVIOUS_ACTION, {}).get(ACTION_NAME)
@@ -1,13 +1,12 @@
1
1
  from typing import Any, Dict, Optional, Text
2
2
 
3
- import os
4
3
  import structlog
5
4
  from jinja2 import Template
6
5
 
7
6
  from rasa import telemetry
8
7
  from rasa.core.nlg.response import TemplatedNaturalLanguageGenerator
8
+ from rasa.core.nlg.summarize import summarize_conversation
9
9
  from rasa.shared.constants import (
10
- LLM_API_HEALTH_CHECK_ENV_VAR,
11
10
  LLM_CONFIG_KEY,
12
11
  MODEL_CONFIG_KEY,
13
12
  MODEL_NAME_CONFIG_KEY,
@@ -15,6 +14,7 @@ from rasa.shared.constants import (
15
14
  PROVIDER_CONFIG_KEY,
16
15
  OPENAI_PROVIDER,
17
16
  TIMEOUT_CONFIG_KEY,
17
+ MODEL_GROUP_CONFIG_KEY,
18
18
  )
19
19
  from rasa.shared.core.domain import KEY_RESPONSES_TEXT, Domain
20
20
  from rasa.shared.core.events import BotUttered, UserUttered
@@ -25,17 +25,14 @@ from rasa.shared.utils.llm import (
25
25
  USER,
26
26
  combine_custom_and_default_config,
27
27
  get_prompt_template,
28
- llm_api_health_check,
29
28
  llm_factory,
30
- try_instantiate_llm_client,
29
+ resolve_model_client_config,
31
30
  )
32
- from rasa.utils.endpoints import EndpointConfig
31
+ from rasa.shared.utils.health_check import perform_training_time_llm_health_check
33
32
  from rasa.shared.utils.llm import (
34
33
  tracker_as_readable_transcript,
35
34
  )
36
-
37
- from rasa.core.nlg.summarize import summarize_conversation
38
-
35
+ from rasa.utils.endpoints import EndpointConfig
39
36
  from rasa.utils.log_utils import log_llm
40
37
 
41
38
  structlogger = structlog.get_logger()
@@ -105,18 +102,18 @@ class ContextualResponseRephraser(TemplatedNaturalLanguageGenerator):
105
102
  self.trace_prompt_tokens = self.nlg_endpoint.kwargs.get(
106
103
  "trace_prompt_tokens", False
107
104
  )
108
- llm_client = try_instantiate_llm_client(
105
+
106
+ self.llm_config = resolve_model_client_config(
109
107
  self.nlg_endpoint.kwargs.get(LLM_CONFIG_KEY),
108
+ ContextualResponseRephraser.__name__,
109
+ )
110
+
111
+ perform_training_time_llm_health_check(
112
+ self.llm_config,
110
113
  DEFAULT_LLM_CONFIG,
111
114
  "contextual_response_rephraser.init",
112
115
  ContextualResponseRephraser.__name__,
113
116
  )
114
- if os.getenv(LLM_API_HEALTH_CHECK_ENV_VAR, "true").lower() == "true":
115
- llm_api_health_check(
116
- llm_client,
117
- "contextual_response_rephraser.init",
118
- ContextualResponseRephraser.__name__,
119
- )
120
117
 
121
118
  def _last_message_if_human(self, tracker: DialogueStateTracker) -> Optional[str]:
122
119
  """Returns the latest message from the tracker.
@@ -145,9 +142,7 @@ class ContextualResponseRephraser(TemplatedNaturalLanguageGenerator):
145
142
  Returns:
146
143
  generated text
147
144
  """
148
- llm = llm_factory(
149
- self.nlg_endpoint.kwargs.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG
150
- )
145
+ llm = llm_factory(self.llm_config, DEFAULT_LLM_CONFIG)
151
146
 
152
147
  try:
153
148
  llm_response = await llm.acompletion(prompt)
@@ -161,7 +156,7 @@ class ContextualResponseRephraser(TemplatedNaturalLanguageGenerator):
161
156
  def llm_property(self, prop: str) -> Optional[str]:
162
157
  """Returns a property of the LLM provider."""
163
158
  return combine_custom_and_default_config(
164
- self.nlg_endpoint.kwargs.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG
159
+ self.llm_config, DEFAULT_LLM_CONFIG
165
160
  ).get(prop)
166
161
 
167
162
  def custom_prompt_template(self, prompt_template: str) -> Optional[str]:
@@ -194,9 +189,7 @@ class ContextualResponseRephraser(TemplatedNaturalLanguageGenerator):
194
189
  Returns:
195
190
  The history for the prompt.
196
191
  """
197
- llm = llm_factory(
198
- self.nlg_endpoint.kwargs.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG
199
- )
192
+ llm = llm_factory(self.llm_config, DEFAULT_LLM_CONFIG)
200
193
  return await summarize_conversation(tracker, llm, max_turns=5)
201
194
 
202
195
  async def rephrase(
@@ -252,6 +245,7 @@ class ContextualResponseRephraser(TemplatedNaturalLanguageGenerator):
252
245
  llm_type=self.llm_property(PROVIDER_CONFIG_KEY),
253
246
  llm_model=self.llm_property(MODEL_CONFIG_KEY)
254
247
  or self.llm_property(MODEL_NAME_CONFIG_KEY),
248
+ llm_model_group_id=self.llm_property(MODEL_GROUP_CONFIG_KEY),
255
249
  )
256
250
  if not (updated_text := await self._generate_llm_response(prompt)):
257
251
  # If the LLM fails to generate a response, we
rasa/core/persistor.py CHANGED
@@ -30,7 +30,7 @@ from rasa.shared.utils.io import raise_warning
30
30
 
31
31
  if TYPE_CHECKING:
32
32
  from azure.storage.blob import ContainerClient
33
- import botocore
33
+ from botocore.exceptions import ClientError
34
34
 
35
35
  structlogger = structlog.get_logger()
36
36
 
@@ -233,13 +233,13 @@ class AWSPersistor(Persistor):
233
233
  def _ensure_bucket_exists(
234
234
  self, bucket_name: Text, region_name: Optional[Text] = None
235
235
  ) -> None:
236
- import botocore
236
+ from botocore import exceptions
237
237
 
238
238
  # noinspection PyUnresolvedReferences
239
239
  try:
240
240
  self.s3.meta.client.head_bucket(Bucket=bucket_name)
241
- except botocore.exceptions.ClientError as e:
242
- if self.error_code(e) == HTTP_STATUS_FORBIDDEN:
241
+ except exceptions.ClientError as exc:
242
+ if self._error_code(exc) == HTTP_STATUS_FORBIDDEN:
243
243
  log = (
244
244
  f"Access to the specified bucket '{bucket_name}' is forbidden. "
245
245
  "Please make sure you have the necessary "
@@ -251,7 +251,7 @@ class AWSPersistor(Persistor):
251
251
  event_info=log,
252
252
  )
253
253
  raise RasaException(log)
254
- elif self.error_code(e) == HTTP_STATUS_NOT_FOUND:
254
+ elif self._error_code(exc) == HTTP_STATUS_NOT_FOUND:
255
255
  log = (
256
256
  f"The specified bucket '{bucket_name}' does not exist. "
257
257
  "Please make sure to create the bucket first."
@@ -264,7 +264,7 @@ class AWSPersistor(Persistor):
264
264
  raise RasaException(log)
265
265
 
266
266
  @staticmethod
267
- def error_code(e: botocore.exceptions.ClientError) -> int:
267
+ def _error_code(e: "ClientError") -> int:
268
268
  return int(e.response["Error"]["Code"])
269
269
 
270
270
  def _persist_tar(self, file_key: Text, tar_path: Text) -> None:
@@ -274,26 +274,48 @@ class AWSPersistor(Persistor):
274
274
 
275
275
  def _retrieve_tar(self, model_path: Text) -> None:
276
276
  """Downloads a model that has previously been persisted to s3."""
277
- import botocore
277
+ from botocore import exceptions
278
278
 
279
279
  target_filename = os.path.basename(model_path)
280
- try:
281
- with open(target_filename, "wb") as f:
282
- self.bucket.download_fileobj(model_path, f)
283
- except botocore.exceptions.ClientError as e:
284
- if self.error_code(e) == HTTP_STATUS_NOT_FOUND:
285
- log = (
286
- f"Model '{target_filename}' not found in the specified bucket "
287
- f"'{self.bucket_name}'. Please make sure the model exists "
288
- f"in the bucket."
289
- )
290
- structlogger.error(
291
- "gcp_persistor.retrieve_tar.model_not_found",
292
- bucket_name=self.bucket_name,
293
- target_filename=target_filename,
294
- event_info=log,
295
- )
296
- raise ModelNotFound() from e
280
+ bucket_objects = list(self.bucket.objects.all())
281
+
282
+ model_found = False
283
+
284
+ log = (
285
+ f"Model '{target_filename}' not found in the specified bucket "
286
+ f"'{self.bucket_name}'. Please make sure the model exists "
287
+ f"in the bucket."
288
+ )
289
+
290
+ for obj in bucket_objects:
291
+ if model_path not in obj.key:
292
+ continue
293
+ structlogger.debug(
294
+ "aws_persistor.retrieve_tar.object_found", object_key=obj.key
295
+ )
296
+
297
+ try:
298
+ with open(target_filename, "wb") as f:
299
+ self.bucket.download_fileobj(obj.key, f)
300
+ model_found = True
301
+ break
302
+ except exceptions.ClientError as exc:
303
+ if self._error_code(exc) == HTTP_STATUS_NOT_FOUND:
304
+ structlogger.error(
305
+ "aws_persistor.retrieve_tar.model_not_found",
306
+ bucket_name=self.bucket_name,
307
+ target_filename=target_filename,
308
+ event_info=log,
309
+ )
310
+ raise ModelNotFound() from exc
311
+ if not model_found:
312
+ structlogger.error(
313
+ "aws_persistor.retrieve_tar.model_not_found",
314
+ bucket_name=self.bucket_name,
315
+ target_filename=target_filename,
316
+ event_info=log,
317
+ )
318
+ raise ModelNotFound()
297
319
 
298
320
 
299
321
  class GCSPersistor(Persistor):
@@ -322,7 +344,7 @@ class GCSPersistor(Persistor):
322
344
 
323
345
  try:
324
346
  self.storage_client.get_bucket(bucket_name)
325
- except auth_exceptions.GoogleAuthError as e:
347
+ except auth_exceptions.GoogleAuthError as exc:
326
348
  log = (
327
349
  f"An error occurred while authenticating with Google Cloud "
328
350
  f"Storage. Please make sure you have the necessary credentials "
@@ -333,8 +355,8 @@ class GCSPersistor(Persistor):
333
355
  bucket_name=bucket_name,
334
356
  event_info=log,
335
357
  )
336
- raise RasaException(log) from e
337
- except exceptions.NotFound as e:
358
+ raise RasaException(log) from exc
359
+ except exceptions.NotFound as exc:
338
360
  log = (
339
361
  f"The specified Google Cloud Storage bucket '{bucket_name}' "
340
362
  f"does not exist. Please make sure to create the bucket first or "
@@ -345,20 +367,20 @@ class GCSPersistor(Persistor):
345
367
  bucket_name=bucket_name,
346
368
  event_info=log,
347
369
  )
348
- raise RasaException(log) from e
349
- except exceptions.Forbidden as e:
370
+ raise RasaException(log) from exc
371
+ except exceptions.Forbidden as exc:
350
372
  log = (
351
373
  f"Access to the specified Google Cloud storage bucket '{bucket_name}' "
352
374
  f"is forbidden. Please make sure you have the necessary "
353
- f"permission to access the bucket. "
375
+ f"permissions to access the bucket. "
354
376
  )
355
377
  structlogger.error(
356
378
  "gcp_persistor.ensure_bucket_exists.bucket_access_forbidden",
357
379
  bucket_name=bucket_name,
358
380
  event_info=log,
359
381
  )
360
- raise RasaException(log) from e
361
- except ValueError as e:
382
+ raise RasaException(log) from exc
383
+ except ValueError as exc:
362
384
  # bucket_name is None
363
385
  log = (
364
386
  "The specified Google Cloud Storage bucket name is None. Please "
@@ -368,7 +390,7 @@ class GCSPersistor(Persistor):
368
390
  "gcp_persistor.ensure_bucket_exists.bucket_name_none",
369
391
  event_info=log,
370
392
  )
371
- raise RasaException(log) from e
393
+ raise RasaException(log) from exc
372
394
 
373
395
  def _persist_tar(self, file_key: Text, tar_path: Text) -> None:
374
396
  """Uploads a model persisted in the `target_dir` to GCS."""
@@ -382,7 +404,7 @@ class GCSPersistor(Persistor):
382
404
  blob = self.bucket.blob(target_filename)
383
405
  try:
384
406
  blob.download_to_filename(target_filename)
385
- except exceptions.NotFound as e:
407
+ except exceptions.NotFound as exc:
386
408
  log = (
387
409
  f"Model '{target_filename}' not found in the specified bucket "
388
410
  f"'{self.bucket_name}'. Please make sure the model exists "
@@ -394,7 +416,7 @@ class GCSPersistor(Persistor):
394
416
  target_filename=target_filename,
395
417
  event_info=log,
396
418
  )
397
- raise ModelNotFound() from e
419
+ raise ModelNotFound() from exc
398
420
 
399
421
 
400
422
  class AzurePersistor(Persistor):
@@ -440,8 +462,33 @@ class AzurePersistor(Persistor):
440
462
 
441
463
  def _retrieve_tar(self, target_filename: Text) -> None:
442
464
  """Downloads a model that has previously been persisted to Azure."""
443
- blob_client = self._container_client().get_blob_client(target_filename)
465
+ try:
466
+ blob_list = self._container_client().list_blobs()
467
+
468
+ for blob in blob_list:
469
+ if target_filename not in blob.name:
470
+ continue
471
+
472
+ structlogger.debug(
473
+ "azure_persistor.retrieve_tar.blob_found", blob_name=blob.name
474
+ )
444
475
 
445
- with open(target_filename, "wb") as blob:
446
- download_stream = blob_client.download_blob()
447
- blob.write(download_stream.readall())
476
+ with open(target_filename, "wb") as model_file:
477
+ blob_client = self._container_client().get_blob_client(blob.name)
478
+ download_stream = blob_client.download_blob()
479
+ model_file.write(download_stream.readall())
480
+ except Exception as exc:
481
+ log = (
482
+ f"An exception occurred while trying to download "
483
+ f"the model '{target_filename}' in the specified container "
484
+ f"'{self.container_name}'. Please make sure the model exists "
485
+ f"in the container."
486
+ )
487
+ structlogger.error(
488
+ "azure_persistor.retrieve_tar.model_download_error",
489
+ container_name=self.container_name,
490
+ target_filename=target_filename,
491
+ event_info=log,
492
+ exception=exc,
493
+ )
494
+ raise ModelNotFound() from exc