rasa-pro 3.14.0rc4__py3-none-any.whl → 3.14.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (79) hide show
  1. rasa/agents/agent_manager.py +7 -5
  2. rasa/agents/protocol/a2a/a2a_agent.py +13 -11
  3. rasa/agents/protocol/mcp/mcp_base_agent.py +49 -11
  4. rasa/agents/validation.py +4 -2
  5. rasa/builder/copilot/copilot_templated_message_provider.py +1 -1
  6. rasa/builder/validation_service.py +4 -0
  7. rasa/cli/arguments/data.py +9 -0
  8. rasa/cli/data.py +72 -6
  9. rasa/cli/interactive.py +3 -0
  10. rasa/cli/llm_fine_tuning.py +1 -0
  11. rasa/cli/project_templates/defaults.py +1 -0
  12. rasa/cli/validation/bot_config.py +2 -0
  13. rasa/constants.py +2 -1
  14. rasa/core/actions/action_clean_stack.py +32 -0
  15. rasa/core/actions/action_exceptions.py +1 -1
  16. rasa/core/actions/constants.py +4 -0
  17. rasa/core/actions/custom_action_executor.py +70 -12
  18. rasa/core/actions/grpc_custom_action_executor.py +41 -2
  19. rasa/core/actions/http_custom_action_executor.py +49 -25
  20. rasa/core/agent.py +4 -1
  21. rasa/core/available_agents.py +1 -1
  22. rasa/core/channels/voice_stream/browser_audio.py +3 -3
  23. rasa/core/channels/voice_stream/voice_channel.py +27 -17
  24. rasa/core/config/credentials.py +3 -3
  25. rasa/core/exceptions.py +1 -1
  26. rasa/core/featurizers/tracker_featurizers.py +3 -2
  27. rasa/core/persistor.py +7 -7
  28. rasa/core/policies/flows/agent_executor.py +84 -4
  29. rasa/core/policies/flows/flow_exceptions.py +5 -2
  30. rasa/core/policies/flows/flow_executor.py +52 -31
  31. rasa/core/policies/flows/mcp_tool_executor.py +7 -1
  32. rasa/core/policies/rule_policy.py +1 -1
  33. rasa/core/run.py +21 -5
  34. rasa/dialogue_understanding/commands/cancel_flow_command.py +1 -1
  35. rasa/dialogue_understanding/generator/llm_based_command_generator.py +6 -3
  36. rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +15 -7
  37. rasa/dialogue_understanding/generator/single_step/search_ready_llm_command_generator.py +15 -8
  38. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +1 -1
  39. rasa/dialogue_understanding/processor/command_processor.py +13 -7
  40. rasa/e2e_test/e2e_config.py +4 -3
  41. rasa/engine/recipes/default_components.py +16 -6
  42. rasa/graph_components/validators/default_recipe_validator.py +10 -4
  43. rasa/model_manager/runner_service.py +1 -1
  44. rasa/nlu/classifiers/diet_classifier.py +2 -0
  45. rasa/privacy/privacy_config.py +1 -1
  46. rasa/shared/agents/auth/auth_strategy/oauth2_auth_strategy.py +4 -7
  47. rasa/shared/core/flows/flow.py +8 -2
  48. rasa/shared/core/slots.py +55 -24
  49. rasa/shared/core/training_data/story_reader/story_reader.py +1 -1
  50. rasa/shared/exceptions.py +23 -2
  51. rasa/shared/providers/_configs/azure_openai_client_config.py +4 -5
  52. rasa/shared/providers/_configs/default_litellm_client_config.py +4 -4
  53. rasa/shared/providers/_configs/litellm_router_client_config.py +3 -2
  54. rasa/shared/providers/_configs/openai_client_config.py +5 -7
  55. rasa/shared/providers/_configs/rasa_llm_client_config.py +4 -4
  56. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +4 -4
  57. rasa/shared/providers/llm/_base_litellm_client.py +42 -14
  58. rasa/shared/providers/llm/litellm_router_llm_client.py +40 -17
  59. rasa/shared/providers/llm/self_hosted_llm_client.py +34 -32
  60. rasa/shared/utils/common.py +9 -1
  61. rasa/shared/utils/configs.py +5 -8
  62. rasa/shared/utils/llm.py +21 -4
  63. rasa/shared/utils/mcp/server_connection.py +7 -4
  64. rasa/studio/download.py +3 -0
  65. rasa/studio/prompts.py +1 -0
  66. rasa/studio/upload.py +4 -0
  67. rasa/utils/common.py +9 -0
  68. rasa/utils/endpoints.py +6 -0
  69. rasa/utils/installation_utils.py +111 -0
  70. rasa/utils/log_utils.py +20 -1
  71. rasa/utils/tensorflow/callback.py +2 -0
  72. rasa/utils/tensorflow/models.py +3 -0
  73. rasa/utils/train_utils.py +2 -0
  74. rasa/version.py +1 -1
  75. {rasa_pro-3.14.0rc4.dist-info → rasa_pro-3.14.2.dist-info}/METADATA +3 -3
  76. {rasa_pro-3.14.0rc4.dist-info → rasa_pro-3.14.2.dist-info}/RECORD +79 -78
  77. {rasa_pro-3.14.0rc4.dist-info → rasa_pro-3.14.2.dist-info}/NOTICE +0 -0
  78. {rasa_pro-3.14.0rc4.dist-info → rasa_pro-3.14.2.dist-info}/WHEEL +0 -0
  79. {rasa_pro-3.14.0rc4.dist-info → rasa_pro-3.14.2.dist-info}/entry_points.txt +0 -0
@@ -195,7 +195,7 @@ def fetch_remote_model_to_dir(
195
195
  try:
196
196
  return persistor.retrieve(model_name=model_name, target_path=target_path)
197
197
  except FileNotFoundError as e:
198
- raise ModelNotFound() from e
198
+ raise ModelNotFound("Model not found") from e
199
199
 
200
200
 
201
201
  def fetch_size_of_remote_model(
@@ -9,9 +9,11 @@ from typing import Any, Dict, List, Optional, Text, Tuple, Type, TypeVar, Union
9
9
  import numpy as np
10
10
  import scipy.sparse
11
11
 
12
+ from rasa.utils.installation_utils import check_for_installation_issues
12
13
  from rasa.utils.tensorflow import TENSORFLOW_AVAILABLE
13
14
 
14
15
  if TENSORFLOW_AVAILABLE:
16
+ check_for_installation_issues()
15
17
  import tensorflow as tf
16
18
  else:
17
19
  tf = None
@@ -211,7 +211,7 @@ def get_cron_trigger(cron_expression: str) -> CronTrigger:
211
211
  "privacy_config.invalid_cron_expression",
212
212
  cron=cron_expression,
213
213
  )
214
- raise RasaException from exc
214
+ raise RasaException("Invalid cron expression") from exc
215
215
 
216
216
  return cron
217
217
 
@@ -139,20 +139,17 @@ class OAuth2AuthStrategy(AgentAuthStrategy):
139
139
  resp.raise_for_status()
140
140
  token_data = resp.json()
141
141
  except httpx.HTTPStatusError as e:
142
- raise ValueError(
143
- f"OAuth2 token request failed with status {e.response.status_code}: "
144
- f"{e.response.text}"
145
- ) from e
142
+ raise e
146
143
  except httpx.RequestError as e:
147
- raise ValueError(f"OAuth2 token request failed: {e}") from e
144
+ raise ValueError(f"OAuth2 token request failed - {e}") from e
148
145
  except Exception as e:
149
146
  raise ValueError(
150
- f"Unexpected error during OAuth2 token request: {e}"
147
+ f"Unexpected error during OAuth2 token request - {e}"
151
148
  ) from e
152
149
 
153
150
  # Validate token data
154
151
  if KEY_ACCESS_TOKEN not in token_data:
155
- raise ValueError(f"No {KEY_ACCESS_TOKEN} in OAuth2 response")
152
+ raise ValueError(f"No `{KEY_ACCESS_TOKEN}` in OAuth2 response")
156
153
 
157
154
  # Set access token and expires at
158
155
  self._access_token = token_data[KEY_ACCESS_TOKEN]
@@ -322,9 +322,15 @@ class Flow:
322
322
 
323
323
  def get_collect_steps(self) -> List[CollectInformationFlowStep]:
324
324
  """Return all CollectInformationFlowSteps in the flow."""
325
- collect_steps = []
325
+ collect_steps: List[CollectInformationFlowStep] = []
326
326
  for step in self.steps_with_calls_resolved:
327
- if isinstance(step, CollectInformationFlowStep):
327
+ # Only add collect steps that are not already in the list.
328
+ # This is to avoid returning duplicate collect steps from called flows
329
+ # in case the called flow is called multiple times.
330
+ if (
331
+ isinstance(step, CollectInformationFlowStep)
332
+ and step not in collect_steps
333
+ ):
328
334
  collect_steps.append(step)
329
335
  return collect_steps
330
336
 
rasa/shared/core/slots.py CHANGED
@@ -355,8 +355,8 @@ class FloatSlot(Slot):
355
355
  mappings: List[Dict[Text, Any]],
356
356
  initial_value: Optional[float] = None,
357
357
  value_reset_delay: Optional[int] = None,
358
- max_value: float = 1.0,
359
- min_value: float = 0.0,
358
+ max_value: Optional[float] = None,
359
+ min_value: Optional[float] = None,
360
360
  influence_conversation: bool = True,
361
361
  is_builtin: bool = False,
362
362
  shared_for_coexistence: bool = False,
@@ -380,32 +380,24 @@ class FloatSlot(Slot):
380
380
  filled_by=filled_by,
381
381
  validation=validation,
382
382
  )
383
+ self.validate_min_max_range(min_value, max_value)
384
+
383
385
  self.max_value = max_value
384
386
  self.min_value = min_value
385
387
 
386
- if min_value >= max_value:
387
- raise InvalidSlotConfigError(
388
- "Float slot ('{}') created with an invalid range "
389
- "using min ({}) and max ({}) values. Make sure "
390
- "min is smaller than max."
391
- "".format(self.name, self.min_value, self.max_value)
392
- )
393
-
394
- if initial_value is not None and not (min_value <= initial_value <= max_value):
395
- rasa.shared.utils.io.raise_warning(
396
- f"Float slot ('{self.name}') created with an initial value "
397
- f"{self.value}. This value is outside of the configured min "
398
- f"({self.min_value}) and max ({self.max_value}) values."
399
- )
400
-
401
388
  def _as_feature(self) -> List[float]:
389
+ # set default min and max values used in prior releases
390
+ # to prevent regressions for existing models
391
+ min_value = self.min_value or 0.0
392
+ max_value = self.max_value or 1.0
393
+
402
394
  try:
403
- capped_value = max(self.min_value, min(self.max_value, float(self.value)))
404
- if abs(self.max_value - self.min_value) > 0:
405
- covered_range = abs(self.max_value - self.min_value)
395
+ capped_value = max(min_value, min(max_value, float(self.value)))
396
+ if abs(max_value - min_value) > 0:
397
+ covered_range = abs(max_value - min_value)
406
398
  else:
407
399
  covered_range = 1
408
- return [1.0, (capped_value - self.min_value) / covered_range]
400
+ return [1.0, (capped_value - min_value) / covered_range]
409
401
  except (TypeError, ValueError):
410
402
  return [0.0, 0.0]
411
403
 
@@ -424,13 +416,52 @@ class FloatSlot(Slot):
424
416
  return value
425
417
 
426
418
  def is_valid_value(self, value: Any) -> bool:
427
- """Checks if the slot contains the value."""
428
- # check that coerced type is float
429
- return value is None or isinstance(self.coerce_value(value), float)
419
+ """Checks if the slot value is valid."""
420
+ if value is None:
421
+ return True
422
+
423
+ if not isinstance(self.coerce_value(value), float):
424
+ return False
425
+
426
+ if (
427
+ self.min_value is not None
428
+ and self.max_value is not None
429
+ and not (self.min_value <= value <= self.max_value)
430
+ ):
431
+ return False
432
+
433
+ return True
430
434
 
431
435
  def _feature_dimensionality(self) -> int:
432
436
  return len(self.as_feature())
433
437
 
438
+ def validate_min_max_range(
439
+ self, min_value: Optional[float], max_value: Optional[float]
440
+ ) -> None:
441
+ """Validates the min-max range for the slot.
442
+
443
+ Raises:
444
+ InvalidSlotConfigError, if the min-max range is invalid.
445
+ """
446
+ if min_value is not None and max_value is not None and min_value >= max_value:
447
+ raise InvalidSlotConfigError(
448
+ f"Float slot ('{self.name}') created with an invalid range "
449
+ f"using min ({min_value}) and max ({max_value}) values. Make sure "
450
+ f"min is smaller than max."
451
+ )
452
+
453
+ if (
454
+ self.initial_value is not None
455
+ and min_value is not None
456
+ and max_value is not None
457
+ and not (min_value <= self.initial_value <= max_value)
458
+ ):
459
+ raise InvalidSlotConfigError(
460
+ f"Float slot ('{self.name}') created with an initial value "
461
+ f"{self.initial_value}. This value is outside of the configured min "
462
+ f"({min_value}) and max ({max_value}) values."
463
+ )
464
+
434
465
 
435
466
  class BooleanSlot(Slot):
436
467
  """A slot storing a truth value."""
@@ -126,4 +126,4 @@ class StoryParseError(RasaCoreException, ValueError):
126
126
 
127
127
  def __init__(self, message: Text) -> None:
128
128
  self.message = message
129
- super(StoryParseError, self).__init__()
129
+ super(StoryParseError, self).__init__(message)
rasa/shared/exceptions.py CHANGED
@@ -16,6 +16,17 @@ class RasaException(Exception):
16
16
  to the users, but will be ignored in telemetry.
17
17
  """
18
18
 
19
+ def __init__(self, message: str, suppress_stack_trace: bool = False, **kwargs: Any):
20
+ """Initialize the exception.
21
+
22
+ Args:
23
+ message: The error message.
24
+ suppress_stack_trace: If True, the stack trace will be suppressed in logs.
25
+ **kwargs: Additional keyword arguments (e.g., cause for exception chaining).
26
+ """
27
+ Exception.__init__(self, message)
28
+ self.suppress_stack_trace = suppress_stack_trace
29
+
19
30
 
20
31
  class RasaCoreException(RasaException):
21
32
  """Basic exception for errors raised by Rasa Core."""
@@ -113,6 +124,17 @@ class SchemaValidationError(RasaException, jsonschema.ValidationError):
113
124
  class InvalidEntityFormatException(RasaException, json.JSONDecodeError):
114
125
  """Raised if the format of an entity is invalid."""
115
126
 
127
+ def __init__(self, msg: str, doc: str = "", pos: int = 0):
128
+ """Initialize the exception.
129
+
130
+ Args:
131
+ msg: The error message.
132
+ doc: The document that caused the error.
133
+ pos: The position in the document where the error occurred.
134
+ """
135
+ RasaException.__init__(self, msg)
136
+ json.JSONDecodeError.__init__(self, msg, doc, pos)
137
+
116
138
  @classmethod
117
139
  def create_from(
118
140
  cls, other: json.JSONDecodeError, msg: Text
@@ -130,8 +152,7 @@ class ConnectionException(RasaException):
130
152
 
131
153
 
132
154
  class ProviderClientAPIException(RasaException):
133
- """Raised for errors that occur during API interactions
134
- with LLM / embedding providers.
155
+ """For errors during API interactions with LLM / embedding providers.
135
156
 
136
157
  Attributes:
137
158
  original_exception (Exception): The original exception that was
@@ -167,8 +167,9 @@ class OAuthConfigWrapper(OAuth, BaseModel):
167
167
 
168
168
  @dataclass
169
169
  class AzureOpenAIClientConfig:
170
- """Parses configuration for Azure OpenAI client, resolves aliases and
171
- raises deprecation warnings.
170
+ """Parses configuration for Azure OpenAI client.
171
+
172
+ Resolves aliases and raises deprecation warnings.
172
173
 
173
174
  Raises:
174
175
  ValueError: Raised in cases of invalid configuration:
@@ -301,9 +302,7 @@ class AzureOpenAIClientConfig:
301
302
 
302
303
 
303
304
  def is_azure_openai_config(config: dict) -> bool:
304
- """Check whether the configuration is meant to configure
305
- an Azure OpenAI client.
306
- """
305
+ """Check whether the configuration is meant to configure an Azure OpenAI client."""
307
306
  # Resolve any aliases that are specific to Azure OpenAI configuration
308
307
  config = AzureOpenAIClientConfig.resolve_config_aliases(config)
309
308
 
@@ -40,8 +40,9 @@ FORBIDDEN_KEYS = [
40
40
 
41
41
  @dataclass
42
42
  class DefaultLiteLLMClientConfig:
43
- """Parses configuration for default LiteLLM client, resolves aliases and
44
- raises deprecation warnings.
43
+ """Parses configuration for default LiteLLM client.
44
+
45
+ Resolves aliases and raises deprecation warnings.
45
46
 
46
47
  Raises:
47
48
  ValueError: Raised in cases of invalid configuration:
@@ -72,8 +73,7 @@ class DefaultLiteLLMClientConfig:
72
73
 
73
74
  @classmethod
74
75
  def from_dict(cls, config: dict) -> DefaultLiteLLMClientConfig:
75
- """
76
- Initializes a dataclass from the passed config.
76
+ """Initializes a dataclass from the passed config.
77
77
 
78
78
  Args:
79
79
  config: (dict) The config from which to initialize.
@@ -38,8 +38,9 @@ _LITELLM_UNSUPPORTED_KEYS = [
38
38
 
39
39
  @dataclass
40
40
  class LiteLLMRouterClientConfig:
41
- """Parses configuration for a LiteLLM Router client. The configuration is expected
42
- to be in the following format:
41
+ """Parses configuration for a LiteLLM Router client.
42
+
43
+ The configuration is expected to be in the following format:
43
44
 
44
45
  {
45
46
  "id": "model_group_id",
@@ -64,8 +64,9 @@ FORBIDDEN_KEYS = [
64
64
 
65
65
  @dataclass
66
66
  class OpenAIClientConfig:
67
- """Parses configuration for Azure OpenAI client, resolves aliases and
68
- raises deprecation warnings.
67
+ """Parses configuration for OpenAI client.
68
+
69
+ Resolves aliases and raises deprecation warnings.
69
70
 
70
71
  Raises:
71
72
  ValueError: Raised in cases of invalid configuration:
@@ -118,8 +119,7 @@ class OpenAIClientConfig:
118
119
 
119
120
  @classmethod
120
121
  def from_dict(cls, config: dict) -> OpenAIClientConfig:
121
- """
122
- Initializes a dataclass from the passed config.
122
+ """Initializes a dataclass from the passed config.
123
123
 
124
124
  Args:
125
125
  config: (dict) The config from which to initialize.
@@ -168,9 +168,7 @@ class OpenAIClientConfig:
168
168
 
169
169
 
170
170
  def is_openai_config(config: dict) -> bool:
171
- """Check whether the configuration is meant to configure
172
- an OpenAI client.
173
- """
171
+ """Check whether the configuration is meant to configure an OpenAI client."""
174
172
  # Process the config to handle all the aliases
175
173
  config = OpenAIClientConfig.resolve_config_aliases(config)
176
174
 
@@ -22,8 +22,9 @@ structlogger = structlog.get_logger()
22
22
 
23
23
  @dataclass
24
24
  class RasaLLMClientConfig:
25
- """Parses configuration for a Rasa Hosted LiteLLM client,
26
- checks required keys present.
25
+ """Parses configuration for a Rasa Hosted LiteLLM client.
26
+
27
+ Checks required keys present.
27
28
 
28
29
  Raises:
29
30
  ValueError: Raised in cases of invalid configuration:
@@ -40,8 +41,7 @@ class RasaLLMClientConfig:
40
41
 
41
42
  @classmethod
42
43
  def from_dict(cls, config: dict) -> RasaLLMClientConfig:
43
- """
44
- Initializes a dataclass from the passed config.
44
+ """Initializes a dataclass from the passed config.
45
45
 
46
46
  Args:
47
47
  config: (dict) The config from which to initialize.
@@ -61,8 +61,9 @@ FORBIDDEN_KEYS = [
61
61
 
62
62
  @dataclass
63
63
  class SelfHostedLLMClientConfig:
64
- """Parses configuration for Self Hosted LiteLLM client, resolves aliases and
65
- raises deprecation warnings.
64
+ """Parses configuration for Self Hosted LiteLLM client.
65
+
66
+ Resolves aliases and raises deprecation warnings.
66
67
 
67
68
  Raises:
68
69
  ValueError: Raised in cases of invalid configuration:
@@ -116,8 +117,7 @@ class SelfHostedLLMClientConfig:
116
117
 
117
118
  @classmethod
118
119
  def from_dict(cls, config: dict) -> SelfHostedLLMClientConfig:
119
- """
120
- Initializes a dataclass from the passed config.
120
+ """Initializes a dataclass from the passed config.
121
121
 
122
122
  Args:
123
123
  config: (dict) The config from which to initialize.
@@ -1,12 +1,14 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import asyncio
3
4
  import logging
4
5
  from abc import abstractmethod
5
- from typing import Any, Dict, List, Union, cast
6
+ from typing import Any, Dict, List, NoReturn, Union, cast
6
7
 
7
8
  import structlog
8
9
  from litellm import acompletion, completion, validate_environment
9
10
 
11
+ from rasa.core.constants import DEFAULT_REQUEST_TIMEOUT
10
12
  from rasa.shared.constants import (
11
13
  _VALIDATE_ENVIRONMENT_MISSING_KEYS_KEY,
12
14
  API_BASE_CONFIG_KEY,
@@ -57,26 +59,24 @@ class _BaseLiteLLMClient:
57
59
  @property
58
60
  @abstractmethod
59
61
  def config(self) -> dict:
60
- """Returns the configuration for that the llm client
61
- in dictionary form.
62
- """
62
+ """Returns the configuration for that the llm client in dictionary form."""
63
63
  pass
64
64
 
65
65
  @property
66
66
  @abstractmethod
67
67
  def _litellm_model_name(self) -> str:
68
- """Returns the value of LiteLLM's model parameter to be used in
69
- completion/acompletion in LiteLLM format:
68
+ """Returns the value of LiteLLM's model parameter.
70
69
 
70
+ To be used in completion/acompletion in LiteLLM format:
71
71
  <provider>/<model or deployment name>
72
72
  """
73
73
  pass
74
74
 
75
75
  @property
76
76
  def _litellm_extra_parameters(self) -> Dict[str, Any]:
77
- """Returns a dictionary of extra parameters which include model
78
- parameters as well as LiteLLM specific input parameters.
77
+ """Returns a dictionary of extra parameters.
79
78
 
79
+ Includes model parameters as well as LiteLLM specific input parameters.
80
80
  By default, this returns an empty dictionary (no extra parameters).
81
81
  """
82
82
  return {}
@@ -96,8 +96,9 @@ class _BaseLiteLLMClient:
96
96
  }
97
97
 
98
98
  def validate_client_setup(self) -> None:
99
- """Perform client validation. By default only environment variables
100
- are validated.
99
+ """Perform client validation.
100
+
101
+ By default only environment variables are validated.
101
102
 
102
103
  Raises:
103
104
  ProviderClientValidationError if validation fails.
@@ -188,10 +189,17 @@ class _BaseLiteLLMClient:
188
189
  arguments = cast(
189
190
  Dict[str, Any], resolve_environment_variables(self._completion_fn_args)
190
191
  )
191
- response = await acompletion(
192
- messages=formatted_messages, **{**arguments, **kwargs}
192
+
193
+ timeout = self._litellm_extra_parameters.get(
194
+ "timeout", DEFAULT_REQUEST_TIMEOUT
195
+ )
196
+ response = await asyncio.wait_for(
197
+ acompletion(messages=formatted_messages, **{**arguments, **kwargs}),
198
+ timeout=timeout,
193
199
  )
194
200
  return self._format_response(response)
201
+ except asyncio.TimeoutError:
202
+ self._handle_timeout_error()
195
203
  except Exception as e:
196
204
  message = ""
197
205
  from rasa.shared.providers.llm.self_hosted_llm_client import (
@@ -211,6 +219,25 @@ class _BaseLiteLLMClient:
211
219
  )
212
220
  raise ProviderClientAPIException(e, message) from e
213
221
 
222
+ def _handle_timeout_error(self) -> NoReturn:
223
+ """Handle asyncio.TimeoutError and raise ProviderClientAPIException.
224
+
225
+ Raises:
226
+ ProviderClientAPIException: Always raised with formatted timeout error.
227
+ """
228
+ timeout = self._litellm_extra_parameters.get("timeout", DEFAULT_REQUEST_TIMEOUT)
229
+ error_message = (
230
+ f"APITimeoutError - Request timed out. Error_str: "
231
+ f"Request timed out. - timeout value={timeout:.6f}, "
232
+ f"time taken={timeout:.6f} seconds"
233
+ )
234
+ # nosemgrep: semgrep.rules.pii-positional-arguments-in-logging
235
+ # Error message contains only numeric timeout values, not PII
236
+ structlogger.error(
237
+ f"{self.__class__.__name__.lower()}.llm.timeout", error=error_message
238
+ )
239
+ raise ProviderClientAPIException(asyncio.TimeoutError(error_message)) from None
240
+
214
241
  def _get_formatted_messages(
215
242
  self, messages: Union[List[dict], List[str], str]
216
243
  ) -> List[Dict[str, str]]:
@@ -312,8 +339,9 @@ class _BaseLiteLLMClient:
312
339
 
313
340
  @staticmethod
314
341
  def _ensure_certificates() -> None:
315
- """Configures SSL certificates for LiteLLM. This method is invoked during
316
- client initialization.
342
+ """Configures SSL certificates for LiteLLM.
343
+
344
+ This method is invoked during client initialization.
317
345
 
318
346
  LiteLLM may utilize `openai` clients or other providers that require
319
347
  SSL verification settings through the `SSL_VERIFY` / `SSL_CERTIFICATE`
@@ -1,10 +1,12 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import asyncio
3
4
  import logging
4
5
  from typing import Any, Dict, List, Union
5
6
 
6
7
  import structlog
7
8
 
9
+ from rasa.core.constants import DEFAULT_REQUEST_TIMEOUT
8
10
  from rasa.shared.exceptions import ProviderClientAPIException
9
11
  from rasa.shared.providers._configs.litellm_router_client_config import (
10
12
  LiteLLMRouterClientConfig,
@@ -79,13 +81,14 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
79
81
 
80
82
  @suppress_logs(log_level=logging.WARNING)
81
83
  def _text_completion(self, prompt: Union[List[str], str]) -> LLMResponse:
82
- """
83
- Synchronously generate completions for given prompt.
84
+ """Synchronously generate completions for given prompt.
84
85
 
85
86
  Args:
86
87
  prompt: Prompt to generate the completion for.
88
+
87
89
  Returns:
88
90
  List of message completions.
91
+
89
92
  Raises:
90
93
  ProviderClientAPIException: If the API request fails.
91
94
  """
@@ -103,21 +106,30 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
103
106
 
104
107
  @suppress_logs(log_level=logging.WARNING)
105
108
  async def _atext_completion(self, prompt: Union[List[str], str]) -> LLMResponse:
106
- """
107
- Asynchronously generate completions for given prompt.
109
+ """Asynchronously generate completions for given prompt.
108
110
 
109
111
  Args:
110
112
  prompt: Prompt to generate the completion for.
113
+
111
114
  Returns:
112
115
  List of message completions.
116
+
113
117
  Raises:
114
118
  ProviderClientAPIException: If the API request fails.
115
119
  """
116
120
  try:
117
- response = await self.router_client.atext_completion(
118
- prompt=prompt, **self._completion_fn_args
121
+ timeout = self._litellm_extra_parameters.get(
122
+ "timeout", DEFAULT_REQUEST_TIMEOUT
123
+ )
124
+ response = await asyncio.wait_for(
125
+ self.router_client.atext_completion(
126
+ prompt=prompt, **self._completion_fn_args
127
+ ),
128
+ timeout=timeout,
119
129
  )
120
130
  return self._format_text_completion_response(response)
131
+ except asyncio.TimeoutError:
132
+ self._handle_timeout_error()
121
133
  except Exception as e:
122
134
  raise ProviderClientAPIException(e)
123
135
 
@@ -125,8 +137,7 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
125
137
  def completion(
126
138
  self, messages: Union[List[dict], List[str], str], **kwargs: Any
127
139
  ) -> LLMResponse:
128
- """
129
- Synchronously generate completions for given list of messages.
140
+ """Synchronously generate completions for given list of messages.
130
141
 
131
142
  Method overrides the base class method to call the appropriate
132
143
  completion method based on the configuration. If the chat completions
@@ -143,15 +154,17 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
143
154
  as a user message.
144
155
  - a single message as a string which will be formatted as user message.
145
156
  **kwargs: Additional parameters to pass to the completion call.
157
+
146
158
  Returns:
147
159
  List of message completions.
160
+
148
161
  Raises:
149
162
  ProviderClientAPIException: If the API request fails.
150
163
  """
151
164
  if not self._use_chat_completions_endpoint:
152
165
  return self._text_completion(messages)
153
166
  try:
154
- formatted_messages = self._format_messages(messages)
167
+ formatted_messages = self._get_formatted_messages(messages)
155
168
  response = self.router_client.completion(
156
169
  messages=formatted_messages, **{**self._completion_fn_args, **kwargs}
157
170
  )
@@ -163,8 +176,7 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
163
176
  async def acompletion(
164
177
  self, messages: Union[List[dict], List[str], str], **kwargs: Any
165
178
  ) -> LLMResponse:
166
- """
167
- Asynchronously generate completions for given list of messages.
179
+ """Asynchronously generate completions for given list of messages.
168
180
 
169
181
  Method overrides the base class method to call the appropriate
170
182
  completion method based on the configuration. If the chat completions
@@ -181,28 +193,39 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
181
193
  as a user message.
182
194
  - a single message as a string which will be formatted as user message.
183
195
  **kwargs: Additional parameters to pass to the completion call.
196
+
184
197
  Returns:
185
198
  List of message completions.
199
+
186
200
  Raises:
187
201
  ProviderClientAPIException: If the API request fails.
188
202
  """
189
203
  if not self._use_chat_completions_endpoint:
190
204
  return await self._atext_completion(messages)
191
205
  try:
192
- formatted_messages = self._format_messages(messages)
193
- response = await self.router_client.acompletion(
194
- messages=formatted_messages, **{**self._completion_fn_args, **kwargs}
206
+ formatted_messages = self._get_formatted_messages(messages)
207
+ timeout = self._litellm_extra_parameters.get(
208
+ "timeout", DEFAULT_REQUEST_TIMEOUT
209
+ )
210
+ response = await asyncio.wait_for(
211
+ self.router_client.acompletion(
212
+ messages=formatted_messages,
213
+ **{**self._completion_fn_args, **kwargs},
214
+ ),
215
+ timeout=timeout,
195
216
  )
196
217
  return self._format_response(response)
218
+ except asyncio.TimeoutError:
219
+ self._handle_timeout_error()
197
220
  except Exception as e:
198
221
  raise ProviderClientAPIException(e)
199
222
 
200
223
  @property
201
224
  def _completion_fn_args(self) -> Dict[str, Any]:
202
- """Returns the completion arguments for invoking a call through
203
- LiteLLM's completion functions.
204
- """
225
+ """Returns the completion arguments.
205
226
 
227
+ For invoking a call through LiteLLM's completion functions.
228
+ """
206
229
  return {
207
230
  **self._litellm_extra_parameters,
208
231
  LITE_LLM_MODEL_FIELD: self.model_group_id,