rasa-pro 3.9.18__py3-none-any.whl → 3.10.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- README.md +26 -57
- rasa/__init__.py +1 -2
- rasa/__main__.py +5 -0
- rasa/anonymization/anonymization_rule_executor.py +2 -2
- rasa/api.py +26 -22
- rasa/cli/arguments/data.py +27 -2
- rasa/cli/arguments/default_arguments.py +25 -3
- rasa/cli/arguments/run.py +9 -9
- rasa/cli/arguments/train.py +2 -0
- rasa/cli/data.py +70 -8
- rasa/cli/e2e_test.py +108 -433
- rasa/cli/interactive.py +1 -0
- rasa/cli/llm_fine_tuning.py +395 -0
- rasa/cli/project_templates/calm/endpoints.yml +1 -1
- rasa/cli/project_templates/tutorial/endpoints.yml +1 -1
- rasa/cli/run.py +14 -13
- rasa/cli/scaffold.py +10 -8
- rasa/cli/train.py +8 -7
- rasa/cli/utils.py +15 -0
- rasa/constants.py +7 -1
- rasa/core/actions/action.py +98 -49
- rasa/core/actions/action_run_slot_rejections.py +4 -1
- rasa/core/actions/custom_action_executor.py +9 -6
- rasa/core/actions/direct_custom_actions_executor.py +80 -0
- rasa/core/actions/e2e_stub_custom_action_executor.py +68 -0
- rasa/core/actions/grpc_custom_action_executor.py +2 -2
- rasa/core/actions/http_custom_action_executor.py +6 -5
- rasa/core/agent.py +21 -17
- rasa/core/channels/__init__.py +2 -0
- rasa/core/channels/audiocodes.py +1 -16
- rasa/core/channels/inspector/dist/index.html +0 -2
- rasa/core/channels/inspector/index.html +0 -2
- rasa/core/channels/voice_aware/__init__.py +0 -0
- rasa/core/channels/voice_aware/jambonz.py +103 -0
- rasa/core/channels/voice_aware/jambonz_protocol.py +344 -0
- rasa/core/channels/voice_aware/utils.py +20 -0
- rasa/core/channels/voice_native/__init__.py +0 -0
- rasa/core/constants.py +6 -1
- rasa/core/featurizers/single_state_featurizer.py +1 -22
- rasa/core/featurizers/tracker_featurizers.py +18 -115
- rasa/core/information_retrieval/faiss.py +7 -4
- rasa/core/information_retrieval/information_retrieval.py +8 -0
- rasa/core/information_retrieval/milvus.py +9 -2
- rasa/core/information_retrieval/qdrant.py +1 -1
- rasa/core/nlg/contextual_response_rephraser.py +32 -10
- rasa/core/nlg/summarize.py +4 -3
- rasa/core/policies/enterprise_search_policy.py +100 -44
- rasa/core/policies/flows/flow_executor.py +130 -94
- rasa/core/policies/intentless_policy.py +52 -28
- rasa/core/policies/ted_policy.py +33 -58
- rasa/core/policies/unexpected_intent_policy.py +7 -15
- rasa/core/processor.py +20 -53
- rasa/core/run.py +5 -4
- rasa/core/tracker_store.py +8 -4
- rasa/core/utils.py +45 -56
- rasa/dialogue_understanding/coexistence/llm_based_router.py +45 -12
- rasa/dialogue_understanding/commands/__init__.py +4 -0
- rasa/dialogue_understanding/commands/change_flow_command.py +0 -6
- rasa/dialogue_understanding/commands/session_start_command.py +59 -0
- rasa/dialogue_understanding/commands/set_slot_command.py +1 -5
- rasa/dialogue_understanding/commands/utils.py +38 -0
- rasa/dialogue_understanding/generator/constants.py +10 -3
- rasa/dialogue_understanding/generator/flow_retrieval.py +14 -5
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +12 -2
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +106 -87
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +28 -6
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +90 -37
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +15 -15
- rasa/dialogue_understanding/patterns/session_start.py +37 -0
- rasa/dialogue_understanding/processor/command_processor.py +13 -14
- rasa/e2e_test/aggregate_test_stats_calculator.py +124 -0
- rasa/e2e_test/assertions.py +1181 -0
- rasa/e2e_test/assertions_schema.yml +106 -0
- rasa/e2e_test/constants.py +20 -0
- rasa/e2e_test/e2e_config.py +220 -0
- rasa/e2e_test/e2e_config_schema.yml +26 -0
- rasa/e2e_test/e2e_test_case.py +131 -8
- rasa/e2e_test/e2e_test_converter.py +363 -0
- rasa/e2e_test/e2e_test_converter_prompt.jinja2 +70 -0
- rasa/e2e_test/e2e_test_coverage_report.py +364 -0
- rasa/e2e_test/e2e_test_result.py +26 -6
- rasa/e2e_test/e2e_test_runner.py +491 -72
- rasa/e2e_test/e2e_test_schema.yml +96 -0
- rasa/e2e_test/pykwalify_extensions.py +39 -0
- rasa/e2e_test/stub_custom_action.py +70 -0
- rasa/e2e_test/utils/__init__.py +0 -0
- rasa/e2e_test/utils/e2e_yaml_utils.py +55 -0
- rasa/e2e_test/utils/io.py +596 -0
- rasa/e2e_test/utils/validation.py +80 -0
- rasa/engine/recipes/default_components.py +0 -2
- rasa/engine/storage/local_model_storage.py +0 -1
- rasa/env.py +9 -0
- rasa/keys +1 -0
- rasa/llm_fine_tuning/__init__.py +0 -0
- rasa/llm_fine_tuning/annotation_module.py +241 -0
- rasa/llm_fine_tuning/conversations.py +144 -0
- rasa/llm_fine_tuning/llm_data_preparation_module.py +178 -0
- rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +407 -0
- rasa/llm_fine_tuning/paraphrasing/__init__.py +0 -0
- rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +281 -0
- rasa/llm_fine_tuning/paraphrasing/default_rephrase_prompt_template.jina2 +44 -0
- rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +121 -0
- rasa/llm_fine_tuning/paraphrasing/rephrased_user_message.py +10 -0
- rasa/llm_fine_tuning/paraphrasing_module.py +128 -0
- rasa/llm_fine_tuning/storage.py +174 -0
- rasa/llm_fine_tuning/train_test_split_module.py +441 -0
- rasa/model_training.py +48 -16
- rasa/nlu/classifiers/diet_classifier.py +25 -38
- rasa/nlu/classifiers/logistic_regression_classifier.py +9 -44
- rasa/nlu/classifiers/sklearn_intent_classifier.py +16 -37
- rasa/nlu/extractors/crf_entity_extractor.py +50 -93
- rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +45 -78
- rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +17 -52
- rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +3 -5
- rasa/nlu/persistor.py +129 -32
- rasa/server.py +45 -10
- rasa/shared/constants.py +63 -15
- rasa/shared/core/domain.py +15 -12
- rasa/shared/core/events.py +28 -2
- rasa/shared/core/flows/flow.py +208 -13
- rasa/shared/core/flows/flow_path.py +84 -0
- rasa/shared/core/flows/flows_list.py +28 -10
- rasa/shared/core/flows/flows_yaml_schema.json +269 -193
- rasa/shared/core/flows/validation.py +112 -25
- rasa/shared/core/flows/yaml_flows_io.py +149 -10
- rasa/shared/core/trackers.py +6 -0
- rasa/shared/core/training_data/visualization.html +2 -2
- rasa/shared/exceptions.py +4 -0
- rasa/shared/importers/importer.py +60 -11
- rasa/shared/importers/remote_importer.py +196 -0
- rasa/shared/nlu/constants.py +2 -0
- rasa/shared/nlu/training_data/features.py +2 -120
- rasa/shared/providers/_configs/__init__.py +0 -0
- rasa/shared/providers/_configs/azure_openai_client_config.py +181 -0
- rasa/shared/providers/_configs/client_config.py +57 -0
- rasa/shared/providers/_configs/default_litellm_client_config.py +130 -0
- rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +234 -0
- rasa/shared/providers/_configs/openai_client_config.py +175 -0
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +171 -0
- rasa/shared/providers/_configs/utils.py +101 -0
- rasa/shared/providers/_ssl_verification_utils.py +124 -0
- rasa/shared/providers/embedding/__init__.py +0 -0
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +254 -0
- rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py +74 -0
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +277 -0
- rasa/shared/providers/embedding/default_litellm_embedding_client.py +102 -0
- rasa/shared/providers/embedding/embedding_client.py +90 -0
- rasa/shared/providers/embedding/embedding_response.py +41 -0
- rasa/shared/providers/embedding/huggingface_local_embedding_client.py +191 -0
- rasa/shared/providers/embedding/openai_embedding_client.py +172 -0
- rasa/shared/providers/llm/__init__.py +0 -0
- rasa/shared/providers/llm/_base_litellm_client.py +227 -0
- rasa/shared/providers/llm/azure_openai_llm_client.py +338 -0
- rasa/shared/providers/llm/default_litellm_llm_client.py +84 -0
- rasa/shared/providers/llm/llm_client.py +76 -0
- rasa/shared/providers/llm/llm_response.py +50 -0
- rasa/shared/providers/llm/openai_llm_client.py +155 -0
- rasa/shared/providers/llm/self_hosted_llm_client.py +169 -0
- rasa/shared/providers/mappings.py +75 -0
- rasa/shared/utils/cli.py +30 -0
- rasa/shared/utils/io.py +65 -3
- rasa/shared/utils/llm.py +223 -200
- rasa/shared/utils/yaml.py +122 -7
- rasa/studio/download.py +19 -13
- rasa/studio/train.py +2 -3
- rasa/studio/upload.py +2 -3
- rasa/telemetry.py +113 -58
- rasa/tracing/config.py +2 -3
- rasa/tracing/instrumentation/attribute_extractors.py +29 -17
- rasa/tracing/instrumentation/instrumentation.py +4 -47
- rasa/utils/common.py +18 -19
- rasa/utils/endpoints.py +7 -4
- rasa/utils/io.py +66 -0
- rasa/utils/json_utils.py +60 -0
- rasa/utils/licensing.py +9 -1
- rasa/utils/ml_utils.py +4 -2
- rasa/utils/tensorflow/model_data.py +193 -2
- rasa/validator.py +196 -1
- rasa/version.py +1 -1
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.4.dist-info}/METADATA +47 -72
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.4.dist-info}/RECORD +186 -121
- rasa/nlu/classifiers/llm_intent_classifier.py +0 -519
- rasa/shared/providers/openai/clients.py +0 -43
- rasa/shared/providers/openai/session_handler.py +0 -110
- rasa/utils/tensorflow/feature_array.py +0 -366
- /rasa/{shared/providers/openai → cli/project_templates/tutorial/actions}/__init__.py +0 -0
- /rasa/cli/project_templates/tutorial/{actions.py → actions/actions.py} +0 -0
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.4.dist-info}/NOTICE +0 -0
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.4.dist-info}/WHEEL +0 -0
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.4.dist-info}/entry_points.txt +0 -0
|
@@ -8,6 +8,11 @@ import structlog
|
|
|
8
8
|
from jinja2 import Template
|
|
9
9
|
from pydantic import ValidationError
|
|
10
10
|
|
|
11
|
+
from rasa.shared.providers.embedding._langchain_embedding_client_adapter import (
|
|
12
|
+
_LangchainEmbeddingClientAdapter,
|
|
13
|
+
)
|
|
14
|
+
from rasa.shared.providers.llm.llm_client import LLMClient
|
|
15
|
+
|
|
11
16
|
import rasa.shared.utils.io
|
|
12
17
|
from rasa.telemetry import (
|
|
13
18
|
track_enterprise_search_policy_predict,
|
|
@@ -19,6 +24,7 @@ from rasa.core.constants import (
|
|
|
19
24
|
POLICY_MAX_HISTORY,
|
|
20
25
|
POLICY_PRIORITY,
|
|
21
26
|
SEARCH_POLICY_PRIORITY,
|
|
27
|
+
UTTER_SOURCE_METADATA_KEY,
|
|
22
28
|
)
|
|
23
29
|
from rasa.core.policies.policy import Policy, PolicyPrediction
|
|
24
30
|
from rasa.core.utils import AvailableEndpoints
|
|
@@ -39,13 +45,23 @@ from rasa.engine.storage.resource import Resource
|
|
|
39
45
|
from rasa.engine.storage.storage import ModelStorage
|
|
40
46
|
from rasa.graph_components.providers.forms_provider import Forms
|
|
41
47
|
from rasa.graph_components.providers.responses_provider import Responses
|
|
48
|
+
from rasa.shared.constants import (
|
|
49
|
+
EMBEDDINGS_CONFIG_KEY,
|
|
50
|
+
LLM_CONFIG_KEY,
|
|
51
|
+
MODEL_CONFIG_KEY,
|
|
52
|
+
MODEL_NAME_CONFIG_KEY,
|
|
53
|
+
PROMPT_CONFIG_KEY,
|
|
54
|
+
PROVIDER_CONFIG_KEY,
|
|
55
|
+
OPENAI_PROVIDER,
|
|
56
|
+
TIMEOUT_CONFIG_KEY,
|
|
57
|
+
)
|
|
42
58
|
from rasa.shared.core.constants import (
|
|
43
59
|
ACTION_CANCEL_FLOW,
|
|
44
60
|
ACTION_SEND_TEXT_NAME,
|
|
45
61
|
DEFAULT_SLOT_NAMES,
|
|
46
62
|
)
|
|
47
63
|
from rasa.shared.core.domain import Domain
|
|
48
|
-
from rasa.shared.core.events import Event
|
|
64
|
+
from rasa.shared.core.events import Event, UserUttered, BotUttered
|
|
49
65
|
from rasa.shared.core.generator import TrackerWithCachedStates
|
|
50
66
|
from rasa.shared.core.trackers import DialogueStateTracker, EventVerbosity
|
|
51
67
|
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
@@ -59,6 +75,7 @@ from rasa.shared.utils.llm import (
|
|
|
59
75
|
llm_factory,
|
|
60
76
|
sanitize_message_for_prompt,
|
|
61
77
|
tracker_as_readable_transcript,
|
|
78
|
+
try_instantiate_llm_client,
|
|
62
79
|
)
|
|
63
80
|
from rasa.core.information_retrieval.faiss import FAISS_Store
|
|
64
81
|
from rasa.core.information_retrieval import (
|
|
@@ -70,7 +87,6 @@ from rasa.core.information_retrieval import (
|
|
|
70
87
|
|
|
71
88
|
if TYPE_CHECKING:
|
|
72
89
|
from langchain.schema.embeddings import Embeddings
|
|
73
|
-
from langchain.llms.base import BaseLLM
|
|
74
90
|
from rasa.core.featurizers.tracker_featurizers import TrackerFeaturizer
|
|
75
91
|
|
|
76
92
|
from rasa.utils.log_utils import log_llm
|
|
@@ -86,6 +102,7 @@ VECTOR_STORE_THRESHOLD_PROPERTY = "threshold"
|
|
|
86
102
|
TRACE_TOKENS_PROPERTY = "trace_prompt_tokens"
|
|
87
103
|
CITATION_ENABLED_PROPERTY = "citation_enabled"
|
|
88
104
|
USE_LLM_PROPERTY = "use_generative_llm"
|
|
105
|
+
MAX_MESSAGES_IN_QUERY_KEY = "max_messages_in_query"
|
|
89
106
|
|
|
90
107
|
DEFAULT_VECTOR_STORE_TYPE = "faiss"
|
|
91
108
|
DEFAULT_VECTOR_STORE_THRESHOLD = 0.0
|
|
@@ -96,23 +113,24 @@ DEFAULT_VECTOR_STORE = {
|
|
|
96
113
|
}
|
|
97
114
|
|
|
98
115
|
DEFAULT_LLM_CONFIG = {
|
|
99
|
-
|
|
100
|
-
|
|
116
|
+
PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
|
|
117
|
+
MODEL_CONFIG_KEY: DEFAULT_OPENAI_CHAT_MODEL_NAME,
|
|
118
|
+
TIMEOUT_CONFIG_KEY: 10,
|
|
101
119
|
"temperature": 0.0,
|
|
102
120
|
"max_tokens": 256,
|
|
103
|
-
"model_name": DEFAULT_OPENAI_CHAT_MODEL_NAME,
|
|
104
121
|
"max_retries": 1,
|
|
105
122
|
}
|
|
106
123
|
|
|
107
124
|
DEFAULT_EMBEDDINGS_CONFIG = {
|
|
108
|
-
|
|
125
|
+
PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
|
|
109
126
|
"model": DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
|
|
110
127
|
}
|
|
111
128
|
|
|
112
|
-
EMBEDDINGS_CONFIG_KEY = "embeddings"
|
|
113
|
-
LLM_CONFIG_KEY = "llm"
|
|
114
129
|
ENTERPRISE_SEARCH_PROMPT_FILE_NAME = "enterprise_search_policy_prompt.jinja2"
|
|
115
130
|
|
|
131
|
+
SEARCH_RESULTS_METADATA_KEY = "search_results"
|
|
132
|
+
SEARCH_QUERY_METADATA_KEY = "search_query"
|
|
133
|
+
|
|
116
134
|
DEFAULT_ENTERPRISE_SEARCH_PROMPT_TEMPLATE = importlib.resources.read_text(
|
|
117
135
|
"rasa.core.policies", "enterprise_search_prompt_template.jinja2"
|
|
118
136
|
)
|
|
@@ -179,26 +197,42 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
179
197
|
"""Constructs a new Policy object."""
|
|
180
198
|
super().__init__(config, model_storage, resource, execution_context, featurizer)
|
|
181
199
|
|
|
200
|
+
# Vector store object and configuration
|
|
182
201
|
self.vector_store = vector_store
|
|
183
202
|
self.vector_store_config = config.get(
|
|
184
203
|
VECTOR_STORE_PROPERTY, DEFAULT_VECTOR_STORE
|
|
185
204
|
)
|
|
186
|
-
|
|
205
|
+
# Embeddings configuration for encoding the search query
|
|
187
206
|
self.embeddings_config = self.config.get(
|
|
188
207
|
EMBEDDINGS_CONFIG_KEY, DEFAULT_EMBEDDINGS_CONFIG
|
|
189
208
|
)
|
|
209
|
+
# Maximum number of turns to include in the prompt
|
|
190
210
|
self.max_history = self.config.get(POLICY_MAX_HISTORY)
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
211
|
+
|
|
212
|
+
# Maximum number of messages to include in the search query
|
|
213
|
+
self.max_messages_in_query = self.config.get(MAX_MESSAGES_IN_QUERY_KEY, 2)
|
|
214
|
+
|
|
215
|
+
# LLM Configuration for response generation
|
|
216
|
+
self.llm_config = self.config.get(LLM_CONFIG_KEY, DEFAULT_LLM_CONFIG)
|
|
217
|
+
|
|
218
|
+
# boolean to enable/disable tracing of prompt tokens
|
|
195
219
|
self.trace_prompt_tokens = self.config.get(TRACE_TOKENS_PROPERTY, False)
|
|
220
|
+
|
|
221
|
+
# boolean to enable/disable the use of LLM for response generation
|
|
196
222
|
self.use_llm = self.config.get(USE_LLM_PROPERTY, True)
|
|
223
|
+
|
|
224
|
+
# boolean to enable/disable citation generation
|
|
197
225
|
self.citation_enabled = self.config.get(CITATION_ENABLED_PROPERTY, False)
|
|
226
|
+
|
|
227
|
+
self.prompt_template = prompt_template or get_prompt_template(
|
|
228
|
+
self.config.get(PROMPT_CONFIG_KEY),
|
|
229
|
+
DEFAULT_ENTERPRISE_SEARCH_PROMPT_TEMPLATE,
|
|
230
|
+
)
|
|
198
231
|
self.citation_prompt_template = get_prompt_template(
|
|
199
|
-
self.config.get(
|
|
232
|
+
self.config.get(PROMPT_CONFIG_KEY),
|
|
200
233
|
DEFAULT_ENTERPRISE_SEARCH_PROMPT_WITH_CITATION_TEMPLATE,
|
|
201
234
|
)
|
|
235
|
+
# If citation is enabled, use the citation prompt template
|
|
202
236
|
if self.citation_enabled:
|
|
203
237
|
self.prompt_template = self.citation_prompt_template
|
|
204
238
|
|
|
@@ -209,9 +243,10 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
209
243
|
Returns:
|
|
210
244
|
The embedder.
|
|
211
245
|
"""
|
|
212
|
-
|
|
246
|
+
client = embedder_factory(
|
|
213
247
|
config.get(EMBEDDINGS_CONFIG_KEY), DEFAULT_EMBEDDINGS_CONFIG
|
|
214
248
|
)
|
|
249
|
+
return _LangchainEmbeddingClientAdapter(client)
|
|
215
250
|
|
|
216
251
|
def train( # type: ignore[override]
|
|
217
252
|
self,
|
|
@@ -245,20 +280,24 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
245
280
|
# validate embedding configuration
|
|
246
281
|
try:
|
|
247
282
|
embeddings = self._create_plain_embedder(self.config)
|
|
248
|
-
except ValidationError as e:
|
|
283
|
+
except (ValidationError, Exception) as e:
|
|
284
|
+
logger.error(
|
|
285
|
+
"enterprise_search_policy.train.embedder_instantiation_failed",
|
|
286
|
+
message="Unable to instantiate the embedding client.",
|
|
287
|
+
error=e,
|
|
288
|
+
)
|
|
249
289
|
print_error_and_exit(
|
|
250
290
|
"Unable to create embedder. Please make sure you specified the "
|
|
251
291
|
f"required environment variables. Error: {e}"
|
|
252
292
|
)
|
|
253
293
|
|
|
254
294
|
# validate llm configuration
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
print_error_and_exit(f"Unable to create LLM. Error: {e}")
|
|
295
|
+
try_instantiate_llm_client(
|
|
296
|
+
self.config.get(LLM_CONFIG_KEY),
|
|
297
|
+
DEFAULT_LLM_CONFIG,
|
|
298
|
+
"enterprise_search_policy.train",
|
|
299
|
+
"EnterpriseSearchPolicy",
|
|
300
|
+
)
|
|
262
301
|
|
|
263
302
|
if store_type == DEFAULT_VECTOR_STORE_TYPE:
|
|
264
303
|
logger.info("enterprise_search_policy.train.faiss")
|
|
@@ -275,11 +314,12 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
275
314
|
# telemetry call to track training completion
|
|
276
315
|
track_enterprise_search_policy_train_completed(
|
|
277
316
|
vector_store_type=store_type,
|
|
278
|
-
embeddings_type=self.embeddings_config.get(
|
|
279
|
-
embeddings_model=self.embeddings_config.get(
|
|
280
|
-
or self.embeddings_config.get(
|
|
281
|
-
llm_type=self.llm_config.get(
|
|
282
|
-
llm_model=self.llm_config.get(
|
|
317
|
+
embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
|
|
318
|
+
embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
|
|
319
|
+
or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
|
|
320
|
+
llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
|
|
321
|
+
llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
|
|
322
|
+
or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
|
|
283
323
|
citation_enabled=self.citation_enabled,
|
|
284
324
|
)
|
|
285
325
|
self.persist()
|
|
@@ -343,24 +383,31 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
343
383
|
logger.error(
|
|
344
384
|
"enterprise_search_policy._connect_vector_store_or_raise.connect_error",
|
|
345
385
|
error=e,
|
|
386
|
+
config=config,
|
|
346
387
|
)
|
|
347
388
|
raise VectorStoreConnectionError(
|
|
348
389
|
f"Unable to connect to the vector store. Error: {e}"
|
|
349
390
|
)
|
|
350
391
|
|
|
351
|
-
def
|
|
352
|
-
"""
|
|
392
|
+
def _prepare_search_query(self, tracker: DialogueStateTracker, history: int) -> str:
|
|
393
|
+
"""Prepares the search query.
|
|
394
|
+
The search query is the last N messages in the conversation history.
|
|
353
395
|
|
|
354
396
|
Args:
|
|
355
397
|
tracker: The tracker containing the conversation history up to now.
|
|
398
|
+
history: The number of messages to include in the search query.
|
|
356
399
|
|
|
357
400
|
Returns:
|
|
358
|
-
The
|
|
401
|
+
The search query.
|
|
359
402
|
"""
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
403
|
+
transcript = []
|
|
404
|
+
for event in tracker.applied_events():
|
|
405
|
+
if isinstance(event, UserUttered) or isinstance(event, BotUttered):
|
|
406
|
+
transcript.append(sanitize_message_for_prompt(event.text))
|
|
407
|
+
|
|
408
|
+
search_query = " ".join(transcript[-history:][::-1])
|
|
409
|
+
logger.debug("search_query", search_query=search_query)
|
|
410
|
+
return search_query
|
|
364
411
|
|
|
365
412
|
async def predict_action_probabilities( # type: ignore[override]
|
|
366
413
|
self,
|
|
@@ -404,7 +451,9 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
404
451
|
logger.error(f"{logger_key}.connection_error", error=e)
|
|
405
452
|
return self._create_prediction_internal_error(domain, tracker)
|
|
406
453
|
|
|
407
|
-
search_query = self.
|
|
454
|
+
search_query = self._prepare_search_query(
|
|
455
|
+
tracker, int(self.max_messages_in_query)
|
|
456
|
+
)
|
|
408
457
|
tracker_state = tracker.current_state(EventVerbosity.AFTER_RESTART)
|
|
409
458
|
|
|
410
459
|
try:
|
|
@@ -448,17 +497,23 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
448
497
|
action_metadata = {
|
|
449
498
|
"message": {
|
|
450
499
|
"text": response,
|
|
500
|
+
SEARCH_RESULTS_METADATA_KEY: [
|
|
501
|
+
result.text for result in documents.results
|
|
502
|
+
],
|
|
503
|
+
UTTER_SOURCE_METADATA_KEY: self.__class__.__name__,
|
|
504
|
+
SEARCH_QUERY_METADATA_KEY: search_query,
|
|
451
505
|
}
|
|
452
506
|
}
|
|
453
507
|
|
|
454
508
|
# telemetry call to track policy prediction
|
|
455
509
|
track_enterprise_search_policy_predict(
|
|
456
510
|
vector_store_type=self.vector_store_config.get(VECTOR_STORE_TYPE_PROPERTY),
|
|
457
|
-
embeddings_type=self.embeddings_config.get(
|
|
458
|
-
embeddings_model=self.embeddings_config.get(
|
|
459
|
-
or self.embeddings_config.get(
|
|
460
|
-
llm_type=self.llm_config.get(
|
|
461
|
-
llm_model=self.llm_config.get(
|
|
511
|
+
embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
|
|
512
|
+
embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
|
|
513
|
+
or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
|
|
514
|
+
llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
|
|
515
|
+
llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
|
|
516
|
+
or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
|
|
462
517
|
citation_enabled=self.citation_enabled,
|
|
463
518
|
)
|
|
464
519
|
return self._create_prediction(
|
|
@@ -495,10 +550,11 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
495
550
|
return prompt
|
|
496
551
|
|
|
497
552
|
async def _generate_llm_answer(
|
|
498
|
-
self, llm:
|
|
553
|
+
self, llm: LLMClient, prompt: Text
|
|
499
554
|
) -> Optional[Text]:
|
|
500
555
|
try:
|
|
501
|
-
|
|
556
|
+
llm_response = await llm.acompletion(prompt)
|
|
557
|
+
llm_answer = llm_response.choices[0]
|
|
502
558
|
except Exception as e:
|
|
503
559
|
# unfortunately, langchain does not wrap LLM exceptions which means
|
|
504
560
|
# we have to catch all exceptions here
|
|
@@ -684,7 +740,7 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
684
740
|
local_knowledge_data = cls._get_local_knowledge_data(config)
|
|
685
741
|
|
|
686
742
|
prompt_template = get_prompt_template(
|
|
687
|
-
config.get(
|
|
743
|
+
config.get(PROMPT_CONFIG_KEY),
|
|
688
744
|
DEFAULT_ENTERPRISE_SEARCH_PROMPT_TEMPLATE,
|
|
689
745
|
)
|
|
690
746
|
return deep_container_fingerprint([prompt_template, local_knowledge_data])
|
|
@@ -2,12 +2,14 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from typing import Any, Dict, Text, List, Optional
|
|
4
4
|
|
|
5
|
+
import structlog
|
|
5
6
|
from jinja2 import Template
|
|
6
|
-
from
|
|
7
|
-
from rasa.dialogue_understanding.patterns.cancel import CancelPatternFlowStackFrame
|
|
7
|
+
from pypred import Predicate
|
|
8
8
|
from structlog.contextvars import (
|
|
9
9
|
bound_contextvars,
|
|
10
10
|
)
|
|
11
|
+
|
|
12
|
+
from rasa.core.constants import STEP_ID_METADATA_KEY, ACTIVE_FLOW_METADATA_KEY
|
|
11
13
|
from rasa.core.policies.flows.flow_exceptions import (
|
|
12
14
|
FlowCircuitBreakerTrippedException,
|
|
13
15
|
FlowException,
|
|
@@ -19,6 +21,20 @@ from rasa.core.policies.flows.flow_step_result import (
|
|
|
19
21
|
FlowStepResult,
|
|
20
22
|
PauseFlowReturnPrediction,
|
|
21
23
|
)
|
|
24
|
+
from rasa.dialogue_understanding.commands import CancelFlowCommand
|
|
25
|
+
from rasa.dialogue_understanding.patterns.cancel import CancelPatternFlowStackFrame
|
|
26
|
+
from rasa.dialogue_understanding.patterns.collect_information import (
|
|
27
|
+
CollectInformationPatternFlowStackFrame,
|
|
28
|
+
)
|
|
29
|
+
from rasa.dialogue_understanding.patterns.completed import (
|
|
30
|
+
CompletedPatternFlowStackFrame,
|
|
31
|
+
)
|
|
32
|
+
from rasa.dialogue_understanding.patterns.continue_interrupted import (
|
|
33
|
+
ContinueInterruptedPatternFlowStackFrame,
|
|
34
|
+
)
|
|
35
|
+
from rasa.dialogue_understanding.patterns.human_handoff import (
|
|
36
|
+
HumanHandoffPatternFlowStackFrame,
|
|
37
|
+
)
|
|
22
38
|
from rasa.dialogue_understanding.patterns.internal_error import (
|
|
23
39
|
InternalErrorPatternFlowStackFrame,
|
|
24
40
|
)
|
|
@@ -29,24 +45,13 @@ from rasa.dialogue_understanding.stack.frames import (
|
|
|
29
45
|
DialogueStackFrame,
|
|
30
46
|
UserFlowStackFrame,
|
|
31
47
|
)
|
|
32
|
-
from rasa.dialogue_understanding.patterns.collect_information import (
|
|
33
|
-
CollectInformationPatternFlowStackFrame,
|
|
34
|
-
)
|
|
35
|
-
from rasa.dialogue_understanding.patterns.completed import (
|
|
36
|
-
CompletedPatternFlowStackFrame,
|
|
37
|
-
)
|
|
38
|
-
from rasa.dialogue_understanding.patterns.continue_interrupted import (
|
|
39
|
-
ContinueInterruptedPatternFlowStackFrame,
|
|
40
|
-
)
|
|
41
48
|
from rasa.dialogue_understanding.stack.frames.flow_stack_frame import (
|
|
42
49
|
FlowStackFrameType,
|
|
43
50
|
)
|
|
44
51
|
from rasa.dialogue_understanding.stack.utils import (
|
|
45
52
|
top_user_flow_frame,
|
|
46
53
|
)
|
|
47
|
-
|
|
48
|
-
from pypred import Predicate
|
|
49
|
-
|
|
54
|
+
from rasa.shared.constants import RASA_PATTERN_HUMAN_HANDOFF
|
|
50
55
|
from rasa.shared.core.constants import ACTION_LISTEN_NAME, SlotMappingType
|
|
51
56
|
from rasa.shared.core.events import (
|
|
52
57
|
Event,
|
|
@@ -56,6 +61,11 @@ from rasa.shared.core.events import (
|
|
|
56
61
|
SlotSet,
|
|
57
62
|
)
|
|
58
63
|
from rasa.shared.core.flows import FlowsList
|
|
64
|
+
from rasa.shared.core.flows.flow import (
|
|
65
|
+
END_STEP,
|
|
66
|
+
Flow,
|
|
67
|
+
FlowStep,
|
|
68
|
+
)
|
|
59
69
|
from rasa.shared.core.flows.flow_step_links import (
|
|
60
70
|
StaticFlowStepLink,
|
|
61
71
|
IfFlowStepLink,
|
|
@@ -71,17 +81,11 @@ from rasa.shared.core.flows.steps import (
|
|
|
71
81
|
CollectInformationFlowStep,
|
|
72
82
|
NoOperationFlowStep,
|
|
73
83
|
)
|
|
74
|
-
from rasa.shared.core.flows.flow import (
|
|
75
|
-
END_STEP,
|
|
76
|
-
Flow,
|
|
77
|
-
FlowStep,
|
|
78
|
-
)
|
|
79
84
|
from rasa.shared.core.flows.steps.collect import SlotRejection
|
|
80
85
|
from rasa.shared.core.slots import Slot
|
|
81
86
|
from rasa.shared.core.trackers import (
|
|
82
87
|
DialogueStateTracker,
|
|
83
88
|
)
|
|
84
|
-
import structlog
|
|
85
89
|
|
|
86
90
|
structlogger = structlog.get_logger()
|
|
87
91
|
|
|
@@ -466,6 +470,10 @@ def advance_flows_until_next_action(
|
|
|
466
470
|
# make sure we really return all events that got created during the
|
|
467
471
|
# step execution of all steps (not only the last one)
|
|
468
472
|
prediction.events = gathered_events
|
|
473
|
+
prediction.metadata = {
|
|
474
|
+
ACTIVE_FLOW_METADATA_KEY: tracker.active_flow,
|
|
475
|
+
STEP_ID_METADATA_KEY: tracker.current_step_id,
|
|
476
|
+
}
|
|
469
477
|
return prediction
|
|
470
478
|
else:
|
|
471
479
|
structlogger.warning("flow.step.execution.no_action")
|
|
@@ -476,15 +484,15 @@ def validate_collect_step(
|
|
|
476
484
|
step: CollectInformationFlowStep,
|
|
477
485
|
stack: DialogueStack,
|
|
478
486
|
available_actions: List[str],
|
|
479
|
-
slots: Dict[
|
|
480
|
-
flow_name: str,
|
|
487
|
+
slots: Dict[Text, Slot],
|
|
481
488
|
) -> bool:
|
|
482
489
|
"""Validate that a collect step can be executed.
|
|
483
490
|
|
|
484
491
|
A collect step can be executed if either the `utter_ask` or the `action_ask` is
|
|
485
492
|
defined in the domain. If neither is defined, the collect step can still be
|
|
486
493
|
executed if the slot has an initial value defined in the domain, which would cause
|
|
487
|
-
the step to be skipped.
|
|
494
|
+
the step to be skipped.
|
|
495
|
+
"""
|
|
488
496
|
slot = slots.get(step.collect)
|
|
489
497
|
slot_has_initial_value_defined = slot and slot.initial_value is not None
|
|
490
498
|
if (
|
|
@@ -499,12 +507,12 @@ def validate_collect_step(
|
|
|
499
507
|
slot_name=step.collect,
|
|
500
508
|
)
|
|
501
509
|
|
|
502
|
-
cancel_flow_and_push_internal_error(stack
|
|
510
|
+
cancel_flow_and_push_internal_error(stack)
|
|
503
511
|
|
|
504
512
|
return False
|
|
505
513
|
|
|
506
514
|
|
|
507
|
-
def cancel_flow_and_push_internal_error(stack: DialogueStack
|
|
515
|
+
def cancel_flow_and_push_internal_error(stack: DialogueStack) -> None:
|
|
508
516
|
"""Cancel the top user flow and push the internal error pattern."""
|
|
509
517
|
top_frame = stack.top()
|
|
510
518
|
|
|
@@ -516,7 +524,7 @@ def cancel_flow_and_push_internal_error(stack: DialogueStack, flow_name: str) ->
|
|
|
516
524
|
canceled_frames = CancelFlowCommand.select_canceled_frames(stack)
|
|
517
525
|
stack.push(
|
|
518
526
|
CancelPatternFlowStackFrame(
|
|
519
|
-
canceled_name=
|
|
527
|
+
canceled_name=top_frame.flow_id,
|
|
520
528
|
canceled_frames=canceled_frames,
|
|
521
529
|
)
|
|
522
530
|
)
|
|
@@ -528,7 +536,6 @@ def validate_custom_slot_mappings(
|
|
|
528
536
|
stack: DialogueStack,
|
|
529
537
|
tracker: DialogueStateTracker,
|
|
530
538
|
available_actions: List[str],
|
|
531
|
-
flow_name: str,
|
|
532
539
|
) -> bool:
|
|
533
540
|
"""Validate a slot with custom mappings.
|
|
534
541
|
|
|
@@ -549,7 +556,7 @@ def validate_custom_slot_mappings(
|
|
|
549
556
|
action=step.collect_action,
|
|
550
557
|
collect=step.collect,
|
|
551
558
|
)
|
|
552
|
-
cancel_flow_and_push_internal_error(stack
|
|
559
|
+
cancel_flow_and_push_internal_error(stack)
|
|
553
560
|
return False
|
|
554
561
|
|
|
555
562
|
return True
|
|
@@ -585,110 +592,139 @@ def run_step(
|
|
|
585
592
|
"""
|
|
586
593
|
initial_events: List[Event] = []
|
|
587
594
|
if step == flow.first_step_in_flow():
|
|
588
|
-
initial_events.append(FlowStarted(flow.id))
|
|
595
|
+
initial_events.append(FlowStarted(flow.id, metadata=stack.current_context()))
|
|
589
596
|
|
|
590
597
|
if isinstance(step, CollectInformationFlowStep):
|
|
591
598
|
return _run_collect_information_step(
|
|
592
|
-
available_actions,
|
|
593
|
-
initial_events,
|
|
594
|
-
stack,
|
|
595
|
-
step,
|
|
596
|
-
tracker,
|
|
597
|
-
flow.readable_name(),
|
|
599
|
+
available_actions, initial_events, stack, step, tracker
|
|
598
600
|
)
|
|
599
601
|
|
|
600
602
|
elif isinstance(step, ActionFlowStep):
|
|
601
603
|
if not step.action:
|
|
602
604
|
raise FlowException(f"Action not specified for step {step}")
|
|
603
|
-
|
|
604
|
-
context = {"context": stack.current_context()}
|
|
605
|
-
action_name = render_template_variables(step.action, context)
|
|
606
|
-
|
|
607
|
-
if action_name in available_actions:
|
|
608
|
-
structlogger.debug("flow.step.run.action", context=context)
|
|
609
|
-
return PauseFlowReturnPrediction(
|
|
610
|
-
FlowActionPrediction(action_name, 1.0, events=initial_events)
|
|
611
|
-
)
|
|
612
|
-
else:
|
|
613
|
-
if step.action != "validate_{{context.collect}}":
|
|
614
|
-
# do not log about non-existing validation actions of collect steps
|
|
615
|
-
utter_action_name = render_template_variables(
|
|
616
|
-
"{{context.utter}}", context
|
|
617
|
-
)
|
|
618
|
-
if utter_action_name not in available_actions:
|
|
619
|
-
structlogger.warning(
|
|
620
|
-
"flow.step.run.action.unknown", action=action_name
|
|
621
|
-
)
|
|
622
|
-
return ContinueFlowWithNextStep(events=initial_events)
|
|
605
|
+
return _run_action_step(available_actions, initial_events, stack, step)
|
|
623
606
|
|
|
624
607
|
elif isinstance(step, LinkFlowStep):
|
|
625
|
-
|
|
626
|
-
stack.push(
|
|
627
|
-
UserFlowStackFrame(
|
|
628
|
-
flow_id=step.link,
|
|
629
|
-
frame_type=FlowStackFrameType.LINK,
|
|
630
|
-
),
|
|
631
|
-
# push this below the current stack frame so that we can
|
|
632
|
-
# complete the current flow first and then continue with the
|
|
633
|
-
# linked flow
|
|
634
|
-
index=-1,
|
|
635
|
-
)
|
|
636
|
-
return ContinueFlowWithNextStep(events=initial_events)
|
|
608
|
+
return _run_link_step(initial_events, stack, step)
|
|
637
609
|
|
|
638
610
|
elif isinstance(step, CallFlowStep):
|
|
639
|
-
|
|
640
|
-
stack.push(
|
|
641
|
-
UserFlowStackFrame(
|
|
642
|
-
flow_id=step.call,
|
|
643
|
-
frame_type=FlowStackFrameType.CALL,
|
|
644
|
-
),
|
|
645
|
-
)
|
|
646
|
-
return ContinueFlowWithNextStep()
|
|
611
|
+
return _run_call_step(initial_events, stack, step)
|
|
647
612
|
|
|
648
613
|
elif isinstance(step, SetSlotsFlowStep):
|
|
649
|
-
|
|
650
|
-
slot_events: List[Event] = events_from_set_slots_step(step)
|
|
651
|
-
return ContinueFlowWithNextStep(events=initial_events + slot_events)
|
|
614
|
+
return _run_set_slot_step(initial_events, step)
|
|
652
615
|
|
|
653
616
|
elif isinstance(step, NoOperationFlowStep):
|
|
654
617
|
structlogger.debug("flow.step.run.no_operation")
|
|
655
618
|
return ContinueFlowWithNextStep(events=initial_events)
|
|
656
619
|
|
|
657
620
|
elif isinstance(step, EndFlowStep):
|
|
658
|
-
|
|
659
|
-
structlogger.debug("flow.step.run.flow_end")
|
|
660
|
-
current_frame = stack.pop()
|
|
661
|
-
trigger_pattern_completed(current_frame, stack, flows)
|
|
662
|
-
resumed_events = trigger_pattern_continue_interrupted(
|
|
663
|
-
current_frame, stack, flows
|
|
664
|
-
)
|
|
665
|
-
reset_events: List[Event] = reset_scoped_slots(current_frame, flow, tracker)
|
|
666
|
-
return ContinueFlowWithNextStep(
|
|
667
|
-
events=initial_events + reset_events + resumed_events, has_flow_ended=True
|
|
668
|
-
)
|
|
621
|
+
return _run_end_step(flow, flows, initial_events, stack, tracker)
|
|
669
622
|
|
|
670
623
|
else:
|
|
671
624
|
raise FlowException(f"Unknown flow step type {type(step)}")
|
|
672
625
|
|
|
673
626
|
|
|
627
|
+
def _run_end_step(
|
|
628
|
+
flow: Flow,
|
|
629
|
+
flows: FlowsList,
|
|
630
|
+
initial_events: List[Event],
|
|
631
|
+
stack: DialogueStack,
|
|
632
|
+
tracker: DialogueStateTracker,
|
|
633
|
+
) -> FlowStepResult:
|
|
634
|
+
# this is the end of the flow, so we'll pop it from the stack
|
|
635
|
+
structlogger.debug("flow.step.run.flow_end")
|
|
636
|
+
current_frame = stack.pop()
|
|
637
|
+
trigger_pattern_completed(current_frame, stack, flows)
|
|
638
|
+
resumed_events = trigger_pattern_continue_interrupted(current_frame, stack, flows)
|
|
639
|
+
reset_events: List[Event] = reset_scoped_slots(current_frame, flow, tracker)
|
|
640
|
+
return ContinueFlowWithNextStep(
|
|
641
|
+
events=initial_events + reset_events + resumed_events, has_flow_ended=True
|
|
642
|
+
)
|
|
643
|
+
|
|
644
|
+
|
|
645
|
+
def _run_set_slot_step(
|
|
646
|
+
initial_events: List[Event], step: SetSlotsFlowStep
|
|
647
|
+
) -> FlowStepResult:
|
|
648
|
+
structlogger.debug("flow.step.run.slot")
|
|
649
|
+
slot_events: List[Event] = events_from_set_slots_step(step)
|
|
650
|
+
return ContinueFlowWithNextStep(events=initial_events + slot_events)
|
|
651
|
+
|
|
652
|
+
|
|
653
|
+
def _run_call_step(
|
|
654
|
+
initial_events: List[Event], stack: DialogueStack, step: CallFlowStep
|
|
655
|
+
) -> FlowStepResult:
|
|
656
|
+
structlogger.debug("flow.step.run.call")
|
|
657
|
+
stack.push(
|
|
658
|
+
UserFlowStackFrame(
|
|
659
|
+
flow_id=step.call,
|
|
660
|
+
frame_type=FlowStackFrameType.CALL,
|
|
661
|
+
),
|
|
662
|
+
)
|
|
663
|
+
return ContinueFlowWithNextStep(events=initial_events)
|
|
664
|
+
|
|
665
|
+
|
|
666
|
+
def _run_link_step(
|
|
667
|
+
initial_events: List[Event], stack: DialogueStack, step: LinkFlowStep
|
|
668
|
+
) -> FlowStepResult:
|
|
669
|
+
structlogger.debug("flow.step.run.link")
|
|
670
|
+
|
|
671
|
+
if step.link == RASA_PATTERN_HUMAN_HANDOFF:
|
|
672
|
+
linked_stack_frame: DialogueStackFrame = HumanHandoffPatternFlowStackFrame()
|
|
673
|
+
else:
|
|
674
|
+
linked_stack_frame = UserFlowStackFrame(
|
|
675
|
+
flow_id=step.link,
|
|
676
|
+
frame_type=FlowStackFrameType.LINK,
|
|
677
|
+
)
|
|
678
|
+
|
|
679
|
+
stack.push(
|
|
680
|
+
linked_stack_frame,
|
|
681
|
+
# push this below the current stack frame so that we can
|
|
682
|
+
# complete the current flow first and then continue with the
|
|
683
|
+
# linked flow
|
|
684
|
+
index=-1,
|
|
685
|
+
)
|
|
686
|
+
|
|
687
|
+
return ContinueFlowWithNextStep(events=initial_events)
|
|
688
|
+
|
|
689
|
+
|
|
690
|
+
def _run_action_step(
|
|
691
|
+
available_actions: List[str],
|
|
692
|
+
initial_events: List[Event],
|
|
693
|
+
stack: DialogueStack,
|
|
694
|
+
step: ActionFlowStep,
|
|
695
|
+
) -> FlowStepResult:
|
|
696
|
+
context = {"context": stack.current_context()}
|
|
697
|
+
action_name = render_template_variables(step.action, context)
|
|
698
|
+
|
|
699
|
+
if action_name in available_actions:
|
|
700
|
+
structlogger.debug("flow.step.run.action", context=context)
|
|
701
|
+
return PauseFlowReturnPrediction(
|
|
702
|
+
FlowActionPrediction(action_name, 1.0, events=initial_events)
|
|
703
|
+
)
|
|
704
|
+
else:
|
|
705
|
+
if step.action != "validate_{{context.collect}}":
|
|
706
|
+
# do not log about non-existing validation actions of collect steps
|
|
707
|
+
utter_action_name = render_template_variables("{{context.utter}}", context)
|
|
708
|
+
if utter_action_name not in available_actions:
|
|
709
|
+
structlogger.warning("flow.step.run.action.unknown", action=action_name)
|
|
710
|
+
return ContinueFlowWithNextStep(events=initial_events)
|
|
711
|
+
|
|
712
|
+
|
|
674
713
|
def _run_collect_information_step(
|
|
675
714
|
available_actions: List[str],
|
|
676
715
|
initial_events: List[Event],
|
|
677
716
|
stack: DialogueStack,
|
|
678
717
|
step: CollectInformationFlowStep,
|
|
679
718
|
tracker: DialogueStateTracker,
|
|
680
|
-
flow_name: str,
|
|
681
719
|
) -> FlowStepResult:
|
|
682
|
-
is_step_valid = validate_collect_step(
|
|
683
|
-
step, stack, available_actions, tracker.slots, flow_name
|
|
684
|
-
)
|
|
720
|
+
is_step_valid = validate_collect_step(step, stack, available_actions, tracker.slots)
|
|
685
721
|
|
|
686
722
|
if not is_step_valid:
|
|
687
723
|
# if we return any other FlowStepResult, the assistant will stay silent
|
|
688
724
|
# instead of triggering the internal error pattern
|
|
689
725
|
return ContinueFlowWithNextStep(events=initial_events)
|
|
690
726
|
is_mapping_valid = validate_custom_slot_mappings(
|
|
691
|
-
step, stack, tracker, available_actions
|
|
727
|
+
step, stack, tracker, available_actions
|
|
692
728
|
)
|
|
693
729
|
|
|
694
730
|
if not is_mapping_valid:
|