rasa-pro 3.9.18__py3-none-any.whl → 3.10.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- README.md +26 -57
- rasa/__init__.py +1 -2
- rasa/__main__.py +5 -0
- rasa/anonymization/anonymization_rule_executor.py +2 -2
- rasa/api.py +26 -22
- rasa/cli/arguments/data.py +27 -2
- rasa/cli/arguments/default_arguments.py +25 -3
- rasa/cli/arguments/run.py +9 -9
- rasa/cli/arguments/train.py +2 -0
- rasa/cli/data.py +70 -8
- rasa/cli/e2e_test.py +108 -433
- rasa/cli/interactive.py +1 -0
- rasa/cli/llm_fine_tuning.py +395 -0
- rasa/cli/project_templates/calm/endpoints.yml +1 -1
- rasa/cli/project_templates/tutorial/endpoints.yml +1 -1
- rasa/cli/run.py +14 -13
- rasa/cli/scaffold.py +10 -8
- rasa/cli/train.py +8 -7
- rasa/cli/utils.py +15 -0
- rasa/constants.py +7 -1
- rasa/core/actions/action.py +98 -49
- rasa/core/actions/action_run_slot_rejections.py +4 -1
- rasa/core/actions/custom_action_executor.py +9 -6
- rasa/core/actions/direct_custom_actions_executor.py +80 -0
- rasa/core/actions/e2e_stub_custom_action_executor.py +68 -0
- rasa/core/actions/grpc_custom_action_executor.py +2 -2
- rasa/core/actions/http_custom_action_executor.py +6 -5
- rasa/core/agent.py +21 -17
- rasa/core/channels/__init__.py +2 -0
- rasa/core/channels/audiocodes.py +1 -16
- rasa/core/channels/inspector/dist/index.html +0 -2
- rasa/core/channels/inspector/index.html +0 -2
- rasa/core/channels/voice_aware/__init__.py +0 -0
- rasa/core/channels/voice_aware/jambonz.py +103 -0
- rasa/core/channels/voice_aware/jambonz_protocol.py +344 -0
- rasa/core/channels/voice_aware/utils.py +20 -0
- rasa/core/channels/voice_native/__init__.py +0 -0
- rasa/core/constants.py +6 -1
- rasa/core/featurizers/single_state_featurizer.py +1 -22
- rasa/core/featurizers/tracker_featurizers.py +18 -115
- rasa/core/information_retrieval/faiss.py +7 -4
- rasa/core/information_retrieval/information_retrieval.py +8 -0
- rasa/core/information_retrieval/milvus.py +9 -2
- rasa/core/information_retrieval/qdrant.py +1 -1
- rasa/core/nlg/contextual_response_rephraser.py +32 -10
- rasa/core/nlg/summarize.py +4 -3
- rasa/core/policies/enterprise_search_policy.py +100 -44
- rasa/core/policies/flows/flow_executor.py +130 -94
- rasa/core/policies/intentless_policy.py +52 -28
- rasa/core/policies/ted_policy.py +33 -58
- rasa/core/policies/unexpected_intent_policy.py +7 -15
- rasa/core/processor.py +20 -53
- rasa/core/run.py +5 -4
- rasa/core/tracker_store.py +8 -4
- rasa/core/utils.py +45 -56
- rasa/dialogue_understanding/coexistence/llm_based_router.py +45 -12
- rasa/dialogue_understanding/commands/__init__.py +4 -0
- rasa/dialogue_understanding/commands/change_flow_command.py +0 -6
- rasa/dialogue_understanding/commands/session_start_command.py +59 -0
- rasa/dialogue_understanding/commands/set_slot_command.py +1 -5
- rasa/dialogue_understanding/commands/utils.py +38 -0
- rasa/dialogue_understanding/generator/constants.py +10 -3
- rasa/dialogue_understanding/generator/flow_retrieval.py +14 -5
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +12 -2
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +106 -87
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +28 -6
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +90 -37
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +15 -15
- rasa/dialogue_understanding/patterns/session_start.py +37 -0
- rasa/dialogue_understanding/processor/command_processor.py +13 -14
- rasa/e2e_test/aggregate_test_stats_calculator.py +124 -0
- rasa/e2e_test/assertions.py +1181 -0
- rasa/e2e_test/assertions_schema.yml +106 -0
- rasa/e2e_test/constants.py +20 -0
- rasa/e2e_test/e2e_config.py +220 -0
- rasa/e2e_test/e2e_config_schema.yml +26 -0
- rasa/e2e_test/e2e_test_case.py +131 -8
- rasa/e2e_test/e2e_test_converter.py +363 -0
- rasa/e2e_test/e2e_test_converter_prompt.jinja2 +70 -0
- rasa/e2e_test/e2e_test_coverage_report.py +364 -0
- rasa/e2e_test/e2e_test_result.py +26 -6
- rasa/e2e_test/e2e_test_runner.py +491 -72
- rasa/e2e_test/e2e_test_schema.yml +96 -0
- rasa/e2e_test/pykwalify_extensions.py +39 -0
- rasa/e2e_test/stub_custom_action.py +70 -0
- rasa/e2e_test/utils/__init__.py +0 -0
- rasa/e2e_test/utils/e2e_yaml_utils.py +55 -0
- rasa/e2e_test/utils/io.py +596 -0
- rasa/e2e_test/utils/validation.py +80 -0
- rasa/engine/recipes/default_components.py +0 -2
- rasa/engine/storage/local_model_storage.py +0 -1
- rasa/env.py +9 -0
- rasa/llm_fine_tuning/__init__.py +0 -0
- rasa/llm_fine_tuning/annotation_module.py +241 -0
- rasa/llm_fine_tuning/conversations.py +144 -0
- rasa/llm_fine_tuning/llm_data_preparation_module.py +178 -0
- rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +407 -0
- rasa/llm_fine_tuning/paraphrasing/__init__.py +0 -0
- rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +281 -0
- rasa/llm_fine_tuning/paraphrasing/default_rephrase_prompt_template.jina2 +44 -0
- rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +121 -0
- rasa/llm_fine_tuning/paraphrasing/rephrased_user_message.py +10 -0
- rasa/llm_fine_tuning/paraphrasing_module.py +128 -0
- rasa/llm_fine_tuning/storage.py +174 -0
- rasa/llm_fine_tuning/train_test_split_module.py +441 -0
- rasa/model_training.py +48 -16
- rasa/nlu/classifiers/diet_classifier.py +25 -38
- rasa/nlu/classifiers/logistic_regression_classifier.py +9 -44
- rasa/nlu/classifiers/sklearn_intent_classifier.py +16 -37
- rasa/nlu/extractors/crf_entity_extractor.py +50 -93
- rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +45 -78
- rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +17 -52
- rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +3 -5
- rasa/nlu/persistor.py +129 -32
- rasa/server.py +45 -10
- rasa/shared/constants.py +63 -15
- rasa/shared/core/domain.py +15 -12
- rasa/shared/core/events.py +28 -2
- rasa/shared/core/flows/flow.py +208 -13
- rasa/shared/core/flows/flow_path.py +84 -0
- rasa/shared/core/flows/flows_list.py +28 -10
- rasa/shared/core/flows/flows_yaml_schema.json +269 -193
- rasa/shared/core/flows/validation.py +112 -25
- rasa/shared/core/flows/yaml_flows_io.py +149 -10
- rasa/shared/core/trackers.py +6 -0
- rasa/shared/core/training_data/visualization.html +2 -2
- rasa/shared/exceptions.py +4 -0
- rasa/shared/importers/importer.py +60 -11
- rasa/shared/importers/remote_importer.py +196 -0
- rasa/shared/nlu/constants.py +2 -0
- rasa/shared/nlu/training_data/features.py +2 -120
- rasa/shared/providers/_configs/__init__.py +0 -0
- rasa/shared/providers/_configs/azure_openai_client_config.py +181 -0
- rasa/shared/providers/_configs/client_config.py +57 -0
- rasa/shared/providers/_configs/default_litellm_client_config.py +130 -0
- rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +234 -0
- rasa/shared/providers/_configs/openai_client_config.py +175 -0
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +171 -0
- rasa/shared/providers/_configs/utils.py +101 -0
- rasa/shared/providers/_ssl_verification_utils.py +124 -0
- rasa/shared/providers/embedding/__init__.py +0 -0
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +254 -0
- rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py +74 -0
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +277 -0
- rasa/shared/providers/embedding/default_litellm_embedding_client.py +102 -0
- rasa/shared/providers/embedding/embedding_client.py +90 -0
- rasa/shared/providers/embedding/embedding_response.py +41 -0
- rasa/shared/providers/embedding/huggingface_local_embedding_client.py +191 -0
- rasa/shared/providers/embedding/openai_embedding_client.py +172 -0
- rasa/shared/providers/llm/__init__.py +0 -0
- rasa/shared/providers/llm/_base_litellm_client.py +227 -0
- rasa/shared/providers/llm/azure_openai_llm_client.py +338 -0
- rasa/shared/providers/llm/default_litellm_llm_client.py +84 -0
- rasa/shared/providers/llm/llm_client.py +76 -0
- rasa/shared/providers/llm/llm_response.py +50 -0
- rasa/shared/providers/llm/openai_llm_client.py +155 -0
- rasa/shared/providers/llm/self_hosted_llm_client.py +169 -0
- rasa/shared/providers/mappings.py +75 -0
- rasa/shared/utils/cli.py +30 -0
- rasa/shared/utils/io.py +65 -3
- rasa/shared/utils/llm.py +223 -200
- rasa/shared/utils/yaml.py +122 -7
- rasa/studio/download.py +19 -13
- rasa/studio/train.py +2 -3
- rasa/studio/upload.py +2 -3
- rasa/telemetry.py +113 -58
- rasa/tracing/config.py +2 -3
- rasa/tracing/instrumentation/attribute_extractors.py +29 -17
- rasa/tracing/instrumentation/instrumentation.py +4 -47
- rasa/utils/common.py +18 -19
- rasa/utils/endpoints.py +7 -4
- rasa/utils/io.py +66 -0
- rasa/utils/json_utils.py +60 -0
- rasa/utils/licensing.py +9 -1
- rasa/utils/ml_utils.py +4 -2
- rasa/utils/tensorflow/model_data.py +193 -2
- rasa/validator.py +195 -1
- rasa/version.py +1 -1
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.3.dist-info}/METADATA +47 -72
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.3.dist-info}/RECORD +185 -121
- rasa/nlu/classifiers/llm_intent_classifier.py +0 -519
- rasa/shared/providers/openai/clients.py +0 -43
- rasa/shared/providers/openai/session_handler.py +0 -110
- rasa/utils/tensorflow/feature_array.py +0 -366
- /rasa/{shared/providers/openai → cli/project_templates/tutorial/actions}/__init__.py +0 -0
- /rasa/cli/project_templates/tutorial/{actions.py → actions/actions.py} +0 -0
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.3.dist-info}/NOTICE +0 -0
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.3.dist-info}/WHEEL +0 -0
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.3.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
from typing import Any, Dict, Optional
|
|
2
|
+
import structlog
|
|
3
|
+
|
|
4
|
+
from rasa.shared.constants import OPENAI_PROVIDER
|
|
5
|
+
from rasa.shared.providers._configs.self_hosted_llm_client_config import (
|
|
6
|
+
SelfHostedLLMClientConfig,
|
|
7
|
+
)
|
|
8
|
+
from rasa.shared.providers.llm._base_litellm_client import _BaseLiteLLMClient
|
|
9
|
+
|
|
10
|
+
structlogger = structlog.get_logger()
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SelfHostedLLMClient(_BaseLiteLLMClient):
|
|
14
|
+
"""A client for interfacing with Self Hosted LLM endpoints that uses
|
|
15
|
+
|
|
16
|
+
Parameters:
|
|
17
|
+
model (str): The model or deployment name.
|
|
18
|
+
provider (str): The provider of the model.
|
|
19
|
+
api_base (str): The base URL of the API endpoint.
|
|
20
|
+
api_type (Optional[str]): The type of the API endpoint.
|
|
21
|
+
api_version (Optional[str]): The version of the API endpoint.
|
|
22
|
+
kwargs: Any: Additional configuration parameters that can include, but
|
|
23
|
+
are not limited to model parameters and lite-llm specific
|
|
24
|
+
parameters. These parameters will be passed to the
|
|
25
|
+
completion/acompletion calls. To see what it can include, visit:
|
|
26
|
+
|
|
27
|
+
Raises:
|
|
28
|
+
ProviderClientValidationError: If validation of the client setup fails.
|
|
29
|
+
ProviderClientAPIException: If the API request fails.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
provider: str,
|
|
35
|
+
model: str,
|
|
36
|
+
api_base: str,
|
|
37
|
+
api_type: Optional[str] = None,
|
|
38
|
+
api_version: Optional[str] = None,
|
|
39
|
+
**kwargs: Any,
|
|
40
|
+
):
|
|
41
|
+
super().__init__() # type: ignore
|
|
42
|
+
self._provider = provider
|
|
43
|
+
self._model = model
|
|
44
|
+
self._api_base = api_base
|
|
45
|
+
self._api_type = api_type
|
|
46
|
+
self._api_version = api_version
|
|
47
|
+
self._extra_parameters = kwargs or {}
|
|
48
|
+
|
|
49
|
+
@classmethod
|
|
50
|
+
def from_config(cls, config: Dict[str, Any]) -> "SelfHostedLLMClient":
|
|
51
|
+
try:
|
|
52
|
+
client_config = SelfHostedLLMClientConfig.from_dict(config)
|
|
53
|
+
except ValueError as e:
|
|
54
|
+
message = "Cannot instantiate a client from the passed configuration."
|
|
55
|
+
structlogger.error(
|
|
56
|
+
"self_hosted_llm_client.from_config.error",
|
|
57
|
+
message=message,
|
|
58
|
+
config=config,
|
|
59
|
+
original_error=e,
|
|
60
|
+
)
|
|
61
|
+
raise
|
|
62
|
+
|
|
63
|
+
return cls(
|
|
64
|
+
model=client_config.model,
|
|
65
|
+
provider=client_config.provider,
|
|
66
|
+
api_base=client_config.api_base,
|
|
67
|
+
api_type=client_config.api_type,
|
|
68
|
+
api_version=client_config.api_version,
|
|
69
|
+
**client_config.extra_parameters,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def provider(self) -> str:
|
|
74
|
+
"""
|
|
75
|
+
Returns the provider name for the self hosted llm client.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
String representing the provider name.
|
|
79
|
+
"""
|
|
80
|
+
return self._provider
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def model(self) -> str:
|
|
84
|
+
"""
|
|
85
|
+
Returns the model name for the self hosted llm client.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
String representing the model name.
|
|
89
|
+
"""
|
|
90
|
+
return self._model
|
|
91
|
+
|
|
92
|
+
@property
|
|
93
|
+
def api_base(self) -> str:
|
|
94
|
+
"""
|
|
95
|
+
Returns the base URL for the API endpoint.
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
String representing the base URL.
|
|
99
|
+
"""
|
|
100
|
+
return self._api_base
|
|
101
|
+
|
|
102
|
+
@property
|
|
103
|
+
def api_type(self) -> Optional[str]:
|
|
104
|
+
"""
|
|
105
|
+
Returns the type of the API endpoint. Currently only OpenAI is supported.
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
String representing the API type.
|
|
109
|
+
"""
|
|
110
|
+
return self._api_type
|
|
111
|
+
|
|
112
|
+
@property
|
|
113
|
+
def api_version(self) -> Optional[str]:
|
|
114
|
+
"""
|
|
115
|
+
Returns the version of the API endpoint.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
String representing the API version.
|
|
119
|
+
"""
|
|
120
|
+
return self._api_version
|
|
121
|
+
|
|
122
|
+
@property
|
|
123
|
+
def config(self) -> Dict:
|
|
124
|
+
"""
|
|
125
|
+
Returns the configuration for the self hosted llm client.
|
|
126
|
+
Returns:
|
|
127
|
+
Dictionary containing the configuration.
|
|
128
|
+
"""
|
|
129
|
+
config = SelfHostedLLMClientConfig(
|
|
130
|
+
model=self._model,
|
|
131
|
+
provider=self._provider,
|
|
132
|
+
api_base=self._api_base,
|
|
133
|
+
api_type=self._api_type,
|
|
134
|
+
api_version=self._api_version,
|
|
135
|
+
extra_parameters=self._extra_parameters,
|
|
136
|
+
)
|
|
137
|
+
return config.to_dict()
|
|
138
|
+
|
|
139
|
+
@property
|
|
140
|
+
def _litellm_model_name(self) -> str:
|
|
141
|
+
"""Returns the value of LiteLLM's model parameter to be used in
|
|
142
|
+
completion/acompletion in LiteLLM format:
|
|
143
|
+
|
|
144
|
+
<openai>/<model or deployment name>
|
|
145
|
+
"""
|
|
146
|
+
if self.model and f"{OPENAI_PROVIDER}/" not in self.model:
|
|
147
|
+
return f"{OPENAI_PROVIDER}/{self.model}"
|
|
148
|
+
return self.model
|
|
149
|
+
|
|
150
|
+
@property
|
|
151
|
+
def _litellm_extra_parameters(self) -> Dict[str, Any]:
|
|
152
|
+
"""Returns optional configuration parameters specific
|
|
153
|
+
to the client provider and deployed model.
|
|
154
|
+
"""
|
|
155
|
+
return self._extra_parameters
|
|
156
|
+
|
|
157
|
+
@property
|
|
158
|
+
def _completion_fn_args(self) -> Dict[str, Any]:
|
|
159
|
+
"""Returns the completion arguments for invoking a call through
|
|
160
|
+
LiteLLM's completion functions.
|
|
161
|
+
"""
|
|
162
|
+
fn_args = super()._completion_fn_args
|
|
163
|
+
fn_args.update(
|
|
164
|
+
{
|
|
165
|
+
"api_base": self.api_base,
|
|
166
|
+
"api_version": self.api_version,
|
|
167
|
+
}
|
|
168
|
+
)
|
|
169
|
+
return fn_args
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
from typing import Dict, Type, Optional
|
|
2
|
+
|
|
3
|
+
from rasa.shared.constants import (
|
|
4
|
+
AZURE_OPENAI_PROVIDER,
|
|
5
|
+
HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER,
|
|
6
|
+
OPENAI_PROVIDER,
|
|
7
|
+
SELF_HOSTED_PROVIDER,
|
|
8
|
+
)
|
|
9
|
+
from rasa.shared.providers.embedding.azure_openai_embedding_client import (
|
|
10
|
+
AzureOpenAIEmbeddingClient,
|
|
11
|
+
)
|
|
12
|
+
from rasa.shared.providers.embedding.default_litellm_embedding_client import (
|
|
13
|
+
DefaultLiteLLMEmbeddingClient,
|
|
14
|
+
)
|
|
15
|
+
from rasa.shared.providers.embedding.embedding_client import EmbeddingClient
|
|
16
|
+
from rasa.shared.providers.embedding.huggingface_local_embedding_client import (
|
|
17
|
+
HuggingFaceLocalEmbeddingClient,
|
|
18
|
+
)
|
|
19
|
+
from rasa.shared.providers.embedding.openai_embedding_client import (
|
|
20
|
+
OpenAIEmbeddingClient,
|
|
21
|
+
)
|
|
22
|
+
from rasa.shared.providers.llm.azure_openai_llm_client import AzureOpenAILLMClient
|
|
23
|
+
from rasa.shared.providers.llm.default_litellm_llm_client import DefaultLiteLLMClient
|
|
24
|
+
from rasa.shared.providers.llm.llm_client import LLMClient
|
|
25
|
+
from rasa.shared.providers.llm.openai_llm_client import OpenAILLMClient
|
|
26
|
+
from rasa.shared.providers.llm.self_hosted_llm_client import SelfHostedLLMClient
|
|
27
|
+
from rasa.shared.providers._configs.azure_openai_client_config import (
|
|
28
|
+
AzureOpenAIClientConfig,
|
|
29
|
+
)
|
|
30
|
+
from rasa.shared.providers._configs.default_litellm_client_config import (
|
|
31
|
+
DefaultLiteLLMClientConfig,
|
|
32
|
+
)
|
|
33
|
+
from rasa.shared.providers._configs.huggingface_local_embedding_client_config import (
|
|
34
|
+
HuggingFaceLocalEmbeddingClientConfig,
|
|
35
|
+
)
|
|
36
|
+
from rasa.shared.providers._configs.openai_client_config import OpenAIClientConfig
|
|
37
|
+
from rasa.shared.providers._configs.self_hosted_llm_client_config import (
|
|
38
|
+
SelfHostedLLMClientConfig,
|
|
39
|
+
)
|
|
40
|
+
from rasa.shared.providers._configs.client_config import ClientConfig
|
|
41
|
+
|
|
42
|
+
_provider_to_llm_client_mapping: Dict[str, Type[LLMClient]] = {
|
|
43
|
+
OPENAI_PROVIDER: OpenAILLMClient,
|
|
44
|
+
AZURE_OPENAI_PROVIDER: AzureOpenAILLMClient,
|
|
45
|
+
SELF_HOSTED_PROVIDER: SelfHostedLLMClient,
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
_provider_to_embedding_client_mapping: Dict[str, Type[EmbeddingClient]] = {
|
|
49
|
+
OPENAI_PROVIDER: OpenAIEmbeddingClient,
|
|
50
|
+
AZURE_OPENAI_PROVIDER: AzureOpenAIEmbeddingClient,
|
|
51
|
+
HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER: HuggingFaceLocalEmbeddingClient,
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
_provider_to_client_config_class_mapping: Dict[str, Type] = {
|
|
55
|
+
OPENAI_PROVIDER: OpenAIClientConfig,
|
|
56
|
+
AZURE_OPENAI_PROVIDER: AzureOpenAIClientConfig,
|
|
57
|
+
HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER: HuggingFaceLocalEmbeddingClientConfig,
|
|
58
|
+
SELF_HOSTED_PROVIDER: SelfHostedLLMClientConfig,
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def get_llm_client_from_provider(provider: Optional[str]) -> Type[LLMClient]:
|
|
63
|
+
return _provider_to_llm_client_mapping.get(provider, DefaultLiteLLMClient)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def get_embedding_client_from_provider(provider: str) -> Type[EmbeddingClient]:
|
|
67
|
+
return _provider_to_embedding_client_mapping.get(
|
|
68
|
+
provider, DefaultLiteLLMEmbeddingClient
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def get_client_config_class_from_provider(provider: str) -> Type[ClientConfig]:
|
|
73
|
+
return _provider_to_client_config_class_mapping.get(
|
|
74
|
+
provider, DefaultLiteLLMClientConfig
|
|
75
|
+
)
|
rasa/shared/utils/cli.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import shutil
|
|
1
3
|
import sys
|
|
2
4
|
from typing import Any, Text, NoReturn
|
|
3
5
|
|
|
@@ -70,3 +72,31 @@ def print_error_and_exit(message: Text, exit_code: int = 1) -> NoReturn:
|
|
|
70
72
|
"""
|
|
71
73
|
print_error(message)
|
|
72
74
|
sys.exit(exit_code)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def pad(text: Text, char: Text = "=", min: int = 3) -> Text:
|
|
78
|
+
"""Pad text to a certain length.
|
|
79
|
+
|
|
80
|
+
Uses `char` to pad the text to the specified length. If the text is longer
|
|
81
|
+
than the specified length, at least `min` are used.
|
|
82
|
+
|
|
83
|
+
The padding is applied to the left and right of the text (almost) equally.
|
|
84
|
+
|
|
85
|
+
Example:
|
|
86
|
+
>>> pad("Hello")
|
|
87
|
+
"========= Hello ========"
|
|
88
|
+
>>> pad("Hello", char="-")
|
|
89
|
+
"--------- Hello --------"
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
text: Text to pad.
|
|
93
|
+
min: Minimum length of the padding.
|
|
94
|
+
char: Character to pad with.
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
Padded text.
|
|
98
|
+
"""
|
|
99
|
+
width = shutil.get_terminal_size((80, 20)).columns
|
|
100
|
+
padding = max(width - len(text) - 2, min * 2)
|
|
101
|
+
|
|
102
|
+
return char * (padding // 2) + " " + text + " " + char * math.ceil(padding / 2)
|
rasa/shared/utils/io.py
CHANGED
|
@@ -1,16 +1,18 @@
|
|
|
1
1
|
from collections import OrderedDict
|
|
2
|
+
from functools import wraps
|
|
3
|
+
from hashlib import md5
|
|
4
|
+
import asyncio
|
|
2
5
|
import errno
|
|
3
6
|
import glob
|
|
4
|
-
from hashlib import md5
|
|
5
7
|
import json
|
|
8
|
+
import logging
|
|
6
9
|
import os
|
|
7
10
|
import sys
|
|
8
11
|
from pathlib import Path
|
|
9
|
-
from typing import Any, Dict, List, Optional, Text, Type, Union
|
|
12
|
+
from typing import Any, cast, Callable, Dict, List, Optional, Text, Type, TypeVar, Union
|
|
10
13
|
import warnings
|
|
11
14
|
import random
|
|
12
15
|
import string
|
|
13
|
-
|
|
14
16
|
import portalocker
|
|
15
17
|
|
|
16
18
|
from rasa.shared.constants import (
|
|
@@ -137,6 +139,17 @@ def read_json_file(filename: Union[Text, Path]) -> Any:
|
|
|
137
139
|
)
|
|
138
140
|
|
|
139
141
|
|
|
142
|
+
def read_jsonl_file(file_path: Union[Text, Path]) -> List[Any]:
|
|
143
|
+
"""Read JSONL from a file."""
|
|
144
|
+
content = read_file(file_path)
|
|
145
|
+
try:
|
|
146
|
+
return [json.loads(line) for line in content.splitlines()]
|
|
147
|
+
except ValueError as e:
|
|
148
|
+
raise FileIOException(
|
|
149
|
+
f"Failed to read JSONL from '{os.path.abspath(file_path)}'. Error: {e}"
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
|
|
140
153
|
def list_directory(path: Text) -> List[Text]:
|
|
141
154
|
"""Returns all files and folders excluding hidden files.
|
|
142
155
|
|
|
@@ -413,3 +426,52 @@ def file_as_bytes(file_path: Text) -> bytes:
|
|
|
413
426
|
raise FileNotFoundException(
|
|
414
427
|
f"Failed to read file, " f"'{os.path.abspath(file_path)}' does not exist."
|
|
415
428
|
)
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
F = TypeVar("F", bound=Callable[..., Any])
|
|
432
|
+
|
|
433
|
+
|
|
434
|
+
def suppress_logs(log_level: int = logging.WARNING) -> Callable[[F], F]:
|
|
435
|
+
"""Decorator to suppress logs during the execution of a function.
|
|
436
|
+
|
|
437
|
+
Args:
|
|
438
|
+
log_level: The log level to set during the execution of the function.
|
|
439
|
+
|
|
440
|
+
Returns:
|
|
441
|
+
The decorated function.
|
|
442
|
+
"""
|
|
443
|
+
|
|
444
|
+
def decorator(func: F) -> F:
|
|
445
|
+
@wraps(func)
|
|
446
|
+
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
447
|
+
# Store the original logging level and set the new level.
|
|
448
|
+
original_logging_level = logging.getLogger().getEffectiveLevel()
|
|
449
|
+
logging.getLogger().setLevel(log_level)
|
|
450
|
+
try:
|
|
451
|
+
# Execute the async function.
|
|
452
|
+
result = await func(*args, **kwargs)
|
|
453
|
+
finally:
|
|
454
|
+
# Reset the logging level to the original level.
|
|
455
|
+
logging.getLogger().setLevel(original_logging_level)
|
|
456
|
+
return result
|
|
457
|
+
|
|
458
|
+
@wraps(func)
|
|
459
|
+
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
460
|
+
# Store the original logging level and set the new level.
|
|
461
|
+
original_logging_level = logging.getLogger().getEffectiveLevel()
|
|
462
|
+
logging.getLogger().setLevel(log_level)
|
|
463
|
+
try:
|
|
464
|
+
# Execute the function.
|
|
465
|
+
result = func(*args, **kwargs)
|
|
466
|
+
finally:
|
|
467
|
+
# Reset the logging level to the original level.
|
|
468
|
+
logging.getLogger().setLevel(original_logging_level)
|
|
469
|
+
return result
|
|
470
|
+
|
|
471
|
+
# Determine if the function is async or not
|
|
472
|
+
if asyncio.iscoroutinefunction(func):
|
|
473
|
+
return cast(F, async_wrapper)
|
|
474
|
+
else:
|
|
475
|
+
return cast(F, sync_wrapper)
|
|
476
|
+
|
|
477
|
+
return decorator
|