rasa-pro 3.10.7__py3-none-any.whl → 3.10.7.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (67) hide show
  1. rasa/api.py +8 -2
  2. rasa/cli/arguments/default_arguments.py +23 -2
  3. rasa/cli/arguments/run.py +2 -0
  4. rasa/cli/e2e_test.py +10 -8
  5. rasa/cli/inspect.py +5 -2
  6. rasa/cli/run.py +7 -0
  7. rasa/cli/studio/studio.py +1 -21
  8. rasa/cli/train.py +9 -4
  9. rasa/cli/utils.py +3 -3
  10. rasa/core/agent.py +2 -2
  11. rasa/core/brokers/kafka.py +3 -1
  12. rasa/core/brokers/pika.py +3 -1
  13. rasa/core/channels/voice_aware/utils.py +6 -5
  14. rasa/core/nlg/contextual_response_rephraser.py +11 -2
  15. rasa/{nlu → core}/persistor.py +1 -1
  16. rasa/core/policies/enterprise_search_policy.py +11 -2
  17. rasa/core/policies/intentless_policy.py +9 -2
  18. rasa/core/run.py +2 -1
  19. rasa/core/secrets_manager/constants.py +4 -0
  20. rasa/core/secrets_manager/factory.py +8 -0
  21. rasa/core/secrets_manager/vault.py +11 -1
  22. rasa/core/utils.py +30 -19
  23. rasa/dialogue_understanding/coexistence/llm_based_router.py +9 -2
  24. rasa/dialogue_understanding/commands/set_slot_command.py +1 -5
  25. rasa/dialogue_understanding/generator/llm_based_command_generator.py +11 -2
  26. rasa/dialogue_understanding/generator/llm_command_generator.py +1 -1
  27. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +15 -15
  28. rasa/e2e_test/e2e_test_runner.py +1 -1
  29. rasa/engine/graph.py +0 -1
  30. rasa/engine/recipes/config_files/default_config.yml +0 -3
  31. rasa/engine/recipes/default_recipe.py +0 -1
  32. rasa/engine/recipes/graph_recipe.py +0 -1
  33. rasa/engine/storage/local_model_storage.py +0 -1
  34. rasa/engine/storage/storage.py +1 -5
  35. rasa/model_manager/__init__.py +0 -0
  36. rasa/model_manager/config.py +12 -0
  37. rasa/model_manager/model_api.py +432 -0
  38. rasa/model_manager/runner_service.py +185 -0
  39. rasa/model_manager/socket_bridge.py +44 -0
  40. rasa/model_manager/trainer_service.py +240 -0
  41. rasa/model_manager/utils.py +27 -0
  42. rasa/model_service.py +44 -0
  43. rasa/model_training.py +11 -6
  44. rasa/server.py +1 -1
  45. rasa/shared/constants.py +2 -0
  46. rasa/shared/core/domain.py +101 -47
  47. rasa/shared/core/flows/flows_list.py +19 -6
  48. rasa/shared/core/flows/validation.py +25 -0
  49. rasa/shared/core/flows/yaml_flows_io.py +3 -24
  50. rasa/shared/importers/importer.py +32 -32
  51. rasa/shared/importers/multi_project.py +23 -11
  52. rasa/shared/importers/rasa.py +7 -2
  53. rasa/shared/importers/remote_importer.py +2 -2
  54. rasa/shared/importers/utils.py +3 -1
  55. rasa/shared/nlu/training_data/training_data.py +18 -19
  56. rasa/shared/utils/common.py +3 -22
  57. rasa/shared/utils/llm.py +28 -2
  58. rasa/shared/utils/schemas/model_config.yml +0 -10
  59. rasa/studio/auth.py +0 -4
  60. rasa/tracing/instrumentation/attribute_extractors.py +1 -1
  61. rasa/validator.py +2 -5
  62. rasa/version.py +1 -1
  63. {rasa_pro-3.10.7.dist-info → rasa_pro-3.10.7.dev2.dist-info}/METADATA +4 -4
  64. {rasa_pro-3.10.7.dist-info → rasa_pro-3.10.7.dev2.dist-info}/RECORD +67 -59
  65. {rasa_pro-3.10.7.dist-info → rasa_pro-3.10.7.dev2.dist-info}/NOTICE +0 -0
  66. {rasa_pro-3.10.7.dist-info → rasa_pro-3.10.7.dev2.dist-info}/WHEEL +0 -0
  67. {rasa_pro-3.10.7.dist-info → rasa_pro-3.10.7.dev2.dist-info}/entry_points.txt +0 -0
rasa/api.py CHANGED
@@ -2,7 +2,7 @@ import asyncio
2
2
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, Text, Union
3
3
 
4
4
  import rasa.shared.constants
5
- from rasa.nlu.persistor import StorageType
5
+ from rasa.core.persistor import StorageType
6
6
 
7
7
  # WARNING: Be careful about adding any top level imports at this place!
8
8
  # These functions are imported in `rasa.__init__` and any top level import
@@ -14,6 +14,7 @@ from rasa.nlu.persistor import StorageType
14
14
 
15
15
  if TYPE_CHECKING:
16
16
  from rasa.model_training import TrainingResult
17
+ from rasa.shared.importers.importer import TrainingDataImporter
17
18
 
18
19
 
19
20
  def run(
@@ -78,6 +79,7 @@ def train(
78
79
  model_to_finetune: Optional[Text] = None,
79
80
  finetuning_epoch_fraction: float = 1.0,
80
81
  remote_storage: Optional[StorageType] = None,
82
+ file_importer: Optional["TrainingDataImporter"] = None,
81
83
  ) -> "TrainingResult":
82
84
  """Runs Rasa Core and NLU training in `async` loop.
83
85
 
@@ -99,7 +101,10 @@ def train(
99
101
  a directory in case the latest trained model should be used.
100
102
  finetuning_epoch_fraction: The fraction currently specified training epochs
101
103
  in the model configuration which should be used for finetuning.
102
- remote_storage: Remote storage to use for model storage.
104
+ remote_storage: Optional name of the remote storage to
105
+ use for storing the model.
106
+ file_importer: Instance of `TrainingDataImporter` to use for training.
107
+ If it is not provided, a new instance will be created.
103
108
 
104
109
  Returns:
105
110
  An instance of `TrainingResult`.
@@ -121,6 +126,7 @@ def train(
121
126
  model_to_finetune=model_to_finetune,
122
127
  finetuning_epoch_fraction=finetuning_epoch_fraction,
123
128
  remote_storage=remote_storage,
129
+ file_importer=file_importer,
124
130
  )
125
131
  )
126
132
 
@@ -1,8 +1,9 @@
1
1
  import argparse
2
2
  import logging
3
- from typing import Optional, Text, Union
3
+ from enum import Enum
4
+ from typing import List, Optional, Text, Union
4
5
 
5
- from rasa.nlu.persistor import RemoteStorageType, StorageType, parse_remote_storage
6
+ from rasa.core.persistor import RemoteStorageType, StorageType, parse_remote_storage
6
7
  from rasa.shared.constants import (
7
8
  DEFAULT_CONFIG_PATH,
8
9
  DEFAULT_DATA_PATH,
@@ -185,3 +186,23 @@ def parse_remote_storage_arg(value: str) -> StorageType:
185
186
  return parse_remote_storage(value)
186
187
  except ValueError as e:
187
188
  raise argparse.ArgumentTypeError(str(e))
189
+
190
+
191
+ class SkipYamlValidation(Enum):
192
+ DOMAIN = "domain"
193
+
194
+ @classmethod
195
+ def list(cls) -> List[str]:
196
+ return [e.value for e in SkipYamlValidation]
197
+
198
+
199
+ def add_skip_validation_flag(
200
+ parser: Union[argparse.ArgumentParser, argparse._ActionsContainer],
201
+ ) -> None:
202
+ parser.add_argument(
203
+ "--skip-yaml-validation",
204
+ default=[],
205
+ choices=SkipYamlValidation.list(),
206
+ action="append",
207
+ help="Skip YAML validation for selected parts of the training data.",
208
+ )
rasa/cli/arguments/run.py CHANGED
@@ -5,6 +5,7 @@ from typing import Union
5
5
  from rasa.cli.arguments.default_arguments import (
6
6
  add_endpoint_param,
7
7
  add_model_param,
8
+ add_skip_validation_flag,
8
9
  add_remote_storage_param,
9
10
  )
10
11
  from rasa.core import constants
@@ -21,6 +22,7 @@ def set_run_arguments(parser: argparse.ArgumentParser) -> None:
21
22
  """Arguments for running Rasa directly using `rasa run`."""
22
23
  add_model_param(parser)
23
24
  add_server_arguments(parser)
25
+ add_skip_validation_flag(parser)
24
26
 
25
27
 
26
28
  def set_run_action_arguments(parser: argparse.ArgumentParser) -> None:
rasa/cli/e2e_test.py CHANGED
@@ -221,15 +221,17 @@ def execute_e2e_tests(args: argparse.Namespace) -> None:
221
221
  if args.e2e_results is not None:
222
222
  results_path = Path(args.e2e_results)
223
223
 
224
- passed_file = rasa.cli.utils.get_e2e_results_file_name(
225
- results_path, STATUS_PASSED
226
- )
227
- write_test_results_to_file(passed, passed_file)
224
+ if passed:
225
+ passed_file = rasa.cli.utils.get_e2e_results_file_name(
226
+ results_path, STATUS_PASSED
227
+ )
228
+ write_test_results_to_file(passed, passed_file)
228
229
 
229
- failed_file = rasa.cli.utils.get_e2e_results_file_name(
230
- results_path, STATUS_FAILED
231
- )
232
- write_test_results_to_file(failed, failed_file)
230
+ if failed:
231
+ failed_file = rasa.cli.utils.get_e2e_results_file_name(
232
+ results_path, STATUS_FAILED
233
+ )
234
+ write_test_results_to_file(failed, failed_file)
233
235
 
234
236
  aggregate_stats_calculator = AggregateTestStatsCalculator(
235
237
  passed_results=passed, failed_results=failed, test_cases=test_suite.test_cases
rasa/cli/inspect.py CHANGED
@@ -3,11 +3,12 @@ import webbrowser
3
3
  from asyncio import AbstractEventLoop
4
4
  from typing import List, Text
5
5
 
6
+ from sanic import Sanic
7
+
6
8
  from rasa.cli import SubParsersAction
7
9
  from rasa.cli.arguments import shell as arguments
10
+ from rasa.cli.arguments.default_arguments import add_skip_validation_flag
8
11
  from rasa.core import constants
9
- from sanic import Sanic
10
-
11
12
  from rasa.utils.cli import remove_argument_from_parser
12
13
 
13
14
 
@@ -33,6 +34,8 @@ def add_subparser(
33
34
  inspect_parser.set_defaults(func=inspect)
34
35
 
35
36
  arguments.set_shell_arguments(inspect_parser)
37
+ add_skip_validation_flag(inspect_parser)
38
+
36
39
  # it'd be confusing to expose those arguments to the user,
37
40
  # so we remove them
38
41
  remove_argument_from_parser(inspect_parser, "--credentials")
rasa/cli/run.py CHANGED
@@ -6,6 +6,7 @@ from typing import List, Text
6
6
  from rasa.api import run as rasa_run
7
7
  from rasa.cli import SubParsersAction
8
8
  from rasa.cli.arguments import run as arguments
9
+ from rasa.cli.arguments.default_arguments import SkipYamlValidation
9
10
  from rasa.cli.utils import get_validated_path
10
11
  from rasa.exceptions import ModelNotFound
11
12
  from rasa.shared.constants import (
@@ -15,6 +16,7 @@ from rasa.shared.constants import (
15
16
  DEFAULT_MODELS_PATH,
16
17
  DOCS_BASE_URL,
17
18
  )
19
+ from rasa.shared.core.domain import Domain
18
20
  from rasa.shared.utils.cli import print_error
19
21
 
20
22
  logger = logging.getLogger(__name__)
@@ -87,6 +89,11 @@ def run(args: argparse.Namespace) -> None:
87
89
  args.credentials, "credentials", DEFAULT_CREDENTIALS_PATH, True
88
90
  )
89
91
 
92
+ if SkipYamlValidation.DOMAIN.value in args.skip_yaml_validation:
93
+ Domain.validate_yaml = False
94
+ else:
95
+ Domain.validate_yaml = True
96
+
90
97
  if args.enable_api:
91
98
  if not args.remote_storage:
92
99
  args.model = _validate_model_path(args.model, "model", DEFAULT_MODELS_PATH)
rasa/cli/studio/studio.py CHANGED
@@ -5,7 +5,6 @@ from urllib.parse import ParseResult, urlparse
5
5
  import questionary
6
6
  from rasa.cli import SubParsersAction
7
7
 
8
- import rasa.shared.utils.cli
9
8
  import rasa.cli.studio.download
10
9
  import rasa.cli.studio.train
11
10
  import rasa.cli.studio.upload
@@ -51,15 +50,6 @@ def _add_config_subparser(
51
50
 
52
51
  studio_config_parser.set_defaults(func=create_and_store_studio_config)
53
52
 
54
- studio_config_parser.add_argument(
55
- "--disable-verify",
56
- "-x",
57
- action="store_true",
58
- default=False,
59
- help="Disable strict SSL verification for the "
60
- "Rasa Studio authentication server.",
61
- )
62
-
63
53
  # add advanced configuration flag to trigger
64
54
  # advanced configuration setup for authentication settings
65
55
  studio_config_parser.add_argument(
@@ -229,17 +219,7 @@ def _configure_studio_config(args: argparse.Namespace) -> StudioConfig:
229
219
  studio_config = _create_studio_config(
230
220
  studio_url, keycloak_url, realm_name, client_id
231
221
  )
232
-
233
- if args.disable_verify:
234
- rasa.shared.utils.cli.print_info(
235
- "Disabling SSL verification for the Rasa Studio authentication server."
236
- )
237
- studio_auth = StudioAuth(studio_config, verify=False)
238
- else:
239
- rasa.shared.utils.cli.print_info(
240
- "Enabling SSL verification for the Rasa Studio authentication server."
241
- )
242
- studio_auth = StudioAuth(studio_config, verify=True)
222
+ studio_auth = StudioAuth(studio_config)
243
223
 
244
224
  if _check_studio_auth(studio_auth):
245
225
  return studio_config
rasa/cli/train.py CHANGED
@@ -110,16 +110,20 @@ def run_training(args: argparse.Namespace, can_exit: bool = False) -> Optional[T
110
110
  for f in args.data
111
111
  ]
112
112
 
113
+ training_data_importer = TrainingDataImporter.load_from_config(
114
+ domain_path=domain, training_data_paths=args.data, config_path=config
115
+ )
116
+
113
117
  if not args.skip_validation:
114
118
  structlogger.info(
115
119
  "cli.train.run_training",
116
120
  event_info="Started validating domain and training data...",
117
121
  )
118
- importer = TrainingDataImporter.load_from_config(
119
- domain_path=domain, training_data_paths=args.data, config_path=config
120
- )
122
+
121
123
  rasa.cli.utils.validate_files(
122
- args.fail_on_validation_warnings, args.validation_max_history, importer
124
+ args.fail_on_validation_warnings,
125
+ args.validation_max_history,
126
+ training_data_importer,
123
127
  )
124
128
 
125
129
  training_result = train_all(
@@ -138,6 +142,7 @@ def run_training(args: argparse.Namespace, can_exit: bool = False) -> Optional[T
138
142
  model_to_finetune=_model_for_finetuning(args),
139
143
  finetuning_epoch_fraction=args.epoch_fraction,
140
144
  remote_storage=args.remote_storage,
145
+ file_importer=training_data_importer,
141
146
  )
142
147
  if training_result.code != 0 and can_exit:
143
148
  sys.exit(training_result.code)
rasa/cli/utils.py CHANGED
@@ -470,10 +470,10 @@ def get_e2e_results_file_name(
470
470
  ) -> str:
471
471
  """Returns the name of the e2e results file."""
472
472
  if results_output_path.is_dir():
473
- file_name = str(results_output_path) + f"/e2e_results_{result_type}.yml"
473
+ file_name = results_output_path / f"e2e_results_{result_type}.yml"
474
474
  else:
475
475
  parent = results_output_path.parent
476
476
  stem = results_output_path.stem
477
- file_name = str(parent) + f"/{stem}_{result_type}.yml"
477
+ file_name = parent / f"{stem}_{result_type}.yml"
478
478
 
479
- return file_name
479
+ return str(file_name)
rasa/core/agent.py CHANGED
@@ -19,6 +19,7 @@ from rasa.core.exceptions import AgentNotReady
19
19
  from rasa.core.http_interpreter import RasaNLUHttpInterpreter
20
20
  from rasa.core.lock_store import InMemoryLockStore, LockStore
21
21
  from rasa.core.nlg import NaturalLanguageGenerator, TemplatedNaturalLanguageGenerator
22
+ from rasa.core.persistor import StorageType
22
23
  from rasa.core.policies.policy import PolicyPrediction
23
24
  from rasa.core.processor import MessageProcessor
24
25
  from rasa.core.tracker_store import (
@@ -28,7 +29,6 @@ from rasa.core.tracker_store import (
28
29
  )
29
30
  from rasa.core.utils import AvailableEndpoints
30
31
  from rasa.exceptions import ModelNotFound
31
- from rasa.nlu.persistor import StorageType
32
32
  from rasa.nlu.utils import is_url
33
33
  from rasa.shared.constants import DEFAULT_SENDER_ID
34
34
  from rasa.shared.core.domain import Domain
@@ -544,7 +544,7 @@ class Agent:
544
544
 
545
545
  def load_model_from_remote_storage(self, model_name: Text) -> None:
546
546
  """Loads an Agent from remote storage."""
547
- from rasa.nlu.persistor import get_persistor
547
+ from rasa.core.persistor import get_persistor
548
548
 
549
549
  persistor = get_persistor(self.remote_storage)
550
550
 
@@ -2,6 +2,8 @@ import asyncio
2
2
  import os
3
3
  import json
4
4
  import logging
5
+ from functools import cached_property
6
+
5
7
  import structlog
6
8
  import threading
7
9
  from asyncio import AbstractEventLoop
@@ -270,7 +272,7 @@ class KafkaEventBroker(EventBroker):
270
272
  if self.producer:
271
273
  self.producer.flush()
272
274
 
273
- @rasa.shared.utils.common.lazy_property
275
+ @cached_property
274
276
  def rasa_environment(self) -> Optional[Text]:
275
277
  """Get value of the `RASA_ENVIRONMENT` environment variable."""
276
278
  return os.environ.get("RASA_ENVIRONMENT", "RASA_ENVIRONMENT_NOT_SET")
rasa/core/brokers/pika.py CHANGED
@@ -1,6 +1,8 @@
1
1
  import asyncio
2
2
  import json
3
3
  import logging
4
+ from functools import cached_property
5
+
4
6
  import structlog
5
7
  import os
6
8
  import ssl
@@ -333,7 +335,7 @@ class PikaEventBroker(EventBroker):
333
335
  delivery_mode=aio_pika.DeliveryMode.PERSISTENT,
334
336
  )
335
337
 
336
- @rasa.shared.utils.common.lazy_property
338
+ @cached_property
337
339
  def rasa_environment(self) -> Optional[Text]:
338
340
  """Get value of the `RASA_ENVIRONMENT` environment variable."""
339
341
  return os.environ.get("RASA_ENVIRONMENT")
@@ -1,15 +1,16 @@
1
1
  import structlog
2
2
 
3
- from rasa.utils.licensing import (
4
- PRODUCT_AREA,
5
- VOICE_SCOPE,
6
- validate_license_from_env,
7
- )
8
3
 
9
4
  structlogger = structlog.get_logger()
10
5
 
11
6
 
12
7
  def validate_voice_license_scope() -> None:
8
+ from rasa.utils.licensing import (
9
+ PRODUCT_AREA,
10
+ VOICE_SCOPE,
11
+ validate_license_from_env,
12
+ )
13
+
13
14
  """Validate that the correct license scope is present."""
14
15
  structlogger.info(
15
16
  f"Validating current Rasa Pro license scope which must include "
@@ -1,11 +1,13 @@
1
1
  from typing import Any, Dict, Optional, Text
2
2
 
3
+ import os
3
4
  import structlog
4
5
  from jinja2 import Template
5
6
 
6
7
  from rasa import telemetry
7
8
  from rasa.core.nlg.response import TemplatedNaturalLanguageGenerator
8
9
  from rasa.shared.constants import (
10
+ LLM_API_HEALTH_CHECK_ENV_VAR,
9
11
  LLM_CONFIG_KEY,
10
12
  MODEL_CONFIG_KEY,
11
13
  MODEL_NAME_CONFIG_KEY,
@@ -23,6 +25,7 @@ from rasa.shared.utils.llm import (
23
25
  USER,
24
26
  combine_custom_and_default_config,
25
27
  get_prompt_template,
28
+ llm_api_health_check,
26
29
  llm_factory,
27
30
  try_instantiate_llm_client,
28
31
  )
@@ -97,12 +100,18 @@ class ContextualResponseRephraser(TemplatedNaturalLanguageGenerator):
97
100
  self.trace_prompt_tokens = self.nlg_endpoint.kwargs.get(
98
101
  "trace_prompt_tokens", False
99
102
  )
100
- try_instantiate_llm_client(
103
+ llm_client = try_instantiate_llm_client(
101
104
  self.nlg_endpoint.kwargs.get(LLM_CONFIG_KEY),
102
105
  DEFAULT_LLM_CONFIG,
103
106
  "contextual_response_rephraser.init",
104
- "ContextualResponseRephraser",
107
+ ContextualResponseRephraser.__name__,
105
108
  )
109
+ if os.getenv(LLM_API_HEALTH_CHECK_ENV_VAR, "true").lower() == "true":
110
+ llm_api_health_check(
111
+ llm_client,
112
+ "contextual_response_rephraser.init",
113
+ ContextualResponseRephraser.__name__,
114
+ )
106
115
 
107
116
  def _last_message_if_human(self, tracker: DialogueStateTracker) -> Optional[str]:
108
117
  """Returns the latest message from the tracker.
@@ -165,7 +165,7 @@ class Persistor(abc.ABC):
165
165
  os.path.join(dirpath, base_name),
166
166
  "gztar",
167
167
  root_dir=model_directory,
168
- base_dir=".",
168
+ base_dir="../nlu",
169
169
  )
170
170
  file_key = os.path.basename(tar_name)
171
171
  return file_key, tar_name
@@ -1,5 +1,6 @@
1
1
  import importlib.resources
2
2
  import json
3
+ import os
3
4
  import re
4
5
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, Text
5
6
 
@@ -47,6 +48,7 @@ from rasa.graph_components.providers.forms_provider import Forms
47
48
  from rasa.graph_components.providers.responses_provider import Responses
48
49
  from rasa.shared.constants import (
49
50
  EMBEDDINGS_CONFIG_KEY,
51
+ LLM_API_HEALTH_CHECK_ENV_VAR,
50
52
  LLM_CONFIG_KEY,
51
53
  MODEL_CONFIG_KEY,
52
54
  MODEL_NAME_CONFIG_KEY,
@@ -72,6 +74,7 @@ from rasa.shared.utils.llm import (
72
74
  DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
73
75
  embedder_factory,
74
76
  get_prompt_template,
77
+ llm_api_health_check,
75
78
  llm_factory,
76
79
  sanitize_message_for_prompt,
77
80
  tracker_as_readable_transcript,
@@ -292,12 +295,18 @@ class EnterpriseSearchPolicy(Policy):
292
295
  )
293
296
 
294
297
  # validate llm configuration
295
- try_instantiate_llm_client(
298
+ llm_client = try_instantiate_llm_client(
296
299
  self.config.get(LLM_CONFIG_KEY),
297
300
  DEFAULT_LLM_CONFIG,
298
301
  "enterprise_search_policy.train",
299
- "EnterpriseSearchPolicy",
302
+ EnterpriseSearchPolicy.__name__,
300
303
  )
304
+ if os.getenv(LLM_API_HEALTH_CHECK_ENV_VAR, "true").lower() == "true":
305
+ llm_api_health_check(
306
+ llm_client,
307
+ "enterprise_search_policy.train",
308
+ EnterpriseSearchPolicy.__name__,
309
+ )
301
310
 
302
311
  if store_type == DEFAULT_VECTOR_STORE_TYPE:
303
312
  logger.info("enterprise_search_policy.train.faiss")
@@ -1,5 +1,6 @@
1
1
  import importlib.resources
2
2
  import math
3
+ import os
3
4
  from dataclasses import dataclass, field
4
5
  from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Text, Tuple
5
6
 
@@ -31,6 +32,7 @@ from rasa.graph_components.providers.responses_provider import Responses
31
32
  from rasa.shared.constants import (
32
33
  REQUIRED_SLOTS_KEY,
33
34
  EMBEDDINGS_CONFIG_KEY,
35
+ LLM_API_HEALTH_CHECK_ENV_VAR,
34
36
  LLM_CONFIG_KEY,
35
37
  MODEL_CONFIG_KEY,
36
38
  MODEL_NAME_CONFIG_KEY,
@@ -67,6 +69,7 @@ from rasa.shared.utils.llm import (
67
69
  combine_custom_and_default_config,
68
70
  embedder_factory,
69
71
  get_prompt_template,
72
+ llm_api_health_check,
70
73
  llm_factory,
71
74
  sanitize_message_for_prompt,
72
75
  tracker_as_readable_transcript,
@@ -487,12 +490,16 @@ class IntentlessPolicy(Policy):
487
490
  A policy must return its resource locator so that potential children nodes
488
491
  can load the policy from the resource.
489
492
  """
490
- try_instantiate_llm_client(
493
+ llm_client = try_instantiate_llm_client(
491
494
  self.config.get(LLM_CONFIG_KEY),
492
495
  DEFAULT_LLM_CONFIG,
493
496
  "intentless_policy.train",
494
- "IntentlessPolicy",
497
+ IntentlessPolicy.__name__,
495
498
  )
499
+ if os.getenv(LLM_API_HEALTH_CHECK_ENV_VAR, "true").lower() == "true":
500
+ llm_api_health_check(
501
+ llm_client, "intentless_policy.train", IntentlessPolicy.__name__
502
+ )
496
503
 
497
504
  responses = filter_responses(responses, forms, flows or FlowsList([]))
498
505
  telemetry.track_intentless_policy_train()
rasa/core/run.py CHANGED
@@ -32,8 +32,8 @@ from rasa.core import agent, channels, constants
32
32
  from rasa.core.agent import Agent
33
33
  from rasa.core.channels import console
34
34
  from rasa.core.channels.channel import InputChannel
35
+ from rasa.core.persistor import StorageType
35
36
  from rasa.core.utils import AvailableEndpoints
36
- from rasa.nlu.persistor import StorageType
37
37
  from rasa.plugin import plugin_manager
38
38
  from rasa.shared.exceptions import RasaException
39
39
  from rasa.shared.utils.yaml import read_config_file
@@ -311,6 +311,7 @@ async def load_agent_on_start(
311
311
  endpoints=endpoints,
312
312
  loop=loop,
313
313
  )
314
+
314
315
  logger.info("Rasa server is up and running.")
315
316
  return app.ctx.agent
316
317
 
@@ -23,6 +23,7 @@ VAULT_TRANSIT_MOUNT_POINT_ENV_NAME = "VAULT_TRANSIT_MOUNT_POINT"
23
23
  VAULT_NAMESPACE_ENV_NAME = "VAULT_NAMESPACE"
24
24
  VAULT_DEFAULT_RASA_SECRETS_PATH = "rasa-secrets"
25
25
  VAULT_SECRET_MANAGER_NAME = "vault"
26
+ VAULT_MOUNT_POINT_ENV_NAME = "VAULT_MOUNT_POINT"
26
27
 
27
28
 
28
29
  VAULT_ENDPOINT_URL_LABEL = "url"
@@ -30,3 +31,6 @@ VAULT_ENDPOINT_TOKEN_LABEL = "token"
30
31
  VAULT_ENDPOINT_SECRETS_PATH_LABEL = "secrets_path"
31
32
  VAULT_ENDPOINT_TRANSIT_MOUNT_POINT_LABEL = "transit_mount_point"
32
33
  VAULT_ENDPOINT_NAMESPACE_LABEL = "namespace"
34
+ VAULT_ENDPOINT_MOUNT_POINT_LABEL = "mount_point"
35
+
36
+ VAULT_MOUNT_POINT_DEFAULT_VALUE = "secret"
@@ -7,9 +7,11 @@ from rasa.utils.endpoints import EndpointConfig, read_endpoint_config
7
7
  from rasa.core.secrets_manager.constants import (
8
8
  SECRET_MANAGER_ENV_NAME,
9
9
  VAULT_DEFAULT_RASA_SECRETS_PATH,
10
+ VAULT_ENDPOINT_MOUNT_POINT_LABEL,
10
11
  VAULT_ENDPOINT_NAMESPACE_LABEL,
11
12
  VAULT_ENDPOINT_SECRETS_PATH_LABEL,
12
13
  VAULT_ENDPOINT_TRANSIT_MOUNT_POINT_LABEL,
14
+ VAULT_MOUNT_POINT_ENV_NAME,
13
15
  VAULT_NAMESPACE_ENV_NAME,
14
16
  VAULT_RASA_SECRETS_PATH_ENV_NAME,
15
17
  VAULT_SECRET_MANAGER_NAME,
@@ -48,6 +50,7 @@ def create(config: SecretManagerConfig) -> Optional[SecretsManager]:
48
50
  transit_mount_point=vault_config.transit_mount_point,
49
51
  secrets_path=vault_config.secrets_path,
50
52
  namespace=vault_config.namespace,
53
+ mount_point=vault_config.mount_point,
51
54
  )
52
55
 
53
56
  return secret_manager
@@ -79,6 +82,7 @@ def read_vault_endpoint_config(
79
82
  )
80
83
  secrets_path = endpoint_config.kwargs.get(VAULT_ENDPOINT_SECRETS_PATH_LABEL)
81
84
  namespace = endpoint_config.kwargs.get(VAULT_ENDPOINT_NAMESPACE_LABEL)
85
+ mount_point = endpoint_config.kwargs.get(VAULT_ENDPOINT_MOUNT_POINT_LABEL)
82
86
 
83
87
  return VaultSecretManagerNonStrictConfig(
84
88
  url=url,
@@ -86,6 +90,7 @@ def read_vault_endpoint_config(
86
90
  transit_mount_point=transit_mount_point,
87
91
  secrets_path=secrets_path or VAULT_DEFAULT_RASA_SECRETS_PATH,
88
92
  namespace=namespace,
93
+ mount_point=mount_point,
89
94
  )
90
95
 
91
96
  return None
@@ -102,6 +107,7 @@ def read_vault_env_vars() -> VaultSecretManagerNonStrictConfig:
102
107
  transit_mount_point = os.getenv(VAULT_TRANSIT_MOUNT_POINT_ENV_NAME)
103
108
  secrets_path = os.getenv(VAULT_RASA_SECRETS_PATH_ENV_NAME)
104
109
  namespace = os.getenv(VAULT_NAMESPACE_ENV_NAME)
110
+ mount_point = os.getenv(VAULT_MOUNT_POINT_ENV_NAME)
105
111
 
106
112
  return VaultSecretManagerNonStrictConfig(
107
113
  url=url,
@@ -109,6 +115,7 @@ def read_vault_env_vars() -> VaultSecretManagerNonStrictConfig:
109
115
  transit_mount_point=transit_mount_point,
110
116
  secrets_path=secrets_path,
111
117
  namespace=namespace,
118
+ mount_point=mount_point,
112
119
  )
113
120
 
114
121
 
@@ -149,6 +156,7 @@ def read_vault_config(
149
156
  f"{VAULT_RASA_SECRETS_PATH_ENV_NAME} = {env_config.secrets_path}, "
150
157
  f"{VAULT_TRANSIT_MOUNT_POINT_ENV_NAME} = {env_config.transit_mount_point}. "
151
158
  f"{VAULT_NAMESPACE_ENV_NAME} = {env_config.namespace}. "
159
+ f"{VAULT_MOUNT_POINT_ENV_NAME} = {env_config.mount_point}. "
152
160
  )
153
161
 
154
162
 
@@ -15,6 +15,7 @@ from rasa.utils.endpoints import EndpointConfig
15
15
  from rasa.core.secrets_manager.constants import (
16
16
  TRACKER_STORE_ENDPOINT_TYPE,
17
17
  TRANSIT_KEY_FOR_ENCRYPTION_LABEL,
18
+ VAULT_MOUNT_POINT_DEFAULT_VALUE,
18
19
  VAULT_SECRET_MANAGER_NAME,
19
20
  )
20
21
  from rasa.core.secrets_manager.endpoints import (
@@ -181,6 +182,7 @@ class VaultSecretsManager(SecretsManager):
181
182
  secrets_path: Text,
182
183
  transit_mount_point: Optional[Text] = None,
183
184
  namespace: Optional[Text] = None,
185
+ mount_point: Optional[Text] = None,
184
186
  ):
185
187
  """Initialise the VaultSecretsManager.
186
188
 
@@ -190,11 +192,13 @@ class VaultSecretsManager(SecretsManager):
190
192
  secrets_path: The path to the secrets in the vault server.
191
193
  transit_mount_point: The mount point of the transit engine.
192
194
  namespace: The namespace in which secrets reside in.
195
+ mount_point: The mount point of the kv engine.
193
196
  """
194
197
  self.host = host
195
198
  self.transit_mount_point = transit_mount_point
196
199
  self.token = token
197
200
  self.secrets_path = secrets_path
201
+ self.mount_point = mount_point or VAULT_MOUNT_POINT_DEFAULT_VALUE
198
202
  self.namespace = namespace
199
203
 
200
204
  # Create client
@@ -236,7 +240,7 @@ class VaultSecretsManager(SecretsManager):
236
240
  """
237
241
  logger.info(f"Loading secrets from vault server at {self.host}.")
238
242
  read_response = self.client.secrets.kv.read_secret_version(
239
- mount_point="secret", path=self.secrets_path
243
+ mount_point=self.mount_point, path=self.secrets_path
240
244
  )
241
245
 
242
246
  secrets = read_response["data"]["data"]
@@ -455,6 +459,7 @@ class VaultSecretManagerConfig(SecretManagerConfig):
455
459
  secrets_path: Text,
456
460
  transit_mount_point: Text = "transit",
457
461
  namespace: Optional[Text] = None,
462
+ mount_point: Optional[Text] = None,
458
463
  ) -> None:
459
464
  """Initialise the VaultSecretManagerConfig.
460
465
 
@@ -471,6 +476,7 @@ class VaultSecretManagerConfig(SecretManagerConfig):
471
476
  self.secrets_path = secrets_path
472
477
  self.transit_mount_point = transit_mount_point
473
478
  self.namespace = namespace
479
+ self.mount_point = mount_point
474
480
 
475
481
 
476
482
  @dataclass
@@ -486,6 +492,7 @@ class VaultSecretManagerNonStrictConfig:
486
492
  secrets_path: Optional[Text]
487
493
  transit_mount_point: Optional[Text]
488
494
  namespace: Optional[Text] = None
495
+ mount_point: Optional[Text] = None
489
496
 
490
497
  def is_empty(self) -> bool:
491
498
  """Check if all the values are empty."""
@@ -495,6 +502,7 @@ class VaultSecretManagerNonStrictConfig:
495
502
  and (self.secrets_path is None or self.secrets_path == "")
496
503
  and (self.transit_mount_point is None or self.transit_mount_point == "")
497
504
  and (self.namespace is None or self.namespace == "")
505
+ and (self.mount_point is None or self.mount_point == "")
498
506
  )
499
507
 
500
508
  def is_valid(self) -> bool:
@@ -516,6 +524,7 @@ class VaultSecretManagerNonStrictConfig:
516
524
  and self.secrets_path != ""
517
525
  and self._is_optional_value_valid(self.transit_mount_point)
518
526
  and self._is_optional_value_valid(self.namespace)
527
+ and self._is_optional_value_valid(self.mount_point)
519
528
  )
520
529
 
521
530
  @staticmethod
@@ -547,6 +556,7 @@ class VaultSecretManagerNonStrictConfig:
547
556
  secrets_path=self.secrets_path or other.secrets_path,
548
557
  transit_mount_point=self.transit_mount_point or other.transit_mount_point,
549
558
  namespace=self.namespace or other.namespace,
559
+ mount_point=self.mount_point or other.mount_point,
550
560
  )
551
561
 
552
562