rasa-pro 3.11.0rc2__py3-none-any.whl → 3.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (65) hide show
  1. rasa/__main__.py +9 -3
  2. rasa/cli/studio/upload.py +0 -15
  3. rasa/cli/utils.py +1 -1
  4. rasa/core/channels/development_inspector.py +8 -2
  5. rasa/core/channels/voice_ready/audiocodes.py +3 -4
  6. rasa/core/channels/voice_stream/asr/asr_engine.py +19 -1
  7. rasa/core/channels/voice_stream/asr/asr_event.py +1 -1
  8. rasa/core/channels/voice_stream/asr/azure.py +16 -9
  9. rasa/core/channels/voice_stream/asr/deepgram.py +17 -14
  10. rasa/core/channels/voice_stream/tts/azure.py +3 -1
  11. rasa/core/channels/voice_stream/tts/cartesia.py +3 -3
  12. rasa/core/channels/voice_stream/tts/tts_engine.py +10 -1
  13. rasa/core/channels/voice_stream/voice_channel.py +48 -18
  14. rasa/core/information_retrieval/qdrant.py +1 -0
  15. rasa/core/nlg/contextual_response_rephraser.py +2 -2
  16. rasa/core/persistor.py +93 -49
  17. rasa/core/policies/enterprise_search_policy.py +5 -5
  18. rasa/core/policies/flows/flow_executor.py +18 -8
  19. rasa/core/policies/intentless_policy.py +9 -5
  20. rasa/core/processor.py +7 -5
  21. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +2 -1
  22. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +9 -0
  23. rasa/e2e_test/aggregate_test_stats_calculator.py +11 -1
  24. rasa/e2e_test/assertions.py +133 -16
  25. rasa/e2e_test/assertions_schema.yml +23 -0
  26. rasa/e2e_test/e2e_test_runner.py +2 -2
  27. rasa/engine/loader.py +12 -0
  28. rasa/engine/validation.py +310 -86
  29. rasa/model_manager/config.py +8 -0
  30. rasa/model_manager/model_api.py +166 -61
  31. rasa/model_manager/runner_service.py +31 -26
  32. rasa/model_manager/trainer_service.py +14 -23
  33. rasa/model_manager/warm_rasa_process.py +187 -0
  34. rasa/model_service.py +3 -5
  35. rasa/model_training.py +3 -1
  36. rasa/shared/constants.py +27 -5
  37. rasa/shared/core/constants.py +1 -1
  38. rasa/shared/core/domain.py +8 -31
  39. rasa/shared/core/flows/yaml_flows_io.py +13 -4
  40. rasa/shared/importers/importer.py +19 -2
  41. rasa/shared/importers/rasa.py +5 -1
  42. rasa/shared/nlu/training_data/formats/rasa_yaml.py +18 -3
  43. rasa/shared/providers/_configs/litellm_router_client_config.py +29 -9
  44. rasa/shared/providers/_utils.py +79 -0
  45. rasa/shared/providers/embedding/default_litellm_embedding_client.py +24 -0
  46. rasa/shared/providers/embedding/litellm_router_embedding_client.py +1 -1
  47. rasa/shared/providers/llm/_base_litellm_client.py +26 -0
  48. rasa/shared/providers/llm/default_litellm_llm_client.py +24 -0
  49. rasa/shared/providers/llm/litellm_router_llm_client.py +56 -1
  50. rasa/shared/providers/llm/self_hosted_llm_client.py +4 -28
  51. rasa/shared/providers/router/_base_litellm_router_client.py +35 -1
  52. rasa/shared/utils/common.py +30 -3
  53. rasa/shared/utils/health_check/health_check.py +26 -24
  54. rasa/shared/utils/yaml.py +116 -31
  55. rasa/studio/data_handler.py +3 -1
  56. rasa/studio/upload.py +119 -57
  57. rasa/telemetry.py +3 -1
  58. rasa/tracing/config.py +1 -1
  59. rasa/validator.py +40 -4
  60. rasa/version.py +1 -1
  61. {rasa_pro-3.11.0rc2.dist-info → rasa_pro-3.11.1.dist-info}/METADATA +2 -2
  62. {rasa_pro-3.11.0rc2.dist-info → rasa_pro-3.11.1.dist-info}/RECORD +65 -63
  63. {rasa_pro-3.11.0rc2.dist-info → rasa_pro-3.11.1.dist-info}/NOTICE +0 -0
  64. {rasa_pro-3.11.0rc2.dist-info → rasa_pro-3.11.1.dist-info}/WHEEL +0 -0
  65. {rasa_pro-3.11.0rc2.dist-info → rasa_pro-3.11.1.dist-info}/entry_points.txt +0 -0
rasa/__main__.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import argparse
2
+ from typing import Optional, List
2
3
  import structlog
3
4
  import os
4
5
  import platform
@@ -97,12 +98,17 @@ def print_version() -> None:
97
98
  print(f"License Expires : {get_license_expiration_date()}")
98
99
 
99
100
 
100
- def main() -> None:
101
- """Run as standalone python application."""
101
+ def main(raw_arguments: Optional[List[str]] = None) -> None:
102
+ """Run as standalone python application.
103
+
104
+ Args:
105
+ raw_arguments: Arguments to parse. If not provided,
106
+ arguments will be taken from the command line.
107
+ """
102
108
  warn_if_rasa_plus_package_installed()
103
109
  parse_last_positional_argument_as_model_path()
104
110
  arg_parser = create_argument_parser()
105
- cmdline_arguments = arg_parser.parse_args()
111
+ cmdline_arguments = arg_parser.parse_args(raw_arguments)
106
112
 
107
113
  log_level = getattr(cmdline_arguments, "loglevel", None)
108
114
  logging_config_file = getattr(cmdline_arguments, "logging_config_file", None)
rasa/cli/studio/upload.py CHANGED
@@ -32,25 +32,10 @@ def add_subparser(
32
32
  set_upload_arguments(upload_parser)
33
33
 
34
34
 
35
- def add_flows_param(
36
- parser: argparse.ArgumentParser,
37
- help_text: str = "Name of flows file to upload to Rasa Studio. Works with --calm",
38
- default_path: str = "flows.yml",
39
- ) -> None:
40
- parser.add_argument(
41
- "--flows",
42
- default=default_path,
43
- nargs="+",
44
- type=str,
45
- help=help_text,
46
- )
47
-
48
-
49
35
  def set_upload_arguments(parser: argparse.ArgumentParser) -> None:
50
36
  """Add arguments for running `rasa upload`."""
51
37
  add_data_param(parser, data_type="training")
52
38
  add_domain_param(parser)
53
- add_flows_param(parser)
54
39
  add_config_param(parser)
55
40
  add_endpoint_param(parser, help_text="Path to the endpoints file.")
56
41
 
rasa/cli/utils.py CHANGED
@@ -305,7 +305,7 @@ def _validate_domain(validator: "Validator") -> bool:
305
305
  valid_forms_in_stories_rules = validator.verify_forms_in_stories_rules()
306
306
  valid_form_slots = validator.verify_form_slots()
307
307
  valid_slot_mappings = validator.verify_slot_mappings()
308
- valid_responses = validator.check_for_no_empty_paranthesis_in_responses()
308
+ valid_responses = validator.check_for_no_empty_parenthesis_in_responses()
309
309
  valid_buttons = validator.validate_button_payloads()
310
310
  return (
311
311
  valid_domain_validity
@@ -128,9 +128,12 @@ class DevelopmentInspectProxy(InputChannel):
128
128
 
129
129
  inspect_path = app.url_for(f"{app.name}.{underlying_webhook.name}.inspect")
130
130
 
131
+ # replace 0.0.0.0 with localhost
132
+ serve_location = app.serve_location.replace("0.0.0.0", "localhost")
133
+
131
134
  print_info(
132
135
  f"Development inspector for channel {self.name()} is running. To "
133
- f"inspect conversations, visit {app.serve_location}{inspect_path}"
136
+ f"inspect conversations, visit {serve_location}{inspect_path}"
134
137
  )
135
138
 
136
139
  underlying_webhook.add_websocket_route(
@@ -187,5 +190,8 @@ class TrackerStream:
187
190
  if not self._connected_clients:
188
191
  return
189
192
  await asyncio.wait(
190
- [self._send(websocket, message) for websocket in self._connected_clients]
193
+ [
194
+ asyncio.create_task(self._send(websocket, message))
195
+ for websocket in self._connected_clients
196
+ ]
191
197
  )
@@ -74,7 +74,7 @@ class Conversation:
74
74
  @staticmethod
75
75
  def get_metadata(activity: Dict[Text, Any]) -> Optional[Dict[Text, Any]]:
76
76
  """Get metadata from the activity."""
77
- return activity.get("parameters")
77
+ return asdict(map_call_params(activity["parameters"]))
78
78
 
79
79
  @staticmethod
80
80
  def _handle_event(event: Dict[Text, Any]) -> Text:
@@ -88,17 +88,16 @@ class Conversation:
88
88
 
89
89
  if event["name"] == EVENT_START:
90
90
  text = f"{INTENT_MESSAGE_PREFIX}{USER_INTENT_SESSION_START}"
91
- event_params = asdict(map_call_params(event["parameters"]))
92
91
  elif event["name"] == EVENT_DTMF:
93
92
  text = f"{INTENT_MESSAGE_PREFIX}vaig_event_DTMF"
94
93
  event_params = {"value": event["value"]}
94
+ text += json.dumps(event_params)
95
95
  else:
96
96
  structlogger.warning(
97
97
  "audiocodes.handle.event.unknown_event", event_payload=event
98
98
  )
99
99
  return ""
100
100
 
101
- text += json.dumps(event_params)
102
101
  return text
103
102
 
104
103
  def is_active_conversation(self, now: datetime, delta: timedelta) -> bool:
@@ -384,7 +383,7 @@ class AudiocodesInput(InputChannel):
384
383
  {"conversation": <conversation_id>, "reason": Optional[Text]}.
385
384
  """
386
385
  self._get_conversation(request.token, conversation_id)
387
- reason = json.dumps({"reason": request.json.get("reason")})
386
+ reason = {"reason": request.json.get("reason")}
388
387
  await on_new_message(
389
388
  UserMessage(
390
389
  text=f"{INTENT_MESSAGE_PREFIX}session_end",
@@ -1,5 +1,14 @@
1
1
  from dataclasses import dataclass
2
- from typing import Dict, AsyncIterator, Any, Generic, Optional, Type, TypeVar
2
+ from typing import (
3
+ Dict,
4
+ AsyncIterator,
5
+ Any,
6
+ Generic,
7
+ Optional,
8
+ Tuple,
9
+ Type,
10
+ TypeVar,
11
+ )
3
12
 
4
13
  from websockets.legacy.client import WebSocketClientProtocol
5
14
 
@@ -7,6 +16,7 @@ from rasa.core.channels.voice_stream.asr.asr_event import ASREvent
7
16
  from rasa.core.channels.voice_stream.audio_bytes import RasaAudioBytes
8
17
  from rasa.core.channels.voice_stream.util import MergeableConfig
9
18
  from rasa.shared.exceptions import ConnectionException
19
+ from rasa.shared.utils.common import validate_environment
10
20
 
11
21
  T = TypeVar("T", bound="ASREngineConfig")
12
22
  E = TypeVar("E", bound="ASREngine")
@@ -18,9 +28,17 @@ class ASREngineConfig(MergeableConfig):
18
28
 
19
29
 
20
30
  class ASREngine(Generic[T]):
31
+ required_env_vars: Tuple[str, ...] = ()
32
+ required_packages: Tuple[str, ...] = ()
33
+
21
34
  def __init__(self, config: Optional[T] = None):
22
35
  self.config = self.get_default_config().merge(config)
23
36
  self.asr_socket: Optional[WebSocketClientProtocol] = None
37
+ validate_environment(
38
+ self.required_env_vars,
39
+ self.required_packages,
40
+ f"ASR Engine {self.__class__.__name__}",
41
+ )
24
42
 
25
43
  async def connect(self) -> None:
26
44
  self.asr_socket = await self.open_websocket_connection()
@@ -14,5 +14,5 @@ class NewTranscript(ASREvent):
14
14
 
15
15
 
16
16
  @dataclass
17
- class UserStartedSpeaking(ASREvent):
17
+ class UserIsSpeaking(ASREvent):
18
18
  pass
@@ -7,9 +7,10 @@ from rasa.core.channels.voice_stream.asr.asr_engine import ASREngine, ASREngineC
7
7
  from rasa.core.channels.voice_stream.asr.asr_event import (
8
8
  ASREvent,
9
9
  NewTranscript,
10
- UserStartedSpeaking,
10
+ UserIsSpeaking,
11
11
  )
12
12
  from rasa.core.channels.voice_stream.audio_bytes import HERTZ, RasaAudioBytes
13
+ from rasa.shared.constants import AZURE_SPEECH_API_KEY_ENV_VAR
13
14
  from rasa.shared.exceptions import ConnectionException
14
15
 
15
16
 
@@ -20,10 +21,14 @@ class AzureASRConfig(ASREngineConfig):
20
21
 
21
22
 
22
23
  class AzureASR(ASREngine[AzureASRConfig]):
24
+ required_env_vars = (AZURE_SPEECH_API_KEY_ENV_VAR,)
25
+ required_packages = ("azure.cognitiveservices.speech",)
26
+
23
27
  def __init__(self, config: Optional[AzureASRConfig] = None):
28
+ super().__init__(config)
29
+
24
30
  import azure.cognitiveservices.speech as speechsdk
25
31
 
26
- super().__init__(config)
27
32
  self.speech_recognizer: Optional[speechsdk.SpeechRecognizer] = None
28
33
  self.stream: Optional[speechsdk.audio.PushAudioInputStream] = None
29
34
  self.is_recognizing = False
@@ -31,9 +36,13 @@ class AzureASR(ASREngine[AzureASRConfig]):
31
36
  asyncio.Queue()
32
37
  )
33
38
 
34
- def signal_user_started_speaking(self, event: Any) -> None:
35
- """Replace the unspecific azure event with a specific start event."""
36
- self.fill_queue(UserStartedSpeaking())
39
+ @staticmethod
40
+ def validate_environment() -> None:
41
+ """Make sure all needed requirements for this component are met."""
42
+
43
+ def signal_user_is_speaking(self, event: Any) -> None:
44
+ """Replace the azure event with a generic is speaking event."""
45
+ self.fill_queue(UserIsSpeaking())
37
46
 
38
47
  def fill_queue(self, event: Any) -> None:
39
48
  """Either puts the event or a dedicated ASR Event into the queue."""
@@ -43,7 +52,7 @@ class AzureASR(ASREngine[AzureASRConfig]):
43
52
  import azure.cognitiveservices.speech as speechsdk
44
53
 
45
54
  speech_config = speechsdk.SpeechConfig(
46
- subscription=os.environ["AZURE_SPEECH_API_KEY"],
55
+ subscription=os.environ[AZURE_SPEECH_API_KEY_ENV_VAR],
47
56
  region=self.config.speech_region,
48
57
  )
49
58
  audio_format = speechsdk.audio.AudioStreamFormat(
@@ -60,9 +69,7 @@ class AzureASR(ASREngine[AzureASRConfig]):
60
69
  audio_config=audio_config,
61
70
  )
62
71
  self.speech_recognizer.recognized.connect(self.fill_queue)
63
- self.speech_recognizer.speech_start_detected.connect(
64
- self.signal_user_started_speaking
65
- )
72
+ self.speech_recognizer.recognizing.connect(self.signal_user_is_speaking)
66
73
  self.speech_recognizer.start_continuous_recognition_async()
67
74
  self.is_recognizing = True
68
75
 
@@ -10,11 +10,10 @@ from rasa.core.channels.voice_stream.asr.asr_engine import ASREngine, ASREngineC
10
10
  from rasa.core.channels.voice_stream.asr.asr_event import (
11
11
  ASREvent,
12
12
  NewTranscript,
13
- UserStartedSpeaking,
13
+ UserIsSpeaking,
14
14
  )
15
15
  from rasa.core.channels.voice_stream.audio_bytes import HERTZ, RasaAudioBytes
16
-
17
- DEEPGRAM_API_KEY = "DEEPGRAM_API_KEY"
16
+ from rasa.shared.constants import DEEPGRAM_API_KEY_ENV_VAR
18
17
 
19
18
 
20
19
  @dataclass
@@ -28,13 +27,15 @@ class DeepgramASRConfig(ASREngineConfig):
28
27
 
29
28
 
30
29
  class DeepgramASR(ASREngine[DeepgramASRConfig]):
30
+ required_env_vars = (DEEPGRAM_API_KEY_ENV_VAR,)
31
+
31
32
  def __init__(self, config: Optional[DeepgramASRConfig] = None):
32
33
  super().__init__(config)
33
34
  self.accumulated_transcript = ""
34
35
 
35
36
  async def open_websocket_connection(self) -> WebSocketClientProtocol:
36
37
  """Connect to the ASR system."""
37
- deepgram_api_key = os.environ[DEEPGRAM_API_KEY]
38
+ deepgram_api_key = os.environ[DEEPGRAM_API_KEY_ENV_VAR]
38
39
  extra_headers = {"Authorization": f"Token {deepgram_api_key}"}
39
40
  api_url = self._get_api_url()
40
41
  query_params = self._get_query_params()
@@ -49,7 +50,7 @@ class DeepgramASR(ASREngine[DeepgramASRConfig]):
49
50
  def _get_query_params(self) -> str:
50
51
  return (
51
52
  f"encoding=mulaw&sample_rate={HERTZ}&endpointing={self.config.endpointing}"
52
- f"&vad_events=true&language={self.config.language}"
53
+ f"&vad_events=true&language={self.config.language}&interim_results=true"
53
54
  f"&model={self.config.model}&smart_format={str(self.config.smart_format).lower()}"
54
55
  )
55
56
 
@@ -66,16 +67,18 @@ class DeepgramASR(ASREngine[DeepgramASRConfig]):
66
67
  def engine_event_to_asr_event(self, e: Any) -> Optional[ASREvent]:
67
68
  """Translate an engine event to a common ASREvent."""
68
69
  data = json.loads(e)
69
- if data.get("is_final"):
70
+ if "is_final" in data:
70
71
  transcript = data["channel"]["alternatives"][0]["transcript"]
71
- if data.get("speech_final"):
72
- full_transcript = self.accumulated_transcript + transcript
73
- self.accumulated_transcript = ""
74
- return NewTranscript(full_transcript)
75
- else:
76
- self.accumulated_transcript += transcript
77
- elif data.get("type") == "SpeechStarted":
78
- return UserStartedSpeaking()
72
+ if data["is_final"]:
73
+ if data.get("speech_final"):
74
+ full_transcript = self.accumulated_transcript + transcript
75
+ self.accumulated_transcript = ""
76
+ if full_transcript:
77
+ return NewTranscript(full_transcript)
78
+ else:
79
+ self.accumulated_transcript += transcript
80
+ elif transcript:
81
+ return UserIsSpeaking()
79
82
  return None
80
83
 
81
84
  @staticmethod
@@ -12,6 +12,7 @@ from rasa.core.channels.voice_stream.tts.tts_engine import (
12
12
  TTSEngineConfig,
13
13
  TTSError,
14
14
  )
15
+ from rasa.shared.constants import AZURE_SPEECH_API_KEY_ENV_VAR
15
16
  from rasa.shared.exceptions import ConnectionException
16
17
 
17
18
 
@@ -25,6 +26,7 @@ class AzureTTSConfig(TTSEngineConfig):
25
26
 
26
27
  class AzureTTS(TTSEngine[AzureTTSConfig]):
27
28
  session: Optional[aiohttp.ClientSession] = None
29
+ required_env_vars = (AZURE_SPEECH_API_KEY_ENV_VAR,)
28
30
 
29
31
  def __init__(self, config: Optional[AzureTTSConfig] = None):
30
32
  super().__init__(config)
@@ -66,7 +68,7 @@ class AzureTTS(TTSEngine[AzureTTSConfig]):
66
68
 
67
69
  @staticmethod
68
70
  def get_request_headers() -> dict[str, str]:
69
- azure_speech_api_key = os.environ["AZURE_SPEECH_API_KEY"]
71
+ azure_speech_api_key = os.environ[AZURE_SPEECH_API_KEY_ENV_VAR]
70
72
  return {
71
73
  "Ocp-Apim-Subscription-Key": azure_speech_api_key,
72
74
  "Content-Type": "application/ssml+xml",
@@ -11,12 +11,11 @@ from rasa.core.channels.voice_stream.tts.tts_engine import (
11
11
 
12
12
  from rasa.core.channels.voice_stream.audio_bytes import HERTZ, RasaAudioBytes
13
13
  from rasa.core.channels.voice_stream.tts.tts_engine import TTSEngine, TTSError
14
+ from rasa.shared.constants import CARTESIA_API_KEY_ENV_VAR
14
15
  from rasa.shared.exceptions import ConnectionException
15
16
 
16
17
  structlogger = structlog.get_logger()
17
18
 
18
- CARTESIA_API_KEY = "CARTESIA_API_KEY"
19
-
20
19
 
21
20
  @dataclass
22
21
  class CartesiaTTSConfig(TTSEngineConfig):
@@ -26,6 +25,7 @@ class CartesiaTTSConfig(TTSEngineConfig):
26
25
 
27
26
  class CartesiaTTS(TTSEngine[CartesiaTTSConfig]):
28
27
  session: Optional[aiohttp.ClientSession] = None
28
+ required_env_vars = (CARTESIA_API_KEY_ENV_VAR,)
29
29
 
30
30
  def __init__(self, config: Optional[CartesiaTTSConfig] = None):
31
31
  super().__init__(config)
@@ -62,7 +62,7 @@ class CartesiaTTS(TTSEngine[CartesiaTTSConfig]):
62
62
 
63
63
  @staticmethod
64
64
  def get_request_headers(config: CartesiaTTSConfig) -> dict[str, str]:
65
- cartesia_api_key = os.environ[CARTESIA_API_KEY]
65
+ cartesia_api_key = os.environ[CARTESIA_API_KEY_ENV_VAR]
66
66
  return {
67
67
  "Cartesia-Version": str(config.version),
68
68
  "Content-Type": "application/json",
@@ -1,9 +1,10 @@
1
- from typing import AsyncIterator, Dict, Generic, Optional, Type, TypeVar
1
+ from typing import AsyncIterator, Dict, Generic, Optional, Tuple, Type, TypeVar
2
2
  from dataclasses import dataclass
3
3
 
4
4
  from rasa.core.channels.voice_stream.audio_bytes import RasaAudioBytes
5
5
  from rasa.core.channels.voice_stream.util import MergeableConfig
6
6
  from rasa.shared.exceptions import RasaException
7
+ from rasa.shared.utils.common import validate_environment
7
8
 
8
9
 
9
10
  class TTSError(RasaException):
@@ -22,8 +23,16 @@ class TTSEngineConfig(MergeableConfig):
22
23
 
23
24
 
24
25
  class TTSEngine(Generic[T]):
26
+ required_env_vars: Tuple[str, ...] = ()
27
+ required_packages: Tuple[str, ...] = ()
28
+
25
29
  def __init__(self, config: Optional[T] = None):
26
30
  self.config = self.get_default_config().merge(config)
31
+ validate_environment(
32
+ self.required_env_vars,
33
+ self.required_packages,
34
+ f"TTS Engine {self.__class__.__name__}",
35
+ )
27
36
 
28
37
  async def close_connection(self) -> None:
29
38
  """Cleanup the connection if necessary."""
@@ -5,10 +5,7 @@ from dataclasses import asdict, dataclass
5
5
  from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Tuple
6
6
 
7
7
  from rasa.core.channels.voice_stream.util import generate_silence
8
- from rasa.shared.core.constants import (
9
- SILENCE_TIMEOUT_DEFAULT_VALUE,
10
- SLOT_SILENCE_TIMEOUT,
11
- )
8
+ from rasa.shared.core.constants import SLOT_SILENCE_TIMEOUT
12
9
  from rasa.shared.utils.common import (
13
10
  class_from_module_path,
14
11
  mark_as_beta_feature,
@@ -24,7 +21,7 @@ from rasa.core.channels.voice_stream.asr.asr_engine import ASREngine
24
21
  from rasa.core.channels.voice_stream.asr.asr_event import (
25
22
  ASREvent,
26
23
  NewTranscript,
27
- UserStartedSpeaking,
24
+ UserIsSpeaking,
28
25
  )
29
26
  from sanic import Websocket # type: ignore
30
27
 
@@ -233,11 +230,18 @@ class VoiceOutputChannel(OutputChannel):
233
230
 
234
231
 
235
232
  class VoiceInputChannel(InputChannel):
236
- def __init__(self, server_url: str, asr_config: Dict, tts_config: Dict):
233
+ def __init__(
234
+ self,
235
+ server_url: str,
236
+ asr_config: Dict,
237
+ tts_config: Dict,
238
+ monitor_silence: bool = False,
239
+ ):
237
240
  validate_voice_license_scope()
238
241
  self.server_url = server_url
239
242
  self.asr_config = asr_config
240
243
  self.tts_config = tts_config
244
+ self.monitor_silence = monitor_silence
241
245
  self.tts_cache = TTSCache(tts_config.get("cache_size", 1000))
242
246
 
243
247
  async def handle_silence_timeout(
@@ -247,10 +251,14 @@ class VoiceInputChannel(InputChannel):
247
251
  tts_engine: TTSEngine,
248
252
  call_parameters: CallParameters,
249
253
  ) -> None:
250
- timeout = call_state.silence_timeout or SILENCE_TIMEOUT_DEFAULT_VALUE
251
- logger.info("voice_channel.silence_timeout_watch_started", timeout=timeout)
254
+ timeout = call_state.silence_timeout
255
+ if not timeout:
256
+ return
257
+ if not self.monitor_silence:
258
+ return
259
+ logger.debug("voice_channel.silence_timeout_watch_started", timeout=timeout)
252
260
  await asyncio.sleep(timeout)
253
- logger.info("voice_channel.silence_timeout_tripped")
261
+ logger.debug("voice_channel.silence_timeout_tripped")
254
262
  output_channel = self.create_output_channel(voice_websocket, tts_engine)
255
263
  message = UserMessage(
256
264
  "/silence_timeout",
@@ -261,10 +269,23 @@ class VoiceInputChannel(InputChannel):
261
269
  )
262
270
  await on_new_message(message)
263
271
 
272
+ @staticmethod
273
+ def _cancel_silence_timeout_watcher() -> None:
274
+ """Cancels the silent timeout task if it exists."""
275
+ if call_state.silence_timeout_watcher:
276
+ logger.debug("voice_channel.cancelling_current_timeout_watcher_task")
277
+ call_state.silence_timeout_watcher.cancel()
278
+ call_state.silence_timeout_watcher = None # type: ignore[attr-defined]
279
+
264
280
  @classmethod
265
281
  def from_credentials(cls, credentials: Optional[Dict[str, Any]]) -> InputChannel:
266
282
  credentials = credentials or {}
267
- return cls(credentials["server_url"], credentials["asr"], credentials["tts"])
283
+ return cls(
284
+ credentials["server_url"],
285
+ credentials["asr"],
286
+ credentials["tts"],
287
+ credentials.get("monitor_silence", False),
288
+ )
268
289
 
269
290
  def channel_bytes_to_rasa_audio_bytes(self, input_bytes: bytes) -> RasaAudioBytes:
270
291
  raise NotImplementedError
@@ -323,11 +344,14 @@ class VoiceInputChannel(InputChannel):
323
344
  is_bot_speaking_after = call_state.is_bot_speaking
324
345
 
325
346
  if not is_bot_speaking_before and is_bot_speaking_after:
326
- logger.info("voice_channel.bot_started_speaking")
347
+ logger.debug("voice_channel.bot_started_speaking")
348
+ # relevant when the bot speaks multiple messages in one turn
349
+ self._cancel_silence_timeout_watcher()
327
350
 
328
351
  # we just stopped speaking, starting a watcher for silence timeout
329
352
  if is_bot_speaking_before and not is_bot_speaking_after:
330
- logger.info("voice_channel.bot_stopped_speaking")
353
+ logger.debug("voice_channel.bot_stopped_speaking")
354
+ self._cancel_silence_timeout_watcher()
331
355
  call_state.silence_timeout_watcher = ( # type: ignore[attr-defined]
332
356
  asyncio.create_task(
333
357
  self.handle_silence_timeout(
@@ -354,12 +378,20 @@ class VoiceInputChannel(InputChannel):
354
378
  call_parameters,
355
379
  )
356
380
 
381
+ audio_forwarding_task = asyncio.create_task(consume_audio_bytes())
382
+ asr_event_task = asyncio.create_task(consume_asr_events())
357
383
  await asyncio.wait(
358
- [consume_audio_bytes(), consume_asr_events()],
384
+ [audio_forwarding_task, asr_event_task],
359
385
  return_when=asyncio.FIRST_COMPLETED,
360
386
  )
387
+ if not audio_forwarding_task.done():
388
+ audio_forwarding_task.cancel()
389
+ if not asr_event_task.done():
390
+ asr_event_task.cancel()
361
391
  await tts_engine.close_connection()
362
392
  await asr_engine.close_connection()
393
+ await channel_websocket.close()
394
+ self._cancel_silence_timeout_watcher()
363
395
 
364
396
  def create_output_channel(
365
397
  self, voice_websocket: Websocket, tts_engine: TTSEngine
@@ -377,7 +409,7 @@ class VoiceInputChannel(InputChannel):
377
409
  ) -> None:
378
410
  """Handle a new event from the ASR system."""
379
411
  if isinstance(e, NewTranscript) and e.text:
380
- logger.info(
412
+ logger.debug(
381
413
  "VoiceInputChannel.handle_asr_event.new_transcript", transcript=e.text
382
414
  )
383
415
  call_state.is_user_speaking = False # type: ignore[attr-defined]
@@ -390,8 +422,6 @@ class VoiceInputChannel(InputChannel):
390
422
  metadata=asdict(call_parameters),
391
423
  )
392
424
  await on_new_message(message)
393
- elif isinstance(e, UserStartedSpeaking):
394
- if call_state.silence_timeout_watcher:
395
- call_state.silence_timeout_watcher.cancel()
396
- call_state.silence_timeout_watcher = None # type: ignore[attr-defined]
425
+ elif isinstance(e, UserIsSpeaking):
426
+ self._cancel_silence_timeout_watcher()
397
427
  call_state.is_user_speaking = True # type: ignore[attr-defined]
@@ -62,6 +62,7 @@ class Qdrant_Store(InformationRetrieval):
62
62
  embeddings=self.embeddings,
63
63
  content_payload_key=params.get("content_payload_key", "text"),
64
64
  metadata_payload_key=params.get("metadata_payload_key", "metadata"),
65
+ vector_name=params.get("vector_name", None),
65
66
  )
66
67
 
67
68
  async def search(
@@ -13,7 +13,7 @@ from rasa.shared.constants import (
13
13
  PROVIDER_CONFIG_KEY,
14
14
  OPENAI_PROVIDER,
15
15
  TIMEOUT_CONFIG_KEY,
16
- MODEL_GROUP_CONFIG_KEY,
16
+ MODEL_GROUP_ID_CONFIG_KEY,
17
17
  )
18
18
  from rasa.shared.core.domain import KEY_RESPONSES_TEXT, Domain
19
19
  from rasa.shared.core.events import BotUttered, UserUttered
@@ -253,7 +253,7 @@ class ContextualResponseRephraser(
253
253
  llm_type=self.llm_property(PROVIDER_CONFIG_KEY),
254
254
  llm_model=self.llm_property(MODEL_CONFIG_KEY)
255
255
  or self.llm_property(MODEL_NAME_CONFIG_KEY),
256
- llm_model_group_id=self.llm_property(MODEL_GROUP_CONFIG_KEY),
256
+ llm_model_group_id=self.llm_property(MODEL_GROUP_ID_CONFIG_KEY),
257
257
  )
258
258
  if not (updated_text := await self._generate_llm_response(prompt)):
259
259
  # If the LLM fails to generate a response, we