rasa-pro 3.10.16__py3-none-any.whl → 3.11.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- README.md +396 -17
- rasa/api.py +9 -3
- rasa/cli/arguments/default_arguments.py +23 -2
- rasa/cli/arguments/run.py +15 -0
- rasa/cli/arguments/train.py +3 -9
- rasa/cli/e2e_test.py +1 -1
- rasa/cli/evaluate.py +1 -1
- rasa/cli/inspect.py +8 -4
- rasa/cli/llm_fine_tuning.py +12 -15
- rasa/cli/run.py +8 -1
- rasa/cli/studio/studio.py +8 -18
- rasa/cli/train.py +11 -53
- rasa/cli/utils.py +8 -10
- rasa/cli/x.py +1 -1
- rasa/constants.py +1 -1
- rasa/core/actions/action.py +2 -0
- rasa/core/actions/action_hangup.py +29 -0
- rasa/core/agent.py +2 -2
- rasa/core/brokers/kafka.py +3 -1
- rasa/core/brokers/pika.py +3 -1
- rasa/core/channels/__init__.py +8 -6
- rasa/core/channels/channel.py +21 -4
- rasa/core/channels/development_inspector.py +143 -46
- rasa/core/channels/inspector/README.md +1 -1
- rasa/core/channels/inspector/dist/assets/{arc-b6e548fe.js → arc-86942a71.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{c4Diagram-d0fbc5ce-fa03ac9e.js → c4Diagram-d0fbc5ce-b0290676.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-936ed81e-ee67392a.js → classDiagram-936ed81e-f6405f6e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-v2-c3cb15f1-9b283fae.js → classDiagram-v2-c3cb15f1-ef61ac77.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{createText-62fc7601-8b6fcc2a.js → createText-62fc7601-f0411e58.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{edges-f2ad444c-22e77f4f.js → edges-f2ad444c-7dcc4f3b.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{erDiagram-9d236eb7-60ffc87f.js → erDiagram-9d236eb7-e0c092d7.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDb-1972c806-9dd802e4.js → flowDb-1972c806-fba2e3ce.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDiagram-7ea5b25a-5fa1912f.js → flowDiagram-7ea5b25a-7a70b71a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-24a5f41a.js +1 -0
- rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-abe16c3d-622a1fd2.js → flowchart-elk-definition-abe16c3d-00a59b68.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{ganttDiagram-9b5ea136-e285a63a.js → ganttDiagram-9b5ea136-293c91fa.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-99d0ae7c-f237bdca.js → gitGraphDiagram-99d0ae7c-07b2d68c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-2c4b9a3b-4b03d70e.js → index-2c4b9a3b-bc959fbd.js} +1 -1
- rasa/core/channels/inspector/dist/assets/index-3a8a5a28.js +1317 -0
- rasa/core/channels/inspector/dist/assets/{infoDiagram-736b4530-72a0fa5f.js → infoDiagram-736b4530-4a350f72.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{journeyDiagram-df861f2b-82218c41.js → journeyDiagram-df861f2b-af464fb7.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{layout-78cff630.js → layout-0071f036.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{line-5038b469.js → line-2f73cc83.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{linear-c4fc4098.js → linear-f014b4cc.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{mindmap-definition-beec6740-c33c8ea6.js → mindmap-definition-beec6740-d2426fb6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{pieDiagram-dbbf0591-a8d03059.js → pieDiagram-dbbf0591-776f01a2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{quadrantDiagram-4d7f4fd6-6a0e56b2.js → quadrantDiagram-4d7f4fd6-82e00b57.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{requirementDiagram-6fc4c22a-2dc7c7bd.js → requirementDiagram-6fc4c22a-ea13c6bb.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sankeyDiagram-8f13d901-2360fe39.js → sankeyDiagram-8f13d901-1feca7e9.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sequenceDiagram-b655622a-41b9f9ad.js → sequenceDiagram-b655622a-070c61d2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-59f0c015-0aad326f.js → stateDiagram-59f0c015-24f46263.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-2b26beab-9847d984.js → stateDiagram-v2-2b26beab-c9056051.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-080da4f6-564d890e.js → styles-080da4f6-08abc34a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-3dcbcfbf-38957613.js → styles-3dcbcfbf-bc74c25a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-9c745c82-f0fc6921.js → styles-9c745c82-4e5d66de.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{svgDrawCommon-4835440b-ef3c5a77.js → svgDrawCommon-4835440b-849c4517.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{timeline-definition-5b62e21b-bf3e91c1.js → timeline-definition-5b62e21b-d0fb1598.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{xychartDiagram-2b33534f-4d4026c0.js → xychartDiagram-2b33534f-04d115e2.js} +1 -1
- rasa/core/channels/inspector/dist/index.html +18 -17
- rasa/core/channels/inspector/index.html +17 -16
- rasa/core/channels/inspector/package.json +5 -1
- rasa/core/channels/inspector/src/App.tsx +117 -67
- rasa/core/channels/inspector/src/components/Chat.tsx +95 -0
- rasa/core/channels/inspector/src/components/DiagramFlow.tsx +11 -10
- rasa/core/channels/inspector/src/components/DialogueStack.tsx +10 -25
- rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +1 -1
- rasa/core/channels/inspector/src/helpers/formatters.test.ts +10 -0
- rasa/core/channels/inspector/src/helpers/formatters.ts +107 -41
- rasa/core/channels/inspector/src/helpers/utils.ts +92 -7
- rasa/core/channels/inspector/src/types.ts +21 -1
- rasa/core/channels/inspector/yarn.lock +94 -1
- rasa/core/channels/rest.py +51 -46
- rasa/core/channels/socketio.py +22 -0
- rasa/core/channels/{audiocodes.py → voice_ready/audiocodes.py} +110 -68
- rasa/core/channels/{voice_aware → voice_ready}/jambonz.py +11 -4
- rasa/core/channels/{voice_aware → voice_ready}/jambonz_protocol.py +57 -5
- rasa/core/channels/{twilio_voice.py → voice_ready/twilio_voice.py} +58 -7
- rasa/core/channels/{voice_aware → voice_ready}/utils.py +16 -0
- rasa/core/channels/voice_stream/asr/__init__.py +0 -0
- rasa/core/channels/voice_stream/asr/asr_engine.py +71 -0
- rasa/core/channels/voice_stream/asr/asr_event.py +13 -0
- rasa/core/channels/voice_stream/asr/deepgram.py +77 -0
- rasa/core/channels/voice_stream/audio_bytes.py +7 -0
- rasa/core/channels/voice_stream/tts/__init__.py +0 -0
- rasa/core/channels/voice_stream/tts/azure.py +100 -0
- rasa/core/channels/voice_stream/tts/cartesia.py +114 -0
- rasa/core/channels/voice_stream/tts/tts_cache.py +27 -0
- rasa/core/channels/voice_stream/tts/tts_engine.py +48 -0
- rasa/core/channels/voice_stream/twilio_media_streams.py +164 -0
- rasa/core/channels/voice_stream/util.py +57 -0
- rasa/core/channels/voice_stream/voice_channel.py +247 -0
- rasa/core/featurizers/single_state_featurizer.py +1 -22
- rasa/core/featurizers/tracker_featurizers.py +18 -115
- rasa/core/nlg/contextual_response_rephraser.py +11 -2
- rasa/{nlu → core}/persistor.py +16 -38
- rasa/core/policies/enterprise_search_policy.py +12 -15
- rasa/core/policies/flows/flow_executor.py +8 -18
- rasa/core/policies/intentless_policy.py +10 -15
- rasa/core/policies/ted_policy.py +33 -58
- rasa/core/policies/unexpected_intent_policy.py +7 -15
- rasa/core/processor.py +13 -64
- rasa/core/run.py +11 -1
- rasa/core/secrets_manager/constants.py +4 -0
- rasa/core/secrets_manager/factory.py +8 -0
- rasa/core/secrets_manager/vault.py +11 -1
- rasa/core/training/interactive.py +1 -1
- rasa/core/utils.py +1 -11
- rasa/dialogue_understanding/coexistence/llm_based_router.py +10 -10
- rasa/dialogue_understanding/commands/__init__.py +2 -0
- rasa/dialogue_understanding/commands/change_flow_command.py +0 -6
- rasa/dialogue_understanding/commands/session_end_command.py +61 -0
- rasa/dialogue_understanding/generator/flow_retrieval.py +0 -7
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +12 -3
- rasa/dialogue_understanding/generator/llm_command_generator.py +1 -1
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +3 -28
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +1 -19
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +4 -37
- rasa/e2e_test/aggregate_test_stats_calculator.py +1 -11
- rasa/e2e_test/assertions.py +6 -48
- rasa/e2e_test/e2e_test_runner.py +6 -9
- rasa/e2e_test/utils/e2e_yaml_utils.py +1 -1
- rasa/e2e_test/utils/io.py +1 -3
- rasa/engine/graph.py +3 -10
- rasa/engine/recipes/config_files/default_config.yml +0 -3
- rasa/engine/recipes/default_recipe.py +0 -1
- rasa/engine/recipes/graph_recipe.py +0 -1
- rasa/engine/runner/dask.py +2 -2
- rasa/engine/storage/local_model_storage.py +12 -42
- rasa/engine/storage/storage.py +1 -5
- rasa/engine/validation.py +1 -78
- rasa/keys +1 -0
- rasa/model_training.py +13 -16
- rasa/nlu/classifiers/diet_classifier.py +25 -38
- rasa/nlu/classifiers/logistic_regression_classifier.py +9 -22
- rasa/nlu/classifiers/sklearn_intent_classifier.py +16 -37
- rasa/nlu/extractors/crf_entity_extractor.py +50 -93
- rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +16 -45
- rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +17 -52
- rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +3 -5
- rasa/server.py +1 -1
- rasa/shared/constants.py +3 -12
- rasa/shared/core/constants.py +4 -0
- rasa/shared/core/domain.py +101 -47
- rasa/shared/core/events.py +29 -0
- rasa/shared/core/flows/flows_list.py +20 -11
- rasa/shared/core/flows/validation.py +25 -0
- rasa/shared/core/flows/yaml_flows_io.py +3 -24
- rasa/shared/importers/importer.py +40 -39
- rasa/shared/importers/multi_project.py +23 -11
- rasa/shared/importers/rasa.py +7 -2
- rasa/shared/importers/remote_importer.py +196 -0
- rasa/shared/importers/utils.py +3 -1
- rasa/shared/nlu/training_data/features.py +2 -120
- rasa/shared/nlu/training_data/training_data.py +18 -19
- rasa/shared/providers/_configs/azure_openai_client_config.py +3 -5
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +1 -6
- rasa/shared/providers/llm/_base_litellm_client.py +11 -31
- rasa/shared/providers/llm/self_hosted_llm_client.py +3 -15
- rasa/shared/utils/common.py +3 -22
- rasa/shared/utils/io.py +0 -1
- rasa/shared/utils/llm.py +30 -27
- rasa/shared/utils/schemas/events.py +2 -0
- rasa/shared/utils/schemas/model_config.yml +0 -10
- rasa/shared/utils/yaml.py +44 -0
- rasa/studio/auth.py +5 -3
- rasa/studio/config.py +4 -13
- rasa/studio/constants.py +0 -1
- rasa/studio/data_handler.py +3 -10
- rasa/studio/upload.py +8 -17
- rasa/tracing/instrumentation/attribute_extractors.py +1 -1
- rasa/utils/io.py +66 -0
- rasa/utils/tensorflow/model_data.py +193 -2
- rasa/validator.py +0 -12
- rasa/version.py +1 -1
- rasa_pro-3.11.0a1.dist-info/METADATA +576 -0
- {rasa_pro-3.10.16.dist-info → rasa_pro-3.11.0a1.dist-info}/RECORD +181 -164
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-1844e5a5.js +0 -1
- rasa/core/channels/inspector/dist/assets/index-a5d3e69d.js +0 -1040
- rasa/utils/tensorflow/feature_array.py +0 -366
- rasa_pro-3.10.16.dist-info/METADATA +0 -196
- /rasa/core/channels/{voice_aware → voice_ready}/__init__.py +0 -0
- /rasa/core/channels/{voice_native → voice_stream}/__init__.py +0 -0
- {rasa_pro-3.10.16.dist-info → rasa_pro-3.11.0a1.dist-info}/NOTICE +0 -0
- {rasa_pro-3.10.16.dist-info → rasa_pro-3.11.0a1.dist-info}/WHEEL +0 -0
- {rasa_pro-3.10.16.dist-info → rasa_pro-3.11.0a1.dist-info}/entry_points.txt +0 -0
rasa/core/policies/ted_policy.py
CHANGED
|
@@ -1,15 +1,15 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
|
|
3
2
|
import logging
|
|
3
|
+
|
|
4
|
+
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
4
5
|
from pathlib import Path
|
|
5
6
|
from collections import defaultdict
|
|
6
7
|
import contextlib
|
|
7
|
-
from typing import Any, List, Optional, Text, Dict, Tuple, Union, Type
|
|
8
8
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
import tensorflow as tf
|
|
11
|
+
from typing import Any, List, Optional, Text, Dict, Tuple, Union, Type
|
|
11
12
|
|
|
12
|
-
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
13
13
|
from rasa.engine.graph import ExecutionContext
|
|
14
14
|
from rasa.engine.storage.resource import Resource
|
|
15
15
|
from rasa.engine.storage.storage import ModelStorage
|
|
@@ -49,22 +49,18 @@ from rasa.shared.core.generator import TrackerWithCachedStates
|
|
|
49
49
|
from rasa.shared.core.events import EntitiesAdded, Event
|
|
50
50
|
from rasa.shared.core.domain import Domain
|
|
51
51
|
from rasa.shared.nlu.training_data.message import Message
|
|
52
|
-
from rasa.shared.nlu.training_data.features import
|
|
53
|
-
Features,
|
|
54
|
-
save_features,
|
|
55
|
-
load_features,
|
|
56
|
-
)
|
|
52
|
+
from rasa.shared.nlu.training_data.features import Features
|
|
57
53
|
import rasa.shared.utils.io
|
|
58
54
|
import rasa.utils.io
|
|
59
55
|
from rasa.utils import train_utils
|
|
60
|
-
from rasa.utils.tensorflow.feature_array import (
|
|
61
|
-
FeatureArray,
|
|
62
|
-
serialize_nested_feature_arrays,
|
|
63
|
-
deserialize_nested_feature_arrays,
|
|
64
|
-
)
|
|
65
56
|
from rasa.utils.tensorflow.models import RasaModel, TransformerRasaModel
|
|
66
57
|
from rasa.utils.tensorflow import rasa_layers
|
|
67
|
-
from rasa.utils.tensorflow.model_data import
|
|
58
|
+
from rasa.utils.tensorflow.model_data import (
|
|
59
|
+
RasaModelData,
|
|
60
|
+
FeatureSignature,
|
|
61
|
+
FeatureArray,
|
|
62
|
+
Data,
|
|
63
|
+
)
|
|
68
64
|
from rasa.utils.tensorflow.model_data_utils import convert_to_data_format
|
|
69
65
|
from rasa.utils.tensorflow.constants import (
|
|
70
66
|
LABEL,
|
|
@@ -965,32 +961,22 @@ class TEDPolicy(Policy):
|
|
|
965
961
|
model_path: Path where model is to be persisted
|
|
966
962
|
"""
|
|
967
963
|
model_filename = self._metadata_filename()
|
|
968
|
-
rasa.
|
|
969
|
-
model_path / f"{model_filename}.priority.
|
|
970
|
-
)
|
|
971
|
-
rasa.shared.utils.io.dump_obj_as_json_to_file(
|
|
972
|
-
model_path / f"{model_filename}.meta.json", self.config
|
|
964
|
+
rasa.utils.io.json_pickle(
|
|
965
|
+
model_path / f"{model_filename}.priority.pkl", self.priority
|
|
973
966
|
)
|
|
974
|
-
|
|
975
|
-
|
|
976
|
-
self.data_example,
|
|
977
|
-
str(model_path / f"{model_filename}.data_example.st"),
|
|
978
|
-
str(model_path / f"{model_filename}.data_example_metadata.json"),
|
|
967
|
+
rasa.utils.io.pickle_dump(
|
|
968
|
+
model_path / f"{model_filename}.meta.pkl", self.config
|
|
979
969
|
)
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
dict(self._label_data.data) if self._label_data is not None else {},
|
|
983
|
-
str(model_path / f"{model_filename}.label_data.st"),
|
|
984
|
-
str(model_path / f"{model_filename}.label_data_metadata.json"),
|
|
970
|
+
rasa.utils.io.pickle_dump(
|
|
971
|
+
model_path / f"{model_filename}.data_example.pkl", self.data_example
|
|
985
972
|
)
|
|
986
|
-
|
|
987
|
-
|
|
988
|
-
self.fake_features, str(model_path / f"{model_filename}.fake_features.st")
|
|
973
|
+
rasa.utils.io.pickle_dump(
|
|
974
|
+
model_path / f"{model_filename}.fake_features.pkl", self.fake_features
|
|
989
975
|
)
|
|
990
|
-
rasa.
|
|
991
|
-
model_path / f"{model_filename}.
|
|
976
|
+
rasa.utils.io.pickle_dump(
|
|
977
|
+
model_path / f"{model_filename}.label_data.pkl",
|
|
978
|
+
dict(self._label_data.data) if self._label_data is not None else {},
|
|
992
979
|
)
|
|
993
|
-
|
|
994
980
|
entity_tag_specs = (
|
|
995
981
|
[tag_spec._asdict() for tag_spec in self._entity_tag_specs]
|
|
996
982
|
if self._entity_tag_specs
|
|
@@ -1008,29 +994,18 @@ class TEDPolicy(Policy):
|
|
|
1008
994
|
model_path: Path where model is to be persisted.
|
|
1009
995
|
"""
|
|
1010
996
|
tf_model_file = model_path / f"{cls._metadata_filename()}.tf_model"
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
loaded_data = deserialize_nested_feature_arrays(
|
|
1014
|
-
str(model_path / f"{cls._metadata_filename()}.data_example.st"),
|
|
1015
|
-
str(model_path / f"{cls._metadata_filename()}.data_example_metadata.json"),
|
|
997
|
+
loaded_data = rasa.utils.io.pickle_load(
|
|
998
|
+
model_path / f"{cls._metadata_filename()}.data_example.pkl"
|
|
1016
999
|
)
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
str(model_path / f"{cls._metadata_filename()}.label_data.st"),
|
|
1020
|
-
str(model_path / f"{cls._metadata_filename()}.label_data_metadata.json"),
|
|
1000
|
+
label_data = rasa.utils.io.pickle_load(
|
|
1001
|
+
model_path / f"{cls._metadata_filename()}.label_data.pkl"
|
|
1021
1002
|
)
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
# load fake features
|
|
1025
|
-
metadata = rasa.shared.utils.io.read_json_file(
|
|
1026
|
-
model_path / f"{cls._metadata_filename()}.fake_features_metadata.json"
|
|
1003
|
+
fake_features = rasa.utils.io.pickle_load(
|
|
1004
|
+
model_path / f"{cls._metadata_filename()}.fake_features.pkl"
|
|
1027
1005
|
)
|
|
1028
|
-
|
|
1029
|
-
|
|
1030
|
-
|
|
1031
|
-
|
|
1032
|
-
priority = rasa.shared.utils.io.read_json_file(
|
|
1033
|
-
model_path / f"{cls._metadata_filename()}.priority.json"
|
|
1006
|
+
label_data = RasaModelData(data=label_data)
|
|
1007
|
+
priority = rasa.utils.io.json_unpickle(
|
|
1008
|
+
model_path / f"{cls._metadata_filename()}.priority.pkl"
|
|
1034
1009
|
)
|
|
1035
1010
|
entity_tag_specs = rasa.shared.utils.io.read_json_file(
|
|
1036
1011
|
model_path / f"{cls._metadata_filename()}.entity_tag_specs.json"
|
|
@@ -1048,8 +1023,8 @@ class TEDPolicy(Policy):
|
|
|
1048
1023
|
)
|
|
1049
1024
|
for tag_spec in entity_tag_specs
|
|
1050
1025
|
]
|
|
1051
|
-
model_config = rasa.
|
|
1052
|
-
model_path / f"{cls._metadata_filename()}.meta.
|
|
1026
|
+
model_config = rasa.utils.io.pickle_load(
|
|
1027
|
+
model_path / f"{cls._metadata_filename()}.meta.pkl"
|
|
1053
1028
|
)
|
|
1054
1029
|
|
|
1055
1030
|
return {
|
|
@@ -1095,7 +1070,7 @@ class TEDPolicy(Policy):
|
|
|
1095
1070
|
) -> TEDPolicy:
|
|
1096
1071
|
featurizer = TrackerFeaturizer.load(model_path)
|
|
1097
1072
|
|
|
1098
|
-
if not (model_path / f"{cls._metadata_filename()}.data_example.
|
|
1073
|
+
if not (model_path / f"{cls._metadata_filename()}.data_example.pkl").is_file():
|
|
1099
1074
|
return cls(
|
|
1100
1075
|
config,
|
|
1101
1076
|
model_storage,
|
|
@@ -5,7 +5,6 @@ from typing import Any, List, Optional, Text, Dict, Type, Union
|
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
7
|
import tensorflow as tf
|
|
8
|
-
|
|
9
8
|
import rasa.utils.common
|
|
10
9
|
from rasa.engine.graph import ExecutionContext
|
|
11
10
|
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
@@ -17,7 +16,6 @@ from rasa.shared.core.domain import Domain
|
|
|
17
16
|
from rasa.shared.core.trackers import DialogueStateTracker
|
|
18
17
|
from rasa.shared.core.constants import SLOTS, ACTIVE_LOOP, ACTION_UNLIKELY_INTENT_NAME
|
|
19
18
|
from rasa.shared.core.events import UserUttered, ActionExecuted
|
|
20
|
-
import rasa.shared.utils.io
|
|
21
19
|
from rasa.shared.nlu.constants import (
|
|
22
20
|
INTENT,
|
|
23
21
|
TEXT,
|
|
@@ -105,6 +103,8 @@ from rasa.utils.tensorflow.constants import (
|
|
|
105
103
|
)
|
|
106
104
|
from rasa.utils.tensorflow import layers
|
|
107
105
|
from rasa.utils.tensorflow.model_data import RasaModelData, FeatureArray, Data
|
|
106
|
+
|
|
107
|
+
import rasa.utils.io as io_utils
|
|
108
108
|
from rasa.core.exceptions import RasaCoreException
|
|
109
109
|
from rasa.shared.utils import common
|
|
110
110
|
|
|
@@ -881,12 +881,9 @@ class UnexpecTEDIntentPolicy(TEDPolicy):
|
|
|
881
881
|
model_path: Path where model is to be persisted
|
|
882
882
|
"""
|
|
883
883
|
super().persist_model_utilities(model_path)
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
save_file(
|
|
888
|
-
{str(k): np.array(v) for k, v in self.label_quantiles.items()},
|
|
889
|
-
model_path / f"{self._metadata_filename()}.label_quantiles.st",
|
|
884
|
+
io_utils.pickle_dump(
|
|
885
|
+
model_path / f"{self._metadata_filename()}.label_quantiles.pkl",
|
|
886
|
+
self.label_quantiles,
|
|
890
887
|
)
|
|
891
888
|
|
|
892
889
|
@classmethod
|
|
@@ -897,14 +894,9 @@ class UnexpecTEDIntentPolicy(TEDPolicy):
|
|
|
897
894
|
model_path: Path where model is to be persisted.
|
|
898
895
|
"""
|
|
899
896
|
model_utilties = super()._load_model_utilities(model_path)
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
loaded_label_quantiles = load_file(
|
|
904
|
-
model_path / f"{cls._metadata_filename()}.label_quantiles.st"
|
|
897
|
+
label_quantiles = io_utils.pickle_load(
|
|
898
|
+
model_path / f"{cls._metadata_filename()}.label_quantiles.pkl"
|
|
905
899
|
)
|
|
906
|
-
label_quantiles = {int(k): list(v) for k, v in loaded_label_quantiles.items()}
|
|
907
|
-
|
|
908
900
|
model_utilties.update({"label_quantiles": label_quantiles})
|
|
909
901
|
return model_utilties
|
|
910
902
|
|
rasa/core/processor.py
CHANGED
|
@@ -101,9 +101,6 @@ logger = logging.getLogger(__name__)
|
|
|
101
101
|
structlogger = structlog.get_logger()
|
|
102
102
|
|
|
103
103
|
MAX_NUMBER_OF_PREDICTIONS = int(os.environ.get("MAX_NUMBER_OF_PREDICTIONS", "10"))
|
|
104
|
-
MAX_NUMBER_OF_PREDICTIONS_CALM = int(
|
|
105
|
-
os.environ.get("MAX_NUMBER_OF_PREDICTIONS_CALM", "1000")
|
|
106
|
-
)
|
|
107
104
|
|
|
108
105
|
|
|
109
106
|
class MessageProcessor:
|
|
@@ -117,7 +114,6 @@ class MessageProcessor:
|
|
|
117
114
|
generator: NaturalLanguageGenerator,
|
|
118
115
|
action_endpoint: Optional[EndpointConfig] = None,
|
|
119
116
|
max_number_of_predictions: int = MAX_NUMBER_OF_PREDICTIONS,
|
|
120
|
-
max_number_of_predictions_calm: int = MAX_NUMBER_OF_PREDICTIONS_CALM,
|
|
121
117
|
on_circuit_break: Optional[LambdaType] = None,
|
|
122
118
|
http_interpreter: Optional[RasaNLUHttpInterpreter] = None,
|
|
123
119
|
endpoints: Optional["AvailableEndpoints"] = None,
|
|
@@ -126,6 +122,7 @@ class MessageProcessor:
|
|
|
126
122
|
self.nlg = generator
|
|
127
123
|
self.tracker_store = tracker_store
|
|
128
124
|
self.lock_store = lock_store
|
|
125
|
+
self.max_number_of_predictions = max_number_of_predictions
|
|
129
126
|
self.on_circuit_break = on_circuit_break
|
|
130
127
|
self.action_endpoint = action_endpoint
|
|
131
128
|
self.model_filename, self.model_metadata, self.graph_runner = self._load_model(
|
|
@@ -133,10 +130,6 @@ class MessageProcessor:
|
|
|
133
130
|
)
|
|
134
131
|
self.endpoints = endpoints
|
|
135
132
|
|
|
136
|
-
self.max_number_of_predictions = max_number_of_predictions
|
|
137
|
-
self.max_number_of_predictions_calm = max_number_of_predictions_calm
|
|
138
|
-
self.is_calm_assistant = self._is_calm_assistant()
|
|
139
|
-
|
|
140
133
|
if self.model_metadata.assistant_id is None:
|
|
141
134
|
rasa.shared.utils.io.raise_warning(
|
|
142
135
|
f"The model metadata does not contain a value for the "
|
|
@@ -979,15 +972,11 @@ class MessageProcessor:
|
|
|
979
972
|
) -> int:
|
|
980
973
|
"""Select the action limit based on the tracker state.
|
|
981
974
|
|
|
982
|
-
|
|
983
|
-
|
|
984
|
-
is
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
through a long dialogue flow.
|
|
988
|
-
|
|
989
|
-
Additionally, if the `ROUTE_TO_CALM_SLOT` is present in the tracker slots,
|
|
990
|
-
the action limit is adjusted to a separate limit for CALM-based flows.
|
|
975
|
+
Usually, we want to limit the number of predictions to the number of actions
|
|
976
|
+
that have been executed in the conversation so far. However, if the
|
|
977
|
+
conversation is currently in a state where the user is correcting the flow
|
|
978
|
+
we want to allow for more predictions to be made as we might be traversing
|
|
979
|
+
through a long flow.
|
|
991
980
|
|
|
992
981
|
Args:
|
|
993
982
|
tracker: instance of DialogueStateTracker.
|
|
@@ -995,18 +984,6 @@ class MessageProcessor:
|
|
|
995
984
|
Returns:
|
|
996
985
|
The maximum number of predictions to make.
|
|
997
986
|
"""
|
|
998
|
-
# Check if it is a CALM assistant and if so, that the `ROUTE_TO_CALM_SLOT`
|
|
999
|
-
# is either not present or set to `True`.
|
|
1000
|
-
# If it does, use the specific prediction limit for CALM assistants.
|
|
1001
|
-
# Otherwise, use the default prediction limit.
|
|
1002
|
-
if self.is_calm_assistant and (
|
|
1003
|
-
not tracker.has_coexistence_routing_slot
|
|
1004
|
-
or tracker.get_slot(ROUTE_TO_CALM_SLOT)
|
|
1005
|
-
):
|
|
1006
|
-
max_number_of_predictions = self.max_number_of_predictions_calm
|
|
1007
|
-
else:
|
|
1008
|
-
max_number_of_predictions = self.max_number_of_predictions
|
|
1009
|
-
|
|
1010
987
|
reversed_events = list(tracker.events)[::-1]
|
|
1011
988
|
is_conversation_in_flow_correction = False
|
|
1012
989
|
for e in reversed_events:
|
|
@@ -1021,10 +998,8 @@ class MessageProcessor:
|
|
|
1021
998
|
# allow for more predictions to be made as we might be traversing through
|
|
1022
999
|
# a long flow. We multiply the number of predictions by 10 to allow for
|
|
1023
1000
|
# more predictions to be made - the factor is a best guess.
|
|
1024
|
-
return max_number_of_predictions * 5
|
|
1025
|
-
|
|
1026
|
-
# Return the default
|
|
1027
|
-
return max_number_of_predictions
|
|
1001
|
+
return self.max_number_of_predictions * 5
|
|
1002
|
+
return self.max_number_of_predictions
|
|
1028
1003
|
|
|
1029
1004
|
def is_action_limit_reached(
|
|
1030
1005
|
self, tracker: DialogueStateTracker, should_predict_another_action: bool
|
|
@@ -1254,13 +1229,11 @@ class MessageProcessor:
|
|
|
1254
1229
|
tracker.update(events[0])
|
|
1255
1230
|
return self.should_predict_another_action(action.name())
|
|
1256
1231
|
except Exception:
|
|
1257
|
-
|
|
1258
|
-
"
|
|
1259
|
-
|
|
1260
|
-
|
|
1261
|
-
|
|
1262
|
-
f"Please check the logs of your action server for "
|
|
1263
|
-
f"more information.",
|
|
1232
|
+
logger.exception(
|
|
1233
|
+
f"Encountered an exception while running action '{action.name()}'."
|
|
1234
|
+
"Bot will continue, but the actions events are lost. "
|
|
1235
|
+
"Please check the logs of your action server for "
|
|
1236
|
+
"more information."
|
|
1264
1237
|
)
|
|
1265
1238
|
events = []
|
|
1266
1239
|
|
|
@@ -1414,27 +1387,3 @@ class MessageProcessor:
|
|
|
1414
1387
|
]
|
|
1415
1388
|
|
|
1416
1389
|
return len(filtered_commands) > 0
|
|
1417
|
-
|
|
1418
|
-
def _is_calm_assistant(self) -> bool:
|
|
1419
|
-
"""Inspects the nodes of the graph schema to determine whether
|
|
1420
|
-
any node is associated with the `FlowPolicy`, which is indicative of a
|
|
1421
|
-
CALM assistant setup.
|
|
1422
|
-
|
|
1423
|
-
Returns:
|
|
1424
|
-
bool: True if any node in the graph schema uses `FlowPolicy`.
|
|
1425
|
-
"""
|
|
1426
|
-
# Get the graph schema's nodes from the graph runner.
|
|
1427
|
-
nodes: dict[str, Any] = self.graph_runner._graph_schema.nodes # type: ignore[attr-defined]
|
|
1428
|
-
|
|
1429
|
-
flow_policy_class_path = "rasa.core.policies.flow_policy.FlowPolicy"
|
|
1430
|
-
# Iterate over the nodes and check if any node uses `FlowPolicy`.
|
|
1431
|
-
for node_name, schema_node in nodes.items():
|
|
1432
|
-
if (
|
|
1433
|
-
schema_node.uses is not None
|
|
1434
|
-
and f"{schema_node.uses.__module__}.{schema_node.uses.__name__}"
|
|
1435
|
-
== flow_policy_class_path
|
|
1436
|
-
):
|
|
1437
|
-
return True
|
|
1438
|
-
|
|
1439
|
-
# Return False if no node is found using `FlowPolicy`.
|
|
1440
|
-
return False
|
rasa/core/run.py
CHANGED
|
@@ -19,6 +19,7 @@ from typing import (
|
|
|
19
19
|
|
|
20
20
|
from sanic import Sanic
|
|
21
21
|
from sanic.worker.loader import AppLoader
|
|
22
|
+
from rasa.core.channels.development_inspector import DevelopmentInspectProxy
|
|
22
23
|
|
|
23
24
|
import rasa.core.utils
|
|
24
25
|
import rasa.shared.utils.common
|
|
@@ -32,8 +33,8 @@ from rasa.core import agent, channels, constants
|
|
|
32
33
|
from rasa.core.agent import Agent
|
|
33
34
|
from rasa.core.channels import console
|
|
34
35
|
from rasa.core.channels.channel import InputChannel
|
|
36
|
+
from rasa.core.persistor import StorageType
|
|
35
37
|
from rasa.core.utils import AvailableEndpoints
|
|
36
|
-
from rasa.nlu.persistor import StorageType
|
|
37
38
|
from rasa.plugin import plugin_manager
|
|
38
39
|
from rasa.shared.exceptions import RasaException
|
|
39
40
|
from rasa.shared.utils.yaml import read_config_file
|
|
@@ -224,6 +225,7 @@ def serve_application(
|
|
|
224
225
|
syslog_protocol: Optional[Text] = None,
|
|
225
226
|
request_timeout: Optional[int] = None,
|
|
226
227
|
server_listeners: Optional[List[Tuple[Callable, Text]]] = None,
|
|
228
|
+
inspect: Optional[bool] = False,
|
|
227
229
|
) -> None:
|
|
228
230
|
"""Run the API entrypoint."""
|
|
229
231
|
if not channel and not credentials:
|
|
@@ -231,6 +233,13 @@ def serve_application(
|
|
|
231
233
|
|
|
232
234
|
input_channels = create_http_input_channels(channel, credentials)
|
|
233
235
|
|
|
236
|
+
if inspect:
|
|
237
|
+
logger.info("Starting development inspector.")
|
|
238
|
+
input_channels = [DevelopmentInspectProxy(ic) for ic in input_channels]
|
|
239
|
+
|
|
240
|
+
# the inspector needs the api to retrieve slots and flows
|
|
241
|
+
enable_api = True
|
|
242
|
+
|
|
234
243
|
app = configure_app(
|
|
235
244
|
input_channels,
|
|
236
245
|
cors,
|
|
@@ -311,6 +320,7 @@ async def load_agent_on_start(
|
|
|
311
320
|
endpoints=endpoints,
|
|
312
321
|
loop=loop,
|
|
313
322
|
)
|
|
323
|
+
|
|
314
324
|
logger.info("Rasa server is up and running.")
|
|
315
325
|
return app.ctx.agent
|
|
316
326
|
|
|
@@ -23,6 +23,7 @@ VAULT_TRANSIT_MOUNT_POINT_ENV_NAME = "VAULT_TRANSIT_MOUNT_POINT"
|
|
|
23
23
|
VAULT_NAMESPACE_ENV_NAME = "VAULT_NAMESPACE"
|
|
24
24
|
VAULT_DEFAULT_RASA_SECRETS_PATH = "rasa-secrets"
|
|
25
25
|
VAULT_SECRET_MANAGER_NAME = "vault"
|
|
26
|
+
VAULT_MOUNT_POINT_ENV_NAME = "VAULT_MOUNT_POINT"
|
|
26
27
|
|
|
27
28
|
|
|
28
29
|
VAULT_ENDPOINT_URL_LABEL = "url"
|
|
@@ -30,3 +31,6 @@ VAULT_ENDPOINT_TOKEN_LABEL = "token"
|
|
|
30
31
|
VAULT_ENDPOINT_SECRETS_PATH_LABEL = "secrets_path"
|
|
31
32
|
VAULT_ENDPOINT_TRANSIT_MOUNT_POINT_LABEL = "transit_mount_point"
|
|
32
33
|
VAULT_ENDPOINT_NAMESPACE_LABEL = "namespace"
|
|
34
|
+
VAULT_ENDPOINT_MOUNT_POINT_LABEL = "mount_point"
|
|
35
|
+
|
|
36
|
+
VAULT_MOUNT_POINT_DEFAULT_VALUE = "secret"
|
|
@@ -7,9 +7,11 @@ from rasa.utils.endpoints import EndpointConfig, read_endpoint_config
|
|
|
7
7
|
from rasa.core.secrets_manager.constants import (
|
|
8
8
|
SECRET_MANAGER_ENV_NAME,
|
|
9
9
|
VAULT_DEFAULT_RASA_SECRETS_PATH,
|
|
10
|
+
VAULT_ENDPOINT_MOUNT_POINT_LABEL,
|
|
10
11
|
VAULT_ENDPOINT_NAMESPACE_LABEL,
|
|
11
12
|
VAULT_ENDPOINT_SECRETS_PATH_LABEL,
|
|
12
13
|
VAULT_ENDPOINT_TRANSIT_MOUNT_POINT_LABEL,
|
|
14
|
+
VAULT_MOUNT_POINT_ENV_NAME,
|
|
13
15
|
VAULT_NAMESPACE_ENV_NAME,
|
|
14
16
|
VAULT_RASA_SECRETS_PATH_ENV_NAME,
|
|
15
17
|
VAULT_SECRET_MANAGER_NAME,
|
|
@@ -48,6 +50,7 @@ def create(config: SecretManagerConfig) -> Optional[SecretsManager]:
|
|
|
48
50
|
transit_mount_point=vault_config.transit_mount_point,
|
|
49
51
|
secrets_path=vault_config.secrets_path,
|
|
50
52
|
namespace=vault_config.namespace,
|
|
53
|
+
mount_point=vault_config.mount_point,
|
|
51
54
|
)
|
|
52
55
|
|
|
53
56
|
return secret_manager
|
|
@@ -79,6 +82,7 @@ def read_vault_endpoint_config(
|
|
|
79
82
|
)
|
|
80
83
|
secrets_path = endpoint_config.kwargs.get(VAULT_ENDPOINT_SECRETS_PATH_LABEL)
|
|
81
84
|
namespace = endpoint_config.kwargs.get(VAULT_ENDPOINT_NAMESPACE_LABEL)
|
|
85
|
+
mount_point = endpoint_config.kwargs.get(VAULT_ENDPOINT_MOUNT_POINT_LABEL)
|
|
82
86
|
|
|
83
87
|
return VaultSecretManagerNonStrictConfig(
|
|
84
88
|
url=url,
|
|
@@ -86,6 +90,7 @@ def read_vault_endpoint_config(
|
|
|
86
90
|
transit_mount_point=transit_mount_point,
|
|
87
91
|
secrets_path=secrets_path or VAULT_DEFAULT_RASA_SECRETS_PATH,
|
|
88
92
|
namespace=namespace,
|
|
93
|
+
mount_point=mount_point,
|
|
89
94
|
)
|
|
90
95
|
|
|
91
96
|
return None
|
|
@@ -102,6 +107,7 @@ def read_vault_env_vars() -> VaultSecretManagerNonStrictConfig:
|
|
|
102
107
|
transit_mount_point = os.getenv(VAULT_TRANSIT_MOUNT_POINT_ENV_NAME)
|
|
103
108
|
secrets_path = os.getenv(VAULT_RASA_SECRETS_PATH_ENV_NAME)
|
|
104
109
|
namespace = os.getenv(VAULT_NAMESPACE_ENV_NAME)
|
|
110
|
+
mount_point = os.getenv(VAULT_MOUNT_POINT_ENV_NAME)
|
|
105
111
|
|
|
106
112
|
return VaultSecretManagerNonStrictConfig(
|
|
107
113
|
url=url,
|
|
@@ -109,6 +115,7 @@ def read_vault_env_vars() -> VaultSecretManagerNonStrictConfig:
|
|
|
109
115
|
transit_mount_point=transit_mount_point,
|
|
110
116
|
secrets_path=secrets_path,
|
|
111
117
|
namespace=namespace,
|
|
118
|
+
mount_point=mount_point,
|
|
112
119
|
)
|
|
113
120
|
|
|
114
121
|
|
|
@@ -149,6 +156,7 @@ def read_vault_config(
|
|
|
149
156
|
f"{VAULT_RASA_SECRETS_PATH_ENV_NAME} = {env_config.secrets_path}, "
|
|
150
157
|
f"{VAULT_TRANSIT_MOUNT_POINT_ENV_NAME} = {env_config.transit_mount_point}. "
|
|
151
158
|
f"{VAULT_NAMESPACE_ENV_NAME} = {env_config.namespace}. "
|
|
159
|
+
f"{VAULT_MOUNT_POINT_ENV_NAME} = {env_config.mount_point}. "
|
|
152
160
|
)
|
|
153
161
|
|
|
154
162
|
|
|
@@ -15,6 +15,7 @@ from rasa.utils.endpoints import EndpointConfig
|
|
|
15
15
|
from rasa.core.secrets_manager.constants import (
|
|
16
16
|
TRACKER_STORE_ENDPOINT_TYPE,
|
|
17
17
|
TRANSIT_KEY_FOR_ENCRYPTION_LABEL,
|
|
18
|
+
VAULT_MOUNT_POINT_DEFAULT_VALUE,
|
|
18
19
|
VAULT_SECRET_MANAGER_NAME,
|
|
19
20
|
)
|
|
20
21
|
from rasa.core.secrets_manager.endpoints import (
|
|
@@ -181,6 +182,7 @@ class VaultSecretsManager(SecretsManager):
|
|
|
181
182
|
secrets_path: Text,
|
|
182
183
|
transit_mount_point: Optional[Text] = None,
|
|
183
184
|
namespace: Optional[Text] = None,
|
|
185
|
+
mount_point: Optional[Text] = None,
|
|
184
186
|
):
|
|
185
187
|
"""Initialise the VaultSecretsManager.
|
|
186
188
|
|
|
@@ -190,11 +192,13 @@ class VaultSecretsManager(SecretsManager):
|
|
|
190
192
|
secrets_path: The path to the secrets in the vault server.
|
|
191
193
|
transit_mount_point: The mount point of the transit engine.
|
|
192
194
|
namespace: The namespace in which secrets reside in.
|
|
195
|
+
mount_point: The mount point of the kv engine.
|
|
193
196
|
"""
|
|
194
197
|
self.host = host
|
|
195
198
|
self.transit_mount_point = transit_mount_point
|
|
196
199
|
self.token = token
|
|
197
200
|
self.secrets_path = secrets_path
|
|
201
|
+
self.mount_point = mount_point or VAULT_MOUNT_POINT_DEFAULT_VALUE
|
|
198
202
|
self.namespace = namespace
|
|
199
203
|
|
|
200
204
|
# Create client
|
|
@@ -236,7 +240,7 @@ class VaultSecretsManager(SecretsManager):
|
|
|
236
240
|
"""
|
|
237
241
|
logger.info(f"Loading secrets from vault server at {self.host}.")
|
|
238
242
|
read_response = self.client.secrets.kv.read_secret_version(
|
|
239
|
-
mount_point=
|
|
243
|
+
mount_point=self.mount_point, path=self.secrets_path
|
|
240
244
|
)
|
|
241
245
|
|
|
242
246
|
secrets = read_response["data"]["data"]
|
|
@@ -455,6 +459,7 @@ class VaultSecretManagerConfig(SecretManagerConfig):
|
|
|
455
459
|
secrets_path: Text,
|
|
456
460
|
transit_mount_point: Text = "transit",
|
|
457
461
|
namespace: Optional[Text] = None,
|
|
462
|
+
mount_point: Optional[Text] = None,
|
|
458
463
|
) -> None:
|
|
459
464
|
"""Initialise the VaultSecretManagerConfig.
|
|
460
465
|
|
|
@@ -471,6 +476,7 @@ class VaultSecretManagerConfig(SecretManagerConfig):
|
|
|
471
476
|
self.secrets_path = secrets_path
|
|
472
477
|
self.transit_mount_point = transit_mount_point
|
|
473
478
|
self.namespace = namespace
|
|
479
|
+
self.mount_point = mount_point
|
|
474
480
|
|
|
475
481
|
|
|
476
482
|
@dataclass
|
|
@@ -486,6 +492,7 @@ class VaultSecretManagerNonStrictConfig:
|
|
|
486
492
|
secrets_path: Optional[Text]
|
|
487
493
|
transit_mount_point: Optional[Text]
|
|
488
494
|
namespace: Optional[Text] = None
|
|
495
|
+
mount_point: Optional[Text] = None
|
|
489
496
|
|
|
490
497
|
def is_empty(self) -> bool:
|
|
491
498
|
"""Check if all the values are empty."""
|
|
@@ -495,6 +502,7 @@ class VaultSecretManagerNonStrictConfig:
|
|
|
495
502
|
and (self.secrets_path is None or self.secrets_path == "")
|
|
496
503
|
and (self.transit_mount_point is None or self.transit_mount_point == "")
|
|
497
504
|
and (self.namespace is None or self.namespace == "")
|
|
505
|
+
and (self.mount_point is None or self.mount_point == "")
|
|
498
506
|
)
|
|
499
507
|
|
|
500
508
|
def is_valid(self) -> bool:
|
|
@@ -516,6 +524,7 @@ class VaultSecretManagerNonStrictConfig:
|
|
|
516
524
|
and self.secrets_path != ""
|
|
517
525
|
and self._is_optional_value_valid(self.transit_mount_point)
|
|
518
526
|
and self._is_optional_value_valid(self.namespace)
|
|
527
|
+
and self._is_optional_value_valid(self.mount_point)
|
|
519
528
|
)
|
|
520
529
|
|
|
521
530
|
@staticmethod
|
|
@@ -547,6 +556,7 @@ class VaultSecretManagerNonStrictConfig:
|
|
|
547
556
|
secrets_path=self.secrets_path or other.secrets_path,
|
|
548
557
|
transit_mount_point=self.transit_mount_point or other.transit_mount_point,
|
|
549
558
|
namespace=self.namespace or other.namespace,
|
|
559
|
+
mount_point=self.mount_point or other.mount_point,
|
|
550
560
|
)
|
|
551
561
|
|
|
552
562
|
|
|
@@ -1688,7 +1688,7 @@ def run_interactive_learning(
|
|
|
1688
1688
|
p = None
|
|
1689
1689
|
|
|
1690
1690
|
app = run.configure_app(port=port, conversation_id="default", enable_api=True)
|
|
1691
|
-
endpoints = AvailableEndpoints.
|
|
1691
|
+
endpoints = AvailableEndpoints.read_endpoints(server_args.get("endpoints"))
|
|
1692
1692
|
|
|
1693
1693
|
# before_server_start handlers make sure the agent is loaded before the
|
|
1694
1694
|
# interactive learning IO starts
|
rasa/core/utils.py
CHANGED
|
@@ -171,8 +171,6 @@ def is_limit_reached(num_messages: int, limit: Optional[int]) -> bool:
|
|
|
171
171
|
class AvailableEndpoints:
|
|
172
172
|
"""Collection of configured endpoints."""
|
|
173
173
|
|
|
174
|
-
_instance = None
|
|
175
|
-
|
|
176
174
|
@classmethod
|
|
177
175
|
def read_endpoints(cls, endpoint_file: Text) -> "AvailableEndpoints":
|
|
178
176
|
"""Read the different endpoints from a yaml file."""
|
|
@@ -219,14 +217,6 @@ class AvailableEndpoints:
|
|
|
219
217
|
self.event_broker = event_broker
|
|
220
218
|
self.vector_store = vector_store
|
|
221
219
|
|
|
222
|
-
@classmethod
|
|
223
|
-
def get_instance(cls, endpoint_file: Optional[Text] = None) -> "AvailableEndpoints":
|
|
224
|
-
"""Get the singleton instance of AvailableEndpoints."""
|
|
225
|
-
# Ensure that the instance is initialized only once.
|
|
226
|
-
if cls._instance is None:
|
|
227
|
-
cls._instance = cls.read_endpoints(endpoint_file)
|
|
228
|
-
return cls._instance
|
|
229
|
-
|
|
230
220
|
|
|
231
221
|
def read_endpoints_from_path(
|
|
232
222
|
endpoints_path: Optional[Union[Path, Text]] = None,
|
|
@@ -244,7 +234,7 @@ def read_endpoints_from_path(
|
|
|
244
234
|
endpoints_config_path = cli_utils.get_validated_path(
|
|
245
235
|
endpoints_path, "endpoints", DEFAULT_ENDPOINTS_PATH, True
|
|
246
236
|
)
|
|
247
|
-
return AvailableEndpoints.
|
|
237
|
+
return AvailableEndpoints.read_endpoints(endpoints_config_path)
|
|
248
238
|
|
|
249
239
|
|
|
250
240
|
def _lock_store_is_multi_worker_compatible(
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import importlib
|
|
4
|
+
import os
|
|
4
5
|
from typing import Any, Dict, List, Optional
|
|
5
6
|
|
|
6
7
|
import structlog
|
|
@@ -21,6 +22,7 @@ from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
|
21
22
|
from rasa.engine.storage.resource import Resource
|
|
22
23
|
from rasa.engine.storage.storage import ModelStorage
|
|
23
24
|
from rasa.shared.constants import (
|
|
25
|
+
LLM_API_HEALTH_CHECK_ENV_VAR,
|
|
24
26
|
ROUTE_TO_CALM_SLOT,
|
|
25
27
|
PROMPT_CONFIG_KEY,
|
|
26
28
|
PROVIDER_CONFIG_KEY,
|
|
@@ -36,6 +38,7 @@ from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
|
36
38
|
from rasa.shared.utils.llm import (
|
|
37
39
|
DEFAULT_OPENAI_CHAT_MODEL_NAME,
|
|
38
40
|
get_prompt_template,
|
|
41
|
+
llm_api_health_check,
|
|
39
42
|
llm_factory,
|
|
40
43
|
try_instantiate_llm_client,
|
|
41
44
|
)
|
|
@@ -130,12 +133,16 @@ class LLMBasedRouter(GraphComponent):
|
|
|
130
133
|
def train(self, training_data: TrainingData) -> Resource:
|
|
131
134
|
"""Train the intent classifier on a data set."""
|
|
132
135
|
# Validate llm configuration
|
|
133
|
-
try_instantiate_llm_client(
|
|
136
|
+
llm_client = try_instantiate_llm_client(
|
|
134
137
|
self.config.get(LLM_CONFIG_KEY),
|
|
135
138
|
DEFAULT_LLM_CONFIG,
|
|
136
139
|
"llm_based_router.train",
|
|
137
|
-
|
|
140
|
+
LLMBasedRouter.__name__,
|
|
138
141
|
)
|
|
142
|
+
if os.getenv(LLM_API_HEALTH_CHECK_ENV_VAR, "true").lower() == "true":
|
|
143
|
+
llm_api_health_check(
|
|
144
|
+
llm_client, "llm_based_router.train", LLMBasedRouter.__name__
|
|
145
|
+
)
|
|
139
146
|
|
|
140
147
|
self.persist()
|
|
141
148
|
return self._resource
|
|
@@ -161,14 +168,7 @@ class LLMBasedRouter(GraphComponent):
|
|
|
161
168
|
"llm_based_router.load.failed", error=e, resource=resource.name
|
|
162
169
|
)
|
|
163
170
|
|
|
164
|
-
|
|
165
|
-
try_instantiate_llm_client(
|
|
166
|
-
router.config.get(LLM_CONFIG_KEY),
|
|
167
|
-
DEFAULT_LLM_CONFIG,
|
|
168
|
-
"llm_based_router.load",
|
|
169
|
-
LLMBasedRouter.__name__,
|
|
170
|
-
)
|
|
171
|
-
return router
|
|
171
|
+
return cls(config, model_storage, resource, prompt_template=prompt_template)
|
|
172
172
|
|
|
173
173
|
@classmethod
|
|
174
174
|
def create(
|
|
@@ -32,6 +32,7 @@ from rasa.dialogue_understanding.commands.change_flow_command import ChangeFlowC
|
|
|
32
32
|
from rasa.dialogue_understanding.commands.session_start_command import (
|
|
33
33
|
SessionStartCommand,
|
|
34
34
|
)
|
|
35
|
+
from rasa.dialogue_understanding.commands.session_end_command import SessionEndCommand
|
|
35
36
|
|
|
36
37
|
__all__ = [
|
|
37
38
|
"Command",
|
|
@@ -51,5 +52,6 @@ __all__ = [
|
|
|
51
52
|
"NoopCommand",
|
|
52
53
|
"ChangeFlowCommand",
|
|
53
54
|
"SessionStartCommand",
|
|
55
|
+
"SessionEndCommand",
|
|
54
56
|
"RestartCommand",
|
|
55
57
|
]
|