rasa-pro 3.11.0rc1__py3-none-any.whl → 3.11.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/cli/inspect.py +2 -0
- rasa/cli/studio/studio.py +18 -8
- rasa/core/actions/action_repeat_bot_messages.py +17 -0
- rasa/core/channels/channel.py +17 -0
- rasa/core/channels/development_inspector.py +4 -1
- rasa/core/channels/voice_ready/audiocodes.py +15 -4
- rasa/core/channels/voice_ready/jambonz.py +13 -2
- rasa/core/channels/voice_ready/twilio_voice.py +6 -21
- rasa/core/channels/voice_stream/asr/asr_event.py +1 -1
- rasa/core/channels/voice_stream/asr/azure.py +5 -7
- rasa/core/channels/voice_stream/asr/deepgram.py +13 -11
- rasa/core/channels/voice_stream/voice_channel.py +61 -19
- rasa/core/nlg/contextual_response_rephraser.py +20 -12
- rasa/core/policies/enterprise_search_policy.py +32 -72
- rasa/core/policies/intentless_policy.py +34 -72
- rasa/dialogue_understanding/coexistence/llm_based_router.py +18 -33
- rasa/dialogue_understanding/generator/constants.py +0 -2
- rasa/dialogue_understanding/generator/flow_retrieval.py +33 -50
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +12 -40
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +18 -20
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +19 -1
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +26 -22
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +9 -0
- rasa/dialogue_understanding/processor/command_processor.py +21 -1
- rasa/e2e_test/e2e_test_case.py +85 -6
- rasa/engine/validation.py +88 -60
- rasa/model_service.py +3 -0
- rasa/nlu/tokenizers/whitespace_tokenizer.py +3 -14
- rasa/server.py +3 -1
- rasa/shared/constants.py +5 -5
- rasa/shared/core/constants.py +1 -1
- rasa/shared/core/domain.py +0 -26
- rasa/shared/core/flows/flows_list.py +5 -1
- rasa/shared/providers/_configs/litellm_router_client_config.py +29 -9
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +6 -14
- rasa/shared/providers/embedding/litellm_router_embedding_client.py +1 -1
- rasa/shared/providers/llm/_base_litellm_client.py +32 -1
- rasa/shared/providers/llm/litellm_router_llm_client.py +56 -1
- rasa/shared/providers/llm/self_hosted_llm_client.py +4 -28
- rasa/shared/providers/router/_base_litellm_router_client.py +35 -1
- rasa/shared/utils/common.py +1 -1
- rasa/shared/utils/health_check/__init__.py +0 -0
- rasa/shared/utils/health_check/embeddings_health_check_mixin.py +31 -0
- rasa/shared/utils/health_check/health_check.py +256 -0
- rasa/shared/utils/health_check/llm_health_check_mixin.py +31 -0
- rasa/shared/utils/llm.py +5 -2
- rasa/shared/utils/yaml.py +102 -62
- rasa/studio/auth.py +3 -5
- rasa/studio/config.py +13 -4
- rasa/studio/constants.py +1 -0
- rasa/studio/data_handler.py +10 -3
- rasa/studio/upload.py +21 -10
- rasa/telemetry.py +15 -1
- rasa/tracing/config.py +3 -1
- rasa/tracing/instrumentation/attribute_extractors.py +20 -0
- rasa/tracing/instrumentation/instrumentation.py +121 -0
- rasa/utils/common.py +5 -0
- rasa/utils/io.py +8 -16
- rasa/utils/sanic_error_handler.py +32 -0
- rasa/version.py +1 -1
- {rasa_pro-3.11.0rc1.dist-info → rasa_pro-3.11.0rc3.dist-info}/METADATA +3 -2
- {rasa_pro-3.11.0rc1.dist-info → rasa_pro-3.11.0rc3.dist-info}/RECORD +65 -61
- rasa/shared/utils/health_check.py +0 -533
- {rasa_pro-3.11.0rc1.dist-info → rasa_pro-3.11.0rc3.dist-info}/NOTICE +0 -0
- {rasa_pro-3.11.0rc1.dist-info → rasa_pro-3.11.0rc3.dist-info}/WHEEL +0 -0
- {rasa_pro-3.11.0rc1.dist-info → rasa_pro-3.11.0rc3.dist-info}/entry_points.txt +0 -0
|
@@ -6,6 +6,7 @@ import litellm
|
|
|
6
6
|
import structlog
|
|
7
7
|
from litellm import aembedding, embedding, validate_environment
|
|
8
8
|
|
|
9
|
+
from rasa.shared.constants import API_BASE_CONFIG_KEY, API_KEY
|
|
9
10
|
from rasa.shared.exceptions import (
|
|
10
11
|
ProviderClientAPIException,
|
|
11
12
|
ProviderClientValidationError,
|
|
@@ -81,11 +82,14 @@ class _BaseLiteLLMEmbeddingClient:
|
|
|
81
82
|
ProviderClientValidationError if validation fails.
|
|
82
83
|
"""
|
|
83
84
|
self._validate_environment_variables()
|
|
84
|
-
self._validate_api_key_not_in_config()
|
|
85
85
|
|
|
86
86
|
def _validate_environment_variables(self) -> None:
|
|
87
87
|
"""Validate that the required environment variables are set."""
|
|
88
|
-
validation_info = validate_environment(
|
|
88
|
+
validation_info = validate_environment(
|
|
89
|
+
self._litellm_model_name,
|
|
90
|
+
api_key=self._litellm_extra_parameters.get(API_KEY),
|
|
91
|
+
api_base=self._litellm_extra_parameters.get(API_BASE_CONFIG_KEY),
|
|
92
|
+
)
|
|
89
93
|
if missing_environment_variables := validation_info.get(
|
|
90
94
|
_VALIDATE_ENVIRONMENT_MISSING_KEYS_KEY
|
|
91
95
|
):
|
|
@@ -100,18 +104,6 @@ class _BaseLiteLLMEmbeddingClient:
|
|
|
100
104
|
)
|
|
101
105
|
raise ProviderClientValidationError(event_info)
|
|
102
106
|
|
|
103
|
-
def _validate_api_key_not_in_config(self) -> None:
|
|
104
|
-
if "api_key" in self._litellm_extra_parameters:
|
|
105
|
-
event_info = (
|
|
106
|
-
"API Key is set through `api_key` extra parameter."
|
|
107
|
-
"Set API keys through environment variables."
|
|
108
|
-
)
|
|
109
|
-
structlogger.error(
|
|
110
|
-
"base_litellm_client.validate_api_key_not_in_config",
|
|
111
|
-
event_info=event_info,
|
|
112
|
-
)
|
|
113
|
-
raise ProviderClientValidationError(event_info)
|
|
114
|
-
|
|
115
107
|
def validate_documents(self, documents: List[str]) -> None:
|
|
116
108
|
"""Validates a list of documents to ensure they are suitable for embedding.
|
|
117
109
|
|
|
@@ -72,7 +72,7 @@ class LiteLLMRouterEmbeddingClient(
|
|
|
72
72
|
return cls(
|
|
73
73
|
model_group_id=client_config.model_group_id,
|
|
74
74
|
model_configurations=client_config.litellm_model_list,
|
|
75
|
-
router_settings=client_config.
|
|
75
|
+
router_settings=client_config.litellm_router_settings,
|
|
76
76
|
**client_config.extra_parameters,
|
|
77
77
|
)
|
|
78
78
|
|
|
@@ -9,6 +9,7 @@ from litellm import (
|
|
|
9
9
|
validate_environment,
|
|
10
10
|
)
|
|
11
11
|
|
|
12
|
+
from rasa.shared.constants import API_BASE_CONFIG_KEY, API_KEY
|
|
12
13
|
from rasa.shared.exceptions import (
|
|
13
14
|
ProviderClientAPIException,
|
|
14
15
|
ProviderClientValidationError,
|
|
@@ -101,7 +102,11 @@ class _BaseLiteLLMClient:
|
|
|
101
102
|
|
|
102
103
|
def _validate_environment_variables(self) -> None:
|
|
103
104
|
"""Validate that the required environment variables are set."""
|
|
104
|
-
validation_info = validate_environment(
|
|
105
|
+
validation_info = validate_environment(
|
|
106
|
+
self._litellm_model_name,
|
|
107
|
+
api_key=self._litellm_extra_parameters.get(API_KEY),
|
|
108
|
+
api_base=self._litellm_extra_parameters.get(API_BASE_CONFIG_KEY),
|
|
109
|
+
)
|
|
105
110
|
if missing_environment_variables := validation_info.get(
|
|
106
111
|
_VALIDATE_ENVIRONMENT_MISSING_KEYS_KEY
|
|
107
112
|
):
|
|
@@ -216,6 +221,32 @@ class _BaseLiteLLMClient:
|
|
|
216
221
|
)
|
|
217
222
|
return formatted_response
|
|
218
223
|
|
|
224
|
+
def _format_text_completion_response(self, response: Any) -> LLMResponse:
|
|
225
|
+
"""Parses the LiteLLM text completion response to Rasa format."""
|
|
226
|
+
formatted_response = LLMResponse(
|
|
227
|
+
id=response.id,
|
|
228
|
+
created=response.created,
|
|
229
|
+
choices=[choice.text for choice in response.choices],
|
|
230
|
+
model=response.model,
|
|
231
|
+
)
|
|
232
|
+
if (usage := response.usage) is not None:
|
|
233
|
+
prompt_tokens = (
|
|
234
|
+
num_tokens
|
|
235
|
+
if isinstance(num_tokens := usage.prompt_tokens, (int, float))
|
|
236
|
+
else 0
|
|
237
|
+
)
|
|
238
|
+
completion_tokens = (
|
|
239
|
+
num_tokens
|
|
240
|
+
if isinstance(num_tokens := usage.completion_tokens, (int, float))
|
|
241
|
+
else 0
|
|
242
|
+
)
|
|
243
|
+
formatted_response.usage = LLMUsage(prompt_tokens, completion_tokens)
|
|
244
|
+
structlogger.debug(
|
|
245
|
+
"base_litellm_client.formatted_response",
|
|
246
|
+
formatted_response=formatted_response.to_dict(),
|
|
247
|
+
)
|
|
248
|
+
return formatted_response
|
|
249
|
+
|
|
219
250
|
@staticmethod
|
|
220
251
|
def _ensure_certificates() -> None:
|
|
221
252
|
"""Configures SSL certificates for LiteLLM. This method is invoked during
|
|
@@ -68,15 +68,61 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
68
68
|
return cls(
|
|
69
69
|
model_group_id=client_config.model_group_id,
|
|
70
70
|
model_configurations=client_config.litellm_model_list,
|
|
71
|
-
router_settings=client_config.
|
|
71
|
+
router_settings=client_config.litellm_router_settings,
|
|
72
|
+
use_chat_completions_endpoint=client_config.use_chat_completions_endpoint,
|
|
72
73
|
**client_config.extra_parameters,
|
|
73
74
|
)
|
|
74
75
|
|
|
76
|
+
@suppress_logs(log_level=logging.WARNING)
|
|
77
|
+
def _text_completion(self, prompt: Union[List[str], str]) -> LLMResponse:
|
|
78
|
+
"""
|
|
79
|
+
Synchronously generate completions for given prompt.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
prompt: Prompt to generate the completion for.
|
|
83
|
+
Returns:
|
|
84
|
+
List of message completions.
|
|
85
|
+
Raises:
|
|
86
|
+
ProviderClientAPIException: If the API request fails.
|
|
87
|
+
"""
|
|
88
|
+
try:
|
|
89
|
+
response = self.router_client.text_completion(
|
|
90
|
+
prompt=prompt, **self._completion_fn_args
|
|
91
|
+
)
|
|
92
|
+
return self._format_text_completion_response(response)
|
|
93
|
+
except Exception as e:
|
|
94
|
+
raise ProviderClientAPIException(e)
|
|
95
|
+
|
|
96
|
+
@suppress_logs(log_level=logging.WARNING)
|
|
97
|
+
async def _atext_completion(self, prompt: Union[List[str], str]) -> LLMResponse:
|
|
98
|
+
"""
|
|
99
|
+
Asynchronously generate completions for given prompt.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
prompt: Prompt to generate the completion for.
|
|
103
|
+
Returns:
|
|
104
|
+
List of message completions.
|
|
105
|
+
Raises:
|
|
106
|
+
ProviderClientAPIException: If the API request fails.
|
|
107
|
+
"""
|
|
108
|
+
try:
|
|
109
|
+
response = await self.router_client.atext_completion(
|
|
110
|
+
prompt=prompt, **self._completion_fn_args
|
|
111
|
+
)
|
|
112
|
+
return self._format_text_completion_response(response)
|
|
113
|
+
except Exception as e:
|
|
114
|
+
raise ProviderClientAPIException(e)
|
|
115
|
+
|
|
75
116
|
@suppress_logs(log_level=logging.WARNING)
|
|
76
117
|
def completion(self, messages: Union[List[str], str]) -> LLMResponse:
|
|
77
118
|
"""
|
|
78
119
|
Synchronously generate completions for given list of messages.
|
|
79
120
|
|
|
121
|
+
Method overrides the base class method to call the appropriate
|
|
122
|
+
completion method based on the configuration. If the chat completions
|
|
123
|
+
endpoint is enabled, the completion method is called. Otherwise, the
|
|
124
|
+
text_completion method is called.
|
|
125
|
+
|
|
80
126
|
Args:
|
|
81
127
|
messages: List of messages or a single message to generate the
|
|
82
128
|
completion for.
|
|
@@ -85,6 +131,8 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
85
131
|
Raises:
|
|
86
132
|
ProviderClientAPIException: If the API request fails.
|
|
87
133
|
"""
|
|
134
|
+
if not self._use_chat_completions_endpoint:
|
|
135
|
+
return self._text_completion(messages)
|
|
88
136
|
try:
|
|
89
137
|
formatted_messages = self._format_messages(messages)
|
|
90
138
|
response = self.router_client.completion(
|
|
@@ -99,6 +147,11 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
99
147
|
"""
|
|
100
148
|
Asynchronously generate completions for given list of messages.
|
|
101
149
|
|
|
150
|
+
Method overrides the base class method to call the appropriate
|
|
151
|
+
completion method based on the configuration. If the chat completions
|
|
152
|
+
endpoint is enabled, the completion method is called. Otherwise, the
|
|
153
|
+
text_completion method is called.
|
|
154
|
+
|
|
102
155
|
Args:
|
|
103
156
|
messages: List of messages or a single message to generate the
|
|
104
157
|
completion for.
|
|
@@ -107,6 +160,8 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
107
160
|
Raises:
|
|
108
161
|
ProviderClientAPIException: If the API request fails.
|
|
109
162
|
"""
|
|
163
|
+
if not self._use_chat_completions_endpoint:
|
|
164
|
+
return await self._atext_completion(messages)
|
|
110
165
|
try:
|
|
111
166
|
formatted_messages = self._format_messages(messages)
|
|
112
167
|
response = await self.router_client.acompletion(
|
|
@@ -10,13 +10,14 @@ import structlog
|
|
|
10
10
|
from rasa.shared.constants import (
|
|
11
11
|
SELF_HOSTED_VLLM_PREFIX,
|
|
12
12
|
SELF_HOSTED_VLLM_API_KEY_ENV_VAR,
|
|
13
|
+
API_KEY,
|
|
13
14
|
)
|
|
14
15
|
from rasa.shared.providers._configs.self_hosted_llm_client_config import (
|
|
15
16
|
SelfHostedLLMClientConfig,
|
|
16
17
|
)
|
|
17
18
|
from rasa.shared.exceptions import ProviderClientAPIException
|
|
18
19
|
from rasa.shared.providers.llm._base_litellm_client import _BaseLiteLLMClient
|
|
19
|
-
from rasa.shared.providers.llm.llm_response import LLMResponse
|
|
20
|
+
from rasa.shared.providers.llm.llm_response import LLMResponse
|
|
20
21
|
from rasa.shared.utils.io import suppress_logs
|
|
21
22
|
|
|
22
23
|
structlogger = structlog.get_logger()
|
|
@@ -61,7 +62,8 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
|
|
|
61
62
|
self._api_version = api_version
|
|
62
63
|
self._use_chat_completions_endpoint = use_chat_completions_endpoint
|
|
63
64
|
self._extra_parameters = kwargs or {}
|
|
64
|
-
self.
|
|
65
|
+
if self._extra_parameters.get(API_KEY) is None:
|
|
66
|
+
self._apply_dummy_api_key_if_missing()
|
|
65
67
|
|
|
66
68
|
@classmethod
|
|
67
69
|
def from_config(cls, config: Dict[str, Any]) -> "SelfHostedLLMClient":
|
|
@@ -259,32 +261,6 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
|
|
|
259
261
|
return super().completion(messages)
|
|
260
262
|
return self._text_completion(messages)
|
|
261
263
|
|
|
262
|
-
def _format_text_completion_response(self, response: Any) -> LLMResponse:
|
|
263
|
-
"""Parses the LiteLLM text completion response to Rasa format."""
|
|
264
|
-
formatted_response = LLMResponse(
|
|
265
|
-
id=response.id,
|
|
266
|
-
created=response.created,
|
|
267
|
-
choices=[choice.text for choice in response.choices],
|
|
268
|
-
model=response.model,
|
|
269
|
-
)
|
|
270
|
-
if (usage := response.usage) is not None:
|
|
271
|
-
prompt_tokens = (
|
|
272
|
-
num_tokens
|
|
273
|
-
if isinstance(num_tokens := usage.prompt_tokens, (int, float))
|
|
274
|
-
else 0
|
|
275
|
-
)
|
|
276
|
-
completion_tokens = (
|
|
277
|
-
num_tokens
|
|
278
|
-
if isinstance(num_tokens := usage.completion_tokens, (int, float))
|
|
279
|
-
else 0
|
|
280
|
-
)
|
|
281
|
-
formatted_response.usage = LLMUsage(prompt_tokens, completion_tokens)
|
|
282
|
-
structlogger.debug(
|
|
283
|
-
"base_litellm_client.formatted_response",
|
|
284
|
-
formatted_response=formatted_response.to_dict(),
|
|
285
|
-
)
|
|
286
|
-
return formatted_response
|
|
287
|
-
|
|
288
264
|
@staticmethod
|
|
289
265
|
def _apply_dummy_api_key_if_missing() -> None:
|
|
290
266
|
if not os.getenv(SELF_HOSTED_VLLM_API_KEY_ENV_VAR):
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from typing import Any, Dict, List
|
|
2
|
+
import os
|
|
2
3
|
import structlog
|
|
3
4
|
|
|
4
5
|
from litellm import Router
|
|
@@ -7,6 +8,12 @@ from rasa.shared.constants import (
|
|
|
7
8
|
MODEL_LIST_KEY,
|
|
8
9
|
MODEL_GROUP_ID_CONFIG_KEY,
|
|
9
10
|
ROUTER_CONFIG_KEY,
|
|
11
|
+
SELF_HOSTED_VLLM_PREFIX,
|
|
12
|
+
SELF_HOSTED_VLLM_API_KEY_ENV_VAR,
|
|
13
|
+
LITELLM_PARAMS_KEY,
|
|
14
|
+
API_KEY,
|
|
15
|
+
MODEL_CONFIG_KEY,
|
|
16
|
+
USE_CHAT_COMPLETIONS_ENDPOINT_CONFIG_KEY,
|
|
10
17
|
)
|
|
11
18
|
from rasa.shared.exceptions import ProviderClientValidationError
|
|
12
19
|
from rasa.shared.providers._configs.litellm_router_client_config import (
|
|
@@ -42,12 +49,15 @@ class _BaseLiteLLMRouterClient:
|
|
|
42
49
|
model_group_id: str,
|
|
43
50
|
model_configurations: List[Dict[str, Any]],
|
|
44
51
|
router_settings: Dict[str, Any],
|
|
52
|
+
use_chat_completions_endpoint: bool = True,
|
|
45
53
|
**kwargs: Any,
|
|
46
54
|
):
|
|
47
55
|
self._model_group_id = model_group_id
|
|
48
56
|
self._model_configurations = model_configurations
|
|
49
57
|
self._router_settings = router_settings
|
|
58
|
+
self._use_chat_completions_endpoint = use_chat_completions_endpoint
|
|
50
59
|
self._extra_parameters = kwargs or {}
|
|
60
|
+
self.additional_client_setup()
|
|
51
61
|
try:
|
|
52
62
|
resolved_model_configurations = (
|
|
53
63
|
self._resolve_env_vars_in_model_configurations()
|
|
@@ -67,6 +77,21 @@ class _BaseLiteLLMRouterClient:
|
|
|
67
77
|
)
|
|
68
78
|
raise ProviderClientValidationError(f"{event_info} Original error: {e}")
|
|
69
79
|
|
|
80
|
+
def additional_client_setup(self) -> None:
|
|
81
|
+
"""Additional setup for the LiteLLM Router client."""
|
|
82
|
+
# If the model configuration is self-hosted VLLM, set a dummy API key if not
|
|
83
|
+
# provided. A bug in the LiteLLM library requires an API key to be set even if
|
|
84
|
+
# it is not required.
|
|
85
|
+
for model_configuration in self.model_configurations:
|
|
86
|
+
if (
|
|
87
|
+
f"{SELF_HOSTED_VLLM_PREFIX}/"
|
|
88
|
+
in model_configuration[LITELLM_PARAMS_KEY][MODEL_CONFIG_KEY]
|
|
89
|
+
and API_KEY not in model_configuration[LITELLM_PARAMS_KEY]
|
|
90
|
+
and not os.getenv(SELF_HOSTED_VLLM_API_KEY_ENV_VAR)
|
|
91
|
+
):
|
|
92
|
+
os.environ[SELF_HOSTED_VLLM_API_KEY_ENV_VAR] = "dummy api key"
|
|
93
|
+
return
|
|
94
|
+
|
|
70
95
|
@classmethod
|
|
71
96
|
def from_config(cls, config: Dict[str, Any]) -> "_BaseLiteLLMRouterClient":
|
|
72
97
|
"""Instantiates a LiteLLM Router Embedding client from a configuration dict.
|
|
@@ -95,7 +120,8 @@ class _BaseLiteLLMRouterClient:
|
|
|
95
120
|
return cls(
|
|
96
121
|
model_group_id=client_config.model_group_id,
|
|
97
122
|
model_configurations=client_config.litellm_model_list,
|
|
98
|
-
router_settings=client_config.
|
|
123
|
+
router_settings=client_config.litellm_router_settings,
|
|
124
|
+
use_chat_completions_endpoint=client_config.use_chat_completions_endpoint,
|
|
99
125
|
**client_config.extra_parameters,
|
|
100
126
|
)
|
|
101
127
|
|
|
@@ -119,6 +145,11 @@ class _BaseLiteLLMRouterClient:
|
|
|
119
145
|
"""Returns the instantiated LiteLLM Router client."""
|
|
120
146
|
return self._router_client
|
|
121
147
|
|
|
148
|
+
@property
|
|
149
|
+
def use_chat_completions_endpoint(self) -> bool:
|
|
150
|
+
"""Returns whether to use the chat completions endpoint."""
|
|
151
|
+
return self._use_chat_completions_endpoint
|
|
152
|
+
|
|
122
153
|
@property
|
|
123
154
|
def _litellm_extra_parameters(self) -> Dict[str, Any]:
|
|
124
155
|
"""
|
|
@@ -136,6 +167,9 @@ class _BaseLiteLLMRouterClient:
|
|
|
136
167
|
MODEL_GROUP_ID_CONFIG_KEY: self.model_group_id,
|
|
137
168
|
MODEL_LIST_KEY: self.model_configurations,
|
|
138
169
|
ROUTER_CONFIG_KEY: self.router_settings,
|
|
170
|
+
USE_CHAT_COMPLETIONS_ENDPOINT_CONFIG_KEY: (
|
|
171
|
+
self.use_chat_completions_endpoint
|
|
172
|
+
),
|
|
139
173
|
**self._litellm_extra_parameters,
|
|
140
174
|
}
|
|
141
175
|
|
rasa/shared/utils/common.py
CHANGED
|
@@ -193,7 +193,7 @@ def mark_as_experimental_feature(feature_name: Text) -> None:
|
|
|
193
193
|
def mark_as_beta_feature(feature_name: Text) -> None:
|
|
194
194
|
"""Warns users that they are using a beta feature."""
|
|
195
195
|
logger.warning(
|
|
196
|
-
f"🔬 Beta Feature: {feature_name} is in beta. It may have unexpected"
|
|
196
|
+
f"🔬 Beta Feature: {feature_name} is in beta. It may have unexpected "
|
|
197
197
|
"behaviour and might be changed in the future."
|
|
198
198
|
)
|
|
199
199
|
|
|
File without changes
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from typing import Optional, Dict, Any
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class EmbeddingsHealthCheckMixin:
|
|
5
|
+
"""Mixin class that provides methods for performing embeddings health checks during
|
|
6
|
+
training and inference within components.
|
|
7
|
+
|
|
8
|
+
This mixin offers static methods that wrap the following health check functions:
|
|
9
|
+
- `perform_embeddings_health_check`
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
@staticmethod
|
|
13
|
+
def perform_embeddings_health_check(
|
|
14
|
+
custom_embeddings_config: Optional[Dict[str, Any]],
|
|
15
|
+
default_embeddings_config: Dict[str, Any],
|
|
16
|
+
log_source_method: str,
|
|
17
|
+
log_source_component: str,
|
|
18
|
+
) -> None:
|
|
19
|
+
"""Wraps the `perform_embeddings_health_check` function to enable
|
|
20
|
+
tracing and instrumentation.
|
|
21
|
+
"""
|
|
22
|
+
from rasa.shared.utils.health_check.health_check import (
|
|
23
|
+
perform_embeddings_health_check,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
perform_embeddings_health_check(
|
|
27
|
+
custom_embeddings_config,
|
|
28
|
+
default_embeddings_config,
|
|
29
|
+
log_source_method,
|
|
30
|
+
log_source_component,
|
|
31
|
+
)
|
|
@@ -0,0 +1,256 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Optional, Dict, Any
|
|
3
|
+
|
|
4
|
+
from rasa.shared.constants import (
|
|
5
|
+
LLM_API_HEALTH_CHECK_ENV_VAR,
|
|
6
|
+
MODELS_CONFIG_KEY,
|
|
7
|
+
LLM_API_HEALTH_CHECK_DEFAULT_VALUE,
|
|
8
|
+
)
|
|
9
|
+
from rasa.shared.exceptions import ProviderClientValidationError
|
|
10
|
+
from rasa.shared.providers.embedding.embedding_client import EmbeddingClient
|
|
11
|
+
from rasa.shared.providers.llm.llm_client import LLMClient
|
|
12
|
+
from rasa.shared.utils.cli import print_error_and_exit
|
|
13
|
+
from rasa.shared.utils.llm import llm_factory, structlogger, embedder_factory
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def try_instantiate_llm_client(
|
|
17
|
+
custom_llm_config: Optional[Dict],
|
|
18
|
+
default_llm_config: Optional[Dict],
|
|
19
|
+
log_source_function: str,
|
|
20
|
+
log_source_component: str,
|
|
21
|
+
) -> LLMClient:
|
|
22
|
+
"""Validate llm configuration."""
|
|
23
|
+
try:
|
|
24
|
+
return llm_factory(custom_llm_config, default_llm_config)
|
|
25
|
+
except (ProviderClientValidationError, ValueError) as e:
|
|
26
|
+
structlogger.error(
|
|
27
|
+
f"{log_source_function}.llm_instantiation_failed",
|
|
28
|
+
message="Unable to instantiate LLM client.",
|
|
29
|
+
error=e,
|
|
30
|
+
)
|
|
31
|
+
print_error_and_exit(
|
|
32
|
+
f"Unable to create the LLM client for component - {log_source_component}. "
|
|
33
|
+
f"Please make sure you specified the required environment variables "
|
|
34
|
+
f"and configuration keys. "
|
|
35
|
+
f"Error: {e}"
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def try_instantiate_embedder(
|
|
40
|
+
custom_embeddings_config: Optional[Dict],
|
|
41
|
+
default_embeddings_config: Optional[Dict],
|
|
42
|
+
log_source_function: str,
|
|
43
|
+
log_source_component: str,
|
|
44
|
+
) -> EmbeddingClient:
|
|
45
|
+
"""Validate embeddings configuration."""
|
|
46
|
+
try:
|
|
47
|
+
return embedder_factory(custom_embeddings_config, default_embeddings_config)
|
|
48
|
+
except (ProviderClientValidationError, ValueError) as e:
|
|
49
|
+
structlogger.error(
|
|
50
|
+
f"{log_source_function}.embedder_instantiation_failed",
|
|
51
|
+
message="Unable to instantiate Embedding client.",
|
|
52
|
+
error=e,
|
|
53
|
+
)
|
|
54
|
+
print_error_and_exit(
|
|
55
|
+
f"Unable to create the Embedding client for component - "
|
|
56
|
+
f"{log_source_component}. Please make sure you specified the required "
|
|
57
|
+
f"environment variables and configuration keys. Error: {e}"
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def perform_llm_health_check(
|
|
62
|
+
custom_config: Optional[Dict[str, Any]],
|
|
63
|
+
default_config: Dict[str, Any],
|
|
64
|
+
log_source_function: str,
|
|
65
|
+
log_source_component: str,
|
|
66
|
+
) -> None:
|
|
67
|
+
"""Try to instantiate the LLM Client to validate the provided config.
|
|
68
|
+
If the LLM_API_HEALTH_CHECK environment variable is true, perform a test call
|
|
69
|
+
to the LLM API. If config contains multiple models, perform a test call for each
|
|
70
|
+
model in the model group.
|
|
71
|
+
|
|
72
|
+
This method supports both single model configurations and model group configurations
|
|
73
|
+
(configs that have the `models` key).
|
|
74
|
+
"""
|
|
75
|
+
# Instantiate the LLM client or Router LLM client to validate the provided config.
|
|
76
|
+
llm_client = try_instantiate_llm_client(
|
|
77
|
+
custom_config, default_config, log_source_function, log_source_component
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
if is_api_health_check_enabled():
|
|
81
|
+
if (
|
|
82
|
+
custom_config
|
|
83
|
+
and MODELS_CONFIG_KEY in custom_config
|
|
84
|
+
and len(custom_config[MODELS_CONFIG_KEY]) > 1
|
|
85
|
+
):
|
|
86
|
+
# If the config uses a router, instantiate the LLM client for each model
|
|
87
|
+
# in the model group. This is required to perform a test api call for each
|
|
88
|
+
# model in the group.
|
|
89
|
+
# Note: The Router LLM client is not used here as we need to perform a test
|
|
90
|
+
# api call and not load balance the requests.
|
|
91
|
+
for model_config in custom_config[MODELS_CONFIG_KEY]:
|
|
92
|
+
llm_client = try_instantiate_llm_client(
|
|
93
|
+
model_config,
|
|
94
|
+
default_config,
|
|
95
|
+
log_source_function,
|
|
96
|
+
log_source_component,
|
|
97
|
+
)
|
|
98
|
+
send_test_llm_api_request(
|
|
99
|
+
llm_client, log_source_function, log_source_component
|
|
100
|
+
)
|
|
101
|
+
else:
|
|
102
|
+
# Make a test api call to perform a health check for the LLM client.
|
|
103
|
+
# LLM config from config file and model group config from endpoint config
|
|
104
|
+
# without router are handled here.
|
|
105
|
+
send_test_llm_api_request(
|
|
106
|
+
llm_client,
|
|
107
|
+
log_source_function,
|
|
108
|
+
log_source_component,
|
|
109
|
+
)
|
|
110
|
+
else:
|
|
111
|
+
structlogger.warning(
|
|
112
|
+
f"{log_source_function}.perform_llm_health_check.disabled",
|
|
113
|
+
event_info=(
|
|
114
|
+
f"The {LLM_API_HEALTH_CHECK_ENV_VAR} environment variable is set "
|
|
115
|
+
f"to false, which will disable LLM health check. "
|
|
116
|
+
f"It is recommended to set this variable to true in production "
|
|
117
|
+
f"environments."
|
|
118
|
+
),
|
|
119
|
+
)
|
|
120
|
+
return None
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def perform_embeddings_health_check(
|
|
124
|
+
custom_config: Optional[Dict[str, Any]],
|
|
125
|
+
default_config: Dict[str, Any],
|
|
126
|
+
log_source_function: str,
|
|
127
|
+
log_source_component: str,
|
|
128
|
+
) -> None:
|
|
129
|
+
"""Try to instantiate the Embedder to validate the provided config.
|
|
130
|
+
If the LLM_API_HEALTH_CHECK environment variable is true, perform a test call
|
|
131
|
+
to the Embeddings API. If config contains multiple models, perform a test call for
|
|
132
|
+
each model in the model group.
|
|
133
|
+
|
|
134
|
+
This method supports both single model configurations and model group configurations
|
|
135
|
+
(configs that have the `models` key).
|
|
136
|
+
"""
|
|
137
|
+
# Instantiate the Embedder client or the Embedder Router client to validate the
|
|
138
|
+
# provided config. Deprecation warnings and errors are logged here.
|
|
139
|
+
embedder = try_instantiate_embedder(
|
|
140
|
+
custom_config, default_config, log_source_function, log_source_component
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
if is_api_health_check_enabled():
|
|
144
|
+
if (
|
|
145
|
+
custom_config
|
|
146
|
+
and MODELS_CONFIG_KEY in custom_config
|
|
147
|
+
and len(custom_config[MODELS_CONFIG_KEY]) > 1
|
|
148
|
+
):
|
|
149
|
+
# If the config uses a router, instantiate the Embedder client for each
|
|
150
|
+
# model in the model group. This is required to perform a test api call
|
|
151
|
+
# for every model in the group.
|
|
152
|
+
# Note: The Router Embedding client is not used here as we need to perform
|
|
153
|
+
# a test API call and not load balance the requests.
|
|
154
|
+
for model_config in custom_config[MODELS_CONFIG_KEY]:
|
|
155
|
+
embedder = try_instantiate_embedder(
|
|
156
|
+
model_config,
|
|
157
|
+
default_config,
|
|
158
|
+
log_source_function,
|
|
159
|
+
log_source_component,
|
|
160
|
+
)
|
|
161
|
+
send_test_embeddings_api_request(
|
|
162
|
+
embedder, log_source_function, log_source_component
|
|
163
|
+
)
|
|
164
|
+
else:
|
|
165
|
+
# Make a test api call to perform a health check for the Embedding client.
|
|
166
|
+
# Embeddings config from config file and model group config from endpoint
|
|
167
|
+
# config without router are handled here.
|
|
168
|
+
send_test_embeddings_api_request(
|
|
169
|
+
embedder, log_source_function, log_source_component
|
|
170
|
+
)
|
|
171
|
+
else:
|
|
172
|
+
structlogger.warning(
|
|
173
|
+
f"{log_source_function}" f".perform_embeddings_health_check.disabled",
|
|
174
|
+
event_info=(
|
|
175
|
+
f"The {LLM_API_HEALTH_CHECK_ENV_VAR} environment variable is set "
|
|
176
|
+
f"to false, which will disable embeddings API health check. "
|
|
177
|
+
f"It is recommended to set this variable to true in production "
|
|
178
|
+
f"environments."
|
|
179
|
+
),
|
|
180
|
+
)
|
|
181
|
+
return None
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def send_test_llm_api_request(
|
|
185
|
+
llm_client: LLMClient, log_source_function: str, log_source_component: str
|
|
186
|
+
) -> None:
|
|
187
|
+
"""Sends a test request to the LLM API to perform a health check.
|
|
188
|
+
|
|
189
|
+
Raises:
|
|
190
|
+
Exception: If the API call fails.
|
|
191
|
+
"""
|
|
192
|
+
structlogger.info(
|
|
193
|
+
f"{log_source_function}.send_test_llm_api_request",
|
|
194
|
+
event_info=(
|
|
195
|
+
f"Sending a test LLM API request for the component - "
|
|
196
|
+
f"{log_source_component}."
|
|
197
|
+
),
|
|
198
|
+
config=llm_client.config,
|
|
199
|
+
)
|
|
200
|
+
try:
|
|
201
|
+
llm_client.completion("hello")
|
|
202
|
+
except Exception as e:
|
|
203
|
+
structlogger.error(
|
|
204
|
+
f"{log_source_function}.send_test_llm_api_request_failed",
|
|
205
|
+
event_info="Test call to the LLM API failed.",
|
|
206
|
+
error=e,
|
|
207
|
+
)
|
|
208
|
+
print_error_and_exit(
|
|
209
|
+
f"Test call to the LLM API failed for component - {log_source_component}. "
|
|
210
|
+
f"Error: {e}"
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def send_test_embeddings_api_request(
|
|
215
|
+
embedder: EmbeddingClient, log_source_function: str, log_source_component: str
|
|
216
|
+
) -> None:
|
|
217
|
+
"""Sends a test request to the Embeddings API to perform a health check.
|
|
218
|
+
|
|
219
|
+
Raises:
|
|
220
|
+
Exception: If the API call fails.
|
|
221
|
+
"""
|
|
222
|
+
structlogger.info(
|
|
223
|
+
f"{log_source_function}.send_test_embeddings_api_request",
|
|
224
|
+
event_info=(
|
|
225
|
+
f"Sending a test Embeddings API request for the component - "
|
|
226
|
+
f"{log_source_component}."
|
|
227
|
+
),
|
|
228
|
+
config=embedder.config,
|
|
229
|
+
)
|
|
230
|
+
try:
|
|
231
|
+
embedder.embed(["hello"])
|
|
232
|
+
except Exception as e:
|
|
233
|
+
structlogger.error(
|
|
234
|
+
f"{log_source_function}.send_test_llm_api_request_failed",
|
|
235
|
+
event_info="Test call to the Embeddings API failed.",
|
|
236
|
+
error=e,
|
|
237
|
+
)
|
|
238
|
+
print_error_and_exit(
|
|
239
|
+
f"Test call to the Embeddings API failed for component - "
|
|
240
|
+
f"{log_source_component}. Error: {e}"
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def is_api_health_check_enabled() -> bool:
|
|
245
|
+
"""Determines whether the API health check is enabled based on an environment
|
|
246
|
+
variable.
|
|
247
|
+
|
|
248
|
+
Returns:
|
|
249
|
+
bool: True if the API health check is enabled, False otherwise.
|
|
250
|
+
"""
|
|
251
|
+
return (
|
|
252
|
+
os.getenv(
|
|
253
|
+
LLM_API_HEALTH_CHECK_ENV_VAR, LLM_API_HEALTH_CHECK_DEFAULT_VALUE
|
|
254
|
+
).lower()
|
|
255
|
+
== "true"
|
|
256
|
+
)
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from typing import Optional, Dict, Any
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class LLMHealthCheckMixin:
|
|
5
|
+
"""Mixin class that provides methods for performing llm health checks during
|
|
6
|
+
training and inference within components.
|
|
7
|
+
|
|
8
|
+
This mixin offers static methods that wrap the following health check functions:
|
|
9
|
+
- `perform_llm_health_check`
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
@staticmethod
|
|
13
|
+
def perform_llm_health_check(
|
|
14
|
+
custom_llm_config: Optional[Dict[str, Any]],
|
|
15
|
+
default_llm_config: Dict[str, Any],
|
|
16
|
+
log_source_method: str,
|
|
17
|
+
log_source_component: str,
|
|
18
|
+
) -> None:
|
|
19
|
+
"""Wraps the `perform_llm_health_check` function to enable
|
|
20
|
+
tracing and instrumentation.
|
|
21
|
+
"""
|
|
22
|
+
from rasa.shared.utils.health_check.health_check import (
|
|
23
|
+
perform_llm_health_check,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
perform_llm_health_check(
|
|
27
|
+
custom_llm_config,
|
|
28
|
+
default_llm_config,
|
|
29
|
+
log_source_method,
|
|
30
|
+
log_source_component,
|
|
31
|
+
)
|