rasa-pro 3.11.0a4.dev3__py3-none-any.whl → 3.11.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/__main__.py +22 -12
- rasa/api.py +1 -1
- rasa/cli/arguments/default_arguments.py +1 -2
- rasa/cli/arguments/shell.py +5 -1
- rasa/cli/e2e_test.py +1 -1
- rasa/cli/evaluate.py +8 -8
- rasa/cli/inspect.py +4 -4
- rasa/cli/llm_fine_tuning.py +1 -1
- rasa/cli/project_templates/calm/config.yml +5 -7
- rasa/cli/project_templates/calm/endpoints.yml +8 -0
- rasa/cli/project_templates/tutorial/config.yml +8 -5
- rasa/cli/project_templates/tutorial/data/flows.yml +1 -1
- rasa/cli/project_templates/tutorial/data/patterns.yml +5 -0
- rasa/cli/project_templates/tutorial/domain.yml +14 -0
- rasa/cli/project_templates/tutorial/endpoints.yml +7 -7
- rasa/cli/run.py +1 -1
- rasa/cli/scaffold.py +4 -2
- rasa/cli/utils.py +5 -0
- rasa/cli/x.py +8 -8
- rasa/constants.py +1 -1
- rasa/core/channels/channel.py +3 -0
- rasa/core/channels/inspector/dist/assets/{arc-6852c607.js → arc-bc141fb2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{c4Diagram-d0fbc5ce-acc952b2.js → c4Diagram-d0fbc5ce-be2db283.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-936ed81e-848a7597.js → classDiagram-936ed81e-55366915.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-v2-c3cb15f1-a73d3e68.js → classDiagram-v2-c3cb15f1-bb529518.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{createText-62fc7601-e5ee049d.js → createText-62fc7601-b0ec81d6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{edges-f2ad444c-771e517e.js → edges-f2ad444c-6166330c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{erDiagram-9d236eb7-aa347178.js → erDiagram-9d236eb7-5ccc6a8e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDb-1972c806-651fc57d.js → flowDb-1972c806-fca3bfe4.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDiagram-7ea5b25a-ca67804f.js → flowDiagram-7ea5b25a-4739080f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-736177bf.js +1 -0
- rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-abe16c3d-2dbc568d.js → flowchart-elk-definition-abe16c3d-7c1b0e0f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{ganttDiagram-9b5ea136-25a65bd8.js → ganttDiagram-9b5ea136-772fd050.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-99d0ae7c-fdc7378d.js → gitGraphDiagram-99d0ae7c-8eae1dc9.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-2c4b9a3b-6f1fd606.js → index-2c4b9a3b-f55afcdf.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-efdd30c1.js → index-e7cef9de.js} +68 -68
- rasa/core/channels/inspector/dist/assets/{infoDiagram-736b4530-cb1a041a.js → infoDiagram-736b4530-124d4a14.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{journeyDiagram-df861f2b-14609879.js → journeyDiagram-df861f2b-7c4fae44.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{layout-2490f52b.js → layout-b9885fb6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{line-40186f1f.js → line-7c59abb6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{linear-08814e93.js → linear-4776f780.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{mindmap-definition-beec6740-1a534584.js → mindmap-definition-beec6740-2332c46c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{pieDiagram-dbbf0591-72397b61.js → pieDiagram-dbbf0591-8fb39303.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{quadrantDiagram-4d7f4fd6-3bb0b6a3.js → quadrantDiagram-4d7f4fd6-3c7180a2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{requirementDiagram-6fc4c22a-57334f61.js → requirementDiagram-6fc4c22a-e910bcb8.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sankeyDiagram-8f13d901-111e1297.js → sankeyDiagram-8f13d901-ead16c89.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sequenceDiagram-b655622a-10bcfe62.js → sequenceDiagram-b655622a-29a02a19.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-59f0c015-acaf7513.js → stateDiagram-59f0c015-042b3137.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-2b26beab-3ec2a235.js → stateDiagram-v2-2b26beab-2178c0f3.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-080da4f6-62730289.js → styles-080da4f6-23ffa4fc.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-3dcbcfbf-5284ee76.js → styles-3dcbcfbf-94f59763.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-9c745c82-642435e3.js → styles-9c745c82-78a6bebc.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{svgDrawCommon-4835440b-b250a350.js → svgDrawCommon-4835440b-eae2a6f6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{timeline-definition-5b62e21b-c2b147ed.js → timeline-definition-5b62e21b-5c968d92.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{xychartDiagram-2b33534f-f92cfea9.js → xychartDiagram-2b33534f-fd3db0d5.js} +1 -1
- rasa/core/channels/inspector/dist/index.html +1 -1
- rasa/core/channels/inspector/src/App.tsx +1 -1
- rasa/core/channels/inspector/src/helpers/audiostream.ts +77 -16
- rasa/core/channels/socketio.py +2 -1
- rasa/core/channels/telegram.py +1 -1
- rasa/core/channels/twilio.py +1 -1
- rasa/core/channels/voice_ready/jambonz.py +2 -2
- rasa/core/channels/voice_stream/asr/asr_event.py +5 -0
- rasa/core/channels/voice_stream/asr/azure.py +122 -0
- rasa/core/channels/voice_stream/asr/deepgram.py +16 -6
- rasa/core/channels/voice_stream/audio_bytes.py +1 -0
- rasa/core/channels/voice_stream/browser_audio.py +31 -8
- rasa/core/channels/voice_stream/call_state.py +23 -0
- rasa/core/channels/voice_stream/tts/azure.py +6 -2
- rasa/core/channels/voice_stream/tts/cartesia.py +10 -6
- rasa/core/channels/voice_stream/tts/tts_engine.py +1 -0
- rasa/core/channels/voice_stream/twilio_media_streams.py +27 -18
- rasa/core/channels/voice_stream/util.py +4 -4
- rasa/core/channels/voice_stream/voice_channel.py +177 -39
- rasa/core/featurizers/single_state_featurizer.py +22 -1
- rasa/core/featurizers/tracker_featurizers.py +115 -18
- rasa/core/nlg/contextual_response_rephraser.py +16 -22
- rasa/core/persistor.py +86 -39
- rasa/core/policies/enterprise_search_policy.py +159 -60
- rasa/core/policies/flows/flow_executor.py +7 -4
- rasa/core/policies/intentless_policy.py +120 -22
- rasa/core/policies/ted_policy.py +58 -33
- rasa/core/policies/unexpected_intent_policy.py +15 -7
- rasa/core/processor.py +25 -0
- rasa/core/training/interactive.py +34 -35
- rasa/core/utils.py +8 -3
- rasa/dialogue_understanding/coexistence/llm_based_router.py +58 -16
- rasa/dialogue_understanding/commands/change_flow_command.py +6 -0
- rasa/dialogue_understanding/commands/user_silence_command.py +59 -0
- rasa/dialogue_understanding/commands/utils.py +5 -0
- rasa/dialogue_understanding/generator/constants.py +4 -0
- rasa/dialogue_understanding/generator/flow_retrieval.py +65 -3
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +68 -26
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +57 -8
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +64 -7
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +39 -0
- rasa/dialogue_understanding/patterns/user_silence.py +37 -0
- rasa/e2e_test/e2e_test_runner.py +4 -2
- rasa/e2e_test/utils/io.py +1 -1
- rasa/engine/validation.py +297 -7
- rasa/model_manager/config.py +15 -3
- rasa/model_manager/model_api.py +15 -7
- rasa/model_manager/runner_service.py +8 -6
- rasa/model_manager/socket_bridge.py +6 -3
- rasa/model_manager/trainer_service.py +7 -5
- rasa/model_manager/utils.py +28 -7
- rasa/model_service.py +6 -2
- rasa/model_training.py +2 -0
- rasa/nlu/classifiers/diet_classifier.py +38 -25
- rasa/nlu/classifiers/logistic_regression_classifier.py +22 -9
- rasa/nlu/classifiers/sklearn_intent_classifier.py +37 -16
- rasa/nlu/extractors/crf_entity_extractor.py +93 -50
- rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +45 -16
- rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +52 -17
- rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +5 -3
- rasa/shared/constants.py +36 -3
- rasa/shared/core/constants.py +7 -0
- rasa/shared/core/domain.py +26 -0
- rasa/shared/core/flows/flow.py +5 -0
- rasa/shared/core/flows/flows_yaml_schema.json +10 -0
- rasa/shared/core/flows/utils.py +39 -0
- rasa/shared/core/flows/validation.py +96 -0
- rasa/shared/core/slots.py +5 -0
- rasa/shared/nlu/training_data/features.py +120 -2
- rasa/shared/providers/_configs/azure_openai_client_config.py +5 -3
- rasa/shared/providers/_configs/litellm_router_client_config.py +200 -0
- rasa/shared/providers/_configs/model_group_config.py +167 -0
- rasa/shared/providers/_configs/openai_client_config.py +1 -1
- rasa/shared/providers/_configs/rasa_llm_client_config.py +73 -0
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +1 -0
- rasa/shared/providers/_configs/utils.py +16 -0
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +12 -15
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +54 -21
- rasa/shared/providers/embedding/litellm_router_embedding_client.py +135 -0
- rasa/shared/providers/llm/_base_litellm_client.py +31 -30
- rasa/shared/providers/llm/azure_openai_llm_client.py +50 -29
- rasa/shared/providers/llm/litellm_router_llm_client.py +127 -0
- rasa/shared/providers/llm/rasa_llm_client.py +112 -0
- rasa/shared/providers/llm/self_hosted_llm_client.py +1 -1
- rasa/shared/providers/mappings.py +19 -0
- rasa/shared/providers/router/__init__.py +0 -0
- rasa/shared/providers/router/_base_litellm_router_client.py +149 -0
- rasa/shared/providers/router/router_client.py +73 -0
- rasa/shared/utils/common.py +8 -0
- rasa/shared/utils/health_check.py +533 -0
- rasa/shared/utils/io.py +28 -6
- rasa/shared/utils/llm.py +350 -46
- rasa/shared/utils/yaml.py +11 -13
- rasa/studio/upload.py +64 -20
- rasa/telemetry.py +80 -17
- rasa/tracing/instrumentation/attribute_extractors.py +74 -17
- rasa/utils/io.py +0 -66
- rasa/utils/log_utils.py +9 -2
- rasa/utils/tensorflow/feature_array.py +366 -0
- rasa/utils/tensorflow/model_data.py +2 -193
- rasa/validator.py +70 -0
- rasa/version.py +1 -1
- {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc1.dist-info}/METADATA +10 -10
- {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc1.dist-info}/RECORD +162 -146
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-587d82d8.js +0 -1
- {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc1.dist-info}/NOTICE +0 -0
- {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc1.dist-info}/WHEEL +0 -0
- {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc1.dist-info}/entry_points.txt +0 -0
|
@@ -1,15 +1,133 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
|
|
2
|
+
|
|
3
3
|
import itertools
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Iterable, Union, Text, Optional, List, Any, Tuple, Dict, Set
|
|
4
6
|
|
|
5
7
|
import numpy as np
|
|
6
8
|
import scipy.sparse
|
|
9
|
+
from safetensors.numpy import save_file, load_file
|
|
7
10
|
|
|
8
|
-
import rasa.shared.utils.io
|
|
9
11
|
import rasa.shared.nlu.training_data.util
|
|
12
|
+
import rasa.shared.utils.io
|
|
10
13
|
from rasa.shared.nlu.constants import FEATURE_TYPE_SEQUENCE, FEATURE_TYPE_SENTENCE
|
|
11
14
|
|
|
12
15
|
|
|
16
|
+
@dataclass
|
|
17
|
+
class FeatureMetadata:
|
|
18
|
+
data_type: str
|
|
19
|
+
attribute: str
|
|
20
|
+
origin: Union[str, List[str]]
|
|
21
|
+
is_sparse: bool
|
|
22
|
+
shape: tuple
|
|
23
|
+
safetensors_key: str
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def save_features(
|
|
27
|
+
features_dict: Dict[Text, List[Features]], file_name: str
|
|
28
|
+
) -> Dict[str, Any]:
|
|
29
|
+
"""Save a dictionary of Features lists to disk using safetensors.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
features_dict: Dictionary mapping strings to lists of Features objects
|
|
33
|
+
file_name: File to save the features to
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
The metadata to reconstruct the features.
|
|
37
|
+
"""
|
|
38
|
+
# All tensors are stored in a single safetensors file
|
|
39
|
+
tensors_to_save = {}
|
|
40
|
+
# Metadata will be stored separately
|
|
41
|
+
metadata = {}
|
|
42
|
+
|
|
43
|
+
for key, features_list in features_dict.items():
|
|
44
|
+
feature_metadata_list = []
|
|
45
|
+
|
|
46
|
+
for idx, feature in enumerate(features_list):
|
|
47
|
+
# Create a unique key for this tensor in the safetensors file
|
|
48
|
+
safetensors_key = f"{key}_{idx}"
|
|
49
|
+
|
|
50
|
+
# Convert sparse matrices to dense if needed
|
|
51
|
+
if feature.is_sparse():
|
|
52
|
+
# For sparse matrices, use the COO format
|
|
53
|
+
coo = feature.features.tocoo() # type:ignore[union-attr]
|
|
54
|
+
# Save data, row indices and col indices separately
|
|
55
|
+
tensors_to_save[f"{safetensors_key}_data"] = coo.data
|
|
56
|
+
tensors_to_save[f"{safetensors_key}_row"] = coo.row
|
|
57
|
+
tensors_to_save[f"{safetensors_key}_col"] = coo.col
|
|
58
|
+
else:
|
|
59
|
+
tensors_to_save[safetensors_key] = feature.features
|
|
60
|
+
|
|
61
|
+
# Store metadata
|
|
62
|
+
metadata_item = FeatureMetadata(
|
|
63
|
+
data_type=feature.type,
|
|
64
|
+
attribute=feature.attribute,
|
|
65
|
+
origin=feature.origin,
|
|
66
|
+
is_sparse=feature.is_sparse(),
|
|
67
|
+
shape=feature.features.shape,
|
|
68
|
+
safetensors_key=safetensors_key,
|
|
69
|
+
)
|
|
70
|
+
feature_metadata_list.append(vars(metadata_item))
|
|
71
|
+
|
|
72
|
+
metadata[key] = feature_metadata_list
|
|
73
|
+
|
|
74
|
+
# Save tensors
|
|
75
|
+
save_file(tensors_to_save, file_name)
|
|
76
|
+
|
|
77
|
+
return metadata
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def load_features(
|
|
81
|
+
filename: str, metadata: Dict[str, Any]
|
|
82
|
+
) -> Dict[Text, List[Features]]:
|
|
83
|
+
"""Load Features dictionary from disk.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
filename: File name of the safetensors file.
|
|
87
|
+
metadata: Metadata to reconstruct the features.
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
Dictionary mapping strings to lists of Features objects
|
|
91
|
+
"""
|
|
92
|
+
# Load tensors
|
|
93
|
+
tensors = load_file(filename)
|
|
94
|
+
|
|
95
|
+
# Reconstruct the features dictionary
|
|
96
|
+
features_dict: Dict[Text, List[Features]] = {}
|
|
97
|
+
|
|
98
|
+
for key, feature_metadata_list in metadata.items():
|
|
99
|
+
features_list = []
|
|
100
|
+
|
|
101
|
+
for meta in feature_metadata_list:
|
|
102
|
+
safetensors_key = meta["safetensors_key"]
|
|
103
|
+
|
|
104
|
+
if meta["is_sparse"]:
|
|
105
|
+
# Reconstruct sparse matrix from COO format
|
|
106
|
+
data = tensors[f"{safetensors_key}_data"]
|
|
107
|
+
row = tensors[f"{safetensors_key}_row"]
|
|
108
|
+
col = tensors[f"{safetensors_key}_col"]
|
|
109
|
+
|
|
110
|
+
features_matrix = scipy.sparse.coo_matrix(
|
|
111
|
+
(data, (row, col)), shape=tuple(meta["shape"])
|
|
112
|
+
).tocsr() # Convert back to CSR format
|
|
113
|
+
else:
|
|
114
|
+
features_matrix = tensors[safetensors_key]
|
|
115
|
+
|
|
116
|
+
# Reconstruct Features object
|
|
117
|
+
features = Features(
|
|
118
|
+
features=features_matrix,
|
|
119
|
+
feature_type=meta["data_type"],
|
|
120
|
+
attribute=meta["attribute"],
|
|
121
|
+
origin=meta["origin"],
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
features_list.append(features)
|
|
125
|
+
|
|
126
|
+
features_dict[key] = features_list
|
|
127
|
+
|
|
128
|
+
return features_dict
|
|
129
|
+
|
|
130
|
+
|
|
13
131
|
class Features:
|
|
14
132
|
"""Stores the features produced by any featurizer."""
|
|
15
133
|
|
|
@@ -107,8 +107,7 @@ class AzureOpenAIClientConfig:
|
|
|
107
107
|
|
|
108
108
|
@classmethod
|
|
109
109
|
def from_dict(cls, config: dict) -> "AzureOpenAIClientConfig":
|
|
110
|
-
"""
|
|
111
|
-
Initializes a dataclass from the passed config.
|
|
110
|
+
"""Initializes a dataclass from the passed config.
|
|
112
111
|
|
|
113
112
|
Args:
|
|
114
113
|
config: (dict) The config from which to initialize.
|
|
@@ -175,7 +174,10 @@ def is_azure_openai_config(config: dict) -> bool:
|
|
|
175
174
|
|
|
176
175
|
# Case: Configuration contains `deployment` key
|
|
177
176
|
# (specific to Azure OpenAI configuration)
|
|
178
|
-
if
|
|
177
|
+
if (
|
|
178
|
+
config.get(DEPLOYMENT_CONFIG_KEY) is not None
|
|
179
|
+
and config.get(PROVIDER_CONFIG_KEY) is None
|
|
180
|
+
):
|
|
179
181
|
return True
|
|
180
182
|
|
|
181
183
|
return False
|
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import Any, Dict, List
|
|
4
|
+
|
|
5
|
+
import structlog
|
|
6
|
+
from rasa.shared.constants import (
|
|
7
|
+
ROUTER_CONFIG_KEY,
|
|
8
|
+
MODELS_CONFIG_KEY,
|
|
9
|
+
MODEL_GROUP_ID_CONFIG_KEY,
|
|
10
|
+
MODEL_NAME_CONFIG_KEY,
|
|
11
|
+
LITELLM_PARAMS_KEY,
|
|
12
|
+
PROVIDER_CONFIG_KEY,
|
|
13
|
+
DEPLOYMENT_CONFIG_KEY,
|
|
14
|
+
API_TYPE_CONFIG_KEY,
|
|
15
|
+
MODEL_CONFIG_KEY,
|
|
16
|
+
MODEL_LIST_KEY,
|
|
17
|
+
)
|
|
18
|
+
from rasa.shared.providers._configs.model_group_config import (
|
|
19
|
+
ModelGroupConfig,
|
|
20
|
+
ModelConfig,
|
|
21
|
+
)
|
|
22
|
+
from rasa.shared.providers.mappings import get_prefix_from_provider
|
|
23
|
+
from rasa.shared.utils.llm import DEPLOYMENT_CENTRIC_PROVIDERS
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
structlogger = structlog.get_logger()
|
|
27
|
+
|
|
28
|
+
_LITELLM_UNSUPPORTED_KEYS = [
|
|
29
|
+
PROVIDER_CONFIG_KEY,
|
|
30
|
+
DEPLOYMENT_CONFIG_KEY,
|
|
31
|
+
API_TYPE_CONFIG_KEY,
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class LiteLLMRouterClientConfig:
|
|
37
|
+
"""Parses configuration for a LiteLLM Router client. The configuration is expected
|
|
38
|
+
to be in the following format:
|
|
39
|
+
|
|
40
|
+
{
|
|
41
|
+
"id": "model_group_id",
|
|
42
|
+
"models": [
|
|
43
|
+
{
|
|
44
|
+
"provider": "provider_name",
|
|
45
|
+
"model": "model_name",
|
|
46
|
+
"api_base": "api_base",
|
|
47
|
+
"api_key": "api_key",
|
|
48
|
+
"api_version": "api_version",
|
|
49
|
+
},
|
|
50
|
+
{
|
|
51
|
+
"provider": "provider_name",
|
|
52
|
+
"model": "model_name",
|
|
53
|
+
},
|
|
54
|
+
"router": {}
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
This configuration is converted into the LiteLLM required format:
|
|
58
|
+
|
|
59
|
+
{
|
|
60
|
+
"id": "model_group_id",
|
|
61
|
+
"model_list": [
|
|
62
|
+
{
|
|
63
|
+
"model_name": "model_group_id",
|
|
64
|
+
"litellm_params": {
|
|
65
|
+
"model": "provider_name/model_name",
|
|
66
|
+
"api_base": "api_base",
|
|
67
|
+
"api_key": "api_key",
|
|
68
|
+
"api_version": "api_version",
|
|
69
|
+
},
|
|
70
|
+
},
|
|
71
|
+
{
|
|
72
|
+
"model_name": "model_group_id",
|
|
73
|
+
"litellm_params": {
|
|
74
|
+
"model": "provider_name/model_name",
|
|
75
|
+
},
|
|
76
|
+
},
|
|
77
|
+
],
|
|
78
|
+
"router": {},
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
Raises:
|
|
82
|
+
ValueError: If the configuration is missing required keys.
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
_model_group_config: ModelGroupConfig
|
|
86
|
+
router: Dict[str, Any]
|
|
87
|
+
extra_parameters: dict = field(default_factory=dict)
|
|
88
|
+
|
|
89
|
+
@property
|
|
90
|
+
def model_group_id(self) -> str:
|
|
91
|
+
return self._model_group_config.model_group_id
|
|
92
|
+
|
|
93
|
+
@property
|
|
94
|
+
def models(self) -> List[ModelConfig]:
|
|
95
|
+
return self._model_group_config.models
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def litellm_model_list(self) -> List[Dict[str, Any]]:
|
|
99
|
+
return self._convert_models_to_litellm_model_list()
|
|
100
|
+
|
|
101
|
+
def __post_init__(self) -> None:
|
|
102
|
+
if not self.router:
|
|
103
|
+
message = "Router cannot be empty."
|
|
104
|
+
structlogger.error(
|
|
105
|
+
"litellm_router_client_config.validation_error",
|
|
106
|
+
message=message,
|
|
107
|
+
model_group_id=self._model_group_config.model_group_id,
|
|
108
|
+
)
|
|
109
|
+
raise ValueError(message)
|
|
110
|
+
|
|
111
|
+
@classmethod
|
|
112
|
+
def from_dict(cls, config: dict) -> "LiteLLMRouterClientConfig":
|
|
113
|
+
"""Initializes a dataclass from the passed config.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
config: (dict) The config from which to initialize.
|
|
117
|
+
|
|
118
|
+
Raises:
|
|
119
|
+
ValueError: Config is missing required keys.
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
LiteLLMRouterClientConfig
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
model_group_config = ModelGroupConfig.from_dict(config)
|
|
126
|
+
|
|
127
|
+
# Copy config to avoid mutating the original
|
|
128
|
+
config_copy = copy.deepcopy(config)
|
|
129
|
+
# Pop the keys used by ModelGroupConfig
|
|
130
|
+
config_copy.pop(MODEL_GROUP_ID_CONFIG_KEY, None)
|
|
131
|
+
config_copy.pop(MODELS_CONFIG_KEY, None)
|
|
132
|
+
# Get the router settings
|
|
133
|
+
router_settings = config_copy.pop(ROUTER_CONFIG_KEY, None)
|
|
134
|
+
# The rest is considered as extra parameters
|
|
135
|
+
extra_parameters = config_copy
|
|
136
|
+
|
|
137
|
+
this = LiteLLMRouterClientConfig(
|
|
138
|
+
_model_group_config=model_group_config,
|
|
139
|
+
router=router_settings,
|
|
140
|
+
extra_parameters=extra_parameters,
|
|
141
|
+
)
|
|
142
|
+
return this
|
|
143
|
+
|
|
144
|
+
def to_dict(self) -> dict:
|
|
145
|
+
"""Converts the config instance into a dictionary."""
|
|
146
|
+
d = self._model_group_config.to_dict()
|
|
147
|
+
d[ROUTER_CONFIG_KEY] = self.router
|
|
148
|
+
if self.extra_parameters:
|
|
149
|
+
d.update(self.extra_parameters)
|
|
150
|
+
return d
|
|
151
|
+
|
|
152
|
+
def to_litellm_dict(self) -> dict:
|
|
153
|
+
litellm_model_list = self._convert_models_to_litellm_model_list()
|
|
154
|
+
d = {
|
|
155
|
+
**self.extra_parameters,
|
|
156
|
+
MODEL_GROUP_ID_CONFIG_KEY: self.model_group_id,
|
|
157
|
+
MODEL_LIST_KEY: litellm_model_list,
|
|
158
|
+
ROUTER_CONFIG_KEY: self.router,
|
|
159
|
+
}
|
|
160
|
+
return d
|
|
161
|
+
|
|
162
|
+
def _convert_models_to_litellm_model_list(self) -> List[Dict[str, Any]]:
|
|
163
|
+
litellm_model_list = []
|
|
164
|
+
|
|
165
|
+
for model_config_object in self.models:
|
|
166
|
+
# Convert the model config to a dict representation
|
|
167
|
+
litellm_model_config = model_config_object.to_dict()
|
|
168
|
+
|
|
169
|
+
provider = litellm_model_config[PROVIDER_CONFIG_KEY]
|
|
170
|
+
|
|
171
|
+
# Get the litellm prefixing for the provider
|
|
172
|
+
prefix = get_prefix_from_provider(provider)
|
|
173
|
+
|
|
174
|
+
# Determine whether to use model or deployment key based on the provider.
|
|
175
|
+
litellm_model_name_without_prefix = (
|
|
176
|
+
litellm_model_config[DEPLOYMENT_CONFIG_KEY]
|
|
177
|
+
if provider in DEPLOYMENT_CENTRIC_PROVIDERS
|
|
178
|
+
else litellm_model_config[MODEL_CONFIG_KEY]
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
# Set 'model' to a provider prefixed model name e.g. openai/gpt-4
|
|
182
|
+
litellm_model_config[MODEL_CONFIG_KEY] = (
|
|
183
|
+
f"{prefix}/{litellm_model_name_without_prefix}"
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
# Remove parameters that are None and not supported by LiteLLM.
|
|
187
|
+
litellm_model_config = {
|
|
188
|
+
key: value
|
|
189
|
+
for key, value in litellm_model_config.items()
|
|
190
|
+
if key not in _LITELLM_UNSUPPORTED_KEYS and value is not None
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
litellm_model_list_item = {
|
|
194
|
+
MODEL_NAME_CONFIG_KEY: self.model_group_id,
|
|
195
|
+
LITELLM_PARAMS_KEY: litellm_model_config,
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
litellm_model_list.append(litellm_model_list_item)
|
|
199
|
+
|
|
200
|
+
return litellm_model_list
|
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
from dataclasses import asdict, dataclass, field
|
|
2
|
+
from typing import List, Optional
|
|
3
|
+
|
|
4
|
+
import structlog
|
|
5
|
+
from rasa.shared.constants import (
|
|
6
|
+
API_BASE_CONFIG_KEY,
|
|
7
|
+
API_KEY,
|
|
8
|
+
API_TYPE_CONFIG_KEY,
|
|
9
|
+
API_VERSION_CONFIG_KEY,
|
|
10
|
+
DEPLOYMENT_CONFIG_KEY,
|
|
11
|
+
PROVIDER_CONFIG_KEY,
|
|
12
|
+
MODEL_CONFIG_KEY,
|
|
13
|
+
MODEL_GROUP_ID_CONFIG_KEY,
|
|
14
|
+
MODELS_CONFIG_KEY,
|
|
15
|
+
MODEL_GROUPS_CONFIG_KEY,
|
|
16
|
+
EXTRA_PARAMETERS_KEY,
|
|
17
|
+
)
|
|
18
|
+
from rasa.shared.providers.mappings import get_client_config_class_from_provider
|
|
19
|
+
|
|
20
|
+
structlogger = structlog.get_logger()
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class ModelConfig:
|
|
25
|
+
"""Parses the model config.
|
|
26
|
+
|
|
27
|
+
Raises:
|
|
28
|
+
ValueError: If the provider config key is missing in the config.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
provider: str
|
|
32
|
+
model: Optional[str] = None
|
|
33
|
+
deployment: Optional[str] = None
|
|
34
|
+
api_base: Optional[str] = None
|
|
35
|
+
api_key: Optional[str] = None
|
|
36
|
+
api_version: Optional[str] = None
|
|
37
|
+
extra_parameters: dict = field(default_factory=dict)
|
|
38
|
+
# Retained for backward compatibility with older configurations,
|
|
39
|
+
# but intentionally not included in extra_parameters
|
|
40
|
+
api_type: Optional[str] = None
|
|
41
|
+
|
|
42
|
+
@classmethod
|
|
43
|
+
def from_dict(cls, config: dict) -> "ModelConfig":
|
|
44
|
+
"""Initializes a dataclass from the passed config. The provider config param is
|
|
45
|
+
used to determine the client config class to use. The client config class takes
|
|
46
|
+
care of resolving config aliases and throwing deprecation warnings.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
config: (dict) The config from which to initialize.
|
|
50
|
+
|
|
51
|
+
Raises:
|
|
52
|
+
ValueError: Config is missing required keys.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
ModelConfig
|
|
56
|
+
"""
|
|
57
|
+
from rasa.shared.utils.llm import get_provider_from_config
|
|
58
|
+
|
|
59
|
+
# Get the provider from config, this also inferring the provider from
|
|
60
|
+
# deprecated configurations
|
|
61
|
+
provider = get_provider_from_config(config)
|
|
62
|
+
|
|
63
|
+
# Retrieve the client configuration class for the specified provider.
|
|
64
|
+
client_config_clazz = get_client_config_class_from_provider(provider)
|
|
65
|
+
|
|
66
|
+
# Try to instantiate the config object in order to resolve deprecated
|
|
67
|
+
# aliases and throw deprecation warnings.
|
|
68
|
+
client_config_obj = client_config_clazz.from_dict(config)
|
|
69
|
+
|
|
70
|
+
# Convert back to dictionary and instantiate the ModelConfig object.
|
|
71
|
+
client_config = client_config_obj.to_dict()
|
|
72
|
+
|
|
73
|
+
# Check for provider after resolving all aliases
|
|
74
|
+
if PROVIDER_CONFIG_KEY not in client_config:
|
|
75
|
+
raise ValueError(
|
|
76
|
+
f"Missing required key '{PROVIDER_CONFIG_KEY}' in "
|
|
77
|
+
f"'{MODELS_CONFIG_KEY}' config."
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
return ModelConfig(
|
|
81
|
+
provider=client_config.pop(PROVIDER_CONFIG_KEY, None),
|
|
82
|
+
model=client_config.pop(MODEL_CONFIG_KEY, None),
|
|
83
|
+
deployment=client_config.pop(DEPLOYMENT_CONFIG_KEY, None),
|
|
84
|
+
api_type=client_config.pop(API_TYPE_CONFIG_KEY, None),
|
|
85
|
+
api_base=client_config.pop(API_BASE_CONFIG_KEY, None),
|
|
86
|
+
api_key=client_config.pop(API_KEY, None),
|
|
87
|
+
api_version=client_config.pop(API_VERSION_CONFIG_KEY, None),
|
|
88
|
+
extra_parameters=client_config,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
def to_dict(self) -> dict:
|
|
92
|
+
"""Converts the config instance into a dictionary."""
|
|
93
|
+
d = asdict(self)
|
|
94
|
+
|
|
95
|
+
# Extra parameters should also be on the top level
|
|
96
|
+
d.pop(EXTRA_PARAMETERS_KEY, None)
|
|
97
|
+
d.update(self.extra_parameters)
|
|
98
|
+
|
|
99
|
+
# Remove keys with None values
|
|
100
|
+
return {key: value for key, value in d.items() if value is not None}
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@dataclass
|
|
104
|
+
class ModelGroupConfig:
|
|
105
|
+
"""Parses the models config. The models config is a list of model configs.
|
|
106
|
+
|
|
107
|
+
Raises:
|
|
108
|
+
ValueError: If the model group ID is None or if the models list is empty.
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
model_group_id: str
|
|
112
|
+
models: List[ModelConfig]
|
|
113
|
+
|
|
114
|
+
def __post_init__(self) -> None:
|
|
115
|
+
if self.model_group_id is None:
|
|
116
|
+
message = "Model group ID cannot be set to None."
|
|
117
|
+
structlogger.error(
|
|
118
|
+
"model_group_config.validation_error",
|
|
119
|
+
message=message,
|
|
120
|
+
model_group_id=self.model_group_id,
|
|
121
|
+
)
|
|
122
|
+
raise ValueError(message)
|
|
123
|
+
if not self.models:
|
|
124
|
+
message = "Models cannot be empty."
|
|
125
|
+
structlogger.error(
|
|
126
|
+
"model_group_config.validation_error",
|
|
127
|
+
message=message,
|
|
128
|
+
model_group_id=self.model_group_id,
|
|
129
|
+
)
|
|
130
|
+
raise ValueError(message)
|
|
131
|
+
|
|
132
|
+
@classmethod
|
|
133
|
+
def from_dict(cls, config: dict) -> "ModelGroupConfig":
|
|
134
|
+
"""Initializes a dataclass from the passed config.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
config: (dict) The config from which to initialize.
|
|
138
|
+
|
|
139
|
+
Raises:
|
|
140
|
+
ValueError: Config is missing required keys.
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
ModelGroupConfig
|
|
144
|
+
"""
|
|
145
|
+
if MODELS_CONFIG_KEY not in config:
|
|
146
|
+
raise ValueError(
|
|
147
|
+
f"Missing required key '{MODELS_CONFIG_KEY}' in "
|
|
148
|
+
f"'{MODEL_GROUPS_CONFIG_KEY}' config."
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
models_config = [
|
|
152
|
+
ModelConfig.from_dict(model_config)
|
|
153
|
+
for model_config in config[MODELS_CONFIG_KEY]
|
|
154
|
+
]
|
|
155
|
+
|
|
156
|
+
return cls(
|
|
157
|
+
model_group_id=config.get(MODEL_GROUP_ID_CONFIG_KEY),
|
|
158
|
+
models=models_config,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
def to_dict(self) -> dict:
|
|
162
|
+
"""Converts the config instance into a dictionary."""
|
|
163
|
+
d = {
|
|
164
|
+
MODEL_GROUP_ID_CONFIG_KEY: self.model_group_id,
|
|
165
|
+
MODELS_CONFIG_KEY: [model.to_dict() for model in self.models],
|
|
166
|
+
}
|
|
167
|
+
return d
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
from dataclasses import asdict, dataclass, field
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import structlog
|
|
5
|
+
|
|
6
|
+
from rasa.shared.constants import (
|
|
7
|
+
MODEL_CONFIG_KEY,
|
|
8
|
+
RASA_PROVIDER,
|
|
9
|
+
PROVIDER_CONFIG_KEY,
|
|
10
|
+
API_BASE_CONFIG_KEY,
|
|
11
|
+
)
|
|
12
|
+
from rasa.shared.providers._configs.utils import (
|
|
13
|
+
validate_required_keys,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
REQUIRED_KEYS = [MODEL_CONFIG_KEY, PROVIDER_CONFIG_KEY, API_BASE_CONFIG_KEY]
|
|
17
|
+
|
|
18
|
+
structlogger = structlog.get_logger()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class RasaLLMClientConfig:
|
|
23
|
+
"""Parses configuration for a Rasa Hosted LiteLLM client,
|
|
24
|
+
checks required keys present.
|
|
25
|
+
|
|
26
|
+
Raises:
|
|
27
|
+
ValueError: Raised in cases of invalid configuration:
|
|
28
|
+
- If any of the required configuration keys are missing.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
model: Optional[str]
|
|
32
|
+
api_base: Optional[str]
|
|
33
|
+
# Provider is not used by LiteLLM backend, but we define it here since it's
|
|
34
|
+
# used as switch between different clients.
|
|
35
|
+
provider: str = RASA_PROVIDER
|
|
36
|
+
|
|
37
|
+
extra_parameters: dict = field(default_factory=dict)
|
|
38
|
+
|
|
39
|
+
@classmethod
|
|
40
|
+
def from_dict(cls, config: dict) -> "RasaLLMClientConfig":
|
|
41
|
+
"""
|
|
42
|
+
Initializes a dataclass from the passed config.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
config: (dict) The config from which to initialize.
|
|
46
|
+
|
|
47
|
+
Raises:
|
|
48
|
+
ValueError: Raised in cases of invalid configuration:
|
|
49
|
+
- If any of the required configuration keys are missing.
|
|
50
|
+
- If `api_type` has a value different from `azure`.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
RasaLLMClientConfig
|
|
54
|
+
"""
|
|
55
|
+
# Validate that required keys are set
|
|
56
|
+
validate_required_keys(config, REQUIRED_KEYS)
|
|
57
|
+
|
|
58
|
+
extra_parameters = {k: v for k, v in config.items() if k not in REQUIRED_KEYS}
|
|
59
|
+
|
|
60
|
+
return cls(
|
|
61
|
+
model=config.get(MODEL_CONFIG_KEY),
|
|
62
|
+
api_base=config.get(API_BASE_CONFIG_KEY),
|
|
63
|
+
provider=config.get(PROVIDER_CONFIG_KEY, RASA_PROVIDER),
|
|
64
|
+
extra_parameters=extra_parameters,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
def to_dict(self) -> dict:
|
|
68
|
+
"""Converts the config instance into a dictionary."""
|
|
69
|
+
d = asdict(self)
|
|
70
|
+
# Extra parameters should also be on the top level
|
|
71
|
+
d.pop("extra_parameters", None)
|
|
72
|
+
d.update(self.extra_parameters)
|
|
73
|
+
return d
|
|
@@ -99,3 +99,19 @@ def validate_forbidden_keys(config: dict, forbidden_keys: list) -> None:
|
|
|
99
99
|
config=config,
|
|
100
100
|
)
|
|
101
101
|
raise ValueError(message)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def get_provider_prefixed_model_name(provider: str, model: str) -> str:
|
|
105
|
+
"""
|
|
106
|
+
Returns the model name with the provider prefixed.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
provider: The provider of the model.
|
|
110
|
+
model: The model name.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
The model name with the provider prefixed.
|
|
114
|
+
"""
|
|
115
|
+
if model and f"{provider}/" not in model:
|
|
116
|
+
return f"{provider}/{model}"
|
|
117
|
+
return model
|