rasa-pro 3.13.0.dev1__py3-none-any.whl → 3.13.0.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/core/actions/action.py +0 -6
- rasa/core/channels/voice_ready/audiocodes.py +52 -17
- rasa/core/channels/voice_stream/audiocodes.py +53 -9
- rasa/core/channels/voice_stream/genesys.py +146 -16
- rasa/core/information_retrieval/faiss.py +6 -1
- rasa/core/information_retrieval/information_retrieval.py +40 -2
- rasa/core/information_retrieval/milvus.py +7 -2
- rasa/core/information_retrieval/qdrant.py +7 -2
- rasa/core/policies/enterprise_search_policy.py +61 -301
- rasa/core/policies/flows/flow_executor.py +3 -38
- rasa/core/processor.py +27 -6
- rasa/core/utils.py +53 -0
- rasa/dialogue_understanding/commands/cancel_flow_command.py +4 -59
- rasa/dialogue_understanding/commands/start_flow_command.py +0 -41
- rasa/dialogue_understanding/generator/command_generator.py +67 -0
- rasa/dialogue_understanding/generator/command_parser.py +1 -1
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +4 -13
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_template.jinja2 +1 -1
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_gpt_4o_2024_11_20_template.jinja2 +20 -1
- rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +7 -0
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +0 -61
- rasa/dialogue_understanding/processor/command_processor.py +7 -65
- rasa/dialogue_understanding/stack/utils.py +0 -38
- rasa/dialogue_understanding_test/io.py +13 -8
- rasa/document_retrieval/__init__.py +0 -0
- rasa/document_retrieval/constants.py +32 -0
- rasa/document_retrieval/document_post_processor.py +351 -0
- rasa/document_retrieval/document_post_processor_prompt_template.jinja2 +0 -0
- rasa/document_retrieval/document_retriever.py +333 -0
- rasa/document_retrieval/knowledge_base_connectors/__init__.py +0 -0
- rasa/document_retrieval/knowledge_base_connectors/api_connector.py +39 -0
- rasa/document_retrieval/knowledge_base_connectors/knowledge_base_connector.py +34 -0
- rasa/document_retrieval/knowledge_base_connectors/vector_store_connector.py +226 -0
- rasa/document_retrieval/query_rewriter.py +234 -0
- rasa/document_retrieval/query_rewriter_prompt_template.jinja2 +8 -0
- rasa/engine/recipes/default_components.py +2 -0
- rasa/shared/core/constants.py +0 -8
- rasa/shared/core/domain.py +12 -3
- rasa/shared/core/flows/flow.py +0 -17
- rasa/shared/core/flows/flows_yaml_schema.json +3 -38
- rasa/shared/core/flows/steps/collect.py +5 -18
- rasa/shared/core/flows/utils.py +1 -16
- rasa/shared/core/slot_mappings.py +11 -5
- rasa/shared/nlu/constants.py +0 -1
- rasa/shared/utils/common.py +11 -1
- rasa/shared/utils/llm.py +1 -1
- rasa/tracing/instrumentation/attribute_extractors.py +10 -7
- rasa/tracing/instrumentation/instrumentation.py +12 -12
- rasa/validator.py +1 -123
- rasa/version.py +1 -1
- {rasa_pro-3.13.0.dev1.dist-info → rasa_pro-3.13.0.dev2.dist-info}/METADATA +1 -1
- {rasa_pro-3.13.0.dev1.dist-info → rasa_pro-3.13.0.dev2.dist-info}/RECORD +55 -47
- rasa/core/actions/action_handle_digressions.py +0 -164
- rasa/dialogue_understanding/commands/handle_digressions_command.py +0 -144
- rasa/dialogue_understanding/patterns/handle_digressions.py +0 -81
- {rasa_pro-3.13.0.dev1.dist-info → rasa_pro-3.13.0.dev2.dist-info}/NOTICE +0 -0
- {rasa_pro-3.13.0.dev1.dist-info → rasa_pro-3.13.0.dev2.dist-info}/WHEEL +0 -0
- {rasa_pro-3.13.0.dev1.dist-info → rasa_pro-3.13.0.dev2.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
import importlib.resources
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from typing import Any, Dict, Optional
|
|
4
|
+
|
|
5
|
+
import structlog
|
|
6
|
+
from jinja2 import Template
|
|
7
|
+
|
|
8
|
+
import rasa.shared.utils.io
|
|
9
|
+
from rasa.engine.storage.resource import Resource
|
|
10
|
+
from rasa.engine.storage.storage import ModelStorage
|
|
11
|
+
from rasa.shared.constants import (
|
|
12
|
+
LLM_CONFIG_KEY,
|
|
13
|
+
MODEL_CONFIG_KEY,
|
|
14
|
+
OPENAI_PROVIDER,
|
|
15
|
+
PROMPT_TEMPLATE_CONFIG_KEY,
|
|
16
|
+
PROVIDER_CONFIG_KEY,
|
|
17
|
+
TEXT,
|
|
18
|
+
TIMEOUT_CONFIG_KEY,
|
|
19
|
+
)
|
|
20
|
+
from rasa.shared.core.trackers import DialogueStateTracker
|
|
21
|
+
from rasa.shared.exceptions import FileIOException, ProviderClientAPIException
|
|
22
|
+
from rasa.shared.nlu.training_data.message import Message
|
|
23
|
+
from rasa.shared.providers.llm.llm_client import LLMClient
|
|
24
|
+
from rasa.shared.providers.llm.llm_response import LLMResponse
|
|
25
|
+
from rasa.shared.utils.health_check.health_check import perform_llm_health_check
|
|
26
|
+
from rasa.shared.utils.health_check.llm_health_check_mixin import (
|
|
27
|
+
LLMHealthCheckMixin,
|
|
28
|
+
)
|
|
29
|
+
from rasa.shared.utils.llm import (
|
|
30
|
+
AI,
|
|
31
|
+
DEFAULT_OPENAI_GENERATE_MODEL_NAME,
|
|
32
|
+
DEFAULT_OPENAI_MAX_GENERATED_TOKENS,
|
|
33
|
+
USER,
|
|
34
|
+
get_prompt_template,
|
|
35
|
+
llm_factory,
|
|
36
|
+
resolve_model_client_config,
|
|
37
|
+
sanitize_message_for_prompt,
|
|
38
|
+
tracker_as_readable_transcript,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
QUERY_REWRITER_PROMPT_FILE_NAME = "query_rewriter_prompt_template.jinja2"
|
|
42
|
+
MAX_TURNS = "max_turns"
|
|
43
|
+
DEFAULT_LLM_CONFIG = {
|
|
44
|
+
PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
|
|
45
|
+
MODEL_CONFIG_KEY: DEFAULT_OPENAI_GENERATE_MODEL_NAME,
|
|
46
|
+
"temperature": 0.3,
|
|
47
|
+
"max_tokens": DEFAULT_OPENAI_MAX_GENERATED_TOKENS,
|
|
48
|
+
TIMEOUT_CONFIG_KEY: 5,
|
|
49
|
+
}
|
|
50
|
+
DEFAULT_QUERY_REWRITER_PROMPT_TEMPLATE = importlib.resources.read_text(
|
|
51
|
+
"rasa.document_retrieval",
|
|
52
|
+
"query_rewriter_prompt_template.jinja2",
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
TYPE_CONFIG_KEY = "type"
|
|
56
|
+
|
|
57
|
+
structlogger = structlog.get_logger()
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class QueryRewritingType(Enum):
|
|
61
|
+
PLAIN = "PLAIN"
|
|
62
|
+
CONCATENATED_TURNS = "CONCATENATED_TURNS"
|
|
63
|
+
REPHRASE = "REPHRASE"
|
|
64
|
+
KEYWORD_EXTRACTION = "KEYWORD_EXTRACTION"
|
|
65
|
+
|
|
66
|
+
def __str__(self) -> str:
|
|
67
|
+
return self.value
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class QueryRewriter(LLMHealthCheckMixin):
|
|
71
|
+
@classmethod
|
|
72
|
+
def get_default_config(cls) -> Dict[str, Any]:
|
|
73
|
+
"""The default config for the query rewriter."""
|
|
74
|
+
return {
|
|
75
|
+
TYPE_CONFIG_KEY: QueryRewritingType.PLAIN,
|
|
76
|
+
MAX_TURNS: 0,
|
|
77
|
+
LLM_CONFIG_KEY: DEFAULT_LLM_CONFIG,
|
|
78
|
+
PROMPT_TEMPLATE_CONFIG_KEY: DEFAULT_QUERY_REWRITER_PROMPT_TEMPLATE,
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
def __init__(
|
|
82
|
+
self,
|
|
83
|
+
config: Dict[str, Any],
|
|
84
|
+
model_storage: ModelStorage,
|
|
85
|
+
resource: Resource,
|
|
86
|
+
prompt_template: Optional[str] = None,
|
|
87
|
+
):
|
|
88
|
+
self.config = {**self.get_default_config(), **config}
|
|
89
|
+
self.config[LLM_CONFIG_KEY] = resolve_model_client_config(
|
|
90
|
+
self.config.get(LLM_CONFIG_KEY), QueryRewriter.__name__
|
|
91
|
+
)
|
|
92
|
+
self.prompt_template = prompt_template or get_prompt_template(
|
|
93
|
+
config.get(PROMPT_TEMPLATE_CONFIG_KEY),
|
|
94
|
+
DEFAULT_QUERY_REWRITER_PROMPT_TEMPLATE,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
self._model_storage = model_storage
|
|
98
|
+
self._resource = resource
|
|
99
|
+
|
|
100
|
+
@classmethod
|
|
101
|
+
def load(
|
|
102
|
+
cls,
|
|
103
|
+
config: Dict[str, Any],
|
|
104
|
+
model_storage: ModelStorage,
|
|
105
|
+
resource: Resource,
|
|
106
|
+
**kwargs: Any,
|
|
107
|
+
) -> "QueryRewriter":
|
|
108
|
+
"""Load query rewriter."""
|
|
109
|
+
llm_config = resolve_model_client_config(config.get(LLM_CONFIG_KEY, {}))
|
|
110
|
+
perform_llm_health_check(
|
|
111
|
+
llm_config,
|
|
112
|
+
DEFAULT_LLM_CONFIG,
|
|
113
|
+
"query_rewriter.load",
|
|
114
|
+
QueryRewriter.__name__,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# load prompt template
|
|
118
|
+
prompt_template = None
|
|
119
|
+
try:
|
|
120
|
+
with model_storage.read_from(resource) as path:
|
|
121
|
+
prompt_template = rasa.shared.utils.io.read_file(
|
|
122
|
+
path / QUERY_REWRITER_PROMPT_FILE_NAME
|
|
123
|
+
)
|
|
124
|
+
except (FileNotFoundError, FileIOException) as e:
|
|
125
|
+
structlogger.warning(
|
|
126
|
+
"query_rewriter.load_prompt_template.failed",
|
|
127
|
+
error=e,
|
|
128
|
+
resource=resource.name,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
return QueryRewriter(config, model_storage, resource, prompt_template)
|
|
132
|
+
|
|
133
|
+
def persist(self) -> None:
|
|
134
|
+
with self._model_storage.write_to(self._resource) as path:
|
|
135
|
+
rasa.shared.utils.io.write_text_file(
|
|
136
|
+
self.prompt_template, path / QUERY_REWRITER_PROMPT_FILE_NAME
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
@staticmethod
|
|
140
|
+
def _concatenate_turns(
|
|
141
|
+
message: Message, tracker: DialogueStateTracker, max_turns: int
|
|
142
|
+
) -> str:
|
|
143
|
+
transcript = tracker_as_readable_transcript(tracker, max_turns=max_turns)
|
|
144
|
+
transcript += "\nUSER: " + message.get(TEXT)
|
|
145
|
+
return transcript
|
|
146
|
+
|
|
147
|
+
@staticmethod
|
|
148
|
+
async def _invoke_llm(prompt: str, llm: LLMClient) -> Optional[LLMResponse]:
|
|
149
|
+
try:
|
|
150
|
+
return await llm.acompletion(prompt)
|
|
151
|
+
except Exception as e:
|
|
152
|
+
# unfortunately, langchain does not wrap LLM exceptions which means
|
|
153
|
+
# we have to catch all exceptions here
|
|
154
|
+
structlogger.error("query_rewriter.llm.error", error=e)
|
|
155
|
+
raise ProviderClientAPIException(
|
|
156
|
+
message="LLM call exception", original_exception=e
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
async def _rephrase_message(
|
|
160
|
+
self, message: Message, tracker: DialogueStateTracker, max_turns: int = 5
|
|
161
|
+
) -> str:
|
|
162
|
+
llm = llm_factory(self.config.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG)
|
|
163
|
+
|
|
164
|
+
transcript = tracker_as_readable_transcript(
|
|
165
|
+
tracker, max_turns=max_turns, ai_prefix="ASSISTANT"
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
inputs = {
|
|
169
|
+
"conversation": transcript,
|
|
170
|
+
"user_message": message.get(TEXT),
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
prompt = Template(self.prompt_template).render(**inputs)
|
|
174
|
+
llm_response = await self._invoke_llm(prompt, llm)
|
|
175
|
+
llm_response = LLMResponse.ensure_llm_response(llm_response)
|
|
176
|
+
|
|
177
|
+
return llm_response.choices[0]
|
|
178
|
+
|
|
179
|
+
@staticmethod
|
|
180
|
+
def _keyword_extraction(
|
|
181
|
+
message: Message, tracker: DialogueStateTracker, max_turns: int = 5
|
|
182
|
+
) -> str:
|
|
183
|
+
import spacy
|
|
184
|
+
|
|
185
|
+
nlp = spacy.load("en_core_web_md")
|
|
186
|
+
|
|
187
|
+
transcript = tracker_as_readable_transcript(tracker, max_turns=max_turns)
|
|
188
|
+
transcript = transcript.replace(USER, "")
|
|
189
|
+
transcript = transcript.replace(AI, "")
|
|
190
|
+
|
|
191
|
+
doc = nlp(transcript)
|
|
192
|
+
|
|
193
|
+
keywords = set()
|
|
194
|
+
for token in doc:
|
|
195
|
+
# Extract nouns and proper nouns
|
|
196
|
+
if token.pos_ in ["NOUN", "PROPN"]:
|
|
197
|
+
keywords.add(token.lemma_)
|
|
198
|
+
|
|
199
|
+
for ent in doc.ents:
|
|
200
|
+
# Add named entities as keywords
|
|
201
|
+
keywords.add(ent.text)
|
|
202
|
+
|
|
203
|
+
# Remove stop words and punctuation
|
|
204
|
+
keywords = {
|
|
205
|
+
word
|
|
206
|
+
for word in keywords
|
|
207
|
+
if word.lower() not in nlp.Defaults.stop_words and word.isalpha()
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
if keywords:
|
|
211
|
+
return message.get(TEXT) + " " + " ".join(keywords)
|
|
212
|
+
else:
|
|
213
|
+
return message.get(TEXT)
|
|
214
|
+
|
|
215
|
+
async def prepare_search_query(
|
|
216
|
+
self, message: Message, tracker: DialogueStateTracker
|
|
217
|
+
) -> str:
|
|
218
|
+
query_rewriting_type = self.config[TYPE_CONFIG_KEY]
|
|
219
|
+
max_turns: int = self.config[MAX_TURNS]
|
|
220
|
+
|
|
221
|
+
query: str
|
|
222
|
+
|
|
223
|
+
if query_rewriting_type == QueryRewritingType.CONCATENATED_TURNS.value:
|
|
224
|
+
query = self._concatenate_turns(message, tracker, max_turns)
|
|
225
|
+
elif query_rewriting_type == QueryRewritingType.KEYWORD_EXTRACTION.value:
|
|
226
|
+
query = self._keyword_extraction(message, tracker, max_turns)
|
|
227
|
+
elif query_rewriting_type == QueryRewritingType.REPHRASE.value:
|
|
228
|
+
query = await self._rephrase_message(message, tracker, max_turns)
|
|
229
|
+
elif query_rewriting_type == QueryRewritingType.PLAIN.value:
|
|
230
|
+
query = message.get(TEXT)
|
|
231
|
+
else:
|
|
232
|
+
raise ValueError(f"Invalid query rewriting type: {query_rewriting_type}")
|
|
233
|
+
|
|
234
|
+
return sanitize_message_for_prompt(query)
|
|
@@ -13,6 +13,7 @@ from rasa.dialogue_understanding.generator import (
|
|
|
13
13
|
LLMCommandGenerator,
|
|
14
14
|
)
|
|
15
15
|
from rasa.dialogue_understanding.generator.nlu_command_adapter import NLUCommandAdapter
|
|
16
|
+
from rasa.document_retrieval.document_retriever import DocumentRetriever
|
|
16
17
|
from rasa.nlu.classifiers.diet_classifier import DIETClassifier
|
|
17
18
|
from rasa.nlu.classifiers.fallback_classifier import FallbackClassifier
|
|
18
19
|
from rasa.nlu.classifiers.keyword_intent_classifier import KeywordIntentClassifier
|
|
@@ -92,4 +93,5 @@ DEFAULT_COMPONENTS = [
|
|
|
92
93
|
FlowPolicy,
|
|
93
94
|
EnterpriseSearchPolicy,
|
|
94
95
|
IntentlessPolicy,
|
|
96
|
+
DocumentRetriever,
|
|
95
97
|
]
|
rasa/shared/core/constants.py
CHANGED
|
@@ -52,8 +52,6 @@ ACTION_TRIGGER_CHITCHAT = "action_trigger_chitchat"
|
|
|
52
52
|
ACTION_RESET_ROUTING = "action_reset_routing"
|
|
53
53
|
ACTION_HANGUP = "action_hangup"
|
|
54
54
|
ACTION_REPEAT_BOT_MESSAGES = "action_repeat_bot_messages"
|
|
55
|
-
ACTION_BLOCK_DIGRESSION = "action_block_digression"
|
|
56
|
-
ACTION_CONTINUE_DIGRESSION = "action_continue_digression"
|
|
57
55
|
|
|
58
56
|
ACTION_METADATA_EXECUTION_SUCCESS = "execution_success"
|
|
59
57
|
ACTION_METADATA_EXECUTION_ERROR_MESSAGE = "execution_error_message"
|
|
@@ -84,8 +82,6 @@ DEFAULT_ACTION_NAMES = [
|
|
|
84
82
|
ACTION_RESET_ROUTING,
|
|
85
83
|
ACTION_HANGUP,
|
|
86
84
|
ACTION_REPEAT_BOT_MESSAGES,
|
|
87
|
-
ACTION_BLOCK_DIGRESSION,
|
|
88
|
-
ACTION_CONTINUE_DIGRESSION,
|
|
89
85
|
]
|
|
90
86
|
|
|
91
87
|
ACTION_SHOULD_SEND_DOMAIN = "send_domain"
|
|
@@ -205,8 +201,4 @@ CLASSIFIER_NAME_FALLBACK = "FallbackClassifier"
|
|
|
205
201
|
|
|
206
202
|
POLICIES_THAT_EXTRACT_ENTITIES = {"TEDPolicy"}
|
|
207
203
|
|
|
208
|
-
# digression constants
|
|
209
|
-
KEY_ASK_CONFIRM_DIGRESSIONS = "ask_confirm_digressions"
|
|
210
|
-
KEY_BLOCK_DIGRESSIONS = "block_digressions"
|
|
211
|
-
|
|
212
204
|
ERROR_CODE_KEY = "error_code"
|
rasa/shared/core/domain.py
CHANGED
|
@@ -1678,6 +1678,14 @@ class Domain:
|
|
|
1678
1678
|
"""Write domain to a file."""
|
|
1679
1679
|
as_yaml = self.as_yaml()
|
|
1680
1680
|
rasa.shared.utils.io.write_text_file(as_yaml, filename)
|
|
1681
|
+
# run the check again on the written domain to catch any errors
|
|
1682
|
+
# that may have been missed in the user defined domain files
|
|
1683
|
+
structlogger.info(
|
|
1684
|
+
"domain.persist.domain_written_to_file",
|
|
1685
|
+
event_info="The entire domain content has been written to file.",
|
|
1686
|
+
filename=filename,
|
|
1687
|
+
)
|
|
1688
|
+
Domain.is_domain_file(filename)
|
|
1681
1689
|
|
|
1682
1690
|
def as_yaml(self) -> Text:
|
|
1683
1691
|
"""Dump the `Domain` object as a YAML string.
|
|
@@ -1972,17 +1980,18 @@ class Domain:
|
|
|
1972
1980
|
|
|
1973
1981
|
try:
|
|
1974
1982
|
content = read_yaml_file(filename, expand_env_vars=cls.expand_env_vars)
|
|
1975
|
-
except (RasaException, YamlSyntaxException):
|
|
1976
|
-
structlogger.
|
|
1983
|
+
except (RasaException, YamlSyntaxException) as error:
|
|
1984
|
+
structlogger.error(
|
|
1977
1985
|
"domain.cannot_load_domain_file",
|
|
1978
1986
|
file=filename,
|
|
1987
|
+
error=error,
|
|
1979
1988
|
event_info=(
|
|
1980
1989
|
f"The file {filename} could not be loaded as domain file. "
|
|
1981
1990
|
f"You can use https://yamlchecker.com/ to validate "
|
|
1982
1991
|
f"the YAML syntax of your file."
|
|
1983
1992
|
),
|
|
1984
1993
|
)
|
|
1985
|
-
|
|
1994
|
+
raise RasaException(f"Domain could not be loaded: {error}")
|
|
1986
1995
|
|
|
1987
1996
|
return any(key in content for key in ALL_DOMAIN_KEYS)
|
|
1988
1997
|
|
rasa/shared/core/flows/flow.py
CHANGED
|
@@ -13,10 +13,6 @@ from pypred import Predicate
|
|
|
13
13
|
import rasa.shared.utils.io
|
|
14
14
|
from rasa.engine.language import Language
|
|
15
15
|
from rasa.shared.constants import RASA_DEFAULT_FLOW_PATTERN_PREFIX
|
|
16
|
-
from rasa.shared.core.constants import (
|
|
17
|
-
KEY_ASK_CONFIRM_DIGRESSIONS,
|
|
18
|
-
KEY_BLOCK_DIGRESSIONS,
|
|
19
|
-
)
|
|
20
16
|
from rasa.shared.core.flows.constants import (
|
|
21
17
|
KEY_ALWAYS_INCLUDE_IN_PROMPT,
|
|
22
18
|
KEY_DESCRIPTION,
|
|
@@ -52,7 +48,6 @@ from rasa.shared.core.flows.steps.constants import (
|
|
|
52
48
|
START_STEP,
|
|
53
49
|
)
|
|
54
50
|
from rasa.shared.core.flows.steps.continuation import ContinueFlowStep
|
|
55
|
-
from rasa.shared.core.flows.utils import extract_digression_prop
|
|
56
51
|
from rasa.shared.core.slots import Slot
|
|
57
52
|
|
|
58
53
|
structlogger = structlog.get_logger()
|
|
@@ -94,10 +89,6 @@ class Flow:
|
|
|
94
89
|
"""The path to the file where the flow is stored."""
|
|
95
90
|
persisted_slots: List[str] = field(default_factory=list)
|
|
96
91
|
"""The list of slots that should be persisted after the flow ends."""
|
|
97
|
-
ask_confirm_digressions: List[str] = field(default_factory=list)
|
|
98
|
-
"""The flow ids for which the assistant should ask for confirmation."""
|
|
99
|
-
block_digressions: List[str] = field(default_factory=list)
|
|
100
|
-
"""The flow ids that the assistant should block from digressing to."""
|
|
101
92
|
run_pattern_completed: bool = True
|
|
102
93
|
"""Whether the pattern_completed flow should be run after the flow ends."""
|
|
103
94
|
|
|
@@ -138,10 +129,6 @@ class Flow:
|
|
|
138
129
|
# data. When the model is trained, take the provided file_path.
|
|
139
130
|
file_path=data.get(KEY_FILE_PATH) if KEY_FILE_PATH in data else file_path,
|
|
140
131
|
persisted_slots=data.get(KEY_PERSISTED_SLOTS, []),
|
|
141
|
-
ask_confirm_digressions=extract_digression_prop(
|
|
142
|
-
KEY_ASK_CONFIRM_DIGRESSIONS, data
|
|
143
|
-
),
|
|
144
|
-
block_digressions=extract_digression_prop(KEY_BLOCK_DIGRESSIONS, data),
|
|
145
132
|
run_pattern_completed=data.get(KEY_RUN_PATTERN_COMPLETED, True),
|
|
146
133
|
translation=extract_translations(
|
|
147
134
|
translation_data=data.get(KEY_TRANSLATION, {})
|
|
@@ -220,10 +207,6 @@ class Flow:
|
|
|
220
207
|
data[KEY_FILE_PATH] = self.file_path
|
|
221
208
|
if self.persisted_slots:
|
|
222
209
|
data[KEY_PERSISTED_SLOTS] = self.persisted_slots
|
|
223
|
-
if self.ask_confirm_digressions:
|
|
224
|
-
data[KEY_ASK_CONFIRM_DIGRESSIONS] = self.ask_confirm_digressions
|
|
225
|
-
if self.block_digressions:
|
|
226
|
-
data[KEY_BLOCK_DIGRESSIONS] = self.block_digressions
|
|
227
210
|
if self.run_pattern_completed is not None:
|
|
228
211
|
data["run_pattern_completed"] = self.run_pattern_completed
|
|
229
212
|
if self.translation:
|
|
@@ -217,15 +217,12 @@
|
|
|
217
217
|
"reset_after_flow_ends": {
|
|
218
218
|
"type": "boolean"
|
|
219
219
|
},
|
|
220
|
-
"ask_confirm_digressions": {
|
|
221
|
-
"$ref": "#/$defs/ask_confirm_digressions"
|
|
222
|
-
},
|
|
223
|
-
"block_digressions": {
|
|
224
|
-
"$ref": "#/$defs/block_digressions"
|
|
225
|
-
},
|
|
226
220
|
"utter": {
|
|
227
221
|
"type": "string"
|
|
228
222
|
},
|
|
223
|
+
"force_slot_filling": {
|
|
224
|
+
"type": "boolean"
|
|
225
|
+
},
|
|
229
226
|
"rejections": {
|
|
230
227
|
"type": "array",
|
|
231
228
|
"schema_name": "list of rejections",
|
|
@@ -253,32 +250,6 @@
|
|
|
253
250
|
}
|
|
254
251
|
}
|
|
255
252
|
},
|
|
256
|
-
"ask_confirm_digressions": {
|
|
257
|
-
"oneOf": [
|
|
258
|
-
{
|
|
259
|
-
"type": "boolean"
|
|
260
|
-
},
|
|
261
|
-
{
|
|
262
|
-
"type": "array",
|
|
263
|
-
"items": {
|
|
264
|
-
"type": "string"
|
|
265
|
-
}
|
|
266
|
-
}
|
|
267
|
-
]
|
|
268
|
-
},
|
|
269
|
-
"block_digressions": {
|
|
270
|
-
"oneOf": [
|
|
271
|
-
{
|
|
272
|
-
"type": "boolean"
|
|
273
|
-
},
|
|
274
|
-
{
|
|
275
|
-
"type": "array",
|
|
276
|
-
"items": {
|
|
277
|
-
"type": "string"
|
|
278
|
-
}
|
|
279
|
-
}
|
|
280
|
-
]
|
|
281
|
-
},
|
|
282
253
|
"flow": {
|
|
283
254
|
"required": [
|
|
284
255
|
"steps",
|
|
@@ -340,12 +311,6 @@
|
|
|
340
311
|
"persisted_slots": {
|
|
341
312
|
"$ref": "#/$defs/persisted_slots"
|
|
342
313
|
},
|
|
343
|
-
"ask_confirm_digressions": {
|
|
344
|
-
"$ref": "#/$defs/ask_confirm_digressions"
|
|
345
|
-
},
|
|
346
|
-
"block_digressions": {
|
|
347
|
-
"$ref": "#/$defs/block_digressions"
|
|
348
|
-
},
|
|
349
314
|
"run_pattern_completed": {
|
|
350
315
|
"type": "boolean"
|
|
351
316
|
}
|
|
@@ -1,15 +1,10 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from dataclasses import dataclass
|
|
3
|
+
from dataclasses import dataclass
|
|
4
4
|
from typing import Any, Dict, List, Set, Text
|
|
5
5
|
|
|
6
6
|
from rasa.shared.constants import ACTION_ASK_PREFIX, UTTER_ASK_PREFIX
|
|
7
|
-
from rasa.shared.core.constants import (
|
|
8
|
-
KEY_ASK_CONFIRM_DIGRESSIONS,
|
|
9
|
-
KEY_BLOCK_DIGRESSIONS,
|
|
10
|
-
)
|
|
11
7
|
from rasa.shared.core.flows.flow_step import FlowStep
|
|
12
|
-
from rasa.shared.core.flows.utils import extract_digression_prop
|
|
13
8
|
from rasa.shared.core.slots import SlotRejection
|
|
14
9
|
|
|
15
10
|
|
|
@@ -29,10 +24,8 @@ class CollectInformationFlowStep(FlowStep):
|
|
|
29
24
|
"""Whether to always ask the question even if the slot is already filled."""
|
|
30
25
|
reset_after_flow_ends: bool = True
|
|
31
26
|
"""Whether to reset the slot value at the end of the flow."""
|
|
32
|
-
|
|
33
|
-
"""
|
|
34
|
-
block_digressions: List[str] = field(default_factory=list)
|
|
35
|
-
"""The flow id digressions that should be blocked during the flow step."""
|
|
27
|
+
force_slot_filling: bool = False
|
|
28
|
+
"""Whether to keep only the SetSlot command for the collected slot."""
|
|
36
29
|
|
|
37
30
|
@classmethod
|
|
38
31
|
def from_json(
|
|
@@ -60,10 +53,7 @@ class CollectInformationFlowStep(FlowStep):
|
|
|
60
53
|
SlotRejection.from_dict(rejection)
|
|
61
54
|
for rejection in data.get("rejections", [])
|
|
62
55
|
],
|
|
63
|
-
|
|
64
|
-
KEY_ASK_CONFIRM_DIGRESSIONS, data
|
|
65
|
-
),
|
|
66
|
-
block_digressions=extract_digression_prop(KEY_BLOCK_DIGRESSIONS, data),
|
|
56
|
+
force_slot_filling=data.get("force_slot_filling", False),
|
|
67
57
|
**base.__dict__,
|
|
68
58
|
)
|
|
69
59
|
|
|
@@ -79,10 +69,7 @@ class CollectInformationFlowStep(FlowStep):
|
|
|
79
69
|
data["ask_before_filling"] = self.ask_before_filling
|
|
80
70
|
data["reset_after_flow_ends"] = self.reset_after_flow_ends
|
|
81
71
|
data["rejections"] = [rejection.as_dict() for rejection in self.rejections]
|
|
82
|
-
data["
|
|
83
|
-
data["block_digressions"] = (
|
|
84
|
-
self.block_digressions if self.block_digressions else False
|
|
85
|
-
)
|
|
72
|
+
data["force_slot_filling"] = self.force_slot_filling
|
|
86
73
|
|
|
87
74
|
return data
|
|
88
75
|
|
rasa/shared/core/flows/utils.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import TYPE_CHECKING, Any, Dict,
|
|
1
|
+
from typing import TYPE_CHECKING, Any, Dict, Set, Text
|
|
2
2
|
|
|
3
3
|
from rasa.shared.utils.io import raise_deprecation_warning
|
|
4
4
|
|
|
@@ -8,7 +8,6 @@ if TYPE_CHECKING:
|
|
|
8
8
|
|
|
9
9
|
RESET_PROPERTY_NAME = "reset_after_flow_ends"
|
|
10
10
|
PERSIST_PROPERTY_NAME = "persisted_slots"
|
|
11
|
-
ALL_LABEL = "ALL"
|
|
12
11
|
|
|
13
12
|
|
|
14
13
|
def warn_deprecated_collect_step_config() -> None:
|
|
@@ -45,20 +44,6 @@ def get_invalid_slot_persistence_config_error_message(
|
|
|
45
44
|
)
|
|
46
45
|
|
|
47
46
|
|
|
48
|
-
def extract_digression_prop(prop: str, data: Dict[str, Any]) -> List[str]:
|
|
49
|
-
"""Extracts the digression property from the data.
|
|
50
|
-
|
|
51
|
-
There can be two types of properties: ask_confirm_digressions and
|
|
52
|
-
block_digressions.
|
|
53
|
-
"""
|
|
54
|
-
digression_property = data.get(prop, [])
|
|
55
|
-
|
|
56
|
-
if isinstance(digression_property, bool):
|
|
57
|
-
digression_property = [ALL_LABEL] if digression_property else []
|
|
58
|
-
|
|
59
|
-
return digression_property
|
|
60
|
-
|
|
61
|
-
|
|
62
47
|
def extract_translations(
|
|
63
48
|
translation_data: Dict[Text, Any],
|
|
64
49
|
) -> Dict[Text, "FlowLanguageTranslation"]:
|
|
@@ -648,12 +648,14 @@ class SlotFillingManager:
|
|
|
648
648
|
output_channel: "OutputChannel",
|
|
649
649
|
nlg: "NaturalLanguageGenerator",
|
|
650
650
|
recreate_tracker: bool = False,
|
|
651
|
+
slot_events: Optional[List[Event]] = None,
|
|
651
652
|
) -> List[Event]:
|
|
652
653
|
from rasa.core.actions.action import RemoteAction
|
|
653
654
|
from rasa.shared.core.trackers import DialogueStateTracker
|
|
654
655
|
from rasa.utils.endpoints import ClientResponseError
|
|
655
656
|
|
|
656
|
-
|
|
657
|
+
validated_slot_events: List[Event] = []
|
|
658
|
+
slot_events = slot_events if slot_events is not None else []
|
|
657
659
|
remote_action = RemoteAction(custom_action, self._action_endpoint)
|
|
658
660
|
disallowed_types = set()
|
|
659
661
|
|
|
@@ -673,9 +675,9 @@ class SlotFillingManager:
|
|
|
673
675
|
)
|
|
674
676
|
for event in custom_events:
|
|
675
677
|
if isinstance(event, SlotSet):
|
|
676
|
-
|
|
678
|
+
validated_slot_events.append(event)
|
|
677
679
|
elif isinstance(event, BotUttered):
|
|
678
|
-
|
|
680
|
+
validated_slot_events.append(event)
|
|
679
681
|
else:
|
|
680
682
|
disallowed_types.add(event.type_name)
|
|
681
683
|
except (RasaException, ClientResponseError) as e:
|
|
@@ -699,7 +701,7 @@ class SlotFillingManager:
|
|
|
699
701
|
f"updated with this event.",
|
|
700
702
|
)
|
|
701
703
|
|
|
702
|
-
return
|
|
704
|
+
return validated_slot_events
|
|
703
705
|
|
|
704
706
|
async def execute_validation_action(
|
|
705
707
|
self,
|
|
@@ -722,7 +724,11 @@ class SlotFillingManager:
|
|
|
722
724
|
return cast(List[Event], slot_events)
|
|
723
725
|
|
|
724
726
|
validate_events = await self._run_custom_action(
|
|
725
|
-
ACTION_VALIDATE_SLOT_MAPPINGS,
|
|
727
|
+
ACTION_VALIDATE_SLOT_MAPPINGS,
|
|
728
|
+
output_channel,
|
|
729
|
+
nlg,
|
|
730
|
+
recreate_tracker=True,
|
|
731
|
+
slot_events=cast(List[Event], slot_events),
|
|
726
732
|
)
|
|
727
733
|
validated_slot_names = [
|
|
728
734
|
event.key for event in validate_events if isinstance(event, SlotSet)
|
rasa/shared/nlu/constants.py
CHANGED
rasa/shared/utils/common.py
CHANGED
|
@@ -7,7 +7,17 @@ import os
|
|
|
7
7
|
import pkgutil
|
|
8
8
|
import sys
|
|
9
9
|
from types import ModuleType
|
|
10
|
-
from typing import
|
|
10
|
+
from typing import (
|
|
11
|
+
Any,
|
|
12
|
+
Callable,
|
|
13
|
+
Collection,
|
|
14
|
+
Dict,
|
|
15
|
+
List,
|
|
16
|
+
Optional,
|
|
17
|
+
Sequence,
|
|
18
|
+
Text,
|
|
19
|
+
Type,
|
|
20
|
+
)
|
|
11
21
|
|
|
12
22
|
import rasa.shared.utils.io
|
|
13
23
|
from rasa.exceptions import MissingDependencyException
|
rasa/shared/utils/llm.py
CHANGED
|
@@ -762,7 +762,7 @@ def allowed_values_for_slot(slot: Slot) -> Union[str, None]:
|
|
|
762
762
|
if isinstance(slot, BooleanSlot):
|
|
763
763
|
return str([True, False])
|
|
764
764
|
if isinstance(slot, CategoricalSlot):
|
|
765
|
-
return str([v for v in slot.values if v != "__other__"])
|
|
765
|
+
return str([v for v in slot.values if v != "__other__"] + ["other"])
|
|
766
766
|
else:
|
|
767
767
|
return None
|
|
768
768
|
|
|
@@ -326,7 +326,7 @@ def extract_attrs_for_command(
|
|
|
326
326
|
def extract_llm_config(
|
|
327
327
|
self: Any,
|
|
328
328
|
default_llm_config: Dict[str, Any],
|
|
329
|
-
default_embeddings_config: Dict[str, Any],
|
|
329
|
+
default_embeddings_config: Optional[Dict[str, Any]],
|
|
330
330
|
) -> Dict[str, Any]:
|
|
331
331
|
if isinstance(self, ContextualResponseRephraser):
|
|
332
332
|
# ContextualResponseRephraser is not a graph component, so it's
|
|
@@ -346,8 +346,12 @@ def extract_llm_config(
|
|
|
346
346
|
default_embeddings_config,
|
|
347
347
|
)
|
|
348
348
|
else:
|
|
349
|
-
embeddings_property =
|
|
350
|
-
|
|
349
|
+
embeddings_property = (
|
|
350
|
+
combine_custom_and_default_config(
|
|
351
|
+
config.get(EMBEDDINGS_CONFIG_KEY), default_embeddings_config
|
|
352
|
+
)
|
|
353
|
+
if default_embeddings_config is not None
|
|
354
|
+
else {}
|
|
351
355
|
)
|
|
352
356
|
|
|
353
357
|
attributes = {
|
|
@@ -402,7 +406,7 @@ def extract_attrs_for_contextual_response_rephraser(
|
|
|
402
406
|
self,
|
|
403
407
|
default_llm_config=DEFAULT_LLM_CONFIG,
|
|
404
408
|
# rephraser is not using embeddings
|
|
405
|
-
default_embeddings_config=
|
|
409
|
+
default_embeddings_config=None,
|
|
406
410
|
)
|
|
407
411
|
|
|
408
412
|
return extend_attributes_with_prompt_tokens_length(self, attributes, prompt)
|
|
@@ -418,7 +422,7 @@ def extract_attrs_for_create_history(
|
|
|
418
422
|
self,
|
|
419
423
|
default_llm_config=DEFAULT_LLM_CONFIG,
|
|
420
424
|
# rephraser is not using embeddings
|
|
421
|
-
default_embeddings_config=
|
|
425
|
+
default_embeddings_config=None,
|
|
422
426
|
)
|
|
423
427
|
|
|
424
428
|
|
|
@@ -771,14 +775,13 @@ def extract_attrs_for_enterprise_search_generate_llm_answer(
|
|
|
771
775
|
self: "EnterpriseSearchPolicy", llm: "BaseLLM", prompt: str
|
|
772
776
|
) -> Dict[str, Any]:
|
|
773
777
|
from rasa.core.policies.enterprise_search_policy import (
|
|
774
|
-
DEFAULT_EMBEDDINGS_CONFIG,
|
|
775
778
|
DEFAULT_LLM_CONFIG,
|
|
776
779
|
)
|
|
777
780
|
|
|
778
781
|
attributes = extract_llm_config(
|
|
779
782
|
self,
|
|
780
783
|
default_llm_config=DEFAULT_LLM_CONFIG,
|
|
781
|
-
default_embeddings_config=
|
|
784
|
+
default_embeddings_config=None,
|
|
782
785
|
)
|
|
783
786
|
|
|
784
787
|
return extend_attributes_with_prompt_tokens_length(self, attributes, prompt)
|