rasa-pro 3.12.18.dev1__py3-none-any.whl → 3.12.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/__init__.py +0 -6
- rasa/core/actions/action.py +2 -5
- rasa/core/actions/action_repeat_bot_messages.py +18 -22
- rasa/core/channels/voice_stream/asr/asr_engine.py +5 -1
- rasa/core/channels/voice_stream/asr/azure.py +9 -0
- rasa/core/channels/voice_stream/asr/deepgram.py +5 -0
- rasa/core/channels/voice_stream/audiocodes.py +9 -4
- rasa/core/channels/voice_stream/twilio_media_streams.py +7 -0
- rasa/core/channels/voice_stream/voice_channel.py +47 -9
- rasa/core/policies/enterprise_search_policy.py +196 -72
- rasa/core/policies/intentless_policy.py +1 -3
- rasa/core/processor.py +50 -5
- rasa/core/utils.py +11 -2
- rasa/dialogue_understanding/coexistence/llm_based_router.py +1 -0
- rasa/dialogue_understanding/commands/__init__.py +4 -0
- rasa/dialogue_understanding/commands/cancel_flow_command.py +3 -1
- rasa/dialogue_understanding/commands/correct_slots_command.py +0 -10
- rasa/dialogue_understanding/commands/set_slot_command.py +6 -0
- rasa/dialogue_understanding/commands/utils.py +26 -2
- rasa/dialogue_understanding/generator/command_generator.py +15 -5
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +4 -15
- rasa/dialogue_understanding/generator/llm_command_generator.py +1 -3
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +4 -44
- rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +1 -14
- rasa/dialogue_understanding/processor/command_processor.py +23 -16
- rasa/dialogue_understanding/stack/frames/flow_stack_frame.py +17 -4
- rasa/dialogue_understanding/stack/utils.py +3 -1
- rasa/dialogue_understanding/utils.py +68 -12
- rasa/dialogue_understanding_test/du_test_schema.yml +3 -3
- rasa/e2e_test/e2e_test_coverage_report.py +1 -1
- rasa/e2e_test/e2e_test_schema.yml +3 -3
- rasa/hooks.py +0 -55
- rasa/llm_fine_tuning/annotation_module.py +43 -11
- rasa/llm_fine_tuning/utils.py +2 -4
- rasa/shared/constants.py +0 -5
- rasa/shared/core/constants.py +1 -0
- rasa/shared/core/flows/constants.py +2 -0
- rasa/shared/core/flows/flow.py +129 -13
- rasa/shared/core/flows/flows_list.py +18 -1
- rasa/shared/core/flows/steps/link.py +7 -2
- rasa/shared/providers/constants.py +0 -9
- rasa/shared/providers/llm/_base_litellm_client.py +4 -14
- rasa/shared/providers/llm/litellm_router_llm_client.py +7 -17
- rasa/shared/providers/llm/llm_client.py +15 -24
- rasa/shared/providers/llm/self_hosted_llm_client.py +2 -10
- rasa/tracing/instrumentation/attribute_extractors.py +2 -2
- rasa/version.py +1 -1
- {rasa_pro-3.12.18.dev1.dist-info → rasa_pro-3.12.25.dist-info}/METADATA +3 -4
- {rasa_pro-3.12.18.dev1.dist-info → rasa_pro-3.12.25.dist-info}/RECORD +52 -53
- rasa/monkey_patches.py +0 -91
- {rasa_pro-3.12.18.dev1.dist-info → rasa_pro-3.12.25.dist-info}/NOTICE +0 -0
- {rasa_pro-3.12.18.dev1.dist-info → rasa_pro-3.12.25.dist-info}/WHEEL +0 -0
- {rasa_pro-3.12.18.dev1.dist-info → rasa_pro-3.12.25.dist-info}/entry_points.txt +0 -0
rasa/__init__.py
CHANGED
|
@@ -5,11 +5,5 @@ from rasa import version
|
|
|
5
5
|
# define the version before the other imports since these need it
|
|
6
6
|
__version__ = version.__version__
|
|
7
7
|
|
|
8
|
-
from litellm.integrations.langfuse.langfuse import LangFuseLogger
|
|
9
|
-
|
|
10
|
-
from rasa.monkey_patches import litellm_langfuse_logger_init_fixed
|
|
11
|
-
|
|
12
|
-
# Monkey-patch the init method as early as possible before the class is used
|
|
13
|
-
LangFuseLogger.__init__ = litellm_langfuse_logger_init_fixed # type: ignore
|
|
14
8
|
|
|
15
9
|
logging.getLogger(__name__).addHandler(logging.NullHandler())
|
rasa/core/actions/action.py
CHANGED
|
@@ -898,7 +898,7 @@ class RemoteAction(Action):
|
|
|
898
898
|
draft["buttons"].extend(buttons)
|
|
899
899
|
|
|
900
900
|
# Avoid overwriting `draft` values with empty values
|
|
901
|
-
response = {k: v for k, v in response.items() if v
|
|
901
|
+
response = {k: v for k, v in response.items() if v}
|
|
902
902
|
draft.update(response)
|
|
903
903
|
bot_messages.append(create_bot_utterance(draft))
|
|
904
904
|
|
|
@@ -1137,15 +1137,12 @@ class ActionSendText(Action):
|
|
|
1137
1137
|
tracker: "DialogueStateTracker",
|
|
1138
1138
|
domain: "Domain",
|
|
1139
1139
|
metadata: Optional[Dict[Text, Any]] = None,
|
|
1140
|
-
create_bot_uttered_event: bool = True,
|
|
1141
1140
|
) -> List[Event]:
|
|
1142
1141
|
"""Runs action. Please see parent class for the full docstring."""
|
|
1143
1142
|
fallback = {"text": ""}
|
|
1144
1143
|
metadata_copy = copy.deepcopy(metadata) if metadata else {}
|
|
1145
1144
|
message = metadata_copy.get("message", fallback)
|
|
1146
|
-
|
|
1147
|
-
return [create_bot_utterance(message)]
|
|
1148
|
-
return []
|
|
1145
|
+
return [create_bot_utterance(message)]
|
|
1149
1146
|
|
|
1150
1147
|
|
|
1151
1148
|
class ActionExtractSlots(Action):
|
|
@@ -25,7 +25,7 @@ class ActionRepeatBotMessages(Action):
|
|
|
25
25
|
"""Return the name of the action."""
|
|
26
26
|
return ACTION_REPEAT_BOT_MESSAGES
|
|
27
27
|
|
|
28
|
-
def _get_last_bot_events(self, tracker: DialogueStateTracker) -> List[
|
|
28
|
+
def _get_last_bot_events(self, tracker: DialogueStateTracker) -> List[BotUttered]:
|
|
29
29
|
"""Get the last consecutive bot events before the most recent user message.
|
|
30
30
|
|
|
31
31
|
This function scans the dialogue history in reverse to find the last sequence of
|
|
@@ -48,33 +48,21 @@ class ActionRepeatBotMessages(Action):
|
|
|
48
48
|
The elif condition doesn't break when it sees User3 event.
|
|
49
49
|
But it does at User2 event.
|
|
50
50
|
"""
|
|
51
|
-
# Skip action if we are in a collect information step whose
|
|
52
|
-
# default behavior is to repeat anyways
|
|
53
|
-
top_frame = tracker.stack.top(
|
|
54
|
-
lambda frame: isinstance(frame, RepeatBotMessagesPatternFlowStackFrame)
|
|
55
|
-
or isinstance(frame, UserSilencePatternFlowStackFrame)
|
|
56
|
-
)
|
|
57
|
-
if isinstance(top_frame, CollectInformationPatternFlowStackFrame):
|
|
58
|
-
return []
|
|
59
51
|
# filter user and bot events
|
|
60
|
-
|
|
52
|
+
user_and_bot_events = [
|
|
61
53
|
e for e in tracker.events if isinstance(e, (BotUttered, UserUttered))
|
|
62
54
|
]
|
|
63
|
-
|
|
55
|
+
last_bot_events: List[BotUttered] = []
|
|
64
56
|
|
|
65
57
|
# find the last BotUttered events
|
|
66
|
-
for e in reversed(
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
bot_events.insert(0, e)
|
|
70
|
-
|
|
71
|
-
# stop if a UserUttered event is found
|
|
72
|
-
# only if we have collected some bot events already
|
|
73
|
-
# this condition skips the first N UserUttered events
|
|
74
|
-
elif bot_events:
|
|
58
|
+
for e in reversed(user_and_bot_events):
|
|
59
|
+
# stop when seeing a user event after having seen bot events already
|
|
60
|
+
if isinstance(e, UserUttered) and len(last_bot_events) > 0:
|
|
75
61
|
break
|
|
62
|
+
elif isinstance(e, BotUttered):
|
|
63
|
+
last_bot_events.append(e)
|
|
76
64
|
|
|
77
|
-
return
|
|
65
|
+
return list(reversed(last_bot_events))
|
|
78
66
|
|
|
79
67
|
async def run(
|
|
80
68
|
self,
|
|
@@ -85,5 +73,13 @@ class ActionRepeatBotMessages(Action):
|
|
|
85
73
|
metadata: Optional[Dict[str, Any]] = None,
|
|
86
74
|
) -> List[Event]:
|
|
87
75
|
"""Send the last bot messages to the channel again"""
|
|
88
|
-
|
|
76
|
+
top_frame = tracker.stack.top(
|
|
77
|
+
lambda frame: isinstance(frame, RepeatBotMessagesPatternFlowStackFrame)
|
|
78
|
+
or isinstance(frame, UserSilencePatternFlowStackFrame)
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
bot_events: List[Event] = list(self._get_last_bot_events(tracker))
|
|
82
|
+
# drop the last bot event in a collect step as that part will be repeated anyway
|
|
83
|
+
if isinstance(top_frame, CollectInformationPatternFlowStackFrame):
|
|
84
|
+
bot_events = bot_events[:-1]
|
|
89
85
|
return bot_events
|
|
@@ -26,7 +26,7 @@ logger = structlog.get_logger(__name__)
|
|
|
26
26
|
|
|
27
27
|
@dataclass
|
|
28
28
|
class ASREngineConfig(MergeableConfig):
|
|
29
|
-
|
|
29
|
+
keep_alive_interval: int = 5 # seconds
|
|
30
30
|
|
|
31
31
|
|
|
32
32
|
class ASREngine(Generic[T]):
|
|
@@ -93,3 +93,7 @@ class ASREngine(Generic[T]):
|
|
|
93
93
|
def get_default_config() -> T:
|
|
94
94
|
"""Get the default config for this component."""
|
|
95
95
|
raise NotImplementedError
|
|
96
|
+
|
|
97
|
+
async def send_keep_alive(self) -> None:
|
|
98
|
+
"""Send a keep-alive message to the ASR system if supported."""
|
|
99
|
+
pass
|
|
@@ -3,6 +3,8 @@ import os
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
4
|
from typing import Any, AsyncIterator, Dict, Optional
|
|
5
5
|
|
|
6
|
+
import structlog
|
|
7
|
+
|
|
6
8
|
from rasa.core.channels.voice_stream.asr.asr_engine import ASREngine, ASREngineConfig
|
|
7
9
|
from rasa.core.channels.voice_stream.asr.asr_event import (
|
|
8
10
|
ASREvent,
|
|
@@ -13,6 +15,8 @@ from rasa.core.channels.voice_stream.audio_bytes import HERTZ, RasaAudioBytes
|
|
|
13
15
|
from rasa.shared.constants import AZURE_SPEECH_API_KEY_ENV_VAR
|
|
14
16
|
from rasa.shared.exceptions import ConnectionException
|
|
15
17
|
|
|
18
|
+
logger = structlog.get_logger(__name__)
|
|
19
|
+
|
|
16
20
|
|
|
17
21
|
@dataclass
|
|
18
22
|
class AzureASRConfig(ASREngineConfig):
|
|
@@ -61,6 +65,11 @@ class AzureASR(ASREngine[AzureASRConfig]):
|
|
|
61
65
|
and self.config.speech_endpoint is None
|
|
62
66
|
):
|
|
63
67
|
self.config.speech_region = "eastus"
|
|
68
|
+
logger.warning(
|
|
69
|
+
"voice_channel.asr.azure.no_region",
|
|
70
|
+
message="No speech region configured, using 'eastus' as default",
|
|
71
|
+
region="eastus",
|
|
72
|
+
)
|
|
64
73
|
speech_config = speechsdk.SpeechConfig(
|
|
65
74
|
subscription=os.environ[AZURE_SPEECH_API_KEY_ENV_VAR],
|
|
66
75
|
region=self.config.speech_region,
|
|
@@ -145,3 +145,8 @@ class DeepgramASR(ASREngine[DeepgramASRConfig]):
|
|
|
145
145
|
def concatenate_transcripts(t1: str, t2: str) -> str:
|
|
146
146
|
"""Concatenate two transcripts making sure there is a space between them."""
|
|
147
147
|
return (t1.strip() + " " + t2.strip()).strip()
|
|
148
|
+
|
|
149
|
+
async def send_keep_alive(self) -> None:
|
|
150
|
+
"""Send a keep-alive message to the Deepgram websocket connection."""
|
|
151
|
+
if self.asr_socket is not None:
|
|
152
|
+
await self.asr_socket.send(json.dumps({"type": "KeepAlive"}))
|
|
@@ -81,6 +81,12 @@ class AudiocodesVoiceOutputChannel(VoiceOutputChannel):
|
|
|
81
81
|
logger.debug("Sending start marker", stream_id=self._get_stream_id())
|
|
82
82
|
await self.voice_websocket.send(media_message)
|
|
83
83
|
|
|
84
|
+
# This should be set when the bot actually starts speaking
|
|
85
|
+
# however, Audiocodes does not have an event to indicate that.
|
|
86
|
+
# This is an approximation, as the bot will be sent the audio chunks next
|
|
87
|
+
# which are played to the user immediately.
|
|
88
|
+
call_state.is_bot_speaking = True # type: ignore[attr-defined]
|
|
89
|
+
|
|
84
90
|
async def send_intermediate_marker(self, recipient_id: str) -> None:
|
|
85
91
|
"""Audiocodes doesn't need intermediate markers, so do nothing."""
|
|
86
92
|
pass
|
|
@@ -173,21 +179,20 @@ class AudiocodesVoiceInputChannel(VoiceInputChannel):
|
|
|
173
179
|
if data["type"] == "activities":
|
|
174
180
|
activities = data["activities"]
|
|
175
181
|
for activity in activities:
|
|
176
|
-
logger.debug("audiocodes_stream.activity", data=activity)
|
|
177
182
|
if activity["name"] == "start":
|
|
178
|
-
#
|
|
183
|
+
# handled in collect_call_parameters
|
|
179
184
|
pass
|
|
180
185
|
elif activity["name"] == "dtmf":
|
|
181
|
-
|
|
186
|
+
logger.info("audiocodes_stream.dtmf_ignored", data=activity)
|
|
182
187
|
pass
|
|
183
188
|
elif activity["name"] == "playFinished":
|
|
184
189
|
logger.debug("audiocodes_stream.playFinished", data=activity)
|
|
190
|
+
call_state.is_bot_speaking = False # type: ignore[attr-defined]
|
|
185
191
|
if call_state.should_hangup:
|
|
186
192
|
logger.info("audiocodes_stream.hangup")
|
|
187
193
|
self._send_hangup(ws, data)
|
|
188
194
|
# the conversation should continue until
|
|
189
195
|
# we receive a end message from audiocodes
|
|
190
|
-
pass
|
|
191
196
|
else:
|
|
192
197
|
logger.warning("audiocodes_stream.unknown_activity", data=activity)
|
|
193
198
|
elif data["type"] == "userStream.start":
|
|
@@ -135,6 +135,13 @@ class TwilioMediaStreamsInputChannel(VoiceInputChannel):
|
|
|
135
135
|
def name(cls) -> str:
|
|
136
136
|
return "twilio_media_streams"
|
|
137
137
|
|
|
138
|
+
def get_sender_id(self, call_parameters: CallParameters) -> str:
|
|
139
|
+
"""Get the sender ID for the channel.
|
|
140
|
+
|
|
141
|
+
Twilio Media Streams uses the Stream ID as Sender ID because
|
|
142
|
+
it is required in OutputChannel.send_text_message to send messages."""
|
|
143
|
+
return call_parameters.stream_id # type: ignore[return-value]
|
|
144
|
+
|
|
138
145
|
def channel_bytes_to_rasa_audio_bytes(self, input_bytes: bytes) -> RasaAudioBytes:
|
|
139
146
|
return RasaAudioBytes(base64.b64decode(input_bytes))
|
|
140
147
|
|
|
@@ -288,6 +288,17 @@ class VoiceInputChannel(InputChannel):
|
|
|
288
288
|
self.monitor_silence = monitor_silence
|
|
289
289
|
self.tts_cache = TTSCache(tts_config.get("cache_size", 1000))
|
|
290
290
|
|
|
291
|
+
logger.info(
|
|
292
|
+
"voice_channel.initialized",
|
|
293
|
+
server_url=self.server_url,
|
|
294
|
+
asr_config=self.asr_config,
|
|
295
|
+
tts_config=self.tts_config,
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
def get_sender_id(self, call_parameters: CallParameters) -> str:
|
|
299
|
+
"""Get the sender ID for the channel."""
|
|
300
|
+
return call_parameters.call_id
|
|
301
|
+
|
|
291
302
|
async def monitor_silence_timeout(self, asr_event_queue: asyncio.Queue) -> None:
|
|
292
303
|
timeout = call_state.silence_timeout
|
|
293
304
|
if not timeout:
|
|
@@ -334,9 +345,9 @@ class VoiceInputChannel(InputChannel):
|
|
|
334
345
|
) -> None:
|
|
335
346
|
output_channel = self.create_output_channel(channel_websocket, tts_engine)
|
|
336
347
|
message = UserMessage(
|
|
337
|
-
"/session_start",
|
|
338
|
-
output_channel,
|
|
339
|
-
call_parameters
|
|
348
|
+
text="/session_start",
|
|
349
|
+
output_channel=output_channel,
|
|
350
|
+
sender_id=self.get_sender_id(call_parameters),
|
|
340
351
|
input_channel=self.name(),
|
|
341
352
|
metadata=asdict(call_parameters),
|
|
342
353
|
)
|
|
@@ -393,6 +404,9 @@ class VoiceInputChannel(InputChannel):
|
|
|
393
404
|
await asr_engine.send_audio_chunks(channel_action.audio_bytes)
|
|
394
405
|
elif isinstance(channel_action, EndConversationAction):
|
|
395
406
|
# end stream event came from the other side
|
|
407
|
+
await self.handle_disconnect(
|
|
408
|
+
channel_websocket, on_new_message, tts_engine, call_parameters
|
|
409
|
+
)
|
|
396
410
|
break
|
|
397
411
|
|
|
398
412
|
async def receive_asr_events() -> None:
|
|
@@ -410,10 +424,17 @@ class VoiceInputChannel(InputChannel):
|
|
|
410
424
|
call_parameters,
|
|
411
425
|
)
|
|
412
426
|
|
|
427
|
+
async def asr_keep_alive_task() -> None:
|
|
428
|
+
interval = getattr(asr_engine.config, "keep_alive_interval", 5)
|
|
429
|
+
while True:
|
|
430
|
+
await asyncio.sleep(interval)
|
|
431
|
+
await asr_engine.send_keep_alive()
|
|
432
|
+
|
|
413
433
|
tasks = [
|
|
414
434
|
asyncio.create_task(consume_audio_bytes()),
|
|
415
435
|
asyncio.create_task(receive_asr_events()),
|
|
416
436
|
asyncio.create_task(handle_asr_events()),
|
|
437
|
+
asyncio.create_task(asr_keep_alive_task()),
|
|
417
438
|
]
|
|
418
439
|
await asyncio.wait(
|
|
419
440
|
tasks,
|
|
@@ -449,9 +470,9 @@ class VoiceInputChannel(InputChannel):
|
|
|
449
470
|
call_state.is_user_speaking = False # type: ignore[attr-defined]
|
|
450
471
|
output_channel = self.create_output_channel(voice_websocket, tts_engine)
|
|
451
472
|
message = UserMessage(
|
|
452
|
-
e.text,
|
|
453
|
-
output_channel,
|
|
454
|
-
call_parameters
|
|
473
|
+
text=e.text,
|
|
474
|
+
output_channel=output_channel,
|
|
475
|
+
sender_id=self.get_sender_id(call_parameters),
|
|
455
476
|
input_channel=self.name(),
|
|
456
477
|
metadata=asdict(call_parameters),
|
|
457
478
|
)
|
|
@@ -462,10 +483,27 @@ class VoiceInputChannel(InputChannel):
|
|
|
462
483
|
elif isinstance(e, UserSilence):
|
|
463
484
|
output_channel = self.create_output_channel(voice_websocket, tts_engine)
|
|
464
485
|
message = UserMessage(
|
|
465
|
-
"/silence_timeout",
|
|
466
|
-
output_channel,
|
|
467
|
-
call_parameters
|
|
486
|
+
text="/silence_timeout",
|
|
487
|
+
output_channel=output_channel,
|
|
488
|
+
sender_id=self.get_sender_id(call_parameters),
|
|
468
489
|
input_channel=self.name(),
|
|
469
490
|
metadata=asdict(call_parameters),
|
|
470
491
|
)
|
|
471
492
|
await on_new_message(message)
|
|
493
|
+
|
|
494
|
+
async def handle_disconnect(
|
|
495
|
+
self,
|
|
496
|
+
channel_websocket: Websocket,
|
|
497
|
+
on_new_message: Callable[[UserMessage], Awaitable[Any]],
|
|
498
|
+
tts_engine: TTSEngine,
|
|
499
|
+
call_parameters: CallParameters,
|
|
500
|
+
) -> None:
|
|
501
|
+
"""Handle disconnection from the channel."""
|
|
502
|
+
output_channel = self.create_output_channel(channel_websocket, tts_engine)
|
|
503
|
+
message = UserMessage(
|
|
504
|
+
text="/session_end",
|
|
505
|
+
output_channel=output_channel,
|
|
506
|
+
sender_id=self.get_sender_id(call_parameters),
|
|
507
|
+
input_channel=self.name(),
|
|
508
|
+
)
|
|
509
|
+
await on_new_message(message)
|
|
@@ -1,7 +1,9 @@
|
|
|
1
|
+
import glob
|
|
1
2
|
import importlib.resources
|
|
2
3
|
import json
|
|
4
|
+
import os.path
|
|
3
5
|
import re
|
|
4
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Text
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Text, Tuple
|
|
5
7
|
|
|
6
8
|
import dotenv
|
|
7
9
|
import structlog
|
|
@@ -162,6 +164,8 @@ DEFAULT_ENTERPRISE_SEARCH_PROMPT_WITH_CITATION_TEMPLATE = importlib.resources.re
|
|
|
162
164
|
"rasa.core.policies", "enterprise_search_prompt_with_citation_template.jinja2"
|
|
163
165
|
)
|
|
164
166
|
|
|
167
|
+
_ENTERPRISE_SEARCH_CITATION_PATTERN = re.compile(r"\[([^\]]+)\]")
|
|
168
|
+
|
|
165
169
|
|
|
166
170
|
class VectorStoreConnectionError(RasaException):
|
|
167
171
|
"""Exception raised for errors in connecting to the vector store."""
|
|
@@ -378,9 +382,11 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
378
382
|
|
|
379
383
|
if store_type == DEFAULT_VECTOR_STORE_TYPE:
|
|
380
384
|
logger.info("enterprise_search_policy.train.faiss")
|
|
385
|
+
docs_folder = self.vector_store_config.get(SOURCE_PROPERTY)
|
|
386
|
+
self._validate_documents_folder(docs_folder)
|
|
381
387
|
with self._model_storage.write_to(self._resource) as path:
|
|
382
388
|
self.vector_store = FAISS_Store(
|
|
383
|
-
docs_folder=
|
|
389
|
+
docs_folder=docs_folder,
|
|
384
390
|
embeddings=embeddings,
|
|
385
391
|
index_path=path,
|
|
386
392
|
create_index=True,
|
|
@@ -760,6 +766,33 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
760
766
|
result[domain.index_for_action(action_name)] = score # type: ignore[assignment]
|
|
761
767
|
return result
|
|
762
768
|
|
|
769
|
+
@classmethod
|
|
770
|
+
def _validate_documents_folder(cls, docs_folder: str) -> None:
|
|
771
|
+
if not os.path.exists(docs_folder) or not os.path.isdir(docs_folder):
|
|
772
|
+
error_message = (
|
|
773
|
+
f"Document source directory does not exist or is not a "
|
|
774
|
+
f"directory: '{docs_folder}'. "
|
|
775
|
+
"Please specify a valid path to the documents source directory in the "
|
|
776
|
+
"vector_store configuration."
|
|
777
|
+
)
|
|
778
|
+
logger.error(
|
|
779
|
+
"enterprise_search_policy.train.faiss.invalid_source_directory",
|
|
780
|
+
message=error_message,
|
|
781
|
+
)
|
|
782
|
+
print_error_and_exit(error_message)
|
|
783
|
+
|
|
784
|
+
docs = glob.glob(os.path.join(docs_folder, "*.txt"), recursive=True)
|
|
785
|
+
if not docs or len(docs) < 1:
|
|
786
|
+
error_message = (
|
|
787
|
+
f"Document source directory is empty: '{docs_folder}'. "
|
|
788
|
+
"Please add documents to this directory or specify a different one."
|
|
789
|
+
)
|
|
790
|
+
logger.error(
|
|
791
|
+
"enterprise_search_policy.train.faiss.source_directory_empty",
|
|
792
|
+
message=error_message,
|
|
793
|
+
)
|
|
794
|
+
print_error_and_exit(error_message)
|
|
795
|
+
|
|
763
796
|
@classmethod
|
|
764
797
|
def load(
|
|
765
798
|
cls,
|
|
@@ -833,7 +866,7 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
833
866
|
return None
|
|
834
867
|
|
|
835
868
|
source = merged_config.get(VECTOR_STORE_PROPERTY, {}).get(SOURCE_PROPERTY)
|
|
836
|
-
if not source:
|
|
869
|
+
if not source or not os.path.exists(source) or not os.path.isdir(source):
|
|
837
870
|
return None
|
|
838
871
|
|
|
839
872
|
docs = FAISS_Store.load_documents(source)
|
|
@@ -870,10 +903,18 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
870
903
|
|
|
871
904
|
@staticmethod
|
|
872
905
|
def post_process_citations(llm_answer: str) -> str:
|
|
873
|
-
"""Post-
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
906
|
+
"""Post-processes the LLM answer to correctly number and sort citations and
|
|
907
|
+
sources.
|
|
908
|
+
|
|
909
|
+
- Handles both single `[1]` and grouped `[1, 3]` citations.
|
|
910
|
+
- Rewrites the numbers in square brackets in the answer text to start from 1
|
|
911
|
+
and be sorted within each group.
|
|
912
|
+
- Reorders the sources according to the order of their first appearance
|
|
913
|
+
in the text.
|
|
914
|
+
- Removes citations from the text that point to sources missing from
|
|
915
|
+
the source list.
|
|
916
|
+
- Keeps sources that are not cited in the text, placing them at the end
|
|
917
|
+
of the list.
|
|
877
918
|
|
|
878
919
|
Args:
|
|
879
920
|
llm_answer: The LLM answer.
|
|
@@ -887,77 +928,160 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
887
928
|
|
|
888
929
|
# Split llm_answer into answer and citations
|
|
889
930
|
try:
|
|
890
|
-
|
|
931
|
+
answer_part, sources_part = llm_answer.rsplit("Sources:", 1)
|
|
891
932
|
except ValueError:
|
|
892
|
-
# if there is no "Sources:"
|
|
893
|
-
return llm_answer
|
|
894
|
-
|
|
895
|
-
#
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
933
|
+
# if there is no "Sources:" separator, return the original llm_answer
|
|
934
|
+
return llm_answer.strip()
|
|
935
|
+
|
|
936
|
+
# Parse the sources block to extract valid sources and other lines
|
|
937
|
+
valid_sources, other_source_lines = EnterpriseSearchPolicy._parse_sources_block(
|
|
938
|
+
sources_part
|
|
939
|
+
)
|
|
940
|
+
|
|
941
|
+
# Find all unique, valid citations in the answer text in their order
|
|
942
|
+
# of appearance
|
|
943
|
+
cited_order = EnterpriseSearchPolicy._get_cited_order(
|
|
944
|
+
answer_part, valid_sources
|
|
945
|
+
)
|
|
946
|
+
|
|
947
|
+
# Create a mapping from the old source numbers to the new, sequential numbers.
|
|
948
|
+
# For example, if the citation order in the text was [3, 1, 2], this map
|
|
949
|
+
# becomes {3: 1, 1: 2, 2: 3}. This allows for a quick lookup when rewriting
|
|
950
|
+
# the citations
|
|
951
|
+
renumbering_map = {
|
|
952
|
+
old_num: new_num + 1 for new_num, old_num in enumerate(cited_order)
|
|
953
|
+
}
|
|
954
|
+
|
|
955
|
+
# Rewrite the citations in the answer text based on the renumbering map
|
|
956
|
+
processed_answer = EnterpriseSearchPolicy._rewrite_answer_citations(
|
|
957
|
+
answer_part, renumbering_map
|
|
958
|
+
)
|
|
959
|
+
|
|
960
|
+
# Build the new list of sources
|
|
961
|
+
new_sources_list = EnterpriseSearchPolicy._build_final_sources_list(
|
|
962
|
+
cited_order,
|
|
963
|
+
renumbering_map,
|
|
964
|
+
valid_sources,
|
|
965
|
+
other_source_lines,
|
|
966
|
+
)
|
|
967
|
+
|
|
968
|
+
if len(new_sources_list) > 0:
|
|
969
|
+
processed_answer += "\nSources:\n" + "\n".join(new_sources_list)
|
|
901
970
|
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
continue
|
|
925
|
-
|
|
926
|
-
word = word.replace(
|
|
927
|
-
match, f"{', '.join(map(str, new_indices))}"
|
|
928
|
-
)
|
|
929
|
-
else:
|
|
930
|
-
old_index = int(match.strip("[].,:;?!"))
|
|
931
|
-
new_index = renumber_mapping.get(old_index)
|
|
932
|
-
if not new_index:
|
|
933
|
-
continue
|
|
934
|
-
|
|
935
|
-
word = word.replace(str(old_index), str(new_index))
|
|
936
|
-
new_answer.append(word)
|
|
937
|
-
|
|
938
|
-
# join the words
|
|
939
|
-
joined_answer = " ".join(new_answer)
|
|
940
|
-
joined_answer += "\nSources:\n"
|
|
941
|
-
|
|
942
|
-
new_sources: List[str] = []
|
|
943
|
-
|
|
944
|
-
for line in citations.split("\n"):
|
|
945
|
-
pattern = r"(?<=\[)\d+"
|
|
946
|
-
match = re.search(pattern, line)
|
|
971
|
+
return processed_answer
|
|
972
|
+
|
|
973
|
+
@staticmethod
|
|
974
|
+
def _parse_sources_block(sources_part: str) -> Tuple[Dict[int, str], List[str]]:
|
|
975
|
+
"""Parses the sources block from the LLM response.
|
|
976
|
+
Returns a tuple containing:
|
|
977
|
+
- A dictionary of valid sources matching the "[1] ..." format,
|
|
978
|
+
where the key is the source number
|
|
979
|
+
- A list of other source lines that do not match the specified format
|
|
980
|
+
"""
|
|
981
|
+
valid_sources: Dict[int, str] = {}
|
|
982
|
+
other_source_lines: List[str] = []
|
|
983
|
+
source_line_pattern = re.compile(r"^\s*\[(\d+)\](.*)")
|
|
984
|
+
|
|
985
|
+
source_lines = sources_part.strip().split("\n")
|
|
986
|
+
|
|
987
|
+
for line in source_lines:
|
|
988
|
+
line = line.strip()
|
|
989
|
+
if not line:
|
|
990
|
+
continue
|
|
991
|
+
|
|
992
|
+
match = source_line_pattern.match(line)
|
|
947
993
|
if match:
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
951
|
-
|
|
994
|
+
num = int(match.group(1))
|
|
995
|
+
valid_sources[num] = line
|
|
996
|
+
else:
|
|
997
|
+
other_source_lines.append(line)
|
|
998
|
+
|
|
999
|
+
return valid_sources, other_source_lines
|
|
1000
|
+
|
|
1001
|
+
@staticmethod
|
|
1002
|
+
def _get_cited_order(
|
|
1003
|
+
answer_part: str, available_sources: Dict[int, str]
|
|
1004
|
+
) -> List[int]:
|
|
1005
|
+
"""Find all unique, valid citations in the answer text in their order
|
|
1006
|
+
# of appearance
|
|
1007
|
+
"""
|
|
1008
|
+
cited_order: List[int] = []
|
|
1009
|
+
seen_indices = set()
|
|
1010
|
+
|
|
1011
|
+
for match in _ENTERPRISE_SEARCH_CITATION_PATTERN.finditer(answer_part):
|
|
1012
|
+
content = match.group(1)
|
|
1013
|
+
indices_str = [s.strip() for s in content.split(",")]
|
|
1014
|
+
for index_str in indices_str:
|
|
1015
|
+
if index_str.isdigit():
|
|
1016
|
+
index = int(index_str)
|
|
1017
|
+
if index in available_sources and index not in seen_indices:
|
|
1018
|
+
cited_order.append(index)
|
|
1019
|
+
seen_indices.add(index)
|
|
1020
|
+
|
|
1021
|
+
return cited_order
|
|
1022
|
+
|
|
1023
|
+
@staticmethod
|
|
1024
|
+
def _rewrite_answer_citations(
|
|
1025
|
+
answer_part: str, renumber_map: Dict[int, int]
|
|
1026
|
+
) -> str:
|
|
1027
|
+
"""Rewrites the citations in the answer text based on the renumbering map."""
|
|
1028
|
+
|
|
1029
|
+
def replacer(match: re.Match) -> str:
|
|
1030
|
+
content = match.group(1)
|
|
1031
|
+
old_indices_str = [s.strip() for s in content.split(",")]
|
|
1032
|
+
new_indices = [
|
|
1033
|
+
renumber_map[int(s)]
|
|
1034
|
+
for s in old_indices_str
|
|
1035
|
+
if s.isdigit() and int(s) in renumber_map
|
|
1036
|
+
]
|
|
1037
|
+
if not new_indices:
|
|
1038
|
+
return ""
|
|
1039
|
+
|
|
1040
|
+
return f"[{', '.join(map(str, sorted(list(set(new_indices)))))}]"
|
|
1041
|
+
|
|
1042
|
+
processed_answer = _ENTERPRISE_SEARCH_CITATION_PATTERN.sub(
|
|
1043
|
+
replacer, answer_part
|
|
1044
|
+
)
|
|
1045
|
+
|
|
1046
|
+
# Clean up formatting after replacements
|
|
1047
|
+
processed_answer = re.sub(r"\s+([,.?])", r"\1", processed_answer)
|
|
1048
|
+
processed_answer = processed_answer.replace("[]", " ")
|
|
1049
|
+
processed_answer = re.sub(r"\s+", " ", processed_answer)
|
|
1050
|
+
processed_answer = processed_answer.strip()
|
|
1051
|
+
|
|
1052
|
+
return processed_answer
|
|
1053
|
+
|
|
1054
|
+
@staticmethod
|
|
1055
|
+
def _build_final_sources_list(
|
|
1056
|
+
cited_order: List[int],
|
|
1057
|
+
renumbering_map: Dict[int, int],
|
|
1058
|
+
valid_sources: Dict[int, str],
|
|
1059
|
+
other_source_lines: List[str],
|
|
1060
|
+
) -> List[str]:
|
|
1061
|
+
"""Builds the final list of sources based on the cited order and
|
|
1062
|
+
renumbering map.
|
|
1063
|
+
"""
|
|
1064
|
+
new_sources_list: List[str] = []
|
|
1065
|
+
|
|
1066
|
+
# First, add the sorted, used sources
|
|
1067
|
+
for old_num in cited_order:
|
|
1068
|
+
new_num = renumbering_map[old_num]
|
|
1069
|
+
source_line = valid_sources[old_num]
|
|
1070
|
+
new_sources_list.append(
|
|
1071
|
+
source_line.replace(f"[{old_num}]", f"[{new_num}]", 1)
|
|
1072
|
+
)
|
|
952
1073
|
|
|
953
|
-
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
|
|
1074
|
+
# Then, add the unused but validly numbered sources
|
|
1075
|
+
used_source_nums = set(cited_order)
|
|
1076
|
+
# Sort by number to ensure a consistent order for uncited sources
|
|
1077
|
+
for num, line in sorted(valid_sources.items()):
|
|
1078
|
+
if num not in used_source_nums:
|
|
1079
|
+
new_sources_list.append(line)
|
|
957
1080
|
|
|
958
|
-
|
|
1081
|
+
# Finally, add any other source lines
|
|
1082
|
+
new_sources_list.extend(other_source_lines)
|
|
959
1083
|
|
|
960
|
-
return
|
|
1084
|
+
return new_sources_list
|
|
961
1085
|
|
|
962
1086
|
@classmethod
|
|
963
1087
|
def _perform_health_checks(
|
|
@@ -721,9 +721,7 @@ class IntentlessPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Policy):
|
|
|
721
721
|
final_response_examples.append(resp)
|
|
722
722
|
|
|
723
723
|
llm_response = await self.generate_answer(
|
|
724
|
-
final_response_examples,
|
|
725
|
-
conversation_samples,
|
|
726
|
-
history,
|
|
724
|
+
final_response_examples, conversation_samples, history
|
|
727
725
|
)
|
|
728
726
|
if not llm_response:
|
|
729
727
|
structlogger.debug("intentless_policy.prediction.skip_llm_fail")
|