rasa-pro 3.9.18__py3-none-any.whl → 3.10.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- README.md +26 -57
- rasa/__init__.py +1 -2
- rasa/__main__.py +5 -0
- rasa/anonymization/anonymization_rule_executor.py +2 -2
- rasa/api.py +26 -22
- rasa/cli/arguments/data.py +27 -2
- rasa/cli/arguments/default_arguments.py +25 -3
- rasa/cli/arguments/run.py +9 -9
- rasa/cli/arguments/train.py +2 -0
- rasa/cli/data.py +70 -8
- rasa/cli/e2e_test.py +108 -433
- rasa/cli/interactive.py +1 -0
- rasa/cli/llm_fine_tuning.py +395 -0
- rasa/cli/project_templates/calm/endpoints.yml +1 -1
- rasa/cli/project_templates/tutorial/endpoints.yml +1 -1
- rasa/cli/run.py +14 -13
- rasa/cli/scaffold.py +10 -8
- rasa/cli/train.py +8 -7
- rasa/cli/utils.py +15 -0
- rasa/constants.py +7 -1
- rasa/core/actions/action.py +98 -49
- rasa/core/actions/action_run_slot_rejections.py +4 -1
- rasa/core/actions/custom_action_executor.py +9 -6
- rasa/core/actions/direct_custom_actions_executor.py +80 -0
- rasa/core/actions/e2e_stub_custom_action_executor.py +68 -0
- rasa/core/actions/grpc_custom_action_executor.py +2 -2
- rasa/core/actions/http_custom_action_executor.py +6 -5
- rasa/core/agent.py +21 -17
- rasa/core/channels/__init__.py +2 -0
- rasa/core/channels/audiocodes.py +1 -16
- rasa/core/channels/inspector/dist/index.html +0 -2
- rasa/core/channels/inspector/index.html +0 -2
- rasa/core/channels/voice_aware/__init__.py +0 -0
- rasa/core/channels/voice_aware/jambonz.py +103 -0
- rasa/core/channels/voice_aware/jambonz_protocol.py +344 -0
- rasa/core/channels/voice_aware/utils.py +20 -0
- rasa/core/channels/voice_native/__init__.py +0 -0
- rasa/core/constants.py +6 -1
- rasa/core/featurizers/single_state_featurizer.py +1 -22
- rasa/core/featurizers/tracker_featurizers.py +18 -115
- rasa/core/information_retrieval/faiss.py +7 -4
- rasa/core/information_retrieval/information_retrieval.py +8 -0
- rasa/core/information_retrieval/milvus.py +9 -2
- rasa/core/information_retrieval/qdrant.py +1 -1
- rasa/core/nlg/contextual_response_rephraser.py +32 -10
- rasa/core/nlg/summarize.py +4 -3
- rasa/core/policies/enterprise_search_policy.py +100 -44
- rasa/core/policies/flows/flow_executor.py +130 -94
- rasa/core/policies/intentless_policy.py +52 -28
- rasa/core/policies/ted_policy.py +33 -58
- rasa/core/policies/unexpected_intent_policy.py +7 -15
- rasa/core/processor.py +20 -53
- rasa/core/run.py +5 -4
- rasa/core/tracker_store.py +8 -4
- rasa/core/utils.py +45 -56
- rasa/dialogue_understanding/coexistence/llm_based_router.py +45 -12
- rasa/dialogue_understanding/commands/__init__.py +4 -0
- rasa/dialogue_understanding/commands/change_flow_command.py +0 -6
- rasa/dialogue_understanding/commands/session_start_command.py +59 -0
- rasa/dialogue_understanding/commands/set_slot_command.py +1 -5
- rasa/dialogue_understanding/commands/utils.py +38 -0
- rasa/dialogue_understanding/generator/constants.py +10 -3
- rasa/dialogue_understanding/generator/flow_retrieval.py +14 -5
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +12 -2
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +106 -87
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +28 -6
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +90 -37
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +15 -15
- rasa/dialogue_understanding/patterns/session_start.py +37 -0
- rasa/dialogue_understanding/processor/command_processor.py +13 -14
- rasa/e2e_test/aggregate_test_stats_calculator.py +124 -0
- rasa/e2e_test/assertions.py +1181 -0
- rasa/e2e_test/assertions_schema.yml +106 -0
- rasa/e2e_test/constants.py +20 -0
- rasa/e2e_test/e2e_config.py +220 -0
- rasa/e2e_test/e2e_config_schema.yml +26 -0
- rasa/e2e_test/e2e_test_case.py +131 -8
- rasa/e2e_test/e2e_test_converter.py +363 -0
- rasa/e2e_test/e2e_test_converter_prompt.jinja2 +70 -0
- rasa/e2e_test/e2e_test_coverage_report.py +364 -0
- rasa/e2e_test/e2e_test_result.py +26 -6
- rasa/e2e_test/e2e_test_runner.py +491 -72
- rasa/e2e_test/e2e_test_schema.yml +96 -0
- rasa/e2e_test/pykwalify_extensions.py +39 -0
- rasa/e2e_test/stub_custom_action.py +70 -0
- rasa/e2e_test/utils/__init__.py +0 -0
- rasa/e2e_test/utils/e2e_yaml_utils.py +55 -0
- rasa/e2e_test/utils/io.py +596 -0
- rasa/e2e_test/utils/validation.py +80 -0
- rasa/engine/recipes/default_components.py +0 -2
- rasa/engine/storage/local_model_storage.py +0 -1
- rasa/env.py +9 -0
- rasa/llm_fine_tuning/__init__.py +0 -0
- rasa/llm_fine_tuning/annotation_module.py +241 -0
- rasa/llm_fine_tuning/conversations.py +144 -0
- rasa/llm_fine_tuning/llm_data_preparation_module.py +178 -0
- rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +407 -0
- rasa/llm_fine_tuning/paraphrasing/__init__.py +0 -0
- rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +281 -0
- rasa/llm_fine_tuning/paraphrasing/default_rephrase_prompt_template.jina2 +44 -0
- rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +121 -0
- rasa/llm_fine_tuning/paraphrasing/rephrased_user_message.py +10 -0
- rasa/llm_fine_tuning/paraphrasing_module.py +128 -0
- rasa/llm_fine_tuning/storage.py +174 -0
- rasa/llm_fine_tuning/train_test_split_module.py +441 -0
- rasa/model_training.py +48 -16
- rasa/nlu/classifiers/diet_classifier.py +25 -38
- rasa/nlu/classifiers/logistic_regression_classifier.py +9 -44
- rasa/nlu/classifiers/sklearn_intent_classifier.py +16 -37
- rasa/nlu/extractors/crf_entity_extractor.py +50 -93
- rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +45 -78
- rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +17 -52
- rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +3 -5
- rasa/nlu/persistor.py +129 -32
- rasa/server.py +45 -10
- rasa/shared/constants.py +63 -15
- rasa/shared/core/domain.py +15 -12
- rasa/shared/core/events.py +28 -2
- rasa/shared/core/flows/flow.py +208 -13
- rasa/shared/core/flows/flow_path.py +84 -0
- rasa/shared/core/flows/flows_list.py +28 -10
- rasa/shared/core/flows/flows_yaml_schema.json +269 -193
- rasa/shared/core/flows/validation.py +112 -25
- rasa/shared/core/flows/yaml_flows_io.py +149 -10
- rasa/shared/core/trackers.py +6 -0
- rasa/shared/core/training_data/visualization.html +2 -2
- rasa/shared/exceptions.py +4 -0
- rasa/shared/importers/importer.py +60 -11
- rasa/shared/importers/remote_importer.py +196 -0
- rasa/shared/nlu/constants.py +2 -0
- rasa/shared/nlu/training_data/features.py +2 -120
- rasa/shared/providers/_configs/__init__.py +0 -0
- rasa/shared/providers/_configs/azure_openai_client_config.py +181 -0
- rasa/shared/providers/_configs/client_config.py +57 -0
- rasa/shared/providers/_configs/default_litellm_client_config.py +130 -0
- rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +234 -0
- rasa/shared/providers/_configs/openai_client_config.py +175 -0
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +171 -0
- rasa/shared/providers/_configs/utils.py +101 -0
- rasa/shared/providers/_ssl_verification_utils.py +124 -0
- rasa/shared/providers/embedding/__init__.py +0 -0
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +254 -0
- rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py +74 -0
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +277 -0
- rasa/shared/providers/embedding/default_litellm_embedding_client.py +102 -0
- rasa/shared/providers/embedding/embedding_client.py +90 -0
- rasa/shared/providers/embedding/embedding_response.py +41 -0
- rasa/shared/providers/embedding/huggingface_local_embedding_client.py +191 -0
- rasa/shared/providers/embedding/openai_embedding_client.py +172 -0
- rasa/shared/providers/llm/__init__.py +0 -0
- rasa/shared/providers/llm/_base_litellm_client.py +227 -0
- rasa/shared/providers/llm/azure_openai_llm_client.py +338 -0
- rasa/shared/providers/llm/default_litellm_llm_client.py +84 -0
- rasa/shared/providers/llm/llm_client.py +76 -0
- rasa/shared/providers/llm/llm_response.py +50 -0
- rasa/shared/providers/llm/openai_llm_client.py +155 -0
- rasa/shared/providers/llm/self_hosted_llm_client.py +169 -0
- rasa/shared/providers/mappings.py +75 -0
- rasa/shared/utils/cli.py +30 -0
- rasa/shared/utils/io.py +65 -3
- rasa/shared/utils/llm.py +223 -200
- rasa/shared/utils/yaml.py +122 -7
- rasa/studio/download.py +19 -13
- rasa/studio/train.py +2 -3
- rasa/studio/upload.py +2 -3
- rasa/telemetry.py +113 -58
- rasa/tracing/config.py +2 -3
- rasa/tracing/instrumentation/attribute_extractors.py +29 -17
- rasa/tracing/instrumentation/instrumentation.py +4 -47
- rasa/utils/common.py +18 -19
- rasa/utils/endpoints.py +7 -4
- rasa/utils/io.py +66 -0
- rasa/utils/json_utils.py +60 -0
- rasa/utils/licensing.py +9 -1
- rasa/utils/ml_utils.py +4 -2
- rasa/utils/tensorflow/model_data.py +193 -2
- rasa/validator.py +195 -1
- rasa/version.py +1 -1
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.3.dist-info}/METADATA +47 -72
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.3.dist-info}/RECORD +185 -121
- rasa/nlu/classifiers/llm_intent_classifier.py +0 -519
- rasa/shared/providers/openai/clients.py +0 -43
- rasa/shared/providers/openai/session_handler.py +0 -110
- rasa/utils/tensorflow/feature_array.py +0 -366
- /rasa/{shared/providers/openai → cli/project_templates/tutorial/actions}/__init__.py +0 -0
- /rasa/cli/project_templates/tutorial/{actions.py → actions/actions.py} +0 -0
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.3.dist-info}/NOTICE +0 -0
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.3.dist-info}/WHEEL +0 -0
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.3.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
from dataclasses import asdict, dataclass, field
|
|
2
|
+
from typing import Any, Dict
|
|
3
|
+
|
|
4
|
+
import structlog
|
|
5
|
+
|
|
6
|
+
from rasa.shared.constants import (
|
|
7
|
+
MODEL_CONFIG_KEY,
|
|
8
|
+
MODEL_NAME_CONFIG_KEY,
|
|
9
|
+
STREAM_CONFIG_KEY,
|
|
10
|
+
N_REPHRASES_CONFIG_KEY,
|
|
11
|
+
PROVIDER_CONFIG_KEY,
|
|
12
|
+
TIMEOUT_CONFIG_KEY,
|
|
13
|
+
REQUEST_TIMEOUT_CONFIG_KEY,
|
|
14
|
+
)
|
|
15
|
+
from rasa.shared.providers._configs.utils import (
|
|
16
|
+
validate_required_keys,
|
|
17
|
+
validate_forbidden_keys,
|
|
18
|
+
resolve_aliases,
|
|
19
|
+
raise_deprecation_warnings,
|
|
20
|
+
)
|
|
21
|
+
import rasa.shared.utils.cli
|
|
22
|
+
|
|
23
|
+
structlogger = structlog.get_logger()
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING = {
|
|
27
|
+
# Timeout aliases
|
|
28
|
+
REQUEST_TIMEOUT_CONFIG_KEY: TIMEOUT_CONFIG_KEY,
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
REQUIRED_KEYS = [MODEL_CONFIG_KEY, PROVIDER_CONFIG_KEY]
|
|
32
|
+
|
|
33
|
+
FORBIDDEN_KEYS = [
|
|
34
|
+
STREAM_CONFIG_KEY,
|
|
35
|
+
N_REPHRASES_CONFIG_KEY,
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass
|
|
40
|
+
class DefaultLiteLLMClientConfig:
|
|
41
|
+
"""Parses configuration for default LiteLLM client, resolves aliases and
|
|
42
|
+
raises deprecation warnings.
|
|
43
|
+
|
|
44
|
+
Raises:
|
|
45
|
+
ValueError: Raised in cases of invalid configuration:
|
|
46
|
+
- If any of the required configuration keys are missing.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
model: str
|
|
50
|
+
provider: str
|
|
51
|
+
extra_parameters: dict = field(default_factory=dict)
|
|
52
|
+
|
|
53
|
+
def __post_init__(self) -> None:
|
|
54
|
+
if self.model is None:
|
|
55
|
+
message = "Model cannot be set to None."
|
|
56
|
+
structlogger.error(
|
|
57
|
+
"default_litellm_client_config.validation_error",
|
|
58
|
+
message=message,
|
|
59
|
+
model=self.model,
|
|
60
|
+
)
|
|
61
|
+
raise ValueError(message)
|
|
62
|
+
if self.provider is None:
|
|
63
|
+
message = "Provider cannot be set to None."
|
|
64
|
+
structlogger.error(
|
|
65
|
+
"default_litellm_client_config.validation_error",
|
|
66
|
+
message=message,
|
|
67
|
+
provider=self.provider,
|
|
68
|
+
)
|
|
69
|
+
raise ValueError(message)
|
|
70
|
+
|
|
71
|
+
@classmethod
|
|
72
|
+
def from_dict(cls, config: dict) -> "DefaultLiteLLMClientConfig":
|
|
73
|
+
"""
|
|
74
|
+
Initializes a dataclass from the passed config.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
config: (dict) The config from which to initialize.
|
|
78
|
+
|
|
79
|
+
Raises:
|
|
80
|
+
ValueError: Config is missing required keys.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
DefaultLiteLLMClientConfig
|
|
84
|
+
"""
|
|
85
|
+
# Check for deprecated keys
|
|
86
|
+
raise_deprecation_warnings(config, DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING)
|
|
87
|
+
# Raise error for using `model_name` instead instead of `model`
|
|
88
|
+
cls.check_and_error_for_model_name_in_config(config)
|
|
89
|
+
# Resolve any potential aliases.
|
|
90
|
+
config = cls.resolve_config_aliases(config)
|
|
91
|
+
# Validate that the required keys are present
|
|
92
|
+
validate_required_keys(config, REQUIRED_KEYS)
|
|
93
|
+
# Validate that the forbidden keys are not present
|
|
94
|
+
validate_forbidden_keys(config, FORBIDDEN_KEYS)
|
|
95
|
+
this = DefaultLiteLLMClientConfig(
|
|
96
|
+
# Required parameters
|
|
97
|
+
model=config.pop(MODEL_CONFIG_KEY),
|
|
98
|
+
provider=config.pop(PROVIDER_CONFIG_KEY),
|
|
99
|
+
# The rest of parameters (e.g. model parameters) are considered
|
|
100
|
+
# as extra parameters
|
|
101
|
+
extra_parameters=config,
|
|
102
|
+
)
|
|
103
|
+
return this
|
|
104
|
+
|
|
105
|
+
def to_dict(self) -> dict:
|
|
106
|
+
"""Converts the config instance into a dictionary."""
|
|
107
|
+
d = asdict(self)
|
|
108
|
+
# Extra parameters should also be on the top level
|
|
109
|
+
d.pop("extra_parameters", None)
|
|
110
|
+
d.update(self.extra_parameters)
|
|
111
|
+
return d
|
|
112
|
+
|
|
113
|
+
@staticmethod
|
|
114
|
+
def resolve_config_aliases(config: Dict[str, Any]) -> Dict[str, Any]:
|
|
115
|
+
return resolve_aliases(config, DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING)
|
|
116
|
+
|
|
117
|
+
@staticmethod
|
|
118
|
+
def check_and_error_for_model_name_in_config(config: Dict[str, Any]) -> None:
|
|
119
|
+
"""Check for usage of deprecated model_name and raise an error if found."""
|
|
120
|
+
if config.get(MODEL_NAME_CONFIG_KEY) and not config.get(MODEL_CONFIG_KEY):
|
|
121
|
+
event_info = (
|
|
122
|
+
f"Unsupported parameter - {MODEL_NAME_CONFIG_KEY} is set. Please use "
|
|
123
|
+
f"{MODEL_CONFIG_KEY} instead."
|
|
124
|
+
)
|
|
125
|
+
structlogger.error(
|
|
126
|
+
"default_litellm_client_config.unsupported_parameter_in_config",
|
|
127
|
+
event_info=event_info,
|
|
128
|
+
config=config,
|
|
129
|
+
)
|
|
130
|
+
rasa.shared.utils.cli.print_error_and_exit(event_info)
|
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
from dataclasses import asdict, dataclass, field
|
|
2
|
+
from typing import Any, Dict, Optional
|
|
3
|
+
|
|
4
|
+
import structlog
|
|
5
|
+
|
|
6
|
+
from rasa.shared.constants import (
|
|
7
|
+
MODEL_CONFIG_KEY,
|
|
8
|
+
MODEL_NAME_CONFIG_KEY,
|
|
9
|
+
RASA_TYPE_CONFIG_KEY,
|
|
10
|
+
LANGCHAIN_TYPE_CONFIG_KEY,
|
|
11
|
+
HUGGINGFACE_MULTIPROCESS_CONFIG_KEY,
|
|
12
|
+
HUGGINGFACE_CACHE_FOLDER_CONFIG_KEY,
|
|
13
|
+
HUGGINGFACE_SHOW_PROGRESS_CONFIG_KEY,
|
|
14
|
+
HUGGINGFACE_MODEL_KWARGS_CONFIG_KEY,
|
|
15
|
+
HUGGINGFACE_ENCODE_KWARGS_CONFIG_KEY,
|
|
16
|
+
HUGGINGFACE_LOCAL_EMBEDDING_CACHING_FOLDER,
|
|
17
|
+
PROVIDER_CONFIG_KEY,
|
|
18
|
+
HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER,
|
|
19
|
+
TIMEOUT_CONFIG_KEY,
|
|
20
|
+
REQUEST_TIMEOUT_CONFIG_KEY,
|
|
21
|
+
)
|
|
22
|
+
from rasa.shared.providers._configs.utils import (
|
|
23
|
+
resolve_aliases,
|
|
24
|
+
raise_deprecation_warnings,
|
|
25
|
+
validate_required_keys,
|
|
26
|
+
)
|
|
27
|
+
from rasa.shared.utils.io import raise_deprecation_warning
|
|
28
|
+
|
|
29
|
+
structlogger = structlog.get_logger()
|
|
30
|
+
|
|
31
|
+
DEPRECATED_HUGGINGFACE_TYPE = "huggingface"
|
|
32
|
+
|
|
33
|
+
DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING = {
|
|
34
|
+
# Provider aliases
|
|
35
|
+
RASA_TYPE_CONFIG_KEY: PROVIDER_CONFIG_KEY,
|
|
36
|
+
LANGCHAIN_TYPE_CONFIG_KEY: PROVIDER_CONFIG_KEY,
|
|
37
|
+
# Model name aliases
|
|
38
|
+
MODEL_NAME_CONFIG_KEY: MODEL_CONFIG_KEY,
|
|
39
|
+
# Timeout aliases
|
|
40
|
+
REQUEST_TIMEOUT_CONFIG_KEY: TIMEOUT_CONFIG_KEY,
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
REQUIRED_KEYS = [MODEL_CONFIG_KEY, PROVIDER_CONFIG_KEY]
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass
|
|
47
|
+
class HuggingFaceLocalEmbeddingClientConfig:
|
|
48
|
+
"""Parses configuration for HuggingFace local embeddings client, resolves
|
|
49
|
+
aliases and raises deprecation warnings.
|
|
50
|
+
|
|
51
|
+
Raises:
|
|
52
|
+
ValueError: Raised in cases of invalid configuration:
|
|
53
|
+
- If any of the required configuration keys are missing.
|
|
54
|
+
- If `api_type` has a value different from `huggingface_local` or
|
|
55
|
+
`huggingface` (deprecated).
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
model: str
|
|
59
|
+
|
|
60
|
+
multi_process: Optional[bool]
|
|
61
|
+
cache_folder: Optional[str]
|
|
62
|
+
show_progress: Optional[bool]
|
|
63
|
+
|
|
64
|
+
# Provider is not actually used by sentence-transformers, but we define
|
|
65
|
+
# it here because it's used as a switch denominator for HuggingFace
|
|
66
|
+
# local embedding client.
|
|
67
|
+
provider: str = HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER
|
|
68
|
+
|
|
69
|
+
model_kwargs: dict = field(default_factory=dict)
|
|
70
|
+
encode_kwargs: dict = field(default_factory=dict)
|
|
71
|
+
|
|
72
|
+
def __post_init__(self) -> None:
|
|
73
|
+
if self.provider != HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER:
|
|
74
|
+
message = (
|
|
75
|
+
f"API type must be set to '{HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER}'."
|
|
76
|
+
)
|
|
77
|
+
structlogger.error(
|
|
78
|
+
"huggingface_local_embeddings_client_config.validation_error",
|
|
79
|
+
message=message,
|
|
80
|
+
provider=self.provider,
|
|
81
|
+
)
|
|
82
|
+
raise ValueError(message)
|
|
83
|
+
if self.model is None:
|
|
84
|
+
message = "Model cannot be set to None."
|
|
85
|
+
structlogger.error(
|
|
86
|
+
"huggingface_local_embeddings_client_config.validation_error",
|
|
87
|
+
message=message,
|
|
88
|
+
model=self.model,
|
|
89
|
+
)
|
|
90
|
+
raise ValueError(message)
|
|
91
|
+
|
|
92
|
+
@classmethod
|
|
93
|
+
def from_dict(cls, config: dict) -> "HuggingFaceLocalEmbeddingClientConfig":
|
|
94
|
+
"""
|
|
95
|
+
Initializes a dataclass from the passed config.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
config: (dict) The config from which to initialize.
|
|
99
|
+
|
|
100
|
+
Raises:
|
|
101
|
+
ValueError: Config is missing required keys.
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
DefaultLiteLLMClientConfig
|
|
105
|
+
"""
|
|
106
|
+
# Check for usage of deprecated switching key and value:
|
|
107
|
+
# 1. type: huggingface
|
|
108
|
+
# 2. _type: huggingface
|
|
109
|
+
_raise_deprecation_warning_for_huggingface_deprecated_switch_value(config)
|
|
110
|
+
# Check for other deprecated keys
|
|
111
|
+
raise_deprecation_warnings(config, DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING)
|
|
112
|
+
# Resolve any potential aliases
|
|
113
|
+
config = cls.resolve_config_aliases(config)
|
|
114
|
+
# Validate that required keys are set
|
|
115
|
+
validate_required_keys(config, REQUIRED_KEYS)
|
|
116
|
+
this = HuggingFaceLocalEmbeddingClientConfig(
|
|
117
|
+
# Required parameters
|
|
118
|
+
model=config.pop(MODEL_CONFIG_KEY),
|
|
119
|
+
provider=config.pop(PROVIDER_CONFIG_KEY),
|
|
120
|
+
# Optional
|
|
121
|
+
multi_process=config.pop(HUGGINGFACE_MULTIPROCESS_CONFIG_KEY, False),
|
|
122
|
+
cache_folder=config.pop(
|
|
123
|
+
HUGGINGFACE_CACHE_FOLDER_CONFIG_KEY,
|
|
124
|
+
str(HUGGINGFACE_LOCAL_EMBEDDING_CACHING_FOLDER),
|
|
125
|
+
),
|
|
126
|
+
show_progress=config.pop(HUGGINGFACE_SHOW_PROGRESS_CONFIG_KEY, False),
|
|
127
|
+
model_kwargs=config.pop(HUGGINGFACE_MODEL_KWARGS_CONFIG_KEY, {}),
|
|
128
|
+
encode_kwargs=config.pop(HUGGINGFACE_ENCODE_KWARGS_CONFIG_KEY, {}),
|
|
129
|
+
)
|
|
130
|
+
return this
|
|
131
|
+
|
|
132
|
+
def to_dict(self) -> dict:
|
|
133
|
+
"""Converts the config instance into a dictionary."""
|
|
134
|
+
return asdict(self)
|
|
135
|
+
|
|
136
|
+
@staticmethod
|
|
137
|
+
def resolve_config_aliases(config: Dict[str, Any]) -> Dict[str, Any]:
|
|
138
|
+
config = _resolve_huggingface_deprecated_switch_value(config)
|
|
139
|
+
return resolve_aliases(config, DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def is_huggingface_local_config(config: dict) -> bool:
|
|
143
|
+
"""Check whether the configuration is meant to configure
|
|
144
|
+
a local HuggingFace embedding client.
|
|
145
|
+
"""
|
|
146
|
+
# Hugging face special deprecated cases:
|
|
147
|
+
# 1. type: huggingface
|
|
148
|
+
# 2. _type: huggingface
|
|
149
|
+
# If the deprecated setting is detected resolve both alias key and key
|
|
150
|
+
# value. This would mean that the configurations above will be
|
|
151
|
+
# transformed to:
|
|
152
|
+
# provider: huggingface_local
|
|
153
|
+
config = HuggingFaceLocalEmbeddingClientConfig.resolve_config_aliases(config)
|
|
154
|
+
|
|
155
|
+
# Case: Configuration contains `provider: huggingface_local`
|
|
156
|
+
if config.get(PROVIDER_CONFIG_KEY) in [
|
|
157
|
+
HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER,
|
|
158
|
+
]:
|
|
159
|
+
return True
|
|
160
|
+
|
|
161
|
+
return False
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def _raise_deprecation_warning_for_huggingface_deprecated_switch_value(
|
|
165
|
+
config: dict,
|
|
166
|
+
) -> None:
|
|
167
|
+
deprecated_switch_keys = [RASA_TYPE_CONFIG_KEY, LANGCHAIN_TYPE_CONFIG_KEY]
|
|
168
|
+
deprecation_message = (
|
|
169
|
+
f"Configuration "
|
|
170
|
+
f"`{{deprecated_switch_key}}: {DEPRECATED_HUGGINGFACE_TYPE}` "
|
|
171
|
+
f"is deprecated and will be removed in 4.0.0. "
|
|
172
|
+
f"Please use "
|
|
173
|
+
f"`{PROVIDER_CONFIG_KEY}: {HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER}` "
|
|
174
|
+
f"instead."
|
|
175
|
+
)
|
|
176
|
+
for deprecated_switch_key in deprecated_switch_keys:
|
|
177
|
+
if (
|
|
178
|
+
deprecated_switch_key in config
|
|
179
|
+
and config[deprecated_switch_key] == DEPRECATED_HUGGINGFACE_TYPE
|
|
180
|
+
):
|
|
181
|
+
raise_deprecation_warning(
|
|
182
|
+
message=deprecation_message.format(
|
|
183
|
+
deprecated_switch_key=deprecated_switch_key
|
|
184
|
+
)
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def _resolve_huggingface_deprecated_switch_value(config: dict) -> dict:
|
|
189
|
+
"""
|
|
190
|
+
Resolve use of deprecated switching mechanism for HuggingFace local
|
|
191
|
+
embedding client.
|
|
192
|
+
|
|
193
|
+
The following settings (key + value) are deprecated:
|
|
194
|
+
1. `type: huggingface`
|
|
195
|
+
2. `_type: huggingface`
|
|
196
|
+
in favor of `provider: huggingface_local`.
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
config: given config
|
|
201
|
+
|
|
202
|
+
Returns:
|
|
203
|
+
New config with resolved switch mechanism
|
|
204
|
+
|
|
205
|
+
"""
|
|
206
|
+
config = config.copy()
|
|
207
|
+
|
|
208
|
+
deprecated_switch_keys = [RASA_TYPE_CONFIG_KEY, LANGCHAIN_TYPE_CONFIG_KEY]
|
|
209
|
+
debug_message = (
|
|
210
|
+
f"Switching "
|
|
211
|
+
f"`{{deprecated_switch_key}}: {DEPRECATED_HUGGINGFACE_TYPE}` "
|
|
212
|
+
f"to `{PROVIDER_CONFIG_KEY}: {HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER}`."
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
for deprecated_switch_key in deprecated_switch_keys:
|
|
216
|
+
if (
|
|
217
|
+
deprecated_switch_key in config
|
|
218
|
+
and config[deprecated_switch_key] == DEPRECATED_HUGGINGFACE_TYPE
|
|
219
|
+
):
|
|
220
|
+
# Update configuration with new switch mechanism
|
|
221
|
+
config[PROVIDER_CONFIG_KEY] = HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER
|
|
222
|
+
# Pop the deprecated key used
|
|
223
|
+
config.pop(deprecated_switch_key, None)
|
|
224
|
+
|
|
225
|
+
structlogger.debug(
|
|
226
|
+
"HuggingFaceLocalEmbeddingClientConfig"
|
|
227
|
+
"._resolve_huggingface_deprecated_switch_value",
|
|
228
|
+
message=debug_message.format(
|
|
229
|
+
deprecated_switch_key=deprecated_switch_key
|
|
230
|
+
),
|
|
231
|
+
new_config=config,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
return config
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
from dataclasses import asdict, dataclass, field
|
|
2
|
+
from typing import Any, Dict, Optional
|
|
3
|
+
|
|
4
|
+
import structlog
|
|
5
|
+
|
|
6
|
+
from rasa.shared.constants import (
|
|
7
|
+
MODEL_CONFIG_KEY,
|
|
8
|
+
MODEL_NAME_CONFIG_KEY,
|
|
9
|
+
OPENAI_API_BASE_CONFIG_KEY,
|
|
10
|
+
API_BASE_CONFIG_KEY,
|
|
11
|
+
OPENAI_API_TYPE_CONFIG_KEY,
|
|
12
|
+
API_TYPE_CONFIG_KEY,
|
|
13
|
+
OPENAI_API_VERSION_CONFIG_KEY,
|
|
14
|
+
API_VERSION_CONFIG_KEY,
|
|
15
|
+
RASA_TYPE_CONFIG_KEY,
|
|
16
|
+
LANGCHAIN_TYPE_CONFIG_KEY,
|
|
17
|
+
STREAM_CONFIG_KEY,
|
|
18
|
+
N_REPHRASES_CONFIG_KEY,
|
|
19
|
+
REQUEST_TIMEOUT_CONFIG_KEY,
|
|
20
|
+
TIMEOUT_CONFIG_KEY,
|
|
21
|
+
PROVIDER_CONFIG_KEY,
|
|
22
|
+
OPENAI_PROVIDER,
|
|
23
|
+
OPENAI_API_TYPE,
|
|
24
|
+
)
|
|
25
|
+
from rasa.shared.providers._configs.utils import (
|
|
26
|
+
resolve_aliases,
|
|
27
|
+
validate_required_keys,
|
|
28
|
+
raise_deprecation_warnings,
|
|
29
|
+
validate_forbidden_keys,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
structlogger = structlog.get_logger()
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING = {
|
|
36
|
+
# Model name aliases
|
|
37
|
+
MODEL_NAME_CONFIG_KEY: MODEL_CONFIG_KEY,
|
|
38
|
+
# Provider aliases
|
|
39
|
+
RASA_TYPE_CONFIG_KEY: PROVIDER_CONFIG_KEY,
|
|
40
|
+
LANGCHAIN_TYPE_CONFIG_KEY: PROVIDER_CONFIG_KEY,
|
|
41
|
+
# API type aliases
|
|
42
|
+
OPENAI_API_TYPE_CONFIG_KEY: API_TYPE_CONFIG_KEY,
|
|
43
|
+
# API base aliases
|
|
44
|
+
OPENAI_API_BASE_CONFIG_KEY: API_BASE_CONFIG_KEY,
|
|
45
|
+
# API version aliases
|
|
46
|
+
OPENAI_API_VERSION_CONFIG_KEY: API_VERSION_CONFIG_KEY,
|
|
47
|
+
# Timeout aliases
|
|
48
|
+
REQUEST_TIMEOUT_CONFIG_KEY: TIMEOUT_CONFIG_KEY,
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
REQUIRED_KEYS = [MODEL_CONFIG_KEY]
|
|
52
|
+
|
|
53
|
+
FORBIDDEN_KEYS = [
|
|
54
|
+
STREAM_CONFIG_KEY,
|
|
55
|
+
N_REPHRASES_CONFIG_KEY,
|
|
56
|
+
]
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@dataclass
|
|
60
|
+
class OpenAIClientConfig:
|
|
61
|
+
"""Parses configuration for Azure OpenAI client, resolves aliases and
|
|
62
|
+
raises deprecation warnings.
|
|
63
|
+
|
|
64
|
+
Raises:
|
|
65
|
+
ValueError: Raised in cases of invalid configuration:
|
|
66
|
+
- If any of the required configuration keys are missing.
|
|
67
|
+
- If `api_type` has a value different from `openai`.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
model: str
|
|
71
|
+
api_base: Optional[str]
|
|
72
|
+
api_version: Optional[str]
|
|
73
|
+
|
|
74
|
+
# API Type is not actually used by LiteLLM backend, but we define
|
|
75
|
+
# it here for backward compatibility.
|
|
76
|
+
api_type: str = OPENAI_API_TYPE
|
|
77
|
+
|
|
78
|
+
# Provider is not used by LiteLLM backend, but we define
|
|
79
|
+
# it here since it's used as switch between different
|
|
80
|
+
# clients
|
|
81
|
+
provider: str = OPENAI_PROVIDER
|
|
82
|
+
|
|
83
|
+
extra_parameters: dict = field(default_factory=dict)
|
|
84
|
+
|
|
85
|
+
def __post_init__(self) -> None:
|
|
86
|
+
# In case of OpenAI hosting, it doesn't make sense
|
|
87
|
+
# for API type to be anything else that 'openai'
|
|
88
|
+
if self.api_type != OPENAI_API_TYPE:
|
|
89
|
+
message = f"API type must be set to '{OPENAI_API_TYPE}'."
|
|
90
|
+
structlogger.error(
|
|
91
|
+
"openai_client_config.validation_error",
|
|
92
|
+
message=message,
|
|
93
|
+
api_type=self.api_type,
|
|
94
|
+
)
|
|
95
|
+
raise ValueError(message)
|
|
96
|
+
if self.provider != OPENAI_PROVIDER:
|
|
97
|
+
message = f"Provider must be set to '{OPENAI_PROVIDER}'."
|
|
98
|
+
structlogger.error(
|
|
99
|
+
"openai_client_config.validation_error",
|
|
100
|
+
message=message,
|
|
101
|
+
provider=self.provider,
|
|
102
|
+
)
|
|
103
|
+
raise ValueError(message)
|
|
104
|
+
if self.model is None:
|
|
105
|
+
message = "Model cannot be set to None."
|
|
106
|
+
structlogger.error(
|
|
107
|
+
"openai_client_config.validation_error",
|
|
108
|
+
message=message,
|
|
109
|
+
model=self.model,
|
|
110
|
+
)
|
|
111
|
+
raise ValueError(message)
|
|
112
|
+
|
|
113
|
+
@classmethod
|
|
114
|
+
def from_dict(cls, config: dict) -> "OpenAIClientConfig":
|
|
115
|
+
"""
|
|
116
|
+
Initializes a dataclass from the passed config.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
config: (dict) The config from which to initialize.
|
|
120
|
+
|
|
121
|
+
Raises:
|
|
122
|
+
ValueError: Config is missing required keys.
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
AzureOpenAIClientConfig
|
|
126
|
+
"""
|
|
127
|
+
# Check for deprecated keys
|
|
128
|
+
raise_deprecation_warnings(config, DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING)
|
|
129
|
+
# Resolve any potential aliases
|
|
130
|
+
config = cls.resolve_config_aliases(config)
|
|
131
|
+
# Validate that the required keys are present
|
|
132
|
+
validate_required_keys(config, REQUIRED_KEYS)
|
|
133
|
+
# Validate that the forbidden keys are not present
|
|
134
|
+
validate_forbidden_keys(config, FORBIDDEN_KEYS)
|
|
135
|
+
this = OpenAIClientConfig(
|
|
136
|
+
# Required parameters
|
|
137
|
+
model=config.pop(MODEL_CONFIG_KEY),
|
|
138
|
+
# Pop the 'provider' key. Currently, it's *optional* because of
|
|
139
|
+
# backward compatibility with older versions.
|
|
140
|
+
provider=config.pop(PROVIDER_CONFIG_KEY, OPENAI_PROVIDER),
|
|
141
|
+
# Optional parameters
|
|
142
|
+
api_base=config.pop(API_BASE_CONFIG_KEY, None),
|
|
143
|
+
api_version=config.pop(API_VERSION_CONFIG_KEY, None),
|
|
144
|
+
api_type=config.pop(API_TYPE_CONFIG_KEY, OPENAI_API_TYPE),
|
|
145
|
+
# The rest of parameters (e.g. model parameters) are considered
|
|
146
|
+
# as extra parameters (this also includes timeout).
|
|
147
|
+
extra_parameters=config,
|
|
148
|
+
)
|
|
149
|
+
return this
|
|
150
|
+
|
|
151
|
+
def to_dict(self) -> dict:
|
|
152
|
+
"""Converts the config instance into a dictionary."""
|
|
153
|
+
d = asdict(self)
|
|
154
|
+
# Extra parameters should also be on the top level
|
|
155
|
+
d.pop("extra_parameters", None)
|
|
156
|
+
d.update(self.extra_parameters)
|
|
157
|
+
return d
|
|
158
|
+
|
|
159
|
+
@staticmethod
|
|
160
|
+
def resolve_config_aliases(config: Dict[str, Any]) -> Dict[str, Any]:
|
|
161
|
+
return resolve_aliases(config, DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def is_openai_config(config: dict) -> bool:
|
|
165
|
+
"""Check whether the configuration is meant to configure
|
|
166
|
+
an OpenAI client.
|
|
167
|
+
"""
|
|
168
|
+
# Process the config to handle all the aliases
|
|
169
|
+
config = OpenAIClientConfig.resolve_config_aliases(config)
|
|
170
|
+
|
|
171
|
+
# Case: Configuration contains `provider: openai`
|
|
172
|
+
if config.get(PROVIDER_CONFIG_KEY) == OPENAI_PROVIDER:
|
|
173
|
+
return True
|
|
174
|
+
|
|
175
|
+
return False
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
from dataclasses import asdict, dataclass, field
|
|
2
|
+
from typing import Any, Dict, Optional
|
|
3
|
+
|
|
4
|
+
import structlog
|
|
5
|
+
|
|
6
|
+
from rasa.shared.constants import (
|
|
7
|
+
MODEL_CONFIG_KEY,
|
|
8
|
+
MODEL_NAME_CONFIG_KEY,
|
|
9
|
+
OPENAI_API_BASE_CONFIG_KEY,
|
|
10
|
+
API_BASE_CONFIG_KEY,
|
|
11
|
+
OPENAI_API_TYPE_CONFIG_KEY,
|
|
12
|
+
API_TYPE_CONFIG_KEY,
|
|
13
|
+
OPENAI_API_VERSION_CONFIG_KEY,
|
|
14
|
+
API_VERSION_CONFIG_KEY,
|
|
15
|
+
RASA_TYPE_CONFIG_KEY,
|
|
16
|
+
LANGCHAIN_TYPE_CONFIG_KEY,
|
|
17
|
+
STREAM_CONFIG_KEY,
|
|
18
|
+
N_REPHRASES_CONFIG_KEY,
|
|
19
|
+
REQUEST_TIMEOUT_CONFIG_KEY,
|
|
20
|
+
TIMEOUT_CONFIG_KEY,
|
|
21
|
+
PROVIDER_CONFIG_KEY,
|
|
22
|
+
OPENAI_PROVIDER,
|
|
23
|
+
SELF_HOSTED_PROVIDER,
|
|
24
|
+
)
|
|
25
|
+
from rasa.shared.providers._configs.utils import (
|
|
26
|
+
raise_deprecation_warnings,
|
|
27
|
+
resolve_aliases,
|
|
28
|
+
validate_forbidden_keys,
|
|
29
|
+
validate_required_keys,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
structlogger = structlog.get_logger()
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING = {
|
|
36
|
+
# Model name aliases
|
|
37
|
+
MODEL_NAME_CONFIG_KEY: MODEL_CONFIG_KEY,
|
|
38
|
+
# Provider aliases
|
|
39
|
+
RASA_TYPE_CONFIG_KEY: PROVIDER_CONFIG_KEY,
|
|
40
|
+
LANGCHAIN_TYPE_CONFIG_KEY: PROVIDER_CONFIG_KEY,
|
|
41
|
+
# API type aliases
|
|
42
|
+
OPENAI_API_TYPE_CONFIG_KEY: API_TYPE_CONFIG_KEY,
|
|
43
|
+
# API base aliases
|
|
44
|
+
OPENAI_API_BASE_CONFIG_KEY: API_BASE_CONFIG_KEY,
|
|
45
|
+
# API version aliases
|
|
46
|
+
OPENAI_API_VERSION_CONFIG_KEY: API_VERSION_CONFIG_KEY,
|
|
47
|
+
# Timeout aliases
|
|
48
|
+
REQUEST_TIMEOUT_CONFIG_KEY: TIMEOUT_CONFIG_KEY,
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
REQUIRED_KEYS = [API_BASE_CONFIG_KEY, MODEL_CONFIG_KEY, PROVIDER_CONFIG_KEY]
|
|
52
|
+
|
|
53
|
+
FORBIDDEN_KEYS = [
|
|
54
|
+
STREAM_CONFIG_KEY,
|
|
55
|
+
N_REPHRASES_CONFIG_KEY,
|
|
56
|
+
]
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@dataclass
|
|
60
|
+
class SelfHostedLLMClientConfig:
|
|
61
|
+
"""Parses configuration for Self Hosted LiteLLM client, resolves aliases and
|
|
62
|
+
raises deprecation warnings.
|
|
63
|
+
|
|
64
|
+
Raises:
|
|
65
|
+
ValueError: Raised in cases of invalid configuration:
|
|
66
|
+
- If any of the required configuration keys are missing.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
model: str
|
|
70
|
+
provider: str
|
|
71
|
+
api_base: str
|
|
72
|
+
api_version: Optional[str] = None
|
|
73
|
+
api_type: Optional[str] = OPENAI_PROVIDER
|
|
74
|
+
extra_parameters: dict = field(default_factory=dict)
|
|
75
|
+
|
|
76
|
+
def __post_init__(self) -> None:
|
|
77
|
+
if self.model is None:
|
|
78
|
+
message = "Model cannot be set to None."
|
|
79
|
+
structlogger.error(
|
|
80
|
+
"self_hosted_llm_client_config.validation_error",
|
|
81
|
+
message=message,
|
|
82
|
+
model=self.model,
|
|
83
|
+
)
|
|
84
|
+
raise ValueError(message)
|
|
85
|
+
if self.provider is None:
|
|
86
|
+
message = "Provider cannot be set to None."
|
|
87
|
+
structlogger.error(
|
|
88
|
+
"self_hosted_llm_client_config.validation_error",
|
|
89
|
+
message=message,
|
|
90
|
+
provider=self.provider,
|
|
91
|
+
)
|
|
92
|
+
raise ValueError(message)
|
|
93
|
+
if self.api_base is None:
|
|
94
|
+
message = "API base cannot be set to None."
|
|
95
|
+
structlogger.error(
|
|
96
|
+
"self_hosted_llm_client_config.validation_error",
|
|
97
|
+
message=message,
|
|
98
|
+
provider=self.provider,
|
|
99
|
+
)
|
|
100
|
+
raise ValueError(message)
|
|
101
|
+
if self.api_type != OPENAI_PROVIDER:
|
|
102
|
+
message = (
|
|
103
|
+
f"Currently supports only {OPENAI_PROVIDER} endpoints. "
|
|
104
|
+
f"API type must be set to '{OPENAI_PROVIDER}'."
|
|
105
|
+
)
|
|
106
|
+
structlogger.error(
|
|
107
|
+
"self_hosted_llm_client_config.validation_error",
|
|
108
|
+
message=message,
|
|
109
|
+
api_type=self.api_type,
|
|
110
|
+
)
|
|
111
|
+
raise ValueError(message)
|
|
112
|
+
|
|
113
|
+
@classmethod
|
|
114
|
+
def from_dict(cls, config: dict) -> "SelfHostedLLMClientConfig":
|
|
115
|
+
"""
|
|
116
|
+
Initializes a dataclass from the passed config.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
config: (dict) The config from which to initialize.
|
|
120
|
+
|
|
121
|
+
Raises:
|
|
122
|
+
ValueError: Config is missing required keys.
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
DefaultLiteLLMClientConfig
|
|
126
|
+
"""
|
|
127
|
+
# Check for deprecated keys
|
|
128
|
+
raise_deprecation_warnings(config, DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING)
|
|
129
|
+
# Resolve any potential aliases
|
|
130
|
+
config = cls.resolve_config_aliases(config)
|
|
131
|
+
# Validate that the required keys are present
|
|
132
|
+
validate_required_keys(config, REQUIRED_KEYS)
|
|
133
|
+
# Validate that the forbidden keys are not present
|
|
134
|
+
validate_forbidden_keys(config, FORBIDDEN_KEYS)
|
|
135
|
+
this = SelfHostedLLMClientConfig(
|
|
136
|
+
# Required parameters
|
|
137
|
+
model=config.pop(MODEL_CONFIG_KEY),
|
|
138
|
+
provider=config.pop(PROVIDER_CONFIG_KEY),
|
|
139
|
+
api_base=config.pop(API_BASE_CONFIG_KEY),
|
|
140
|
+
# Optional parameters
|
|
141
|
+
api_type=config.pop(API_TYPE_CONFIG_KEY, OPENAI_PROVIDER),
|
|
142
|
+
api_version=config.pop(API_VERSION_CONFIG_KEY, None),
|
|
143
|
+
# The rest of parameters (e.g. model parameters) are considered
|
|
144
|
+
# as extra parameters
|
|
145
|
+
extra_parameters=config,
|
|
146
|
+
)
|
|
147
|
+
return this
|
|
148
|
+
|
|
149
|
+
def to_dict(self) -> dict:
|
|
150
|
+
"""Converts the config instance into a dictionary."""
|
|
151
|
+
d = asdict(self)
|
|
152
|
+
# Extra parameters should also be on the top level
|
|
153
|
+
d.pop("extra_parameters", None)
|
|
154
|
+
d.update(self.extra_parameters)
|
|
155
|
+
return d
|
|
156
|
+
|
|
157
|
+
@staticmethod
|
|
158
|
+
def resolve_config_aliases(config: Dict[str, Any]) -> Dict[str, Any]:
|
|
159
|
+
return resolve_aliases(config, DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def is_self_hosted_config(config: dict) -> bool:
|
|
163
|
+
"""Check whether the configuration is meant to configure an self-hosted client."""
|
|
164
|
+
# Process the config to handle all the aliases
|
|
165
|
+
config = SelfHostedLLMClientConfig.resolve_config_aliases(config)
|
|
166
|
+
|
|
167
|
+
# Case: Configuration contains `provider: self-hosted`
|
|
168
|
+
if config.get(PROVIDER_CONFIG_KEY) == SELF_HOSTED_PROVIDER:
|
|
169
|
+
return True
|
|
170
|
+
|
|
171
|
+
return False
|