rasa-pro 3.12.0.dev9__py3-none-any.whl → 3.12.0.dev11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (72) hide show
  1. rasa/cli/inspect.py +20 -1
  2. rasa/cli/shell.py +3 -3
  3. rasa/core/actions/action.py +20 -7
  4. rasa/core/actions/action_handle_digressions.py +142 -0
  5. rasa/core/actions/forms.py +10 -5
  6. rasa/core/channels/__init__.py +2 -0
  7. rasa/core/channels/voice_ready/audiocodes.py +42 -23
  8. rasa/core/channels/voice_stream/browser_audio.py +1 -0
  9. rasa/core/channels/voice_stream/call_state.py +7 -1
  10. rasa/core/channels/voice_stream/genesys.py +331 -0
  11. rasa/core/channels/voice_stream/tts/azure.py +2 -1
  12. rasa/core/channels/voice_stream/tts/cartesia.py +16 -3
  13. rasa/core/channels/voice_stream/twilio_media_streams.py +2 -1
  14. rasa/core/channels/voice_stream/voice_channel.py +2 -1
  15. rasa/core/migrate.py +2 -2
  16. rasa/core/policies/flows/flow_executor.py +36 -42
  17. rasa/core/run.py +4 -3
  18. rasa/dialogue_understanding/commands/can_not_handle_command.py +2 -2
  19. rasa/dialogue_understanding/commands/cancel_flow_command.py +62 -4
  20. rasa/dialogue_understanding/commands/change_flow_command.py +2 -2
  21. rasa/dialogue_understanding/commands/chit_chat_answer_command.py +2 -2
  22. rasa/dialogue_understanding/commands/clarify_command.py +2 -2
  23. rasa/dialogue_understanding/commands/correct_slots_command.py +11 -2
  24. rasa/dialogue_understanding/commands/handle_digressions_command.py +150 -0
  25. rasa/dialogue_understanding/commands/human_handoff_command.py +2 -2
  26. rasa/dialogue_understanding/commands/knowledge_answer_command.py +2 -2
  27. rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +2 -2
  28. rasa/dialogue_understanding/commands/set_slot_command.py +7 -15
  29. rasa/dialogue_understanding/commands/skip_question_command.py +2 -2
  30. rasa/dialogue_understanding/commands/start_flow_command.py +43 -2
  31. rasa/dialogue_understanding/commands/utils.py +1 -1
  32. rasa/dialogue_understanding/constants.py +1 -0
  33. rasa/dialogue_understanding/generator/command_generator.py +110 -73
  34. rasa/dialogue_understanding/generator/command_parser.py +1 -1
  35. rasa/dialogue_understanding/generator/llm_based_command_generator.py +161 -3
  36. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +10 -2
  37. rasa/dialogue_understanding/generator/nlu_command_adapter.py +44 -3
  38. rasa/dialogue_understanding/generator/single_step/command_prompt_template.jinja2 +40 -40
  39. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +11 -19
  40. rasa/dialogue_understanding/generator/utils.py +32 -1
  41. rasa/dialogue_understanding/patterns/correction.py +13 -1
  42. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +62 -2
  43. rasa/dialogue_understanding/patterns/handle_digressions.py +81 -0
  44. rasa/dialogue_understanding/processor/command_processor.py +115 -28
  45. rasa/dialogue_understanding/utils.py +31 -0
  46. rasa/dialogue_understanding_test/README.md +50 -0
  47. rasa/dialogue_understanding_test/test_case_simulation/test_case_tracker_simulator.py +3 -3
  48. rasa/model_service.py +4 -0
  49. rasa/model_training.py +24 -27
  50. rasa/shared/core/constants.py +28 -3
  51. rasa/shared/core/domain.py +13 -20
  52. rasa/shared/core/events.py +13 -2
  53. rasa/shared/core/flows/flow.py +17 -0
  54. rasa/shared/core/flows/flows_yaml_schema.json +38 -0
  55. rasa/shared/core/flows/steps/collect.py +18 -1
  56. rasa/shared/core/flows/utils.py +16 -1
  57. rasa/shared/core/slot_mappings.py +144 -108
  58. rasa/shared/core/slots.py +23 -2
  59. rasa/shared/core/trackers.py +3 -1
  60. rasa/shared/nlu/constants.py +1 -0
  61. rasa/shared/providers/llm/_base_litellm_client.py +0 -40
  62. rasa/shared/utils/llm.py +1 -86
  63. rasa/shared/utils/schemas/domain.yml +0 -1
  64. rasa/telemetry.py +43 -13
  65. rasa/utils/common.py +0 -1
  66. rasa/validator.py +189 -82
  67. rasa/version.py +1 -1
  68. {rasa_pro-3.12.0.dev9.dist-info → rasa_pro-3.12.0.dev11.dist-info}/METADATA +1 -1
  69. {rasa_pro-3.12.0.dev9.dist-info → rasa_pro-3.12.0.dev11.dist-info}/RECORD +72 -68
  70. {rasa_pro-3.12.0.dev9.dist-info → rasa_pro-3.12.0.dev11.dist-info}/NOTICE +0 -0
  71. {rasa_pro-3.12.0.dev9.dist-info → rasa_pro-3.12.0.dev11.dist-info}/WHEEL +0 -0
  72. {rasa_pro-3.12.0.dev9.dist-info → rasa_pro-3.12.0.dev11.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,331 @@
1
+ import asyncio
2
+ import json
3
+ from typing import Any, Awaitable, Callable, Dict, Optional, Text
4
+
5
+ import structlog
6
+ from sanic import ( # type: ignore[attr-defined]
7
+ Blueprint,
8
+ HTTPResponse,
9
+ Request,
10
+ Websocket,
11
+ response,
12
+ )
13
+
14
+ from rasa.core.channels import UserMessage
15
+ from rasa.core.channels.voice_ready.utils import CallParameters
16
+ from rasa.core.channels.voice_stream.audio_bytes import RasaAudioBytes
17
+ from rasa.core.channels.voice_stream.call_state import (
18
+ call_state,
19
+ )
20
+ from rasa.core.channels.voice_stream.tts.tts_engine import TTSEngine
21
+ from rasa.core.channels.voice_stream.voice_channel import (
22
+ ContinueConversationAction,
23
+ EndConversationAction,
24
+ NewAudioAction,
25
+ VoiceChannelAction,
26
+ VoiceInputChannel,
27
+ VoiceOutputChannel,
28
+ )
29
+
30
+ # Not mentioned in the documentation but observed in Geneys's example
31
+ # https://github.com/GenesysCloudBlueprints/audioconnector-server-reference-implementation
32
+ MAXIMUM_BINARY_MESSAGE_SIZE = 64000 # 64KB
33
+ logger = structlog.get_logger(__name__)
34
+
35
+
36
+ def map_call_params(data: Dict[Text, Any]) -> CallParameters:
37
+ """Map the twilio stream parameters to the CallParameters dataclass."""
38
+ parameters = data["parameters"]
39
+ participant = parameters["participant"]
40
+ # sent as {"ani": "tel:+491604697810"}
41
+ ani = participant.get("ani", "")
42
+ user_phone = ani.split(":")[-1] if ani else ""
43
+
44
+ return CallParameters(
45
+ call_id=parameters.get("conversationId", ""),
46
+ user_phone=user_phone,
47
+ bot_phone=participant.get("dnis", ""),
48
+ )
49
+
50
+
51
+ class GenesysOutputChannel(VoiceOutputChannel):
52
+ @classmethod
53
+ def name(cls) -> str:
54
+ return "genesys"
55
+
56
+ async def send_audio_bytes(
57
+ self, recipient_id: str, audio_bytes: RasaAudioBytes
58
+ ) -> None:
59
+ """
60
+ Send audio bytes to the recipient with buffering.
61
+
62
+ Genesys throws a rate limit error with too many audio messages.
63
+ To avoid this, we buffer the audio messages and send them in chunks.
64
+
65
+ - global.inbound.binary.average.rate.per.second: 5
66
+ The allowed average rate per second of inbound binary data
67
+
68
+ - global.inbound.binary.max: 25
69
+ The maximum number of inbound binary data messages
70
+ that can be sent instantaneously
71
+
72
+ https://developer.genesys.cloud/organization/organization/limits#audiohook
73
+ """
74
+ call_state.audio_buffer.extend(audio_bytes)
75
+
76
+ # If we receive a non-standard chunk size, assume it's the end of a sequence
77
+ # or buffer is more than 32KB (this is half of genesys's max audio message size)
78
+ if len(audio_bytes) != 1024 or len(call_state.audio_buffer) >= (
79
+ MAXIMUM_BINARY_MESSAGE_SIZE / 2
80
+ ):
81
+ # TODO: we should send the buffer when we receive a synthesis complete event
82
+ # from TTS. This will ensure that the last audio chunk is always sent.
83
+ await self._send_audio_buffer(self.voice_websocket)
84
+
85
+ async def _send_audio_buffer(self, ws: Websocket) -> None:
86
+ """Send the audio buffer to the recipient if it's not empty."""
87
+ if call_state.audio_buffer:
88
+ buffer_bytes = bytes(call_state.audio_buffer)
89
+ await self._send_bytes_to_ws(ws, buffer_bytes)
90
+ call_state.audio_buffer.clear()
91
+
92
+ async def _send_bytes_to_ws(self, ws: Websocket, data: bytes) -> None:
93
+ """Send audio bytes to the recipient as a binary websocket message."""
94
+ if len(data) <= MAXIMUM_BINARY_MESSAGE_SIZE:
95
+ await self.voice_websocket.send(data)
96
+ else:
97
+ # split the audio into chunks
98
+ current_position = 0
99
+ while current_position < len(data):
100
+ end_position = min(
101
+ current_position + MAXIMUM_BINARY_MESSAGE_SIZE, len(data)
102
+ )
103
+ await self.voice_websocket.send(data[current_position:end_position])
104
+ current_position = end_position
105
+
106
+ async def send_marker_message(self, recipient_id: str) -> None:
107
+ """Send a message that marks positions in the audio stream."""
108
+ pass
109
+
110
+
111
+ class GenesysInputChannel(VoiceInputChannel):
112
+ @classmethod
113
+ def name(cls) -> str:
114
+ return "genesys"
115
+
116
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
117
+ super().__init__(*args, **kwargs)
118
+
119
+ def _get_next_sequence(self) -> int:
120
+ """
121
+ Get the next message sequence number
122
+ Rasa == Server
123
+ Genesys == Client
124
+
125
+ Genesys requires the server and client each maintain a
126
+ monotonically increasing message sequence number.
127
+ """
128
+ cs = call_state
129
+ cs.server_sequence_number += 1 # type: ignore[attr-defined]
130
+ return cs.server_sequence_number
131
+
132
+ def _get_last_client_sequence(self) -> int:
133
+ """Get the last client(Genesys) sequence number."""
134
+ return call_state.client_sequence_number
135
+
136
+ def _update_client_sequence(self, seq: int) -> None:
137
+ """Update the client(Genesys) sequence number."""
138
+ if seq - call_state.client_sequence_number != 1:
139
+ logger.warning(
140
+ "genesys.update_client_sequence.sequence_gap",
141
+ received_seq=seq,
142
+ last_seq=call_state.client_sequence_number,
143
+ )
144
+ call_state.client_sequence_number = seq # type: ignore[attr-defined]
145
+
146
+ def channel_bytes_to_rasa_audio_bytes(self, input_bytes: bytes) -> RasaAudioBytes:
147
+ return RasaAudioBytes(input_bytes)
148
+
149
+ async def collect_call_parameters(
150
+ self, channel_websocket: Websocket
151
+ ) -> Optional[CallParameters]:
152
+ """Call Parameters are collected during the open event."""
153
+ async for message in channel_websocket:
154
+ data = json.loads(message)
155
+ self._update_client_sequence(data["seq"])
156
+ if data.get("type") == "open":
157
+ call_params = await self.handle_open(channel_websocket, data)
158
+ return call_params
159
+ else:
160
+ logger.error("genesys.receive.unexpected_initial_message", message=data)
161
+
162
+ return None
163
+
164
+ def map_input_message(
165
+ self,
166
+ message: Any,
167
+ ws: Websocket,
168
+ ) -> VoiceChannelAction:
169
+ # if message is binary, it's audio
170
+ if isinstance(message, bytes):
171
+ return NewAudioAction(self.channel_bytes_to_rasa_audio_bytes(message))
172
+ else:
173
+ # process text message
174
+ data = json.loads(message)
175
+ self._update_client_sequence(data["seq"])
176
+ msg_type = data.get("type")
177
+ if msg_type == "close":
178
+ logger.info("genesys.handle_close", message=data)
179
+ self.handle_close(ws, data)
180
+ return EndConversationAction()
181
+ elif msg_type == "ping":
182
+ logger.info("genesys.handle_ping", message=data)
183
+ self.handle_ping(ws, data)
184
+ elif msg_type == "playback_started":
185
+ logger.debug("genesys.handle_playback_started", message=data)
186
+ call_state.is_bot_speaking = True # type: ignore[attr-defined]
187
+ elif msg_type == "playback_completed":
188
+ logger.debug("genesys.handle_playback_completed", message=data)
189
+ call_state.is_bot_speaking = False # type: ignore[attr-defined]
190
+ if call_state.should_hangup:
191
+ logger.info("genesys.hangup")
192
+ self.disconnect(ws, data)
193
+ elif msg_type == "dtmf":
194
+ logger.info("genesys.handle_dtmf", message=data)
195
+ elif msg_type == "error":
196
+ logger.warning("genesys.handle_error", message=data)
197
+ else:
198
+ logger.warning("genesys.map_input_message.unknown_type", message=data)
199
+
200
+ return ContinueConversationAction()
201
+
202
+ def create_output_channel(
203
+ self, voice_websocket: Websocket, tts_engine: TTSEngine
204
+ ) -> VoiceOutputChannel:
205
+ return GenesysOutputChannel(
206
+ voice_websocket,
207
+ tts_engine,
208
+ self.tts_cache,
209
+ )
210
+
211
+ async def handle_open(self, ws: Websocket, message: dict) -> CallParameters:
212
+ """Handle initial open transaction from Genesys."""
213
+ call_parameters = map_call_params(message)
214
+ params = message["parameters"]
215
+ media_options = params.get("media", [])
216
+
217
+ # Send opened response
218
+ if media_options:
219
+ logger.info("genesys.handle_open", media_parameter=media_options[0])
220
+ response = {
221
+ "version": "2",
222
+ "type": "opened",
223
+ "seq": self._get_next_sequence(),
224
+ "clientseq": self._get_last_client_sequence(),
225
+ "id": message.get("id"),
226
+ "parameters": {"startPaused": False, "media": [media_options[0]]},
227
+ }
228
+ logger.debug("genesys.handle_open.opened", response=response)
229
+ await ws.send(json.dumps(response))
230
+ else:
231
+ logger.warning(
232
+ "genesys.handle_open.no_media_formats", client_message=message
233
+ )
234
+ return call_parameters
235
+
236
+ def handle_ping(self, ws: Websocket, message: dict) -> None:
237
+ """Handle ping message from Genesys."""
238
+ response = {
239
+ "version": "2",
240
+ "type": "pong",
241
+ "seq": self._get_next_sequence(),
242
+ "clientseq": message.get("seq"),
243
+ "id": message.get("id"),
244
+ "parameters": {},
245
+ }
246
+ logger.debug("genesys.handle_ping.pong", response=response)
247
+ _schedule_ws_task(ws.send(json.dumps(response)))
248
+
249
+ def handle_close(self, ws: Websocket, message: dict) -> None:
250
+ """Handle close message from Genesys."""
251
+ response = {
252
+ "version": "2",
253
+ "type": "closed",
254
+ "seq": self._get_next_sequence(),
255
+ "clientseq": self._get_last_client_sequence(),
256
+ "id": message.get("id"),
257
+ "parameters": message.get("parameters", {}),
258
+ }
259
+ logger.debug("genesys.handle_close.closed", response=response)
260
+
261
+ _schedule_ws_task(ws.send(json.dumps(response)))
262
+ _schedule_ws_task(ws.close())
263
+
264
+ def disconnect(self, ws: Websocket, data: dict) -> None:
265
+ """
266
+ Send disconnect message to Genesys.
267
+
268
+ https://developer.genesys.cloud/devapps/audiohook/protocol-reference#disconnect
269
+ It should be used to hangup the call.
270
+ Genesys will respond with a "close" message to us
271
+ that is handled by the handle_close method.
272
+ """
273
+ message = {
274
+ "version": "2",
275
+ "type": "disconnect",
276
+ "seq": self._get_next_sequence(),
277
+ "clientseq": self._get_last_client_sequence(),
278
+ "id": data.get("id"),
279
+ "parameters": {
280
+ "reason": "completed",
281
+ # arbitrary values can be sent here
282
+ },
283
+ }
284
+ logger.debug("genesys.disconnect", message=message)
285
+ _schedule_ws_task(ws.send(json.dumps(message)))
286
+
287
+ def blueprint(
288
+ self, on_new_message: Callable[[UserMessage], Awaitable[Any]]
289
+ ) -> Blueprint:
290
+ """Defines a Sanic blueprint for the voice input channel."""
291
+ blueprint = Blueprint("genesys", __name__)
292
+
293
+ @blueprint.route("/", methods=["GET"])
294
+ async def health(_: Request) -> HTTPResponse:
295
+ return response.json({"status": "ok"})
296
+
297
+ @blueprint.websocket("/websocket") # type: ignore[misc]
298
+ async def receive(request: Request, ws: Websocket) -> None:
299
+ logger.debug(
300
+ "genesys.receive",
301
+ audiohook_session_id=request.headers.get("audiohook-session-id"),
302
+ )
303
+ # validate required headers
304
+ required_headers = [
305
+ "audiohook-organization-id",
306
+ "audiohook-correlation-id",
307
+ "audiohook-session-id",
308
+ "x-api-key",
309
+ ]
310
+
311
+ for header in required_headers:
312
+ if header not in request.headers:
313
+ await ws.close(1008, f"Missing required header: {header}")
314
+ return
315
+
316
+ # TODO: validate API key header
317
+ # process audio streaming
318
+ logger.info("genesys.receive", message="Starting audio streaming")
319
+ await self.run_audio_streaming(on_new_message, ws)
320
+
321
+ return blueprint
322
+
323
+
324
+ def _schedule_ws_task(coro: Awaitable[Any]) -> None:
325
+ """Helper function to schedule a coroutine in the event loop.
326
+
327
+ Args:
328
+ coro: The coroutine to schedule
329
+ """
330
+ loop = asyncio.get_running_loop()
331
+ loop.call_soon_threadsafe(lambda: loop.create_task(coro))
@@ -81,7 +81,8 @@ class AzureTTS(TTSEngine[AzureTTSConfig]):
81
81
  @staticmethod
82
82
  def create_request_body(text: str, conf: AzureTTSConfig) -> str:
83
83
  return f"""
84
- <speak version='1.0' xml:lang='{conf.language}'>
84
+ <speak version='1.0' xml:lang='{conf.language}' xmlns:mstts='http://www.w3.org/2001/mstts'
85
+ xmlns='http://www.w3.org/2001/10/synthesis'>
85
86
  <voice xml:lang='{conf.language}' name='{conf.voice}'>
86
87
  {text}
87
88
  </voice>
@@ -1,3 +1,5 @@
1
+ import base64
2
+ import json
1
3
  import os
2
4
  from dataclasses import dataclass
3
5
  from typing import AsyncIterator, Dict, Optional
@@ -39,7 +41,7 @@ class CartesiaTTS(TTSEngine[CartesiaTTSConfig]):
39
41
  @staticmethod
40
42
  def get_tts_endpoint() -> str:
41
43
  """Create the endpoint string for cartesia."""
42
- return "https://api.cartesia.ai/tts/bytes"
44
+ return "https://api.cartesia.ai/tts/sse"
43
45
 
44
46
  @staticmethod
45
47
  def get_request_body(text: str, config: CartesiaTTSConfig) -> Dict:
@@ -85,8 +87,19 @@ class CartesiaTTS(TTSEngine[CartesiaTTSConfig]):
85
87
  url, headers=headers, json=payload, chunked=True
86
88
  ) as response:
87
89
  if 200 <= response.status < 300:
88
- async for data in response.content.iter_chunked(1024):
89
- yield self.engine_bytes_to_rasa_audio_bytes(data)
90
+ async for chunk in response.content:
91
+ # we are looking for chunks in the response that look like
92
+ # b"data: {..., data: <base64 encoded audio bytes> ...}"
93
+ # and extract the audio bytes from that
94
+ if chunk.startswith(b"data: "):
95
+ json_bytes = chunk[5:-1]
96
+ json_data = json.loads(json_bytes.decode())
97
+ if "data" in json_data:
98
+ base64_encoded_bytes = json_data["data"]
99
+ channel_bytes = base64.b64decode(base64_encoded_bytes)
100
+ yield self.engine_bytes_to_rasa_audio_bytes(
101
+ channel_bytes
102
+ )
90
103
  return
91
104
  else:
92
105
  structlogger.error(
@@ -98,6 +98,7 @@ class TwilioMediaStreamsInputChannel(VoiceInputChannel):
98
98
  def map_input_message(
99
99
  self,
100
100
  message: Any,
101
+ ws: Websocket,
101
102
  ) -> VoiceChannelAction:
102
103
  data = json.loads(message)
103
104
  if data["event"] == "media":
@@ -142,7 +143,7 @@ class TwilioMediaStreamsInputChannel(VoiceInputChannel):
142
143
  def blueprint(
143
144
  self, on_new_message: Callable[[UserMessage], Awaitable[Any]]
144
145
  ) -> Blueprint:
145
- """Defines a Sanic bluelogger.debug."""
146
+ """Defines a Sanic blueprint for the voice input channel."""
146
147
  blueprint = Blueprint("twilio_media_streams", __name__)
147
148
 
148
149
  @blueprint.route("/", methods=["GET"])
@@ -315,6 +315,7 @@ class VoiceInputChannel(InputChannel):
315
315
  def map_input_message(
316
316
  self,
317
317
  message: Any,
318
+ ws: Websocket,
318
319
  ) -> VoiceChannelAction:
319
320
  """Map a channel input message to a voice channel action."""
320
321
  raise NotImplementedError
@@ -340,7 +341,7 @@ class VoiceInputChannel(InputChannel):
340
341
  async def consume_audio_bytes() -> None:
341
342
  async for message in channel_websocket:
342
343
  is_bot_speaking_before = call_state.is_bot_speaking
343
- channel_action = self.map_input_message(message)
344
+ channel_action = self.map_input_message(message, channel_websocket)
344
345
  is_bot_speaking_after = call_state.is_bot_speaking
345
346
 
346
347
  if not is_bot_speaking_before and is_bot_speaking_after:
rasa/core/migrate.py CHANGED
@@ -14,7 +14,7 @@ from rasa.shared.constants import (
14
14
  )
15
15
  from rasa.shared.core.constants import (
16
16
  ACTIVE_LOOP,
17
- MAPPING_TYPE,
17
+ KEY_MAPPING_TYPE,
18
18
  REQUESTED_SLOT,
19
19
  SLOT_MAPPINGS,
20
20
  SlotMappingType,
@@ -43,7 +43,7 @@ def _create_back_up(domain_file: Path, backup_location: Path) -> Dict[Text, Any]
43
43
  def _get_updated_mapping_condition(
44
44
  condition: Dict[Text, Text], mapping: Dict[Text, Any], slot_name: Text
45
45
  ) -> Dict[Text, Text]:
46
- if mapping.get(MAPPING_TYPE) not in [
46
+ if mapping.get(KEY_MAPPING_TYPE) not in [
47
47
  str(SlotMappingType.FROM_ENTITY),
48
48
  str(SlotMappingType.FROM_TRIGGER_INTENT),
49
49
  ]:
@@ -23,6 +23,7 @@ from rasa.core.policies.flows.flow_step_result import (
23
23
  )
24
24
  from rasa.dialogue_understanding.commands import CancelFlowCommand
25
25
  from rasa.dialogue_understanding.patterns.cancel import CancelPatternFlowStackFrame
26
+ from rasa.dialogue_understanding.patterns.clarify import ClarifyPatternFlowStackFrame
26
27
  from rasa.dialogue_understanding.patterns.collect_information import (
27
28
  CollectInformationPatternFlowStackFrame,
28
29
  )
@@ -50,9 +51,12 @@ from rasa.dialogue_understanding.stack.frames.flow_stack_frame import (
50
51
  )
51
52
  from rasa.dialogue_understanding.stack.utils import (
52
53
  top_user_flow_frame,
54
+ user_flows_on_the_stack,
53
55
  )
54
56
  from rasa.shared.constants import RASA_PATTERN_HUMAN_HANDOFF
55
- from rasa.shared.core.constants import ACTION_LISTEN_NAME, SlotMappingType
57
+ from rasa.shared.core.constants import (
58
+ ACTION_LISTEN_NAME,
59
+ )
56
60
  from rasa.shared.core.events import (
57
61
  Event,
58
62
  FlowCompleted,
@@ -272,6 +276,28 @@ def trigger_pattern_continue_interrupted(
272
276
  return events
273
277
 
274
278
 
279
+ def trigger_pattern_clarification(
280
+ current_frame: DialogueStackFrame, stack: DialogueStack, flows: FlowsList
281
+ ) -> None:
282
+ """Trigger the pattern to clarify which topic to continue if needed."""
283
+ if not isinstance(current_frame, UserFlowStackFrame):
284
+ return None
285
+
286
+ if current_frame.frame_type == FlowStackFrameType.CALL:
287
+ # we want to return to the flow that called the current flow
288
+ return None
289
+
290
+ pending_flows = [
291
+ flows.flow_by_id(frame.flow_id)
292
+ for frame in stack.frames
293
+ if isinstance(frame, UserFlowStackFrame)
294
+ and frame.flow_id != current_frame.flow_id
295
+ ]
296
+
297
+ flow_names = [flow.readable_name() for flow in pending_flows if flow is not None]
298
+ stack.push(ClarifyPatternFlowStackFrame(names=flow_names))
299
+
300
+
275
301
  def trigger_pattern_completed(
276
302
  current_frame: DialogueStackFrame, stack: DialogueStack, flows: FlowsList
277
303
  ) -> None:
@@ -540,38 +566,6 @@ def cancel_flow_and_push_internal_error(stack: DialogueStack, flow_name: str) ->
540
566
  stack.push(InternalErrorPatternFlowStackFrame())
541
567
 
542
568
 
543
- def validate_custom_slot_mappings(
544
- step: CollectInformationFlowStep,
545
- stack: DialogueStack,
546
- tracker: DialogueStateTracker,
547
- available_actions: List[str],
548
- flow_name: str,
549
- ) -> bool:
550
- """Validate a slot with custom mappings.
551
-
552
- If invalid, trigger pattern_internal_error and return False.
553
- """
554
- slot = tracker.slots.get(step.collect, None)
555
- slot_mappings = slot.mappings if slot else []
556
- for mapping in slot_mappings:
557
- if (
558
- mapping.get("type") == SlotMappingType.CUSTOM.value
559
- and mapping.get("action") is None
560
- ):
561
- # this is a slot that must be filled by a custom action
562
- # check if collect_action exists
563
- if step.collect_action not in available_actions:
564
- structlogger.error(
565
- "flow.step.run.collect_action_not_found_for_custom_slot_mapping",
566
- action=step.collect_action,
567
- collect=step.collect,
568
- )
569
- cancel_flow_and_push_internal_error(stack, flow_name)
570
- return False
571
-
572
- return True
573
-
574
-
575
569
  def attach_stack_metadata_to_events(
576
570
  step_id: str,
577
571
  flow_id: str,
@@ -669,7 +663,15 @@ def _run_end_step(
669
663
  structlogger.debug("flow.step.run.flow_end")
670
664
  current_frame = stack.pop()
671
665
  trigger_pattern_completed(current_frame, stack, flows)
672
- resumed_events = trigger_pattern_continue_interrupted(current_frame, stack, flows)
666
+ resumed_events = []
667
+ if len(user_flows_on_the_stack(stack)) > 1:
668
+ # if there are more user flows on the stack,
669
+ # we need to trigger the pattern clarify
670
+ trigger_pattern_clarification(current_frame, stack, flows)
671
+ else:
672
+ resumed_events = trigger_pattern_continue_interrupted(
673
+ current_frame, stack, flows
674
+ )
673
675
  reset_events: List[Event] = reset_scoped_slots(current_frame, flow, tracker)
674
676
  return ContinueFlowWithNextStep(
675
677
  events=initial_events + reset_events + resumed_events, has_flow_ended=True
@@ -760,14 +762,6 @@ def _run_collect_information_step(
760
762
  # if we return any other FlowStepResult, the assistant will stay silent
761
763
  # instead of triggering the internal error pattern
762
764
  return ContinueFlowWithNextStep(events=initial_events)
763
- is_mapping_valid = validate_custom_slot_mappings(
764
- step, stack, tracker, available_actions, flow_name
765
- )
766
-
767
- if not is_mapping_valid:
768
- # if we return any other FlowStepResult, the assistant will stay silent
769
- # instead of triggering the internal error pattern
770
- return ContinueFlowWithNextStep(events=initial_events)
771
765
 
772
766
  structlogger.debug("flow.step.run.collect")
773
767
  trigger_pattern_ask_collect_information(
rasa/core/run.py CHANGED
@@ -283,9 +283,10 @@ def serve_application(
283
283
  endpoints.lock_store if endpoints else None
284
284
  )
285
285
 
286
- telemetry.track_server_start(
287
- input_channels, endpoints, model_path, number_of_workers, enable_api
288
- )
286
+ if not inspect:
287
+ telemetry.track_server_start(
288
+ input_channels, endpoints, model_path, number_of_workers, enable_api
289
+ )
289
290
 
290
291
  rasa.utils.common.update_sanic_log_level(
291
292
  log_file, use_syslog, syslog_address, syslog_port, syslog_protocol
@@ -74,7 +74,7 @@ class CannotHandleCommand(Command):
74
74
 
75
75
  def to_dsl(self) -> str:
76
76
  """Converts the command to a DSL string."""
77
- return "cannot handle"
77
+ return "CannotHandle()"
78
78
 
79
79
  @classmethod
80
80
  def from_dsl(cls, match: re.Match, **kwargs: Any) -> CannotHandleCommand:
@@ -86,4 +86,4 @@ class CannotHandleCommand(Command):
86
86
 
87
87
  @staticmethod
88
88
  def regex_pattern() -> str:
89
- return r"^cannot handle$"
89
+ return r"CannotHandle\(\)"
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import copy
3
4
  import re
4
5
  from dataclasses import dataclass
5
6
  from typing import Any, Dict, List
@@ -8,8 +9,11 @@ import structlog
8
9
 
9
10
  from rasa.dialogue_understanding.commands.command import Command
10
11
  from rasa.dialogue_understanding.patterns.cancel import CancelPatternFlowStackFrame
12
+ from rasa.dialogue_understanding.patterns.clarify import ClarifyPatternFlowStackFrame
11
13
  from rasa.dialogue_understanding.stack.dialogue_stack import DialogueStack
12
- from rasa.dialogue_understanding.stack.frames import UserFlowStackFrame
14
+ from rasa.dialogue_understanding.stack.frames import (
15
+ UserFlowStackFrame,
16
+ )
13
17
  from rasa.dialogue_understanding.stack.frames.flow_stack_frame import FlowStackFrameType
14
18
  from rasa.dialogue_understanding.stack.utils import top_user_flow_frame
15
19
  from rasa.shared.core.events import Event, FlowCancelled
@@ -89,7 +93,8 @@ class CancelFlowCommand(Command):
89
93
  original_stack = original_tracker.stack
90
94
 
91
95
  applied_events: List[Event] = []
92
-
96
+ # capture the top frame before we push new frames onto the stack
97
+ initial_top_frame = stack.top()
93
98
  user_frame = top_user_flow_frame(original_stack)
94
99
  current_flow = user_frame.flow(all_flows) if user_frame else None
95
100
 
@@ -114,6 +119,21 @@ class CancelFlowCommand(Command):
114
119
  if user_frame:
115
120
  applied_events.append(FlowCancelled(user_frame.flow_id, user_frame.step_id))
116
121
 
122
+ if initial_top_frame and isinstance(
123
+ initial_top_frame, ClarifyPatternFlowStackFrame
124
+ ):
125
+ structlogger.debug(
126
+ "command_executor.cancel_flow.cancel_clarification_options",
127
+ clarification_options=initial_top_frame.clarification_options,
128
+ )
129
+ applied_events += cancel_all_pending_clarification_options(
130
+ initial_top_frame,
131
+ original_stack,
132
+ canceled_frames,
133
+ all_flows,
134
+ stack,
135
+ )
136
+
117
137
  return applied_events + tracker.create_stack_updated_events(stack)
118
138
 
119
139
  def __hash__(self) -> int:
@@ -124,7 +144,7 @@ class CancelFlowCommand(Command):
124
144
 
125
145
  def to_dsl(self) -> str:
126
146
  """Converts the command to a DSL string."""
127
- return "cancel"
147
+ return "CancelFlow()"
128
148
 
129
149
  @classmethod
130
150
  def from_dsl(cls, match: re.Match, **kwargs: Any) -> CancelFlowCommand:
@@ -133,4 +153,42 @@ class CancelFlowCommand(Command):
133
153
 
134
154
  @staticmethod
135
155
  def regex_pattern() -> str:
136
- return r"^cancel$"
156
+ return r"CancelFlow\(\)"
157
+
158
+
159
+ def cancel_all_pending_clarification_options(
160
+ initial_top_frame: ClarifyPatternFlowStackFrame,
161
+ original_stack: DialogueStack,
162
+ canceled_frames: List[str],
163
+ all_flows: FlowsList,
164
+ stack: DialogueStack,
165
+ ) -> List[FlowCancelled]:
166
+ """Cancel all pending clarification options.
167
+
168
+ This is a special case when the assistant asks the user to clarify
169
+ which pending digression flow to start after the completion of an active flow.
170
+ If the user chooses to cancel all options, this function takes care of
171
+ updating the stack by removing all pending flow stack frames
172
+ listed as clarification options.
173
+ """
174
+ clarification_names = set(initial_top_frame.names)
175
+ to_be_canceled_frames = []
176
+ applied_events = []
177
+ for frame in reversed(original_stack.frames):
178
+ if frame.frame_id in canceled_frames:
179
+ continue
180
+
181
+ to_be_canceled_frames.append(frame.frame_id)
182
+ if isinstance(frame, UserFlowStackFrame):
183
+ readable_flow_name = frame.flow(all_flows).readable_name()
184
+ if readable_flow_name in clarification_names:
185
+ stack.push(
186
+ CancelPatternFlowStackFrame(
187
+ canceled_name=readable_flow_name,
188
+ canceled_frames=copy.deepcopy(to_be_canceled_frames),
189
+ )
190
+ )
191
+ applied_events.append(FlowCancelled(frame.flow_id, frame.step_id))
192
+ to_be_canceled_frames.clear()
193
+
194
+ return applied_events