rasa-pro 3.9.18__py3-none-any.whl → 3.10.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (183) hide show
  1. README.md +0 -374
  2. rasa/__init__.py +1 -2
  3. rasa/__main__.py +5 -0
  4. rasa/anonymization/anonymization_rule_executor.py +2 -2
  5. rasa/api.py +27 -23
  6. rasa/cli/arguments/data.py +27 -2
  7. rasa/cli/arguments/default_arguments.py +25 -3
  8. rasa/cli/arguments/run.py +9 -9
  9. rasa/cli/arguments/train.py +11 -3
  10. rasa/cli/data.py +70 -8
  11. rasa/cli/e2e_test.py +104 -431
  12. rasa/cli/evaluate.py +1 -1
  13. rasa/cli/interactive.py +1 -0
  14. rasa/cli/llm_fine_tuning.py +398 -0
  15. rasa/cli/project_templates/calm/endpoints.yml +1 -1
  16. rasa/cli/project_templates/tutorial/endpoints.yml +1 -1
  17. rasa/cli/run.py +15 -14
  18. rasa/cli/scaffold.py +10 -8
  19. rasa/cli/studio/studio.py +35 -5
  20. rasa/cli/train.py +56 -8
  21. rasa/cli/utils.py +22 -5
  22. rasa/cli/x.py +1 -1
  23. rasa/constants.py +7 -1
  24. rasa/core/actions/action.py +98 -49
  25. rasa/core/actions/action_run_slot_rejections.py +4 -1
  26. rasa/core/actions/custom_action_executor.py +9 -6
  27. rasa/core/actions/direct_custom_actions_executor.py +80 -0
  28. rasa/core/actions/e2e_stub_custom_action_executor.py +68 -0
  29. rasa/core/actions/grpc_custom_action_executor.py +2 -2
  30. rasa/core/actions/http_custom_action_executor.py +6 -5
  31. rasa/core/agent.py +21 -17
  32. rasa/core/channels/__init__.py +2 -0
  33. rasa/core/channels/audiocodes.py +1 -16
  34. rasa/core/channels/voice_aware/__init__.py +0 -0
  35. rasa/core/channels/voice_aware/jambonz.py +103 -0
  36. rasa/core/channels/voice_aware/jambonz_protocol.py +344 -0
  37. rasa/core/channels/voice_aware/utils.py +20 -0
  38. rasa/core/channels/voice_native/__init__.py +0 -0
  39. rasa/core/constants.py +6 -1
  40. rasa/core/information_retrieval/faiss.py +7 -4
  41. rasa/core/information_retrieval/information_retrieval.py +8 -0
  42. rasa/core/information_retrieval/milvus.py +9 -2
  43. rasa/core/information_retrieval/qdrant.py +1 -1
  44. rasa/core/nlg/contextual_response_rephraser.py +32 -10
  45. rasa/core/nlg/summarize.py +4 -3
  46. rasa/core/policies/enterprise_search_policy.py +113 -45
  47. rasa/core/policies/flows/flow_executor.py +122 -76
  48. rasa/core/policies/intentless_policy.py +83 -29
  49. rasa/core/processor.py +72 -54
  50. rasa/core/run.py +5 -4
  51. rasa/core/tracker_store.py +8 -4
  52. rasa/core/training/interactive.py +1 -1
  53. rasa/core/utils.py +56 -57
  54. rasa/dialogue_understanding/coexistence/llm_based_router.py +53 -13
  55. rasa/dialogue_understanding/commands/__init__.py +6 -0
  56. rasa/dialogue_understanding/commands/restart_command.py +58 -0
  57. rasa/dialogue_understanding/commands/session_start_command.py +59 -0
  58. rasa/dialogue_understanding/commands/utils.py +40 -0
  59. rasa/dialogue_understanding/generator/constants.py +10 -3
  60. rasa/dialogue_understanding/generator/flow_retrieval.py +21 -5
  61. rasa/dialogue_understanding/generator/llm_based_command_generator.py +13 -3
  62. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +134 -90
  63. rasa/dialogue_understanding/generator/nlu_command_adapter.py +47 -7
  64. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +127 -41
  65. rasa/dialogue_understanding/patterns/restart.py +37 -0
  66. rasa/dialogue_understanding/patterns/session_start.py +37 -0
  67. rasa/dialogue_understanding/processor/command_processor.py +16 -3
  68. rasa/dialogue_understanding/processor/command_processor_component.py +6 -2
  69. rasa/e2e_test/aggregate_test_stats_calculator.py +134 -0
  70. rasa/e2e_test/assertions.py +1223 -0
  71. rasa/e2e_test/assertions_schema.yml +106 -0
  72. rasa/e2e_test/constants.py +20 -0
  73. rasa/e2e_test/e2e_config.py +220 -0
  74. rasa/e2e_test/e2e_config_schema.yml +26 -0
  75. rasa/e2e_test/e2e_test_case.py +131 -8
  76. rasa/e2e_test/e2e_test_converter.py +363 -0
  77. rasa/e2e_test/e2e_test_converter_prompt.jinja2 +70 -0
  78. rasa/e2e_test/e2e_test_coverage_report.py +364 -0
  79. rasa/e2e_test/e2e_test_result.py +26 -6
  80. rasa/e2e_test/e2e_test_runner.py +493 -71
  81. rasa/e2e_test/e2e_test_schema.yml +96 -0
  82. rasa/e2e_test/pykwalify_extensions.py +39 -0
  83. rasa/e2e_test/stub_custom_action.py +70 -0
  84. rasa/e2e_test/utils/__init__.py +0 -0
  85. rasa/e2e_test/utils/e2e_yaml_utils.py +55 -0
  86. rasa/e2e_test/utils/io.py +598 -0
  87. rasa/e2e_test/utils/validation.py +80 -0
  88. rasa/engine/graph.py +9 -3
  89. rasa/engine/recipes/default_components.py +0 -2
  90. rasa/engine/recipes/default_recipe.py +10 -2
  91. rasa/engine/storage/local_model_storage.py +40 -12
  92. rasa/engine/validation.py +78 -1
  93. rasa/env.py +9 -0
  94. rasa/graph_components/providers/story_graph_provider.py +59 -6
  95. rasa/llm_fine_tuning/__init__.py +0 -0
  96. rasa/llm_fine_tuning/annotation_module.py +241 -0
  97. rasa/llm_fine_tuning/conversations.py +144 -0
  98. rasa/llm_fine_tuning/llm_data_preparation_module.py +178 -0
  99. rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +407 -0
  100. rasa/llm_fine_tuning/paraphrasing/__init__.py +0 -0
  101. rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +281 -0
  102. rasa/llm_fine_tuning/paraphrasing/default_rephrase_prompt_template.jina2 +44 -0
  103. rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +121 -0
  104. rasa/llm_fine_tuning/paraphrasing/rephrased_user_message.py +10 -0
  105. rasa/llm_fine_tuning/paraphrasing_module.py +128 -0
  106. rasa/llm_fine_tuning/storage.py +174 -0
  107. rasa/llm_fine_tuning/train_test_split_module.py +441 -0
  108. rasa/model_training.py +56 -16
  109. rasa/nlu/persistor.py +157 -36
  110. rasa/server.py +45 -10
  111. rasa/shared/constants.py +76 -16
  112. rasa/shared/core/domain.py +27 -19
  113. rasa/shared/core/events.py +28 -2
  114. rasa/shared/core/flows/flow.py +208 -13
  115. rasa/shared/core/flows/flow_path.py +84 -0
  116. rasa/shared/core/flows/flows_list.py +33 -11
  117. rasa/shared/core/flows/flows_yaml_schema.json +269 -193
  118. rasa/shared/core/flows/validation.py +112 -25
  119. rasa/shared/core/flows/yaml_flows_io.py +149 -10
  120. rasa/shared/core/trackers.py +6 -0
  121. rasa/shared/core/training_data/structures.py +20 -0
  122. rasa/shared/core/training_data/visualization.html +2 -2
  123. rasa/shared/exceptions.py +4 -0
  124. rasa/shared/importers/importer.py +64 -16
  125. rasa/shared/nlu/constants.py +2 -0
  126. rasa/shared/providers/_configs/__init__.py +0 -0
  127. rasa/shared/providers/_configs/azure_openai_client_config.py +183 -0
  128. rasa/shared/providers/_configs/client_config.py +57 -0
  129. rasa/shared/providers/_configs/default_litellm_client_config.py +130 -0
  130. rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +234 -0
  131. rasa/shared/providers/_configs/openai_client_config.py +175 -0
  132. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +176 -0
  133. rasa/shared/providers/_configs/utils.py +101 -0
  134. rasa/shared/providers/_ssl_verification_utils.py +124 -0
  135. rasa/shared/providers/embedding/__init__.py +0 -0
  136. rasa/shared/providers/embedding/_base_litellm_embedding_client.py +259 -0
  137. rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py +74 -0
  138. rasa/shared/providers/embedding/azure_openai_embedding_client.py +277 -0
  139. rasa/shared/providers/embedding/default_litellm_embedding_client.py +102 -0
  140. rasa/shared/providers/embedding/embedding_client.py +90 -0
  141. rasa/shared/providers/embedding/embedding_response.py +41 -0
  142. rasa/shared/providers/embedding/huggingface_local_embedding_client.py +191 -0
  143. rasa/shared/providers/embedding/openai_embedding_client.py +172 -0
  144. rasa/shared/providers/llm/__init__.py +0 -0
  145. rasa/shared/providers/llm/_base_litellm_client.py +251 -0
  146. rasa/shared/providers/llm/azure_openai_llm_client.py +338 -0
  147. rasa/shared/providers/llm/default_litellm_llm_client.py +84 -0
  148. rasa/shared/providers/llm/llm_client.py +76 -0
  149. rasa/shared/providers/llm/llm_response.py +50 -0
  150. rasa/shared/providers/llm/openai_llm_client.py +155 -0
  151. rasa/shared/providers/llm/self_hosted_llm_client.py +293 -0
  152. rasa/shared/providers/mappings.py +75 -0
  153. rasa/shared/utils/cli.py +30 -0
  154. rasa/shared/utils/io.py +65 -2
  155. rasa/shared/utils/llm.py +246 -200
  156. rasa/shared/utils/yaml.py +121 -15
  157. rasa/studio/auth.py +6 -4
  158. rasa/studio/config.py +13 -4
  159. rasa/studio/constants.py +1 -0
  160. rasa/studio/data_handler.py +10 -3
  161. rasa/studio/download.py +19 -13
  162. rasa/studio/train.py +2 -3
  163. rasa/studio/upload.py +19 -11
  164. rasa/telemetry.py +113 -58
  165. rasa/tracing/instrumentation/attribute_extractors.py +32 -17
  166. rasa/utils/common.py +18 -19
  167. rasa/utils/endpoints.py +7 -4
  168. rasa/utils/json_utils.py +60 -0
  169. rasa/utils/licensing.py +9 -1
  170. rasa/utils/ml_utils.py +4 -2
  171. rasa/validator.py +213 -3
  172. rasa/version.py +1 -1
  173. rasa_pro-3.10.16.dist-info/METADATA +196 -0
  174. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/RECORD +179 -113
  175. rasa/nlu/classifiers/llm_intent_classifier.py +0 -519
  176. rasa/shared/providers/openai/clients.py +0 -43
  177. rasa/shared/providers/openai/session_handler.py +0 -110
  178. rasa_pro-3.9.18.dist-info/METADATA +0 -563
  179. /rasa/{shared/providers/openai → cli/project_templates/tutorial/actions}/__init__.py +0 -0
  180. /rasa/cli/project_templates/tutorial/{actions.py → actions/actions.py} +0 -0
  181. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/NOTICE +0 -0
  182. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/WHEEL +0 -0
  183. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,293 @@
1
+ from typing import Any, Dict, List, Optional, Union
2
+ from litellm import (
3
+ text_completion,
4
+ atext_completion,
5
+ )
6
+ import logging
7
+ import os
8
+ import structlog
9
+
10
+ from rasa.shared.constants import (
11
+ SELF_HOSTED_VLLM_PREFIX,
12
+ SELF_HOSTED_VLLM_API_KEY_ENV_VAR,
13
+ )
14
+ from rasa.shared.providers._configs.self_hosted_llm_client_config import (
15
+ SelfHostedLLMClientConfig,
16
+ )
17
+ from rasa.shared.exceptions import ProviderClientAPIException
18
+ from rasa.shared.providers.llm._base_litellm_client import _BaseLiteLLMClient
19
+ from rasa.shared.providers.llm.llm_response import LLMResponse, LLMUsage
20
+ from rasa.shared.utils.io import suppress_logs
21
+
22
+ structlogger = structlog.get_logger()
23
+
24
+
25
+ class SelfHostedLLMClient(_BaseLiteLLMClient):
26
+ """A client for interfacing with Self Hosted LLM endpoints that uses
27
+
28
+ Parameters:
29
+ model (str): The model or deployment name.
30
+ provider (str): The provider of the model.
31
+ api_base (str): The base URL of the API endpoint.
32
+ api_type (Optional[str]): The type of the API endpoint.
33
+ api_version (Optional[str]): The version of the API endpoint.
34
+ use_chat_completions_endpoint (Optional[bool]): Whether to use the chat
35
+ completions endpoint for completions. Defaults to True.
36
+ kwargs: Any: Additional configuration parameters that can include, but
37
+ are not limited to model parameters and lite-llm specific
38
+ parameters. These parameters will be passed to the
39
+ completion/acompletion calls. To see what it can include, visit:
40
+
41
+ Raises:
42
+ ProviderClientValidationError: If validation of the client setup fails.
43
+ ProviderClientAPIException: If the API request fails.
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ provider: str,
49
+ model: str,
50
+ api_base: str,
51
+ api_type: Optional[str] = None,
52
+ api_version: Optional[str] = None,
53
+ use_chat_completions_endpoint: Optional[bool] = True,
54
+ **kwargs: Any,
55
+ ):
56
+ super().__init__() # type: ignore
57
+ self._provider = provider
58
+ self._model = model
59
+ self._api_base = api_base
60
+ self._api_type = api_type
61
+ self._api_version = api_version
62
+ self._use_chat_completions_endpoint = use_chat_completions_endpoint
63
+ self._extra_parameters = kwargs or {}
64
+ self._apply_dummy_api_key_if_missing()
65
+
66
+ @classmethod
67
+ def from_config(cls, config: Dict[str, Any]) -> "SelfHostedLLMClient":
68
+ try:
69
+ client_config = SelfHostedLLMClientConfig.from_dict(config)
70
+ except ValueError as e:
71
+ message = "Cannot instantiate a client from the passed configuration."
72
+ structlogger.error(
73
+ "self_hosted_llm_client.from_config.error",
74
+ message=message,
75
+ config=config,
76
+ original_error=e,
77
+ )
78
+ raise
79
+
80
+ return cls(
81
+ model=client_config.model,
82
+ provider=client_config.provider,
83
+ api_base=client_config.api_base,
84
+ api_type=client_config.api_type,
85
+ api_version=client_config.api_version,
86
+ use_chat_completions_endpoint=client_config.use_chat_completions_endpoint,
87
+ **client_config.extra_parameters,
88
+ )
89
+
90
+ @property
91
+ def provider(self) -> str:
92
+ """
93
+ Returns the provider name for the self hosted llm client.
94
+
95
+ Returns:
96
+ String representing the provider name.
97
+ """
98
+ return self._provider
99
+
100
+ @property
101
+ def model(self) -> str:
102
+ """
103
+ Returns the model name for the self hosted llm client.
104
+
105
+ Returns:
106
+ String representing the model name.
107
+ """
108
+ return self._model
109
+
110
+ @property
111
+ def api_base(self) -> str:
112
+ """
113
+ Returns the base URL for the API endpoint.
114
+
115
+ Returns:
116
+ String representing the base URL.
117
+ """
118
+ return self._api_base
119
+
120
+ @property
121
+ def api_type(self) -> Optional[str]:
122
+ """
123
+ Returns the type of the API endpoint. Currently only OpenAI is supported.
124
+
125
+ Returns:
126
+ String representing the API type.
127
+ """
128
+ return self._api_type
129
+
130
+ @property
131
+ def api_version(self) -> Optional[str]:
132
+ """
133
+ Returns the version of the API endpoint.
134
+
135
+ Returns:
136
+ String representing the API version.
137
+ """
138
+ return self._api_version
139
+
140
+ @property
141
+ def config(self) -> Dict:
142
+ """
143
+ Returns the configuration for the self hosted llm client.
144
+ Returns:
145
+ Dictionary containing the configuration.
146
+ """
147
+ config = SelfHostedLLMClientConfig(
148
+ model=self._model,
149
+ provider=self._provider,
150
+ api_base=self._api_base,
151
+ api_type=self._api_type,
152
+ api_version=self._api_version,
153
+ use_chat_completions_endpoint=self._use_chat_completions_endpoint,
154
+ extra_parameters=self._extra_parameters,
155
+ )
156
+ return config.to_dict()
157
+
158
+ @property
159
+ def _litellm_model_name(self) -> str:
160
+ """Returns the value of LiteLLM's model parameter to be used in
161
+ completion/acompletion in LiteLLM format:
162
+
163
+ <openai>/<model or deployment name>
164
+ """
165
+ if self.model and f"{SELF_HOSTED_VLLM_PREFIX}/" not in self.model:
166
+ return f"{SELF_HOSTED_VLLM_PREFIX}/{self.model}"
167
+ return self.model
168
+
169
+ @property
170
+ def _litellm_extra_parameters(self) -> Dict[str, Any]:
171
+ """Returns optional configuration parameters specific
172
+ to the client provider and deployed model.
173
+ """
174
+ return self._extra_parameters
175
+
176
+ @property
177
+ def _completion_fn_args(self) -> Dict[str, Any]:
178
+ """Returns the completion arguments for invoking a call through
179
+ LiteLLM's completion functions.
180
+ """
181
+ fn_args = super()._completion_fn_args
182
+ fn_args.update(
183
+ {
184
+ "api_base": self.api_base,
185
+ "api_version": self.api_version,
186
+ }
187
+ )
188
+ return fn_args
189
+
190
+ @suppress_logs(log_level=logging.WARNING)
191
+ def _text_completion(self, prompt: Union[List[str], str]) -> LLMResponse:
192
+ """
193
+ Synchronously generate completions for given prompt.
194
+
195
+ Args:
196
+ prompt: Prompt to generate the completion for.
197
+ Returns:
198
+ List of message completions.
199
+ Raises:
200
+ ProviderClientAPIException: If the API request fails.
201
+ """
202
+ try:
203
+ response = text_completion(prompt=prompt, **self._completion_fn_args)
204
+ return self._format_text_completion_response(response)
205
+ except Exception as e:
206
+ raise ProviderClientAPIException(e)
207
+
208
+ @suppress_logs(log_level=logging.WARNING)
209
+ async def _atext_completion(self, prompt: Union[List[str], str]) -> LLMResponse:
210
+ """
211
+ Asynchronously generate completions for given prompt.
212
+
213
+ Args:
214
+ prompt: Prompt to generate the completion for.
215
+ Returns:
216
+ List of message completions.
217
+ Raises:
218
+ ProviderClientAPIException: If the API request fails.
219
+ """
220
+ try:
221
+ response = await atext_completion(prompt=prompt, **self._completion_fn_args)
222
+ return self._format_text_completion_response(response)
223
+ except Exception as e:
224
+ raise ProviderClientAPIException(e)
225
+
226
+ async def acompletion(self, messages: Union[List[str], str]) -> LLMResponse:
227
+ """Asynchronous completion of the model with the given messages.
228
+
229
+ Method overrides the base class method to call the appropriate
230
+ completion method based on the configuration. If the chat completions
231
+ endpoint is enabled, the acompletion method is called. Otherwise, the
232
+ atext_completion method is called.
233
+
234
+ Args:
235
+ messages: The messages to be used for completion.
236
+
237
+ Returns:
238
+ The completion response.
239
+ """
240
+ if self._use_chat_completions_endpoint:
241
+ return await super().acompletion(messages)
242
+ return await self._atext_completion(messages)
243
+
244
+ def completion(self, messages: Union[List[str], str]) -> LLMResponse:
245
+ """Completion of the model with the given messages.
246
+
247
+ Method overrides the base class method to call the appropriate
248
+ completion method based on the configuration. If the chat completions
249
+ endpoint is enabled, the completion method is called. Otherwise, the
250
+ text_completion method is called.
251
+
252
+ Args:
253
+ messages: The messages to be used for completion.
254
+
255
+ Returns:
256
+ The completion response.
257
+ """
258
+ if self._use_chat_completions_endpoint:
259
+ return super().completion(messages)
260
+ return self._text_completion(messages)
261
+
262
+ def _format_text_completion_response(self, response: Any) -> LLMResponse:
263
+ """Parses the LiteLLM text completion response to Rasa format."""
264
+ formatted_response = LLMResponse(
265
+ id=response.id,
266
+ created=response.created,
267
+ choices=[choice.text for choice in response.choices],
268
+ model=response.model,
269
+ )
270
+ if (usage := response.usage) is not None:
271
+ prompt_tokens = (
272
+ num_tokens
273
+ if isinstance(num_tokens := usage.prompt_tokens, (int, float))
274
+ else 0
275
+ )
276
+ completion_tokens = (
277
+ num_tokens
278
+ if isinstance(num_tokens := usage.completion_tokens, (int, float))
279
+ else 0
280
+ )
281
+ formatted_response.usage = LLMUsage(prompt_tokens, completion_tokens)
282
+ structlogger.debug(
283
+ "base_litellm_client.formatted_response",
284
+ formatted_response=formatted_response.to_dict(),
285
+ )
286
+ return formatted_response
287
+
288
+ @staticmethod
289
+ def _apply_dummy_api_key_if_missing() -> None:
290
+ if not os.getenv(SELF_HOSTED_VLLM_API_KEY_ENV_VAR):
291
+ os.environ[SELF_HOSTED_VLLM_API_KEY_ENV_VAR] = (
292
+ "dummy_self_hosted_llm_api_key"
293
+ )
@@ -0,0 +1,75 @@
1
+ from typing import Dict, Type, Optional
2
+
3
+ from rasa.shared.constants import (
4
+ AZURE_OPENAI_PROVIDER,
5
+ HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER,
6
+ OPENAI_PROVIDER,
7
+ SELF_HOSTED_PROVIDER,
8
+ )
9
+ from rasa.shared.providers.embedding.azure_openai_embedding_client import (
10
+ AzureOpenAIEmbeddingClient,
11
+ )
12
+ from rasa.shared.providers.embedding.default_litellm_embedding_client import (
13
+ DefaultLiteLLMEmbeddingClient,
14
+ )
15
+ from rasa.shared.providers.embedding.embedding_client import EmbeddingClient
16
+ from rasa.shared.providers.embedding.huggingface_local_embedding_client import (
17
+ HuggingFaceLocalEmbeddingClient,
18
+ )
19
+ from rasa.shared.providers.embedding.openai_embedding_client import (
20
+ OpenAIEmbeddingClient,
21
+ )
22
+ from rasa.shared.providers.llm.azure_openai_llm_client import AzureOpenAILLMClient
23
+ from rasa.shared.providers.llm.default_litellm_llm_client import DefaultLiteLLMClient
24
+ from rasa.shared.providers.llm.llm_client import LLMClient
25
+ from rasa.shared.providers.llm.openai_llm_client import OpenAILLMClient
26
+ from rasa.shared.providers.llm.self_hosted_llm_client import SelfHostedLLMClient
27
+ from rasa.shared.providers._configs.azure_openai_client_config import (
28
+ AzureOpenAIClientConfig,
29
+ )
30
+ from rasa.shared.providers._configs.default_litellm_client_config import (
31
+ DefaultLiteLLMClientConfig,
32
+ )
33
+ from rasa.shared.providers._configs.huggingface_local_embedding_client_config import (
34
+ HuggingFaceLocalEmbeddingClientConfig,
35
+ )
36
+ from rasa.shared.providers._configs.openai_client_config import OpenAIClientConfig
37
+ from rasa.shared.providers._configs.self_hosted_llm_client_config import (
38
+ SelfHostedLLMClientConfig,
39
+ )
40
+ from rasa.shared.providers._configs.client_config import ClientConfig
41
+
42
+ _provider_to_llm_client_mapping: Dict[str, Type[LLMClient]] = {
43
+ OPENAI_PROVIDER: OpenAILLMClient,
44
+ AZURE_OPENAI_PROVIDER: AzureOpenAILLMClient,
45
+ SELF_HOSTED_PROVIDER: SelfHostedLLMClient,
46
+ }
47
+
48
+ _provider_to_embedding_client_mapping: Dict[str, Type[EmbeddingClient]] = {
49
+ OPENAI_PROVIDER: OpenAIEmbeddingClient,
50
+ AZURE_OPENAI_PROVIDER: AzureOpenAIEmbeddingClient,
51
+ HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER: HuggingFaceLocalEmbeddingClient,
52
+ }
53
+
54
+ _provider_to_client_config_class_mapping: Dict[str, Type] = {
55
+ OPENAI_PROVIDER: OpenAIClientConfig,
56
+ AZURE_OPENAI_PROVIDER: AzureOpenAIClientConfig,
57
+ HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER: HuggingFaceLocalEmbeddingClientConfig,
58
+ SELF_HOSTED_PROVIDER: SelfHostedLLMClientConfig,
59
+ }
60
+
61
+
62
+ def get_llm_client_from_provider(provider: Optional[str]) -> Type[LLMClient]:
63
+ return _provider_to_llm_client_mapping.get(provider, DefaultLiteLLMClient)
64
+
65
+
66
+ def get_embedding_client_from_provider(provider: str) -> Type[EmbeddingClient]:
67
+ return _provider_to_embedding_client_mapping.get(
68
+ provider, DefaultLiteLLMEmbeddingClient
69
+ )
70
+
71
+
72
+ def get_client_config_class_from_provider(provider: str) -> Type[ClientConfig]:
73
+ return _provider_to_client_config_class_mapping.get(
74
+ provider, DefaultLiteLLMClientConfig
75
+ )
rasa/shared/utils/cli.py CHANGED
@@ -1,3 +1,5 @@
1
+ import math
2
+ import shutil
1
3
  import sys
2
4
  from typing import Any, Text, NoReturn
3
5
 
@@ -70,3 +72,31 @@ def print_error_and_exit(message: Text, exit_code: int = 1) -> NoReturn:
70
72
  """
71
73
  print_error(message)
72
74
  sys.exit(exit_code)
75
+
76
+
77
+ def pad(text: Text, char: Text = "=", min: int = 3) -> Text:
78
+ """Pad text to a certain length.
79
+
80
+ Uses `char` to pad the text to the specified length. If the text is longer
81
+ than the specified length, at least `min` are used.
82
+
83
+ The padding is applied to the left and right of the text (almost) equally.
84
+
85
+ Example:
86
+ >>> pad("Hello")
87
+ "========= Hello ========"
88
+ >>> pad("Hello", char="-")
89
+ "--------- Hello --------"
90
+
91
+ Args:
92
+ text: Text to pad.
93
+ min: Minimum length of the padding.
94
+ char: Character to pad with.
95
+
96
+ Returns:
97
+ Padded text.
98
+ """
99
+ width = shutil.get_terminal_size((80, 20)).columns
100
+ padding = max(width - len(text) - 2, min * 2)
101
+
102
+ return char * (padding // 2) + " " + text + " " + char * math.ceil(padding / 2)
rasa/shared/utils/io.py CHANGED
@@ -1,12 +1,15 @@
1
1
  from collections import OrderedDict
2
+ from functools import wraps
3
+ from hashlib import md5
4
+ import asyncio
2
5
  import errno
3
6
  import glob
4
- from hashlib import md5
5
7
  import json
8
+ import logging
6
9
  import os
7
10
  import sys
8
11
  from pathlib import Path
9
- from typing import Any, Dict, List, Optional, Text, Type, Union
12
+ from typing import Any, cast, Callable, Dict, List, Optional, Text, Type, TypeVar, Union
10
13
  import warnings
11
14
  import random
12
15
  import string
@@ -137,6 +140,17 @@ def read_json_file(filename: Union[Text, Path]) -> Any:
137
140
  )
138
141
 
139
142
 
143
+ def read_jsonl_file(file_path: Union[Text, Path]) -> List[Any]:
144
+ """Read JSONL from a file."""
145
+ content = read_file(file_path)
146
+ try:
147
+ return [json.loads(line) for line in content.splitlines()]
148
+ except ValueError as e:
149
+ raise FileIOException(
150
+ f"Failed to read JSONL from '{os.path.abspath(file_path)}'. Error: {e}"
151
+ )
152
+
153
+
140
154
  def list_directory(path: Text) -> List[Text]:
141
155
  """Returns all files and folders excluding hidden files.
142
156
 
@@ -413,3 +427,52 @@ def file_as_bytes(file_path: Text) -> bytes:
413
427
  raise FileNotFoundException(
414
428
  f"Failed to read file, " f"'{os.path.abspath(file_path)}' does not exist."
415
429
  )
430
+
431
+
432
+ F = TypeVar("F", bound=Callable[..., Any])
433
+
434
+
435
+ def suppress_logs(log_level: int = logging.WARNING) -> Callable[[F], F]:
436
+ """Decorator to suppress logs during the execution of a function.
437
+
438
+ Args:
439
+ log_level: The log level to set during the execution of the function.
440
+
441
+ Returns:
442
+ The decorated function.
443
+ """
444
+
445
+ def decorator(func: F) -> F:
446
+ @wraps(func)
447
+ async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
448
+ # Store the original logging level and set the new level.
449
+ original_logging_level = logging.getLogger().getEffectiveLevel()
450
+ logging.getLogger().setLevel(log_level)
451
+ try:
452
+ # Execute the async function.
453
+ result = await func(*args, **kwargs)
454
+ finally:
455
+ # Reset the logging level to the original level.
456
+ logging.getLogger().setLevel(original_logging_level)
457
+ return result
458
+
459
+ @wraps(func)
460
+ def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
461
+ # Store the original logging level and set the new level.
462
+ original_logging_level = logging.getLogger().getEffectiveLevel()
463
+ logging.getLogger().setLevel(log_level)
464
+ try:
465
+ # Execute the function.
466
+ result = func(*args, **kwargs)
467
+ finally:
468
+ # Reset the logging level to the original level.
469
+ logging.getLogger().setLevel(original_logging_level)
470
+ return result
471
+
472
+ # Determine if the function is async or not
473
+ if asyncio.iscoroutinefunction(func):
474
+ return cast(F, async_wrapper)
475
+ else:
476
+ return cast(F, sync_wrapper)
477
+
478
+ return decorator