rasa-pro 3.11.3a1.dev7__py3-none-any.whl → 3.12.0.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/cli/arguments/default_arguments.py +1 -1
- rasa/cli/dialogue_understanding_test.py +251 -0
- rasa/core/actions/action.py +7 -16
- rasa/core/channels/__init__.py +0 -2
- rasa/core/channels/socketio.py +23 -1
- rasa/core/nlg/contextual_response_rephraser.py +9 -62
- rasa/core/policies/enterprise_search_policy.py +12 -77
- rasa/core/policies/flows/flow_executor.py +2 -26
- rasa/core/processor.py +8 -11
- rasa/dialogue_understanding/generator/command_generator.py +49 -43
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +5 -5
- rasa/dialogue_understanding/generator/llm_command_generator.py +1 -2
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +15 -34
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +6 -11
- rasa/dialogue_understanding/utils.py +1 -8
- rasa/dialogue_understanding_test/command_metric_calculation.py +12 -0
- rasa/dialogue_understanding_test/constants.py +2 -0
- rasa/dialogue_understanding_test/du_test_runner.py +93 -0
- rasa/dialogue_understanding_test/io.py +54 -0
- rasa/dialogue_understanding_test/validation.py +22 -0
- rasa/e2e_test/e2e_test_runner.py +9 -7
- rasa/hooks.py +9 -15
- rasa/model_manager/socket_bridge.py +2 -7
- rasa/model_manager/warm_rasa_process.py +4 -9
- rasa/plugin.py +0 -11
- rasa/shared/constants.py +2 -21
- rasa/shared/core/events.py +8 -8
- rasa/shared/nlu/constants.py +0 -3
- rasa/shared/providers/_configs/azure_entra_id_client_creds.py +40 -0
- rasa/shared/providers/_configs/azure_entra_id_config.py +533 -0
- rasa/shared/providers/_configs/azure_openai_client_config.py +131 -15
- rasa/shared/providers/_configs/client_config.py +3 -1
- rasa/shared/providers/_configs/default_litellm_client_config.py +9 -7
- rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +13 -11
- rasa/shared/providers/_configs/litellm_router_client_config.py +12 -10
- rasa/shared/providers/_configs/model_group_config.py +11 -5
- rasa/shared/providers/_configs/oauth_config.py +33 -0
- rasa/shared/providers/_configs/openai_client_config.py +14 -12
- rasa/shared/providers/_configs/rasa_llm_client_config.py +5 -3
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +12 -11
- rasa/shared/providers/constants.py +6 -0
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +30 -7
- rasa/shared/providers/embedding/litellm_router_embedding_client.py +5 -2
- rasa/shared/providers/llm/_base_litellm_client.py +6 -4
- rasa/shared/providers/llm/azure_openai_llm_client.py +88 -34
- rasa/shared/providers/llm/default_litellm_llm_client.py +4 -2
- rasa/shared/providers/llm/litellm_router_llm_client.py +23 -3
- rasa/shared/providers/llm/llm_client.py +4 -2
- rasa/shared/providers/llm/llm_response.py +1 -42
- rasa/shared/providers/llm/openai_llm_client.py +11 -5
- rasa/shared/providers/llm/rasa_llm_client.py +13 -5
- rasa/shared/providers/llm/self_hosted_llm_client.py +17 -10
- rasa/shared/providers/router/_base_litellm_router_client.py +10 -8
- rasa/shared/providers/router/router_client.py +3 -1
- rasa/shared/utils/llm.py +16 -12
- rasa/shared/utils/schemas/events.py +1 -1
- rasa/tracing/instrumentation/attribute_extractors.py +0 -2
- rasa/version.py +1 -1
- {rasa_pro-3.11.3a1.dev7.dist-info → rasa_pro-3.12.0.dev2.dist-info}/METADATA +2 -1
- {rasa_pro-3.11.3a1.dev7.dist-info → rasa_pro-3.12.0.dev2.dist-info}/RECORD +63 -56
- rasa/core/channels/studio_chat.py +0 -192
- rasa/dialogue_understanding/constants.py +0 -1
- {rasa_pro-3.11.3a1.dev7.dist-info → rasa_pro-3.12.0.dev2.dist-info}/NOTICE +0 -0
- {rasa_pro-3.11.3a1.dev7.dist-info → rasa_pro-3.12.0.dev2.dist-info}/WHEEL +0 -0
- {rasa_pro-3.11.3a1.dev7.dist-info → rasa_pro-3.12.0.dev2.dist-info}/entry_points.txt +0 -0
|
@@ -4,23 +4,29 @@ from typing import Any, Dict, List, Optional
|
|
|
4
4
|
import structlog
|
|
5
5
|
|
|
6
6
|
from rasa.shared.constants import (
|
|
7
|
+
API_BASE_CONFIG_KEY,
|
|
8
|
+
API_KEY,
|
|
9
|
+
API_VERSION_CONFIG_KEY,
|
|
7
10
|
AZURE_API_BASE_ENV_VAR,
|
|
8
11
|
AZURE_API_KEY_ENV_VAR,
|
|
9
12
|
AZURE_API_TYPE_ENV_VAR,
|
|
10
13
|
AZURE_API_VERSION_ENV_VAR,
|
|
14
|
+
AZURE_OPENAI_PROVIDER,
|
|
11
15
|
OPENAI_API_BASE_ENV_VAR,
|
|
12
16
|
OPENAI_API_KEY_ENV_VAR,
|
|
13
17
|
OPENAI_API_TYPE_ENV_VAR,
|
|
14
18
|
OPENAI_API_VERSION_ENV_VAR,
|
|
15
|
-
API_BASE_CONFIG_KEY,
|
|
16
|
-
API_KEY,
|
|
17
|
-
API_VERSION_CONFIG_KEY,
|
|
18
|
-
AZURE_OPENAI_PROVIDER,
|
|
19
19
|
)
|
|
20
20
|
from rasa.shared.exceptions import ProviderClientValidationError
|
|
21
21
|
from rasa.shared.providers._configs.azure_openai_client_config import (
|
|
22
|
+
AzureEntraIDOAuthConfig,
|
|
22
23
|
AzureOpenAIClientConfig,
|
|
23
24
|
)
|
|
25
|
+
from rasa.shared.providers.constants import (
|
|
26
|
+
DEFAULT_AZURE_API_KEY_NAME,
|
|
27
|
+
LITE_LLM_API_KEY_FIELD,
|
|
28
|
+
LITE_LLM_AZURE_AD_TOKEN,
|
|
29
|
+
)
|
|
24
30
|
from rasa.shared.providers.embedding._base_litellm_embedding_client import (
|
|
25
31
|
_BaseLiteLLMEmbeddingClient,
|
|
26
32
|
)
|
|
@@ -41,6 +47,8 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
|
|
|
41
47
|
If not provided, it will be set via environment variable.
|
|
42
48
|
api_version (Optional[str]): The version of the API to use.
|
|
43
49
|
If not provided, it will be set via environment variable.
|
|
50
|
+
oauth (Optional[AzureEntraIDOAuthConfig]): Optional OAuth configuration. If provided,
|
|
51
|
+
the client will use OAuth for authentication.
|
|
44
52
|
kwargs (Optional[Dict[str, Any]]): Optional configuration parameters specific
|
|
45
53
|
to the embedding model deployment.
|
|
46
54
|
|
|
@@ -57,6 +65,7 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
|
|
|
57
65
|
api_base: Optional[str] = None,
|
|
58
66
|
api_type: Optional[str] = None,
|
|
59
67
|
api_version: Optional[str] = None,
|
|
68
|
+
oauth: Optional[AzureEntraIDOAuthConfig] = None,
|
|
60
69
|
**kwargs: Any,
|
|
61
70
|
):
|
|
62
71
|
super().__init__() # type: ignore
|
|
@@ -84,7 +93,11 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
|
|
|
84
93
|
# Litellm does not support use of OPENAI_API_KEY, so we need to map it
|
|
85
94
|
# because of backward compatibility. However, we're first looking at
|
|
86
95
|
# AZURE_API_KEY.
|
|
87
|
-
|
|
96
|
+
|
|
97
|
+
self._oauth = oauth
|
|
98
|
+
self._api_key_env_var = (
|
|
99
|
+
self._resolve_api_key_env_var() if not self._oauth else None
|
|
100
|
+
)
|
|
88
101
|
|
|
89
102
|
self.validate_client_setup()
|
|
90
103
|
|
|
@@ -100,7 +113,7 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
|
|
|
100
113
|
return self._extra_parameters[API_KEY]
|
|
101
114
|
|
|
102
115
|
if os.getenv(AZURE_API_KEY_ENV_VAR) is not None:
|
|
103
|
-
return "${
|
|
116
|
+
return f"${DEFAULT_AZURE_API_KEY_NAME}"
|
|
104
117
|
|
|
105
118
|
if os.getenv(OPENAI_API_KEY_ENV_VAR) is not None:
|
|
106
119
|
# API key can be set through OPENAI_API_KEY too,
|
|
@@ -163,6 +176,7 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
|
|
|
163
176
|
api_base=azure_openai_config.api_base,
|
|
164
177
|
api_type=azure_openai_config.api_type,
|
|
165
178
|
api_version=azure_openai_config.api_version,
|
|
179
|
+
oauth=azure_openai_config.oauth,
|
|
166
180
|
**azure_openai_config.extra_parameters,
|
|
167
181
|
)
|
|
168
182
|
|
|
@@ -177,6 +191,7 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
|
|
|
177
191
|
api_base=self.api_base,
|
|
178
192
|
api_type=self.api_type,
|
|
179
193
|
api_version=self.api_version,
|
|
194
|
+
oauth=self._oauth,
|
|
180
195
|
extra_parameters=self._extra_parameters,
|
|
181
196
|
)
|
|
182
197
|
return config.to_dict()
|
|
@@ -219,13 +234,21 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
|
|
|
219
234
|
|
|
220
235
|
@property
|
|
221
236
|
def _embedding_fn_args(self) -> dict:
|
|
237
|
+
auth_parameter = (
|
|
238
|
+
{
|
|
239
|
+
LITE_LLM_AZURE_AD_TOKEN: self._oauth.get_bearer_token(),
|
|
240
|
+
}
|
|
241
|
+
if self._oauth
|
|
242
|
+
else {LITE_LLM_API_KEY_FIELD: self._api_key_env_var}
|
|
243
|
+
)
|
|
244
|
+
|
|
222
245
|
return {
|
|
223
246
|
**self._litellm_extra_parameters,
|
|
224
247
|
"model": self._litellm_model_name,
|
|
225
248
|
"api_base": self.api_base,
|
|
226
249
|
"api_type": self.api_type,
|
|
227
250
|
"api_version": self.api_version,
|
|
228
|
-
|
|
251
|
+
**auth_parameter,
|
|
229
252
|
}
|
|
230
253
|
|
|
231
254
|
@property
|
|
@@ -1,5 +1,8 @@
|
|
|
1
|
-
from
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
2
3
|
import logging
|
|
4
|
+
from typing import Any, Dict, List
|
|
5
|
+
|
|
3
6
|
import structlog
|
|
4
7
|
|
|
5
8
|
from rasa.shared.exceptions import ProviderClientAPIException
|
|
@@ -45,7 +48,7 @@ class LiteLLMRouterEmbeddingClient(
|
|
|
45
48
|
)
|
|
46
49
|
|
|
47
50
|
@classmethod
|
|
48
|
-
def from_config(cls, config: Dict[str, Any]) ->
|
|
51
|
+
def from_config(cls, config: Dict[str, Any]) -> LiteLLMRouterEmbeddingClient:
|
|
49
52
|
"""Instantiates a LiteLLM Router Embedding client from a configuration dict.
|
|
50
53
|
|
|
51
54
|
Args:
|
|
@@ -1,11 +1,13 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import logging
|
|
2
4
|
from abc import abstractmethod
|
|
3
|
-
from typing import Dict, List,
|
|
5
|
+
from typing import Any, Dict, List, Union
|
|
4
6
|
|
|
5
7
|
import structlog
|
|
6
8
|
from litellm import (
|
|
7
|
-
completion,
|
|
8
9
|
acompletion,
|
|
10
|
+
completion,
|
|
9
11
|
validate_environment,
|
|
10
12
|
)
|
|
11
13
|
|
|
@@ -19,7 +21,7 @@ from rasa.shared.providers._ssl_verification_utils import (
|
|
|
19
21
|
ensure_ssl_certificates_for_litellm_openai_based_clients,
|
|
20
22
|
)
|
|
21
23
|
from rasa.shared.providers.llm.llm_response import LLMResponse, LLMUsage
|
|
22
|
-
from rasa.shared.utils.io import
|
|
24
|
+
from rasa.shared.utils.io import resolve_environment_variables, suppress_logs
|
|
23
25
|
|
|
24
26
|
structlogger = structlog.get_logger()
|
|
25
27
|
|
|
@@ -50,7 +52,7 @@ class _BaseLiteLLMClient:
|
|
|
50
52
|
|
|
51
53
|
@classmethod
|
|
52
54
|
@abstractmethod
|
|
53
|
-
def from_config(cls, config: Dict[str, Any]) ->
|
|
55
|
+
def from_config(cls, config: Dict[str, Any]) -> _BaseLiteLLMClient:
|
|
54
56
|
pass
|
|
55
57
|
|
|
56
58
|
@property
|
|
@@ -1,33 +1,60 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import os
|
|
2
4
|
import re
|
|
3
|
-
from typing import
|
|
5
|
+
from typing import Any, Dict, Optional
|
|
4
6
|
|
|
5
7
|
import structlog
|
|
6
8
|
|
|
7
9
|
from rasa.shared.constants import (
|
|
8
|
-
OPENAI_API_BASE_ENV_VAR,
|
|
9
|
-
OPENAI_API_VERSION_ENV_VAR,
|
|
10
|
-
AZURE_API_BASE_ENV_VAR,
|
|
11
|
-
AZURE_API_VERSION_ENV_VAR,
|
|
12
10
|
API_BASE_CONFIG_KEY,
|
|
11
|
+
API_KEY,
|
|
13
12
|
API_VERSION_CONFIG_KEY,
|
|
14
|
-
|
|
13
|
+
AZURE_API_BASE_ENV_VAR,
|
|
15
14
|
AZURE_API_KEY_ENV_VAR,
|
|
16
|
-
OPENAI_API_TYPE_ENV_VAR,
|
|
17
|
-
OPENAI_API_KEY_ENV_VAR,
|
|
18
15
|
AZURE_API_TYPE_ENV_VAR,
|
|
16
|
+
AZURE_API_VERSION_ENV_VAR,
|
|
19
17
|
AZURE_OPENAI_PROVIDER,
|
|
20
|
-
|
|
18
|
+
DEPLOYMENT_CONFIG_KEY,
|
|
19
|
+
OPENAI_API_BASE_ENV_VAR,
|
|
20
|
+
OPENAI_API_KEY_ENV_VAR,
|
|
21
|
+
OPENAI_API_TYPE_ENV_VAR,
|
|
22
|
+
OPENAI_API_VERSION_ENV_VAR,
|
|
21
23
|
)
|
|
22
24
|
from rasa.shared.exceptions import ProviderClientValidationError
|
|
23
25
|
from rasa.shared.providers._configs.azure_openai_client_config import (
|
|
26
|
+
AzureEntraIDOAuthConfig,
|
|
24
27
|
AzureOpenAIClientConfig,
|
|
25
28
|
)
|
|
29
|
+
from rasa.shared.providers.constants import (
|
|
30
|
+
DEFAULT_AZURE_API_KEY_NAME,
|
|
31
|
+
LITE_LLM_API_BASE_FIELD,
|
|
32
|
+
LITE_LLM_API_KEY_FIELD,
|
|
33
|
+
LITE_LLM_API_VERSION_FIELD,
|
|
34
|
+
LITE_LLM_AZURE_AD_TOKEN,
|
|
35
|
+
)
|
|
26
36
|
from rasa.shared.providers.llm._base_litellm_client import _BaseLiteLLMClient
|
|
27
37
|
from rasa.shared.utils.io import raise_deprecation_warning
|
|
28
38
|
|
|
29
39
|
structlogger = structlog.get_logger()
|
|
30
40
|
|
|
41
|
+
AZURE_CLIENT_ID = "AZURE_CLIENT_ID"
|
|
42
|
+
AZURE_CLIENT_SECRET = "AZURE_CLIENT_SECRET"
|
|
43
|
+
AZURE_TENANT_ID = "AZURE_TENANT_ID"
|
|
44
|
+
CLIENT_SECRET_VARS = (AZURE_CLIENT_ID, AZURE_CLIENT_SECRET, AZURE_TENANT_ID)
|
|
45
|
+
|
|
46
|
+
AZURE_CLIENT_CERTIFICATE_PATH = "AZURE_CLIENT_CERTIFICATE_PATH"
|
|
47
|
+
AZURE_CLIENT_CERTIFICATE_PASSWORD = "AZURE_CLIENT_CERTIFICATE_PASSWORD"
|
|
48
|
+
AZURE_CLIENT_SEND_CERTIFICATE_CHAIN = "AZURE_CLIENT_SEND_CERTIFICATE_CHAIN"
|
|
49
|
+
CERT_VARS = (AZURE_CLIENT_ID, AZURE_CLIENT_CERTIFICATE_PATH, AZURE_TENANT_ID)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class AzureADConfig:
|
|
53
|
+
def __init__(
|
|
54
|
+
self, client_id: str, client_secret: str, tenant_id: str, scopes: str
|
|
55
|
+
) -> None:
|
|
56
|
+
self.scopes = scopes
|
|
57
|
+
|
|
31
58
|
|
|
32
59
|
class AzureOpenAILLMClient(_BaseLiteLLMClient):
|
|
33
60
|
"""A client for interfacing with Azure's OpenAI LLM deployments.
|
|
@@ -41,6 +68,7 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
|
|
|
41
68
|
it will be set via environment variables.
|
|
42
69
|
api_version (Optional[str]): The version of the API to use. If not provided,
|
|
43
70
|
it will be set via environment variable.
|
|
71
|
+
|
|
44
72
|
kwargs (Optional[Dict[str, Any]]): Optional configuration parameters specific
|
|
45
73
|
to the model deployment.
|
|
46
74
|
|
|
@@ -57,6 +85,7 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
|
|
|
57
85
|
api_type: Optional[str] = None,
|
|
58
86
|
api_base: Optional[str] = None,
|
|
59
87
|
api_version: Optional[str] = None,
|
|
88
|
+
oauth: Optional[AzureEntraIDOAuthConfig] = None,
|
|
60
89
|
**kwargs: Any,
|
|
61
90
|
):
|
|
62
91
|
super().__init__() # type: ignore
|
|
@@ -80,8 +109,6 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
|
|
|
80
109
|
or os.getenv(OPENAI_API_VERSION_ENV_VAR)
|
|
81
110
|
)
|
|
82
111
|
|
|
83
|
-
self._api_key_env_var = self._resolve_api_key_env_var()
|
|
84
|
-
|
|
85
112
|
# Not used by LiteLLM, here for backward compatibility
|
|
86
113
|
self._api_type = (
|
|
87
114
|
api_type
|
|
@@ -89,6 +116,19 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
|
|
|
89
116
|
or os.getenv(OPENAI_API_TYPE_ENV_VAR)
|
|
90
117
|
)
|
|
91
118
|
|
|
119
|
+
os.unsetenv("OPENAI_API_KEY")
|
|
120
|
+
os.unsetenv("AZURE_API_KEY")
|
|
121
|
+
|
|
122
|
+
self._oauth = oauth
|
|
123
|
+
|
|
124
|
+
if self._oauth:
|
|
125
|
+
os.unsetenv(DEFAULT_AZURE_API_KEY_NAME)
|
|
126
|
+
os.unsetenv(AZURE_API_KEY_ENV_VAR)
|
|
127
|
+
os.unsetenv(OPENAI_API_KEY_ENV_VAR)
|
|
128
|
+
self._api_key_env_var = (
|
|
129
|
+
self._resolve_api_key_env_var() if not self._oauth else None
|
|
130
|
+
)
|
|
131
|
+
|
|
92
132
|
# Run helper function to check and raise deprecation warning if
|
|
93
133
|
# deprecated environment variables were used for initialization of the
|
|
94
134
|
# client settings
|
|
@@ -157,7 +197,7 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
|
|
|
157
197
|
return self._extra_parameters[API_KEY]
|
|
158
198
|
|
|
159
199
|
if os.getenv(AZURE_API_KEY_ENV_VAR) is not None:
|
|
160
|
-
return "${
|
|
200
|
+
return f"${DEFAULT_AZURE_API_KEY_NAME}"
|
|
161
201
|
|
|
162
202
|
if os.getenv(OPENAI_API_KEY_ENV_VAR) is not None:
|
|
163
203
|
# API key can be set through OPENAI_API_KEY too,
|
|
@@ -188,7 +228,7 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
|
|
|
188
228
|
)
|
|
189
229
|
|
|
190
230
|
@classmethod
|
|
191
|
-
def from_config(cls, config: Dict[str, Any]) ->
|
|
231
|
+
def from_config(cls, config: Dict[str, Any]) -> AzureOpenAILLMClient:
|
|
192
232
|
"""Initializes the client from given configuration.
|
|
193
233
|
|
|
194
234
|
Args:
|
|
@@ -215,11 +255,12 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
|
|
|
215
255
|
raise
|
|
216
256
|
|
|
217
257
|
return cls(
|
|
218
|
-
azure_openai_config.deployment,
|
|
219
|
-
azure_openai_config.model,
|
|
220
|
-
azure_openai_config.api_type,
|
|
221
|
-
azure_openai_config.api_base,
|
|
222
|
-
azure_openai_config.api_version,
|
|
258
|
+
deployment=azure_openai_config.deployment,
|
|
259
|
+
model=azure_openai_config.model,
|
|
260
|
+
api_type=azure_openai_config.api_type,
|
|
261
|
+
api_base=azure_openai_config.api_base,
|
|
262
|
+
api_version=azure_openai_config.api_version,
|
|
263
|
+
oauth=azure_openai_config.oauth,
|
|
223
264
|
**azure_openai_config.extra_parameters,
|
|
224
265
|
)
|
|
225
266
|
|
|
@@ -234,6 +275,7 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
|
|
|
234
275
|
api_base=self._api_base,
|
|
235
276
|
api_version=self._api_version,
|
|
236
277
|
api_type=self._api_type,
|
|
278
|
+
oauth=self._oauth,
|
|
237
279
|
extra_parameters=self._extra_parameters,
|
|
238
280
|
)
|
|
239
281
|
return config.to_dict()
|
|
@@ -282,12 +324,21 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
|
|
|
282
324
|
"""Returns the completion arguments for invoking a call through
|
|
283
325
|
LiteLLM's completion functions.
|
|
284
326
|
"""
|
|
327
|
+
# Set the API key env var to None if OAuth is used
|
|
328
|
+
auth_parameter = (
|
|
329
|
+
{
|
|
330
|
+
LITE_LLM_AZURE_AD_TOKEN: self._oauth.get_bearer_token(),
|
|
331
|
+
}
|
|
332
|
+
if self._oauth
|
|
333
|
+
else {LITE_LLM_API_KEY_FIELD: self._api_key_env_var}
|
|
334
|
+
)
|
|
335
|
+
|
|
285
336
|
fn_args = super()._completion_fn_args
|
|
286
337
|
fn_args.update(
|
|
287
338
|
{
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
339
|
+
LITE_LLM_API_BASE_FIELD: self.api_base,
|
|
340
|
+
LITE_LLM_API_VERSION_FIELD: self.api_version,
|
|
341
|
+
**auth_parameter,
|
|
291
342
|
}
|
|
292
343
|
)
|
|
293
344
|
return fn_args
|
|
@@ -314,41 +365,44 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
|
|
|
314
365
|
|
|
315
366
|
return info.format(setting=setting, options=options)
|
|
316
367
|
|
|
368
|
+
env_var_field = "env_var"
|
|
369
|
+
config_key_field = "config_key"
|
|
370
|
+
current_value_field = "current_value"
|
|
317
371
|
# All required settings for Azure OpenAI client
|
|
318
372
|
settings: Dict[str, Dict[str, Any]] = {
|
|
319
373
|
"API Base": {
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
374
|
+
current_value_field: self.api_base,
|
|
375
|
+
env_var_field: AZURE_API_BASE_ENV_VAR,
|
|
376
|
+
config_key_field: API_BASE_CONFIG_KEY,
|
|
323
377
|
},
|
|
324
378
|
"API Version": {
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
379
|
+
current_value_field: self.api_version,
|
|
380
|
+
env_var_field: AZURE_API_VERSION_ENV_VAR,
|
|
381
|
+
config_key_field: API_VERSION_CONFIG_KEY,
|
|
328
382
|
},
|
|
329
383
|
"Deployment Name": {
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
384
|
+
current_value_field: self.deployment,
|
|
385
|
+
env_var_field: None,
|
|
386
|
+
config_key_field: DEPLOYMENT_CONFIG_KEY,
|
|
333
387
|
},
|
|
334
388
|
}
|
|
335
389
|
|
|
336
390
|
missing_settings = [
|
|
337
391
|
setting_name
|
|
338
392
|
for setting_name, setting_info in settings.items()
|
|
339
|
-
if setting_info[
|
|
393
|
+
if setting_info[current_value_field] is None
|
|
340
394
|
]
|
|
341
395
|
|
|
342
396
|
if missing_settings:
|
|
343
397
|
event_info = f"Client settings not set: " f"{', '.join(missing_settings)}. "
|
|
344
398
|
|
|
345
399
|
for missing_setting in missing_settings:
|
|
346
|
-
if settings[missing_setting][
|
|
400
|
+
if settings[missing_setting][current_value_field] is not None:
|
|
347
401
|
continue
|
|
348
402
|
event_info += generate_event_info_for_missing_setting(
|
|
349
403
|
missing_setting,
|
|
350
|
-
settings[missing_setting][
|
|
351
|
-
settings[missing_setting][
|
|
404
|
+
settings[missing_setting][env_var_field],
|
|
405
|
+
settings[missing_setting][config_key_field],
|
|
352
406
|
)
|
|
353
407
|
|
|
354
408
|
structlogger.error(
|
|
@@ -1,4 +1,6 @@
|
|
|
1
|
-
from
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict
|
|
2
4
|
|
|
3
5
|
from rasa.shared.constants import (
|
|
4
6
|
AWS_BEDROCK_PROVIDER,
|
|
@@ -35,7 +37,7 @@ class DefaultLiteLLMClient(_BaseLiteLLMClient):
|
|
|
35
37
|
self.validate_client_setup()
|
|
36
38
|
|
|
37
39
|
@classmethod
|
|
38
|
-
def from_config(cls, config: Dict[str, Any]) ->
|
|
40
|
+
def from_config(cls, config: Dict[str, Any]) -> DefaultLiteLLMClient:
|
|
39
41
|
default_config = DefaultLiteLLMClientConfig.from_dict(config)
|
|
40
42
|
return cls(
|
|
41
43
|
model=default_config.model,
|
|
@@ -1,11 +1,15 @@
|
|
|
1
|
-
from
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
2
3
|
import logging
|
|
4
|
+
from typing import Any, Dict, List, Union
|
|
5
|
+
|
|
3
6
|
import structlog
|
|
4
7
|
|
|
5
8
|
from rasa.shared.exceptions import ProviderClientAPIException
|
|
6
9
|
from rasa.shared.providers._configs.litellm_router_client_config import (
|
|
7
10
|
LiteLLMRouterClientConfig,
|
|
8
11
|
)
|
|
12
|
+
from rasa.shared.providers.constants import LITE_LLM_MODEL_FIELD
|
|
9
13
|
from rasa.shared.providers.llm._base_litellm_client import _BaseLiteLLMClient
|
|
10
14
|
from rasa.shared.providers.llm.llm_response import LLMResponse
|
|
11
15
|
from rasa.shared.providers.router._base_litellm_router_client import (
|
|
@@ -41,7 +45,7 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
41
45
|
)
|
|
42
46
|
|
|
43
47
|
@classmethod
|
|
44
|
-
def from_config(cls, config: Dict[str, Any]) ->
|
|
48
|
+
def from_config(cls, config: Dict[str, Any]) -> LiteLLMRouterLLMClient:
|
|
45
49
|
"""Instantiates a LiteLLM Router LLM client from a configuration dict.
|
|
46
50
|
|
|
47
51
|
Args:
|
|
@@ -86,6 +90,10 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
86
90
|
ProviderClientAPIException: If the API request fails.
|
|
87
91
|
"""
|
|
88
92
|
try:
|
|
93
|
+
structlogger.info(
|
|
94
|
+
"litellm_router_llm_client.text_completion",
|
|
95
|
+
_completion_fn_args=self._completion_fn_args,
|
|
96
|
+
)
|
|
89
97
|
response = self.router_client.text_completion(
|
|
90
98
|
prompt=prompt, **self._completion_fn_args
|
|
91
99
|
)
|
|
@@ -106,6 +114,10 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
106
114
|
ProviderClientAPIException: If the API request fails.
|
|
107
115
|
"""
|
|
108
116
|
try:
|
|
117
|
+
structlogger.info(
|
|
118
|
+
"litellm_router_llm_client.atext_completion",
|
|
119
|
+
_completion_fn_args=self._completion_fn_args,
|
|
120
|
+
)
|
|
109
121
|
response = await self.router_client.atext_completion(
|
|
110
122
|
prompt=prompt, **self._completion_fn_args
|
|
111
123
|
)
|
|
@@ -135,6 +147,10 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
135
147
|
return self._text_completion(messages)
|
|
136
148
|
try:
|
|
137
149
|
formatted_messages = self._format_messages(messages)
|
|
150
|
+
structlogger.info(
|
|
151
|
+
"litellm_router_llm_client.completion",
|
|
152
|
+
_completion_fn_args=self._completion_fn_args,
|
|
153
|
+
)
|
|
138
154
|
response = self.router_client.completion(
|
|
139
155
|
messages=formatted_messages, **self._completion_fn_args
|
|
140
156
|
)
|
|
@@ -164,6 +180,10 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
164
180
|
return await self._atext_completion(messages)
|
|
165
181
|
try:
|
|
166
182
|
formatted_messages = self._format_messages(messages)
|
|
183
|
+
structlogger.info(
|
|
184
|
+
"litellm_router_llm_client.acompletion",
|
|
185
|
+
_completion_fn_args=self._completion_fn_args,
|
|
186
|
+
)
|
|
167
187
|
response = await self.router_client.acompletion(
|
|
168
188
|
messages=formatted_messages, **self._completion_fn_args
|
|
169
189
|
)
|
|
@@ -178,5 +198,5 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
178
198
|
"""
|
|
179
199
|
return {
|
|
180
200
|
**self._litellm_extra_parameters,
|
|
181
|
-
|
|
201
|
+
LITE_LLM_MODEL_FIELD: self.model_group_id,
|
|
182
202
|
}
|
|
@@ -1,4 +1,6 @@
|
|
|
1
|
-
from
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Dict, List, Protocol, Union, runtime_checkable
|
|
2
4
|
|
|
3
5
|
from rasa.shared.providers.llm.llm_response import LLMResponse
|
|
4
6
|
|
|
@@ -11,7 +13,7 @@ class LLMClient(Protocol):
|
|
|
11
13
|
"""
|
|
12
14
|
|
|
13
15
|
@classmethod
|
|
14
|
-
def from_config(cls, config: dict) ->
|
|
16
|
+
def from_config(cls, config: dict) -> LLMClient:
|
|
15
17
|
"""
|
|
16
18
|
Initializes the llm client with the given configuration.
|
|
17
19
|
|
|
@@ -1,8 +1,5 @@
|
|
|
1
1
|
from dataclasses import dataclass, field, asdict
|
|
2
|
-
from typing import Dict, List, Optional
|
|
3
|
-
import structlog
|
|
4
|
-
|
|
5
|
-
structlogger = structlog.get_logger()
|
|
2
|
+
from typing import Dict, List, Optional
|
|
6
3
|
|
|
7
4
|
|
|
8
5
|
@dataclass
|
|
@@ -19,18 +16,6 @@ class LLMUsage:
|
|
|
19
16
|
def __post_init__(self) -> None:
|
|
20
17
|
self.total_tokens = self.prompt_tokens + self.completion_tokens
|
|
21
18
|
|
|
22
|
-
@classmethod
|
|
23
|
-
def from_dict(cls, data: Dict[Text, Any]) -> "LLMUsage":
|
|
24
|
-
"""
|
|
25
|
-
Creates an LLMUsage object from a dictionary.
|
|
26
|
-
If any keys are missing, they will default to zero
|
|
27
|
-
or whatever default you prefer.
|
|
28
|
-
"""
|
|
29
|
-
return cls(
|
|
30
|
-
prompt_tokens=data.get("prompt_tokens"),
|
|
31
|
-
completion_tokens=data.get("completion_tokens"),
|
|
32
|
-
)
|
|
33
|
-
|
|
34
19
|
def to_dict(self) -> dict:
|
|
35
20
|
"""Converts the LLMUsage dataclass instance into a dictionary."""
|
|
36
21
|
return asdict(self)
|
|
@@ -57,32 +42,6 @@ class LLMResponse:
|
|
|
57
42
|
"""Optional dictionary for storing additional information related to the
|
|
58
43
|
completion that may not be covered by other fields."""
|
|
59
44
|
|
|
60
|
-
@classmethod
|
|
61
|
-
def from_dict(cls, data: Dict[Text, Any]) -> "LLMResponse":
|
|
62
|
-
"""
|
|
63
|
-
Creates an LLMResponse from a dictionary.
|
|
64
|
-
"""
|
|
65
|
-
usage_data = data.get("usage", {})
|
|
66
|
-
usage_obj = LLMUsage.from_dict(usage_data) if usage_data else None
|
|
67
|
-
|
|
68
|
-
return cls(
|
|
69
|
-
id=data["id"],
|
|
70
|
-
choices=data["choices"],
|
|
71
|
-
created=data["created"],
|
|
72
|
-
model=data.get("model"),
|
|
73
|
-
usage=usage_obj,
|
|
74
|
-
additional_info=data.get("additional_info"),
|
|
75
|
-
)
|
|
76
|
-
|
|
77
|
-
@classmethod
|
|
78
|
-
def ensure_llm_response(cls, response: Union[str, "LLMResponse"]) -> "LLMResponse":
|
|
79
|
-
if isinstance(response, LLMResponse):
|
|
80
|
-
return response
|
|
81
|
-
|
|
82
|
-
structlogger.warn("llm_response.deprecated_response_type", response=response)
|
|
83
|
-
data = {"id": None, "choices": [response], "created": None}
|
|
84
|
-
return LLMResponse.from_dict(data)
|
|
85
|
-
|
|
86
45
|
def to_dict(self) -> dict:
|
|
87
46
|
"""Converts the LLMResponse dataclass instance into a dictionary."""
|
|
88
47
|
result = asdict(self)
|
|
@@ -1,16 +1,22 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import os
|
|
2
4
|
import re
|
|
3
|
-
from typing import
|
|
5
|
+
from typing import Any, Dict, Optional
|
|
4
6
|
|
|
5
7
|
import structlog
|
|
6
8
|
|
|
7
9
|
from rasa.shared.constants import (
|
|
8
10
|
OPENAI_API_BASE_ENV_VAR,
|
|
9
|
-
OPENAI_API_VERSION_ENV_VAR,
|
|
10
11
|
OPENAI_API_TYPE_ENV_VAR,
|
|
12
|
+
OPENAI_API_VERSION_ENV_VAR,
|
|
11
13
|
OPENAI_PROVIDER,
|
|
12
14
|
)
|
|
13
15
|
from rasa.shared.providers._configs.openai_client_config import OpenAIClientConfig
|
|
16
|
+
from rasa.shared.providers.constants import (
|
|
17
|
+
LITE_LLM_API_KEY_FIELD,
|
|
18
|
+
LITE_LLM_API_VERSION_FIELD,
|
|
19
|
+
)
|
|
14
20
|
from rasa.shared.providers.llm._base_litellm_client import _BaseLiteLLMClient
|
|
15
21
|
|
|
16
22
|
structlogger = structlog.get_logger()
|
|
@@ -57,7 +63,7 @@ class OpenAILLMClient(_BaseLiteLLMClient):
|
|
|
57
63
|
self.validate_client_setup()
|
|
58
64
|
|
|
59
65
|
@classmethod
|
|
60
|
-
def from_config(cls, config: Dict[str, Any]) ->
|
|
66
|
+
def from_config(cls, config: Dict[str, Any]) -> OpenAILLMClient:
|
|
61
67
|
"""
|
|
62
68
|
Initializes the client from given configuration.
|
|
63
69
|
|
|
@@ -148,8 +154,8 @@ class OpenAILLMClient(_BaseLiteLLMClient):
|
|
|
148
154
|
fn_args = super()._completion_fn_args
|
|
149
155
|
fn_args.update(
|
|
150
156
|
{
|
|
151
|
-
|
|
152
|
-
|
|
157
|
+
LITE_LLM_API_KEY_FIELD: self.api_base,
|
|
158
|
+
LITE_LLM_API_VERSION_FIELD: self.api_version,
|
|
153
159
|
}
|
|
154
160
|
)
|
|
155
161
|
return fn_args
|
|
@@ -1,17 +1,22 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
from typing import Any, Dict, Optional
|
|
2
4
|
|
|
3
5
|
import structlog
|
|
4
6
|
|
|
5
7
|
from rasa.shared.constants import (
|
|
6
|
-
RASA_PROVIDER,
|
|
7
8
|
OPENAI_PROVIDER,
|
|
9
|
+
RASA_PROVIDER,
|
|
8
10
|
)
|
|
9
11
|
from rasa.shared.providers._configs.rasa_llm_client_config import (
|
|
10
12
|
RasaLLMClientConfig,
|
|
11
13
|
)
|
|
12
|
-
from rasa.
|
|
14
|
+
from rasa.shared.providers.constants import (
|
|
15
|
+
LITE_LLM_API_BASE_FIELD,
|
|
16
|
+
LITE_LLM_API_KEY_FIELD,
|
|
17
|
+
)
|
|
13
18
|
from rasa.shared.providers.llm._base_litellm_client import _BaseLiteLLMClient
|
|
14
|
-
|
|
19
|
+
from rasa.utils.licensing import retrieve_license_from_env
|
|
15
20
|
|
|
16
21
|
structlogger = structlog.get_logger()
|
|
17
22
|
|
|
@@ -88,12 +93,15 @@ class RasaLLMClient(_BaseLiteLLMClient):
|
|
|
88
93
|
"""
|
|
89
94
|
fn_args = super()._completion_fn_args
|
|
90
95
|
fn_args.update(
|
|
91
|
-
{
|
|
96
|
+
{
|
|
97
|
+
LITE_LLM_API_BASE_FIELD: self.api_base,
|
|
98
|
+
LITE_LLM_API_KEY_FIELD: retrieve_license_from_env(),
|
|
99
|
+
}
|
|
92
100
|
)
|
|
93
101
|
return fn_args
|
|
94
102
|
|
|
95
103
|
@classmethod
|
|
96
|
-
def from_config(cls, config: Dict[str, Any]) ->
|
|
104
|
+
def from_config(cls, config: Dict[str, Any]) -> RasaLLMClient:
|
|
97
105
|
try:
|
|
98
106
|
client_config = RasaLLMClientConfig.from_dict(config)
|
|
99
107
|
except ValueError as e:
|