rasa-pro 3.13.0.dev1__py3-none-any.whl → 3.13.0.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/core/actions/action.py +0 -6
- rasa/core/channels/voice_ready/audiocodes.py +52 -17
- rasa/core/channels/voice_stream/audiocodes.py +53 -9
- rasa/core/channels/voice_stream/genesys.py +146 -16
- rasa/core/information_retrieval/faiss.py +6 -1
- rasa/core/information_retrieval/information_retrieval.py +40 -2
- rasa/core/information_retrieval/milvus.py +7 -2
- rasa/core/information_retrieval/qdrant.py +7 -2
- rasa/core/policies/enterprise_search_policy.py +61 -301
- rasa/core/policies/flows/flow_executor.py +3 -38
- rasa/core/processor.py +27 -6
- rasa/core/utils.py +53 -0
- rasa/dialogue_understanding/commands/cancel_flow_command.py +4 -59
- rasa/dialogue_understanding/commands/start_flow_command.py +0 -41
- rasa/dialogue_understanding/generator/command_generator.py +67 -0
- rasa/dialogue_understanding/generator/command_parser.py +1 -1
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +4 -13
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_template.jinja2 +1 -1
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_gpt_4o_2024_11_20_template.jinja2 +20 -1
- rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +7 -0
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +0 -61
- rasa/dialogue_understanding/processor/command_processor.py +7 -65
- rasa/dialogue_understanding/stack/utils.py +0 -38
- rasa/dialogue_understanding_test/io.py +13 -8
- rasa/document_retrieval/__init__.py +0 -0
- rasa/document_retrieval/constants.py +32 -0
- rasa/document_retrieval/document_post_processor.py +351 -0
- rasa/document_retrieval/document_post_processor_prompt_template.jinja2 +0 -0
- rasa/document_retrieval/document_retriever.py +333 -0
- rasa/document_retrieval/knowledge_base_connectors/__init__.py +0 -0
- rasa/document_retrieval/knowledge_base_connectors/api_connector.py +39 -0
- rasa/document_retrieval/knowledge_base_connectors/knowledge_base_connector.py +34 -0
- rasa/document_retrieval/knowledge_base_connectors/vector_store_connector.py +226 -0
- rasa/document_retrieval/query_rewriter.py +234 -0
- rasa/document_retrieval/query_rewriter_prompt_template.jinja2 +8 -0
- rasa/engine/recipes/default_components.py +2 -0
- rasa/shared/core/constants.py +0 -8
- rasa/shared/core/domain.py +12 -3
- rasa/shared/core/flows/flow.py +0 -17
- rasa/shared/core/flows/flows_yaml_schema.json +3 -38
- rasa/shared/core/flows/steps/collect.py +5 -18
- rasa/shared/core/flows/utils.py +1 -16
- rasa/shared/core/slot_mappings.py +11 -5
- rasa/shared/nlu/constants.py +0 -1
- rasa/shared/utils/common.py +11 -1
- rasa/shared/utils/llm.py +1 -1
- rasa/tracing/instrumentation/attribute_extractors.py +10 -7
- rasa/tracing/instrumentation/instrumentation.py +12 -12
- rasa/validator.py +1 -123
- rasa/version.py +1 -1
- {rasa_pro-3.13.0.dev1.dist-info → rasa_pro-3.13.0.dev2.dist-info}/METADATA +1 -1
- {rasa_pro-3.13.0.dev1.dist-info → rasa_pro-3.13.0.dev2.dist-info}/RECORD +55 -47
- rasa/core/actions/action_handle_digressions.py +0 -164
- rasa/dialogue_understanding/commands/handle_digressions_command.py +0 -144
- rasa/dialogue_understanding/patterns/handle_digressions.py +0 -81
- {rasa_pro-3.13.0.dev1.dist-info → rasa_pro-3.13.0.dev2.dist-info}/NOTICE +0 -0
- {rasa_pro-3.13.0.dev1.dist-info → rasa_pro-3.13.0.dev2.dist-info}/WHEEL +0 -0
- {rasa_pro-3.13.0.dev1.dist-info → rasa_pro-3.13.0.dev2.dist-info}/entry_points.txt +0 -0
|
@@ -1,12 +1,10 @@
|
|
|
1
1
|
import importlib.resources
|
|
2
|
-
import json
|
|
3
2
|
import re
|
|
4
3
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Text
|
|
5
4
|
|
|
6
5
|
import dotenv
|
|
7
6
|
import structlog
|
|
8
7
|
from jinja2 import Template
|
|
9
|
-
from pydantic import ValidationError
|
|
10
8
|
|
|
11
9
|
import rasa.shared.utils.io
|
|
12
10
|
from rasa.core.constants import (
|
|
@@ -16,12 +14,9 @@ from rasa.core.constants import (
|
|
|
16
14
|
UTTER_SOURCE_METADATA_KEY,
|
|
17
15
|
)
|
|
18
16
|
from rasa.core.information_retrieval import (
|
|
19
|
-
InformationRetrieval,
|
|
20
|
-
InformationRetrievalException,
|
|
21
17
|
SearchResult,
|
|
22
|
-
|
|
18
|
+
SearchResultList,
|
|
23
19
|
)
|
|
24
|
-
from rasa.core.information_retrieval.faiss import FAISS_Store
|
|
25
20
|
from rasa.core.policies.policy import Policy, PolicyPrediction
|
|
26
21
|
from rasa.core.utils import AvailableEndpoints
|
|
27
22
|
from rasa.dialogue_understanding.generator.constants import (
|
|
@@ -38,6 +33,10 @@ from rasa.dialogue_understanding.stack.frames import (
|
|
|
38
33
|
PatternFlowStackFrame,
|
|
39
34
|
SearchStackFrame,
|
|
40
35
|
)
|
|
36
|
+
from rasa.document_retrieval.constants import (
|
|
37
|
+
POST_PROCESSED_DOCUMENTS_KEY,
|
|
38
|
+
SEARCH_QUERY_KEY,
|
|
39
|
+
)
|
|
41
40
|
from rasa.engine.graph import ExecutionContext
|
|
42
41
|
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
43
42
|
from rasa.engine.storage.resource import Resource
|
|
@@ -45,10 +44,7 @@ from rasa.engine.storage.storage import ModelStorage
|
|
|
45
44
|
from rasa.graph_components.providers.forms_provider import Forms
|
|
46
45
|
from rasa.graph_components.providers.responses_provider import Responses
|
|
47
46
|
from rasa.shared.constants import (
|
|
48
|
-
EMBEDDINGS_CONFIG_KEY,
|
|
49
47
|
MODEL_CONFIG_KEY,
|
|
50
|
-
MODEL_GROUP_ID_CONFIG_KEY,
|
|
51
|
-
MODEL_NAME_CONFIG_KEY,
|
|
52
48
|
OPENAI_PROVIDER,
|
|
53
49
|
PROMPT_CONFIG_KEY,
|
|
54
50
|
PROVIDER_CONFIG_KEY,
|
|
@@ -60,10 +56,10 @@ from rasa.shared.core.constants import (
|
|
|
60
56
|
DEFAULT_SLOT_NAMES,
|
|
61
57
|
)
|
|
62
58
|
from rasa.shared.core.domain import Domain
|
|
63
|
-
from rasa.shared.core.events import
|
|
59
|
+
from rasa.shared.core.events import Event
|
|
64
60
|
from rasa.shared.core.generator import TrackerWithCachedStates
|
|
65
|
-
from rasa.shared.core.trackers import DialogueStateTracker
|
|
66
|
-
from rasa.shared.exceptions import FileIOException
|
|
61
|
+
from rasa.shared.core.trackers import DialogueStateTracker
|
|
62
|
+
from rasa.shared.exceptions import FileIOException
|
|
67
63
|
from rasa.shared.nlu.constants import (
|
|
68
64
|
KEY_COMPONENT_NAME,
|
|
69
65
|
KEY_LLM_RESPONSE_METADATA,
|
|
@@ -72,12 +68,8 @@ from rasa.shared.nlu.constants import (
|
|
|
72
68
|
PROMPTS,
|
|
73
69
|
)
|
|
74
70
|
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
75
|
-
from rasa.shared.providers.embedding._langchain_embedding_client_adapter import (
|
|
76
|
-
_LangchainEmbeddingClientAdapter,
|
|
77
|
-
)
|
|
78
71
|
from rasa.shared.providers.llm.llm_client import LLMClient
|
|
79
72
|
from rasa.shared.providers.llm.llm_response import LLMResponse, measure_llm_latency
|
|
80
|
-
from rasa.shared.utils.cli import print_error_and_exit
|
|
81
73
|
from rasa.shared.utils.health_check.embeddings_health_check_mixin import (
|
|
82
74
|
EmbeddingsHealthCheckMixin,
|
|
83
75
|
)
|
|
@@ -85,23 +77,13 @@ from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheck
|
|
|
85
77
|
from rasa.shared.utils.io import deep_container_fingerprint
|
|
86
78
|
from rasa.shared.utils.llm import (
|
|
87
79
|
DEFAULT_OPENAI_CHAT_MODEL_NAME,
|
|
88
|
-
DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
|
|
89
|
-
embedder_factory,
|
|
90
80
|
get_prompt_template,
|
|
91
81
|
llm_factory,
|
|
92
82
|
resolve_model_client_config,
|
|
93
|
-
sanitize_message_for_prompt,
|
|
94
83
|
tracker_as_readable_transcript,
|
|
95
84
|
)
|
|
96
|
-
from rasa.telemetry import (
|
|
97
|
-
track_enterprise_search_policy_predict,
|
|
98
|
-
track_enterprise_search_policy_train_completed,
|
|
99
|
-
track_enterprise_search_policy_train_started,
|
|
100
|
-
)
|
|
101
85
|
|
|
102
86
|
if TYPE_CHECKING:
|
|
103
|
-
from langchain.schema.embeddings import Embeddings
|
|
104
|
-
|
|
105
87
|
from rasa.core.featurizers.tracker_featurizers import TrackerFeaturizer
|
|
106
88
|
|
|
107
89
|
from rasa.utils.log_utils import log_llm
|
|
@@ -110,22 +92,11 @@ logger = structlog.get_logger()
|
|
|
110
92
|
|
|
111
93
|
dotenv.load_dotenv("./.env")
|
|
112
94
|
|
|
113
|
-
SOURCE_PROPERTY = "source"
|
|
114
|
-
VECTOR_STORE_TYPE_PROPERTY = "type"
|
|
115
|
-
VECTOR_STORE_PROPERTY = "vector_store"
|
|
116
|
-
VECTOR_STORE_THRESHOLD_PROPERTY = "threshold"
|
|
117
95
|
TRACE_TOKENS_PROPERTY = "trace_prompt_tokens"
|
|
118
96
|
CITATION_ENABLED_PROPERTY = "citation_enabled"
|
|
119
97
|
USE_LLM_PROPERTY = "use_generative_llm"
|
|
120
98
|
MAX_MESSAGES_IN_QUERY_KEY = "max_messages_in_query"
|
|
121
99
|
|
|
122
|
-
DEFAULT_VECTOR_STORE_TYPE = "faiss"
|
|
123
|
-
DEFAULT_VECTOR_STORE_THRESHOLD = 0.0
|
|
124
|
-
DEFAULT_VECTOR_STORE = {
|
|
125
|
-
VECTOR_STORE_TYPE_PROPERTY: DEFAULT_VECTOR_STORE_TYPE,
|
|
126
|
-
SOURCE_PROPERTY: "./docs",
|
|
127
|
-
VECTOR_STORE_THRESHOLD_PROPERTY: DEFAULT_VECTOR_STORE_THRESHOLD,
|
|
128
|
-
}
|
|
129
100
|
|
|
130
101
|
DEFAULT_LLM_CONFIG = {
|
|
131
102
|
PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
|
|
@@ -136,11 +107,6 @@ DEFAULT_LLM_CONFIG = {
|
|
|
136
107
|
"max_retries": 1,
|
|
137
108
|
}
|
|
138
109
|
|
|
139
|
-
DEFAULT_EMBEDDINGS_CONFIG = {
|
|
140
|
-
PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
|
|
141
|
-
"model": DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
|
|
142
|
-
}
|
|
143
|
-
|
|
144
110
|
ENTERPRISE_SEARCH_PROMPT_FILE_NAME = "enterprise_search_policy_prompt.jinja2"
|
|
145
111
|
ENTERPRISE_SEARCH_CONFIG_FILE_NAME = "config.json"
|
|
146
112
|
|
|
@@ -156,14 +122,6 @@ DEFAULT_ENTERPRISE_SEARCH_PROMPT_WITH_CITATION_TEMPLATE = importlib.resources.re
|
|
|
156
122
|
)
|
|
157
123
|
|
|
158
124
|
|
|
159
|
-
class VectorStoreConnectionError(RasaException):
|
|
160
|
-
"""Exception raised for errors in connecting to the vector store."""
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
class VectorStoreConfigurationError(RasaException):
|
|
164
|
-
"""Exception raised for errors in vector store configuration."""
|
|
165
|
-
|
|
166
|
-
|
|
167
125
|
@DefaultV1Recipe.register(
|
|
168
126
|
DefaultV1Recipe.ComponentType.POLICY_WITH_END_TO_END_SUPPORT, is_trainable=True
|
|
169
127
|
)
|
|
@@ -197,7 +155,6 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
197
155
|
"""Returns the default config of the policy."""
|
|
198
156
|
return {
|
|
199
157
|
POLICY_PRIORITY: SEARCH_POLICY_PRIORITY,
|
|
200
|
-
VECTOR_STORE_PROPERTY: DEFAULT_VECTOR_STORE,
|
|
201
158
|
}
|
|
202
159
|
|
|
203
160
|
def __init__(
|
|
@@ -206,7 +163,6 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
206
163
|
model_storage: ModelStorage,
|
|
207
164
|
resource: Resource,
|
|
208
165
|
execution_context: ExecutionContext,
|
|
209
|
-
vector_store: Optional[InformationRetrieval] = None,
|
|
210
166
|
featurizer: Optional["TrackerFeaturizer"] = None,
|
|
211
167
|
prompt_template: Optional[Text] = None,
|
|
212
168
|
) -> None:
|
|
@@ -217,21 +173,6 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
217
173
|
self.config[LLM_CONFIG_KEY] = resolve_model_client_config(
|
|
218
174
|
self.config.get(LLM_CONFIG_KEY), EnterpriseSearchPolicy.__name__
|
|
219
175
|
)
|
|
220
|
-
# Resolve embeddings config
|
|
221
|
-
self.config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
|
|
222
|
-
self.config.get(EMBEDDINGS_CONFIG_KEY), EnterpriseSearchPolicy.__name__
|
|
223
|
-
)
|
|
224
|
-
|
|
225
|
-
# Vector store object and configuration
|
|
226
|
-
self.vector_store = vector_store
|
|
227
|
-
self.vector_store_config = self.config.get(
|
|
228
|
-
VECTOR_STORE_PROPERTY, DEFAULT_VECTOR_STORE
|
|
229
|
-
)
|
|
230
|
-
|
|
231
|
-
# Embeddings configuration for encoding the search query
|
|
232
|
-
self.embeddings_config = (
|
|
233
|
-
self.config[EMBEDDINGS_CONFIG_KEY] or DEFAULT_EMBEDDINGS_CONFIG
|
|
234
|
-
)
|
|
235
176
|
|
|
236
177
|
# LLM Configuration for response generation
|
|
237
178
|
self.llm_config = self.config[LLM_CONFIG_KEY] or DEFAULT_LLM_CONFIG
|
|
@@ -239,9 +180,6 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
239
180
|
# Maximum number of turns to include in the prompt
|
|
240
181
|
self.max_history = self.config.get(POLICY_MAX_HISTORY)
|
|
241
182
|
|
|
242
|
-
# Maximum number of messages to include in the search query
|
|
243
|
-
self.max_messages_in_query = self.config.get(MAX_MESSAGES_IN_QUERY_KEY, 2)
|
|
244
|
-
|
|
245
183
|
# boolean to enable/disable tracing of prompt tokens
|
|
246
184
|
self.trace_prompt_tokens = self.config.get(TRACE_TOKENS_PROPERTY, False)
|
|
247
185
|
|
|
@@ -263,25 +201,6 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
263
201
|
if self.citation_enabled:
|
|
264
202
|
self.prompt_template = self.citation_prompt_template
|
|
265
203
|
|
|
266
|
-
@classmethod
|
|
267
|
-
def _create_plain_embedder(cls, config: Dict[Text, Any]) -> "Embeddings":
|
|
268
|
-
"""Creates an embedder based on the given configuration.
|
|
269
|
-
|
|
270
|
-
Returns:
|
|
271
|
-
The embedder.
|
|
272
|
-
"""
|
|
273
|
-
# Copy the config so original config is not modified
|
|
274
|
-
config = config.copy()
|
|
275
|
-
# Resolve config and instantiate the embedding client
|
|
276
|
-
config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
|
|
277
|
-
config.get(EMBEDDINGS_CONFIG_KEY), EnterpriseSearchPolicy.__name__
|
|
278
|
-
)
|
|
279
|
-
client = embedder_factory(
|
|
280
|
-
config.get(EMBEDDINGS_CONFIG_KEY), DEFAULT_EMBEDDINGS_CONFIG
|
|
281
|
-
)
|
|
282
|
-
# Wrap the embedding client in the adapter
|
|
283
|
-
return _LangchainEmbeddingClientAdapter(client)
|
|
284
|
-
|
|
285
204
|
@classmethod
|
|
286
205
|
def _add_prompt_and_llm_response_to_latest_message(
|
|
287
206
|
cls,
|
|
@@ -346,53 +265,24 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
346
265
|
# Perform health checks for both LLM and embeddings client configs
|
|
347
266
|
self._perform_health_checks(self.config, "enterprise_search_policy.train")
|
|
348
267
|
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
# telemetry call to track training
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
#
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
)
|
|
268
|
+
# # telemetry call to track training start
|
|
269
|
+
# track_enterprise_search_policy_train_started()
|
|
270
|
+
# # telemetry call to track training completion
|
|
271
|
+
# track_enterprise_search_policy_train_completed(
|
|
272
|
+
# vector_store_type=store_type,
|
|
273
|
+
# embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
|
|
274
|
+
# embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
|
|
275
|
+
# or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
|
|
276
|
+
# embeddings_model_group_id=self.embeddings_config.get(
|
|
277
|
+
# MODEL_GROUP_ID_CONFIG_KEY
|
|
278
|
+
# ),
|
|
279
|
+
# llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
|
|
280
|
+
# llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
|
|
281
|
+
# or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
|
|
282
|
+
# llm_model_group_id=self.llm_config.get(MODEL_GROUP_ID_CONFIG_KEY),
|
|
283
|
+
# citation_enabled=self.citation_enabled,
|
|
284
|
+
# )
|
|
367
285
|
|
|
368
|
-
if store_type == DEFAULT_VECTOR_STORE_TYPE:
|
|
369
|
-
logger.info("enterprise_search_policy.train.faiss")
|
|
370
|
-
with self._model_storage.write_to(self._resource) as path:
|
|
371
|
-
self.vector_store = FAISS_Store(
|
|
372
|
-
docs_folder=self.vector_store_config.get(SOURCE_PROPERTY),
|
|
373
|
-
embeddings=embeddings,
|
|
374
|
-
index_path=path,
|
|
375
|
-
create_index=True,
|
|
376
|
-
use_llm=self.use_llm,
|
|
377
|
-
)
|
|
378
|
-
else:
|
|
379
|
-
logger.info("enterprise_search_policy.train.custom", store_type=store_type)
|
|
380
|
-
|
|
381
|
-
# telemetry call to track training completion
|
|
382
|
-
track_enterprise_search_policy_train_completed(
|
|
383
|
-
vector_store_type=store_type,
|
|
384
|
-
embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
|
|
385
|
-
embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
|
|
386
|
-
or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
|
|
387
|
-
embeddings_model_group_id=self.embeddings_config.get(
|
|
388
|
-
MODEL_GROUP_ID_CONFIG_KEY
|
|
389
|
-
),
|
|
390
|
-
llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
|
|
391
|
-
llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
|
|
392
|
-
or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
|
|
393
|
-
llm_model_group_id=self.llm_config.get(MODEL_GROUP_ID_CONFIG_KEY),
|
|
394
|
-
citation_enabled=self.citation_enabled,
|
|
395
|
-
)
|
|
396
286
|
self.persist()
|
|
397
287
|
return self._resource
|
|
398
288
|
|
|
@@ -429,60 +319,6 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
429
319
|
)
|
|
430
320
|
return template_slots
|
|
431
321
|
|
|
432
|
-
def _connect_vector_store_or_raise(
|
|
433
|
-
self, endpoints: Optional[AvailableEndpoints]
|
|
434
|
-
) -> None:
|
|
435
|
-
"""Connects to the vector store or raises an exception.
|
|
436
|
-
|
|
437
|
-
Raise exceptions for the following cases:
|
|
438
|
-
- The configuration is not specified
|
|
439
|
-
- Unable to connect to the vector store
|
|
440
|
-
|
|
441
|
-
Args:
|
|
442
|
-
endpoints: Endpoints configuration.
|
|
443
|
-
"""
|
|
444
|
-
config = endpoints.vector_store if endpoints else None
|
|
445
|
-
store_type = self.vector_store_config.get(VECTOR_STORE_TYPE_PROPERTY)
|
|
446
|
-
if config is None and store_type != DEFAULT_VECTOR_STORE_TYPE:
|
|
447
|
-
logger.error(
|
|
448
|
-
"enterprise_search_policy._connect_vector_store_or_raise.no_config"
|
|
449
|
-
)
|
|
450
|
-
raise VectorStoreConfigurationError(
|
|
451
|
-
"""No vector store specified. Please specify a vector
|
|
452
|
-
store in the endpoints configuration"""
|
|
453
|
-
)
|
|
454
|
-
try:
|
|
455
|
-
self.vector_store.connect(config) # type: ignore
|
|
456
|
-
except Exception as e:
|
|
457
|
-
logger.error(
|
|
458
|
-
"enterprise_search_policy._connect_vector_store_or_raise.connect_error",
|
|
459
|
-
error=e,
|
|
460
|
-
config=config,
|
|
461
|
-
)
|
|
462
|
-
raise VectorStoreConnectionError(
|
|
463
|
-
f"Unable to connect to the vector store. Error: {e}"
|
|
464
|
-
)
|
|
465
|
-
|
|
466
|
-
def _prepare_search_query(self, tracker: DialogueStateTracker, history: int) -> str:
|
|
467
|
-
"""Prepares the search query.
|
|
468
|
-
The search query is the last N messages in the conversation history.
|
|
469
|
-
|
|
470
|
-
Args:
|
|
471
|
-
tracker: The tracker containing the conversation history up to now.
|
|
472
|
-
history: The number of messages to include in the search query.
|
|
473
|
-
|
|
474
|
-
Returns:
|
|
475
|
-
The search query.
|
|
476
|
-
"""
|
|
477
|
-
transcript = []
|
|
478
|
-
for event in tracker.applied_events():
|
|
479
|
-
if isinstance(event, UserUttered) or isinstance(event, BotUttered):
|
|
480
|
-
transcript.append(sanitize_message_for_prompt(event.text))
|
|
481
|
-
|
|
482
|
-
search_query = " ".join(transcript[-history:][::-1])
|
|
483
|
-
logger.debug("search_query", search_query=search_query)
|
|
484
|
-
return search_query
|
|
485
|
-
|
|
486
322
|
async def predict_action_probabilities( # type: ignore[override]
|
|
487
323
|
self,
|
|
488
324
|
tracker: DialogueStateTracker,
|
|
@@ -506,44 +342,34 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
506
342
|
The prediction.
|
|
507
343
|
"""
|
|
508
344
|
logger_key = "enterprise_search_policy.predict_action_probabilities"
|
|
509
|
-
|
|
510
|
-
VECTOR_STORE_THRESHOLD_PROPERTY, DEFAULT_VECTOR_STORE_THRESHOLD
|
|
511
|
-
)
|
|
512
|
-
llm = llm_factory(self.config.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG)
|
|
345
|
+
|
|
513
346
|
if not self.supports_current_stack_frame(
|
|
514
347
|
tracker, False, False
|
|
515
348
|
) or self.should_abstain_in_coexistence(tracker, True):
|
|
516
349
|
return self._prediction(self._default_predictions(domain))
|
|
517
350
|
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
self._connect_vector_store_or_raise(endpoints)
|
|
524
|
-
except (VectorStoreConfigurationError, VectorStoreConnectionError) as e:
|
|
525
|
-
logger.error(f"{logger_key}.connection_error", error=e)
|
|
526
|
-
return self._create_prediction_internal_error(domain, tracker)
|
|
351
|
+
# retrieve documents from the latest message
|
|
352
|
+
# document retrieval happened earlier in the pipeline
|
|
353
|
+
if tracker.latest_message is None or tracker.latest_message.parse_data is None:
|
|
354
|
+
logger.info(f"{logger_key}.no_documents")
|
|
355
|
+
return self._create_prediction_cannot_handle(domain, tracker)
|
|
527
356
|
|
|
528
|
-
|
|
529
|
-
|
|
357
|
+
documents_data = tracker.latest_message.parse_data.get(
|
|
358
|
+
POST_PROCESSED_DOCUMENTS_KEY
|
|
530
359
|
)
|
|
531
|
-
tracker_state = tracker.current_state(EventVerbosity.AFTER_RESTART)
|
|
532
360
|
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
)
|
|
539
|
-
except InformationRetrievalException as e:
|
|
540
|
-
logger.error(f"{logger_key}.search_error", error=e)
|
|
541
|
-
return self._create_prediction_internal_error(domain, tracker)
|
|
361
|
+
if not documents_data:
|
|
362
|
+
logger.info(f"{logger_key}.no_documents")
|
|
363
|
+
return self._create_prediction_cannot_handle(domain, tracker)
|
|
364
|
+
|
|
365
|
+
documents = SearchResultList.from_dict(documents_data)
|
|
542
366
|
|
|
543
367
|
if not documents.results:
|
|
544
368
|
logger.info(f"{logger_key}.no_documents")
|
|
545
369
|
return self._create_prediction_cannot_handle(domain, tracker)
|
|
546
370
|
|
|
371
|
+
llm = llm_factory(self.config.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG)
|
|
372
|
+
|
|
547
373
|
if self.use_llm:
|
|
548
374
|
prompt = self._render_prompt(tracker, documents.results)
|
|
549
375
|
llm_response = await self._generate_llm_answer(llm, prompt)
|
|
@@ -588,25 +414,29 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
588
414
|
result.text for result in documents.results
|
|
589
415
|
],
|
|
590
416
|
UTTER_SOURCE_METADATA_KEY: self.__class__.__name__,
|
|
591
|
-
SEARCH_QUERY_METADATA_KEY:
|
|
417
|
+
SEARCH_QUERY_METADATA_KEY: tracker.latest_message.parse_data.get(
|
|
418
|
+
SEARCH_QUERY_KEY
|
|
419
|
+
),
|
|
592
420
|
}
|
|
593
421
|
}
|
|
594
422
|
|
|
595
|
-
# telemetry call to track policy prediction
|
|
596
|
-
track_enterprise_search_policy_predict(
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
423
|
+
# # telemetry call to track policy prediction
|
|
424
|
+
# track_enterprise_search_policy_predict(
|
|
425
|
+
# vector_store_type=self.vector_store_config.get(
|
|
426
|
+
# VECTOR_STORE_TYPE_PROPERTY),
|
|
427
|
+
# embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
|
|
428
|
+
# embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
|
|
429
|
+
# or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
|
|
430
|
+
# embeddings_model_group_id=self.embeddings_config.get(
|
|
431
|
+
# MODEL_GROUP_ID_CONFIG_KEY
|
|
432
|
+
# ),
|
|
433
|
+
# llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
|
|
434
|
+
# llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
|
|
435
|
+
# or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
|
|
436
|
+
# llm_model_group_id=self.llm_config.get(MODEL_GROUP_ID_CONFIG_KEY),
|
|
437
|
+
# citation_enabled=self.citation_enabled,
|
|
438
|
+
# )
|
|
439
|
+
|
|
610
440
|
return self._create_prediction(
|
|
611
441
|
domain=domain, tracker=tracker, action_metadata=action_metadata
|
|
612
442
|
)
|
|
@@ -774,73 +604,19 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
774
604
|
"enterprise_search_policy.load.failed", error=e, resource=resource.name
|
|
775
605
|
)
|
|
776
606
|
|
|
777
|
-
store_type = config.get(VECTOR_STORE_PROPERTY, {}).get(
|
|
778
|
-
VECTOR_STORE_TYPE_PROPERTY
|
|
779
|
-
)
|
|
780
|
-
|
|
781
|
-
embeddings = cls._create_plain_embedder(config)
|
|
782
|
-
|
|
783
607
|
logger.info("enterprise_search_policy.load", config=config)
|
|
784
|
-
if store_type == DEFAULT_VECTOR_STORE_TYPE:
|
|
785
|
-
# if a vector store is not specified,
|
|
786
|
-
# default to using FAISS with the index stored in the model
|
|
787
|
-
# TODO figure out a way to get path without context manager
|
|
788
|
-
with model_storage.read_from(resource) as path:
|
|
789
|
-
vector_store = FAISS_Store(
|
|
790
|
-
embeddings=embeddings,
|
|
791
|
-
index_path=path,
|
|
792
|
-
docs_folder=None,
|
|
793
|
-
create_index=False,
|
|
794
|
-
)
|
|
795
|
-
else:
|
|
796
|
-
vector_store = create_from_endpoint_config(
|
|
797
|
-
config_type=store_type,
|
|
798
|
-
embeddings=embeddings,
|
|
799
|
-
) # type: ignore
|
|
800
608
|
|
|
801
609
|
return cls(
|
|
802
610
|
config,
|
|
803
611
|
model_storage,
|
|
804
612
|
resource,
|
|
805
613
|
execution_context,
|
|
806
|
-
vector_store=vector_store,
|
|
807
614
|
prompt_template=prompt_template,
|
|
808
615
|
)
|
|
809
616
|
|
|
810
|
-
@classmethod
|
|
811
|
-
def _get_local_knowledge_data(cls, config: Dict[str, Any]) -> Optional[List[str]]:
|
|
812
|
-
"""This is required only for local knowledge base types.
|
|
813
|
-
|
|
814
|
-
e.g. FAISS, to ensure that the graph component is retrained when the knowledge
|
|
815
|
-
base is updated.
|
|
816
|
-
"""
|
|
817
|
-
merged_config = {**cls.get_default_config(), **config}
|
|
818
|
-
|
|
819
|
-
store_type = merged_config.get(VECTOR_STORE_PROPERTY, {}).get(
|
|
820
|
-
VECTOR_STORE_TYPE_PROPERTY
|
|
821
|
-
)
|
|
822
|
-
if store_type != DEFAULT_VECTOR_STORE_TYPE:
|
|
823
|
-
return None
|
|
824
|
-
|
|
825
|
-
source = merged_config.get(VECTOR_STORE_PROPERTY, {}).get(SOURCE_PROPERTY)
|
|
826
|
-
if not source:
|
|
827
|
-
return None
|
|
828
|
-
|
|
829
|
-
docs = FAISS_Store.load_documents(source)
|
|
830
|
-
|
|
831
|
-
if len(docs) == 0:
|
|
832
|
-
return None
|
|
833
|
-
|
|
834
|
-
docs_as_strings = [
|
|
835
|
-
json.dumps(doc.dict(), ensure_ascii=False, sort_keys=True) for doc in docs
|
|
836
|
-
]
|
|
837
|
-
return sorted(docs_as_strings)
|
|
838
|
-
|
|
839
617
|
@classmethod
|
|
840
618
|
def fingerprint_addon(cls, config: Dict[str, Any]) -> Optional[str]:
|
|
841
619
|
"""Add a fingerprint of enterprise search policy for the graph."""
|
|
842
|
-
local_knowledge_data = cls._get_local_knowledge_data(config)
|
|
843
|
-
|
|
844
620
|
prompt_template = get_prompt_template(
|
|
845
621
|
config.get(PROMPT_CONFIG_KEY),
|
|
846
622
|
DEFAULT_ENTERPRISE_SEARCH_PROMPT_TEMPLATE,
|
|
@@ -849,12 +625,7 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
849
625
|
llm_config = resolve_model_client_config(
|
|
850
626
|
config.get(LLM_CONFIG_KEY), EnterpriseSearchPolicy.__name__
|
|
851
627
|
)
|
|
852
|
-
|
|
853
|
-
config.get(EMBEDDINGS_CONFIG_KEY), EnterpriseSearchPolicy.__name__
|
|
854
|
-
)
|
|
855
|
-
return deep_container_fingerprint(
|
|
856
|
-
[prompt_template, local_knowledge_data, llm_config, embedding_config]
|
|
857
|
-
)
|
|
628
|
+
return deep_container_fingerprint([prompt_template, llm_config])
|
|
858
629
|
|
|
859
630
|
@staticmethod
|
|
860
631
|
def post_process_citations(llm_answer: str) -> str:
|
|
@@ -959,14 +730,3 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
959
730
|
log_source_method,
|
|
960
731
|
EnterpriseSearchPolicy.__name__,
|
|
961
732
|
)
|
|
962
|
-
|
|
963
|
-
# Perform health check of the embeddings client config
|
|
964
|
-
embeddings_config = resolve_model_client_config(
|
|
965
|
-
config.get(EMBEDDINGS_CONFIG_KEY, {})
|
|
966
|
-
)
|
|
967
|
-
cls.perform_embeddings_health_check(
|
|
968
|
-
embeddings_config,
|
|
969
|
-
DEFAULT_EMBEDDINGS_CONFIG,
|
|
970
|
-
log_source_method,
|
|
971
|
-
EnterpriseSearchPolicy.__name__,
|
|
972
|
-
)
|
|
@@ -23,7 +23,6 @@ from rasa.core.policies.flows.flow_step_result import (
|
|
|
23
23
|
)
|
|
24
24
|
from rasa.dialogue_understanding.commands import CancelFlowCommand
|
|
25
25
|
from rasa.dialogue_understanding.patterns.cancel import CancelPatternFlowStackFrame
|
|
26
|
-
from rasa.dialogue_understanding.patterns.clarify import ClarifyPatternFlowStackFrame
|
|
27
26
|
from rasa.dialogue_understanding.patterns.collect_information import (
|
|
28
27
|
CollectInformationPatternFlowStackFrame,
|
|
29
28
|
)
|
|
@@ -51,7 +50,6 @@ from rasa.dialogue_understanding.stack.frames.flow_stack_frame import (
|
|
|
51
50
|
)
|
|
52
51
|
from rasa.dialogue_understanding.stack.utils import (
|
|
53
52
|
top_user_flow_frame,
|
|
54
|
-
user_flows_on_the_stack,
|
|
55
53
|
)
|
|
56
54
|
from rasa.shared.constants import RASA_PATTERN_HUMAN_HANDOFF
|
|
57
55
|
from rasa.shared.core.constants import (
|
|
@@ -280,33 +278,6 @@ def trigger_pattern_continue_interrupted(
|
|
|
280
278
|
return events
|
|
281
279
|
|
|
282
280
|
|
|
283
|
-
def trigger_pattern_clarification(
|
|
284
|
-
current_frame: DialogueStackFrame, stack: DialogueStack, flows: FlowsList
|
|
285
|
-
) -> None:
|
|
286
|
-
"""Trigger the pattern to clarify which topic to continue if needed."""
|
|
287
|
-
if not isinstance(current_frame, UserFlowStackFrame):
|
|
288
|
-
return None
|
|
289
|
-
|
|
290
|
-
if current_frame.frame_type in [
|
|
291
|
-
FlowStackFrameType.CALL,
|
|
292
|
-
FlowStackFrameType.INTERRUPT,
|
|
293
|
-
]:
|
|
294
|
-
# we want to return to the flow that called
|
|
295
|
-
# the current flow or the flow that was interrupted
|
|
296
|
-
# by the current flow
|
|
297
|
-
return None
|
|
298
|
-
|
|
299
|
-
pending_flows = [
|
|
300
|
-
flows.flow_by_id(frame.flow_id)
|
|
301
|
-
for frame in stack.frames
|
|
302
|
-
if isinstance(frame, UserFlowStackFrame)
|
|
303
|
-
and frame.flow_id != current_frame.flow_id
|
|
304
|
-
]
|
|
305
|
-
|
|
306
|
-
flow_names = [flow.readable_name() for flow in pending_flows if flow is not None]
|
|
307
|
-
stack.push(ClarifyPatternFlowStackFrame(names=flow_names))
|
|
308
|
-
|
|
309
|
-
|
|
310
281
|
def trigger_pattern_completed(
|
|
311
282
|
current_frame: DialogueStackFrame, stack: DialogueStack, flows: FlowsList
|
|
312
283
|
) -> None:
|
|
@@ -675,15 +646,9 @@ def _run_end_step(
|
|
|
675
646
|
structlogger.debug("flow.step.run.flow_end")
|
|
676
647
|
current_frame = stack.pop()
|
|
677
648
|
trigger_pattern_completed(current_frame, stack, flows)
|
|
678
|
-
resumed_events =
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
# we need to trigger the pattern clarify
|
|
682
|
-
trigger_pattern_clarification(current_frame, stack, flows)
|
|
683
|
-
else:
|
|
684
|
-
resumed_events = trigger_pattern_continue_interrupted(
|
|
685
|
-
current_frame, stack, flows, tracker
|
|
686
|
-
)
|
|
649
|
+
resumed_events = trigger_pattern_continue_interrupted(
|
|
650
|
+
current_frame, stack, flows, tracker
|
|
651
|
+
)
|
|
687
652
|
reset_events: List[Event] = reset_scoped_slots(current_frame, flow, tracker)
|
|
688
653
|
return ContinueFlowWithNextStep(
|
|
689
654
|
events=initial_events + reset_events + resumed_events, has_flow_ended=True
|
rasa/core/processor.py
CHANGED
|
@@ -76,6 +76,7 @@ from rasa.shared.core.constants import (
|
|
|
76
76
|
SLOT_SILENCE_TIMEOUT,
|
|
77
77
|
USER_INTENT_RESTART,
|
|
78
78
|
USER_INTENT_SILENCE_TIMEOUT,
|
|
79
|
+
SetSlotExtractor,
|
|
79
80
|
)
|
|
80
81
|
from rasa.shared.core.events import (
|
|
81
82
|
ActionExecuted,
|
|
@@ -766,13 +767,26 @@ class MessageProcessor:
|
|
|
766
767
|
if self.http_interpreter:
|
|
767
768
|
parse_data = await self.http_interpreter.parse(message)
|
|
768
769
|
else:
|
|
769
|
-
regex_reader = create_regex_pattern_reader(message, self.domain)
|
|
770
|
-
|
|
771
770
|
processed_message = Message({TEXT: message.text})
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
771
|
+
|
|
772
|
+
all_flows = await self.get_flows()
|
|
773
|
+
should_force_slot_command, slot_name = (
|
|
774
|
+
rasa.core.utils.should_force_slot_filling(tracker, all_flows)
|
|
775
|
+
)
|
|
776
|
+
|
|
777
|
+
if should_force_slot_command:
|
|
778
|
+
command = SetSlotCommand(
|
|
779
|
+
name=slot_name,
|
|
780
|
+
value=message.text,
|
|
781
|
+
extractor=SetSlotExtractor.COMMAND_PAYLOAD_READER.value,
|
|
775
782
|
)
|
|
783
|
+
processed_message.set(COMMANDS, [command.as_dict()], add_to_output=True)
|
|
784
|
+
else:
|
|
785
|
+
regex_reader = create_regex_pattern_reader(message, self.domain)
|
|
786
|
+
if regex_reader:
|
|
787
|
+
processed_message = regex_reader.unpack_regex_message(
|
|
788
|
+
message=processed_message, domain=self.domain
|
|
789
|
+
)
|
|
776
790
|
|
|
777
791
|
# Invalid use of slash syntax, sanitize the message before passing
|
|
778
792
|
# it to the graph
|
|
@@ -1009,7 +1023,14 @@ class MessageProcessor:
|
|
|
1009
1023
|
|
|
1010
1024
|
@staticmethod
|
|
1011
1025
|
def _should_handle_message(tracker: DialogueStateTracker) -> bool:
|
|
1012
|
-
return not tracker.is_paused() or (
|
|
1026
|
+
return not tracker.is_paused() or MessageProcessor._last_user_intent_is_restart(
|
|
1027
|
+
tracker
|
|
1028
|
+
)
|
|
1029
|
+
|
|
1030
|
+
@staticmethod
|
|
1031
|
+
def _last_user_intent_is_restart(tracker: DialogueStateTracker) -> bool:
|
|
1032
|
+
"""Check if the last user intent is a restart intent."""
|
|
1033
|
+
return (
|
|
1013
1034
|
tracker.latest_message is not None
|
|
1014
1035
|
and tracker.latest_message.intent.get(INTENT_NAME_KEY)
|
|
1015
1036
|
== USER_INTENT_RESTART
|