rasa-pro 3.13.0.dev2__py3-none-any.whl → 3.13.0.dev3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/cli/run.py +10 -6
- rasa/cli/utils.py +7 -0
- rasa/core/channels/channel.py +30 -0
- rasa/core/channels/voice_ready/jambonz.py +25 -5
- rasa/core/channels/voice_ready/jambonz_protocol.py +4 -0
- rasa/core/information_retrieval/faiss.py +7 -68
- rasa/core/information_retrieval/information_retrieval.py +2 -40
- rasa/core/information_retrieval/milvus.py +2 -7
- rasa/core/information_retrieval/qdrant.py +2 -7
- rasa/core/nlg/contextual_response_rephraser.py +3 -0
- rasa/core/policies/enterprise_search_policy.py +310 -61
- rasa/core/policies/intentless_policy.py +3 -0
- rasa/dialogue_understanding/coexistence/llm_based_router.py +8 -0
- rasa/dialogue_understanding/commands/knowledge_answer_command.py +2 -2
- rasa/dialogue_understanding/generator/command_parser.py +1 -1
- rasa/dialogue_understanding/generator/flow_retrieval.py +1 -4
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +1 -2
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +13 -0
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_template.jinja2 +1 -1
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_claude_3_5_sonnet_20240620_template.jinja2 +1 -1
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_gpt_4o_2024_11_20_template.jinja2 +2 -24
- rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +22 -17
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +27 -12
- rasa/dialogue_understanding_test/io.py +8 -13
- rasa/e2e_test/utils/validation.py +3 -3
- rasa/engine/recipes/default_components.py +0 -2
- rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +3 -0
- rasa/shared/utils/constants.py +3 -0
- rasa/shared/utils/llm.py +70 -24
- rasa/tracing/instrumentation/attribute_extractors.py +7 -10
- rasa/tracing/instrumentation/instrumentation.py +12 -12
- rasa/version.py +1 -1
- {rasa_pro-3.13.0.dev2.dist-info → rasa_pro-3.13.0.dev3.dist-info}/METADATA +2 -2
- {rasa_pro-3.13.0.dev2.dist-info → rasa_pro-3.13.0.dev3.dist-info}/RECORD +37 -48
- rasa/document_retrieval/__init__.py +0 -0
- rasa/document_retrieval/constants.py +0 -32
- rasa/document_retrieval/document_post_processor.py +0 -351
- rasa/document_retrieval/document_post_processor_prompt_template.jinja2 +0 -0
- rasa/document_retrieval/document_retriever.py +0 -333
- rasa/document_retrieval/knowledge_base_connectors/__init__.py +0 -0
- rasa/document_retrieval/knowledge_base_connectors/api_connector.py +0 -39
- rasa/document_retrieval/knowledge_base_connectors/knowledge_base_connector.py +0 -34
- rasa/document_retrieval/knowledge_base_connectors/vector_store_connector.py +0 -226
- rasa/document_retrieval/query_rewriter.py +0 -234
- rasa/document_retrieval/query_rewriter_prompt_template.jinja2 +0 -8
- {rasa_pro-3.13.0.dev2.dist-info → rasa_pro-3.13.0.dev3.dist-info}/NOTICE +0 -0
- {rasa_pro-3.13.0.dev2.dist-info → rasa_pro-3.13.0.dev3.dist-info}/WHEEL +0 -0
- {rasa_pro-3.13.0.dev2.dist-info → rasa_pro-3.13.0.dev3.dist-info}/entry_points.txt +0 -0
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
import importlib.resources
|
|
2
|
+
import json
|
|
2
3
|
import re
|
|
3
4
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Text
|
|
4
5
|
|
|
5
6
|
import dotenv
|
|
6
7
|
import structlog
|
|
7
8
|
from jinja2 import Template
|
|
9
|
+
from pydantic import ValidationError
|
|
8
10
|
|
|
9
11
|
import rasa.shared.utils.io
|
|
10
12
|
from rasa.core.constants import (
|
|
@@ -14,9 +16,12 @@ from rasa.core.constants import (
|
|
|
14
16
|
UTTER_SOURCE_METADATA_KEY,
|
|
15
17
|
)
|
|
16
18
|
from rasa.core.information_retrieval import (
|
|
19
|
+
InformationRetrieval,
|
|
20
|
+
InformationRetrievalException,
|
|
17
21
|
SearchResult,
|
|
18
|
-
|
|
22
|
+
create_from_endpoint_config,
|
|
19
23
|
)
|
|
24
|
+
from rasa.core.information_retrieval.faiss import FAISS_Store
|
|
20
25
|
from rasa.core.policies.policy import Policy, PolicyPrediction
|
|
21
26
|
from rasa.core.utils import AvailableEndpoints
|
|
22
27
|
from rasa.dialogue_understanding.generator.constants import (
|
|
@@ -33,10 +38,6 @@ from rasa.dialogue_understanding.stack.frames import (
|
|
|
33
38
|
PatternFlowStackFrame,
|
|
34
39
|
SearchStackFrame,
|
|
35
40
|
)
|
|
36
|
-
from rasa.document_retrieval.constants import (
|
|
37
|
-
POST_PROCESSED_DOCUMENTS_KEY,
|
|
38
|
-
SEARCH_QUERY_KEY,
|
|
39
|
-
)
|
|
40
41
|
from rasa.engine.graph import ExecutionContext
|
|
41
42
|
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
42
43
|
from rasa.engine.storage.resource import Resource
|
|
@@ -44,7 +45,10 @@ from rasa.engine.storage.storage import ModelStorage
|
|
|
44
45
|
from rasa.graph_components.providers.forms_provider import Forms
|
|
45
46
|
from rasa.graph_components.providers.responses_provider import Responses
|
|
46
47
|
from rasa.shared.constants import (
|
|
48
|
+
EMBEDDINGS_CONFIG_KEY,
|
|
47
49
|
MODEL_CONFIG_KEY,
|
|
50
|
+
MODEL_GROUP_ID_CONFIG_KEY,
|
|
51
|
+
MODEL_NAME_CONFIG_KEY,
|
|
48
52
|
OPENAI_PROVIDER,
|
|
49
53
|
PROMPT_CONFIG_KEY,
|
|
50
54
|
PROVIDER_CONFIG_KEY,
|
|
@@ -56,10 +60,10 @@ from rasa.shared.core.constants import (
|
|
|
56
60
|
DEFAULT_SLOT_NAMES,
|
|
57
61
|
)
|
|
58
62
|
from rasa.shared.core.domain import Domain
|
|
59
|
-
from rasa.shared.core.events import Event
|
|
63
|
+
from rasa.shared.core.events import BotUttered, Event, UserUttered
|
|
60
64
|
from rasa.shared.core.generator import TrackerWithCachedStates
|
|
61
|
-
from rasa.shared.core.trackers import DialogueStateTracker
|
|
62
|
-
from rasa.shared.exceptions import FileIOException
|
|
65
|
+
from rasa.shared.core.trackers import DialogueStateTracker, EventVerbosity
|
|
66
|
+
from rasa.shared.exceptions import FileIOException, RasaException
|
|
63
67
|
from rasa.shared.nlu.constants import (
|
|
64
68
|
KEY_COMPONENT_NAME,
|
|
65
69
|
KEY_LLM_RESPONSE_METADATA,
|
|
@@ -68,8 +72,16 @@ from rasa.shared.nlu.constants import (
|
|
|
68
72
|
PROMPTS,
|
|
69
73
|
)
|
|
70
74
|
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
75
|
+
from rasa.shared.providers.embedding._langchain_embedding_client_adapter import (
|
|
76
|
+
_LangchainEmbeddingClientAdapter,
|
|
77
|
+
)
|
|
71
78
|
from rasa.shared.providers.llm.llm_client import LLMClient
|
|
72
79
|
from rasa.shared.providers.llm.llm_response import LLMResponse, measure_llm_latency
|
|
80
|
+
from rasa.shared.utils.cli import print_error_and_exit
|
|
81
|
+
from rasa.shared.utils.constants import (
|
|
82
|
+
LOG_COMPONENT_SOURCE_METHOD_FINGERPRINT_ADDON,
|
|
83
|
+
LOG_COMPONENT_SOURCE_METHOD_INIT,
|
|
84
|
+
)
|
|
73
85
|
from rasa.shared.utils.health_check.embeddings_health_check_mixin import (
|
|
74
86
|
EmbeddingsHealthCheckMixin,
|
|
75
87
|
)
|
|
@@ -77,13 +89,23 @@ from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheck
|
|
|
77
89
|
from rasa.shared.utils.io import deep_container_fingerprint
|
|
78
90
|
from rasa.shared.utils.llm import (
|
|
79
91
|
DEFAULT_OPENAI_CHAT_MODEL_NAME,
|
|
92
|
+
DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
|
|
93
|
+
embedder_factory,
|
|
80
94
|
get_prompt_template,
|
|
81
95
|
llm_factory,
|
|
82
96
|
resolve_model_client_config,
|
|
97
|
+
sanitize_message_for_prompt,
|
|
83
98
|
tracker_as_readable_transcript,
|
|
84
99
|
)
|
|
100
|
+
from rasa.telemetry import (
|
|
101
|
+
track_enterprise_search_policy_predict,
|
|
102
|
+
track_enterprise_search_policy_train_completed,
|
|
103
|
+
track_enterprise_search_policy_train_started,
|
|
104
|
+
)
|
|
85
105
|
|
|
86
106
|
if TYPE_CHECKING:
|
|
107
|
+
from langchain.schema.embeddings import Embeddings
|
|
108
|
+
|
|
87
109
|
from rasa.core.featurizers.tracker_featurizers import TrackerFeaturizer
|
|
88
110
|
|
|
89
111
|
from rasa.utils.log_utils import log_llm
|
|
@@ -92,11 +114,22 @@ logger = structlog.get_logger()
|
|
|
92
114
|
|
|
93
115
|
dotenv.load_dotenv("./.env")
|
|
94
116
|
|
|
117
|
+
SOURCE_PROPERTY = "source"
|
|
118
|
+
VECTOR_STORE_TYPE_PROPERTY = "type"
|
|
119
|
+
VECTOR_STORE_PROPERTY = "vector_store"
|
|
120
|
+
VECTOR_STORE_THRESHOLD_PROPERTY = "threshold"
|
|
95
121
|
TRACE_TOKENS_PROPERTY = "trace_prompt_tokens"
|
|
96
122
|
CITATION_ENABLED_PROPERTY = "citation_enabled"
|
|
97
123
|
USE_LLM_PROPERTY = "use_generative_llm"
|
|
98
124
|
MAX_MESSAGES_IN_QUERY_KEY = "max_messages_in_query"
|
|
99
125
|
|
|
126
|
+
DEFAULT_VECTOR_STORE_TYPE = "faiss"
|
|
127
|
+
DEFAULT_VECTOR_STORE_THRESHOLD = 0.0
|
|
128
|
+
DEFAULT_VECTOR_STORE = {
|
|
129
|
+
VECTOR_STORE_TYPE_PROPERTY: DEFAULT_VECTOR_STORE_TYPE,
|
|
130
|
+
SOURCE_PROPERTY: "./docs",
|
|
131
|
+
VECTOR_STORE_THRESHOLD_PROPERTY: DEFAULT_VECTOR_STORE_THRESHOLD,
|
|
132
|
+
}
|
|
100
133
|
|
|
101
134
|
DEFAULT_LLM_CONFIG = {
|
|
102
135
|
PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
|
|
@@ -107,6 +140,11 @@ DEFAULT_LLM_CONFIG = {
|
|
|
107
140
|
"max_retries": 1,
|
|
108
141
|
}
|
|
109
142
|
|
|
143
|
+
DEFAULT_EMBEDDINGS_CONFIG = {
|
|
144
|
+
PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
|
|
145
|
+
"model": DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
|
|
146
|
+
}
|
|
147
|
+
|
|
110
148
|
ENTERPRISE_SEARCH_PROMPT_FILE_NAME = "enterprise_search_policy_prompt.jinja2"
|
|
111
149
|
ENTERPRISE_SEARCH_CONFIG_FILE_NAME = "config.json"
|
|
112
150
|
|
|
@@ -122,6 +160,14 @@ DEFAULT_ENTERPRISE_SEARCH_PROMPT_WITH_CITATION_TEMPLATE = importlib.resources.re
|
|
|
122
160
|
)
|
|
123
161
|
|
|
124
162
|
|
|
163
|
+
class VectorStoreConnectionError(RasaException):
|
|
164
|
+
"""Exception raised for errors in connecting to the vector store."""
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
class VectorStoreConfigurationError(RasaException):
|
|
168
|
+
"""Exception raised for errors in vector store configuration."""
|
|
169
|
+
|
|
170
|
+
|
|
125
171
|
@DefaultV1Recipe.register(
|
|
126
172
|
DefaultV1Recipe.ComponentType.POLICY_WITH_END_TO_END_SUPPORT, is_trainable=True
|
|
127
173
|
)
|
|
@@ -155,6 +201,7 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
155
201
|
"""Returns the default config of the policy."""
|
|
156
202
|
return {
|
|
157
203
|
POLICY_PRIORITY: SEARCH_POLICY_PRIORITY,
|
|
204
|
+
VECTOR_STORE_PROPERTY: DEFAULT_VECTOR_STORE,
|
|
158
205
|
}
|
|
159
206
|
|
|
160
207
|
def __init__(
|
|
@@ -163,6 +210,7 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
163
210
|
model_storage: ModelStorage,
|
|
164
211
|
resource: Resource,
|
|
165
212
|
execution_context: ExecutionContext,
|
|
213
|
+
vector_store: Optional[InformationRetrieval] = None,
|
|
166
214
|
featurizer: Optional["TrackerFeaturizer"] = None,
|
|
167
215
|
prompt_template: Optional[Text] = None,
|
|
168
216
|
) -> None:
|
|
@@ -173,6 +221,21 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
173
221
|
self.config[LLM_CONFIG_KEY] = resolve_model_client_config(
|
|
174
222
|
self.config.get(LLM_CONFIG_KEY), EnterpriseSearchPolicy.__name__
|
|
175
223
|
)
|
|
224
|
+
# Resolve embeddings config
|
|
225
|
+
self.config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
|
|
226
|
+
self.config.get(EMBEDDINGS_CONFIG_KEY), EnterpriseSearchPolicy.__name__
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
# Vector store object and configuration
|
|
230
|
+
self.vector_store = vector_store
|
|
231
|
+
self.vector_store_config = self.config.get(
|
|
232
|
+
VECTOR_STORE_PROPERTY, DEFAULT_VECTOR_STORE
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
# Embeddings configuration for encoding the search query
|
|
236
|
+
self.embeddings_config = (
|
|
237
|
+
self.config[EMBEDDINGS_CONFIG_KEY] or DEFAULT_EMBEDDINGS_CONFIG
|
|
238
|
+
)
|
|
176
239
|
|
|
177
240
|
# LLM Configuration for response generation
|
|
178
241
|
self.llm_config = self.config[LLM_CONFIG_KEY] or DEFAULT_LLM_CONFIG
|
|
@@ -180,6 +243,9 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
180
243
|
# Maximum number of turns to include in the prompt
|
|
181
244
|
self.max_history = self.config.get(POLICY_MAX_HISTORY)
|
|
182
245
|
|
|
246
|
+
# Maximum number of messages to include in the search query
|
|
247
|
+
self.max_messages_in_query = self.config.get(MAX_MESSAGES_IN_QUERY_KEY, 2)
|
|
248
|
+
|
|
183
249
|
# boolean to enable/disable tracing of prompt tokens
|
|
184
250
|
self.trace_prompt_tokens = self.config.get(TRACE_TOKENS_PROPERTY, False)
|
|
185
251
|
|
|
@@ -192,15 +258,38 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
192
258
|
self.prompt_template = prompt_template or get_prompt_template(
|
|
193
259
|
self.config.get(PROMPT_CONFIG_KEY),
|
|
194
260
|
DEFAULT_ENTERPRISE_SEARCH_PROMPT_TEMPLATE,
|
|
261
|
+
log_source_component=EnterpriseSearchPolicy.__name__,
|
|
262
|
+
log_source_method=LOG_COMPONENT_SOURCE_METHOD_INIT,
|
|
195
263
|
)
|
|
196
264
|
self.citation_prompt_template = get_prompt_template(
|
|
197
265
|
self.config.get(PROMPT_CONFIG_KEY),
|
|
198
266
|
DEFAULT_ENTERPRISE_SEARCH_PROMPT_WITH_CITATION_TEMPLATE,
|
|
267
|
+
log_source_component=EnterpriseSearchPolicy.__name__,
|
|
268
|
+
log_source_method=LOG_COMPONENT_SOURCE_METHOD_INIT,
|
|
199
269
|
)
|
|
200
270
|
# If citation is enabled, use the citation prompt template
|
|
201
271
|
if self.citation_enabled:
|
|
202
272
|
self.prompt_template = self.citation_prompt_template
|
|
203
273
|
|
|
274
|
+
@classmethod
|
|
275
|
+
def _create_plain_embedder(cls, config: Dict[Text, Any]) -> "Embeddings":
|
|
276
|
+
"""Creates an embedder based on the given configuration.
|
|
277
|
+
|
|
278
|
+
Returns:
|
|
279
|
+
The embedder.
|
|
280
|
+
"""
|
|
281
|
+
# Copy the config so original config is not modified
|
|
282
|
+
config = config.copy()
|
|
283
|
+
# Resolve config and instantiate the embedding client
|
|
284
|
+
config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
|
|
285
|
+
config.get(EMBEDDINGS_CONFIG_KEY), EnterpriseSearchPolicy.__name__
|
|
286
|
+
)
|
|
287
|
+
client = embedder_factory(
|
|
288
|
+
config.get(EMBEDDINGS_CONFIG_KEY), DEFAULT_EMBEDDINGS_CONFIG
|
|
289
|
+
)
|
|
290
|
+
# Wrap the embedding client in the adapter
|
|
291
|
+
return _LangchainEmbeddingClientAdapter(client)
|
|
292
|
+
|
|
204
293
|
@classmethod
|
|
205
294
|
def _add_prompt_and_llm_response_to_latest_message(
|
|
206
295
|
cls,
|
|
@@ -265,24 +354,52 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
265
354
|
# Perform health checks for both LLM and embeddings client configs
|
|
266
355
|
self._perform_health_checks(self.config, "enterprise_search_policy.train")
|
|
267
356
|
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
#
|
|
271
|
-
|
|
272
|
-
# vector_store_type=store_type,
|
|
273
|
-
# embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
|
|
274
|
-
# embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
|
|
275
|
-
# or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
|
|
276
|
-
# embeddings_model_group_id=self.embeddings_config.get(
|
|
277
|
-
# MODEL_GROUP_ID_CONFIG_KEY
|
|
278
|
-
# ),
|
|
279
|
-
# llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
|
|
280
|
-
# llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
|
|
281
|
-
# or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
|
|
282
|
-
# llm_model_group_id=self.llm_config.get(MODEL_GROUP_ID_CONFIG_KEY),
|
|
283
|
-
# citation_enabled=self.citation_enabled,
|
|
284
|
-
# )
|
|
357
|
+
store_type = self.vector_store_config.get(VECTOR_STORE_TYPE_PROPERTY)
|
|
358
|
+
|
|
359
|
+
# telemetry call to track training start
|
|
360
|
+
track_enterprise_search_policy_train_started()
|
|
285
361
|
|
|
362
|
+
# validate embedding configuration
|
|
363
|
+
try:
|
|
364
|
+
embeddings = self._create_plain_embedder(self.config)
|
|
365
|
+
except (ValidationError, Exception) as e:
|
|
366
|
+
logger.error(
|
|
367
|
+
"enterprise_search_policy.train.embedder_instantiation_failed",
|
|
368
|
+
message="Unable to instantiate the embedding client.",
|
|
369
|
+
error=e,
|
|
370
|
+
)
|
|
371
|
+
print_error_and_exit(
|
|
372
|
+
"Unable to create embedder. Please make sure you specified the "
|
|
373
|
+
f"required environment variables. Error: {e}"
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
if store_type == DEFAULT_VECTOR_STORE_TYPE:
|
|
377
|
+
logger.info("enterprise_search_policy.train.faiss")
|
|
378
|
+
with self._model_storage.write_to(self._resource) as path:
|
|
379
|
+
self.vector_store = FAISS_Store(
|
|
380
|
+
docs_folder=self.vector_store_config.get(SOURCE_PROPERTY),
|
|
381
|
+
embeddings=embeddings,
|
|
382
|
+
index_path=path,
|
|
383
|
+
create_index=True,
|
|
384
|
+
)
|
|
385
|
+
else:
|
|
386
|
+
logger.info("enterprise_search_policy.train.custom", store_type=store_type)
|
|
387
|
+
|
|
388
|
+
# telemetry call to track training completion
|
|
389
|
+
track_enterprise_search_policy_train_completed(
|
|
390
|
+
vector_store_type=store_type,
|
|
391
|
+
embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
|
|
392
|
+
embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
|
|
393
|
+
or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
|
|
394
|
+
embeddings_model_group_id=self.embeddings_config.get(
|
|
395
|
+
MODEL_GROUP_ID_CONFIG_KEY
|
|
396
|
+
),
|
|
397
|
+
llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
|
|
398
|
+
llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
|
|
399
|
+
or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
|
|
400
|
+
llm_model_group_id=self.llm_config.get(MODEL_GROUP_ID_CONFIG_KEY),
|
|
401
|
+
citation_enabled=self.citation_enabled,
|
|
402
|
+
)
|
|
286
403
|
self.persist()
|
|
287
404
|
return self._resource
|
|
288
405
|
|
|
@@ -319,6 +436,60 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
319
436
|
)
|
|
320
437
|
return template_slots
|
|
321
438
|
|
|
439
|
+
def _connect_vector_store_or_raise(
|
|
440
|
+
self, endpoints: Optional[AvailableEndpoints]
|
|
441
|
+
) -> None:
|
|
442
|
+
"""Connects to the vector store or raises an exception.
|
|
443
|
+
|
|
444
|
+
Raise exceptions for the following cases:
|
|
445
|
+
- The configuration is not specified
|
|
446
|
+
- Unable to connect to the vector store
|
|
447
|
+
|
|
448
|
+
Args:
|
|
449
|
+
endpoints: Endpoints configuration.
|
|
450
|
+
"""
|
|
451
|
+
config = endpoints.vector_store if endpoints else None
|
|
452
|
+
store_type = self.vector_store_config.get(VECTOR_STORE_TYPE_PROPERTY)
|
|
453
|
+
if config is None and store_type != DEFAULT_VECTOR_STORE_TYPE:
|
|
454
|
+
logger.error(
|
|
455
|
+
"enterprise_search_policy._connect_vector_store_or_raise.no_config"
|
|
456
|
+
)
|
|
457
|
+
raise VectorStoreConfigurationError(
|
|
458
|
+
"""No vector store specified. Please specify a vector
|
|
459
|
+
store in the endpoints configuration"""
|
|
460
|
+
)
|
|
461
|
+
try:
|
|
462
|
+
self.vector_store.connect(config) # type: ignore
|
|
463
|
+
except Exception as e:
|
|
464
|
+
logger.error(
|
|
465
|
+
"enterprise_search_policy._connect_vector_store_or_raise.connect_error",
|
|
466
|
+
error=e,
|
|
467
|
+
config=config,
|
|
468
|
+
)
|
|
469
|
+
raise VectorStoreConnectionError(
|
|
470
|
+
f"Unable to connect to the vector store. Error: {e}"
|
|
471
|
+
)
|
|
472
|
+
|
|
473
|
+
def _prepare_search_query(self, tracker: DialogueStateTracker, history: int) -> str:
|
|
474
|
+
"""Prepares the search query.
|
|
475
|
+
The search query is the last N messages in the conversation history.
|
|
476
|
+
|
|
477
|
+
Args:
|
|
478
|
+
tracker: The tracker containing the conversation history up to now.
|
|
479
|
+
history: The number of messages to include in the search query.
|
|
480
|
+
|
|
481
|
+
Returns:
|
|
482
|
+
The search query.
|
|
483
|
+
"""
|
|
484
|
+
transcript = []
|
|
485
|
+
for event in tracker.applied_events():
|
|
486
|
+
if isinstance(event, UserUttered) or isinstance(event, BotUttered):
|
|
487
|
+
transcript.append(sanitize_message_for_prompt(event.text))
|
|
488
|
+
|
|
489
|
+
search_query = " ".join(transcript[-history:][::-1])
|
|
490
|
+
logger.debug("search_query", search_query=search_query)
|
|
491
|
+
return search_query
|
|
492
|
+
|
|
322
493
|
async def predict_action_probabilities( # type: ignore[override]
|
|
323
494
|
self,
|
|
324
495
|
tracker: DialogueStateTracker,
|
|
@@ -342,34 +513,44 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
342
513
|
The prediction.
|
|
343
514
|
"""
|
|
344
515
|
logger_key = "enterprise_search_policy.predict_action_probabilities"
|
|
345
|
-
|
|
516
|
+
vector_search_threshold = self.vector_store_config.get(
|
|
517
|
+
VECTOR_STORE_THRESHOLD_PROPERTY, DEFAULT_VECTOR_STORE_THRESHOLD
|
|
518
|
+
)
|
|
519
|
+
llm = llm_factory(self.config.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG)
|
|
346
520
|
if not self.supports_current_stack_frame(
|
|
347
521
|
tracker, False, False
|
|
348
522
|
) or self.should_abstain_in_coexistence(tracker, True):
|
|
349
523
|
return self._prediction(self._default_predictions(domain))
|
|
350
524
|
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
logger.info(f"{logger_key}.no_documents")
|
|
355
|
-
return self._create_prediction_cannot_handle(domain, tracker)
|
|
525
|
+
if not self.vector_store:
|
|
526
|
+
logger.error(f"{logger_key}.no_vector_store")
|
|
527
|
+
return self._create_prediction_internal_error(domain, tracker)
|
|
356
528
|
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
)
|
|
529
|
+
try:
|
|
530
|
+
self._connect_vector_store_or_raise(endpoints)
|
|
531
|
+
except (VectorStoreConfigurationError, VectorStoreConnectionError) as e:
|
|
532
|
+
logger.error(f"{logger_key}.connection_error", error=e)
|
|
533
|
+
return self._create_prediction_internal_error(domain, tracker)
|
|
360
534
|
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
535
|
+
search_query = self._prepare_search_query(
|
|
536
|
+
tracker, int(self.max_messages_in_query)
|
|
537
|
+
)
|
|
538
|
+
tracker_state = tracker.current_state(EventVerbosity.AFTER_RESTART)
|
|
364
539
|
|
|
365
|
-
|
|
540
|
+
try:
|
|
541
|
+
documents = await self.vector_store.search(
|
|
542
|
+
query=search_query,
|
|
543
|
+
tracker_state=tracker_state,
|
|
544
|
+
threshold=vector_search_threshold,
|
|
545
|
+
)
|
|
546
|
+
except InformationRetrievalException as e:
|
|
547
|
+
logger.error(f"{logger_key}.search_error", error=e)
|
|
548
|
+
return self._create_prediction_internal_error(domain, tracker)
|
|
366
549
|
|
|
367
550
|
if not documents.results:
|
|
368
551
|
logger.info(f"{logger_key}.no_documents")
|
|
369
552
|
return self._create_prediction_cannot_handle(domain, tracker)
|
|
370
553
|
|
|
371
|
-
llm = llm_factory(self.config.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG)
|
|
372
|
-
|
|
373
554
|
if self.use_llm:
|
|
374
555
|
prompt = self._render_prompt(tracker, documents.results)
|
|
375
556
|
llm_response = await self._generate_llm_answer(llm, prompt)
|
|
@@ -414,29 +595,25 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
414
595
|
result.text for result in documents.results
|
|
415
596
|
],
|
|
416
597
|
UTTER_SOURCE_METADATA_KEY: self.__class__.__name__,
|
|
417
|
-
SEARCH_QUERY_METADATA_KEY:
|
|
418
|
-
SEARCH_QUERY_KEY
|
|
419
|
-
),
|
|
598
|
+
SEARCH_QUERY_METADATA_KEY: search_query,
|
|
420
599
|
}
|
|
421
600
|
}
|
|
422
601
|
|
|
423
|
-
#
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
# )
|
|
439
|
-
|
|
602
|
+
# telemetry call to track policy prediction
|
|
603
|
+
track_enterprise_search_policy_predict(
|
|
604
|
+
vector_store_type=self.vector_store_config.get(VECTOR_STORE_TYPE_PROPERTY),
|
|
605
|
+
embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
|
|
606
|
+
embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
|
|
607
|
+
or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
|
|
608
|
+
embeddings_model_group_id=self.embeddings_config.get(
|
|
609
|
+
MODEL_GROUP_ID_CONFIG_KEY
|
|
610
|
+
),
|
|
611
|
+
llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
|
|
612
|
+
llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
|
|
613
|
+
or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
|
|
614
|
+
llm_model_group_id=self.llm_config.get(MODEL_GROUP_ID_CONFIG_KEY),
|
|
615
|
+
citation_enabled=self.citation_enabled,
|
|
616
|
+
)
|
|
440
617
|
return self._create_prediction(
|
|
441
618
|
domain=domain, tracker=tracker, action_metadata=action_metadata
|
|
442
619
|
)
|
|
@@ -604,28 +781,89 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
604
781
|
"enterprise_search_policy.load.failed", error=e, resource=resource.name
|
|
605
782
|
)
|
|
606
783
|
|
|
784
|
+
store_type = config.get(VECTOR_STORE_PROPERTY, {}).get(
|
|
785
|
+
VECTOR_STORE_TYPE_PROPERTY
|
|
786
|
+
)
|
|
787
|
+
|
|
788
|
+
embeddings = cls._create_plain_embedder(config)
|
|
789
|
+
|
|
607
790
|
logger.info("enterprise_search_policy.load", config=config)
|
|
791
|
+
if store_type == DEFAULT_VECTOR_STORE_TYPE:
|
|
792
|
+
# if a vector store is not specified,
|
|
793
|
+
# default to using FAISS with the index stored in the model
|
|
794
|
+
# TODO figure out a way to get path without context manager
|
|
795
|
+
with model_storage.read_from(resource) as path:
|
|
796
|
+
vector_store = FAISS_Store(
|
|
797
|
+
embeddings=embeddings,
|
|
798
|
+
index_path=path,
|
|
799
|
+
docs_folder=None,
|
|
800
|
+
create_index=False,
|
|
801
|
+
)
|
|
802
|
+
else:
|
|
803
|
+
vector_store = create_from_endpoint_config(
|
|
804
|
+
config_type=store_type,
|
|
805
|
+
embeddings=embeddings,
|
|
806
|
+
) # type: ignore
|
|
608
807
|
|
|
609
808
|
return cls(
|
|
610
809
|
config,
|
|
611
810
|
model_storage,
|
|
612
811
|
resource,
|
|
613
812
|
execution_context,
|
|
813
|
+
vector_store=vector_store,
|
|
614
814
|
prompt_template=prompt_template,
|
|
615
815
|
)
|
|
616
816
|
|
|
817
|
+
@classmethod
|
|
818
|
+
def _get_local_knowledge_data(cls, config: Dict[str, Any]) -> Optional[List[str]]:
|
|
819
|
+
"""This is required only for local knowledge base types.
|
|
820
|
+
|
|
821
|
+
e.g. FAISS, to ensure that the graph component is retrained when the knowledge
|
|
822
|
+
base is updated.
|
|
823
|
+
"""
|
|
824
|
+
merged_config = {**cls.get_default_config(), **config}
|
|
825
|
+
|
|
826
|
+
store_type = merged_config.get(VECTOR_STORE_PROPERTY, {}).get(
|
|
827
|
+
VECTOR_STORE_TYPE_PROPERTY
|
|
828
|
+
)
|
|
829
|
+
if store_type != DEFAULT_VECTOR_STORE_TYPE:
|
|
830
|
+
return None
|
|
831
|
+
|
|
832
|
+
source = merged_config.get(VECTOR_STORE_PROPERTY, {}).get(SOURCE_PROPERTY)
|
|
833
|
+
if not source:
|
|
834
|
+
return None
|
|
835
|
+
|
|
836
|
+
docs = FAISS_Store.load_documents(source)
|
|
837
|
+
|
|
838
|
+
if len(docs) == 0:
|
|
839
|
+
return None
|
|
840
|
+
|
|
841
|
+
docs_as_strings = [
|
|
842
|
+
json.dumps(doc.dict(), ensure_ascii=False, sort_keys=True) for doc in docs
|
|
843
|
+
]
|
|
844
|
+
return sorted(docs_as_strings)
|
|
845
|
+
|
|
617
846
|
@classmethod
|
|
618
847
|
def fingerprint_addon(cls, config: Dict[str, Any]) -> Optional[str]:
|
|
619
848
|
"""Add a fingerprint of enterprise search policy for the graph."""
|
|
849
|
+
local_knowledge_data = cls._get_local_knowledge_data(config)
|
|
850
|
+
|
|
620
851
|
prompt_template = get_prompt_template(
|
|
621
852
|
config.get(PROMPT_CONFIG_KEY),
|
|
622
853
|
DEFAULT_ENTERPRISE_SEARCH_PROMPT_TEMPLATE,
|
|
854
|
+
log_source_component=EnterpriseSearchPolicy.__name__,
|
|
855
|
+
log_source_method=LOG_COMPONENT_SOURCE_METHOD_FINGERPRINT_ADDON,
|
|
623
856
|
)
|
|
624
857
|
|
|
625
858
|
llm_config = resolve_model_client_config(
|
|
626
859
|
config.get(LLM_CONFIG_KEY), EnterpriseSearchPolicy.__name__
|
|
627
860
|
)
|
|
628
|
-
|
|
861
|
+
embedding_config = resolve_model_client_config(
|
|
862
|
+
config.get(EMBEDDINGS_CONFIG_KEY), EnterpriseSearchPolicy.__name__
|
|
863
|
+
)
|
|
864
|
+
return deep_container_fingerprint(
|
|
865
|
+
[prompt_template, local_knowledge_data, llm_config, embedding_config]
|
|
866
|
+
)
|
|
629
867
|
|
|
630
868
|
@staticmethod
|
|
631
869
|
def post_process_citations(llm_answer: str) -> str:
|
|
@@ -730,3 +968,14 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
730
968
|
log_source_method,
|
|
731
969
|
EnterpriseSearchPolicy.__name__,
|
|
732
970
|
)
|
|
971
|
+
|
|
972
|
+
# Perform health check of the embeddings client config
|
|
973
|
+
embeddings_config = resolve_model_client_config(
|
|
974
|
+
config.get(EMBEDDINGS_CONFIG_KEY, {})
|
|
975
|
+
)
|
|
976
|
+
cls.perform_embeddings_health_check(
|
|
977
|
+
embeddings_config,
|
|
978
|
+
DEFAULT_EMBEDDINGS_CONFIG,
|
|
979
|
+
log_source_method,
|
|
980
|
+
EnterpriseSearchPolicy.__name__,
|
|
981
|
+
)
|
|
@@ -58,6 +58,7 @@ from rasa.shared.providers.embedding._langchain_embedding_client_adapter import
|
|
|
58
58
|
_LangchainEmbeddingClientAdapter,
|
|
59
59
|
)
|
|
60
60
|
from rasa.shared.providers.llm.llm_client import LLMClient
|
|
61
|
+
from rasa.shared.utils.constants import LOG_COMPONENT_SOURCE_METHOD_FINGERPRINT_ADDON
|
|
61
62
|
from rasa.shared.utils.health_check.embeddings_health_check_mixin import (
|
|
62
63
|
EmbeddingsHealthCheckMixin,
|
|
63
64
|
)
|
|
@@ -939,6 +940,8 @@ class IntentlessPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Policy):
|
|
|
939
940
|
prompt_template = get_prompt_template(
|
|
940
941
|
config.get(PROMPT_CONFIG_KEY),
|
|
941
942
|
DEFAULT_INTENTLESS_PROMPT_TEMPLATE,
|
|
943
|
+
log_source_component=IntentlessPolicy.__name__,
|
|
944
|
+
log_source_method=LOG_COMPONENT_SOURCE_METHOD_FINGERPRINT_ADDON,
|
|
942
945
|
)
|
|
943
946
|
|
|
944
947
|
llm_config = resolve_model_client_config(
|
|
@@ -35,6 +35,10 @@ from rasa.shared.exceptions import FileIOException, InvalidConfigException
|
|
|
35
35
|
from rasa.shared.nlu.constants import COMMANDS, TEXT
|
|
36
36
|
from rasa.shared.nlu.training_data.message import Message
|
|
37
37
|
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
38
|
+
from rasa.shared.utils.constants import (
|
|
39
|
+
LOG_COMPONENT_SOURCE_METHOD_FINGERPRINT_ADDON,
|
|
40
|
+
LOG_COMPONENT_SOURCE_METHOD_INIT,
|
|
41
|
+
)
|
|
38
42
|
from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheckMixin
|
|
39
43
|
from rasa.shared.utils.io import deep_container_fingerprint
|
|
40
44
|
from rasa.shared.utils.llm import (
|
|
@@ -107,6 +111,8 @@ class LLMBasedRouter(LLMHealthCheckMixin, GraphComponent):
|
|
|
107
111
|
or get_prompt_template(
|
|
108
112
|
config.get(PROMPT_CONFIG_KEY),
|
|
109
113
|
DEFAULT_COMMAND_PROMPT_TEMPLATE,
|
|
114
|
+
log_source_component=LLMBasedRouter.__name__,
|
|
115
|
+
log_source_method=LOG_COMPONENT_SOURCE_METHOD_INIT,
|
|
110
116
|
).strip()
|
|
111
117
|
)
|
|
112
118
|
|
|
@@ -318,6 +324,8 @@ class LLMBasedRouter(LLMHealthCheckMixin, GraphComponent):
|
|
|
318
324
|
prompt_template = get_prompt_template(
|
|
319
325
|
config.get(PROMPT_CONFIG_KEY),
|
|
320
326
|
DEFAULT_COMMAND_PROMPT_TEMPLATE,
|
|
327
|
+
log_source_component=LLMBasedRouter.__name__,
|
|
328
|
+
log_source_method=LOG_COMPONENT_SOURCE_METHOD_FINGERPRINT_ADDON,
|
|
321
329
|
)
|
|
322
330
|
|
|
323
331
|
llm_config = resolve_model_client_config(
|
|
@@ -65,7 +65,7 @@ class KnowledgeAnswerCommand(FreeFormAnswerCommand):
|
|
|
65
65
|
"""Converts the command to a DSL string."""
|
|
66
66
|
mapper = {
|
|
67
67
|
CommandSyntaxVersion.v1: "SearchAndReply()",
|
|
68
|
-
CommandSyntaxVersion.v2: "
|
|
68
|
+
CommandSyntaxVersion.v2: "provide info",
|
|
69
69
|
}
|
|
70
70
|
return mapper.get(
|
|
71
71
|
CommandSyntaxManager.get_syntax_version(),
|
|
@@ -81,7 +81,7 @@ class KnowledgeAnswerCommand(FreeFormAnswerCommand):
|
|
|
81
81
|
def regex_pattern() -> str:
|
|
82
82
|
mapper = {
|
|
83
83
|
CommandSyntaxVersion.v1: r"SearchAndReply\(\)",
|
|
84
|
-
CommandSyntaxVersion.v2: r"""^[\s\W\d]*
|
|
84
|
+
CommandSyntaxVersion.v2: r"""^[\s\W\d]*provide info['"`]*$""",
|
|
85
85
|
}
|
|
86
86
|
return mapper.get(
|
|
87
87
|
CommandSyntaxManager.get_syntax_version(),
|
|
@@ -169,7 +169,7 @@ def _parse_standard_commands(
|
|
|
169
169
|
commands: List[Command] = []
|
|
170
170
|
for command_clz in standard_commands:
|
|
171
171
|
pattern = _get_compiled_pattern(command_clz.regex_pattern())
|
|
172
|
-
if match := pattern.search(action
|
|
172
|
+
if match := pattern.search(action):
|
|
173
173
|
parsed_command = command_clz.from_dsl(match, **kwargs)
|
|
174
174
|
if _additional_parsing_fn := _get_additional_parsing_logic(command_clz):
|
|
175
175
|
parsed_command = _additional_parsing_fn(parsed_command, flows, **kwargs)
|
|
@@ -52,7 +52,6 @@ from rasa.shared.utils.llm import (
|
|
|
52
52
|
USER,
|
|
53
53
|
allowed_values_for_slot,
|
|
54
54
|
embedder_factory,
|
|
55
|
-
get_prompt_template,
|
|
56
55
|
resolve_model_client_config,
|
|
57
56
|
tracker_as_readable_transcript,
|
|
58
57
|
)
|
|
@@ -103,9 +102,7 @@ class FlowRetrieval(EmbeddingsHealthCheckMixin):
|
|
|
103
102
|
self.config.get(EMBEDDINGS_CONFIG_KEY), FlowRetrieval.__name__
|
|
104
103
|
)
|
|
105
104
|
self.vector_store: Optional[FAISS] = None
|
|
106
|
-
self.flow_document_template =
|
|
107
|
-
None, DEFAULT_FLOW_DOCUMENT_TEMPLATE
|
|
108
|
-
)
|
|
105
|
+
self.flow_document_template = DEFAULT_FLOW_DOCUMENT_TEMPLATE
|
|
109
106
|
self._model_storage = model_storage
|
|
110
107
|
self._resource = resource
|
|
111
108
|
|