rasa-pro 3.12.0rc2__py3-none-any.whl → 3.12.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (43) hide show
  1. rasa/cli/dialogue_understanding_test.py +5 -8
  2. rasa/cli/llm_fine_tuning.py +47 -12
  3. rasa/core/channels/voice_stream/asr/asr_event.py +5 -0
  4. rasa/core/channels/voice_stream/audiocodes.py +19 -6
  5. rasa/core/channels/voice_stream/call_state.py +3 -9
  6. rasa/core/channels/voice_stream/genesys.py +40 -55
  7. rasa/core/channels/voice_stream/voice_channel.py +61 -39
  8. rasa/core/tracker_store.py +123 -34
  9. rasa/dialogue_understanding/commands/set_slot_command.py +1 -0
  10. rasa/dialogue_understanding/commands/utils.py +1 -4
  11. rasa/dialogue_understanding/generator/command_parser.py +41 -0
  12. rasa/dialogue_understanding/generator/constants.py +7 -2
  13. rasa/dialogue_understanding/generator/llm_based_command_generator.py +9 -2
  14. rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_claude_3_5_sonnet_20240620_template.jinja2 +29 -48
  15. rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_fallback_other_models_template.jinja2 +57 -0
  16. rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_gpt_4o_2024_11_20_template.jinja2 +23 -50
  17. rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +76 -24
  18. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +32 -18
  19. rasa/dialogue_understanding/processor/command_processor.py +39 -19
  20. rasa/dialogue_understanding/stack/utils.py +11 -6
  21. rasa/engine/language.py +67 -25
  22. rasa/llm_fine_tuning/conversations.py +3 -31
  23. rasa/llm_fine_tuning/llm_data_preparation_module.py +5 -3
  24. rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +18 -13
  25. rasa/llm_fine_tuning/paraphrasing_module.py +6 -2
  26. rasa/llm_fine_tuning/train_test_split_module.py +27 -27
  27. rasa/llm_fine_tuning/utils.py +7 -0
  28. rasa/shared/constants.py +4 -0
  29. rasa/shared/core/domain.py +2 -0
  30. rasa/shared/providers/_configs/azure_entra_id_config.py +8 -8
  31. rasa/shared/providers/llm/litellm_router_llm_client.py +1 -0
  32. rasa/shared/providers/router/_base_litellm_router_client.py +38 -7
  33. rasa/shared/utils/llm.py +69 -13
  34. rasa/telemetry.py +13 -3
  35. rasa/tracing/instrumentation/attribute_extractors.py +2 -5
  36. rasa/validator.py +2 -2
  37. rasa/version.py +1 -1
  38. {rasa_pro-3.12.0rc2.dist-info → rasa_pro-3.12.0rc3.dist-info}/METADATA +1 -1
  39. {rasa_pro-3.12.0rc2.dist-info → rasa_pro-3.12.0rc3.dist-info}/RECORD +42 -41
  40. rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_default.jinja2 +0 -68
  41. {rasa_pro-3.12.0rc2.dist-info → rasa_pro-3.12.0rc3.dist-info}/NOTICE +0 -0
  42. {rasa_pro-3.12.0rc2.dist-info → rasa_pro-3.12.0rc3.dist-info}/WHEEL +0 -0
  43. {rasa_pro-3.12.0rc2.dist-info → rasa_pro-3.12.0rc3.dist-info}/entry_points.txt +0 -0
@@ -3,7 +3,7 @@ import asyncio
3
3
  import datetime
4
4
  import importlib
5
5
  import sys
6
- from typing import Any, Dict, List, Optional
6
+ from typing import Any, Dict, List, Optional, Type, cast
7
7
 
8
8
  import structlog
9
9
 
@@ -20,9 +20,7 @@ from rasa.core.exceptions import AgentNotReady
20
20
  from rasa.core.processor import MessageProcessor
21
21
  from rasa.core.utils import AvailableEndpoints
22
22
  from rasa.dialogue_understanding.commands import Command
23
- from rasa.dialogue_understanding.generator import (
24
- LLMBasedCommandGenerator,
25
- )
23
+ from rasa.dialogue_understanding.generator import LLMBasedCommandGenerator
26
24
  from rasa.dialogue_understanding.generator.command_parser import DEFAULT_COMMANDS
27
25
  from rasa.dialogue_understanding_test.command_metric_calculation import (
28
26
  calculate_command_metrics,
@@ -372,18 +370,17 @@ def split_test_results(
372
370
  def _get_llm_command_generator_config(
373
371
  processor: MessageProcessor,
374
372
  ) -> Optional[Dict[str, Any]]:
375
- from rasa.dialogue_understanding.generator.constants import DEFAULT_LLM_CONFIG
376
-
377
373
  train_schema = processor.model_metadata.train_schema
378
374
 
379
375
  for node_name, node in train_schema.nodes.items():
380
376
  if node.matches_type(LLMBasedCommandGenerator, include_subtypes=True):
381
377
  # Configurations can reference model groups defined in the endpoints.yml
382
- resolved_config = resolve_model_client_config(
378
+ resolved_llm_config = resolve_model_client_config(
383
379
  node.config.get(LLM_CONFIG_KEY, {}), node_name
384
380
  )
381
+ llm_command_generator = cast(Type[LLMBasedCommandGenerator], node.uses)
385
382
  return combine_custom_and_default_config(
386
- resolved_config, DEFAULT_LLM_CONFIG
383
+ resolved_llm_config, llm_command_generator.get_default_llm_config()
387
384
  )
388
385
 
389
386
  return None
@@ -1,7 +1,7 @@
1
1
  import argparse
2
2
  import asyncio
3
3
  import sys
4
- from typing import Any, Dict, List
4
+ from typing import Any, Dict, List, Type, cast
5
5
 
6
6
  import structlog
7
7
 
@@ -22,7 +22,12 @@ from rasa.cli.e2e_test import (
22
22
  )
23
23
  from rasa.core.exceptions import AgentNotReady
24
24
  from rasa.core.utils import AvailableEndpoints
25
- from rasa.dialogue_understanding.generator import SingleStepLLMCommandGenerator
25
+ from rasa.dialogue_understanding.generator.llm_based_command_generator import (
26
+ LLMBasedCommandGenerator,
27
+ )
28
+ from rasa.dialogue_understanding.generator.multi_step.multi_step_llm_command_generator import ( # noqa: E501
29
+ MultiStepLLMCommandGenerator,
30
+ )
26
31
  from rasa.e2e_test.e2e_test_runner import E2ETestRunner
27
32
  from rasa.llm_fine_tuning.annotation_module import annotate_e2e_tests
28
33
  from rasa.llm_fine_tuning.llm_data_preparation_module import convert_to_fine_tuning_data
@@ -112,7 +117,6 @@ def create_llm_finetune_data_preparation_subparser(
112
117
  help_text="Configuration file for the model server and the connectors as a "
113
118
  "yml file.",
114
119
  )
115
-
116
120
  return data_preparation_subparser
117
121
 
118
122
 
@@ -205,6 +209,9 @@ def prepare_llm_fine_tuning_data(args: argparse.Namespace) -> None:
205
209
 
206
210
  flows = asyncio.run(e2e_test_runner.agent.processor.get_flows())
207
211
  llm_command_generator_config = _get_llm_command_generator_config(e2e_test_runner)
212
+ llm_command_generator: Type[LLMBasedCommandGenerator] = _get_llm_command_generator(
213
+ e2e_test_runner
214
+ )
208
215
 
209
216
  # set up storage context
210
217
  storage_context = create_storage_context(StorageType.FILE, output_dir)
@@ -235,6 +242,7 @@ def prepare_llm_fine_tuning_data(args: argparse.Namespace) -> None:
235
242
  rephrase_config,
236
243
  args.num_rephrases,
237
244
  flows,
245
+ llm_command_generator,
238
246
  llm_command_generator_config,
239
247
  storage_context,
240
248
  )
@@ -271,30 +279,57 @@ def prepare_llm_fine_tuning_data(args: argparse.Namespace) -> None:
271
279
  write_statistics(statistics, output_dir)
272
280
 
273
281
  rasa.shared.utils.cli.print_success(
274
- f"Data and intermediate results are written " f"to '{output_dir}'."
282
+ f"Data and intermediate results are written to '{output_dir}'."
275
283
  )
276
284
 
277
285
 
278
286
  def _get_llm_command_generator_config(e2e_test_runner: E2ETestRunner) -> Dict[str, Any]:
279
- from rasa.dialogue_understanding.generator.constants import DEFAULT_LLM_CONFIG
280
-
281
287
  train_schema = e2e_test_runner.agent.processor.model_metadata.train_schema # type: ignore
282
288
 
283
289
  for node_name, node in train_schema.nodes.items():
284
- if node.matches_type(SingleStepLLMCommandGenerator, include_subtypes=True):
290
+ if node.matches_type(
291
+ LLMBasedCommandGenerator, include_subtypes=True
292
+ ) and not node.matches_type(
293
+ MultiStepLLMCommandGenerator, include_subtypes=True
294
+ ):
285
295
  # Configurations can reference model groups defined in the endpoints.yml
286
- resolved_config = resolve_model_client_config(
296
+ resolved_llm_config = resolve_model_client_config(
287
297
  node.config.get(LLM_CONFIG_KEY, {}), node_name
288
298
  )
299
+ llm_command_generator = cast(Type[LLMBasedCommandGenerator], node.uses)
289
300
  return combine_custom_and_default_config(
290
- resolved_config, DEFAULT_LLM_CONFIG
301
+ resolved_llm_config, llm_command_generator.get_default_llm_config()
291
302
  )
292
303
 
293
304
  rasa.shared.utils.cli.print_error(
294
305
  "The provided model is not trained using 'SingleStepLLMCommandGenerator' or "
295
- "its subclasses. Without it, no data for fine-tuning can be generated. To "
296
- "resolve this, please include 'SingleStepLLMCommandGenerator' or its subclass "
297
- "in your config and train your model."
306
+ "'CompactLLMCommandGenerator' or its subclasses. Without it, no data for "
307
+ "fine-tuning can be generated. To resolve this, please include "
308
+ "'SingleStepLLMCommandGenerator' or 'CompactLLMCommandGenerator' or its "
309
+ "subclasses in your config and train your model."
310
+ )
311
+ sys.exit(1)
312
+
313
+
314
+ def _get_llm_command_generator(
315
+ e2e_test_runner: E2ETestRunner,
316
+ ) -> Type[LLMBasedCommandGenerator]:
317
+ train_schema = e2e_test_runner.agent.processor.model_metadata.train_schema # type: ignore
318
+
319
+ for _, node in train_schema.nodes.items():
320
+ if node.matches_type(
321
+ LLMBasedCommandGenerator, include_subtypes=True
322
+ ) and not node.matches_type(
323
+ MultiStepLLMCommandGenerator, include_subtypes=True
324
+ ):
325
+ return cast(Type[LLMBasedCommandGenerator], node.uses)
326
+
327
+ rasa.shared.utils.cli.print_error(
328
+ "The provided model is not trained using 'SingleStepLLMCommandGenerator' or "
329
+ "'CompactLLMCommandGenerator' or its subclasses. Without it, no data for "
330
+ "fine-tuning can be generated. To resolve this, please include "
331
+ "'SingleStepLLMCommandGenerator' or 'CompactLLMCommandGenerator' or its "
332
+ "subclasses in your config and train your model."
298
333
  )
299
334
  sys.exit(1)
300
335
 
@@ -16,3 +16,8 @@ class NewTranscript(ASREvent):
16
16
  @dataclass
17
17
  class UserIsSpeaking(ASREvent):
18
18
  pass
19
+
20
+
21
+ @dataclass
22
+ class UserSilence(ASREvent):
23
+ pass
@@ -46,6 +46,19 @@ class AudiocodesVoiceOutputChannel(VoiceOutputChannel):
46
46
  def name(cls) -> str:
47
47
  return "ac_voice"
48
48
 
49
+ def _ensure_stream_id(self) -> None:
50
+ """Audiocodes requires a stream ID with playStream messages."""
51
+ if "stream_id" not in call_state.channel_data:
52
+ call_state.channel_data["stream_id"] = 0
53
+
54
+ def _increment_stream_id(self) -> None:
55
+ self._ensure_stream_id()
56
+ call_state.channel_data["stream_id"] += 1
57
+
58
+ def _get_stream_id(self) -> str:
59
+ self._ensure_stream_id()
60
+ return str(call_state.channel_data["stream_id"])
61
+
49
62
  def rasa_audio_bytes_to_channel_bytes(
50
63
  self, rasa_audio_bytes: RasaAudioBytes
51
64
  ) -> bytes:
@@ -55,7 +68,7 @@ class AudiocodesVoiceOutputChannel(VoiceOutputChannel):
55
68
  media_message = json.dumps(
56
69
  {
57
70
  "type": "playStream.chunk",
58
- "streamId": str(call_state.stream_id),
71
+ "streamId": self._get_stream_id(),
59
72
  "audioChunk": channel_bytes.decode("utf-8"),
60
73
  }
61
74
  )
@@ -63,14 +76,14 @@ class AudiocodesVoiceOutputChannel(VoiceOutputChannel):
63
76
 
64
77
  async def send_start_marker(self, recipient_id: str) -> None:
65
78
  """Send playStream.start before first audio chunk."""
66
- call_state.stream_id += 1 # type: ignore[attr-defined]
79
+ self._increment_stream_id()
67
80
  media_message = json.dumps(
68
81
  {
69
82
  "type": "playStream.start",
70
- "streamId": str(call_state.stream_id),
83
+ "streamId": self._get_stream_id(),
71
84
  }
72
85
  )
73
- logger.debug("Sending start marker", stream_id=call_state.stream_id)
86
+ logger.debug("Sending start marker", stream_id=self._get_stream_id())
74
87
  await self.voice_websocket.send(media_message)
75
88
 
76
89
  async def send_intermediate_marker(self, recipient_id: str) -> None:
@@ -82,10 +95,10 @@ class AudiocodesVoiceOutputChannel(VoiceOutputChannel):
82
95
  media_message = json.dumps(
83
96
  {
84
97
  "type": "playStream.stop",
85
- "streamId": str(call_state.stream_id),
98
+ "streamId": self._get_stream_id(),
86
99
  }
87
100
  )
88
- logger.debug("Sending end marker", stream_id=call_state.stream_id)
101
+ logger.debug("Sending end marker", stream_id=self._get_stream_id())
89
102
  await self.voice_websocket.send(media_message)
90
103
 
91
104
 
@@ -1,7 +1,7 @@
1
1
  import asyncio
2
2
  from contextvars import ContextVar
3
3
  from dataclasses import dataclass, field
4
- from typing import Optional
4
+ from typing import Any, Dict, Optional
5
5
 
6
6
  from werkzeug.local import LocalProxy
7
7
 
@@ -19,14 +19,8 @@ class CallState:
19
19
  should_hangup: bool = False
20
20
  connection_failed: bool = False
21
21
 
22
- # Genesys requires the server and client each maintain a
23
- # monotonically increasing message sequence number.
24
- client_sequence_number: int = 0
25
- server_sequence_number: int = 0
26
- audio_buffer: bytearray = field(default_factory=bytearray)
27
-
28
- # Audiocodes requires a stream ID at start and end of stream
29
- stream_id: int = 0
22
+ # Generic field for channel-specific state data
23
+ channel_data: Dict[str, Any] = field(default_factory=dict)
30
24
 
31
25
 
32
26
  _call_state: ContextVar[CallState] = ContextVar("call_state")
@@ -27,8 +27,23 @@ from rasa.core.channels.voice_stream.voice_channel import (
27
27
  VoiceOutputChannel,
28
28
  )
29
29
 
30
- # Not mentioned in the documentation but observed in Geneys's example
31
- # https://github.com/GenesysCloudBlueprints/audioconnector-server-reference-implementation
30
+ """
31
+ Genesys throws a rate limit error with too many audio messages.
32
+ To avoid this, we buffer the audio messages and send them in chunks.
33
+
34
+ - global.inbound.binary.average.rate.per.second: 5
35
+ The allowed average rate per second of inbound binary data
36
+
37
+ - global.inbound.binary.max: 25
38
+ The maximum number of inbound binary data messages
39
+ that can be sent instantaneously
40
+
41
+ https://developer.genesys.cloud/organization/organization/limits#audiohook
42
+
43
+ The maximum binary message size is not mentioned
44
+ in the documentation but observed in their example app
45
+ https://github.com/GenesysCloudBlueprints/audioconnector-server-reference-implementation
46
+ """
32
47
  MAXIMUM_BINARY_MESSAGE_SIZE = 64000 # 64KB
33
48
  logger = structlog.get_logger(__name__)
34
49
 
@@ -56,52 +71,7 @@ class GenesysOutputChannel(VoiceOutputChannel):
56
71
  async def send_audio_bytes(
57
72
  self, recipient_id: str, audio_bytes: RasaAudioBytes
58
73
  ) -> None:
59
- """
60
- Send audio bytes to the recipient with buffering.
61
-
62
- Genesys throws a rate limit error with too many audio messages.
63
- To avoid this, we buffer the audio messages and send them in chunks.
64
-
65
- - global.inbound.binary.average.rate.per.second: 5
66
- The allowed average rate per second of inbound binary data
67
-
68
- - global.inbound.binary.max: 25
69
- The maximum number of inbound binary data messages
70
- that can be sent instantaneously
71
-
72
- https://developer.genesys.cloud/organization/organization/limits#audiohook
73
- """
74
- call_state.audio_buffer.extend(audio_bytes)
75
-
76
- # If we receive a non-standard chunk size, assume it's the end of a sequence
77
- # or buffer is more than 32KB (this is half of genesys's max audio message size)
78
- if len(audio_bytes) != 1024 or len(call_state.audio_buffer) >= (
79
- MAXIMUM_BINARY_MESSAGE_SIZE / 2
80
- ):
81
- # TODO: we should send the buffer when we receive a synthesis complete event
82
- # from TTS. This will ensure that the last audio chunk is always sent.
83
- await self._send_audio_buffer(self.voice_websocket)
84
-
85
- async def _send_audio_buffer(self, ws: Websocket) -> None:
86
- """Send the audio buffer to the recipient if it's not empty."""
87
- if call_state.audio_buffer:
88
- buffer_bytes = bytes(call_state.audio_buffer)
89
- await self._send_bytes_to_ws(ws, buffer_bytes)
90
- call_state.audio_buffer.clear()
91
-
92
- async def _send_bytes_to_ws(self, ws: Websocket, data: bytes) -> None:
93
- """Send audio bytes to the recipient as a binary websocket message."""
94
- if len(data) <= MAXIMUM_BINARY_MESSAGE_SIZE:
95
- await self.voice_websocket.send(data)
96
- else:
97
- # split the audio into chunks
98
- current_position = 0
99
- while current_position < len(data):
100
- end_position = min(
101
- current_position + MAXIMUM_BINARY_MESSAGE_SIZE, len(data)
102
- )
103
- await self.voice_websocket.send(data[current_position:end_position])
104
- current_position = end_position
74
+ await self.voice_websocket.send(audio_bytes)
105
75
 
106
76
  async def send_marker_message(self, recipient_id: str) -> None:
107
77
  """
@@ -119,6 +89,17 @@ class GenesysInputChannel(VoiceInputChannel):
119
89
  def __init__(self, *args: Any, **kwargs: Any) -> None:
120
90
  super().__init__(*args, **kwargs)
121
91
 
92
+ def _ensure_channel_data_initialized(self) -> None:
93
+ """Initialize Genesys-specific channel data if not already present.
94
+
95
+ Genesys requires the server and client each maintain a
96
+ monotonically increasing message sequence number.
97
+ """
98
+ if "server_sequence_number" not in call_state.channel_data:
99
+ call_state.channel_data["server_sequence_number"] = 0
100
+ if "client_sequence_number" not in call_state.channel_data:
101
+ call_state.channel_data["client_sequence_number"] = 0
102
+
122
103
  def _get_next_sequence(self) -> int:
123
104
  """
124
105
  Get the next message sequence number
@@ -128,23 +109,26 @@ class GenesysInputChannel(VoiceInputChannel):
128
109
  Genesys requires the server and client each maintain a
129
110
  monotonically increasing message sequence number.
130
111
  """
131
- cs = call_state
132
- cs.server_sequence_number += 1 # type: ignore[attr-defined]
133
- return cs.server_sequence_number
112
+ self._ensure_channel_data_initialized()
113
+ call_state.channel_data["server_sequence_number"] += 1
114
+ return call_state.channel_data["server_sequence_number"]
134
115
 
135
116
  def _get_last_client_sequence(self) -> int:
136
117
  """Get the last client(Genesys) sequence number."""
137
- return call_state.client_sequence_number
118
+ self._ensure_channel_data_initialized()
119
+ return call_state.channel_data["client_sequence_number"]
138
120
 
139
121
  def _update_client_sequence(self, seq: int) -> None:
140
122
  """Update the client(Genesys) sequence number."""
141
- if seq - call_state.client_sequence_number != 1:
123
+ self._ensure_channel_data_initialized()
124
+
125
+ if seq - call_state.channel_data["client_sequence_number"] != 1:
142
126
  logger.warning(
143
127
  "genesys.update_client_sequence.sequence_gap",
144
128
  received_seq=seq,
145
- last_seq=call_state.client_sequence_number,
129
+ last_seq=call_state.channel_data["client_sequence_number"],
146
130
  )
147
- call_state.client_sequence_number = seq # type: ignore[attr-defined]
131
+ call_state.channel_data["client_sequence_number"] = seq
148
132
 
149
133
  def channel_bytes_to_rasa_audio_bytes(self, input_bytes: bytes) -> RasaAudioBytes:
150
134
  return RasaAudioBytes(input_bytes)
@@ -211,6 +195,7 @@ class GenesysInputChannel(VoiceInputChannel):
211
195
  voice_websocket,
212
196
  tts_engine,
213
197
  self.tts_cache,
198
+ min_buffer_size=MAXIMUM_BINARY_MESSAGE_SIZE // 2,
214
199
  )
215
200
 
216
201
  async def handle_open(self, ws: Websocket, message: dict) -> CallParameters:
@@ -17,6 +17,7 @@ from rasa.core.channels.voice_stream.asr.asr_event import (
17
17
  ASREvent,
18
18
  NewTranscript,
19
19
  UserIsSpeaking,
20
+ UserSilence,
20
21
  )
21
22
  from rasa.core.channels.voice_stream.asr.azure import AzureASR
22
23
  from rasa.core.channels.voice_stream.asr.deepgram import DeepgramASR
@@ -120,13 +121,14 @@ class VoiceOutputChannel(OutputChannel):
120
121
  voice_websocket: Websocket,
121
122
  tts_engine: TTSEngine,
122
123
  tts_cache: TTSCache,
124
+ min_buffer_size: int = 0,
123
125
  ):
124
126
  super().__init__()
125
127
  self.voice_websocket = voice_websocket
126
128
  self.tts_engine = tts_engine
127
129
  self.tts_cache = tts_cache
128
-
129
130
  self.latest_message_id: Optional[str] = None
131
+ self.min_buffer_size = min_buffer_size
130
132
 
131
133
  def rasa_audio_bytes_to_channel_bytes(
132
134
  self, rasa_audio_bytes: RasaAudioBytes
@@ -186,6 +188,7 @@ class VoiceOutputChannel(OutputChannel):
186
188
  cached_audio_bytes = self.tts_cache.get(text)
187
189
  collected_audio_bytes = RasaAudioBytes(b"")
188
190
  seconds_marker = -1
191
+ last_sent_offset = 0
189
192
 
190
193
  # Send start marker before first chunk
191
194
  try:
@@ -205,17 +208,37 @@ class VoiceOutputChannel(OutputChannel):
205
208
  audio_stream = self.chunk_audio(generate_silence())
206
209
 
207
210
  async for audio_bytes in audio_stream:
208
- try:
209
- await self.send_audio_bytes(recipient_id, audio_bytes)
210
- full_seconds_of_audio = len(collected_audio_bytes) // HERTZ
211
- if full_seconds_of_audio > seconds_marker:
212
- await self.send_intermediate_marker(recipient_id)
213
- seconds_marker = full_seconds_of_audio
211
+ collected_audio_bytes = RasaAudioBytes(collected_audio_bytes + audio_bytes)
214
212
 
213
+ # Check if we have enough new bytes to send
214
+ current_buffer_size = len(collected_audio_bytes) - last_sent_offset
215
+ should_send = current_buffer_size >= self.min_buffer_size
216
+
217
+ if should_send:
218
+ try:
219
+ # Send only the new bytes since last send
220
+ new_bytes = RasaAudioBytes(collected_audio_bytes[last_sent_offset:])
221
+ await self.send_audio_bytes(recipient_id, new_bytes)
222
+ last_sent_offset = len(collected_audio_bytes)
223
+
224
+ full_seconds_of_audio = len(collected_audio_bytes) // HERTZ
225
+ if full_seconds_of_audio > seconds_marker:
226
+ await self.send_intermediate_marker(recipient_id)
227
+ seconds_marker = full_seconds_of_audio
228
+
229
+ except (WebsocketClosed, ServerError):
230
+ # ignore sending error, and keep collecting and caching audio bytes
231
+ call_state.connection_failed = True # type: ignore[attr-defined]
232
+
233
+ # Send any remaining audio not yet sent
234
+ remaining_bytes = len(collected_audio_bytes) - last_sent_offset
235
+ if remaining_bytes > 0:
236
+ try:
237
+ new_bytes = RasaAudioBytes(collected_audio_bytes[last_sent_offset:])
238
+ await self.send_audio_bytes(recipient_id, new_bytes)
215
239
  except (WebsocketClosed, ServerError):
216
- # ignore sending error, and keep collecting and caching audio bytes
240
+ # ignore sending error
217
241
  call_state.connection_failed = True # type: ignore[attr-defined]
218
- collected_audio_bytes = RasaAudioBytes(collected_audio_bytes + audio_bytes)
219
242
 
220
243
  try:
221
244
  await self.send_end_marker(recipient_id)
@@ -265,13 +288,7 @@ class VoiceInputChannel(InputChannel):
265
288
  self.monitor_silence = monitor_silence
266
289
  self.tts_cache = TTSCache(tts_config.get("cache_size", 1000))
267
290
 
268
- async def handle_silence_timeout(
269
- self,
270
- voice_websocket: Websocket,
271
- on_new_message: Callable[[UserMessage], Awaitable[Any]],
272
- tts_engine: TTSEngine,
273
- call_parameters: CallParameters,
274
- ) -> None:
291
+ async def monitor_silence_timeout(self, asr_event_queue: asyncio.Queue) -> None:
275
292
  timeout = call_state.silence_timeout
276
293
  if not timeout:
277
294
  return
@@ -279,16 +296,8 @@ class VoiceInputChannel(InputChannel):
279
296
  return
280
297
  logger.debug("voice_channel.silence_timeout_watch_started", timeout=timeout)
281
298
  await asyncio.sleep(timeout)
299
+ await asr_event_queue.put(UserSilence())
282
300
  logger.debug("voice_channel.silence_timeout_tripped")
283
- output_channel = self.create_output_channel(voice_websocket, tts_engine)
284
- message = UserMessage(
285
- "/silence_timeout",
286
- output_channel,
287
- call_parameters.stream_id,
288
- input_channel=self.name(),
289
- metadata=asdict(call_parameters),
290
- )
291
- await on_new_message(message)
292
301
 
293
302
  @staticmethod
294
303
  def _cancel_silence_timeout_watcher() -> None:
@@ -350,6 +359,7 @@ class VoiceInputChannel(InputChannel):
350
359
  _call_state.set(CallState())
351
360
  asr_engine = asr_engine_from_config(self.asr_config)
352
361
  tts_engine = tts_engine_from_config(self.tts_config)
362
+ asr_event_queue: asyncio.Queue = asyncio.Queue()
353
363
  await asr_engine.connect()
354
364
 
355
365
  call_parameters = await self.collect_call_parameters(channel_websocket)
@@ -376,12 +386,7 @@ class VoiceInputChannel(InputChannel):
376
386
  self._cancel_silence_timeout_watcher()
377
387
  call_state.silence_timeout_watcher = ( # type: ignore[attr-defined]
378
388
  asyncio.create_task(
379
- self.handle_silence_timeout(
380
- channel_websocket,
381
- on_new_message,
382
- tts_engine,
383
- call_parameters,
384
- )
389
+ self.monitor_silence_timeout(asr_event_queue)
385
390
  )
386
391
  )
387
392
  if isinstance(channel_action, NewAudioAction):
@@ -390,8 +395,13 @@ class VoiceInputChannel(InputChannel):
390
395
  # end stream event came from the other side
391
396
  break
392
397
 
393
- async def consume_asr_events() -> None:
398
+ async def receive_asr_events() -> None:
394
399
  async for event in asr_engine.stream_asr_events():
400
+ await asr_event_queue.put(event)
401
+
402
+ async def handle_asr_events() -> None:
403
+ while True:
404
+ event = await asr_event_queue.get()
395
405
  await self.handle_asr_event(
396
406
  event,
397
407
  channel_websocket,
@@ -400,16 +410,18 @@ class VoiceInputChannel(InputChannel):
400
410
  call_parameters,
401
411
  )
402
412
 
403
- audio_forwarding_task = asyncio.create_task(consume_audio_bytes())
404
- asr_event_task = asyncio.create_task(consume_asr_events())
413
+ tasks = [
414
+ asyncio.create_task(consume_audio_bytes()),
415
+ asyncio.create_task(receive_asr_events()),
416
+ asyncio.create_task(handle_asr_events()),
417
+ ]
405
418
  await asyncio.wait(
406
- [audio_forwarding_task, asr_event_task],
419
+ tasks,
407
420
  return_when=asyncio.FIRST_COMPLETED,
408
421
  )
409
- if not audio_forwarding_task.done():
410
- audio_forwarding_task.cancel()
411
- if not asr_event_task.done():
412
- asr_event_task.cancel()
422
+ for task in tasks:
423
+ if not task.done():
424
+ task.cancel()
413
425
  await tts_engine.close_connection()
414
426
  await asr_engine.close_connection()
415
427
  await channel_websocket.close()
@@ -447,3 +459,13 @@ class VoiceInputChannel(InputChannel):
447
459
  elif isinstance(e, UserIsSpeaking):
448
460
  self._cancel_silence_timeout_watcher()
449
461
  call_state.is_user_speaking = True # type: ignore[attr-defined]
462
+ elif isinstance(e, UserSilence):
463
+ output_channel = self.create_output_channel(voice_websocket, tts_engine)
464
+ message = UserMessage(
465
+ "/silence_timeout",
466
+ output_channel,
467
+ call_parameters.stream_id,
468
+ input_channel=self.name(),
469
+ metadata=asdict(call_parameters),
470
+ )
471
+ await on_new_message(message)