rasa-pro 3.11.0rc2__py3-none-any.whl → 3.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (65) hide show
  1. rasa/__main__.py +9 -3
  2. rasa/cli/studio/upload.py +0 -15
  3. rasa/cli/utils.py +1 -1
  4. rasa/core/channels/development_inspector.py +8 -2
  5. rasa/core/channels/voice_ready/audiocodes.py +3 -4
  6. rasa/core/channels/voice_stream/asr/asr_engine.py +19 -1
  7. rasa/core/channels/voice_stream/asr/asr_event.py +1 -1
  8. rasa/core/channels/voice_stream/asr/azure.py +16 -9
  9. rasa/core/channels/voice_stream/asr/deepgram.py +17 -14
  10. rasa/core/channels/voice_stream/tts/azure.py +3 -1
  11. rasa/core/channels/voice_stream/tts/cartesia.py +3 -3
  12. rasa/core/channels/voice_stream/tts/tts_engine.py +10 -1
  13. rasa/core/channels/voice_stream/voice_channel.py +48 -18
  14. rasa/core/information_retrieval/qdrant.py +1 -0
  15. rasa/core/nlg/contextual_response_rephraser.py +2 -2
  16. rasa/core/persistor.py +93 -49
  17. rasa/core/policies/enterprise_search_policy.py +5 -5
  18. rasa/core/policies/flows/flow_executor.py +18 -8
  19. rasa/core/policies/intentless_policy.py +9 -5
  20. rasa/core/processor.py +7 -5
  21. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +2 -1
  22. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +9 -0
  23. rasa/e2e_test/aggregate_test_stats_calculator.py +11 -1
  24. rasa/e2e_test/assertions.py +133 -16
  25. rasa/e2e_test/assertions_schema.yml +23 -0
  26. rasa/e2e_test/e2e_test_runner.py +2 -2
  27. rasa/engine/loader.py +12 -0
  28. rasa/engine/validation.py +310 -86
  29. rasa/model_manager/config.py +8 -0
  30. rasa/model_manager/model_api.py +166 -61
  31. rasa/model_manager/runner_service.py +31 -26
  32. rasa/model_manager/trainer_service.py +14 -23
  33. rasa/model_manager/warm_rasa_process.py +187 -0
  34. rasa/model_service.py +3 -5
  35. rasa/model_training.py +3 -1
  36. rasa/shared/constants.py +27 -5
  37. rasa/shared/core/constants.py +1 -1
  38. rasa/shared/core/domain.py +8 -31
  39. rasa/shared/core/flows/yaml_flows_io.py +13 -4
  40. rasa/shared/importers/importer.py +19 -2
  41. rasa/shared/importers/rasa.py +5 -1
  42. rasa/shared/nlu/training_data/formats/rasa_yaml.py +18 -3
  43. rasa/shared/providers/_configs/litellm_router_client_config.py +29 -9
  44. rasa/shared/providers/_utils.py +79 -0
  45. rasa/shared/providers/embedding/default_litellm_embedding_client.py +24 -0
  46. rasa/shared/providers/embedding/litellm_router_embedding_client.py +1 -1
  47. rasa/shared/providers/llm/_base_litellm_client.py +26 -0
  48. rasa/shared/providers/llm/default_litellm_llm_client.py +24 -0
  49. rasa/shared/providers/llm/litellm_router_llm_client.py +56 -1
  50. rasa/shared/providers/llm/self_hosted_llm_client.py +4 -28
  51. rasa/shared/providers/router/_base_litellm_router_client.py +35 -1
  52. rasa/shared/utils/common.py +30 -3
  53. rasa/shared/utils/health_check/health_check.py +26 -24
  54. rasa/shared/utils/yaml.py +116 -31
  55. rasa/studio/data_handler.py +3 -1
  56. rasa/studio/upload.py +119 -57
  57. rasa/telemetry.py +3 -1
  58. rasa/tracing/config.py +1 -1
  59. rasa/validator.py +40 -4
  60. rasa/version.py +1 -1
  61. {rasa_pro-3.11.0rc2.dist-info → rasa_pro-3.11.1.dist-info}/METADATA +2 -2
  62. {rasa_pro-3.11.0rc2.dist-info → rasa_pro-3.11.1.dist-info}/RECORD +65 -63
  63. {rasa_pro-3.11.0rc2.dist-info → rasa_pro-3.11.1.dist-info}/NOTICE +0 -0
  64. {rasa_pro-3.11.0rc2.dist-info → rasa_pro-3.11.1.dist-info}/WHEEL +0 -0
  65. {rasa_pro-3.11.0rc2.dist-info → rasa_pro-3.11.1.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,79 @@
1
+ import structlog
2
+
3
+ from rasa.shared.constants import (
4
+ AWS_ACCESS_KEY_ID_ENV_VAR,
5
+ AWS_ACCESS_KEY_ID_CONFIG_KEY,
6
+ AWS_SECRET_ACCESS_KEY_ENV_VAR,
7
+ AWS_SECRET_ACCESS_KEY_CONFIG_KEY,
8
+ AWS_REGION_NAME_ENV_VAR,
9
+ AWS_REGION_NAME_CONFIG_KEY,
10
+ AWS_SESSION_TOKEN_CONFIG_KEY,
11
+ AWS_SESSION_TOKEN_ENV_VAR,
12
+ )
13
+ from rasa.shared.exceptions import ProviderClientValidationError
14
+ from litellm import validate_environment
15
+ from rasa.shared.providers.embedding._base_litellm_embedding_client import (
16
+ _VALIDATE_ENVIRONMENT_MISSING_KEYS_KEY,
17
+ )
18
+
19
+ structlogger = structlog.get_logger()
20
+
21
+
22
+ def validate_aws_setup_for_litellm_clients(
23
+ litellm_model_name: str, litellm_call_kwargs: dict, source_log: str
24
+ ) -> None:
25
+ """Validates the AWS setup for LiteLLM clients to ensure all required
26
+ environment variables or corresponding call kwargs are set.
27
+
28
+ Args:
29
+ litellm_model_name (str): The name of the LiteLLM model being validated.
30
+ litellm_call_kwargs (dict): Additional keyword arguments passed to the client,
31
+ which may include configuration values for AWS credentials.
32
+ source_log (str): The source log identifier for structured logging.
33
+
34
+ Raises:
35
+ ProviderClientValidationError: If any required AWS environment variable
36
+ or corresponding configuration key is missing.
37
+ """
38
+
39
+ # Mapping of environment variable names to their corresponding config keys
40
+ envs_to_args = {
41
+ AWS_ACCESS_KEY_ID_ENV_VAR: AWS_ACCESS_KEY_ID_CONFIG_KEY,
42
+ AWS_SECRET_ACCESS_KEY_ENV_VAR: AWS_SECRET_ACCESS_KEY_CONFIG_KEY,
43
+ AWS_REGION_NAME_ENV_VAR: AWS_REGION_NAME_CONFIG_KEY,
44
+ AWS_SESSION_TOKEN_ENV_VAR: AWS_SESSION_TOKEN_CONFIG_KEY,
45
+ }
46
+
47
+ # Validate the environment setup for the model
48
+ validation_info = validate_environment(litellm_model_name)
49
+ missing_environment_variables = validation_info.get(
50
+ _VALIDATE_ENVIRONMENT_MISSING_KEYS_KEY, []
51
+ )
52
+ # Filter out missing environment variables that have been set trough arguments
53
+ # in extra parameters
54
+ missing_environment_variables = [
55
+ missing_env_var
56
+ for missing_env_var in missing_environment_variables
57
+ if litellm_call_kwargs.get(envs_to_args.get(missing_env_var)) is None
58
+ ]
59
+
60
+ if missing_environment_variables:
61
+ missing_environment_details = [
62
+ (
63
+ f"'{missing_env_var}' environment variable or "
64
+ f"'{envs_to_args.get(missing_env_var)}' config key"
65
+ )
66
+ for missing_env_var in missing_environment_variables
67
+ ]
68
+ event_info = (
69
+ f"The following environment variables or configuration keys are "
70
+ f"missing: "
71
+ f"{', '.join(missing_environment_details)}. "
72
+ f"These settings are required for API calls."
73
+ )
74
+ structlogger.error(
75
+ f"{source_log}.validate_aws_environment_variables",
76
+ event_info=event_info,
77
+ missing_environment_variables=missing_environment_variables,
78
+ )
79
+ raise ProviderClientValidationError(event_info)
@@ -1,8 +1,13 @@
1
1
  from typing import Any, Dict
2
2
 
3
+ from rasa.shared.constants import (
4
+ AWS_BEDROCK_PROVIDER,
5
+ AWS_SAGEMAKER_PROVIDER,
6
+ )
3
7
  from rasa.shared.providers._configs.default_litellm_client_config import (
4
8
  DefaultLiteLLMClientConfig,
5
9
  )
10
+ from rasa.shared.providers._utils import validate_aws_setup_for_litellm_clients
6
11
  from rasa.shared.providers.embedding._base_litellm_embedding_client import (
7
12
  _BaseLiteLLMEmbeddingClient,
8
13
  )
@@ -100,3 +105,22 @@ class DefaultLiteLLMEmbeddingClient(_BaseLiteLLMEmbeddingClient):
100
105
  "model": self._litellm_model_name,
101
106
  **self._litellm_extra_parameters,
102
107
  }
108
+
109
+ def validate_client_setup(self) -> None:
110
+ # TODO: Temporarily disable environment variable validation for AWS setup
111
+ # (Bedrock and SageMaker) until resolved by either:
112
+ # 1. An update from the LiteLLM package addressing the issue.
113
+ # 2. The implementation of a Bedrock client on our end.
114
+ # ---
115
+ # This fix ensures a consistent user experience for Bedrock (and
116
+ # SageMaker) in Rasa by allowing AWS secrets to be provided as extra
117
+ # parameters without triggering validation errors due to missing AWS
118
+ # environment variables.
119
+ if self.provider.lower() in [AWS_BEDROCK_PROVIDER, AWS_SAGEMAKER_PROVIDER]:
120
+ validate_aws_setup_for_litellm_clients(
121
+ self._litellm_model_name,
122
+ self._litellm_extra_parameters,
123
+ "default_litellm_embedding_client",
124
+ )
125
+ else:
126
+ super().validate_client_setup()
@@ -72,7 +72,7 @@ class LiteLLMRouterEmbeddingClient(
72
72
  return cls(
73
73
  model_group_id=client_config.model_group_id,
74
74
  model_configurations=client_config.litellm_model_list,
75
- router_settings=client_config.router,
75
+ router_settings=client_config.litellm_router_settings,
76
76
  **client_config.extra_parameters,
77
77
  )
78
78
 
@@ -221,6 +221,32 @@ class _BaseLiteLLMClient:
221
221
  )
222
222
  return formatted_response
223
223
 
224
+ def _format_text_completion_response(self, response: Any) -> LLMResponse:
225
+ """Parses the LiteLLM text completion response to Rasa format."""
226
+ formatted_response = LLMResponse(
227
+ id=response.id,
228
+ created=response.created,
229
+ choices=[choice.text for choice in response.choices],
230
+ model=response.model,
231
+ )
232
+ if (usage := response.usage) is not None:
233
+ prompt_tokens = (
234
+ num_tokens
235
+ if isinstance(num_tokens := usage.prompt_tokens, (int, float))
236
+ else 0
237
+ )
238
+ completion_tokens = (
239
+ num_tokens
240
+ if isinstance(num_tokens := usage.completion_tokens, (int, float))
241
+ else 0
242
+ )
243
+ formatted_response.usage = LLMUsage(prompt_tokens, completion_tokens)
244
+ structlogger.debug(
245
+ "base_litellm_client.formatted_response",
246
+ formatted_response=formatted_response.to_dict(),
247
+ )
248
+ return formatted_response
249
+
224
250
  @staticmethod
225
251
  def _ensure_certificates() -> None:
226
252
  """Configures SSL certificates for LiteLLM. This method is invoked during
@@ -1,8 +1,13 @@
1
1
  from typing import Dict, Any
2
2
 
3
+ from rasa.shared.constants import (
4
+ AWS_BEDROCK_PROVIDER,
5
+ AWS_SAGEMAKER_PROVIDER,
6
+ )
3
7
  from rasa.shared.providers._configs.default_litellm_client_config import (
4
8
  DefaultLiteLLMClientConfig,
5
9
  )
10
+ from rasa.shared.providers._utils import validate_aws_setup_for_litellm_clients
6
11
  from rasa.shared.providers.llm._base_litellm_client import _BaseLiteLLMClient
7
12
 
8
13
 
@@ -82,3 +87,22 @@ class DefaultLiteLLMClient(_BaseLiteLLMClient):
82
87
  to the client provider and deployed model.
83
88
  """
84
89
  return self._extra_parameters
90
+
91
+ def validate_client_setup(self) -> None:
92
+ # TODO: Temporarily change the environment variable validation for AWS setup
93
+ # (Bedrock and SageMaker) until resolved by either:
94
+ # 1. An update from the LiteLLM package addressing the issue.
95
+ # 2. The implementation of a Bedrock client on our end.
96
+ # ---
97
+ # This fix ensures a consistent user experience for Bedrock (and
98
+ # SageMaker) in Rasa by allowing AWS secrets to be provided as extra
99
+ # parameters without triggering validation errors due to missing AWS
100
+ # environment variables.
101
+ if self.provider.lower() in [AWS_BEDROCK_PROVIDER, AWS_SAGEMAKER_PROVIDER]:
102
+ validate_aws_setup_for_litellm_clients(
103
+ self._litellm_model_name,
104
+ self._litellm_extra_parameters,
105
+ "default_litellm_llm_client",
106
+ )
107
+ else:
108
+ super().validate_client_setup()
@@ -68,15 +68,61 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
68
68
  return cls(
69
69
  model_group_id=client_config.model_group_id,
70
70
  model_configurations=client_config.litellm_model_list,
71
- router_settings=client_config.router,
71
+ router_settings=client_config.litellm_router_settings,
72
+ use_chat_completions_endpoint=client_config.use_chat_completions_endpoint,
72
73
  **client_config.extra_parameters,
73
74
  )
74
75
 
76
+ @suppress_logs(log_level=logging.WARNING)
77
+ def _text_completion(self, prompt: Union[List[str], str]) -> LLMResponse:
78
+ """
79
+ Synchronously generate completions for given prompt.
80
+
81
+ Args:
82
+ prompt: Prompt to generate the completion for.
83
+ Returns:
84
+ List of message completions.
85
+ Raises:
86
+ ProviderClientAPIException: If the API request fails.
87
+ """
88
+ try:
89
+ response = self.router_client.text_completion(
90
+ prompt=prompt, **self._completion_fn_args
91
+ )
92
+ return self._format_text_completion_response(response)
93
+ except Exception as e:
94
+ raise ProviderClientAPIException(e)
95
+
96
+ @suppress_logs(log_level=logging.WARNING)
97
+ async def _atext_completion(self, prompt: Union[List[str], str]) -> LLMResponse:
98
+ """
99
+ Asynchronously generate completions for given prompt.
100
+
101
+ Args:
102
+ prompt: Prompt to generate the completion for.
103
+ Returns:
104
+ List of message completions.
105
+ Raises:
106
+ ProviderClientAPIException: If the API request fails.
107
+ """
108
+ try:
109
+ response = await self.router_client.atext_completion(
110
+ prompt=prompt, **self._completion_fn_args
111
+ )
112
+ return self._format_text_completion_response(response)
113
+ except Exception as e:
114
+ raise ProviderClientAPIException(e)
115
+
75
116
  @suppress_logs(log_level=logging.WARNING)
76
117
  def completion(self, messages: Union[List[str], str]) -> LLMResponse:
77
118
  """
78
119
  Synchronously generate completions for given list of messages.
79
120
 
121
+ Method overrides the base class method to call the appropriate
122
+ completion method based on the configuration. If the chat completions
123
+ endpoint is enabled, the completion method is called. Otherwise, the
124
+ text_completion method is called.
125
+
80
126
  Args:
81
127
  messages: List of messages or a single message to generate the
82
128
  completion for.
@@ -85,6 +131,8 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
85
131
  Raises:
86
132
  ProviderClientAPIException: If the API request fails.
87
133
  """
134
+ if not self._use_chat_completions_endpoint:
135
+ return self._text_completion(messages)
88
136
  try:
89
137
  formatted_messages = self._format_messages(messages)
90
138
  response = self.router_client.completion(
@@ -99,6 +147,11 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
99
147
  """
100
148
  Asynchronously generate completions for given list of messages.
101
149
 
150
+ Method overrides the base class method to call the appropriate
151
+ completion method based on the configuration. If the chat completions
152
+ endpoint is enabled, the completion method is called. Otherwise, the
153
+ text_completion method is called.
154
+
102
155
  Args:
103
156
  messages: List of messages or a single message to generate the
104
157
  completion for.
@@ -107,6 +160,8 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
107
160
  Raises:
108
161
  ProviderClientAPIException: If the API request fails.
109
162
  """
163
+ if not self._use_chat_completions_endpoint:
164
+ return await self._atext_completion(messages)
110
165
  try:
111
166
  formatted_messages = self._format_messages(messages)
112
167
  response = await self.router_client.acompletion(
@@ -10,13 +10,14 @@ import structlog
10
10
  from rasa.shared.constants import (
11
11
  SELF_HOSTED_VLLM_PREFIX,
12
12
  SELF_HOSTED_VLLM_API_KEY_ENV_VAR,
13
+ API_KEY,
13
14
  )
14
15
  from rasa.shared.providers._configs.self_hosted_llm_client_config import (
15
16
  SelfHostedLLMClientConfig,
16
17
  )
17
18
  from rasa.shared.exceptions import ProviderClientAPIException
18
19
  from rasa.shared.providers.llm._base_litellm_client import _BaseLiteLLMClient
19
- from rasa.shared.providers.llm.llm_response import LLMResponse, LLMUsage
20
+ from rasa.shared.providers.llm.llm_response import LLMResponse
20
21
  from rasa.shared.utils.io import suppress_logs
21
22
 
22
23
  structlogger = structlog.get_logger()
@@ -61,7 +62,8 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
61
62
  self._api_version = api_version
62
63
  self._use_chat_completions_endpoint = use_chat_completions_endpoint
63
64
  self._extra_parameters = kwargs or {}
64
- self._apply_dummy_api_key_if_missing()
65
+ if self._extra_parameters.get(API_KEY) is None:
66
+ self._apply_dummy_api_key_if_missing()
65
67
 
66
68
  @classmethod
67
69
  def from_config(cls, config: Dict[str, Any]) -> "SelfHostedLLMClient":
@@ -259,32 +261,6 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
259
261
  return super().completion(messages)
260
262
  return self._text_completion(messages)
261
263
 
262
- def _format_text_completion_response(self, response: Any) -> LLMResponse:
263
- """Parses the LiteLLM text completion response to Rasa format."""
264
- formatted_response = LLMResponse(
265
- id=response.id,
266
- created=response.created,
267
- choices=[choice.text for choice in response.choices],
268
- model=response.model,
269
- )
270
- if (usage := response.usage) is not None:
271
- prompt_tokens = (
272
- num_tokens
273
- if isinstance(num_tokens := usage.prompt_tokens, (int, float))
274
- else 0
275
- )
276
- completion_tokens = (
277
- num_tokens
278
- if isinstance(num_tokens := usage.completion_tokens, (int, float))
279
- else 0
280
- )
281
- formatted_response.usage = LLMUsage(prompt_tokens, completion_tokens)
282
- structlogger.debug(
283
- "base_litellm_client.formatted_response",
284
- formatted_response=formatted_response.to_dict(),
285
- )
286
- return formatted_response
287
-
288
264
  @staticmethod
289
265
  def _apply_dummy_api_key_if_missing() -> None:
290
266
  if not os.getenv(SELF_HOSTED_VLLM_API_KEY_ENV_VAR):
@@ -1,4 +1,5 @@
1
1
  from typing import Any, Dict, List
2
+ import os
2
3
  import structlog
3
4
 
4
5
  from litellm import Router
@@ -7,6 +8,12 @@ from rasa.shared.constants import (
7
8
  MODEL_LIST_KEY,
8
9
  MODEL_GROUP_ID_CONFIG_KEY,
9
10
  ROUTER_CONFIG_KEY,
11
+ SELF_HOSTED_VLLM_PREFIX,
12
+ SELF_HOSTED_VLLM_API_KEY_ENV_VAR,
13
+ LITELLM_PARAMS_KEY,
14
+ API_KEY,
15
+ MODEL_CONFIG_KEY,
16
+ USE_CHAT_COMPLETIONS_ENDPOINT_CONFIG_KEY,
10
17
  )
11
18
  from rasa.shared.exceptions import ProviderClientValidationError
12
19
  from rasa.shared.providers._configs.litellm_router_client_config import (
@@ -42,12 +49,15 @@ class _BaseLiteLLMRouterClient:
42
49
  model_group_id: str,
43
50
  model_configurations: List[Dict[str, Any]],
44
51
  router_settings: Dict[str, Any],
52
+ use_chat_completions_endpoint: bool = True,
45
53
  **kwargs: Any,
46
54
  ):
47
55
  self._model_group_id = model_group_id
48
56
  self._model_configurations = model_configurations
49
57
  self._router_settings = router_settings
58
+ self._use_chat_completions_endpoint = use_chat_completions_endpoint
50
59
  self._extra_parameters = kwargs or {}
60
+ self.additional_client_setup()
51
61
  try:
52
62
  resolved_model_configurations = (
53
63
  self._resolve_env_vars_in_model_configurations()
@@ -67,6 +77,21 @@ class _BaseLiteLLMRouterClient:
67
77
  )
68
78
  raise ProviderClientValidationError(f"{event_info} Original error: {e}")
69
79
 
80
+ def additional_client_setup(self) -> None:
81
+ """Additional setup for the LiteLLM Router client."""
82
+ # If the model configuration is self-hosted VLLM, set a dummy API key if not
83
+ # provided. A bug in the LiteLLM library requires an API key to be set even if
84
+ # it is not required.
85
+ for model_configuration in self.model_configurations:
86
+ if (
87
+ f"{SELF_HOSTED_VLLM_PREFIX}/"
88
+ in model_configuration[LITELLM_PARAMS_KEY][MODEL_CONFIG_KEY]
89
+ and API_KEY not in model_configuration[LITELLM_PARAMS_KEY]
90
+ and not os.getenv(SELF_HOSTED_VLLM_API_KEY_ENV_VAR)
91
+ ):
92
+ os.environ[SELF_HOSTED_VLLM_API_KEY_ENV_VAR] = "dummy api key"
93
+ return
94
+
70
95
  @classmethod
71
96
  def from_config(cls, config: Dict[str, Any]) -> "_BaseLiteLLMRouterClient":
72
97
  """Instantiates a LiteLLM Router Embedding client from a configuration dict.
@@ -95,7 +120,8 @@ class _BaseLiteLLMRouterClient:
95
120
  return cls(
96
121
  model_group_id=client_config.model_group_id,
97
122
  model_configurations=client_config.litellm_model_list,
98
- router_settings=client_config.router,
123
+ router_settings=client_config.litellm_router_settings,
124
+ use_chat_completions_endpoint=client_config.use_chat_completions_endpoint,
99
125
  **client_config.extra_parameters,
100
126
  )
101
127
 
@@ -119,6 +145,11 @@ class _BaseLiteLLMRouterClient:
119
145
  """Returns the instantiated LiteLLM Router client."""
120
146
  return self._router_client
121
147
 
148
+ @property
149
+ def use_chat_completions_endpoint(self) -> bool:
150
+ """Returns whether to use the chat completions endpoint."""
151
+ return self._use_chat_completions_endpoint
152
+
122
153
  @property
123
154
  def _litellm_extra_parameters(self) -> Dict[str, Any]:
124
155
  """
@@ -136,6 +167,9 @@ class _BaseLiteLLMRouterClient:
136
167
  MODEL_GROUP_ID_CONFIG_KEY: self.model_group_id,
137
168
  MODEL_LIST_KEY: self.model_configurations,
138
169
  ROUTER_CONFIG_KEY: self.router_settings,
170
+ USE_CHAT_COMPLETIONS_ENDPOINT_CONFIG_KEY: (
171
+ self.use_chat_completions_endpoint
172
+ ),
139
173
  **self._litellm_extra_parameters,
140
174
  }
141
175
 
@@ -3,14 +3,16 @@ import functools
3
3
  import importlib
4
4
  import inspect
5
5
  import logging
6
+ import os
6
7
  import pkgutil
7
8
  import sys
8
9
  from types import ModuleType
9
- from typing import Text, Dict, Optional, Any, List, Callable, Collection, Type
10
+ from typing import Sequence, Text, Dict, Optional, Any, List, Callable, Collection, Type
10
11
 
11
12
  import rasa.shared.utils.io
13
+ from rasa.exceptions import MissingDependencyException
12
14
  from rasa.shared.constants import DOCS_URL_MIGRATION_GUIDE
13
- from rasa.shared.exceptions import RasaException
15
+ from rasa.shared.exceptions import ProviderClientValidationError, RasaException
14
16
 
15
17
  logger = logging.getLogger(__name__)
16
18
 
@@ -193,7 +195,7 @@ def mark_as_experimental_feature(feature_name: Text) -> None:
193
195
  def mark_as_beta_feature(feature_name: Text) -> None:
194
196
  """Warns users that they are using a beta feature."""
195
197
  logger.warning(
196
- f"🔬 Beta Feature: {feature_name} is in beta. It may have unexpected"
198
+ f"🔬 Beta Feature: {feature_name} is in beta. It may have unexpected "
197
199
  "behaviour and might be changed in the future."
198
200
  )
199
201
 
@@ -295,3 +297,28 @@ def warn_and_exit_if_module_path_contains_rasa_plus(
295
297
  docs=DOCS_URL_MIGRATION_GUIDE,
296
298
  )
297
299
  sys.exit(1)
300
+
301
+
302
+ def validate_environment(
303
+ required_env_vars: Sequence[str],
304
+ required_packages: Sequence[str],
305
+ component_name: str,
306
+ ) -> None:
307
+ """Make sure all needed requirements for a component are met.
308
+ Args:
309
+ required_env_vars: List of environment variables that should be set
310
+ required_packages: List of packages that should be installed
311
+ component_name: component name that needs the requirements
312
+ """
313
+ for e in required_env_vars:
314
+ if not os.environ.get(e):
315
+ raise ProviderClientValidationError(
316
+ f"Missing environment variable for {component_name}: {e}"
317
+ )
318
+ for p in required_packages:
319
+ try:
320
+ importlib.import_module(p)
321
+ except ImportError:
322
+ raise MissingDependencyException(
323
+ f"Missing package for {component_name}: {p}"
324
+ )
@@ -1,4 +1,5 @@
1
1
  import os
2
+ import sys
2
3
  from typing import Optional, Dict, Any
3
4
 
4
5
  from rasa.shared.constants import (
@@ -9,7 +10,6 @@ from rasa.shared.constants import (
9
10
  from rasa.shared.exceptions import ProviderClientValidationError
10
11
  from rasa.shared.providers.embedding.embedding_client import EmbeddingClient
11
12
  from rasa.shared.providers.llm.llm_client import LLMClient
12
- from rasa.shared.utils.cli import print_error_and_exit
13
13
  from rasa.shared.utils.llm import llm_factory, structlogger, embedder_factory
14
14
 
15
15
 
@@ -25,15 +25,15 @@ def try_instantiate_llm_client(
25
25
  except (ProviderClientValidationError, ValueError) as e:
26
26
  structlogger.error(
27
27
  f"{log_source_function}.llm_instantiation_failed",
28
- message="Unable to instantiate LLM client.",
28
+ event_info=(
29
+ f"Unable to create the LLM client for component - "
30
+ f"{log_source_component}. "
31
+ f"Please make sure you specified the required environment variables "
32
+ f"and configuration keys. "
33
+ ),
29
34
  error=e,
30
35
  )
31
- print_error_and_exit(
32
- f"Unable to create the LLM client for component - {log_source_component}. "
33
- f"Please make sure you specified the required environment variables "
34
- f"and configuration keys. "
35
- f"Error: {e}"
36
- )
36
+ sys.exit(1)
37
37
 
38
38
 
39
39
  def try_instantiate_embedder(
@@ -48,14 +48,14 @@ def try_instantiate_embedder(
48
48
  except (ProviderClientValidationError, ValueError) as e:
49
49
  structlogger.error(
50
50
  f"{log_source_function}.embedder_instantiation_failed",
51
- message="Unable to instantiate Embedding client.",
51
+ event_info=(
52
+ f"Unable to create the Embedding client for component - "
53
+ f"{log_source_component}. Please make sure you specified the required "
54
+ f"environment variables and configuration keys."
55
+ ),
52
56
  error=e,
53
57
  )
54
- print_error_and_exit(
55
- f"Unable to create the Embedding client for component - "
56
- f"{log_source_component}. Please make sure you specified the required "
57
- f"environment variables and configuration keys. Error: {e}"
58
- )
58
+ sys.exit(1)
59
59
 
60
60
 
61
61
  def perform_llm_health_check(
@@ -202,13 +202,14 @@ def send_test_llm_api_request(
202
202
  except Exception as e:
203
203
  structlogger.error(
204
204
  f"{log_source_function}.send_test_llm_api_request_failed",
205
- event_info="Test call to the LLM API failed.",
205
+ event_info=(
206
+ f"Test call to the LLM API failed for component - "
207
+ f"{log_source_component}.",
208
+ ),
209
+ config=llm_client.config,
206
210
  error=e,
207
211
  )
208
- print_error_and_exit(
209
- f"Test call to the LLM API failed for component - {log_source_component}. "
210
- f"Error: {e}"
211
- )
212
+ sys.exit(1)
212
213
 
213
214
 
214
215
  def send_test_embeddings_api_request(
@@ -232,13 +233,14 @@ def send_test_embeddings_api_request(
232
233
  except Exception as e:
233
234
  structlogger.error(
234
235
  f"{log_source_function}.send_test_llm_api_request_failed",
235
- event_info="Test call to the Embeddings API failed.",
236
+ event_info=(
237
+ f"Test call to the Embeddings API failed for component - "
238
+ f"{log_source_component}."
239
+ ),
240
+ config=embedder.config,
236
241
  error=e,
237
242
  )
238
- print_error_and_exit(
239
- f"Test call to the Embeddings API failed for component - "
240
- f"{log_source_component}. Error: {e}"
241
- )
243
+ sys.exit(1)
242
244
 
243
245
 
244
246
  def is_api_health_check_enabled() -> bool: