rasa-pro 3.11.0rc3__py3-none-any.whl → 3.11.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/__main__.py +9 -3
- rasa/cli/studio/upload.py +0 -15
- rasa/cli/utils.py +1 -1
- rasa/core/channels/development_inspector.py +4 -1
- rasa/core/channels/voice_stream/asr/asr_engine.py +19 -1
- rasa/core/channels/voice_stream/asr/azure.py +11 -2
- rasa/core/channels/voice_stream/asr/deepgram.py +4 -3
- rasa/core/channels/voice_stream/tts/azure.py +3 -1
- rasa/core/channels/voice_stream/tts/cartesia.py +3 -3
- rasa/core/channels/voice_stream/tts/tts_engine.py +10 -1
- rasa/core/information_retrieval/qdrant.py +1 -0
- rasa/core/persistor.py +93 -49
- rasa/core/policies/flows/flow_executor.py +18 -8
- rasa/core/processor.py +7 -5
- rasa/e2e_test/aggregate_test_stats_calculator.py +11 -1
- rasa/e2e_test/assertions.py +133 -16
- rasa/e2e_test/assertions_schema.yml +23 -0
- rasa/e2e_test/e2e_test_runner.py +2 -2
- rasa/engine/loader.py +12 -0
- rasa/engine/validation.py +291 -79
- rasa/model_manager/config.py +8 -0
- rasa/model_manager/model_api.py +166 -61
- rasa/model_manager/runner_service.py +31 -26
- rasa/model_manager/trainer_service.py +14 -23
- rasa/model_manager/warm_rasa_process.py +187 -0
- rasa/model_service.py +3 -5
- rasa/model_training.py +3 -1
- rasa/shared/constants.py +22 -0
- rasa/shared/core/domain.py +8 -5
- rasa/shared/core/flows/yaml_flows_io.py +13 -4
- rasa/shared/importers/importer.py +19 -2
- rasa/shared/importers/rasa.py +5 -1
- rasa/shared/nlu/training_data/formats/rasa_yaml.py +18 -3
- rasa/shared/providers/_utils.py +79 -0
- rasa/shared/providers/embedding/default_litellm_embedding_client.py +24 -0
- rasa/shared/providers/llm/default_litellm_llm_client.py +24 -0
- rasa/shared/utils/common.py +29 -2
- rasa/shared/utils/health_check/health_check.py +26 -24
- rasa/shared/utils/yaml.py +116 -31
- rasa/studio/data_handler.py +3 -1
- rasa/studio/upload.py +119 -57
- rasa/validator.py +40 -4
- rasa/version.py +1 -1
- {rasa_pro-3.11.0rc3.dist-info → rasa_pro-3.11.1.dist-info}/METADATA +2 -2
- {rasa_pro-3.11.0rc3.dist-info → rasa_pro-3.11.1.dist-info}/RECORD +48 -46
- {rasa_pro-3.11.0rc3.dist-info → rasa_pro-3.11.1.dist-info}/NOTICE +0 -0
- {rasa_pro-3.11.0rc3.dist-info → rasa_pro-3.11.1.dist-info}/WHEEL +0 -0
- {rasa_pro-3.11.0rc3.dist-info → rasa_pro-3.11.1.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
import shlex
|
|
2
|
+
import subprocess
|
|
3
|
+
from rasa.__main__ import main
|
|
4
|
+
import os
|
|
5
|
+
from typing import List
|
|
6
|
+
import structlog
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
import uuid
|
|
9
|
+
|
|
10
|
+
from rasa.model_manager import config
|
|
11
|
+
from rasa.model_manager.utils import ensure_base_directory_exists, logs_path
|
|
12
|
+
|
|
13
|
+
structlogger = structlog.get_logger(__name__)
|
|
14
|
+
|
|
15
|
+
warm_rasa_processes: List["WarmRasaProcess"] = []
|
|
16
|
+
|
|
17
|
+
NUMBER_OF_INITIAL_PROCESSES = 3
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class WarmRasaProcess:
|
|
22
|
+
"""Data class to store a warm Rasa process.
|
|
23
|
+
|
|
24
|
+
A "warm" Rasa process is one where we've done the heavy lifting of
|
|
25
|
+
importing key modules ahead of time (e.g. litellm). This is to avoid
|
|
26
|
+
long import times when we actually want to run a command.
|
|
27
|
+
|
|
28
|
+
This is a started process waiting for a Rasa CLI command. It's
|
|
29
|
+
output is stored in a log file identified by `log_id`.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
process: subprocess.Popen
|
|
33
|
+
log_id: str
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _create_warm_rasa_process() -> WarmRasaProcess:
|
|
37
|
+
"""Create a new warm Rasa process."""
|
|
38
|
+
command = [
|
|
39
|
+
config.RASA_PYTHON_PATH,
|
|
40
|
+
"-m",
|
|
41
|
+
"rasa.model_manager.warm_rasa_process",
|
|
42
|
+
]
|
|
43
|
+
|
|
44
|
+
envs = os.environ.copy()
|
|
45
|
+
envs["RASA_TELEMETRY_ENABLED"] = "false"
|
|
46
|
+
|
|
47
|
+
log_id = uuid.uuid4().hex
|
|
48
|
+
log_path = logs_path(log_id)
|
|
49
|
+
|
|
50
|
+
ensure_base_directory_exists(log_path)
|
|
51
|
+
|
|
52
|
+
process = subprocess.Popen(
|
|
53
|
+
command,
|
|
54
|
+
stdout=open(log_path, "w"),
|
|
55
|
+
stderr=subprocess.STDOUT,
|
|
56
|
+
stdin=subprocess.PIPE,
|
|
57
|
+
env=envs,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
structlogger.debug(
|
|
61
|
+
"model_trainer.created_warm_rasa_process",
|
|
62
|
+
pid=process.pid,
|
|
63
|
+
command=command,
|
|
64
|
+
log_path=log_path,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
return WarmRasaProcess(process=process, log_id=log_id)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def initialize_warm_rasa_process() -> None:
|
|
71
|
+
"""Initialize the warm Rasa processes."""
|
|
72
|
+
global warm_rasa_processes
|
|
73
|
+
for _ in range(NUMBER_OF_INITIAL_PROCESSES):
|
|
74
|
+
warm_rasa_processes.append(_create_warm_rasa_process())
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def shutdown_warm_rasa_processes() -> None:
|
|
78
|
+
"""Shutdown all warm Rasa processes."""
|
|
79
|
+
global warm_rasa_processes
|
|
80
|
+
for warm_rasa_process in warm_rasa_processes:
|
|
81
|
+
warm_rasa_process.process.terminate()
|
|
82
|
+
warm_rasa_processes = []
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def start_rasa_process(cwd: str, arguments: List[str]) -> WarmRasaProcess:
|
|
86
|
+
"""Start a Rasa process.
|
|
87
|
+
|
|
88
|
+
This will start a Rasa process with the given current working directory
|
|
89
|
+
and arguments. The process will be a warm one, meaning that it has already
|
|
90
|
+
imported all necessary modules.
|
|
91
|
+
"""
|
|
92
|
+
warm_rasa_process = _get_warm_rasa_process()
|
|
93
|
+
_pass_arguments_to_process(warm_rasa_process.process, cwd, arguments)
|
|
94
|
+
return warm_rasa_process
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _get_warm_rasa_process() -> WarmRasaProcess:
|
|
98
|
+
"""Get a warm Rasa process.
|
|
99
|
+
|
|
100
|
+
This will return a warm Rasa process from the pool and create a
|
|
101
|
+
new one to replace it.
|
|
102
|
+
"""
|
|
103
|
+
global warm_rasa_processes
|
|
104
|
+
|
|
105
|
+
if not warm_rasa_processes:
|
|
106
|
+
warm_rasa_processes = [_create_warm_rasa_process()]
|
|
107
|
+
|
|
108
|
+
previous_warm_rasa_process = warm_rasa_processes.pop(0)
|
|
109
|
+
|
|
110
|
+
if previous_warm_rasa_process.process.poll() is not None:
|
|
111
|
+
# process has finished (for some reason...)
|
|
112
|
+
# back up plan is to create a new one on the spot.
|
|
113
|
+
# this should not happen, but let's be safe
|
|
114
|
+
structlogger.warning(
|
|
115
|
+
"model_trainer.warm_rasa_process_finished_unexpectedly",
|
|
116
|
+
pid=previous_warm_rasa_process.process.pid,
|
|
117
|
+
)
|
|
118
|
+
previous_warm_rasa_process = _create_warm_rasa_process()
|
|
119
|
+
|
|
120
|
+
warm_rasa_processes.append(_create_warm_rasa_process())
|
|
121
|
+
return previous_warm_rasa_process
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _pass_arguments_to_process(
|
|
125
|
+
process: subprocess.Popen, cwd: str, arguments: List[str]
|
|
126
|
+
) -> None:
|
|
127
|
+
"""Pass arguments to a warm Rasa process.
|
|
128
|
+
|
|
129
|
+
The process is waiting for input on stdin. We pass the current working
|
|
130
|
+
directory and the arguments to run a Rasa CLI command.
|
|
131
|
+
"""
|
|
132
|
+
arguments_string = " ".join(arguments)
|
|
133
|
+
# send arguments to stdin
|
|
134
|
+
process.stdin.write(cwd.encode()) # type: ignore[union-attr]
|
|
135
|
+
process.stdin.write("\n".encode()) # type: ignore[union-attr]
|
|
136
|
+
process.stdin.write(arguments_string.encode()) # type: ignore[union-attr]
|
|
137
|
+
process.stdin.write("\n".encode()) # type: ignore[union-attr]
|
|
138
|
+
process.stdin.flush() # type: ignore[union-attr]
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def warmup() -> None:
|
|
142
|
+
"""Import all necessary modules to warm up the process.
|
|
143
|
+
|
|
144
|
+
This should include all the modules that take a long time to import.
|
|
145
|
+
We import them now, so that the training / deployment can later
|
|
146
|
+
directly start.
|
|
147
|
+
"""
|
|
148
|
+
try:
|
|
149
|
+
import presidio_analyzer # noqa: F401
|
|
150
|
+
import litellm # noqa: F401
|
|
151
|
+
import langchain # noqa: F401
|
|
152
|
+
import tensorflow # noqa: F401
|
|
153
|
+
import matplotlib # noqa: F401
|
|
154
|
+
import pandas # noqa: F401
|
|
155
|
+
import numpy # noqa: F401
|
|
156
|
+
import spacy # noqa: F401
|
|
157
|
+
import rasa.validator # noqa: F401
|
|
158
|
+
except ImportError:
|
|
159
|
+
pass
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def warm_rasa_main() -> None:
|
|
163
|
+
"""Entry point for processes waiting for their command to run.
|
|
164
|
+
|
|
165
|
+
The process will wait for the current working directory and the command
|
|
166
|
+
to run. These will be send on stdin by the parent process. After receiving
|
|
167
|
+
the input, we will kick things of starting or running a bot.
|
|
168
|
+
|
|
169
|
+
Uses the normal Rasa CLI entry point (e.g. `rasa train --data ...`).
|
|
170
|
+
"""
|
|
171
|
+
warmup()
|
|
172
|
+
|
|
173
|
+
cwd = input()
|
|
174
|
+
|
|
175
|
+
# this should be `train --data ...` or similar
|
|
176
|
+
cli_arguments_str = input()
|
|
177
|
+
# splits the arguments string into a list of arguments as expected by `argparse`
|
|
178
|
+
arguments = shlex.split(cli_arguments_str)
|
|
179
|
+
|
|
180
|
+
# needed to make sure the passed arguments are relative to the working directory
|
|
181
|
+
os.chdir(cwd)
|
|
182
|
+
|
|
183
|
+
main(arguments)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
if __name__ == "__main__":
|
|
187
|
+
warm_rasa_main()
|
rasa/model_service.py
CHANGED
|
@@ -8,7 +8,7 @@ from rasa.core.persistor import RemoteStorageType, get_persistor
|
|
|
8
8
|
from rasa.core.utils import list_routes
|
|
9
9
|
from rasa.model_manager import model_api
|
|
10
10
|
from rasa.model_manager import config
|
|
11
|
-
from rasa.model_manager.config import SERVER_BASE_URL
|
|
11
|
+
from rasa.model_manager.config import SERVER_BASE_URL, SERVER_PORT
|
|
12
12
|
from rasa.utils.common import configure_logging_and_warnings
|
|
13
13
|
import rasa.utils.licensing
|
|
14
14
|
from urllib.parse import urlparse
|
|
@@ -18,8 +18,6 @@ from rasa.utils.sanic_error_handler import register_custom_sanic_error_handler
|
|
|
18
18
|
|
|
19
19
|
structlogger = structlog.get_logger()
|
|
20
20
|
|
|
21
|
-
MODEL_SERVICE_PORT = 8000
|
|
22
|
-
|
|
23
21
|
|
|
24
22
|
def url_prefix_from_base_url() -> str:
|
|
25
23
|
"""Return the path prefix from the base URL."""
|
|
@@ -93,7 +91,7 @@ def main() -> None:
|
|
|
93
91
|
|
|
94
92
|
validate_model_storage_type()
|
|
95
93
|
|
|
96
|
-
structlogger.debug("model_api.starting_server", port=
|
|
94
|
+
structlogger.debug("model_api.starting_server", port=SERVER_PORT)
|
|
97
95
|
|
|
98
96
|
url_prefix = url_prefix_from_base_url()
|
|
99
97
|
# configure the sanic application
|
|
@@ -107,7 +105,7 @@ def main() -> None:
|
|
|
107
105
|
|
|
108
106
|
register_custom_sanic_error_handler(app)
|
|
109
107
|
|
|
110
|
-
app.run(host="0.0.0.0", port=
|
|
108
|
+
app.run(host="0.0.0.0", port=SERVER_PORT, legacy=True, motd=False)
|
|
111
109
|
|
|
112
110
|
|
|
113
111
|
if __name__ == "__main__":
|
rasa/model_training.py
CHANGED
|
@@ -322,8 +322,10 @@ async def _train_graph(
|
|
|
322
322
|
rasa.engine.validation.validate_coexistance_routing_setup(
|
|
323
323
|
domain, model_configuration, flows
|
|
324
324
|
)
|
|
325
|
-
rasa.engine.validation.validate_model_client_configuration_setup(config)
|
|
326
325
|
rasa.engine.validation.validate_model_group_configuration_setup()
|
|
326
|
+
rasa.engine.validation.validate_model_client_configuration_setup_during_training_time(
|
|
327
|
+
config
|
|
328
|
+
)
|
|
327
329
|
rasa.engine.validation.validate_flow_component_dependencies(
|
|
328
330
|
flows, model_configuration
|
|
329
331
|
)
|
rasa/shared/constants.py
CHANGED
|
@@ -149,6 +149,10 @@ AZURE_AD_TOKEN_ENV_VAR = "AZURE_AD_TOKEN"
|
|
|
149
149
|
AZURE_API_BASE_ENV_VAR = "AZURE_API_BASE"
|
|
150
150
|
AZURE_API_VERSION_ENV_VAR = "AZURE_API_VERSION"
|
|
151
151
|
AZURE_API_TYPE_ENV_VAR = "AZURE_API_TYPE"
|
|
152
|
+
AZURE_SPEECH_API_KEY_ENV_VAR = "AZURE_SPEECH_API_KEY"
|
|
153
|
+
|
|
154
|
+
DEEPGRAM_API_KEY_ENV_VAR = "DEEPGRAM_API_KEY"
|
|
155
|
+
CARTESIA_API_KEY_ENV_VAR = "CARTESIA_API_KEY"
|
|
152
156
|
|
|
153
157
|
OPENAI_API_KEY_ENV_VAR = "OPENAI_API_KEY"
|
|
154
158
|
OPENAI_API_TYPE_ENV_VAR = "OPENAI_API_TYPE"
|
|
@@ -159,6 +163,9 @@ OPENAI_API_BASE_CONFIG_KEY = "openai_api_base"
|
|
|
159
163
|
OPENAI_API_TYPE_CONFIG_KEY = "openai_api_type"
|
|
160
164
|
OPENAI_API_VERSION_CONFIG_KEY = "openai_api_version"
|
|
161
165
|
|
|
166
|
+
AWS_BEDROCK_PROVIDER = "bedrock"
|
|
167
|
+
AWS_SAGEMAKER_PROVIDER = "sagemaker"
|
|
168
|
+
|
|
162
169
|
API_BASE_CONFIG_KEY = "api_base"
|
|
163
170
|
API_TYPE_CONFIG_KEY = "api_type"
|
|
164
171
|
API_VERSION_CONFIG_KEY = "api_version"
|
|
@@ -219,6 +226,14 @@ AZURE_API_VERSION_ENV_VAR = "AZURE_API_VERSION"
|
|
|
219
226
|
AZURE_API_TYPE_ENV_VAR = "AZURE_API_TYPE"
|
|
220
227
|
|
|
221
228
|
AWS_REGION_NAME_CONFIG_KEY = "aws_region_name"
|
|
229
|
+
AWS_ACCESS_KEY_ID_CONFIG_KEY = "aws_access_key_id"
|
|
230
|
+
AWS_SECRET_ACCESS_KEY_CONFIG_KEY = "aws_secret_access_key"
|
|
231
|
+
AWS_SESSION_TOKEN_CONFIG_KEY = "aws_session_token"
|
|
232
|
+
|
|
233
|
+
AWS_ACCESS_KEY_ID_ENV_VAR = "AWS_ACCESS_KEY_ID"
|
|
234
|
+
AWS_SECRET_ACCESS_KEY_ENV_VAR = "AWS_SECRET_ACCESS_KEY"
|
|
235
|
+
AWS_REGION_NAME_ENV_VAR = "AWS_REGION_NAME"
|
|
236
|
+
AWS_SESSION_TOKEN_ENV_VAR = "AWS_SESSION_TOKEN"
|
|
222
237
|
|
|
223
238
|
HUGGINGFACE_MULTIPROCESS_CONFIG_KEY = "multi_process"
|
|
224
239
|
HUGGINGFACE_CACHE_FOLDER_CONFIG_KEY = "cache_folder"
|
|
@@ -280,3 +295,10 @@ RASA_PATTERN_CANNOT_HANDLE_INVALID_INTENT = (
|
|
|
280
295
|
)
|
|
281
296
|
|
|
282
297
|
ROUTE_TO_CALM_SLOT = "route_session_to_calm"
|
|
298
|
+
|
|
299
|
+
SENSITIVE_DATA = [
|
|
300
|
+
API_KEY,
|
|
301
|
+
AWS_ACCESS_KEY_ID_CONFIG_KEY,
|
|
302
|
+
AWS_SECRET_ACCESS_KEY_CONFIG_KEY,
|
|
303
|
+
AWS_SESSION_TOKEN_CONFIG_KEY,
|
|
304
|
+
]
|
rasa/shared/core/domain.py
CHANGED
|
@@ -196,6 +196,7 @@ class Domain:
|
|
|
196
196
|
"""
|
|
197
197
|
|
|
198
198
|
validate_yaml: ClassVar[bool] = True
|
|
199
|
+
expand_env_vars: ClassVar[bool] = True
|
|
199
200
|
|
|
200
201
|
@classmethod
|
|
201
202
|
def empty(cls) -> Domain:
|
|
@@ -1955,8 +1956,8 @@ class Domain:
|
|
|
1955
1956
|
"""Check whether the domain is empty."""
|
|
1956
1957
|
return self.as_dict() == Domain.empty().as_dict()
|
|
1957
1958
|
|
|
1958
|
-
@
|
|
1959
|
-
def is_domain_file(filename: Union[Text, Path]) -> bool:
|
|
1959
|
+
@classmethod
|
|
1960
|
+
def is_domain_file(cls, filename: Union[Text, Path]) -> bool:
|
|
1960
1961
|
"""Checks whether the given file path is a Rasa domain file.
|
|
1961
1962
|
|
|
1962
1963
|
Args:
|
|
@@ -1975,7 +1976,7 @@ class Domain:
|
|
|
1975
1976
|
return False
|
|
1976
1977
|
|
|
1977
1978
|
try:
|
|
1978
|
-
content = read_yaml_file(filename)
|
|
1979
|
+
content = read_yaml_file(filename, expand_env_vars=cls.expand_env_vars)
|
|
1979
1980
|
except (RasaException, YamlSyntaxException):
|
|
1980
1981
|
structlogger.warning(
|
|
1981
1982
|
"domain.cannot_load_domain_file",
|
|
@@ -2104,10 +2105,12 @@ class Domain:
|
|
|
2104
2105
|
"domain.from_yaml.validating",
|
|
2105
2106
|
)
|
|
2106
2107
|
validate_raw_yaml_using_schema_file_with_responses(
|
|
2107
|
-
raw_yaml_content,
|
|
2108
|
+
raw_yaml_content,
|
|
2109
|
+
DOMAIN_SCHEMA_FILE,
|
|
2110
|
+
expand_env_vars=cls.expand_env_vars,
|
|
2108
2111
|
)
|
|
2109
2112
|
|
|
2110
|
-
return read_yaml(raw_yaml_content)
|
|
2113
|
+
return read_yaml(raw_yaml_content, expand_env_vars=cls.expand_env_vars)
|
|
2111
2114
|
|
|
2112
2115
|
|
|
2113
2116
|
def warn_about_duplicates_found_during_domain_merging(
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from pathlib import Path
|
|
2
|
-
from typing import Any, Dict, List, Optional, Text, Union
|
|
2
|
+
from typing import Any, ClassVar, Dict, List, Optional, Text, Union
|
|
3
3
|
|
|
4
4
|
import jsonschema
|
|
5
5
|
import ruamel.yaml.nodes as yaml_nodes
|
|
@@ -25,6 +25,8 @@ KEY_FLOWS = "flows"
|
|
|
25
25
|
class YAMLFlowsReader:
|
|
26
26
|
"""Class that reads flows information in YAML format."""
|
|
27
27
|
|
|
28
|
+
expand_env_vars: ClassVar[bool] = True
|
|
29
|
+
|
|
28
30
|
@classmethod
|
|
29
31
|
def read_from_file(
|
|
30
32
|
cls, filename: Union[Text, Path], add_line_numbers: bool = True
|
|
@@ -217,14 +219,21 @@ class YAMLFlowsReader:
|
|
|
217
219
|
`Flow`s read from `string`.
|
|
218
220
|
"""
|
|
219
221
|
validate_yaml_with_jsonschema(
|
|
220
|
-
string,
|
|
222
|
+
string,
|
|
223
|
+
FLOWS_SCHEMA_FILE,
|
|
224
|
+
humanize_error=cls.humanize_flow_error,
|
|
225
|
+
expand_env_vars=cls.expand_env_vars,
|
|
221
226
|
)
|
|
222
227
|
if add_line_numbers:
|
|
223
|
-
yaml_content = read_yaml(
|
|
228
|
+
yaml_content = read_yaml(
|
|
229
|
+
string,
|
|
230
|
+
custom_constructor=line_number_constructor,
|
|
231
|
+
expand_env_vars=cls.expand_env_vars,
|
|
232
|
+
)
|
|
224
233
|
yaml_content = process_yaml_content(yaml_content)
|
|
225
234
|
|
|
226
235
|
else:
|
|
227
|
-
yaml_content = read_yaml(string)
|
|
236
|
+
yaml_content = read_yaml(string, expand_env_vars=cls.expand_env_vars)
|
|
228
237
|
|
|
229
238
|
return FlowsList.from_json(yaml_content.get(KEY_FLOWS, {}), file_path=file_path)
|
|
230
239
|
|
|
@@ -1,7 +1,18 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
3
|
from functools import reduce
|
|
4
|
-
from typing import
|
|
4
|
+
from typing import (
|
|
5
|
+
Any,
|
|
6
|
+
Dict,
|
|
7
|
+
List,
|
|
8
|
+
Optional,
|
|
9
|
+
Set,
|
|
10
|
+
Text,
|
|
11
|
+
Tuple,
|
|
12
|
+
Type,
|
|
13
|
+
Union,
|
|
14
|
+
cast,
|
|
15
|
+
)
|
|
5
16
|
|
|
6
17
|
import importlib_resources
|
|
7
18
|
|
|
@@ -167,6 +178,7 @@ class TrainingDataImporter(ABC):
|
|
|
167
178
|
domain_path: Optional[Text] = None,
|
|
168
179
|
training_data_paths: Optional[List[Text]] = None,
|
|
169
180
|
args: Optional[Dict[Text, Any]] = None,
|
|
181
|
+
expand_env_vars: bool = True,
|
|
170
182
|
) -> "TrainingDataImporter":
|
|
171
183
|
"""Loads a `TrainingDataImporter` instance from a dictionary."""
|
|
172
184
|
from rasa.shared.importers.rasa import RasaFileImporter
|
|
@@ -182,7 +194,12 @@ class TrainingDataImporter(ABC):
|
|
|
182
194
|
importers = [importer for importer in importers if importer]
|
|
183
195
|
if not importers:
|
|
184
196
|
importers = [
|
|
185
|
-
RasaFileImporter(
|
|
197
|
+
RasaFileImporter(
|
|
198
|
+
config_path,
|
|
199
|
+
domain_path,
|
|
200
|
+
training_data_paths,
|
|
201
|
+
expand_env_vars=expand_env_vars,
|
|
202
|
+
)
|
|
186
203
|
]
|
|
187
204
|
|
|
188
205
|
return E2EImporter(
|
rasa/shared/importers/rasa.py
CHANGED
|
@@ -29,7 +29,9 @@ class RasaFileImporter(TrainingDataImporter):
|
|
|
29
29
|
config_file: Optional[Text] = None,
|
|
30
30
|
domain_path: Optional[Text] = None,
|
|
31
31
|
training_data_paths: Optional[Union[List[Text], Text]] = None,
|
|
32
|
+
expand_env_vars: bool = True,
|
|
32
33
|
):
|
|
34
|
+
self.expand_env_vars = expand_env_vars
|
|
33
35
|
self._domain_path = domain_path
|
|
34
36
|
|
|
35
37
|
self._nlu_files = rasa.shared.data.get_data_files(
|
|
@@ -54,7 +56,9 @@ class RasaFileImporter(TrainingDataImporter):
|
|
|
54
56
|
logger.debug("No configuration file was provided to the RasaFileImporter.")
|
|
55
57
|
return {}
|
|
56
58
|
|
|
57
|
-
config = read_model_configuration(
|
|
59
|
+
config = read_model_configuration(
|
|
60
|
+
self.config_file, expand_env_vars=self.expand_env_vars
|
|
61
|
+
)
|
|
58
62
|
return config
|
|
59
63
|
|
|
60
64
|
def get_config_file_for_auto_config(self) -> Optional[Text]:
|
|
@@ -1,7 +1,18 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
from collections import OrderedDict
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import
|
|
4
|
+
from typing import (
|
|
5
|
+
ClassVar,
|
|
6
|
+
Text,
|
|
7
|
+
Any,
|
|
8
|
+
List,
|
|
9
|
+
Dict,
|
|
10
|
+
Tuple,
|
|
11
|
+
Union,
|
|
12
|
+
Iterator,
|
|
13
|
+
Optional,
|
|
14
|
+
Callable,
|
|
15
|
+
)
|
|
5
16
|
|
|
6
17
|
import rasa.shared.data
|
|
7
18
|
from rasa.shared.core.domain import Domain
|
|
@@ -55,6 +66,8 @@ STRIP_SYMBOLS = "\n\r "
|
|
|
55
66
|
class RasaYAMLReader(TrainingDataReader):
|
|
56
67
|
"""Reads YAML training data and creates a TrainingData object."""
|
|
57
68
|
|
|
69
|
+
expand_env_vars: ClassVar[bool] = True
|
|
70
|
+
|
|
58
71
|
def __init__(self) -> None:
|
|
59
72
|
super().__init__()
|
|
60
73
|
self.training_examples: List[Message] = []
|
|
@@ -69,7 +82,9 @@ class RasaYAMLReader(TrainingDataReader):
|
|
|
69
82
|
If the string is not in the right format, an exception will be raised.
|
|
70
83
|
"""
|
|
71
84
|
try:
|
|
72
|
-
validate_raw_yaml_using_schema_file_with_responses(
|
|
85
|
+
validate_raw_yaml_using_schema_file_with_responses(
|
|
86
|
+
string, NLU_SCHEMA_FILE, expand_env_vars=self.expand_env_vars
|
|
87
|
+
)
|
|
73
88
|
except YamlException as e:
|
|
74
89
|
e.filename = self.filename
|
|
75
90
|
raise e
|
|
@@ -88,7 +103,7 @@ class RasaYAMLReader(TrainingDataReader):
|
|
|
88
103
|
"""
|
|
89
104
|
self.validate(string)
|
|
90
105
|
|
|
91
|
-
yaml_content = read_yaml(string)
|
|
106
|
+
yaml_content = read_yaml(string, expand_env_vars=self.expand_env_vars)
|
|
92
107
|
|
|
93
108
|
if not validate_training_data_format_version(yaml_content, self.filename):
|
|
94
109
|
return TrainingData()
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
import structlog
|
|
2
|
+
|
|
3
|
+
from rasa.shared.constants import (
|
|
4
|
+
AWS_ACCESS_KEY_ID_ENV_VAR,
|
|
5
|
+
AWS_ACCESS_KEY_ID_CONFIG_KEY,
|
|
6
|
+
AWS_SECRET_ACCESS_KEY_ENV_VAR,
|
|
7
|
+
AWS_SECRET_ACCESS_KEY_CONFIG_KEY,
|
|
8
|
+
AWS_REGION_NAME_ENV_VAR,
|
|
9
|
+
AWS_REGION_NAME_CONFIG_KEY,
|
|
10
|
+
AWS_SESSION_TOKEN_CONFIG_KEY,
|
|
11
|
+
AWS_SESSION_TOKEN_ENV_VAR,
|
|
12
|
+
)
|
|
13
|
+
from rasa.shared.exceptions import ProviderClientValidationError
|
|
14
|
+
from litellm import validate_environment
|
|
15
|
+
from rasa.shared.providers.embedding._base_litellm_embedding_client import (
|
|
16
|
+
_VALIDATE_ENVIRONMENT_MISSING_KEYS_KEY,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
structlogger = structlog.get_logger()
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def validate_aws_setup_for_litellm_clients(
|
|
23
|
+
litellm_model_name: str, litellm_call_kwargs: dict, source_log: str
|
|
24
|
+
) -> None:
|
|
25
|
+
"""Validates the AWS setup for LiteLLM clients to ensure all required
|
|
26
|
+
environment variables or corresponding call kwargs are set.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
litellm_model_name (str): The name of the LiteLLM model being validated.
|
|
30
|
+
litellm_call_kwargs (dict): Additional keyword arguments passed to the client,
|
|
31
|
+
which may include configuration values for AWS credentials.
|
|
32
|
+
source_log (str): The source log identifier for structured logging.
|
|
33
|
+
|
|
34
|
+
Raises:
|
|
35
|
+
ProviderClientValidationError: If any required AWS environment variable
|
|
36
|
+
or corresponding configuration key is missing.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
# Mapping of environment variable names to their corresponding config keys
|
|
40
|
+
envs_to_args = {
|
|
41
|
+
AWS_ACCESS_KEY_ID_ENV_VAR: AWS_ACCESS_KEY_ID_CONFIG_KEY,
|
|
42
|
+
AWS_SECRET_ACCESS_KEY_ENV_VAR: AWS_SECRET_ACCESS_KEY_CONFIG_KEY,
|
|
43
|
+
AWS_REGION_NAME_ENV_VAR: AWS_REGION_NAME_CONFIG_KEY,
|
|
44
|
+
AWS_SESSION_TOKEN_ENV_VAR: AWS_SESSION_TOKEN_CONFIG_KEY,
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
# Validate the environment setup for the model
|
|
48
|
+
validation_info = validate_environment(litellm_model_name)
|
|
49
|
+
missing_environment_variables = validation_info.get(
|
|
50
|
+
_VALIDATE_ENVIRONMENT_MISSING_KEYS_KEY, []
|
|
51
|
+
)
|
|
52
|
+
# Filter out missing environment variables that have been set trough arguments
|
|
53
|
+
# in extra parameters
|
|
54
|
+
missing_environment_variables = [
|
|
55
|
+
missing_env_var
|
|
56
|
+
for missing_env_var in missing_environment_variables
|
|
57
|
+
if litellm_call_kwargs.get(envs_to_args.get(missing_env_var)) is None
|
|
58
|
+
]
|
|
59
|
+
|
|
60
|
+
if missing_environment_variables:
|
|
61
|
+
missing_environment_details = [
|
|
62
|
+
(
|
|
63
|
+
f"'{missing_env_var}' environment variable or "
|
|
64
|
+
f"'{envs_to_args.get(missing_env_var)}' config key"
|
|
65
|
+
)
|
|
66
|
+
for missing_env_var in missing_environment_variables
|
|
67
|
+
]
|
|
68
|
+
event_info = (
|
|
69
|
+
f"The following environment variables or configuration keys are "
|
|
70
|
+
f"missing: "
|
|
71
|
+
f"{', '.join(missing_environment_details)}. "
|
|
72
|
+
f"These settings are required for API calls."
|
|
73
|
+
)
|
|
74
|
+
structlogger.error(
|
|
75
|
+
f"{source_log}.validate_aws_environment_variables",
|
|
76
|
+
event_info=event_info,
|
|
77
|
+
missing_environment_variables=missing_environment_variables,
|
|
78
|
+
)
|
|
79
|
+
raise ProviderClientValidationError(event_info)
|
|
@@ -1,8 +1,13 @@
|
|
|
1
1
|
from typing import Any, Dict
|
|
2
2
|
|
|
3
|
+
from rasa.shared.constants import (
|
|
4
|
+
AWS_BEDROCK_PROVIDER,
|
|
5
|
+
AWS_SAGEMAKER_PROVIDER,
|
|
6
|
+
)
|
|
3
7
|
from rasa.shared.providers._configs.default_litellm_client_config import (
|
|
4
8
|
DefaultLiteLLMClientConfig,
|
|
5
9
|
)
|
|
10
|
+
from rasa.shared.providers._utils import validate_aws_setup_for_litellm_clients
|
|
6
11
|
from rasa.shared.providers.embedding._base_litellm_embedding_client import (
|
|
7
12
|
_BaseLiteLLMEmbeddingClient,
|
|
8
13
|
)
|
|
@@ -100,3 +105,22 @@ class DefaultLiteLLMEmbeddingClient(_BaseLiteLLMEmbeddingClient):
|
|
|
100
105
|
"model": self._litellm_model_name,
|
|
101
106
|
**self._litellm_extra_parameters,
|
|
102
107
|
}
|
|
108
|
+
|
|
109
|
+
def validate_client_setup(self) -> None:
|
|
110
|
+
# TODO: Temporarily disable environment variable validation for AWS setup
|
|
111
|
+
# (Bedrock and SageMaker) until resolved by either:
|
|
112
|
+
# 1. An update from the LiteLLM package addressing the issue.
|
|
113
|
+
# 2. The implementation of a Bedrock client on our end.
|
|
114
|
+
# ---
|
|
115
|
+
# This fix ensures a consistent user experience for Bedrock (and
|
|
116
|
+
# SageMaker) in Rasa by allowing AWS secrets to be provided as extra
|
|
117
|
+
# parameters without triggering validation errors due to missing AWS
|
|
118
|
+
# environment variables.
|
|
119
|
+
if self.provider.lower() in [AWS_BEDROCK_PROVIDER, AWS_SAGEMAKER_PROVIDER]:
|
|
120
|
+
validate_aws_setup_for_litellm_clients(
|
|
121
|
+
self._litellm_model_name,
|
|
122
|
+
self._litellm_extra_parameters,
|
|
123
|
+
"default_litellm_embedding_client",
|
|
124
|
+
)
|
|
125
|
+
else:
|
|
126
|
+
super().validate_client_setup()
|
|
@@ -1,8 +1,13 @@
|
|
|
1
1
|
from typing import Dict, Any
|
|
2
2
|
|
|
3
|
+
from rasa.shared.constants import (
|
|
4
|
+
AWS_BEDROCK_PROVIDER,
|
|
5
|
+
AWS_SAGEMAKER_PROVIDER,
|
|
6
|
+
)
|
|
3
7
|
from rasa.shared.providers._configs.default_litellm_client_config import (
|
|
4
8
|
DefaultLiteLLMClientConfig,
|
|
5
9
|
)
|
|
10
|
+
from rasa.shared.providers._utils import validate_aws_setup_for_litellm_clients
|
|
6
11
|
from rasa.shared.providers.llm._base_litellm_client import _BaseLiteLLMClient
|
|
7
12
|
|
|
8
13
|
|
|
@@ -82,3 +87,22 @@ class DefaultLiteLLMClient(_BaseLiteLLMClient):
|
|
|
82
87
|
to the client provider and deployed model.
|
|
83
88
|
"""
|
|
84
89
|
return self._extra_parameters
|
|
90
|
+
|
|
91
|
+
def validate_client_setup(self) -> None:
|
|
92
|
+
# TODO: Temporarily change the environment variable validation for AWS setup
|
|
93
|
+
# (Bedrock and SageMaker) until resolved by either:
|
|
94
|
+
# 1. An update from the LiteLLM package addressing the issue.
|
|
95
|
+
# 2. The implementation of a Bedrock client on our end.
|
|
96
|
+
# ---
|
|
97
|
+
# This fix ensures a consistent user experience for Bedrock (and
|
|
98
|
+
# SageMaker) in Rasa by allowing AWS secrets to be provided as extra
|
|
99
|
+
# parameters without triggering validation errors due to missing AWS
|
|
100
|
+
# environment variables.
|
|
101
|
+
if self.provider.lower() in [AWS_BEDROCK_PROVIDER, AWS_SAGEMAKER_PROVIDER]:
|
|
102
|
+
validate_aws_setup_for_litellm_clients(
|
|
103
|
+
self._litellm_model_name,
|
|
104
|
+
self._litellm_extra_parameters,
|
|
105
|
+
"default_litellm_llm_client",
|
|
106
|
+
)
|
|
107
|
+
else:
|
|
108
|
+
super().validate_client_setup()
|
rasa/shared/utils/common.py
CHANGED
|
@@ -3,14 +3,16 @@ import functools
|
|
|
3
3
|
import importlib
|
|
4
4
|
import inspect
|
|
5
5
|
import logging
|
|
6
|
+
import os
|
|
6
7
|
import pkgutil
|
|
7
8
|
import sys
|
|
8
9
|
from types import ModuleType
|
|
9
|
-
from typing import Text, Dict, Optional, Any, List, Callable, Collection, Type
|
|
10
|
+
from typing import Sequence, Text, Dict, Optional, Any, List, Callable, Collection, Type
|
|
10
11
|
|
|
11
12
|
import rasa.shared.utils.io
|
|
13
|
+
from rasa.exceptions import MissingDependencyException
|
|
12
14
|
from rasa.shared.constants import DOCS_URL_MIGRATION_GUIDE
|
|
13
|
-
from rasa.shared.exceptions import RasaException
|
|
15
|
+
from rasa.shared.exceptions import ProviderClientValidationError, RasaException
|
|
14
16
|
|
|
15
17
|
logger = logging.getLogger(__name__)
|
|
16
18
|
|
|
@@ -295,3 +297,28 @@ def warn_and_exit_if_module_path_contains_rasa_plus(
|
|
|
295
297
|
docs=DOCS_URL_MIGRATION_GUIDE,
|
|
296
298
|
)
|
|
297
299
|
sys.exit(1)
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def validate_environment(
|
|
303
|
+
required_env_vars: Sequence[str],
|
|
304
|
+
required_packages: Sequence[str],
|
|
305
|
+
component_name: str,
|
|
306
|
+
) -> None:
|
|
307
|
+
"""Make sure all needed requirements for a component are met.
|
|
308
|
+
Args:
|
|
309
|
+
required_env_vars: List of environment variables that should be set
|
|
310
|
+
required_packages: List of packages that should be installed
|
|
311
|
+
component_name: component name that needs the requirements
|
|
312
|
+
"""
|
|
313
|
+
for e in required_env_vars:
|
|
314
|
+
if not os.environ.get(e):
|
|
315
|
+
raise ProviderClientValidationError(
|
|
316
|
+
f"Missing environment variable for {component_name}: {e}"
|
|
317
|
+
)
|
|
318
|
+
for p in required_packages:
|
|
319
|
+
try:
|
|
320
|
+
importlib.import_module(p)
|
|
321
|
+
except ImportError:
|
|
322
|
+
raise MissingDependencyException(
|
|
323
|
+
f"Missing package for {component_name}: {p}"
|
|
324
|
+
)
|