rasa-pro 3.9.18__py3-none-any.whl → 3.10.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- README.md +0 -374
- rasa/__init__.py +1 -2
- rasa/__main__.py +5 -0
- rasa/anonymization/anonymization_rule_executor.py +2 -2
- rasa/api.py +27 -23
- rasa/cli/arguments/data.py +27 -2
- rasa/cli/arguments/default_arguments.py +25 -3
- rasa/cli/arguments/run.py +9 -9
- rasa/cli/arguments/train.py +11 -3
- rasa/cli/data.py +70 -8
- rasa/cli/e2e_test.py +104 -431
- rasa/cli/evaluate.py +1 -1
- rasa/cli/interactive.py +1 -0
- rasa/cli/llm_fine_tuning.py +398 -0
- rasa/cli/project_templates/calm/endpoints.yml +1 -1
- rasa/cli/project_templates/tutorial/endpoints.yml +1 -1
- rasa/cli/run.py +15 -14
- rasa/cli/scaffold.py +10 -8
- rasa/cli/studio/studio.py +35 -5
- rasa/cli/train.py +56 -8
- rasa/cli/utils.py +22 -5
- rasa/cli/x.py +1 -1
- rasa/constants.py +7 -1
- rasa/core/actions/action.py +98 -49
- rasa/core/actions/action_run_slot_rejections.py +4 -1
- rasa/core/actions/custom_action_executor.py +9 -6
- rasa/core/actions/direct_custom_actions_executor.py +80 -0
- rasa/core/actions/e2e_stub_custom_action_executor.py +68 -0
- rasa/core/actions/grpc_custom_action_executor.py +2 -2
- rasa/core/actions/http_custom_action_executor.py +6 -5
- rasa/core/agent.py +21 -17
- rasa/core/channels/__init__.py +2 -0
- rasa/core/channels/audiocodes.py +1 -16
- rasa/core/channels/voice_aware/__init__.py +0 -0
- rasa/core/channels/voice_aware/jambonz.py +103 -0
- rasa/core/channels/voice_aware/jambonz_protocol.py +344 -0
- rasa/core/channels/voice_aware/utils.py +20 -0
- rasa/core/channels/voice_native/__init__.py +0 -0
- rasa/core/constants.py +6 -1
- rasa/core/information_retrieval/faiss.py +7 -4
- rasa/core/information_retrieval/information_retrieval.py +8 -0
- rasa/core/information_retrieval/milvus.py +9 -2
- rasa/core/information_retrieval/qdrant.py +1 -1
- rasa/core/nlg/contextual_response_rephraser.py +32 -10
- rasa/core/nlg/summarize.py +4 -3
- rasa/core/policies/enterprise_search_policy.py +113 -45
- rasa/core/policies/flows/flow_executor.py +122 -76
- rasa/core/policies/intentless_policy.py +83 -29
- rasa/core/processor.py +72 -54
- rasa/core/run.py +5 -4
- rasa/core/tracker_store.py +8 -4
- rasa/core/training/interactive.py +1 -1
- rasa/core/utils.py +56 -57
- rasa/dialogue_understanding/coexistence/llm_based_router.py +53 -13
- rasa/dialogue_understanding/commands/__init__.py +6 -0
- rasa/dialogue_understanding/commands/restart_command.py +58 -0
- rasa/dialogue_understanding/commands/session_start_command.py +59 -0
- rasa/dialogue_understanding/commands/utils.py +40 -0
- rasa/dialogue_understanding/generator/constants.py +10 -3
- rasa/dialogue_understanding/generator/flow_retrieval.py +21 -5
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +13 -3
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +134 -90
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +47 -7
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +127 -41
- rasa/dialogue_understanding/patterns/restart.py +37 -0
- rasa/dialogue_understanding/patterns/session_start.py +37 -0
- rasa/dialogue_understanding/processor/command_processor.py +16 -3
- rasa/dialogue_understanding/processor/command_processor_component.py +6 -2
- rasa/e2e_test/aggregate_test_stats_calculator.py +134 -0
- rasa/e2e_test/assertions.py +1223 -0
- rasa/e2e_test/assertions_schema.yml +106 -0
- rasa/e2e_test/constants.py +20 -0
- rasa/e2e_test/e2e_config.py +220 -0
- rasa/e2e_test/e2e_config_schema.yml +26 -0
- rasa/e2e_test/e2e_test_case.py +131 -8
- rasa/e2e_test/e2e_test_converter.py +363 -0
- rasa/e2e_test/e2e_test_converter_prompt.jinja2 +70 -0
- rasa/e2e_test/e2e_test_coverage_report.py +364 -0
- rasa/e2e_test/e2e_test_result.py +26 -6
- rasa/e2e_test/e2e_test_runner.py +493 -71
- rasa/e2e_test/e2e_test_schema.yml +96 -0
- rasa/e2e_test/pykwalify_extensions.py +39 -0
- rasa/e2e_test/stub_custom_action.py +70 -0
- rasa/e2e_test/utils/__init__.py +0 -0
- rasa/e2e_test/utils/e2e_yaml_utils.py +55 -0
- rasa/e2e_test/utils/io.py +598 -0
- rasa/e2e_test/utils/validation.py +80 -0
- rasa/engine/graph.py +9 -3
- rasa/engine/recipes/default_components.py +0 -2
- rasa/engine/recipes/default_recipe.py +10 -2
- rasa/engine/storage/local_model_storage.py +40 -12
- rasa/engine/validation.py +78 -1
- rasa/env.py +9 -0
- rasa/graph_components/providers/story_graph_provider.py +59 -6
- rasa/llm_fine_tuning/__init__.py +0 -0
- rasa/llm_fine_tuning/annotation_module.py +241 -0
- rasa/llm_fine_tuning/conversations.py +144 -0
- rasa/llm_fine_tuning/llm_data_preparation_module.py +178 -0
- rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +407 -0
- rasa/llm_fine_tuning/paraphrasing/__init__.py +0 -0
- rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +281 -0
- rasa/llm_fine_tuning/paraphrasing/default_rephrase_prompt_template.jina2 +44 -0
- rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +121 -0
- rasa/llm_fine_tuning/paraphrasing/rephrased_user_message.py +10 -0
- rasa/llm_fine_tuning/paraphrasing_module.py +128 -0
- rasa/llm_fine_tuning/storage.py +174 -0
- rasa/llm_fine_tuning/train_test_split_module.py +441 -0
- rasa/model_training.py +56 -16
- rasa/nlu/persistor.py +157 -36
- rasa/server.py +45 -10
- rasa/shared/constants.py +76 -16
- rasa/shared/core/domain.py +27 -19
- rasa/shared/core/events.py +28 -2
- rasa/shared/core/flows/flow.py +208 -13
- rasa/shared/core/flows/flow_path.py +84 -0
- rasa/shared/core/flows/flows_list.py +33 -11
- rasa/shared/core/flows/flows_yaml_schema.json +269 -193
- rasa/shared/core/flows/validation.py +112 -25
- rasa/shared/core/flows/yaml_flows_io.py +149 -10
- rasa/shared/core/trackers.py +6 -0
- rasa/shared/core/training_data/structures.py +20 -0
- rasa/shared/core/training_data/visualization.html +2 -2
- rasa/shared/exceptions.py +4 -0
- rasa/shared/importers/importer.py +64 -16
- rasa/shared/nlu/constants.py +2 -0
- rasa/shared/providers/_configs/__init__.py +0 -0
- rasa/shared/providers/_configs/azure_openai_client_config.py +183 -0
- rasa/shared/providers/_configs/client_config.py +57 -0
- rasa/shared/providers/_configs/default_litellm_client_config.py +130 -0
- rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +234 -0
- rasa/shared/providers/_configs/openai_client_config.py +175 -0
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +176 -0
- rasa/shared/providers/_configs/utils.py +101 -0
- rasa/shared/providers/_ssl_verification_utils.py +124 -0
- rasa/shared/providers/embedding/__init__.py +0 -0
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +259 -0
- rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py +74 -0
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +277 -0
- rasa/shared/providers/embedding/default_litellm_embedding_client.py +102 -0
- rasa/shared/providers/embedding/embedding_client.py +90 -0
- rasa/shared/providers/embedding/embedding_response.py +41 -0
- rasa/shared/providers/embedding/huggingface_local_embedding_client.py +191 -0
- rasa/shared/providers/embedding/openai_embedding_client.py +172 -0
- rasa/shared/providers/llm/__init__.py +0 -0
- rasa/shared/providers/llm/_base_litellm_client.py +251 -0
- rasa/shared/providers/llm/azure_openai_llm_client.py +338 -0
- rasa/shared/providers/llm/default_litellm_llm_client.py +84 -0
- rasa/shared/providers/llm/llm_client.py +76 -0
- rasa/shared/providers/llm/llm_response.py +50 -0
- rasa/shared/providers/llm/openai_llm_client.py +155 -0
- rasa/shared/providers/llm/self_hosted_llm_client.py +293 -0
- rasa/shared/providers/mappings.py +75 -0
- rasa/shared/utils/cli.py +30 -0
- rasa/shared/utils/io.py +65 -2
- rasa/shared/utils/llm.py +246 -200
- rasa/shared/utils/yaml.py +121 -15
- rasa/studio/auth.py +6 -4
- rasa/studio/config.py +13 -4
- rasa/studio/constants.py +1 -0
- rasa/studio/data_handler.py +10 -3
- rasa/studio/download.py +19 -13
- rasa/studio/train.py +2 -3
- rasa/studio/upload.py +19 -11
- rasa/telemetry.py +113 -58
- rasa/tracing/instrumentation/attribute_extractors.py +32 -17
- rasa/utils/common.py +18 -19
- rasa/utils/endpoints.py +7 -4
- rasa/utils/json_utils.py +60 -0
- rasa/utils/licensing.py +9 -1
- rasa/utils/ml_utils.py +4 -2
- rasa/validator.py +213 -3
- rasa/version.py +1 -1
- rasa_pro-3.10.16.dist-info/METADATA +196 -0
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/RECORD +179 -113
- rasa/nlu/classifiers/llm_intent_classifier.py +0 -519
- rasa/shared/providers/openai/clients.py +0 -43
- rasa/shared/providers/openai/session_handler.py +0 -110
- rasa_pro-3.9.18.dist-info/METADATA +0 -563
- /rasa/{shared/providers/openai → cli/project_templates/tutorial/actions}/__init__.py +0 -0
- /rasa/cli/project_templates/tutorial/{actions.py → actions/actions.py} +0 -0
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/NOTICE +0 -0
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/WHEEL +0 -0
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
from dataclasses import asdict, dataclass, field
|
|
2
|
+
from typing import Any, Dict, Optional
|
|
3
|
+
|
|
4
|
+
import structlog
|
|
5
|
+
|
|
6
|
+
from rasa.shared.constants import (
|
|
7
|
+
MODEL_CONFIG_KEY,
|
|
8
|
+
MODEL_NAME_CONFIG_KEY,
|
|
9
|
+
RASA_TYPE_CONFIG_KEY,
|
|
10
|
+
LANGCHAIN_TYPE_CONFIG_KEY,
|
|
11
|
+
HUGGINGFACE_MULTIPROCESS_CONFIG_KEY,
|
|
12
|
+
HUGGINGFACE_CACHE_FOLDER_CONFIG_KEY,
|
|
13
|
+
HUGGINGFACE_SHOW_PROGRESS_CONFIG_KEY,
|
|
14
|
+
HUGGINGFACE_MODEL_KWARGS_CONFIG_KEY,
|
|
15
|
+
HUGGINGFACE_ENCODE_KWARGS_CONFIG_KEY,
|
|
16
|
+
HUGGINGFACE_LOCAL_EMBEDDING_CACHING_FOLDER,
|
|
17
|
+
PROVIDER_CONFIG_KEY,
|
|
18
|
+
HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER,
|
|
19
|
+
TIMEOUT_CONFIG_KEY,
|
|
20
|
+
REQUEST_TIMEOUT_CONFIG_KEY,
|
|
21
|
+
)
|
|
22
|
+
from rasa.shared.providers._configs.utils import (
|
|
23
|
+
resolve_aliases,
|
|
24
|
+
raise_deprecation_warnings,
|
|
25
|
+
validate_required_keys,
|
|
26
|
+
)
|
|
27
|
+
from rasa.shared.utils.io import raise_deprecation_warning
|
|
28
|
+
|
|
29
|
+
structlogger = structlog.get_logger()
|
|
30
|
+
|
|
31
|
+
DEPRECATED_HUGGINGFACE_TYPE = "huggingface"
|
|
32
|
+
|
|
33
|
+
DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING = {
|
|
34
|
+
# Provider aliases
|
|
35
|
+
RASA_TYPE_CONFIG_KEY: PROVIDER_CONFIG_KEY,
|
|
36
|
+
LANGCHAIN_TYPE_CONFIG_KEY: PROVIDER_CONFIG_KEY,
|
|
37
|
+
# Model name aliases
|
|
38
|
+
MODEL_NAME_CONFIG_KEY: MODEL_CONFIG_KEY,
|
|
39
|
+
# Timeout aliases
|
|
40
|
+
REQUEST_TIMEOUT_CONFIG_KEY: TIMEOUT_CONFIG_KEY,
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
REQUIRED_KEYS = [MODEL_CONFIG_KEY, PROVIDER_CONFIG_KEY]
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass
|
|
47
|
+
class HuggingFaceLocalEmbeddingClientConfig:
|
|
48
|
+
"""Parses configuration for HuggingFace local embeddings client, resolves
|
|
49
|
+
aliases and raises deprecation warnings.
|
|
50
|
+
|
|
51
|
+
Raises:
|
|
52
|
+
ValueError: Raised in cases of invalid configuration:
|
|
53
|
+
- If any of the required configuration keys are missing.
|
|
54
|
+
- If `api_type` has a value different from `huggingface_local` or
|
|
55
|
+
`huggingface` (deprecated).
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
model: str
|
|
59
|
+
|
|
60
|
+
multi_process: Optional[bool]
|
|
61
|
+
cache_folder: Optional[str]
|
|
62
|
+
show_progress: Optional[bool]
|
|
63
|
+
|
|
64
|
+
# Provider is not actually used by sentence-transformers, but we define
|
|
65
|
+
# it here because it's used as a switch denominator for HuggingFace
|
|
66
|
+
# local embedding client.
|
|
67
|
+
provider: str = HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER
|
|
68
|
+
|
|
69
|
+
model_kwargs: dict = field(default_factory=dict)
|
|
70
|
+
encode_kwargs: dict = field(default_factory=dict)
|
|
71
|
+
|
|
72
|
+
def __post_init__(self) -> None:
|
|
73
|
+
if self.provider != HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER:
|
|
74
|
+
message = (
|
|
75
|
+
f"API type must be set to '{HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER}'."
|
|
76
|
+
)
|
|
77
|
+
structlogger.error(
|
|
78
|
+
"huggingface_local_embeddings_client_config.validation_error",
|
|
79
|
+
message=message,
|
|
80
|
+
provider=self.provider,
|
|
81
|
+
)
|
|
82
|
+
raise ValueError(message)
|
|
83
|
+
if self.model is None:
|
|
84
|
+
message = "Model cannot be set to None."
|
|
85
|
+
structlogger.error(
|
|
86
|
+
"huggingface_local_embeddings_client_config.validation_error",
|
|
87
|
+
message=message,
|
|
88
|
+
model=self.model,
|
|
89
|
+
)
|
|
90
|
+
raise ValueError(message)
|
|
91
|
+
|
|
92
|
+
@classmethod
|
|
93
|
+
def from_dict(cls, config: dict) -> "HuggingFaceLocalEmbeddingClientConfig":
|
|
94
|
+
"""
|
|
95
|
+
Initializes a dataclass from the passed config.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
config: (dict) The config from which to initialize.
|
|
99
|
+
|
|
100
|
+
Raises:
|
|
101
|
+
ValueError: Config is missing required keys.
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
DefaultLiteLLMClientConfig
|
|
105
|
+
"""
|
|
106
|
+
# Check for usage of deprecated switching key and value:
|
|
107
|
+
# 1. type: huggingface
|
|
108
|
+
# 2. _type: huggingface
|
|
109
|
+
_raise_deprecation_warning_for_huggingface_deprecated_switch_value(config)
|
|
110
|
+
# Check for other deprecated keys
|
|
111
|
+
raise_deprecation_warnings(config, DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING)
|
|
112
|
+
# Resolve any potential aliases
|
|
113
|
+
config = cls.resolve_config_aliases(config)
|
|
114
|
+
# Validate that required keys are set
|
|
115
|
+
validate_required_keys(config, REQUIRED_KEYS)
|
|
116
|
+
this = HuggingFaceLocalEmbeddingClientConfig(
|
|
117
|
+
# Required parameters
|
|
118
|
+
model=config.pop(MODEL_CONFIG_KEY),
|
|
119
|
+
provider=config.pop(PROVIDER_CONFIG_KEY),
|
|
120
|
+
# Optional
|
|
121
|
+
multi_process=config.pop(HUGGINGFACE_MULTIPROCESS_CONFIG_KEY, False),
|
|
122
|
+
cache_folder=config.pop(
|
|
123
|
+
HUGGINGFACE_CACHE_FOLDER_CONFIG_KEY,
|
|
124
|
+
str(HUGGINGFACE_LOCAL_EMBEDDING_CACHING_FOLDER),
|
|
125
|
+
),
|
|
126
|
+
show_progress=config.pop(HUGGINGFACE_SHOW_PROGRESS_CONFIG_KEY, False),
|
|
127
|
+
model_kwargs=config.pop(HUGGINGFACE_MODEL_KWARGS_CONFIG_KEY, {}),
|
|
128
|
+
encode_kwargs=config.pop(HUGGINGFACE_ENCODE_KWARGS_CONFIG_KEY, {}),
|
|
129
|
+
)
|
|
130
|
+
return this
|
|
131
|
+
|
|
132
|
+
def to_dict(self) -> dict:
|
|
133
|
+
"""Converts the config instance into a dictionary."""
|
|
134
|
+
return asdict(self)
|
|
135
|
+
|
|
136
|
+
@staticmethod
|
|
137
|
+
def resolve_config_aliases(config: Dict[str, Any]) -> Dict[str, Any]:
|
|
138
|
+
config = _resolve_huggingface_deprecated_switch_value(config)
|
|
139
|
+
return resolve_aliases(config, DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def is_huggingface_local_config(config: dict) -> bool:
|
|
143
|
+
"""Check whether the configuration is meant to configure
|
|
144
|
+
a local HuggingFace embedding client.
|
|
145
|
+
"""
|
|
146
|
+
# Hugging face special deprecated cases:
|
|
147
|
+
# 1. type: huggingface
|
|
148
|
+
# 2. _type: huggingface
|
|
149
|
+
# If the deprecated setting is detected resolve both alias key and key
|
|
150
|
+
# value. This would mean that the configurations above will be
|
|
151
|
+
# transformed to:
|
|
152
|
+
# provider: huggingface_local
|
|
153
|
+
config = HuggingFaceLocalEmbeddingClientConfig.resolve_config_aliases(config)
|
|
154
|
+
|
|
155
|
+
# Case: Configuration contains `provider: huggingface_local`
|
|
156
|
+
if config.get(PROVIDER_CONFIG_KEY) in [
|
|
157
|
+
HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER,
|
|
158
|
+
]:
|
|
159
|
+
return True
|
|
160
|
+
|
|
161
|
+
return False
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def _raise_deprecation_warning_for_huggingface_deprecated_switch_value(
|
|
165
|
+
config: dict,
|
|
166
|
+
) -> None:
|
|
167
|
+
deprecated_switch_keys = [RASA_TYPE_CONFIG_KEY, LANGCHAIN_TYPE_CONFIG_KEY]
|
|
168
|
+
deprecation_message = (
|
|
169
|
+
f"Configuration "
|
|
170
|
+
f"`{{deprecated_switch_key}}: {DEPRECATED_HUGGINGFACE_TYPE}` "
|
|
171
|
+
f"is deprecated and will be removed in 4.0.0. "
|
|
172
|
+
f"Please use "
|
|
173
|
+
f"`{PROVIDER_CONFIG_KEY}: {HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER}` "
|
|
174
|
+
f"instead."
|
|
175
|
+
)
|
|
176
|
+
for deprecated_switch_key in deprecated_switch_keys:
|
|
177
|
+
if (
|
|
178
|
+
deprecated_switch_key in config
|
|
179
|
+
and config[deprecated_switch_key] == DEPRECATED_HUGGINGFACE_TYPE
|
|
180
|
+
):
|
|
181
|
+
raise_deprecation_warning(
|
|
182
|
+
message=deprecation_message.format(
|
|
183
|
+
deprecated_switch_key=deprecated_switch_key
|
|
184
|
+
)
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def _resolve_huggingface_deprecated_switch_value(config: dict) -> dict:
|
|
189
|
+
"""
|
|
190
|
+
Resolve use of deprecated switching mechanism for HuggingFace local
|
|
191
|
+
embedding client.
|
|
192
|
+
|
|
193
|
+
The following settings (key + value) are deprecated:
|
|
194
|
+
1. `type: huggingface`
|
|
195
|
+
2. `_type: huggingface`
|
|
196
|
+
in favor of `provider: huggingface_local`.
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
config: given config
|
|
201
|
+
|
|
202
|
+
Returns:
|
|
203
|
+
New config with resolved switch mechanism
|
|
204
|
+
|
|
205
|
+
"""
|
|
206
|
+
config = config.copy()
|
|
207
|
+
|
|
208
|
+
deprecated_switch_keys = [RASA_TYPE_CONFIG_KEY, LANGCHAIN_TYPE_CONFIG_KEY]
|
|
209
|
+
debug_message = (
|
|
210
|
+
f"Switching "
|
|
211
|
+
f"`{{deprecated_switch_key}}: {DEPRECATED_HUGGINGFACE_TYPE}` "
|
|
212
|
+
f"to `{PROVIDER_CONFIG_KEY}: {HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER}`."
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
for deprecated_switch_key in deprecated_switch_keys:
|
|
216
|
+
if (
|
|
217
|
+
deprecated_switch_key in config
|
|
218
|
+
and config[deprecated_switch_key] == DEPRECATED_HUGGINGFACE_TYPE
|
|
219
|
+
):
|
|
220
|
+
# Update configuration with new switch mechanism
|
|
221
|
+
config[PROVIDER_CONFIG_KEY] = HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER
|
|
222
|
+
# Pop the deprecated key used
|
|
223
|
+
config.pop(deprecated_switch_key, None)
|
|
224
|
+
|
|
225
|
+
structlogger.debug(
|
|
226
|
+
"HuggingFaceLocalEmbeddingClientConfig"
|
|
227
|
+
"._resolve_huggingface_deprecated_switch_value",
|
|
228
|
+
message=debug_message.format(
|
|
229
|
+
deprecated_switch_key=deprecated_switch_key
|
|
230
|
+
),
|
|
231
|
+
new_config=config,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
return config
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
from dataclasses import asdict, dataclass, field
|
|
2
|
+
from typing import Any, Dict, Optional
|
|
3
|
+
|
|
4
|
+
import structlog
|
|
5
|
+
|
|
6
|
+
from rasa.shared.constants import (
|
|
7
|
+
MODEL_CONFIG_KEY,
|
|
8
|
+
MODEL_NAME_CONFIG_KEY,
|
|
9
|
+
OPENAI_API_BASE_CONFIG_KEY,
|
|
10
|
+
API_BASE_CONFIG_KEY,
|
|
11
|
+
OPENAI_API_TYPE_CONFIG_KEY,
|
|
12
|
+
API_TYPE_CONFIG_KEY,
|
|
13
|
+
OPENAI_API_VERSION_CONFIG_KEY,
|
|
14
|
+
API_VERSION_CONFIG_KEY,
|
|
15
|
+
RASA_TYPE_CONFIG_KEY,
|
|
16
|
+
LANGCHAIN_TYPE_CONFIG_KEY,
|
|
17
|
+
STREAM_CONFIG_KEY,
|
|
18
|
+
N_REPHRASES_CONFIG_KEY,
|
|
19
|
+
REQUEST_TIMEOUT_CONFIG_KEY,
|
|
20
|
+
TIMEOUT_CONFIG_KEY,
|
|
21
|
+
PROVIDER_CONFIG_KEY,
|
|
22
|
+
OPENAI_PROVIDER,
|
|
23
|
+
OPENAI_API_TYPE,
|
|
24
|
+
)
|
|
25
|
+
from rasa.shared.providers._configs.utils import (
|
|
26
|
+
resolve_aliases,
|
|
27
|
+
validate_required_keys,
|
|
28
|
+
raise_deprecation_warnings,
|
|
29
|
+
validate_forbidden_keys,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
structlogger = structlog.get_logger()
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING = {
|
|
36
|
+
# Model name aliases
|
|
37
|
+
MODEL_NAME_CONFIG_KEY: MODEL_CONFIG_KEY,
|
|
38
|
+
# Provider aliases
|
|
39
|
+
RASA_TYPE_CONFIG_KEY: PROVIDER_CONFIG_KEY,
|
|
40
|
+
LANGCHAIN_TYPE_CONFIG_KEY: PROVIDER_CONFIG_KEY,
|
|
41
|
+
# API type aliases
|
|
42
|
+
OPENAI_API_TYPE_CONFIG_KEY: API_TYPE_CONFIG_KEY,
|
|
43
|
+
# API base aliases
|
|
44
|
+
OPENAI_API_BASE_CONFIG_KEY: API_BASE_CONFIG_KEY,
|
|
45
|
+
# API version aliases
|
|
46
|
+
OPENAI_API_VERSION_CONFIG_KEY: API_VERSION_CONFIG_KEY,
|
|
47
|
+
# Timeout aliases
|
|
48
|
+
REQUEST_TIMEOUT_CONFIG_KEY: TIMEOUT_CONFIG_KEY,
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
REQUIRED_KEYS = [MODEL_CONFIG_KEY]
|
|
52
|
+
|
|
53
|
+
FORBIDDEN_KEYS = [
|
|
54
|
+
STREAM_CONFIG_KEY,
|
|
55
|
+
N_REPHRASES_CONFIG_KEY,
|
|
56
|
+
]
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@dataclass
|
|
60
|
+
class OpenAIClientConfig:
|
|
61
|
+
"""Parses configuration for Azure OpenAI client, resolves aliases and
|
|
62
|
+
raises deprecation warnings.
|
|
63
|
+
|
|
64
|
+
Raises:
|
|
65
|
+
ValueError: Raised in cases of invalid configuration:
|
|
66
|
+
- If any of the required configuration keys are missing.
|
|
67
|
+
- If `api_type` has a value different from `openai`.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
model: str
|
|
71
|
+
api_base: Optional[str]
|
|
72
|
+
api_version: Optional[str]
|
|
73
|
+
|
|
74
|
+
# API Type is not actually used by LiteLLM backend, but we define
|
|
75
|
+
# it here for backward compatibility.
|
|
76
|
+
api_type: str = OPENAI_API_TYPE
|
|
77
|
+
|
|
78
|
+
# Provider is not used by LiteLLM backend, but we define
|
|
79
|
+
# it here since it's used as switch between different
|
|
80
|
+
# clients
|
|
81
|
+
provider: str = OPENAI_PROVIDER
|
|
82
|
+
|
|
83
|
+
extra_parameters: dict = field(default_factory=dict)
|
|
84
|
+
|
|
85
|
+
def __post_init__(self) -> None:
|
|
86
|
+
# In case of OpenAI hosting, it doesn't make sense
|
|
87
|
+
# for API type to be anything else that 'openai'
|
|
88
|
+
if self.api_type != OPENAI_API_TYPE:
|
|
89
|
+
message = f"API type must be set to '{OPENAI_API_TYPE}'."
|
|
90
|
+
structlogger.error(
|
|
91
|
+
"openai_client_config.validation_error",
|
|
92
|
+
message=message,
|
|
93
|
+
api_type=self.api_type,
|
|
94
|
+
)
|
|
95
|
+
raise ValueError(message)
|
|
96
|
+
if self.provider != OPENAI_PROVIDER:
|
|
97
|
+
message = f"Provider must be set to '{OPENAI_PROVIDER}'."
|
|
98
|
+
structlogger.error(
|
|
99
|
+
"openai_client_config.validation_error",
|
|
100
|
+
message=message,
|
|
101
|
+
provider=self.provider,
|
|
102
|
+
)
|
|
103
|
+
raise ValueError(message)
|
|
104
|
+
if self.model is None:
|
|
105
|
+
message = "Model cannot be set to None."
|
|
106
|
+
structlogger.error(
|
|
107
|
+
"openai_client_config.validation_error",
|
|
108
|
+
message=message,
|
|
109
|
+
model=self.model,
|
|
110
|
+
)
|
|
111
|
+
raise ValueError(message)
|
|
112
|
+
|
|
113
|
+
@classmethod
|
|
114
|
+
def from_dict(cls, config: dict) -> "OpenAIClientConfig":
|
|
115
|
+
"""
|
|
116
|
+
Initializes a dataclass from the passed config.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
config: (dict) The config from which to initialize.
|
|
120
|
+
|
|
121
|
+
Raises:
|
|
122
|
+
ValueError: Config is missing required keys.
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
AzureOpenAIClientConfig
|
|
126
|
+
"""
|
|
127
|
+
# Check for deprecated keys
|
|
128
|
+
raise_deprecation_warnings(config, DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING)
|
|
129
|
+
# Resolve any potential aliases
|
|
130
|
+
config = cls.resolve_config_aliases(config)
|
|
131
|
+
# Validate that the required keys are present
|
|
132
|
+
validate_required_keys(config, REQUIRED_KEYS)
|
|
133
|
+
# Validate that the forbidden keys are not present
|
|
134
|
+
validate_forbidden_keys(config, FORBIDDEN_KEYS)
|
|
135
|
+
this = OpenAIClientConfig(
|
|
136
|
+
# Required parameters
|
|
137
|
+
model=config.pop(MODEL_CONFIG_KEY),
|
|
138
|
+
# Pop the 'provider' key. Currently, it's *optional* because of
|
|
139
|
+
# backward compatibility with older versions.
|
|
140
|
+
provider=config.pop(PROVIDER_CONFIG_KEY, OPENAI_PROVIDER),
|
|
141
|
+
# Optional parameters
|
|
142
|
+
api_base=config.pop(API_BASE_CONFIG_KEY, None),
|
|
143
|
+
api_version=config.pop(API_VERSION_CONFIG_KEY, None),
|
|
144
|
+
api_type=config.pop(API_TYPE_CONFIG_KEY, OPENAI_API_TYPE),
|
|
145
|
+
# The rest of parameters (e.g. model parameters) are considered
|
|
146
|
+
# as extra parameters (this also includes timeout).
|
|
147
|
+
extra_parameters=config,
|
|
148
|
+
)
|
|
149
|
+
return this
|
|
150
|
+
|
|
151
|
+
def to_dict(self) -> dict:
|
|
152
|
+
"""Converts the config instance into a dictionary."""
|
|
153
|
+
d = asdict(self)
|
|
154
|
+
# Extra parameters should also be on the top level
|
|
155
|
+
d.pop("extra_parameters", None)
|
|
156
|
+
d.update(self.extra_parameters)
|
|
157
|
+
return d
|
|
158
|
+
|
|
159
|
+
@staticmethod
|
|
160
|
+
def resolve_config_aliases(config: Dict[str, Any]) -> Dict[str, Any]:
|
|
161
|
+
return resolve_aliases(config, DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def is_openai_config(config: dict) -> bool:
|
|
165
|
+
"""Check whether the configuration is meant to configure
|
|
166
|
+
an OpenAI client.
|
|
167
|
+
"""
|
|
168
|
+
# Process the config to handle all the aliases
|
|
169
|
+
config = OpenAIClientConfig.resolve_config_aliases(config)
|
|
170
|
+
|
|
171
|
+
# Case: Configuration contains `provider: openai`
|
|
172
|
+
if config.get(PROVIDER_CONFIG_KEY) == OPENAI_PROVIDER:
|
|
173
|
+
return True
|
|
174
|
+
|
|
175
|
+
return False
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
from dataclasses import asdict, dataclass, field
|
|
2
|
+
from typing import Any, Dict, Optional
|
|
3
|
+
|
|
4
|
+
import structlog
|
|
5
|
+
|
|
6
|
+
from rasa.shared.constants import (
|
|
7
|
+
MODEL_CONFIG_KEY,
|
|
8
|
+
MODEL_NAME_CONFIG_KEY,
|
|
9
|
+
OPENAI_API_BASE_CONFIG_KEY,
|
|
10
|
+
API_BASE_CONFIG_KEY,
|
|
11
|
+
OPENAI_API_TYPE_CONFIG_KEY,
|
|
12
|
+
API_TYPE_CONFIG_KEY,
|
|
13
|
+
OPENAI_API_VERSION_CONFIG_KEY,
|
|
14
|
+
API_VERSION_CONFIG_KEY,
|
|
15
|
+
RASA_TYPE_CONFIG_KEY,
|
|
16
|
+
LANGCHAIN_TYPE_CONFIG_KEY,
|
|
17
|
+
STREAM_CONFIG_KEY,
|
|
18
|
+
N_REPHRASES_CONFIG_KEY,
|
|
19
|
+
REQUEST_TIMEOUT_CONFIG_KEY,
|
|
20
|
+
TIMEOUT_CONFIG_KEY,
|
|
21
|
+
PROVIDER_CONFIG_KEY,
|
|
22
|
+
OPENAI_PROVIDER,
|
|
23
|
+
SELF_HOSTED_PROVIDER,
|
|
24
|
+
USE_CHAT_COMPLETIONS_ENDPOINT_CONFIG_KEY,
|
|
25
|
+
)
|
|
26
|
+
from rasa.shared.providers._configs.utils import (
|
|
27
|
+
raise_deprecation_warnings,
|
|
28
|
+
resolve_aliases,
|
|
29
|
+
validate_forbidden_keys,
|
|
30
|
+
validate_required_keys,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
structlogger = structlog.get_logger()
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING = {
|
|
37
|
+
# Model name aliases
|
|
38
|
+
MODEL_NAME_CONFIG_KEY: MODEL_CONFIG_KEY,
|
|
39
|
+
# Provider aliases
|
|
40
|
+
RASA_TYPE_CONFIG_KEY: PROVIDER_CONFIG_KEY,
|
|
41
|
+
LANGCHAIN_TYPE_CONFIG_KEY: PROVIDER_CONFIG_KEY,
|
|
42
|
+
# API type aliases
|
|
43
|
+
OPENAI_API_TYPE_CONFIG_KEY: API_TYPE_CONFIG_KEY,
|
|
44
|
+
# API base aliases
|
|
45
|
+
OPENAI_API_BASE_CONFIG_KEY: API_BASE_CONFIG_KEY,
|
|
46
|
+
# API version aliases
|
|
47
|
+
OPENAI_API_VERSION_CONFIG_KEY: API_VERSION_CONFIG_KEY,
|
|
48
|
+
# Timeout aliases
|
|
49
|
+
REQUEST_TIMEOUT_CONFIG_KEY: TIMEOUT_CONFIG_KEY,
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
REQUIRED_KEYS = [API_BASE_CONFIG_KEY, MODEL_CONFIG_KEY, PROVIDER_CONFIG_KEY]
|
|
53
|
+
|
|
54
|
+
FORBIDDEN_KEYS = [
|
|
55
|
+
STREAM_CONFIG_KEY,
|
|
56
|
+
N_REPHRASES_CONFIG_KEY,
|
|
57
|
+
]
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@dataclass
|
|
61
|
+
class SelfHostedLLMClientConfig:
|
|
62
|
+
"""Parses configuration for Self Hosted LiteLLM client, resolves aliases and
|
|
63
|
+
raises deprecation warnings.
|
|
64
|
+
|
|
65
|
+
Raises:
|
|
66
|
+
ValueError: Raised in cases of invalid configuration:
|
|
67
|
+
- If any of the required configuration keys are missing.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
model: str
|
|
71
|
+
provider: str
|
|
72
|
+
api_base: str
|
|
73
|
+
api_version: Optional[str] = None
|
|
74
|
+
api_type: Optional[str] = OPENAI_PROVIDER
|
|
75
|
+
use_chat_completions_endpoint: Optional[bool] = True
|
|
76
|
+
extra_parameters: dict = field(default_factory=dict)
|
|
77
|
+
|
|
78
|
+
def __post_init__(self) -> None:
|
|
79
|
+
if self.model is None:
|
|
80
|
+
message = "Model cannot be set to None."
|
|
81
|
+
structlogger.error(
|
|
82
|
+
"self_hosted_llm_client_config.validation_error",
|
|
83
|
+
message=message,
|
|
84
|
+
model=self.model,
|
|
85
|
+
)
|
|
86
|
+
raise ValueError(message)
|
|
87
|
+
if self.provider is None:
|
|
88
|
+
message = "Provider cannot be set to None."
|
|
89
|
+
structlogger.error(
|
|
90
|
+
"self_hosted_llm_client_config.validation_error",
|
|
91
|
+
message=message,
|
|
92
|
+
provider=self.provider,
|
|
93
|
+
)
|
|
94
|
+
raise ValueError(message)
|
|
95
|
+
if self.api_base is None:
|
|
96
|
+
message = "API base cannot be set to None."
|
|
97
|
+
structlogger.error(
|
|
98
|
+
"self_hosted_llm_client_config.validation_error",
|
|
99
|
+
message=message,
|
|
100
|
+
provider=self.provider,
|
|
101
|
+
)
|
|
102
|
+
raise ValueError(message)
|
|
103
|
+
if self.api_type != OPENAI_PROVIDER:
|
|
104
|
+
message = (
|
|
105
|
+
f"Currently supports only {OPENAI_PROVIDER} endpoints. "
|
|
106
|
+
f"API type must be set to '{OPENAI_PROVIDER}'."
|
|
107
|
+
)
|
|
108
|
+
structlogger.error(
|
|
109
|
+
"self_hosted_llm_client_config.validation_error",
|
|
110
|
+
message=message,
|
|
111
|
+
api_type=self.api_type,
|
|
112
|
+
)
|
|
113
|
+
raise ValueError(message)
|
|
114
|
+
|
|
115
|
+
@classmethod
|
|
116
|
+
def from_dict(cls, config: dict) -> "SelfHostedLLMClientConfig":
|
|
117
|
+
"""
|
|
118
|
+
Initializes a dataclass from the passed config.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
config: (dict) The config from which to initialize.
|
|
122
|
+
|
|
123
|
+
Raises:
|
|
124
|
+
ValueError: Config is missing required keys.
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
DefaultLiteLLMClientConfig
|
|
128
|
+
"""
|
|
129
|
+
# Check for deprecated keys
|
|
130
|
+
raise_deprecation_warnings(config, DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING)
|
|
131
|
+
# Resolve any potential aliases
|
|
132
|
+
config = cls.resolve_config_aliases(config)
|
|
133
|
+
# Validate that the required keys are present
|
|
134
|
+
validate_required_keys(config, REQUIRED_KEYS)
|
|
135
|
+
# Validate that the forbidden keys are not present
|
|
136
|
+
validate_forbidden_keys(config, FORBIDDEN_KEYS)
|
|
137
|
+
this = SelfHostedLLMClientConfig(
|
|
138
|
+
# Required parameters
|
|
139
|
+
model=config.pop(MODEL_CONFIG_KEY),
|
|
140
|
+
provider=config.pop(PROVIDER_CONFIG_KEY),
|
|
141
|
+
api_base=config.pop(API_BASE_CONFIG_KEY),
|
|
142
|
+
# Optional parameters
|
|
143
|
+
api_type=config.pop(API_TYPE_CONFIG_KEY, OPENAI_PROVIDER),
|
|
144
|
+
api_version=config.pop(API_VERSION_CONFIG_KEY, None),
|
|
145
|
+
use_chat_completions_endpoint=config.pop(
|
|
146
|
+
USE_CHAT_COMPLETIONS_ENDPOINT_CONFIG_KEY, True
|
|
147
|
+
),
|
|
148
|
+
# The rest of parameters (e.g. model parameters) are considered
|
|
149
|
+
# as extra parameters
|
|
150
|
+
extra_parameters=config,
|
|
151
|
+
)
|
|
152
|
+
return this
|
|
153
|
+
|
|
154
|
+
def to_dict(self) -> dict:
|
|
155
|
+
"""Converts the config instance into a dictionary."""
|
|
156
|
+
d = asdict(self)
|
|
157
|
+
# Extra parameters should also be on the top level
|
|
158
|
+
d.pop("extra_parameters", None)
|
|
159
|
+
d.update(self.extra_parameters)
|
|
160
|
+
return d
|
|
161
|
+
|
|
162
|
+
@staticmethod
|
|
163
|
+
def resolve_config_aliases(config: Dict[str, Any]) -> Dict[str, Any]:
|
|
164
|
+
return resolve_aliases(config, DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def is_self_hosted_config(config: dict) -> bool:
|
|
168
|
+
"""Check whether the configuration is meant to configure an self-hosted client."""
|
|
169
|
+
# Process the config to handle all the aliases
|
|
170
|
+
config = SelfHostedLLMClientConfig.resolve_config_aliases(config)
|
|
171
|
+
|
|
172
|
+
# Case: Configuration contains `provider: self-hosted`
|
|
173
|
+
if config.get(PROVIDER_CONFIG_KEY) == SELF_HOSTED_PROVIDER:
|
|
174
|
+
return True
|
|
175
|
+
|
|
176
|
+
return False
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
import structlog
|
|
2
|
+
from rasa.shared.utils.io import raise_deprecation_warning
|
|
3
|
+
|
|
4
|
+
structlogger = structlog.get_logger()
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def resolve_aliases(config: dict, deprecated_alias_mapping: dict) -> dict:
|
|
8
|
+
"""
|
|
9
|
+
Resolve aliases in the configuration to standard keys.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
config: Dictionary containing the configuration.
|
|
13
|
+
deprecated_alias_mapping: Dictionary mapping aliases to
|
|
14
|
+
their standard keys.
|
|
15
|
+
|
|
16
|
+
Returns:
|
|
17
|
+
New dictionary containing the processed configuration.
|
|
18
|
+
"""
|
|
19
|
+
config = config.copy()
|
|
20
|
+
|
|
21
|
+
for alias, standard_key in deprecated_alias_mapping.items():
|
|
22
|
+
# We check for the alias instead of the standard key because our goal is to
|
|
23
|
+
# update the standard key when the alias is found. Since the standard key is
|
|
24
|
+
# always included in the default component configurations, we overwrite it
|
|
25
|
+
# with the alias value if the alias exists.
|
|
26
|
+
if alias in config:
|
|
27
|
+
config[standard_key] = config.pop(alias)
|
|
28
|
+
|
|
29
|
+
return config
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def raise_deprecation_warnings(config: dict, deprecated_alias_mapping: dict) -> None:
|
|
33
|
+
"""
|
|
34
|
+
Raises warnings for deprecated keys in the configuration.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
config: Dictionary containing the configuration.
|
|
38
|
+
deprecated_alias_mapping: Dictionary mapping deprecated keys to
|
|
39
|
+
their standard keys.
|
|
40
|
+
|
|
41
|
+
Raises:
|
|
42
|
+
DeprecationWarning: If any deprecated key is found in the config.
|
|
43
|
+
"""
|
|
44
|
+
for alias, standard_key in deprecated_alias_mapping.items():
|
|
45
|
+
if alias in config:
|
|
46
|
+
raise_deprecation_warning(
|
|
47
|
+
message=(
|
|
48
|
+
f"'{alias}' is deprecated and will be removed in "
|
|
49
|
+
f"4.0.0. Use '{standard_key}' instead."
|
|
50
|
+
)
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def validate_required_keys(config: dict, required_keys: list) -> None:
|
|
55
|
+
"""
|
|
56
|
+
Validates that the passed config contains all the required keys.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
config: Dictionary containing the configuration.
|
|
60
|
+
required_keys: List of keys that must be present in the config.
|
|
61
|
+
|
|
62
|
+
Raises:
|
|
63
|
+
ValueError: If any required key is missing.
|
|
64
|
+
"""
|
|
65
|
+
missing_keys = [key for key in required_keys if key not in config]
|
|
66
|
+
if missing_keys:
|
|
67
|
+
message = f"Missing required keys '{missing_keys}' for configuration."
|
|
68
|
+
structlogger.error(
|
|
69
|
+
"validate_required_keys",
|
|
70
|
+
message=message,
|
|
71
|
+
missing_keys=missing_keys,
|
|
72
|
+
config=config,
|
|
73
|
+
)
|
|
74
|
+
raise ValueError(message)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def validate_forbidden_keys(config: dict, forbidden_keys: list) -> None:
|
|
78
|
+
"""
|
|
79
|
+
Validates that the passed config doesn't contain any forbidden keys.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
config: Dictionary containing the configuration.
|
|
83
|
+
forbidden_keys: List of keys that are forbidden in the config.
|
|
84
|
+
|
|
85
|
+
Raises:
|
|
86
|
+
ValueError: If any forbidden key is present.
|
|
87
|
+
"""
|
|
88
|
+
forbidden_keys_in_config = set(config.keys()).intersection(set(forbidden_keys))
|
|
89
|
+
|
|
90
|
+
if forbidden_keys_in_config:
|
|
91
|
+
message = (
|
|
92
|
+
f"Forbidden keys '{forbidden_keys_in_config}' present "
|
|
93
|
+
f"in the configuration."
|
|
94
|
+
)
|
|
95
|
+
structlogger.error(
|
|
96
|
+
"validate_forbidden_keys",
|
|
97
|
+
message=message,
|
|
98
|
+
forbidden_keys=forbidden_keys_in_config,
|
|
99
|
+
config=config,
|
|
100
|
+
)
|
|
101
|
+
raise ValueError(message)
|