rasa-pro 3.11.0rc1__py3-none-any.whl → 3.11.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/cli/inspect.py +2 -0
- rasa/cli/studio/studio.py +18 -8
- rasa/core/actions/action_repeat_bot_messages.py +17 -0
- rasa/core/channels/channel.py +17 -0
- rasa/core/channels/voice_ready/audiocodes.py +12 -0
- rasa/core/channels/voice_ready/jambonz.py +13 -2
- rasa/core/channels/voice_ready/twilio_voice.py +6 -21
- rasa/core/channels/voice_stream/voice_channel.py +13 -1
- rasa/core/nlg/contextual_response_rephraser.py +18 -10
- rasa/core/policies/enterprise_search_policy.py +27 -67
- rasa/core/policies/intentless_policy.py +25 -67
- rasa/dialogue_understanding/coexistence/llm_based_router.py +18 -33
- rasa/dialogue_understanding/generator/constants.py +0 -2
- rasa/dialogue_understanding/generator/flow_retrieval.py +33 -50
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +12 -40
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +18 -20
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +19 -1
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +24 -21
- rasa/dialogue_understanding/processor/command_processor.py +21 -1
- rasa/e2e_test/e2e_test_case.py +85 -6
- rasa/engine/validation.py +57 -41
- rasa/model_service.py +3 -0
- rasa/nlu/tokenizers/whitespace_tokenizer.py +3 -14
- rasa/server.py +3 -1
- rasa/shared/core/flows/flows_list.py +5 -1
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +6 -14
- rasa/shared/providers/llm/_base_litellm_client.py +6 -1
- rasa/shared/utils/health_check/__init__.py +0 -0
- rasa/shared/utils/health_check/embeddings_health_check_mixin.py +31 -0
- rasa/shared/utils/health_check/health_check.py +256 -0
- rasa/shared/utils/health_check/llm_health_check_mixin.py +31 -0
- rasa/shared/utils/llm.py +5 -2
- rasa/shared/utils/yaml.py +102 -62
- rasa/studio/auth.py +3 -5
- rasa/studio/config.py +13 -4
- rasa/studio/constants.py +1 -0
- rasa/studio/data_handler.py +10 -3
- rasa/studio/upload.py +21 -10
- rasa/telemetry.py +12 -0
- rasa/tracing/config.py +2 -0
- rasa/tracing/instrumentation/attribute_extractors.py +20 -0
- rasa/tracing/instrumentation/instrumentation.py +121 -0
- rasa/utils/common.py +5 -0
- rasa/utils/io.py +8 -16
- rasa/utils/sanic_error_handler.py +32 -0
- rasa/version.py +1 -1
- {rasa_pro-3.11.0rc1.dist-info → rasa_pro-3.11.0rc2.dist-info}/METADATA +3 -2
- {rasa_pro-3.11.0rc1.dist-info → rasa_pro-3.11.0rc2.dist-info}/RECORD +51 -47
- rasa/shared/utils/health_check.py +0 -533
- {rasa_pro-3.11.0rc1.dist-info → rasa_pro-3.11.0rc2.dist-info}/NOTICE +0 -0
- {rasa_pro-3.11.0rc1.dist-info → rasa_pro-3.11.0rc2.dist-info}/WHEEL +0 -0
- {rasa_pro-3.11.0rc1.dist-info → rasa_pro-3.11.0rc2.dist-info}/entry_points.txt +0 -0
rasa/cli/inspect.py
CHANGED
|
@@ -5,6 +5,7 @@ from typing import List, Text
|
|
|
5
5
|
|
|
6
6
|
from sanic import Sanic
|
|
7
7
|
|
|
8
|
+
from rasa import telemetry
|
|
8
9
|
from rasa.cli import SubParsersAction
|
|
9
10
|
from rasa.cli.arguments import shell as arguments
|
|
10
11
|
from rasa.core import constants
|
|
@@ -70,4 +71,5 @@ def inspect(args: argparse.Namespace) -> None:
|
|
|
70
71
|
args.credentials = None
|
|
71
72
|
args.server_listeners = [(after_start_hook_open_inspector, "after_server_start")]
|
|
72
73
|
|
|
74
|
+
telemetry.track_inspect_started(args.connector)
|
|
73
75
|
rasa.cli.run.run(args)
|
rasa/cli/studio/studio.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import argparse
|
|
2
|
-
from typing import List, Optional
|
|
2
|
+
from typing import List, Optional, Tuple
|
|
3
3
|
from urllib.parse import ParseResult, urlparse
|
|
4
4
|
|
|
5
5
|
import questionary
|
|
@@ -149,7 +149,7 @@ def _configure_studio_url() -> Optional[str]:
|
|
|
149
149
|
return studio_url
|
|
150
150
|
|
|
151
151
|
|
|
152
|
-
def _get_advanced_config(studio_url: str) ->
|
|
152
|
+
def _get_advanced_config(studio_url: str) -> Tuple:
|
|
153
153
|
"""Get the advanced configuration values for Rasa Studio."""
|
|
154
154
|
keycloak_url = questionary.text(
|
|
155
155
|
"Please provide your Rasa Studio Keycloak URL",
|
|
@@ -167,7 +167,7 @@ def _get_advanced_config(studio_url: str) -> tuple:
|
|
|
167
167
|
return keycloak_url, realm_name, client_id
|
|
168
168
|
|
|
169
169
|
|
|
170
|
-
def _get_default_config(studio_url: str) ->
|
|
170
|
+
def _get_default_config(studio_url: str) -> Tuple:
|
|
171
171
|
"""Get the default configuration values for Rasa Studio."""
|
|
172
172
|
keycloak_url = studio_url + "auth/"
|
|
173
173
|
realm_name = DEFAULT_REALM_NAME
|
|
@@ -178,6 +178,7 @@ def _get_default_config(studio_url: str) -> tuple:
|
|
|
178
178
|
f"Keycloak URL: {keycloak_url}, "
|
|
179
179
|
f"Realm Name: '{realm_name}', "
|
|
180
180
|
f"Client ID: '{client_id}'. "
|
|
181
|
+
f"SSL verification is enabled."
|
|
181
182
|
f"You can use '--advanced' to configure these settings."
|
|
182
183
|
)
|
|
183
184
|
|
|
@@ -185,7 +186,11 @@ def _get_default_config(studio_url: str) -> tuple:
|
|
|
185
186
|
|
|
186
187
|
|
|
187
188
|
def _create_studio_config(
|
|
188
|
-
studio_url: str,
|
|
189
|
+
studio_url: str,
|
|
190
|
+
keycloak_url: str,
|
|
191
|
+
realm_name: str,
|
|
192
|
+
client_id: str,
|
|
193
|
+
disable_verify: bool = False,
|
|
189
194
|
) -> StudioConfig:
|
|
190
195
|
"""Create a StudioConfig object with the provided parameters."""
|
|
191
196
|
return StudioConfig(
|
|
@@ -193,6 +198,7 @@ def _create_studio_config(
|
|
|
193
198
|
studio_url=studio_url + "api/graphql/",
|
|
194
199
|
client_id=client_id,
|
|
195
200
|
realm_name=realm_name,
|
|
201
|
+
disable_verify=disable_verify,
|
|
196
202
|
)
|
|
197
203
|
|
|
198
204
|
|
|
@@ -227,19 +233,23 @@ def _configure_studio_config(args: argparse.Namespace) -> StudioConfig:
|
|
|
227
233
|
|
|
228
234
|
# create a configuration and auth object to try to reach the studio
|
|
229
235
|
studio_config = _create_studio_config(
|
|
230
|
-
studio_url,
|
|
236
|
+
studio_url,
|
|
237
|
+
keycloak_url,
|
|
238
|
+
realm_name,
|
|
239
|
+
client_id,
|
|
240
|
+
disable_verify=args.disable_verify,
|
|
231
241
|
)
|
|
232
242
|
|
|
233
|
-
if
|
|
243
|
+
if studio_config.disable_verify:
|
|
234
244
|
rasa.shared.utils.cli.print_info(
|
|
235
245
|
"Disabling SSL verification for the Rasa Studio authentication server."
|
|
236
246
|
)
|
|
237
|
-
studio_auth = StudioAuth(studio_config, verify=False)
|
|
238
247
|
else:
|
|
239
248
|
rasa.shared.utils.cli.print_info(
|
|
240
249
|
"Enabling SSL verification for the Rasa Studio authentication server."
|
|
241
250
|
)
|
|
242
|
-
|
|
251
|
+
|
|
252
|
+
studio_auth = StudioAuth(studio_config)
|
|
243
253
|
|
|
244
254
|
if _check_studio_auth(studio_auth):
|
|
245
255
|
return studio_config
|
|
@@ -3,6 +3,15 @@ from typing import Optional, Dict, Any, List
|
|
|
3
3
|
from rasa.core.actions.action import Action
|
|
4
4
|
from rasa.core.channels import OutputChannel
|
|
5
5
|
from rasa.core.nlg import NaturalLanguageGenerator
|
|
6
|
+
from rasa.dialogue_understanding.patterns.collect_information import (
|
|
7
|
+
CollectInformationPatternFlowStackFrame,
|
|
8
|
+
)
|
|
9
|
+
from rasa.dialogue_understanding.patterns.repeat import (
|
|
10
|
+
RepeatBotMessagesPatternFlowStackFrame,
|
|
11
|
+
)
|
|
12
|
+
from rasa.dialogue_understanding.patterns.user_silence import (
|
|
13
|
+
UserSilencePatternFlowStackFrame,
|
|
14
|
+
)
|
|
6
15
|
from rasa.shared.core.constants import ACTION_REPEAT_BOT_MESSAGES
|
|
7
16
|
from rasa.shared.core.domain import Domain
|
|
8
17
|
from rasa.shared.core.events import Event, BotUttered, UserUttered
|
|
@@ -39,6 +48,14 @@ class ActionRepeatBotMessages(Action):
|
|
|
39
48
|
The elif condition doesn't break when it sees User3 event.
|
|
40
49
|
But it does at User2 event.
|
|
41
50
|
"""
|
|
51
|
+
# Skip action if we are in a collect information step whose
|
|
52
|
+
# default behavior is to repeat anyways
|
|
53
|
+
top_frame = tracker.stack.top(
|
|
54
|
+
lambda frame: isinstance(frame, RepeatBotMessagesPatternFlowStackFrame)
|
|
55
|
+
or isinstance(frame, UserSilencePatternFlowStackFrame)
|
|
56
|
+
)
|
|
57
|
+
if isinstance(top_frame, CollectInformationPatternFlowStackFrame):
|
|
58
|
+
return []
|
|
42
59
|
# filter user and bot events
|
|
43
60
|
filtered = [
|
|
44
61
|
e for e in tracker.events if isinstance(e, (BotUttered, UserUttered))
|
rasa/core/channels/channel.py
CHANGED
|
@@ -313,6 +313,23 @@ class OutputChannel:
|
|
|
313
313
|
button_msg = cli_utils.button_to_string(button, idx)
|
|
314
314
|
await self.send_text_message(recipient_id, button_msg)
|
|
315
315
|
|
|
316
|
+
async def send_text_with_buttons_concise(
|
|
317
|
+
self,
|
|
318
|
+
recipient_id: str,
|
|
319
|
+
text: str,
|
|
320
|
+
buttons: List[Dict[str, Any]],
|
|
321
|
+
**kwargs: Any,
|
|
322
|
+
) -> None:
|
|
323
|
+
"""Sends buttons in a concise format, useful for voice channels."""
|
|
324
|
+
if text.strip()[-1] not in {".", "!", "?", ":"}:
|
|
325
|
+
text += "."
|
|
326
|
+
text += " "
|
|
327
|
+
for idx, button in enumerate(buttons):
|
|
328
|
+
text += button["title"]
|
|
329
|
+
if idx != len(buttons) - 1:
|
|
330
|
+
text += ", "
|
|
331
|
+
await self.send_text_message(recipient_id, text)
|
|
332
|
+
|
|
316
333
|
async def send_quick_replies(
|
|
317
334
|
self,
|
|
318
335
|
recipient_id: Text,
|
|
@@ -21,6 +21,7 @@ from sanic.exceptions import NotFound, SanicException, ServerError
|
|
|
21
21
|
from sanic.request import Request
|
|
22
22
|
from sanic.response import HTTPResponse
|
|
23
23
|
|
|
24
|
+
from rasa.utils.io import remove_emojis
|
|
24
25
|
|
|
25
26
|
structlogger = structlog.get_logger()
|
|
26
27
|
|
|
@@ -449,6 +450,7 @@ class AudiocodesOutput(OutputChannel):
|
|
|
449
450
|
self, recipient_id: Text, text: Text, **kwargs: Any
|
|
450
451
|
) -> None:
|
|
451
452
|
"""Send a text message."""
|
|
453
|
+
text = remove_emojis(text)
|
|
452
454
|
await self.add_message({"type": "message", "text": text})
|
|
453
455
|
|
|
454
456
|
async def send_image_url(
|
|
@@ -471,6 +473,16 @@ class AudiocodesOutput(OutputChannel):
|
|
|
471
473
|
"""Indicate that the conversation should be ended."""
|
|
472
474
|
await self.add_message({"type": "event", "name": "hangup"})
|
|
473
475
|
|
|
476
|
+
async def send_text_with_buttons(
|
|
477
|
+
self,
|
|
478
|
+
recipient_id: str,
|
|
479
|
+
text: str,
|
|
480
|
+
buttons: List[Dict[str, Any]],
|
|
481
|
+
**kwargs: Any,
|
|
482
|
+
) -> None:
|
|
483
|
+
"""Uses the concise button output format for voice channels."""
|
|
484
|
+
await self.send_text_with_buttons_concise(recipient_id, text, buttons, **kwargs)
|
|
485
|
+
|
|
474
486
|
|
|
475
487
|
class WebsocketOutput(AudiocodesOutput):
|
|
476
488
|
def __init__(self, ws: Any, conversation_id: Text) -> None:
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any, Awaitable, Callable, Dict, Optional, Text
|
|
1
|
+
from typing import Any, Awaitable, Callable, Dict, List, Optional, Text
|
|
2
2
|
|
|
3
3
|
import structlog
|
|
4
4
|
from rasa.core.channels.channel import InputChannel, OutputChannel, UserMessage
|
|
@@ -14,7 +14,7 @@ from sanic.request import Request
|
|
|
14
14
|
from sanic.response import HTTPResponse
|
|
15
15
|
|
|
16
16
|
from rasa.shared.utils.common import mark_as_beta_feature
|
|
17
|
-
|
|
17
|
+
from rasa.utils.io import remove_emojis
|
|
18
18
|
|
|
19
19
|
structlogger = structlog.get_logger()
|
|
20
20
|
|
|
@@ -87,6 +87,7 @@ class JambonzWebsocketOutput(OutputChannel):
|
|
|
87
87
|
self, recipient_id: Text, text: Text, **kwargs: Any
|
|
88
88
|
) -> None:
|
|
89
89
|
"""Send a text message."""
|
|
90
|
+
text = remove_emojis(text)
|
|
90
91
|
await self.add_message({"type": "message", "text": text})
|
|
91
92
|
|
|
92
93
|
async def send_image_url(
|
|
@@ -108,3 +109,13 @@ class JambonzWebsocketOutput(OutputChannel):
|
|
|
108
109
|
async def hangup(self, recipient_id: Text, **kwargs: Any) -> None:
|
|
109
110
|
"""Indicate that the conversation should be ended."""
|
|
110
111
|
await send_ws_hangup_message(DEFAULT_HANGUP_DELAY_SECONDS, self.ws)
|
|
112
|
+
|
|
113
|
+
async def send_text_with_buttons(
|
|
114
|
+
self,
|
|
115
|
+
recipient_id: str,
|
|
116
|
+
text: str,
|
|
117
|
+
buttons: List[Dict[str, Any]],
|
|
118
|
+
**kwargs: Any,
|
|
119
|
+
) -> None:
|
|
120
|
+
"""Uses the concise button output format for voice channels."""
|
|
121
|
+
await self.send_text_with_buttons_concise(recipient_id, text, buttons, **kwargs)
|
|
@@ -358,38 +358,23 @@ class TwilioVoiceCollectingOutputChannel(CollectingOutputChannel):
|
|
|
358
358
|
"""Name of the output channel."""
|
|
359
359
|
return "twilio_voice"
|
|
360
360
|
|
|
361
|
-
@staticmethod
|
|
362
|
-
def _emoji_warning(text: Text) -> None:
|
|
363
|
-
"""Raises a warning if text contains an emoji."""
|
|
364
|
-
emoji_regex = rasa.utils.io.get_emoji_regex()
|
|
365
|
-
if emoji_regex.findall(text):
|
|
366
|
-
rasa.shared.utils.io.raise_warning(
|
|
367
|
-
"Text contains an emoji in a voice response. "
|
|
368
|
-
"Review responses to provide a voice-friendly alternative."
|
|
369
|
-
)
|
|
370
|
-
|
|
371
361
|
async def send_text_message(
|
|
372
362
|
self, recipient_id: Text, text: Text, **kwargs: Any
|
|
373
363
|
) -> None:
|
|
374
364
|
"""Sends the text message after removing emojis."""
|
|
375
|
-
|
|
365
|
+
text = rasa.utils.io.remove_emojis(text)
|
|
376
366
|
for message_part in text.strip().split("\n\n"):
|
|
377
367
|
await self._persist_message(self._message(recipient_id, text=message_part))
|
|
378
368
|
|
|
379
369
|
async def send_text_with_buttons(
|
|
380
370
|
self,
|
|
381
|
-
recipient_id:
|
|
382
|
-
text:
|
|
383
|
-
buttons: List[Dict[
|
|
371
|
+
recipient_id: str,
|
|
372
|
+
text: str,
|
|
373
|
+
buttons: List[Dict[str, Any]],
|
|
384
374
|
**kwargs: Any,
|
|
385
375
|
) -> None:
|
|
386
|
-
"""
|
|
387
|
-
self.
|
|
388
|
-
await self._persist_message(self._message(recipient_id, text=text))
|
|
389
|
-
|
|
390
|
-
for b in buttons:
|
|
391
|
-
self._emoji_warning(b["title"])
|
|
392
|
-
await self._persist_message(self._message(recipient_id, text=b["title"]))
|
|
376
|
+
"""Uses the concise button output format for voice channels."""
|
|
377
|
+
await self.send_text_with_buttons_concise(recipient_id, text, buttons, **kwargs)
|
|
393
378
|
|
|
394
379
|
async def send_image_url(
|
|
395
380
|
self, recipient_id: Text, image: Text, **kwargs: Any
|
|
@@ -2,7 +2,7 @@ import asyncio
|
|
|
2
2
|
import structlog
|
|
3
3
|
import copy
|
|
4
4
|
from dataclasses import asdict, dataclass
|
|
5
|
-
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Optional, Tuple
|
|
5
|
+
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Tuple
|
|
6
6
|
|
|
7
7
|
from rasa.core.channels.voice_stream.util import generate_silence
|
|
8
8
|
from rasa.shared.core.constants import (
|
|
@@ -40,6 +40,7 @@ from rasa.core.channels.voice_stream.tts.azure import AzureTTS
|
|
|
40
40
|
from rasa.core.channels.voice_stream.tts.tts_engine import TTSEngine, TTSError
|
|
41
41
|
from rasa.core.channels.voice_stream.tts.cartesia import CartesiaTTS
|
|
42
42
|
from rasa.core.channels.voice_stream.tts.tts_cache import TTSCache
|
|
43
|
+
from rasa.utils.io import remove_emojis
|
|
43
44
|
|
|
44
45
|
logger = structlog.get_logger(__name__)
|
|
45
46
|
|
|
@@ -157,9 +158,20 @@ class VoiceOutputChannel(OutputChannel):
|
|
|
157
158
|
self.tracker_state["slots"][SLOT_SILENCE_TIMEOUT]
|
|
158
159
|
)
|
|
159
160
|
|
|
161
|
+
async def send_text_with_buttons(
|
|
162
|
+
self,
|
|
163
|
+
recipient_id: str,
|
|
164
|
+
text: str,
|
|
165
|
+
buttons: List[Dict[str, Any]],
|
|
166
|
+
**kwargs: Any,
|
|
167
|
+
) -> None:
|
|
168
|
+
"""Uses the concise button output format for voice channels."""
|
|
169
|
+
await self.send_text_with_buttons_concise(recipient_id, text, buttons, **kwargs)
|
|
170
|
+
|
|
160
171
|
async def send_text_message(
|
|
161
172
|
self, recipient_id: str, text: str, **kwargs: Any
|
|
162
173
|
) -> None:
|
|
174
|
+
text = remove_emojis(text)
|
|
163
175
|
self.update_silence_timeout()
|
|
164
176
|
cached_audio_bytes = self.tts_cache.get(text)
|
|
165
177
|
collected_audio_bytes = RasaAudioBytes(b"")
|
|
@@ -2,7 +2,6 @@ from typing import Any, Dict, Optional, Text
|
|
|
2
2
|
|
|
3
3
|
import structlog
|
|
4
4
|
from jinja2 import Template
|
|
5
|
-
|
|
6
5
|
from rasa import telemetry
|
|
7
6
|
from rasa.core.nlg.response import TemplatedNaturalLanguageGenerator
|
|
8
7
|
from rasa.core.nlg.summarize import summarize_conversation
|
|
@@ -19,6 +18,7 @@ from rasa.shared.constants import (
|
|
|
19
18
|
from rasa.shared.core.domain import KEY_RESPONSES_TEXT, Domain
|
|
20
19
|
from rasa.shared.core.events import BotUttered, UserUttered
|
|
21
20
|
from rasa.shared.core.trackers import DialogueStateTracker
|
|
21
|
+
from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheckMixin
|
|
22
22
|
from rasa.shared.utils.llm import (
|
|
23
23
|
DEFAULT_OPENAI_GENERATE_MODEL_NAME,
|
|
24
24
|
DEFAULT_OPENAI_MAX_GENERATED_TOKENS,
|
|
@@ -28,7 +28,6 @@ from rasa.shared.utils.llm import (
|
|
|
28
28
|
llm_factory,
|
|
29
29
|
resolve_model_client_config,
|
|
30
30
|
)
|
|
31
|
-
from rasa.shared.utils.health_check import perform_training_time_llm_health_check
|
|
32
31
|
from rasa.shared.utils.llm import (
|
|
33
32
|
tracker_as_readable_transcript,
|
|
34
33
|
)
|
|
@@ -44,6 +43,8 @@ RESPONSE_REPHRASING_TEMPLATE_KEY = "rephrase_prompt"
|
|
|
44
43
|
RESPONSE_SUMMARISE_CONVERSATION_KEY = "summarize_conversation"
|
|
45
44
|
|
|
46
45
|
DEFAULT_REPHRASE_ALL = False
|
|
46
|
+
DEFAULT_SUMMARIZE_HISTORY = True
|
|
47
|
+
DEFAULT_MAX_HISTORICAL_TURNS = 5
|
|
47
48
|
|
|
48
49
|
DEFAULT_LLM_CONFIG = {
|
|
49
50
|
PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
|
|
@@ -68,7 +69,9 @@ Suggested AI Response: {{suggested_response}}
|
|
|
68
69
|
Rephrased AI Response:"""
|
|
69
70
|
|
|
70
71
|
|
|
71
|
-
class ContextualResponseRephraser(
|
|
72
|
+
class ContextualResponseRephraser(
|
|
73
|
+
LLMHealthCheckMixin, TemplatedNaturalLanguageGenerator
|
|
74
|
+
):
|
|
72
75
|
"""Generates responses based on modified templates.
|
|
73
76
|
|
|
74
77
|
The templates are filled with the entities and slots that are available in the
|
|
@@ -102,13 +105,19 @@ class ContextualResponseRephraser(TemplatedNaturalLanguageGenerator):
|
|
|
102
105
|
self.trace_prompt_tokens = self.nlg_endpoint.kwargs.get(
|
|
103
106
|
"trace_prompt_tokens", False
|
|
104
107
|
)
|
|
108
|
+
self.summarize_history = self.nlg_endpoint.kwargs.get(
|
|
109
|
+
"summarize_history", DEFAULT_SUMMARIZE_HISTORY
|
|
110
|
+
)
|
|
111
|
+
self.max_historical_turns = self.nlg_endpoint.kwargs.get(
|
|
112
|
+
"max_historical_turns", DEFAULT_MAX_HISTORICAL_TURNS
|
|
113
|
+
)
|
|
105
114
|
|
|
106
115
|
self.llm_config = resolve_model_client_config(
|
|
107
116
|
self.nlg_endpoint.kwargs.get(LLM_CONFIG_KEY),
|
|
108
117
|
ContextualResponseRephraser.__name__,
|
|
109
118
|
)
|
|
110
119
|
|
|
111
|
-
|
|
120
|
+
self.perform_llm_health_check(
|
|
112
121
|
self.llm_config,
|
|
113
122
|
DEFAULT_LLM_CONFIG,
|
|
114
123
|
"contextual_response_rephraser.init",
|
|
@@ -213,18 +222,17 @@ class ContextualResponseRephraser(TemplatedNaturalLanguageGenerator):
|
|
|
213
222
|
prompt_template_text = self._template_for_response_rephrasing(response)
|
|
214
223
|
|
|
215
224
|
# Retrieve inputs for the dynamic prompt
|
|
216
|
-
transcript = tracker_as_readable_transcript(tracker, max_turns=5)
|
|
217
225
|
latest_message = self._last_message_if_human(tracker)
|
|
218
226
|
current_input = f"{USER}: {latest_message}" if latest_message else ""
|
|
219
227
|
|
|
220
228
|
# Only summarise conversation history if flagged
|
|
221
|
-
|
|
222
|
-
RESPONSE_SUMMARISE_CONVERSATION_KEY, False
|
|
223
|
-
)
|
|
224
|
-
if summarize_conversation_flag:
|
|
229
|
+
if self.summarize_history:
|
|
225
230
|
history = await self._create_history(tracker)
|
|
226
231
|
else:
|
|
227
|
-
history
|
|
232
|
+
# make sure the transcript/history contains the last user utterance
|
|
233
|
+
max_turns = max(self.max_historical_turns, 1)
|
|
234
|
+
history = tracker_as_readable_transcript(tracker, max_turns=max_turns)
|
|
235
|
+
# the history already contains the current input
|
|
228
236
|
current_input = ""
|
|
229
237
|
|
|
230
238
|
prompt = Template(prompt_template_text).render(
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import importlib.resources
|
|
2
2
|
import json
|
|
3
3
|
import re
|
|
4
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Text
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Text
|
|
5
5
|
import dotenv
|
|
6
6
|
import structlog
|
|
7
7
|
from jinja2 import Template
|
|
@@ -25,8 +25,6 @@ from rasa.core.policies.policy import Policy, PolicyPrediction
|
|
|
25
25
|
from rasa.core.utils import AvailableEndpoints
|
|
26
26
|
from rasa.dialogue_understanding.generator.constants import (
|
|
27
27
|
LLM_CONFIG_KEY,
|
|
28
|
-
TRAINED_MODEL_NAME_CONFIG_KEY,
|
|
29
|
-
TRAINED_EMBEDDINGS_CONFIG_KEY,
|
|
30
28
|
)
|
|
31
29
|
from rasa.dialogue_understanding.patterns.cannot_handle import (
|
|
32
30
|
CannotHandlePatternFlowStackFrame,
|
|
@@ -71,6 +69,10 @@ from rasa.shared.providers.embedding._langchain_embedding_client_adapter import
|
|
|
71
69
|
)
|
|
72
70
|
from rasa.shared.providers.llm.llm_client import LLMClient
|
|
73
71
|
from rasa.shared.utils.cli import print_error_and_exit
|
|
72
|
+
from rasa.shared.utils.health_check.embeddings_health_check_mixin import (
|
|
73
|
+
EmbeddingsHealthCheckMixin,
|
|
74
|
+
)
|
|
75
|
+
from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheckMixin
|
|
74
76
|
from rasa.shared.utils.io import deep_container_fingerprint
|
|
75
77
|
from rasa.shared.utils.llm import (
|
|
76
78
|
DEFAULT_OPENAI_CHAT_MODEL_NAME,
|
|
@@ -82,12 +84,6 @@ from rasa.shared.utils.llm import (
|
|
|
82
84
|
tracker_as_readable_transcript,
|
|
83
85
|
resolve_model_client_config,
|
|
84
86
|
)
|
|
85
|
-
from rasa.shared.utils.health_check import (
|
|
86
|
-
perform_training_time_llm_health_check,
|
|
87
|
-
perform_training_time_embeddings_health_check,
|
|
88
|
-
perform_inference_time_llm_health_check,
|
|
89
|
-
perform_inference_time_embeddings_health_check,
|
|
90
|
-
)
|
|
91
87
|
from rasa.telemetry import (
|
|
92
88
|
track_enterprise_search_policy_predict,
|
|
93
89
|
track_enterprise_search_policy_train_completed,
|
|
@@ -161,7 +157,7 @@ class VectorStoreConfigurationError(RasaException):
|
|
|
161
157
|
@DefaultV1Recipe.register(
|
|
162
158
|
DefaultV1Recipe.ComponentType.POLICY_WITH_END_TO_END_SUPPORT, is_trainable=True
|
|
163
159
|
)
|
|
164
|
-
class EnterpriseSearchPolicy(Policy):
|
|
160
|
+
class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Policy):
|
|
165
161
|
"""Policy which uses a vector store and LLMs to respond to user messages.
|
|
166
162
|
|
|
167
163
|
The policy uses a vector store and LLMs to respond to user messages. The
|
|
@@ -300,6 +296,9 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
300
296
|
A policy must return its resource locator so that potential children nodes
|
|
301
297
|
can load the policy from the resource.
|
|
302
298
|
"""
|
|
299
|
+
# Perform health checks for both LLM and embeddings client configs
|
|
300
|
+
self._perform_health_checks(self.config, "enterprise_search_policy.train")
|
|
301
|
+
|
|
303
302
|
store_type = self.vector_store_config.get(VECTOR_STORE_TYPE_PROPERTY)
|
|
304
303
|
|
|
305
304
|
# telemetry call to track training start
|
|
@@ -319,11 +318,6 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
319
318
|
f"required environment variables. Error: {e}"
|
|
320
319
|
)
|
|
321
320
|
|
|
322
|
-
(
|
|
323
|
-
self.config[TRAINED_MODEL_NAME_CONFIG_KEY],
|
|
324
|
-
self.config[TRAINED_EMBEDDINGS_CONFIG_KEY],
|
|
325
|
-
) = self._perform_training_time_health_checks()
|
|
326
|
-
|
|
327
321
|
if store_type == DEFAULT_VECTOR_STORE_TYPE:
|
|
328
322
|
logger.info("enterprise_search_policy.train.faiss")
|
|
329
323
|
with self._model_storage.write_to(self._resource) as path:
|
|
@@ -698,16 +692,16 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
698
692
|
**kwargs: Any,
|
|
699
693
|
) -> "EnterpriseSearchPolicy":
|
|
700
694
|
"""Loads a trained policy (see parent class for full docstring)."""
|
|
695
|
+
|
|
696
|
+
# Perform health checks for both LLM and embeddings client configs
|
|
697
|
+
cls._perform_health_checks(config, "enterprise_search_policy.load")
|
|
698
|
+
|
|
701
699
|
prompt_template = None
|
|
702
|
-
persisted_config = None
|
|
703
700
|
try:
|
|
704
701
|
with model_storage.read_from(resource) as path:
|
|
705
702
|
prompt_template = rasa.shared.utils.io.read_file(
|
|
706
703
|
path / ENTERPRISE_SEARCH_PROMPT_FILE_NAME
|
|
707
704
|
)
|
|
708
|
-
persisted_config = rasa.shared.utils.io.read_json_file(
|
|
709
|
-
path / ENTERPRISE_SEARCH_CONFIG_FILE_NAME
|
|
710
|
-
)
|
|
711
705
|
except (FileNotFoundError, FileIOException) as e:
|
|
712
706
|
logger.warning(
|
|
713
707
|
"enterprise_search_policy.load.failed", error=e, resource=resource.name
|
|
@@ -737,7 +731,7 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
737
731
|
embeddings=embeddings,
|
|
738
732
|
) # type: ignore
|
|
739
733
|
|
|
740
|
-
|
|
734
|
+
return cls(
|
|
741
735
|
config,
|
|
742
736
|
model_storage,
|
|
743
737
|
resource,
|
|
@@ -746,14 +740,6 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
746
740
|
prompt_template=prompt_template,
|
|
747
741
|
)
|
|
748
742
|
|
|
749
|
-
cls._perform_inference_time_health_checks(
|
|
750
|
-
persisted_config,
|
|
751
|
-
policy.config.get(LLM_CONFIG_KEY),
|
|
752
|
-
policy.config.get(EMBEDDINGS_CONFIG_KEY),
|
|
753
|
-
)
|
|
754
|
-
|
|
755
|
-
return policy
|
|
756
|
-
|
|
757
743
|
@classmethod
|
|
758
744
|
def _get_local_knowledge_data(cls, config: Dict[str, Any]) -> Optional[List[str]]:
|
|
759
745
|
"""This is required only for local knowledge base types.
|
|
@@ -894,52 +880,26 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
894
880
|
|
|
895
881
|
return joined_answer + joined_sources
|
|
896
882
|
|
|
897
|
-
def _perform_training_time_health_checks(
|
|
898
|
-
self,
|
|
899
|
-
) -> Tuple[Optional[str], Optional[str]]:
|
|
900
|
-
train_model_name = perform_training_time_llm_health_check(
|
|
901
|
-
self.config.get(LLM_CONFIG_KEY),
|
|
902
|
-
DEFAULT_LLM_CONFIG,
|
|
903
|
-
"enterprise_search_policy.train",
|
|
904
|
-
EnterpriseSearchPolicy.__name__,
|
|
905
|
-
)
|
|
906
|
-
train_embedding_name = perform_training_time_embeddings_health_check(
|
|
907
|
-
self.config.get(EMBEDDINGS_CONFIG_KEY),
|
|
908
|
-
DEFAULT_EMBEDDINGS_CONFIG,
|
|
909
|
-
"enterprise_search_policy.train",
|
|
910
|
-
EnterpriseSearchPolicy.__name__,
|
|
911
|
-
)
|
|
912
|
-
return train_model_name, train_embedding_name
|
|
913
|
-
|
|
914
883
|
@classmethod
|
|
915
|
-
def
|
|
916
|
-
cls,
|
|
917
|
-
persisted_config: Optional[Dict[str, Any]],
|
|
918
|
-
resolved_llm_config: Optional[Dict[str, Any]],
|
|
919
|
-
resolved_embeddings_config: Optional[Dict[str, Any]],
|
|
884
|
+
def _perform_health_checks(
|
|
885
|
+
cls, config: Dict[Text, Any], log_source_method: str
|
|
920
886
|
) -> None:
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
)
|
|
926
|
-
perform_inference_time_llm_health_check(
|
|
927
|
-
resolved_llm_config,
|
|
887
|
+
# Perform health check of the LLM client config
|
|
888
|
+
llm_config = resolve_model_client_config(config.get(LLM_CONFIG_KEY, {}))
|
|
889
|
+
cls.perform_llm_health_check(
|
|
890
|
+
llm_config,
|
|
928
891
|
DEFAULT_LLM_CONFIG,
|
|
929
|
-
|
|
930
|
-
"enterprise_search_policy.load",
|
|
892
|
+
log_source_method,
|
|
931
893
|
EnterpriseSearchPolicy.__name__,
|
|
932
894
|
)
|
|
933
895
|
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
else None
|
|
896
|
+
# Perform health check of the embeddings client config
|
|
897
|
+
embeddings_config = resolve_model_client_config(
|
|
898
|
+
config.get(EMBEDDINGS_CONFIG_KEY, {})
|
|
938
899
|
)
|
|
939
|
-
|
|
940
|
-
|
|
900
|
+
cls.perform_embeddings_health_check(
|
|
901
|
+
embeddings_config,
|
|
941
902
|
DEFAULT_EMBEDDINGS_CONFIG,
|
|
942
|
-
|
|
943
|
-
"enterprise_search_policy.load",
|
|
903
|
+
log_source_method,
|
|
944
904
|
EnterpriseSearchPolicy.__name__,
|
|
945
905
|
)
|