rasa-pro 3.10.16__py3-none-any.whl → 3.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (240) hide show
  1. rasa/__main__.py +31 -15
  2. rasa/api.py +12 -2
  3. rasa/cli/arguments/default_arguments.py +24 -4
  4. rasa/cli/arguments/run.py +15 -0
  5. rasa/cli/arguments/shell.py +5 -1
  6. rasa/cli/arguments/train.py +17 -9
  7. rasa/cli/evaluate.py +7 -7
  8. rasa/cli/inspect.py +19 -7
  9. rasa/cli/interactive.py +1 -0
  10. rasa/cli/llm_fine_tuning.py +11 -14
  11. rasa/cli/project_templates/calm/config.yml +5 -7
  12. rasa/cli/project_templates/calm/endpoints.yml +15 -2
  13. rasa/cli/project_templates/tutorial/config.yml +8 -5
  14. rasa/cli/project_templates/tutorial/data/flows.yml +1 -1
  15. rasa/cli/project_templates/tutorial/data/patterns.yml +5 -0
  16. rasa/cli/project_templates/tutorial/domain.yml +14 -0
  17. rasa/cli/project_templates/tutorial/endpoints.yml +5 -0
  18. rasa/cli/run.py +7 -0
  19. rasa/cli/scaffold.py +4 -2
  20. rasa/cli/studio/upload.py +0 -15
  21. rasa/cli/train.py +14 -53
  22. rasa/cli/utils.py +14 -11
  23. rasa/cli/x.py +7 -7
  24. rasa/constants.py +3 -1
  25. rasa/core/actions/action.py +77 -33
  26. rasa/core/actions/action_hangup.py +29 -0
  27. rasa/core/actions/action_repeat_bot_messages.py +89 -0
  28. rasa/core/actions/e2e_stub_custom_action_executor.py +5 -1
  29. rasa/core/actions/http_custom_action_executor.py +4 -0
  30. rasa/core/agent.py +2 -2
  31. rasa/core/brokers/kafka.py +3 -1
  32. rasa/core/brokers/pika.py +3 -1
  33. rasa/core/channels/__init__.py +10 -6
  34. rasa/core/channels/channel.py +41 -4
  35. rasa/core/channels/development_inspector.py +150 -46
  36. rasa/core/channels/inspector/README.md +1 -1
  37. rasa/core/channels/inspector/dist/assets/{arc-b6e548fe.js → arc-bc141fb2.js} +1 -1
  38. rasa/core/channels/inspector/dist/assets/{c4Diagram-d0fbc5ce-fa03ac9e.js → c4Diagram-d0fbc5ce-be2db283.js} +1 -1
  39. rasa/core/channels/inspector/dist/assets/{classDiagram-936ed81e-ee67392a.js → classDiagram-936ed81e-55366915.js} +1 -1
  40. rasa/core/channels/inspector/dist/assets/{classDiagram-v2-c3cb15f1-9b283fae.js → classDiagram-v2-c3cb15f1-bb529518.js} +1 -1
  41. rasa/core/channels/inspector/dist/assets/{createText-62fc7601-8b6fcc2a.js → createText-62fc7601-b0ec81d6.js} +1 -1
  42. rasa/core/channels/inspector/dist/assets/{edges-f2ad444c-22e77f4f.js → edges-f2ad444c-6166330c.js} +1 -1
  43. rasa/core/channels/inspector/dist/assets/{erDiagram-9d236eb7-60ffc87f.js → erDiagram-9d236eb7-5ccc6a8e.js} +1 -1
  44. rasa/core/channels/inspector/dist/assets/{flowDb-1972c806-9dd802e4.js → flowDb-1972c806-fca3bfe4.js} +1 -1
  45. rasa/core/channels/inspector/dist/assets/{flowDiagram-7ea5b25a-5fa1912f.js → flowDiagram-7ea5b25a-4739080f.js} +1 -1
  46. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-736177bf.js +1 -0
  47. rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-abe16c3d-622a1fd2.js → flowchart-elk-definition-abe16c3d-7c1b0e0f.js} +1 -1
  48. rasa/core/channels/inspector/dist/assets/{ganttDiagram-9b5ea136-e285a63a.js → ganttDiagram-9b5ea136-772fd050.js} +1 -1
  49. rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-99d0ae7c-f237bdca.js → gitGraphDiagram-99d0ae7c-8eae1dc9.js} +1 -1
  50. rasa/core/channels/inspector/dist/assets/{index-2c4b9a3b-4b03d70e.js → index-2c4b9a3b-f55afcdf.js} +1 -1
  51. rasa/core/channels/inspector/dist/assets/index-e7cef9de.js +1317 -0
  52. rasa/core/channels/inspector/dist/assets/{infoDiagram-736b4530-72a0fa5f.js → infoDiagram-736b4530-124d4a14.js} +1 -1
  53. rasa/core/channels/inspector/dist/assets/{journeyDiagram-df861f2b-82218c41.js → journeyDiagram-df861f2b-7c4fae44.js} +1 -1
  54. rasa/core/channels/inspector/dist/assets/{layout-78cff630.js → layout-b9885fb6.js} +1 -1
  55. rasa/core/channels/inspector/dist/assets/{line-5038b469.js → line-7c59abb6.js} +1 -1
  56. rasa/core/channels/inspector/dist/assets/{linear-c4fc4098.js → linear-4776f780.js} +1 -1
  57. rasa/core/channels/inspector/dist/assets/{mindmap-definition-beec6740-c33c8ea6.js → mindmap-definition-beec6740-2332c46c.js} +1 -1
  58. rasa/core/channels/inspector/dist/assets/{pieDiagram-dbbf0591-a8d03059.js → pieDiagram-dbbf0591-8fb39303.js} +1 -1
  59. rasa/core/channels/inspector/dist/assets/{quadrantDiagram-4d7f4fd6-6a0e56b2.js → quadrantDiagram-4d7f4fd6-3c7180a2.js} +1 -1
  60. rasa/core/channels/inspector/dist/assets/{requirementDiagram-6fc4c22a-2dc7c7bd.js → requirementDiagram-6fc4c22a-e910bcb8.js} +1 -1
  61. rasa/core/channels/inspector/dist/assets/{sankeyDiagram-8f13d901-2360fe39.js → sankeyDiagram-8f13d901-ead16c89.js} +1 -1
  62. rasa/core/channels/inspector/dist/assets/{sequenceDiagram-b655622a-41b9f9ad.js → sequenceDiagram-b655622a-29a02a19.js} +1 -1
  63. rasa/core/channels/inspector/dist/assets/{stateDiagram-59f0c015-0aad326f.js → stateDiagram-59f0c015-042b3137.js} +1 -1
  64. rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-2b26beab-9847d984.js → stateDiagram-v2-2b26beab-2178c0f3.js} +1 -1
  65. rasa/core/channels/inspector/dist/assets/{styles-080da4f6-564d890e.js → styles-080da4f6-23ffa4fc.js} +1 -1
  66. rasa/core/channels/inspector/dist/assets/{styles-3dcbcfbf-38957613.js → styles-3dcbcfbf-94f59763.js} +1 -1
  67. rasa/core/channels/inspector/dist/assets/{styles-9c745c82-f0fc6921.js → styles-9c745c82-78a6bebc.js} +1 -1
  68. rasa/core/channels/inspector/dist/assets/{svgDrawCommon-4835440b-ef3c5a77.js → svgDrawCommon-4835440b-eae2a6f6.js} +1 -1
  69. rasa/core/channels/inspector/dist/assets/{timeline-definition-5b62e21b-bf3e91c1.js → timeline-definition-5b62e21b-5c968d92.js} +1 -1
  70. rasa/core/channels/inspector/dist/assets/{xychartDiagram-2b33534f-4d4026c0.js → xychartDiagram-2b33534f-fd3db0d5.js} +1 -1
  71. rasa/core/channels/inspector/dist/index.html +18 -17
  72. rasa/core/channels/inspector/index.html +17 -16
  73. rasa/core/channels/inspector/package.json +5 -1
  74. rasa/core/channels/inspector/src/App.tsx +118 -68
  75. rasa/core/channels/inspector/src/components/Chat.tsx +95 -0
  76. rasa/core/channels/inspector/src/components/DiagramFlow.tsx +11 -10
  77. rasa/core/channels/inspector/src/components/DialogueStack.tsx +10 -25
  78. rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +6 -3
  79. rasa/core/channels/inspector/src/helpers/audiostream.ts +165 -0
  80. rasa/core/channels/inspector/src/helpers/formatters.test.ts +10 -0
  81. rasa/core/channels/inspector/src/helpers/formatters.ts +107 -41
  82. rasa/core/channels/inspector/src/helpers/utils.ts +92 -7
  83. rasa/core/channels/inspector/src/types.ts +21 -1
  84. rasa/core/channels/inspector/yarn.lock +94 -1
  85. rasa/core/channels/rest.py +51 -46
  86. rasa/core/channels/socketio.py +28 -1
  87. rasa/core/channels/telegram.py +1 -1
  88. rasa/core/channels/twilio.py +1 -1
  89. rasa/core/channels/{audiocodes.py → voice_ready/audiocodes.py} +122 -69
  90. rasa/core/channels/{voice_aware → voice_ready}/jambonz.py +26 -8
  91. rasa/core/channels/{voice_aware → voice_ready}/jambonz_protocol.py +57 -5
  92. rasa/core/channels/{twilio_voice.py → voice_ready/twilio_voice.py} +64 -28
  93. rasa/core/channels/voice_ready/utils.py +37 -0
  94. rasa/core/channels/voice_stream/asr/__init__.py +0 -0
  95. rasa/core/channels/voice_stream/asr/asr_engine.py +89 -0
  96. rasa/core/channels/voice_stream/asr/asr_event.py +18 -0
  97. rasa/core/channels/voice_stream/asr/azure.py +129 -0
  98. rasa/core/channels/voice_stream/asr/deepgram.py +90 -0
  99. rasa/core/channels/voice_stream/audio_bytes.py +8 -0
  100. rasa/core/channels/voice_stream/browser_audio.py +107 -0
  101. rasa/core/channels/voice_stream/call_state.py +23 -0
  102. rasa/core/channels/voice_stream/tts/__init__.py +0 -0
  103. rasa/core/channels/voice_stream/tts/azure.py +106 -0
  104. rasa/core/channels/voice_stream/tts/cartesia.py +118 -0
  105. rasa/core/channels/voice_stream/tts/tts_cache.py +27 -0
  106. rasa/core/channels/voice_stream/tts/tts_engine.py +58 -0
  107. rasa/core/channels/voice_stream/twilio_media_streams.py +173 -0
  108. rasa/core/channels/voice_stream/util.py +57 -0
  109. rasa/core/channels/voice_stream/voice_channel.py +427 -0
  110. rasa/core/information_retrieval/qdrant.py +1 -0
  111. rasa/core/nlg/contextual_response_rephraser.py +45 -17
  112. rasa/{nlu → core}/persistor.py +203 -68
  113. rasa/core/policies/enterprise_search_policy.py +119 -63
  114. rasa/core/policies/flows/flow_executor.py +15 -22
  115. rasa/core/policies/intentless_policy.py +83 -28
  116. rasa/core/processor.py +25 -0
  117. rasa/core/run.py +12 -2
  118. rasa/core/secrets_manager/constants.py +4 -0
  119. rasa/core/secrets_manager/factory.py +8 -0
  120. rasa/core/secrets_manager/vault.py +11 -1
  121. rasa/core/training/interactive.py +33 -34
  122. rasa/core/utils.py +47 -21
  123. rasa/dialogue_understanding/coexistence/llm_based_router.py +41 -14
  124. rasa/dialogue_understanding/commands/__init__.py +6 -0
  125. rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +60 -0
  126. rasa/dialogue_understanding/commands/session_end_command.py +61 -0
  127. rasa/dialogue_understanding/commands/user_silence_command.py +59 -0
  128. rasa/dialogue_understanding/commands/utils.py +5 -0
  129. rasa/dialogue_understanding/generator/constants.py +2 -0
  130. rasa/dialogue_understanding/generator/flow_retrieval.py +47 -9
  131. rasa/dialogue_understanding/generator/llm_based_command_generator.py +38 -15
  132. rasa/dialogue_understanding/generator/llm_command_generator.py +1 -1
  133. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +35 -13
  134. rasa/dialogue_understanding/generator/single_step/command_prompt_template.jinja2 +3 -0
  135. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +60 -13
  136. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +53 -0
  137. rasa/dialogue_understanding/patterns/repeat.py +37 -0
  138. rasa/dialogue_understanding/patterns/user_silence.py +37 -0
  139. rasa/dialogue_understanding/processor/command_processor.py +21 -1
  140. rasa/e2e_test/aggregate_test_stats_calculator.py +1 -11
  141. rasa/e2e_test/assertions.py +136 -61
  142. rasa/e2e_test/assertions_schema.yml +23 -0
  143. rasa/e2e_test/e2e_test_case.py +85 -6
  144. rasa/e2e_test/e2e_test_runner.py +2 -3
  145. rasa/e2e_test/utils/e2e_yaml_utils.py +1 -1
  146. rasa/engine/graph.py +3 -10
  147. rasa/engine/loader.py +12 -0
  148. rasa/engine/recipes/config_files/default_config.yml +0 -3
  149. rasa/engine/recipes/default_recipe.py +0 -1
  150. rasa/engine/recipes/graph_recipe.py +0 -1
  151. rasa/engine/runner/dask.py +2 -2
  152. rasa/engine/storage/local_model_storage.py +12 -42
  153. rasa/engine/storage/storage.py +1 -5
  154. rasa/engine/validation.py +527 -74
  155. rasa/model_manager/__init__.py +0 -0
  156. rasa/model_manager/config.py +40 -0
  157. rasa/model_manager/model_api.py +559 -0
  158. rasa/model_manager/runner_service.py +286 -0
  159. rasa/model_manager/socket_bridge.py +146 -0
  160. rasa/model_manager/studio_jwt_auth.py +86 -0
  161. rasa/model_manager/trainer_service.py +325 -0
  162. rasa/model_manager/utils.py +87 -0
  163. rasa/model_manager/warm_rasa_process.py +187 -0
  164. rasa/model_service.py +112 -0
  165. rasa/model_training.py +42 -23
  166. rasa/nlu/tokenizers/whitespace_tokenizer.py +3 -14
  167. rasa/server.py +4 -2
  168. rasa/shared/constants.py +60 -8
  169. rasa/shared/core/constants.py +13 -0
  170. rasa/shared/core/domain.py +107 -50
  171. rasa/shared/core/events.py +29 -0
  172. rasa/shared/core/flows/flow.py +5 -0
  173. rasa/shared/core/flows/flows_list.py +19 -6
  174. rasa/shared/core/flows/flows_yaml_schema.json +10 -0
  175. rasa/shared/core/flows/utils.py +39 -0
  176. rasa/shared/core/flows/validation.py +121 -0
  177. rasa/shared/core/flows/yaml_flows_io.py +15 -27
  178. rasa/shared/core/slots.py +5 -0
  179. rasa/shared/importers/importer.py +59 -41
  180. rasa/shared/importers/multi_project.py +23 -11
  181. rasa/shared/importers/rasa.py +12 -3
  182. rasa/shared/importers/remote_importer.py +196 -0
  183. rasa/shared/importers/utils.py +3 -1
  184. rasa/shared/nlu/training_data/formats/rasa_yaml.py +18 -3
  185. rasa/shared/nlu/training_data/training_data.py +18 -19
  186. rasa/shared/providers/_configs/litellm_router_client_config.py +220 -0
  187. rasa/shared/providers/_configs/model_group_config.py +167 -0
  188. rasa/shared/providers/_configs/openai_client_config.py +1 -1
  189. rasa/shared/providers/_configs/rasa_llm_client_config.py +73 -0
  190. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +1 -0
  191. rasa/shared/providers/_configs/utils.py +16 -0
  192. rasa/shared/providers/_utils.py +79 -0
  193. rasa/shared/providers/embedding/_base_litellm_embedding_client.py +13 -29
  194. rasa/shared/providers/embedding/azure_openai_embedding_client.py +54 -21
  195. rasa/shared/providers/embedding/default_litellm_embedding_client.py +24 -0
  196. rasa/shared/providers/embedding/litellm_router_embedding_client.py +135 -0
  197. rasa/shared/providers/llm/_base_litellm_client.py +34 -22
  198. rasa/shared/providers/llm/azure_openai_llm_client.py +50 -29
  199. rasa/shared/providers/llm/default_litellm_llm_client.py +24 -0
  200. rasa/shared/providers/llm/litellm_router_llm_client.py +182 -0
  201. rasa/shared/providers/llm/rasa_llm_client.py +112 -0
  202. rasa/shared/providers/llm/self_hosted_llm_client.py +5 -29
  203. rasa/shared/providers/mappings.py +19 -0
  204. rasa/shared/providers/router/__init__.py +0 -0
  205. rasa/shared/providers/router/_base_litellm_router_client.py +183 -0
  206. rasa/shared/providers/router/router_client.py +73 -0
  207. rasa/shared/utils/common.py +40 -24
  208. rasa/shared/utils/health_check/__init__.py +0 -0
  209. rasa/shared/utils/health_check/embeddings_health_check_mixin.py +31 -0
  210. rasa/shared/utils/health_check/health_check.py +258 -0
  211. rasa/shared/utils/health_check/llm_health_check_mixin.py +31 -0
  212. rasa/shared/utils/io.py +27 -6
  213. rasa/shared/utils/llm.py +354 -44
  214. rasa/shared/utils/schemas/events.py +2 -0
  215. rasa/shared/utils/schemas/model_config.yml +0 -10
  216. rasa/shared/utils/yaml.py +181 -38
  217. rasa/studio/data_handler.py +3 -1
  218. rasa/studio/upload.py +160 -74
  219. rasa/telemetry.py +94 -17
  220. rasa/tracing/config.py +3 -1
  221. rasa/tracing/instrumentation/attribute_extractors.py +95 -18
  222. rasa/tracing/instrumentation/instrumentation.py +121 -0
  223. rasa/utils/common.py +5 -0
  224. rasa/utils/endpoints.py +27 -1
  225. rasa/utils/io.py +8 -16
  226. rasa/utils/log_utils.py +9 -2
  227. rasa/utils/sanic_error_handler.py +32 -0
  228. rasa/validator.py +110 -16
  229. rasa/version.py +1 -1
  230. {rasa_pro-3.10.16.dist-info → rasa_pro-3.11.0.dist-info}/METADATA +16 -14
  231. {rasa_pro-3.10.16.dist-info → rasa_pro-3.11.0.dist-info}/RECORD +236 -185
  232. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-1844e5a5.js +0 -1
  233. rasa/core/channels/inspector/dist/assets/index-a5d3e69d.js +0 -1040
  234. rasa/core/channels/voice_aware/utils.py +0 -20
  235. rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +0 -407
  236. /rasa/core/channels/{voice_aware → voice_ready}/__init__.py +0 -0
  237. /rasa/core/channels/{voice_native → voice_stream}/__init__.py +0 -0
  238. {rasa_pro-3.10.16.dist-info → rasa_pro-3.11.0.dist-info}/NOTICE +0 -0
  239. {rasa_pro-3.10.16.dist-info → rasa_pro-3.11.0.dist-info}/WHEEL +0 -0
  240. {rasa_pro-3.10.16.dist-info → rasa_pro-3.11.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,135 @@
1
+ from typing import Any, Dict, List
2
+ import logging
3
+ import structlog
4
+
5
+ from rasa.shared.exceptions import ProviderClientAPIException
6
+ from rasa.shared.providers._configs.litellm_router_client_config import (
7
+ LiteLLMRouterClientConfig,
8
+ )
9
+ from rasa.shared.providers.embedding._base_litellm_embedding_client import (
10
+ _BaseLiteLLMEmbeddingClient,
11
+ )
12
+ from rasa.shared.providers.embedding.embedding_response import EmbeddingResponse
13
+ from rasa.shared.providers.router._base_litellm_router_client import (
14
+ _BaseLiteLLMRouterClient,
15
+ )
16
+ from rasa.shared.utils.io import suppress_logs
17
+
18
+ structlogger = structlog.get_logger()
19
+
20
+
21
+ class LiteLLMRouterEmbeddingClient(
22
+ _BaseLiteLLMRouterClient, _BaseLiteLLMEmbeddingClient
23
+ ):
24
+ """A client for interfacing with LiteLLM Router Embedding endpoints.
25
+
26
+ Parameters:
27
+ model_group_id (str): The model group ID.
28
+ model_configurations (List[Dict[str, Any]]): The list of model configurations.
29
+ router_settings (Dict[str, Any]): The router settings.
30
+ kwargs (Optional[Dict[str, Any]]): Additional configuration parameters.
31
+
32
+ Raises:
33
+ ProviderClientValidationError: If validation of the client setup fails.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ model_group_id: str,
39
+ model_configurations: List[Dict[str, Any]],
40
+ router_settings: Dict[str, Any],
41
+ **kwargs: Any,
42
+ ):
43
+ super().__init__(
44
+ model_group_id, model_configurations, router_settings, **kwargs
45
+ )
46
+
47
+ @classmethod
48
+ def from_config(cls, config: Dict[str, Any]) -> "LiteLLMRouterEmbeddingClient":
49
+ """Instantiates a LiteLLM Router Embedding client from a configuration dict.
50
+
51
+ Args:
52
+ config: (Dict[str, Any]) The configuration dictionary.
53
+
54
+ Returns:
55
+ LiteLLMRouterLLMClient: The instantiated LiteLLM Router LLM client.
56
+
57
+ Raises:
58
+ ValueError: If the configuration is invalid.
59
+ """
60
+ try:
61
+ client_config = LiteLLMRouterClientConfig.from_dict(config)
62
+ except (KeyError, ValueError) as e:
63
+ message = "Cannot instantiate a client from the passed configuration."
64
+ structlogger.error(
65
+ "litellm_router_llm_client.from_config.error",
66
+ message=message,
67
+ config=config,
68
+ original_error=e,
69
+ )
70
+ raise
71
+
72
+ return cls(
73
+ model_group_id=client_config.model_group_id,
74
+ model_configurations=client_config.litellm_model_list,
75
+ router_settings=client_config.litellm_router_settings,
76
+ **client_config.extra_parameters,
77
+ )
78
+
79
+ @suppress_logs(log_level=logging.WARNING)
80
+ def embed(self, documents: List[str]) -> EmbeddingResponse:
81
+ """
82
+ Embeds a list of documents synchronously.
83
+
84
+ Args:
85
+ documents: List of documents to be embedded.
86
+
87
+ Returns:
88
+ List of embedding vectors.
89
+
90
+ Raises:
91
+ ProviderClientAPIException: If API calls raised an error.
92
+ """
93
+ self.validate_documents(documents)
94
+ try:
95
+ response = self.router_client.embedding(
96
+ input=documents, **self._embedding_fn_args
97
+ )
98
+ return self._format_response(response)
99
+ except Exception as e:
100
+ raise ProviderClientAPIException(
101
+ message="Failed to embed documents", original_exception=e
102
+ )
103
+
104
+ @suppress_logs(log_level=logging.WARNING)
105
+ async def aembed(self, documents: List[str]) -> EmbeddingResponse:
106
+ """
107
+ Embeds a list of documents asynchronously.
108
+
109
+ Args:
110
+ documents: List of documents to be embedded.
111
+
112
+ Returns:
113
+ List of embedding vectors.
114
+
115
+ Raises:
116
+ ProviderClientAPIException: If API calls raised an error.
117
+ """
118
+ self.validate_documents(documents)
119
+ try:
120
+ response = await self.router_client.aembedding(
121
+ input=documents, **self._embedding_fn_args
122
+ )
123
+ return self._format_response(response)
124
+ except Exception as e:
125
+ raise ProviderClientAPIException(
126
+ message="Failed to embed documents", original_exception=e
127
+ )
128
+
129
+ @property
130
+ def _embedding_fn_args(self) -> Dict[str, Any]:
131
+ """Returns the arguments to be passed to the embedding function."""
132
+ return {
133
+ **self._litellm_extra_parameters,
134
+ "model": self._model_group_id,
135
+ }
@@ -1,6 +1,6 @@
1
+ import logging
1
2
  from abc import abstractmethod
2
3
  from typing import Dict, List, Any, Union
3
- import logging
4
4
 
5
5
  import structlog
6
6
  from litellm import (
@@ -9,7 +9,7 @@ from litellm import (
9
9
  validate_environment,
10
10
  )
11
11
 
12
- from rasa.shared.constants import API_BASE_CONFIG_KEY
12
+ from rasa.shared.constants import API_BASE_CONFIG_KEY, API_KEY
13
13
  from rasa.shared.exceptions import (
14
14
  ProviderClientAPIException,
15
15
  ProviderClientValidationError,
@@ -19,7 +19,7 @@ from rasa.shared.providers._ssl_verification_utils import (
19
19
  ensure_ssl_certificates_for_litellm_openai_based_clients,
20
20
  )
21
21
  from rasa.shared.providers.llm.llm_response import LLMResponse, LLMUsage
22
- from rasa.shared.utils.io import suppress_logs
22
+ from rasa.shared.utils.io import suppress_logs, resolve_environment_variables
23
23
 
24
24
  structlogger = structlog.get_logger()
25
25
 
@@ -99,12 +99,12 @@ class _BaseLiteLLMClient:
99
99
  ProviderClientValidationError if validation fails.
100
100
  """
101
101
  self._validate_environment_variables()
102
- self._validate_api_key_not_in_config()
103
102
 
104
103
  def _validate_environment_variables(self) -> None:
105
104
  """Validate that the required environment variables are set."""
106
105
  validation_info = validate_environment(
107
106
  self._litellm_model_name,
107
+ api_key=self._litellm_extra_parameters.get(API_KEY),
108
108
  api_base=self._litellm_extra_parameters.get(API_BASE_CONFIG_KEY),
109
109
  )
110
110
  if missing_environment_variables := validation_info.get(
@@ -121,18 +121,6 @@ class _BaseLiteLLMClient:
121
121
  )
122
122
  raise ProviderClientValidationError(event_info)
123
123
 
124
- def _validate_api_key_not_in_config(self) -> None:
125
- if "api_key" in self._litellm_extra_parameters:
126
- event_info = (
127
- "API Key is set through `api_key` extra parameter."
128
- "Set API keys through environment variables."
129
- )
130
- structlogger.error(
131
- "base_litellm_client.validate_api_key_not_in_config",
132
- event_info=event_info,
133
- )
134
- raise ProviderClientValidationError(event_info)
135
-
136
124
  @suppress_logs(log_level=logging.WARNING)
137
125
  def completion(self, messages: Union[List[str], str]) -> LLMResponse:
138
126
  """Synchronously generate completions for given list of messages.
@@ -149,9 +137,8 @@ class _BaseLiteLLMClient:
149
137
  """
150
138
  try:
151
139
  formatted_messages = self._format_messages(messages)
152
- response = completion(
153
- messages=formatted_messages, **self._completion_fn_args
154
- )
140
+ arguments = resolve_environment_variables(self._completion_fn_args)
141
+ response = completion(messages=formatted_messages, **arguments)
155
142
  return self._format_response(response)
156
143
  except Exception as e:
157
144
  raise ProviderClientAPIException(e)
@@ -172,9 +159,8 @@ class _BaseLiteLLMClient:
172
159
  """
173
160
  try:
174
161
  formatted_messages = self._format_messages(messages)
175
- response = await acompletion(
176
- messages=formatted_messages, **self._completion_fn_args
177
- )
162
+ arguments = resolve_environment_variables(self._completion_fn_args)
163
+ response = await acompletion(messages=formatted_messages, **arguments)
178
164
  return self._format_response(response)
179
165
  except Exception as e:
180
166
  message = ""
@@ -235,6 +221,32 @@ class _BaseLiteLLMClient:
235
221
  )
236
222
  return formatted_response
237
223
 
224
+ def _format_text_completion_response(self, response: Any) -> LLMResponse:
225
+ """Parses the LiteLLM text completion response to Rasa format."""
226
+ formatted_response = LLMResponse(
227
+ id=response.id,
228
+ created=response.created,
229
+ choices=[choice.text for choice in response.choices],
230
+ model=response.model,
231
+ )
232
+ if (usage := response.usage) is not None:
233
+ prompt_tokens = (
234
+ num_tokens
235
+ if isinstance(num_tokens := usage.prompt_tokens, (int, float))
236
+ else 0
237
+ )
238
+ completion_tokens = (
239
+ num_tokens
240
+ if isinstance(num_tokens := usage.completion_tokens, (int, float))
241
+ else 0
242
+ )
243
+ formatted_response.usage = LLMUsage(prompt_tokens, completion_tokens)
244
+ structlogger.debug(
245
+ "base_litellm_client.formatted_response",
246
+ formatted_response=formatted_response.to_dict(),
247
+ )
248
+ return formatted_response
249
+
238
250
  @staticmethod
239
251
  def _ensure_certificates() -> None:
240
252
  """Configures SSL certificates for LiteLLM. This method is invoked during
@@ -17,6 +17,7 @@ from rasa.shared.constants import (
17
17
  OPENAI_API_KEY_ENV_VAR,
18
18
  AZURE_API_TYPE_ENV_VAR,
19
19
  AZURE_OPENAI_PROVIDER,
20
+ API_KEY,
20
21
  )
21
22
  from rasa.shared.exceptions import ProviderClientValidationError
22
23
  from rasa.shared.providers._configs.azure_openai_client_config import (
@@ -29,8 +30,7 @@ structlogger = structlog.get_logger()
29
30
 
30
31
 
31
32
  class AzureOpenAILLMClient(_BaseLiteLLMClient):
32
- """
33
- A client for interfacing with Azure's OpenAI LLM deployments.
33
+ """A client for interfacing with Azure's OpenAI LLM deployments.
34
34
 
35
35
  Parameters:
36
36
  deployment (str): The deployment name.
@@ -80,11 +80,7 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
80
80
  or os.getenv(OPENAI_API_VERSION_ENV_VAR)
81
81
  )
82
82
 
83
- # API key can be set through OPENAI_API_KEY too,
84
- # because of the backward compatibility
85
- self._api_key = os.getenv(AZURE_API_KEY_ENV_VAR) or os.getenv(
86
- OPENAI_API_KEY_ENV_VAR
87
- )
83
+ self._api_key_env_var = self._resolve_api_key_env_var()
88
84
 
89
85
  # Not used by LiteLLM, here for backward compatibility
90
86
  self._api_type = (
@@ -117,11 +113,6 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
117
113
  "env_var": AZURE_API_VERSION_ENV_VAR,
118
114
  "deprecated_var": OPENAI_API_VERSION_ENV_VAR,
119
115
  },
120
- "API Key": {
121
- "current_value": self._api_key,
122
- "env_var": AZURE_API_KEY_ENV_VAR,
123
- "deprecated_var": OPENAI_API_KEY_ENV_VAR,
124
- },
125
116
  }
126
117
 
127
118
  deprecation_warning_message = (
@@ -154,10 +145,51 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
154
145
  )
155
146
  raise_deprecation_warning(message=message)
156
147
 
148
+ def _resolve_api_key_env_var(self) -> str:
149
+ """Resolves the environment variable to use for the API key.
150
+
151
+ Returns:
152
+ str: The env variable in dollar syntax format to use for the API key.
153
+ """
154
+ if API_KEY in self._extra_parameters:
155
+ # API key is set to an env var in the config itself
156
+ # in case the model is defined in the endpoints.yml
157
+ return self._extra_parameters[API_KEY]
158
+
159
+ if os.getenv(AZURE_API_KEY_ENV_VAR) is not None:
160
+ return "${AZURE_API_KEY}"
161
+
162
+ if os.getenv(OPENAI_API_KEY_ENV_VAR) is not None:
163
+ # API key can be set through OPENAI_API_KEY too,
164
+ # because of the backward compatibility
165
+ raise_deprecation_warning(
166
+ message=(
167
+ f"Usage of '{OPENAI_API_KEY_ENV_VAR}' environment variable "
168
+ "for setting the API key for Azure OpenAI "
169
+ "client is deprecated and will be removed "
170
+ f"in 4.0.0. Please use '{AZURE_API_KEY_ENV_VAR}' "
171
+ "environment variable."
172
+ )
173
+ )
174
+ return "${OPENAI_API_KEY}"
175
+
176
+ structlogger.error(
177
+ "azure_openai_llm_client.api_key_not_set",
178
+ event_info=(
179
+ "API key not set, it is required for API calls. "
180
+ f"Set it either via the environment variable"
181
+ f"'{AZURE_API_KEY_ENV_VAR}' or directly"
182
+ f"via the config key '{API_KEY}'."
183
+ ),
184
+ )
185
+ raise ProviderClientValidationError(
186
+ f"Missing required environment variable/config key '{API_KEY}' for "
187
+ f"API calls."
188
+ )
189
+
157
190
  @classmethod
158
191
  def from_config(cls, config: Dict[str, Any]) -> "AzureOpenAILLMClient":
159
- """
160
- Initializes the client from given configuration.
192
+ """Initializes the client from given configuration.
161
193
 
162
194
  Args:
163
195
  config (Dict[str, Any]): Configuration.
@@ -212,23 +244,17 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
212
244
 
213
245
  @property
214
246
  def model(self) -> Optional[str]:
215
- """
216
- Returns the name of the model deployed on Azure.
217
- """
247
+ """Returns the name of the model deployed on Azure."""
218
248
  return self._model
219
249
 
220
250
  @property
221
251
  def api_base(self) -> Optional[str]:
222
- """
223
- Returns the API base URL for the Azure OpenAI llm client.
224
- """
252
+ """Returns the API base URL for the Azure OpenAI llm client."""
225
253
  return self._api_base
226
254
 
227
255
  @property
228
256
  def api_version(self) -> Optional[str]:
229
- """
230
- Returns the API version for the Azure OpenAI llm client.
231
- """
257
+ """Returns the API version for the Azure OpenAI llm client."""
232
258
  return self._api_version
233
259
 
234
260
  @property
@@ -261,7 +287,7 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
261
287
  {
262
288
  "api_base": self.api_base,
263
289
  "api_version": self.api_version,
264
- "api_key": self._api_key,
290
+ "api_key": self._api_key_env_var,
265
291
  }
266
292
  )
267
293
  return fn_args
@@ -305,11 +331,6 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
305
331
  "env_var": None,
306
332
  "config_key": DEPLOYMENT_CONFIG_KEY,
307
333
  },
308
- "API Key": {
309
- "current_value": self._api_key,
310
- "env_var": AZURE_API_KEY_ENV_VAR,
311
- "config_key": None,
312
- },
313
334
  }
314
335
 
315
336
  missing_settings = [
@@ -1,8 +1,13 @@
1
1
  from typing import Dict, Any
2
2
 
3
+ from rasa.shared.constants import (
4
+ AWS_BEDROCK_PROVIDER,
5
+ AWS_SAGEMAKER_PROVIDER,
6
+ )
3
7
  from rasa.shared.providers._configs.default_litellm_client_config import (
4
8
  DefaultLiteLLMClientConfig,
5
9
  )
10
+ from rasa.shared.providers._utils import validate_aws_setup_for_litellm_clients
6
11
  from rasa.shared.providers.llm._base_litellm_client import _BaseLiteLLMClient
7
12
 
8
13
 
@@ -82,3 +87,22 @@ class DefaultLiteLLMClient(_BaseLiteLLMClient):
82
87
  to the client provider and deployed model.
83
88
  """
84
89
  return self._extra_parameters
90
+
91
+ def validate_client_setup(self) -> None:
92
+ # TODO: Temporarily change the environment variable validation for AWS setup
93
+ # (Bedrock and SageMaker) until resolved by either:
94
+ # 1. An update from the LiteLLM package addressing the issue.
95
+ # 2. The implementation of a Bedrock client on our end.
96
+ # ---
97
+ # This fix ensures a consistent user experience for Bedrock (and
98
+ # SageMaker) in Rasa by allowing AWS secrets to be provided as extra
99
+ # parameters without triggering validation errors due to missing AWS
100
+ # environment variables.
101
+ if self.provider.lower() in [AWS_BEDROCK_PROVIDER, AWS_SAGEMAKER_PROVIDER]:
102
+ validate_aws_setup_for_litellm_clients(
103
+ self._litellm_model_name,
104
+ self._litellm_extra_parameters,
105
+ "default_litellm_llm_client",
106
+ )
107
+ else:
108
+ super().validate_client_setup()
@@ -0,0 +1,182 @@
1
+ from typing import Any, Dict, List, Union
2
+ import logging
3
+ import structlog
4
+
5
+ from rasa.shared.exceptions import ProviderClientAPIException
6
+ from rasa.shared.providers._configs.litellm_router_client_config import (
7
+ LiteLLMRouterClientConfig,
8
+ )
9
+ from rasa.shared.providers.llm._base_litellm_client import _BaseLiteLLMClient
10
+ from rasa.shared.providers.llm.llm_response import LLMResponse
11
+ from rasa.shared.providers.router._base_litellm_router_client import (
12
+ _BaseLiteLLMRouterClient,
13
+ )
14
+ from rasa.shared.utils.io import suppress_logs
15
+
16
+ structlogger = structlog.get_logger()
17
+
18
+
19
+ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
20
+ """A client for interfacing with LiteLLM Router LLM endpoints.
21
+
22
+ Parameters:
23
+ model_group_id (str): The model group ID.
24
+ model_configurations (List[Dict[str, Any]]): The list of model configurations.
25
+ router_settings (Dict[str, Any]): The router settings.
26
+ kwargs (Optional[Dict[str, Any]]): Additional configuration parameters.
27
+
28
+ Raises:
29
+ ProviderClientValidationError: If validation of the client setup fails.
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ model_group_id: str,
35
+ model_configurations: List[Dict[str, Any]],
36
+ router_settings: Dict[str, Any],
37
+ **kwargs: Any,
38
+ ):
39
+ super().__init__(
40
+ model_group_id, model_configurations, router_settings, **kwargs
41
+ )
42
+
43
+ @classmethod
44
+ def from_config(cls, config: Dict[str, Any]) -> "LiteLLMRouterLLMClient":
45
+ """Instantiates a LiteLLM Router LLM client from a configuration dict.
46
+
47
+ Args:
48
+ config: (Dict[str, Any]) The configuration dictionary.
49
+
50
+ Returns:
51
+ LiteLLMRouterLLMClient: The instantiated LiteLLM Router LLM client.
52
+
53
+ Raises:
54
+ ValueError: If the configuration is invalid.
55
+ """
56
+ try:
57
+ client_config = LiteLLMRouterClientConfig.from_dict(config)
58
+ except (KeyError, ValueError) as e:
59
+ message = "Cannot instantiate a client from the passed configuration."
60
+ structlogger.error(
61
+ "litellm_router_llm_client.from_config.error",
62
+ message=message,
63
+ config=config,
64
+ original_error=e,
65
+ )
66
+ raise
67
+
68
+ return cls(
69
+ model_group_id=client_config.model_group_id,
70
+ model_configurations=client_config.litellm_model_list,
71
+ router_settings=client_config.litellm_router_settings,
72
+ use_chat_completions_endpoint=client_config.use_chat_completions_endpoint,
73
+ **client_config.extra_parameters,
74
+ )
75
+
76
+ @suppress_logs(log_level=logging.WARNING)
77
+ def _text_completion(self, prompt: Union[List[str], str]) -> LLMResponse:
78
+ """
79
+ Synchronously generate completions for given prompt.
80
+
81
+ Args:
82
+ prompt: Prompt to generate the completion for.
83
+ Returns:
84
+ List of message completions.
85
+ Raises:
86
+ ProviderClientAPIException: If the API request fails.
87
+ """
88
+ try:
89
+ response = self.router_client.text_completion(
90
+ prompt=prompt, **self._completion_fn_args
91
+ )
92
+ return self._format_text_completion_response(response)
93
+ except Exception as e:
94
+ raise ProviderClientAPIException(e)
95
+
96
+ @suppress_logs(log_level=logging.WARNING)
97
+ async def _atext_completion(self, prompt: Union[List[str], str]) -> LLMResponse:
98
+ """
99
+ Asynchronously generate completions for given prompt.
100
+
101
+ Args:
102
+ prompt: Prompt to generate the completion for.
103
+ Returns:
104
+ List of message completions.
105
+ Raises:
106
+ ProviderClientAPIException: If the API request fails.
107
+ """
108
+ try:
109
+ response = await self.router_client.atext_completion(
110
+ prompt=prompt, **self._completion_fn_args
111
+ )
112
+ return self._format_text_completion_response(response)
113
+ except Exception as e:
114
+ raise ProviderClientAPIException(e)
115
+
116
+ @suppress_logs(log_level=logging.WARNING)
117
+ def completion(self, messages: Union[List[str], str]) -> LLMResponse:
118
+ """
119
+ Synchronously generate completions for given list of messages.
120
+
121
+ Method overrides the base class method to call the appropriate
122
+ completion method based on the configuration. If the chat completions
123
+ endpoint is enabled, the completion method is called. Otherwise, the
124
+ text_completion method is called.
125
+
126
+ Args:
127
+ messages: List of messages or a single message to generate the
128
+ completion for.
129
+ Returns:
130
+ List of message completions.
131
+ Raises:
132
+ ProviderClientAPIException: If the API request fails.
133
+ """
134
+ if not self._use_chat_completions_endpoint:
135
+ return self._text_completion(messages)
136
+ try:
137
+ formatted_messages = self._format_messages(messages)
138
+ response = self.router_client.completion(
139
+ messages=formatted_messages, **self._completion_fn_args
140
+ )
141
+ return self._format_response(response)
142
+ except Exception as e:
143
+ raise ProviderClientAPIException(e)
144
+
145
+ @suppress_logs(log_level=logging.WARNING)
146
+ async def acompletion(self, messages: Union[List[str], str]) -> LLMResponse:
147
+ """
148
+ Asynchronously generate completions for given list of messages.
149
+
150
+ Method overrides the base class method to call the appropriate
151
+ completion method based on the configuration. If the chat completions
152
+ endpoint is enabled, the completion method is called. Otherwise, the
153
+ text_completion method is called.
154
+
155
+ Args:
156
+ messages: List of messages or a single message to generate the
157
+ completion for.
158
+ Returns:
159
+ List of message completions.
160
+ Raises:
161
+ ProviderClientAPIException: If the API request fails.
162
+ """
163
+ if not self._use_chat_completions_endpoint:
164
+ return await self._atext_completion(messages)
165
+ try:
166
+ formatted_messages = self._format_messages(messages)
167
+ response = await self.router_client.acompletion(
168
+ messages=formatted_messages, **self._completion_fn_args
169
+ )
170
+ return self._format_response(response)
171
+ except Exception as e:
172
+ raise ProviderClientAPIException(e)
173
+
174
+ @property
175
+ def _completion_fn_args(self) -> Dict[str, Any]:
176
+ """Returns the completion arguments for invoking a call through
177
+ LiteLLM's completion functions.
178
+ """
179
+ return {
180
+ **self._litellm_extra_parameters,
181
+ "model": self.model_group_id,
182
+ }