rasa-pro 3.9.18__py3-none-any.whl → 3.10.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (183) hide show
  1. README.md +0 -374
  2. rasa/__init__.py +1 -2
  3. rasa/__main__.py +5 -0
  4. rasa/anonymization/anonymization_rule_executor.py +2 -2
  5. rasa/api.py +27 -23
  6. rasa/cli/arguments/data.py +27 -2
  7. rasa/cli/arguments/default_arguments.py +25 -3
  8. rasa/cli/arguments/run.py +9 -9
  9. rasa/cli/arguments/train.py +11 -3
  10. rasa/cli/data.py +70 -8
  11. rasa/cli/e2e_test.py +104 -431
  12. rasa/cli/evaluate.py +1 -1
  13. rasa/cli/interactive.py +1 -0
  14. rasa/cli/llm_fine_tuning.py +398 -0
  15. rasa/cli/project_templates/calm/endpoints.yml +1 -1
  16. rasa/cli/project_templates/tutorial/endpoints.yml +1 -1
  17. rasa/cli/run.py +15 -14
  18. rasa/cli/scaffold.py +10 -8
  19. rasa/cli/studio/studio.py +35 -5
  20. rasa/cli/train.py +56 -8
  21. rasa/cli/utils.py +22 -5
  22. rasa/cli/x.py +1 -1
  23. rasa/constants.py +7 -1
  24. rasa/core/actions/action.py +98 -49
  25. rasa/core/actions/action_run_slot_rejections.py +4 -1
  26. rasa/core/actions/custom_action_executor.py +9 -6
  27. rasa/core/actions/direct_custom_actions_executor.py +80 -0
  28. rasa/core/actions/e2e_stub_custom_action_executor.py +68 -0
  29. rasa/core/actions/grpc_custom_action_executor.py +2 -2
  30. rasa/core/actions/http_custom_action_executor.py +6 -5
  31. rasa/core/agent.py +21 -17
  32. rasa/core/channels/__init__.py +2 -0
  33. rasa/core/channels/audiocodes.py +1 -16
  34. rasa/core/channels/voice_aware/__init__.py +0 -0
  35. rasa/core/channels/voice_aware/jambonz.py +103 -0
  36. rasa/core/channels/voice_aware/jambonz_protocol.py +344 -0
  37. rasa/core/channels/voice_aware/utils.py +20 -0
  38. rasa/core/channels/voice_native/__init__.py +0 -0
  39. rasa/core/constants.py +6 -1
  40. rasa/core/information_retrieval/faiss.py +7 -4
  41. rasa/core/information_retrieval/information_retrieval.py +8 -0
  42. rasa/core/information_retrieval/milvus.py +9 -2
  43. rasa/core/information_retrieval/qdrant.py +1 -1
  44. rasa/core/nlg/contextual_response_rephraser.py +32 -10
  45. rasa/core/nlg/summarize.py +4 -3
  46. rasa/core/policies/enterprise_search_policy.py +113 -45
  47. rasa/core/policies/flows/flow_executor.py +122 -76
  48. rasa/core/policies/intentless_policy.py +83 -29
  49. rasa/core/processor.py +72 -54
  50. rasa/core/run.py +5 -4
  51. rasa/core/tracker_store.py +8 -4
  52. rasa/core/training/interactive.py +1 -1
  53. rasa/core/utils.py +56 -57
  54. rasa/dialogue_understanding/coexistence/llm_based_router.py +53 -13
  55. rasa/dialogue_understanding/commands/__init__.py +6 -0
  56. rasa/dialogue_understanding/commands/restart_command.py +58 -0
  57. rasa/dialogue_understanding/commands/session_start_command.py +59 -0
  58. rasa/dialogue_understanding/commands/utils.py +40 -0
  59. rasa/dialogue_understanding/generator/constants.py +10 -3
  60. rasa/dialogue_understanding/generator/flow_retrieval.py +21 -5
  61. rasa/dialogue_understanding/generator/llm_based_command_generator.py +13 -3
  62. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +134 -90
  63. rasa/dialogue_understanding/generator/nlu_command_adapter.py +47 -7
  64. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +127 -41
  65. rasa/dialogue_understanding/patterns/restart.py +37 -0
  66. rasa/dialogue_understanding/patterns/session_start.py +37 -0
  67. rasa/dialogue_understanding/processor/command_processor.py +16 -3
  68. rasa/dialogue_understanding/processor/command_processor_component.py +6 -2
  69. rasa/e2e_test/aggregate_test_stats_calculator.py +134 -0
  70. rasa/e2e_test/assertions.py +1223 -0
  71. rasa/e2e_test/assertions_schema.yml +106 -0
  72. rasa/e2e_test/constants.py +20 -0
  73. rasa/e2e_test/e2e_config.py +220 -0
  74. rasa/e2e_test/e2e_config_schema.yml +26 -0
  75. rasa/e2e_test/e2e_test_case.py +131 -8
  76. rasa/e2e_test/e2e_test_converter.py +363 -0
  77. rasa/e2e_test/e2e_test_converter_prompt.jinja2 +70 -0
  78. rasa/e2e_test/e2e_test_coverage_report.py +364 -0
  79. rasa/e2e_test/e2e_test_result.py +26 -6
  80. rasa/e2e_test/e2e_test_runner.py +493 -71
  81. rasa/e2e_test/e2e_test_schema.yml +96 -0
  82. rasa/e2e_test/pykwalify_extensions.py +39 -0
  83. rasa/e2e_test/stub_custom_action.py +70 -0
  84. rasa/e2e_test/utils/__init__.py +0 -0
  85. rasa/e2e_test/utils/e2e_yaml_utils.py +55 -0
  86. rasa/e2e_test/utils/io.py +598 -0
  87. rasa/e2e_test/utils/validation.py +80 -0
  88. rasa/engine/graph.py +9 -3
  89. rasa/engine/recipes/default_components.py +0 -2
  90. rasa/engine/recipes/default_recipe.py +10 -2
  91. rasa/engine/storage/local_model_storage.py +40 -12
  92. rasa/engine/validation.py +78 -1
  93. rasa/env.py +9 -0
  94. rasa/graph_components/providers/story_graph_provider.py +59 -6
  95. rasa/llm_fine_tuning/__init__.py +0 -0
  96. rasa/llm_fine_tuning/annotation_module.py +241 -0
  97. rasa/llm_fine_tuning/conversations.py +144 -0
  98. rasa/llm_fine_tuning/llm_data_preparation_module.py +178 -0
  99. rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +407 -0
  100. rasa/llm_fine_tuning/paraphrasing/__init__.py +0 -0
  101. rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +281 -0
  102. rasa/llm_fine_tuning/paraphrasing/default_rephrase_prompt_template.jina2 +44 -0
  103. rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +121 -0
  104. rasa/llm_fine_tuning/paraphrasing/rephrased_user_message.py +10 -0
  105. rasa/llm_fine_tuning/paraphrasing_module.py +128 -0
  106. rasa/llm_fine_tuning/storage.py +174 -0
  107. rasa/llm_fine_tuning/train_test_split_module.py +441 -0
  108. rasa/model_training.py +56 -16
  109. rasa/nlu/persistor.py +157 -36
  110. rasa/server.py +45 -10
  111. rasa/shared/constants.py +76 -16
  112. rasa/shared/core/domain.py +27 -19
  113. rasa/shared/core/events.py +28 -2
  114. rasa/shared/core/flows/flow.py +208 -13
  115. rasa/shared/core/flows/flow_path.py +84 -0
  116. rasa/shared/core/flows/flows_list.py +33 -11
  117. rasa/shared/core/flows/flows_yaml_schema.json +269 -193
  118. rasa/shared/core/flows/validation.py +112 -25
  119. rasa/shared/core/flows/yaml_flows_io.py +149 -10
  120. rasa/shared/core/trackers.py +6 -0
  121. rasa/shared/core/training_data/structures.py +20 -0
  122. rasa/shared/core/training_data/visualization.html +2 -2
  123. rasa/shared/exceptions.py +4 -0
  124. rasa/shared/importers/importer.py +64 -16
  125. rasa/shared/nlu/constants.py +2 -0
  126. rasa/shared/providers/_configs/__init__.py +0 -0
  127. rasa/shared/providers/_configs/azure_openai_client_config.py +183 -0
  128. rasa/shared/providers/_configs/client_config.py +57 -0
  129. rasa/shared/providers/_configs/default_litellm_client_config.py +130 -0
  130. rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +234 -0
  131. rasa/shared/providers/_configs/openai_client_config.py +175 -0
  132. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +176 -0
  133. rasa/shared/providers/_configs/utils.py +101 -0
  134. rasa/shared/providers/_ssl_verification_utils.py +124 -0
  135. rasa/shared/providers/embedding/__init__.py +0 -0
  136. rasa/shared/providers/embedding/_base_litellm_embedding_client.py +259 -0
  137. rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py +74 -0
  138. rasa/shared/providers/embedding/azure_openai_embedding_client.py +277 -0
  139. rasa/shared/providers/embedding/default_litellm_embedding_client.py +102 -0
  140. rasa/shared/providers/embedding/embedding_client.py +90 -0
  141. rasa/shared/providers/embedding/embedding_response.py +41 -0
  142. rasa/shared/providers/embedding/huggingface_local_embedding_client.py +191 -0
  143. rasa/shared/providers/embedding/openai_embedding_client.py +172 -0
  144. rasa/shared/providers/llm/__init__.py +0 -0
  145. rasa/shared/providers/llm/_base_litellm_client.py +251 -0
  146. rasa/shared/providers/llm/azure_openai_llm_client.py +338 -0
  147. rasa/shared/providers/llm/default_litellm_llm_client.py +84 -0
  148. rasa/shared/providers/llm/llm_client.py +76 -0
  149. rasa/shared/providers/llm/llm_response.py +50 -0
  150. rasa/shared/providers/llm/openai_llm_client.py +155 -0
  151. rasa/shared/providers/llm/self_hosted_llm_client.py +293 -0
  152. rasa/shared/providers/mappings.py +75 -0
  153. rasa/shared/utils/cli.py +30 -0
  154. rasa/shared/utils/io.py +65 -2
  155. rasa/shared/utils/llm.py +246 -200
  156. rasa/shared/utils/yaml.py +121 -15
  157. rasa/studio/auth.py +6 -4
  158. rasa/studio/config.py +13 -4
  159. rasa/studio/constants.py +1 -0
  160. rasa/studio/data_handler.py +10 -3
  161. rasa/studio/download.py +19 -13
  162. rasa/studio/train.py +2 -3
  163. rasa/studio/upload.py +19 -11
  164. rasa/telemetry.py +113 -58
  165. rasa/tracing/instrumentation/attribute_extractors.py +32 -17
  166. rasa/utils/common.py +18 -19
  167. rasa/utils/endpoints.py +7 -4
  168. rasa/utils/json_utils.py +60 -0
  169. rasa/utils/licensing.py +9 -1
  170. rasa/utils/ml_utils.py +4 -2
  171. rasa/validator.py +213 -3
  172. rasa/version.py +1 -1
  173. rasa_pro-3.10.16.dist-info/METADATA +196 -0
  174. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/RECORD +179 -113
  175. rasa/nlu/classifiers/llm_intent_classifier.py +0 -519
  176. rasa/shared/providers/openai/clients.py +0 -43
  177. rasa/shared/providers/openai/session_handler.py +0 -110
  178. rasa_pro-3.9.18.dist-info/METADATA +0 -563
  179. /rasa/{shared/providers/openai → cli/project_templates/tutorial/actions}/__init__.py +0 -0
  180. /rasa/cli/project_templates/tutorial/{actions.py → actions/actions.py} +0 -0
  181. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/NOTICE +0 -0
  182. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/WHEEL +0 -0
  183. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/entry_points.txt +0 -0
rasa/model_training.py CHANGED
@@ -1,12 +1,20 @@
1
1
  import sys
2
2
  import time
3
3
  from pathlib import Path
4
- from typing import Text, NamedTuple, Optional, List, Union, Dict, Any
4
+ from typing import Any, Dict, List, NamedTuple, Optional, Text, Union
5
5
 
6
6
  import randomname
7
7
  import structlog
8
8
 
9
9
  import rasa.engine.validation
10
+ import rasa.model
11
+ import rasa.shared.constants
12
+ import rasa.shared.exceptions
13
+ import rasa.shared.utils.cli
14
+ import rasa.shared.utils.common
15
+ import rasa.shared.utils.io
16
+ import rasa.utils.common
17
+ from rasa import telemetry
10
18
  from rasa.engine.caching import LocalTrainingCache
11
19
  from rasa.engine.recipes.recipe import Recipe
12
20
  from rasa.engine.runner.dask import DaskGraphRunner
@@ -14,19 +22,13 @@ from rasa.engine.storage.local_model_storage import LocalModelStorage
14
22
  from rasa.engine.storage.storage import ModelStorage
15
23
  from rasa.engine.training.components import FingerprintStatus
16
24
  from rasa.engine.training.graph_trainer import GraphTrainer
25
+ from rasa.nlu.persistor import RemoteStorageType, StorageType
26
+ from rasa.shared.core.domain import Domain
17
27
  from rasa.shared.core.events import SlotSet
18
28
  from rasa.shared.core.training_data.structures import StoryGraph
19
29
  from rasa.shared.data import TrainingType
30
+ from rasa.shared.exceptions import RasaException
20
31
  from rasa.shared.importers.importer import TrainingDataImporter
21
- from rasa import telemetry
22
- from rasa.shared.core.domain import Domain
23
- import rasa.utils.common
24
- import rasa.shared.utils.common
25
- import rasa.shared.utils.cli
26
- import rasa.shared.exceptions
27
- import rasa.shared.utils.io
28
- import rasa.shared.constants
29
- import rasa.model
30
32
 
31
33
  CODE_NEEDS_TO_BE_RETRAINED = 0b0001
32
34
  CODE_FORCED_TRAINING = 0b1000
@@ -153,6 +155,7 @@ async def train(
153
155
  nlu_additional_arguments: Optional[Dict] = None,
154
156
  model_to_finetune: Optional[Text] = None,
155
157
  finetuning_epoch_fraction: float = 1.0,
158
+ remote_storage: Optional[StorageType] = None,
156
159
  ) -> TrainingResult:
157
160
  """Trains a Rasa model (Core and NLU).
158
161
 
@@ -174,6 +177,7 @@ async def train(
174
177
  a directory in case the latest trained model should be used.
175
178
  finetuning_epoch_fraction: The fraction currently specified training epochs
176
179
  in the model configuration which should be used for finetuning.
180
+ remote_storage: The remote storage which should be used to store the model.
177
181
 
178
182
  Returns:
179
183
  An instance of `TrainingResult`.
@@ -253,6 +257,7 @@ async def train(
253
257
  persist_nlu_training_data=persist_nlu_training_data,
254
258
  finetuning_epoch_fraction=finetuning_epoch_fraction,
255
259
  dry_run=dry_run,
260
+ remote_storage=remote_storage,
256
261
  **(core_additional_arguments or {}),
257
262
  **(nlu_additional_arguments or {}),
258
263
  )
@@ -266,6 +271,7 @@ async def _train_graph(
266
271
  model_to_finetune: Optional[Union[Text, Path]] = None,
267
272
  force_full_training: bool = False,
268
273
  dry_run: bool = False,
274
+ remote_storage: Optional[StorageType] = None,
269
275
  **kwargs: Any,
270
276
  ) -> TrainingResult:
271
277
  if model_to_finetune:
@@ -306,6 +312,7 @@ async def _train_graph(
306
312
  rasa.engine.validation.validate_coexistance_routing_setup(
307
313
  domain, model_configuration, flows
308
314
  )
315
+ rasa.engine.validation.validate_model_client_configuration_setup(config)
309
316
  rasa.engine.validation.validate_flow_component_dependencies(
310
317
  flows, model_configuration
311
318
  )
@@ -341,12 +348,30 @@ async def _train_graph(
341
348
  force_retraining=force_full_training,
342
349
  is_finetuning=is_finetuning,
343
350
  )
344
- structlogger.info(
345
- "model_training.train.finished_training",
346
- event_info=(
347
- f"Your Rasa model is trained " f"and saved at '{full_model_path}'."
348
- ),
349
- )
351
+ if remote_storage:
352
+ push_model_to_remote_storage(full_model_path, remote_storage)
353
+ full_model_path.unlink()
354
+ remote_storage_string = (
355
+ remote_storage.value
356
+ if isinstance(remote_storage, RemoteStorageType)
357
+ else remote_storage
358
+ )
359
+ structlogger.info(
360
+ "model_training.train.finished_training",
361
+ event_info=(
362
+ f"Your Rasa model {model_name} is trained "
363
+ f"and saved at remote storage provider "
364
+ f"'{remote_storage_string}'."
365
+ ),
366
+ )
367
+ else:
368
+ structlogger.info(
369
+ "model_training.train.finished_training",
370
+ event_info=(
371
+ f"Your Rasa model is trained and saved at "
372
+ f"'{full_model_path}'."
373
+ ),
374
+ )
350
375
 
351
376
  return TrainingResult(str(full_model_path), 0)
352
377
 
@@ -534,3 +559,18 @@ async def train_nlu(
534
559
  **(additional_arguments or {}),
535
560
  )
536
561
  ).model
562
+
563
+
564
+ def push_model_to_remote_storage(model_path: Path, remote_storage: StorageType) -> None:
565
+ """push model to remote storage"""
566
+ from rasa.nlu.persistor import get_persistor
567
+
568
+ persistor = get_persistor(remote_storage)
569
+
570
+ if persistor is not None:
571
+ persistor.persist(str(model_path))
572
+
573
+ else:
574
+ raise RasaException(
575
+ f"Persistor not found for remote storage: '{remote_storage}'."
576
+ )
rasa/nlu/persistor.py CHANGED
@@ -1,13 +1,30 @@
1
+ from __future__ import annotations
2
+
1
3
  import abc
2
- import structlog
3
4
  import os
4
5
  import shutil
5
- from typing import Optional, Text, Tuple, TYPE_CHECKING
6
+ from enum import Enum
7
+ from typing import TYPE_CHECKING, List, Optional, Text, Tuple, Union
6
8
 
7
- from rasa.shared.exceptions import RasaException
9
+ import structlog
8
10
 
9
11
  import rasa.shared.utils.common
10
12
  import rasa.utils.common
13
+ from rasa.constants import (
14
+ HTTP_STATUS_FORBIDDEN,
15
+ HTTP_STATUS_NOT_FOUND,
16
+ MODEL_ARCHIVE_EXTENSION,
17
+ )
18
+ from rasa.env import (
19
+ AWS_ENDPOINT_URL_ENV,
20
+ AZURE_ACCOUNT_KEY_ENV,
21
+ AZURE_ACCOUNT_NAME_ENV,
22
+ AZURE_CONTAINER_ENV,
23
+ BUCKET_NAME_ENV,
24
+ REMOTE_STORAGE_PATH_ENV,
25
+ )
26
+ from rasa.shared.exceptions import RasaException
27
+ from rasa.shared.utils.io import raise_warning
11
28
 
12
29
  if TYPE_CHECKING:
13
30
  from azure.storage.blob import ContainerClient
@@ -15,34 +32,80 @@ if TYPE_CHECKING:
15
32
  structlogger = structlog.get_logger()
16
33
 
17
34
 
18
- def get_persistor(name: Text) -> Optional["Persistor"]:
35
+ class RemoteStorageType(Enum):
36
+ """Enum for the different remote storage types."""
37
+
38
+ AWS = "aws"
39
+ GCS = "gcs"
40
+ AZURE = "azure"
41
+
42
+ @classmethod
43
+ def list(cls) -> List[str]:
44
+ """Returns a list of all available storage types."""
45
+ return [item.value for item in cls]
46
+
47
+
48
+ """Storage can be a built-in one or a module path to a custom persistor."""
49
+ StorageType = Union[RemoteStorageType, str]
50
+
51
+
52
+ def parse_remote_storage(value: str) -> StorageType:
53
+ try:
54
+ return RemoteStorageType(value)
55
+ except ValueError:
56
+ # if the value is not a valid storage type,
57
+ # but it is a string we assume it is a custom class
58
+ # and return it as is
59
+
60
+ supported_storages_help_text = (
61
+ f"Supported storages are: {RemoteStorageType.list()} "
62
+ "or path to a python class which implements `Persistor` interface."
63
+ )
64
+
65
+ if isinstance(value, str):
66
+ if value == "":
67
+ raise RasaException(
68
+ f"The value can't be an empty string."
69
+ f" {supported_storages_help_text}"
70
+ )
71
+
72
+ return value
73
+
74
+ raise RasaException(
75
+ f"Invalid storage type '{value}'. {supported_storages_help_text}"
76
+ )
77
+
78
+
79
+ def get_persistor(storage: StorageType) -> Optional[Persistor]:
19
80
  """Returns an instance of the requested persistor.
20
81
 
21
82
  Currently, `aws`, `gcs`, `azure` and providing module paths are supported remote
22
83
  storages.
23
84
  """
24
- if name == "aws":
85
+ if storage == RemoteStorageType.AWS:
25
86
  return AWSPersistor(
26
- os.environ.get("BUCKET_NAME"), os.environ.get("AWS_ENDPOINT_URL")
87
+ os.environ.get(BUCKET_NAME_ENV), os.environ.get(AWS_ENDPOINT_URL_ENV)
27
88
  )
28
- if name == "gcs":
29
- return GCSPersistor(os.environ.get("BUCKET_NAME"))
89
+ if storage == RemoteStorageType.GCS:
90
+ return GCSPersistor(os.environ.get(BUCKET_NAME_ENV))
30
91
 
31
- if name == "azure":
92
+ if storage == RemoteStorageType.AZURE:
32
93
  return AzurePersistor(
33
- os.environ.get("AZURE_CONTAINER"),
34
- os.environ.get("AZURE_ACCOUNT_NAME"),
35
- os.environ.get("AZURE_ACCOUNT_KEY"),
94
+ os.environ.get(AZURE_CONTAINER_ENV),
95
+ os.environ.get(AZURE_ACCOUNT_NAME_ENV),
96
+ os.environ.get(AZURE_ACCOUNT_KEY_ENV),
36
97
  )
37
- if name:
98
+ # If the persistor is not a built-in one, it is assumed to be a module path
99
+ # to a persistor implementation supplied by the user.
100
+ if storage:
38
101
  try:
39
- persistor = rasa.shared.utils.common.class_from_module_path(name)
102
+ persistor = rasa.shared.utils.common.class_from_module_path(storage)
40
103
  return persistor()
41
104
  except ImportError:
42
105
  raise ImportError(
43
- f"Unknown model persistor {name}. Please make sure to "
44
- "either use an included model persistor (`aws`, `gcs` "
45
- "or `azure`) or specify the module path to an external "
106
+ f"Unknown model persistor {storage}. Please make sure to "
107
+ f"either use an included model persistor ({RemoteStorageType.list()}) "
108
+ f"or specify the module path to an external "
46
109
  "model persistor."
47
110
  )
48
111
  return None
@@ -51,24 +114,46 @@ def get_persistor(name: Text) -> Optional["Persistor"]:
51
114
  class Persistor(abc.ABC):
52
115
  """Store models in cloud and fetch them when needed."""
53
116
 
54
- def persist(self, model_directory: Text, model_name: Text) -> None:
55
- """Uploads a model persisted in the `target_dir` to cloud storage."""
56
- if not os.path.isdir(model_directory):
57
- raise ValueError(f"Target directory '{model_directory}' not found.")
117
+ def persist(self, trained_model: str) -> None:
118
+ """Uploads a trained model persisted in the `target_dir` to cloud storage."""
119
+ file_key = self._create_file_key(trained_model)
120
+ self._persist_tar(file_key, trained_model)
58
121
 
59
- file_key, tar_path = self._compress(model_directory, model_name)
60
- self._persist_tar(file_key, tar_path)
122
+ def retrieve(self, model_name: Text, target_path: Text) -> Text:
123
+ """Downloads a model that has been persisted to cloud storage.
61
124
 
62
- def retrieve(self, model_name: Text, target_path: Text) -> None:
63
- """Downloads a model that has been persisted to cloud storage."""
64
- tar_name = model_name
125
+ Downloaded model will be saved to the `target_path`.
126
+ If `target_path` is a directory, the model will be downloaded to that directory.
127
+ If `target_path` is a file, the model will be downloaded to that file.
65
128
 
66
- if not model_name.endswith("tar.gz"):
129
+ Args:
130
+ model_name: The name of the model to retrieve.
131
+ target_path: The path to which the model should be downloaded.
132
+ """
133
+ tar_name = model_name
134
+ if not model_name.endswith(MODEL_ARCHIVE_EXTENSION):
67
135
  # ensure backward compatibility
68
136
  tar_name = self._tar_name(model_name)
137
+ remote_object_path = self._create_file_key(tar_name)
138
+ self._retrieve_tar(remote_object_path)
139
+
140
+ target_tar_file_name = os.path.basename(tar_name)
141
+ if os.path.isdir(target_path):
142
+ target_path = os.path.join(target_path, target_tar_file_name)
143
+
144
+ if not os.path.exists(target_path):
145
+ structlogger.debug(
146
+ "persistor.retrieve.copy_model",
147
+ event_info=f"Copying model '{target_tar_file_name}' to "
148
+ f"'{target_path}'.",
149
+ )
150
+ self._copy(target_tar_file_name, target_path)
69
151
 
70
- self._retrieve_tar(tar_name)
71
- self._copy(os.path.basename(tar_name), target_path)
152
+ structlogger.debug(
153
+ "persistor.retrieve.model_retrieved",
154
+ event_info=f"Model retrieved and saved to '{target_path}'.",
155
+ )
156
+ return target_path
72
157
 
73
158
  @abc.abstractmethod
74
159
  def _retrieve_tar(self, filename: Text) -> None:
@@ -97,13 +182,37 @@ class Persistor(abc.ABC):
97
182
 
98
183
  @staticmethod
99
184
  def _tar_name(model_name: Text, include_extension: bool = True) -> Text:
100
- ext = ".tar.gz" if include_extension else ""
185
+ ext = f".{MODEL_ARCHIVE_EXTENSION}" if include_extension else ""
101
186
  return f"{model_name}{ext}"
102
187
 
103
188
  @staticmethod
104
189
  def _copy(compressed_path: Text, target_path: Text) -> None:
105
190
  shutil.copy2(compressed_path, target_path)
106
191
 
192
+ @staticmethod
193
+ def _create_file_key(model_path: str) -> Text:
194
+ """Appends remote storage folders when provided to upload or retrieve file"""
195
+ bucket_object_path = os.environ.get(REMOTE_STORAGE_PATH_ENV)
196
+
197
+ # To keep the backward compatibility, if REMOTE_STORAGE_PATH is not provided,
198
+ # the model_name (which might be a complete path) will be returned as it is.
199
+ if bucket_object_path is None:
200
+ return str(model_path)
201
+ else:
202
+ raise_warning(
203
+ f"{REMOTE_STORAGE_PATH_ENV} is deprecated and will be "
204
+ "removed in future versions. "
205
+ "Please use the -m path/to/model.tar.gz option to "
206
+ "specify the model path when loading a model."
207
+ "Or use --output and --fixed-model-name to specify the "
208
+ "output directory and the model name when saving a "
209
+ "trained model to remote storage.",
210
+ )
211
+
212
+ file_key = os.path.basename(model_path)
213
+ file_key = os.path.join(bucket_object_path, file_key)
214
+ return file_key
215
+
107
216
 
108
217
  class AWSPersistor(Persistor):
109
218
  """Store models on S3.
@@ -137,7 +246,7 @@ class AWSPersistor(Persistor):
137
246
  self.s3.meta.client.head_bucket(Bucket=bucket_name)
138
247
  except botocore.exceptions.ClientError as e:
139
248
  error_code = int(e.response["Error"]["Code"])
140
- if error_code == 403:
249
+ if error_code == HTTP_STATUS_FORBIDDEN:
141
250
  log = (
142
251
  f"Access to the specified bucket '{bucket_name}' is forbidden. "
143
252
  "Please make sure you have the necessary "
@@ -149,7 +258,7 @@ class AWSPersistor(Persistor):
149
258
  event_info=log,
150
259
  )
151
260
  raise RasaException(log)
152
- elif error_code == 404:
261
+ elif error_code == HTTP_STATUS_NOT_FOUND:
153
262
  log = (
154
263
  f"The specified bucket '{bucket_name}' does not exist. "
155
264
  "Please make sure to create the bucket first."
@@ -163,6 +272,11 @@ class AWSPersistor(Persistor):
163
272
 
164
273
  def _persist_tar(self, file_key: Text, tar_path: Text) -> None:
165
274
  """Uploads a model persisted in the `target_dir` to s3."""
275
+ structlogger.debug(
276
+ "aws_persistor.persist_tar.uploading_model",
277
+ event_info=f"Uploading tar archive {file_key} to "
278
+ f"s3 bucket '{self.bucket_name}'.",
279
+ )
166
280
  with open(tar_path, "rb") as f:
167
281
  self.s3.Object(self.bucket_name, file_key).put(Body=f)
168
282
 
@@ -183,7 +297,7 @@ class GCSPersistor(Persistor):
183
297
  """Initialise class with client and bucket."""
184
298
  # there are no type hints in this repo for now
185
299
  # https://github.com/googleapis/python-storage/issues/393
186
- from google.cloud import storage # type: ignore[attr-defined]
300
+ from google.cloud import storage
187
301
 
188
302
  super().__init__()
189
303
 
@@ -230,7 +344,7 @@ class GCSPersistor(Persistor):
230
344
  def _retrieve_tar(self, target_filename: Text) -> None:
231
345
  """Downloads a model that has previously been persisted to GCS."""
232
346
  blob = self.bucket.blob(target_filename)
233
- blob.download_to_filename(target_filename)
347
+ blob.download_to_filename(os.path.basename(target_filename))
234
348
 
235
349
 
236
350
  class AzurePersistor(Persistor):
@@ -271,12 +385,19 @@ class AzurePersistor(Persistor):
271
385
  def _persist_tar(self, file_key: Text, tar_path: Text) -> None:
272
386
  """Uploads a model persisted in the `target_dir` to Azure."""
273
387
  with open(tar_path, "rb") as data:
274
- self._container_client().upload_blob(name=file_key, data=data)
388
+ self._container_client().upload_blob(
389
+ name=file_key,
390
+ data=data,
391
+ # overwrite is set to True to keep in line with
392
+ # how GCS and AWS APIs work this enables easy
393
+ # updating of models in the cloud
394
+ overwrite=True,
395
+ )
275
396
 
276
397
  def _retrieve_tar(self, target_filename: Text) -> None:
277
398
  """Downloads a model that has previously been persisted to Azure."""
278
399
  blob_client = self._container_client().get_blob_client(target_filename)
279
400
 
280
- with open(target_filename, "wb") as blob:
401
+ with open(os.path.basename(target_filename), "wb") as blob:
281
402
  download_stream = blob_client.download_blob()
282
403
  blob.write(download_stream.readall())
rasa/server.py CHANGED
@@ -11,17 +11,17 @@ from http import HTTPStatus
11
11
  from inspect import isawaitable
12
12
  from pathlib import Path
13
13
  from typing import (
14
+ TYPE_CHECKING,
14
15
  Any,
15
16
  Callable,
17
+ Coroutine,
16
18
  DefaultDict,
19
+ Dict,
17
20
  List,
21
+ NoReturn,
18
22
  Optional,
19
23
  Text,
20
24
  Union,
21
- Dict,
22
- TYPE_CHECKING,
23
- NoReturn,
24
- Coroutine,
25
25
  )
26
26
 
27
27
  import aiohttp
@@ -54,15 +54,16 @@ from rasa.core.test import test
54
54
  from rasa.core.utils import AvailableEndpoints
55
55
  from rasa.nlu.emulators.emulator import Emulator
56
56
  from rasa.nlu.emulators.no_emulator import NoEmulator
57
+ from rasa.nlu.persistor import parse_remote_storage
57
58
  from rasa.nlu.test import CVEvaluationResult
58
59
  from rasa.shared.constants import (
59
- DOCS_URL_TRAINING_DATA,
60
- DOCS_BASE_URL,
61
- DEFAULT_SENDER_ID,
62
60
  DEFAULT_MODELS_PATH,
61
+ DEFAULT_SENDER_ID,
62
+ DOCS_BASE_URL,
63
+ DOCS_URL_TRAINING_DATA,
63
64
  TEST_STORIES_FILE_PREFIX,
64
65
  )
65
- from rasa.shared.core.domain import InvalidDomain, Domain
66
+ from rasa.shared.core.domain import Domain, InvalidDomain
66
67
  from rasa.shared.core.events import Event
67
68
  from rasa.shared.core.trackers import (
68
69
  DialogueStateTracker,
@@ -80,8 +81,10 @@ from rasa.utils.endpoints import EndpointConfig
80
81
 
81
82
  if TYPE_CHECKING:
82
83
  from ssl import SSLContext
84
+
85
+ from mypy_extensions import Arg, KwArg, VarArg
86
+
83
87
  from rasa.core.processor import MessageProcessor
84
- from mypy_extensions import Arg, VarArg, KwArg
85
88
 
86
89
  SanicResponse = Union[
87
90
  response.HTTPResponse, Coroutine[Any, Any, response.HTTPResponse]
@@ -532,6 +535,32 @@ def add_root_route(app: Sanic) -> None:
532
535
  """
533
536
  return response.html(html_content)
534
537
 
538
+ @app.get("/license")
539
+ async def license(request: Request) -> HTTPResponse:
540
+ """Respond with the license expiration date."""
541
+ from rasa.utils.licensing import (
542
+ get_license_expiration_date,
543
+ property_of_active_license,
544
+ )
545
+
546
+ body = {
547
+ "id": property_of_active_license(lambda active_license: active_license.jti),
548
+ "company": property_of_active_license(
549
+ lambda active_license: active_license.company
550
+ ),
551
+ "scope": property_of_active_license(
552
+ lambda active_license: active_license.scope
553
+ ),
554
+ "email": property_of_active_license(
555
+ lambda active_license: active_license.email
556
+ ),
557
+ "expires": get_license_expiration_date(),
558
+ }
559
+ return response.json(
560
+ body=body,
561
+ headers={"Content-Type": "application/json"},
562
+ )
563
+
535
564
 
536
565
  def async_if_callback_url(f: Callable[..., Coroutine]) -> Callable:
537
566
  """Decorator to enable async request handling.
@@ -1351,7 +1380,13 @@ def create_app(
1351
1380
 
1352
1381
  model_path = request.json.get("model_file", None)
1353
1382
  model_server = request.json.get("model_server", None)
1354
- remote_storage = request.json.get("remote_storage", None)
1383
+
1384
+ remote_storage_argument = request.json.get("remote_storage", None)
1385
+ remote_storage = (
1386
+ parse_remote_storage(remote_storage_argument)
1387
+ if remote_storage_argument
1388
+ else None
1389
+ )
1355
1390
 
1356
1391
  if model_server:
1357
1392
  try:
rasa/shared/constants.py CHANGED
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from typing import List, Text
4
4
 
5
+ from rasa.shared.engine.caching import get_local_cache_location
5
6
 
6
7
  DOCS_BASE_URL = "https://rasa.com/docs/rasa-pro"
7
8
  DOCS_URL_CONCEPTS = DOCS_BASE_URL + "/concepts"
@@ -51,6 +52,8 @@ MODEL_CONFIG_SCHEMA_FILE = "shared/utils/schemas/model_config.yml"
51
52
  CONFIG_SCHEMA_FILE = "shared/utils/schemas/config.yml"
52
53
  RESPONSES_SCHEMA_FILE = "shared/nlu/training_data/schemas/responses.yml"
53
54
  SCHEMA_EXTENSIONS_FILE = "shared/utils/pykwalify_extensions.py"
55
+ ASSERTIONS_SCHEMA_FILE = "e2e_test/assertions_schema.yml"
56
+ ASSERTIONS_SCHEMA_EXTENSIONS_FILE = "e2e_test/pykwalify_extensions.py"
54
57
  LATEST_TRAINING_DATA_FORMAT_VERSION = "3.1"
55
58
 
56
59
  DOMAIN_SCHEMA_FILE = "shared/utils/schemas/domain.yml"
@@ -83,6 +86,8 @@ ENV_LOG_LEVEL_LLM_MODULE_NAMES = {
83
86
  "EnterpriseSearchPolicy": "LOG_LEVEL_LLM_ENTERPRISE_SEARCH",
84
87
  "IntentlessPolicy": "LOG_LEVEL_LLM_INTENTLESS_POLICY",
85
88
  "ContextualResponseRephraser": "LOG_LEVEL_LLM_REPHRASER",
89
+ "NLUCommandAdapter": "LOG_LEVEL_NLU_COMMAND_ADAPTER",
90
+ "LLMBasedRouter": "LOG_LEVEL_LLM_BASED_ROUTER",
86
91
  }
87
92
  TCP_PROTOCOL = "TCP"
88
93
 
@@ -106,7 +111,10 @@ CONFIG_KEYS_NLU = ["language", "pipeline"] + CONFIG_MANDATORY_COMMON_KEYS
106
111
  CONFIG_KEYS = CONFIG_KEYS_CORE + CONFIG_KEYS_NLU
107
112
  CONFIG_MANDATORY_KEYS_CORE: List[Text] = [] + CONFIG_MANDATORY_COMMON_KEYS
108
113
  CONFIG_MANDATORY_KEYS_NLU = ["language"] + CONFIG_MANDATORY_COMMON_KEYS
109
- CONFIG_MANDATORY_KEYS = CONFIG_MANDATORY_KEYS_CORE + CONFIG_MANDATORY_KEYS_NLU
114
+ # we need the list to contain unique values
115
+ CONFIG_MANDATORY_KEYS = list(
116
+ set(CONFIG_MANDATORY_KEYS_CORE + CONFIG_MANDATORY_KEYS_NLU)
117
+ )
110
118
 
111
119
  # Keys related to Forms (in the Domain)
112
120
  REQUIRED_SLOTS_KEY = "required_slots"
@@ -137,34 +145,91 @@ DIAGNOSTIC_DATA = "diagnostic_data"
137
145
  RESPONSE_CONDITION = "condition"
138
146
  CHANNEL = "channel"
139
147
 
148
+ API_KEY = "api_key"
149
+
150
+ AZURE_API_KEY_ENV_VAR = "AZURE_API_KEY"
151
+ AZURE_AD_TOKEN_ENV_VAR = "AZURE_AD_TOKEN"
152
+ AZURE_API_BASE_ENV_VAR = "AZURE_API_BASE"
153
+ AZURE_API_VERSION_ENV_VAR = "AZURE_API_VERSION"
154
+ AZURE_API_TYPE_ENV_VAR = "AZURE_API_TYPE"
155
+
140
156
  OPENAI_API_KEY_ENV_VAR = "OPENAI_API_KEY"
141
157
  OPENAI_API_TYPE_ENV_VAR = "OPENAI_API_TYPE"
142
158
  OPENAI_API_VERSION_ENV_VAR = "OPENAI_API_VERSION"
143
159
  OPENAI_API_BASE_ENV_VAR = "OPENAI_API_BASE"
144
160
 
161
+ OPENAI_API_BASE_CONFIG_KEY = "openai_api_base"
145
162
  OPENAI_API_TYPE_CONFIG_KEY = "openai_api_type"
146
- OPENAI_API_TYPE_NO_PREFIX_CONFIG_KEY = "api_type"
147
-
148
163
  OPENAI_API_VERSION_CONFIG_KEY = "openai_api_version"
149
- OPENAI_API_VERSION_NO_PREFIX_CONFIG_KEY = "api_version"
150
164
 
151
- OPENAI_API_BASE_CONFIG_KEY = "openai_api_base"
152
- OPENAI_API_BASE_NO_PREFIX_CONFIG_KEY = "api_base"
165
+ API_BASE_CONFIG_KEY = "api_base"
166
+ API_TYPE_CONFIG_KEY = "api_type"
167
+ API_VERSION_CONFIG_KEY = "api_version"
168
+ LANGCHAIN_TYPE_CONFIG_KEY = "_type"
169
+ RASA_TYPE_CONFIG_KEY = "type"
170
+ PROVIDER_CONFIG_KEY = "provider"
153
171
 
154
- OPENAI_DEPLOYMENT_NAME_CONFIG_KEY = "deployment_name"
155
- OPENAI_DEPLOYMENT_CONFIG_KEY = "deployment"
156
- OPENAI_ENGINE_CONFIG_KEY = "engine"
172
+ REQUEST_TIMEOUT_CONFIG_KEY = "request_timeout" # deprecated
173
+ TIMEOUT_CONFIG_KEY = "timeout"
157
174
 
158
- RASA_TYPE_CONFIG_KEY = "type"
159
- LANGCHAIN_TYPE_CONFIG_KEY = "_type"
175
+ DEPLOYMENT_NAME_CONFIG_KEY = "deployment_name"
176
+ DEPLOYMENT_CONFIG_KEY = "deployment"
177
+ EMBEDDINGS_CONFIG_KEY = "embeddings"
178
+ ENGINE_CONFIG_KEY = "engine"
179
+ LLM_CONFIG_KEY = "llm"
180
+ MODEL_CONFIG_KEY = "model"
181
+ MODEL_NAME_CONFIG_KEY = "model_name"
182
+ PROMPT_CONFIG_KEY = "prompt"
183
+ PROMPT_TEMPLATE_CONFIG_KEY = "prompt_template"
184
+
185
+ STREAM_CONFIG_KEY = "stream"
186
+ N_REPHRASES_CONFIG_KEY = "n"
187
+ USE_CHAT_COMPLETIONS_ENDPOINT_CONFIG_KEY = "use_chat_completions_endpoint"
188
+
189
+ AZURE_API_KEY_ENV_VAR = "AZURE_API_KEY"
190
+ AZURE_AD_TOKEN_ENV_VAR = "AZURE_AD_TOKEN"
191
+ AZURE_API_BASE_ENV_VAR = "AZURE_API_BASE"
192
+ AZURE_API_VERSION_ENV_VAR = "AZURE_API_VERSION"
193
+ AZURE_API_TYPE_ENV_VAR = "AZURE_API_TYPE"
194
+
195
+ HUGGINGFACE_MULTIPROCESS_CONFIG_KEY = "multi_process"
196
+ HUGGINGFACE_CACHE_FOLDER_CONFIG_KEY = "cache_folder"
197
+ HUGGINGFACE_SHOW_PROGRESS_CONFIG_KEY = "show_progress"
198
+ HUGGINGFACE_MODEL_KWARGS_CONFIG_KEY = "model_kwargs"
199
+ HUGGINGFACE_ENCODE_KWARGS_CONFIG_KEY = "encode_kwargs"
200
+ HUGGINGFACE_LOCAL_EMBEDDING_CACHING_FOLDER = (
201
+ get_local_cache_location() / "huggingface_local_embeddings"
202
+ )
160
203
 
161
204
  REQUESTS_CA_BUNDLE_ENV_VAR = "REQUESTS_CA_BUNDLE"
162
205
  REQUESTS_SSL_CONTEXT_PURPOSE_ENV_VAR = "REQUESTS_SSL_CONTEXT_PURPOSE"
206
+ RASA_CA_BUNDLE_ENV_VAR = "RASA_CA_BUNDLE" # used in verify
207
+ RASA_SSL_CERTIFICATE_ENV_VAR = "RASA_SSL_CERTIFICATE" # used in cert (client side)
208
+ LITELLM_SSL_VERIFY_ENV_VAR = "SSL_VERIFY"
209
+ LITELLM_SSL_CERTIFICATE_ENV_VAR = "SSL_CERTIFICATE"
210
+
211
+ OPENAI_PROVIDER = "openai"
212
+ AZURE_OPENAI_PROVIDER = "azure"
213
+ SELF_HOSTED_PROVIDER = "self-hosted"
214
+ HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER = "huggingface_local"
215
+
216
+ VALID_PROVIDERS_FOR_API_TYPE_CONFIG_KEY = [
217
+ OPENAI_PROVIDER,
218
+ AZURE_OPENAI_PROVIDER,
219
+ ]
220
+
221
+ SELF_HOSTED_VLLM_PREFIX = "hosted_vllm"
222
+ SELF_HOSTED_VLLM_API_KEY_ENV_VAR = "HOSTED_VLLM_API_KEY"
163
223
 
224
+ AZURE_API_TYPE = "azure"
225
+ OPENAI_API_TYPE = "openai"
164
226
 
165
227
  RASA_DEFAULT_FLOW_PATTERN_PREFIX = "pattern_"
166
228
  CONTEXT = "context"
167
229
 
230
+ RASA_PATTERN_INTERNAL_ERROR = "pattern_internal_error"
231
+ RASA_PATTERN_HUMAN_HANDOFF = "pattern_human_handoff"
232
+
168
233
  RASA_INTERNAL_ERROR_PREFIX = "rasa_internal_error_"
169
234
  RASA_PATTERN_INTERNAL_ERROR_DEFAULT = RASA_INTERNAL_ERROR_PREFIX + "default"
170
235
  RASA_PATTERN_INTERNAL_ERROR_USER_INPUT_TOO_LONG = (
@@ -185,8 +250,3 @@ RASA_PATTERN_CANNOT_HANDLE_INVALID_INTENT = (
185
250
  )
186
251
 
187
252
  ROUTE_TO_CALM_SLOT = "route_session_to_calm"
188
- EMBEDDINGS_CONFIG_KEY = "embeddings"
189
- MODEL_CONFIG_KEY = "model"
190
- MODEL_NAME_CONFIG_KEY = "model_name"
191
- PROMPT_CONFIG_KEY = "prompt"
192
- PROMPT_TEMPLATE_CONFIG_KEY = "prompt_template"