rasa-pro 3.11.0rc1__py3-none-any.whl → 3.11.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/cli/inspect.py +2 -0
- rasa/cli/studio/studio.py +18 -8
- rasa/core/actions/action_repeat_bot_messages.py +17 -0
- rasa/core/channels/channel.py +17 -0
- rasa/core/channels/development_inspector.py +4 -1
- rasa/core/channels/voice_ready/audiocodes.py +15 -4
- rasa/core/channels/voice_ready/jambonz.py +13 -2
- rasa/core/channels/voice_ready/twilio_voice.py +6 -21
- rasa/core/channels/voice_stream/asr/asr_event.py +1 -1
- rasa/core/channels/voice_stream/asr/azure.py +5 -7
- rasa/core/channels/voice_stream/asr/deepgram.py +13 -11
- rasa/core/channels/voice_stream/voice_channel.py +61 -19
- rasa/core/nlg/contextual_response_rephraser.py +20 -12
- rasa/core/policies/enterprise_search_policy.py +32 -72
- rasa/core/policies/intentless_policy.py +34 -72
- rasa/dialogue_understanding/coexistence/llm_based_router.py +18 -33
- rasa/dialogue_understanding/generator/constants.py +0 -2
- rasa/dialogue_understanding/generator/flow_retrieval.py +33 -50
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +12 -40
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +18 -20
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +19 -1
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +26 -22
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +9 -0
- rasa/dialogue_understanding/processor/command_processor.py +21 -1
- rasa/e2e_test/e2e_test_case.py +85 -6
- rasa/engine/validation.py +88 -60
- rasa/model_service.py +3 -0
- rasa/nlu/tokenizers/whitespace_tokenizer.py +3 -14
- rasa/server.py +3 -1
- rasa/shared/constants.py +5 -5
- rasa/shared/core/constants.py +1 -1
- rasa/shared/core/domain.py +0 -26
- rasa/shared/core/flows/flows_list.py +5 -1
- rasa/shared/providers/_configs/litellm_router_client_config.py +29 -9
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +6 -14
- rasa/shared/providers/embedding/litellm_router_embedding_client.py +1 -1
- rasa/shared/providers/llm/_base_litellm_client.py +32 -1
- rasa/shared/providers/llm/litellm_router_llm_client.py +56 -1
- rasa/shared/providers/llm/self_hosted_llm_client.py +4 -28
- rasa/shared/providers/router/_base_litellm_router_client.py +35 -1
- rasa/shared/utils/common.py +1 -1
- rasa/shared/utils/health_check/__init__.py +0 -0
- rasa/shared/utils/health_check/embeddings_health_check_mixin.py +31 -0
- rasa/shared/utils/health_check/health_check.py +256 -0
- rasa/shared/utils/health_check/llm_health_check_mixin.py +31 -0
- rasa/shared/utils/llm.py +5 -2
- rasa/shared/utils/yaml.py +102 -62
- rasa/studio/auth.py +3 -5
- rasa/studio/config.py +13 -4
- rasa/studio/constants.py +1 -0
- rasa/studio/data_handler.py +10 -3
- rasa/studio/upload.py +21 -10
- rasa/telemetry.py +15 -1
- rasa/tracing/config.py +3 -1
- rasa/tracing/instrumentation/attribute_extractors.py +20 -0
- rasa/tracing/instrumentation/instrumentation.py +121 -0
- rasa/utils/common.py +5 -0
- rasa/utils/io.py +8 -16
- rasa/utils/sanic_error_handler.py +32 -0
- rasa/version.py +1 -1
- {rasa_pro-3.11.0rc1.dist-info → rasa_pro-3.11.0rc3.dist-info}/METADATA +3 -2
- {rasa_pro-3.11.0rc1.dist-info → rasa_pro-3.11.0rc3.dist-info}/RECORD +65 -61
- rasa/shared/utils/health_check.py +0 -533
- {rasa_pro-3.11.0rc1.dist-info → rasa_pro-3.11.0rc3.dist-info}/NOTICE +0 -0
- {rasa_pro-3.11.0rc1.dist-info → rasa_pro-3.11.0rc3.dist-info}/WHEEL +0 -0
- {rasa_pro-3.11.0rc1.dist-info → rasa_pro-3.11.0rc3.dist-info}/entry_points.txt +0 -0
rasa/cli/inspect.py
CHANGED
|
@@ -5,6 +5,7 @@ from typing import List, Text
|
|
|
5
5
|
|
|
6
6
|
from sanic import Sanic
|
|
7
7
|
|
|
8
|
+
from rasa import telemetry
|
|
8
9
|
from rasa.cli import SubParsersAction
|
|
9
10
|
from rasa.cli.arguments import shell as arguments
|
|
10
11
|
from rasa.core import constants
|
|
@@ -70,4 +71,5 @@ def inspect(args: argparse.Namespace) -> None:
|
|
|
70
71
|
args.credentials = None
|
|
71
72
|
args.server_listeners = [(after_start_hook_open_inspector, "after_server_start")]
|
|
72
73
|
|
|
74
|
+
telemetry.track_inspect_started(args.connector)
|
|
73
75
|
rasa.cli.run.run(args)
|
rasa/cli/studio/studio.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import argparse
|
|
2
|
-
from typing import List, Optional
|
|
2
|
+
from typing import List, Optional, Tuple
|
|
3
3
|
from urllib.parse import ParseResult, urlparse
|
|
4
4
|
|
|
5
5
|
import questionary
|
|
@@ -149,7 +149,7 @@ def _configure_studio_url() -> Optional[str]:
|
|
|
149
149
|
return studio_url
|
|
150
150
|
|
|
151
151
|
|
|
152
|
-
def _get_advanced_config(studio_url: str) ->
|
|
152
|
+
def _get_advanced_config(studio_url: str) -> Tuple:
|
|
153
153
|
"""Get the advanced configuration values for Rasa Studio."""
|
|
154
154
|
keycloak_url = questionary.text(
|
|
155
155
|
"Please provide your Rasa Studio Keycloak URL",
|
|
@@ -167,7 +167,7 @@ def _get_advanced_config(studio_url: str) -> tuple:
|
|
|
167
167
|
return keycloak_url, realm_name, client_id
|
|
168
168
|
|
|
169
169
|
|
|
170
|
-
def _get_default_config(studio_url: str) ->
|
|
170
|
+
def _get_default_config(studio_url: str) -> Tuple:
|
|
171
171
|
"""Get the default configuration values for Rasa Studio."""
|
|
172
172
|
keycloak_url = studio_url + "auth/"
|
|
173
173
|
realm_name = DEFAULT_REALM_NAME
|
|
@@ -178,6 +178,7 @@ def _get_default_config(studio_url: str) -> tuple:
|
|
|
178
178
|
f"Keycloak URL: {keycloak_url}, "
|
|
179
179
|
f"Realm Name: '{realm_name}', "
|
|
180
180
|
f"Client ID: '{client_id}'. "
|
|
181
|
+
f"SSL verification is enabled."
|
|
181
182
|
f"You can use '--advanced' to configure these settings."
|
|
182
183
|
)
|
|
183
184
|
|
|
@@ -185,7 +186,11 @@ def _get_default_config(studio_url: str) -> tuple:
|
|
|
185
186
|
|
|
186
187
|
|
|
187
188
|
def _create_studio_config(
|
|
188
|
-
studio_url: str,
|
|
189
|
+
studio_url: str,
|
|
190
|
+
keycloak_url: str,
|
|
191
|
+
realm_name: str,
|
|
192
|
+
client_id: str,
|
|
193
|
+
disable_verify: bool = False,
|
|
189
194
|
) -> StudioConfig:
|
|
190
195
|
"""Create a StudioConfig object with the provided parameters."""
|
|
191
196
|
return StudioConfig(
|
|
@@ -193,6 +198,7 @@ def _create_studio_config(
|
|
|
193
198
|
studio_url=studio_url + "api/graphql/",
|
|
194
199
|
client_id=client_id,
|
|
195
200
|
realm_name=realm_name,
|
|
201
|
+
disable_verify=disable_verify,
|
|
196
202
|
)
|
|
197
203
|
|
|
198
204
|
|
|
@@ -227,19 +233,23 @@ def _configure_studio_config(args: argparse.Namespace) -> StudioConfig:
|
|
|
227
233
|
|
|
228
234
|
# create a configuration and auth object to try to reach the studio
|
|
229
235
|
studio_config = _create_studio_config(
|
|
230
|
-
studio_url,
|
|
236
|
+
studio_url,
|
|
237
|
+
keycloak_url,
|
|
238
|
+
realm_name,
|
|
239
|
+
client_id,
|
|
240
|
+
disable_verify=args.disable_verify,
|
|
231
241
|
)
|
|
232
242
|
|
|
233
|
-
if
|
|
243
|
+
if studio_config.disable_verify:
|
|
234
244
|
rasa.shared.utils.cli.print_info(
|
|
235
245
|
"Disabling SSL verification for the Rasa Studio authentication server."
|
|
236
246
|
)
|
|
237
|
-
studio_auth = StudioAuth(studio_config, verify=False)
|
|
238
247
|
else:
|
|
239
248
|
rasa.shared.utils.cli.print_info(
|
|
240
249
|
"Enabling SSL verification for the Rasa Studio authentication server."
|
|
241
250
|
)
|
|
242
|
-
|
|
251
|
+
|
|
252
|
+
studio_auth = StudioAuth(studio_config)
|
|
243
253
|
|
|
244
254
|
if _check_studio_auth(studio_auth):
|
|
245
255
|
return studio_config
|
|
@@ -3,6 +3,15 @@ from typing import Optional, Dict, Any, List
|
|
|
3
3
|
from rasa.core.actions.action import Action
|
|
4
4
|
from rasa.core.channels import OutputChannel
|
|
5
5
|
from rasa.core.nlg import NaturalLanguageGenerator
|
|
6
|
+
from rasa.dialogue_understanding.patterns.collect_information import (
|
|
7
|
+
CollectInformationPatternFlowStackFrame,
|
|
8
|
+
)
|
|
9
|
+
from rasa.dialogue_understanding.patterns.repeat import (
|
|
10
|
+
RepeatBotMessagesPatternFlowStackFrame,
|
|
11
|
+
)
|
|
12
|
+
from rasa.dialogue_understanding.patterns.user_silence import (
|
|
13
|
+
UserSilencePatternFlowStackFrame,
|
|
14
|
+
)
|
|
6
15
|
from rasa.shared.core.constants import ACTION_REPEAT_BOT_MESSAGES
|
|
7
16
|
from rasa.shared.core.domain import Domain
|
|
8
17
|
from rasa.shared.core.events import Event, BotUttered, UserUttered
|
|
@@ -39,6 +48,14 @@ class ActionRepeatBotMessages(Action):
|
|
|
39
48
|
The elif condition doesn't break when it sees User3 event.
|
|
40
49
|
But it does at User2 event.
|
|
41
50
|
"""
|
|
51
|
+
# Skip action if we are in a collect information step whose
|
|
52
|
+
# default behavior is to repeat anyways
|
|
53
|
+
top_frame = tracker.stack.top(
|
|
54
|
+
lambda frame: isinstance(frame, RepeatBotMessagesPatternFlowStackFrame)
|
|
55
|
+
or isinstance(frame, UserSilencePatternFlowStackFrame)
|
|
56
|
+
)
|
|
57
|
+
if isinstance(top_frame, CollectInformationPatternFlowStackFrame):
|
|
58
|
+
return []
|
|
42
59
|
# filter user and bot events
|
|
43
60
|
filtered = [
|
|
44
61
|
e for e in tracker.events if isinstance(e, (BotUttered, UserUttered))
|
rasa/core/channels/channel.py
CHANGED
|
@@ -313,6 +313,23 @@ class OutputChannel:
|
|
|
313
313
|
button_msg = cli_utils.button_to_string(button, idx)
|
|
314
314
|
await self.send_text_message(recipient_id, button_msg)
|
|
315
315
|
|
|
316
|
+
async def send_text_with_buttons_concise(
|
|
317
|
+
self,
|
|
318
|
+
recipient_id: str,
|
|
319
|
+
text: str,
|
|
320
|
+
buttons: List[Dict[str, Any]],
|
|
321
|
+
**kwargs: Any,
|
|
322
|
+
) -> None:
|
|
323
|
+
"""Sends buttons in a concise format, useful for voice channels."""
|
|
324
|
+
if text.strip()[-1] not in {".", "!", "?", ":"}:
|
|
325
|
+
text += "."
|
|
326
|
+
text += " "
|
|
327
|
+
for idx, button in enumerate(buttons):
|
|
328
|
+
text += button["title"]
|
|
329
|
+
if idx != len(buttons) - 1:
|
|
330
|
+
text += ", "
|
|
331
|
+
await self.send_text_message(recipient_id, text)
|
|
332
|
+
|
|
316
333
|
async def send_quick_replies(
|
|
317
334
|
self,
|
|
318
335
|
recipient_id: Text,
|
|
@@ -187,5 +187,8 @@ class TrackerStream:
|
|
|
187
187
|
if not self._connected_clients:
|
|
188
188
|
return
|
|
189
189
|
await asyncio.wait(
|
|
190
|
-
[
|
|
190
|
+
[
|
|
191
|
+
asyncio.create_task(self._send(websocket, message))
|
|
192
|
+
for websocket in self._connected_clients
|
|
193
|
+
]
|
|
191
194
|
)
|
|
@@ -21,6 +21,7 @@ from sanic.exceptions import NotFound, SanicException, ServerError
|
|
|
21
21
|
from sanic.request import Request
|
|
22
22
|
from sanic.response import HTTPResponse
|
|
23
23
|
|
|
24
|
+
from rasa.utils.io import remove_emojis
|
|
24
25
|
|
|
25
26
|
structlogger = structlog.get_logger()
|
|
26
27
|
|
|
@@ -73,7 +74,7 @@ class Conversation:
|
|
|
73
74
|
@staticmethod
|
|
74
75
|
def get_metadata(activity: Dict[Text, Any]) -> Optional[Dict[Text, Any]]:
|
|
75
76
|
"""Get metadata from the activity."""
|
|
76
|
-
return activity
|
|
77
|
+
return asdict(map_call_params(activity["parameters"]))
|
|
77
78
|
|
|
78
79
|
@staticmethod
|
|
79
80
|
def _handle_event(event: Dict[Text, Any]) -> Text:
|
|
@@ -87,17 +88,16 @@ class Conversation:
|
|
|
87
88
|
|
|
88
89
|
if event["name"] == EVENT_START:
|
|
89
90
|
text = f"{INTENT_MESSAGE_PREFIX}{USER_INTENT_SESSION_START}"
|
|
90
|
-
event_params = asdict(map_call_params(event["parameters"]))
|
|
91
91
|
elif event["name"] == EVENT_DTMF:
|
|
92
92
|
text = f"{INTENT_MESSAGE_PREFIX}vaig_event_DTMF"
|
|
93
93
|
event_params = {"value": event["value"]}
|
|
94
|
+
text += json.dumps(event_params)
|
|
94
95
|
else:
|
|
95
96
|
structlogger.warning(
|
|
96
97
|
"audiocodes.handle.event.unknown_event", event_payload=event
|
|
97
98
|
)
|
|
98
99
|
return ""
|
|
99
100
|
|
|
100
|
-
text += json.dumps(event_params)
|
|
101
101
|
return text
|
|
102
102
|
|
|
103
103
|
def is_active_conversation(self, now: datetime, delta: timedelta) -> bool:
|
|
@@ -383,7 +383,7 @@ class AudiocodesInput(InputChannel):
|
|
|
383
383
|
{"conversation": <conversation_id>, "reason": Optional[Text]}.
|
|
384
384
|
"""
|
|
385
385
|
self._get_conversation(request.token, conversation_id)
|
|
386
|
-
reason =
|
|
386
|
+
reason = {"reason": request.json.get("reason")}
|
|
387
387
|
await on_new_message(
|
|
388
388
|
UserMessage(
|
|
389
389
|
text=f"{INTENT_MESSAGE_PREFIX}session_end",
|
|
@@ -449,6 +449,7 @@ class AudiocodesOutput(OutputChannel):
|
|
|
449
449
|
self, recipient_id: Text, text: Text, **kwargs: Any
|
|
450
450
|
) -> None:
|
|
451
451
|
"""Send a text message."""
|
|
452
|
+
text = remove_emojis(text)
|
|
452
453
|
await self.add_message({"type": "message", "text": text})
|
|
453
454
|
|
|
454
455
|
async def send_image_url(
|
|
@@ -471,6 +472,16 @@ class AudiocodesOutput(OutputChannel):
|
|
|
471
472
|
"""Indicate that the conversation should be ended."""
|
|
472
473
|
await self.add_message({"type": "event", "name": "hangup"})
|
|
473
474
|
|
|
475
|
+
async def send_text_with_buttons(
|
|
476
|
+
self,
|
|
477
|
+
recipient_id: str,
|
|
478
|
+
text: str,
|
|
479
|
+
buttons: List[Dict[str, Any]],
|
|
480
|
+
**kwargs: Any,
|
|
481
|
+
) -> None:
|
|
482
|
+
"""Uses the concise button output format for voice channels."""
|
|
483
|
+
await self.send_text_with_buttons_concise(recipient_id, text, buttons, **kwargs)
|
|
484
|
+
|
|
474
485
|
|
|
475
486
|
class WebsocketOutput(AudiocodesOutput):
|
|
476
487
|
def __init__(self, ws: Any, conversation_id: Text) -> None:
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any, Awaitable, Callable, Dict, Optional, Text
|
|
1
|
+
from typing import Any, Awaitable, Callable, Dict, List, Optional, Text
|
|
2
2
|
|
|
3
3
|
import structlog
|
|
4
4
|
from rasa.core.channels.channel import InputChannel, OutputChannel, UserMessage
|
|
@@ -14,7 +14,7 @@ from sanic.request import Request
|
|
|
14
14
|
from sanic.response import HTTPResponse
|
|
15
15
|
|
|
16
16
|
from rasa.shared.utils.common import mark_as_beta_feature
|
|
17
|
-
|
|
17
|
+
from rasa.utils.io import remove_emojis
|
|
18
18
|
|
|
19
19
|
structlogger = structlog.get_logger()
|
|
20
20
|
|
|
@@ -87,6 +87,7 @@ class JambonzWebsocketOutput(OutputChannel):
|
|
|
87
87
|
self, recipient_id: Text, text: Text, **kwargs: Any
|
|
88
88
|
) -> None:
|
|
89
89
|
"""Send a text message."""
|
|
90
|
+
text = remove_emojis(text)
|
|
90
91
|
await self.add_message({"type": "message", "text": text})
|
|
91
92
|
|
|
92
93
|
async def send_image_url(
|
|
@@ -108,3 +109,13 @@ class JambonzWebsocketOutput(OutputChannel):
|
|
|
108
109
|
async def hangup(self, recipient_id: Text, **kwargs: Any) -> None:
|
|
109
110
|
"""Indicate that the conversation should be ended."""
|
|
110
111
|
await send_ws_hangup_message(DEFAULT_HANGUP_DELAY_SECONDS, self.ws)
|
|
112
|
+
|
|
113
|
+
async def send_text_with_buttons(
|
|
114
|
+
self,
|
|
115
|
+
recipient_id: str,
|
|
116
|
+
text: str,
|
|
117
|
+
buttons: List[Dict[str, Any]],
|
|
118
|
+
**kwargs: Any,
|
|
119
|
+
) -> None:
|
|
120
|
+
"""Uses the concise button output format for voice channels."""
|
|
121
|
+
await self.send_text_with_buttons_concise(recipient_id, text, buttons, **kwargs)
|
|
@@ -358,38 +358,23 @@ class TwilioVoiceCollectingOutputChannel(CollectingOutputChannel):
|
|
|
358
358
|
"""Name of the output channel."""
|
|
359
359
|
return "twilio_voice"
|
|
360
360
|
|
|
361
|
-
@staticmethod
|
|
362
|
-
def _emoji_warning(text: Text) -> None:
|
|
363
|
-
"""Raises a warning if text contains an emoji."""
|
|
364
|
-
emoji_regex = rasa.utils.io.get_emoji_regex()
|
|
365
|
-
if emoji_regex.findall(text):
|
|
366
|
-
rasa.shared.utils.io.raise_warning(
|
|
367
|
-
"Text contains an emoji in a voice response. "
|
|
368
|
-
"Review responses to provide a voice-friendly alternative."
|
|
369
|
-
)
|
|
370
|
-
|
|
371
361
|
async def send_text_message(
|
|
372
362
|
self, recipient_id: Text, text: Text, **kwargs: Any
|
|
373
363
|
) -> None:
|
|
374
364
|
"""Sends the text message after removing emojis."""
|
|
375
|
-
|
|
365
|
+
text = rasa.utils.io.remove_emojis(text)
|
|
376
366
|
for message_part in text.strip().split("\n\n"):
|
|
377
367
|
await self._persist_message(self._message(recipient_id, text=message_part))
|
|
378
368
|
|
|
379
369
|
async def send_text_with_buttons(
|
|
380
370
|
self,
|
|
381
|
-
recipient_id:
|
|
382
|
-
text:
|
|
383
|
-
buttons: List[Dict[
|
|
371
|
+
recipient_id: str,
|
|
372
|
+
text: str,
|
|
373
|
+
buttons: List[Dict[str, Any]],
|
|
384
374
|
**kwargs: Any,
|
|
385
375
|
) -> None:
|
|
386
|
-
"""
|
|
387
|
-
self.
|
|
388
|
-
await self._persist_message(self._message(recipient_id, text=text))
|
|
389
|
-
|
|
390
|
-
for b in buttons:
|
|
391
|
-
self._emoji_warning(b["title"])
|
|
392
|
-
await self._persist_message(self._message(recipient_id, text=b["title"]))
|
|
376
|
+
"""Uses the concise button output format for voice channels."""
|
|
377
|
+
await self.send_text_with_buttons_concise(recipient_id, text, buttons, **kwargs)
|
|
393
378
|
|
|
394
379
|
async def send_image_url(
|
|
395
380
|
self, recipient_id: Text, image: Text, **kwargs: Any
|
|
@@ -7,7 +7,7 @@ from rasa.core.channels.voice_stream.asr.asr_engine import ASREngine, ASREngineC
|
|
|
7
7
|
from rasa.core.channels.voice_stream.asr.asr_event import (
|
|
8
8
|
ASREvent,
|
|
9
9
|
NewTranscript,
|
|
10
|
-
|
|
10
|
+
UserIsSpeaking,
|
|
11
11
|
)
|
|
12
12
|
from rasa.core.channels.voice_stream.audio_bytes import HERTZ, RasaAudioBytes
|
|
13
13
|
from rasa.shared.exceptions import ConnectionException
|
|
@@ -31,9 +31,9 @@ class AzureASR(ASREngine[AzureASRConfig]):
|
|
|
31
31
|
asyncio.Queue()
|
|
32
32
|
)
|
|
33
33
|
|
|
34
|
-
def
|
|
35
|
-
"""Replace the
|
|
36
|
-
self.fill_queue(
|
|
34
|
+
def signal_user_is_speaking(self, event: Any) -> None:
|
|
35
|
+
"""Replace the azure event with a generic is speaking event."""
|
|
36
|
+
self.fill_queue(UserIsSpeaking())
|
|
37
37
|
|
|
38
38
|
def fill_queue(self, event: Any) -> None:
|
|
39
39
|
"""Either puts the event or a dedicated ASR Event into the queue."""
|
|
@@ -60,9 +60,7 @@ class AzureASR(ASREngine[AzureASRConfig]):
|
|
|
60
60
|
audio_config=audio_config,
|
|
61
61
|
)
|
|
62
62
|
self.speech_recognizer.recognized.connect(self.fill_queue)
|
|
63
|
-
self.speech_recognizer.
|
|
64
|
-
self.signal_user_started_speaking
|
|
65
|
-
)
|
|
63
|
+
self.speech_recognizer.recognizing.connect(self.signal_user_is_speaking)
|
|
66
64
|
self.speech_recognizer.start_continuous_recognition_async()
|
|
67
65
|
self.is_recognizing = True
|
|
68
66
|
|
|
@@ -10,7 +10,7 @@ from rasa.core.channels.voice_stream.asr.asr_engine import ASREngine, ASREngineC
|
|
|
10
10
|
from rasa.core.channels.voice_stream.asr.asr_event import (
|
|
11
11
|
ASREvent,
|
|
12
12
|
NewTranscript,
|
|
13
|
-
|
|
13
|
+
UserIsSpeaking,
|
|
14
14
|
)
|
|
15
15
|
from rasa.core.channels.voice_stream.audio_bytes import HERTZ, RasaAudioBytes
|
|
16
16
|
|
|
@@ -49,7 +49,7 @@ class DeepgramASR(ASREngine[DeepgramASRConfig]):
|
|
|
49
49
|
def _get_query_params(self) -> str:
|
|
50
50
|
return (
|
|
51
51
|
f"encoding=mulaw&sample_rate={HERTZ}&endpointing={self.config.endpointing}"
|
|
52
|
-
f"&vad_events=true&language={self.config.language}"
|
|
52
|
+
f"&vad_events=true&language={self.config.language}&interim_results=true"
|
|
53
53
|
f"&model={self.config.model}&smart_format={str(self.config.smart_format).lower()}"
|
|
54
54
|
)
|
|
55
55
|
|
|
@@ -66,16 +66,18 @@ class DeepgramASR(ASREngine[DeepgramASRConfig]):
|
|
|
66
66
|
def engine_event_to_asr_event(self, e: Any) -> Optional[ASREvent]:
|
|
67
67
|
"""Translate an engine event to a common ASREvent."""
|
|
68
68
|
data = json.loads(e)
|
|
69
|
-
if
|
|
69
|
+
if "is_final" in data:
|
|
70
70
|
transcript = data["channel"]["alternatives"][0]["transcript"]
|
|
71
|
-
if data
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
71
|
+
if data["is_final"]:
|
|
72
|
+
if data.get("speech_final"):
|
|
73
|
+
full_transcript = self.accumulated_transcript + transcript
|
|
74
|
+
self.accumulated_transcript = ""
|
|
75
|
+
if full_transcript:
|
|
76
|
+
return NewTranscript(full_transcript)
|
|
77
|
+
else:
|
|
78
|
+
self.accumulated_transcript += transcript
|
|
79
|
+
elif transcript:
|
|
80
|
+
return UserIsSpeaking()
|
|
79
81
|
return None
|
|
80
82
|
|
|
81
83
|
@staticmethod
|
|
@@ -2,13 +2,10 @@ import asyncio
|
|
|
2
2
|
import structlog
|
|
3
3
|
import copy
|
|
4
4
|
from dataclasses import asdict, dataclass
|
|
5
|
-
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Optional, Tuple
|
|
5
|
+
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Tuple
|
|
6
6
|
|
|
7
7
|
from rasa.core.channels.voice_stream.util import generate_silence
|
|
8
|
-
from rasa.shared.core.constants import
|
|
9
|
-
SILENCE_TIMEOUT_DEFAULT_VALUE,
|
|
10
|
-
SLOT_SILENCE_TIMEOUT,
|
|
11
|
-
)
|
|
8
|
+
from rasa.shared.core.constants import SLOT_SILENCE_TIMEOUT
|
|
12
9
|
from rasa.shared.utils.common import (
|
|
13
10
|
class_from_module_path,
|
|
14
11
|
mark_as_beta_feature,
|
|
@@ -24,7 +21,7 @@ from rasa.core.channels.voice_stream.asr.asr_engine import ASREngine
|
|
|
24
21
|
from rasa.core.channels.voice_stream.asr.asr_event import (
|
|
25
22
|
ASREvent,
|
|
26
23
|
NewTranscript,
|
|
27
|
-
|
|
24
|
+
UserIsSpeaking,
|
|
28
25
|
)
|
|
29
26
|
from sanic import Websocket # type: ignore
|
|
30
27
|
|
|
@@ -40,6 +37,7 @@ from rasa.core.channels.voice_stream.tts.azure import AzureTTS
|
|
|
40
37
|
from rasa.core.channels.voice_stream.tts.tts_engine import TTSEngine, TTSError
|
|
41
38
|
from rasa.core.channels.voice_stream.tts.cartesia import CartesiaTTS
|
|
42
39
|
from rasa.core.channels.voice_stream.tts.tts_cache import TTSCache
|
|
40
|
+
from rasa.utils.io import remove_emojis
|
|
43
41
|
|
|
44
42
|
logger = structlog.get_logger(__name__)
|
|
45
43
|
|
|
@@ -157,9 +155,20 @@ class VoiceOutputChannel(OutputChannel):
|
|
|
157
155
|
self.tracker_state["slots"][SLOT_SILENCE_TIMEOUT]
|
|
158
156
|
)
|
|
159
157
|
|
|
158
|
+
async def send_text_with_buttons(
|
|
159
|
+
self,
|
|
160
|
+
recipient_id: str,
|
|
161
|
+
text: str,
|
|
162
|
+
buttons: List[Dict[str, Any]],
|
|
163
|
+
**kwargs: Any,
|
|
164
|
+
) -> None:
|
|
165
|
+
"""Uses the concise button output format for voice channels."""
|
|
166
|
+
await self.send_text_with_buttons_concise(recipient_id, text, buttons, **kwargs)
|
|
167
|
+
|
|
160
168
|
async def send_text_message(
|
|
161
169
|
self, recipient_id: str, text: str, **kwargs: Any
|
|
162
170
|
) -> None:
|
|
171
|
+
text = remove_emojis(text)
|
|
163
172
|
self.update_silence_timeout()
|
|
164
173
|
cached_audio_bytes = self.tts_cache.get(text)
|
|
165
174
|
collected_audio_bytes = RasaAudioBytes(b"")
|
|
@@ -221,11 +230,18 @@ class VoiceOutputChannel(OutputChannel):
|
|
|
221
230
|
|
|
222
231
|
|
|
223
232
|
class VoiceInputChannel(InputChannel):
|
|
224
|
-
def __init__(
|
|
233
|
+
def __init__(
|
|
234
|
+
self,
|
|
235
|
+
server_url: str,
|
|
236
|
+
asr_config: Dict,
|
|
237
|
+
tts_config: Dict,
|
|
238
|
+
monitor_silence: bool = False,
|
|
239
|
+
):
|
|
225
240
|
validate_voice_license_scope()
|
|
226
241
|
self.server_url = server_url
|
|
227
242
|
self.asr_config = asr_config
|
|
228
243
|
self.tts_config = tts_config
|
|
244
|
+
self.monitor_silence = monitor_silence
|
|
229
245
|
self.tts_cache = TTSCache(tts_config.get("cache_size", 1000))
|
|
230
246
|
|
|
231
247
|
async def handle_silence_timeout(
|
|
@@ -235,10 +251,14 @@ class VoiceInputChannel(InputChannel):
|
|
|
235
251
|
tts_engine: TTSEngine,
|
|
236
252
|
call_parameters: CallParameters,
|
|
237
253
|
) -> None:
|
|
238
|
-
timeout = call_state.silence_timeout
|
|
239
|
-
|
|
254
|
+
timeout = call_state.silence_timeout
|
|
255
|
+
if not timeout:
|
|
256
|
+
return
|
|
257
|
+
if not self.monitor_silence:
|
|
258
|
+
return
|
|
259
|
+
logger.debug("voice_channel.silence_timeout_watch_started", timeout=timeout)
|
|
240
260
|
await asyncio.sleep(timeout)
|
|
241
|
-
logger.
|
|
261
|
+
logger.debug("voice_channel.silence_timeout_tripped")
|
|
242
262
|
output_channel = self.create_output_channel(voice_websocket, tts_engine)
|
|
243
263
|
message = UserMessage(
|
|
244
264
|
"/silence_timeout",
|
|
@@ -249,10 +269,23 @@ class VoiceInputChannel(InputChannel):
|
|
|
249
269
|
)
|
|
250
270
|
await on_new_message(message)
|
|
251
271
|
|
|
272
|
+
@staticmethod
|
|
273
|
+
def _cancel_silence_timeout_watcher() -> None:
|
|
274
|
+
"""Cancels the silent timeout task if it exists."""
|
|
275
|
+
if call_state.silence_timeout_watcher:
|
|
276
|
+
logger.debug("voice_channel.cancelling_current_timeout_watcher_task")
|
|
277
|
+
call_state.silence_timeout_watcher.cancel()
|
|
278
|
+
call_state.silence_timeout_watcher = None # type: ignore[attr-defined]
|
|
279
|
+
|
|
252
280
|
@classmethod
|
|
253
281
|
def from_credentials(cls, credentials: Optional[Dict[str, Any]]) -> InputChannel:
|
|
254
282
|
credentials = credentials or {}
|
|
255
|
-
return cls(
|
|
283
|
+
return cls(
|
|
284
|
+
credentials["server_url"],
|
|
285
|
+
credentials["asr"],
|
|
286
|
+
credentials["tts"],
|
|
287
|
+
credentials.get("monitor_silence", False),
|
|
288
|
+
)
|
|
256
289
|
|
|
257
290
|
def channel_bytes_to_rasa_audio_bytes(self, input_bytes: bytes) -> RasaAudioBytes:
|
|
258
291
|
raise NotImplementedError
|
|
@@ -311,11 +344,14 @@ class VoiceInputChannel(InputChannel):
|
|
|
311
344
|
is_bot_speaking_after = call_state.is_bot_speaking
|
|
312
345
|
|
|
313
346
|
if not is_bot_speaking_before and is_bot_speaking_after:
|
|
314
|
-
logger.
|
|
347
|
+
logger.debug("voice_channel.bot_started_speaking")
|
|
348
|
+
# relevant when the bot speaks multiple messages in one turn
|
|
349
|
+
self._cancel_silence_timeout_watcher()
|
|
315
350
|
|
|
316
351
|
# we just stopped speaking, starting a watcher for silence timeout
|
|
317
352
|
if is_bot_speaking_before and not is_bot_speaking_after:
|
|
318
|
-
logger.
|
|
353
|
+
logger.debug("voice_channel.bot_stopped_speaking")
|
|
354
|
+
self._cancel_silence_timeout_watcher()
|
|
319
355
|
call_state.silence_timeout_watcher = ( # type: ignore[attr-defined]
|
|
320
356
|
asyncio.create_task(
|
|
321
357
|
self.handle_silence_timeout(
|
|
@@ -342,12 +378,20 @@ class VoiceInputChannel(InputChannel):
|
|
|
342
378
|
call_parameters,
|
|
343
379
|
)
|
|
344
380
|
|
|
381
|
+
audio_forwarding_task = asyncio.create_task(consume_audio_bytes())
|
|
382
|
+
asr_event_task = asyncio.create_task(consume_asr_events())
|
|
345
383
|
await asyncio.wait(
|
|
346
|
-
[
|
|
384
|
+
[audio_forwarding_task, asr_event_task],
|
|
347
385
|
return_when=asyncio.FIRST_COMPLETED,
|
|
348
386
|
)
|
|
387
|
+
if not audio_forwarding_task.done():
|
|
388
|
+
audio_forwarding_task.cancel()
|
|
389
|
+
if not asr_event_task.done():
|
|
390
|
+
asr_event_task.cancel()
|
|
349
391
|
await tts_engine.close_connection()
|
|
350
392
|
await asr_engine.close_connection()
|
|
393
|
+
await channel_websocket.close()
|
|
394
|
+
self._cancel_silence_timeout_watcher()
|
|
351
395
|
|
|
352
396
|
def create_output_channel(
|
|
353
397
|
self, voice_websocket: Websocket, tts_engine: TTSEngine
|
|
@@ -365,7 +409,7 @@ class VoiceInputChannel(InputChannel):
|
|
|
365
409
|
) -> None:
|
|
366
410
|
"""Handle a new event from the ASR system."""
|
|
367
411
|
if isinstance(e, NewTranscript) and e.text:
|
|
368
|
-
logger.
|
|
412
|
+
logger.debug(
|
|
369
413
|
"VoiceInputChannel.handle_asr_event.new_transcript", transcript=e.text
|
|
370
414
|
)
|
|
371
415
|
call_state.is_user_speaking = False # type: ignore[attr-defined]
|
|
@@ -378,8 +422,6 @@ class VoiceInputChannel(InputChannel):
|
|
|
378
422
|
metadata=asdict(call_parameters),
|
|
379
423
|
)
|
|
380
424
|
await on_new_message(message)
|
|
381
|
-
elif isinstance(e,
|
|
382
|
-
|
|
383
|
-
call_state.silence_timeout_watcher.cancel()
|
|
384
|
-
call_state.silence_timeout_watcher = None # type: ignore[attr-defined]
|
|
425
|
+
elif isinstance(e, UserIsSpeaking):
|
|
426
|
+
self._cancel_silence_timeout_watcher()
|
|
385
427
|
call_state.is_user_speaking = True # type: ignore[attr-defined]
|