rasa-pro 3.10.15__py3-none-any.whl → 3.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/__main__.py +31 -15
- rasa/api.py +12 -2
- rasa/cli/arguments/default_arguments.py +24 -4
- rasa/cli/arguments/run.py +15 -0
- rasa/cli/arguments/shell.py +5 -1
- rasa/cli/arguments/train.py +17 -9
- rasa/cli/evaluate.py +7 -7
- rasa/cli/inspect.py +19 -7
- rasa/cli/interactive.py +1 -0
- rasa/cli/project_templates/calm/config.yml +5 -7
- rasa/cli/project_templates/calm/endpoints.yml +15 -2
- rasa/cli/project_templates/tutorial/config.yml +8 -5
- rasa/cli/project_templates/tutorial/data/flows.yml +1 -1
- rasa/cli/project_templates/tutorial/data/patterns.yml +5 -0
- rasa/cli/project_templates/tutorial/domain.yml +14 -0
- rasa/cli/project_templates/tutorial/endpoints.yml +5 -0
- rasa/cli/run.py +7 -0
- rasa/cli/scaffold.py +4 -2
- rasa/cli/studio/upload.py +0 -15
- rasa/cli/train.py +14 -53
- rasa/cli/utils.py +14 -11
- rasa/cli/x.py +7 -7
- rasa/constants.py +3 -1
- rasa/core/actions/action.py +77 -33
- rasa/core/actions/action_hangup.py +29 -0
- rasa/core/actions/action_repeat_bot_messages.py +89 -0
- rasa/core/actions/e2e_stub_custom_action_executor.py +5 -1
- rasa/core/actions/http_custom_action_executor.py +4 -0
- rasa/core/agent.py +2 -2
- rasa/core/brokers/kafka.py +3 -1
- rasa/core/brokers/pika.py +3 -1
- rasa/core/channels/__init__.py +10 -6
- rasa/core/channels/channel.py +41 -4
- rasa/core/channels/development_inspector.py +150 -46
- rasa/core/channels/inspector/README.md +1 -1
- rasa/core/channels/inspector/dist/assets/{arc-b6e548fe.js → arc-bc141fb2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{c4Diagram-d0fbc5ce-fa03ac9e.js → c4Diagram-d0fbc5ce-be2db283.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-936ed81e-ee67392a.js → classDiagram-936ed81e-55366915.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-v2-c3cb15f1-9b283fae.js → classDiagram-v2-c3cb15f1-bb529518.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{createText-62fc7601-8b6fcc2a.js → createText-62fc7601-b0ec81d6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{edges-f2ad444c-22e77f4f.js → edges-f2ad444c-6166330c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{erDiagram-9d236eb7-60ffc87f.js → erDiagram-9d236eb7-5ccc6a8e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDb-1972c806-9dd802e4.js → flowDb-1972c806-fca3bfe4.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDiagram-7ea5b25a-5fa1912f.js → flowDiagram-7ea5b25a-4739080f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-736177bf.js +1 -0
- rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-abe16c3d-622a1fd2.js → flowchart-elk-definition-abe16c3d-7c1b0e0f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{ganttDiagram-9b5ea136-e285a63a.js → ganttDiagram-9b5ea136-772fd050.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-99d0ae7c-f237bdca.js → gitGraphDiagram-99d0ae7c-8eae1dc9.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-2c4b9a3b-4b03d70e.js → index-2c4b9a3b-f55afcdf.js} +1 -1
- rasa/core/channels/inspector/dist/assets/index-e7cef9de.js +1317 -0
- rasa/core/channels/inspector/dist/assets/{infoDiagram-736b4530-72a0fa5f.js → infoDiagram-736b4530-124d4a14.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{journeyDiagram-df861f2b-82218c41.js → journeyDiagram-df861f2b-7c4fae44.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{layout-78cff630.js → layout-b9885fb6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{line-5038b469.js → line-7c59abb6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{linear-c4fc4098.js → linear-4776f780.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{mindmap-definition-beec6740-c33c8ea6.js → mindmap-definition-beec6740-2332c46c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{pieDiagram-dbbf0591-a8d03059.js → pieDiagram-dbbf0591-8fb39303.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{quadrantDiagram-4d7f4fd6-6a0e56b2.js → quadrantDiagram-4d7f4fd6-3c7180a2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{requirementDiagram-6fc4c22a-2dc7c7bd.js → requirementDiagram-6fc4c22a-e910bcb8.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sankeyDiagram-8f13d901-2360fe39.js → sankeyDiagram-8f13d901-ead16c89.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sequenceDiagram-b655622a-41b9f9ad.js → sequenceDiagram-b655622a-29a02a19.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-59f0c015-0aad326f.js → stateDiagram-59f0c015-042b3137.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-2b26beab-9847d984.js → stateDiagram-v2-2b26beab-2178c0f3.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-080da4f6-564d890e.js → styles-080da4f6-23ffa4fc.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-3dcbcfbf-38957613.js → styles-3dcbcfbf-94f59763.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-9c745c82-f0fc6921.js → styles-9c745c82-78a6bebc.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{svgDrawCommon-4835440b-ef3c5a77.js → svgDrawCommon-4835440b-eae2a6f6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{timeline-definition-5b62e21b-bf3e91c1.js → timeline-definition-5b62e21b-5c968d92.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{xychartDiagram-2b33534f-4d4026c0.js → xychartDiagram-2b33534f-fd3db0d5.js} +1 -1
- rasa/core/channels/inspector/dist/index.html +18 -15
- rasa/core/channels/inspector/index.html +17 -14
- rasa/core/channels/inspector/package.json +5 -1
- rasa/core/channels/inspector/src/App.tsx +118 -68
- rasa/core/channels/inspector/src/components/Chat.tsx +95 -0
- rasa/core/channels/inspector/src/components/DiagramFlow.tsx +11 -10
- rasa/core/channels/inspector/src/components/DialogueStack.tsx +10 -25
- rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +6 -3
- rasa/core/channels/inspector/src/helpers/audiostream.ts +165 -0
- rasa/core/channels/inspector/src/helpers/formatters.test.ts +10 -0
- rasa/core/channels/inspector/src/helpers/formatters.ts +107 -41
- rasa/core/channels/inspector/src/helpers/utils.ts +92 -7
- rasa/core/channels/inspector/src/types.ts +21 -1
- rasa/core/channels/inspector/yarn.lock +94 -1
- rasa/core/channels/rest.py +51 -46
- rasa/core/channels/socketio.py +28 -1
- rasa/core/channels/telegram.py +1 -1
- rasa/core/channels/twilio.py +1 -1
- rasa/core/channels/{audiocodes.py → voice_ready/audiocodes.py} +122 -69
- rasa/core/channels/{voice_aware → voice_ready}/jambonz.py +26 -8
- rasa/core/channels/{voice_aware → voice_ready}/jambonz_protocol.py +57 -5
- rasa/core/channels/{twilio_voice.py → voice_ready/twilio_voice.py} +64 -28
- rasa/core/channels/voice_ready/utils.py +37 -0
- rasa/core/channels/voice_stream/asr/__init__.py +0 -0
- rasa/core/channels/voice_stream/asr/asr_engine.py +89 -0
- rasa/core/channels/voice_stream/asr/asr_event.py +18 -0
- rasa/core/channels/voice_stream/asr/azure.py +129 -0
- rasa/core/channels/voice_stream/asr/deepgram.py +90 -0
- rasa/core/channels/voice_stream/audio_bytes.py +8 -0
- rasa/core/channels/voice_stream/browser_audio.py +107 -0
- rasa/core/channels/voice_stream/call_state.py +23 -0
- rasa/core/channels/voice_stream/tts/__init__.py +0 -0
- rasa/core/channels/voice_stream/tts/azure.py +106 -0
- rasa/core/channels/voice_stream/tts/cartesia.py +118 -0
- rasa/core/channels/voice_stream/tts/tts_cache.py +27 -0
- rasa/core/channels/voice_stream/tts/tts_engine.py +58 -0
- rasa/core/channels/voice_stream/twilio_media_streams.py +173 -0
- rasa/core/channels/voice_stream/util.py +57 -0
- rasa/core/channels/voice_stream/voice_channel.py +427 -0
- rasa/core/information_retrieval/qdrant.py +1 -0
- rasa/core/nlg/contextual_response_rephraser.py +45 -17
- rasa/{nlu → core}/persistor.py +203 -68
- rasa/core/policies/enterprise_search_policy.py +119 -63
- rasa/core/policies/flows/flow_executor.py +15 -22
- rasa/core/policies/intentless_policy.py +83 -28
- rasa/core/processor.py +25 -0
- rasa/core/run.py +12 -2
- rasa/core/secrets_manager/constants.py +4 -0
- rasa/core/secrets_manager/factory.py +8 -0
- rasa/core/secrets_manager/vault.py +11 -1
- rasa/core/training/interactive.py +33 -34
- rasa/core/utils.py +47 -21
- rasa/dialogue_understanding/coexistence/llm_based_router.py +41 -14
- rasa/dialogue_understanding/commands/__init__.py +6 -0
- rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +60 -0
- rasa/dialogue_understanding/commands/session_end_command.py +61 -0
- rasa/dialogue_understanding/commands/user_silence_command.py +59 -0
- rasa/dialogue_understanding/commands/utils.py +5 -0
- rasa/dialogue_understanding/generator/constants.py +2 -0
- rasa/dialogue_understanding/generator/flow_retrieval.py +47 -9
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +38 -15
- rasa/dialogue_understanding/generator/llm_command_generator.py +1 -1
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +35 -13
- rasa/dialogue_understanding/generator/single_step/command_prompt_template.jinja2 +3 -0
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +60 -13
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +53 -0
- rasa/dialogue_understanding/patterns/repeat.py +37 -0
- rasa/dialogue_understanding/patterns/user_silence.py +37 -0
- rasa/dialogue_understanding/processor/command_processor.py +21 -1
- rasa/e2e_test/aggregate_test_stats_calculator.py +1 -11
- rasa/e2e_test/assertions.py +136 -61
- rasa/e2e_test/assertions_schema.yml +23 -0
- rasa/e2e_test/e2e_test_case.py +85 -6
- rasa/e2e_test/e2e_test_runner.py +2 -3
- rasa/engine/graph.py +0 -1
- rasa/engine/loader.py +12 -0
- rasa/engine/recipes/config_files/default_config.yml +0 -3
- rasa/engine/recipes/default_recipe.py +0 -1
- rasa/engine/recipes/graph_recipe.py +0 -1
- rasa/engine/runner/dask.py +2 -2
- rasa/engine/storage/local_model_storage.py +12 -42
- rasa/engine/storage/storage.py +1 -5
- rasa/engine/validation.py +527 -74
- rasa/model_manager/__init__.py +0 -0
- rasa/model_manager/config.py +40 -0
- rasa/model_manager/model_api.py +559 -0
- rasa/model_manager/runner_service.py +286 -0
- rasa/model_manager/socket_bridge.py +146 -0
- rasa/model_manager/studio_jwt_auth.py +86 -0
- rasa/model_manager/trainer_service.py +325 -0
- rasa/model_manager/utils.py +87 -0
- rasa/model_manager/warm_rasa_process.py +187 -0
- rasa/model_service.py +112 -0
- rasa/model_training.py +42 -23
- rasa/nlu/tokenizers/whitespace_tokenizer.py +3 -14
- rasa/server.py +4 -2
- rasa/shared/constants.py +60 -8
- rasa/shared/core/constants.py +13 -0
- rasa/shared/core/domain.py +107 -50
- rasa/shared/core/events.py +29 -0
- rasa/shared/core/flows/flow.py +5 -0
- rasa/shared/core/flows/flows_list.py +19 -6
- rasa/shared/core/flows/flows_yaml_schema.json +10 -0
- rasa/shared/core/flows/utils.py +39 -0
- rasa/shared/core/flows/validation.py +121 -0
- rasa/shared/core/flows/yaml_flows_io.py +15 -27
- rasa/shared/core/slots.py +5 -0
- rasa/shared/importers/importer.py +59 -41
- rasa/shared/importers/multi_project.py +23 -11
- rasa/shared/importers/rasa.py +12 -3
- rasa/shared/importers/remote_importer.py +196 -0
- rasa/shared/importers/utils.py +3 -1
- rasa/shared/nlu/training_data/formats/rasa_yaml.py +18 -3
- rasa/shared/nlu/training_data/training_data.py +18 -19
- rasa/shared/providers/_configs/litellm_router_client_config.py +220 -0
- rasa/shared/providers/_configs/model_group_config.py +167 -0
- rasa/shared/providers/_configs/openai_client_config.py +1 -1
- rasa/shared/providers/_configs/rasa_llm_client_config.py +73 -0
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +1 -0
- rasa/shared/providers/_configs/utils.py +16 -0
- rasa/shared/providers/_utils.py +79 -0
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +13 -29
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +54 -21
- rasa/shared/providers/embedding/default_litellm_embedding_client.py +24 -0
- rasa/shared/providers/embedding/litellm_router_embedding_client.py +135 -0
- rasa/shared/providers/llm/_base_litellm_client.py +34 -22
- rasa/shared/providers/llm/azure_openai_llm_client.py +50 -29
- rasa/shared/providers/llm/default_litellm_llm_client.py +24 -0
- rasa/shared/providers/llm/litellm_router_llm_client.py +182 -0
- rasa/shared/providers/llm/rasa_llm_client.py +112 -0
- rasa/shared/providers/llm/self_hosted_llm_client.py +5 -29
- rasa/shared/providers/mappings.py +19 -0
- rasa/shared/providers/router/__init__.py +0 -0
- rasa/shared/providers/router/_base_litellm_router_client.py +183 -0
- rasa/shared/providers/router/router_client.py +73 -0
- rasa/shared/utils/common.py +40 -24
- rasa/shared/utils/health_check/__init__.py +0 -0
- rasa/shared/utils/health_check/embeddings_health_check_mixin.py +31 -0
- rasa/shared/utils/health_check/health_check.py +258 -0
- rasa/shared/utils/health_check/llm_health_check_mixin.py +31 -0
- rasa/shared/utils/io.py +27 -6
- rasa/shared/utils/llm.py +353 -43
- rasa/shared/utils/schemas/events.py +2 -0
- rasa/shared/utils/schemas/model_config.yml +0 -10
- rasa/shared/utils/yaml.py +181 -38
- rasa/studio/data_handler.py +3 -1
- rasa/studio/upload.py +160 -74
- rasa/telemetry.py +94 -17
- rasa/tracing/config.py +3 -1
- rasa/tracing/instrumentation/attribute_extractors.py +95 -18
- rasa/tracing/instrumentation/instrumentation.py +121 -0
- rasa/utils/common.py +5 -0
- rasa/utils/endpoints.py +27 -1
- rasa/utils/io.py +8 -16
- rasa/utils/log_utils.py +9 -2
- rasa/utils/sanic_error_handler.py +32 -0
- rasa/validator.py +110 -4
- rasa/version.py +1 -1
- {rasa_pro-3.10.15.dist-info → rasa_pro-3.11.0.dist-info}/METADATA +14 -12
- {rasa_pro-3.10.15.dist-info → rasa_pro-3.11.0.dist-info}/RECORD +234 -183
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-1844e5a5.js +0 -1
- rasa/core/channels/inspector/dist/assets/index-a5d3e69d.js +0 -1040
- rasa/core/channels/voice_aware/utils.py +0 -20
- rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +0 -407
- /rasa/core/channels/{voice_aware → voice_ready}/__init__.py +0 -0
- /rasa/core/channels/{voice_native → voice_stream}/__init__.py +0 -0
- {rasa_pro-3.10.15.dist-info → rasa_pro-3.11.0.dist-info}/NOTICE +0 -0
- {rasa_pro-3.10.15.dist-info → rasa_pro-3.11.0.dist-info}/WHEEL +0 -0
- {rasa_pro-3.10.15.dist-info → rasa_pro-3.11.0.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
from dataclasses import asdict, dataclass, field
|
|
2
|
+
from typing import List, Optional
|
|
3
|
+
|
|
4
|
+
import structlog
|
|
5
|
+
from rasa.shared.constants import (
|
|
6
|
+
API_BASE_CONFIG_KEY,
|
|
7
|
+
API_KEY,
|
|
8
|
+
API_TYPE_CONFIG_KEY,
|
|
9
|
+
API_VERSION_CONFIG_KEY,
|
|
10
|
+
DEPLOYMENT_CONFIG_KEY,
|
|
11
|
+
PROVIDER_CONFIG_KEY,
|
|
12
|
+
MODEL_CONFIG_KEY,
|
|
13
|
+
MODEL_GROUP_ID_CONFIG_KEY,
|
|
14
|
+
MODELS_CONFIG_KEY,
|
|
15
|
+
MODEL_GROUPS_CONFIG_KEY,
|
|
16
|
+
EXTRA_PARAMETERS_KEY,
|
|
17
|
+
)
|
|
18
|
+
from rasa.shared.providers.mappings import get_client_config_class_from_provider
|
|
19
|
+
|
|
20
|
+
structlogger = structlog.get_logger()
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class ModelConfig:
|
|
25
|
+
"""Parses the model config.
|
|
26
|
+
|
|
27
|
+
Raises:
|
|
28
|
+
ValueError: If the provider config key is missing in the config.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
provider: str
|
|
32
|
+
model: Optional[str] = None
|
|
33
|
+
deployment: Optional[str] = None
|
|
34
|
+
api_base: Optional[str] = None
|
|
35
|
+
api_key: Optional[str] = None
|
|
36
|
+
api_version: Optional[str] = None
|
|
37
|
+
extra_parameters: dict = field(default_factory=dict)
|
|
38
|
+
# Retained for backward compatibility with older configurations,
|
|
39
|
+
# but intentionally not included in extra_parameters
|
|
40
|
+
api_type: Optional[str] = None
|
|
41
|
+
|
|
42
|
+
@classmethod
|
|
43
|
+
def from_dict(cls, config: dict) -> "ModelConfig":
|
|
44
|
+
"""Initializes a dataclass from the passed config. The provider config param is
|
|
45
|
+
used to determine the client config class to use. The client config class takes
|
|
46
|
+
care of resolving config aliases and throwing deprecation warnings.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
config: (dict) The config from which to initialize.
|
|
50
|
+
|
|
51
|
+
Raises:
|
|
52
|
+
ValueError: Config is missing required keys.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
ModelConfig
|
|
56
|
+
"""
|
|
57
|
+
from rasa.shared.utils.llm import get_provider_from_config
|
|
58
|
+
|
|
59
|
+
# Get the provider from config, this also inferring the provider from
|
|
60
|
+
# deprecated configurations
|
|
61
|
+
provider = get_provider_from_config(config)
|
|
62
|
+
|
|
63
|
+
# Retrieve the client configuration class for the specified provider.
|
|
64
|
+
client_config_clazz = get_client_config_class_from_provider(provider)
|
|
65
|
+
|
|
66
|
+
# Try to instantiate the config object in order to resolve deprecated
|
|
67
|
+
# aliases and throw deprecation warnings.
|
|
68
|
+
client_config_obj = client_config_clazz.from_dict(config)
|
|
69
|
+
|
|
70
|
+
# Convert back to dictionary and instantiate the ModelConfig object.
|
|
71
|
+
client_config = client_config_obj.to_dict()
|
|
72
|
+
|
|
73
|
+
# Check for provider after resolving all aliases
|
|
74
|
+
if PROVIDER_CONFIG_KEY not in client_config:
|
|
75
|
+
raise ValueError(
|
|
76
|
+
f"Missing required key '{PROVIDER_CONFIG_KEY}' in "
|
|
77
|
+
f"'{MODELS_CONFIG_KEY}' config."
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
return ModelConfig(
|
|
81
|
+
provider=client_config.pop(PROVIDER_CONFIG_KEY, None),
|
|
82
|
+
model=client_config.pop(MODEL_CONFIG_KEY, None),
|
|
83
|
+
deployment=client_config.pop(DEPLOYMENT_CONFIG_KEY, None),
|
|
84
|
+
api_type=client_config.pop(API_TYPE_CONFIG_KEY, None),
|
|
85
|
+
api_base=client_config.pop(API_BASE_CONFIG_KEY, None),
|
|
86
|
+
api_key=client_config.pop(API_KEY, None),
|
|
87
|
+
api_version=client_config.pop(API_VERSION_CONFIG_KEY, None),
|
|
88
|
+
extra_parameters=client_config,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
def to_dict(self) -> dict:
|
|
92
|
+
"""Converts the config instance into a dictionary."""
|
|
93
|
+
d = asdict(self)
|
|
94
|
+
|
|
95
|
+
# Extra parameters should also be on the top level
|
|
96
|
+
d.pop(EXTRA_PARAMETERS_KEY, None)
|
|
97
|
+
d.update(self.extra_parameters)
|
|
98
|
+
|
|
99
|
+
# Remove keys with None values
|
|
100
|
+
return {key: value for key, value in d.items() if value is not None}
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@dataclass
|
|
104
|
+
class ModelGroupConfig:
|
|
105
|
+
"""Parses the models config. The models config is a list of model configs.
|
|
106
|
+
|
|
107
|
+
Raises:
|
|
108
|
+
ValueError: If the model group ID is None or if the models list is empty.
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
model_group_id: str
|
|
112
|
+
models: List[ModelConfig]
|
|
113
|
+
|
|
114
|
+
def __post_init__(self) -> None:
|
|
115
|
+
if self.model_group_id is None:
|
|
116
|
+
message = "Model group ID cannot be set to None."
|
|
117
|
+
structlogger.error(
|
|
118
|
+
"model_group_config.validation_error",
|
|
119
|
+
message=message,
|
|
120
|
+
model_group_id=self.model_group_id,
|
|
121
|
+
)
|
|
122
|
+
raise ValueError(message)
|
|
123
|
+
if not self.models:
|
|
124
|
+
message = "Models cannot be empty."
|
|
125
|
+
structlogger.error(
|
|
126
|
+
"model_group_config.validation_error",
|
|
127
|
+
message=message,
|
|
128
|
+
model_group_id=self.model_group_id,
|
|
129
|
+
)
|
|
130
|
+
raise ValueError(message)
|
|
131
|
+
|
|
132
|
+
@classmethod
|
|
133
|
+
def from_dict(cls, config: dict) -> "ModelGroupConfig":
|
|
134
|
+
"""Initializes a dataclass from the passed config.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
config: (dict) The config from which to initialize.
|
|
138
|
+
|
|
139
|
+
Raises:
|
|
140
|
+
ValueError: Config is missing required keys.
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
ModelGroupConfig
|
|
144
|
+
"""
|
|
145
|
+
if MODELS_CONFIG_KEY not in config:
|
|
146
|
+
raise ValueError(
|
|
147
|
+
f"Missing required key '{MODELS_CONFIG_KEY}' in "
|
|
148
|
+
f"'{MODEL_GROUPS_CONFIG_KEY}' config."
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
models_config = [
|
|
152
|
+
ModelConfig.from_dict(model_config)
|
|
153
|
+
for model_config in config[MODELS_CONFIG_KEY]
|
|
154
|
+
]
|
|
155
|
+
|
|
156
|
+
return cls(
|
|
157
|
+
model_group_id=config.get(MODEL_GROUP_ID_CONFIG_KEY),
|
|
158
|
+
models=models_config,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
def to_dict(self) -> dict:
|
|
162
|
+
"""Converts the config instance into a dictionary."""
|
|
163
|
+
d = {
|
|
164
|
+
MODEL_GROUP_ID_CONFIG_KEY: self.model_group_id,
|
|
165
|
+
MODELS_CONFIG_KEY: [model.to_dict() for model in self.models],
|
|
166
|
+
}
|
|
167
|
+
return d
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
from dataclasses import asdict, dataclass, field
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import structlog
|
|
5
|
+
|
|
6
|
+
from rasa.shared.constants import (
|
|
7
|
+
MODEL_CONFIG_KEY,
|
|
8
|
+
RASA_PROVIDER,
|
|
9
|
+
PROVIDER_CONFIG_KEY,
|
|
10
|
+
API_BASE_CONFIG_KEY,
|
|
11
|
+
)
|
|
12
|
+
from rasa.shared.providers._configs.utils import (
|
|
13
|
+
validate_required_keys,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
REQUIRED_KEYS = [MODEL_CONFIG_KEY, PROVIDER_CONFIG_KEY, API_BASE_CONFIG_KEY]
|
|
17
|
+
|
|
18
|
+
structlogger = structlog.get_logger()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class RasaLLMClientConfig:
|
|
23
|
+
"""Parses configuration for a Rasa Hosted LiteLLM client,
|
|
24
|
+
checks required keys present.
|
|
25
|
+
|
|
26
|
+
Raises:
|
|
27
|
+
ValueError: Raised in cases of invalid configuration:
|
|
28
|
+
- If any of the required configuration keys are missing.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
model: Optional[str]
|
|
32
|
+
api_base: Optional[str]
|
|
33
|
+
# Provider is not used by LiteLLM backend, but we define it here since it's
|
|
34
|
+
# used as switch between different clients.
|
|
35
|
+
provider: str = RASA_PROVIDER
|
|
36
|
+
|
|
37
|
+
extra_parameters: dict = field(default_factory=dict)
|
|
38
|
+
|
|
39
|
+
@classmethod
|
|
40
|
+
def from_dict(cls, config: dict) -> "RasaLLMClientConfig":
|
|
41
|
+
"""
|
|
42
|
+
Initializes a dataclass from the passed config.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
config: (dict) The config from which to initialize.
|
|
46
|
+
|
|
47
|
+
Raises:
|
|
48
|
+
ValueError: Raised in cases of invalid configuration:
|
|
49
|
+
- If any of the required configuration keys are missing.
|
|
50
|
+
- If `api_type` has a value different from `azure`.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
RasaLLMClientConfig
|
|
54
|
+
"""
|
|
55
|
+
# Validate that required keys are set
|
|
56
|
+
validate_required_keys(config, REQUIRED_KEYS)
|
|
57
|
+
|
|
58
|
+
extra_parameters = {k: v for k, v in config.items() if k not in REQUIRED_KEYS}
|
|
59
|
+
|
|
60
|
+
return cls(
|
|
61
|
+
model=config.get(MODEL_CONFIG_KEY),
|
|
62
|
+
api_base=config.get(API_BASE_CONFIG_KEY),
|
|
63
|
+
provider=config.get(PROVIDER_CONFIG_KEY, RASA_PROVIDER),
|
|
64
|
+
extra_parameters=extra_parameters,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
def to_dict(self) -> dict:
|
|
68
|
+
"""Converts the config instance into a dictionary."""
|
|
69
|
+
d = asdict(self)
|
|
70
|
+
# Extra parameters should also be on the top level
|
|
71
|
+
d.pop("extra_parameters", None)
|
|
72
|
+
d.update(self.extra_parameters)
|
|
73
|
+
return d
|
|
@@ -99,3 +99,19 @@ def validate_forbidden_keys(config: dict, forbidden_keys: list) -> None:
|
|
|
99
99
|
config=config,
|
|
100
100
|
)
|
|
101
101
|
raise ValueError(message)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def get_provider_prefixed_model_name(provider: str, model: str) -> str:
|
|
105
|
+
"""
|
|
106
|
+
Returns the model name with the provider prefixed.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
provider: The provider of the model.
|
|
110
|
+
model: The model name.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
The model name with the provider prefixed.
|
|
114
|
+
"""
|
|
115
|
+
if model and f"{provider}/" not in model:
|
|
116
|
+
return f"{provider}/{model}"
|
|
117
|
+
return model
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
import structlog
|
|
2
|
+
|
|
3
|
+
from rasa.shared.constants import (
|
|
4
|
+
AWS_ACCESS_KEY_ID_ENV_VAR,
|
|
5
|
+
AWS_ACCESS_KEY_ID_CONFIG_KEY,
|
|
6
|
+
AWS_SECRET_ACCESS_KEY_ENV_VAR,
|
|
7
|
+
AWS_SECRET_ACCESS_KEY_CONFIG_KEY,
|
|
8
|
+
AWS_REGION_NAME_ENV_VAR,
|
|
9
|
+
AWS_REGION_NAME_CONFIG_KEY,
|
|
10
|
+
AWS_SESSION_TOKEN_CONFIG_KEY,
|
|
11
|
+
AWS_SESSION_TOKEN_ENV_VAR,
|
|
12
|
+
)
|
|
13
|
+
from rasa.shared.exceptions import ProviderClientValidationError
|
|
14
|
+
from litellm import validate_environment
|
|
15
|
+
from rasa.shared.providers.embedding._base_litellm_embedding_client import (
|
|
16
|
+
_VALIDATE_ENVIRONMENT_MISSING_KEYS_KEY,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
structlogger = structlog.get_logger()
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def validate_aws_setup_for_litellm_clients(
|
|
23
|
+
litellm_model_name: str, litellm_call_kwargs: dict, source_log: str
|
|
24
|
+
) -> None:
|
|
25
|
+
"""Validates the AWS setup for LiteLLM clients to ensure all required
|
|
26
|
+
environment variables or corresponding call kwargs are set.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
litellm_model_name (str): The name of the LiteLLM model being validated.
|
|
30
|
+
litellm_call_kwargs (dict): Additional keyword arguments passed to the client,
|
|
31
|
+
which may include configuration values for AWS credentials.
|
|
32
|
+
source_log (str): The source log identifier for structured logging.
|
|
33
|
+
|
|
34
|
+
Raises:
|
|
35
|
+
ProviderClientValidationError: If any required AWS environment variable
|
|
36
|
+
or corresponding configuration key is missing.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
# Mapping of environment variable names to their corresponding config keys
|
|
40
|
+
envs_to_args = {
|
|
41
|
+
AWS_ACCESS_KEY_ID_ENV_VAR: AWS_ACCESS_KEY_ID_CONFIG_KEY,
|
|
42
|
+
AWS_SECRET_ACCESS_KEY_ENV_VAR: AWS_SECRET_ACCESS_KEY_CONFIG_KEY,
|
|
43
|
+
AWS_REGION_NAME_ENV_VAR: AWS_REGION_NAME_CONFIG_KEY,
|
|
44
|
+
AWS_SESSION_TOKEN_ENV_VAR: AWS_SESSION_TOKEN_CONFIG_KEY,
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
# Validate the environment setup for the model
|
|
48
|
+
validation_info = validate_environment(litellm_model_name)
|
|
49
|
+
missing_environment_variables = validation_info.get(
|
|
50
|
+
_VALIDATE_ENVIRONMENT_MISSING_KEYS_KEY, []
|
|
51
|
+
)
|
|
52
|
+
# Filter out missing environment variables that have been set trough arguments
|
|
53
|
+
# in extra parameters
|
|
54
|
+
missing_environment_variables = [
|
|
55
|
+
missing_env_var
|
|
56
|
+
for missing_env_var in missing_environment_variables
|
|
57
|
+
if litellm_call_kwargs.get(envs_to_args.get(missing_env_var)) is None
|
|
58
|
+
]
|
|
59
|
+
|
|
60
|
+
if missing_environment_variables:
|
|
61
|
+
missing_environment_details = [
|
|
62
|
+
(
|
|
63
|
+
f"'{missing_env_var}' environment variable or "
|
|
64
|
+
f"'{envs_to_args.get(missing_env_var)}' config key"
|
|
65
|
+
)
|
|
66
|
+
for missing_env_var in missing_environment_variables
|
|
67
|
+
]
|
|
68
|
+
event_info = (
|
|
69
|
+
f"The following environment variables or configuration keys are "
|
|
70
|
+
f"missing: "
|
|
71
|
+
f"{', '.join(missing_environment_details)}. "
|
|
72
|
+
f"These settings are required for API calls."
|
|
73
|
+
)
|
|
74
|
+
structlogger.error(
|
|
75
|
+
f"{source_log}.validate_aws_environment_variables",
|
|
76
|
+
event_info=event_info,
|
|
77
|
+
missing_environment_variables=missing_environment_variables,
|
|
78
|
+
)
|
|
79
|
+
raise ProviderClientValidationError(event_info)
|
|
@@ -1,12 +1,12 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
from abc import abstractmethod
|
|
2
3
|
from typing import Any, Dict, List
|
|
3
4
|
|
|
4
5
|
import litellm
|
|
5
|
-
import logging
|
|
6
6
|
import structlog
|
|
7
7
|
from litellm import aembedding, embedding, validate_environment
|
|
8
8
|
|
|
9
|
-
from rasa.shared.constants import API_BASE_CONFIG_KEY
|
|
9
|
+
from rasa.shared.constants import API_BASE_CONFIG_KEY, API_KEY
|
|
10
10
|
from rasa.shared.exceptions import (
|
|
11
11
|
ProviderClientAPIException,
|
|
12
12
|
ProviderClientValidationError,
|
|
@@ -19,7 +19,7 @@ from rasa.shared.providers.embedding.embedding_response import (
|
|
|
19
19
|
EmbeddingResponse,
|
|
20
20
|
EmbeddingUsage,
|
|
21
21
|
)
|
|
22
|
-
from rasa.shared.utils.io import suppress_logs
|
|
22
|
+
from rasa.shared.utils.io import suppress_logs, resolve_environment_variables
|
|
23
23
|
|
|
24
24
|
structlogger = structlog.get_logger()
|
|
25
25
|
|
|
@@ -27,8 +27,7 @@ _VALIDATE_ENVIRONMENT_MISSING_KEYS_KEY = "missing_keys"
|
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
class _BaseLiteLLMEmbeddingClient:
|
|
30
|
-
"""
|
|
31
|
-
An abstract base class for LiteLLM embedding clients.
|
|
30
|
+
"""An abstract base class for LiteLLM embedding clients.
|
|
32
31
|
|
|
33
32
|
This class defines the interface and common functionality for all clients
|
|
34
33
|
based on LiteLLM.
|
|
@@ -83,12 +82,12 @@ class _BaseLiteLLMEmbeddingClient:
|
|
|
83
82
|
ProviderClientValidationError if validation fails.
|
|
84
83
|
"""
|
|
85
84
|
self._validate_environment_variables()
|
|
86
|
-
self._validate_api_key_not_in_config()
|
|
87
85
|
|
|
88
86
|
def _validate_environment_variables(self) -> None:
|
|
89
87
|
"""Validate that the required environment variables are set."""
|
|
90
88
|
validation_info = validate_environment(
|
|
91
89
|
self._litellm_model_name,
|
|
90
|
+
api_key=self._litellm_extra_parameters.get(API_KEY),
|
|
92
91
|
api_base=self._litellm_extra_parameters.get(API_BASE_CONFIG_KEY),
|
|
93
92
|
)
|
|
94
93
|
if missing_environment_variables := validation_info.get(
|
|
@@ -105,21 +104,8 @@ class _BaseLiteLLMEmbeddingClient:
|
|
|
105
104
|
)
|
|
106
105
|
raise ProviderClientValidationError(event_info)
|
|
107
106
|
|
|
108
|
-
def _validate_api_key_not_in_config(self) -> None:
|
|
109
|
-
if "api_key" in self._litellm_extra_parameters:
|
|
110
|
-
event_info = (
|
|
111
|
-
"API Key is set through `api_key` extra parameter."
|
|
112
|
-
"Set API keys through environment variables."
|
|
113
|
-
)
|
|
114
|
-
structlogger.error(
|
|
115
|
-
"base_litellm_client.validate_api_key_not_in_config",
|
|
116
|
-
event_info=event_info,
|
|
117
|
-
)
|
|
118
|
-
raise ProviderClientValidationError(event_info)
|
|
119
|
-
|
|
120
107
|
def validate_documents(self, documents: List[str]) -> None:
|
|
121
|
-
"""
|
|
122
|
-
Validates a list of documents to ensure they are suitable for embedding.
|
|
108
|
+
"""Validates a list of documents to ensure they are suitable for embedding.
|
|
123
109
|
|
|
124
110
|
Args:
|
|
125
111
|
documents: List of documents to be validated.
|
|
@@ -135,8 +121,7 @@ class _BaseLiteLLMEmbeddingClient:
|
|
|
135
121
|
|
|
136
122
|
@suppress_logs(log_level=logging.WARNING)
|
|
137
123
|
def embed(self, documents: List[str]) -> EmbeddingResponse:
|
|
138
|
-
"""
|
|
139
|
-
Embeds a list of documents synchronously.
|
|
124
|
+
"""Embeds a list of documents synchronously.
|
|
140
125
|
|
|
141
126
|
Args:
|
|
142
127
|
documents: List of documents to be embedded.
|
|
@@ -149,7 +134,8 @@ class _BaseLiteLLMEmbeddingClient:
|
|
|
149
134
|
"""
|
|
150
135
|
self.validate_documents(documents)
|
|
151
136
|
try:
|
|
152
|
-
|
|
137
|
+
arguments = resolve_environment_variables(self._embedding_fn_args)
|
|
138
|
+
response = embedding(input=documents, **arguments)
|
|
153
139
|
return self._format_response(response)
|
|
154
140
|
except Exception as e:
|
|
155
141
|
raise ProviderClientAPIException(
|
|
@@ -158,8 +144,7 @@ class _BaseLiteLLMEmbeddingClient:
|
|
|
158
144
|
|
|
159
145
|
@suppress_logs(log_level=logging.WARNING)
|
|
160
146
|
async def aembed(self, documents: List[str]) -> EmbeddingResponse:
|
|
161
|
-
"""
|
|
162
|
-
Embeds a list of documents asynchronously.
|
|
147
|
+
"""Embeds a list of documents asynchronously.
|
|
163
148
|
|
|
164
149
|
Args:
|
|
165
150
|
documents: List of documents to be embedded.
|
|
@@ -172,7 +157,8 @@ class _BaseLiteLLMEmbeddingClient:
|
|
|
172
157
|
"""
|
|
173
158
|
self.validate_documents(documents)
|
|
174
159
|
try:
|
|
175
|
-
|
|
160
|
+
arguments = resolve_environment_variables(self._embedding_fn_args)
|
|
161
|
+
response = await aembedding(input=documents, **arguments)
|
|
176
162
|
return self._format_response(response)
|
|
177
163
|
except Exception as e:
|
|
178
164
|
raise ProviderClientAPIException(
|
|
@@ -187,7 +173,6 @@ class _BaseLiteLLMEmbeddingClient:
|
|
|
187
173
|
Raises:
|
|
188
174
|
ValueError: If any response data is None.
|
|
189
175
|
"""
|
|
190
|
-
|
|
191
176
|
# If data is not available (None), raise a ValueError
|
|
192
177
|
if response.data is None:
|
|
193
178
|
message = (
|
|
@@ -244,8 +229,7 @@ class _BaseLiteLLMEmbeddingClient:
|
|
|
244
229
|
|
|
245
230
|
@staticmethod
|
|
246
231
|
def _ensure_certificates() -> None:
|
|
247
|
-
"""
|
|
248
|
-
Configures SSL certificates for LiteLLM. This method is invoked during
|
|
232
|
+
"""Configures SSL certificates for LiteLLM. This method is invoked during
|
|
249
233
|
client initialization.
|
|
250
234
|
|
|
251
235
|
LiteLLM may utilize `openai` clients or other providers that require
|
|
@@ -1,5 +1,6 @@
|
|
|
1
|
-
from typing import Any, Dict, List, Optional
|
|
2
1
|
import os
|
|
2
|
+
from typing import Any, Dict, List, Optional
|
|
3
|
+
|
|
3
4
|
import structlog
|
|
4
5
|
|
|
5
6
|
from rasa.shared.constants import (
|
|
@@ -42,6 +43,7 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
|
|
|
42
43
|
If not provided, it will be set via environment variable.
|
|
43
44
|
kwargs (Optional[Dict[str, Any]]): Optional configuration parameters specific
|
|
44
45
|
to the embedding model deployment.
|
|
46
|
+
|
|
45
47
|
Raises:
|
|
46
48
|
ProviderClientValidationError: If validation of the client setup fails.
|
|
47
49
|
DeprecationWarning: If deprecated environment variables are used for
|
|
@@ -60,6 +62,7 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
|
|
|
60
62
|
super().__init__() # type: ignore
|
|
61
63
|
self._deployment = deployment
|
|
62
64
|
self._model = model
|
|
65
|
+
self._extra_parameters = kwargs or {}
|
|
63
66
|
|
|
64
67
|
# Set api_base with the following priority:
|
|
65
68
|
# parameter -> Azure Env Var -> (deprecated) OpenAI Env Var
|
|
@@ -81,17 +84,55 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
|
|
|
81
84
|
# Litellm does not support use of OPENAI_API_KEY, so we need to map it
|
|
82
85
|
# because of backward compatibility. However, we're first looking at
|
|
83
86
|
# AZURE_API_KEY.
|
|
84
|
-
self.
|
|
85
|
-
OPENAI_API_KEY_ENV_VAR
|
|
86
|
-
)
|
|
87
|
+
self._api_key_env_var = self._resolve_api_key_env_var()
|
|
87
88
|
|
|
88
|
-
self._extra_parameters = kwargs or {}
|
|
89
89
|
self.validate_client_setup()
|
|
90
90
|
|
|
91
|
+
def _resolve_api_key_env_var(self) -> str:
|
|
92
|
+
"""Resolves the environment variable to use for the API key.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
str: The env variable in dollar syntax format to use for the API key.
|
|
96
|
+
"""
|
|
97
|
+
if API_KEY in self._extra_parameters:
|
|
98
|
+
# API key is set to an env var in the config itself
|
|
99
|
+
# in case the model is defined in the endpoints.yml
|
|
100
|
+
return self._extra_parameters[API_KEY]
|
|
101
|
+
|
|
102
|
+
if os.getenv(AZURE_API_KEY_ENV_VAR) is not None:
|
|
103
|
+
return "${AZURE_API_KEY}"
|
|
104
|
+
|
|
105
|
+
if os.getenv(OPENAI_API_KEY_ENV_VAR) is not None:
|
|
106
|
+
# API key can be set through OPENAI_API_KEY too,
|
|
107
|
+
# because of the backward compatibility
|
|
108
|
+
raise_deprecation_warning(
|
|
109
|
+
message=(
|
|
110
|
+
f"Usage of '{OPENAI_API_KEY_ENV_VAR}' environment variable "
|
|
111
|
+
"for setting the API key of "
|
|
112
|
+
"Azure OpenAI client is deprecated and will "
|
|
113
|
+
"be removed in 4.0.0. Please "
|
|
114
|
+
f"use '{AZURE_API_KEY_ENV_VAR}' instead."
|
|
115
|
+
)
|
|
116
|
+
)
|
|
117
|
+
return "${OPENAI_API_KEY}"
|
|
118
|
+
|
|
119
|
+
structlogger.error(
|
|
120
|
+
"azure_openai_embedding_client.api_key_not_set",
|
|
121
|
+
event_info=(
|
|
122
|
+
"API key not set, it is required for API calls. "
|
|
123
|
+
f"Set it either via the environment variable "
|
|
124
|
+
f"'{AZURE_API_KEY_ENV_VAR}' or directly"
|
|
125
|
+
f"via the config key '{API_KEY}'."
|
|
126
|
+
),
|
|
127
|
+
)
|
|
128
|
+
raise ProviderClientValidationError(
|
|
129
|
+
f"Missing required environment variable/config key '{API_KEY}' for "
|
|
130
|
+
f"API calls."
|
|
131
|
+
)
|
|
132
|
+
|
|
91
133
|
@classmethod
|
|
92
134
|
def from_config(cls, config: Dict[str, Any]) -> "AzureOpenAIEmbeddingClient":
|
|
93
|
-
"""
|
|
94
|
-
Initializes the client from given configuration.
|
|
135
|
+
"""Initializes the client from given configuration.
|
|
95
136
|
|
|
96
137
|
Args:
|
|
97
138
|
config (Dict[str, Any]): Configuration.
|
|
@@ -142,8 +183,7 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
|
|
|
142
183
|
|
|
143
184
|
@property
|
|
144
185
|
def model(self) -> Optional[str]:
|
|
145
|
-
"""
|
|
146
|
-
Returns the name of the model deployed on Azure. If model name is not
|
|
186
|
+
"""Returns the name of the model deployed on Azure. If model name is not
|
|
147
187
|
provided, returns "N/A".
|
|
148
188
|
"""
|
|
149
189
|
return self._model
|
|
@@ -170,8 +210,7 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
|
|
|
170
210
|
|
|
171
211
|
@property
|
|
172
212
|
def _litellm_extra_parameters(self) -> Dict[str, Any]:
|
|
173
|
-
"""
|
|
174
|
-
Returns the model parameters for the azure openai embedding client.
|
|
213
|
+
"""Returns the model parameters for the azure openai embedding client.
|
|
175
214
|
|
|
176
215
|
Returns:
|
|
177
216
|
Dictionary containing the model parameters.
|
|
@@ -186,7 +225,7 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
|
|
|
186
225
|
"api_base": self.api_base,
|
|
187
226
|
"api_type": self.api_type,
|
|
188
227
|
"api_version": self.api_version,
|
|
189
|
-
"api_key": self.
|
|
228
|
+
"api_key": self._api_key_env_var,
|
|
190
229
|
}
|
|
191
230
|
|
|
192
231
|
@property
|
|
@@ -197,8 +236,9 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
|
|
|
197
236
|
return self.deployment
|
|
198
237
|
|
|
199
238
|
def validate_client_setup(self) -> None:
|
|
200
|
-
"""Perform client validation.
|
|
201
|
-
|
|
239
|
+
"""Perform client validation.
|
|
240
|
+
|
|
241
|
+
By default, only environment variables are validated.
|
|
202
242
|
|
|
203
243
|
Raises:
|
|
204
244
|
ProviderClientValidationError if validation fails.
|
|
@@ -214,13 +254,6 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
|
|
|
214
254
|
"current_value": self.api_base,
|
|
215
255
|
"new_env_key": AZURE_API_BASE_ENV_VAR,
|
|
216
256
|
},
|
|
217
|
-
{
|
|
218
|
-
"param_name": "API key",
|
|
219
|
-
"config_key": API_KEY,
|
|
220
|
-
"deprecated_env_key": OPENAI_API_KEY_ENV_VAR,
|
|
221
|
-
"current_value": self._api_key,
|
|
222
|
-
"new_env_key": AZURE_API_KEY_ENV_VAR,
|
|
223
|
-
},
|
|
224
257
|
{
|
|
225
258
|
"param_name": "API version",
|
|
226
259
|
"config_key": API_VERSION_CONFIG_KEY,
|
|
@@ -1,8 +1,13 @@
|
|
|
1
1
|
from typing import Any, Dict
|
|
2
2
|
|
|
3
|
+
from rasa.shared.constants import (
|
|
4
|
+
AWS_BEDROCK_PROVIDER,
|
|
5
|
+
AWS_SAGEMAKER_PROVIDER,
|
|
6
|
+
)
|
|
3
7
|
from rasa.shared.providers._configs.default_litellm_client_config import (
|
|
4
8
|
DefaultLiteLLMClientConfig,
|
|
5
9
|
)
|
|
10
|
+
from rasa.shared.providers._utils import validate_aws_setup_for_litellm_clients
|
|
6
11
|
from rasa.shared.providers.embedding._base_litellm_embedding_client import (
|
|
7
12
|
_BaseLiteLLMEmbeddingClient,
|
|
8
13
|
)
|
|
@@ -100,3 +105,22 @@ class DefaultLiteLLMEmbeddingClient(_BaseLiteLLMEmbeddingClient):
|
|
|
100
105
|
"model": self._litellm_model_name,
|
|
101
106
|
**self._litellm_extra_parameters,
|
|
102
107
|
}
|
|
108
|
+
|
|
109
|
+
def validate_client_setup(self) -> None:
|
|
110
|
+
# TODO: Temporarily disable environment variable validation for AWS setup
|
|
111
|
+
# (Bedrock and SageMaker) until resolved by either:
|
|
112
|
+
# 1. An update from the LiteLLM package addressing the issue.
|
|
113
|
+
# 2. The implementation of a Bedrock client on our end.
|
|
114
|
+
# ---
|
|
115
|
+
# This fix ensures a consistent user experience for Bedrock (and
|
|
116
|
+
# SageMaker) in Rasa by allowing AWS secrets to be provided as extra
|
|
117
|
+
# parameters without triggering validation errors due to missing AWS
|
|
118
|
+
# environment variables.
|
|
119
|
+
if self.provider.lower() in [AWS_BEDROCK_PROVIDER, AWS_SAGEMAKER_PROVIDER]:
|
|
120
|
+
validate_aws_setup_for_litellm_clients(
|
|
121
|
+
self._litellm_model_name,
|
|
122
|
+
self._litellm_extra_parameters,
|
|
123
|
+
"default_litellm_embedding_client",
|
|
124
|
+
)
|
|
125
|
+
else:
|
|
126
|
+
super().validate_client_setup()
|