rasa-pro 3.13.0.dev9__py3-none-any.whl → 3.13.0.dev10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

@@ -2,7 +2,7 @@ import dataclasses
2
2
  import importlib.resources
3
3
  import json
4
4
  import re
5
- from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Text
5
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Text
6
6
 
7
7
  import dotenv
8
8
  import structlog
@@ -12,9 +12,6 @@ from pydantic import ValidationError
12
12
  import rasa.shared.utils.io
13
13
  from rasa.core.available_endpoints import AvailableEndpoints
14
14
  from rasa.core.constants import (
15
- POLICY_MAX_HISTORY,
16
- POLICY_PRIORITY,
17
- SEARCH_POLICY_PRIORITY,
18
15
  UTTER_SOURCE_METADATA_KEY,
19
16
  )
20
17
  from rasa.core.information_retrieval import (
@@ -24,6 +21,14 @@ from rasa.core.information_retrieval import (
24
21
  create_from_endpoint_config,
25
22
  )
26
23
  from rasa.core.information_retrieval.faiss import FAISS_Store
24
+ from rasa.core.policies.enterprise_search_policy_config import (
25
+ DEFAULT_EMBEDDINGS_CONFIG,
26
+ DEFAULT_ENTERPRISE_SEARCH_CONFIG,
27
+ DEFAULT_LLM_CONFIG,
28
+ DEFAULT_VECTOR_STORE_TYPE,
29
+ SOURCE_PROPERTY,
30
+ EnterpriseSearchPolicyConfig,
31
+ )
27
32
  from rasa.core.policies.policy import Policy, PolicyPrediction
28
33
  from rasa.dialogue_understanding.generator.constants import (
29
34
  LLM_CONFIG_KEY,
@@ -47,18 +52,11 @@ from rasa.graph_components.providers.forms_provider import Forms
47
52
  from rasa.graph_components.providers.responses_provider import Responses
48
53
  from rasa.shared.constants import (
49
54
  EMBEDDINGS_CONFIG_KEY,
50
- MAX_COMPLETION_TOKENS_CONFIG_KEY,
51
- MAX_RETRIES_CONFIG_KEY,
52
55
  MODEL_CONFIG_KEY,
53
56
  MODEL_GROUP_ID_CONFIG_KEY,
54
57
  MODEL_NAME_CONFIG_KEY,
55
- OPENAI_PROVIDER,
56
- PROMPT_CONFIG_KEY,
57
- PROMPT_TEMPLATE_CONFIG_KEY,
58
58
  PROVIDER_CONFIG_KEY,
59
59
  RASA_PATTERN_CANNOT_HANDLE_NO_RELEVANT_ANSWER,
60
- TEMPERATURE_CONFIG_KEY,
61
- TIMEOUT_CONFIG_KEY,
62
60
  )
63
61
  from rasa.shared.core.constants import (
64
62
  ACTION_CANCEL_FLOW,
@@ -93,13 +91,9 @@ from rasa.shared.utils.health_check.embeddings_health_check_mixin import (
93
91
  from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheckMixin
94
92
  from rasa.shared.utils.io import deep_container_fingerprint
95
93
  from rasa.shared.utils.llm import (
96
- DEFAULT_OPENAI_CHAT_MODEL_NAME,
97
- DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
98
- check_prompt_config_keys_and_warn_if_deprecated,
99
94
  embedder_factory,
100
95
  get_prompt_template,
101
96
  llm_factory,
102
- resolve_model_client_config,
103
97
  sanitize_message_for_prompt,
104
98
  tracker_as_readable_transcript,
105
99
  )
@@ -120,42 +114,6 @@ structlogger = structlog.get_logger()
120
114
 
121
115
  dotenv.load_dotenv("./.env")
122
116
 
123
- SOURCE_PROPERTY = "source"
124
- VECTOR_STORE_TYPE_PROPERTY = "type"
125
- VECTOR_STORE_PROPERTY = "vector_store"
126
- VECTOR_STORE_THRESHOLD_PROPERTY = "threshold"
127
- TRACE_TOKENS_PROPERTY = "trace_prompt_tokens"
128
- CITATION_ENABLED_PROPERTY = "citation_enabled"
129
- USE_LLM_PROPERTY = "use_generative_llm"
130
- CHECK_RELEVANCY_PROPERTY = "check_relevancy"
131
- MAX_MESSAGES_IN_QUERY_KEY = "max_messages_in_query"
132
-
133
- DEFAULT_VECTOR_STORE_TYPE = "faiss"
134
- DEFAULT_VECTOR_STORE_THRESHOLD = 0.0
135
- DEFAULT_VECTOR_STORE = {
136
- VECTOR_STORE_TYPE_PROPERTY: DEFAULT_VECTOR_STORE_TYPE,
137
- SOURCE_PROPERTY: "./docs",
138
- VECTOR_STORE_THRESHOLD_PROPERTY: DEFAULT_VECTOR_STORE_THRESHOLD,
139
- }
140
-
141
- DEFAULT_CHECK_RELEVANCY_PROPERTY = False
142
- DEFAULT_USE_LLM_PROPERTY = True
143
- DEFAULT_CITATION_ENABLED_PROPERTY = False
144
-
145
- DEFAULT_LLM_CONFIG = {
146
- PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
147
- MODEL_CONFIG_KEY: DEFAULT_OPENAI_CHAT_MODEL_NAME,
148
- TIMEOUT_CONFIG_KEY: 10,
149
- TEMPERATURE_CONFIG_KEY: 0.0,
150
- MAX_COMPLETION_TOKENS_CONFIG_KEY: 256,
151
- MAX_RETRIES_CONFIG_KEY: 1,
152
- }
153
-
154
- DEFAULT_EMBEDDINGS_CONFIG = {
155
- PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
156
- MODEL_CONFIG_KEY: DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
157
- }
158
-
159
117
  ENTERPRISE_SEARCH_PROMPT_FILE_NAME = "enterprise_search_policy_prompt.jinja2"
160
118
  ENTERPRISE_SEARCH_CONFIG_FILE_NAME = "config.json"
161
119
 
@@ -228,10 +186,7 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
228
186
  @staticmethod
229
187
  def get_default_config() -> Dict[str, Any]:
230
188
  """Returns the default config of the policy."""
231
- return {
232
- POLICY_PRIORITY: SEARCH_POLICY_PRIORITY,
233
- VECTOR_STORE_PROPERTY: DEFAULT_VECTOR_STORE,
234
- }
189
+ return DEFAULT_ENTERPRISE_SEARCH_CONFIG
235
190
 
236
191
  def __init__(
237
192
  self,
@@ -246,105 +201,71 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
246
201
  """Constructs a new Policy object."""
247
202
  super().__init__(config, model_storage, resource, execution_context, featurizer)
248
203
 
249
- # Check for deprecated keys and issue a warning if those are used
250
- check_prompt_config_keys_and_warn_if_deprecated(
251
- config, "enterprise_search_policy"
252
- )
253
- # Check for mutual exclusivity of extractive and generative search
254
- self._check_and_warn_mutual_exclusivity_of_extractive_and_generative_search()
255
-
256
- # Resolve LLM config
257
- self.config[LLM_CONFIG_KEY] = resolve_model_client_config(
258
- self.config.get(LLM_CONFIG_KEY), EnterpriseSearchPolicy.__name__
259
- )
260
- # Resolve embeddings config
261
- self.config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
262
- self.config.get(EMBEDDINGS_CONFIG_KEY), EnterpriseSearchPolicy.__name__
263
- )
204
+ parsed_config = EnterpriseSearchPolicyConfig.from_dict(config)
264
205
 
265
206
  # Vector store object and configuration
266
207
  self.vector_store = vector_store
267
- self.vector_store_config = self.config.get(
268
- VECTOR_STORE_PROPERTY, DEFAULT_VECTOR_STORE
269
- )
270
- self.vector_search_threshold = self.vector_store_config.get(
271
- VECTOR_STORE_THRESHOLD_PROPERTY, DEFAULT_VECTOR_STORE_THRESHOLD
272
- )
208
+ self.vector_store_config = parsed_config.vector_store_config
209
+ self.vector_search_threshold = parsed_config.vector_store_threshold
210
+ self.vector_store_type = parsed_config.vector_store_type
273
211
 
274
- # Embeddings configuration for encoding the search query
275
- self.embeddings_config = (
276
- self.config[EMBEDDINGS_CONFIG_KEY] or DEFAULT_EMBEDDINGS_CONFIG
277
- )
212
+ # Resolved embeddings configuration for encoding the search query
213
+ self.embeddings_config = parsed_config.embeddings_config
278
214
 
279
- # LLM Configuration for response generation
280
- self.llm_config = self.config[LLM_CONFIG_KEY] or DEFAULT_LLM_CONFIG
215
+ # Resolved LLM Configuration for response generation
216
+ self.llm_config = parsed_config.llm_config
281
217
 
282
218
  # Maximum number of turns to include in the prompt
283
- self.max_history = self.config.get(POLICY_MAX_HISTORY)
219
+ self.max_history = parsed_config.max_history
284
220
 
285
221
  # Maximum number of messages to include in the search query
286
- self.max_messages_in_query = self.config.get(MAX_MESSAGES_IN_QUERY_KEY, 2)
222
+ self.max_messages_in_query = parsed_config.max_messages_in_query
287
223
 
288
224
  # Boolean to enable/disable tracing of prompt tokens
289
- self.trace_prompt_tokens = self.config.get(TRACE_TOKENS_PROPERTY, False)
225
+ self.trace_prompt_tokens = parsed_config.trace_prompt_tokens
290
226
 
291
227
  # Boolean to enable/disable the use of LLM for response generation
292
- self.use_llm = self.config.get(USE_LLM_PROPERTY, DEFAULT_USE_LLM_PROPERTY)
228
+ self.use_llm = parsed_config.use_generative_llm
293
229
 
294
230
  # Boolean to enable/disable citation generation. This flag enables citation
295
231
  # logic, but it only takes effect if `use_llm` is True.
296
- self.citation_enabled = self.config.get(
297
- CITATION_ENABLED_PROPERTY, DEFAULT_CITATION_ENABLED_PROPERTY
298
- )
232
+ self.citation_enabled = parsed_config.enable_citation
299
233
 
300
234
  # Boolean to enable/disable the use of relevancy check alongside answer
301
235
  # generation. This flag enables citation logic, but it only takes effect if
302
236
  # `use_llm` is True.
303
- self.relevancy_check_enabled = self.config.get(
304
- CHECK_RELEVANCY_PROPERTY, DEFAULT_CHECK_RELEVANCY_PROPERTY
305
- )
237
+ self.relevancy_check_enabled = parsed_config.check_relevancy
306
238
 
307
239
  # Resolve the prompt template. The prompt will only be used if the 'use_llm' is
308
240
  # set to True.
309
- self.prompt_template = prompt_template or self._resolve_prompt_template(
310
- self.config, LOG_COMPONENT_SOURCE_METHOD_INIT
241
+ self.prompt_template = prompt_template or get_prompt_template(
242
+ jinja_file_path=parsed_config.prompt_template,
243
+ default_prompt_template=self._select_default_prompt_template_based_on_features(
244
+ parsed_config.check_relevancy, parsed_config.enable_citation
245
+ ),
246
+ log_source_component=EnterpriseSearchPolicy.__name__,
247
+ log_source_method=LOG_COMPONENT_SOURCE_METHOD_INIT,
311
248
  )
312
249
 
313
- def _check_and_warn_mutual_exclusivity_of_extractive_and_generative_search(
314
- self,
315
- ) -> None:
316
- if self.config.get(
317
- CHECK_RELEVANCY_PROPERTY, DEFAULT_CHECK_RELEVANCY_PROPERTY
318
- ) and not self.config.get(USE_LLM_PROPERTY, DEFAULT_USE_LLM_PROPERTY):
319
- structlogger.warning(
320
- "enterprise_search_policy.init"
321
- ".relevancy_check_enabled_with_disabled_generative_search",
322
- event_info=(
323
- f"The config parameter '{CHECK_RELEVANCY_PROPERTY}' is set to"
324
- f"'True', but the generative search is disabled (config"
325
- f"parameter '{USE_LLM_PROPERTY}' is set to 'False'). As a result, "
326
- "the relevancy check for the generative search will be disabled. "
327
- f"To use this check, set the config parameter '{USE_LLM_PROPERTY}' "
328
- f"to `True`."
329
- ),
330
- )
331
-
332
250
  @classmethod
333
- def _create_plain_embedder(cls, config: Dict[Text, Any]) -> "Embeddings":
251
+ def _create_plain_embedder(cls, embeddings_config: Dict[Text, Any]) -> "Embeddings":
334
252
  """Creates an embedder based on the given configuration.
335
253
 
254
+ Args:
255
+ embeddings_config: A resolved embeddings configuration. Resolved means the
256
+ configuration is either:
257
+ - A reference to a model group that has already been expanded into
258
+ its corresponding configuration using the information from
259
+ `endpoints.yml`, or
260
+ - A full configuration for the embedder defined directly (i.e. not
261
+ relying on model groups or indirections).
262
+
336
263
  Returns:
337
- The embedder.
264
+ The embedder.
338
265
  """
339
266
  # Copy the config so original config is not modified
340
- config = config.copy()
341
- # Resolve config and instantiate the embedding client
342
- config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
343
- config.get(EMBEDDINGS_CONFIG_KEY), EnterpriseSearchPolicy.__name__
344
- )
345
- client = embedder_factory(
346
- config.get(EMBEDDINGS_CONFIG_KEY), DEFAULT_EMBEDDINGS_CONFIG
347
- )
267
+ embeddings_config = embeddings_config.copy()
268
+ client = embedder_factory(embeddings_config, DEFAULT_EMBEDDINGS_CONFIG)
348
269
  # Wrap the embedding client in the adapter
349
270
  return _LangchainEmbeddingClientAdapter(client)
350
271
 
@@ -410,16 +331,16 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
410
331
  can load the policy from the resource.
411
332
  """
412
333
  # Perform health checks for both LLM and embeddings client configs
413
- self._perform_health_checks(self.config, "enterprise_search_policy.train")
414
-
415
- store_type = self.vector_store_config.get(VECTOR_STORE_TYPE_PROPERTY)
334
+ self._perform_health_checks(
335
+ self.llm_config, self.embeddings_config, "enterprise_search_policy.train"
336
+ )
416
337
 
417
338
  # telemetry call to track training start
418
339
  track_enterprise_search_policy_train_started()
419
340
 
420
341
  # validate embedding configuration
421
342
  try:
422
- embeddings = self._create_plain_embedder(self.config)
343
+ embeddings = self._create_plain_embedder(self.embeddings_config)
423
344
  except (ValidationError, Exception) as e:
424
345
  structlogger.error(
425
346
  "enterprise_search_policy.train.embedder_instantiation_failed",
@@ -431,7 +352,7 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
431
352
  f"required environment variables. Error: {e}"
432
353
  )
433
354
 
434
- if store_type == DEFAULT_VECTOR_STORE_TYPE:
355
+ if self.vector_store_type == DEFAULT_VECTOR_STORE_TYPE:
435
356
  structlogger.info("enterprise_search_policy.train.faiss")
436
357
  with self._model_storage.write_to(self._resource) as path:
437
358
  self.vector_store = FAISS_Store(
@@ -443,12 +364,13 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
443
364
  )
444
365
  else:
445
366
  structlogger.info(
446
- "enterprise_search_policy.train.custom", store_type=store_type
367
+ "enterprise_search_policy.train.custom",
368
+ store_type=self.vector_store_type,
447
369
  )
448
370
 
449
371
  # telemetry call to track training completion
450
372
  track_enterprise_search_policy_train_completed(
451
- vector_store_type=store_type,
373
+ vector_store_type=self.vector_store_type,
452
374
  embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
453
375
  embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
454
376
  or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
@@ -471,8 +393,11 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
471
393
  rasa.shared.utils.io.write_text_file(
472
394
  self.prompt_template, path / ENTERPRISE_SEARCH_PROMPT_FILE_NAME
473
395
  )
396
+ config = self.config.copy()
397
+ config[LLM_CONFIG_KEY] = self.llm_config
398
+ config[EMBEDDINGS_CONFIG_KEY] = self.embeddings_config
474
399
  rasa.shared.utils.io.dump_obj_as_json_to_file(
475
- path / ENTERPRISE_SEARCH_CONFIG_FILE_NAME, self.config
400
+ path / ENTERPRISE_SEARCH_CONFIG_FILE_NAME, config
476
401
  )
477
402
 
478
403
  def _prepare_slots_for_template(
@@ -511,8 +436,7 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
511
436
  endpoints: Endpoints configuration.
512
437
  """
513
438
  config = endpoints.vector_store if endpoints else None
514
- store_type = self.vector_store_config.get(VECTOR_STORE_TYPE_PROPERTY)
515
- if config is None and store_type != DEFAULT_VECTOR_STORE_TYPE:
439
+ if config is None and self.vector_store_type != DEFAULT_VECTOR_STORE_TYPE:
516
440
  structlogger.error(
517
441
  "enterprise_search_policy._connect_vector_store_or_raise.no_config"
518
442
  )
@@ -673,7 +597,7 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
673
597
 
674
598
  # telemetry call to track policy prediction
675
599
  track_enterprise_search_policy_predict(
676
- vector_store_type=self.vector_store_config.get(VECTOR_STORE_TYPE_PROPERTY),
600
+ vector_store_type=self.vector_store_type,
677
601
  embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
678
602
  embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
679
603
  or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
@@ -732,7 +656,7 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
732
656
  Returns:
733
657
  An LLMResponse object, or None if the call fails.
734
658
  """
735
- llm = llm_factory(self.config.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG)
659
+ llm = llm_factory(self.llm_config, DEFAULT_LLM_CONFIG)
736
660
  try:
737
661
  response = await llm.acompletion(prompt)
738
662
  return LLMResponse.ensure_llm_response(response)
@@ -862,73 +786,88 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
862
786
  **kwargs: Any,
863
787
  ) -> "EnterpriseSearchPolicy":
864
788
  """Loads a trained policy (see parent class for full docstring)."""
789
+ parsed_config = EnterpriseSearchPolicyConfig.from_dict(config)
790
+
865
791
  # Perform health checks for both LLM and embeddings client configs
866
- cls._perform_health_checks(config, "enterprise_search_policy.load")
792
+ cls._perform_health_checks(
793
+ parsed_config.llm_config,
794
+ parsed_config.embeddings_config,
795
+ "enterprise_search_policy.load",
796
+ )
867
797
 
868
- prompt_template = None
798
+ prompt_template = cls._load_prompt_template(model_storage, resource)
799
+ embeddings = cls._create_plain_embedder(parsed_config.embeddings_config)
800
+ vector_store = cls._load_vector_store(
801
+ embeddings,
802
+ parsed_config.vector_store_type,
803
+ parsed_config.use_generative_llm,
804
+ model_storage,
805
+ resource,
806
+ )
807
+
808
+ structlogger.info("enterprise_search_policy.load", config=config)
809
+
810
+ return cls(
811
+ config,
812
+ model_storage,
813
+ resource,
814
+ execution_context,
815
+ vector_store=vector_store,
816
+ prompt_template=prompt_template,
817
+ )
818
+
819
+ @classmethod
820
+ def _load_prompt_template(
821
+ cls, model_storage: ModelStorage, resource: Resource
822
+ ) -> Optional[str]:
869
823
  try:
870
824
  with model_storage.read_from(resource) as path:
871
- prompt_template = rasa.shared.utils.io.read_file(
825
+ return rasa.shared.utils.io.read_file(
872
826
  path / ENTERPRISE_SEARCH_PROMPT_FILE_NAME
873
827
  )
874
828
  except (FileNotFoundError, FileIOException) as e:
875
829
  structlogger.warning(
876
830
  "enterprise_search_policy.load.failed", error=e, resource=resource.name
877
831
  )
832
+ return None
878
833
 
879
- store_type = config.get(VECTOR_STORE_PROPERTY, {}).get(
880
- VECTOR_STORE_TYPE_PROPERTY
881
- )
882
-
883
- embeddings = cls._create_plain_embedder(config)
884
-
885
- structlogger.info("enterprise_search_policy.load", config=config)
834
+ @classmethod
835
+ def _load_vector_store(
836
+ cls,
837
+ embeddings: "Embeddings",
838
+ store_type: str,
839
+ use_generative_llm: bool,
840
+ model_storage: ModelStorage,
841
+ resource: Resource,
842
+ ) -> InformationRetrieval:
886
843
  if store_type == DEFAULT_VECTOR_STORE_TYPE:
887
844
  # if a vector store is not specified,
888
845
  # default to using FAISS with the index stored in the model
889
846
  # TODO figure out a way to get path without context manager
890
847
  with model_storage.read_from(resource) as path:
891
- vector_store = FAISS_Store(
848
+ return FAISS_Store(
892
849
  embeddings=embeddings,
893
850
  index_path=path,
894
851
  docs_folder=None,
895
852
  create_index=False,
896
- parse_as_faq_pairs=not config.get(
897
- USE_LLM_PROPERTY, DEFAULT_USE_LLM_PROPERTY
898
- ),
853
+ parse_as_faq_pairs=not use_generative_llm,
899
854
  )
900
855
  else:
901
- vector_store = create_from_endpoint_config(
856
+ return create_from_endpoint_config(
902
857
  config_type=store_type,
903
858
  embeddings=embeddings,
904
- ) # type: ignore
905
-
906
- return cls(
907
- config,
908
- model_storage,
909
- resource,
910
- execution_context,
911
- vector_store=vector_store,
912
- prompt_template=prompt_template,
913
- )
859
+ )
914
860
 
915
861
  @classmethod
916
- def _get_local_knowledge_data(cls, config: Dict[str, Any]) -> Optional[List[str]]:
862
+ def _get_local_knowledge_data(
863
+ cls, store_type: str, source: Optional[str] = None
864
+ ) -> Optional[List[str]]:
917
865
  """This is required only for local knowledge base types.
918
866
 
919
867
  e.g. FAISS, to ensure that the graph component is retrained when the knowledge
920
868
  base is updated.
921
869
  """
922
- merged_config = {**cls.get_default_config(), **config}
923
-
924
- store_type = merged_config.get(VECTOR_STORE_PROPERTY, {}).get(
925
- VECTOR_STORE_TYPE_PROPERTY
926
- )
927
- if store_type != DEFAULT_VECTOR_STORE_TYPE:
928
- return None
929
-
930
- source = merged_config.get(VECTOR_STORE_PROPERTY, {}).get(SOURCE_PROPERTY)
931
- if not source:
870
+ if store_type != DEFAULT_VECTOR_STORE_TYPE or not source:
932
871
  return None
933
872
 
934
873
  docs = FAISS_Store.load_documents(source)
@@ -944,18 +883,28 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
944
883
  @classmethod
945
884
  def fingerprint_addon(cls, config: Dict[str, Any]) -> Optional[str]:
946
885
  """Add a fingerprint of enterprise search policy for the graph."""
947
- prompt_template = cls._resolve_prompt_template(
948
- config, LOG_COMPONENT_SOURCE_METHOD_FINGERPRINT_ADDON
949
- )
950
-
951
- local_knowledge_data = cls._get_local_knowledge_data(config)
886
+ parsed_config = EnterpriseSearchPolicyConfig.from_dict(config)
952
887
 
953
- llm_config = resolve_model_client_config(
954
- config.get(LLM_CONFIG_KEY), EnterpriseSearchPolicy.__name__
888
+ # Resolve the prompt template
889
+ default_prompt_template = cls._select_default_prompt_template_based_on_features(
890
+ parsed_config.check_relevancy, parsed_config.enable_citation
955
891
  )
956
- embedding_config = resolve_model_client_config(
957
- config.get(EMBEDDINGS_CONFIG_KEY), EnterpriseSearchPolicy.__name__
892
+ prompt_template = get_prompt_template(
893
+ jinja_file_path=parsed_config.prompt_template,
894
+ default_prompt_template=default_prompt_template,
895
+ log_source_component=EnterpriseSearchPolicy.__name__,
896
+ log_source_method=LOG_COMPONENT_SOURCE_METHOD_FINGERPRINT_ADDON,
897
+ )
898
+
899
+ # Fetch the local knowledge data in case FAISS is used
900
+ local_knowledge_data = cls._get_local_knowledge_data(
901
+ parsed_config.vector_store_type, parsed_config.vector_store_source
958
902
  )
903
+
904
+ # Get the resolved LLM and embeddings configurations
905
+ llm_config = parsed_config.llm_config
906
+ embedding_config = parsed_config.embeddings_config
907
+
959
908
  return deep_container_fingerprint(
960
909
  [prompt_template, local_knowledge_data, llm_config, embedding_config]
961
910
  )
@@ -1053,21 +1002,32 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
1053
1002
 
1054
1003
  @classmethod
1055
1004
  def _perform_health_checks(
1056
- cls, config: Dict[Text, Any], log_source_method: str
1005
+ cls,
1006
+ llm_config: Dict[Text, Any],
1007
+ embeddings_config: Dict[Text, Any],
1008
+ log_source_method: str,
1057
1009
  ) -> None:
1058
- # Perform health check of the LLM client config
1059
- llm_config = resolve_model_client_config(config.get(LLM_CONFIG_KEY, {}))
1010
+ """
1011
+ Perform the health checks using resolved LLM and embeddings configurations.
1012
+ Resolved means the configuration is either:
1013
+ - A reference to a model group that has already been expanded into
1014
+ its corresponding configuration using the information from
1015
+ `endpoints.yml`, or
1016
+ - A full configuration for the embedder defined directly (i.e. not
1017
+ relying on model groups or indirections).
1018
+
1019
+ Args:
1020
+ llm_config: A resolved LLM configuration.
1021
+ embeddings_config: A resolved embeddings configuration.
1022
+ log_source_method: The method health checks has been called from.
1023
+
1024
+ """
1060
1025
  cls.perform_llm_health_check(
1061
1026
  llm_config,
1062
1027
  DEFAULT_LLM_CONFIG,
1063
1028
  log_source_method,
1064
1029
  EnterpriseSearchPolicy.__name__,
1065
1030
  )
1066
-
1067
- # Perform health check of the embeddings client config
1068
- embeddings_config = resolve_model_client_config(
1069
- config.get(EMBEDDINGS_CONFIG_KEY, {})
1070
- )
1071
1031
  cls.perform_embeddings_health_check(
1072
1032
  embeddings_config,
1073
1033
  DEFAULT_EMBEDDINGS_CONFIG,
@@ -1093,62 +1053,16 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
1093
1053
  Returns:
1094
1054
  The resolved jinja prompt template as a string.
1095
1055
  """
1096
-
1097
1056
  # Get the feature flags
1098
- citation_enabled = config.get(
1099
- CITATION_ENABLED_PROPERTY, DEFAULT_CITATION_ENABLED_PROPERTY
1100
- )
1101
- relevancy_check_enabled = config.get(
1102
- CHECK_RELEVANCY_PROPERTY, DEFAULT_CHECK_RELEVANCY_PROPERTY
1103
- )
1104
-
1057
+ parsed_config = EnterpriseSearchPolicyConfig.from_dict(config)
1105
1058
  # Based on the enabled features (citation, relevancy check) fetch the
1106
1059
  # appropriate default prompt
1107
1060
  default_prompt = cls._select_default_prompt_template_based_on_features(
1108
- relevancy_check_enabled, citation_enabled
1061
+ parsed_config.check_relevancy, parsed_config.enable_citation
1109
1062
  )
1110
1063
 
1111
1064
  return default_prompt
1112
1065
 
1113
- @classmethod
1114
- def _resolve_prompt_template(
1115
- cls,
1116
- config: dict,
1117
- log_source_method: Literal["init", "fingerprint"],
1118
- ) -> str:
1119
- """
1120
- Resolves the prompt template to use for the Enterprise Search Policy's
1121
- generative search.
1122
-
1123
- Checks if a custom template is provided via component's configuration. If not,
1124
- it selects the appropriate default template based on the enabled features
1125
- (citation and relevancy check).
1126
-
1127
- Args:
1128
- config: The component's configuration.
1129
- log_source_method: The name of the method or function emitting the log for
1130
- better traceability.
1131
- Returns:
1132
- The resolved jinja prompt template as a string.
1133
- """
1134
-
1135
- # Read the template path from the configuration if available.
1136
- # The deprecated 'prompt' has a lower priority compared to 'prompt_template'
1137
- config_defined_prompt = (
1138
- config.get(PROMPT_TEMPLATE_CONFIG_KEY)
1139
- or config.get(PROMPT_CONFIG_KEY)
1140
- or None
1141
- )
1142
- # Select the default prompt based on the features set in the config.
1143
- default_prompt = cls.get_system_default_prompt_based_on_config(config)
1144
-
1145
- return get_prompt_template(
1146
- config_defined_prompt,
1147
- default_prompt,
1148
- log_source_component=EnterpriseSearchPolicy.__name__,
1149
- log_source_method=log_source_method,
1150
- )
1151
-
1152
1066
  @classmethod
1153
1067
  def _select_default_prompt_template_based_on_features(
1154
1068
  cls,