rasa-pro 3.11.3a1.dev7__py3-none-any.whl → 3.12.0.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/cli/arguments/default_arguments.py +1 -1
- rasa/cli/dialogue_understanding_test.py +251 -0
- rasa/core/actions/action.py +7 -16
- rasa/core/channels/__init__.py +0 -2
- rasa/core/channels/socketio.py +23 -1
- rasa/core/nlg/contextual_response_rephraser.py +9 -62
- rasa/core/policies/enterprise_search_policy.py +12 -77
- rasa/core/policies/flows/flow_executor.py +2 -26
- rasa/core/processor.py +8 -11
- rasa/dialogue_understanding/generator/command_generator.py +49 -43
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +5 -5
- rasa/dialogue_understanding/generator/llm_command_generator.py +1 -2
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +15 -34
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +6 -11
- rasa/dialogue_understanding/utils.py +1 -8
- rasa/dialogue_understanding_test/command_metric_calculation.py +12 -0
- rasa/dialogue_understanding_test/constants.py +2 -0
- rasa/dialogue_understanding_test/du_test_runner.py +93 -0
- rasa/dialogue_understanding_test/io.py +54 -0
- rasa/dialogue_understanding_test/validation.py +22 -0
- rasa/e2e_test/e2e_test_runner.py +9 -7
- rasa/hooks.py +9 -15
- rasa/model_manager/socket_bridge.py +2 -7
- rasa/model_manager/warm_rasa_process.py +4 -9
- rasa/plugin.py +0 -11
- rasa/shared/constants.py +2 -21
- rasa/shared/core/events.py +8 -8
- rasa/shared/nlu/constants.py +0 -3
- rasa/shared/providers/_configs/azure_entra_id_client_creds.py +40 -0
- rasa/shared/providers/_configs/azure_entra_id_config.py +533 -0
- rasa/shared/providers/_configs/azure_openai_client_config.py +131 -15
- rasa/shared/providers/_configs/client_config.py +3 -1
- rasa/shared/providers/_configs/default_litellm_client_config.py +9 -7
- rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +13 -11
- rasa/shared/providers/_configs/litellm_router_client_config.py +12 -10
- rasa/shared/providers/_configs/model_group_config.py +11 -5
- rasa/shared/providers/_configs/oauth_config.py +33 -0
- rasa/shared/providers/_configs/openai_client_config.py +14 -12
- rasa/shared/providers/_configs/rasa_llm_client_config.py +5 -3
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +12 -11
- rasa/shared/providers/constants.py +6 -0
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +30 -7
- rasa/shared/providers/embedding/litellm_router_embedding_client.py +5 -2
- rasa/shared/providers/llm/_base_litellm_client.py +6 -4
- rasa/shared/providers/llm/azure_openai_llm_client.py +88 -34
- rasa/shared/providers/llm/default_litellm_llm_client.py +4 -2
- rasa/shared/providers/llm/litellm_router_llm_client.py +23 -3
- rasa/shared/providers/llm/llm_client.py +4 -2
- rasa/shared/providers/llm/llm_response.py +1 -42
- rasa/shared/providers/llm/openai_llm_client.py +11 -5
- rasa/shared/providers/llm/rasa_llm_client.py +13 -5
- rasa/shared/providers/llm/self_hosted_llm_client.py +17 -10
- rasa/shared/providers/router/_base_litellm_router_client.py +10 -8
- rasa/shared/providers/router/router_client.py +3 -1
- rasa/shared/utils/llm.py +16 -12
- rasa/shared/utils/schemas/events.py +1 -1
- rasa/tracing/instrumentation/attribute_extractors.py +0 -2
- rasa/version.py +1 -1
- {rasa_pro-3.11.3a1.dev7.dist-info → rasa_pro-3.12.0.dev2.dist-info}/METADATA +2 -1
- {rasa_pro-3.11.3a1.dev7.dist-info → rasa_pro-3.12.0.dev2.dist-info}/RECORD +63 -56
- rasa/core/channels/studio_chat.py +0 -192
- rasa/dialogue_understanding/constants.py +0 -1
- {rasa_pro-3.11.3a1.dev7.dist-info → rasa_pro-3.12.0.dev2.dist-info}/NOTICE +0 -0
- {rasa_pro-3.11.3a1.dev7.dist-info → rasa_pro-3.12.0.dev2.dist-info}/WHEEL +0 -0
- {rasa_pro-3.11.3a1.dev7.dist-info → rasa_pro-3.12.0.dev2.dist-info}/entry_points.txt +0 -0
rasa/e2e_test/e2e_test_runner.py
CHANGED
|
@@ -1041,11 +1041,13 @@ class E2ETestRunner:
|
|
|
1041
1041
|
return conversations
|
|
1042
1042
|
|
|
1043
1043
|
@staticmethod
|
|
1044
|
-
def _action_server_is_reachable(
|
|
1044
|
+
def _action_server_is_reachable(
|
|
1045
|
+
endpoints: AvailableEndpoints, module: str = "e2e_test_runner"
|
|
1046
|
+
) -> None:
|
|
1045
1047
|
"""Calls the action server health endpoint."""
|
|
1046
1048
|
if not endpoints.action:
|
|
1047
1049
|
structlogger.debug(
|
|
1048
|
-
"
|
|
1050
|
+
f"{module}._action_server_is_reachable",
|
|
1049
1051
|
message="No action endpoint configured. Skipping the health check "
|
|
1050
1052
|
"of the action server.",
|
|
1051
1053
|
)
|
|
@@ -1053,7 +1055,7 @@ class E2ETestRunner:
|
|
|
1053
1055
|
|
|
1054
1056
|
if endpoints.action.actions_module:
|
|
1055
1057
|
structlogger.debug(
|
|
1056
|
-
"
|
|
1058
|
+
f"{module}._action_server_is_reachable",
|
|
1057
1059
|
message="Rasa server is configured to run custom actions directly. "
|
|
1058
1060
|
"Skipping the health check of the action server.",
|
|
1059
1061
|
)
|
|
@@ -1061,14 +1063,14 @@ class E2ETestRunner:
|
|
|
1061
1063
|
|
|
1062
1064
|
if not endpoints.action.url:
|
|
1063
1065
|
structlogger.debug(
|
|
1064
|
-
"
|
|
1066
|
+
f"{module}._action_server_is_reachable",
|
|
1065
1067
|
message="Action endpoint URL is not defined in the endpoint "
|
|
1066
1068
|
"configuration.",
|
|
1067
1069
|
)
|
|
1068
1070
|
return
|
|
1069
1071
|
|
|
1070
1072
|
structlogger.debug(
|
|
1071
|
-
"
|
|
1073
|
+
f"{module}._action_server_is_reachable",
|
|
1072
1074
|
message="Detected action URL in the endpoint configuration.\n"
|
|
1073
1075
|
f"Action Server URL: {endpoints.action.url}\n"
|
|
1074
1076
|
"Sending a health request to the action endpoint.",
|
|
@@ -1084,7 +1086,7 @@ class E2ETestRunner:
|
|
|
1084
1086
|
"Actions server URL is defined in your endpoint configuration as "
|
|
1085
1087
|
f"'{endpoints.action.url}'.\n"
|
|
1086
1088
|
"Please make sure your action server is running and properly "
|
|
1087
|
-
"configured. Since running
|
|
1089
|
+
"configured. Since running tests without a action server may "
|
|
1088
1090
|
f"lead to unpredictable results.\n{error}"
|
|
1089
1091
|
)
|
|
1090
1092
|
|
|
@@ -1096,7 +1098,7 @@ class E2ETestRunner:
|
|
|
1096
1098
|
)
|
|
1097
1099
|
|
|
1098
1100
|
structlogger.debug(
|
|
1099
|
-
"
|
|
1101
|
+
f"{module}._action_server_is_reachable",
|
|
1100
1102
|
message="Action endpoint has responded successfully.\n"
|
|
1101
1103
|
f"Response message: {response.text}\n"
|
|
1102
1104
|
f"Response status code: {response.status_code}.",
|
rasa/hooks.py
CHANGED
|
@@ -4,13 +4,14 @@ from typing import Optional, TYPE_CHECKING, List, Text, Union
|
|
|
4
4
|
|
|
5
5
|
import pluggy
|
|
6
6
|
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
7
|
+
from rasa.cli import SubParsersAction
|
|
8
|
+
from rasa.cli import x as rasa_x
|
|
9
|
+
from rasa.core.auth_retry_tracker_store import AuthRetryTrackerStore
|
|
10
|
+
from rasa.core.secrets_manager.factory import load_secret_manager
|
|
11
|
+
from rasa.tracing import config
|
|
12
|
+
from rasa.utils.endpoints import EndpointConfig
|
|
10
13
|
|
|
11
14
|
if TYPE_CHECKING:
|
|
12
|
-
from rasa.cli import SubParsersAction
|
|
13
|
-
from rasa.utils.endpoints import EndpointConfig
|
|
14
15
|
from rasa.core.brokers.broker import EventBroker
|
|
15
16
|
from rasa.core.tracker_store import TrackerStore
|
|
16
17
|
from rasa.shared.core.domain import Domain
|
|
@@ -22,15 +23,16 @@ logger = logging.getLogger(__name__)
|
|
|
22
23
|
|
|
23
24
|
@hookimpl # type: ignore[misc]
|
|
24
25
|
def refine_cli(
|
|
25
|
-
subparsers:
|
|
26
|
+
subparsers: SubParsersAction,
|
|
26
27
|
parent_parsers: List[argparse.ArgumentParser],
|
|
27
28
|
) -> None:
|
|
28
|
-
from rasa.cli import e2e_test, inspect, markers
|
|
29
|
+
from rasa.cli import e2e_test, inspect, markers, dialogue_understanding_test
|
|
29
30
|
from rasa.cli.studio import studio
|
|
30
31
|
|
|
31
32
|
from rasa.cli import license as license_cli
|
|
32
33
|
|
|
33
34
|
e2e_test.add_subparser(subparsers, parent_parsers)
|
|
35
|
+
dialogue_understanding_test.add_subparser(subparsers, parent_parsers)
|
|
34
36
|
studio.add_subparser(subparsers, parent_parsers)
|
|
35
37
|
license_cli.add_subparser(subparsers, parent_parsers)
|
|
36
38
|
markers.add_subparser(subparsers, parent_parsers)
|
|
@@ -40,9 +42,6 @@ def refine_cli(
|
|
|
40
42
|
|
|
41
43
|
@hookimpl # type: ignore[misc]
|
|
42
44
|
def configure_commandline(cmdline_arguments: argparse.Namespace) -> Optional[Text]:
|
|
43
|
-
from rasa.tracing import config
|
|
44
|
-
from rasa.cli import x as rasa_x
|
|
45
|
-
|
|
46
45
|
endpoints_file = None
|
|
47
46
|
|
|
48
47
|
if cmdline_arguments.func.__name__ == "rasa_x":
|
|
@@ -69,8 +68,6 @@ def init_telemetry(endpoints_file: Optional[Text]) -> None:
|
|
|
69
68
|
|
|
70
69
|
@hookimpl # type: ignore[misc]
|
|
71
70
|
def init_managers(endpoints_file: Optional[Text]) -> None:
|
|
72
|
-
from rasa.core.secrets_manager.factory import load_secret_manager
|
|
73
|
-
|
|
74
71
|
load_secret_manager(endpoints_file)
|
|
75
72
|
|
|
76
73
|
|
|
@@ -80,9 +77,6 @@ def create_tracker_store(
|
|
|
80
77
|
domain: "Domain",
|
|
81
78
|
event_broker: Optional["EventBroker"],
|
|
82
79
|
) -> "TrackerStore":
|
|
83
|
-
from rasa.utils.endpoints import EndpointConfig
|
|
84
|
-
from rasa.core.auth_retry_tracker_store import AuthRetryTrackerStore
|
|
85
|
-
|
|
86
80
|
if isinstance(endpoint_config, EndpointConfig):
|
|
87
81
|
return AuthRetryTrackerStore(
|
|
88
82
|
endpoint_config=endpoint_config, domain=domain, event_broker=event_broker
|
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import json
|
|
2
1
|
from typing import Any, Dict, Optional
|
|
3
2
|
|
|
4
3
|
from socketio import AsyncServer
|
|
@@ -93,7 +92,7 @@ def create_bridge_server(sio: AsyncServer, running_bots: Dict[str, BotSession])
|
|
|
93
92
|
|
|
94
93
|
@sio.on("*")
|
|
95
94
|
async def handle_message(event: str, sid: str, data: Dict[str, Any]) -> None:
|
|
96
|
-
"""Bridge messages between user and bot.
|
|
95
|
+
""" "Bridge messages between user and bot.
|
|
97
96
|
|
|
98
97
|
Both incoming user messages to the bot_url and
|
|
99
98
|
bot responses sent back to the client need to
|
|
@@ -112,7 +111,7 @@ async def create_bridge_client(
|
|
|
112
111
|
) -> AsyncClient:
|
|
113
112
|
"""Create a new socket bridge client.
|
|
114
113
|
|
|
115
|
-
Forwards messages
|
|
114
|
+
Forwards messages comming from the bot to the user.
|
|
116
115
|
"""
|
|
117
116
|
client = AsyncClient()
|
|
118
117
|
|
|
@@ -130,10 +129,6 @@ async def create_bridge_client(
|
|
|
130
129
|
structlogger.debug("model_runner.bot_message", deployment_id=deployment_id)
|
|
131
130
|
await sio.emit("bot_message", data, room=sid)
|
|
132
131
|
|
|
133
|
-
@client.event # type: ignore[misc]
|
|
134
|
-
async def tracker(data: Dict[str, Any]) -> None:
|
|
135
|
-
await sio.emit("tracker", json.loads(data), room=sid)
|
|
136
|
-
|
|
137
132
|
@client.event # type: ignore[misc]
|
|
138
133
|
async def disconnect() -> None:
|
|
139
134
|
structlogger.debug(
|
|
@@ -1,16 +1,12 @@
|
|
|
1
|
-
import os
|
|
2
1
|
import shlex
|
|
3
2
|
import subprocess
|
|
4
|
-
import
|
|
5
|
-
|
|
3
|
+
from rasa.__main__ import main
|
|
4
|
+
import os
|
|
6
5
|
from typing import List
|
|
7
|
-
|
|
8
6
|
import structlog
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
import uuid
|
|
9
9
|
|
|
10
|
-
from rasa.__main__ import main
|
|
11
|
-
from rasa.dialogue_understanding.constants import (
|
|
12
|
-
RASA_RECORD_COMMANDS_AND_PROMPTS_ENV_VAR_NAME,
|
|
13
|
-
)
|
|
14
10
|
from rasa.model_manager import config
|
|
15
11
|
from rasa.model_manager.utils import ensure_base_directory_exists, logs_path
|
|
16
12
|
|
|
@@ -47,7 +43,6 @@ def _create_warm_rasa_process() -> WarmRasaProcess:
|
|
|
47
43
|
|
|
48
44
|
envs = os.environ.copy()
|
|
49
45
|
envs["RASA_TELEMETRY_ENABLED"] = "false"
|
|
50
|
-
envs[RASA_RECORD_COMMANDS_AND_PROMPTS_ENV_VAR_NAME] = "true"
|
|
51
46
|
|
|
52
47
|
log_id = uuid.uuid4().hex
|
|
53
48
|
log_path = logs_path(log_id)
|
rasa/plugin.py
CHANGED
|
@@ -14,7 +14,6 @@ if TYPE_CHECKING:
|
|
|
14
14
|
from rasa.core.tracker_store import TrackerStore
|
|
15
15
|
from rasa.shared.core.domain import Domain
|
|
16
16
|
from rasa.utils.endpoints import EndpointConfig
|
|
17
|
-
from rasa.shared.core.trackers import DialogueStateTracker
|
|
18
17
|
|
|
19
18
|
|
|
20
19
|
hookspec = pluggy.HookspecMarker("rasa")
|
|
@@ -89,13 +88,3 @@ def after_server_stop() -> None:
|
|
|
89
88
|
Use this hook to de-initialize any resources that require explicit cleanup like,
|
|
90
89
|
thread shutdown, closing connections, etc.
|
|
91
90
|
"""
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
@hookspec # type: ignore[misc]
|
|
95
|
-
def after_new_user_message(tracker: "DialogueStateTracker") -> None:
|
|
96
|
-
"""Hook specification for after a new user message is received."""
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
@hookspec # type: ignore[misc]
|
|
100
|
-
def after_action_executed(tracker: "DialogueStateTracker") -> None:
|
|
101
|
-
"""Hook specification for after an action is executed."""
|
rasa/shared/constants.py
CHANGED
|
@@ -161,6 +161,7 @@ AZURE_AD_TOKEN_ENV_VAR = "AZURE_AD_TOKEN"
|
|
|
161
161
|
AZURE_API_BASE_ENV_VAR = "AZURE_API_BASE"
|
|
162
162
|
AZURE_API_VERSION_ENV_VAR = "AZURE_API_VERSION"
|
|
163
163
|
AZURE_API_TYPE_ENV_VAR = "AZURE_API_TYPE"
|
|
164
|
+
AZURE_AD_SCOPES_ENV_VAR = "AZURE_AD_SCOPES"
|
|
164
165
|
AZURE_SPEECH_API_KEY_ENV_VAR = "AZURE_SPEECH_API_KEY"
|
|
165
166
|
|
|
166
167
|
DEEPGRAM_API_KEY_ENV_VAR = "DEEPGRAM_API_KEY"
|
|
@@ -197,7 +198,6 @@ MODEL_CONFIG_KEY = "model"
|
|
|
197
198
|
MODEL_NAME_CONFIG_KEY = "model_name"
|
|
198
199
|
PROMPT_CONFIG_KEY = "prompt"
|
|
199
200
|
PROMPT_TEMPLATE_CONFIG_KEY = "prompt_template"
|
|
200
|
-
|
|
201
201
|
STREAM_CONFIG_KEY = "stream"
|
|
202
202
|
N_REPHRASES_CONFIG_KEY = "n"
|
|
203
203
|
USE_CHAT_COMPLETIONS_ENDPOINT_CONFIG_KEY = "use_chat_completions_endpoint"
|
|
@@ -231,12 +231,6 @@ LITELLM_PARAMS_KEY = "litellm_params"
|
|
|
231
231
|
LLM_API_HEALTH_CHECK_ENV_VAR = "LLM_API_HEALTH_CHECK"
|
|
232
232
|
LLM_API_HEALTH_CHECK_DEFAULT_VALUE = "false"
|
|
233
233
|
|
|
234
|
-
AZURE_API_KEY_ENV_VAR = "AZURE_API_KEY"
|
|
235
|
-
AZURE_AD_TOKEN_ENV_VAR = "AZURE_AD_TOKEN"
|
|
236
|
-
AZURE_API_BASE_ENV_VAR = "AZURE_API_BASE"
|
|
237
|
-
AZURE_API_VERSION_ENV_VAR = "AZURE_API_VERSION"
|
|
238
|
-
AZURE_API_TYPE_ENV_VAR = "AZURE_API_TYPE"
|
|
239
|
-
|
|
240
234
|
AWS_REGION_NAME_CONFIG_KEY = "aws_region_name"
|
|
241
235
|
AWS_ACCESS_KEY_ID_CONFIG_KEY = "aws_access_key_id"
|
|
242
236
|
AWS_SECRET_ACCESS_KEY_CONFIG_KEY = "aws_secret_access_key"
|
|
@@ -272,17 +266,11 @@ RASA_PROVIDER = "rasa"
|
|
|
272
266
|
SELF_HOSTED_VLLM_PREFIX = "hosted_vllm"
|
|
273
267
|
SELF_HOSTED_VLLM_API_KEY_ENV_VAR = "HOSTED_VLLM_API_KEY"
|
|
274
268
|
|
|
275
|
-
SELF_HOSTED_VLLM_PREFIX = "hosted_vllm"
|
|
276
|
-
SELF_HOSTED_VLLM_API_KEY_ENV_VAR = "HOSTED_VLLM_API_KEY"
|
|
277
|
-
|
|
278
269
|
VALID_PROVIDERS_FOR_API_TYPE_CONFIG_KEY = [
|
|
279
270
|
OPENAI_PROVIDER,
|
|
280
271
|
AZURE_OPENAI_PROVIDER,
|
|
281
272
|
]
|
|
282
273
|
|
|
283
|
-
SELF_HOSTED_VLLM_PREFIX = "hosted_vllm"
|
|
284
|
-
SELF_HOSTED_VLLM_API_KEY_ENV_VAR = "HOSTED_VLLM_API_KEY"
|
|
285
|
-
|
|
286
274
|
AZURE_API_TYPE = "azure"
|
|
287
275
|
OPENAI_API_TYPE = "openai"
|
|
288
276
|
|
|
@@ -320,11 +308,4 @@ SENSITIVE_DATA = [
|
|
|
320
308
|
AWS_SESSION_TOKEN_CONFIG_KEY,
|
|
321
309
|
]
|
|
322
310
|
|
|
323
|
-
|
|
324
|
-
TEXT = "text"
|
|
325
|
-
ELEMENTS = "elements"
|
|
326
|
-
QUICK_REPLIES = "quick_replies"
|
|
327
|
-
BUTTONS = "buttons"
|
|
328
|
-
ATTACHMENT = "attachment"
|
|
329
|
-
IMAGE = "image"
|
|
330
|
-
CUSTOM = "custom"
|
|
311
|
+
AZURE_AD_KEY = "azure_ad"
|
rasa/shared/core/events.py
CHANGED
|
@@ -2,10 +2,14 @@ import abc
|
|
|
2
2
|
import copy
|
|
3
3
|
import json
|
|
4
4
|
import logging
|
|
5
|
+
import structlog
|
|
5
6
|
import re
|
|
7
|
+
from abc import ABC
|
|
8
|
+
|
|
9
|
+
import jsonpickle
|
|
6
10
|
import time
|
|
7
11
|
import uuid
|
|
8
|
-
from
|
|
12
|
+
from dateutil import parser
|
|
9
13
|
from datetime import datetime
|
|
10
14
|
from typing import (
|
|
11
15
|
List,
|
|
@@ -20,14 +24,11 @@ from typing import (
|
|
|
20
24
|
Tuple,
|
|
21
25
|
TypeVar,
|
|
22
26
|
)
|
|
23
|
-
from typing import Union
|
|
24
|
-
|
|
25
|
-
import jsonpickle
|
|
26
|
-
import structlog
|
|
27
|
-
from dateutil import parser
|
|
28
27
|
|
|
29
28
|
import rasa.shared.utils.common
|
|
30
29
|
import rasa.shared.utils.io
|
|
30
|
+
from typing import Union
|
|
31
|
+
|
|
31
32
|
from rasa.shared.constants import DOCS_URL_TRAINING_DATA
|
|
32
33
|
from rasa.shared.core.constants import (
|
|
33
34
|
LOOP_NAME,
|
|
@@ -61,7 +62,7 @@ from rasa.shared.nlu.constants import (
|
|
|
61
62
|
ENTITY_ATTRIBUTE_END,
|
|
62
63
|
FULL_RETRIEVAL_INTENT_NAME_KEY,
|
|
63
64
|
)
|
|
64
|
-
|
|
65
|
+
|
|
65
66
|
|
|
66
67
|
if TYPE_CHECKING:
|
|
67
68
|
from typing_extensions import TypedDict
|
|
@@ -97,7 +98,6 @@ if TYPE_CHECKING:
|
|
|
97
98
|
ENTITIES: List[EntityPrediction],
|
|
98
99
|
"message_id": Optional[Text],
|
|
99
100
|
"metadata": Dict,
|
|
100
|
-
PROMPTS: Dict,
|
|
101
101
|
},
|
|
102
102
|
total=False,
|
|
103
103
|
)
|
rasa/shared/nlu/constants.py
CHANGED
|
@@ -6,9 +6,6 @@ PREDICTED_COMMANDS = "predicted_commands"
|
|
|
6
6
|
PROMPTS = "prompts"
|
|
7
7
|
KEY_USER_PROMPT = "user_prompt"
|
|
8
8
|
KEY_SYSTEM_PROMPT = "system_prompt"
|
|
9
|
-
KEY_LLM_RESPONSE_METADATA = "llm_response_metadata"
|
|
10
|
-
KEY_PROMPT_NAME = "prompt_name"
|
|
11
|
-
KEY_COMPONENT_NAME = "component_name"
|
|
12
9
|
LLM_COMMANDS = "llm_commands" # needed for fine-tuning
|
|
13
10
|
LLM_PROMPT = "llm_prompt" # needed for fine-tuning
|
|
14
11
|
FLOWS_FROM_SEMANTIC_SEARCH = "flows_from_semantic_search"
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Any, Dict, List
|
|
5
|
+
|
|
6
|
+
from azure.identity import ClientSecretCredential
|
|
7
|
+
|
|
8
|
+
from rasa.shared.providers._configs.oauth_config import OAuth
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class AzureEntraIDClientCreds(OAuth):
|
|
13
|
+
client_id: str
|
|
14
|
+
client_secret: str
|
|
15
|
+
tenant_id: str
|
|
16
|
+
scopes: List[str] = field(default_factory=list)
|
|
17
|
+
|
|
18
|
+
@classmethod
|
|
19
|
+
def from_config(cls, config: Dict[str, Any]) -> AzureEntraIDClientCreds:
|
|
20
|
+
scopes = config.get("scopes")
|
|
21
|
+
if isinstance(scopes, str):
|
|
22
|
+
scopes = [scopes]
|
|
23
|
+
|
|
24
|
+
return cls(
|
|
25
|
+
client_id=config.get("client_id"),
|
|
26
|
+
client_secret=config.get("client_secret"),
|
|
27
|
+
tenant_id=config.get("tenant_id"),
|
|
28
|
+
scopes=scopes,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
def get_bearer_token(self) -> str:
|
|
32
|
+
return (
|
|
33
|
+
ClientSecretCredential(
|
|
34
|
+
client_id=self.client_id,
|
|
35
|
+
client_secret=self.client_secret,
|
|
36
|
+
tenant_id=self.tenant_id,
|
|
37
|
+
)
|
|
38
|
+
.get_token(*self.scopes)
|
|
39
|
+
.token
|
|
40
|
+
)
|