rasa-pro 3.14.1__py3-none-any.whl → 3.14.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (44) hide show
  1. rasa/constants.py +1 -0
  2. rasa/core/actions/action_clean_stack.py +32 -0
  3. rasa/core/actions/constants.py +4 -0
  4. rasa/core/actions/custom_action_executor.py +70 -12
  5. rasa/core/actions/grpc_custom_action_executor.py +41 -2
  6. rasa/core/actions/http_custom_action_executor.py +49 -25
  7. rasa/core/channels/voice_stream/browser_audio.py +3 -3
  8. rasa/core/channels/voice_stream/voice_channel.py +27 -17
  9. rasa/core/config/credentials.py +3 -3
  10. rasa/core/policies/flows/flow_executor.py +49 -29
  11. rasa/core/run.py +21 -5
  12. rasa/dialogue_understanding/generator/llm_based_command_generator.py +6 -3
  13. rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +15 -7
  14. rasa/dialogue_understanding/generator/single_step/search_ready_llm_command_generator.py +15 -8
  15. rasa/dialogue_understanding/processor/command_processor.py +13 -7
  16. rasa/e2e_test/e2e_config.py +4 -3
  17. rasa/engine/recipes/default_components.py +16 -6
  18. rasa/graph_components/validators/default_recipe_validator.py +10 -4
  19. rasa/nlu/classifiers/diet_classifier.py +2 -0
  20. rasa/shared/core/flows/flow.py +8 -2
  21. rasa/shared/core/slots.py +55 -24
  22. rasa/shared/providers/_configs/azure_openai_client_config.py +4 -5
  23. rasa/shared/providers/_configs/default_litellm_client_config.py +4 -4
  24. rasa/shared/providers/_configs/litellm_router_client_config.py +3 -2
  25. rasa/shared/providers/_configs/openai_client_config.py +5 -7
  26. rasa/shared/providers/_configs/rasa_llm_client_config.py +4 -4
  27. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +4 -4
  28. rasa/shared/providers/llm/_base_litellm_client.py +42 -14
  29. rasa/shared/providers/llm/litellm_router_llm_client.py +38 -15
  30. rasa/shared/providers/llm/self_hosted_llm_client.py +34 -32
  31. rasa/shared/utils/common.py +9 -1
  32. rasa/shared/utils/configs.py +5 -8
  33. rasa/utils/common.py +9 -0
  34. rasa/utils/endpoints.py +6 -0
  35. rasa/utils/installation_utils.py +111 -0
  36. rasa/utils/tensorflow/callback.py +2 -0
  37. rasa/utils/tensorflow/models.py +3 -0
  38. rasa/utils/train_utils.py +2 -0
  39. rasa/version.py +1 -1
  40. {rasa_pro-3.14.1.dist-info → rasa_pro-3.14.2.dist-info}/METADATA +2 -2
  41. {rasa_pro-3.14.1.dist-info → rasa_pro-3.14.2.dist-info}/RECORD +44 -43
  42. {rasa_pro-3.14.1.dist-info → rasa_pro-3.14.2.dist-info}/NOTICE +0 -0
  43. {rasa_pro-3.14.1.dist-info → rasa_pro-3.14.2.dist-info}/WHEEL +0 -0
  44. {rasa_pro-3.14.1.dist-info → rasa_pro-3.14.2.dist-info}/entry_points.txt +0 -0
@@ -62,7 +62,9 @@ structlogger = structlog.get_logger()
62
62
  class LLMBasedCommandGenerator(
63
63
  LLMHealthCheckMixin, GraphComponent, CommandGenerator, ABC
64
64
  ):
65
- """An abstract class defining interface and common functionality
65
+ """This class provides common functionality for all LLM-based command generators.
66
+
67
+ An abstract class defining interface and common functionality
66
68
  of an LLM-based command generators.
67
69
  """
68
70
 
@@ -174,8 +176,9 @@ class LLMBasedCommandGenerator(
174
176
  def train(
175
177
  self, training_data: TrainingData, flows: FlowsList, domain: Domain
176
178
  ) -> Resource:
177
- """Train the llm based command generator. Stores all flows into a vector
178
- store.
179
+ """Trains the LLM-based command generator and prepares flow retrieval data.
180
+
181
+ Stores all flows into a vector store.
179
182
  """
180
183
  self.perform_llm_health_check(
181
184
  self.config.get(LLM_CONFIG_KEY),
@@ -168,6 +168,20 @@ class CompactLLMCommandGenerator(SingleStepBasedLLMCommandGenerator):
168
168
  if prompt_template is not None:
169
169
  return prompt_template
170
170
 
171
+ # Try to load the template from the given path or fallback to the default for
172
+ # the component.
173
+ custom_prompt_template_path = config.get(PROMPT_TEMPLATE_CONFIG_KEY)
174
+ if custom_prompt_template_path is not None:
175
+ custom_prompt_template = get_prompt_template(
176
+ custom_prompt_template_path,
177
+ None, # Default will be based on the model
178
+ log_source_component=log_source_component,
179
+ log_source_method=log_context,
180
+ )
181
+ if custom_prompt_template is not None:
182
+ return custom_prompt_template
183
+
184
+ # Fallback to the default prompt template based on the model.
171
185
  default_command_prompt_template = get_default_prompt_template_based_on_model(
172
186
  llm_config=config.get(LLM_CONFIG_KEY, {}) or {},
173
187
  model_prompt_mapping=cls.get_model_prompt_mapper(),
@@ -177,10 +191,4 @@ class CompactLLMCommandGenerator(SingleStepBasedLLMCommandGenerator):
177
191
  log_source_method=log_context,
178
192
  )
179
193
 
180
- # Return the prompt template either from the config or the default prompt.
181
- return get_prompt_template(
182
- config.get(PROMPT_TEMPLATE_CONFIG_KEY),
183
- default_command_prompt_template,
184
- log_source_component=log_source_component,
185
- log_source_method=log_context,
186
- )
194
+ return default_command_prompt_template
@@ -165,7 +165,20 @@ class SearchReadyLLMCommandGenerator(SingleStepBasedLLMCommandGenerator):
165
165
  if prompt_template is not None:
166
166
  return prompt_template
167
167
 
168
- # Get the default prompt template based on the model name.
168
+ # Try to load the template from the given path or fallback to the default for
169
+ # the component.
170
+ custom_prompt_template_path = config.get(PROMPT_TEMPLATE_CONFIG_KEY)
171
+ if custom_prompt_template_path is not None:
172
+ custom_prompt_template = get_prompt_template(
173
+ custom_prompt_template_path,
174
+ None, # Default will be based on the model
175
+ log_source_component=log_source_component,
176
+ log_source_method=log_context,
177
+ )
178
+ if custom_prompt_template is not None:
179
+ return custom_prompt_template
180
+
181
+ # Fallback to the default prompt template based on the model.
169
182
  default_command_prompt_template = get_default_prompt_template_based_on_model(
170
183
  llm_config=config.get(LLM_CONFIG_KEY, {}) or {},
171
184
  model_prompt_mapping=cls.get_model_prompt_mapper(),
@@ -175,10 +188,4 @@ class SearchReadyLLMCommandGenerator(SingleStepBasedLLMCommandGenerator):
175
188
  log_source_method=log_context,
176
189
  )
177
190
 
178
- # Return the prompt template either from the config or the default prompt.
179
- return get_prompt_template(
180
- config.get(PROMPT_TEMPLATE_CONFIG_KEY),
181
- default_command_prompt_template,
182
- log_source_component=log_source_component,
183
- log_source_method=log_context,
184
- )
191
+ return default_command_prompt_template
@@ -499,10 +499,22 @@ def clean_up_commands(
499
499
  else:
500
500
  clean_commands.append(command)
501
501
 
502
+ # ensure that there is only one command of a certain command type
503
+ clean_commands = ensure_max_number_of_command_type(
504
+ clean_commands, CannotHandleCommand, 1
505
+ )
506
+ clean_commands = ensure_max_number_of_command_type(
507
+ clean_commands, RepeatBotMessagesCommand, 1
508
+ )
509
+ clean_commands = ensure_max_number_of_command_type(
510
+ clean_commands, ChitChatAnswerCommand, 1
511
+ )
512
+
502
513
  # Replace CannotHandleCommands with ContinueAgentCommand when an agent is active
503
514
  # to keep the agent running, but preserve chitchat
504
515
  clean_commands = _replace_cannot_handle_with_continue_agent(clean_commands, tracker)
505
516
 
517
+ # filter out cannot handle commands if there are other commands present
506
518
  # when coexistence is enabled, by default there will be a SetSlotCommand
507
519
  # for the ROUTE_TO_CALM_SLOT slot.
508
520
  if tracker.has_coexistence_routing_slot and len(clean_commands) > 2:
@@ -510,12 +522,6 @@ def clean_up_commands(
510
522
  elif not tracker.has_coexistence_routing_slot and len(clean_commands) > 1:
511
523
  clean_commands = filter_cannot_handle_command(clean_commands)
512
524
 
513
- clean_commands = ensure_max_number_of_command_type(
514
- clean_commands, RepeatBotMessagesCommand, 1
515
- )
516
- clean_commands = ensure_max_number_of_command_type(
517
- clean_commands, ContinueAgentCommand, 1
518
- )
519
525
  structlogger.debug(
520
526
  "command_processor.clean_up_commands.final_commands",
521
527
  command=clean_commands,
@@ -580,7 +586,7 @@ def clean_up_start_flow_command(
580
586
  # drop a start flow command if the starting flow is equal
581
587
  # to the currently active flow
582
588
  structlogger.debug(
583
- "command_processor.clean_up_commands." "skip_command_flow_already_active",
589
+ "command_processor.clean_up_commands.skip_command_flow_already_active",
584
590
  command=command,
585
591
  )
586
592
  return clean_commands
@@ -72,9 +72,10 @@ class LLMJudgeConfig(BaseModel):
72
72
 
73
73
  llm_config = resolve_model_client_config(llm_config)
74
74
  llm_config, llm_extra_parameters = cls.extract_attributes(llm_config)
75
- llm_config = combine_custom_and_default_config(
76
- llm_config, cls.get_default_llm_config()
77
- )
75
+ if not llm_config:
76
+ llm_config = combine_custom_and_default_config(
77
+ llm_config, cls.get_default_llm_config()
78
+ )
78
79
  embeddings_config = resolve_model_client_config(embeddings)
79
80
  embeddings_config, embeddings_extra_parameters = cls.extract_attributes(
80
81
  embeddings_config
@@ -27,22 +27,32 @@ from rasa.shared.utils.common import conditional_import
27
27
 
28
28
  # components dependent on tensorflow
29
29
  TEDPolicy, TED_POLICY_AVAILABLE = conditional_import(
30
- "rasa.core.policies.ted_policy", "TEDPolicy"
30
+ "rasa.core.policies.ted_policy", "TEDPolicy", check_installation_setup=True
31
31
  )
32
32
  UnexpecTEDIntentPolicy, UNEXPECTED_INTENT_POLICY_AVAILABLE = conditional_import(
33
- "rasa.core.policies.unexpected_intent_policy", "UnexpecTEDIntentPolicy"
33
+ "rasa.core.policies.unexpected_intent_policy",
34
+ "UnexpecTEDIntentPolicy",
35
+ check_installation_setup=True,
34
36
  )
35
37
  DIETClassifier, DIET_CLASSIFIER_AVAILABLE = conditional_import(
36
- "rasa.nlu.classifiers.diet_classifier", "DIETClassifier"
38
+ "rasa.nlu.classifiers.diet_classifier",
39
+ "DIETClassifier",
40
+ check_installation_setup=True,
37
41
  )
38
42
  ConveRTFeaturizer, CONVERT_FEATURIZER_AVAILABLE = conditional_import(
39
- "rasa.nlu.featurizers.dense_featurizer.convert_featurizer", "ConveRTFeaturizer"
43
+ "rasa.nlu.featurizers.dense_featurizer.convert_featurizer",
44
+ "ConveRTFeaturizer",
45
+ check_installation_setup=True,
40
46
  )
41
47
  LanguageModelFeaturizer, LANGUAGE_MODEL_FEATURIZER_AVAILABLE = conditional_import(
42
- "rasa.nlu.featurizers.dense_featurizer.lm_featurizer", "LanguageModelFeaturizer"
48
+ "rasa.nlu.featurizers.dense_featurizer.lm_featurizer",
49
+ "LanguageModelFeaturizer",
50
+ check_installation_setup=True,
43
51
  )
44
52
  ResponseSelector, RESPONSE_SELECTOR_AVAILABLE = conditional_import(
45
- "rasa.nlu.selectors.response_selector", "ResponseSelector"
53
+ "rasa.nlu.selectors.response_selector",
54
+ "ResponseSelector",
55
+ check_installation_setup=True,
46
56
  )
47
57
 
48
58
  # components dependent on skops
@@ -40,16 +40,22 @@ from rasa.shared.utils.common import conditional_import
40
40
 
41
41
  # Conditional imports for TensorFlow-dependent components
42
42
  TEDPolicy, TED_POLICY_AVAILABLE = conditional_import(
43
- "rasa.core.policies.ted_policy", "TEDPolicy"
43
+ "rasa.core.policies.ted_policy", "TEDPolicy", check_installation_setup=True
44
44
  )
45
45
  UnexpecTEDIntentPolicy, UNEXPECTED_INTENT_POLICY_AVAILABLE = conditional_import(
46
- "rasa.core.policies.unexpected_intent_policy", "UnexpecTEDIntentPolicy"
46
+ "rasa.core.policies.unexpected_intent_policy",
47
+ "UnexpecTEDIntentPolicy",
48
+ check_installation_setup=True,
47
49
  )
48
50
  DIETClassifier, DIET_CLASSIFIER_AVAILABLE = conditional_import(
49
- "rasa.nlu.classifiers.diet_classifier", "DIETClassifier"
51
+ "rasa.nlu.classifiers.diet_classifier",
52
+ "DIETClassifier",
53
+ check_installation_setup=True,
50
54
  )
51
55
  ResponseSelector, RESPONSE_SELECTOR_AVAILABLE = conditional_import(
52
- "rasa.nlu.selectors.response_selector", "ResponseSelector"
56
+ "rasa.nlu.selectors.response_selector",
57
+ "ResponseSelector",
58
+ check_installation_setup=True,
53
59
  )
54
60
 
55
61
  # Conditional imports for nlu components requiring other dependencies than tensorflow
@@ -9,9 +9,11 @@ from typing import Any, Dict, List, Optional, Text, Tuple, Type, TypeVar, Union
9
9
  import numpy as np
10
10
  import scipy.sparse
11
11
 
12
+ from rasa.utils.installation_utils import check_for_installation_issues
12
13
  from rasa.utils.tensorflow import TENSORFLOW_AVAILABLE
13
14
 
14
15
  if TENSORFLOW_AVAILABLE:
16
+ check_for_installation_issues()
15
17
  import tensorflow as tf
16
18
  else:
17
19
  tf = None
@@ -322,9 +322,15 @@ class Flow:
322
322
 
323
323
  def get_collect_steps(self) -> List[CollectInformationFlowStep]:
324
324
  """Return all CollectInformationFlowSteps in the flow."""
325
- collect_steps = []
325
+ collect_steps: List[CollectInformationFlowStep] = []
326
326
  for step in self.steps_with_calls_resolved:
327
- if isinstance(step, CollectInformationFlowStep):
327
+ # Only add collect steps that are not already in the list.
328
+ # This is to avoid returning duplicate collect steps from called flows
329
+ # in case the called flow is called multiple times.
330
+ if (
331
+ isinstance(step, CollectInformationFlowStep)
332
+ and step not in collect_steps
333
+ ):
328
334
  collect_steps.append(step)
329
335
  return collect_steps
330
336
 
rasa/shared/core/slots.py CHANGED
@@ -355,8 +355,8 @@ class FloatSlot(Slot):
355
355
  mappings: List[Dict[Text, Any]],
356
356
  initial_value: Optional[float] = None,
357
357
  value_reset_delay: Optional[int] = None,
358
- max_value: float = 1.0,
359
- min_value: float = 0.0,
358
+ max_value: Optional[float] = None,
359
+ min_value: Optional[float] = None,
360
360
  influence_conversation: bool = True,
361
361
  is_builtin: bool = False,
362
362
  shared_for_coexistence: bool = False,
@@ -380,32 +380,24 @@ class FloatSlot(Slot):
380
380
  filled_by=filled_by,
381
381
  validation=validation,
382
382
  )
383
+ self.validate_min_max_range(min_value, max_value)
384
+
383
385
  self.max_value = max_value
384
386
  self.min_value = min_value
385
387
 
386
- if min_value >= max_value:
387
- raise InvalidSlotConfigError(
388
- "Float slot ('{}') created with an invalid range "
389
- "using min ({}) and max ({}) values. Make sure "
390
- "min is smaller than max."
391
- "".format(self.name, self.min_value, self.max_value)
392
- )
393
-
394
- if initial_value is not None and not (min_value <= initial_value <= max_value):
395
- rasa.shared.utils.io.raise_warning(
396
- f"Float slot ('{self.name}') created with an initial value "
397
- f"{self.value}. This value is outside of the configured min "
398
- f"({self.min_value}) and max ({self.max_value}) values."
399
- )
400
-
401
388
  def _as_feature(self) -> List[float]:
389
+ # set default min and max values used in prior releases
390
+ # to prevent regressions for existing models
391
+ min_value = self.min_value or 0.0
392
+ max_value = self.max_value or 1.0
393
+
402
394
  try:
403
- capped_value = max(self.min_value, min(self.max_value, float(self.value)))
404
- if abs(self.max_value - self.min_value) > 0:
405
- covered_range = abs(self.max_value - self.min_value)
395
+ capped_value = max(min_value, min(max_value, float(self.value)))
396
+ if abs(max_value - min_value) > 0:
397
+ covered_range = abs(max_value - min_value)
406
398
  else:
407
399
  covered_range = 1
408
- return [1.0, (capped_value - self.min_value) / covered_range]
400
+ return [1.0, (capped_value - min_value) / covered_range]
409
401
  except (TypeError, ValueError):
410
402
  return [0.0, 0.0]
411
403
 
@@ -424,13 +416,52 @@ class FloatSlot(Slot):
424
416
  return value
425
417
 
426
418
  def is_valid_value(self, value: Any) -> bool:
427
- """Checks if the slot contains the value."""
428
- # check that coerced type is float
429
- return value is None or isinstance(self.coerce_value(value), float)
419
+ """Checks if the slot value is valid."""
420
+ if value is None:
421
+ return True
422
+
423
+ if not isinstance(self.coerce_value(value), float):
424
+ return False
425
+
426
+ if (
427
+ self.min_value is not None
428
+ and self.max_value is not None
429
+ and not (self.min_value <= value <= self.max_value)
430
+ ):
431
+ return False
432
+
433
+ return True
430
434
 
431
435
  def _feature_dimensionality(self) -> int:
432
436
  return len(self.as_feature())
433
437
 
438
+ def validate_min_max_range(
439
+ self, min_value: Optional[float], max_value: Optional[float]
440
+ ) -> None:
441
+ """Validates the min-max range for the slot.
442
+
443
+ Raises:
444
+ InvalidSlotConfigError, if the min-max range is invalid.
445
+ """
446
+ if min_value is not None and max_value is not None and min_value >= max_value:
447
+ raise InvalidSlotConfigError(
448
+ f"Float slot ('{self.name}') created with an invalid range "
449
+ f"using min ({min_value}) and max ({max_value}) values. Make sure "
450
+ f"min is smaller than max."
451
+ )
452
+
453
+ if (
454
+ self.initial_value is not None
455
+ and min_value is not None
456
+ and max_value is not None
457
+ and not (min_value <= self.initial_value <= max_value)
458
+ ):
459
+ raise InvalidSlotConfigError(
460
+ f"Float slot ('{self.name}') created with an initial value "
461
+ f"{self.initial_value}. This value is outside of the configured min "
462
+ f"({min_value}) and max ({max_value}) values."
463
+ )
464
+
434
465
 
435
466
  class BooleanSlot(Slot):
436
467
  """A slot storing a truth value."""
@@ -167,8 +167,9 @@ class OAuthConfigWrapper(OAuth, BaseModel):
167
167
 
168
168
  @dataclass
169
169
  class AzureOpenAIClientConfig:
170
- """Parses configuration for Azure OpenAI client, resolves aliases and
171
- raises deprecation warnings.
170
+ """Parses configuration for Azure OpenAI client.
171
+
172
+ Resolves aliases and raises deprecation warnings.
172
173
 
173
174
  Raises:
174
175
  ValueError: Raised in cases of invalid configuration:
@@ -301,9 +302,7 @@ class AzureOpenAIClientConfig:
301
302
 
302
303
 
303
304
  def is_azure_openai_config(config: dict) -> bool:
304
- """Check whether the configuration is meant to configure
305
- an Azure OpenAI client.
306
- """
305
+ """Check whether the configuration is meant to configure an Azure OpenAI client."""
307
306
  # Resolve any aliases that are specific to Azure OpenAI configuration
308
307
  config = AzureOpenAIClientConfig.resolve_config_aliases(config)
309
308
 
@@ -40,8 +40,9 @@ FORBIDDEN_KEYS = [
40
40
 
41
41
  @dataclass
42
42
  class DefaultLiteLLMClientConfig:
43
- """Parses configuration for default LiteLLM client, resolves aliases and
44
- raises deprecation warnings.
43
+ """Parses configuration for default LiteLLM client.
44
+
45
+ Resolves aliases and raises deprecation warnings.
45
46
 
46
47
  Raises:
47
48
  ValueError: Raised in cases of invalid configuration:
@@ -72,8 +73,7 @@ class DefaultLiteLLMClientConfig:
72
73
 
73
74
  @classmethod
74
75
  def from_dict(cls, config: dict) -> DefaultLiteLLMClientConfig:
75
- """
76
- Initializes a dataclass from the passed config.
76
+ """Initializes a dataclass from the passed config.
77
77
 
78
78
  Args:
79
79
  config: (dict) The config from which to initialize.
@@ -38,8 +38,9 @@ _LITELLM_UNSUPPORTED_KEYS = [
38
38
 
39
39
  @dataclass
40
40
  class LiteLLMRouterClientConfig:
41
- """Parses configuration for a LiteLLM Router client. The configuration is expected
42
- to be in the following format:
41
+ """Parses configuration for a LiteLLM Router client.
42
+
43
+ The configuration is expected to be in the following format:
43
44
 
44
45
  {
45
46
  "id": "model_group_id",
@@ -64,8 +64,9 @@ FORBIDDEN_KEYS = [
64
64
 
65
65
  @dataclass
66
66
  class OpenAIClientConfig:
67
- """Parses configuration for Azure OpenAI client, resolves aliases and
68
- raises deprecation warnings.
67
+ """Parses configuration for OpenAI client.
68
+
69
+ Resolves aliases and raises deprecation warnings.
69
70
 
70
71
  Raises:
71
72
  ValueError: Raised in cases of invalid configuration:
@@ -118,8 +119,7 @@ class OpenAIClientConfig:
118
119
 
119
120
  @classmethod
120
121
  def from_dict(cls, config: dict) -> OpenAIClientConfig:
121
- """
122
- Initializes a dataclass from the passed config.
122
+ """Initializes a dataclass from the passed config.
123
123
 
124
124
  Args:
125
125
  config: (dict) The config from which to initialize.
@@ -168,9 +168,7 @@ class OpenAIClientConfig:
168
168
 
169
169
 
170
170
  def is_openai_config(config: dict) -> bool:
171
- """Check whether the configuration is meant to configure
172
- an OpenAI client.
173
- """
171
+ """Check whether the configuration is meant to configure an OpenAI client."""
174
172
  # Process the config to handle all the aliases
175
173
  config = OpenAIClientConfig.resolve_config_aliases(config)
176
174
 
@@ -22,8 +22,9 @@ structlogger = structlog.get_logger()
22
22
 
23
23
  @dataclass
24
24
  class RasaLLMClientConfig:
25
- """Parses configuration for a Rasa Hosted LiteLLM client,
26
- checks required keys present.
25
+ """Parses configuration for a Rasa Hosted LiteLLM client.
26
+
27
+ Checks required keys present.
27
28
 
28
29
  Raises:
29
30
  ValueError: Raised in cases of invalid configuration:
@@ -40,8 +41,7 @@ class RasaLLMClientConfig:
40
41
 
41
42
  @classmethod
42
43
  def from_dict(cls, config: dict) -> RasaLLMClientConfig:
43
- """
44
- Initializes a dataclass from the passed config.
44
+ """Initializes a dataclass from the passed config.
45
45
 
46
46
  Args:
47
47
  config: (dict) The config from which to initialize.
@@ -61,8 +61,9 @@ FORBIDDEN_KEYS = [
61
61
 
62
62
  @dataclass
63
63
  class SelfHostedLLMClientConfig:
64
- """Parses configuration for Self Hosted LiteLLM client, resolves aliases and
65
- raises deprecation warnings.
64
+ """Parses configuration for Self Hosted LiteLLM client.
65
+
66
+ Resolves aliases and raises deprecation warnings.
66
67
 
67
68
  Raises:
68
69
  ValueError: Raised in cases of invalid configuration:
@@ -116,8 +117,7 @@ class SelfHostedLLMClientConfig:
116
117
 
117
118
  @classmethod
118
119
  def from_dict(cls, config: dict) -> SelfHostedLLMClientConfig:
119
- """
120
- Initializes a dataclass from the passed config.
120
+ """Initializes a dataclass from the passed config.
121
121
 
122
122
  Args:
123
123
  config: (dict) The config from which to initialize.
@@ -1,12 +1,14 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import asyncio
3
4
  import logging
4
5
  from abc import abstractmethod
5
- from typing import Any, Dict, List, Union, cast
6
+ from typing import Any, Dict, List, NoReturn, Union, cast
6
7
 
7
8
  import structlog
8
9
  from litellm import acompletion, completion, validate_environment
9
10
 
11
+ from rasa.core.constants import DEFAULT_REQUEST_TIMEOUT
10
12
  from rasa.shared.constants import (
11
13
  _VALIDATE_ENVIRONMENT_MISSING_KEYS_KEY,
12
14
  API_BASE_CONFIG_KEY,
@@ -57,26 +59,24 @@ class _BaseLiteLLMClient:
57
59
  @property
58
60
  @abstractmethod
59
61
  def config(self) -> dict:
60
- """Returns the configuration for that the llm client
61
- in dictionary form.
62
- """
62
+ """Returns the configuration for that the llm client in dictionary form."""
63
63
  pass
64
64
 
65
65
  @property
66
66
  @abstractmethod
67
67
  def _litellm_model_name(self) -> str:
68
- """Returns the value of LiteLLM's model parameter to be used in
69
- completion/acompletion in LiteLLM format:
68
+ """Returns the value of LiteLLM's model parameter.
70
69
 
70
+ To be used in completion/acompletion in LiteLLM format:
71
71
  <provider>/<model or deployment name>
72
72
  """
73
73
  pass
74
74
 
75
75
  @property
76
76
  def _litellm_extra_parameters(self) -> Dict[str, Any]:
77
- """Returns a dictionary of extra parameters which include model
78
- parameters as well as LiteLLM specific input parameters.
77
+ """Returns a dictionary of extra parameters.
79
78
 
79
+ Includes model parameters as well as LiteLLM specific input parameters.
80
80
  By default, this returns an empty dictionary (no extra parameters).
81
81
  """
82
82
  return {}
@@ -96,8 +96,9 @@ class _BaseLiteLLMClient:
96
96
  }
97
97
 
98
98
  def validate_client_setup(self) -> None:
99
- """Perform client validation. By default only environment variables
100
- are validated.
99
+ """Perform client validation.
100
+
101
+ By default only environment variables are validated.
101
102
 
102
103
  Raises:
103
104
  ProviderClientValidationError if validation fails.
@@ -188,10 +189,17 @@ class _BaseLiteLLMClient:
188
189
  arguments = cast(
189
190
  Dict[str, Any], resolve_environment_variables(self._completion_fn_args)
190
191
  )
191
- response = await acompletion(
192
- messages=formatted_messages, **{**arguments, **kwargs}
192
+
193
+ timeout = self._litellm_extra_parameters.get(
194
+ "timeout", DEFAULT_REQUEST_TIMEOUT
195
+ )
196
+ response = await asyncio.wait_for(
197
+ acompletion(messages=formatted_messages, **{**arguments, **kwargs}),
198
+ timeout=timeout,
193
199
  )
194
200
  return self._format_response(response)
201
+ except asyncio.TimeoutError:
202
+ self._handle_timeout_error()
195
203
  except Exception as e:
196
204
  message = ""
197
205
  from rasa.shared.providers.llm.self_hosted_llm_client import (
@@ -211,6 +219,25 @@ class _BaseLiteLLMClient:
211
219
  )
212
220
  raise ProviderClientAPIException(e, message) from e
213
221
 
222
+ def _handle_timeout_error(self) -> NoReturn:
223
+ """Handle asyncio.TimeoutError and raise ProviderClientAPIException.
224
+
225
+ Raises:
226
+ ProviderClientAPIException: Always raised with formatted timeout error.
227
+ """
228
+ timeout = self._litellm_extra_parameters.get("timeout", DEFAULT_REQUEST_TIMEOUT)
229
+ error_message = (
230
+ f"APITimeoutError - Request timed out. Error_str: "
231
+ f"Request timed out. - timeout value={timeout:.6f}, "
232
+ f"time taken={timeout:.6f} seconds"
233
+ )
234
+ # nosemgrep: semgrep.rules.pii-positional-arguments-in-logging
235
+ # Error message contains only numeric timeout values, not PII
236
+ structlogger.error(
237
+ f"{self.__class__.__name__.lower()}.llm.timeout", error=error_message
238
+ )
239
+ raise ProviderClientAPIException(asyncio.TimeoutError(error_message)) from None
240
+
214
241
  def _get_formatted_messages(
215
242
  self, messages: Union[List[dict], List[str], str]
216
243
  ) -> List[Dict[str, str]]:
@@ -312,8 +339,9 @@ class _BaseLiteLLMClient:
312
339
 
313
340
  @staticmethod
314
341
  def _ensure_certificates() -> None:
315
- """Configures SSL certificates for LiteLLM. This method is invoked during
316
- client initialization.
342
+ """Configures SSL certificates for LiteLLM.
343
+
344
+ This method is invoked during client initialization.
317
345
 
318
346
  LiteLLM may utilize `openai` clients or other providers that require
319
347
  SSL verification settings through the `SSL_VERIFY` / `SSL_CERTIFICATE`