rasa-pro 3.12.0.dev11__py3-none-any.whl → 3.12.0.dev12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/cli/inspect.py +1 -20
- rasa/cli/shell.py +3 -3
- rasa/core/actions/action.py +7 -20
- rasa/core/actions/forms.py +5 -10
- rasa/core/channels/__init__.py +0 -2
- rasa/core/channels/voice_ready/audiocodes.py +23 -42
- rasa/core/channels/voice_stream/browser_audio.py +0 -1
- rasa/core/channels/voice_stream/call_state.py +1 -7
- rasa/core/channels/voice_stream/tts/azure.py +1 -2
- rasa/core/channels/voice_stream/tts/cartesia.py +3 -16
- rasa/core/channels/voice_stream/twilio_media_streams.py +1 -2
- rasa/core/channels/voice_stream/voice_channel.py +1 -2
- rasa/core/migrate.py +2 -2
- rasa/core/policies/flows/flow_executor.py +42 -36
- rasa/core/run.py +3 -4
- rasa/dialogue_understanding/commands/can_not_handle_command.py +2 -2
- rasa/dialogue_understanding/commands/cancel_flow_command.py +4 -62
- rasa/dialogue_understanding/commands/change_flow_command.py +2 -2
- rasa/dialogue_understanding/commands/chit_chat_answer_command.py +2 -2
- rasa/dialogue_understanding/commands/clarify_command.py +2 -2
- rasa/dialogue_understanding/commands/correct_slots_command.py +2 -11
- rasa/dialogue_understanding/commands/human_handoff_command.py +2 -2
- rasa/dialogue_understanding/commands/knowledge_answer_command.py +2 -2
- rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +2 -2
- rasa/dialogue_understanding/commands/set_slot_command.py +15 -7
- rasa/dialogue_understanding/commands/skip_question_command.py +2 -2
- rasa/dialogue_understanding/commands/start_flow_command.py +2 -43
- rasa/dialogue_understanding/commands/utils.py +1 -1
- rasa/dialogue_understanding/constants.py +0 -1
- rasa/dialogue_understanding/generator/command_generator.py +73 -110
- rasa/dialogue_understanding/generator/command_parser.py +1 -1
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +3 -161
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +2 -10
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +3 -44
- rasa/dialogue_understanding/generator/single_step/command_prompt_template.jinja2 +79 -53
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +19 -11
- rasa/dialogue_understanding/generator/utils.py +1 -32
- rasa/dialogue_understanding/patterns/correction.py +1 -13
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +2 -62
- rasa/dialogue_understanding/processor/command_processor.py +28 -115
- rasa/dialogue_understanding/utils.py +0 -31
- rasa/dialogue_understanding_test/README.md +0 -50
- rasa/dialogue_understanding_test/test_case_simulation/test_case_tracker_simulator.py +3 -3
- rasa/model_service.py +0 -4
- rasa/model_training.py +27 -24
- rasa/shared/core/constants.py +3 -28
- rasa/shared/core/domain.py +20 -13
- rasa/shared/core/events.py +2 -13
- rasa/shared/core/flows/flow.py +0 -17
- rasa/shared/core/flows/flows_yaml_schema.json +0 -38
- rasa/shared/core/flows/steps/collect.py +1 -18
- rasa/shared/core/flows/utils.py +1 -16
- rasa/shared/core/slot_mappings.py +108 -144
- rasa/shared/core/slots.py +2 -23
- rasa/shared/core/trackers.py +1 -3
- rasa/shared/nlu/constants.py +0 -1
- rasa/shared/utils/llm.py +1 -1
- rasa/shared/utils/schemas/domain.yml +1 -0
- rasa/telemetry.py +13 -43
- rasa/utils/common.py +1 -0
- rasa/validator.py +82 -189
- rasa/version.py +1 -1
- {rasa_pro-3.12.0.dev11.dist-info → rasa_pro-3.12.0.dev12.dist-info}/METADATA +1 -1
- {rasa_pro-3.12.0.dev11.dist-info → rasa_pro-3.12.0.dev12.dist-info}/RECORD +67 -71
- rasa/core/actions/action_handle_digressions.py +0 -142
- rasa/core/channels/voice_stream/genesys.py +0 -331
- rasa/dialogue_understanding/commands/handle_digressions_command.py +0 -150
- rasa/dialogue_understanding/patterns/handle_digressions.py +0 -81
- {rasa_pro-3.12.0.dev11.dist-info → rasa_pro-3.12.0.dev12.dist-info}/NOTICE +0 -0
- {rasa_pro-3.12.0.dev11.dist-info → rasa_pro-3.12.0.dev12.dist-info}/WHEEL +0 -0
- {rasa_pro-3.12.0.dev11.dist-info → rasa_pro-3.12.0.dev12.dist-info}/entry_points.txt +0 -0
|
@@ -1,331 +0,0 @@
|
|
|
1
|
-
import asyncio
|
|
2
|
-
import json
|
|
3
|
-
from typing import Any, Awaitable, Callable, Dict, Optional, Text
|
|
4
|
-
|
|
5
|
-
import structlog
|
|
6
|
-
from sanic import ( # type: ignore[attr-defined]
|
|
7
|
-
Blueprint,
|
|
8
|
-
HTTPResponse,
|
|
9
|
-
Request,
|
|
10
|
-
Websocket,
|
|
11
|
-
response,
|
|
12
|
-
)
|
|
13
|
-
|
|
14
|
-
from rasa.core.channels import UserMessage
|
|
15
|
-
from rasa.core.channels.voice_ready.utils import CallParameters
|
|
16
|
-
from rasa.core.channels.voice_stream.audio_bytes import RasaAudioBytes
|
|
17
|
-
from rasa.core.channels.voice_stream.call_state import (
|
|
18
|
-
call_state,
|
|
19
|
-
)
|
|
20
|
-
from rasa.core.channels.voice_stream.tts.tts_engine import TTSEngine
|
|
21
|
-
from rasa.core.channels.voice_stream.voice_channel import (
|
|
22
|
-
ContinueConversationAction,
|
|
23
|
-
EndConversationAction,
|
|
24
|
-
NewAudioAction,
|
|
25
|
-
VoiceChannelAction,
|
|
26
|
-
VoiceInputChannel,
|
|
27
|
-
VoiceOutputChannel,
|
|
28
|
-
)
|
|
29
|
-
|
|
30
|
-
# Not mentioned in the documentation but observed in Geneys's example
|
|
31
|
-
# https://github.com/GenesysCloudBlueprints/audioconnector-server-reference-implementation
|
|
32
|
-
MAXIMUM_BINARY_MESSAGE_SIZE = 64000 # 64KB
|
|
33
|
-
logger = structlog.get_logger(__name__)
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
def map_call_params(data: Dict[Text, Any]) -> CallParameters:
|
|
37
|
-
"""Map the twilio stream parameters to the CallParameters dataclass."""
|
|
38
|
-
parameters = data["parameters"]
|
|
39
|
-
participant = parameters["participant"]
|
|
40
|
-
# sent as {"ani": "tel:+491604697810"}
|
|
41
|
-
ani = participant.get("ani", "")
|
|
42
|
-
user_phone = ani.split(":")[-1] if ani else ""
|
|
43
|
-
|
|
44
|
-
return CallParameters(
|
|
45
|
-
call_id=parameters.get("conversationId", ""),
|
|
46
|
-
user_phone=user_phone,
|
|
47
|
-
bot_phone=participant.get("dnis", ""),
|
|
48
|
-
)
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
class GenesysOutputChannel(VoiceOutputChannel):
|
|
52
|
-
@classmethod
|
|
53
|
-
def name(cls) -> str:
|
|
54
|
-
return "genesys"
|
|
55
|
-
|
|
56
|
-
async def send_audio_bytes(
|
|
57
|
-
self, recipient_id: str, audio_bytes: RasaAudioBytes
|
|
58
|
-
) -> None:
|
|
59
|
-
"""
|
|
60
|
-
Send audio bytes to the recipient with buffering.
|
|
61
|
-
|
|
62
|
-
Genesys throws a rate limit error with too many audio messages.
|
|
63
|
-
To avoid this, we buffer the audio messages and send them in chunks.
|
|
64
|
-
|
|
65
|
-
- global.inbound.binary.average.rate.per.second: 5
|
|
66
|
-
The allowed average rate per second of inbound binary data
|
|
67
|
-
|
|
68
|
-
- global.inbound.binary.max: 25
|
|
69
|
-
The maximum number of inbound binary data messages
|
|
70
|
-
that can be sent instantaneously
|
|
71
|
-
|
|
72
|
-
https://developer.genesys.cloud/organization/organization/limits#audiohook
|
|
73
|
-
"""
|
|
74
|
-
call_state.audio_buffer.extend(audio_bytes)
|
|
75
|
-
|
|
76
|
-
# If we receive a non-standard chunk size, assume it's the end of a sequence
|
|
77
|
-
# or buffer is more than 32KB (this is half of genesys's max audio message size)
|
|
78
|
-
if len(audio_bytes) != 1024 or len(call_state.audio_buffer) >= (
|
|
79
|
-
MAXIMUM_BINARY_MESSAGE_SIZE / 2
|
|
80
|
-
):
|
|
81
|
-
# TODO: we should send the buffer when we receive a synthesis complete event
|
|
82
|
-
# from TTS. This will ensure that the last audio chunk is always sent.
|
|
83
|
-
await self._send_audio_buffer(self.voice_websocket)
|
|
84
|
-
|
|
85
|
-
async def _send_audio_buffer(self, ws: Websocket) -> None:
|
|
86
|
-
"""Send the audio buffer to the recipient if it's not empty."""
|
|
87
|
-
if call_state.audio_buffer:
|
|
88
|
-
buffer_bytes = bytes(call_state.audio_buffer)
|
|
89
|
-
await self._send_bytes_to_ws(ws, buffer_bytes)
|
|
90
|
-
call_state.audio_buffer.clear()
|
|
91
|
-
|
|
92
|
-
async def _send_bytes_to_ws(self, ws: Websocket, data: bytes) -> None:
|
|
93
|
-
"""Send audio bytes to the recipient as a binary websocket message."""
|
|
94
|
-
if len(data) <= MAXIMUM_BINARY_MESSAGE_SIZE:
|
|
95
|
-
await self.voice_websocket.send(data)
|
|
96
|
-
else:
|
|
97
|
-
# split the audio into chunks
|
|
98
|
-
current_position = 0
|
|
99
|
-
while current_position < len(data):
|
|
100
|
-
end_position = min(
|
|
101
|
-
current_position + MAXIMUM_BINARY_MESSAGE_SIZE, len(data)
|
|
102
|
-
)
|
|
103
|
-
await self.voice_websocket.send(data[current_position:end_position])
|
|
104
|
-
current_position = end_position
|
|
105
|
-
|
|
106
|
-
async def send_marker_message(self, recipient_id: str) -> None:
|
|
107
|
-
"""Send a message that marks positions in the audio stream."""
|
|
108
|
-
pass
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
class GenesysInputChannel(VoiceInputChannel):
|
|
112
|
-
@classmethod
|
|
113
|
-
def name(cls) -> str:
|
|
114
|
-
return "genesys"
|
|
115
|
-
|
|
116
|
-
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
117
|
-
super().__init__(*args, **kwargs)
|
|
118
|
-
|
|
119
|
-
def _get_next_sequence(self) -> int:
|
|
120
|
-
"""
|
|
121
|
-
Get the next message sequence number
|
|
122
|
-
Rasa == Server
|
|
123
|
-
Genesys == Client
|
|
124
|
-
|
|
125
|
-
Genesys requires the server and client each maintain a
|
|
126
|
-
monotonically increasing message sequence number.
|
|
127
|
-
"""
|
|
128
|
-
cs = call_state
|
|
129
|
-
cs.server_sequence_number += 1 # type: ignore[attr-defined]
|
|
130
|
-
return cs.server_sequence_number
|
|
131
|
-
|
|
132
|
-
def _get_last_client_sequence(self) -> int:
|
|
133
|
-
"""Get the last client(Genesys) sequence number."""
|
|
134
|
-
return call_state.client_sequence_number
|
|
135
|
-
|
|
136
|
-
def _update_client_sequence(self, seq: int) -> None:
|
|
137
|
-
"""Update the client(Genesys) sequence number."""
|
|
138
|
-
if seq - call_state.client_sequence_number != 1:
|
|
139
|
-
logger.warning(
|
|
140
|
-
"genesys.update_client_sequence.sequence_gap",
|
|
141
|
-
received_seq=seq,
|
|
142
|
-
last_seq=call_state.client_sequence_number,
|
|
143
|
-
)
|
|
144
|
-
call_state.client_sequence_number = seq # type: ignore[attr-defined]
|
|
145
|
-
|
|
146
|
-
def channel_bytes_to_rasa_audio_bytes(self, input_bytes: bytes) -> RasaAudioBytes:
|
|
147
|
-
return RasaAudioBytes(input_bytes)
|
|
148
|
-
|
|
149
|
-
async def collect_call_parameters(
|
|
150
|
-
self, channel_websocket: Websocket
|
|
151
|
-
) -> Optional[CallParameters]:
|
|
152
|
-
"""Call Parameters are collected during the open event."""
|
|
153
|
-
async for message in channel_websocket:
|
|
154
|
-
data = json.loads(message)
|
|
155
|
-
self._update_client_sequence(data["seq"])
|
|
156
|
-
if data.get("type") == "open":
|
|
157
|
-
call_params = await self.handle_open(channel_websocket, data)
|
|
158
|
-
return call_params
|
|
159
|
-
else:
|
|
160
|
-
logger.error("genesys.receive.unexpected_initial_message", message=data)
|
|
161
|
-
|
|
162
|
-
return None
|
|
163
|
-
|
|
164
|
-
def map_input_message(
|
|
165
|
-
self,
|
|
166
|
-
message: Any,
|
|
167
|
-
ws: Websocket,
|
|
168
|
-
) -> VoiceChannelAction:
|
|
169
|
-
# if message is binary, it's audio
|
|
170
|
-
if isinstance(message, bytes):
|
|
171
|
-
return NewAudioAction(self.channel_bytes_to_rasa_audio_bytes(message))
|
|
172
|
-
else:
|
|
173
|
-
# process text message
|
|
174
|
-
data = json.loads(message)
|
|
175
|
-
self._update_client_sequence(data["seq"])
|
|
176
|
-
msg_type = data.get("type")
|
|
177
|
-
if msg_type == "close":
|
|
178
|
-
logger.info("genesys.handle_close", message=data)
|
|
179
|
-
self.handle_close(ws, data)
|
|
180
|
-
return EndConversationAction()
|
|
181
|
-
elif msg_type == "ping":
|
|
182
|
-
logger.info("genesys.handle_ping", message=data)
|
|
183
|
-
self.handle_ping(ws, data)
|
|
184
|
-
elif msg_type == "playback_started":
|
|
185
|
-
logger.debug("genesys.handle_playback_started", message=data)
|
|
186
|
-
call_state.is_bot_speaking = True # type: ignore[attr-defined]
|
|
187
|
-
elif msg_type == "playback_completed":
|
|
188
|
-
logger.debug("genesys.handle_playback_completed", message=data)
|
|
189
|
-
call_state.is_bot_speaking = False # type: ignore[attr-defined]
|
|
190
|
-
if call_state.should_hangup:
|
|
191
|
-
logger.info("genesys.hangup")
|
|
192
|
-
self.disconnect(ws, data)
|
|
193
|
-
elif msg_type == "dtmf":
|
|
194
|
-
logger.info("genesys.handle_dtmf", message=data)
|
|
195
|
-
elif msg_type == "error":
|
|
196
|
-
logger.warning("genesys.handle_error", message=data)
|
|
197
|
-
else:
|
|
198
|
-
logger.warning("genesys.map_input_message.unknown_type", message=data)
|
|
199
|
-
|
|
200
|
-
return ContinueConversationAction()
|
|
201
|
-
|
|
202
|
-
def create_output_channel(
|
|
203
|
-
self, voice_websocket: Websocket, tts_engine: TTSEngine
|
|
204
|
-
) -> VoiceOutputChannel:
|
|
205
|
-
return GenesysOutputChannel(
|
|
206
|
-
voice_websocket,
|
|
207
|
-
tts_engine,
|
|
208
|
-
self.tts_cache,
|
|
209
|
-
)
|
|
210
|
-
|
|
211
|
-
async def handle_open(self, ws: Websocket, message: dict) -> CallParameters:
|
|
212
|
-
"""Handle initial open transaction from Genesys."""
|
|
213
|
-
call_parameters = map_call_params(message)
|
|
214
|
-
params = message["parameters"]
|
|
215
|
-
media_options = params.get("media", [])
|
|
216
|
-
|
|
217
|
-
# Send opened response
|
|
218
|
-
if media_options:
|
|
219
|
-
logger.info("genesys.handle_open", media_parameter=media_options[0])
|
|
220
|
-
response = {
|
|
221
|
-
"version": "2",
|
|
222
|
-
"type": "opened",
|
|
223
|
-
"seq": self._get_next_sequence(),
|
|
224
|
-
"clientseq": self._get_last_client_sequence(),
|
|
225
|
-
"id": message.get("id"),
|
|
226
|
-
"parameters": {"startPaused": False, "media": [media_options[0]]},
|
|
227
|
-
}
|
|
228
|
-
logger.debug("genesys.handle_open.opened", response=response)
|
|
229
|
-
await ws.send(json.dumps(response))
|
|
230
|
-
else:
|
|
231
|
-
logger.warning(
|
|
232
|
-
"genesys.handle_open.no_media_formats", client_message=message
|
|
233
|
-
)
|
|
234
|
-
return call_parameters
|
|
235
|
-
|
|
236
|
-
def handle_ping(self, ws: Websocket, message: dict) -> None:
|
|
237
|
-
"""Handle ping message from Genesys."""
|
|
238
|
-
response = {
|
|
239
|
-
"version": "2",
|
|
240
|
-
"type": "pong",
|
|
241
|
-
"seq": self._get_next_sequence(),
|
|
242
|
-
"clientseq": message.get("seq"),
|
|
243
|
-
"id": message.get("id"),
|
|
244
|
-
"parameters": {},
|
|
245
|
-
}
|
|
246
|
-
logger.debug("genesys.handle_ping.pong", response=response)
|
|
247
|
-
_schedule_ws_task(ws.send(json.dumps(response)))
|
|
248
|
-
|
|
249
|
-
def handle_close(self, ws: Websocket, message: dict) -> None:
|
|
250
|
-
"""Handle close message from Genesys."""
|
|
251
|
-
response = {
|
|
252
|
-
"version": "2",
|
|
253
|
-
"type": "closed",
|
|
254
|
-
"seq": self._get_next_sequence(),
|
|
255
|
-
"clientseq": self._get_last_client_sequence(),
|
|
256
|
-
"id": message.get("id"),
|
|
257
|
-
"parameters": message.get("parameters", {}),
|
|
258
|
-
}
|
|
259
|
-
logger.debug("genesys.handle_close.closed", response=response)
|
|
260
|
-
|
|
261
|
-
_schedule_ws_task(ws.send(json.dumps(response)))
|
|
262
|
-
_schedule_ws_task(ws.close())
|
|
263
|
-
|
|
264
|
-
def disconnect(self, ws: Websocket, data: dict) -> None:
|
|
265
|
-
"""
|
|
266
|
-
Send disconnect message to Genesys.
|
|
267
|
-
|
|
268
|
-
https://developer.genesys.cloud/devapps/audiohook/protocol-reference#disconnect
|
|
269
|
-
It should be used to hangup the call.
|
|
270
|
-
Genesys will respond with a "close" message to us
|
|
271
|
-
that is handled by the handle_close method.
|
|
272
|
-
"""
|
|
273
|
-
message = {
|
|
274
|
-
"version": "2",
|
|
275
|
-
"type": "disconnect",
|
|
276
|
-
"seq": self._get_next_sequence(),
|
|
277
|
-
"clientseq": self._get_last_client_sequence(),
|
|
278
|
-
"id": data.get("id"),
|
|
279
|
-
"parameters": {
|
|
280
|
-
"reason": "completed",
|
|
281
|
-
# arbitrary values can be sent here
|
|
282
|
-
},
|
|
283
|
-
}
|
|
284
|
-
logger.debug("genesys.disconnect", message=message)
|
|
285
|
-
_schedule_ws_task(ws.send(json.dumps(message)))
|
|
286
|
-
|
|
287
|
-
def blueprint(
|
|
288
|
-
self, on_new_message: Callable[[UserMessage], Awaitable[Any]]
|
|
289
|
-
) -> Blueprint:
|
|
290
|
-
"""Defines a Sanic blueprint for the voice input channel."""
|
|
291
|
-
blueprint = Blueprint("genesys", __name__)
|
|
292
|
-
|
|
293
|
-
@blueprint.route("/", methods=["GET"])
|
|
294
|
-
async def health(_: Request) -> HTTPResponse:
|
|
295
|
-
return response.json({"status": "ok"})
|
|
296
|
-
|
|
297
|
-
@blueprint.websocket("/websocket") # type: ignore[misc]
|
|
298
|
-
async def receive(request: Request, ws: Websocket) -> None:
|
|
299
|
-
logger.debug(
|
|
300
|
-
"genesys.receive",
|
|
301
|
-
audiohook_session_id=request.headers.get("audiohook-session-id"),
|
|
302
|
-
)
|
|
303
|
-
# validate required headers
|
|
304
|
-
required_headers = [
|
|
305
|
-
"audiohook-organization-id",
|
|
306
|
-
"audiohook-correlation-id",
|
|
307
|
-
"audiohook-session-id",
|
|
308
|
-
"x-api-key",
|
|
309
|
-
]
|
|
310
|
-
|
|
311
|
-
for header in required_headers:
|
|
312
|
-
if header not in request.headers:
|
|
313
|
-
await ws.close(1008, f"Missing required header: {header}")
|
|
314
|
-
return
|
|
315
|
-
|
|
316
|
-
# TODO: validate API key header
|
|
317
|
-
# process audio streaming
|
|
318
|
-
logger.info("genesys.receive", message="Starting audio streaming")
|
|
319
|
-
await self.run_audio_streaming(on_new_message, ws)
|
|
320
|
-
|
|
321
|
-
return blueprint
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
def _schedule_ws_task(coro: Awaitable[Any]) -> None:
|
|
325
|
-
"""Helper function to schedule a coroutine in the event loop.
|
|
326
|
-
|
|
327
|
-
Args:
|
|
328
|
-
coro: The coroutine to schedule
|
|
329
|
-
"""
|
|
330
|
-
loop = asyncio.get_running_loop()
|
|
331
|
-
loop.call_soon_threadsafe(lambda: loop.create_task(coro))
|
|
@@ -1,150 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
from dataclasses import dataclass
|
|
4
|
-
from typing import Any, Dict, List
|
|
5
|
-
|
|
6
|
-
import structlog
|
|
7
|
-
|
|
8
|
-
from rasa.dialogue_understanding.commands.command import Command
|
|
9
|
-
from rasa.dialogue_understanding.patterns.cannot_handle import (
|
|
10
|
-
CannotHandlePatternFlowStackFrame,
|
|
11
|
-
)
|
|
12
|
-
from rasa.dialogue_understanding.patterns.handle_digressions import (
|
|
13
|
-
HandleDigressionsPatternFlowStackFrame,
|
|
14
|
-
)
|
|
15
|
-
from rasa.dialogue_understanding.stack.utils import (
|
|
16
|
-
top_flow_frame,
|
|
17
|
-
user_flows_on_the_stack,
|
|
18
|
-
)
|
|
19
|
-
from rasa.shared.core.events import Event
|
|
20
|
-
from rasa.shared.core.flows import FlowsList
|
|
21
|
-
from rasa.shared.core.flows.steps import CollectInformationFlowStep
|
|
22
|
-
from rasa.shared.core.flows.utils import ALL_LABEL
|
|
23
|
-
from rasa.shared.core.trackers import DialogueStateTracker
|
|
24
|
-
from rasa.shared.nlu.constants import HANDLE_DIGRESSIONS_COMMAND
|
|
25
|
-
|
|
26
|
-
structlogger = structlog.get_logger()
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
@dataclass
|
|
30
|
-
class HandleDigressionsCommand(Command):
|
|
31
|
-
"""A command to handle digressions during an active flow."""
|
|
32
|
-
|
|
33
|
-
flow: str
|
|
34
|
-
"""The interrupting flow."""
|
|
35
|
-
|
|
36
|
-
@classmethod
|
|
37
|
-
def command(cls) -> str:
|
|
38
|
-
"""Returns the command type."""
|
|
39
|
-
return HANDLE_DIGRESSIONS_COMMAND
|
|
40
|
-
|
|
41
|
-
@classmethod
|
|
42
|
-
def from_dict(cls, data: Dict[str, Any]) -> HandleDigressionsCommand:
|
|
43
|
-
"""Converts the dictionary to a command.
|
|
44
|
-
|
|
45
|
-
Returns:
|
|
46
|
-
The converted dictionary.
|
|
47
|
-
"""
|
|
48
|
-
try:
|
|
49
|
-
return HandleDigressionsCommand(flow=data["flow"])
|
|
50
|
-
except KeyError as e:
|
|
51
|
-
raise ValueError(
|
|
52
|
-
f"Missing parameter '{e}' while parsing HandleDigressionsCommand."
|
|
53
|
-
) from e
|
|
54
|
-
|
|
55
|
-
def run_command_on_tracker(
|
|
56
|
-
self,
|
|
57
|
-
tracker: DialogueStateTracker,
|
|
58
|
-
all_flows: FlowsList,
|
|
59
|
-
original_tracker: DialogueStateTracker,
|
|
60
|
-
) -> List[Event]:
|
|
61
|
-
"""Runs the command on the tracker.
|
|
62
|
-
|
|
63
|
-
Args:
|
|
64
|
-
tracker: The tracker to run the command on.
|
|
65
|
-
all_flows: All flows in the assistant.
|
|
66
|
-
original_tracker: The tracker before any command was executed.
|
|
67
|
-
|
|
68
|
-
Returns:
|
|
69
|
-
The events to apply to the tracker.
|
|
70
|
-
"""
|
|
71
|
-
stack = tracker.stack
|
|
72
|
-
original_stack = original_tracker.stack
|
|
73
|
-
|
|
74
|
-
if self.flow in user_flows_on_the_stack(stack):
|
|
75
|
-
structlogger.debug(
|
|
76
|
-
"command_executor.skip_command.already_started_flow", command=self
|
|
77
|
-
)
|
|
78
|
-
return []
|
|
79
|
-
elif self.flow not in all_flows.flow_ids:
|
|
80
|
-
structlogger.debug(
|
|
81
|
-
"command_executor.push_cannot_handle.start_invalid_flow_id",
|
|
82
|
-
command=self,
|
|
83
|
-
)
|
|
84
|
-
stack.push(CannotHandlePatternFlowStackFrame())
|
|
85
|
-
return tracker.create_stack_updated_events(stack)
|
|
86
|
-
|
|
87
|
-
# this allows to include called user flows in the stack search
|
|
88
|
-
latest_user_frame = top_flow_frame(original_stack, ignore_call_frames=False)
|
|
89
|
-
|
|
90
|
-
if latest_user_frame is None:
|
|
91
|
-
structlogger.debug(
|
|
92
|
-
"command_executor.skip_command.no_top_flow", command=self
|
|
93
|
-
)
|
|
94
|
-
return []
|
|
95
|
-
|
|
96
|
-
original_top_flow = latest_user_frame.flow(all_flows)
|
|
97
|
-
current_step = original_top_flow.step_by_id(latest_user_frame.step_id)
|
|
98
|
-
if not isinstance(current_step, CollectInformationFlowStep):
|
|
99
|
-
structlogger.debug(
|
|
100
|
-
"command_executor.skip_command.not_at_a_collect_step", command=self
|
|
101
|
-
)
|
|
102
|
-
return []
|
|
103
|
-
|
|
104
|
-
ask_confirm_digressions = set(
|
|
105
|
-
current_step.ask_confirm_digressions
|
|
106
|
-
+ original_top_flow.ask_confirm_digressions
|
|
107
|
-
)
|
|
108
|
-
|
|
109
|
-
block_digressions = set(
|
|
110
|
-
current_step.block_digressions + original_top_flow.block_digressions
|
|
111
|
-
)
|
|
112
|
-
|
|
113
|
-
if block_digressions:
|
|
114
|
-
if ALL_LABEL in block_digressions:
|
|
115
|
-
block_digressions.remove(ALL_LABEL)
|
|
116
|
-
block_digressions.add(self.flow)
|
|
117
|
-
|
|
118
|
-
if ask_confirm_digressions:
|
|
119
|
-
if ALL_LABEL in ask_confirm_digressions:
|
|
120
|
-
ask_confirm_digressions.remove(ALL_LABEL)
|
|
121
|
-
ask_confirm_digressions.add(self.flow)
|
|
122
|
-
|
|
123
|
-
structlogger.debug(
|
|
124
|
-
"command_executor.push_handle_digressions",
|
|
125
|
-
interrupting_flow_id=self.flow,
|
|
126
|
-
interrupted_flow_id=original_top_flow.id,
|
|
127
|
-
interrupted_step_id=current_step.id,
|
|
128
|
-
ask_confirm_digressions=ask_confirm_digressions,
|
|
129
|
-
block_digressions=block_digressions,
|
|
130
|
-
)
|
|
131
|
-
stack.push(
|
|
132
|
-
HandleDigressionsPatternFlowStackFrame(
|
|
133
|
-
interrupting_flow_id=self.flow,
|
|
134
|
-
interrupted_flow_id=original_top_flow.id,
|
|
135
|
-
interrupted_step_id=current_step.id,
|
|
136
|
-
ask_confirm_digressions=ask_confirm_digressions,
|
|
137
|
-
block_digressions=block_digressions,
|
|
138
|
-
)
|
|
139
|
-
)
|
|
140
|
-
|
|
141
|
-
return tracker.create_stack_updated_events(stack)
|
|
142
|
-
|
|
143
|
-
def __hash__(self) -> int:
|
|
144
|
-
return hash(self.flow)
|
|
145
|
-
|
|
146
|
-
def __eq__(self, other: object) -> bool:
|
|
147
|
-
if not isinstance(other, HandleDigressionsCommand):
|
|
148
|
-
return False
|
|
149
|
-
|
|
150
|
-
return other.flow == self.flow
|
|
@@ -1,81 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
from dataclasses import dataclass, field
|
|
4
|
-
from typing import Any, Dict, Set
|
|
5
|
-
|
|
6
|
-
from rasa.dialogue_understanding.stack.frames import PatternFlowStackFrame
|
|
7
|
-
from rasa.shared.constants import RASA_DEFAULT_FLOW_PATTERN_PREFIX
|
|
8
|
-
from rasa.shared.core.constants import (
|
|
9
|
-
KEY_ASK_CONFIRM_DIGRESSIONS,
|
|
10
|
-
KEY_BLOCK_DIGRESSIONS,
|
|
11
|
-
)
|
|
12
|
-
|
|
13
|
-
FLOW_PATTERN_HANDLE_DIGRESSIONS = (
|
|
14
|
-
RASA_DEFAULT_FLOW_PATTERN_PREFIX + "handle_digressions"
|
|
15
|
-
)
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
@dataclass
|
|
19
|
-
class HandleDigressionsPatternFlowStackFrame(PatternFlowStackFrame):
|
|
20
|
-
"""A pattern flow stack frame that gets added if an interruption is completed."""
|
|
21
|
-
|
|
22
|
-
flow_id: str = FLOW_PATTERN_HANDLE_DIGRESSIONS
|
|
23
|
-
"""The ID of the flow."""
|
|
24
|
-
interrupting_flow_id: str = ""
|
|
25
|
-
"""The ID of the flow that interrupted the active flow."""
|
|
26
|
-
interrupted_flow_id: str = ""
|
|
27
|
-
"""The name of the active flow that was interrupted."""
|
|
28
|
-
interrupted_step_id: str = ""
|
|
29
|
-
"""The ID of the step that was interrupted."""
|
|
30
|
-
ask_confirm_digressions: Set[str] = field(default_factory=set)
|
|
31
|
-
"""The set of interrupting flow names to confirm."""
|
|
32
|
-
block_digressions: Set[str] = field(default_factory=set)
|
|
33
|
-
"""The set of interrupting flow names to block."""
|
|
34
|
-
|
|
35
|
-
@classmethod
|
|
36
|
-
def type(cls) -> str:
|
|
37
|
-
"""Returns the type of the frame."""
|
|
38
|
-
return FLOW_PATTERN_HANDLE_DIGRESSIONS
|
|
39
|
-
|
|
40
|
-
@staticmethod
|
|
41
|
-
def from_dict(data: Dict[str, Any]) -> HandleDigressionsPatternFlowStackFrame:
|
|
42
|
-
"""Creates a `DialogueStackFrame` from a dictionary.
|
|
43
|
-
|
|
44
|
-
Args:
|
|
45
|
-
data: The dictionary to create the `DialogueStackFrame` from.
|
|
46
|
-
|
|
47
|
-
Returns:
|
|
48
|
-
The created `DialogueStackFrame`.
|
|
49
|
-
"""
|
|
50
|
-
return HandleDigressionsPatternFlowStackFrame(
|
|
51
|
-
frame_id=data["frame_id"],
|
|
52
|
-
step_id=data["step_id"],
|
|
53
|
-
interrupted_step_id=data["interrupted_step_id"],
|
|
54
|
-
interrupted_flow_id=data["interrupted_flow_id"],
|
|
55
|
-
interrupting_flow_id=data["interrupting_flow_id"],
|
|
56
|
-
ask_confirm_digressions=set(data.get(KEY_ASK_CONFIRM_DIGRESSIONS, [])),
|
|
57
|
-
# This attribute must be converted to a set to enable usage
|
|
58
|
-
# of subset `contains` pypred operator in the default pattern
|
|
59
|
-
# conditional branching
|
|
60
|
-
block_digressions=set(data.get(KEY_BLOCK_DIGRESSIONS, [])),
|
|
61
|
-
)
|
|
62
|
-
|
|
63
|
-
def __eq__(self, other: Any) -> bool:
|
|
64
|
-
if not isinstance(other, HandleDigressionsPatternFlowStackFrame):
|
|
65
|
-
return False
|
|
66
|
-
return (
|
|
67
|
-
self.flow_id == other.flow_id
|
|
68
|
-
and self.interrupted_step_id == other.interrupted_step_id
|
|
69
|
-
and self.interrupted_flow_id == other.interrupted_flow_id
|
|
70
|
-
and self.interrupting_flow_id == other.interrupting_flow_id
|
|
71
|
-
and self.ask_confirm_digressions == other.ask_confirm_digressions
|
|
72
|
-
and self.block_digressions == other.block_digressions
|
|
73
|
-
)
|
|
74
|
-
|
|
75
|
-
def as_dict(self) -> Dict[str, Any]:
|
|
76
|
-
"""Returns the frame as a dictionary."""
|
|
77
|
-
data = super().as_dict()
|
|
78
|
-
# converting back to list to avoid serialization issues
|
|
79
|
-
data[KEY_ASK_CONFIRM_DIGRESSIONS] = list(self.ask_confirm_digressions)
|
|
80
|
-
data[KEY_BLOCK_DIGRESSIONS] = list(self.block_digressions)
|
|
81
|
-
return data
|
|
File without changes
|
|
File without changes
|
|
File without changes
|