rasa-pro 3.9.18__py3-none-any.whl → 3.10.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- README.md +0 -374
- rasa/__init__.py +1 -2
- rasa/__main__.py +5 -0
- rasa/anonymization/anonymization_rule_executor.py +2 -2
- rasa/api.py +27 -23
- rasa/cli/arguments/data.py +27 -2
- rasa/cli/arguments/default_arguments.py +25 -3
- rasa/cli/arguments/run.py +9 -9
- rasa/cli/arguments/train.py +11 -3
- rasa/cli/data.py +70 -8
- rasa/cli/e2e_test.py +104 -431
- rasa/cli/evaluate.py +1 -1
- rasa/cli/interactive.py +1 -0
- rasa/cli/llm_fine_tuning.py +398 -0
- rasa/cli/project_templates/calm/endpoints.yml +1 -1
- rasa/cli/project_templates/tutorial/endpoints.yml +1 -1
- rasa/cli/run.py +15 -14
- rasa/cli/scaffold.py +10 -8
- rasa/cli/studio/studio.py +35 -5
- rasa/cli/train.py +56 -8
- rasa/cli/utils.py +22 -5
- rasa/cli/x.py +1 -1
- rasa/constants.py +7 -1
- rasa/core/actions/action.py +98 -49
- rasa/core/actions/action_run_slot_rejections.py +4 -1
- rasa/core/actions/custom_action_executor.py +9 -6
- rasa/core/actions/direct_custom_actions_executor.py +80 -0
- rasa/core/actions/e2e_stub_custom_action_executor.py +68 -0
- rasa/core/actions/grpc_custom_action_executor.py +2 -2
- rasa/core/actions/http_custom_action_executor.py +6 -5
- rasa/core/agent.py +21 -17
- rasa/core/channels/__init__.py +2 -0
- rasa/core/channels/audiocodes.py +1 -16
- rasa/core/channels/voice_aware/__init__.py +0 -0
- rasa/core/channels/voice_aware/jambonz.py +103 -0
- rasa/core/channels/voice_aware/jambonz_protocol.py +344 -0
- rasa/core/channels/voice_aware/utils.py +20 -0
- rasa/core/channels/voice_native/__init__.py +0 -0
- rasa/core/constants.py +6 -1
- rasa/core/information_retrieval/faiss.py +7 -4
- rasa/core/information_retrieval/information_retrieval.py +8 -0
- rasa/core/information_retrieval/milvus.py +9 -2
- rasa/core/information_retrieval/qdrant.py +1 -1
- rasa/core/nlg/contextual_response_rephraser.py +32 -10
- rasa/core/nlg/summarize.py +4 -3
- rasa/core/policies/enterprise_search_policy.py +113 -45
- rasa/core/policies/flows/flow_executor.py +122 -76
- rasa/core/policies/intentless_policy.py +83 -29
- rasa/core/processor.py +72 -54
- rasa/core/run.py +5 -4
- rasa/core/tracker_store.py +8 -4
- rasa/core/training/interactive.py +1 -1
- rasa/core/utils.py +56 -57
- rasa/dialogue_understanding/coexistence/llm_based_router.py +53 -13
- rasa/dialogue_understanding/commands/__init__.py +6 -0
- rasa/dialogue_understanding/commands/restart_command.py +58 -0
- rasa/dialogue_understanding/commands/session_start_command.py +59 -0
- rasa/dialogue_understanding/commands/utils.py +40 -0
- rasa/dialogue_understanding/generator/constants.py +10 -3
- rasa/dialogue_understanding/generator/flow_retrieval.py +21 -5
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +13 -3
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +134 -90
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +47 -7
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +127 -41
- rasa/dialogue_understanding/patterns/restart.py +37 -0
- rasa/dialogue_understanding/patterns/session_start.py +37 -0
- rasa/dialogue_understanding/processor/command_processor.py +16 -3
- rasa/dialogue_understanding/processor/command_processor_component.py +6 -2
- rasa/e2e_test/aggregate_test_stats_calculator.py +134 -0
- rasa/e2e_test/assertions.py +1223 -0
- rasa/e2e_test/assertions_schema.yml +106 -0
- rasa/e2e_test/constants.py +20 -0
- rasa/e2e_test/e2e_config.py +220 -0
- rasa/e2e_test/e2e_config_schema.yml +26 -0
- rasa/e2e_test/e2e_test_case.py +131 -8
- rasa/e2e_test/e2e_test_converter.py +363 -0
- rasa/e2e_test/e2e_test_converter_prompt.jinja2 +70 -0
- rasa/e2e_test/e2e_test_coverage_report.py +364 -0
- rasa/e2e_test/e2e_test_result.py +26 -6
- rasa/e2e_test/e2e_test_runner.py +493 -71
- rasa/e2e_test/e2e_test_schema.yml +96 -0
- rasa/e2e_test/pykwalify_extensions.py +39 -0
- rasa/e2e_test/stub_custom_action.py +70 -0
- rasa/e2e_test/utils/__init__.py +0 -0
- rasa/e2e_test/utils/e2e_yaml_utils.py +55 -0
- rasa/e2e_test/utils/io.py +598 -0
- rasa/e2e_test/utils/validation.py +80 -0
- rasa/engine/graph.py +9 -3
- rasa/engine/recipes/default_components.py +0 -2
- rasa/engine/recipes/default_recipe.py +10 -2
- rasa/engine/storage/local_model_storage.py +40 -12
- rasa/engine/validation.py +78 -1
- rasa/env.py +9 -0
- rasa/graph_components/providers/story_graph_provider.py +59 -6
- rasa/llm_fine_tuning/__init__.py +0 -0
- rasa/llm_fine_tuning/annotation_module.py +241 -0
- rasa/llm_fine_tuning/conversations.py +144 -0
- rasa/llm_fine_tuning/llm_data_preparation_module.py +178 -0
- rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +407 -0
- rasa/llm_fine_tuning/paraphrasing/__init__.py +0 -0
- rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +281 -0
- rasa/llm_fine_tuning/paraphrasing/default_rephrase_prompt_template.jina2 +44 -0
- rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +121 -0
- rasa/llm_fine_tuning/paraphrasing/rephrased_user_message.py +10 -0
- rasa/llm_fine_tuning/paraphrasing_module.py +128 -0
- rasa/llm_fine_tuning/storage.py +174 -0
- rasa/llm_fine_tuning/train_test_split_module.py +441 -0
- rasa/model_training.py +56 -16
- rasa/nlu/persistor.py +157 -36
- rasa/server.py +45 -10
- rasa/shared/constants.py +76 -16
- rasa/shared/core/domain.py +27 -19
- rasa/shared/core/events.py +28 -2
- rasa/shared/core/flows/flow.py +208 -13
- rasa/shared/core/flows/flow_path.py +84 -0
- rasa/shared/core/flows/flows_list.py +33 -11
- rasa/shared/core/flows/flows_yaml_schema.json +269 -193
- rasa/shared/core/flows/validation.py +112 -25
- rasa/shared/core/flows/yaml_flows_io.py +149 -10
- rasa/shared/core/trackers.py +6 -0
- rasa/shared/core/training_data/structures.py +20 -0
- rasa/shared/core/training_data/visualization.html +2 -2
- rasa/shared/exceptions.py +4 -0
- rasa/shared/importers/importer.py +64 -16
- rasa/shared/nlu/constants.py +2 -0
- rasa/shared/providers/_configs/__init__.py +0 -0
- rasa/shared/providers/_configs/azure_openai_client_config.py +183 -0
- rasa/shared/providers/_configs/client_config.py +57 -0
- rasa/shared/providers/_configs/default_litellm_client_config.py +130 -0
- rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +234 -0
- rasa/shared/providers/_configs/openai_client_config.py +175 -0
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +176 -0
- rasa/shared/providers/_configs/utils.py +101 -0
- rasa/shared/providers/_ssl_verification_utils.py +124 -0
- rasa/shared/providers/embedding/__init__.py +0 -0
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +259 -0
- rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py +74 -0
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +277 -0
- rasa/shared/providers/embedding/default_litellm_embedding_client.py +102 -0
- rasa/shared/providers/embedding/embedding_client.py +90 -0
- rasa/shared/providers/embedding/embedding_response.py +41 -0
- rasa/shared/providers/embedding/huggingface_local_embedding_client.py +191 -0
- rasa/shared/providers/embedding/openai_embedding_client.py +172 -0
- rasa/shared/providers/llm/__init__.py +0 -0
- rasa/shared/providers/llm/_base_litellm_client.py +251 -0
- rasa/shared/providers/llm/azure_openai_llm_client.py +338 -0
- rasa/shared/providers/llm/default_litellm_llm_client.py +84 -0
- rasa/shared/providers/llm/llm_client.py +76 -0
- rasa/shared/providers/llm/llm_response.py +50 -0
- rasa/shared/providers/llm/openai_llm_client.py +155 -0
- rasa/shared/providers/llm/self_hosted_llm_client.py +293 -0
- rasa/shared/providers/mappings.py +75 -0
- rasa/shared/utils/cli.py +30 -0
- rasa/shared/utils/io.py +65 -2
- rasa/shared/utils/llm.py +246 -200
- rasa/shared/utils/yaml.py +121 -15
- rasa/studio/auth.py +6 -4
- rasa/studio/config.py +13 -4
- rasa/studio/constants.py +1 -0
- rasa/studio/data_handler.py +10 -3
- rasa/studio/download.py +19 -13
- rasa/studio/train.py +2 -3
- rasa/studio/upload.py +19 -11
- rasa/telemetry.py +113 -58
- rasa/tracing/instrumentation/attribute_extractors.py +32 -17
- rasa/utils/common.py +18 -19
- rasa/utils/endpoints.py +7 -4
- rasa/utils/json_utils.py +60 -0
- rasa/utils/licensing.py +9 -1
- rasa/utils/ml_utils.py +4 -2
- rasa/validator.py +213 -3
- rasa/version.py +1 -1
- rasa_pro-3.10.16.dist-info/METADATA +196 -0
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/RECORD +179 -113
- rasa/nlu/classifiers/llm_intent_classifier.py +0 -519
- rasa/shared/providers/openai/clients.py +0 -43
- rasa/shared/providers/openai/session_handler.py +0 -110
- rasa_pro-3.9.18.dist-info/METADATA +0 -563
- /rasa/{shared/providers/openai → cli/project_templates/tutorial/actions}/__init__.py +0 -0
- /rasa/cli/project_templates/tutorial/{actions.py → actions/actions.py} +0 -0
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/NOTICE +0 -0
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/WHEEL +0 -0
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Optional, Union
|
|
3
|
+
|
|
4
|
+
import httpx
|
|
5
|
+
import litellm
|
|
6
|
+
from rasa.shared.constants import (
|
|
7
|
+
RASA_CA_BUNDLE_ENV_VAR,
|
|
8
|
+
REQUESTS_CA_BUNDLE_ENV_VAR,
|
|
9
|
+
RASA_SSL_CERTIFICATE_ENV_VAR,
|
|
10
|
+
LITELLM_SSL_VERIFY_ENV_VAR,
|
|
11
|
+
LITELLM_SSL_CERTIFICATE_ENV_VAR,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
import structlog
|
|
15
|
+
|
|
16
|
+
from rasa.shared.utils.io import raise_deprecation_warning
|
|
17
|
+
|
|
18
|
+
structlogger = structlog.get_logger()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def ensure_ssl_certificates_for_litellm_non_openai_based_clients() -> None:
|
|
22
|
+
"""
|
|
23
|
+
Ensure SSL certificates configuration for LiteLLM based on environment
|
|
24
|
+
variables for clients that are not utilizing OpenAI's clients from
|
|
25
|
+
`openai` library.
|
|
26
|
+
"""
|
|
27
|
+
ssl_verify = _get_ssl_verify()
|
|
28
|
+
ssl_certificate = _get_ssl_cert()
|
|
29
|
+
|
|
30
|
+
structlogger.debug(
|
|
31
|
+
"ensure_ssl_certificates_for_litellm_non_openai_based_clients",
|
|
32
|
+
ssl_verify=ssl_verify,
|
|
33
|
+
ssl_certificate=ssl_certificate,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
if ssl_verify is not None:
|
|
37
|
+
litellm.ssl_verify = ssl_verify
|
|
38
|
+
if ssl_certificate is not None:
|
|
39
|
+
litellm.ssl_certificate = ssl_certificate
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def ensure_ssl_certificates_for_litellm_openai_based_clients() -> None:
|
|
43
|
+
"""
|
|
44
|
+
Ensure SSL certificates configuration for LiteLLM based on environment
|
|
45
|
+
variables for clients that are utilizing OpenAI's clients from
|
|
46
|
+
`openai` library.
|
|
47
|
+
|
|
48
|
+
The ssl configuration is ensured by setting `litellm.client_session` and
|
|
49
|
+
`litellm.aclient_session` if not previously set.
|
|
50
|
+
"""
|
|
51
|
+
client_args = {}
|
|
52
|
+
|
|
53
|
+
ssl_verify = _get_ssl_verify()
|
|
54
|
+
ssl_certificate = _get_ssl_cert()
|
|
55
|
+
|
|
56
|
+
structlogger.debug(
|
|
57
|
+
"ensure_ssl_certificates_for_litellm_openai_based_clients",
|
|
58
|
+
ssl_verify=ssl_verify,
|
|
59
|
+
ssl_certificate=ssl_certificate,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
if ssl_verify is not None:
|
|
63
|
+
client_args["verify"] = ssl_verify
|
|
64
|
+
if ssl_certificate is not None:
|
|
65
|
+
client_args["cert"] = ssl_certificate
|
|
66
|
+
|
|
67
|
+
if client_args and not isinstance(litellm.aclient_session, httpx.AsyncClient):
|
|
68
|
+
litellm.aclient_session = httpx.AsyncClient(**client_args)
|
|
69
|
+
if client_args and not isinstance(litellm.client_session, httpx.Client):
|
|
70
|
+
litellm.client_session = httpx.Client(**client_args)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _get_ssl_verify() -> Optional[Union[bool, str]]:
|
|
74
|
+
"""
|
|
75
|
+
Environment variable priority (ssl verify):
|
|
76
|
+
1. `RASA_CA_BUNDLE`: Preferred for SSL verification.
|
|
77
|
+
2. `REQUESTS_CA_BUNDLE`: Deprecated; use `RASA_CA_BUNDLE_ENV_VAR` instead.
|
|
78
|
+
3. `SSL_VERIFY`: Fallback for SSL verification.
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
Path to a self-signed SSL certificate or None if no SSL certificate is found.
|
|
82
|
+
"""
|
|
83
|
+
if os.environ.get(REQUESTS_CA_BUNDLE_ENV_VAR) and os.environ.get(
|
|
84
|
+
RASA_CA_BUNDLE_ENV_VAR
|
|
85
|
+
):
|
|
86
|
+
raise_deprecation_warning(
|
|
87
|
+
"Both REQUESTS_CA_BUNDLE and RASA_CA_BUNDLE environment variables are set. "
|
|
88
|
+
"RASA_CA_BUNDLE will be used as the SSL verification path.\n"
|
|
89
|
+
"Support of the REQUESTS_CA_BUNDLE environment variable is deprecated and "
|
|
90
|
+
"will be removed in Rasa Pro 4.0.0. Please set the RASA_CA_BUNDLE "
|
|
91
|
+
"environment variable instead."
|
|
92
|
+
)
|
|
93
|
+
elif os.environ.get(REQUESTS_CA_BUNDLE_ENV_VAR):
|
|
94
|
+
raise_deprecation_warning(
|
|
95
|
+
"Support of the REQUESTS_CA_BUNDLE environment variable is deprecated and "
|
|
96
|
+
"will be removed in Rasa Pro 4.0.0. Please set the RASA_CA_BUNDLE "
|
|
97
|
+
"environment variable instead."
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
return (
|
|
101
|
+
os.environ.get(RASA_CA_BUNDLE_ENV_VAR)
|
|
102
|
+
# Deprecated
|
|
103
|
+
or os.environ.get(REQUESTS_CA_BUNDLE_ENV_VAR)
|
|
104
|
+
# From LiteLLM, use as a fallback
|
|
105
|
+
or os.environ.get(LITELLM_SSL_VERIFY_ENV_VAR)
|
|
106
|
+
or None
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _get_ssl_cert() -> Optional[str]:
|
|
111
|
+
"""
|
|
112
|
+
Environment variable priority (ssl certificate):
|
|
113
|
+
1. `RASA_SSL_CERTIFICATE`: Preferred for client certificate.
|
|
114
|
+
2. `SSL_CERTIFICATE`: Fallback for client certificate.
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
Path to a SSL certificate or None if no SSL certificate is found.
|
|
118
|
+
"""
|
|
119
|
+
return (
|
|
120
|
+
os.environ.get(RASA_SSL_CERTIFICATE_ENV_VAR)
|
|
121
|
+
# From LiteLLM, use as a fallback
|
|
122
|
+
or os.environ.get(LITELLM_SSL_CERTIFICATE_ENV_VAR)
|
|
123
|
+
or None
|
|
124
|
+
)
|
|
File without changes
|
|
@@ -0,0 +1,259 @@
|
|
|
1
|
+
from abc import abstractmethod
|
|
2
|
+
from typing import Any, Dict, List
|
|
3
|
+
|
|
4
|
+
import litellm
|
|
5
|
+
import logging
|
|
6
|
+
import structlog
|
|
7
|
+
from litellm import aembedding, embedding, validate_environment
|
|
8
|
+
|
|
9
|
+
from rasa.shared.constants import API_BASE_CONFIG_KEY
|
|
10
|
+
from rasa.shared.exceptions import (
|
|
11
|
+
ProviderClientAPIException,
|
|
12
|
+
ProviderClientValidationError,
|
|
13
|
+
)
|
|
14
|
+
from rasa.shared.providers._ssl_verification_utils import (
|
|
15
|
+
ensure_ssl_certificates_for_litellm_non_openai_based_clients,
|
|
16
|
+
ensure_ssl_certificates_for_litellm_openai_based_clients,
|
|
17
|
+
)
|
|
18
|
+
from rasa.shared.providers.embedding.embedding_response import (
|
|
19
|
+
EmbeddingResponse,
|
|
20
|
+
EmbeddingUsage,
|
|
21
|
+
)
|
|
22
|
+
from rasa.shared.utils.io import suppress_logs
|
|
23
|
+
|
|
24
|
+
structlogger = structlog.get_logger()
|
|
25
|
+
|
|
26
|
+
_VALIDATE_ENVIRONMENT_MISSING_KEYS_KEY = "missing_keys"
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class _BaseLiteLLMEmbeddingClient:
|
|
30
|
+
"""
|
|
31
|
+
An abstract base class for LiteLLM embedding clients.
|
|
32
|
+
|
|
33
|
+
This class defines the interface and common functionality for all clients
|
|
34
|
+
based on LiteLLM.
|
|
35
|
+
|
|
36
|
+
The class is made private to prevent it from being part of the
|
|
37
|
+
public-facing interface, as it serves as an internal base class
|
|
38
|
+
for specific implementations of clients that are currently based on
|
|
39
|
+
LiteLLM.
|
|
40
|
+
|
|
41
|
+
By keeping it private, we ensure that only the derived, concrete
|
|
42
|
+
implementations are exposed to users, maintaining a cleaner and
|
|
43
|
+
more controlled API surface.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(self): # type: ignore
|
|
47
|
+
self._ensure_certificates()
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
@abstractmethod
|
|
51
|
+
def config(self) -> dict:
|
|
52
|
+
"""Returns the configuration for that the embedding client in dict form."""
|
|
53
|
+
pass
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
@abstractmethod
|
|
57
|
+
def _litellm_model_name(self) -> str:
|
|
58
|
+
"""Returns the model name in LiteLLM format based on the Provider/API type."""
|
|
59
|
+
pass
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
@abstractmethod
|
|
63
|
+
def _litellm_extra_parameters(self) -> Dict[str, Any]:
|
|
64
|
+
"""Returns a dictionary of extra parameters which include model
|
|
65
|
+
parameters as well as LiteLLM specific input parameters.
|
|
66
|
+
By default, this returns an empty dictionary (no extra parameters).
|
|
67
|
+
"""
|
|
68
|
+
return {}
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def _embedding_fn_args(self) -> Dict[str, Any]:
|
|
72
|
+
"""Returns the arguments to be passed to the embedding function."""
|
|
73
|
+
return {
|
|
74
|
+
**self._litellm_extra_parameters,
|
|
75
|
+
"model": self._litellm_model_name,
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
def validate_client_setup(self) -> None:
|
|
79
|
+
"""Perform client validation. By default only environment variables
|
|
80
|
+
are validated. Override this method to add more validation steps.
|
|
81
|
+
|
|
82
|
+
Raises:
|
|
83
|
+
ProviderClientValidationError if validation fails.
|
|
84
|
+
"""
|
|
85
|
+
self._validate_environment_variables()
|
|
86
|
+
self._validate_api_key_not_in_config()
|
|
87
|
+
|
|
88
|
+
def _validate_environment_variables(self) -> None:
|
|
89
|
+
"""Validate that the required environment variables are set."""
|
|
90
|
+
validation_info = validate_environment(
|
|
91
|
+
self._litellm_model_name,
|
|
92
|
+
api_base=self._litellm_extra_parameters.get(API_BASE_CONFIG_KEY),
|
|
93
|
+
)
|
|
94
|
+
if missing_environment_variables := validation_info.get(
|
|
95
|
+
_VALIDATE_ENVIRONMENT_MISSING_KEYS_KEY
|
|
96
|
+
):
|
|
97
|
+
event_info = (
|
|
98
|
+
f"Environment variables: {missing_environment_variables} "
|
|
99
|
+
f"not set. Required for API calls."
|
|
100
|
+
)
|
|
101
|
+
structlogger.error(
|
|
102
|
+
"base_litellm_embedding_client.validate_environment_variables",
|
|
103
|
+
event_info=event_info,
|
|
104
|
+
missing_environment_variables=missing_environment_variables,
|
|
105
|
+
)
|
|
106
|
+
raise ProviderClientValidationError(event_info)
|
|
107
|
+
|
|
108
|
+
def _validate_api_key_not_in_config(self) -> None:
|
|
109
|
+
if "api_key" in self._litellm_extra_parameters:
|
|
110
|
+
event_info = (
|
|
111
|
+
"API Key is set through `api_key` extra parameter."
|
|
112
|
+
"Set API keys through environment variables."
|
|
113
|
+
)
|
|
114
|
+
structlogger.error(
|
|
115
|
+
"base_litellm_client.validate_api_key_not_in_config",
|
|
116
|
+
event_info=event_info,
|
|
117
|
+
)
|
|
118
|
+
raise ProviderClientValidationError(event_info)
|
|
119
|
+
|
|
120
|
+
def validate_documents(self, documents: List[str]) -> None:
|
|
121
|
+
"""
|
|
122
|
+
Validates a list of documents to ensure they are suitable for embedding.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
documents: List of documents to be validated.
|
|
126
|
+
|
|
127
|
+
Raises:
|
|
128
|
+
ValueError: If any document is invalid.
|
|
129
|
+
"""
|
|
130
|
+
for doc in documents:
|
|
131
|
+
if not isinstance(doc, str):
|
|
132
|
+
raise ValueError("All documents must be strings.")
|
|
133
|
+
if not doc.strip():
|
|
134
|
+
raise ValueError("Documents cannot be empty or whitespace.")
|
|
135
|
+
|
|
136
|
+
@suppress_logs(log_level=logging.WARNING)
|
|
137
|
+
def embed(self, documents: List[str]) -> EmbeddingResponse:
|
|
138
|
+
"""
|
|
139
|
+
Embeds a list of documents synchronously.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
documents: List of documents to be embedded.
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
List of embedding vectors.
|
|
146
|
+
|
|
147
|
+
Raises:
|
|
148
|
+
ProviderClientAPIException: If API calls raised an error.
|
|
149
|
+
"""
|
|
150
|
+
self.validate_documents(documents)
|
|
151
|
+
try:
|
|
152
|
+
response = embedding(input=documents, **self._embedding_fn_args)
|
|
153
|
+
return self._format_response(response)
|
|
154
|
+
except Exception as e:
|
|
155
|
+
raise ProviderClientAPIException(
|
|
156
|
+
message="Failed to embed documents", original_exception=e
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
@suppress_logs(log_level=logging.WARNING)
|
|
160
|
+
async def aembed(self, documents: List[str]) -> EmbeddingResponse:
|
|
161
|
+
"""
|
|
162
|
+
Embeds a list of documents asynchronously.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
documents: List of documents to be embedded.
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
List of embedding vectors.
|
|
169
|
+
|
|
170
|
+
Raises:
|
|
171
|
+
ProviderClientAPIException: If API calls raised an error.
|
|
172
|
+
"""
|
|
173
|
+
self.validate_documents(documents)
|
|
174
|
+
try:
|
|
175
|
+
response = await aembedding(input=documents, **self._embedding_fn_args)
|
|
176
|
+
return self._format_response(response)
|
|
177
|
+
except Exception as e:
|
|
178
|
+
raise ProviderClientAPIException(
|
|
179
|
+
message="Failed to embed documents", original_exception=e
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
def _format_response(
|
|
183
|
+
self, response: litellm.EmbeddingResponse
|
|
184
|
+
) -> EmbeddingResponse:
|
|
185
|
+
"""Parses the LiteLLM EmbeddingResponse to Rasa format.
|
|
186
|
+
|
|
187
|
+
Raises:
|
|
188
|
+
ValueError: If any response data is None.
|
|
189
|
+
"""
|
|
190
|
+
|
|
191
|
+
# If data is not available (None), raise a ValueError
|
|
192
|
+
if response.data is None:
|
|
193
|
+
message = (
|
|
194
|
+
"Failed to embed documents. Received 'None' " "instead of embeddings."
|
|
195
|
+
)
|
|
196
|
+
structlogger.error(
|
|
197
|
+
"base_litellm_client.format_response.data_is_none",
|
|
198
|
+
message=message,
|
|
199
|
+
response=response.to_dict(),
|
|
200
|
+
)
|
|
201
|
+
raise ValueError(message)
|
|
202
|
+
|
|
203
|
+
# Sort the embeddings by the "index" key
|
|
204
|
+
response.data.sort(key=lambda x: x["index"])
|
|
205
|
+
# Extract the embedding vectors
|
|
206
|
+
embeddings = [data["embedding"] for data in response.data]
|
|
207
|
+
formatted_response = EmbeddingResponse(
|
|
208
|
+
data=embeddings,
|
|
209
|
+
model=response.model,
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
# Process additional usage information if available
|
|
213
|
+
if response.usage:
|
|
214
|
+
completion_tokens = (
|
|
215
|
+
response.usage.completion_tokens
|
|
216
|
+
if hasattr(response.usage, "completion_tokens")
|
|
217
|
+
else 0
|
|
218
|
+
)
|
|
219
|
+
prompt_tokens = (
|
|
220
|
+
response.usage.prompt_tokens
|
|
221
|
+
if hasattr(response.usage, "prompt_tokens")
|
|
222
|
+
else 0
|
|
223
|
+
)
|
|
224
|
+
total_tokens = (
|
|
225
|
+
response.usage.total_tokens
|
|
226
|
+
if hasattr(response.usage, "total_tokens")
|
|
227
|
+
else 0
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
formatted_response.usage = EmbeddingUsage(
|
|
231
|
+
completion_tokens=completion_tokens,
|
|
232
|
+
prompt_tokens=prompt_tokens,
|
|
233
|
+
total_tokens=total_tokens,
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
# Log the response with masked data for brevity
|
|
237
|
+
log_response = formatted_response.to_dict()
|
|
238
|
+
log_response["data"] = "Embedding response data not shown here for brevity."
|
|
239
|
+
structlogger.debug(
|
|
240
|
+
"base_litellm_client.formatted_response",
|
|
241
|
+
formatted_response=log_response,
|
|
242
|
+
)
|
|
243
|
+
return formatted_response
|
|
244
|
+
|
|
245
|
+
@staticmethod
|
|
246
|
+
def _ensure_certificates() -> None:
|
|
247
|
+
"""
|
|
248
|
+
Configures SSL certificates for LiteLLM. This method is invoked during
|
|
249
|
+
client initialization.
|
|
250
|
+
|
|
251
|
+
LiteLLM may utilize `openai` clients or other providers that require
|
|
252
|
+
SSL verification settings through the `SSL_VERIFY` / `SSL_CERTIFICATE`
|
|
253
|
+
environment variables or the `litellm.ssl_verify` /
|
|
254
|
+
`litellm.ssl_certificate` global settings.
|
|
255
|
+
|
|
256
|
+
This method ensures proper SSL configuration for both cases.
|
|
257
|
+
"""
|
|
258
|
+
ensure_ssl_certificates_for_litellm_non_openai_based_clients()
|
|
259
|
+
ensure_ssl_certificates_for_litellm_openai_based_clients()
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
3
|
+
from langchain_core.embeddings.embeddings import Embeddings
|
|
4
|
+
|
|
5
|
+
from rasa.shared.providers.embedding.embedding_client import EmbeddingClient
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class _LangchainEmbeddingClientAdapter(Embeddings):
|
|
9
|
+
"""
|
|
10
|
+
Temporary adapter to bridge differences between LiteLLM and LangChain.
|
|
11
|
+
|
|
12
|
+
Clients instantiated with `embedder_factory` follow our new EmbeddingClient
|
|
13
|
+
protocol, but `langchain`'s vector stores require an `Embeddings` type
|
|
14
|
+
client. This adapter extracts and returns the necessary part of the output
|
|
15
|
+
from our LiteLLM-based clients.
|
|
16
|
+
|
|
17
|
+
This adapter will be removed in ticket:
|
|
18
|
+
https://rasahq.atlassian.net/browse/ENG-1220
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, client: EmbeddingClient):
|
|
22
|
+
self._client = client
|
|
23
|
+
|
|
24
|
+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
25
|
+
"""Embed search docs.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
texts: List of text to embed.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
List of embeddings.
|
|
32
|
+
"""
|
|
33
|
+
response = self._client.embed(documents=texts)
|
|
34
|
+
embedding_vector = response.data
|
|
35
|
+
return embedding_vector
|
|
36
|
+
|
|
37
|
+
def embed_query(self, text: str) -> List[float]:
|
|
38
|
+
"""Embed query text.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
text: Text to embed.
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
Embedding.
|
|
45
|
+
"""
|
|
46
|
+
response = self._client.embed(documents=[text])
|
|
47
|
+
embedding_vector = response.data[0]
|
|
48
|
+
return embedding_vector
|
|
49
|
+
|
|
50
|
+
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
51
|
+
"""Asynchronous Embed search docs.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
texts: List of text to embed.
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
List of embeddings.
|
|
58
|
+
"""
|
|
59
|
+
response = await self._client.aembed(documents=texts)
|
|
60
|
+
embedding_vector = response.data
|
|
61
|
+
return embedding_vector
|
|
62
|
+
|
|
63
|
+
async def aembed_query(self, text: str) -> List[float]:
|
|
64
|
+
"""Asynchronous Embed query text.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
text: Text to embed.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
Embedding.
|
|
71
|
+
"""
|
|
72
|
+
response = await self._client.aembed(documents=[text])
|
|
73
|
+
embedding_vector = response.data[0]
|
|
74
|
+
return embedding_vector
|