rasa-pro 3.11.0a4.dev3__py3-none-any.whl → 3.11.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (184) hide show
  1. rasa/__main__.py +22 -12
  2. rasa/api.py +1 -1
  3. rasa/cli/arguments/default_arguments.py +1 -2
  4. rasa/cli/arguments/shell.py +5 -1
  5. rasa/cli/e2e_test.py +1 -1
  6. rasa/cli/evaluate.py +8 -8
  7. rasa/cli/inspect.py +6 -4
  8. rasa/cli/llm_fine_tuning.py +1 -1
  9. rasa/cli/project_templates/calm/config.yml +5 -7
  10. rasa/cli/project_templates/calm/endpoints.yml +8 -0
  11. rasa/cli/project_templates/tutorial/config.yml +8 -5
  12. rasa/cli/project_templates/tutorial/data/flows.yml +1 -1
  13. rasa/cli/project_templates/tutorial/data/patterns.yml +5 -0
  14. rasa/cli/project_templates/tutorial/domain.yml +14 -0
  15. rasa/cli/project_templates/tutorial/endpoints.yml +7 -7
  16. rasa/cli/run.py +1 -1
  17. rasa/cli/scaffold.py +4 -2
  18. rasa/cli/studio/studio.py +18 -8
  19. rasa/cli/utils.py +5 -0
  20. rasa/cli/x.py +8 -8
  21. rasa/constants.py +1 -1
  22. rasa/core/actions/action_repeat_bot_messages.py +17 -0
  23. rasa/core/channels/channel.py +20 -0
  24. rasa/core/channels/inspector/dist/assets/{arc-6852c607.js → arc-bc141fb2.js} +1 -1
  25. rasa/core/channels/inspector/dist/assets/{c4Diagram-d0fbc5ce-acc952b2.js → c4Diagram-d0fbc5ce-be2db283.js} +1 -1
  26. rasa/core/channels/inspector/dist/assets/{classDiagram-936ed81e-848a7597.js → classDiagram-936ed81e-55366915.js} +1 -1
  27. rasa/core/channels/inspector/dist/assets/{classDiagram-v2-c3cb15f1-a73d3e68.js → classDiagram-v2-c3cb15f1-bb529518.js} +1 -1
  28. rasa/core/channels/inspector/dist/assets/{createText-62fc7601-e5ee049d.js → createText-62fc7601-b0ec81d6.js} +1 -1
  29. rasa/core/channels/inspector/dist/assets/{edges-f2ad444c-771e517e.js → edges-f2ad444c-6166330c.js} +1 -1
  30. rasa/core/channels/inspector/dist/assets/{erDiagram-9d236eb7-aa347178.js → erDiagram-9d236eb7-5ccc6a8e.js} +1 -1
  31. rasa/core/channels/inspector/dist/assets/{flowDb-1972c806-651fc57d.js → flowDb-1972c806-fca3bfe4.js} +1 -1
  32. rasa/core/channels/inspector/dist/assets/{flowDiagram-7ea5b25a-ca67804f.js → flowDiagram-7ea5b25a-4739080f.js} +1 -1
  33. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-736177bf.js +1 -0
  34. rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-abe16c3d-2dbc568d.js → flowchart-elk-definition-abe16c3d-7c1b0e0f.js} +1 -1
  35. rasa/core/channels/inspector/dist/assets/{ganttDiagram-9b5ea136-25a65bd8.js → ganttDiagram-9b5ea136-772fd050.js} +1 -1
  36. rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-99d0ae7c-fdc7378d.js → gitGraphDiagram-99d0ae7c-8eae1dc9.js} +1 -1
  37. rasa/core/channels/inspector/dist/assets/{index-2c4b9a3b-6f1fd606.js → index-2c4b9a3b-f55afcdf.js} +1 -1
  38. rasa/core/channels/inspector/dist/assets/{index-efdd30c1.js → index-e7cef9de.js} +68 -68
  39. rasa/core/channels/inspector/dist/assets/{infoDiagram-736b4530-cb1a041a.js → infoDiagram-736b4530-124d4a14.js} +1 -1
  40. rasa/core/channels/inspector/dist/assets/{journeyDiagram-df861f2b-14609879.js → journeyDiagram-df861f2b-7c4fae44.js} +1 -1
  41. rasa/core/channels/inspector/dist/assets/{layout-2490f52b.js → layout-b9885fb6.js} +1 -1
  42. rasa/core/channels/inspector/dist/assets/{line-40186f1f.js → line-7c59abb6.js} +1 -1
  43. rasa/core/channels/inspector/dist/assets/{linear-08814e93.js → linear-4776f780.js} +1 -1
  44. rasa/core/channels/inspector/dist/assets/{mindmap-definition-beec6740-1a534584.js → mindmap-definition-beec6740-2332c46c.js} +1 -1
  45. rasa/core/channels/inspector/dist/assets/{pieDiagram-dbbf0591-72397b61.js → pieDiagram-dbbf0591-8fb39303.js} +1 -1
  46. rasa/core/channels/inspector/dist/assets/{quadrantDiagram-4d7f4fd6-3bb0b6a3.js → quadrantDiagram-4d7f4fd6-3c7180a2.js} +1 -1
  47. rasa/core/channels/inspector/dist/assets/{requirementDiagram-6fc4c22a-57334f61.js → requirementDiagram-6fc4c22a-e910bcb8.js} +1 -1
  48. rasa/core/channels/inspector/dist/assets/{sankeyDiagram-8f13d901-111e1297.js → sankeyDiagram-8f13d901-ead16c89.js} +1 -1
  49. rasa/core/channels/inspector/dist/assets/{sequenceDiagram-b655622a-10bcfe62.js → sequenceDiagram-b655622a-29a02a19.js} +1 -1
  50. rasa/core/channels/inspector/dist/assets/{stateDiagram-59f0c015-acaf7513.js → stateDiagram-59f0c015-042b3137.js} +1 -1
  51. rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-2b26beab-3ec2a235.js → stateDiagram-v2-2b26beab-2178c0f3.js} +1 -1
  52. rasa/core/channels/inspector/dist/assets/{styles-080da4f6-62730289.js → styles-080da4f6-23ffa4fc.js} +1 -1
  53. rasa/core/channels/inspector/dist/assets/{styles-3dcbcfbf-5284ee76.js → styles-3dcbcfbf-94f59763.js} +1 -1
  54. rasa/core/channels/inspector/dist/assets/{styles-9c745c82-642435e3.js → styles-9c745c82-78a6bebc.js} +1 -1
  55. rasa/core/channels/inspector/dist/assets/{svgDrawCommon-4835440b-b250a350.js → svgDrawCommon-4835440b-eae2a6f6.js} +1 -1
  56. rasa/core/channels/inspector/dist/assets/{timeline-definition-5b62e21b-c2b147ed.js → timeline-definition-5b62e21b-5c968d92.js} +1 -1
  57. rasa/core/channels/inspector/dist/assets/{xychartDiagram-2b33534f-f92cfea9.js → xychartDiagram-2b33534f-fd3db0d5.js} +1 -1
  58. rasa/core/channels/inspector/dist/index.html +1 -1
  59. rasa/core/channels/inspector/src/App.tsx +1 -1
  60. rasa/core/channels/inspector/src/helpers/audiostream.ts +77 -16
  61. rasa/core/channels/socketio.py +2 -1
  62. rasa/core/channels/telegram.py +1 -1
  63. rasa/core/channels/twilio.py +1 -1
  64. rasa/core/channels/voice_ready/audiocodes.py +12 -0
  65. rasa/core/channels/voice_ready/jambonz.py +15 -4
  66. rasa/core/channels/voice_ready/twilio_voice.py +6 -21
  67. rasa/core/channels/voice_stream/asr/asr_event.py +5 -0
  68. rasa/core/channels/voice_stream/asr/azure.py +122 -0
  69. rasa/core/channels/voice_stream/asr/deepgram.py +16 -6
  70. rasa/core/channels/voice_stream/audio_bytes.py +1 -0
  71. rasa/core/channels/voice_stream/browser_audio.py +31 -8
  72. rasa/core/channels/voice_stream/call_state.py +23 -0
  73. rasa/core/channels/voice_stream/tts/azure.py +6 -2
  74. rasa/core/channels/voice_stream/tts/cartesia.py +10 -6
  75. rasa/core/channels/voice_stream/tts/tts_engine.py +1 -0
  76. rasa/core/channels/voice_stream/twilio_media_streams.py +27 -18
  77. rasa/core/channels/voice_stream/util.py +4 -4
  78. rasa/core/channels/voice_stream/voice_channel.py +189 -39
  79. rasa/core/featurizers/single_state_featurizer.py +22 -1
  80. rasa/core/featurizers/tracker_featurizers.py +115 -18
  81. rasa/core/nlg/contextual_response_rephraser.py +32 -30
  82. rasa/core/persistor.py +86 -39
  83. rasa/core/policies/enterprise_search_policy.py +119 -60
  84. rasa/core/policies/flows/flow_executor.py +7 -4
  85. rasa/core/policies/intentless_policy.py +78 -22
  86. rasa/core/policies/ted_policy.py +58 -33
  87. rasa/core/policies/unexpected_intent_policy.py +15 -7
  88. rasa/core/processor.py +25 -0
  89. rasa/core/training/interactive.py +34 -35
  90. rasa/core/utils.py +8 -3
  91. rasa/dialogue_understanding/coexistence/llm_based_router.py +39 -12
  92. rasa/dialogue_understanding/commands/change_flow_command.py +6 -0
  93. rasa/dialogue_understanding/commands/user_silence_command.py +59 -0
  94. rasa/dialogue_understanding/commands/utils.py +5 -0
  95. rasa/dialogue_understanding/generator/constants.py +2 -0
  96. rasa/dialogue_understanding/generator/flow_retrieval.py +49 -4
  97. rasa/dialogue_understanding/generator/llm_based_command_generator.py +37 -23
  98. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +57 -10
  99. rasa/dialogue_understanding/generator/nlu_command_adapter.py +19 -1
  100. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +71 -11
  101. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +39 -0
  102. rasa/dialogue_understanding/patterns/user_silence.py +37 -0
  103. rasa/dialogue_understanding/processor/command_processor.py +21 -1
  104. rasa/e2e_test/e2e_test_case.py +85 -6
  105. rasa/e2e_test/e2e_test_runner.py +4 -2
  106. rasa/e2e_test/utils/io.py +1 -1
  107. rasa/engine/validation.py +316 -10
  108. rasa/model_manager/config.py +15 -3
  109. rasa/model_manager/model_api.py +15 -7
  110. rasa/model_manager/runner_service.py +8 -6
  111. rasa/model_manager/socket_bridge.py +6 -3
  112. rasa/model_manager/trainer_service.py +7 -5
  113. rasa/model_manager/utils.py +28 -7
  114. rasa/model_service.py +9 -2
  115. rasa/model_training.py +2 -0
  116. rasa/nlu/classifiers/diet_classifier.py +38 -25
  117. rasa/nlu/classifiers/logistic_regression_classifier.py +22 -9
  118. rasa/nlu/classifiers/sklearn_intent_classifier.py +37 -16
  119. rasa/nlu/extractors/crf_entity_extractor.py +93 -50
  120. rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +45 -16
  121. rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +52 -17
  122. rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +5 -3
  123. rasa/nlu/tokenizers/whitespace_tokenizer.py +3 -14
  124. rasa/server.py +3 -1
  125. rasa/shared/constants.py +36 -3
  126. rasa/shared/core/constants.py +7 -0
  127. rasa/shared/core/domain.py +26 -0
  128. rasa/shared/core/flows/flow.py +5 -0
  129. rasa/shared/core/flows/flows_list.py +5 -1
  130. rasa/shared/core/flows/flows_yaml_schema.json +10 -0
  131. rasa/shared/core/flows/utils.py +39 -0
  132. rasa/shared/core/flows/validation.py +96 -0
  133. rasa/shared/core/slots.py +5 -0
  134. rasa/shared/nlu/training_data/features.py +120 -2
  135. rasa/shared/providers/_configs/azure_openai_client_config.py +5 -3
  136. rasa/shared/providers/_configs/litellm_router_client_config.py +200 -0
  137. rasa/shared/providers/_configs/model_group_config.py +167 -0
  138. rasa/shared/providers/_configs/openai_client_config.py +1 -1
  139. rasa/shared/providers/_configs/rasa_llm_client_config.py +73 -0
  140. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +1 -0
  141. rasa/shared/providers/_configs/utils.py +16 -0
  142. rasa/shared/providers/embedding/_base_litellm_embedding_client.py +18 -29
  143. rasa/shared/providers/embedding/azure_openai_embedding_client.py +54 -21
  144. rasa/shared/providers/embedding/litellm_router_embedding_client.py +135 -0
  145. rasa/shared/providers/llm/_base_litellm_client.py +37 -31
  146. rasa/shared/providers/llm/azure_openai_llm_client.py +50 -29
  147. rasa/shared/providers/llm/litellm_router_llm_client.py +127 -0
  148. rasa/shared/providers/llm/rasa_llm_client.py +112 -0
  149. rasa/shared/providers/llm/self_hosted_llm_client.py +1 -1
  150. rasa/shared/providers/mappings.py +19 -0
  151. rasa/shared/providers/router/__init__.py +0 -0
  152. rasa/shared/providers/router/_base_litellm_router_client.py +149 -0
  153. rasa/shared/providers/router/router_client.py +73 -0
  154. rasa/shared/utils/common.py +8 -0
  155. rasa/shared/utils/health_check/__init__.py +0 -0
  156. rasa/shared/utils/health_check/embeddings_health_check_mixin.py +31 -0
  157. rasa/shared/utils/health_check/health_check.py +256 -0
  158. rasa/shared/utils/health_check/llm_health_check_mixin.py +31 -0
  159. rasa/shared/utils/io.py +28 -6
  160. rasa/shared/utils/llm.py +353 -46
  161. rasa/shared/utils/yaml.py +111 -73
  162. rasa/studio/auth.py +3 -5
  163. rasa/studio/config.py +13 -4
  164. rasa/studio/constants.py +1 -0
  165. rasa/studio/data_handler.py +10 -3
  166. rasa/studio/upload.py +81 -26
  167. rasa/telemetry.py +92 -17
  168. rasa/tracing/config.py +2 -0
  169. rasa/tracing/instrumentation/attribute_extractors.py +94 -17
  170. rasa/tracing/instrumentation/instrumentation.py +121 -0
  171. rasa/utils/common.py +5 -0
  172. rasa/utils/io.py +7 -81
  173. rasa/utils/log_utils.py +9 -2
  174. rasa/utils/sanic_error_handler.py +32 -0
  175. rasa/utils/tensorflow/feature_array.py +366 -0
  176. rasa/utils/tensorflow/model_data.py +2 -193
  177. rasa/validator.py +70 -0
  178. rasa/version.py +1 -1
  179. {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc2.dist-info}/METADATA +11 -10
  180. {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc2.dist-info}/RECORD +183 -163
  181. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-587d82d8.js +0 -1
  182. {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc2.dist-info}/NOTICE +0 -0
  183. {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc2.dist-info}/WHEEL +0 -0
  184. {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc2.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,200 @@
1
+ import copy
2
+ from dataclasses import dataclass, field
3
+ from typing import Any, Dict, List
4
+
5
+ import structlog
6
+ from rasa.shared.constants import (
7
+ ROUTER_CONFIG_KEY,
8
+ MODELS_CONFIG_KEY,
9
+ MODEL_GROUP_ID_CONFIG_KEY,
10
+ MODEL_NAME_CONFIG_KEY,
11
+ LITELLM_PARAMS_KEY,
12
+ PROVIDER_CONFIG_KEY,
13
+ DEPLOYMENT_CONFIG_KEY,
14
+ API_TYPE_CONFIG_KEY,
15
+ MODEL_CONFIG_KEY,
16
+ MODEL_LIST_KEY,
17
+ )
18
+ from rasa.shared.providers._configs.model_group_config import (
19
+ ModelGroupConfig,
20
+ ModelConfig,
21
+ )
22
+ from rasa.shared.providers.mappings import get_prefix_from_provider
23
+ from rasa.shared.utils.llm import DEPLOYMENT_CENTRIC_PROVIDERS
24
+
25
+
26
+ structlogger = structlog.get_logger()
27
+
28
+ _LITELLM_UNSUPPORTED_KEYS = [
29
+ PROVIDER_CONFIG_KEY,
30
+ DEPLOYMENT_CONFIG_KEY,
31
+ API_TYPE_CONFIG_KEY,
32
+ ]
33
+
34
+
35
+ @dataclass
36
+ class LiteLLMRouterClientConfig:
37
+ """Parses configuration for a LiteLLM Router client. The configuration is expected
38
+ to be in the following format:
39
+
40
+ {
41
+ "id": "model_group_id",
42
+ "models": [
43
+ {
44
+ "provider": "provider_name",
45
+ "model": "model_name",
46
+ "api_base": "api_base",
47
+ "api_key": "api_key",
48
+ "api_version": "api_version",
49
+ },
50
+ {
51
+ "provider": "provider_name",
52
+ "model": "model_name",
53
+ },
54
+ "router": {}
55
+ }
56
+
57
+ This configuration is converted into the LiteLLM required format:
58
+
59
+ {
60
+ "id": "model_group_id",
61
+ "model_list": [
62
+ {
63
+ "model_name": "model_group_id",
64
+ "litellm_params": {
65
+ "model": "provider_name/model_name",
66
+ "api_base": "api_base",
67
+ "api_key": "api_key",
68
+ "api_version": "api_version",
69
+ },
70
+ },
71
+ {
72
+ "model_name": "model_group_id",
73
+ "litellm_params": {
74
+ "model": "provider_name/model_name",
75
+ },
76
+ },
77
+ ],
78
+ "router": {},
79
+ }
80
+
81
+ Raises:
82
+ ValueError: If the configuration is missing required keys.
83
+ """
84
+
85
+ _model_group_config: ModelGroupConfig
86
+ router: Dict[str, Any]
87
+ extra_parameters: dict = field(default_factory=dict)
88
+
89
+ @property
90
+ def model_group_id(self) -> str:
91
+ return self._model_group_config.model_group_id
92
+
93
+ @property
94
+ def models(self) -> List[ModelConfig]:
95
+ return self._model_group_config.models
96
+
97
+ @property
98
+ def litellm_model_list(self) -> List[Dict[str, Any]]:
99
+ return self._convert_models_to_litellm_model_list()
100
+
101
+ def __post_init__(self) -> None:
102
+ if not self.router:
103
+ message = "Router cannot be empty."
104
+ structlogger.error(
105
+ "litellm_router_client_config.validation_error",
106
+ message=message,
107
+ model_group_id=self._model_group_config.model_group_id,
108
+ )
109
+ raise ValueError(message)
110
+
111
+ @classmethod
112
+ def from_dict(cls, config: dict) -> "LiteLLMRouterClientConfig":
113
+ """Initializes a dataclass from the passed config.
114
+
115
+ Args:
116
+ config: (dict) The config from which to initialize.
117
+
118
+ Raises:
119
+ ValueError: Config is missing required keys.
120
+
121
+ Returns:
122
+ LiteLLMRouterClientConfig
123
+ """
124
+
125
+ model_group_config = ModelGroupConfig.from_dict(config)
126
+
127
+ # Copy config to avoid mutating the original
128
+ config_copy = copy.deepcopy(config)
129
+ # Pop the keys used by ModelGroupConfig
130
+ config_copy.pop(MODEL_GROUP_ID_CONFIG_KEY, None)
131
+ config_copy.pop(MODELS_CONFIG_KEY, None)
132
+ # Get the router settings
133
+ router_settings = config_copy.pop(ROUTER_CONFIG_KEY, None)
134
+ # The rest is considered as extra parameters
135
+ extra_parameters = config_copy
136
+
137
+ this = LiteLLMRouterClientConfig(
138
+ _model_group_config=model_group_config,
139
+ router=router_settings,
140
+ extra_parameters=extra_parameters,
141
+ )
142
+ return this
143
+
144
+ def to_dict(self) -> dict:
145
+ """Converts the config instance into a dictionary."""
146
+ d = self._model_group_config.to_dict()
147
+ d[ROUTER_CONFIG_KEY] = self.router
148
+ if self.extra_parameters:
149
+ d.update(self.extra_parameters)
150
+ return d
151
+
152
+ def to_litellm_dict(self) -> dict:
153
+ litellm_model_list = self._convert_models_to_litellm_model_list()
154
+ d = {
155
+ **self.extra_parameters,
156
+ MODEL_GROUP_ID_CONFIG_KEY: self.model_group_id,
157
+ MODEL_LIST_KEY: litellm_model_list,
158
+ ROUTER_CONFIG_KEY: self.router,
159
+ }
160
+ return d
161
+
162
+ def _convert_models_to_litellm_model_list(self) -> List[Dict[str, Any]]:
163
+ litellm_model_list = []
164
+
165
+ for model_config_object in self.models:
166
+ # Convert the model config to a dict representation
167
+ litellm_model_config = model_config_object.to_dict()
168
+
169
+ provider = litellm_model_config[PROVIDER_CONFIG_KEY]
170
+
171
+ # Get the litellm prefixing for the provider
172
+ prefix = get_prefix_from_provider(provider)
173
+
174
+ # Determine whether to use model or deployment key based on the provider.
175
+ litellm_model_name_without_prefix = (
176
+ litellm_model_config[DEPLOYMENT_CONFIG_KEY]
177
+ if provider in DEPLOYMENT_CENTRIC_PROVIDERS
178
+ else litellm_model_config[MODEL_CONFIG_KEY]
179
+ )
180
+
181
+ # Set 'model' to a provider prefixed model name e.g. openai/gpt-4
182
+ litellm_model_config[MODEL_CONFIG_KEY] = (
183
+ f"{prefix}/{litellm_model_name_without_prefix}"
184
+ )
185
+
186
+ # Remove parameters that are None and not supported by LiteLLM.
187
+ litellm_model_config = {
188
+ key: value
189
+ for key, value in litellm_model_config.items()
190
+ if key not in _LITELLM_UNSUPPORTED_KEYS and value is not None
191
+ }
192
+
193
+ litellm_model_list_item = {
194
+ MODEL_NAME_CONFIG_KEY: self.model_group_id,
195
+ LITELLM_PARAMS_KEY: litellm_model_config,
196
+ }
197
+
198
+ litellm_model_list.append(litellm_model_list_item)
199
+
200
+ return litellm_model_list
@@ -0,0 +1,167 @@
1
+ from dataclasses import asdict, dataclass, field
2
+ from typing import List, Optional
3
+
4
+ import structlog
5
+ from rasa.shared.constants import (
6
+ API_BASE_CONFIG_KEY,
7
+ API_KEY,
8
+ API_TYPE_CONFIG_KEY,
9
+ API_VERSION_CONFIG_KEY,
10
+ DEPLOYMENT_CONFIG_KEY,
11
+ PROVIDER_CONFIG_KEY,
12
+ MODEL_CONFIG_KEY,
13
+ MODEL_GROUP_ID_CONFIG_KEY,
14
+ MODELS_CONFIG_KEY,
15
+ MODEL_GROUPS_CONFIG_KEY,
16
+ EXTRA_PARAMETERS_KEY,
17
+ )
18
+ from rasa.shared.providers.mappings import get_client_config_class_from_provider
19
+
20
+ structlogger = structlog.get_logger()
21
+
22
+
23
+ @dataclass
24
+ class ModelConfig:
25
+ """Parses the model config.
26
+
27
+ Raises:
28
+ ValueError: If the provider config key is missing in the config.
29
+ """
30
+
31
+ provider: str
32
+ model: Optional[str] = None
33
+ deployment: Optional[str] = None
34
+ api_base: Optional[str] = None
35
+ api_key: Optional[str] = None
36
+ api_version: Optional[str] = None
37
+ extra_parameters: dict = field(default_factory=dict)
38
+ # Retained for backward compatibility with older configurations,
39
+ # but intentionally not included in extra_parameters
40
+ api_type: Optional[str] = None
41
+
42
+ @classmethod
43
+ def from_dict(cls, config: dict) -> "ModelConfig":
44
+ """Initializes a dataclass from the passed config. The provider config param is
45
+ used to determine the client config class to use. The client config class takes
46
+ care of resolving config aliases and throwing deprecation warnings.
47
+
48
+ Args:
49
+ config: (dict) The config from which to initialize.
50
+
51
+ Raises:
52
+ ValueError: Config is missing required keys.
53
+
54
+ Returns:
55
+ ModelConfig
56
+ """
57
+ from rasa.shared.utils.llm import get_provider_from_config
58
+
59
+ # Get the provider from config, this also inferring the provider from
60
+ # deprecated configurations
61
+ provider = get_provider_from_config(config)
62
+
63
+ # Retrieve the client configuration class for the specified provider.
64
+ client_config_clazz = get_client_config_class_from_provider(provider)
65
+
66
+ # Try to instantiate the config object in order to resolve deprecated
67
+ # aliases and throw deprecation warnings.
68
+ client_config_obj = client_config_clazz.from_dict(config)
69
+
70
+ # Convert back to dictionary and instantiate the ModelConfig object.
71
+ client_config = client_config_obj.to_dict()
72
+
73
+ # Check for provider after resolving all aliases
74
+ if PROVIDER_CONFIG_KEY not in client_config:
75
+ raise ValueError(
76
+ f"Missing required key '{PROVIDER_CONFIG_KEY}' in "
77
+ f"'{MODELS_CONFIG_KEY}' config."
78
+ )
79
+
80
+ return ModelConfig(
81
+ provider=client_config.pop(PROVIDER_CONFIG_KEY, None),
82
+ model=client_config.pop(MODEL_CONFIG_KEY, None),
83
+ deployment=client_config.pop(DEPLOYMENT_CONFIG_KEY, None),
84
+ api_type=client_config.pop(API_TYPE_CONFIG_KEY, None),
85
+ api_base=client_config.pop(API_BASE_CONFIG_KEY, None),
86
+ api_key=client_config.pop(API_KEY, None),
87
+ api_version=client_config.pop(API_VERSION_CONFIG_KEY, None),
88
+ extra_parameters=client_config,
89
+ )
90
+
91
+ def to_dict(self) -> dict:
92
+ """Converts the config instance into a dictionary."""
93
+ d = asdict(self)
94
+
95
+ # Extra parameters should also be on the top level
96
+ d.pop(EXTRA_PARAMETERS_KEY, None)
97
+ d.update(self.extra_parameters)
98
+
99
+ # Remove keys with None values
100
+ return {key: value for key, value in d.items() if value is not None}
101
+
102
+
103
+ @dataclass
104
+ class ModelGroupConfig:
105
+ """Parses the models config. The models config is a list of model configs.
106
+
107
+ Raises:
108
+ ValueError: If the model group ID is None or if the models list is empty.
109
+ """
110
+
111
+ model_group_id: str
112
+ models: List[ModelConfig]
113
+
114
+ def __post_init__(self) -> None:
115
+ if self.model_group_id is None:
116
+ message = "Model group ID cannot be set to None."
117
+ structlogger.error(
118
+ "model_group_config.validation_error",
119
+ message=message,
120
+ model_group_id=self.model_group_id,
121
+ )
122
+ raise ValueError(message)
123
+ if not self.models:
124
+ message = "Models cannot be empty."
125
+ structlogger.error(
126
+ "model_group_config.validation_error",
127
+ message=message,
128
+ model_group_id=self.model_group_id,
129
+ )
130
+ raise ValueError(message)
131
+
132
+ @classmethod
133
+ def from_dict(cls, config: dict) -> "ModelGroupConfig":
134
+ """Initializes a dataclass from the passed config.
135
+
136
+ Args:
137
+ config: (dict) The config from which to initialize.
138
+
139
+ Raises:
140
+ ValueError: Config is missing required keys.
141
+
142
+ Returns:
143
+ ModelGroupConfig
144
+ """
145
+ if MODELS_CONFIG_KEY not in config:
146
+ raise ValueError(
147
+ f"Missing required key '{MODELS_CONFIG_KEY}' in "
148
+ f"'{MODEL_GROUPS_CONFIG_KEY}' config."
149
+ )
150
+
151
+ models_config = [
152
+ ModelConfig.from_dict(model_config)
153
+ for model_config in config[MODELS_CONFIG_KEY]
154
+ ]
155
+
156
+ return cls(
157
+ model_group_id=config.get(MODEL_GROUP_ID_CONFIG_KEY),
158
+ models=models_config,
159
+ )
160
+
161
+ def to_dict(self) -> dict:
162
+ """Converts the config instance into a dictionary."""
163
+ d = {
164
+ MODEL_GROUP_ID_CONFIG_KEY: self.model_group_id,
165
+ MODELS_CONFIG_KEY: [model.to_dict() for model in self.models],
166
+ }
167
+ return d
@@ -19,8 +19,8 @@ from rasa.shared.constants import (
19
19
  REQUEST_TIMEOUT_CONFIG_KEY,
20
20
  TIMEOUT_CONFIG_KEY,
21
21
  PROVIDER_CONFIG_KEY,
22
- OPENAI_PROVIDER,
23
22
  OPENAI_API_TYPE,
23
+ OPENAI_PROVIDER,
24
24
  )
25
25
  from rasa.shared.providers._configs.utils import (
26
26
  resolve_aliases,
@@ -0,0 +1,73 @@
1
+ from dataclasses import asdict, dataclass, field
2
+ from typing import Optional
3
+
4
+ import structlog
5
+
6
+ from rasa.shared.constants import (
7
+ MODEL_CONFIG_KEY,
8
+ RASA_PROVIDER,
9
+ PROVIDER_CONFIG_KEY,
10
+ API_BASE_CONFIG_KEY,
11
+ )
12
+ from rasa.shared.providers._configs.utils import (
13
+ validate_required_keys,
14
+ )
15
+
16
+ REQUIRED_KEYS = [MODEL_CONFIG_KEY, PROVIDER_CONFIG_KEY, API_BASE_CONFIG_KEY]
17
+
18
+ structlogger = structlog.get_logger()
19
+
20
+
21
+ @dataclass
22
+ class RasaLLMClientConfig:
23
+ """Parses configuration for a Rasa Hosted LiteLLM client,
24
+ checks required keys present.
25
+
26
+ Raises:
27
+ ValueError: Raised in cases of invalid configuration:
28
+ - If any of the required configuration keys are missing.
29
+ """
30
+
31
+ model: Optional[str]
32
+ api_base: Optional[str]
33
+ # Provider is not used by LiteLLM backend, but we define it here since it's
34
+ # used as switch between different clients.
35
+ provider: str = RASA_PROVIDER
36
+
37
+ extra_parameters: dict = field(default_factory=dict)
38
+
39
+ @classmethod
40
+ def from_dict(cls, config: dict) -> "RasaLLMClientConfig":
41
+ """
42
+ Initializes a dataclass from the passed config.
43
+
44
+ Args:
45
+ config: (dict) The config from which to initialize.
46
+
47
+ Raises:
48
+ ValueError: Raised in cases of invalid configuration:
49
+ - If any of the required configuration keys are missing.
50
+ - If `api_type` has a value different from `azure`.
51
+
52
+ Returns:
53
+ RasaLLMClientConfig
54
+ """
55
+ # Validate that required keys are set
56
+ validate_required_keys(config, REQUIRED_KEYS)
57
+
58
+ extra_parameters = {k: v for k, v in config.items() if k not in REQUIRED_KEYS}
59
+
60
+ return cls(
61
+ model=config.get(MODEL_CONFIG_KEY),
62
+ api_base=config.get(API_BASE_CONFIG_KEY),
63
+ provider=config.get(PROVIDER_CONFIG_KEY, RASA_PROVIDER),
64
+ extra_parameters=extra_parameters,
65
+ )
66
+
67
+ def to_dict(self) -> dict:
68
+ """Converts the config instance into a dictionary."""
69
+ d = asdict(self)
70
+ # Extra parameters should also be on the top level
71
+ d.pop("extra_parameters", None)
72
+ d.update(self.extra_parameters)
73
+ return d
@@ -23,6 +23,7 @@ from rasa.shared.constants import (
23
23
  SELF_HOSTED_PROVIDER,
24
24
  USE_CHAT_COMPLETIONS_ENDPOINT_CONFIG_KEY,
25
25
  )
26
+
26
27
  from rasa.shared.providers._configs.utils import (
27
28
  raise_deprecation_warnings,
28
29
  resolve_aliases,
@@ -99,3 +99,19 @@ def validate_forbidden_keys(config: dict, forbidden_keys: list) -> None:
99
99
  config=config,
100
100
  )
101
101
  raise ValueError(message)
102
+
103
+
104
+ def get_provider_prefixed_model_name(provider: str, model: str) -> str:
105
+ """
106
+ Returns the model name with the provider prefixed.
107
+
108
+ Args:
109
+ provider: The provider of the model.
110
+ model: The model name.
111
+
112
+ Returns:
113
+ The model name with the provider prefixed.
114
+ """
115
+ if model and f"{provider}/" not in model:
116
+ return f"{provider}/{model}"
117
+ return model
@@ -1,10 +1,12 @@
1
+ import logging
1
2
  from abc import abstractmethod
2
3
  from typing import Any, Dict, List
3
4
 
4
5
  import litellm
5
- import logging
6
6
  import structlog
7
7
  from litellm import aembedding, embedding, validate_environment
8
+
9
+ from rasa.shared.constants import API_BASE_CONFIG_KEY, API_KEY
8
10
  from rasa.shared.exceptions import (
9
11
  ProviderClientAPIException,
10
12
  ProviderClientValidationError,
@@ -17,7 +19,7 @@ from rasa.shared.providers.embedding.embedding_response import (
17
19
  EmbeddingResponse,
18
20
  EmbeddingUsage,
19
21
  )
20
- from rasa.shared.utils.io import suppress_logs
22
+ from rasa.shared.utils.io import suppress_logs, resolve_environment_variables
21
23
 
22
24
  structlogger = structlog.get_logger()
23
25
 
@@ -25,8 +27,7 @@ _VALIDATE_ENVIRONMENT_MISSING_KEYS_KEY = "missing_keys"
25
27
 
26
28
 
27
29
  class _BaseLiteLLMEmbeddingClient:
28
- """
29
- An abstract base class for LiteLLM embedding clients.
30
+ """An abstract base class for LiteLLM embedding clients.
30
31
 
31
32
  This class defines the interface and common functionality for all clients
32
33
  based on LiteLLM.
@@ -81,11 +82,14 @@ class _BaseLiteLLMEmbeddingClient:
81
82
  ProviderClientValidationError if validation fails.
82
83
  """
83
84
  self._validate_environment_variables()
84
- self._validate_api_key_not_in_config()
85
85
 
86
86
  def _validate_environment_variables(self) -> None:
87
87
  """Validate that the required environment variables are set."""
88
- validation_info = validate_environment(self._litellm_model_name)
88
+ validation_info = validate_environment(
89
+ self._litellm_model_name,
90
+ api_key=self._litellm_extra_parameters.get(API_KEY),
91
+ api_base=self._litellm_extra_parameters.get(API_BASE_CONFIG_KEY),
92
+ )
89
93
  if missing_environment_variables := validation_info.get(
90
94
  _VALIDATE_ENVIRONMENT_MISSING_KEYS_KEY
91
95
  ):
@@ -100,21 +104,8 @@ class _BaseLiteLLMEmbeddingClient:
100
104
  )
101
105
  raise ProviderClientValidationError(event_info)
102
106
 
103
- def _validate_api_key_not_in_config(self) -> None:
104
- if "api_key" in self._litellm_extra_parameters:
105
- event_info = (
106
- "API Key is set through `api_key` extra parameter."
107
- "Set API keys through environment variables."
108
- )
109
- structlogger.error(
110
- "base_litellm_client.validate_api_key_not_in_config",
111
- event_info=event_info,
112
- )
113
- raise ProviderClientValidationError(event_info)
114
-
115
107
  def validate_documents(self, documents: List[str]) -> None:
116
- """
117
- Validates a list of documents to ensure they are suitable for embedding.
108
+ """Validates a list of documents to ensure they are suitable for embedding.
118
109
 
119
110
  Args:
120
111
  documents: List of documents to be validated.
@@ -130,8 +121,7 @@ class _BaseLiteLLMEmbeddingClient:
130
121
 
131
122
  @suppress_logs(log_level=logging.WARNING)
132
123
  def embed(self, documents: List[str]) -> EmbeddingResponse:
133
- """
134
- Embeds a list of documents synchronously.
124
+ """Embeds a list of documents synchronously.
135
125
 
136
126
  Args:
137
127
  documents: List of documents to be embedded.
@@ -144,7 +134,8 @@ class _BaseLiteLLMEmbeddingClient:
144
134
  """
145
135
  self.validate_documents(documents)
146
136
  try:
147
- response = embedding(input=documents, **self._embedding_fn_args)
137
+ arguments = resolve_environment_variables(self._embedding_fn_args)
138
+ response = embedding(input=documents, **arguments)
148
139
  return self._format_response(response)
149
140
  except Exception as e:
150
141
  raise ProviderClientAPIException(
@@ -153,8 +144,7 @@ class _BaseLiteLLMEmbeddingClient:
153
144
 
154
145
  @suppress_logs(log_level=logging.WARNING)
155
146
  async def aembed(self, documents: List[str]) -> EmbeddingResponse:
156
- """
157
- Embeds a list of documents asynchronously.
147
+ """Embeds a list of documents asynchronously.
158
148
 
159
149
  Args:
160
150
  documents: List of documents to be embedded.
@@ -167,7 +157,8 @@ class _BaseLiteLLMEmbeddingClient:
167
157
  """
168
158
  self.validate_documents(documents)
169
159
  try:
170
- response = await aembedding(input=documents, **self._embedding_fn_args)
160
+ arguments = resolve_environment_variables(self._embedding_fn_args)
161
+ response = await aembedding(input=documents, **arguments)
171
162
  return self._format_response(response)
172
163
  except Exception as e:
173
164
  raise ProviderClientAPIException(
@@ -182,7 +173,6 @@ class _BaseLiteLLMEmbeddingClient:
182
173
  Raises:
183
174
  ValueError: If any response data is None.
184
175
  """
185
-
186
176
  # If data is not available (None), raise a ValueError
187
177
  if response.data is None:
188
178
  message = (
@@ -239,8 +229,7 @@ class _BaseLiteLLMEmbeddingClient:
239
229
 
240
230
  @staticmethod
241
231
  def _ensure_certificates() -> None:
242
- """
243
- Configures SSL certificates for LiteLLM. This method is invoked during
232
+ """Configures SSL certificates for LiteLLM. This method is invoked during
244
233
  client initialization.
245
234
 
246
235
  LiteLLM may utilize `openai` clients or other providers that require