rasa-pro 3.11.0rc2__py3-none-any.whl → 3.11.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/__main__.py +9 -3
- rasa/cli/studio/upload.py +0 -15
- rasa/cli/utils.py +1 -1
- rasa/core/channels/development_inspector.py +8 -2
- rasa/core/channels/voice_ready/audiocodes.py +3 -4
- rasa/core/channels/voice_stream/asr/asr_engine.py +19 -1
- rasa/core/channels/voice_stream/asr/asr_event.py +1 -1
- rasa/core/channels/voice_stream/asr/azure.py +16 -9
- rasa/core/channels/voice_stream/asr/deepgram.py +17 -14
- rasa/core/channels/voice_stream/tts/azure.py +3 -1
- rasa/core/channels/voice_stream/tts/cartesia.py +3 -3
- rasa/core/channels/voice_stream/tts/tts_engine.py +10 -1
- rasa/core/channels/voice_stream/voice_channel.py +48 -18
- rasa/core/information_retrieval/qdrant.py +1 -0
- rasa/core/nlg/contextual_response_rephraser.py +2 -2
- rasa/core/persistor.py +93 -49
- rasa/core/policies/enterprise_search_policy.py +5 -5
- rasa/core/policies/flows/flow_executor.py +18 -8
- rasa/core/policies/intentless_policy.py +9 -5
- rasa/core/processor.py +7 -5
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +2 -1
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +9 -0
- rasa/e2e_test/aggregate_test_stats_calculator.py +11 -1
- rasa/e2e_test/assertions.py +133 -16
- rasa/e2e_test/assertions_schema.yml +23 -0
- rasa/e2e_test/e2e_test_runner.py +2 -2
- rasa/engine/loader.py +12 -0
- rasa/engine/validation.py +310 -86
- rasa/model_manager/config.py +8 -0
- rasa/model_manager/model_api.py +166 -61
- rasa/model_manager/runner_service.py +31 -26
- rasa/model_manager/trainer_service.py +14 -23
- rasa/model_manager/warm_rasa_process.py +187 -0
- rasa/model_service.py +3 -5
- rasa/model_training.py +3 -1
- rasa/shared/constants.py +27 -5
- rasa/shared/core/constants.py +1 -1
- rasa/shared/core/domain.py +8 -31
- rasa/shared/core/flows/yaml_flows_io.py +13 -4
- rasa/shared/importers/importer.py +19 -2
- rasa/shared/importers/rasa.py +5 -1
- rasa/shared/nlu/training_data/formats/rasa_yaml.py +18 -3
- rasa/shared/providers/_configs/litellm_router_client_config.py +29 -9
- rasa/shared/providers/_utils.py +79 -0
- rasa/shared/providers/embedding/default_litellm_embedding_client.py +24 -0
- rasa/shared/providers/embedding/litellm_router_embedding_client.py +1 -1
- rasa/shared/providers/llm/_base_litellm_client.py +26 -0
- rasa/shared/providers/llm/default_litellm_llm_client.py +24 -0
- rasa/shared/providers/llm/litellm_router_llm_client.py +56 -1
- rasa/shared/providers/llm/self_hosted_llm_client.py +4 -28
- rasa/shared/providers/router/_base_litellm_router_client.py +35 -1
- rasa/shared/utils/common.py +30 -3
- rasa/shared/utils/health_check/health_check.py +26 -24
- rasa/shared/utils/yaml.py +116 -31
- rasa/studio/data_handler.py +3 -1
- rasa/studio/upload.py +119 -57
- rasa/telemetry.py +3 -1
- rasa/tracing/config.py +1 -1
- rasa/validator.py +40 -4
- rasa/version.py +1 -1
- {rasa_pro-3.11.0rc2.dist-info → rasa_pro-3.11.1.dist-info}/METADATA +2 -2
- {rasa_pro-3.11.0rc2.dist-info → rasa_pro-3.11.1.dist-info}/RECORD +65 -63
- {rasa_pro-3.11.0rc2.dist-info → rasa_pro-3.11.1.dist-info}/NOTICE +0 -0
- {rasa_pro-3.11.0rc2.dist-info → rasa_pro-3.11.1.dist-info}/WHEEL +0 -0
- {rasa_pro-3.11.0rc2.dist-info → rasa_pro-3.11.1.dist-info}/entry_points.txt +0 -0
rasa/__main__.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import argparse
|
|
2
|
+
from typing import Optional, List
|
|
2
3
|
import structlog
|
|
3
4
|
import os
|
|
4
5
|
import platform
|
|
@@ -97,12 +98,17 @@ def print_version() -> None:
|
|
|
97
98
|
print(f"License Expires : {get_license_expiration_date()}")
|
|
98
99
|
|
|
99
100
|
|
|
100
|
-
def main() -> None:
|
|
101
|
-
"""Run as standalone python application.
|
|
101
|
+
def main(raw_arguments: Optional[List[str]] = None) -> None:
|
|
102
|
+
"""Run as standalone python application.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
raw_arguments: Arguments to parse. If not provided,
|
|
106
|
+
arguments will be taken from the command line.
|
|
107
|
+
"""
|
|
102
108
|
warn_if_rasa_plus_package_installed()
|
|
103
109
|
parse_last_positional_argument_as_model_path()
|
|
104
110
|
arg_parser = create_argument_parser()
|
|
105
|
-
cmdline_arguments = arg_parser.parse_args()
|
|
111
|
+
cmdline_arguments = arg_parser.parse_args(raw_arguments)
|
|
106
112
|
|
|
107
113
|
log_level = getattr(cmdline_arguments, "loglevel", None)
|
|
108
114
|
logging_config_file = getattr(cmdline_arguments, "logging_config_file", None)
|
rasa/cli/studio/upload.py
CHANGED
|
@@ -32,25 +32,10 @@ def add_subparser(
|
|
|
32
32
|
set_upload_arguments(upload_parser)
|
|
33
33
|
|
|
34
34
|
|
|
35
|
-
def add_flows_param(
|
|
36
|
-
parser: argparse.ArgumentParser,
|
|
37
|
-
help_text: str = "Name of flows file to upload to Rasa Studio. Works with --calm",
|
|
38
|
-
default_path: str = "flows.yml",
|
|
39
|
-
) -> None:
|
|
40
|
-
parser.add_argument(
|
|
41
|
-
"--flows",
|
|
42
|
-
default=default_path,
|
|
43
|
-
nargs="+",
|
|
44
|
-
type=str,
|
|
45
|
-
help=help_text,
|
|
46
|
-
)
|
|
47
|
-
|
|
48
|
-
|
|
49
35
|
def set_upload_arguments(parser: argparse.ArgumentParser) -> None:
|
|
50
36
|
"""Add arguments for running `rasa upload`."""
|
|
51
37
|
add_data_param(parser, data_type="training")
|
|
52
38
|
add_domain_param(parser)
|
|
53
|
-
add_flows_param(parser)
|
|
54
39
|
add_config_param(parser)
|
|
55
40
|
add_endpoint_param(parser, help_text="Path to the endpoints file.")
|
|
56
41
|
|
rasa/cli/utils.py
CHANGED
|
@@ -305,7 +305,7 @@ def _validate_domain(validator: "Validator") -> bool:
|
|
|
305
305
|
valid_forms_in_stories_rules = validator.verify_forms_in_stories_rules()
|
|
306
306
|
valid_form_slots = validator.verify_form_slots()
|
|
307
307
|
valid_slot_mappings = validator.verify_slot_mappings()
|
|
308
|
-
valid_responses = validator.
|
|
308
|
+
valid_responses = validator.check_for_no_empty_parenthesis_in_responses()
|
|
309
309
|
valid_buttons = validator.validate_button_payloads()
|
|
310
310
|
return (
|
|
311
311
|
valid_domain_validity
|
|
@@ -128,9 +128,12 @@ class DevelopmentInspectProxy(InputChannel):
|
|
|
128
128
|
|
|
129
129
|
inspect_path = app.url_for(f"{app.name}.{underlying_webhook.name}.inspect")
|
|
130
130
|
|
|
131
|
+
# replace 0.0.0.0 with localhost
|
|
132
|
+
serve_location = app.serve_location.replace("0.0.0.0", "localhost")
|
|
133
|
+
|
|
131
134
|
print_info(
|
|
132
135
|
f"Development inspector for channel {self.name()} is running. To "
|
|
133
|
-
f"inspect conversations, visit {
|
|
136
|
+
f"inspect conversations, visit {serve_location}{inspect_path}"
|
|
134
137
|
)
|
|
135
138
|
|
|
136
139
|
underlying_webhook.add_websocket_route(
|
|
@@ -187,5 +190,8 @@ class TrackerStream:
|
|
|
187
190
|
if not self._connected_clients:
|
|
188
191
|
return
|
|
189
192
|
await asyncio.wait(
|
|
190
|
-
[
|
|
193
|
+
[
|
|
194
|
+
asyncio.create_task(self._send(websocket, message))
|
|
195
|
+
for websocket in self._connected_clients
|
|
196
|
+
]
|
|
191
197
|
)
|
|
@@ -74,7 +74,7 @@ class Conversation:
|
|
|
74
74
|
@staticmethod
|
|
75
75
|
def get_metadata(activity: Dict[Text, Any]) -> Optional[Dict[Text, Any]]:
|
|
76
76
|
"""Get metadata from the activity."""
|
|
77
|
-
return activity
|
|
77
|
+
return asdict(map_call_params(activity["parameters"]))
|
|
78
78
|
|
|
79
79
|
@staticmethod
|
|
80
80
|
def _handle_event(event: Dict[Text, Any]) -> Text:
|
|
@@ -88,17 +88,16 @@ class Conversation:
|
|
|
88
88
|
|
|
89
89
|
if event["name"] == EVENT_START:
|
|
90
90
|
text = f"{INTENT_MESSAGE_PREFIX}{USER_INTENT_SESSION_START}"
|
|
91
|
-
event_params = asdict(map_call_params(event["parameters"]))
|
|
92
91
|
elif event["name"] == EVENT_DTMF:
|
|
93
92
|
text = f"{INTENT_MESSAGE_PREFIX}vaig_event_DTMF"
|
|
94
93
|
event_params = {"value": event["value"]}
|
|
94
|
+
text += json.dumps(event_params)
|
|
95
95
|
else:
|
|
96
96
|
structlogger.warning(
|
|
97
97
|
"audiocodes.handle.event.unknown_event", event_payload=event
|
|
98
98
|
)
|
|
99
99
|
return ""
|
|
100
100
|
|
|
101
|
-
text += json.dumps(event_params)
|
|
102
101
|
return text
|
|
103
102
|
|
|
104
103
|
def is_active_conversation(self, now: datetime, delta: timedelta) -> bool:
|
|
@@ -384,7 +383,7 @@ class AudiocodesInput(InputChannel):
|
|
|
384
383
|
{"conversation": <conversation_id>, "reason": Optional[Text]}.
|
|
385
384
|
"""
|
|
386
385
|
self._get_conversation(request.token, conversation_id)
|
|
387
|
-
reason =
|
|
386
|
+
reason = {"reason": request.json.get("reason")}
|
|
388
387
|
await on_new_message(
|
|
389
388
|
UserMessage(
|
|
390
389
|
text=f"{INTENT_MESSAGE_PREFIX}session_end",
|
|
@@ -1,5 +1,14 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
|
-
from typing import
|
|
2
|
+
from typing import (
|
|
3
|
+
Dict,
|
|
4
|
+
AsyncIterator,
|
|
5
|
+
Any,
|
|
6
|
+
Generic,
|
|
7
|
+
Optional,
|
|
8
|
+
Tuple,
|
|
9
|
+
Type,
|
|
10
|
+
TypeVar,
|
|
11
|
+
)
|
|
3
12
|
|
|
4
13
|
from websockets.legacy.client import WebSocketClientProtocol
|
|
5
14
|
|
|
@@ -7,6 +16,7 @@ from rasa.core.channels.voice_stream.asr.asr_event import ASREvent
|
|
|
7
16
|
from rasa.core.channels.voice_stream.audio_bytes import RasaAudioBytes
|
|
8
17
|
from rasa.core.channels.voice_stream.util import MergeableConfig
|
|
9
18
|
from rasa.shared.exceptions import ConnectionException
|
|
19
|
+
from rasa.shared.utils.common import validate_environment
|
|
10
20
|
|
|
11
21
|
T = TypeVar("T", bound="ASREngineConfig")
|
|
12
22
|
E = TypeVar("E", bound="ASREngine")
|
|
@@ -18,9 +28,17 @@ class ASREngineConfig(MergeableConfig):
|
|
|
18
28
|
|
|
19
29
|
|
|
20
30
|
class ASREngine(Generic[T]):
|
|
31
|
+
required_env_vars: Tuple[str, ...] = ()
|
|
32
|
+
required_packages: Tuple[str, ...] = ()
|
|
33
|
+
|
|
21
34
|
def __init__(self, config: Optional[T] = None):
|
|
22
35
|
self.config = self.get_default_config().merge(config)
|
|
23
36
|
self.asr_socket: Optional[WebSocketClientProtocol] = None
|
|
37
|
+
validate_environment(
|
|
38
|
+
self.required_env_vars,
|
|
39
|
+
self.required_packages,
|
|
40
|
+
f"ASR Engine {self.__class__.__name__}",
|
|
41
|
+
)
|
|
24
42
|
|
|
25
43
|
async def connect(self) -> None:
|
|
26
44
|
self.asr_socket = await self.open_websocket_connection()
|
|
@@ -7,9 +7,10 @@ from rasa.core.channels.voice_stream.asr.asr_engine import ASREngine, ASREngineC
|
|
|
7
7
|
from rasa.core.channels.voice_stream.asr.asr_event import (
|
|
8
8
|
ASREvent,
|
|
9
9
|
NewTranscript,
|
|
10
|
-
|
|
10
|
+
UserIsSpeaking,
|
|
11
11
|
)
|
|
12
12
|
from rasa.core.channels.voice_stream.audio_bytes import HERTZ, RasaAudioBytes
|
|
13
|
+
from rasa.shared.constants import AZURE_SPEECH_API_KEY_ENV_VAR
|
|
13
14
|
from rasa.shared.exceptions import ConnectionException
|
|
14
15
|
|
|
15
16
|
|
|
@@ -20,10 +21,14 @@ class AzureASRConfig(ASREngineConfig):
|
|
|
20
21
|
|
|
21
22
|
|
|
22
23
|
class AzureASR(ASREngine[AzureASRConfig]):
|
|
24
|
+
required_env_vars = (AZURE_SPEECH_API_KEY_ENV_VAR,)
|
|
25
|
+
required_packages = ("azure.cognitiveservices.speech",)
|
|
26
|
+
|
|
23
27
|
def __init__(self, config: Optional[AzureASRConfig] = None):
|
|
28
|
+
super().__init__(config)
|
|
29
|
+
|
|
24
30
|
import azure.cognitiveservices.speech as speechsdk
|
|
25
31
|
|
|
26
|
-
super().__init__(config)
|
|
27
32
|
self.speech_recognizer: Optional[speechsdk.SpeechRecognizer] = None
|
|
28
33
|
self.stream: Optional[speechsdk.audio.PushAudioInputStream] = None
|
|
29
34
|
self.is_recognizing = False
|
|
@@ -31,9 +36,13 @@ class AzureASR(ASREngine[AzureASRConfig]):
|
|
|
31
36
|
asyncio.Queue()
|
|
32
37
|
)
|
|
33
38
|
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
39
|
+
@staticmethod
|
|
40
|
+
def validate_environment() -> None:
|
|
41
|
+
"""Make sure all needed requirements for this component are met."""
|
|
42
|
+
|
|
43
|
+
def signal_user_is_speaking(self, event: Any) -> None:
|
|
44
|
+
"""Replace the azure event with a generic is speaking event."""
|
|
45
|
+
self.fill_queue(UserIsSpeaking())
|
|
37
46
|
|
|
38
47
|
def fill_queue(self, event: Any) -> None:
|
|
39
48
|
"""Either puts the event or a dedicated ASR Event into the queue."""
|
|
@@ -43,7 +52,7 @@ class AzureASR(ASREngine[AzureASRConfig]):
|
|
|
43
52
|
import azure.cognitiveservices.speech as speechsdk
|
|
44
53
|
|
|
45
54
|
speech_config = speechsdk.SpeechConfig(
|
|
46
|
-
subscription=os.environ[
|
|
55
|
+
subscription=os.environ[AZURE_SPEECH_API_KEY_ENV_VAR],
|
|
47
56
|
region=self.config.speech_region,
|
|
48
57
|
)
|
|
49
58
|
audio_format = speechsdk.audio.AudioStreamFormat(
|
|
@@ -60,9 +69,7 @@ class AzureASR(ASREngine[AzureASRConfig]):
|
|
|
60
69
|
audio_config=audio_config,
|
|
61
70
|
)
|
|
62
71
|
self.speech_recognizer.recognized.connect(self.fill_queue)
|
|
63
|
-
self.speech_recognizer.
|
|
64
|
-
self.signal_user_started_speaking
|
|
65
|
-
)
|
|
72
|
+
self.speech_recognizer.recognizing.connect(self.signal_user_is_speaking)
|
|
66
73
|
self.speech_recognizer.start_continuous_recognition_async()
|
|
67
74
|
self.is_recognizing = True
|
|
68
75
|
|
|
@@ -10,11 +10,10 @@ from rasa.core.channels.voice_stream.asr.asr_engine import ASREngine, ASREngineC
|
|
|
10
10
|
from rasa.core.channels.voice_stream.asr.asr_event import (
|
|
11
11
|
ASREvent,
|
|
12
12
|
NewTranscript,
|
|
13
|
-
|
|
13
|
+
UserIsSpeaking,
|
|
14
14
|
)
|
|
15
15
|
from rasa.core.channels.voice_stream.audio_bytes import HERTZ, RasaAudioBytes
|
|
16
|
-
|
|
17
|
-
DEEPGRAM_API_KEY = "DEEPGRAM_API_KEY"
|
|
16
|
+
from rasa.shared.constants import DEEPGRAM_API_KEY_ENV_VAR
|
|
18
17
|
|
|
19
18
|
|
|
20
19
|
@dataclass
|
|
@@ -28,13 +27,15 @@ class DeepgramASRConfig(ASREngineConfig):
|
|
|
28
27
|
|
|
29
28
|
|
|
30
29
|
class DeepgramASR(ASREngine[DeepgramASRConfig]):
|
|
30
|
+
required_env_vars = (DEEPGRAM_API_KEY_ENV_VAR,)
|
|
31
|
+
|
|
31
32
|
def __init__(self, config: Optional[DeepgramASRConfig] = None):
|
|
32
33
|
super().__init__(config)
|
|
33
34
|
self.accumulated_transcript = ""
|
|
34
35
|
|
|
35
36
|
async def open_websocket_connection(self) -> WebSocketClientProtocol:
|
|
36
37
|
"""Connect to the ASR system."""
|
|
37
|
-
deepgram_api_key = os.environ[
|
|
38
|
+
deepgram_api_key = os.environ[DEEPGRAM_API_KEY_ENV_VAR]
|
|
38
39
|
extra_headers = {"Authorization": f"Token {deepgram_api_key}"}
|
|
39
40
|
api_url = self._get_api_url()
|
|
40
41
|
query_params = self._get_query_params()
|
|
@@ -49,7 +50,7 @@ class DeepgramASR(ASREngine[DeepgramASRConfig]):
|
|
|
49
50
|
def _get_query_params(self) -> str:
|
|
50
51
|
return (
|
|
51
52
|
f"encoding=mulaw&sample_rate={HERTZ}&endpointing={self.config.endpointing}"
|
|
52
|
-
f"&vad_events=true&language={self.config.language}"
|
|
53
|
+
f"&vad_events=true&language={self.config.language}&interim_results=true"
|
|
53
54
|
f"&model={self.config.model}&smart_format={str(self.config.smart_format).lower()}"
|
|
54
55
|
)
|
|
55
56
|
|
|
@@ -66,16 +67,18 @@ class DeepgramASR(ASREngine[DeepgramASRConfig]):
|
|
|
66
67
|
def engine_event_to_asr_event(self, e: Any) -> Optional[ASREvent]:
|
|
67
68
|
"""Translate an engine event to a common ASREvent."""
|
|
68
69
|
data = json.loads(e)
|
|
69
|
-
if
|
|
70
|
+
if "is_final" in data:
|
|
70
71
|
transcript = data["channel"]["alternatives"][0]["transcript"]
|
|
71
|
-
if data
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
72
|
+
if data["is_final"]:
|
|
73
|
+
if data.get("speech_final"):
|
|
74
|
+
full_transcript = self.accumulated_transcript + transcript
|
|
75
|
+
self.accumulated_transcript = ""
|
|
76
|
+
if full_transcript:
|
|
77
|
+
return NewTranscript(full_transcript)
|
|
78
|
+
else:
|
|
79
|
+
self.accumulated_transcript += transcript
|
|
80
|
+
elif transcript:
|
|
81
|
+
return UserIsSpeaking()
|
|
79
82
|
return None
|
|
80
83
|
|
|
81
84
|
@staticmethod
|
|
@@ -12,6 +12,7 @@ from rasa.core.channels.voice_stream.tts.tts_engine import (
|
|
|
12
12
|
TTSEngineConfig,
|
|
13
13
|
TTSError,
|
|
14
14
|
)
|
|
15
|
+
from rasa.shared.constants import AZURE_SPEECH_API_KEY_ENV_VAR
|
|
15
16
|
from rasa.shared.exceptions import ConnectionException
|
|
16
17
|
|
|
17
18
|
|
|
@@ -25,6 +26,7 @@ class AzureTTSConfig(TTSEngineConfig):
|
|
|
25
26
|
|
|
26
27
|
class AzureTTS(TTSEngine[AzureTTSConfig]):
|
|
27
28
|
session: Optional[aiohttp.ClientSession] = None
|
|
29
|
+
required_env_vars = (AZURE_SPEECH_API_KEY_ENV_VAR,)
|
|
28
30
|
|
|
29
31
|
def __init__(self, config: Optional[AzureTTSConfig] = None):
|
|
30
32
|
super().__init__(config)
|
|
@@ -66,7 +68,7 @@ class AzureTTS(TTSEngine[AzureTTSConfig]):
|
|
|
66
68
|
|
|
67
69
|
@staticmethod
|
|
68
70
|
def get_request_headers() -> dict[str, str]:
|
|
69
|
-
azure_speech_api_key = os.environ[
|
|
71
|
+
azure_speech_api_key = os.environ[AZURE_SPEECH_API_KEY_ENV_VAR]
|
|
70
72
|
return {
|
|
71
73
|
"Ocp-Apim-Subscription-Key": azure_speech_api_key,
|
|
72
74
|
"Content-Type": "application/ssml+xml",
|
|
@@ -11,12 +11,11 @@ from rasa.core.channels.voice_stream.tts.tts_engine import (
|
|
|
11
11
|
|
|
12
12
|
from rasa.core.channels.voice_stream.audio_bytes import HERTZ, RasaAudioBytes
|
|
13
13
|
from rasa.core.channels.voice_stream.tts.tts_engine import TTSEngine, TTSError
|
|
14
|
+
from rasa.shared.constants import CARTESIA_API_KEY_ENV_VAR
|
|
14
15
|
from rasa.shared.exceptions import ConnectionException
|
|
15
16
|
|
|
16
17
|
structlogger = structlog.get_logger()
|
|
17
18
|
|
|
18
|
-
CARTESIA_API_KEY = "CARTESIA_API_KEY"
|
|
19
|
-
|
|
20
19
|
|
|
21
20
|
@dataclass
|
|
22
21
|
class CartesiaTTSConfig(TTSEngineConfig):
|
|
@@ -26,6 +25,7 @@ class CartesiaTTSConfig(TTSEngineConfig):
|
|
|
26
25
|
|
|
27
26
|
class CartesiaTTS(TTSEngine[CartesiaTTSConfig]):
|
|
28
27
|
session: Optional[aiohttp.ClientSession] = None
|
|
28
|
+
required_env_vars = (CARTESIA_API_KEY_ENV_VAR,)
|
|
29
29
|
|
|
30
30
|
def __init__(self, config: Optional[CartesiaTTSConfig] = None):
|
|
31
31
|
super().__init__(config)
|
|
@@ -62,7 +62,7 @@ class CartesiaTTS(TTSEngine[CartesiaTTSConfig]):
|
|
|
62
62
|
|
|
63
63
|
@staticmethod
|
|
64
64
|
def get_request_headers(config: CartesiaTTSConfig) -> dict[str, str]:
|
|
65
|
-
cartesia_api_key = os.environ[
|
|
65
|
+
cartesia_api_key = os.environ[CARTESIA_API_KEY_ENV_VAR]
|
|
66
66
|
return {
|
|
67
67
|
"Cartesia-Version": str(config.version),
|
|
68
68
|
"Content-Type": "application/json",
|
|
@@ -1,9 +1,10 @@
|
|
|
1
|
-
from typing import AsyncIterator, Dict, Generic, Optional, Type, TypeVar
|
|
1
|
+
from typing import AsyncIterator, Dict, Generic, Optional, Tuple, Type, TypeVar
|
|
2
2
|
from dataclasses import dataclass
|
|
3
3
|
|
|
4
4
|
from rasa.core.channels.voice_stream.audio_bytes import RasaAudioBytes
|
|
5
5
|
from rasa.core.channels.voice_stream.util import MergeableConfig
|
|
6
6
|
from rasa.shared.exceptions import RasaException
|
|
7
|
+
from rasa.shared.utils.common import validate_environment
|
|
7
8
|
|
|
8
9
|
|
|
9
10
|
class TTSError(RasaException):
|
|
@@ -22,8 +23,16 @@ class TTSEngineConfig(MergeableConfig):
|
|
|
22
23
|
|
|
23
24
|
|
|
24
25
|
class TTSEngine(Generic[T]):
|
|
26
|
+
required_env_vars: Tuple[str, ...] = ()
|
|
27
|
+
required_packages: Tuple[str, ...] = ()
|
|
28
|
+
|
|
25
29
|
def __init__(self, config: Optional[T] = None):
|
|
26
30
|
self.config = self.get_default_config().merge(config)
|
|
31
|
+
validate_environment(
|
|
32
|
+
self.required_env_vars,
|
|
33
|
+
self.required_packages,
|
|
34
|
+
f"TTS Engine {self.__class__.__name__}",
|
|
35
|
+
)
|
|
27
36
|
|
|
28
37
|
async def close_connection(self) -> None:
|
|
29
38
|
"""Cleanup the connection if necessary."""
|
|
@@ -5,10 +5,7 @@ from dataclasses import asdict, dataclass
|
|
|
5
5
|
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Tuple
|
|
6
6
|
|
|
7
7
|
from rasa.core.channels.voice_stream.util import generate_silence
|
|
8
|
-
from rasa.shared.core.constants import
|
|
9
|
-
SILENCE_TIMEOUT_DEFAULT_VALUE,
|
|
10
|
-
SLOT_SILENCE_TIMEOUT,
|
|
11
|
-
)
|
|
8
|
+
from rasa.shared.core.constants import SLOT_SILENCE_TIMEOUT
|
|
12
9
|
from rasa.shared.utils.common import (
|
|
13
10
|
class_from_module_path,
|
|
14
11
|
mark_as_beta_feature,
|
|
@@ -24,7 +21,7 @@ from rasa.core.channels.voice_stream.asr.asr_engine import ASREngine
|
|
|
24
21
|
from rasa.core.channels.voice_stream.asr.asr_event import (
|
|
25
22
|
ASREvent,
|
|
26
23
|
NewTranscript,
|
|
27
|
-
|
|
24
|
+
UserIsSpeaking,
|
|
28
25
|
)
|
|
29
26
|
from sanic import Websocket # type: ignore
|
|
30
27
|
|
|
@@ -233,11 +230,18 @@ class VoiceOutputChannel(OutputChannel):
|
|
|
233
230
|
|
|
234
231
|
|
|
235
232
|
class VoiceInputChannel(InputChannel):
|
|
236
|
-
def __init__(
|
|
233
|
+
def __init__(
|
|
234
|
+
self,
|
|
235
|
+
server_url: str,
|
|
236
|
+
asr_config: Dict,
|
|
237
|
+
tts_config: Dict,
|
|
238
|
+
monitor_silence: bool = False,
|
|
239
|
+
):
|
|
237
240
|
validate_voice_license_scope()
|
|
238
241
|
self.server_url = server_url
|
|
239
242
|
self.asr_config = asr_config
|
|
240
243
|
self.tts_config = tts_config
|
|
244
|
+
self.monitor_silence = monitor_silence
|
|
241
245
|
self.tts_cache = TTSCache(tts_config.get("cache_size", 1000))
|
|
242
246
|
|
|
243
247
|
async def handle_silence_timeout(
|
|
@@ -247,10 +251,14 @@ class VoiceInputChannel(InputChannel):
|
|
|
247
251
|
tts_engine: TTSEngine,
|
|
248
252
|
call_parameters: CallParameters,
|
|
249
253
|
) -> None:
|
|
250
|
-
timeout = call_state.silence_timeout
|
|
251
|
-
|
|
254
|
+
timeout = call_state.silence_timeout
|
|
255
|
+
if not timeout:
|
|
256
|
+
return
|
|
257
|
+
if not self.monitor_silence:
|
|
258
|
+
return
|
|
259
|
+
logger.debug("voice_channel.silence_timeout_watch_started", timeout=timeout)
|
|
252
260
|
await asyncio.sleep(timeout)
|
|
253
|
-
logger.
|
|
261
|
+
logger.debug("voice_channel.silence_timeout_tripped")
|
|
254
262
|
output_channel = self.create_output_channel(voice_websocket, tts_engine)
|
|
255
263
|
message = UserMessage(
|
|
256
264
|
"/silence_timeout",
|
|
@@ -261,10 +269,23 @@ class VoiceInputChannel(InputChannel):
|
|
|
261
269
|
)
|
|
262
270
|
await on_new_message(message)
|
|
263
271
|
|
|
272
|
+
@staticmethod
|
|
273
|
+
def _cancel_silence_timeout_watcher() -> None:
|
|
274
|
+
"""Cancels the silent timeout task if it exists."""
|
|
275
|
+
if call_state.silence_timeout_watcher:
|
|
276
|
+
logger.debug("voice_channel.cancelling_current_timeout_watcher_task")
|
|
277
|
+
call_state.silence_timeout_watcher.cancel()
|
|
278
|
+
call_state.silence_timeout_watcher = None # type: ignore[attr-defined]
|
|
279
|
+
|
|
264
280
|
@classmethod
|
|
265
281
|
def from_credentials(cls, credentials: Optional[Dict[str, Any]]) -> InputChannel:
|
|
266
282
|
credentials = credentials or {}
|
|
267
|
-
return cls(
|
|
283
|
+
return cls(
|
|
284
|
+
credentials["server_url"],
|
|
285
|
+
credentials["asr"],
|
|
286
|
+
credentials["tts"],
|
|
287
|
+
credentials.get("monitor_silence", False),
|
|
288
|
+
)
|
|
268
289
|
|
|
269
290
|
def channel_bytes_to_rasa_audio_bytes(self, input_bytes: bytes) -> RasaAudioBytes:
|
|
270
291
|
raise NotImplementedError
|
|
@@ -323,11 +344,14 @@ class VoiceInputChannel(InputChannel):
|
|
|
323
344
|
is_bot_speaking_after = call_state.is_bot_speaking
|
|
324
345
|
|
|
325
346
|
if not is_bot_speaking_before and is_bot_speaking_after:
|
|
326
|
-
logger.
|
|
347
|
+
logger.debug("voice_channel.bot_started_speaking")
|
|
348
|
+
# relevant when the bot speaks multiple messages in one turn
|
|
349
|
+
self._cancel_silence_timeout_watcher()
|
|
327
350
|
|
|
328
351
|
# we just stopped speaking, starting a watcher for silence timeout
|
|
329
352
|
if is_bot_speaking_before and not is_bot_speaking_after:
|
|
330
|
-
logger.
|
|
353
|
+
logger.debug("voice_channel.bot_stopped_speaking")
|
|
354
|
+
self._cancel_silence_timeout_watcher()
|
|
331
355
|
call_state.silence_timeout_watcher = ( # type: ignore[attr-defined]
|
|
332
356
|
asyncio.create_task(
|
|
333
357
|
self.handle_silence_timeout(
|
|
@@ -354,12 +378,20 @@ class VoiceInputChannel(InputChannel):
|
|
|
354
378
|
call_parameters,
|
|
355
379
|
)
|
|
356
380
|
|
|
381
|
+
audio_forwarding_task = asyncio.create_task(consume_audio_bytes())
|
|
382
|
+
asr_event_task = asyncio.create_task(consume_asr_events())
|
|
357
383
|
await asyncio.wait(
|
|
358
|
-
[
|
|
384
|
+
[audio_forwarding_task, asr_event_task],
|
|
359
385
|
return_when=asyncio.FIRST_COMPLETED,
|
|
360
386
|
)
|
|
387
|
+
if not audio_forwarding_task.done():
|
|
388
|
+
audio_forwarding_task.cancel()
|
|
389
|
+
if not asr_event_task.done():
|
|
390
|
+
asr_event_task.cancel()
|
|
361
391
|
await tts_engine.close_connection()
|
|
362
392
|
await asr_engine.close_connection()
|
|
393
|
+
await channel_websocket.close()
|
|
394
|
+
self._cancel_silence_timeout_watcher()
|
|
363
395
|
|
|
364
396
|
def create_output_channel(
|
|
365
397
|
self, voice_websocket: Websocket, tts_engine: TTSEngine
|
|
@@ -377,7 +409,7 @@ class VoiceInputChannel(InputChannel):
|
|
|
377
409
|
) -> None:
|
|
378
410
|
"""Handle a new event from the ASR system."""
|
|
379
411
|
if isinstance(e, NewTranscript) and e.text:
|
|
380
|
-
logger.
|
|
412
|
+
logger.debug(
|
|
381
413
|
"VoiceInputChannel.handle_asr_event.new_transcript", transcript=e.text
|
|
382
414
|
)
|
|
383
415
|
call_state.is_user_speaking = False # type: ignore[attr-defined]
|
|
@@ -390,8 +422,6 @@ class VoiceInputChannel(InputChannel):
|
|
|
390
422
|
metadata=asdict(call_parameters),
|
|
391
423
|
)
|
|
392
424
|
await on_new_message(message)
|
|
393
|
-
elif isinstance(e,
|
|
394
|
-
|
|
395
|
-
call_state.silence_timeout_watcher.cancel()
|
|
396
|
-
call_state.silence_timeout_watcher = None # type: ignore[attr-defined]
|
|
425
|
+
elif isinstance(e, UserIsSpeaking):
|
|
426
|
+
self._cancel_silence_timeout_watcher()
|
|
397
427
|
call_state.is_user_speaking = True # type: ignore[attr-defined]
|
|
@@ -62,6 +62,7 @@ class Qdrant_Store(InformationRetrieval):
|
|
|
62
62
|
embeddings=self.embeddings,
|
|
63
63
|
content_payload_key=params.get("content_payload_key", "text"),
|
|
64
64
|
metadata_payload_key=params.get("metadata_payload_key", "metadata"),
|
|
65
|
+
vector_name=params.get("vector_name", None),
|
|
65
66
|
)
|
|
66
67
|
|
|
67
68
|
async def search(
|
|
@@ -13,7 +13,7 @@ from rasa.shared.constants import (
|
|
|
13
13
|
PROVIDER_CONFIG_KEY,
|
|
14
14
|
OPENAI_PROVIDER,
|
|
15
15
|
TIMEOUT_CONFIG_KEY,
|
|
16
|
-
|
|
16
|
+
MODEL_GROUP_ID_CONFIG_KEY,
|
|
17
17
|
)
|
|
18
18
|
from rasa.shared.core.domain import KEY_RESPONSES_TEXT, Domain
|
|
19
19
|
from rasa.shared.core.events import BotUttered, UserUttered
|
|
@@ -253,7 +253,7 @@ class ContextualResponseRephraser(
|
|
|
253
253
|
llm_type=self.llm_property(PROVIDER_CONFIG_KEY),
|
|
254
254
|
llm_model=self.llm_property(MODEL_CONFIG_KEY)
|
|
255
255
|
or self.llm_property(MODEL_NAME_CONFIG_KEY),
|
|
256
|
-
llm_model_group_id=self.llm_property(
|
|
256
|
+
llm_model_group_id=self.llm_property(MODEL_GROUP_ID_CONFIG_KEY),
|
|
257
257
|
)
|
|
258
258
|
if not (updated_text := await self._generate_llm_response(prompt)):
|
|
259
259
|
# If the LLM fails to generate a response, we
|