rasa-pro 3.12.6.dev2__py3-none-any.whl → 3.13.0.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (92) hide show
  1. rasa/__init__.py +0 -6
  2. rasa/cli/scaffold.py +1 -1
  3. rasa/core/actions/action.py +38 -34
  4. rasa/core/actions/action_run_slot_rejections.py +1 -1
  5. rasa/core/channels/studio_chat.py +16 -43
  6. rasa/core/channels/voice_ready/audiocodes.py +46 -17
  7. rasa/core/information_retrieval/faiss.py +68 -7
  8. rasa/core/information_retrieval/information_retrieval.py +40 -2
  9. rasa/core/information_retrieval/milvus.py +7 -2
  10. rasa/core/information_retrieval/qdrant.py +7 -2
  11. rasa/core/nlg/contextual_response_rephraser.py +11 -27
  12. rasa/core/nlg/generator.py +5 -21
  13. rasa/core/nlg/response.py +6 -43
  14. rasa/core/nlg/summarize.py +1 -15
  15. rasa/core/nlg/translate.py +0 -8
  16. rasa/core/policies/enterprise_search_policy.py +64 -316
  17. rasa/core/policies/flows/flow_executor.py +3 -38
  18. rasa/core/policies/intentless_policy.py +4 -17
  19. rasa/core/policies/policy.py +0 -2
  20. rasa/core/processor.py +27 -6
  21. rasa/core/utils.py +53 -0
  22. rasa/dialogue_understanding/coexistence/llm_based_router.py +4 -18
  23. rasa/dialogue_understanding/commands/cancel_flow_command.py +4 -59
  24. rasa/dialogue_understanding/commands/knowledge_answer_command.py +2 -2
  25. rasa/dialogue_understanding/commands/start_flow_command.py +0 -41
  26. rasa/dialogue_understanding/generator/command_generator.py +67 -0
  27. rasa/dialogue_understanding/generator/command_parser.py +1 -1
  28. rasa/dialogue_understanding/generator/llm_based_command_generator.py +7 -23
  29. rasa/dialogue_understanding/generator/llm_command_generator.py +1 -3
  30. rasa/dialogue_understanding/generator/prompt_templates/command_prompt_template.jinja2 +1 -1
  31. rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_claude_3_5_sonnet_20240620_template.jinja2 +1 -1
  32. rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_gpt_4o_2024_11_20_template.jinja2 +24 -2
  33. rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +8 -12
  34. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +0 -61
  35. rasa/dialogue_understanding/processor/command_processor.py +7 -65
  36. rasa/dialogue_understanding/stack/utils.py +0 -38
  37. rasa/dialogue_understanding_test/command_metric_calculation.py +7 -40
  38. rasa/dialogue_understanding_test/command_metrics.py +38 -0
  39. rasa/dialogue_understanding_test/du_test_case.py +58 -25
  40. rasa/dialogue_understanding_test/du_test_result.py +228 -132
  41. rasa/dialogue_understanding_test/du_test_runner.py +10 -1
  42. rasa/dialogue_understanding_test/io.py +48 -16
  43. rasa/document_retrieval/__init__.py +0 -0
  44. rasa/document_retrieval/constants.py +32 -0
  45. rasa/document_retrieval/document_post_processor.py +351 -0
  46. rasa/document_retrieval/document_post_processor_prompt_template.jinja2 +0 -0
  47. rasa/document_retrieval/document_retriever.py +333 -0
  48. rasa/document_retrieval/knowledge_base_connectors/__init__.py +0 -0
  49. rasa/document_retrieval/knowledge_base_connectors/api_connector.py +39 -0
  50. rasa/document_retrieval/knowledge_base_connectors/knowledge_base_connector.py +34 -0
  51. rasa/document_retrieval/knowledge_base_connectors/vector_store_connector.py +226 -0
  52. rasa/document_retrieval/query_rewriter.py +234 -0
  53. rasa/document_retrieval/query_rewriter_prompt_template.jinja2 +8 -0
  54. rasa/engine/recipes/default_components.py +2 -0
  55. rasa/hooks.py +0 -55
  56. rasa/model_manager/model_api.py +1 -1
  57. rasa/model_manager/socket_bridge.py +0 -7
  58. rasa/shared/constants.py +0 -5
  59. rasa/shared/core/constants.py +0 -8
  60. rasa/shared/core/domain.py +12 -3
  61. rasa/shared/core/flows/flow.py +0 -17
  62. rasa/shared/core/flows/flows_yaml_schema.json +3 -38
  63. rasa/shared/core/flows/steps/collect.py +5 -18
  64. rasa/shared/core/flows/utils.py +1 -16
  65. rasa/shared/core/slot_mappings.py +11 -5
  66. rasa/shared/core/slots.py +1 -1
  67. rasa/shared/core/trackers.py +4 -10
  68. rasa/shared/nlu/constants.py +0 -1
  69. rasa/shared/providers/constants.py +0 -9
  70. rasa/shared/providers/llm/_base_litellm_client.py +4 -14
  71. rasa/shared/providers/llm/default_litellm_llm_client.py +2 -2
  72. rasa/shared/providers/llm/litellm_router_llm_client.py +7 -17
  73. rasa/shared/providers/llm/llm_client.py +15 -24
  74. rasa/shared/providers/llm/self_hosted_llm_client.py +2 -10
  75. rasa/shared/utils/common.py +11 -1
  76. rasa/shared/utils/health_check/health_check.py +1 -7
  77. rasa/shared/utils/llm.py +1 -1
  78. rasa/tracing/instrumentation/attribute_extractors.py +50 -17
  79. rasa/tracing/instrumentation/instrumentation.py +12 -12
  80. rasa/tracing/instrumentation/intentless_policy_instrumentation.py +1 -2
  81. rasa/utils/licensing.py +0 -15
  82. rasa/validator.py +1 -123
  83. rasa/version.py +1 -1
  84. {rasa_pro-3.12.6.dev2.dist-info → rasa_pro-3.13.0.dev2.dist-info}/METADATA +2 -3
  85. {rasa_pro-3.12.6.dev2.dist-info → rasa_pro-3.13.0.dev2.dist-info}/RECORD +88 -80
  86. rasa/core/actions/action_handle_digressions.py +0 -164
  87. rasa/dialogue_understanding/commands/handle_digressions_command.py +0 -144
  88. rasa/dialogue_understanding/patterns/handle_digressions.py +0 -81
  89. rasa/monkey_patches.py +0 -91
  90. {rasa_pro-3.12.6.dev2.dist-info → rasa_pro-3.13.0.dev2.dist-info}/NOTICE +0 -0
  91. {rasa_pro-3.12.6.dev2.dist-info → rasa_pro-3.13.0.dev2.dist-info}/WHEEL +0 -0
  92. {rasa_pro-3.12.6.dev2.dist-info → rasa_pro-3.13.0.dev2.dist-info}/entry_points.txt +0 -0
@@ -1,12 +1,10 @@
1
1
  import importlib.resources
2
- import json
3
2
  import re
4
3
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, Text
5
4
 
6
5
  import dotenv
7
6
  import structlog
8
7
  from jinja2 import Template
9
- from pydantic import ValidationError
10
8
 
11
9
  import rasa.shared.utils.io
12
10
  from rasa.core.constants import (
@@ -16,12 +14,9 @@ from rasa.core.constants import (
16
14
  UTTER_SOURCE_METADATA_KEY,
17
15
  )
18
16
  from rasa.core.information_retrieval import (
19
- InformationRetrieval,
20
- InformationRetrievalException,
21
17
  SearchResult,
22
- create_from_endpoint_config,
18
+ SearchResultList,
23
19
  )
24
- from rasa.core.information_retrieval.faiss import FAISS_Store
25
20
  from rasa.core.policies.policy import Policy, PolicyPrediction
26
21
  from rasa.core.utils import AvailableEndpoints
27
22
  from rasa.dialogue_understanding.generator.constants import (
@@ -38,6 +33,10 @@ from rasa.dialogue_understanding.stack.frames import (
38
33
  PatternFlowStackFrame,
39
34
  SearchStackFrame,
40
35
  )
36
+ from rasa.document_retrieval.constants import (
37
+ POST_PROCESSED_DOCUMENTS_KEY,
38
+ SEARCH_QUERY_KEY,
39
+ )
41
40
  from rasa.engine.graph import ExecutionContext
42
41
  from rasa.engine.recipes.default_recipe import DefaultV1Recipe
43
42
  from rasa.engine.storage.resource import Resource
@@ -45,14 +44,7 @@ from rasa.engine.storage.storage import ModelStorage
45
44
  from rasa.graph_components.providers.forms_provider import Forms
46
45
  from rasa.graph_components.providers.responses_provider import Responses
47
46
  from rasa.shared.constants import (
48
- EMBEDDINGS_CONFIG_KEY,
49
- LANGFUSE_CUSTOM_METADATA_DICT,
50
- LANGFUSE_METADATA_SESSION_ID,
51
- LANGFUSE_METADATA_USER_ID,
52
- LANGFUSE_TAGS,
53
47
  MODEL_CONFIG_KEY,
54
- MODEL_GROUP_ID_CONFIG_KEY,
55
- MODEL_NAME_CONFIG_KEY,
56
48
  OPENAI_PROVIDER,
57
49
  PROMPT_CONFIG_KEY,
58
50
  PROVIDER_CONFIG_KEY,
@@ -64,10 +56,10 @@ from rasa.shared.core.constants import (
64
56
  DEFAULT_SLOT_NAMES,
65
57
  )
66
58
  from rasa.shared.core.domain import Domain
67
- from rasa.shared.core.events import BotUttered, Event, UserUttered
59
+ from rasa.shared.core.events import Event
68
60
  from rasa.shared.core.generator import TrackerWithCachedStates
69
- from rasa.shared.core.trackers import DialogueStateTracker, EventVerbosity
70
- from rasa.shared.exceptions import FileIOException, RasaException
61
+ from rasa.shared.core.trackers import DialogueStateTracker
62
+ from rasa.shared.exceptions import FileIOException
71
63
  from rasa.shared.nlu.constants import (
72
64
  KEY_COMPONENT_NAME,
73
65
  KEY_LLM_RESPONSE_METADATA,
@@ -76,12 +68,8 @@ from rasa.shared.nlu.constants import (
76
68
  PROMPTS,
77
69
  )
78
70
  from rasa.shared.nlu.training_data.training_data import TrainingData
79
- from rasa.shared.providers.embedding._langchain_embedding_client_adapter import (
80
- _LangchainEmbeddingClientAdapter,
81
- )
82
71
  from rasa.shared.providers.llm.llm_client import LLMClient
83
72
  from rasa.shared.providers.llm.llm_response import LLMResponse, measure_llm_latency
84
- from rasa.shared.utils.cli import print_error_and_exit
85
73
  from rasa.shared.utils.health_check.embeddings_health_check_mixin import (
86
74
  EmbeddingsHealthCheckMixin,
87
75
  )
@@ -89,23 +77,13 @@ from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheck
89
77
  from rasa.shared.utils.io import deep_container_fingerprint
90
78
  from rasa.shared.utils.llm import (
91
79
  DEFAULT_OPENAI_CHAT_MODEL_NAME,
92
- DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
93
- embedder_factory,
94
80
  get_prompt_template,
95
81
  llm_factory,
96
82
  resolve_model_client_config,
97
- sanitize_message_for_prompt,
98
83
  tracker_as_readable_transcript,
99
84
  )
100
- from rasa.telemetry import (
101
- track_enterprise_search_policy_predict,
102
- track_enterprise_search_policy_train_completed,
103
- track_enterprise_search_policy_train_started,
104
- )
105
85
 
106
86
  if TYPE_CHECKING:
107
- from langchain.schema.embeddings import Embeddings
108
-
109
87
  from rasa.core.featurizers.tracker_featurizers import TrackerFeaturizer
110
88
 
111
89
  from rasa.utils.log_utils import log_llm
@@ -114,22 +92,11 @@ logger = structlog.get_logger()
114
92
 
115
93
  dotenv.load_dotenv("./.env")
116
94
 
117
- SOURCE_PROPERTY = "source"
118
- VECTOR_STORE_TYPE_PROPERTY = "type"
119
- VECTOR_STORE_PROPERTY = "vector_store"
120
- VECTOR_STORE_THRESHOLD_PROPERTY = "threshold"
121
95
  TRACE_TOKENS_PROPERTY = "trace_prompt_tokens"
122
96
  CITATION_ENABLED_PROPERTY = "citation_enabled"
123
97
  USE_LLM_PROPERTY = "use_generative_llm"
124
98
  MAX_MESSAGES_IN_QUERY_KEY = "max_messages_in_query"
125
99
 
126
- DEFAULT_VECTOR_STORE_TYPE = "faiss"
127
- DEFAULT_VECTOR_STORE_THRESHOLD = 0.0
128
- DEFAULT_VECTOR_STORE = {
129
- VECTOR_STORE_TYPE_PROPERTY: DEFAULT_VECTOR_STORE_TYPE,
130
- SOURCE_PROPERTY: "./docs",
131
- VECTOR_STORE_THRESHOLD_PROPERTY: DEFAULT_VECTOR_STORE_THRESHOLD,
132
- }
133
100
 
134
101
  DEFAULT_LLM_CONFIG = {
135
102
  PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
@@ -140,11 +107,6 @@ DEFAULT_LLM_CONFIG = {
140
107
  "max_retries": 1,
141
108
  }
142
109
 
143
- DEFAULT_EMBEDDINGS_CONFIG = {
144
- PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
145
- "model": DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
146
- }
147
-
148
110
  ENTERPRISE_SEARCH_PROMPT_FILE_NAME = "enterprise_search_policy_prompt.jinja2"
149
111
  ENTERPRISE_SEARCH_CONFIG_FILE_NAME = "config.json"
150
112
 
@@ -160,14 +122,6 @@ DEFAULT_ENTERPRISE_SEARCH_PROMPT_WITH_CITATION_TEMPLATE = importlib.resources.re
160
122
  )
161
123
 
162
124
 
163
- class VectorStoreConnectionError(RasaException):
164
- """Exception raised for errors in connecting to the vector store."""
165
-
166
-
167
- class VectorStoreConfigurationError(RasaException):
168
- """Exception raised for errors in vector store configuration."""
169
-
170
-
171
125
  @DefaultV1Recipe.register(
172
126
  DefaultV1Recipe.ComponentType.POLICY_WITH_END_TO_END_SUPPORT, is_trainable=True
173
127
  )
@@ -201,7 +155,6 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
201
155
  """Returns the default config of the policy."""
202
156
  return {
203
157
  POLICY_PRIORITY: SEARCH_POLICY_PRIORITY,
204
- VECTOR_STORE_PROPERTY: DEFAULT_VECTOR_STORE,
205
158
  }
206
159
 
207
160
  def __init__(
@@ -210,7 +163,6 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
210
163
  model_storage: ModelStorage,
211
164
  resource: Resource,
212
165
  execution_context: ExecutionContext,
213
- vector_store: Optional[InformationRetrieval] = None,
214
166
  featurizer: Optional["TrackerFeaturizer"] = None,
215
167
  prompt_template: Optional[Text] = None,
216
168
  ) -> None:
@@ -221,21 +173,6 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
221
173
  self.config[LLM_CONFIG_KEY] = resolve_model_client_config(
222
174
  self.config.get(LLM_CONFIG_KEY), EnterpriseSearchPolicy.__name__
223
175
  )
224
- # Resolve embeddings config
225
- self.config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
226
- self.config.get(EMBEDDINGS_CONFIG_KEY), EnterpriseSearchPolicy.__name__
227
- )
228
-
229
- # Vector store object and configuration
230
- self.vector_store = vector_store
231
- self.vector_store_config = self.config.get(
232
- VECTOR_STORE_PROPERTY, DEFAULT_VECTOR_STORE
233
- )
234
-
235
- # Embeddings configuration for encoding the search query
236
- self.embeddings_config = (
237
- self.config[EMBEDDINGS_CONFIG_KEY] or DEFAULT_EMBEDDINGS_CONFIG
238
- )
239
176
 
240
177
  # LLM Configuration for response generation
241
178
  self.llm_config = self.config[LLM_CONFIG_KEY] or DEFAULT_LLM_CONFIG
@@ -243,9 +180,6 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
243
180
  # Maximum number of turns to include in the prompt
244
181
  self.max_history = self.config.get(POLICY_MAX_HISTORY)
245
182
 
246
- # Maximum number of messages to include in the search query
247
- self.max_messages_in_query = self.config.get(MAX_MESSAGES_IN_QUERY_KEY, 2)
248
-
249
183
  # boolean to enable/disable tracing of prompt tokens
250
184
  self.trace_prompt_tokens = self.config.get(TRACE_TOKENS_PROPERTY, False)
251
185
 
@@ -267,25 +201,6 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
267
201
  if self.citation_enabled:
268
202
  self.prompt_template = self.citation_prompt_template
269
203
 
270
- @classmethod
271
- def _create_plain_embedder(cls, config: Dict[Text, Any]) -> "Embeddings":
272
- """Creates an embedder based on the given configuration.
273
-
274
- Returns:
275
- The embedder.
276
- """
277
- # Copy the config so original config is not modified
278
- config = config.copy()
279
- # Resolve config and instantiate the embedding client
280
- config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
281
- config.get(EMBEDDINGS_CONFIG_KEY), EnterpriseSearchPolicy.__name__
282
- )
283
- client = embedder_factory(
284
- config.get(EMBEDDINGS_CONFIG_KEY), DEFAULT_EMBEDDINGS_CONFIG
285
- )
286
- # Wrap the embedding client in the adapter
287
- return _LangchainEmbeddingClientAdapter(client)
288
-
289
204
  @classmethod
290
205
  def _add_prompt_and_llm_response_to_latest_message(
291
206
  cls,
@@ -350,52 +265,24 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
350
265
  # Perform health checks for both LLM and embeddings client configs
351
266
  self._perform_health_checks(self.config, "enterprise_search_policy.train")
352
267
 
353
- store_type = self.vector_store_config.get(VECTOR_STORE_TYPE_PROPERTY)
268
+ # # telemetry call to track training start
269
+ # track_enterprise_search_policy_train_started()
270
+ # # telemetry call to track training completion
271
+ # track_enterprise_search_policy_train_completed(
272
+ # vector_store_type=store_type,
273
+ # embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
274
+ # embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
275
+ # or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
276
+ # embeddings_model_group_id=self.embeddings_config.get(
277
+ # MODEL_GROUP_ID_CONFIG_KEY
278
+ # ),
279
+ # llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
280
+ # llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
281
+ # or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
282
+ # llm_model_group_id=self.llm_config.get(MODEL_GROUP_ID_CONFIG_KEY),
283
+ # citation_enabled=self.citation_enabled,
284
+ # )
354
285
 
355
- # telemetry call to track training start
356
- track_enterprise_search_policy_train_started()
357
-
358
- # validate embedding configuration
359
- try:
360
- embeddings = self._create_plain_embedder(self.config)
361
- except (ValidationError, Exception) as e:
362
- logger.error(
363
- "enterprise_search_policy.train.embedder_instantiation_failed",
364
- message="Unable to instantiate the embedding client.",
365
- error=e,
366
- )
367
- print_error_and_exit(
368
- "Unable to create embedder. Please make sure you specified the "
369
- f"required environment variables. Error: {e}"
370
- )
371
-
372
- if store_type == DEFAULT_VECTOR_STORE_TYPE:
373
- logger.info("enterprise_search_policy.train.faiss")
374
- with self._model_storage.write_to(self._resource) as path:
375
- self.vector_store = FAISS_Store(
376
- docs_folder=self.vector_store_config.get(SOURCE_PROPERTY),
377
- embeddings=embeddings,
378
- index_path=path,
379
- create_index=True,
380
- )
381
- else:
382
- logger.info("enterprise_search_policy.train.custom", store_type=store_type)
383
-
384
- # telemetry call to track training completion
385
- track_enterprise_search_policy_train_completed(
386
- vector_store_type=store_type,
387
- embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
388
- embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
389
- or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
390
- embeddings_model_group_id=self.embeddings_config.get(
391
- MODEL_GROUP_ID_CONFIG_KEY
392
- ),
393
- llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
394
- llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
395
- or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
396
- llm_model_group_id=self.llm_config.get(MODEL_GROUP_ID_CONFIG_KEY),
397
- citation_enabled=self.citation_enabled,
398
- )
399
286
  self.persist()
400
287
  return self._resource
401
288
 
@@ -432,60 +319,6 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
432
319
  )
433
320
  return template_slots
434
321
 
435
- def _connect_vector_store_or_raise(
436
- self, endpoints: Optional[AvailableEndpoints]
437
- ) -> None:
438
- """Connects to the vector store or raises an exception.
439
-
440
- Raise exceptions for the following cases:
441
- - The configuration is not specified
442
- - Unable to connect to the vector store
443
-
444
- Args:
445
- endpoints: Endpoints configuration.
446
- """
447
- config = endpoints.vector_store if endpoints else None
448
- store_type = self.vector_store_config.get(VECTOR_STORE_TYPE_PROPERTY)
449
- if config is None and store_type != DEFAULT_VECTOR_STORE_TYPE:
450
- logger.error(
451
- "enterprise_search_policy._connect_vector_store_or_raise.no_config"
452
- )
453
- raise VectorStoreConfigurationError(
454
- """No vector store specified. Please specify a vector
455
- store in the endpoints configuration"""
456
- )
457
- try:
458
- self.vector_store.connect(config) # type: ignore
459
- except Exception as e:
460
- logger.error(
461
- "enterprise_search_policy._connect_vector_store_or_raise.connect_error",
462
- error=e,
463
- config=config,
464
- )
465
- raise VectorStoreConnectionError(
466
- f"Unable to connect to the vector store. Error: {e}"
467
- )
468
-
469
- def _prepare_search_query(self, tracker: DialogueStateTracker, history: int) -> str:
470
- """Prepares the search query.
471
- The search query is the last N messages in the conversation history.
472
-
473
- Args:
474
- tracker: The tracker containing the conversation history up to now.
475
- history: The number of messages to include in the search query.
476
-
477
- Returns:
478
- The search query.
479
- """
480
- transcript = []
481
- for event in tracker.applied_events():
482
- if isinstance(event, UserUttered) or isinstance(event, BotUttered):
483
- transcript.append(sanitize_message_for_prompt(event.text))
484
-
485
- search_query = " ".join(transcript[-history:][::-1])
486
- logger.debug("search_query", search_query=search_query)
487
- return search_query
488
-
489
322
  async def predict_action_probabilities( # type: ignore[override]
490
323
  self,
491
324
  tracker: DialogueStateTracker,
@@ -509,49 +342,37 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
509
342
  The prediction.
510
343
  """
511
344
  logger_key = "enterprise_search_policy.predict_action_probabilities"
512
- vector_search_threshold = self.vector_store_config.get(
513
- VECTOR_STORE_THRESHOLD_PROPERTY, DEFAULT_VECTOR_STORE_THRESHOLD
514
- )
515
- llm = llm_factory(self.config.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG)
345
+
516
346
  if not self.supports_current_stack_frame(
517
347
  tracker, False, False
518
348
  ) or self.should_abstain_in_coexistence(tracker, True):
519
349
  return self._prediction(self._default_predictions(domain))
520
350
 
521
- if not self.vector_store:
522
- logger.error(f"{logger_key}.no_vector_store")
523
- return self._create_prediction_internal_error(domain, tracker)
524
-
525
- try:
526
- self._connect_vector_store_or_raise(endpoints)
527
- except (VectorStoreConfigurationError, VectorStoreConnectionError) as e:
528
- logger.error(f"{logger_key}.connection_error", error=e)
529
- return self._create_prediction_internal_error(domain, tracker)
351
+ # retrieve documents from the latest message
352
+ # document retrieval happened earlier in the pipeline
353
+ if tracker.latest_message is None or tracker.latest_message.parse_data is None:
354
+ logger.info(f"{logger_key}.no_documents")
355
+ return self._create_prediction_cannot_handle(domain, tracker)
530
356
 
531
- search_query = self._prepare_search_query(
532
- tracker, int(self.max_messages_in_query)
357
+ documents_data = tracker.latest_message.parse_data.get(
358
+ POST_PROCESSED_DOCUMENTS_KEY
533
359
  )
534
- tracker_state = tracker.current_state(EventVerbosity.AFTER_RESTART)
535
360
 
536
- try:
537
- documents = await self.vector_store.search(
538
- query=search_query,
539
- tracker_state=tracker_state,
540
- threshold=vector_search_threshold,
541
- )
542
- except InformationRetrievalException as e:
543
- logger.error(f"{logger_key}.search_error", error=e)
544
- return self._create_prediction_internal_error(domain, tracker)
361
+ if not documents_data:
362
+ logger.info(f"{logger_key}.no_documents")
363
+ return self._create_prediction_cannot_handle(domain, tracker)
364
+
365
+ documents = SearchResultList.from_dict(documents_data)
545
366
 
546
367
  if not documents.results:
547
368
  logger.info(f"{logger_key}.no_documents")
548
369
  return self._create_prediction_cannot_handle(domain, tracker)
549
370
 
371
+ llm = llm_factory(self.config.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG)
372
+
550
373
  if self.use_llm:
551
374
  prompt = self._render_prompt(tracker, documents.results)
552
- llm_response = await self._generate_llm_answer(
553
- llm, prompt, tracker.sender_id
554
- )
375
+ llm_response = await self._generate_llm_answer(llm, prompt)
555
376
  llm_response = LLMResponse.ensure_llm_response(llm_response)
556
377
 
557
378
  self._add_prompt_and_llm_response_to_latest_message(
@@ -593,25 +414,29 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
593
414
  result.text for result in documents.results
594
415
  ],
595
416
  UTTER_SOURCE_METADATA_KEY: self.__class__.__name__,
596
- SEARCH_QUERY_METADATA_KEY: search_query,
417
+ SEARCH_QUERY_METADATA_KEY: tracker.latest_message.parse_data.get(
418
+ SEARCH_QUERY_KEY
419
+ ),
597
420
  }
598
421
  }
599
422
 
600
- # telemetry call to track policy prediction
601
- track_enterprise_search_policy_predict(
602
- vector_store_type=self.vector_store_config.get(VECTOR_STORE_TYPE_PROPERTY),
603
- embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
604
- embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
605
- or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
606
- embeddings_model_group_id=self.embeddings_config.get(
607
- MODEL_GROUP_ID_CONFIG_KEY
608
- ),
609
- llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
610
- llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
611
- or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
612
- llm_model_group_id=self.llm_config.get(MODEL_GROUP_ID_CONFIG_KEY),
613
- citation_enabled=self.citation_enabled,
614
- )
423
+ # # telemetry call to track policy prediction
424
+ # track_enterprise_search_policy_predict(
425
+ # vector_store_type=self.vector_store_config.get(
426
+ # VECTOR_STORE_TYPE_PROPERTY),
427
+ # embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
428
+ # embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
429
+ # or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
430
+ # embeddings_model_group_id=self.embeddings_config.get(
431
+ # MODEL_GROUP_ID_CONFIG_KEY
432
+ # ),
433
+ # llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
434
+ # llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
435
+ # or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
436
+ # llm_model_group_id=self.llm_config.get(MODEL_GROUP_ID_CONFIG_KEY),
437
+ # citation_enabled=self.citation_enabled,
438
+ # )
439
+
615
440
  return self._create_prediction(
616
441
  domain=domain, tracker=tracker, action_metadata=action_metadata
617
442
  )
@@ -647,26 +472,19 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
647
472
 
648
473
  @measure_llm_latency
649
474
  async def _generate_llm_answer(
650
- self, llm: LLMClient, prompt: Text, sender_id: str
475
+ self, llm: LLMClient, prompt: Text
651
476
  ) -> Optional[LLMResponse]:
652
477
  """Fetches an LLM completion for the provided prompt.
653
478
 
654
479
  Args:
655
480
  llm: The LLM client used to get the completion.
656
481
  prompt: The prompt text to send to the model.
657
- sender_id: sender_id from the tracker.
658
482
 
659
483
  Returns:
660
484
  An LLMResponse object, or None if the call fails.
661
485
  """
662
- metadata = {
663
- LANGFUSE_METADATA_USER_ID: self.user_id,
664
- LANGFUSE_METADATA_SESSION_ID: sender_id,
665
- LANGFUSE_CUSTOM_METADATA_DICT: {"component": self.__class__.__name__},
666
- LANGFUSE_TAGS: [self.__class__.__name__],
667
- }
668
486
  try:
669
- return await llm.acompletion(prompt, metadata)
487
+ return await llm.acompletion(prompt)
670
488
  except Exception as e:
671
489
  # unfortunately, langchain does not wrap LLM exceptions which means
672
490
  # we have to catch all exceptions here
@@ -786,73 +604,19 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
786
604
  "enterprise_search_policy.load.failed", error=e, resource=resource.name
787
605
  )
788
606
 
789
- store_type = config.get(VECTOR_STORE_PROPERTY, {}).get(
790
- VECTOR_STORE_TYPE_PROPERTY
791
- )
792
-
793
- embeddings = cls._create_plain_embedder(config)
794
-
795
607
  logger.info("enterprise_search_policy.load", config=config)
796
- if store_type == DEFAULT_VECTOR_STORE_TYPE:
797
- # if a vector store is not specified,
798
- # default to using FAISS with the index stored in the model
799
- # TODO figure out a way to get path without context manager
800
- with model_storage.read_from(resource) as path:
801
- vector_store = FAISS_Store(
802
- embeddings=embeddings,
803
- index_path=path,
804
- docs_folder=None,
805
- create_index=False,
806
- )
807
- else:
808
- vector_store = create_from_endpoint_config(
809
- config_type=store_type,
810
- embeddings=embeddings,
811
- ) # type: ignore
812
608
 
813
609
  return cls(
814
610
  config,
815
611
  model_storage,
816
612
  resource,
817
613
  execution_context,
818
- vector_store=vector_store,
819
614
  prompt_template=prompt_template,
820
615
  )
821
616
 
822
- @classmethod
823
- def _get_local_knowledge_data(cls, config: Dict[str, Any]) -> Optional[List[str]]:
824
- """This is required only for local knowledge base types.
825
-
826
- e.g. FAISS, to ensure that the graph component is retrained when the knowledge
827
- base is updated.
828
- """
829
- merged_config = {**cls.get_default_config(), **config}
830
-
831
- store_type = merged_config.get(VECTOR_STORE_PROPERTY, {}).get(
832
- VECTOR_STORE_TYPE_PROPERTY
833
- )
834
- if store_type != DEFAULT_VECTOR_STORE_TYPE:
835
- return None
836
-
837
- source = merged_config.get(VECTOR_STORE_PROPERTY, {}).get(SOURCE_PROPERTY)
838
- if not source:
839
- return None
840
-
841
- docs = FAISS_Store.load_documents(source)
842
-
843
- if len(docs) == 0:
844
- return None
845
-
846
- docs_as_strings = [
847
- json.dumps(doc.dict(), ensure_ascii=False, sort_keys=True) for doc in docs
848
- ]
849
- return sorted(docs_as_strings)
850
-
851
617
  @classmethod
852
618
  def fingerprint_addon(cls, config: Dict[str, Any]) -> Optional[str]:
853
619
  """Add a fingerprint of enterprise search policy for the graph."""
854
- local_knowledge_data = cls._get_local_knowledge_data(config)
855
-
856
620
  prompt_template = get_prompt_template(
857
621
  config.get(PROMPT_CONFIG_KEY),
858
622
  DEFAULT_ENTERPRISE_SEARCH_PROMPT_TEMPLATE,
@@ -861,12 +625,7 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
861
625
  llm_config = resolve_model_client_config(
862
626
  config.get(LLM_CONFIG_KEY), EnterpriseSearchPolicy.__name__
863
627
  )
864
- embedding_config = resolve_model_client_config(
865
- config.get(EMBEDDINGS_CONFIG_KEY), EnterpriseSearchPolicy.__name__
866
- )
867
- return deep_container_fingerprint(
868
- [prompt_template, local_knowledge_data, llm_config, embedding_config]
869
- )
628
+ return deep_container_fingerprint([prompt_template, llm_config])
870
629
 
871
630
  @staticmethod
872
631
  def post_process_citations(llm_answer: str) -> str:
@@ -971,14 +730,3 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
971
730
  log_source_method,
972
731
  EnterpriseSearchPolicy.__name__,
973
732
  )
974
-
975
- # Perform health check of the embeddings client config
976
- embeddings_config = resolve_model_client_config(
977
- config.get(EMBEDDINGS_CONFIG_KEY, {})
978
- )
979
- cls.perform_embeddings_health_check(
980
- embeddings_config,
981
- DEFAULT_EMBEDDINGS_CONFIG,
982
- log_source_method,
983
- EnterpriseSearchPolicy.__name__,
984
- )
@@ -23,7 +23,6 @@ from rasa.core.policies.flows.flow_step_result import (
23
23
  )
24
24
  from rasa.dialogue_understanding.commands import CancelFlowCommand
25
25
  from rasa.dialogue_understanding.patterns.cancel import CancelPatternFlowStackFrame
26
- from rasa.dialogue_understanding.patterns.clarify import ClarifyPatternFlowStackFrame
27
26
  from rasa.dialogue_understanding.patterns.collect_information import (
28
27
  CollectInformationPatternFlowStackFrame,
29
28
  )
@@ -51,7 +50,6 @@ from rasa.dialogue_understanding.stack.frames.flow_stack_frame import (
51
50
  )
52
51
  from rasa.dialogue_understanding.stack.utils import (
53
52
  top_user_flow_frame,
54
- user_flows_on_the_stack,
55
53
  )
56
54
  from rasa.shared.constants import RASA_PATTERN_HUMAN_HANDOFF
57
55
  from rasa.shared.core.constants import (
@@ -280,33 +278,6 @@ def trigger_pattern_continue_interrupted(
280
278
  return events
281
279
 
282
280
 
283
- def trigger_pattern_clarification(
284
- current_frame: DialogueStackFrame, stack: DialogueStack, flows: FlowsList
285
- ) -> None:
286
- """Trigger the pattern to clarify which topic to continue if needed."""
287
- if not isinstance(current_frame, UserFlowStackFrame):
288
- return None
289
-
290
- if current_frame.frame_type in [
291
- FlowStackFrameType.CALL,
292
- FlowStackFrameType.INTERRUPT,
293
- ]:
294
- # we want to return to the flow that called
295
- # the current flow or the flow that was interrupted
296
- # by the current flow
297
- return None
298
-
299
- pending_flows = [
300
- flows.flow_by_id(frame.flow_id)
301
- for frame in stack.frames
302
- if isinstance(frame, UserFlowStackFrame)
303
- and frame.flow_id != current_frame.flow_id
304
- ]
305
-
306
- flow_names = [flow.readable_name() for flow in pending_flows if flow is not None]
307
- stack.push(ClarifyPatternFlowStackFrame(names=flow_names))
308
-
309
-
310
281
  def trigger_pattern_completed(
311
282
  current_frame: DialogueStackFrame, stack: DialogueStack, flows: FlowsList
312
283
  ) -> None:
@@ -675,15 +646,9 @@ def _run_end_step(
675
646
  structlogger.debug("flow.step.run.flow_end")
676
647
  current_frame = stack.pop()
677
648
  trigger_pattern_completed(current_frame, stack, flows)
678
- resumed_events = []
679
- if len(user_flows_on_the_stack(stack)) > 1:
680
- # if there are more user flows on the stack,
681
- # we need to trigger the pattern clarify
682
- trigger_pattern_clarification(current_frame, stack, flows)
683
- else:
684
- resumed_events = trigger_pattern_continue_interrupted(
685
- current_frame, stack, flows, tracker
686
- )
649
+ resumed_events = trigger_pattern_continue_interrupted(
650
+ current_frame, stack, flows, tracker
651
+ )
687
652
  reset_events: List[Event] = reset_scoped_slots(current_frame, flow, tracker)
688
653
  return ContinueFlowWithNextStep(
689
654
  events=initial_events + reset_events + resumed_events, has_flow_ended=True