rasa-pro 3.10.16__py3-none-any.whl → 3.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/__main__.py +31 -15
- rasa/api.py +12 -2
- rasa/cli/arguments/default_arguments.py +24 -4
- rasa/cli/arguments/run.py +15 -0
- rasa/cli/arguments/shell.py +5 -1
- rasa/cli/arguments/train.py +17 -9
- rasa/cli/evaluate.py +7 -7
- rasa/cli/inspect.py +19 -7
- rasa/cli/interactive.py +1 -0
- rasa/cli/llm_fine_tuning.py +11 -14
- rasa/cli/project_templates/calm/config.yml +5 -7
- rasa/cli/project_templates/calm/endpoints.yml +15 -2
- rasa/cli/project_templates/tutorial/config.yml +8 -5
- rasa/cli/project_templates/tutorial/data/flows.yml +1 -1
- rasa/cli/project_templates/tutorial/data/patterns.yml +5 -0
- rasa/cli/project_templates/tutorial/domain.yml +14 -0
- rasa/cli/project_templates/tutorial/endpoints.yml +5 -0
- rasa/cli/run.py +7 -0
- rasa/cli/scaffold.py +4 -2
- rasa/cli/studio/upload.py +0 -15
- rasa/cli/train.py +14 -53
- rasa/cli/utils.py +14 -11
- rasa/cli/x.py +7 -7
- rasa/constants.py +3 -1
- rasa/core/actions/action.py +77 -33
- rasa/core/actions/action_hangup.py +29 -0
- rasa/core/actions/action_repeat_bot_messages.py +89 -0
- rasa/core/actions/e2e_stub_custom_action_executor.py +5 -1
- rasa/core/actions/http_custom_action_executor.py +4 -0
- rasa/core/agent.py +2 -2
- rasa/core/brokers/kafka.py +3 -1
- rasa/core/brokers/pika.py +3 -1
- rasa/core/channels/__init__.py +10 -6
- rasa/core/channels/channel.py +41 -4
- rasa/core/channels/development_inspector.py +150 -46
- rasa/core/channels/inspector/README.md +1 -1
- rasa/core/channels/inspector/dist/assets/{arc-b6e548fe.js → arc-bc141fb2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{c4Diagram-d0fbc5ce-fa03ac9e.js → c4Diagram-d0fbc5ce-be2db283.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-936ed81e-ee67392a.js → classDiagram-936ed81e-55366915.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-v2-c3cb15f1-9b283fae.js → classDiagram-v2-c3cb15f1-bb529518.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{createText-62fc7601-8b6fcc2a.js → createText-62fc7601-b0ec81d6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{edges-f2ad444c-22e77f4f.js → edges-f2ad444c-6166330c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{erDiagram-9d236eb7-60ffc87f.js → erDiagram-9d236eb7-5ccc6a8e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDb-1972c806-9dd802e4.js → flowDb-1972c806-fca3bfe4.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDiagram-7ea5b25a-5fa1912f.js → flowDiagram-7ea5b25a-4739080f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-736177bf.js +1 -0
- rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-abe16c3d-622a1fd2.js → flowchart-elk-definition-abe16c3d-7c1b0e0f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{ganttDiagram-9b5ea136-e285a63a.js → ganttDiagram-9b5ea136-772fd050.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-99d0ae7c-f237bdca.js → gitGraphDiagram-99d0ae7c-8eae1dc9.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-2c4b9a3b-4b03d70e.js → index-2c4b9a3b-f55afcdf.js} +1 -1
- rasa/core/channels/inspector/dist/assets/index-e7cef9de.js +1317 -0
- rasa/core/channels/inspector/dist/assets/{infoDiagram-736b4530-72a0fa5f.js → infoDiagram-736b4530-124d4a14.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{journeyDiagram-df861f2b-82218c41.js → journeyDiagram-df861f2b-7c4fae44.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{layout-78cff630.js → layout-b9885fb6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{line-5038b469.js → line-7c59abb6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{linear-c4fc4098.js → linear-4776f780.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{mindmap-definition-beec6740-c33c8ea6.js → mindmap-definition-beec6740-2332c46c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{pieDiagram-dbbf0591-a8d03059.js → pieDiagram-dbbf0591-8fb39303.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{quadrantDiagram-4d7f4fd6-6a0e56b2.js → quadrantDiagram-4d7f4fd6-3c7180a2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{requirementDiagram-6fc4c22a-2dc7c7bd.js → requirementDiagram-6fc4c22a-e910bcb8.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sankeyDiagram-8f13d901-2360fe39.js → sankeyDiagram-8f13d901-ead16c89.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sequenceDiagram-b655622a-41b9f9ad.js → sequenceDiagram-b655622a-29a02a19.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-59f0c015-0aad326f.js → stateDiagram-59f0c015-042b3137.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-2b26beab-9847d984.js → stateDiagram-v2-2b26beab-2178c0f3.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-080da4f6-564d890e.js → styles-080da4f6-23ffa4fc.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-3dcbcfbf-38957613.js → styles-3dcbcfbf-94f59763.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-9c745c82-f0fc6921.js → styles-9c745c82-78a6bebc.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{svgDrawCommon-4835440b-ef3c5a77.js → svgDrawCommon-4835440b-eae2a6f6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{timeline-definition-5b62e21b-bf3e91c1.js → timeline-definition-5b62e21b-5c968d92.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{xychartDiagram-2b33534f-4d4026c0.js → xychartDiagram-2b33534f-fd3db0d5.js} +1 -1
- rasa/core/channels/inspector/dist/index.html +18 -17
- rasa/core/channels/inspector/index.html +17 -16
- rasa/core/channels/inspector/package.json +5 -1
- rasa/core/channels/inspector/src/App.tsx +118 -68
- rasa/core/channels/inspector/src/components/Chat.tsx +95 -0
- rasa/core/channels/inspector/src/components/DiagramFlow.tsx +11 -10
- rasa/core/channels/inspector/src/components/DialogueStack.tsx +10 -25
- rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +6 -3
- rasa/core/channels/inspector/src/helpers/audiostream.ts +165 -0
- rasa/core/channels/inspector/src/helpers/formatters.test.ts +10 -0
- rasa/core/channels/inspector/src/helpers/formatters.ts +107 -41
- rasa/core/channels/inspector/src/helpers/utils.ts +92 -7
- rasa/core/channels/inspector/src/types.ts +21 -1
- rasa/core/channels/inspector/yarn.lock +94 -1
- rasa/core/channels/rest.py +51 -46
- rasa/core/channels/socketio.py +28 -1
- rasa/core/channels/telegram.py +1 -1
- rasa/core/channels/twilio.py +1 -1
- rasa/core/channels/{audiocodes.py → voice_ready/audiocodes.py} +122 -69
- rasa/core/channels/{voice_aware → voice_ready}/jambonz.py +26 -8
- rasa/core/channels/{voice_aware → voice_ready}/jambonz_protocol.py +57 -5
- rasa/core/channels/{twilio_voice.py → voice_ready/twilio_voice.py} +64 -28
- rasa/core/channels/voice_ready/utils.py +37 -0
- rasa/core/channels/voice_stream/asr/__init__.py +0 -0
- rasa/core/channels/voice_stream/asr/asr_engine.py +89 -0
- rasa/core/channels/voice_stream/asr/asr_event.py +18 -0
- rasa/core/channels/voice_stream/asr/azure.py +129 -0
- rasa/core/channels/voice_stream/asr/deepgram.py +90 -0
- rasa/core/channels/voice_stream/audio_bytes.py +8 -0
- rasa/core/channels/voice_stream/browser_audio.py +107 -0
- rasa/core/channels/voice_stream/call_state.py +23 -0
- rasa/core/channels/voice_stream/tts/__init__.py +0 -0
- rasa/core/channels/voice_stream/tts/azure.py +106 -0
- rasa/core/channels/voice_stream/tts/cartesia.py +118 -0
- rasa/core/channels/voice_stream/tts/tts_cache.py +27 -0
- rasa/core/channels/voice_stream/tts/tts_engine.py +58 -0
- rasa/core/channels/voice_stream/twilio_media_streams.py +173 -0
- rasa/core/channels/voice_stream/util.py +57 -0
- rasa/core/channels/voice_stream/voice_channel.py +427 -0
- rasa/core/information_retrieval/qdrant.py +1 -0
- rasa/core/nlg/contextual_response_rephraser.py +45 -17
- rasa/{nlu → core}/persistor.py +203 -68
- rasa/core/policies/enterprise_search_policy.py +119 -63
- rasa/core/policies/flows/flow_executor.py +15 -22
- rasa/core/policies/intentless_policy.py +83 -28
- rasa/core/processor.py +25 -0
- rasa/core/run.py +12 -2
- rasa/core/secrets_manager/constants.py +4 -0
- rasa/core/secrets_manager/factory.py +8 -0
- rasa/core/secrets_manager/vault.py +11 -1
- rasa/core/training/interactive.py +33 -34
- rasa/core/utils.py +47 -21
- rasa/dialogue_understanding/coexistence/llm_based_router.py +41 -14
- rasa/dialogue_understanding/commands/__init__.py +6 -0
- rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +60 -0
- rasa/dialogue_understanding/commands/session_end_command.py +61 -0
- rasa/dialogue_understanding/commands/user_silence_command.py +59 -0
- rasa/dialogue_understanding/commands/utils.py +5 -0
- rasa/dialogue_understanding/generator/constants.py +2 -0
- rasa/dialogue_understanding/generator/flow_retrieval.py +47 -9
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +38 -15
- rasa/dialogue_understanding/generator/llm_command_generator.py +1 -1
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +35 -13
- rasa/dialogue_understanding/generator/single_step/command_prompt_template.jinja2 +3 -0
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +60 -13
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +53 -0
- rasa/dialogue_understanding/patterns/repeat.py +37 -0
- rasa/dialogue_understanding/patterns/user_silence.py +37 -0
- rasa/dialogue_understanding/processor/command_processor.py +21 -1
- rasa/e2e_test/aggregate_test_stats_calculator.py +1 -11
- rasa/e2e_test/assertions.py +136 -61
- rasa/e2e_test/assertions_schema.yml +23 -0
- rasa/e2e_test/e2e_test_case.py +85 -6
- rasa/e2e_test/e2e_test_runner.py +2 -3
- rasa/e2e_test/utils/e2e_yaml_utils.py +1 -1
- rasa/engine/graph.py +3 -10
- rasa/engine/loader.py +12 -0
- rasa/engine/recipes/config_files/default_config.yml +0 -3
- rasa/engine/recipes/default_recipe.py +0 -1
- rasa/engine/recipes/graph_recipe.py +0 -1
- rasa/engine/runner/dask.py +2 -2
- rasa/engine/storage/local_model_storage.py +12 -42
- rasa/engine/storage/storage.py +1 -5
- rasa/engine/validation.py +527 -74
- rasa/model_manager/__init__.py +0 -0
- rasa/model_manager/config.py +40 -0
- rasa/model_manager/model_api.py +559 -0
- rasa/model_manager/runner_service.py +286 -0
- rasa/model_manager/socket_bridge.py +146 -0
- rasa/model_manager/studio_jwt_auth.py +86 -0
- rasa/model_manager/trainer_service.py +325 -0
- rasa/model_manager/utils.py +87 -0
- rasa/model_manager/warm_rasa_process.py +187 -0
- rasa/model_service.py +112 -0
- rasa/model_training.py +42 -23
- rasa/nlu/tokenizers/whitespace_tokenizer.py +3 -14
- rasa/server.py +4 -2
- rasa/shared/constants.py +60 -8
- rasa/shared/core/constants.py +13 -0
- rasa/shared/core/domain.py +107 -50
- rasa/shared/core/events.py +29 -0
- rasa/shared/core/flows/flow.py +5 -0
- rasa/shared/core/flows/flows_list.py +19 -6
- rasa/shared/core/flows/flows_yaml_schema.json +10 -0
- rasa/shared/core/flows/utils.py +39 -0
- rasa/shared/core/flows/validation.py +121 -0
- rasa/shared/core/flows/yaml_flows_io.py +15 -27
- rasa/shared/core/slots.py +5 -0
- rasa/shared/importers/importer.py +59 -41
- rasa/shared/importers/multi_project.py +23 -11
- rasa/shared/importers/rasa.py +12 -3
- rasa/shared/importers/remote_importer.py +196 -0
- rasa/shared/importers/utils.py +3 -1
- rasa/shared/nlu/training_data/formats/rasa_yaml.py +18 -3
- rasa/shared/nlu/training_data/training_data.py +18 -19
- rasa/shared/providers/_configs/litellm_router_client_config.py +220 -0
- rasa/shared/providers/_configs/model_group_config.py +167 -0
- rasa/shared/providers/_configs/openai_client_config.py +1 -1
- rasa/shared/providers/_configs/rasa_llm_client_config.py +73 -0
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +1 -0
- rasa/shared/providers/_configs/utils.py +16 -0
- rasa/shared/providers/_utils.py +79 -0
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +13 -29
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +54 -21
- rasa/shared/providers/embedding/default_litellm_embedding_client.py +24 -0
- rasa/shared/providers/embedding/litellm_router_embedding_client.py +135 -0
- rasa/shared/providers/llm/_base_litellm_client.py +34 -22
- rasa/shared/providers/llm/azure_openai_llm_client.py +50 -29
- rasa/shared/providers/llm/default_litellm_llm_client.py +24 -0
- rasa/shared/providers/llm/litellm_router_llm_client.py +182 -0
- rasa/shared/providers/llm/rasa_llm_client.py +112 -0
- rasa/shared/providers/llm/self_hosted_llm_client.py +5 -29
- rasa/shared/providers/mappings.py +19 -0
- rasa/shared/providers/router/__init__.py +0 -0
- rasa/shared/providers/router/_base_litellm_router_client.py +183 -0
- rasa/shared/providers/router/router_client.py +73 -0
- rasa/shared/utils/common.py +40 -24
- rasa/shared/utils/health_check/__init__.py +0 -0
- rasa/shared/utils/health_check/embeddings_health_check_mixin.py +31 -0
- rasa/shared/utils/health_check/health_check.py +258 -0
- rasa/shared/utils/health_check/llm_health_check_mixin.py +31 -0
- rasa/shared/utils/io.py +27 -6
- rasa/shared/utils/llm.py +354 -44
- rasa/shared/utils/schemas/events.py +2 -0
- rasa/shared/utils/schemas/model_config.yml +0 -10
- rasa/shared/utils/yaml.py +181 -38
- rasa/studio/data_handler.py +3 -1
- rasa/studio/upload.py +160 -74
- rasa/telemetry.py +94 -17
- rasa/tracing/config.py +3 -1
- rasa/tracing/instrumentation/attribute_extractors.py +95 -18
- rasa/tracing/instrumentation/instrumentation.py +121 -0
- rasa/utils/common.py +5 -0
- rasa/utils/endpoints.py +27 -1
- rasa/utils/io.py +8 -16
- rasa/utils/log_utils.py +9 -2
- rasa/utils/sanic_error_handler.py +32 -0
- rasa/validator.py +110 -16
- rasa/version.py +1 -1
- {rasa_pro-3.10.16.dist-info → rasa_pro-3.11.0.dist-info}/METADATA +16 -14
- {rasa_pro-3.10.16.dist-info → rasa_pro-3.11.0.dist-info}/RECORD +236 -185
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-1844e5a5.js +0 -1
- rasa/core/channels/inspector/dist/assets/index-a5d3e69d.js +0 -1040
- rasa/core/channels/voice_aware/utils.py +0 -20
- rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +0 -407
- /rasa/core/channels/{voice_aware → voice_ready}/__init__.py +0 -0
- /rasa/core/channels/{voice_native → voice_stream}/__init__.py +0 -0
- {rasa_pro-3.10.16.dist-info → rasa_pro-3.11.0.dist-info}/NOTICE +0 -0
- {rasa_pro-3.10.16.dist-info → rasa_pro-3.11.0.dist-info}/WHEEL +0 -0
- {rasa_pro-3.10.16.dist-info → rasa_pro-3.11.0.dist-info}/entry_points.txt +0 -0
rasa/{nlu → core}/persistor.py
RENAMED
|
@@ -4,16 +4,19 @@ import abc
|
|
|
4
4
|
import os
|
|
5
5
|
import shutil
|
|
6
6
|
from enum import Enum
|
|
7
|
+
from pathlib import Path
|
|
7
8
|
from typing import TYPE_CHECKING, List, Optional, Text, Tuple, Union
|
|
8
9
|
|
|
9
10
|
import structlog
|
|
10
11
|
|
|
12
|
+
from rasa.exceptions import ModelNotFound
|
|
11
13
|
import rasa.shared.utils.common
|
|
12
14
|
import rasa.utils.common
|
|
13
15
|
from rasa.constants import (
|
|
14
16
|
HTTP_STATUS_FORBIDDEN,
|
|
15
17
|
HTTP_STATUS_NOT_FOUND,
|
|
16
18
|
MODEL_ARCHIVE_EXTENSION,
|
|
19
|
+
DEFAULT_BUCKET_NAME,
|
|
17
20
|
)
|
|
18
21
|
from rasa.env import (
|
|
19
22
|
AWS_ENDPOINT_URL_ENV,
|
|
@@ -28,6 +31,7 @@ from rasa.shared.utils.io import raise_warning
|
|
|
28
31
|
|
|
29
32
|
if TYPE_CHECKING:
|
|
30
33
|
from azure.storage.blob import ContainerClient
|
|
34
|
+
from botocore.exceptions import ClientError
|
|
31
35
|
|
|
32
36
|
structlogger = structlog.get_logger()
|
|
33
37
|
|
|
@@ -82,16 +86,19 @@ def get_persistor(storage: StorageType) -> Optional[Persistor]:
|
|
|
82
86
|
Currently, `aws`, `gcs`, `azure` and providing module paths are supported remote
|
|
83
87
|
storages.
|
|
84
88
|
"""
|
|
85
|
-
if storage
|
|
89
|
+
storage = storage.value if isinstance(storage, RemoteStorageType) else storage
|
|
90
|
+
|
|
91
|
+
if storage == RemoteStorageType.AWS.value:
|
|
86
92
|
return AWSPersistor(
|
|
87
|
-
os.environ.get(BUCKET_NAME_ENV
|
|
93
|
+
os.environ.get(BUCKET_NAME_ENV, DEFAULT_BUCKET_NAME),
|
|
94
|
+
os.environ.get(AWS_ENDPOINT_URL_ENV),
|
|
88
95
|
)
|
|
89
|
-
if storage == RemoteStorageType.GCS:
|
|
90
|
-
return GCSPersistor(os.environ.get(BUCKET_NAME_ENV))
|
|
96
|
+
if storage == RemoteStorageType.GCS.value:
|
|
97
|
+
return GCSPersistor(os.environ.get(BUCKET_NAME_ENV, DEFAULT_BUCKET_NAME))
|
|
91
98
|
|
|
92
|
-
if storage == RemoteStorageType.AZURE:
|
|
99
|
+
if storage == RemoteStorageType.AZURE.value:
|
|
93
100
|
return AzurePersistor(
|
|
94
|
-
os.environ.get(AZURE_CONTAINER_ENV),
|
|
101
|
+
os.environ.get(AZURE_CONTAINER_ENV, DEFAULT_BUCKET_NAME),
|
|
95
102
|
os.environ.get(AZURE_ACCOUNT_NAME_ENV),
|
|
96
103
|
os.environ.get(AZURE_ACCOUNT_KEY_ENV),
|
|
97
104
|
)
|
|
@@ -116,44 +123,64 @@ class Persistor(abc.ABC):
|
|
|
116
123
|
|
|
117
124
|
def persist(self, trained_model: str) -> None:
|
|
118
125
|
"""Uploads a trained model persisted in the `target_dir` to cloud storage."""
|
|
119
|
-
|
|
126
|
+
absolute_file_key = self._create_file_key(trained_model)
|
|
127
|
+
file_key = Path(absolute_file_key).name
|
|
120
128
|
self._persist_tar(file_key, trained_model)
|
|
121
129
|
|
|
122
130
|
def retrieve(self, model_name: Text, target_path: Text) -> Text:
|
|
123
131
|
"""Downloads a model that has been persisted to cloud storage.
|
|
124
132
|
|
|
125
133
|
Downloaded model will be saved to the `target_path`.
|
|
126
|
-
If `target_path` is a directory, the model will be
|
|
127
|
-
If `target_path` is a file, the model will be
|
|
134
|
+
If `target_path` is a directory, the model will be saved to that directory.
|
|
135
|
+
If `target_path` is a file, the model will be saved to that file.
|
|
128
136
|
|
|
129
137
|
Args:
|
|
130
138
|
model_name: The name of the model to retrieve.
|
|
131
|
-
target_path: The path to which the model should be
|
|
139
|
+
target_path: The path to which the model should be saved.
|
|
132
140
|
"""
|
|
133
141
|
tar_name = model_name
|
|
134
142
|
if not model_name.endswith(MODEL_ARCHIVE_EXTENSION):
|
|
135
143
|
# ensure backward compatibility
|
|
136
144
|
tar_name = self._tar_name(model_name)
|
|
137
|
-
|
|
138
|
-
|
|
145
|
+
tar_name = self._create_file_key(tar_name)
|
|
146
|
+
target_filename = os.path.basename(tar_name)
|
|
147
|
+
self._retrieve_tar(target_filename)
|
|
148
|
+
self._copy(os.path.basename(tar_name), target_path)
|
|
139
149
|
|
|
140
|
-
target_tar_file_name = os.path.basename(tar_name)
|
|
141
150
|
if os.path.isdir(target_path):
|
|
142
|
-
|
|
151
|
+
return os.path.join(target_path, model_name)
|
|
143
152
|
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
f"'{target_path}'.",
|
|
149
|
-
)
|
|
150
|
-
self._copy(target_tar_file_name, target_path)
|
|
153
|
+
return target_path
|
|
154
|
+
|
|
155
|
+
def size_of_persisted_model(self, model_name: Text) -> int:
|
|
156
|
+
"""Returns the size of the model that has been persisted to cloud storage.
|
|
151
157
|
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
158
|
+
Args:
|
|
159
|
+
model_name: The name of the model to retrieve.
|
|
160
|
+
"""
|
|
161
|
+
tar_name = model_name
|
|
162
|
+
if not model_name.endswith(MODEL_ARCHIVE_EXTENSION):
|
|
163
|
+
# ensure backward compatibility
|
|
164
|
+
tar_name = self._tar_name(model_name)
|
|
165
|
+
tar_name = self._create_file_key(tar_name)
|
|
166
|
+
target_filename = os.path.basename(tar_name)
|
|
167
|
+
return self._retrieve_tar_size(target_filename)
|
|
168
|
+
|
|
169
|
+
def _retrieve_tar_size(self, filename: Text) -> int:
|
|
170
|
+
"""Returns the size of the model that has been persisted to cloud storage."""
|
|
171
|
+
structlogger.warning(
|
|
172
|
+
"persistor.retrieve_tar_size.not_implemented",
|
|
173
|
+
filename=filename,
|
|
174
|
+
event_info=(
|
|
175
|
+
"This method should be implemented in the persistor. "
|
|
176
|
+
"The default implementation will download the model "
|
|
177
|
+
"to calculate the size. Most persistors should override "
|
|
178
|
+
"this method to avoid downloading the model and get the "
|
|
179
|
+
"size directly from the cloud storage."
|
|
180
|
+
),
|
|
155
181
|
)
|
|
156
|
-
|
|
182
|
+
self._retrieve_tar(filename)
|
|
183
|
+
return os.path.getsize(os.path.basename(filename))
|
|
157
184
|
|
|
158
185
|
@abc.abstractmethod
|
|
159
186
|
def _retrieve_tar(self, filename: Text) -> None:
|
|
@@ -175,7 +202,7 @@ class Persistor(abc.ABC):
|
|
|
175
202
|
os.path.join(dirpath, base_name),
|
|
176
203
|
"gztar",
|
|
177
204
|
root_dir=model_directory,
|
|
178
|
-
base_dir="
|
|
205
|
+
base_dir="../nlu",
|
|
179
206
|
)
|
|
180
207
|
file_key = os.path.basename(tar_name)
|
|
181
208
|
return file_key, tar_name
|
|
@@ -191,7 +218,7 @@ class Persistor(abc.ABC):
|
|
|
191
218
|
|
|
192
219
|
@staticmethod
|
|
193
220
|
def _create_file_key(model_path: str) -> Text:
|
|
194
|
-
"""Appends remote storage folders when provided to upload or retrieve file"""
|
|
221
|
+
"""Appends remote storage folders when provided to upload or retrieve file."""
|
|
195
222
|
bucket_object_path = os.environ.get(REMOTE_STORAGE_PATH_ENV)
|
|
196
223
|
|
|
197
224
|
# To keep the backward compatibility, if REMOTE_STORAGE_PATH is not provided,
|
|
@@ -203,10 +230,7 @@ class Persistor(abc.ABC):
|
|
|
203
230
|
f"{REMOTE_STORAGE_PATH_ENV} is deprecated and will be "
|
|
204
231
|
"removed in future versions. "
|
|
205
232
|
"Please use the -m path/to/model.tar.gz option to "
|
|
206
|
-
"specify the model path when loading a model."
|
|
207
|
-
"Or use --output and --fixed-model-name to specify the "
|
|
208
|
-
"output directory and the model name when saving a "
|
|
209
|
-
"trained model to remote storage.",
|
|
233
|
+
"specify the model path when loading a model.",
|
|
210
234
|
)
|
|
211
235
|
|
|
212
236
|
file_key = os.path.basename(model_path)
|
|
@@ -239,14 +263,13 @@ class AWSPersistor(Persistor):
|
|
|
239
263
|
def _ensure_bucket_exists(
|
|
240
264
|
self, bucket_name: Text, region_name: Optional[Text] = None
|
|
241
265
|
) -> None:
|
|
242
|
-
import
|
|
266
|
+
from botocore import exceptions
|
|
243
267
|
|
|
244
268
|
# noinspection PyUnresolvedReferences
|
|
245
269
|
try:
|
|
246
270
|
self.s3.meta.client.head_bucket(Bucket=bucket_name)
|
|
247
|
-
except
|
|
248
|
-
|
|
249
|
-
if error_code == HTTP_STATUS_FORBIDDEN:
|
|
271
|
+
except exceptions.ClientError as exc:
|
|
272
|
+
if self._error_code(exc) == HTTP_STATUS_FORBIDDEN:
|
|
250
273
|
log = (
|
|
251
274
|
f"Access to the specified bucket '{bucket_name}' is forbidden. "
|
|
252
275
|
"Please make sure you have the necessary "
|
|
@@ -258,7 +281,7 @@ class AWSPersistor(Persistor):
|
|
|
258
281
|
event_info=log,
|
|
259
282
|
)
|
|
260
283
|
raise RasaException(log)
|
|
261
|
-
elif
|
|
284
|
+
elif self._error_code(exc) == HTTP_STATUS_NOT_FOUND:
|
|
262
285
|
log = (
|
|
263
286
|
f"The specified bucket '{bucket_name}' does not exist. "
|
|
264
287
|
"Please make sure to create the bucket first."
|
|
@@ -270,21 +293,57 @@ class AWSPersistor(Persistor):
|
|
|
270
293
|
)
|
|
271
294
|
raise RasaException(log)
|
|
272
295
|
|
|
296
|
+
@staticmethod
|
|
297
|
+
def _error_code(e: "ClientError") -> int:
|
|
298
|
+
return int(e.response["Error"]["Code"])
|
|
299
|
+
|
|
273
300
|
def _persist_tar(self, file_key: Text, tar_path: Text) -> None:
|
|
274
301
|
"""Uploads a model persisted in the `target_dir` to s3."""
|
|
275
|
-
structlogger.debug(
|
|
276
|
-
"aws_persistor.persist_tar.uploading_model",
|
|
277
|
-
event_info=f"Uploading tar archive {file_key} to "
|
|
278
|
-
f"s3 bucket '{self.bucket_name}'.",
|
|
279
|
-
)
|
|
280
302
|
with open(tar_path, "rb") as f:
|
|
281
303
|
self.s3.Object(self.bucket_name, file_key).put(Body=f)
|
|
282
304
|
|
|
283
|
-
def
|
|
305
|
+
def _retrieve_tar_size(self, model_path: Text) -> int:
|
|
306
|
+
"""Returns the size of the model that has been persisted to s3."""
|
|
307
|
+
try:
|
|
308
|
+
obj = self.s3.Object(self.bucket_name, model_path)
|
|
309
|
+
return obj.content_length
|
|
310
|
+
except Exception:
|
|
311
|
+
raise ModelNotFound()
|
|
312
|
+
|
|
313
|
+
def _retrieve_tar(self, target_filename: str) -> None:
|
|
284
314
|
"""Downloads a model that has previously been persisted to s3."""
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
315
|
+
from botocore import exceptions
|
|
316
|
+
|
|
317
|
+
log = (
|
|
318
|
+
f"Model '{target_filename}' not found in the specified bucket "
|
|
319
|
+
f"'{self.bucket_name}'. Please make sure the model exists "
|
|
320
|
+
f"in the bucket."
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
try:
|
|
324
|
+
with open(target_filename, "wb") as f:
|
|
325
|
+
self.bucket.download_fileobj(target_filename, f)
|
|
326
|
+
|
|
327
|
+
structlogger.debug(
|
|
328
|
+
"aws_persistor.retrieve_tar.object_found", object_key=target_filename
|
|
329
|
+
)
|
|
330
|
+
except exceptions.ClientError as exc:
|
|
331
|
+
if self._error_code(exc) == HTTP_STATUS_NOT_FOUND:
|
|
332
|
+
structlogger.error(
|
|
333
|
+
"aws_persistor.retrieve_tar.model_not_found",
|
|
334
|
+
bucket_name=self.bucket_name,
|
|
335
|
+
target_filename=target_filename,
|
|
336
|
+
event_info=log,
|
|
337
|
+
)
|
|
338
|
+
raise ModelNotFound() from exc
|
|
339
|
+
except exceptions.BotoCoreError as exc:
|
|
340
|
+
structlogger.error(
|
|
341
|
+
"aws_persistor.retrieve_tar.model_download_error",
|
|
342
|
+
bucket_name=self.bucket_name,
|
|
343
|
+
target_filename=target_filename,
|
|
344
|
+
event_info=log,
|
|
345
|
+
)
|
|
346
|
+
raise ModelNotFound() from exc
|
|
288
347
|
|
|
289
348
|
|
|
290
349
|
class GCSPersistor(Persistor):
|
|
@@ -309,42 +368,95 @@ class GCSPersistor(Persistor):
|
|
|
309
368
|
|
|
310
369
|
def _ensure_bucket_exists(self, bucket_name: Text) -> None:
|
|
311
370
|
from google.cloud import exceptions
|
|
371
|
+
from google.auth import exceptions as auth_exceptions
|
|
312
372
|
|
|
313
373
|
try:
|
|
314
374
|
self.storage_client.get_bucket(bucket_name)
|
|
315
|
-
except
|
|
375
|
+
except auth_exceptions.GoogleAuthError as exc:
|
|
376
|
+
log = (
|
|
377
|
+
f"An error occurred while authenticating with Google Cloud "
|
|
378
|
+
f"Storage. Please make sure you have the necessary credentials "
|
|
379
|
+
f"to access the bucket '{bucket_name}'."
|
|
380
|
+
)
|
|
381
|
+
structlogger.error(
|
|
382
|
+
"gcp_persistor.ensure_bucket_exists.authentication_error",
|
|
383
|
+
bucket_name=bucket_name,
|
|
384
|
+
event_info=log,
|
|
385
|
+
)
|
|
386
|
+
raise RasaException(log) from exc
|
|
387
|
+
except exceptions.NotFound as exc:
|
|
316
388
|
log = (
|
|
317
|
-
f"The specified bucket '{bucket_name}'
|
|
318
|
-
"Please make sure to create the bucket first
|
|
389
|
+
f"The specified Google Cloud Storage bucket '{bucket_name}' "
|
|
390
|
+
f"does not exist. Please make sure to create the bucket first or "
|
|
391
|
+
f"provide an alternative valid bucket name."
|
|
319
392
|
)
|
|
320
393
|
structlogger.error(
|
|
321
394
|
"gcp_persistor.ensure_bucket_exists.bucket_not_found",
|
|
322
395
|
bucket_name=bucket_name,
|
|
323
396
|
event_info=log,
|
|
324
397
|
)
|
|
325
|
-
raise RasaException(log)
|
|
326
|
-
except exceptions.Forbidden:
|
|
398
|
+
raise RasaException(log) from exc
|
|
399
|
+
except exceptions.Forbidden as exc:
|
|
327
400
|
log = (
|
|
328
|
-
f"Access to the specified bucket '{bucket_name}'
|
|
329
|
-
"Please make sure you have the necessary "
|
|
330
|
-
"
|
|
401
|
+
f"Access to the specified Google Cloud storage bucket '{bucket_name}' "
|
|
402
|
+
f"is forbidden. Please make sure you have the necessary "
|
|
403
|
+
f"permissions to access the bucket. "
|
|
331
404
|
)
|
|
332
405
|
structlogger.error(
|
|
333
406
|
"gcp_persistor.ensure_bucket_exists.bucket_access_forbidden",
|
|
334
407
|
bucket_name=bucket_name,
|
|
335
408
|
event_info=log,
|
|
336
409
|
)
|
|
337
|
-
raise RasaException(log)
|
|
410
|
+
raise RasaException(log) from exc
|
|
411
|
+
except ValueError as exc:
|
|
412
|
+
# bucket_name is None
|
|
413
|
+
log = (
|
|
414
|
+
"The specified Google Cloud Storage bucket name is None. Please "
|
|
415
|
+
"make sure to provide a valid bucket name."
|
|
416
|
+
)
|
|
417
|
+
structlogger.error(
|
|
418
|
+
"gcp_persistor.ensure_bucket_exists.bucket_name_none",
|
|
419
|
+
event_info=log,
|
|
420
|
+
)
|
|
421
|
+
raise RasaException(log) from exc
|
|
338
422
|
|
|
339
423
|
def _persist_tar(self, file_key: Text, tar_path: Text) -> None:
|
|
340
424
|
"""Uploads a model persisted in the `target_dir` to GCS."""
|
|
341
425
|
blob = self.bucket.blob(file_key)
|
|
342
426
|
blob.upload_from_filename(tar_path)
|
|
343
427
|
|
|
428
|
+
def _retrieve_tar_size(self, target_filename: Text) -> int:
|
|
429
|
+
"""Returns the size of the model that has been persisted to GCS."""
|
|
430
|
+
try:
|
|
431
|
+
blob = self.bucket.blob(target_filename)
|
|
432
|
+
return blob.size
|
|
433
|
+
except Exception:
|
|
434
|
+
raise ModelNotFound()
|
|
435
|
+
|
|
344
436
|
def _retrieve_tar(self, target_filename: Text) -> None:
|
|
345
437
|
"""Downloads a model that has previously been persisted to GCS."""
|
|
438
|
+
from google.api_core import exceptions
|
|
439
|
+
|
|
346
440
|
blob = self.bucket.blob(target_filename)
|
|
347
|
-
|
|
441
|
+
try:
|
|
442
|
+
blob.download_to_filename(target_filename)
|
|
443
|
+
|
|
444
|
+
structlogger.debug(
|
|
445
|
+
"gcs_persistor.retrieve_tar.object_found", object_key=target_filename
|
|
446
|
+
)
|
|
447
|
+
except exceptions.NotFound as exc:
|
|
448
|
+
log = (
|
|
449
|
+
f"Model '{target_filename}' not found in the specified bucket "
|
|
450
|
+
f"'{self.bucket_name}'. Please make sure the model exists "
|
|
451
|
+
f"in the bucket."
|
|
452
|
+
)
|
|
453
|
+
structlogger.error(
|
|
454
|
+
"gcp_persistor.retrieve_tar.model_not_found",
|
|
455
|
+
bucket_name=self.bucket_name,
|
|
456
|
+
target_filename=target_filename,
|
|
457
|
+
event_info=log,
|
|
458
|
+
)
|
|
459
|
+
raise ModelNotFound() from exc
|
|
348
460
|
|
|
349
461
|
|
|
350
462
|
class AzurePersistor(Persistor):
|
|
@@ -370,7 +482,8 @@ class AzurePersistor(Persistor):
|
|
|
370
482
|
else:
|
|
371
483
|
log = (
|
|
372
484
|
f"The specified container '{self.container_name}' does not exist."
|
|
373
|
-
"Please make sure to create the
|
|
485
|
+
"Please make sure to create the bucket first or "
|
|
486
|
+
f"provide an alternative valid bucket name."
|
|
374
487
|
)
|
|
375
488
|
structlogger.error(
|
|
376
489
|
"azure_persistor.ensure_container_exists.container_not_found",
|
|
@@ -385,19 +498,41 @@ class AzurePersistor(Persistor):
|
|
|
385
498
|
def _persist_tar(self, file_key: Text, tar_path: Text) -> None:
|
|
386
499
|
"""Uploads a model persisted in the `target_dir` to Azure."""
|
|
387
500
|
with open(tar_path, "rb") as data:
|
|
388
|
-
self._container_client().upload_blob(
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
501
|
+
self._container_client().upload_blob(name=file_key, data=data)
|
|
502
|
+
|
|
503
|
+
def _retrieve_tar_size(self, target_filename: Text) -> int:
|
|
504
|
+
"""Returns the size of the model that has been persisted to Azure."""
|
|
505
|
+
try:
|
|
506
|
+
blob_client = self._container_client().get_blob_client(target_filename)
|
|
507
|
+
properties = blob_client.get_blob_properties()
|
|
508
|
+
return properties.size
|
|
509
|
+
except Exception:
|
|
510
|
+
raise ModelNotFound()
|
|
396
511
|
|
|
397
512
|
def _retrieve_tar(self, target_filename: Text) -> None:
|
|
398
513
|
"""Downloads a model that has previously been persisted to Azure."""
|
|
399
|
-
|
|
514
|
+
from azure.core.exceptions import AzureError
|
|
400
515
|
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
516
|
+
try:
|
|
517
|
+
with open(target_filename, "wb") as model_file:
|
|
518
|
+
blob_client = self._container_client().get_blob_client(target_filename)
|
|
519
|
+
download_stream = blob_client.download_blob()
|
|
520
|
+
model_file.write(download_stream.readall())
|
|
521
|
+
structlogger.debug(
|
|
522
|
+
"azure_persistor.retrieve_tar.blob_found", blob_name=target_filename
|
|
523
|
+
)
|
|
524
|
+
except AzureError as exc:
|
|
525
|
+
log = (
|
|
526
|
+
f"An exception occurred while trying to download "
|
|
527
|
+
f"the model '{target_filename}' in the specified container "
|
|
528
|
+
f"'{self.container_name}'. Please make sure the model exists "
|
|
529
|
+
f"in the container."
|
|
530
|
+
)
|
|
531
|
+
structlogger.error(
|
|
532
|
+
"azure_persistor.retrieve_tar.model_download_error",
|
|
533
|
+
container_name=self.container_name,
|
|
534
|
+
target_filename=target_filename,
|
|
535
|
+
event_info=log,
|
|
536
|
+
exception=exc,
|
|
537
|
+
)
|
|
538
|
+
raise ModelNotFound() from exc
|