rasa-pro 3.10.16__py3-none-any.whl → 3.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (240) hide show
  1. rasa/__main__.py +31 -15
  2. rasa/api.py +12 -2
  3. rasa/cli/arguments/default_arguments.py +24 -4
  4. rasa/cli/arguments/run.py +15 -0
  5. rasa/cli/arguments/shell.py +5 -1
  6. rasa/cli/arguments/train.py +17 -9
  7. rasa/cli/evaluate.py +7 -7
  8. rasa/cli/inspect.py +19 -7
  9. rasa/cli/interactive.py +1 -0
  10. rasa/cli/llm_fine_tuning.py +11 -14
  11. rasa/cli/project_templates/calm/config.yml +5 -7
  12. rasa/cli/project_templates/calm/endpoints.yml +15 -2
  13. rasa/cli/project_templates/tutorial/config.yml +8 -5
  14. rasa/cli/project_templates/tutorial/data/flows.yml +1 -1
  15. rasa/cli/project_templates/tutorial/data/patterns.yml +5 -0
  16. rasa/cli/project_templates/tutorial/domain.yml +14 -0
  17. rasa/cli/project_templates/tutorial/endpoints.yml +5 -0
  18. rasa/cli/run.py +7 -0
  19. rasa/cli/scaffold.py +4 -2
  20. rasa/cli/studio/upload.py +0 -15
  21. rasa/cli/train.py +14 -53
  22. rasa/cli/utils.py +14 -11
  23. rasa/cli/x.py +7 -7
  24. rasa/constants.py +3 -1
  25. rasa/core/actions/action.py +77 -33
  26. rasa/core/actions/action_hangup.py +29 -0
  27. rasa/core/actions/action_repeat_bot_messages.py +89 -0
  28. rasa/core/actions/e2e_stub_custom_action_executor.py +5 -1
  29. rasa/core/actions/http_custom_action_executor.py +4 -0
  30. rasa/core/agent.py +2 -2
  31. rasa/core/brokers/kafka.py +3 -1
  32. rasa/core/brokers/pika.py +3 -1
  33. rasa/core/channels/__init__.py +10 -6
  34. rasa/core/channels/channel.py +41 -4
  35. rasa/core/channels/development_inspector.py +150 -46
  36. rasa/core/channels/inspector/README.md +1 -1
  37. rasa/core/channels/inspector/dist/assets/{arc-b6e548fe.js → arc-bc141fb2.js} +1 -1
  38. rasa/core/channels/inspector/dist/assets/{c4Diagram-d0fbc5ce-fa03ac9e.js → c4Diagram-d0fbc5ce-be2db283.js} +1 -1
  39. rasa/core/channels/inspector/dist/assets/{classDiagram-936ed81e-ee67392a.js → classDiagram-936ed81e-55366915.js} +1 -1
  40. rasa/core/channels/inspector/dist/assets/{classDiagram-v2-c3cb15f1-9b283fae.js → classDiagram-v2-c3cb15f1-bb529518.js} +1 -1
  41. rasa/core/channels/inspector/dist/assets/{createText-62fc7601-8b6fcc2a.js → createText-62fc7601-b0ec81d6.js} +1 -1
  42. rasa/core/channels/inspector/dist/assets/{edges-f2ad444c-22e77f4f.js → edges-f2ad444c-6166330c.js} +1 -1
  43. rasa/core/channels/inspector/dist/assets/{erDiagram-9d236eb7-60ffc87f.js → erDiagram-9d236eb7-5ccc6a8e.js} +1 -1
  44. rasa/core/channels/inspector/dist/assets/{flowDb-1972c806-9dd802e4.js → flowDb-1972c806-fca3bfe4.js} +1 -1
  45. rasa/core/channels/inspector/dist/assets/{flowDiagram-7ea5b25a-5fa1912f.js → flowDiagram-7ea5b25a-4739080f.js} +1 -1
  46. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-736177bf.js +1 -0
  47. rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-abe16c3d-622a1fd2.js → flowchart-elk-definition-abe16c3d-7c1b0e0f.js} +1 -1
  48. rasa/core/channels/inspector/dist/assets/{ganttDiagram-9b5ea136-e285a63a.js → ganttDiagram-9b5ea136-772fd050.js} +1 -1
  49. rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-99d0ae7c-f237bdca.js → gitGraphDiagram-99d0ae7c-8eae1dc9.js} +1 -1
  50. rasa/core/channels/inspector/dist/assets/{index-2c4b9a3b-4b03d70e.js → index-2c4b9a3b-f55afcdf.js} +1 -1
  51. rasa/core/channels/inspector/dist/assets/index-e7cef9de.js +1317 -0
  52. rasa/core/channels/inspector/dist/assets/{infoDiagram-736b4530-72a0fa5f.js → infoDiagram-736b4530-124d4a14.js} +1 -1
  53. rasa/core/channels/inspector/dist/assets/{journeyDiagram-df861f2b-82218c41.js → journeyDiagram-df861f2b-7c4fae44.js} +1 -1
  54. rasa/core/channels/inspector/dist/assets/{layout-78cff630.js → layout-b9885fb6.js} +1 -1
  55. rasa/core/channels/inspector/dist/assets/{line-5038b469.js → line-7c59abb6.js} +1 -1
  56. rasa/core/channels/inspector/dist/assets/{linear-c4fc4098.js → linear-4776f780.js} +1 -1
  57. rasa/core/channels/inspector/dist/assets/{mindmap-definition-beec6740-c33c8ea6.js → mindmap-definition-beec6740-2332c46c.js} +1 -1
  58. rasa/core/channels/inspector/dist/assets/{pieDiagram-dbbf0591-a8d03059.js → pieDiagram-dbbf0591-8fb39303.js} +1 -1
  59. rasa/core/channels/inspector/dist/assets/{quadrantDiagram-4d7f4fd6-6a0e56b2.js → quadrantDiagram-4d7f4fd6-3c7180a2.js} +1 -1
  60. rasa/core/channels/inspector/dist/assets/{requirementDiagram-6fc4c22a-2dc7c7bd.js → requirementDiagram-6fc4c22a-e910bcb8.js} +1 -1
  61. rasa/core/channels/inspector/dist/assets/{sankeyDiagram-8f13d901-2360fe39.js → sankeyDiagram-8f13d901-ead16c89.js} +1 -1
  62. rasa/core/channels/inspector/dist/assets/{sequenceDiagram-b655622a-41b9f9ad.js → sequenceDiagram-b655622a-29a02a19.js} +1 -1
  63. rasa/core/channels/inspector/dist/assets/{stateDiagram-59f0c015-0aad326f.js → stateDiagram-59f0c015-042b3137.js} +1 -1
  64. rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-2b26beab-9847d984.js → stateDiagram-v2-2b26beab-2178c0f3.js} +1 -1
  65. rasa/core/channels/inspector/dist/assets/{styles-080da4f6-564d890e.js → styles-080da4f6-23ffa4fc.js} +1 -1
  66. rasa/core/channels/inspector/dist/assets/{styles-3dcbcfbf-38957613.js → styles-3dcbcfbf-94f59763.js} +1 -1
  67. rasa/core/channels/inspector/dist/assets/{styles-9c745c82-f0fc6921.js → styles-9c745c82-78a6bebc.js} +1 -1
  68. rasa/core/channels/inspector/dist/assets/{svgDrawCommon-4835440b-ef3c5a77.js → svgDrawCommon-4835440b-eae2a6f6.js} +1 -1
  69. rasa/core/channels/inspector/dist/assets/{timeline-definition-5b62e21b-bf3e91c1.js → timeline-definition-5b62e21b-5c968d92.js} +1 -1
  70. rasa/core/channels/inspector/dist/assets/{xychartDiagram-2b33534f-4d4026c0.js → xychartDiagram-2b33534f-fd3db0d5.js} +1 -1
  71. rasa/core/channels/inspector/dist/index.html +18 -17
  72. rasa/core/channels/inspector/index.html +17 -16
  73. rasa/core/channels/inspector/package.json +5 -1
  74. rasa/core/channels/inspector/src/App.tsx +118 -68
  75. rasa/core/channels/inspector/src/components/Chat.tsx +95 -0
  76. rasa/core/channels/inspector/src/components/DiagramFlow.tsx +11 -10
  77. rasa/core/channels/inspector/src/components/DialogueStack.tsx +10 -25
  78. rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +6 -3
  79. rasa/core/channels/inspector/src/helpers/audiostream.ts +165 -0
  80. rasa/core/channels/inspector/src/helpers/formatters.test.ts +10 -0
  81. rasa/core/channels/inspector/src/helpers/formatters.ts +107 -41
  82. rasa/core/channels/inspector/src/helpers/utils.ts +92 -7
  83. rasa/core/channels/inspector/src/types.ts +21 -1
  84. rasa/core/channels/inspector/yarn.lock +94 -1
  85. rasa/core/channels/rest.py +51 -46
  86. rasa/core/channels/socketio.py +28 -1
  87. rasa/core/channels/telegram.py +1 -1
  88. rasa/core/channels/twilio.py +1 -1
  89. rasa/core/channels/{audiocodes.py → voice_ready/audiocodes.py} +122 -69
  90. rasa/core/channels/{voice_aware → voice_ready}/jambonz.py +26 -8
  91. rasa/core/channels/{voice_aware → voice_ready}/jambonz_protocol.py +57 -5
  92. rasa/core/channels/{twilio_voice.py → voice_ready/twilio_voice.py} +64 -28
  93. rasa/core/channels/voice_ready/utils.py +37 -0
  94. rasa/core/channels/voice_stream/asr/__init__.py +0 -0
  95. rasa/core/channels/voice_stream/asr/asr_engine.py +89 -0
  96. rasa/core/channels/voice_stream/asr/asr_event.py +18 -0
  97. rasa/core/channels/voice_stream/asr/azure.py +129 -0
  98. rasa/core/channels/voice_stream/asr/deepgram.py +90 -0
  99. rasa/core/channels/voice_stream/audio_bytes.py +8 -0
  100. rasa/core/channels/voice_stream/browser_audio.py +107 -0
  101. rasa/core/channels/voice_stream/call_state.py +23 -0
  102. rasa/core/channels/voice_stream/tts/__init__.py +0 -0
  103. rasa/core/channels/voice_stream/tts/azure.py +106 -0
  104. rasa/core/channels/voice_stream/tts/cartesia.py +118 -0
  105. rasa/core/channels/voice_stream/tts/tts_cache.py +27 -0
  106. rasa/core/channels/voice_stream/tts/tts_engine.py +58 -0
  107. rasa/core/channels/voice_stream/twilio_media_streams.py +173 -0
  108. rasa/core/channels/voice_stream/util.py +57 -0
  109. rasa/core/channels/voice_stream/voice_channel.py +427 -0
  110. rasa/core/information_retrieval/qdrant.py +1 -0
  111. rasa/core/nlg/contextual_response_rephraser.py +45 -17
  112. rasa/{nlu → core}/persistor.py +203 -68
  113. rasa/core/policies/enterprise_search_policy.py +119 -63
  114. rasa/core/policies/flows/flow_executor.py +15 -22
  115. rasa/core/policies/intentless_policy.py +83 -28
  116. rasa/core/processor.py +25 -0
  117. rasa/core/run.py +12 -2
  118. rasa/core/secrets_manager/constants.py +4 -0
  119. rasa/core/secrets_manager/factory.py +8 -0
  120. rasa/core/secrets_manager/vault.py +11 -1
  121. rasa/core/training/interactive.py +33 -34
  122. rasa/core/utils.py +47 -21
  123. rasa/dialogue_understanding/coexistence/llm_based_router.py +41 -14
  124. rasa/dialogue_understanding/commands/__init__.py +6 -0
  125. rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +60 -0
  126. rasa/dialogue_understanding/commands/session_end_command.py +61 -0
  127. rasa/dialogue_understanding/commands/user_silence_command.py +59 -0
  128. rasa/dialogue_understanding/commands/utils.py +5 -0
  129. rasa/dialogue_understanding/generator/constants.py +2 -0
  130. rasa/dialogue_understanding/generator/flow_retrieval.py +47 -9
  131. rasa/dialogue_understanding/generator/llm_based_command_generator.py +38 -15
  132. rasa/dialogue_understanding/generator/llm_command_generator.py +1 -1
  133. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +35 -13
  134. rasa/dialogue_understanding/generator/single_step/command_prompt_template.jinja2 +3 -0
  135. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +60 -13
  136. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +53 -0
  137. rasa/dialogue_understanding/patterns/repeat.py +37 -0
  138. rasa/dialogue_understanding/patterns/user_silence.py +37 -0
  139. rasa/dialogue_understanding/processor/command_processor.py +21 -1
  140. rasa/e2e_test/aggregate_test_stats_calculator.py +1 -11
  141. rasa/e2e_test/assertions.py +136 -61
  142. rasa/e2e_test/assertions_schema.yml +23 -0
  143. rasa/e2e_test/e2e_test_case.py +85 -6
  144. rasa/e2e_test/e2e_test_runner.py +2 -3
  145. rasa/e2e_test/utils/e2e_yaml_utils.py +1 -1
  146. rasa/engine/graph.py +3 -10
  147. rasa/engine/loader.py +12 -0
  148. rasa/engine/recipes/config_files/default_config.yml +0 -3
  149. rasa/engine/recipes/default_recipe.py +0 -1
  150. rasa/engine/recipes/graph_recipe.py +0 -1
  151. rasa/engine/runner/dask.py +2 -2
  152. rasa/engine/storage/local_model_storage.py +12 -42
  153. rasa/engine/storage/storage.py +1 -5
  154. rasa/engine/validation.py +527 -74
  155. rasa/model_manager/__init__.py +0 -0
  156. rasa/model_manager/config.py +40 -0
  157. rasa/model_manager/model_api.py +559 -0
  158. rasa/model_manager/runner_service.py +286 -0
  159. rasa/model_manager/socket_bridge.py +146 -0
  160. rasa/model_manager/studio_jwt_auth.py +86 -0
  161. rasa/model_manager/trainer_service.py +325 -0
  162. rasa/model_manager/utils.py +87 -0
  163. rasa/model_manager/warm_rasa_process.py +187 -0
  164. rasa/model_service.py +112 -0
  165. rasa/model_training.py +42 -23
  166. rasa/nlu/tokenizers/whitespace_tokenizer.py +3 -14
  167. rasa/server.py +4 -2
  168. rasa/shared/constants.py +60 -8
  169. rasa/shared/core/constants.py +13 -0
  170. rasa/shared/core/domain.py +107 -50
  171. rasa/shared/core/events.py +29 -0
  172. rasa/shared/core/flows/flow.py +5 -0
  173. rasa/shared/core/flows/flows_list.py +19 -6
  174. rasa/shared/core/flows/flows_yaml_schema.json +10 -0
  175. rasa/shared/core/flows/utils.py +39 -0
  176. rasa/shared/core/flows/validation.py +121 -0
  177. rasa/shared/core/flows/yaml_flows_io.py +15 -27
  178. rasa/shared/core/slots.py +5 -0
  179. rasa/shared/importers/importer.py +59 -41
  180. rasa/shared/importers/multi_project.py +23 -11
  181. rasa/shared/importers/rasa.py +12 -3
  182. rasa/shared/importers/remote_importer.py +196 -0
  183. rasa/shared/importers/utils.py +3 -1
  184. rasa/shared/nlu/training_data/formats/rasa_yaml.py +18 -3
  185. rasa/shared/nlu/training_data/training_data.py +18 -19
  186. rasa/shared/providers/_configs/litellm_router_client_config.py +220 -0
  187. rasa/shared/providers/_configs/model_group_config.py +167 -0
  188. rasa/shared/providers/_configs/openai_client_config.py +1 -1
  189. rasa/shared/providers/_configs/rasa_llm_client_config.py +73 -0
  190. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +1 -0
  191. rasa/shared/providers/_configs/utils.py +16 -0
  192. rasa/shared/providers/_utils.py +79 -0
  193. rasa/shared/providers/embedding/_base_litellm_embedding_client.py +13 -29
  194. rasa/shared/providers/embedding/azure_openai_embedding_client.py +54 -21
  195. rasa/shared/providers/embedding/default_litellm_embedding_client.py +24 -0
  196. rasa/shared/providers/embedding/litellm_router_embedding_client.py +135 -0
  197. rasa/shared/providers/llm/_base_litellm_client.py +34 -22
  198. rasa/shared/providers/llm/azure_openai_llm_client.py +50 -29
  199. rasa/shared/providers/llm/default_litellm_llm_client.py +24 -0
  200. rasa/shared/providers/llm/litellm_router_llm_client.py +182 -0
  201. rasa/shared/providers/llm/rasa_llm_client.py +112 -0
  202. rasa/shared/providers/llm/self_hosted_llm_client.py +5 -29
  203. rasa/shared/providers/mappings.py +19 -0
  204. rasa/shared/providers/router/__init__.py +0 -0
  205. rasa/shared/providers/router/_base_litellm_router_client.py +183 -0
  206. rasa/shared/providers/router/router_client.py +73 -0
  207. rasa/shared/utils/common.py +40 -24
  208. rasa/shared/utils/health_check/__init__.py +0 -0
  209. rasa/shared/utils/health_check/embeddings_health_check_mixin.py +31 -0
  210. rasa/shared/utils/health_check/health_check.py +258 -0
  211. rasa/shared/utils/health_check/llm_health_check_mixin.py +31 -0
  212. rasa/shared/utils/io.py +27 -6
  213. rasa/shared/utils/llm.py +354 -44
  214. rasa/shared/utils/schemas/events.py +2 -0
  215. rasa/shared/utils/schemas/model_config.yml +0 -10
  216. rasa/shared/utils/yaml.py +181 -38
  217. rasa/studio/data_handler.py +3 -1
  218. rasa/studio/upload.py +160 -74
  219. rasa/telemetry.py +94 -17
  220. rasa/tracing/config.py +3 -1
  221. rasa/tracing/instrumentation/attribute_extractors.py +95 -18
  222. rasa/tracing/instrumentation/instrumentation.py +121 -0
  223. rasa/utils/common.py +5 -0
  224. rasa/utils/endpoints.py +27 -1
  225. rasa/utils/io.py +8 -16
  226. rasa/utils/log_utils.py +9 -2
  227. rasa/utils/sanic_error_handler.py +32 -0
  228. rasa/validator.py +110 -16
  229. rasa/version.py +1 -1
  230. {rasa_pro-3.10.16.dist-info → rasa_pro-3.11.0.dist-info}/METADATA +16 -14
  231. {rasa_pro-3.10.16.dist-info → rasa_pro-3.11.0.dist-info}/RECORD +236 -185
  232. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-1844e5a5.js +0 -1
  233. rasa/core/channels/inspector/dist/assets/index-a5d3e69d.js +0 -1040
  234. rasa/core/channels/voice_aware/utils.py +0 -20
  235. rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +0 -407
  236. /rasa/core/channels/{voice_aware → voice_ready}/__init__.py +0 -0
  237. /rasa/core/channels/{voice_native → voice_stream}/__init__.py +0 -0
  238. {rasa_pro-3.10.16.dist-info → rasa_pro-3.11.0.dist-info}/NOTICE +0 -0
  239. {rasa_pro-3.10.16.dist-info → rasa_pro-3.11.0.dist-info}/WHEEL +0 -0
  240. {rasa_pro-3.10.16.dist-info → rasa_pro-3.11.0.dist-info}/entry_points.txt +0 -0
@@ -4,16 +4,19 @@ import abc
4
4
  import os
5
5
  import shutil
6
6
  from enum import Enum
7
+ from pathlib import Path
7
8
  from typing import TYPE_CHECKING, List, Optional, Text, Tuple, Union
8
9
 
9
10
  import structlog
10
11
 
12
+ from rasa.exceptions import ModelNotFound
11
13
  import rasa.shared.utils.common
12
14
  import rasa.utils.common
13
15
  from rasa.constants import (
14
16
  HTTP_STATUS_FORBIDDEN,
15
17
  HTTP_STATUS_NOT_FOUND,
16
18
  MODEL_ARCHIVE_EXTENSION,
19
+ DEFAULT_BUCKET_NAME,
17
20
  )
18
21
  from rasa.env import (
19
22
  AWS_ENDPOINT_URL_ENV,
@@ -28,6 +31,7 @@ from rasa.shared.utils.io import raise_warning
28
31
 
29
32
  if TYPE_CHECKING:
30
33
  from azure.storage.blob import ContainerClient
34
+ from botocore.exceptions import ClientError
31
35
 
32
36
  structlogger = structlog.get_logger()
33
37
 
@@ -82,16 +86,19 @@ def get_persistor(storage: StorageType) -> Optional[Persistor]:
82
86
  Currently, `aws`, `gcs`, `azure` and providing module paths are supported remote
83
87
  storages.
84
88
  """
85
- if storage == RemoteStorageType.AWS:
89
+ storage = storage.value if isinstance(storage, RemoteStorageType) else storage
90
+
91
+ if storage == RemoteStorageType.AWS.value:
86
92
  return AWSPersistor(
87
- os.environ.get(BUCKET_NAME_ENV), os.environ.get(AWS_ENDPOINT_URL_ENV)
93
+ os.environ.get(BUCKET_NAME_ENV, DEFAULT_BUCKET_NAME),
94
+ os.environ.get(AWS_ENDPOINT_URL_ENV),
88
95
  )
89
- if storage == RemoteStorageType.GCS:
90
- return GCSPersistor(os.environ.get(BUCKET_NAME_ENV))
96
+ if storage == RemoteStorageType.GCS.value:
97
+ return GCSPersistor(os.environ.get(BUCKET_NAME_ENV, DEFAULT_BUCKET_NAME))
91
98
 
92
- if storage == RemoteStorageType.AZURE:
99
+ if storage == RemoteStorageType.AZURE.value:
93
100
  return AzurePersistor(
94
- os.environ.get(AZURE_CONTAINER_ENV),
101
+ os.environ.get(AZURE_CONTAINER_ENV, DEFAULT_BUCKET_NAME),
95
102
  os.environ.get(AZURE_ACCOUNT_NAME_ENV),
96
103
  os.environ.get(AZURE_ACCOUNT_KEY_ENV),
97
104
  )
@@ -116,44 +123,64 @@ class Persistor(abc.ABC):
116
123
 
117
124
  def persist(self, trained_model: str) -> None:
118
125
  """Uploads a trained model persisted in the `target_dir` to cloud storage."""
119
- file_key = self._create_file_key(trained_model)
126
+ absolute_file_key = self._create_file_key(trained_model)
127
+ file_key = Path(absolute_file_key).name
120
128
  self._persist_tar(file_key, trained_model)
121
129
 
122
130
  def retrieve(self, model_name: Text, target_path: Text) -> Text:
123
131
  """Downloads a model that has been persisted to cloud storage.
124
132
 
125
133
  Downloaded model will be saved to the `target_path`.
126
- If `target_path` is a directory, the model will be downloaded to that directory.
127
- If `target_path` is a file, the model will be downloaded to that file.
134
+ If `target_path` is a directory, the model will be saved to that directory.
135
+ If `target_path` is a file, the model will be saved to that file.
128
136
 
129
137
  Args:
130
138
  model_name: The name of the model to retrieve.
131
- target_path: The path to which the model should be downloaded.
139
+ target_path: The path to which the model should be saved.
132
140
  """
133
141
  tar_name = model_name
134
142
  if not model_name.endswith(MODEL_ARCHIVE_EXTENSION):
135
143
  # ensure backward compatibility
136
144
  tar_name = self._tar_name(model_name)
137
- remote_object_path = self._create_file_key(tar_name)
138
- self._retrieve_tar(remote_object_path)
145
+ tar_name = self._create_file_key(tar_name)
146
+ target_filename = os.path.basename(tar_name)
147
+ self._retrieve_tar(target_filename)
148
+ self._copy(os.path.basename(tar_name), target_path)
139
149
 
140
- target_tar_file_name = os.path.basename(tar_name)
141
150
  if os.path.isdir(target_path):
142
- target_path = os.path.join(target_path, target_tar_file_name)
151
+ return os.path.join(target_path, model_name)
143
152
 
144
- if not os.path.exists(target_path):
145
- structlogger.debug(
146
- "persistor.retrieve.copy_model",
147
- event_info=f"Copying model '{target_tar_file_name}' to "
148
- f"'{target_path}'.",
149
- )
150
- self._copy(target_tar_file_name, target_path)
153
+ return target_path
154
+
155
+ def size_of_persisted_model(self, model_name: Text) -> int:
156
+ """Returns the size of the model that has been persisted to cloud storage.
151
157
 
152
- structlogger.debug(
153
- "persistor.retrieve.model_retrieved",
154
- event_info=f"Model retrieved and saved to '{target_path}'.",
158
+ Args:
159
+ model_name: The name of the model to retrieve.
160
+ """
161
+ tar_name = model_name
162
+ if not model_name.endswith(MODEL_ARCHIVE_EXTENSION):
163
+ # ensure backward compatibility
164
+ tar_name = self._tar_name(model_name)
165
+ tar_name = self._create_file_key(tar_name)
166
+ target_filename = os.path.basename(tar_name)
167
+ return self._retrieve_tar_size(target_filename)
168
+
169
+ def _retrieve_tar_size(self, filename: Text) -> int:
170
+ """Returns the size of the model that has been persisted to cloud storage."""
171
+ structlogger.warning(
172
+ "persistor.retrieve_tar_size.not_implemented",
173
+ filename=filename,
174
+ event_info=(
175
+ "This method should be implemented in the persistor. "
176
+ "The default implementation will download the model "
177
+ "to calculate the size. Most persistors should override "
178
+ "this method to avoid downloading the model and get the "
179
+ "size directly from the cloud storage."
180
+ ),
155
181
  )
156
- return target_path
182
+ self._retrieve_tar(filename)
183
+ return os.path.getsize(os.path.basename(filename))
157
184
 
158
185
  @abc.abstractmethod
159
186
  def _retrieve_tar(self, filename: Text) -> None:
@@ -175,7 +202,7 @@ class Persistor(abc.ABC):
175
202
  os.path.join(dirpath, base_name),
176
203
  "gztar",
177
204
  root_dir=model_directory,
178
- base_dir=".",
205
+ base_dir="../nlu",
179
206
  )
180
207
  file_key = os.path.basename(tar_name)
181
208
  return file_key, tar_name
@@ -191,7 +218,7 @@ class Persistor(abc.ABC):
191
218
 
192
219
  @staticmethod
193
220
  def _create_file_key(model_path: str) -> Text:
194
- """Appends remote storage folders when provided to upload or retrieve file"""
221
+ """Appends remote storage folders when provided to upload or retrieve file."""
195
222
  bucket_object_path = os.environ.get(REMOTE_STORAGE_PATH_ENV)
196
223
 
197
224
  # To keep the backward compatibility, if REMOTE_STORAGE_PATH is not provided,
@@ -203,10 +230,7 @@ class Persistor(abc.ABC):
203
230
  f"{REMOTE_STORAGE_PATH_ENV} is deprecated and will be "
204
231
  "removed in future versions. "
205
232
  "Please use the -m path/to/model.tar.gz option to "
206
- "specify the model path when loading a model."
207
- "Or use --output and --fixed-model-name to specify the "
208
- "output directory and the model name when saving a "
209
- "trained model to remote storage.",
233
+ "specify the model path when loading a model.",
210
234
  )
211
235
 
212
236
  file_key = os.path.basename(model_path)
@@ -239,14 +263,13 @@ class AWSPersistor(Persistor):
239
263
  def _ensure_bucket_exists(
240
264
  self, bucket_name: Text, region_name: Optional[Text] = None
241
265
  ) -> None:
242
- import botocore
266
+ from botocore import exceptions
243
267
 
244
268
  # noinspection PyUnresolvedReferences
245
269
  try:
246
270
  self.s3.meta.client.head_bucket(Bucket=bucket_name)
247
- except botocore.exceptions.ClientError as e:
248
- error_code = int(e.response["Error"]["Code"])
249
- if error_code == HTTP_STATUS_FORBIDDEN:
271
+ except exceptions.ClientError as exc:
272
+ if self._error_code(exc) == HTTP_STATUS_FORBIDDEN:
250
273
  log = (
251
274
  f"Access to the specified bucket '{bucket_name}' is forbidden. "
252
275
  "Please make sure you have the necessary "
@@ -258,7 +281,7 @@ class AWSPersistor(Persistor):
258
281
  event_info=log,
259
282
  )
260
283
  raise RasaException(log)
261
- elif error_code == HTTP_STATUS_NOT_FOUND:
284
+ elif self._error_code(exc) == HTTP_STATUS_NOT_FOUND:
262
285
  log = (
263
286
  f"The specified bucket '{bucket_name}' does not exist. "
264
287
  "Please make sure to create the bucket first."
@@ -270,21 +293,57 @@ class AWSPersistor(Persistor):
270
293
  )
271
294
  raise RasaException(log)
272
295
 
296
+ @staticmethod
297
+ def _error_code(e: "ClientError") -> int:
298
+ return int(e.response["Error"]["Code"])
299
+
273
300
  def _persist_tar(self, file_key: Text, tar_path: Text) -> None:
274
301
  """Uploads a model persisted in the `target_dir` to s3."""
275
- structlogger.debug(
276
- "aws_persistor.persist_tar.uploading_model",
277
- event_info=f"Uploading tar archive {file_key} to "
278
- f"s3 bucket '{self.bucket_name}'.",
279
- )
280
302
  with open(tar_path, "rb") as f:
281
303
  self.s3.Object(self.bucket_name, file_key).put(Body=f)
282
304
 
283
- def _retrieve_tar(self, model_path: Text) -> None:
305
+ def _retrieve_tar_size(self, model_path: Text) -> int:
306
+ """Returns the size of the model that has been persisted to s3."""
307
+ try:
308
+ obj = self.s3.Object(self.bucket_name, model_path)
309
+ return obj.content_length
310
+ except Exception:
311
+ raise ModelNotFound()
312
+
313
+ def _retrieve_tar(self, target_filename: str) -> None:
284
314
  """Downloads a model that has previously been persisted to s3."""
285
- tar_name = os.path.basename(model_path)
286
- with open(tar_name, "wb") as f:
287
- self.bucket.download_fileobj(model_path, f)
315
+ from botocore import exceptions
316
+
317
+ log = (
318
+ f"Model '{target_filename}' not found in the specified bucket "
319
+ f"'{self.bucket_name}'. Please make sure the model exists "
320
+ f"in the bucket."
321
+ )
322
+
323
+ try:
324
+ with open(target_filename, "wb") as f:
325
+ self.bucket.download_fileobj(target_filename, f)
326
+
327
+ structlogger.debug(
328
+ "aws_persistor.retrieve_tar.object_found", object_key=target_filename
329
+ )
330
+ except exceptions.ClientError as exc:
331
+ if self._error_code(exc) == HTTP_STATUS_NOT_FOUND:
332
+ structlogger.error(
333
+ "aws_persistor.retrieve_tar.model_not_found",
334
+ bucket_name=self.bucket_name,
335
+ target_filename=target_filename,
336
+ event_info=log,
337
+ )
338
+ raise ModelNotFound() from exc
339
+ except exceptions.BotoCoreError as exc:
340
+ structlogger.error(
341
+ "aws_persistor.retrieve_tar.model_download_error",
342
+ bucket_name=self.bucket_name,
343
+ target_filename=target_filename,
344
+ event_info=log,
345
+ )
346
+ raise ModelNotFound() from exc
288
347
 
289
348
 
290
349
  class GCSPersistor(Persistor):
@@ -309,42 +368,95 @@ class GCSPersistor(Persistor):
309
368
 
310
369
  def _ensure_bucket_exists(self, bucket_name: Text) -> None:
311
370
  from google.cloud import exceptions
371
+ from google.auth import exceptions as auth_exceptions
312
372
 
313
373
  try:
314
374
  self.storage_client.get_bucket(bucket_name)
315
- except exceptions.NotFound:
375
+ except auth_exceptions.GoogleAuthError as exc:
376
+ log = (
377
+ f"An error occurred while authenticating with Google Cloud "
378
+ f"Storage. Please make sure you have the necessary credentials "
379
+ f"to access the bucket '{bucket_name}'."
380
+ )
381
+ structlogger.error(
382
+ "gcp_persistor.ensure_bucket_exists.authentication_error",
383
+ bucket_name=bucket_name,
384
+ event_info=log,
385
+ )
386
+ raise RasaException(log) from exc
387
+ except exceptions.NotFound as exc:
316
388
  log = (
317
- f"The specified bucket '{bucket_name}' does not exist. "
318
- "Please make sure to create the bucket first."
389
+ f"The specified Google Cloud Storage bucket '{bucket_name}' "
390
+ f"does not exist. Please make sure to create the bucket first or "
391
+ f"provide an alternative valid bucket name."
319
392
  )
320
393
  structlogger.error(
321
394
  "gcp_persistor.ensure_bucket_exists.bucket_not_found",
322
395
  bucket_name=bucket_name,
323
396
  event_info=log,
324
397
  )
325
- raise RasaException(log)
326
- except exceptions.Forbidden:
398
+ raise RasaException(log) from exc
399
+ except exceptions.Forbidden as exc:
327
400
  log = (
328
- f"Access to the specified bucket '{bucket_name}' is forbidden. "
329
- "Please make sure you have the necessary "
330
- "permission to access the bucket. "
401
+ f"Access to the specified Google Cloud storage bucket '{bucket_name}' "
402
+ f"is forbidden. Please make sure you have the necessary "
403
+ f"permissions to access the bucket. "
331
404
  )
332
405
  structlogger.error(
333
406
  "gcp_persistor.ensure_bucket_exists.bucket_access_forbidden",
334
407
  bucket_name=bucket_name,
335
408
  event_info=log,
336
409
  )
337
- raise RasaException(log)
410
+ raise RasaException(log) from exc
411
+ except ValueError as exc:
412
+ # bucket_name is None
413
+ log = (
414
+ "The specified Google Cloud Storage bucket name is None. Please "
415
+ "make sure to provide a valid bucket name."
416
+ )
417
+ structlogger.error(
418
+ "gcp_persistor.ensure_bucket_exists.bucket_name_none",
419
+ event_info=log,
420
+ )
421
+ raise RasaException(log) from exc
338
422
 
339
423
  def _persist_tar(self, file_key: Text, tar_path: Text) -> None:
340
424
  """Uploads a model persisted in the `target_dir` to GCS."""
341
425
  blob = self.bucket.blob(file_key)
342
426
  blob.upload_from_filename(tar_path)
343
427
 
428
+ def _retrieve_tar_size(self, target_filename: Text) -> int:
429
+ """Returns the size of the model that has been persisted to GCS."""
430
+ try:
431
+ blob = self.bucket.blob(target_filename)
432
+ return blob.size
433
+ except Exception:
434
+ raise ModelNotFound()
435
+
344
436
  def _retrieve_tar(self, target_filename: Text) -> None:
345
437
  """Downloads a model that has previously been persisted to GCS."""
438
+ from google.api_core import exceptions
439
+
346
440
  blob = self.bucket.blob(target_filename)
347
- blob.download_to_filename(os.path.basename(target_filename))
441
+ try:
442
+ blob.download_to_filename(target_filename)
443
+
444
+ structlogger.debug(
445
+ "gcs_persistor.retrieve_tar.object_found", object_key=target_filename
446
+ )
447
+ except exceptions.NotFound as exc:
448
+ log = (
449
+ f"Model '{target_filename}' not found in the specified bucket "
450
+ f"'{self.bucket_name}'. Please make sure the model exists "
451
+ f"in the bucket."
452
+ )
453
+ structlogger.error(
454
+ "gcp_persistor.retrieve_tar.model_not_found",
455
+ bucket_name=self.bucket_name,
456
+ target_filename=target_filename,
457
+ event_info=log,
458
+ )
459
+ raise ModelNotFound() from exc
348
460
 
349
461
 
350
462
  class AzurePersistor(Persistor):
@@ -370,7 +482,8 @@ class AzurePersistor(Persistor):
370
482
  else:
371
483
  log = (
372
484
  f"The specified container '{self.container_name}' does not exist."
373
- "Please make sure to create the container first."
485
+ "Please make sure to create the bucket first or "
486
+ f"provide an alternative valid bucket name."
374
487
  )
375
488
  structlogger.error(
376
489
  "azure_persistor.ensure_container_exists.container_not_found",
@@ -385,19 +498,41 @@ class AzurePersistor(Persistor):
385
498
  def _persist_tar(self, file_key: Text, tar_path: Text) -> None:
386
499
  """Uploads a model persisted in the `target_dir` to Azure."""
387
500
  with open(tar_path, "rb") as data:
388
- self._container_client().upload_blob(
389
- name=file_key,
390
- data=data,
391
- # overwrite is set to True to keep in line with
392
- # how GCS and AWS APIs work this enables easy
393
- # updating of models in the cloud
394
- overwrite=True,
395
- )
501
+ self._container_client().upload_blob(name=file_key, data=data)
502
+
503
+ def _retrieve_tar_size(self, target_filename: Text) -> int:
504
+ """Returns the size of the model that has been persisted to Azure."""
505
+ try:
506
+ blob_client = self._container_client().get_blob_client(target_filename)
507
+ properties = blob_client.get_blob_properties()
508
+ return properties.size
509
+ except Exception:
510
+ raise ModelNotFound()
396
511
 
397
512
  def _retrieve_tar(self, target_filename: Text) -> None:
398
513
  """Downloads a model that has previously been persisted to Azure."""
399
- blob_client = self._container_client().get_blob_client(target_filename)
514
+ from azure.core.exceptions import AzureError
400
515
 
401
- with open(os.path.basename(target_filename), "wb") as blob:
402
- download_stream = blob_client.download_blob()
403
- blob.write(download_stream.readall())
516
+ try:
517
+ with open(target_filename, "wb") as model_file:
518
+ blob_client = self._container_client().get_blob_client(target_filename)
519
+ download_stream = blob_client.download_blob()
520
+ model_file.write(download_stream.readall())
521
+ structlogger.debug(
522
+ "azure_persistor.retrieve_tar.blob_found", blob_name=target_filename
523
+ )
524
+ except AzureError as exc:
525
+ log = (
526
+ f"An exception occurred while trying to download "
527
+ f"the model '{target_filename}' in the specified container "
528
+ f"'{self.container_name}'. Please make sure the model exists "
529
+ f"in the container."
530
+ )
531
+ structlogger.error(
532
+ "azure_persistor.retrieve_tar.model_download_error",
533
+ container_name=self.container_name,
534
+ target_filename=target_filename,
535
+ event_info=log,
536
+ exception=exc,
537
+ )
538
+ raise ModelNotFound() from exc