rasa-pro 3.12.0rc2__py3-none-any.whl → 3.12.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/cli/dialogue_understanding_test.py +5 -8
- rasa/cli/llm_fine_tuning.py +47 -12
- rasa/core/channels/voice_stream/asr/asr_event.py +5 -0
- rasa/core/channels/voice_stream/audiocodes.py +19 -6
- rasa/core/channels/voice_stream/call_state.py +3 -9
- rasa/core/channels/voice_stream/genesys.py +40 -55
- rasa/core/channels/voice_stream/voice_channel.py +61 -39
- rasa/core/tracker_store.py +123 -34
- rasa/dialogue_understanding/commands/set_slot_command.py +1 -0
- rasa/dialogue_understanding/commands/utils.py +1 -4
- rasa/dialogue_understanding/generator/command_parser.py +41 -0
- rasa/dialogue_understanding/generator/constants.py +7 -2
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +9 -2
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_claude_3_5_sonnet_20240620_template.jinja2 +29 -48
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_fallback_other_models_template.jinja2 +57 -0
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_gpt_4o_2024_11_20_template.jinja2 +23 -50
- rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +76 -24
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +32 -18
- rasa/dialogue_understanding/processor/command_processor.py +39 -19
- rasa/dialogue_understanding/stack/utils.py +11 -6
- rasa/engine/language.py +67 -25
- rasa/llm_fine_tuning/conversations.py +3 -31
- rasa/llm_fine_tuning/llm_data_preparation_module.py +5 -3
- rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +18 -13
- rasa/llm_fine_tuning/paraphrasing_module.py +6 -2
- rasa/llm_fine_tuning/train_test_split_module.py +27 -27
- rasa/llm_fine_tuning/utils.py +7 -0
- rasa/shared/constants.py +4 -0
- rasa/shared/core/domain.py +2 -0
- rasa/shared/providers/_configs/azure_entra_id_config.py +8 -8
- rasa/shared/providers/llm/litellm_router_llm_client.py +1 -0
- rasa/shared/providers/router/_base_litellm_router_client.py +38 -7
- rasa/shared/utils/llm.py +69 -13
- rasa/telemetry.py +13 -3
- rasa/tracing/instrumentation/attribute_extractors.py +2 -5
- rasa/validator.py +2 -2
- rasa/version.py +1 -1
- {rasa_pro-3.12.0rc2.dist-info → rasa_pro-3.12.0rc3.dist-info}/METADATA +1 -1
- {rasa_pro-3.12.0rc2.dist-info → rasa_pro-3.12.0rc3.dist-info}/RECORD +42 -41
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_default.jinja2 +0 -68
- {rasa_pro-3.12.0rc2.dist-info → rasa_pro-3.12.0rc3.dist-info}/NOTICE +0 -0
- {rasa_pro-3.12.0rc2.dist-info → rasa_pro-3.12.0rc3.dist-info}/WHEEL +0 -0
- {rasa_pro-3.12.0rc2.dist-info → rasa_pro-3.12.0rc3.dist-info}/entry_points.txt +0 -0
|
@@ -4,8 +4,10 @@ from typing import Any, Dict, List, Optional
|
|
|
4
4
|
import structlog
|
|
5
5
|
from tqdm import tqdm
|
|
6
6
|
|
|
7
|
+
from rasa.dialogue_understanding.commands.prompt_command import PromptCommand
|
|
7
8
|
from rasa.llm_fine_tuning.conversations import Conversation, ConversationStep
|
|
8
9
|
from rasa.llm_fine_tuning.storage import StorageContext
|
|
10
|
+
from rasa.llm_fine_tuning.utils import commands_as_string
|
|
9
11
|
|
|
10
12
|
LLM_DATA_PREPARATION_MODULE_STORAGE_LOCATION = "3_llm_finetune_data/llm_ft_data.jsonl"
|
|
11
13
|
|
|
@@ -15,7 +17,7 @@ structlogger = structlog.get_logger()
|
|
|
15
17
|
@dataclass
|
|
16
18
|
class LLMDataExample:
|
|
17
19
|
prompt: str
|
|
18
|
-
output:
|
|
20
|
+
output: List[PromptCommand]
|
|
19
21
|
original_test_name: str
|
|
20
22
|
original_user_utterance: str
|
|
21
23
|
rephrased_user_utterance: str
|
|
@@ -23,7 +25,7 @@ class LLMDataExample:
|
|
|
23
25
|
def as_dict(self) -> Dict[str, Any]:
|
|
24
26
|
return {
|
|
25
27
|
"prompt": self.prompt,
|
|
26
|
-
"output": self.output,
|
|
28
|
+
"output": commands_as_string(self.output),
|
|
27
29
|
"original_test_name": self.original_test_name,
|
|
28
30
|
"original_user_utterance": self.original_user_utterance,
|
|
29
31
|
"rephrased_user_utterance": self.rephrased_user_utterance,
|
|
@@ -38,7 +40,7 @@ def _create_data_point(
|
|
|
38
40
|
) -> LLMDataExample:
|
|
39
41
|
return LLMDataExample(
|
|
40
42
|
prompt,
|
|
41
|
-
step.
|
|
43
|
+
step.llm_commands,
|
|
42
44
|
conversation.get_full_name(),
|
|
43
45
|
step.original_test_step.text,
|
|
44
46
|
rephrased_user_message,
|
|
@@ -1,18 +1,18 @@
|
|
|
1
|
-
from typing import Any, Dict, List
|
|
1
|
+
from typing import Any, Dict, List, Type
|
|
2
2
|
|
|
3
3
|
import structlog
|
|
4
4
|
|
|
5
5
|
from rasa.dialogue_understanding.commands import Command, SetSlotCommand
|
|
6
|
-
from rasa.dialogue_understanding.generator import
|
|
6
|
+
from rasa.dialogue_understanding.generator.llm_based_command_generator import (
|
|
7
|
+
LLMBasedCommandGenerator,
|
|
8
|
+
)
|
|
7
9
|
from rasa.llm_fine_tuning.conversations import Conversation, ConversationStep
|
|
8
10
|
from rasa.llm_fine_tuning.paraphrasing.rephrased_user_message import (
|
|
9
11
|
RephrasedUserMessage,
|
|
10
12
|
)
|
|
11
13
|
from rasa.shared.core.flows import FlowsList
|
|
12
14
|
from rasa.shared.exceptions import ProviderClientAPIException
|
|
13
|
-
from rasa.shared.utils.llm import
|
|
14
|
-
llm_factory,
|
|
15
|
-
)
|
|
15
|
+
from rasa.shared.utils.llm import llm_factory
|
|
16
16
|
|
|
17
17
|
structlogger = structlog.get_logger()
|
|
18
18
|
|
|
@@ -26,6 +26,7 @@ class RephraseValidator:
|
|
|
26
26
|
self,
|
|
27
27
|
rephrasings: List[RephrasedUserMessage],
|
|
28
28
|
conversation: Conversation,
|
|
29
|
+
llm_command_generator: Type[LLMBasedCommandGenerator],
|
|
29
30
|
) -> List[RephrasedUserMessage]:
|
|
30
31
|
"""Split rephrased user messages into passing and failing.
|
|
31
32
|
|
|
@@ -38,6 +39,7 @@ class RephraseValidator:
|
|
|
38
39
|
Args:
|
|
39
40
|
rephrasings: The rephrased user messages.
|
|
40
41
|
conversation: The conversation.
|
|
42
|
+
llm_command_generator: A LLM based command generator class.
|
|
41
43
|
|
|
42
44
|
Returns:
|
|
43
45
|
A list of rephrased user messages including the passing and failing
|
|
@@ -49,7 +51,9 @@ class RephraseValidator:
|
|
|
49
51
|
current_rephrasings = rephrasings[i]
|
|
50
52
|
|
|
51
53
|
for rephrase in current_rephrasings.rephrasings:
|
|
52
|
-
if await self._validate_rephrase_is_passing(
|
|
54
|
+
if await self._validate_rephrase_is_passing(
|
|
55
|
+
rephrase, step, llm_command_generator
|
|
56
|
+
):
|
|
53
57
|
current_rephrasings.passed_rephrasings.append(rephrase)
|
|
54
58
|
else:
|
|
55
59
|
current_rephrasings.failed_rephrasings.append(rephrase)
|
|
@@ -60,25 +64,26 @@ class RephraseValidator:
|
|
|
60
64
|
self,
|
|
61
65
|
rephrase: str,
|
|
62
66
|
step: ConversationStep,
|
|
67
|
+
llm_command_generator: Type[LLMBasedCommandGenerator],
|
|
63
68
|
) -> bool:
|
|
64
69
|
prompt = self._update_prompt(
|
|
65
70
|
rephrase, step.original_test_step.text, step.llm_prompt
|
|
66
71
|
)
|
|
67
72
|
|
|
68
|
-
action_list = await self._invoke_llm(
|
|
73
|
+
action_list = await self._invoke_llm(
|
|
74
|
+
prompt, llm_command_generator.get_default_llm_config()
|
|
75
|
+
)
|
|
69
76
|
|
|
70
77
|
commands_from_original_utterance = step.llm_commands
|
|
71
|
-
commands_from_rephrased_utterance = (
|
|
72
|
-
|
|
78
|
+
commands_from_rephrased_utterance = llm_command_generator.parse_commands( # type: ignore
|
|
79
|
+
action_list, None, self.flows
|
|
73
80
|
)
|
|
74
81
|
return self._check_commands_match(
|
|
75
82
|
commands_from_original_utterance, commands_from_rephrased_utterance
|
|
76
83
|
)
|
|
77
84
|
|
|
78
|
-
async def _invoke_llm(self, prompt: str) -> str:
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
llm = llm_factory(self.llm_config, DEFAULT_LLM_CONFIG)
|
|
85
|
+
async def _invoke_llm(self, prompt: str, default_llm_config: Dict[str, Any]) -> str:
|
|
86
|
+
llm = llm_factory(self.llm_config, default_llm_config)
|
|
82
87
|
|
|
83
88
|
try:
|
|
84
89
|
llm_response = await llm.acompletion(prompt)
|
|
@@ -1,8 +1,11 @@
|
|
|
1
|
-
from typing import Any, Dict, List, Tuple
|
|
1
|
+
from typing import Any, Dict, List, Tuple, Type
|
|
2
2
|
|
|
3
3
|
import structlog
|
|
4
4
|
from tqdm import tqdm
|
|
5
5
|
|
|
6
|
+
from rasa.dialogue_understanding.generator.llm_based_command_generator import (
|
|
7
|
+
LLMBasedCommandGenerator,
|
|
8
|
+
)
|
|
6
9
|
from rasa.llm_fine_tuning.conversations import Conversation
|
|
7
10
|
from rasa.llm_fine_tuning.paraphrasing.conversation_rephraser import (
|
|
8
11
|
ConversationRephraser,
|
|
@@ -25,6 +28,7 @@ async def create_paraphrased_conversations(
|
|
|
25
28
|
rephrase_config: Dict[str, Any],
|
|
26
29
|
num_rephrases: int,
|
|
27
30
|
flows: FlowsList,
|
|
31
|
+
llm_command_generator: Type[LLMBasedCommandGenerator],
|
|
28
32
|
llm_command_generator_config: Dict[str, Any],
|
|
29
33
|
storage_context: StorageContext,
|
|
30
34
|
) -> Tuple[List[Conversation], Dict[str, Any]]:
|
|
@@ -71,7 +75,7 @@ async def create_paraphrased_conversations(
|
|
|
71
75
|
rephrasings = _filter_rephrasings(rephrasings, conversations[i])
|
|
72
76
|
# check if the rephrasings are still producing the same commands
|
|
73
77
|
rephrasings = await validator.validate_rephrasings(
|
|
74
|
-
rephrasings, current_conversation
|
|
78
|
+
rephrasings, current_conversation, llm_command_generator
|
|
75
79
|
)
|
|
76
80
|
except ProviderClientAPIException as e:
|
|
77
81
|
structlogger.error(
|
|
@@ -1,27 +1,18 @@
|
|
|
1
1
|
import random
|
|
2
2
|
from collections import defaultdict
|
|
3
3
|
from dataclasses import dataclass
|
|
4
|
-
from typing import Any, Dict, List, Protocol, Set, Tuple
|
|
4
|
+
from typing import Any, Dict, List, Protocol, Set, Tuple, Type
|
|
5
5
|
|
|
6
6
|
import structlog
|
|
7
7
|
|
|
8
|
+
from rasa.dialogue_understanding.commands.prompt_command import PromptCommand
|
|
8
9
|
from rasa.e2e_test.e2e_test_case import TestSuite
|
|
9
10
|
from rasa.llm_fine_tuning.llm_data_preparation_module import LLMDataExample
|
|
10
11
|
from rasa.llm_fine_tuning.storage import StorageContext
|
|
12
|
+
from rasa.llm_fine_tuning.utils import commands_as_string
|
|
11
13
|
|
|
12
14
|
TRAIN_TEST_MODULE_STORAGE_LOCATION = "4_train_test_split"
|
|
13
15
|
|
|
14
|
-
SUPPORTED_COMMANDS = [
|
|
15
|
-
"SetSlot",
|
|
16
|
-
"StartFlow",
|
|
17
|
-
"CancelFlow",
|
|
18
|
-
"ChitChat",
|
|
19
|
-
"SkipQuestion",
|
|
20
|
-
"SearchAndReply",
|
|
21
|
-
"HumanHandoff",
|
|
22
|
-
"Clarify",
|
|
23
|
-
]
|
|
24
|
-
|
|
25
16
|
INSTRUCTION_DATA_FORMAT = "instruction"
|
|
26
17
|
CONVERSATIONAL_DATA_FORMAT = "conversational"
|
|
27
18
|
|
|
@@ -77,17 +68,19 @@ class ConversationalDataFormat(DataExampleFormat):
|
|
|
77
68
|
}
|
|
78
69
|
|
|
79
70
|
|
|
80
|
-
def _get_command_types_covered_by_llm_data_point(
|
|
71
|
+
def _get_command_types_covered_by_llm_data_point(
|
|
72
|
+
data_point: LLMDataExample,
|
|
73
|
+
) -> Set[Type[PromptCommand]]:
|
|
81
74
|
"""Get the command types covered by the LLM data point.
|
|
82
75
|
|
|
83
76
|
This function returns the set of command types from the output present in a
|
|
84
|
-
LLMDataExample object. Eg: The function returns {'
|
|
85
|
-
LLMDataExample.output is '
|
|
77
|
+
LLMDataExample object. Eg: The function returns {'SetSlotCommand',
|
|
78
|
+
'StartFlowCommand'} when the LLMDataExample.output is 'SetSlotCommand(slot, abc),
|
|
79
|
+
SetSlotCommand(slot, cde), StartFlowCommand(xyz)'.
|
|
86
80
|
"""
|
|
87
81
|
commands_covered = set()
|
|
88
|
-
for command in
|
|
89
|
-
|
|
90
|
-
commands_covered.add(command)
|
|
82
|
+
for command in data_point.output:
|
|
83
|
+
commands_covered.add(command.__class__)
|
|
91
84
|
return commands_covered
|
|
92
85
|
|
|
93
86
|
|
|
@@ -146,14 +139,18 @@ def _get_minimum_test_case_groups_to_cover_all_commands(
|
|
|
146
139
|
{
|
|
147
140
|
"test_case_name": "t1",
|
|
148
141
|
"data_examples": [],
|
|
149
|
-
"commands": {"
|
|
142
|
+
"commands": {"SetSlotCommand", "CancelFlowCommand"}
|
|
150
143
|
},
|
|
151
|
-
{
|
|
152
|
-
|
|
144
|
+
{
|
|
145
|
+
"test_case_name": "t2",
|
|
146
|
+
"data_examples": [],
|
|
147
|
+
"commands": {"CancelFlowCommand"}
|
|
148
|
+
},
|
|
149
|
+
{"test_case_name": "t3", "data_examples": [], "commands": {"StartFlowCommand"}},
|
|
153
150
|
{
|
|
154
151
|
"test_case_name": "t4",
|
|
155
152
|
"data_examples": [],
|
|
156
|
-
"commands": {"
|
|
153
|
+
"commands": {"SetSlotCommand", "StartFlowCommand"}
|
|
157
154
|
},
|
|
158
155
|
]
|
|
159
156
|
|
|
@@ -166,7 +163,7 @@ def _get_minimum_test_case_groups_to_cover_all_commands(
|
|
|
166
163
|
command for test_group in grouped_data for command in test_group[KEY_COMMANDS]
|
|
167
164
|
)
|
|
168
165
|
selected_test_cases = []
|
|
169
|
-
covered_commands: Set[
|
|
166
|
+
covered_commands: Set[Type[PromptCommand]] = set()
|
|
170
167
|
|
|
171
168
|
while covered_commands != all_commands:
|
|
172
169
|
# Find the test case group that covers the most number of uncovered commands
|
|
@@ -187,7 +184,7 @@ def _get_minimum_test_case_groups_to_cover_all_commands(
|
|
|
187
184
|
|
|
188
185
|
structlogger.info(
|
|
189
186
|
"llm_fine_tuning.train_test_split_module.command_coverage_in_train_dataset",
|
|
190
|
-
covered_commands=covered_commands,
|
|
187
|
+
covered_commands=[command.__name__ for command in covered_commands],
|
|
191
188
|
)
|
|
192
189
|
return selected_test_cases
|
|
193
190
|
|
|
@@ -205,7 +202,10 @@ def _get_finetuning_data_in_instruction_data_format(
|
|
|
205
202
|
data: List[Dict[str, Any]],
|
|
206
203
|
) -> List[DataExampleFormat]:
|
|
207
204
|
return [
|
|
208
|
-
InstructionDataFormat(
|
|
205
|
+
InstructionDataFormat(
|
|
206
|
+
llm_data_example.prompt,
|
|
207
|
+
commands_as_string(llm_data_example.output),
|
|
208
|
+
)
|
|
209
209
|
for test_group in data
|
|
210
210
|
for llm_data_example in test_group[KEY_DATA_EXAMPLES]
|
|
211
211
|
]
|
|
@@ -232,7 +232,7 @@ def _get_finetuning_data_in_conversational_data_format(
|
|
|
232
232
|
[
|
|
233
233
|
ConversationalMessageDataFormat("user", llm_data_example.prompt),
|
|
234
234
|
ConversationalMessageDataFormat(
|
|
235
|
-
"assistant", llm_data_example.output
|
|
235
|
+
"assistant", commands_as_string(llm_data_example.output)
|
|
236
236
|
),
|
|
237
237
|
]
|
|
238
238
|
)
|
|
@@ -271,7 +271,7 @@ def _check_and_log_missing_validation_dataset_command_coverage(
|
|
|
271
271
|
structlogger.warning(
|
|
272
272
|
"llm_fine_tuning.train_test_split_module.missing_commands_in_validation_dat"
|
|
273
273
|
"aset",
|
|
274
|
-
missing_commands=missing_commands,
|
|
274
|
+
missing_commands=[command.__name__ for command in missing_commands],
|
|
275
275
|
)
|
|
276
276
|
|
|
277
277
|
|
rasa/shared/constants.py
CHANGED
|
@@ -194,6 +194,9 @@ PROVIDER_CONFIG_KEY = "provider"
|
|
|
194
194
|
REQUEST_TIMEOUT_CONFIG_KEY = "request_timeout" # deprecated
|
|
195
195
|
TIMEOUT_CONFIG_KEY = "timeout"
|
|
196
196
|
|
|
197
|
+
TEMPERATURE_CONFIG_KEY = "temperature"
|
|
198
|
+
MAX_TOKENS_CONFIG_KEY = "max_tokens"
|
|
199
|
+
|
|
197
200
|
DEPLOYMENT_NAME_CONFIG_KEY = "deployment_name"
|
|
198
201
|
DEPLOYMENT_CONFIG_KEY = "deployment"
|
|
199
202
|
EMBEDDINGS_CONFIG_KEY = "embeddings"
|
|
@@ -264,6 +267,7 @@ LITELLM_SSL_CERTIFICATE_ENV_VAR = "SSL_CERTIFICATE"
|
|
|
264
267
|
|
|
265
268
|
OPENAI_PROVIDER = "openai"
|
|
266
269
|
AZURE_OPENAI_PROVIDER = "azure"
|
|
270
|
+
ANTHROPIC_PROVIDER = "anthropic"
|
|
267
271
|
SELF_HOSTED_PROVIDER = "self-hosted"
|
|
268
272
|
HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER = "huggingface_local"
|
|
269
273
|
RASA_PROVIDER = "rasa"
|
rasa/shared/core/domain.py
CHANGED
|
@@ -52,6 +52,7 @@ from rasa.shared.core.constants import (
|
|
|
52
52
|
SlotMappingType,
|
|
53
53
|
)
|
|
54
54
|
from rasa.shared.core.events import SlotSet, UserUttered
|
|
55
|
+
from rasa.shared.core.flows.constants import KEY_TRANSLATION
|
|
55
56
|
from rasa.shared.core.slots import (
|
|
56
57
|
AnySlot,
|
|
57
58
|
CategoricalSlot,
|
|
@@ -117,6 +118,7 @@ RESPONSE_KEYS_TO_INTERPOLATE = [
|
|
|
117
118
|
KEY_RESPONSES_BUTTONS,
|
|
118
119
|
KEY_RESPONSES_ATTACHMENT,
|
|
119
120
|
KEY_RESPONSES_QUICK_REPLIES,
|
|
121
|
+
KEY_TRANSLATION,
|
|
120
122
|
]
|
|
121
123
|
|
|
122
124
|
ALL_DOMAIN_KEYS = [
|
|
@@ -8,7 +8,7 @@ from functools import lru_cache
|
|
|
8
8
|
from typing import Any, Callable, Dict, List, Optional, Set, Type
|
|
9
9
|
|
|
10
10
|
import structlog
|
|
11
|
-
from azure.core.credentials import
|
|
11
|
+
from azure.core.credentials import TokenCredential
|
|
12
12
|
from azure.identity import (
|
|
13
13
|
CertificateCredential,
|
|
14
14
|
ClientSecretCredential,
|
|
@@ -77,7 +77,7 @@ class AzureEntraIDTokenProviderConfig(abc.ABC):
|
|
|
77
77
|
"""Interface for Azure Entra ID OAuth credential configuration."""
|
|
78
78
|
|
|
79
79
|
@abc.abstractmethod
|
|
80
|
-
def create_azure_token_provider(self) ->
|
|
80
|
+
def create_azure_token_provider(self) -> TokenCredential:
|
|
81
81
|
"""Create an Azure Entra ID token provider."""
|
|
82
82
|
...
|
|
83
83
|
|
|
@@ -159,7 +159,7 @@ class AzureEntraIDClientCredentialsConfig(AzureEntraIDTokenProviderConfig, BaseM
|
|
|
159
159
|
),
|
|
160
160
|
)
|
|
161
161
|
|
|
162
|
-
def create_azure_token_provider(self) ->
|
|
162
|
+
def create_azure_token_provider(self) -> TokenCredential:
|
|
163
163
|
"""Create a ClientSecretCredential for Azure Entra ID."""
|
|
164
164
|
return create_azure_entra_id_client_credentials(
|
|
165
165
|
client_id=self.client_id,
|
|
@@ -286,7 +286,7 @@ class AzureEntraIDClientCertificateConfig(AzureEntraIDTokenProviderConfig, BaseM
|
|
|
286
286
|
),
|
|
287
287
|
)
|
|
288
288
|
|
|
289
|
-
def create_azure_token_provider(self) ->
|
|
289
|
+
def create_azure_token_provider(self) -> TokenCredential:
|
|
290
290
|
"""Creates a CertificateCredential for Azure Entra ID."""
|
|
291
291
|
return create_azure_entra_id_certificate_credentials(
|
|
292
292
|
client_id=self.client_id,
|
|
@@ -369,7 +369,7 @@ class AzureEntraIDDefaultCredentialsConfig(AzureEntraIDTokenProviderConfig, Base
|
|
|
369
369
|
"""
|
|
370
370
|
return cls(authority_host=config.pop(AZURE_AUTHORITY_FIELD, None))
|
|
371
371
|
|
|
372
|
-
def create_azure_token_provider(self) ->
|
|
372
|
+
def create_azure_token_provider(self) -> TokenCredential:
|
|
373
373
|
"""Creates a DefaultAzureCredential."""
|
|
374
374
|
return create_azure_entra_id_default_credentials(
|
|
375
375
|
authority_host=self.authority_host
|
|
@@ -530,12 +530,12 @@ class AzureEntraIDOAuthConfig(OAuth, BaseModel):
|
|
|
530
530
|
azure_oauth_class = AzureEntraIDOAuthConfig._get_azure_oauth_by_type(oauth_type)
|
|
531
531
|
return azure_oauth_class.from_dict(oauth_config)
|
|
532
532
|
|
|
533
|
-
def
|
|
533
|
+
def create_azure_credential(
|
|
534
534
|
self,
|
|
535
|
-
) ->
|
|
535
|
+
) -> TokenCredential:
|
|
536
536
|
"""Create an Azure Entra ID client which can be used to get a bearer token."""
|
|
537
537
|
return self.azure_entra_id_token_provider_config.create_azure_token_provider()
|
|
538
538
|
|
|
539
539
|
def get_bearer_token(self) -> str:
|
|
540
540
|
"""Returns a bearer token."""
|
|
541
|
-
return self.
|
|
541
|
+
return self.create_azure_credential().get_token(*self.scopes).token
|
|
@@ -198,6 +198,7 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
198
198
|
"""Returns the completion arguments for invoking a call through
|
|
199
199
|
LiteLLM's completion functions.
|
|
200
200
|
"""
|
|
201
|
+
|
|
201
202
|
return {
|
|
202
203
|
**self._litellm_extra_parameters,
|
|
203
204
|
LITE_LLM_MODEL_FIELD: self.model_group_id,
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
|
+
from copy import deepcopy
|
|
4
5
|
from typing import Any, Dict, List
|
|
5
6
|
|
|
6
7
|
import structlog
|
|
@@ -18,6 +19,7 @@ from rasa.shared.constants import (
|
|
|
18
19
|
USE_CHAT_COMPLETIONS_ENDPOINT_CONFIG_KEY,
|
|
19
20
|
)
|
|
20
21
|
from rasa.shared.exceptions import ProviderClientValidationError
|
|
22
|
+
from rasa.shared.providers._configs.azure_entra_id_config import AzureEntraIDOAuthConfig
|
|
21
23
|
from rasa.shared.providers._configs.litellm_router_client_config import (
|
|
22
24
|
LiteLLMRouterClientConfig,
|
|
23
25
|
)
|
|
@@ -61,12 +63,8 @@ class _BaseLiteLLMRouterClient:
|
|
|
61
63
|
self._extra_parameters = kwargs or {}
|
|
62
64
|
self.additional_client_setup()
|
|
63
65
|
try:
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
)
|
|
67
|
-
self._router_client = Router(
|
|
68
|
-
model_list=resolved_model_configurations, **router_settings
|
|
69
|
-
)
|
|
66
|
+
# We instantiate a router client here to validate the configuration.
|
|
67
|
+
self._router_client = self._create_router_client()
|
|
70
68
|
except Exception as e:
|
|
71
69
|
event_info = "Cannot instantiate a router client."
|
|
72
70
|
structlogger.error(
|
|
@@ -145,6 +143,14 @@ class _BaseLiteLLMRouterClient:
|
|
|
145
143
|
@property
|
|
146
144
|
def router_client(self) -> Router:
|
|
147
145
|
"""Returns the instantiated LiteLLM Router client."""
|
|
146
|
+
# In ca se oauth is used, due to a bug in LiteLLM,
|
|
147
|
+
# azure_ad_token_provider is not working as expected.
|
|
148
|
+
# To work around this, we create a new client every
|
|
149
|
+
# time we need to make a call which will
|
|
150
|
+
# ensure that the token is always fresh.
|
|
151
|
+
# GitHub issue for LiteLLm: https://github.com/BerriAI/litellm/issues/4417
|
|
152
|
+
if self._has_oauth():
|
|
153
|
+
return self._create_router_client()
|
|
148
154
|
return self._router_client
|
|
149
155
|
|
|
150
156
|
@property
|
|
@@ -175,11 +181,36 @@ class _BaseLiteLLMRouterClient:
|
|
|
175
181
|
**self._litellm_extra_parameters,
|
|
176
182
|
}
|
|
177
183
|
|
|
184
|
+
def _create_router_client(self) -> Router:
|
|
185
|
+
resolved_model_configurations = self._resolve_env_vars_in_model_configurations()
|
|
186
|
+
return Router(model_list=resolved_model_configurations, **self.router_settings)
|
|
187
|
+
|
|
188
|
+
def _has_oauth(self) -> bool:
|
|
189
|
+
for model_configuration in self.model_configurations:
|
|
190
|
+
if model_configuration.get("litellm_params", {}).get("oauth", None):
|
|
191
|
+
return True
|
|
192
|
+
return False
|
|
193
|
+
|
|
178
194
|
def _resolve_env_vars_in_model_configurations(self) -> List:
|
|
179
195
|
model_configuration_with_resolved_keys = []
|
|
180
196
|
for model_configuration in self.model_configurations:
|
|
181
197
|
resolved_model_configuration = resolve_environment_variables(
|
|
182
|
-
model_configuration
|
|
198
|
+
deepcopy(model_configuration)
|
|
183
199
|
)
|
|
200
|
+
|
|
201
|
+
if not isinstance(resolved_model_configuration, dict):
|
|
202
|
+
continue
|
|
203
|
+
|
|
204
|
+
lite_llm_params = resolved_model_configuration.get("litellm_params", {})
|
|
205
|
+
if lite_llm_params.get("oauth", None):
|
|
206
|
+
oauth_config_dict = lite_llm_params.pop("oauth")
|
|
207
|
+
oauth_config = AzureEntraIDOAuthConfig.from_dict(oauth_config_dict)
|
|
208
|
+
credential = oauth_config.create_azure_credential()
|
|
209
|
+
# token_provider = get_bearer_token_provider(
|
|
210
|
+
# credential, *oauth_config.scopes
|
|
211
|
+
# )
|
|
212
|
+
resolved_model_configuration["litellm_params"]["azure_ad_token"] = (
|
|
213
|
+
credential.get_token(*oauth_config.scopes).token
|
|
214
|
+
)
|
|
184
215
|
model_configuration_with_resolved_keys.append(resolved_model_configuration)
|
|
185
216
|
return model_configuration_with_resolved_keys
|
rasa/shared/utils/llm.py
CHANGED
|
@@ -667,38 +667,94 @@ def get_prompt_template(
|
|
|
667
667
|
"""
|
|
668
668
|
try:
|
|
669
669
|
if jinja_file_path is not None:
|
|
670
|
-
|
|
670
|
+
prompt_template = rasa.shared.utils.io.read_file(jinja_file_path)
|
|
671
|
+
structlogger.info(
|
|
672
|
+
"utils.llm.get_prompt_template.custom_prompt_template_read_successfull",
|
|
673
|
+
event_info=(
|
|
674
|
+
f"Custom prompt template read successfully from "
|
|
675
|
+
f"`{jinja_file_path}`."
|
|
676
|
+
),
|
|
677
|
+
prompt_file_path=jinja_file_path,
|
|
678
|
+
)
|
|
679
|
+
return prompt_template
|
|
671
680
|
except (FileIOException, FileNotFoundException):
|
|
672
681
|
structlogger.warning(
|
|
673
|
-
"
|
|
674
|
-
|
|
682
|
+
"utils.llm.get_prompt_template.failed_to_read_custom_prompt_template",
|
|
683
|
+
event_info=(
|
|
684
|
+
"Failed to read custom prompt template. Using default template instead."
|
|
685
|
+
),
|
|
675
686
|
)
|
|
676
687
|
return default_prompt_template
|
|
677
688
|
|
|
678
689
|
|
|
679
690
|
def get_default_prompt_template_based_on_model(
|
|
680
|
-
|
|
691
|
+
llm_config: Dict[str, Any],
|
|
681
692
|
model_prompt_mapping: Dict[str, Any],
|
|
693
|
+
default_prompt_path: str,
|
|
682
694
|
fallback_prompt_path: str,
|
|
683
695
|
) -> Text:
|
|
684
696
|
"""Returns the default prompt template based on the model name.
|
|
685
697
|
|
|
686
698
|
Args:
|
|
687
|
-
|
|
699
|
+
llm_config: The model config.
|
|
688
700
|
model_prompt_mapping: The mapping of model name to prompt template.
|
|
689
|
-
|
|
701
|
+
default_prompt_path: The default prompt path of the component.
|
|
702
|
+
fallback_prompt_path: The fallback prompt path for all other models
|
|
703
|
+
that do not have a mapping in the model_prompt_mapping.
|
|
690
704
|
|
|
691
705
|
Returns:
|
|
692
706
|
The default prompt template.
|
|
693
707
|
"""
|
|
694
|
-
|
|
695
|
-
if MODELS_CONFIG_KEY in
|
|
696
|
-
|
|
697
|
-
provider =
|
|
698
|
-
model =
|
|
708
|
+
_llm_config = deepcopy(llm_config)
|
|
709
|
+
if MODELS_CONFIG_KEY in _llm_config:
|
|
710
|
+
_llm_config = _llm_config[MODELS_CONFIG_KEY][0]
|
|
711
|
+
provider = _llm_config.get(PROVIDER_CONFIG_KEY)
|
|
712
|
+
model = _llm_config.get(MODEL_CONFIG_KEY)
|
|
713
|
+
if not model:
|
|
714
|
+
# If the model is not defined, we default to the default prompt template.
|
|
715
|
+
structlogger.info(
|
|
716
|
+
"utils.llm.get_default_prompt_template_based_on_model.using_default_prompt_template",
|
|
717
|
+
event_info=(
|
|
718
|
+
f"Model not defined in the config. Default prompt template read from"
|
|
719
|
+
f" - `{default_prompt_path}`."
|
|
720
|
+
),
|
|
721
|
+
default_prompt_path=default_prompt_path,
|
|
722
|
+
)
|
|
723
|
+
return importlib.resources.read_text(
|
|
724
|
+
DEFAULT_PROMPT_PACKAGE_NAME, default_prompt_path
|
|
725
|
+
)
|
|
726
|
+
|
|
699
727
|
model_name = model if provider and provider in model else f"{provider}/{model}"
|
|
700
|
-
prompt_file_path
|
|
701
|
-
|
|
728
|
+
if prompt_file_path := model_prompt_mapping.get(model_name):
|
|
729
|
+
# If the model is found in the mapping, we use the model-specific prompt
|
|
730
|
+
# template.
|
|
731
|
+
structlogger.info(
|
|
732
|
+
"utils.llm.get_default_prompt_template_based_on_model.using_model_specific_prompt_template",
|
|
733
|
+
event_info=(
|
|
734
|
+
f"Using model-specific default prompt template. Default prompt "
|
|
735
|
+
f"template read from - `{prompt_file_path}`."
|
|
736
|
+
),
|
|
737
|
+
default_prompt_path=prompt_file_path,
|
|
738
|
+
model_name=model_name,
|
|
739
|
+
)
|
|
740
|
+
return importlib.resources.read_text(
|
|
741
|
+
DEFAULT_PROMPT_PACKAGE_NAME, prompt_file_path
|
|
742
|
+
)
|
|
743
|
+
|
|
744
|
+
# If the model is not found in the mapping, we default to the fallback prompt
|
|
745
|
+
# template.
|
|
746
|
+
structlogger.info(
|
|
747
|
+
"utils.llm.get_default_prompt_template_based_on_model.using_fallback_prompt_template",
|
|
748
|
+
event_info=(
|
|
749
|
+
f"Model not found in the model prompt mapping. Fallback prompt template "
|
|
750
|
+
f"read from - `{fallback_prompt_path}`."
|
|
751
|
+
),
|
|
752
|
+
fallback_prompt_path=fallback_prompt_path,
|
|
753
|
+
model_name=model_name,
|
|
754
|
+
)
|
|
755
|
+
return importlib.resources.read_text(
|
|
756
|
+
DEFAULT_PROMPT_PACKAGE_NAME, fallback_prompt_path
|
|
757
|
+
)
|
|
702
758
|
|
|
703
759
|
|
|
704
760
|
def allowed_values_for_slot(slot: Slot) -> Union[str, None]:
|
rasa/telemetry.py
CHANGED
|
@@ -15,7 +15,7 @@ from collections import defaultdict
|
|
|
15
15
|
from datetime import datetime
|
|
16
16
|
from functools import wraps
|
|
17
17
|
from pathlib import Path
|
|
18
|
-
from typing import Any, Callable, Dict, List, Optional, Text, Tuple
|
|
18
|
+
from typing import Any, Callable, Dict, List, Optional, Text, Tuple, Type, cast
|
|
19
19
|
|
|
20
20
|
import importlib_resources
|
|
21
21
|
import requests
|
|
@@ -1126,12 +1126,12 @@ def _get_llm_command_generator_config(config: Dict[str, Any]) -> Optional[Dict]:
|
|
|
1126
1126
|
"""
|
|
1127
1127
|
from rasa.dialogue_understanding.generator import (
|
|
1128
1128
|
CompactLLMCommandGenerator,
|
|
1129
|
+
LLMBasedCommandGenerator,
|
|
1129
1130
|
LLMCommandGenerator,
|
|
1130
1131
|
MultiStepLLMCommandGenerator,
|
|
1131
1132
|
SingleStepLLMCommandGenerator,
|
|
1132
1133
|
)
|
|
1133
1134
|
from rasa.dialogue_understanding.generator.constants import (
|
|
1134
|
-
DEFAULT_LLM_CONFIG,
|
|
1135
1135
|
FLOW_RETRIEVAL_KEY,
|
|
1136
1136
|
LLM_CONFIG_KEY,
|
|
1137
1137
|
)
|
|
@@ -1162,6 +1162,12 @@ def _get_llm_command_generator_config(config: Dict[str, Any]) -> Optional[Dict]:
|
|
|
1162
1162
|
|
|
1163
1163
|
def extract_llm_command_generator_llm_client_settings(component: Dict) -> Dict:
|
|
1164
1164
|
"""Extracts settings related to LLM command generator."""
|
|
1165
|
+
component_class_lookup = {
|
|
1166
|
+
LLMCommandGenerator.__name__: LLMCommandGenerator,
|
|
1167
|
+
SingleStepLLMCommandGenerator.__name__: SingleStepLLMCommandGenerator,
|
|
1168
|
+
MultiStepLLMCommandGenerator.__name__: MultiStepLLMCommandGenerator,
|
|
1169
|
+
CompactLLMCommandGenerator.__name__: CompactLLMCommandGenerator,
|
|
1170
|
+
}
|
|
1165
1171
|
llm_config = component.get(LLM_CONFIG_KEY, {})
|
|
1166
1172
|
# Config at this stage is not yet resolved, so read from `model_group`
|
|
1167
1173
|
llm_model_group_id = llm_config.get(MODEL_GROUP_CONFIG_KEY)
|
|
@@ -1169,7 +1175,11 @@ def _get_llm_command_generator_config(config: Dict[str, Any]) -> Optional[Dict]:
|
|
|
1169
1175
|
MODEL_NAME_CONFIG_KEY
|
|
1170
1176
|
)
|
|
1171
1177
|
if llm_model_group_id is None and llm_model_name is None:
|
|
1172
|
-
|
|
1178
|
+
component_clz = cast(
|
|
1179
|
+
Type[LLMBasedCommandGenerator],
|
|
1180
|
+
component_class_lookup[component["name"]],
|
|
1181
|
+
)
|
|
1182
|
+
llm_model_name = component_clz.get_default_llm_config()[MODEL_CONFIG_KEY]
|
|
1173
1183
|
|
|
1174
1184
|
custom_prompt_used = (
|
|
1175
1185
|
PROMPT_CONFIG_KEY in component or PROMPT_TEMPLATE_CONFIG_KEY in component
|
|
@@ -58,9 +58,7 @@ from rasa.shared.core.trackers import DialogueStateTracker
|
|
|
58
58
|
from rasa.shared.core.training_data.structures import StoryGraph
|
|
59
59
|
from rasa.shared.importers.importer import TrainingDataImporter
|
|
60
60
|
from rasa.shared.nlu.constants import INTENT_NAME_KEY, SET_SLOT_COMMAND
|
|
61
|
-
from rasa.shared.utils.llm import
|
|
62
|
-
combine_custom_and_default_config,
|
|
63
|
-
)
|
|
61
|
+
from rasa.shared.utils.llm import combine_custom_and_default_config
|
|
64
62
|
from rasa.tracing.constants import (
|
|
65
63
|
PROMPT_TOKEN_LENGTH_ATTRIBUTE_NAME,
|
|
66
64
|
REQUEST_BODY_SIZE_IN_BYTES_ATTRIBUTE_NAME,
|
|
@@ -375,14 +373,13 @@ def extract_attrs_for_llm_based_command_generator(
|
|
|
375
373
|
self: "LLMBasedCommandGenerator",
|
|
376
374
|
prompt: str,
|
|
377
375
|
) -> Dict[str, Any]:
|
|
378
|
-
from rasa.dialogue_understanding.generator.constants import DEFAULT_LLM_CONFIG
|
|
379
376
|
from rasa.dialogue_understanding.generator.flow_retrieval import (
|
|
380
377
|
DEFAULT_EMBEDDINGS_CONFIG,
|
|
381
378
|
)
|
|
382
379
|
|
|
383
380
|
attributes = extract_llm_config(
|
|
384
381
|
self,
|
|
385
|
-
default_llm_config=
|
|
382
|
+
default_llm_config=self.get_default_llm_config(),
|
|
386
383
|
default_embeddings_config=DEFAULT_EMBEDDINGS_CONFIG,
|
|
387
384
|
)
|
|
388
385
|
|