rasa-pro 3.9.17__py3-none-any.whl → 3.10.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- README.md +5 -37
- rasa/__init__.py +1 -2
- rasa/__main__.py +5 -0
- rasa/anonymization/anonymization_rule_executor.py +2 -2
- rasa/api.py +26 -22
- rasa/cli/arguments/data.py +27 -2
- rasa/cli/arguments/default_arguments.py +25 -3
- rasa/cli/arguments/run.py +9 -9
- rasa/cli/arguments/train.py +2 -0
- rasa/cli/data.py +70 -8
- rasa/cli/e2e_test.py +108 -433
- rasa/cli/interactive.py +1 -0
- rasa/cli/llm_fine_tuning.py +395 -0
- rasa/cli/project_templates/calm/endpoints.yml +1 -1
- rasa/cli/project_templates/tutorial/endpoints.yml +1 -1
- rasa/cli/run.py +14 -13
- rasa/cli/scaffold.py +10 -8
- rasa/cli/train.py +8 -7
- rasa/cli/utils.py +15 -0
- rasa/constants.py +7 -1
- rasa/core/actions/action.py +98 -49
- rasa/core/actions/action_run_slot_rejections.py +4 -1
- rasa/core/actions/custom_action_executor.py +9 -6
- rasa/core/actions/direct_custom_actions_executor.py +80 -0
- rasa/core/actions/e2e_stub_custom_action_executor.py +68 -0
- rasa/core/actions/grpc_custom_action_executor.py +2 -2
- rasa/core/actions/http_custom_action_executor.py +6 -5
- rasa/core/agent.py +21 -17
- rasa/core/channels/__init__.py +2 -0
- rasa/core/channels/audiocodes.py +1 -16
- rasa/core/channels/voice_aware/__init__.py +0 -0
- rasa/core/channels/voice_aware/jambonz.py +103 -0
- rasa/core/channels/voice_aware/jambonz_protocol.py +344 -0
- rasa/core/channels/voice_aware/utils.py +20 -0
- rasa/core/channels/voice_native/__init__.py +0 -0
- rasa/core/constants.py +6 -1
- rasa/core/featurizers/single_state_featurizer.py +1 -22
- rasa/core/featurizers/tracker_featurizers.py +18 -115
- rasa/core/information_retrieval/faiss.py +7 -4
- rasa/core/information_retrieval/information_retrieval.py +8 -0
- rasa/core/information_retrieval/milvus.py +9 -2
- rasa/core/information_retrieval/qdrant.py +1 -1
- rasa/core/nlg/contextual_response_rephraser.py +32 -10
- rasa/core/nlg/summarize.py +4 -3
- rasa/core/policies/enterprise_search_policy.py +100 -44
- rasa/core/policies/flows/flow_executor.py +155 -98
- rasa/core/policies/intentless_policy.py +52 -28
- rasa/core/policies/ted_policy.py +33 -58
- rasa/core/policies/unexpected_intent_policy.py +7 -15
- rasa/core/processor.py +15 -46
- rasa/core/run.py +5 -4
- rasa/core/tracker_store.py +8 -4
- rasa/core/utils.py +45 -56
- rasa/dialogue_understanding/coexistence/llm_based_router.py +45 -12
- rasa/dialogue_understanding/commands/__init__.py +4 -0
- rasa/dialogue_understanding/commands/change_flow_command.py +0 -6
- rasa/dialogue_understanding/commands/session_start_command.py +59 -0
- rasa/dialogue_understanding/commands/set_slot_command.py +1 -5
- rasa/dialogue_understanding/commands/utils.py +38 -0
- rasa/dialogue_understanding/generator/constants.py +10 -3
- rasa/dialogue_understanding/generator/flow_retrieval.py +14 -5
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +12 -2
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +106 -87
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +28 -6
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +90 -37
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +15 -15
- rasa/dialogue_understanding/patterns/session_start.py +37 -0
- rasa/dialogue_understanding/processor/command_processor.py +13 -14
- rasa/e2e_test/aggregate_test_stats_calculator.py +124 -0
- rasa/e2e_test/assertions.py +1181 -0
- rasa/e2e_test/assertions_schema.yml +106 -0
- rasa/e2e_test/constants.py +20 -0
- rasa/e2e_test/e2e_config.py +220 -0
- rasa/e2e_test/e2e_config_schema.yml +26 -0
- rasa/e2e_test/e2e_test_case.py +131 -8
- rasa/e2e_test/e2e_test_converter.py +363 -0
- rasa/e2e_test/e2e_test_converter_prompt.jinja2 +70 -0
- rasa/e2e_test/e2e_test_coverage_report.py +364 -0
- rasa/e2e_test/e2e_test_result.py +26 -6
- rasa/e2e_test/e2e_test_runner.py +498 -73
- rasa/e2e_test/e2e_test_schema.yml +96 -0
- rasa/e2e_test/pykwalify_extensions.py +39 -0
- rasa/e2e_test/stub_custom_action.py +70 -0
- rasa/e2e_test/utils/__init__.py +0 -0
- rasa/e2e_test/utils/e2e_yaml_utils.py +55 -0
- rasa/e2e_test/utils/io.py +596 -0
- rasa/e2e_test/utils/validation.py +80 -0
- rasa/engine/recipes/default_components.py +0 -2
- rasa/engine/storage/local_model_storage.py +0 -1
- rasa/env.py +9 -0
- rasa/llm_fine_tuning/__init__.py +0 -0
- rasa/llm_fine_tuning/annotation_module.py +241 -0
- rasa/llm_fine_tuning/conversations.py +144 -0
- rasa/llm_fine_tuning/llm_data_preparation_module.py +178 -0
- rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +407 -0
- rasa/llm_fine_tuning/paraphrasing/__init__.py +0 -0
- rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +281 -0
- rasa/llm_fine_tuning/paraphrasing/default_rephrase_prompt_template.jina2 +44 -0
- rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +121 -0
- rasa/llm_fine_tuning/paraphrasing/rephrased_user_message.py +10 -0
- rasa/llm_fine_tuning/paraphrasing_module.py +128 -0
- rasa/llm_fine_tuning/storage.py +174 -0
- rasa/llm_fine_tuning/train_test_split_module.py +441 -0
- rasa/model_training.py +48 -16
- rasa/nlu/classifiers/diet_classifier.py +25 -38
- rasa/nlu/classifiers/logistic_regression_classifier.py +9 -44
- rasa/nlu/classifiers/sklearn_intent_classifier.py +16 -37
- rasa/nlu/extractors/crf_entity_extractor.py +50 -93
- rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +45 -78
- rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +17 -52
- rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +3 -5
- rasa/nlu/persistor.py +129 -32
- rasa/server.py +45 -10
- rasa/shared/constants.py +63 -15
- rasa/shared/core/domain.py +15 -12
- rasa/shared/core/events.py +28 -2
- rasa/shared/core/flows/flow.py +208 -13
- rasa/shared/core/flows/flow_path.py +84 -0
- rasa/shared/core/flows/flows_list.py +28 -10
- rasa/shared/core/flows/flows_yaml_schema.json +269 -193
- rasa/shared/core/flows/validation.py +112 -25
- rasa/shared/core/flows/yaml_flows_io.py +149 -10
- rasa/shared/core/trackers.py +6 -0
- rasa/shared/core/training_data/visualization.html +2 -2
- rasa/shared/exceptions.py +4 -0
- rasa/shared/importers/importer.py +60 -11
- rasa/shared/importers/remote_importer.py +196 -0
- rasa/shared/nlu/constants.py +2 -0
- rasa/shared/nlu/training_data/features.py +2 -120
- rasa/shared/providers/_configs/__init__.py +0 -0
- rasa/shared/providers/_configs/azure_openai_client_config.py +181 -0
- rasa/shared/providers/_configs/client_config.py +57 -0
- rasa/shared/providers/_configs/default_litellm_client_config.py +130 -0
- rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +234 -0
- rasa/shared/providers/_configs/openai_client_config.py +175 -0
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +171 -0
- rasa/shared/providers/_configs/utils.py +101 -0
- rasa/shared/providers/_ssl_verification_utils.py +124 -0
- rasa/shared/providers/embedding/__init__.py +0 -0
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +254 -0
- rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py +74 -0
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +277 -0
- rasa/shared/providers/embedding/default_litellm_embedding_client.py +102 -0
- rasa/shared/providers/embedding/embedding_client.py +90 -0
- rasa/shared/providers/embedding/embedding_response.py +41 -0
- rasa/shared/providers/embedding/huggingface_local_embedding_client.py +191 -0
- rasa/shared/providers/embedding/openai_embedding_client.py +172 -0
- rasa/shared/providers/llm/__init__.py +0 -0
- rasa/shared/providers/llm/_base_litellm_client.py +227 -0
- rasa/shared/providers/llm/azure_openai_llm_client.py +338 -0
- rasa/shared/providers/llm/default_litellm_llm_client.py +84 -0
- rasa/shared/providers/llm/llm_client.py +76 -0
- rasa/shared/providers/llm/llm_response.py +50 -0
- rasa/shared/providers/llm/openai_llm_client.py +155 -0
- rasa/shared/providers/llm/self_hosted_llm_client.py +169 -0
- rasa/shared/providers/mappings.py +75 -0
- rasa/shared/utils/cli.py +30 -0
- rasa/shared/utils/io.py +65 -3
- rasa/shared/utils/llm.py +223 -200
- rasa/shared/utils/yaml.py +122 -7
- rasa/studio/download.py +19 -13
- rasa/studio/train.py +2 -3
- rasa/studio/upload.py +2 -3
- rasa/telemetry.py +113 -58
- rasa/tracing/config.py +2 -3
- rasa/tracing/instrumentation/attribute_extractors.py +29 -17
- rasa/tracing/instrumentation/instrumentation.py +4 -47
- rasa/utils/common.py +18 -19
- rasa/utils/endpoints.py +7 -4
- rasa/utils/io.py +66 -0
- rasa/utils/json_utils.py +60 -0
- rasa/utils/licensing.py +9 -1
- rasa/utils/ml_utils.py +4 -2
- rasa/utils/tensorflow/model_data.py +193 -2
- rasa/validator.py +195 -1
- rasa/version.py +1 -1
- {rasa_pro-3.9.17.dist-info → rasa_pro-3.10.3.dist-info}/METADATA +25 -51
- {rasa_pro-3.9.17.dist-info → rasa_pro-3.10.3.dist-info}/RECORD +183 -119
- rasa/nlu/classifiers/llm_intent_classifier.py +0 -519
- rasa/shared/providers/openai/clients.py +0 -43
- rasa/shared/providers/openai/session_handler.py +0 -110
- rasa/utils/tensorflow/feature_array.py +0 -366
- /rasa/{shared/providers/openai → cli/project_templates/tutorial/actions}/__init__.py +0 -0
- /rasa/cli/project_templates/tutorial/{actions.py → actions/actions.py} +0 -0
- {rasa_pro-3.9.17.dist-info → rasa_pro-3.10.3.dist-info}/NOTICE +0 -0
- {rasa_pro-3.9.17.dist-info → rasa_pro-3.10.3.dist-info}/WHEEL +0 -0
- {rasa_pro-3.9.17.dist-info → rasa_pro-3.10.3.dist-info}/entry_points.txt +0 -0
rasa/shared/utils/llm.py
CHANGED
|
@@ -1,46 +1,60 @@
|
|
|
1
|
-
import
|
|
2
|
-
import
|
|
3
|
-
|
|
4
|
-
|
|
1
|
+
from functools import wraps
|
|
2
|
+
from typing import (
|
|
3
|
+
Any,
|
|
4
|
+
Callable,
|
|
5
|
+
Dict,
|
|
6
|
+
Optional,
|
|
7
|
+
Text,
|
|
8
|
+
Type,
|
|
9
|
+
TypeVar,
|
|
10
|
+
TYPE_CHECKING,
|
|
11
|
+
Union,
|
|
12
|
+
cast,
|
|
13
|
+
)
|
|
14
|
+
import json
|
|
5
15
|
import structlog
|
|
6
16
|
|
|
7
17
|
import rasa.shared.utils.io
|
|
8
18
|
from rasa.shared.constants import (
|
|
9
19
|
RASA_PATTERN_INTERNAL_ERROR_USER_INPUT_TOO_LONG,
|
|
10
20
|
RASA_PATTERN_INTERNAL_ERROR_USER_INPUT_EMPTY,
|
|
11
|
-
|
|
12
|
-
OPENAI_API_VERSION_ENV_VAR,
|
|
13
|
-
OPENAI_API_BASE_ENV_VAR,
|
|
14
|
-
REQUESTS_CA_BUNDLE_ENV_VAR,
|
|
15
|
-
OPENAI_API_BASE_NO_PREFIX_CONFIG_KEY,
|
|
16
|
-
OPENAI_API_TYPE_NO_PREFIX_CONFIG_KEY,
|
|
17
|
-
OPENAI_API_VERSION_CONFIG_KEY,
|
|
18
|
-
OPENAI_API_VERSION_NO_PREFIX_CONFIG_KEY,
|
|
19
|
-
OPENAI_API_TYPE_CONFIG_KEY,
|
|
20
|
-
OPENAI_API_BASE_CONFIG_KEY,
|
|
21
|
-
OPENAI_DEPLOYMENT_NAME_CONFIG_KEY,
|
|
22
|
-
OPENAI_DEPLOYMENT_CONFIG_KEY,
|
|
23
|
-
OPENAI_ENGINE_CONFIG_KEY,
|
|
24
|
-
LANGCHAIN_TYPE_CONFIG_KEY,
|
|
25
|
-
RASA_TYPE_CONFIG_KEY,
|
|
21
|
+
PROVIDER_CONFIG_KEY,
|
|
26
22
|
)
|
|
27
23
|
from rasa.shared.core.events import BotUttered, UserUttered
|
|
28
24
|
from rasa.shared.core.slots import Slot, BooleanSlot, CategoricalSlot
|
|
29
|
-
from rasa.shared.engine.caching import
|
|
25
|
+
from rasa.shared.engine.caching import (
|
|
26
|
+
get_local_cache_location,
|
|
27
|
+
)
|
|
30
28
|
from rasa.shared.exceptions import (
|
|
31
29
|
FileIOException,
|
|
32
30
|
FileNotFoundException,
|
|
31
|
+
ProviderClientValidationError,
|
|
32
|
+
)
|
|
33
|
+
from rasa.shared.providers._configs.azure_openai_client_config import (
|
|
34
|
+
is_azure_openai_config,
|
|
35
|
+
)
|
|
36
|
+
from rasa.shared.providers._configs.huggingface_local_embedding_client_config import (
|
|
37
|
+
is_huggingface_local_config,
|
|
33
38
|
)
|
|
39
|
+
from rasa.shared.providers._configs.openai_client_config import is_openai_config
|
|
40
|
+
from rasa.shared.providers._configs.self_hosted_llm_client_config import (
|
|
41
|
+
is_self_hosted_config,
|
|
42
|
+
)
|
|
43
|
+
from rasa.shared.providers.embedding.embedding_client import EmbeddingClient
|
|
44
|
+
from rasa.shared.providers.llm.llm_client import LLMClient
|
|
45
|
+
from rasa.shared.providers.mappings import (
|
|
46
|
+
get_llm_client_from_provider,
|
|
47
|
+
AZURE_OPENAI_PROVIDER,
|
|
48
|
+
OPENAI_PROVIDER,
|
|
49
|
+
SELF_HOSTED_PROVIDER,
|
|
50
|
+
get_embedding_client_from_provider,
|
|
51
|
+
HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER,
|
|
52
|
+
get_client_config_class_from_provider,
|
|
53
|
+
)
|
|
54
|
+
from rasa.shared.utils.cli import print_error_and_exit
|
|
34
55
|
|
|
35
56
|
if TYPE_CHECKING:
|
|
36
|
-
from langchain.chat_models import AzureChatOpenAI
|
|
37
|
-
from langchain.schema.embeddings import Embeddings
|
|
38
|
-
from langchain.llms.base import BaseLLM
|
|
39
57
|
from rasa.shared.core.trackers import DialogueStateTracker
|
|
40
|
-
from rasa.shared.providers.openai.clients import (
|
|
41
|
-
AioHTTPSessionAzureChatOpenAI,
|
|
42
|
-
AioHTTPSessionOpenAIChat,
|
|
43
|
-
)
|
|
44
58
|
|
|
45
59
|
structlogger = structlog.get_logger()
|
|
46
60
|
|
|
@@ -70,6 +84,94 @@ ERROR_PLACEHOLDER = {
|
|
|
70
84
|
"default": "[User input triggered an error]",
|
|
71
85
|
}
|
|
72
86
|
|
|
87
|
+
_Factory_F = TypeVar(
|
|
88
|
+
"_Factory_F",
|
|
89
|
+
bound=Callable[[Dict[str, Any], Dict[str, Any]], Union[EmbeddingClient, LLMClient]],
|
|
90
|
+
)
|
|
91
|
+
_CombineConfigs_F = TypeVar(
|
|
92
|
+
"_CombineConfigs_F",
|
|
93
|
+
bound=Callable[[Dict[str, Any], Dict[str, Any]], Dict[str, Any]],
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _compute_hash_for_cache_from_configs(
|
|
98
|
+
config_x: Dict[str, Any], config_y: Dict[str, Any]
|
|
99
|
+
) -> int:
|
|
100
|
+
"""Get a unique hash of the default and custom configs."""
|
|
101
|
+
return hash(
|
|
102
|
+
json.dumps(config_x, sort_keys=True) + json.dumps(config_y, sort_keys=True)
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def _retrieve_from_cache(
|
|
107
|
+
cache: Dict[int, Any], unique_hash: int, function: Callable, function_kwargs: dict
|
|
108
|
+
) -> Any:
|
|
109
|
+
"""Retrieve the value from the cache if it exists. If it does not exist, cache it"""
|
|
110
|
+
if unique_hash in cache:
|
|
111
|
+
return cache[unique_hash]
|
|
112
|
+
else:
|
|
113
|
+
return_value = function(**function_kwargs)
|
|
114
|
+
cache[unique_hash] = return_value
|
|
115
|
+
return return_value
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def _cache_factory(function: _Factory_F) -> _Factory_F:
|
|
119
|
+
"""Memoize the factory methods based on the arguments."""
|
|
120
|
+
cache: Dict[int, Union[EmbeddingClient, LLMClient]] = {}
|
|
121
|
+
|
|
122
|
+
@wraps(function)
|
|
123
|
+
def factory_method_wrapper(
|
|
124
|
+
config_x: Dict[str, Any], config_y: Dict[str, Any]
|
|
125
|
+
) -> Union[EmbeddingClient, LLMClient]:
|
|
126
|
+
# Get a unique hash of the default and custom configs.
|
|
127
|
+
unique_hash = _compute_hash_for_cache_from_configs(config_x, config_y)
|
|
128
|
+
return _retrieve_from_cache(
|
|
129
|
+
cache=cache,
|
|
130
|
+
unique_hash=unique_hash,
|
|
131
|
+
function=function,
|
|
132
|
+
function_kwargs={"custom_config": config_x, "default_config": config_y},
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
def clear_cache() -> None:
|
|
136
|
+
cache.clear()
|
|
137
|
+
structlogger.debug(
|
|
138
|
+
"Cleared cache for factory method",
|
|
139
|
+
function_name=function.__name__,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
setattr(factory_method_wrapper, "clear_cache", clear_cache)
|
|
143
|
+
return cast(_Factory_F, factory_method_wrapper)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def _cache_combine_custom_and_default_configs(
|
|
147
|
+
function: _CombineConfigs_F,
|
|
148
|
+
) -> _CombineConfigs_F:
|
|
149
|
+
"""Memoize the combine_custom_and_default_config method based on the arguments."""
|
|
150
|
+
cache: Dict[int, dict] = {}
|
|
151
|
+
|
|
152
|
+
@wraps(function)
|
|
153
|
+
def combine_configs_wrapper(
|
|
154
|
+
config_x: Dict[str, Any], config_y: Dict[str, Any]
|
|
155
|
+
) -> dict:
|
|
156
|
+
# Get a unique hash of the default and custom configs.
|
|
157
|
+
unique_hash = _compute_hash_for_cache_from_configs(config_x, config_y)
|
|
158
|
+
return _retrieve_from_cache(
|
|
159
|
+
cache=cache,
|
|
160
|
+
unique_hash=unique_hash,
|
|
161
|
+
function=function,
|
|
162
|
+
function_kwargs={"custom_config": config_x, "default_config": config_y},
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
def clear_cache() -> None:
|
|
166
|
+
cache.clear()
|
|
167
|
+
structlogger.debug(
|
|
168
|
+
"Cleared cache for combine_custom_and_default_config method",
|
|
169
|
+
function_name=function.__name__,
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
setattr(combine_configs_wrapper, "clear_cache", clear_cache)
|
|
173
|
+
return cast(_CombineConfigs_F, combine_configs_wrapper)
|
|
174
|
+
|
|
73
175
|
|
|
74
176
|
def tracker_as_readable_transcript(
|
|
75
177
|
tracker: "DialogueStateTracker",
|
|
@@ -138,11 +240,15 @@ def sanitize_message_for_prompt(text: Optional[str]) -> str:
|
|
|
138
240
|
return text.replace("\n", " ") if text else ""
|
|
139
241
|
|
|
140
242
|
|
|
243
|
+
@_cache_combine_custom_and_default_configs
|
|
141
244
|
def combine_custom_and_default_config(
|
|
142
|
-
custom_config: Optional[Dict[
|
|
245
|
+
custom_config: Optional[Dict[str, Any]], default_config: Dict[str, Any]
|
|
143
246
|
) -> Dict[Text, Any]:
|
|
144
247
|
"""Merges the given llm config with the default config.
|
|
145
248
|
|
|
249
|
+
This method guarantees that the provider is set and all the deprecated keys are
|
|
250
|
+
resolved. Hence, produces only a valid client config.
|
|
251
|
+
|
|
146
252
|
Only uses the default configuration arguments, if the type set in the
|
|
147
253
|
custom config matches the type in the default config. Otherwise, only
|
|
148
254
|
the custom config is used.
|
|
@@ -155,155 +261,96 @@ def combine_custom_and_default_config(
|
|
|
155
261
|
The merged config.
|
|
156
262
|
"""
|
|
157
263
|
if custom_config is None:
|
|
158
|
-
return default_config
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
264
|
+
return default_config.copy()
|
|
265
|
+
|
|
266
|
+
# Get the provider from the custom config.
|
|
267
|
+
custom_config_provider = get_provider_from_config(custom_config)
|
|
268
|
+
# We expect the provider to be set in the default configs of all Rasa components.
|
|
269
|
+
default_config_provider = default_config[PROVIDER_CONFIG_KEY]
|
|
270
|
+
|
|
271
|
+
if (
|
|
272
|
+
custom_config_provider is not None
|
|
273
|
+
and custom_config_provider != default_config_provider
|
|
274
|
+
):
|
|
275
|
+
# Get the provider-specific config class
|
|
276
|
+
client_config_clazz = get_client_config_class_from_provider(
|
|
277
|
+
custom_config_provider
|
|
166
278
|
)
|
|
279
|
+
# Checks for deprecated keys, resolves aliases and returns a valid config.
|
|
280
|
+
# This is done to ensure that the custom config is valid.
|
|
281
|
+
return client_config_clazz.from_dict(custom_config).to_dict()
|
|
282
|
+
|
|
283
|
+
# If the provider is the same in both configs
|
|
284
|
+
# OR provider is not specified in the custom config
|
|
285
|
+
# perform MERGE by overriding the default config keys and values
|
|
286
|
+
# with custom config keys and values.
|
|
287
|
+
merged_config = {**default_config.copy(), **custom_config.copy()}
|
|
288
|
+
# Check for deprecated keys, resolve aliases and return a valid config.
|
|
289
|
+
# This is done to ensure that the merged config is valid.
|
|
290
|
+
default_config_clazz = get_client_config_class_from_provider(
|
|
291
|
+
default_config_provider
|
|
292
|
+
)
|
|
293
|
+
return default_config_clazz.from_dict(merged_config).to_dict()
|
|
167
294
|
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
295
|
+
|
|
296
|
+
def get_provider_from_config(config: dict) -> Optional[str]:
|
|
297
|
+
"""Try to get the provider from the passed llm/embeddings configuration.
|
|
298
|
+
If no provider can be found, return None.
|
|
299
|
+
"""
|
|
300
|
+
if not config:
|
|
301
|
+
return None
|
|
302
|
+
if is_self_hosted_config(config):
|
|
303
|
+
return SELF_HOSTED_PROVIDER
|
|
304
|
+
elif is_azure_openai_config(config):
|
|
305
|
+
return AZURE_OPENAI_PROVIDER
|
|
306
|
+
elif is_openai_config(config):
|
|
307
|
+
return OPENAI_PROVIDER
|
|
308
|
+
elif is_huggingface_local_config(config):
|
|
309
|
+
return HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER
|
|
310
|
+
else:
|
|
311
|
+
return config.get(PROVIDER_CONFIG_KEY)
|
|
173
312
|
|
|
174
313
|
|
|
175
314
|
def ensure_cache() -> None:
|
|
176
315
|
"""Ensures that the cache is initialized."""
|
|
177
|
-
import
|
|
178
|
-
from langchain.cache import SQLiteCache
|
|
316
|
+
import litellm
|
|
179
317
|
|
|
180
|
-
#
|
|
181
|
-
cache_location = get_local_cache_location()
|
|
318
|
+
# Ensure the cache directory exists
|
|
319
|
+
cache_location = get_local_cache_location() / "rasa-llm-cache"
|
|
182
320
|
cache_location.mkdir(parents=True, exist_ok=True)
|
|
183
321
|
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
def preprocess_config_for_azure(config: Dict[str, Any]) -> Dict[str, Any]:
|
|
189
|
-
"""Preprocesses the config for Azure deployments.
|
|
190
|
-
|
|
191
|
-
This function is used to preprocess the config for Azure deployments.
|
|
192
|
-
AzureChatOpenAI does not expect the _type key, as it is not a defined parameter
|
|
193
|
-
in the class. So we need to remove it before passing the config to the class.
|
|
194
|
-
AzureChatOpenAI expects the openai_api_type key to be set instead.
|
|
195
|
-
|
|
196
|
-
Args:
|
|
197
|
-
config: The config to preprocess.
|
|
198
|
-
|
|
199
|
-
Returns:
|
|
200
|
-
The preprocessed config.
|
|
201
|
-
"""
|
|
202
|
-
config["deployment_name"] = (
|
|
203
|
-
config.get(OPENAI_DEPLOYMENT_NAME_CONFIG_KEY)
|
|
204
|
-
or config.get(OPENAI_DEPLOYMENT_CONFIG_KEY)
|
|
205
|
-
or config.get(OPENAI_ENGINE_CONFIG_KEY)
|
|
206
|
-
)
|
|
207
|
-
config["openai_api_base"] = (
|
|
208
|
-
config.get(OPENAI_API_BASE_CONFIG_KEY)
|
|
209
|
-
or config.get(OPENAI_API_BASE_NO_PREFIX_CONFIG_KEY)
|
|
210
|
-
or os.environ.get(OPENAI_API_BASE_ENV_VAR)
|
|
211
|
-
)
|
|
212
|
-
config["openai_api_type"] = (
|
|
213
|
-
config.get(OPENAI_API_TYPE_CONFIG_KEY)
|
|
214
|
-
or config.get(OPENAI_API_TYPE_NO_PREFIX_CONFIG_KEY)
|
|
215
|
-
or os.environ.get(OPENAI_API_TYPE_ENV_VAR)
|
|
216
|
-
)
|
|
217
|
-
config["openai_api_version"] = (
|
|
218
|
-
config.get(OPENAI_API_VERSION_CONFIG_KEY)
|
|
219
|
-
or config.get(OPENAI_API_VERSION_NO_PREFIX_CONFIG_KEY)
|
|
220
|
-
or os.environ.get(OPENAI_API_VERSION_ENV_VAR)
|
|
221
|
-
)
|
|
222
|
-
for keys in [
|
|
223
|
-
OPENAI_API_BASE_NO_PREFIX_CONFIG_KEY,
|
|
224
|
-
OPENAI_API_TYPE_NO_PREFIX_CONFIG_KEY,
|
|
225
|
-
OPENAI_API_VERSION_NO_PREFIX_CONFIG_KEY,
|
|
226
|
-
OPENAI_DEPLOYMENT_CONFIG_KEY,
|
|
227
|
-
OPENAI_ENGINE_CONFIG_KEY,
|
|
228
|
-
LANGCHAIN_TYPE_CONFIG_KEY,
|
|
229
|
-
]:
|
|
230
|
-
config.pop(keys, None)
|
|
231
|
-
|
|
232
|
-
return config
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
def process_config_for_aiohttp_chat_openai(config: Dict[str, Any]) -> Dict[str, Any]:
|
|
236
|
-
config = config.copy()
|
|
237
|
-
config.pop(LANGCHAIN_TYPE_CONFIG_KEY)
|
|
238
|
-
return config
|
|
322
|
+
# Set diskcache as a caching option
|
|
323
|
+
litellm.cache = litellm.Cache(type="disk", disk_cache_dir=cache_location)
|
|
239
324
|
|
|
240
325
|
|
|
326
|
+
@_cache_factory
|
|
241
327
|
def llm_factory(
|
|
242
328
|
custom_config: Optional[Dict[str, Any]], default_config: Dict[str, Any]
|
|
243
|
-
) ->
|
|
244
|
-
"BaseLLM",
|
|
245
|
-
"AzureChatOpenAI",
|
|
246
|
-
"AioHTTPSessionAzureChatOpenAI",
|
|
247
|
-
"AioHTTPSessionOpenAIChat",
|
|
248
|
-
]:
|
|
329
|
+
) -> LLMClient:
|
|
249
330
|
"""Creates an LLM from the given config.
|
|
250
331
|
|
|
251
332
|
Args:
|
|
252
333
|
custom_config: The custom config containing values to overwrite defaults
|
|
253
334
|
default_config: The default config.
|
|
254
335
|
|
|
255
|
-
|
|
256
336
|
Returns:
|
|
257
|
-
|
|
337
|
+
Instantiated LLM based on the configuration.
|
|
258
338
|
"""
|
|
259
|
-
from langchain.llms.loading import load_llm_from_config
|
|
260
|
-
|
|
261
|
-
ensure_cache()
|
|
262
|
-
|
|
263
339
|
config = combine_custom_and_default_config(custom_config, default_config)
|
|
264
340
|
|
|
265
|
-
|
|
266
|
-
# config in place...
|
|
267
|
-
structlogger.debug("llmfactory.create.llm", config=config)
|
|
268
|
-
# langchain issues a user warning when using chat models. at the same time
|
|
269
|
-
# it doesn't provide a way to instantiate a chat model directly using the
|
|
270
|
-
# config. so for now, we need to suppress the warning here. Original
|
|
271
|
-
# warning:
|
|
272
|
-
# packages/langchain/llms/openai.py:189: UserWarning: You are trying to
|
|
273
|
-
# use a chat model. This way of initializing it is no longer supported.
|
|
274
|
-
# Instead, please use: `from langchain.chat_models import ChatOpenAI
|
|
275
|
-
with warnings.catch_warnings():
|
|
276
|
-
warnings.simplefilter("ignore", category=UserWarning)
|
|
277
|
-
if is_azure_config(config):
|
|
278
|
-
# Azure deployments are treated differently. This is done as the
|
|
279
|
-
# GPT-3.5 Turbo newer versions 0613 and 1106 only support the
|
|
280
|
-
# Chat Completions API.
|
|
281
|
-
from langchain.chat_models import AzureChatOpenAI
|
|
282
|
-
from rasa.shared.providers.openai.clients import (
|
|
283
|
-
AioHTTPSessionAzureChatOpenAI,
|
|
284
|
-
)
|
|
285
|
-
|
|
286
|
-
transformed_config = preprocess_config_for_azure(config.copy())
|
|
287
|
-
if os.environ.get(REQUESTS_CA_BUNDLE_ENV_VAR) is None:
|
|
288
|
-
return AzureChatOpenAI(**transformed_config)
|
|
289
|
-
else:
|
|
290
|
-
return AioHTTPSessionAzureChatOpenAI(**transformed_config)
|
|
291
|
-
|
|
292
|
-
if (
|
|
293
|
-
os.environ.get(REQUESTS_CA_BUNDLE_ENV_VAR) is not None
|
|
294
|
-
and config.get(LANGCHAIN_TYPE_CONFIG_KEY) == "openai"
|
|
295
|
-
):
|
|
296
|
-
from rasa.shared.providers.openai.clients import AioHTTPSessionOpenAIChat
|
|
297
|
-
|
|
298
|
-
config = process_config_for_aiohttp_chat_openai(config)
|
|
299
|
-
return AioHTTPSessionOpenAIChat(**config.copy())
|
|
341
|
+
ensure_cache()
|
|
300
342
|
|
|
301
|
-
|
|
343
|
+
client_clazz: Type[LLMClient] = get_llm_client_from_provider(
|
|
344
|
+
config[PROVIDER_CONFIG_KEY]
|
|
345
|
+
)
|
|
346
|
+
client = client_clazz.from_config(config)
|
|
347
|
+
return client
|
|
302
348
|
|
|
303
349
|
|
|
350
|
+
@_cache_factory
|
|
304
351
|
def embedder_factory(
|
|
305
352
|
custom_config: Optional[Dict[str, Any]], default_config: Dict[str, Any]
|
|
306
|
-
) ->
|
|
353
|
+
) -> EmbeddingClient:
|
|
307
354
|
"""Creates an Embedder from the given config.
|
|
308
355
|
|
|
309
356
|
Args:
|
|
@@ -312,55 +359,17 @@ def embedder_factory(
|
|
|
312
359
|
|
|
313
360
|
|
|
314
361
|
Returns:
|
|
315
|
-
|
|
362
|
+
Instantiated Embedder based on the configuration.
|
|
316
363
|
"""
|
|
317
|
-
from langchain.schema.embeddings import Embeddings
|
|
318
|
-
from langchain.embeddings import (
|
|
319
|
-
CohereEmbeddings,
|
|
320
|
-
HuggingFaceHubEmbeddings,
|
|
321
|
-
HuggingFaceInstructEmbeddings,
|
|
322
|
-
HuggingFaceEmbeddings,
|
|
323
|
-
HuggingFaceBgeEmbeddings,
|
|
324
|
-
LlamaCppEmbeddings,
|
|
325
|
-
OpenAIEmbeddings,
|
|
326
|
-
SpacyEmbeddings,
|
|
327
|
-
VertexAIEmbeddings,
|
|
328
|
-
)
|
|
329
|
-
from rasa.shared.providers.openai.clients import AioHTTPSessionOpenAIEmbeddings
|
|
330
|
-
|
|
331
|
-
type_to_embedding_cls_dict: Dict[str, Type[Embeddings]] = {
|
|
332
|
-
"azure": OpenAIEmbeddings,
|
|
333
|
-
"openai": OpenAIEmbeddings,
|
|
334
|
-
"openai-aiohttp-session": AioHTTPSessionOpenAIEmbeddings,
|
|
335
|
-
"cohere": CohereEmbeddings,
|
|
336
|
-
"spacy": SpacyEmbeddings,
|
|
337
|
-
"vertexai": VertexAIEmbeddings,
|
|
338
|
-
"huggingface_instruct": HuggingFaceInstructEmbeddings,
|
|
339
|
-
"huggingface_hub": HuggingFaceHubEmbeddings,
|
|
340
|
-
"huggingface_bge": HuggingFaceBgeEmbeddings,
|
|
341
|
-
"huggingface": HuggingFaceEmbeddings,
|
|
342
|
-
"llamacpp": LlamaCppEmbeddings,
|
|
343
|
-
}
|
|
344
|
-
|
|
345
364
|
config = combine_custom_and_default_config(custom_config, default_config)
|
|
346
|
-
embedding_type = config.get(LANGCHAIN_TYPE_CONFIG_KEY)
|
|
347
|
-
|
|
348
|
-
if (
|
|
349
|
-
os.environ.get(REQUESTS_CA_BUNDLE_ENV_VAR) is not None
|
|
350
|
-
and embedding_type is not None
|
|
351
|
-
):
|
|
352
|
-
embedding_type = f"{embedding_type}-aiohttp-session"
|
|
353
365
|
|
|
354
|
-
|
|
366
|
+
ensure_cache()
|
|
355
367
|
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
return embeddings_cls(**parameters)
|
|
362
|
-
else:
|
|
363
|
-
raise ValueError(f"Unsupported embeddings type '{embedding_type}'")
|
|
368
|
+
client_clazz: Type[EmbeddingClient] = get_embedding_client_from_provider(
|
|
369
|
+
config[PROVIDER_CONFIG_KEY]
|
|
370
|
+
)
|
|
371
|
+
client = client_clazz.from_config(config)
|
|
372
|
+
return client
|
|
364
373
|
|
|
365
374
|
|
|
366
375
|
def get_prompt_template(
|
|
@@ -396,9 +405,23 @@ def allowed_values_for_slot(slot: Slot) -> Union[str, None]:
|
|
|
396
405
|
return None
|
|
397
406
|
|
|
398
407
|
|
|
399
|
-
def
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
408
|
+
def try_instantiate_llm_client(
|
|
409
|
+
custom_llm_config: Optional[Dict],
|
|
410
|
+
default_llm_config: Optional[Dict],
|
|
411
|
+
log_source_function: str,
|
|
412
|
+
log_source_component: str,
|
|
413
|
+
) -> None:
|
|
414
|
+
"""Validate llm configuration."""
|
|
415
|
+
try:
|
|
416
|
+
llm_factory(custom_llm_config, default_llm_config)
|
|
417
|
+
except (ProviderClientValidationError, ValueError) as e:
|
|
418
|
+
structlogger.error(
|
|
419
|
+
f"{log_source_function}.llm_instantiation_failed",
|
|
420
|
+
message="Unable to instantiate LLM client.",
|
|
421
|
+
error=e,
|
|
422
|
+
)
|
|
423
|
+
print_error_and_exit(
|
|
424
|
+
f"Unable to create the LLM client for component - {log_source_component}. "
|
|
425
|
+
f"Please make sure you specified the required environment variables. "
|
|
426
|
+
f"Error: {e}"
|
|
427
|
+
)
|
rasa/shared/utils/yaml.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import datetime
|
|
1
2
|
import logging
|
|
2
3
|
import os
|
|
3
4
|
import re
|
|
@@ -12,15 +13,17 @@ from typing import Dict, List, Optional, Any, Callable, Tuple, Union
|
|
|
12
13
|
import jsonschema
|
|
13
14
|
from importlib_resources import files
|
|
14
15
|
from packaging import version
|
|
15
|
-
from packaging.version import LegacyVersion
|
|
16
16
|
from pykwalify.core import Core
|
|
17
17
|
from pykwalify.errors import SchemaError
|
|
18
18
|
from ruamel import yaml as yaml
|
|
19
19
|
from ruamel.yaml import RoundTripRepresenter, YAMLError
|
|
20
20
|
from ruamel.yaml.constructor import DuplicateKeyError, BaseConstructor, ScalarNode
|
|
21
21
|
from ruamel.yaml.comments import CommentedSeq, CommentedMap
|
|
22
|
+
from ruamel.yaml.loader import SafeLoader
|
|
22
23
|
|
|
23
24
|
from rasa.shared.constants import (
|
|
25
|
+
ASSERTIONS_SCHEMA_EXTENSIONS_FILE,
|
|
26
|
+
ASSERTIONS_SCHEMA_FILE,
|
|
24
27
|
MODEL_CONFIG_SCHEMA_FILE,
|
|
25
28
|
CONFIG_SCHEMA_FILE,
|
|
26
29
|
DOCS_URL_TRAINING_DATA,
|
|
@@ -413,12 +416,17 @@ def validate_raw_yaml_using_schema_file_with_responses(
|
|
|
413
416
|
)
|
|
414
417
|
|
|
415
418
|
|
|
416
|
-
def read_yaml(
|
|
419
|
+
def read_yaml(
|
|
420
|
+
content: str,
|
|
421
|
+
reader_type: Union[str, List[str]] = "safe",
|
|
422
|
+
**kwargs: Any,
|
|
423
|
+
) -> Any:
|
|
417
424
|
"""Parses yaml from a text.
|
|
418
425
|
|
|
419
426
|
Args:
|
|
420
427
|
content: A text containing yaml content.
|
|
421
428
|
reader_type: Reader type to use. By default, "safe" will be used.
|
|
429
|
+
**kwargs: Any
|
|
422
430
|
|
|
423
431
|
Raises:
|
|
424
432
|
ruamel.yaml.parser.ParserError: If there was an error when parsing the YAML.
|
|
@@ -432,11 +440,93 @@ def read_yaml(content: str, reader_type: Union[str, List[str]] = "safe") -> Any:
|
|
|
432
440
|
.decode("utf-16")
|
|
433
441
|
)
|
|
434
442
|
|
|
443
|
+
custom_constructor = kwargs.get("custom_constructor", None)
|
|
444
|
+
|
|
445
|
+
# Create YAML parser with custom constructor
|
|
446
|
+
yaml_parser, reset_constructors = create_yaml_parser(
|
|
447
|
+
reader_type, custom_constructor
|
|
448
|
+
)
|
|
449
|
+
yaml_content = yaml_parser.load(content) or {}
|
|
450
|
+
|
|
451
|
+
# Reset to default constructors
|
|
452
|
+
reset_constructors()
|
|
453
|
+
|
|
454
|
+
return yaml_content
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
def create_yaml_parser(
|
|
458
|
+
reader_type: str,
|
|
459
|
+
custom_constructor: Optional[Callable] = None,
|
|
460
|
+
) -> Tuple[yaml.YAML, Callable[[], None]]:
|
|
461
|
+
"""Create a YAML parser with an optional custom constructor.
|
|
462
|
+
|
|
463
|
+
Args:
|
|
464
|
+
reader_type (str): The type of the reader
|
|
465
|
+
(e.g., 'safe', 'rt', 'unsafe').
|
|
466
|
+
custom_constructor (Optional[Callable]):
|
|
467
|
+
A custom constructor function for YAML parsing.
|
|
468
|
+
|
|
469
|
+
Returns:
|
|
470
|
+
Tuple[yaml.YAML, Callable[[], None]]: A tuple containing
|
|
471
|
+
the YAML parser and a function to reset constructors to
|
|
472
|
+
their original state.
|
|
473
|
+
"""
|
|
435
474
|
yaml_parser = yaml.YAML(typ=reader_type)
|
|
436
475
|
yaml_parser.version = YAML_VERSION # type: ignore[assignment]
|
|
437
476
|
yaml_parser.preserve_quotes = True # type: ignore[assignment]
|
|
438
477
|
|
|
439
|
-
|
|
478
|
+
# Save the original constructors
|
|
479
|
+
original_mapping_constructor = yaml_parser.constructor.yaml_constructors.get(
|
|
480
|
+
yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
|
|
481
|
+
)
|
|
482
|
+
original_sequence_constructor = yaml_parser.constructor.yaml_constructors.get(
|
|
483
|
+
yaml.resolver.BaseResolver.DEFAULT_SEQUENCE_TAG
|
|
484
|
+
)
|
|
485
|
+
|
|
486
|
+
if custom_constructor is not None:
|
|
487
|
+
# Attach the custom constructor to the loader
|
|
488
|
+
yaml_parser.constructor.add_constructor(
|
|
489
|
+
yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, custom_constructor
|
|
490
|
+
)
|
|
491
|
+
yaml_parser.constructor.add_constructor(
|
|
492
|
+
yaml.resolver.BaseResolver.DEFAULT_SEQUENCE_TAG, custom_constructor
|
|
493
|
+
)
|
|
494
|
+
|
|
495
|
+
def reset_constructors() -> None:
|
|
496
|
+
"""Reset the constructors back to their original state."""
|
|
497
|
+
yaml_parser.constructor.add_constructor(
|
|
498
|
+
yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, original_mapping_constructor
|
|
499
|
+
)
|
|
500
|
+
yaml_parser.constructor.add_constructor(
|
|
501
|
+
yaml.resolver.BaseResolver.DEFAULT_SEQUENCE_TAG,
|
|
502
|
+
original_sequence_constructor,
|
|
503
|
+
)
|
|
504
|
+
|
|
505
|
+
def custom_date_constructor(loader: SafeLoader, node: ScalarNode) -> str:
|
|
506
|
+
"""Custom constructor for parsing dates in the format '%Y-%m-%d'.
|
|
507
|
+
|
|
508
|
+
This constructor parses dates in the '%Y-%m-%d' format and returns them as
|
|
509
|
+
strings instead of datetime objects. This change was introduced because the
|
|
510
|
+
default timestamp constructor in ruamel.yaml returns datetime objects, which
|
|
511
|
+
caused issues in our use case where the `api_version` in the LLM config must
|
|
512
|
+
be a string, but was being interpreted as a datetime object.
|
|
513
|
+
"""
|
|
514
|
+
value = loader.construct_scalar(node)
|
|
515
|
+
try:
|
|
516
|
+
# Attempt to parse the date
|
|
517
|
+
date_obj = datetime.datetime.strptime(value, "%Y-%m-%d").date()
|
|
518
|
+
# Return the date as a string instead of a datetime object
|
|
519
|
+
return date_obj.strftime("%Y-%m-%d")
|
|
520
|
+
except ValueError:
|
|
521
|
+
# If the date is not in the correct format, return the original value
|
|
522
|
+
return value
|
|
523
|
+
|
|
524
|
+
# Add the custom date constructor
|
|
525
|
+
yaml_parser.constructor.add_constructor(
|
|
526
|
+
"tag:yaml.org,2002:timestamp", custom_date_constructor
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
return yaml_parser, reset_constructors
|
|
440
530
|
|
|
441
531
|
|
|
442
532
|
def _is_ascii(text: str) -> bool:
|
|
@@ -684,9 +774,6 @@ def validate_training_data_format_version(
|
|
|
684
774
|
parsed_version = version.parse(version_value)
|
|
685
775
|
latest_version = version.parse(LATEST_TRAINING_DATA_FORMAT_VERSION)
|
|
686
776
|
|
|
687
|
-
if isinstance(parsed_version, LegacyVersion):
|
|
688
|
-
raise TypeError
|
|
689
|
-
|
|
690
777
|
if parsed_version < latest_version:
|
|
691
778
|
raise_warning(
|
|
692
779
|
f"Training data file {filename} has a lower "
|
|
@@ -702,7 +789,7 @@ def validate_training_data_format_version(
|
|
|
702
789
|
if latest_version >= parsed_version:
|
|
703
790
|
return True
|
|
704
791
|
|
|
705
|
-
except TypeError:
|
|
792
|
+
except (TypeError, version.InvalidVersion):
|
|
706
793
|
raise_warning(
|
|
707
794
|
f"Training data file {filename} must specify "
|
|
708
795
|
f"'{KEY_TRAINING_DATA_FORMAT_VERSION}' as string, for example:\n"
|
|
@@ -784,3 +871,31 @@ def validate_yaml_with_jsonschema(
|
|
|
784
871
|
errors,
|
|
785
872
|
content=source_data,
|
|
786
873
|
)
|
|
874
|
+
|
|
875
|
+
|
|
876
|
+
def validate_yaml_data_using_schema_with_assertions(
|
|
877
|
+
yaml_data: Any,
|
|
878
|
+
schema_content: Union[List[Any], Dict[str, Any]],
|
|
879
|
+
package_name: str = PACKAGE_NAME,
|
|
880
|
+
) -> None:
|
|
881
|
+
"""Validate raw yaml content using a schema with assertions sub-schema.
|
|
882
|
+
|
|
883
|
+
Args:
|
|
884
|
+
yaml_data: the parsed yaml data to be validated
|
|
885
|
+
schema_content: the content of the YAML schema
|
|
886
|
+
package_name: the name of the package the schema is located in. defaults
|
|
887
|
+
to `rasa`.
|
|
888
|
+
"""
|
|
889
|
+
# test case assertions are part of the schema extension
|
|
890
|
+
# it will be included if the schema explicitly references it with
|
|
891
|
+
# include: assertions
|
|
892
|
+
e2e_test_cases_schema_content = read_schema_file(
|
|
893
|
+
ASSERTIONS_SCHEMA_FILE, package_name
|
|
894
|
+
)
|
|
895
|
+
|
|
896
|
+
schema_content = dict(schema_content, **e2e_test_cases_schema_content)
|
|
897
|
+
schema_extensions = [
|
|
898
|
+
str(files(package_name).joinpath(ASSERTIONS_SCHEMA_EXTENSIONS_FILE))
|
|
899
|
+
]
|
|
900
|
+
|
|
901
|
+
validate_yaml_content_using_schema(yaml_data, schema_content, schema_extensions)
|