rasa-pro 3.11.0rc2__py3-none-any.whl → 3.11.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/__main__.py +9 -3
- rasa/cli/studio/upload.py +0 -15
- rasa/cli/utils.py +1 -1
- rasa/core/channels/development_inspector.py +8 -2
- rasa/core/channels/voice_ready/audiocodes.py +3 -4
- rasa/core/channels/voice_stream/asr/asr_engine.py +19 -1
- rasa/core/channels/voice_stream/asr/asr_event.py +1 -1
- rasa/core/channels/voice_stream/asr/azure.py +16 -9
- rasa/core/channels/voice_stream/asr/deepgram.py +17 -14
- rasa/core/channels/voice_stream/tts/azure.py +3 -1
- rasa/core/channels/voice_stream/tts/cartesia.py +3 -3
- rasa/core/channels/voice_stream/tts/tts_engine.py +10 -1
- rasa/core/channels/voice_stream/voice_channel.py +48 -18
- rasa/core/information_retrieval/qdrant.py +1 -0
- rasa/core/nlg/contextual_response_rephraser.py +2 -2
- rasa/core/persistor.py +93 -49
- rasa/core/policies/enterprise_search_policy.py +5 -5
- rasa/core/policies/flows/flow_executor.py +18 -8
- rasa/core/policies/intentless_policy.py +9 -5
- rasa/core/processor.py +7 -5
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +2 -1
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +9 -0
- rasa/e2e_test/aggregate_test_stats_calculator.py +11 -1
- rasa/e2e_test/assertions.py +133 -16
- rasa/e2e_test/assertions_schema.yml +23 -0
- rasa/e2e_test/e2e_test_runner.py +2 -2
- rasa/engine/loader.py +12 -0
- rasa/engine/validation.py +310 -86
- rasa/model_manager/config.py +8 -0
- rasa/model_manager/model_api.py +166 -61
- rasa/model_manager/runner_service.py +31 -26
- rasa/model_manager/trainer_service.py +14 -23
- rasa/model_manager/warm_rasa_process.py +187 -0
- rasa/model_service.py +3 -5
- rasa/model_training.py +3 -1
- rasa/shared/constants.py +27 -5
- rasa/shared/core/constants.py +1 -1
- rasa/shared/core/domain.py +8 -31
- rasa/shared/core/flows/yaml_flows_io.py +13 -4
- rasa/shared/importers/importer.py +19 -2
- rasa/shared/importers/rasa.py +5 -1
- rasa/shared/nlu/training_data/formats/rasa_yaml.py +18 -3
- rasa/shared/providers/_configs/litellm_router_client_config.py +29 -9
- rasa/shared/providers/_utils.py +79 -0
- rasa/shared/providers/embedding/default_litellm_embedding_client.py +24 -0
- rasa/shared/providers/embedding/litellm_router_embedding_client.py +1 -1
- rasa/shared/providers/llm/_base_litellm_client.py +26 -0
- rasa/shared/providers/llm/default_litellm_llm_client.py +24 -0
- rasa/shared/providers/llm/litellm_router_llm_client.py +56 -1
- rasa/shared/providers/llm/self_hosted_llm_client.py +4 -28
- rasa/shared/providers/router/_base_litellm_router_client.py +35 -1
- rasa/shared/utils/common.py +30 -3
- rasa/shared/utils/health_check/health_check.py +26 -24
- rasa/shared/utils/yaml.py +116 -31
- rasa/studio/data_handler.py +3 -1
- rasa/studio/upload.py +119 -57
- rasa/telemetry.py +3 -1
- rasa/tracing/config.py +1 -1
- rasa/validator.py +40 -4
- rasa/version.py +1 -1
- {rasa_pro-3.11.0rc2.dist-info → rasa_pro-3.11.1.dist-info}/METADATA +2 -2
- {rasa_pro-3.11.0rc2.dist-info → rasa_pro-3.11.1.dist-info}/RECORD +65 -63
- {rasa_pro-3.11.0rc2.dist-info → rasa_pro-3.11.1.dist-info}/NOTICE +0 -0
- {rasa_pro-3.11.0rc2.dist-info → rasa_pro-3.11.1.dist-info}/WHEEL +0 -0
- {rasa_pro-3.11.0rc2.dist-info → rasa_pro-3.11.1.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
import structlog
|
|
2
|
+
|
|
3
|
+
from rasa.shared.constants import (
|
|
4
|
+
AWS_ACCESS_KEY_ID_ENV_VAR,
|
|
5
|
+
AWS_ACCESS_KEY_ID_CONFIG_KEY,
|
|
6
|
+
AWS_SECRET_ACCESS_KEY_ENV_VAR,
|
|
7
|
+
AWS_SECRET_ACCESS_KEY_CONFIG_KEY,
|
|
8
|
+
AWS_REGION_NAME_ENV_VAR,
|
|
9
|
+
AWS_REGION_NAME_CONFIG_KEY,
|
|
10
|
+
AWS_SESSION_TOKEN_CONFIG_KEY,
|
|
11
|
+
AWS_SESSION_TOKEN_ENV_VAR,
|
|
12
|
+
)
|
|
13
|
+
from rasa.shared.exceptions import ProviderClientValidationError
|
|
14
|
+
from litellm import validate_environment
|
|
15
|
+
from rasa.shared.providers.embedding._base_litellm_embedding_client import (
|
|
16
|
+
_VALIDATE_ENVIRONMENT_MISSING_KEYS_KEY,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
structlogger = structlog.get_logger()
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def validate_aws_setup_for_litellm_clients(
|
|
23
|
+
litellm_model_name: str, litellm_call_kwargs: dict, source_log: str
|
|
24
|
+
) -> None:
|
|
25
|
+
"""Validates the AWS setup for LiteLLM clients to ensure all required
|
|
26
|
+
environment variables or corresponding call kwargs are set.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
litellm_model_name (str): The name of the LiteLLM model being validated.
|
|
30
|
+
litellm_call_kwargs (dict): Additional keyword arguments passed to the client,
|
|
31
|
+
which may include configuration values for AWS credentials.
|
|
32
|
+
source_log (str): The source log identifier for structured logging.
|
|
33
|
+
|
|
34
|
+
Raises:
|
|
35
|
+
ProviderClientValidationError: If any required AWS environment variable
|
|
36
|
+
or corresponding configuration key is missing.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
# Mapping of environment variable names to their corresponding config keys
|
|
40
|
+
envs_to_args = {
|
|
41
|
+
AWS_ACCESS_KEY_ID_ENV_VAR: AWS_ACCESS_KEY_ID_CONFIG_KEY,
|
|
42
|
+
AWS_SECRET_ACCESS_KEY_ENV_VAR: AWS_SECRET_ACCESS_KEY_CONFIG_KEY,
|
|
43
|
+
AWS_REGION_NAME_ENV_VAR: AWS_REGION_NAME_CONFIG_KEY,
|
|
44
|
+
AWS_SESSION_TOKEN_ENV_VAR: AWS_SESSION_TOKEN_CONFIG_KEY,
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
# Validate the environment setup for the model
|
|
48
|
+
validation_info = validate_environment(litellm_model_name)
|
|
49
|
+
missing_environment_variables = validation_info.get(
|
|
50
|
+
_VALIDATE_ENVIRONMENT_MISSING_KEYS_KEY, []
|
|
51
|
+
)
|
|
52
|
+
# Filter out missing environment variables that have been set trough arguments
|
|
53
|
+
# in extra parameters
|
|
54
|
+
missing_environment_variables = [
|
|
55
|
+
missing_env_var
|
|
56
|
+
for missing_env_var in missing_environment_variables
|
|
57
|
+
if litellm_call_kwargs.get(envs_to_args.get(missing_env_var)) is None
|
|
58
|
+
]
|
|
59
|
+
|
|
60
|
+
if missing_environment_variables:
|
|
61
|
+
missing_environment_details = [
|
|
62
|
+
(
|
|
63
|
+
f"'{missing_env_var}' environment variable or "
|
|
64
|
+
f"'{envs_to_args.get(missing_env_var)}' config key"
|
|
65
|
+
)
|
|
66
|
+
for missing_env_var in missing_environment_variables
|
|
67
|
+
]
|
|
68
|
+
event_info = (
|
|
69
|
+
f"The following environment variables or configuration keys are "
|
|
70
|
+
f"missing: "
|
|
71
|
+
f"{', '.join(missing_environment_details)}. "
|
|
72
|
+
f"These settings are required for API calls."
|
|
73
|
+
)
|
|
74
|
+
structlogger.error(
|
|
75
|
+
f"{source_log}.validate_aws_environment_variables",
|
|
76
|
+
event_info=event_info,
|
|
77
|
+
missing_environment_variables=missing_environment_variables,
|
|
78
|
+
)
|
|
79
|
+
raise ProviderClientValidationError(event_info)
|
|
@@ -1,8 +1,13 @@
|
|
|
1
1
|
from typing import Any, Dict
|
|
2
2
|
|
|
3
|
+
from rasa.shared.constants import (
|
|
4
|
+
AWS_BEDROCK_PROVIDER,
|
|
5
|
+
AWS_SAGEMAKER_PROVIDER,
|
|
6
|
+
)
|
|
3
7
|
from rasa.shared.providers._configs.default_litellm_client_config import (
|
|
4
8
|
DefaultLiteLLMClientConfig,
|
|
5
9
|
)
|
|
10
|
+
from rasa.shared.providers._utils import validate_aws_setup_for_litellm_clients
|
|
6
11
|
from rasa.shared.providers.embedding._base_litellm_embedding_client import (
|
|
7
12
|
_BaseLiteLLMEmbeddingClient,
|
|
8
13
|
)
|
|
@@ -100,3 +105,22 @@ class DefaultLiteLLMEmbeddingClient(_BaseLiteLLMEmbeddingClient):
|
|
|
100
105
|
"model": self._litellm_model_name,
|
|
101
106
|
**self._litellm_extra_parameters,
|
|
102
107
|
}
|
|
108
|
+
|
|
109
|
+
def validate_client_setup(self) -> None:
|
|
110
|
+
# TODO: Temporarily disable environment variable validation for AWS setup
|
|
111
|
+
# (Bedrock and SageMaker) until resolved by either:
|
|
112
|
+
# 1. An update from the LiteLLM package addressing the issue.
|
|
113
|
+
# 2. The implementation of a Bedrock client on our end.
|
|
114
|
+
# ---
|
|
115
|
+
# This fix ensures a consistent user experience for Bedrock (and
|
|
116
|
+
# SageMaker) in Rasa by allowing AWS secrets to be provided as extra
|
|
117
|
+
# parameters without triggering validation errors due to missing AWS
|
|
118
|
+
# environment variables.
|
|
119
|
+
if self.provider.lower() in [AWS_BEDROCK_PROVIDER, AWS_SAGEMAKER_PROVIDER]:
|
|
120
|
+
validate_aws_setup_for_litellm_clients(
|
|
121
|
+
self._litellm_model_name,
|
|
122
|
+
self._litellm_extra_parameters,
|
|
123
|
+
"default_litellm_embedding_client",
|
|
124
|
+
)
|
|
125
|
+
else:
|
|
126
|
+
super().validate_client_setup()
|
|
@@ -72,7 +72,7 @@ class LiteLLMRouterEmbeddingClient(
|
|
|
72
72
|
return cls(
|
|
73
73
|
model_group_id=client_config.model_group_id,
|
|
74
74
|
model_configurations=client_config.litellm_model_list,
|
|
75
|
-
router_settings=client_config.
|
|
75
|
+
router_settings=client_config.litellm_router_settings,
|
|
76
76
|
**client_config.extra_parameters,
|
|
77
77
|
)
|
|
78
78
|
|
|
@@ -221,6 +221,32 @@ class _BaseLiteLLMClient:
|
|
|
221
221
|
)
|
|
222
222
|
return formatted_response
|
|
223
223
|
|
|
224
|
+
def _format_text_completion_response(self, response: Any) -> LLMResponse:
|
|
225
|
+
"""Parses the LiteLLM text completion response to Rasa format."""
|
|
226
|
+
formatted_response = LLMResponse(
|
|
227
|
+
id=response.id,
|
|
228
|
+
created=response.created,
|
|
229
|
+
choices=[choice.text for choice in response.choices],
|
|
230
|
+
model=response.model,
|
|
231
|
+
)
|
|
232
|
+
if (usage := response.usage) is not None:
|
|
233
|
+
prompt_tokens = (
|
|
234
|
+
num_tokens
|
|
235
|
+
if isinstance(num_tokens := usage.prompt_tokens, (int, float))
|
|
236
|
+
else 0
|
|
237
|
+
)
|
|
238
|
+
completion_tokens = (
|
|
239
|
+
num_tokens
|
|
240
|
+
if isinstance(num_tokens := usage.completion_tokens, (int, float))
|
|
241
|
+
else 0
|
|
242
|
+
)
|
|
243
|
+
formatted_response.usage = LLMUsage(prompt_tokens, completion_tokens)
|
|
244
|
+
structlogger.debug(
|
|
245
|
+
"base_litellm_client.formatted_response",
|
|
246
|
+
formatted_response=formatted_response.to_dict(),
|
|
247
|
+
)
|
|
248
|
+
return formatted_response
|
|
249
|
+
|
|
224
250
|
@staticmethod
|
|
225
251
|
def _ensure_certificates() -> None:
|
|
226
252
|
"""Configures SSL certificates for LiteLLM. This method is invoked during
|
|
@@ -1,8 +1,13 @@
|
|
|
1
1
|
from typing import Dict, Any
|
|
2
2
|
|
|
3
|
+
from rasa.shared.constants import (
|
|
4
|
+
AWS_BEDROCK_PROVIDER,
|
|
5
|
+
AWS_SAGEMAKER_PROVIDER,
|
|
6
|
+
)
|
|
3
7
|
from rasa.shared.providers._configs.default_litellm_client_config import (
|
|
4
8
|
DefaultLiteLLMClientConfig,
|
|
5
9
|
)
|
|
10
|
+
from rasa.shared.providers._utils import validate_aws_setup_for_litellm_clients
|
|
6
11
|
from rasa.shared.providers.llm._base_litellm_client import _BaseLiteLLMClient
|
|
7
12
|
|
|
8
13
|
|
|
@@ -82,3 +87,22 @@ class DefaultLiteLLMClient(_BaseLiteLLMClient):
|
|
|
82
87
|
to the client provider and deployed model.
|
|
83
88
|
"""
|
|
84
89
|
return self._extra_parameters
|
|
90
|
+
|
|
91
|
+
def validate_client_setup(self) -> None:
|
|
92
|
+
# TODO: Temporarily change the environment variable validation for AWS setup
|
|
93
|
+
# (Bedrock and SageMaker) until resolved by either:
|
|
94
|
+
# 1. An update from the LiteLLM package addressing the issue.
|
|
95
|
+
# 2. The implementation of a Bedrock client on our end.
|
|
96
|
+
# ---
|
|
97
|
+
# This fix ensures a consistent user experience for Bedrock (and
|
|
98
|
+
# SageMaker) in Rasa by allowing AWS secrets to be provided as extra
|
|
99
|
+
# parameters without triggering validation errors due to missing AWS
|
|
100
|
+
# environment variables.
|
|
101
|
+
if self.provider.lower() in [AWS_BEDROCK_PROVIDER, AWS_SAGEMAKER_PROVIDER]:
|
|
102
|
+
validate_aws_setup_for_litellm_clients(
|
|
103
|
+
self._litellm_model_name,
|
|
104
|
+
self._litellm_extra_parameters,
|
|
105
|
+
"default_litellm_llm_client",
|
|
106
|
+
)
|
|
107
|
+
else:
|
|
108
|
+
super().validate_client_setup()
|
|
@@ -68,15 +68,61 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
68
68
|
return cls(
|
|
69
69
|
model_group_id=client_config.model_group_id,
|
|
70
70
|
model_configurations=client_config.litellm_model_list,
|
|
71
|
-
router_settings=client_config.
|
|
71
|
+
router_settings=client_config.litellm_router_settings,
|
|
72
|
+
use_chat_completions_endpoint=client_config.use_chat_completions_endpoint,
|
|
72
73
|
**client_config.extra_parameters,
|
|
73
74
|
)
|
|
74
75
|
|
|
76
|
+
@suppress_logs(log_level=logging.WARNING)
|
|
77
|
+
def _text_completion(self, prompt: Union[List[str], str]) -> LLMResponse:
|
|
78
|
+
"""
|
|
79
|
+
Synchronously generate completions for given prompt.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
prompt: Prompt to generate the completion for.
|
|
83
|
+
Returns:
|
|
84
|
+
List of message completions.
|
|
85
|
+
Raises:
|
|
86
|
+
ProviderClientAPIException: If the API request fails.
|
|
87
|
+
"""
|
|
88
|
+
try:
|
|
89
|
+
response = self.router_client.text_completion(
|
|
90
|
+
prompt=prompt, **self._completion_fn_args
|
|
91
|
+
)
|
|
92
|
+
return self._format_text_completion_response(response)
|
|
93
|
+
except Exception as e:
|
|
94
|
+
raise ProviderClientAPIException(e)
|
|
95
|
+
|
|
96
|
+
@suppress_logs(log_level=logging.WARNING)
|
|
97
|
+
async def _atext_completion(self, prompt: Union[List[str], str]) -> LLMResponse:
|
|
98
|
+
"""
|
|
99
|
+
Asynchronously generate completions for given prompt.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
prompt: Prompt to generate the completion for.
|
|
103
|
+
Returns:
|
|
104
|
+
List of message completions.
|
|
105
|
+
Raises:
|
|
106
|
+
ProviderClientAPIException: If the API request fails.
|
|
107
|
+
"""
|
|
108
|
+
try:
|
|
109
|
+
response = await self.router_client.atext_completion(
|
|
110
|
+
prompt=prompt, **self._completion_fn_args
|
|
111
|
+
)
|
|
112
|
+
return self._format_text_completion_response(response)
|
|
113
|
+
except Exception as e:
|
|
114
|
+
raise ProviderClientAPIException(e)
|
|
115
|
+
|
|
75
116
|
@suppress_logs(log_level=logging.WARNING)
|
|
76
117
|
def completion(self, messages: Union[List[str], str]) -> LLMResponse:
|
|
77
118
|
"""
|
|
78
119
|
Synchronously generate completions for given list of messages.
|
|
79
120
|
|
|
121
|
+
Method overrides the base class method to call the appropriate
|
|
122
|
+
completion method based on the configuration. If the chat completions
|
|
123
|
+
endpoint is enabled, the completion method is called. Otherwise, the
|
|
124
|
+
text_completion method is called.
|
|
125
|
+
|
|
80
126
|
Args:
|
|
81
127
|
messages: List of messages or a single message to generate the
|
|
82
128
|
completion for.
|
|
@@ -85,6 +131,8 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
85
131
|
Raises:
|
|
86
132
|
ProviderClientAPIException: If the API request fails.
|
|
87
133
|
"""
|
|
134
|
+
if not self._use_chat_completions_endpoint:
|
|
135
|
+
return self._text_completion(messages)
|
|
88
136
|
try:
|
|
89
137
|
formatted_messages = self._format_messages(messages)
|
|
90
138
|
response = self.router_client.completion(
|
|
@@ -99,6 +147,11 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
99
147
|
"""
|
|
100
148
|
Asynchronously generate completions for given list of messages.
|
|
101
149
|
|
|
150
|
+
Method overrides the base class method to call the appropriate
|
|
151
|
+
completion method based on the configuration. If the chat completions
|
|
152
|
+
endpoint is enabled, the completion method is called. Otherwise, the
|
|
153
|
+
text_completion method is called.
|
|
154
|
+
|
|
102
155
|
Args:
|
|
103
156
|
messages: List of messages or a single message to generate the
|
|
104
157
|
completion for.
|
|
@@ -107,6 +160,8 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
107
160
|
Raises:
|
|
108
161
|
ProviderClientAPIException: If the API request fails.
|
|
109
162
|
"""
|
|
163
|
+
if not self._use_chat_completions_endpoint:
|
|
164
|
+
return await self._atext_completion(messages)
|
|
110
165
|
try:
|
|
111
166
|
formatted_messages = self._format_messages(messages)
|
|
112
167
|
response = await self.router_client.acompletion(
|
|
@@ -10,13 +10,14 @@ import structlog
|
|
|
10
10
|
from rasa.shared.constants import (
|
|
11
11
|
SELF_HOSTED_VLLM_PREFIX,
|
|
12
12
|
SELF_HOSTED_VLLM_API_KEY_ENV_VAR,
|
|
13
|
+
API_KEY,
|
|
13
14
|
)
|
|
14
15
|
from rasa.shared.providers._configs.self_hosted_llm_client_config import (
|
|
15
16
|
SelfHostedLLMClientConfig,
|
|
16
17
|
)
|
|
17
18
|
from rasa.shared.exceptions import ProviderClientAPIException
|
|
18
19
|
from rasa.shared.providers.llm._base_litellm_client import _BaseLiteLLMClient
|
|
19
|
-
from rasa.shared.providers.llm.llm_response import LLMResponse
|
|
20
|
+
from rasa.shared.providers.llm.llm_response import LLMResponse
|
|
20
21
|
from rasa.shared.utils.io import suppress_logs
|
|
21
22
|
|
|
22
23
|
structlogger = structlog.get_logger()
|
|
@@ -61,7 +62,8 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
|
|
|
61
62
|
self._api_version = api_version
|
|
62
63
|
self._use_chat_completions_endpoint = use_chat_completions_endpoint
|
|
63
64
|
self._extra_parameters = kwargs or {}
|
|
64
|
-
self.
|
|
65
|
+
if self._extra_parameters.get(API_KEY) is None:
|
|
66
|
+
self._apply_dummy_api_key_if_missing()
|
|
65
67
|
|
|
66
68
|
@classmethod
|
|
67
69
|
def from_config(cls, config: Dict[str, Any]) -> "SelfHostedLLMClient":
|
|
@@ -259,32 +261,6 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
|
|
|
259
261
|
return super().completion(messages)
|
|
260
262
|
return self._text_completion(messages)
|
|
261
263
|
|
|
262
|
-
def _format_text_completion_response(self, response: Any) -> LLMResponse:
|
|
263
|
-
"""Parses the LiteLLM text completion response to Rasa format."""
|
|
264
|
-
formatted_response = LLMResponse(
|
|
265
|
-
id=response.id,
|
|
266
|
-
created=response.created,
|
|
267
|
-
choices=[choice.text for choice in response.choices],
|
|
268
|
-
model=response.model,
|
|
269
|
-
)
|
|
270
|
-
if (usage := response.usage) is not None:
|
|
271
|
-
prompt_tokens = (
|
|
272
|
-
num_tokens
|
|
273
|
-
if isinstance(num_tokens := usage.prompt_tokens, (int, float))
|
|
274
|
-
else 0
|
|
275
|
-
)
|
|
276
|
-
completion_tokens = (
|
|
277
|
-
num_tokens
|
|
278
|
-
if isinstance(num_tokens := usage.completion_tokens, (int, float))
|
|
279
|
-
else 0
|
|
280
|
-
)
|
|
281
|
-
formatted_response.usage = LLMUsage(prompt_tokens, completion_tokens)
|
|
282
|
-
structlogger.debug(
|
|
283
|
-
"base_litellm_client.formatted_response",
|
|
284
|
-
formatted_response=formatted_response.to_dict(),
|
|
285
|
-
)
|
|
286
|
-
return formatted_response
|
|
287
|
-
|
|
288
264
|
@staticmethod
|
|
289
265
|
def _apply_dummy_api_key_if_missing() -> None:
|
|
290
266
|
if not os.getenv(SELF_HOSTED_VLLM_API_KEY_ENV_VAR):
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from typing import Any, Dict, List
|
|
2
|
+
import os
|
|
2
3
|
import structlog
|
|
3
4
|
|
|
4
5
|
from litellm import Router
|
|
@@ -7,6 +8,12 @@ from rasa.shared.constants import (
|
|
|
7
8
|
MODEL_LIST_KEY,
|
|
8
9
|
MODEL_GROUP_ID_CONFIG_KEY,
|
|
9
10
|
ROUTER_CONFIG_KEY,
|
|
11
|
+
SELF_HOSTED_VLLM_PREFIX,
|
|
12
|
+
SELF_HOSTED_VLLM_API_KEY_ENV_VAR,
|
|
13
|
+
LITELLM_PARAMS_KEY,
|
|
14
|
+
API_KEY,
|
|
15
|
+
MODEL_CONFIG_KEY,
|
|
16
|
+
USE_CHAT_COMPLETIONS_ENDPOINT_CONFIG_KEY,
|
|
10
17
|
)
|
|
11
18
|
from rasa.shared.exceptions import ProviderClientValidationError
|
|
12
19
|
from rasa.shared.providers._configs.litellm_router_client_config import (
|
|
@@ -42,12 +49,15 @@ class _BaseLiteLLMRouterClient:
|
|
|
42
49
|
model_group_id: str,
|
|
43
50
|
model_configurations: List[Dict[str, Any]],
|
|
44
51
|
router_settings: Dict[str, Any],
|
|
52
|
+
use_chat_completions_endpoint: bool = True,
|
|
45
53
|
**kwargs: Any,
|
|
46
54
|
):
|
|
47
55
|
self._model_group_id = model_group_id
|
|
48
56
|
self._model_configurations = model_configurations
|
|
49
57
|
self._router_settings = router_settings
|
|
58
|
+
self._use_chat_completions_endpoint = use_chat_completions_endpoint
|
|
50
59
|
self._extra_parameters = kwargs or {}
|
|
60
|
+
self.additional_client_setup()
|
|
51
61
|
try:
|
|
52
62
|
resolved_model_configurations = (
|
|
53
63
|
self._resolve_env_vars_in_model_configurations()
|
|
@@ -67,6 +77,21 @@ class _BaseLiteLLMRouterClient:
|
|
|
67
77
|
)
|
|
68
78
|
raise ProviderClientValidationError(f"{event_info} Original error: {e}")
|
|
69
79
|
|
|
80
|
+
def additional_client_setup(self) -> None:
|
|
81
|
+
"""Additional setup for the LiteLLM Router client."""
|
|
82
|
+
# If the model configuration is self-hosted VLLM, set a dummy API key if not
|
|
83
|
+
# provided. A bug in the LiteLLM library requires an API key to be set even if
|
|
84
|
+
# it is not required.
|
|
85
|
+
for model_configuration in self.model_configurations:
|
|
86
|
+
if (
|
|
87
|
+
f"{SELF_HOSTED_VLLM_PREFIX}/"
|
|
88
|
+
in model_configuration[LITELLM_PARAMS_KEY][MODEL_CONFIG_KEY]
|
|
89
|
+
and API_KEY not in model_configuration[LITELLM_PARAMS_KEY]
|
|
90
|
+
and not os.getenv(SELF_HOSTED_VLLM_API_KEY_ENV_VAR)
|
|
91
|
+
):
|
|
92
|
+
os.environ[SELF_HOSTED_VLLM_API_KEY_ENV_VAR] = "dummy api key"
|
|
93
|
+
return
|
|
94
|
+
|
|
70
95
|
@classmethod
|
|
71
96
|
def from_config(cls, config: Dict[str, Any]) -> "_BaseLiteLLMRouterClient":
|
|
72
97
|
"""Instantiates a LiteLLM Router Embedding client from a configuration dict.
|
|
@@ -95,7 +120,8 @@ class _BaseLiteLLMRouterClient:
|
|
|
95
120
|
return cls(
|
|
96
121
|
model_group_id=client_config.model_group_id,
|
|
97
122
|
model_configurations=client_config.litellm_model_list,
|
|
98
|
-
router_settings=client_config.
|
|
123
|
+
router_settings=client_config.litellm_router_settings,
|
|
124
|
+
use_chat_completions_endpoint=client_config.use_chat_completions_endpoint,
|
|
99
125
|
**client_config.extra_parameters,
|
|
100
126
|
)
|
|
101
127
|
|
|
@@ -119,6 +145,11 @@ class _BaseLiteLLMRouterClient:
|
|
|
119
145
|
"""Returns the instantiated LiteLLM Router client."""
|
|
120
146
|
return self._router_client
|
|
121
147
|
|
|
148
|
+
@property
|
|
149
|
+
def use_chat_completions_endpoint(self) -> bool:
|
|
150
|
+
"""Returns whether to use the chat completions endpoint."""
|
|
151
|
+
return self._use_chat_completions_endpoint
|
|
152
|
+
|
|
122
153
|
@property
|
|
123
154
|
def _litellm_extra_parameters(self) -> Dict[str, Any]:
|
|
124
155
|
"""
|
|
@@ -136,6 +167,9 @@ class _BaseLiteLLMRouterClient:
|
|
|
136
167
|
MODEL_GROUP_ID_CONFIG_KEY: self.model_group_id,
|
|
137
168
|
MODEL_LIST_KEY: self.model_configurations,
|
|
138
169
|
ROUTER_CONFIG_KEY: self.router_settings,
|
|
170
|
+
USE_CHAT_COMPLETIONS_ENDPOINT_CONFIG_KEY: (
|
|
171
|
+
self.use_chat_completions_endpoint
|
|
172
|
+
),
|
|
139
173
|
**self._litellm_extra_parameters,
|
|
140
174
|
}
|
|
141
175
|
|
rasa/shared/utils/common.py
CHANGED
|
@@ -3,14 +3,16 @@ import functools
|
|
|
3
3
|
import importlib
|
|
4
4
|
import inspect
|
|
5
5
|
import logging
|
|
6
|
+
import os
|
|
6
7
|
import pkgutil
|
|
7
8
|
import sys
|
|
8
9
|
from types import ModuleType
|
|
9
|
-
from typing import Text, Dict, Optional, Any, List, Callable, Collection, Type
|
|
10
|
+
from typing import Sequence, Text, Dict, Optional, Any, List, Callable, Collection, Type
|
|
10
11
|
|
|
11
12
|
import rasa.shared.utils.io
|
|
13
|
+
from rasa.exceptions import MissingDependencyException
|
|
12
14
|
from rasa.shared.constants import DOCS_URL_MIGRATION_GUIDE
|
|
13
|
-
from rasa.shared.exceptions import RasaException
|
|
15
|
+
from rasa.shared.exceptions import ProviderClientValidationError, RasaException
|
|
14
16
|
|
|
15
17
|
logger = logging.getLogger(__name__)
|
|
16
18
|
|
|
@@ -193,7 +195,7 @@ def mark_as_experimental_feature(feature_name: Text) -> None:
|
|
|
193
195
|
def mark_as_beta_feature(feature_name: Text) -> None:
|
|
194
196
|
"""Warns users that they are using a beta feature."""
|
|
195
197
|
logger.warning(
|
|
196
|
-
f"🔬 Beta Feature: {feature_name} is in beta. It may have unexpected"
|
|
198
|
+
f"🔬 Beta Feature: {feature_name} is in beta. It may have unexpected "
|
|
197
199
|
"behaviour and might be changed in the future."
|
|
198
200
|
)
|
|
199
201
|
|
|
@@ -295,3 +297,28 @@ def warn_and_exit_if_module_path_contains_rasa_plus(
|
|
|
295
297
|
docs=DOCS_URL_MIGRATION_GUIDE,
|
|
296
298
|
)
|
|
297
299
|
sys.exit(1)
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def validate_environment(
|
|
303
|
+
required_env_vars: Sequence[str],
|
|
304
|
+
required_packages: Sequence[str],
|
|
305
|
+
component_name: str,
|
|
306
|
+
) -> None:
|
|
307
|
+
"""Make sure all needed requirements for a component are met.
|
|
308
|
+
Args:
|
|
309
|
+
required_env_vars: List of environment variables that should be set
|
|
310
|
+
required_packages: List of packages that should be installed
|
|
311
|
+
component_name: component name that needs the requirements
|
|
312
|
+
"""
|
|
313
|
+
for e in required_env_vars:
|
|
314
|
+
if not os.environ.get(e):
|
|
315
|
+
raise ProviderClientValidationError(
|
|
316
|
+
f"Missing environment variable for {component_name}: {e}"
|
|
317
|
+
)
|
|
318
|
+
for p in required_packages:
|
|
319
|
+
try:
|
|
320
|
+
importlib.import_module(p)
|
|
321
|
+
except ImportError:
|
|
322
|
+
raise MissingDependencyException(
|
|
323
|
+
f"Missing package for {component_name}: {p}"
|
|
324
|
+
)
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import os
|
|
2
|
+
import sys
|
|
2
3
|
from typing import Optional, Dict, Any
|
|
3
4
|
|
|
4
5
|
from rasa.shared.constants import (
|
|
@@ -9,7 +10,6 @@ from rasa.shared.constants import (
|
|
|
9
10
|
from rasa.shared.exceptions import ProviderClientValidationError
|
|
10
11
|
from rasa.shared.providers.embedding.embedding_client import EmbeddingClient
|
|
11
12
|
from rasa.shared.providers.llm.llm_client import LLMClient
|
|
12
|
-
from rasa.shared.utils.cli import print_error_and_exit
|
|
13
13
|
from rasa.shared.utils.llm import llm_factory, structlogger, embedder_factory
|
|
14
14
|
|
|
15
15
|
|
|
@@ -25,15 +25,15 @@ def try_instantiate_llm_client(
|
|
|
25
25
|
except (ProviderClientValidationError, ValueError) as e:
|
|
26
26
|
structlogger.error(
|
|
27
27
|
f"{log_source_function}.llm_instantiation_failed",
|
|
28
|
-
|
|
28
|
+
event_info=(
|
|
29
|
+
f"Unable to create the LLM client for component - "
|
|
30
|
+
f"{log_source_component}. "
|
|
31
|
+
f"Please make sure you specified the required environment variables "
|
|
32
|
+
f"and configuration keys. "
|
|
33
|
+
),
|
|
29
34
|
error=e,
|
|
30
35
|
)
|
|
31
|
-
|
|
32
|
-
f"Unable to create the LLM client for component - {log_source_component}. "
|
|
33
|
-
f"Please make sure you specified the required environment variables "
|
|
34
|
-
f"and configuration keys. "
|
|
35
|
-
f"Error: {e}"
|
|
36
|
-
)
|
|
36
|
+
sys.exit(1)
|
|
37
37
|
|
|
38
38
|
|
|
39
39
|
def try_instantiate_embedder(
|
|
@@ -48,14 +48,14 @@ def try_instantiate_embedder(
|
|
|
48
48
|
except (ProviderClientValidationError, ValueError) as e:
|
|
49
49
|
structlogger.error(
|
|
50
50
|
f"{log_source_function}.embedder_instantiation_failed",
|
|
51
|
-
|
|
51
|
+
event_info=(
|
|
52
|
+
f"Unable to create the Embedding client for component - "
|
|
53
|
+
f"{log_source_component}. Please make sure you specified the required "
|
|
54
|
+
f"environment variables and configuration keys."
|
|
55
|
+
),
|
|
52
56
|
error=e,
|
|
53
57
|
)
|
|
54
|
-
|
|
55
|
-
f"Unable to create the Embedding client for component - "
|
|
56
|
-
f"{log_source_component}. Please make sure you specified the required "
|
|
57
|
-
f"environment variables and configuration keys. Error: {e}"
|
|
58
|
-
)
|
|
58
|
+
sys.exit(1)
|
|
59
59
|
|
|
60
60
|
|
|
61
61
|
def perform_llm_health_check(
|
|
@@ -202,13 +202,14 @@ def send_test_llm_api_request(
|
|
|
202
202
|
except Exception as e:
|
|
203
203
|
structlogger.error(
|
|
204
204
|
f"{log_source_function}.send_test_llm_api_request_failed",
|
|
205
|
-
event_info=
|
|
205
|
+
event_info=(
|
|
206
|
+
f"Test call to the LLM API failed for component - "
|
|
207
|
+
f"{log_source_component}.",
|
|
208
|
+
),
|
|
209
|
+
config=llm_client.config,
|
|
206
210
|
error=e,
|
|
207
211
|
)
|
|
208
|
-
|
|
209
|
-
f"Test call to the LLM API failed for component - {log_source_component}. "
|
|
210
|
-
f"Error: {e}"
|
|
211
|
-
)
|
|
212
|
+
sys.exit(1)
|
|
212
213
|
|
|
213
214
|
|
|
214
215
|
def send_test_embeddings_api_request(
|
|
@@ -232,13 +233,14 @@ def send_test_embeddings_api_request(
|
|
|
232
233
|
except Exception as e:
|
|
233
234
|
structlogger.error(
|
|
234
235
|
f"{log_source_function}.send_test_llm_api_request_failed",
|
|
235
|
-
event_info=
|
|
236
|
+
event_info=(
|
|
237
|
+
f"Test call to the Embeddings API failed for component - "
|
|
238
|
+
f"{log_source_component}."
|
|
239
|
+
),
|
|
240
|
+
config=embedder.config,
|
|
236
241
|
error=e,
|
|
237
242
|
)
|
|
238
|
-
|
|
239
|
-
f"Test call to the Embeddings API failed for component - "
|
|
240
|
-
f"{log_source_component}. Error: {e}"
|
|
241
|
-
)
|
|
243
|
+
sys.exit(1)
|
|
242
244
|
|
|
243
245
|
|
|
244
246
|
def is_api_health_check_enabled() -> bool:
|