rasa-pro 3.14.1__py3-none-any.whl → 3.14.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/constants.py +1 -0
- rasa/core/actions/action_clean_stack.py +32 -0
- rasa/core/actions/constants.py +4 -0
- rasa/core/actions/custom_action_executor.py +70 -12
- rasa/core/actions/grpc_custom_action_executor.py +41 -2
- rasa/core/actions/http_custom_action_executor.py +49 -25
- rasa/core/channels/voice_stream/browser_audio.py +3 -3
- rasa/core/channels/voice_stream/voice_channel.py +27 -17
- rasa/core/config/credentials.py +3 -3
- rasa/core/policies/flows/flow_executor.py +49 -29
- rasa/core/run.py +21 -5
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +6 -3
- rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +15 -7
- rasa/dialogue_understanding/generator/single_step/search_ready_llm_command_generator.py +15 -8
- rasa/dialogue_understanding/processor/command_processor.py +13 -7
- rasa/e2e_test/e2e_config.py +4 -3
- rasa/engine/recipes/default_components.py +16 -6
- rasa/graph_components/validators/default_recipe_validator.py +10 -4
- rasa/nlu/classifiers/diet_classifier.py +2 -0
- rasa/shared/core/flows/flow.py +8 -2
- rasa/shared/core/slots.py +55 -24
- rasa/shared/providers/_configs/azure_openai_client_config.py +4 -5
- rasa/shared/providers/_configs/default_litellm_client_config.py +4 -4
- rasa/shared/providers/_configs/litellm_router_client_config.py +3 -2
- rasa/shared/providers/_configs/openai_client_config.py +5 -7
- rasa/shared/providers/_configs/rasa_llm_client_config.py +4 -4
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +4 -4
- rasa/shared/providers/llm/_base_litellm_client.py +42 -14
- rasa/shared/providers/llm/litellm_router_llm_client.py +38 -15
- rasa/shared/providers/llm/self_hosted_llm_client.py +34 -32
- rasa/shared/utils/common.py +9 -1
- rasa/shared/utils/configs.py +5 -8
- rasa/utils/common.py +9 -0
- rasa/utils/endpoints.py +6 -0
- rasa/utils/installation_utils.py +111 -0
- rasa/utils/tensorflow/callback.py +2 -0
- rasa/utils/tensorflow/models.py +3 -0
- rasa/utils/train_utils.py +2 -0
- rasa/version.py +1 -1
- {rasa_pro-3.14.1.dist-info → rasa_pro-3.14.2.dist-info}/METADATA +2 -2
- {rasa_pro-3.14.1.dist-info → rasa_pro-3.14.2.dist-info}/RECORD +44 -43
- {rasa_pro-3.14.1.dist-info → rasa_pro-3.14.2.dist-info}/NOTICE +0 -0
- {rasa_pro-3.14.1.dist-info → rasa_pro-3.14.2.dist-info}/WHEEL +0 -0
- {rasa_pro-3.14.1.dist-info → rasa_pro-3.14.2.dist-info}/entry_points.txt +0 -0
|
@@ -62,7 +62,9 @@ structlogger = structlog.get_logger()
|
|
|
62
62
|
class LLMBasedCommandGenerator(
|
|
63
63
|
LLMHealthCheckMixin, GraphComponent, CommandGenerator, ABC
|
|
64
64
|
):
|
|
65
|
-
"""
|
|
65
|
+
"""This class provides common functionality for all LLM-based command generators.
|
|
66
|
+
|
|
67
|
+
An abstract class defining interface and common functionality
|
|
66
68
|
of an LLM-based command generators.
|
|
67
69
|
"""
|
|
68
70
|
|
|
@@ -174,8 +176,9 @@ class LLMBasedCommandGenerator(
|
|
|
174
176
|
def train(
|
|
175
177
|
self, training_data: TrainingData, flows: FlowsList, domain: Domain
|
|
176
178
|
) -> Resource:
|
|
177
|
-
"""
|
|
178
|
-
|
|
179
|
+
"""Trains the LLM-based command generator and prepares flow retrieval data.
|
|
180
|
+
|
|
181
|
+
Stores all flows into a vector store.
|
|
179
182
|
"""
|
|
180
183
|
self.perform_llm_health_check(
|
|
181
184
|
self.config.get(LLM_CONFIG_KEY),
|
|
@@ -168,6 +168,20 @@ class CompactLLMCommandGenerator(SingleStepBasedLLMCommandGenerator):
|
|
|
168
168
|
if prompt_template is not None:
|
|
169
169
|
return prompt_template
|
|
170
170
|
|
|
171
|
+
# Try to load the template from the given path or fallback to the default for
|
|
172
|
+
# the component.
|
|
173
|
+
custom_prompt_template_path = config.get(PROMPT_TEMPLATE_CONFIG_KEY)
|
|
174
|
+
if custom_prompt_template_path is not None:
|
|
175
|
+
custom_prompt_template = get_prompt_template(
|
|
176
|
+
custom_prompt_template_path,
|
|
177
|
+
None, # Default will be based on the model
|
|
178
|
+
log_source_component=log_source_component,
|
|
179
|
+
log_source_method=log_context,
|
|
180
|
+
)
|
|
181
|
+
if custom_prompt_template is not None:
|
|
182
|
+
return custom_prompt_template
|
|
183
|
+
|
|
184
|
+
# Fallback to the default prompt template based on the model.
|
|
171
185
|
default_command_prompt_template = get_default_prompt_template_based_on_model(
|
|
172
186
|
llm_config=config.get(LLM_CONFIG_KEY, {}) or {},
|
|
173
187
|
model_prompt_mapping=cls.get_model_prompt_mapper(),
|
|
@@ -177,10 +191,4 @@ class CompactLLMCommandGenerator(SingleStepBasedLLMCommandGenerator):
|
|
|
177
191
|
log_source_method=log_context,
|
|
178
192
|
)
|
|
179
193
|
|
|
180
|
-
|
|
181
|
-
return get_prompt_template(
|
|
182
|
-
config.get(PROMPT_TEMPLATE_CONFIG_KEY),
|
|
183
|
-
default_command_prompt_template,
|
|
184
|
-
log_source_component=log_source_component,
|
|
185
|
-
log_source_method=log_context,
|
|
186
|
-
)
|
|
194
|
+
return default_command_prompt_template
|
|
@@ -165,7 +165,20 @@ class SearchReadyLLMCommandGenerator(SingleStepBasedLLMCommandGenerator):
|
|
|
165
165
|
if prompt_template is not None:
|
|
166
166
|
return prompt_template
|
|
167
167
|
|
|
168
|
-
#
|
|
168
|
+
# Try to load the template from the given path or fallback to the default for
|
|
169
|
+
# the component.
|
|
170
|
+
custom_prompt_template_path = config.get(PROMPT_TEMPLATE_CONFIG_KEY)
|
|
171
|
+
if custom_prompt_template_path is not None:
|
|
172
|
+
custom_prompt_template = get_prompt_template(
|
|
173
|
+
custom_prompt_template_path,
|
|
174
|
+
None, # Default will be based on the model
|
|
175
|
+
log_source_component=log_source_component,
|
|
176
|
+
log_source_method=log_context,
|
|
177
|
+
)
|
|
178
|
+
if custom_prompt_template is not None:
|
|
179
|
+
return custom_prompt_template
|
|
180
|
+
|
|
181
|
+
# Fallback to the default prompt template based on the model.
|
|
169
182
|
default_command_prompt_template = get_default_prompt_template_based_on_model(
|
|
170
183
|
llm_config=config.get(LLM_CONFIG_KEY, {}) or {},
|
|
171
184
|
model_prompt_mapping=cls.get_model_prompt_mapper(),
|
|
@@ -175,10 +188,4 @@ class SearchReadyLLMCommandGenerator(SingleStepBasedLLMCommandGenerator):
|
|
|
175
188
|
log_source_method=log_context,
|
|
176
189
|
)
|
|
177
190
|
|
|
178
|
-
|
|
179
|
-
return get_prompt_template(
|
|
180
|
-
config.get(PROMPT_TEMPLATE_CONFIG_KEY),
|
|
181
|
-
default_command_prompt_template,
|
|
182
|
-
log_source_component=log_source_component,
|
|
183
|
-
log_source_method=log_context,
|
|
184
|
-
)
|
|
191
|
+
return default_command_prompt_template
|
|
@@ -499,10 +499,22 @@ def clean_up_commands(
|
|
|
499
499
|
else:
|
|
500
500
|
clean_commands.append(command)
|
|
501
501
|
|
|
502
|
+
# ensure that there is only one command of a certain command type
|
|
503
|
+
clean_commands = ensure_max_number_of_command_type(
|
|
504
|
+
clean_commands, CannotHandleCommand, 1
|
|
505
|
+
)
|
|
506
|
+
clean_commands = ensure_max_number_of_command_type(
|
|
507
|
+
clean_commands, RepeatBotMessagesCommand, 1
|
|
508
|
+
)
|
|
509
|
+
clean_commands = ensure_max_number_of_command_type(
|
|
510
|
+
clean_commands, ChitChatAnswerCommand, 1
|
|
511
|
+
)
|
|
512
|
+
|
|
502
513
|
# Replace CannotHandleCommands with ContinueAgentCommand when an agent is active
|
|
503
514
|
# to keep the agent running, but preserve chitchat
|
|
504
515
|
clean_commands = _replace_cannot_handle_with_continue_agent(clean_commands, tracker)
|
|
505
516
|
|
|
517
|
+
# filter out cannot handle commands if there are other commands present
|
|
506
518
|
# when coexistence is enabled, by default there will be a SetSlotCommand
|
|
507
519
|
# for the ROUTE_TO_CALM_SLOT slot.
|
|
508
520
|
if tracker.has_coexistence_routing_slot and len(clean_commands) > 2:
|
|
@@ -510,12 +522,6 @@ def clean_up_commands(
|
|
|
510
522
|
elif not tracker.has_coexistence_routing_slot and len(clean_commands) > 1:
|
|
511
523
|
clean_commands = filter_cannot_handle_command(clean_commands)
|
|
512
524
|
|
|
513
|
-
clean_commands = ensure_max_number_of_command_type(
|
|
514
|
-
clean_commands, RepeatBotMessagesCommand, 1
|
|
515
|
-
)
|
|
516
|
-
clean_commands = ensure_max_number_of_command_type(
|
|
517
|
-
clean_commands, ContinueAgentCommand, 1
|
|
518
|
-
)
|
|
519
525
|
structlogger.debug(
|
|
520
526
|
"command_processor.clean_up_commands.final_commands",
|
|
521
527
|
command=clean_commands,
|
|
@@ -580,7 +586,7 @@ def clean_up_start_flow_command(
|
|
|
580
586
|
# drop a start flow command if the starting flow is equal
|
|
581
587
|
# to the currently active flow
|
|
582
588
|
structlogger.debug(
|
|
583
|
-
"command_processor.clean_up_commands.
|
|
589
|
+
"command_processor.clean_up_commands.skip_command_flow_already_active",
|
|
584
590
|
command=command,
|
|
585
591
|
)
|
|
586
592
|
return clean_commands
|
rasa/e2e_test/e2e_config.py
CHANGED
|
@@ -72,9 +72,10 @@ class LLMJudgeConfig(BaseModel):
|
|
|
72
72
|
|
|
73
73
|
llm_config = resolve_model_client_config(llm_config)
|
|
74
74
|
llm_config, llm_extra_parameters = cls.extract_attributes(llm_config)
|
|
75
|
-
|
|
76
|
-
llm_config
|
|
77
|
-
|
|
75
|
+
if not llm_config:
|
|
76
|
+
llm_config = combine_custom_and_default_config(
|
|
77
|
+
llm_config, cls.get_default_llm_config()
|
|
78
|
+
)
|
|
78
79
|
embeddings_config = resolve_model_client_config(embeddings)
|
|
79
80
|
embeddings_config, embeddings_extra_parameters = cls.extract_attributes(
|
|
80
81
|
embeddings_config
|
|
@@ -27,22 +27,32 @@ from rasa.shared.utils.common import conditional_import
|
|
|
27
27
|
|
|
28
28
|
# components dependent on tensorflow
|
|
29
29
|
TEDPolicy, TED_POLICY_AVAILABLE = conditional_import(
|
|
30
|
-
"rasa.core.policies.ted_policy", "TEDPolicy"
|
|
30
|
+
"rasa.core.policies.ted_policy", "TEDPolicy", check_installation_setup=True
|
|
31
31
|
)
|
|
32
32
|
UnexpecTEDIntentPolicy, UNEXPECTED_INTENT_POLICY_AVAILABLE = conditional_import(
|
|
33
|
-
"rasa.core.policies.unexpected_intent_policy",
|
|
33
|
+
"rasa.core.policies.unexpected_intent_policy",
|
|
34
|
+
"UnexpecTEDIntentPolicy",
|
|
35
|
+
check_installation_setup=True,
|
|
34
36
|
)
|
|
35
37
|
DIETClassifier, DIET_CLASSIFIER_AVAILABLE = conditional_import(
|
|
36
|
-
"rasa.nlu.classifiers.diet_classifier",
|
|
38
|
+
"rasa.nlu.classifiers.diet_classifier",
|
|
39
|
+
"DIETClassifier",
|
|
40
|
+
check_installation_setup=True,
|
|
37
41
|
)
|
|
38
42
|
ConveRTFeaturizer, CONVERT_FEATURIZER_AVAILABLE = conditional_import(
|
|
39
|
-
"rasa.nlu.featurizers.dense_featurizer.convert_featurizer",
|
|
43
|
+
"rasa.nlu.featurizers.dense_featurizer.convert_featurizer",
|
|
44
|
+
"ConveRTFeaturizer",
|
|
45
|
+
check_installation_setup=True,
|
|
40
46
|
)
|
|
41
47
|
LanguageModelFeaturizer, LANGUAGE_MODEL_FEATURIZER_AVAILABLE = conditional_import(
|
|
42
|
-
"rasa.nlu.featurizers.dense_featurizer.lm_featurizer",
|
|
48
|
+
"rasa.nlu.featurizers.dense_featurizer.lm_featurizer",
|
|
49
|
+
"LanguageModelFeaturizer",
|
|
50
|
+
check_installation_setup=True,
|
|
43
51
|
)
|
|
44
52
|
ResponseSelector, RESPONSE_SELECTOR_AVAILABLE = conditional_import(
|
|
45
|
-
"rasa.nlu.selectors.response_selector",
|
|
53
|
+
"rasa.nlu.selectors.response_selector",
|
|
54
|
+
"ResponseSelector",
|
|
55
|
+
check_installation_setup=True,
|
|
46
56
|
)
|
|
47
57
|
|
|
48
58
|
# components dependent on skops
|
|
@@ -40,16 +40,22 @@ from rasa.shared.utils.common import conditional_import
|
|
|
40
40
|
|
|
41
41
|
# Conditional imports for TensorFlow-dependent components
|
|
42
42
|
TEDPolicy, TED_POLICY_AVAILABLE = conditional_import(
|
|
43
|
-
"rasa.core.policies.ted_policy", "TEDPolicy"
|
|
43
|
+
"rasa.core.policies.ted_policy", "TEDPolicy", check_installation_setup=True
|
|
44
44
|
)
|
|
45
45
|
UnexpecTEDIntentPolicy, UNEXPECTED_INTENT_POLICY_AVAILABLE = conditional_import(
|
|
46
|
-
"rasa.core.policies.unexpected_intent_policy",
|
|
46
|
+
"rasa.core.policies.unexpected_intent_policy",
|
|
47
|
+
"UnexpecTEDIntentPolicy",
|
|
48
|
+
check_installation_setup=True,
|
|
47
49
|
)
|
|
48
50
|
DIETClassifier, DIET_CLASSIFIER_AVAILABLE = conditional_import(
|
|
49
|
-
"rasa.nlu.classifiers.diet_classifier",
|
|
51
|
+
"rasa.nlu.classifiers.diet_classifier",
|
|
52
|
+
"DIETClassifier",
|
|
53
|
+
check_installation_setup=True,
|
|
50
54
|
)
|
|
51
55
|
ResponseSelector, RESPONSE_SELECTOR_AVAILABLE = conditional_import(
|
|
52
|
-
"rasa.nlu.selectors.response_selector",
|
|
56
|
+
"rasa.nlu.selectors.response_selector",
|
|
57
|
+
"ResponseSelector",
|
|
58
|
+
check_installation_setup=True,
|
|
53
59
|
)
|
|
54
60
|
|
|
55
61
|
# Conditional imports for nlu components requiring other dependencies than tensorflow
|
|
@@ -9,9 +9,11 @@ from typing import Any, Dict, List, Optional, Text, Tuple, Type, TypeVar, Union
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
import scipy.sparse
|
|
11
11
|
|
|
12
|
+
from rasa.utils.installation_utils import check_for_installation_issues
|
|
12
13
|
from rasa.utils.tensorflow import TENSORFLOW_AVAILABLE
|
|
13
14
|
|
|
14
15
|
if TENSORFLOW_AVAILABLE:
|
|
16
|
+
check_for_installation_issues()
|
|
15
17
|
import tensorflow as tf
|
|
16
18
|
else:
|
|
17
19
|
tf = None
|
rasa/shared/core/flows/flow.py
CHANGED
|
@@ -322,9 +322,15 @@ class Flow:
|
|
|
322
322
|
|
|
323
323
|
def get_collect_steps(self) -> List[CollectInformationFlowStep]:
|
|
324
324
|
"""Return all CollectInformationFlowSteps in the flow."""
|
|
325
|
-
collect_steps = []
|
|
325
|
+
collect_steps: List[CollectInformationFlowStep] = []
|
|
326
326
|
for step in self.steps_with_calls_resolved:
|
|
327
|
-
|
|
327
|
+
# Only add collect steps that are not already in the list.
|
|
328
|
+
# This is to avoid returning duplicate collect steps from called flows
|
|
329
|
+
# in case the called flow is called multiple times.
|
|
330
|
+
if (
|
|
331
|
+
isinstance(step, CollectInformationFlowStep)
|
|
332
|
+
and step not in collect_steps
|
|
333
|
+
):
|
|
328
334
|
collect_steps.append(step)
|
|
329
335
|
return collect_steps
|
|
330
336
|
|
rasa/shared/core/slots.py
CHANGED
|
@@ -355,8 +355,8 @@ class FloatSlot(Slot):
|
|
|
355
355
|
mappings: List[Dict[Text, Any]],
|
|
356
356
|
initial_value: Optional[float] = None,
|
|
357
357
|
value_reset_delay: Optional[int] = None,
|
|
358
|
-
max_value: float =
|
|
359
|
-
min_value: float =
|
|
358
|
+
max_value: Optional[float] = None,
|
|
359
|
+
min_value: Optional[float] = None,
|
|
360
360
|
influence_conversation: bool = True,
|
|
361
361
|
is_builtin: bool = False,
|
|
362
362
|
shared_for_coexistence: bool = False,
|
|
@@ -380,32 +380,24 @@ class FloatSlot(Slot):
|
|
|
380
380
|
filled_by=filled_by,
|
|
381
381
|
validation=validation,
|
|
382
382
|
)
|
|
383
|
+
self.validate_min_max_range(min_value, max_value)
|
|
384
|
+
|
|
383
385
|
self.max_value = max_value
|
|
384
386
|
self.min_value = min_value
|
|
385
387
|
|
|
386
|
-
if min_value >= max_value:
|
|
387
|
-
raise InvalidSlotConfigError(
|
|
388
|
-
"Float slot ('{}') created with an invalid range "
|
|
389
|
-
"using min ({}) and max ({}) values. Make sure "
|
|
390
|
-
"min is smaller than max."
|
|
391
|
-
"".format(self.name, self.min_value, self.max_value)
|
|
392
|
-
)
|
|
393
|
-
|
|
394
|
-
if initial_value is not None and not (min_value <= initial_value <= max_value):
|
|
395
|
-
rasa.shared.utils.io.raise_warning(
|
|
396
|
-
f"Float slot ('{self.name}') created with an initial value "
|
|
397
|
-
f"{self.value}. This value is outside of the configured min "
|
|
398
|
-
f"({self.min_value}) and max ({self.max_value}) values."
|
|
399
|
-
)
|
|
400
|
-
|
|
401
388
|
def _as_feature(self) -> List[float]:
|
|
389
|
+
# set default min and max values used in prior releases
|
|
390
|
+
# to prevent regressions for existing models
|
|
391
|
+
min_value = self.min_value or 0.0
|
|
392
|
+
max_value = self.max_value or 1.0
|
|
393
|
+
|
|
402
394
|
try:
|
|
403
|
-
capped_value = max(
|
|
404
|
-
if abs(
|
|
405
|
-
covered_range = abs(
|
|
395
|
+
capped_value = max(min_value, min(max_value, float(self.value)))
|
|
396
|
+
if abs(max_value - min_value) > 0:
|
|
397
|
+
covered_range = abs(max_value - min_value)
|
|
406
398
|
else:
|
|
407
399
|
covered_range = 1
|
|
408
|
-
return [1.0, (capped_value -
|
|
400
|
+
return [1.0, (capped_value - min_value) / covered_range]
|
|
409
401
|
except (TypeError, ValueError):
|
|
410
402
|
return [0.0, 0.0]
|
|
411
403
|
|
|
@@ -424,13 +416,52 @@ class FloatSlot(Slot):
|
|
|
424
416
|
return value
|
|
425
417
|
|
|
426
418
|
def is_valid_value(self, value: Any) -> bool:
|
|
427
|
-
"""Checks if the slot
|
|
428
|
-
|
|
429
|
-
|
|
419
|
+
"""Checks if the slot value is valid."""
|
|
420
|
+
if value is None:
|
|
421
|
+
return True
|
|
422
|
+
|
|
423
|
+
if not isinstance(self.coerce_value(value), float):
|
|
424
|
+
return False
|
|
425
|
+
|
|
426
|
+
if (
|
|
427
|
+
self.min_value is not None
|
|
428
|
+
and self.max_value is not None
|
|
429
|
+
and not (self.min_value <= value <= self.max_value)
|
|
430
|
+
):
|
|
431
|
+
return False
|
|
432
|
+
|
|
433
|
+
return True
|
|
430
434
|
|
|
431
435
|
def _feature_dimensionality(self) -> int:
|
|
432
436
|
return len(self.as_feature())
|
|
433
437
|
|
|
438
|
+
def validate_min_max_range(
|
|
439
|
+
self, min_value: Optional[float], max_value: Optional[float]
|
|
440
|
+
) -> None:
|
|
441
|
+
"""Validates the min-max range for the slot.
|
|
442
|
+
|
|
443
|
+
Raises:
|
|
444
|
+
InvalidSlotConfigError, if the min-max range is invalid.
|
|
445
|
+
"""
|
|
446
|
+
if min_value is not None and max_value is not None and min_value >= max_value:
|
|
447
|
+
raise InvalidSlotConfigError(
|
|
448
|
+
f"Float slot ('{self.name}') created with an invalid range "
|
|
449
|
+
f"using min ({min_value}) and max ({max_value}) values. Make sure "
|
|
450
|
+
f"min is smaller than max."
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
if (
|
|
454
|
+
self.initial_value is not None
|
|
455
|
+
and min_value is not None
|
|
456
|
+
and max_value is not None
|
|
457
|
+
and not (min_value <= self.initial_value <= max_value)
|
|
458
|
+
):
|
|
459
|
+
raise InvalidSlotConfigError(
|
|
460
|
+
f"Float slot ('{self.name}') created with an initial value "
|
|
461
|
+
f"{self.initial_value}. This value is outside of the configured min "
|
|
462
|
+
f"({min_value}) and max ({max_value}) values."
|
|
463
|
+
)
|
|
464
|
+
|
|
434
465
|
|
|
435
466
|
class BooleanSlot(Slot):
|
|
436
467
|
"""A slot storing a truth value."""
|
|
@@ -167,8 +167,9 @@ class OAuthConfigWrapper(OAuth, BaseModel):
|
|
|
167
167
|
|
|
168
168
|
@dataclass
|
|
169
169
|
class AzureOpenAIClientConfig:
|
|
170
|
-
"""Parses configuration for Azure OpenAI client
|
|
171
|
-
|
|
170
|
+
"""Parses configuration for Azure OpenAI client.
|
|
171
|
+
|
|
172
|
+
Resolves aliases and raises deprecation warnings.
|
|
172
173
|
|
|
173
174
|
Raises:
|
|
174
175
|
ValueError: Raised in cases of invalid configuration:
|
|
@@ -301,9 +302,7 @@ class AzureOpenAIClientConfig:
|
|
|
301
302
|
|
|
302
303
|
|
|
303
304
|
def is_azure_openai_config(config: dict) -> bool:
|
|
304
|
-
"""Check whether the configuration is meant to configure
|
|
305
|
-
an Azure OpenAI client.
|
|
306
|
-
"""
|
|
305
|
+
"""Check whether the configuration is meant to configure an Azure OpenAI client."""
|
|
307
306
|
# Resolve any aliases that are specific to Azure OpenAI configuration
|
|
308
307
|
config = AzureOpenAIClientConfig.resolve_config_aliases(config)
|
|
309
308
|
|
|
@@ -40,8 +40,9 @@ FORBIDDEN_KEYS = [
|
|
|
40
40
|
|
|
41
41
|
@dataclass
|
|
42
42
|
class DefaultLiteLLMClientConfig:
|
|
43
|
-
"""Parses configuration for default LiteLLM client
|
|
44
|
-
|
|
43
|
+
"""Parses configuration for default LiteLLM client.
|
|
44
|
+
|
|
45
|
+
Resolves aliases and raises deprecation warnings.
|
|
45
46
|
|
|
46
47
|
Raises:
|
|
47
48
|
ValueError: Raised in cases of invalid configuration:
|
|
@@ -72,8 +73,7 @@ class DefaultLiteLLMClientConfig:
|
|
|
72
73
|
|
|
73
74
|
@classmethod
|
|
74
75
|
def from_dict(cls, config: dict) -> DefaultLiteLLMClientConfig:
|
|
75
|
-
"""
|
|
76
|
-
Initializes a dataclass from the passed config.
|
|
76
|
+
"""Initializes a dataclass from the passed config.
|
|
77
77
|
|
|
78
78
|
Args:
|
|
79
79
|
config: (dict) The config from which to initialize.
|
|
@@ -38,8 +38,9 @@ _LITELLM_UNSUPPORTED_KEYS = [
|
|
|
38
38
|
|
|
39
39
|
@dataclass
|
|
40
40
|
class LiteLLMRouterClientConfig:
|
|
41
|
-
"""Parses configuration for a LiteLLM Router client.
|
|
42
|
-
|
|
41
|
+
"""Parses configuration for a LiteLLM Router client.
|
|
42
|
+
|
|
43
|
+
The configuration is expected to be in the following format:
|
|
43
44
|
|
|
44
45
|
{
|
|
45
46
|
"id": "model_group_id",
|
|
@@ -64,8 +64,9 @@ FORBIDDEN_KEYS = [
|
|
|
64
64
|
|
|
65
65
|
@dataclass
|
|
66
66
|
class OpenAIClientConfig:
|
|
67
|
-
"""Parses configuration for
|
|
68
|
-
|
|
67
|
+
"""Parses configuration for OpenAI client.
|
|
68
|
+
|
|
69
|
+
Resolves aliases and raises deprecation warnings.
|
|
69
70
|
|
|
70
71
|
Raises:
|
|
71
72
|
ValueError: Raised in cases of invalid configuration:
|
|
@@ -118,8 +119,7 @@ class OpenAIClientConfig:
|
|
|
118
119
|
|
|
119
120
|
@classmethod
|
|
120
121
|
def from_dict(cls, config: dict) -> OpenAIClientConfig:
|
|
121
|
-
"""
|
|
122
|
-
Initializes a dataclass from the passed config.
|
|
122
|
+
"""Initializes a dataclass from the passed config.
|
|
123
123
|
|
|
124
124
|
Args:
|
|
125
125
|
config: (dict) The config from which to initialize.
|
|
@@ -168,9 +168,7 @@ class OpenAIClientConfig:
|
|
|
168
168
|
|
|
169
169
|
|
|
170
170
|
def is_openai_config(config: dict) -> bool:
|
|
171
|
-
"""Check whether the configuration is meant to configure
|
|
172
|
-
an OpenAI client.
|
|
173
|
-
"""
|
|
171
|
+
"""Check whether the configuration is meant to configure an OpenAI client."""
|
|
174
172
|
# Process the config to handle all the aliases
|
|
175
173
|
config = OpenAIClientConfig.resolve_config_aliases(config)
|
|
176
174
|
|
|
@@ -22,8 +22,9 @@ structlogger = structlog.get_logger()
|
|
|
22
22
|
|
|
23
23
|
@dataclass
|
|
24
24
|
class RasaLLMClientConfig:
|
|
25
|
-
"""Parses configuration for a Rasa Hosted LiteLLM client
|
|
26
|
-
|
|
25
|
+
"""Parses configuration for a Rasa Hosted LiteLLM client.
|
|
26
|
+
|
|
27
|
+
Checks required keys present.
|
|
27
28
|
|
|
28
29
|
Raises:
|
|
29
30
|
ValueError: Raised in cases of invalid configuration:
|
|
@@ -40,8 +41,7 @@ class RasaLLMClientConfig:
|
|
|
40
41
|
|
|
41
42
|
@classmethod
|
|
42
43
|
def from_dict(cls, config: dict) -> RasaLLMClientConfig:
|
|
43
|
-
"""
|
|
44
|
-
Initializes a dataclass from the passed config.
|
|
44
|
+
"""Initializes a dataclass from the passed config.
|
|
45
45
|
|
|
46
46
|
Args:
|
|
47
47
|
config: (dict) The config from which to initialize.
|
|
@@ -61,8 +61,9 @@ FORBIDDEN_KEYS = [
|
|
|
61
61
|
|
|
62
62
|
@dataclass
|
|
63
63
|
class SelfHostedLLMClientConfig:
|
|
64
|
-
"""Parses configuration for Self Hosted LiteLLM client
|
|
65
|
-
|
|
64
|
+
"""Parses configuration for Self Hosted LiteLLM client.
|
|
65
|
+
|
|
66
|
+
Resolves aliases and raises deprecation warnings.
|
|
66
67
|
|
|
67
68
|
Raises:
|
|
68
69
|
ValueError: Raised in cases of invalid configuration:
|
|
@@ -116,8 +117,7 @@ class SelfHostedLLMClientConfig:
|
|
|
116
117
|
|
|
117
118
|
@classmethod
|
|
118
119
|
def from_dict(cls, config: dict) -> SelfHostedLLMClientConfig:
|
|
119
|
-
"""
|
|
120
|
-
Initializes a dataclass from the passed config.
|
|
120
|
+
"""Initializes a dataclass from the passed config.
|
|
121
121
|
|
|
122
122
|
Args:
|
|
123
123
|
config: (dict) The config from which to initialize.
|
|
@@ -1,12 +1,14 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
import logging
|
|
4
5
|
from abc import abstractmethod
|
|
5
|
-
from typing import Any, Dict, List, Union, cast
|
|
6
|
+
from typing import Any, Dict, List, NoReturn, Union, cast
|
|
6
7
|
|
|
7
8
|
import structlog
|
|
8
9
|
from litellm import acompletion, completion, validate_environment
|
|
9
10
|
|
|
11
|
+
from rasa.core.constants import DEFAULT_REQUEST_TIMEOUT
|
|
10
12
|
from rasa.shared.constants import (
|
|
11
13
|
_VALIDATE_ENVIRONMENT_MISSING_KEYS_KEY,
|
|
12
14
|
API_BASE_CONFIG_KEY,
|
|
@@ -57,26 +59,24 @@ class _BaseLiteLLMClient:
|
|
|
57
59
|
@property
|
|
58
60
|
@abstractmethod
|
|
59
61
|
def config(self) -> dict:
|
|
60
|
-
"""Returns the configuration for that the llm client
|
|
61
|
-
in dictionary form.
|
|
62
|
-
"""
|
|
62
|
+
"""Returns the configuration for that the llm client in dictionary form."""
|
|
63
63
|
pass
|
|
64
64
|
|
|
65
65
|
@property
|
|
66
66
|
@abstractmethod
|
|
67
67
|
def _litellm_model_name(self) -> str:
|
|
68
|
-
"""Returns the value of LiteLLM's model parameter
|
|
69
|
-
completion/acompletion in LiteLLM format:
|
|
68
|
+
"""Returns the value of LiteLLM's model parameter.
|
|
70
69
|
|
|
70
|
+
To be used in completion/acompletion in LiteLLM format:
|
|
71
71
|
<provider>/<model or deployment name>
|
|
72
72
|
"""
|
|
73
73
|
pass
|
|
74
74
|
|
|
75
75
|
@property
|
|
76
76
|
def _litellm_extra_parameters(self) -> Dict[str, Any]:
|
|
77
|
-
"""Returns a dictionary of extra parameters
|
|
78
|
-
parameters as well as LiteLLM specific input parameters.
|
|
77
|
+
"""Returns a dictionary of extra parameters.
|
|
79
78
|
|
|
79
|
+
Includes model parameters as well as LiteLLM specific input parameters.
|
|
80
80
|
By default, this returns an empty dictionary (no extra parameters).
|
|
81
81
|
"""
|
|
82
82
|
return {}
|
|
@@ -96,8 +96,9 @@ class _BaseLiteLLMClient:
|
|
|
96
96
|
}
|
|
97
97
|
|
|
98
98
|
def validate_client_setup(self) -> None:
|
|
99
|
-
"""Perform client validation.
|
|
100
|
-
|
|
99
|
+
"""Perform client validation.
|
|
100
|
+
|
|
101
|
+
By default only environment variables are validated.
|
|
101
102
|
|
|
102
103
|
Raises:
|
|
103
104
|
ProviderClientValidationError if validation fails.
|
|
@@ -188,10 +189,17 @@ class _BaseLiteLLMClient:
|
|
|
188
189
|
arguments = cast(
|
|
189
190
|
Dict[str, Any], resolve_environment_variables(self._completion_fn_args)
|
|
190
191
|
)
|
|
191
|
-
|
|
192
|
-
|
|
192
|
+
|
|
193
|
+
timeout = self._litellm_extra_parameters.get(
|
|
194
|
+
"timeout", DEFAULT_REQUEST_TIMEOUT
|
|
195
|
+
)
|
|
196
|
+
response = await asyncio.wait_for(
|
|
197
|
+
acompletion(messages=formatted_messages, **{**arguments, **kwargs}),
|
|
198
|
+
timeout=timeout,
|
|
193
199
|
)
|
|
194
200
|
return self._format_response(response)
|
|
201
|
+
except asyncio.TimeoutError:
|
|
202
|
+
self._handle_timeout_error()
|
|
195
203
|
except Exception as e:
|
|
196
204
|
message = ""
|
|
197
205
|
from rasa.shared.providers.llm.self_hosted_llm_client import (
|
|
@@ -211,6 +219,25 @@ class _BaseLiteLLMClient:
|
|
|
211
219
|
)
|
|
212
220
|
raise ProviderClientAPIException(e, message) from e
|
|
213
221
|
|
|
222
|
+
def _handle_timeout_error(self) -> NoReturn:
|
|
223
|
+
"""Handle asyncio.TimeoutError and raise ProviderClientAPIException.
|
|
224
|
+
|
|
225
|
+
Raises:
|
|
226
|
+
ProviderClientAPIException: Always raised with formatted timeout error.
|
|
227
|
+
"""
|
|
228
|
+
timeout = self._litellm_extra_parameters.get("timeout", DEFAULT_REQUEST_TIMEOUT)
|
|
229
|
+
error_message = (
|
|
230
|
+
f"APITimeoutError - Request timed out. Error_str: "
|
|
231
|
+
f"Request timed out. - timeout value={timeout:.6f}, "
|
|
232
|
+
f"time taken={timeout:.6f} seconds"
|
|
233
|
+
)
|
|
234
|
+
# nosemgrep: semgrep.rules.pii-positional-arguments-in-logging
|
|
235
|
+
# Error message contains only numeric timeout values, not PII
|
|
236
|
+
structlogger.error(
|
|
237
|
+
f"{self.__class__.__name__.lower()}.llm.timeout", error=error_message
|
|
238
|
+
)
|
|
239
|
+
raise ProviderClientAPIException(asyncio.TimeoutError(error_message)) from None
|
|
240
|
+
|
|
214
241
|
def _get_formatted_messages(
|
|
215
242
|
self, messages: Union[List[dict], List[str], str]
|
|
216
243
|
) -> List[Dict[str, str]]:
|
|
@@ -312,8 +339,9 @@ class _BaseLiteLLMClient:
|
|
|
312
339
|
|
|
313
340
|
@staticmethod
|
|
314
341
|
def _ensure_certificates() -> None:
|
|
315
|
-
"""Configures SSL certificates for LiteLLM.
|
|
316
|
-
|
|
342
|
+
"""Configures SSL certificates for LiteLLM.
|
|
343
|
+
|
|
344
|
+
This method is invoked during client initialization.
|
|
317
345
|
|
|
318
346
|
LiteLLM may utilize `openai` clients or other providers that require
|
|
319
347
|
SSL verification settings through the `SSL_VERIFY` / `SSL_CERTIFICATE`
|