rasa-pro 3.15.0a1__py3-none-any.whl → 3.15.0a3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/builder/constants.py +5 -0
- rasa/builder/copilot/models.py +80 -28
- rasa/builder/download.py +110 -0
- rasa/builder/evaluator/__init__.py +0 -0
- rasa/builder/evaluator/constants.py +15 -0
- rasa/builder/evaluator/copilot_executor.py +89 -0
- rasa/builder/evaluator/dataset/models.py +173 -0
- rasa/builder/evaluator/exceptions.py +4 -0
- rasa/builder/evaluator/response_classification/__init__.py +0 -0
- rasa/builder/evaluator/response_classification/constants.py +66 -0
- rasa/builder/evaluator/response_classification/evaluator.py +346 -0
- rasa/builder/evaluator/response_classification/langfuse_runner.py +463 -0
- rasa/builder/evaluator/response_classification/models.py +61 -0
- rasa/builder/evaluator/scripts/__init__.py +0 -0
- rasa/builder/evaluator/scripts/run_response_classification_evaluator.py +152 -0
- rasa/builder/jobs.py +208 -1
- rasa/builder/logging_utils.py +25 -24
- rasa/builder/main.py +6 -1
- rasa/builder/models.py +23 -0
- rasa/builder/project_generator.py +29 -10
- rasa/builder/service.py +104 -22
- rasa/builder/training_service.py +13 -1
- rasa/builder/validation_service.py +2 -1
- rasa/core/actions/action_clean_stack.py +32 -0
- rasa/core/actions/constants.py +4 -0
- rasa/core/actions/custom_action_executor.py +70 -12
- rasa/core/actions/grpc_custom_action_executor.py +41 -2
- rasa/core/actions/http_custom_action_executor.py +49 -25
- rasa/core/channels/voice_stream/voice_channel.py +14 -2
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +6 -3
- rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +15 -7
- rasa/dialogue_understanding/generator/single_step/search_ready_llm_command_generator.py +15 -8
- rasa/dialogue_understanding/processor/command_processor.py +49 -7
- rasa/shared/providers/_configs/azure_openai_client_config.py +4 -5
- rasa/shared/providers/_configs/default_litellm_client_config.py +4 -4
- rasa/shared/providers/_configs/litellm_router_client_config.py +3 -2
- rasa/shared/providers/_configs/openai_client_config.py +5 -7
- rasa/shared/providers/_configs/rasa_llm_client_config.py +4 -4
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +4 -4
- rasa/shared/providers/llm/_base_litellm_client.py +42 -14
- rasa/shared/providers/llm/litellm_router_llm_client.py +38 -15
- rasa/shared/providers/llm/self_hosted_llm_client.py +34 -32
- rasa/shared/utils/configs.py +5 -8
- rasa/utils/endpoints.py +6 -0
- rasa/version.py +1 -1
- {rasa_pro-3.15.0a1.dist-info → rasa_pro-3.15.0a3.dist-info}/METADATA +12 -12
- {rasa_pro-3.15.0a1.dist-info → rasa_pro-3.15.0a3.dist-info}/RECORD +50 -37
- {rasa_pro-3.15.0a1.dist-info → rasa_pro-3.15.0a3.dist-info}/NOTICE +0 -0
- {rasa_pro-3.15.0a1.dist-info → rasa_pro-3.15.0a3.dist-info}/WHEEL +0 -0
- {rasa_pro-3.15.0a1.dist-info → rasa_pro-3.15.0a3.dist-info}/entry_points.txt +0 -0
|
@@ -167,8 +167,9 @@ class OAuthConfigWrapper(OAuth, BaseModel):
|
|
|
167
167
|
|
|
168
168
|
@dataclass
|
|
169
169
|
class AzureOpenAIClientConfig:
|
|
170
|
-
"""Parses configuration for Azure OpenAI client
|
|
171
|
-
|
|
170
|
+
"""Parses configuration for Azure OpenAI client.
|
|
171
|
+
|
|
172
|
+
Resolves aliases and raises deprecation warnings.
|
|
172
173
|
|
|
173
174
|
Raises:
|
|
174
175
|
ValueError: Raised in cases of invalid configuration:
|
|
@@ -301,9 +302,7 @@ class AzureOpenAIClientConfig:
|
|
|
301
302
|
|
|
302
303
|
|
|
303
304
|
def is_azure_openai_config(config: dict) -> bool:
|
|
304
|
-
"""Check whether the configuration is meant to configure
|
|
305
|
-
an Azure OpenAI client.
|
|
306
|
-
"""
|
|
305
|
+
"""Check whether the configuration is meant to configure an Azure OpenAI client."""
|
|
307
306
|
# Resolve any aliases that are specific to Azure OpenAI configuration
|
|
308
307
|
config = AzureOpenAIClientConfig.resolve_config_aliases(config)
|
|
309
308
|
|
|
@@ -40,8 +40,9 @@ FORBIDDEN_KEYS = [
|
|
|
40
40
|
|
|
41
41
|
@dataclass
|
|
42
42
|
class DefaultLiteLLMClientConfig:
|
|
43
|
-
"""Parses configuration for default LiteLLM client
|
|
44
|
-
|
|
43
|
+
"""Parses configuration for default LiteLLM client.
|
|
44
|
+
|
|
45
|
+
Resolves aliases and raises deprecation warnings.
|
|
45
46
|
|
|
46
47
|
Raises:
|
|
47
48
|
ValueError: Raised in cases of invalid configuration:
|
|
@@ -72,8 +73,7 @@ class DefaultLiteLLMClientConfig:
|
|
|
72
73
|
|
|
73
74
|
@classmethod
|
|
74
75
|
def from_dict(cls, config: dict) -> DefaultLiteLLMClientConfig:
|
|
75
|
-
"""
|
|
76
|
-
Initializes a dataclass from the passed config.
|
|
76
|
+
"""Initializes a dataclass from the passed config.
|
|
77
77
|
|
|
78
78
|
Args:
|
|
79
79
|
config: (dict) The config from which to initialize.
|
|
@@ -38,8 +38,9 @@ _LITELLM_UNSUPPORTED_KEYS = [
|
|
|
38
38
|
|
|
39
39
|
@dataclass
|
|
40
40
|
class LiteLLMRouterClientConfig:
|
|
41
|
-
"""Parses configuration for a LiteLLM Router client.
|
|
42
|
-
|
|
41
|
+
"""Parses configuration for a LiteLLM Router client.
|
|
42
|
+
|
|
43
|
+
The configuration is expected to be in the following format:
|
|
43
44
|
|
|
44
45
|
{
|
|
45
46
|
"id": "model_group_id",
|
|
@@ -64,8 +64,9 @@ FORBIDDEN_KEYS = [
|
|
|
64
64
|
|
|
65
65
|
@dataclass
|
|
66
66
|
class OpenAIClientConfig:
|
|
67
|
-
"""Parses configuration for
|
|
68
|
-
|
|
67
|
+
"""Parses configuration for OpenAI client.
|
|
68
|
+
|
|
69
|
+
Resolves aliases and raises deprecation warnings.
|
|
69
70
|
|
|
70
71
|
Raises:
|
|
71
72
|
ValueError: Raised in cases of invalid configuration:
|
|
@@ -118,8 +119,7 @@ class OpenAIClientConfig:
|
|
|
118
119
|
|
|
119
120
|
@classmethod
|
|
120
121
|
def from_dict(cls, config: dict) -> OpenAIClientConfig:
|
|
121
|
-
"""
|
|
122
|
-
Initializes a dataclass from the passed config.
|
|
122
|
+
"""Initializes a dataclass from the passed config.
|
|
123
123
|
|
|
124
124
|
Args:
|
|
125
125
|
config: (dict) The config from which to initialize.
|
|
@@ -168,9 +168,7 @@ class OpenAIClientConfig:
|
|
|
168
168
|
|
|
169
169
|
|
|
170
170
|
def is_openai_config(config: dict) -> bool:
|
|
171
|
-
"""Check whether the configuration is meant to configure
|
|
172
|
-
an OpenAI client.
|
|
173
|
-
"""
|
|
171
|
+
"""Check whether the configuration is meant to configure an OpenAI client."""
|
|
174
172
|
# Process the config to handle all the aliases
|
|
175
173
|
config = OpenAIClientConfig.resolve_config_aliases(config)
|
|
176
174
|
|
|
@@ -22,8 +22,9 @@ structlogger = structlog.get_logger()
|
|
|
22
22
|
|
|
23
23
|
@dataclass
|
|
24
24
|
class RasaLLMClientConfig:
|
|
25
|
-
"""Parses configuration for a Rasa Hosted LiteLLM client
|
|
26
|
-
|
|
25
|
+
"""Parses configuration for a Rasa Hosted LiteLLM client.
|
|
26
|
+
|
|
27
|
+
Checks required keys present.
|
|
27
28
|
|
|
28
29
|
Raises:
|
|
29
30
|
ValueError: Raised in cases of invalid configuration:
|
|
@@ -40,8 +41,7 @@ class RasaLLMClientConfig:
|
|
|
40
41
|
|
|
41
42
|
@classmethod
|
|
42
43
|
def from_dict(cls, config: dict) -> RasaLLMClientConfig:
|
|
43
|
-
"""
|
|
44
|
-
Initializes a dataclass from the passed config.
|
|
44
|
+
"""Initializes a dataclass from the passed config.
|
|
45
45
|
|
|
46
46
|
Args:
|
|
47
47
|
config: (dict) The config from which to initialize.
|
|
@@ -61,8 +61,9 @@ FORBIDDEN_KEYS = [
|
|
|
61
61
|
|
|
62
62
|
@dataclass
|
|
63
63
|
class SelfHostedLLMClientConfig:
|
|
64
|
-
"""Parses configuration for Self Hosted LiteLLM client
|
|
65
|
-
|
|
64
|
+
"""Parses configuration for Self Hosted LiteLLM client.
|
|
65
|
+
|
|
66
|
+
Resolves aliases and raises deprecation warnings.
|
|
66
67
|
|
|
67
68
|
Raises:
|
|
68
69
|
ValueError: Raised in cases of invalid configuration:
|
|
@@ -116,8 +117,7 @@ class SelfHostedLLMClientConfig:
|
|
|
116
117
|
|
|
117
118
|
@classmethod
|
|
118
119
|
def from_dict(cls, config: dict) -> SelfHostedLLMClientConfig:
|
|
119
|
-
"""
|
|
120
|
-
Initializes a dataclass from the passed config.
|
|
120
|
+
"""Initializes a dataclass from the passed config.
|
|
121
121
|
|
|
122
122
|
Args:
|
|
123
123
|
config: (dict) The config from which to initialize.
|
|
@@ -1,12 +1,14 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
import logging
|
|
4
5
|
from abc import abstractmethod
|
|
5
|
-
from typing import Any, Dict, List, Union, cast
|
|
6
|
+
from typing import Any, Dict, List, NoReturn, Union, cast
|
|
6
7
|
|
|
7
8
|
import structlog
|
|
8
9
|
from litellm import acompletion, completion, validate_environment
|
|
9
10
|
|
|
11
|
+
from rasa.core.constants import DEFAULT_REQUEST_TIMEOUT
|
|
10
12
|
from rasa.shared.constants import (
|
|
11
13
|
_VALIDATE_ENVIRONMENT_MISSING_KEYS_KEY,
|
|
12
14
|
API_BASE_CONFIG_KEY,
|
|
@@ -57,26 +59,24 @@ class _BaseLiteLLMClient:
|
|
|
57
59
|
@property
|
|
58
60
|
@abstractmethod
|
|
59
61
|
def config(self) -> dict:
|
|
60
|
-
"""Returns the configuration for that the llm client
|
|
61
|
-
in dictionary form.
|
|
62
|
-
"""
|
|
62
|
+
"""Returns the configuration for that the llm client in dictionary form."""
|
|
63
63
|
pass
|
|
64
64
|
|
|
65
65
|
@property
|
|
66
66
|
@abstractmethod
|
|
67
67
|
def _litellm_model_name(self) -> str:
|
|
68
|
-
"""Returns the value of LiteLLM's model parameter
|
|
69
|
-
completion/acompletion in LiteLLM format:
|
|
68
|
+
"""Returns the value of LiteLLM's model parameter.
|
|
70
69
|
|
|
70
|
+
To be used in completion/acompletion in LiteLLM format:
|
|
71
71
|
<provider>/<model or deployment name>
|
|
72
72
|
"""
|
|
73
73
|
pass
|
|
74
74
|
|
|
75
75
|
@property
|
|
76
76
|
def _litellm_extra_parameters(self) -> Dict[str, Any]:
|
|
77
|
-
"""Returns a dictionary of extra parameters
|
|
78
|
-
parameters as well as LiteLLM specific input parameters.
|
|
77
|
+
"""Returns a dictionary of extra parameters.
|
|
79
78
|
|
|
79
|
+
Includes model parameters as well as LiteLLM specific input parameters.
|
|
80
80
|
By default, this returns an empty dictionary (no extra parameters).
|
|
81
81
|
"""
|
|
82
82
|
return {}
|
|
@@ -96,8 +96,9 @@ class _BaseLiteLLMClient:
|
|
|
96
96
|
}
|
|
97
97
|
|
|
98
98
|
def validate_client_setup(self) -> None:
|
|
99
|
-
"""Perform client validation.
|
|
100
|
-
|
|
99
|
+
"""Perform client validation.
|
|
100
|
+
|
|
101
|
+
By default only environment variables are validated.
|
|
101
102
|
|
|
102
103
|
Raises:
|
|
103
104
|
ProviderClientValidationError if validation fails.
|
|
@@ -188,10 +189,17 @@ class _BaseLiteLLMClient:
|
|
|
188
189
|
arguments = cast(
|
|
189
190
|
Dict[str, Any], resolve_environment_variables(self._completion_fn_args)
|
|
190
191
|
)
|
|
191
|
-
|
|
192
|
-
|
|
192
|
+
|
|
193
|
+
timeout = self._litellm_extra_parameters.get(
|
|
194
|
+
"timeout", DEFAULT_REQUEST_TIMEOUT
|
|
195
|
+
)
|
|
196
|
+
response = await asyncio.wait_for(
|
|
197
|
+
acompletion(messages=formatted_messages, **{**arguments, **kwargs}),
|
|
198
|
+
timeout=timeout,
|
|
193
199
|
)
|
|
194
200
|
return self._format_response(response)
|
|
201
|
+
except asyncio.TimeoutError:
|
|
202
|
+
self._handle_timeout_error()
|
|
195
203
|
except Exception as e:
|
|
196
204
|
message = ""
|
|
197
205
|
from rasa.shared.providers.llm.self_hosted_llm_client import (
|
|
@@ -211,6 +219,25 @@ class _BaseLiteLLMClient:
|
|
|
211
219
|
)
|
|
212
220
|
raise ProviderClientAPIException(e, message) from e
|
|
213
221
|
|
|
222
|
+
def _handle_timeout_error(self) -> NoReturn:
|
|
223
|
+
"""Handle asyncio.TimeoutError and raise ProviderClientAPIException.
|
|
224
|
+
|
|
225
|
+
Raises:
|
|
226
|
+
ProviderClientAPIException: Always raised with formatted timeout error.
|
|
227
|
+
"""
|
|
228
|
+
timeout = self._litellm_extra_parameters.get("timeout", DEFAULT_REQUEST_TIMEOUT)
|
|
229
|
+
error_message = (
|
|
230
|
+
f"APITimeoutError - Request timed out. Error_str: "
|
|
231
|
+
f"Request timed out. - timeout value={timeout:.6f}, "
|
|
232
|
+
f"time taken={timeout:.6f} seconds"
|
|
233
|
+
)
|
|
234
|
+
# nosemgrep: semgrep.rules.pii-positional-arguments-in-logging
|
|
235
|
+
# Error message contains only numeric timeout values, not PII
|
|
236
|
+
structlogger.error(
|
|
237
|
+
f"{self.__class__.__name__.lower()}.llm.timeout", error=error_message
|
|
238
|
+
)
|
|
239
|
+
raise ProviderClientAPIException(asyncio.TimeoutError(error_message)) from None
|
|
240
|
+
|
|
214
241
|
def _get_formatted_messages(
|
|
215
242
|
self, messages: Union[List[dict], List[str], str]
|
|
216
243
|
) -> List[Dict[str, str]]:
|
|
@@ -312,8 +339,9 @@ class _BaseLiteLLMClient:
|
|
|
312
339
|
|
|
313
340
|
@staticmethod
|
|
314
341
|
def _ensure_certificates() -> None:
|
|
315
|
-
"""Configures SSL certificates for LiteLLM.
|
|
316
|
-
|
|
342
|
+
"""Configures SSL certificates for LiteLLM.
|
|
343
|
+
|
|
344
|
+
This method is invoked during client initialization.
|
|
317
345
|
|
|
318
346
|
LiteLLM may utilize `openai` clients or other providers that require
|
|
319
347
|
SSL verification settings through the `SSL_VERIFY` / `SSL_CERTIFICATE`
|
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
import logging
|
|
4
5
|
from typing import Any, Dict, List, Union
|
|
5
6
|
|
|
6
7
|
import structlog
|
|
7
8
|
|
|
9
|
+
from rasa.core.constants import DEFAULT_REQUEST_TIMEOUT
|
|
8
10
|
from rasa.shared.exceptions import ProviderClientAPIException
|
|
9
11
|
from rasa.shared.providers._configs.litellm_router_client_config import (
|
|
10
12
|
LiteLLMRouterClientConfig,
|
|
@@ -79,13 +81,14 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
79
81
|
|
|
80
82
|
@suppress_logs(log_level=logging.WARNING)
|
|
81
83
|
def _text_completion(self, prompt: Union[List[str], str]) -> LLMResponse:
|
|
82
|
-
"""
|
|
83
|
-
Synchronously generate completions for given prompt.
|
|
84
|
+
"""Synchronously generate completions for given prompt.
|
|
84
85
|
|
|
85
86
|
Args:
|
|
86
87
|
prompt: Prompt to generate the completion for.
|
|
88
|
+
|
|
87
89
|
Returns:
|
|
88
90
|
List of message completions.
|
|
91
|
+
|
|
89
92
|
Raises:
|
|
90
93
|
ProviderClientAPIException: If the API request fails.
|
|
91
94
|
"""
|
|
@@ -103,21 +106,30 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
103
106
|
|
|
104
107
|
@suppress_logs(log_level=logging.WARNING)
|
|
105
108
|
async def _atext_completion(self, prompt: Union[List[str], str]) -> LLMResponse:
|
|
106
|
-
"""
|
|
107
|
-
Asynchronously generate completions for given prompt.
|
|
109
|
+
"""Asynchronously generate completions for given prompt.
|
|
108
110
|
|
|
109
111
|
Args:
|
|
110
112
|
prompt: Prompt to generate the completion for.
|
|
113
|
+
|
|
111
114
|
Returns:
|
|
112
115
|
List of message completions.
|
|
116
|
+
|
|
113
117
|
Raises:
|
|
114
118
|
ProviderClientAPIException: If the API request fails.
|
|
115
119
|
"""
|
|
116
120
|
try:
|
|
117
|
-
|
|
118
|
-
|
|
121
|
+
timeout = self._litellm_extra_parameters.get(
|
|
122
|
+
"timeout", DEFAULT_REQUEST_TIMEOUT
|
|
123
|
+
)
|
|
124
|
+
response = await asyncio.wait_for(
|
|
125
|
+
self.router_client.atext_completion(
|
|
126
|
+
prompt=prompt, **self._completion_fn_args
|
|
127
|
+
),
|
|
128
|
+
timeout=timeout,
|
|
119
129
|
)
|
|
120
130
|
return self._format_text_completion_response(response)
|
|
131
|
+
except asyncio.TimeoutError:
|
|
132
|
+
self._handle_timeout_error()
|
|
121
133
|
except Exception as e:
|
|
122
134
|
raise ProviderClientAPIException(e)
|
|
123
135
|
|
|
@@ -125,8 +137,7 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
125
137
|
def completion(
|
|
126
138
|
self, messages: Union[List[dict], List[str], str], **kwargs: Any
|
|
127
139
|
) -> LLMResponse:
|
|
128
|
-
"""
|
|
129
|
-
Synchronously generate completions for given list of messages.
|
|
140
|
+
"""Synchronously generate completions for given list of messages.
|
|
130
141
|
|
|
131
142
|
Method overrides the base class method to call the appropriate
|
|
132
143
|
completion method based on the configuration. If the chat completions
|
|
@@ -143,8 +154,10 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
143
154
|
as a user message.
|
|
144
155
|
- a single message as a string which will be formatted as user message.
|
|
145
156
|
**kwargs: Additional parameters to pass to the completion call.
|
|
157
|
+
|
|
146
158
|
Returns:
|
|
147
159
|
List of message completions.
|
|
160
|
+
|
|
148
161
|
Raises:
|
|
149
162
|
ProviderClientAPIException: If the API request fails.
|
|
150
163
|
"""
|
|
@@ -163,8 +176,7 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
163
176
|
async def acompletion(
|
|
164
177
|
self, messages: Union[List[dict], List[str], str], **kwargs: Any
|
|
165
178
|
) -> LLMResponse:
|
|
166
|
-
"""
|
|
167
|
-
Asynchronously generate completions for given list of messages.
|
|
179
|
+
"""Asynchronously generate completions for given list of messages.
|
|
168
180
|
|
|
169
181
|
Method overrides the base class method to call the appropriate
|
|
170
182
|
completion method based on the configuration. If the chat completions
|
|
@@ -181,8 +193,10 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
181
193
|
as a user message.
|
|
182
194
|
- a single message as a string which will be formatted as user message.
|
|
183
195
|
**kwargs: Additional parameters to pass to the completion call.
|
|
196
|
+
|
|
184
197
|
Returns:
|
|
185
198
|
List of message completions.
|
|
199
|
+
|
|
186
200
|
Raises:
|
|
187
201
|
ProviderClientAPIException: If the API request fails.
|
|
188
202
|
"""
|
|
@@ -190,19 +204,28 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
190
204
|
return await self._atext_completion(messages)
|
|
191
205
|
try:
|
|
192
206
|
formatted_messages = self._get_formatted_messages(messages)
|
|
193
|
-
|
|
194
|
-
|
|
207
|
+
timeout = self._litellm_extra_parameters.get(
|
|
208
|
+
"timeout", DEFAULT_REQUEST_TIMEOUT
|
|
209
|
+
)
|
|
210
|
+
response = await asyncio.wait_for(
|
|
211
|
+
self.router_client.acompletion(
|
|
212
|
+
messages=formatted_messages,
|
|
213
|
+
**{**self._completion_fn_args, **kwargs},
|
|
214
|
+
),
|
|
215
|
+
timeout=timeout,
|
|
195
216
|
)
|
|
196
217
|
return self._format_response(response)
|
|
218
|
+
except asyncio.TimeoutError:
|
|
219
|
+
self._handle_timeout_error()
|
|
197
220
|
except Exception as e:
|
|
198
221
|
raise ProviderClientAPIException(e)
|
|
199
222
|
|
|
200
223
|
@property
|
|
201
224
|
def _completion_fn_args(self) -> Dict[str, Any]:
|
|
202
|
-
"""Returns the completion arguments
|
|
203
|
-
LiteLLM's completion functions.
|
|
204
|
-
"""
|
|
225
|
+
"""Returns the completion arguments.
|
|
205
226
|
|
|
227
|
+
For invoking a call through LiteLLM's completion functions.
|
|
228
|
+
"""
|
|
206
229
|
return {
|
|
207
230
|
**self._litellm_extra_parameters,
|
|
208
231
|
LITE_LLM_MODEL_FIELD: self.model_group_id,
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
import logging
|
|
4
5
|
import os
|
|
5
6
|
from typing import Any, Dict, List, Optional, Union
|
|
@@ -7,6 +8,7 @@ from typing import Any, Dict, List, Optional, Union
|
|
|
7
8
|
import structlog
|
|
8
9
|
from litellm import atext_completion, text_completion
|
|
9
10
|
|
|
11
|
+
from rasa.core.constants import DEFAULT_REQUEST_TIMEOUT
|
|
10
12
|
from rasa.shared.constants import (
|
|
11
13
|
API_KEY,
|
|
12
14
|
SELF_HOSTED_VLLM_API_KEY_ENV_VAR,
|
|
@@ -28,7 +30,7 @@ structlogger = structlog.get_logger()
|
|
|
28
30
|
|
|
29
31
|
|
|
30
32
|
class SelfHostedLLMClient(_BaseLiteLLMClient):
|
|
31
|
-
"""A client for interfacing with Self Hosted LLM endpoints
|
|
33
|
+
"""A client for interfacing with Self Hosted LLM endpoints.
|
|
32
34
|
|
|
33
35
|
Parameters:
|
|
34
36
|
model (str): The model or deployment name.
|
|
@@ -95,8 +97,7 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
|
|
|
95
97
|
|
|
96
98
|
@property
|
|
97
99
|
def provider(self) -> str:
|
|
98
|
-
"""
|
|
99
|
-
Returns the provider name for the self hosted llm client.
|
|
100
|
+
"""Returns the provider name for the self hosted llm client.
|
|
100
101
|
|
|
101
102
|
Returns:
|
|
102
103
|
String representing the provider name.
|
|
@@ -105,8 +106,7 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
|
|
|
105
106
|
|
|
106
107
|
@property
|
|
107
108
|
def model(self) -> str:
|
|
108
|
-
"""
|
|
109
|
-
Returns the model name for the self hosted llm client.
|
|
109
|
+
"""Returns the model name for the self hosted llm client.
|
|
110
110
|
|
|
111
111
|
Returns:
|
|
112
112
|
String representing the model name.
|
|
@@ -115,8 +115,7 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
|
|
|
115
115
|
|
|
116
116
|
@property
|
|
117
117
|
def api_base(self) -> str:
|
|
118
|
-
"""
|
|
119
|
-
Returns the base URL for the API endpoint.
|
|
118
|
+
"""Returns the base URL for the API endpoint.
|
|
120
119
|
|
|
121
120
|
Returns:
|
|
122
121
|
String representing the base URL.
|
|
@@ -125,8 +124,7 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
|
|
|
125
124
|
|
|
126
125
|
@property
|
|
127
126
|
def api_type(self) -> Optional[str]:
|
|
128
|
-
"""
|
|
129
|
-
Returns the type of the API endpoint. Currently only OpenAI is supported.
|
|
127
|
+
"""Returns the type of the API endpoint. Currently only OpenAI is supported.
|
|
130
128
|
|
|
131
129
|
Returns:
|
|
132
130
|
String representing the API type.
|
|
@@ -135,8 +133,7 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
|
|
|
135
133
|
|
|
136
134
|
@property
|
|
137
135
|
def api_version(self) -> Optional[str]:
|
|
138
|
-
"""
|
|
139
|
-
Returns the version of the API endpoint.
|
|
136
|
+
"""Returns the version of the API endpoint.
|
|
140
137
|
|
|
141
138
|
Returns:
|
|
142
139
|
String representing the API version.
|
|
@@ -145,8 +142,8 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
|
|
|
145
142
|
|
|
146
143
|
@property
|
|
147
144
|
def config(self) -> Dict:
|
|
148
|
-
"""
|
|
149
|
-
|
|
145
|
+
"""Returns the configuration for the self hosted llm client.
|
|
146
|
+
|
|
150
147
|
Returns:
|
|
151
148
|
Dictionary containing the configuration.
|
|
152
149
|
"""
|
|
@@ -163,9 +160,9 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
|
|
|
163
160
|
|
|
164
161
|
@property
|
|
165
162
|
def _litellm_model_name(self) -> str:
|
|
166
|
-
"""Returns the value of LiteLLM's model parameter
|
|
167
|
-
completion/acompletion in LiteLLM format:
|
|
163
|
+
"""Returns the value of LiteLLM's model parameter.
|
|
168
164
|
|
|
165
|
+
To be used in completion/acompletion in LiteLLM format:
|
|
169
166
|
<hosted_vllm>/<model or deployment name>
|
|
170
167
|
"""
|
|
171
168
|
if self.model and f"{SELF_HOSTED_VLLM_PREFIX}/" not in self.model:
|
|
@@ -174,15 +171,17 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
|
|
|
174
171
|
|
|
175
172
|
@property
|
|
176
173
|
def _litellm_extra_parameters(self) -> Dict[str, Any]:
|
|
177
|
-
"""Returns optional configuration parameters
|
|
178
|
-
|
|
174
|
+
"""Returns optional configuration parameters.
|
|
175
|
+
|
|
176
|
+
Specific to the client provider and deployed model.
|
|
179
177
|
"""
|
|
180
178
|
return self._extra_parameters
|
|
181
179
|
|
|
182
180
|
@property
|
|
183
181
|
def _completion_fn_args(self) -> Dict[str, Any]:
|
|
184
|
-
"""Returns the completion arguments
|
|
185
|
-
|
|
182
|
+
"""Returns the completion arguments.
|
|
183
|
+
|
|
184
|
+
For invoking a call through LiteLLM's completion functions.
|
|
186
185
|
"""
|
|
187
186
|
fn_args = super()._completion_fn_args
|
|
188
187
|
fn_args.update(
|
|
@@ -195,13 +194,14 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
|
|
|
195
194
|
|
|
196
195
|
@suppress_logs(log_level=logging.WARNING)
|
|
197
196
|
def _text_completion(self, prompt: Union[List[str], str]) -> LLMResponse:
|
|
198
|
-
"""
|
|
199
|
-
Synchronously generate completions for given prompt.
|
|
197
|
+
"""Synchronously generate completions for given prompt.
|
|
200
198
|
|
|
201
199
|
Args:
|
|
202
200
|
prompt: Prompt to generate the completion for.
|
|
201
|
+
|
|
203
202
|
Returns:
|
|
204
203
|
List of message completions.
|
|
204
|
+
|
|
205
205
|
Raises:
|
|
206
206
|
ProviderClientAPIException: If the API request fails.
|
|
207
207
|
"""
|
|
@@ -213,26 +213,28 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
|
|
|
213
213
|
|
|
214
214
|
@suppress_logs(log_level=logging.WARNING)
|
|
215
215
|
async def _atext_completion(self, prompt: Union[List[str], str]) -> LLMResponse:
|
|
216
|
-
"""
|
|
217
|
-
Asynchronously generate completions for given prompt.
|
|
216
|
+
"""Asynchronously generate completions for given prompt.
|
|
218
217
|
|
|
219
218
|
Args:
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
with the following keys:
|
|
223
|
-
- content: The message content.
|
|
224
|
-
- role: The role of the message (e.g. user or system).
|
|
225
|
-
- a list of messages. Each message is a string and will be formatted
|
|
226
|
-
as a user message.
|
|
227
|
-
- a single message as a string which will be formatted as user message.
|
|
219
|
+
prompt: Prompt to generate the completion for.
|
|
220
|
+
|
|
228
221
|
Returns:
|
|
229
222
|
List of message completions.
|
|
223
|
+
|
|
230
224
|
Raises:
|
|
231
225
|
ProviderClientAPIException: If the API request fails.
|
|
232
226
|
"""
|
|
233
227
|
try:
|
|
234
|
-
|
|
228
|
+
timeout = self._litellm_extra_parameters.get(
|
|
229
|
+
"timeout", DEFAULT_REQUEST_TIMEOUT
|
|
230
|
+
)
|
|
231
|
+
response = await asyncio.wait_for(
|
|
232
|
+
atext_completion(prompt=prompt, **self._completion_fn_args),
|
|
233
|
+
timeout=timeout,
|
|
234
|
+
)
|
|
235
235
|
return self._format_text_completion_response(response)
|
|
236
|
+
except asyncio.TimeoutError:
|
|
237
|
+
self._handle_timeout_error()
|
|
236
238
|
except Exception as e:
|
|
237
239
|
raise ProviderClientAPIException(e)
|
|
238
240
|
|
rasa/shared/utils/configs.py
CHANGED
|
@@ -8,8 +8,7 @@ structlogger = structlog.get_logger()
|
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
def resolve_aliases(config: dict, deprecated_alias_mapping: dict) -> dict:
|
|
11
|
-
"""
|
|
12
|
-
Resolve aliases in the configuration to standard keys.
|
|
11
|
+
"""Resolve aliases in the configuration to standard keys.
|
|
13
12
|
|
|
14
13
|
Args:
|
|
15
14
|
config: Dictionary containing the configuration.
|
|
@@ -37,13 +36,13 @@ def raise_deprecation_warnings(
|
|
|
37
36
|
deprecated_alias_mapping: dict,
|
|
38
37
|
source: Optional[str] = None,
|
|
39
38
|
) -> None:
|
|
40
|
-
"""
|
|
41
|
-
Raises warnings for deprecated keys in the configuration.
|
|
39
|
+
"""Raises warnings for deprecated keys in the configuration.
|
|
42
40
|
|
|
43
41
|
Args:
|
|
44
42
|
config: Dictionary containing the configuration.
|
|
45
43
|
deprecated_alias_mapping: Dictionary mapping deprecated keys to
|
|
46
44
|
their standard keys.
|
|
45
|
+
source: Optional source context for the deprecation warning.
|
|
47
46
|
|
|
48
47
|
Raises:
|
|
49
48
|
DeprecationWarning: If any deprecated key is found in the config.
|
|
@@ -61,8 +60,7 @@ def raise_deprecation_warnings(
|
|
|
61
60
|
|
|
62
61
|
|
|
63
62
|
def validate_required_keys(config: dict, required_keys: list) -> None:
|
|
64
|
-
"""
|
|
65
|
-
Validates that the passed config contains all the required keys.
|
|
63
|
+
"""Validates that the passed config contains all the required keys.
|
|
66
64
|
|
|
67
65
|
Args:
|
|
68
66
|
config: Dictionary containing the configuration.
|
|
@@ -84,8 +82,7 @@ def validate_required_keys(config: dict, required_keys: list) -> None:
|
|
|
84
82
|
|
|
85
83
|
|
|
86
84
|
def validate_forbidden_keys(config: dict, forbidden_keys: list) -> None:
|
|
87
|
-
"""
|
|
88
|
-
Validates that the passed config doesn't contain any forbidden keys.
|
|
85
|
+
"""Validates that the passed config doesn't contain any forbidden keys.
|
|
89
86
|
|
|
90
87
|
Args:
|
|
91
88
|
config: Dictionary containing the configuration.
|
rasa/utils/endpoints.py
CHANGED
|
@@ -10,6 +10,7 @@ import structlog
|
|
|
10
10
|
from aiohttp.client_exceptions import ContentTypeError
|
|
11
11
|
from sanic.request import Request
|
|
12
12
|
|
|
13
|
+
from rasa.core.actions.constants import MISSING_DOMAIN_MARKER
|
|
13
14
|
from rasa.core.constants import DEFAULT_REQUEST_TIMEOUT
|
|
14
15
|
from rasa.shared.exceptions import FileNotFoundException
|
|
15
16
|
from rasa.shared.utils.yaml import read_config_file
|
|
@@ -224,6 +225,11 @@ class EndpointConfig:
|
|
|
224
225
|
ssl=sslcontext,
|
|
225
226
|
**kwargs,
|
|
226
227
|
) as response:
|
|
228
|
+
if response.status == 449:
|
|
229
|
+
# Return a special marker that HTTPCustomActionExecutor can detect
|
|
230
|
+
# This avoids raising an exception for this expected case
|
|
231
|
+
return {MISSING_DOMAIN_MARKER: True}
|
|
232
|
+
|
|
227
233
|
if response.status >= 400:
|
|
228
234
|
raise ClientResponseError(
|
|
229
235
|
response.status,
|
rasa/version.py
CHANGED