rasa-pro 3.10.15__py3-none-any.whl → 3.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/__main__.py +31 -15
- rasa/api.py +12 -2
- rasa/cli/arguments/default_arguments.py +24 -4
- rasa/cli/arguments/run.py +15 -0
- rasa/cli/arguments/shell.py +5 -1
- rasa/cli/arguments/train.py +17 -9
- rasa/cli/evaluate.py +7 -7
- rasa/cli/inspect.py +19 -7
- rasa/cli/interactive.py +1 -0
- rasa/cli/project_templates/calm/config.yml +5 -7
- rasa/cli/project_templates/calm/endpoints.yml +15 -2
- rasa/cli/project_templates/tutorial/config.yml +8 -5
- rasa/cli/project_templates/tutorial/data/flows.yml +1 -1
- rasa/cli/project_templates/tutorial/data/patterns.yml +5 -0
- rasa/cli/project_templates/tutorial/domain.yml +14 -0
- rasa/cli/project_templates/tutorial/endpoints.yml +5 -0
- rasa/cli/run.py +7 -0
- rasa/cli/scaffold.py +4 -2
- rasa/cli/studio/upload.py +0 -15
- rasa/cli/train.py +14 -53
- rasa/cli/utils.py +14 -11
- rasa/cli/x.py +7 -7
- rasa/constants.py +3 -1
- rasa/core/actions/action.py +77 -33
- rasa/core/actions/action_hangup.py +29 -0
- rasa/core/actions/action_repeat_bot_messages.py +89 -0
- rasa/core/actions/e2e_stub_custom_action_executor.py +5 -1
- rasa/core/actions/http_custom_action_executor.py +4 -0
- rasa/core/agent.py +2 -2
- rasa/core/brokers/kafka.py +3 -1
- rasa/core/brokers/pika.py +3 -1
- rasa/core/channels/__init__.py +10 -6
- rasa/core/channels/channel.py +41 -4
- rasa/core/channels/development_inspector.py +150 -46
- rasa/core/channels/inspector/README.md +1 -1
- rasa/core/channels/inspector/dist/assets/{arc-b6e548fe.js → arc-bc141fb2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{c4Diagram-d0fbc5ce-fa03ac9e.js → c4Diagram-d0fbc5ce-be2db283.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-936ed81e-ee67392a.js → classDiagram-936ed81e-55366915.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-v2-c3cb15f1-9b283fae.js → classDiagram-v2-c3cb15f1-bb529518.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{createText-62fc7601-8b6fcc2a.js → createText-62fc7601-b0ec81d6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{edges-f2ad444c-22e77f4f.js → edges-f2ad444c-6166330c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{erDiagram-9d236eb7-60ffc87f.js → erDiagram-9d236eb7-5ccc6a8e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDb-1972c806-9dd802e4.js → flowDb-1972c806-fca3bfe4.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDiagram-7ea5b25a-5fa1912f.js → flowDiagram-7ea5b25a-4739080f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-736177bf.js +1 -0
- rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-abe16c3d-622a1fd2.js → flowchart-elk-definition-abe16c3d-7c1b0e0f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{ganttDiagram-9b5ea136-e285a63a.js → ganttDiagram-9b5ea136-772fd050.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-99d0ae7c-f237bdca.js → gitGraphDiagram-99d0ae7c-8eae1dc9.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-2c4b9a3b-4b03d70e.js → index-2c4b9a3b-f55afcdf.js} +1 -1
- rasa/core/channels/inspector/dist/assets/index-e7cef9de.js +1317 -0
- rasa/core/channels/inspector/dist/assets/{infoDiagram-736b4530-72a0fa5f.js → infoDiagram-736b4530-124d4a14.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{journeyDiagram-df861f2b-82218c41.js → journeyDiagram-df861f2b-7c4fae44.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{layout-78cff630.js → layout-b9885fb6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{line-5038b469.js → line-7c59abb6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{linear-c4fc4098.js → linear-4776f780.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{mindmap-definition-beec6740-c33c8ea6.js → mindmap-definition-beec6740-2332c46c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{pieDiagram-dbbf0591-a8d03059.js → pieDiagram-dbbf0591-8fb39303.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{quadrantDiagram-4d7f4fd6-6a0e56b2.js → quadrantDiagram-4d7f4fd6-3c7180a2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{requirementDiagram-6fc4c22a-2dc7c7bd.js → requirementDiagram-6fc4c22a-e910bcb8.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sankeyDiagram-8f13d901-2360fe39.js → sankeyDiagram-8f13d901-ead16c89.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sequenceDiagram-b655622a-41b9f9ad.js → sequenceDiagram-b655622a-29a02a19.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-59f0c015-0aad326f.js → stateDiagram-59f0c015-042b3137.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-2b26beab-9847d984.js → stateDiagram-v2-2b26beab-2178c0f3.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-080da4f6-564d890e.js → styles-080da4f6-23ffa4fc.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-3dcbcfbf-38957613.js → styles-3dcbcfbf-94f59763.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-9c745c82-f0fc6921.js → styles-9c745c82-78a6bebc.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{svgDrawCommon-4835440b-ef3c5a77.js → svgDrawCommon-4835440b-eae2a6f6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{timeline-definition-5b62e21b-bf3e91c1.js → timeline-definition-5b62e21b-5c968d92.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{xychartDiagram-2b33534f-4d4026c0.js → xychartDiagram-2b33534f-fd3db0d5.js} +1 -1
- rasa/core/channels/inspector/dist/index.html +18 -15
- rasa/core/channels/inspector/index.html +17 -14
- rasa/core/channels/inspector/package.json +5 -1
- rasa/core/channels/inspector/src/App.tsx +118 -68
- rasa/core/channels/inspector/src/components/Chat.tsx +95 -0
- rasa/core/channels/inspector/src/components/DiagramFlow.tsx +11 -10
- rasa/core/channels/inspector/src/components/DialogueStack.tsx +10 -25
- rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +6 -3
- rasa/core/channels/inspector/src/helpers/audiostream.ts +165 -0
- rasa/core/channels/inspector/src/helpers/formatters.test.ts +10 -0
- rasa/core/channels/inspector/src/helpers/formatters.ts +107 -41
- rasa/core/channels/inspector/src/helpers/utils.ts +92 -7
- rasa/core/channels/inspector/src/types.ts +21 -1
- rasa/core/channels/inspector/yarn.lock +94 -1
- rasa/core/channels/rest.py +51 -46
- rasa/core/channels/socketio.py +28 -1
- rasa/core/channels/telegram.py +1 -1
- rasa/core/channels/twilio.py +1 -1
- rasa/core/channels/{audiocodes.py → voice_ready/audiocodes.py} +122 -69
- rasa/core/channels/{voice_aware → voice_ready}/jambonz.py +26 -8
- rasa/core/channels/{voice_aware → voice_ready}/jambonz_protocol.py +57 -5
- rasa/core/channels/{twilio_voice.py → voice_ready/twilio_voice.py} +64 -28
- rasa/core/channels/voice_ready/utils.py +37 -0
- rasa/core/channels/voice_stream/asr/__init__.py +0 -0
- rasa/core/channels/voice_stream/asr/asr_engine.py +89 -0
- rasa/core/channels/voice_stream/asr/asr_event.py +18 -0
- rasa/core/channels/voice_stream/asr/azure.py +129 -0
- rasa/core/channels/voice_stream/asr/deepgram.py +90 -0
- rasa/core/channels/voice_stream/audio_bytes.py +8 -0
- rasa/core/channels/voice_stream/browser_audio.py +107 -0
- rasa/core/channels/voice_stream/call_state.py +23 -0
- rasa/core/channels/voice_stream/tts/__init__.py +0 -0
- rasa/core/channels/voice_stream/tts/azure.py +106 -0
- rasa/core/channels/voice_stream/tts/cartesia.py +118 -0
- rasa/core/channels/voice_stream/tts/tts_cache.py +27 -0
- rasa/core/channels/voice_stream/tts/tts_engine.py +58 -0
- rasa/core/channels/voice_stream/twilio_media_streams.py +173 -0
- rasa/core/channels/voice_stream/util.py +57 -0
- rasa/core/channels/voice_stream/voice_channel.py +427 -0
- rasa/core/information_retrieval/qdrant.py +1 -0
- rasa/core/nlg/contextual_response_rephraser.py +45 -17
- rasa/{nlu → core}/persistor.py +203 -68
- rasa/core/policies/enterprise_search_policy.py +119 -63
- rasa/core/policies/flows/flow_executor.py +15 -22
- rasa/core/policies/intentless_policy.py +83 -28
- rasa/core/processor.py +25 -0
- rasa/core/run.py +12 -2
- rasa/core/secrets_manager/constants.py +4 -0
- rasa/core/secrets_manager/factory.py +8 -0
- rasa/core/secrets_manager/vault.py +11 -1
- rasa/core/training/interactive.py +33 -34
- rasa/core/utils.py +47 -21
- rasa/dialogue_understanding/coexistence/llm_based_router.py +41 -14
- rasa/dialogue_understanding/commands/__init__.py +6 -0
- rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +60 -0
- rasa/dialogue_understanding/commands/session_end_command.py +61 -0
- rasa/dialogue_understanding/commands/user_silence_command.py +59 -0
- rasa/dialogue_understanding/commands/utils.py +5 -0
- rasa/dialogue_understanding/generator/constants.py +2 -0
- rasa/dialogue_understanding/generator/flow_retrieval.py +47 -9
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +38 -15
- rasa/dialogue_understanding/generator/llm_command_generator.py +1 -1
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +35 -13
- rasa/dialogue_understanding/generator/single_step/command_prompt_template.jinja2 +3 -0
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +60 -13
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +53 -0
- rasa/dialogue_understanding/patterns/repeat.py +37 -0
- rasa/dialogue_understanding/patterns/user_silence.py +37 -0
- rasa/dialogue_understanding/processor/command_processor.py +21 -1
- rasa/e2e_test/aggregate_test_stats_calculator.py +1 -11
- rasa/e2e_test/assertions.py +136 -61
- rasa/e2e_test/assertions_schema.yml +23 -0
- rasa/e2e_test/e2e_test_case.py +85 -6
- rasa/e2e_test/e2e_test_runner.py +2 -3
- rasa/engine/graph.py +0 -1
- rasa/engine/loader.py +12 -0
- rasa/engine/recipes/config_files/default_config.yml +0 -3
- rasa/engine/recipes/default_recipe.py +0 -1
- rasa/engine/recipes/graph_recipe.py +0 -1
- rasa/engine/runner/dask.py +2 -2
- rasa/engine/storage/local_model_storage.py +12 -42
- rasa/engine/storage/storage.py +1 -5
- rasa/engine/validation.py +527 -74
- rasa/model_manager/__init__.py +0 -0
- rasa/model_manager/config.py +40 -0
- rasa/model_manager/model_api.py +559 -0
- rasa/model_manager/runner_service.py +286 -0
- rasa/model_manager/socket_bridge.py +146 -0
- rasa/model_manager/studio_jwt_auth.py +86 -0
- rasa/model_manager/trainer_service.py +325 -0
- rasa/model_manager/utils.py +87 -0
- rasa/model_manager/warm_rasa_process.py +187 -0
- rasa/model_service.py +112 -0
- rasa/model_training.py +42 -23
- rasa/nlu/tokenizers/whitespace_tokenizer.py +3 -14
- rasa/server.py +4 -2
- rasa/shared/constants.py +60 -8
- rasa/shared/core/constants.py +13 -0
- rasa/shared/core/domain.py +107 -50
- rasa/shared/core/events.py +29 -0
- rasa/shared/core/flows/flow.py +5 -0
- rasa/shared/core/flows/flows_list.py +19 -6
- rasa/shared/core/flows/flows_yaml_schema.json +10 -0
- rasa/shared/core/flows/utils.py +39 -0
- rasa/shared/core/flows/validation.py +121 -0
- rasa/shared/core/flows/yaml_flows_io.py +15 -27
- rasa/shared/core/slots.py +5 -0
- rasa/shared/importers/importer.py +59 -41
- rasa/shared/importers/multi_project.py +23 -11
- rasa/shared/importers/rasa.py +12 -3
- rasa/shared/importers/remote_importer.py +196 -0
- rasa/shared/importers/utils.py +3 -1
- rasa/shared/nlu/training_data/formats/rasa_yaml.py +18 -3
- rasa/shared/nlu/training_data/training_data.py +18 -19
- rasa/shared/providers/_configs/litellm_router_client_config.py +220 -0
- rasa/shared/providers/_configs/model_group_config.py +167 -0
- rasa/shared/providers/_configs/openai_client_config.py +1 -1
- rasa/shared/providers/_configs/rasa_llm_client_config.py +73 -0
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +1 -0
- rasa/shared/providers/_configs/utils.py +16 -0
- rasa/shared/providers/_utils.py +79 -0
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +13 -29
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +54 -21
- rasa/shared/providers/embedding/default_litellm_embedding_client.py +24 -0
- rasa/shared/providers/embedding/litellm_router_embedding_client.py +135 -0
- rasa/shared/providers/llm/_base_litellm_client.py +34 -22
- rasa/shared/providers/llm/azure_openai_llm_client.py +50 -29
- rasa/shared/providers/llm/default_litellm_llm_client.py +24 -0
- rasa/shared/providers/llm/litellm_router_llm_client.py +182 -0
- rasa/shared/providers/llm/rasa_llm_client.py +112 -0
- rasa/shared/providers/llm/self_hosted_llm_client.py +5 -29
- rasa/shared/providers/mappings.py +19 -0
- rasa/shared/providers/router/__init__.py +0 -0
- rasa/shared/providers/router/_base_litellm_router_client.py +183 -0
- rasa/shared/providers/router/router_client.py +73 -0
- rasa/shared/utils/common.py +40 -24
- rasa/shared/utils/health_check/__init__.py +0 -0
- rasa/shared/utils/health_check/embeddings_health_check_mixin.py +31 -0
- rasa/shared/utils/health_check/health_check.py +258 -0
- rasa/shared/utils/health_check/llm_health_check_mixin.py +31 -0
- rasa/shared/utils/io.py +27 -6
- rasa/shared/utils/llm.py +353 -43
- rasa/shared/utils/schemas/events.py +2 -0
- rasa/shared/utils/schemas/model_config.yml +0 -10
- rasa/shared/utils/yaml.py +181 -38
- rasa/studio/data_handler.py +3 -1
- rasa/studio/upload.py +160 -74
- rasa/telemetry.py +94 -17
- rasa/tracing/config.py +3 -1
- rasa/tracing/instrumentation/attribute_extractors.py +95 -18
- rasa/tracing/instrumentation/instrumentation.py +121 -0
- rasa/utils/common.py +5 -0
- rasa/utils/endpoints.py +27 -1
- rasa/utils/io.py +8 -16
- rasa/utils/log_utils.py +9 -2
- rasa/utils/sanic_error_handler.py +32 -0
- rasa/validator.py +110 -4
- rasa/version.py +1 -1
- {rasa_pro-3.10.15.dist-info → rasa_pro-3.11.0.dist-info}/METADATA +14 -12
- {rasa_pro-3.10.15.dist-info → rasa_pro-3.11.0.dist-info}/RECORD +234 -183
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-1844e5a5.js +0 -1
- rasa/core/channels/inspector/dist/assets/index-a5d3e69d.js +0 -1040
- rasa/core/channels/voice_aware/utils.py +0 -20
- rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +0 -407
- /rasa/core/channels/{voice_aware → voice_ready}/__init__.py +0 -0
- /rasa/core/channels/{voice_native → voice_stream}/__init__.py +0 -0
- {rasa_pro-3.10.15.dist-info → rasa_pro-3.11.0.dist-info}/NOTICE +0 -0
- {rasa_pro-3.10.15.dist-info → rasa_pro-3.11.0.dist-info}/WHEEL +0 -0
- {rasa_pro-3.10.15.dist-info → rasa_pro-3.11.0.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
from typing import Any, Dict, List
|
|
2
|
+
import logging
|
|
3
|
+
import structlog
|
|
4
|
+
|
|
5
|
+
from rasa.shared.exceptions import ProviderClientAPIException
|
|
6
|
+
from rasa.shared.providers._configs.litellm_router_client_config import (
|
|
7
|
+
LiteLLMRouterClientConfig,
|
|
8
|
+
)
|
|
9
|
+
from rasa.shared.providers.embedding._base_litellm_embedding_client import (
|
|
10
|
+
_BaseLiteLLMEmbeddingClient,
|
|
11
|
+
)
|
|
12
|
+
from rasa.shared.providers.embedding.embedding_response import EmbeddingResponse
|
|
13
|
+
from rasa.shared.providers.router._base_litellm_router_client import (
|
|
14
|
+
_BaseLiteLLMRouterClient,
|
|
15
|
+
)
|
|
16
|
+
from rasa.shared.utils.io import suppress_logs
|
|
17
|
+
|
|
18
|
+
structlogger = structlog.get_logger()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class LiteLLMRouterEmbeddingClient(
|
|
22
|
+
_BaseLiteLLMRouterClient, _BaseLiteLLMEmbeddingClient
|
|
23
|
+
):
|
|
24
|
+
"""A client for interfacing with LiteLLM Router Embedding endpoints.
|
|
25
|
+
|
|
26
|
+
Parameters:
|
|
27
|
+
model_group_id (str): The model group ID.
|
|
28
|
+
model_configurations (List[Dict[str, Any]]): The list of model configurations.
|
|
29
|
+
router_settings (Dict[str, Any]): The router settings.
|
|
30
|
+
kwargs (Optional[Dict[str, Any]]): Additional configuration parameters.
|
|
31
|
+
|
|
32
|
+
Raises:
|
|
33
|
+
ProviderClientValidationError: If validation of the client setup fails.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
model_group_id: str,
|
|
39
|
+
model_configurations: List[Dict[str, Any]],
|
|
40
|
+
router_settings: Dict[str, Any],
|
|
41
|
+
**kwargs: Any,
|
|
42
|
+
):
|
|
43
|
+
super().__init__(
|
|
44
|
+
model_group_id, model_configurations, router_settings, **kwargs
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
@classmethod
|
|
48
|
+
def from_config(cls, config: Dict[str, Any]) -> "LiteLLMRouterEmbeddingClient":
|
|
49
|
+
"""Instantiates a LiteLLM Router Embedding client from a configuration dict.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
config: (Dict[str, Any]) The configuration dictionary.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
LiteLLMRouterLLMClient: The instantiated LiteLLM Router LLM client.
|
|
56
|
+
|
|
57
|
+
Raises:
|
|
58
|
+
ValueError: If the configuration is invalid.
|
|
59
|
+
"""
|
|
60
|
+
try:
|
|
61
|
+
client_config = LiteLLMRouterClientConfig.from_dict(config)
|
|
62
|
+
except (KeyError, ValueError) as e:
|
|
63
|
+
message = "Cannot instantiate a client from the passed configuration."
|
|
64
|
+
structlogger.error(
|
|
65
|
+
"litellm_router_llm_client.from_config.error",
|
|
66
|
+
message=message,
|
|
67
|
+
config=config,
|
|
68
|
+
original_error=e,
|
|
69
|
+
)
|
|
70
|
+
raise
|
|
71
|
+
|
|
72
|
+
return cls(
|
|
73
|
+
model_group_id=client_config.model_group_id,
|
|
74
|
+
model_configurations=client_config.litellm_model_list,
|
|
75
|
+
router_settings=client_config.litellm_router_settings,
|
|
76
|
+
**client_config.extra_parameters,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
@suppress_logs(log_level=logging.WARNING)
|
|
80
|
+
def embed(self, documents: List[str]) -> EmbeddingResponse:
|
|
81
|
+
"""
|
|
82
|
+
Embeds a list of documents synchronously.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
documents: List of documents to be embedded.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
List of embedding vectors.
|
|
89
|
+
|
|
90
|
+
Raises:
|
|
91
|
+
ProviderClientAPIException: If API calls raised an error.
|
|
92
|
+
"""
|
|
93
|
+
self.validate_documents(documents)
|
|
94
|
+
try:
|
|
95
|
+
response = self.router_client.embedding(
|
|
96
|
+
input=documents, **self._embedding_fn_args
|
|
97
|
+
)
|
|
98
|
+
return self._format_response(response)
|
|
99
|
+
except Exception as e:
|
|
100
|
+
raise ProviderClientAPIException(
|
|
101
|
+
message="Failed to embed documents", original_exception=e
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
@suppress_logs(log_level=logging.WARNING)
|
|
105
|
+
async def aembed(self, documents: List[str]) -> EmbeddingResponse:
|
|
106
|
+
"""
|
|
107
|
+
Embeds a list of documents asynchronously.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
documents: List of documents to be embedded.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
List of embedding vectors.
|
|
114
|
+
|
|
115
|
+
Raises:
|
|
116
|
+
ProviderClientAPIException: If API calls raised an error.
|
|
117
|
+
"""
|
|
118
|
+
self.validate_documents(documents)
|
|
119
|
+
try:
|
|
120
|
+
response = await self.router_client.aembedding(
|
|
121
|
+
input=documents, **self._embedding_fn_args
|
|
122
|
+
)
|
|
123
|
+
return self._format_response(response)
|
|
124
|
+
except Exception as e:
|
|
125
|
+
raise ProviderClientAPIException(
|
|
126
|
+
message="Failed to embed documents", original_exception=e
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
@property
|
|
130
|
+
def _embedding_fn_args(self) -> Dict[str, Any]:
|
|
131
|
+
"""Returns the arguments to be passed to the embedding function."""
|
|
132
|
+
return {
|
|
133
|
+
**self._litellm_extra_parameters,
|
|
134
|
+
"model": self._model_group_id,
|
|
135
|
+
}
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
from abc import abstractmethod
|
|
2
3
|
from typing import Dict, List, Any, Union
|
|
3
|
-
import logging
|
|
4
4
|
|
|
5
5
|
import structlog
|
|
6
6
|
from litellm import (
|
|
@@ -9,7 +9,7 @@ from litellm import (
|
|
|
9
9
|
validate_environment,
|
|
10
10
|
)
|
|
11
11
|
|
|
12
|
-
from rasa.shared.constants import API_BASE_CONFIG_KEY
|
|
12
|
+
from rasa.shared.constants import API_BASE_CONFIG_KEY, API_KEY
|
|
13
13
|
from rasa.shared.exceptions import (
|
|
14
14
|
ProviderClientAPIException,
|
|
15
15
|
ProviderClientValidationError,
|
|
@@ -19,7 +19,7 @@ from rasa.shared.providers._ssl_verification_utils import (
|
|
|
19
19
|
ensure_ssl_certificates_for_litellm_openai_based_clients,
|
|
20
20
|
)
|
|
21
21
|
from rasa.shared.providers.llm.llm_response import LLMResponse, LLMUsage
|
|
22
|
-
from rasa.shared.utils.io import suppress_logs
|
|
22
|
+
from rasa.shared.utils.io import suppress_logs, resolve_environment_variables
|
|
23
23
|
|
|
24
24
|
structlogger = structlog.get_logger()
|
|
25
25
|
|
|
@@ -99,12 +99,12 @@ class _BaseLiteLLMClient:
|
|
|
99
99
|
ProviderClientValidationError if validation fails.
|
|
100
100
|
"""
|
|
101
101
|
self._validate_environment_variables()
|
|
102
|
-
self._validate_api_key_not_in_config()
|
|
103
102
|
|
|
104
103
|
def _validate_environment_variables(self) -> None:
|
|
105
104
|
"""Validate that the required environment variables are set."""
|
|
106
105
|
validation_info = validate_environment(
|
|
107
106
|
self._litellm_model_name,
|
|
107
|
+
api_key=self._litellm_extra_parameters.get(API_KEY),
|
|
108
108
|
api_base=self._litellm_extra_parameters.get(API_BASE_CONFIG_KEY),
|
|
109
109
|
)
|
|
110
110
|
if missing_environment_variables := validation_info.get(
|
|
@@ -121,18 +121,6 @@ class _BaseLiteLLMClient:
|
|
|
121
121
|
)
|
|
122
122
|
raise ProviderClientValidationError(event_info)
|
|
123
123
|
|
|
124
|
-
def _validate_api_key_not_in_config(self) -> None:
|
|
125
|
-
if "api_key" in self._litellm_extra_parameters:
|
|
126
|
-
event_info = (
|
|
127
|
-
"API Key is set through `api_key` extra parameter."
|
|
128
|
-
"Set API keys through environment variables."
|
|
129
|
-
)
|
|
130
|
-
structlogger.error(
|
|
131
|
-
"base_litellm_client.validate_api_key_not_in_config",
|
|
132
|
-
event_info=event_info,
|
|
133
|
-
)
|
|
134
|
-
raise ProviderClientValidationError(event_info)
|
|
135
|
-
|
|
136
124
|
@suppress_logs(log_level=logging.WARNING)
|
|
137
125
|
def completion(self, messages: Union[List[str], str]) -> LLMResponse:
|
|
138
126
|
"""Synchronously generate completions for given list of messages.
|
|
@@ -149,9 +137,8 @@ class _BaseLiteLLMClient:
|
|
|
149
137
|
"""
|
|
150
138
|
try:
|
|
151
139
|
formatted_messages = self._format_messages(messages)
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
)
|
|
140
|
+
arguments = resolve_environment_variables(self._completion_fn_args)
|
|
141
|
+
response = completion(messages=formatted_messages, **arguments)
|
|
155
142
|
return self._format_response(response)
|
|
156
143
|
except Exception as e:
|
|
157
144
|
raise ProviderClientAPIException(e)
|
|
@@ -172,9 +159,8 @@ class _BaseLiteLLMClient:
|
|
|
172
159
|
"""
|
|
173
160
|
try:
|
|
174
161
|
formatted_messages = self._format_messages(messages)
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
)
|
|
162
|
+
arguments = resolve_environment_variables(self._completion_fn_args)
|
|
163
|
+
response = await acompletion(messages=formatted_messages, **arguments)
|
|
178
164
|
return self._format_response(response)
|
|
179
165
|
except Exception as e:
|
|
180
166
|
message = ""
|
|
@@ -235,6 +221,32 @@ class _BaseLiteLLMClient:
|
|
|
235
221
|
)
|
|
236
222
|
return formatted_response
|
|
237
223
|
|
|
224
|
+
def _format_text_completion_response(self, response: Any) -> LLMResponse:
|
|
225
|
+
"""Parses the LiteLLM text completion response to Rasa format."""
|
|
226
|
+
formatted_response = LLMResponse(
|
|
227
|
+
id=response.id,
|
|
228
|
+
created=response.created,
|
|
229
|
+
choices=[choice.text for choice in response.choices],
|
|
230
|
+
model=response.model,
|
|
231
|
+
)
|
|
232
|
+
if (usage := response.usage) is not None:
|
|
233
|
+
prompt_tokens = (
|
|
234
|
+
num_tokens
|
|
235
|
+
if isinstance(num_tokens := usage.prompt_tokens, (int, float))
|
|
236
|
+
else 0
|
|
237
|
+
)
|
|
238
|
+
completion_tokens = (
|
|
239
|
+
num_tokens
|
|
240
|
+
if isinstance(num_tokens := usage.completion_tokens, (int, float))
|
|
241
|
+
else 0
|
|
242
|
+
)
|
|
243
|
+
formatted_response.usage = LLMUsage(prompt_tokens, completion_tokens)
|
|
244
|
+
structlogger.debug(
|
|
245
|
+
"base_litellm_client.formatted_response",
|
|
246
|
+
formatted_response=formatted_response.to_dict(),
|
|
247
|
+
)
|
|
248
|
+
return formatted_response
|
|
249
|
+
|
|
238
250
|
@staticmethod
|
|
239
251
|
def _ensure_certificates() -> None:
|
|
240
252
|
"""Configures SSL certificates for LiteLLM. This method is invoked during
|
|
@@ -17,6 +17,7 @@ from rasa.shared.constants import (
|
|
|
17
17
|
OPENAI_API_KEY_ENV_VAR,
|
|
18
18
|
AZURE_API_TYPE_ENV_VAR,
|
|
19
19
|
AZURE_OPENAI_PROVIDER,
|
|
20
|
+
API_KEY,
|
|
20
21
|
)
|
|
21
22
|
from rasa.shared.exceptions import ProviderClientValidationError
|
|
22
23
|
from rasa.shared.providers._configs.azure_openai_client_config import (
|
|
@@ -29,8 +30,7 @@ structlogger = structlog.get_logger()
|
|
|
29
30
|
|
|
30
31
|
|
|
31
32
|
class AzureOpenAILLMClient(_BaseLiteLLMClient):
|
|
32
|
-
"""
|
|
33
|
-
A client for interfacing with Azure's OpenAI LLM deployments.
|
|
33
|
+
"""A client for interfacing with Azure's OpenAI LLM deployments.
|
|
34
34
|
|
|
35
35
|
Parameters:
|
|
36
36
|
deployment (str): The deployment name.
|
|
@@ -80,11 +80,7 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
|
|
|
80
80
|
or os.getenv(OPENAI_API_VERSION_ENV_VAR)
|
|
81
81
|
)
|
|
82
82
|
|
|
83
|
-
|
|
84
|
-
# because of the backward compatibility
|
|
85
|
-
self._api_key = os.getenv(AZURE_API_KEY_ENV_VAR) or os.getenv(
|
|
86
|
-
OPENAI_API_KEY_ENV_VAR
|
|
87
|
-
)
|
|
83
|
+
self._api_key_env_var = self._resolve_api_key_env_var()
|
|
88
84
|
|
|
89
85
|
# Not used by LiteLLM, here for backward compatibility
|
|
90
86
|
self._api_type = (
|
|
@@ -117,11 +113,6 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
|
|
|
117
113
|
"env_var": AZURE_API_VERSION_ENV_VAR,
|
|
118
114
|
"deprecated_var": OPENAI_API_VERSION_ENV_VAR,
|
|
119
115
|
},
|
|
120
|
-
"API Key": {
|
|
121
|
-
"current_value": self._api_key,
|
|
122
|
-
"env_var": AZURE_API_KEY_ENV_VAR,
|
|
123
|
-
"deprecated_var": OPENAI_API_KEY_ENV_VAR,
|
|
124
|
-
},
|
|
125
116
|
}
|
|
126
117
|
|
|
127
118
|
deprecation_warning_message = (
|
|
@@ -154,10 +145,51 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
|
|
|
154
145
|
)
|
|
155
146
|
raise_deprecation_warning(message=message)
|
|
156
147
|
|
|
148
|
+
def _resolve_api_key_env_var(self) -> str:
|
|
149
|
+
"""Resolves the environment variable to use for the API key.
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
str: The env variable in dollar syntax format to use for the API key.
|
|
153
|
+
"""
|
|
154
|
+
if API_KEY in self._extra_parameters:
|
|
155
|
+
# API key is set to an env var in the config itself
|
|
156
|
+
# in case the model is defined in the endpoints.yml
|
|
157
|
+
return self._extra_parameters[API_KEY]
|
|
158
|
+
|
|
159
|
+
if os.getenv(AZURE_API_KEY_ENV_VAR) is not None:
|
|
160
|
+
return "${AZURE_API_KEY}"
|
|
161
|
+
|
|
162
|
+
if os.getenv(OPENAI_API_KEY_ENV_VAR) is not None:
|
|
163
|
+
# API key can be set through OPENAI_API_KEY too,
|
|
164
|
+
# because of the backward compatibility
|
|
165
|
+
raise_deprecation_warning(
|
|
166
|
+
message=(
|
|
167
|
+
f"Usage of '{OPENAI_API_KEY_ENV_VAR}' environment variable "
|
|
168
|
+
"for setting the API key for Azure OpenAI "
|
|
169
|
+
"client is deprecated and will be removed "
|
|
170
|
+
f"in 4.0.0. Please use '{AZURE_API_KEY_ENV_VAR}' "
|
|
171
|
+
"environment variable."
|
|
172
|
+
)
|
|
173
|
+
)
|
|
174
|
+
return "${OPENAI_API_KEY}"
|
|
175
|
+
|
|
176
|
+
structlogger.error(
|
|
177
|
+
"azure_openai_llm_client.api_key_not_set",
|
|
178
|
+
event_info=(
|
|
179
|
+
"API key not set, it is required for API calls. "
|
|
180
|
+
f"Set it either via the environment variable"
|
|
181
|
+
f"'{AZURE_API_KEY_ENV_VAR}' or directly"
|
|
182
|
+
f"via the config key '{API_KEY}'."
|
|
183
|
+
),
|
|
184
|
+
)
|
|
185
|
+
raise ProviderClientValidationError(
|
|
186
|
+
f"Missing required environment variable/config key '{API_KEY}' for "
|
|
187
|
+
f"API calls."
|
|
188
|
+
)
|
|
189
|
+
|
|
157
190
|
@classmethod
|
|
158
191
|
def from_config(cls, config: Dict[str, Any]) -> "AzureOpenAILLMClient":
|
|
159
|
-
"""
|
|
160
|
-
Initializes the client from given configuration.
|
|
192
|
+
"""Initializes the client from given configuration.
|
|
161
193
|
|
|
162
194
|
Args:
|
|
163
195
|
config (Dict[str, Any]): Configuration.
|
|
@@ -212,23 +244,17 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
|
|
|
212
244
|
|
|
213
245
|
@property
|
|
214
246
|
def model(self) -> Optional[str]:
|
|
215
|
-
"""
|
|
216
|
-
Returns the name of the model deployed on Azure.
|
|
217
|
-
"""
|
|
247
|
+
"""Returns the name of the model deployed on Azure."""
|
|
218
248
|
return self._model
|
|
219
249
|
|
|
220
250
|
@property
|
|
221
251
|
def api_base(self) -> Optional[str]:
|
|
222
|
-
"""
|
|
223
|
-
Returns the API base URL for the Azure OpenAI llm client.
|
|
224
|
-
"""
|
|
252
|
+
"""Returns the API base URL for the Azure OpenAI llm client."""
|
|
225
253
|
return self._api_base
|
|
226
254
|
|
|
227
255
|
@property
|
|
228
256
|
def api_version(self) -> Optional[str]:
|
|
229
|
-
"""
|
|
230
|
-
Returns the API version for the Azure OpenAI llm client.
|
|
231
|
-
"""
|
|
257
|
+
"""Returns the API version for the Azure OpenAI llm client."""
|
|
232
258
|
return self._api_version
|
|
233
259
|
|
|
234
260
|
@property
|
|
@@ -261,7 +287,7 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
|
|
|
261
287
|
{
|
|
262
288
|
"api_base": self.api_base,
|
|
263
289
|
"api_version": self.api_version,
|
|
264
|
-
"api_key": self.
|
|
290
|
+
"api_key": self._api_key_env_var,
|
|
265
291
|
}
|
|
266
292
|
)
|
|
267
293
|
return fn_args
|
|
@@ -305,11 +331,6 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
|
|
|
305
331
|
"env_var": None,
|
|
306
332
|
"config_key": DEPLOYMENT_CONFIG_KEY,
|
|
307
333
|
},
|
|
308
|
-
"API Key": {
|
|
309
|
-
"current_value": self._api_key,
|
|
310
|
-
"env_var": AZURE_API_KEY_ENV_VAR,
|
|
311
|
-
"config_key": None,
|
|
312
|
-
},
|
|
313
334
|
}
|
|
314
335
|
|
|
315
336
|
missing_settings = [
|
|
@@ -1,8 +1,13 @@
|
|
|
1
1
|
from typing import Dict, Any
|
|
2
2
|
|
|
3
|
+
from rasa.shared.constants import (
|
|
4
|
+
AWS_BEDROCK_PROVIDER,
|
|
5
|
+
AWS_SAGEMAKER_PROVIDER,
|
|
6
|
+
)
|
|
3
7
|
from rasa.shared.providers._configs.default_litellm_client_config import (
|
|
4
8
|
DefaultLiteLLMClientConfig,
|
|
5
9
|
)
|
|
10
|
+
from rasa.shared.providers._utils import validate_aws_setup_for_litellm_clients
|
|
6
11
|
from rasa.shared.providers.llm._base_litellm_client import _BaseLiteLLMClient
|
|
7
12
|
|
|
8
13
|
|
|
@@ -82,3 +87,22 @@ class DefaultLiteLLMClient(_BaseLiteLLMClient):
|
|
|
82
87
|
to the client provider and deployed model.
|
|
83
88
|
"""
|
|
84
89
|
return self._extra_parameters
|
|
90
|
+
|
|
91
|
+
def validate_client_setup(self) -> None:
|
|
92
|
+
# TODO: Temporarily change the environment variable validation for AWS setup
|
|
93
|
+
# (Bedrock and SageMaker) until resolved by either:
|
|
94
|
+
# 1. An update from the LiteLLM package addressing the issue.
|
|
95
|
+
# 2. The implementation of a Bedrock client on our end.
|
|
96
|
+
# ---
|
|
97
|
+
# This fix ensures a consistent user experience for Bedrock (and
|
|
98
|
+
# SageMaker) in Rasa by allowing AWS secrets to be provided as extra
|
|
99
|
+
# parameters without triggering validation errors due to missing AWS
|
|
100
|
+
# environment variables.
|
|
101
|
+
if self.provider.lower() in [AWS_BEDROCK_PROVIDER, AWS_SAGEMAKER_PROVIDER]:
|
|
102
|
+
validate_aws_setup_for_litellm_clients(
|
|
103
|
+
self._litellm_model_name,
|
|
104
|
+
self._litellm_extra_parameters,
|
|
105
|
+
"default_litellm_llm_client",
|
|
106
|
+
)
|
|
107
|
+
else:
|
|
108
|
+
super().validate_client_setup()
|
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Union
|
|
2
|
+
import logging
|
|
3
|
+
import structlog
|
|
4
|
+
|
|
5
|
+
from rasa.shared.exceptions import ProviderClientAPIException
|
|
6
|
+
from rasa.shared.providers._configs.litellm_router_client_config import (
|
|
7
|
+
LiteLLMRouterClientConfig,
|
|
8
|
+
)
|
|
9
|
+
from rasa.shared.providers.llm._base_litellm_client import _BaseLiteLLMClient
|
|
10
|
+
from rasa.shared.providers.llm.llm_response import LLMResponse
|
|
11
|
+
from rasa.shared.providers.router._base_litellm_router_client import (
|
|
12
|
+
_BaseLiteLLMRouterClient,
|
|
13
|
+
)
|
|
14
|
+
from rasa.shared.utils.io import suppress_logs
|
|
15
|
+
|
|
16
|
+
structlogger = structlog.get_logger()
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
20
|
+
"""A client for interfacing with LiteLLM Router LLM endpoints.
|
|
21
|
+
|
|
22
|
+
Parameters:
|
|
23
|
+
model_group_id (str): The model group ID.
|
|
24
|
+
model_configurations (List[Dict[str, Any]]): The list of model configurations.
|
|
25
|
+
router_settings (Dict[str, Any]): The router settings.
|
|
26
|
+
kwargs (Optional[Dict[str, Any]]): Additional configuration parameters.
|
|
27
|
+
|
|
28
|
+
Raises:
|
|
29
|
+
ProviderClientValidationError: If validation of the client setup fails.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
model_group_id: str,
|
|
35
|
+
model_configurations: List[Dict[str, Any]],
|
|
36
|
+
router_settings: Dict[str, Any],
|
|
37
|
+
**kwargs: Any,
|
|
38
|
+
):
|
|
39
|
+
super().__init__(
|
|
40
|
+
model_group_id, model_configurations, router_settings, **kwargs
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
@classmethod
|
|
44
|
+
def from_config(cls, config: Dict[str, Any]) -> "LiteLLMRouterLLMClient":
|
|
45
|
+
"""Instantiates a LiteLLM Router LLM client from a configuration dict.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
config: (Dict[str, Any]) The configuration dictionary.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
LiteLLMRouterLLMClient: The instantiated LiteLLM Router LLM client.
|
|
52
|
+
|
|
53
|
+
Raises:
|
|
54
|
+
ValueError: If the configuration is invalid.
|
|
55
|
+
"""
|
|
56
|
+
try:
|
|
57
|
+
client_config = LiteLLMRouterClientConfig.from_dict(config)
|
|
58
|
+
except (KeyError, ValueError) as e:
|
|
59
|
+
message = "Cannot instantiate a client from the passed configuration."
|
|
60
|
+
structlogger.error(
|
|
61
|
+
"litellm_router_llm_client.from_config.error",
|
|
62
|
+
message=message,
|
|
63
|
+
config=config,
|
|
64
|
+
original_error=e,
|
|
65
|
+
)
|
|
66
|
+
raise
|
|
67
|
+
|
|
68
|
+
return cls(
|
|
69
|
+
model_group_id=client_config.model_group_id,
|
|
70
|
+
model_configurations=client_config.litellm_model_list,
|
|
71
|
+
router_settings=client_config.litellm_router_settings,
|
|
72
|
+
use_chat_completions_endpoint=client_config.use_chat_completions_endpoint,
|
|
73
|
+
**client_config.extra_parameters,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
@suppress_logs(log_level=logging.WARNING)
|
|
77
|
+
def _text_completion(self, prompt: Union[List[str], str]) -> LLMResponse:
|
|
78
|
+
"""
|
|
79
|
+
Synchronously generate completions for given prompt.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
prompt: Prompt to generate the completion for.
|
|
83
|
+
Returns:
|
|
84
|
+
List of message completions.
|
|
85
|
+
Raises:
|
|
86
|
+
ProviderClientAPIException: If the API request fails.
|
|
87
|
+
"""
|
|
88
|
+
try:
|
|
89
|
+
response = self.router_client.text_completion(
|
|
90
|
+
prompt=prompt, **self._completion_fn_args
|
|
91
|
+
)
|
|
92
|
+
return self._format_text_completion_response(response)
|
|
93
|
+
except Exception as e:
|
|
94
|
+
raise ProviderClientAPIException(e)
|
|
95
|
+
|
|
96
|
+
@suppress_logs(log_level=logging.WARNING)
|
|
97
|
+
async def _atext_completion(self, prompt: Union[List[str], str]) -> LLMResponse:
|
|
98
|
+
"""
|
|
99
|
+
Asynchronously generate completions for given prompt.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
prompt: Prompt to generate the completion for.
|
|
103
|
+
Returns:
|
|
104
|
+
List of message completions.
|
|
105
|
+
Raises:
|
|
106
|
+
ProviderClientAPIException: If the API request fails.
|
|
107
|
+
"""
|
|
108
|
+
try:
|
|
109
|
+
response = await self.router_client.atext_completion(
|
|
110
|
+
prompt=prompt, **self._completion_fn_args
|
|
111
|
+
)
|
|
112
|
+
return self._format_text_completion_response(response)
|
|
113
|
+
except Exception as e:
|
|
114
|
+
raise ProviderClientAPIException(e)
|
|
115
|
+
|
|
116
|
+
@suppress_logs(log_level=logging.WARNING)
|
|
117
|
+
def completion(self, messages: Union[List[str], str]) -> LLMResponse:
|
|
118
|
+
"""
|
|
119
|
+
Synchronously generate completions for given list of messages.
|
|
120
|
+
|
|
121
|
+
Method overrides the base class method to call the appropriate
|
|
122
|
+
completion method based on the configuration. If the chat completions
|
|
123
|
+
endpoint is enabled, the completion method is called. Otherwise, the
|
|
124
|
+
text_completion method is called.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
messages: List of messages or a single message to generate the
|
|
128
|
+
completion for.
|
|
129
|
+
Returns:
|
|
130
|
+
List of message completions.
|
|
131
|
+
Raises:
|
|
132
|
+
ProviderClientAPIException: If the API request fails.
|
|
133
|
+
"""
|
|
134
|
+
if not self._use_chat_completions_endpoint:
|
|
135
|
+
return self._text_completion(messages)
|
|
136
|
+
try:
|
|
137
|
+
formatted_messages = self._format_messages(messages)
|
|
138
|
+
response = self.router_client.completion(
|
|
139
|
+
messages=formatted_messages, **self._completion_fn_args
|
|
140
|
+
)
|
|
141
|
+
return self._format_response(response)
|
|
142
|
+
except Exception as e:
|
|
143
|
+
raise ProviderClientAPIException(e)
|
|
144
|
+
|
|
145
|
+
@suppress_logs(log_level=logging.WARNING)
|
|
146
|
+
async def acompletion(self, messages: Union[List[str], str]) -> LLMResponse:
|
|
147
|
+
"""
|
|
148
|
+
Asynchronously generate completions for given list of messages.
|
|
149
|
+
|
|
150
|
+
Method overrides the base class method to call the appropriate
|
|
151
|
+
completion method based on the configuration. If the chat completions
|
|
152
|
+
endpoint is enabled, the completion method is called. Otherwise, the
|
|
153
|
+
text_completion method is called.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
messages: List of messages or a single message to generate the
|
|
157
|
+
completion for.
|
|
158
|
+
Returns:
|
|
159
|
+
List of message completions.
|
|
160
|
+
Raises:
|
|
161
|
+
ProviderClientAPIException: If the API request fails.
|
|
162
|
+
"""
|
|
163
|
+
if not self._use_chat_completions_endpoint:
|
|
164
|
+
return await self._atext_completion(messages)
|
|
165
|
+
try:
|
|
166
|
+
formatted_messages = self._format_messages(messages)
|
|
167
|
+
response = await self.router_client.acompletion(
|
|
168
|
+
messages=formatted_messages, **self._completion_fn_args
|
|
169
|
+
)
|
|
170
|
+
return self._format_response(response)
|
|
171
|
+
except Exception as e:
|
|
172
|
+
raise ProviderClientAPIException(e)
|
|
173
|
+
|
|
174
|
+
@property
|
|
175
|
+
def _completion_fn_args(self) -> Dict[str, Any]:
|
|
176
|
+
"""Returns the completion arguments for invoking a call through
|
|
177
|
+
LiteLLM's completion functions.
|
|
178
|
+
"""
|
|
179
|
+
return {
|
|
180
|
+
**self._litellm_extra_parameters,
|
|
181
|
+
"model": self.model_group_id,
|
|
182
|
+
}
|