rasa-pro 3.12.0.dev8__py3-none-any.whl → 3.12.0.dev10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/core/actions/action.py +17 -3
- rasa/core/actions/action_handle_digressions.py +142 -0
- rasa/core/actions/forms.py +4 -2
- rasa/core/channels/voice_ready/audiocodes.py +42 -23
- rasa/core/channels/voice_stream/tts/azure.py +2 -1
- rasa/core/migrate.py +2 -2
- rasa/core/policies/flows/flow_executor.py +33 -1
- rasa/dialogue_understanding/commands/can_not_handle_command.py +2 -2
- rasa/dialogue_understanding/commands/cancel_flow_command.py +62 -4
- rasa/dialogue_understanding/commands/change_flow_command.py +2 -2
- rasa/dialogue_understanding/commands/chit_chat_answer_command.py +2 -2
- rasa/dialogue_understanding/commands/clarify_command.py +2 -2
- rasa/dialogue_understanding/commands/correct_slots_command.py +11 -2
- rasa/dialogue_understanding/commands/handle_digressions_command.py +150 -0
- rasa/dialogue_understanding/commands/human_handoff_command.py +2 -2
- rasa/dialogue_understanding/commands/knowledge_answer_command.py +2 -2
- rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +2 -2
- rasa/dialogue_understanding/commands/set_slot_command.py +7 -15
- rasa/dialogue_understanding/commands/skip_question_command.py +2 -2
- rasa/dialogue_understanding/commands/start_flow_command.py +43 -2
- rasa/dialogue_understanding/commands/utils.py +1 -1
- rasa/dialogue_understanding/constants.py +1 -0
- rasa/dialogue_understanding/generator/command_generator.py +10 -76
- rasa/dialogue_understanding/generator/command_parser.py +1 -1
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +126 -2
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +10 -2
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +4 -2
- rasa/dialogue_understanding/generator/single_step/command_prompt_template.jinja2 +40 -40
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +11 -19
- rasa/dialogue_understanding/patterns/correction.py +13 -1
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +62 -2
- rasa/dialogue_understanding/patterns/handle_digressions.py +81 -0
- rasa/dialogue_understanding/processor/command_processor.py +117 -28
- rasa/dialogue_understanding/utils.py +31 -0
- rasa/dialogue_understanding_test/test_case_simulation/test_case_tracker_simulator.py +2 -2
- rasa/shared/core/constants.py +22 -1
- rasa/shared/core/domain.py +6 -4
- rasa/shared/core/events.py +13 -2
- rasa/shared/core/flows/flow.py +17 -0
- rasa/shared/core/flows/flows_yaml_schema.json +38 -0
- rasa/shared/core/flows/steps/collect.py +18 -1
- rasa/shared/core/flows/utils.py +16 -1
- rasa/shared/core/slot_mappings.py +6 -6
- rasa/shared/core/slots.py +19 -0
- rasa/shared/core/trackers.py +3 -1
- rasa/shared/nlu/constants.py +1 -0
- rasa/shared/providers/llm/_base_litellm_client.py +0 -40
- rasa/shared/utils/llm.py +1 -80
- rasa/shared/utils/schemas/domain.yml +0 -1
- rasa/validator.py +172 -22
- rasa/version.py +1 -1
- {rasa_pro-3.12.0.dev8.dist-info → rasa_pro-3.12.0.dev10.dist-info}/METADATA +1 -1
- {rasa_pro-3.12.0.dev8.dist-info → rasa_pro-3.12.0.dev10.dist-info}/RECORD +56 -53
- {rasa_pro-3.12.0.dev8.dist-info → rasa_pro-3.12.0.dev10.dist-info}/NOTICE +0 -0
- {rasa_pro-3.12.0.dev8.dist-info → rasa_pro-3.12.0.dev10.dist-info}/WHEEL +0 -0
- {rasa_pro-3.12.0.dev8.dist-info → rasa_pro-3.12.0.dev10.dist-info}/entry_points.txt +0 -0
rasa/core/actions/action.py
CHANGED
|
@@ -72,10 +72,11 @@ from rasa.shared.core.constants import (
|
|
|
72
72
|
ACTION_UNLIKELY_INTENT_NAME,
|
|
73
73
|
ACTION_VALIDATE_SLOT_MAPPINGS,
|
|
74
74
|
DEFAULT_SLOT_NAMES,
|
|
75
|
+
KEY_MAPPING_TYPE,
|
|
75
76
|
KNOWLEDGE_BASE_SLOT_NAMES,
|
|
76
|
-
MAPPING_TYPE,
|
|
77
77
|
REQUESTED_SLOT,
|
|
78
78
|
USER_INTENT_OUT_OF_SCOPE,
|
|
79
|
+
SetSlotExtractor,
|
|
79
80
|
SlotMappingType,
|
|
80
81
|
)
|
|
81
82
|
from rasa.shared.core.domain import Domain
|
|
@@ -118,6 +119,10 @@ logger = logging.getLogger(__name__)
|
|
|
118
119
|
def default_actions(action_endpoint: Optional[EndpointConfig] = None) -> List["Action"]:
|
|
119
120
|
"""List default actions."""
|
|
120
121
|
from rasa.core.actions.action_clean_stack import ActionCleanStack
|
|
122
|
+
from rasa.core.actions.action_handle_digressions import (
|
|
123
|
+
ActionBlockDigressions,
|
|
124
|
+
ActionContinueDigression,
|
|
125
|
+
)
|
|
121
126
|
from rasa.core.actions.action_hangup import ActionHangup
|
|
122
127
|
from rasa.core.actions.action_repeat_bot_messages import ActionRepeatBotMessages
|
|
123
128
|
from rasa.core.actions.action_run_slot_rejections import ActionRunSlotRejections
|
|
@@ -152,6 +157,8 @@ def default_actions(action_endpoint: Optional[EndpointConfig] = None) -> List["A
|
|
|
152
157
|
ActionResetRouting(),
|
|
153
158
|
ActionHangup(),
|
|
154
159
|
ActionRepeatBotMessages(),
|
|
160
|
+
ActionBlockDigressions(),
|
|
161
|
+
ActionContinueDigression(),
|
|
155
162
|
]
|
|
156
163
|
|
|
157
164
|
|
|
@@ -940,7 +947,14 @@ class RemoteAction(Action):
|
|
|
940
947
|
)
|
|
941
948
|
|
|
942
949
|
events = rasa.shared.core.events.deserialise_events(events_json)
|
|
943
|
-
|
|
950
|
+
|
|
951
|
+
processed_events = []
|
|
952
|
+
for event in events:
|
|
953
|
+
if isinstance(event, SlotSet) and event.filled_by is None:
|
|
954
|
+
event.filled_by = SetSlotExtractor.CUSTOM.value
|
|
955
|
+
processed_events.append(event)
|
|
956
|
+
|
|
957
|
+
return cast(List[Event], bot_messages) + processed_events
|
|
944
958
|
|
|
945
959
|
def name(self) -> Text:
|
|
946
960
|
return self._name
|
|
@@ -1317,7 +1331,7 @@ class ActionExtractSlots(Action):
|
|
|
1317
1331
|
slot_events.append(SlotSet(slot.name, slot_value))
|
|
1318
1332
|
|
|
1319
1333
|
for mapping in slot.mappings:
|
|
1320
|
-
mapping_type = SlotMappingType(mapping.get(
|
|
1334
|
+
mapping_type = SlotMappingType(mapping.get(KEY_MAPPING_TYPE))
|
|
1321
1335
|
should_fill_custom_slot = mapping_type == SlotMappingType.CUSTOM
|
|
1322
1336
|
|
|
1323
1337
|
if should_fill_custom_slot:
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, List, Optional
|
|
4
|
+
|
|
5
|
+
import structlog
|
|
6
|
+
|
|
7
|
+
from rasa.core.actions.action import Action, create_bot_utterance
|
|
8
|
+
from rasa.core.channels import OutputChannel
|
|
9
|
+
from rasa.core.nlg import NaturalLanguageGenerator
|
|
10
|
+
from rasa.core.utils import add_bot_utterance_metadata
|
|
11
|
+
from rasa.dialogue_understanding.patterns.continue_interrupted import (
|
|
12
|
+
ContinueInterruptedPatternFlowStackFrame,
|
|
13
|
+
)
|
|
14
|
+
from rasa.dialogue_understanding.patterns.handle_digressions import (
|
|
15
|
+
HandleDigressionsPatternFlowStackFrame,
|
|
16
|
+
)
|
|
17
|
+
from rasa.dialogue_understanding.stack.frames.flow_stack_frame import (
|
|
18
|
+
FlowStackFrameType,
|
|
19
|
+
UserFlowStackFrame,
|
|
20
|
+
)
|
|
21
|
+
from rasa.shared.core.constants import (
|
|
22
|
+
ACTION_BLOCK_DIGRESSION,
|
|
23
|
+
ACTION_CONTINUE_DIGRESSION,
|
|
24
|
+
)
|
|
25
|
+
from rasa.shared.core.domain import Domain
|
|
26
|
+
from rasa.shared.core.events import Event, FlowInterrupted
|
|
27
|
+
from rasa.shared.core.trackers import DialogueStateTracker
|
|
28
|
+
|
|
29
|
+
structlogger = structlog.get_logger()
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ActionBlockDigressions(Action):
|
|
33
|
+
"""Action which blocks an interruption and continues the current flow."""
|
|
34
|
+
|
|
35
|
+
def name(self) -> str:
|
|
36
|
+
"""Return the action name."""
|
|
37
|
+
return ACTION_BLOCK_DIGRESSION
|
|
38
|
+
|
|
39
|
+
async def run(
|
|
40
|
+
self,
|
|
41
|
+
output_channel: OutputChannel,
|
|
42
|
+
nlg: NaturalLanguageGenerator,
|
|
43
|
+
tracker: DialogueStateTracker,
|
|
44
|
+
domain: Domain,
|
|
45
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
46
|
+
) -> List[Event]:
|
|
47
|
+
"""Update the stack."""
|
|
48
|
+
structlogger.debug("action_block_digressions.run")
|
|
49
|
+
top_frame = tracker.stack.top()
|
|
50
|
+
|
|
51
|
+
if not isinstance(top_frame, HandleDigressionsPatternFlowStackFrame):
|
|
52
|
+
return []
|
|
53
|
+
|
|
54
|
+
blocked_flow_id = top_frame.interrupting_flow_id
|
|
55
|
+
frame_type = FlowStackFrameType.REGULAR
|
|
56
|
+
|
|
57
|
+
stack = tracker.stack
|
|
58
|
+
stack.push(
|
|
59
|
+
UserFlowStackFrame(flow_id=blocked_flow_id, frame_type=frame_type), 0
|
|
60
|
+
)
|
|
61
|
+
stack.push(
|
|
62
|
+
ContinueInterruptedPatternFlowStackFrame(
|
|
63
|
+
previous_flow_name=blocked_flow_id
|
|
64
|
+
),
|
|
65
|
+
1,
|
|
66
|
+
)
|
|
67
|
+
events = tracker.create_stack_updated_events(stack)
|
|
68
|
+
|
|
69
|
+
utterance = "utter_block_digressions"
|
|
70
|
+
message = await nlg.generate(
|
|
71
|
+
utterance,
|
|
72
|
+
tracker,
|
|
73
|
+
output_channel.name(),
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
if message is None:
|
|
77
|
+
structlogger.error(
|
|
78
|
+
"action_block_digressions.run.failed.finding.utter",
|
|
79
|
+
utterance=utterance,
|
|
80
|
+
)
|
|
81
|
+
else:
|
|
82
|
+
message = add_bot_utterance_metadata(
|
|
83
|
+
message, utterance, nlg, domain, tracker
|
|
84
|
+
)
|
|
85
|
+
events.append(create_bot_utterance(message))
|
|
86
|
+
|
|
87
|
+
return events
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class ActionContinueDigression(Action):
|
|
91
|
+
"""Action which continues with an interruption."""
|
|
92
|
+
|
|
93
|
+
def name(self) -> str:
|
|
94
|
+
"""Return the action name."""
|
|
95
|
+
return ACTION_CONTINUE_DIGRESSION
|
|
96
|
+
|
|
97
|
+
async def run(
|
|
98
|
+
self,
|
|
99
|
+
output_channel: OutputChannel,
|
|
100
|
+
nlg: NaturalLanguageGenerator,
|
|
101
|
+
tracker: DialogueStateTracker,
|
|
102
|
+
domain: Domain,
|
|
103
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
104
|
+
) -> List[Event]:
|
|
105
|
+
"""Update the stack."""
|
|
106
|
+
structlogger.debug("action_continue_digression.run")
|
|
107
|
+
top_frame = tracker.stack.top()
|
|
108
|
+
|
|
109
|
+
if not isinstance(top_frame, HandleDigressionsPatternFlowStackFrame):
|
|
110
|
+
return []
|
|
111
|
+
|
|
112
|
+
blocked_flow_id = top_frame.interrupting_flow_id
|
|
113
|
+
frame_type = FlowStackFrameType.INTERRUPT
|
|
114
|
+
stack = tracker.stack
|
|
115
|
+
stack.push(UserFlowStackFrame(flow_id=blocked_flow_id, frame_type=frame_type))
|
|
116
|
+
|
|
117
|
+
events = [
|
|
118
|
+
FlowInterrupted(
|
|
119
|
+
flow_id=top_frame.interrupted_flow_id,
|
|
120
|
+
step_id=top_frame.interrupted_step_id,
|
|
121
|
+
)
|
|
122
|
+
] + tracker.create_stack_updated_events(stack)
|
|
123
|
+
|
|
124
|
+
utterance = "utter_continue_interruption"
|
|
125
|
+
message = await nlg.generate(
|
|
126
|
+
utterance,
|
|
127
|
+
tracker,
|
|
128
|
+
output_channel.name(),
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
if message is None:
|
|
132
|
+
structlogger.error(
|
|
133
|
+
"action_continue_digression.run.failed.finding.utter",
|
|
134
|
+
utterance=utterance,
|
|
135
|
+
)
|
|
136
|
+
else:
|
|
137
|
+
message = add_bot_utterance_metadata(
|
|
138
|
+
message, utterance, nlg, domain, tracker
|
|
139
|
+
)
|
|
140
|
+
events.append(create_bot_utterance(message))
|
|
141
|
+
|
|
142
|
+
return events
|
rasa/core/actions/forms.py
CHANGED
|
@@ -16,7 +16,7 @@ from rasa.shared.constants import UTTER_PREFIX
|
|
|
16
16
|
from rasa.shared.core.constants import (
|
|
17
17
|
ACTION_EXTRACT_SLOTS,
|
|
18
18
|
ACTION_LISTEN_NAME,
|
|
19
|
-
|
|
19
|
+
KEY_MAPPING_TYPE,
|
|
20
20
|
REQUESTED_SLOT,
|
|
21
21
|
SLOT_MAPPINGS,
|
|
22
22
|
SlotMappingType,
|
|
@@ -158,7 +158,9 @@ class FormAction(LoopAction):
|
|
|
158
158
|
domain_slots = domain.as_dict().get(KEY_SLOTS, {})
|
|
159
159
|
for slot in domain.required_slots_for_form(self.name()):
|
|
160
160
|
for slot_mapping in domain_slots.get(slot, {}).get(SLOT_MAPPINGS, []):
|
|
161
|
-
if slot_mapping.get(
|
|
161
|
+
if slot_mapping.get(KEY_MAPPING_TYPE) == str(
|
|
162
|
+
SlotMappingType.FROM_ENTITY
|
|
163
|
+
):
|
|
162
164
|
mapping_as_string = json.dumps(slot_mapping, sort_keys=True)
|
|
163
165
|
if mapping_as_string in unique_entity_slot_mappings:
|
|
164
166
|
unique_entity_slot_mappings.remove(mapping_as_string)
|
|
@@ -1,9 +1,11 @@
|
|
|
1
|
+
import asyncio
|
|
1
2
|
import copy
|
|
2
3
|
import json
|
|
3
4
|
import uuid
|
|
5
|
+
from collections import defaultdict
|
|
4
6
|
from dataclasses import asdict
|
|
5
7
|
from datetime import datetime, timedelta, timezone
|
|
6
|
-
from typing import Any, Awaitable, Callable, Dict, List, Optional, Text, Union
|
|
8
|
+
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Text, Union
|
|
7
9
|
|
|
8
10
|
import structlog
|
|
9
11
|
from jsonschema import ValidationError, validate
|
|
@@ -223,6 +225,16 @@ class AudiocodesInput(InputChannel):
|
|
|
223
225
|
self.scheduler_job = None
|
|
224
226
|
self.keep_alive = keep_alive
|
|
225
227
|
self.keep_alive_expiration_factor = keep_alive_expiration_factor
|
|
228
|
+
self.background_tasks: Dict[Text, Set[asyncio.Task]] = defaultdict(set)
|
|
229
|
+
|
|
230
|
+
def _create_task(self, conversation_id: Text, coro: Awaitable[Any]) -> asyncio.Task:
|
|
231
|
+
"""Create and track an asyncio task for a conversation."""
|
|
232
|
+
task: asyncio.Task = asyncio.create_task(coro)
|
|
233
|
+
self.background_tasks[conversation_id].add(task)
|
|
234
|
+
task.add_done_callback(
|
|
235
|
+
lambda t: self.background_tasks[conversation_id].discard(t)
|
|
236
|
+
)
|
|
237
|
+
return task
|
|
226
238
|
|
|
227
239
|
async def _set_scheduler_job(self) -> None:
|
|
228
240
|
if self.scheduler_job:
|
|
@@ -251,11 +263,20 @@ class AudiocodesInput(InputChannel):
|
|
|
251
263
|
)
|
|
252
264
|
now = datetime.now(timezone.utc)
|
|
253
265
|
delta = timedelta(seconds=self.keep_alive * self.keep_alive_expiration_factor)
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
266
|
+
|
|
267
|
+
# clean up conversations
|
|
268
|
+
inactive = [
|
|
269
|
+
conv_id
|
|
270
|
+
for conv_id, conv in self.conversations.items()
|
|
271
|
+
if not conv.is_active_conversation(now, delta)
|
|
272
|
+
]
|
|
273
|
+
|
|
274
|
+
# cancel tasks and remove conversations
|
|
275
|
+
for conv_id in inactive:
|
|
276
|
+
for task in self.background_tasks[conv_id]:
|
|
277
|
+
task.cancel()
|
|
278
|
+
self.background_tasks.pop(conv_id, None)
|
|
279
|
+
self.conversations.pop(conv_id, None)
|
|
259
280
|
|
|
260
281
|
def handle_start_conversation(self, body: Dict[Text, Any]) -> Dict[Text, Any]:
|
|
261
282
|
conversation_id = body["conversation"]
|
|
@@ -347,31 +368,29 @@ class AudiocodesInput(InputChannel):
|
|
|
347
368
|
structlogger.debug("audiocodes.on_activities", conversation=conversation_id)
|
|
348
369
|
conversation = self._get_conversation(request.token, conversation_id)
|
|
349
370
|
if conversation is None:
|
|
371
|
+
structlogger.warning(
|
|
372
|
+
"audiocodes.on_activities.no_conversation", request=request.json
|
|
373
|
+
)
|
|
350
374
|
return response.json({})
|
|
351
375
|
elif conversation.ws:
|
|
352
376
|
ac_output: Union[WebsocketOutput, AudiocodesOutput] = WebsocketOutput(
|
|
353
377
|
conversation.ws, conversation_id
|
|
354
378
|
)
|
|
355
|
-
|
|
356
|
-
request.json,
|
|
357
|
-
output_channel=ac_output,
|
|
358
|
-
on_new_message=on_new_message,
|
|
359
|
-
)
|
|
360
|
-
return response.json({})
|
|
379
|
+
response_json = {}
|
|
361
380
|
else:
|
|
362
381
|
# handle non websocket case where messages get returned in json
|
|
363
382
|
ac_output = AudiocodesOutput()
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
383
|
+
response_json = {
|
|
384
|
+
"conversation": conversation_id,
|
|
385
|
+
"activities": ac_output.messages,
|
|
386
|
+
}
|
|
387
|
+
|
|
388
|
+
# start a background task to handle activities
|
|
389
|
+
self._create_task(
|
|
390
|
+
conversation_id,
|
|
391
|
+
conversation.handle_activities(request.json, ac_output, on_new_message),
|
|
392
|
+
)
|
|
393
|
+
return response.json(response_json)
|
|
375
394
|
|
|
376
395
|
@ac_webhook.route(
|
|
377
396
|
"/conversation/<conversation_id>/disconnect", methods=["POST"]
|
|
@@ -81,7 +81,8 @@ class AzureTTS(TTSEngine[AzureTTSConfig]):
|
|
|
81
81
|
@staticmethod
|
|
82
82
|
def create_request_body(text: str, conf: AzureTTSConfig) -> str:
|
|
83
83
|
return f"""
|
|
84
|
-
<speak version='1.0' xml:lang='{conf.language}'
|
|
84
|
+
<speak version='1.0' xml:lang='{conf.language}' xmlns:mstts='http://www.w3.org/2001/mstts'
|
|
85
|
+
xmlns='http://www.w3.org/2001/10/synthesis'>
|
|
85
86
|
<voice xml:lang='{conf.language}' name='{conf.voice}'>
|
|
86
87
|
{text}
|
|
87
88
|
</voice>
|
rasa/core/migrate.py
CHANGED
|
@@ -14,7 +14,7 @@ from rasa.shared.constants import (
|
|
|
14
14
|
)
|
|
15
15
|
from rasa.shared.core.constants import (
|
|
16
16
|
ACTIVE_LOOP,
|
|
17
|
-
|
|
17
|
+
KEY_MAPPING_TYPE,
|
|
18
18
|
REQUESTED_SLOT,
|
|
19
19
|
SLOT_MAPPINGS,
|
|
20
20
|
SlotMappingType,
|
|
@@ -43,7 +43,7 @@ def _create_back_up(domain_file: Path, backup_location: Path) -> Dict[Text, Any]
|
|
|
43
43
|
def _get_updated_mapping_condition(
|
|
44
44
|
condition: Dict[Text, Text], mapping: Dict[Text, Any], slot_name: Text
|
|
45
45
|
) -> Dict[Text, Text]:
|
|
46
|
-
if mapping.get(
|
|
46
|
+
if mapping.get(KEY_MAPPING_TYPE) not in [
|
|
47
47
|
str(SlotMappingType.FROM_ENTITY),
|
|
48
48
|
str(SlotMappingType.FROM_TRIGGER_INTENT),
|
|
49
49
|
]:
|
|
@@ -23,6 +23,7 @@ from rasa.core.policies.flows.flow_step_result import (
|
|
|
23
23
|
)
|
|
24
24
|
from rasa.dialogue_understanding.commands import CancelFlowCommand
|
|
25
25
|
from rasa.dialogue_understanding.patterns.cancel import CancelPatternFlowStackFrame
|
|
26
|
+
from rasa.dialogue_understanding.patterns.clarify import ClarifyPatternFlowStackFrame
|
|
26
27
|
from rasa.dialogue_understanding.patterns.collect_information import (
|
|
27
28
|
CollectInformationPatternFlowStackFrame,
|
|
28
29
|
)
|
|
@@ -50,6 +51,7 @@ from rasa.dialogue_understanding.stack.frames.flow_stack_frame import (
|
|
|
50
51
|
)
|
|
51
52
|
from rasa.dialogue_understanding.stack.utils import (
|
|
52
53
|
top_user_flow_frame,
|
|
54
|
+
user_flows_on_the_stack,
|
|
53
55
|
)
|
|
54
56
|
from rasa.shared.constants import RASA_PATTERN_HUMAN_HANDOFF
|
|
55
57
|
from rasa.shared.core.constants import ACTION_LISTEN_NAME, SlotMappingType
|
|
@@ -272,6 +274,28 @@ def trigger_pattern_continue_interrupted(
|
|
|
272
274
|
return events
|
|
273
275
|
|
|
274
276
|
|
|
277
|
+
def trigger_pattern_clarification(
|
|
278
|
+
current_frame: DialogueStackFrame, stack: DialogueStack, flows: FlowsList
|
|
279
|
+
) -> None:
|
|
280
|
+
"""Trigger the pattern to clarify which topic to continue if needed."""
|
|
281
|
+
if not isinstance(current_frame, UserFlowStackFrame):
|
|
282
|
+
return None
|
|
283
|
+
|
|
284
|
+
if current_frame.frame_type == FlowStackFrameType.CALL:
|
|
285
|
+
# we want to return to the flow that called the current flow
|
|
286
|
+
return None
|
|
287
|
+
|
|
288
|
+
pending_flows = [
|
|
289
|
+
flows.flow_by_id(frame.flow_id)
|
|
290
|
+
for frame in stack.frames
|
|
291
|
+
if isinstance(frame, UserFlowStackFrame)
|
|
292
|
+
and frame.flow_id != current_frame.flow_id
|
|
293
|
+
]
|
|
294
|
+
|
|
295
|
+
flow_names = [flow.readable_name() for flow in pending_flows if flow is not None]
|
|
296
|
+
stack.push(ClarifyPatternFlowStackFrame(names=flow_names))
|
|
297
|
+
|
|
298
|
+
|
|
275
299
|
def trigger_pattern_completed(
|
|
276
300
|
current_frame: DialogueStackFrame, stack: DialogueStack, flows: FlowsList
|
|
277
301
|
) -> None:
|
|
@@ -669,7 +693,15 @@ def _run_end_step(
|
|
|
669
693
|
structlogger.debug("flow.step.run.flow_end")
|
|
670
694
|
current_frame = stack.pop()
|
|
671
695
|
trigger_pattern_completed(current_frame, stack, flows)
|
|
672
|
-
resumed_events =
|
|
696
|
+
resumed_events = []
|
|
697
|
+
if len(user_flows_on_the_stack(stack)) > 1:
|
|
698
|
+
# if there are more user flows on the stack,
|
|
699
|
+
# we need to trigger the pattern clarify
|
|
700
|
+
trigger_pattern_clarification(current_frame, stack, flows)
|
|
701
|
+
else:
|
|
702
|
+
resumed_events = trigger_pattern_continue_interrupted(
|
|
703
|
+
current_frame, stack, flows
|
|
704
|
+
)
|
|
673
705
|
reset_events: List[Event] = reset_scoped_slots(current_frame, flow, tracker)
|
|
674
706
|
return ContinueFlowWithNextStep(
|
|
675
707
|
events=initial_events + reset_events + resumed_events, has_flow_ended=True
|
|
@@ -74,7 +74,7 @@ class CannotHandleCommand(Command):
|
|
|
74
74
|
|
|
75
75
|
def to_dsl(self) -> str:
|
|
76
76
|
"""Converts the command to a DSL string."""
|
|
77
|
-
return "
|
|
77
|
+
return "CannotHandle()"
|
|
78
78
|
|
|
79
79
|
@classmethod
|
|
80
80
|
def from_dsl(cls, match: re.Match, **kwargs: Any) -> CannotHandleCommand:
|
|
@@ -86,4 +86,4 @@ class CannotHandleCommand(Command):
|
|
|
86
86
|
|
|
87
87
|
@staticmethod
|
|
88
88
|
def regex_pattern() -> str:
|
|
89
|
-
return r"
|
|
89
|
+
return r"CannotHandle\(\)"
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import copy
|
|
3
4
|
import re
|
|
4
5
|
from dataclasses import dataclass
|
|
5
6
|
from typing import Any, Dict, List
|
|
@@ -8,8 +9,11 @@ import structlog
|
|
|
8
9
|
|
|
9
10
|
from rasa.dialogue_understanding.commands.command import Command
|
|
10
11
|
from rasa.dialogue_understanding.patterns.cancel import CancelPatternFlowStackFrame
|
|
12
|
+
from rasa.dialogue_understanding.patterns.clarify import ClarifyPatternFlowStackFrame
|
|
11
13
|
from rasa.dialogue_understanding.stack.dialogue_stack import DialogueStack
|
|
12
|
-
from rasa.dialogue_understanding.stack.frames import
|
|
14
|
+
from rasa.dialogue_understanding.stack.frames import (
|
|
15
|
+
UserFlowStackFrame,
|
|
16
|
+
)
|
|
13
17
|
from rasa.dialogue_understanding.stack.frames.flow_stack_frame import FlowStackFrameType
|
|
14
18
|
from rasa.dialogue_understanding.stack.utils import top_user_flow_frame
|
|
15
19
|
from rasa.shared.core.events import Event, FlowCancelled
|
|
@@ -89,7 +93,8 @@ class CancelFlowCommand(Command):
|
|
|
89
93
|
original_stack = original_tracker.stack
|
|
90
94
|
|
|
91
95
|
applied_events: List[Event] = []
|
|
92
|
-
|
|
96
|
+
# capture the top frame before we push new frames onto the stack
|
|
97
|
+
initial_top_frame = stack.top()
|
|
93
98
|
user_frame = top_user_flow_frame(original_stack)
|
|
94
99
|
current_flow = user_frame.flow(all_flows) if user_frame else None
|
|
95
100
|
|
|
@@ -114,6 +119,21 @@ class CancelFlowCommand(Command):
|
|
|
114
119
|
if user_frame:
|
|
115
120
|
applied_events.append(FlowCancelled(user_frame.flow_id, user_frame.step_id))
|
|
116
121
|
|
|
122
|
+
if initial_top_frame and isinstance(
|
|
123
|
+
initial_top_frame, ClarifyPatternFlowStackFrame
|
|
124
|
+
):
|
|
125
|
+
structlogger.debug(
|
|
126
|
+
"command_executor.cancel_flow.cancel_clarification_options",
|
|
127
|
+
clarification_options=initial_top_frame.clarification_options,
|
|
128
|
+
)
|
|
129
|
+
applied_events += cancel_all_pending_clarification_options(
|
|
130
|
+
initial_top_frame,
|
|
131
|
+
original_stack,
|
|
132
|
+
canceled_frames,
|
|
133
|
+
all_flows,
|
|
134
|
+
stack,
|
|
135
|
+
)
|
|
136
|
+
|
|
117
137
|
return applied_events + tracker.create_stack_updated_events(stack)
|
|
118
138
|
|
|
119
139
|
def __hash__(self) -> int:
|
|
@@ -124,7 +144,7 @@ class CancelFlowCommand(Command):
|
|
|
124
144
|
|
|
125
145
|
def to_dsl(self) -> str:
|
|
126
146
|
"""Converts the command to a DSL string."""
|
|
127
|
-
return "
|
|
147
|
+
return "CancelFlow()"
|
|
128
148
|
|
|
129
149
|
@classmethod
|
|
130
150
|
def from_dsl(cls, match: re.Match, **kwargs: Any) -> CancelFlowCommand:
|
|
@@ -133,4 +153,42 @@ class CancelFlowCommand(Command):
|
|
|
133
153
|
|
|
134
154
|
@staticmethod
|
|
135
155
|
def regex_pattern() -> str:
|
|
136
|
-
return r"
|
|
156
|
+
return r"CancelFlow\(\)"
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def cancel_all_pending_clarification_options(
|
|
160
|
+
initial_top_frame: ClarifyPatternFlowStackFrame,
|
|
161
|
+
original_stack: DialogueStack,
|
|
162
|
+
canceled_frames: List[str],
|
|
163
|
+
all_flows: FlowsList,
|
|
164
|
+
stack: DialogueStack,
|
|
165
|
+
) -> List[FlowCancelled]:
|
|
166
|
+
"""Cancel all pending clarification options.
|
|
167
|
+
|
|
168
|
+
This is a special case when the assistant asks the user to clarify
|
|
169
|
+
which pending digression flow to start after the completion of an active flow.
|
|
170
|
+
If the user chooses to cancel all options, this function takes care of
|
|
171
|
+
updating the stack by removing all pending flow stack frames
|
|
172
|
+
listed as clarification options.
|
|
173
|
+
"""
|
|
174
|
+
clarification_names = set(initial_top_frame.names)
|
|
175
|
+
to_be_canceled_frames = []
|
|
176
|
+
applied_events = []
|
|
177
|
+
for frame in reversed(original_stack.frames):
|
|
178
|
+
if frame.frame_id in canceled_frames:
|
|
179
|
+
continue
|
|
180
|
+
|
|
181
|
+
to_be_canceled_frames.append(frame.frame_id)
|
|
182
|
+
if isinstance(frame, UserFlowStackFrame):
|
|
183
|
+
readable_flow_name = frame.flow(all_flows).readable_name()
|
|
184
|
+
if readable_flow_name in clarification_names:
|
|
185
|
+
stack.push(
|
|
186
|
+
CancelPatternFlowStackFrame(
|
|
187
|
+
canceled_name=readable_flow_name,
|
|
188
|
+
canceled_frames=copy.deepcopy(to_be_canceled_frames),
|
|
189
|
+
)
|
|
190
|
+
)
|
|
191
|
+
applied_events.append(FlowCancelled(frame.flow_id, frame.step_id))
|
|
192
|
+
to_be_canceled_frames.clear()
|
|
193
|
+
|
|
194
|
+
return applied_events
|
|
@@ -48,7 +48,7 @@ class ChangeFlowCommand(Command):
|
|
|
48
48
|
|
|
49
49
|
def to_dsl(self) -> str:
|
|
50
50
|
"""Converts the command to a DSL string."""
|
|
51
|
-
return "
|
|
51
|
+
return "ChangeFlow()"
|
|
52
52
|
|
|
53
53
|
@staticmethod
|
|
54
54
|
def from_dsl(match: re.Match, **kwargs: Any) -> ChangeFlowCommand:
|
|
@@ -57,4 +57,4 @@ class ChangeFlowCommand(Command):
|
|
|
57
57
|
|
|
58
58
|
@staticmethod
|
|
59
59
|
def regex_pattern() -> str:
|
|
60
|
-
return r"
|
|
60
|
+
return r"ChangeFlow\(\)"
|
|
@@ -59,7 +59,7 @@ class ChitChatAnswerCommand(FreeFormAnswerCommand):
|
|
|
59
59
|
|
|
60
60
|
def to_dsl(self) -> str:
|
|
61
61
|
"""Converts the command to a DSL string."""
|
|
62
|
-
return "
|
|
62
|
+
return "ChitChat()"
|
|
63
63
|
|
|
64
64
|
@classmethod
|
|
65
65
|
def from_dsl(cls, match: re.Match, **kwargs: Any) -> ChitChatAnswerCommand:
|
|
@@ -68,4 +68,4 @@ class ChitChatAnswerCommand(FreeFormAnswerCommand):
|
|
|
68
68
|
|
|
69
69
|
@staticmethod
|
|
70
70
|
def regex_pattern() -> str:
|
|
71
|
-
return r"
|
|
71
|
+
return r"ChitChat\(\)"
|
|
@@ -89,7 +89,7 @@ class ClarifyCommand(Command):
|
|
|
89
89
|
|
|
90
90
|
def to_dsl(self) -> str:
|
|
91
91
|
"""Converts the command to a DSL string."""
|
|
92
|
-
return f"
|
|
92
|
+
return f"Clarify({', '.join(self.options)})"
|
|
93
93
|
|
|
94
94
|
@classmethod
|
|
95
95
|
def from_dsl(cls, match: re.Match, **kwargs: Any) -> Optional[ClarifyCommand]:
|
|
@@ -99,4 +99,4 @@ class ClarifyCommand(Command):
|
|
|
99
99
|
|
|
100
100
|
@staticmethod
|
|
101
101
|
def regex_pattern() -> str:
|
|
102
|
-
return r"
|
|
102
|
+
return r"Clarify\(([\"\'a-zA-Z0-9_, ]*)\)"
|
|
@@ -31,6 +31,7 @@ class CorrectedSlot:
|
|
|
31
31
|
|
|
32
32
|
name: str
|
|
33
33
|
value: Any
|
|
34
|
+
filled_by: Optional[str] = None
|
|
34
35
|
|
|
35
36
|
|
|
36
37
|
@dataclass
|
|
@@ -54,7 +55,9 @@ class CorrectSlotsCommand(Command):
|
|
|
54
55
|
try:
|
|
55
56
|
return CorrectSlotsCommand(
|
|
56
57
|
corrected_slots=[
|
|
57
|
-
CorrectedSlot(
|
|
58
|
+
CorrectedSlot(
|
|
59
|
+
s["name"], value=s["value"], filled_by=s.get("filled_by", None)
|
|
60
|
+
)
|
|
58
61
|
for s in data["corrected_slots"]
|
|
59
62
|
]
|
|
60
63
|
)
|
|
@@ -135,7 +138,10 @@ class CorrectSlotsCommand(Command):
|
|
|
135
138
|
proposed_slots = {}
|
|
136
139
|
for corrected_slot in self.corrected_slots:
|
|
137
140
|
if tracker.get_slot(corrected_slot.name) != corrected_slot.value:
|
|
138
|
-
proposed_slots[corrected_slot.name] =
|
|
141
|
+
proposed_slots[corrected_slot.name] = {
|
|
142
|
+
"value": corrected_slot.value,
|
|
143
|
+
"filled_by": corrected_slot.filled_by,
|
|
144
|
+
}
|
|
139
145
|
else:
|
|
140
146
|
structlogger.debug(
|
|
141
147
|
"command_executor.skip_correction.slot_already_set", command=self
|
|
@@ -240,6 +246,9 @@ class CorrectSlotsCommand(Command):
|
|
|
240
246
|
corrected_slots=proposed_slots,
|
|
241
247
|
reset_flow_id=earliest_collect.flow_id if earliest_collect else None,
|
|
242
248
|
reset_step_id=earliest_collect.step.id if earliest_collect else None,
|
|
249
|
+
new_slot_values=[
|
|
250
|
+
value.get("value") for slot, value in proposed_slots.items()
|
|
251
|
+
],
|
|
243
252
|
)
|
|
244
253
|
|
|
245
254
|
def run_command_on_tracker(
|