rasa-pro 3.13.12__py3-none-any.whl → 3.13.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/constants.py +1 -0
- rasa/core/actions/action_clean_stack.py +32 -0
- rasa/core/actions/constants.py +4 -0
- rasa/core/actions/custom_action_executor.py +70 -12
- rasa/core/actions/grpc_custom_action_executor.py +41 -2
- rasa/core/actions/http_custom_action_executor.py +49 -25
- rasa/core/channels/voice_stream/voice_channel.py +26 -16
- rasa/core/policies/flows/flow_executor.py +20 -6
- rasa/core/run.py +0 -1
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +6 -3
- rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +15 -8
- rasa/dialogue_understanding/generator/single_step/search_ready_llm_command_generator.py +15 -8
- rasa/dialogue_understanding/processor/command_processor.py +12 -3
- rasa/e2e_test/e2e_config.py +4 -3
- rasa/model_manager/socket_bridge.py +1 -2
- rasa/shared/core/flows/flow.py +8 -2
- rasa/shared/core/slots.py +55 -24
- rasa/shared/providers/_configs/azure_openai_client_config.py +4 -5
- rasa/shared/providers/_configs/default_litellm_client_config.py +4 -4
- rasa/shared/providers/_configs/litellm_router_client_config.py +3 -2
- rasa/shared/providers/_configs/openai_client_config.py +5 -7
- rasa/shared/providers/_configs/rasa_llm_client_config.py +4 -4
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +4 -4
- rasa/shared/providers/llm/_base_litellm_client.py +42 -13
- rasa/shared/providers/llm/litellm_router_llm_client.py +39 -17
- rasa/shared/providers/llm/self_hosted_llm_client.py +34 -32
- rasa/shared/utils/configs.py +5 -8
- rasa/utils/common.py +9 -0
- rasa/utils/endpoints.py +6 -0
- rasa/version.py +1 -1
- {rasa_pro-3.13.12.dist-info → rasa_pro-3.13.14.dist-info}/METADATA +2 -2
- {rasa_pro-3.13.12.dist-info → rasa_pro-3.13.14.dist-info}/RECORD +35 -35
- {rasa_pro-3.13.12.dist-info → rasa_pro-3.13.14.dist-info}/NOTICE +0 -0
- {rasa_pro-3.13.12.dist-info → rasa_pro-3.13.14.dist-info}/WHEEL +0 -0
- {rasa_pro-3.13.12.dist-info → rasa_pro-3.13.14.dist-info}/entry_points.txt +0 -0
rasa/shared/core/flows/flow.py
CHANGED
|
@@ -322,9 +322,15 @@ class Flow:
|
|
|
322
322
|
|
|
323
323
|
def get_collect_steps(self) -> List[CollectInformationFlowStep]:
|
|
324
324
|
"""Return all CollectInformationFlowSteps in the flow."""
|
|
325
|
-
collect_steps = []
|
|
325
|
+
collect_steps: List[CollectInformationFlowStep] = []
|
|
326
326
|
for step in self.steps_with_calls_resolved:
|
|
327
|
-
|
|
327
|
+
# Only add collect steps that are not already in the list.
|
|
328
|
+
# This is to avoid returning duplicate collect steps from called flows
|
|
329
|
+
# in case the called flow is called multiple times.
|
|
330
|
+
if (
|
|
331
|
+
isinstance(step, CollectInformationFlowStep)
|
|
332
|
+
and step not in collect_steps
|
|
333
|
+
):
|
|
328
334
|
collect_steps.append(step)
|
|
329
335
|
return collect_steps
|
|
330
336
|
|
rasa/shared/core/slots.py
CHANGED
|
@@ -351,8 +351,8 @@ class FloatSlot(Slot):
|
|
|
351
351
|
mappings: List[Dict[Text, Any]],
|
|
352
352
|
initial_value: Optional[float] = None,
|
|
353
353
|
value_reset_delay: Optional[int] = None,
|
|
354
|
-
max_value: float =
|
|
355
|
-
min_value: float =
|
|
354
|
+
max_value: Optional[float] = None,
|
|
355
|
+
min_value: Optional[float] = None,
|
|
356
356
|
influence_conversation: bool = True,
|
|
357
357
|
is_builtin: bool = False,
|
|
358
358
|
shared_for_coexistence: bool = False,
|
|
@@ -376,32 +376,24 @@ class FloatSlot(Slot):
|
|
|
376
376
|
filled_by=filled_by,
|
|
377
377
|
validation=validation,
|
|
378
378
|
)
|
|
379
|
+
self.validate_min_max_range(min_value, max_value)
|
|
380
|
+
|
|
379
381
|
self.max_value = max_value
|
|
380
382
|
self.min_value = min_value
|
|
381
383
|
|
|
382
|
-
if min_value >= max_value:
|
|
383
|
-
raise InvalidSlotConfigError(
|
|
384
|
-
"Float slot ('{}') created with an invalid range "
|
|
385
|
-
"using min ({}) and max ({}) values. Make sure "
|
|
386
|
-
"min is smaller than max."
|
|
387
|
-
"".format(self.name, self.min_value, self.max_value)
|
|
388
|
-
)
|
|
389
|
-
|
|
390
|
-
if initial_value is not None and not (min_value <= initial_value <= max_value):
|
|
391
|
-
rasa.shared.utils.io.raise_warning(
|
|
392
|
-
f"Float slot ('{self.name}') created with an initial value "
|
|
393
|
-
f"{self.value}. This value is outside of the configured min "
|
|
394
|
-
f"({self.min_value}) and max ({self.max_value}) values."
|
|
395
|
-
)
|
|
396
|
-
|
|
397
384
|
def _as_feature(self) -> List[float]:
|
|
385
|
+
# set default min and max values used in prior releases
|
|
386
|
+
# to prevent regressions for existing models
|
|
387
|
+
min_value = self.min_value or 0.0
|
|
388
|
+
max_value = self.max_value or 1.0
|
|
389
|
+
|
|
398
390
|
try:
|
|
399
|
-
capped_value = max(
|
|
400
|
-
if abs(
|
|
401
|
-
covered_range = abs(
|
|
391
|
+
capped_value = max(min_value, min(max_value, float(self.value)))
|
|
392
|
+
if abs(max_value - min_value) > 0:
|
|
393
|
+
covered_range = abs(max_value - min_value)
|
|
402
394
|
else:
|
|
403
395
|
covered_range = 1
|
|
404
|
-
return [1.0, (capped_value -
|
|
396
|
+
return [1.0, (capped_value - min_value) / covered_range]
|
|
405
397
|
except (TypeError, ValueError):
|
|
406
398
|
return [0.0, 0.0]
|
|
407
399
|
|
|
@@ -420,13 +412,52 @@ class FloatSlot(Slot):
|
|
|
420
412
|
return value
|
|
421
413
|
|
|
422
414
|
def is_valid_value(self, value: Any) -> bool:
|
|
423
|
-
"""Checks if the slot
|
|
424
|
-
|
|
425
|
-
|
|
415
|
+
"""Checks if the slot value is valid."""
|
|
416
|
+
if value is None:
|
|
417
|
+
return True
|
|
418
|
+
|
|
419
|
+
if not isinstance(self.coerce_value(value), float):
|
|
420
|
+
return False
|
|
421
|
+
|
|
422
|
+
if (
|
|
423
|
+
self.min_value is not None
|
|
424
|
+
and self.max_value is not None
|
|
425
|
+
and not (self.min_value <= value <= self.max_value)
|
|
426
|
+
):
|
|
427
|
+
return False
|
|
428
|
+
|
|
429
|
+
return True
|
|
426
430
|
|
|
427
431
|
def _feature_dimensionality(self) -> int:
|
|
428
432
|
return len(self.as_feature())
|
|
429
433
|
|
|
434
|
+
def validate_min_max_range(
|
|
435
|
+
self, min_value: Optional[float], max_value: Optional[float]
|
|
436
|
+
) -> None:
|
|
437
|
+
"""Validates the min-max range for the slot.
|
|
438
|
+
|
|
439
|
+
Raises:
|
|
440
|
+
InvalidSlotConfigError, if the min-max range is invalid.
|
|
441
|
+
"""
|
|
442
|
+
if min_value is not None and max_value is not None and min_value >= max_value:
|
|
443
|
+
raise InvalidSlotConfigError(
|
|
444
|
+
f"Float slot ('{self.name}') created with an invalid range "
|
|
445
|
+
f"using min ({min_value}) and max ({max_value}) values. Make sure "
|
|
446
|
+
f"min is smaller than max."
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
if (
|
|
450
|
+
self.initial_value is not None
|
|
451
|
+
and min_value is not None
|
|
452
|
+
and max_value is not None
|
|
453
|
+
and not (min_value <= self.initial_value <= max_value)
|
|
454
|
+
):
|
|
455
|
+
raise InvalidSlotConfigError(
|
|
456
|
+
f"Float slot ('{self.name}') created with an initial value "
|
|
457
|
+
f"{self.initial_value}. This value is outside of the configured min "
|
|
458
|
+
f"({min_value}) and max ({max_value}) values."
|
|
459
|
+
)
|
|
460
|
+
|
|
430
461
|
|
|
431
462
|
class BooleanSlot(Slot):
|
|
432
463
|
"""A slot storing a truth value."""
|
|
@@ -167,8 +167,9 @@ class OAuthConfigWrapper(OAuth, BaseModel):
|
|
|
167
167
|
|
|
168
168
|
@dataclass
|
|
169
169
|
class AzureOpenAIClientConfig:
|
|
170
|
-
"""Parses configuration for Azure OpenAI client
|
|
171
|
-
|
|
170
|
+
"""Parses configuration for Azure OpenAI client.
|
|
171
|
+
|
|
172
|
+
Resolves aliases and raises deprecation warnings.
|
|
172
173
|
|
|
173
174
|
Raises:
|
|
174
175
|
ValueError: Raised in cases of invalid configuration:
|
|
@@ -301,9 +302,7 @@ class AzureOpenAIClientConfig:
|
|
|
301
302
|
|
|
302
303
|
|
|
303
304
|
def is_azure_openai_config(config: dict) -> bool:
|
|
304
|
-
"""Check whether the configuration is meant to configure
|
|
305
|
-
an Azure OpenAI client.
|
|
306
|
-
"""
|
|
305
|
+
"""Check whether the configuration is meant to configure an Azure OpenAI client."""
|
|
307
306
|
# Resolve any aliases that are specific to Azure OpenAI configuration
|
|
308
307
|
config = AzureOpenAIClientConfig.resolve_config_aliases(config)
|
|
309
308
|
|
|
@@ -40,8 +40,9 @@ FORBIDDEN_KEYS = [
|
|
|
40
40
|
|
|
41
41
|
@dataclass
|
|
42
42
|
class DefaultLiteLLMClientConfig:
|
|
43
|
-
"""Parses configuration for default LiteLLM client
|
|
44
|
-
|
|
43
|
+
"""Parses configuration for default LiteLLM client.
|
|
44
|
+
|
|
45
|
+
Resolves aliases and raises deprecation warnings.
|
|
45
46
|
|
|
46
47
|
Raises:
|
|
47
48
|
ValueError: Raised in cases of invalid configuration:
|
|
@@ -72,8 +73,7 @@ class DefaultLiteLLMClientConfig:
|
|
|
72
73
|
|
|
73
74
|
@classmethod
|
|
74
75
|
def from_dict(cls, config: dict) -> DefaultLiteLLMClientConfig:
|
|
75
|
-
"""
|
|
76
|
-
Initializes a dataclass from the passed config.
|
|
76
|
+
"""Initializes a dataclass from the passed config.
|
|
77
77
|
|
|
78
78
|
Args:
|
|
79
79
|
config: (dict) The config from which to initialize.
|
|
@@ -38,8 +38,9 @@ _LITELLM_UNSUPPORTED_KEYS = [
|
|
|
38
38
|
|
|
39
39
|
@dataclass
|
|
40
40
|
class LiteLLMRouterClientConfig:
|
|
41
|
-
"""Parses configuration for a LiteLLM Router client.
|
|
42
|
-
|
|
41
|
+
"""Parses configuration for a LiteLLM Router client.
|
|
42
|
+
|
|
43
|
+
The configuration is expected to be in the following format:
|
|
43
44
|
|
|
44
45
|
{
|
|
45
46
|
"id": "model_group_id",
|
|
@@ -64,8 +64,9 @@ FORBIDDEN_KEYS = [
|
|
|
64
64
|
|
|
65
65
|
@dataclass
|
|
66
66
|
class OpenAIClientConfig:
|
|
67
|
-
"""Parses configuration for
|
|
68
|
-
|
|
67
|
+
"""Parses configuration for OpenAI client.
|
|
68
|
+
|
|
69
|
+
Resolves aliases and raises deprecation warnings.
|
|
69
70
|
|
|
70
71
|
Raises:
|
|
71
72
|
ValueError: Raised in cases of invalid configuration:
|
|
@@ -118,8 +119,7 @@ class OpenAIClientConfig:
|
|
|
118
119
|
|
|
119
120
|
@classmethod
|
|
120
121
|
def from_dict(cls, config: dict) -> OpenAIClientConfig:
|
|
121
|
-
"""
|
|
122
|
-
Initializes a dataclass from the passed config.
|
|
122
|
+
"""Initializes a dataclass from the passed config.
|
|
123
123
|
|
|
124
124
|
Args:
|
|
125
125
|
config: (dict) The config from which to initialize.
|
|
@@ -168,9 +168,7 @@ class OpenAIClientConfig:
|
|
|
168
168
|
|
|
169
169
|
|
|
170
170
|
def is_openai_config(config: dict) -> bool:
|
|
171
|
-
"""Check whether the configuration is meant to configure
|
|
172
|
-
an OpenAI client.
|
|
173
|
-
"""
|
|
171
|
+
"""Check whether the configuration is meant to configure an OpenAI client."""
|
|
174
172
|
# Process the config to handle all the aliases
|
|
175
173
|
config = OpenAIClientConfig.resolve_config_aliases(config)
|
|
176
174
|
|
|
@@ -22,8 +22,9 @@ structlogger = structlog.get_logger()
|
|
|
22
22
|
|
|
23
23
|
@dataclass
|
|
24
24
|
class RasaLLMClientConfig:
|
|
25
|
-
"""Parses configuration for a Rasa Hosted LiteLLM client
|
|
26
|
-
|
|
25
|
+
"""Parses configuration for a Rasa Hosted LiteLLM client.
|
|
26
|
+
|
|
27
|
+
Checks required keys present.
|
|
27
28
|
|
|
28
29
|
Raises:
|
|
29
30
|
ValueError: Raised in cases of invalid configuration:
|
|
@@ -40,8 +41,7 @@ class RasaLLMClientConfig:
|
|
|
40
41
|
|
|
41
42
|
@classmethod
|
|
42
43
|
def from_dict(cls, config: dict) -> RasaLLMClientConfig:
|
|
43
|
-
"""
|
|
44
|
-
Initializes a dataclass from the passed config.
|
|
44
|
+
"""Initializes a dataclass from the passed config.
|
|
45
45
|
|
|
46
46
|
Args:
|
|
47
47
|
config: (dict) The config from which to initialize.
|
|
@@ -61,8 +61,9 @@ FORBIDDEN_KEYS = [
|
|
|
61
61
|
|
|
62
62
|
@dataclass
|
|
63
63
|
class SelfHostedLLMClientConfig:
|
|
64
|
-
"""Parses configuration for Self Hosted LiteLLM client
|
|
65
|
-
|
|
64
|
+
"""Parses configuration for Self Hosted LiteLLM client.
|
|
65
|
+
|
|
66
|
+
Resolves aliases and raises deprecation warnings.
|
|
66
67
|
|
|
67
68
|
Raises:
|
|
68
69
|
ValueError: Raised in cases of invalid configuration:
|
|
@@ -116,8 +117,7 @@ class SelfHostedLLMClientConfig:
|
|
|
116
117
|
|
|
117
118
|
@classmethod
|
|
118
119
|
def from_dict(cls, config: dict) -> SelfHostedLLMClientConfig:
|
|
119
|
-
"""
|
|
120
|
-
Initializes a dataclass from the passed config.
|
|
120
|
+
"""Initializes a dataclass from the passed config.
|
|
121
121
|
|
|
122
122
|
Args:
|
|
123
123
|
config: (dict) The config from which to initialize.
|
|
@@ -1,12 +1,14 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
import logging
|
|
4
5
|
from abc import abstractmethod
|
|
5
|
-
from typing import Any, Dict, List, Union, cast
|
|
6
|
+
from typing import Any, Dict, List, NoReturn, Union, cast
|
|
6
7
|
|
|
7
8
|
import structlog
|
|
8
9
|
from litellm import acompletion, completion, validate_environment
|
|
9
10
|
|
|
11
|
+
from rasa.core.constants import DEFAULT_REQUEST_TIMEOUT
|
|
10
12
|
from rasa.shared.constants import (
|
|
11
13
|
_VALIDATE_ENVIRONMENT_MISSING_KEYS_KEY,
|
|
12
14
|
API_BASE_CONFIG_KEY,
|
|
@@ -57,26 +59,24 @@ class _BaseLiteLLMClient:
|
|
|
57
59
|
@property
|
|
58
60
|
@abstractmethod
|
|
59
61
|
def config(self) -> dict:
|
|
60
|
-
"""Returns the configuration for that the llm client
|
|
61
|
-
in dictionary form.
|
|
62
|
-
"""
|
|
62
|
+
"""Returns the configuration for that the llm client in dictionary form."""
|
|
63
63
|
pass
|
|
64
64
|
|
|
65
65
|
@property
|
|
66
66
|
@abstractmethod
|
|
67
67
|
def _litellm_model_name(self) -> str:
|
|
68
|
-
"""Returns the value of LiteLLM's model parameter
|
|
69
|
-
completion/acompletion in LiteLLM format:
|
|
68
|
+
"""Returns the value of LiteLLM's model parameter.
|
|
70
69
|
|
|
70
|
+
To be used in completion/acompletion in LiteLLM format:
|
|
71
71
|
<provider>/<model or deployment name>
|
|
72
72
|
"""
|
|
73
73
|
pass
|
|
74
74
|
|
|
75
75
|
@property
|
|
76
76
|
def _litellm_extra_parameters(self) -> Dict[str, Any]:
|
|
77
|
-
"""Returns a dictionary of extra parameters
|
|
78
|
-
parameters as well as LiteLLM specific input parameters.
|
|
77
|
+
"""Returns a dictionary of extra parameters.
|
|
79
78
|
|
|
79
|
+
Includes model parameters as well as LiteLLM specific input parameters.
|
|
80
80
|
By default, this returns an empty dictionary (no extra parameters).
|
|
81
81
|
"""
|
|
82
82
|
return {}
|
|
@@ -96,8 +96,9 @@ class _BaseLiteLLMClient:
|
|
|
96
96
|
}
|
|
97
97
|
|
|
98
98
|
def validate_client_setup(self) -> None:
|
|
99
|
-
"""Perform client validation.
|
|
100
|
-
|
|
99
|
+
"""Perform client validation.
|
|
100
|
+
|
|
101
|
+
By default only environment variables are validated.
|
|
101
102
|
|
|
102
103
|
Raises:
|
|
103
104
|
ProviderClientValidationError if validation fails.
|
|
@@ -178,8 +179,16 @@ class _BaseLiteLLMClient:
|
|
|
178
179
|
try:
|
|
179
180
|
formatted_messages = self._get_formatted_messages(messages)
|
|
180
181
|
arguments = resolve_environment_variables(self._completion_fn_args)
|
|
181
|
-
|
|
182
|
+
timeout = self._litellm_extra_parameters.get(
|
|
183
|
+
"timeout", DEFAULT_REQUEST_TIMEOUT
|
|
184
|
+
)
|
|
185
|
+
response = await asyncio.wait_for(
|
|
186
|
+
acompletion(messages=formatted_messages, **arguments),
|
|
187
|
+
timeout=timeout,
|
|
188
|
+
)
|
|
182
189
|
return self._format_response(response)
|
|
190
|
+
except asyncio.TimeoutError:
|
|
191
|
+
self._handle_timeout_error()
|
|
183
192
|
except Exception as e:
|
|
184
193
|
message = ""
|
|
185
194
|
from rasa.shared.providers.llm.self_hosted_llm_client import (
|
|
@@ -199,6 +208,25 @@ class _BaseLiteLLMClient:
|
|
|
199
208
|
)
|
|
200
209
|
raise ProviderClientAPIException(e, message)
|
|
201
210
|
|
|
211
|
+
def _handle_timeout_error(self) -> NoReturn:
|
|
212
|
+
"""Handle asyncio.TimeoutError and raise ProviderClientAPIException.
|
|
213
|
+
|
|
214
|
+
Raises:
|
|
215
|
+
ProviderClientAPIException: Always raised with formatted timeout error.
|
|
216
|
+
"""
|
|
217
|
+
timeout = self._litellm_extra_parameters.get("timeout", DEFAULT_REQUEST_TIMEOUT)
|
|
218
|
+
error_message = (
|
|
219
|
+
f"APITimeoutError - Request timed out. Error_str: "
|
|
220
|
+
f"Request timed out. - timeout value={timeout:.6f}, "
|
|
221
|
+
f"time taken={timeout:.6f} seconds"
|
|
222
|
+
)
|
|
223
|
+
# nosemgrep: semgrep.rules.pii-positional-arguments-in-logging
|
|
224
|
+
# Error message contains only numeric timeout values, not PII
|
|
225
|
+
structlogger.error(
|
|
226
|
+
f"{self.__class__.__name__.lower()}.llm.timeout", error=error_message
|
|
227
|
+
)
|
|
228
|
+
raise ProviderClientAPIException(asyncio.TimeoutError(error_message)) from None
|
|
229
|
+
|
|
202
230
|
def _get_formatted_messages(
|
|
203
231
|
self, messages: Union[List[dict], List[str], str]
|
|
204
232
|
) -> List[Dict[str, str]]:
|
|
@@ -280,8 +308,9 @@ class _BaseLiteLLMClient:
|
|
|
280
308
|
|
|
281
309
|
@staticmethod
|
|
282
310
|
def _ensure_certificates() -> None:
|
|
283
|
-
"""Configures SSL certificates for LiteLLM.
|
|
284
|
-
|
|
311
|
+
"""Configures SSL certificates for LiteLLM.
|
|
312
|
+
|
|
313
|
+
This method is invoked during client initialization.
|
|
285
314
|
|
|
286
315
|
LiteLLM may utilize `openai` clients or other providers that require
|
|
287
316
|
SSL verification settings through the `SSL_VERIFY` / `SSL_CERTIFICATE`
|
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
import logging
|
|
4
5
|
from typing import Any, Dict, List, Union
|
|
5
6
|
|
|
6
7
|
import structlog
|
|
7
8
|
|
|
9
|
+
from rasa.core.constants import DEFAULT_REQUEST_TIMEOUT
|
|
8
10
|
from rasa.shared.exceptions import ProviderClientAPIException
|
|
9
11
|
from rasa.shared.providers._configs.litellm_router_client_config import (
|
|
10
12
|
LiteLLMRouterClientConfig,
|
|
@@ -79,13 +81,14 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
79
81
|
|
|
80
82
|
@suppress_logs(log_level=logging.WARNING)
|
|
81
83
|
def _text_completion(self, prompt: Union[List[str], str]) -> LLMResponse:
|
|
82
|
-
"""
|
|
83
|
-
Synchronously generate completions for given prompt.
|
|
84
|
+
"""Synchronously generate completions for given prompt.
|
|
84
85
|
|
|
85
86
|
Args:
|
|
86
87
|
prompt: Prompt to generate the completion for.
|
|
88
|
+
|
|
87
89
|
Returns:
|
|
88
90
|
List of message completions.
|
|
91
|
+
|
|
89
92
|
Raises:
|
|
90
93
|
ProviderClientAPIException: If the API request fails.
|
|
91
94
|
"""
|
|
@@ -103,28 +106,36 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
103
106
|
|
|
104
107
|
@suppress_logs(log_level=logging.WARNING)
|
|
105
108
|
async def _atext_completion(self, prompt: Union[List[str], str]) -> LLMResponse:
|
|
106
|
-
"""
|
|
107
|
-
Asynchronously generate completions for given prompt.
|
|
109
|
+
"""Asynchronously generate completions for given prompt.
|
|
108
110
|
|
|
109
111
|
Args:
|
|
110
112
|
prompt: Prompt to generate the completion for.
|
|
113
|
+
|
|
111
114
|
Returns:
|
|
112
115
|
List of message completions.
|
|
116
|
+
|
|
113
117
|
Raises:
|
|
114
118
|
ProviderClientAPIException: If the API request fails.
|
|
115
119
|
"""
|
|
116
120
|
try:
|
|
117
|
-
|
|
118
|
-
|
|
121
|
+
timeout = self._litellm_extra_parameters.get(
|
|
122
|
+
"timeout", DEFAULT_REQUEST_TIMEOUT
|
|
123
|
+
)
|
|
124
|
+
response = await asyncio.wait_for(
|
|
125
|
+
self.router_client.atext_completion(
|
|
126
|
+
prompt=prompt, **self._completion_fn_args
|
|
127
|
+
),
|
|
128
|
+
timeout=timeout,
|
|
119
129
|
)
|
|
120
130
|
return self._format_text_completion_response(response)
|
|
131
|
+
except asyncio.TimeoutError:
|
|
132
|
+
self._handle_timeout_error()
|
|
121
133
|
except Exception as e:
|
|
122
134
|
raise ProviderClientAPIException(e)
|
|
123
135
|
|
|
124
136
|
@suppress_logs(log_level=logging.WARNING)
|
|
125
137
|
def completion(self, messages: Union[List[dict], List[str], str]) -> LLMResponse:
|
|
126
|
-
"""
|
|
127
|
-
Synchronously generate completions for given list of messages.
|
|
138
|
+
"""Synchronously generate completions for given list of messages.
|
|
128
139
|
|
|
129
140
|
Method overrides the base class method to call the appropriate
|
|
130
141
|
completion method based on the configuration. If the chat completions
|
|
@@ -140,15 +151,17 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
140
151
|
- a list of messages. Each message is a string and will be formatted
|
|
141
152
|
as a user message.
|
|
142
153
|
- a single message as a string which will be formatted as user message.
|
|
154
|
+
|
|
143
155
|
Returns:
|
|
144
156
|
List of message completions.
|
|
157
|
+
|
|
145
158
|
Raises:
|
|
146
159
|
ProviderClientAPIException: If the API request fails.
|
|
147
160
|
"""
|
|
148
161
|
if not self._use_chat_completions_endpoint:
|
|
149
162
|
return self._text_completion(messages)
|
|
150
163
|
try:
|
|
151
|
-
formatted_messages = self.
|
|
164
|
+
formatted_messages = self._get_formatted_messages(messages)
|
|
152
165
|
response = self.router_client.completion(
|
|
153
166
|
messages=formatted_messages, **self._completion_fn_args
|
|
154
167
|
)
|
|
@@ -160,8 +173,7 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
160
173
|
async def acompletion(
|
|
161
174
|
self, messages: Union[List[dict], List[str], str]
|
|
162
175
|
) -> LLMResponse:
|
|
163
|
-
"""
|
|
164
|
-
Asynchronously generate completions for given list of messages.
|
|
176
|
+
"""Asynchronously generate completions for given list of messages.
|
|
165
177
|
|
|
166
178
|
Method overrides the base class method to call the appropriate
|
|
167
179
|
completion method based on the configuration. If the chat completions
|
|
@@ -177,28 +189,38 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
177
189
|
- a list of messages. Each message is a string and will be formatted
|
|
178
190
|
as a user message.
|
|
179
191
|
- a single message as a string which will be formatted as user message.
|
|
192
|
+
|
|
180
193
|
Returns:
|
|
181
194
|
List of message completions.
|
|
195
|
+
|
|
182
196
|
Raises:
|
|
183
197
|
ProviderClientAPIException: If the API request fails.
|
|
184
198
|
"""
|
|
185
199
|
if not self._use_chat_completions_endpoint:
|
|
186
200
|
return await self._atext_completion(messages)
|
|
187
201
|
try:
|
|
188
|
-
formatted_messages = self.
|
|
189
|
-
|
|
190
|
-
|
|
202
|
+
formatted_messages = self._get_formatted_messages(messages)
|
|
203
|
+
timeout = self._litellm_extra_parameters.get(
|
|
204
|
+
"timeout", DEFAULT_REQUEST_TIMEOUT
|
|
205
|
+
)
|
|
206
|
+
response = await asyncio.wait_for(
|
|
207
|
+
self.router_client.acompletion(
|
|
208
|
+
messages=formatted_messages, **self._completion_fn_args
|
|
209
|
+
),
|
|
210
|
+
timeout=timeout,
|
|
191
211
|
)
|
|
192
212
|
return self._format_response(response)
|
|
213
|
+
except asyncio.TimeoutError:
|
|
214
|
+
self._handle_timeout_error()
|
|
193
215
|
except Exception as e:
|
|
194
216
|
raise ProviderClientAPIException(e)
|
|
195
217
|
|
|
196
218
|
@property
|
|
197
219
|
def _completion_fn_args(self) -> Dict[str, Any]:
|
|
198
|
-
"""Returns the completion arguments
|
|
199
|
-
LiteLLM's completion functions.
|
|
200
|
-
"""
|
|
220
|
+
"""Returns the completion arguments.
|
|
201
221
|
|
|
222
|
+
For invoking a call through LiteLLM's completion functions.
|
|
223
|
+
"""
|
|
202
224
|
return {
|
|
203
225
|
**self._litellm_extra_parameters,
|
|
204
226
|
LITE_LLM_MODEL_FIELD: self.model_group_id,
|