rasa-pro 3.12.12.dev1__py3-none-any.whl → 3.12.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (37) hide show
  1. rasa/cli/llm_fine_tuning.py +11 -10
  2. rasa/core/nlg/contextual_response_rephraser.py +4 -2
  3. rasa/core/policies/enterprise_search_policy.py +7 -4
  4. rasa/core/policies/intentless_policy.py +15 -9
  5. rasa/core/run.py +7 -2
  6. rasa/core/utils.py +4 -0
  7. rasa/dialogue_understanding/coexistence/llm_based_router.py +8 -3
  8. rasa/dialogue_understanding/commands/clarify_command.py +2 -2
  9. rasa/dialogue_understanding/commands/set_slot_command.py +1 -1
  10. rasa/dialogue_understanding/generator/constants.py +2 -2
  11. rasa/dialogue_understanding/generator/llm_based_command_generator.py +1 -1
  12. rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +2 -2
  13. rasa/dialogue_understanding_test/du_test_runner.py +3 -21
  14. rasa/dialogue_understanding_test/test_case_simulation/test_case_tracker_simulator.py +2 -6
  15. rasa/llm_fine_tuning/annotation_module.py +39 -9
  16. rasa/llm_fine_tuning/conversations.py +3 -0
  17. rasa/llm_fine_tuning/llm_data_preparation_module.py +66 -49
  18. rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +4 -2
  19. rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +52 -44
  20. rasa/llm_fine_tuning/paraphrasing_module.py +10 -12
  21. rasa/llm_fine_tuning/storage.py +4 -4
  22. rasa/llm_fine_tuning/utils.py +63 -1
  23. rasa/server.py +6 -2
  24. rasa/shared/constants.py +3 -0
  25. rasa/shared/exceptions.py +4 -0
  26. rasa/shared/providers/_configs/azure_openai_client_config.py +4 -0
  27. rasa/shared/providers/_configs/openai_client_config.py +4 -0
  28. rasa/shared/providers/embedding/_base_litellm_embedding_client.py +3 -0
  29. rasa/shared/providers/llm/_base_litellm_client.py +5 -2
  30. rasa/shared/utils/llm.py +28 -0
  31. rasa/telemetry.py +1 -1
  32. rasa/version.py +1 -1
  33. {rasa_pro-3.12.12.dev1.dist-info → rasa_pro-3.12.14.dist-info}/METADATA +3 -3
  34. {rasa_pro-3.12.12.dev1.dist-info → rasa_pro-3.12.14.dist-info}/RECORD +37 -37
  35. {rasa_pro-3.12.12.dev1.dist-info → rasa_pro-3.12.14.dist-info}/NOTICE +0 -0
  36. {rasa_pro-3.12.12.dev1.dist-info → rasa_pro-3.12.14.dist-info}/WHEEL +0 -0
  37. {rasa_pro-3.12.12.dev1.dist-info → rasa_pro-3.12.14.dist-info}/entry_points.txt +0 -0
@@ -1,13 +1,23 @@
1
1
  from dataclasses import dataclass
2
- from typing import Any, Dict, List, Optional
2
+ from typing import Any, Dict, List, Optional, cast
3
3
 
4
4
  import structlog
5
5
  from tqdm import tqdm
6
6
 
7
+ from rasa.core.agent import Agent
8
+ from rasa.core.channels import UserMessage
7
9
  from rasa.dialogue_understanding.commands.prompt_command import PromptCommand
10
+ from rasa.dialogue_understanding.utils import set_record_commands_and_prompts
8
11
  from rasa.llm_fine_tuning.conversations import Conversation, ConversationStep
9
12
  from rasa.llm_fine_tuning.storage import StorageContext
10
- from rasa.llm_fine_tuning.utils import commands_as_string
13
+ from rasa.llm_fine_tuning.utils import (
14
+ commands_as_string,
15
+ make_mock_invoke_llm,
16
+ patch_invoke_llm_in_generators,
17
+ )
18
+ from rasa.shared.core.trackers import DialogueStateTracker
19
+ from rasa.shared.nlu.constants import KEY_USER_PROMPT, PROMPTS
20
+ from rasa.shared.utils.llm import generate_sender_id
11
21
 
12
22
  LLM_DATA_PREPARATION_MODULE_STORAGE_LOCATION = "3_llm_finetune_data/llm_ft_data.jsonl"
13
23
 
@@ -47,40 +57,8 @@ def _create_data_point(
47
57
  )
48
58
 
49
59
 
50
- def _update_prompt(
51
- prompt: str,
52
- original_user_steps: List[ConversationStep],
53
- rephrased_user_steps: List[str],
54
- ) -> Optional[str]:
55
- if len(original_user_steps) != len(rephrased_user_steps):
56
- structlogger.debug(
57
- "llm_fine_tuning.llm_data_preparation_module.failed_to_update_prompt",
58
- original_user_steps=[
59
- step.original_test_step.text for step in original_user_steps
60
- ],
61
- rephrased_user_steps=rephrased_user_steps,
62
- )
63
- return None
64
-
65
- updated_prompt = prompt
66
- for user_step, rephrased_message in zip(original_user_steps, rephrased_user_steps):
67
- # replace all occurrences of the original user message with the rephrased user
68
- # message in the conversation history mentioned in the prompt
69
- updated_prompt = updated_prompt.replace(
70
- f"USER: {user_step.original_test_step.text}", f"USER: {rephrased_message}"
71
- )
72
-
73
- # replace the latest user message mentioned in the prompt
74
- updated_prompt = updated_prompt.replace(
75
- f"'''{original_user_steps[-1].original_test_step.text}'''",
76
- f"'''{rephrased_user_steps[-1]}'''",
77
- )
78
-
79
- return updated_prompt
80
-
81
-
82
- def _convert_conversation_into_llm_data(
83
- conversation: Conversation,
60
+ async def _convert_conversation_into_llm_data(
61
+ conversation: Conversation, agent: Agent
84
62
  ) -> List[LLMDataExample]:
85
63
  data = []
86
64
 
@@ -95,18 +73,52 @@ def _convert_conversation_into_llm_data(
95
73
  # create data point for the original e2e test case
96
74
  data.append(_create_data_point(step.llm_prompt, step, conversation))
97
75
 
98
- # create data points using the rephrasings, e.g. 'new_conversations'
99
- for rephrased_user_steps in new_conversations:
100
- # +1 to include the current user turn
101
- prompt = _update_prompt(
102
- step.llm_prompt,
103
- original_user_steps[: i + 1],
104
- rephrased_user_steps[: i + 1],
76
+ test_case_name = conversation.name
77
+
78
+ # create data points using the rephrasings, e.g. 'new_conversations'
79
+ for rephrased_user_steps in new_conversations:
80
+ sender_id = generate_sender_id(test_case_name)
81
+ # create a new tracker to be able to simulate the conversation from start
82
+ await agent.tracker_store.save(DialogueStateTracker(sender_id, slots=[]))
83
+ # simulate the conversation to get the prompts
84
+ for i, step in enumerate(original_user_steps):
85
+ rephrased_user_message = rephrased_user_steps[i]
86
+ user_message = UserMessage(rephrased_user_message, sender_id=sender_id)
87
+
88
+ expected_commands = "\n".join(
89
+ [command.to_dsl() for command in step.llm_commands]
90
+ )
91
+ fake_invoke_function = make_mock_invoke_llm(expected_commands)
92
+
93
+ with (
94
+ set_record_commands_and_prompts(),
95
+ patch_invoke_llm_in_generators(fake_invoke_function),
96
+ ):
97
+ await agent.handle_message(user_message)
98
+
99
+ rephrased_tracker = await agent.tracker_store.retrieve(sender_id)
100
+ if rephrased_tracker is None:
101
+ # if tracker doesn't exist, we can't create a data point
102
+ continue
103
+
104
+ latest_message = rephrased_tracker.latest_message
105
+ if latest_message is None:
106
+ # if there is no latest message, we don't create a data point
107
+ continue
108
+
109
+ # tell the type checker what we expect to find under "prompts"
110
+ prompts = cast(
111
+ Optional[List[Dict[str, Any]]], latest_message.parse_data.get(PROMPTS)
105
112
  )
106
- if prompt:
113
+
114
+ if prompts:
115
+ # as we only use single step or compact command generator,
116
+ # there is always exactly one prompt
117
+ prompt = prompts[0]
118
+ user_prompt: Optional[str] = prompt.get(KEY_USER_PROMPT)
107
119
  data.append(
108
120
  _create_data_point(
109
- prompt, step, conversation, rephrased_user_steps[i]
121
+ user_prompt, step, conversation, rephrased_user_message
110
122
  )
111
123
  )
112
124
 
@@ -149,7 +161,7 @@ def _construct_new_conversations(conversation: Conversation) -> List[List[str]]:
149
161
  current_conversation.append(step.original_test_step.text)
150
162
  continue
151
163
 
152
- # some user steps might have less rephrasings than others
164
+ # some user steps might have fewer rephrasings than others
153
165
  # loop over the rephrasings
154
166
  index = i % len(step.passed_rephrasings)
155
167
  current_conversation.append(step.passed_rephrasings[index])
@@ -165,13 +177,18 @@ def _construct_new_conversations(conversation: Conversation) -> List[List[str]]:
165
177
  return new_conversations
166
178
 
167
179
 
168
- def convert_to_fine_tuning_data(
169
- conversations: List[Conversation], storage_context: StorageContext
180
+ async def convert_to_fine_tuning_data(
181
+ conversations: List[Conversation],
182
+ storage_context: StorageContext,
183
+ agent: Agent,
170
184
  ) -> List[LLMDataExample]:
171
185
  llm_data = []
172
186
 
173
187
  for i in tqdm(range(len(conversations))):
174
- llm_data.extend(_convert_conversation_into_llm_data(conversations[i]))
188
+ conversation_llm_data = await _convert_conversation_into_llm_data(
189
+ conversations[i], agent
190
+ )
191
+ llm_data.extend(conversation_llm_data)
175
192
 
176
193
  storage_context.write_llm_data(
177
194
  llm_data, LLM_DATA_PREPARATION_MODULE_STORAGE_LOCATION
@@ -11,10 +11,12 @@ from rasa.llm_fine_tuning.paraphrasing.rephrased_user_message import (
11
11
  )
12
12
  from rasa.shared.constants import (
13
13
  LLM_CONFIG_KEY,
14
+ MAX_COMPLETION_TOKENS_CONFIG_KEY,
14
15
  MODEL_CONFIG_KEY,
15
16
  MODEL_NAME_CONFIG_KEY,
16
17
  PROMPT_TEMPLATE_CONFIG_KEY,
17
18
  PROVIDER_CONFIG_KEY,
19
+ TEMPERATURE_CONFIG_KEY,
18
20
  TIMEOUT_CONFIG_KEY,
19
21
  )
20
22
  from rasa.shared.exceptions import ProviderClientAPIException
@@ -39,8 +41,8 @@ DEFAULT_LLM_CONFIG = {
39
41
  PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
40
42
  MODEL_CONFIG_KEY: "gpt-4o-mini",
41
43
  TIMEOUT_CONFIG_KEY: 7,
42
- "temperature": 0.0,
43
- "max_tokens": 4096,
44
+ TEMPERATURE_CONFIG_KEY: 0.0,
45
+ MAX_COMPLETION_TOKENS_CONFIG_KEY: 4096,
44
46
  }
45
47
 
46
48
  structlogger = structlog.get_logger()
@@ -1,45 +1,45 @@
1
- from typing import Any, Dict, List, Type
1
+ from typing import List, Optional
2
2
 
3
3
  import structlog
4
4
 
5
+ from rasa.core.agent import Agent
6
+ from rasa.core.channels import UserMessage
5
7
  from rasa.dialogue_understanding.commands import Command, SetSlotCommand
6
- from rasa.dialogue_understanding.generator.llm_based_command_generator import (
7
- LLMBasedCommandGenerator,
8
- )
9
8
  from rasa.llm_fine_tuning.conversations import Conversation, ConversationStep
10
9
  from rasa.llm_fine_tuning.paraphrasing.rephrased_user_message import (
11
10
  RephrasedUserMessage,
12
11
  )
13
12
  from rasa.shared.core.flows import FlowsList
14
- from rasa.shared.exceptions import ProviderClientAPIException
15
- from rasa.shared.utils.llm import llm_factory
13
+ from rasa.shared.core.trackers import DialogueStateTracker
14
+ from rasa.shared.utils.llm import (
15
+ create_tracker_for_user_step,
16
+ generate_sender_id,
17
+ )
16
18
 
17
19
  structlogger = structlog.get_logger()
18
20
 
19
21
 
20
22
  class RephraseValidator:
21
- def __init__(self, llm_config: Dict[str, Any], flows: FlowsList) -> None:
22
- self.llm_config = llm_config
23
+ def __init__(self, flows: FlowsList) -> None:
23
24
  self.flows = flows
24
25
 
25
26
  async def validate_rephrasings(
26
27
  self,
28
+ agent: Agent,
27
29
  rephrasings: List[RephrasedUserMessage],
28
30
  conversation: Conversation,
29
- llm_command_generator: Type[LLMBasedCommandGenerator],
30
31
  ) -> List[RephrasedUserMessage]:
31
32
  """Split rephrased user messages into passing and failing.
32
33
 
33
- Call an LLM using the same config of the former trained model with an updated
34
- prompt from the original user message (replace all occurrences of the original
35
- user message with the rephrased user message). Check if the
36
- rephrased user message is producing the same commands as the original user
37
- message. The rephase is passing if the commands match and failing otherwise.
34
+ Handle the rephrased messages using agent the same way the original
35
+ message was handled. Check if the rephrased user message is producing
36
+ the same commands as the original user message. The rephrase is passing
37
+ if the commands match and failing otherwise.
38
38
 
39
39
  Args:
40
+ agent: Rasa agent
40
41
  rephrasings: The rephrased user messages.
41
42
  conversation: The conversation.
42
- llm_command_generator: A LLM based command generator class.
43
43
 
44
44
  Returns:
45
45
  A list of rephrased user messages including the passing and failing
@@ -52,7 +52,11 @@ class RephraseValidator:
52
52
 
53
53
  for rephrase in current_rephrasings.rephrasings:
54
54
  if await self._validate_rephrase_is_passing(
55
- rephrase, step, llm_command_generator
55
+ agent,
56
+ rephrase,
57
+ step,
58
+ conversation.name,
59
+ conversation.tracker,
56
60
  ):
57
61
  current_rephrasings.passed_rephrasings.append(rephrase)
58
62
  else:
@@ -62,40 +66,29 @@ class RephraseValidator:
62
66
 
63
67
  async def _validate_rephrase_is_passing(
64
68
  self,
69
+ agent: Agent,
65
70
  rephrase: str,
66
71
  step: ConversationStep,
67
- llm_command_generator: Type[LLMBasedCommandGenerator],
72
+ test_case_name: str,
73
+ tracker: DialogueStateTracker,
68
74
  ) -> bool:
69
- prompt = self._update_prompt(
70
- rephrase, step.original_test_step.text, step.llm_prompt
71
- )
72
-
73
- action_list = await self._invoke_llm(
74
- prompt, llm_command_generator.get_default_llm_config()
75
+ rephrased_tracker = await self._send_rephrased_message_to_agent(
76
+ rephrase, step, test_case_name, agent, tracker
75
77
  )
78
+ if not (rephrased_tracker and rephrased_tracker.latest_message):
79
+ return False
76
80
 
77
81
  commands_from_original_utterance = step.llm_commands
78
- commands_from_rephrased_utterance = llm_command_generator.parse_commands( # type: ignore
79
- action_list, None, self.flows
80
- )
82
+
83
+ commands_from_rephrased_utterance = [
84
+ Command.command_from_json(command_json)
85
+ for command_json in rephrased_tracker.latest_message.commands
86
+ ]
87
+
81
88
  return self._check_commands_match(
82
89
  commands_from_original_utterance, commands_from_rephrased_utterance
83
90
  )
84
91
 
85
- async def _invoke_llm(self, prompt: str, default_llm_config: Dict[str, Any]) -> str:
86
- llm = llm_factory(self.llm_config, default_llm_config)
87
-
88
- try:
89
- llm_response = await llm.acompletion(prompt)
90
- return llm_response.choices[0]
91
- except Exception as e:
92
- # unfortunately, langchain does not wrap LLM exceptions which means
93
- # we have to catch all exceptions here
94
- structlogger.error(
95
- "rephrase_validator.validate_conversation.llm.error", error=e
96
- )
97
- raise ProviderClientAPIException(e, message="LLM call exception")
98
-
99
92
  @staticmethod
100
93
  def _check_commands_match(
101
94
  expected_commands: List[Command], actual_commands: List[Command]
@@ -120,7 +113,22 @@ class RephraseValidator:
120
113
  return True
121
114
 
122
115
  @staticmethod
123
- def _update_prompt(
124
- rephrased_user_message: str, original_user_message: str, prompt: str
125
- ) -> str:
126
- return prompt.replace(original_user_message, rephrased_user_message)
116
+ async def _send_rephrased_message_to_agent(
117
+ rephrased_user_message: str,
118
+ step: ConversationStep,
119
+ test_case_name: str,
120
+ agent: Agent,
121
+ tracker: DialogueStateTracker,
122
+ ) -> Optional[DialogueStateTracker]:
123
+ # create a rephrased UserMessage
124
+ sender_id = generate_sender_id(test_case_name)
125
+ user_message = UserMessage(rephrased_user_message, sender_id=sender_id)
126
+
127
+ await create_tracker_for_user_step(
128
+ sender_id, agent, tracker, step.tracker_event_index
129
+ )
130
+
131
+ await agent.handle_message(user_message)
132
+ rephrased_tracker = await agent.tracker_store.retrieve(sender_id)
133
+
134
+ return rephrased_tracker
@@ -1,11 +1,9 @@
1
- from typing import Any, Dict, List, Tuple, Type
1
+ from typing import Any, Dict, List, Tuple
2
2
 
3
3
  import structlog
4
4
  from tqdm import tqdm
5
5
 
6
- from rasa.dialogue_understanding.generator.llm_based_command_generator import (
7
- LLMBasedCommandGenerator,
8
- )
6
+ from rasa.core.agent import Agent
9
7
  from rasa.llm_fine_tuning.conversations import Conversation
10
8
  from rasa.llm_fine_tuning.paraphrasing.conversation_rephraser import (
11
9
  ConversationRephraser,
@@ -28,8 +26,7 @@ async def create_paraphrased_conversations(
28
26
  rephrase_config: Dict[str, Any],
29
27
  num_rephrases: int,
30
28
  flows: FlowsList,
31
- llm_command_generator: Type[LLMBasedCommandGenerator],
32
- llm_command_generator_config: Dict[str, Any],
29
+ agent: Agent,
33
30
  storage_context: StorageContext,
34
31
  ) -> Tuple[List[Conversation], Dict[str, Any]]:
35
32
  """Create paraphrased conversations.
@@ -42,7 +39,7 @@ async def create_paraphrased_conversations(
42
39
  rephrase_config: The path to the rephrase configuration file.
43
40
  num_rephrases: The number of rephrases to produce per user message.
44
41
  flows: All flows.
45
- llm_command_generator_config: The configuration of the trained model.
42
+ agent: The Rasa agent.
46
43
  storage_context: The storage context.
47
44
 
48
45
  Returns:
@@ -50,7 +47,7 @@ async def create_paraphrased_conversations(
50
47
  rephrasing.
51
48
  """
52
49
  rephraser = ConversationRephraser(rephrase_config)
53
- validator = RephraseValidator(llm_command_generator_config, flows)
50
+ validator = RephraseValidator(flows)
54
51
 
55
52
  if num_rephrases <= 0:
56
53
  structlogger.info(
@@ -64,18 +61,19 @@ async def create_paraphrased_conversations(
64
61
  rephrased_conversations: List[Conversation] = []
65
62
  for i in tqdm(range(len(conversations))):
66
63
  current_conversation = conversations[i]
67
-
68
64
  try:
69
65
  # rephrase all user messages even if rephrase=False is set
70
66
  # to not confuse the LLM and get valid output
71
67
  rephrasings = await rephraser.rephrase_conversation(
72
- conversations[i], num_rephrases
68
+ current_conversation, num_rephrases
73
69
  )
74
70
  # filter out the rephrasings for user messages that have rephrase=False set
75
- rephrasings = _filter_rephrasings(rephrasings, conversations[i])
71
+ rephrasings = _filter_rephrasings(rephrasings, current_conversation)
76
72
  # check if the rephrasings are still producing the same commands
77
73
  rephrasings = await validator.validate_rephrasings(
78
- rephrasings, current_conversation, llm_command_generator
74
+ agent,
75
+ rephrasings,
76
+ current_conversation,
79
77
  )
80
78
  except ProviderClientAPIException as e:
81
79
  structlogger.error(
@@ -96,9 +96,9 @@ class FileStorageStrategy(StorageStrategy):
96
96
  file_path = self._get_file_path(storage_location)
97
97
  self._create_output_dir(file_path)
98
98
 
99
- with open(str(file_path), "w") as outfile:
99
+ with open(str(file_path), "w", encoding="utf-8") as outfile:
100
100
  for example in llm_data:
101
- json.dump(example.as_dict(), outfile)
101
+ json.dump(example.as_dict(), outfile, ensure_ascii=False)
102
102
  outfile.write("\n")
103
103
 
104
104
  def write_formatted_finetuning_data(
@@ -110,9 +110,9 @@ class FileStorageStrategy(StorageStrategy):
110
110
  file_path = self._get_file_path(module_storage_location, file_name)
111
111
  self._create_output_dir(file_path)
112
112
 
113
- with open(str(file_path), "w") as file:
113
+ with open(str(file_path), "w", encoding="utf-8") as file:
114
114
  for example in formatted_data:
115
- json.dump(example.as_dict(), file)
115
+ json.dump(example.as_dict(), file, ensure_ascii=False)
116
116
  file.write("\n")
117
117
 
118
118
  def write_e2e_test_suite_to_yaml_file(
@@ -1,7 +1,69 @@
1
- from typing import List
1
+ from contextlib import contextmanager
2
+ from datetime import datetime
3
+ from typing import Callable, Generator, List, Union
4
+
5
+ import structlog
2
6
 
3
7
  from rasa.dialogue_understanding.commands.prompt_command import PromptCommand
8
+ from rasa.dialogue_understanding.generator import LLMBasedCommandGenerator
9
+ from rasa.shared.providers.llm.llm_response import LLMResponse
10
+
11
+ structlogger = structlog.get_logger()
4
12
 
5
13
 
6
14
  def commands_as_string(commands: List[PromptCommand], delimiter: str = "\n") -> str:
7
15
  return delimiter.join([command.to_dsl() for command in commands])
16
+
17
+
18
+ def make_mock_invoke_llm(commands: str) -> Callable:
19
+ """Capture the `commands` in a closure so the resulting async function
20
+ can use it as its response.
21
+
22
+ Args:
23
+ commands: The commands to return from the mock LLM call.
24
+ """
25
+
26
+ async def _mock_invoke_llm(
27
+ self: LLMBasedCommandGenerator, prompt: Union[List[dict], List[str], str]
28
+ ) -> LLMResponse:
29
+ structlogger.debug(
30
+ f"LLM call intercepted, response mocked. "
31
+ f"Responding with the following commands: '{commands}' "
32
+ f"to the prompt: {prompt}"
33
+ )
34
+ fake_response_dict = {
35
+ "id": "",
36
+ "choices": [commands],
37
+ "created": int(datetime.now().timestamp()),
38
+ "model": "mocked-llm",
39
+ }
40
+ return LLMResponse.from_dict(fake_response_dict)
41
+
42
+ return _mock_invoke_llm
43
+
44
+
45
+ @contextmanager
46
+ def patch_invoke_llm_in_generators(mock_impl: Callable) -> Generator:
47
+ """Replace CommandGenerator.invoke_llm in the base class AND in all
48
+ current subclasses (recursively). Everything is restored on exit.
49
+ """
50
+ originals = {}
51
+
52
+ def collect(cls: type[LLMBasedCommandGenerator]) -> None:
53
+ # store current attribute, then recurse
54
+ originals[cls] = cls.invoke_llm
55
+ for sub in cls.__subclasses__():
56
+ collect(sub)
57
+
58
+ # collect every existing subclass of CommandGenerator
59
+ collect(LLMBasedCommandGenerator) # type: ignore[type-abstract]
60
+
61
+ try:
62
+ # apply the monkey-patch everywhere
63
+ for cls in originals:
64
+ cls.invoke_llm = mock_impl # type: ignore[assignment]
65
+ yield
66
+ finally:
67
+ # restore originals (even if an exception happened)
68
+ for cls, orig in originals.items():
69
+ cls.invoke_llm = orig # type: ignore[assignment]
rasa/server.py CHANGED
@@ -522,12 +522,15 @@ def configure_cors(
522
522
  )
523
523
 
524
524
 
525
- def add_root_route(app: Sanic) -> None:
525
+ def add_root_route(app: Sanic, is_inspector_enabled: bool = False) -> None:
526
526
  """Add '/' route to return hello."""
527
527
 
528
528
  @app.get("/")
529
529
  async def hello(request: Request) -> HTTPResponse:
530
530
  """Check if the server is running and responds with the version."""
531
+ if not is_inspector_enabled:
532
+ return response.text("Hello from Rasa: " + rasa.__version__)
533
+
531
534
  html_content = f"""
532
535
  <html>
533
536
  <body>
@@ -688,6 +691,7 @@ def create_app(
688
691
  jwt_private_key: Optional[Text] = None,
689
692
  jwt_method: Text = "HS256",
690
693
  endpoints: Optional[AvailableEndpoints] = None,
694
+ is_inspector_enabled: bool = False,
691
695
  ) -> Sanic:
692
696
  """Class representing a Rasa HTTP server."""
693
697
  app = Sanic("rasa_server")
@@ -733,7 +737,7 @@ def create_app(
733
737
  ) -> HTTPResponse:
734
738
  return response.json(exception.error_info, status=exception.status)
735
739
 
736
- add_root_route(app)
740
+ add_root_route(app, is_inspector_enabled)
737
741
 
738
742
  @app.get("/version")
739
743
  async def version(request: Request) -> HTTPResponse:
rasa/shared/constants.py CHANGED
@@ -197,7 +197,10 @@ PROVIDER_CONFIG_KEY = "provider"
197
197
  REQUEST_TIMEOUT_CONFIG_KEY = "request_timeout" # deprecated
198
198
  TIMEOUT_CONFIG_KEY = "timeout"
199
199
 
200
+ LOGIT_BIAS_CONFIG_KEY = "logit_bias"
201
+ MAX_RETRIES_CONFIG_KEY = "max_retries"
200
202
  TEMPERATURE_CONFIG_KEY = "temperature"
203
+ MAX_COMPLETION_TOKENS_CONFIG_KEY = "max_completion_tokens"
201
204
  MAX_TOKENS_CONFIG_KEY = "max_tokens"
202
205
 
203
206
  DEPLOYMENT_NAME_CONFIG_KEY = "deployment_name"
rasa/shared/exceptions.py CHANGED
@@ -165,3 +165,7 @@ class ProviderClientAPIException(RasaException):
165
165
 
166
166
  class ProviderClientValidationError(RasaException):
167
167
  """Raised for errors that occur during validation of the API client."""
168
+
169
+
170
+ class FinetuningDataPreparationException(RasaException):
171
+ """Raised when there is an error in data preparation for fine-tuning."""
@@ -23,6 +23,8 @@ from rasa.shared.constants import (
23
23
  DEPLOYMENT_NAME_CONFIG_KEY,
24
24
  ENGINE_CONFIG_KEY,
25
25
  LANGCHAIN_TYPE_CONFIG_KEY,
26
+ MAX_COMPLETION_TOKENS_CONFIG_KEY,
27
+ MAX_TOKENS_CONFIG_KEY,
26
28
  MODEL_CONFIG_KEY,
27
29
  MODEL_NAME_CONFIG_KEY,
28
30
  N_REPHRASES_CONFIG_KEY,
@@ -71,6 +73,8 @@ DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING = {
71
73
  MODEL_NAME_CONFIG_KEY: MODEL_CONFIG_KEY,
72
74
  # Timeout aliases
73
75
  REQUEST_TIMEOUT_CONFIG_KEY: TIMEOUT_CONFIG_KEY,
76
+ # Max tokens aliases
77
+ MAX_TOKENS_CONFIG_KEY: MAX_COMPLETION_TOKENS_CONFIG_KEY,
74
78
  }
75
79
 
76
80
  REQUIRED_KEYS = [DEPLOYMENT_CONFIG_KEY]
@@ -10,6 +10,8 @@ from rasa.shared.constants import (
10
10
  API_TYPE_CONFIG_KEY,
11
11
  API_VERSION_CONFIG_KEY,
12
12
  LANGCHAIN_TYPE_CONFIG_KEY,
13
+ MAX_COMPLETION_TOKENS_CONFIG_KEY,
14
+ MAX_TOKENS_CONFIG_KEY,
13
15
  MODEL_CONFIG_KEY,
14
16
  MODEL_NAME_CONFIG_KEY,
15
17
  N_REPHRASES_CONFIG_KEY,
@@ -48,6 +50,8 @@ DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING = {
48
50
  OPENAI_API_VERSION_CONFIG_KEY: API_VERSION_CONFIG_KEY,
49
51
  # Timeout aliases
50
52
  REQUEST_TIMEOUT_CONFIG_KEY: TIMEOUT_CONFIG_KEY,
53
+ # Max tokens aliases
54
+ MAX_TOKENS_CONFIG_KEY: MAX_COMPLETION_TOKENS_CONFIG_KEY,
51
55
  }
52
56
 
53
57
  REQUIRED_KEYS = [MODEL_CONFIG_KEY]
@@ -70,7 +70,10 @@ class _BaseLiteLLMEmbeddingClient:
70
70
  def _embedding_fn_args(self) -> Dict[str, Any]:
71
71
  """Returns the arguments to be passed to the embedding function."""
72
72
  return {
73
+ # Parameters set through config, can override drop_params
73
74
  **self._litellm_extra_parameters,
75
+ # Model name is constructed in the LiteLLM format from the provided config
76
+ # Non-overridable to ensure consistency
74
77
  "model": self._litellm_model_name,
75
78
  }
76
79
 
@@ -84,12 +84,15 @@ class _BaseLiteLLMClient:
84
84
  @property
85
85
  def _completion_fn_args(self) -> dict:
86
86
  return {
87
- **self._litellm_extra_parameters,
88
- "model": self._litellm_model_name,
89
87
  # Since all providers covered by LiteLLM use the OpenAI format, but
90
88
  # not all support every OpenAI parameter, raise an exception if
91
89
  # provider/model uses unsupported parameter
92
90
  "drop_params": False,
91
+ # All other parameters set through config, can override drop_params
92
+ **self._litellm_extra_parameters,
93
+ # Model name is constructed in the LiteLLM format from the provided config
94
+ # Non-overridable to ensure consistency
95
+ "model": self._litellm_model_name,
93
96
  }
94
97
 
95
98
  def validate_client_setup(self) -> None:
rasa/shared/utils/llm.py CHANGED
@@ -2,6 +2,7 @@ import importlib.resources
2
2
  import json
3
3
  import logging
4
4
  from copy import deepcopy
5
+ from datetime import datetime
5
6
  from functools import wraps
6
7
  from typing import (
7
8
  TYPE_CHECKING,
@@ -64,6 +65,7 @@ from rasa.shared.providers.mappings import (
64
65
  from rasa.shared.utils.constants import LOG_COMPONENT_SOURCE_METHOD_INIT
65
66
 
66
67
  if TYPE_CHECKING:
68
+ from rasa.core.agent import Agent
67
69
  from rasa.shared.core.trackers import DialogueStateTracker
68
70
 
69
71
 
@@ -886,3 +888,29 @@ def resolve_model_client_config(
886
888
  )
887
889
 
888
890
  return model_group[0]
891
+
892
+
893
+ def generate_sender_id(test_case_name: str) -> str:
894
+ # add timestamp suffix to ensure sender_id is unique
895
+ return f"{test_case_name}_{datetime.now()}"
896
+
897
+
898
+ async def create_tracker_for_user_step(
899
+ step_sender_id: str,
900
+ agent: "Agent",
901
+ test_case_tracker: "DialogueStateTracker",
902
+ index_user_uttered_event: int,
903
+ ) -> None:
904
+ """Creates a tracker for the user step."""
905
+ tracker = test_case_tracker.copy()
906
+ # modify the sender id so that the original tracker is not overwritten
907
+ tracker.sender_id = step_sender_id
908
+
909
+ if tracker.events:
910
+ # get the timestamp of the event just before the user uttered event
911
+ timestamp = tracker.events[index_user_uttered_event - 1].timestamp
912
+ # revert the tracker to the event just before the user uttered event
913
+ tracker = tracker.travel_back_in_time(timestamp)
914
+
915
+ # store the tracker with the unique sender id
916
+ await agent.tracker_store.save(tracker)