rasa-pro 3.12.0.dev9__py3-none-any.whl → 3.12.0.dev10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (56) hide show
  1. rasa/core/actions/action.py +17 -3
  2. rasa/core/actions/action_handle_digressions.py +142 -0
  3. rasa/core/actions/forms.py +4 -2
  4. rasa/core/channels/voice_ready/audiocodes.py +42 -23
  5. rasa/core/channels/voice_stream/tts/azure.py +2 -1
  6. rasa/core/migrate.py +2 -2
  7. rasa/core/policies/flows/flow_executor.py +33 -1
  8. rasa/dialogue_understanding/commands/can_not_handle_command.py +2 -2
  9. rasa/dialogue_understanding/commands/cancel_flow_command.py +62 -4
  10. rasa/dialogue_understanding/commands/change_flow_command.py +2 -2
  11. rasa/dialogue_understanding/commands/chit_chat_answer_command.py +2 -2
  12. rasa/dialogue_understanding/commands/clarify_command.py +2 -2
  13. rasa/dialogue_understanding/commands/correct_slots_command.py +11 -2
  14. rasa/dialogue_understanding/commands/handle_digressions_command.py +150 -0
  15. rasa/dialogue_understanding/commands/human_handoff_command.py +2 -2
  16. rasa/dialogue_understanding/commands/knowledge_answer_command.py +2 -2
  17. rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +2 -2
  18. rasa/dialogue_understanding/commands/set_slot_command.py +7 -15
  19. rasa/dialogue_understanding/commands/skip_question_command.py +2 -2
  20. rasa/dialogue_understanding/commands/start_flow_command.py +43 -2
  21. rasa/dialogue_understanding/commands/utils.py +1 -1
  22. rasa/dialogue_understanding/constants.py +1 -0
  23. rasa/dialogue_understanding/generator/command_generator.py +10 -76
  24. rasa/dialogue_understanding/generator/command_parser.py +1 -1
  25. rasa/dialogue_understanding/generator/llm_based_command_generator.py +126 -2
  26. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +10 -2
  27. rasa/dialogue_understanding/generator/nlu_command_adapter.py +4 -2
  28. rasa/dialogue_understanding/generator/single_step/command_prompt_template.jinja2 +40 -40
  29. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +11 -19
  30. rasa/dialogue_understanding/patterns/correction.py +13 -1
  31. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +62 -2
  32. rasa/dialogue_understanding/patterns/handle_digressions.py +81 -0
  33. rasa/dialogue_understanding/processor/command_processor.py +117 -28
  34. rasa/dialogue_understanding/utils.py +31 -0
  35. rasa/dialogue_understanding_test/test_case_simulation/test_case_tracker_simulator.py +2 -2
  36. rasa/shared/core/constants.py +22 -1
  37. rasa/shared/core/domain.py +6 -4
  38. rasa/shared/core/events.py +13 -2
  39. rasa/shared/core/flows/flow.py +17 -0
  40. rasa/shared/core/flows/flows_yaml_schema.json +38 -0
  41. rasa/shared/core/flows/steps/collect.py +18 -1
  42. rasa/shared/core/flows/utils.py +16 -1
  43. rasa/shared/core/slot_mappings.py +6 -6
  44. rasa/shared/core/slots.py +19 -0
  45. rasa/shared/core/trackers.py +3 -1
  46. rasa/shared/nlu/constants.py +1 -0
  47. rasa/shared/providers/llm/_base_litellm_client.py +0 -40
  48. rasa/shared/utils/llm.py +1 -86
  49. rasa/shared/utils/schemas/domain.yml +0 -1
  50. rasa/validator.py +172 -22
  51. rasa/version.py +1 -1
  52. {rasa_pro-3.12.0.dev9.dist-info → rasa_pro-3.12.0.dev10.dist-info}/METADATA +1 -1
  53. {rasa_pro-3.12.0.dev9.dist-info → rasa_pro-3.12.0.dev10.dist-info}/RECORD +56 -53
  54. {rasa_pro-3.12.0.dev9.dist-info → rasa_pro-3.12.0.dev10.dist-info}/NOTICE +0 -0
  55. {rasa_pro-3.12.0.dev9.dist-info → rasa_pro-3.12.0.dev10.dist-info}/WHEEL +0 -0
  56. {rasa_pro-3.12.0.dev9.dist-info → rasa_pro-3.12.0.dev10.dist-info}/entry_points.txt +0 -0
@@ -72,10 +72,11 @@ from rasa.shared.core.constants import (
72
72
  ACTION_UNLIKELY_INTENT_NAME,
73
73
  ACTION_VALIDATE_SLOT_MAPPINGS,
74
74
  DEFAULT_SLOT_NAMES,
75
+ KEY_MAPPING_TYPE,
75
76
  KNOWLEDGE_BASE_SLOT_NAMES,
76
- MAPPING_TYPE,
77
77
  REQUESTED_SLOT,
78
78
  USER_INTENT_OUT_OF_SCOPE,
79
+ SetSlotExtractor,
79
80
  SlotMappingType,
80
81
  )
81
82
  from rasa.shared.core.domain import Domain
@@ -118,6 +119,10 @@ logger = logging.getLogger(__name__)
118
119
  def default_actions(action_endpoint: Optional[EndpointConfig] = None) -> List["Action"]:
119
120
  """List default actions."""
120
121
  from rasa.core.actions.action_clean_stack import ActionCleanStack
122
+ from rasa.core.actions.action_handle_digressions import (
123
+ ActionBlockDigressions,
124
+ ActionContinueDigression,
125
+ )
121
126
  from rasa.core.actions.action_hangup import ActionHangup
122
127
  from rasa.core.actions.action_repeat_bot_messages import ActionRepeatBotMessages
123
128
  from rasa.core.actions.action_run_slot_rejections import ActionRunSlotRejections
@@ -152,6 +157,8 @@ def default_actions(action_endpoint: Optional[EndpointConfig] = None) -> List["A
152
157
  ActionResetRouting(),
153
158
  ActionHangup(),
154
159
  ActionRepeatBotMessages(),
160
+ ActionBlockDigressions(),
161
+ ActionContinueDigression(),
155
162
  ]
156
163
 
157
164
 
@@ -940,7 +947,14 @@ class RemoteAction(Action):
940
947
  )
941
948
 
942
949
  events = rasa.shared.core.events.deserialise_events(events_json)
943
- return cast(List[Event], bot_messages) + events
950
+
951
+ processed_events = []
952
+ for event in events:
953
+ if isinstance(event, SlotSet) and event.filled_by is None:
954
+ event.filled_by = SetSlotExtractor.CUSTOM.value
955
+ processed_events.append(event)
956
+
957
+ return cast(List[Event], bot_messages) + processed_events
944
958
 
945
959
  def name(self) -> Text:
946
960
  return self._name
@@ -1317,7 +1331,7 @@ class ActionExtractSlots(Action):
1317
1331
  slot_events.append(SlotSet(slot.name, slot_value))
1318
1332
 
1319
1333
  for mapping in slot.mappings:
1320
- mapping_type = SlotMappingType(mapping.get(MAPPING_TYPE))
1334
+ mapping_type = SlotMappingType(mapping.get(KEY_MAPPING_TYPE))
1321
1335
  should_fill_custom_slot = mapping_type == SlotMappingType.CUSTOM
1322
1336
 
1323
1337
  if should_fill_custom_slot:
@@ -0,0 +1,142 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ import structlog
6
+
7
+ from rasa.core.actions.action import Action, create_bot_utterance
8
+ from rasa.core.channels import OutputChannel
9
+ from rasa.core.nlg import NaturalLanguageGenerator
10
+ from rasa.core.utils import add_bot_utterance_metadata
11
+ from rasa.dialogue_understanding.patterns.continue_interrupted import (
12
+ ContinueInterruptedPatternFlowStackFrame,
13
+ )
14
+ from rasa.dialogue_understanding.patterns.handle_digressions import (
15
+ HandleDigressionsPatternFlowStackFrame,
16
+ )
17
+ from rasa.dialogue_understanding.stack.frames.flow_stack_frame import (
18
+ FlowStackFrameType,
19
+ UserFlowStackFrame,
20
+ )
21
+ from rasa.shared.core.constants import (
22
+ ACTION_BLOCK_DIGRESSION,
23
+ ACTION_CONTINUE_DIGRESSION,
24
+ )
25
+ from rasa.shared.core.domain import Domain
26
+ from rasa.shared.core.events import Event, FlowInterrupted
27
+ from rasa.shared.core.trackers import DialogueStateTracker
28
+
29
+ structlogger = structlog.get_logger()
30
+
31
+
32
+ class ActionBlockDigressions(Action):
33
+ """Action which blocks an interruption and continues the current flow."""
34
+
35
+ def name(self) -> str:
36
+ """Return the action name."""
37
+ return ACTION_BLOCK_DIGRESSION
38
+
39
+ async def run(
40
+ self,
41
+ output_channel: OutputChannel,
42
+ nlg: NaturalLanguageGenerator,
43
+ tracker: DialogueStateTracker,
44
+ domain: Domain,
45
+ metadata: Optional[Dict[str, Any]] = None,
46
+ ) -> List[Event]:
47
+ """Update the stack."""
48
+ structlogger.debug("action_block_digressions.run")
49
+ top_frame = tracker.stack.top()
50
+
51
+ if not isinstance(top_frame, HandleDigressionsPatternFlowStackFrame):
52
+ return []
53
+
54
+ blocked_flow_id = top_frame.interrupting_flow_id
55
+ frame_type = FlowStackFrameType.REGULAR
56
+
57
+ stack = tracker.stack
58
+ stack.push(
59
+ UserFlowStackFrame(flow_id=blocked_flow_id, frame_type=frame_type), 0
60
+ )
61
+ stack.push(
62
+ ContinueInterruptedPatternFlowStackFrame(
63
+ previous_flow_name=blocked_flow_id
64
+ ),
65
+ 1,
66
+ )
67
+ events = tracker.create_stack_updated_events(stack)
68
+
69
+ utterance = "utter_block_digressions"
70
+ message = await nlg.generate(
71
+ utterance,
72
+ tracker,
73
+ output_channel.name(),
74
+ )
75
+
76
+ if message is None:
77
+ structlogger.error(
78
+ "action_block_digressions.run.failed.finding.utter",
79
+ utterance=utterance,
80
+ )
81
+ else:
82
+ message = add_bot_utterance_metadata(
83
+ message, utterance, nlg, domain, tracker
84
+ )
85
+ events.append(create_bot_utterance(message))
86
+
87
+ return events
88
+
89
+
90
+ class ActionContinueDigression(Action):
91
+ """Action which continues with an interruption."""
92
+
93
+ def name(self) -> str:
94
+ """Return the action name."""
95
+ return ACTION_CONTINUE_DIGRESSION
96
+
97
+ async def run(
98
+ self,
99
+ output_channel: OutputChannel,
100
+ nlg: NaturalLanguageGenerator,
101
+ tracker: DialogueStateTracker,
102
+ domain: Domain,
103
+ metadata: Optional[Dict[str, Any]] = None,
104
+ ) -> List[Event]:
105
+ """Update the stack."""
106
+ structlogger.debug("action_continue_digression.run")
107
+ top_frame = tracker.stack.top()
108
+
109
+ if not isinstance(top_frame, HandleDigressionsPatternFlowStackFrame):
110
+ return []
111
+
112
+ blocked_flow_id = top_frame.interrupting_flow_id
113
+ frame_type = FlowStackFrameType.INTERRUPT
114
+ stack = tracker.stack
115
+ stack.push(UserFlowStackFrame(flow_id=blocked_flow_id, frame_type=frame_type))
116
+
117
+ events = [
118
+ FlowInterrupted(
119
+ flow_id=top_frame.interrupted_flow_id,
120
+ step_id=top_frame.interrupted_step_id,
121
+ )
122
+ ] + tracker.create_stack_updated_events(stack)
123
+
124
+ utterance = "utter_continue_interruption"
125
+ message = await nlg.generate(
126
+ utterance,
127
+ tracker,
128
+ output_channel.name(),
129
+ )
130
+
131
+ if message is None:
132
+ structlogger.error(
133
+ "action_continue_digression.run.failed.finding.utter",
134
+ utterance=utterance,
135
+ )
136
+ else:
137
+ message = add_bot_utterance_metadata(
138
+ message, utterance, nlg, domain, tracker
139
+ )
140
+ events.append(create_bot_utterance(message))
141
+
142
+ return events
@@ -16,7 +16,7 @@ from rasa.shared.constants import UTTER_PREFIX
16
16
  from rasa.shared.core.constants import (
17
17
  ACTION_EXTRACT_SLOTS,
18
18
  ACTION_LISTEN_NAME,
19
- MAPPING_TYPE,
19
+ KEY_MAPPING_TYPE,
20
20
  REQUESTED_SLOT,
21
21
  SLOT_MAPPINGS,
22
22
  SlotMappingType,
@@ -158,7 +158,9 @@ class FormAction(LoopAction):
158
158
  domain_slots = domain.as_dict().get(KEY_SLOTS, {})
159
159
  for slot in domain.required_slots_for_form(self.name()):
160
160
  for slot_mapping in domain_slots.get(slot, {}).get(SLOT_MAPPINGS, []):
161
- if slot_mapping.get(MAPPING_TYPE) == str(SlotMappingType.FROM_ENTITY):
161
+ if slot_mapping.get(KEY_MAPPING_TYPE) == str(
162
+ SlotMappingType.FROM_ENTITY
163
+ ):
162
164
  mapping_as_string = json.dumps(slot_mapping, sort_keys=True)
163
165
  if mapping_as_string in unique_entity_slot_mappings:
164
166
  unique_entity_slot_mappings.remove(mapping_as_string)
@@ -1,9 +1,11 @@
1
+ import asyncio
1
2
  import copy
2
3
  import json
3
4
  import uuid
5
+ from collections import defaultdict
4
6
  from dataclasses import asdict
5
7
  from datetime import datetime, timedelta, timezone
6
- from typing import Any, Awaitable, Callable, Dict, List, Optional, Text, Union
8
+ from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Text, Union
7
9
 
8
10
  import structlog
9
11
  from jsonschema import ValidationError, validate
@@ -223,6 +225,16 @@ class AudiocodesInput(InputChannel):
223
225
  self.scheduler_job = None
224
226
  self.keep_alive = keep_alive
225
227
  self.keep_alive_expiration_factor = keep_alive_expiration_factor
228
+ self.background_tasks: Dict[Text, Set[asyncio.Task]] = defaultdict(set)
229
+
230
+ def _create_task(self, conversation_id: Text, coro: Awaitable[Any]) -> asyncio.Task:
231
+ """Create and track an asyncio task for a conversation."""
232
+ task: asyncio.Task = asyncio.create_task(coro)
233
+ self.background_tasks[conversation_id].add(task)
234
+ task.add_done_callback(
235
+ lambda t: self.background_tasks[conversation_id].discard(t)
236
+ )
237
+ return task
226
238
 
227
239
  async def _set_scheduler_job(self) -> None:
228
240
  if self.scheduler_job:
@@ -251,11 +263,20 @@ class AudiocodesInput(InputChannel):
251
263
  )
252
264
  now = datetime.now(timezone.utc)
253
265
  delta = timedelta(seconds=self.keep_alive * self.keep_alive_expiration_factor)
254
- self.conversations = {
255
- k: v
256
- for k, v in self.conversations.items()
257
- if v.is_active_conversation(now, delta)
258
- }
266
+
267
+ # clean up conversations
268
+ inactive = [
269
+ conv_id
270
+ for conv_id, conv in self.conversations.items()
271
+ if not conv.is_active_conversation(now, delta)
272
+ ]
273
+
274
+ # cancel tasks and remove conversations
275
+ for conv_id in inactive:
276
+ for task in self.background_tasks[conv_id]:
277
+ task.cancel()
278
+ self.background_tasks.pop(conv_id, None)
279
+ self.conversations.pop(conv_id, None)
259
280
 
260
281
  def handle_start_conversation(self, body: Dict[Text, Any]) -> Dict[Text, Any]:
261
282
  conversation_id = body["conversation"]
@@ -347,31 +368,29 @@ class AudiocodesInput(InputChannel):
347
368
  structlogger.debug("audiocodes.on_activities", conversation=conversation_id)
348
369
  conversation = self._get_conversation(request.token, conversation_id)
349
370
  if conversation is None:
371
+ structlogger.warning(
372
+ "audiocodes.on_activities.no_conversation", request=request.json
373
+ )
350
374
  return response.json({})
351
375
  elif conversation.ws:
352
376
  ac_output: Union[WebsocketOutput, AudiocodesOutput] = WebsocketOutput(
353
377
  conversation.ws, conversation_id
354
378
  )
355
- await conversation.handle_activities(
356
- request.json,
357
- output_channel=ac_output,
358
- on_new_message=on_new_message,
359
- )
360
- return response.json({})
379
+ response_json = {}
361
380
  else:
362
381
  # handle non websocket case where messages get returned in json
363
382
  ac_output = AudiocodesOutput()
364
- await conversation.handle_activities(
365
- request.json,
366
- output_channel=ac_output,
367
- on_new_message=on_new_message,
368
- )
369
- return response.json(
370
- {
371
- "conversation": conversation_id,
372
- "activities": ac_output.messages,
373
- }
374
- )
383
+ response_json = {
384
+ "conversation": conversation_id,
385
+ "activities": ac_output.messages,
386
+ }
387
+
388
+ # start a background task to handle activities
389
+ self._create_task(
390
+ conversation_id,
391
+ conversation.handle_activities(request.json, ac_output, on_new_message),
392
+ )
393
+ return response.json(response_json)
375
394
 
376
395
  @ac_webhook.route(
377
396
  "/conversation/<conversation_id>/disconnect", methods=["POST"]
@@ -81,7 +81,8 @@ class AzureTTS(TTSEngine[AzureTTSConfig]):
81
81
  @staticmethod
82
82
  def create_request_body(text: str, conf: AzureTTSConfig) -> str:
83
83
  return f"""
84
- <speak version='1.0' xml:lang='{conf.language}'>
84
+ <speak version='1.0' xml:lang='{conf.language}' xmlns:mstts='http://www.w3.org/2001/mstts'
85
+ xmlns='http://www.w3.org/2001/10/synthesis'>
85
86
  <voice xml:lang='{conf.language}' name='{conf.voice}'>
86
87
  {text}
87
88
  </voice>
rasa/core/migrate.py CHANGED
@@ -14,7 +14,7 @@ from rasa.shared.constants import (
14
14
  )
15
15
  from rasa.shared.core.constants import (
16
16
  ACTIVE_LOOP,
17
- MAPPING_TYPE,
17
+ KEY_MAPPING_TYPE,
18
18
  REQUESTED_SLOT,
19
19
  SLOT_MAPPINGS,
20
20
  SlotMappingType,
@@ -43,7 +43,7 @@ def _create_back_up(domain_file: Path, backup_location: Path) -> Dict[Text, Any]
43
43
  def _get_updated_mapping_condition(
44
44
  condition: Dict[Text, Text], mapping: Dict[Text, Any], slot_name: Text
45
45
  ) -> Dict[Text, Text]:
46
- if mapping.get(MAPPING_TYPE) not in [
46
+ if mapping.get(KEY_MAPPING_TYPE) not in [
47
47
  str(SlotMappingType.FROM_ENTITY),
48
48
  str(SlotMappingType.FROM_TRIGGER_INTENT),
49
49
  ]:
@@ -23,6 +23,7 @@ from rasa.core.policies.flows.flow_step_result import (
23
23
  )
24
24
  from rasa.dialogue_understanding.commands import CancelFlowCommand
25
25
  from rasa.dialogue_understanding.patterns.cancel import CancelPatternFlowStackFrame
26
+ from rasa.dialogue_understanding.patterns.clarify import ClarifyPatternFlowStackFrame
26
27
  from rasa.dialogue_understanding.patterns.collect_information import (
27
28
  CollectInformationPatternFlowStackFrame,
28
29
  )
@@ -50,6 +51,7 @@ from rasa.dialogue_understanding.stack.frames.flow_stack_frame import (
50
51
  )
51
52
  from rasa.dialogue_understanding.stack.utils import (
52
53
  top_user_flow_frame,
54
+ user_flows_on_the_stack,
53
55
  )
54
56
  from rasa.shared.constants import RASA_PATTERN_HUMAN_HANDOFF
55
57
  from rasa.shared.core.constants import ACTION_LISTEN_NAME, SlotMappingType
@@ -272,6 +274,28 @@ def trigger_pattern_continue_interrupted(
272
274
  return events
273
275
 
274
276
 
277
+ def trigger_pattern_clarification(
278
+ current_frame: DialogueStackFrame, stack: DialogueStack, flows: FlowsList
279
+ ) -> None:
280
+ """Trigger the pattern to clarify which topic to continue if needed."""
281
+ if not isinstance(current_frame, UserFlowStackFrame):
282
+ return None
283
+
284
+ if current_frame.frame_type == FlowStackFrameType.CALL:
285
+ # we want to return to the flow that called the current flow
286
+ return None
287
+
288
+ pending_flows = [
289
+ flows.flow_by_id(frame.flow_id)
290
+ for frame in stack.frames
291
+ if isinstance(frame, UserFlowStackFrame)
292
+ and frame.flow_id != current_frame.flow_id
293
+ ]
294
+
295
+ flow_names = [flow.readable_name() for flow in pending_flows if flow is not None]
296
+ stack.push(ClarifyPatternFlowStackFrame(names=flow_names))
297
+
298
+
275
299
  def trigger_pattern_completed(
276
300
  current_frame: DialogueStackFrame, stack: DialogueStack, flows: FlowsList
277
301
  ) -> None:
@@ -669,7 +693,15 @@ def _run_end_step(
669
693
  structlogger.debug("flow.step.run.flow_end")
670
694
  current_frame = stack.pop()
671
695
  trigger_pattern_completed(current_frame, stack, flows)
672
- resumed_events = trigger_pattern_continue_interrupted(current_frame, stack, flows)
696
+ resumed_events = []
697
+ if len(user_flows_on_the_stack(stack)) > 1:
698
+ # if there are more user flows on the stack,
699
+ # we need to trigger the pattern clarify
700
+ trigger_pattern_clarification(current_frame, stack, flows)
701
+ else:
702
+ resumed_events = trigger_pattern_continue_interrupted(
703
+ current_frame, stack, flows
704
+ )
673
705
  reset_events: List[Event] = reset_scoped_slots(current_frame, flow, tracker)
674
706
  return ContinueFlowWithNextStep(
675
707
  events=initial_events + reset_events + resumed_events, has_flow_ended=True
@@ -74,7 +74,7 @@ class CannotHandleCommand(Command):
74
74
 
75
75
  def to_dsl(self) -> str:
76
76
  """Converts the command to a DSL string."""
77
- return "cannot handle"
77
+ return "CannotHandle()"
78
78
 
79
79
  @classmethod
80
80
  def from_dsl(cls, match: re.Match, **kwargs: Any) -> CannotHandleCommand:
@@ -86,4 +86,4 @@ class CannotHandleCommand(Command):
86
86
 
87
87
  @staticmethod
88
88
  def regex_pattern() -> str:
89
- return r"^cannot handle$"
89
+ return r"CannotHandle\(\)"
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import copy
3
4
  import re
4
5
  from dataclasses import dataclass
5
6
  from typing import Any, Dict, List
@@ -8,8 +9,11 @@ import structlog
8
9
 
9
10
  from rasa.dialogue_understanding.commands.command import Command
10
11
  from rasa.dialogue_understanding.patterns.cancel import CancelPatternFlowStackFrame
12
+ from rasa.dialogue_understanding.patterns.clarify import ClarifyPatternFlowStackFrame
11
13
  from rasa.dialogue_understanding.stack.dialogue_stack import DialogueStack
12
- from rasa.dialogue_understanding.stack.frames import UserFlowStackFrame
14
+ from rasa.dialogue_understanding.stack.frames import (
15
+ UserFlowStackFrame,
16
+ )
13
17
  from rasa.dialogue_understanding.stack.frames.flow_stack_frame import FlowStackFrameType
14
18
  from rasa.dialogue_understanding.stack.utils import top_user_flow_frame
15
19
  from rasa.shared.core.events import Event, FlowCancelled
@@ -89,7 +93,8 @@ class CancelFlowCommand(Command):
89
93
  original_stack = original_tracker.stack
90
94
 
91
95
  applied_events: List[Event] = []
92
-
96
+ # capture the top frame before we push new frames onto the stack
97
+ initial_top_frame = stack.top()
93
98
  user_frame = top_user_flow_frame(original_stack)
94
99
  current_flow = user_frame.flow(all_flows) if user_frame else None
95
100
 
@@ -114,6 +119,21 @@ class CancelFlowCommand(Command):
114
119
  if user_frame:
115
120
  applied_events.append(FlowCancelled(user_frame.flow_id, user_frame.step_id))
116
121
 
122
+ if initial_top_frame and isinstance(
123
+ initial_top_frame, ClarifyPatternFlowStackFrame
124
+ ):
125
+ structlogger.debug(
126
+ "command_executor.cancel_flow.cancel_clarification_options",
127
+ clarification_options=initial_top_frame.clarification_options,
128
+ )
129
+ applied_events += cancel_all_pending_clarification_options(
130
+ initial_top_frame,
131
+ original_stack,
132
+ canceled_frames,
133
+ all_flows,
134
+ stack,
135
+ )
136
+
117
137
  return applied_events + tracker.create_stack_updated_events(stack)
118
138
 
119
139
  def __hash__(self) -> int:
@@ -124,7 +144,7 @@ class CancelFlowCommand(Command):
124
144
 
125
145
  def to_dsl(self) -> str:
126
146
  """Converts the command to a DSL string."""
127
- return "cancel"
147
+ return "CancelFlow()"
128
148
 
129
149
  @classmethod
130
150
  def from_dsl(cls, match: re.Match, **kwargs: Any) -> CancelFlowCommand:
@@ -133,4 +153,42 @@ class CancelFlowCommand(Command):
133
153
 
134
154
  @staticmethod
135
155
  def regex_pattern() -> str:
136
- return r"^cancel$"
156
+ return r"CancelFlow\(\)"
157
+
158
+
159
+ def cancel_all_pending_clarification_options(
160
+ initial_top_frame: ClarifyPatternFlowStackFrame,
161
+ original_stack: DialogueStack,
162
+ canceled_frames: List[str],
163
+ all_flows: FlowsList,
164
+ stack: DialogueStack,
165
+ ) -> List[FlowCancelled]:
166
+ """Cancel all pending clarification options.
167
+
168
+ This is a special case when the assistant asks the user to clarify
169
+ which pending digression flow to start after the completion of an active flow.
170
+ If the user chooses to cancel all options, this function takes care of
171
+ updating the stack by removing all pending flow stack frames
172
+ listed as clarification options.
173
+ """
174
+ clarification_names = set(initial_top_frame.names)
175
+ to_be_canceled_frames = []
176
+ applied_events = []
177
+ for frame in reversed(original_stack.frames):
178
+ if frame.frame_id in canceled_frames:
179
+ continue
180
+
181
+ to_be_canceled_frames.append(frame.frame_id)
182
+ if isinstance(frame, UserFlowStackFrame):
183
+ readable_flow_name = frame.flow(all_flows).readable_name()
184
+ if readable_flow_name in clarification_names:
185
+ stack.push(
186
+ CancelPatternFlowStackFrame(
187
+ canceled_name=readable_flow_name,
188
+ canceled_frames=copy.deepcopy(to_be_canceled_frames),
189
+ )
190
+ )
191
+ applied_events.append(FlowCancelled(frame.flow_id, frame.step_id))
192
+ to_be_canceled_frames.clear()
193
+
194
+ return applied_events
@@ -48,7 +48,7 @@ class ChangeFlowCommand(Command):
48
48
 
49
49
  def to_dsl(self) -> str:
50
50
  """Converts the command to a DSL string."""
51
- return "change"
51
+ return "ChangeFlow()"
52
52
 
53
53
  @staticmethod
54
54
  def from_dsl(match: re.Match, **kwargs: Any) -> ChangeFlowCommand:
@@ -57,4 +57,4 @@ class ChangeFlowCommand(Command):
57
57
 
58
58
  @staticmethod
59
59
  def regex_pattern() -> str:
60
- return r"^change"
60
+ return r"ChangeFlow\(\)"
@@ -59,7 +59,7 @@ class ChitChatAnswerCommand(FreeFormAnswerCommand):
59
59
 
60
60
  def to_dsl(self) -> str:
61
61
  """Converts the command to a DSL string."""
62
- return "chat"
62
+ return "ChitChat()"
63
63
 
64
64
  @classmethod
65
65
  def from_dsl(cls, match: re.Match, **kwargs: Any) -> ChitChatAnswerCommand:
@@ -68,4 +68,4 @@ class ChitChatAnswerCommand(FreeFormAnswerCommand):
68
68
 
69
69
  @staticmethod
70
70
  def regex_pattern() -> str:
71
- return r"^chat$"
71
+ return r"ChitChat\(\)"
@@ -89,7 +89,7 @@ class ClarifyCommand(Command):
89
89
 
90
90
  def to_dsl(self) -> str:
91
91
  """Converts the command to a DSL string."""
92
- return f"clarify {' '.join(self.options)}"
92
+ return f"Clarify({', '.join(self.options)})"
93
93
 
94
94
  @classmethod
95
95
  def from_dsl(cls, match: re.Match, **kwargs: Any) -> Optional[ClarifyCommand]:
@@ -99,4 +99,4 @@ class ClarifyCommand(Command):
99
99
 
100
100
  @staticmethod
101
101
  def regex_pattern() -> str:
102
- return r"^clarify([\"\'a-zA-Z0-9_, ]*)$"
102
+ return r"Clarify\(([\"\'a-zA-Z0-9_, ]*)\)"
@@ -31,6 +31,7 @@ class CorrectedSlot:
31
31
 
32
32
  name: str
33
33
  value: Any
34
+ filled_by: Optional[str] = None
34
35
 
35
36
 
36
37
  @dataclass
@@ -54,7 +55,9 @@ class CorrectSlotsCommand(Command):
54
55
  try:
55
56
  return CorrectSlotsCommand(
56
57
  corrected_slots=[
57
- CorrectedSlot(s["name"], value=s["value"])
58
+ CorrectedSlot(
59
+ s["name"], value=s["value"], filled_by=s.get("filled_by", None)
60
+ )
58
61
  for s in data["corrected_slots"]
59
62
  ]
60
63
  )
@@ -135,7 +138,10 @@ class CorrectSlotsCommand(Command):
135
138
  proposed_slots = {}
136
139
  for corrected_slot in self.corrected_slots:
137
140
  if tracker.get_slot(corrected_slot.name) != corrected_slot.value:
138
- proposed_slots[corrected_slot.name] = corrected_slot.value
141
+ proposed_slots[corrected_slot.name] = {
142
+ "value": corrected_slot.value,
143
+ "filled_by": corrected_slot.filled_by,
144
+ }
139
145
  else:
140
146
  structlogger.debug(
141
147
  "command_executor.skip_correction.slot_already_set", command=self
@@ -240,6 +246,9 @@ class CorrectSlotsCommand(Command):
240
246
  corrected_slots=proposed_slots,
241
247
  reset_flow_id=earliest_collect.flow_id if earliest_collect else None,
242
248
  reset_step_id=earliest_collect.step.id if earliest_collect else None,
249
+ new_slot_values=[
250
+ value.get("value") for slot, value in proposed_slots.items()
251
+ ],
243
252
  )
244
253
 
245
254
  def run_command_on_tracker(