rasa-pro 3.9.18__py3-none-any.whl → 3.10.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (183) hide show
  1. README.md +0 -374
  2. rasa/__init__.py +1 -2
  3. rasa/__main__.py +5 -0
  4. rasa/anonymization/anonymization_rule_executor.py +2 -2
  5. rasa/api.py +27 -23
  6. rasa/cli/arguments/data.py +27 -2
  7. rasa/cli/arguments/default_arguments.py +25 -3
  8. rasa/cli/arguments/run.py +9 -9
  9. rasa/cli/arguments/train.py +11 -3
  10. rasa/cli/data.py +70 -8
  11. rasa/cli/e2e_test.py +104 -431
  12. rasa/cli/evaluate.py +1 -1
  13. rasa/cli/interactive.py +1 -0
  14. rasa/cli/llm_fine_tuning.py +398 -0
  15. rasa/cli/project_templates/calm/endpoints.yml +1 -1
  16. rasa/cli/project_templates/tutorial/endpoints.yml +1 -1
  17. rasa/cli/run.py +15 -14
  18. rasa/cli/scaffold.py +10 -8
  19. rasa/cli/studio/studio.py +35 -5
  20. rasa/cli/train.py +56 -8
  21. rasa/cli/utils.py +22 -5
  22. rasa/cli/x.py +1 -1
  23. rasa/constants.py +7 -1
  24. rasa/core/actions/action.py +98 -49
  25. rasa/core/actions/action_run_slot_rejections.py +4 -1
  26. rasa/core/actions/custom_action_executor.py +9 -6
  27. rasa/core/actions/direct_custom_actions_executor.py +80 -0
  28. rasa/core/actions/e2e_stub_custom_action_executor.py +68 -0
  29. rasa/core/actions/grpc_custom_action_executor.py +2 -2
  30. rasa/core/actions/http_custom_action_executor.py +6 -5
  31. rasa/core/agent.py +21 -17
  32. rasa/core/channels/__init__.py +2 -0
  33. rasa/core/channels/audiocodes.py +1 -16
  34. rasa/core/channels/voice_aware/__init__.py +0 -0
  35. rasa/core/channels/voice_aware/jambonz.py +103 -0
  36. rasa/core/channels/voice_aware/jambonz_protocol.py +344 -0
  37. rasa/core/channels/voice_aware/utils.py +20 -0
  38. rasa/core/channels/voice_native/__init__.py +0 -0
  39. rasa/core/constants.py +6 -1
  40. rasa/core/information_retrieval/faiss.py +7 -4
  41. rasa/core/information_retrieval/information_retrieval.py +8 -0
  42. rasa/core/information_retrieval/milvus.py +9 -2
  43. rasa/core/information_retrieval/qdrant.py +1 -1
  44. rasa/core/nlg/contextual_response_rephraser.py +32 -10
  45. rasa/core/nlg/summarize.py +4 -3
  46. rasa/core/policies/enterprise_search_policy.py +113 -45
  47. rasa/core/policies/flows/flow_executor.py +122 -76
  48. rasa/core/policies/intentless_policy.py +83 -29
  49. rasa/core/processor.py +72 -54
  50. rasa/core/run.py +5 -4
  51. rasa/core/tracker_store.py +8 -4
  52. rasa/core/training/interactive.py +1 -1
  53. rasa/core/utils.py +56 -57
  54. rasa/dialogue_understanding/coexistence/llm_based_router.py +53 -13
  55. rasa/dialogue_understanding/commands/__init__.py +6 -0
  56. rasa/dialogue_understanding/commands/restart_command.py +58 -0
  57. rasa/dialogue_understanding/commands/session_start_command.py +59 -0
  58. rasa/dialogue_understanding/commands/utils.py +40 -0
  59. rasa/dialogue_understanding/generator/constants.py +10 -3
  60. rasa/dialogue_understanding/generator/flow_retrieval.py +21 -5
  61. rasa/dialogue_understanding/generator/llm_based_command_generator.py +13 -3
  62. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +134 -90
  63. rasa/dialogue_understanding/generator/nlu_command_adapter.py +47 -7
  64. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +127 -41
  65. rasa/dialogue_understanding/patterns/restart.py +37 -0
  66. rasa/dialogue_understanding/patterns/session_start.py +37 -0
  67. rasa/dialogue_understanding/processor/command_processor.py +16 -3
  68. rasa/dialogue_understanding/processor/command_processor_component.py +6 -2
  69. rasa/e2e_test/aggregate_test_stats_calculator.py +134 -0
  70. rasa/e2e_test/assertions.py +1223 -0
  71. rasa/e2e_test/assertions_schema.yml +106 -0
  72. rasa/e2e_test/constants.py +20 -0
  73. rasa/e2e_test/e2e_config.py +220 -0
  74. rasa/e2e_test/e2e_config_schema.yml +26 -0
  75. rasa/e2e_test/e2e_test_case.py +131 -8
  76. rasa/e2e_test/e2e_test_converter.py +363 -0
  77. rasa/e2e_test/e2e_test_converter_prompt.jinja2 +70 -0
  78. rasa/e2e_test/e2e_test_coverage_report.py +364 -0
  79. rasa/e2e_test/e2e_test_result.py +26 -6
  80. rasa/e2e_test/e2e_test_runner.py +493 -71
  81. rasa/e2e_test/e2e_test_schema.yml +96 -0
  82. rasa/e2e_test/pykwalify_extensions.py +39 -0
  83. rasa/e2e_test/stub_custom_action.py +70 -0
  84. rasa/e2e_test/utils/__init__.py +0 -0
  85. rasa/e2e_test/utils/e2e_yaml_utils.py +55 -0
  86. rasa/e2e_test/utils/io.py +598 -0
  87. rasa/e2e_test/utils/validation.py +80 -0
  88. rasa/engine/graph.py +9 -3
  89. rasa/engine/recipes/default_components.py +0 -2
  90. rasa/engine/recipes/default_recipe.py +10 -2
  91. rasa/engine/storage/local_model_storage.py +40 -12
  92. rasa/engine/validation.py +78 -1
  93. rasa/env.py +9 -0
  94. rasa/graph_components/providers/story_graph_provider.py +59 -6
  95. rasa/llm_fine_tuning/__init__.py +0 -0
  96. rasa/llm_fine_tuning/annotation_module.py +241 -0
  97. rasa/llm_fine_tuning/conversations.py +144 -0
  98. rasa/llm_fine_tuning/llm_data_preparation_module.py +178 -0
  99. rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +407 -0
  100. rasa/llm_fine_tuning/paraphrasing/__init__.py +0 -0
  101. rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +281 -0
  102. rasa/llm_fine_tuning/paraphrasing/default_rephrase_prompt_template.jina2 +44 -0
  103. rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +121 -0
  104. rasa/llm_fine_tuning/paraphrasing/rephrased_user_message.py +10 -0
  105. rasa/llm_fine_tuning/paraphrasing_module.py +128 -0
  106. rasa/llm_fine_tuning/storage.py +174 -0
  107. rasa/llm_fine_tuning/train_test_split_module.py +441 -0
  108. rasa/model_training.py +56 -16
  109. rasa/nlu/persistor.py +157 -36
  110. rasa/server.py +45 -10
  111. rasa/shared/constants.py +76 -16
  112. rasa/shared/core/domain.py +27 -19
  113. rasa/shared/core/events.py +28 -2
  114. rasa/shared/core/flows/flow.py +208 -13
  115. rasa/shared/core/flows/flow_path.py +84 -0
  116. rasa/shared/core/flows/flows_list.py +33 -11
  117. rasa/shared/core/flows/flows_yaml_schema.json +269 -193
  118. rasa/shared/core/flows/validation.py +112 -25
  119. rasa/shared/core/flows/yaml_flows_io.py +149 -10
  120. rasa/shared/core/trackers.py +6 -0
  121. rasa/shared/core/training_data/structures.py +20 -0
  122. rasa/shared/core/training_data/visualization.html +2 -2
  123. rasa/shared/exceptions.py +4 -0
  124. rasa/shared/importers/importer.py +64 -16
  125. rasa/shared/nlu/constants.py +2 -0
  126. rasa/shared/providers/_configs/__init__.py +0 -0
  127. rasa/shared/providers/_configs/azure_openai_client_config.py +183 -0
  128. rasa/shared/providers/_configs/client_config.py +57 -0
  129. rasa/shared/providers/_configs/default_litellm_client_config.py +130 -0
  130. rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +234 -0
  131. rasa/shared/providers/_configs/openai_client_config.py +175 -0
  132. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +176 -0
  133. rasa/shared/providers/_configs/utils.py +101 -0
  134. rasa/shared/providers/_ssl_verification_utils.py +124 -0
  135. rasa/shared/providers/embedding/__init__.py +0 -0
  136. rasa/shared/providers/embedding/_base_litellm_embedding_client.py +259 -0
  137. rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py +74 -0
  138. rasa/shared/providers/embedding/azure_openai_embedding_client.py +277 -0
  139. rasa/shared/providers/embedding/default_litellm_embedding_client.py +102 -0
  140. rasa/shared/providers/embedding/embedding_client.py +90 -0
  141. rasa/shared/providers/embedding/embedding_response.py +41 -0
  142. rasa/shared/providers/embedding/huggingface_local_embedding_client.py +191 -0
  143. rasa/shared/providers/embedding/openai_embedding_client.py +172 -0
  144. rasa/shared/providers/llm/__init__.py +0 -0
  145. rasa/shared/providers/llm/_base_litellm_client.py +251 -0
  146. rasa/shared/providers/llm/azure_openai_llm_client.py +338 -0
  147. rasa/shared/providers/llm/default_litellm_llm_client.py +84 -0
  148. rasa/shared/providers/llm/llm_client.py +76 -0
  149. rasa/shared/providers/llm/llm_response.py +50 -0
  150. rasa/shared/providers/llm/openai_llm_client.py +155 -0
  151. rasa/shared/providers/llm/self_hosted_llm_client.py +293 -0
  152. rasa/shared/providers/mappings.py +75 -0
  153. rasa/shared/utils/cli.py +30 -0
  154. rasa/shared/utils/io.py +65 -2
  155. rasa/shared/utils/llm.py +246 -200
  156. rasa/shared/utils/yaml.py +121 -15
  157. rasa/studio/auth.py +6 -4
  158. rasa/studio/config.py +13 -4
  159. rasa/studio/constants.py +1 -0
  160. rasa/studio/data_handler.py +10 -3
  161. rasa/studio/download.py +19 -13
  162. rasa/studio/train.py +2 -3
  163. rasa/studio/upload.py +19 -11
  164. rasa/telemetry.py +113 -58
  165. rasa/tracing/instrumentation/attribute_extractors.py +32 -17
  166. rasa/utils/common.py +18 -19
  167. rasa/utils/endpoints.py +7 -4
  168. rasa/utils/json_utils.py +60 -0
  169. rasa/utils/licensing.py +9 -1
  170. rasa/utils/ml_utils.py +4 -2
  171. rasa/validator.py +213 -3
  172. rasa/version.py +1 -1
  173. rasa_pro-3.10.16.dist-info/METADATA +196 -0
  174. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/RECORD +179 -113
  175. rasa/nlu/classifiers/llm_intent_classifier.py +0 -519
  176. rasa/shared/providers/openai/clients.py +0 -43
  177. rasa/shared/providers/openai/session_handler.py +0 -110
  178. rasa_pro-3.9.18.dist-info/METADATA +0 -563
  179. /rasa/{shared/providers/openai → cli/project_templates/tutorial/actions}/__init__.py +0 -0
  180. /rasa/cli/project_templates/tutorial/{actions.py → actions/actions.py} +0 -0
  181. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/NOTICE +0 -0
  182. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/WHEEL +0 -0
  183. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,234 @@
1
+ from dataclasses import asdict, dataclass, field
2
+ from typing import Any, Dict, Optional
3
+
4
+ import structlog
5
+
6
+ from rasa.shared.constants import (
7
+ MODEL_CONFIG_KEY,
8
+ MODEL_NAME_CONFIG_KEY,
9
+ RASA_TYPE_CONFIG_KEY,
10
+ LANGCHAIN_TYPE_CONFIG_KEY,
11
+ HUGGINGFACE_MULTIPROCESS_CONFIG_KEY,
12
+ HUGGINGFACE_CACHE_FOLDER_CONFIG_KEY,
13
+ HUGGINGFACE_SHOW_PROGRESS_CONFIG_KEY,
14
+ HUGGINGFACE_MODEL_KWARGS_CONFIG_KEY,
15
+ HUGGINGFACE_ENCODE_KWARGS_CONFIG_KEY,
16
+ HUGGINGFACE_LOCAL_EMBEDDING_CACHING_FOLDER,
17
+ PROVIDER_CONFIG_KEY,
18
+ HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER,
19
+ TIMEOUT_CONFIG_KEY,
20
+ REQUEST_TIMEOUT_CONFIG_KEY,
21
+ )
22
+ from rasa.shared.providers._configs.utils import (
23
+ resolve_aliases,
24
+ raise_deprecation_warnings,
25
+ validate_required_keys,
26
+ )
27
+ from rasa.shared.utils.io import raise_deprecation_warning
28
+
29
+ structlogger = structlog.get_logger()
30
+
31
+ DEPRECATED_HUGGINGFACE_TYPE = "huggingface"
32
+
33
+ DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING = {
34
+ # Provider aliases
35
+ RASA_TYPE_CONFIG_KEY: PROVIDER_CONFIG_KEY,
36
+ LANGCHAIN_TYPE_CONFIG_KEY: PROVIDER_CONFIG_KEY,
37
+ # Model name aliases
38
+ MODEL_NAME_CONFIG_KEY: MODEL_CONFIG_KEY,
39
+ # Timeout aliases
40
+ REQUEST_TIMEOUT_CONFIG_KEY: TIMEOUT_CONFIG_KEY,
41
+ }
42
+
43
+ REQUIRED_KEYS = [MODEL_CONFIG_KEY, PROVIDER_CONFIG_KEY]
44
+
45
+
46
+ @dataclass
47
+ class HuggingFaceLocalEmbeddingClientConfig:
48
+ """Parses configuration for HuggingFace local embeddings client, resolves
49
+ aliases and raises deprecation warnings.
50
+
51
+ Raises:
52
+ ValueError: Raised in cases of invalid configuration:
53
+ - If any of the required configuration keys are missing.
54
+ - If `api_type` has a value different from `huggingface_local` or
55
+ `huggingface` (deprecated).
56
+ """
57
+
58
+ model: str
59
+
60
+ multi_process: Optional[bool]
61
+ cache_folder: Optional[str]
62
+ show_progress: Optional[bool]
63
+
64
+ # Provider is not actually used by sentence-transformers, but we define
65
+ # it here because it's used as a switch denominator for HuggingFace
66
+ # local embedding client.
67
+ provider: str = HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER
68
+
69
+ model_kwargs: dict = field(default_factory=dict)
70
+ encode_kwargs: dict = field(default_factory=dict)
71
+
72
+ def __post_init__(self) -> None:
73
+ if self.provider != HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER:
74
+ message = (
75
+ f"API type must be set to '{HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER}'."
76
+ )
77
+ structlogger.error(
78
+ "huggingface_local_embeddings_client_config.validation_error",
79
+ message=message,
80
+ provider=self.provider,
81
+ )
82
+ raise ValueError(message)
83
+ if self.model is None:
84
+ message = "Model cannot be set to None."
85
+ structlogger.error(
86
+ "huggingface_local_embeddings_client_config.validation_error",
87
+ message=message,
88
+ model=self.model,
89
+ )
90
+ raise ValueError(message)
91
+
92
+ @classmethod
93
+ def from_dict(cls, config: dict) -> "HuggingFaceLocalEmbeddingClientConfig":
94
+ """
95
+ Initializes a dataclass from the passed config.
96
+
97
+ Args:
98
+ config: (dict) The config from which to initialize.
99
+
100
+ Raises:
101
+ ValueError: Config is missing required keys.
102
+
103
+ Returns:
104
+ DefaultLiteLLMClientConfig
105
+ """
106
+ # Check for usage of deprecated switching key and value:
107
+ # 1. type: huggingface
108
+ # 2. _type: huggingface
109
+ _raise_deprecation_warning_for_huggingface_deprecated_switch_value(config)
110
+ # Check for other deprecated keys
111
+ raise_deprecation_warnings(config, DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING)
112
+ # Resolve any potential aliases
113
+ config = cls.resolve_config_aliases(config)
114
+ # Validate that required keys are set
115
+ validate_required_keys(config, REQUIRED_KEYS)
116
+ this = HuggingFaceLocalEmbeddingClientConfig(
117
+ # Required parameters
118
+ model=config.pop(MODEL_CONFIG_KEY),
119
+ provider=config.pop(PROVIDER_CONFIG_KEY),
120
+ # Optional
121
+ multi_process=config.pop(HUGGINGFACE_MULTIPROCESS_CONFIG_KEY, False),
122
+ cache_folder=config.pop(
123
+ HUGGINGFACE_CACHE_FOLDER_CONFIG_KEY,
124
+ str(HUGGINGFACE_LOCAL_EMBEDDING_CACHING_FOLDER),
125
+ ),
126
+ show_progress=config.pop(HUGGINGFACE_SHOW_PROGRESS_CONFIG_KEY, False),
127
+ model_kwargs=config.pop(HUGGINGFACE_MODEL_KWARGS_CONFIG_KEY, {}),
128
+ encode_kwargs=config.pop(HUGGINGFACE_ENCODE_KWARGS_CONFIG_KEY, {}),
129
+ )
130
+ return this
131
+
132
+ def to_dict(self) -> dict:
133
+ """Converts the config instance into a dictionary."""
134
+ return asdict(self)
135
+
136
+ @staticmethod
137
+ def resolve_config_aliases(config: Dict[str, Any]) -> Dict[str, Any]:
138
+ config = _resolve_huggingface_deprecated_switch_value(config)
139
+ return resolve_aliases(config, DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING)
140
+
141
+
142
+ def is_huggingface_local_config(config: dict) -> bool:
143
+ """Check whether the configuration is meant to configure
144
+ a local HuggingFace embedding client.
145
+ """
146
+ # Hugging face special deprecated cases:
147
+ # 1. type: huggingface
148
+ # 2. _type: huggingface
149
+ # If the deprecated setting is detected resolve both alias key and key
150
+ # value. This would mean that the configurations above will be
151
+ # transformed to:
152
+ # provider: huggingface_local
153
+ config = HuggingFaceLocalEmbeddingClientConfig.resolve_config_aliases(config)
154
+
155
+ # Case: Configuration contains `provider: huggingface_local`
156
+ if config.get(PROVIDER_CONFIG_KEY) in [
157
+ HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER,
158
+ ]:
159
+ return True
160
+
161
+ return False
162
+
163
+
164
+ def _raise_deprecation_warning_for_huggingface_deprecated_switch_value(
165
+ config: dict,
166
+ ) -> None:
167
+ deprecated_switch_keys = [RASA_TYPE_CONFIG_KEY, LANGCHAIN_TYPE_CONFIG_KEY]
168
+ deprecation_message = (
169
+ f"Configuration "
170
+ f"`{{deprecated_switch_key}}: {DEPRECATED_HUGGINGFACE_TYPE}` "
171
+ f"is deprecated and will be removed in 4.0.0. "
172
+ f"Please use "
173
+ f"`{PROVIDER_CONFIG_KEY}: {HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER}` "
174
+ f"instead."
175
+ )
176
+ for deprecated_switch_key in deprecated_switch_keys:
177
+ if (
178
+ deprecated_switch_key in config
179
+ and config[deprecated_switch_key] == DEPRECATED_HUGGINGFACE_TYPE
180
+ ):
181
+ raise_deprecation_warning(
182
+ message=deprecation_message.format(
183
+ deprecated_switch_key=deprecated_switch_key
184
+ )
185
+ )
186
+
187
+
188
+ def _resolve_huggingface_deprecated_switch_value(config: dict) -> dict:
189
+ """
190
+ Resolve use of deprecated switching mechanism for HuggingFace local
191
+ embedding client.
192
+
193
+ The following settings (key + value) are deprecated:
194
+ 1. `type: huggingface`
195
+ 2. `_type: huggingface`
196
+ in favor of `provider: huggingface_local`.
197
+
198
+
199
+ Args:
200
+ config: given config
201
+
202
+ Returns:
203
+ New config with resolved switch mechanism
204
+
205
+ """
206
+ config = config.copy()
207
+
208
+ deprecated_switch_keys = [RASA_TYPE_CONFIG_KEY, LANGCHAIN_TYPE_CONFIG_KEY]
209
+ debug_message = (
210
+ f"Switching "
211
+ f"`{{deprecated_switch_key}}: {DEPRECATED_HUGGINGFACE_TYPE}` "
212
+ f"to `{PROVIDER_CONFIG_KEY}: {HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER}`."
213
+ )
214
+
215
+ for deprecated_switch_key in deprecated_switch_keys:
216
+ if (
217
+ deprecated_switch_key in config
218
+ and config[deprecated_switch_key] == DEPRECATED_HUGGINGFACE_TYPE
219
+ ):
220
+ # Update configuration with new switch mechanism
221
+ config[PROVIDER_CONFIG_KEY] = HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER
222
+ # Pop the deprecated key used
223
+ config.pop(deprecated_switch_key, None)
224
+
225
+ structlogger.debug(
226
+ "HuggingFaceLocalEmbeddingClientConfig"
227
+ "._resolve_huggingface_deprecated_switch_value",
228
+ message=debug_message.format(
229
+ deprecated_switch_key=deprecated_switch_key
230
+ ),
231
+ new_config=config,
232
+ )
233
+
234
+ return config
@@ -0,0 +1,175 @@
1
+ from dataclasses import asdict, dataclass, field
2
+ from typing import Any, Dict, Optional
3
+
4
+ import structlog
5
+
6
+ from rasa.shared.constants import (
7
+ MODEL_CONFIG_KEY,
8
+ MODEL_NAME_CONFIG_KEY,
9
+ OPENAI_API_BASE_CONFIG_KEY,
10
+ API_BASE_CONFIG_KEY,
11
+ OPENAI_API_TYPE_CONFIG_KEY,
12
+ API_TYPE_CONFIG_KEY,
13
+ OPENAI_API_VERSION_CONFIG_KEY,
14
+ API_VERSION_CONFIG_KEY,
15
+ RASA_TYPE_CONFIG_KEY,
16
+ LANGCHAIN_TYPE_CONFIG_KEY,
17
+ STREAM_CONFIG_KEY,
18
+ N_REPHRASES_CONFIG_KEY,
19
+ REQUEST_TIMEOUT_CONFIG_KEY,
20
+ TIMEOUT_CONFIG_KEY,
21
+ PROVIDER_CONFIG_KEY,
22
+ OPENAI_PROVIDER,
23
+ OPENAI_API_TYPE,
24
+ )
25
+ from rasa.shared.providers._configs.utils import (
26
+ resolve_aliases,
27
+ validate_required_keys,
28
+ raise_deprecation_warnings,
29
+ validate_forbidden_keys,
30
+ )
31
+
32
+ structlogger = structlog.get_logger()
33
+
34
+
35
+ DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING = {
36
+ # Model name aliases
37
+ MODEL_NAME_CONFIG_KEY: MODEL_CONFIG_KEY,
38
+ # Provider aliases
39
+ RASA_TYPE_CONFIG_KEY: PROVIDER_CONFIG_KEY,
40
+ LANGCHAIN_TYPE_CONFIG_KEY: PROVIDER_CONFIG_KEY,
41
+ # API type aliases
42
+ OPENAI_API_TYPE_CONFIG_KEY: API_TYPE_CONFIG_KEY,
43
+ # API base aliases
44
+ OPENAI_API_BASE_CONFIG_KEY: API_BASE_CONFIG_KEY,
45
+ # API version aliases
46
+ OPENAI_API_VERSION_CONFIG_KEY: API_VERSION_CONFIG_KEY,
47
+ # Timeout aliases
48
+ REQUEST_TIMEOUT_CONFIG_KEY: TIMEOUT_CONFIG_KEY,
49
+ }
50
+
51
+ REQUIRED_KEYS = [MODEL_CONFIG_KEY]
52
+
53
+ FORBIDDEN_KEYS = [
54
+ STREAM_CONFIG_KEY,
55
+ N_REPHRASES_CONFIG_KEY,
56
+ ]
57
+
58
+
59
+ @dataclass
60
+ class OpenAIClientConfig:
61
+ """Parses configuration for Azure OpenAI client, resolves aliases and
62
+ raises deprecation warnings.
63
+
64
+ Raises:
65
+ ValueError: Raised in cases of invalid configuration:
66
+ - If any of the required configuration keys are missing.
67
+ - If `api_type` has a value different from `openai`.
68
+ """
69
+
70
+ model: str
71
+ api_base: Optional[str]
72
+ api_version: Optional[str]
73
+
74
+ # API Type is not actually used by LiteLLM backend, but we define
75
+ # it here for backward compatibility.
76
+ api_type: str = OPENAI_API_TYPE
77
+
78
+ # Provider is not used by LiteLLM backend, but we define
79
+ # it here since it's used as switch between different
80
+ # clients
81
+ provider: str = OPENAI_PROVIDER
82
+
83
+ extra_parameters: dict = field(default_factory=dict)
84
+
85
+ def __post_init__(self) -> None:
86
+ # In case of OpenAI hosting, it doesn't make sense
87
+ # for API type to be anything else that 'openai'
88
+ if self.api_type != OPENAI_API_TYPE:
89
+ message = f"API type must be set to '{OPENAI_API_TYPE}'."
90
+ structlogger.error(
91
+ "openai_client_config.validation_error",
92
+ message=message,
93
+ api_type=self.api_type,
94
+ )
95
+ raise ValueError(message)
96
+ if self.provider != OPENAI_PROVIDER:
97
+ message = f"Provider must be set to '{OPENAI_PROVIDER}'."
98
+ structlogger.error(
99
+ "openai_client_config.validation_error",
100
+ message=message,
101
+ provider=self.provider,
102
+ )
103
+ raise ValueError(message)
104
+ if self.model is None:
105
+ message = "Model cannot be set to None."
106
+ structlogger.error(
107
+ "openai_client_config.validation_error",
108
+ message=message,
109
+ model=self.model,
110
+ )
111
+ raise ValueError(message)
112
+
113
+ @classmethod
114
+ def from_dict(cls, config: dict) -> "OpenAIClientConfig":
115
+ """
116
+ Initializes a dataclass from the passed config.
117
+
118
+ Args:
119
+ config: (dict) The config from which to initialize.
120
+
121
+ Raises:
122
+ ValueError: Config is missing required keys.
123
+
124
+ Returns:
125
+ AzureOpenAIClientConfig
126
+ """
127
+ # Check for deprecated keys
128
+ raise_deprecation_warnings(config, DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING)
129
+ # Resolve any potential aliases
130
+ config = cls.resolve_config_aliases(config)
131
+ # Validate that the required keys are present
132
+ validate_required_keys(config, REQUIRED_KEYS)
133
+ # Validate that the forbidden keys are not present
134
+ validate_forbidden_keys(config, FORBIDDEN_KEYS)
135
+ this = OpenAIClientConfig(
136
+ # Required parameters
137
+ model=config.pop(MODEL_CONFIG_KEY),
138
+ # Pop the 'provider' key. Currently, it's *optional* because of
139
+ # backward compatibility with older versions.
140
+ provider=config.pop(PROVIDER_CONFIG_KEY, OPENAI_PROVIDER),
141
+ # Optional parameters
142
+ api_base=config.pop(API_BASE_CONFIG_KEY, None),
143
+ api_version=config.pop(API_VERSION_CONFIG_KEY, None),
144
+ api_type=config.pop(API_TYPE_CONFIG_KEY, OPENAI_API_TYPE),
145
+ # The rest of parameters (e.g. model parameters) are considered
146
+ # as extra parameters (this also includes timeout).
147
+ extra_parameters=config,
148
+ )
149
+ return this
150
+
151
+ def to_dict(self) -> dict:
152
+ """Converts the config instance into a dictionary."""
153
+ d = asdict(self)
154
+ # Extra parameters should also be on the top level
155
+ d.pop("extra_parameters", None)
156
+ d.update(self.extra_parameters)
157
+ return d
158
+
159
+ @staticmethod
160
+ def resolve_config_aliases(config: Dict[str, Any]) -> Dict[str, Any]:
161
+ return resolve_aliases(config, DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING)
162
+
163
+
164
+ def is_openai_config(config: dict) -> bool:
165
+ """Check whether the configuration is meant to configure
166
+ an OpenAI client.
167
+ """
168
+ # Process the config to handle all the aliases
169
+ config = OpenAIClientConfig.resolve_config_aliases(config)
170
+
171
+ # Case: Configuration contains `provider: openai`
172
+ if config.get(PROVIDER_CONFIG_KEY) == OPENAI_PROVIDER:
173
+ return True
174
+
175
+ return False
@@ -0,0 +1,176 @@
1
+ from dataclasses import asdict, dataclass, field
2
+ from typing import Any, Dict, Optional
3
+
4
+ import structlog
5
+
6
+ from rasa.shared.constants import (
7
+ MODEL_CONFIG_KEY,
8
+ MODEL_NAME_CONFIG_KEY,
9
+ OPENAI_API_BASE_CONFIG_KEY,
10
+ API_BASE_CONFIG_KEY,
11
+ OPENAI_API_TYPE_CONFIG_KEY,
12
+ API_TYPE_CONFIG_KEY,
13
+ OPENAI_API_VERSION_CONFIG_KEY,
14
+ API_VERSION_CONFIG_KEY,
15
+ RASA_TYPE_CONFIG_KEY,
16
+ LANGCHAIN_TYPE_CONFIG_KEY,
17
+ STREAM_CONFIG_KEY,
18
+ N_REPHRASES_CONFIG_KEY,
19
+ REQUEST_TIMEOUT_CONFIG_KEY,
20
+ TIMEOUT_CONFIG_KEY,
21
+ PROVIDER_CONFIG_KEY,
22
+ OPENAI_PROVIDER,
23
+ SELF_HOSTED_PROVIDER,
24
+ USE_CHAT_COMPLETIONS_ENDPOINT_CONFIG_KEY,
25
+ )
26
+ from rasa.shared.providers._configs.utils import (
27
+ raise_deprecation_warnings,
28
+ resolve_aliases,
29
+ validate_forbidden_keys,
30
+ validate_required_keys,
31
+ )
32
+
33
+ structlogger = structlog.get_logger()
34
+
35
+
36
+ DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING = {
37
+ # Model name aliases
38
+ MODEL_NAME_CONFIG_KEY: MODEL_CONFIG_KEY,
39
+ # Provider aliases
40
+ RASA_TYPE_CONFIG_KEY: PROVIDER_CONFIG_KEY,
41
+ LANGCHAIN_TYPE_CONFIG_KEY: PROVIDER_CONFIG_KEY,
42
+ # API type aliases
43
+ OPENAI_API_TYPE_CONFIG_KEY: API_TYPE_CONFIG_KEY,
44
+ # API base aliases
45
+ OPENAI_API_BASE_CONFIG_KEY: API_BASE_CONFIG_KEY,
46
+ # API version aliases
47
+ OPENAI_API_VERSION_CONFIG_KEY: API_VERSION_CONFIG_KEY,
48
+ # Timeout aliases
49
+ REQUEST_TIMEOUT_CONFIG_KEY: TIMEOUT_CONFIG_KEY,
50
+ }
51
+
52
+ REQUIRED_KEYS = [API_BASE_CONFIG_KEY, MODEL_CONFIG_KEY, PROVIDER_CONFIG_KEY]
53
+
54
+ FORBIDDEN_KEYS = [
55
+ STREAM_CONFIG_KEY,
56
+ N_REPHRASES_CONFIG_KEY,
57
+ ]
58
+
59
+
60
+ @dataclass
61
+ class SelfHostedLLMClientConfig:
62
+ """Parses configuration for Self Hosted LiteLLM client, resolves aliases and
63
+ raises deprecation warnings.
64
+
65
+ Raises:
66
+ ValueError: Raised in cases of invalid configuration:
67
+ - If any of the required configuration keys are missing.
68
+ """
69
+
70
+ model: str
71
+ provider: str
72
+ api_base: str
73
+ api_version: Optional[str] = None
74
+ api_type: Optional[str] = OPENAI_PROVIDER
75
+ use_chat_completions_endpoint: Optional[bool] = True
76
+ extra_parameters: dict = field(default_factory=dict)
77
+
78
+ def __post_init__(self) -> None:
79
+ if self.model is None:
80
+ message = "Model cannot be set to None."
81
+ structlogger.error(
82
+ "self_hosted_llm_client_config.validation_error",
83
+ message=message,
84
+ model=self.model,
85
+ )
86
+ raise ValueError(message)
87
+ if self.provider is None:
88
+ message = "Provider cannot be set to None."
89
+ structlogger.error(
90
+ "self_hosted_llm_client_config.validation_error",
91
+ message=message,
92
+ provider=self.provider,
93
+ )
94
+ raise ValueError(message)
95
+ if self.api_base is None:
96
+ message = "API base cannot be set to None."
97
+ structlogger.error(
98
+ "self_hosted_llm_client_config.validation_error",
99
+ message=message,
100
+ provider=self.provider,
101
+ )
102
+ raise ValueError(message)
103
+ if self.api_type != OPENAI_PROVIDER:
104
+ message = (
105
+ f"Currently supports only {OPENAI_PROVIDER} endpoints. "
106
+ f"API type must be set to '{OPENAI_PROVIDER}'."
107
+ )
108
+ structlogger.error(
109
+ "self_hosted_llm_client_config.validation_error",
110
+ message=message,
111
+ api_type=self.api_type,
112
+ )
113
+ raise ValueError(message)
114
+
115
+ @classmethod
116
+ def from_dict(cls, config: dict) -> "SelfHostedLLMClientConfig":
117
+ """
118
+ Initializes a dataclass from the passed config.
119
+
120
+ Args:
121
+ config: (dict) The config from which to initialize.
122
+
123
+ Raises:
124
+ ValueError: Config is missing required keys.
125
+
126
+ Returns:
127
+ DefaultLiteLLMClientConfig
128
+ """
129
+ # Check for deprecated keys
130
+ raise_deprecation_warnings(config, DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING)
131
+ # Resolve any potential aliases
132
+ config = cls.resolve_config_aliases(config)
133
+ # Validate that the required keys are present
134
+ validate_required_keys(config, REQUIRED_KEYS)
135
+ # Validate that the forbidden keys are not present
136
+ validate_forbidden_keys(config, FORBIDDEN_KEYS)
137
+ this = SelfHostedLLMClientConfig(
138
+ # Required parameters
139
+ model=config.pop(MODEL_CONFIG_KEY),
140
+ provider=config.pop(PROVIDER_CONFIG_KEY),
141
+ api_base=config.pop(API_BASE_CONFIG_KEY),
142
+ # Optional parameters
143
+ api_type=config.pop(API_TYPE_CONFIG_KEY, OPENAI_PROVIDER),
144
+ api_version=config.pop(API_VERSION_CONFIG_KEY, None),
145
+ use_chat_completions_endpoint=config.pop(
146
+ USE_CHAT_COMPLETIONS_ENDPOINT_CONFIG_KEY, True
147
+ ),
148
+ # The rest of parameters (e.g. model parameters) are considered
149
+ # as extra parameters
150
+ extra_parameters=config,
151
+ )
152
+ return this
153
+
154
+ def to_dict(self) -> dict:
155
+ """Converts the config instance into a dictionary."""
156
+ d = asdict(self)
157
+ # Extra parameters should also be on the top level
158
+ d.pop("extra_parameters", None)
159
+ d.update(self.extra_parameters)
160
+ return d
161
+
162
+ @staticmethod
163
+ def resolve_config_aliases(config: Dict[str, Any]) -> Dict[str, Any]:
164
+ return resolve_aliases(config, DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING)
165
+
166
+
167
+ def is_self_hosted_config(config: dict) -> bool:
168
+ """Check whether the configuration is meant to configure an self-hosted client."""
169
+ # Process the config to handle all the aliases
170
+ config = SelfHostedLLMClientConfig.resolve_config_aliases(config)
171
+
172
+ # Case: Configuration contains `provider: self-hosted`
173
+ if config.get(PROVIDER_CONFIG_KEY) == SELF_HOSTED_PROVIDER:
174
+ return True
175
+
176
+ return False
@@ -0,0 +1,101 @@
1
+ import structlog
2
+ from rasa.shared.utils.io import raise_deprecation_warning
3
+
4
+ structlogger = structlog.get_logger()
5
+
6
+
7
+ def resolve_aliases(config: dict, deprecated_alias_mapping: dict) -> dict:
8
+ """
9
+ Resolve aliases in the configuration to standard keys.
10
+
11
+ Args:
12
+ config: Dictionary containing the configuration.
13
+ deprecated_alias_mapping: Dictionary mapping aliases to
14
+ their standard keys.
15
+
16
+ Returns:
17
+ New dictionary containing the processed configuration.
18
+ """
19
+ config = config.copy()
20
+
21
+ for alias, standard_key in deprecated_alias_mapping.items():
22
+ # We check for the alias instead of the standard key because our goal is to
23
+ # update the standard key when the alias is found. Since the standard key is
24
+ # always included in the default component configurations, we overwrite it
25
+ # with the alias value if the alias exists.
26
+ if alias in config:
27
+ config[standard_key] = config.pop(alias)
28
+
29
+ return config
30
+
31
+
32
+ def raise_deprecation_warnings(config: dict, deprecated_alias_mapping: dict) -> None:
33
+ """
34
+ Raises warnings for deprecated keys in the configuration.
35
+
36
+ Args:
37
+ config: Dictionary containing the configuration.
38
+ deprecated_alias_mapping: Dictionary mapping deprecated keys to
39
+ their standard keys.
40
+
41
+ Raises:
42
+ DeprecationWarning: If any deprecated key is found in the config.
43
+ """
44
+ for alias, standard_key in deprecated_alias_mapping.items():
45
+ if alias in config:
46
+ raise_deprecation_warning(
47
+ message=(
48
+ f"'{alias}' is deprecated and will be removed in "
49
+ f"4.0.0. Use '{standard_key}' instead."
50
+ )
51
+ )
52
+
53
+
54
+ def validate_required_keys(config: dict, required_keys: list) -> None:
55
+ """
56
+ Validates that the passed config contains all the required keys.
57
+
58
+ Args:
59
+ config: Dictionary containing the configuration.
60
+ required_keys: List of keys that must be present in the config.
61
+
62
+ Raises:
63
+ ValueError: If any required key is missing.
64
+ """
65
+ missing_keys = [key for key in required_keys if key not in config]
66
+ if missing_keys:
67
+ message = f"Missing required keys '{missing_keys}' for configuration."
68
+ structlogger.error(
69
+ "validate_required_keys",
70
+ message=message,
71
+ missing_keys=missing_keys,
72
+ config=config,
73
+ )
74
+ raise ValueError(message)
75
+
76
+
77
+ def validate_forbidden_keys(config: dict, forbidden_keys: list) -> None:
78
+ """
79
+ Validates that the passed config doesn't contain any forbidden keys.
80
+
81
+ Args:
82
+ config: Dictionary containing the configuration.
83
+ forbidden_keys: List of keys that are forbidden in the config.
84
+
85
+ Raises:
86
+ ValueError: If any forbidden key is present.
87
+ """
88
+ forbidden_keys_in_config = set(config.keys()).intersection(set(forbidden_keys))
89
+
90
+ if forbidden_keys_in_config:
91
+ message = (
92
+ f"Forbidden keys '{forbidden_keys_in_config}' present "
93
+ f"in the configuration."
94
+ )
95
+ structlogger.error(
96
+ "validate_forbidden_keys",
97
+ message=message,
98
+ forbidden_keys=forbidden_keys_in_config,
99
+ config=config,
100
+ )
101
+ raise ValueError(message)