rasa-pro 3.9.18__py3-none-any.whl → 3.10.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- README.md +0 -374
- rasa/__init__.py +1 -2
- rasa/__main__.py +5 -0
- rasa/anonymization/anonymization_rule_executor.py +2 -2
- rasa/api.py +27 -23
- rasa/cli/arguments/data.py +27 -2
- rasa/cli/arguments/default_arguments.py +25 -3
- rasa/cli/arguments/run.py +9 -9
- rasa/cli/arguments/train.py +11 -3
- rasa/cli/data.py +70 -8
- rasa/cli/e2e_test.py +104 -431
- rasa/cli/evaluate.py +1 -1
- rasa/cli/interactive.py +1 -0
- rasa/cli/llm_fine_tuning.py +398 -0
- rasa/cli/project_templates/calm/endpoints.yml +1 -1
- rasa/cli/project_templates/tutorial/endpoints.yml +1 -1
- rasa/cli/run.py +15 -14
- rasa/cli/scaffold.py +10 -8
- rasa/cli/studio/studio.py +35 -5
- rasa/cli/train.py +56 -8
- rasa/cli/utils.py +22 -5
- rasa/cli/x.py +1 -1
- rasa/constants.py +7 -1
- rasa/core/actions/action.py +98 -49
- rasa/core/actions/action_run_slot_rejections.py +4 -1
- rasa/core/actions/custom_action_executor.py +9 -6
- rasa/core/actions/direct_custom_actions_executor.py +80 -0
- rasa/core/actions/e2e_stub_custom_action_executor.py +68 -0
- rasa/core/actions/grpc_custom_action_executor.py +2 -2
- rasa/core/actions/http_custom_action_executor.py +6 -5
- rasa/core/agent.py +21 -17
- rasa/core/channels/__init__.py +2 -0
- rasa/core/channels/audiocodes.py +1 -16
- rasa/core/channels/voice_aware/__init__.py +0 -0
- rasa/core/channels/voice_aware/jambonz.py +103 -0
- rasa/core/channels/voice_aware/jambonz_protocol.py +344 -0
- rasa/core/channels/voice_aware/utils.py +20 -0
- rasa/core/channels/voice_native/__init__.py +0 -0
- rasa/core/constants.py +6 -1
- rasa/core/information_retrieval/faiss.py +7 -4
- rasa/core/information_retrieval/information_retrieval.py +8 -0
- rasa/core/information_retrieval/milvus.py +9 -2
- rasa/core/information_retrieval/qdrant.py +1 -1
- rasa/core/nlg/contextual_response_rephraser.py +32 -10
- rasa/core/nlg/summarize.py +4 -3
- rasa/core/policies/enterprise_search_policy.py +113 -45
- rasa/core/policies/flows/flow_executor.py +122 -76
- rasa/core/policies/intentless_policy.py +83 -29
- rasa/core/processor.py +72 -54
- rasa/core/run.py +5 -4
- rasa/core/tracker_store.py +8 -4
- rasa/core/training/interactive.py +1 -1
- rasa/core/utils.py +56 -57
- rasa/dialogue_understanding/coexistence/llm_based_router.py +53 -13
- rasa/dialogue_understanding/commands/__init__.py +6 -0
- rasa/dialogue_understanding/commands/restart_command.py +58 -0
- rasa/dialogue_understanding/commands/session_start_command.py +59 -0
- rasa/dialogue_understanding/commands/utils.py +40 -0
- rasa/dialogue_understanding/generator/constants.py +10 -3
- rasa/dialogue_understanding/generator/flow_retrieval.py +21 -5
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +13 -3
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +134 -90
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +47 -7
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +127 -41
- rasa/dialogue_understanding/patterns/restart.py +37 -0
- rasa/dialogue_understanding/patterns/session_start.py +37 -0
- rasa/dialogue_understanding/processor/command_processor.py +16 -3
- rasa/dialogue_understanding/processor/command_processor_component.py +6 -2
- rasa/e2e_test/aggregate_test_stats_calculator.py +134 -0
- rasa/e2e_test/assertions.py +1223 -0
- rasa/e2e_test/assertions_schema.yml +106 -0
- rasa/e2e_test/constants.py +20 -0
- rasa/e2e_test/e2e_config.py +220 -0
- rasa/e2e_test/e2e_config_schema.yml +26 -0
- rasa/e2e_test/e2e_test_case.py +131 -8
- rasa/e2e_test/e2e_test_converter.py +363 -0
- rasa/e2e_test/e2e_test_converter_prompt.jinja2 +70 -0
- rasa/e2e_test/e2e_test_coverage_report.py +364 -0
- rasa/e2e_test/e2e_test_result.py +26 -6
- rasa/e2e_test/e2e_test_runner.py +493 -71
- rasa/e2e_test/e2e_test_schema.yml +96 -0
- rasa/e2e_test/pykwalify_extensions.py +39 -0
- rasa/e2e_test/stub_custom_action.py +70 -0
- rasa/e2e_test/utils/__init__.py +0 -0
- rasa/e2e_test/utils/e2e_yaml_utils.py +55 -0
- rasa/e2e_test/utils/io.py +598 -0
- rasa/e2e_test/utils/validation.py +80 -0
- rasa/engine/graph.py +9 -3
- rasa/engine/recipes/default_components.py +0 -2
- rasa/engine/recipes/default_recipe.py +10 -2
- rasa/engine/storage/local_model_storage.py +40 -12
- rasa/engine/validation.py +78 -1
- rasa/env.py +9 -0
- rasa/graph_components/providers/story_graph_provider.py +59 -6
- rasa/llm_fine_tuning/__init__.py +0 -0
- rasa/llm_fine_tuning/annotation_module.py +241 -0
- rasa/llm_fine_tuning/conversations.py +144 -0
- rasa/llm_fine_tuning/llm_data_preparation_module.py +178 -0
- rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +407 -0
- rasa/llm_fine_tuning/paraphrasing/__init__.py +0 -0
- rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +281 -0
- rasa/llm_fine_tuning/paraphrasing/default_rephrase_prompt_template.jina2 +44 -0
- rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +121 -0
- rasa/llm_fine_tuning/paraphrasing/rephrased_user_message.py +10 -0
- rasa/llm_fine_tuning/paraphrasing_module.py +128 -0
- rasa/llm_fine_tuning/storage.py +174 -0
- rasa/llm_fine_tuning/train_test_split_module.py +441 -0
- rasa/model_training.py +56 -16
- rasa/nlu/persistor.py +157 -36
- rasa/server.py +45 -10
- rasa/shared/constants.py +76 -16
- rasa/shared/core/domain.py +27 -19
- rasa/shared/core/events.py +28 -2
- rasa/shared/core/flows/flow.py +208 -13
- rasa/shared/core/flows/flow_path.py +84 -0
- rasa/shared/core/flows/flows_list.py +33 -11
- rasa/shared/core/flows/flows_yaml_schema.json +269 -193
- rasa/shared/core/flows/validation.py +112 -25
- rasa/shared/core/flows/yaml_flows_io.py +149 -10
- rasa/shared/core/trackers.py +6 -0
- rasa/shared/core/training_data/structures.py +20 -0
- rasa/shared/core/training_data/visualization.html +2 -2
- rasa/shared/exceptions.py +4 -0
- rasa/shared/importers/importer.py +64 -16
- rasa/shared/nlu/constants.py +2 -0
- rasa/shared/providers/_configs/__init__.py +0 -0
- rasa/shared/providers/_configs/azure_openai_client_config.py +183 -0
- rasa/shared/providers/_configs/client_config.py +57 -0
- rasa/shared/providers/_configs/default_litellm_client_config.py +130 -0
- rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +234 -0
- rasa/shared/providers/_configs/openai_client_config.py +175 -0
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +176 -0
- rasa/shared/providers/_configs/utils.py +101 -0
- rasa/shared/providers/_ssl_verification_utils.py +124 -0
- rasa/shared/providers/embedding/__init__.py +0 -0
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +259 -0
- rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py +74 -0
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +277 -0
- rasa/shared/providers/embedding/default_litellm_embedding_client.py +102 -0
- rasa/shared/providers/embedding/embedding_client.py +90 -0
- rasa/shared/providers/embedding/embedding_response.py +41 -0
- rasa/shared/providers/embedding/huggingface_local_embedding_client.py +191 -0
- rasa/shared/providers/embedding/openai_embedding_client.py +172 -0
- rasa/shared/providers/llm/__init__.py +0 -0
- rasa/shared/providers/llm/_base_litellm_client.py +251 -0
- rasa/shared/providers/llm/azure_openai_llm_client.py +338 -0
- rasa/shared/providers/llm/default_litellm_llm_client.py +84 -0
- rasa/shared/providers/llm/llm_client.py +76 -0
- rasa/shared/providers/llm/llm_response.py +50 -0
- rasa/shared/providers/llm/openai_llm_client.py +155 -0
- rasa/shared/providers/llm/self_hosted_llm_client.py +293 -0
- rasa/shared/providers/mappings.py +75 -0
- rasa/shared/utils/cli.py +30 -0
- rasa/shared/utils/io.py +65 -2
- rasa/shared/utils/llm.py +246 -200
- rasa/shared/utils/yaml.py +121 -15
- rasa/studio/auth.py +6 -4
- rasa/studio/config.py +13 -4
- rasa/studio/constants.py +1 -0
- rasa/studio/data_handler.py +10 -3
- rasa/studio/download.py +19 -13
- rasa/studio/train.py +2 -3
- rasa/studio/upload.py +19 -11
- rasa/telemetry.py +113 -58
- rasa/tracing/instrumentation/attribute_extractors.py +32 -17
- rasa/utils/common.py +18 -19
- rasa/utils/endpoints.py +7 -4
- rasa/utils/json_utils.py +60 -0
- rasa/utils/licensing.py +9 -1
- rasa/utils/ml_utils.py +4 -2
- rasa/validator.py +213 -3
- rasa/version.py +1 -1
- rasa_pro-3.10.16.dist-info/METADATA +196 -0
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/RECORD +179 -113
- rasa/nlu/classifiers/llm_intent_classifier.py +0 -519
- rasa/shared/providers/openai/clients.py +0 -43
- rasa/shared/providers/openai/session_handler.py +0 -110
- rasa_pro-3.9.18.dist-info/METADATA +0 -563
- /rasa/{shared/providers/openai → cli/project_templates/tutorial/actions}/__init__.py +0 -0
- /rasa/cli/project_templates/tutorial/{actions.py → actions/actions.py} +0 -0
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/NOTICE +0 -0
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/WHEEL +0 -0
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/entry_points.txt +0 -0
|
@@ -1,28 +1,30 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
1
4
|
from abc import ABC, abstractmethod
|
|
2
5
|
from functools import reduce
|
|
3
|
-
from typing import
|
|
4
|
-
import logging
|
|
6
|
+
from typing import Any, Dict, List, Optional, Set, Text, Tuple, Type, Union, cast
|
|
5
7
|
|
|
6
8
|
import importlib_resources
|
|
7
9
|
|
|
8
10
|
import rasa.shared.constants
|
|
9
|
-
from rasa.shared.core.flows import FlowsList
|
|
10
|
-
import rasa.shared.utils.common
|
|
11
11
|
import rasa.shared.core.constants
|
|
12
|
+
import rasa.shared.utils.common
|
|
12
13
|
import rasa.shared.utils.io
|
|
13
14
|
from rasa.shared.core.domain import (
|
|
14
|
-
|
|
15
|
+
IS_RETRIEVAL_INTENT_KEY,
|
|
16
|
+
KEY_ACTIONS,
|
|
15
17
|
KEY_E2E_ACTIONS,
|
|
16
18
|
KEY_INTENTS,
|
|
17
19
|
KEY_RESPONSES,
|
|
18
|
-
|
|
20
|
+
Domain,
|
|
19
21
|
)
|
|
20
22
|
from rasa.shared.core.events import ActionExecuted, UserUttered
|
|
23
|
+
from rasa.shared.core.flows import FlowsList
|
|
21
24
|
from rasa.shared.core.training_data.structures import StoryGraph
|
|
25
|
+
from rasa.shared.nlu.constants import ACTION_NAME, ENTITIES
|
|
22
26
|
from rasa.shared.nlu.training_data.message import Message
|
|
23
27
|
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
24
|
-
from rasa.shared.nlu.constants import ENTITIES, ACTION_NAME
|
|
25
|
-
from rasa.shared.core.domain import IS_RETRIEVAL_INTENT_KEY
|
|
26
28
|
from rasa.shared.utils.yaml import read_config_file
|
|
27
29
|
|
|
28
30
|
logger = logging.getLogger(__name__)
|
|
@@ -114,7 +116,7 @@ class TrainingDataImporter(ABC):
|
|
|
114
116
|
domain_path: Optional[Text] = None,
|
|
115
117
|
training_data_paths: Optional[List[Text]] = None,
|
|
116
118
|
args: Optional[Dict[Text, Any]] = {},
|
|
117
|
-
) ->
|
|
119
|
+
) -> TrainingDataImporter:
|
|
118
120
|
"""Loads a `TrainingDataImporter` instance from a configuration file."""
|
|
119
121
|
config = read_config_file(config_path)
|
|
120
122
|
return TrainingDataImporter.load_from_dict(
|
|
@@ -127,7 +129,7 @@ class TrainingDataImporter(ABC):
|
|
|
127
129
|
domain_path: Optional[Text] = None,
|
|
128
130
|
training_data_paths: Optional[List[Text]] = None,
|
|
129
131
|
args: Optional[Dict[Text, Any]] = {},
|
|
130
|
-
) ->
|
|
132
|
+
) -> TrainingDataImporter:
|
|
131
133
|
"""Loads core `TrainingDataImporter` instance.
|
|
132
134
|
|
|
133
135
|
Instance loaded from configuration file will only read Core training data.
|
|
@@ -143,7 +145,7 @@ class TrainingDataImporter(ABC):
|
|
|
143
145
|
domain_path: Optional[Text] = None,
|
|
144
146
|
training_data_paths: Optional[List[Text]] = None,
|
|
145
147
|
args: Optional[Dict[Text, Any]] = {},
|
|
146
|
-
) ->
|
|
148
|
+
) -> TrainingDataImporter:
|
|
147
149
|
"""Loads nlu `TrainingDataImporter` instance.
|
|
148
150
|
|
|
149
151
|
Instance loaded from configuration file will only read NLU training data.
|
|
@@ -165,8 +167,8 @@ class TrainingDataImporter(ABC):
|
|
|
165
167
|
config_path: Optional[Text] = None,
|
|
166
168
|
domain_path: Optional[Text] = None,
|
|
167
169
|
training_data_paths: Optional[List[Text]] = None,
|
|
168
|
-
args: Optional[Dict[Text, Any]] =
|
|
169
|
-
) ->
|
|
170
|
+
args: Optional[Dict[Text, Any]] = None,
|
|
171
|
+
) -> TrainingDataImporter:
|
|
170
172
|
"""Loads a `TrainingDataImporter` instance from a dictionary."""
|
|
171
173
|
from rasa.shared.importers.rasa import RasaFileImporter
|
|
172
174
|
|
|
@@ -194,8 +196,8 @@ class TrainingDataImporter(ABC):
|
|
|
194
196
|
config_path: Text,
|
|
195
197
|
domain_path: Optional[Text] = None,
|
|
196
198
|
training_data_paths: Optional[List[Text]] = None,
|
|
197
|
-
args: Optional[Dict[Text, Any]] =
|
|
198
|
-
) -> Optional[
|
|
199
|
+
args: Optional[Dict[Text, Any]] = None,
|
|
200
|
+
) -> Optional[TrainingDataImporter]:
|
|
199
201
|
from rasa.shared.importers.multi_project import MultiProjectImporter
|
|
200
202
|
from rasa.shared.importers.rasa import RasaFileImporter
|
|
201
203
|
|
|
@@ -216,7 +218,6 @@ class TrainingDataImporter(ABC):
|
|
|
216
218
|
constructor_arguments = rasa.shared.utils.common.minimal_kwargs(
|
|
217
219
|
{**importer_config, **(args or {})}, importer_class
|
|
218
220
|
)
|
|
219
|
-
|
|
220
221
|
return importer_class(
|
|
221
222
|
config_path,
|
|
222
223
|
domain_path,
|
|
@@ -232,6 +233,26 @@ class TrainingDataImporter(ABC):
|
|
|
232
233
|
"""Returns text representation of object."""
|
|
233
234
|
return self.__class__.__name__
|
|
234
235
|
|
|
236
|
+
def get_user_flows(self) -> FlowsList:
|
|
237
|
+
"""Retrieves the user-defined flows that should be used for training.
|
|
238
|
+
|
|
239
|
+
Implemented by FlowSyncImporter and E2EImporter only.
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
`FlowsList` containing all loaded flows.
|
|
243
|
+
"""
|
|
244
|
+
raise NotImplementedError
|
|
245
|
+
|
|
246
|
+
def get_user_domain(self) -> Domain:
|
|
247
|
+
"""Retrieves the user-defined domain that should be used for training.
|
|
248
|
+
|
|
249
|
+
Implemented by FlowSyncImporter and E2EImporter only.
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
`Domain`.
|
|
253
|
+
"""
|
|
254
|
+
raise NotImplementedError
|
|
255
|
+
|
|
235
256
|
|
|
236
257
|
class NluDataImporter(TrainingDataImporter):
|
|
237
258
|
"""Importer that skips any Core-related file reading."""
|
|
@@ -448,6 +469,10 @@ class FlowSyncImporter(PassThroughImporter):
|
|
|
448
469
|
|
|
449
470
|
return self.merge_with_default_flows(flows)
|
|
450
471
|
|
|
472
|
+
@rasa.shared.utils.common.cached_method
|
|
473
|
+
def get_user_flows(self) -> FlowsList:
|
|
474
|
+
return self._importer.get_flows()
|
|
475
|
+
|
|
451
476
|
@rasa.shared.utils.common.cached_method
|
|
452
477
|
def get_domain(self) -> Domain:
|
|
453
478
|
"""Merge existing domain with properties of flows."""
|
|
@@ -476,6 +501,11 @@ class FlowSyncImporter(PassThroughImporter):
|
|
|
476
501
|
)
|
|
477
502
|
return domain
|
|
478
503
|
|
|
504
|
+
@rasa.shared.utils.common.cached_method
|
|
505
|
+
def get_user_domain(self) -> Domain:
|
|
506
|
+
"""Retrieves only user defined domain."""
|
|
507
|
+
return self._importer.get_domain()
|
|
508
|
+
|
|
479
509
|
|
|
480
510
|
class ResponsesSyncImporter(PassThroughImporter):
|
|
481
511
|
"""Importer that syncs `responses` between Domain and NLU training data.
|
|
@@ -602,6 +632,15 @@ class E2EImporter(PassThroughImporter):
|
|
|
602
632
|
- adds potential end-to-end bot messages from stories as actions to the domain
|
|
603
633
|
"""
|
|
604
634
|
|
|
635
|
+
@rasa.shared.utils.common.cached_method
|
|
636
|
+
def get_user_flows(self) -> FlowsList:
|
|
637
|
+
if not isinstance(self._importer, FlowSyncImporter):
|
|
638
|
+
raise NotImplementedError(
|
|
639
|
+
"Accessing user flows is only supported with FlowSyncImporter."
|
|
640
|
+
)
|
|
641
|
+
|
|
642
|
+
return self._importer.get_user_flows()
|
|
643
|
+
|
|
605
644
|
@rasa.shared.utils.common.cached_method
|
|
606
645
|
def get_domain(self) -> Domain:
|
|
607
646
|
"""Retrieves model domain (see parent class for full docstring)."""
|
|
@@ -610,6 +649,15 @@ class E2EImporter(PassThroughImporter):
|
|
|
610
649
|
|
|
611
650
|
return original.merge(e2e_domain)
|
|
612
651
|
|
|
652
|
+
@rasa.shared.utils.common.cached_method
|
|
653
|
+
def get_user_domain(self) -> Domain:
|
|
654
|
+
"""Retrieves only user defined domain."""
|
|
655
|
+
if not isinstance(self._importer, FlowSyncImporter):
|
|
656
|
+
raise NotImplementedError(
|
|
657
|
+
"Accessing user domain is only supported with FlowSyncImporter."
|
|
658
|
+
)
|
|
659
|
+
return self._importer.get_user_domain()
|
|
660
|
+
|
|
613
661
|
def _get_domain_with_e2e_actions(self) -> Domain:
|
|
614
662
|
stories = self.get_stories()
|
|
615
663
|
|
rasa/shared/nlu/constants.py
CHANGED
|
@@ -2,6 +2,8 @@ TEXT = "text"
|
|
|
2
2
|
TEXT_TOKENS = "text_tokens"
|
|
3
3
|
INTENT = "intent"
|
|
4
4
|
COMMANDS = "commands"
|
|
5
|
+
LLM_COMMANDS = "llm_commands" # needed for fine-tuning
|
|
6
|
+
LLM_PROMPT = "llm_prompt" # needed for fine-tuning
|
|
5
7
|
FLOWS_FROM_SEMANTIC_SEARCH = "flows_from_semantic_search"
|
|
6
8
|
FLOWS_IN_PROMPT = "flows_in_prompt"
|
|
7
9
|
NOT_INTENT = "not_intent"
|
|
File without changes
|
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
from dataclasses import asdict, dataclass, field
|
|
2
|
+
from typing import Any, Dict, Optional
|
|
3
|
+
|
|
4
|
+
import structlog
|
|
5
|
+
|
|
6
|
+
from rasa.shared.constants import (
|
|
7
|
+
MODEL_CONFIG_KEY,
|
|
8
|
+
MODEL_NAME_CONFIG_KEY,
|
|
9
|
+
OPENAI_API_BASE_CONFIG_KEY,
|
|
10
|
+
API_BASE_CONFIG_KEY,
|
|
11
|
+
OPENAI_API_TYPE_CONFIG_KEY,
|
|
12
|
+
API_TYPE_CONFIG_KEY,
|
|
13
|
+
OPENAI_API_VERSION_CONFIG_KEY,
|
|
14
|
+
API_VERSION_CONFIG_KEY,
|
|
15
|
+
DEPLOYMENT_CONFIG_KEY,
|
|
16
|
+
DEPLOYMENT_NAME_CONFIG_KEY,
|
|
17
|
+
ENGINE_CONFIG_KEY,
|
|
18
|
+
RASA_TYPE_CONFIG_KEY,
|
|
19
|
+
LANGCHAIN_TYPE_CONFIG_KEY,
|
|
20
|
+
STREAM_CONFIG_KEY,
|
|
21
|
+
N_REPHRASES_CONFIG_KEY,
|
|
22
|
+
REQUEST_TIMEOUT_CONFIG_KEY,
|
|
23
|
+
TIMEOUT_CONFIG_KEY,
|
|
24
|
+
PROVIDER_CONFIG_KEY,
|
|
25
|
+
AZURE_OPENAI_PROVIDER,
|
|
26
|
+
AZURE_API_TYPE,
|
|
27
|
+
)
|
|
28
|
+
from rasa.shared.providers._configs.utils import (
|
|
29
|
+
resolve_aliases,
|
|
30
|
+
raise_deprecation_warnings,
|
|
31
|
+
validate_required_keys,
|
|
32
|
+
validate_forbidden_keys,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
structlogger = structlog.get_logger()
|
|
36
|
+
|
|
37
|
+
DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING = {
|
|
38
|
+
# Deployment name aliases
|
|
39
|
+
DEPLOYMENT_NAME_CONFIG_KEY: DEPLOYMENT_CONFIG_KEY,
|
|
40
|
+
ENGINE_CONFIG_KEY: DEPLOYMENT_CONFIG_KEY,
|
|
41
|
+
# Provider aliases
|
|
42
|
+
RASA_TYPE_CONFIG_KEY: PROVIDER_CONFIG_KEY,
|
|
43
|
+
LANGCHAIN_TYPE_CONFIG_KEY: PROVIDER_CONFIG_KEY,
|
|
44
|
+
# API type aliases
|
|
45
|
+
OPENAI_API_TYPE_CONFIG_KEY: API_TYPE_CONFIG_KEY,
|
|
46
|
+
# API base aliases
|
|
47
|
+
OPENAI_API_BASE_CONFIG_KEY: API_BASE_CONFIG_KEY,
|
|
48
|
+
# API version aliases
|
|
49
|
+
OPENAI_API_VERSION_CONFIG_KEY: API_VERSION_CONFIG_KEY,
|
|
50
|
+
# Model name aliases
|
|
51
|
+
MODEL_NAME_CONFIG_KEY: MODEL_CONFIG_KEY,
|
|
52
|
+
# Timeout aliases
|
|
53
|
+
REQUEST_TIMEOUT_CONFIG_KEY: TIMEOUT_CONFIG_KEY,
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
REQUIRED_KEYS = [DEPLOYMENT_CONFIG_KEY]
|
|
57
|
+
|
|
58
|
+
FORBIDDEN_KEYS = [
|
|
59
|
+
STREAM_CONFIG_KEY,
|
|
60
|
+
N_REPHRASES_CONFIG_KEY,
|
|
61
|
+
]
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@dataclass
|
|
65
|
+
class AzureOpenAIClientConfig:
|
|
66
|
+
"""Parses configuration for Azure OpenAI client, resolves aliases and
|
|
67
|
+
raises deprecation warnings.
|
|
68
|
+
|
|
69
|
+
Raises:
|
|
70
|
+
ValueError: Raised in cases of invalid configuration:
|
|
71
|
+
- If any of the required configuration keys are missing.
|
|
72
|
+
- If `api_type` has a value different from `azure`.
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
deployment: str
|
|
76
|
+
|
|
77
|
+
model: Optional[str]
|
|
78
|
+
api_base: Optional[str]
|
|
79
|
+
api_version: Optional[str]
|
|
80
|
+
# API Type is not used by LiteLLM backend, but we define
|
|
81
|
+
# it here for backward compatibility.
|
|
82
|
+
api_type: Optional[str] = AZURE_API_TYPE
|
|
83
|
+
|
|
84
|
+
# Provider is not used by LiteLLM backend, but we define it here since it's
|
|
85
|
+
# used as switch between different clients.
|
|
86
|
+
provider: str = AZURE_OPENAI_PROVIDER
|
|
87
|
+
|
|
88
|
+
extra_parameters: dict = field(default_factory=dict)
|
|
89
|
+
|
|
90
|
+
def __post_init__(self) -> None:
|
|
91
|
+
if self.provider != AZURE_OPENAI_PROVIDER:
|
|
92
|
+
message = f"Provider must be set to '{AZURE_OPENAI_PROVIDER}'."
|
|
93
|
+
structlogger.error(
|
|
94
|
+
"azure_openai_client_config.validation_error",
|
|
95
|
+
message=message,
|
|
96
|
+
provider=self.provider,
|
|
97
|
+
)
|
|
98
|
+
raise ValueError(message)
|
|
99
|
+
if self.deployment is None:
|
|
100
|
+
message = "Deployment cannot be set to None."
|
|
101
|
+
structlogger.error(
|
|
102
|
+
"azure_openai_client_config.validation_error",
|
|
103
|
+
message=message,
|
|
104
|
+
deployment=self.deployment,
|
|
105
|
+
)
|
|
106
|
+
raise ValueError(message)
|
|
107
|
+
|
|
108
|
+
@classmethod
|
|
109
|
+
def from_dict(cls, config: dict) -> "AzureOpenAIClientConfig":
|
|
110
|
+
"""Initializes a dataclass from the passed config.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
config: (dict) The config from which to initialize.
|
|
114
|
+
|
|
115
|
+
Raises:
|
|
116
|
+
ValueError: Raised in cases of invalid configuration:
|
|
117
|
+
- If any of the required configuration keys are missing.
|
|
118
|
+
- If `api_type` has a value different from `azure`.
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
AzureOpenAIClientConfig
|
|
122
|
+
"""
|
|
123
|
+
# Check for deprecated keys
|
|
124
|
+
raise_deprecation_warnings(config, DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING)
|
|
125
|
+
# Resolve any potential aliases
|
|
126
|
+
config = cls.resolve_config_aliases(config)
|
|
127
|
+
# Validate that required keys are set
|
|
128
|
+
validate_required_keys(config, REQUIRED_KEYS)
|
|
129
|
+
# Validate that the forbidden keys are not present
|
|
130
|
+
validate_forbidden_keys(config, FORBIDDEN_KEYS)
|
|
131
|
+
# Init client config
|
|
132
|
+
this = AzureOpenAIClientConfig(
|
|
133
|
+
# Required parameters
|
|
134
|
+
deployment=config.pop(DEPLOYMENT_CONFIG_KEY),
|
|
135
|
+
# Pop the 'provider' key. Currently, it's *optional* because of
|
|
136
|
+
# backward compatibility with older versions.
|
|
137
|
+
provider=config.pop(PROVIDER_CONFIG_KEY, AZURE_OPENAI_PROVIDER),
|
|
138
|
+
# Optional
|
|
139
|
+
api_type=config.pop(API_TYPE_CONFIG_KEY, AZURE_API_TYPE),
|
|
140
|
+
model=config.pop(MODEL_CONFIG_KEY, None),
|
|
141
|
+
# Optional, can also be set through environment variables
|
|
142
|
+
# in clients.
|
|
143
|
+
api_base=config.pop(API_BASE_CONFIG_KEY, None),
|
|
144
|
+
api_version=config.pop(API_VERSION_CONFIG_KEY, None),
|
|
145
|
+
# The rest of parameters (e.g. model parameters) are considered
|
|
146
|
+
# as extra parameters (this also includes timeout).
|
|
147
|
+
extra_parameters=config,
|
|
148
|
+
)
|
|
149
|
+
return this
|
|
150
|
+
|
|
151
|
+
def to_dict(self) -> dict:
|
|
152
|
+
"""Converts the config instance into a dictionary."""
|
|
153
|
+
d = asdict(self)
|
|
154
|
+
# Extra parameters should also be on the top level
|
|
155
|
+
d.pop("extra_parameters", None)
|
|
156
|
+
d.update(self.extra_parameters)
|
|
157
|
+
return d
|
|
158
|
+
|
|
159
|
+
@staticmethod
|
|
160
|
+
def resolve_config_aliases(config: Dict[str, Any]) -> Dict[str, Any]:
|
|
161
|
+
return resolve_aliases(config, DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def is_azure_openai_config(config: dict) -> bool:
|
|
165
|
+
"""Check whether the configuration is meant to configure
|
|
166
|
+
an Azure OpenAI client.
|
|
167
|
+
"""
|
|
168
|
+
# Resolve any aliases that are specific to Azure OpenAI configuration
|
|
169
|
+
config = AzureOpenAIClientConfig.resolve_config_aliases(config)
|
|
170
|
+
|
|
171
|
+
# Case: Configuration contains `provider: azure`.
|
|
172
|
+
if config.get(PROVIDER_CONFIG_KEY) == AZURE_OPENAI_PROVIDER:
|
|
173
|
+
return True
|
|
174
|
+
|
|
175
|
+
# Case: Configuration contains `deployment` key
|
|
176
|
+
# (specific to Azure OpenAI configuration)
|
|
177
|
+
if (
|
|
178
|
+
config.get(DEPLOYMENT_CONFIG_KEY) is not None
|
|
179
|
+
and config.get(PROVIDER_CONFIG_KEY) is None
|
|
180
|
+
):
|
|
181
|
+
return True
|
|
182
|
+
|
|
183
|
+
return False
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
from typing import Protocol, runtime_checkable
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
@runtime_checkable
|
|
5
|
+
class ClientConfig(Protocol):
|
|
6
|
+
"""
|
|
7
|
+
Protocol for the client config that specifies the interface for interacting
|
|
8
|
+
with the API.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
@classmethod
|
|
12
|
+
def from_dict(cls, config: dict) -> "ClientConfig":
|
|
13
|
+
"""
|
|
14
|
+
Initializes the client config with the given configuration.
|
|
15
|
+
|
|
16
|
+
This class method should be implemented to parse the given
|
|
17
|
+
configuration and create an instance of an client config.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
config: (dict) The config from which to initialize.
|
|
21
|
+
|
|
22
|
+
Raises:
|
|
23
|
+
ValueError: Config is missing required keys.
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
ClientConfig
|
|
27
|
+
"""
|
|
28
|
+
...
|
|
29
|
+
|
|
30
|
+
def to_dict(self) -> dict:
|
|
31
|
+
"""
|
|
32
|
+
Returns the configuration for that the client config is initialized with.
|
|
33
|
+
|
|
34
|
+
This method should be implemented to return a dictionary containing
|
|
35
|
+
the configuration settings for the client config.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
dictionary containing the configuration settings for the client config.
|
|
39
|
+
"""
|
|
40
|
+
...
|
|
41
|
+
|
|
42
|
+
@staticmethod
|
|
43
|
+
def resolve_config_aliases(config: dict) -> dict:
|
|
44
|
+
"""
|
|
45
|
+
Resolve any potential aliases in the configuration.
|
|
46
|
+
|
|
47
|
+
This method should be implemented to resolve any potential aliases in the
|
|
48
|
+
configuration.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
config: (dict) The config from which to initialize.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
dictionary containing the resolved configuration settings for the
|
|
55
|
+
client config.
|
|
56
|
+
"""
|
|
57
|
+
...
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
from dataclasses import asdict, dataclass, field
|
|
2
|
+
from typing import Any, Dict
|
|
3
|
+
|
|
4
|
+
import structlog
|
|
5
|
+
|
|
6
|
+
from rasa.shared.constants import (
|
|
7
|
+
MODEL_CONFIG_KEY,
|
|
8
|
+
MODEL_NAME_CONFIG_KEY,
|
|
9
|
+
STREAM_CONFIG_KEY,
|
|
10
|
+
N_REPHRASES_CONFIG_KEY,
|
|
11
|
+
PROVIDER_CONFIG_KEY,
|
|
12
|
+
TIMEOUT_CONFIG_KEY,
|
|
13
|
+
REQUEST_TIMEOUT_CONFIG_KEY,
|
|
14
|
+
)
|
|
15
|
+
from rasa.shared.providers._configs.utils import (
|
|
16
|
+
validate_required_keys,
|
|
17
|
+
validate_forbidden_keys,
|
|
18
|
+
resolve_aliases,
|
|
19
|
+
raise_deprecation_warnings,
|
|
20
|
+
)
|
|
21
|
+
import rasa.shared.utils.cli
|
|
22
|
+
|
|
23
|
+
structlogger = structlog.get_logger()
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING = {
|
|
27
|
+
# Timeout aliases
|
|
28
|
+
REQUEST_TIMEOUT_CONFIG_KEY: TIMEOUT_CONFIG_KEY,
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
REQUIRED_KEYS = [MODEL_CONFIG_KEY, PROVIDER_CONFIG_KEY]
|
|
32
|
+
|
|
33
|
+
FORBIDDEN_KEYS = [
|
|
34
|
+
STREAM_CONFIG_KEY,
|
|
35
|
+
N_REPHRASES_CONFIG_KEY,
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass
|
|
40
|
+
class DefaultLiteLLMClientConfig:
|
|
41
|
+
"""Parses configuration for default LiteLLM client, resolves aliases and
|
|
42
|
+
raises deprecation warnings.
|
|
43
|
+
|
|
44
|
+
Raises:
|
|
45
|
+
ValueError: Raised in cases of invalid configuration:
|
|
46
|
+
- If any of the required configuration keys are missing.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
model: str
|
|
50
|
+
provider: str
|
|
51
|
+
extra_parameters: dict = field(default_factory=dict)
|
|
52
|
+
|
|
53
|
+
def __post_init__(self) -> None:
|
|
54
|
+
if self.model is None:
|
|
55
|
+
message = "Model cannot be set to None."
|
|
56
|
+
structlogger.error(
|
|
57
|
+
"default_litellm_client_config.validation_error",
|
|
58
|
+
message=message,
|
|
59
|
+
model=self.model,
|
|
60
|
+
)
|
|
61
|
+
raise ValueError(message)
|
|
62
|
+
if self.provider is None:
|
|
63
|
+
message = "Provider cannot be set to None."
|
|
64
|
+
structlogger.error(
|
|
65
|
+
"default_litellm_client_config.validation_error",
|
|
66
|
+
message=message,
|
|
67
|
+
provider=self.provider,
|
|
68
|
+
)
|
|
69
|
+
raise ValueError(message)
|
|
70
|
+
|
|
71
|
+
@classmethod
|
|
72
|
+
def from_dict(cls, config: dict) -> "DefaultLiteLLMClientConfig":
|
|
73
|
+
"""
|
|
74
|
+
Initializes a dataclass from the passed config.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
config: (dict) The config from which to initialize.
|
|
78
|
+
|
|
79
|
+
Raises:
|
|
80
|
+
ValueError: Config is missing required keys.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
DefaultLiteLLMClientConfig
|
|
84
|
+
"""
|
|
85
|
+
# Check for deprecated keys
|
|
86
|
+
raise_deprecation_warnings(config, DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING)
|
|
87
|
+
# Raise error for using `model_name` instead instead of `model`
|
|
88
|
+
cls.check_and_error_for_model_name_in_config(config)
|
|
89
|
+
# Resolve any potential aliases.
|
|
90
|
+
config = cls.resolve_config_aliases(config)
|
|
91
|
+
# Validate that the required keys are present
|
|
92
|
+
validate_required_keys(config, REQUIRED_KEYS)
|
|
93
|
+
# Validate that the forbidden keys are not present
|
|
94
|
+
validate_forbidden_keys(config, FORBIDDEN_KEYS)
|
|
95
|
+
this = DefaultLiteLLMClientConfig(
|
|
96
|
+
# Required parameters
|
|
97
|
+
model=config.pop(MODEL_CONFIG_KEY),
|
|
98
|
+
provider=config.pop(PROVIDER_CONFIG_KEY),
|
|
99
|
+
# The rest of parameters (e.g. model parameters) are considered
|
|
100
|
+
# as extra parameters
|
|
101
|
+
extra_parameters=config,
|
|
102
|
+
)
|
|
103
|
+
return this
|
|
104
|
+
|
|
105
|
+
def to_dict(self) -> dict:
|
|
106
|
+
"""Converts the config instance into a dictionary."""
|
|
107
|
+
d = asdict(self)
|
|
108
|
+
# Extra parameters should also be on the top level
|
|
109
|
+
d.pop("extra_parameters", None)
|
|
110
|
+
d.update(self.extra_parameters)
|
|
111
|
+
return d
|
|
112
|
+
|
|
113
|
+
@staticmethod
|
|
114
|
+
def resolve_config_aliases(config: Dict[str, Any]) -> Dict[str, Any]:
|
|
115
|
+
return resolve_aliases(config, DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING)
|
|
116
|
+
|
|
117
|
+
@staticmethod
|
|
118
|
+
def check_and_error_for_model_name_in_config(config: Dict[str, Any]) -> None:
|
|
119
|
+
"""Check for usage of deprecated model_name and raise an error if found."""
|
|
120
|
+
if config.get(MODEL_NAME_CONFIG_KEY) and not config.get(MODEL_CONFIG_KEY):
|
|
121
|
+
event_info = (
|
|
122
|
+
f"Unsupported parameter - {MODEL_NAME_CONFIG_KEY} is set. Please use "
|
|
123
|
+
f"{MODEL_CONFIG_KEY} instead."
|
|
124
|
+
)
|
|
125
|
+
structlogger.error(
|
|
126
|
+
"default_litellm_client_config.unsupported_parameter_in_config",
|
|
127
|
+
event_info=event_info,
|
|
128
|
+
config=config,
|
|
129
|
+
)
|
|
130
|
+
rasa.shared.utils.cli.print_error_and_exit(event_info)
|