rasa-pro 3.14.0rc4__py3-none-any.whl → 3.14.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/agents/agent_manager.py +7 -5
- rasa/agents/protocol/a2a/a2a_agent.py +13 -11
- rasa/agents/protocol/mcp/mcp_base_agent.py +49 -11
- rasa/agents/validation.py +4 -2
- rasa/builder/copilot/copilot_templated_message_provider.py +1 -1
- rasa/builder/validation_service.py +4 -0
- rasa/cli/arguments/data.py +9 -0
- rasa/cli/data.py +72 -6
- rasa/cli/interactive.py +3 -0
- rasa/cli/llm_fine_tuning.py +1 -0
- rasa/cli/project_templates/defaults.py +1 -0
- rasa/cli/validation/bot_config.py +2 -0
- rasa/constants.py +2 -1
- rasa/core/actions/action_clean_stack.py +32 -0
- rasa/core/actions/action_exceptions.py +1 -1
- rasa/core/actions/constants.py +4 -0
- rasa/core/actions/custom_action_executor.py +70 -12
- rasa/core/actions/grpc_custom_action_executor.py +41 -2
- rasa/core/actions/http_custom_action_executor.py +49 -25
- rasa/core/agent.py +4 -1
- rasa/core/available_agents.py +1 -1
- rasa/core/channels/voice_stream/browser_audio.py +3 -3
- rasa/core/channels/voice_stream/voice_channel.py +27 -17
- rasa/core/config/credentials.py +3 -3
- rasa/core/exceptions.py +1 -1
- rasa/core/featurizers/tracker_featurizers.py +3 -2
- rasa/core/persistor.py +7 -7
- rasa/core/policies/flows/agent_executor.py +84 -4
- rasa/core/policies/flows/flow_exceptions.py +5 -2
- rasa/core/policies/flows/flow_executor.py +52 -31
- rasa/core/policies/flows/mcp_tool_executor.py +7 -1
- rasa/core/policies/rule_policy.py +1 -1
- rasa/core/run.py +21 -5
- rasa/dialogue_understanding/commands/cancel_flow_command.py +1 -1
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +6 -3
- rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +15 -7
- rasa/dialogue_understanding/generator/single_step/search_ready_llm_command_generator.py +15 -8
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +1 -1
- rasa/dialogue_understanding/processor/command_processor.py +13 -7
- rasa/e2e_test/e2e_config.py +4 -3
- rasa/engine/recipes/default_components.py +16 -6
- rasa/graph_components/validators/default_recipe_validator.py +10 -4
- rasa/model_manager/runner_service.py +1 -1
- rasa/nlu/classifiers/diet_classifier.py +2 -0
- rasa/privacy/privacy_config.py +1 -1
- rasa/shared/agents/auth/auth_strategy/oauth2_auth_strategy.py +4 -7
- rasa/shared/core/flows/flow.py +8 -2
- rasa/shared/core/slots.py +55 -24
- rasa/shared/core/training_data/story_reader/story_reader.py +1 -1
- rasa/shared/exceptions.py +23 -2
- rasa/shared/providers/_configs/azure_openai_client_config.py +4 -5
- rasa/shared/providers/_configs/default_litellm_client_config.py +4 -4
- rasa/shared/providers/_configs/litellm_router_client_config.py +3 -2
- rasa/shared/providers/_configs/openai_client_config.py +5 -7
- rasa/shared/providers/_configs/rasa_llm_client_config.py +4 -4
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +4 -4
- rasa/shared/providers/llm/_base_litellm_client.py +42 -14
- rasa/shared/providers/llm/litellm_router_llm_client.py +40 -17
- rasa/shared/providers/llm/self_hosted_llm_client.py +34 -32
- rasa/shared/utils/common.py +9 -1
- rasa/shared/utils/configs.py +5 -8
- rasa/shared/utils/llm.py +21 -4
- rasa/shared/utils/mcp/server_connection.py +7 -4
- rasa/studio/download.py +3 -0
- rasa/studio/prompts.py +1 -0
- rasa/studio/upload.py +4 -0
- rasa/utils/common.py +9 -0
- rasa/utils/endpoints.py +6 -0
- rasa/utils/installation_utils.py +111 -0
- rasa/utils/log_utils.py +20 -1
- rasa/utils/tensorflow/callback.py +2 -0
- rasa/utils/tensorflow/models.py +3 -0
- rasa/utils/train_utils.py +2 -0
- rasa/version.py +1 -1
- {rasa_pro-3.14.0rc4.dist-info → rasa_pro-3.14.2.dist-info}/METADATA +3 -3
- {rasa_pro-3.14.0rc4.dist-info → rasa_pro-3.14.2.dist-info}/RECORD +79 -78
- {rasa_pro-3.14.0rc4.dist-info → rasa_pro-3.14.2.dist-info}/NOTICE +0 -0
- {rasa_pro-3.14.0rc4.dist-info → rasa_pro-3.14.2.dist-info}/WHEEL +0 -0
- {rasa_pro-3.14.0rc4.dist-info → rasa_pro-3.14.2.dist-info}/entry_points.txt +0 -0
|
@@ -195,7 +195,7 @@ def fetch_remote_model_to_dir(
|
|
|
195
195
|
try:
|
|
196
196
|
return persistor.retrieve(model_name=model_name, target_path=target_path)
|
|
197
197
|
except FileNotFoundError as e:
|
|
198
|
-
raise ModelNotFound() from e
|
|
198
|
+
raise ModelNotFound("Model not found") from e
|
|
199
199
|
|
|
200
200
|
|
|
201
201
|
def fetch_size_of_remote_model(
|
|
@@ -9,9 +9,11 @@ from typing import Any, Dict, List, Optional, Text, Tuple, Type, TypeVar, Union
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
import scipy.sparse
|
|
11
11
|
|
|
12
|
+
from rasa.utils.installation_utils import check_for_installation_issues
|
|
12
13
|
from rasa.utils.tensorflow import TENSORFLOW_AVAILABLE
|
|
13
14
|
|
|
14
15
|
if TENSORFLOW_AVAILABLE:
|
|
16
|
+
check_for_installation_issues()
|
|
15
17
|
import tensorflow as tf
|
|
16
18
|
else:
|
|
17
19
|
tf = None
|
rasa/privacy/privacy_config.py
CHANGED
|
@@ -211,7 +211,7 @@ def get_cron_trigger(cron_expression: str) -> CronTrigger:
|
|
|
211
211
|
"privacy_config.invalid_cron_expression",
|
|
212
212
|
cron=cron_expression,
|
|
213
213
|
)
|
|
214
|
-
raise RasaException from exc
|
|
214
|
+
raise RasaException("Invalid cron expression") from exc
|
|
215
215
|
|
|
216
216
|
return cron
|
|
217
217
|
|
|
@@ -139,20 +139,17 @@ class OAuth2AuthStrategy(AgentAuthStrategy):
|
|
|
139
139
|
resp.raise_for_status()
|
|
140
140
|
token_data = resp.json()
|
|
141
141
|
except httpx.HTTPStatusError as e:
|
|
142
|
-
raise
|
|
143
|
-
f"OAuth2 token request failed with status {e.response.status_code}: "
|
|
144
|
-
f"{e.response.text}"
|
|
145
|
-
) from e
|
|
142
|
+
raise e
|
|
146
143
|
except httpx.RequestError as e:
|
|
147
|
-
raise ValueError(f"OAuth2 token request failed
|
|
144
|
+
raise ValueError(f"OAuth2 token request failed - {e}") from e
|
|
148
145
|
except Exception as e:
|
|
149
146
|
raise ValueError(
|
|
150
|
-
f"Unexpected error during OAuth2 token request
|
|
147
|
+
f"Unexpected error during OAuth2 token request - {e}"
|
|
151
148
|
) from e
|
|
152
149
|
|
|
153
150
|
# Validate token data
|
|
154
151
|
if KEY_ACCESS_TOKEN not in token_data:
|
|
155
|
-
raise ValueError(f"No {KEY_ACCESS_TOKEN} in OAuth2 response")
|
|
152
|
+
raise ValueError(f"No `{KEY_ACCESS_TOKEN}` in OAuth2 response")
|
|
156
153
|
|
|
157
154
|
# Set access token and expires at
|
|
158
155
|
self._access_token = token_data[KEY_ACCESS_TOKEN]
|
rasa/shared/core/flows/flow.py
CHANGED
|
@@ -322,9 +322,15 @@ class Flow:
|
|
|
322
322
|
|
|
323
323
|
def get_collect_steps(self) -> List[CollectInformationFlowStep]:
|
|
324
324
|
"""Return all CollectInformationFlowSteps in the flow."""
|
|
325
|
-
collect_steps = []
|
|
325
|
+
collect_steps: List[CollectInformationFlowStep] = []
|
|
326
326
|
for step in self.steps_with_calls_resolved:
|
|
327
|
-
|
|
327
|
+
# Only add collect steps that are not already in the list.
|
|
328
|
+
# This is to avoid returning duplicate collect steps from called flows
|
|
329
|
+
# in case the called flow is called multiple times.
|
|
330
|
+
if (
|
|
331
|
+
isinstance(step, CollectInformationFlowStep)
|
|
332
|
+
and step not in collect_steps
|
|
333
|
+
):
|
|
328
334
|
collect_steps.append(step)
|
|
329
335
|
return collect_steps
|
|
330
336
|
|
rasa/shared/core/slots.py
CHANGED
|
@@ -355,8 +355,8 @@ class FloatSlot(Slot):
|
|
|
355
355
|
mappings: List[Dict[Text, Any]],
|
|
356
356
|
initial_value: Optional[float] = None,
|
|
357
357
|
value_reset_delay: Optional[int] = None,
|
|
358
|
-
max_value: float =
|
|
359
|
-
min_value: float =
|
|
358
|
+
max_value: Optional[float] = None,
|
|
359
|
+
min_value: Optional[float] = None,
|
|
360
360
|
influence_conversation: bool = True,
|
|
361
361
|
is_builtin: bool = False,
|
|
362
362
|
shared_for_coexistence: bool = False,
|
|
@@ -380,32 +380,24 @@ class FloatSlot(Slot):
|
|
|
380
380
|
filled_by=filled_by,
|
|
381
381
|
validation=validation,
|
|
382
382
|
)
|
|
383
|
+
self.validate_min_max_range(min_value, max_value)
|
|
384
|
+
|
|
383
385
|
self.max_value = max_value
|
|
384
386
|
self.min_value = min_value
|
|
385
387
|
|
|
386
|
-
if min_value >= max_value:
|
|
387
|
-
raise InvalidSlotConfigError(
|
|
388
|
-
"Float slot ('{}') created with an invalid range "
|
|
389
|
-
"using min ({}) and max ({}) values. Make sure "
|
|
390
|
-
"min is smaller than max."
|
|
391
|
-
"".format(self.name, self.min_value, self.max_value)
|
|
392
|
-
)
|
|
393
|
-
|
|
394
|
-
if initial_value is not None and not (min_value <= initial_value <= max_value):
|
|
395
|
-
rasa.shared.utils.io.raise_warning(
|
|
396
|
-
f"Float slot ('{self.name}') created with an initial value "
|
|
397
|
-
f"{self.value}. This value is outside of the configured min "
|
|
398
|
-
f"({self.min_value}) and max ({self.max_value}) values."
|
|
399
|
-
)
|
|
400
|
-
|
|
401
388
|
def _as_feature(self) -> List[float]:
|
|
389
|
+
# set default min and max values used in prior releases
|
|
390
|
+
# to prevent regressions for existing models
|
|
391
|
+
min_value = self.min_value or 0.0
|
|
392
|
+
max_value = self.max_value or 1.0
|
|
393
|
+
|
|
402
394
|
try:
|
|
403
|
-
capped_value = max(
|
|
404
|
-
if abs(
|
|
405
|
-
covered_range = abs(
|
|
395
|
+
capped_value = max(min_value, min(max_value, float(self.value)))
|
|
396
|
+
if abs(max_value - min_value) > 0:
|
|
397
|
+
covered_range = abs(max_value - min_value)
|
|
406
398
|
else:
|
|
407
399
|
covered_range = 1
|
|
408
|
-
return [1.0, (capped_value -
|
|
400
|
+
return [1.0, (capped_value - min_value) / covered_range]
|
|
409
401
|
except (TypeError, ValueError):
|
|
410
402
|
return [0.0, 0.0]
|
|
411
403
|
|
|
@@ -424,13 +416,52 @@ class FloatSlot(Slot):
|
|
|
424
416
|
return value
|
|
425
417
|
|
|
426
418
|
def is_valid_value(self, value: Any) -> bool:
|
|
427
|
-
"""Checks if the slot
|
|
428
|
-
|
|
429
|
-
|
|
419
|
+
"""Checks if the slot value is valid."""
|
|
420
|
+
if value is None:
|
|
421
|
+
return True
|
|
422
|
+
|
|
423
|
+
if not isinstance(self.coerce_value(value), float):
|
|
424
|
+
return False
|
|
425
|
+
|
|
426
|
+
if (
|
|
427
|
+
self.min_value is not None
|
|
428
|
+
and self.max_value is not None
|
|
429
|
+
and not (self.min_value <= value <= self.max_value)
|
|
430
|
+
):
|
|
431
|
+
return False
|
|
432
|
+
|
|
433
|
+
return True
|
|
430
434
|
|
|
431
435
|
def _feature_dimensionality(self) -> int:
|
|
432
436
|
return len(self.as_feature())
|
|
433
437
|
|
|
438
|
+
def validate_min_max_range(
|
|
439
|
+
self, min_value: Optional[float], max_value: Optional[float]
|
|
440
|
+
) -> None:
|
|
441
|
+
"""Validates the min-max range for the slot.
|
|
442
|
+
|
|
443
|
+
Raises:
|
|
444
|
+
InvalidSlotConfigError, if the min-max range is invalid.
|
|
445
|
+
"""
|
|
446
|
+
if min_value is not None and max_value is not None and min_value >= max_value:
|
|
447
|
+
raise InvalidSlotConfigError(
|
|
448
|
+
f"Float slot ('{self.name}') created with an invalid range "
|
|
449
|
+
f"using min ({min_value}) and max ({max_value}) values. Make sure "
|
|
450
|
+
f"min is smaller than max."
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
if (
|
|
454
|
+
self.initial_value is not None
|
|
455
|
+
and min_value is not None
|
|
456
|
+
and max_value is not None
|
|
457
|
+
and not (min_value <= self.initial_value <= max_value)
|
|
458
|
+
):
|
|
459
|
+
raise InvalidSlotConfigError(
|
|
460
|
+
f"Float slot ('{self.name}') created with an initial value "
|
|
461
|
+
f"{self.initial_value}. This value is outside of the configured min "
|
|
462
|
+
f"({min_value}) and max ({max_value}) values."
|
|
463
|
+
)
|
|
464
|
+
|
|
434
465
|
|
|
435
466
|
class BooleanSlot(Slot):
|
|
436
467
|
"""A slot storing a truth value."""
|
rasa/shared/exceptions.py
CHANGED
|
@@ -16,6 +16,17 @@ class RasaException(Exception):
|
|
|
16
16
|
to the users, but will be ignored in telemetry.
|
|
17
17
|
"""
|
|
18
18
|
|
|
19
|
+
def __init__(self, message: str, suppress_stack_trace: bool = False, **kwargs: Any):
|
|
20
|
+
"""Initialize the exception.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
message: The error message.
|
|
24
|
+
suppress_stack_trace: If True, the stack trace will be suppressed in logs.
|
|
25
|
+
**kwargs: Additional keyword arguments (e.g., cause for exception chaining).
|
|
26
|
+
"""
|
|
27
|
+
Exception.__init__(self, message)
|
|
28
|
+
self.suppress_stack_trace = suppress_stack_trace
|
|
29
|
+
|
|
19
30
|
|
|
20
31
|
class RasaCoreException(RasaException):
|
|
21
32
|
"""Basic exception for errors raised by Rasa Core."""
|
|
@@ -113,6 +124,17 @@ class SchemaValidationError(RasaException, jsonschema.ValidationError):
|
|
|
113
124
|
class InvalidEntityFormatException(RasaException, json.JSONDecodeError):
|
|
114
125
|
"""Raised if the format of an entity is invalid."""
|
|
115
126
|
|
|
127
|
+
def __init__(self, msg: str, doc: str = "", pos: int = 0):
|
|
128
|
+
"""Initialize the exception.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
msg: The error message.
|
|
132
|
+
doc: The document that caused the error.
|
|
133
|
+
pos: The position in the document where the error occurred.
|
|
134
|
+
"""
|
|
135
|
+
RasaException.__init__(self, msg)
|
|
136
|
+
json.JSONDecodeError.__init__(self, msg, doc, pos)
|
|
137
|
+
|
|
116
138
|
@classmethod
|
|
117
139
|
def create_from(
|
|
118
140
|
cls, other: json.JSONDecodeError, msg: Text
|
|
@@ -130,8 +152,7 @@ class ConnectionException(RasaException):
|
|
|
130
152
|
|
|
131
153
|
|
|
132
154
|
class ProviderClientAPIException(RasaException):
|
|
133
|
-
"""
|
|
134
|
-
with LLM / embedding providers.
|
|
155
|
+
"""For errors during API interactions with LLM / embedding providers.
|
|
135
156
|
|
|
136
157
|
Attributes:
|
|
137
158
|
original_exception (Exception): The original exception that was
|
|
@@ -167,8 +167,9 @@ class OAuthConfigWrapper(OAuth, BaseModel):
|
|
|
167
167
|
|
|
168
168
|
@dataclass
|
|
169
169
|
class AzureOpenAIClientConfig:
|
|
170
|
-
"""Parses configuration for Azure OpenAI client
|
|
171
|
-
|
|
170
|
+
"""Parses configuration for Azure OpenAI client.
|
|
171
|
+
|
|
172
|
+
Resolves aliases and raises deprecation warnings.
|
|
172
173
|
|
|
173
174
|
Raises:
|
|
174
175
|
ValueError: Raised in cases of invalid configuration:
|
|
@@ -301,9 +302,7 @@ class AzureOpenAIClientConfig:
|
|
|
301
302
|
|
|
302
303
|
|
|
303
304
|
def is_azure_openai_config(config: dict) -> bool:
|
|
304
|
-
"""Check whether the configuration is meant to configure
|
|
305
|
-
an Azure OpenAI client.
|
|
306
|
-
"""
|
|
305
|
+
"""Check whether the configuration is meant to configure an Azure OpenAI client."""
|
|
307
306
|
# Resolve any aliases that are specific to Azure OpenAI configuration
|
|
308
307
|
config = AzureOpenAIClientConfig.resolve_config_aliases(config)
|
|
309
308
|
|
|
@@ -40,8 +40,9 @@ FORBIDDEN_KEYS = [
|
|
|
40
40
|
|
|
41
41
|
@dataclass
|
|
42
42
|
class DefaultLiteLLMClientConfig:
|
|
43
|
-
"""Parses configuration for default LiteLLM client
|
|
44
|
-
|
|
43
|
+
"""Parses configuration for default LiteLLM client.
|
|
44
|
+
|
|
45
|
+
Resolves aliases and raises deprecation warnings.
|
|
45
46
|
|
|
46
47
|
Raises:
|
|
47
48
|
ValueError: Raised in cases of invalid configuration:
|
|
@@ -72,8 +73,7 @@ class DefaultLiteLLMClientConfig:
|
|
|
72
73
|
|
|
73
74
|
@classmethod
|
|
74
75
|
def from_dict(cls, config: dict) -> DefaultLiteLLMClientConfig:
|
|
75
|
-
"""
|
|
76
|
-
Initializes a dataclass from the passed config.
|
|
76
|
+
"""Initializes a dataclass from the passed config.
|
|
77
77
|
|
|
78
78
|
Args:
|
|
79
79
|
config: (dict) The config from which to initialize.
|
|
@@ -38,8 +38,9 @@ _LITELLM_UNSUPPORTED_KEYS = [
|
|
|
38
38
|
|
|
39
39
|
@dataclass
|
|
40
40
|
class LiteLLMRouterClientConfig:
|
|
41
|
-
"""Parses configuration for a LiteLLM Router client.
|
|
42
|
-
|
|
41
|
+
"""Parses configuration for a LiteLLM Router client.
|
|
42
|
+
|
|
43
|
+
The configuration is expected to be in the following format:
|
|
43
44
|
|
|
44
45
|
{
|
|
45
46
|
"id": "model_group_id",
|
|
@@ -64,8 +64,9 @@ FORBIDDEN_KEYS = [
|
|
|
64
64
|
|
|
65
65
|
@dataclass
|
|
66
66
|
class OpenAIClientConfig:
|
|
67
|
-
"""Parses configuration for
|
|
68
|
-
|
|
67
|
+
"""Parses configuration for OpenAI client.
|
|
68
|
+
|
|
69
|
+
Resolves aliases and raises deprecation warnings.
|
|
69
70
|
|
|
70
71
|
Raises:
|
|
71
72
|
ValueError: Raised in cases of invalid configuration:
|
|
@@ -118,8 +119,7 @@ class OpenAIClientConfig:
|
|
|
118
119
|
|
|
119
120
|
@classmethod
|
|
120
121
|
def from_dict(cls, config: dict) -> OpenAIClientConfig:
|
|
121
|
-
"""
|
|
122
|
-
Initializes a dataclass from the passed config.
|
|
122
|
+
"""Initializes a dataclass from the passed config.
|
|
123
123
|
|
|
124
124
|
Args:
|
|
125
125
|
config: (dict) The config from which to initialize.
|
|
@@ -168,9 +168,7 @@ class OpenAIClientConfig:
|
|
|
168
168
|
|
|
169
169
|
|
|
170
170
|
def is_openai_config(config: dict) -> bool:
|
|
171
|
-
"""Check whether the configuration is meant to configure
|
|
172
|
-
an OpenAI client.
|
|
173
|
-
"""
|
|
171
|
+
"""Check whether the configuration is meant to configure an OpenAI client."""
|
|
174
172
|
# Process the config to handle all the aliases
|
|
175
173
|
config = OpenAIClientConfig.resolve_config_aliases(config)
|
|
176
174
|
|
|
@@ -22,8 +22,9 @@ structlogger = structlog.get_logger()
|
|
|
22
22
|
|
|
23
23
|
@dataclass
|
|
24
24
|
class RasaLLMClientConfig:
|
|
25
|
-
"""Parses configuration for a Rasa Hosted LiteLLM client
|
|
26
|
-
|
|
25
|
+
"""Parses configuration for a Rasa Hosted LiteLLM client.
|
|
26
|
+
|
|
27
|
+
Checks required keys present.
|
|
27
28
|
|
|
28
29
|
Raises:
|
|
29
30
|
ValueError: Raised in cases of invalid configuration:
|
|
@@ -40,8 +41,7 @@ class RasaLLMClientConfig:
|
|
|
40
41
|
|
|
41
42
|
@classmethod
|
|
42
43
|
def from_dict(cls, config: dict) -> RasaLLMClientConfig:
|
|
43
|
-
"""
|
|
44
|
-
Initializes a dataclass from the passed config.
|
|
44
|
+
"""Initializes a dataclass from the passed config.
|
|
45
45
|
|
|
46
46
|
Args:
|
|
47
47
|
config: (dict) The config from which to initialize.
|
|
@@ -61,8 +61,9 @@ FORBIDDEN_KEYS = [
|
|
|
61
61
|
|
|
62
62
|
@dataclass
|
|
63
63
|
class SelfHostedLLMClientConfig:
|
|
64
|
-
"""Parses configuration for Self Hosted LiteLLM client
|
|
65
|
-
|
|
64
|
+
"""Parses configuration for Self Hosted LiteLLM client.
|
|
65
|
+
|
|
66
|
+
Resolves aliases and raises deprecation warnings.
|
|
66
67
|
|
|
67
68
|
Raises:
|
|
68
69
|
ValueError: Raised in cases of invalid configuration:
|
|
@@ -116,8 +117,7 @@ class SelfHostedLLMClientConfig:
|
|
|
116
117
|
|
|
117
118
|
@classmethod
|
|
118
119
|
def from_dict(cls, config: dict) -> SelfHostedLLMClientConfig:
|
|
119
|
-
"""
|
|
120
|
-
Initializes a dataclass from the passed config.
|
|
120
|
+
"""Initializes a dataclass from the passed config.
|
|
121
121
|
|
|
122
122
|
Args:
|
|
123
123
|
config: (dict) The config from which to initialize.
|
|
@@ -1,12 +1,14 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
import logging
|
|
4
5
|
from abc import abstractmethod
|
|
5
|
-
from typing import Any, Dict, List, Union, cast
|
|
6
|
+
from typing import Any, Dict, List, NoReturn, Union, cast
|
|
6
7
|
|
|
7
8
|
import structlog
|
|
8
9
|
from litellm import acompletion, completion, validate_environment
|
|
9
10
|
|
|
11
|
+
from rasa.core.constants import DEFAULT_REQUEST_TIMEOUT
|
|
10
12
|
from rasa.shared.constants import (
|
|
11
13
|
_VALIDATE_ENVIRONMENT_MISSING_KEYS_KEY,
|
|
12
14
|
API_BASE_CONFIG_KEY,
|
|
@@ -57,26 +59,24 @@ class _BaseLiteLLMClient:
|
|
|
57
59
|
@property
|
|
58
60
|
@abstractmethod
|
|
59
61
|
def config(self) -> dict:
|
|
60
|
-
"""Returns the configuration for that the llm client
|
|
61
|
-
in dictionary form.
|
|
62
|
-
"""
|
|
62
|
+
"""Returns the configuration for that the llm client in dictionary form."""
|
|
63
63
|
pass
|
|
64
64
|
|
|
65
65
|
@property
|
|
66
66
|
@abstractmethod
|
|
67
67
|
def _litellm_model_name(self) -> str:
|
|
68
|
-
"""Returns the value of LiteLLM's model parameter
|
|
69
|
-
completion/acompletion in LiteLLM format:
|
|
68
|
+
"""Returns the value of LiteLLM's model parameter.
|
|
70
69
|
|
|
70
|
+
To be used in completion/acompletion in LiteLLM format:
|
|
71
71
|
<provider>/<model or deployment name>
|
|
72
72
|
"""
|
|
73
73
|
pass
|
|
74
74
|
|
|
75
75
|
@property
|
|
76
76
|
def _litellm_extra_parameters(self) -> Dict[str, Any]:
|
|
77
|
-
"""Returns a dictionary of extra parameters
|
|
78
|
-
parameters as well as LiteLLM specific input parameters.
|
|
77
|
+
"""Returns a dictionary of extra parameters.
|
|
79
78
|
|
|
79
|
+
Includes model parameters as well as LiteLLM specific input parameters.
|
|
80
80
|
By default, this returns an empty dictionary (no extra parameters).
|
|
81
81
|
"""
|
|
82
82
|
return {}
|
|
@@ -96,8 +96,9 @@ class _BaseLiteLLMClient:
|
|
|
96
96
|
}
|
|
97
97
|
|
|
98
98
|
def validate_client_setup(self) -> None:
|
|
99
|
-
"""Perform client validation.
|
|
100
|
-
|
|
99
|
+
"""Perform client validation.
|
|
100
|
+
|
|
101
|
+
By default only environment variables are validated.
|
|
101
102
|
|
|
102
103
|
Raises:
|
|
103
104
|
ProviderClientValidationError if validation fails.
|
|
@@ -188,10 +189,17 @@ class _BaseLiteLLMClient:
|
|
|
188
189
|
arguments = cast(
|
|
189
190
|
Dict[str, Any], resolve_environment_variables(self._completion_fn_args)
|
|
190
191
|
)
|
|
191
|
-
|
|
192
|
-
|
|
192
|
+
|
|
193
|
+
timeout = self._litellm_extra_parameters.get(
|
|
194
|
+
"timeout", DEFAULT_REQUEST_TIMEOUT
|
|
195
|
+
)
|
|
196
|
+
response = await asyncio.wait_for(
|
|
197
|
+
acompletion(messages=formatted_messages, **{**arguments, **kwargs}),
|
|
198
|
+
timeout=timeout,
|
|
193
199
|
)
|
|
194
200
|
return self._format_response(response)
|
|
201
|
+
except asyncio.TimeoutError:
|
|
202
|
+
self._handle_timeout_error()
|
|
195
203
|
except Exception as e:
|
|
196
204
|
message = ""
|
|
197
205
|
from rasa.shared.providers.llm.self_hosted_llm_client import (
|
|
@@ -211,6 +219,25 @@ class _BaseLiteLLMClient:
|
|
|
211
219
|
)
|
|
212
220
|
raise ProviderClientAPIException(e, message) from e
|
|
213
221
|
|
|
222
|
+
def _handle_timeout_error(self) -> NoReturn:
|
|
223
|
+
"""Handle asyncio.TimeoutError and raise ProviderClientAPIException.
|
|
224
|
+
|
|
225
|
+
Raises:
|
|
226
|
+
ProviderClientAPIException: Always raised with formatted timeout error.
|
|
227
|
+
"""
|
|
228
|
+
timeout = self._litellm_extra_parameters.get("timeout", DEFAULT_REQUEST_TIMEOUT)
|
|
229
|
+
error_message = (
|
|
230
|
+
f"APITimeoutError - Request timed out. Error_str: "
|
|
231
|
+
f"Request timed out. - timeout value={timeout:.6f}, "
|
|
232
|
+
f"time taken={timeout:.6f} seconds"
|
|
233
|
+
)
|
|
234
|
+
# nosemgrep: semgrep.rules.pii-positional-arguments-in-logging
|
|
235
|
+
# Error message contains only numeric timeout values, not PII
|
|
236
|
+
structlogger.error(
|
|
237
|
+
f"{self.__class__.__name__.lower()}.llm.timeout", error=error_message
|
|
238
|
+
)
|
|
239
|
+
raise ProviderClientAPIException(asyncio.TimeoutError(error_message)) from None
|
|
240
|
+
|
|
214
241
|
def _get_formatted_messages(
|
|
215
242
|
self, messages: Union[List[dict], List[str], str]
|
|
216
243
|
) -> List[Dict[str, str]]:
|
|
@@ -312,8 +339,9 @@ class _BaseLiteLLMClient:
|
|
|
312
339
|
|
|
313
340
|
@staticmethod
|
|
314
341
|
def _ensure_certificates() -> None:
|
|
315
|
-
"""Configures SSL certificates for LiteLLM.
|
|
316
|
-
|
|
342
|
+
"""Configures SSL certificates for LiteLLM.
|
|
343
|
+
|
|
344
|
+
This method is invoked during client initialization.
|
|
317
345
|
|
|
318
346
|
LiteLLM may utilize `openai` clients or other providers that require
|
|
319
347
|
SSL verification settings through the `SSL_VERIFY` / `SSL_CERTIFICATE`
|
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
import logging
|
|
4
5
|
from typing import Any, Dict, List, Union
|
|
5
6
|
|
|
6
7
|
import structlog
|
|
7
8
|
|
|
9
|
+
from rasa.core.constants import DEFAULT_REQUEST_TIMEOUT
|
|
8
10
|
from rasa.shared.exceptions import ProviderClientAPIException
|
|
9
11
|
from rasa.shared.providers._configs.litellm_router_client_config import (
|
|
10
12
|
LiteLLMRouterClientConfig,
|
|
@@ -79,13 +81,14 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
79
81
|
|
|
80
82
|
@suppress_logs(log_level=logging.WARNING)
|
|
81
83
|
def _text_completion(self, prompt: Union[List[str], str]) -> LLMResponse:
|
|
82
|
-
"""
|
|
83
|
-
Synchronously generate completions for given prompt.
|
|
84
|
+
"""Synchronously generate completions for given prompt.
|
|
84
85
|
|
|
85
86
|
Args:
|
|
86
87
|
prompt: Prompt to generate the completion for.
|
|
88
|
+
|
|
87
89
|
Returns:
|
|
88
90
|
List of message completions.
|
|
91
|
+
|
|
89
92
|
Raises:
|
|
90
93
|
ProviderClientAPIException: If the API request fails.
|
|
91
94
|
"""
|
|
@@ -103,21 +106,30 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
103
106
|
|
|
104
107
|
@suppress_logs(log_level=logging.WARNING)
|
|
105
108
|
async def _atext_completion(self, prompt: Union[List[str], str]) -> LLMResponse:
|
|
106
|
-
"""
|
|
107
|
-
Asynchronously generate completions for given prompt.
|
|
109
|
+
"""Asynchronously generate completions for given prompt.
|
|
108
110
|
|
|
109
111
|
Args:
|
|
110
112
|
prompt: Prompt to generate the completion for.
|
|
113
|
+
|
|
111
114
|
Returns:
|
|
112
115
|
List of message completions.
|
|
116
|
+
|
|
113
117
|
Raises:
|
|
114
118
|
ProviderClientAPIException: If the API request fails.
|
|
115
119
|
"""
|
|
116
120
|
try:
|
|
117
|
-
|
|
118
|
-
|
|
121
|
+
timeout = self._litellm_extra_parameters.get(
|
|
122
|
+
"timeout", DEFAULT_REQUEST_TIMEOUT
|
|
123
|
+
)
|
|
124
|
+
response = await asyncio.wait_for(
|
|
125
|
+
self.router_client.atext_completion(
|
|
126
|
+
prompt=prompt, **self._completion_fn_args
|
|
127
|
+
),
|
|
128
|
+
timeout=timeout,
|
|
119
129
|
)
|
|
120
130
|
return self._format_text_completion_response(response)
|
|
131
|
+
except asyncio.TimeoutError:
|
|
132
|
+
self._handle_timeout_error()
|
|
121
133
|
except Exception as e:
|
|
122
134
|
raise ProviderClientAPIException(e)
|
|
123
135
|
|
|
@@ -125,8 +137,7 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
125
137
|
def completion(
|
|
126
138
|
self, messages: Union[List[dict], List[str], str], **kwargs: Any
|
|
127
139
|
) -> LLMResponse:
|
|
128
|
-
"""
|
|
129
|
-
Synchronously generate completions for given list of messages.
|
|
140
|
+
"""Synchronously generate completions for given list of messages.
|
|
130
141
|
|
|
131
142
|
Method overrides the base class method to call the appropriate
|
|
132
143
|
completion method based on the configuration. If the chat completions
|
|
@@ -143,15 +154,17 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
143
154
|
as a user message.
|
|
144
155
|
- a single message as a string which will be formatted as user message.
|
|
145
156
|
**kwargs: Additional parameters to pass to the completion call.
|
|
157
|
+
|
|
146
158
|
Returns:
|
|
147
159
|
List of message completions.
|
|
160
|
+
|
|
148
161
|
Raises:
|
|
149
162
|
ProviderClientAPIException: If the API request fails.
|
|
150
163
|
"""
|
|
151
164
|
if not self._use_chat_completions_endpoint:
|
|
152
165
|
return self._text_completion(messages)
|
|
153
166
|
try:
|
|
154
|
-
formatted_messages = self.
|
|
167
|
+
formatted_messages = self._get_formatted_messages(messages)
|
|
155
168
|
response = self.router_client.completion(
|
|
156
169
|
messages=formatted_messages, **{**self._completion_fn_args, **kwargs}
|
|
157
170
|
)
|
|
@@ -163,8 +176,7 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
163
176
|
async def acompletion(
|
|
164
177
|
self, messages: Union[List[dict], List[str], str], **kwargs: Any
|
|
165
178
|
) -> LLMResponse:
|
|
166
|
-
"""
|
|
167
|
-
Asynchronously generate completions for given list of messages.
|
|
179
|
+
"""Asynchronously generate completions for given list of messages.
|
|
168
180
|
|
|
169
181
|
Method overrides the base class method to call the appropriate
|
|
170
182
|
completion method based on the configuration. If the chat completions
|
|
@@ -181,28 +193,39 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
181
193
|
as a user message.
|
|
182
194
|
- a single message as a string which will be formatted as user message.
|
|
183
195
|
**kwargs: Additional parameters to pass to the completion call.
|
|
196
|
+
|
|
184
197
|
Returns:
|
|
185
198
|
List of message completions.
|
|
199
|
+
|
|
186
200
|
Raises:
|
|
187
201
|
ProviderClientAPIException: If the API request fails.
|
|
188
202
|
"""
|
|
189
203
|
if not self._use_chat_completions_endpoint:
|
|
190
204
|
return await self._atext_completion(messages)
|
|
191
205
|
try:
|
|
192
|
-
formatted_messages = self.
|
|
193
|
-
|
|
194
|
-
|
|
206
|
+
formatted_messages = self._get_formatted_messages(messages)
|
|
207
|
+
timeout = self._litellm_extra_parameters.get(
|
|
208
|
+
"timeout", DEFAULT_REQUEST_TIMEOUT
|
|
209
|
+
)
|
|
210
|
+
response = await asyncio.wait_for(
|
|
211
|
+
self.router_client.acompletion(
|
|
212
|
+
messages=formatted_messages,
|
|
213
|
+
**{**self._completion_fn_args, **kwargs},
|
|
214
|
+
),
|
|
215
|
+
timeout=timeout,
|
|
195
216
|
)
|
|
196
217
|
return self._format_response(response)
|
|
218
|
+
except asyncio.TimeoutError:
|
|
219
|
+
self._handle_timeout_error()
|
|
197
220
|
except Exception as e:
|
|
198
221
|
raise ProviderClientAPIException(e)
|
|
199
222
|
|
|
200
223
|
@property
|
|
201
224
|
def _completion_fn_args(self) -> Dict[str, Any]:
|
|
202
|
-
"""Returns the completion arguments
|
|
203
|
-
LiteLLM's completion functions.
|
|
204
|
-
"""
|
|
225
|
+
"""Returns the completion arguments.
|
|
205
226
|
|
|
227
|
+
For invoking a call through LiteLLM's completion functions.
|
|
228
|
+
"""
|
|
206
229
|
return {
|
|
207
230
|
**self._litellm_extra_parameters,
|
|
208
231
|
LITE_LLM_MODEL_FIELD: self.model_group_id,
|