rasa-pro 3.13.0.dev9__py3-none-any.whl → 3.13.0.dev11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/cli/export.py +2 -0
- rasa/cli/studio/download.py +3 -9
- rasa/cli/studio/link.py +1 -2
- rasa/cli/studio/pull.py +3 -2
- rasa/cli/studio/push.py +1 -1
- rasa/cli/studio/train.py +0 -1
- rasa/core/exporter.py +36 -0
- rasa/core/policies/enterprise_search_policy.py +151 -240
- rasa/core/policies/enterprise_search_policy_config.py +242 -0
- rasa/core/policies/enterprise_search_prompt_with_relevancy_check_and_citation_template.jinja2 +6 -5
- rasa/core/utils.py +11 -2
- rasa/dialogue_understanding/commands/__init__.py +4 -0
- rasa/dialogue_understanding/generator/command_generator.py +11 -1
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v3_claude_3_5_sonnet_20240620_template.jinja2 +78 -0
- rasa/dialogue_understanding/generator/single_step/search_ready_llm_command_generator.py +2 -2
- rasa/dialogue_understanding/processor/command_processor.py +5 -5
- rasa/shared/core/flows/validation.py +9 -2
- rasa/shared/providers/_configs/azure_openai_client_config.py +2 -2
- rasa/shared/providers/_configs/default_litellm_client_config.py +1 -1
- rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +1 -1
- rasa/shared/providers/_configs/openai_client_config.py +1 -1
- rasa/shared/providers/_configs/rasa_llm_client_config.py +1 -1
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +1 -1
- rasa/shared/providers/_configs/utils.py +0 -99
- rasa/shared/utils/configs.py +110 -0
- rasa/shared/utils/constants.py +0 -3
- rasa/shared/utils/pykwalify_extensions.py +0 -9
- rasa/studio/constants.py +1 -0
- rasa/studio/download.py +164 -0
- rasa/studio/link.py +1 -1
- rasa/studio/{download/flows.py → pull/data.py} +2 -131
- rasa/studio/{download → pull}/domains.py +1 -1
- rasa/studio/pull/pull.py +235 -0
- rasa/studio/push.py +5 -0
- rasa/studio/train.py +1 -1
- rasa/tracing/instrumentation/attribute_extractors.py +10 -5
- rasa/version.py +1 -1
- {rasa_pro-3.13.0.dev9.dist-info → rasa_pro-3.13.0.dev11.dist-info}/METADATA +1 -1
- {rasa_pro-3.13.0.dev9.dist-info → rasa_pro-3.13.0.dev11.dist-info}/RECORD +43 -40
- rasa/studio/download/download.py +0 -416
- rasa/studio/pull.py +0 -94
- /rasa/studio/{download → pull}/__init__.py +0 -0
- {rasa_pro-3.13.0.dev9.dist-info → rasa_pro-3.13.0.dev11.dist-info}/NOTICE +0 -0
- {rasa_pro-3.13.0.dev9.dist-info → rasa_pro-3.13.0.dev11.dist-info}/WHEEL +0 -0
- {rasa_pro-3.13.0.dev9.dist-info → rasa_pro-3.13.0.dev11.dist-info}/entry_points.txt +0 -0
|
@@ -2,7 +2,7 @@ import dataclasses
|
|
|
2
2
|
import importlib.resources
|
|
3
3
|
import json
|
|
4
4
|
import re
|
|
5
|
-
from typing import TYPE_CHECKING, Any, Dict, List,
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Text
|
|
6
6
|
|
|
7
7
|
import dotenv
|
|
8
8
|
import structlog
|
|
@@ -12,9 +12,6 @@ from pydantic import ValidationError
|
|
|
12
12
|
import rasa.shared.utils.io
|
|
13
13
|
from rasa.core.available_endpoints import AvailableEndpoints
|
|
14
14
|
from rasa.core.constants import (
|
|
15
|
-
POLICY_MAX_HISTORY,
|
|
16
|
-
POLICY_PRIORITY,
|
|
17
|
-
SEARCH_POLICY_PRIORITY,
|
|
18
15
|
UTTER_SOURCE_METADATA_KEY,
|
|
19
16
|
)
|
|
20
17
|
from rasa.core.information_retrieval import (
|
|
@@ -24,6 +21,14 @@ from rasa.core.information_retrieval import (
|
|
|
24
21
|
create_from_endpoint_config,
|
|
25
22
|
)
|
|
26
23
|
from rasa.core.information_retrieval.faiss import FAISS_Store
|
|
24
|
+
from rasa.core.policies.enterprise_search_policy_config import (
|
|
25
|
+
DEFAULT_EMBEDDINGS_CONFIG,
|
|
26
|
+
DEFAULT_ENTERPRISE_SEARCH_CONFIG,
|
|
27
|
+
DEFAULT_LLM_CONFIG,
|
|
28
|
+
DEFAULT_VECTOR_STORE_TYPE,
|
|
29
|
+
SOURCE_PROPERTY,
|
|
30
|
+
EnterpriseSearchPolicyConfig,
|
|
31
|
+
)
|
|
27
32
|
from rasa.core.policies.policy import Policy, PolicyPrediction
|
|
28
33
|
from rasa.dialogue_understanding.generator.constants import (
|
|
29
34
|
LLM_CONFIG_KEY,
|
|
@@ -47,18 +52,11 @@ from rasa.graph_components.providers.forms_provider import Forms
|
|
|
47
52
|
from rasa.graph_components.providers.responses_provider import Responses
|
|
48
53
|
from rasa.shared.constants import (
|
|
49
54
|
EMBEDDINGS_CONFIG_KEY,
|
|
50
|
-
MAX_COMPLETION_TOKENS_CONFIG_KEY,
|
|
51
|
-
MAX_RETRIES_CONFIG_KEY,
|
|
52
55
|
MODEL_CONFIG_KEY,
|
|
53
56
|
MODEL_GROUP_ID_CONFIG_KEY,
|
|
54
57
|
MODEL_NAME_CONFIG_KEY,
|
|
55
|
-
OPENAI_PROVIDER,
|
|
56
|
-
PROMPT_CONFIG_KEY,
|
|
57
|
-
PROMPT_TEMPLATE_CONFIG_KEY,
|
|
58
58
|
PROVIDER_CONFIG_KEY,
|
|
59
59
|
RASA_PATTERN_CANNOT_HANDLE_NO_RELEVANT_ANSWER,
|
|
60
|
-
TEMPERATURE_CONFIG_KEY,
|
|
61
|
-
TIMEOUT_CONFIG_KEY,
|
|
62
60
|
)
|
|
63
61
|
from rasa.shared.core.constants import (
|
|
64
62
|
ACTION_CANCEL_FLOW,
|
|
@@ -93,13 +91,9 @@ from rasa.shared.utils.health_check.embeddings_health_check_mixin import (
|
|
|
93
91
|
from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheckMixin
|
|
94
92
|
from rasa.shared.utils.io import deep_container_fingerprint
|
|
95
93
|
from rasa.shared.utils.llm import (
|
|
96
|
-
DEFAULT_OPENAI_CHAT_MODEL_NAME,
|
|
97
|
-
DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
|
|
98
|
-
check_prompt_config_keys_and_warn_if_deprecated,
|
|
99
94
|
embedder_factory,
|
|
100
95
|
get_prompt_template,
|
|
101
96
|
llm_factory,
|
|
102
|
-
resolve_model_client_config,
|
|
103
97
|
sanitize_message_for_prompt,
|
|
104
98
|
tracker_as_readable_transcript,
|
|
105
99
|
)
|
|
@@ -120,42 +114,6 @@ structlogger = structlog.get_logger()
|
|
|
120
114
|
|
|
121
115
|
dotenv.load_dotenv("./.env")
|
|
122
116
|
|
|
123
|
-
SOURCE_PROPERTY = "source"
|
|
124
|
-
VECTOR_STORE_TYPE_PROPERTY = "type"
|
|
125
|
-
VECTOR_STORE_PROPERTY = "vector_store"
|
|
126
|
-
VECTOR_STORE_THRESHOLD_PROPERTY = "threshold"
|
|
127
|
-
TRACE_TOKENS_PROPERTY = "trace_prompt_tokens"
|
|
128
|
-
CITATION_ENABLED_PROPERTY = "citation_enabled"
|
|
129
|
-
USE_LLM_PROPERTY = "use_generative_llm"
|
|
130
|
-
CHECK_RELEVANCY_PROPERTY = "check_relevancy"
|
|
131
|
-
MAX_MESSAGES_IN_QUERY_KEY = "max_messages_in_query"
|
|
132
|
-
|
|
133
|
-
DEFAULT_VECTOR_STORE_TYPE = "faiss"
|
|
134
|
-
DEFAULT_VECTOR_STORE_THRESHOLD = 0.0
|
|
135
|
-
DEFAULT_VECTOR_STORE = {
|
|
136
|
-
VECTOR_STORE_TYPE_PROPERTY: DEFAULT_VECTOR_STORE_TYPE,
|
|
137
|
-
SOURCE_PROPERTY: "./docs",
|
|
138
|
-
VECTOR_STORE_THRESHOLD_PROPERTY: DEFAULT_VECTOR_STORE_THRESHOLD,
|
|
139
|
-
}
|
|
140
|
-
|
|
141
|
-
DEFAULT_CHECK_RELEVANCY_PROPERTY = False
|
|
142
|
-
DEFAULT_USE_LLM_PROPERTY = True
|
|
143
|
-
DEFAULT_CITATION_ENABLED_PROPERTY = False
|
|
144
|
-
|
|
145
|
-
DEFAULT_LLM_CONFIG = {
|
|
146
|
-
PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
|
|
147
|
-
MODEL_CONFIG_KEY: DEFAULT_OPENAI_CHAT_MODEL_NAME,
|
|
148
|
-
TIMEOUT_CONFIG_KEY: 10,
|
|
149
|
-
TEMPERATURE_CONFIG_KEY: 0.0,
|
|
150
|
-
MAX_COMPLETION_TOKENS_CONFIG_KEY: 256,
|
|
151
|
-
MAX_RETRIES_CONFIG_KEY: 1,
|
|
152
|
-
}
|
|
153
|
-
|
|
154
|
-
DEFAULT_EMBEDDINGS_CONFIG = {
|
|
155
|
-
PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
|
|
156
|
-
MODEL_CONFIG_KEY: DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
|
|
157
|
-
}
|
|
158
|
-
|
|
159
117
|
ENTERPRISE_SEARCH_PROMPT_FILE_NAME = "enterprise_search_policy_prompt.jinja2"
|
|
160
118
|
ENTERPRISE_SEARCH_CONFIG_FILE_NAME = "config.json"
|
|
161
119
|
|
|
@@ -177,10 +135,7 @@ DEFAULT_ENTERPRISE_SEARCH_PROMPT_WITH_RELEVANCY_CHECK_AND_CITATION_TEMPLATE = (
|
|
|
177
135
|
)
|
|
178
136
|
)
|
|
179
137
|
|
|
180
|
-
|
|
181
|
-
_ENTERPRISE_SEARCH_ANSWER_NOT_RELEVANT_PATTERN = re.compile(
|
|
182
|
-
r"\[NO_RELEVANT_ANSWER_FOUND\]"
|
|
183
|
-
)
|
|
138
|
+
_ENTERPRISE_SEARCH_ANSWER_NOT_RELEVANT_PATTERN = re.compile(r"\[NO_RAG_ANSWER\]")
|
|
184
139
|
|
|
185
140
|
|
|
186
141
|
class VectorStoreConnectionError(RasaException):
|
|
@@ -228,10 +183,7 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
228
183
|
@staticmethod
|
|
229
184
|
def get_default_config() -> Dict[str, Any]:
|
|
230
185
|
"""Returns the default config of the policy."""
|
|
231
|
-
return
|
|
232
|
-
POLICY_PRIORITY: SEARCH_POLICY_PRIORITY,
|
|
233
|
-
VECTOR_STORE_PROPERTY: DEFAULT_VECTOR_STORE,
|
|
234
|
-
}
|
|
186
|
+
return DEFAULT_ENTERPRISE_SEARCH_CONFIG
|
|
235
187
|
|
|
236
188
|
def __init__(
|
|
237
189
|
self,
|
|
@@ -246,105 +198,71 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
246
198
|
"""Constructs a new Policy object."""
|
|
247
199
|
super().__init__(config, model_storage, resource, execution_context, featurizer)
|
|
248
200
|
|
|
249
|
-
|
|
250
|
-
check_prompt_config_keys_and_warn_if_deprecated(
|
|
251
|
-
config, "enterprise_search_policy"
|
|
252
|
-
)
|
|
253
|
-
# Check for mutual exclusivity of extractive and generative search
|
|
254
|
-
self._check_and_warn_mutual_exclusivity_of_extractive_and_generative_search()
|
|
255
|
-
|
|
256
|
-
# Resolve LLM config
|
|
257
|
-
self.config[LLM_CONFIG_KEY] = resolve_model_client_config(
|
|
258
|
-
self.config.get(LLM_CONFIG_KEY), EnterpriseSearchPolicy.__name__
|
|
259
|
-
)
|
|
260
|
-
# Resolve embeddings config
|
|
261
|
-
self.config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
|
|
262
|
-
self.config.get(EMBEDDINGS_CONFIG_KEY), EnterpriseSearchPolicy.__name__
|
|
263
|
-
)
|
|
201
|
+
parsed_config = EnterpriseSearchPolicyConfig.from_dict(config)
|
|
264
202
|
|
|
265
203
|
# Vector store object and configuration
|
|
266
204
|
self.vector_store = vector_store
|
|
267
|
-
self.vector_store_config =
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
self.vector_search_threshold = self.vector_store_config.get(
|
|
271
|
-
VECTOR_STORE_THRESHOLD_PROPERTY, DEFAULT_VECTOR_STORE_THRESHOLD
|
|
272
|
-
)
|
|
205
|
+
self.vector_store_config = parsed_config.vector_store_config
|
|
206
|
+
self.vector_search_threshold = parsed_config.vector_store_threshold
|
|
207
|
+
self.vector_store_type = parsed_config.vector_store_type
|
|
273
208
|
|
|
274
|
-
#
|
|
275
|
-
self.embeddings_config =
|
|
276
|
-
self.config[EMBEDDINGS_CONFIG_KEY] or DEFAULT_EMBEDDINGS_CONFIG
|
|
277
|
-
)
|
|
209
|
+
# Resolved embeddings configuration for encoding the search query
|
|
210
|
+
self.embeddings_config = parsed_config.embeddings_config
|
|
278
211
|
|
|
279
|
-
# LLM Configuration for response generation
|
|
280
|
-
self.llm_config =
|
|
212
|
+
# Resolved LLM Configuration for response generation
|
|
213
|
+
self.llm_config = parsed_config.llm_config
|
|
281
214
|
|
|
282
215
|
# Maximum number of turns to include in the prompt
|
|
283
|
-
self.max_history =
|
|
216
|
+
self.max_history = parsed_config.max_history
|
|
284
217
|
|
|
285
218
|
# Maximum number of messages to include in the search query
|
|
286
|
-
self.max_messages_in_query =
|
|
219
|
+
self.max_messages_in_query = parsed_config.max_messages_in_query
|
|
287
220
|
|
|
288
221
|
# Boolean to enable/disable tracing of prompt tokens
|
|
289
|
-
self.trace_prompt_tokens =
|
|
222
|
+
self.trace_prompt_tokens = parsed_config.trace_prompt_tokens
|
|
290
223
|
|
|
291
224
|
# Boolean to enable/disable the use of LLM for response generation
|
|
292
|
-
self.use_llm =
|
|
225
|
+
self.use_llm = parsed_config.use_generative_llm
|
|
293
226
|
|
|
294
227
|
# Boolean to enable/disable citation generation. This flag enables citation
|
|
295
228
|
# logic, but it only takes effect if `use_llm` is True.
|
|
296
|
-
self.citation_enabled =
|
|
297
|
-
CITATION_ENABLED_PROPERTY, DEFAULT_CITATION_ENABLED_PROPERTY
|
|
298
|
-
)
|
|
229
|
+
self.citation_enabled = parsed_config.enable_citation
|
|
299
230
|
|
|
300
231
|
# Boolean to enable/disable the use of relevancy check alongside answer
|
|
301
232
|
# generation. This flag enables citation logic, but it only takes effect if
|
|
302
233
|
# `use_llm` is True.
|
|
303
|
-
self.relevancy_check_enabled =
|
|
304
|
-
CHECK_RELEVANCY_PROPERTY, DEFAULT_CHECK_RELEVANCY_PROPERTY
|
|
305
|
-
)
|
|
234
|
+
self.relevancy_check_enabled = parsed_config.check_relevancy
|
|
306
235
|
|
|
307
236
|
# Resolve the prompt template. The prompt will only be used if the 'use_llm' is
|
|
308
237
|
# set to True.
|
|
309
|
-
self.prompt_template = prompt_template or
|
|
310
|
-
|
|
238
|
+
self.prompt_template = prompt_template or get_prompt_template(
|
|
239
|
+
jinja_file_path=parsed_config.prompt_template,
|
|
240
|
+
default_prompt_template=self._select_default_prompt_template_based_on_features(
|
|
241
|
+
parsed_config.check_relevancy, parsed_config.enable_citation
|
|
242
|
+
),
|
|
243
|
+
log_source_component=EnterpriseSearchPolicy.__name__,
|
|
244
|
+
log_source_method=LOG_COMPONENT_SOURCE_METHOD_INIT,
|
|
311
245
|
)
|
|
312
246
|
|
|
313
|
-
def _check_and_warn_mutual_exclusivity_of_extractive_and_generative_search(
|
|
314
|
-
self,
|
|
315
|
-
) -> None:
|
|
316
|
-
if self.config.get(
|
|
317
|
-
CHECK_RELEVANCY_PROPERTY, DEFAULT_CHECK_RELEVANCY_PROPERTY
|
|
318
|
-
) and not self.config.get(USE_LLM_PROPERTY, DEFAULT_USE_LLM_PROPERTY):
|
|
319
|
-
structlogger.warning(
|
|
320
|
-
"enterprise_search_policy.init"
|
|
321
|
-
".relevancy_check_enabled_with_disabled_generative_search",
|
|
322
|
-
event_info=(
|
|
323
|
-
f"The config parameter '{CHECK_RELEVANCY_PROPERTY}' is set to"
|
|
324
|
-
f"'True', but the generative search is disabled (config"
|
|
325
|
-
f"parameter '{USE_LLM_PROPERTY}' is set to 'False'). As a result, "
|
|
326
|
-
"the relevancy check for the generative search will be disabled. "
|
|
327
|
-
f"To use this check, set the config parameter '{USE_LLM_PROPERTY}' "
|
|
328
|
-
f"to `True`."
|
|
329
|
-
),
|
|
330
|
-
)
|
|
331
|
-
|
|
332
247
|
@classmethod
|
|
333
|
-
def _create_plain_embedder(cls,
|
|
248
|
+
def _create_plain_embedder(cls, embeddings_config: Dict[Text, Any]) -> "Embeddings":
|
|
334
249
|
"""Creates an embedder based on the given configuration.
|
|
335
250
|
|
|
251
|
+
Args:
|
|
252
|
+
embeddings_config: A resolved embeddings configuration. Resolved means the
|
|
253
|
+
configuration is either:
|
|
254
|
+
- A reference to a model group that has already been expanded into
|
|
255
|
+
its corresponding configuration using the information from
|
|
256
|
+
`endpoints.yml`, or
|
|
257
|
+
- A full configuration for the embedder defined directly (i.e. not
|
|
258
|
+
relying on model groups or indirections).
|
|
259
|
+
|
|
336
260
|
Returns:
|
|
337
|
-
|
|
261
|
+
The embedder.
|
|
338
262
|
"""
|
|
339
263
|
# Copy the config so original config is not modified
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
|
|
343
|
-
config.get(EMBEDDINGS_CONFIG_KEY), EnterpriseSearchPolicy.__name__
|
|
344
|
-
)
|
|
345
|
-
client = embedder_factory(
|
|
346
|
-
config.get(EMBEDDINGS_CONFIG_KEY), DEFAULT_EMBEDDINGS_CONFIG
|
|
347
|
-
)
|
|
264
|
+
embeddings_config = embeddings_config.copy()
|
|
265
|
+
client = embedder_factory(embeddings_config, DEFAULT_EMBEDDINGS_CONFIG)
|
|
348
266
|
# Wrap the embedding client in the adapter
|
|
349
267
|
return _LangchainEmbeddingClientAdapter(client)
|
|
350
268
|
|
|
@@ -410,16 +328,16 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
410
328
|
can load the policy from the resource.
|
|
411
329
|
"""
|
|
412
330
|
# Perform health checks for both LLM and embeddings client configs
|
|
413
|
-
self._perform_health_checks(
|
|
414
|
-
|
|
415
|
-
|
|
331
|
+
self._perform_health_checks(
|
|
332
|
+
self.llm_config, self.embeddings_config, "enterprise_search_policy.train"
|
|
333
|
+
)
|
|
416
334
|
|
|
417
335
|
# telemetry call to track training start
|
|
418
336
|
track_enterprise_search_policy_train_started()
|
|
419
337
|
|
|
420
338
|
# validate embedding configuration
|
|
421
339
|
try:
|
|
422
|
-
embeddings = self._create_plain_embedder(self.
|
|
340
|
+
embeddings = self._create_plain_embedder(self.embeddings_config)
|
|
423
341
|
except (ValidationError, Exception) as e:
|
|
424
342
|
structlogger.error(
|
|
425
343
|
"enterprise_search_policy.train.embedder_instantiation_failed",
|
|
@@ -431,7 +349,7 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
431
349
|
f"required environment variables. Error: {e}"
|
|
432
350
|
)
|
|
433
351
|
|
|
434
|
-
if
|
|
352
|
+
if self.vector_store_type == DEFAULT_VECTOR_STORE_TYPE:
|
|
435
353
|
structlogger.info("enterprise_search_policy.train.faiss")
|
|
436
354
|
with self._model_storage.write_to(self._resource) as path:
|
|
437
355
|
self.vector_store = FAISS_Store(
|
|
@@ -443,12 +361,13 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
443
361
|
)
|
|
444
362
|
else:
|
|
445
363
|
structlogger.info(
|
|
446
|
-
"enterprise_search_policy.train.custom",
|
|
364
|
+
"enterprise_search_policy.train.custom",
|
|
365
|
+
store_type=self.vector_store_type,
|
|
447
366
|
)
|
|
448
367
|
|
|
449
368
|
# telemetry call to track training completion
|
|
450
369
|
track_enterprise_search_policy_train_completed(
|
|
451
|
-
vector_store_type=
|
|
370
|
+
vector_store_type=self.vector_store_type,
|
|
452
371
|
embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
|
|
453
372
|
embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
|
|
454
373
|
or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
|
|
@@ -471,8 +390,11 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
471
390
|
rasa.shared.utils.io.write_text_file(
|
|
472
391
|
self.prompt_template, path / ENTERPRISE_SEARCH_PROMPT_FILE_NAME
|
|
473
392
|
)
|
|
393
|
+
config = self.config.copy()
|
|
394
|
+
config[LLM_CONFIG_KEY] = self.llm_config
|
|
395
|
+
config[EMBEDDINGS_CONFIG_KEY] = self.embeddings_config
|
|
474
396
|
rasa.shared.utils.io.dump_obj_as_json_to_file(
|
|
475
|
-
path / ENTERPRISE_SEARCH_CONFIG_FILE_NAME,
|
|
397
|
+
path / ENTERPRISE_SEARCH_CONFIG_FILE_NAME, config
|
|
476
398
|
)
|
|
477
399
|
|
|
478
400
|
def _prepare_slots_for_template(
|
|
@@ -511,8 +433,7 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
511
433
|
endpoints: Endpoints configuration.
|
|
512
434
|
"""
|
|
513
435
|
config = endpoints.vector_store if endpoints else None
|
|
514
|
-
|
|
515
|
-
if config is None and store_type != DEFAULT_VECTOR_STORE_TYPE:
|
|
436
|
+
if config is None and self.vector_store_type != DEFAULT_VECTOR_STORE_TYPE:
|
|
516
437
|
structlogger.error(
|
|
517
438
|
"enterprise_search_policy._connect_vector_store_or_raise.no_config"
|
|
518
439
|
)
|
|
@@ -673,7 +594,7 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
673
594
|
|
|
674
595
|
# telemetry call to track policy prediction
|
|
675
596
|
track_enterprise_search_policy_predict(
|
|
676
|
-
vector_store_type=self.
|
|
597
|
+
vector_store_type=self.vector_store_type,
|
|
677
598
|
embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
|
|
678
599
|
embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
|
|
679
600
|
or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
|
|
@@ -732,7 +653,7 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
732
653
|
Returns:
|
|
733
654
|
An LLMResponse object, or None if the call fails.
|
|
734
655
|
"""
|
|
735
|
-
llm = llm_factory(self.
|
|
656
|
+
llm = llm_factory(self.llm_config, DEFAULT_LLM_CONFIG)
|
|
736
657
|
try:
|
|
737
658
|
response = await llm.acompletion(prompt)
|
|
738
659
|
return LLMResponse.ensure_llm_response(response)
|
|
@@ -862,73 +783,88 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
862
783
|
**kwargs: Any,
|
|
863
784
|
) -> "EnterpriseSearchPolicy":
|
|
864
785
|
"""Loads a trained policy (see parent class for full docstring)."""
|
|
786
|
+
parsed_config = EnterpriseSearchPolicyConfig.from_dict(config)
|
|
787
|
+
|
|
865
788
|
# Perform health checks for both LLM and embeddings client configs
|
|
866
|
-
cls._perform_health_checks(
|
|
789
|
+
cls._perform_health_checks(
|
|
790
|
+
parsed_config.llm_config,
|
|
791
|
+
parsed_config.embeddings_config,
|
|
792
|
+
"enterprise_search_policy.load",
|
|
793
|
+
)
|
|
794
|
+
|
|
795
|
+
prompt_template = cls._load_prompt_template(model_storage, resource)
|
|
796
|
+
embeddings = cls._create_plain_embedder(parsed_config.embeddings_config)
|
|
797
|
+
vector_store = cls._load_vector_store(
|
|
798
|
+
embeddings,
|
|
799
|
+
parsed_config.vector_store_type,
|
|
800
|
+
parsed_config.use_generative_llm,
|
|
801
|
+
model_storage,
|
|
802
|
+
resource,
|
|
803
|
+
)
|
|
804
|
+
|
|
805
|
+
structlogger.info("enterprise_search_policy.load", config=config)
|
|
806
|
+
|
|
807
|
+
return cls(
|
|
808
|
+
config,
|
|
809
|
+
model_storage,
|
|
810
|
+
resource,
|
|
811
|
+
execution_context,
|
|
812
|
+
vector_store=vector_store,
|
|
813
|
+
prompt_template=prompt_template,
|
|
814
|
+
)
|
|
867
815
|
|
|
868
|
-
|
|
816
|
+
@classmethod
|
|
817
|
+
def _load_prompt_template(
|
|
818
|
+
cls, model_storage: ModelStorage, resource: Resource
|
|
819
|
+
) -> Optional[str]:
|
|
869
820
|
try:
|
|
870
821
|
with model_storage.read_from(resource) as path:
|
|
871
|
-
|
|
822
|
+
return rasa.shared.utils.io.read_file(
|
|
872
823
|
path / ENTERPRISE_SEARCH_PROMPT_FILE_NAME
|
|
873
824
|
)
|
|
874
825
|
except (FileNotFoundError, FileIOException) as e:
|
|
875
826
|
structlogger.warning(
|
|
876
827
|
"enterprise_search_policy.load.failed", error=e, resource=resource.name
|
|
877
828
|
)
|
|
829
|
+
return None
|
|
878
830
|
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
831
|
+
@classmethod
|
|
832
|
+
def _load_vector_store(
|
|
833
|
+
cls,
|
|
834
|
+
embeddings: "Embeddings",
|
|
835
|
+
store_type: str,
|
|
836
|
+
use_generative_llm: bool,
|
|
837
|
+
model_storage: ModelStorage,
|
|
838
|
+
resource: Resource,
|
|
839
|
+
) -> InformationRetrieval:
|
|
886
840
|
if store_type == DEFAULT_VECTOR_STORE_TYPE:
|
|
887
841
|
# if a vector store is not specified,
|
|
888
842
|
# default to using FAISS with the index stored in the model
|
|
889
843
|
# TODO figure out a way to get path without context manager
|
|
890
844
|
with model_storage.read_from(resource) as path:
|
|
891
|
-
|
|
845
|
+
return FAISS_Store(
|
|
892
846
|
embeddings=embeddings,
|
|
893
847
|
index_path=path,
|
|
894
848
|
docs_folder=None,
|
|
895
849
|
create_index=False,
|
|
896
|
-
parse_as_faq_pairs=not
|
|
897
|
-
USE_LLM_PROPERTY, DEFAULT_USE_LLM_PROPERTY
|
|
898
|
-
),
|
|
850
|
+
parse_as_faq_pairs=not use_generative_llm,
|
|
899
851
|
)
|
|
900
852
|
else:
|
|
901
|
-
|
|
853
|
+
return create_from_endpoint_config(
|
|
902
854
|
config_type=store_type,
|
|
903
855
|
embeddings=embeddings,
|
|
904
|
-
)
|
|
905
|
-
|
|
906
|
-
return cls(
|
|
907
|
-
config,
|
|
908
|
-
model_storage,
|
|
909
|
-
resource,
|
|
910
|
-
execution_context,
|
|
911
|
-
vector_store=vector_store,
|
|
912
|
-
prompt_template=prompt_template,
|
|
913
|
-
)
|
|
856
|
+
)
|
|
914
857
|
|
|
915
858
|
@classmethod
|
|
916
|
-
def _get_local_knowledge_data(
|
|
859
|
+
def _get_local_knowledge_data(
|
|
860
|
+
cls, store_type: str, source: Optional[str] = None
|
|
861
|
+
) -> Optional[List[str]]:
|
|
917
862
|
"""This is required only for local knowledge base types.
|
|
918
863
|
|
|
919
864
|
e.g. FAISS, to ensure that the graph component is retrained when the knowledge
|
|
920
865
|
base is updated.
|
|
921
866
|
"""
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
store_type = merged_config.get(VECTOR_STORE_PROPERTY, {}).get(
|
|
925
|
-
VECTOR_STORE_TYPE_PROPERTY
|
|
926
|
-
)
|
|
927
|
-
if store_type != DEFAULT_VECTOR_STORE_TYPE:
|
|
928
|
-
return None
|
|
929
|
-
|
|
930
|
-
source = merged_config.get(VECTOR_STORE_PROPERTY, {}).get(SOURCE_PROPERTY)
|
|
931
|
-
if not source:
|
|
867
|
+
if store_type != DEFAULT_VECTOR_STORE_TYPE or not source:
|
|
932
868
|
return None
|
|
933
869
|
|
|
934
870
|
docs = FAISS_Store.load_documents(source)
|
|
@@ -944,18 +880,28 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
944
880
|
@classmethod
|
|
945
881
|
def fingerprint_addon(cls, config: Dict[str, Any]) -> Optional[str]:
|
|
946
882
|
"""Add a fingerprint of enterprise search policy for the graph."""
|
|
947
|
-
|
|
948
|
-
config, LOG_COMPONENT_SOURCE_METHOD_FINGERPRINT_ADDON
|
|
949
|
-
)
|
|
950
|
-
|
|
951
|
-
local_knowledge_data = cls._get_local_knowledge_data(config)
|
|
883
|
+
parsed_config = EnterpriseSearchPolicyConfig.from_dict(config)
|
|
952
884
|
|
|
953
|
-
|
|
954
|
-
|
|
885
|
+
# Resolve the prompt template
|
|
886
|
+
default_prompt_template = cls._select_default_prompt_template_based_on_features(
|
|
887
|
+
parsed_config.check_relevancy, parsed_config.enable_citation
|
|
955
888
|
)
|
|
956
|
-
|
|
957
|
-
|
|
889
|
+
prompt_template = get_prompt_template(
|
|
890
|
+
jinja_file_path=parsed_config.prompt_template,
|
|
891
|
+
default_prompt_template=default_prompt_template,
|
|
892
|
+
log_source_component=EnterpriseSearchPolicy.__name__,
|
|
893
|
+
log_source_method=LOG_COMPONENT_SOURCE_METHOD_FINGERPRINT_ADDON,
|
|
894
|
+
)
|
|
895
|
+
|
|
896
|
+
# Fetch the local knowledge data in case FAISS is used
|
|
897
|
+
local_knowledge_data = cls._get_local_knowledge_data(
|
|
898
|
+
parsed_config.vector_store_type, parsed_config.vector_store_source
|
|
958
899
|
)
|
|
900
|
+
|
|
901
|
+
# Get the resolved LLM and embeddings configurations
|
|
902
|
+
llm_config = parsed_config.llm_config
|
|
903
|
+
embedding_config = parsed_config.embeddings_config
|
|
904
|
+
|
|
959
905
|
return deep_container_fingerprint(
|
|
960
906
|
[prompt_template, local_knowledge_data, llm_config, embedding_config]
|
|
961
907
|
)
|
|
@@ -1053,21 +999,32 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
1053
999
|
|
|
1054
1000
|
@classmethod
|
|
1055
1001
|
def _perform_health_checks(
|
|
1056
|
-
cls,
|
|
1002
|
+
cls,
|
|
1003
|
+
llm_config: Dict[Text, Any],
|
|
1004
|
+
embeddings_config: Dict[Text, Any],
|
|
1005
|
+
log_source_method: str,
|
|
1057
1006
|
) -> None:
|
|
1058
|
-
|
|
1059
|
-
|
|
1007
|
+
"""
|
|
1008
|
+
Perform the health checks using resolved LLM and embeddings configurations.
|
|
1009
|
+
Resolved means the configuration is either:
|
|
1010
|
+
- A reference to a model group that has already been expanded into
|
|
1011
|
+
its corresponding configuration using the information from
|
|
1012
|
+
`endpoints.yml`, or
|
|
1013
|
+
- A full configuration for the embedder defined directly (i.e. not
|
|
1014
|
+
relying on model groups or indirections).
|
|
1015
|
+
|
|
1016
|
+
Args:
|
|
1017
|
+
llm_config: A resolved LLM configuration.
|
|
1018
|
+
embeddings_config: A resolved embeddings configuration.
|
|
1019
|
+
log_source_method: The method health checks has been called from.
|
|
1020
|
+
|
|
1021
|
+
"""
|
|
1060
1022
|
cls.perform_llm_health_check(
|
|
1061
1023
|
llm_config,
|
|
1062
1024
|
DEFAULT_LLM_CONFIG,
|
|
1063
1025
|
log_source_method,
|
|
1064
1026
|
EnterpriseSearchPolicy.__name__,
|
|
1065
1027
|
)
|
|
1066
|
-
|
|
1067
|
-
# Perform health check of the embeddings client config
|
|
1068
|
-
embeddings_config = resolve_model_client_config(
|
|
1069
|
-
config.get(EMBEDDINGS_CONFIG_KEY, {})
|
|
1070
|
-
)
|
|
1071
1028
|
cls.perform_embeddings_health_check(
|
|
1072
1029
|
embeddings_config,
|
|
1073
1030
|
DEFAULT_EMBEDDINGS_CONFIG,
|
|
@@ -1093,62 +1050,16 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
1093
1050
|
Returns:
|
|
1094
1051
|
The resolved jinja prompt template as a string.
|
|
1095
1052
|
"""
|
|
1096
|
-
|
|
1097
1053
|
# Get the feature flags
|
|
1098
|
-
|
|
1099
|
-
CITATION_ENABLED_PROPERTY, DEFAULT_CITATION_ENABLED_PROPERTY
|
|
1100
|
-
)
|
|
1101
|
-
relevancy_check_enabled = config.get(
|
|
1102
|
-
CHECK_RELEVANCY_PROPERTY, DEFAULT_CHECK_RELEVANCY_PROPERTY
|
|
1103
|
-
)
|
|
1104
|
-
|
|
1054
|
+
parsed_config = EnterpriseSearchPolicyConfig.from_dict(config)
|
|
1105
1055
|
# Based on the enabled features (citation, relevancy check) fetch the
|
|
1106
1056
|
# appropriate default prompt
|
|
1107
1057
|
default_prompt = cls._select_default_prompt_template_based_on_features(
|
|
1108
|
-
|
|
1058
|
+
parsed_config.check_relevancy, parsed_config.enable_citation
|
|
1109
1059
|
)
|
|
1110
1060
|
|
|
1111
1061
|
return default_prompt
|
|
1112
1062
|
|
|
1113
|
-
@classmethod
|
|
1114
|
-
def _resolve_prompt_template(
|
|
1115
|
-
cls,
|
|
1116
|
-
config: dict,
|
|
1117
|
-
log_source_method: Literal["init", "fingerprint"],
|
|
1118
|
-
) -> str:
|
|
1119
|
-
"""
|
|
1120
|
-
Resolves the prompt template to use for the Enterprise Search Policy's
|
|
1121
|
-
generative search.
|
|
1122
|
-
|
|
1123
|
-
Checks if a custom template is provided via component's configuration. If not,
|
|
1124
|
-
it selects the appropriate default template based on the enabled features
|
|
1125
|
-
(citation and relevancy check).
|
|
1126
|
-
|
|
1127
|
-
Args:
|
|
1128
|
-
config: The component's configuration.
|
|
1129
|
-
log_source_method: The name of the method or function emitting the log for
|
|
1130
|
-
better traceability.
|
|
1131
|
-
Returns:
|
|
1132
|
-
The resolved jinja prompt template as a string.
|
|
1133
|
-
"""
|
|
1134
|
-
|
|
1135
|
-
# Read the template path from the configuration if available.
|
|
1136
|
-
# The deprecated 'prompt' has a lower priority compared to 'prompt_template'
|
|
1137
|
-
config_defined_prompt = (
|
|
1138
|
-
config.get(PROMPT_TEMPLATE_CONFIG_KEY)
|
|
1139
|
-
or config.get(PROMPT_CONFIG_KEY)
|
|
1140
|
-
or None
|
|
1141
|
-
)
|
|
1142
|
-
# Select the default prompt based on the features set in the config.
|
|
1143
|
-
default_prompt = cls.get_system_default_prompt_based_on_config(config)
|
|
1144
|
-
|
|
1145
|
-
return get_prompt_template(
|
|
1146
|
-
config_defined_prompt,
|
|
1147
|
-
default_prompt,
|
|
1148
|
-
log_source_component=EnterpriseSearchPolicy.__name__,
|
|
1149
|
-
log_source_method=log_source_method,
|
|
1150
|
-
)
|
|
1151
|
-
|
|
1152
1063
|
@classmethod
|
|
1153
1064
|
def _select_default_prompt_template_based_on_features(
|
|
1154
1065
|
cls,
|