rasa-pro 3.10.16__py3-none-any.whl → 3.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (240) hide show
  1. rasa/__main__.py +31 -15
  2. rasa/api.py +12 -2
  3. rasa/cli/arguments/default_arguments.py +24 -4
  4. rasa/cli/arguments/run.py +15 -0
  5. rasa/cli/arguments/shell.py +5 -1
  6. rasa/cli/arguments/train.py +17 -9
  7. rasa/cli/evaluate.py +7 -7
  8. rasa/cli/inspect.py +19 -7
  9. rasa/cli/interactive.py +1 -0
  10. rasa/cli/llm_fine_tuning.py +11 -14
  11. rasa/cli/project_templates/calm/config.yml +5 -7
  12. rasa/cli/project_templates/calm/endpoints.yml +15 -2
  13. rasa/cli/project_templates/tutorial/config.yml +8 -5
  14. rasa/cli/project_templates/tutorial/data/flows.yml +1 -1
  15. rasa/cli/project_templates/tutorial/data/patterns.yml +5 -0
  16. rasa/cli/project_templates/tutorial/domain.yml +14 -0
  17. rasa/cli/project_templates/tutorial/endpoints.yml +5 -0
  18. rasa/cli/run.py +7 -0
  19. rasa/cli/scaffold.py +4 -2
  20. rasa/cli/studio/upload.py +0 -15
  21. rasa/cli/train.py +14 -53
  22. rasa/cli/utils.py +14 -11
  23. rasa/cli/x.py +7 -7
  24. rasa/constants.py +3 -1
  25. rasa/core/actions/action.py +77 -33
  26. rasa/core/actions/action_hangup.py +29 -0
  27. rasa/core/actions/action_repeat_bot_messages.py +89 -0
  28. rasa/core/actions/e2e_stub_custom_action_executor.py +5 -1
  29. rasa/core/actions/http_custom_action_executor.py +4 -0
  30. rasa/core/agent.py +2 -2
  31. rasa/core/brokers/kafka.py +3 -1
  32. rasa/core/brokers/pika.py +3 -1
  33. rasa/core/channels/__init__.py +10 -6
  34. rasa/core/channels/channel.py +41 -4
  35. rasa/core/channels/development_inspector.py +150 -46
  36. rasa/core/channels/inspector/README.md +1 -1
  37. rasa/core/channels/inspector/dist/assets/{arc-b6e548fe.js → arc-bc141fb2.js} +1 -1
  38. rasa/core/channels/inspector/dist/assets/{c4Diagram-d0fbc5ce-fa03ac9e.js → c4Diagram-d0fbc5ce-be2db283.js} +1 -1
  39. rasa/core/channels/inspector/dist/assets/{classDiagram-936ed81e-ee67392a.js → classDiagram-936ed81e-55366915.js} +1 -1
  40. rasa/core/channels/inspector/dist/assets/{classDiagram-v2-c3cb15f1-9b283fae.js → classDiagram-v2-c3cb15f1-bb529518.js} +1 -1
  41. rasa/core/channels/inspector/dist/assets/{createText-62fc7601-8b6fcc2a.js → createText-62fc7601-b0ec81d6.js} +1 -1
  42. rasa/core/channels/inspector/dist/assets/{edges-f2ad444c-22e77f4f.js → edges-f2ad444c-6166330c.js} +1 -1
  43. rasa/core/channels/inspector/dist/assets/{erDiagram-9d236eb7-60ffc87f.js → erDiagram-9d236eb7-5ccc6a8e.js} +1 -1
  44. rasa/core/channels/inspector/dist/assets/{flowDb-1972c806-9dd802e4.js → flowDb-1972c806-fca3bfe4.js} +1 -1
  45. rasa/core/channels/inspector/dist/assets/{flowDiagram-7ea5b25a-5fa1912f.js → flowDiagram-7ea5b25a-4739080f.js} +1 -1
  46. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-736177bf.js +1 -0
  47. rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-abe16c3d-622a1fd2.js → flowchart-elk-definition-abe16c3d-7c1b0e0f.js} +1 -1
  48. rasa/core/channels/inspector/dist/assets/{ganttDiagram-9b5ea136-e285a63a.js → ganttDiagram-9b5ea136-772fd050.js} +1 -1
  49. rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-99d0ae7c-f237bdca.js → gitGraphDiagram-99d0ae7c-8eae1dc9.js} +1 -1
  50. rasa/core/channels/inspector/dist/assets/{index-2c4b9a3b-4b03d70e.js → index-2c4b9a3b-f55afcdf.js} +1 -1
  51. rasa/core/channels/inspector/dist/assets/index-e7cef9de.js +1317 -0
  52. rasa/core/channels/inspector/dist/assets/{infoDiagram-736b4530-72a0fa5f.js → infoDiagram-736b4530-124d4a14.js} +1 -1
  53. rasa/core/channels/inspector/dist/assets/{journeyDiagram-df861f2b-82218c41.js → journeyDiagram-df861f2b-7c4fae44.js} +1 -1
  54. rasa/core/channels/inspector/dist/assets/{layout-78cff630.js → layout-b9885fb6.js} +1 -1
  55. rasa/core/channels/inspector/dist/assets/{line-5038b469.js → line-7c59abb6.js} +1 -1
  56. rasa/core/channels/inspector/dist/assets/{linear-c4fc4098.js → linear-4776f780.js} +1 -1
  57. rasa/core/channels/inspector/dist/assets/{mindmap-definition-beec6740-c33c8ea6.js → mindmap-definition-beec6740-2332c46c.js} +1 -1
  58. rasa/core/channels/inspector/dist/assets/{pieDiagram-dbbf0591-a8d03059.js → pieDiagram-dbbf0591-8fb39303.js} +1 -1
  59. rasa/core/channels/inspector/dist/assets/{quadrantDiagram-4d7f4fd6-6a0e56b2.js → quadrantDiagram-4d7f4fd6-3c7180a2.js} +1 -1
  60. rasa/core/channels/inspector/dist/assets/{requirementDiagram-6fc4c22a-2dc7c7bd.js → requirementDiagram-6fc4c22a-e910bcb8.js} +1 -1
  61. rasa/core/channels/inspector/dist/assets/{sankeyDiagram-8f13d901-2360fe39.js → sankeyDiagram-8f13d901-ead16c89.js} +1 -1
  62. rasa/core/channels/inspector/dist/assets/{sequenceDiagram-b655622a-41b9f9ad.js → sequenceDiagram-b655622a-29a02a19.js} +1 -1
  63. rasa/core/channels/inspector/dist/assets/{stateDiagram-59f0c015-0aad326f.js → stateDiagram-59f0c015-042b3137.js} +1 -1
  64. rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-2b26beab-9847d984.js → stateDiagram-v2-2b26beab-2178c0f3.js} +1 -1
  65. rasa/core/channels/inspector/dist/assets/{styles-080da4f6-564d890e.js → styles-080da4f6-23ffa4fc.js} +1 -1
  66. rasa/core/channels/inspector/dist/assets/{styles-3dcbcfbf-38957613.js → styles-3dcbcfbf-94f59763.js} +1 -1
  67. rasa/core/channels/inspector/dist/assets/{styles-9c745c82-f0fc6921.js → styles-9c745c82-78a6bebc.js} +1 -1
  68. rasa/core/channels/inspector/dist/assets/{svgDrawCommon-4835440b-ef3c5a77.js → svgDrawCommon-4835440b-eae2a6f6.js} +1 -1
  69. rasa/core/channels/inspector/dist/assets/{timeline-definition-5b62e21b-bf3e91c1.js → timeline-definition-5b62e21b-5c968d92.js} +1 -1
  70. rasa/core/channels/inspector/dist/assets/{xychartDiagram-2b33534f-4d4026c0.js → xychartDiagram-2b33534f-fd3db0d5.js} +1 -1
  71. rasa/core/channels/inspector/dist/index.html +18 -17
  72. rasa/core/channels/inspector/index.html +17 -16
  73. rasa/core/channels/inspector/package.json +5 -1
  74. rasa/core/channels/inspector/src/App.tsx +118 -68
  75. rasa/core/channels/inspector/src/components/Chat.tsx +95 -0
  76. rasa/core/channels/inspector/src/components/DiagramFlow.tsx +11 -10
  77. rasa/core/channels/inspector/src/components/DialogueStack.tsx +10 -25
  78. rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +6 -3
  79. rasa/core/channels/inspector/src/helpers/audiostream.ts +165 -0
  80. rasa/core/channels/inspector/src/helpers/formatters.test.ts +10 -0
  81. rasa/core/channels/inspector/src/helpers/formatters.ts +107 -41
  82. rasa/core/channels/inspector/src/helpers/utils.ts +92 -7
  83. rasa/core/channels/inspector/src/types.ts +21 -1
  84. rasa/core/channels/inspector/yarn.lock +94 -1
  85. rasa/core/channels/rest.py +51 -46
  86. rasa/core/channels/socketio.py +28 -1
  87. rasa/core/channels/telegram.py +1 -1
  88. rasa/core/channels/twilio.py +1 -1
  89. rasa/core/channels/{audiocodes.py → voice_ready/audiocodes.py} +122 -69
  90. rasa/core/channels/{voice_aware → voice_ready}/jambonz.py +26 -8
  91. rasa/core/channels/{voice_aware → voice_ready}/jambonz_protocol.py +57 -5
  92. rasa/core/channels/{twilio_voice.py → voice_ready/twilio_voice.py} +64 -28
  93. rasa/core/channels/voice_ready/utils.py +37 -0
  94. rasa/core/channels/voice_stream/asr/__init__.py +0 -0
  95. rasa/core/channels/voice_stream/asr/asr_engine.py +89 -0
  96. rasa/core/channels/voice_stream/asr/asr_event.py +18 -0
  97. rasa/core/channels/voice_stream/asr/azure.py +129 -0
  98. rasa/core/channels/voice_stream/asr/deepgram.py +90 -0
  99. rasa/core/channels/voice_stream/audio_bytes.py +8 -0
  100. rasa/core/channels/voice_stream/browser_audio.py +107 -0
  101. rasa/core/channels/voice_stream/call_state.py +23 -0
  102. rasa/core/channels/voice_stream/tts/__init__.py +0 -0
  103. rasa/core/channels/voice_stream/tts/azure.py +106 -0
  104. rasa/core/channels/voice_stream/tts/cartesia.py +118 -0
  105. rasa/core/channels/voice_stream/tts/tts_cache.py +27 -0
  106. rasa/core/channels/voice_stream/tts/tts_engine.py +58 -0
  107. rasa/core/channels/voice_stream/twilio_media_streams.py +173 -0
  108. rasa/core/channels/voice_stream/util.py +57 -0
  109. rasa/core/channels/voice_stream/voice_channel.py +427 -0
  110. rasa/core/information_retrieval/qdrant.py +1 -0
  111. rasa/core/nlg/contextual_response_rephraser.py +45 -17
  112. rasa/{nlu → core}/persistor.py +203 -68
  113. rasa/core/policies/enterprise_search_policy.py +119 -63
  114. rasa/core/policies/flows/flow_executor.py +15 -22
  115. rasa/core/policies/intentless_policy.py +83 -28
  116. rasa/core/processor.py +25 -0
  117. rasa/core/run.py +12 -2
  118. rasa/core/secrets_manager/constants.py +4 -0
  119. rasa/core/secrets_manager/factory.py +8 -0
  120. rasa/core/secrets_manager/vault.py +11 -1
  121. rasa/core/training/interactive.py +33 -34
  122. rasa/core/utils.py +47 -21
  123. rasa/dialogue_understanding/coexistence/llm_based_router.py +41 -14
  124. rasa/dialogue_understanding/commands/__init__.py +6 -0
  125. rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +60 -0
  126. rasa/dialogue_understanding/commands/session_end_command.py +61 -0
  127. rasa/dialogue_understanding/commands/user_silence_command.py +59 -0
  128. rasa/dialogue_understanding/commands/utils.py +5 -0
  129. rasa/dialogue_understanding/generator/constants.py +2 -0
  130. rasa/dialogue_understanding/generator/flow_retrieval.py +47 -9
  131. rasa/dialogue_understanding/generator/llm_based_command_generator.py +38 -15
  132. rasa/dialogue_understanding/generator/llm_command_generator.py +1 -1
  133. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +35 -13
  134. rasa/dialogue_understanding/generator/single_step/command_prompt_template.jinja2 +3 -0
  135. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +60 -13
  136. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +53 -0
  137. rasa/dialogue_understanding/patterns/repeat.py +37 -0
  138. rasa/dialogue_understanding/patterns/user_silence.py +37 -0
  139. rasa/dialogue_understanding/processor/command_processor.py +21 -1
  140. rasa/e2e_test/aggregate_test_stats_calculator.py +1 -11
  141. rasa/e2e_test/assertions.py +136 -61
  142. rasa/e2e_test/assertions_schema.yml +23 -0
  143. rasa/e2e_test/e2e_test_case.py +85 -6
  144. rasa/e2e_test/e2e_test_runner.py +2 -3
  145. rasa/e2e_test/utils/e2e_yaml_utils.py +1 -1
  146. rasa/engine/graph.py +3 -10
  147. rasa/engine/loader.py +12 -0
  148. rasa/engine/recipes/config_files/default_config.yml +0 -3
  149. rasa/engine/recipes/default_recipe.py +0 -1
  150. rasa/engine/recipes/graph_recipe.py +0 -1
  151. rasa/engine/runner/dask.py +2 -2
  152. rasa/engine/storage/local_model_storage.py +12 -42
  153. rasa/engine/storage/storage.py +1 -5
  154. rasa/engine/validation.py +527 -74
  155. rasa/model_manager/__init__.py +0 -0
  156. rasa/model_manager/config.py +40 -0
  157. rasa/model_manager/model_api.py +559 -0
  158. rasa/model_manager/runner_service.py +286 -0
  159. rasa/model_manager/socket_bridge.py +146 -0
  160. rasa/model_manager/studio_jwt_auth.py +86 -0
  161. rasa/model_manager/trainer_service.py +325 -0
  162. rasa/model_manager/utils.py +87 -0
  163. rasa/model_manager/warm_rasa_process.py +187 -0
  164. rasa/model_service.py +112 -0
  165. rasa/model_training.py +42 -23
  166. rasa/nlu/tokenizers/whitespace_tokenizer.py +3 -14
  167. rasa/server.py +4 -2
  168. rasa/shared/constants.py +60 -8
  169. rasa/shared/core/constants.py +13 -0
  170. rasa/shared/core/domain.py +107 -50
  171. rasa/shared/core/events.py +29 -0
  172. rasa/shared/core/flows/flow.py +5 -0
  173. rasa/shared/core/flows/flows_list.py +19 -6
  174. rasa/shared/core/flows/flows_yaml_schema.json +10 -0
  175. rasa/shared/core/flows/utils.py +39 -0
  176. rasa/shared/core/flows/validation.py +121 -0
  177. rasa/shared/core/flows/yaml_flows_io.py +15 -27
  178. rasa/shared/core/slots.py +5 -0
  179. rasa/shared/importers/importer.py +59 -41
  180. rasa/shared/importers/multi_project.py +23 -11
  181. rasa/shared/importers/rasa.py +12 -3
  182. rasa/shared/importers/remote_importer.py +196 -0
  183. rasa/shared/importers/utils.py +3 -1
  184. rasa/shared/nlu/training_data/formats/rasa_yaml.py +18 -3
  185. rasa/shared/nlu/training_data/training_data.py +18 -19
  186. rasa/shared/providers/_configs/litellm_router_client_config.py +220 -0
  187. rasa/shared/providers/_configs/model_group_config.py +167 -0
  188. rasa/shared/providers/_configs/openai_client_config.py +1 -1
  189. rasa/shared/providers/_configs/rasa_llm_client_config.py +73 -0
  190. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +1 -0
  191. rasa/shared/providers/_configs/utils.py +16 -0
  192. rasa/shared/providers/_utils.py +79 -0
  193. rasa/shared/providers/embedding/_base_litellm_embedding_client.py +13 -29
  194. rasa/shared/providers/embedding/azure_openai_embedding_client.py +54 -21
  195. rasa/shared/providers/embedding/default_litellm_embedding_client.py +24 -0
  196. rasa/shared/providers/embedding/litellm_router_embedding_client.py +135 -0
  197. rasa/shared/providers/llm/_base_litellm_client.py +34 -22
  198. rasa/shared/providers/llm/azure_openai_llm_client.py +50 -29
  199. rasa/shared/providers/llm/default_litellm_llm_client.py +24 -0
  200. rasa/shared/providers/llm/litellm_router_llm_client.py +182 -0
  201. rasa/shared/providers/llm/rasa_llm_client.py +112 -0
  202. rasa/shared/providers/llm/self_hosted_llm_client.py +5 -29
  203. rasa/shared/providers/mappings.py +19 -0
  204. rasa/shared/providers/router/__init__.py +0 -0
  205. rasa/shared/providers/router/_base_litellm_router_client.py +183 -0
  206. rasa/shared/providers/router/router_client.py +73 -0
  207. rasa/shared/utils/common.py +40 -24
  208. rasa/shared/utils/health_check/__init__.py +0 -0
  209. rasa/shared/utils/health_check/embeddings_health_check_mixin.py +31 -0
  210. rasa/shared/utils/health_check/health_check.py +258 -0
  211. rasa/shared/utils/health_check/llm_health_check_mixin.py +31 -0
  212. rasa/shared/utils/io.py +27 -6
  213. rasa/shared/utils/llm.py +354 -44
  214. rasa/shared/utils/schemas/events.py +2 -0
  215. rasa/shared/utils/schemas/model_config.yml +0 -10
  216. rasa/shared/utils/yaml.py +181 -38
  217. rasa/studio/data_handler.py +3 -1
  218. rasa/studio/upload.py +160 -74
  219. rasa/telemetry.py +94 -17
  220. rasa/tracing/config.py +3 -1
  221. rasa/tracing/instrumentation/attribute_extractors.py +95 -18
  222. rasa/tracing/instrumentation/instrumentation.py +121 -0
  223. rasa/utils/common.py +5 -0
  224. rasa/utils/endpoints.py +27 -1
  225. rasa/utils/io.py +8 -16
  226. rasa/utils/log_utils.py +9 -2
  227. rasa/utils/sanic_error_handler.py +32 -0
  228. rasa/validator.py +110 -16
  229. rasa/version.py +1 -1
  230. {rasa_pro-3.10.16.dist-info → rasa_pro-3.11.0.dist-info}/METADATA +16 -14
  231. {rasa_pro-3.10.16.dist-info → rasa_pro-3.11.0.dist-info}/RECORD +236 -185
  232. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-1844e5a5.js +0 -1
  233. rasa/core/channels/inspector/dist/assets/index-a5d3e69d.js +0 -1040
  234. rasa/core/channels/voice_aware/utils.py +0 -20
  235. rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +0 -407
  236. /rasa/core/channels/{voice_aware → voice_ready}/__init__.py +0 -0
  237. /rasa/core/channels/{voice_native → voice_stream}/__init__.py +0 -0
  238. {rasa_pro-3.10.16.dist-info → rasa_pro-3.11.0.dist-info}/NOTICE +0 -0
  239. {rasa_pro-3.10.16.dist-info → rasa_pro-3.11.0.dist-info}/WHEEL +0 -0
  240. {rasa_pro-3.10.16.dist-info → rasa_pro-3.11.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,167 @@
1
+ from dataclasses import asdict, dataclass, field
2
+ from typing import List, Optional
3
+
4
+ import structlog
5
+ from rasa.shared.constants import (
6
+ API_BASE_CONFIG_KEY,
7
+ API_KEY,
8
+ API_TYPE_CONFIG_KEY,
9
+ API_VERSION_CONFIG_KEY,
10
+ DEPLOYMENT_CONFIG_KEY,
11
+ PROVIDER_CONFIG_KEY,
12
+ MODEL_CONFIG_KEY,
13
+ MODEL_GROUP_ID_CONFIG_KEY,
14
+ MODELS_CONFIG_KEY,
15
+ MODEL_GROUPS_CONFIG_KEY,
16
+ EXTRA_PARAMETERS_KEY,
17
+ )
18
+ from rasa.shared.providers.mappings import get_client_config_class_from_provider
19
+
20
+ structlogger = structlog.get_logger()
21
+
22
+
23
+ @dataclass
24
+ class ModelConfig:
25
+ """Parses the model config.
26
+
27
+ Raises:
28
+ ValueError: If the provider config key is missing in the config.
29
+ """
30
+
31
+ provider: str
32
+ model: Optional[str] = None
33
+ deployment: Optional[str] = None
34
+ api_base: Optional[str] = None
35
+ api_key: Optional[str] = None
36
+ api_version: Optional[str] = None
37
+ extra_parameters: dict = field(default_factory=dict)
38
+ # Retained for backward compatibility with older configurations,
39
+ # but intentionally not included in extra_parameters
40
+ api_type: Optional[str] = None
41
+
42
+ @classmethod
43
+ def from_dict(cls, config: dict) -> "ModelConfig":
44
+ """Initializes a dataclass from the passed config. The provider config param is
45
+ used to determine the client config class to use. The client config class takes
46
+ care of resolving config aliases and throwing deprecation warnings.
47
+
48
+ Args:
49
+ config: (dict) The config from which to initialize.
50
+
51
+ Raises:
52
+ ValueError: Config is missing required keys.
53
+
54
+ Returns:
55
+ ModelConfig
56
+ """
57
+ from rasa.shared.utils.llm import get_provider_from_config
58
+
59
+ # Get the provider from config, this also inferring the provider from
60
+ # deprecated configurations
61
+ provider = get_provider_from_config(config)
62
+
63
+ # Retrieve the client configuration class for the specified provider.
64
+ client_config_clazz = get_client_config_class_from_provider(provider)
65
+
66
+ # Try to instantiate the config object in order to resolve deprecated
67
+ # aliases and throw deprecation warnings.
68
+ client_config_obj = client_config_clazz.from_dict(config)
69
+
70
+ # Convert back to dictionary and instantiate the ModelConfig object.
71
+ client_config = client_config_obj.to_dict()
72
+
73
+ # Check for provider after resolving all aliases
74
+ if PROVIDER_CONFIG_KEY not in client_config:
75
+ raise ValueError(
76
+ f"Missing required key '{PROVIDER_CONFIG_KEY}' in "
77
+ f"'{MODELS_CONFIG_KEY}' config."
78
+ )
79
+
80
+ return ModelConfig(
81
+ provider=client_config.pop(PROVIDER_CONFIG_KEY, None),
82
+ model=client_config.pop(MODEL_CONFIG_KEY, None),
83
+ deployment=client_config.pop(DEPLOYMENT_CONFIG_KEY, None),
84
+ api_type=client_config.pop(API_TYPE_CONFIG_KEY, None),
85
+ api_base=client_config.pop(API_BASE_CONFIG_KEY, None),
86
+ api_key=client_config.pop(API_KEY, None),
87
+ api_version=client_config.pop(API_VERSION_CONFIG_KEY, None),
88
+ extra_parameters=client_config,
89
+ )
90
+
91
+ def to_dict(self) -> dict:
92
+ """Converts the config instance into a dictionary."""
93
+ d = asdict(self)
94
+
95
+ # Extra parameters should also be on the top level
96
+ d.pop(EXTRA_PARAMETERS_KEY, None)
97
+ d.update(self.extra_parameters)
98
+
99
+ # Remove keys with None values
100
+ return {key: value for key, value in d.items() if value is not None}
101
+
102
+
103
+ @dataclass
104
+ class ModelGroupConfig:
105
+ """Parses the models config. The models config is a list of model configs.
106
+
107
+ Raises:
108
+ ValueError: If the model group ID is None or if the models list is empty.
109
+ """
110
+
111
+ model_group_id: str
112
+ models: List[ModelConfig]
113
+
114
+ def __post_init__(self) -> None:
115
+ if self.model_group_id is None:
116
+ message = "Model group ID cannot be set to None."
117
+ structlogger.error(
118
+ "model_group_config.validation_error",
119
+ message=message,
120
+ model_group_id=self.model_group_id,
121
+ )
122
+ raise ValueError(message)
123
+ if not self.models:
124
+ message = "Models cannot be empty."
125
+ structlogger.error(
126
+ "model_group_config.validation_error",
127
+ message=message,
128
+ model_group_id=self.model_group_id,
129
+ )
130
+ raise ValueError(message)
131
+
132
+ @classmethod
133
+ def from_dict(cls, config: dict) -> "ModelGroupConfig":
134
+ """Initializes a dataclass from the passed config.
135
+
136
+ Args:
137
+ config: (dict) The config from which to initialize.
138
+
139
+ Raises:
140
+ ValueError: Config is missing required keys.
141
+
142
+ Returns:
143
+ ModelGroupConfig
144
+ """
145
+ if MODELS_CONFIG_KEY not in config:
146
+ raise ValueError(
147
+ f"Missing required key '{MODELS_CONFIG_KEY}' in "
148
+ f"'{MODEL_GROUPS_CONFIG_KEY}' config."
149
+ )
150
+
151
+ models_config = [
152
+ ModelConfig.from_dict(model_config)
153
+ for model_config in config[MODELS_CONFIG_KEY]
154
+ ]
155
+
156
+ return cls(
157
+ model_group_id=config.get(MODEL_GROUP_ID_CONFIG_KEY),
158
+ models=models_config,
159
+ )
160
+
161
+ def to_dict(self) -> dict:
162
+ """Converts the config instance into a dictionary."""
163
+ d = {
164
+ MODEL_GROUP_ID_CONFIG_KEY: self.model_group_id,
165
+ MODELS_CONFIG_KEY: [model.to_dict() for model in self.models],
166
+ }
167
+ return d
@@ -19,8 +19,8 @@ from rasa.shared.constants import (
19
19
  REQUEST_TIMEOUT_CONFIG_KEY,
20
20
  TIMEOUT_CONFIG_KEY,
21
21
  PROVIDER_CONFIG_KEY,
22
- OPENAI_PROVIDER,
23
22
  OPENAI_API_TYPE,
23
+ OPENAI_PROVIDER,
24
24
  )
25
25
  from rasa.shared.providers._configs.utils import (
26
26
  resolve_aliases,
@@ -0,0 +1,73 @@
1
+ from dataclasses import asdict, dataclass, field
2
+ from typing import Optional
3
+
4
+ import structlog
5
+
6
+ from rasa.shared.constants import (
7
+ MODEL_CONFIG_KEY,
8
+ RASA_PROVIDER,
9
+ PROVIDER_CONFIG_KEY,
10
+ API_BASE_CONFIG_KEY,
11
+ )
12
+ from rasa.shared.providers._configs.utils import (
13
+ validate_required_keys,
14
+ )
15
+
16
+ REQUIRED_KEYS = [MODEL_CONFIG_KEY, PROVIDER_CONFIG_KEY, API_BASE_CONFIG_KEY]
17
+
18
+ structlogger = structlog.get_logger()
19
+
20
+
21
+ @dataclass
22
+ class RasaLLMClientConfig:
23
+ """Parses configuration for a Rasa Hosted LiteLLM client,
24
+ checks required keys present.
25
+
26
+ Raises:
27
+ ValueError: Raised in cases of invalid configuration:
28
+ - If any of the required configuration keys are missing.
29
+ """
30
+
31
+ model: Optional[str]
32
+ api_base: Optional[str]
33
+ # Provider is not used by LiteLLM backend, but we define it here since it's
34
+ # used as switch between different clients.
35
+ provider: str = RASA_PROVIDER
36
+
37
+ extra_parameters: dict = field(default_factory=dict)
38
+
39
+ @classmethod
40
+ def from_dict(cls, config: dict) -> "RasaLLMClientConfig":
41
+ """
42
+ Initializes a dataclass from the passed config.
43
+
44
+ Args:
45
+ config: (dict) The config from which to initialize.
46
+
47
+ Raises:
48
+ ValueError: Raised in cases of invalid configuration:
49
+ - If any of the required configuration keys are missing.
50
+ - If `api_type` has a value different from `azure`.
51
+
52
+ Returns:
53
+ RasaLLMClientConfig
54
+ """
55
+ # Validate that required keys are set
56
+ validate_required_keys(config, REQUIRED_KEYS)
57
+
58
+ extra_parameters = {k: v for k, v in config.items() if k not in REQUIRED_KEYS}
59
+
60
+ return cls(
61
+ model=config.get(MODEL_CONFIG_KEY),
62
+ api_base=config.get(API_BASE_CONFIG_KEY),
63
+ provider=config.get(PROVIDER_CONFIG_KEY, RASA_PROVIDER),
64
+ extra_parameters=extra_parameters,
65
+ )
66
+
67
+ def to_dict(self) -> dict:
68
+ """Converts the config instance into a dictionary."""
69
+ d = asdict(self)
70
+ # Extra parameters should also be on the top level
71
+ d.pop("extra_parameters", None)
72
+ d.update(self.extra_parameters)
73
+ return d
@@ -23,6 +23,7 @@ from rasa.shared.constants import (
23
23
  SELF_HOSTED_PROVIDER,
24
24
  USE_CHAT_COMPLETIONS_ENDPOINT_CONFIG_KEY,
25
25
  )
26
+
26
27
  from rasa.shared.providers._configs.utils import (
27
28
  raise_deprecation_warnings,
28
29
  resolve_aliases,
@@ -99,3 +99,19 @@ def validate_forbidden_keys(config: dict, forbidden_keys: list) -> None:
99
99
  config=config,
100
100
  )
101
101
  raise ValueError(message)
102
+
103
+
104
+ def get_provider_prefixed_model_name(provider: str, model: str) -> str:
105
+ """
106
+ Returns the model name with the provider prefixed.
107
+
108
+ Args:
109
+ provider: The provider of the model.
110
+ model: The model name.
111
+
112
+ Returns:
113
+ The model name with the provider prefixed.
114
+ """
115
+ if model and f"{provider}/" not in model:
116
+ return f"{provider}/{model}"
117
+ return model
@@ -0,0 +1,79 @@
1
+ import structlog
2
+
3
+ from rasa.shared.constants import (
4
+ AWS_ACCESS_KEY_ID_ENV_VAR,
5
+ AWS_ACCESS_KEY_ID_CONFIG_KEY,
6
+ AWS_SECRET_ACCESS_KEY_ENV_VAR,
7
+ AWS_SECRET_ACCESS_KEY_CONFIG_KEY,
8
+ AWS_REGION_NAME_ENV_VAR,
9
+ AWS_REGION_NAME_CONFIG_KEY,
10
+ AWS_SESSION_TOKEN_CONFIG_KEY,
11
+ AWS_SESSION_TOKEN_ENV_VAR,
12
+ )
13
+ from rasa.shared.exceptions import ProviderClientValidationError
14
+ from litellm import validate_environment
15
+ from rasa.shared.providers.embedding._base_litellm_embedding_client import (
16
+ _VALIDATE_ENVIRONMENT_MISSING_KEYS_KEY,
17
+ )
18
+
19
+ structlogger = structlog.get_logger()
20
+
21
+
22
+ def validate_aws_setup_for_litellm_clients(
23
+ litellm_model_name: str, litellm_call_kwargs: dict, source_log: str
24
+ ) -> None:
25
+ """Validates the AWS setup for LiteLLM clients to ensure all required
26
+ environment variables or corresponding call kwargs are set.
27
+
28
+ Args:
29
+ litellm_model_name (str): The name of the LiteLLM model being validated.
30
+ litellm_call_kwargs (dict): Additional keyword arguments passed to the client,
31
+ which may include configuration values for AWS credentials.
32
+ source_log (str): The source log identifier for structured logging.
33
+
34
+ Raises:
35
+ ProviderClientValidationError: If any required AWS environment variable
36
+ or corresponding configuration key is missing.
37
+ """
38
+
39
+ # Mapping of environment variable names to their corresponding config keys
40
+ envs_to_args = {
41
+ AWS_ACCESS_KEY_ID_ENV_VAR: AWS_ACCESS_KEY_ID_CONFIG_KEY,
42
+ AWS_SECRET_ACCESS_KEY_ENV_VAR: AWS_SECRET_ACCESS_KEY_CONFIG_KEY,
43
+ AWS_REGION_NAME_ENV_VAR: AWS_REGION_NAME_CONFIG_KEY,
44
+ AWS_SESSION_TOKEN_ENV_VAR: AWS_SESSION_TOKEN_CONFIG_KEY,
45
+ }
46
+
47
+ # Validate the environment setup for the model
48
+ validation_info = validate_environment(litellm_model_name)
49
+ missing_environment_variables = validation_info.get(
50
+ _VALIDATE_ENVIRONMENT_MISSING_KEYS_KEY, []
51
+ )
52
+ # Filter out missing environment variables that have been set trough arguments
53
+ # in extra parameters
54
+ missing_environment_variables = [
55
+ missing_env_var
56
+ for missing_env_var in missing_environment_variables
57
+ if litellm_call_kwargs.get(envs_to_args.get(missing_env_var)) is None
58
+ ]
59
+
60
+ if missing_environment_variables:
61
+ missing_environment_details = [
62
+ (
63
+ f"'{missing_env_var}' environment variable or "
64
+ f"'{envs_to_args.get(missing_env_var)}' config key"
65
+ )
66
+ for missing_env_var in missing_environment_variables
67
+ ]
68
+ event_info = (
69
+ f"The following environment variables or configuration keys are "
70
+ f"missing: "
71
+ f"{', '.join(missing_environment_details)}. "
72
+ f"These settings are required for API calls."
73
+ )
74
+ structlogger.error(
75
+ f"{source_log}.validate_aws_environment_variables",
76
+ event_info=event_info,
77
+ missing_environment_variables=missing_environment_variables,
78
+ )
79
+ raise ProviderClientValidationError(event_info)
@@ -1,12 +1,12 @@
1
+ import logging
1
2
  from abc import abstractmethod
2
3
  from typing import Any, Dict, List
3
4
 
4
5
  import litellm
5
- import logging
6
6
  import structlog
7
7
  from litellm import aembedding, embedding, validate_environment
8
8
 
9
- from rasa.shared.constants import API_BASE_CONFIG_KEY
9
+ from rasa.shared.constants import API_BASE_CONFIG_KEY, API_KEY
10
10
  from rasa.shared.exceptions import (
11
11
  ProviderClientAPIException,
12
12
  ProviderClientValidationError,
@@ -19,7 +19,7 @@ from rasa.shared.providers.embedding.embedding_response import (
19
19
  EmbeddingResponse,
20
20
  EmbeddingUsage,
21
21
  )
22
- from rasa.shared.utils.io import suppress_logs
22
+ from rasa.shared.utils.io import suppress_logs, resolve_environment_variables
23
23
 
24
24
  structlogger = structlog.get_logger()
25
25
 
@@ -27,8 +27,7 @@ _VALIDATE_ENVIRONMENT_MISSING_KEYS_KEY = "missing_keys"
27
27
 
28
28
 
29
29
  class _BaseLiteLLMEmbeddingClient:
30
- """
31
- An abstract base class for LiteLLM embedding clients.
30
+ """An abstract base class for LiteLLM embedding clients.
32
31
 
33
32
  This class defines the interface and common functionality for all clients
34
33
  based on LiteLLM.
@@ -83,12 +82,12 @@ class _BaseLiteLLMEmbeddingClient:
83
82
  ProviderClientValidationError if validation fails.
84
83
  """
85
84
  self._validate_environment_variables()
86
- self._validate_api_key_not_in_config()
87
85
 
88
86
  def _validate_environment_variables(self) -> None:
89
87
  """Validate that the required environment variables are set."""
90
88
  validation_info = validate_environment(
91
89
  self._litellm_model_name,
90
+ api_key=self._litellm_extra_parameters.get(API_KEY),
92
91
  api_base=self._litellm_extra_parameters.get(API_BASE_CONFIG_KEY),
93
92
  )
94
93
  if missing_environment_variables := validation_info.get(
@@ -105,21 +104,8 @@ class _BaseLiteLLMEmbeddingClient:
105
104
  )
106
105
  raise ProviderClientValidationError(event_info)
107
106
 
108
- def _validate_api_key_not_in_config(self) -> None:
109
- if "api_key" in self._litellm_extra_parameters:
110
- event_info = (
111
- "API Key is set through `api_key` extra parameter."
112
- "Set API keys through environment variables."
113
- )
114
- structlogger.error(
115
- "base_litellm_client.validate_api_key_not_in_config",
116
- event_info=event_info,
117
- )
118
- raise ProviderClientValidationError(event_info)
119
-
120
107
  def validate_documents(self, documents: List[str]) -> None:
121
- """
122
- Validates a list of documents to ensure they are suitable for embedding.
108
+ """Validates a list of documents to ensure they are suitable for embedding.
123
109
 
124
110
  Args:
125
111
  documents: List of documents to be validated.
@@ -135,8 +121,7 @@ class _BaseLiteLLMEmbeddingClient:
135
121
 
136
122
  @suppress_logs(log_level=logging.WARNING)
137
123
  def embed(self, documents: List[str]) -> EmbeddingResponse:
138
- """
139
- Embeds a list of documents synchronously.
124
+ """Embeds a list of documents synchronously.
140
125
 
141
126
  Args:
142
127
  documents: List of documents to be embedded.
@@ -149,7 +134,8 @@ class _BaseLiteLLMEmbeddingClient:
149
134
  """
150
135
  self.validate_documents(documents)
151
136
  try:
152
- response = embedding(input=documents, **self._embedding_fn_args)
137
+ arguments = resolve_environment_variables(self._embedding_fn_args)
138
+ response = embedding(input=documents, **arguments)
153
139
  return self._format_response(response)
154
140
  except Exception as e:
155
141
  raise ProviderClientAPIException(
@@ -158,8 +144,7 @@ class _BaseLiteLLMEmbeddingClient:
158
144
 
159
145
  @suppress_logs(log_level=logging.WARNING)
160
146
  async def aembed(self, documents: List[str]) -> EmbeddingResponse:
161
- """
162
- Embeds a list of documents asynchronously.
147
+ """Embeds a list of documents asynchronously.
163
148
 
164
149
  Args:
165
150
  documents: List of documents to be embedded.
@@ -172,7 +157,8 @@ class _BaseLiteLLMEmbeddingClient:
172
157
  """
173
158
  self.validate_documents(documents)
174
159
  try:
175
- response = await aembedding(input=documents, **self._embedding_fn_args)
160
+ arguments = resolve_environment_variables(self._embedding_fn_args)
161
+ response = await aembedding(input=documents, **arguments)
176
162
  return self._format_response(response)
177
163
  except Exception as e:
178
164
  raise ProviderClientAPIException(
@@ -187,7 +173,6 @@ class _BaseLiteLLMEmbeddingClient:
187
173
  Raises:
188
174
  ValueError: If any response data is None.
189
175
  """
190
-
191
176
  # If data is not available (None), raise a ValueError
192
177
  if response.data is None:
193
178
  message = (
@@ -244,8 +229,7 @@ class _BaseLiteLLMEmbeddingClient:
244
229
 
245
230
  @staticmethod
246
231
  def _ensure_certificates() -> None:
247
- """
248
- Configures SSL certificates for LiteLLM. This method is invoked during
232
+ """Configures SSL certificates for LiteLLM. This method is invoked during
249
233
  client initialization.
250
234
 
251
235
  LiteLLM may utilize `openai` clients or other providers that require
@@ -1,5 +1,6 @@
1
- from typing import Any, Dict, List, Optional
2
1
  import os
2
+ from typing import Any, Dict, List, Optional
3
+
3
4
  import structlog
4
5
 
5
6
  from rasa.shared.constants import (
@@ -42,6 +43,7 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
42
43
  If not provided, it will be set via environment variable.
43
44
  kwargs (Optional[Dict[str, Any]]): Optional configuration parameters specific
44
45
  to the embedding model deployment.
46
+
45
47
  Raises:
46
48
  ProviderClientValidationError: If validation of the client setup fails.
47
49
  DeprecationWarning: If deprecated environment variables are used for
@@ -60,6 +62,7 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
60
62
  super().__init__() # type: ignore
61
63
  self._deployment = deployment
62
64
  self._model = model
65
+ self._extra_parameters = kwargs or {}
63
66
 
64
67
  # Set api_base with the following priority:
65
68
  # parameter -> Azure Env Var -> (deprecated) OpenAI Env Var
@@ -81,17 +84,55 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
81
84
  # Litellm does not support use of OPENAI_API_KEY, so we need to map it
82
85
  # because of backward compatibility. However, we're first looking at
83
86
  # AZURE_API_KEY.
84
- self._api_key = os.environ.get(AZURE_API_KEY_ENV_VAR) or os.environ.get(
85
- OPENAI_API_KEY_ENV_VAR
86
- )
87
+ self._api_key_env_var = self._resolve_api_key_env_var()
87
88
 
88
- self._extra_parameters = kwargs or {}
89
89
  self.validate_client_setup()
90
90
 
91
+ def _resolve_api_key_env_var(self) -> str:
92
+ """Resolves the environment variable to use for the API key.
93
+
94
+ Returns:
95
+ str: The env variable in dollar syntax format to use for the API key.
96
+ """
97
+ if API_KEY in self._extra_parameters:
98
+ # API key is set to an env var in the config itself
99
+ # in case the model is defined in the endpoints.yml
100
+ return self._extra_parameters[API_KEY]
101
+
102
+ if os.getenv(AZURE_API_KEY_ENV_VAR) is not None:
103
+ return "${AZURE_API_KEY}"
104
+
105
+ if os.getenv(OPENAI_API_KEY_ENV_VAR) is not None:
106
+ # API key can be set through OPENAI_API_KEY too,
107
+ # because of the backward compatibility
108
+ raise_deprecation_warning(
109
+ message=(
110
+ f"Usage of '{OPENAI_API_KEY_ENV_VAR}' environment variable "
111
+ "for setting the API key of "
112
+ "Azure OpenAI client is deprecated and will "
113
+ "be removed in 4.0.0. Please "
114
+ f"use '{AZURE_API_KEY_ENV_VAR}' instead."
115
+ )
116
+ )
117
+ return "${OPENAI_API_KEY}"
118
+
119
+ structlogger.error(
120
+ "azure_openai_embedding_client.api_key_not_set",
121
+ event_info=(
122
+ "API key not set, it is required for API calls. "
123
+ f"Set it either via the environment variable "
124
+ f"'{AZURE_API_KEY_ENV_VAR}' or directly"
125
+ f"via the config key '{API_KEY}'."
126
+ ),
127
+ )
128
+ raise ProviderClientValidationError(
129
+ f"Missing required environment variable/config key '{API_KEY}' for "
130
+ f"API calls."
131
+ )
132
+
91
133
  @classmethod
92
134
  def from_config(cls, config: Dict[str, Any]) -> "AzureOpenAIEmbeddingClient":
93
- """
94
- Initializes the client from given configuration.
135
+ """Initializes the client from given configuration.
95
136
 
96
137
  Args:
97
138
  config (Dict[str, Any]): Configuration.
@@ -142,8 +183,7 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
142
183
 
143
184
  @property
144
185
  def model(self) -> Optional[str]:
145
- """
146
- Returns the name of the model deployed on Azure. If model name is not
186
+ """Returns the name of the model deployed on Azure. If model name is not
147
187
  provided, returns "N/A".
148
188
  """
149
189
  return self._model
@@ -170,8 +210,7 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
170
210
 
171
211
  @property
172
212
  def _litellm_extra_parameters(self) -> Dict[str, Any]:
173
- """
174
- Returns the model parameters for the azure openai embedding client.
213
+ """Returns the model parameters for the azure openai embedding client.
175
214
 
176
215
  Returns:
177
216
  Dictionary containing the model parameters.
@@ -186,7 +225,7 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
186
225
  "api_base": self.api_base,
187
226
  "api_type": self.api_type,
188
227
  "api_version": self.api_version,
189
- "api_key": self._api_key,
228
+ "api_key": self._api_key_env_var,
190
229
  }
191
230
 
192
231
  @property
@@ -197,8 +236,9 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
197
236
  return self.deployment
198
237
 
199
238
  def validate_client_setup(self) -> None:
200
- """Perform client validation. By default only environment variables
201
- are validated.
239
+ """Perform client validation.
240
+
241
+ By default, only environment variables are validated.
202
242
 
203
243
  Raises:
204
244
  ProviderClientValidationError if validation fails.
@@ -214,13 +254,6 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
214
254
  "current_value": self.api_base,
215
255
  "new_env_key": AZURE_API_BASE_ENV_VAR,
216
256
  },
217
- {
218
- "param_name": "API key",
219
- "config_key": API_KEY,
220
- "deprecated_env_key": OPENAI_API_KEY_ENV_VAR,
221
- "current_value": self._api_key,
222
- "new_env_key": AZURE_API_KEY_ENV_VAR,
223
- },
224
257
  {
225
258
  "param_name": "API version",
226
259
  "config_key": API_VERSION_CONFIG_KEY,
@@ -1,8 +1,13 @@
1
1
  from typing import Any, Dict
2
2
 
3
+ from rasa.shared.constants import (
4
+ AWS_BEDROCK_PROVIDER,
5
+ AWS_SAGEMAKER_PROVIDER,
6
+ )
3
7
  from rasa.shared.providers._configs.default_litellm_client_config import (
4
8
  DefaultLiteLLMClientConfig,
5
9
  )
10
+ from rasa.shared.providers._utils import validate_aws_setup_for_litellm_clients
6
11
  from rasa.shared.providers.embedding._base_litellm_embedding_client import (
7
12
  _BaseLiteLLMEmbeddingClient,
8
13
  )
@@ -100,3 +105,22 @@ class DefaultLiteLLMEmbeddingClient(_BaseLiteLLMEmbeddingClient):
100
105
  "model": self._litellm_model_name,
101
106
  **self._litellm_extra_parameters,
102
107
  }
108
+
109
+ def validate_client_setup(self) -> None:
110
+ # TODO: Temporarily disable environment variable validation for AWS setup
111
+ # (Bedrock and SageMaker) until resolved by either:
112
+ # 1. An update from the LiteLLM package addressing the issue.
113
+ # 2. The implementation of a Bedrock client on our end.
114
+ # ---
115
+ # This fix ensures a consistent user experience for Bedrock (and
116
+ # SageMaker) in Rasa by allowing AWS secrets to be provided as extra
117
+ # parameters without triggering validation errors due to missing AWS
118
+ # environment variables.
119
+ if self.provider.lower() in [AWS_BEDROCK_PROVIDER, AWS_SAGEMAKER_PROVIDER]:
120
+ validate_aws_setup_for_litellm_clients(
121
+ self._litellm_model_name,
122
+ self._litellm_extra_parameters,
123
+ "default_litellm_embedding_client",
124
+ )
125
+ else:
126
+ super().validate_client_setup()