rasa-pro 3.9.18__py3-none-any.whl → 3.10.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (189) hide show
  1. README.md +26 -57
  2. rasa/__init__.py +1 -2
  3. rasa/__main__.py +5 -0
  4. rasa/anonymization/anonymization_rule_executor.py +2 -2
  5. rasa/api.py +26 -22
  6. rasa/cli/arguments/data.py +27 -2
  7. rasa/cli/arguments/default_arguments.py +25 -3
  8. rasa/cli/arguments/run.py +9 -9
  9. rasa/cli/arguments/train.py +2 -0
  10. rasa/cli/data.py +70 -8
  11. rasa/cli/e2e_test.py +108 -433
  12. rasa/cli/interactive.py +1 -0
  13. rasa/cli/llm_fine_tuning.py +395 -0
  14. rasa/cli/project_templates/calm/endpoints.yml +1 -1
  15. rasa/cli/project_templates/tutorial/endpoints.yml +1 -1
  16. rasa/cli/run.py +14 -13
  17. rasa/cli/scaffold.py +10 -8
  18. rasa/cli/train.py +8 -7
  19. rasa/cli/utils.py +15 -0
  20. rasa/constants.py +7 -1
  21. rasa/core/actions/action.py +98 -49
  22. rasa/core/actions/action_run_slot_rejections.py +4 -1
  23. rasa/core/actions/custom_action_executor.py +9 -6
  24. rasa/core/actions/direct_custom_actions_executor.py +80 -0
  25. rasa/core/actions/e2e_stub_custom_action_executor.py +68 -0
  26. rasa/core/actions/grpc_custom_action_executor.py +2 -2
  27. rasa/core/actions/http_custom_action_executor.py +6 -5
  28. rasa/core/agent.py +21 -17
  29. rasa/core/channels/__init__.py +2 -0
  30. rasa/core/channels/audiocodes.py +1 -16
  31. rasa/core/channels/inspector/dist/index.html +0 -2
  32. rasa/core/channels/inspector/index.html +0 -2
  33. rasa/core/channels/voice_aware/__init__.py +0 -0
  34. rasa/core/channels/voice_aware/jambonz.py +103 -0
  35. rasa/core/channels/voice_aware/jambonz_protocol.py +344 -0
  36. rasa/core/channels/voice_aware/utils.py +20 -0
  37. rasa/core/channels/voice_native/__init__.py +0 -0
  38. rasa/core/constants.py +6 -1
  39. rasa/core/featurizers/single_state_featurizer.py +1 -22
  40. rasa/core/featurizers/tracker_featurizers.py +18 -115
  41. rasa/core/information_retrieval/faiss.py +7 -4
  42. rasa/core/information_retrieval/information_retrieval.py +8 -0
  43. rasa/core/information_retrieval/milvus.py +9 -2
  44. rasa/core/information_retrieval/qdrant.py +1 -1
  45. rasa/core/nlg/contextual_response_rephraser.py +32 -10
  46. rasa/core/nlg/summarize.py +4 -3
  47. rasa/core/policies/enterprise_search_policy.py +100 -44
  48. rasa/core/policies/flows/flow_executor.py +130 -94
  49. rasa/core/policies/intentless_policy.py +52 -28
  50. rasa/core/policies/ted_policy.py +33 -58
  51. rasa/core/policies/unexpected_intent_policy.py +7 -15
  52. rasa/core/processor.py +20 -53
  53. rasa/core/run.py +5 -4
  54. rasa/core/tracker_store.py +8 -4
  55. rasa/core/utils.py +45 -56
  56. rasa/dialogue_understanding/coexistence/llm_based_router.py +45 -12
  57. rasa/dialogue_understanding/commands/__init__.py +4 -0
  58. rasa/dialogue_understanding/commands/change_flow_command.py +0 -6
  59. rasa/dialogue_understanding/commands/session_start_command.py +59 -0
  60. rasa/dialogue_understanding/commands/set_slot_command.py +1 -5
  61. rasa/dialogue_understanding/commands/utils.py +38 -0
  62. rasa/dialogue_understanding/generator/constants.py +10 -3
  63. rasa/dialogue_understanding/generator/flow_retrieval.py +14 -5
  64. rasa/dialogue_understanding/generator/llm_based_command_generator.py +12 -2
  65. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +106 -87
  66. rasa/dialogue_understanding/generator/nlu_command_adapter.py +28 -6
  67. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +90 -37
  68. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +15 -15
  69. rasa/dialogue_understanding/patterns/session_start.py +37 -0
  70. rasa/dialogue_understanding/processor/command_processor.py +13 -14
  71. rasa/e2e_test/aggregate_test_stats_calculator.py +124 -0
  72. rasa/e2e_test/assertions.py +1181 -0
  73. rasa/e2e_test/assertions_schema.yml +106 -0
  74. rasa/e2e_test/constants.py +20 -0
  75. rasa/e2e_test/e2e_config.py +220 -0
  76. rasa/e2e_test/e2e_config_schema.yml +26 -0
  77. rasa/e2e_test/e2e_test_case.py +131 -8
  78. rasa/e2e_test/e2e_test_converter.py +363 -0
  79. rasa/e2e_test/e2e_test_converter_prompt.jinja2 +70 -0
  80. rasa/e2e_test/e2e_test_coverage_report.py +364 -0
  81. rasa/e2e_test/e2e_test_result.py +26 -6
  82. rasa/e2e_test/e2e_test_runner.py +491 -72
  83. rasa/e2e_test/e2e_test_schema.yml +96 -0
  84. rasa/e2e_test/pykwalify_extensions.py +39 -0
  85. rasa/e2e_test/stub_custom_action.py +70 -0
  86. rasa/e2e_test/utils/__init__.py +0 -0
  87. rasa/e2e_test/utils/e2e_yaml_utils.py +55 -0
  88. rasa/e2e_test/utils/io.py +596 -0
  89. rasa/e2e_test/utils/validation.py +80 -0
  90. rasa/engine/recipes/default_components.py +0 -2
  91. rasa/engine/storage/local_model_storage.py +0 -1
  92. rasa/env.py +9 -0
  93. rasa/llm_fine_tuning/__init__.py +0 -0
  94. rasa/llm_fine_tuning/annotation_module.py +241 -0
  95. rasa/llm_fine_tuning/conversations.py +144 -0
  96. rasa/llm_fine_tuning/llm_data_preparation_module.py +178 -0
  97. rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +407 -0
  98. rasa/llm_fine_tuning/paraphrasing/__init__.py +0 -0
  99. rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +281 -0
  100. rasa/llm_fine_tuning/paraphrasing/default_rephrase_prompt_template.jina2 +44 -0
  101. rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +121 -0
  102. rasa/llm_fine_tuning/paraphrasing/rephrased_user_message.py +10 -0
  103. rasa/llm_fine_tuning/paraphrasing_module.py +128 -0
  104. rasa/llm_fine_tuning/storage.py +174 -0
  105. rasa/llm_fine_tuning/train_test_split_module.py +441 -0
  106. rasa/model_training.py +48 -16
  107. rasa/nlu/classifiers/diet_classifier.py +25 -38
  108. rasa/nlu/classifiers/logistic_regression_classifier.py +9 -44
  109. rasa/nlu/classifiers/sklearn_intent_classifier.py +16 -37
  110. rasa/nlu/extractors/crf_entity_extractor.py +50 -93
  111. rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +45 -78
  112. rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +17 -52
  113. rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +3 -5
  114. rasa/nlu/persistor.py +129 -32
  115. rasa/server.py +45 -10
  116. rasa/shared/constants.py +63 -15
  117. rasa/shared/core/domain.py +15 -12
  118. rasa/shared/core/events.py +28 -2
  119. rasa/shared/core/flows/flow.py +208 -13
  120. rasa/shared/core/flows/flow_path.py +84 -0
  121. rasa/shared/core/flows/flows_list.py +28 -10
  122. rasa/shared/core/flows/flows_yaml_schema.json +269 -193
  123. rasa/shared/core/flows/validation.py +112 -25
  124. rasa/shared/core/flows/yaml_flows_io.py +149 -10
  125. rasa/shared/core/trackers.py +6 -0
  126. rasa/shared/core/training_data/visualization.html +2 -2
  127. rasa/shared/exceptions.py +4 -0
  128. rasa/shared/importers/importer.py +60 -11
  129. rasa/shared/importers/remote_importer.py +196 -0
  130. rasa/shared/nlu/constants.py +2 -0
  131. rasa/shared/nlu/training_data/features.py +2 -120
  132. rasa/shared/providers/_configs/__init__.py +0 -0
  133. rasa/shared/providers/_configs/azure_openai_client_config.py +181 -0
  134. rasa/shared/providers/_configs/client_config.py +57 -0
  135. rasa/shared/providers/_configs/default_litellm_client_config.py +130 -0
  136. rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +234 -0
  137. rasa/shared/providers/_configs/openai_client_config.py +175 -0
  138. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +171 -0
  139. rasa/shared/providers/_configs/utils.py +101 -0
  140. rasa/shared/providers/_ssl_verification_utils.py +124 -0
  141. rasa/shared/providers/embedding/__init__.py +0 -0
  142. rasa/shared/providers/embedding/_base_litellm_embedding_client.py +254 -0
  143. rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py +74 -0
  144. rasa/shared/providers/embedding/azure_openai_embedding_client.py +277 -0
  145. rasa/shared/providers/embedding/default_litellm_embedding_client.py +102 -0
  146. rasa/shared/providers/embedding/embedding_client.py +90 -0
  147. rasa/shared/providers/embedding/embedding_response.py +41 -0
  148. rasa/shared/providers/embedding/huggingface_local_embedding_client.py +191 -0
  149. rasa/shared/providers/embedding/openai_embedding_client.py +172 -0
  150. rasa/shared/providers/llm/__init__.py +0 -0
  151. rasa/shared/providers/llm/_base_litellm_client.py +227 -0
  152. rasa/shared/providers/llm/azure_openai_llm_client.py +338 -0
  153. rasa/shared/providers/llm/default_litellm_llm_client.py +84 -0
  154. rasa/shared/providers/llm/llm_client.py +76 -0
  155. rasa/shared/providers/llm/llm_response.py +50 -0
  156. rasa/shared/providers/llm/openai_llm_client.py +155 -0
  157. rasa/shared/providers/llm/self_hosted_llm_client.py +169 -0
  158. rasa/shared/providers/mappings.py +75 -0
  159. rasa/shared/utils/cli.py +30 -0
  160. rasa/shared/utils/io.py +65 -3
  161. rasa/shared/utils/llm.py +223 -200
  162. rasa/shared/utils/yaml.py +122 -7
  163. rasa/studio/download.py +19 -13
  164. rasa/studio/train.py +2 -3
  165. rasa/studio/upload.py +2 -3
  166. rasa/telemetry.py +113 -58
  167. rasa/tracing/config.py +2 -3
  168. rasa/tracing/instrumentation/attribute_extractors.py +29 -17
  169. rasa/tracing/instrumentation/instrumentation.py +4 -47
  170. rasa/utils/common.py +18 -19
  171. rasa/utils/endpoints.py +7 -4
  172. rasa/utils/io.py +66 -0
  173. rasa/utils/json_utils.py +60 -0
  174. rasa/utils/licensing.py +9 -1
  175. rasa/utils/ml_utils.py +4 -2
  176. rasa/utils/tensorflow/model_data.py +193 -2
  177. rasa/validator.py +195 -1
  178. rasa/version.py +1 -1
  179. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.3.dist-info}/METADATA +47 -72
  180. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.3.dist-info}/RECORD +185 -121
  181. rasa/nlu/classifiers/llm_intent_classifier.py +0 -519
  182. rasa/shared/providers/openai/clients.py +0 -43
  183. rasa/shared/providers/openai/session_handler.py +0 -110
  184. rasa/utils/tensorflow/feature_array.py +0 -366
  185. /rasa/{shared/providers/openai → cli/project_templates/tutorial/actions}/__init__.py +0 -0
  186. /rasa/cli/project_templates/tutorial/{actions.py → actions/actions.py} +0 -0
  187. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.3.dist-info}/NOTICE +0 -0
  188. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.3.dist-info}/WHEEL +0 -0
  189. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.3.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,196 @@
1
+ import os
2
+ from typing import Dict, List, Optional, Text, Union
3
+
4
+ import structlog
5
+ from tarsafe import TarSafe
6
+
7
+ import rasa.shared.core.flows.yaml_flows_io
8
+ import rasa.shared.data
9
+ import rasa.shared.utils.common
10
+ import rasa.shared.utils.io
11
+ from rasa.nlu.persistor import StorageType
12
+ from rasa.shared.core.domain import Domain, InvalidDomain
13
+ from rasa.shared.core.flows import FlowsList
14
+ from rasa.shared.core.training_data.story_reader.yaml_story_reader import (
15
+ YAMLStoryReader,
16
+ )
17
+ from rasa.shared.core.training_data.structures import StoryGraph
18
+ from rasa.shared.exceptions import RasaException
19
+ from rasa.shared.importers import utils
20
+ from rasa.shared.importers.importer import TrainingDataImporter
21
+ from rasa.shared.nlu.training_data.training_data import TrainingData
22
+ from rasa.shared.utils.yaml import read_model_configuration
23
+
24
+ structlogger = structlog.get_logger()
25
+
26
+ TRAINING_DATA_ARCHIVE = "training_data.tar.gz"
27
+
28
+
29
+ class RemoteTrainingDataImporter(TrainingDataImporter):
30
+ """Remote `TrainingFileImporter` implementation.
31
+
32
+ Fetches training data from a remote storage and extracts it to a local directory.
33
+ Extracted training data is then used to load flows, NLU, stories,
34
+ domain, and config files.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ config_file: Optional[Text] = None,
40
+ domain_path: Optional[Text] = None,
41
+ training_data_paths: Optional[Union[List[Text], Text]] = None,
42
+ project_directory: Optional[Text] = None,
43
+ remote_storage: Optional[StorageType] = None,
44
+ training_data_path: Optional[Text] = None,
45
+ ):
46
+ """Initializes `RemoteTrainingDataImporter`.
47
+
48
+ Args:
49
+ config_file: Path to the model configuration file.
50
+ domain_path: Path to the domain file.
51
+ training_data_paths: List of paths to the training data files.
52
+ project_directory: Path to the project directory.
53
+ remote_storage: Storage to use to load the training data.
54
+ training_data_path: Path to the training data.
55
+ """
56
+ self.remote_storage = remote_storage
57
+ self.training_data_path = training_data_path
58
+
59
+ self.extracted_path = self._fetch_and_extract_training_archive(
60
+ TRAINING_DATA_ARCHIVE, self.training_data_path
61
+ )
62
+
63
+ self._nlu_files = rasa.shared.data.get_data_files(
64
+ self.extracted_path, rasa.shared.data.is_nlu_file
65
+ )
66
+ self._story_files = rasa.shared.data.get_data_files(
67
+ self.extracted_path, YAMLStoryReader.is_stories_file
68
+ )
69
+ self._flow_files = rasa.shared.data.get_data_files(
70
+ self.extracted_path, rasa.shared.core.flows.yaml_flows_io.is_flows_file
71
+ )
72
+ self._conversation_test_files = rasa.shared.data.get_data_files(
73
+ self.extracted_path, YAMLStoryReader.is_test_stories_file
74
+ )
75
+
76
+ self.config_file = config_file
77
+
78
+ def _fetch_training_archive(
79
+ self, training_file: str, training_data_path: Optional[str] = None
80
+ ) -> str:
81
+ """Fetches training files from remote storage."""
82
+ from rasa.nlu.persistor import get_persistor
83
+
84
+ persistor = get_persistor(self.remote_storage)
85
+ if persistor is None:
86
+ raise RasaException(
87
+ f"Could not find a persistor for "
88
+ f"the storage type '{self.remote_storage}'."
89
+ )
90
+
91
+ return persistor.retrieve(training_file, training_data_path)
92
+
93
+ def _fetch_and_extract_training_archive(
94
+ self, training_file: str, training_data_path: Optional[Text] = None
95
+ ) -> Optional[str]:
96
+ """Fetches and extracts training files from remote storage.
97
+
98
+ If the `training_data_path` is not provided, the training
99
+ data is extracted to the current working directory.
100
+
101
+ Args:
102
+ training_file: Name of the training data archive file.
103
+ training_data_path: Path to the training data.
104
+
105
+ Returns:
106
+ Path to the extracted training data.
107
+ """
108
+
109
+ if training_data_path is None:
110
+ training_data_path = os.path.join(os.getcwd(), "data")
111
+
112
+ if os.path.isfile(training_data_path):
113
+ raise ValueError(
114
+ f"Training data path '{training_data_path}' is a file. "
115
+ f"Please provide a directory path."
116
+ )
117
+
118
+ structlogger.debug(
119
+ "rasa.importers.remote_training_data_importer.fetch_training_archive",
120
+ training_data_path=training_data_path,
121
+ )
122
+ training_archive_file_path = self._fetch_training_archive(
123
+ training_file, training_data_path
124
+ )
125
+
126
+ if not os.path.isfile(training_archive_file_path):
127
+ raise FileNotFoundError(
128
+ f"Training data archive '{training_archive_file_path}' not found. "
129
+ f"Please make sure to provide the correct path."
130
+ )
131
+
132
+ structlogger.debug(
133
+ "rasa.importers.remote_training_data_importer.extract_training_archive",
134
+ training_archive_file_path=training_archive_file_path,
135
+ training_data_path=training_data_path,
136
+ )
137
+ with TarSafe.open(training_archive_file_path, "r:gz") as tar:
138
+ tar.extractall(path=training_data_path)
139
+
140
+ structlogger.debug(
141
+ "rasa.importers.remote_training_data_importer.remove_downloaded_archive",
142
+ training_data_path=training_data_path,
143
+ )
144
+ os.remove(training_archive_file_path)
145
+ return training_data_path
146
+
147
+ def get_config(self) -> Dict:
148
+ """Retrieves model config (see parent class for full docstring)."""
149
+ if not self.config_file or not os.path.exists(self.config_file):
150
+ structlogger.debug(
151
+ "rasa.importers.remote_training_data_importer.no_config_file",
152
+ message="No configuration file was provided to the RasaFileImporter.",
153
+ )
154
+ return {}
155
+
156
+ config = read_model_configuration(self.config_file)
157
+ return config
158
+
159
+ @rasa.shared.utils.common.cached_method
160
+ def get_config_file_for_auto_config(self) -> Optional[Text]:
161
+ """Returns config file path for auto-config only if there is a single one."""
162
+ return self.config_file
163
+
164
+ def get_stories(self, exclusion_percentage: Optional[int] = None) -> StoryGraph:
165
+ """Retrieves training stories / rules (see parent class for full docstring)."""
166
+ return utils.story_graph_from_paths(
167
+ self._story_files, self.get_domain(), exclusion_percentage
168
+ )
169
+
170
+ def get_flows(self) -> FlowsList:
171
+ """Retrieves training stories / rules (see parent class for full docstring)."""
172
+ return utils.flows_from_paths(self._flow_files)
173
+
174
+ def get_conversation_tests(self) -> StoryGraph:
175
+ """Retrieves conversation test stories (see parent class for full docstring)."""
176
+ return utils.story_graph_from_paths(
177
+ self._conversation_test_files, self.get_domain()
178
+ )
179
+
180
+ def get_nlu_data(self, language: Optional[Text] = "en") -> TrainingData:
181
+ """Retrieves NLU training data (see parent class for full docstring)."""
182
+ return utils.training_data_from_paths(self._nlu_files, language)
183
+
184
+ def get_domain(self) -> Domain:
185
+ """Retrieves model domain (see parent class for full docstring)."""
186
+ domain = Domain.empty()
187
+ domain_path = f"{self.extracted_path}"
188
+ try:
189
+ domain = Domain.load(domain_path)
190
+ except InvalidDomain as e:
191
+ rasa.shared.utils.io.raise_warning(
192
+ f"Loading domain from '{domain_path}' failed. Using "
193
+ f"empty domain. Error: '{e}'"
194
+ )
195
+
196
+ return domain
@@ -2,6 +2,8 @@ TEXT = "text"
2
2
  TEXT_TOKENS = "text_tokens"
3
3
  INTENT = "intent"
4
4
  COMMANDS = "commands"
5
+ LLM_COMMANDS = "llm_commands" # needed for fine-tuning
6
+ LLM_PROMPT = "llm_prompt" # needed for fine-tuning
5
7
  FLOWS_FROM_SEMANTIC_SEARCH = "flows_from_semantic_search"
6
8
  FLOWS_IN_PROMPT = "flows_in_prompt"
7
9
  NOT_INTENT = "not_intent"
@@ -1,133 +1,15 @@
1
1
  from __future__ import annotations
2
-
3
- import itertools
4
- from dataclasses import dataclass
5
2
  from typing import Iterable, Union, Text, Optional, List, Any, Tuple, Dict, Set
3
+ import itertools
6
4
 
7
5
  import numpy as np
8
6
  import scipy.sparse
9
- from safetensors.numpy import save_file, load_file
10
7
 
11
- import rasa.shared.nlu.training_data.util
12
8
  import rasa.shared.utils.io
9
+ import rasa.shared.nlu.training_data.util
13
10
  from rasa.shared.nlu.constants import FEATURE_TYPE_SEQUENCE, FEATURE_TYPE_SENTENCE
14
11
 
15
12
 
16
- @dataclass
17
- class FeatureMetadata:
18
- data_type: str
19
- attribute: str
20
- origin: Union[str, List[str]]
21
- is_sparse: bool
22
- shape: tuple
23
- safetensors_key: str
24
-
25
-
26
- def save_features(
27
- features_dict: Dict[Text, List[Features]], file_name: str
28
- ) -> Dict[str, Any]:
29
- """Save a dictionary of Features lists to disk using safetensors.
30
-
31
- Args:
32
- features_dict: Dictionary mapping strings to lists of Features objects
33
- file_name: File to save the features to
34
-
35
- Returns:
36
- The metadata to reconstruct the features.
37
- """
38
- # All tensors are stored in a single safetensors file
39
- tensors_to_save = {}
40
- # Metadata will be stored separately
41
- metadata = {}
42
-
43
- for key, features_list in features_dict.items():
44
- feature_metadata_list = []
45
-
46
- for idx, feature in enumerate(features_list):
47
- # Create a unique key for this tensor in the safetensors file
48
- safetensors_key = f"{key}_{idx}"
49
-
50
- # Convert sparse matrices to dense if needed
51
- if feature.is_sparse():
52
- # For sparse matrices, use the COO format
53
- coo = feature.features.tocoo() # type:ignore[union-attr]
54
- # Save data, row indices and col indices separately
55
- tensors_to_save[f"{safetensors_key}_data"] = coo.data
56
- tensors_to_save[f"{safetensors_key}_row"] = coo.row
57
- tensors_to_save[f"{safetensors_key}_col"] = coo.col
58
- else:
59
- tensors_to_save[safetensors_key] = feature.features
60
-
61
- # Store metadata
62
- metadata_item = FeatureMetadata(
63
- data_type=feature.type,
64
- attribute=feature.attribute,
65
- origin=feature.origin,
66
- is_sparse=feature.is_sparse(),
67
- shape=feature.features.shape,
68
- safetensors_key=safetensors_key,
69
- )
70
- feature_metadata_list.append(vars(metadata_item))
71
-
72
- metadata[key] = feature_metadata_list
73
-
74
- # Save tensors
75
- save_file(tensors_to_save, file_name)
76
-
77
- return metadata
78
-
79
-
80
- def load_features(
81
- filename: str, metadata: Dict[str, Any]
82
- ) -> Dict[Text, List[Features]]:
83
- """Load Features dictionary from disk.
84
-
85
- Args:
86
- filename: File name of the safetensors file.
87
- metadata: Metadata to reconstruct the features.
88
-
89
- Returns:
90
- Dictionary mapping strings to lists of Features objects
91
- """
92
- # Load tensors
93
- tensors = load_file(filename)
94
-
95
- # Reconstruct the features dictionary
96
- features_dict: Dict[Text, List[Features]] = {}
97
-
98
- for key, feature_metadata_list in metadata.items():
99
- features_list = []
100
-
101
- for meta in feature_metadata_list:
102
- safetensors_key = meta["safetensors_key"]
103
-
104
- if meta["is_sparse"]:
105
- # Reconstruct sparse matrix from COO format
106
- data = tensors[f"{safetensors_key}_data"]
107
- row = tensors[f"{safetensors_key}_row"]
108
- col = tensors[f"{safetensors_key}_col"]
109
-
110
- features_matrix = scipy.sparse.coo_matrix(
111
- (data, (row, col)), shape=tuple(meta["shape"])
112
- ).tocsr() # Convert back to CSR format
113
- else:
114
- features_matrix = tensors[safetensors_key]
115
-
116
- # Reconstruct Features object
117
- features = Features(
118
- features=features_matrix,
119
- feature_type=meta["data_type"],
120
- attribute=meta["attribute"],
121
- origin=meta["origin"],
122
- )
123
-
124
- features_list.append(features)
125
-
126
- features_dict[key] = features_list
127
-
128
- return features_dict
129
-
130
-
131
13
  class Features:
132
14
  """Stores the features produced by any featurizer."""
133
15
 
File without changes
@@ -0,0 +1,181 @@
1
+ from dataclasses import asdict, dataclass, field
2
+ from typing import Any, Dict, Optional
3
+
4
+ import structlog
5
+
6
+ from rasa.shared.constants import (
7
+ MODEL_CONFIG_KEY,
8
+ MODEL_NAME_CONFIG_KEY,
9
+ OPENAI_API_BASE_CONFIG_KEY,
10
+ API_BASE_CONFIG_KEY,
11
+ OPENAI_API_TYPE_CONFIG_KEY,
12
+ API_TYPE_CONFIG_KEY,
13
+ OPENAI_API_VERSION_CONFIG_KEY,
14
+ API_VERSION_CONFIG_KEY,
15
+ DEPLOYMENT_CONFIG_KEY,
16
+ DEPLOYMENT_NAME_CONFIG_KEY,
17
+ ENGINE_CONFIG_KEY,
18
+ RASA_TYPE_CONFIG_KEY,
19
+ LANGCHAIN_TYPE_CONFIG_KEY,
20
+ STREAM_CONFIG_KEY,
21
+ N_REPHRASES_CONFIG_KEY,
22
+ REQUEST_TIMEOUT_CONFIG_KEY,
23
+ TIMEOUT_CONFIG_KEY,
24
+ PROVIDER_CONFIG_KEY,
25
+ AZURE_OPENAI_PROVIDER,
26
+ AZURE_API_TYPE,
27
+ )
28
+ from rasa.shared.providers._configs.utils import (
29
+ resolve_aliases,
30
+ raise_deprecation_warnings,
31
+ validate_required_keys,
32
+ validate_forbidden_keys,
33
+ )
34
+
35
+ structlogger = structlog.get_logger()
36
+
37
+ DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING = {
38
+ # Deployment name aliases
39
+ DEPLOYMENT_NAME_CONFIG_KEY: DEPLOYMENT_CONFIG_KEY,
40
+ ENGINE_CONFIG_KEY: DEPLOYMENT_CONFIG_KEY,
41
+ # Provider aliases
42
+ RASA_TYPE_CONFIG_KEY: PROVIDER_CONFIG_KEY,
43
+ LANGCHAIN_TYPE_CONFIG_KEY: PROVIDER_CONFIG_KEY,
44
+ # API type aliases
45
+ OPENAI_API_TYPE_CONFIG_KEY: API_TYPE_CONFIG_KEY,
46
+ # API base aliases
47
+ OPENAI_API_BASE_CONFIG_KEY: API_BASE_CONFIG_KEY,
48
+ # API version aliases
49
+ OPENAI_API_VERSION_CONFIG_KEY: API_VERSION_CONFIG_KEY,
50
+ # Model name aliases
51
+ MODEL_NAME_CONFIG_KEY: MODEL_CONFIG_KEY,
52
+ # Timeout aliases
53
+ REQUEST_TIMEOUT_CONFIG_KEY: TIMEOUT_CONFIG_KEY,
54
+ }
55
+
56
+ REQUIRED_KEYS = [DEPLOYMENT_CONFIG_KEY]
57
+
58
+ FORBIDDEN_KEYS = [
59
+ STREAM_CONFIG_KEY,
60
+ N_REPHRASES_CONFIG_KEY,
61
+ ]
62
+
63
+
64
+ @dataclass
65
+ class AzureOpenAIClientConfig:
66
+ """Parses configuration for Azure OpenAI client, resolves aliases and
67
+ raises deprecation warnings.
68
+
69
+ Raises:
70
+ ValueError: Raised in cases of invalid configuration:
71
+ - If any of the required configuration keys are missing.
72
+ - If `api_type` has a value different from `azure`.
73
+ """
74
+
75
+ deployment: str
76
+
77
+ model: Optional[str]
78
+ api_base: Optional[str]
79
+ api_version: Optional[str]
80
+ # API Type is not used by LiteLLM backend, but we define
81
+ # it here for backward compatibility.
82
+ api_type: Optional[str] = AZURE_API_TYPE
83
+
84
+ # Provider is not used by LiteLLM backend, but we define it here since it's
85
+ # used as switch between different clients.
86
+ provider: str = AZURE_OPENAI_PROVIDER
87
+
88
+ extra_parameters: dict = field(default_factory=dict)
89
+
90
+ def __post_init__(self) -> None:
91
+ if self.provider != AZURE_OPENAI_PROVIDER:
92
+ message = f"Provider must be set to '{AZURE_OPENAI_PROVIDER}'."
93
+ structlogger.error(
94
+ "azure_openai_client_config.validation_error",
95
+ message=message,
96
+ provider=self.provider,
97
+ )
98
+ raise ValueError(message)
99
+ if self.deployment is None:
100
+ message = "Deployment cannot be set to None."
101
+ structlogger.error(
102
+ "azure_openai_client_config.validation_error",
103
+ message=message,
104
+ deployment=self.deployment,
105
+ )
106
+ raise ValueError(message)
107
+
108
+ @classmethod
109
+ def from_dict(cls, config: dict) -> "AzureOpenAIClientConfig":
110
+ """
111
+ Initializes a dataclass from the passed config.
112
+
113
+ Args:
114
+ config: (dict) The config from which to initialize.
115
+
116
+ Raises:
117
+ ValueError: Raised in cases of invalid configuration:
118
+ - If any of the required configuration keys are missing.
119
+ - If `api_type` has a value different from `azure`.
120
+
121
+ Returns:
122
+ AzureOpenAIClientConfig
123
+ """
124
+ # Check for deprecated keys
125
+ raise_deprecation_warnings(config, DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING)
126
+ # Resolve any potential aliases
127
+ config = cls.resolve_config_aliases(config)
128
+ # Validate that required keys are set
129
+ validate_required_keys(config, REQUIRED_KEYS)
130
+ # Validate that the forbidden keys are not present
131
+ validate_forbidden_keys(config, FORBIDDEN_KEYS)
132
+ # Init client config
133
+ this = AzureOpenAIClientConfig(
134
+ # Required parameters
135
+ deployment=config.pop(DEPLOYMENT_CONFIG_KEY),
136
+ # Pop the 'provider' key. Currently, it's *optional* because of
137
+ # backward compatibility with older versions.
138
+ provider=config.pop(PROVIDER_CONFIG_KEY, AZURE_OPENAI_PROVIDER),
139
+ # Optional
140
+ api_type=config.pop(API_TYPE_CONFIG_KEY, AZURE_API_TYPE),
141
+ model=config.pop(MODEL_CONFIG_KEY, None),
142
+ # Optional, can also be set through environment variables
143
+ # in clients.
144
+ api_base=config.pop(API_BASE_CONFIG_KEY, None),
145
+ api_version=config.pop(API_VERSION_CONFIG_KEY, None),
146
+ # The rest of parameters (e.g. model parameters) are considered
147
+ # as extra parameters (this also includes timeout).
148
+ extra_parameters=config,
149
+ )
150
+ return this
151
+
152
+ def to_dict(self) -> dict:
153
+ """Converts the config instance into a dictionary."""
154
+ d = asdict(self)
155
+ # Extra parameters should also be on the top level
156
+ d.pop("extra_parameters", None)
157
+ d.update(self.extra_parameters)
158
+ return d
159
+
160
+ @staticmethod
161
+ def resolve_config_aliases(config: Dict[str, Any]) -> Dict[str, Any]:
162
+ return resolve_aliases(config, DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING)
163
+
164
+
165
+ def is_azure_openai_config(config: dict) -> bool:
166
+ """Check whether the configuration is meant to configure
167
+ an Azure OpenAI client.
168
+ """
169
+ # Resolve any aliases that are specific to Azure OpenAI configuration
170
+ config = AzureOpenAIClientConfig.resolve_config_aliases(config)
171
+
172
+ # Case: Configuration contains `provider: azure`.
173
+ if config.get(PROVIDER_CONFIG_KEY) == AZURE_OPENAI_PROVIDER:
174
+ return True
175
+
176
+ # Case: Configuration contains `deployment` key
177
+ # (specific to Azure OpenAI configuration)
178
+ if config.get(DEPLOYMENT_CONFIG_KEY) is not None:
179
+ return True
180
+
181
+ return False
@@ -0,0 +1,57 @@
1
+ from typing import Protocol, runtime_checkable
2
+
3
+
4
+ @runtime_checkable
5
+ class ClientConfig(Protocol):
6
+ """
7
+ Protocol for the client config that specifies the interface for interacting
8
+ with the API.
9
+ """
10
+
11
+ @classmethod
12
+ def from_dict(cls, config: dict) -> "ClientConfig":
13
+ """
14
+ Initializes the client config with the given configuration.
15
+
16
+ This class method should be implemented to parse the given
17
+ configuration and create an instance of an client config.
18
+
19
+ Args:
20
+ config: (dict) The config from which to initialize.
21
+
22
+ Raises:
23
+ ValueError: Config is missing required keys.
24
+
25
+ Returns:
26
+ ClientConfig
27
+ """
28
+ ...
29
+
30
+ def to_dict(self) -> dict:
31
+ """
32
+ Returns the configuration for that the client config is initialized with.
33
+
34
+ This method should be implemented to return a dictionary containing
35
+ the configuration settings for the client config.
36
+
37
+ Returns:
38
+ dictionary containing the configuration settings for the client config.
39
+ """
40
+ ...
41
+
42
+ @staticmethod
43
+ def resolve_config_aliases(config: dict) -> dict:
44
+ """
45
+ Resolve any potential aliases in the configuration.
46
+
47
+ This method should be implemented to resolve any potential aliases in the
48
+ configuration.
49
+
50
+ Args:
51
+ config: (dict) The config from which to initialize.
52
+
53
+ Returns:
54
+ dictionary containing the resolved configuration settings for the
55
+ client config.
56
+ """
57
+ ...