rasa-pro 3.11.0a3__py3-none-any.whl → 3.11.0a4.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- README.md +17 -396
- rasa/api.py +4 -0
- rasa/cli/arguments/train.py +14 -0
- rasa/cli/inspect.py +1 -1
- rasa/cli/interactive.py +1 -0
- rasa/cli/project_templates/calm/endpoints.yml +7 -2
- rasa/cli/project_templates/tutorial/endpoints.yml +7 -2
- rasa/cli/train.py +3 -0
- rasa/constants.py +2 -0
- rasa/core/actions/action.py +75 -33
- rasa/core/actions/action_repeat_bot_messages.py +72 -0
- rasa/core/actions/e2e_stub_custom_action_executor.py +5 -1
- rasa/core/actions/http_custom_action_executor.py +4 -0
- rasa/core/channels/socketio.py +5 -1
- rasa/core/channels/voice_ready/utils.py +6 -5
- rasa/core/channels/voice_stream/browser_audio.py +1 -1
- rasa/core/channels/voice_stream/twilio_media_streams.py +1 -1
- rasa/core/nlg/contextual_response_rephraser.py +19 -2
- rasa/core/persistor.py +87 -21
- rasa/core/utils.py +53 -22
- rasa/dialogue_understanding/commands/__init__.py +4 -0
- rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +60 -0
- rasa/dialogue_understanding/generator/single_step/command_prompt_template.jinja2 +3 -0
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +19 -0
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +5 -0
- rasa/dialogue_understanding/patterns/repeat.py +37 -0
- rasa/e2e_test/utils/io.py +2 -0
- rasa/model_manager/__init__.py +0 -0
- rasa/model_manager/config.py +18 -0
- rasa/model_manager/model_api.py +469 -0
- rasa/model_manager/runner_service.py +279 -0
- rasa/model_manager/socket_bridge.py +143 -0
- rasa/model_manager/studio_jwt_auth.py +86 -0
- rasa/model_manager/trainer_service.py +332 -0
- rasa/model_manager/utils.py +66 -0
- rasa/model_service.py +109 -0
- rasa/model_training.py +25 -7
- rasa/shared/constants.py +6 -0
- rasa/shared/core/constants.py +2 -0
- rasa/shared/providers/llm/self_hosted_llm_client.py +15 -3
- rasa/shared/utils/yaml.py +10 -1
- rasa/utils/endpoints.py +27 -1
- rasa/version.py +1 -1
- rasa_pro-3.11.0a4.dev1.dist-info/METADATA +197 -0
- {rasa_pro-3.11.0a3.dist-info → rasa_pro-3.11.0a4.dev1.dist-info}/RECORD +48 -38
- rasa/keys +0 -1
- rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +0 -407
- rasa_pro-3.11.0a3.dist-info/METADATA +0 -576
- {rasa_pro-3.11.0a3.dist-info → rasa_pro-3.11.0a4.dev1.dist-info}/NOTICE +0 -0
- {rasa_pro-3.11.0a3.dist-info → rasa_pro-3.11.0a4.dev1.dist-info}/WHEEL +0 -0
- {rasa_pro-3.11.0a3.dist-info → rasa_pro-3.11.0a4.dev1.dist-info}/entry_points.txt +0 -0
rasa/core/actions/action.py
CHANGED
|
@@ -13,6 +13,8 @@ from typing import (
|
|
|
13
13
|
cast,
|
|
14
14
|
)
|
|
15
15
|
|
|
16
|
+
from jsonschema import Draft202012Validator
|
|
17
|
+
|
|
16
18
|
import rasa.core
|
|
17
19
|
import rasa.shared.utils.io
|
|
18
20
|
from rasa.core.actions.custom_action_executor import (
|
|
@@ -101,7 +103,6 @@ if TYPE_CHECKING:
|
|
|
101
103
|
from rasa.core.nlg import NaturalLanguageGenerator
|
|
102
104
|
from rasa.shared.core.events import IntentPrediction
|
|
103
105
|
|
|
104
|
-
|
|
105
106
|
logger = logging.getLogger(__name__)
|
|
106
107
|
|
|
107
108
|
|
|
@@ -113,6 +114,7 @@ def default_actions(action_endpoint: Optional[EndpointConfig] = None) -> List["A
|
|
|
113
114
|
from rasa.core.actions.action_trigger_search import ActionTriggerSearch
|
|
114
115
|
from rasa.core.actions.two_stage_fallback import TwoStageFallbackAction
|
|
115
116
|
from rasa.core.actions.action_hangup import ActionHangup
|
|
117
|
+
from rasa.core.actions.action_repeat_bot_messages import ActionRepeatBotMessages
|
|
116
118
|
from rasa.dialogue_understanding.patterns.cancel import ActionCancelFlow
|
|
117
119
|
from rasa.dialogue_understanding.patterns.clarify import ActionClarifyFlows
|
|
118
120
|
from rasa.dialogue_understanding.patterns.correction import ActionCorrectFlowSlot
|
|
@@ -140,6 +142,7 @@ def default_actions(action_endpoint: Optional[EndpointConfig] = None) -> List["A
|
|
|
140
142
|
ActionTriggerChitchat(),
|
|
141
143
|
ActionResetRouting(),
|
|
142
144
|
ActionHangup(),
|
|
145
|
+
ActionRepeatBotMessages(),
|
|
143
146
|
]
|
|
144
147
|
|
|
145
148
|
|
|
@@ -723,6 +726,77 @@ class ActionDeactivateLoop(Action):
|
|
|
723
726
|
return [ActiveLoop(None), SlotSet(REQUESTED_SLOT, None)]
|
|
724
727
|
|
|
725
728
|
|
|
729
|
+
class RemoteActionJSONValidator:
|
|
730
|
+
"""
|
|
731
|
+
A validator class for ensuring that the JSON response from a custom action executor
|
|
732
|
+
adheres to the expected schema.
|
|
733
|
+
"""
|
|
734
|
+
|
|
735
|
+
@staticmethod
|
|
736
|
+
def action_response_format_spec() -> Dict[Text, Any]:
|
|
737
|
+
"""Expected response schema for an Action endpoint.
|
|
738
|
+
|
|
739
|
+
Used for validation of the response returned from the
|
|
740
|
+
Action endpoint.
|
|
741
|
+
|
|
742
|
+
Returns:
|
|
743
|
+
Dict[Text, Any]: A dictionary representing the JSON schema for validation.
|
|
744
|
+
"""
|
|
745
|
+
schema = {
|
|
746
|
+
"type": "object",
|
|
747
|
+
"properties": {
|
|
748
|
+
"events": EVENTS_SCHEMA,
|
|
749
|
+
"responses": {"type": "array", "items": {"type": "object"}},
|
|
750
|
+
},
|
|
751
|
+
}
|
|
752
|
+
return schema
|
|
753
|
+
|
|
754
|
+
@classmethod
|
|
755
|
+
def validate(cls, result: Dict[Text, Any]) -> bool:
|
|
756
|
+
"""
|
|
757
|
+
Validate the given JSON result against the expected Action response schema.
|
|
758
|
+
|
|
759
|
+
This method uses a cached JSON schema validator to check if the provided result
|
|
760
|
+
conforms to the predefined schema.
|
|
761
|
+
|
|
762
|
+
Args:
|
|
763
|
+
result (Dict[Text, Any]): The JSON response to validate.
|
|
764
|
+
|
|
765
|
+
Returns:
|
|
766
|
+
bool: True if validation is successful.
|
|
767
|
+
|
|
768
|
+
Raises:
|
|
769
|
+
ValidationError: If the JSON response does not conform to the schema.
|
|
770
|
+
"""
|
|
771
|
+
from jsonschema import ValidationError
|
|
772
|
+
|
|
773
|
+
try:
|
|
774
|
+
validator = cls.get_action_response_validator()
|
|
775
|
+
validator.validate(
|
|
776
|
+
result, RemoteActionJSONValidator.action_response_format_spec()
|
|
777
|
+
)
|
|
778
|
+
return True
|
|
779
|
+
except ValidationError as e:
|
|
780
|
+
e.message += (
|
|
781
|
+
f". Failed to validate Action server response from API, "
|
|
782
|
+
f"make sure your response from the Action endpoint is valid. "
|
|
783
|
+
f"For more information about the format visit "
|
|
784
|
+
f"{DOCS_BASE_URL}/custom-actions"
|
|
785
|
+
)
|
|
786
|
+
raise e
|
|
787
|
+
|
|
788
|
+
@classmethod
|
|
789
|
+
@lru_cache(maxsize=1)
|
|
790
|
+
def get_action_response_validator(cls) -> Draft202012Validator:
|
|
791
|
+
"""
|
|
792
|
+
Retrieve a cached JSON schema validator for the Action response schema.
|
|
793
|
+
|
|
794
|
+
Returns:
|
|
795
|
+
Draft202012Validator: An instance of the JSON schema validator.
|
|
796
|
+
"""
|
|
797
|
+
return Draft202012Validator(cls.action_response_format_spec())
|
|
798
|
+
|
|
799
|
+
|
|
726
800
|
class RemoteAction(Action):
|
|
727
801
|
def __init__(
|
|
728
802
|
self,
|
|
@@ -781,37 +855,6 @@ class RemoteAction(Action):
|
|
|
781
855
|
f"Found url '{self.action_endpoint.url}'."
|
|
782
856
|
)
|
|
783
857
|
|
|
784
|
-
@staticmethod
|
|
785
|
-
def action_response_format_spec() -> Dict[Text, Any]:
|
|
786
|
-
"""Expected response schema for an Action endpoint.
|
|
787
|
-
|
|
788
|
-
Used for validation of the response returned from the
|
|
789
|
-
Action endpoint.
|
|
790
|
-
"""
|
|
791
|
-
schema = {
|
|
792
|
-
"type": "object",
|
|
793
|
-
"properties": {
|
|
794
|
-
"events": EVENTS_SCHEMA,
|
|
795
|
-
"responses": {"type": "array", "items": {"type": "object"}},
|
|
796
|
-
},
|
|
797
|
-
}
|
|
798
|
-
return schema
|
|
799
|
-
|
|
800
|
-
def _validate_action_result(self, result: Dict[Text, Any]) -> bool:
|
|
801
|
-
from jsonschema import ValidationError, validate
|
|
802
|
-
|
|
803
|
-
try:
|
|
804
|
-
validate(result, self.action_response_format_spec())
|
|
805
|
-
return True
|
|
806
|
-
except ValidationError as e:
|
|
807
|
-
e.message += (
|
|
808
|
-
f". Failed to validate Action server response from API, "
|
|
809
|
-
f"make sure your response from the Action endpoint is valid. "
|
|
810
|
-
f"For more information about the format visit "
|
|
811
|
-
f"{DOCS_BASE_URL}/custom-actions"
|
|
812
|
-
)
|
|
813
|
-
raise e
|
|
814
|
-
|
|
815
858
|
@staticmethod
|
|
816
859
|
async def _utter_responses(
|
|
817
860
|
responses: List[Dict[Text, Any]],
|
|
@@ -863,7 +906,6 @@ class RemoteAction(Action):
|
|
|
863
906
|
domain=domain,
|
|
864
907
|
tracker=tracker,
|
|
865
908
|
)
|
|
866
|
-
self._validate_action_result(response)
|
|
867
909
|
|
|
868
910
|
events_json = response.get("events", [])
|
|
869
911
|
responses = response.get("responses", [])
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
from typing import Optional, Dict, Any, List
|
|
2
|
+
|
|
3
|
+
from rasa.core.actions.action import Action
|
|
4
|
+
from rasa.core.channels import OutputChannel
|
|
5
|
+
from rasa.core.nlg import NaturalLanguageGenerator
|
|
6
|
+
from rasa.shared.core.constants import ACTION_REPEAT_BOT_MESSAGES
|
|
7
|
+
from rasa.shared.core.domain import Domain
|
|
8
|
+
from rasa.shared.core.events import Event, BotUttered, UserUttered
|
|
9
|
+
from rasa.shared.core.trackers import DialogueStateTracker
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ActionRepeatBotMessages(Action):
|
|
13
|
+
"""Action to repeat bot messages"""
|
|
14
|
+
|
|
15
|
+
def name(self) -> str:
|
|
16
|
+
"""Return the name of the action."""
|
|
17
|
+
return ACTION_REPEAT_BOT_MESSAGES
|
|
18
|
+
|
|
19
|
+
def _get_last_bot_events(self, tracker: DialogueStateTracker) -> List[Event]:
|
|
20
|
+
"""Get the last consecutive bot events before the most recent user message.
|
|
21
|
+
|
|
22
|
+
This function scans the dialogue history in reverse to find the last sequence of
|
|
23
|
+
bot responses that occurred without any user interruption. It filters out all
|
|
24
|
+
non-utterance events and stops when it encounters a user message after finding
|
|
25
|
+
bot messages.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
tracker: DialogueStateTracker containing the conversation events.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
List[Event]: A list of consecutive BotUttered events that occurred
|
|
32
|
+
most recently, in chronological order. Returns an empty list
|
|
33
|
+
if no bot messages are found or if the last message was from
|
|
34
|
+
the user.
|
|
35
|
+
|
|
36
|
+
Example:
|
|
37
|
+
For events: [User1, Bot1, Bot2, User2, Bot4, Bot5, User3]
|
|
38
|
+
Returns: [Bot4, Bot5] (the last two bot events)
|
|
39
|
+
The elif condition doesn't break when it sees User3 event.
|
|
40
|
+
But it does at User2 event.
|
|
41
|
+
"""
|
|
42
|
+
# filter user and bot events
|
|
43
|
+
filtered = [
|
|
44
|
+
e for e in tracker.events if isinstance(e, (BotUttered, UserUttered))
|
|
45
|
+
]
|
|
46
|
+
bot_events: List[Event] = []
|
|
47
|
+
|
|
48
|
+
# find the last BotUttered events
|
|
49
|
+
for e in reversed(filtered):
|
|
50
|
+
if isinstance(e, BotUttered):
|
|
51
|
+
# insert instead of append because the list is reversed
|
|
52
|
+
bot_events.insert(0, e)
|
|
53
|
+
|
|
54
|
+
# stop if a UserUttered event is found
|
|
55
|
+
# only if we have collected some bot events already
|
|
56
|
+
# this condition skips the first N UserUttered events
|
|
57
|
+
elif bot_events:
|
|
58
|
+
break
|
|
59
|
+
|
|
60
|
+
return bot_events
|
|
61
|
+
|
|
62
|
+
async def run(
|
|
63
|
+
self,
|
|
64
|
+
output_channel: OutputChannel,
|
|
65
|
+
nlg: NaturalLanguageGenerator,
|
|
66
|
+
tracker: DialogueStateTracker,
|
|
67
|
+
domain: Domain,
|
|
68
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
69
|
+
) -> List[Event]:
|
|
70
|
+
"""Send the last bot messages to the channel again"""
|
|
71
|
+
bot_events = self._get_last_bot_events(tracker)
|
|
72
|
+
return bot_events
|
|
@@ -61,8 +61,12 @@ class E2EStubCustomActionExecutor(CustomActionExecutor):
|
|
|
61
61
|
domain: "Domain",
|
|
62
62
|
include_domain: bool = False,
|
|
63
63
|
) -> Dict[Text, Any]:
|
|
64
|
+
from rasa.core.actions.action import RemoteActionJSONValidator
|
|
65
|
+
|
|
64
66
|
structlogger.debug(
|
|
65
67
|
"action.e2e_stub_custom_action_executor.run",
|
|
66
68
|
action_name=self.action_name,
|
|
67
69
|
)
|
|
68
|
-
|
|
70
|
+
response = self.stub_custom_action.as_dict()
|
|
71
|
+
RemoteActionJSONValidator.validate(response)
|
|
72
|
+
return response
|
|
@@ -66,6 +66,8 @@ class HTTPCustomActionExecutor(CustomActionExecutor):
|
|
|
66
66
|
Raises:
|
|
67
67
|
RasaException: If an error occurs while making the HTTP request.
|
|
68
68
|
"""
|
|
69
|
+
from rasa.core.actions.action import RemoteActionJSONValidator
|
|
70
|
+
|
|
69
71
|
try:
|
|
70
72
|
logger.debug(
|
|
71
73
|
"Calling action endpoint to run action '{}'.".format(self.action_name)
|
|
@@ -80,6 +82,8 @@ class HTTPCustomActionExecutor(CustomActionExecutor):
|
|
|
80
82
|
if response is None:
|
|
81
83
|
response = {}
|
|
82
84
|
|
|
85
|
+
RemoteActionJSONValidator.validate(response)
|
|
86
|
+
|
|
83
87
|
return response
|
|
84
88
|
|
|
85
89
|
except ClientResponseError as e:
|
rasa/core/channels/socketio.py
CHANGED
|
@@ -37,7 +37,11 @@ class SocketBlueprint(Blueprint):
|
|
|
37
37
|
:param options: Options to be used while registering the
|
|
38
38
|
blueprint into the app.
|
|
39
39
|
"""
|
|
40
|
-
|
|
40
|
+
if self.ctx.socketio_path:
|
|
41
|
+
path = self.ctx.socketio_path
|
|
42
|
+
else:
|
|
43
|
+
path = options.get("url_prefix", "/socket.io")
|
|
44
|
+
self.ctx.sio.attach(app, path)
|
|
41
45
|
super().register(app, options)
|
|
42
46
|
|
|
43
47
|
|
|
@@ -2,16 +2,17 @@ import structlog
|
|
|
2
2
|
from dataclasses import dataclass
|
|
3
3
|
from typing import Optional
|
|
4
4
|
|
|
5
|
-
from rasa.utils.licensing import (
|
|
6
|
-
PRODUCT_AREA,
|
|
7
|
-
VOICE_SCOPE,
|
|
8
|
-
validate_license_from_env,
|
|
9
|
-
)
|
|
10
5
|
|
|
11
6
|
structlogger = structlog.get_logger()
|
|
12
7
|
|
|
13
8
|
|
|
14
9
|
def validate_voice_license_scope() -> None:
|
|
10
|
+
from rasa.utils.licensing import (
|
|
11
|
+
PRODUCT_AREA,
|
|
12
|
+
VOICE_SCOPE,
|
|
13
|
+
validate_license_from_env,
|
|
14
|
+
)
|
|
15
|
+
|
|
15
16
|
"""Validate that the correct license scope is present."""
|
|
16
17
|
structlogger.info(
|
|
17
18
|
f"Validating current Rasa Pro license scope which must include "
|
|
@@ -71,7 +71,7 @@ class BrowserAudioInputChannel(VoiceInputChannel):
|
|
|
71
71
|
self, on_new_message: Callable[[UserMessage], Awaitable[Any]]
|
|
72
72
|
) -> Blueprint:
|
|
73
73
|
"""Defines a Sanic bluelogger.debug."""
|
|
74
|
-
blueprint = Blueprint("
|
|
74
|
+
blueprint = Blueprint("browser_audio", __name__)
|
|
75
75
|
|
|
76
76
|
@blueprint.route("/", methods=["GET"])
|
|
77
77
|
async def health(_: Request) -> HTTPResponse:
|
|
@@ -131,7 +131,7 @@ class TwilioMediaStreamsInputChannel(VoiceInputChannel):
|
|
|
131
131
|
self, on_new_message: Callable[[UserMessage], Awaitable[Any]]
|
|
132
132
|
) -> Blueprint:
|
|
133
133
|
"""Defines a Sanic bluelogger.debug."""
|
|
134
|
-
blueprint = Blueprint("
|
|
134
|
+
blueprint = Blueprint("twilio_media_streams", __name__)
|
|
135
135
|
|
|
136
136
|
@blueprint.route("/", methods=["GET"])
|
|
137
137
|
async def health(_: Request) -> HTTPResponse:
|
|
@@ -30,6 +30,9 @@ from rasa.shared.utils.llm import (
|
|
|
30
30
|
try_instantiate_llm_client,
|
|
31
31
|
)
|
|
32
32
|
from rasa.utils.endpoints import EndpointConfig
|
|
33
|
+
from rasa.shared.utils.llm import (
|
|
34
|
+
tracker_as_readable_transcript,
|
|
35
|
+
)
|
|
33
36
|
|
|
34
37
|
from rasa.core.nlg.summarize import summarize_conversation
|
|
35
38
|
|
|
@@ -41,6 +44,8 @@ RESPONSE_REPHRASING_KEY = "rephrase"
|
|
|
41
44
|
|
|
42
45
|
RESPONSE_REPHRASING_TEMPLATE_KEY = "rephrase_prompt"
|
|
43
46
|
|
|
47
|
+
RESPONSE_SUMMARISE_CONVERSATION_KEY = "summarize_conversation"
|
|
48
|
+
|
|
44
49
|
DEFAULT_REPHRASE_ALL = False
|
|
45
50
|
|
|
46
51
|
DEFAULT_LLM_CONFIG = {
|
|
@@ -212,13 +217,25 @@ class ContextualResponseRephraser(TemplatedNaturalLanguageGenerator):
|
|
|
212
217
|
if not (response_text := response.get(KEY_RESPONSES_TEXT)):
|
|
213
218
|
return response
|
|
214
219
|
|
|
220
|
+
prompt_template_text = self._template_for_response_rephrasing(response)
|
|
221
|
+
|
|
222
|
+
# Retrieve inputs for the dynamic prompt
|
|
223
|
+
transcript = tracker_as_readable_transcript(tracker, max_turns=5)
|
|
215
224
|
latest_message = self._last_message_if_human(tracker)
|
|
216
225
|
current_input = f"{USER}: {latest_message}" if latest_message else ""
|
|
217
226
|
|
|
218
|
-
|
|
227
|
+
# Only summarise conversation history if flagged
|
|
228
|
+
summarize_conversation_flag = response.get("metadata", {}).get(
|
|
229
|
+
RESPONSE_SUMMARISE_CONVERSATION_KEY, False
|
|
230
|
+
)
|
|
231
|
+
if summarize_conversation_flag:
|
|
232
|
+
history = await self._create_history(tracker)
|
|
233
|
+
else:
|
|
234
|
+
history = transcript
|
|
235
|
+
current_input = ""
|
|
219
236
|
|
|
220
237
|
prompt = Template(prompt_template_text).render(
|
|
221
|
-
history=
|
|
238
|
+
history=history,
|
|
222
239
|
suggested_response=response_text,
|
|
223
240
|
current_input=current_input,
|
|
224
241
|
slots=tracker.current_slot_values(),
|
rasa/core/persistor.py
CHANGED
|
@@ -8,12 +8,14 @@ from typing import TYPE_CHECKING, List, Optional, Text, Tuple, Union
|
|
|
8
8
|
|
|
9
9
|
import structlog
|
|
10
10
|
|
|
11
|
+
from rasa.exceptions import ModelNotFound
|
|
11
12
|
import rasa.shared.utils.common
|
|
12
13
|
import rasa.utils.common
|
|
13
14
|
from rasa.constants import (
|
|
14
15
|
HTTP_STATUS_FORBIDDEN,
|
|
15
16
|
HTTP_STATUS_NOT_FOUND,
|
|
16
17
|
MODEL_ARCHIVE_EXTENSION,
|
|
18
|
+
DEFAULT_BUCKET_NAME,
|
|
17
19
|
)
|
|
18
20
|
from rasa.env import (
|
|
19
21
|
AWS_ENDPOINT_URL_ENV,
|
|
@@ -28,6 +30,7 @@ from rasa.shared.utils.io import raise_warning
|
|
|
28
30
|
|
|
29
31
|
if TYPE_CHECKING:
|
|
30
32
|
from azure.storage.blob import ContainerClient
|
|
33
|
+
import botocore
|
|
31
34
|
|
|
32
35
|
structlogger = structlog.get_logger()
|
|
33
36
|
|
|
@@ -86,14 +89,15 @@ def get_persistor(storage: StorageType) -> Optional[Persistor]:
|
|
|
86
89
|
|
|
87
90
|
if storage == RemoteStorageType.AWS.value:
|
|
88
91
|
return AWSPersistor(
|
|
89
|
-
os.environ.get(BUCKET_NAME_ENV
|
|
92
|
+
os.environ.get(BUCKET_NAME_ENV, DEFAULT_BUCKET_NAME),
|
|
93
|
+
os.environ.get(AWS_ENDPOINT_URL_ENV),
|
|
90
94
|
)
|
|
91
95
|
if storage == RemoteStorageType.GCS.value:
|
|
92
|
-
return GCSPersistor(os.environ.get(BUCKET_NAME_ENV))
|
|
96
|
+
return GCSPersistor(os.environ.get(BUCKET_NAME_ENV, DEFAULT_BUCKET_NAME))
|
|
93
97
|
|
|
94
98
|
if storage == RemoteStorageType.AZURE.value:
|
|
95
99
|
return AzurePersistor(
|
|
96
|
-
os.environ.get(AZURE_CONTAINER_ENV),
|
|
100
|
+
os.environ.get(AZURE_CONTAINER_ENV, DEFAULT_BUCKET_NAME),
|
|
97
101
|
os.environ.get(AZURE_ACCOUNT_NAME_ENV),
|
|
98
102
|
os.environ.get(AZURE_ACCOUNT_KEY_ENV),
|
|
99
103
|
)
|
|
@@ -181,7 +185,7 @@ class Persistor(abc.ABC):
|
|
|
181
185
|
|
|
182
186
|
@staticmethod
|
|
183
187
|
def _create_file_key(model_path: str) -> Text:
|
|
184
|
-
"""Appends remote storage folders when provided to upload or retrieve file"""
|
|
188
|
+
"""Appends remote storage folders when provided to upload or retrieve file."""
|
|
185
189
|
bucket_object_path = os.environ.get(REMOTE_STORAGE_PATH_ENV)
|
|
186
190
|
|
|
187
191
|
# To keep the backward compatibility, if REMOTE_STORAGE_PATH is not provided,
|
|
@@ -235,8 +239,7 @@ class AWSPersistor(Persistor):
|
|
|
235
239
|
try:
|
|
236
240
|
self.s3.meta.client.head_bucket(Bucket=bucket_name)
|
|
237
241
|
except botocore.exceptions.ClientError as e:
|
|
238
|
-
error_code
|
|
239
|
-
if error_code == HTTP_STATUS_FORBIDDEN:
|
|
242
|
+
if self.error_code(e) == HTTP_STATUS_FORBIDDEN:
|
|
240
243
|
log = (
|
|
241
244
|
f"Access to the specified bucket '{bucket_name}' is forbidden. "
|
|
242
245
|
"Please make sure you have the necessary "
|
|
@@ -248,7 +251,7 @@ class AWSPersistor(Persistor):
|
|
|
248
251
|
event_info=log,
|
|
249
252
|
)
|
|
250
253
|
raise RasaException(log)
|
|
251
|
-
elif error_code == HTTP_STATUS_NOT_FOUND:
|
|
254
|
+
elif self.error_code(e) == HTTP_STATUS_NOT_FOUND:
|
|
252
255
|
log = (
|
|
253
256
|
f"The specified bucket '{bucket_name}' does not exist. "
|
|
254
257
|
"Please make sure to create the bucket first."
|
|
@@ -260,6 +263,10 @@ class AWSPersistor(Persistor):
|
|
|
260
263
|
)
|
|
261
264
|
raise RasaException(log)
|
|
262
265
|
|
|
266
|
+
@staticmethod
|
|
267
|
+
def error_code(e: botocore.exceptions.ClientError) -> int:
|
|
268
|
+
return int(e.response["Error"]["Code"])
|
|
269
|
+
|
|
263
270
|
def _persist_tar(self, file_key: Text, tar_path: Text) -> None:
|
|
264
271
|
"""Uploads a model persisted in the `target_dir` to s3."""
|
|
265
272
|
with open(tar_path, "rb") as f:
|
|
@@ -267,9 +274,26 @@ class AWSPersistor(Persistor):
|
|
|
267
274
|
|
|
268
275
|
def _retrieve_tar(self, model_path: Text) -> None:
|
|
269
276
|
"""Downloads a model that has previously been persisted to s3."""
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
277
|
+
import botocore
|
|
278
|
+
|
|
279
|
+
target_filename = os.path.basename(model_path)
|
|
280
|
+
try:
|
|
281
|
+
with open(target_filename, "wb") as f:
|
|
282
|
+
self.bucket.download_fileobj(model_path, f)
|
|
283
|
+
except botocore.exceptions.ClientError as e:
|
|
284
|
+
if self.error_code(e) == HTTP_STATUS_NOT_FOUND:
|
|
285
|
+
log = (
|
|
286
|
+
f"Model '{target_filename}' not found in the specified bucket "
|
|
287
|
+
f"'{self.bucket_name}'. Please make sure the model exists "
|
|
288
|
+
f"in the bucket."
|
|
289
|
+
)
|
|
290
|
+
structlogger.error(
|
|
291
|
+
"gcp_persistor.retrieve_tar.model_not_found",
|
|
292
|
+
bucket_name=self.bucket_name,
|
|
293
|
+
target_filename=target_filename,
|
|
294
|
+
event_info=log,
|
|
295
|
+
)
|
|
296
|
+
raise ModelNotFound() from e
|
|
273
297
|
|
|
274
298
|
|
|
275
299
|
class GCSPersistor(Persistor):
|
|
@@ -294,32 +318,57 @@ class GCSPersistor(Persistor):
|
|
|
294
318
|
|
|
295
319
|
def _ensure_bucket_exists(self, bucket_name: Text) -> None:
|
|
296
320
|
from google.cloud import exceptions
|
|
321
|
+
from google.auth import exceptions as auth_exceptions
|
|
297
322
|
|
|
298
323
|
try:
|
|
299
324
|
self.storage_client.get_bucket(bucket_name)
|
|
300
|
-
except
|
|
325
|
+
except auth_exceptions.GoogleAuthError as e:
|
|
301
326
|
log = (
|
|
302
|
-
f"
|
|
303
|
-
"Please make sure
|
|
327
|
+
f"An error occurred while authenticating with Google Cloud "
|
|
328
|
+
f"Storage. Please make sure you have the necessary credentials "
|
|
329
|
+
f"to access the bucket '{bucket_name}'."
|
|
330
|
+
)
|
|
331
|
+
structlogger.error(
|
|
332
|
+
"gcp_persistor.ensure_bucket_exists.authentication_error",
|
|
333
|
+
bucket_name=bucket_name,
|
|
334
|
+
event_info=log,
|
|
335
|
+
)
|
|
336
|
+
raise RasaException(log) from e
|
|
337
|
+
except exceptions.NotFound as e:
|
|
338
|
+
log = (
|
|
339
|
+
f"The specified Google Cloud Storage bucket '{bucket_name}' "
|
|
340
|
+
f"does not exist. Please make sure to create the bucket first or "
|
|
341
|
+
f"provide an alternative valid bucket name."
|
|
304
342
|
)
|
|
305
343
|
structlogger.error(
|
|
306
344
|
"gcp_persistor.ensure_bucket_exists.bucket_not_found",
|
|
307
345
|
bucket_name=bucket_name,
|
|
308
346
|
event_info=log,
|
|
309
347
|
)
|
|
310
|
-
raise RasaException(log)
|
|
311
|
-
except exceptions.Forbidden:
|
|
348
|
+
raise RasaException(log) from e
|
|
349
|
+
except exceptions.Forbidden as e:
|
|
312
350
|
log = (
|
|
313
|
-
f"Access to the specified bucket '{bucket_name}'
|
|
314
|
-
"Please make sure you have the necessary "
|
|
315
|
-
"permission to access the bucket. "
|
|
351
|
+
f"Access to the specified Google Cloud storage bucket '{bucket_name}' "
|
|
352
|
+
f"is forbidden. Please make sure you have the necessary "
|
|
353
|
+
f"permission to access the bucket. "
|
|
316
354
|
)
|
|
317
355
|
structlogger.error(
|
|
318
356
|
"gcp_persistor.ensure_bucket_exists.bucket_access_forbidden",
|
|
319
357
|
bucket_name=bucket_name,
|
|
320
358
|
event_info=log,
|
|
321
359
|
)
|
|
322
|
-
raise RasaException(log)
|
|
360
|
+
raise RasaException(log) from e
|
|
361
|
+
except ValueError as e:
|
|
362
|
+
# bucket_name is None
|
|
363
|
+
log = (
|
|
364
|
+
"The specified Google Cloud Storage bucket name is None. Please "
|
|
365
|
+
"make sure to provide a valid bucket name."
|
|
366
|
+
)
|
|
367
|
+
structlogger.error(
|
|
368
|
+
"gcp_persistor.ensure_bucket_exists.bucket_name_none",
|
|
369
|
+
event_info=log,
|
|
370
|
+
)
|
|
371
|
+
raise RasaException(log) from e
|
|
323
372
|
|
|
324
373
|
def _persist_tar(self, file_key: Text, tar_path: Text) -> None:
|
|
325
374
|
"""Uploads a model persisted in the `target_dir` to GCS."""
|
|
@@ -328,8 +377,24 @@ class GCSPersistor(Persistor):
|
|
|
328
377
|
|
|
329
378
|
def _retrieve_tar(self, target_filename: Text) -> None:
|
|
330
379
|
"""Downloads a model that has previously been persisted to GCS."""
|
|
380
|
+
from google.api_core import exceptions
|
|
381
|
+
|
|
331
382
|
blob = self.bucket.blob(target_filename)
|
|
332
|
-
|
|
383
|
+
try:
|
|
384
|
+
blob.download_to_filename(target_filename)
|
|
385
|
+
except exceptions.NotFound as e:
|
|
386
|
+
log = (
|
|
387
|
+
f"Model '{target_filename}' not found in the specified bucket "
|
|
388
|
+
f"'{self.bucket_name}'. Please make sure the model exists "
|
|
389
|
+
f"in the bucket."
|
|
390
|
+
)
|
|
391
|
+
structlogger.error(
|
|
392
|
+
"gcp_persistor.retrieve_tar.model_not_found",
|
|
393
|
+
bucket_name=self.bucket_name,
|
|
394
|
+
target_filename=target_filename,
|
|
395
|
+
event_info=log,
|
|
396
|
+
)
|
|
397
|
+
raise ModelNotFound() from e
|
|
333
398
|
|
|
334
399
|
|
|
335
400
|
class AzurePersistor(Persistor):
|
|
@@ -355,7 +420,8 @@ class AzurePersistor(Persistor):
|
|
|
355
420
|
else:
|
|
356
421
|
log = (
|
|
357
422
|
f"The specified container '{self.container_name}' does not exist."
|
|
358
|
-
"Please make sure to create the
|
|
423
|
+
"Please make sure to create the bucket first or "
|
|
424
|
+
f"provide an alternative valid bucket name."
|
|
359
425
|
)
|
|
360
426
|
structlogger.error(
|
|
361
427
|
"azure_persistor.ensure_container_exists.container_not_found",
|