rasa-pro 3.10.15__py3-none-any.whl → 3.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/__main__.py +31 -15
- rasa/api.py +12 -2
- rasa/cli/arguments/default_arguments.py +24 -4
- rasa/cli/arguments/run.py +15 -0
- rasa/cli/arguments/shell.py +5 -1
- rasa/cli/arguments/train.py +17 -9
- rasa/cli/evaluate.py +7 -7
- rasa/cli/inspect.py +19 -7
- rasa/cli/interactive.py +1 -0
- rasa/cli/project_templates/calm/config.yml +5 -7
- rasa/cli/project_templates/calm/endpoints.yml +15 -2
- rasa/cli/project_templates/tutorial/config.yml +8 -5
- rasa/cli/project_templates/tutorial/data/flows.yml +1 -1
- rasa/cli/project_templates/tutorial/data/patterns.yml +5 -0
- rasa/cli/project_templates/tutorial/domain.yml +14 -0
- rasa/cli/project_templates/tutorial/endpoints.yml +5 -0
- rasa/cli/run.py +7 -0
- rasa/cli/scaffold.py +4 -2
- rasa/cli/studio/upload.py +0 -15
- rasa/cli/train.py +14 -53
- rasa/cli/utils.py +14 -11
- rasa/cli/x.py +7 -7
- rasa/constants.py +3 -1
- rasa/core/actions/action.py +77 -33
- rasa/core/actions/action_hangup.py +29 -0
- rasa/core/actions/action_repeat_bot_messages.py +89 -0
- rasa/core/actions/e2e_stub_custom_action_executor.py +5 -1
- rasa/core/actions/http_custom_action_executor.py +4 -0
- rasa/core/agent.py +2 -2
- rasa/core/brokers/kafka.py +3 -1
- rasa/core/brokers/pika.py +3 -1
- rasa/core/channels/__init__.py +10 -6
- rasa/core/channels/channel.py +41 -4
- rasa/core/channels/development_inspector.py +150 -46
- rasa/core/channels/inspector/README.md +1 -1
- rasa/core/channels/inspector/dist/assets/{arc-b6e548fe.js → arc-bc141fb2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{c4Diagram-d0fbc5ce-fa03ac9e.js → c4Diagram-d0fbc5ce-be2db283.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-936ed81e-ee67392a.js → classDiagram-936ed81e-55366915.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-v2-c3cb15f1-9b283fae.js → classDiagram-v2-c3cb15f1-bb529518.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{createText-62fc7601-8b6fcc2a.js → createText-62fc7601-b0ec81d6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{edges-f2ad444c-22e77f4f.js → edges-f2ad444c-6166330c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{erDiagram-9d236eb7-60ffc87f.js → erDiagram-9d236eb7-5ccc6a8e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDb-1972c806-9dd802e4.js → flowDb-1972c806-fca3bfe4.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDiagram-7ea5b25a-5fa1912f.js → flowDiagram-7ea5b25a-4739080f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-736177bf.js +1 -0
- rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-abe16c3d-622a1fd2.js → flowchart-elk-definition-abe16c3d-7c1b0e0f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{ganttDiagram-9b5ea136-e285a63a.js → ganttDiagram-9b5ea136-772fd050.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-99d0ae7c-f237bdca.js → gitGraphDiagram-99d0ae7c-8eae1dc9.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-2c4b9a3b-4b03d70e.js → index-2c4b9a3b-f55afcdf.js} +1 -1
- rasa/core/channels/inspector/dist/assets/index-e7cef9de.js +1317 -0
- rasa/core/channels/inspector/dist/assets/{infoDiagram-736b4530-72a0fa5f.js → infoDiagram-736b4530-124d4a14.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{journeyDiagram-df861f2b-82218c41.js → journeyDiagram-df861f2b-7c4fae44.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{layout-78cff630.js → layout-b9885fb6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{line-5038b469.js → line-7c59abb6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{linear-c4fc4098.js → linear-4776f780.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{mindmap-definition-beec6740-c33c8ea6.js → mindmap-definition-beec6740-2332c46c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{pieDiagram-dbbf0591-a8d03059.js → pieDiagram-dbbf0591-8fb39303.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{quadrantDiagram-4d7f4fd6-6a0e56b2.js → quadrantDiagram-4d7f4fd6-3c7180a2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{requirementDiagram-6fc4c22a-2dc7c7bd.js → requirementDiagram-6fc4c22a-e910bcb8.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sankeyDiagram-8f13d901-2360fe39.js → sankeyDiagram-8f13d901-ead16c89.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sequenceDiagram-b655622a-41b9f9ad.js → sequenceDiagram-b655622a-29a02a19.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-59f0c015-0aad326f.js → stateDiagram-59f0c015-042b3137.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-2b26beab-9847d984.js → stateDiagram-v2-2b26beab-2178c0f3.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-080da4f6-564d890e.js → styles-080da4f6-23ffa4fc.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-3dcbcfbf-38957613.js → styles-3dcbcfbf-94f59763.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-9c745c82-f0fc6921.js → styles-9c745c82-78a6bebc.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{svgDrawCommon-4835440b-ef3c5a77.js → svgDrawCommon-4835440b-eae2a6f6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{timeline-definition-5b62e21b-bf3e91c1.js → timeline-definition-5b62e21b-5c968d92.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{xychartDiagram-2b33534f-4d4026c0.js → xychartDiagram-2b33534f-fd3db0d5.js} +1 -1
- rasa/core/channels/inspector/dist/index.html +18 -15
- rasa/core/channels/inspector/index.html +17 -14
- rasa/core/channels/inspector/package.json +5 -1
- rasa/core/channels/inspector/src/App.tsx +118 -68
- rasa/core/channels/inspector/src/components/Chat.tsx +95 -0
- rasa/core/channels/inspector/src/components/DiagramFlow.tsx +11 -10
- rasa/core/channels/inspector/src/components/DialogueStack.tsx +10 -25
- rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +6 -3
- rasa/core/channels/inspector/src/helpers/audiostream.ts +165 -0
- rasa/core/channels/inspector/src/helpers/formatters.test.ts +10 -0
- rasa/core/channels/inspector/src/helpers/formatters.ts +107 -41
- rasa/core/channels/inspector/src/helpers/utils.ts +92 -7
- rasa/core/channels/inspector/src/types.ts +21 -1
- rasa/core/channels/inspector/yarn.lock +94 -1
- rasa/core/channels/rest.py +51 -46
- rasa/core/channels/socketio.py +28 -1
- rasa/core/channels/telegram.py +1 -1
- rasa/core/channels/twilio.py +1 -1
- rasa/core/channels/{audiocodes.py → voice_ready/audiocodes.py} +122 -69
- rasa/core/channels/{voice_aware → voice_ready}/jambonz.py +26 -8
- rasa/core/channels/{voice_aware → voice_ready}/jambonz_protocol.py +57 -5
- rasa/core/channels/{twilio_voice.py → voice_ready/twilio_voice.py} +64 -28
- rasa/core/channels/voice_ready/utils.py +37 -0
- rasa/core/channels/voice_stream/asr/__init__.py +0 -0
- rasa/core/channels/voice_stream/asr/asr_engine.py +89 -0
- rasa/core/channels/voice_stream/asr/asr_event.py +18 -0
- rasa/core/channels/voice_stream/asr/azure.py +129 -0
- rasa/core/channels/voice_stream/asr/deepgram.py +90 -0
- rasa/core/channels/voice_stream/audio_bytes.py +8 -0
- rasa/core/channels/voice_stream/browser_audio.py +107 -0
- rasa/core/channels/voice_stream/call_state.py +23 -0
- rasa/core/channels/voice_stream/tts/__init__.py +0 -0
- rasa/core/channels/voice_stream/tts/azure.py +106 -0
- rasa/core/channels/voice_stream/tts/cartesia.py +118 -0
- rasa/core/channels/voice_stream/tts/tts_cache.py +27 -0
- rasa/core/channels/voice_stream/tts/tts_engine.py +58 -0
- rasa/core/channels/voice_stream/twilio_media_streams.py +173 -0
- rasa/core/channels/voice_stream/util.py +57 -0
- rasa/core/channels/voice_stream/voice_channel.py +427 -0
- rasa/core/information_retrieval/qdrant.py +1 -0
- rasa/core/nlg/contextual_response_rephraser.py +45 -17
- rasa/{nlu → core}/persistor.py +203 -68
- rasa/core/policies/enterprise_search_policy.py +119 -63
- rasa/core/policies/flows/flow_executor.py +15 -22
- rasa/core/policies/intentless_policy.py +83 -28
- rasa/core/processor.py +25 -0
- rasa/core/run.py +12 -2
- rasa/core/secrets_manager/constants.py +4 -0
- rasa/core/secrets_manager/factory.py +8 -0
- rasa/core/secrets_manager/vault.py +11 -1
- rasa/core/training/interactive.py +33 -34
- rasa/core/utils.py +47 -21
- rasa/dialogue_understanding/coexistence/llm_based_router.py +41 -14
- rasa/dialogue_understanding/commands/__init__.py +6 -0
- rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +60 -0
- rasa/dialogue_understanding/commands/session_end_command.py +61 -0
- rasa/dialogue_understanding/commands/user_silence_command.py +59 -0
- rasa/dialogue_understanding/commands/utils.py +5 -0
- rasa/dialogue_understanding/generator/constants.py +2 -0
- rasa/dialogue_understanding/generator/flow_retrieval.py +47 -9
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +38 -15
- rasa/dialogue_understanding/generator/llm_command_generator.py +1 -1
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +35 -13
- rasa/dialogue_understanding/generator/single_step/command_prompt_template.jinja2 +3 -0
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +60 -13
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +53 -0
- rasa/dialogue_understanding/patterns/repeat.py +37 -0
- rasa/dialogue_understanding/patterns/user_silence.py +37 -0
- rasa/dialogue_understanding/processor/command_processor.py +21 -1
- rasa/e2e_test/aggregate_test_stats_calculator.py +1 -11
- rasa/e2e_test/assertions.py +136 -61
- rasa/e2e_test/assertions_schema.yml +23 -0
- rasa/e2e_test/e2e_test_case.py +85 -6
- rasa/e2e_test/e2e_test_runner.py +2 -3
- rasa/engine/graph.py +0 -1
- rasa/engine/loader.py +12 -0
- rasa/engine/recipes/config_files/default_config.yml +0 -3
- rasa/engine/recipes/default_recipe.py +0 -1
- rasa/engine/recipes/graph_recipe.py +0 -1
- rasa/engine/runner/dask.py +2 -2
- rasa/engine/storage/local_model_storage.py +12 -42
- rasa/engine/storage/storage.py +1 -5
- rasa/engine/validation.py +527 -74
- rasa/model_manager/__init__.py +0 -0
- rasa/model_manager/config.py +40 -0
- rasa/model_manager/model_api.py +559 -0
- rasa/model_manager/runner_service.py +286 -0
- rasa/model_manager/socket_bridge.py +146 -0
- rasa/model_manager/studio_jwt_auth.py +86 -0
- rasa/model_manager/trainer_service.py +325 -0
- rasa/model_manager/utils.py +87 -0
- rasa/model_manager/warm_rasa_process.py +187 -0
- rasa/model_service.py +112 -0
- rasa/model_training.py +42 -23
- rasa/nlu/tokenizers/whitespace_tokenizer.py +3 -14
- rasa/server.py +4 -2
- rasa/shared/constants.py +60 -8
- rasa/shared/core/constants.py +13 -0
- rasa/shared/core/domain.py +107 -50
- rasa/shared/core/events.py +29 -0
- rasa/shared/core/flows/flow.py +5 -0
- rasa/shared/core/flows/flows_list.py +19 -6
- rasa/shared/core/flows/flows_yaml_schema.json +10 -0
- rasa/shared/core/flows/utils.py +39 -0
- rasa/shared/core/flows/validation.py +121 -0
- rasa/shared/core/flows/yaml_flows_io.py +15 -27
- rasa/shared/core/slots.py +5 -0
- rasa/shared/importers/importer.py +59 -41
- rasa/shared/importers/multi_project.py +23 -11
- rasa/shared/importers/rasa.py +12 -3
- rasa/shared/importers/remote_importer.py +196 -0
- rasa/shared/importers/utils.py +3 -1
- rasa/shared/nlu/training_data/formats/rasa_yaml.py +18 -3
- rasa/shared/nlu/training_data/training_data.py +18 -19
- rasa/shared/providers/_configs/litellm_router_client_config.py +220 -0
- rasa/shared/providers/_configs/model_group_config.py +167 -0
- rasa/shared/providers/_configs/openai_client_config.py +1 -1
- rasa/shared/providers/_configs/rasa_llm_client_config.py +73 -0
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +1 -0
- rasa/shared/providers/_configs/utils.py +16 -0
- rasa/shared/providers/_utils.py +79 -0
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +13 -29
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +54 -21
- rasa/shared/providers/embedding/default_litellm_embedding_client.py +24 -0
- rasa/shared/providers/embedding/litellm_router_embedding_client.py +135 -0
- rasa/shared/providers/llm/_base_litellm_client.py +34 -22
- rasa/shared/providers/llm/azure_openai_llm_client.py +50 -29
- rasa/shared/providers/llm/default_litellm_llm_client.py +24 -0
- rasa/shared/providers/llm/litellm_router_llm_client.py +182 -0
- rasa/shared/providers/llm/rasa_llm_client.py +112 -0
- rasa/shared/providers/llm/self_hosted_llm_client.py +5 -29
- rasa/shared/providers/mappings.py +19 -0
- rasa/shared/providers/router/__init__.py +0 -0
- rasa/shared/providers/router/_base_litellm_router_client.py +183 -0
- rasa/shared/providers/router/router_client.py +73 -0
- rasa/shared/utils/common.py +40 -24
- rasa/shared/utils/health_check/__init__.py +0 -0
- rasa/shared/utils/health_check/embeddings_health_check_mixin.py +31 -0
- rasa/shared/utils/health_check/health_check.py +258 -0
- rasa/shared/utils/health_check/llm_health_check_mixin.py +31 -0
- rasa/shared/utils/io.py +27 -6
- rasa/shared/utils/llm.py +353 -43
- rasa/shared/utils/schemas/events.py +2 -0
- rasa/shared/utils/schemas/model_config.yml +0 -10
- rasa/shared/utils/yaml.py +181 -38
- rasa/studio/data_handler.py +3 -1
- rasa/studio/upload.py +160 -74
- rasa/telemetry.py +94 -17
- rasa/tracing/config.py +3 -1
- rasa/tracing/instrumentation/attribute_extractors.py +95 -18
- rasa/tracing/instrumentation/instrumentation.py +121 -0
- rasa/utils/common.py +5 -0
- rasa/utils/endpoints.py +27 -1
- rasa/utils/io.py +8 -16
- rasa/utils/log_utils.py +9 -2
- rasa/utils/sanic_error_handler.py +32 -0
- rasa/validator.py +110 -4
- rasa/version.py +1 -1
- {rasa_pro-3.10.15.dist-info → rasa_pro-3.11.0.dist-info}/METADATA +14 -12
- {rasa_pro-3.10.15.dist-info → rasa_pro-3.11.0.dist-info}/RECORD +234 -183
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-1844e5a5.js +0 -1
- rasa/core/channels/inspector/dist/assets/index-a5d3e69d.js +0 -1040
- rasa/core/channels/voice_aware/utils.py +0 -20
- rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +0 -407
- /rasa/core/channels/{voice_aware → voice_ready}/__init__.py +0 -0
- /rasa/core/channels/{voice_native → voice_stream}/__init__.py +0 -0
- {rasa_pro-3.10.15.dist-info → rasa_pro-3.11.0.dist-info}/NOTICE +0 -0
- {rasa_pro-3.10.15.dist-info → rasa_pro-3.11.0.dist-info}/WHEEL +0 -0
- {rasa_pro-3.10.15.dist-info → rasa_pro-3.11.0.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
from typing import Any, Dict, Optional
|
|
2
|
+
|
|
3
|
+
import structlog
|
|
4
|
+
|
|
5
|
+
from rasa.shared.constants import (
|
|
6
|
+
RASA_PROVIDER,
|
|
7
|
+
OPENAI_PROVIDER,
|
|
8
|
+
)
|
|
9
|
+
from rasa.shared.providers._configs.rasa_llm_client_config import (
|
|
10
|
+
RasaLLMClientConfig,
|
|
11
|
+
)
|
|
12
|
+
from rasa.utils.licensing import retrieve_license_from_env
|
|
13
|
+
from rasa.shared.providers.llm._base_litellm_client import _BaseLiteLLMClient
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
structlogger = structlog.get_logger()
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class RasaLLMClient(_BaseLiteLLMClient):
|
|
20
|
+
"""A client for interfacing with a Rasa-Hosted LLM endpoint that uses
|
|
21
|
+
|
|
22
|
+
Parameters:
|
|
23
|
+
model (str): The model or deployment name.
|
|
24
|
+
api_base (str): The base URL of the API endpoint.
|
|
25
|
+
kwargs: Any: Additional configuration parameters that can include, but
|
|
26
|
+
are not limited to model parameters and lite-llm specific
|
|
27
|
+
parameters. These parameters will be passed to the
|
|
28
|
+
completion/acompletion calls. To see what it can include, visit:
|
|
29
|
+
|
|
30
|
+
Raises:
|
|
31
|
+
ProviderClientValidationError: If validation of the client setup fails.
|
|
32
|
+
ProviderClientAPIException: If the API request fails.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
model: str,
|
|
38
|
+
api_base: str,
|
|
39
|
+
**kwargs: Any,
|
|
40
|
+
):
|
|
41
|
+
super().__init__() # type: ignore
|
|
42
|
+
self._model = model
|
|
43
|
+
self._api_base = api_base
|
|
44
|
+
self._use_chat_completions_endpoint = True
|
|
45
|
+
self._extra_parameters = kwargs or {}
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def model(self) -> str:
|
|
49
|
+
return self._model
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def api_base(self) -> Optional[str]:
|
|
53
|
+
"""
|
|
54
|
+
Returns the base API URL for the openai llm client.
|
|
55
|
+
"""
|
|
56
|
+
return self._api_base
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def provider(self) -> str:
|
|
60
|
+
"""
|
|
61
|
+
Returns the provider name for the self hosted llm client.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
String representing the provider name.
|
|
65
|
+
"""
|
|
66
|
+
return RASA_PROVIDER
|
|
67
|
+
|
|
68
|
+
@property
|
|
69
|
+
def _litellm_model_name(self) -> str:
|
|
70
|
+
return f"{OPENAI_PROVIDER}/{self._model}"
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def _litellm_extra_parameters(self) -> Dict[str, Any]:
|
|
74
|
+
return self._extra_parameters
|
|
75
|
+
|
|
76
|
+
@property
|
|
77
|
+
def config(self) -> dict:
|
|
78
|
+
return RasaLLMClientConfig(
|
|
79
|
+
model=self._model,
|
|
80
|
+
api_base=self._api_base,
|
|
81
|
+
extra_parameters=self._extra_parameters,
|
|
82
|
+
).to_dict()
|
|
83
|
+
|
|
84
|
+
@property
|
|
85
|
+
def _completion_fn_args(self) -> Dict[str, Any]:
|
|
86
|
+
"""Returns the completion arguments for invoking a call through
|
|
87
|
+
LiteLLM's completion functions.
|
|
88
|
+
"""
|
|
89
|
+
fn_args = super()._completion_fn_args
|
|
90
|
+
fn_args.update(
|
|
91
|
+
{"api_base": self.api_base, "api_key": retrieve_license_from_env()}
|
|
92
|
+
)
|
|
93
|
+
return fn_args
|
|
94
|
+
|
|
95
|
+
@classmethod
|
|
96
|
+
def from_config(cls, config: Dict[str, Any]) -> "RasaLLMClient":
|
|
97
|
+
try:
|
|
98
|
+
client_config = RasaLLMClientConfig.from_dict(config)
|
|
99
|
+
except ValueError as e:
|
|
100
|
+
message = "Cannot instantiate a client from the passed configuration."
|
|
101
|
+
structlogger.error(
|
|
102
|
+
"rasa_llm_client.from_config.error",
|
|
103
|
+
message=message,
|
|
104
|
+
config=config,
|
|
105
|
+
original_error=e,
|
|
106
|
+
)
|
|
107
|
+
raise
|
|
108
|
+
return cls(
|
|
109
|
+
model=client_config.model,
|
|
110
|
+
api_base=client_config.api_base,
|
|
111
|
+
**client_config.extra_parameters,
|
|
112
|
+
)
|
|
@@ -10,13 +10,14 @@ import structlog
|
|
|
10
10
|
from rasa.shared.constants import (
|
|
11
11
|
SELF_HOSTED_VLLM_PREFIX,
|
|
12
12
|
SELF_HOSTED_VLLM_API_KEY_ENV_VAR,
|
|
13
|
+
API_KEY,
|
|
13
14
|
)
|
|
14
15
|
from rasa.shared.providers._configs.self_hosted_llm_client_config import (
|
|
15
16
|
SelfHostedLLMClientConfig,
|
|
16
17
|
)
|
|
17
18
|
from rasa.shared.exceptions import ProviderClientAPIException
|
|
18
19
|
from rasa.shared.providers.llm._base_litellm_client import _BaseLiteLLMClient
|
|
19
|
-
from rasa.shared.providers.llm.llm_response import LLMResponse
|
|
20
|
+
from rasa.shared.providers.llm.llm_response import LLMResponse
|
|
20
21
|
from rasa.shared.utils.io import suppress_logs
|
|
21
22
|
|
|
22
23
|
structlogger = structlog.get_logger()
|
|
@@ -61,7 +62,8 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
|
|
|
61
62
|
self._api_version = api_version
|
|
62
63
|
self._use_chat_completions_endpoint = use_chat_completions_endpoint
|
|
63
64
|
self._extra_parameters = kwargs or {}
|
|
64
|
-
self.
|
|
65
|
+
if self._extra_parameters.get(API_KEY) is None:
|
|
66
|
+
self._apply_dummy_api_key_if_missing()
|
|
65
67
|
|
|
66
68
|
@classmethod
|
|
67
69
|
def from_config(cls, config: Dict[str, Any]) -> "SelfHostedLLMClient":
|
|
@@ -160,7 +162,7 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
|
|
|
160
162
|
"""Returns the value of LiteLLM's model parameter to be used in
|
|
161
163
|
completion/acompletion in LiteLLM format:
|
|
162
164
|
|
|
163
|
-
<
|
|
165
|
+
<hosted_vllm>/<model or deployment name>
|
|
164
166
|
"""
|
|
165
167
|
if self.model and f"{SELF_HOSTED_VLLM_PREFIX}/" not in self.model:
|
|
166
168
|
return f"{SELF_HOSTED_VLLM_PREFIX}/{self.model}"
|
|
@@ -259,32 +261,6 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
|
|
|
259
261
|
return super().completion(messages)
|
|
260
262
|
return self._text_completion(messages)
|
|
261
263
|
|
|
262
|
-
def _format_text_completion_response(self, response: Any) -> LLMResponse:
|
|
263
|
-
"""Parses the LiteLLM text completion response to Rasa format."""
|
|
264
|
-
formatted_response = LLMResponse(
|
|
265
|
-
id=response.id,
|
|
266
|
-
created=response.created,
|
|
267
|
-
choices=[choice.text for choice in response.choices],
|
|
268
|
-
model=response.model,
|
|
269
|
-
)
|
|
270
|
-
if (usage := response.usage) is not None:
|
|
271
|
-
prompt_tokens = (
|
|
272
|
-
num_tokens
|
|
273
|
-
if isinstance(num_tokens := usage.prompt_tokens, (int, float))
|
|
274
|
-
else 0
|
|
275
|
-
)
|
|
276
|
-
completion_tokens = (
|
|
277
|
-
num_tokens
|
|
278
|
-
if isinstance(num_tokens := usage.completion_tokens, (int, float))
|
|
279
|
-
else 0
|
|
280
|
-
)
|
|
281
|
-
formatted_response.usage = LLMUsage(prompt_tokens, completion_tokens)
|
|
282
|
-
structlogger.debug(
|
|
283
|
-
"base_litellm_client.formatted_response",
|
|
284
|
-
formatted_response=formatted_response.to_dict(),
|
|
285
|
-
)
|
|
286
|
-
return formatted_response
|
|
287
|
-
|
|
288
264
|
@staticmethod
|
|
289
265
|
def _apply_dummy_api_key_if_missing() -> None:
|
|
290
266
|
if not os.getenv(SELF_HOSTED_VLLM_API_KEY_ENV_VAR):
|
|
@@ -5,6 +5,8 @@ from rasa.shared.constants import (
|
|
|
5
5
|
HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER,
|
|
6
6
|
OPENAI_PROVIDER,
|
|
7
7
|
SELF_HOSTED_PROVIDER,
|
|
8
|
+
RASA_PROVIDER,
|
|
9
|
+
SELF_HOSTED_VLLM_PREFIX,
|
|
8
10
|
)
|
|
9
11
|
from rasa.shared.providers.embedding.azure_openai_embedding_client import (
|
|
10
12
|
AzureOpenAIEmbeddingClient,
|
|
@@ -24,6 +26,7 @@ from rasa.shared.providers.llm.default_litellm_llm_client import DefaultLiteLLMC
|
|
|
24
26
|
from rasa.shared.providers.llm.llm_client import LLMClient
|
|
25
27
|
from rasa.shared.providers.llm.openai_llm_client import OpenAILLMClient
|
|
26
28
|
from rasa.shared.providers.llm.self_hosted_llm_client import SelfHostedLLMClient
|
|
29
|
+
from rasa.shared.providers.llm.rasa_llm_client import RasaLLMClient
|
|
27
30
|
from rasa.shared.providers._configs.azure_openai_client_config import (
|
|
28
31
|
AzureOpenAIClientConfig,
|
|
29
32
|
)
|
|
@@ -37,12 +40,15 @@ from rasa.shared.providers._configs.openai_client_config import OpenAIClientConf
|
|
|
37
40
|
from rasa.shared.providers._configs.self_hosted_llm_client_config import (
|
|
38
41
|
SelfHostedLLMClientConfig,
|
|
39
42
|
)
|
|
43
|
+
from rasa.shared.providers._configs.rasa_llm_client_config import RasaLLMClientConfig
|
|
44
|
+
|
|
40
45
|
from rasa.shared.providers._configs.client_config import ClientConfig
|
|
41
46
|
|
|
42
47
|
_provider_to_llm_client_mapping: Dict[str, Type[LLMClient]] = {
|
|
43
48
|
OPENAI_PROVIDER: OpenAILLMClient,
|
|
44
49
|
AZURE_OPENAI_PROVIDER: AzureOpenAILLMClient,
|
|
45
50
|
SELF_HOSTED_PROVIDER: SelfHostedLLMClient,
|
|
51
|
+
RASA_PROVIDER: RasaLLMClient,
|
|
46
52
|
}
|
|
47
53
|
|
|
48
54
|
_provider_to_embedding_client_mapping: Dict[str, Type[EmbeddingClient]] = {
|
|
@@ -56,6 +62,15 @@ _provider_to_client_config_class_mapping: Dict[str, Type] = {
|
|
|
56
62
|
AZURE_OPENAI_PROVIDER: AzureOpenAIClientConfig,
|
|
57
63
|
HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER: HuggingFaceLocalEmbeddingClientConfig,
|
|
58
64
|
SELF_HOSTED_PROVIDER: SelfHostedLLMClientConfig,
|
|
65
|
+
RASA_PROVIDER: RasaLLMClientConfig,
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
_provider_to_prefix_mapping: Dict[str, str] = {
|
|
70
|
+
# Specify the provider name as the key and its corresponding prefix as the value
|
|
71
|
+
# for providers where the prefix differs from the provider name.
|
|
72
|
+
SELF_HOSTED_PROVIDER: SELF_HOSTED_VLLM_PREFIX,
|
|
73
|
+
RASA_PROVIDER: OPENAI_PROVIDER,
|
|
59
74
|
}
|
|
60
75
|
|
|
61
76
|
|
|
@@ -73,3 +88,7 @@ def get_client_config_class_from_provider(provider: str) -> Type[ClientConfig]:
|
|
|
73
88
|
return _provider_to_client_config_class_mapping.get(
|
|
74
89
|
provider, DefaultLiteLLMClientConfig
|
|
75
90
|
)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def get_prefix_from_provider(provider: str) -> str:
|
|
94
|
+
return _provider_to_prefix_mapping.get(provider, provider)
|
|
File without changes
|
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
from typing import Any, Dict, List
|
|
2
|
+
import os
|
|
3
|
+
import structlog
|
|
4
|
+
|
|
5
|
+
from litellm import Router
|
|
6
|
+
|
|
7
|
+
from rasa.shared.constants import (
|
|
8
|
+
MODEL_LIST_KEY,
|
|
9
|
+
MODEL_GROUP_ID_CONFIG_KEY,
|
|
10
|
+
ROUTER_CONFIG_KEY,
|
|
11
|
+
SELF_HOSTED_VLLM_PREFIX,
|
|
12
|
+
SELF_HOSTED_VLLM_API_KEY_ENV_VAR,
|
|
13
|
+
LITELLM_PARAMS_KEY,
|
|
14
|
+
API_KEY,
|
|
15
|
+
MODEL_CONFIG_KEY,
|
|
16
|
+
USE_CHAT_COMPLETIONS_ENDPOINT_CONFIG_KEY,
|
|
17
|
+
)
|
|
18
|
+
from rasa.shared.exceptions import ProviderClientValidationError
|
|
19
|
+
from rasa.shared.providers._configs.litellm_router_client_config import (
|
|
20
|
+
LiteLLMRouterClientConfig,
|
|
21
|
+
)
|
|
22
|
+
from rasa.shared.utils.io import resolve_environment_variables
|
|
23
|
+
|
|
24
|
+
structlogger = structlog.get_logger()
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class _BaseLiteLLMRouterClient:
|
|
28
|
+
"""An abstract base class for LiteLLM Router clients.
|
|
29
|
+
|
|
30
|
+
This class defines the interface and common functionality for all the router clients
|
|
31
|
+
based on LiteLLM.
|
|
32
|
+
|
|
33
|
+
The class is made private to prevent it from being part of the public-facing
|
|
34
|
+
interface, as it serves as an internal base class for specific implementations of
|
|
35
|
+
router clients that are based on LiteLLM router implementation.
|
|
36
|
+
|
|
37
|
+
Parameters:
|
|
38
|
+
model_group_id (str): The model group ID.
|
|
39
|
+
model_configurations (List[Dict[str, Any]]): The list of model configurations.
|
|
40
|
+
router_settings (Dict[str, Any]): The router settings.
|
|
41
|
+
kwargs (Optional[Dict[str, Any]]): Additional configuration parameters.
|
|
42
|
+
|
|
43
|
+
Raises:
|
|
44
|
+
ProviderClientValidationError: If validation of the client setup fails.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
model_group_id: str,
|
|
50
|
+
model_configurations: List[Dict[str, Any]],
|
|
51
|
+
router_settings: Dict[str, Any],
|
|
52
|
+
use_chat_completions_endpoint: bool = True,
|
|
53
|
+
**kwargs: Any,
|
|
54
|
+
):
|
|
55
|
+
self._model_group_id = model_group_id
|
|
56
|
+
self._model_configurations = model_configurations
|
|
57
|
+
self._router_settings = router_settings
|
|
58
|
+
self._use_chat_completions_endpoint = use_chat_completions_endpoint
|
|
59
|
+
self._extra_parameters = kwargs or {}
|
|
60
|
+
self.additional_client_setup()
|
|
61
|
+
try:
|
|
62
|
+
resolved_model_configurations = (
|
|
63
|
+
self._resolve_env_vars_in_model_configurations()
|
|
64
|
+
)
|
|
65
|
+
self._router_client = Router(
|
|
66
|
+
model_list=resolved_model_configurations, **router_settings
|
|
67
|
+
)
|
|
68
|
+
except Exception as e:
|
|
69
|
+
event_info = "Cannot instantiate a router client."
|
|
70
|
+
structlogger.error(
|
|
71
|
+
"_base_litellm_router_client.init.error",
|
|
72
|
+
event_info=event_info,
|
|
73
|
+
model_group_id=model_group_id,
|
|
74
|
+
models=model_configurations,
|
|
75
|
+
router=router_settings,
|
|
76
|
+
original_error=e,
|
|
77
|
+
)
|
|
78
|
+
raise ProviderClientValidationError(f"{event_info} Original error: {e}")
|
|
79
|
+
|
|
80
|
+
def additional_client_setup(self) -> None:
|
|
81
|
+
"""Additional setup for the LiteLLM Router client."""
|
|
82
|
+
# If the model configuration is self-hosted VLLM, set a dummy API key if not
|
|
83
|
+
# provided. A bug in the LiteLLM library requires an API key to be set even if
|
|
84
|
+
# it is not required.
|
|
85
|
+
for model_configuration in self.model_configurations:
|
|
86
|
+
if (
|
|
87
|
+
f"{SELF_HOSTED_VLLM_PREFIX}/"
|
|
88
|
+
in model_configuration[LITELLM_PARAMS_KEY][MODEL_CONFIG_KEY]
|
|
89
|
+
and API_KEY not in model_configuration[LITELLM_PARAMS_KEY]
|
|
90
|
+
and not os.getenv(SELF_HOSTED_VLLM_API_KEY_ENV_VAR)
|
|
91
|
+
):
|
|
92
|
+
os.environ[SELF_HOSTED_VLLM_API_KEY_ENV_VAR] = "dummy api key"
|
|
93
|
+
return
|
|
94
|
+
|
|
95
|
+
@classmethod
|
|
96
|
+
def from_config(cls, config: Dict[str, Any]) -> "_BaseLiteLLMRouterClient":
|
|
97
|
+
"""Instantiates a LiteLLM Router Embedding client from a configuration dict.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
config: (Dict[str, Any]) The configuration dictionary.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
LiteLLMRouterLLMClient: The instantiated LiteLLM Router LLM client.
|
|
104
|
+
|
|
105
|
+
Raises:
|
|
106
|
+
ValueError: If the configuration is invalid.
|
|
107
|
+
"""
|
|
108
|
+
try:
|
|
109
|
+
client_config = LiteLLMRouterClientConfig.from_dict(config)
|
|
110
|
+
except (KeyError, ValueError) as e:
|
|
111
|
+
message = "Cannot instantiate a client from the passed configuration."
|
|
112
|
+
structlogger.error(
|
|
113
|
+
"litellm_router_llm_client.from_config.error",
|
|
114
|
+
message=message,
|
|
115
|
+
config=config,
|
|
116
|
+
original_error=e,
|
|
117
|
+
)
|
|
118
|
+
raise
|
|
119
|
+
|
|
120
|
+
return cls(
|
|
121
|
+
model_group_id=client_config.model_group_id,
|
|
122
|
+
model_configurations=client_config.litellm_model_list,
|
|
123
|
+
router_settings=client_config.litellm_router_settings,
|
|
124
|
+
use_chat_completions_endpoint=client_config.use_chat_completions_endpoint,
|
|
125
|
+
**client_config.extra_parameters,
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
@property
|
|
129
|
+
def model_group_id(self) -> str:
|
|
130
|
+
"""Returns the model group ID for the LiteLLM Router client."""
|
|
131
|
+
return self._model_group_id
|
|
132
|
+
|
|
133
|
+
@property
|
|
134
|
+
def model_configurations(self) -> List[Dict[str, Any]]:
|
|
135
|
+
"""Returns the model configurations for the LiteLLM Router client."""
|
|
136
|
+
return self._model_configurations
|
|
137
|
+
|
|
138
|
+
@property
|
|
139
|
+
def router_settings(self) -> Dict[str, Any]:
|
|
140
|
+
"""Returns the router settings for the LiteLLM Router client."""
|
|
141
|
+
return self._router_settings
|
|
142
|
+
|
|
143
|
+
@property
|
|
144
|
+
def router_client(self) -> Router:
|
|
145
|
+
"""Returns the instantiated LiteLLM Router client."""
|
|
146
|
+
return self._router_client
|
|
147
|
+
|
|
148
|
+
@property
|
|
149
|
+
def use_chat_completions_endpoint(self) -> bool:
|
|
150
|
+
"""Returns whether to use the chat completions endpoint."""
|
|
151
|
+
return self._use_chat_completions_endpoint
|
|
152
|
+
|
|
153
|
+
@property
|
|
154
|
+
def _litellm_extra_parameters(self) -> Dict[str, Any]:
|
|
155
|
+
"""
|
|
156
|
+
Returns the extra parameters for the LiteLLM Router client.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
Dictionary containing the model parameters.
|
|
160
|
+
"""
|
|
161
|
+
return self._extra_parameters
|
|
162
|
+
|
|
163
|
+
@property
|
|
164
|
+
def config(self) -> Dict:
|
|
165
|
+
"""Returns the configuration for the LiteLLM Router client in LiteLLM format."""
|
|
166
|
+
return {
|
|
167
|
+
MODEL_GROUP_ID_CONFIG_KEY: self.model_group_id,
|
|
168
|
+
MODEL_LIST_KEY: self.model_configurations,
|
|
169
|
+
ROUTER_CONFIG_KEY: self.router_settings,
|
|
170
|
+
USE_CHAT_COMPLETIONS_ENDPOINT_CONFIG_KEY: (
|
|
171
|
+
self.use_chat_completions_endpoint
|
|
172
|
+
),
|
|
173
|
+
**self._litellm_extra_parameters,
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
def _resolve_env_vars_in_model_configurations(self) -> List:
|
|
177
|
+
model_configuration_with_resolved_keys = []
|
|
178
|
+
for model_configuration in self.model_configurations:
|
|
179
|
+
resolved_model_configuration = resolve_environment_variables(
|
|
180
|
+
model_configuration
|
|
181
|
+
)
|
|
182
|
+
model_configuration_with_resolved_keys.append(resolved_model_configuration)
|
|
183
|
+
return model_configuration_with_resolved_keys
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Protocol, runtime_checkable
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
@runtime_checkable
|
|
5
|
+
class RouterClient(Protocol):
|
|
6
|
+
"""
|
|
7
|
+
Protocol for a Router client that specifies the interface for interacting
|
|
8
|
+
with the API.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
@classmethod
|
|
12
|
+
def from_config(cls, config: dict) -> "RouterClient":
|
|
13
|
+
"""
|
|
14
|
+
Initializes the router client with the given configuration.
|
|
15
|
+
|
|
16
|
+
This class method should be implemented to parse the given
|
|
17
|
+
configuration and create an instance of an router client.
|
|
18
|
+
"""
|
|
19
|
+
...
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
def config(self) -> Dict:
|
|
23
|
+
"""
|
|
24
|
+
Returns the configuration for that the router client is initialized with.
|
|
25
|
+
|
|
26
|
+
This property should be implemented to return a dictionary containing
|
|
27
|
+
the client configuration settings for the router client.
|
|
28
|
+
"""
|
|
29
|
+
...
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def router_settings(self) -> Dict[str, Any]:
|
|
33
|
+
"""
|
|
34
|
+
Returns the router settings for the Router client.
|
|
35
|
+
|
|
36
|
+
This property should be implemented to return a dictionary containing
|
|
37
|
+
the router settings for the router client.
|
|
38
|
+
"""
|
|
39
|
+
...
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
def model_group_id(self) -> str:
|
|
43
|
+
"""
|
|
44
|
+
Returns the model group ID for the Router client.
|
|
45
|
+
|
|
46
|
+
This property should be implemented to return the model group ID
|
|
47
|
+
for the router client.
|
|
48
|
+
"""
|
|
49
|
+
...
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def model_configurations(self) -> List[Dict[str, Any]]:
|
|
53
|
+
"""
|
|
54
|
+
Returns the list of model configurations for the Router client.
|
|
55
|
+
|
|
56
|
+
This property should be implemented to return the list of model configurations
|
|
57
|
+
for the router client as a list of dictionaries.
|
|
58
|
+
|
|
59
|
+
Each dictionary should contain the model configuration.
|
|
60
|
+
Ideally, the `ModelGroupConfig` should parse the model configurations
|
|
61
|
+
and generate this list of dictionaries.
|
|
62
|
+
"""
|
|
63
|
+
...
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def router_client(self) -> object:
|
|
67
|
+
"""
|
|
68
|
+
Returns the instantiated Router client.
|
|
69
|
+
|
|
70
|
+
This property should be implemented to return the instantiated
|
|
71
|
+
Router client.
|
|
72
|
+
"""
|
|
73
|
+
...
|
rasa/shared/utils/common.py
CHANGED
|
@@ -3,14 +3,16 @@ import functools
|
|
|
3
3
|
import importlib
|
|
4
4
|
import inspect
|
|
5
5
|
import logging
|
|
6
|
+
import os
|
|
6
7
|
import pkgutil
|
|
7
8
|
import sys
|
|
8
9
|
from types import ModuleType
|
|
9
|
-
from typing import Text, Dict, Optional, Any, List, Callable, Collection, Type
|
|
10
|
+
from typing import Sequence, Text, Dict, Optional, Any, List, Callable, Collection, Type
|
|
10
11
|
|
|
11
12
|
import rasa.shared.utils.io
|
|
13
|
+
from rasa.exceptions import MissingDependencyException
|
|
12
14
|
from rasa.shared.constants import DOCS_URL_MIGRATION_GUIDE
|
|
13
|
-
from rasa.shared.exceptions import RasaException
|
|
15
|
+
from rasa.shared.exceptions import ProviderClientValidationError, RasaException
|
|
14
16
|
|
|
15
17
|
logger = logging.getLogger(__name__)
|
|
16
18
|
|
|
@@ -86,31 +88,11 @@ def sort_list_of_dicts_by_first_key(dicts: List[Dict]) -> List[Dict]:
|
|
|
86
88
|
return sorted(dicts, key=lambda d: next(iter(d.keys())))
|
|
87
89
|
|
|
88
90
|
|
|
89
|
-
def lazy_property(function: Callable) -> Any:
|
|
90
|
-
"""Allows to avoid recomputing a property over and over.
|
|
91
|
-
|
|
92
|
-
The result gets stored in a local var. Computation of the property
|
|
93
|
-
will happen once, on the first call of the property. All
|
|
94
|
-
succeeding calls will use the value stored in the private property.
|
|
95
|
-
"""
|
|
96
|
-
attr_name = "_lazy_" + function.__name__
|
|
97
|
-
|
|
98
|
-
def _lazyprop(self: Any) -> Any:
|
|
99
|
-
if not hasattr(self, attr_name):
|
|
100
|
-
setattr(self, attr_name, function(self))
|
|
101
|
-
return getattr(self, attr_name)
|
|
102
|
-
|
|
103
|
-
return property(_lazyprop)
|
|
104
|
-
|
|
105
|
-
|
|
106
91
|
def cached_method(f: Callable[..., Any]) -> Callable[..., Any]:
|
|
107
92
|
"""Caches method calls based on the call's `args` and `kwargs`.
|
|
108
|
-
|
|
109
93
|
Works for `async` and `sync` methods. Don't apply this to functions.
|
|
110
|
-
|
|
111
94
|
Args:
|
|
112
95
|
f: The decorated method whose return value should be cached.
|
|
113
|
-
|
|
114
96
|
Returns:
|
|
115
97
|
The return value which the method gives for the first call with the given
|
|
116
98
|
arguments.
|
|
@@ -176,8 +158,9 @@ def transform_collection_to_sentence(collection: Collection[Text]) -> Text:
|
|
|
176
158
|
def minimal_kwargs(
|
|
177
159
|
kwargs: Dict[Text, Any], func: Callable, excluded_keys: Optional[List] = None
|
|
178
160
|
) -> Dict[Text, Any]:
|
|
179
|
-
"""Returns only the kwargs which are required by a function.
|
|
180
|
-
|
|
161
|
+
"""Returns only the kwargs which are required by a function.
|
|
162
|
+
|
|
163
|
+
Keys, contained in the exception list, are not included.
|
|
181
164
|
|
|
182
165
|
Args:
|
|
183
166
|
kwargs: All available kwargs.
|
|
@@ -209,6 +192,14 @@ def mark_as_experimental_feature(feature_name: Text) -> None:
|
|
|
209
192
|
)
|
|
210
193
|
|
|
211
194
|
|
|
195
|
+
def mark_as_beta_feature(feature_name: Text) -> None:
|
|
196
|
+
"""Warns users that they are using a beta feature."""
|
|
197
|
+
logger.warning(
|
|
198
|
+
f"🔬 Beta Feature: {feature_name} is in beta. It may have unexpected "
|
|
199
|
+
"behaviour and might be changed in the future."
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
|
|
212
203
|
def arguments_of(func: Callable) -> List[Text]:
|
|
213
204
|
"""Return the parameters of the function `func` as a list of names."""
|
|
214
205
|
import inspect
|
|
@@ -306,3 +297,28 @@ def warn_and_exit_if_module_path_contains_rasa_plus(
|
|
|
306
297
|
docs=DOCS_URL_MIGRATION_GUIDE,
|
|
307
298
|
)
|
|
308
299
|
sys.exit(1)
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def validate_environment(
|
|
303
|
+
required_env_vars: Sequence[str],
|
|
304
|
+
required_packages: Sequence[str],
|
|
305
|
+
component_name: str,
|
|
306
|
+
) -> None:
|
|
307
|
+
"""Make sure all needed requirements for a component are met.
|
|
308
|
+
Args:
|
|
309
|
+
required_env_vars: List of environment variables that should be set
|
|
310
|
+
required_packages: List of packages that should be installed
|
|
311
|
+
component_name: component name that needs the requirements
|
|
312
|
+
"""
|
|
313
|
+
for e in required_env_vars:
|
|
314
|
+
if not os.environ.get(e):
|
|
315
|
+
raise ProviderClientValidationError(
|
|
316
|
+
f"Missing environment variable for {component_name}: {e}"
|
|
317
|
+
)
|
|
318
|
+
for p in required_packages:
|
|
319
|
+
try:
|
|
320
|
+
importlib.import_module(p)
|
|
321
|
+
except ImportError:
|
|
322
|
+
raise MissingDependencyException(
|
|
323
|
+
f"Missing package for {component_name}: {p}"
|
|
324
|
+
)
|
|
File without changes
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from typing import Optional, Dict, Any
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class EmbeddingsHealthCheckMixin:
|
|
5
|
+
"""Mixin class that provides methods for performing embeddings health checks during
|
|
6
|
+
training and inference within components.
|
|
7
|
+
|
|
8
|
+
This mixin offers static methods that wrap the following health check functions:
|
|
9
|
+
- `perform_embeddings_health_check`
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
@staticmethod
|
|
13
|
+
def perform_embeddings_health_check(
|
|
14
|
+
custom_embeddings_config: Optional[Dict[str, Any]],
|
|
15
|
+
default_embeddings_config: Dict[str, Any],
|
|
16
|
+
log_source_method: str,
|
|
17
|
+
log_source_component: str,
|
|
18
|
+
) -> None:
|
|
19
|
+
"""Wraps the `perform_embeddings_health_check` function to enable
|
|
20
|
+
tracing and instrumentation.
|
|
21
|
+
"""
|
|
22
|
+
from rasa.shared.utils.health_check.health_check import (
|
|
23
|
+
perform_embeddings_health_check,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
perform_embeddings_health_check(
|
|
27
|
+
custom_embeddings_config,
|
|
28
|
+
default_embeddings_config,
|
|
29
|
+
log_source_method,
|
|
30
|
+
log_source_component,
|
|
31
|
+
)
|