rasa-pro 3.9.18__py3-none-any.whl → 3.10.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (183) hide show
  1. README.md +0 -374
  2. rasa/__init__.py +1 -2
  3. rasa/__main__.py +5 -0
  4. rasa/anonymization/anonymization_rule_executor.py +2 -2
  5. rasa/api.py +27 -23
  6. rasa/cli/arguments/data.py +27 -2
  7. rasa/cli/arguments/default_arguments.py +25 -3
  8. rasa/cli/arguments/run.py +9 -9
  9. rasa/cli/arguments/train.py +11 -3
  10. rasa/cli/data.py +70 -8
  11. rasa/cli/e2e_test.py +104 -431
  12. rasa/cli/evaluate.py +1 -1
  13. rasa/cli/interactive.py +1 -0
  14. rasa/cli/llm_fine_tuning.py +398 -0
  15. rasa/cli/project_templates/calm/endpoints.yml +1 -1
  16. rasa/cli/project_templates/tutorial/endpoints.yml +1 -1
  17. rasa/cli/run.py +15 -14
  18. rasa/cli/scaffold.py +10 -8
  19. rasa/cli/studio/studio.py +35 -5
  20. rasa/cli/train.py +56 -8
  21. rasa/cli/utils.py +22 -5
  22. rasa/cli/x.py +1 -1
  23. rasa/constants.py +7 -1
  24. rasa/core/actions/action.py +98 -49
  25. rasa/core/actions/action_run_slot_rejections.py +4 -1
  26. rasa/core/actions/custom_action_executor.py +9 -6
  27. rasa/core/actions/direct_custom_actions_executor.py +80 -0
  28. rasa/core/actions/e2e_stub_custom_action_executor.py +68 -0
  29. rasa/core/actions/grpc_custom_action_executor.py +2 -2
  30. rasa/core/actions/http_custom_action_executor.py +6 -5
  31. rasa/core/agent.py +21 -17
  32. rasa/core/channels/__init__.py +2 -0
  33. rasa/core/channels/audiocodes.py +1 -16
  34. rasa/core/channels/voice_aware/__init__.py +0 -0
  35. rasa/core/channels/voice_aware/jambonz.py +103 -0
  36. rasa/core/channels/voice_aware/jambonz_protocol.py +344 -0
  37. rasa/core/channels/voice_aware/utils.py +20 -0
  38. rasa/core/channels/voice_native/__init__.py +0 -0
  39. rasa/core/constants.py +6 -1
  40. rasa/core/information_retrieval/faiss.py +7 -4
  41. rasa/core/information_retrieval/information_retrieval.py +8 -0
  42. rasa/core/information_retrieval/milvus.py +9 -2
  43. rasa/core/information_retrieval/qdrant.py +1 -1
  44. rasa/core/nlg/contextual_response_rephraser.py +32 -10
  45. rasa/core/nlg/summarize.py +4 -3
  46. rasa/core/policies/enterprise_search_policy.py +113 -45
  47. rasa/core/policies/flows/flow_executor.py +122 -76
  48. rasa/core/policies/intentless_policy.py +83 -29
  49. rasa/core/processor.py +72 -54
  50. rasa/core/run.py +5 -4
  51. rasa/core/tracker_store.py +8 -4
  52. rasa/core/training/interactive.py +1 -1
  53. rasa/core/utils.py +56 -57
  54. rasa/dialogue_understanding/coexistence/llm_based_router.py +53 -13
  55. rasa/dialogue_understanding/commands/__init__.py +6 -0
  56. rasa/dialogue_understanding/commands/restart_command.py +58 -0
  57. rasa/dialogue_understanding/commands/session_start_command.py +59 -0
  58. rasa/dialogue_understanding/commands/utils.py +40 -0
  59. rasa/dialogue_understanding/generator/constants.py +10 -3
  60. rasa/dialogue_understanding/generator/flow_retrieval.py +21 -5
  61. rasa/dialogue_understanding/generator/llm_based_command_generator.py +13 -3
  62. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +134 -90
  63. rasa/dialogue_understanding/generator/nlu_command_adapter.py +47 -7
  64. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +127 -41
  65. rasa/dialogue_understanding/patterns/restart.py +37 -0
  66. rasa/dialogue_understanding/patterns/session_start.py +37 -0
  67. rasa/dialogue_understanding/processor/command_processor.py +16 -3
  68. rasa/dialogue_understanding/processor/command_processor_component.py +6 -2
  69. rasa/e2e_test/aggregate_test_stats_calculator.py +134 -0
  70. rasa/e2e_test/assertions.py +1223 -0
  71. rasa/e2e_test/assertions_schema.yml +106 -0
  72. rasa/e2e_test/constants.py +20 -0
  73. rasa/e2e_test/e2e_config.py +220 -0
  74. rasa/e2e_test/e2e_config_schema.yml +26 -0
  75. rasa/e2e_test/e2e_test_case.py +131 -8
  76. rasa/e2e_test/e2e_test_converter.py +363 -0
  77. rasa/e2e_test/e2e_test_converter_prompt.jinja2 +70 -0
  78. rasa/e2e_test/e2e_test_coverage_report.py +364 -0
  79. rasa/e2e_test/e2e_test_result.py +26 -6
  80. rasa/e2e_test/e2e_test_runner.py +493 -71
  81. rasa/e2e_test/e2e_test_schema.yml +96 -0
  82. rasa/e2e_test/pykwalify_extensions.py +39 -0
  83. rasa/e2e_test/stub_custom_action.py +70 -0
  84. rasa/e2e_test/utils/__init__.py +0 -0
  85. rasa/e2e_test/utils/e2e_yaml_utils.py +55 -0
  86. rasa/e2e_test/utils/io.py +598 -0
  87. rasa/e2e_test/utils/validation.py +80 -0
  88. rasa/engine/graph.py +9 -3
  89. rasa/engine/recipes/default_components.py +0 -2
  90. rasa/engine/recipes/default_recipe.py +10 -2
  91. rasa/engine/storage/local_model_storage.py +40 -12
  92. rasa/engine/validation.py +78 -1
  93. rasa/env.py +9 -0
  94. rasa/graph_components/providers/story_graph_provider.py +59 -6
  95. rasa/llm_fine_tuning/__init__.py +0 -0
  96. rasa/llm_fine_tuning/annotation_module.py +241 -0
  97. rasa/llm_fine_tuning/conversations.py +144 -0
  98. rasa/llm_fine_tuning/llm_data_preparation_module.py +178 -0
  99. rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +407 -0
  100. rasa/llm_fine_tuning/paraphrasing/__init__.py +0 -0
  101. rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +281 -0
  102. rasa/llm_fine_tuning/paraphrasing/default_rephrase_prompt_template.jina2 +44 -0
  103. rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +121 -0
  104. rasa/llm_fine_tuning/paraphrasing/rephrased_user_message.py +10 -0
  105. rasa/llm_fine_tuning/paraphrasing_module.py +128 -0
  106. rasa/llm_fine_tuning/storage.py +174 -0
  107. rasa/llm_fine_tuning/train_test_split_module.py +441 -0
  108. rasa/model_training.py +56 -16
  109. rasa/nlu/persistor.py +157 -36
  110. rasa/server.py +45 -10
  111. rasa/shared/constants.py +76 -16
  112. rasa/shared/core/domain.py +27 -19
  113. rasa/shared/core/events.py +28 -2
  114. rasa/shared/core/flows/flow.py +208 -13
  115. rasa/shared/core/flows/flow_path.py +84 -0
  116. rasa/shared/core/flows/flows_list.py +33 -11
  117. rasa/shared/core/flows/flows_yaml_schema.json +269 -193
  118. rasa/shared/core/flows/validation.py +112 -25
  119. rasa/shared/core/flows/yaml_flows_io.py +149 -10
  120. rasa/shared/core/trackers.py +6 -0
  121. rasa/shared/core/training_data/structures.py +20 -0
  122. rasa/shared/core/training_data/visualization.html +2 -2
  123. rasa/shared/exceptions.py +4 -0
  124. rasa/shared/importers/importer.py +64 -16
  125. rasa/shared/nlu/constants.py +2 -0
  126. rasa/shared/providers/_configs/__init__.py +0 -0
  127. rasa/shared/providers/_configs/azure_openai_client_config.py +183 -0
  128. rasa/shared/providers/_configs/client_config.py +57 -0
  129. rasa/shared/providers/_configs/default_litellm_client_config.py +130 -0
  130. rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +234 -0
  131. rasa/shared/providers/_configs/openai_client_config.py +175 -0
  132. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +176 -0
  133. rasa/shared/providers/_configs/utils.py +101 -0
  134. rasa/shared/providers/_ssl_verification_utils.py +124 -0
  135. rasa/shared/providers/embedding/__init__.py +0 -0
  136. rasa/shared/providers/embedding/_base_litellm_embedding_client.py +259 -0
  137. rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py +74 -0
  138. rasa/shared/providers/embedding/azure_openai_embedding_client.py +277 -0
  139. rasa/shared/providers/embedding/default_litellm_embedding_client.py +102 -0
  140. rasa/shared/providers/embedding/embedding_client.py +90 -0
  141. rasa/shared/providers/embedding/embedding_response.py +41 -0
  142. rasa/shared/providers/embedding/huggingface_local_embedding_client.py +191 -0
  143. rasa/shared/providers/embedding/openai_embedding_client.py +172 -0
  144. rasa/shared/providers/llm/__init__.py +0 -0
  145. rasa/shared/providers/llm/_base_litellm_client.py +251 -0
  146. rasa/shared/providers/llm/azure_openai_llm_client.py +338 -0
  147. rasa/shared/providers/llm/default_litellm_llm_client.py +84 -0
  148. rasa/shared/providers/llm/llm_client.py +76 -0
  149. rasa/shared/providers/llm/llm_response.py +50 -0
  150. rasa/shared/providers/llm/openai_llm_client.py +155 -0
  151. rasa/shared/providers/llm/self_hosted_llm_client.py +293 -0
  152. rasa/shared/providers/mappings.py +75 -0
  153. rasa/shared/utils/cli.py +30 -0
  154. rasa/shared/utils/io.py +65 -2
  155. rasa/shared/utils/llm.py +246 -200
  156. rasa/shared/utils/yaml.py +121 -15
  157. rasa/studio/auth.py +6 -4
  158. rasa/studio/config.py +13 -4
  159. rasa/studio/constants.py +1 -0
  160. rasa/studio/data_handler.py +10 -3
  161. rasa/studio/download.py +19 -13
  162. rasa/studio/train.py +2 -3
  163. rasa/studio/upload.py +19 -11
  164. rasa/telemetry.py +113 -58
  165. rasa/tracing/instrumentation/attribute_extractors.py +32 -17
  166. rasa/utils/common.py +18 -19
  167. rasa/utils/endpoints.py +7 -4
  168. rasa/utils/json_utils.py +60 -0
  169. rasa/utils/licensing.py +9 -1
  170. rasa/utils/ml_utils.py +4 -2
  171. rasa/validator.py +213 -3
  172. rasa/version.py +1 -1
  173. rasa_pro-3.10.16.dist-info/METADATA +196 -0
  174. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/RECORD +179 -113
  175. rasa/nlu/classifiers/llm_intent_classifier.py +0 -519
  176. rasa/shared/providers/openai/clients.py +0 -43
  177. rasa/shared/providers/openai/session_handler.py +0 -110
  178. rasa_pro-3.9.18.dist-info/METADATA +0 -563
  179. /rasa/{shared/providers/openai → cli/project_templates/tutorial/actions}/__init__.py +0 -0
  180. /rasa/cli/project_templates/tutorial/{actions.py → actions/actions.py} +0 -0
  181. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/NOTICE +0 -0
  182. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/WHEEL +0 -0
  183. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,124 @@
1
+ import os
2
+ from typing import Optional, Union
3
+
4
+ import httpx
5
+ import litellm
6
+ from rasa.shared.constants import (
7
+ RASA_CA_BUNDLE_ENV_VAR,
8
+ REQUESTS_CA_BUNDLE_ENV_VAR,
9
+ RASA_SSL_CERTIFICATE_ENV_VAR,
10
+ LITELLM_SSL_VERIFY_ENV_VAR,
11
+ LITELLM_SSL_CERTIFICATE_ENV_VAR,
12
+ )
13
+
14
+ import structlog
15
+
16
+ from rasa.shared.utils.io import raise_deprecation_warning
17
+
18
+ structlogger = structlog.get_logger()
19
+
20
+
21
+ def ensure_ssl_certificates_for_litellm_non_openai_based_clients() -> None:
22
+ """
23
+ Ensure SSL certificates configuration for LiteLLM based on environment
24
+ variables for clients that are not utilizing OpenAI's clients from
25
+ `openai` library.
26
+ """
27
+ ssl_verify = _get_ssl_verify()
28
+ ssl_certificate = _get_ssl_cert()
29
+
30
+ structlogger.debug(
31
+ "ensure_ssl_certificates_for_litellm_non_openai_based_clients",
32
+ ssl_verify=ssl_verify,
33
+ ssl_certificate=ssl_certificate,
34
+ )
35
+
36
+ if ssl_verify is not None:
37
+ litellm.ssl_verify = ssl_verify
38
+ if ssl_certificate is not None:
39
+ litellm.ssl_certificate = ssl_certificate
40
+
41
+
42
+ def ensure_ssl_certificates_for_litellm_openai_based_clients() -> None:
43
+ """
44
+ Ensure SSL certificates configuration for LiteLLM based on environment
45
+ variables for clients that are utilizing OpenAI's clients from
46
+ `openai` library.
47
+
48
+ The ssl configuration is ensured by setting `litellm.client_session` and
49
+ `litellm.aclient_session` if not previously set.
50
+ """
51
+ client_args = {}
52
+
53
+ ssl_verify = _get_ssl_verify()
54
+ ssl_certificate = _get_ssl_cert()
55
+
56
+ structlogger.debug(
57
+ "ensure_ssl_certificates_for_litellm_openai_based_clients",
58
+ ssl_verify=ssl_verify,
59
+ ssl_certificate=ssl_certificate,
60
+ )
61
+
62
+ if ssl_verify is not None:
63
+ client_args["verify"] = ssl_verify
64
+ if ssl_certificate is not None:
65
+ client_args["cert"] = ssl_certificate
66
+
67
+ if client_args and not isinstance(litellm.aclient_session, httpx.AsyncClient):
68
+ litellm.aclient_session = httpx.AsyncClient(**client_args)
69
+ if client_args and not isinstance(litellm.client_session, httpx.Client):
70
+ litellm.client_session = httpx.Client(**client_args)
71
+
72
+
73
+ def _get_ssl_verify() -> Optional[Union[bool, str]]:
74
+ """
75
+ Environment variable priority (ssl verify):
76
+ 1. `RASA_CA_BUNDLE`: Preferred for SSL verification.
77
+ 2. `REQUESTS_CA_BUNDLE`: Deprecated; use `RASA_CA_BUNDLE_ENV_VAR` instead.
78
+ 3. `SSL_VERIFY`: Fallback for SSL verification.
79
+
80
+ Returns:
81
+ Path to a self-signed SSL certificate or None if no SSL certificate is found.
82
+ """
83
+ if os.environ.get(REQUESTS_CA_BUNDLE_ENV_VAR) and os.environ.get(
84
+ RASA_CA_BUNDLE_ENV_VAR
85
+ ):
86
+ raise_deprecation_warning(
87
+ "Both REQUESTS_CA_BUNDLE and RASA_CA_BUNDLE environment variables are set. "
88
+ "RASA_CA_BUNDLE will be used as the SSL verification path.\n"
89
+ "Support of the REQUESTS_CA_BUNDLE environment variable is deprecated and "
90
+ "will be removed in Rasa Pro 4.0.0. Please set the RASA_CA_BUNDLE "
91
+ "environment variable instead."
92
+ )
93
+ elif os.environ.get(REQUESTS_CA_BUNDLE_ENV_VAR):
94
+ raise_deprecation_warning(
95
+ "Support of the REQUESTS_CA_BUNDLE environment variable is deprecated and "
96
+ "will be removed in Rasa Pro 4.0.0. Please set the RASA_CA_BUNDLE "
97
+ "environment variable instead."
98
+ )
99
+
100
+ return (
101
+ os.environ.get(RASA_CA_BUNDLE_ENV_VAR)
102
+ # Deprecated
103
+ or os.environ.get(REQUESTS_CA_BUNDLE_ENV_VAR)
104
+ # From LiteLLM, use as a fallback
105
+ or os.environ.get(LITELLM_SSL_VERIFY_ENV_VAR)
106
+ or None
107
+ )
108
+
109
+
110
+ def _get_ssl_cert() -> Optional[str]:
111
+ """
112
+ Environment variable priority (ssl certificate):
113
+ 1. `RASA_SSL_CERTIFICATE`: Preferred for client certificate.
114
+ 2. `SSL_CERTIFICATE`: Fallback for client certificate.
115
+
116
+ Returns:
117
+ Path to a SSL certificate or None if no SSL certificate is found.
118
+ """
119
+ return (
120
+ os.environ.get(RASA_SSL_CERTIFICATE_ENV_VAR)
121
+ # From LiteLLM, use as a fallback
122
+ or os.environ.get(LITELLM_SSL_CERTIFICATE_ENV_VAR)
123
+ or None
124
+ )
File without changes
@@ -0,0 +1,259 @@
1
+ from abc import abstractmethod
2
+ from typing import Any, Dict, List
3
+
4
+ import litellm
5
+ import logging
6
+ import structlog
7
+ from litellm import aembedding, embedding, validate_environment
8
+
9
+ from rasa.shared.constants import API_BASE_CONFIG_KEY
10
+ from rasa.shared.exceptions import (
11
+ ProviderClientAPIException,
12
+ ProviderClientValidationError,
13
+ )
14
+ from rasa.shared.providers._ssl_verification_utils import (
15
+ ensure_ssl_certificates_for_litellm_non_openai_based_clients,
16
+ ensure_ssl_certificates_for_litellm_openai_based_clients,
17
+ )
18
+ from rasa.shared.providers.embedding.embedding_response import (
19
+ EmbeddingResponse,
20
+ EmbeddingUsage,
21
+ )
22
+ from rasa.shared.utils.io import suppress_logs
23
+
24
+ structlogger = structlog.get_logger()
25
+
26
+ _VALIDATE_ENVIRONMENT_MISSING_KEYS_KEY = "missing_keys"
27
+
28
+
29
+ class _BaseLiteLLMEmbeddingClient:
30
+ """
31
+ An abstract base class for LiteLLM embedding clients.
32
+
33
+ This class defines the interface and common functionality for all clients
34
+ based on LiteLLM.
35
+
36
+ The class is made private to prevent it from being part of the
37
+ public-facing interface, as it serves as an internal base class
38
+ for specific implementations of clients that are currently based on
39
+ LiteLLM.
40
+
41
+ By keeping it private, we ensure that only the derived, concrete
42
+ implementations are exposed to users, maintaining a cleaner and
43
+ more controlled API surface.
44
+ """
45
+
46
+ def __init__(self): # type: ignore
47
+ self._ensure_certificates()
48
+
49
+ @property
50
+ @abstractmethod
51
+ def config(self) -> dict:
52
+ """Returns the configuration for that the embedding client in dict form."""
53
+ pass
54
+
55
+ @property
56
+ @abstractmethod
57
+ def _litellm_model_name(self) -> str:
58
+ """Returns the model name in LiteLLM format based on the Provider/API type."""
59
+ pass
60
+
61
+ @property
62
+ @abstractmethod
63
+ def _litellm_extra_parameters(self) -> Dict[str, Any]:
64
+ """Returns a dictionary of extra parameters which include model
65
+ parameters as well as LiteLLM specific input parameters.
66
+ By default, this returns an empty dictionary (no extra parameters).
67
+ """
68
+ return {}
69
+
70
+ @property
71
+ def _embedding_fn_args(self) -> Dict[str, Any]:
72
+ """Returns the arguments to be passed to the embedding function."""
73
+ return {
74
+ **self._litellm_extra_parameters,
75
+ "model": self._litellm_model_name,
76
+ }
77
+
78
+ def validate_client_setup(self) -> None:
79
+ """Perform client validation. By default only environment variables
80
+ are validated. Override this method to add more validation steps.
81
+
82
+ Raises:
83
+ ProviderClientValidationError if validation fails.
84
+ """
85
+ self._validate_environment_variables()
86
+ self._validate_api_key_not_in_config()
87
+
88
+ def _validate_environment_variables(self) -> None:
89
+ """Validate that the required environment variables are set."""
90
+ validation_info = validate_environment(
91
+ self._litellm_model_name,
92
+ api_base=self._litellm_extra_parameters.get(API_BASE_CONFIG_KEY),
93
+ )
94
+ if missing_environment_variables := validation_info.get(
95
+ _VALIDATE_ENVIRONMENT_MISSING_KEYS_KEY
96
+ ):
97
+ event_info = (
98
+ f"Environment variables: {missing_environment_variables} "
99
+ f"not set. Required for API calls."
100
+ )
101
+ structlogger.error(
102
+ "base_litellm_embedding_client.validate_environment_variables",
103
+ event_info=event_info,
104
+ missing_environment_variables=missing_environment_variables,
105
+ )
106
+ raise ProviderClientValidationError(event_info)
107
+
108
+ def _validate_api_key_not_in_config(self) -> None:
109
+ if "api_key" in self._litellm_extra_parameters:
110
+ event_info = (
111
+ "API Key is set through `api_key` extra parameter."
112
+ "Set API keys through environment variables."
113
+ )
114
+ structlogger.error(
115
+ "base_litellm_client.validate_api_key_not_in_config",
116
+ event_info=event_info,
117
+ )
118
+ raise ProviderClientValidationError(event_info)
119
+
120
+ def validate_documents(self, documents: List[str]) -> None:
121
+ """
122
+ Validates a list of documents to ensure they are suitable for embedding.
123
+
124
+ Args:
125
+ documents: List of documents to be validated.
126
+
127
+ Raises:
128
+ ValueError: If any document is invalid.
129
+ """
130
+ for doc in documents:
131
+ if not isinstance(doc, str):
132
+ raise ValueError("All documents must be strings.")
133
+ if not doc.strip():
134
+ raise ValueError("Documents cannot be empty or whitespace.")
135
+
136
+ @suppress_logs(log_level=logging.WARNING)
137
+ def embed(self, documents: List[str]) -> EmbeddingResponse:
138
+ """
139
+ Embeds a list of documents synchronously.
140
+
141
+ Args:
142
+ documents: List of documents to be embedded.
143
+
144
+ Returns:
145
+ List of embedding vectors.
146
+
147
+ Raises:
148
+ ProviderClientAPIException: If API calls raised an error.
149
+ """
150
+ self.validate_documents(documents)
151
+ try:
152
+ response = embedding(input=documents, **self._embedding_fn_args)
153
+ return self._format_response(response)
154
+ except Exception as e:
155
+ raise ProviderClientAPIException(
156
+ message="Failed to embed documents", original_exception=e
157
+ )
158
+
159
+ @suppress_logs(log_level=logging.WARNING)
160
+ async def aembed(self, documents: List[str]) -> EmbeddingResponse:
161
+ """
162
+ Embeds a list of documents asynchronously.
163
+
164
+ Args:
165
+ documents: List of documents to be embedded.
166
+
167
+ Returns:
168
+ List of embedding vectors.
169
+
170
+ Raises:
171
+ ProviderClientAPIException: If API calls raised an error.
172
+ """
173
+ self.validate_documents(documents)
174
+ try:
175
+ response = await aembedding(input=documents, **self._embedding_fn_args)
176
+ return self._format_response(response)
177
+ except Exception as e:
178
+ raise ProviderClientAPIException(
179
+ message="Failed to embed documents", original_exception=e
180
+ )
181
+
182
+ def _format_response(
183
+ self, response: litellm.EmbeddingResponse
184
+ ) -> EmbeddingResponse:
185
+ """Parses the LiteLLM EmbeddingResponse to Rasa format.
186
+
187
+ Raises:
188
+ ValueError: If any response data is None.
189
+ """
190
+
191
+ # If data is not available (None), raise a ValueError
192
+ if response.data is None:
193
+ message = (
194
+ "Failed to embed documents. Received 'None' " "instead of embeddings."
195
+ )
196
+ structlogger.error(
197
+ "base_litellm_client.format_response.data_is_none",
198
+ message=message,
199
+ response=response.to_dict(),
200
+ )
201
+ raise ValueError(message)
202
+
203
+ # Sort the embeddings by the "index" key
204
+ response.data.sort(key=lambda x: x["index"])
205
+ # Extract the embedding vectors
206
+ embeddings = [data["embedding"] for data in response.data]
207
+ formatted_response = EmbeddingResponse(
208
+ data=embeddings,
209
+ model=response.model,
210
+ )
211
+
212
+ # Process additional usage information if available
213
+ if response.usage:
214
+ completion_tokens = (
215
+ response.usage.completion_tokens
216
+ if hasattr(response.usage, "completion_tokens")
217
+ else 0
218
+ )
219
+ prompt_tokens = (
220
+ response.usage.prompt_tokens
221
+ if hasattr(response.usage, "prompt_tokens")
222
+ else 0
223
+ )
224
+ total_tokens = (
225
+ response.usage.total_tokens
226
+ if hasattr(response.usage, "total_tokens")
227
+ else 0
228
+ )
229
+
230
+ formatted_response.usage = EmbeddingUsage(
231
+ completion_tokens=completion_tokens,
232
+ prompt_tokens=prompt_tokens,
233
+ total_tokens=total_tokens,
234
+ )
235
+
236
+ # Log the response with masked data for brevity
237
+ log_response = formatted_response.to_dict()
238
+ log_response["data"] = "Embedding response data not shown here for brevity."
239
+ structlogger.debug(
240
+ "base_litellm_client.formatted_response",
241
+ formatted_response=log_response,
242
+ )
243
+ return formatted_response
244
+
245
+ @staticmethod
246
+ def _ensure_certificates() -> None:
247
+ """
248
+ Configures SSL certificates for LiteLLM. This method is invoked during
249
+ client initialization.
250
+
251
+ LiteLLM may utilize `openai` clients or other providers that require
252
+ SSL verification settings through the `SSL_VERIFY` / `SSL_CERTIFICATE`
253
+ environment variables or the `litellm.ssl_verify` /
254
+ `litellm.ssl_certificate` global settings.
255
+
256
+ This method ensures proper SSL configuration for both cases.
257
+ """
258
+ ensure_ssl_certificates_for_litellm_non_openai_based_clients()
259
+ ensure_ssl_certificates_for_litellm_openai_based_clients()
@@ -0,0 +1,74 @@
1
+ from typing import List
2
+
3
+ from langchain_core.embeddings.embeddings import Embeddings
4
+
5
+ from rasa.shared.providers.embedding.embedding_client import EmbeddingClient
6
+
7
+
8
+ class _LangchainEmbeddingClientAdapter(Embeddings):
9
+ """
10
+ Temporary adapter to bridge differences between LiteLLM and LangChain.
11
+
12
+ Clients instantiated with `embedder_factory` follow our new EmbeddingClient
13
+ protocol, but `langchain`'s vector stores require an `Embeddings` type
14
+ client. This adapter extracts and returns the necessary part of the output
15
+ from our LiteLLM-based clients.
16
+
17
+ This adapter will be removed in ticket:
18
+ https://rasahq.atlassian.net/browse/ENG-1220
19
+ """
20
+
21
+ def __init__(self, client: EmbeddingClient):
22
+ self._client = client
23
+
24
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
25
+ """Embed search docs.
26
+
27
+ Args:
28
+ texts: List of text to embed.
29
+
30
+ Returns:
31
+ List of embeddings.
32
+ """
33
+ response = self._client.embed(documents=texts)
34
+ embedding_vector = response.data
35
+ return embedding_vector
36
+
37
+ def embed_query(self, text: str) -> List[float]:
38
+ """Embed query text.
39
+
40
+ Args:
41
+ text: Text to embed.
42
+
43
+ Returns:
44
+ Embedding.
45
+ """
46
+ response = self._client.embed(documents=[text])
47
+ embedding_vector = response.data[0]
48
+ return embedding_vector
49
+
50
+ async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
51
+ """Asynchronous Embed search docs.
52
+
53
+ Args:
54
+ texts: List of text to embed.
55
+
56
+ Returns:
57
+ List of embeddings.
58
+ """
59
+ response = await self._client.aembed(documents=texts)
60
+ embedding_vector = response.data
61
+ return embedding_vector
62
+
63
+ async def aembed_query(self, text: str) -> List[float]:
64
+ """Asynchronous Embed query text.
65
+
66
+ Args:
67
+ text: Text to embed.
68
+
69
+ Returns:
70
+ Embedding.
71
+ """
72
+ response = await self._client.aembed(documents=[text])
73
+ embedding_vector = response.data[0]
74
+ return embedding_vector