rasa-pro 3.10.7__py3-none-any.whl → 3.10.7.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/api.py +8 -2
- rasa/cli/arguments/default_arguments.py +23 -2
- rasa/cli/arguments/run.py +2 -0
- rasa/cli/e2e_test.py +10 -8
- rasa/cli/inspect.py +5 -2
- rasa/cli/run.py +7 -0
- rasa/cli/studio/studio.py +1 -21
- rasa/cli/train.py +9 -4
- rasa/cli/utils.py +3 -3
- rasa/core/agent.py +2 -2
- rasa/core/brokers/kafka.py +3 -1
- rasa/core/brokers/pika.py +3 -1
- rasa/core/channels/voice_aware/utils.py +6 -5
- rasa/core/nlg/contextual_response_rephraser.py +11 -2
- rasa/{nlu → core}/persistor.py +1 -1
- rasa/core/policies/enterprise_search_policy.py +11 -2
- rasa/core/policies/intentless_policy.py +9 -2
- rasa/core/run.py +2 -1
- rasa/core/secrets_manager/constants.py +4 -0
- rasa/core/secrets_manager/factory.py +8 -0
- rasa/core/secrets_manager/vault.py +11 -1
- rasa/core/utils.py +30 -19
- rasa/dialogue_understanding/coexistence/llm_based_router.py +9 -2
- rasa/dialogue_understanding/commands/set_slot_command.py +1 -5
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +11 -2
- rasa/dialogue_understanding/generator/llm_command_generator.py +1 -1
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +15 -15
- rasa/e2e_test/e2e_test_runner.py +1 -1
- rasa/engine/graph.py +0 -1
- rasa/engine/recipes/config_files/default_config.yml +0 -3
- rasa/engine/recipes/default_recipe.py +0 -1
- rasa/engine/recipes/graph_recipe.py +0 -1
- rasa/engine/storage/local_model_storage.py +0 -1
- rasa/engine/storage/storage.py +1 -5
- rasa/model_manager/__init__.py +0 -0
- rasa/model_manager/config.py +12 -0
- rasa/model_manager/model_api.py +432 -0
- rasa/model_manager/runner_service.py +185 -0
- rasa/model_manager/socket_bridge.py +44 -0
- rasa/model_manager/trainer_service.py +240 -0
- rasa/model_manager/utils.py +27 -0
- rasa/model_service.py +44 -0
- rasa/model_training.py +11 -6
- rasa/server.py +1 -1
- rasa/shared/constants.py +2 -0
- rasa/shared/core/domain.py +101 -47
- rasa/shared/core/flows/flows_list.py +19 -6
- rasa/shared/core/flows/validation.py +25 -0
- rasa/shared/core/flows/yaml_flows_io.py +3 -24
- rasa/shared/importers/importer.py +32 -32
- rasa/shared/importers/multi_project.py +23 -11
- rasa/shared/importers/rasa.py +7 -2
- rasa/shared/importers/remote_importer.py +2 -2
- rasa/shared/importers/utils.py +3 -1
- rasa/shared/nlu/training_data/training_data.py +18 -19
- rasa/shared/utils/common.py +3 -22
- rasa/shared/utils/llm.py +28 -2
- rasa/shared/utils/schemas/model_config.yml +0 -10
- rasa/studio/auth.py +0 -4
- rasa/tracing/instrumentation/attribute_extractors.py +1 -1
- rasa/validator.py +2 -5
- rasa/version.py +1 -1
- {rasa_pro-3.10.7.dist-info → rasa_pro-3.10.7.dev2.dist-info}/METADATA +4 -4
- {rasa_pro-3.10.7.dist-info → rasa_pro-3.10.7.dev2.dist-info}/RECORD +67 -59
- {rasa_pro-3.10.7.dist-info → rasa_pro-3.10.7.dev2.dist-info}/NOTICE +0 -0
- {rasa_pro-3.10.7.dist-info → rasa_pro-3.10.7.dev2.dist-info}/WHEEL +0 -0
- {rasa_pro-3.10.7.dist-info → rasa_pro-3.10.7.dev2.dist-info}/entry_points.txt +0 -0
rasa/api.py
CHANGED
|
@@ -2,7 +2,7 @@ import asyncio
|
|
|
2
2
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Text, Union
|
|
3
3
|
|
|
4
4
|
import rasa.shared.constants
|
|
5
|
-
from rasa.
|
|
5
|
+
from rasa.core.persistor import StorageType
|
|
6
6
|
|
|
7
7
|
# WARNING: Be careful about adding any top level imports at this place!
|
|
8
8
|
# These functions are imported in `rasa.__init__` and any top level import
|
|
@@ -14,6 +14,7 @@ from rasa.nlu.persistor import StorageType
|
|
|
14
14
|
|
|
15
15
|
if TYPE_CHECKING:
|
|
16
16
|
from rasa.model_training import TrainingResult
|
|
17
|
+
from rasa.shared.importers.importer import TrainingDataImporter
|
|
17
18
|
|
|
18
19
|
|
|
19
20
|
def run(
|
|
@@ -78,6 +79,7 @@ def train(
|
|
|
78
79
|
model_to_finetune: Optional[Text] = None,
|
|
79
80
|
finetuning_epoch_fraction: float = 1.0,
|
|
80
81
|
remote_storage: Optional[StorageType] = None,
|
|
82
|
+
file_importer: Optional["TrainingDataImporter"] = None,
|
|
81
83
|
) -> "TrainingResult":
|
|
82
84
|
"""Runs Rasa Core and NLU training in `async` loop.
|
|
83
85
|
|
|
@@ -99,7 +101,10 @@ def train(
|
|
|
99
101
|
a directory in case the latest trained model should be used.
|
|
100
102
|
finetuning_epoch_fraction: The fraction currently specified training epochs
|
|
101
103
|
in the model configuration which should be used for finetuning.
|
|
102
|
-
remote_storage:
|
|
104
|
+
remote_storage: Optional name of the remote storage to
|
|
105
|
+
use for storing the model.
|
|
106
|
+
file_importer: Instance of `TrainingDataImporter` to use for training.
|
|
107
|
+
If it is not provided, a new instance will be created.
|
|
103
108
|
|
|
104
109
|
Returns:
|
|
105
110
|
An instance of `TrainingResult`.
|
|
@@ -121,6 +126,7 @@ def train(
|
|
|
121
126
|
model_to_finetune=model_to_finetune,
|
|
122
127
|
finetuning_epoch_fraction=finetuning_epoch_fraction,
|
|
123
128
|
remote_storage=remote_storage,
|
|
129
|
+
file_importer=file_importer,
|
|
124
130
|
)
|
|
125
131
|
)
|
|
126
132
|
|
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
import argparse
|
|
2
2
|
import logging
|
|
3
|
-
from
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from typing import List, Optional, Text, Union
|
|
4
5
|
|
|
5
|
-
from rasa.
|
|
6
|
+
from rasa.core.persistor import RemoteStorageType, StorageType, parse_remote_storage
|
|
6
7
|
from rasa.shared.constants import (
|
|
7
8
|
DEFAULT_CONFIG_PATH,
|
|
8
9
|
DEFAULT_DATA_PATH,
|
|
@@ -185,3 +186,23 @@ def parse_remote_storage_arg(value: str) -> StorageType:
|
|
|
185
186
|
return parse_remote_storage(value)
|
|
186
187
|
except ValueError as e:
|
|
187
188
|
raise argparse.ArgumentTypeError(str(e))
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
class SkipYamlValidation(Enum):
|
|
192
|
+
DOMAIN = "domain"
|
|
193
|
+
|
|
194
|
+
@classmethod
|
|
195
|
+
def list(cls) -> List[str]:
|
|
196
|
+
return [e.value for e in SkipYamlValidation]
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def add_skip_validation_flag(
|
|
200
|
+
parser: Union[argparse.ArgumentParser, argparse._ActionsContainer],
|
|
201
|
+
) -> None:
|
|
202
|
+
parser.add_argument(
|
|
203
|
+
"--skip-yaml-validation",
|
|
204
|
+
default=[],
|
|
205
|
+
choices=SkipYamlValidation.list(),
|
|
206
|
+
action="append",
|
|
207
|
+
help="Skip YAML validation for selected parts of the training data.",
|
|
208
|
+
)
|
rasa/cli/arguments/run.py
CHANGED
|
@@ -5,6 +5,7 @@ from typing import Union
|
|
|
5
5
|
from rasa.cli.arguments.default_arguments import (
|
|
6
6
|
add_endpoint_param,
|
|
7
7
|
add_model_param,
|
|
8
|
+
add_skip_validation_flag,
|
|
8
9
|
add_remote_storage_param,
|
|
9
10
|
)
|
|
10
11
|
from rasa.core import constants
|
|
@@ -21,6 +22,7 @@ def set_run_arguments(parser: argparse.ArgumentParser) -> None:
|
|
|
21
22
|
"""Arguments for running Rasa directly using `rasa run`."""
|
|
22
23
|
add_model_param(parser)
|
|
23
24
|
add_server_arguments(parser)
|
|
25
|
+
add_skip_validation_flag(parser)
|
|
24
26
|
|
|
25
27
|
|
|
26
28
|
def set_run_action_arguments(parser: argparse.ArgumentParser) -> None:
|
rasa/cli/e2e_test.py
CHANGED
|
@@ -221,15 +221,17 @@ def execute_e2e_tests(args: argparse.Namespace) -> None:
|
|
|
221
221
|
if args.e2e_results is not None:
|
|
222
222
|
results_path = Path(args.e2e_results)
|
|
223
223
|
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
224
|
+
if passed:
|
|
225
|
+
passed_file = rasa.cli.utils.get_e2e_results_file_name(
|
|
226
|
+
results_path, STATUS_PASSED
|
|
227
|
+
)
|
|
228
|
+
write_test_results_to_file(passed, passed_file)
|
|
228
229
|
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
230
|
+
if failed:
|
|
231
|
+
failed_file = rasa.cli.utils.get_e2e_results_file_name(
|
|
232
|
+
results_path, STATUS_FAILED
|
|
233
|
+
)
|
|
234
|
+
write_test_results_to_file(failed, failed_file)
|
|
233
235
|
|
|
234
236
|
aggregate_stats_calculator = AggregateTestStatsCalculator(
|
|
235
237
|
passed_results=passed, failed_results=failed, test_cases=test_suite.test_cases
|
rasa/cli/inspect.py
CHANGED
|
@@ -3,11 +3,12 @@ import webbrowser
|
|
|
3
3
|
from asyncio import AbstractEventLoop
|
|
4
4
|
from typing import List, Text
|
|
5
5
|
|
|
6
|
+
from sanic import Sanic
|
|
7
|
+
|
|
6
8
|
from rasa.cli import SubParsersAction
|
|
7
9
|
from rasa.cli.arguments import shell as arguments
|
|
10
|
+
from rasa.cli.arguments.default_arguments import add_skip_validation_flag
|
|
8
11
|
from rasa.core import constants
|
|
9
|
-
from sanic import Sanic
|
|
10
|
-
|
|
11
12
|
from rasa.utils.cli import remove_argument_from_parser
|
|
12
13
|
|
|
13
14
|
|
|
@@ -33,6 +34,8 @@ def add_subparser(
|
|
|
33
34
|
inspect_parser.set_defaults(func=inspect)
|
|
34
35
|
|
|
35
36
|
arguments.set_shell_arguments(inspect_parser)
|
|
37
|
+
add_skip_validation_flag(inspect_parser)
|
|
38
|
+
|
|
36
39
|
# it'd be confusing to expose those arguments to the user,
|
|
37
40
|
# so we remove them
|
|
38
41
|
remove_argument_from_parser(inspect_parser, "--credentials")
|
rasa/cli/run.py
CHANGED
|
@@ -6,6 +6,7 @@ from typing import List, Text
|
|
|
6
6
|
from rasa.api import run as rasa_run
|
|
7
7
|
from rasa.cli import SubParsersAction
|
|
8
8
|
from rasa.cli.arguments import run as arguments
|
|
9
|
+
from rasa.cli.arguments.default_arguments import SkipYamlValidation
|
|
9
10
|
from rasa.cli.utils import get_validated_path
|
|
10
11
|
from rasa.exceptions import ModelNotFound
|
|
11
12
|
from rasa.shared.constants import (
|
|
@@ -15,6 +16,7 @@ from rasa.shared.constants import (
|
|
|
15
16
|
DEFAULT_MODELS_PATH,
|
|
16
17
|
DOCS_BASE_URL,
|
|
17
18
|
)
|
|
19
|
+
from rasa.shared.core.domain import Domain
|
|
18
20
|
from rasa.shared.utils.cli import print_error
|
|
19
21
|
|
|
20
22
|
logger = logging.getLogger(__name__)
|
|
@@ -87,6 +89,11 @@ def run(args: argparse.Namespace) -> None:
|
|
|
87
89
|
args.credentials, "credentials", DEFAULT_CREDENTIALS_PATH, True
|
|
88
90
|
)
|
|
89
91
|
|
|
92
|
+
if SkipYamlValidation.DOMAIN.value in args.skip_yaml_validation:
|
|
93
|
+
Domain.validate_yaml = False
|
|
94
|
+
else:
|
|
95
|
+
Domain.validate_yaml = True
|
|
96
|
+
|
|
90
97
|
if args.enable_api:
|
|
91
98
|
if not args.remote_storage:
|
|
92
99
|
args.model = _validate_model_path(args.model, "model", DEFAULT_MODELS_PATH)
|
rasa/cli/studio/studio.py
CHANGED
|
@@ -5,7 +5,6 @@ from urllib.parse import ParseResult, urlparse
|
|
|
5
5
|
import questionary
|
|
6
6
|
from rasa.cli import SubParsersAction
|
|
7
7
|
|
|
8
|
-
import rasa.shared.utils.cli
|
|
9
8
|
import rasa.cli.studio.download
|
|
10
9
|
import rasa.cli.studio.train
|
|
11
10
|
import rasa.cli.studio.upload
|
|
@@ -51,15 +50,6 @@ def _add_config_subparser(
|
|
|
51
50
|
|
|
52
51
|
studio_config_parser.set_defaults(func=create_and_store_studio_config)
|
|
53
52
|
|
|
54
|
-
studio_config_parser.add_argument(
|
|
55
|
-
"--disable-verify",
|
|
56
|
-
"-x",
|
|
57
|
-
action="store_true",
|
|
58
|
-
default=False,
|
|
59
|
-
help="Disable strict SSL verification for the "
|
|
60
|
-
"Rasa Studio authentication server.",
|
|
61
|
-
)
|
|
62
|
-
|
|
63
53
|
# add advanced configuration flag to trigger
|
|
64
54
|
# advanced configuration setup for authentication settings
|
|
65
55
|
studio_config_parser.add_argument(
|
|
@@ -229,17 +219,7 @@ def _configure_studio_config(args: argparse.Namespace) -> StudioConfig:
|
|
|
229
219
|
studio_config = _create_studio_config(
|
|
230
220
|
studio_url, keycloak_url, realm_name, client_id
|
|
231
221
|
)
|
|
232
|
-
|
|
233
|
-
if args.disable_verify:
|
|
234
|
-
rasa.shared.utils.cli.print_info(
|
|
235
|
-
"Disabling SSL verification for the Rasa Studio authentication server."
|
|
236
|
-
)
|
|
237
|
-
studio_auth = StudioAuth(studio_config, verify=False)
|
|
238
|
-
else:
|
|
239
|
-
rasa.shared.utils.cli.print_info(
|
|
240
|
-
"Enabling SSL verification for the Rasa Studio authentication server."
|
|
241
|
-
)
|
|
242
|
-
studio_auth = StudioAuth(studio_config, verify=True)
|
|
222
|
+
studio_auth = StudioAuth(studio_config)
|
|
243
223
|
|
|
244
224
|
if _check_studio_auth(studio_auth):
|
|
245
225
|
return studio_config
|
rasa/cli/train.py
CHANGED
|
@@ -110,16 +110,20 @@ def run_training(args: argparse.Namespace, can_exit: bool = False) -> Optional[T
|
|
|
110
110
|
for f in args.data
|
|
111
111
|
]
|
|
112
112
|
|
|
113
|
+
training_data_importer = TrainingDataImporter.load_from_config(
|
|
114
|
+
domain_path=domain, training_data_paths=args.data, config_path=config
|
|
115
|
+
)
|
|
116
|
+
|
|
113
117
|
if not args.skip_validation:
|
|
114
118
|
structlogger.info(
|
|
115
119
|
"cli.train.run_training",
|
|
116
120
|
event_info="Started validating domain and training data...",
|
|
117
121
|
)
|
|
118
|
-
|
|
119
|
-
domain_path=domain, training_data_paths=args.data, config_path=config
|
|
120
|
-
)
|
|
122
|
+
|
|
121
123
|
rasa.cli.utils.validate_files(
|
|
122
|
-
args.fail_on_validation_warnings,
|
|
124
|
+
args.fail_on_validation_warnings,
|
|
125
|
+
args.validation_max_history,
|
|
126
|
+
training_data_importer,
|
|
123
127
|
)
|
|
124
128
|
|
|
125
129
|
training_result = train_all(
|
|
@@ -138,6 +142,7 @@ def run_training(args: argparse.Namespace, can_exit: bool = False) -> Optional[T
|
|
|
138
142
|
model_to_finetune=_model_for_finetuning(args),
|
|
139
143
|
finetuning_epoch_fraction=args.epoch_fraction,
|
|
140
144
|
remote_storage=args.remote_storage,
|
|
145
|
+
file_importer=training_data_importer,
|
|
141
146
|
)
|
|
142
147
|
if training_result.code != 0 and can_exit:
|
|
143
148
|
sys.exit(training_result.code)
|
rasa/cli/utils.py
CHANGED
|
@@ -470,10 +470,10 @@ def get_e2e_results_file_name(
|
|
|
470
470
|
) -> str:
|
|
471
471
|
"""Returns the name of the e2e results file."""
|
|
472
472
|
if results_output_path.is_dir():
|
|
473
|
-
file_name =
|
|
473
|
+
file_name = results_output_path / f"e2e_results_{result_type}.yml"
|
|
474
474
|
else:
|
|
475
475
|
parent = results_output_path.parent
|
|
476
476
|
stem = results_output_path.stem
|
|
477
|
-
file_name =
|
|
477
|
+
file_name = parent / f"{stem}_{result_type}.yml"
|
|
478
478
|
|
|
479
|
-
return file_name
|
|
479
|
+
return str(file_name)
|
rasa/core/agent.py
CHANGED
|
@@ -19,6 +19,7 @@ from rasa.core.exceptions import AgentNotReady
|
|
|
19
19
|
from rasa.core.http_interpreter import RasaNLUHttpInterpreter
|
|
20
20
|
from rasa.core.lock_store import InMemoryLockStore, LockStore
|
|
21
21
|
from rasa.core.nlg import NaturalLanguageGenerator, TemplatedNaturalLanguageGenerator
|
|
22
|
+
from rasa.core.persistor import StorageType
|
|
22
23
|
from rasa.core.policies.policy import PolicyPrediction
|
|
23
24
|
from rasa.core.processor import MessageProcessor
|
|
24
25
|
from rasa.core.tracker_store import (
|
|
@@ -28,7 +29,6 @@ from rasa.core.tracker_store import (
|
|
|
28
29
|
)
|
|
29
30
|
from rasa.core.utils import AvailableEndpoints
|
|
30
31
|
from rasa.exceptions import ModelNotFound
|
|
31
|
-
from rasa.nlu.persistor import StorageType
|
|
32
32
|
from rasa.nlu.utils import is_url
|
|
33
33
|
from rasa.shared.constants import DEFAULT_SENDER_ID
|
|
34
34
|
from rasa.shared.core.domain import Domain
|
|
@@ -544,7 +544,7 @@ class Agent:
|
|
|
544
544
|
|
|
545
545
|
def load_model_from_remote_storage(self, model_name: Text) -> None:
|
|
546
546
|
"""Loads an Agent from remote storage."""
|
|
547
|
-
from rasa.
|
|
547
|
+
from rasa.core.persistor import get_persistor
|
|
548
548
|
|
|
549
549
|
persistor = get_persistor(self.remote_storage)
|
|
550
550
|
|
rasa/core/brokers/kafka.py
CHANGED
|
@@ -2,6 +2,8 @@ import asyncio
|
|
|
2
2
|
import os
|
|
3
3
|
import json
|
|
4
4
|
import logging
|
|
5
|
+
from functools import cached_property
|
|
6
|
+
|
|
5
7
|
import structlog
|
|
6
8
|
import threading
|
|
7
9
|
from asyncio import AbstractEventLoop
|
|
@@ -270,7 +272,7 @@ class KafkaEventBroker(EventBroker):
|
|
|
270
272
|
if self.producer:
|
|
271
273
|
self.producer.flush()
|
|
272
274
|
|
|
273
|
-
@
|
|
275
|
+
@cached_property
|
|
274
276
|
def rasa_environment(self) -> Optional[Text]:
|
|
275
277
|
"""Get value of the `RASA_ENVIRONMENT` environment variable."""
|
|
276
278
|
return os.environ.get("RASA_ENVIRONMENT", "RASA_ENVIRONMENT_NOT_SET")
|
rasa/core/brokers/pika.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import json
|
|
3
3
|
import logging
|
|
4
|
+
from functools import cached_property
|
|
5
|
+
|
|
4
6
|
import structlog
|
|
5
7
|
import os
|
|
6
8
|
import ssl
|
|
@@ -333,7 +335,7 @@ class PikaEventBroker(EventBroker):
|
|
|
333
335
|
delivery_mode=aio_pika.DeliveryMode.PERSISTENT,
|
|
334
336
|
)
|
|
335
337
|
|
|
336
|
-
@
|
|
338
|
+
@cached_property
|
|
337
339
|
def rasa_environment(self) -> Optional[Text]:
|
|
338
340
|
"""Get value of the `RASA_ENVIRONMENT` environment variable."""
|
|
339
341
|
return os.environ.get("RASA_ENVIRONMENT")
|
|
@@ -1,15 +1,16 @@
|
|
|
1
1
|
import structlog
|
|
2
2
|
|
|
3
|
-
from rasa.utils.licensing import (
|
|
4
|
-
PRODUCT_AREA,
|
|
5
|
-
VOICE_SCOPE,
|
|
6
|
-
validate_license_from_env,
|
|
7
|
-
)
|
|
8
3
|
|
|
9
4
|
structlogger = structlog.get_logger()
|
|
10
5
|
|
|
11
6
|
|
|
12
7
|
def validate_voice_license_scope() -> None:
|
|
8
|
+
from rasa.utils.licensing import (
|
|
9
|
+
PRODUCT_AREA,
|
|
10
|
+
VOICE_SCOPE,
|
|
11
|
+
validate_license_from_env,
|
|
12
|
+
)
|
|
13
|
+
|
|
13
14
|
"""Validate that the correct license scope is present."""
|
|
14
15
|
structlogger.info(
|
|
15
16
|
f"Validating current Rasa Pro license scope which must include "
|
|
@@ -1,11 +1,13 @@
|
|
|
1
1
|
from typing import Any, Dict, Optional, Text
|
|
2
2
|
|
|
3
|
+
import os
|
|
3
4
|
import structlog
|
|
4
5
|
from jinja2 import Template
|
|
5
6
|
|
|
6
7
|
from rasa import telemetry
|
|
7
8
|
from rasa.core.nlg.response import TemplatedNaturalLanguageGenerator
|
|
8
9
|
from rasa.shared.constants import (
|
|
10
|
+
LLM_API_HEALTH_CHECK_ENV_VAR,
|
|
9
11
|
LLM_CONFIG_KEY,
|
|
10
12
|
MODEL_CONFIG_KEY,
|
|
11
13
|
MODEL_NAME_CONFIG_KEY,
|
|
@@ -23,6 +25,7 @@ from rasa.shared.utils.llm import (
|
|
|
23
25
|
USER,
|
|
24
26
|
combine_custom_and_default_config,
|
|
25
27
|
get_prompt_template,
|
|
28
|
+
llm_api_health_check,
|
|
26
29
|
llm_factory,
|
|
27
30
|
try_instantiate_llm_client,
|
|
28
31
|
)
|
|
@@ -97,12 +100,18 @@ class ContextualResponseRephraser(TemplatedNaturalLanguageGenerator):
|
|
|
97
100
|
self.trace_prompt_tokens = self.nlg_endpoint.kwargs.get(
|
|
98
101
|
"trace_prompt_tokens", False
|
|
99
102
|
)
|
|
100
|
-
try_instantiate_llm_client(
|
|
103
|
+
llm_client = try_instantiate_llm_client(
|
|
101
104
|
self.nlg_endpoint.kwargs.get(LLM_CONFIG_KEY),
|
|
102
105
|
DEFAULT_LLM_CONFIG,
|
|
103
106
|
"contextual_response_rephraser.init",
|
|
104
|
-
|
|
107
|
+
ContextualResponseRephraser.__name__,
|
|
105
108
|
)
|
|
109
|
+
if os.getenv(LLM_API_HEALTH_CHECK_ENV_VAR, "true").lower() == "true":
|
|
110
|
+
llm_api_health_check(
|
|
111
|
+
llm_client,
|
|
112
|
+
"contextual_response_rephraser.init",
|
|
113
|
+
ContextualResponseRephraser.__name__,
|
|
114
|
+
)
|
|
106
115
|
|
|
107
116
|
def _last_message_if_human(self, tracker: DialogueStateTracker) -> Optional[str]:
|
|
108
117
|
"""Returns the latest message from the tracker.
|
rasa/{nlu → core}/persistor.py
RENAMED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import importlib.resources
|
|
2
2
|
import json
|
|
3
|
+
import os
|
|
3
4
|
import re
|
|
4
5
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Text
|
|
5
6
|
|
|
@@ -47,6 +48,7 @@ from rasa.graph_components.providers.forms_provider import Forms
|
|
|
47
48
|
from rasa.graph_components.providers.responses_provider import Responses
|
|
48
49
|
from rasa.shared.constants import (
|
|
49
50
|
EMBEDDINGS_CONFIG_KEY,
|
|
51
|
+
LLM_API_HEALTH_CHECK_ENV_VAR,
|
|
50
52
|
LLM_CONFIG_KEY,
|
|
51
53
|
MODEL_CONFIG_KEY,
|
|
52
54
|
MODEL_NAME_CONFIG_KEY,
|
|
@@ -72,6 +74,7 @@ from rasa.shared.utils.llm import (
|
|
|
72
74
|
DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
|
|
73
75
|
embedder_factory,
|
|
74
76
|
get_prompt_template,
|
|
77
|
+
llm_api_health_check,
|
|
75
78
|
llm_factory,
|
|
76
79
|
sanitize_message_for_prompt,
|
|
77
80
|
tracker_as_readable_transcript,
|
|
@@ -292,12 +295,18 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
292
295
|
)
|
|
293
296
|
|
|
294
297
|
# validate llm configuration
|
|
295
|
-
try_instantiate_llm_client(
|
|
298
|
+
llm_client = try_instantiate_llm_client(
|
|
296
299
|
self.config.get(LLM_CONFIG_KEY),
|
|
297
300
|
DEFAULT_LLM_CONFIG,
|
|
298
301
|
"enterprise_search_policy.train",
|
|
299
|
-
|
|
302
|
+
EnterpriseSearchPolicy.__name__,
|
|
300
303
|
)
|
|
304
|
+
if os.getenv(LLM_API_HEALTH_CHECK_ENV_VAR, "true").lower() == "true":
|
|
305
|
+
llm_api_health_check(
|
|
306
|
+
llm_client,
|
|
307
|
+
"enterprise_search_policy.train",
|
|
308
|
+
EnterpriseSearchPolicy.__name__,
|
|
309
|
+
)
|
|
301
310
|
|
|
302
311
|
if store_type == DEFAULT_VECTOR_STORE_TYPE:
|
|
303
312
|
logger.info("enterprise_search_policy.train.faiss")
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import importlib.resources
|
|
2
2
|
import math
|
|
3
|
+
import os
|
|
3
4
|
from dataclasses import dataclass, field
|
|
4
5
|
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Text, Tuple
|
|
5
6
|
|
|
@@ -31,6 +32,7 @@ from rasa.graph_components.providers.responses_provider import Responses
|
|
|
31
32
|
from rasa.shared.constants import (
|
|
32
33
|
REQUIRED_SLOTS_KEY,
|
|
33
34
|
EMBEDDINGS_CONFIG_KEY,
|
|
35
|
+
LLM_API_HEALTH_CHECK_ENV_VAR,
|
|
34
36
|
LLM_CONFIG_KEY,
|
|
35
37
|
MODEL_CONFIG_KEY,
|
|
36
38
|
MODEL_NAME_CONFIG_KEY,
|
|
@@ -67,6 +69,7 @@ from rasa.shared.utils.llm import (
|
|
|
67
69
|
combine_custom_and_default_config,
|
|
68
70
|
embedder_factory,
|
|
69
71
|
get_prompt_template,
|
|
72
|
+
llm_api_health_check,
|
|
70
73
|
llm_factory,
|
|
71
74
|
sanitize_message_for_prompt,
|
|
72
75
|
tracker_as_readable_transcript,
|
|
@@ -487,12 +490,16 @@ class IntentlessPolicy(Policy):
|
|
|
487
490
|
A policy must return its resource locator so that potential children nodes
|
|
488
491
|
can load the policy from the resource.
|
|
489
492
|
"""
|
|
490
|
-
try_instantiate_llm_client(
|
|
493
|
+
llm_client = try_instantiate_llm_client(
|
|
491
494
|
self.config.get(LLM_CONFIG_KEY),
|
|
492
495
|
DEFAULT_LLM_CONFIG,
|
|
493
496
|
"intentless_policy.train",
|
|
494
|
-
|
|
497
|
+
IntentlessPolicy.__name__,
|
|
495
498
|
)
|
|
499
|
+
if os.getenv(LLM_API_HEALTH_CHECK_ENV_VAR, "true").lower() == "true":
|
|
500
|
+
llm_api_health_check(
|
|
501
|
+
llm_client, "intentless_policy.train", IntentlessPolicy.__name__
|
|
502
|
+
)
|
|
496
503
|
|
|
497
504
|
responses = filter_responses(responses, forms, flows or FlowsList([]))
|
|
498
505
|
telemetry.track_intentless_policy_train()
|
rasa/core/run.py
CHANGED
|
@@ -32,8 +32,8 @@ from rasa.core import agent, channels, constants
|
|
|
32
32
|
from rasa.core.agent import Agent
|
|
33
33
|
from rasa.core.channels import console
|
|
34
34
|
from rasa.core.channels.channel import InputChannel
|
|
35
|
+
from rasa.core.persistor import StorageType
|
|
35
36
|
from rasa.core.utils import AvailableEndpoints
|
|
36
|
-
from rasa.nlu.persistor import StorageType
|
|
37
37
|
from rasa.plugin import plugin_manager
|
|
38
38
|
from rasa.shared.exceptions import RasaException
|
|
39
39
|
from rasa.shared.utils.yaml import read_config_file
|
|
@@ -311,6 +311,7 @@ async def load_agent_on_start(
|
|
|
311
311
|
endpoints=endpoints,
|
|
312
312
|
loop=loop,
|
|
313
313
|
)
|
|
314
|
+
|
|
314
315
|
logger.info("Rasa server is up and running.")
|
|
315
316
|
return app.ctx.agent
|
|
316
317
|
|
|
@@ -23,6 +23,7 @@ VAULT_TRANSIT_MOUNT_POINT_ENV_NAME = "VAULT_TRANSIT_MOUNT_POINT"
|
|
|
23
23
|
VAULT_NAMESPACE_ENV_NAME = "VAULT_NAMESPACE"
|
|
24
24
|
VAULT_DEFAULT_RASA_SECRETS_PATH = "rasa-secrets"
|
|
25
25
|
VAULT_SECRET_MANAGER_NAME = "vault"
|
|
26
|
+
VAULT_MOUNT_POINT_ENV_NAME = "VAULT_MOUNT_POINT"
|
|
26
27
|
|
|
27
28
|
|
|
28
29
|
VAULT_ENDPOINT_URL_LABEL = "url"
|
|
@@ -30,3 +31,6 @@ VAULT_ENDPOINT_TOKEN_LABEL = "token"
|
|
|
30
31
|
VAULT_ENDPOINT_SECRETS_PATH_LABEL = "secrets_path"
|
|
31
32
|
VAULT_ENDPOINT_TRANSIT_MOUNT_POINT_LABEL = "transit_mount_point"
|
|
32
33
|
VAULT_ENDPOINT_NAMESPACE_LABEL = "namespace"
|
|
34
|
+
VAULT_ENDPOINT_MOUNT_POINT_LABEL = "mount_point"
|
|
35
|
+
|
|
36
|
+
VAULT_MOUNT_POINT_DEFAULT_VALUE = "secret"
|
|
@@ -7,9 +7,11 @@ from rasa.utils.endpoints import EndpointConfig, read_endpoint_config
|
|
|
7
7
|
from rasa.core.secrets_manager.constants import (
|
|
8
8
|
SECRET_MANAGER_ENV_NAME,
|
|
9
9
|
VAULT_DEFAULT_RASA_SECRETS_PATH,
|
|
10
|
+
VAULT_ENDPOINT_MOUNT_POINT_LABEL,
|
|
10
11
|
VAULT_ENDPOINT_NAMESPACE_LABEL,
|
|
11
12
|
VAULT_ENDPOINT_SECRETS_PATH_LABEL,
|
|
12
13
|
VAULT_ENDPOINT_TRANSIT_MOUNT_POINT_LABEL,
|
|
14
|
+
VAULT_MOUNT_POINT_ENV_NAME,
|
|
13
15
|
VAULT_NAMESPACE_ENV_NAME,
|
|
14
16
|
VAULT_RASA_SECRETS_PATH_ENV_NAME,
|
|
15
17
|
VAULT_SECRET_MANAGER_NAME,
|
|
@@ -48,6 +50,7 @@ def create(config: SecretManagerConfig) -> Optional[SecretsManager]:
|
|
|
48
50
|
transit_mount_point=vault_config.transit_mount_point,
|
|
49
51
|
secrets_path=vault_config.secrets_path,
|
|
50
52
|
namespace=vault_config.namespace,
|
|
53
|
+
mount_point=vault_config.mount_point,
|
|
51
54
|
)
|
|
52
55
|
|
|
53
56
|
return secret_manager
|
|
@@ -79,6 +82,7 @@ def read_vault_endpoint_config(
|
|
|
79
82
|
)
|
|
80
83
|
secrets_path = endpoint_config.kwargs.get(VAULT_ENDPOINT_SECRETS_PATH_LABEL)
|
|
81
84
|
namespace = endpoint_config.kwargs.get(VAULT_ENDPOINT_NAMESPACE_LABEL)
|
|
85
|
+
mount_point = endpoint_config.kwargs.get(VAULT_ENDPOINT_MOUNT_POINT_LABEL)
|
|
82
86
|
|
|
83
87
|
return VaultSecretManagerNonStrictConfig(
|
|
84
88
|
url=url,
|
|
@@ -86,6 +90,7 @@ def read_vault_endpoint_config(
|
|
|
86
90
|
transit_mount_point=transit_mount_point,
|
|
87
91
|
secrets_path=secrets_path or VAULT_DEFAULT_RASA_SECRETS_PATH,
|
|
88
92
|
namespace=namespace,
|
|
93
|
+
mount_point=mount_point,
|
|
89
94
|
)
|
|
90
95
|
|
|
91
96
|
return None
|
|
@@ -102,6 +107,7 @@ def read_vault_env_vars() -> VaultSecretManagerNonStrictConfig:
|
|
|
102
107
|
transit_mount_point = os.getenv(VAULT_TRANSIT_MOUNT_POINT_ENV_NAME)
|
|
103
108
|
secrets_path = os.getenv(VAULT_RASA_SECRETS_PATH_ENV_NAME)
|
|
104
109
|
namespace = os.getenv(VAULT_NAMESPACE_ENV_NAME)
|
|
110
|
+
mount_point = os.getenv(VAULT_MOUNT_POINT_ENV_NAME)
|
|
105
111
|
|
|
106
112
|
return VaultSecretManagerNonStrictConfig(
|
|
107
113
|
url=url,
|
|
@@ -109,6 +115,7 @@ def read_vault_env_vars() -> VaultSecretManagerNonStrictConfig:
|
|
|
109
115
|
transit_mount_point=transit_mount_point,
|
|
110
116
|
secrets_path=secrets_path,
|
|
111
117
|
namespace=namespace,
|
|
118
|
+
mount_point=mount_point,
|
|
112
119
|
)
|
|
113
120
|
|
|
114
121
|
|
|
@@ -149,6 +156,7 @@ def read_vault_config(
|
|
|
149
156
|
f"{VAULT_RASA_SECRETS_PATH_ENV_NAME} = {env_config.secrets_path}, "
|
|
150
157
|
f"{VAULT_TRANSIT_MOUNT_POINT_ENV_NAME} = {env_config.transit_mount_point}. "
|
|
151
158
|
f"{VAULT_NAMESPACE_ENV_NAME} = {env_config.namespace}. "
|
|
159
|
+
f"{VAULT_MOUNT_POINT_ENV_NAME} = {env_config.mount_point}. "
|
|
152
160
|
)
|
|
153
161
|
|
|
154
162
|
|
|
@@ -15,6 +15,7 @@ from rasa.utils.endpoints import EndpointConfig
|
|
|
15
15
|
from rasa.core.secrets_manager.constants import (
|
|
16
16
|
TRACKER_STORE_ENDPOINT_TYPE,
|
|
17
17
|
TRANSIT_KEY_FOR_ENCRYPTION_LABEL,
|
|
18
|
+
VAULT_MOUNT_POINT_DEFAULT_VALUE,
|
|
18
19
|
VAULT_SECRET_MANAGER_NAME,
|
|
19
20
|
)
|
|
20
21
|
from rasa.core.secrets_manager.endpoints import (
|
|
@@ -181,6 +182,7 @@ class VaultSecretsManager(SecretsManager):
|
|
|
181
182
|
secrets_path: Text,
|
|
182
183
|
transit_mount_point: Optional[Text] = None,
|
|
183
184
|
namespace: Optional[Text] = None,
|
|
185
|
+
mount_point: Optional[Text] = None,
|
|
184
186
|
):
|
|
185
187
|
"""Initialise the VaultSecretsManager.
|
|
186
188
|
|
|
@@ -190,11 +192,13 @@ class VaultSecretsManager(SecretsManager):
|
|
|
190
192
|
secrets_path: The path to the secrets in the vault server.
|
|
191
193
|
transit_mount_point: The mount point of the transit engine.
|
|
192
194
|
namespace: The namespace in which secrets reside in.
|
|
195
|
+
mount_point: The mount point of the kv engine.
|
|
193
196
|
"""
|
|
194
197
|
self.host = host
|
|
195
198
|
self.transit_mount_point = transit_mount_point
|
|
196
199
|
self.token = token
|
|
197
200
|
self.secrets_path = secrets_path
|
|
201
|
+
self.mount_point = mount_point or VAULT_MOUNT_POINT_DEFAULT_VALUE
|
|
198
202
|
self.namespace = namespace
|
|
199
203
|
|
|
200
204
|
# Create client
|
|
@@ -236,7 +240,7 @@ class VaultSecretsManager(SecretsManager):
|
|
|
236
240
|
"""
|
|
237
241
|
logger.info(f"Loading secrets from vault server at {self.host}.")
|
|
238
242
|
read_response = self.client.secrets.kv.read_secret_version(
|
|
239
|
-
mount_point=
|
|
243
|
+
mount_point=self.mount_point, path=self.secrets_path
|
|
240
244
|
)
|
|
241
245
|
|
|
242
246
|
secrets = read_response["data"]["data"]
|
|
@@ -455,6 +459,7 @@ class VaultSecretManagerConfig(SecretManagerConfig):
|
|
|
455
459
|
secrets_path: Text,
|
|
456
460
|
transit_mount_point: Text = "transit",
|
|
457
461
|
namespace: Optional[Text] = None,
|
|
462
|
+
mount_point: Optional[Text] = None,
|
|
458
463
|
) -> None:
|
|
459
464
|
"""Initialise the VaultSecretManagerConfig.
|
|
460
465
|
|
|
@@ -471,6 +476,7 @@ class VaultSecretManagerConfig(SecretManagerConfig):
|
|
|
471
476
|
self.secrets_path = secrets_path
|
|
472
477
|
self.transit_mount_point = transit_mount_point
|
|
473
478
|
self.namespace = namespace
|
|
479
|
+
self.mount_point = mount_point
|
|
474
480
|
|
|
475
481
|
|
|
476
482
|
@dataclass
|
|
@@ -486,6 +492,7 @@ class VaultSecretManagerNonStrictConfig:
|
|
|
486
492
|
secrets_path: Optional[Text]
|
|
487
493
|
transit_mount_point: Optional[Text]
|
|
488
494
|
namespace: Optional[Text] = None
|
|
495
|
+
mount_point: Optional[Text] = None
|
|
489
496
|
|
|
490
497
|
def is_empty(self) -> bool:
|
|
491
498
|
"""Check if all the values are empty."""
|
|
@@ -495,6 +502,7 @@ class VaultSecretManagerNonStrictConfig:
|
|
|
495
502
|
and (self.secrets_path is None or self.secrets_path == "")
|
|
496
503
|
and (self.transit_mount_point is None or self.transit_mount_point == "")
|
|
497
504
|
and (self.namespace is None or self.namespace == "")
|
|
505
|
+
and (self.mount_point is None or self.mount_point == "")
|
|
498
506
|
)
|
|
499
507
|
|
|
500
508
|
def is_valid(self) -> bool:
|
|
@@ -516,6 +524,7 @@ class VaultSecretManagerNonStrictConfig:
|
|
|
516
524
|
and self.secrets_path != ""
|
|
517
525
|
and self._is_optional_value_valid(self.transit_mount_point)
|
|
518
526
|
and self._is_optional_value_valid(self.namespace)
|
|
527
|
+
and self._is_optional_value_valid(self.mount_point)
|
|
519
528
|
)
|
|
520
529
|
|
|
521
530
|
@staticmethod
|
|
@@ -547,6 +556,7 @@ class VaultSecretManagerNonStrictConfig:
|
|
|
547
556
|
secrets_path=self.secrets_path or other.secrets_path,
|
|
548
557
|
transit_mount_point=self.transit_mount_point or other.transit_mount_point,
|
|
549
558
|
namespace=self.namespace or other.namespace,
|
|
559
|
+
mount_point=self.mount_point or other.mount_point,
|
|
550
560
|
)
|
|
551
561
|
|
|
552
562
|
|