rasa-pro 3.13.0.dev9__py3-none-any.whl → 3.13.0.dev10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/cli/export.py +2 -0
- rasa/core/exporter.py +36 -0
- rasa/core/policies/enterprise_search_policy.py +150 -236
- rasa/core/policies/enterprise_search_policy_config.py +242 -0
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v3_claude_3_5_sonnet_20240620_template.jinja2 +78 -0
- rasa/dialogue_understanding/generator/single_step/search_ready_llm_command_generator.py +2 -2
- rasa/shared/core/flows/validation.py +9 -2
- rasa/shared/providers/_configs/azure_openai_client_config.py +2 -2
- rasa/shared/providers/_configs/default_litellm_client_config.py +1 -1
- rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +1 -1
- rasa/shared/providers/_configs/openai_client_config.py +1 -1
- rasa/shared/providers/_configs/rasa_llm_client_config.py +1 -1
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +1 -1
- rasa/shared/providers/_configs/utils.py +0 -99
- rasa/shared/utils/configs.py +110 -0
- rasa/tracing/instrumentation/attribute_extractors.py +10 -5
- rasa/version.py +1 -1
- {rasa_pro-3.13.0.dev9.dist-info → rasa_pro-3.13.0.dev10.dist-info}/METADATA +1 -1
- {rasa_pro-3.13.0.dev9.dist-info → rasa_pro-3.13.0.dev10.dist-info}/RECORD +22 -19
- {rasa_pro-3.13.0.dev9.dist-info → rasa_pro-3.13.0.dev10.dist-info}/NOTICE +0 -0
- {rasa_pro-3.13.0.dev9.dist-info → rasa_pro-3.13.0.dev10.dist-info}/WHEEL +0 -0
- {rasa_pro-3.13.0.dev9.dist-info → rasa_pro-3.13.0.dev10.dist-info}/entry_points.txt +0 -0
|
@@ -2,7 +2,7 @@ import dataclasses
|
|
|
2
2
|
import importlib.resources
|
|
3
3
|
import json
|
|
4
4
|
import re
|
|
5
|
-
from typing import TYPE_CHECKING, Any, Dict, List,
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Text
|
|
6
6
|
|
|
7
7
|
import dotenv
|
|
8
8
|
import structlog
|
|
@@ -12,9 +12,6 @@ from pydantic import ValidationError
|
|
|
12
12
|
import rasa.shared.utils.io
|
|
13
13
|
from rasa.core.available_endpoints import AvailableEndpoints
|
|
14
14
|
from rasa.core.constants import (
|
|
15
|
-
POLICY_MAX_HISTORY,
|
|
16
|
-
POLICY_PRIORITY,
|
|
17
|
-
SEARCH_POLICY_PRIORITY,
|
|
18
15
|
UTTER_SOURCE_METADATA_KEY,
|
|
19
16
|
)
|
|
20
17
|
from rasa.core.information_retrieval import (
|
|
@@ -24,6 +21,14 @@ from rasa.core.information_retrieval import (
|
|
|
24
21
|
create_from_endpoint_config,
|
|
25
22
|
)
|
|
26
23
|
from rasa.core.information_retrieval.faiss import FAISS_Store
|
|
24
|
+
from rasa.core.policies.enterprise_search_policy_config import (
|
|
25
|
+
DEFAULT_EMBEDDINGS_CONFIG,
|
|
26
|
+
DEFAULT_ENTERPRISE_SEARCH_CONFIG,
|
|
27
|
+
DEFAULT_LLM_CONFIG,
|
|
28
|
+
DEFAULT_VECTOR_STORE_TYPE,
|
|
29
|
+
SOURCE_PROPERTY,
|
|
30
|
+
EnterpriseSearchPolicyConfig,
|
|
31
|
+
)
|
|
27
32
|
from rasa.core.policies.policy import Policy, PolicyPrediction
|
|
28
33
|
from rasa.dialogue_understanding.generator.constants import (
|
|
29
34
|
LLM_CONFIG_KEY,
|
|
@@ -47,18 +52,11 @@ from rasa.graph_components.providers.forms_provider import Forms
|
|
|
47
52
|
from rasa.graph_components.providers.responses_provider import Responses
|
|
48
53
|
from rasa.shared.constants import (
|
|
49
54
|
EMBEDDINGS_CONFIG_KEY,
|
|
50
|
-
MAX_COMPLETION_TOKENS_CONFIG_KEY,
|
|
51
|
-
MAX_RETRIES_CONFIG_KEY,
|
|
52
55
|
MODEL_CONFIG_KEY,
|
|
53
56
|
MODEL_GROUP_ID_CONFIG_KEY,
|
|
54
57
|
MODEL_NAME_CONFIG_KEY,
|
|
55
|
-
OPENAI_PROVIDER,
|
|
56
|
-
PROMPT_CONFIG_KEY,
|
|
57
|
-
PROMPT_TEMPLATE_CONFIG_KEY,
|
|
58
58
|
PROVIDER_CONFIG_KEY,
|
|
59
59
|
RASA_PATTERN_CANNOT_HANDLE_NO_RELEVANT_ANSWER,
|
|
60
|
-
TEMPERATURE_CONFIG_KEY,
|
|
61
|
-
TIMEOUT_CONFIG_KEY,
|
|
62
60
|
)
|
|
63
61
|
from rasa.shared.core.constants import (
|
|
64
62
|
ACTION_CANCEL_FLOW,
|
|
@@ -93,13 +91,9 @@ from rasa.shared.utils.health_check.embeddings_health_check_mixin import (
|
|
|
93
91
|
from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheckMixin
|
|
94
92
|
from rasa.shared.utils.io import deep_container_fingerprint
|
|
95
93
|
from rasa.shared.utils.llm import (
|
|
96
|
-
DEFAULT_OPENAI_CHAT_MODEL_NAME,
|
|
97
|
-
DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
|
|
98
|
-
check_prompt_config_keys_and_warn_if_deprecated,
|
|
99
94
|
embedder_factory,
|
|
100
95
|
get_prompt_template,
|
|
101
96
|
llm_factory,
|
|
102
|
-
resolve_model_client_config,
|
|
103
97
|
sanitize_message_for_prompt,
|
|
104
98
|
tracker_as_readable_transcript,
|
|
105
99
|
)
|
|
@@ -120,42 +114,6 @@ structlogger = structlog.get_logger()
|
|
|
120
114
|
|
|
121
115
|
dotenv.load_dotenv("./.env")
|
|
122
116
|
|
|
123
|
-
SOURCE_PROPERTY = "source"
|
|
124
|
-
VECTOR_STORE_TYPE_PROPERTY = "type"
|
|
125
|
-
VECTOR_STORE_PROPERTY = "vector_store"
|
|
126
|
-
VECTOR_STORE_THRESHOLD_PROPERTY = "threshold"
|
|
127
|
-
TRACE_TOKENS_PROPERTY = "trace_prompt_tokens"
|
|
128
|
-
CITATION_ENABLED_PROPERTY = "citation_enabled"
|
|
129
|
-
USE_LLM_PROPERTY = "use_generative_llm"
|
|
130
|
-
CHECK_RELEVANCY_PROPERTY = "check_relevancy"
|
|
131
|
-
MAX_MESSAGES_IN_QUERY_KEY = "max_messages_in_query"
|
|
132
|
-
|
|
133
|
-
DEFAULT_VECTOR_STORE_TYPE = "faiss"
|
|
134
|
-
DEFAULT_VECTOR_STORE_THRESHOLD = 0.0
|
|
135
|
-
DEFAULT_VECTOR_STORE = {
|
|
136
|
-
VECTOR_STORE_TYPE_PROPERTY: DEFAULT_VECTOR_STORE_TYPE,
|
|
137
|
-
SOURCE_PROPERTY: "./docs",
|
|
138
|
-
VECTOR_STORE_THRESHOLD_PROPERTY: DEFAULT_VECTOR_STORE_THRESHOLD,
|
|
139
|
-
}
|
|
140
|
-
|
|
141
|
-
DEFAULT_CHECK_RELEVANCY_PROPERTY = False
|
|
142
|
-
DEFAULT_USE_LLM_PROPERTY = True
|
|
143
|
-
DEFAULT_CITATION_ENABLED_PROPERTY = False
|
|
144
|
-
|
|
145
|
-
DEFAULT_LLM_CONFIG = {
|
|
146
|
-
PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
|
|
147
|
-
MODEL_CONFIG_KEY: DEFAULT_OPENAI_CHAT_MODEL_NAME,
|
|
148
|
-
TIMEOUT_CONFIG_KEY: 10,
|
|
149
|
-
TEMPERATURE_CONFIG_KEY: 0.0,
|
|
150
|
-
MAX_COMPLETION_TOKENS_CONFIG_KEY: 256,
|
|
151
|
-
MAX_RETRIES_CONFIG_KEY: 1,
|
|
152
|
-
}
|
|
153
|
-
|
|
154
|
-
DEFAULT_EMBEDDINGS_CONFIG = {
|
|
155
|
-
PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
|
|
156
|
-
MODEL_CONFIG_KEY: DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
|
|
157
|
-
}
|
|
158
|
-
|
|
159
117
|
ENTERPRISE_SEARCH_PROMPT_FILE_NAME = "enterprise_search_policy_prompt.jinja2"
|
|
160
118
|
ENTERPRISE_SEARCH_CONFIG_FILE_NAME = "config.json"
|
|
161
119
|
|
|
@@ -228,10 +186,7 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
228
186
|
@staticmethod
|
|
229
187
|
def get_default_config() -> Dict[str, Any]:
|
|
230
188
|
"""Returns the default config of the policy."""
|
|
231
|
-
return
|
|
232
|
-
POLICY_PRIORITY: SEARCH_POLICY_PRIORITY,
|
|
233
|
-
VECTOR_STORE_PROPERTY: DEFAULT_VECTOR_STORE,
|
|
234
|
-
}
|
|
189
|
+
return DEFAULT_ENTERPRISE_SEARCH_CONFIG
|
|
235
190
|
|
|
236
191
|
def __init__(
|
|
237
192
|
self,
|
|
@@ -246,105 +201,71 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
246
201
|
"""Constructs a new Policy object."""
|
|
247
202
|
super().__init__(config, model_storage, resource, execution_context, featurizer)
|
|
248
203
|
|
|
249
|
-
|
|
250
|
-
check_prompt_config_keys_and_warn_if_deprecated(
|
|
251
|
-
config, "enterprise_search_policy"
|
|
252
|
-
)
|
|
253
|
-
# Check for mutual exclusivity of extractive and generative search
|
|
254
|
-
self._check_and_warn_mutual_exclusivity_of_extractive_and_generative_search()
|
|
255
|
-
|
|
256
|
-
# Resolve LLM config
|
|
257
|
-
self.config[LLM_CONFIG_KEY] = resolve_model_client_config(
|
|
258
|
-
self.config.get(LLM_CONFIG_KEY), EnterpriseSearchPolicy.__name__
|
|
259
|
-
)
|
|
260
|
-
# Resolve embeddings config
|
|
261
|
-
self.config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
|
|
262
|
-
self.config.get(EMBEDDINGS_CONFIG_KEY), EnterpriseSearchPolicy.__name__
|
|
263
|
-
)
|
|
204
|
+
parsed_config = EnterpriseSearchPolicyConfig.from_dict(config)
|
|
264
205
|
|
|
265
206
|
# Vector store object and configuration
|
|
266
207
|
self.vector_store = vector_store
|
|
267
|
-
self.vector_store_config =
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
self.vector_search_threshold = self.vector_store_config.get(
|
|
271
|
-
VECTOR_STORE_THRESHOLD_PROPERTY, DEFAULT_VECTOR_STORE_THRESHOLD
|
|
272
|
-
)
|
|
208
|
+
self.vector_store_config = parsed_config.vector_store_config
|
|
209
|
+
self.vector_search_threshold = parsed_config.vector_store_threshold
|
|
210
|
+
self.vector_store_type = parsed_config.vector_store_type
|
|
273
211
|
|
|
274
|
-
#
|
|
275
|
-
self.embeddings_config =
|
|
276
|
-
self.config[EMBEDDINGS_CONFIG_KEY] or DEFAULT_EMBEDDINGS_CONFIG
|
|
277
|
-
)
|
|
212
|
+
# Resolved embeddings configuration for encoding the search query
|
|
213
|
+
self.embeddings_config = parsed_config.embeddings_config
|
|
278
214
|
|
|
279
|
-
# LLM Configuration for response generation
|
|
280
|
-
self.llm_config =
|
|
215
|
+
# Resolved LLM Configuration for response generation
|
|
216
|
+
self.llm_config = parsed_config.llm_config
|
|
281
217
|
|
|
282
218
|
# Maximum number of turns to include in the prompt
|
|
283
|
-
self.max_history =
|
|
219
|
+
self.max_history = parsed_config.max_history
|
|
284
220
|
|
|
285
221
|
# Maximum number of messages to include in the search query
|
|
286
|
-
self.max_messages_in_query =
|
|
222
|
+
self.max_messages_in_query = parsed_config.max_messages_in_query
|
|
287
223
|
|
|
288
224
|
# Boolean to enable/disable tracing of prompt tokens
|
|
289
|
-
self.trace_prompt_tokens =
|
|
225
|
+
self.trace_prompt_tokens = parsed_config.trace_prompt_tokens
|
|
290
226
|
|
|
291
227
|
# Boolean to enable/disable the use of LLM for response generation
|
|
292
|
-
self.use_llm =
|
|
228
|
+
self.use_llm = parsed_config.use_generative_llm
|
|
293
229
|
|
|
294
230
|
# Boolean to enable/disable citation generation. This flag enables citation
|
|
295
231
|
# logic, but it only takes effect if `use_llm` is True.
|
|
296
|
-
self.citation_enabled =
|
|
297
|
-
CITATION_ENABLED_PROPERTY, DEFAULT_CITATION_ENABLED_PROPERTY
|
|
298
|
-
)
|
|
232
|
+
self.citation_enabled = parsed_config.enable_citation
|
|
299
233
|
|
|
300
234
|
# Boolean to enable/disable the use of relevancy check alongside answer
|
|
301
235
|
# generation. This flag enables citation logic, but it only takes effect if
|
|
302
236
|
# `use_llm` is True.
|
|
303
|
-
self.relevancy_check_enabled =
|
|
304
|
-
CHECK_RELEVANCY_PROPERTY, DEFAULT_CHECK_RELEVANCY_PROPERTY
|
|
305
|
-
)
|
|
237
|
+
self.relevancy_check_enabled = parsed_config.check_relevancy
|
|
306
238
|
|
|
307
239
|
# Resolve the prompt template. The prompt will only be used if the 'use_llm' is
|
|
308
240
|
# set to True.
|
|
309
|
-
self.prompt_template = prompt_template or
|
|
310
|
-
|
|
241
|
+
self.prompt_template = prompt_template or get_prompt_template(
|
|
242
|
+
jinja_file_path=parsed_config.prompt_template,
|
|
243
|
+
default_prompt_template=self._select_default_prompt_template_based_on_features(
|
|
244
|
+
parsed_config.check_relevancy, parsed_config.enable_citation
|
|
245
|
+
),
|
|
246
|
+
log_source_component=EnterpriseSearchPolicy.__name__,
|
|
247
|
+
log_source_method=LOG_COMPONENT_SOURCE_METHOD_INIT,
|
|
311
248
|
)
|
|
312
249
|
|
|
313
|
-
def _check_and_warn_mutual_exclusivity_of_extractive_and_generative_search(
|
|
314
|
-
self,
|
|
315
|
-
) -> None:
|
|
316
|
-
if self.config.get(
|
|
317
|
-
CHECK_RELEVANCY_PROPERTY, DEFAULT_CHECK_RELEVANCY_PROPERTY
|
|
318
|
-
) and not self.config.get(USE_LLM_PROPERTY, DEFAULT_USE_LLM_PROPERTY):
|
|
319
|
-
structlogger.warning(
|
|
320
|
-
"enterprise_search_policy.init"
|
|
321
|
-
".relevancy_check_enabled_with_disabled_generative_search",
|
|
322
|
-
event_info=(
|
|
323
|
-
f"The config parameter '{CHECK_RELEVANCY_PROPERTY}' is set to"
|
|
324
|
-
f"'True', but the generative search is disabled (config"
|
|
325
|
-
f"parameter '{USE_LLM_PROPERTY}' is set to 'False'). As a result, "
|
|
326
|
-
"the relevancy check for the generative search will be disabled. "
|
|
327
|
-
f"To use this check, set the config parameter '{USE_LLM_PROPERTY}' "
|
|
328
|
-
f"to `True`."
|
|
329
|
-
),
|
|
330
|
-
)
|
|
331
|
-
|
|
332
250
|
@classmethod
|
|
333
|
-
def _create_plain_embedder(cls,
|
|
251
|
+
def _create_plain_embedder(cls, embeddings_config: Dict[Text, Any]) -> "Embeddings":
|
|
334
252
|
"""Creates an embedder based on the given configuration.
|
|
335
253
|
|
|
254
|
+
Args:
|
|
255
|
+
embeddings_config: A resolved embeddings configuration. Resolved means the
|
|
256
|
+
configuration is either:
|
|
257
|
+
- A reference to a model group that has already been expanded into
|
|
258
|
+
its corresponding configuration using the information from
|
|
259
|
+
`endpoints.yml`, or
|
|
260
|
+
- A full configuration for the embedder defined directly (i.e. not
|
|
261
|
+
relying on model groups or indirections).
|
|
262
|
+
|
|
336
263
|
Returns:
|
|
337
|
-
|
|
264
|
+
The embedder.
|
|
338
265
|
"""
|
|
339
266
|
# Copy the config so original config is not modified
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
|
|
343
|
-
config.get(EMBEDDINGS_CONFIG_KEY), EnterpriseSearchPolicy.__name__
|
|
344
|
-
)
|
|
345
|
-
client = embedder_factory(
|
|
346
|
-
config.get(EMBEDDINGS_CONFIG_KEY), DEFAULT_EMBEDDINGS_CONFIG
|
|
347
|
-
)
|
|
267
|
+
embeddings_config = embeddings_config.copy()
|
|
268
|
+
client = embedder_factory(embeddings_config, DEFAULT_EMBEDDINGS_CONFIG)
|
|
348
269
|
# Wrap the embedding client in the adapter
|
|
349
270
|
return _LangchainEmbeddingClientAdapter(client)
|
|
350
271
|
|
|
@@ -410,16 +331,16 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
410
331
|
can load the policy from the resource.
|
|
411
332
|
"""
|
|
412
333
|
# Perform health checks for both LLM and embeddings client configs
|
|
413
|
-
self._perform_health_checks(
|
|
414
|
-
|
|
415
|
-
|
|
334
|
+
self._perform_health_checks(
|
|
335
|
+
self.llm_config, self.embeddings_config, "enterprise_search_policy.train"
|
|
336
|
+
)
|
|
416
337
|
|
|
417
338
|
# telemetry call to track training start
|
|
418
339
|
track_enterprise_search_policy_train_started()
|
|
419
340
|
|
|
420
341
|
# validate embedding configuration
|
|
421
342
|
try:
|
|
422
|
-
embeddings = self._create_plain_embedder(self.
|
|
343
|
+
embeddings = self._create_plain_embedder(self.embeddings_config)
|
|
423
344
|
except (ValidationError, Exception) as e:
|
|
424
345
|
structlogger.error(
|
|
425
346
|
"enterprise_search_policy.train.embedder_instantiation_failed",
|
|
@@ -431,7 +352,7 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
431
352
|
f"required environment variables. Error: {e}"
|
|
432
353
|
)
|
|
433
354
|
|
|
434
|
-
if
|
|
355
|
+
if self.vector_store_type == DEFAULT_VECTOR_STORE_TYPE:
|
|
435
356
|
structlogger.info("enterprise_search_policy.train.faiss")
|
|
436
357
|
with self._model_storage.write_to(self._resource) as path:
|
|
437
358
|
self.vector_store = FAISS_Store(
|
|
@@ -443,12 +364,13 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
443
364
|
)
|
|
444
365
|
else:
|
|
445
366
|
structlogger.info(
|
|
446
|
-
"enterprise_search_policy.train.custom",
|
|
367
|
+
"enterprise_search_policy.train.custom",
|
|
368
|
+
store_type=self.vector_store_type,
|
|
447
369
|
)
|
|
448
370
|
|
|
449
371
|
# telemetry call to track training completion
|
|
450
372
|
track_enterprise_search_policy_train_completed(
|
|
451
|
-
vector_store_type=
|
|
373
|
+
vector_store_type=self.vector_store_type,
|
|
452
374
|
embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
|
|
453
375
|
embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
|
|
454
376
|
or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
|
|
@@ -471,8 +393,11 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
471
393
|
rasa.shared.utils.io.write_text_file(
|
|
472
394
|
self.prompt_template, path / ENTERPRISE_SEARCH_PROMPT_FILE_NAME
|
|
473
395
|
)
|
|
396
|
+
config = self.config.copy()
|
|
397
|
+
config[LLM_CONFIG_KEY] = self.llm_config
|
|
398
|
+
config[EMBEDDINGS_CONFIG_KEY] = self.embeddings_config
|
|
474
399
|
rasa.shared.utils.io.dump_obj_as_json_to_file(
|
|
475
|
-
path / ENTERPRISE_SEARCH_CONFIG_FILE_NAME,
|
|
400
|
+
path / ENTERPRISE_SEARCH_CONFIG_FILE_NAME, config
|
|
476
401
|
)
|
|
477
402
|
|
|
478
403
|
def _prepare_slots_for_template(
|
|
@@ -511,8 +436,7 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
511
436
|
endpoints: Endpoints configuration.
|
|
512
437
|
"""
|
|
513
438
|
config = endpoints.vector_store if endpoints else None
|
|
514
|
-
|
|
515
|
-
if config is None and store_type != DEFAULT_VECTOR_STORE_TYPE:
|
|
439
|
+
if config is None and self.vector_store_type != DEFAULT_VECTOR_STORE_TYPE:
|
|
516
440
|
structlogger.error(
|
|
517
441
|
"enterprise_search_policy._connect_vector_store_or_raise.no_config"
|
|
518
442
|
)
|
|
@@ -673,7 +597,7 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
673
597
|
|
|
674
598
|
# telemetry call to track policy prediction
|
|
675
599
|
track_enterprise_search_policy_predict(
|
|
676
|
-
vector_store_type=self.
|
|
600
|
+
vector_store_type=self.vector_store_type,
|
|
677
601
|
embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
|
|
678
602
|
embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
|
|
679
603
|
or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
|
|
@@ -732,7 +656,7 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
732
656
|
Returns:
|
|
733
657
|
An LLMResponse object, or None if the call fails.
|
|
734
658
|
"""
|
|
735
|
-
llm = llm_factory(self.
|
|
659
|
+
llm = llm_factory(self.llm_config, DEFAULT_LLM_CONFIG)
|
|
736
660
|
try:
|
|
737
661
|
response = await llm.acompletion(prompt)
|
|
738
662
|
return LLMResponse.ensure_llm_response(response)
|
|
@@ -862,73 +786,88 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
862
786
|
**kwargs: Any,
|
|
863
787
|
) -> "EnterpriseSearchPolicy":
|
|
864
788
|
"""Loads a trained policy (see parent class for full docstring)."""
|
|
789
|
+
parsed_config = EnterpriseSearchPolicyConfig.from_dict(config)
|
|
790
|
+
|
|
865
791
|
# Perform health checks for both LLM and embeddings client configs
|
|
866
|
-
cls._perform_health_checks(
|
|
792
|
+
cls._perform_health_checks(
|
|
793
|
+
parsed_config.llm_config,
|
|
794
|
+
parsed_config.embeddings_config,
|
|
795
|
+
"enterprise_search_policy.load",
|
|
796
|
+
)
|
|
867
797
|
|
|
868
|
-
prompt_template =
|
|
798
|
+
prompt_template = cls._load_prompt_template(model_storage, resource)
|
|
799
|
+
embeddings = cls._create_plain_embedder(parsed_config.embeddings_config)
|
|
800
|
+
vector_store = cls._load_vector_store(
|
|
801
|
+
embeddings,
|
|
802
|
+
parsed_config.vector_store_type,
|
|
803
|
+
parsed_config.use_generative_llm,
|
|
804
|
+
model_storage,
|
|
805
|
+
resource,
|
|
806
|
+
)
|
|
807
|
+
|
|
808
|
+
structlogger.info("enterprise_search_policy.load", config=config)
|
|
809
|
+
|
|
810
|
+
return cls(
|
|
811
|
+
config,
|
|
812
|
+
model_storage,
|
|
813
|
+
resource,
|
|
814
|
+
execution_context,
|
|
815
|
+
vector_store=vector_store,
|
|
816
|
+
prompt_template=prompt_template,
|
|
817
|
+
)
|
|
818
|
+
|
|
819
|
+
@classmethod
|
|
820
|
+
def _load_prompt_template(
|
|
821
|
+
cls, model_storage: ModelStorage, resource: Resource
|
|
822
|
+
) -> Optional[str]:
|
|
869
823
|
try:
|
|
870
824
|
with model_storage.read_from(resource) as path:
|
|
871
|
-
|
|
825
|
+
return rasa.shared.utils.io.read_file(
|
|
872
826
|
path / ENTERPRISE_SEARCH_PROMPT_FILE_NAME
|
|
873
827
|
)
|
|
874
828
|
except (FileNotFoundError, FileIOException) as e:
|
|
875
829
|
structlogger.warning(
|
|
876
830
|
"enterprise_search_policy.load.failed", error=e, resource=resource.name
|
|
877
831
|
)
|
|
832
|
+
return None
|
|
878
833
|
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
834
|
+
@classmethod
|
|
835
|
+
def _load_vector_store(
|
|
836
|
+
cls,
|
|
837
|
+
embeddings: "Embeddings",
|
|
838
|
+
store_type: str,
|
|
839
|
+
use_generative_llm: bool,
|
|
840
|
+
model_storage: ModelStorage,
|
|
841
|
+
resource: Resource,
|
|
842
|
+
) -> InformationRetrieval:
|
|
886
843
|
if store_type == DEFAULT_VECTOR_STORE_TYPE:
|
|
887
844
|
# if a vector store is not specified,
|
|
888
845
|
# default to using FAISS with the index stored in the model
|
|
889
846
|
# TODO figure out a way to get path without context manager
|
|
890
847
|
with model_storage.read_from(resource) as path:
|
|
891
|
-
|
|
848
|
+
return FAISS_Store(
|
|
892
849
|
embeddings=embeddings,
|
|
893
850
|
index_path=path,
|
|
894
851
|
docs_folder=None,
|
|
895
852
|
create_index=False,
|
|
896
|
-
parse_as_faq_pairs=not
|
|
897
|
-
USE_LLM_PROPERTY, DEFAULT_USE_LLM_PROPERTY
|
|
898
|
-
),
|
|
853
|
+
parse_as_faq_pairs=not use_generative_llm,
|
|
899
854
|
)
|
|
900
855
|
else:
|
|
901
|
-
|
|
856
|
+
return create_from_endpoint_config(
|
|
902
857
|
config_type=store_type,
|
|
903
858
|
embeddings=embeddings,
|
|
904
|
-
)
|
|
905
|
-
|
|
906
|
-
return cls(
|
|
907
|
-
config,
|
|
908
|
-
model_storage,
|
|
909
|
-
resource,
|
|
910
|
-
execution_context,
|
|
911
|
-
vector_store=vector_store,
|
|
912
|
-
prompt_template=prompt_template,
|
|
913
|
-
)
|
|
859
|
+
)
|
|
914
860
|
|
|
915
861
|
@classmethod
|
|
916
|
-
def _get_local_knowledge_data(
|
|
862
|
+
def _get_local_knowledge_data(
|
|
863
|
+
cls, store_type: str, source: Optional[str] = None
|
|
864
|
+
) -> Optional[List[str]]:
|
|
917
865
|
"""This is required only for local knowledge base types.
|
|
918
866
|
|
|
919
867
|
e.g. FAISS, to ensure that the graph component is retrained when the knowledge
|
|
920
868
|
base is updated.
|
|
921
869
|
"""
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
store_type = merged_config.get(VECTOR_STORE_PROPERTY, {}).get(
|
|
925
|
-
VECTOR_STORE_TYPE_PROPERTY
|
|
926
|
-
)
|
|
927
|
-
if store_type != DEFAULT_VECTOR_STORE_TYPE:
|
|
928
|
-
return None
|
|
929
|
-
|
|
930
|
-
source = merged_config.get(VECTOR_STORE_PROPERTY, {}).get(SOURCE_PROPERTY)
|
|
931
|
-
if not source:
|
|
870
|
+
if store_type != DEFAULT_VECTOR_STORE_TYPE or not source:
|
|
932
871
|
return None
|
|
933
872
|
|
|
934
873
|
docs = FAISS_Store.load_documents(source)
|
|
@@ -944,18 +883,28 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
944
883
|
@classmethod
|
|
945
884
|
def fingerprint_addon(cls, config: Dict[str, Any]) -> Optional[str]:
|
|
946
885
|
"""Add a fingerprint of enterprise search policy for the graph."""
|
|
947
|
-
|
|
948
|
-
config, LOG_COMPONENT_SOURCE_METHOD_FINGERPRINT_ADDON
|
|
949
|
-
)
|
|
950
|
-
|
|
951
|
-
local_knowledge_data = cls._get_local_knowledge_data(config)
|
|
886
|
+
parsed_config = EnterpriseSearchPolicyConfig.from_dict(config)
|
|
952
887
|
|
|
953
|
-
|
|
954
|
-
|
|
888
|
+
# Resolve the prompt template
|
|
889
|
+
default_prompt_template = cls._select_default_prompt_template_based_on_features(
|
|
890
|
+
parsed_config.check_relevancy, parsed_config.enable_citation
|
|
955
891
|
)
|
|
956
|
-
|
|
957
|
-
|
|
892
|
+
prompt_template = get_prompt_template(
|
|
893
|
+
jinja_file_path=parsed_config.prompt_template,
|
|
894
|
+
default_prompt_template=default_prompt_template,
|
|
895
|
+
log_source_component=EnterpriseSearchPolicy.__name__,
|
|
896
|
+
log_source_method=LOG_COMPONENT_SOURCE_METHOD_FINGERPRINT_ADDON,
|
|
897
|
+
)
|
|
898
|
+
|
|
899
|
+
# Fetch the local knowledge data in case FAISS is used
|
|
900
|
+
local_knowledge_data = cls._get_local_knowledge_data(
|
|
901
|
+
parsed_config.vector_store_type, parsed_config.vector_store_source
|
|
958
902
|
)
|
|
903
|
+
|
|
904
|
+
# Get the resolved LLM and embeddings configurations
|
|
905
|
+
llm_config = parsed_config.llm_config
|
|
906
|
+
embedding_config = parsed_config.embeddings_config
|
|
907
|
+
|
|
959
908
|
return deep_container_fingerprint(
|
|
960
909
|
[prompt_template, local_knowledge_data, llm_config, embedding_config]
|
|
961
910
|
)
|
|
@@ -1053,21 +1002,32 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
1053
1002
|
|
|
1054
1003
|
@classmethod
|
|
1055
1004
|
def _perform_health_checks(
|
|
1056
|
-
cls,
|
|
1005
|
+
cls,
|
|
1006
|
+
llm_config: Dict[Text, Any],
|
|
1007
|
+
embeddings_config: Dict[Text, Any],
|
|
1008
|
+
log_source_method: str,
|
|
1057
1009
|
) -> None:
|
|
1058
|
-
|
|
1059
|
-
|
|
1010
|
+
"""
|
|
1011
|
+
Perform the health checks using resolved LLM and embeddings configurations.
|
|
1012
|
+
Resolved means the configuration is either:
|
|
1013
|
+
- A reference to a model group that has already been expanded into
|
|
1014
|
+
its corresponding configuration using the information from
|
|
1015
|
+
`endpoints.yml`, or
|
|
1016
|
+
- A full configuration for the embedder defined directly (i.e. not
|
|
1017
|
+
relying on model groups or indirections).
|
|
1018
|
+
|
|
1019
|
+
Args:
|
|
1020
|
+
llm_config: A resolved LLM configuration.
|
|
1021
|
+
embeddings_config: A resolved embeddings configuration.
|
|
1022
|
+
log_source_method: The method health checks has been called from.
|
|
1023
|
+
|
|
1024
|
+
"""
|
|
1060
1025
|
cls.perform_llm_health_check(
|
|
1061
1026
|
llm_config,
|
|
1062
1027
|
DEFAULT_LLM_CONFIG,
|
|
1063
1028
|
log_source_method,
|
|
1064
1029
|
EnterpriseSearchPolicy.__name__,
|
|
1065
1030
|
)
|
|
1066
|
-
|
|
1067
|
-
# Perform health check of the embeddings client config
|
|
1068
|
-
embeddings_config = resolve_model_client_config(
|
|
1069
|
-
config.get(EMBEDDINGS_CONFIG_KEY, {})
|
|
1070
|
-
)
|
|
1071
1031
|
cls.perform_embeddings_health_check(
|
|
1072
1032
|
embeddings_config,
|
|
1073
1033
|
DEFAULT_EMBEDDINGS_CONFIG,
|
|
@@ -1093,62 +1053,16 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
1093
1053
|
Returns:
|
|
1094
1054
|
The resolved jinja prompt template as a string.
|
|
1095
1055
|
"""
|
|
1096
|
-
|
|
1097
1056
|
# Get the feature flags
|
|
1098
|
-
|
|
1099
|
-
CITATION_ENABLED_PROPERTY, DEFAULT_CITATION_ENABLED_PROPERTY
|
|
1100
|
-
)
|
|
1101
|
-
relevancy_check_enabled = config.get(
|
|
1102
|
-
CHECK_RELEVANCY_PROPERTY, DEFAULT_CHECK_RELEVANCY_PROPERTY
|
|
1103
|
-
)
|
|
1104
|
-
|
|
1057
|
+
parsed_config = EnterpriseSearchPolicyConfig.from_dict(config)
|
|
1105
1058
|
# Based on the enabled features (citation, relevancy check) fetch the
|
|
1106
1059
|
# appropriate default prompt
|
|
1107
1060
|
default_prompt = cls._select_default_prompt_template_based_on_features(
|
|
1108
|
-
|
|
1061
|
+
parsed_config.check_relevancy, parsed_config.enable_citation
|
|
1109
1062
|
)
|
|
1110
1063
|
|
|
1111
1064
|
return default_prompt
|
|
1112
1065
|
|
|
1113
|
-
@classmethod
|
|
1114
|
-
def _resolve_prompt_template(
|
|
1115
|
-
cls,
|
|
1116
|
-
config: dict,
|
|
1117
|
-
log_source_method: Literal["init", "fingerprint"],
|
|
1118
|
-
) -> str:
|
|
1119
|
-
"""
|
|
1120
|
-
Resolves the prompt template to use for the Enterprise Search Policy's
|
|
1121
|
-
generative search.
|
|
1122
|
-
|
|
1123
|
-
Checks if a custom template is provided via component's configuration. If not,
|
|
1124
|
-
it selects the appropriate default template based on the enabled features
|
|
1125
|
-
(citation and relevancy check).
|
|
1126
|
-
|
|
1127
|
-
Args:
|
|
1128
|
-
config: The component's configuration.
|
|
1129
|
-
log_source_method: The name of the method or function emitting the log for
|
|
1130
|
-
better traceability.
|
|
1131
|
-
Returns:
|
|
1132
|
-
The resolved jinja prompt template as a string.
|
|
1133
|
-
"""
|
|
1134
|
-
|
|
1135
|
-
# Read the template path from the configuration if available.
|
|
1136
|
-
# The deprecated 'prompt' has a lower priority compared to 'prompt_template'
|
|
1137
|
-
config_defined_prompt = (
|
|
1138
|
-
config.get(PROMPT_TEMPLATE_CONFIG_KEY)
|
|
1139
|
-
or config.get(PROMPT_CONFIG_KEY)
|
|
1140
|
-
or None
|
|
1141
|
-
)
|
|
1142
|
-
# Select the default prompt based on the features set in the config.
|
|
1143
|
-
default_prompt = cls.get_system_default_prompt_based_on_config(config)
|
|
1144
|
-
|
|
1145
|
-
return get_prompt_template(
|
|
1146
|
-
config_defined_prompt,
|
|
1147
|
-
default_prompt,
|
|
1148
|
-
log_source_component=EnterpriseSearchPolicy.__name__,
|
|
1149
|
-
log_source_method=log_source_method,
|
|
1150
|
-
)
|
|
1151
|
-
|
|
1152
1066
|
@classmethod
|
|
1153
1067
|
def _select_default_prompt_template_based_on_features(
|
|
1154
1068
|
cls,
|