rasa-pro 3.10.15__py3-none-any.whl → 3.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/__main__.py +31 -15
- rasa/api.py +12 -2
- rasa/cli/arguments/default_arguments.py +24 -4
- rasa/cli/arguments/run.py +15 -0
- rasa/cli/arguments/shell.py +5 -1
- rasa/cli/arguments/train.py +17 -9
- rasa/cli/evaluate.py +7 -7
- rasa/cli/inspect.py +19 -7
- rasa/cli/interactive.py +1 -0
- rasa/cli/project_templates/calm/config.yml +5 -7
- rasa/cli/project_templates/calm/endpoints.yml +15 -2
- rasa/cli/project_templates/tutorial/config.yml +8 -5
- rasa/cli/project_templates/tutorial/data/flows.yml +1 -1
- rasa/cli/project_templates/tutorial/data/patterns.yml +5 -0
- rasa/cli/project_templates/tutorial/domain.yml +14 -0
- rasa/cli/project_templates/tutorial/endpoints.yml +5 -0
- rasa/cli/run.py +7 -0
- rasa/cli/scaffold.py +4 -2
- rasa/cli/studio/upload.py +0 -15
- rasa/cli/train.py +14 -53
- rasa/cli/utils.py +14 -11
- rasa/cli/x.py +7 -7
- rasa/constants.py +3 -1
- rasa/core/actions/action.py +77 -33
- rasa/core/actions/action_hangup.py +29 -0
- rasa/core/actions/action_repeat_bot_messages.py +89 -0
- rasa/core/actions/e2e_stub_custom_action_executor.py +5 -1
- rasa/core/actions/http_custom_action_executor.py +4 -0
- rasa/core/agent.py +2 -2
- rasa/core/brokers/kafka.py +3 -1
- rasa/core/brokers/pika.py +3 -1
- rasa/core/channels/__init__.py +10 -6
- rasa/core/channels/channel.py +41 -4
- rasa/core/channels/development_inspector.py +150 -46
- rasa/core/channels/inspector/README.md +1 -1
- rasa/core/channels/inspector/dist/assets/{arc-b6e548fe.js → arc-bc141fb2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{c4Diagram-d0fbc5ce-fa03ac9e.js → c4Diagram-d0fbc5ce-be2db283.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-936ed81e-ee67392a.js → classDiagram-936ed81e-55366915.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-v2-c3cb15f1-9b283fae.js → classDiagram-v2-c3cb15f1-bb529518.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{createText-62fc7601-8b6fcc2a.js → createText-62fc7601-b0ec81d6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{edges-f2ad444c-22e77f4f.js → edges-f2ad444c-6166330c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{erDiagram-9d236eb7-60ffc87f.js → erDiagram-9d236eb7-5ccc6a8e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDb-1972c806-9dd802e4.js → flowDb-1972c806-fca3bfe4.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDiagram-7ea5b25a-5fa1912f.js → flowDiagram-7ea5b25a-4739080f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-736177bf.js +1 -0
- rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-abe16c3d-622a1fd2.js → flowchart-elk-definition-abe16c3d-7c1b0e0f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{ganttDiagram-9b5ea136-e285a63a.js → ganttDiagram-9b5ea136-772fd050.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-99d0ae7c-f237bdca.js → gitGraphDiagram-99d0ae7c-8eae1dc9.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-2c4b9a3b-4b03d70e.js → index-2c4b9a3b-f55afcdf.js} +1 -1
- rasa/core/channels/inspector/dist/assets/index-e7cef9de.js +1317 -0
- rasa/core/channels/inspector/dist/assets/{infoDiagram-736b4530-72a0fa5f.js → infoDiagram-736b4530-124d4a14.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{journeyDiagram-df861f2b-82218c41.js → journeyDiagram-df861f2b-7c4fae44.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{layout-78cff630.js → layout-b9885fb6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{line-5038b469.js → line-7c59abb6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{linear-c4fc4098.js → linear-4776f780.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{mindmap-definition-beec6740-c33c8ea6.js → mindmap-definition-beec6740-2332c46c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{pieDiagram-dbbf0591-a8d03059.js → pieDiagram-dbbf0591-8fb39303.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{quadrantDiagram-4d7f4fd6-6a0e56b2.js → quadrantDiagram-4d7f4fd6-3c7180a2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{requirementDiagram-6fc4c22a-2dc7c7bd.js → requirementDiagram-6fc4c22a-e910bcb8.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sankeyDiagram-8f13d901-2360fe39.js → sankeyDiagram-8f13d901-ead16c89.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sequenceDiagram-b655622a-41b9f9ad.js → sequenceDiagram-b655622a-29a02a19.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-59f0c015-0aad326f.js → stateDiagram-59f0c015-042b3137.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-2b26beab-9847d984.js → stateDiagram-v2-2b26beab-2178c0f3.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-080da4f6-564d890e.js → styles-080da4f6-23ffa4fc.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-3dcbcfbf-38957613.js → styles-3dcbcfbf-94f59763.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-9c745c82-f0fc6921.js → styles-9c745c82-78a6bebc.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{svgDrawCommon-4835440b-ef3c5a77.js → svgDrawCommon-4835440b-eae2a6f6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{timeline-definition-5b62e21b-bf3e91c1.js → timeline-definition-5b62e21b-5c968d92.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{xychartDiagram-2b33534f-4d4026c0.js → xychartDiagram-2b33534f-fd3db0d5.js} +1 -1
- rasa/core/channels/inspector/dist/index.html +18 -15
- rasa/core/channels/inspector/index.html +17 -14
- rasa/core/channels/inspector/package.json +5 -1
- rasa/core/channels/inspector/src/App.tsx +118 -68
- rasa/core/channels/inspector/src/components/Chat.tsx +95 -0
- rasa/core/channels/inspector/src/components/DiagramFlow.tsx +11 -10
- rasa/core/channels/inspector/src/components/DialogueStack.tsx +10 -25
- rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +6 -3
- rasa/core/channels/inspector/src/helpers/audiostream.ts +165 -0
- rasa/core/channels/inspector/src/helpers/formatters.test.ts +10 -0
- rasa/core/channels/inspector/src/helpers/formatters.ts +107 -41
- rasa/core/channels/inspector/src/helpers/utils.ts +92 -7
- rasa/core/channels/inspector/src/types.ts +21 -1
- rasa/core/channels/inspector/yarn.lock +94 -1
- rasa/core/channels/rest.py +51 -46
- rasa/core/channels/socketio.py +28 -1
- rasa/core/channels/telegram.py +1 -1
- rasa/core/channels/twilio.py +1 -1
- rasa/core/channels/{audiocodes.py → voice_ready/audiocodes.py} +122 -69
- rasa/core/channels/{voice_aware → voice_ready}/jambonz.py +26 -8
- rasa/core/channels/{voice_aware → voice_ready}/jambonz_protocol.py +57 -5
- rasa/core/channels/{twilio_voice.py → voice_ready/twilio_voice.py} +64 -28
- rasa/core/channels/voice_ready/utils.py +37 -0
- rasa/core/channels/voice_stream/asr/__init__.py +0 -0
- rasa/core/channels/voice_stream/asr/asr_engine.py +89 -0
- rasa/core/channels/voice_stream/asr/asr_event.py +18 -0
- rasa/core/channels/voice_stream/asr/azure.py +129 -0
- rasa/core/channels/voice_stream/asr/deepgram.py +90 -0
- rasa/core/channels/voice_stream/audio_bytes.py +8 -0
- rasa/core/channels/voice_stream/browser_audio.py +107 -0
- rasa/core/channels/voice_stream/call_state.py +23 -0
- rasa/core/channels/voice_stream/tts/__init__.py +0 -0
- rasa/core/channels/voice_stream/tts/azure.py +106 -0
- rasa/core/channels/voice_stream/tts/cartesia.py +118 -0
- rasa/core/channels/voice_stream/tts/tts_cache.py +27 -0
- rasa/core/channels/voice_stream/tts/tts_engine.py +58 -0
- rasa/core/channels/voice_stream/twilio_media_streams.py +173 -0
- rasa/core/channels/voice_stream/util.py +57 -0
- rasa/core/channels/voice_stream/voice_channel.py +427 -0
- rasa/core/information_retrieval/qdrant.py +1 -0
- rasa/core/nlg/contextual_response_rephraser.py +45 -17
- rasa/{nlu → core}/persistor.py +203 -68
- rasa/core/policies/enterprise_search_policy.py +119 -63
- rasa/core/policies/flows/flow_executor.py +15 -22
- rasa/core/policies/intentless_policy.py +83 -28
- rasa/core/processor.py +25 -0
- rasa/core/run.py +12 -2
- rasa/core/secrets_manager/constants.py +4 -0
- rasa/core/secrets_manager/factory.py +8 -0
- rasa/core/secrets_manager/vault.py +11 -1
- rasa/core/training/interactive.py +33 -34
- rasa/core/utils.py +47 -21
- rasa/dialogue_understanding/coexistence/llm_based_router.py +41 -14
- rasa/dialogue_understanding/commands/__init__.py +6 -0
- rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +60 -0
- rasa/dialogue_understanding/commands/session_end_command.py +61 -0
- rasa/dialogue_understanding/commands/user_silence_command.py +59 -0
- rasa/dialogue_understanding/commands/utils.py +5 -0
- rasa/dialogue_understanding/generator/constants.py +2 -0
- rasa/dialogue_understanding/generator/flow_retrieval.py +47 -9
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +38 -15
- rasa/dialogue_understanding/generator/llm_command_generator.py +1 -1
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +35 -13
- rasa/dialogue_understanding/generator/single_step/command_prompt_template.jinja2 +3 -0
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +60 -13
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +53 -0
- rasa/dialogue_understanding/patterns/repeat.py +37 -0
- rasa/dialogue_understanding/patterns/user_silence.py +37 -0
- rasa/dialogue_understanding/processor/command_processor.py +21 -1
- rasa/e2e_test/aggregate_test_stats_calculator.py +1 -11
- rasa/e2e_test/assertions.py +136 -61
- rasa/e2e_test/assertions_schema.yml +23 -0
- rasa/e2e_test/e2e_test_case.py +85 -6
- rasa/e2e_test/e2e_test_runner.py +2 -3
- rasa/engine/graph.py +0 -1
- rasa/engine/loader.py +12 -0
- rasa/engine/recipes/config_files/default_config.yml +0 -3
- rasa/engine/recipes/default_recipe.py +0 -1
- rasa/engine/recipes/graph_recipe.py +0 -1
- rasa/engine/runner/dask.py +2 -2
- rasa/engine/storage/local_model_storage.py +12 -42
- rasa/engine/storage/storage.py +1 -5
- rasa/engine/validation.py +527 -74
- rasa/model_manager/__init__.py +0 -0
- rasa/model_manager/config.py +40 -0
- rasa/model_manager/model_api.py +559 -0
- rasa/model_manager/runner_service.py +286 -0
- rasa/model_manager/socket_bridge.py +146 -0
- rasa/model_manager/studio_jwt_auth.py +86 -0
- rasa/model_manager/trainer_service.py +325 -0
- rasa/model_manager/utils.py +87 -0
- rasa/model_manager/warm_rasa_process.py +187 -0
- rasa/model_service.py +112 -0
- rasa/model_training.py +42 -23
- rasa/nlu/tokenizers/whitespace_tokenizer.py +3 -14
- rasa/server.py +4 -2
- rasa/shared/constants.py +60 -8
- rasa/shared/core/constants.py +13 -0
- rasa/shared/core/domain.py +107 -50
- rasa/shared/core/events.py +29 -0
- rasa/shared/core/flows/flow.py +5 -0
- rasa/shared/core/flows/flows_list.py +19 -6
- rasa/shared/core/flows/flows_yaml_schema.json +10 -0
- rasa/shared/core/flows/utils.py +39 -0
- rasa/shared/core/flows/validation.py +121 -0
- rasa/shared/core/flows/yaml_flows_io.py +15 -27
- rasa/shared/core/slots.py +5 -0
- rasa/shared/importers/importer.py +59 -41
- rasa/shared/importers/multi_project.py +23 -11
- rasa/shared/importers/rasa.py +12 -3
- rasa/shared/importers/remote_importer.py +196 -0
- rasa/shared/importers/utils.py +3 -1
- rasa/shared/nlu/training_data/formats/rasa_yaml.py +18 -3
- rasa/shared/nlu/training_data/training_data.py +18 -19
- rasa/shared/providers/_configs/litellm_router_client_config.py +220 -0
- rasa/shared/providers/_configs/model_group_config.py +167 -0
- rasa/shared/providers/_configs/openai_client_config.py +1 -1
- rasa/shared/providers/_configs/rasa_llm_client_config.py +73 -0
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +1 -0
- rasa/shared/providers/_configs/utils.py +16 -0
- rasa/shared/providers/_utils.py +79 -0
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +13 -29
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +54 -21
- rasa/shared/providers/embedding/default_litellm_embedding_client.py +24 -0
- rasa/shared/providers/embedding/litellm_router_embedding_client.py +135 -0
- rasa/shared/providers/llm/_base_litellm_client.py +34 -22
- rasa/shared/providers/llm/azure_openai_llm_client.py +50 -29
- rasa/shared/providers/llm/default_litellm_llm_client.py +24 -0
- rasa/shared/providers/llm/litellm_router_llm_client.py +182 -0
- rasa/shared/providers/llm/rasa_llm_client.py +112 -0
- rasa/shared/providers/llm/self_hosted_llm_client.py +5 -29
- rasa/shared/providers/mappings.py +19 -0
- rasa/shared/providers/router/__init__.py +0 -0
- rasa/shared/providers/router/_base_litellm_router_client.py +183 -0
- rasa/shared/providers/router/router_client.py +73 -0
- rasa/shared/utils/common.py +40 -24
- rasa/shared/utils/health_check/__init__.py +0 -0
- rasa/shared/utils/health_check/embeddings_health_check_mixin.py +31 -0
- rasa/shared/utils/health_check/health_check.py +258 -0
- rasa/shared/utils/health_check/llm_health_check_mixin.py +31 -0
- rasa/shared/utils/io.py +27 -6
- rasa/shared/utils/llm.py +353 -43
- rasa/shared/utils/schemas/events.py +2 -0
- rasa/shared/utils/schemas/model_config.yml +0 -10
- rasa/shared/utils/yaml.py +181 -38
- rasa/studio/data_handler.py +3 -1
- rasa/studio/upload.py +160 -74
- rasa/telemetry.py +94 -17
- rasa/tracing/config.py +3 -1
- rasa/tracing/instrumentation/attribute_extractors.py +95 -18
- rasa/tracing/instrumentation/instrumentation.py +121 -0
- rasa/utils/common.py +5 -0
- rasa/utils/endpoints.py +27 -1
- rasa/utils/io.py +8 -16
- rasa/utils/log_utils.py +9 -2
- rasa/utils/sanic_error_handler.py +32 -0
- rasa/validator.py +110 -4
- rasa/version.py +1 -1
- {rasa_pro-3.10.15.dist-info → rasa_pro-3.11.0.dist-info}/METADATA +14 -12
- {rasa_pro-3.10.15.dist-info → rasa_pro-3.11.0.dist-info}/RECORD +234 -183
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-1844e5a5.js +0 -1
- rasa/core/channels/inspector/dist/assets/index-a5d3e69d.js +0 -1040
- rasa/core/channels/voice_aware/utils.py +0 -20
- rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +0 -407
- /rasa/core/channels/{voice_aware → voice_ready}/__init__.py +0 -0
- /rasa/core/channels/{voice_native → voice_stream}/__init__.py +0 -0
- {rasa_pro-3.10.15.dist-info → rasa_pro-3.11.0.dist-info}/NOTICE +0 -0
- {rasa_pro-3.10.15.dist-info → rasa_pro-3.11.0.dist-info}/WHEEL +0 -0
- {rasa_pro-3.10.15.dist-info → rasa_pro-3.11.0.dist-info}/entry_points.txt +0 -0
rasa/shared/importers/rasa.py
CHANGED
|
@@ -6,7 +6,6 @@ import rasa.shared.core.flows.yaml_flows_io
|
|
|
6
6
|
from rasa.shared.core.flows import FlowsList
|
|
7
7
|
|
|
8
8
|
import rasa.shared.data
|
|
9
|
-
import rasa.shared.utils.common
|
|
10
9
|
import rasa.shared.utils.io
|
|
11
10
|
from rasa.shared.core.training_data.structures import StoryGraph
|
|
12
11
|
from rasa.shared.importers import utils
|
|
@@ -16,6 +15,7 @@ from rasa.shared.core.domain import InvalidDomain, Domain
|
|
|
16
15
|
from rasa.shared.core.training_data.story_reader.yaml_story_reader import (
|
|
17
16
|
YAMLStoryReader,
|
|
18
17
|
)
|
|
18
|
+
from rasa.shared.utils.common import cached_method
|
|
19
19
|
from rasa.shared.utils.yaml import read_model_configuration
|
|
20
20
|
|
|
21
21
|
logger = logging.getLogger(__name__)
|
|
@@ -29,7 +29,9 @@ class RasaFileImporter(TrainingDataImporter):
|
|
|
29
29
|
config_file: Optional[Text] = None,
|
|
30
30
|
domain_path: Optional[Text] = None,
|
|
31
31
|
training_data_paths: Optional[Union[List[Text], Text]] = None,
|
|
32
|
+
expand_env_vars: bool = True,
|
|
32
33
|
):
|
|
34
|
+
self.expand_env_vars = expand_env_vars
|
|
33
35
|
self._domain_path = domain_path
|
|
34
36
|
|
|
35
37
|
self._nlu_files = rasa.shared.data.get_data_files(
|
|
@@ -47,40 +49,47 @@ class RasaFileImporter(TrainingDataImporter):
|
|
|
47
49
|
|
|
48
50
|
self.config_file = config_file
|
|
49
51
|
|
|
52
|
+
@cached_method
|
|
50
53
|
def get_config(self) -> Dict:
|
|
51
54
|
"""Retrieves model config (see parent class for full docstring)."""
|
|
52
55
|
if not self.config_file or not os.path.exists(self.config_file):
|
|
53
56
|
logger.debug("No configuration file was provided to the RasaFileImporter.")
|
|
54
57
|
return {}
|
|
55
58
|
|
|
56
|
-
config = read_model_configuration(
|
|
59
|
+
config = read_model_configuration(
|
|
60
|
+
self.config_file, expand_env_vars=self.expand_env_vars
|
|
61
|
+
)
|
|
57
62
|
return config
|
|
58
63
|
|
|
59
|
-
@rasa.shared.utils.common.cached_method
|
|
60
64
|
def get_config_file_for_auto_config(self) -> Optional[Text]:
|
|
61
65
|
"""Returns config file path for auto-config only if there is a single one."""
|
|
62
66
|
return self.config_file
|
|
63
67
|
|
|
68
|
+
@cached_method
|
|
64
69
|
def get_stories(self, exclusion_percentage: Optional[int] = None) -> StoryGraph:
|
|
65
70
|
"""Retrieves training stories / rules (see parent class for full docstring)."""
|
|
66
71
|
return utils.story_graph_from_paths(
|
|
67
72
|
self._story_files, self.get_domain(), exclusion_percentage
|
|
68
73
|
)
|
|
69
74
|
|
|
75
|
+
@cached_method
|
|
70
76
|
def get_flows(self) -> FlowsList:
|
|
71
77
|
"""Retrieves training stories / rules (see parent class for full docstring)."""
|
|
72
78
|
return utils.flows_from_paths(self._flow_files)
|
|
73
79
|
|
|
80
|
+
@cached_method
|
|
74
81
|
def get_conversation_tests(self) -> StoryGraph:
|
|
75
82
|
"""Retrieves conversation test stories (see parent class for full docstring)."""
|
|
76
83
|
return utils.story_graph_from_paths(
|
|
77
84
|
self._conversation_test_files, self.get_domain()
|
|
78
85
|
)
|
|
79
86
|
|
|
87
|
+
@cached_method
|
|
80
88
|
def get_nlu_data(self, language: Optional[Text] = "en") -> TrainingData:
|
|
81
89
|
"""Retrieves NLU training data (see parent class for full docstring)."""
|
|
82
90
|
return utils.training_data_from_paths(self._nlu_files, language)
|
|
83
91
|
|
|
92
|
+
@cached_method
|
|
84
93
|
def get_domain(self) -> Domain:
|
|
85
94
|
"""Retrieves model domain (see parent class for full docstring)."""
|
|
86
95
|
domain = Domain.empty()
|
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Dict, List, Optional, Text, Union
|
|
3
|
+
|
|
4
|
+
import structlog
|
|
5
|
+
from tarsafe import TarSafe
|
|
6
|
+
|
|
7
|
+
import rasa.shared.core.flows.yaml_flows_io
|
|
8
|
+
import rasa.shared.data
|
|
9
|
+
import rasa.shared.utils.common
|
|
10
|
+
import rasa.shared.utils.io
|
|
11
|
+
from rasa.core.persistor import StorageType
|
|
12
|
+
from rasa.shared.core.domain import Domain, InvalidDomain
|
|
13
|
+
from rasa.shared.core.flows import FlowsList
|
|
14
|
+
from rasa.shared.core.training_data.story_reader.yaml_story_reader import (
|
|
15
|
+
YAMLStoryReader,
|
|
16
|
+
)
|
|
17
|
+
from rasa.shared.core.training_data.structures import StoryGraph
|
|
18
|
+
from rasa.shared.exceptions import RasaException
|
|
19
|
+
from rasa.shared.importers import utils
|
|
20
|
+
from rasa.shared.importers.importer import TrainingDataImporter
|
|
21
|
+
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
22
|
+
from rasa.shared.utils.yaml import read_model_configuration
|
|
23
|
+
|
|
24
|
+
structlogger = structlog.get_logger()
|
|
25
|
+
|
|
26
|
+
TRAINING_DATA_ARCHIVE = "training_data.tar.gz"
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class RemoteTrainingDataImporter(TrainingDataImporter):
|
|
30
|
+
"""Remote `TrainingFileImporter` implementation.
|
|
31
|
+
|
|
32
|
+
Fetches training data from a remote storage and extracts it to a local directory.
|
|
33
|
+
Extracted training data is then used to load flows, NLU, stories,
|
|
34
|
+
domain, and config files.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
config_file: Optional[Text] = None,
|
|
40
|
+
domain_path: Optional[Text] = None,
|
|
41
|
+
training_data_paths: Optional[Union[List[Text], Text]] = None,
|
|
42
|
+
project_directory: Optional[Text] = None,
|
|
43
|
+
remote_storage: Optional[StorageType] = None,
|
|
44
|
+
training_data_path: Optional[Text] = None,
|
|
45
|
+
):
|
|
46
|
+
"""Initializes `RemoteTrainingDataImporter`.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
config_file: Path to the model configuration file.
|
|
50
|
+
domain_path: Path to the domain file.
|
|
51
|
+
training_data_paths: List of paths to the training data files.
|
|
52
|
+
project_directory: Path to the project directory.
|
|
53
|
+
remote_storage: Storage to use to load the training data.
|
|
54
|
+
training_data_path: Path to the training data.
|
|
55
|
+
"""
|
|
56
|
+
self.remote_storage = remote_storage
|
|
57
|
+
self.training_data_path = training_data_path
|
|
58
|
+
|
|
59
|
+
self.extracted_path = self._fetch_and_extract_training_archive(
|
|
60
|
+
TRAINING_DATA_ARCHIVE, self.training_data_path
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
self._nlu_files = rasa.shared.data.get_data_files(
|
|
64
|
+
self.extracted_path, rasa.shared.data.is_nlu_file
|
|
65
|
+
)
|
|
66
|
+
self._story_files = rasa.shared.data.get_data_files(
|
|
67
|
+
self.extracted_path, YAMLStoryReader.is_stories_file
|
|
68
|
+
)
|
|
69
|
+
self._flow_files = rasa.shared.data.get_data_files(
|
|
70
|
+
self.extracted_path, rasa.shared.core.flows.yaml_flows_io.is_flows_file
|
|
71
|
+
)
|
|
72
|
+
self._conversation_test_files = rasa.shared.data.get_data_files(
|
|
73
|
+
self.extracted_path, YAMLStoryReader.is_test_stories_file
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
self.config_file = config_file
|
|
77
|
+
|
|
78
|
+
def _fetch_training_archive(
|
|
79
|
+
self, training_file: str, training_data_path: Optional[str] = None
|
|
80
|
+
) -> str:
|
|
81
|
+
"""Fetches training files from remote storage."""
|
|
82
|
+
from rasa.core.persistor import get_persistor
|
|
83
|
+
|
|
84
|
+
persistor = get_persistor(self.remote_storage)
|
|
85
|
+
if persistor is None:
|
|
86
|
+
raise RasaException(
|
|
87
|
+
f"Could not find a persistor for "
|
|
88
|
+
f"the storage type '{self.remote_storage}'."
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
return persistor.retrieve(training_file, training_data_path)
|
|
92
|
+
|
|
93
|
+
def _fetch_and_extract_training_archive(
|
|
94
|
+
self, training_file: str, training_data_path: Optional[Text] = None
|
|
95
|
+
) -> Optional[str]:
|
|
96
|
+
"""Fetches and extracts training files from remote storage.
|
|
97
|
+
|
|
98
|
+
If the `training_data_path` is not provided, the training
|
|
99
|
+
data is extracted to the current working directory.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
training_file: Name of the training data archive file.
|
|
103
|
+
training_data_path: Path to the training data.
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
Path to the extracted training data.
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
if training_data_path is None:
|
|
110
|
+
training_data_path = os.path.join(os.getcwd(), "data")
|
|
111
|
+
|
|
112
|
+
if os.path.isfile(training_data_path):
|
|
113
|
+
raise ValueError(
|
|
114
|
+
f"Training data path '{training_data_path}' is a file. "
|
|
115
|
+
f"Please provide a directory path."
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
structlogger.debug(
|
|
119
|
+
"rasa.importers.remote_training_data_importer.fetch_training_archive",
|
|
120
|
+
training_data_path=training_data_path,
|
|
121
|
+
)
|
|
122
|
+
training_archive_file_path = self._fetch_training_archive(
|
|
123
|
+
training_file, training_data_path
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
if not os.path.isfile(training_archive_file_path):
|
|
127
|
+
raise FileNotFoundError(
|
|
128
|
+
f"Training data archive '{training_archive_file_path}' not found. "
|
|
129
|
+
f"Please make sure to provide the correct path."
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
structlogger.debug(
|
|
133
|
+
"rasa.importers.remote_training_data_importer.extract_training_archive",
|
|
134
|
+
training_archive_file_path=training_archive_file_path,
|
|
135
|
+
training_data_path=training_data_path,
|
|
136
|
+
)
|
|
137
|
+
with TarSafe.open(training_archive_file_path, "r:gz") as tar:
|
|
138
|
+
tar.extractall(path=training_data_path)
|
|
139
|
+
|
|
140
|
+
structlogger.debug(
|
|
141
|
+
"rasa.importers.remote_training_data_importer.remove_downloaded_archive",
|
|
142
|
+
training_data_path=training_data_path,
|
|
143
|
+
)
|
|
144
|
+
os.remove(training_archive_file_path)
|
|
145
|
+
return training_data_path
|
|
146
|
+
|
|
147
|
+
def get_config(self) -> Dict:
|
|
148
|
+
"""Retrieves model config (see parent class for full docstring)."""
|
|
149
|
+
if not self.config_file or not os.path.exists(self.config_file):
|
|
150
|
+
structlogger.debug(
|
|
151
|
+
"rasa.importers.remote_training_data_importer.no_config_file",
|
|
152
|
+
message="No configuration file was provided to the RasaFileImporter.",
|
|
153
|
+
)
|
|
154
|
+
return {}
|
|
155
|
+
|
|
156
|
+
config = read_model_configuration(self.config_file)
|
|
157
|
+
return config
|
|
158
|
+
|
|
159
|
+
@rasa.shared.utils.common.cached_method
|
|
160
|
+
def get_config_file_for_auto_config(self) -> Optional[Text]:
|
|
161
|
+
"""Returns config file path for auto-config only if there is a single one."""
|
|
162
|
+
return self.config_file
|
|
163
|
+
|
|
164
|
+
def get_stories(self, exclusion_percentage: Optional[int] = None) -> StoryGraph:
|
|
165
|
+
"""Retrieves training stories / rules (see parent class for full docstring)."""
|
|
166
|
+
return utils.story_graph_from_paths(
|
|
167
|
+
self._story_files, self.get_domain(), exclusion_percentage
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
def get_flows(self) -> FlowsList:
|
|
171
|
+
"""Retrieves training stories / rules (see parent class for full docstring)."""
|
|
172
|
+
return utils.flows_from_paths(self._flow_files)
|
|
173
|
+
|
|
174
|
+
def get_conversation_tests(self) -> StoryGraph:
|
|
175
|
+
"""Retrieves conversation test stories (see parent class for full docstring)."""
|
|
176
|
+
return utils.story_graph_from_paths(
|
|
177
|
+
self._conversation_test_files, self.get_domain()
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
def get_nlu_data(self, language: Optional[Text] = "en") -> TrainingData:
|
|
181
|
+
"""Retrieves NLU training data (see parent class for full docstring)."""
|
|
182
|
+
return utils.training_data_from_paths(self._nlu_files, language)
|
|
183
|
+
|
|
184
|
+
def get_domain(self) -> Domain:
|
|
185
|
+
"""Retrieves model domain (see parent class for full docstring)."""
|
|
186
|
+
domain = Domain.empty()
|
|
187
|
+
domain_path = f"{self.extracted_path}"
|
|
188
|
+
try:
|
|
189
|
+
domain = Domain.load(domain_path)
|
|
190
|
+
except InvalidDomain as e:
|
|
191
|
+
rasa.shared.utils.io.raise_warning(
|
|
192
|
+
f"Loading domain from '{domain_path}' failed. Using "
|
|
193
|
+
f"empty domain. Error: '{e}'"
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
return domain
|
rasa/shared/importers/utils.py
CHANGED
|
@@ -29,6 +29,8 @@ def flows_from_paths(files: List[Text]) -> FlowsList:
|
|
|
29
29
|
|
|
30
30
|
flows = FlowsList(underlying_flows=[])
|
|
31
31
|
for file in files:
|
|
32
|
-
flows = flows.merge(
|
|
32
|
+
flows = flows.merge(
|
|
33
|
+
YAMLFlowsReader.read_from_file(file), ignore_duplicates=False
|
|
34
|
+
)
|
|
33
35
|
flows.validate()
|
|
34
36
|
return flows
|
|
@@ -1,7 +1,18 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
from collections import OrderedDict
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import
|
|
4
|
+
from typing import (
|
|
5
|
+
ClassVar,
|
|
6
|
+
Text,
|
|
7
|
+
Any,
|
|
8
|
+
List,
|
|
9
|
+
Dict,
|
|
10
|
+
Tuple,
|
|
11
|
+
Union,
|
|
12
|
+
Iterator,
|
|
13
|
+
Optional,
|
|
14
|
+
Callable,
|
|
15
|
+
)
|
|
5
16
|
|
|
6
17
|
import rasa.shared.data
|
|
7
18
|
from rasa.shared.core.domain import Domain
|
|
@@ -55,6 +66,8 @@ STRIP_SYMBOLS = "\n\r "
|
|
|
55
66
|
class RasaYAMLReader(TrainingDataReader):
|
|
56
67
|
"""Reads YAML training data and creates a TrainingData object."""
|
|
57
68
|
|
|
69
|
+
expand_env_vars: ClassVar[bool] = True
|
|
70
|
+
|
|
58
71
|
def __init__(self) -> None:
|
|
59
72
|
super().__init__()
|
|
60
73
|
self.training_examples: List[Message] = []
|
|
@@ -69,7 +82,9 @@ class RasaYAMLReader(TrainingDataReader):
|
|
|
69
82
|
If the string is not in the right format, an exception will be raised.
|
|
70
83
|
"""
|
|
71
84
|
try:
|
|
72
|
-
validate_raw_yaml_using_schema_file_with_responses(
|
|
85
|
+
validate_raw_yaml_using_schema_file_with_responses(
|
|
86
|
+
string, NLU_SCHEMA_FILE, expand_env_vars=self.expand_env_vars
|
|
87
|
+
)
|
|
73
88
|
except YamlException as e:
|
|
74
89
|
e.filename = self.filename
|
|
75
90
|
raise e
|
|
@@ -88,7 +103,7 @@ class RasaYAMLReader(TrainingDataReader):
|
|
|
88
103
|
"""
|
|
89
104
|
self.validate(string)
|
|
90
105
|
|
|
91
|
-
yaml_content = read_yaml(string)
|
|
106
|
+
yaml_content = read_yaml(string, expand_env_vars=self.expand_env_vars)
|
|
92
107
|
|
|
93
108
|
if not validate_training_data_format_version(yaml_content, self.filename):
|
|
94
109
|
return TrainingData()
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import os
|
|
3
|
+
from functools import cached_property
|
|
3
4
|
from pathlib import Path
|
|
4
5
|
import random
|
|
5
6
|
from collections import Counter, OrderedDict
|
|
@@ -9,7 +10,6 @@ from typing import Any, Dict, List, Optional, Set, Text, Tuple, Callable
|
|
|
9
10
|
import operator
|
|
10
11
|
|
|
11
12
|
import rasa.shared.data
|
|
12
|
-
from rasa.shared.utils.common import lazy_property
|
|
13
13
|
import rasa.shared.utils.io
|
|
14
14
|
from rasa.shared.nlu.constants import (
|
|
15
15
|
RESPONSE,
|
|
@@ -202,7 +202,7 @@ class TrainingData:
|
|
|
202
202
|
|
|
203
203
|
return list(OrderedDict.fromkeys(examples))
|
|
204
204
|
|
|
205
|
-
@
|
|
205
|
+
@cached_property
|
|
206
206
|
def nlu_examples(self) -> List[Message]:
|
|
207
207
|
"""Return examples which have come from NLU training data.
|
|
208
208
|
|
|
@@ -215,32 +215,32 @@ class TrainingData:
|
|
|
215
215
|
ex for ex in self.training_examples if not ex.is_core_or_domain_message()
|
|
216
216
|
]
|
|
217
217
|
|
|
218
|
-
@
|
|
218
|
+
@cached_property
|
|
219
219
|
def intent_examples(self) -> List[Message]:
|
|
220
220
|
"""Returns the list of examples that have intent."""
|
|
221
221
|
return [ex for ex in self.nlu_examples if ex.get(INTENT)]
|
|
222
222
|
|
|
223
|
-
@
|
|
223
|
+
@cached_property
|
|
224
224
|
def response_examples(self) -> List[Message]:
|
|
225
225
|
"""Returns the list of examples that have response."""
|
|
226
226
|
return [ex for ex in self.nlu_examples if ex.get(INTENT_RESPONSE_KEY)]
|
|
227
227
|
|
|
228
|
-
@
|
|
228
|
+
@cached_property
|
|
229
229
|
def entity_examples(self) -> List[Message]:
|
|
230
230
|
"""Returns the list of examples that have entities."""
|
|
231
231
|
return [ex for ex in self.nlu_examples if ex.get(ENTITIES)]
|
|
232
232
|
|
|
233
|
-
@
|
|
233
|
+
@cached_property
|
|
234
234
|
def intents(self) -> Set[Text]:
|
|
235
235
|
"""Returns the set of intents in the training data."""
|
|
236
236
|
return {ex.get(INTENT) for ex in self.training_examples} - {None}
|
|
237
237
|
|
|
238
|
-
@
|
|
238
|
+
@cached_property
|
|
239
239
|
def action_names(self) -> Set[Text]:
|
|
240
240
|
"""Returns the set of action names in the training data."""
|
|
241
241
|
return {ex.get(ACTION_NAME) for ex in self.training_examples} - {None}
|
|
242
242
|
|
|
243
|
-
@
|
|
243
|
+
@cached_property
|
|
244
244
|
def retrieval_intents(self) -> Set[Text]:
|
|
245
245
|
"""Returns the total number of response types in the training data."""
|
|
246
246
|
return {
|
|
@@ -249,13 +249,13 @@ class TrainingData:
|
|
|
249
249
|
if ex.get(INTENT_RESPONSE_KEY)
|
|
250
250
|
}
|
|
251
251
|
|
|
252
|
-
@
|
|
252
|
+
@cached_property
|
|
253
253
|
def number_of_examples_per_intent(self) -> Dict[Text, int]:
|
|
254
254
|
"""Calculates the number of examples per intent."""
|
|
255
255
|
intents = [ex.get(INTENT) for ex in self.nlu_examples]
|
|
256
256
|
return dict(Counter(intents))
|
|
257
257
|
|
|
258
|
-
@
|
|
258
|
+
@cached_property
|
|
259
259
|
def number_of_examples_per_response(self) -> Dict[Text, int]:
|
|
260
260
|
"""Calculates the number of examples per response."""
|
|
261
261
|
responses = [
|
|
@@ -265,12 +265,12 @@ class TrainingData:
|
|
|
265
265
|
]
|
|
266
266
|
return dict(Counter(responses))
|
|
267
267
|
|
|
268
|
-
@
|
|
268
|
+
@cached_property
|
|
269
269
|
def entities(self) -> Set[Text]:
|
|
270
270
|
"""Returns the set of entity types in the training data."""
|
|
271
271
|
return {e.get(ENTITY_ATTRIBUTE_TYPE) for e in self.sorted_entities()}
|
|
272
272
|
|
|
273
|
-
@
|
|
273
|
+
@cached_property
|
|
274
274
|
def entity_roles(self) -> Set[Text]:
|
|
275
275
|
"""Returns the set of entity roles in the training data."""
|
|
276
276
|
entity_types = {
|
|
@@ -280,7 +280,7 @@ class TrainingData:
|
|
|
280
280
|
}
|
|
281
281
|
return entity_types - {NO_ENTITY_TAG}
|
|
282
282
|
|
|
283
|
-
@
|
|
283
|
+
@cached_property
|
|
284
284
|
def entity_groups(self) -> Set[Text]:
|
|
285
285
|
"""Returns the set of entity groups in the training data."""
|
|
286
286
|
entity_types = {
|
|
@@ -299,7 +299,7 @@ class TrainingData:
|
|
|
299
299
|
|
|
300
300
|
return entity_groups_used or entity_roles_used
|
|
301
301
|
|
|
302
|
-
@
|
|
302
|
+
@cached_property
|
|
303
303
|
def number_of_examples_per_entity(self) -> Dict[Text, int]:
|
|
304
304
|
"""Calculates the number of examples per entity."""
|
|
305
305
|
entities = []
|
|
@@ -426,8 +426,9 @@ class TrainingData:
|
|
|
426
426
|
def persist(
|
|
427
427
|
self, dir_name: Text, filename: Text = DEFAULT_TRAINING_DATA_OUTPUT_PATH
|
|
428
428
|
) -> Dict[Text, Any]:
|
|
429
|
-
"""Persists this training data to disk
|
|
430
|
-
|
|
429
|
+
"""Persists this training data to disk.
|
|
430
|
+
|
|
431
|
+
Returns: necessary information to load it again.
|
|
431
432
|
"""
|
|
432
433
|
if not os.path.exists(dir_name):
|
|
433
434
|
os.makedirs(dir_name)
|
|
@@ -498,9 +499,7 @@ class TrainingData:
|
|
|
498
499
|
def train_test_split(
|
|
499
500
|
self, train_frac: float = 0.8, random_seed: Optional[int] = None
|
|
500
501
|
) -> Tuple["TrainingData", "TrainingData"]:
|
|
501
|
-
"""Split into a training and test dataset,
|
|
502
|
-
preserving the fraction of examples per intent.
|
|
503
|
-
"""
|
|
502
|
+
"""Split into a training and test dataset, preserving the fraction of examples per intent.""" # noqa: E501
|
|
504
503
|
# collect all nlu data
|
|
505
504
|
test, train = self.split_nlu_examples(train_frac, random_seed)
|
|
506
505
|
|
|
@@ -0,0 +1,220 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import Any, Dict, List
|
|
4
|
+
|
|
5
|
+
import structlog
|
|
6
|
+
from rasa.shared.constants import (
|
|
7
|
+
ROUTER_CONFIG_KEY,
|
|
8
|
+
MODELS_CONFIG_KEY,
|
|
9
|
+
MODEL_GROUP_ID_CONFIG_KEY,
|
|
10
|
+
MODEL_NAME_CONFIG_KEY,
|
|
11
|
+
LITELLM_PARAMS_KEY,
|
|
12
|
+
PROVIDER_CONFIG_KEY,
|
|
13
|
+
DEPLOYMENT_CONFIG_KEY,
|
|
14
|
+
API_TYPE_CONFIG_KEY,
|
|
15
|
+
MODEL_CONFIG_KEY,
|
|
16
|
+
MODEL_LIST_KEY,
|
|
17
|
+
USE_CHAT_COMPLETIONS_ENDPOINT_CONFIG_KEY,
|
|
18
|
+
)
|
|
19
|
+
from rasa.shared.providers._configs.model_group_config import (
|
|
20
|
+
ModelGroupConfig,
|
|
21
|
+
ModelConfig,
|
|
22
|
+
)
|
|
23
|
+
from rasa.shared.providers.mappings import get_prefix_from_provider
|
|
24
|
+
from rasa.shared.utils.llm import DEPLOYMENT_CENTRIC_PROVIDERS
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
structlogger = structlog.get_logger()
|
|
28
|
+
|
|
29
|
+
_LITELLM_UNSUPPORTED_KEYS = [
|
|
30
|
+
PROVIDER_CONFIG_KEY,
|
|
31
|
+
DEPLOYMENT_CONFIG_KEY,
|
|
32
|
+
API_TYPE_CONFIG_KEY,
|
|
33
|
+
USE_CHAT_COMPLETIONS_ENDPOINT_CONFIG_KEY,
|
|
34
|
+
]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class LiteLLMRouterClientConfig:
|
|
39
|
+
"""Parses configuration for a LiteLLM Router client. The configuration is expected
|
|
40
|
+
to be in the following format:
|
|
41
|
+
|
|
42
|
+
{
|
|
43
|
+
"id": "model_group_id",
|
|
44
|
+
"models": [
|
|
45
|
+
{
|
|
46
|
+
"provider": "provider_name",
|
|
47
|
+
"model": "model_name",
|
|
48
|
+
"api_base": "api_base",
|
|
49
|
+
"api_key": "api_key",
|
|
50
|
+
"api_version": "api_version",
|
|
51
|
+
},
|
|
52
|
+
{
|
|
53
|
+
"provider": "provider_name",
|
|
54
|
+
"model": "model_name",
|
|
55
|
+
},
|
|
56
|
+
"router": {}
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
This configuration is converted into the LiteLLM required format:
|
|
60
|
+
|
|
61
|
+
{
|
|
62
|
+
"id": "model_group_id",
|
|
63
|
+
"model_list": [
|
|
64
|
+
{
|
|
65
|
+
"model_name": "model_group_id",
|
|
66
|
+
"litellm_params": {
|
|
67
|
+
"model": "provider_name/model_name",
|
|
68
|
+
"api_base": "api_base",
|
|
69
|
+
"api_key": "api_key",
|
|
70
|
+
"api_version": "api_version",
|
|
71
|
+
},
|
|
72
|
+
},
|
|
73
|
+
{
|
|
74
|
+
"model_name": "model_group_id",
|
|
75
|
+
"litellm_params": {
|
|
76
|
+
"model": "provider_name/model_name",
|
|
77
|
+
},
|
|
78
|
+
},
|
|
79
|
+
],
|
|
80
|
+
"router": {},
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
Raises:
|
|
84
|
+
ValueError: If the configuration is missing required keys.
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
_model_group_config: ModelGroupConfig
|
|
88
|
+
router: Dict[str, Any]
|
|
89
|
+
_use_chat_completions_endpoint: bool = True
|
|
90
|
+
extra_parameters: dict = field(default_factory=dict)
|
|
91
|
+
|
|
92
|
+
@property
|
|
93
|
+
def model_group_id(self) -> str:
|
|
94
|
+
return self._model_group_config.model_group_id
|
|
95
|
+
|
|
96
|
+
@property
|
|
97
|
+
def models(self) -> List[ModelConfig]:
|
|
98
|
+
return self._model_group_config.models
|
|
99
|
+
|
|
100
|
+
@property
|
|
101
|
+
def litellm_model_list(self) -> List[Dict[str, Any]]:
|
|
102
|
+
return self._convert_models_to_litellm_model_list()
|
|
103
|
+
|
|
104
|
+
@property
|
|
105
|
+
def litellm_router_settings(self) -> Dict[str, Any]:
|
|
106
|
+
return self._convert_router_to_litellm_router_settings()
|
|
107
|
+
|
|
108
|
+
@property
|
|
109
|
+
def use_chat_completions_endpoint(self) -> bool:
|
|
110
|
+
return self._use_chat_completions_endpoint
|
|
111
|
+
|
|
112
|
+
def __post_init__(self) -> None:
|
|
113
|
+
if not self.router:
|
|
114
|
+
message = "Router cannot be empty."
|
|
115
|
+
structlogger.error(
|
|
116
|
+
"litellm_router_client_config.validation_error",
|
|
117
|
+
message=message,
|
|
118
|
+
model_group_id=self._model_group_config.model_group_id,
|
|
119
|
+
)
|
|
120
|
+
raise ValueError(message)
|
|
121
|
+
|
|
122
|
+
@classmethod
|
|
123
|
+
def from_dict(cls, config: dict) -> "LiteLLMRouterClientConfig":
|
|
124
|
+
"""Initializes a dataclass from the passed config.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
config: (dict) The config from which to initialize.
|
|
128
|
+
|
|
129
|
+
Raises:
|
|
130
|
+
ValueError: Config is missing required keys.
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
LiteLLMRouterClientConfig
|
|
134
|
+
"""
|
|
135
|
+
model_group_config = ModelGroupConfig.from_dict(config)
|
|
136
|
+
|
|
137
|
+
# Copy config to avoid mutating the original
|
|
138
|
+
config_copy = copy.deepcopy(config)
|
|
139
|
+
# Pop the keys used by ModelGroupConfig
|
|
140
|
+
config_copy.pop(MODEL_GROUP_ID_CONFIG_KEY, None)
|
|
141
|
+
config_copy.pop(MODELS_CONFIG_KEY, None)
|
|
142
|
+
# Get the router settings
|
|
143
|
+
router_settings = config_copy.pop(ROUTER_CONFIG_KEY, {})
|
|
144
|
+
# Get the use_chat_completions_endpoint setting
|
|
145
|
+
use_chat_completions_endpoint = router_settings.get(
|
|
146
|
+
USE_CHAT_COMPLETIONS_ENDPOINT_CONFIG_KEY, True
|
|
147
|
+
)
|
|
148
|
+
# The rest is considered as extra parameters
|
|
149
|
+
extra_parameters = config_copy
|
|
150
|
+
|
|
151
|
+
this = LiteLLMRouterClientConfig(
|
|
152
|
+
_model_group_config=model_group_config,
|
|
153
|
+
router=router_settings,
|
|
154
|
+
_use_chat_completions_endpoint=use_chat_completions_endpoint,
|
|
155
|
+
extra_parameters=extra_parameters,
|
|
156
|
+
)
|
|
157
|
+
return this
|
|
158
|
+
|
|
159
|
+
def to_dict(self) -> dict:
|
|
160
|
+
"""Converts the config instance into a dictionary."""
|
|
161
|
+
d = self._model_group_config.to_dict()
|
|
162
|
+
d[ROUTER_CONFIG_KEY] = self.router
|
|
163
|
+
if self.extra_parameters:
|
|
164
|
+
d.update(self.extra_parameters)
|
|
165
|
+
return d
|
|
166
|
+
|
|
167
|
+
def to_litellm_dict(self) -> dict:
|
|
168
|
+
return {
|
|
169
|
+
**self.extra_parameters,
|
|
170
|
+
MODEL_GROUP_ID_CONFIG_KEY: self.model_group_id,
|
|
171
|
+
MODEL_LIST_KEY: self._convert_models_to_litellm_model_list(),
|
|
172
|
+
ROUTER_CONFIG_KEY: self._convert_router_to_litellm_router_settings(),
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
def _convert_router_to_litellm_router_settings(self) -> Dict[str, Any]:
|
|
176
|
+
_router_settings_copy = copy.deepcopy(self.router)
|
|
177
|
+
_router_settings_copy.pop(USE_CHAT_COMPLETIONS_ENDPOINT_CONFIG_KEY, None)
|
|
178
|
+
return _router_settings_copy
|
|
179
|
+
|
|
180
|
+
def _convert_models_to_litellm_model_list(self) -> List[Dict[str, Any]]:
|
|
181
|
+
litellm_model_list = []
|
|
182
|
+
|
|
183
|
+
for model_config_object in self.models:
|
|
184
|
+
# Convert the model config to a dict representation
|
|
185
|
+
litellm_model_config = model_config_object.to_dict()
|
|
186
|
+
|
|
187
|
+
provider = litellm_model_config[PROVIDER_CONFIG_KEY]
|
|
188
|
+
|
|
189
|
+
# Get the litellm prefixing for the provider
|
|
190
|
+
prefix = get_prefix_from_provider(provider)
|
|
191
|
+
|
|
192
|
+
# Determine whether to use model or deployment key based on the provider.
|
|
193
|
+
litellm_model_name = (
|
|
194
|
+
litellm_model_config[DEPLOYMENT_CONFIG_KEY]
|
|
195
|
+
if provider in DEPLOYMENT_CENTRIC_PROVIDERS
|
|
196
|
+
else litellm_model_config[MODEL_CONFIG_KEY]
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
# Set 'model' to a provider prefixed model name e.g. openai/gpt-4
|
|
200
|
+
litellm_model_config[MODEL_CONFIG_KEY] = (
|
|
201
|
+
litellm_model_name
|
|
202
|
+
if f"{prefix}/" in litellm_model_name
|
|
203
|
+
else f"{prefix}/{litellm_model_name}"
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
# Remove parameters that are None and not supported by LiteLLM.
|
|
207
|
+
litellm_model_config = {
|
|
208
|
+
key: value
|
|
209
|
+
for key, value in litellm_model_config.items()
|
|
210
|
+
if key not in _LITELLM_UNSUPPORTED_KEYS and value is not None
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
litellm_model_list_item = {
|
|
214
|
+
MODEL_NAME_CONFIG_KEY: self.model_group_id,
|
|
215
|
+
LITELLM_PARAMS_KEY: litellm_model_config,
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
litellm_model_list.append(litellm_model_list_item)
|
|
219
|
+
|
|
220
|
+
return litellm_model_list
|