rasa-pro 3.11.0rc1__py3-none-any.whl → 3.11.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/cli/inspect.py +2 -0
- rasa/cli/studio/studio.py +18 -8
- rasa/core/actions/action_repeat_bot_messages.py +17 -0
- rasa/core/channels/channel.py +17 -0
- rasa/core/channels/voice_ready/audiocodes.py +12 -0
- rasa/core/channels/voice_ready/jambonz.py +13 -2
- rasa/core/channels/voice_ready/twilio_voice.py +6 -21
- rasa/core/channels/voice_stream/voice_channel.py +13 -1
- rasa/core/nlg/contextual_response_rephraser.py +18 -10
- rasa/core/policies/enterprise_search_policy.py +27 -67
- rasa/core/policies/intentless_policy.py +25 -67
- rasa/dialogue_understanding/coexistence/llm_based_router.py +18 -33
- rasa/dialogue_understanding/generator/constants.py +0 -2
- rasa/dialogue_understanding/generator/flow_retrieval.py +33 -50
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +12 -40
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +18 -20
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +19 -1
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +24 -21
- rasa/dialogue_understanding/processor/command_processor.py +21 -1
- rasa/e2e_test/e2e_test_case.py +85 -6
- rasa/engine/validation.py +57 -41
- rasa/model_service.py +3 -0
- rasa/nlu/tokenizers/whitespace_tokenizer.py +3 -14
- rasa/server.py +3 -1
- rasa/shared/core/flows/flows_list.py +5 -1
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +6 -14
- rasa/shared/providers/llm/_base_litellm_client.py +6 -1
- rasa/shared/utils/health_check/__init__.py +0 -0
- rasa/shared/utils/health_check/embeddings_health_check_mixin.py +31 -0
- rasa/shared/utils/health_check/health_check.py +256 -0
- rasa/shared/utils/health_check/llm_health_check_mixin.py +31 -0
- rasa/shared/utils/llm.py +5 -2
- rasa/shared/utils/yaml.py +102 -62
- rasa/studio/auth.py +3 -5
- rasa/studio/config.py +13 -4
- rasa/studio/constants.py +1 -0
- rasa/studio/data_handler.py +10 -3
- rasa/studio/upload.py +21 -10
- rasa/telemetry.py +12 -0
- rasa/tracing/config.py +2 -0
- rasa/tracing/instrumentation/attribute_extractors.py +20 -0
- rasa/tracing/instrumentation/instrumentation.py +121 -0
- rasa/utils/common.py +5 -0
- rasa/utils/io.py +8 -16
- rasa/utils/sanic_error_handler.py +32 -0
- rasa/version.py +1 -1
- {rasa_pro-3.11.0rc1.dist-info → rasa_pro-3.11.0rc2.dist-info}/METADATA +3 -2
- {rasa_pro-3.11.0rc1.dist-info → rasa_pro-3.11.0rc2.dist-info}/RECORD +51 -47
- rasa/shared/utils/health_check.py +0 -533
- {rasa_pro-3.11.0rc1.dist-info → rasa_pro-3.11.0rc2.dist-info}/NOTICE +0 -0
- {rasa_pro-3.11.0rc1.dist-info → rasa_pro-3.11.0rc2.dist-info}/WHEEL +0 -0
- {rasa_pro-3.11.0rc1.dist-info → rasa_pro-3.11.0rc2.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,256 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Optional, Dict, Any
|
|
3
|
+
|
|
4
|
+
from rasa.shared.constants import (
|
|
5
|
+
LLM_API_HEALTH_CHECK_ENV_VAR,
|
|
6
|
+
MODELS_CONFIG_KEY,
|
|
7
|
+
LLM_API_HEALTH_CHECK_DEFAULT_VALUE,
|
|
8
|
+
)
|
|
9
|
+
from rasa.shared.exceptions import ProviderClientValidationError
|
|
10
|
+
from rasa.shared.providers.embedding.embedding_client import EmbeddingClient
|
|
11
|
+
from rasa.shared.providers.llm.llm_client import LLMClient
|
|
12
|
+
from rasa.shared.utils.cli import print_error_and_exit
|
|
13
|
+
from rasa.shared.utils.llm import llm_factory, structlogger, embedder_factory
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def try_instantiate_llm_client(
|
|
17
|
+
custom_llm_config: Optional[Dict],
|
|
18
|
+
default_llm_config: Optional[Dict],
|
|
19
|
+
log_source_function: str,
|
|
20
|
+
log_source_component: str,
|
|
21
|
+
) -> LLMClient:
|
|
22
|
+
"""Validate llm configuration."""
|
|
23
|
+
try:
|
|
24
|
+
return llm_factory(custom_llm_config, default_llm_config)
|
|
25
|
+
except (ProviderClientValidationError, ValueError) as e:
|
|
26
|
+
structlogger.error(
|
|
27
|
+
f"{log_source_function}.llm_instantiation_failed",
|
|
28
|
+
message="Unable to instantiate LLM client.",
|
|
29
|
+
error=e,
|
|
30
|
+
)
|
|
31
|
+
print_error_and_exit(
|
|
32
|
+
f"Unable to create the LLM client for component - {log_source_component}. "
|
|
33
|
+
f"Please make sure you specified the required environment variables "
|
|
34
|
+
f"and configuration keys. "
|
|
35
|
+
f"Error: {e}"
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def try_instantiate_embedder(
|
|
40
|
+
custom_embeddings_config: Optional[Dict],
|
|
41
|
+
default_embeddings_config: Optional[Dict],
|
|
42
|
+
log_source_function: str,
|
|
43
|
+
log_source_component: str,
|
|
44
|
+
) -> EmbeddingClient:
|
|
45
|
+
"""Validate embeddings configuration."""
|
|
46
|
+
try:
|
|
47
|
+
return embedder_factory(custom_embeddings_config, default_embeddings_config)
|
|
48
|
+
except (ProviderClientValidationError, ValueError) as e:
|
|
49
|
+
structlogger.error(
|
|
50
|
+
f"{log_source_function}.embedder_instantiation_failed",
|
|
51
|
+
message="Unable to instantiate Embedding client.",
|
|
52
|
+
error=e,
|
|
53
|
+
)
|
|
54
|
+
print_error_and_exit(
|
|
55
|
+
f"Unable to create the Embedding client for component - "
|
|
56
|
+
f"{log_source_component}. Please make sure you specified the required "
|
|
57
|
+
f"environment variables and configuration keys. Error: {e}"
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def perform_llm_health_check(
|
|
62
|
+
custom_config: Optional[Dict[str, Any]],
|
|
63
|
+
default_config: Dict[str, Any],
|
|
64
|
+
log_source_function: str,
|
|
65
|
+
log_source_component: str,
|
|
66
|
+
) -> None:
|
|
67
|
+
"""Try to instantiate the LLM Client to validate the provided config.
|
|
68
|
+
If the LLM_API_HEALTH_CHECK environment variable is true, perform a test call
|
|
69
|
+
to the LLM API. If config contains multiple models, perform a test call for each
|
|
70
|
+
model in the model group.
|
|
71
|
+
|
|
72
|
+
This method supports both single model configurations and model group configurations
|
|
73
|
+
(configs that have the `models` key).
|
|
74
|
+
"""
|
|
75
|
+
# Instantiate the LLM client or Router LLM client to validate the provided config.
|
|
76
|
+
llm_client = try_instantiate_llm_client(
|
|
77
|
+
custom_config, default_config, log_source_function, log_source_component
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
if is_api_health_check_enabled():
|
|
81
|
+
if (
|
|
82
|
+
custom_config
|
|
83
|
+
and MODELS_CONFIG_KEY in custom_config
|
|
84
|
+
and len(custom_config[MODELS_CONFIG_KEY]) > 1
|
|
85
|
+
):
|
|
86
|
+
# If the config uses a router, instantiate the LLM client for each model
|
|
87
|
+
# in the model group. This is required to perform a test api call for each
|
|
88
|
+
# model in the group.
|
|
89
|
+
# Note: The Router LLM client is not used here as we need to perform a test
|
|
90
|
+
# api call and not load balance the requests.
|
|
91
|
+
for model_config in custom_config[MODELS_CONFIG_KEY]:
|
|
92
|
+
llm_client = try_instantiate_llm_client(
|
|
93
|
+
model_config,
|
|
94
|
+
default_config,
|
|
95
|
+
log_source_function,
|
|
96
|
+
log_source_component,
|
|
97
|
+
)
|
|
98
|
+
send_test_llm_api_request(
|
|
99
|
+
llm_client, log_source_function, log_source_component
|
|
100
|
+
)
|
|
101
|
+
else:
|
|
102
|
+
# Make a test api call to perform a health check for the LLM client.
|
|
103
|
+
# LLM config from config file and model group config from endpoint config
|
|
104
|
+
# without router are handled here.
|
|
105
|
+
send_test_llm_api_request(
|
|
106
|
+
llm_client,
|
|
107
|
+
log_source_function,
|
|
108
|
+
log_source_component,
|
|
109
|
+
)
|
|
110
|
+
else:
|
|
111
|
+
structlogger.warning(
|
|
112
|
+
f"{log_source_function}.perform_llm_health_check.disabled",
|
|
113
|
+
event_info=(
|
|
114
|
+
f"The {LLM_API_HEALTH_CHECK_ENV_VAR} environment variable is set "
|
|
115
|
+
f"to false, which will disable LLM health check. "
|
|
116
|
+
f"It is recommended to set this variable to true in production "
|
|
117
|
+
f"environments."
|
|
118
|
+
),
|
|
119
|
+
)
|
|
120
|
+
return None
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def perform_embeddings_health_check(
|
|
124
|
+
custom_config: Optional[Dict[str, Any]],
|
|
125
|
+
default_config: Dict[str, Any],
|
|
126
|
+
log_source_function: str,
|
|
127
|
+
log_source_component: str,
|
|
128
|
+
) -> None:
|
|
129
|
+
"""Try to instantiate the Embedder to validate the provided config.
|
|
130
|
+
If the LLM_API_HEALTH_CHECK environment variable is true, perform a test call
|
|
131
|
+
to the Embeddings API. If config contains multiple models, perform a test call for
|
|
132
|
+
each model in the model group.
|
|
133
|
+
|
|
134
|
+
This method supports both single model configurations and model group configurations
|
|
135
|
+
(configs that have the `models` key).
|
|
136
|
+
"""
|
|
137
|
+
# Instantiate the Embedder client or the Embedder Router client to validate the
|
|
138
|
+
# provided config. Deprecation warnings and errors are logged here.
|
|
139
|
+
embedder = try_instantiate_embedder(
|
|
140
|
+
custom_config, default_config, log_source_function, log_source_component
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
if is_api_health_check_enabled():
|
|
144
|
+
if (
|
|
145
|
+
custom_config
|
|
146
|
+
and MODELS_CONFIG_KEY in custom_config
|
|
147
|
+
and len(custom_config[MODELS_CONFIG_KEY]) > 1
|
|
148
|
+
):
|
|
149
|
+
# If the config uses a router, instantiate the Embedder client for each
|
|
150
|
+
# model in the model group. This is required to perform a test api call
|
|
151
|
+
# for every model in the group.
|
|
152
|
+
# Note: The Router Embedding client is not used here as we need to perform
|
|
153
|
+
# a test API call and not load balance the requests.
|
|
154
|
+
for model_config in custom_config[MODELS_CONFIG_KEY]:
|
|
155
|
+
embedder = try_instantiate_embedder(
|
|
156
|
+
model_config,
|
|
157
|
+
default_config,
|
|
158
|
+
log_source_function,
|
|
159
|
+
log_source_component,
|
|
160
|
+
)
|
|
161
|
+
send_test_embeddings_api_request(
|
|
162
|
+
embedder, log_source_function, log_source_component
|
|
163
|
+
)
|
|
164
|
+
else:
|
|
165
|
+
# Make a test api call to perform a health check for the Embedding client.
|
|
166
|
+
# Embeddings config from config file and model group config from endpoint
|
|
167
|
+
# config without router are handled here.
|
|
168
|
+
send_test_embeddings_api_request(
|
|
169
|
+
embedder, log_source_function, log_source_component
|
|
170
|
+
)
|
|
171
|
+
else:
|
|
172
|
+
structlogger.warning(
|
|
173
|
+
f"{log_source_function}" f".perform_embeddings_health_check.disabled",
|
|
174
|
+
event_info=(
|
|
175
|
+
f"The {LLM_API_HEALTH_CHECK_ENV_VAR} environment variable is set "
|
|
176
|
+
f"to false, which will disable embeddings API health check. "
|
|
177
|
+
f"It is recommended to set this variable to true in production "
|
|
178
|
+
f"environments."
|
|
179
|
+
),
|
|
180
|
+
)
|
|
181
|
+
return None
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def send_test_llm_api_request(
|
|
185
|
+
llm_client: LLMClient, log_source_function: str, log_source_component: str
|
|
186
|
+
) -> None:
|
|
187
|
+
"""Sends a test request to the LLM API to perform a health check.
|
|
188
|
+
|
|
189
|
+
Raises:
|
|
190
|
+
Exception: If the API call fails.
|
|
191
|
+
"""
|
|
192
|
+
structlogger.info(
|
|
193
|
+
f"{log_source_function}.send_test_llm_api_request",
|
|
194
|
+
event_info=(
|
|
195
|
+
f"Sending a test LLM API request for the component - "
|
|
196
|
+
f"{log_source_component}."
|
|
197
|
+
),
|
|
198
|
+
config=llm_client.config,
|
|
199
|
+
)
|
|
200
|
+
try:
|
|
201
|
+
llm_client.completion("hello")
|
|
202
|
+
except Exception as e:
|
|
203
|
+
structlogger.error(
|
|
204
|
+
f"{log_source_function}.send_test_llm_api_request_failed",
|
|
205
|
+
event_info="Test call to the LLM API failed.",
|
|
206
|
+
error=e,
|
|
207
|
+
)
|
|
208
|
+
print_error_and_exit(
|
|
209
|
+
f"Test call to the LLM API failed for component - {log_source_component}. "
|
|
210
|
+
f"Error: {e}"
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def send_test_embeddings_api_request(
|
|
215
|
+
embedder: EmbeddingClient, log_source_function: str, log_source_component: str
|
|
216
|
+
) -> None:
|
|
217
|
+
"""Sends a test request to the Embeddings API to perform a health check.
|
|
218
|
+
|
|
219
|
+
Raises:
|
|
220
|
+
Exception: If the API call fails.
|
|
221
|
+
"""
|
|
222
|
+
structlogger.info(
|
|
223
|
+
f"{log_source_function}.send_test_embeddings_api_request",
|
|
224
|
+
event_info=(
|
|
225
|
+
f"Sending a test Embeddings API request for the component - "
|
|
226
|
+
f"{log_source_component}."
|
|
227
|
+
),
|
|
228
|
+
config=embedder.config,
|
|
229
|
+
)
|
|
230
|
+
try:
|
|
231
|
+
embedder.embed(["hello"])
|
|
232
|
+
except Exception as e:
|
|
233
|
+
structlogger.error(
|
|
234
|
+
f"{log_source_function}.send_test_llm_api_request_failed",
|
|
235
|
+
event_info="Test call to the Embeddings API failed.",
|
|
236
|
+
error=e,
|
|
237
|
+
)
|
|
238
|
+
print_error_and_exit(
|
|
239
|
+
f"Test call to the Embeddings API failed for component - "
|
|
240
|
+
f"{log_source_component}. Error: {e}"
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def is_api_health_check_enabled() -> bool:
|
|
245
|
+
"""Determines whether the API health check is enabled based on an environment
|
|
246
|
+
variable.
|
|
247
|
+
|
|
248
|
+
Returns:
|
|
249
|
+
bool: True if the API health check is enabled, False otherwise.
|
|
250
|
+
"""
|
|
251
|
+
return (
|
|
252
|
+
os.getenv(
|
|
253
|
+
LLM_API_HEALTH_CHECK_ENV_VAR, LLM_API_HEALTH_CHECK_DEFAULT_VALUE
|
|
254
|
+
).lower()
|
|
255
|
+
== "true"
|
|
256
|
+
)
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from typing import Optional, Dict, Any
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class LLMHealthCheckMixin:
|
|
5
|
+
"""Mixin class that provides methods for performing llm health checks during
|
|
6
|
+
training and inference within components.
|
|
7
|
+
|
|
8
|
+
This mixin offers static methods that wrap the following health check functions:
|
|
9
|
+
- `perform_llm_health_check`
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
@staticmethod
|
|
13
|
+
def perform_llm_health_check(
|
|
14
|
+
custom_llm_config: Optional[Dict[str, Any]],
|
|
15
|
+
default_llm_config: Dict[str, Any],
|
|
16
|
+
log_source_method: str,
|
|
17
|
+
log_source_component: str,
|
|
18
|
+
) -> None:
|
|
19
|
+
"""Wraps the `perform_llm_health_check` function to enable
|
|
20
|
+
tracing and instrumentation.
|
|
21
|
+
"""
|
|
22
|
+
from rasa.shared.utils.health_check.health_check import (
|
|
23
|
+
perform_llm_health_check,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
perform_llm_health_check(
|
|
27
|
+
custom_llm_config,
|
|
28
|
+
default_llm_config,
|
|
29
|
+
log_source_method,
|
|
30
|
+
log_source_component,
|
|
31
|
+
)
|
rasa/shared/utils/llm.py
CHANGED
|
@@ -690,14 +690,16 @@ def resolve_model_client_config(
|
|
|
690
690
|
) -> Optional[Dict[str, Any]]:
|
|
691
691
|
"""Resolve the model group in the model config.
|
|
692
692
|
|
|
693
|
-
If the config is pointing to a model group, the corresponding model group
|
|
693
|
+
1. If the config is pointing to a model group, the corresponding model group
|
|
694
694
|
of the endpoints.yml is returned.
|
|
695
|
-
If the config is using the old syntax, e.g. defining the llm
|
|
695
|
+
2. If the config is using the old syntax, e.g. defining the llm
|
|
696
696
|
directly in config.yml, the config is returned as is.
|
|
697
|
+
3. If the config is already resolved, return it as is.
|
|
697
698
|
|
|
698
699
|
Args:
|
|
699
700
|
model_config: The model config to be resolved.
|
|
700
701
|
component_name: The name of the component.
|
|
702
|
+
component_name: The method of the component.
|
|
701
703
|
|
|
702
704
|
Returns:
|
|
703
705
|
The resolved llm config.
|
|
@@ -718,6 +720,7 @@ def resolve_model_client_config(
|
|
|
718
720
|
if model_config is None:
|
|
719
721
|
return None
|
|
720
722
|
|
|
723
|
+
# Config is already resolved or defines a client without model groups
|
|
721
724
|
if MODEL_GROUP_CONFIG_KEY not in model_config:
|
|
722
725
|
return model_config
|
|
723
726
|
|
rasa/shared/utils/yaml.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import datetime
|
|
2
|
+
import io
|
|
2
3
|
import logging
|
|
3
4
|
import os
|
|
4
5
|
import re
|
|
@@ -8,19 +9,13 @@ from dataclasses import field
|
|
|
8
9
|
from functools import lru_cache
|
|
9
10
|
from io import StringIO
|
|
10
11
|
from pathlib import Path
|
|
11
|
-
from typing import
|
|
12
|
+
from typing import Any, List, Optional, Tuple, Dict, Callable, Union
|
|
12
13
|
|
|
13
14
|
import jsonschema
|
|
14
15
|
from importlib_resources import files
|
|
15
16
|
from packaging import version
|
|
16
17
|
from pykwalify.core import Core
|
|
17
18
|
from pykwalify.errors import SchemaError
|
|
18
|
-
from ruamel import yaml as yaml
|
|
19
|
-
from ruamel.yaml import RoundTripRepresenter, YAMLError
|
|
20
|
-
from ruamel.yaml.comments import CommentedSeq, CommentedMap
|
|
21
|
-
from ruamel.yaml.constructor import DuplicateKeyError, BaseConstructor, ScalarNode
|
|
22
|
-
from ruamel.yaml.loader import SafeLoader
|
|
23
|
-
|
|
24
19
|
from rasa.shared.constants import (
|
|
25
20
|
ASSERTIONS_SCHEMA_EXTENSIONS_FILE,
|
|
26
21
|
ASSERTIONS_SCHEMA_FILE,
|
|
@@ -51,6 +46,11 @@ from rasa.shared.utils.io import (
|
|
|
51
46
|
raise_warning,
|
|
52
47
|
read_json_file,
|
|
53
48
|
)
|
|
49
|
+
from ruamel import yaml as yaml
|
|
50
|
+
from ruamel.yaml import YAML, RoundTripRepresenter, YAMLError
|
|
51
|
+
from ruamel.yaml.comments import CommentedSeq, CommentedMap
|
|
52
|
+
from ruamel.yaml.constructor import DuplicateKeyError, BaseConstructor, ScalarNode
|
|
53
|
+
from ruamel.yaml.loader import SafeLoader
|
|
54
54
|
|
|
55
55
|
logger = logging.getLogger(__name__)
|
|
56
56
|
|
|
@@ -64,8 +64,17 @@ SENSITIVE_DATA = [API_KEY]
|
|
|
64
64
|
|
|
65
65
|
@dataclass
|
|
66
66
|
class PathWithError:
|
|
67
|
+
"""Represents a validation error at a specific location in the YAML content.
|
|
68
|
+
|
|
69
|
+
Attributes:
|
|
70
|
+
message (str): A description of the validation error.
|
|
71
|
+
path (List[str]): Path to the node where the error occurred.
|
|
72
|
+
key (Optional[str]): The specific key associated with the error, if any.
|
|
73
|
+
"""
|
|
74
|
+
|
|
67
75
|
message: str
|
|
68
76
|
path: List[str] = field(default_factory=list)
|
|
77
|
+
key: Optional[str] = None
|
|
69
78
|
|
|
70
79
|
|
|
71
80
|
def fix_yaml_loader() -> None:
|
|
@@ -146,21 +155,72 @@ class YamlValidationException(YamlException, ValueError):
|
|
|
146
155
|
if self.validation_errors:
|
|
147
156
|
unique_errors = {}
|
|
148
157
|
for error in self.validation_errors:
|
|
149
|
-
line_number = self._line_number_for_path(
|
|
158
|
+
line_number = self._line_number_for_path(
|
|
159
|
+
self.content, error.path, error.key
|
|
160
|
+
)
|
|
150
161
|
|
|
151
162
|
if line_number and self.filename:
|
|
152
|
-
|
|
163
|
+
error_location = f" in {self.filename}:{line_number}:\n"
|
|
153
164
|
elif line_number:
|
|
154
|
-
|
|
165
|
+
error_location = f" in Line {line_number}:\n"
|
|
155
166
|
else:
|
|
156
|
-
|
|
167
|
+
error_location = ""
|
|
157
168
|
|
|
158
|
-
|
|
159
|
-
|
|
169
|
+
code_snippet = self._get_code_snippet(line_number)
|
|
170
|
+
error_message = f"{error_location}\n{code_snippet}{error.message}\n"
|
|
171
|
+
unique_errors[error.message] = error_message
|
|
160
172
|
error_msg = "\n".join(unique_errors.values())
|
|
161
173
|
msg += f":\n{error_msg}"
|
|
162
174
|
return msg
|
|
163
175
|
|
|
176
|
+
def _get_code_snippet(
|
|
177
|
+
self,
|
|
178
|
+
error_line: Optional[int],
|
|
179
|
+
context_lines: int = 2,
|
|
180
|
+
) -> str:
|
|
181
|
+
"""Extract code snippet from the YAML lines around the error.
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
error_line: Line number where the error occurred (1-based).
|
|
185
|
+
context_lines: Number of context lines before and after the error line.
|
|
186
|
+
Default is 2, balancing context and readability. Adjust as needed.
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
A string containing the code snippet with the error highlighted.
|
|
190
|
+
"""
|
|
191
|
+
yaml_lines = self._get_serialized_yaml_lines()
|
|
192
|
+
if not yaml_lines or error_line is None:
|
|
193
|
+
return ""
|
|
194
|
+
|
|
195
|
+
start = max(error_line - context_lines - 1, 0)
|
|
196
|
+
end = min(error_line + context_lines, len(yaml_lines))
|
|
197
|
+
snippet_lines = yaml_lines[start:end]
|
|
198
|
+
snippet = ""
|
|
199
|
+
for idx, line_content in enumerate(snippet_lines, start=start + 1):
|
|
200
|
+
prefix = ">>> " if idx == error_line else " "
|
|
201
|
+
line_number_str = str(idx)
|
|
202
|
+
snippet += f"{prefix}{line_number_str} | {line_content}\n"
|
|
203
|
+
return snippet
|
|
204
|
+
|
|
205
|
+
def _get_serialized_yaml_lines(self) -> List[str]:
|
|
206
|
+
"""Serialize the content back to YAML and return the lines."""
|
|
207
|
+
yaml_lines = []
|
|
208
|
+
try:
|
|
209
|
+
yaml = YAML()
|
|
210
|
+
yaml.default_flow_style = False
|
|
211
|
+
# Set width to 1000, so we don't break the lines of the original YAML file
|
|
212
|
+
yaml.width = 1000 # type: ignore[assignment]
|
|
213
|
+
yaml.indent(mapping=2, sequence=4, offset=2)
|
|
214
|
+
stream = io.StringIO()
|
|
215
|
+
yaml.dump(self.content, stream)
|
|
216
|
+
serialized_yaml = stream.getvalue()
|
|
217
|
+
yaml_lines = serialized_yaml.splitlines()
|
|
218
|
+
return yaml_lines
|
|
219
|
+
except Exception as exc:
|
|
220
|
+
logger.debug(f"Error serializing YAML content: {exc}")
|
|
221
|
+
|
|
222
|
+
return yaml_lines
|
|
223
|
+
|
|
164
224
|
def _calculate_number_of_lines(
|
|
165
225
|
self,
|
|
166
226
|
current: Union[CommentedSeq, CommentedMap],
|
|
@@ -228,7 +288,9 @@ class YamlValidationException(YamlException, ValueError):
|
|
|
228
288
|
# Return the calculated child offset and True indicating a line number was found
|
|
229
289
|
return child_offset, True
|
|
230
290
|
|
|
231
|
-
def _line_number_for_path(
|
|
291
|
+
def _line_number_for_path(
|
|
292
|
+
self, current: Any, path: List[str], key: Optional[str] = None
|
|
293
|
+
) -> Optional[int]:
|
|
232
294
|
"""Get line number for a yaml path in the current content.
|
|
233
295
|
|
|
234
296
|
Implemented using recursion: algorithm goes down the path navigating to the
|
|
@@ -237,6 +299,7 @@ class YamlValidationException(YamlException, ValueError):
|
|
|
237
299
|
Args:
|
|
238
300
|
current: current content
|
|
239
301
|
path: path to traverse within the content
|
|
302
|
+
key: the key associated with the error, if any
|
|
240
303
|
|
|
241
304
|
Returns:
|
|
242
305
|
the line number of the path in the content.
|
|
@@ -247,6 +310,10 @@ class YamlValidationException(YamlException, ValueError):
|
|
|
247
310
|
this_line = current.lc.line + 1 if hasattr(current, "lc") else None
|
|
248
311
|
|
|
249
312
|
if not path:
|
|
313
|
+
if key and hasattr(current, "lc"):
|
|
314
|
+
if hasattr(current.lc, "data") and key in current.lc.data:
|
|
315
|
+
key_line_no = current.lc.data[key][0] + 1
|
|
316
|
+
return key_line_no
|
|
250
317
|
return this_line
|
|
251
318
|
|
|
252
319
|
head, tail = path[0], path[1:]
|
|
@@ -256,7 +323,7 @@ class YamlValidationException(YamlException, ValueError):
|
|
|
256
323
|
|
|
257
324
|
if head:
|
|
258
325
|
if isinstance(current, dict) and head in current:
|
|
259
|
-
line = self._line_number_for_path(current[head], tail)
|
|
326
|
+
line = self._line_number_for_path(current[head], tail, key)
|
|
260
327
|
if line is None:
|
|
261
328
|
line_offset, found_lc = self._calculate_number_of_lines(
|
|
262
329
|
current, head
|
|
@@ -266,10 +333,13 @@ class YamlValidationException(YamlException, ValueError):
|
|
|
266
333
|
return this_line + line_offset
|
|
267
334
|
return line
|
|
268
335
|
elif isinstance(current, list) and head.isdigit():
|
|
269
|
-
return
|
|
336
|
+
return (
|
|
337
|
+
self._line_number_for_path(current[int(head)], tail, key)
|
|
338
|
+
or this_line
|
|
339
|
+
)
|
|
270
340
|
else:
|
|
271
341
|
return this_line
|
|
272
|
-
return self._line_number_for_path(current, tail) or this_line
|
|
342
|
+
return self._line_number_for_path(current, tail, key) or this_line
|
|
273
343
|
|
|
274
344
|
|
|
275
345
|
def read_schema_file(
|
|
@@ -331,13 +401,26 @@ def validate_yaml_content_using_schema(
|
|
|
331
401
|
try:
|
|
332
402
|
core.validate(raise_exception=True)
|
|
333
403
|
except SchemaError:
|
|
404
|
+
# PyKwalify propagates each validation error up the data hierarchy, resulting
|
|
405
|
+
# in multiple redundant errors for a single issue. To present a clear message
|
|
406
|
+
# about the root cause, we use only the first error.
|
|
407
|
+
error = core.errors[0]
|
|
408
|
+
|
|
409
|
+
# Increment numeric indices by 1 to convert from 0-based to 1-based indexing
|
|
410
|
+
error_message = re.sub(
|
|
411
|
+
r"(/)(\d+)", lambda m: f"/{int(m.group(2)) + 1}", str(error)
|
|
412
|
+
)
|
|
413
|
+
|
|
334
414
|
raise YamlValidationException(
|
|
335
415
|
"Please make sure the file is correct and all "
|
|
336
416
|
"mandatory parameters are specified. Here are the errors "
|
|
337
417
|
"found during validation",
|
|
338
418
|
[
|
|
339
|
-
PathWithError(
|
|
340
|
-
|
|
419
|
+
PathWithError(
|
|
420
|
+
message=error_message,
|
|
421
|
+
path=error.path.removeprefix("/").split("/"),
|
|
422
|
+
key=getattr(error, "key", None),
|
|
423
|
+
)
|
|
341
424
|
],
|
|
342
425
|
content=yaml_content,
|
|
343
426
|
)
|
|
@@ -424,46 +507,6 @@ def validate_raw_yaml_using_schema_file_with_responses(
|
|
|
424
507
|
)
|
|
425
508
|
|
|
426
509
|
|
|
427
|
-
def process_content(content: str) -> str:
|
|
428
|
-
"""Process the content to handle both Windows paths and emojis.
|
|
429
|
-
Windows paths are processed by escaping backslashes but emojis are left untouched.
|
|
430
|
-
|
|
431
|
-
Args:
|
|
432
|
-
content: yaml content to be processed
|
|
433
|
-
"""
|
|
434
|
-
# Detect common Windows path patterns: e.g., C:\ or \\
|
|
435
|
-
UNESCAPED_WINDOWS_PATH_PATTERN = re.compile(
|
|
436
|
-
r"(?<!\w)[a-zA-Z]:(\\[a-zA-Z0-9_ -]+)*(\\)?(?!\\n)"
|
|
437
|
-
)
|
|
438
|
-
ESCAPED_WINDOWS_PATH_PATTERN = re.compile(
|
|
439
|
-
r"(?<!\w)[a-zA-Z]:(\\\\[a-zA-Z0-9_ -]+)+\\\\?(?!\\n)"
|
|
440
|
-
)
|
|
441
|
-
|
|
442
|
-
# Function to escape backslashes in Windows paths but leave other content as is
|
|
443
|
-
def escape_windows_paths(match: re.Match) -> str:
|
|
444
|
-
path = str(match.group(0))
|
|
445
|
-
return path.replace("\\", "\\\\") # Escape backslashes only in Windows paths
|
|
446
|
-
|
|
447
|
-
def unescape_windows_paths(match: re.Match) -> str:
|
|
448
|
-
path = str(match.group(0))
|
|
449
|
-
return path.replace("\\\\", "\\")
|
|
450
|
-
|
|
451
|
-
# First, process Windows paths by escaping backslashes
|
|
452
|
-
content = re.sub(UNESCAPED_WINDOWS_PATH_PATTERN, escape_windows_paths, content)
|
|
453
|
-
|
|
454
|
-
# Ensure proper handling of emojis by decoding Unicode sequences
|
|
455
|
-
content = (
|
|
456
|
-
content.encode("utf-8")
|
|
457
|
-
.decode("raw_unicode_escape")
|
|
458
|
-
.encode("utf-16", "surrogatepass")
|
|
459
|
-
.decode("utf-16")
|
|
460
|
-
)
|
|
461
|
-
|
|
462
|
-
content = re.sub(ESCAPED_WINDOWS_PATH_PATTERN, unescape_windows_paths, content)
|
|
463
|
-
|
|
464
|
-
return content
|
|
465
|
-
|
|
466
|
-
|
|
467
510
|
def read_yaml(
|
|
468
511
|
content: str,
|
|
469
512
|
reader_type: Union[str, List[str]] = "safe",
|
|
@@ -479,9 +522,6 @@ def read_yaml(
|
|
|
479
522
|
Raises:
|
|
480
523
|
ruamel.yaml.parser.ParserError: If there was an error when parsing the YAML.
|
|
481
524
|
"""
|
|
482
|
-
if _is_ascii(content):
|
|
483
|
-
content = process_content(content)
|
|
484
|
-
|
|
485
525
|
custom_constructor = kwargs.get("custom_constructor", None)
|
|
486
526
|
|
|
487
527
|
# Create YAML parser with custom constructor
|
rasa/studio/auth.py
CHANGED
|
@@ -23,12 +23,10 @@ from rasa.studio.results_logger import with_studio_error_handler, StudioResult
|
|
|
23
23
|
class StudioAuth:
|
|
24
24
|
"""Handles the authentication with the Rasa Studio authentication server."""
|
|
25
25
|
|
|
26
|
-
def __init__(
|
|
27
|
-
self,
|
|
28
|
-
studio_config: StudioConfig,
|
|
29
|
-
verify: bool = True,
|
|
30
|
-
) -> None:
|
|
26
|
+
def __init__(self, studio_config: StudioConfig) -> None:
|
|
31
27
|
self.config = studio_config
|
|
28
|
+
verify = not studio_config.disable_verify
|
|
29
|
+
|
|
32
30
|
self.keycloak_openid = KeycloakOpenID(
|
|
33
31
|
server_url=studio_config.authentication_server_url,
|
|
34
32
|
client_id=studio_config.client_id,
|
rasa/studio/config.py
CHANGED
|
@@ -2,13 +2,14 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
from dataclasses import dataclass
|
|
5
|
-
from typing import Dict, Optional, Text
|
|
5
|
+
from typing import Any, Dict, Optional, Text
|
|
6
6
|
|
|
7
7
|
from rasa.utils.common import read_global_config_value, write_global_config_value
|
|
8
8
|
|
|
9
9
|
from rasa.studio.constants import (
|
|
10
10
|
RASA_STUDIO_AUTH_SERVER_URL_ENV,
|
|
11
11
|
RASA_STUDIO_CLI_CLIENT_ID_KEY_ENV,
|
|
12
|
+
RASA_STUDIO_CLI_DISABLE_VERIFY_KEY_ENV,
|
|
12
13
|
RASA_STUDIO_CLI_REALM_NAME_KEY_ENV,
|
|
13
14
|
RASA_STUDIO_CLI_STUDIO_URL_ENV,
|
|
14
15
|
STUDIO_CONFIG_KEY,
|
|
@@ -19,6 +20,7 @@ STUDIO_URL_KEY = "studio_url"
|
|
|
19
20
|
CLIENT_ID_KEY = "client_id"
|
|
20
21
|
REALM_NAME_KEY = "realm_name"
|
|
21
22
|
CLIENT_SECRET_KEY = "client_secret"
|
|
23
|
+
DISABLE_VERIFY = "disable_verify"
|
|
22
24
|
|
|
23
25
|
|
|
24
26
|
@dataclass
|
|
@@ -27,13 +29,15 @@ class StudioConfig:
|
|
|
27
29
|
studio_url: Optional[Text]
|
|
28
30
|
client_id: Optional[Text]
|
|
29
31
|
realm_name: Optional[Text]
|
|
32
|
+
disable_verify: bool = False
|
|
30
33
|
|
|
31
|
-
def to_dict(self) -> Dict[Text, Optional[
|
|
34
|
+
def to_dict(self) -> Dict[Text, Optional[Any]]:
|
|
32
35
|
return {
|
|
33
36
|
AUTH_SERVER_URL_KEY: self.authentication_server_url,
|
|
34
37
|
STUDIO_URL_KEY: self.studio_url,
|
|
35
38
|
CLIENT_ID_KEY: self.client_id,
|
|
36
39
|
REALM_NAME_KEY: self.realm_name,
|
|
40
|
+
DISABLE_VERIFY: self.disable_verify,
|
|
37
41
|
}
|
|
38
42
|
|
|
39
43
|
@classmethod
|
|
@@ -43,6 +47,7 @@ class StudioConfig:
|
|
|
43
47
|
studio_url=data[STUDIO_URL_KEY],
|
|
44
48
|
client_id=data[CLIENT_ID_KEY],
|
|
45
49
|
realm_name=data[REALM_NAME_KEY],
|
|
50
|
+
disable_verify=data.get(DISABLE_VERIFY, False),
|
|
46
51
|
)
|
|
47
52
|
|
|
48
53
|
def write_config(self) -> None:
|
|
@@ -73,7 +78,7 @@ class StudioConfig:
|
|
|
73
78
|
config = read_global_config_value(STUDIO_CONFIG_KEY, unavailable_ok=True)
|
|
74
79
|
|
|
75
80
|
if config is None:
|
|
76
|
-
return StudioConfig(None, None, None, None)
|
|
81
|
+
return StudioConfig(None, None, None, None, False)
|
|
77
82
|
|
|
78
83
|
if not isinstance(config, dict):
|
|
79
84
|
raise ValueError(
|
|
@@ -83,7 +88,7 @@ class StudioConfig:
|
|
|
83
88
|
)
|
|
84
89
|
|
|
85
90
|
for key in config:
|
|
86
|
-
if not isinstance(config[key], str):
|
|
91
|
+
if not isinstance(config[key], str) and key != DISABLE_VERIFY:
|
|
87
92
|
raise ValueError(
|
|
88
93
|
"Invalid config file format. "
|
|
89
94
|
f"Key '{key}' is not a text value."
|
|
@@ -102,6 +107,9 @@ class StudioConfig:
|
|
|
102
107
|
studio_url=StudioConfig._read_env_value(RASA_STUDIO_CLI_STUDIO_URL_ENV),
|
|
103
108
|
client_id=StudioConfig._read_env_value(RASA_STUDIO_CLI_CLIENT_ID_KEY_ENV),
|
|
104
109
|
realm_name=StudioConfig._read_env_value(RASA_STUDIO_CLI_REALM_NAME_KEY_ENV),
|
|
110
|
+
disable_verify=bool(
|
|
111
|
+
os.getenv(RASA_STUDIO_CLI_DISABLE_VERIFY_KEY_ENV, False)
|
|
112
|
+
),
|
|
105
113
|
)
|
|
106
114
|
|
|
107
115
|
@staticmethod
|
|
@@ -124,4 +132,5 @@ class StudioConfig:
|
|
|
124
132
|
studio_url=self.studio_url or other.studio_url,
|
|
125
133
|
client_id=self.client_id or other.client_id,
|
|
126
134
|
realm_name=self.realm_name or other.realm_name,
|
|
135
|
+
disable_verify=self.disable_verify or other.disable_verify,
|
|
127
136
|
)
|