rasa-pro 3.12.0.dev13__py3-none-any.whl → 3.12.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/anonymization/anonymization_rule_executor.py +16 -10
- rasa/cli/data.py +16 -0
- rasa/cli/project_templates/calm/config.yml +2 -2
- rasa/cli/project_templates/calm/endpoints.yml +2 -2
- rasa/cli/utils.py +12 -0
- rasa/core/actions/action.py +84 -191
- rasa/core/actions/action_run_slot_rejections.py +16 -4
- rasa/core/channels/__init__.py +2 -0
- rasa/core/channels/studio_chat.py +19 -0
- rasa/core/channels/telegram.py +42 -24
- rasa/core/channels/voice_ready/utils.py +1 -1
- rasa/core/channels/voice_stream/asr/asr_engine.py +10 -4
- rasa/core/channels/voice_stream/asr/azure.py +14 -1
- rasa/core/channels/voice_stream/asr/deepgram.py +20 -4
- rasa/core/channels/voice_stream/audiocodes.py +264 -0
- rasa/core/channels/voice_stream/browser_audio.py +4 -1
- rasa/core/channels/voice_stream/call_state.py +3 -0
- rasa/core/channels/voice_stream/genesys.py +6 -2
- rasa/core/channels/voice_stream/tts/azure.py +9 -1
- rasa/core/channels/voice_stream/tts/cartesia.py +14 -8
- rasa/core/channels/voice_stream/voice_channel.py +23 -2
- rasa/core/constants.py +2 -0
- rasa/core/nlg/contextual_response_rephraser.py +18 -1
- rasa/core/nlg/generator.py +83 -15
- rasa/core/nlg/response.py +6 -3
- rasa/core/nlg/translate.py +55 -0
- rasa/core/policies/enterprise_search_prompt_with_citation_template.jinja2 +1 -1
- rasa/core/policies/flows/flow_executor.py +12 -5
- rasa/core/processor.py +72 -9
- rasa/dialogue_understanding/commands/can_not_handle_command.py +20 -2
- rasa/dialogue_understanding/commands/cancel_flow_command.py +24 -6
- rasa/dialogue_understanding/commands/change_flow_command.py +20 -2
- rasa/dialogue_understanding/commands/chit_chat_answer_command.py +20 -2
- rasa/dialogue_understanding/commands/clarify_command.py +29 -3
- rasa/dialogue_understanding/commands/command.py +1 -16
- rasa/dialogue_understanding/commands/command_syntax_manager.py +55 -0
- rasa/dialogue_understanding/commands/human_handoff_command.py +20 -2
- rasa/dialogue_understanding/commands/knowledge_answer_command.py +20 -2
- rasa/dialogue_understanding/commands/prompt_command.py +94 -0
- rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +20 -2
- rasa/dialogue_understanding/commands/set_slot_command.py +24 -2
- rasa/dialogue_understanding/commands/skip_question_command.py +20 -2
- rasa/dialogue_understanding/commands/start_flow_command.py +20 -2
- rasa/dialogue_understanding/commands/utils.py +98 -4
- rasa/dialogue_understanding/generator/__init__.py +2 -0
- rasa/dialogue_understanding/generator/command_parser.py +15 -12
- rasa/dialogue_understanding/generator/constants.py +3 -0
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +12 -5
- rasa/dialogue_understanding/generator/llm_command_generator.py +5 -3
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +16 -2
- rasa/dialogue_understanding/generator/prompt_templates/__init__.py +0 -0
- rasa/dialogue_understanding/generator/{single_step → prompt_templates}/command_prompt_template.jinja2 +2 -0
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_claude_3_5_sonnet_20240620_template.jinja2 +77 -0
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_default.jinja2 +68 -0
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_gpt_4o_2024_11_20_template.jinja2 +84 -0
- rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +460 -0
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +12 -310
- rasa/dialogue_understanding/patterns/collect_information.py +1 -1
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +16 -0
- rasa/dialogue_understanding/patterns/validate_slot.py +65 -0
- rasa/dialogue_understanding/processor/command_processor.py +39 -0
- rasa/dialogue_understanding_test/du_test_case.py +28 -8
- rasa/dialogue_understanding_test/du_test_result.py +13 -9
- rasa/dialogue_understanding_test/io.py +14 -0
- rasa/e2e_test/utils/io.py +0 -37
- rasa/engine/graph.py +1 -0
- rasa/engine/language.py +140 -0
- rasa/engine/recipes/config_files/default_config.yml +4 -0
- rasa/engine/recipes/default_recipe.py +2 -0
- rasa/engine/recipes/graph_recipe.py +2 -0
- rasa/engine/storage/local_model_storage.py +1 -0
- rasa/engine/storage/storage.py +4 -1
- rasa/model_manager/runner_service.py +7 -4
- rasa/model_manager/socket_bridge.py +7 -6
- rasa/shared/constants.py +15 -13
- rasa/shared/core/constants.py +2 -0
- rasa/shared/core/flows/constants.py +11 -0
- rasa/shared/core/flows/flow.py +83 -19
- rasa/shared/core/flows/flows_yaml_schema.json +31 -3
- rasa/shared/core/flows/steps/collect.py +1 -36
- rasa/shared/core/flows/utils.py +28 -4
- rasa/shared/core/flows/validation.py +1 -1
- rasa/shared/core/slot_mappings.py +208 -5
- rasa/shared/core/slots.py +131 -1
- rasa/shared/core/trackers.py +74 -1
- rasa/shared/importers/importer.py +50 -2
- rasa/shared/nlu/training_data/schemas/responses.yml +19 -12
- rasa/shared/providers/_configs/azure_entra_id_config.py +541 -0
- rasa/shared/providers/_configs/azure_openai_client_config.py +138 -3
- rasa/shared/providers/_configs/client_config.py +3 -1
- rasa/shared/providers/_configs/default_litellm_client_config.py +3 -1
- rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +3 -1
- rasa/shared/providers/_configs/litellm_router_client_config.py +3 -1
- rasa/shared/providers/_configs/model_group_config.py +4 -2
- rasa/shared/providers/_configs/oauth_config.py +33 -0
- rasa/shared/providers/_configs/openai_client_config.py +3 -1
- rasa/shared/providers/_configs/rasa_llm_client_config.py +3 -1
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +3 -1
- rasa/shared/providers/constants.py +6 -0
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +28 -3
- rasa/shared/providers/embedding/litellm_router_embedding_client.py +3 -1
- rasa/shared/providers/llm/_base_litellm_client.py +42 -17
- rasa/shared/providers/llm/azure_openai_llm_client.py +81 -25
- rasa/shared/providers/llm/default_litellm_llm_client.py +3 -1
- rasa/shared/providers/llm/litellm_router_llm_client.py +29 -8
- rasa/shared/providers/llm/llm_client.py +23 -7
- rasa/shared/providers/llm/openai_llm_client.py +9 -3
- rasa/shared/providers/llm/rasa_llm_client.py +11 -2
- rasa/shared/providers/llm/self_hosted_llm_client.py +30 -11
- rasa/shared/providers/router/_base_litellm_router_client.py +3 -1
- rasa/shared/providers/router/router_client.py +3 -1
- rasa/shared/utils/constants.py +3 -0
- rasa/shared/utils/llm.py +30 -7
- rasa/shared/utils/pykwalify_extensions.py +24 -0
- rasa/shared/utils/schemas/domain.yml +26 -0
- rasa/telemetry.py +2 -1
- rasa/tracing/config.py +2 -0
- rasa/tracing/constants.py +12 -0
- rasa/tracing/instrumentation/instrumentation.py +36 -0
- rasa/tracing/instrumentation/metrics.py +41 -0
- rasa/tracing/metric_instrument_provider.py +40 -0
- rasa/validator.py +372 -7
- rasa/version.py +1 -1
- {rasa_pro-3.12.0.dev13.dist-info → rasa_pro-3.12.0rc1.dist-info}/METADATA +2 -1
- {rasa_pro-3.12.0.dev13.dist-info → rasa_pro-3.12.0rc1.dist-info}/RECORD +128 -113
- {rasa_pro-3.12.0.dev13.dist-info → rasa_pro-3.12.0rc1.dist-info}/NOTICE +0 -0
- {rasa_pro-3.12.0.dev13.dist-info → rasa_pro-3.12.0rc1.dist-info}/WHEEL +0 -0
- {rasa_pro-3.12.0.dev13.dist-info → rasa_pro-3.12.0rc1.dist-info}/entry_points.txt +0 -0
|
@@ -1,10 +1,20 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from copy import deepcopy
|
|
1
4
|
from dataclasses import asdict, dataclass, field
|
|
2
|
-
from typing import
|
|
5
|
+
from typing import (
|
|
6
|
+
Any,
|
|
7
|
+
Dict,
|
|
8
|
+
Optional,
|
|
9
|
+
Set,
|
|
10
|
+
)
|
|
3
11
|
|
|
4
12
|
import structlog
|
|
13
|
+
from pydantic import BaseModel
|
|
5
14
|
|
|
6
15
|
from rasa.shared.constants import (
|
|
7
16
|
API_BASE_CONFIG_KEY,
|
|
17
|
+
API_KEY,
|
|
8
18
|
API_TYPE_CONFIG_KEY,
|
|
9
19
|
API_VERSION_CONFIG_KEY,
|
|
10
20
|
AZURE_API_TYPE,
|
|
@@ -25,12 +35,22 @@ from rasa.shared.constants import (
|
|
|
25
35
|
STREAM_CONFIG_KEY,
|
|
26
36
|
TIMEOUT_CONFIG_KEY,
|
|
27
37
|
)
|
|
38
|
+
from rasa.shared.providers._configs.azure_entra_id_config import (
|
|
39
|
+
AzureEntraIDOAuthConfig,
|
|
40
|
+
AzureEntraIDOAuthType,
|
|
41
|
+
)
|
|
42
|
+
from rasa.shared.providers._configs.oauth_config import (
|
|
43
|
+
OAUTH_KEY,
|
|
44
|
+
OAUTH_TYPE_FIELD,
|
|
45
|
+
OAuth,
|
|
46
|
+
)
|
|
28
47
|
from rasa.shared.providers._configs.utils import (
|
|
29
48
|
raise_deprecation_warnings,
|
|
30
49
|
resolve_aliases,
|
|
31
50
|
validate_forbidden_keys,
|
|
32
51
|
validate_required_keys,
|
|
33
52
|
)
|
|
53
|
+
from rasa.shared.utils.common import class_from_module_path
|
|
34
54
|
|
|
35
55
|
structlogger = structlog.get_logger()
|
|
36
56
|
|
|
@@ -61,6 +81,86 @@ FORBIDDEN_KEYS = [
|
|
|
61
81
|
]
|
|
62
82
|
|
|
63
83
|
|
|
84
|
+
class OAuthConfigWrapper(OAuth, BaseModel):
|
|
85
|
+
"""Wrapper for OAuth configuration.
|
|
86
|
+
|
|
87
|
+
It's main purpose is to provide to_dict method which is used to serialize
|
|
88
|
+
the oauth configuration to the original format.
|
|
89
|
+
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
# Pydantic configuration to allow arbitrary user defined types
|
|
93
|
+
class Config:
|
|
94
|
+
arbitrary_types_allowed = True
|
|
95
|
+
|
|
96
|
+
oauth: OAuth
|
|
97
|
+
original_config: Dict[str, Any]
|
|
98
|
+
|
|
99
|
+
def get_bearer_token(self) -> str:
|
|
100
|
+
"""Returns a bearer token."""
|
|
101
|
+
return self.oauth.get_bearer_token()
|
|
102
|
+
|
|
103
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
104
|
+
"""Converts the OAuth configuration to the original format."""
|
|
105
|
+
return self.original_config
|
|
106
|
+
|
|
107
|
+
@staticmethod
|
|
108
|
+
def _valid_type_values() -> Set[str]:
|
|
109
|
+
"""Returns the valid built-in values for the `type` field in the `oauth`."""
|
|
110
|
+
return AzureEntraIDOAuthType.valid_string_values()
|
|
111
|
+
|
|
112
|
+
@classmethod
|
|
113
|
+
def from_dict(cls, oauth_config: Dict[str, Any]) -> OAuthConfigWrapper:
|
|
114
|
+
"""Initializes a dataclass from the passed config.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
oauth_config: (dict) The config from which to initialize.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
AzureOAuthConfig
|
|
121
|
+
"""
|
|
122
|
+
original_config = deepcopy(oauth_config)
|
|
123
|
+
|
|
124
|
+
oauth_type: Optional[str] = oauth_config.get(OAUTH_TYPE_FIELD, None)
|
|
125
|
+
|
|
126
|
+
if oauth_type is None:
|
|
127
|
+
message = (
|
|
128
|
+
"Oauth configuration must contain "
|
|
129
|
+
f"'{OAUTH_TYPE_FIELD}' field and it must be set to one of the "
|
|
130
|
+
f"following values: {OAuthConfigWrapper._valid_type_values()}, "
|
|
131
|
+
f"or to the path of module which is "
|
|
132
|
+
f"implementing {OAuth.__name__} protocol."
|
|
133
|
+
)
|
|
134
|
+
structlogger.error(
|
|
135
|
+
"azure_oauth_config.missing_oauth_type",
|
|
136
|
+
message=message,
|
|
137
|
+
)
|
|
138
|
+
raise ValueError(message)
|
|
139
|
+
|
|
140
|
+
if oauth_type in AzureEntraIDOAuthType.valid_string_values():
|
|
141
|
+
return cls(
|
|
142
|
+
oauth=AzureEntraIDOAuthConfig.from_dict(oauth_config),
|
|
143
|
+
original_config=original_config,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
module = class_from_module_path(oauth_type)
|
|
147
|
+
|
|
148
|
+
if not issubclass(module, OAuth):
|
|
149
|
+
message = (
|
|
150
|
+
f"Module {oauth_type} does not implement "
|
|
151
|
+
f"{OAuth.__name__} interface."
|
|
152
|
+
)
|
|
153
|
+
structlogger.error(
|
|
154
|
+
"azure_oauth_config.invalid_oauth_module",
|
|
155
|
+
message=message,
|
|
156
|
+
)
|
|
157
|
+
raise ValueError(message)
|
|
158
|
+
|
|
159
|
+
return cls(
|
|
160
|
+
oauth=module.from_dict(oauth_config), original_config=original_config
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
|
|
64
164
|
@dataclass
|
|
65
165
|
class AzureOpenAIClientConfig:
|
|
66
166
|
"""Parses configuration for Azure OpenAI client, resolves aliases and
|
|
@@ -80,11 +180,13 @@ class AzureOpenAIClientConfig:
|
|
|
80
180
|
# API Type is not used by LiteLLM backend, but we define
|
|
81
181
|
# it here for backward compatibility.
|
|
82
182
|
api_type: Optional[str] = AZURE_API_TYPE
|
|
83
|
-
|
|
84
183
|
# Provider is not used by LiteLLM backend, but we define it here since it's
|
|
85
184
|
# used as switch between different clients.
|
|
86
185
|
provider: str = AZURE_OPENAI_PROVIDER
|
|
87
186
|
|
|
187
|
+
# OAuth related parameters
|
|
188
|
+
oauth: Optional[OAuthConfigWrapper] = None
|
|
189
|
+
|
|
88
190
|
extra_parameters: dict = field(default_factory=dict)
|
|
89
191
|
|
|
90
192
|
def __post_init__(self) -> None:
|
|
@@ -106,7 +208,7 @@ class AzureOpenAIClientConfig:
|
|
|
106
208
|
raise ValueError(message)
|
|
107
209
|
|
|
108
210
|
@classmethod
|
|
109
|
-
def from_dict(cls, config: dict) ->
|
|
211
|
+
def from_dict(cls, config: dict) -> AzureOpenAIClientConfig:
|
|
110
212
|
"""Initializes a dataclass from the passed config.
|
|
111
213
|
|
|
112
214
|
Args:
|
|
@@ -129,6 +231,16 @@ class AzureOpenAIClientConfig:
|
|
|
129
231
|
# Validate that the forbidden keys are not present
|
|
130
232
|
validate_forbidden_keys(config, FORBIDDEN_KEYS)
|
|
131
233
|
# Init client config
|
|
234
|
+
|
|
235
|
+
cls._validate_authentication_configuration(config)
|
|
236
|
+
|
|
237
|
+
has_oauth_key = config.get(OAUTH_KEY, None) is not None
|
|
238
|
+
oauth = (
|
|
239
|
+
OAuthConfigWrapper.from_dict(config.pop(OAUTH_KEY))
|
|
240
|
+
if has_oauth_key
|
|
241
|
+
else None
|
|
242
|
+
)
|
|
243
|
+
|
|
132
244
|
this = AzureOpenAIClientConfig(
|
|
133
245
|
# Required parameters
|
|
134
246
|
deployment=config.pop(DEPLOYMENT_CONFIG_KEY),
|
|
@@ -142,6 +254,8 @@ class AzureOpenAIClientConfig:
|
|
|
142
254
|
# in clients.
|
|
143
255
|
api_base=config.pop(API_BASE_CONFIG_KEY, None),
|
|
144
256
|
api_version=config.pop(API_VERSION_CONFIG_KEY, None),
|
|
257
|
+
# OAuth related parameters, set only if auth_type is set to 'entra_id'
|
|
258
|
+
oauth=oauth,
|
|
145
259
|
# The rest of parameters (e.g. model parameters) are considered
|
|
146
260
|
# as extra parameters (this also includes timeout).
|
|
147
261
|
extra_parameters=config,
|
|
@@ -154,12 +268,33 @@ class AzureOpenAIClientConfig:
|
|
|
154
268
|
# Extra parameters should also be on the top level
|
|
155
269
|
d.pop("extra_parameters", None)
|
|
156
270
|
d.update(self.extra_parameters)
|
|
271
|
+
|
|
272
|
+
d.pop("oauth", None)
|
|
273
|
+
d.update({"oauth": self.oauth.to_dict()} if self.oauth else {})
|
|
157
274
|
return d
|
|
158
275
|
|
|
159
276
|
@staticmethod
|
|
160
277
|
def resolve_config_aliases(config: Dict[str, Any]) -> Dict[str, Any]:
|
|
161
278
|
return resolve_aliases(config, DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING)
|
|
162
279
|
|
|
280
|
+
@staticmethod
|
|
281
|
+
def _validate_authentication_configuration(config: Dict[str, Any]) -> None:
|
|
282
|
+
"""Validates the authentication configuration."""
|
|
283
|
+
has_api_key = config.get(API_KEY, None) is not None
|
|
284
|
+
has_oauth_key = config.get(OAUTH_KEY, None) is not None
|
|
285
|
+
|
|
286
|
+
if has_api_key and has_oauth_key:
|
|
287
|
+
message = (
|
|
288
|
+
"Azure OpenAI client configuration cannot contain "
|
|
289
|
+
f"both '{API_KEY}' and '{OAUTH_KEY}' fields. Please provide either "
|
|
290
|
+
f"'{API_KEY}' or '{OAUTH_KEY}' fields."
|
|
291
|
+
)
|
|
292
|
+
structlogger.error(
|
|
293
|
+
"azure_openai_client_config.multiple_auth_types_specified",
|
|
294
|
+
message=message,
|
|
295
|
+
)
|
|
296
|
+
raise ValueError(message)
|
|
297
|
+
|
|
163
298
|
|
|
164
299
|
def is_azure_openai_config(config: dict) -> bool:
|
|
165
300
|
"""Check whether the configuration is meant to configure
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
from typing import Protocol, runtime_checkable
|
|
2
4
|
|
|
3
5
|
|
|
@@ -9,7 +11,7 @@ class ClientConfig(Protocol):
|
|
|
9
11
|
"""
|
|
10
12
|
|
|
11
13
|
@classmethod
|
|
12
|
-
def from_dict(cls, config: dict) ->
|
|
14
|
+
def from_dict(cls, config: dict) -> ClientConfig:
|
|
13
15
|
"""
|
|
14
16
|
Initializes the client config with the given configuration.
|
|
15
17
|
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
from dataclasses import asdict, dataclass, field
|
|
2
4
|
from typing import Any, Dict
|
|
3
5
|
|
|
@@ -69,7 +71,7 @@ class DefaultLiteLLMClientConfig:
|
|
|
69
71
|
raise ValueError(message)
|
|
70
72
|
|
|
71
73
|
@classmethod
|
|
72
|
-
def from_dict(cls, config: dict) ->
|
|
74
|
+
def from_dict(cls, config: dict) -> DefaultLiteLLMClientConfig:
|
|
73
75
|
"""
|
|
74
76
|
Initializes a dataclass from the passed config.
|
|
75
77
|
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
from dataclasses import asdict, dataclass, field
|
|
2
4
|
from typing import Any, Dict, Optional
|
|
3
5
|
|
|
@@ -90,7 +92,7 @@ class HuggingFaceLocalEmbeddingClientConfig:
|
|
|
90
92
|
raise ValueError(message)
|
|
91
93
|
|
|
92
94
|
@classmethod
|
|
93
|
-
def from_dict(cls, config: dict) ->
|
|
95
|
+
def from_dict(cls, config: dict) -> HuggingFaceLocalEmbeddingClientConfig:
|
|
94
96
|
"""
|
|
95
97
|
Initializes a dataclass from the passed config.
|
|
96
98
|
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import copy
|
|
2
4
|
from dataclasses import dataclass, field
|
|
3
5
|
from typing import Any, Dict, List
|
|
@@ -120,7 +122,7 @@ class LiteLLMRouterClientConfig:
|
|
|
120
122
|
raise ValueError(message)
|
|
121
123
|
|
|
122
124
|
@classmethod
|
|
123
|
-
def from_dict(cls, config: dict) ->
|
|
125
|
+
def from_dict(cls, config: dict) -> LiteLLMRouterClientConfig:
|
|
124
126
|
"""Initializes a dataclass from the passed config.
|
|
125
127
|
|
|
126
128
|
Args:
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
from dataclasses import asdict, dataclass, field
|
|
2
4
|
from typing import List, Optional
|
|
3
5
|
|
|
@@ -41,7 +43,7 @@ class ModelConfig:
|
|
|
41
43
|
api_type: Optional[str] = None
|
|
42
44
|
|
|
43
45
|
@classmethod
|
|
44
|
-
def from_dict(cls, config: dict) ->
|
|
46
|
+
def from_dict(cls, config: dict) -> ModelConfig:
|
|
45
47
|
"""Initializes a dataclass from the passed config. The provider config param is
|
|
46
48
|
used to determine the client config class to use. The client config class takes
|
|
47
49
|
care of resolving config aliases and throwing deprecation warnings.
|
|
@@ -131,7 +133,7 @@ class ModelGroupConfig:
|
|
|
131
133
|
raise ValueError(message)
|
|
132
134
|
|
|
133
135
|
@classmethod
|
|
134
|
-
def from_dict(cls, config: dict) ->
|
|
136
|
+
def from_dict(cls, config: dict) -> ModelGroupConfig:
|
|
135
137
|
"""Initializes a dataclass from the passed config.
|
|
136
138
|
|
|
137
139
|
Args:
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
from typing import Any, Dict, Type, TypeVar
|
|
3
|
+
|
|
4
|
+
OAUTH_TYPE_FIELD = "type"
|
|
5
|
+
OAUTH_KEY = "oauth"
|
|
6
|
+
|
|
7
|
+
OAuthType = TypeVar("OAuthType", bound="OAuth")
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class OAuth(abc.ABC):
|
|
11
|
+
"""Interface for OAuth configuration."""
|
|
12
|
+
|
|
13
|
+
@classmethod
|
|
14
|
+
@abc.abstractmethod
|
|
15
|
+
def from_dict(
|
|
16
|
+
cls: Type[OAuthType], config: Dict[str, Any]
|
|
17
|
+
) -> OAuthType: # ignore[type]
|
|
18
|
+
"""Initializes a dataclass from the passed config.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
config: (dict) The config from which to initialize.
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
OAuth
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
@abc.abstractmethod
|
|
28
|
+
def get_bearer_token(self) -> str:
|
|
29
|
+
"""Returns a bearer token.
|
|
30
|
+
|
|
31
|
+
Bear token is used to authenticate requests to the Azure
|
|
32
|
+
Oopen AI instance's API protected by the Gateway.
|
|
33
|
+
"""
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
from dataclasses import asdict, dataclass, field
|
|
2
4
|
from typing import Any, Dict, Optional
|
|
3
5
|
|
|
@@ -111,7 +113,7 @@ class OpenAIClientConfig:
|
|
|
111
113
|
raise ValueError(message)
|
|
112
114
|
|
|
113
115
|
@classmethod
|
|
114
|
-
def from_dict(cls, config: dict) ->
|
|
116
|
+
def from_dict(cls, config: dict) -> OpenAIClientConfig:
|
|
115
117
|
"""
|
|
116
118
|
Initializes a dataclass from the passed config.
|
|
117
119
|
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
from dataclasses import asdict, dataclass, field
|
|
2
4
|
from typing import Optional
|
|
3
5
|
|
|
@@ -37,7 +39,7 @@ class RasaLLMClientConfig:
|
|
|
37
39
|
extra_parameters: dict = field(default_factory=dict)
|
|
38
40
|
|
|
39
41
|
@classmethod
|
|
40
|
-
def from_dict(cls, config: dict) ->
|
|
42
|
+
def from_dict(cls, config: dict) -> RasaLLMClientConfig:
|
|
41
43
|
"""
|
|
42
44
|
Initializes a dataclass from the passed config.
|
|
43
45
|
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
from dataclasses import asdict, dataclass, field
|
|
2
4
|
from typing import Any, Dict, Optional
|
|
3
5
|
|
|
@@ -113,7 +115,7 @@ class SelfHostedLLMClientConfig:
|
|
|
113
115
|
raise ValueError(message)
|
|
114
116
|
|
|
115
117
|
@classmethod
|
|
116
|
-
def from_dict(cls, config: dict) ->
|
|
118
|
+
def from_dict(cls, config: dict) -> SelfHostedLLMClientConfig:
|
|
117
119
|
"""
|
|
118
120
|
Initializes a dataclass from the passed config.
|
|
119
121
|
|
|
@@ -19,8 +19,14 @@ from rasa.shared.constants import (
|
|
|
19
19
|
)
|
|
20
20
|
from rasa.shared.exceptions import ProviderClientValidationError
|
|
21
21
|
from rasa.shared.providers._configs.azure_openai_client_config import (
|
|
22
|
+
AzureEntraIDOAuthConfig,
|
|
22
23
|
AzureOpenAIClientConfig,
|
|
23
24
|
)
|
|
25
|
+
from rasa.shared.providers.constants import (
|
|
26
|
+
DEFAULT_AZURE_API_KEY_NAME,
|
|
27
|
+
LITE_LLM_API_KEY_FIELD,
|
|
28
|
+
LITE_LLM_AZURE_AD_TOKEN,
|
|
29
|
+
)
|
|
24
30
|
from rasa.shared.providers.embedding._base_litellm_embedding_client import (
|
|
25
31
|
_BaseLiteLLMEmbeddingClient,
|
|
26
32
|
)
|
|
@@ -41,6 +47,8 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
|
|
|
41
47
|
If not provided, it will be set via environment variable.
|
|
42
48
|
api_version (Optional[str]): The version of the API to use.
|
|
43
49
|
If not provided, it will be set via environment variable.
|
|
50
|
+
oauth (Optional[AzureEntraIDOAuthConfig]): Optional OAuth configuration.
|
|
51
|
+
If provided, the client will use OAuth for authentication.
|
|
44
52
|
kwargs (Optional[Dict[str, Any]]): Optional configuration parameters specific
|
|
45
53
|
to the embedding model deployment.
|
|
46
54
|
|
|
@@ -57,6 +65,7 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
|
|
|
57
65
|
api_base: Optional[str] = None,
|
|
58
66
|
api_type: Optional[str] = None,
|
|
59
67
|
api_version: Optional[str] = None,
|
|
68
|
+
oauth: Optional[AzureEntraIDOAuthConfig] = None,
|
|
60
69
|
**kwargs: Any,
|
|
61
70
|
):
|
|
62
71
|
super().__init__() # type: ignore
|
|
@@ -84,7 +93,11 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
|
|
|
84
93
|
# Litellm does not support use of OPENAI_API_KEY, so we need to map it
|
|
85
94
|
# because of backward compatibility. However, we're first looking at
|
|
86
95
|
# AZURE_API_KEY.
|
|
87
|
-
|
|
96
|
+
|
|
97
|
+
self._oauth = oauth
|
|
98
|
+
self._api_key_env_var = (
|
|
99
|
+
self._resolve_api_key_env_var() if not self._oauth else None
|
|
100
|
+
)
|
|
88
101
|
|
|
89
102
|
self.validate_client_setup()
|
|
90
103
|
|
|
@@ -100,7 +113,7 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
|
|
|
100
113
|
return self._extra_parameters[API_KEY]
|
|
101
114
|
|
|
102
115
|
if os.getenv(AZURE_API_KEY_ENV_VAR) is not None:
|
|
103
|
-
return "${
|
|
116
|
+
return f"${{{DEFAULT_AZURE_API_KEY_NAME}}}"
|
|
104
117
|
|
|
105
118
|
if os.getenv(OPENAI_API_KEY_ENV_VAR) is not None:
|
|
106
119
|
# API key can be set through OPENAI_API_KEY too,
|
|
@@ -163,6 +176,7 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
|
|
|
163
176
|
api_base=azure_openai_config.api_base,
|
|
164
177
|
api_type=azure_openai_config.api_type,
|
|
165
178
|
api_version=azure_openai_config.api_version,
|
|
179
|
+
oauth=azure_openai_config.oauth,
|
|
166
180
|
**azure_openai_config.extra_parameters,
|
|
167
181
|
)
|
|
168
182
|
|
|
@@ -177,6 +191,7 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
|
|
|
177
191
|
api_base=self.api_base,
|
|
178
192
|
api_type=self.api_type,
|
|
179
193
|
api_version=self.api_version,
|
|
194
|
+
oauth=self._oauth,
|
|
180
195
|
extra_parameters=self._extra_parameters,
|
|
181
196
|
)
|
|
182
197
|
return config.to_dict()
|
|
@@ -219,13 +234,23 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
|
|
|
219
234
|
|
|
220
235
|
@property
|
|
221
236
|
def _embedding_fn_args(self) -> dict:
|
|
237
|
+
auth_parameter: Dict[str, str] = {}
|
|
238
|
+
|
|
239
|
+
if self._oauth:
|
|
240
|
+
auth_parameter = {
|
|
241
|
+
**auth_parameter,
|
|
242
|
+
LITE_LLM_AZURE_AD_TOKEN: self._oauth.get_bearer_token(),
|
|
243
|
+
}
|
|
244
|
+
elif self._api_key_env_var:
|
|
245
|
+
auth_parameter = {LITE_LLM_API_KEY_FIELD: self._api_key_env_var}
|
|
246
|
+
|
|
222
247
|
return {
|
|
223
248
|
**self._litellm_extra_parameters,
|
|
224
249
|
"model": self._litellm_model_name,
|
|
225
250
|
"api_base": self.api_base,
|
|
226
251
|
"api_type": self.api_type,
|
|
227
252
|
"api_version": self.api_version,
|
|
228
|
-
|
|
253
|
+
**auth_parameter,
|
|
229
254
|
}
|
|
230
255
|
|
|
231
256
|
@property
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import logging
|
|
2
4
|
from typing import Any, Dict, List
|
|
3
5
|
|
|
@@ -46,7 +48,7 @@ class LiteLLMRouterEmbeddingClient(
|
|
|
46
48
|
)
|
|
47
49
|
|
|
48
50
|
@classmethod
|
|
49
|
-
def from_config(cls, config: Dict[str, Any]) ->
|
|
51
|
+
def from_config(cls, config: Dict[str, Any]) -> LiteLLMRouterEmbeddingClient:
|
|
50
52
|
"""Instantiates a LiteLLM Router Embedding client from a configuration dict.
|
|
51
53
|
|
|
52
54
|
Args:
|
|
@@ -1,15 +1,13 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import logging
|
|
2
4
|
from abc import abstractmethod
|
|
3
|
-
from typing import Any, Dict, List, Union
|
|
5
|
+
from typing import Any, Dict, List, Union, cast
|
|
4
6
|
|
|
5
7
|
import structlog
|
|
6
|
-
from litellm import
|
|
7
|
-
acompletion,
|
|
8
|
-
completion,
|
|
9
|
-
validate_environment,
|
|
10
|
-
)
|
|
8
|
+
from litellm import acompletion, completion, validate_environment
|
|
11
9
|
|
|
12
|
-
from rasa.shared.constants import API_BASE_CONFIG_KEY, API_KEY
|
|
10
|
+
from rasa.shared.constants import API_BASE_CONFIG_KEY, API_KEY, ROLE_USER
|
|
13
11
|
from rasa.shared.exceptions import (
|
|
14
12
|
ProviderClientAPIException,
|
|
15
13
|
ProviderClientValidationError,
|
|
@@ -50,7 +48,7 @@ class _BaseLiteLLMClient:
|
|
|
50
48
|
|
|
51
49
|
@classmethod
|
|
52
50
|
@abstractmethod
|
|
53
|
-
def from_config(cls, config: Dict[str, Any]) ->
|
|
51
|
+
def from_config(cls, config: Dict[str, Any]) -> _BaseLiteLLMClient:
|
|
54
52
|
pass
|
|
55
53
|
|
|
56
54
|
@property
|
|
@@ -122,12 +120,18 @@ class _BaseLiteLLMClient:
|
|
|
122
120
|
raise ProviderClientValidationError(event_info)
|
|
123
121
|
|
|
124
122
|
@suppress_logs(log_level=logging.WARNING)
|
|
125
|
-
def completion(self, messages: Union[List[str], str]) -> LLMResponse:
|
|
123
|
+
def completion(self, messages: Union[List[dict], List[str], str]) -> LLMResponse:
|
|
126
124
|
"""Synchronously generate completions for given list of messages.
|
|
127
125
|
|
|
128
126
|
Args:
|
|
129
|
-
messages:
|
|
130
|
-
|
|
127
|
+
messages: The message can be,
|
|
128
|
+
- a list of preformatted messages. Each message should be a dictionary
|
|
129
|
+
with the following keys:
|
|
130
|
+
- content: The message content.
|
|
131
|
+
- role: The role of the message (e.g. user or system).
|
|
132
|
+
- a list of messages. Each message is a string and will be formatted
|
|
133
|
+
as a user message.
|
|
134
|
+
- a single message as a string which will be formatted as user message.
|
|
131
135
|
|
|
132
136
|
Returns:
|
|
133
137
|
List of message completions.
|
|
@@ -136,7 +140,7 @@ class _BaseLiteLLMClient:
|
|
|
136
140
|
ProviderClientAPIException: If the API request fails.
|
|
137
141
|
"""
|
|
138
142
|
try:
|
|
139
|
-
formatted_messages = self.
|
|
143
|
+
formatted_messages = self._get_formatted_messages(messages)
|
|
140
144
|
arguments = resolve_environment_variables(self._completion_fn_args)
|
|
141
145
|
response = completion(messages=formatted_messages, **arguments)
|
|
142
146
|
return self._format_response(response)
|
|
@@ -144,12 +148,20 @@ class _BaseLiteLLMClient:
|
|
|
144
148
|
raise ProviderClientAPIException(e)
|
|
145
149
|
|
|
146
150
|
@suppress_logs(log_level=logging.WARNING)
|
|
147
|
-
async def acompletion(
|
|
151
|
+
async def acompletion(
|
|
152
|
+
self, messages: Union[List[dict], List[str], str]
|
|
153
|
+
) -> LLMResponse:
|
|
148
154
|
"""Asynchronously generate completions for given list of messages.
|
|
149
155
|
|
|
150
156
|
Args:
|
|
151
|
-
messages:
|
|
152
|
-
|
|
157
|
+
messages: The message can be,
|
|
158
|
+
- a list of preformatted messages. Each message should be a dictionary
|
|
159
|
+
with the following keys:
|
|
160
|
+
- content: The message content.
|
|
161
|
+
- role: The role of the message (e.g. user or system).
|
|
162
|
+
- a list of messages. Each message is a string and will be formatted
|
|
163
|
+
as a user message.
|
|
164
|
+
- a single message as a string which will be formatted as user message.
|
|
153
165
|
|
|
154
166
|
Returns:
|
|
155
167
|
List of message completions.
|
|
@@ -158,7 +170,7 @@ class _BaseLiteLLMClient:
|
|
|
158
170
|
ProviderClientAPIException: If the API request fails.
|
|
159
171
|
"""
|
|
160
172
|
try:
|
|
161
|
-
formatted_messages = self.
|
|
173
|
+
formatted_messages = self._get_formatted_messages(messages)
|
|
162
174
|
arguments = resolve_environment_variables(self._completion_fn_args)
|
|
163
175
|
response = await acompletion(messages=formatted_messages, **arguments)
|
|
164
176
|
return self._format_response(response)
|
|
@@ -181,11 +193,24 @@ class _BaseLiteLLMClient:
|
|
|
181
193
|
)
|
|
182
194
|
raise ProviderClientAPIException(e, message)
|
|
183
195
|
|
|
196
|
+
def _get_formatted_messages(
|
|
197
|
+
self, messages: Union[List[dict], List[str], str]
|
|
198
|
+
) -> List[Dict[str, str]]:
|
|
199
|
+
"""Returns a list of formatted messages."""
|
|
200
|
+
if (
|
|
201
|
+
isinstance(messages, list)
|
|
202
|
+
and len(messages) > 0
|
|
203
|
+
and isinstance(messages[0], dict)
|
|
204
|
+
):
|
|
205
|
+
# Check if the messages are already formatted. If so, return them as is.
|
|
206
|
+
return cast(List[Dict[str, str]], messages)
|
|
207
|
+
return self._format_messages(messages)
|
|
208
|
+
|
|
184
209
|
def _format_messages(self, messages: Union[List[str], str]) -> List[Dict[str, str]]:
|
|
185
210
|
"""Formats messages (or a single message) to OpenAI format."""
|
|
186
211
|
if isinstance(messages, str):
|
|
187
212
|
messages = [messages]
|
|
188
|
-
return [{"content": message, "role":
|
|
213
|
+
return [{"content": message, "role": ROLE_USER} for message in messages]
|
|
189
214
|
|
|
190
215
|
def _format_response(self, response: Any) -> LLMResponse:
|
|
191
216
|
"""Parses the LiteLLM response to Rasa format."""
|