rasa-pro 3.9.18__py3-none-any.whl → 3.10.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- README.md +0 -374
- rasa/__init__.py +1 -2
- rasa/__main__.py +5 -0
- rasa/anonymization/anonymization_rule_executor.py +2 -2
- rasa/api.py +27 -23
- rasa/cli/arguments/data.py +27 -2
- rasa/cli/arguments/default_arguments.py +25 -3
- rasa/cli/arguments/run.py +9 -9
- rasa/cli/arguments/train.py +11 -3
- rasa/cli/data.py +70 -8
- rasa/cli/e2e_test.py +104 -431
- rasa/cli/evaluate.py +1 -1
- rasa/cli/interactive.py +1 -0
- rasa/cli/llm_fine_tuning.py +398 -0
- rasa/cli/project_templates/calm/endpoints.yml +1 -1
- rasa/cli/project_templates/tutorial/endpoints.yml +1 -1
- rasa/cli/run.py +15 -14
- rasa/cli/scaffold.py +10 -8
- rasa/cli/studio/studio.py +35 -5
- rasa/cli/train.py +56 -8
- rasa/cli/utils.py +22 -5
- rasa/cli/x.py +1 -1
- rasa/constants.py +7 -1
- rasa/core/actions/action.py +98 -49
- rasa/core/actions/action_run_slot_rejections.py +4 -1
- rasa/core/actions/custom_action_executor.py +9 -6
- rasa/core/actions/direct_custom_actions_executor.py +80 -0
- rasa/core/actions/e2e_stub_custom_action_executor.py +68 -0
- rasa/core/actions/grpc_custom_action_executor.py +2 -2
- rasa/core/actions/http_custom_action_executor.py +6 -5
- rasa/core/agent.py +21 -17
- rasa/core/channels/__init__.py +2 -0
- rasa/core/channels/audiocodes.py +1 -16
- rasa/core/channels/voice_aware/__init__.py +0 -0
- rasa/core/channels/voice_aware/jambonz.py +103 -0
- rasa/core/channels/voice_aware/jambonz_protocol.py +344 -0
- rasa/core/channels/voice_aware/utils.py +20 -0
- rasa/core/channels/voice_native/__init__.py +0 -0
- rasa/core/constants.py +6 -1
- rasa/core/information_retrieval/faiss.py +7 -4
- rasa/core/information_retrieval/information_retrieval.py +8 -0
- rasa/core/information_retrieval/milvus.py +9 -2
- rasa/core/information_retrieval/qdrant.py +1 -1
- rasa/core/nlg/contextual_response_rephraser.py +32 -10
- rasa/core/nlg/summarize.py +4 -3
- rasa/core/policies/enterprise_search_policy.py +113 -45
- rasa/core/policies/flows/flow_executor.py +122 -76
- rasa/core/policies/intentless_policy.py +83 -29
- rasa/core/processor.py +72 -54
- rasa/core/run.py +5 -4
- rasa/core/tracker_store.py +8 -4
- rasa/core/training/interactive.py +1 -1
- rasa/core/utils.py +56 -57
- rasa/dialogue_understanding/coexistence/llm_based_router.py +53 -13
- rasa/dialogue_understanding/commands/__init__.py +6 -0
- rasa/dialogue_understanding/commands/restart_command.py +58 -0
- rasa/dialogue_understanding/commands/session_start_command.py +59 -0
- rasa/dialogue_understanding/commands/utils.py +40 -0
- rasa/dialogue_understanding/generator/constants.py +10 -3
- rasa/dialogue_understanding/generator/flow_retrieval.py +21 -5
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +13 -3
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +134 -90
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +47 -7
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +127 -41
- rasa/dialogue_understanding/patterns/restart.py +37 -0
- rasa/dialogue_understanding/patterns/session_start.py +37 -0
- rasa/dialogue_understanding/processor/command_processor.py +16 -3
- rasa/dialogue_understanding/processor/command_processor_component.py +6 -2
- rasa/e2e_test/aggregate_test_stats_calculator.py +134 -0
- rasa/e2e_test/assertions.py +1223 -0
- rasa/e2e_test/assertions_schema.yml +106 -0
- rasa/e2e_test/constants.py +20 -0
- rasa/e2e_test/e2e_config.py +220 -0
- rasa/e2e_test/e2e_config_schema.yml +26 -0
- rasa/e2e_test/e2e_test_case.py +131 -8
- rasa/e2e_test/e2e_test_converter.py +363 -0
- rasa/e2e_test/e2e_test_converter_prompt.jinja2 +70 -0
- rasa/e2e_test/e2e_test_coverage_report.py +364 -0
- rasa/e2e_test/e2e_test_result.py +26 -6
- rasa/e2e_test/e2e_test_runner.py +493 -71
- rasa/e2e_test/e2e_test_schema.yml +96 -0
- rasa/e2e_test/pykwalify_extensions.py +39 -0
- rasa/e2e_test/stub_custom_action.py +70 -0
- rasa/e2e_test/utils/__init__.py +0 -0
- rasa/e2e_test/utils/e2e_yaml_utils.py +55 -0
- rasa/e2e_test/utils/io.py +598 -0
- rasa/e2e_test/utils/validation.py +80 -0
- rasa/engine/graph.py +9 -3
- rasa/engine/recipes/default_components.py +0 -2
- rasa/engine/recipes/default_recipe.py +10 -2
- rasa/engine/storage/local_model_storage.py +40 -12
- rasa/engine/validation.py +78 -1
- rasa/env.py +9 -0
- rasa/graph_components/providers/story_graph_provider.py +59 -6
- rasa/llm_fine_tuning/__init__.py +0 -0
- rasa/llm_fine_tuning/annotation_module.py +241 -0
- rasa/llm_fine_tuning/conversations.py +144 -0
- rasa/llm_fine_tuning/llm_data_preparation_module.py +178 -0
- rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +407 -0
- rasa/llm_fine_tuning/paraphrasing/__init__.py +0 -0
- rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +281 -0
- rasa/llm_fine_tuning/paraphrasing/default_rephrase_prompt_template.jina2 +44 -0
- rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +121 -0
- rasa/llm_fine_tuning/paraphrasing/rephrased_user_message.py +10 -0
- rasa/llm_fine_tuning/paraphrasing_module.py +128 -0
- rasa/llm_fine_tuning/storage.py +174 -0
- rasa/llm_fine_tuning/train_test_split_module.py +441 -0
- rasa/model_training.py +56 -16
- rasa/nlu/persistor.py +157 -36
- rasa/server.py +45 -10
- rasa/shared/constants.py +76 -16
- rasa/shared/core/domain.py +27 -19
- rasa/shared/core/events.py +28 -2
- rasa/shared/core/flows/flow.py +208 -13
- rasa/shared/core/flows/flow_path.py +84 -0
- rasa/shared/core/flows/flows_list.py +33 -11
- rasa/shared/core/flows/flows_yaml_schema.json +269 -193
- rasa/shared/core/flows/validation.py +112 -25
- rasa/shared/core/flows/yaml_flows_io.py +149 -10
- rasa/shared/core/trackers.py +6 -0
- rasa/shared/core/training_data/structures.py +20 -0
- rasa/shared/core/training_data/visualization.html +2 -2
- rasa/shared/exceptions.py +4 -0
- rasa/shared/importers/importer.py +64 -16
- rasa/shared/nlu/constants.py +2 -0
- rasa/shared/providers/_configs/__init__.py +0 -0
- rasa/shared/providers/_configs/azure_openai_client_config.py +183 -0
- rasa/shared/providers/_configs/client_config.py +57 -0
- rasa/shared/providers/_configs/default_litellm_client_config.py +130 -0
- rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +234 -0
- rasa/shared/providers/_configs/openai_client_config.py +175 -0
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +176 -0
- rasa/shared/providers/_configs/utils.py +101 -0
- rasa/shared/providers/_ssl_verification_utils.py +124 -0
- rasa/shared/providers/embedding/__init__.py +0 -0
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +259 -0
- rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py +74 -0
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +277 -0
- rasa/shared/providers/embedding/default_litellm_embedding_client.py +102 -0
- rasa/shared/providers/embedding/embedding_client.py +90 -0
- rasa/shared/providers/embedding/embedding_response.py +41 -0
- rasa/shared/providers/embedding/huggingface_local_embedding_client.py +191 -0
- rasa/shared/providers/embedding/openai_embedding_client.py +172 -0
- rasa/shared/providers/llm/__init__.py +0 -0
- rasa/shared/providers/llm/_base_litellm_client.py +251 -0
- rasa/shared/providers/llm/azure_openai_llm_client.py +338 -0
- rasa/shared/providers/llm/default_litellm_llm_client.py +84 -0
- rasa/shared/providers/llm/llm_client.py +76 -0
- rasa/shared/providers/llm/llm_response.py +50 -0
- rasa/shared/providers/llm/openai_llm_client.py +155 -0
- rasa/shared/providers/llm/self_hosted_llm_client.py +293 -0
- rasa/shared/providers/mappings.py +75 -0
- rasa/shared/utils/cli.py +30 -0
- rasa/shared/utils/io.py +65 -2
- rasa/shared/utils/llm.py +246 -200
- rasa/shared/utils/yaml.py +121 -15
- rasa/studio/auth.py +6 -4
- rasa/studio/config.py +13 -4
- rasa/studio/constants.py +1 -0
- rasa/studio/data_handler.py +10 -3
- rasa/studio/download.py +19 -13
- rasa/studio/train.py +2 -3
- rasa/studio/upload.py +19 -11
- rasa/telemetry.py +113 -58
- rasa/tracing/instrumentation/attribute_extractors.py +32 -17
- rasa/utils/common.py +18 -19
- rasa/utils/endpoints.py +7 -4
- rasa/utils/json_utils.py +60 -0
- rasa/utils/licensing.py +9 -1
- rasa/utils/ml_utils.py +4 -2
- rasa/validator.py +213 -3
- rasa/version.py +1 -1
- rasa_pro-3.10.16.dist-info/METADATA +196 -0
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/RECORD +179 -113
- rasa/nlu/classifiers/llm_intent_classifier.py +0 -519
- rasa/shared/providers/openai/clients.py +0 -43
- rasa/shared/providers/openai/session_handler.py +0 -110
- rasa_pro-3.9.18.dist-info/METADATA +0 -563
- /rasa/{shared/providers/openai → cli/project_templates/tutorial/actions}/__init__.py +0 -0
- /rasa/cli/project_templates/tutorial/{actions.py → actions/actions.py} +0 -0
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/NOTICE +0 -0
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/WHEEL +0 -0
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/entry_points.txt +0 -0
rasa/shared/utils/llm.py
CHANGED
|
@@ -1,6 +1,18 @@
|
|
|
1
|
-
import
|
|
2
|
-
import
|
|
3
|
-
from typing import
|
|
1
|
+
import sys
|
|
2
|
+
from functools import wraps
|
|
3
|
+
from typing import (
|
|
4
|
+
Any,
|
|
5
|
+
Callable,
|
|
6
|
+
Dict,
|
|
7
|
+
Optional,
|
|
8
|
+
Text,
|
|
9
|
+
Type,
|
|
10
|
+
TypeVar,
|
|
11
|
+
TYPE_CHECKING,
|
|
12
|
+
Union,
|
|
13
|
+
cast,
|
|
14
|
+
)
|
|
15
|
+
import json
|
|
4
16
|
|
|
5
17
|
import structlog
|
|
6
18
|
|
|
@@ -8,39 +20,42 @@ import rasa.shared.utils.io
|
|
|
8
20
|
from rasa.shared.constants import (
|
|
9
21
|
RASA_PATTERN_INTERNAL_ERROR_USER_INPUT_TOO_LONG,
|
|
10
22
|
RASA_PATTERN_INTERNAL_ERROR_USER_INPUT_EMPTY,
|
|
11
|
-
|
|
12
|
-
OPENAI_API_VERSION_ENV_VAR,
|
|
13
|
-
OPENAI_API_BASE_ENV_VAR,
|
|
14
|
-
REQUESTS_CA_BUNDLE_ENV_VAR,
|
|
15
|
-
OPENAI_API_BASE_NO_PREFIX_CONFIG_KEY,
|
|
16
|
-
OPENAI_API_TYPE_NO_PREFIX_CONFIG_KEY,
|
|
17
|
-
OPENAI_API_VERSION_CONFIG_KEY,
|
|
18
|
-
OPENAI_API_VERSION_NO_PREFIX_CONFIG_KEY,
|
|
19
|
-
OPENAI_API_TYPE_CONFIG_KEY,
|
|
20
|
-
OPENAI_API_BASE_CONFIG_KEY,
|
|
21
|
-
OPENAI_DEPLOYMENT_NAME_CONFIG_KEY,
|
|
22
|
-
OPENAI_DEPLOYMENT_CONFIG_KEY,
|
|
23
|
-
OPENAI_ENGINE_CONFIG_KEY,
|
|
24
|
-
LANGCHAIN_TYPE_CONFIG_KEY,
|
|
25
|
-
RASA_TYPE_CONFIG_KEY,
|
|
23
|
+
PROVIDER_CONFIG_KEY,
|
|
26
24
|
)
|
|
27
25
|
from rasa.shared.core.events import BotUttered, UserUttered
|
|
28
26
|
from rasa.shared.core.slots import Slot, BooleanSlot, CategoricalSlot
|
|
29
|
-
from rasa.shared.engine.caching import
|
|
27
|
+
from rasa.shared.engine.caching import (
|
|
28
|
+
get_local_cache_location,
|
|
29
|
+
)
|
|
30
30
|
from rasa.shared.exceptions import (
|
|
31
31
|
FileIOException,
|
|
32
32
|
FileNotFoundException,
|
|
33
|
+
ProviderClientValidationError,
|
|
34
|
+
)
|
|
35
|
+
from rasa.shared.providers._configs.azure_openai_client_config import (
|
|
36
|
+
is_azure_openai_config,
|
|
37
|
+
)
|
|
38
|
+
from rasa.shared.providers._configs.huggingface_local_embedding_client_config import (
|
|
39
|
+
is_huggingface_local_config,
|
|
40
|
+
)
|
|
41
|
+
from rasa.shared.providers._configs.openai_client_config import is_openai_config
|
|
42
|
+
from rasa.shared.providers._configs.self_hosted_llm_client_config import (
|
|
43
|
+
is_self_hosted_config,
|
|
44
|
+
)
|
|
45
|
+
from rasa.shared.providers.embedding.embedding_client import EmbeddingClient
|
|
46
|
+
from rasa.shared.providers.llm.llm_client import LLMClient
|
|
47
|
+
from rasa.shared.providers.mappings import (
|
|
48
|
+
get_llm_client_from_provider,
|
|
49
|
+
AZURE_OPENAI_PROVIDER,
|
|
50
|
+
OPENAI_PROVIDER,
|
|
51
|
+
SELF_HOSTED_PROVIDER,
|
|
52
|
+
get_embedding_client_from_provider,
|
|
53
|
+
HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER,
|
|
54
|
+
get_client_config_class_from_provider,
|
|
33
55
|
)
|
|
34
56
|
|
|
35
57
|
if TYPE_CHECKING:
|
|
36
|
-
from langchain.chat_models import AzureChatOpenAI
|
|
37
|
-
from langchain.schema.embeddings import Embeddings
|
|
38
|
-
from langchain.llms.base import BaseLLM
|
|
39
58
|
from rasa.shared.core.trackers import DialogueStateTracker
|
|
40
|
-
from rasa.shared.providers.openai.clients import (
|
|
41
|
-
AioHTTPSessionAzureChatOpenAI,
|
|
42
|
-
AioHTTPSessionOpenAIChat,
|
|
43
|
-
)
|
|
44
59
|
|
|
45
60
|
structlogger = structlog.get_logger()
|
|
46
61
|
|
|
@@ -52,7 +67,7 @@ DEFAULT_OPENAI_GENERATE_MODEL_NAME = "gpt-3.5-turbo"
|
|
|
52
67
|
|
|
53
68
|
DEFAULT_OPENAI_CHAT_MODEL_NAME = "gpt-3.5-turbo"
|
|
54
69
|
|
|
55
|
-
DEFAULT_OPENAI_CHAT_MODEL_NAME_ADVANCED = "gpt-4"
|
|
70
|
+
DEFAULT_OPENAI_CHAT_MODEL_NAME_ADVANCED = "gpt-4-0613"
|
|
56
71
|
|
|
57
72
|
DEFAULT_OPENAI_EMBEDDING_MODEL_NAME = "text-embedding-ada-002"
|
|
58
73
|
|
|
@@ -70,6 +85,94 @@ ERROR_PLACEHOLDER = {
|
|
|
70
85
|
"default": "[User input triggered an error]",
|
|
71
86
|
}
|
|
72
87
|
|
|
88
|
+
_Factory_F = TypeVar(
|
|
89
|
+
"_Factory_F",
|
|
90
|
+
bound=Callable[[Dict[str, Any], Dict[str, Any]], Union[EmbeddingClient, LLMClient]],
|
|
91
|
+
)
|
|
92
|
+
_CombineConfigs_F = TypeVar(
|
|
93
|
+
"_CombineConfigs_F",
|
|
94
|
+
bound=Callable[[Dict[str, Any], Dict[str, Any]], Dict[str, Any]],
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _compute_hash_for_cache_from_configs(
|
|
99
|
+
config_x: Dict[str, Any], config_y: Dict[str, Any]
|
|
100
|
+
) -> int:
|
|
101
|
+
"""Get a unique hash of the default and custom configs."""
|
|
102
|
+
return hash(
|
|
103
|
+
json.dumps(config_x, sort_keys=True) + json.dumps(config_y, sort_keys=True)
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def _retrieve_from_cache(
|
|
108
|
+
cache: Dict[int, Any], unique_hash: int, function: Callable, function_kwargs: dict
|
|
109
|
+
) -> Any:
|
|
110
|
+
"""Retrieve the value from the cache if it exists. If it does not exist, cache it"""
|
|
111
|
+
if unique_hash in cache:
|
|
112
|
+
return cache[unique_hash]
|
|
113
|
+
else:
|
|
114
|
+
return_value = function(**function_kwargs)
|
|
115
|
+
cache[unique_hash] = return_value
|
|
116
|
+
return return_value
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _cache_factory(function: _Factory_F) -> _Factory_F:
|
|
120
|
+
"""Memoize the factory methods based on the arguments."""
|
|
121
|
+
cache: Dict[int, Union[EmbeddingClient, LLMClient]] = {}
|
|
122
|
+
|
|
123
|
+
@wraps(function)
|
|
124
|
+
def factory_method_wrapper(
|
|
125
|
+
config_x: Dict[str, Any], config_y: Dict[str, Any]
|
|
126
|
+
) -> Union[EmbeddingClient, LLMClient]:
|
|
127
|
+
# Get a unique hash of the default and custom configs.
|
|
128
|
+
unique_hash = _compute_hash_for_cache_from_configs(config_x, config_y)
|
|
129
|
+
return _retrieve_from_cache(
|
|
130
|
+
cache=cache,
|
|
131
|
+
unique_hash=unique_hash,
|
|
132
|
+
function=function,
|
|
133
|
+
function_kwargs={"custom_config": config_x, "default_config": config_y},
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
def clear_cache() -> None:
|
|
137
|
+
cache.clear()
|
|
138
|
+
structlogger.debug(
|
|
139
|
+
"Cleared cache for factory method",
|
|
140
|
+
function_name=function.__name__,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
setattr(factory_method_wrapper, "clear_cache", clear_cache)
|
|
144
|
+
return cast(_Factory_F, factory_method_wrapper)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def _cache_combine_custom_and_default_configs(
|
|
148
|
+
function: _CombineConfigs_F,
|
|
149
|
+
) -> _CombineConfigs_F:
|
|
150
|
+
"""Memoize the combine_custom_and_default_config method based on the arguments."""
|
|
151
|
+
cache: Dict[int, dict] = {}
|
|
152
|
+
|
|
153
|
+
@wraps(function)
|
|
154
|
+
def combine_configs_wrapper(
|
|
155
|
+
config_x: Dict[str, Any], config_y: Dict[str, Any]
|
|
156
|
+
) -> dict:
|
|
157
|
+
# Get a unique hash of the default and custom configs.
|
|
158
|
+
unique_hash = _compute_hash_for_cache_from_configs(config_x, config_y)
|
|
159
|
+
return _retrieve_from_cache(
|
|
160
|
+
cache=cache,
|
|
161
|
+
unique_hash=unique_hash,
|
|
162
|
+
function=function,
|
|
163
|
+
function_kwargs={"custom_config": config_x, "default_config": config_y},
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
def clear_cache() -> None:
|
|
167
|
+
cache.clear()
|
|
168
|
+
structlogger.debug(
|
|
169
|
+
"Cleared cache for combine_custom_and_default_config method",
|
|
170
|
+
function_name=function.__name__,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
setattr(combine_configs_wrapper, "clear_cache", clear_cache)
|
|
174
|
+
return cast(_CombineConfigs_F, combine_configs_wrapper)
|
|
175
|
+
|
|
73
176
|
|
|
74
177
|
def tracker_as_readable_transcript(
|
|
75
178
|
tracker: "DialogueStateTracker",
|
|
@@ -138,11 +241,15 @@ def sanitize_message_for_prompt(text: Optional[str]) -> str:
|
|
|
138
241
|
return text.replace("\n", " ") if text else ""
|
|
139
242
|
|
|
140
243
|
|
|
244
|
+
@_cache_combine_custom_and_default_configs
|
|
141
245
|
def combine_custom_and_default_config(
|
|
142
|
-
custom_config: Optional[Dict[
|
|
246
|
+
custom_config: Optional[Dict[str, Any]], default_config: Dict[str, Any]
|
|
143
247
|
) -> Dict[Text, Any]:
|
|
144
248
|
"""Merges the given llm config with the default config.
|
|
145
249
|
|
|
250
|
+
This method guarantees that the provider is set and all the deprecated keys are
|
|
251
|
+
resolved. Hence, produces only a valid client config.
|
|
252
|
+
|
|
146
253
|
Only uses the default configuration arguments, if the type set in the
|
|
147
254
|
custom config matches the type in the default config. Otherwise, only
|
|
148
255
|
the custom config is used.
|
|
@@ -155,155 +262,96 @@ def combine_custom_and_default_config(
|
|
|
155
262
|
The merged config.
|
|
156
263
|
"""
|
|
157
264
|
if custom_config is None:
|
|
158
|
-
return default_config
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
265
|
+
return default_config.copy()
|
|
266
|
+
|
|
267
|
+
# Get the provider from the custom config.
|
|
268
|
+
custom_config_provider = get_provider_from_config(custom_config)
|
|
269
|
+
# We expect the provider to be set in the default configs of all Rasa components.
|
|
270
|
+
default_config_provider = default_config[PROVIDER_CONFIG_KEY]
|
|
271
|
+
|
|
272
|
+
if (
|
|
273
|
+
custom_config_provider is not None
|
|
274
|
+
and custom_config_provider != default_config_provider
|
|
275
|
+
):
|
|
276
|
+
# Get the provider-specific config class
|
|
277
|
+
client_config_clazz = get_client_config_class_from_provider(
|
|
278
|
+
custom_config_provider
|
|
166
279
|
)
|
|
280
|
+
# Checks for deprecated keys, resolves aliases and returns a valid config.
|
|
281
|
+
# This is done to ensure that the custom config is valid.
|
|
282
|
+
return client_config_clazz.from_dict(custom_config).to_dict()
|
|
283
|
+
|
|
284
|
+
# If the provider is the same in both configs
|
|
285
|
+
# OR provider is not specified in the custom config
|
|
286
|
+
# perform MERGE by overriding the default config keys and values
|
|
287
|
+
# with custom config keys and values.
|
|
288
|
+
merged_config = {**default_config.copy(), **custom_config.copy()}
|
|
289
|
+
# Check for deprecated keys, resolve aliases and return a valid config.
|
|
290
|
+
# This is done to ensure that the merged config is valid.
|
|
291
|
+
default_config_clazz = get_client_config_class_from_provider(
|
|
292
|
+
default_config_provider
|
|
293
|
+
)
|
|
294
|
+
return default_config_clazz.from_dict(merged_config).to_dict()
|
|
167
295
|
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
296
|
+
|
|
297
|
+
def get_provider_from_config(config: dict) -> Optional[str]:
|
|
298
|
+
"""Try to get the provider from the passed llm/embeddings configuration.
|
|
299
|
+
If no provider can be found, return None.
|
|
300
|
+
"""
|
|
301
|
+
if not config:
|
|
302
|
+
return None
|
|
303
|
+
if is_self_hosted_config(config):
|
|
304
|
+
return SELF_HOSTED_PROVIDER
|
|
305
|
+
elif is_azure_openai_config(config):
|
|
306
|
+
return AZURE_OPENAI_PROVIDER
|
|
307
|
+
elif is_openai_config(config):
|
|
308
|
+
return OPENAI_PROVIDER
|
|
309
|
+
elif is_huggingface_local_config(config):
|
|
310
|
+
return HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER
|
|
311
|
+
else:
|
|
312
|
+
return config.get(PROVIDER_CONFIG_KEY)
|
|
173
313
|
|
|
174
314
|
|
|
175
315
|
def ensure_cache() -> None:
|
|
176
316
|
"""Ensures that the cache is initialized."""
|
|
177
|
-
import
|
|
178
|
-
from langchain.cache import SQLiteCache
|
|
317
|
+
import litellm
|
|
179
318
|
|
|
180
|
-
#
|
|
181
|
-
cache_location = get_local_cache_location()
|
|
319
|
+
# Ensure the cache directory exists
|
|
320
|
+
cache_location = get_local_cache_location() / "rasa-llm-cache"
|
|
182
321
|
cache_location.mkdir(parents=True, exist_ok=True)
|
|
183
322
|
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
def preprocess_config_for_azure(config: Dict[str, Any]) -> Dict[str, Any]:
|
|
189
|
-
"""Preprocesses the config for Azure deployments.
|
|
190
|
-
|
|
191
|
-
This function is used to preprocess the config for Azure deployments.
|
|
192
|
-
AzureChatOpenAI does not expect the _type key, as it is not a defined parameter
|
|
193
|
-
in the class. So we need to remove it before passing the config to the class.
|
|
194
|
-
AzureChatOpenAI expects the openai_api_type key to be set instead.
|
|
195
|
-
|
|
196
|
-
Args:
|
|
197
|
-
config: The config to preprocess.
|
|
198
|
-
|
|
199
|
-
Returns:
|
|
200
|
-
The preprocessed config.
|
|
201
|
-
"""
|
|
202
|
-
config["deployment_name"] = (
|
|
203
|
-
config.get(OPENAI_DEPLOYMENT_NAME_CONFIG_KEY)
|
|
204
|
-
or config.get(OPENAI_DEPLOYMENT_CONFIG_KEY)
|
|
205
|
-
or config.get(OPENAI_ENGINE_CONFIG_KEY)
|
|
206
|
-
)
|
|
207
|
-
config["openai_api_base"] = (
|
|
208
|
-
config.get(OPENAI_API_BASE_CONFIG_KEY)
|
|
209
|
-
or config.get(OPENAI_API_BASE_NO_PREFIX_CONFIG_KEY)
|
|
210
|
-
or os.environ.get(OPENAI_API_BASE_ENV_VAR)
|
|
211
|
-
)
|
|
212
|
-
config["openai_api_type"] = (
|
|
213
|
-
config.get(OPENAI_API_TYPE_CONFIG_KEY)
|
|
214
|
-
or config.get(OPENAI_API_TYPE_NO_PREFIX_CONFIG_KEY)
|
|
215
|
-
or os.environ.get(OPENAI_API_TYPE_ENV_VAR)
|
|
216
|
-
)
|
|
217
|
-
config["openai_api_version"] = (
|
|
218
|
-
config.get(OPENAI_API_VERSION_CONFIG_KEY)
|
|
219
|
-
or config.get(OPENAI_API_VERSION_NO_PREFIX_CONFIG_KEY)
|
|
220
|
-
or os.environ.get(OPENAI_API_VERSION_ENV_VAR)
|
|
221
|
-
)
|
|
222
|
-
for keys in [
|
|
223
|
-
OPENAI_API_BASE_NO_PREFIX_CONFIG_KEY,
|
|
224
|
-
OPENAI_API_TYPE_NO_PREFIX_CONFIG_KEY,
|
|
225
|
-
OPENAI_API_VERSION_NO_PREFIX_CONFIG_KEY,
|
|
226
|
-
OPENAI_DEPLOYMENT_CONFIG_KEY,
|
|
227
|
-
OPENAI_ENGINE_CONFIG_KEY,
|
|
228
|
-
LANGCHAIN_TYPE_CONFIG_KEY,
|
|
229
|
-
]:
|
|
230
|
-
config.pop(keys, None)
|
|
231
|
-
|
|
232
|
-
return config
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
def process_config_for_aiohttp_chat_openai(config: Dict[str, Any]) -> Dict[str, Any]:
|
|
236
|
-
config = config.copy()
|
|
237
|
-
config.pop(LANGCHAIN_TYPE_CONFIG_KEY)
|
|
238
|
-
return config
|
|
323
|
+
# Set diskcache as a caching option
|
|
324
|
+
litellm.cache = litellm.Cache(type="disk", disk_cache_dir=cache_location)
|
|
239
325
|
|
|
240
326
|
|
|
327
|
+
@_cache_factory
|
|
241
328
|
def llm_factory(
|
|
242
329
|
custom_config: Optional[Dict[str, Any]], default_config: Dict[str, Any]
|
|
243
|
-
) ->
|
|
244
|
-
"BaseLLM",
|
|
245
|
-
"AzureChatOpenAI",
|
|
246
|
-
"AioHTTPSessionAzureChatOpenAI",
|
|
247
|
-
"AioHTTPSessionOpenAIChat",
|
|
248
|
-
]:
|
|
330
|
+
) -> LLMClient:
|
|
249
331
|
"""Creates an LLM from the given config.
|
|
250
332
|
|
|
251
333
|
Args:
|
|
252
334
|
custom_config: The custom config containing values to overwrite defaults
|
|
253
335
|
default_config: The default config.
|
|
254
336
|
|
|
255
|
-
|
|
256
337
|
Returns:
|
|
257
|
-
|
|
338
|
+
Instantiated LLM based on the configuration.
|
|
258
339
|
"""
|
|
259
|
-
from langchain.llms.loading import load_llm_from_config
|
|
260
|
-
|
|
261
|
-
ensure_cache()
|
|
262
|
-
|
|
263
340
|
config = combine_custom_and_default_config(custom_config, default_config)
|
|
264
341
|
|
|
265
|
-
|
|
266
|
-
# config in place...
|
|
267
|
-
structlogger.debug("llmfactory.create.llm", config=config)
|
|
268
|
-
# langchain issues a user warning when using chat models. at the same time
|
|
269
|
-
# it doesn't provide a way to instantiate a chat model directly using the
|
|
270
|
-
# config. so for now, we need to suppress the warning here. Original
|
|
271
|
-
# warning:
|
|
272
|
-
# packages/langchain/llms/openai.py:189: UserWarning: You are trying to
|
|
273
|
-
# use a chat model. This way of initializing it is no longer supported.
|
|
274
|
-
# Instead, please use: `from langchain.chat_models import ChatOpenAI
|
|
275
|
-
with warnings.catch_warnings():
|
|
276
|
-
warnings.simplefilter("ignore", category=UserWarning)
|
|
277
|
-
if is_azure_config(config):
|
|
278
|
-
# Azure deployments are treated differently. This is done as the
|
|
279
|
-
# GPT-3.5 Turbo newer versions 0613 and 1106 only support the
|
|
280
|
-
# Chat Completions API.
|
|
281
|
-
from langchain.chat_models import AzureChatOpenAI
|
|
282
|
-
from rasa.shared.providers.openai.clients import (
|
|
283
|
-
AioHTTPSessionAzureChatOpenAI,
|
|
284
|
-
)
|
|
285
|
-
|
|
286
|
-
transformed_config = preprocess_config_for_azure(config.copy())
|
|
287
|
-
if os.environ.get(REQUESTS_CA_BUNDLE_ENV_VAR) is None:
|
|
288
|
-
return AzureChatOpenAI(**transformed_config)
|
|
289
|
-
else:
|
|
290
|
-
return AioHTTPSessionAzureChatOpenAI(**transformed_config)
|
|
291
|
-
|
|
292
|
-
if (
|
|
293
|
-
os.environ.get(REQUESTS_CA_BUNDLE_ENV_VAR) is not None
|
|
294
|
-
and config.get(LANGCHAIN_TYPE_CONFIG_KEY) == "openai"
|
|
295
|
-
):
|
|
296
|
-
from rasa.shared.providers.openai.clients import AioHTTPSessionOpenAIChat
|
|
297
|
-
|
|
298
|
-
config = process_config_for_aiohttp_chat_openai(config)
|
|
299
|
-
return AioHTTPSessionOpenAIChat(**config.copy())
|
|
342
|
+
ensure_cache()
|
|
300
343
|
|
|
301
|
-
|
|
344
|
+
client_clazz: Type[LLMClient] = get_llm_client_from_provider(
|
|
345
|
+
config[PROVIDER_CONFIG_KEY]
|
|
346
|
+
)
|
|
347
|
+
client = client_clazz.from_config(config)
|
|
348
|
+
return client
|
|
302
349
|
|
|
303
350
|
|
|
351
|
+
@_cache_factory
|
|
304
352
|
def embedder_factory(
|
|
305
353
|
custom_config: Optional[Dict[str, Any]], default_config: Dict[str, Any]
|
|
306
|
-
) ->
|
|
354
|
+
) -> EmbeddingClient:
|
|
307
355
|
"""Creates an Embedder from the given config.
|
|
308
356
|
|
|
309
357
|
Args:
|
|
@@ -312,55 +360,17 @@ def embedder_factory(
|
|
|
312
360
|
|
|
313
361
|
|
|
314
362
|
Returns:
|
|
315
|
-
|
|
363
|
+
Instantiated Embedder based on the configuration.
|
|
316
364
|
"""
|
|
317
|
-
from langchain.schema.embeddings import Embeddings
|
|
318
|
-
from langchain.embeddings import (
|
|
319
|
-
CohereEmbeddings,
|
|
320
|
-
HuggingFaceHubEmbeddings,
|
|
321
|
-
HuggingFaceInstructEmbeddings,
|
|
322
|
-
HuggingFaceEmbeddings,
|
|
323
|
-
HuggingFaceBgeEmbeddings,
|
|
324
|
-
LlamaCppEmbeddings,
|
|
325
|
-
OpenAIEmbeddings,
|
|
326
|
-
SpacyEmbeddings,
|
|
327
|
-
VertexAIEmbeddings,
|
|
328
|
-
)
|
|
329
|
-
from rasa.shared.providers.openai.clients import AioHTTPSessionOpenAIEmbeddings
|
|
330
|
-
|
|
331
|
-
type_to_embedding_cls_dict: Dict[str, Type[Embeddings]] = {
|
|
332
|
-
"azure": OpenAIEmbeddings,
|
|
333
|
-
"openai": OpenAIEmbeddings,
|
|
334
|
-
"openai-aiohttp-session": AioHTTPSessionOpenAIEmbeddings,
|
|
335
|
-
"cohere": CohereEmbeddings,
|
|
336
|
-
"spacy": SpacyEmbeddings,
|
|
337
|
-
"vertexai": VertexAIEmbeddings,
|
|
338
|
-
"huggingface_instruct": HuggingFaceInstructEmbeddings,
|
|
339
|
-
"huggingface_hub": HuggingFaceHubEmbeddings,
|
|
340
|
-
"huggingface_bge": HuggingFaceBgeEmbeddings,
|
|
341
|
-
"huggingface": HuggingFaceEmbeddings,
|
|
342
|
-
"llamacpp": LlamaCppEmbeddings,
|
|
343
|
-
}
|
|
344
|
-
|
|
345
365
|
config = combine_custom_and_default_config(custom_config, default_config)
|
|
346
|
-
embedding_type = config.get(LANGCHAIN_TYPE_CONFIG_KEY)
|
|
347
366
|
|
|
348
|
-
|
|
349
|
-
os.environ.get(REQUESTS_CA_BUNDLE_ENV_VAR) is not None
|
|
350
|
-
and embedding_type is not None
|
|
351
|
-
):
|
|
352
|
-
embedding_type = f"{embedding_type}-aiohttp-session"
|
|
353
|
-
|
|
354
|
-
structlogger.debug("llmfactory.create.embedder", config=config)
|
|
367
|
+
ensure_cache()
|
|
355
368
|
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
return embeddings_cls(**parameters)
|
|
362
|
-
else:
|
|
363
|
-
raise ValueError(f"Unsupported embeddings type '{embedding_type}'")
|
|
369
|
+
client_clazz: Type[EmbeddingClient] = get_embedding_client_from_provider(
|
|
370
|
+
config[PROVIDER_CONFIG_KEY]
|
|
371
|
+
)
|
|
372
|
+
client = client_clazz.from_config(config)
|
|
373
|
+
return client
|
|
364
374
|
|
|
365
375
|
|
|
366
376
|
def get_prompt_template(
|
|
@@ -396,9 +406,45 @@ def allowed_values_for_slot(slot: Slot) -> Union[str, None]:
|
|
|
396
406
|
return None
|
|
397
407
|
|
|
398
408
|
|
|
399
|
-
def
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
409
|
+
def try_instantiate_llm_client(
|
|
410
|
+
custom_llm_config: Optional[Dict],
|
|
411
|
+
default_llm_config: Optional[Dict],
|
|
412
|
+
log_source_function: str,
|
|
413
|
+
log_source_component: str,
|
|
414
|
+
) -> None:
|
|
415
|
+
"""Validate llm configuration."""
|
|
416
|
+
try:
|
|
417
|
+
llm_factory(custom_llm_config, default_llm_config)
|
|
418
|
+
except (ProviderClientValidationError, ValueError) as e:
|
|
419
|
+
structlogger.error(
|
|
420
|
+
f"{log_source_function}.llm_instantiation_failed",
|
|
421
|
+
event_info=(
|
|
422
|
+
f"Unable to create the LLM client for component - "
|
|
423
|
+
f"{log_source_component}. Please make sure you specified the required "
|
|
424
|
+
f"environment variables and configuration keys."
|
|
425
|
+
),
|
|
426
|
+
error=e,
|
|
427
|
+
)
|
|
428
|
+
sys.exit(1)
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
def try_instantiate_embedder(
|
|
432
|
+
custom_embeddings_config: Optional[Dict],
|
|
433
|
+
default_embeddings_config: Optional[Dict],
|
|
434
|
+
log_source_function: str,
|
|
435
|
+
log_source_component: str,
|
|
436
|
+
) -> EmbeddingClient:
|
|
437
|
+
"""Validate embeddings configuration."""
|
|
438
|
+
try:
|
|
439
|
+
return embedder_factory(custom_embeddings_config, default_embeddings_config)
|
|
440
|
+
except (ProviderClientValidationError, ValueError) as e:
|
|
441
|
+
structlogger.error(
|
|
442
|
+
f"{log_source_function}.embedder_instantiation_failed",
|
|
443
|
+
event_info=(
|
|
444
|
+
f"Unable to create the Embedding client for component - "
|
|
445
|
+
f"{log_source_component}. Please make sure you specified the required "
|
|
446
|
+
f"environment variables and configuration keys."
|
|
447
|
+
),
|
|
448
|
+
error=e,
|
|
449
|
+
)
|
|
450
|
+
sys.exit(1)
|