rasa-pro 3.11.0a4.dev3__py3-none-any.whl → 3.11.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/__main__.py +22 -12
- rasa/api.py +1 -1
- rasa/cli/arguments/default_arguments.py +1 -2
- rasa/cli/arguments/shell.py +5 -1
- rasa/cli/e2e_test.py +1 -1
- rasa/cli/evaluate.py +8 -8
- rasa/cli/inspect.py +6 -4
- rasa/cli/llm_fine_tuning.py +1 -1
- rasa/cli/project_templates/calm/config.yml +5 -7
- rasa/cli/project_templates/calm/endpoints.yml +8 -0
- rasa/cli/project_templates/tutorial/config.yml +8 -5
- rasa/cli/project_templates/tutorial/data/flows.yml +1 -1
- rasa/cli/project_templates/tutorial/data/patterns.yml +5 -0
- rasa/cli/project_templates/tutorial/domain.yml +14 -0
- rasa/cli/project_templates/tutorial/endpoints.yml +7 -7
- rasa/cli/run.py +1 -1
- rasa/cli/scaffold.py +4 -2
- rasa/cli/studio/studio.py +18 -8
- rasa/cli/utils.py +5 -0
- rasa/cli/x.py +8 -8
- rasa/constants.py +1 -1
- rasa/core/actions/action_repeat_bot_messages.py +17 -0
- rasa/core/channels/channel.py +20 -0
- rasa/core/channels/inspector/dist/assets/{arc-6852c607.js → arc-bc141fb2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{c4Diagram-d0fbc5ce-acc952b2.js → c4Diagram-d0fbc5ce-be2db283.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-936ed81e-848a7597.js → classDiagram-936ed81e-55366915.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-v2-c3cb15f1-a73d3e68.js → classDiagram-v2-c3cb15f1-bb529518.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{createText-62fc7601-e5ee049d.js → createText-62fc7601-b0ec81d6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{edges-f2ad444c-771e517e.js → edges-f2ad444c-6166330c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{erDiagram-9d236eb7-aa347178.js → erDiagram-9d236eb7-5ccc6a8e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDb-1972c806-651fc57d.js → flowDb-1972c806-fca3bfe4.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDiagram-7ea5b25a-ca67804f.js → flowDiagram-7ea5b25a-4739080f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-736177bf.js +1 -0
- rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-abe16c3d-2dbc568d.js → flowchart-elk-definition-abe16c3d-7c1b0e0f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{ganttDiagram-9b5ea136-25a65bd8.js → ganttDiagram-9b5ea136-772fd050.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-99d0ae7c-fdc7378d.js → gitGraphDiagram-99d0ae7c-8eae1dc9.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-2c4b9a3b-6f1fd606.js → index-2c4b9a3b-f55afcdf.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-efdd30c1.js → index-e7cef9de.js} +68 -68
- rasa/core/channels/inspector/dist/assets/{infoDiagram-736b4530-cb1a041a.js → infoDiagram-736b4530-124d4a14.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{journeyDiagram-df861f2b-14609879.js → journeyDiagram-df861f2b-7c4fae44.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{layout-2490f52b.js → layout-b9885fb6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{line-40186f1f.js → line-7c59abb6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{linear-08814e93.js → linear-4776f780.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{mindmap-definition-beec6740-1a534584.js → mindmap-definition-beec6740-2332c46c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{pieDiagram-dbbf0591-72397b61.js → pieDiagram-dbbf0591-8fb39303.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{quadrantDiagram-4d7f4fd6-3bb0b6a3.js → quadrantDiagram-4d7f4fd6-3c7180a2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{requirementDiagram-6fc4c22a-57334f61.js → requirementDiagram-6fc4c22a-e910bcb8.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sankeyDiagram-8f13d901-111e1297.js → sankeyDiagram-8f13d901-ead16c89.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sequenceDiagram-b655622a-10bcfe62.js → sequenceDiagram-b655622a-29a02a19.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-59f0c015-acaf7513.js → stateDiagram-59f0c015-042b3137.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-2b26beab-3ec2a235.js → stateDiagram-v2-2b26beab-2178c0f3.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-080da4f6-62730289.js → styles-080da4f6-23ffa4fc.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-3dcbcfbf-5284ee76.js → styles-3dcbcfbf-94f59763.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-9c745c82-642435e3.js → styles-9c745c82-78a6bebc.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{svgDrawCommon-4835440b-b250a350.js → svgDrawCommon-4835440b-eae2a6f6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{timeline-definition-5b62e21b-c2b147ed.js → timeline-definition-5b62e21b-5c968d92.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{xychartDiagram-2b33534f-f92cfea9.js → xychartDiagram-2b33534f-fd3db0d5.js} +1 -1
- rasa/core/channels/inspector/dist/index.html +1 -1
- rasa/core/channels/inspector/src/App.tsx +1 -1
- rasa/core/channels/inspector/src/helpers/audiostream.ts +77 -16
- rasa/core/channels/socketio.py +2 -1
- rasa/core/channels/telegram.py +1 -1
- rasa/core/channels/twilio.py +1 -1
- rasa/core/channels/voice_ready/audiocodes.py +12 -0
- rasa/core/channels/voice_ready/jambonz.py +15 -4
- rasa/core/channels/voice_ready/twilio_voice.py +6 -21
- rasa/core/channels/voice_stream/asr/asr_event.py +5 -0
- rasa/core/channels/voice_stream/asr/azure.py +122 -0
- rasa/core/channels/voice_stream/asr/deepgram.py +16 -6
- rasa/core/channels/voice_stream/audio_bytes.py +1 -0
- rasa/core/channels/voice_stream/browser_audio.py +31 -8
- rasa/core/channels/voice_stream/call_state.py +23 -0
- rasa/core/channels/voice_stream/tts/azure.py +6 -2
- rasa/core/channels/voice_stream/tts/cartesia.py +10 -6
- rasa/core/channels/voice_stream/tts/tts_engine.py +1 -0
- rasa/core/channels/voice_stream/twilio_media_streams.py +27 -18
- rasa/core/channels/voice_stream/util.py +4 -4
- rasa/core/channels/voice_stream/voice_channel.py +189 -39
- rasa/core/featurizers/single_state_featurizer.py +22 -1
- rasa/core/featurizers/tracker_featurizers.py +115 -18
- rasa/core/nlg/contextual_response_rephraser.py +32 -30
- rasa/core/persistor.py +86 -39
- rasa/core/policies/enterprise_search_policy.py +119 -60
- rasa/core/policies/flows/flow_executor.py +7 -4
- rasa/core/policies/intentless_policy.py +78 -22
- rasa/core/policies/ted_policy.py +58 -33
- rasa/core/policies/unexpected_intent_policy.py +15 -7
- rasa/core/processor.py +25 -0
- rasa/core/training/interactive.py +34 -35
- rasa/core/utils.py +8 -3
- rasa/dialogue_understanding/coexistence/llm_based_router.py +39 -12
- rasa/dialogue_understanding/commands/change_flow_command.py +6 -0
- rasa/dialogue_understanding/commands/user_silence_command.py +59 -0
- rasa/dialogue_understanding/commands/utils.py +5 -0
- rasa/dialogue_understanding/generator/constants.py +2 -0
- rasa/dialogue_understanding/generator/flow_retrieval.py +49 -4
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +37 -23
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +57 -10
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +19 -1
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +71 -11
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +39 -0
- rasa/dialogue_understanding/patterns/user_silence.py +37 -0
- rasa/dialogue_understanding/processor/command_processor.py +21 -1
- rasa/e2e_test/e2e_test_case.py +85 -6
- rasa/e2e_test/e2e_test_runner.py +4 -2
- rasa/e2e_test/utils/io.py +1 -1
- rasa/engine/validation.py +316 -10
- rasa/model_manager/config.py +15 -3
- rasa/model_manager/model_api.py +15 -7
- rasa/model_manager/runner_service.py +8 -6
- rasa/model_manager/socket_bridge.py +6 -3
- rasa/model_manager/trainer_service.py +7 -5
- rasa/model_manager/utils.py +28 -7
- rasa/model_service.py +9 -2
- rasa/model_training.py +2 -0
- rasa/nlu/classifiers/diet_classifier.py +38 -25
- rasa/nlu/classifiers/logistic_regression_classifier.py +22 -9
- rasa/nlu/classifiers/sklearn_intent_classifier.py +37 -16
- rasa/nlu/extractors/crf_entity_extractor.py +93 -50
- rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +45 -16
- rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +52 -17
- rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +5 -3
- rasa/nlu/tokenizers/whitespace_tokenizer.py +3 -14
- rasa/server.py +3 -1
- rasa/shared/constants.py +36 -3
- rasa/shared/core/constants.py +7 -0
- rasa/shared/core/domain.py +26 -0
- rasa/shared/core/flows/flow.py +5 -0
- rasa/shared/core/flows/flows_list.py +5 -1
- rasa/shared/core/flows/flows_yaml_schema.json +10 -0
- rasa/shared/core/flows/utils.py +39 -0
- rasa/shared/core/flows/validation.py +96 -0
- rasa/shared/core/slots.py +5 -0
- rasa/shared/nlu/training_data/features.py +120 -2
- rasa/shared/providers/_configs/azure_openai_client_config.py +5 -3
- rasa/shared/providers/_configs/litellm_router_client_config.py +200 -0
- rasa/shared/providers/_configs/model_group_config.py +167 -0
- rasa/shared/providers/_configs/openai_client_config.py +1 -1
- rasa/shared/providers/_configs/rasa_llm_client_config.py +73 -0
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +1 -0
- rasa/shared/providers/_configs/utils.py +16 -0
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +18 -29
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +54 -21
- rasa/shared/providers/embedding/litellm_router_embedding_client.py +135 -0
- rasa/shared/providers/llm/_base_litellm_client.py +37 -31
- rasa/shared/providers/llm/azure_openai_llm_client.py +50 -29
- rasa/shared/providers/llm/litellm_router_llm_client.py +127 -0
- rasa/shared/providers/llm/rasa_llm_client.py +112 -0
- rasa/shared/providers/llm/self_hosted_llm_client.py +1 -1
- rasa/shared/providers/mappings.py +19 -0
- rasa/shared/providers/router/__init__.py +0 -0
- rasa/shared/providers/router/_base_litellm_router_client.py +149 -0
- rasa/shared/providers/router/router_client.py +73 -0
- rasa/shared/utils/common.py +8 -0
- rasa/shared/utils/health_check/__init__.py +0 -0
- rasa/shared/utils/health_check/embeddings_health_check_mixin.py +31 -0
- rasa/shared/utils/health_check/health_check.py +256 -0
- rasa/shared/utils/health_check/llm_health_check_mixin.py +31 -0
- rasa/shared/utils/io.py +28 -6
- rasa/shared/utils/llm.py +353 -46
- rasa/shared/utils/yaml.py +111 -73
- rasa/studio/auth.py +3 -5
- rasa/studio/config.py +13 -4
- rasa/studio/constants.py +1 -0
- rasa/studio/data_handler.py +10 -3
- rasa/studio/upload.py +81 -26
- rasa/telemetry.py +92 -17
- rasa/tracing/config.py +2 -0
- rasa/tracing/instrumentation/attribute_extractors.py +94 -17
- rasa/tracing/instrumentation/instrumentation.py +121 -0
- rasa/utils/common.py +5 -0
- rasa/utils/io.py +7 -81
- rasa/utils/log_utils.py +9 -2
- rasa/utils/sanic_error_handler.py +32 -0
- rasa/utils/tensorflow/feature_array.py +366 -0
- rasa/utils/tensorflow/model_data.py +2 -193
- rasa/validator.py +70 -0
- rasa/version.py +1 -1
- {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc2.dist-info}/METADATA +11 -10
- {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc2.dist-info}/RECORD +183 -163
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-587d82d8.js +0 -1
- {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc2.dist-info}/NOTICE +0 -0
- {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc2.dist-info}/WHEEL +0 -0
- {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc2.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import Any, Dict, List
|
|
4
|
+
|
|
5
|
+
import structlog
|
|
6
|
+
from rasa.shared.constants import (
|
|
7
|
+
ROUTER_CONFIG_KEY,
|
|
8
|
+
MODELS_CONFIG_KEY,
|
|
9
|
+
MODEL_GROUP_ID_CONFIG_KEY,
|
|
10
|
+
MODEL_NAME_CONFIG_KEY,
|
|
11
|
+
LITELLM_PARAMS_KEY,
|
|
12
|
+
PROVIDER_CONFIG_KEY,
|
|
13
|
+
DEPLOYMENT_CONFIG_KEY,
|
|
14
|
+
API_TYPE_CONFIG_KEY,
|
|
15
|
+
MODEL_CONFIG_KEY,
|
|
16
|
+
MODEL_LIST_KEY,
|
|
17
|
+
)
|
|
18
|
+
from rasa.shared.providers._configs.model_group_config import (
|
|
19
|
+
ModelGroupConfig,
|
|
20
|
+
ModelConfig,
|
|
21
|
+
)
|
|
22
|
+
from rasa.shared.providers.mappings import get_prefix_from_provider
|
|
23
|
+
from rasa.shared.utils.llm import DEPLOYMENT_CENTRIC_PROVIDERS
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
structlogger = structlog.get_logger()
|
|
27
|
+
|
|
28
|
+
_LITELLM_UNSUPPORTED_KEYS = [
|
|
29
|
+
PROVIDER_CONFIG_KEY,
|
|
30
|
+
DEPLOYMENT_CONFIG_KEY,
|
|
31
|
+
API_TYPE_CONFIG_KEY,
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class LiteLLMRouterClientConfig:
|
|
37
|
+
"""Parses configuration for a LiteLLM Router client. The configuration is expected
|
|
38
|
+
to be in the following format:
|
|
39
|
+
|
|
40
|
+
{
|
|
41
|
+
"id": "model_group_id",
|
|
42
|
+
"models": [
|
|
43
|
+
{
|
|
44
|
+
"provider": "provider_name",
|
|
45
|
+
"model": "model_name",
|
|
46
|
+
"api_base": "api_base",
|
|
47
|
+
"api_key": "api_key",
|
|
48
|
+
"api_version": "api_version",
|
|
49
|
+
},
|
|
50
|
+
{
|
|
51
|
+
"provider": "provider_name",
|
|
52
|
+
"model": "model_name",
|
|
53
|
+
},
|
|
54
|
+
"router": {}
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
This configuration is converted into the LiteLLM required format:
|
|
58
|
+
|
|
59
|
+
{
|
|
60
|
+
"id": "model_group_id",
|
|
61
|
+
"model_list": [
|
|
62
|
+
{
|
|
63
|
+
"model_name": "model_group_id",
|
|
64
|
+
"litellm_params": {
|
|
65
|
+
"model": "provider_name/model_name",
|
|
66
|
+
"api_base": "api_base",
|
|
67
|
+
"api_key": "api_key",
|
|
68
|
+
"api_version": "api_version",
|
|
69
|
+
},
|
|
70
|
+
},
|
|
71
|
+
{
|
|
72
|
+
"model_name": "model_group_id",
|
|
73
|
+
"litellm_params": {
|
|
74
|
+
"model": "provider_name/model_name",
|
|
75
|
+
},
|
|
76
|
+
},
|
|
77
|
+
],
|
|
78
|
+
"router": {},
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
Raises:
|
|
82
|
+
ValueError: If the configuration is missing required keys.
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
_model_group_config: ModelGroupConfig
|
|
86
|
+
router: Dict[str, Any]
|
|
87
|
+
extra_parameters: dict = field(default_factory=dict)
|
|
88
|
+
|
|
89
|
+
@property
|
|
90
|
+
def model_group_id(self) -> str:
|
|
91
|
+
return self._model_group_config.model_group_id
|
|
92
|
+
|
|
93
|
+
@property
|
|
94
|
+
def models(self) -> List[ModelConfig]:
|
|
95
|
+
return self._model_group_config.models
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def litellm_model_list(self) -> List[Dict[str, Any]]:
|
|
99
|
+
return self._convert_models_to_litellm_model_list()
|
|
100
|
+
|
|
101
|
+
def __post_init__(self) -> None:
|
|
102
|
+
if not self.router:
|
|
103
|
+
message = "Router cannot be empty."
|
|
104
|
+
structlogger.error(
|
|
105
|
+
"litellm_router_client_config.validation_error",
|
|
106
|
+
message=message,
|
|
107
|
+
model_group_id=self._model_group_config.model_group_id,
|
|
108
|
+
)
|
|
109
|
+
raise ValueError(message)
|
|
110
|
+
|
|
111
|
+
@classmethod
|
|
112
|
+
def from_dict(cls, config: dict) -> "LiteLLMRouterClientConfig":
|
|
113
|
+
"""Initializes a dataclass from the passed config.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
config: (dict) The config from which to initialize.
|
|
117
|
+
|
|
118
|
+
Raises:
|
|
119
|
+
ValueError: Config is missing required keys.
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
LiteLLMRouterClientConfig
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
model_group_config = ModelGroupConfig.from_dict(config)
|
|
126
|
+
|
|
127
|
+
# Copy config to avoid mutating the original
|
|
128
|
+
config_copy = copy.deepcopy(config)
|
|
129
|
+
# Pop the keys used by ModelGroupConfig
|
|
130
|
+
config_copy.pop(MODEL_GROUP_ID_CONFIG_KEY, None)
|
|
131
|
+
config_copy.pop(MODELS_CONFIG_KEY, None)
|
|
132
|
+
# Get the router settings
|
|
133
|
+
router_settings = config_copy.pop(ROUTER_CONFIG_KEY, None)
|
|
134
|
+
# The rest is considered as extra parameters
|
|
135
|
+
extra_parameters = config_copy
|
|
136
|
+
|
|
137
|
+
this = LiteLLMRouterClientConfig(
|
|
138
|
+
_model_group_config=model_group_config,
|
|
139
|
+
router=router_settings,
|
|
140
|
+
extra_parameters=extra_parameters,
|
|
141
|
+
)
|
|
142
|
+
return this
|
|
143
|
+
|
|
144
|
+
def to_dict(self) -> dict:
|
|
145
|
+
"""Converts the config instance into a dictionary."""
|
|
146
|
+
d = self._model_group_config.to_dict()
|
|
147
|
+
d[ROUTER_CONFIG_KEY] = self.router
|
|
148
|
+
if self.extra_parameters:
|
|
149
|
+
d.update(self.extra_parameters)
|
|
150
|
+
return d
|
|
151
|
+
|
|
152
|
+
def to_litellm_dict(self) -> dict:
|
|
153
|
+
litellm_model_list = self._convert_models_to_litellm_model_list()
|
|
154
|
+
d = {
|
|
155
|
+
**self.extra_parameters,
|
|
156
|
+
MODEL_GROUP_ID_CONFIG_KEY: self.model_group_id,
|
|
157
|
+
MODEL_LIST_KEY: litellm_model_list,
|
|
158
|
+
ROUTER_CONFIG_KEY: self.router,
|
|
159
|
+
}
|
|
160
|
+
return d
|
|
161
|
+
|
|
162
|
+
def _convert_models_to_litellm_model_list(self) -> List[Dict[str, Any]]:
|
|
163
|
+
litellm_model_list = []
|
|
164
|
+
|
|
165
|
+
for model_config_object in self.models:
|
|
166
|
+
# Convert the model config to a dict representation
|
|
167
|
+
litellm_model_config = model_config_object.to_dict()
|
|
168
|
+
|
|
169
|
+
provider = litellm_model_config[PROVIDER_CONFIG_KEY]
|
|
170
|
+
|
|
171
|
+
# Get the litellm prefixing for the provider
|
|
172
|
+
prefix = get_prefix_from_provider(provider)
|
|
173
|
+
|
|
174
|
+
# Determine whether to use model or deployment key based on the provider.
|
|
175
|
+
litellm_model_name_without_prefix = (
|
|
176
|
+
litellm_model_config[DEPLOYMENT_CONFIG_KEY]
|
|
177
|
+
if provider in DEPLOYMENT_CENTRIC_PROVIDERS
|
|
178
|
+
else litellm_model_config[MODEL_CONFIG_KEY]
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
# Set 'model' to a provider prefixed model name e.g. openai/gpt-4
|
|
182
|
+
litellm_model_config[MODEL_CONFIG_KEY] = (
|
|
183
|
+
f"{prefix}/{litellm_model_name_without_prefix}"
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
# Remove parameters that are None and not supported by LiteLLM.
|
|
187
|
+
litellm_model_config = {
|
|
188
|
+
key: value
|
|
189
|
+
for key, value in litellm_model_config.items()
|
|
190
|
+
if key not in _LITELLM_UNSUPPORTED_KEYS and value is not None
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
litellm_model_list_item = {
|
|
194
|
+
MODEL_NAME_CONFIG_KEY: self.model_group_id,
|
|
195
|
+
LITELLM_PARAMS_KEY: litellm_model_config,
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
litellm_model_list.append(litellm_model_list_item)
|
|
199
|
+
|
|
200
|
+
return litellm_model_list
|
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
from dataclasses import asdict, dataclass, field
|
|
2
|
+
from typing import List, Optional
|
|
3
|
+
|
|
4
|
+
import structlog
|
|
5
|
+
from rasa.shared.constants import (
|
|
6
|
+
API_BASE_CONFIG_KEY,
|
|
7
|
+
API_KEY,
|
|
8
|
+
API_TYPE_CONFIG_KEY,
|
|
9
|
+
API_VERSION_CONFIG_KEY,
|
|
10
|
+
DEPLOYMENT_CONFIG_KEY,
|
|
11
|
+
PROVIDER_CONFIG_KEY,
|
|
12
|
+
MODEL_CONFIG_KEY,
|
|
13
|
+
MODEL_GROUP_ID_CONFIG_KEY,
|
|
14
|
+
MODELS_CONFIG_KEY,
|
|
15
|
+
MODEL_GROUPS_CONFIG_KEY,
|
|
16
|
+
EXTRA_PARAMETERS_KEY,
|
|
17
|
+
)
|
|
18
|
+
from rasa.shared.providers.mappings import get_client_config_class_from_provider
|
|
19
|
+
|
|
20
|
+
structlogger = structlog.get_logger()
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class ModelConfig:
|
|
25
|
+
"""Parses the model config.
|
|
26
|
+
|
|
27
|
+
Raises:
|
|
28
|
+
ValueError: If the provider config key is missing in the config.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
provider: str
|
|
32
|
+
model: Optional[str] = None
|
|
33
|
+
deployment: Optional[str] = None
|
|
34
|
+
api_base: Optional[str] = None
|
|
35
|
+
api_key: Optional[str] = None
|
|
36
|
+
api_version: Optional[str] = None
|
|
37
|
+
extra_parameters: dict = field(default_factory=dict)
|
|
38
|
+
# Retained for backward compatibility with older configurations,
|
|
39
|
+
# but intentionally not included in extra_parameters
|
|
40
|
+
api_type: Optional[str] = None
|
|
41
|
+
|
|
42
|
+
@classmethod
|
|
43
|
+
def from_dict(cls, config: dict) -> "ModelConfig":
|
|
44
|
+
"""Initializes a dataclass from the passed config. The provider config param is
|
|
45
|
+
used to determine the client config class to use. The client config class takes
|
|
46
|
+
care of resolving config aliases and throwing deprecation warnings.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
config: (dict) The config from which to initialize.
|
|
50
|
+
|
|
51
|
+
Raises:
|
|
52
|
+
ValueError: Config is missing required keys.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
ModelConfig
|
|
56
|
+
"""
|
|
57
|
+
from rasa.shared.utils.llm import get_provider_from_config
|
|
58
|
+
|
|
59
|
+
# Get the provider from config, this also inferring the provider from
|
|
60
|
+
# deprecated configurations
|
|
61
|
+
provider = get_provider_from_config(config)
|
|
62
|
+
|
|
63
|
+
# Retrieve the client configuration class for the specified provider.
|
|
64
|
+
client_config_clazz = get_client_config_class_from_provider(provider)
|
|
65
|
+
|
|
66
|
+
# Try to instantiate the config object in order to resolve deprecated
|
|
67
|
+
# aliases and throw deprecation warnings.
|
|
68
|
+
client_config_obj = client_config_clazz.from_dict(config)
|
|
69
|
+
|
|
70
|
+
# Convert back to dictionary and instantiate the ModelConfig object.
|
|
71
|
+
client_config = client_config_obj.to_dict()
|
|
72
|
+
|
|
73
|
+
# Check for provider after resolving all aliases
|
|
74
|
+
if PROVIDER_CONFIG_KEY not in client_config:
|
|
75
|
+
raise ValueError(
|
|
76
|
+
f"Missing required key '{PROVIDER_CONFIG_KEY}' in "
|
|
77
|
+
f"'{MODELS_CONFIG_KEY}' config."
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
return ModelConfig(
|
|
81
|
+
provider=client_config.pop(PROVIDER_CONFIG_KEY, None),
|
|
82
|
+
model=client_config.pop(MODEL_CONFIG_KEY, None),
|
|
83
|
+
deployment=client_config.pop(DEPLOYMENT_CONFIG_KEY, None),
|
|
84
|
+
api_type=client_config.pop(API_TYPE_CONFIG_KEY, None),
|
|
85
|
+
api_base=client_config.pop(API_BASE_CONFIG_KEY, None),
|
|
86
|
+
api_key=client_config.pop(API_KEY, None),
|
|
87
|
+
api_version=client_config.pop(API_VERSION_CONFIG_KEY, None),
|
|
88
|
+
extra_parameters=client_config,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
def to_dict(self) -> dict:
|
|
92
|
+
"""Converts the config instance into a dictionary."""
|
|
93
|
+
d = asdict(self)
|
|
94
|
+
|
|
95
|
+
# Extra parameters should also be on the top level
|
|
96
|
+
d.pop(EXTRA_PARAMETERS_KEY, None)
|
|
97
|
+
d.update(self.extra_parameters)
|
|
98
|
+
|
|
99
|
+
# Remove keys with None values
|
|
100
|
+
return {key: value for key, value in d.items() if value is not None}
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@dataclass
|
|
104
|
+
class ModelGroupConfig:
|
|
105
|
+
"""Parses the models config. The models config is a list of model configs.
|
|
106
|
+
|
|
107
|
+
Raises:
|
|
108
|
+
ValueError: If the model group ID is None or if the models list is empty.
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
model_group_id: str
|
|
112
|
+
models: List[ModelConfig]
|
|
113
|
+
|
|
114
|
+
def __post_init__(self) -> None:
|
|
115
|
+
if self.model_group_id is None:
|
|
116
|
+
message = "Model group ID cannot be set to None."
|
|
117
|
+
structlogger.error(
|
|
118
|
+
"model_group_config.validation_error",
|
|
119
|
+
message=message,
|
|
120
|
+
model_group_id=self.model_group_id,
|
|
121
|
+
)
|
|
122
|
+
raise ValueError(message)
|
|
123
|
+
if not self.models:
|
|
124
|
+
message = "Models cannot be empty."
|
|
125
|
+
structlogger.error(
|
|
126
|
+
"model_group_config.validation_error",
|
|
127
|
+
message=message,
|
|
128
|
+
model_group_id=self.model_group_id,
|
|
129
|
+
)
|
|
130
|
+
raise ValueError(message)
|
|
131
|
+
|
|
132
|
+
@classmethod
|
|
133
|
+
def from_dict(cls, config: dict) -> "ModelGroupConfig":
|
|
134
|
+
"""Initializes a dataclass from the passed config.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
config: (dict) The config from which to initialize.
|
|
138
|
+
|
|
139
|
+
Raises:
|
|
140
|
+
ValueError: Config is missing required keys.
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
ModelGroupConfig
|
|
144
|
+
"""
|
|
145
|
+
if MODELS_CONFIG_KEY not in config:
|
|
146
|
+
raise ValueError(
|
|
147
|
+
f"Missing required key '{MODELS_CONFIG_KEY}' in "
|
|
148
|
+
f"'{MODEL_GROUPS_CONFIG_KEY}' config."
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
models_config = [
|
|
152
|
+
ModelConfig.from_dict(model_config)
|
|
153
|
+
for model_config in config[MODELS_CONFIG_KEY]
|
|
154
|
+
]
|
|
155
|
+
|
|
156
|
+
return cls(
|
|
157
|
+
model_group_id=config.get(MODEL_GROUP_ID_CONFIG_KEY),
|
|
158
|
+
models=models_config,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
def to_dict(self) -> dict:
|
|
162
|
+
"""Converts the config instance into a dictionary."""
|
|
163
|
+
d = {
|
|
164
|
+
MODEL_GROUP_ID_CONFIG_KEY: self.model_group_id,
|
|
165
|
+
MODELS_CONFIG_KEY: [model.to_dict() for model in self.models],
|
|
166
|
+
}
|
|
167
|
+
return d
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
from dataclasses import asdict, dataclass, field
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import structlog
|
|
5
|
+
|
|
6
|
+
from rasa.shared.constants import (
|
|
7
|
+
MODEL_CONFIG_KEY,
|
|
8
|
+
RASA_PROVIDER,
|
|
9
|
+
PROVIDER_CONFIG_KEY,
|
|
10
|
+
API_BASE_CONFIG_KEY,
|
|
11
|
+
)
|
|
12
|
+
from rasa.shared.providers._configs.utils import (
|
|
13
|
+
validate_required_keys,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
REQUIRED_KEYS = [MODEL_CONFIG_KEY, PROVIDER_CONFIG_KEY, API_BASE_CONFIG_KEY]
|
|
17
|
+
|
|
18
|
+
structlogger = structlog.get_logger()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class RasaLLMClientConfig:
|
|
23
|
+
"""Parses configuration for a Rasa Hosted LiteLLM client,
|
|
24
|
+
checks required keys present.
|
|
25
|
+
|
|
26
|
+
Raises:
|
|
27
|
+
ValueError: Raised in cases of invalid configuration:
|
|
28
|
+
- If any of the required configuration keys are missing.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
model: Optional[str]
|
|
32
|
+
api_base: Optional[str]
|
|
33
|
+
# Provider is not used by LiteLLM backend, but we define it here since it's
|
|
34
|
+
# used as switch between different clients.
|
|
35
|
+
provider: str = RASA_PROVIDER
|
|
36
|
+
|
|
37
|
+
extra_parameters: dict = field(default_factory=dict)
|
|
38
|
+
|
|
39
|
+
@classmethod
|
|
40
|
+
def from_dict(cls, config: dict) -> "RasaLLMClientConfig":
|
|
41
|
+
"""
|
|
42
|
+
Initializes a dataclass from the passed config.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
config: (dict) The config from which to initialize.
|
|
46
|
+
|
|
47
|
+
Raises:
|
|
48
|
+
ValueError: Raised in cases of invalid configuration:
|
|
49
|
+
- If any of the required configuration keys are missing.
|
|
50
|
+
- If `api_type` has a value different from `azure`.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
RasaLLMClientConfig
|
|
54
|
+
"""
|
|
55
|
+
# Validate that required keys are set
|
|
56
|
+
validate_required_keys(config, REQUIRED_KEYS)
|
|
57
|
+
|
|
58
|
+
extra_parameters = {k: v for k, v in config.items() if k not in REQUIRED_KEYS}
|
|
59
|
+
|
|
60
|
+
return cls(
|
|
61
|
+
model=config.get(MODEL_CONFIG_KEY),
|
|
62
|
+
api_base=config.get(API_BASE_CONFIG_KEY),
|
|
63
|
+
provider=config.get(PROVIDER_CONFIG_KEY, RASA_PROVIDER),
|
|
64
|
+
extra_parameters=extra_parameters,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
def to_dict(self) -> dict:
|
|
68
|
+
"""Converts the config instance into a dictionary."""
|
|
69
|
+
d = asdict(self)
|
|
70
|
+
# Extra parameters should also be on the top level
|
|
71
|
+
d.pop("extra_parameters", None)
|
|
72
|
+
d.update(self.extra_parameters)
|
|
73
|
+
return d
|
|
@@ -99,3 +99,19 @@ def validate_forbidden_keys(config: dict, forbidden_keys: list) -> None:
|
|
|
99
99
|
config=config,
|
|
100
100
|
)
|
|
101
101
|
raise ValueError(message)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def get_provider_prefixed_model_name(provider: str, model: str) -> str:
|
|
105
|
+
"""
|
|
106
|
+
Returns the model name with the provider prefixed.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
provider: The provider of the model.
|
|
110
|
+
model: The model name.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
The model name with the provider prefixed.
|
|
114
|
+
"""
|
|
115
|
+
if model and f"{provider}/" not in model:
|
|
116
|
+
return f"{provider}/{model}"
|
|
117
|
+
return model
|
|
@@ -1,10 +1,12 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
from abc import abstractmethod
|
|
2
3
|
from typing import Any, Dict, List
|
|
3
4
|
|
|
4
5
|
import litellm
|
|
5
|
-
import logging
|
|
6
6
|
import structlog
|
|
7
7
|
from litellm import aembedding, embedding, validate_environment
|
|
8
|
+
|
|
9
|
+
from rasa.shared.constants import API_BASE_CONFIG_KEY, API_KEY
|
|
8
10
|
from rasa.shared.exceptions import (
|
|
9
11
|
ProviderClientAPIException,
|
|
10
12
|
ProviderClientValidationError,
|
|
@@ -17,7 +19,7 @@ from rasa.shared.providers.embedding.embedding_response import (
|
|
|
17
19
|
EmbeddingResponse,
|
|
18
20
|
EmbeddingUsage,
|
|
19
21
|
)
|
|
20
|
-
from rasa.shared.utils.io import suppress_logs
|
|
22
|
+
from rasa.shared.utils.io import suppress_logs, resolve_environment_variables
|
|
21
23
|
|
|
22
24
|
structlogger = structlog.get_logger()
|
|
23
25
|
|
|
@@ -25,8 +27,7 @@ _VALIDATE_ENVIRONMENT_MISSING_KEYS_KEY = "missing_keys"
|
|
|
25
27
|
|
|
26
28
|
|
|
27
29
|
class _BaseLiteLLMEmbeddingClient:
|
|
28
|
-
"""
|
|
29
|
-
An abstract base class for LiteLLM embedding clients.
|
|
30
|
+
"""An abstract base class for LiteLLM embedding clients.
|
|
30
31
|
|
|
31
32
|
This class defines the interface and common functionality for all clients
|
|
32
33
|
based on LiteLLM.
|
|
@@ -81,11 +82,14 @@ class _BaseLiteLLMEmbeddingClient:
|
|
|
81
82
|
ProviderClientValidationError if validation fails.
|
|
82
83
|
"""
|
|
83
84
|
self._validate_environment_variables()
|
|
84
|
-
self._validate_api_key_not_in_config()
|
|
85
85
|
|
|
86
86
|
def _validate_environment_variables(self) -> None:
|
|
87
87
|
"""Validate that the required environment variables are set."""
|
|
88
|
-
validation_info = validate_environment(
|
|
88
|
+
validation_info = validate_environment(
|
|
89
|
+
self._litellm_model_name,
|
|
90
|
+
api_key=self._litellm_extra_parameters.get(API_KEY),
|
|
91
|
+
api_base=self._litellm_extra_parameters.get(API_BASE_CONFIG_KEY),
|
|
92
|
+
)
|
|
89
93
|
if missing_environment_variables := validation_info.get(
|
|
90
94
|
_VALIDATE_ENVIRONMENT_MISSING_KEYS_KEY
|
|
91
95
|
):
|
|
@@ -100,21 +104,8 @@ class _BaseLiteLLMEmbeddingClient:
|
|
|
100
104
|
)
|
|
101
105
|
raise ProviderClientValidationError(event_info)
|
|
102
106
|
|
|
103
|
-
def _validate_api_key_not_in_config(self) -> None:
|
|
104
|
-
if "api_key" in self._litellm_extra_parameters:
|
|
105
|
-
event_info = (
|
|
106
|
-
"API Key is set through `api_key` extra parameter."
|
|
107
|
-
"Set API keys through environment variables."
|
|
108
|
-
)
|
|
109
|
-
structlogger.error(
|
|
110
|
-
"base_litellm_client.validate_api_key_not_in_config",
|
|
111
|
-
event_info=event_info,
|
|
112
|
-
)
|
|
113
|
-
raise ProviderClientValidationError(event_info)
|
|
114
|
-
|
|
115
107
|
def validate_documents(self, documents: List[str]) -> None:
|
|
116
|
-
"""
|
|
117
|
-
Validates a list of documents to ensure they are suitable for embedding.
|
|
108
|
+
"""Validates a list of documents to ensure they are suitable for embedding.
|
|
118
109
|
|
|
119
110
|
Args:
|
|
120
111
|
documents: List of documents to be validated.
|
|
@@ -130,8 +121,7 @@ class _BaseLiteLLMEmbeddingClient:
|
|
|
130
121
|
|
|
131
122
|
@suppress_logs(log_level=logging.WARNING)
|
|
132
123
|
def embed(self, documents: List[str]) -> EmbeddingResponse:
|
|
133
|
-
"""
|
|
134
|
-
Embeds a list of documents synchronously.
|
|
124
|
+
"""Embeds a list of documents synchronously.
|
|
135
125
|
|
|
136
126
|
Args:
|
|
137
127
|
documents: List of documents to be embedded.
|
|
@@ -144,7 +134,8 @@ class _BaseLiteLLMEmbeddingClient:
|
|
|
144
134
|
"""
|
|
145
135
|
self.validate_documents(documents)
|
|
146
136
|
try:
|
|
147
|
-
|
|
137
|
+
arguments = resolve_environment_variables(self._embedding_fn_args)
|
|
138
|
+
response = embedding(input=documents, **arguments)
|
|
148
139
|
return self._format_response(response)
|
|
149
140
|
except Exception as e:
|
|
150
141
|
raise ProviderClientAPIException(
|
|
@@ -153,8 +144,7 @@ class _BaseLiteLLMEmbeddingClient:
|
|
|
153
144
|
|
|
154
145
|
@suppress_logs(log_level=logging.WARNING)
|
|
155
146
|
async def aembed(self, documents: List[str]) -> EmbeddingResponse:
|
|
156
|
-
"""
|
|
157
|
-
Embeds a list of documents asynchronously.
|
|
147
|
+
"""Embeds a list of documents asynchronously.
|
|
158
148
|
|
|
159
149
|
Args:
|
|
160
150
|
documents: List of documents to be embedded.
|
|
@@ -167,7 +157,8 @@ class _BaseLiteLLMEmbeddingClient:
|
|
|
167
157
|
"""
|
|
168
158
|
self.validate_documents(documents)
|
|
169
159
|
try:
|
|
170
|
-
|
|
160
|
+
arguments = resolve_environment_variables(self._embedding_fn_args)
|
|
161
|
+
response = await aembedding(input=documents, **arguments)
|
|
171
162
|
return self._format_response(response)
|
|
172
163
|
except Exception as e:
|
|
173
164
|
raise ProviderClientAPIException(
|
|
@@ -182,7 +173,6 @@ class _BaseLiteLLMEmbeddingClient:
|
|
|
182
173
|
Raises:
|
|
183
174
|
ValueError: If any response data is None.
|
|
184
175
|
"""
|
|
185
|
-
|
|
186
176
|
# If data is not available (None), raise a ValueError
|
|
187
177
|
if response.data is None:
|
|
188
178
|
message = (
|
|
@@ -239,8 +229,7 @@ class _BaseLiteLLMEmbeddingClient:
|
|
|
239
229
|
|
|
240
230
|
@staticmethod
|
|
241
231
|
def _ensure_certificates() -> None:
|
|
242
|
-
"""
|
|
243
|
-
Configures SSL certificates for LiteLLM. This method is invoked during
|
|
232
|
+
"""Configures SSL certificates for LiteLLM. This method is invoked during
|
|
244
233
|
client initialization.
|
|
245
234
|
|
|
246
235
|
LiteLLM may utilize `openai` clients or other providers that require
|