rasa-pro 3.12.6.dev2__py3-none-any.whl → 3.13.0.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/__init__.py +0 -6
- rasa/cli/scaffold.py +1 -1
- rasa/core/actions/action.py +38 -34
- rasa/core/actions/action_run_slot_rejections.py +1 -1
- rasa/core/channels/studio_chat.py +16 -43
- rasa/core/channels/voice_ready/audiocodes.py +46 -17
- rasa/core/information_retrieval/faiss.py +68 -7
- rasa/core/information_retrieval/information_retrieval.py +40 -2
- rasa/core/information_retrieval/milvus.py +7 -2
- rasa/core/information_retrieval/qdrant.py +7 -2
- rasa/core/nlg/contextual_response_rephraser.py +11 -27
- rasa/core/nlg/generator.py +5 -21
- rasa/core/nlg/response.py +6 -43
- rasa/core/nlg/summarize.py +1 -15
- rasa/core/nlg/translate.py +0 -8
- rasa/core/policies/enterprise_search_policy.py +64 -316
- rasa/core/policies/flows/flow_executor.py +3 -38
- rasa/core/policies/intentless_policy.py +4 -17
- rasa/core/policies/policy.py +0 -2
- rasa/core/processor.py +27 -6
- rasa/core/utils.py +53 -0
- rasa/dialogue_understanding/coexistence/llm_based_router.py +4 -18
- rasa/dialogue_understanding/commands/cancel_flow_command.py +4 -59
- rasa/dialogue_understanding/commands/knowledge_answer_command.py +2 -2
- rasa/dialogue_understanding/commands/start_flow_command.py +0 -41
- rasa/dialogue_understanding/generator/command_generator.py +67 -0
- rasa/dialogue_understanding/generator/command_parser.py +1 -1
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +7 -23
- rasa/dialogue_understanding/generator/llm_command_generator.py +1 -3
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_template.jinja2 +1 -1
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_claude_3_5_sonnet_20240620_template.jinja2 +1 -1
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_gpt_4o_2024_11_20_template.jinja2 +24 -2
- rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +8 -12
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +0 -61
- rasa/dialogue_understanding/processor/command_processor.py +7 -65
- rasa/dialogue_understanding/stack/utils.py +0 -38
- rasa/dialogue_understanding_test/command_metric_calculation.py +7 -40
- rasa/dialogue_understanding_test/command_metrics.py +38 -0
- rasa/dialogue_understanding_test/du_test_case.py +58 -25
- rasa/dialogue_understanding_test/du_test_result.py +228 -132
- rasa/dialogue_understanding_test/du_test_runner.py +10 -1
- rasa/dialogue_understanding_test/io.py +48 -16
- rasa/document_retrieval/__init__.py +0 -0
- rasa/document_retrieval/constants.py +32 -0
- rasa/document_retrieval/document_post_processor.py +351 -0
- rasa/document_retrieval/document_post_processor_prompt_template.jinja2 +0 -0
- rasa/document_retrieval/document_retriever.py +333 -0
- rasa/document_retrieval/knowledge_base_connectors/__init__.py +0 -0
- rasa/document_retrieval/knowledge_base_connectors/api_connector.py +39 -0
- rasa/document_retrieval/knowledge_base_connectors/knowledge_base_connector.py +34 -0
- rasa/document_retrieval/knowledge_base_connectors/vector_store_connector.py +226 -0
- rasa/document_retrieval/query_rewriter.py +234 -0
- rasa/document_retrieval/query_rewriter_prompt_template.jinja2 +8 -0
- rasa/engine/recipes/default_components.py +2 -0
- rasa/hooks.py +0 -55
- rasa/model_manager/model_api.py +1 -1
- rasa/model_manager/socket_bridge.py +0 -7
- rasa/shared/constants.py +0 -5
- rasa/shared/core/constants.py +0 -8
- rasa/shared/core/domain.py +12 -3
- rasa/shared/core/flows/flow.py +0 -17
- rasa/shared/core/flows/flows_yaml_schema.json +3 -38
- rasa/shared/core/flows/steps/collect.py +5 -18
- rasa/shared/core/flows/utils.py +1 -16
- rasa/shared/core/slot_mappings.py +11 -5
- rasa/shared/core/slots.py +1 -1
- rasa/shared/core/trackers.py +4 -10
- rasa/shared/nlu/constants.py +0 -1
- rasa/shared/providers/constants.py +0 -9
- rasa/shared/providers/llm/_base_litellm_client.py +4 -14
- rasa/shared/providers/llm/default_litellm_llm_client.py +2 -2
- rasa/shared/providers/llm/litellm_router_llm_client.py +7 -17
- rasa/shared/providers/llm/llm_client.py +15 -24
- rasa/shared/providers/llm/self_hosted_llm_client.py +2 -10
- rasa/shared/utils/common.py +11 -1
- rasa/shared/utils/health_check/health_check.py +1 -7
- rasa/shared/utils/llm.py +1 -1
- rasa/tracing/instrumentation/attribute_extractors.py +50 -17
- rasa/tracing/instrumentation/instrumentation.py +12 -12
- rasa/tracing/instrumentation/intentless_policy_instrumentation.py +1 -2
- rasa/utils/licensing.py +0 -15
- rasa/validator.py +1 -123
- rasa/version.py +1 -1
- {rasa_pro-3.12.6.dev2.dist-info → rasa_pro-3.13.0.dev2.dist-info}/METADATA +2 -3
- {rasa_pro-3.12.6.dev2.dist-info → rasa_pro-3.13.0.dev2.dist-info}/RECORD +88 -80
- rasa/core/actions/action_handle_digressions.py +0 -164
- rasa/dialogue_understanding/commands/handle_digressions_command.py +0 -144
- rasa/dialogue_understanding/patterns/handle_digressions.py +0 -81
- rasa/monkey_patches.py +0 -91
- {rasa_pro-3.12.6.dev2.dist-info → rasa_pro-3.13.0.dev2.dist-info}/NOTICE +0 -0
- {rasa_pro-3.12.6.dev2.dist-info → rasa_pro-3.13.0.dev2.dist-info}/WHEEL +0 -0
- {rasa_pro-3.12.6.dev2.dist-info → rasa_pro-3.13.0.dev2.dist-info}/entry_points.txt +0 -0
|
@@ -1,12 +1,10 @@
|
|
|
1
1
|
import importlib.resources
|
|
2
|
-
import json
|
|
3
2
|
import re
|
|
4
3
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Text
|
|
5
4
|
|
|
6
5
|
import dotenv
|
|
7
6
|
import structlog
|
|
8
7
|
from jinja2 import Template
|
|
9
|
-
from pydantic import ValidationError
|
|
10
8
|
|
|
11
9
|
import rasa.shared.utils.io
|
|
12
10
|
from rasa.core.constants import (
|
|
@@ -16,12 +14,9 @@ from rasa.core.constants import (
|
|
|
16
14
|
UTTER_SOURCE_METADATA_KEY,
|
|
17
15
|
)
|
|
18
16
|
from rasa.core.information_retrieval import (
|
|
19
|
-
InformationRetrieval,
|
|
20
|
-
InformationRetrievalException,
|
|
21
17
|
SearchResult,
|
|
22
|
-
|
|
18
|
+
SearchResultList,
|
|
23
19
|
)
|
|
24
|
-
from rasa.core.information_retrieval.faiss import FAISS_Store
|
|
25
20
|
from rasa.core.policies.policy import Policy, PolicyPrediction
|
|
26
21
|
from rasa.core.utils import AvailableEndpoints
|
|
27
22
|
from rasa.dialogue_understanding.generator.constants import (
|
|
@@ -38,6 +33,10 @@ from rasa.dialogue_understanding.stack.frames import (
|
|
|
38
33
|
PatternFlowStackFrame,
|
|
39
34
|
SearchStackFrame,
|
|
40
35
|
)
|
|
36
|
+
from rasa.document_retrieval.constants import (
|
|
37
|
+
POST_PROCESSED_DOCUMENTS_KEY,
|
|
38
|
+
SEARCH_QUERY_KEY,
|
|
39
|
+
)
|
|
41
40
|
from rasa.engine.graph import ExecutionContext
|
|
42
41
|
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
43
42
|
from rasa.engine.storage.resource import Resource
|
|
@@ -45,14 +44,7 @@ from rasa.engine.storage.storage import ModelStorage
|
|
|
45
44
|
from rasa.graph_components.providers.forms_provider import Forms
|
|
46
45
|
from rasa.graph_components.providers.responses_provider import Responses
|
|
47
46
|
from rasa.shared.constants import (
|
|
48
|
-
EMBEDDINGS_CONFIG_KEY,
|
|
49
|
-
LANGFUSE_CUSTOM_METADATA_DICT,
|
|
50
|
-
LANGFUSE_METADATA_SESSION_ID,
|
|
51
|
-
LANGFUSE_METADATA_USER_ID,
|
|
52
|
-
LANGFUSE_TAGS,
|
|
53
47
|
MODEL_CONFIG_KEY,
|
|
54
|
-
MODEL_GROUP_ID_CONFIG_KEY,
|
|
55
|
-
MODEL_NAME_CONFIG_KEY,
|
|
56
48
|
OPENAI_PROVIDER,
|
|
57
49
|
PROMPT_CONFIG_KEY,
|
|
58
50
|
PROVIDER_CONFIG_KEY,
|
|
@@ -64,10 +56,10 @@ from rasa.shared.core.constants import (
|
|
|
64
56
|
DEFAULT_SLOT_NAMES,
|
|
65
57
|
)
|
|
66
58
|
from rasa.shared.core.domain import Domain
|
|
67
|
-
from rasa.shared.core.events import
|
|
59
|
+
from rasa.shared.core.events import Event
|
|
68
60
|
from rasa.shared.core.generator import TrackerWithCachedStates
|
|
69
|
-
from rasa.shared.core.trackers import DialogueStateTracker
|
|
70
|
-
from rasa.shared.exceptions import FileIOException
|
|
61
|
+
from rasa.shared.core.trackers import DialogueStateTracker
|
|
62
|
+
from rasa.shared.exceptions import FileIOException
|
|
71
63
|
from rasa.shared.nlu.constants import (
|
|
72
64
|
KEY_COMPONENT_NAME,
|
|
73
65
|
KEY_LLM_RESPONSE_METADATA,
|
|
@@ -76,12 +68,8 @@ from rasa.shared.nlu.constants import (
|
|
|
76
68
|
PROMPTS,
|
|
77
69
|
)
|
|
78
70
|
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
79
|
-
from rasa.shared.providers.embedding._langchain_embedding_client_adapter import (
|
|
80
|
-
_LangchainEmbeddingClientAdapter,
|
|
81
|
-
)
|
|
82
71
|
from rasa.shared.providers.llm.llm_client import LLMClient
|
|
83
72
|
from rasa.shared.providers.llm.llm_response import LLMResponse, measure_llm_latency
|
|
84
|
-
from rasa.shared.utils.cli import print_error_and_exit
|
|
85
73
|
from rasa.shared.utils.health_check.embeddings_health_check_mixin import (
|
|
86
74
|
EmbeddingsHealthCheckMixin,
|
|
87
75
|
)
|
|
@@ -89,23 +77,13 @@ from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheck
|
|
|
89
77
|
from rasa.shared.utils.io import deep_container_fingerprint
|
|
90
78
|
from rasa.shared.utils.llm import (
|
|
91
79
|
DEFAULT_OPENAI_CHAT_MODEL_NAME,
|
|
92
|
-
DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
|
|
93
|
-
embedder_factory,
|
|
94
80
|
get_prompt_template,
|
|
95
81
|
llm_factory,
|
|
96
82
|
resolve_model_client_config,
|
|
97
|
-
sanitize_message_for_prompt,
|
|
98
83
|
tracker_as_readable_transcript,
|
|
99
84
|
)
|
|
100
|
-
from rasa.telemetry import (
|
|
101
|
-
track_enterprise_search_policy_predict,
|
|
102
|
-
track_enterprise_search_policy_train_completed,
|
|
103
|
-
track_enterprise_search_policy_train_started,
|
|
104
|
-
)
|
|
105
85
|
|
|
106
86
|
if TYPE_CHECKING:
|
|
107
|
-
from langchain.schema.embeddings import Embeddings
|
|
108
|
-
|
|
109
87
|
from rasa.core.featurizers.tracker_featurizers import TrackerFeaturizer
|
|
110
88
|
|
|
111
89
|
from rasa.utils.log_utils import log_llm
|
|
@@ -114,22 +92,11 @@ logger = structlog.get_logger()
|
|
|
114
92
|
|
|
115
93
|
dotenv.load_dotenv("./.env")
|
|
116
94
|
|
|
117
|
-
SOURCE_PROPERTY = "source"
|
|
118
|
-
VECTOR_STORE_TYPE_PROPERTY = "type"
|
|
119
|
-
VECTOR_STORE_PROPERTY = "vector_store"
|
|
120
|
-
VECTOR_STORE_THRESHOLD_PROPERTY = "threshold"
|
|
121
95
|
TRACE_TOKENS_PROPERTY = "trace_prompt_tokens"
|
|
122
96
|
CITATION_ENABLED_PROPERTY = "citation_enabled"
|
|
123
97
|
USE_LLM_PROPERTY = "use_generative_llm"
|
|
124
98
|
MAX_MESSAGES_IN_QUERY_KEY = "max_messages_in_query"
|
|
125
99
|
|
|
126
|
-
DEFAULT_VECTOR_STORE_TYPE = "faiss"
|
|
127
|
-
DEFAULT_VECTOR_STORE_THRESHOLD = 0.0
|
|
128
|
-
DEFAULT_VECTOR_STORE = {
|
|
129
|
-
VECTOR_STORE_TYPE_PROPERTY: DEFAULT_VECTOR_STORE_TYPE,
|
|
130
|
-
SOURCE_PROPERTY: "./docs",
|
|
131
|
-
VECTOR_STORE_THRESHOLD_PROPERTY: DEFAULT_VECTOR_STORE_THRESHOLD,
|
|
132
|
-
}
|
|
133
100
|
|
|
134
101
|
DEFAULT_LLM_CONFIG = {
|
|
135
102
|
PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
|
|
@@ -140,11 +107,6 @@ DEFAULT_LLM_CONFIG = {
|
|
|
140
107
|
"max_retries": 1,
|
|
141
108
|
}
|
|
142
109
|
|
|
143
|
-
DEFAULT_EMBEDDINGS_CONFIG = {
|
|
144
|
-
PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
|
|
145
|
-
"model": DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
|
|
146
|
-
}
|
|
147
|
-
|
|
148
110
|
ENTERPRISE_SEARCH_PROMPT_FILE_NAME = "enterprise_search_policy_prompt.jinja2"
|
|
149
111
|
ENTERPRISE_SEARCH_CONFIG_FILE_NAME = "config.json"
|
|
150
112
|
|
|
@@ -160,14 +122,6 @@ DEFAULT_ENTERPRISE_SEARCH_PROMPT_WITH_CITATION_TEMPLATE = importlib.resources.re
|
|
|
160
122
|
)
|
|
161
123
|
|
|
162
124
|
|
|
163
|
-
class VectorStoreConnectionError(RasaException):
|
|
164
|
-
"""Exception raised for errors in connecting to the vector store."""
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
class VectorStoreConfigurationError(RasaException):
|
|
168
|
-
"""Exception raised for errors in vector store configuration."""
|
|
169
|
-
|
|
170
|
-
|
|
171
125
|
@DefaultV1Recipe.register(
|
|
172
126
|
DefaultV1Recipe.ComponentType.POLICY_WITH_END_TO_END_SUPPORT, is_trainable=True
|
|
173
127
|
)
|
|
@@ -201,7 +155,6 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
201
155
|
"""Returns the default config of the policy."""
|
|
202
156
|
return {
|
|
203
157
|
POLICY_PRIORITY: SEARCH_POLICY_PRIORITY,
|
|
204
|
-
VECTOR_STORE_PROPERTY: DEFAULT_VECTOR_STORE,
|
|
205
158
|
}
|
|
206
159
|
|
|
207
160
|
def __init__(
|
|
@@ -210,7 +163,6 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
210
163
|
model_storage: ModelStorage,
|
|
211
164
|
resource: Resource,
|
|
212
165
|
execution_context: ExecutionContext,
|
|
213
|
-
vector_store: Optional[InformationRetrieval] = None,
|
|
214
166
|
featurizer: Optional["TrackerFeaturizer"] = None,
|
|
215
167
|
prompt_template: Optional[Text] = None,
|
|
216
168
|
) -> None:
|
|
@@ -221,21 +173,6 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
221
173
|
self.config[LLM_CONFIG_KEY] = resolve_model_client_config(
|
|
222
174
|
self.config.get(LLM_CONFIG_KEY), EnterpriseSearchPolicy.__name__
|
|
223
175
|
)
|
|
224
|
-
# Resolve embeddings config
|
|
225
|
-
self.config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
|
|
226
|
-
self.config.get(EMBEDDINGS_CONFIG_KEY), EnterpriseSearchPolicy.__name__
|
|
227
|
-
)
|
|
228
|
-
|
|
229
|
-
# Vector store object and configuration
|
|
230
|
-
self.vector_store = vector_store
|
|
231
|
-
self.vector_store_config = self.config.get(
|
|
232
|
-
VECTOR_STORE_PROPERTY, DEFAULT_VECTOR_STORE
|
|
233
|
-
)
|
|
234
|
-
|
|
235
|
-
# Embeddings configuration for encoding the search query
|
|
236
|
-
self.embeddings_config = (
|
|
237
|
-
self.config[EMBEDDINGS_CONFIG_KEY] or DEFAULT_EMBEDDINGS_CONFIG
|
|
238
|
-
)
|
|
239
176
|
|
|
240
177
|
# LLM Configuration for response generation
|
|
241
178
|
self.llm_config = self.config[LLM_CONFIG_KEY] or DEFAULT_LLM_CONFIG
|
|
@@ -243,9 +180,6 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
243
180
|
# Maximum number of turns to include in the prompt
|
|
244
181
|
self.max_history = self.config.get(POLICY_MAX_HISTORY)
|
|
245
182
|
|
|
246
|
-
# Maximum number of messages to include in the search query
|
|
247
|
-
self.max_messages_in_query = self.config.get(MAX_MESSAGES_IN_QUERY_KEY, 2)
|
|
248
|
-
|
|
249
183
|
# boolean to enable/disable tracing of prompt tokens
|
|
250
184
|
self.trace_prompt_tokens = self.config.get(TRACE_TOKENS_PROPERTY, False)
|
|
251
185
|
|
|
@@ -267,25 +201,6 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
267
201
|
if self.citation_enabled:
|
|
268
202
|
self.prompt_template = self.citation_prompt_template
|
|
269
203
|
|
|
270
|
-
@classmethod
|
|
271
|
-
def _create_plain_embedder(cls, config: Dict[Text, Any]) -> "Embeddings":
|
|
272
|
-
"""Creates an embedder based on the given configuration.
|
|
273
|
-
|
|
274
|
-
Returns:
|
|
275
|
-
The embedder.
|
|
276
|
-
"""
|
|
277
|
-
# Copy the config so original config is not modified
|
|
278
|
-
config = config.copy()
|
|
279
|
-
# Resolve config and instantiate the embedding client
|
|
280
|
-
config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
|
|
281
|
-
config.get(EMBEDDINGS_CONFIG_KEY), EnterpriseSearchPolicy.__name__
|
|
282
|
-
)
|
|
283
|
-
client = embedder_factory(
|
|
284
|
-
config.get(EMBEDDINGS_CONFIG_KEY), DEFAULT_EMBEDDINGS_CONFIG
|
|
285
|
-
)
|
|
286
|
-
# Wrap the embedding client in the adapter
|
|
287
|
-
return _LangchainEmbeddingClientAdapter(client)
|
|
288
|
-
|
|
289
204
|
@classmethod
|
|
290
205
|
def _add_prompt_and_llm_response_to_latest_message(
|
|
291
206
|
cls,
|
|
@@ -350,52 +265,24 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
350
265
|
# Perform health checks for both LLM and embeddings client configs
|
|
351
266
|
self._perform_health_checks(self.config, "enterprise_search_policy.train")
|
|
352
267
|
|
|
353
|
-
|
|
268
|
+
# # telemetry call to track training start
|
|
269
|
+
# track_enterprise_search_policy_train_started()
|
|
270
|
+
# # telemetry call to track training completion
|
|
271
|
+
# track_enterprise_search_policy_train_completed(
|
|
272
|
+
# vector_store_type=store_type,
|
|
273
|
+
# embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
|
|
274
|
+
# embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
|
|
275
|
+
# or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
|
|
276
|
+
# embeddings_model_group_id=self.embeddings_config.get(
|
|
277
|
+
# MODEL_GROUP_ID_CONFIG_KEY
|
|
278
|
+
# ),
|
|
279
|
+
# llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
|
|
280
|
+
# llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
|
|
281
|
+
# or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
|
|
282
|
+
# llm_model_group_id=self.llm_config.get(MODEL_GROUP_ID_CONFIG_KEY),
|
|
283
|
+
# citation_enabled=self.citation_enabled,
|
|
284
|
+
# )
|
|
354
285
|
|
|
355
|
-
# telemetry call to track training start
|
|
356
|
-
track_enterprise_search_policy_train_started()
|
|
357
|
-
|
|
358
|
-
# validate embedding configuration
|
|
359
|
-
try:
|
|
360
|
-
embeddings = self._create_plain_embedder(self.config)
|
|
361
|
-
except (ValidationError, Exception) as e:
|
|
362
|
-
logger.error(
|
|
363
|
-
"enterprise_search_policy.train.embedder_instantiation_failed",
|
|
364
|
-
message="Unable to instantiate the embedding client.",
|
|
365
|
-
error=e,
|
|
366
|
-
)
|
|
367
|
-
print_error_and_exit(
|
|
368
|
-
"Unable to create embedder. Please make sure you specified the "
|
|
369
|
-
f"required environment variables. Error: {e}"
|
|
370
|
-
)
|
|
371
|
-
|
|
372
|
-
if store_type == DEFAULT_VECTOR_STORE_TYPE:
|
|
373
|
-
logger.info("enterprise_search_policy.train.faiss")
|
|
374
|
-
with self._model_storage.write_to(self._resource) as path:
|
|
375
|
-
self.vector_store = FAISS_Store(
|
|
376
|
-
docs_folder=self.vector_store_config.get(SOURCE_PROPERTY),
|
|
377
|
-
embeddings=embeddings,
|
|
378
|
-
index_path=path,
|
|
379
|
-
create_index=True,
|
|
380
|
-
)
|
|
381
|
-
else:
|
|
382
|
-
logger.info("enterprise_search_policy.train.custom", store_type=store_type)
|
|
383
|
-
|
|
384
|
-
# telemetry call to track training completion
|
|
385
|
-
track_enterprise_search_policy_train_completed(
|
|
386
|
-
vector_store_type=store_type,
|
|
387
|
-
embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
|
|
388
|
-
embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
|
|
389
|
-
or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
|
|
390
|
-
embeddings_model_group_id=self.embeddings_config.get(
|
|
391
|
-
MODEL_GROUP_ID_CONFIG_KEY
|
|
392
|
-
),
|
|
393
|
-
llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
|
|
394
|
-
llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
|
|
395
|
-
or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
|
|
396
|
-
llm_model_group_id=self.llm_config.get(MODEL_GROUP_ID_CONFIG_KEY),
|
|
397
|
-
citation_enabled=self.citation_enabled,
|
|
398
|
-
)
|
|
399
286
|
self.persist()
|
|
400
287
|
return self._resource
|
|
401
288
|
|
|
@@ -432,60 +319,6 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
432
319
|
)
|
|
433
320
|
return template_slots
|
|
434
321
|
|
|
435
|
-
def _connect_vector_store_or_raise(
|
|
436
|
-
self, endpoints: Optional[AvailableEndpoints]
|
|
437
|
-
) -> None:
|
|
438
|
-
"""Connects to the vector store or raises an exception.
|
|
439
|
-
|
|
440
|
-
Raise exceptions for the following cases:
|
|
441
|
-
- The configuration is not specified
|
|
442
|
-
- Unable to connect to the vector store
|
|
443
|
-
|
|
444
|
-
Args:
|
|
445
|
-
endpoints: Endpoints configuration.
|
|
446
|
-
"""
|
|
447
|
-
config = endpoints.vector_store if endpoints else None
|
|
448
|
-
store_type = self.vector_store_config.get(VECTOR_STORE_TYPE_PROPERTY)
|
|
449
|
-
if config is None and store_type != DEFAULT_VECTOR_STORE_TYPE:
|
|
450
|
-
logger.error(
|
|
451
|
-
"enterprise_search_policy._connect_vector_store_or_raise.no_config"
|
|
452
|
-
)
|
|
453
|
-
raise VectorStoreConfigurationError(
|
|
454
|
-
"""No vector store specified. Please specify a vector
|
|
455
|
-
store in the endpoints configuration"""
|
|
456
|
-
)
|
|
457
|
-
try:
|
|
458
|
-
self.vector_store.connect(config) # type: ignore
|
|
459
|
-
except Exception as e:
|
|
460
|
-
logger.error(
|
|
461
|
-
"enterprise_search_policy._connect_vector_store_or_raise.connect_error",
|
|
462
|
-
error=e,
|
|
463
|
-
config=config,
|
|
464
|
-
)
|
|
465
|
-
raise VectorStoreConnectionError(
|
|
466
|
-
f"Unable to connect to the vector store. Error: {e}"
|
|
467
|
-
)
|
|
468
|
-
|
|
469
|
-
def _prepare_search_query(self, tracker: DialogueStateTracker, history: int) -> str:
|
|
470
|
-
"""Prepares the search query.
|
|
471
|
-
The search query is the last N messages in the conversation history.
|
|
472
|
-
|
|
473
|
-
Args:
|
|
474
|
-
tracker: The tracker containing the conversation history up to now.
|
|
475
|
-
history: The number of messages to include in the search query.
|
|
476
|
-
|
|
477
|
-
Returns:
|
|
478
|
-
The search query.
|
|
479
|
-
"""
|
|
480
|
-
transcript = []
|
|
481
|
-
for event in tracker.applied_events():
|
|
482
|
-
if isinstance(event, UserUttered) or isinstance(event, BotUttered):
|
|
483
|
-
transcript.append(sanitize_message_for_prompt(event.text))
|
|
484
|
-
|
|
485
|
-
search_query = " ".join(transcript[-history:][::-1])
|
|
486
|
-
logger.debug("search_query", search_query=search_query)
|
|
487
|
-
return search_query
|
|
488
|
-
|
|
489
322
|
async def predict_action_probabilities( # type: ignore[override]
|
|
490
323
|
self,
|
|
491
324
|
tracker: DialogueStateTracker,
|
|
@@ -509,49 +342,37 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
509
342
|
The prediction.
|
|
510
343
|
"""
|
|
511
344
|
logger_key = "enterprise_search_policy.predict_action_probabilities"
|
|
512
|
-
|
|
513
|
-
VECTOR_STORE_THRESHOLD_PROPERTY, DEFAULT_VECTOR_STORE_THRESHOLD
|
|
514
|
-
)
|
|
515
|
-
llm = llm_factory(self.config.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG)
|
|
345
|
+
|
|
516
346
|
if not self.supports_current_stack_frame(
|
|
517
347
|
tracker, False, False
|
|
518
348
|
) or self.should_abstain_in_coexistence(tracker, True):
|
|
519
349
|
return self._prediction(self._default_predictions(domain))
|
|
520
350
|
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
self._connect_vector_store_or_raise(endpoints)
|
|
527
|
-
except (VectorStoreConfigurationError, VectorStoreConnectionError) as e:
|
|
528
|
-
logger.error(f"{logger_key}.connection_error", error=e)
|
|
529
|
-
return self._create_prediction_internal_error(domain, tracker)
|
|
351
|
+
# retrieve documents from the latest message
|
|
352
|
+
# document retrieval happened earlier in the pipeline
|
|
353
|
+
if tracker.latest_message is None or tracker.latest_message.parse_data is None:
|
|
354
|
+
logger.info(f"{logger_key}.no_documents")
|
|
355
|
+
return self._create_prediction_cannot_handle(domain, tracker)
|
|
530
356
|
|
|
531
|
-
|
|
532
|
-
|
|
357
|
+
documents_data = tracker.latest_message.parse_data.get(
|
|
358
|
+
POST_PROCESSED_DOCUMENTS_KEY
|
|
533
359
|
)
|
|
534
|
-
tracker_state = tracker.current_state(EventVerbosity.AFTER_RESTART)
|
|
535
360
|
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
)
|
|
542
|
-
except InformationRetrievalException as e:
|
|
543
|
-
logger.error(f"{logger_key}.search_error", error=e)
|
|
544
|
-
return self._create_prediction_internal_error(domain, tracker)
|
|
361
|
+
if not documents_data:
|
|
362
|
+
logger.info(f"{logger_key}.no_documents")
|
|
363
|
+
return self._create_prediction_cannot_handle(domain, tracker)
|
|
364
|
+
|
|
365
|
+
documents = SearchResultList.from_dict(documents_data)
|
|
545
366
|
|
|
546
367
|
if not documents.results:
|
|
547
368
|
logger.info(f"{logger_key}.no_documents")
|
|
548
369
|
return self._create_prediction_cannot_handle(domain, tracker)
|
|
549
370
|
|
|
371
|
+
llm = llm_factory(self.config.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG)
|
|
372
|
+
|
|
550
373
|
if self.use_llm:
|
|
551
374
|
prompt = self._render_prompt(tracker, documents.results)
|
|
552
|
-
llm_response = await self._generate_llm_answer(
|
|
553
|
-
llm, prompt, tracker.sender_id
|
|
554
|
-
)
|
|
375
|
+
llm_response = await self._generate_llm_answer(llm, prompt)
|
|
555
376
|
llm_response = LLMResponse.ensure_llm_response(llm_response)
|
|
556
377
|
|
|
557
378
|
self._add_prompt_and_llm_response_to_latest_message(
|
|
@@ -593,25 +414,29 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
593
414
|
result.text for result in documents.results
|
|
594
415
|
],
|
|
595
416
|
UTTER_SOURCE_METADATA_KEY: self.__class__.__name__,
|
|
596
|
-
SEARCH_QUERY_METADATA_KEY:
|
|
417
|
+
SEARCH_QUERY_METADATA_KEY: tracker.latest_message.parse_data.get(
|
|
418
|
+
SEARCH_QUERY_KEY
|
|
419
|
+
),
|
|
597
420
|
}
|
|
598
421
|
}
|
|
599
422
|
|
|
600
|
-
# telemetry call to track policy prediction
|
|
601
|
-
track_enterprise_search_policy_predict(
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
423
|
+
# # telemetry call to track policy prediction
|
|
424
|
+
# track_enterprise_search_policy_predict(
|
|
425
|
+
# vector_store_type=self.vector_store_config.get(
|
|
426
|
+
# VECTOR_STORE_TYPE_PROPERTY),
|
|
427
|
+
# embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
|
|
428
|
+
# embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
|
|
429
|
+
# or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
|
|
430
|
+
# embeddings_model_group_id=self.embeddings_config.get(
|
|
431
|
+
# MODEL_GROUP_ID_CONFIG_KEY
|
|
432
|
+
# ),
|
|
433
|
+
# llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
|
|
434
|
+
# llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
|
|
435
|
+
# or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
|
|
436
|
+
# llm_model_group_id=self.llm_config.get(MODEL_GROUP_ID_CONFIG_KEY),
|
|
437
|
+
# citation_enabled=self.citation_enabled,
|
|
438
|
+
# )
|
|
439
|
+
|
|
615
440
|
return self._create_prediction(
|
|
616
441
|
domain=domain, tracker=tracker, action_metadata=action_metadata
|
|
617
442
|
)
|
|
@@ -647,26 +472,19 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
647
472
|
|
|
648
473
|
@measure_llm_latency
|
|
649
474
|
async def _generate_llm_answer(
|
|
650
|
-
self, llm: LLMClient, prompt: Text
|
|
475
|
+
self, llm: LLMClient, prompt: Text
|
|
651
476
|
) -> Optional[LLMResponse]:
|
|
652
477
|
"""Fetches an LLM completion for the provided prompt.
|
|
653
478
|
|
|
654
479
|
Args:
|
|
655
480
|
llm: The LLM client used to get the completion.
|
|
656
481
|
prompt: The prompt text to send to the model.
|
|
657
|
-
sender_id: sender_id from the tracker.
|
|
658
482
|
|
|
659
483
|
Returns:
|
|
660
484
|
An LLMResponse object, or None if the call fails.
|
|
661
485
|
"""
|
|
662
|
-
metadata = {
|
|
663
|
-
LANGFUSE_METADATA_USER_ID: self.user_id,
|
|
664
|
-
LANGFUSE_METADATA_SESSION_ID: sender_id,
|
|
665
|
-
LANGFUSE_CUSTOM_METADATA_DICT: {"component": self.__class__.__name__},
|
|
666
|
-
LANGFUSE_TAGS: [self.__class__.__name__],
|
|
667
|
-
}
|
|
668
486
|
try:
|
|
669
|
-
return await llm.acompletion(prompt
|
|
487
|
+
return await llm.acompletion(prompt)
|
|
670
488
|
except Exception as e:
|
|
671
489
|
# unfortunately, langchain does not wrap LLM exceptions which means
|
|
672
490
|
# we have to catch all exceptions here
|
|
@@ -786,73 +604,19 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
786
604
|
"enterprise_search_policy.load.failed", error=e, resource=resource.name
|
|
787
605
|
)
|
|
788
606
|
|
|
789
|
-
store_type = config.get(VECTOR_STORE_PROPERTY, {}).get(
|
|
790
|
-
VECTOR_STORE_TYPE_PROPERTY
|
|
791
|
-
)
|
|
792
|
-
|
|
793
|
-
embeddings = cls._create_plain_embedder(config)
|
|
794
|
-
|
|
795
607
|
logger.info("enterprise_search_policy.load", config=config)
|
|
796
|
-
if store_type == DEFAULT_VECTOR_STORE_TYPE:
|
|
797
|
-
# if a vector store is not specified,
|
|
798
|
-
# default to using FAISS with the index stored in the model
|
|
799
|
-
# TODO figure out a way to get path without context manager
|
|
800
|
-
with model_storage.read_from(resource) as path:
|
|
801
|
-
vector_store = FAISS_Store(
|
|
802
|
-
embeddings=embeddings,
|
|
803
|
-
index_path=path,
|
|
804
|
-
docs_folder=None,
|
|
805
|
-
create_index=False,
|
|
806
|
-
)
|
|
807
|
-
else:
|
|
808
|
-
vector_store = create_from_endpoint_config(
|
|
809
|
-
config_type=store_type,
|
|
810
|
-
embeddings=embeddings,
|
|
811
|
-
) # type: ignore
|
|
812
608
|
|
|
813
609
|
return cls(
|
|
814
610
|
config,
|
|
815
611
|
model_storage,
|
|
816
612
|
resource,
|
|
817
613
|
execution_context,
|
|
818
|
-
vector_store=vector_store,
|
|
819
614
|
prompt_template=prompt_template,
|
|
820
615
|
)
|
|
821
616
|
|
|
822
|
-
@classmethod
|
|
823
|
-
def _get_local_knowledge_data(cls, config: Dict[str, Any]) -> Optional[List[str]]:
|
|
824
|
-
"""This is required only for local knowledge base types.
|
|
825
|
-
|
|
826
|
-
e.g. FAISS, to ensure that the graph component is retrained when the knowledge
|
|
827
|
-
base is updated.
|
|
828
|
-
"""
|
|
829
|
-
merged_config = {**cls.get_default_config(), **config}
|
|
830
|
-
|
|
831
|
-
store_type = merged_config.get(VECTOR_STORE_PROPERTY, {}).get(
|
|
832
|
-
VECTOR_STORE_TYPE_PROPERTY
|
|
833
|
-
)
|
|
834
|
-
if store_type != DEFAULT_VECTOR_STORE_TYPE:
|
|
835
|
-
return None
|
|
836
|
-
|
|
837
|
-
source = merged_config.get(VECTOR_STORE_PROPERTY, {}).get(SOURCE_PROPERTY)
|
|
838
|
-
if not source:
|
|
839
|
-
return None
|
|
840
|
-
|
|
841
|
-
docs = FAISS_Store.load_documents(source)
|
|
842
|
-
|
|
843
|
-
if len(docs) == 0:
|
|
844
|
-
return None
|
|
845
|
-
|
|
846
|
-
docs_as_strings = [
|
|
847
|
-
json.dumps(doc.dict(), ensure_ascii=False, sort_keys=True) for doc in docs
|
|
848
|
-
]
|
|
849
|
-
return sorted(docs_as_strings)
|
|
850
|
-
|
|
851
617
|
@classmethod
|
|
852
618
|
def fingerprint_addon(cls, config: Dict[str, Any]) -> Optional[str]:
|
|
853
619
|
"""Add a fingerprint of enterprise search policy for the graph."""
|
|
854
|
-
local_knowledge_data = cls._get_local_knowledge_data(config)
|
|
855
|
-
|
|
856
620
|
prompt_template = get_prompt_template(
|
|
857
621
|
config.get(PROMPT_CONFIG_KEY),
|
|
858
622
|
DEFAULT_ENTERPRISE_SEARCH_PROMPT_TEMPLATE,
|
|
@@ -861,12 +625,7 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
861
625
|
llm_config = resolve_model_client_config(
|
|
862
626
|
config.get(LLM_CONFIG_KEY), EnterpriseSearchPolicy.__name__
|
|
863
627
|
)
|
|
864
|
-
|
|
865
|
-
config.get(EMBEDDINGS_CONFIG_KEY), EnterpriseSearchPolicy.__name__
|
|
866
|
-
)
|
|
867
|
-
return deep_container_fingerprint(
|
|
868
|
-
[prompt_template, local_knowledge_data, llm_config, embedding_config]
|
|
869
|
-
)
|
|
628
|
+
return deep_container_fingerprint([prompt_template, llm_config])
|
|
870
629
|
|
|
871
630
|
@staticmethod
|
|
872
631
|
def post_process_citations(llm_answer: str) -> str:
|
|
@@ -971,14 +730,3 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
971
730
|
log_source_method,
|
|
972
731
|
EnterpriseSearchPolicy.__name__,
|
|
973
732
|
)
|
|
974
|
-
|
|
975
|
-
# Perform health check of the embeddings client config
|
|
976
|
-
embeddings_config = resolve_model_client_config(
|
|
977
|
-
config.get(EMBEDDINGS_CONFIG_KEY, {})
|
|
978
|
-
)
|
|
979
|
-
cls.perform_embeddings_health_check(
|
|
980
|
-
embeddings_config,
|
|
981
|
-
DEFAULT_EMBEDDINGS_CONFIG,
|
|
982
|
-
log_source_method,
|
|
983
|
-
EnterpriseSearchPolicy.__name__,
|
|
984
|
-
)
|
|
@@ -23,7 +23,6 @@ from rasa.core.policies.flows.flow_step_result import (
|
|
|
23
23
|
)
|
|
24
24
|
from rasa.dialogue_understanding.commands import CancelFlowCommand
|
|
25
25
|
from rasa.dialogue_understanding.patterns.cancel import CancelPatternFlowStackFrame
|
|
26
|
-
from rasa.dialogue_understanding.patterns.clarify import ClarifyPatternFlowStackFrame
|
|
27
26
|
from rasa.dialogue_understanding.patterns.collect_information import (
|
|
28
27
|
CollectInformationPatternFlowStackFrame,
|
|
29
28
|
)
|
|
@@ -51,7 +50,6 @@ from rasa.dialogue_understanding.stack.frames.flow_stack_frame import (
|
|
|
51
50
|
)
|
|
52
51
|
from rasa.dialogue_understanding.stack.utils import (
|
|
53
52
|
top_user_flow_frame,
|
|
54
|
-
user_flows_on_the_stack,
|
|
55
53
|
)
|
|
56
54
|
from rasa.shared.constants import RASA_PATTERN_HUMAN_HANDOFF
|
|
57
55
|
from rasa.shared.core.constants import (
|
|
@@ -280,33 +278,6 @@ def trigger_pattern_continue_interrupted(
|
|
|
280
278
|
return events
|
|
281
279
|
|
|
282
280
|
|
|
283
|
-
def trigger_pattern_clarification(
|
|
284
|
-
current_frame: DialogueStackFrame, stack: DialogueStack, flows: FlowsList
|
|
285
|
-
) -> None:
|
|
286
|
-
"""Trigger the pattern to clarify which topic to continue if needed."""
|
|
287
|
-
if not isinstance(current_frame, UserFlowStackFrame):
|
|
288
|
-
return None
|
|
289
|
-
|
|
290
|
-
if current_frame.frame_type in [
|
|
291
|
-
FlowStackFrameType.CALL,
|
|
292
|
-
FlowStackFrameType.INTERRUPT,
|
|
293
|
-
]:
|
|
294
|
-
# we want to return to the flow that called
|
|
295
|
-
# the current flow or the flow that was interrupted
|
|
296
|
-
# by the current flow
|
|
297
|
-
return None
|
|
298
|
-
|
|
299
|
-
pending_flows = [
|
|
300
|
-
flows.flow_by_id(frame.flow_id)
|
|
301
|
-
for frame in stack.frames
|
|
302
|
-
if isinstance(frame, UserFlowStackFrame)
|
|
303
|
-
and frame.flow_id != current_frame.flow_id
|
|
304
|
-
]
|
|
305
|
-
|
|
306
|
-
flow_names = [flow.readable_name() for flow in pending_flows if flow is not None]
|
|
307
|
-
stack.push(ClarifyPatternFlowStackFrame(names=flow_names))
|
|
308
|
-
|
|
309
|
-
|
|
310
281
|
def trigger_pattern_completed(
|
|
311
282
|
current_frame: DialogueStackFrame, stack: DialogueStack, flows: FlowsList
|
|
312
283
|
) -> None:
|
|
@@ -675,15 +646,9 @@ def _run_end_step(
|
|
|
675
646
|
structlogger.debug("flow.step.run.flow_end")
|
|
676
647
|
current_frame = stack.pop()
|
|
677
648
|
trigger_pattern_completed(current_frame, stack, flows)
|
|
678
|
-
resumed_events =
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
# we need to trigger the pattern clarify
|
|
682
|
-
trigger_pattern_clarification(current_frame, stack, flows)
|
|
683
|
-
else:
|
|
684
|
-
resumed_events = trigger_pattern_continue_interrupted(
|
|
685
|
-
current_frame, stack, flows, tracker
|
|
686
|
-
)
|
|
649
|
+
resumed_events = trigger_pattern_continue_interrupted(
|
|
650
|
+
current_frame, stack, flows, tracker
|
|
651
|
+
)
|
|
687
652
|
reset_events: List[Event] = reset_scoped_slots(current_frame, flow, tracker)
|
|
688
653
|
return ContinueFlowWithNextStep(
|
|
689
654
|
events=initial_events + reset_events + resumed_events, has_flow_ended=True
|