rasa-pro 3.11.0rc3__py3-none-any.whl → 3.11.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/__main__.py +9 -3
- rasa/cli/studio/upload.py +0 -15
- rasa/cli/utils.py +1 -1
- rasa/core/channels/development_inspector.py +4 -1
- rasa/core/channels/voice_stream/asr/asr_engine.py +19 -1
- rasa/core/channels/voice_stream/asr/azure.py +11 -2
- rasa/core/channels/voice_stream/asr/deepgram.py +4 -3
- rasa/core/channels/voice_stream/tts/azure.py +3 -1
- rasa/core/channels/voice_stream/tts/cartesia.py +3 -3
- rasa/core/channels/voice_stream/tts/tts_engine.py +10 -1
- rasa/core/information_retrieval/qdrant.py +1 -0
- rasa/core/persistor.py +93 -49
- rasa/core/policies/flows/flow_executor.py +18 -8
- rasa/core/processor.py +7 -5
- rasa/e2e_test/aggregate_test_stats_calculator.py +11 -1
- rasa/e2e_test/assertions.py +133 -16
- rasa/e2e_test/assertions_schema.yml +23 -0
- rasa/e2e_test/e2e_test_runner.py +2 -2
- rasa/engine/loader.py +12 -0
- rasa/engine/validation.py +291 -79
- rasa/model_manager/config.py +8 -0
- rasa/model_manager/model_api.py +166 -61
- rasa/model_manager/runner_service.py +31 -26
- rasa/model_manager/trainer_service.py +14 -23
- rasa/model_manager/warm_rasa_process.py +187 -0
- rasa/model_service.py +3 -5
- rasa/model_training.py +3 -1
- rasa/shared/constants.py +22 -0
- rasa/shared/core/domain.py +8 -5
- rasa/shared/core/flows/yaml_flows_io.py +13 -4
- rasa/shared/importers/importer.py +19 -2
- rasa/shared/importers/rasa.py +5 -1
- rasa/shared/nlu/training_data/formats/rasa_yaml.py +18 -3
- rasa/shared/providers/_utils.py +79 -0
- rasa/shared/providers/embedding/default_litellm_embedding_client.py +24 -0
- rasa/shared/providers/llm/default_litellm_llm_client.py +24 -0
- rasa/shared/utils/common.py +29 -2
- rasa/shared/utils/health_check/health_check.py +26 -24
- rasa/shared/utils/yaml.py +116 -31
- rasa/studio/data_handler.py +3 -1
- rasa/studio/upload.py +119 -57
- rasa/validator.py +40 -4
- rasa/version.py +1 -1
- {rasa_pro-3.11.0rc3.dist-info → rasa_pro-3.11.1.dist-info}/METADATA +2 -2
- {rasa_pro-3.11.0rc3.dist-info → rasa_pro-3.11.1.dist-info}/RECORD +48 -46
- {rasa_pro-3.11.0rc3.dist-info → rasa_pro-3.11.1.dist-info}/NOTICE +0 -0
- {rasa_pro-3.11.0rc3.dist-info → rasa_pro-3.11.1.dist-info}/WHEEL +0 -0
- {rasa_pro-3.11.0rc3.dist-info → rasa_pro-3.11.1.dist-info}/entry_points.txt +0 -0
rasa/__main__.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import argparse
|
|
2
|
+
from typing import Optional, List
|
|
2
3
|
import structlog
|
|
3
4
|
import os
|
|
4
5
|
import platform
|
|
@@ -97,12 +98,17 @@ def print_version() -> None:
|
|
|
97
98
|
print(f"License Expires : {get_license_expiration_date()}")
|
|
98
99
|
|
|
99
100
|
|
|
100
|
-
def main() -> None:
|
|
101
|
-
"""Run as standalone python application.
|
|
101
|
+
def main(raw_arguments: Optional[List[str]] = None) -> None:
|
|
102
|
+
"""Run as standalone python application.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
raw_arguments: Arguments to parse. If not provided,
|
|
106
|
+
arguments will be taken from the command line.
|
|
107
|
+
"""
|
|
102
108
|
warn_if_rasa_plus_package_installed()
|
|
103
109
|
parse_last_positional_argument_as_model_path()
|
|
104
110
|
arg_parser = create_argument_parser()
|
|
105
|
-
cmdline_arguments = arg_parser.parse_args()
|
|
111
|
+
cmdline_arguments = arg_parser.parse_args(raw_arguments)
|
|
106
112
|
|
|
107
113
|
log_level = getattr(cmdline_arguments, "loglevel", None)
|
|
108
114
|
logging_config_file = getattr(cmdline_arguments, "logging_config_file", None)
|
rasa/cli/studio/upload.py
CHANGED
|
@@ -32,25 +32,10 @@ def add_subparser(
|
|
|
32
32
|
set_upload_arguments(upload_parser)
|
|
33
33
|
|
|
34
34
|
|
|
35
|
-
def add_flows_param(
|
|
36
|
-
parser: argparse.ArgumentParser,
|
|
37
|
-
help_text: str = "Name of flows file to upload to Rasa Studio. Works with --calm",
|
|
38
|
-
default_path: str = "flows.yml",
|
|
39
|
-
) -> None:
|
|
40
|
-
parser.add_argument(
|
|
41
|
-
"--flows",
|
|
42
|
-
default=default_path,
|
|
43
|
-
nargs="+",
|
|
44
|
-
type=str,
|
|
45
|
-
help=help_text,
|
|
46
|
-
)
|
|
47
|
-
|
|
48
|
-
|
|
49
35
|
def set_upload_arguments(parser: argparse.ArgumentParser) -> None:
|
|
50
36
|
"""Add arguments for running `rasa upload`."""
|
|
51
37
|
add_data_param(parser, data_type="training")
|
|
52
38
|
add_domain_param(parser)
|
|
53
|
-
add_flows_param(parser)
|
|
54
39
|
add_config_param(parser)
|
|
55
40
|
add_endpoint_param(parser, help_text="Path to the endpoints file.")
|
|
56
41
|
|
rasa/cli/utils.py
CHANGED
|
@@ -305,7 +305,7 @@ def _validate_domain(validator: "Validator") -> bool:
|
|
|
305
305
|
valid_forms_in_stories_rules = validator.verify_forms_in_stories_rules()
|
|
306
306
|
valid_form_slots = validator.verify_form_slots()
|
|
307
307
|
valid_slot_mappings = validator.verify_slot_mappings()
|
|
308
|
-
valid_responses = validator.
|
|
308
|
+
valid_responses = validator.check_for_no_empty_parenthesis_in_responses()
|
|
309
309
|
valid_buttons = validator.validate_button_payloads()
|
|
310
310
|
return (
|
|
311
311
|
valid_domain_validity
|
|
@@ -128,9 +128,12 @@ class DevelopmentInspectProxy(InputChannel):
|
|
|
128
128
|
|
|
129
129
|
inspect_path = app.url_for(f"{app.name}.{underlying_webhook.name}.inspect")
|
|
130
130
|
|
|
131
|
+
# replace 0.0.0.0 with localhost
|
|
132
|
+
serve_location = app.serve_location.replace("0.0.0.0", "localhost")
|
|
133
|
+
|
|
131
134
|
print_info(
|
|
132
135
|
f"Development inspector for channel {self.name()} is running. To "
|
|
133
|
-
f"inspect conversations, visit {
|
|
136
|
+
f"inspect conversations, visit {serve_location}{inspect_path}"
|
|
134
137
|
)
|
|
135
138
|
|
|
136
139
|
underlying_webhook.add_websocket_route(
|
|
@@ -1,5 +1,14 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
|
-
from typing import
|
|
2
|
+
from typing import (
|
|
3
|
+
Dict,
|
|
4
|
+
AsyncIterator,
|
|
5
|
+
Any,
|
|
6
|
+
Generic,
|
|
7
|
+
Optional,
|
|
8
|
+
Tuple,
|
|
9
|
+
Type,
|
|
10
|
+
TypeVar,
|
|
11
|
+
)
|
|
3
12
|
|
|
4
13
|
from websockets.legacy.client import WebSocketClientProtocol
|
|
5
14
|
|
|
@@ -7,6 +16,7 @@ from rasa.core.channels.voice_stream.asr.asr_event import ASREvent
|
|
|
7
16
|
from rasa.core.channels.voice_stream.audio_bytes import RasaAudioBytes
|
|
8
17
|
from rasa.core.channels.voice_stream.util import MergeableConfig
|
|
9
18
|
from rasa.shared.exceptions import ConnectionException
|
|
19
|
+
from rasa.shared.utils.common import validate_environment
|
|
10
20
|
|
|
11
21
|
T = TypeVar("T", bound="ASREngineConfig")
|
|
12
22
|
E = TypeVar("E", bound="ASREngine")
|
|
@@ -18,9 +28,17 @@ class ASREngineConfig(MergeableConfig):
|
|
|
18
28
|
|
|
19
29
|
|
|
20
30
|
class ASREngine(Generic[T]):
|
|
31
|
+
required_env_vars: Tuple[str, ...] = ()
|
|
32
|
+
required_packages: Tuple[str, ...] = ()
|
|
33
|
+
|
|
21
34
|
def __init__(self, config: Optional[T] = None):
|
|
22
35
|
self.config = self.get_default_config().merge(config)
|
|
23
36
|
self.asr_socket: Optional[WebSocketClientProtocol] = None
|
|
37
|
+
validate_environment(
|
|
38
|
+
self.required_env_vars,
|
|
39
|
+
self.required_packages,
|
|
40
|
+
f"ASR Engine {self.__class__.__name__}",
|
|
41
|
+
)
|
|
24
42
|
|
|
25
43
|
async def connect(self) -> None:
|
|
26
44
|
self.asr_socket = await self.open_websocket_connection()
|
|
@@ -10,6 +10,7 @@ from rasa.core.channels.voice_stream.asr.asr_event import (
|
|
|
10
10
|
UserIsSpeaking,
|
|
11
11
|
)
|
|
12
12
|
from rasa.core.channels.voice_stream.audio_bytes import HERTZ, RasaAudioBytes
|
|
13
|
+
from rasa.shared.constants import AZURE_SPEECH_API_KEY_ENV_VAR
|
|
13
14
|
from rasa.shared.exceptions import ConnectionException
|
|
14
15
|
|
|
15
16
|
|
|
@@ -20,10 +21,14 @@ class AzureASRConfig(ASREngineConfig):
|
|
|
20
21
|
|
|
21
22
|
|
|
22
23
|
class AzureASR(ASREngine[AzureASRConfig]):
|
|
24
|
+
required_env_vars = (AZURE_SPEECH_API_KEY_ENV_VAR,)
|
|
25
|
+
required_packages = ("azure.cognitiveservices.speech",)
|
|
26
|
+
|
|
23
27
|
def __init__(self, config: Optional[AzureASRConfig] = None):
|
|
28
|
+
super().__init__(config)
|
|
29
|
+
|
|
24
30
|
import azure.cognitiveservices.speech as speechsdk
|
|
25
31
|
|
|
26
|
-
super().__init__(config)
|
|
27
32
|
self.speech_recognizer: Optional[speechsdk.SpeechRecognizer] = None
|
|
28
33
|
self.stream: Optional[speechsdk.audio.PushAudioInputStream] = None
|
|
29
34
|
self.is_recognizing = False
|
|
@@ -31,6 +36,10 @@ class AzureASR(ASREngine[AzureASRConfig]):
|
|
|
31
36
|
asyncio.Queue()
|
|
32
37
|
)
|
|
33
38
|
|
|
39
|
+
@staticmethod
|
|
40
|
+
def validate_environment() -> None:
|
|
41
|
+
"""Make sure all needed requirements for this component are met."""
|
|
42
|
+
|
|
34
43
|
def signal_user_is_speaking(self, event: Any) -> None:
|
|
35
44
|
"""Replace the azure event with a generic is speaking event."""
|
|
36
45
|
self.fill_queue(UserIsSpeaking())
|
|
@@ -43,7 +52,7 @@ class AzureASR(ASREngine[AzureASRConfig]):
|
|
|
43
52
|
import azure.cognitiveservices.speech as speechsdk
|
|
44
53
|
|
|
45
54
|
speech_config = speechsdk.SpeechConfig(
|
|
46
|
-
subscription=os.environ[
|
|
55
|
+
subscription=os.environ[AZURE_SPEECH_API_KEY_ENV_VAR],
|
|
47
56
|
region=self.config.speech_region,
|
|
48
57
|
)
|
|
49
58
|
audio_format = speechsdk.audio.AudioStreamFormat(
|
|
@@ -13,8 +13,7 @@ from rasa.core.channels.voice_stream.asr.asr_event import (
|
|
|
13
13
|
UserIsSpeaking,
|
|
14
14
|
)
|
|
15
15
|
from rasa.core.channels.voice_stream.audio_bytes import HERTZ, RasaAudioBytes
|
|
16
|
-
|
|
17
|
-
DEEPGRAM_API_KEY = "DEEPGRAM_API_KEY"
|
|
16
|
+
from rasa.shared.constants import DEEPGRAM_API_KEY_ENV_VAR
|
|
18
17
|
|
|
19
18
|
|
|
20
19
|
@dataclass
|
|
@@ -28,13 +27,15 @@ class DeepgramASRConfig(ASREngineConfig):
|
|
|
28
27
|
|
|
29
28
|
|
|
30
29
|
class DeepgramASR(ASREngine[DeepgramASRConfig]):
|
|
30
|
+
required_env_vars = (DEEPGRAM_API_KEY_ENV_VAR,)
|
|
31
|
+
|
|
31
32
|
def __init__(self, config: Optional[DeepgramASRConfig] = None):
|
|
32
33
|
super().__init__(config)
|
|
33
34
|
self.accumulated_transcript = ""
|
|
34
35
|
|
|
35
36
|
async def open_websocket_connection(self) -> WebSocketClientProtocol:
|
|
36
37
|
"""Connect to the ASR system."""
|
|
37
|
-
deepgram_api_key = os.environ[
|
|
38
|
+
deepgram_api_key = os.environ[DEEPGRAM_API_KEY_ENV_VAR]
|
|
38
39
|
extra_headers = {"Authorization": f"Token {deepgram_api_key}"}
|
|
39
40
|
api_url = self._get_api_url()
|
|
40
41
|
query_params = self._get_query_params()
|
|
@@ -12,6 +12,7 @@ from rasa.core.channels.voice_stream.tts.tts_engine import (
|
|
|
12
12
|
TTSEngineConfig,
|
|
13
13
|
TTSError,
|
|
14
14
|
)
|
|
15
|
+
from rasa.shared.constants import AZURE_SPEECH_API_KEY_ENV_VAR
|
|
15
16
|
from rasa.shared.exceptions import ConnectionException
|
|
16
17
|
|
|
17
18
|
|
|
@@ -25,6 +26,7 @@ class AzureTTSConfig(TTSEngineConfig):
|
|
|
25
26
|
|
|
26
27
|
class AzureTTS(TTSEngine[AzureTTSConfig]):
|
|
27
28
|
session: Optional[aiohttp.ClientSession] = None
|
|
29
|
+
required_env_vars = (AZURE_SPEECH_API_KEY_ENV_VAR,)
|
|
28
30
|
|
|
29
31
|
def __init__(self, config: Optional[AzureTTSConfig] = None):
|
|
30
32
|
super().__init__(config)
|
|
@@ -66,7 +68,7 @@ class AzureTTS(TTSEngine[AzureTTSConfig]):
|
|
|
66
68
|
|
|
67
69
|
@staticmethod
|
|
68
70
|
def get_request_headers() -> dict[str, str]:
|
|
69
|
-
azure_speech_api_key = os.environ[
|
|
71
|
+
azure_speech_api_key = os.environ[AZURE_SPEECH_API_KEY_ENV_VAR]
|
|
70
72
|
return {
|
|
71
73
|
"Ocp-Apim-Subscription-Key": azure_speech_api_key,
|
|
72
74
|
"Content-Type": "application/ssml+xml",
|
|
@@ -11,12 +11,11 @@ from rasa.core.channels.voice_stream.tts.tts_engine import (
|
|
|
11
11
|
|
|
12
12
|
from rasa.core.channels.voice_stream.audio_bytes import HERTZ, RasaAudioBytes
|
|
13
13
|
from rasa.core.channels.voice_stream.tts.tts_engine import TTSEngine, TTSError
|
|
14
|
+
from rasa.shared.constants import CARTESIA_API_KEY_ENV_VAR
|
|
14
15
|
from rasa.shared.exceptions import ConnectionException
|
|
15
16
|
|
|
16
17
|
structlogger = structlog.get_logger()
|
|
17
18
|
|
|
18
|
-
CARTESIA_API_KEY = "CARTESIA_API_KEY"
|
|
19
|
-
|
|
20
19
|
|
|
21
20
|
@dataclass
|
|
22
21
|
class CartesiaTTSConfig(TTSEngineConfig):
|
|
@@ -26,6 +25,7 @@ class CartesiaTTSConfig(TTSEngineConfig):
|
|
|
26
25
|
|
|
27
26
|
class CartesiaTTS(TTSEngine[CartesiaTTSConfig]):
|
|
28
27
|
session: Optional[aiohttp.ClientSession] = None
|
|
28
|
+
required_env_vars = (CARTESIA_API_KEY_ENV_VAR,)
|
|
29
29
|
|
|
30
30
|
def __init__(self, config: Optional[CartesiaTTSConfig] = None):
|
|
31
31
|
super().__init__(config)
|
|
@@ -62,7 +62,7 @@ class CartesiaTTS(TTSEngine[CartesiaTTSConfig]):
|
|
|
62
62
|
|
|
63
63
|
@staticmethod
|
|
64
64
|
def get_request_headers(config: CartesiaTTSConfig) -> dict[str, str]:
|
|
65
|
-
cartesia_api_key = os.environ[
|
|
65
|
+
cartesia_api_key = os.environ[CARTESIA_API_KEY_ENV_VAR]
|
|
66
66
|
return {
|
|
67
67
|
"Cartesia-Version": str(config.version),
|
|
68
68
|
"Content-Type": "application/json",
|
|
@@ -1,9 +1,10 @@
|
|
|
1
|
-
from typing import AsyncIterator, Dict, Generic, Optional, Type, TypeVar
|
|
1
|
+
from typing import AsyncIterator, Dict, Generic, Optional, Tuple, Type, TypeVar
|
|
2
2
|
from dataclasses import dataclass
|
|
3
3
|
|
|
4
4
|
from rasa.core.channels.voice_stream.audio_bytes import RasaAudioBytes
|
|
5
5
|
from rasa.core.channels.voice_stream.util import MergeableConfig
|
|
6
6
|
from rasa.shared.exceptions import RasaException
|
|
7
|
+
from rasa.shared.utils.common import validate_environment
|
|
7
8
|
|
|
8
9
|
|
|
9
10
|
class TTSError(RasaException):
|
|
@@ -22,8 +23,16 @@ class TTSEngineConfig(MergeableConfig):
|
|
|
22
23
|
|
|
23
24
|
|
|
24
25
|
class TTSEngine(Generic[T]):
|
|
26
|
+
required_env_vars: Tuple[str, ...] = ()
|
|
27
|
+
required_packages: Tuple[str, ...] = ()
|
|
28
|
+
|
|
25
29
|
def __init__(self, config: Optional[T] = None):
|
|
26
30
|
self.config = self.get_default_config().merge(config)
|
|
31
|
+
validate_environment(
|
|
32
|
+
self.required_env_vars,
|
|
33
|
+
self.required_packages,
|
|
34
|
+
f"TTS Engine {self.__class__.__name__}",
|
|
35
|
+
)
|
|
27
36
|
|
|
28
37
|
async def close_connection(self) -> None:
|
|
29
38
|
"""Cleanup the connection if necessary."""
|
|
@@ -62,6 +62,7 @@ class Qdrant_Store(InformationRetrieval):
|
|
|
62
62
|
embeddings=self.embeddings,
|
|
63
63
|
content_payload_key=params.get("content_payload_key", "text"),
|
|
64
64
|
metadata_payload_key=params.get("metadata_payload_key", "metadata"),
|
|
65
|
+
vector_name=params.get("vector_name", None),
|
|
65
66
|
)
|
|
66
67
|
|
|
67
68
|
async def search(
|
rasa/core/persistor.py
CHANGED
|
@@ -4,6 +4,7 @@ import abc
|
|
|
4
4
|
import os
|
|
5
5
|
import shutil
|
|
6
6
|
from enum import Enum
|
|
7
|
+
from pathlib import Path
|
|
7
8
|
from typing import TYPE_CHECKING, List, Optional, Text, Tuple, Union
|
|
8
9
|
|
|
9
10
|
import structlog
|
|
@@ -122,7 +123,8 @@ class Persistor(abc.ABC):
|
|
|
122
123
|
|
|
123
124
|
def persist(self, trained_model: str) -> None:
|
|
124
125
|
"""Uploads a trained model persisted in the `target_dir` to cloud storage."""
|
|
125
|
-
|
|
126
|
+
absolute_file_key = self._create_file_key(trained_model)
|
|
127
|
+
file_key = Path(absolute_file_key).name
|
|
126
128
|
self._persist_tar(file_key, trained_model)
|
|
127
129
|
|
|
128
130
|
def retrieve(self, model_name: Text, target_path: Text) -> Text:
|
|
@@ -141,7 +143,8 @@ class Persistor(abc.ABC):
|
|
|
141
143
|
# ensure backward compatibility
|
|
142
144
|
tar_name = self._tar_name(model_name)
|
|
143
145
|
tar_name = self._create_file_key(tar_name)
|
|
144
|
-
|
|
146
|
+
target_filename = os.path.basename(tar_name)
|
|
147
|
+
self._retrieve_tar(target_filename)
|
|
145
148
|
self._copy(os.path.basename(tar_name), target_path)
|
|
146
149
|
|
|
147
150
|
if os.path.isdir(target_path):
|
|
@@ -149,6 +152,36 @@ class Persistor(abc.ABC):
|
|
|
149
152
|
|
|
150
153
|
return target_path
|
|
151
154
|
|
|
155
|
+
def size_of_persisted_model(self, model_name: Text) -> int:
|
|
156
|
+
"""Returns the size of the model that has been persisted to cloud storage.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
model_name: The name of the model to retrieve.
|
|
160
|
+
"""
|
|
161
|
+
tar_name = model_name
|
|
162
|
+
if not model_name.endswith(MODEL_ARCHIVE_EXTENSION):
|
|
163
|
+
# ensure backward compatibility
|
|
164
|
+
tar_name = self._tar_name(model_name)
|
|
165
|
+
tar_name = self._create_file_key(tar_name)
|
|
166
|
+
target_filename = os.path.basename(tar_name)
|
|
167
|
+
return self._retrieve_tar_size(target_filename)
|
|
168
|
+
|
|
169
|
+
def _retrieve_tar_size(self, filename: Text) -> int:
|
|
170
|
+
"""Returns the size of the model that has been persisted to cloud storage."""
|
|
171
|
+
structlogger.warning(
|
|
172
|
+
"persistor.retrieve_tar_size.not_implemented",
|
|
173
|
+
filename=filename,
|
|
174
|
+
event_info=(
|
|
175
|
+
"This method should be implemented in the persistor. "
|
|
176
|
+
"The default implementation will download the model "
|
|
177
|
+
"to calculate the size. Most persistors should override "
|
|
178
|
+
"this method to avoid downloading the model and get the "
|
|
179
|
+
"size directly from the cloud storage."
|
|
180
|
+
),
|
|
181
|
+
)
|
|
182
|
+
self._retrieve_tar(filename)
|
|
183
|
+
return os.path.getsize(os.path.basename(filename))
|
|
184
|
+
|
|
152
185
|
@abc.abstractmethod
|
|
153
186
|
def _retrieve_tar(self, filename: Text) -> None:
|
|
154
187
|
"""Downloads a model previously persisted to cloud storage."""
|
|
@@ -197,10 +230,7 @@ class Persistor(abc.ABC):
|
|
|
197
230
|
f"{REMOTE_STORAGE_PATH_ENV} is deprecated and will be "
|
|
198
231
|
"removed in future versions. "
|
|
199
232
|
"Please use the -m path/to/model.tar.gz option to "
|
|
200
|
-
"specify the model path when loading a model."
|
|
201
|
-
"Or use --output and --fixed-model-name to specify the "
|
|
202
|
-
"output directory and the model name when saving a "
|
|
203
|
-
"trained model to remote storage.",
|
|
233
|
+
"specify the model path when loading a model.",
|
|
204
234
|
)
|
|
205
235
|
|
|
206
236
|
file_key = os.path.basename(model_path)
|
|
@@ -272,50 +302,48 @@ class AWSPersistor(Persistor):
|
|
|
272
302
|
with open(tar_path, "rb") as f:
|
|
273
303
|
self.s3.Object(self.bucket_name, file_key).put(Body=f)
|
|
274
304
|
|
|
275
|
-
def
|
|
305
|
+
def _retrieve_tar_size(self, model_path: Text) -> int:
|
|
306
|
+
"""Returns the size of the model that has been persisted to s3."""
|
|
307
|
+
try:
|
|
308
|
+
obj = self.s3.Object(self.bucket_name, model_path)
|
|
309
|
+
return obj.content_length
|
|
310
|
+
except Exception:
|
|
311
|
+
raise ModelNotFound()
|
|
312
|
+
|
|
313
|
+
def _retrieve_tar(self, target_filename: str) -> None:
|
|
276
314
|
"""Downloads a model that has previously been persisted to s3."""
|
|
277
315
|
from botocore import exceptions
|
|
278
316
|
|
|
279
|
-
target_filename = os.path.basename(model_path)
|
|
280
|
-
bucket_objects = list(self.bucket.objects.all())
|
|
281
|
-
|
|
282
|
-
model_found = False
|
|
283
|
-
|
|
284
317
|
log = (
|
|
285
318
|
f"Model '{target_filename}' not found in the specified bucket "
|
|
286
319
|
f"'{self.bucket_name}'. Please make sure the model exists "
|
|
287
320
|
f"in the bucket."
|
|
288
321
|
)
|
|
289
322
|
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
323
|
+
try:
|
|
324
|
+
with open(target_filename, "wb") as f:
|
|
325
|
+
self.bucket.download_fileobj(target_filename, f)
|
|
326
|
+
|
|
293
327
|
structlogger.debug(
|
|
294
|
-
"aws_persistor.retrieve_tar.object_found", object_key=
|
|
328
|
+
"aws_persistor.retrieve_tar.object_found", object_key=target_filename
|
|
295
329
|
)
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
bucket_name=self.bucket_name,
|
|
307
|
-
target_filename=target_filename,
|
|
308
|
-
event_info=log,
|
|
309
|
-
)
|
|
310
|
-
raise ModelNotFound() from exc
|
|
311
|
-
if not model_found:
|
|
330
|
+
except exceptions.ClientError as exc:
|
|
331
|
+
if self._error_code(exc) == HTTP_STATUS_NOT_FOUND:
|
|
332
|
+
structlogger.error(
|
|
333
|
+
"aws_persistor.retrieve_tar.model_not_found",
|
|
334
|
+
bucket_name=self.bucket_name,
|
|
335
|
+
target_filename=target_filename,
|
|
336
|
+
event_info=log,
|
|
337
|
+
)
|
|
338
|
+
raise ModelNotFound() from exc
|
|
339
|
+
except exceptions.BotoCoreError as exc:
|
|
312
340
|
structlogger.error(
|
|
313
|
-
"aws_persistor.retrieve_tar.
|
|
341
|
+
"aws_persistor.retrieve_tar.model_download_error",
|
|
314
342
|
bucket_name=self.bucket_name,
|
|
315
343
|
target_filename=target_filename,
|
|
316
344
|
event_info=log,
|
|
317
345
|
)
|
|
318
|
-
raise ModelNotFound()
|
|
346
|
+
raise ModelNotFound() from exc
|
|
319
347
|
|
|
320
348
|
|
|
321
349
|
class GCSPersistor(Persistor):
|
|
@@ -397,6 +425,14 @@ class GCSPersistor(Persistor):
|
|
|
397
425
|
blob = self.bucket.blob(file_key)
|
|
398
426
|
blob.upload_from_filename(tar_path)
|
|
399
427
|
|
|
428
|
+
def _retrieve_tar_size(self, target_filename: Text) -> int:
|
|
429
|
+
"""Returns the size of the model that has been persisted to GCS."""
|
|
430
|
+
try:
|
|
431
|
+
blob = self.bucket.blob(target_filename)
|
|
432
|
+
return blob.size
|
|
433
|
+
except Exception:
|
|
434
|
+
raise ModelNotFound()
|
|
435
|
+
|
|
400
436
|
def _retrieve_tar(self, target_filename: Text) -> None:
|
|
401
437
|
"""Downloads a model that has previously been persisted to GCS."""
|
|
402
438
|
from google.api_core import exceptions
|
|
@@ -404,6 +440,10 @@ class GCSPersistor(Persistor):
|
|
|
404
440
|
blob = self.bucket.blob(target_filename)
|
|
405
441
|
try:
|
|
406
442
|
blob.download_to_filename(target_filename)
|
|
443
|
+
|
|
444
|
+
structlogger.debug(
|
|
445
|
+
"gcs_persistor.retrieve_tar.object_found", object_key=target_filename
|
|
446
|
+
)
|
|
407
447
|
except exceptions.NotFound as exc:
|
|
408
448
|
log = (
|
|
409
449
|
f"Model '{target_filename}' not found in the specified bucket "
|
|
@@ -460,24 +500,28 @@ class AzurePersistor(Persistor):
|
|
|
460
500
|
with open(tar_path, "rb") as data:
|
|
461
501
|
self._container_client().upload_blob(name=file_key, data=data)
|
|
462
502
|
|
|
463
|
-
def
|
|
464
|
-
"""
|
|
503
|
+
def _retrieve_tar_size(self, target_filename: Text) -> int:
|
|
504
|
+
"""Returns the size of the model that has been persisted to Azure."""
|
|
465
505
|
try:
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
506
|
+
blob_client = self._container_client().get_blob_client(target_filename)
|
|
507
|
+
properties = blob_client.get_blob_properties()
|
|
508
|
+
return properties.size
|
|
509
|
+
except Exception:
|
|
510
|
+
raise ModelNotFound()
|
|
471
511
|
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
512
|
+
def _retrieve_tar(self, target_filename: Text) -> None:
|
|
513
|
+
"""Downloads a model that has previously been persisted to Azure."""
|
|
514
|
+
from azure.core.exceptions import AzureError
|
|
475
515
|
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
516
|
+
try:
|
|
517
|
+
with open(target_filename, "wb") as model_file:
|
|
518
|
+
blob_client = self._container_client().get_blob_client(target_filename)
|
|
519
|
+
download_stream = blob_client.download_blob()
|
|
520
|
+
model_file.write(download_stream.readall())
|
|
521
|
+
structlogger.debug(
|
|
522
|
+
"azure_persistor.retrieve_tar.blob_found", blob_name=target_filename
|
|
523
|
+
)
|
|
524
|
+
except AzureError as exc:
|
|
481
525
|
log = (
|
|
482
526
|
f"An exception occurred while trying to download "
|
|
483
527
|
f"the model '{target_filename}' in the specified container "
|
|
@@ -487,7 +487,8 @@ def validate_collect_step(
|
|
|
487
487
|
step: CollectInformationFlowStep,
|
|
488
488
|
stack: DialogueStack,
|
|
489
489
|
available_actions: List[str],
|
|
490
|
-
slots: Dict[
|
|
490
|
+
slots: Dict[str, Slot],
|
|
491
|
+
flow_name: str,
|
|
491
492
|
) -> bool:
|
|
492
493
|
"""Validate that a collect step can be executed.
|
|
493
494
|
|
|
@@ -510,12 +511,12 @@ def validate_collect_step(
|
|
|
510
511
|
slot_name=step.collect,
|
|
511
512
|
)
|
|
512
513
|
|
|
513
|
-
cancel_flow_and_push_internal_error(stack)
|
|
514
|
+
cancel_flow_and_push_internal_error(stack, flow_name)
|
|
514
515
|
|
|
515
516
|
return False
|
|
516
517
|
|
|
517
518
|
|
|
518
|
-
def cancel_flow_and_push_internal_error(stack: DialogueStack) -> None:
|
|
519
|
+
def cancel_flow_and_push_internal_error(stack: DialogueStack, flow_name: str) -> None:
|
|
519
520
|
"""Cancel the top user flow and push the internal error pattern."""
|
|
520
521
|
top_frame = stack.top()
|
|
521
522
|
|
|
@@ -527,7 +528,7 @@ def cancel_flow_and_push_internal_error(stack: DialogueStack) -> None:
|
|
|
527
528
|
canceled_frames = CancelFlowCommand.select_canceled_frames(stack)
|
|
528
529
|
stack.push(
|
|
529
530
|
CancelPatternFlowStackFrame(
|
|
530
|
-
canceled_name=
|
|
531
|
+
canceled_name=flow_name,
|
|
531
532
|
canceled_frames=canceled_frames,
|
|
532
533
|
)
|
|
533
534
|
)
|
|
@@ -539,6 +540,7 @@ def validate_custom_slot_mappings(
|
|
|
539
540
|
stack: DialogueStack,
|
|
540
541
|
tracker: DialogueStateTracker,
|
|
541
542
|
available_actions: List[str],
|
|
543
|
+
flow_name: str,
|
|
542
544
|
) -> bool:
|
|
543
545
|
"""Validate a slot with custom mappings.
|
|
544
546
|
|
|
@@ -559,7 +561,7 @@ def validate_custom_slot_mappings(
|
|
|
559
561
|
action=step.collect_action,
|
|
560
562
|
collect=step.collect,
|
|
561
563
|
)
|
|
562
|
-
cancel_flow_and_push_internal_error(stack)
|
|
564
|
+
cancel_flow_and_push_internal_error(stack, flow_name)
|
|
563
565
|
return False
|
|
564
566
|
|
|
565
567
|
return True
|
|
@@ -599,7 +601,12 @@ def run_step(
|
|
|
599
601
|
|
|
600
602
|
if isinstance(step, CollectInformationFlowStep):
|
|
601
603
|
return _run_collect_information_step(
|
|
602
|
-
available_actions,
|
|
604
|
+
available_actions,
|
|
605
|
+
initial_events,
|
|
606
|
+
stack,
|
|
607
|
+
step,
|
|
608
|
+
tracker,
|
|
609
|
+
flow.readable_name(),
|
|
603
610
|
)
|
|
604
611
|
|
|
605
612
|
elif isinstance(step, ActionFlowStep):
|
|
@@ -719,15 +726,18 @@ def _run_collect_information_step(
|
|
|
719
726
|
stack: DialogueStack,
|
|
720
727
|
step: CollectInformationFlowStep,
|
|
721
728
|
tracker: DialogueStateTracker,
|
|
729
|
+
flow_name: str,
|
|
722
730
|
) -> FlowStepResult:
|
|
723
|
-
is_step_valid = validate_collect_step(
|
|
731
|
+
is_step_valid = validate_collect_step(
|
|
732
|
+
step, stack, available_actions, tracker.slots, flow_name
|
|
733
|
+
)
|
|
724
734
|
|
|
725
735
|
if not is_step_valid:
|
|
726
736
|
# if we return any other FlowStepResult, the assistant will stay silent
|
|
727
737
|
# instead of triggering the internal error pattern
|
|
728
738
|
return ContinueFlowWithNextStep(events=initial_events)
|
|
729
739
|
is_mapping_valid = validate_custom_slot_mappings(
|
|
730
|
-
step, stack, tracker, available_actions
|
|
740
|
+
step, stack, tracker, available_actions, flow_name
|
|
731
741
|
)
|
|
732
742
|
|
|
733
743
|
if not is_mapping_valid:
|
rasa/core/processor.py
CHANGED
|
@@ -1279,11 +1279,13 @@ class MessageProcessor:
|
|
|
1279
1279
|
tracker.update(events[0])
|
|
1280
1280
|
return self.should_predict_another_action(action.name())
|
|
1281
1281
|
except Exception:
|
|
1282
|
-
|
|
1283
|
-
|
|
1284
|
-
"
|
|
1285
|
-
"
|
|
1286
|
-
"
|
|
1282
|
+
structlogger.exception(
|
|
1283
|
+
"rasa.core.processor.run_action.exception",
|
|
1284
|
+
event_info=f"Encountered an exception while "
|
|
1285
|
+
f"running action '{action.name()}'."
|
|
1286
|
+
f"Bot will continue, but the actions events are lost. "
|
|
1287
|
+
f"Please check the logs of your action server for "
|
|
1288
|
+
f"more information.",
|
|
1287
1289
|
)
|
|
1288
1290
|
events = []
|
|
1289
1291
|
|
|
@@ -35,6 +35,7 @@ class AggregateTestStatsCalculator:
|
|
|
35
35
|
self.test_cases = test_cases
|
|
36
36
|
|
|
37
37
|
self.failed_assertion_set: Set["Assertion"] = set()
|
|
38
|
+
self.failed_test_cases_without_assertion_failure: Set[str] = set()
|
|
38
39
|
self.passed_count_mapping = {
|
|
39
40
|
subclass_type: 0
|
|
40
41
|
for subclass_type in _get_all_assertion_subclasses().keys()
|
|
@@ -89,8 +90,14 @@ class AggregateTestStatsCalculator:
|
|
|
89
90
|
passed_test_case_names = [
|
|
90
91
|
passed.test_case.name for passed in self.passed_results
|
|
91
92
|
]
|
|
93
|
+
# We filter out test cases that failed without an assertion failure
|
|
94
|
+
filtered_test_cases = [
|
|
95
|
+
test_case
|
|
96
|
+
for test_case in self.test_cases
|
|
97
|
+
if test_case.name not in self.failed_test_cases_without_assertion_failure
|
|
98
|
+
]
|
|
92
99
|
|
|
93
|
-
for test_case in
|
|
100
|
+
for test_case in filtered_test_cases:
|
|
94
101
|
if test_case.name in passed_test_case_names:
|
|
95
102
|
for step in test_case.steps:
|
|
96
103
|
if step.assertions is None:
|
|
@@ -118,6 +125,9 @@ class AggregateTestStatsCalculator:
|
|
|
118
125
|
"no_assertion_failure_in_failed_result",
|
|
119
126
|
test_case=failed.test_case.name,
|
|
120
127
|
)
|
|
128
|
+
self.failed_test_cases_without_assertion_failure.add(
|
|
129
|
+
failed.test_case.name
|
|
130
|
+
)
|
|
121
131
|
continue
|
|
122
132
|
|
|
123
133
|
self.failed_assertion_set.add(failed.assertion_failure.assertion)
|