rasa-pro 3.11.3a1.dev7__py3-none-any.whl → 3.12.0.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/cli/arguments/default_arguments.py +1 -1
- rasa/cli/dialogue_understanding_test.py +251 -0
- rasa/core/actions/action.py +7 -16
- rasa/core/channels/__init__.py +0 -2
- rasa/core/channels/socketio.py +23 -1
- rasa/core/nlg/contextual_response_rephraser.py +9 -62
- rasa/core/policies/enterprise_search_policy.py +12 -77
- rasa/core/policies/flows/flow_executor.py +2 -26
- rasa/core/processor.py +8 -11
- rasa/dialogue_understanding/generator/command_generator.py +49 -43
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +5 -5
- rasa/dialogue_understanding/generator/llm_command_generator.py +1 -2
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +15 -34
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +6 -11
- rasa/dialogue_understanding/utils.py +1 -8
- rasa/dialogue_understanding_test/command_metric_calculation.py +12 -0
- rasa/dialogue_understanding_test/constants.py +2 -0
- rasa/dialogue_understanding_test/du_test_runner.py +93 -0
- rasa/dialogue_understanding_test/io.py +54 -0
- rasa/dialogue_understanding_test/validation.py +22 -0
- rasa/e2e_test/e2e_test_runner.py +9 -7
- rasa/hooks.py +9 -15
- rasa/model_manager/socket_bridge.py +2 -7
- rasa/model_manager/warm_rasa_process.py +4 -9
- rasa/plugin.py +0 -11
- rasa/shared/constants.py +2 -21
- rasa/shared/core/events.py +8 -8
- rasa/shared/nlu/constants.py +0 -3
- rasa/shared/providers/_configs/azure_entra_id_client_creds.py +40 -0
- rasa/shared/providers/_configs/azure_entra_id_config.py +533 -0
- rasa/shared/providers/_configs/azure_openai_client_config.py +131 -15
- rasa/shared/providers/_configs/client_config.py +3 -1
- rasa/shared/providers/_configs/default_litellm_client_config.py +9 -7
- rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +13 -11
- rasa/shared/providers/_configs/litellm_router_client_config.py +12 -10
- rasa/shared/providers/_configs/model_group_config.py +11 -5
- rasa/shared/providers/_configs/oauth_config.py +33 -0
- rasa/shared/providers/_configs/openai_client_config.py +14 -12
- rasa/shared/providers/_configs/rasa_llm_client_config.py +5 -3
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +12 -11
- rasa/shared/providers/constants.py +6 -0
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +30 -7
- rasa/shared/providers/embedding/litellm_router_embedding_client.py +5 -2
- rasa/shared/providers/llm/_base_litellm_client.py +6 -4
- rasa/shared/providers/llm/azure_openai_llm_client.py +88 -34
- rasa/shared/providers/llm/default_litellm_llm_client.py +4 -2
- rasa/shared/providers/llm/litellm_router_llm_client.py +23 -3
- rasa/shared/providers/llm/llm_client.py +4 -2
- rasa/shared/providers/llm/llm_response.py +1 -42
- rasa/shared/providers/llm/openai_llm_client.py +11 -5
- rasa/shared/providers/llm/rasa_llm_client.py +13 -5
- rasa/shared/providers/llm/self_hosted_llm_client.py +17 -10
- rasa/shared/providers/router/_base_litellm_router_client.py +10 -8
- rasa/shared/providers/router/router_client.py +3 -1
- rasa/shared/utils/llm.py +16 -12
- rasa/shared/utils/schemas/events.py +1 -1
- rasa/tracing/instrumentation/attribute_extractors.py +0 -2
- rasa/version.py +1 -1
- {rasa_pro-3.11.3a1.dev7.dist-info → rasa_pro-3.12.0.dev2.dist-info}/METADATA +2 -1
- {rasa_pro-3.11.3a1.dev7.dist-info → rasa_pro-3.12.0.dev2.dist-info}/RECORD +63 -56
- rasa/core/channels/studio_chat.py +0 -192
- rasa/dialogue_understanding/constants.py +0 -1
- {rasa_pro-3.11.3a1.dev7.dist-info → rasa_pro-3.12.0.dev2.dist-info}/NOTICE +0 -0
- {rasa_pro-3.11.3a1.dev7.dist-info → rasa_pro-3.12.0.dev2.dist-info}/WHEEL +0 -0
- {rasa_pro-3.11.3a1.dev7.dist-info → rasa_pro-3.12.0.dev2.dist-info}/entry_points.txt +0 -0
|
@@ -1,36 +1,55 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from copy import deepcopy
|
|
1
4
|
from dataclasses import asdict, dataclass, field
|
|
2
|
-
from typing import
|
|
5
|
+
from typing import (
|
|
6
|
+
Any,
|
|
7
|
+
Dict,
|
|
8
|
+
Optional,
|
|
9
|
+
Set,
|
|
10
|
+
)
|
|
3
11
|
|
|
4
12
|
import structlog
|
|
5
13
|
|
|
6
14
|
from rasa.shared.constants import (
|
|
7
|
-
MODEL_CONFIG_KEY,
|
|
8
|
-
MODEL_NAME_CONFIG_KEY,
|
|
9
|
-
OPENAI_API_BASE_CONFIG_KEY,
|
|
10
15
|
API_BASE_CONFIG_KEY,
|
|
11
|
-
|
|
16
|
+
API_KEY,
|
|
12
17
|
API_TYPE_CONFIG_KEY,
|
|
13
|
-
OPENAI_API_VERSION_CONFIG_KEY,
|
|
14
18
|
API_VERSION_CONFIG_KEY,
|
|
19
|
+
AZURE_API_TYPE,
|
|
20
|
+
AZURE_OPENAI_PROVIDER,
|
|
15
21
|
DEPLOYMENT_CONFIG_KEY,
|
|
16
22
|
DEPLOYMENT_NAME_CONFIG_KEY,
|
|
17
23
|
ENGINE_CONFIG_KEY,
|
|
18
|
-
RASA_TYPE_CONFIG_KEY,
|
|
19
24
|
LANGCHAIN_TYPE_CONFIG_KEY,
|
|
20
|
-
|
|
25
|
+
MODEL_CONFIG_KEY,
|
|
26
|
+
MODEL_NAME_CONFIG_KEY,
|
|
21
27
|
N_REPHRASES_CONFIG_KEY,
|
|
28
|
+
OPENAI_API_BASE_CONFIG_KEY,
|
|
29
|
+
OPENAI_API_TYPE_CONFIG_KEY,
|
|
30
|
+
OPENAI_API_VERSION_CONFIG_KEY,
|
|
31
|
+
PROVIDER_CONFIG_KEY,
|
|
32
|
+
RASA_TYPE_CONFIG_KEY,
|
|
22
33
|
REQUEST_TIMEOUT_CONFIG_KEY,
|
|
34
|
+
STREAM_CONFIG_KEY,
|
|
23
35
|
TIMEOUT_CONFIG_KEY,
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
36
|
+
)
|
|
37
|
+
from rasa.shared.providers._configs.azure_entra_id_config import (
|
|
38
|
+
AzureEntraIDOAuthConfig,
|
|
39
|
+
AzureEntraIDOAuthType,
|
|
40
|
+
)
|
|
41
|
+
from rasa.shared.providers._configs.oauth_config import (
|
|
42
|
+
OAUTH_KEY,
|
|
43
|
+
OAUTH_TYPE_FIELD,
|
|
44
|
+
OAuth,
|
|
27
45
|
)
|
|
28
46
|
from rasa.shared.providers._configs.utils import (
|
|
29
|
-
resolve_aliases,
|
|
30
47
|
raise_deprecation_warnings,
|
|
31
|
-
|
|
48
|
+
resolve_aliases,
|
|
32
49
|
validate_forbidden_keys,
|
|
50
|
+
validate_required_keys,
|
|
33
51
|
)
|
|
52
|
+
from rasa.shared.utils.common import class_from_module_path
|
|
34
53
|
|
|
35
54
|
structlogger = structlog.get_logger()
|
|
36
55
|
|
|
@@ -61,6 +80,76 @@ FORBIDDEN_KEYS = [
|
|
|
61
80
|
]
|
|
62
81
|
|
|
63
82
|
|
|
83
|
+
@dataclass
|
|
84
|
+
class OAuthConfigWrapper(OAuth):
|
|
85
|
+
"""Wrapper for OAuth configuration.
|
|
86
|
+
|
|
87
|
+
It's main purpose is to provide to_dict method which is used to serialize
|
|
88
|
+
the oauth configuration to the original format.
|
|
89
|
+
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
oauth: OAuth
|
|
93
|
+
original_config: Dict[str, Any]
|
|
94
|
+
|
|
95
|
+
def get_bearer_token(self) -> str:
|
|
96
|
+
"""Returns a bearer token."""
|
|
97
|
+
return self.oauth.get_bearer_token()
|
|
98
|
+
|
|
99
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
100
|
+
"""Converts the OAuth configuration to the original format."""
|
|
101
|
+
return self.original_config
|
|
102
|
+
|
|
103
|
+
@staticmethod
|
|
104
|
+
def _valid_type_values() -> Set[str]:
|
|
105
|
+
"""Returns the valid built-in values for the `type` field in the `oauth`."""
|
|
106
|
+
return AzureEntraIDOAuthType.valid_string_values()
|
|
107
|
+
|
|
108
|
+
@classmethod
|
|
109
|
+
def from_config(cls, oauth_config: Dict[str, Any]) -> OAuth:
|
|
110
|
+
"""Initializes a dataclass from the passed config.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
oauth_config: (dict) The config from which to initialize.
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
AzureOAuthConfig
|
|
117
|
+
"""
|
|
118
|
+
original_config = deepcopy(oauth_config)
|
|
119
|
+
|
|
120
|
+
oauth_type: Optional[str] = oauth_config.get(OAUTH_TYPE_FIELD, None)
|
|
121
|
+
|
|
122
|
+
if oauth_type is None:
|
|
123
|
+
message = (
|
|
124
|
+
"Oauth configuration must contain "
|
|
125
|
+
f"'{OAUTH_TYPE_FIELD}' field and it must be set to one of the "
|
|
126
|
+
f"following values: {OAuthConfigWrapper._valid_type_values()}, "
|
|
127
|
+
f"or to the path of module which is implementing {OAuth.__name__} protocol."
|
|
128
|
+
)
|
|
129
|
+
structlogger.error(
|
|
130
|
+
"azure_oauth_config.missing_oauth_type",
|
|
131
|
+
message=message,
|
|
132
|
+
)
|
|
133
|
+
raise ValueError(message)
|
|
134
|
+
|
|
135
|
+
if oauth_type in AzureEntraIDOAuthType.valid_string_values():
|
|
136
|
+
oauth = AzureEntraIDOAuthConfig.from_config(oauth_config)
|
|
137
|
+
else:
|
|
138
|
+
module = class_from_module_path(oauth_type)
|
|
139
|
+
|
|
140
|
+
if not issubclass(module, OAuth):
|
|
141
|
+
message = f"Module {oauth_type} does not implement {OAuth.__name__} interface."
|
|
142
|
+
structlogger.error(
|
|
143
|
+
"azure_oauth_config.invalid_oauth_module",
|
|
144
|
+
message=message,
|
|
145
|
+
)
|
|
146
|
+
raise ValueError(message)
|
|
147
|
+
|
|
148
|
+
oauth = module.from_config(oauth_config)
|
|
149
|
+
|
|
150
|
+
return cls(oauth=oauth, original_config=original_config)
|
|
151
|
+
|
|
152
|
+
|
|
64
153
|
@dataclass
|
|
65
154
|
class AzureOpenAIClientConfig:
|
|
66
155
|
"""Parses configuration for Azure OpenAI client, resolves aliases and
|
|
@@ -80,11 +169,13 @@ class AzureOpenAIClientConfig:
|
|
|
80
169
|
# API Type is not used by LiteLLM backend, but we define
|
|
81
170
|
# it here for backward compatibility.
|
|
82
171
|
api_type: Optional[str] = AZURE_API_TYPE
|
|
83
|
-
|
|
84
172
|
# Provider is not used by LiteLLM backend, but we define it here since it's
|
|
85
173
|
# used as switch between different clients.
|
|
86
174
|
provider: str = AZURE_OPENAI_PROVIDER
|
|
87
175
|
|
|
176
|
+
# OAuth related parameters
|
|
177
|
+
oauth: Optional[OAuthConfigWrapper] = None
|
|
178
|
+
|
|
88
179
|
extra_parameters: dict = field(default_factory=dict)
|
|
89
180
|
|
|
90
181
|
def __post_init__(self) -> None:
|
|
@@ -106,7 +197,7 @@ class AzureOpenAIClientConfig:
|
|
|
106
197
|
raise ValueError(message)
|
|
107
198
|
|
|
108
199
|
@classmethod
|
|
109
|
-
def from_dict(cls, config: dict) ->
|
|
200
|
+
def from_dict(cls, config: dict) -> AzureOpenAIClientConfig:
|
|
110
201
|
"""Initializes a dataclass from the passed config.
|
|
111
202
|
|
|
112
203
|
Args:
|
|
@@ -129,6 +220,26 @@ class AzureOpenAIClientConfig:
|
|
|
129
220
|
# Validate that the forbidden keys are not present
|
|
130
221
|
validate_forbidden_keys(config, FORBIDDEN_KEYS)
|
|
131
222
|
# Init client config
|
|
223
|
+
|
|
224
|
+
has_api_key = config.get(API_KEY, None) is not None
|
|
225
|
+
has_oauth_key = config.get(OAUTH_KEY, None) is not None
|
|
226
|
+
|
|
227
|
+
if has_api_key and has_oauth_key:
|
|
228
|
+
message = (
|
|
229
|
+
"Azure OpenAI client configuration cannot contain "
|
|
230
|
+
f"both '{API_KEY}' and '{OAUTH_KEY}' fields. Please provide either "
|
|
231
|
+
f"'{API_KEY}' or '{OAUTH_KEY}' fields."
|
|
232
|
+
)
|
|
233
|
+
structlogger.error(
|
|
234
|
+
"azure_openai_client_config.multiple_auth_types_specified",
|
|
235
|
+
message=message,
|
|
236
|
+
)
|
|
237
|
+
raise ValueError(message)
|
|
238
|
+
|
|
239
|
+
oauth = None
|
|
240
|
+
if has_oauth_key:
|
|
241
|
+
oauth = OAuthConfigWrapper.from_config(config.pop(OAUTH_KEY))
|
|
242
|
+
|
|
132
243
|
this = AzureOpenAIClientConfig(
|
|
133
244
|
# Required parameters
|
|
134
245
|
deployment=config.pop(DEPLOYMENT_CONFIG_KEY),
|
|
@@ -142,6 +253,8 @@ class AzureOpenAIClientConfig:
|
|
|
142
253
|
# in clients.
|
|
143
254
|
api_base=config.pop(API_BASE_CONFIG_KEY, None),
|
|
144
255
|
api_version=config.pop(API_VERSION_CONFIG_KEY, None),
|
|
256
|
+
# OAuth related parameters, set only if auth_type is set to 'entra_id'
|
|
257
|
+
oauth=oauth,
|
|
145
258
|
# The rest of parameters (e.g. model parameters) are considered
|
|
146
259
|
# as extra parameters (this also includes timeout).
|
|
147
260
|
extra_parameters=config,
|
|
@@ -154,6 +267,9 @@ class AzureOpenAIClientConfig:
|
|
|
154
267
|
# Extra parameters should also be on the top level
|
|
155
268
|
d.pop("extra_parameters", None)
|
|
156
269
|
d.update(self.extra_parameters)
|
|
270
|
+
|
|
271
|
+
d.pop("oauth", None)
|
|
272
|
+
d.update({"oauth": self.oauth.to_dict()} if self.oauth else {})
|
|
157
273
|
return d
|
|
158
274
|
|
|
159
275
|
@staticmethod
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
from typing import Protocol, runtime_checkable
|
|
2
4
|
|
|
3
5
|
|
|
@@ -9,7 +11,7 @@ class ClientConfig(Protocol):
|
|
|
9
11
|
"""
|
|
10
12
|
|
|
11
13
|
@classmethod
|
|
12
|
-
def from_dict(cls, config: dict) ->
|
|
14
|
+
def from_dict(cls, config: dict) -> ClientConfig:
|
|
13
15
|
"""
|
|
14
16
|
Initializes the client config with the given configuration.
|
|
15
17
|
|
|
@@ -1,24 +1,26 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
from dataclasses import asdict, dataclass, field
|
|
2
4
|
from typing import Any, Dict
|
|
3
5
|
|
|
4
6
|
import structlog
|
|
5
7
|
|
|
8
|
+
import rasa.shared.utils.cli
|
|
6
9
|
from rasa.shared.constants import (
|
|
7
10
|
MODEL_CONFIG_KEY,
|
|
8
11
|
MODEL_NAME_CONFIG_KEY,
|
|
9
|
-
STREAM_CONFIG_KEY,
|
|
10
12
|
N_REPHRASES_CONFIG_KEY,
|
|
11
13
|
PROVIDER_CONFIG_KEY,
|
|
12
|
-
TIMEOUT_CONFIG_KEY,
|
|
13
14
|
REQUEST_TIMEOUT_CONFIG_KEY,
|
|
15
|
+
STREAM_CONFIG_KEY,
|
|
16
|
+
TIMEOUT_CONFIG_KEY,
|
|
14
17
|
)
|
|
15
18
|
from rasa.shared.providers._configs.utils import (
|
|
16
|
-
validate_required_keys,
|
|
17
|
-
validate_forbidden_keys,
|
|
18
|
-
resolve_aliases,
|
|
19
19
|
raise_deprecation_warnings,
|
|
20
|
+
resolve_aliases,
|
|
21
|
+
validate_forbidden_keys,
|
|
22
|
+
validate_required_keys,
|
|
20
23
|
)
|
|
21
|
-
import rasa.shared.utils.cli
|
|
22
24
|
|
|
23
25
|
structlogger = structlog.get_logger()
|
|
24
26
|
|
|
@@ -69,7 +71,7 @@ class DefaultLiteLLMClientConfig:
|
|
|
69
71
|
raise ValueError(message)
|
|
70
72
|
|
|
71
73
|
@classmethod
|
|
72
|
-
def from_dict(cls, config: dict) ->
|
|
74
|
+
def from_dict(cls, config: dict) -> DefaultLiteLLMClientConfig:
|
|
73
75
|
"""
|
|
74
76
|
Initializes a dataclass from the passed config.
|
|
75
77
|
|
|
@@ -1,27 +1,29 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
from dataclasses import asdict, dataclass, field
|
|
2
4
|
from typing import Any, Dict, Optional
|
|
3
5
|
|
|
4
6
|
import structlog
|
|
5
7
|
|
|
6
8
|
from rasa.shared.constants import (
|
|
7
|
-
MODEL_CONFIG_KEY,
|
|
8
|
-
MODEL_NAME_CONFIG_KEY,
|
|
9
|
-
RASA_TYPE_CONFIG_KEY,
|
|
10
|
-
LANGCHAIN_TYPE_CONFIG_KEY,
|
|
11
|
-
HUGGINGFACE_MULTIPROCESS_CONFIG_KEY,
|
|
12
9
|
HUGGINGFACE_CACHE_FOLDER_CONFIG_KEY,
|
|
13
|
-
HUGGINGFACE_SHOW_PROGRESS_CONFIG_KEY,
|
|
14
|
-
HUGGINGFACE_MODEL_KWARGS_CONFIG_KEY,
|
|
15
10
|
HUGGINGFACE_ENCODE_KWARGS_CONFIG_KEY,
|
|
16
11
|
HUGGINGFACE_LOCAL_EMBEDDING_CACHING_FOLDER,
|
|
17
|
-
PROVIDER_CONFIG_KEY,
|
|
18
12
|
HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER,
|
|
19
|
-
|
|
13
|
+
HUGGINGFACE_MODEL_KWARGS_CONFIG_KEY,
|
|
14
|
+
HUGGINGFACE_MULTIPROCESS_CONFIG_KEY,
|
|
15
|
+
HUGGINGFACE_SHOW_PROGRESS_CONFIG_KEY,
|
|
16
|
+
LANGCHAIN_TYPE_CONFIG_KEY,
|
|
17
|
+
MODEL_CONFIG_KEY,
|
|
18
|
+
MODEL_NAME_CONFIG_KEY,
|
|
19
|
+
PROVIDER_CONFIG_KEY,
|
|
20
|
+
RASA_TYPE_CONFIG_KEY,
|
|
20
21
|
REQUEST_TIMEOUT_CONFIG_KEY,
|
|
22
|
+
TIMEOUT_CONFIG_KEY,
|
|
21
23
|
)
|
|
22
24
|
from rasa.shared.providers._configs.utils import (
|
|
23
|
-
resolve_aliases,
|
|
24
25
|
raise_deprecation_warnings,
|
|
26
|
+
resolve_aliases,
|
|
25
27
|
validate_required_keys,
|
|
26
28
|
)
|
|
27
29
|
from rasa.shared.utils.io import raise_deprecation_warning
|
|
@@ -90,7 +92,7 @@ class HuggingFaceLocalEmbeddingClientConfig:
|
|
|
90
92
|
raise ValueError(message)
|
|
91
93
|
|
|
92
94
|
@classmethod
|
|
93
|
-
def from_dict(cls, config: dict) ->
|
|
95
|
+
def from_dict(cls, config: dict) -> HuggingFaceLocalEmbeddingClientConfig:
|
|
94
96
|
"""
|
|
95
97
|
Initializes a dataclass from the passed config.
|
|
96
98
|
|
|
@@ -1,29 +1,31 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import copy
|
|
2
4
|
from dataclasses import dataclass, field
|
|
3
5
|
from typing import Any, Dict, List
|
|
4
6
|
|
|
5
7
|
import structlog
|
|
8
|
+
|
|
6
9
|
from rasa.shared.constants import (
|
|
7
|
-
ROUTER_CONFIG_KEY,
|
|
8
|
-
MODELS_CONFIG_KEY,
|
|
9
|
-
MODEL_GROUP_ID_CONFIG_KEY,
|
|
10
|
-
MODEL_NAME_CONFIG_KEY,
|
|
11
|
-
LITELLM_PARAMS_KEY,
|
|
12
|
-
PROVIDER_CONFIG_KEY,
|
|
13
|
-
DEPLOYMENT_CONFIG_KEY,
|
|
14
10
|
API_TYPE_CONFIG_KEY,
|
|
11
|
+
DEPLOYMENT_CONFIG_KEY,
|
|
12
|
+
LITELLM_PARAMS_KEY,
|
|
15
13
|
MODEL_CONFIG_KEY,
|
|
14
|
+
MODEL_GROUP_ID_CONFIG_KEY,
|
|
16
15
|
MODEL_LIST_KEY,
|
|
16
|
+
MODEL_NAME_CONFIG_KEY,
|
|
17
|
+
MODELS_CONFIG_KEY,
|
|
18
|
+
PROVIDER_CONFIG_KEY,
|
|
19
|
+
ROUTER_CONFIG_KEY,
|
|
17
20
|
USE_CHAT_COMPLETIONS_ENDPOINT_CONFIG_KEY,
|
|
18
21
|
)
|
|
19
22
|
from rasa.shared.providers._configs.model_group_config import (
|
|
20
|
-
ModelGroupConfig,
|
|
21
23
|
ModelConfig,
|
|
24
|
+
ModelGroupConfig,
|
|
22
25
|
)
|
|
23
26
|
from rasa.shared.providers.mappings import get_prefix_from_provider
|
|
24
27
|
from rasa.shared.utils.llm import DEPLOYMENT_CENTRIC_PROVIDERS
|
|
25
28
|
|
|
26
|
-
|
|
27
29
|
structlogger = structlog.get_logger()
|
|
28
30
|
|
|
29
31
|
_LITELLM_UNSUPPORTED_KEYS = [
|
|
@@ -120,7 +122,7 @@ class LiteLLMRouterClientConfig:
|
|
|
120
122
|
raise ValueError(message)
|
|
121
123
|
|
|
122
124
|
@classmethod
|
|
123
|
-
def from_dict(cls, config: dict) ->
|
|
125
|
+
def from_dict(cls, config: dict) -> LiteLLMRouterClientConfig:
|
|
124
126
|
"""Initializes a dataclass from the passed config.
|
|
125
127
|
|
|
126
128
|
Args:
|
|
@@ -1,19 +1,25 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
from dataclasses import asdict, dataclass, field
|
|
2
4
|
from typing import List, Optional
|
|
3
5
|
|
|
4
6
|
import structlog
|
|
7
|
+
|
|
5
8
|
from rasa.shared.constants import (
|
|
6
9
|
API_BASE_CONFIG_KEY,
|
|
7
10
|
API_KEY,
|
|
8
11
|
API_TYPE_CONFIG_KEY,
|
|
9
12
|
API_VERSION_CONFIG_KEY,
|
|
10
13
|
DEPLOYMENT_CONFIG_KEY,
|
|
11
|
-
|
|
14
|
+
EXTRA_PARAMETERS_KEY,
|
|
12
15
|
MODEL_CONFIG_KEY,
|
|
13
16
|
MODEL_GROUP_ID_CONFIG_KEY,
|
|
14
|
-
MODELS_CONFIG_KEY,
|
|
15
17
|
MODEL_GROUPS_CONFIG_KEY,
|
|
16
|
-
|
|
18
|
+
MODELS_CONFIG_KEY,
|
|
19
|
+
PROVIDER_CONFIG_KEY,
|
|
20
|
+
)
|
|
21
|
+
from rasa.shared.providers._configs._lite_llm_config.lite_llm_config_adapter import (
|
|
22
|
+
to_lite_llm_config,
|
|
17
23
|
)
|
|
18
24
|
from rasa.shared.providers.mappings import get_client_config_class_from_provider
|
|
19
25
|
|
|
@@ -40,7 +46,7 @@ class ModelConfig:
|
|
|
40
46
|
api_type: Optional[str] = None
|
|
41
47
|
|
|
42
48
|
@classmethod
|
|
43
|
-
def from_dict(cls, config: dict) ->
|
|
49
|
+
def from_dict(cls, config: dict) -> ModelConfig:
|
|
44
50
|
"""Initializes a dataclass from the passed config. The provider config param is
|
|
45
51
|
used to determine the client config class to use. The client config class takes
|
|
46
52
|
care of resolving config aliases and throwing deprecation warnings.
|
|
@@ -130,7 +136,7 @@ class ModelGroupConfig:
|
|
|
130
136
|
raise ValueError(message)
|
|
131
137
|
|
|
132
138
|
@classmethod
|
|
133
|
-
def from_dict(cls, config: dict) ->
|
|
139
|
+
def from_dict(cls, config: dict) -> ModelGroupConfig:
|
|
134
140
|
"""Initializes a dataclass from the passed config.
|
|
135
141
|
|
|
136
142
|
Args:
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
from typing import Any, Dict, TypeVar
|
|
3
|
+
|
|
4
|
+
OAUTH_TYPE_FIELD = "type"
|
|
5
|
+
OAUTH_KEY = "oauth"
|
|
6
|
+
|
|
7
|
+
OAuthType = TypeVar("OAuthType", bound="OAuth")
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class OAuth(abc.ABC):
|
|
11
|
+
"""Interface for OAuth configuration."""
|
|
12
|
+
|
|
13
|
+
@classmethod
|
|
14
|
+
@abc.abstractmethod
|
|
15
|
+
def from_config(cls: OAuthType, config: Dict[str, Any]) -> OAuthType:
|
|
16
|
+
"""Initializes a dataclass from the passed config.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
config: (dict) The config from which to initialize.
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
OAuth
|
|
23
|
+
"""
|
|
24
|
+
...
|
|
25
|
+
|
|
26
|
+
@abc.abstractmethod
|
|
27
|
+
def get_bearer_token(self) -> str:
|
|
28
|
+
"""Returns a bearer token.
|
|
29
|
+
|
|
30
|
+
Bear token is used to authenticate requests to the Azure Oopen AI instance's API protected
|
|
31
|
+
by the Gateway.
|
|
32
|
+
"""
|
|
33
|
+
...
|
|
@@ -1,32 +1,34 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
from dataclasses import asdict, dataclass, field
|
|
2
4
|
from typing import Any, Dict, Optional
|
|
3
5
|
|
|
4
6
|
import structlog
|
|
5
7
|
|
|
6
8
|
from rasa.shared.constants import (
|
|
9
|
+
API_BASE_CONFIG_KEY,
|
|
10
|
+
API_TYPE_CONFIG_KEY,
|
|
11
|
+
API_VERSION_CONFIG_KEY,
|
|
12
|
+
LANGCHAIN_TYPE_CONFIG_KEY,
|
|
7
13
|
MODEL_CONFIG_KEY,
|
|
8
14
|
MODEL_NAME_CONFIG_KEY,
|
|
15
|
+
N_REPHRASES_CONFIG_KEY,
|
|
9
16
|
OPENAI_API_BASE_CONFIG_KEY,
|
|
10
|
-
|
|
17
|
+
OPENAI_API_TYPE,
|
|
11
18
|
OPENAI_API_TYPE_CONFIG_KEY,
|
|
12
|
-
API_TYPE_CONFIG_KEY,
|
|
13
19
|
OPENAI_API_VERSION_CONFIG_KEY,
|
|
14
|
-
|
|
20
|
+
OPENAI_PROVIDER,
|
|
21
|
+
PROVIDER_CONFIG_KEY,
|
|
15
22
|
RASA_TYPE_CONFIG_KEY,
|
|
16
|
-
LANGCHAIN_TYPE_CONFIG_KEY,
|
|
17
|
-
STREAM_CONFIG_KEY,
|
|
18
|
-
N_REPHRASES_CONFIG_KEY,
|
|
19
23
|
REQUEST_TIMEOUT_CONFIG_KEY,
|
|
24
|
+
STREAM_CONFIG_KEY,
|
|
20
25
|
TIMEOUT_CONFIG_KEY,
|
|
21
|
-
PROVIDER_CONFIG_KEY,
|
|
22
|
-
OPENAI_API_TYPE,
|
|
23
|
-
OPENAI_PROVIDER,
|
|
24
26
|
)
|
|
25
27
|
from rasa.shared.providers._configs.utils import (
|
|
26
|
-
resolve_aliases,
|
|
27
|
-
validate_required_keys,
|
|
28
28
|
raise_deprecation_warnings,
|
|
29
|
+
resolve_aliases,
|
|
29
30
|
validate_forbidden_keys,
|
|
31
|
+
validate_required_keys,
|
|
30
32
|
)
|
|
31
33
|
|
|
32
34
|
structlogger = structlog.get_logger()
|
|
@@ -111,7 +113,7 @@ class OpenAIClientConfig:
|
|
|
111
113
|
raise ValueError(message)
|
|
112
114
|
|
|
113
115
|
@classmethod
|
|
114
|
-
def from_dict(cls, config: dict) ->
|
|
116
|
+
def from_dict(cls, config: dict) -> OpenAIClientConfig:
|
|
115
117
|
"""
|
|
116
118
|
Initializes a dataclass from the passed config.
|
|
117
119
|
|
|
@@ -1,13 +1,15 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
from dataclasses import asdict, dataclass, field
|
|
2
4
|
from typing import Optional
|
|
3
5
|
|
|
4
6
|
import structlog
|
|
5
7
|
|
|
6
8
|
from rasa.shared.constants import (
|
|
9
|
+
API_BASE_CONFIG_KEY,
|
|
7
10
|
MODEL_CONFIG_KEY,
|
|
8
|
-
RASA_PROVIDER,
|
|
9
11
|
PROVIDER_CONFIG_KEY,
|
|
10
|
-
|
|
12
|
+
RASA_PROVIDER,
|
|
11
13
|
)
|
|
12
14
|
from rasa.shared.providers._configs.utils import (
|
|
13
15
|
validate_required_keys,
|
|
@@ -37,7 +39,7 @@ class RasaLLMClientConfig:
|
|
|
37
39
|
extra_parameters: dict = field(default_factory=dict)
|
|
38
40
|
|
|
39
41
|
@classmethod
|
|
40
|
-
def from_dict(cls, config: dict) ->
|
|
42
|
+
def from_dict(cls, config: dict) -> RasaLLMClientConfig:
|
|
41
43
|
"""
|
|
42
44
|
Initializes a dataclass from the passed config.
|
|
43
45
|
|
|
@@ -1,29 +1,30 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
from dataclasses import asdict, dataclass, field
|
|
2
4
|
from typing import Any, Dict, Optional
|
|
3
5
|
|
|
4
6
|
import structlog
|
|
5
7
|
|
|
6
8
|
from rasa.shared.constants import (
|
|
9
|
+
API_BASE_CONFIG_KEY,
|
|
10
|
+
API_TYPE_CONFIG_KEY,
|
|
11
|
+
API_VERSION_CONFIG_KEY,
|
|
12
|
+
LANGCHAIN_TYPE_CONFIG_KEY,
|
|
7
13
|
MODEL_CONFIG_KEY,
|
|
8
14
|
MODEL_NAME_CONFIG_KEY,
|
|
15
|
+
N_REPHRASES_CONFIG_KEY,
|
|
9
16
|
OPENAI_API_BASE_CONFIG_KEY,
|
|
10
|
-
API_BASE_CONFIG_KEY,
|
|
11
17
|
OPENAI_API_TYPE_CONFIG_KEY,
|
|
12
|
-
API_TYPE_CONFIG_KEY,
|
|
13
18
|
OPENAI_API_VERSION_CONFIG_KEY,
|
|
14
|
-
|
|
19
|
+
OPENAI_PROVIDER,
|
|
20
|
+
PROVIDER_CONFIG_KEY,
|
|
15
21
|
RASA_TYPE_CONFIG_KEY,
|
|
16
|
-
LANGCHAIN_TYPE_CONFIG_KEY,
|
|
17
|
-
STREAM_CONFIG_KEY,
|
|
18
|
-
N_REPHRASES_CONFIG_KEY,
|
|
19
22
|
REQUEST_TIMEOUT_CONFIG_KEY,
|
|
20
|
-
TIMEOUT_CONFIG_KEY,
|
|
21
|
-
PROVIDER_CONFIG_KEY,
|
|
22
|
-
OPENAI_PROVIDER,
|
|
23
23
|
SELF_HOSTED_PROVIDER,
|
|
24
|
+
STREAM_CONFIG_KEY,
|
|
25
|
+
TIMEOUT_CONFIG_KEY,
|
|
24
26
|
USE_CHAT_COMPLETIONS_ENDPOINT_CONFIG_KEY,
|
|
25
27
|
)
|
|
26
|
-
|
|
27
28
|
from rasa.shared.providers._configs.utils import (
|
|
28
29
|
raise_deprecation_warnings,
|
|
29
30
|
resolve_aliases,
|
|
@@ -114,7 +115,7 @@ class SelfHostedLLMClientConfig:
|
|
|
114
115
|
raise ValueError(message)
|
|
115
116
|
|
|
116
117
|
@classmethod
|
|
117
|
-
def from_dict(cls, config: dict) ->
|
|
118
|
+
def from_dict(cls, config: dict) -> SelfHostedLLMClientConfig:
|
|
118
119
|
"""
|
|
119
120
|
Initializes a dataclass from the passed config.
|
|
120
121
|
|