rasa-pro 3.9.18__py3-none-any.whl → 3.10.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (190) hide show
  1. README.md +26 -57
  2. rasa/__init__.py +1 -2
  3. rasa/__main__.py +5 -0
  4. rasa/anonymization/anonymization_rule_executor.py +2 -2
  5. rasa/api.py +26 -22
  6. rasa/cli/arguments/data.py +27 -2
  7. rasa/cli/arguments/default_arguments.py +25 -3
  8. rasa/cli/arguments/run.py +9 -9
  9. rasa/cli/arguments/train.py +2 -0
  10. rasa/cli/data.py +70 -8
  11. rasa/cli/e2e_test.py +108 -433
  12. rasa/cli/interactive.py +1 -0
  13. rasa/cli/llm_fine_tuning.py +395 -0
  14. rasa/cli/project_templates/calm/endpoints.yml +1 -1
  15. rasa/cli/project_templates/tutorial/endpoints.yml +1 -1
  16. rasa/cli/run.py +14 -13
  17. rasa/cli/scaffold.py +10 -8
  18. rasa/cli/train.py +8 -7
  19. rasa/cli/utils.py +15 -0
  20. rasa/constants.py +7 -1
  21. rasa/core/actions/action.py +98 -49
  22. rasa/core/actions/action_run_slot_rejections.py +4 -1
  23. rasa/core/actions/custom_action_executor.py +9 -6
  24. rasa/core/actions/direct_custom_actions_executor.py +80 -0
  25. rasa/core/actions/e2e_stub_custom_action_executor.py +68 -0
  26. rasa/core/actions/grpc_custom_action_executor.py +2 -2
  27. rasa/core/actions/http_custom_action_executor.py +6 -5
  28. rasa/core/agent.py +21 -17
  29. rasa/core/channels/__init__.py +2 -0
  30. rasa/core/channels/audiocodes.py +1 -16
  31. rasa/core/channels/inspector/dist/index.html +0 -2
  32. rasa/core/channels/inspector/index.html +0 -2
  33. rasa/core/channels/voice_aware/__init__.py +0 -0
  34. rasa/core/channels/voice_aware/jambonz.py +103 -0
  35. rasa/core/channels/voice_aware/jambonz_protocol.py +344 -0
  36. rasa/core/channels/voice_aware/utils.py +20 -0
  37. rasa/core/channels/voice_native/__init__.py +0 -0
  38. rasa/core/constants.py +6 -1
  39. rasa/core/featurizers/single_state_featurizer.py +1 -22
  40. rasa/core/featurizers/tracker_featurizers.py +18 -115
  41. rasa/core/information_retrieval/faiss.py +7 -4
  42. rasa/core/information_retrieval/information_retrieval.py +8 -0
  43. rasa/core/information_retrieval/milvus.py +9 -2
  44. rasa/core/information_retrieval/qdrant.py +1 -1
  45. rasa/core/nlg/contextual_response_rephraser.py +32 -10
  46. rasa/core/nlg/summarize.py +4 -3
  47. rasa/core/policies/enterprise_search_policy.py +100 -44
  48. rasa/core/policies/flows/flow_executor.py +130 -94
  49. rasa/core/policies/intentless_policy.py +52 -28
  50. rasa/core/policies/ted_policy.py +33 -58
  51. rasa/core/policies/unexpected_intent_policy.py +7 -15
  52. rasa/core/processor.py +20 -53
  53. rasa/core/run.py +5 -4
  54. rasa/core/tracker_store.py +8 -4
  55. rasa/core/utils.py +45 -56
  56. rasa/dialogue_understanding/coexistence/llm_based_router.py +45 -12
  57. rasa/dialogue_understanding/commands/__init__.py +4 -0
  58. rasa/dialogue_understanding/commands/change_flow_command.py +0 -6
  59. rasa/dialogue_understanding/commands/session_start_command.py +59 -0
  60. rasa/dialogue_understanding/commands/set_slot_command.py +1 -5
  61. rasa/dialogue_understanding/commands/utils.py +38 -0
  62. rasa/dialogue_understanding/generator/constants.py +10 -3
  63. rasa/dialogue_understanding/generator/flow_retrieval.py +14 -5
  64. rasa/dialogue_understanding/generator/llm_based_command_generator.py +12 -2
  65. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +106 -87
  66. rasa/dialogue_understanding/generator/nlu_command_adapter.py +28 -6
  67. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +90 -37
  68. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +15 -15
  69. rasa/dialogue_understanding/patterns/session_start.py +37 -0
  70. rasa/dialogue_understanding/processor/command_processor.py +13 -14
  71. rasa/e2e_test/aggregate_test_stats_calculator.py +124 -0
  72. rasa/e2e_test/assertions.py +1181 -0
  73. rasa/e2e_test/assertions_schema.yml +106 -0
  74. rasa/e2e_test/constants.py +20 -0
  75. rasa/e2e_test/e2e_config.py +220 -0
  76. rasa/e2e_test/e2e_config_schema.yml +26 -0
  77. rasa/e2e_test/e2e_test_case.py +131 -8
  78. rasa/e2e_test/e2e_test_converter.py +363 -0
  79. rasa/e2e_test/e2e_test_converter_prompt.jinja2 +70 -0
  80. rasa/e2e_test/e2e_test_coverage_report.py +364 -0
  81. rasa/e2e_test/e2e_test_result.py +26 -6
  82. rasa/e2e_test/e2e_test_runner.py +491 -72
  83. rasa/e2e_test/e2e_test_schema.yml +96 -0
  84. rasa/e2e_test/pykwalify_extensions.py +39 -0
  85. rasa/e2e_test/stub_custom_action.py +70 -0
  86. rasa/e2e_test/utils/__init__.py +0 -0
  87. rasa/e2e_test/utils/e2e_yaml_utils.py +55 -0
  88. rasa/e2e_test/utils/io.py +596 -0
  89. rasa/e2e_test/utils/validation.py +80 -0
  90. rasa/engine/recipes/default_components.py +0 -2
  91. rasa/engine/storage/local_model_storage.py +0 -1
  92. rasa/env.py +9 -0
  93. rasa/keys +1 -0
  94. rasa/llm_fine_tuning/__init__.py +0 -0
  95. rasa/llm_fine_tuning/annotation_module.py +241 -0
  96. rasa/llm_fine_tuning/conversations.py +144 -0
  97. rasa/llm_fine_tuning/llm_data_preparation_module.py +178 -0
  98. rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +407 -0
  99. rasa/llm_fine_tuning/paraphrasing/__init__.py +0 -0
  100. rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +281 -0
  101. rasa/llm_fine_tuning/paraphrasing/default_rephrase_prompt_template.jina2 +44 -0
  102. rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +121 -0
  103. rasa/llm_fine_tuning/paraphrasing/rephrased_user_message.py +10 -0
  104. rasa/llm_fine_tuning/paraphrasing_module.py +128 -0
  105. rasa/llm_fine_tuning/storage.py +174 -0
  106. rasa/llm_fine_tuning/train_test_split_module.py +441 -0
  107. rasa/model_training.py +48 -16
  108. rasa/nlu/classifiers/diet_classifier.py +25 -38
  109. rasa/nlu/classifiers/logistic_regression_classifier.py +9 -44
  110. rasa/nlu/classifiers/sklearn_intent_classifier.py +16 -37
  111. rasa/nlu/extractors/crf_entity_extractor.py +50 -93
  112. rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +45 -78
  113. rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +17 -52
  114. rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +3 -5
  115. rasa/nlu/persistor.py +129 -32
  116. rasa/server.py +45 -10
  117. rasa/shared/constants.py +63 -15
  118. rasa/shared/core/domain.py +15 -12
  119. rasa/shared/core/events.py +28 -2
  120. rasa/shared/core/flows/flow.py +208 -13
  121. rasa/shared/core/flows/flow_path.py +84 -0
  122. rasa/shared/core/flows/flows_list.py +28 -10
  123. rasa/shared/core/flows/flows_yaml_schema.json +269 -193
  124. rasa/shared/core/flows/validation.py +112 -25
  125. rasa/shared/core/flows/yaml_flows_io.py +149 -10
  126. rasa/shared/core/trackers.py +6 -0
  127. rasa/shared/core/training_data/visualization.html +2 -2
  128. rasa/shared/exceptions.py +4 -0
  129. rasa/shared/importers/importer.py +60 -11
  130. rasa/shared/importers/remote_importer.py +196 -0
  131. rasa/shared/nlu/constants.py +2 -0
  132. rasa/shared/nlu/training_data/features.py +2 -120
  133. rasa/shared/providers/_configs/__init__.py +0 -0
  134. rasa/shared/providers/_configs/azure_openai_client_config.py +181 -0
  135. rasa/shared/providers/_configs/client_config.py +57 -0
  136. rasa/shared/providers/_configs/default_litellm_client_config.py +130 -0
  137. rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +234 -0
  138. rasa/shared/providers/_configs/openai_client_config.py +175 -0
  139. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +171 -0
  140. rasa/shared/providers/_configs/utils.py +101 -0
  141. rasa/shared/providers/_ssl_verification_utils.py +124 -0
  142. rasa/shared/providers/embedding/__init__.py +0 -0
  143. rasa/shared/providers/embedding/_base_litellm_embedding_client.py +254 -0
  144. rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py +74 -0
  145. rasa/shared/providers/embedding/azure_openai_embedding_client.py +277 -0
  146. rasa/shared/providers/embedding/default_litellm_embedding_client.py +102 -0
  147. rasa/shared/providers/embedding/embedding_client.py +90 -0
  148. rasa/shared/providers/embedding/embedding_response.py +41 -0
  149. rasa/shared/providers/embedding/huggingface_local_embedding_client.py +191 -0
  150. rasa/shared/providers/embedding/openai_embedding_client.py +172 -0
  151. rasa/shared/providers/llm/__init__.py +0 -0
  152. rasa/shared/providers/llm/_base_litellm_client.py +227 -0
  153. rasa/shared/providers/llm/azure_openai_llm_client.py +338 -0
  154. rasa/shared/providers/llm/default_litellm_llm_client.py +84 -0
  155. rasa/shared/providers/llm/llm_client.py +76 -0
  156. rasa/shared/providers/llm/llm_response.py +50 -0
  157. rasa/shared/providers/llm/openai_llm_client.py +155 -0
  158. rasa/shared/providers/llm/self_hosted_llm_client.py +169 -0
  159. rasa/shared/providers/mappings.py +75 -0
  160. rasa/shared/utils/cli.py +30 -0
  161. rasa/shared/utils/io.py +65 -3
  162. rasa/shared/utils/llm.py +223 -200
  163. rasa/shared/utils/yaml.py +122 -7
  164. rasa/studio/download.py +19 -13
  165. rasa/studio/train.py +2 -3
  166. rasa/studio/upload.py +2 -3
  167. rasa/telemetry.py +113 -58
  168. rasa/tracing/config.py +2 -3
  169. rasa/tracing/instrumentation/attribute_extractors.py +29 -17
  170. rasa/tracing/instrumentation/instrumentation.py +4 -47
  171. rasa/utils/common.py +18 -19
  172. rasa/utils/endpoints.py +7 -4
  173. rasa/utils/io.py +66 -0
  174. rasa/utils/json_utils.py +60 -0
  175. rasa/utils/licensing.py +9 -1
  176. rasa/utils/ml_utils.py +4 -2
  177. rasa/utils/tensorflow/model_data.py +193 -2
  178. rasa/validator.py +196 -1
  179. rasa/version.py +1 -1
  180. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.4.dist-info}/METADATA +47 -72
  181. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.4.dist-info}/RECORD +186 -121
  182. rasa/nlu/classifiers/llm_intent_classifier.py +0 -519
  183. rasa/shared/providers/openai/clients.py +0 -43
  184. rasa/shared/providers/openai/session_handler.py +0 -110
  185. rasa/utils/tensorflow/feature_array.py +0 -366
  186. /rasa/{shared/providers/openai → cli/project_templates/tutorial/actions}/__init__.py +0 -0
  187. /rasa/cli/project_templates/tutorial/{actions.py → actions/actions.py} +0 -0
  188. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.4.dist-info}/NOTICE +0 -0
  189. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.4.dist-info}/WHEEL +0 -0
  190. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.4.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,101 @@
1
+ import structlog
2
+ from rasa.shared.utils.io import raise_deprecation_warning
3
+
4
+ structlogger = structlog.get_logger()
5
+
6
+
7
+ def resolve_aliases(config: dict, deprecated_alias_mapping: dict) -> dict:
8
+ """
9
+ Resolve aliases in the configuration to standard keys.
10
+
11
+ Args:
12
+ config: Dictionary containing the configuration.
13
+ deprecated_alias_mapping: Dictionary mapping aliases to
14
+ their standard keys.
15
+
16
+ Returns:
17
+ New dictionary containing the processed configuration.
18
+ """
19
+ config = config.copy()
20
+
21
+ for alias, standard_key in deprecated_alias_mapping.items():
22
+ # We check for the alias instead of the standard key because our goal is to
23
+ # update the standard key when the alias is found. Since the standard key is
24
+ # always included in the default component configurations, we overwrite it
25
+ # with the alias value if the alias exists.
26
+ if alias in config:
27
+ config[standard_key] = config.pop(alias)
28
+
29
+ return config
30
+
31
+
32
+ def raise_deprecation_warnings(config: dict, deprecated_alias_mapping: dict) -> None:
33
+ """
34
+ Raises warnings for deprecated keys in the configuration.
35
+
36
+ Args:
37
+ config: Dictionary containing the configuration.
38
+ deprecated_alias_mapping: Dictionary mapping deprecated keys to
39
+ their standard keys.
40
+
41
+ Raises:
42
+ DeprecationWarning: If any deprecated key is found in the config.
43
+ """
44
+ for alias, standard_key in deprecated_alias_mapping.items():
45
+ if alias in config:
46
+ raise_deprecation_warning(
47
+ message=(
48
+ f"'{alias}' is deprecated and will be removed in "
49
+ f"4.0.0. Use '{standard_key}' instead."
50
+ )
51
+ )
52
+
53
+
54
+ def validate_required_keys(config: dict, required_keys: list) -> None:
55
+ """
56
+ Validates that the passed config contains all the required keys.
57
+
58
+ Args:
59
+ config: Dictionary containing the configuration.
60
+ required_keys: List of keys that must be present in the config.
61
+
62
+ Raises:
63
+ ValueError: If any required key is missing.
64
+ """
65
+ missing_keys = [key for key in required_keys if key not in config]
66
+ if missing_keys:
67
+ message = f"Missing required keys '{missing_keys}' for configuration."
68
+ structlogger.error(
69
+ "validate_required_keys",
70
+ message=message,
71
+ missing_keys=missing_keys,
72
+ config=config,
73
+ )
74
+ raise ValueError(message)
75
+
76
+
77
+ def validate_forbidden_keys(config: dict, forbidden_keys: list) -> None:
78
+ """
79
+ Validates that the passed config doesn't contain any forbidden keys.
80
+
81
+ Args:
82
+ config: Dictionary containing the configuration.
83
+ forbidden_keys: List of keys that are forbidden in the config.
84
+
85
+ Raises:
86
+ ValueError: If any forbidden key is present.
87
+ """
88
+ forbidden_keys_in_config = set(config.keys()).intersection(set(forbidden_keys))
89
+
90
+ if forbidden_keys_in_config:
91
+ message = (
92
+ f"Forbidden keys '{forbidden_keys_in_config}' present "
93
+ f"in the configuration."
94
+ )
95
+ structlogger.error(
96
+ "validate_forbidden_keys",
97
+ message=message,
98
+ forbidden_keys=forbidden_keys_in_config,
99
+ config=config,
100
+ )
101
+ raise ValueError(message)
@@ -0,0 +1,124 @@
1
+ import os
2
+ from typing import Optional, Union
3
+
4
+ import httpx
5
+ import litellm
6
+ from rasa.shared.constants import (
7
+ RASA_CA_BUNDLE_ENV_VAR,
8
+ REQUESTS_CA_BUNDLE_ENV_VAR,
9
+ RASA_SSL_CERTIFICATE_ENV_VAR,
10
+ LITELLM_SSL_VERIFY_ENV_VAR,
11
+ LITELLM_SSL_CERTIFICATE_ENV_VAR,
12
+ )
13
+
14
+ import structlog
15
+
16
+ from rasa.shared.utils.io import raise_deprecation_warning
17
+
18
+ structlogger = structlog.get_logger()
19
+
20
+
21
+ def ensure_ssl_certificates_for_litellm_non_openai_based_clients() -> None:
22
+ """
23
+ Ensure SSL certificates configuration for LiteLLM based on environment
24
+ variables for clients that are not utilizing OpenAI's clients from
25
+ `openai` library.
26
+ """
27
+ ssl_verify = _get_ssl_verify()
28
+ ssl_certificate = _get_ssl_cert()
29
+
30
+ structlogger.debug(
31
+ "ensure_ssl_certificates_for_litellm_non_openai_based_clients",
32
+ ssl_verify=ssl_verify,
33
+ ssl_certificate=ssl_certificate,
34
+ )
35
+
36
+ if ssl_verify is not None:
37
+ litellm.ssl_verify = ssl_verify
38
+ if ssl_certificate is not None:
39
+ litellm.ssl_certificate = ssl_certificate
40
+
41
+
42
+ def ensure_ssl_certificates_for_litellm_openai_based_clients() -> None:
43
+ """
44
+ Ensure SSL certificates configuration for LiteLLM based on environment
45
+ variables for clients that are utilizing OpenAI's clients from
46
+ `openai` library.
47
+
48
+ The ssl configuration is ensured by setting `litellm.client_session` and
49
+ `litellm.aclient_session` if not previously set.
50
+ """
51
+ client_args = {}
52
+
53
+ ssl_verify = _get_ssl_verify()
54
+ ssl_certificate = _get_ssl_cert()
55
+
56
+ structlogger.debug(
57
+ "ensure_ssl_certificates_for_litellm_openai_based_clients",
58
+ ssl_verify=ssl_verify,
59
+ ssl_certificate=ssl_certificate,
60
+ )
61
+
62
+ if ssl_verify is not None:
63
+ client_args["verify"] = ssl_verify
64
+ if ssl_certificate is not None:
65
+ client_args["cert"] = ssl_certificate
66
+
67
+ if client_args and not isinstance(litellm.aclient_session, httpx.AsyncClient):
68
+ litellm.aclient_session = httpx.AsyncClient(**client_args)
69
+ if client_args and not isinstance(litellm.client_session, httpx.Client):
70
+ litellm.client_session = httpx.Client(**client_args)
71
+
72
+
73
+ def _get_ssl_verify() -> Optional[Union[bool, str]]:
74
+ """
75
+ Environment variable priority (ssl verify):
76
+ 1. `RASA_CA_BUNDLE`: Preferred for SSL verification.
77
+ 2. `REQUESTS_CA_BUNDLE`: Deprecated; use `RASA_CA_BUNDLE_ENV_VAR` instead.
78
+ 3. `SSL_VERIFY`: Fallback for SSL verification.
79
+
80
+ Returns:
81
+ Path to a self-signed SSL certificate or None if no SSL certificate is found.
82
+ """
83
+ if os.environ.get(REQUESTS_CA_BUNDLE_ENV_VAR) and os.environ.get(
84
+ RASA_CA_BUNDLE_ENV_VAR
85
+ ):
86
+ raise_deprecation_warning(
87
+ "Both REQUESTS_CA_BUNDLE and RASA_CA_BUNDLE environment variables are set. "
88
+ "RASA_CA_BUNDLE will be used as the SSL verification path.\n"
89
+ "Support of the REQUESTS_CA_BUNDLE environment variable is deprecated and "
90
+ "will be removed in Rasa Pro 4.0.0. Please set the RASA_CA_BUNDLE "
91
+ "environment variable instead."
92
+ )
93
+ elif os.environ.get(REQUESTS_CA_BUNDLE_ENV_VAR):
94
+ raise_deprecation_warning(
95
+ "Support of the REQUESTS_CA_BUNDLE environment variable is deprecated and "
96
+ "will be removed in Rasa Pro 4.0.0. Please set the RASA_CA_BUNDLE "
97
+ "environment variable instead."
98
+ )
99
+
100
+ return (
101
+ os.environ.get(RASA_CA_BUNDLE_ENV_VAR)
102
+ # Deprecated
103
+ or os.environ.get(REQUESTS_CA_BUNDLE_ENV_VAR)
104
+ # From LiteLLM, use as a fallback
105
+ or os.environ.get(LITELLM_SSL_VERIFY_ENV_VAR)
106
+ or None
107
+ )
108
+
109
+
110
+ def _get_ssl_cert() -> Optional[str]:
111
+ """
112
+ Environment variable priority (ssl certificate):
113
+ 1. `RASA_SSL_CERTIFICATE`: Preferred for client certificate.
114
+ 2. `SSL_CERTIFICATE`: Fallback for client certificate.
115
+
116
+ Returns:
117
+ Path to a SSL certificate or None if no SSL certificate is found.
118
+ """
119
+ return (
120
+ os.environ.get(RASA_SSL_CERTIFICATE_ENV_VAR)
121
+ # From LiteLLM, use as a fallback
122
+ or os.environ.get(LITELLM_SSL_CERTIFICATE_ENV_VAR)
123
+ or None
124
+ )
File without changes
@@ -0,0 +1,254 @@
1
+ from abc import abstractmethod
2
+ from typing import Any, Dict, List
3
+
4
+ import litellm
5
+ import logging
6
+ import structlog
7
+ from litellm import aembedding, embedding, validate_environment
8
+ from rasa.shared.exceptions import (
9
+ ProviderClientAPIException,
10
+ ProviderClientValidationError,
11
+ )
12
+ from rasa.shared.providers._ssl_verification_utils import (
13
+ ensure_ssl_certificates_for_litellm_non_openai_based_clients,
14
+ ensure_ssl_certificates_for_litellm_openai_based_clients,
15
+ )
16
+ from rasa.shared.providers.embedding.embedding_response import (
17
+ EmbeddingResponse,
18
+ EmbeddingUsage,
19
+ )
20
+ from rasa.shared.utils.io import suppress_logs
21
+
22
+ structlogger = structlog.get_logger()
23
+
24
+ _VALIDATE_ENVIRONMENT_MISSING_KEYS_KEY = "missing_keys"
25
+
26
+
27
+ class _BaseLiteLLMEmbeddingClient:
28
+ """
29
+ An abstract base class for LiteLLM embedding clients.
30
+
31
+ This class defines the interface and common functionality for all clients
32
+ based on LiteLLM.
33
+
34
+ The class is made private to prevent it from being part of the
35
+ public-facing interface, as it serves as an internal base class
36
+ for specific implementations of clients that are currently based on
37
+ LiteLLM.
38
+
39
+ By keeping it private, we ensure that only the derived, concrete
40
+ implementations are exposed to users, maintaining a cleaner and
41
+ more controlled API surface.
42
+ """
43
+
44
+ def __init__(self): # type: ignore
45
+ self._ensure_certificates()
46
+
47
+ @property
48
+ @abstractmethod
49
+ def config(self) -> dict:
50
+ """Returns the configuration for that the embedding client in dict form."""
51
+ pass
52
+
53
+ @property
54
+ @abstractmethod
55
+ def _litellm_model_name(self) -> str:
56
+ """Returns the model name in LiteLLM format based on the Provider/API type."""
57
+ pass
58
+
59
+ @property
60
+ @abstractmethod
61
+ def _litellm_extra_parameters(self) -> Dict[str, Any]:
62
+ """Returns a dictionary of extra parameters which include model
63
+ parameters as well as LiteLLM specific input parameters.
64
+ By default, this returns an empty dictionary (no extra parameters).
65
+ """
66
+ return {}
67
+
68
+ @property
69
+ def _embedding_fn_args(self) -> Dict[str, Any]:
70
+ """Returns the arguments to be passed to the embedding function."""
71
+ return {
72
+ **self._litellm_extra_parameters,
73
+ "model": self._litellm_model_name,
74
+ }
75
+
76
+ def validate_client_setup(self) -> None:
77
+ """Perform client validation. By default only environment variables
78
+ are validated. Override this method to add more validation steps.
79
+
80
+ Raises:
81
+ ProviderClientValidationError if validation fails.
82
+ """
83
+ self._validate_environment_variables()
84
+ self._validate_api_key_not_in_config()
85
+
86
+ def _validate_environment_variables(self) -> None:
87
+ """Validate that the required environment variables are set."""
88
+ validation_info = validate_environment(self._litellm_model_name)
89
+ if missing_environment_variables := validation_info.get(
90
+ _VALIDATE_ENVIRONMENT_MISSING_KEYS_KEY
91
+ ):
92
+ event_info = (
93
+ f"Environment variables: {missing_environment_variables} "
94
+ f"not set. Required for API calls."
95
+ )
96
+ structlogger.error(
97
+ "base_litellm_embedding_client.validate_environment_variables",
98
+ event_info=event_info,
99
+ missing_environment_variables=missing_environment_variables,
100
+ )
101
+ raise ProviderClientValidationError(event_info)
102
+
103
+ def _validate_api_key_not_in_config(self) -> None:
104
+ if "api_key" in self._litellm_extra_parameters:
105
+ event_info = (
106
+ "API Key is set through `api_key` extra parameter."
107
+ "Set API keys through environment variables."
108
+ )
109
+ structlogger.error(
110
+ "base_litellm_client.validate_api_key_not_in_config",
111
+ event_info=event_info,
112
+ )
113
+ raise ProviderClientValidationError(event_info)
114
+
115
+ def validate_documents(self, documents: List[str]) -> None:
116
+ """
117
+ Validates a list of documents to ensure they are suitable for embedding.
118
+
119
+ Args:
120
+ documents: List of documents to be validated.
121
+
122
+ Raises:
123
+ ValueError: If any document is invalid.
124
+ """
125
+ for doc in documents:
126
+ if not isinstance(doc, str):
127
+ raise ValueError("All documents must be strings.")
128
+ if not doc.strip():
129
+ raise ValueError("Documents cannot be empty or whitespace.")
130
+
131
+ @suppress_logs(log_level=logging.WARNING)
132
+ def embed(self, documents: List[str]) -> EmbeddingResponse:
133
+ """
134
+ Embeds a list of documents synchronously.
135
+
136
+ Args:
137
+ documents: List of documents to be embedded.
138
+
139
+ Returns:
140
+ List of embedding vectors.
141
+
142
+ Raises:
143
+ ProviderClientAPIException: If API calls raised an error.
144
+ """
145
+ self.validate_documents(documents)
146
+ try:
147
+ response = embedding(input=documents, **self._embedding_fn_args)
148
+ return self._format_response(response)
149
+ except Exception as e:
150
+ raise ProviderClientAPIException(
151
+ message="Failed to embed documents", original_exception=e
152
+ )
153
+
154
+ @suppress_logs(log_level=logging.WARNING)
155
+ async def aembed(self, documents: List[str]) -> EmbeddingResponse:
156
+ """
157
+ Embeds a list of documents asynchronously.
158
+
159
+ Args:
160
+ documents: List of documents to be embedded.
161
+
162
+ Returns:
163
+ List of embedding vectors.
164
+
165
+ Raises:
166
+ ProviderClientAPIException: If API calls raised an error.
167
+ """
168
+ self.validate_documents(documents)
169
+ try:
170
+ response = await aembedding(input=documents, **self._embedding_fn_args)
171
+ return self._format_response(response)
172
+ except Exception as e:
173
+ raise ProviderClientAPIException(
174
+ message="Failed to embed documents", original_exception=e
175
+ )
176
+
177
+ def _format_response(
178
+ self, response: litellm.EmbeddingResponse
179
+ ) -> EmbeddingResponse:
180
+ """Parses the LiteLLM EmbeddingResponse to Rasa format.
181
+
182
+ Raises:
183
+ ValueError: If any response data is None.
184
+ """
185
+
186
+ # If data is not available (None), raise a ValueError
187
+ if response.data is None:
188
+ message = (
189
+ "Failed to embed documents. Received 'None' " "instead of embeddings."
190
+ )
191
+ structlogger.error(
192
+ "base_litellm_client.format_response.data_is_none",
193
+ message=message,
194
+ response=response.to_dict(),
195
+ )
196
+ raise ValueError(message)
197
+
198
+ # Sort the embeddings by the "index" key
199
+ response.data.sort(key=lambda x: x["index"])
200
+ # Extract the embedding vectors
201
+ embeddings = [data["embedding"] for data in response.data]
202
+ formatted_response = EmbeddingResponse(
203
+ data=embeddings,
204
+ model=response.model,
205
+ )
206
+
207
+ # Process additional usage information if available
208
+ if response.usage:
209
+ completion_tokens = (
210
+ response.usage.completion_tokens
211
+ if hasattr(response.usage, "completion_tokens")
212
+ else 0
213
+ )
214
+ prompt_tokens = (
215
+ response.usage.prompt_tokens
216
+ if hasattr(response.usage, "prompt_tokens")
217
+ else 0
218
+ )
219
+ total_tokens = (
220
+ response.usage.total_tokens
221
+ if hasattr(response.usage, "total_tokens")
222
+ else 0
223
+ )
224
+
225
+ formatted_response.usage = EmbeddingUsage(
226
+ completion_tokens=completion_tokens,
227
+ prompt_tokens=prompt_tokens,
228
+ total_tokens=total_tokens,
229
+ )
230
+
231
+ # Log the response with masked data for brevity
232
+ log_response = formatted_response.to_dict()
233
+ log_response["data"] = "Embedding response data not shown here for brevity."
234
+ structlogger.debug(
235
+ "base_litellm_client.formatted_response",
236
+ formatted_response=log_response,
237
+ )
238
+ return formatted_response
239
+
240
+ @staticmethod
241
+ def _ensure_certificates() -> None:
242
+ """
243
+ Configures SSL certificates for LiteLLM. This method is invoked during
244
+ client initialization.
245
+
246
+ LiteLLM may utilize `openai` clients or other providers that require
247
+ SSL verification settings through the `SSL_VERIFY` / `SSL_CERTIFICATE`
248
+ environment variables or the `litellm.ssl_verify` /
249
+ `litellm.ssl_certificate` global settings.
250
+
251
+ This method ensures proper SSL configuration for both cases.
252
+ """
253
+ ensure_ssl_certificates_for_litellm_non_openai_based_clients()
254
+ ensure_ssl_certificates_for_litellm_openai_based_clients()
@@ -0,0 +1,74 @@
1
+ from typing import List
2
+
3
+ from langchain_core.embeddings.embeddings import Embeddings
4
+
5
+ from rasa.shared.providers.embedding.embedding_client import EmbeddingClient
6
+
7
+
8
+ class _LangchainEmbeddingClientAdapter(Embeddings):
9
+ """
10
+ Temporary adapter to bridge differences between LiteLLM and LangChain.
11
+
12
+ Clients instantiated with `embedder_factory` follow our new EmbeddingClient
13
+ protocol, but `langchain`'s vector stores require an `Embeddings` type
14
+ client. This adapter extracts and returns the necessary part of the output
15
+ from our LiteLLM-based clients.
16
+
17
+ This adapter will be removed in ticket:
18
+ https://rasahq.atlassian.net/browse/ENG-1220
19
+ """
20
+
21
+ def __init__(self, client: EmbeddingClient):
22
+ self._client = client
23
+
24
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
25
+ """Embed search docs.
26
+
27
+ Args:
28
+ texts: List of text to embed.
29
+
30
+ Returns:
31
+ List of embeddings.
32
+ """
33
+ response = self._client.embed(documents=texts)
34
+ embedding_vector = response.data
35
+ return embedding_vector
36
+
37
+ def embed_query(self, text: str) -> List[float]:
38
+ """Embed query text.
39
+
40
+ Args:
41
+ text: Text to embed.
42
+
43
+ Returns:
44
+ Embedding.
45
+ """
46
+ response = self._client.embed(documents=[text])
47
+ embedding_vector = response.data[0]
48
+ return embedding_vector
49
+
50
+ async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
51
+ """Asynchronous Embed search docs.
52
+
53
+ Args:
54
+ texts: List of text to embed.
55
+
56
+ Returns:
57
+ List of embeddings.
58
+ """
59
+ response = await self._client.aembed(documents=texts)
60
+ embedding_vector = response.data
61
+ return embedding_vector
62
+
63
+ async def aembed_query(self, text: str) -> List[float]:
64
+ """Asynchronous Embed query text.
65
+
66
+ Args:
67
+ text: Text to embed.
68
+
69
+ Returns:
70
+ Embedding.
71
+ """
72
+ response = await self._client.aembed(documents=[text])
73
+ embedding_vector = response.data[0]
74
+ return embedding_vector