rasa-pro 3.9.18__py3-none-any.whl → 3.10.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- README.md +0 -374
- rasa/__init__.py +1 -2
- rasa/__main__.py +5 -0
- rasa/anonymization/anonymization_rule_executor.py +2 -2
- rasa/api.py +27 -23
- rasa/cli/arguments/data.py +27 -2
- rasa/cli/arguments/default_arguments.py +25 -3
- rasa/cli/arguments/run.py +9 -9
- rasa/cli/arguments/train.py +11 -3
- rasa/cli/data.py +70 -8
- rasa/cli/e2e_test.py +104 -431
- rasa/cli/evaluate.py +1 -1
- rasa/cli/interactive.py +1 -0
- rasa/cli/llm_fine_tuning.py +398 -0
- rasa/cli/project_templates/calm/endpoints.yml +1 -1
- rasa/cli/project_templates/tutorial/endpoints.yml +1 -1
- rasa/cli/run.py +15 -14
- rasa/cli/scaffold.py +10 -8
- rasa/cli/studio/studio.py +35 -5
- rasa/cli/train.py +56 -8
- rasa/cli/utils.py +22 -5
- rasa/cli/x.py +1 -1
- rasa/constants.py +7 -1
- rasa/core/actions/action.py +98 -49
- rasa/core/actions/action_run_slot_rejections.py +4 -1
- rasa/core/actions/custom_action_executor.py +9 -6
- rasa/core/actions/direct_custom_actions_executor.py +80 -0
- rasa/core/actions/e2e_stub_custom_action_executor.py +68 -0
- rasa/core/actions/grpc_custom_action_executor.py +2 -2
- rasa/core/actions/http_custom_action_executor.py +6 -5
- rasa/core/agent.py +21 -17
- rasa/core/channels/__init__.py +2 -0
- rasa/core/channels/audiocodes.py +1 -16
- rasa/core/channels/voice_aware/__init__.py +0 -0
- rasa/core/channels/voice_aware/jambonz.py +103 -0
- rasa/core/channels/voice_aware/jambonz_protocol.py +344 -0
- rasa/core/channels/voice_aware/utils.py +20 -0
- rasa/core/channels/voice_native/__init__.py +0 -0
- rasa/core/constants.py +6 -1
- rasa/core/information_retrieval/faiss.py +7 -4
- rasa/core/information_retrieval/information_retrieval.py +8 -0
- rasa/core/information_retrieval/milvus.py +9 -2
- rasa/core/information_retrieval/qdrant.py +1 -1
- rasa/core/nlg/contextual_response_rephraser.py +32 -10
- rasa/core/nlg/summarize.py +4 -3
- rasa/core/policies/enterprise_search_policy.py +113 -45
- rasa/core/policies/flows/flow_executor.py +122 -76
- rasa/core/policies/intentless_policy.py +83 -29
- rasa/core/processor.py +72 -54
- rasa/core/run.py +5 -4
- rasa/core/tracker_store.py +8 -4
- rasa/core/training/interactive.py +1 -1
- rasa/core/utils.py +56 -57
- rasa/dialogue_understanding/coexistence/llm_based_router.py +53 -13
- rasa/dialogue_understanding/commands/__init__.py +6 -0
- rasa/dialogue_understanding/commands/restart_command.py +58 -0
- rasa/dialogue_understanding/commands/session_start_command.py +59 -0
- rasa/dialogue_understanding/commands/utils.py +40 -0
- rasa/dialogue_understanding/generator/constants.py +10 -3
- rasa/dialogue_understanding/generator/flow_retrieval.py +21 -5
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +13 -3
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +134 -90
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +47 -7
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +127 -41
- rasa/dialogue_understanding/patterns/restart.py +37 -0
- rasa/dialogue_understanding/patterns/session_start.py +37 -0
- rasa/dialogue_understanding/processor/command_processor.py +16 -3
- rasa/dialogue_understanding/processor/command_processor_component.py +6 -2
- rasa/e2e_test/aggregate_test_stats_calculator.py +134 -0
- rasa/e2e_test/assertions.py +1223 -0
- rasa/e2e_test/assertions_schema.yml +106 -0
- rasa/e2e_test/constants.py +20 -0
- rasa/e2e_test/e2e_config.py +220 -0
- rasa/e2e_test/e2e_config_schema.yml +26 -0
- rasa/e2e_test/e2e_test_case.py +131 -8
- rasa/e2e_test/e2e_test_converter.py +363 -0
- rasa/e2e_test/e2e_test_converter_prompt.jinja2 +70 -0
- rasa/e2e_test/e2e_test_coverage_report.py +364 -0
- rasa/e2e_test/e2e_test_result.py +26 -6
- rasa/e2e_test/e2e_test_runner.py +493 -71
- rasa/e2e_test/e2e_test_schema.yml +96 -0
- rasa/e2e_test/pykwalify_extensions.py +39 -0
- rasa/e2e_test/stub_custom_action.py +70 -0
- rasa/e2e_test/utils/__init__.py +0 -0
- rasa/e2e_test/utils/e2e_yaml_utils.py +55 -0
- rasa/e2e_test/utils/io.py +598 -0
- rasa/e2e_test/utils/validation.py +80 -0
- rasa/engine/graph.py +9 -3
- rasa/engine/recipes/default_components.py +0 -2
- rasa/engine/recipes/default_recipe.py +10 -2
- rasa/engine/storage/local_model_storage.py +40 -12
- rasa/engine/validation.py +78 -1
- rasa/env.py +9 -0
- rasa/graph_components/providers/story_graph_provider.py +59 -6
- rasa/llm_fine_tuning/__init__.py +0 -0
- rasa/llm_fine_tuning/annotation_module.py +241 -0
- rasa/llm_fine_tuning/conversations.py +144 -0
- rasa/llm_fine_tuning/llm_data_preparation_module.py +178 -0
- rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +407 -0
- rasa/llm_fine_tuning/paraphrasing/__init__.py +0 -0
- rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +281 -0
- rasa/llm_fine_tuning/paraphrasing/default_rephrase_prompt_template.jina2 +44 -0
- rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +121 -0
- rasa/llm_fine_tuning/paraphrasing/rephrased_user_message.py +10 -0
- rasa/llm_fine_tuning/paraphrasing_module.py +128 -0
- rasa/llm_fine_tuning/storage.py +174 -0
- rasa/llm_fine_tuning/train_test_split_module.py +441 -0
- rasa/model_training.py +56 -16
- rasa/nlu/persistor.py +157 -36
- rasa/server.py +45 -10
- rasa/shared/constants.py +76 -16
- rasa/shared/core/domain.py +27 -19
- rasa/shared/core/events.py +28 -2
- rasa/shared/core/flows/flow.py +208 -13
- rasa/shared/core/flows/flow_path.py +84 -0
- rasa/shared/core/flows/flows_list.py +33 -11
- rasa/shared/core/flows/flows_yaml_schema.json +269 -193
- rasa/shared/core/flows/validation.py +112 -25
- rasa/shared/core/flows/yaml_flows_io.py +149 -10
- rasa/shared/core/trackers.py +6 -0
- rasa/shared/core/training_data/structures.py +20 -0
- rasa/shared/core/training_data/visualization.html +2 -2
- rasa/shared/exceptions.py +4 -0
- rasa/shared/importers/importer.py +64 -16
- rasa/shared/nlu/constants.py +2 -0
- rasa/shared/providers/_configs/__init__.py +0 -0
- rasa/shared/providers/_configs/azure_openai_client_config.py +183 -0
- rasa/shared/providers/_configs/client_config.py +57 -0
- rasa/shared/providers/_configs/default_litellm_client_config.py +130 -0
- rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +234 -0
- rasa/shared/providers/_configs/openai_client_config.py +175 -0
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +176 -0
- rasa/shared/providers/_configs/utils.py +101 -0
- rasa/shared/providers/_ssl_verification_utils.py +124 -0
- rasa/shared/providers/embedding/__init__.py +0 -0
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +259 -0
- rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py +74 -0
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +277 -0
- rasa/shared/providers/embedding/default_litellm_embedding_client.py +102 -0
- rasa/shared/providers/embedding/embedding_client.py +90 -0
- rasa/shared/providers/embedding/embedding_response.py +41 -0
- rasa/shared/providers/embedding/huggingface_local_embedding_client.py +191 -0
- rasa/shared/providers/embedding/openai_embedding_client.py +172 -0
- rasa/shared/providers/llm/__init__.py +0 -0
- rasa/shared/providers/llm/_base_litellm_client.py +251 -0
- rasa/shared/providers/llm/azure_openai_llm_client.py +338 -0
- rasa/shared/providers/llm/default_litellm_llm_client.py +84 -0
- rasa/shared/providers/llm/llm_client.py +76 -0
- rasa/shared/providers/llm/llm_response.py +50 -0
- rasa/shared/providers/llm/openai_llm_client.py +155 -0
- rasa/shared/providers/llm/self_hosted_llm_client.py +293 -0
- rasa/shared/providers/mappings.py +75 -0
- rasa/shared/utils/cli.py +30 -0
- rasa/shared/utils/io.py +65 -2
- rasa/shared/utils/llm.py +246 -200
- rasa/shared/utils/yaml.py +121 -15
- rasa/studio/auth.py +6 -4
- rasa/studio/config.py +13 -4
- rasa/studio/constants.py +1 -0
- rasa/studio/data_handler.py +10 -3
- rasa/studio/download.py +19 -13
- rasa/studio/train.py +2 -3
- rasa/studio/upload.py +19 -11
- rasa/telemetry.py +113 -58
- rasa/tracing/instrumentation/attribute_extractors.py +32 -17
- rasa/utils/common.py +18 -19
- rasa/utils/endpoints.py +7 -4
- rasa/utils/json_utils.py +60 -0
- rasa/utils/licensing.py +9 -1
- rasa/utils/ml_utils.py +4 -2
- rasa/validator.py +213 -3
- rasa/version.py +1 -1
- rasa_pro-3.10.16.dist-info/METADATA +196 -0
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/RECORD +179 -113
- rasa/nlu/classifiers/llm_intent_classifier.py +0 -519
- rasa/shared/providers/openai/clients.py +0 -43
- rasa/shared/providers/openai/session_handler.py +0 -110
- rasa_pro-3.9.18.dist-info/METADATA +0 -563
- /rasa/{shared/providers/openai → cli/project_templates/tutorial/actions}/__init__.py +0 -0
- /rasa/cli/project_templates/tutorial/{actions.py → actions/actions.py} +0 -0
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/NOTICE +0 -0
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/WHEEL +0 -0
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/entry_points.txt +0 -0
|
@@ -8,6 +8,11 @@ import structlog
|
|
|
8
8
|
from jinja2 import Template
|
|
9
9
|
from pydantic import ValidationError
|
|
10
10
|
|
|
11
|
+
from rasa.shared.providers.embedding._langchain_embedding_client_adapter import (
|
|
12
|
+
_LangchainEmbeddingClientAdapter,
|
|
13
|
+
)
|
|
14
|
+
from rasa.shared.providers.llm.llm_client import LLMClient
|
|
15
|
+
|
|
11
16
|
import rasa.shared.utils.io
|
|
12
17
|
from rasa.telemetry import (
|
|
13
18
|
track_enterprise_search_policy_predict,
|
|
@@ -19,6 +24,7 @@ from rasa.core.constants import (
|
|
|
19
24
|
POLICY_MAX_HISTORY,
|
|
20
25
|
POLICY_PRIORITY,
|
|
21
26
|
SEARCH_POLICY_PRIORITY,
|
|
27
|
+
UTTER_SOURCE_METADATA_KEY,
|
|
22
28
|
)
|
|
23
29
|
from rasa.core.policies.policy import Policy, PolicyPrediction
|
|
24
30
|
from rasa.core.utils import AvailableEndpoints
|
|
@@ -39,13 +45,23 @@ from rasa.engine.storage.resource import Resource
|
|
|
39
45
|
from rasa.engine.storage.storage import ModelStorage
|
|
40
46
|
from rasa.graph_components.providers.forms_provider import Forms
|
|
41
47
|
from rasa.graph_components.providers.responses_provider import Responses
|
|
48
|
+
from rasa.shared.constants import (
|
|
49
|
+
EMBEDDINGS_CONFIG_KEY,
|
|
50
|
+
LLM_CONFIG_KEY,
|
|
51
|
+
MODEL_CONFIG_KEY,
|
|
52
|
+
MODEL_NAME_CONFIG_KEY,
|
|
53
|
+
PROMPT_CONFIG_KEY,
|
|
54
|
+
PROVIDER_CONFIG_KEY,
|
|
55
|
+
OPENAI_PROVIDER,
|
|
56
|
+
TIMEOUT_CONFIG_KEY,
|
|
57
|
+
)
|
|
42
58
|
from rasa.shared.core.constants import (
|
|
43
59
|
ACTION_CANCEL_FLOW,
|
|
44
60
|
ACTION_SEND_TEXT_NAME,
|
|
45
61
|
DEFAULT_SLOT_NAMES,
|
|
46
62
|
)
|
|
47
63
|
from rasa.shared.core.domain import Domain
|
|
48
|
-
from rasa.shared.core.events import Event
|
|
64
|
+
from rasa.shared.core.events import Event, UserUttered, BotUttered
|
|
49
65
|
from rasa.shared.core.generator import TrackerWithCachedStates
|
|
50
66
|
from rasa.shared.core.trackers import DialogueStateTracker, EventVerbosity
|
|
51
67
|
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
@@ -59,6 +75,8 @@ from rasa.shared.utils.llm import (
|
|
|
59
75
|
llm_factory,
|
|
60
76
|
sanitize_message_for_prompt,
|
|
61
77
|
tracker_as_readable_transcript,
|
|
78
|
+
try_instantiate_llm_client,
|
|
79
|
+
try_instantiate_embedder,
|
|
62
80
|
)
|
|
63
81
|
from rasa.core.information_retrieval.faiss import FAISS_Store
|
|
64
82
|
from rasa.core.information_retrieval import (
|
|
@@ -70,7 +88,6 @@ from rasa.core.information_retrieval import (
|
|
|
70
88
|
|
|
71
89
|
if TYPE_CHECKING:
|
|
72
90
|
from langchain.schema.embeddings import Embeddings
|
|
73
|
-
from langchain.llms.base import BaseLLM
|
|
74
91
|
from rasa.core.featurizers.tracker_featurizers import TrackerFeaturizer
|
|
75
92
|
|
|
76
93
|
from rasa.utils.log_utils import log_llm
|
|
@@ -86,6 +103,7 @@ VECTOR_STORE_THRESHOLD_PROPERTY = "threshold"
|
|
|
86
103
|
TRACE_TOKENS_PROPERTY = "trace_prompt_tokens"
|
|
87
104
|
CITATION_ENABLED_PROPERTY = "citation_enabled"
|
|
88
105
|
USE_LLM_PROPERTY = "use_generative_llm"
|
|
106
|
+
MAX_MESSAGES_IN_QUERY_KEY = "max_messages_in_query"
|
|
89
107
|
|
|
90
108
|
DEFAULT_VECTOR_STORE_TYPE = "faiss"
|
|
91
109
|
DEFAULT_VECTOR_STORE_THRESHOLD = 0.0
|
|
@@ -96,23 +114,24 @@ DEFAULT_VECTOR_STORE = {
|
|
|
96
114
|
}
|
|
97
115
|
|
|
98
116
|
DEFAULT_LLM_CONFIG = {
|
|
99
|
-
|
|
100
|
-
|
|
117
|
+
PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
|
|
118
|
+
MODEL_CONFIG_KEY: DEFAULT_OPENAI_CHAT_MODEL_NAME,
|
|
119
|
+
TIMEOUT_CONFIG_KEY: 10,
|
|
101
120
|
"temperature": 0.0,
|
|
102
121
|
"max_tokens": 256,
|
|
103
|
-
"model_name": DEFAULT_OPENAI_CHAT_MODEL_NAME,
|
|
104
122
|
"max_retries": 1,
|
|
105
123
|
}
|
|
106
124
|
|
|
107
125
|
DEFAULT_EMBEDDINGS_CONFIG = {
|
|
108
|
-
|
|
126
|
+
PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
|
|
109
127
|
"model": DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
|
|
110
128
|
}
|
|
111
129
|
|
|
112
|
-
EMBEDDINGS_CONFIG_KEY = "embeddings"
|
|
113
|
-
LLM_CONFIG_KEY = "llm"
|
|
114
130
|
ENTERPRISE_SEARCH_PROMPT_FILE_NAME = "enterprise_search_policy_prompt.jinja2"
|
|
115
131
|
|
|
132
|
+
SEARCH_RESULTS_METADATA_KEY = "search_results"
|
|
133
|
+
SEARCH_QUERY_METADATA_KEY = "search_query"
|
|
134
|
+
|
|
116
135
|
DEFAULT_ENTERPRISE_SEARCH_PROMPT_TEMPLATE = importlib.resources.read_text(
|
|
117
136
|
"rasa.core.policies", "enterprise_search_prompt_template.jinja2"
|
|
118
137
|
)
|
|
@@ -179,26 +198,42 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
179
198
|
"""Constructs a new Policy object."""
|
|
180
199
|
super().__init__(config, model_storage, resource, execution_context, featurizer)
|
|
181
200
|
|
|
201
|
+
# Vector store object and configuration
|
|
182
202
|
self.vector_store = vector_store
|
|
183
203
|
self.vector_store_config = config.get(
|
|
184
204
|
VECTOR_STORE_PROPERTY, DEFAULT_VECTOR_STORE
|
|
185
205
|
)
|
|
186
|
-
|
|
206
|
+
# Embeddings configuration for encoding the search query
|
|
187
207
|
self.embeddings_config = self.config.get(
|
|
188
208
|
EMBEDDINGS_CONFIG_KEY, DEFAULT_EMBEDDINGS_CONFIG
|
|
189
209
|
)
|
|
210
|
+
# Maximum number of turns to include in the prompt
|
|
190
211
|
self.max_history = self.config.get(POLICY_MAX_HISTORY)
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
212
|
+
|
|
213
|
+
# Maximum number of messages to include in the search query
|
|
214
|
+
self.max_messages_in_query = self.config.get(MAX_MESSAGES_IN_QUERY_KEY, 2)
|
|
215
|
+
|
|
216
|
+
# LLM Configuration for response generation
|
|
217
|
+
self.llm_config = self.config.get(LLM_CONFIG_KEY, DEFAULT_LLM_CONFIG)
|
|
218
|
+
|
|
219
|
+
# boolean to enable/disable tracing of prompt tokens
|
|
195
220
|
self.trace_prompt_tokens = self.config.get(TRACE_TOKENS_PROPERTY, False)
|
|
221
|
+
|
|
222
|
+
# boolean to enable/disable the use of LLM for response generation
|
|
196
223
|
self.use_llm = self.config.get(USE_LLM_PROPERTY, True)
|
|
224
|
+
|
|
225
|
+
# boolean to enable/disable citation generation
|
|
197
226
|
self.citation_enabled = self.config.get(CITATION_ENABLED_PROPERTY, False)
|
|
227
|
+
|
|
228
|
+
self.prompt_template = prompt_template or get_prompt_template(
|
|
229
|
+
self.config.get(PROMPT_CONFIG_KEY),
|
|
230
|
+
DEFAULT_ENTERPRISE_SEARCH_PROMPT_TEMPLATE,
|
|
231
|
+
)
|
|
198
232
|
self.citation_prompt_template = get_prompt_template(
|
|
199
|
-
self.config.get(
|
|
233
|
+
self.config.get(PROMPT_CONFIG_KEY),
|
|
200
234
|
DEFAULT_ENTERPRISE_SEARCH_PROMPT_WITH_CITATION_TEMPLATE,
|
|
201
235
|
)
|
|
236
|
+
# If citation is enabled, use the citation prompt template
|
|
202
237
|
if self.citation_enabled:
|
|
203
238
|
self.prompt_template = self.citation_prompt_template
|
|
204
239
|
|
|
@@ -209,9 +244,10 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
209
244
|
Returns:
|
|
210
245
|
The embedder.
|
|
211
246
|
"""
|
|
212
|
-
|
|
247
|
+
client = embedder_factory(
|
|
213
248
|
config.get(EMBEDDINGS_CONFIG_KEY), DEFAULT_EMBEDDINGS_CONFIG
|
|
214
249
|
)
|
|
250
|
+
return _LangchainEmbeddingClientAdapter(client)
|
|
215
251
|
|
|
216
252
|
def train( # type: ignore[override]
|
|
217
253
|
self,
|
|
@@ -245,20 +281,24 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
245
281
|
# validate embedding configuration
|
|
246
282
|
try:
|
|
247
283
|
embeddings = self._create_plain_embedder(self.config)
|
|
248
|
-
except ValidationError as e:
|
|
284
|
+
except (ValidationError, Exception) as e:
|
|
285
|
+
logger.error(
|
|
286
|
+
"enterprise_search_policy.train.embedder_instantiation_failed",
|
|
287
|
+
message="Unable to instantiate the embedding client.",
|
|
288
|
+
error=e,
|
|
289
|
+
)
|
|
249
290
|
print_error_and_exit(
|
|
250
291
|
"Unable to create embedder. Please make sure you specified the "
|
|
251
292
|
f"required environment variables. Error: {e}"
|
|
252
293
|
)
|
|
253
294
|
|
|
254
295
|
# validate llm configuration
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
print_error_and_exit(f"Unable to create LLM. Error: {e}")
|
|
296
|
+
try_instantiate_llm_client(
|
|
297
|
+
self.config.get(LLM_CONFIG_KEY),
|
|
298
|
+
DEFAULT_LLM_CONFIG,
|
|
299
|
+
"enterprise_search_policy.train",
|
|
300
|
+
"EnterpriseSearchPolicy",
|
|
301
|
+
)
|
|
262
302
|
|
|
263
303
|
if store_type == DEFAULT_VECTOR_STORE_TYPE:
|
|
264
304
|
logger.info("enterprise_search_policy.train.faiss")
|
|
@@ -275,11 +315,12 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
275
315
|
# telemetry call to track training completion
|
|
276
316
|
track_enterprise_search_policy_train_completed(
|
|
277
317
|
vector_store_type=store_type,
|
|
278
|
-
embeddings_type=self.embeddings_config.get(
|
|
279
|
-
embeddings_model=self.embeddings_config.get(
|
|
280
|
-
or self.embeddings_config.get(
|
|
281
|
-
llm_type=self.llm_config.get(
|
|
282
|
-
llm_model=self.llm_config.get(
|
|
318
|
+
embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
|
|
319
|
+
embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
|
|
320
|
+
or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
|
|
321
|
+
llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
|
|
322
|
+
llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
|
|
323
|
+
or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
|
|
283
324
|
citation_enabled=self.citation_enabled,
|
|
284
325
|
)
|
|
285
326
|
self.persist()
|
|
@@ -343,24 +384,31 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
343
384
|
logger.error(
|
|
344
385
|
"enterprise_search_policy._connect_vector_store_or_raise.connect_error",
|
|
345
386
|
error=e,
|
|
387
|
+
config=config,
|
|
346
388
|
)
|
|
347
389
|
raise VectorStoreConnectionError(
|
|
348
390
|
f"Unable to connect to the vector store. Error: {e}"
|
|
349
391
|
)
|
|
350
392
|
|
|
351
|
-
def
|
|
352
|
-
"""
|
|
393
|
+
def _prepare_search_query(self, tracker: DialogueStateTracker, history: int) -> str:
|
|
394
|
+
"""Prepares the search query.
|
|
395
|
+
The search query is the last N messages in the conversation history.
|
|
353
396
|
|
|
354
397
|
Args:
|
|
355
398
|
tracker: The tracker containing the conversation history up to now.
|
|
399
|
+
history: The number of messages to include in the search query.
|
|
356
400
|
|
|
357
401
|
Returns:
|
|
358
|
-
The
|
|
402
|
+
The search query.
|
|
359
403
|
"""
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
404
|
+
transcript = []
|
|
405
|
+
for event in tracker.applied_events():
|
|
406
|
+
if isinstance(event, UserUttered) or isinstance(event, BotUttered):
|
|
407
|
+
transcript.append(sanitize_message_for_prompt(event.text))
|
|
408
|
+
|
|
409
|
+
search_query = " ".join(transcript[-history:][::-1])
|
|
410
|
+
logger.debug("search_query", search_query=search_query)
|
|
411
|
+
return search_query
|
|
364
412
|
|
|
365
413
|
async def predict_action_probabilities( # type: ignore[override]
|
|
366
414
|
self,
|
|
@@ -404,7 +452,9 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
404
452
|
logger.error(f"{logger_key}.connection_error", error=e)
|
|
405
453
|
return self._create_prediction_internal_error(domain, tracker)
|
|
406
454
|
|
|
407
|
-
search_query = self.
|
|
455
|
+
search_query = self._prepare_search_query(
|
|
456
|
+
tracker, int(self.max_messages_in_query)
|
|
457
|
+
)
|
|
408
458
|
tracker_state = tracker.current_state(EventVerbosity.AFTER_RESTART)
|
|
409
459
|
|
|
410
460
|
try:
|
|
@@ -448,17 +498,23 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
448
498
|
action_metadata = {
|
|
449
499
|
"message": {
|
|
450
500
|
"text": response,
|
|
501
|
+
SEARCH_RESULTS_METADATA_KEY: [
|
|
502
|
+
result.text for result in documents.results
|
|
503
|
+
],
|
|
504
|
+
UTTER_SOURCE_METADATA_KEY: self.__class__.__name__,
|
|
505
|
+
SEARCH_QUERY_METADATA_KEY: search_query,
|
|
451
506
|
}
|
|
452
507
|
}
|
|
453
508
|
|
|
454
509
|
# telemetry call to track policy prediction
|
|
455
510
|
track_enterprise_search_policy_predict(
|
|
456
511
|
vector_store_type=self.vector_store_config.get(VECTOR_STORE_TYPE_PROPERTY),
|
|
457
|
-
embeddings_type=self.embeddings_config.get(
|
|
458
|
-
embeddings_model=self.embeddings_config.get(
|
|
459
|
-
or self.embeddings_config.get(
|
|
460
|
-
llm_type=self.llm_config.get(
|
|
461
|
-
llm_model=self.llm_config.get(
|
|
512
|
+
embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
|
|
513
|
+
embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
|
|
514
|
+
or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
|
|
515
|
+
llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
|
|
516
|
+
llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
|
|
517
|
+
or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
|
|
462
518
|
citation_enabled=self.citation_enabled,
|
|
463
519
|
)
|
|
464
520
|
return self._create_prediction(
|
|
@@ -495,10 +551,11 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
495
551
|
return prompt
|
|
496
552
|
|
|
497
553
|
async def _generate_llm_answer(
|
|
498
|
-
self, llm:
|
|
554
|
+
self, llm: LLMClient, prompt: Text
|
|
499
555
|
) -> Optional[Text]:
|
|
500
556
|
try:
|
|
501
|
-
|
|
557
|
+
llm_response = await llm.acompletion(prompt)
|
|
558
|
+
llm_answer = llm_response.choices[0]
|
|
502
559
|
except Exception as e:
|
|
503
560
|
# unfortunately, langchain does not wrap LLM exceptions which means
|
|
504
561
|
# we have to catch all exceptions here
|
|
@@ -605,6 +662,18 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
605
662
|
execution_context: ExecutionContext,
|
|
606
663
|
**kwargs: Any,
|
|
607
664
|
) -> "EnterpriseSearchPolicy":
|
|
665
|
+
try_instantiate_llm_client(
|
|
666
|
+
config.get(LLM_CONFIG_KEY),
|
|
667
|
+
DEFAULT_LLM_CONFIG,
|
|
668
|
+
"enterprise_search_policy.load",
|
|
669
|
+
EnterpriseSearchPolicy.__name__,
|
|
670
|
+
)
|
|
671
|
+
try_instantiate_embedder(
|
|
672
|
+
config.get(EMBEDDINGS_CONFIG_KEY),
|
|
673
|
+
DEFAULT_EMBEDDINGS_CONFIG,
|
|
674
|
+
"enterprise_search_policy.load",
|
|
675
|
+
EnterpriseSearchPolicy.__name__,
|
|
676
|
+
)
|
|
608
677
|
"""Loads a trained policy (see parent class for full docstring)."""
|
|
609
678
|
prompt_template = None
|
|
610
679
|
store_type = config.get(VECTOR_STORE_PROPERTY, {}).get(
|
|
@@ -639,7 +708,6 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
639
708
|
logger.warning(
|
|
640
709
|
"enterprise_search_policy.load.failed", error=e, resource=resource.name
|
|
641
710
|
)
|
|
642
|
-
|
|
643
711
|
return cls(
|
|
644
712
|
config,
|
|
645
713
|
model_storage,
|
|
@@ -684,7 +752,7 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
684
752
|
local_knowledge_data = cls._get_local_knowledge_data(config)
|
|
685
753
|
|
|
686
754
|
prompt_template = get_prompt_template(
|
|
687
|
-
config.get(
|
|
755
|
+
config.get(PROMPT_CONFIG_KEY),
|
|
688
756
|
DEFAULT_ENTERPRISE_SEARCH_PROMPT_TEMPLATE,
|
|
689
757
|
)
|
|
690
758
|
return deep_container_fingerprint([prompt_template, local_knowledge_data])
|
|
@@ -2,12 +2,14 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from typing import Any, Dict, Text, List, Optional
|
|
4
4
|
|
|
5
|
+
import structlog
|
|
5
6
|
from jinja2 import Template
|
|
6
|
-
from
|
|
7
|
-
from rasa.dialogue_understanding.patterns.cancel import CancelPatternFlowStackFrame
|
|
7
|
+
from pypred import Predicate
|
|
8
8
|
from structlog.contextvars import (
|
|
9
9
|
bound_contextvars,
|
|
10
10
|
)
|
|
11
|
+
|
|
12
|
+
from rasa.core.constants import STEP_ID_METADATA_KEY, ACTIVE_FLOW_METADATA_KEY
|
|
11
13
|
from rasa.core.policies.flows.flow_exceptions import (
|
|
12
14
|
FlowCircuitBreakerTrippedException,
|
|
13
15
|
FlowException,
|
|
@@ -19,6 +21,20 @@ from rasa.core.policies.flows.flow_step_result import (
|
|
|
19
21
|
FlowStepResult,
|
|
20
22
|
PauseFlowReturnPrediction,
|
|
21
23
|
)
|
|
24
|
+
from rasa.dialogue_understanding.commands import CancelFlowCommand
|
|
25
|
+
from rasa.dialogue_understanding.patterns.cancel import CancelPatternFlowStackFrame
|
|
26
|
+
from rasa.dialogue_understanding.patterns.collect_information import (
|
|
27
|
+
CollectInformationPatternFlowStackFrame,
|
|
28
|
+
)
|
|
29
|
+
from rasa.dialogue_understanding.patterns.completed import (
|
|
30
|
+
CompletedPatternFlowStackFrame,
|
|
31
|
+
)
|
|
32
|
+
from rasa.dialogue_understanding.patterns.continue_interrupted import (
|
|
33
|
+
ContinueInterruptedPatternFlowStackFrame,
|
|
34
|
+
)
|
|
35
|
+
from rasa.dialogue_understanding.patterns.human_handoff import (
|
|
36
|
+
HumanHandoffPatternFlowStackFrame,
|
|
37
|
+
)
|
|
22
38
|
from rasa.dialogue_understanding.patterns.internal_error import (
|
|
23
39
|
InternalErrorPatternFlowStackFrame,
|
|
24
40
|
)
|
|
@@ -29,24 +45,13 @@ from rasa.dialogue_understanding.stack.frames import (
|
|
|
29
45
|
DialogueStackFrame,
|
|
30
46
|
UserFlowStackFrame,
|
|
31
47
|
)
|
|
32
|
-
from rasa.dialogue_understanding.patterns.collect_information import (
|
|
33
|
-
CollectInformationPatternFlowStackFrame,
|
|
34
|
-
)
|
|
35
|
-
from rasa.dialogue_understanding.patterns.completed import (
|
|
36
|
-
CompletedPatternFlowStackFrame,
|
|
37
|
-
)
|
|
38
|
-
from rasa.dialogue_understanding.patterns.continue_interrupted import (
|
|
39
|
-
ContinueInterruptedPatternFlowStackFrame,
|
|
40
|
-
)
|
|
41
48
|
from rasa.dialogue_understanding.stack.frames.flow_stack_frame import (
|
|
42
49
|
FlowStackFrameType,
|
|
43
50
|
)
|
|
44
51
|
from rasa.dialogue_understanding.stack.utils import (
|
|
45
52
|
top_user_flow_frame,
|
|
46
53
|
)
|
|
47
|
-
|
|
48
|
-
from pypred import Predicate
|
|
49
|
-
|
|
54
|
+
from rasa.shared.constants import RASA_PATTERN_HUMAN_HANDOFF
|
|
50
55
|
from rasa.shared.core.constants import ACTION_LISTEN_NAME, SlotMappingType
|
|
51
56
|
from rasa.shared.core.events import (
|
|
52
57
|
Event,
|
|
@@ -56,6 +61,11 @@ from rasa.shared.core.events import (
|
|
|
56
61
|
SlotSet,
|
|
57
62
|
)
|
|
58
63
|
from rasa.shared.core.flows import FlowsList
|
|
64
|
+
from rasa.shared.core.flows.flow import (
|
|
65
|
+
END_STEP,
|
|
66
|
+
Flow,
|
|
67
|
+
FlowStep,
|
|
68
|
+
)
|
|
59
69
|
from rasa.shared.core.flows.flow_step_links import (
|
|
60
70
|
StaticFlowStepLink,
|
|
61
71
|
IfFlowStepLink,
|
|
@@ -71,17 +81,11 @@ from rasa.shared.core.flows.steps import (
|
|
|
71
81
|
CollectInformationFlowStep,
|
|
72
82
|
NoOperationFlowStep,
|
|
73
83
|
)
|
|
74
|
-
from rasa.shared.core.flows.flow import (
|
|
75
|
-
END_STEP,
|
|
76
|
-
Flow,
|
|
77
|
-
FlowStep,
|
|
78
|
-
)
|
|
79
84
|
from rasa.shared.core.flows.steps.collect import SlotRejection
|
|
80
85
|
from rasa.shared.core.slots import Slot
|
|
81
86
|
from rasa.shared.core.trackers import (
|
|
82
87
|
DialogueStateTracker,
|
|
83
88
|
)
|
|
84
|
-
import structlog
|
|
85
89
|
|
|
86
90
|
structlogger = structlog.get_logger()
|
|
87
91
|
|
|
@@ -466,6 +470,10 @@ def advance_flows_until_next_action(
|
|
|
466
470
|
# make sure we really return all events that got created during the
|
|
467
471
|
# step execution of all steps (not only the last one)
|
|
468
472
|
prediction.events = gathered_events
|
|
473
|
+
prediction.metadata = {
|
|
474
|
+
ACTIVE_FLOW_METADATA_KEY: tracker.active_flow,
|
|
475
|
+
STEP_ID_METADATA_KEY: tracker.current_step_id,
|
|
476
|
+
}
|
|
469
477
|
return prediction
|
|
470
478
|
else:
|
|
471
479
|
structlogger.warning("flow.step.execution.no_action")
|
|
@@ -484,7 +492,8 @@ def validate_collect_step(
|
|
|
484
492
|
A collect step can be executed if either the `utter_ask` or the `action_ask` is
|
|
485
493
|
defined in the domain. If neither is defined, the collect step can still be
|
|
486
494
|
executed if the slot has an initial value defined in the domain, which would cause
|
|
487
|
-
the step to be skipped.
|
|
495
|
+
the step to be skipped.
|
|
496
|
+
"""
|
|
488
497
|
slot = slots.get(step.collect)
|
|
489
498
|
slot_has_initial_value_defined = slot and slot.initial_value is not None
|
|
490
499
|
if (
|
|
@@ -585,7 +594,7 @@ def run_step(
|
|
|
585
594
|
"""
|
|
586
595
|
initial_events: List[Event] = []
|
|
587
596
|
if step == flow.first_step_in_flow():
|
|
588
|
-
initial_events.append(FlowStarted(flow.id))
|
|
597
|
+
initial_events.append(FlowStarted(flow.id, metadata=stack.current_context()))
|
|
589
598
|
|
|
590
599
|
if isinstance(step, CollectInformationFlowStep):
|
|
591
600
|
return _run_collect_information_step(
|
|
@@ -600,77 +609,114 @@ def run_step(
|
|
|
600
609
|
elif isinstance(step, ActionFlowStep):
|
|
601
610
|
if not step.action:
|
|
602
611
|
raise FlowException(f"Action not specified for step {step}")
|
|
603
|
-
|
|
604
|
-
context = {"context": stack.current_context()}
|
|
605
|
-
action_name = render_template_variables(step.action, context)
|
|
606
|
-
|
|
607
|
-
if action_name in available_actions:
|
|
608
|
-
structlogger.debug("flow.step.run.action", context=context)
|
|
609
|
-
return PauseFlowReturnPrediction(
|
|
610
|
-
FlowActionPrediction(action_name, 1.0, events=initial_events)
|
|
611
|
-
)
|
|
612
|
-
else:
|
|
613
|
-
if step.action != "validate_{{context.collect}}":
|
|
614
|
-
# do not log about non-existing validation actions of collect steps
|
|
615
|
-
utter_action_name = render_template_variables(
|
|
616
|
-
"{{context.utter}}", context
|
|
617
|
-
)
|
|
618
|
-
if utter_action_name not in available_actions:
|
|
619
|
-
structlogger.warning(
|
|
620
|
-
"flow.step.run.action.unknown", action=action_name
|
|
621
|
-
)
|
|
622
|
-
return ContinueFlowWithNextStep(events=initial_events)
|
|
612
|
+
return _run_action_step(available_actions, initial_events, stack, step)
|
|
623
613
|
|
|
624
614
|
elif isinstance(step, LinkFlowStep):
|
|
625
|
-
|
|
626
|
-
stack.push(
|
|
627
|
-
UserFlowStackFrame(
|
|
628
|
-
flow_id=step.link,
|
|
629
|
-
frame_type=FlowStackFrameType.LINK,
|
|
630
|
-
),
|
|
631
|
-
# push this below the current stack frame so that we can
|
|
632
|
-
# complete the current flow first and then continue with the
|
|
633
|
-
# linked flow
|
|
634
|
-
index=-1,
|
|
635
|
-
)
|
|
636
|
-
return ContinueFlowWithNextStep(events=initial_events)
|
|
615
|
+
return _run_link_step(initial_events, stack, step)
|
|
637
616
|
|
|
638
617
|
elif isinstance(step, CallFlowStep):
|
|
639
|
-
|
|
640
|
-
stack.push(
|
|
641
|
-
UserFlowStackFrame(
|
|
642
|
-
flow_id=step.call,
|
|
643
|
-
frame_type=FlowStackFrameType.CALL,
|
|
644
|
-
),
|
|
645
|
-
)
|
|
646
|
-
return ContinueFlowWithNextStep()
|
|
618
|
+
return _run_call_step(initial_events, stack, step)
|
|
647
619
|
|
|
648
620
|
elif isinstance(step, SetSlotsFlowStep):
|
|
649
|
-
|
|
650
|
-
slot_events: List[Event] = events_from_set_slots_step(step)
|
|
651
|
-
return ContinueFlowWithNextStep(events=initial_events + slot_events)
|
|
621
|
+
return _run_set_slot_step(initial_events, step)
|
|
652
622
|
|
|
653
623
|
elif isinstance(step, NoOperationFlowStep):
|
|
654
624
|
structlogger.debug("flow.step.run.no_operation")
|
|
655
625
|
return ContinueFlowWithNextStep(events=initial_events)
|
|
656
626
|
|
|
657
627
|
elif isinstance(step, EndFlowStep):
|
|
658
|
-
|
|
659
|
-
structlogger.debug("flow.step.run.flow_end")
|
|
660
|
-
current_frame = stack.pop()
|
|
661
|
-
trigger_pattern_completed(current_frame, stack, flows)
|
|
662
|
-
resumed_events = trigger_pattern_continue_interrupted(
|
|
663
|
-
current_frame, stack, flows
|
|
664
|
-
)
|
|
665
|
-
reset_events: List[Event] = reset_scoped_slots(current_frame, flow, tracker)
|
|
666
|
-
return ContinueFlowWithNextStep(
|
|
667
|
-
events=initial_events + reset_events + resumed_events, has_flow_ended=True
|
|
668
|
-
)
|
|
628
|
+
return _run_end_step(flow, flows, initial_events, stack, tracker)
|
|
669
629
|
|
|
670
630
|
else:
|
|
671
631
|
raise FlowException(f"Unknown flow step type {type(step)}")
|
|
672
632
|
|
|
673
633
|
|
|
634
|
+
def _run_end_step(
|
|
635
|
+
flow: Flow,
|
|
636
|
+
flows: FlowsList,
|
|
637
|
+
initial_events: List[Event],
|
|
638
|
+
stack: DialogueStack,
|
|
639
|
+
tracker: DialogueStateTracker,
|
|
640
|
+
) -> FlowStepResult:
|
|
641
|
+
# this is the end of the flow, so we'll pop it from the stack
|
|
642
|
+
structlogger.debug("flow.step.run.flow_end")
|
|
643
|
+
current_frame = stack.pop()
|
|
644
|
+
trigger_pattern_completed(current_frame, stack, flows)
|
|
645
|
+
resumed_events = trigger_pattern_continue_interrupted(current_frame, stack, flows)
|
|
646
|
+
reset_events: List[Event] = reset_scoped_slots(current_frame, flow, tracker)
|
|
647
|
+
return ContinueFlowWithNextStep(
|
|
648
|
+
events=initial_events + reset_events + resumed_events, has_flow_ended=True
|
|
649
|
+
)
|
|
650
|
+
|
|
651
|
+
|
|
652
|
+
def _run_set_slot_step(
|
|
653
|
+
initial_events: List[Event], step: SetSlotsFlowStep
|
|
654
|
+
) -> FlowStepResult:
|
|
655
|
+
structlogger.debug("flow.step.run.slot")
|
|
656
|
+
slot_events: List[Event] = events_from_set_slots_step(step)
|
|
657
|
+
return ContinueFlowWithNextStep(events=initial_events + slot_events)
|
|
658
|
+
|
|
659
|
+
|
|
660
|
+
def _run_call_step(
|
|
661
|
+
initial_events: List[Event], stack: DialogueStack, step: CallFlowStep
|
|
662
|
+
) -> FlowStepResult:
|
|
663
|
+
structlogger.debug("flow.step.run.call")
|
|
664
|
+
stack.push(
|
|
665
|
+
UserFlowStackFrame(
|
|
666
|
+
flow_id=step.call,
|
|
667
|
+
frame_type=FlowStackFrameType.CALL,
|
|
668
|
+
),
|
|
669
|
+
)
|
|
670
|
+
return ContinueFlowWithNextStep(events=initial_events)
|
|
671
|
+
|
|
672
|
+
|
|
673
|
+
def _run_link_step(
|
|
674
|
+
initial_events: List[Event], stack: DialogueStack, step: LinkFlowStep
|
|
675
|
+
) -> FlowStepResult:
|
|
676
|
+
structlogger.debug("flow.step.run.link")
|
|
677
|
+
|
|
678
|
+
if step.link == RASA_PATTERN_HUMAN_HANDOFF:
|
|
679
|
+
linked_stack_frame: DialogueStackFrame = HumanHandoffPatternFlowStackFrame()
|
|
680
|
+
else:
|
|
681
|
+
linked_stack_frame = UserFlowStackFrame(
|
|
682
|
+
flow_id=step.link,
|
|
683
|
+
frame_type=FlowStackFrameType.LINK,
|
|
684
|
+
)
|
|
685
|
+
|
|
686
|
+
stack.push(
|
|
687
|
+
linked_stack_frame,
|
|
688
|
+
# push this below the current stack frame so that we can
|
|
689
|
+
# complete the current flow first and then continue with the
|
|
690
|
+
# linked flow
|
|
691
|
+
index=-1,
|
|
692
|
+
)
|
|
693
|
+
|
|
694
|
+
return ContinueFlowWithNextStep(events=initial_events)
|
|
695
|
+
|
|
696
|
+
|
|
697
|
+
def _run_action_step(
|
|
698
|
+
available_actions: List[str],
|
|
699
|
+
initial_events: List[Event],
|
|
700
|
+
stack: DialogueStack,
|
|
701
|
+
step: ActionFlowStep,
|
|
702
|
+
) -> FlowStepResult:
|
|
703
|
+
context = {"context": stack.current_context()}
|
|
704
|
+
action_name = render_template_variables(step.action, context)
|
|
705
|
+
|
|
706
|
+
if action_name in available_actions:
|
|
707
|
+
structlogger.debug("flow.step.run.action", context=context)
|
|
708
|
+
return PauseFlowReturnPrediction(
|
|
709
|
+
FlowActionPrediction(action_name, 1.0, events=initial_events)
|
|
710
|
+
)
|
|
711
|
+
else:
|
|
712
|
+
if step.action != "validate_{{context.collect}}":
|
|
713
|
+
# do not log about non-existing validation actions of collect steps
|
|
714
|
+
utter_action_name = render_template_variables("{{context.utter}}", context)
|
|
715
|
+
if utter_action_name not in available_actions:
|
|
716
|
+
structlogger.warning("flow.step.run.action.unknown", action=action_name)
|
|
717
|
+
return ContinueFlowWithNextStep(events=initial_events)
|
|
718
|
+
|
|
719
|
+
|
|
674
720
|
def _run_collect_information_step(
|
|
675
721
|
available_actions: List[str],
|
|
676
722
|
initial_events: List[Event],
|