rasa-pro 3.12.18.dev1__py3-none-any.whl → 3.12.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (53) hide show
  1. rasa/__init__.py +0 -6
  2. rasa/core/actions/action.py +2 -5
  3. rasa/core/actions/action_repeat_bot_messages.py +18 -22
  4. rasa/core/channels/voice_stream/asr/asr_engine.py +5 -1
  5. rasa/core/channels/voice_stream/asr/azure.py +9 -0
  6. rasa/core/channels/voice_stream/asr/deepgram.py +5 -0
  7. rasa/core/channels/voice_stream/audiocodes.py +9 -4
  8. rasa/core/channels/voice_stream/twilio_media_streams.py +7 -0
  9. rasa/core/channels/voice_stream/voice_channel.py +47 -9
  10. rasa/core/policies/enterprise_search_policy.py +196 -72
  11. rasa/core/policies/intentless_policy.py +1 -3
  12. rasa/core/processor.py +50 -5
  13. rasa/core/utils.py +11 -2
  14. rasa/dialogue_understanding/coexistence/llm_based_router.py +1 -0
  15. rasa/dialogue_understanding/commands/__init__.py +4 -0
  16. rasa/dialogue_understanding/commands/cancel_flow_command.py +3 -1
  17. rasa/dialogue_understanding/commands/correct_slots_command.py +0 -10
  18. rasa/dialogue_understanding/commands/set_slot_command.py +6 -0
  19. rasa/dialogue_understanding/commands/utils.py +26 -2
  20. rasa/dialogue_understanding/generator/command_generator.py +15 -5
  21. rasa/dialogue_understanding/generator/llm_based_command_generator.py +4 -15
  22. rasa/dialogue_understanding/generator/llm_command_generator.py +1 -3
  23. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +4 -44
  24. rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +1 -14
  25. rasa/dialogue_understanding/processor/command_processor.py +23 -16
  26. rasa/dialogue_understanding/stack/frames/flow_stack_frame.py +17 -4
  27. rasa/dialogue_understanding/stack/utils.py +3 -1
  28. rasa/dialogue_understanding/utils.py +68 -12
  29. rasa/dialogue_understanding_test/du_test_schema.yml +3 -3
  30. rasa/e2e_test/e2e_test_coverage_report.py +1 -1
  31. rasa/e2e_test/e2e_test_schema.yml +3 -3
  32. rasa/hooks.py +0 -55
  33. rasa/llm_fine_tuning/annotation_module.py +43 -11
  34. rasa/llm_fine_tuning/utils.py +2 -4
  35. rasa/shared/constants.py +0 -5
  36. rasa/shared/core/constants.py +1 -0
  37. rasa/shared/core/flows/constants.py +2 -0
  38. rasa/shared/core/flows/flow.py +129 -13
  39. rasa/shared/core/flows/flows_list.py +18 -1
  40. rasa/shared/core/flows/steps/link.py +7 -2
  41. rasa/shared/providers/constants.py +0 -9
  42. rasa/shared/providers/llm/_base_litellm_client.py +4 -14
  43. rasa/shared/providers/llm/litellm_router_llm_client.py +7 -17
  44. rasa/shared/providers/llm/llm_client.py +15 -24
  45. rasa/shared/providers/llm/self_hosted_llm_client.py +2 -10
  46. rasa/tracing/instrumentation/attribute_extractors.py +2 -2
  47. rasa/version.py +1 -1
  48. {rasa_pro-3.12.18.dev1.dist-info → rasa_pro-3.12.25.dist-info}/METADATA +3 -4
  49. {rasa_pro-3.12.18.dev1.dist-info → rasa_pro-3.12.25.dist-info}/RECORD +52 -53
  50. rasa/monkey_patches.py +0 -91
  51. {rasa_pro-3.12.18.dev1.dist-info → rasa_pro-3.12.25.dist-info}/NOTICE +0 -0
  52. {rasa_pro-3.12.18.dev1.dist-info → rasa_pro-3.12.25.dist-info}/WHEEL +0 -0
  53. {rasa_pro-3.12.18.dev1.dist-info → rasa_pro-3.12.25.dist-info}/entry_points.txt +0 -0
rasa/__init__.py CHANGED
@@ -5,11 +5,5 @@ from rasa import version
5
5
  # define the version before the other imports since these need it
6
6
  __version__ = version.__version__
7
7
 
8
- from litellm.integrations.langfuse.langfuse import LangFuseLogger
9
-
10
- from rasa.monkey_patches import litellm_langfuse_logger_init_fixed
11
-
12
- # Monkey-patch the init method as early as possible before the class is used
13
- LangFuseLogger.__init__ = litellm_langfuse_logger_init_fixed # type: ignore
14
8
 
15
9
  logging.getLogger(__name__).addHandler(logging.NullHandler())
@@ -898,7 +898,7 @@ class RemoteAction(Action):
898
898
  draft["buttons"].extend(buttons)
899
899
 
900
900
  # Avoid overwriting `draft` values with empty values
901
- response = {k: v for k, v in response.items() if v is not None}
901
+ response = {k: v for k, v in response.items() if v}
902
902
  draft.update(response)
903
903
  bot_messages.append(create_bot_utterance(draft))
904
904
 
@@ -1137,15 +1137,12 @@ class ActionSendText(Action):
1137
1137
  tracker: "DialogueStateTracker",
1138
1138
  domain: "Domain",
1139
1139
  metadata: Optional[Dict[Text, Any]] = None,
1140
- create_bot_uttered_event: bool = True,
1141
1140
  ) -> List[Event]:
1142
1141
  """Runs action. Please see parent class for the full docstring."""
1143
1142
  fallback = {"text": ""}
1144
1143
  metadata_copy = copy.deepcopy(metadata) if metadata else {}
1145
1144
  message = metadata_copy.get("message", fallback)
1146
- if create_bot_uttered_event:
1147
- return [create_bot_utterance(message)]
1148
- return []
1145
+ return [create_bot_utterance(message)]
1149
1146
 
1150
1147
 
1151
1148
  class ActionExtractSlots(Action):
@@ -25,7 +25,7 @@ class ActionRepeatBotMessages(Action):
25
25
  """Return the name of the action."""
26
26
  return ACTION_REPEAT_BOT_MESSAGES
27
27
 
28
- def _get_last_bot_events(self, tracker: DialogueStateTracker) -> List[Event]:
28
+ def _get_last_bot_events(self, tracker: DialogueStateTracker) -> List[BotUttered]:
29
29
  """Get the last consecutive bot events before the most recent user message.
30
30
 
31
31
  This function scans the dialogue history in reverse to find the last sequence of
@@ -48,33 +48,21 @@ class ActionRepeatBotMessages(Action):
48
48
  The elif condition doesn't break when it sees User3 event.
49
49
  But it does at User2 event.
50
50
  """
51
- # Skip action if we are in a collect information step whose
52
- # default behavior is to repeat anyways
53
- top_frame = tracker.stack.top(
54
- lambda frame: isinstance(frame, RepeatBotMessagesPatternFlowStackFrame)
55
- or isinstance(frame, UserSilencePatternFlowStackFrame)
56
- )
57
- if isinstance(top_frame, CollectInformationPatternFlowStackFrame):
58
- return []
59
51
  # filter user and bot events
60
- filtered = [
52
+ user_and_bot_events = [
61
53
  e for e in tracker.events if isinstance(e, (BotUttered, UserUttered))
62
54
  ]
63
- bot_events: List[Event] = []
55
+ last_bot_events: List[BotUttered] = []
64
56
 
65
57
  # find the last BotUttered events
66
- for e in reversed(filtered):
67
- if isinstance(e, BotUttered):
68
- # insert instead of append because the list is reversed
69
- bot_events.insert(0, e)
70
-
71
- # stop if a UserUttered event is found
72
- # only if we have collected some bot events already
73
- # this condition skips the first N UserUttered events
74
- elif bot_events:
58
+ for e in reversed(user_and_bot_events):
59
+ # stop when seeing a user event after having seen bot events already
60
+ if isinstance(e, UserUttered) and len(last_bot_events) > 0:
75
61
  break
62
+ elif isinstance(e, BotUttered):
63
+ last_bot_events.append(e)
76
64
 
77
- return bot_events
65
+ return list(reversed(last_bot_events))
78
66
 
79
67
  async def run(
80
68
  self,
@@ -85,5 +73,13 @@ class ActionRepeatBotMessages(Action):
85
73
  metadata: Optional[Dict[str, Any]] = None,
86
74
  ) -> List[Event]:
87
75
  """Send the last bot messages to the channel again"""
88
- bot_events = self._get_last_bot_events(tracker)
76
+ top_frame = tracker.stack.top(
77
+ lambda frame: isinstance(frame, RepeatBotMessagesPatternFlowStackFrame)
78
+ or isinstance(frame, UserSilencePatternFlowStackFrame)
79
+ )
80
+
81
+ bot_events: List[Event] = list(self._get_last_bot_events(tracker))
82
+ # drop the last bot event in a collect step as that part will be repeated anyway
83
+ if isinstance(top_frame, CollectInformationPatternFlowStackFrame):
84
+ bot_events = bot_events[:-1]
89
85
  return bot_events
@@ -26,7 +26,7 @@ logger = structlog.get_logger(__name__)
26
26
 
27
27
  @dataclass
28
28
  class ASREngineConfig(MergeableConfig):
29
- pass
29
+ keep_alive_interval: int = 5 # seconds
30
30
 
31
31
 
32
32
  class ASREngine(Generic[T]):
@@ -93,3 +93,7 @@ class ASREngine(Generic[T]):
93
93
  def get_default_config() -> T:
94
94
  """Get the default config for this component."""
95
95
  raise NotImplementedError
96
+
97
+ async def send_keep_alive(self) -> None:
98
+ """Send a keep-alive message to the ASR system if supported."""
99
+ pass
@@ -3,6 +3,8 @@ import os
3
3
  from dataclasses import dataclass
4
4
  from typing import Any, AsyncIterator, Dict, Optional
5
5
 
6
+ import structlog
7
+
6
8
  from rasa.core.channels.voice_stream.asr.asr_engine import ASREngine, ASREngineConfig
7
9
  from rasa.core.channels.voice_stream.asr.asr_event import (
8
10
  ASREvent,
@@ -13,6 +15,8 @@ from rasa.core.channels.voice_stream.audio_bytes import HERTZ, RasaAudioBytes
13
15
  from rasa.shared.constants import AZURE_SPEECH_API_KEY_ENV_VAR
14
16
  from rasa.shared.exceptions import ConnectionException
15
17
 
18
+ logger = structlog.get_logger(__name__)
19
+
16
20
 
17
21
  @dataclass
18
22
  class AzureASRConfig(ASREngineConfig):
@@ -61,6 +65,11 @@ class AzureASR(ASREngine[AzureASRConfig]):
61
65
  and self.config.speech_endpoint is None
62
66
  ):
63
67
  self.config.speech_region = "eastus"
68
+ logger.warning(
69
+ "voice_channel.asr.azure.no_region",
70
+ message="No speech region configured, using 'eastus' as default",
71
+ region="eastus",
72
+ )
64
73
  speech_config = speechsdk.SpeechConfig(
65
74
  subscription=os.environ[AZURE_SPEECH_API_KEY_ENV_VAR],
66
75
  region=self.config.speech_region,
@@ -145,3 +145,8 @@ class DeepgramASR(ASREngine[DeepgramASRConfig]):
145
145
  def concatenate_transcripts(t1: str, t2: str) -> str:
146
146
  """Concatenate two transcripts making sure there is a space between them."""
147
147
  return (t1.strip() + " " + t2.strip()).strip()
148
+
149
+ async def send_keep_alive(self) -> None:
150
+ """Send a keep-alive message to the Deepgram websocket connection."""
151
+ if self.asr_socket is not None:
152
+ await self.asr_socket.send(json.dumps({"type": "KeepAlive"}))
@@ -81,6 +81,12 @@ class AudiocodesVoiceOutputChannel(VoiceOutputChannel):
81
81
  logger.debug("Sending start marker", stream_id=self._get_stream_id())
82
82
  await self.voice_websocket.send(media_message)
83
83
 
84
+ # This should be set when the bot actually starts speaking
85
+ # however, Audiocodes does not have an event to indicate that.
86
+ # This is an approximation, as the bot will be sent the audio chunks next
87
+ # which are played to the user immediately.
88
+ call_state.is_bot_speaking = True # type: ignore[attr-defined]
89
+
84
90
  async def send_intermediate_marker(self, recipient_id: str) -> None:
85
91
  """Audiocodes doesn't need intermediate markers, so do nothing."""
86
92
  pass
@@ -173,21 +179,20 @@ class AudiocodesVoiceInputChannel(VoiceInputChannel):
173
179
  if data["type"] == "activities":
174
180
  activities = data["activities"]
175
181
  for activity in activities:
176
- logger.debug("audiocodes_stream.activity", data=activity)
177
182
  if activity["name"] == "start":
178
- # already handled in collect_call_parameters
183
+ # handled in collect_call_parameters
179
184
  pass
180
185
  elif activity["name"] == "dtmf":
181
- # TODO: handle DTMF input
186
+ logger.info("audiocodes_stream.dtmf_ignored", data=activity)
182
187
  pass
183
188
  elif activity["name"] == "playFinished":
184
189
  logger.debug("audiocodes_stream.playFinished", data=activity)
190
+ call_state.is_bot_speaking = False # type: ignore[attr-defined]
185
191
  if call_state.should_hangup:
186
192
  logger.info("audiocodes_stream.hangup")
187
193
  self._send_hangup(ws, data)
188
194
  # the conversation should continue until
189
195
  # we receive a end message from audiocodes
190
- pass
191
196
  else:
192
197
  logger.warning("audiocodes_stream.unknown_activity", data=activity)
193
198
  elif data["type"] == "userStream.start":
@@ -135,6 +135,13 @@ class TwilioMediaStreamsInputChannel(VoiceInputChannel):
135
135
  def name(cls) -> str:
136
136
  return "twilio_media_streams"
137
137
 
138
+ def get_sender_id(self, call_parameters: CallParameters) -> str:
139
+ """Get the sender ID for the channel.
140
+
141
+ Twilio Media Streams uses the Stream ID as Sender ID because
142
+ it is required in OutputChannel.send_text_message to send messages."""
143
+ return call_parameters.stream_id # type: ignore[return-value]
144
+
138
145
  def channel_bytes_to_rasa_audio_bytes(self, input_bytes: bytes) -> RasaAudioBytes:
139
146
  return RasaAudioBytes(base64.b64decode(input_bytes))
140
147
 
@@ -288,6 +288,17 @@ class VoiceInputChannel(InputChannel):
288
288
  self.monitor_silence = monitor_silence
289
289
  self.tts_cache = TTSCache(tts_config.get("cache_size", 1000))
290
290
 
291
+ logger.info(
292
+ "voice_channel.initialized",
293
+ server_url=self.server_url,
294
+ asr_config=self.asr_config,
295
+ tts_config=self.tts_config,
296
+ )
297
+
298
+ def get_sender_id(self, call_parameters: CallParameters) -> str:
299
+ """Get the sender ID for the channel."""
300
+ return call_parameters.call_id
301
+
291
302
  async def monitor_silence_timeout(self, asr_event_queue: asyncio.Queue) -> None:
292
303
  timeout = call_state.silence_timeout
293
304
  if not timeout:
@@ -334,9 +345,9 @@ class VoiceInputChannel(InputChannel):
334
345
  ) -> None:
335
346
  output_channel = self.create_output_channel(channel_websocket, tts_engine)
336
347
  message = UserMessage(
337
- "/session_start",
338
- output_channel,
339
- call_parameters.stream_id,
348
+ text="/session_start",
349
+ output_channel=output_channel,
350
+ sender_id=self.get_sender_id(call_parameters),
340
351
  input_channel=self.name(),
341
352
  metadata=asdict(call_parameters),
342
353
  )
@@ -393,6 +404,9 @@ class VoiceInputChannel(InputChannel):
393
404
  await asr_engine.send_audio_chunks(channel_action.audio_bytes)
394
405
  elif isinstance(channel_action, EndConversationAction):
395
406
  # end stream event came from the other side
407
+ await self.handle_disconnect(
408
+ channel_websocket, on_new_message, tts_engine, call_parameters
409
+ )
396
410
  break
397
411
 
398
412
  async def receive_asr_events() -> None:
@@ -410,10 +424,17 @@ class VoiceInputChannel(InputChannel):
410
424
  call_parameters,
411
425
  )
412
426
 
427
+ async def asr_keep_alive_task() -> None:
428
+ interval = getattr(asr_engine.config, "keep_alive_interval", 5)
429
+ while True:
430
+ await asyncio.sleep(interval)
431
+ await asr_engine.send_keep_alive()
432
+
413
433
  tasks = [
414
434
  asyncio.create_task(consume_audio_bytes()),
415
435
  asyncio.create_task(receive_asr_events()),
416
436
  asyncio.create_task(handle_asr_events()),
437
+ asyncio.create_task(asr_keep_alive_task()),
417
438
  ]
418
439
  await asyncio.wait(
419
440
  tasks,
@@ -449,9 +470,9 @@ class VoiceInputChannel(InputChannel):
449
470
  call_state.is_user_speaking = False # type: ignore[attr-defined]
450
471
  output_channel = self.create_output_channel(voice_websocket, tts_engine)
451
472
  message = UserMessage(
452
- e.text,
453
- output_channel,
454
- call_parameters.stream_id,
473
+ text=e.text,
474
+ output_channel=output_channel,
475
+ sender_id=self.get_sender_id(call_parameters),
455
476
  input_channel=self.name(),
456
477
  metadata=asdict(call_parameters),
457
478
  )
@@ -462,10 +483,27 @@ class VoiceInputChannel(InputChannel):
462
483
  elif isinstance(e, UserSilence):
463
484
  output_channel = self.create_output_channel(voice_websocket, tts_engine)
464
485
  message = UserMessage(
465
- "/silence_timeout",
466
- output_channel,
467
- call_parameters.stream_id,
486
+ text="/silence_timeout",
487
+ output_channel=output_channel,
488
+ sender_id=self.get_sender_id(call_parameters),
468
489
  input_channel=self.name(),
469
490
  metadata=asdict(call_parameters),
470
491
  )
471
492
  await on_new_message(message)
493
+
494
+ async def handle_disconnect(
495
+ self,
496
+ channel_websocket: Websocket,
497
+ on_new_message: Callable[[UserMessage], Awaitable[Any]],
498
+ tts_engine: TTSEngine,
499
+ call_parameters: CallParameters,
500
+ ) -> None:
501
+ """Handle disconnection from the channel."""
502
+ output_channel = self.create_output_channel(channel_websocket, tts_engine)
503
+ message = UserMessage(
504
+ text="/session_end",
505
+ output_channel=output_channel,
506
+ sender_id=self.get_sender_id(call_parameters),
507
+ input_channel=self.name(),
508
+ )
509
+ await on_new_message(message)
@@ -1,7 +1,9 @@
1
+ import glob
1
2
  import importlib.resources
2
3
  import json
4
+ import os.path
3
5
  import re
4
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Text
6
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Text, Tuple
5
7
 
6
8
  import dotenv
7
9
  import structlog
@@ -162,6 +164,8 @@ DEFAULT_ENTERPRISE_SEARCH_PROMPT_WITH_CITATION_TEMPLATE = importlib.resources.re
162
164
  "rasa.core.policies", "enterprise_search_prompt_with_citation_template.jinja2"
163
165
  )
164
166
 
167
+ _ENTERPRISE_SEARCH_CITATION_PATTERN = re.compile(r"\[([^\]]+)\]")
168
+
165
169
 
166
170
  class VectorStoreConnectionError(RasaException):
167
171
  """Exception raised for errors in connecting to the vector store."""
@@ -378,9 +382,11 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
378
382
 
379
383
  if store_type == DEFAULT_VECTOR_STORE_TYPE:
380
384
  logger.info("enterprise_search_policy.train.faiss")
385
+ docs_folder = self.vector_store_config.get(SOURCE_PROPERTY)
386
+ self._validate_documents_folder(docs_folder)
381
387
  with self._model_storage.write_to(self._resource) as path:
382
388
  self.vector_store = FAISS_Store(
383
- docs_folder=self.vector_store_config.get(SOURCE_PROPERTY),
389
+ docs_folder=docs_folder,
384
390
  embeddings=embeddings,
385
391
  index_path=path,
386
392
  create_index=True,
@@ -760,6 +766,33 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
760
766
  result[domain.index_for_action(action_name)] = score # type: ignore[assignment]
761
767
  return result
762
768
 
769
+ @classmethod
770
+ def _validate_documents_folder(cls, docs_folder: str) -> None:
771
+ if not os.path.exists(docs_folder) or not os.path.isdir(docs_folder):
772
+ error_message = (
773
+ f"Document source directory does not exist or is not a "
774
+ f"directory: '{docs_folder}'. "
775
+ "Please specify a valid path to the documents source directory in the "
776
+ "vector_store configuration."
777
+ )
778
+ logger.error(
779
+ "enterprise_search_policy.train.faiss.invalid_source_directory",
780
+ message=error_message,
781
+ )
782
+ print_error_and_exit(error_message)
783
+
784
+ docs = glob.glob(os.path.join(docs_folder, "*.txt"), recursive=True)
785
+ if not docs or len(docs) < 1:
786
+ error_message = (
787
+ f"Document source directory is empty: '{docs_folder}'. "
788
+ "Please add documents to this directory or specify a different one."
789
+ )
790
+ logger.error(
791
+ "enterprise_search_policy.train.faiss.source_directory_empty",
792
+ message=error_message,
793
+ )
794
+ print_error_and_exit(error_message)
795
+
763
796
  @classmethod
764
797
  def load(
765
798
  cls,
@@ -833,7 +866,7 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
833
866
  return None
834
867
 
835
868
  source = merged_config.get(VECTOR_STORE_PROPERTY, {}).get(SOURCE_PROPERTY)
836
- if not source:
869
+ if not source or not os.path.exists(source) or not os.path.isdir(source):
837
870
  return None
838
871
 
839
872
  docs = FAISS_Store.load_documents(source)
@@ -870,10 +903,18 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
870
903
 
871
904
  @staticmethod
872
905
  def post_process_citations(llm_answer: str) -> str:
873
- """Post-process the LLM answer.
874
-
875
- Re-writes the bracketed numbers to start from 1 and
876
- re-arranges the sources to follow the enumeration order.
906
+ """Post-processes the LLM answer to correctly number and sort citations and
907
+ sources.
908
+
909
+ - Handles both single `[1]` and grouped `[1, 3]` citations.
910
+ - Rewrites the numbers in square brackets in the answer text to start from 1
911
+ and be sorted within each group.
912
+ - Reorders the sources according to the order of their first appearance
913
+ in the text.
914
+ - Removes citations from the text that point to sources missing from
915
+ the source list.
916
+ - Keeps sources that are not cited in the text, placing them at the end
917
+ of the list.
877
918
 
878
919
  Args:
879
920
  llm_answer: The LLM answer.
@@ -887,77 +928,160 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
887
928
 
888
929
  # Split llm_answer into answer and citations
889
930
  try:
890
- answer, citations = llm_answer.rsplit("Sources:", 1)
931
+ answer_part, sources_part = llm_answer.rsplit("Sources:", 1)
891
932
  except ValueError:
892
- # if there is no "Sources:" in the llm_answer
893
- return llm_answer
894
-
895
- # Find all source references in the answer
896
- pattern = r"\[\s*(\d+(?:\s*,\s*\d+)*)\s*\]"
897
- matches = re.findall(pattern, answer)
898
- old_source_indices = [
899
- int(num.strip()) for match in matches for num in match.split(",")
900
- ]
933
+ # if there is no "Sources:" separator, return the original llm_answer
934
+ return llm_answer.strip()
935
+
936
+ # Parse the sources block to extract valid sources and other lines
937
+ valid_sources, other_source_lines = EnterpriseSearchPolicy._parse_sources_block(
938
+ sources_part
939
+ )
940
+
941
+ # Find all unique, valid citations in the answer text in their order
942
+ # of appearance
943
+ cited_order = EnterpriseSearchPolicy._get_cited_order(
944
+ answer_part, valid_sources
945
+ )
946
+
947
+ # Create a mapping from the old source numbers to the new, sequential numbers.
948
+ # For example, if the citation order in the text was [3, 1, 2], this map
949
+ # becomes {3: 1, 1: 2, 2: 3}. This allows for a quick lookup when rewriting
950
+ # the citations
951
+ renumbering_map = {
952
+ old_num: new_num + 1 for new_num, old_num in enumerate(cited_order)
953
+ }
954
+
955
+ # Rewrite the citations in the answer text based on the renumbering map
956
+ processed_answer = EnterpriseSearchPolicy._rewrite_answer_citations(
957
+ answer_part, renumbering_map
958
+ )
959
+
960
+ # Build the new list of sources
961
+ new_sources_list = EnterpriseSearchPolicy._build_final_sources_list(
962
+ cited_order,
963
+ renumbering_map,
964
+ valid_sources,
965
+ other_source_lines,
966
+ )
967
+
968
+ if len(new_sources_list) > 0:
969
+ processed_answer += "\nSources:\n" + "\n".join(new_sources_list)
901
970
 
902
- # Map old source references to the correct enumeration
903
- renumber_mapping = {num: idx + 1 for idx, num in enumerate(old_source_indices)}
904
-
905
- # remove whitespace from original source citations in answer
906
- for match in matches:
907
- answer = answer.replace(f"[{match}]", f"[{match.replace(' ', '')}]")
908
-
909
- new_answer = []
910
- for word in answer.split():
911
- matches = re.findall(pattern, word)
912
- if matches:
913
- for match in matches:
914
- if "," in match:
915
- old_indices = [
916
- int(num.strip()) for num in match.split(",") if num
917
- ]
918
- new_indices = [
919
- renumber_mapping[old_index]
920
- for old_index in old_indices
921
- if old_index in renumber_mapping
922
- ]
923
- if not new_indices:
924
- continue
925
-
926
- word = word.replace(
927
- match, f"{', '.join(map(str, new_indices))}"
928
- )
929
- else:
930
- old_index = int(match.strip("[].,:;?!"))
931
- new_index = renumber_mapping.get(old_index)
932
- if not new_index:
933
- continue
934
-
935
- word = word.replace(str(old_index), str(new_index))
936
- new_answer.append(word)
937
-
938
- # join the words
939
- joined_answer = " ".join(new_answer)
940
- joined_answer += "\nSources:\n"
941
-
942
- new_sources: List[str] = []
943
-
944
- for line in citations.split("\n"):
945
- pattern = r"(?<=\[)\d+"
946
- match = re.search(pattern, line)
971
+ return processed_answer
972
+
973
+ @staticmethod
974
+ def _parse_sources_block(sources_part: str) -> Tuple[Dict[int, str], List[str]]:
975
+ """Parses the sources block from the LLM response.
976
+ Returns a tuple containing:
977
+ - A dictionary of valid sources matching the "[1] ..." format,
978
+ where the key is the source number
979
+ - A list of other source lines that do not match the specified format
980
+ """
981
+ valid_sources: Dict[int, str] = {}
982
+ other_source_lines: List[str] = []
983
+ source_line_pattern = re.compile(r"^\s*\[(\d+)\](.*)")
984
+
985
+ source_lines = sources_part.strip().split("\n")
986
+
987
+ for line in source_lines:
988
+ line = line.strip()
989
+ if not line:
990
+ continue
991
+
992
+ match = source_line_pattern.match(line)
947
993
  if match:
948
- old_index = int(match.group(0))
949
- new_index = renumber_mapping[old_index]
950
- # replace only the first occurrence of the old index
951
- line = line.replace(f"[{old_index}]", f"[{new_index}]", 1)
994
+ num = int(match.group(1))
995
+ valid_sources[num] = line
996
+ else:
997
+ other_source_lines.append(line)
998
+
999
+ return valid_sources, other_source_lines
1000
+
1001
+ @staticmethod
1002
+ def _get_cited_order(
1003
+ answer_part: str, available_sources: Dict[int, str]
1004
+ ) -> List[int]:
1005
+ """Find all unique, valid citations in the answer text in their order
1006
+ # of appearance
1007
+ """
1008
+ cited_order: List[int] = []
1009
+ seen_indices = set()
1010
+
1011
+ for match in _ENTERPRISE_SEARCH_CITATION_PATTERN.finditer(answer_part):
1012
+ content = match.group(1)
1013
+ indices_str = [s.strip() for s in content.split(",")]
1014
+ for index_str in indices_str:
1015
+ if index_str.isdigit():
1016
+ index = int(index_str)
1017
+ if index in available_sources and index not in seen_indices:
1018
+ cited_order.append(index)
1019
+ seen_indices.add(index)
1020
+
1021
+ return cited_order
1022
+
1023
+ @staticmethod
1024
+ def _rewrite_answer_citations(
1025
+ answer_part: str, renumber_map: Dict[int, int]
1026
+ ) -> str:
1027
+ """Rewrites the citations in the answer text based on the renumbering map."""
1028
+
1029
+ def replacer(match: re.Match) -> str:
1030
+ content = match.group(1)
1031
+ old_indices_str = [s.strip() for s in content.split(",")]
1032
+ new_indices = [
1033
+ renumber_map[int(s)]
1034
+ for s in old_indices_str
1035
+ if s.isdigit() and int(s) in renumber_map
1036
+ ]
1037
+ if not new_indices:
1038
+ return ""
1039
+
1040
+ return f"[{', '.join(map(str, sorted(list(set(new_indices)))))}]"
1041
+
1042
+ processed_answer = _ENTERPRISE_SEARCH_CITATION_PATTERN.sub(
1043
+ replacer, answer_part
1044
+ )
1045
+
1046
+ # Clean up formatting after replacements
1047
+ processed_answer = re.sub(r"\s+([,.?])", r"\1", processed_answer)
1048
+ processed_answer = processed_answer.replace("[]", " ")
1049
+ processed_answer = re.sub(r"\s+", " ", processed_answer)
1050
+ processed_answer = processed_answer.strip()
1051
+
1052
+ return processed_answer
1053
+
1054
+ @staticmethod
1055
+ def _build_final_sources_list(
1056
+ cited_order: List[int],
1057
+ renumbering_map: Dict[int, int],
1058
+ valid_sources: Dict[int, str],
1059
+ other_source_lines: List[str],
1060
+ ) -> List[str]:
1061
+ """Builds the final list of sources based on the cited order and
1062
+ renumbering map.
1063
+ """
1064
+ new_sources_list: List[str] = []
1065
+
1066
+ # First, add the sorted, used sources
1067
+ for old_num in cited_order:
1068
+ new_num = renumbering_map[old_num]
1069
+ source_line = valid_sources[old_num]
1070
+ new_sources_list.append(
1071
+ source_line.replace(f"[{old_num}]", f"[{new_num}]", 1)
1072
+ )
952
1073
 
953
- # insert the line into the new_index position
954
- new_sources.insert(new_index - 1, line)
955
- elif line.strip():
956
- new_sources.append(line)
1074
+ # Then, add the unused but validly numbered sources
1075
+ used_source_nums = set(cited_order)
1076
+ # Sort by number to ensure a consistent order for uncited sources
1077
+ for num, line in sorted(valid_sources.items()):
1078
+ if num not in used_source_nums:
1079
+ new_sources_list.append(line)
957
1080
 
958
- joined_sources = "\n".join(new_sources)
1081
+ # Finally, add any other source lines
1082
+ new_sources_list.extend(other_source_lines)
959
1083
 
960
- return joined_answer + joined_sources
1084
+ return new_sources_list
961
1085
 
962
1086
  @classmethod
963
1087
  def _perform_health_checks(
@@ -721,9 +721,7 @@ class IntentlessPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Policy):
721
721
  final_response_examples.append(resp)
722
722
 
723
723
  llm_response = await self.generate_answer(
724
- final_response_examples,
725
- conversation_samples,
726
- history,
724
+ final_response_examples, conversation_samples, history
727
725
  )
728
726
  if not llm_response:
729
727
  structlogger.debug("intentless_policy.prediction.skip_llm_fail")