rasa-pro 3.9.18__py3-none-any.whl → 3.10.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (183) hide show
  1. README.md +0 -374
  2. rasa/__init__.py +1 -2
  3. rasa/__main__.py +5 -0
  4. rasa/anonymization/anonymization_rule_executor.py +2 -2
  5. rasa/api.py +27 -23
  6. rasa/cli/arguments/data.py +27 -2
  7. rasa/cli/arguments/default_arguments.py +25 -3
  8. rasa/cli/arguments/run.py +9 -9
  9. rasa/cli/arguments/train.py +11 -3
  10. rasa/cli/data.py +70 -8
  11. rasa/cli/e2e_test.py +104 -431
  12. rasa/cli/evaluate.py +1 -1
  13. rasa/cli/interactive.py +1 -0
  14. rasa/cli/llm_fine_tuning.py +398 -0
  15. rasa/cli/project_templates/calm/endpoints.yml +1 -1
  16. rasa/cli/project_templates/tutorial/endpoints.yml +1 -1
  17. rasa/cli/run.py +15 -14
  18. rasa/cli/scaffold.py +10 -8
  19. rasa/cli/studio/studio.py +35 -5
  20. rasa/cli/train.py +56 -8
  21. rasa/cli/utils.py +22 -5
  22. rasa/cli/x.py +1 -1
  23. rasa/constants.py +7 -1
  24. rasa/core/actions/action.py +98 -49
  25. rasa/core/actions/action_run_slot_rejections.py +4 -1
  26. rasa/core/actions/custom_action_executor.py +9 -6
  27. rasa/core/actions/direct_custom_actions_executor.py +80 -0
  28. rasa/core/actions/e2e_stub_custom_action_executor.py +68 -0
  29. rasa/core/actions/grpc_custom_action_executor.py +2 -2
  30. rasa/core/actions/http_custom_action_executor.py +6 -5
  31. rasa/core/agent.py +21 -17
  32. rasa/core/channels/__init__.py +2 -0
  33. rasa/core/channels/audiocodes.py +1 -16
  34. rasa/core/channels/voice_aware/__init__.py +0 -0
  35. rasa/core/channels/voice_aware/jambonz.py +103 -0
  36. rasa/core/channels/voice_aware/jambonz_protocol.py +344 -0
  37. rasa/core/channels/voice_aware/utils.py +20 -0
  38. rasa/core/channels/voice_native/__init__.py +0 -0
  39. rasa/core/constants.py +6 -1
  40. rasa/core/information_retrieval/faiss.py +7 -4
  41. rasa/core/information_retrieval/information_retrieval.py +8 -0
  42. rasa/core/information_retrieval/milvus.py +9 -2
  43. rasa/core/information_retrieval/qdrant.py +1 -1
  44. rasa/core/nlg/contextual_response_rephraser.py +32 -10
  45. rasa/core/nlg/summarize.py +4 -3
  46. rasa/core/policies/enterprise_search_policy.py +113 -45
  47. rasa/core/policies/flows/flow_executor.py +122 -76
  48. rasa/core/policies/intentless_policy.py +83 -29
  49. rasa/core/processor.py +72 -54
  50. rasa/core/run.py +5 -4
  51. rasa/core/tracker_store.py +8 -4
  52. rasa/core/training/interactive.py +1 -1
  53. rasa/core/utils.py +56 -57
  54. rasa/dialogue_understanding/coexistence/llm_based_router.py +53 -13
  55. rasa/dialogue_understanding/commands/__init__.py +6 -0
  56. rasa/dialogue_understanding/commands/restart_command.py +58 -0
  57. rasa/dialogue_understanding/commands/session_start_command.py +59 -0
  58. rasa/dialogue_understanding/commands/utils.py +40 -0
  59. rasa/dialogue_understanding/generator/constants.py +10 -3
  60. rasa/dialogue_understanding/generator/flow_retrieval.py +21 -5
  61. rasa/dialogue_understanding/generator/llm_based_command_generator.py +13 -3
  62. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +134 -90
  63. rasa/dialogue_understanding/generator/nlu_command_adapter.py +47 -7
  64. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +127 -41
  65. rasa/dialogue_understanding/patterns/restart.py +37 -0
  66. rasa/dialogue_understanding/patterns/session_start.py +37 -0
  67. rasa/dialogue_understanding/processor/command_processor.py +16 -3
  68. rasa/dialogue_understanding/processor/command_processor_component.py +6 -2
  69. rasa/e2e_test/aggregate_test_stats_calculator.py +134 -0
  70. rasa/e2e_test/assertions.py +1223 -0
  71. rasa/e2e_test/assertions_schema.yml +106 -0
  72. rasa/e2e_test/constants.py +20 -0
  73. rasa/e2e_test/e2e_config.py +220 -0
  74. rasa/e2e_test/e2e_config_schema.yml +26 -0
  75. rasa/e2e_test/e2e_test_case.py +131 -8
  76. rasa/e2e_test/e2e_test_converter.py +363 -0
  77. rasa/e2e_test/e2e_test_converter_prompt.jinja2 +70 -0
  78. rasa/e2e_test/e2e_test_coverage_report.py +364 -0
  79. rasa/e2e_test/e2e_test_result.py +26 -6
  80. rasa/e2e_test/e2e_test_runner.py +493 -71
  81. rasa/e2e_test/e2e_test_schema.yml +96 -0
  82. rasa/e2e_test/pykwalify_extensions.py +39 -0
  83. rasa/e2e_test/stub_custom_action.py +70 -0
  84. rasa/e2e_test/utils/__init__.py +0 -0
  85. rasa/e2e_test/utils/e2e_yaml_utils.py +55 -0
  86. rasa/e2e_test/utils/io.py +598 -0
  87. rasa/e2e_test/utils/validation.py +80 -0
  88. rasa/engine/graph.py +9 -3
  89. rasa/engine/recipes/default_components.py +0 -2
  90. rasa/engine/recipes/default_recipe.py +10 -2
  91. rasa/engine/storage/local_model_storage.py +40 -12
  92. rasa/engine/validation.py +78 -1
  93. rasa/env.py +9 -0
  94. rasa/graph_components/providers/story_graph_provider.py +59 -6
  95. rasa/llm_fine_tuning/__init__.py +0 -0
  96. rasa/llm_fine_tuning/annotation_module.py +241 -0
  97. rasa/llm_fine_tuning/conversations.py +144 -0
  98. rasa/llm_fine_tuning/llm_data_preparation_module.py +178 -0
  99. rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +407 -0
  100. rasa/llm_fine_tuning/paraphrasing/__init__.py +0 -0
  101. rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +281 -0
  102. rasa/llm_fine_tuning/paraphrasing/default_rephrase_prompt_template.jina2 +44 -0
  103. rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +121 -0
  104. rasa/llm_fine_tuning/paraphrasing/rephrased_user_message.py +10 -0
  105. rasa/llm_fine_tuning/paraphrasing_module.py +128 -0
  106. rasa/llm_fine_tuning/storage.py +174 -0
  107. rasa/llm_fine_tuning/train_test_split_module.py +441 -0
  108. rasa/model_training.py +56 -16
  109. rasa/nlu/persistor.py +157 -36
  110. rasa/server.py +45 -10
  111. rasa/shared/constants.py +76 -16
  112. rasa/shared/core/domain.py +27 -19
  113. rasa/shared/core/events.py +28 -2
  114. rasa/shared/core/flows/flow.py +208 -13
  115. rasa/shared/core/flows/flow_path.py +84 -0
  116. rasa/shared/core/flows/flows_list.py +33 -11
  117. rasa/shared/core/flows/flows_yaml_schema.json +269 -193
  118. rasa/shared/core/flows/validation.py +112 -25
  119. rasa/shared/core/flows/yaml_flows_io.py +149 -10
  120. rasa/shared/core/trackers.py +6 -0
  121. rasa/shared/core/training_data/structures.py +20 -0
  122. rasa/shared/core/training_data/visualization.html +2 -2
  123. rasa/shared/exceptions.py +4 -0
  124. rasa/shared/importers/importer.py +64 -16
  125. rasa/shared/nlu/constants.py +2 -0
  126. rasa/shared/providers/_configs/__init__.py +0 -0
  127. rasa/shared/providers/_configs/azure_openai_client_config.py +183 -0
  128. rasa/shared/providers/_configs/client_config.py +57 -0
  129. rasa/shared/providers/_configs/default_litellm_client_config.py +130 -0
  130. rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +234 -0
  131. rasa/shared/providers/_configs/openai_client_config.py +175 -0
  132. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +176 -0
  133. rasa/shared/providers/_configs/utils.py +101 -0
  134. rasa/shared/providers/_ssl_verification_utils.py +124 -0
  135. rasa/shared/providers/embedding/__init__.py +0 -0
  136. rasa/shared/providers/embedding/_base_litellm_embedding_client.py +259 -0
  137. rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py +74 -0
  138. rasa/shared/providers/embedding/azure_openai_embedding_client.py +277 -0
  139. rasa/shared/providers/embedding/default_litellm_embedding_client.py +102 -0
  140. rasa/shared/providers/embedding/embedding_client.py +90 -0
  141. rasa/shared/providers/embedding/embedding_response.py +41 -0
  142. rasa/shared/providers/embedding/huggingface_local_embedding_client.py +191 -0
  143. rasa/shared/providers/embedding/openai_embedding_client.py +172 -0
  144. rasa/shared/providers/llm/__init__.py +0 -0
  145. rasa/shared/providers/llm/_base_litellm_client.py +251 -0
  146. rasa/shared/providers/llm/azure_openai_llm_client.py +338 -0
  147. rasa/shared/providers/llm/default_litellm_llm_client.py +84 -0
  148. rasa/shared/providers/llm/llm_client.py +76 -0
  149. rasa/shared/providers/llm/llm_response.py +50 -0
  150. rasa/shared/providers/llm/openai_llm_client.py +155 -0
  151. rasa/shared/providers/llm/self_hosted_llm_client.py +293 -0
  152. rasa/shared/providers/mappings.py +75 -0
  153. rasa/shared/utils/cli.py +30 -0
  154. rasa/shared/utils/io.py +65 -2
  155. rasa/shared/utils/llm.py +246 -200
  156. rasa/shared/utils/yaml.py +121 -15
  157. rasa/studio/auth.py +6 -4
  158. rasa/studio/config.py +13 -4
  159. rasa/studio/constants.py +1 -0
  160. rasa/studio/data_handler.py +10 -3
  161. rasa/studio/download.py +19 -13
  162. rasa/studio/train.py +2 -3
  163. rasa/studio/upload.py +19 -11
  164. rasa/telemetry.py +113 -58
  165. rasa/tracing/instrumentation/attribute_extractors.py +32 -17
  166. rasa/utils/common.py +18 -19
  167. rasa/utils/endpoints.py +7 -4
  168. rasa/utils/json_utils.py +60 -0
  169. rasa/utils/licensing.py +9 -1
  170. rasa/utils/ml_utils.py +4 -2
  171. rasa/validator.py +213 -3
  172. rasa/version.py +1 -1
  173. rasa_pro-3.10.16.dist-info/METADATA +196 -0
  174. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/RECORD +179 -113
  175. rasa/nlu/classifiers/llm_intent_classifier.py +0 -519
  176. rasa/shared/providers/openai/clients.py +0 -43
  177. rasa/shared/providers/openai/session_handler.py +0 -110
  178. rasa_pro-3.9.18.dist-info/METADATA +0 -563
  179. /rasa/{shared/providers/openai → cli/project_templates/tutorial/actions}/__init__.py +0 -0
  180. /rasa/cli/project_templates/tutorial/{actions.py → actions/actions.py} +0 -0
  181. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/NOTICE +0 -0
  182. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/WHEEL +0 -0
  183. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,338 @@
1
+ import os
2
+ import re
3
+ from typing import Dict, Any, Optional
4
+
5
+ import structlog
6
+
7
+ from rasa.shared.constants import (
8
+ OPENAI_API_BASE_ENV_VAR,
9
+ OPENAI_API_VERSION_ENV_VAR,
10
+ AZURE_API_BASE_ENV_VAR,
11
+ AZURE_API_VERSION_ENV_VAR,
12
+ API_BASE_CONFIG_KEY,
13
+ API_VERSION_CONFIG_KEY,
14
+ DEPLOYMENT_CONFIG_KEY,
15
+ AZURE_API_KEY_ENV_VAR,
16
+ OPENAI_API_TYPE_ENV_VAR,
17
+ OPENAI_API_KEY_ENV_VAR,
18
+ AZURE_API_TYPE_ENV_VAR,
19
+ AZURE_OPENAI_PROVIDER,
20
+ )
21
+ from rasa.shared.exceptions import ProviderClientValidationError
22
+ from rasa.shared.providers._configs.azure_openai_client_config import (
23
+ AzureOpenAIClientConfig,
24
+ )
25
+ from rasa.shared.providers.llm._base_litellm_client import _BaseLiteLLMClient
26
+ from rasa.shared.utils.io import raise_deprecation_warning
27
+
28
+ structlogger = structlog.get_logger()
29
+
30
+
31
+ class AzureOpenAILLMClient(_BaseLiteLLMClient):
32
+ """
33
+ A client for interfacing with Azure's OpenAI LLM deployments.
34
+
35
+ Parameters:
36
+ deployment (str): The deployment name.
37
+ model (Optional[str]): The name of the deployed model.
38
+ api_type: (Optional[str]): The api type. If not provided, it will be set via
39
+ environment variable.
40
+ api_base (Optional[str]): The base URL for the API endpoints. If not provided,
41
+ it will be set via environment variables.
42
+ api_version (Optional[str]): The version of the API to use. If not provided,
43
+ it will be set via environment variable.
44
+ kwargs (Optional[Dict[str, Any]]): Optional configuration parameters specific
45
+ to the model deployment.
46
+
47
+ Raises:
48
+ ProviderClientValidationError: If validation of the client setup fails.
49
+ DeprecationWarning: If deprecated environment variables are used for
50
+ configuration.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ deployment: str,
56
+ model: Optional[str] = None,
57
+ api_type: Optional[str] = None,
58
+ api_base: Optional[str] = None,
59
+ api_version: Optional[str] = None,
60
+ **kwargs: Any,
61
+ ):
62
+ super().__init__() # type: ignore
63
+ self._deployment = deployment
64
+ self._model = model
65
+ self._extra_parameters = kwargs or {}
66
+
67
+ # Set api_base with the following priority:
68
+ # parameter -> Azure Env Var -> (deprecated) OpenAI Env Var
69
+ self._api_base = (
70
+ api_base
71
+ or os.getenv(AZURE_API_BASE_ENV_VAR)
72
+ or os.getenv(OPENAI_API_BASE_ENV_VAR)
73
+ )
74
+
75
+ # Set api_version with the following priority:
76
+ # parameter -> Azure Env Var -> (deprecated) OpenAI Env Var
77
+ self._api_version = (
78
+ api_version
79
+ or os.getenv(AZURE_API_VERSION_ENV_VAR)
80
+ or os.getenv(OPENAI_API_VERSION_ENV_VAR)
81
+ )
82
+
83
+ # API key can be set through OPENAI_API_KEY too,
84
+ # because of the backward compatibility
85
+ self._api_key = os.getenv(AZURE_API_KEY_ENV_VAR) or os.getenv(
86
+ OPENAI_API_KEY_ENV_VAR
87
+ )
88
+
89
+ # Not used by LiteLLM, here for backward compatibility
90
+ self._api_type = (
91
+ api_type
92
+ or os.getenv(AZURE_API_TYPE_ENV_VAR)
93
+ or os.getenv(OPENAI_API_TYPE_ENV_VAR)
94
+ )
95
+
96
+ # Run helper function to check and raise deprecation warning if
97
+ # deprecated environment variables were used for initialization of the
98
+ # client settings
99
+ self._raise_evn_var_deprecation_warnings()
100
+
101
+ # validate the client settings
102
+ self.validate_client_setup()
103
+
104
+ def _raise_evn_var_deprecation_warnings(self) -> None:
105
+ """Helper function to check and raise deprecation warning if
106
+ deprecated environment variables were used for initialization of
107
+ some client settings.
108
+ """
109
+ deprecation_mapping = {
110
+ "API Base": {
111
+ "current_value": self.api_base,
112
+ "env_var": AZURE_API_BASE_ENV_VAR,
113
+ "deprecated_var": OPENAI_API_BASE_ENV_VAR,
114
+ },
115
+ "API Version": {
116
+ "current_value": self.api_version,
117
+ "env_var": AZURE_API_VERSION_ENV_VAR,
118
+ "deprecated_var": OPENAI_API_VERSION_ENV_VAR,
119
+ },
120
+ "API Key": {
121
+ "current_value": self._api_key,
122
+ "env_var": AZURE_API_KEY_ENV_VAR,
123
+ "deprecated_var": OPENAI_API_KEY_ENV_VAR,
124
+ },
125
+ }
126
+
127
+ deprecation_warning_message = (
128
+ "Usage of {deprecated_env_var} environment "
129
+ "variable for setting the {setting} for Azure "
130
+ "OpenAI client is deprecated and will be removed "
131
+ "in 4.0.0. "
132
+ )
133
+ deprecation_warning_replacement_message = (
134
+ "Please use {env_var} environment variable."
135
+ )
136
+
137
+ for setting in deprecation_mapping.keys():
138
+ current_value = deprecation_mapping[setting]["current_value"]
139
+ env_var = deprecation_mapping[setting]["env_var"]
140
+ deprecated_var = deprecation_mapping[setting]["deprecated_var"]
141
+
142
+ # Value is set through the non-deprecated env var
143
+ if current_value == os.getenv(env_var):
144
+ continue
145
+
146
+ # Value is set through the deprecated env var
147
+ if current_value == os.getenv(deprecated_var):
148
+ message = deprecation_warning_message.format(
149
+ setting=setting, deprecated_env_var=deprecated_var
150
+ )
151
+ if env_var is not None:
152
+ message += deprecation_warning_replacement_message.format(
153
+ env_var=env_var
154
+ )
155
+ raise_deprecation_warning(message=message)
156
+
157
+ @classmethod
158
+ def from_config(cls, config: Dict[str, Any]) -> "AzureOpenAILLMClient":
159
+ """
160
+ Initializes the client from given configuration.
161
+
162
+ Args:
163
+ config (Dict[str, Any]): Configuration.
164
+
165
+ Raises:
166
+ ValueError:
167
+ If any of the required configuration keys are missing.
168
+ If `api_type` has a value different from `azure`.
169
+
170
+ Returns:
171
+ AzureOpenAILLMClient: Initialized client.
172
+ """
173
+ try:
174
+ azure_openai_config = AzureOpenAIClientConfig.from_dict(config)
175
+ except ValueError as e:
176
+ message = "Cannot instantiate a client from the passed configuration."
177
+ structlogger.error(
178
+ "azure_openai_llm_client.from_config.error",
179
+ message=message,
180
+ config=config,
181
+ original_error=e,
182
+ )
183
+ raise
184
+
185
+ return cls(
186
+ azure_openai_config.deployment,
187
+ azure_openai_config.model,
188
+ azure_openai_config.api_type,
189
+ azure_openai_config.api_base,
190
+ azure_openai_config.api_version,
191
+ **azure_openai_config.extra_parameters,
192
+ )
193
+
194
+ @property
195
+ def config(self) -> dict:
196
+ """Returns the configuration for that the llm client
197
+ in dictionary form.
198
+ """
199
+ config = AzureOpenAIClientConfig(
200
+ deployment=self._deployment,
201
+ model=self._model,
202
+ api_base=self._api_base,
203
+ api_version=self._api_version,
204
+ api_type=self._api_type,
205
+ extra_parameters=self._extra_parameters,
206
+ )
207
+ return config.to_dict()
208
+
209
+ @property
210
+ def deployment(self) -> str:
211
+ return self._deployment
212
+
213
+ @property
214
+ def model(self) -> Optional[str]:
215
+ """
216
+ Returns the name of the model deployed on Azure.
217
+ """
218
+ return self._model
219
+
220
+ @property
221
+ def api_base(self) -> Optional[str]:
222
+ """
223
+ Returns the API base URL for the Azure OpenAI llm client.
224
+ """
225
+ return self._api_base
226
+
227
+ @property
228
+ def api_version(self) -> Optional[str]:
229
+ """
230
+ Returns the API version for the Azure OpenAI llm client.
231
+ """
232
+ return self._api_version
233
+
234
+ @property
235
+ def api_type(self) -> Optional[str]:
236
+ return self._api_type
237
+
238
+ @property
239
+ def _litellm_model_name(self) -> str:
240
+ """Returns the value of LiteLLM's model parameter to be used in
241
+ completion/acompletion in LiteLLM format:
242
+
243
+ <provider>/<model or deployment name>
244
+ """
245
+ regex_pattern = rf"^{AZURE_OPENAI_PROVIDER}/"
246
+ if not re.match(regex_pattern, self._deployment):
247
+ return f"{AZURE_OPENAI_PROVIDER}/{self._deployment}"
248
+ return self._deployment
249
+
250
+ @property
251
+ def _litellm_extra_parameters(self) -> Dict[str, Any]:
252
+ return self._extra_parameters
253
+
254
+ @property
255
+ def _completion_fn_args(self) -> Dict[str, Any]:
256
+ """Returns the completion arguments for invoking a call through
257
+ LiteLLM's completion functions.
258
+ """
259
+ fn_args = super()._completion_fn_args
260
+ fn_args.update(
261
+ {
262
+ "api_base": self.api_base,
263
+ "api_version": self.api_version,
264
+ "api_key": self._api_key,
265
+ }
266
+ )
267
+ return fn_args
268
+
269
+ def validate_client_setup(self) -> None:
270
+ """Validates that all required configuration parameters are set."""
271
+
272
+ def generate_event_info_for_missing_setting(
273
+ setting: str,
274
+ setting_env_var: Optional[str] = None,
275
+ setting_config_key: Optional[str] = None,
276
+ ) -> str:
277
+ """Generate a part of the message with instructions on what to set
278
+ for the missing client setting.
279
+ """
280
+ info = "Set {setting} with {options}. "
281
+ options = ""
282
+ if setting_env_var is not None:
283
+ options += f"environment variable '{setting_env_var}'"
284
+ if setting_config_key is not None and setting_env_var is not None:
285
+ options += " or "
286
+ if setting_config_key is not None:
287
+ options += f"config key '{setting_config_key}'"
288
+
289
+ return info.format(setting=setting, options=options)
290
+
291
+ # All required settings for Azure OpenAI client
292
+ settings: Dict[str, Dict[str, Any]] = {
293
+ "API Base": {
294
+ "current_value": self.api_base,
295
+ "env_var": AZURE_API_BASE_ENV_VAR,
296
+ "config_key": API_BASE_CONFIG_KEY,
297
+ },
298
+ "API Version": {
299
+ "current_value": self.api_version,
300
+ "env_var": AZURE_API_VERSION_ENV_VAR,
301
+ "config_key": API_VERSION_CONFIG_KEY,
302
+ },
303
+ "Deployment Name": {
304
+ "current_value": self.deployment,
305
+ "env_var": None,
306
+ "config_key": DEPLOYMENT_CONFIG_KEY,
307
+ },
308
+ "API Key": {
309
+ "current_value": self._api_key,
310
+ "env_var": AZURE_API_KEY_ENV_VAR,
311
+ "config_key": None,
312
+ },
313
+ }
314
+
315
+ missing_settings = [
316
+ setting_name
317
+ for setting_name, setting_info in settings.items()
318
+ if setting_info["current_value"] is None
319
+ ]
320
+
321
+ if missing_settings:
322
+ event_info = f"Client settings not set: " f"{', '.join(missing_settings)}. "
323
+
324
+ for missing_setting in missing_settings:
325
+ if settings[missing_setting]["current_value"] is not None:
326
+ continue
327
+ event_info += generate_event_info_for_missing_setting(
328
+ missing_setting,
329
+ settings[missing_setting]["env_var"],
330
+ settings[missing_setting]["config_key"],
331
+ )
332
+
333
+ structlogger.error(
334
+ "azure_openai_llm_client.not_configured",
335
+ event_info=event_info,
336
+ missing_settings=missing_settings,
337
+ )
338
+ raise ProviderClientValidationError(event_info)
@@ -0,0 +1,84 @@
1
+ from typing import Dict, Any
2
+
3
+ from rasa.shared.providers._configs.default_litellm_client_config import (
4
+ DefaultLiteLLMClientConfig,
5
+ )
6
+ from rasa.shared.providers.llm._base_litellm_client import _BaseLiteLLMClient
7
+
8
+
9
+ class DefaultLiteLLMClient(_BaseLiteLLMClient):
10
+ """A default client for interfacing with LiteLLM LLM endpoints.
11
+
12
+ Parameters:
13
+ model (str): The model or deployment name.
14
+ kwargs: Any: Additional configuration parameters that can include, but
15
+ are not limited to model parameters and lite-llm specific
16
+ parameters. These parameters will be passed to the
17
+ completion/acompletion calls. To see what it can include, visit:
18
+
19
+ https://docs.litellm.ai/docs/completion/input
20
+ Raises:
21
+ ProviderClientValidationError: If validation of the client setup fails.
22
+ ProviderClientAPIException: If the API request fails.
23
+ """
24
+
25
+ def __init__(self, provider: str, model: str, **kwargs: Any):
26
+ super().__init__() # type: ignore
27
+ self._provider = provider
28
+ self._model = model
29
+ self._extra_parameters = kwargs
30
+ self.validate_client_setup()
31
+
32
+ @classmethod
33
+ def from_config(cls, config: Dict[str, Any]) -> "DefaultLiteLLMClient":
34
+ default_config = DefaultLiteLLMClientConfig.from_dict(config)
35
+ return cls(
36
+ model=default_config.model,
37
+ provider=default_config.provider,
38
+ # Pass the rest of the configuration as extra parameters
39
+ **default_config.extra_parameters,
40
+ )
41
+
42
+ @property
43
+ def provider(self) -> str:
44
+ return self._provider
45
+
46
+ @property
47
+ def model(self) -> str:
48
+ """
49
+
50
+ Returns:
51
+ """
52
+ return self._model
53
+
54
+ @property
55
+ def config(self) -> Dict:
56
+ """
57
+ Returns the configuration for the openai embedding client.
58
+ Returns:
59
+ Dictionary containing the configuration.
60
+ """
61
+ config = DefaultLiteLLMClientConfig(
62
+ model=self._model,
63
+ provider=self._provider,
64
+ extra_parameters=self._extra_parameters,
65
+ )
66
+ return config.to_dict()
67
+
68
+ @property
69
+ def _litellm_model_name(self) -> str:
70
+ """Returns the value of LiteLLM's model parameter to be used in
71
+ completion/acompletion in LiteLLM format:
72
+
73
+ <provider>/<model or deployment name>
74
+ """
75
+ if self.model and f"{self.provider}/" not in self.model:
76
+ return f"{self.provider}/{self.model}"
77
+ return self.model
78
+
79
+ @property
80
+ def _litellm_extra_parameters(self) -> Dict[str, Any]:
81
+ """Returns optional configuration parameters specific
82
+ to the client provider and deployed model.
83
+ """
84
+ return self._extra_parameters
@@ -0,0 +1,76 @@
1
+ from typing import Protocol, Dict, List, runtime_checkable, Union
2
+
3
+ from rasa.shared.providers.llm.llm_response import LLMResponse
4
+
5
+
6
+ @runtime_checkable
7
+ class LLMClient(Protocol):
8
+ """
9
+ Protocol for an LLM client that specifies the interface for interacting
10
+ with the API.
11
+ """
12
+
13
+ @classmethod
14
+ def from_config(cls, config: dict) -> "LLMClient":
15
+ """
16
+ Initializes the llm client with the given configuration.
17
+
18
+ This class method should be implemented to parse the given
19
+ configuration and create an instance of an llm client.
20
+ """
21
+ ...
22
+
23
+ @property
24
+ def config(self) -> Dict:
25
+ """
26
+ Returns the configuration for that the llm client is initialized with.
27
+
28
+ This property should be implemented to return a dictionary containing
29
+ the configuration settings for the llm client.
30
+ """
31
+ ...
32
+
33
+ def completion(self, messages: Union[List[str], str]) -> LLMResponse:
34
+ """
35
+ Synchronously generate completions for given list of messages.
36
+
37
+ This method should be implemented to take a list of messages (as
38
+ strings) and return a list of completions (as strings).
39
+
40
+ Args:
41
+ messages: List of messages or a single message to generate the
42
+ completion for.
43
+ Returns:
44
+ LLMResponse
45
+ """
46
+ ...
47
+
48
+ async def acompletion(self, messages: Union[List[str], str]) -> LLMResponse:
49
+ """
50
+ Asynchronously generate completions for given list of messages.
51
+
52
+ This method should be implemented to take a list of messages (as
53
+ strings) and return a list of completions (as strings).
54
+
55
+ Args:
56
+ messages: List of messages or a single message to generate the
57
+ completion for.
58
+ Returns:
59
+ LLMResponse
60
+ """
61
+ ...
62
+
63
+ def validate_client_setup(self, *args, **kwargs) -> None: # type: ignore
64
+ """
65
+ Perform client setup validation.
66
+
67
+ This method should be implemented to validate whether the client can be
68
+ used with the parameters provided through configuration or environment
69
+ variables.
70
+
71
+ If there are any issues, the client should raise
72
+ ProviderClientValidationError.
73
+
74
+ If no validation is needed, this check can simply pass.
75
+ """
76
+ ...
@@ -0,0 +1,50 @@
1
+ from dataclasses import dataclass, field, asdict
2
+ from typing import Dict, List, Optional
3
+
4
+
5
+ @dataclass
6
+ class LLMUsage:
7
+ prompt_tokens: int
8
+ """Number of prompt tokens used to generate completion."""
9
+
10
+ completion_tokens: int
11
+ """Number of generated tokens."""
12
+
13
+ total_tokens: int = field(init=False)
14
+ """Total number of used tokens."""
15
+
16
+ def __post_init__(self) -> None:
17
+ self.total_tokens = self.prompt_tokens + self.completion_tokens
18
+
19
+ def to_dict(self) -> dict:
20
+ """Converts the LLMUsage dataclass instance into a dictionary."""
21
+ return asdict(self)
22
+
23
+
24
+ @dataclass
25
+ class LLMResponse:
26
+ id: str
27
+ """A unique identifier for the completion."""
28
+
29
+ choices: List[str]
30
+ """The list of completion choices the model generated for the input prompt."""
31
+
32
+ created: int
33
+ """The Unix timestamp (in seconds) of when the completion was created."""
34
+
35
+ model: Optional[str] = None
36
+ """The model used for completion."""
37
+
38
+ usage: Optional[LLMUsage] = None
39
+ """An optional details about the token usage for the API call."""
40
+
41
+ additional_info: Optional[Dict] = None
42
+ """Optional dictionary for storing additional information related to the
43
+ completion that may not be covered by other fields."""
44
+
45
+ def to_dict(self) -> dict:
46
+ """Converts the LLMResponse dataclass instance into a dictionary."""
47
+ result = asdict(self)
48
+ if self.usage:
49
+ result["usage"] = self.usage.to_dict()
50
+ return result
@@ -0,0 +1,155 @@
1
+ import os
2
+ import re
3
+ from typing import Dict, Any, Optional
4
+
5
+ import structlog
6
+
7
+ from rasa.shared.constants import (
8
+ OPENAI_API_BASE_ENV_VAR,
9
+ OPENAI_API_VERSION_ENV_VAR,
10
+ OPENAI_API_TYPE_ENV_VAR,
11
+ OPENAI_PROVIDER,
12
+ )
13
+ from rasa.shared.providers._configs.openai_client_config import OpenAIClientConfig
14
+ from rasa.shared.providers.llm._base_litellm_client import _BaseLiteLLMClient
15
+
16
+ structlogger = structlog.get_logger()
17
+
18
+
19
+ class OpenAILLMClient(_BaseLiteLLMClient):
20
+ """
21
+ A client for interfacing with OpenAI LLMs.
22
+
23
+ Parameters:
24
+ model (str): The OpenAI model name.
25
+ api_base (Optional[str]): Optional, the base URL for the API endpoints.
26
+ If not provided, it will try to be set via environment variables.
27
+ api_version (Optional[str]): Optional, the version of the API to use.
28
+ If not provided, it will try to be set via environment variable.
29
+ api_type: (Optional[str]): The api type. If not provided, it will be set via
30
+ environment variable (also optional).
31
+ kwargs (Optional[Dict[str, Any]]): Optional configuration parameters specific
32
+ to the model.
33
+
34
+ Raises:
35
+ ProviderClientValidationError: If validation of the client setup fails.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ model: str,
41
+ api_base: Optional[str] = None,
42
+ api_version: Optional[str] = None,
43
+ api_type: Optional[str] = None,
44
+ **kwargs: Any,
45
+ ):
46
+ super().__init__() # type: ignore
47
+ self._model = model
48
+ self._api_base = api_base or os.environ.get(OPENAI_API_BASE_ENV_VAR, None)
49
+ self._api_version = api_version or os.environ.get(
50
+ OPENAI_API_VERSION_ENV_VAR, None
51
+ )
52
+
53
+ # Not used by LiteLLM, here for backward compatibility
54
+ self._api_type = api_type or os.environ.get(OPENAI_API_TYPE_ENV_VAR)
55
+
56
+ self._extra_parameters = kwargs or {}
57
+ self.validate_client_setup()
58
+
59
+ @classmethod
60
+ def from_config(cls, config: Dict[str, Any]) -> "OpenAILLMClient":
61
+ """
62
+ Initializes the client from given configuration.
63
+
64
+ Args:
65
+ config (Dict[str, Any]): Configuration.
66
+
67
+ Raises:
68
+ KeyError: If any of the required configuration keys are missing.
69
+ ValueError: If `api_type` has a value different from `openai`.
70
+
71
+ Returns:
72
+ AzureOpenAILLMClient: Initialized client.
73
+ """
74
+ try:
75
+ openai_config = OpenAIClientConfig.from_dict(config)
76
+ except (KeyError, ValueError) as e:
77
+ message = "Cannot instantiate a client from the passed configuration."
78
+ structlogger.error(
79
+ "openai_llm_client.from_config.error",
80
+ message=message,
81
+ config=config,
82
+ original_error=e,
83
+ )
84
+ raise
85
+
86
+ return cls(
87
+ openai_config.model,
88
+ openai_config.api_base,
89
+ openai_config.api_version,
90
+ openai_config.api_type,
91
+ **openai_config.extra_parameters,
92
+ )
93
+
94
+ @property
95
+ def config(self) -> dict:
96
+ config = OpenAIClientConfig(
97
+ model=self.model,
98
+ api_type=self.api_type,
99
+ api_base=self.api_base,
100
+ api_version=self.api_version,
101
+ extra_parameters=self._litellm_extra_parameters,
102
+ )
103
+ return config.to_dict()
104
+
105
+ @property
106
+ def model(self) -> str:
107
+ return self._model
108
+
109
+ @property
110
+ def api_base(self) -> Optional[str]:
111
+ """
112
+ Returns the base API URL for the openai llm client.
113
+ """
114
+ return self._api_base
115
+
116
+ @property
117
+ def api_version(self) -> Optional[str]:
118
+ """
119
+ Returns the API version for the OpenAI LLM client.
120
+ """
121
+ return self._api_version
122
+
123
+ @property
124
+ def api_type(self) -> Optional[str]:
125
+ return self._api_type
126
+
127
+ @property
128
+ def _litellm_model_name(self) -> str:
129
+ """Returns the value of LiteLLM's model parameter to be used in
130
+ completion/acompletion in LiteLLM format:
131
+
132
+ <provider>/<model or deployment name>
133
+ """
134
+ regex_patter = rf"^{OPENAI_PROVIDER}/"
135
+ if not re.match(regex_patter, self._model):
136
+ return f"{OPENAI_PROVIDER}/{self._model}"
137
+ return self._model
138
+
139
+ @property
140
+ def _litellm_extra_parameters(self) -> Dict[str, Any]:
141
+ return self._extra_parameters
142
+
143
+ @property
144
+ def _completion_fn_args(self) -> Dict[str, Any]:
145
+ """Returns the completion arguments for invoking a call through
146
+ LiteLLM's completion functions.
147
+ """
148
+ fn_args = super()._completion_fn_args
149
+ fn_args.update(
150
+ {
151
+ "api_base": self.api_base,
152
+ "api_version": self.api_version,
153
+ }
154
+ )
155
+ return fn_args