rasa-pro 3.12.13__py3-none-any.whl → 3.12.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (33) hide show
  1. rasa/cli/llm_fine_tuning.py +11 -10
  2. rasa/core/nlg/contextual_response_rephraser.py +4 -2
  3. rasa/core/policies/enterprise_search_policy.py +7 -4
  4. rasa/core/policies/intentless_policy.py +15 -9
  5. rasa/core/utils.py +4 -0
  6. rasa/dialogue_understanding/coexistence/llm_based_router.py +8 -3
  7. rasa/dialogue_understanding/commands/clarify_command.py +1 -1
  8. rasa/dialogue_understanding/commands/set_slot_command.py +1 -1
  9. rasa/dialogue_understanding/generator/constants.py +2 -2
  10. rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +2 -2
  11. rasa/dialogue_understanding_test/du_test_runner.py +3 -21
  12. rasa/dialogue_understanding_test/test_case_simulation/test_case_tracker_simulator.py +2 -6
  13. rasa/llm_fine_tuning/annotation_module.py +39 -9
  14. rasa/llm_fine_tuning/conversations.py +3 -0
  15. rasa/llm_fine_tuning/llm_data_preparation_module.py +66 -49
  16. rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +4 -2
  17. rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +52 -44
  18. rasa/llm_fine_tuning/paraphrasing_module.py +10 -12
  19. rasa/llm_fine_tuning/storage.py +4 -4
  20. rasa/llm_fine_tuning/utils.py +63 -1
  21. rasa/shared/constants.py +3 -0
  22. rasa/shared/exceptions.py +4 -0
  23. rasa/shared/providers/_configs/azure_openai_client_config.py +4 -0
  24. rasa/shared/providers/_configs/openai_client_config.py +4 -0
  25. rasa/shared/providers/embedding/_base_litellm_embedding_client.py +3 -0
  26. rasa/shared/providers/llm/_base_litellm_client.py +5 -2
  27. rasa/shared/utils/llm.py +28 -0
  28. rasa/version.py +1 -1
  29. {rasa_pro-3.12.13.dist-info → rasa_pro-3.12.14.dist-info}/METADATA +1 -1
  30. {rasa_pro-3.12.13.dist-info → rasa_pro-3.12.14.dist-info}/RECORD +33 -33
  31. {rasa_pro-3.12.13.dist-info → rasa_pro-3.12.14.dist-info}/NOTICE +0 -0
  32. {rasa_pro-3.12.13.dist-info → rasa_pro-3.12.14.dist-info}/WHEEL +0 -0
  33. {rasa_pro-3.12.13.dist-info → rasa_pro-3.12.14.dist-info}/entry_points.txt +0 -0
@@ -208,10 +208,7 @@ def prepare_llm_fine_tuning_data(args: argparse.Namespace) -> None:
208
208
  sys.exit(0)
209
209
 
210
210
  flows = asyncio.run(e2e_test_runner.agent.processor.get_flows())
211
- llm_command_generator_config = _get_llm_command_generator_config(e2e_test_runner)
212
- llm_command_generator: Type[LLMBasedCommandGenerator] = _get_llm_command_generator(
213
- e2e_test_runner
214
- )
211
+ _validate_llm_command_generator_present(e2e_test_runner)
215
212
 
216
213
  # set up storage context
217
214
  storage_context = create_storage_context(StorageType.FILE, output_dir)
@@ -242,11 +239,11 @@ def prepare_llm_fine_tuning_data(args: argparse.Namespace) -> None:
242
239
  rephrase_config,
243
240
  args.num_rephrases,
244
241
  flows,
245
- llm_command_generator,
246
- llm_command_generator_config,
242
+ e2e_test_runner.agent,
247
243
  storage_context,
248
244
  )
249
245
  )
246
+
250
247
  statistics["num_passing_rephrased_user_messages"] = sum(
251
248
  [conversation.get_number_of_rephrases(True) for conversation in conversations]
252
249
  )
@@ -257,7 +254,11 @@ def prepare_llm_fine_tuning_data(args: argparse.Namespace) -> None:
257
254
 
258
255
  # 3. create fine-tuning dataset
259
256
  log_start_of_module("LLM Data Preparation")
260
- llm_fine_tuning_data = convert_to_fine_tuning_data(conversations, storage_context)
257
+ llm_fine_tuning_data = asyncio.run(
258
+ convert_to_fine_tuning_data(
259
+ conversations, storage_context, e2e_test_runner.agent
260
+ )
261
+ )
261
262
  statistics["num_ft_data_points"] = len(llm_fine_tuning_data)
262
263
  log_end_of_module("LLM Data Preparation", statistics)
263
264
 
@@ -311,9 +312,9 @@ def _get_llm_command_generator_config(e2e_test_runner: E2ETestRunner) -> Dict[st
311
312
  sys.exit(1)
312
313
 
313
314
 
314
- def _get_llm_command_generator(
315
+ def _validate_llm_command_generator_present(
315
316
  e2e_test_runner: E2ETestRunner,
316
- ) -> Type[LLMBasedCommandGenerator]:
317
+ ) -> None:
317
318
  train_schema = e2e_test_runner.agent.processor.model_metadata.train_schema # type: ignore
318
319
 
319
320
  for _, node in train_schema.nodes.items():
@@ -322,7 +323,7 @@ def _get_llm_command_generator(
322
323
  ) and not node.matches_type(
323
324
  MultiStepLLMCommandGenerator, include_subtypes=True
324
325
  ):
325
- return cast(Type[LLMBasedCommandGenerator], node.uses)
326
+ return
326
327
 
327
328
  rasa.shared.utils.cli.print_error(
328
329
  "The provided model is not trained using 'SingleStepLLMCommandGenerator' or "
@@ -8,12 +8,14 @@ from rasa.core.nlg.response import TemplatedNaturalLanguageGenerator
8
8
  from rasa.core.nlg.summarize import summarize_conversation
9
9
  from rasa.shared.constants import (
10
10
  LLM_CONFIG_KEY,
11
+ MAX_COMPLETION_TOKENS_CONFIG_KEY,
11
12
  MODEL_CONFIG_KEY,
12
13
  MODEL_GROUP_ID_CONFIG_KEY,
13
14
  MODEL_NAME_CONFIG_KEY,
14
15
  OPENAI_PROVIDER,
15
16
  PROMPT_CONFIG_KEY,
16
17
  PROVIDER_CONFIG_KEY,
18
+ TEMPERATURE_CONFIG_KEY,
17
19
  TIMEOUT_CONFIG_KEY,
18
20
  )
19
21
  from rasa.shared.core.domain import KEY_RESPONSES_TEXT, Domain
@@ -57,8 +59,8 @@ DEFAULT_MAX_HISTORICAL_TURNS = 5
57
59
  DEFAULT_LLM_CONFIG = {
58
60
  PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
59
61
  MODEL_CONFIG_KEY: DEFAULT_OPENAI_GENERATE_MODEL_NAME,
60
- "temperature": 0.3,
61
- "max_tokens": DEFAULT_OPENAI_MAX_GENERATED_TOKENS,
62
+ TEMPERATURE_CONFIG_KEY: 0.3,
63
+ MAX_COMPLETION_TOKENS_CONFIG_KEY: DEFAULT_OPENAI_MAX_GENERATED_TOKENS,
62
64
  TIMEOUT_CONFIG_KEY: 5,
63
65
  }
64
66
 
@@ -46,12 +46,15 @@ from rasa.graph_components.providers.forms_provider import Forms
46
46
  from rasa.graph_components.providers.responses_provider import Responses
47
47
  from rasa.shared.constants import (
48
48
  EMBEDDINGS_CONFIG_KEY,
49
+ MAX_COMPLETION_TOKENS_CONFIG_KEY,
50
+ MAX_RETRIES_CONFIG_KEY,
49
51
  MODEL_CONFIG_KEY,
50
52
  MODEL_GROUP_ID_CONFIG_KEY,
51
53
  MODEL_NAME_CONFIG_KEY,
52
54
  OPENAI_PROVIDER,
53
55
  PROMPT_CONFIG_KEY,
54
56
  PROVIDER_CONFIG_KEY,
57
+ TEMPERATURE_CONFIG_KEY,
55
58
  TIMEOUT_CONFIG_KEY,
56
59
  )
57
60
  from rasa.shared.core.constants import (
@@ -135,14 +138,14 @@ DEFAULT_LLM_CONFIG = {
135
138
  PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
136
139
  MODEL_CONFIG_KEY: DEFAULT_OPENAI_CHAT_MODEL_NAME,
137
140
  TIMEOUT_CONFIG_KEY: 10,
138
- "temperature": 0.0,
139
- "max_tokens": 256,
140
- "max_retries": 1,
141
+ TEMPERATURE_CONFIG_KEY: 0.0,
142
+ MAX_COMPLETION_TOKENS_CONFIG_KEY: 256,
143
+ MAX_RETRIES_CONFIG_KEY: 1,
141
144
  }
142
145
 
143
146
  DEFAULT_EMBEDDINGS_CONFIG = {
144
147
  PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
145
- "model": DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
148
+ MODEL_CONFIG_KEY: DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
146
149
  }
147
150
 
148
151
  ENTERPRISE_SEARCH_PROMPT_FILE_NAME = "enterprise_search_policy_prompt.jinja2"
@@ -31,12 +31,14 @@ from rasa.graph_components.providers.responses_provider import Responses
31
31
  from rasa.shared.constants import (
32
32
  EMBEDDINGS_CONFIG_KEY,
33
33
  LLM_CONFIG_KEY,
34
+ MAX_COMPLETION_TOKENS_CONFIG_KEY,
34
35
  MODEL_CONFIG_KEY,
35
36
  MODEL_GROUP_ID_CONFIG_KEY,
36
37
  MODEL_NAME_CONFIG_KEY,
37
38
  OPENAI_PROVIDER,
38
39
  PROMPT_CONFIG_KEY,
39
40
  PROVIDER_CONFIG_KEY,
41
+ TEMPERATURE_CONFIG_KEY,
40
42
  TIMEOUT_CONFIG_KEY,
41
43
  )
42
44
  from rasa.shared.core.constants import ACTION_LISTEN_NAME
@@ -111,14 +113,14 @@ NLU_ABSTENTION_THRESHOLD = "nlu_abstention_threshold"
111
113
  DEFAULT_LLM_CONFIG = {
112
114
  PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
113
115
  MODEL_CONFIG_KEY: DEFAULT_OPENAI_CHAT_MODEL_NAME,
114
- "temperature": 0.0,
115
- "max_tokens": DEFAULT_OPENAI_MAX_GENERATED_TOKENS,
116
+ TEMPERATURE_CONFIG_KEY: 0.0,
117
+ MAX_COMPLETION_TOKENS_CONFIG_KEY: DEFAULT_OPENAI_MAX_GENERATED_TOKENS,
116
118
  TIMEOUT_CONFIG_KEY: 5,
117
119
  }
118
120
 
119
121
  DEFAULT_EMBEDDINGS_CONFIG = {
120
122
  PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
121
- "model": DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
123
+ MODEL_CONFIG_KEY: DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
122
124
  }
123
125
 
124
126
  DEFAULT_INTENTLESS_PROMPT_TEMPLATE = importlib.resources.open_text(
@@ -344,8 +346,6 @@ class IntentlessPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Policy):
344
346
  # ensures that the policy will not override a deterministic policy
345
347
  # which utilizes the nlu predictions confidence (e.g. Memoization).
346
348
  NLU_ABSTENTION_THRESHOLD: 0.9,
347
- LLM_CONFIG_KEY: DEFAULT_LLM_CONFIG,
348
- EMBEDDINGS_CONFIG_KEY: DEFAULT_EMBEDDINGS_CONFIG,
349
349
  PROMPT_CONFIG_KEY: DEFAULT_INTENTLESS_PROMPT_TEMPLATE,
350
350
  }
351
351
 
@@ -381,13 +381,19 @@ class IntentlessPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Policy):
381
381
  super().__init__(config, model_storage, resource, execution_context, featurizer)
382
382
 
383
383
  # Resolve LLM config
384
- self.config[LLM_CONFIG_KEY] = resolve_model_client_config(
385
- self.config.get(LLM_CONFIG_KEY), IntentlessPolicy.__name__
384
+ self.config[LLM_CONFIG_KEY] = combine_custom_and_default_config(
385
+ resolve_model_client_config(
386
+ self.config.get(LLM_CONFIG_KEY), IntentlessPolicy.__name__
387
+ ),
388
+ DEFAULT_LLM_CONFIG,
386
389
  )
387
390
 
388
391
  # Resolve embeddings config
389
- self.config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
390
- self.config.get(EMBEDDINGS_CONFIG_KEY), IntentlessPolicy.__name__
392
+ self.config[EMBEDDINGS_CONFIG_KEY] = combine_custom_and_default_config(
393
+ resolve_model_client_config(
394
+ self.config.get(EMBEDDINGS_CONFIG_KEY), IntentlessPolicy.__name__
395
+ ),
396
+ DEFAULT_EMBEDDINGS_CONFIG,
391
397
  )
392
398
 
393
399
  self.nlu_abstention_threshold: float = self.config[NLU_ABSTENTION_THRESHOLD]
rasa/core/utils.py CHANGED
@@ -244,6 +244,10 @@ class AvailableEndpoints:
244
244
  cls._instance = cls.read_endpoints(endpoint_file)
245
245
  return cls._instance
246
246
 
247
+ @classmethod
248
+ def reset_instance(cls) -> None:
249
+ cls._instance = None
250
+
247
251
 
248
252
  def read_endpoints_from_path(
249
253
  endpoints_path: Optional[Union[Path, Text]] = None,
@@ -23,11 +23,14 @@ from rasa.engine.recipes.default_recipe import DefaultV1Recipe
23
23
  from rasa.engine.storage.resource import Resource
24
24
  from rasa.engine.storage.storage import ModelStorage
25
25
  from rasa.shared.constants import (
26
+ LOGIT_BIAS_CONFIG_KEY,
27
+ MAX_COMPLETION_TOKENS_CONFIG_KEY,
26
28
  MODEL_CONFIG_KEY,
27
29
  OPENAI_PROVIDER,
28
30
  PROMPT_CONFIG_KEY,
29
31
  PROVIDER_CONFIG_KEY,
30
32
  ROUTE_TO_CALM_SLOT,
33
+ TEMPERATURE_CONFIG_KEY,
31
34
  TIMEOUT_CONFIG_KEY,
32
35
  )
33
36
  from rasa.shared.core.trackers import DialogueStateTracker
@@ -66,9 +69,11 @@ DEFAULT_LLM_CONFIG = {
66
69
  PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
67
70
  MODEL_CONFIG_KEY: DEFAULT_OPENAI_CHAT_MODEL_NAME,
68
71
  TIMEOUT_CONFIG_KEY: 7,
69
- "temperature": 0.0,
70
- "max_tokens": 1,
71
- "logit_bias": {str(token_id): 100 for token_id in A_TO_C_TOKEN_IDS_CHATGPT},
72
+ TEMPERATURE_CONFIG_KEY: 0.0,
73
+ MAX_COMPLETION_TOKENS_CONFIG_KEY: 1,
74
+ LOGIT_BIAS_CONFIG_KEY: {
75
+ str(token_id): 100 for token_id in A_TO_C_TOKEN_IDS_CHATGPT
76
+ },
72
77
  }
73
78
 
74
79
  structlogger = structlog.get_logger()
@@ -119,7 +119,7 @@ class ClarifyCommand(Command):
119
119
  mapper = {
120
120
  CommandSyntaxVersion.v1: r"Clarify\(([\"\'a-zA-Z0-9_, -]*)\)",
121
121
  CommandSyntaxVersion.v2: (
122
- r"""^[\s\W\d]*disambiguate flows (["'a-zA-Z0-9_, -]*)['"`]*$"""
122
+ r"""^[\s\W\d]*disambiguate flows (["'a-zA-Z0-9_, -]*)[\W\\n]*$"""
123
123
  ),
124
124
  }
125
125
  return mapper.get(
@@ -190,7 +190,7 @@ class SetSlotCommand(Command):
190
190
  r"""SetSlot\(['"]?([a-zA-Z_][a-zA-Z0-9_-]*)['"]?, ?['"]?(.*)['"]?\)"""
191
191
  ),
192
192
  CommandSyntaxVersion.v2: (
193
- r"""^[\s\W\d]*set slot ['"`]?([a-zA-Z_][a-zA-Z0-9_-]*)['"`]? ['"`]?(.+?)['"`]*$""" # noqa: E501
193
+ r"""^[\s\W\d]*set slot ['"`]?([a-zA-Z_][a-zA-Z0-9_-]*)['"`]? ['"`]?(.+?)[\W\\n]*$""" # noqa: E501
194
194
  ),
195
195
  }
196
196
  return mapper.get(
@@ -1,5 +1,5 @@
1
1
  from rasa.shared.constants import (
2
- MAX_TOKENS_CONFIG_KEY,
2
+ MAX_COMPLETION_TOKENS_CONFIG_KEY,
3
3
  MODEL_CONFIG_KEY,
4
4
  OPENAI_PROVIDER,
5
5
  PROVIDER_CONFIG_KEY,
@@ -15,7 +15,7 @@ DEFAULT_LLM_CONFIG = {
15
15
  PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
16
16
  MODEL_CONFIG_KEY: DEFAULT_OPENAI_CHAT_MODEL_NAME_ADVANCED,
17
17
  TEMPERATURE_CONFIG_KEY: 0.0,
18
- MAX_TOKENS_CONFIG_KEY: DEFAULT_OPENAI_MAX_GENERATED_TOKENS,
18
+ MAX_COMPLETION_TOKENS_CONFIG_KEY: DEFAULT_OPENAI_MAX_GENERATED_TOKENS,
19
19
  TIMEOUT_CONFIG_KEY: 7,
20
20
  }
21
21
 
@@ -47,7 +47,7 @@ from rasa.shared.constants import (
47
47
  AWS_BEDROCK_PROVIDER,
48
48
  AZURE_OPENAI_PROVIDER,
49
49
  EMBEDDINGS_CONFIG_KEY,
50
- MAX_TOKENS_CONFIG_KEY,
50
+ MAX_COMPLETION_TOKENS_CONFIG_KEY,
51
51
  PROMPT_TEMPLATE_CONFIG_KEY,
52
52
  ROUTE_TO_CALM_SLOT,
53
53
  TEMPERATURE_CONFIG_KEY,
@@ -81,7 +81,7 @@ DEFAULT_LLM_CONFIG = {
81
81
  PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
82
82
  MODEL_CONFIG_KEY: MODEL_NAME_GPT_4O_2024_11_20,
83
83
  TEMPERATURE_CONFIG_KEY: 0.0,
84
- MAX_TOKENS_CONFIG_KEY: DEFAULT_OPENAI_MAX_GENERATED_TOKENS,
84
+ MAX_COMPLETION_TOKENS_CONFIG_KEY: DEFAULT_OPENAI_MAX_GENERATED_TOKENS,
85
85
  TIMEOUT_CONFIG_KEY: 7,
86
86
  }
87
87
 
@@ -33,6 +33,7 @@ from rasa.e2e_test.e2e_test_runner import E2ETestRunner
33
33
  from rasa.shared.core.events import UserUttered
34
34
  from rasa.shared.core.trackers import DialogueStateTracker
35
35
  from rasa.shared.nlu.constants import PREDICTED_COMMANDS, PROMPTS
36
+ from rasa.shared.utils.llm import create_tracker_for_user_step
36
37
  from rasa.utils.endpoints import EndpointConfig
37
38
 
38
39
  structlogger = structlog.get_logger()
@@ -178,8 +179,9 @@ class DialogueUnderstandingTestRunner:
178
179
  # create and save the tracker at the time just
179
180
  # before the user message was sent
180
181
  step_sender_id = f"{sender_id}_{user_step_index}"
181
- await self._create_tracker_for_user_step(
182
+ await create_tracker_for_user_step(
182
183
  step_sender_id,
184
+ self.agent,
183
185
  test_case_tracker,
184
186
  user_uttered_event_indices[user_step_index],
185
187
  )
@@ -280,26 +282,6 @@ class DialogueUnderstandingTestRunner:
280
282
 
281
283
  return user_uttered_event
282
284
 
283
- async def _create_tracker_for_user_step(
284
- self,
285
- step_sender_id: str,
286
- test_case_tracker: DialogueStateTracker,
287
- index_user_uttered_event: int,
288
- ) -> None:
289
- """Creates a tracker for the user step."""
290
- tracker = test_case_tracker.copy()
291
- # modify the sender id so that the test case tracker is not overwritten
292
- tracker.sender_id = step_sender_id
293
-
294
- if tracker.events:
295
- # get timestamp of the event just before the user uttered event
296
- timestamp = tracker.events[index_user_uttered_event - 1].timestamp
297
- # revert the tracker to the event just before the user uttered event
298
- tracker = tracker.travel_back_in_time(timestamp)
299
-
300
- # store the tracker with the unique sender id
301
- await self.agent.tracker_store.save(tracker)
302
-
303
285
  async def _send_user_message(
304
286
  self,
305
287
  sender_id: str,
@@ -1,4 +1,3 @@
1
- from datetime import datetime
2
1
  from typing import List, Optional
3
2
 
4
3
  import structlog
@@ -24,6 +23,7 @@ from rasa.shared.core.constants import SlotMappingType
24
23
  from rasa.shared.core.events import BotUttered, SlotSet, UserUttered
25
24
  from rasa.shared.core.trackers import DialogueStateTracker
26
25
  from rasa.shared.nlu.constants import COMMANDS, ENTITIES, INTENT
26
+ from rasa.shared.utils.llm import generate_sender_id
27
27
 
28
28
  structlogger = structlog.get_logger()
29
29
 
@@ -52,7 +52,7 @@ class TestCaseTrackerSimulator:
52
52
  self.test_case = test_case
53
53
  self.output_channel = output_channel or CollectingOutputChannel()
54
54
 
55
- self.sender_id = self._generate_sender_id()
55
+ self.sender_id = generate_sender_id(self.test_case.name)
56
56
 
57
57
  async def simulate_test_case(
58
58
  self,
@@ -150,10 +150,6 @@ class TestCaseTrackerSimulator:
150
150
  user_uttered_event_indices=user_uttered_event_indices,
151
151
  )
152
152
 
153
- def _generate_sender_id(self) -> str:
154
- # add timestamp suffix to ensure sender_id is unique
155
- return f"{self.test_case.name}_{datetime.now()}"
156
-
157
153
  @staticmethod
158
154
  async def _get_latest_user_uttered_event_index(
159
155
  tracker: DialogueStateTracker, user_uttered_event_indices: List[int]
@@ -10,7 +10,9 @@ from rasa.e2e_test.e2e_test_runner import TEST_TURNS_TYPE, E2ETestRunner
10
10
  from rasa.llm_fine_tuning.conversations import Conversation, ConversationStep
11
11
  from rasa.llm_fine_tuning.storage import StorageContext
12
12
  from rasa.shared.core.constants import USER
13
+ from rasa.shared.core.events import UserUttered
13
14
  from rasa.shared.core.trackers import DialogueStateTracker
15
+ from rasa.shared.exceptions import FinetuningDataPreparationException
14
16
  from rasa.shared.nlu.constants import LLM_COMMANDS, LLM_PROMPT
15
17
  from rasa.shared.utils.llm import tracker_as_readable_transcript
16
18
 
@@ -37,7 +39,7 @@ def annotate_e2e_tests(
37
39
  storage_context: StorageContext,
38
40
  ) -> List[Conversation]:
39
41
  with set_preparing_fine_tuning_data():
40
- converations = asyncio.run(
42
+ conversations = asyncio.run(
41
43
  e2e_test_runner.run_tests_for_fine_tuning(
42
44
  test_suite.test_cases,
43
45
  test_suite.fixtures,
@@ -46,10 +48,11 @@ def annotate_e2e_tests(
46
48
  )
47
49
 
48
50
  storage_context.write_conversations(
49
- converations, ANNOTATION_MODULE_STORAGE_LOCATION
51
+ conversations,
52
+ ANNOTATION_MODULE_STORAGE_LOCATION,
50
53
  )
51
54
 
52
- return converations
55
+ return conversations
53
56
 
54
57
 
55
58
  def _get_previous_actual_step_output(
@@ -80,25 +83,45 @@ def generate_conversation(
80
83
  Conversation.
81
84
  """
82
85
  steps = []
86
+ tracker_event_indices = [
87
+ i for i, event in enumerate(tracker.events) if isinstance(event, UserUttered)
88
+ ]
89
+
90
+ if len(test_case.steps) != len(tracker_event_indices):
91
+ raise FinetuningDataPreparationException(
92
+ "Number of test case steps and tracker events do not match."
93
+ )
83
94
 
84
95
  if assertions_used:
85
96
  # we only have user steps, extract the bot response from the bot uttered
86
97
  # events of the test turn
87
- for i, original_step in enumerate(test_case.steps):
98
+ for i, (original_step, tracker_event_index) in enumerate(
99
+ zip(test_case.steps, tracker_event_indices)
100
+ ):
88
101
  previous_turn = _get_previous_actual_step_output(test_turns, i)
89
102
  steps.append(
90
103
  _convert_to_conversation_step(
91
- original_step, test_turns[i], test_case.name, previous_turn
104
+ original_step,
105
+ test_turns[i],
106
+ test_case.name,
107
+ previous_turn,
108
+ tracker_event_index,
92
109
  )
93
110
  )
94
111
  steps.extend(_create_bot_test_steps(test_turns[i]))
95
112
  else:
96
- for i, original_step in enumerate(test_case.steps):
113
+ for i, (original_step, tracker_event_index) in enumerate(
114
+ zip(test_case.steps, tracker_event_indices)
115
+ ):
97
116
  if original_step.actor == USER:
98
117
  previous_turn = _get_previous_actual_step_output(test_turns, i)
99
118
  steps.append(
100
119
  _convert_to_conversation_step(
101
- original_step, test_turns[i], test_case.name, previous_turn
120
+ original_step,
121
+ test_turns[i],
122
+ test_case.name,
123
+ previous_turn,
124
+ tracker_event_index,
102
125
  )
103
126
  )
104
127
  else:
@@ -120,7 +143,7 @@ def generate_conversation(
120
143
 
121
144
  transcript = tracker_as_readable_transcript(tracker, max_turns=None)
122
145
 
123
- return Conversation(test_case.name, test_case, steps, transcript)
146
+ return Conversation(test_case.name, test_case, steps, transcript, tracker)
124
147
 
125
148
 
126
149
  def _create_bot_test_steps(current_turn: ActualStepOutput) -> List[TestStep]:
@@ -140,6 +163,7 @@ def _convert_to_conversation_step(
140
163
  current_turn: ActualStepOutput,
141
164
  test_case_name: str,
142
165
  previous_turn: Optional[ActualStepOutput],
166
+ tracker_event_index: Optional[int] = None,
143
167
  ) -> Union[TestStep, ConversationStep]:
144
168
  if not current_step.text == current_turn.text or not isinstance(
145
169
  current_turn, ActualStepOutput
@@ -169,7 +193,13 @@ def _convert_to_conversation_step(
169
193
  commands = [Command.command_from_json(data) for data in llm_commands]
170
194
  rephrase = _should_be_rephrased(current_turn.text, previous_turn, test_case_name)
171
195
 
172
- return ConversationStep(current_step, commands, llm_prompt, rephrase=rephrase)
196
+ return ConversationStep(
197
+ current_step,
198
+ commands,
199
+ llm_prompt,
200
+ rephrase=rephrase,
201
+ tracker_event_index=tracker_event_index,
202
+ )
173
203
 
174
204
 
175
205
  def _should_be_rephrased(
@@ -4,6 +4,7 @@ from typing import Any, Dict, Iterator, List, Optional, Union
4
4
  from rasa.dialogue_understanding.commands.prompt_command import PromptCommand
5
5
  from rasa.e2e_test.e2e_test_case import TestCase, TestStep
6
6
  from rasa.shared.core.constants import USER
7
+ from rasa.shared.core.trackers import DialogueStateTracker
7
8
 
8
9
 
9
10
  @dataclass
@@ -14,6 +15,7 @@ class ConversationStep:
14
15
  failed_rephrasings: List[str] = field(default_factory=list)
15
16
  passed_rephrasings: List[str] = field(default_factory=list)
16
17
  rephrase: bool = True
18
+ tracker_event_index: Optional[int] = None
17
19
 
18
20
  def as_dict(self) -> Dict[str, Any]:
19
21
  data = {
@@ -40,6 +42,7 @@ class Conversation:
40
42
  original_e2e_test_case: TestCase
41
43
  steps: List[Union[TestStep, ConversationStep]]
42
44
  transcript: str
45
+ tracker: Optional[DialogueStateTracker] = None
43
46
 
44
47
  def iterate_over_annotated_user_steps(
45
48
  self, rephrase: Optional[bool] = None
@@ -1,13 +1,23 @@
1
1
  from dataclasses import dataclass
2
- from typing import Any, Dict, List, Optional
2
+ from typing import Any, Dict, List, Optional, cast
3
3
 
4
4
  import structlog
5
5
  from tqdm import tqdm
6
6
 
7
+ from rasa.core.agent import Agent
8
+ from rasa.core.channels import UserMessage
7
9
  from rasa.dialogue_understanding.commands.prompt_command import PromptCommand
10
+ from rasa.dialogue_understanding.utils import set_record_commands_and_prompts
8
11
  from rasa.llm_fine_tuning.conversations import Conversation, ConversationStep
9
12
  from rasa.llm_fine_tuning.storage import StorageContext
10
- from rasa.llm_fine_tuning.utils import commands_as_string
13
+ from rasa.llm_fine_tuning.utils import (
14
+ commands_as_string,
15
+ make_mock_invoke_llm,
16
+ patch_invoke_llm_in_generators,
17
+ )
18
+ from rasa.shared.core.trackers import DialogueStateTracker
19
+ from rasa.shared.nlu.constants import KEY_USER_PROMPT, PROMPTS
20
+ from rasa.shared.utils.llm import generate_sender_id
11
21
 
12
22
  LLM_DATA_PREPARATION_MODULE_STORAGE_LOCATION = "3_llm_finetune_data/llm_ft_data.jsonl"
13
23
 
@@ -47,40 +57,8 @@ def _create_data_point(
47
57
  )
48
58
 
49
59
 
50
- def _update_prompt(
51
- prompt: str,
52
- original_user_steps: List[ConversationStep],
53
- rephrased_user_steps: List[str],
54
- ) -> Optional[str]:
55
- if len(original_user_steps) != len(rephrased_user_steps):
56
- structlogger.debug(
57
- "llm_fine_tuning.llm_data_preparation_module.failed_to_update_prompt",
58
- original_user_steps=[
59
- step.original_test_step.text for step in original_user_steps
60
- ],
61
- rephrased_user_steps=rephrased_user_steps,
62
- )
63
- return None
64
-
65
- updated_prompt = prompt
66
- for user_step, rephrased_message in zip(original_user_steps, rephrased_user_steps):
67
- # replace all occurrences of the original user message with the rephrased user
68
- # message in the conversation history mentioned in the prompt
69
- updated_prompt = updated_prompt.replace(
70
- f"USER: {user_step.original_test_step.text}", f"USER: {rephrased_message}"
71
- )
72
-
73
- # replace the latest user message mentioned in the prompt
74
- updated_prompt = updated_prompt.replace(
75
- f"'''{original_user_steps[-1].original_test_step.text}'''",
76
- f"'''{rephrased_user_steps[-1]}'''",
77
- )
78
-
79
- return updated_prompt
80
-
81
-
82
- def _convert_conversation_into_llm_data(
83
- conversation: Conversation,
60
+ async def _convert_conversation_into_llm_data(
61
+ conversation: Conversation, agent: Agent
84
62
  ) -> List[LLMDataExample]:
85
63
  data = []
86
64
 
@@ -95,18 +73,52 @@ def _convert_conversation_into_llm_data(
95
73
  # create data point for the original e2e test case
96
74
  data.append(_create_data_point(step.llm_prompt, step, conversation))
97
75
 
98
- # create data points using the rephrasings, e.g. 'new_conversations'
99
- for rephrased_user_steps in new_conversations:
100
- # +1 to include the current user turn
101
- prompt = _update_prompt(
102
- step.llm_prompt,
103
- original_user_steps[: i + 1],
104
- rephrased_user_steps[: i + 1],
76
+ test_case_name = conversation.name
77
+
78
+ # create data points using the rephrasings, e.g. 'new_conversations'
79
+ for rephrased_user_steps in new_conversations:
80
+ sender_id = generate_sender_id(test_case_name)
81
+ # create a new tracker to be able to simulate the conversation from start
82
+ await agent.tracker_store.save(DialogueStateTracker(sender_id, slots=[]))
83
+ # simulate the conversation to get the prompts
84
+ for i, step in enumerate(original_user_steps):
85
+ rephrased_user_message = rephrased_user_steps[i]
86
+ user_message = UserMessage(rephrased_user_message, sender_id=sender_id)
87
+
88
+ expected_commands = "\n".join(
89
+ [command.to_dsl() for command in step.llm_commands]
90
+ )
91
+ fake_invoke_function = make_mock_invoke_llm(expected_commands)
92
+
93
+ with (
94
+ set_record_commands_and_prompts(),
95
+ patch_invoke_llm_in_generators(fake_invoke_function),
96
+ ):
97
+ await agent.handle_message(user_message)
98
+
99
+ rephrased_tracker = await agent.tracker_store.retrieve(sender_id)
100
+ if rephrased_tracker is None:
101
+ # if tracker doesn't exist, we can't create a data point
102
+ continue
103
+
104
+ latest_message = rephrased_tracker.latest_message
105
+ if latest_message is None:
106
+ # if there is no latest message, we don't create a data point
107
+ continue
108
+
109
+ # tell the type checker what we expect to find under "prompts"
110
+ prompts = cast(
111
+ Optional[List[Dict[str, Any]]], latest_message.parse_data.get(PROMPTS)
105
112
  )
106
- if prompt:
113
+
114
+ if prompts:
115
+ # as we only use single step or compact command generator,
116
+ # there is always exactly one prompt
117
+ prompt = prompts[0]
118
+ user_prompt: Optional[str] = prompt.get(KEY_USER_PROMPT)
107
119
  data.append(
108
120
  _create_data_point(
109
- prompt, step, conversation, rephrased_user_steps[i]
121
+ user_prompt, step, conversation, rephrased_user_message
110
122
  )
111
123
  )
112
124
 
@@ -149,7 +161,7 @@ def _construct_new_conversations(conversation: Conversation) -> List[List[str]]:
149
161
  current_conversation.append(step.original_test_step.text)
150
162
  continue
151
163
 
152
- # some user steps might have less rephrasings than others
164
+ # some user steps might have fewer rephrasings than others
153
165
  # loop over the rephrasings
154
166
  index = i % len(step.passed_rephrasings)
155
167
  current_conversation.append(step.passed_rephrasings[index])
@@ -165,13 +177,18 @@ def _construct_new_conversations(conversation: Conversation) -> List[List[str]]:
165
177
  return new_conversations
166
178
 
167
179
 
168
- def convert_to_fine_tuning_data(
169
- conversations: List[Conversation], storage_context: StorageContext
180
+ async def convert_to_fine_tuning_data(
181
+ conversations: List[Conversation],
182
+ storage_context: StorageContext,
183
+ agent: Agent,
170
184
  ) -> List[LLMDataExample]:
171
185
  llm_data = []
172
186
 
173
187
  for i in tqdm(range(len(conversations))):
174
- llm_data.extend(_convert_conversation_into_llm_data(conversations[i]))
188
+ conversation_llm_data = await _convert_conversation_into_llm_data(
189
+ conversations[i], agent
190
+ )
191
+ llm_data.extend(conversation_llm_data)
175
192
 
176
193
  storage_context.write_llm_data(
177
194
  llm_data, LLM_DATA_PREPARATION_MODULE_STORAGE_LOCATION