rasa-pro 3.11.0a4.dev2__py3-none-any.whl → 3.11.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (163) hide show
  1. rasa/__main__.py +22 -12
  2. rasa/api.py +1 -1
  3. rasa/cli/arguments/default_arguments.py +1 -2
  4. rasa/cli/arguments/shell.py +5 -1
  5. rasa/cli/e2e_test.py +1 -1
  6. rasa/cli/evaluate.py +8 -8
  7. rasa/cli/inspect.py +4 -4
  8. rasa/cli/llm_fine_tuning.py +1 -1
  9. rasa/cli/project_templates/calm/config.yml +5 -7
  10. rasa/cli/project_templates/calm/endpoints.yml +8 -0
  11. rasa/cli/project_templates/tutorial/config.yml +8 -5
  12. rasa/cli/project_templates/tutorial/data/flows.yml +1 -1
  13. rasa/cli/project_templates/tutorial/data/patterns.yml +5 -0
  14. rasa/cli/project_templates/tutorial/domain.yml +14 -0
  15. rasa/cli/project_templates/tutorial/endpoints.yml +7 -7
  16. rasa/cli/run.py +1 -1
  17. rasa/cli/scaffold.py +4 -2
  18. rasa/cli/utils.py +5 -0
  19. rasa/cli/x.py +8 -8
  20. rasa/constants.py +1 -1
  21. rasa/core/channels/channel.py +3 -0
  22. rasa/core/channels/inspector/dist/assets/{arc-6852c607.js → arc-bc141fb2.js} +1 -1
  23. rasa/core/channels/inspector/dist/assets/{c4Diagram-d0fbc5ce-acc952b2.js → c4Diagram-d0fbc5ce-be2db283.js} +1 -1
  24. rasa/core/channels/inspector/dist/assets/{classDiagram-936ed81e-848a7597.js → classDiagram-936ed81e-55366915.js} +1 -1
  25. rasa/core/channels/inspector/dist/assets/{classDiagram-v2-c3cb15f1-a73d3e68.js → classDiagram-v2-c3cb15f1-bb529518.js} +1 -1
  26. rasa/core/channels/inspector/dist/assets/{createText-62fc7601-e5ee049d.js → createText-62fc7601-b0ec81d6.js} +1 -1
  27. rasa/core/channels/inspector/dist/assets/{edges-f2ad444c-771e517e.js → edges-f2ad444c-6166330c.js} +1 -1
  28. rasa/core/channels/inspector/dist/assets/{erDiagram-9d236eb7-aa347178.js → erDiagram-9d236eb7-5ccc6a8e.js} +1 -1
  29. rasa/core/channels/inspector/dist/assets/{flowDb-1972c806-651fc57d.js → flowDb-1972c806-fca3bfe4.js} +1 -1
  30. rasa/core/channels/inspector/dist/assets/{flowDiagram-7ea5b25a-ca67804f.js → flowDiagram-7ea5b25a-4739080f.js} +1 -1
  31. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-736177bf.js +1 -0
  32. rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-abe16c3d-2dbc568d.js → flowchart-elk-definition-abe16c3d-7c1b0e0f.js} +1 -1
  33. rasa/core/channels/inspector/dist/assets/{ganttDiagram-9b5ea136-25a65bd8.js → ganttDiagram-9b5ea136-772fd050.js} +1 -1
  34. rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-99d0ae7c-fdc7378d.js → gitGraphDiagram-99d0ae7c-8eae1dc9.js} +1 -1
  35. rasa/core/channels/inspector/dist/assets/{index-2c4b9a3b-6f1fd606.js → index-2c4b9a3b-f55afcdf.js} +1 -1
  36. rasa/core/channels/inspector/dist/assets/{index-efdd30c1.js → index-e7cef9de.js} +68 -68
  37. rasa/core/channels/inspector/dist/assets/{infoDiagram-736b4530-cb1a041a.js → infoDiagram-736b4530-124d4a14.js} +1 -1
  38. rasa/core/channels/inspector/dist/assets/{journeyDiagram-df861f2b-14609879.js → journeyDiagram-df861f2b-7c4fae44.js} +1 -1
  39. rasa/core/channels/inspector/dist/assets/{layout-2490f52b.js → layout-b9885fb6.js} +1 -1
  40. rasa/core/channels/inspector/dist/assets/{line-40186f1f.js → line-7c59abb6.js} +1 -1
  41. rasa/core/channels/inspector/dist/assets/{linear-08814e93.js → linear-4776f780.js} +1 -1
  42. rasa/core/channels/inspector/dist/assets/{mindmap-definition-beec6740-1a534584.js → mindmap-definition-beec6740-2332c46c.js} +1 -1
  43. rasa/core/channels/inspector/dist/assets/{pieDiagram-dbbf0591-72397b61.js → pieDiagram-dbbf0591-8fb39303.js} +1 -1
  44. rasa/core/channels/inspector/dist/assets/{quadrantDiagram-4d7f4fd6-3bb0b6a3.js → quadrantDiagram-4d7f4fd6-3c7180a2.js} +1 -1
  45. rasa/core/channels/inspector/dist/assets/{requirementDiagram-6fc4c22a-57334f61.js → requirementDiagram-6fc4c22a-e910bcb8.js} +1 -1
  46. rasa/core/channels/inspector/dist/assets/{sankeyDiagram-8f13d901-111e1297.js → sankeyDiagram-8f13d901-ead16c89.js} +1 -1
  47. rasa/core/channels/inspector/dist/assets/{sequenceDiagram-b655622a-10bcfe62.js → sequenceDiagram-b655622a-29a02a19.js} +1 -1
  48. rasa/core/channels/inspector/dist/assets/{stateDiagram-59f0c015-acaf7513.js → stateDiagram-59f0c015-042b3137.js} +1 -1
  49. rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-2b26beab-3ec2a235.js → stateDiagram-v2-2b26beab-2178c0f3.js} +1 -1
  50. rasa/core/channels/inspector/dist/assets/{styles-080da4f6-62730289.js → styles-080da4f6-23ffa4fc.js} +1 -1
  51. rasa/core/channels/inspector/dist/assets/{styles-3dcbcfbf-5284ee76.js → styles-3dcbcfbf-94f59763.js} +1 -1
  52. rasa/core/channels/inspector/dist/assets/{styles-9c745c82-642435e3.js → styles-9c745c82-78a6bebc.js} +1 -1
  53. rasa/core/channels/inspector/dist/assets/{svgDrawCommon-4835440b-b250a350.js → svgDrawCommon-4835440b-eae2a6f6.js} +1 -1
  54. rasa/core/channels/inspector/dist/assets/{timeline-definition-5b62e21b-c2b147ed.js → timeline-definition-5b62e21b-5c968d92.js} +1 -1
  55. rasa/core/channels/inspector/dist/assets/{xychartDiagram-2b33534f-f92cfea9.js → xychartDiagram-2b33534f-fd3db0d5.js} +1 -1
  56. rasa/core/channels/inspector/dist/index.html +1 -1
  57. rasa/core/channels/inspector/src/App.tsx +1 -1
  58. rasa/core/channels/inspector/src/helpers/audiostream.ts +77 -16
  59. rasa/core/channels/socketio.py +2 -1
  60. rasa/core/channels/telegram.py +1 -1
  61. rasa/core/channels/twilio.py +1 -1
  62. rasa/core/channels/voice_ready/jambonz.py +2 -2
  63. rasa/core/channels/voice_stream/asr/asr_event.py +5 -0
  64. rasa/core/channels/voice_stream/asr/azure.py +122 -0
  65. rasa/core/channels/voice_stream/asr/deepgram.py +16 -6
  66. rasa/core/channels/voice_stream/audio_bytes.py +1 -0
  67. rasa/core/channels/voice_stream/browser_audio.py +31 -8
  68. rasa/core/channels/voice_stream/call_state.py +23 -0
  69. rasa/core/channels/voice_stream/tts/azure.py +6 -2
  70. rasa/core/channels/voice_stream/tts/cartesia.py +10 -6
  71. rasa/core/channels/voice_stream/tts/tts_engine.py +1 -0
  72. rasa/core/channels/voice_stream/twilio_media_streams.py +27 -18
  73. rasa/core/channels/voice_stream/util.py +4 -4
  74. rasa/core/channels/voice_stream/voice_channel.py +177 -39
  75. rasa/core/featurizers/single_state_featurizer.py +22 -1
  76. rasa/core/featurizers/tracker_featurizers.py +115 -18
  77. rasa/core/nlg/contextual_response_rephraser.py +16 -22
  78. rasa/core/persistor.py +86 -39
  79. rasa/core/policies/enterprise_search_policy.py +159 -60
  80. rasa/core/policies/flows/flow_executor.py +7 -4
  81. rasa/core/policies/intentless_policy.py +120 -22
  82. rasa/core/policies/ted_policy.py +58 -33
  83. rasa/core/policies/unexpected_intent_policy.py +15 -7
  84. rasa/core/processor.py +25 -0
  85. rasa/core/training/interactive.py +34 -35
  86. rasa/core/utils.py +8 -3
  87. rasa/dialogue_understanding/coexistence/llm_based_router.py +58 -16
  88. rasa/dialogue_understanding/commands/change_flow_command.py +6 -0
  89. rasa/dialogue_understanding/commands/user_silence_command.py +59 -0
  90. rasa/dialogue_understanding/commands/utils.py +5 -0
  91. rasa/dialogue_understanding/generator/constants.py +4 -0
  92. rasa/dialogue_understanding/generator/flow_retrieval.py +65 -3
  93. rasa/dialogue_understanding/generator/llm_based_command_generator.py +68 -26
  94. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +57 -8
  95. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +64 -7
  96. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +39 -0
  97. rasa/dialogue_understanding/patterns/user_silence.py +37 -0
  98. rasa/e2e_test/e2e_test_runner.py +4 -2
  99. rasa/e2e_test/utils/io.py +1 -1
  100. rasa/engine/validation.py +297 -7
  101. rasa/model_manager/config.py +17 -3
  102. rasa/model_manager/model_api.py +16 -8
  103. rasa/model_manager/runner_service.py +8 -6
  104. rasa/model_manager/socket_bridge.py +6 -3
  105. rasa/model_manager/trainer_service.py +7 -5
  106. rasa/model_manager/utils.py +28 -7
  107. rasa/model_service.py +7 -5
  108. rasa/model_training.py +2 -0
  109. rasa/nlu/classifiers/diet_classifier.py +38 -25
  110. rasa/nlu/classifiers/logistic_regression_classifier.py +22 -9
  111. rasa/nlu/classifiers/sklearn_intent_classifier.py +37 -16
  112. rasa/nlu/extractors/crf_entity_extractor.py +93 -50
  113. rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +45 -16
  114. rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +52 -17
  115. rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +5 -3
  116. rasa/shared/constants.py +36 -3
  117. rasa/shared/core/constants.py +7 -0
  118. rasa/shared/core/domain.py +26 -0
  119. rasa/shared/core/flows/flow.py +5 -0
  120. rasa/shared/core/flows/flows_yaml_schema.json +10 -0
  121. rasa/shared/core/flows/utils.py +39 -0
  122. rasa/shared/core/flows/validation.py +96 -0
  123. rasa/shared/core/slots.py +5 -0
  124. rasa/shared/nlu/training_data/features.py +120 -2
  125. rasa/shared/providers/_configs/azure_openai_client_config.py +5 -3
  126. rasa/shared/providers/_configs/litellm_router_client_config.py +200 -0
  127. rasa/shared/providers/_configs/model_group_config.py +167 -0
  128. rasa/shared/providers/_configs/openai_client_config.py +1 -1
  129. rasa/shared/providers/_configs/rasa_llm_client_config.py +73 -0
  130. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +1 -0
  131. rasa/shared/providers/_configs/utils.py +16 -0
  132. rasa/shared/providers/embedding/_base_litellm_embedding_client.py +12 -15
  133. rasa/shared/providers/embedding/azure_openai_embedding_client.py +54 -21
  134. rasa/shared/providers/embedding/litellm_router_embedding_client.py +135 -0
  135. rasa/shared/providers/llm/_base_litellm_client.py +31 -30
  136. rasa/shared/providers/llm/azure_openai_llm_client.py +50 -29
  137. rasa/shared/providers/llm/litellm_router_llm_client.py +127 -0
  138. rasa/shared/providers/llm/rasa_llm_client.py +112 -0
  139. rasa/shared/providers/llm/self_hosted_llm_client.py +1 -1
  140. rasa/shared/providers/mappings.py +19 -0
  141. rasa/shared/providers/router/__init__.py +0 -0
  142. rasa/shared/providers/router/_base_litellm_router_client.py +149 -0
  143. rasa/shared/providers/router/router_client.py +73 -0
  144. rasa/shared/utils/common.py +8 -0
  145. rasa/shared/utils/health_check.py +533 -0
  146. rasa/shared/utils/io.py +28 -6
  147. rasa/shared/utils/llm.py +350 -46
  148. rasa/shared/utils/yaml.py +11 -13
  149. rasa/studio/upload.py +64 -20
  150. rasa/telemetry.py +80 -17
  151. rasa/tracing/instrumentation/attribute_extractors.py +74 -17
  152. rasa/utils/io.py +0 -66
  153. rasa/utils/log_utils.py +9 -2
  154. rasa/utils/tensorflow/feature_array.py +366 -0
  155. rasa/utils/tensorflow/model_data.py +2 -193
  156. rasa/validator.py +70 -0
  157. rasa/version.py +1 -1
  158. {rasa_pro-3.11.0a4.dev2.dist-info → rasa_pro-3.11.0rc1.dist-info}/METADATA +10 -10
  159. {rasa_pro-3.11.0a4.dev2.dist-info → rasa_pro-3.11.0rc1.dist-info}/RECORD +162 -146
  160. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-587d82d8.js +0 -1
  161. {rasa_pro-3.11.0a4.dev2.dist-info → rasa_pro-3.11.0rc1.dist-info}/NOTICE +0 -0
  162. {rasa_pro-3.11.0a4.dev2.dist-info → rasa_pro-3.11.0rc1.dist-info}/WHEEL +0 -0
  163. {rasa_pro-3.11.0a4.dev2.dist-info → rasa_pro-3.11.0rc1.dist-info}/entry_points.txt +0 -0
@@ -1,10 +1,11 @@
1
+ import logging
1
2
  from abc import abstractmethod
2
3
  from typing import Any, Dict, List
3
4
 
4
5
  import litellm
5
- import logging
6
6
  import structlog
7
7
  from litellm import aembedding, embedding, validate_environment
8
+
8
9
  from rasa.shared.exceptions import (
9
10
  ProviderClientAPIException,
10
11
  ProviderClientValidationError,
@@ -17,7 +18,7 @@ from rasa.shared.providers.embedding.embedding_response import (
17
18
  EmbeddingResponse,
18
19
  EmbeddingUsage,
19
20
  )
20
- from rasa.shared.utils.io import suppress_logs
21
+ from rasa.shared.utils.io import suppress_logs, resolve_environment_variables
21
22
 
22
23
  structlogger = structlog.get_logger()
23
24
 
@@ -25,8 +26,7 @@ _VALIDATE_ENVIRONMENT_MISSING_KEYS_KEY = "missing_keys"
25
26
 
26
27
 
27
28
  class _BaseLiteLLMEmbeddingClient:
28
- """
29
- An abstract base class for LiteLLM embedding clients.
29
+ """An abstract base class for LiteLLM embedding clients.
30
30
 
31
31
  This class defines the interface and common functionality for all clients
32
32
  based on LiteLLM.
@@ -113,8 +113,7 @@ class _BaseLiteLLMEmbeddingClient:
113
113
  raise ProviderClientValidationError(event_info)
114
114
 
115
115
  def validate_documents(self, documents: List[str]) -> None:
116
- """
117
- Validates a list of documents to ensure they are suitable for embedding.
116
+ """Validates a list of documents to ensure they are suitable for embedding.
118
117
 
119
118
  Args:
120
119
  documents: List of documents to be validated.
@@ -130,8 +129,7 @@ class _BaseLiteLLMEmbeddingClient:
130
129
 
131
130
  @suppress_logs(log_level=logging.WARNING)
132
131
  def embed(self, documents: List[str]) -> EmbeddingResponse:
133
- """
134
- Embeds a list of documents synchronously.
132
+ """Embeds a list of documents synchronously.
135
133
 
136
134
  Args:
137
135
  documents: List of documents to be embedded.
@@ -144,7 +142,8 @@ class _BaseLiteLLMEmbeddingClient:
144
142
  """
145
143
  self.validate_documents(documents)
146
144
  try:
147
- response = embedding(input=documents, **self._embedding_fn_args)
145
+ arguments = resolve_environment_variables(self._embedding_fn_args)
146
+ response = embedding(input=documents, **arguments)
148
147
  return self._format_response(response)
149
148
  except Exception as e:
150
149
  raise ProviderClientAPIException(
@@ -153,8 +152,7 @@ class _BaseLiteLLMEmbeddingClient:
153
152
 
154
153
  @suppress_logs(log_level=logging.WARNING)
155
154
  async def aembed(self, documents: List[str]) -> EmbeddingResponse:
156
- """
157
- Embeds a list of documents asynchronously.
155
+ """Embeds a list of documents asynchronously.
158
156
 
159
157
  Args:
160
158
  documents: List of documents to be embedded.
@@ -167,7 +165,8 @@ class _BaseLiteLLMEmbeddingClient:
167
165
  """
168
166
  self.validate_documents(documents)
169
167
  try:
170
- response = await aembedding(input=documents, **self._embedding_fn_args)
168
+ arguments = resolve_environment_variables(self._embedding_fn_args)
169
+ response = await aembedding(input=documents, **arguments)
171
170
  return self._format_response(response)
172
171
  except Exception as e:
173
172
  raise ProviderClientAPIException(
@@ -182,7 +181,6 @@ class _BaseLiteLLMEmbeddingClient:
182
181
  Raises:
183
182
  ValueError: If any response data is None.
184
183
  """
185
-
186
184
  # If data is not available (None), raise a ValueError
187
185
  if response.data is None:
188
186
  message = (
@@ -239,8 +237,7 @@ class _BaseLiteLLMEmbeddingClient:
239
237
 
240
238
  @staticmethod
241
239
  def _ensure_certificates() -> None:
242
- """
243
- Configures SSL certificates for LiteLLM. This method is invoked during
240
+ """Configures SSL certificates for LiteLLM. This method is invoked during
244
241
  client initialization.
245
242
 
246
243
  LiteLLM may utilize `openai` clients or other providers that require
@@ -1,5 +1,6 @@
1
- from typing import Any, Dict, List, Optional
2
1
  import os
2
+ from typing import Any, Dict, List, Optional
3
+
3
4
  import structlog
4
5
 
5
6
  from rasa.shared.constants import (
@@ -42,6 +43,7 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
42
43
  If not provided, it will be set via environment variable.
43
44
  kwargs (Optional[Dict[str, Any]]): Optional configuration parameters specific
44
45
  to the embedding model deployment.
46
+
45
47
  Raises:
46
48
  ProviderClientValidationError: If validation of the client setup fails.
47
49
  DeprecationWarning: If deprecated environment variables are used for
@@ -60,6 +62,7 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
60
62
  super().__init__() # type: ignore
61
63
  self._deployment = deployment
62
64
  self._model = model
65
+ self._extra_parameters = kwargs or {}
63
66
 
64
67
  # Set api_base with the following priority:
65
68
  # parameter -> Azure Env Var -> (deprecated) OpenAI Env Var
@@ -81,17 +84,55 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
81
84
  # Litellm does not support use of OPENAI_API_KEY, so we need to map it
82
85
  # because of backward compatibility. However, we're first looking at
83
86
  # AZURE_API_KEY.
84
- self._api_key = os.environ.get(AZURE_API_KEY_ENV_VAR) or os.environ.get(
85
- OPENAI_API_KEY_ENV_VAR
86
- )
87
+ self._api_key_env_var = self._resolve_api_key_env_var()
87
88
 
88
- self._extra_parameters = kwargs or {}
89
89
  self.validate_client_setup()
90
90
 
91
+ def _resolve_api_key_env_var(self) -> str:
92
+ """Resolves the environment variable to use for the API key.
93
+
94
+ Returns:
95
+ str: The env variable in dollar syntax format to use for the API key.
96
+ """
97
+ if API_KEY in self._extra_parameters:
98
+ # API key is set to an env var in the config itself
99
+ # in case the model is defined in the endpoints.yml
100
+ return self._extra_parameters[API_KEY]
101
+
102
+ if os.getenv(AZURE_API_KEY_ENV_VAR) is not None:
103
+ return "${AZURE_API_KEY}"
104
+
105
+ if os.getenv(OPENAI_API_KEY_ENV_VAR) is not None:
106
+ # API key can be set through OPENAI_API_KEY too,
107
+ # because of the backward compatibility
108
+ raise_deprecation_warning(
109
+ message=(
110
+ f"Usage of '{OPENAI_API_KEY_ENV_VAR}' environment variable "
111
+ "for setting the API key of "
112
+ "Azure OpenAI client is deprecated and will "
113
+ "be removed in 4.0.0. Please "
114
+ f"use '{AZURE_API_KEY_ENV_VAR}' instead."
115
+ )
116
+ )
117
+ return "${OPENAI_API_KEY}"
118
+
119
+ structlogger.error(
120
+ "azure_openai_embedding_client.api_key_not_set",
121
+ event_info=(
122
+ "API key not set, it is required for API calls. "
123
+ f"Set it either via the environment variable "
124
+ f"'{AZURE_API_KEY_ENV_VAR}' or directly"
125
+ f"via the config key '{API_KEY}'."
126
+ ),
127
+ )
128
+ raise ProviderClientValidationError(
129
+ f"Missing required environment variable/config key '{API_KEY}' for "
130
+ f"API calls."
131
+ )
132
+
91
133
  @classmethod
92
134
  def from_config(cls, config: Dict[str, Any]) -> "AzureOpenAIEmbeddingClient":
93
- """
94
- Initializes the client from given configuration.
135
+ """Initializes the client from given configuration.
95
136
 
96
137
  Args:
97
138
  config (Dict[str, Any]): Configuration.
@@ -142,8 +183,7 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
142
183
 
143
184
  @property
144
185
  def model(self) -> Optional[str]:
145
- """
146
- Returns the name of the model deployed on Azure. If model name is not
186
+ """Returns the name of the model deployed on Azure. If model name is not
147
187
  provided, returns "N/A".
148
188
  """
149
189
  return self._model
@@ -170,8 +210,7 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
170
210
 
171
211
  @property
172
212
  def _litellm_extra_parameters(self) -> Dict[str, Any]:
173
- """
174
- Returns the model parameters for the azure openai embedding client.
213
+ """Returns the model parameters for the azure openai embedding client.
175
214
 
176
215
  Returns:
177
216
  Dictionary containing the model parameters.
@@ -186,7 +225,7 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
186
225
  "api_base": self.api_base,
187
226
  "api_type": self.api_type,
188
227
  "api_version": self.api_version,
189
- "api_key": self._api_key,
228
+ "api_key": self._api_key_env_var,
190
229
  }
191
230
 
192
231
  @property
@@ -197,8 +236,9 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
197
236
  return self.deployment
198
237
 
199
238
  def validate_client_setup(self) -> None:
200
- """Perform client validation. By default only environment variables
201
- are validated.
239
+ """Perform client validation.
240
+
241
+ By default, only environment variables are validated.
202
242
 
203
243
  Raises:
204
244
  ProviderClientValidationError if validation fails.
@@ -214,13 +254,6 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
214
254
  "current_value": self.api_base,
215
255
  "new_env_key": AZURE_API_BASE_ENV_VAR,
216
256
  },
217
- {
218
- "param_name": "API key",
219
- "config_key": API_KEY,
220
- "deprecated_env_key": OPENAI_API_KEY_ENV_VAR,
221
- "current_value": self._api_key,
222
- "new_env_key": AZURE_API_KEY_ENV_VAR,
223
- },
224
257
  {
225
258
  "param_name": "API version",
226
259
  "config_key": API_VERSION_CONFIG_KEY,
@@ -0,0 +1,135 @@
1
+ from typing import Any, Dict, List
2
+ import logging
3
+ import structlog
4
+
5
+ from rasa.shared.exceptions import ProviderClientAPIException
6
+ from rasa.shared.providers._configs.litellm_router_client_config import (
7
+ LiteLLMRouterClientConfig,
8
+ )
9
+ from rasa.shared.providers.embedding._base_litellm_embedding_client import (
10
+ _BaseLiteLLMEmbeddingClient,
11
+ )
12
+ from rasa.shared.providers.embedding.embedding_response import EmbeddingResponse
13
+ from rasa.shared.providers.router._base_litellm_router_client import (
14
+ _BaseLiteLLMRouterClient,
15
+ )
16
+ from rasa.shared.utils.io import suppress_logs
17
+
18
+ structlogger = structlog.get_logger()
19
+
20
+
21
+ class LiteLLMRouterEmbeddingClient(
22
+ _BaseLiteLLMRouterClient, _BaseLiteLLMEmbeddingClient
23
+ ):
24
+ """A client for interfacing with LiteLLM Router Embedding endpoints.
25
+
26
+ Parameters:
27
+ model_group_id (str): The model group ID.
28
+ model_configurations (List[Dict[str, Any]]): The list of model configurations.
29
+ router_settings (Dict[str, Any]): The router settings.
30
+ kwargs (Optional[Dict[str, Any]]): Additional configuration parameters.
31
+
32
+ Raises:
33
+ ProviderClientValidationError: If validation of the client setup fails.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ model_group_id: str,
39
+ model_configurations: List[Dict[str, Any]],
40
+ router_settings: Dict[str, Any],
41
+ **kwargs: Any,
42
+ ):
43
+ super().__init__(
44
+ model_group_id, model_configurations, router_settings, **kwargs
45
+ )
46
+
47
+ @classmethod
48
+ def from_config(cls, config: Dict[str, Any]) -> "LiteLLMRouterEmbeddingClient":
49
+ """Instantiates a LiteLLM Router Embedding client from a configuration dict.
50
+
51
+ Args:
52
+ config: (Dict[str, Any]) The configuration dictionary.
53
+
54
+ Returns:
55
+ LiteLLMRouterLLMClient: The instantiated LiteLLM Router LLM client.
56
+
57
+ Raises:
58
+ ValueError: If the configuration is invalid.
59
+ """
60
+ try:
61
+ client_config = LiteLLMRouterClientConfig.from_dict(config)
62
+ except (KeyError, ValueError) as e:
63
+ message = "Cannot instantiate a client from the passed configuration."
64
+ structlogger.error(
65
+ "litellm_router_llm_client.from_config.error",
66
+ message=message,
67
+ config=config,
68
+ original_error=e,
69
+ )
70
+ raise
71
+
72
+ return cls(
73
+ model_group_id=client_config.model_group_id,
74
+ model_configurations=client_config.litellm_model_list,
75
+ router_settings=client_config.router,
76
+ **client_config.extra_parameters,
77
+ )
78
+
79
+ @suppress_logs(log_level=logging.WARNING)
80
+ def embed(self, documents: List[str]) -> EmbeddingResponse:
81
+ """
82
+ Embeds a list of documents synchronously.
83
+
84
+ Args:
85
+ documents: List of documents to be embedded.
86
+
87
+ Returns:
88
+ List of embedding vectors.
89
+
90
+ Raises:
91
+ ProviderClientAPIException: If API calls raised an error.
92
+ """
93
+ self.validate_documents(documents)
94
+ try:
95
+ response = self.router_client.embedding(
96
+ input=documents, **self._embedding_fn_args
97
+ )
98
+ return self._format_response(response)
99
+ except Exception as e:
100
+ raise ProviderClientAPIException(
101
+ message="Failed to embed documents", original_exception=e
102
+ )
103
+
104
+ @suppress_logs(log_level=logging.WARNING)
105
+ async def aembed(self, documents: List[str]) -> EmbeddingResponse:
106
+ """
107
+ Embeds a list of documents asynchronously.
108
+
109
+ Args:
110
+ documents: List of documents to be embedded.
111
+
112
+ Returns:
113
+ List of embedding vectors.
114
+
115
+ Raises:
116
+ ProviderClientAPIException: If API calls raised an error.
117
+ """
118
+ self.validate_documents(documents)
119
+ try:
120
+ response = await self.router_client.aembedding(
121
+ input=documents, **self._embedding_fn_args
122
+ )
123
+ return self._format_response(response)
124
+ except Exception as e:
125
+ raise ProviderClientAPIException(
126
+ message="Failed to embed documents", original_exception=e
127
+ )
128
+
129
+ @property
130
+ def _embedding_fn_args(self) -> Dict[str, Any]:
131
+ """Returns the arguments to be passed to the embedding function."""
132
+ return {
133
+ **self._litellm_extra_parameters,
134
+ "model": self._model_group_id,
135
+ }
@@ -1,7 +1,7 @@
1
+ import logging
1
2
  from abc import abstractmethod
2
3
  from typing import Dict, List, Any, Union
3
4
 
4
- import logging
5
5
  import structlog
6
6
  from litellm import (
7
7
  completion,
@@ -18,7 +18,7 @@ from rasa.shared.providers._ssl_verification_utils import (
18
18
  ensure_ssl_certificates_for_litellm_openai_based_clients,
19
19
  )
20
20
  from rasa.shared.providers.llm.llm_response import LLMResponse, LLMUsage
21
- from rasa.shared.utils.io import suppress_logs
21
+ from rasa.shared.utils.io import suppress_logs, resolve_environment_variables
22
22
 
23
23
  structlogger = structlog.get_logger()
24
24
 
@@ -29,8 +29,7 @@ logging.getLogger("LiteLLM").setLevel(logging.WARNING)
29
29
 
30
30
 
31
31
  class _BaseLiteLLMClient:
32
- """
33
- An abstract base class for LiteLLM clients.
32
+ """An abstract base class for LiteLLM clients.
34
33
 
35
34
  This class defines the interface and common functionality for all clients
36
35
  based on LiteLLM.
@@ -99,7 +98,6 @@ class _BaseLiteLLMClient:
99
98
  ProviderClientValidationError if validation fails.
100
99
  """
101
100
  self._validate_environment_variables()
102
- self._validate_api_key_not_in_config()
103
101
 
104
102
  def _validate_environment_variables(self) -> None:
105
103
  """Validate that the required environment variables are set."""
@@ -118,61 +116,65 @@ class _BaseLiteLLMClient:
118
116
  )
119
117
  raise ProviderClientValidationError(event_info)
120
118
 
121
- def _validate_api_key_not_in_config(self) -> None:
122
- if "api_key" in self._litellm_extra_parameters:
123
- event_info = (
124
- "API Key is set through `api_key` extra parameter."
125
- "Set API keys through environment variables."
126
- )
127
- structlogger.error(
128
- "base_litellm_client.validate_api_key_not_in_config",
129
- event_info=event_info,
130
- )
131
- raise ProviderClientValidationError(event_info)
132
-
133
119
  @suppress_logs(log_level=logging.WARNING)
134
120
  def completion(self, messages: Union[List[str], str]) -> LLMResponse:
135
- """
136
- Synchronously generate completions for given list of messages.
121
+ """Synchronously generate completions for given list of messages.
137
122
 
138
123
  Args:
139
124
  messages: List of messages or a single message to generate the
140
125
  completion for.
126
+
141
127
  Returns:
142
128
  List of message completions.
129
+
143
130
  Raises:
144
131
  ProviderClientAPIException: If the API request fails.
145
132
  """
146
133
  try:
147
134
  formatted_messages = self._format_messages(messages)
148
- response = completion(
149
- messages=formatted_messages, **self._completion_fn_args
150
- )
135
+ arguments = resolve_environment_variables(self._completion_fn_args)
136
+ response = completion(messages=formatted_messages, **arguments)
151
137
  return self._format_response(response)
152
138
  except Exception as e:
153
139
  raise ProviderClientAPIException(e)
154
140
 
155
141
  @suppress_logs(log_level=logging.WARNING)
156
142
  async def acompletion(self, messages: Union[List[str], str]) -> LLMResponse:
157
- """
158
- Asynchronously generate completions for given list of messages.
143
+ """Asynchronously generate completions for given list of messages.
159
144
 
160
145
  Args:
161
146
  messages: List of messages or a single message to generate the
162
147
  completion for.
148
+
163
149
  Returns:
164
150
  List of message completions.
151
+
165
152
  Raises:
166
153
  ProviderClientAPIException: If the API request fails.
167
154
  """
168
155
  try:
169
156
  formatted_messages = self._format_messages(messages)
170
- response = await acompletion(
171
- messages=formatted_messages, **self._completion_fn_args
172
- )
157
+ arguments = resolve_environment_variables(self._completion_fn_args)
158
+ response = await acompletion(messages=formatted_messages, **arguments)
173
159
  return self._format_response(response)
174
160
  except Exception as e:
175
- raise ProviderClientAPIException(e)
161
+ message = ""
162
+ from rasa.shared.providers.llm.self_hosted_llm_client import (
163
+ SelfHostedLLMClient,
164
+ )
165
+
166
+ if isinstance(self, SelfHostedLLMClient):
167
+ message = (
168
+ "If you are using 'provider=self-hosted' to call a hosted vllm "
169
+ "server make sure your config is correctly setup. You should have "
170
+ "the following mandatory keys in your config: "
171
+ "provider=self-hosted; "
172
+ "model='<your-vllm-model-name>'; "
173
+ "api_base='your-hosted-vllm-serv'."
174
+ "In case you are getting OpenAI connection errors, such as missing "
175
+ "API key, your configuration is incorrect."
176
+ )
177
+ raise ProviderClientAPIException(e, message)
176
178
 
177
179
  def _format_messages(self, messages: Union[List[str], str]) -> List[Dict[str, str]]:
178
180
  """Formats messages (or a single message) to OpenAI format."""
@@ -216,8 +218,7 @@ class _BaseLiteLLMClient:
216
218
 
217
219
  @staticmethod
218
220
  def _ensure_certificates() -> None:
219
- """
220
- Configures SSL certificates for LiteLLM. This method is invoked during
221
+ """Configures SSL certificates for LiteLLM. This method is invoked during
221
222
  client initialization.
222
223
 
223
224
  LiteLLM may utilize `openai` clients or other providers that require
@@ -17,6 +17,7 @@ from rasa.shared.constants import (
17
17
  OPENAI_API_KEY_ENV_VAR,
18
18
  AZURE_API_TYPE_ENV_VAR,
19
19
  AZURE_OPENAI_PROVIDER,
20
+ API_KEY,
20
21
  )
21
22
  from rasa.shared.exceptions import ProviderClientValidationError
22
23
  from rasa.shared.providers._configs.azure_openai_client_config import (
@@ -29,8 +30,7 @@ structlogger = structlog.get_logger()
29
30
 
30
31
 
31
32
  class AzureOpenAILLMClient(_BaseLiteLLMClient):
32
- """
33
- A client for interfacing with Azure's OpenAI LLM deployments.
33
+ """A client for interfacing with Azure's OpenAI LLM deployments.
34
34
 
35
35
  Parameters:
36
36
  deployment (str): The deployment name.
@@ -80,11 +80,7 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
80
80
  or os.getenv(OPENAI_API_VERSION_ENV_VAR)
81
81
  )
82
82
 
83
- # API key can be set through OPENAI_API_KEY too,
84
- # because of the backward compatibility
85
- self._api_key = os.getenv(AZURE_API_KEY_ENV_VAR) or os.getenv(
86
- OPENAI_API_KEY_ENV_VAR
87
- )
83
+ self._api_key_env_var = self._resolve_api_key_env_var()
88
84
 
89
85
  # Not used by LiteLLM, here for backward compatibility
90
86
  self._api_type = (
@@ -117,11 +113,6 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
117
113
  "env_var": AZURE_API_VERSION_ENV_VAR,
118
114
  "deprecated_var": OPENAI_API_VERSION_ENV_VAR,
119
115
  },
120
- "API Key": {
121
- "current_value": self._api_key,
122
- "env_var": AZURE_API_KEY_ENV_VAR,
123
- "deprecated_var": OPENAI_API_KEY_ENV_VAR,
124
- },
125
116
  }
126
117
 
127
118
  deprecation_warning_message = (
@@ -154,10 +145,51 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
154
145
  )
155
146
  raise_deprecation_warning(message=message)
156
147
 
148
+ def _resolve_api_key_env_var(self) -> str:
149
+ """Resolves the environment variable to use for the API key.
150
+
151
+ Returns:
152
+ str: The env variable in dollar syntax format to use for the API key.
153
+ """
154
+ if API_KEY in self._extra_parameters:
155
+ # API key is set to an env var in the config itself
156
+ # in case the model is defined in the endpoints.yml
157
+ return self._extra_parameters[API_KEY]
158
+
159
+ if os.getenv(AZURE_API_KEY_ENV_VAR) is not None:
160
+ return "${AZURE_API_KEY}"
161
+
162
+ if os.getenv(OPENAI_API_KEY_ENV_VAR) is not None:
163
+ # API key can be set through OPENAI_API_KEY too,
164
+ # because of the backward compatibility
165
+ raise_deprecation_warning(
166
+ message=(
167
+ f"Usage of '{OPENAI_API_KEY_ENV_VAR}' environment variable "
168
+ "for setting the API key for Azure OpenAI "
169
+ "client is deprecated and will be removed "
170
+ f"in 4.0.0. Please use '{AZURE_API_KEY_ENV_VAR}' "
171
+ "environment variable."
172
+ )
173
+ )
174
+ return "${OPENAI_API_KEY}"
175
+
176
+ structlogger.error(
177
+ "azure_openai_llm_client.api_key_not_set",
178
+ event_info=(
179
+ "API key not set, it is required for API calls. "
180
+ f"Set it either via the environment variable"
181
+ f"'{AZURE_API_KEY_ENV_VAR}' or directly"
182
+ f"via the config key '{API_KEY}'."
183
+ ),
184
+ )
185
+ raise ProviderClientValidationError(
186
+ f"Missing required environment variable/config key '{API_KEY}' for "
187
+ f"API calls."
188
+ )
189
+
157
190
  @classmethod
158
191
  def from_config(cls, config: Dict[str, Any]) -> "AzureOpenAILLMClient":
159
- """
160
- Initializes the client from given configuration.
192
+ """Initializes the client from given configuration.
161
193
 
162
194
  Args:
163
195
  config (Dict[str, Any]): Configuration.
@@ -212,23 +244,17 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
212
244
 
213
245
  @property
214
246
  def model(self) -> Optional[str]:
215
- """
216
- Returns the name of the model deployed on Azure.
217
- """
247
+ """Returns the name of the model deployed on Azure."""
218
248
  return self._model
219
249
 
220
250
  @property
221
251
  def api_base(self) -> Optional[str]:
222
- """
223
- Returns the API base URL for the Azure OpenAI llm client.
224
- """
252
+ """Returns the API base URL for the Azure OpenAI llm client."""
225
253
  return self._api_base
226
254
 
227
255
  @property
228
256
  def api_version(self) -> Optional[str]:
229
- """
230
- Returns the API version for the Azure OpenAI llm client.
231
- """
257
+ """Returns the API version for the Azure OpenAI llm client."""
232
258
  return self._api_version
233
259
 
234
260
  @property
@@ -261,7 +287,7 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
261
287
  {
262
288
  "api_base": self.api_base,
263
289
  "api_version": self.api_version,
264
- "api_key": self._api_key,
290
+ "api_key": self._api_key_env_var,
265
291
  }
266
292
  )
267
293
  return fn_args
@@ -305,11 +331,6 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
305
331
  "env_var": None,
306
332
  "config_key": DEPLOYMENT_CONFIG_KEY,
307
333
  },
308
- "API Key": {
309
- "current_value": self._api_key,
310
- "env_var": AZURE_API_KEY_ENV_VAR,
311
- "config_key": None,
312
- },
313
334
  }
314
335
 
315
336
  missing_settings = [