rasa-pro 3.13.0.dev9__py3-none-any.whl → 3.13.0.dev11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (45) hide show
  1. rasa/cli/export.py +2 -0
  2. rasa/cli/studio/download.py +3 -9
  3. rasa/cli/studio/link.py +1 -2
  4. rasa/cli/studio/pull.py +3 -2
  5. rasa/cli/studio/push.py +1 -1
  6. rasa/cli/studio/train.py +0 -1
  7. rasa/core/exporter.py +36 -0
  8. rasa/core/policies/enterprise_search_policy.py +151 -240
  9. rasa/core/policies/enterprise_search_policy_config.py +242 -0
  10. rasa/core/policies/enterprise_search_prompt_with_relevancy_check_and_citation_template.jinja2 +6 -5
  11. rasa/core/utils.py +11 -2
  12. rasa/dialogue_understanding/commands/__init__.py +4 -0
  13. rasa/dialogue_understanding/generator/command_generator.py +11 -1
  14. rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v3_claude_3_5_sonnet_20240620_template.jinja2 +78 -0
  15. rasa/dialogue_understanding/generator/single_step/search_ready_llm_command_generator.py +2 -2
  16. rasa/dialogue_understanding/processor/command_processor.py +5 -5
  17. rasa/shared/core/flows/validation.py +9 -2
  18. rasa/shared/providers/_configs/azure_openai_client_config.py +2 -2
  19. rasa/shared/providers/_configs/default_litellm_client_config.py +1 -1
  20. rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +1 -1
  21. rasa/shared/providers/_configs/openai_client_config.py +1 -1
  22. rasa/shared/providers/_configs/rasa_llm_client_config.py +1 -1
  23. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +1 -1
  24. rasa/shared/providers/_configs/utils.py +0 -99
  25. rasa/shared/utils/configs.py +110 -0
  26. rasa/shared/utils/constants.py +0 -3
  27. rasa/shared/utils/pykwalify_extensions.py +0 -9
  28. rasa/studio/constants.py +1 -0
  29. rasa/studio/download.py +164 -0
  30. rasa/studio/link.py +1 -1
  31. rasa/studio/{download/flows.py → pull/data.py} +2 -131
  32. rasa/studio/{download → pull}/domains.py +1 -1
  33. rasa/studio/pull/pull.py +235 -0
  34. rasa/studio/push.py +5 -0
  35. rasa/studio/train.py +1 -1
  36. rasa/tracing/instrumentation/attribute_extractors.py +10 -5
  37. rasa/version.py +1 -1
  38. {rasa_pro-3.13.0.dev9.dist-info → rasa_pro-3.13.0.dev11.dist-info}/METADATA +1 -1
  39. {rasa_pro-3.13.0.dev9.dist-info → rasa_pro-3.13.0.dev11.dist-info}/RECORD +43 -40
  40. rasa/studio/download/download.py +0 -416
  41. rasa/studio/pull.py +0 -94
  42. /rasa/studio/{download → pull}/__init__.py +0 -0
  43. {rasa_pro-3.13.0.dev9.dist-info → rasa_pro-3.13.0.dev11.dist-info}/NOTICE +0 -0
  44. {rasa_pro-3.13.0.dev9.dist-info → rasa_pro-3.13.0.dev11.dist-info}/WHEEL +0 -0
  45. {rasa_pro-3.13.0.dev9.dist-info → rasa_pro-3.13.0.dev11.dist-info}/entry_points.txt +0 -0
@@ -2,7 +2,7 @@ import dataclasses
2
2
  import importlib.resources
3
3
  import json
4
4
  import re
5
- from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Text
5
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Text
6
6
 
7
7
  import dotenv
8
8
  import structlog
@@ -12,9 +12,6 @@ from pydantic import ValidationError
12
12
  import rasa.shared.utils.io
13
13
  from rasa.core.available_endpoints import AvailableEndpoints
14
14
  from rasa.core.constants import (
15
- POLICY_MAX_HISTORY,
16
- POLICY_PRIORITY,
17
- SEARCH_POLICY_PRIORITY,
18
15
  UTTER_SOURCE_METADATA_KEY,
19
16
  )
20
17
  from rasa.core.information_retrieval import (
@@ -24,6 +21,14 @@ from rasa.core.information_retrieval import (
24
21
  create_from_endpoint_config,
25
22
  )
26
23
  from rasa.core.information_retrieval.faiss import FAISS_Store
24
+ from rasa.core.policies.enterprise_search_policy_config import (
25
+ DEFAULT_EMBEDDINGS_CONFIG,
26
+ DEFAULT_ENTERPRISE_SEARCH_CONFIG,
27
+ DEFAULT_LLM_CONFIG,
28
+ DEFAULT_VECTOR_STORE_TYPE,
29
+ SOURCE_PROPERTY,
30
+ EnterpriseSearchPolicyConfig,
31
+ )
27
32
  from rasa.core.policies.policy import Policy, PolicyPrediction
28
33
  from rasa.dialogue_understanding.generator.constants import (
29
34
  LLM_CONFIG_KEY,
@@ -47,18 +52,11 @@ from rasa.graph_components.providers.forms_provider import Forms
47
52
  from rasa.graph_components.providers.responses_provider import Responses
48
53
  from rasa.shared.constants import (
49
54
  EMBEDDINGS_CONFIG_KEY,
50
- MAX_COMPLETION_TOKENS_CONFIG_KEY,
51
- MAX_RETRIES_CONFIG_KEY,
52
55
  MODEL_CONFIG_KEY,
53
56
  MODEL_GROUP_ID_CONFIG_KEY,
54
57
  MODEL_NAME_CONFIG_KEY,
55
- OPENAI_PROVIDER,
56
- PROMPT_CONFIG_KEY,
57
- PROMPT_TEMPLATE_CONFIG_KEY,
58
58
  PROVIDER_CONFIG_KEY,
59
59
  RASA_PATTERN_CANNOT_HANDLE_NO_RELEVANT_ANSWER,
60
- TEMPERATURE_CONFIG_KEY,
61
- TIMEOUT_CONFIG_KEY,
62
60
  )
63
61
  from rasa.shared.core.constants import (
64
62
  ACTION_CANCEL_FLOW,
@@ -93,13 +91,9 @@ from rasa.shared.utils.health_check.embeddings_health_check_mixin import (
93
91
  from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheckMixin
94
92
  from rasa.shared.utils.io import deep_container_fingerprint
95
93
  from rasa.shared.utils.llm import (
96
- DEFAULT_OPENAI_CHAT_MODEL_NAME,
97
- DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
98
- check_prompt_config_keys_and_warn_if_deprecated,
99
94
  embedder_factory,
100
95
  get_prompt_template,
101
96
  llm_factory,
102
- resolve_model_client_config,
103
97
  sanitize_message_for_prompt,
104
98
  tracker_as_readable_transcript,
105
99
  )
@@ -120,42 +114,6 @@ structlogger = structlog.get_logger()
120
114
 
121
115
  dotenv.load_dotenv("./.env")
122
116
 
123
- SOURCE_PROPERTY = "source"
124
- VECTOR_STORE_TYPE_PROPERTY = "type"
125
- VECTOR_STORE_PROPERTY = "vector_store"
126
- VECTOR_STORE_THRESHOLD_PROPERTY = "threshold"
127
- TRACE_TOKENS_PROPERTY = "trace_prompt_tokens"
128
- CITATION_ENABLED_PROPERTY = "citation_enabled"
129
- USE_LLM_PROPERTY = "use_generative_llm"
130
- CHECK_RELEVANCY_PROPERTY = "check_relevancy"
131
- MAX_MESSAGES_IN_QUERY_KEY = "max_messages_in_query"
132
-
133
- DEFAULT_VECTOR_STORE_TYPE = "faiss"
134
- DEFAULT_VECTOR_STORE_THRESHOLD = 0.0
135
- DEFAULT_VECTOR_STORE = {
136
- VECTOR_STORE_TYPE_PROPERTY: DEFAULT_VECTOR_STORE_TYPE,
137
- SOURCE_PROPERTY: "./docs",
138
- VECTOR_STORE_THRESHOLD_PROPERTY: DEFAULT_VECTOR_STORE_THRESHOLD,
139
- }
140
-
141
- DEFAULT_CHECK_RELEVANCY_PROPERTY = False
142
- DEFAULT_USE_LLM_PROPERTY = True
143
- DEFAULT_CITATION_ENABLED_PROPERTY = False
144
-
145
- DEFAULT_LLM_CONFIG = {
146
- PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
147
- MODEL_CONFIG_KEY: DEFAULT_OPENAI_CHAT_MODEL_NAME,
148
- TIMEOUT_CONFIG_KEY: 10,
149
- TEMPERATURE_CONFIG_KEY: 0.0,
150
- MAX_COMPLETION_TOKENS_CONFIG_KEY: 256,
151
- MAX_RETRIES_CONFIG_KEY: 1,
152
- }
153
-
154
- DEFAULT_EMBEDDINGS_CONFIG = {
155
- PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
156
- MODEL_CONFIG_KEY: DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
157
- }
158
-
159
117
  ENTERPRISE_SEARCH_PROMPT_FILE_NAME = "enterprise_search_policy_prompt.jinja2"
160
118
  ENTERPRISE_SEARCH_CONFIG_FILE_NAME = "config.json"
161
119
 
@@ -177,10 +135,7 @@ DEFAULT_ENTERPRISE_SEARCH_PROMPT_WITH_RELEVANCY_CHECK_AND_CITATION_TEMPLATE = (
177
135
  )
178
136
  )
179
137
 
180
- # TODO: Update this pattern once the experiments are done
181
- _ENTERPRISE_SEARCH_ANSWER_NOT_RELEVANT_PATTERN = re.compile(
182
- r"\[NO_RELEVANT_ANSWER_FOUND\]"
183
- )
138
+ _ENTERPRISE_SEARCH_ANSWER_NOT_RELEVANT_PATTERN = re.compile(r"\[NO_RAG_ANSWER\]")
184
139
 
185
140
 
186
141
  class VectorStoreConnectionError(RasaException):
@@ -228,10 +183,7 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
228
183
  @staticmethod
229
184
  def get_default_config() -> Dict[str, Any]:
230
185
  """Returns the default config of the policy."""
231
- return {
232
- POLICY_PRIORITY: SEARCH_POLICY_PRIORITY,
233
- VECTOR_STORE_PROPERTY: DEFAULT_VECTOR_STORE,
234
- }
186
+ return DEFAULT_ENTERPRISE_SEARCH_CONFIG
235
187
 
236
188
  def __init__(
237
189
  self,
@@ -246,105 +198,71 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
246
198
  """Constructs a new Policy object."""
247
199
  super().__init__(config, model_storage, resource, execution_context, featurizer)
248
200
 
249
- # Check for deprecated keys and issue a warning if those are used
250
- check_prompt_config_keys_and_warn_if_deprecated(
251
- config, "enterprise_search_policy"
252
- )
253
- # Check for mutual exclusivity of extractive and generative search
254
- self._check_and_warn_mutual_exclusivity_of_extractive_and_generative_search()
255
-
256
- # Resolve LLM config
257
- self.config[LLM_CONFIG_KEY] = resolve_model_client_config(
258
- self.config.get(LLM_CONFIG_KEY), EnterpriseSearchPolicy.__name__
259
- )
260
- # Resolve embeddings config
261
- self.config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
262
- self.config.get(EMBEDDINGS_CONFIG_KEY), EnterpriseSearchPolicy.__name__
263
- )
201
+ parsed_config = EnterpriseSearchPolicyConfig.from_dict(config)
264
202
 
265
203
  # Vector store object and configuration
266
204
  self.vector_store = vector_store
267
- self.vector_store_config = self.config.get(
268
- VECTOR_STORE_PROPERTY, DEFAULT_VECTOR_STORE
269
- )
270
- self.vector_search_threshold = self.vector_store_config.get(
271
- VECTOR_STORE_THRESHOLD_PROPERTY, DEFAULT_VECTOR_STORE_THRESHOLD
272
- )
205
+ self.vector_store_config = parsed_config.vector_store_config
206
+ self.vector_search_threshold = parsed_config.vector_store_threshold
207
+ self.vector_store_type = parsed_config.vector_store_type
273
208
 
274
- # Embeddings configuration for encoding the search query
275
- self.embeddings_config = (
276
- self.config[EMBEDDINGS_CONFIG_KEY] or DEFAULT_EMBEDDINGS_CONFIG
277
- )
209
+ # Resolved embeddings configuration for encoding the search query
210
+ self.embeddings_config = parsed_config.embeddings_config
278
211
 
279
- # LLM Configuration for response generation
280
- self.llm_config = self.config[LLM_CONFIG_KEY] or DEFAULT_LLM_CONFIG
212
+ # Resolved LLM Configuration for response generation
213
+ self.llm_config = parsed_config.llm_config
281
214
 
282
215
  # Maximum number of turns to include in the prompt
283
- self.max_history = self.config.get(POLICY_MAX_HISTORY)
216
+ self.max_history = parsed_config.max_history
284
217
 
285
218
  # Maximum number of messages to include in the search query
286
- self.max_messages_in_query = self.config.get(MAX_MESSAGES_IN_QUERY_KEY, 2)
219
+ self.max_messages_in_query = parsed_config.max_messages_in_query
287
220
 
288
221
  # Boolean to enable/disable tracing of prompt tokens
289
- self.trace_prompt_tokens = self.config.get(TRACE_TOKENS_PROPERTY, False)
222
+ self.trace_prompt_tokens = parsed_config.trace_prompt_tokens
290
223
 
291
224
  # Boolean to enable/disable the use of LLM for response generation
292
- self.use_llm = self.config.get(USE_LLM_PROPERTY, DEFAULT_USE_LLM_PROPERTY)
225
+ self.use_llm = parsed_config.use_generative_llm
293
226
 
294
227
  # Boolean to enable/disable citation generation. This flag enables citation
295
228
  # logic, but it only takes effect if `use_llm` is True.
296
- self.citation_enabled = self.config.get(
297
- CITATION_ENABLED_PROPERTY, DEFAULT_CITATION_ENABLED_PROPERTY
298
- )
229
+ self.citation_enabled = parsed_config.enable_citation
299
230
 
300
231
  # Boolean to enable/disable the use of relevancy check alongside answer
301
232
  # generation. This flag enables citation logic, but it only takes effect if
302
233
  # `use_llm` is True.
303
- self.relevancy_check_enabled = self.config.get(
304
- CHECK_RELEVANCY_PROPERTY, DEFAULT_CHECK_RELEVANCY_PROPERTY
305
- )
234
+ self.relevancy_check_enabled = parsed_config.check_relevancy
306
235
 
307
236
  # Resolve the prompt template. The prompt will only be used if the 'use_llm' is
308
237
  # set to True.
309
- self.prompt_template = prompt_template or self._resolve_prompt_template(
310
- self.config, LOG_COMPONENT_SOURCE_METHOD_INIT
238
+ self.prompt_template = prompt_template or get_prompt_template(
239
+ jinja_file_path=parsed_config.prompt_template,
240
+ default_prompt_template=self._select_default_prompt_template_based_on_features(
241
+ parsed_config.check_relevancy, parsed_config.enable_citation
242
+ ),
243
+ log_source_component=EnterpriseSearchPolicy.__name__,
244
+ log_source_method=LOG_COMPONENT_SOURCE_METHOD_INIT,
311
245
  )
312
246
 
313
- def _check_and_warn_mutual_exclusivity_of_extractive_and_generative_search(
314
- self,
315
- ) -> None:
316
- if self.config.get(
317
- CHECK_RELEVANCY_PROPERTY, DEFAULT_CHECK_RELEVANCY_PROPERTY
318
- ) and not self.config.get(USE_LLM_PROPERTY, DEFAULT_USE_LLM_PROPERTY):
319
- structlogger.warning(
320
- "enterprise_search_policy.init"
321
- ".relevancy_check_enabled_with_disabled_generative_search",
322
- event_info=(
323
- f"The config parameter '{CHECK_RELEVANCY_PROPERTY}' is set to"
324
- f"'True', but the generative search is disabled (config"
325
- f"parameter '{USE_LLM_PROPERTY}' is set to 'False'). As a result, "
326
- "the relevancy check for the generative search will be disabled. "
327
- f"To use this check, set the config parameter '{USE_LLM_PROPERTY}' "
328
- f"to `True`."
329
- ),
330
- )
331
-
332
247
  @classmethod
333
- def _create_plain_embedder(cls, config: Dict[Text, Any]) -> "Embeddings":
248
+ def _create_plain_embedder(cls, embeddings_config: Dict[Text, Any]) -> "Embeddings":
334
249
  """Creates an embedder based on the given configuration.
335
250
 
251
+ Args:
252
+ embeddings_config: A resolved embeddings configuration. Resolved means the
253
+ configuration is either:
254
+ - A reference to a model group that has already been expanded into
255
+ its corresponding configuration using the information from
256
+ `endpoints.yml`, or
257
+ - A full configuration for the embedder defined directly (i.e. not
258
+ relying on model groups or indirections).
259
+
336
260
  Returns:
337
- The embedder.
261
+ The embedder.
338
262
  """
339
263
  # Copy the config so original config is not modified
340
- config = config.copy()
341
- # Resolve config and instantiate the embedding client
342
- config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
343
- config.get(EMBEDDINGS_CONFIG_KEY), EnterpriseSearchPolicy.__name__
344
- )
345
- client = embedder_factory(
346
- config.get(EMBEDDINGS_CONFIG_KEY), DEFAULT_EMBEDDINGS_CONFIG
347
- )
264
+ embeddings_config = embeddings_config.copy()
265
+ client = embedder_factory(embeddings_config, DEFAULT_EMBEDDINGS_CONFIG)
348
266
  # Wrap the embedding client in the adapter
349
267
  return _LangchainEmbeddingClientAdapter(client)
350
268
 
@@ -410,16 +328,16 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
410
328
  can load the policy from the resource.
411
329
  """
412
330
  # Perform health checks for both LLM and embeddings client configs
413
- self._perform_health_checks(self.config, "enterprise_search_policy.train")
414
-
415
- store_type = self.vector_store_config.get(VECTOR_STORE_TYPE_PROPERTY)
331
+ self._perform_health_checks(
332
+ self.llm_config, self.embeddings_config, "enterprise_search_policy.train"
333
+ )
416
334
 
417
335
  # telemetry call to track training start
418
336
  track_enterprise_search_policy_train_started()
419
337
 
420
338
  # validate embedding configuration
421
339
  try:
422
- embeddings = self._create_plain_embedder(self.config)
340
+ embeddings = self._create_plain_embedder(self.embeddings_config)
423
341
  except (ValidationError, Exception) as e:
424
342
  structlogger.error(
425
343
  "enterprise_search_policy.train.embedder_instantiation_failed",
@@ -431,7 +349,7 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
431
349
  f"required environment variables. Error: {e}"
432
350
  )
433
351
 
434
- if store_type == DEFAULT_VECTOR_STORE_TYPE:
352
+ if self.vector_store_type == DEFAULT_VECTOR_STORE_TYPE:
435
353
  structlogger.info("enterprise_search_policy.train.faiss")
436
354
  with self._model_storage.write_to(self._resource) as path:
437
355
  self.vector_store = FAISS_Store(
@@ -443,12 +361,13 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
443
361
  )
444
362
  else:
445
363
  structlogger.info(
446
- "enterprise_search_policy.train.custom", store_type=store_type
364
+ "enterprise_search_policy.train.custom",
365
+ store_type=self.vector_store_type,
447
366
  )
448
367
 
449
368
  # telemetry call to track training completion
450
369
  track_enterprise_search_policy_train_completed(
451
- vector_store_type=store_type,
370
+ vector_store_type=self.vector_store_type,
452
371
  embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
453
372
  embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
454
373
  or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
@@ -471,8 +390,11 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
471
390
  rasa.shared.utils.io.write_text_file(
472
391
  self.prompt_template, path / ENTERPRISE_SEARCH_PROMPT_FILE_NAME
473
392
  )
393
+ config = self.config.copy()
394
+ config[LLM_CONFIG_KEY] = self.llm_config
395
+ config[EMBEDDINGS_CONFIG_KEY] = self.embeddings_config
474
396
  rasa.shared.utils.io.dump_obj_as_json_to_file(
475
- path / ENTERPRISE_SEARCH_CONFIG_FILE_NAME, self.config
397
+ path / ENTERPRISE_SEARCH_CONFIG_FILE_NAME, config
476
398
  )
477
399
 
478
400
  def _prepare_slots_for_template(
@@ -511,8 +433,7 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
511
433
  endpoints: Endpoints configuration.
512
434
  """
513
435
  config = endpoints.vector_store if endpoints else None
514
- store_type = self.vector_store_config.get(VECTOR_STORE_TYPE_PROPERTY)
515
- if config is None and store_type != DEFAULT_VECTOR_STORE_TYPE:
436
+ if config is None and self.vector_store_type != DEFAULT_VECTOR_STORE_TYPE:
516
437
  structlogger.error(
517
438
  "enterprise_search_policy._connect_vector_store_or_raise.no_config"
518
439
  )
@@ -673,7 +594,7 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
673
594
 
674
595
  # telemetry call to track policy prediction
675
596
  track_enterprise_search_policy_predict(
676
- vector_store_type=self.vector_store_config.get(VECTOR_STORE_TYPE_PROPERTY),
597
+ vector_store_type=self.vector_store_type,
677
598
  embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
678
599
  embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
679
600
  or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
@@ -732,7 +653,7 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
732
653
  Returns:
733
654
  An LLMResponse object, or None if the call fails.
734
655
  """
735
- llm = llm_factory(self.config.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG)
656
+ llm = llm_factory(self.llm_config, DEFAULT_LLM_CONFIG)
736
657
  try:
737
658
  response = await llm.acompletion(prompt)
738
659
  return LLMResponse.ensure_llm_response(response)
@@ -862,73 +783,88 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
862
783
  **kwargs: Any,
863
784
  ) -> "EnterpriseSearchPolicy":
864
785
  """Loads a trained policy (see parent class for full docstring)."""
786
+ parsed_config = EnterpriseSearchPolicyConfig.from_dict(config)
787
+
865
788
  # Perform health checks for both LLM and embeddings client configs
866
- cls._perform_health_checks(config, "enterprise_search_policy.load")
789
+ cls._perform_health_checks(
790
+ parsed_config.llm_config,
791
+ parsed_config.embeddings_config,
792
+ "enterprise_search_policy.load",
793
+ )
794
+
795
+ prompt_template = cls._load_prompt_template(model_storage, resource)
796
+ embeddings = cls._create_plain_embedder(parsed_config.embeddings_config)
797
+ vector_store = cls._load_vector_store(
798
+ embeddings,
799
+ parsed_config.vector_store_type,
800
+ parsed_config.use_generative_llm,
801
+ model_storage,
802
+ resource,
803
+ )
804
+
805
+ structlogger.info("enterprise_search_policy.load", config=config)
806
+
807
+ return cls(
808
+ config,
809
+ model_storage,
810
+ resource,
811
+ execution_context,
812
+ vector_store=vector_store,
813
+ prompt_template=prompt_template,
814
+ )
867
815
 
868
- prompt_template = None
816
+ @classmethod
817
+ def _load_prompt_template(
818
+ cls, model_storage: ModelStorage, resource: Resource
819
+ ) -> Optional[str]:
869
820
  try:
870
821
  with model_storage.read_from(resource) as path:
871
- prompt_template = rasa.shared.utils.io.read_file(
822
+ return rasa.shared.utils.io.read_file(
872
823
  path / ENTERPRISE_SEARCH_PROMPT_FILE_NAME
873
824
  )
874
825
  except (FileNotFoundError, FileIOException) as e:
875
826
  structlogger.warning(
876
827
  "enterprise_search_policy.load.failed", error=e, resource=resource.name
877
828
  )
829
+ return None
878
830
 
879
- store_type = config.get(VECTOR_STORE_PROPERTY, {}).get(
880
- VECTOR_STORE_TYPE_PROPERTY
881
- )
882
-
883
- embeddings = cls._create_plain_embedder(config)
884
-
885
- structlogger.info("enterprise_search_policy.load", config=config)
831
+ @classmethod
832
+ def _load_vector_store(
833
+ cls,
834
+ embeddings: "Embeddings",
835
+ store_type: str,
836
+ use_generative_llm: bool,
837
+ model_storage: ModelStorage,
838
+ resource: Resource,
839
+ ) -> InformationRetrieval:
886
840
  if store_type == DEFAULT_VECTOR_STORE_TYPE:
887
841
  # if a vector store is not specified,
888
842
  # default to using FAISS with the index stored in the model
889
843
  # TODO figure out a way to get path without context manager
890
844
  with model_storage.read_from(resource) as path:
891
- vector_store = FAISS_Store(
845
+ return FAISS_Store(
892
846
  embeddings=embeddings,
893
847
  index_path=path,
894
848
  docs_folder=None,
895
849
  create_index=False,
896
- parse_as_faq_pairs=not config.get(
897
- USE_LLM_PROPERTY, DEFAULT_USE_LLM_PROPERTY
898
- ),
850
+ parse_as_faq_pairs=not use_generative_llm,
899
851
  )
900
852
  else:
901
- vector_store = create_from_endpoint_config(
853
+ return create_from_endpoint_config(
902
854
  config_type=store_type,
903
855
  embeddings=embeddings,
904
- ) # type: ignore
905
-
906
- return cls(
907
- config,
908
- model_storage,
909
- resource,
910
- execution_context,
911
- vector_store=vector_store,
912
- prompt_template=prompt_template,
913
- )
856
+ )
914
857
 
915
858
  @classmethod
916
- def _get_local_knowledge_data(cls, config: Dict[str, Any]) -> Optional[List[str]]:
859
+ def _get_local_knowledge_data(
860
+ cls, store_type: str, source: Optional[str] = None
861
+ ) -> Optional[List[str]]:
917
862
  """This is required only for local knowledge base types.
918
863
 
919
864
  e.g. FAISS, to ensure that the graph component is retrained when the knowledge
920
865
  base is updated.
921
866
  """
922
- merged_config = {**cls.get_default_config(), **config}
923
-
924
- store_type = merged_config.get(VECTOR_STORE_PROPERTY, {}).get(
925
- VECTOR_STORE_TYPE_PROPERTY
926
- )
927
- if store_type != DEFAULT_VECTOR_STORE_TYPE:
928
- return None
929
-
930
- source = merged_config.get(VECTOR_STORE_PROPERTY, {}).get(SOURCE_PROPERTY)
931
- if not source:
867
+ if store_type != DEFAULT_VECTOR_STORE_TYPE or not source:
932
868
  return None
933
869
 
934
870
  docs = FAISS_Store.load_documents(source)
@@ -944,18 +880,28 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
944
880
  @classmethod
945
881
  def fingerprint_addon(cls, config: Dict[str, Any]) -> Optional[str]:
946
882
  """Add a fingerprint of enterprise search policy for the graph."""
947
- prompt_template = cls._resolve_prompt_template(
948
- config, LOG_COMPONENT_SOURCE_METHOD_FINGERPRINT_ADDON
949
- )
950
-
951
- local_knowledge_data = cls._get_local_knowledge_data(config)
883
+ parsed_config = EnterpriseSearchPolicyConfig.from_dict(config)
952
884
 
953
- llm_config = resolve_model_client_config(
954
- config.get(LLM_CONFIG_KEY), EnterpriseSearchPolicy.__name__
885
+ # Resolve the prompt template
886
+ default_prompt_template = cls._select_default_prompt_template_based_on_features(
887
+ parsed_config.check_relevancy, parsed_config.enable_citation
955
888
  )
956
- embedding_config = resolve_model_client_config(
957
- config.get(EMBEDDINGS_CONFIG_KEY), EnterpriseSearchPolicy.__name__
889
+ prompt_template = get_prompt_template(
890
+ jinja_file_path=parsed_config.prompt_template,
891
+ default_prompt_template=default_prompt_template,
892
+ log_source_component=EnterpriseSearchPolicy.__name__,
893
+ log_source_method=LOG_COMPONENT_SOURCE_METHOD_FINGERPRINT_ADDON,
894
+ )
895
+
896
+ # Fetch the local knowledge data in case FAISS is used
897
+ local_knowledge_data = cls._get_local_knowledge_data(
898
+ parsed_config.vector_store_type, parsed_config.vector_store_source
958
899
  )
900
+
901
+ # Get the resolved LLM and embeddings configurations
902
+ llm_config = parsed_config.llm_config
903
+ embedding_config = parsed_config.embeddings_config
904
+
959
905
  return deep_container_fingerprint(
960
906
  [prompt_template, local_knowledge_data, llm_config, embedding_config]
961
907
  )
@@ -1053,21 +999,32 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
1053
999
 
1054
1000
  @classmethod
1055
1001
  def _perform_health_checks(
1056
- cls, config: Dict[Text, Any], log_source_method: str
1002
+ cls,
1003
+ llm_config: Dict[Text, Any],
1004
+ embeddings_config: Dict[Text, Any],
1005
+ log_source_method: str,
1057
1006
  ) -> None:
1058
- # Perform health check of the LLM client config
1059
- llm_config = resolve_model_client_config(config.get(LLM_CONFIG_KEY, {}))
1007
+ """
1008
+ Perform the health checks using resolved LLM and embeddings configurations.
1009
+ Resolved means the configuration is either:
1010
+ - A reference to a model group that has already been expanded into
1011
+ its corresponding configuration using the information from
1012
+ `endpoints.yml`, or
1013
+ - A full configuration for the embedder defined directly (i.e. not
1014
+ relying on model groups or indirections).
1015
+
1016
+ Args:
1017
+ llm_config: A resolved LLM configuration.
1018
+ embeddings_config: A resolved embeddings configuration.
1019
+ log_source_method: The method health checks has been called from.
1020
+
1021
+ """
1060
1022
  cls.perform_llm_health_check(
1061
1023
  llm_config,
1062
1024
  DEFAULT_LLM_CONFIG,
1063
1025
  log_source_method,
1064
1026
  EnterpriseSearchPolicy.__name__,
1065
1027
  )
1066
-
1067
- # Perform health check of the embeddings client config
1068
- embeddings_config = resolve_model_client_config(
1069
- config.get(EMBEDDINGS_CONFIG_KEY, {})
1070
- )
1071
1028
  cls.perform_embeddings_health_check(
1072
1029
  embeddings_config,
1073
1030
  DEFAULT_EMBEDDINGS_CONFIG,
@@ -1093,62 +1050,16 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
1093
1050
  Returns:
1094
1051
  The resolved jinja prompt template as a string.
1095
1052
  """
1096
-
1097
1053
  # Get the feature flags
1098
- citation_enabled = config.get(
1099
- CITATION_ENABLED_PROPERTY, DEFAULT_CITATION_ENABLED_PROPERTY
1100
- )
1101
- relevancy_check_enabled = config.get(
1102
- CHECK_RELEVANCY_PROPERTY, DEFAULT_CHECK_RELEVANCY_PROPERTY
1103
- )
1104
-
1054
+ parsed_config = EnterpriseSearchPolicyConfig.from_dict(config)
1105
1055
  # Based on the enabled features (citation, relevancy check) fetch the
1106
1056
  # appropriate default prompt
1107
1057
  default_prompt = cls._select_default_prompt_template_based_on_features(
1108
- relevancy_check_enabled, citation_enabled
1058
+ parsed_config.check_relevancy, parsed_config.enable_citation
1109
1059
  )
1110
1060
 
1111
1061
  return default_prompt
1112
1062
 
1113
- @classmethod
1114
- def _resolve_prompt_template(
1115
- cls,
1116
- config: dict,
1117
- log_source_method: Literal["init", "fingerprint"],
1118
- ) -> str:
1119
- """
1120
- Resolves the prompt template to use for the Enterprise Search Policy's
1121
- generative search.
1122
-
1123
- Checks if a custom template is provided via component's configuration. If not,
1124
- it selects the appropriate default template based on the enabled features
1125
- (citation and relevancy check).
1126
-
1127
- Args:
1128
- config: The component's configuration.
1129
- log_source_method: The name of the method or function emitting the log for
1130
- better traceability.
1131
- Returns:
1132
- The resolved jinja prompt template as a string.
1133
- """
1134
-
1135
- # Read the template path from the configuration if available.
1136
- # The deprecated 'prompt' has a lower priority compared to 'prompt_template'
1137
- config_defined_prompt = (
1138
- config.get(PROMPT_TEMPLATE_CONFIG_KEY)
1139
- or config.get(PROMPT_CONFIG_KEY)
1140
- or None
1141
- )
1142
- # Select the default prompt based on the features set in the config.
1143
- default_prompt = cls.get_system_default_prompt_based_on_config(config)
1144
-
1145
- return get_prompt_template(
1146
- config_defined_prompt,
1147
- default_prompt,
1148
- log_source_component=EnterpriseSearchPolicy.__name__,
1149
- log_source_method=log_source_method,
1150
- )
1151
-
1152
1063
  @classmethod
1153
1064
  def _select_default_prompt_template_based_on_features(
1154
1065
  cls,