rasa-pro 3.12.13__py3-none-any.whl → 3.12.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/cli/llm_fine_tuning.py +11 -10
- rasa/core/nlg/contextual_response_rephraser.py +38 -11
- rasa/core/nlg/summarize.py +39 -5
- rasa/core/policies/enterprise_search_policy.py +7 -4
- rasa/core/policies/intentless_policy.py +15 -9
- rasa/core/utils.py +4 -0
- rasa/dialogue_understanding/coexistence/llm_based_router.py +8 -3
- rasa/dialogue_understanding/commands/clarify_command.py +1 -1
- rasa/dialogue_understanding/commands/set_slot_command.py +1 -1
- rasa/dialogue_understanding/generator/constants.py +2 -2
- rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +2 -2
- rasa/dialogue_understanding/processor/command_processor_component.py +2 -2
- rasa/dialogue_understanding_test/du_test_runner.py +3 -21
- rasa/dialogue_understanding_test/test_case_simulation/test_case_tracker_simulator.py +2 -6
- rasa/engine/recipes/default_recipe.py +26 -2
- rasa/llm_fine_tuning/annotation_module.py +39 -9
- rasa/llm_fine_tuning/conversations.py +3 -0
- rasa/llm_fine_tuning/llm_data_preparation_module.py +66 -49
- rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +4 -2
- rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +52 -44
- rasa/llm_fine_tuning/paraphrasing_module.py +10 -12
- rasa/llm_fine_tuning/storage.py +4 -4
- rasa/llm_fine_tuning/utils.py +63 -1
- rasa/shared/constants.py +3 -0
- rasa/shared/exceptions.py +4 -0
- rasa/shared/providers/_configs/azure_openai_client_config.py +4 -0
- rasa/shared/providers/_configs/openai_client_config.py +4 -0
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +3 -0
- rasa/shared/providers/llm/_base_litellm_client.py +5 -2
- rasa/shared/utils/llm.py +36 -3
- rasa/version.py +1 -1
- {rasa_pro-3.12.13.dist-info → rasa_pro-3.12.15.dist-info}/METADATA +1 -1
- {rasa_pro-3.12.13.dist-info → rasa_pro-3.12.15.dist-info}/RECORD +36 -36
- {rasa_pro-3.12.13.dist-info → rasa_pro-3.12.15.dist-info}/NOTICE +0 -0
- {rasa_pro-3.12.13.dist-info → rasa_pro-3.12.15.dist-info}/WHEEL +0 -0
- {rasa_pro-3.12.13.dist-info → rasa_pro-3.12.15.dist-info}/entry_points.txt +0 -0
|
@@ -10,7 +10,9 @@ from rasa.e2e_test.e2e_test_runner import TEST_TURNS_TYPE, E2ETestRunner
|
|
|
10
10
|
from rasa.llm_fine_tuning.conversations import Conversation, ConversationStep
|
|
11
11
|
from rasa.llm_fine_tuning.storage import StorageContext
|
|
12
12
|
from rasa.shared.core.constants import USER
|
|
13
|
+
from rasa.shared.core.events import UserUttered
|
|
13
14
|
from rasa.shared.core.trackers import DialogueStateTracker
|
|
15
|
+
from rasa.shared.exceptions import FinetuningDataPreparationException
|
|
14
16
|
from rasa.shared.nlu.constants import LLM_COMMANDS, LLM_PROMPT
|
|
15
17
|
from rasa.shared.utils.llm import tracker_as_readable_transcript
|
|
16
18
|
|
|
@@ -37,7 +39,7 @@ def annotate_e2e_tests(
|
|
|
37
39
|
storage_context: StorageContext,
|
|
38
40
|
) -> List[Conversation]:
|
|
39
41
|
with set_preparing_fine_tuning_data():
|
|
40
|
-
|
|
42
|
+
conversations = asyncio.run(
|
|
41
43
|
e2e_test_runner.run_tests_for_fine_tuning(
|
|
42
44
|
test_suite.test_cases,
|
|
43
45
|
test_suite.fixtures,
|
|
@@ -46,10 +48,11 @@ def annotate_e2e_tests(
|
|
|
46
48
|
)
|
|
47
49
|
|
|
48
50
|
storage_context.write_conversations(
|
|
49
|
-
|
|
51
|
+
conversations,
|
|
52
|
+
ANNOTATION_MODULE_STORAGE_LOCATION,
|
|
50
53
|
)
|
|
51
54
|
|
|
52
|
-
return
|
|
55
|
+
return conversations
|
|
53
56
|
|
|
54
57
|
|
|
55
58
|
def _get_previous_actual_step_output(
|
|
@@ -80,25 +83,45 @@ def generate_conversation(
|
|
|
80
83
|
Conversation.
|
|
81
84
|
"""
|
|
82
85
|
steps = []
|
|
86
|
+
tracker_event_indices = [
|
|
87
|
+
i for i, event in enumerate(tracker.events) if isinstance(event, UserUttered)
|
|
88
|
+
]
|
|
89
|
+
|
|
90
|
+
if len(test_case.steps) != len(tracker_event_indices):
|
|
91
|
+
raise FinetuningDataPreparationException(
|
|
92
|
+
"Number of test case steps and tracker events do not match."
|
|
93
|
+
)
|
|
83
94
|
|
|
84
95
|
if assertions_used:
|
|
85
96
|
# we only have user steps, extract the bot response from the bot uttered
|
|
86
97
|
# events of the test turn
|
|
87
|
-
for i, original_step in enumerate(
|
|
98
|
+
for i, (original_step, tracker_event_index) in enumerate(
|
|
99
|
+
zip(test_case.steps, tracker_event_indices)
|
|
100
|
+
):
|
|
88
101
|
previous_turn = _get_previous_actual_step_output(test_turns, i)
|
|
89
102
|
steps.append(
|
|
90
103
|
_convert_to_conversation_step(
|
|
91
|
-
original_step,
|
|
104
|
+
original_step,
|
|
105
|
+
test_turns[i],
|
|
106
|
+
test_case.name,
|
|
107
|
+
previous_turn,
|
|
108
|
+
tracker_event_index,
|
|
92
109
|
)
|
|
93
110
|
)
|
|
94
111
|
steps.extend(_create_bot_test_steps(test_turns[i]))
|
|
95
112
|
else:
|
|
96
|
-
for i, original_step in enumerate(
|
|
113
|
+
for i, (original_step, tracker_event_index) in enumerate(
|
|
114
|
+
zip(test_case.steps, tracker_event_indices)
|
|
115
|
+
):
|
|
97
116
|
if original_step.actor == USER:
|
|
98
117
|
previous_turn = _get_previous_actual_step_output(test_turns, i)
|
|
99
118
|
steps.append(
|
|
100
119
|
_convert_to_conversation_step(
|
|
101
|
-
original_step,
|
|
120
|
+
original_step,
|
|
121
|
+
test_turns[i],
|
|
122
|
+
test_case.name,
|
|
123
|
+
previous_turn,
|
|
124
|
+
tracker_event_index,
|
|
102
125
|
)
|
|
103
126
|
)
|
|
104
127
|
else:
|
|
@@ -120,7 +143,7 @@ def generate_conversation(
|
|
|
120
143
|
|
|
121
144
|
transcript = tracker_as_readable_transcript(tracker, max_turns=None)
|
|
122
145
|
|
|
123
|
-
return Conversation(test_case.name, test_case, steps, transcript)
|
|
146
|
+
return Conversation(test_case.name, test_case, steps, transcript, tracker)
|
|
124
147
|
|
|
125
148
|
|
|
126
149
|
def _create_bot_test_steps(current_turn: ActualStepOutput) -> List[TestStep]:
|
|
@@ -140,6 +163,7 @@ def _convert_to_conversation_step(
|
|
|
140
163
|
current_turn: ActualStepOutput,
|
|
141
164
|
test_case_name: str,
|
|
142
165
|
previous_turn: Optional[ActualStepOutput],
|
|
166
|
+
tracker_event_index: Optional[int] = None,
|
|
143
167
|
) -> Union[TestStep, ConversationStep]:
|
|
144
168
|
if not current_step.text == current_turn.text or not isinstance(
|
|
145
169
|
current_turn, ActualStepOutput
|
|
@@ -169,7 +193,13 @@ def _convert_to_conversation_step(
|
|
|
169
193
|
commands = [Command.command_from_json(data) for data in llm_commands]
|
|
170
194
|
rephrase = _should_be_rephrased(current_turn.text, previous_turn, test_case_name)
|
|
171
195
|
|
|
172
|
-
return ConversationStep(
|
|
196
|
+
return ConversationStep(
|
|
197
|
+
current_step,
|
|
198
|
+
commands,
|
|
199
|
+
llm_prompt,
|
|
200
|
+
rephrase=rephrase,
|
|
201
|
+
tracker_event_index=tracker_event_index,
|
|
202
|
+
)
|
|
173
203
|
|
|
174
204
|
|
|
175
205
|
def _should_be_rephrased(
|
|
@@ -4,6 +4,7 @@ from typing import Any, Dict, Iterator, List, Optional, Union
|
|
|
4
4
|
from rasa.dialogue_understanding.commands.prompt_command import PromptCommand
|
|
5
5
|
from rasa.e2e_test.e2e_test_case import TestCase, TestStep
|
|
6
6
|
from rasa.shared.core.constants import USER
|
|
7
|
+
from rasa.shared.core.trackers import DialogueStateTracker
|
|
7
8
|
|
|
8
9
|
|
|
9
10
|
@dataclass
|
|
@@ -14,6 +15,7 @@ class ConversationStep:
|
|
|
14
15
|
failed_rephrasings: List[str] = field(default_factory=list)
|
|
15
16
|
passed_rephrasings: List[str] = field(default_factory=list)
|
|
16
17
|
rephrase: bool = True
|
|
18
|
+
tracker_event_index: Optional[int] = None
|
|
17
19
|
|
|
18
20
|
def as_dict(self) -> Dict[str, Any]:
|
|
19
21
|
data = {
|
|
@@ -40,6 +42,7 @@ class Conversation:
|
|
|
40
42
|
original_e2e_test_case: TestCase
|
|
41
43
|
steps: List[Union[TestStep, ConversationStep]]
|
|
42
44
|
transcript: str
|
|
45
|
+
tracker: Optional[DialogueStateTracker] = None
|
|
43
46
|
|
|
44
47
|
def iterate_over_annotated_user_steps(
|
|
45
48
|
self, rephrase: Optional[bool] = None
|
|
@@ -1,13 +1,23 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
|
-
from typing import Any, Dict, List, Optional
|
|
2
|
+
from typing import Any, Dict, List, Optional, cast
|
|
3
3
|
|
|
4
4
|
import structlog
|
|
5
5
|
from tqdm import tqdm
|
|
6
6
|
|
|
7
|
+
from rasa.core.agent import Agent
|
|
8
|
+
from rasa.core.channels import UserMessage
|
|
7
9
|
from rasa.dialogue_understanding.commands.prompt_command import PromptCommand
|
|
10
|
+
from rasa.dialogue_understanding.utils import set_record_commands_and_prompts
|
|
8
11
|
from rasa.llm_fine_tuning.conversations import Conversation, ConversationStep
|
|
9
12
|
from rasa.llm_fine_tuning.storage import StorageContext
|
|
10
|
-
from rasa.llm_fine_tuning.utils import
|
|
13
|
+
from rasa.llm_fine_tuning.utils import (
|
|
14
|
+
commands_as_string,
|
|
15
|
+
make_mock_invoke_llm,
|
|
16
|
+
patch_invoke_llm_in_generators,
|
|
17
|
+
)
|
|
18
|
+
from rasa.shared.core.trackers import DialogueStateTracker
|
|
19
|
+
from rasa.shared.nlu.constants import KEY_USER_PROMPT, PROMPTS
|
|
20
|
+
from rasa.shared.utils.llm import generate_sender_id
|
|
11
21
|
|
|
12
22
|
LLM_DATA_PREPARATION_MODULE_STORAGE_LOCATION = "3_llm_finetune_data/llm_ft_data.jsonl"
|
|
13
23
|
|
|
@@ -47,40 +57,8 @@ def _create_data_point(
|
|
|
47
57
|
)
|
|
48
58
|
|
|
49
59
|
|
|
50
|
-
def
|
|
51
|
-
|
|
52
|
-
original_user_steps: List[ConversationStep],
|
|
53
|
-
rephrased_user_steps: List[str],
|
|
54
|
-
) -> Optional[str]:
|
|
55
|
-
if len(original_user_steps) != len(rephrased_user_steps):
|
|
56
|
-
structlogger.debug(
|
|
57
|
-
"llm_fine_tuning.llm_data_preparation_module.failed_to_update_prompt",
|
|
58
|
-
original_user_steps=[
|
|
59
|
-
step.original_test_step.text for step in original_user_steps
|
|
60
|
-
],
|
|
61
|
-
rephrased_user_steps=rephrased_user_steps,
|
|
62
|
-
)
|
|
63
|
-
return None
|
|
64
|
-
|
|
65
|
-
updated_prompt = prompt
|
|
66
|
-
for user_step, rephrased_message in zip(original_user_steps, rephrased_user_steps):
|
|
67
|
-
# replace all occurrences of the original user message with the rephrased user
|
|
68
|
-
# message in the conversation history mentioned in the prompt
|
|
69
|
-
updated_prompt = updated_prompt.replace(
|
|
70
|
-
f"USER: {user_step.original_test_step.text}", f"USER: {rephrased_message}"
|
|
71
|
-
)
|
|
72
|
-
|
|
73
|
-
# replace the latest user message mentioned in the prompt
|
|
74
|
-
updated_prompt = updated_prompt.replace(
|
|
75
|
-
f"'''{original_user_steps[-1].original_test_step.text}'''",
|
|
76
|
-
f"'''{rephrased_user_steps[-1]}'''",
|
|
77
|
-
)
|
|
78
|
-
|
|
79
|
-
return updated_prompt
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
def _convert_conversation_into_llm_data(
|
|
83
|
-
conversation: Conversation,
|
|
60
|
+
async def _convert_conversation_into_llm_data(
|
|
61
|
+
conversation: Conversation, agent: Agent
|
|
84
62
|
) -> List[LLMDataExample]:
|
|
85
63
|
data = []
|
|
86
64
|
|
|
@@ -95,18 +73,52 @@ def _convert_conversation_into_llm_data(
|
|
|
95
73
|
# create data point for the original e2e test case
|
|
96
74
|
data.append(_create_data_point(step.llm_prompt, step, conversation))
|
|
97
75
|
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
76
|
+
test_case_name = conversation.name
|
|
77
|
+
|
|
78
|
+
# create data points using the rephrasings, e.g. 'new_conversations'
|
|
79
|
+
for rephrased_user_steps in new_conversations:
|
|
80
|
+
sender_id = generate_sender_id(test_case_name)
|
|
81
|
+
# create a new tracker to be able to simulate the conversation from start
|
|
82
|
+
await agent.tracker_store.save(DialogueStateTracker(sender_id, slots=[]))
|
|
83
|
+
# simulate the conversation to get the prompts
|
|
84
|
+
for i, step in enumerate(original_user_steps):
|
|
85
|
+
rephrased_user_message = rephrased_user_steps[i]
|
|
86
|
+
user_message = UserMessage(rephrased_user_message, sender_id=sender_id)
|
|
87
|
+
|
|
88
|
+
expected_commands = "\n".join(
|
|
89
|
+
[command.to_dsl() for command in step.llm_commands]
|
|
90
|
+
)
|
|
91
|
+
fake_invoke_function = make_mock_invoke_llm(expected_commands)
|
|
92
|
+
|
|
93
|
+
with (
|
|
94
|
+
set_record_commands_and_prompts(),
|
|
95
|
+
patch_invoke_llm_in_generators(fake_invoke_function),
|
|
96
|
+
):
|
|
97
|
+
await agent.handle_message(user_message)
|
|
98
|
+
|
|
99
|
+
rephrased_tracker = await agent.tracker_store.retrieve(sender_id)
|
|
100
|
+
if rephrased_tracker is None:
|
|
101
|
+
# if tracker doesn't exist, we can't create a data point
|
|
102
|
+
continue
|
|
103
|
+
|
|
104
|
+
latest_message = rephrased_tracker.latest_message
|
|
105
|
+
if latest_message is None:
|
|
106
|
+
# if there is no latest message, we don't create a data point
|
|
107
|
+
continue
|
|
108
|
+
|
|
109
|
+
# tell the type checker what we expect to find under "prompts"
|
|
110
|
+
prompts = cast(
|
|
111
|
+
Optional[List[Dict[str, Any]]], latest_message.parse_data.get(PROMPTS)
|
|
105
112
|
)
|
|
106
|
-
|
|
113
|
+
|
|
114
|
+
if prompts:
|
|
115
|
+
# as we only use single step or compact command generator,
|
|
116
|
+
# there is always exactly one prompt
|
|
117
|
+
prompt = prompts[0]
|
|
118
|
+
user_prompt: Optional[str] = prompt.get(KEY_USER_PROMPT)
|
|
107
119
|
data.append(
|
|
108
120
|
_create_data_point(
|
|
109
|
-
|
|
121
|
+
user_prompt, step, conversation, rephrased_user_message
|
|
110
122
|
)
|
|
111
123
|
)
|
|
112
124
|
|
|
@@ -149,7 +161,7 @@ def _construct_new_conversations(conversation: Conversation) -> List[List[str]]:
|
|
|
149
161
|
current_conversation.append(step.original_test_step.text)
|
|
150
162
|
continue
|
|
151
163
|
|
|
152
|
-
# some user steps might have
|
|
164
|
+
# some user steps might have fewer rephrasings than others
|
|
153
165
|
# loop over the rephrasings
|
|
154
166
|
index = i % len(step.passed_rephrasings)
|
|
155
167
|
current_conversation.append(step.passed_rephrasings[index])
|
|
@@ -165,13 +177,18 @@ def _construct_new_conversations(conversation: Conversation) -> List[List[str]]:
|
|
|
165
177
|
return new_conversations
|
|
166
178
|
|
|
167
179
|
|
|
168
|
-
def convert_to_fine_tuning_data(
|
|
169
|
-
conversations: List[Conversation],
|
|
180
|
+
async def convert_to_fine_tuning_data(
|
|
181
|
+
conversations: List[Conversation],
|
|
182
|
+
storage_context: StorageContext,
|
|
183
|
+
agent: Agent,
|
|
170
184
|
) -> List[LLMDataExample]:
|
|
171
185
|
llm_data = []
|
|
172
186
|
|
|
173
187
|
for i in tqdm(range(len(conversations))):
|
|
174
|
-
|
|
188
|
+
conversation_llm_data = await _convert_conversation_into_llm_data(
|
|
189
|
+
conversations[i], agent
|
|
190
|
+
)
|
|
191
|
+
llm_data.extend(conversation_llm_data)
|
|
175
192
|
|
|
176
193
|
storage_context.write_llm_data(
|
|
177
194
|
llm_data, LLM_DATA_PREPARATION_MODULE_STORAGE_LOCATION
|
|
@@ -11,10 +11,12 @@ from rasa.llm_fine_tuning.paraphrasing.rephrased_user_message import (
|
|
|
11
11
|
)
|
|
12
12
|
from rasa.shared.constants import (
|
|
13
13
|
LLM_CONFIG_KEY,
|
|
14
|
+
MAX_COMPLETION_TOKENS_CONFIG_KEY,
|
|
14
15
|
MODEL_CONFIG_KEY,
|
|
15
16
|
MODEL_NAME_CONFIG_KEY,
|
|
16
17
|
PROMPT_TEMPLATE_CONFIG_KEY,
|
|
17
18
|
PROVIDER_CONFIG_KEY,
|
|
19
|
+
TEMPERATURE_CONFIG_KEY,
|
|
18
20
|
TIMEOUT_CONFIG_KEY,
|
|
19
21
|
)
|
|
20
22
|
from rasa.shared.exceptions import ProviderClientAPIException
|
|
@@ -39,8 +41,8 @@ DEFAULT_LLM_CONFIG = {
|
|
|
39
41
|
PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
|
|
40
42
|
MODEL_CONFIG_KEY: "gpt-4o-mini",
|
|
41
43
|
TIMEOUT_CONFIG_KEY: 7,
|
|
42
|
-
|
|
43
|
-
|
|
44
|
+
TEMPERATURE_CONFIG_KEY: 0.0,
|
|
45
|
+
MAX_COMPLETION_TOKENS_CONFIG_KEY: 4096,
|
|
44
46
|
}
|
|
45
47
|
|
|
46
48
|
structlogger = structlog.get_logger()
|
|
@@ -1,45 +1,45 @@
|
|
|
1
|
-
from typing import
|
|
1
|
+
from typing import List, Optional
|
|
2
2
|
|
|
3
3
|
import structlog
|
|
4
4
|
|
|
5
|
+
from rasa.core.agent import Agent
|
|
6
|
+
from rasa.core.channels import UserMessage
|
|
5
7
|
from rasa.dialogue_understanding.commands import Command, SetSlotCommand
|
|
6
|
-
from rasa.dialogue_understanding.generator.llm_based_command_generator import (
|
|
7
|
-
LLMBasedCommandGenerator,
|
|
8
|
-
)
|
|
9
8
|
from rasa.llm_fine_tuning.conversations import Conversation, ConversationStep
|
|
10
9
|
from rasa.llm_fine_tuning.paraphrasing.rephrased_user_message import (
|
|
11
10
|
RephrasedUserMessage,
|
|
12
11
|
)
|
|
13
12
|
from rasa.shared.core.flows import FlowsList
|
|
14
|
-
from rasa.shared.
|
|
15
|
-
from rasa.shared.utils.llm import
|
|
13
|
+
from rasa.shared.core.trackers import DialogueStateTracker
|
|
14
|
+
from rasa.shared.utils.llm import (
|
|
15
|
+
create_tracker_for_user_step,
|
|
16
|
+
generate_sender_id,
|
|
17
|
+
)
|
|
16
18
|
|
|
17
19
|
structlogger = structlog.get_logger()
|
|
18
20
|
|
|
19
21
|
|
|
20
22
|
class RephraseValidator:
|
|
21
|
-
def __init__(self,
|
|
22
|
-
self.llm_config = llm_config
|
|
23
|
+
def __init__(self, flows: FlowsList) -> None:
|
|
23
24
|
self.flows = flows
|
|
24
25
|
|
|
25
26
|
async def validate_rephrasings(
|
|
26
27
|
self,
|
|
28
|
+
agent: Agent,
|
|
27
29
|
rephrasings: List[RephrasedUserMessage],
|
|
28
30
|
conversation: Conversation,
|
|
29
|
-
llm_command_generator: Type[LLMBasedCommandGenerator],
|
|
30
31
|
) -> List[RephrasedUserMessage]:
|
|
31
32
|
"""Split rephrased user messages into passing and failing.
|
|
32
33
|
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
message. The rephase is passing if the commands match and failing otherwise.
|
|
34
|
+
Handle the rephrased messages using agent the same way the original
|
|
35
|
+
message was handled. Check if the rephrased user message is producing
|
|
36
|
+
the same commands as the original user message. The rephrase is passing
|
|
37
|
+
if the commands match and failing otherwise.
|
|
38
38
|
|
|
39
39
|
Args:
|
|
40
|
+
agent: Rasa agent
|
|
40
41
|
rephrasings: The rephrased user messages.
|
|
41
42
|
conversation: The conversation.
|
|
42
|
-
llm_command_generator: A LLM based command generator class.
|
|
43
43
|
|
|
44
44
|
Returns:
|
|
45
45
|
A list of rephrased user messages including the passing and failing
|
|
@@ -52,7 +52,11 @@ class RephraseValidator:
|
|
|
52
52
|
|
|
53
53
|
for rephrase in current_rephrasings.rephrasings:
|
|
54
54
|
if await self._validate_rephrase_is_passing(
|
|
55
|
-
|
|
55
|
+
agent,
|
|
56
|
+
rephrase,
|
|
57
|
+
step,
|
|
58
|
+
conversation.name,
|
|
59
|
+
conversation.tracker,
|
|
56
60
|
):
|
|
57
61
|
current_rephrasings.passed_rephrasings.append(rephrase)
|
|
58
62
|
else:
|
|
@@ -62,40 +66,29 @@ class RephraseValidator:
|
|
|
62
66
|
|
|
63
67
|
async def _validate_rephrase_is_passing(
|
|
64
68
|
self,
|
|
69
|
+
agent: Agent,
|
|
65
70
|
rephrase: str,
|
|
66
71
|
step: ConversationStep,
|
|
67
|
-
|
|
72
|
+
test_case_name: str,
|
|
73
|
+
tracker: DialogueStateTracker,
|
|
68
74
|
) -> bool:
|
|
69
|
-
|
|
70
|
-
rephrase, step
|
|
71
|
-
)
|
|
72
|
-
|
|
73
|
-
action_list = await self._invoke_llm(
|
|
74
|
-
prompt, llm_command_generator.get_default_llm_config()
|
|
75
|
+
rephrased_tracker = await self._send_rephrased_message_to_agent(
|
|
76
|
+
rephrase, step, test_case_name, agent, tracker
|
|
75
77
|
)
|
|
78
|
+
if not (rephrased_tracker and rephrased_tracker.latest_message):
|
|
79
|
+
return False
|
|
76
80
|
|
|
77
81
|
commands_from_original_utterance = step.llm_commands
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
82
|
+
|
|
83
|
+
commands_from_rephrased_utterance = [
|
|
84
|
+
Command.command_from_json(command_json)
|
|
85
|
+
for command_json in rephrased_tracker.latest_message.commands
|
|
86
|
+
]
|
|
87
|
+
|
|
81
88
|
return self._check_commands_match(
|
|
82
89
|
commands_from_original_utterance, commands_from_rephrased_utterance
|
|
83
90
|
)
|
|
84
91
|
|
|
85
|
-
async def _invoke_llm(self, prompt: str, default_llm_config: Dict[str, Any]) -> str:
|
|
86
|
-
llm = llm_factory(self.llm_config, default_llm_config)
|
|
87
|
-
|
|
88
|
-
try:
|
|
89
|
-
llm_response = await llm.acompletion(prompt)
|
|
90
|
-
return llm_response.choices[0]
|
|
91
|
-
except Exception as e:
|
|
92
|
-
# unfortunately, langchain does not wrap LLM exceptions which means
|
|
93
|
-
# we have to catch all exceptions here
|
|
94
|
-
structlogger.error(
|
|
95
|
-
"rephrase_validator.validate_conversation.llm.error", error=e
|
|
96
|
-
)
|
|
97
|
-
raise ProviderClientAPIException(e, message="LLM call exception")
|
|
98
|
-
|
|
99
92
|
@staticmethod
|
|
100
93
|
def _check_commands_match(
|
|
101
94
|
expected_commands: List[Command], actual_commands: List[Command]
|
|
@@ -120,7 +113,22 @@ class RephraseValidator:
|
|
|
120
113
|
return True
|
|
121
114
|
|
|
122
115
|
@staticmethod
|
|
123
|
-
def
|
|
124
|
-
rephrased_user_message: str,
|
|
125
|
-
|
|
126
|
-
|
|
116
|
+
async def _send_rephrased_message_to_agent(
|
|
117
|
+
rephrased_user_message: str,
|
|
118
|
+
step: ConversationStep,
|
|
119
|
+
test_case_name: str,
|
|
120
|
+
agent: Agent,
|
|
121
|
+
tracker: DialogueStateTracker,
|
|
122
|
+
) -> Optional[DialogueStateTracker]:
|
|
123
|
+
# create a rephrased UserMessage
|
|
124
|
+
sender_id = generate_sender_id(test_case_name)
|
|
125
|
+
user_message = UserMessage(rephrased_user_message, sender_id=sender_id)
|
|
126
|
+
|
|
127
|
+
await create_tracker_for_user_step(
|
|
128
|
+
sender_id, agent, tracker, step.tracker_event_index
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
await agent.handle_message(user_message)
|
|
132
|
+
rephrased_tracker = await agent.tracker_store.retrieve(sender_id)
|
|
133
|
+
|
|
134
|
+
return rephrased_tracker
|
|
@@ -1,11 +1,9 @@
|
|
|
1
|
-
from typing import Any, Dict, List, Tuple
|
|
1
|
+
from typing import Any, Dict, List, Tuple
|
|
2
2
|
|
|
3
3
|
import structlog
|
|
4
4
|
from tqdm import tqdm
|
|
5
5
|
|
|
6
|
-
from rasa.
|
|
7
|
-
LLMBasedCommandGenerator,
|
|
8
|
-
)
|
|
6
|
+
from rasa.core.agent import Agent
|
|
9
7
|
from rasa.llm_fine_tuning.conversations import Conversation
|
|
10
8
|
from rasa.llm_fine_tuning.paraphrasing.conversation_rephraser import (
|
|
11
9
|
ConversationRephraser,
|
|
@@ -28,8 +26,7 @@ async def create_paraphrased_conversations(
|
|
|
28
26
|
rephrase_config: Dict[str, Any],
|
|
29
27
|
num_rephrases: int,
|
|
30
28
|
flows: FlowsList,
|
|
31
|
-
|
|
32
|
-
llm_command_generator_config: Dict[str, Any],
|
|
29
|
+
agent: Agent,
|
|
33
30
|
storage_context: StorageContext,
|
|
34
31
|
) -> Tuple[List[Conversation], Dict[str, Any]]:
|
|
35
32
|
"""Create paraphrased conversations.
|
|
@@ -42,7 +39,7 @@ async def create_paraphrased_conversations(
|
|
|
42
39
|
rephrase_config: The path to the rephrase configuration file.
|
|
43
40
|
num_rephrases: The number of rephrases to produce per user message.
|
|
44
41
|
flows: All flows.
|
|
45
|
-
|
|
42
|
+
agent: The Rasa agent.
|
|
46
43
|
storage_context: The storage context.
|
|
47
44
|
|
|
48
45
|
Returns:
|
|
@@ -50,7 +47,7 @@ async def create_paraphrased_conversations(
|
|
|
50
47
|
rephrasing.
|
|
51
48
|
"""
|
|
52
49
|
rephraser = ConversationRephraser(rephrase_config)
|
|
53
|
-
validator = RephraseValidator(
|
|
50
|
+
validator = RephraseValidator(flows)
|
|
54
51
|
|
|
55
52
|
if num_rephrases <= 0:
|
|
56
53
|
structlogger.info(
|
|
@@ -64,18 +61,19 @@ async def create_paraphrased_conversations(
|
|
|
64
61
|
rephrased_conversations: List[Conversation] = []
|
|
65
62
|
for i in tqdm(range(len(conversations))):
|
|
66
63
|
current_conversation = conversations[i]
|
|
67
|
-
|
|
68
64
|
try:
|
|
69
65
|
# rephrase all user messages even if rephrase=False is set
|
|
70
66
|
# to not confuse the LLM and get valid output
|
|
71
67
|
rephrasings = await rephraser.rephrase_conversation(
|
|
72
|
-
|
|
68
|
+
current_conversation, num_rephrases
|
|
73
69
|
)
|
|
74
70
|
# filter out the rephrasings for user messages that have rephrase=False set
|
|
75
|
-
rephrasings = _filter_rephrasings(rephrasings,
|
|
71
|
+
rephrasings = _filter_rephrasings(rephrasings, current_conversation)
|
|
76
72
|
# check if the rephrasings are still producing the same commands
|
|
77
73
|
rephrasings = await validator.validate_rephrasings(
|
|
78
|
-
|
|
74
|
+
agent,
|
|
75
|
+
rephrasings,
|
|
76
|
+
current_conversation,
|
|
79
77
|
)
|
|
80
78
|
except ProviderClientAPIException as e:
|
|
81
79
|
structlogger.error(
|
rasa/llm_fine_tuning/storage.py
CHANGED
|
@@ -96,9 +96,9 @@ class FileStorageStrategy(StorageStrategy):
|
|
|
96
96
|
file_path = self._get_file_path(storage_location)
|
|
97
97
|
self._create_output_dir(file_path)
|
|
98
98
|
|
|
99
|
-
with open(str(file_path), "w") as outfile:
|
|
99
|
+
with open(str(file_path), "w", encoding="utf-8") as outfile:
|
|
100
100
|
for example in llm_data:
|
|
101
|
-
json.dump(example.as_dict(), outfile)
|
|
101
|
+
json.dump(example.as_dict(), outfile, ensure_ascii=False)
|
|
102
102
|
outfile.write("\n")
|
|
103
103
|
|
|
104
104
|
def write_formatted_finetuning_data(
|
|
@@ -110,9 +110,9 @@ class FileStorageStrategy(StorageStrategy):
|
|
|
110
110
|
file_path = self._get_file_path(module_storage_location, file_name)
|
|
111
111
|
self._create_output_dir(file_path)
|
|
112
112
|
|
|
113
|
-
with open(str(file_path), "w") as file:
|
|
113
|
+
with open(str(file_path), "w", encoding="utf-8") as file:
|
|
114
114
|
for example in formatted_data:
|
|
115
|
-
json.dump(example.as_dict(), file)
|
|
115
|
+
json.dump(example.as_dict(), file, ensure_ascii=False)
|
|
116
116
|
file.write("\n")
|
|
117
117
|
|
|
118
118
|
def write_e2e_test_suite_to_yaml_file(
|
rasa/llm_fine_tuning/utils.py
CHANGED
|
@@ -1,7 +1,69 @@
|
|
|
1
|
-
from
|
|
1
|
+
from contextlib import contextmanager
|
|
2
|
+
from datetime import datetime
|
|
3
|
+
from typing import Callable, Generator, List, Union
|
|
4
|
+
|
|
5
|
+
import structlog
|
|
2
6
|
|
|
3
7
|
from rasa.dialogue_understanding.commands.prompt_command import PromptCommand
|
|
8
|
+
from rasa.dialogue_understanding.generator import LLMBasedCommandGenerator
|
|
9
|
+
from rasa.shared.providers.llm.llm_response import LLMResponse
|
|
10
|
+
|
|
11
|
+
structlogger = structlog.get_logger()
|
|
4
12
|
|
|
5
13
|
|
|
6
14
|
def commands_as_string(commands: List[PromptCommand], delimiter: str = "\n") -> str:
|
|
7
15
|
return delimiter.join([command.to_dsl() for command in commands])
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def make_mock_invoke_llm(commands: str) -> Callable:
|
|
19
|
+
"""Capture the `commands` in a closure so the resulting async function
|
|
20
|
+
can use it as its response.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
commands: The commands to return from the mock LLM call.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
async def _mock_invoke_llm(
|
|
27
|
+
self: LLMBasedCommandGenerator, prompt: Union[List[dict], List[str], str]
|
|
28
|
+
) -> LLMResponse:
|
|
29
|
+
structlogger.debug(
|
|
30
|
+
f"LLM call intercepted, response mocked. "
|
|
31
|
+
f"Responding with the following commands: '{commands}' "
|
|
32
|
+
f"to the prompt: {prompt}"
|
|
33
|
+
)
|
|
34
|
+
fake_response_dict = {
|
|
35
|
+
"id": "",
|
|
36
|
+
"choices": [commands],
|
|
37
|
+
"created": int(datetime.now().timestamp()),
|
|
38
|
+
"model": "mocked-llm",
|
|
39
|
+
}
|
|
40
|
+
return LLMResponse.from_dict(fake_response_dict)
|
|
41
|
+
|
|
42
|
+
return _mock_invoke_llm
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@contextmanager
|
|
46
|
+
def patch_invoke_llm_in_generators(mock_impl: Callable) -> Generator:
|
|
47
|
+
"""Replace CommandGenerator.invoke_llm in the base class AND in all
|
|
48
|
+
current subclasses (recursively). Everything is restored on exit.
|
|
49
|
+
"""
|
|
50
|
+
originals = {}
|
|
51
|
+
|
|
52
|
+
def collect(cls: type[LLMBasedCommandGenerator]) -> None:
|
|
53
|
+
# store current attribute, then recurse
|
|
54
|
+
originals[cls] = cls.invoke_llm
|
|
55
|
+
for sub in cls.__subclasses__():
|
|
56
|
+
collect(sub)
|
|
57
|
+
|
|
58
|
+
# collect every existing subclass of CommandGenerator
|
|
59
|
+
collect(LLMBasedCommandGenerator) # type: ignore[type-abstract]
|
|
60
|
+
|
|
61
|
+
try:
|
|
62
|
+
# apply the monkey-patch everywhere
|
|
63
|
+
for cls in originals:
|
|
64
|
+
cls.invoke_llm = mock_impl # type: ignore[assignment]
|
|
65
|
+
yield
|
|
66
|
+
finally:
|
|
67
|
+
# restore originals (even if an exception happened)
|
|
68
|
+
for cls, orig in originals.items():
|
|
69
|
+
cls.invoke_llm = orig # type: ignore[assignment]
|
rasa/shared/constants.py
CHANGED
|
@@ -197,7 +197,10 @@ PROVIDER_CONFIG_KEY = "provider"
|
|
|
197
197
|
REQUEST_TIMEOUT_CONFIG_KEY = "request_timeout" # deprecated
|
|
198
198
|
TIMEOUT_CONFIG_KEY = "timeout"
|
|
199
199
|
|
|
200
|
+
LOGIT_BIAS_CONFIG_KEY = "logit_bias"
|
|
201
|
+
MAX_RETRIES_CONFIG_KEY = "max_retries"
|
|
200
202
|
TEMPERATURE_CONFIG_KEY = "temperature"
|
|
203
|
+
MAX_COMPLETION_TOKENS_CONFIG_KEY = "max_completion_tokens"
|
|
201
204
|
MAX_TOKENS_CONFIG_KEY = "max_tokens"
|
|
202
205
|
|
|
203
206
|
DEPLOYMENT_NAME_CONFIG_KEY = "deployment_name"
|
rasa/shared/exceptions.py
CHANGED
|
@@ -165,3 +165,7 @@ class ProviderClientAPIException(RasaException):
|
|
|
165
165
|
|
|
166
166
|
class ProviderClientValidationError(RasaException):
|
|
167
167
|
"""Raised for errors that occur during validation of the API client."""
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
class FinetuningDataPreparationException(RasaException):
|
|
171
|
+
"""Raised when there is an error in data preparation for fine-tuning."""
|
|
@@ -23,6 +23,8 @@ from rasa.shared.constants import (
|
|
|
23
23
|
DEPLOYMENT_NAME_CONFIG_KEY,
|
|
24
24
|
ENGINE_CONFIG_KEY,
|
|
25
25
|
LANGCHAIN_TYPE_CONFIG_KEY,
|
|
26
|
+
MAX_COMPLETION_TOKENS_CONFIG_KEY,
|
|
27
|
+
MAX_TOKENS_CONFIG_KEY,
|
|
26
28
|
MODEL_CONFIG_KEY,
|
|
27
29
|
MODEL_NAME_CONFIG_KEY,
|
|
28
30
|
N_REPHRASES_CONFIG_KEY,
|
|
@@ -71,6 +73,8 @@ DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING = {
|
|
|
71
73
|
MODEL_NAME_CONFIG_KEY: MODEL_CONFIG_KEY,
|
|
72
74
|
# Timeout aliases
|
|
73
75
|
REQUEST_TIMEOUT_CONFIG_KEY: TIMEOUT_CONFIG_KEY,
|
|
76
|
+
# Max tokens aliases
|
|
77
|
+
MAX_TOKENS_CONFIG_KEY: MAX_COMPLETION_TOKENS_CONFIG_KEY,
|
|
74
78
|
}
|
|
75
79
|
|
|
76
80
|
REQUIRED_KEYS = [DEPLOYMENT_CONFIG_KEY]
|