rasa-pro 3.12.0.dev13__py3-none-any.whl → 3.12.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/anonymization/anonymization_rule_executor.py +16 -10
- rasa/cli/data.py +16 -0
- rasa/cli/project_templates/calm/config.yml +2 -2
- rasa/cli/project_templates/calm/endpoints.yml +2 -2
- rasa/cli/utils.py +12 -0
- rasa/core/actions/action.py +84 -191
- rasa/core/actions/action_run_slot_rejections.py +16 -4
- rasa/core/channels/__init__.py +2 -0
- rasa/core/channels/studio_chat.py +19 -0
- rasa/core/channels/telegram.py +42 -24
- rasa/core/channels/voice_ready/utils.py +1 -1
- rasa/core/channels/voice_stream/asr/asr_engine.py +10 -4
- rasa/core/channels/voice_stream/asr/azure.py +14 -1
- rasa/core/channels/voice_stream/asr/deepgram.py +20 -4
- rasa/core/channels/voice_stream/audiocodes.py +264 -0
- rasa/core/channels/voice_stream/browser_audio.py +4 -1
- rasa/core/channels/voice_stream/call_state.py +3 -0
- rasa/core/channels/voice_stream/genesys.py +6 -2
- rasa/core/channels/voice_stream/tts/azure.py +9 -1
- rasa/core/channels/voice_stream/tts/cartesia.py +14 -8
- rasa/core/channels/voice_stream/voice_channel.py +23 -2
- rasa/core/constants.py +2 -0
- rasa/core/nlg/contextual_response_rephraser.py +18 -1
- rasa/core/nlg/generator.py +83 -15
- rasa/core/nlg/response.py +6 -3
- rasa/core/nlg/translate.py +55 -0
- rasa/core/policies/enterprise_search_prompt_with_citation_template.jinja2 +1 -1
- rasa/core/policies/flows/flow_executor.py +12 -5
- rasa/core/processor.py +72 -9
- rasa/dialogue_understanding/commands/can_not_handle_command.py +20 -2
- rasa/dialogue_understanding/commands/cancel_flow_command.py +24 -6
- rasa/dialogue_understanding/commands/change_flow_command.py +20 -2
- rasa/dialogue_understanding/commands/chit_chat_answer_command.py +20 -2
- rasa/dialogue_understanding/commands/clarify_command.py +29 -3
- rasa/dialogue_understanding/commands/command.py +1 -16
- rasa/dialogue_understanding/commands/command_syntax_manager.py +55 -0
- rasa/dialogue_understanding/commands/human_handoff_command.py +20 -2
- rasa/dialogue_understanding/commands/knowledge_answer_command.py +20 -2
- rasa/dialogue_understanding/commands/prompt_command.py +94 -0
- rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +20 -2
- rasa/dialogue_understanding/commands/set_slot_command.py +24 -2
- rasa/dialogue_understanding/commands/skip_question_command.py +20 -2
- rasa/dialogue_understanding/commands/start_flow_command.py +20 -2
- rasa/dialogue_understanding/commands/utils.py +98 -4
- rasa/dialogue_understanding/generator/__init__.py +2 -0
- rasa/dialogue_understanding/generator/command_parser.py +15 -12
- rasa/dialogue_understanding/generator/constants.py +3 -0
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +12 -5
- rasa/dialogue_understanding/generator/llm_command_generator.py +5 -3
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +16 -2
- rasa/dialogue_understanding/generator/prompt_templates/__init__.py +0 -0
- rasa/dialogue_understanding/generator/{single_step → prompt_templates}/command_prompt_template.jinja2 +2 -0
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_claude_3_5_sonnet_20240620_template.jinja2 +77 -0
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_default.jinja2 +68 -0
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_gpt_4o_2024_11_20_template.jinja2 +84 -0
- rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +460 -0
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +12 -310
- rasa/dialogue_understanding/patterns/collect_information.py +1 -1
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +16 -0
- rasa/dialogue_understanding/patterns/validate_slot.py +65 -0
- rasa/dialogue_understanding/processor/command_processor.py +39 -0
- rasa/dialogue_understanding_test/du_test_case.py +28 -8
- rasa/dialogue_understanding_test/du_test_result.py +13 -9
- rasa/dialogue_understanding_test/io.py +14 -0
- rasa/e2e_test/utils/io.py +0 -37
- rasa/engine/graph.py +1 -0
- rasa/engine/language.py +140 -0
- rasa/engine/recipes/config_files/default_config.yml +4 -0
- rasa/engine/recipes/default_recipe.py +2 -0
- rasa/engine/recipes/graph_recipe.py +2 -0
- rasa/engine/storage/local_model_storage.py +1 -0
- rasa/engine/storage/storage.py +4 -1
- rasa/model_manager/runner_service.py +7 -4
- rasa/model_manager/socket_bridge.py +7 -6
- rasa/shared/constants.py +15 -13
- rasa/shared/core/constants.py +2 -0
- rasa/shared/core/flows/constants.py +11 -0
- rasa/shared/core/flows/flow.py +83 -19
- rasa/shared/core/flows/flows_yaml_schema.json +31 -3
- rasa/shared/core/flows/steps/collect.py +1 -36
- rasa/shared/core/flows/utils.py +28 -4
- rasa/shared/core/flows/validation.py +1 -1
- rasa/shared/core/slot_mappings.py +208 -5
- rasa/shared/core/slots.py +131 -1
- rasa/shared/core/trackers.py +74 -1
- rasa/shared/importers/importer.py +50 -2
- rasa/shared/nlu/training_data/schemas/responses.yml +19 -12
- rasa/shared/providers/_configs/azure_entra_id_config.py +541 -0
- rasa/shared/providers/_configs/azure_openai_client_config.py +138 -3
- rasa/shared/providers/_configs/client_config.py +3 -1
- rasa/shared/providers/_configs/default_litellm_client_config.py +3 -1
- rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +3 -1
- rasa/shared/providers/_configs/litellm_router_client_config.py +3 -1
- rasa/shared/providers/_configs/model_group_config.py +4 -2
- rasa/shared/providers/_configs/oauth_config.py +33 -0
- rasa/shared/providers/_configs/openai_client_config.py +3 -1
- rasa/shared/providers/_configs/rasa_llm_client_config.py +3 -1
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +3 -1
- rasa/shared/providers/constants.py +6 -0
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +28 -3
- rasa/shared/providers/embedding/litellm_router_embedding_client.py +3 -1
- rasa/shared/providers/llm/_base_litellm_client.py +42 -17
- rasa/shared/providers/llm/azure_openai_llm_client.py +81 -25
- rasa/shared/providers/llm/default_litellm_llm_client.py +3 -1
- rasa/shared/providers/llm/litellm_router_llm_client.py +29 -8
- rasa/shared/providers/llm/llm_client.py +23 -7
- rasa/shared/providers/llm/openai_llm_client.py +9 -3
- rasa/shared/providers/llm/rasa_llm_client.py +11 -2
- rasa/shared/providers/llm/self_hosted_llm_client.py +30 -11
- rasa/shared/providers/router/_base_litellm_router_client.py +3 -1
- rasa/shared/providers/router/router_client.py +3 -1
- rasa/shared/utils/constants.py +3 -0
- rasa/shared/utils/llm.py +30 -7
- rasa/shared/utils/pykwalify_extensions.py +24 -0
- rasa/shared/utils/schemas/domain.yml +26 -0
- rasa/telemetry.py +2 -1
- rasa/tracing/config.py +2 -0
- rasa/tracing/constants.py +12 -0
- rasa/tracing/instrumentation/instrumentation.py +36 -0
- rasa/tracing/instrumentation/metrics.py +41 -0
- rasa/tracing/metric_instrument_provider.py +40 -0
- rasa/validator.py +372 -7
- rasa/version.py +1 -1
- {rasa_pro-3.12.0.dev13.dist-info → rasa_pro-3.12.0rc1.dist-info}/METADATA +2 -1
- {rasa_pro-3.12.0.dev13.dist-info → rasa_pro-3.12.0rc1.dist-info}/RECORD +128 -113
- {rasa_pro-3.12.0.dev13.dist-info → rasa_pro-3.12.0rc1.dist-info}/NOTICE +0 -0
- {rasa_pro-3.12.0.dev13.dist-info → rasa_pro-3.12.0rc1.dist-info}/WHEEL +0 -0
- {rasa_pro-3.12.0.dev13.dist-info → rasa_pro-3.12.0rc1.dist-info}/entry_points.txt +0 -0
|
@@ -5,6 +5,10 @@ from dataclasses import dataclass
|
|
|
5
5
|
from typing import Any, Dict, List
|
|
6
6
|
|
|
7
7
|
from rasa.dialogue_understanding.commands.command import Command
|
|
8
|
+
from rasa.dialogue_understanding.commands.command_syntax_manager import (
|
|
9
|
+
CommandSyntaxManager,
|
|
10
|
+
CommandSyntaxVersion,
|
|
11
|
+
)
|
|
8
12
|
from rasa.shared.core.events import Event
|
|
9
13
|
from rasa.shared.core.flows import FlowsList
|
|
10
14
|
from rasa.shared.core.trackers import DialogueStateTracker
|
|
@@ -48,7 +52,14 @@ class ChangeFlowCommand(Command):
|
|
|
48
52
|
|
|
49
53
|
def to_dsl(self) -> str:
|
|
50
54
|
"""Converts the command to a DSL string."""
|
|
51
|
-
|
|
55
|
+
mapper = {
|
|
56
|
+
CommandSyntaxVersion.v1: "ChangeFlow()",
|
|
57
|
+
CommandSyntaxVersion.v2: "change",
|
|
58
|
+
}
|
|
59
|
+
return mapper.get(
|
|
60
|
+
CommandSyntaxManager.get_syntax_version(),
|
|
61
|
+
mapper[CommandSyntaxManager.get_default_syntax_version()],
|
|
62
|
+
)
|
|
52
63
|
|
|
53
64
|
@staticmethod
|
|
54
65
|
def from_dsl(match: re.Match, **kwargs: Any) -> ChangeFlowCommand:
|
|
@@ -57,4 +68,11 @@ class ChangeFlowCommand(Command):
|
|
|
57
68
|
|
|
58
69
|
@staticmethod
|
|
59
70
|
def regex_pattern() -> str:
|
|
60
|
-
|
|
71
|
+
mapper = {
|
|
72
|
+
CommandSyntaxVersion.v1: r"ChangeFlow\(\)",
|
|
73
|
+
CommandSyntaxVersion.v2: r"^[^\w]*change$",
|
|
74
|
+
}
|
|
75
|
+
return mapper.get(
|
|
76
|
+
CommandSyntaxManager.get_syntax_version(),
|
|
77
|
+
mapper[CommandSyntaxManager.get_default_syntax_version()],
|
|
78
|
+
)
|
|
@@ -4,6 +4,10 @@ import re
|
|
|
4
4
|
from dataclasses import dataclass
|
|
5
5
|
from typing import Any, Dict, List
|
|
6
6
|
|
|
7
|
+
from rasa.dialogue_understanding.commands.command_syntax_manager import (
|
|
8
|
+
CommandSyntaxManager,
|
|
9
|
+
CommandSyntaxVersion,
|
|
10
|
+
)
|
|
7
11
|
from rasa.dialogue_understanding.commands.free_form_answer_command import (
|
|
8
12
|
FreeFormAnswerCommand,
|
|
9
13
|
)
|
|
@@ -59,7 +63,14 @@ class ChitChatAnswerCommand(FreeFormAnswerCommand):
|
|
|
59
63
|
|
|
60
64
|
def to_dsl(self) -> str:
|
|
61
65
|
"""Converts the command to a DSL string."""
|
|
62
|
-
|
|
66
|
+
mapper = {
|
|
67
|
+
CommandSyntaxVersion.v1: "ChitChat()",
|
|
68
|
+
CommandSyntaxVersion.v2: "offtopic reply",
|
|
69
|
+
}
|
|
70
|
+
return mapper.get(
|
|
71
|
+
CommandSyntaxManager.get_syntax_version(),
|
|
72
|
+
mapper[CommandSyntaxManager.get_default_syntax_version()],
|
|
73
|
+
)
|
|
63
74
|
|
|
64
75
|
@classmethod
|
|
65
76
|
def from_dsl(cls, match: re.Match, **kwargs: Any) -> ChitChatAnswerCommand:
|
|
@@ -68,4 +79,11 @@ class ChitChatAnswerCommand(FreeFormAnswerCommand):
|
|
|
68
79
|
|
|
69
80
|
@staticmethod
|
|
70
81
|
def regex_pattern() -> str:
|
|
71
|
-
|
|
82
|
+
mapper = {
|
|
83
|
+
CommandSyntaxVersion.v1: r"ChitChat\(\)",
|
|
84
|
+
CommandSyntaxVersion.v2: r"^[^\w]*offtopic reply$",
|
|
85
|
+
}
|
|
86
|
+
return mapper.get(
|
|
87
|
+
CommandSyntaxManager.get_syntax_version(),
|
|
88
|
+
mapper[CommandSyntaxManager.get_default_syntax_version()],
|
|
89
|
+
)
|
|
@@ -7,6 +7,10 @@ from typing import Any, Dict, List, Optional
|
|
|
7
7
|
import structlog
|
|
8
8
|
|
|
9
9
|
from rasa.dialogue_understanding.commands.command import Command
|
|
10
|
+
from rasa.dialogue_understanding.commands.command_syntax_manager import (
|
|
11
|
+
CommandSyntaxManager,
|
|
12
|
+
CommandSyntaxVersion,
|
|
13
|
+
)
|
|
10
14
|
from rasa.dialogue_understanding.commands.utils import extract_cleaned_options
|
|
11
15
|
from rasa.dialogue_understanding.patterns.clarify import ClarifyPatternFlowStackFrame
|
|
12
16
|
from rasa.shared.core.events import Event
|
|
@@ -74,7 +78,13 @@ class ClarifyCommand(Command):
|
|
|
74
78
|
|
|
75
79
|
stack = tracker.stack
|
|
76
80
|
relevant_flows = [all_flows.flow_by_id(opt) for opt in clean_options]
|
|
77
|
-
|
|
81
|
+
|
|
82
|
+
names = [
|
|
83
|
+
flow.readable_name(language=tracker.current_language)
|
|
84
|
+
for flow in relevant_flows
|
|
85
|
+
if flow is not None
|
|
86
|
+
]
|
|
87
|
+
|
|
78
88
|
stack.push(ClarifyPatternFlowStackFrame(names=names))
|
|
79
89
|
return tracker.create_stack_updated_events(stack)
|
|
80
90
|
|
|
@@ -89,7 +99,14 @@ class ClarifyCommand(Command):
|
|
|
89
99
|
|
|
90
100
|
def to_dsl(self) -> str:
|
|
91
101
|
"""Converts the command to a DSL string."""
|
|
92
|
-
|
|
102
|
+
mapper = {
|
|
103
|
+
CommandSyntaxVersion.v1: f"Clarify({', '.join(self.options)})",
|
|
104
|
+
CommandSyntaxVersion.v2: f"disambiguate flows {' '.join(self.options)}",
|
|
105
|
+
}
|
|
106
|
+
return mapper.get(
|
|
107
|
+
CommandSyntaxManager.get_syntax_version(),
|
|
108
|
+
mapper[CommandSyntaxManager.get_default_syntax_version()],
|
|
109
|
+
)
|
|
93
110
|
|
|
94
111
|
@classmethod
|
|
95
112
|
def from_dsl(cls, match: re.Match, **kwargs: Any) -> Optional[ClarifyCommand]:
|
|
@@ -99,4 +116,13 @@ class ClarifyCommand(Command):
|
|
|
99
116
|
|
|
100
117
|
@staticmethod
|
|
101
118
|
def regex_pattern() -> str:
|
|
102
|
-
|
|
119
|
+
mapper = {
|
|
120
|
+
CommandSyntaxVersion.v1: r"Clarify\(([\"\'a-zA-Z0-9_, ]*)\)",
|
|
121
|
+
CommandSyntaxVersion.v2: (
|
|
122
|
+
r"^[^\w]*disambiguate flows ([\"\'a-zA-Z0-9_, ]*)$"
|
|
123
|
+
),
|
|
124
|
+
}
|
|
125
|
+
return mapper.get(
|
|
126
|
+
CommandSyntaxManager.get_syntax_version(),
|
|
127
|
+
mapper[CommandSyntaxManager.get_default_syntax_version()],
|
|
128
|
+
)
|
|
@@ -1,9 +1,8 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import dataclasses
|
|
4
|
-
import re
|
|
5
4
|
from dataclasses import dataclass
|
|
6
|
-
from typing import Any, Dict, List
|
|
5
|
+
from typing import Any, Dict, List
|
|
7
6
|
|
|
8
7
|
import rasa.shared.utils.common
|
|
9
8
|
from rasa.shared.core.events import Event
|
|
@@ -85,17 +84,3 @@ class Command:
|
|
|
85
84
|
The events to apply to the tracker.
|
|
86
85
|
"""
|
|
87
86
|
raise NotImplementedError()
|
|
88
|
-
|
|
89
|
-
def to_dsl(self) -> str:
|
|
90
|
-
"""Converts the command to a DSL string."""
|
|
91
|
-
raise NotImplementedError()
|
|
92
|
-
|
|
93
|
-
@classmethod
|
|
94
|
-
def from_dsl(cls, match: re.Match, **kwargs: Any) -> Optional[Command]:
|
|
95
|
-
"""Converts the DSL string to a command."""
|
|
96
|
-
raise NotImplementedError()
|
|
97
|
-
|
|
98
|
-
@staticmethod
|
|
99
|
-
def regex_pattern() -> str:
|
|
100
|
-
"""Returns the regex pattern for the command."""
|
|
101
|
-
raise NotImplementedError()
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import structlog
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class CommandSyntaxVersion(Enum):
|
|
8
|
+
"""Defines different syntax versions for commands."""
|
|
9
|
+
|
|
10
|
+
v1 = "v1"
|
|
11
|
+
v2 = "v2"
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
structlogger = structlog.get_logger()
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class CommandSyntaxManager:
|
|
18
|
+
"""A class to manage the command syntax version. It is used to set and get the
|
|
19
|
+
command syntax version. This class provides a way to introduce new syntax versions
|
|
20
|
+
for commands in the future. Hence, it is for internal use only.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
_version = None # Directly store the version as a class attribute
|
|
24
|
+
|
|
25
|
+
@classmethod
|
|
26
|
+
def set_syntax_version(cls, version: CommandSyntaxVersion) -> None:
|
|
27
|
+
"""Sets the command syntax version on the class itself.
|
|
28
|
+
This method is called only once at the time of LLMCommandGenerator
|
|
29
|
+
initialization to set the command syntax version, which ensures that the command
|
|
30
|
+
syntax version remains consistent throughout the lifetime of the generator.
|
|
31
|
+
"""
|
|
32
|
+
if cls._version:
|
|
33
|
+
structlogger.warn(
|
|
34
|
+
"command_syntax_manager.syntax_version_already_set",
|
|
35
|
+
event_info=(
|
|
36
|
+
f"The command syntax version has already been set. Overwriting "
|
|
37
|
+
f"the existing version with the new version - {version}."
|
|
38
|
+
),
|
|
39
|
+
)
|
|
40
|
+
cls._version = version
|
|
41
|
+
|
|
42
|
+
@classmethod
|
|
43
|
+
def get_syntax_version(cls) -> Optional[CommandSyntaxVersion]:
|
|
44
|
+
"""Fetches the stored command syntax version."""
|
|
45
|
+
return cls._version
|
|
46
|
+
|
|
47
|
+
@staticmethod
|
|
48
|
+
def get_default_syntax_version() -> CommandSyntaxVersion:
|
|
49
|
+
"""Returns the default command syntax version."""
|
|
50
|
+
return CommandSyntaxVersion.v1
|
|
51
|
+
|
|
52
|
+
@classmethod
|
|
53
|
+
def reset_syntax_version(cls) -> None:
|
|
54
|
+
"""Resets the command syntax version. Implemented for use in testing."""
|
|
55
|
+
cls._version = None
|
|
@@ -7,6 +7,10 @@ from typing import Any, Dict, List
|
|
|
7
7
|
import structlog
|
|
8
8
|
|
|
9
9
|
from rasa.dialogue_understanding.commands.command import Command
|
|
10
|
+
from rasa.dialogue_understanding.commands.command_syntax_manager import (
|
|
11
|
+
CommandSyntaxManager,
|
|
12
|
+
CommandSyntaxVersion,
|
|
13
|
+
)
|
|
10
14
|
from rasa.dialogue_understanding.patterns.human_handoff import (
|
|
11
15
|
HumanHandoffPatternFlowStackFrame,
|
|
12
16
|
)
|
|
@@ -66,7 +70,14 @@ class HumanHandoffCommand(Command):
|
|
|
66
70
|
|
|
67
71
|
def to_dsl(self) -> str:
|
|
68
72
|
"""Converts the command to a DSL string."""
|
|
69
|
-
|
|
73
|
+
mapper = {
|
|
74
|
+
CommandSyntaxVersion.v1: "HumanHandoff()",
|
|
75
|
+
CommandSyntaxVersion.v2: "hand over",
|
|
76
|
+
}
|
|
77
|
+
return mapper.get(
|
|
78
|
+
CommandSyntaxManager.get_syntax_version(),
|
|
79
|
+
mapper[CommandSyntaxManager.get_default_syntax_version()],
|
|
80
|
+
)
|
|
70
81
|
|
|
71
82
|
@classmethod
|
|
72
83
|
def from_dsl(cls, match: re.Match, **kwargs: Any) -> HumanHandoffCommand:
|
|
@@ -75,4 +86,11 @@ class HumanHandoffCommand(Command):
|
|
|
75
86
|
|
|
76
87
|
@staticmethod
|
|
77
88
|
def regex_pattern() -> str:
|
|
78
|
-
|
|
89
|
+
mapper = {
|
|
90
|
+
CommandSyntaxVersion.v1: r"HumanHandoff\(\)",
|
|
91
|
+
CommandSyntaxVersion.v2: r"^[^\w]*hand over$",
|
|
92
|
+
}
|
|
93
|
+
return mapper.get(
|
|
94
|
+
CommandSyntaxManager.get_syntax_version(),
|
|
95
|
+
mapper[CommandSyntaxManager.get_default_syntax_version()],
|
|
96
|
+
)
|
|
@@ -4,6 +4,10 @@ import re
|
|
|
4
4
|
from dataclasses import dataclass
|
|
5
5
|
from typing import Any, Dict, List
|
|
6
6
|
|
|
7
|
+
from rasa.dialogue_understanding.commands.command_syntax_manager import (
|
|
8
|
+
CommandSyntaxManager,
|
|
9
|
+
CommandSyntaxVersion,
|
|
10
|
+
)
|
|
7
11
|
from rasa.dialogue_understanding.commands.free_form_answer_command import (
|
|
8
12
|
FreeFormAnswerCommand,
|
|
9
13
|
)
|
|
@@ -59,7 +63,14 @@ class KnowledgeAnswerCommand(FreeFormAnswerCommand):
|
|
|
59
63
|
|
|
60
64
|
def to_dsl(self) -> str:
|
|
61
65
|
"""Converts the command to a DSL string."""
|
|
62
|
-
|
|
66
|
+
mapper = {
|
|
67
|
+
CommandSyntaxVersion.v1: "SearchAndReply()",
|
|
68
|
+
CommandSyntaxVersion.v2: "provide info",
|
|
69
|
+
}
|
|
70
|
+
return mapper.get(
|
|
71
|
+
CommandSyntaxManager.get_syntax_version(),
|
|
72
|
+
mapper[CommandSyntaxManager.get_default_syntax_version()],
|
|
73
|
+
)
|
|
63
74
|
|
|
64
75
|
@classmethod
|
|
65
76
|
def from_dsl(cls, match: re.Match, **kwargs: Any) -> KnowledgeAnswerCommand:
|
|
@@ -68,4 +79,11 @@ class KnowledgeAnswerCommand(FreeFormAnswerCommand):
|
|
|
68
79
|
|
|
69
80
|
@staticmethod
|
|
70
81
|
def regex_pattern() -> str:
|
|
71
|
-
|
|
82
|
+
mapper = {
|
|
83
|
+
CommandSyntaxVersion.v1: r"SearchAndReply\(\)",
|
|
84
|
+
CommandSyntaxVersion.v2: r"^[^\w]*provide info$",
|
|
85
|
+
}
|
|
86
|
+
return mapper.get(
|
|
87
|
+
CommandSyntaxManager.get_syntax_version(),
|
|
88
|
+
mapper[CommandSyntaxManager.get_default_syntax_version()],
|
|
89
|
+
)
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
|
5
|
+
|
|
6
|
+
from rasa.shared.core.events import Event
|
|
7
|
+
from rasa.shared.core.flows import FlowsList
|
|
8
|
+
from rasa.shared.core.trackers import DialogueStateTracker
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@runtime_checkable
|
|
12
|
+
class PromptCommand(Protocol):
|
|
13
|
+
"""
|
|
14
|
+
A protocol for commands predicted by the LLM model and incorporated into the prompt.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
@classmethod
|
|
18
|
+
def command(cls) -> str:
|
|
19
|
+
"""
|
|
20
|
+
Returns the command name.
|
|
21
|
+
|
|
22
|
+
This class method should be implemented to return the name of the command.
|
|
23
|
+
"""
|
|
24
|
+
...
|
|
25
|
+
|
|
26
|
+
@classmethod
|
|
27
|
+
def from_dict(cls, data: Dict[str, Any]) -> PromptCommand:
|
|
28
|
+
"""
|
|
29
|
+
Converts the dictionary to a command.
|
|
30
|
+
|
|
31
|
+
This class method should be implemented to create a command instance from the
|
|
32
|
+
given dictionary.
|
|
33
|
+
"""
|
|
34
|
+
...
|
|
35
|
+
|
|
36
|
+
def run_command_on_tracker(
|
|
37
|
+
self,
|
|
38
|
+
tracker: DialogueStateTracker,
|
|
39
|
+
all_flows: FlowsList,
|
|
40
|
+
original_tracker: DialogueStateTracker,
|
|
41
|
+
) -> List[Event]:
|
|
42
|
+
"""
|
|
43
|
+
Runs the command on the tracker.
|
|
44
|
+
|
|
45
|
+
This method should be implemented to execute the command on the given tracker.
|
|
46
|
+
"""
|
|
47
|
+
...
|
|
48
|
+
|
|
49
|
+
def __hash__(self) -> int:
|
|
50
|
+
"""
|
|
51
|
+
Returns the hash of the command.
|
|
52
|
+
|
|
53
|
+
This method should be implemented to return the hash of the command.
|
|
54
|
+
Useful for comparing commands and storing them in sets.
|
|
55
|
+
"""
|
|
56
|
+
...
|
|
57
|
+
|
|
58
|
+
def __eq__(self, other: object) -> bool:
|
|
59
|
+
"""
|
|
60
|
+
Compares the command with another object.
|
|
61
|
+
|
|
62
|
+
This method should be implemented to compare the command with another object.
|
|
63
|
+
"""
|
|
64
|
+
...
|
|
65
|
+
|
|
66
|
+
def to_dsl(self) -> str:
|
|
67
|
+
"""
|
|
68
|
+
Converts the command to a DSL string.
|
|
69
|
+
|
|
70
|
+
This method should be implemented to convert the command to a DSL string.
|
|
71
|
+
A DSL string is a string representation of the command that is used in the
|
|
72
|
+
prompt.
|
|
73
|
+
"""
|
|
74
|
+
...
|
|
75
|
+
|
|
76
|
+
@classmethod
|
|
77
|
+
def from_dsl(cls, match: re.Match, **kwargs: Any) -> Optional[PromptCommand]:
|
|
78
|
+
"""
|
|
79
|
+
Converts the regex match to a command.
|
|
80
|
+
|
|
81
|
+
This class method should be implemented to create a command instance from the
|
|
82
|
+
given DSL string.
|
|
83
|
+
"""
|
|
84
|
+
...
|
|
85
|
+
|
|
86
|
+
@staticmethod
|
|
87
|
+
def regex_pattern() -> str:
|
|
88
|
+
"""
|
|
89
|
+
Returns the regex pattern for the command.
|
|
90
|
+
|
|
91
|
+
This method should be implemented to return the regex pattern that matches the
|
|
92
|
+
command in the prompt.
|
|
93
|
+
"""
|
|
94
|
+
...
|
|
@@ -5,6 +5,10 @@ from dataclasses import dataclass
|
|
|
5
5
|
from typing import Any, Dict, List
|
|
6
6
|
|
|
7
7
|
from rasa.dialogue_understanding.commands.command import Command
|
|
8
|
+
from rasa.dialogue_understanding.commands.command_syntax_manager import (
|
|
9
|
+
CommandSyntaxManager,
|
|
10
|
+
CommandSyntaxVersion,
|
|
11
|
+
)
|
|
8
12
|
from rasa.dialogue_understanding.patterns.repeat import (
|
|
9
13
|
RepeatBotMessagesPatternFlowStackFrame,
|
|
10
14
|
)
|
|
@@ -60,7 +64,14 @@ class RepeatBotMessagesCommand(Command):
|
|
|
60
64
|
|
|
61
65
|
def to_dsl(self) -> str:
|
|
62
66
|
"""Converts the command to a DSL string."""
|
|
63
|
-
|
|
67
|
+
mapper = {
|
|
68
|
+
CommandSyntaxVersion.v1: "RepeatLastBotMessages()",
|
|
69
|
+
CommandSyntaxVersion.v2: "repeat message",
|
|
70
|
+
}
|
|
71
|
+
return mapper.get(
|
|
72
|
+
CommandSyntaxManager.get_syntax_version(),
|
|
73
|
+
mapper[CommandSyntaxManager.get_default_syntax_version()],
|
|
74
|
+
)
|
|
64
75
|
|
|
65
76
|
@classmethod
|
|
66
77
|
def from_dsl(cls, match: re.Match, **kwargs: Any) -> RepeatBotMessagesCommand:
|
|
@@ -69,4 +80,11 @@ class RepeatBotMessagesCommand(Command):
|
|
|
69
80
|
|
|
70
81
|
@staticmethod
|
|
71
82
|
def regex_pattern() -> str:
|
|
72
|
-
|
|
83
|
+
mapper = {
|
|
84
|
+
CommandSyntaxVersion.v1: r"RepeatLastBotMessages\(\)",
|
|
85
|
+
CommandSyntaxVersion.v2: r"^[^\w]*repeat message$",
|
|
86
|
+
}
|
|
87
|
+
return mapper.get(
|
|
88
|
+
CommandSyntaxManager.get_syntax_version(),
|
|
89
|
+
mapper[CommandSyntaxManager.get_default_syntax_version()],
|
|
90
|
+
)
|
|
@@ -7,6 +7,10 @@ from typing import Any, Dict, List
|
|
|
7
7
|
import structlog
|
|
8
8
|
|
|
9
9
|
from rasa.dialogue_understanding.commands.command import Command
|
|
10
|
+
from rasa.dialogue_understanding.commands.command_syntax_manager import (
|
|
11
|
+
CommandSyntaxManager,
|
|
12
|
+
CommandSyntaxVersion,
|
|
13
|
+
)
|
|
10
14
|
from rasa.dialogue_understanding.commands.utils import (
|
|
11
15
|
clean_extracted_value,
|
|
12
16
|
get_nullable_slot_value,
|
|
@@ -162,7 +166,14 @@ class SetSlotCommand(Command):
|
|
|
162
166
|
|
|
163
167
|
def to_dsl(self) -> str:
|
|
164
168
|
"""Converts the command to a DSL string."""
|
|
165
|
-
|
|
169
|
+
mapper = {
|
|
170
|
+
CommandSyntaxVersion.v1: f"SetSlot({self.name}, {self.value})",
|
|
171
|
+
CommandSyntaxVersion.v2: f"set slot {self.name} {self.value}",
|
|
172
|
+
}
|
|
173
|
+
return mapper.get(
|
|
174
|
+
CommandSyntaxManager.get_syntax_version(),
|
|
175
|
+
mapper[CommandSyntaxManager.get_default_syntax_version()],
|
|
176
|
+
)
|
|
166
177
|
|
|
167
178
|
@classmethod
|
|
168
179
|
def from_dsl(cls, match: re.Match, **kwargs: Any) -> SetSlotCommand:
|
|
@@ -173,4 +184,15 @@ class SetSlotCommand(Command):
|
|
|
173
184
|
|
|
174
185
|
@staticmethod
|
|
175
186
|
def regex_pattern() -> str:
|
|
176
|
-
|
|
187
|
+
mapper = {
|
|
188
|
+
CommandSyntaxVersion.v1: (
|
|
189
|
+
r"""SetSlot\(['"]?([a-zA-Z_][a-zA-Z0-9_-]*)['"]?, ?['"]?(.*)['"]?\)"""
|
|
190
|
+
),
|
|
191
|
+
CommandSyntaxVersion.v2: (
|
|
192
|
+
r"""^[^\w]*set slot ['"]?([a-zA-Z_][a-zA-Z0-9_-]*)['"]? ['"]?(.+?)['"]?$""" # noqa: E501
|
|
193
|
+
),
|
|
194
|
+
}
|
|
195
|
+
return mapper.get(
|
|
196
|
+
CommandSyntaxManager.get_syntax_version(),
|
|
197
|
+
mapper[CommandSyntaxManager.get_default_syntax_version()],
|
|
198
|
+
)
|
|
@@ -7,6 +7,10 @@ from typing import Any, Dict, List
|
|
|
7
7
|
import structlog
|
|
8
8
|
|
|
9
9
|
from rasa.dialogue_understanding.commands.command import Command
|
|
10
|
+
from rasa.dialogue_understanding.commands.command_syntax_manager import (
|
|
11
|
+
CommandSyntaxManager,
|
|
12
|
+
CommandSyntaxVersion,
|
|
13
|
+
)
|
|
10
14
|
from rasa.dialogue_understanding.patterns.skip_question import (
|
|
11
15
|
SkipQuestionPatternFlowStackFrame,
|
|
12
16
|
)
|
|
@@ -75,7 +79,14 @@ class SkipQuestionCommand(Command):
|
|
|
75
79
|
|
|
76
80
|
def to_dsl(self) -> str:
|
|
77
81
|
"""Converts the command to a DSL string."""
|
|
78
|
-
|
|
82
|
+
mapper = {
|
|
83
|
+
CommandSyntaxVersion.v1: "SkipQuestion()",
|
|
84
|
+
CommandSyntaxVersion.v2: "skip question",
|
|
85
|
+
}
|
|
86
|
+
return mapper.get(
|
|
87
|
+
CommandSyntaxManager.get_syntax_version(),
|
|
88
|
+
mapper[CommandSyntaxManager.get_default_syntax_version()],
|
|
89
|
+
)
|
|
79
90
|
|
|
80
91
|
@classmethod
|
|
81
92
|
def from_dsl(cls, match: re.Match, **kwargs: Any) -> SkipQuestionCommand:
|
|
@@ -84,4 +95,11 @@ class SkipQuestionCommand(Command):
|
|
|
84
95
|
|
|
85
96
|
@staticmethod
|
|
86
97
|
def regex_pattern() -> str:
|
|
87
|
-
|
|
98
|
+
mapper = {
|
|
99
|
+
CommandSyntaxVersion.v1: r"SkipQuestion\(\)",
|
|
100
|
+
CommandSyntaxVersion.v2: r"^[^\w]*skip question$",
|
|
101
|
+
}
|
|
102
|
+
return mapper.get(
|
|
103
|
+
CommandSyntaxManager.get_syntax_version(),
|
|
104
|
+
mapper[CommandSyntaxManager.get_default_syntax_version()],
|
|
105
|
+
)
|
|
@@ -7,6 +7,10 @@ from typing import Any, Dict, List, Optional
|
|
|
7
7
|
import structlog
|
|
8
8
|
|
|
9
9
|
from rasa.dialogue_understanding.commands.command import Command
|
|
10
|
+
from rasa.dialogue_understanding.commands.command_syntax_manager import (
|
|
11
|
+
CommandSyntaxManager,
|
|
12
|
+
CommandSyntaxVersion,
|
|
13
|
+
)
|
|
10
14
|
from rasa.dialogue_understanding.patterns.clarify import FLOW_PATTERN_CLARIFICATION
|
|
11
15
|
from rasa.dialogue_understanding.patterns.continue_interrupted import (
|
|
12
16
|
ContinueInterruptedPatternFlowStackFrame,
|
|
@@ -119,7 +123,14 @@ class StartFlowCommand(Command):
|
|
|
119
123
|
|
|
120
124
|
def to_dsl(self) -> str:
|
|
121
125
|
"""Converts the command to a DSL string."""
|
|
122
|
-
|
|
126
|
+
mapper = {
|
|
127
|
+
CommandSyntaxVersion.v1: f"StartFlow({self.flow})",
|
|
128
|
+
CommandSyntaxVersion.v2: f"start flow {self.flow}",
|
|
129
|
+
}
|
|
130
|
+
return mapper.get(
|
|
131
|
+
CommandSyntaxManager.get_syntax_version(),
|
|
132
|
+
mapper[CommandSyntaxManager.get_default_syntax_version()],
|
|
133
|
+
)
|
|
123
134
|
|
|
124
135
|
@classmethod
|
|
125
136
|
def from_dsl(cls, match: re.Match, **kwargs: Any) -> Optional[StartFlowCommand]:
|
|
@@ -128,7 +139,14 @@ class StartFlowCommand(Command):
|
|
|
128
139
|
|
|
129
140
|
@staticmethod
|
|
130
141
|
def regex_pattern() -> str:
|
|
131
|
-
|
|
142
|
+
mapper = {
|
|
143
|
+
CommandSyntaxVersion.v1: r"StartFlow\(['\"]?([a-zA-Z0-9_-]+)['\"]?\)",
|
|
144
|
+
CommandSyntaxVersion.v2: r"^[^\w]*start flow ['\"]?([a-zA-Z0-9_-]+)['\"]?",
|
|
145
|
+
}
|
|
146
|
+
return mapper.get(
|
|
147
|
+
CommandSyntaxManager.get_syntax_version(),
|
|
148
|
+
mapper[CommandSyntaxManager.get_default_syntax_version()],
|
|
149
|
+
)
|
|
132
150
|
|
|
133
151
|
def change_flow_frame_position_in_the_stack(
|
|
134
152
|
self, stack: DialogueStack, tracker: DialogueStateTracker
|
|
@@ -1,7 +1,18 @@
|
|
|
1
|
-
from typing import TYPE_CHECKING, List, Optional, Union
|
|
1
|
+
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
|
|
2
2
|
|
|
3
3
|
import structlog
|
|
4
4
|
|
|
5
|
+
from rasa.dialogue_understanding.patterns.validate_slot import (
|
|
6
|
+
ValidateSlotPatternFlowStackFrame,
|
|
7
|
+
)
|
|
8
|
+
from rasa.shared.constants import (
|
|
9
|
+
ACTION_ASK_PREFIX,
|
|
10
|
+
UTTER_ASK_PREFIX,
|
|
11
|
+
)
|
|
12
|
+
from rasa.shared.core.events import Event, SlotSet
|
|
13
|
+
from rasa.shared.core.slots import Slot
|
|
14
|
+
from rasa.shared.core.trackers import DialogueStateTracker
|
|
15
|
+
|
|
5
16
|
if TYPE_CHECKING:
|
|
6
17
|
from rasa.dialogue_understanding.commands import StartFlowCommand
|
|
7
18
|
from rasa.shared.core.flows import FlowsList
|
|
@@ -25,10 +36,13 @@ def start_flow_by_name(
|
|
|
25
36
|
|
|
26
37
|
def extract_cleaned_options(options_str: str) -> List[str]:
|
|
27
38
|
"""Extract and clean options from a string."""
|
|
39
|
+
delimiters = [",", " "]
|
|
40
|
+
|
|
41
|
+
for delimiter in delimiters:
|
|
42
|
+
options_str = options_str.replace(delimiter, " ")
|
|
43
|
+
|
|
28
44
|
return sorted(
|
|
29
|
-
opt.strip().strip('"').strip("'")
|
|
30
|
-
for opt in options_str.split(",")
|
|
31
|
-
if opt.strip()
|
|
45
|
+
opt.strip().strip('"').strip("'") for opt in options_str.split() if opt.strip()
|
|
32
46
|
)
|
|
33
47
|
|
|
34
48
|
|
|
@@ -62,3 +76,83 @@ def get_nullable_slot_value(slot_value: str) -> Union[str, None]:
|
|
|
62
76
|
The slot value or None if the value is a none value.
|
|
63
77
|
"""
|
|
64
78
|
return slot_value if not is_none_value(slot_value) else None
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def initialize_pattern_validate_slot(
|
|
82
|
+
slot: Slot,
|
|
83
|
+
) -> Optional[ValidateSlotPatternFlowStackFrame]:
|
|
84
|
+
"""Initialize the pattern to validate a slot value."""
|
|
85
|
+
if not slot.requires_validation():
|
|
86
|
+
return None
|
|
87
|
+
|
|
88
|
+
validation = slot.validation
|
|
89
|
+
slot_name = slot.name
|
|
90
|
+
return ValidateSlotPatternFlowStackFrame(
|
|
91
|
+
validate=slot_name,
|
|
92
|
+
refill_utter=validation.refill_utter or f"{UTTER_ASK_PREFIX}{slot_name}", # type: ignore[union-attr]
|
|
93
|
+
refill_action=f"{ACTION_ASK_PREFIX}{slot_name}",
|
|
94
|
+
rejections=validation.rejections, # type: ignore[union-attr]
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def create_validate_frames_from_slot_set_events(
|
|
99
|
+
tracker: DialogueStateTracker,
|
|
100
|
+
events: List[Event],
|
|
101
|
+
validate_frames: List[ValidateSlotPatternFlowStackFrame] = [],
|
|
102
|
+
should_break: bool = False,
|
|
103
|
+
update_corrected_slots: bool = False,
|
|
104
|
+
) -> Tuple[DialogueStateTracker, List[ValidateSlotPatternFlowStackFrame]]:
|
|
105
|
+
"""Process SlotSet events and create validation frames.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
tracker: The dialogue state tracker.
|
|
109
|
+
events: List of events to process.
|
|
110
|
+
should_break: whether or not to break after the first non-SlotSet event.
|
|
111
|
+
if True, break out of the event loop as soon as the first non-SlotSet
|
|
112
|
+
event is encountered.
|
|
113
|
+
if False, continue processing the events until the end.
|
|
114
|
+
update_corrected_slots: whether or not corrected slots in the last
|
|
115
|
+
correction frame need to be updated.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
Tuple of (updated tracker, list of validation frames).
|
|
119
|
+
"""
|
|
120
|
+
for event in events:
|
|
121
|
+
if not isinstance(event, SlotSet):
|
|
122
|
+
if should_break:
|
|
123
|
+
# we want to only process the most recent SlotSet events
|
|
124
|
+
# so we break once we encounter a different event
|
|
125
|
+
break
|
|
126
|
+
continue
|
|
127
|
+
|
|
128
|
+
slot = tracker.slots.get(event.key)
|
|
129
|
+
frame = initialize_pattern_validate_slot(slot)
|
|
130
|
+
|
|
131
|
+
if frame:
|
|
132
|
+
validate_frames.append(frame)
|
|
133
|
+
if update_corrected_slots:
|
|
134
|
+
tracker = update_corrected_slots_in_correction_frame(
|
|
135
|
+
tracker, event.key, event.value
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
return tracker, validate_frames
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def update_corrected_slots_in_correction_frame(
|
|
142
|
+
tracker: DialogueStateTracker, slot_name: str, slot_value: Any
|
|
143
|
+
) -> DialogueStateTracker:
|
|
144
|
+
"""Update the corrected_slots and new_slot_values of the
|
|
145
|
+
CorrectionPatternFlowStackFrame with only valid values.
|
|
146
|
+
"""
|
|
147
|
+
stack = tracker.stack
|
|
148
|
+
top_frame = stack.top()
|
|
149
|
+
del top_frame.corrected_slots[slot_name] # type: ignore[union-attr]
|
|
150
|
+
top_frame.new_slot_values.remove(slot_value) # type: ignore[union-attr]
|
|
151
|
+
|
|
152
|
+
# since we can't directly modify a stack we have to pop first
|
|
153
|
+
# and then push back the updated frame
|
|
154
|
+
stack.pop()
|
|
155
|
+
stack.push(top_frame)
|
|
156
|
+
new_events = tracker.create_stack_updated_events(stack)
|
|
157
|
+
tracker.update_with_events(new_events)
|
|
158
|
+
return tracker
|