rasa-pro 3.12.0rc2__py3-none-any.whl → 3.12.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/cli/dialogue_understanding_test.py +5 -8
- rasa/cli/llm_fine_tuning.py +47 -12
- rasa/core/channels/voice_stream/asr/asr_event.py +5 -0
- rasa/core/channels/voice_stream/audiocodes.py +19 -6
- rasa/core/channels/voice_stream/call_state.py +3 -9
- rasa/core/channels/voice_stream/genesys.py +40 -55
- rasa/core/channels/voice_stream/voice_channel.py +61 -39
- rasa/core/tracker_store.py +123 -34
- rasa/dialogue_understanding/commands/set_slot_command.py +1 -0
- rasa/dialogue_understanding/commands/utils.py +1 -4
- rasa/dialogue_understanding/generator/command_parser.py +41 -0
- rasa/dialogue_understanding/generator/constants.py +7 -2
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +9 -2
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_claude_3_5_sonnet_20240620_template.jinja2 +29 -48
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_fallback_other_models_template.jinja2 +57 -0
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_gpt_4o_2024_11_20_template.jinja2 +23 -50
- rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +76 -24
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +32 -18
- rasa/dialogue_understanding/processor/command_processor.py +39 -19
- rasa/dialogue_understanding/stack/utils.py +11 -6
- rasa/engine/language.py +67 -25
- rasa/llm_fine_tuning/conversations.py +3 -31
- rasa/llm_fine_tuning/llm_data_preparation_module.py +5 -3
- rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +18 -13
- rasa/llm_fine_tuning/paraphrasing_module.py +6 -2
- rasa/llm_fine_tuning/train_test_split_module.py +27 -27
- rasa/llm_fine_tuning/utils.py +7 -0
- rasa/shared/constants.py +4 -0
- rasa/shared/core/domain.py +2 -0
- rasa/shared/providers/_configs/azure_entra_id_config.py +8 -8
- rasa/shared/providers/llm/litellm_router_llm_client.py +1 -0
- rasa/shared/providers/router/_base_litellm_router_client.py +38 -7
- rasa/shared/utils/llm.py +69 -13
- rasa/telemetry.py +13 -3
- rasa/tracing/instrumentation/attribute_extractors.py +2 -5
- rasa/validator.py +2 -2
- rasa/version.py +1 -1
- {rasa_pro-3.12.0rc2.dist-info → rasa_pro-3.12.0rc3.dist-info}/METADATA +1 -1
- {rasa_pro-3.12.0rc2.dist-info → rasa_pro-3.12.0rc3.dist-info}/RECORD +42 -41
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_default.jinja2 +0 -68
- {rasa_pro-3.12.0rc2.dist-info → rasa_pro-3.12.0rc3.dist-info}/NOTICE +0 -0
- {rasa_pro-3.12.0rc2.dist-info → rasa_pro-3.12.0rc3.dist-info}/WHEEL +0 -0
- {rasa_pro-3.12.0rc2.dist-info → rasa_pro-3.12.0rc3.dist-info}/entry_points.txt +0 -0
|
@@ -3,7 +3,7 @@ import asyncio
|
|
|
3
3
|
import datetime
|
|
4
4
|
import importlib
|
|
5
5
|
import sys
|
|
6
|
-
from typing import Any, Dict, List, Optional
|
|
6
|
+
from typing import Any, Dict, List, Optional, Type, cast
|
|
7
7
|
|
|
8
8
|
import structlog
|
|
9
9
|
|
|
@@ -20,9 +20,7 @@ from rasa.core.exceptions import AgentNotReady
|
|
|
20
20
|
from rasa.core.processor import MessageProcessor
|
|
21
21
|
from rasa.core.utils import AvailableEndpoints
|
|
22
22
|
from rasa.dialogue_understanding.commands import Command
|
|
23
|
-
from rasa.dialogue_understanding.generator import
|
|
24
|
-
LLMBasedCommandGenerator,
|
|
25
|
-
)
|
|
23
|
+
from rasa.dialogue_understanding.generator import LLMBasedCommandGenerator
|
|
26
24
|
from rasa.dialogue_understanding.generator.command_parser import DEFAULT_COMMANDS
|
|
27
25
|
from rasa.dialogue_understanding_test.command_metric_calculation import (
|
|
28
26
|
calculate_command_metrics,
|
|
@@ -372,18 +370,17 @@ def split_test_results(
|
|
|
372
370
|
def _get_llm_command_generator_config(
|
|
373
371
|
processor: MessageProcessor,
|
|
374
372
|
) -> Optional[Dict[str, Any]]:
|
|
375
|
-
from rasa.dialogue_understanding.generator.constants import DEFAULT_LLM_CONFIG
|
|
376
|
-
|
|
377
373
|
train_schema = processor.model_metadata.train_schema
|
|
378
374
|
|
|
379
375
|
for node_name, node in train_schema.nodes.items():
|
|
380
376
|
if node.matches_type(LLMBasedCommandGenerator, include_subtypes=True):
|
|
381
377
|
# Configurations can reference model groups defined in the endpoints.yml
|
|
382
|
-
|
|
378
|
+
resolved_llm_config = resolve_model_client_config(
|
|
383
379
|
node.config.get(LLM_CONFIG_KEY, {}), node_name
|
|
384
380
|
)
|
|
381
|
+
llm_command_generator = cast(Type[LLMBasedCommandGenerator], node.uses)
|
|
385
382
|
return combine_custom_and_default_config(
|
|
386
|
-
|
|
383
|
+
resolved_llm_config, llm_command_generator.get_default_llm_config()
|
|
387
384
|
)
|
|
388
385
|
|
|
389
386
|
return None
|
rasa/cli/llm_fine_tuning.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import argparse
|
|
2
2
|
import asyncio
|
|
3
3
|
import sys
|
|
4
|
-
from typing import Any, Dict, List
|
|
4
|
+
from typing import Any, Dict, List, Type, cast
|
|
5
5
|
|
|
6
6
|
import structlog
|
|
7
7
|
|
|
@@ -22,7 +22,12 @@ from rasa.cli.e2e_test import (
|
|
|
22
22
|
)
|
|
23
23
|
from rasa.core.exceptions import AgentNotReady
|
|
24
24
|
from rasa.core.utils import AvailableEndpoints
|
|
25
|
-
from rasa.dialogue_understanding.generator import
|
|
25
|
+
from rasa.dialogue_understanding.generator.llm_based_command_generator import (
|
|
26
|
+
LLMBasedCommandGenerator,
|
|
27
|
+
)
|
|
28
|
+
from rasa.dialogue_understanding.generator.multi_step.multi_step_llm_command_generator import ( # noqa: E501
|
|
29
|
+
MultiStepLLMCommandGenerator,
|
|
30
|
+
)
|
|
26
31
|
from rasa.e2e_test.e2e_test_runner import E2ETestRunner
|
|
27
32
|
from rasa.llm_fine_tuning.annotation_module import annotate_e2e_tests
|
|
28
33
|
from rasa.llm_fine_tuning.llm_data_preparation_module import convert_to_fine_tuning_data
|
|
@@ -112,7 +117,6 @@ def create_llm_finetune_data_preparation_subparser(
|
|
|
112
117
|
help_text="Configuration file for the model server and the connectors as a "
|
|
113
118
|
"yml file.",
|
|
114
119
|
)
|
|
115
|
-
|
|
116
120
|
return data_preparation_subparser
|
|
117
121
|
|
|
118
122
|
|
|
@@ -205,6 +209,9 @@ def prepare_llm_fine_tuning_data(args: argparse.Namespace) -> None:
|
|
|
205
209
|
|
|
206
210
|
flows = asyncio.run(e2e_test_runner.agent.processor.get_flows())
|
|
207
211
|
llm_command_generator_config = _get_llm_command_generator_config(e2e_test_runner)
|
|
212
|
+
llm_command_generator: Type[LLMBasedCommandGenerator] = _get_llm_command_generator(
|
|
213
|
+
e2e_test_runner
|
|
214
|
+
)
|
|
208
215
|
|
|
209
216
|
# set up storage context
|
|
210
217
|
storage_context = create_storage_context(StorageType.FILE, output_dir)
|
|
@@ -235,6 +242,7 @@ def prepare_llm_fine_tuning_data(args: argparse.Namespace) -> None:
|
|
|
235
242
|
rephrase_config,
|
|
236
243
|
args.num_rephrases,
|
|
237
244
|
flows,
|
|
245
|
+
llm_command_generator,
|
|
238
246
|
llm_command_generator_config,
|
|
239
247
|
storage_context,
|
|
240
248
|
)
|
|
@@ -271,30 +279,57 @@ def prepare_llm_fine_tuning_data(args: argparse.Namespace) -> None:
|
|
|
271
279
|
write_statistics(statistics, output_dir)
|
|
272
280
|
|
|
273
281
|
rasa.shared.utils.cli.print_success(
|
|
274
|
-
f"Data and intermediate results are written
|
|
282
|
+
f"Data and intermediate results are written to '{output_dir}'."
|
|
275
283
|
)
|
|
276
284
|
|
|
277
285
|
|
|
278
286
|
def _get_llm_command_generator_config(e2e_test_runner: E2ETestRunner) -> Dict[str, Any]:
|
|
279
|
-
from rasa.dialogue_understanding.generator.constants import DEFAULT_LLM_CONFIG
|
|
280
|
-
|
|
281
287
|
train_schema = e2e_test_runner.agent.processor.model_metadata.train_schema # type: ignore
|
|
282
288
|
|
|
283
289
|
for node_name, node in train_schema.nodes.items():
|
|
284
|
-
if node.matches_type(
|
|
290
|
+
if node.matches_type(
|
|
291
|
+
LLMBasedCommandGenerator, include_subtypes=True
|
|
292
|
+
) and not node.matches_type(
|
|
293
|
+
MultiStepLLMCommandGenerator, include_subtypes=True
|
|
294
|
+
):
|
|
285
295
|
# Configurations can reference model groups defined in the endpoints.yml
|
|
286
|
-
|
|
296
|
+
resolved_llm_config = resolve_model_client_config(
|
|
287
297
|
node.config.get(LLM_CONFIG_KEY, {}), node_name
|
|
288
298
|
)
|
|
299
|
+
llm_command_generator = cast(Type[LLMBasedCommandGenerator], node.uses)
|
|
289
300
|
return combine_custom_and_default_config(
|
|
290
|
-
|
|
301
|
+
resolved_llm_config, llm_command_generator.get_default_llm_config()
|
|
291
302
|
)
|
|
292
303
|
|
|
293
304
|
rasa.shared.utils.cli.print_error(
|
|
294
305
|
"The provided model is not trained using 'SingleStepLLMCommandGenerator' or "
|
|
295
|
-
"its subclasses. Without it, no data for
|
|
296
|
-
"resolve this, please include
|
|
297
|
-
"
|
|
306
|
+
"'CompactLLMCommandGenerator' or its subclasses. Without it, no data for "
|
|
307
|
+
"fine-tuning can be generated. To resolve this, please include "
|
|
308
|
+
"'SingleStepLLMCommandGenerator' or 'CompactLLMCommandGenerator' or its "
|
|
309
|
+
"subclasses in your config and train your model."
|
|
310
|
+
)
|
|
311
|
+
sys.exit(1)
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def _get_llm_command_generator(
|
|
315
|
+
e2e_test_runner: E2ETestRunner,
|
|
316
|
+
) -> Type[LLMBasedCommandGenerator]:
|
|
317
|
+
train_schema = e2e_test_runner.agent.processor.model_metadata.train_schema # type: ignore
|
|
318
|
+
|
|
319
|
+
for _, node in train_schema.nodes.items():
|
|
320
|
+
if node.matches_type(
|
|
321
|
+
LLMBasedCommandGenerator, include_subtypes=True
|
|
322
|
+
) and not node.matches_type(
|
|
323
|
+
MultiStepLLMCommandGenerator, include_subtypes=True
|
|
324
|
+
):
|
|
325
|
+
return cast(Type[LLMBasedCommandGenerator], node.uses)
|
|
326
|
+
|
|
327
|
+
rasa.shared.utils.cli.print_error(
|
|
328
|
+
"The provided model is not trained using 'SingleStepLLMCommandGenerator' or "
|
|
329
|
+
"'CompactLLMCommandGenerator' or its subclasses. Without it, no data for "
|
|
330
|
+
"fine-tuning can be generated. To resolve this, please include "
|
|
331
|
+
"'SingleStepLLMCommandGenerator' or 'CompactLLMCommandGenerator' or its "
|
|
332
|
+
"subclasses in your config and train your model."
|
|
298
333
|
)
|
|
299
334
|
sys.exit(1)
|
|
300
335
|
|
|
@@ -46,6 +46,19 @@ class AudiocodesVoiceOutputChannel(VoiceOutputChannel):
|
|
|
46
46
|
def name(cls) -> str:
|
|
47
47
|
return "ac_voice"
|
|
48
48
|
|
|
49
|
+
def _ensure_stream_id(self) -> None:
|
|
50
|
+
"""Audiocodes requires a stream ID with playStream messages."""
|
|
51
|
+
if "stream_id" not in call_state.channel_data:
|
|
52
|
+
call_state.channel_data["stream_id"] = 0
|
|
53
|
+
|
|
54
|
+
def _increment_stream_id(self) -> None:
|
|
55
|
+
self._ensure_stream_id()
|
|
56
|
+
call_state.channel_data["stream_id"] += 1
|
|
57
|
+
|
|
58
|
+
def _get_stream_id(self) -> str:
|
|
59
|
+
self._ensure_stream_id()
|
|
60
|
+
return str(call_state.channel_data["stream_id"])
|
|
61
|
+
|
|
49
62
|
def rasa_audio_bytes_to_channel_bytes(
|
|
50
63
|
self, rasa_audio_bytes: RasaAudioBytes
|
|
51
64
|
) -> bytes:
|
|
@@ -55,7 +68,7 @@ class AudiocodesVoiceOutputChannel(VoiceOutputChannel):
|
|
|
55
68
|
media_message = json.dumps(
|
|
56
69
|
{
|
|
57
70
|
"type": "playStream.chunk",
|
|
58
|
-
"streamId":
|
|
71
|
+
"streamId": self._get_stream_id(),
|
|
59
72
|
"audioChunk": channel_bytes.decode("utf-8"),
|
|
60
73
|
}
|
|
61
74
|
)
|
|
@@ -63,14 +76,14 @@ class AudiocodesVoiceOutputChannel(VoiceOutputChannel):
|
|
|
63
76
|
|
|
64
77
|
async def send_start_marker(self, recipient_id: str) -> None:
|
|
65
78
|
"""Send playStream.start before first audio chunk."""
|
|
66
|
-
|
|
79
|
+
self._increment_stream_id()
|
|
67
80
|
media_message = json.dumps(
|
|
68
81
|
{
|
|
69
82
|
"type": "playStream.start",
|
|
70
|
-
"streamId":
|
|
83
|
+
"streamId": self._get_stream_id(),
|
|
71
84
|
}
|
|
72
85
|
)
|
|
73
|
-
logger.debug("Sending start marker", stream_id=
|
|
86
|
+
logger.debug("Sending start marker", stream_id=self._get_stream_id())
|
|
74
87
|
await self.voice_websocket.send(media_message)
|
|
75
88
|
|
|
76
89
|
async def send_intermediate_marker(self, recipient_id: str) -> None:
|
|
@@ -82,10 +95,10 @@ class AudiocodesVoiceOutputChannel(VoiceOutputChannel):
|
|
|
82
95
|
media_message = json.dumps(
|
|
83
96
|
{
|
|
84
97
|
"type": "playStream.stop",
|
|
85
|
-
"streamId":
|
|
98
|
+
"streamId": self._get_stream_id(),
|
|
86
99
|
}
|
|
87
100
|
)
|
|
88
|
-
logger.debug("Sending end marker", stream_id=
|
|
101
|
+
logger.debug("Sending end marker", stream_id=self._get_stream_id())
|
|
89
102
|
await self.voice_websocket.send(media_message)
|
|
90
103
|
|
|
91
104
|
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
from contextvars import ContextVar
|
|
3
3
|
from dataclasses import dataclass, field
|
|
4
|
-
from typing import Optional
|
|
4
|
+
from typing import Any, Dict, Optional
|
|
5
5
|
|
|
6
6
|
from werkzeug.local import LocalProxy
|
|
7
7
|
|
|
@@ -19,14 +19,8 @@ class CallState:
|
|
|
19
19
|
should_hangup: bool = False
|
|
20
20
|
connection_failed: bool = False
|
|
21
21
|
|
|
22
|
-
#
|
|
23
|
-
|
|
24
|
-
client_sequence_number: int = 0
|
|
25
|
-
server_sequence_number: int = 0
|
|
26
|
-
audio_buffer: bytearray = field(default_factory=bytearray)
|
|
27
|
-
|
|
28
|
-
# Audiocodes requires a stream ID at start and end of stream
|
|
29
|
-
stream_id: int = 0
|
|
22
|
+
# Generic field for channel-specific state data
|
|
23
|
+
channel_data: Dict[str, Any] = field(default_factory=dict)
|
|
30
24
|
|
|
31
25
|
|
|
32
26
|
_call_state: ContextVar[CallState] = ContextVar("call_state")
|
|
@@ -27,8 +27,23 @@ from rasa.core.channels.voice_stream.voice_channel import (
|
|
|
27
27
|
VoiceOutputChannel,
|
|
28
28
|
)
|
|
29
29
|
|
|
30
|
-
|
|
31
|
-
|
|
30
|
+
"""
|
|
31
|
+
Genesys throws a rate limit error with too many audio messages.
|
|
32
|
+
To avoid this, we buffer the audio messages and send them in chunks.
|
|
33
|
+
|
|
34
|
+
- global.inbound.binary.average.rate.per.second: 5
|
|
35
|
+
The allowed average rate per second of inbound binary data
|
|
36
|
+
|
|
37
|
+
- global.inbound.binary.max: 25
|
|
38
|
+
The maximum number of inbound binary data messages
|
|
39
|
+
that can be sent instantaneously
|
|
40
|
+
|
|
41
|
+
https://developer.genesys.cloud/organization/organization/limits#audiohook
|
|
42
|
+
|
|
43
|
+
The maximum binary message size is not mentioned
|
|
44
|
+
in the documentation but observed in their example app
|
|
45
|
+
https://github.com/GenesysCloudBlueprints/audioconnector-server-reference-implementation
|
|
46
|
+
"""
|
|
32
47
|
MAXIMUM_BINARY_MESSAGE_SIZE = 64000 # 64KB
|
|
33
48
|
logger = structlog.get_logger(__name__)
|
|
34
49
|
|
|
@@ -56,52 +71,7 @@ class GenesysOutputChannel(VoiceOutputChannel):
|
|
|
56
71
|
async def send_audio_bytes(
|
|
57
72
|
self, recipient_id: str, audio_bytes: RasaAudioBytes
|
|
58
73
|
) -> None:
|
|
59
|
-
|
|
60
|
-
Send audio bytes to the recipient with buffering.
|
|
61
|
-
|
|
62
|
-
Genesys throws a rate limit error with too many audio messages.
|
|
63
|
-
To avoid this, we buffer the audio messages and send them in chunks.
|
|
64
|
-
|
|
65
|
-
- global.inbound.binary.average.rate.per.second: 5
|
|
66
|
-
The allowed average rate per second of inbound binary data
|
|
67
|
-
|
|
68
|
-
- global.inbound.binary.max: 25
|
|
69
|
-
The maximum number of inbound binary data messages
|
|
70
|
-
that can be sent instantaneously
|
|
71
|
-
|
|
72
|
-
https://developer.genesys.cloud/organization/organization/limits#audiohook
|
|
73
|
-
"""
|
|
74
|
-
call_state.audio_buffer.extend(audio_bytes)
|
|
75
|
-
|
|
76
|
-
# If we receive a non-standard chunk size, assume it's the end of a sequence
|
|
77
|
-
# or buffer is more than 32KB (this is half of genesys's max audio message size)
|
|
78
|
-
if len(audio_bytes) != 1024 or len(call_state.audio_buffer) >= (
|
|
79
|
-
MAXIMUM_BINARY_MESSAGE_SIZE / 2
|
|
80
|
-
):
|
|
81
|
-
# TODO: we should send the buffer when we receive a synthesis complete event
|
|
82
|
-
# from TTS. This will ensure that the last audio chunk is always sent.
|
|
83
|
-
await self._send_audio_buffer(self.voice_websocket)
|
|
84
|
-
|
|
85
|
-
async def _send_audio_buffer(self, ws: Websocket) -> None:
|
|
86
|
-
"""Send the audio buffer to the recipient if it's not empty."""
|
|
87
|
-
if call_state.audio_buffer:
|
|
88
|
-
buffer_bytes = bytes(call_state.audio_buffer)
|
|
89
|
-
await self._send_bytes_to_ws(ws, buffer_bytes)
|
|
90
|
-
call_state.audio_buffer.clear()
|
|
91
|
-
|
|
92
|
-
async def _send_bytes_to_ws(self, ws: Websocket, data: bytes) -> None:
|
|
93
|
-
"""Send audio bytes to the recipient as a binary websocket message."""
|
|
94
|
-
if len(data) <= MAXIMUM_BINARY_MESSAGE_SIZE:
|
|
95
|
-
await self.voice_websocket.send(data)
|
|
96
|
-
else:
|
|
97
|
-
# split the audio into chunks
|
|
98
|
-
current_position = 0
|
|
99
|
-
while current_position < len(data):
|
|
100
|
-
end_position = min(
|
|
101
|
-
current_position + MAXIMUM_BINARY_MESSAGE_SIZE, len(data)
|
|
102
|
-
)
|
|
103
|
-
await self.voice_websocket.send(data[current_position:end_position])
|
|
104
|
-
current_position = end_position
|
|
74
|
+
await self.voice_websocket.send(audio_bytes)
|
|
105
75
|
|
|
106
76
|
async def send_marker_message(self, recipient_id: str) -> None:
|
|
107
77
|
"""
|
|
@@ -119,6 +89,17 @@ class GenesysInputChannel(VoiceInputChannel):
|
|
|
119
89
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
120
90
|
super().__init__(*args, **kwargs)
|
|
121
91
|
|
|
92
|
+
def _ensure_channel_data_initialized(self) -> None:
|
|
93
|
+
"""Initialize Genesys-specific channel data if not already present.
|
|
94
|
+
|
|
95
|
+
Genesys requires the server and client each maintain a
|
|
96
|
+
monotonically increasing message sequence number.
|
|
97
|
+
"""
|
|
98
|
+
if "server_sequence_number" not in call_state.channel_data:
|
|
99
|
+
call_state.channel_data["server_sequence_number"] = 0
|
|
100
|
+
if "client_sequence_number" not in call_state.channel_data:
|
|
101
|
+
call_state.channel_data["client_sequence_number"] = 0
|
|
102
|
+
|
|
122
103
|
def _get_next_sequence(self) -> int:
|
|
123
104
|
"""
|
|
124
105
|
Get the next message sequence number
|
|
@@ -128,23 +109,26 @@ class GenesysInputChannel(VoiceInputChannel):
|
|
|
128
109
|
Genesys requires the server and client each maintain a
|
|
129
110
|
monotonically increasing message sequence number.
|
|
130
111
|
"""
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
return
|
|
112
|
+
self._ensure_channel_data_initialized()
|
|
113
|
+
call_state.channel_data["server_sequence_number"] += 1
|
|
114
|
+
return call_state.channel_data["server_sequence_number"]
|
|
134
115
|
|
|
135
116
|
def _get_last_client_sequence(self) -> int:
|
|
136
117
|
"""Get the last client(Genesys) sequence number."""
|
|
137
|
-
|
|
118
|
+
self._ensure_channel_data_initialized()
|
|
119
|
+
return call_state.channel_data["client_sequence_number"]
|
|
138
120
|
|
|
139
121
|
def _update_client_sequence(self, seq: int) -> None:
|
|
140
122
|
"""Update the client(Genesys) sequence number."""
|
|
141
|
-
|
|
123
|
+
self._ensure_channel_data_initialized()
|
|
124
|
+
|
|
125
|
+
if seq - call_state.channel_data["client_sequence_number"] != 1:
|
|
142
126
|
logger.warning(
|
|
143
127
|
"genesys.update_client_sequence.sequence_gap",
|
|
144
128
|
received_seq=seq,
|
|
145
|
-
last_seq=call_state.client_sequence_number,
|
|
129
|
+
last_seq=call_state.channel_data["client_sequence_number"],
|
|
146
130
|
)
|
|
147
|
-
call_state.client_sequence_number = seq
|
|
131
|
+
call_state.channel_data["client_sequence_number"] = seq
|
|
148
132
|
|
|
149
133
|
def channel_bytes_to_rasa_audio_bytes(self, input_bytes: bytes) -> RasaAudioBytes:
|
|
150
134
|
return RasaAudioBytes(input_bytes)
|
|
@@ -211,6 +195,7 @@ class GenesysInputChannel(VoiceInputChannel):
|
|
|
211
195
|
voice_websocket,
|
|
212
196
|
tts_engine,
|
|
213
197
|
self.tts_cache,
|
|
198
|
+
min_buffer_size=MAXIMUM_BINARY_MESSAGE_SIZE // 2,
|
|
214
199
|
)
|
|
215
200
|
|
|
216
201
|
async def handle_open(self, ws: Websocket, message: dict) -> CallParameters:
|
|
@@ -17,6 +17,7 @@ from rasa.core.channels.voice_stream.asr.asr_event import (
|
|
|
17
17
|
ASREvent,
|
|
18
18
|
NewTranscript,
|
|
19
19
|
UserIsSpeaking,
|
|
20
|
+
UserSilence,
|
|
20
21
|
)
|
|
21
22
|
from rasa.core.channels.voice_stream.asr.azure import AzureASR
|
|
22
23
|
from rasa.core.channels.voice_stream.asr.deepgram import DeepgramASR
|
|
@@ -120,13 +121,14 @@ class VoiceOutputChannel(OutputChannel):
|
|
|
120
121
|
voice_websocket: Websocket,
|
|
121
122
|
tts_engine: TTSEngine,
|
|
122
123
|
tts_cache: TTSCache,
|
|
124
|
+
min_buffer_size: int = 0,
|
|
123
125
|
):
|
|
124
126
|
super().__init__()
|
|
125
127
|
self.voice_websocket = voice_websocket
|
|
126
128
|
self.tts_engine = tts_engine
|
|
127
129
|
self.tts_cache = tts_cache
|
|
128
|
-
|
|
129
130
|
self.latest_message_id: Optional[str] = None
|
|
131
|
+
self.min_buffer_size = min_buffer_size
|
|
130
132
|
|
|
131
133
|
def rasa_audio_bytes_to_channel_bytes(
|
|
132
134
|
self, rasa_audio_bytes: RasaAudioBytes
|
|
@@ -186,6 +188,7 @@ class VoiceOutputChannel(OutputChannel):
|
|
|
186
188
|
cached_audio_bytes = self.tts_cache.get(text)
|
|
187
189
|
collected_audio_bytes = RasaAudioBytes(b"")
|
|
188
190
|
seconds_marker = -1
|
|
191
|
+
last_sent_offset = 0
|
|
189
192
|
|
|
190
193
|
# Send start marker before first chunk
|
|
191
194
|
try:
|
|
@@ -205,17 +208,37 @@ class VoiceOutputChannel(OutputChannel):
|
|
|
205
208
|
audio_stream = self.chunk_audio(generate_silence())
|
|
206
209
|
|
|
207
210
|
async for audio_bytes in audio_stream:
|
|
208
|
-
|
|
209
|
-
await self.send_audio_bytes(recipient_id, audio_bytes)
|
|
210
|
-
full_seconds_of_audio = len(collected_audio_bytes) // HERTZ
|
|
211
|
-
if full_seconds_of_audio > seconds_marker:
|
|
212
|
-
await self.send_intermediate_marker(recipient_id)
|
|
213
|
-
seconds_marker = full_seconds_of_audio
|
|
211
|
+
collected_audio_bytes = RasaAudioBytes(collected_audio_bytes + audio_bytes)
|
|
214
212
|
|
|
213
|
+
# Check if we have enough new bytes to send
|
|
214
|
+
current_buffer_size = len(collected_audio_bytes) - last_sent_offset
|
|
215
|
+
should_send = current_buffer_size >= self.min_buffer_size
|
|
216
|
+
|
|
217
|
+
if should_send:
|
|
218
|
+
try:
|
|
219
|
+
# Send only the new bytes since last send
|
|
220
|
+
new_bytes = RasaAudioBytes(collected_audio_bytes[last_sent_offset:])
|
|
221
|
+
await self.send_audio_bytes(recipient_id, new_bytes)
|
|
222
|
+
last_sent_offset = len(collected_audio_bytes)
|
|
223
|
+
|
|
224
|
+
full_seconds_of_audio = len(collected_audio_bytes) // HERTZ
|
|
225
|
+
if full_seconds_of_audio > seconds_marker:
|
|
226
|
+
await self.send_intermediate_marker(recipient_id)
|
|
227
|
+
seconds_marker = full_seconds_of_audio
|
|
228
|
+
|
|
229
|
+
except (WebsocketClosed, ServerError):
|
|
230
|
+
# ignore sending error, and keep collecting and caching audio bytes
|
|
231
|
+
call_state.connection_failed = True # type: ignore[attr-defined]
|
|
232
|
+
|
|
233
|
+
# Send any remaining audio not yet sent
|
|
234
|
+
remaining_bytes = len(collected_audio_bytes) - last_sent_offset
|
|
235
|
+
if remaining_bytes > 0:
|
|
236
|
+
try:
|
|
237
|
+
new_bytes = RasaAudioBytes(collected_audio_bytes[last_sent_offset:])
|
|
238
|
+
await self.send_audio_bytes(recipient_id, new_bytes)
|
|
215
239
|
except (WebsocketClosed, ServerError):
|
|
216
|
-
# ignore sending error
|
|
240
|
+
# ignore sending error
|
|
217
241
|
call_state.connection_failed = True # type: ignore[attr-defined]
|
|
218
|
-
collected_audio_bytes = RasaAudioBytes(collected_audio_bytes + audio_bytes)
|
|
219
242
|
|
|
220
243
|
try:
|
|
221
244
|
await self.send_end_marker(recipient_id)
|
|
@@ -265,13 +288,7 @@ class VoiceInputChannel(InputChannel):
|
|
|
265
288
|
self.monitor_silence = monitor_silence
|
|
266
289
|
self.tts_cache = TTSCache(tts_config.get("cache_size", 1000))
|
|
267
290
|
|
|
268
|
-
async def
|
|
269
|
-
self,
|
|
270
|
-
voice_websocket: Websocket,
|
|
271
|
-
on_new_message: Callable[[UserMessage], Awaitable[Any]],
|
|
272
|
-
tts_engine: TTSEngine,
|
|
273
|
-
call_parameters: CallParameters,
|
|
274
|
-
) -> None:
|
|
291
|
+
async def monitor_silence_timeout(self, asr_event_queue: asyncio.Queue) -> None:
|
|
275
292
|
timeout = call_state.silence_timeout
|
|
276
293
|
if not timeout:
|
|
277
294
|
return
|
|
@@ -279,16 +296,8 @@ class VoiceInputChannel(InputChannel):
|
|
|
279
296
|
return
|
|
280
297
|
logger.debug("voice_channel.silence_timeout_watch_started", timeout=timeout)
|
|
281
298
|
await asyncio.sleep(timeout)
|
|
299
|
+
await asr_event_queue.put(UserSilence())
|
|
282
300
|
logger.debug("voice_channel.silence_timeout_tripped")
|
|
283
|
-
output_channel = self.create_output_channel(voice_websocket, tts_engine)
|
|
284
|
-
message = UserMessage(
|
|
285
|
-
"/silence_timeout",
|
|
286
|
-
output_channel,
|
|
287
|
-
call_parameters.stream_id,
|
|
288
|
-
input_channel=self.name(),
|
|
289
|
-
metadata=asdict(call_parameters),
|
|
290
|
-
)
|
|
291
|
-
await on_new_message(message)
|
|
292
301
|
|
|
293
302
|
@staticmethod
|
|
294
303
|
def _cancel_silence_timeout_watcher() -> None:
|
|
@@ -350,6 +359,7 @@ class VoiceInputChannel(InputChannel):
|
|
|
350
359
|
_call_state.set(CallState())
|
|
351
360
|
asr_engine = asr_engine_from_config(self.asr_config)
|
|
352
361
|
tts_engine = tts_engine_from_config(self.tts_config)
|
|
362
|
+
asr_event_queue: asyncio.Queue = asyncio.Queue()
|
|
353
363
|
await asr_engine.connect()
|
|
354
364
|
|
|
355
365
|
call_parameters = await self.collect_call_parameters(channel_websocket)
|
|
@@ -376,12 +386,7 @@ class VoiceInputChannel(InputChannel):
|
|
|
376
386
|
self._cancel_silence_timeout_watcher()
|
|
377
387
|
call_state.silence_timeout_watcher = ( # type: ignore[attr-defined]
|
|
378
388
|
asyncio.create_task(
|
|
379
|
-
self.
|
|
380
|
-
channel_websocket,
|
|
381
|
-
on_new_message,
|
|
382
|
-
tts_engine,
|
|
383
|
-
call_parameters,
|
|
384
|
-
)
|
|
389
|
+
self.monitor_silence_timeout(asr_event_queue)
|
|
385
390
|
)
|
|
386
391
|
)
|
|
387
392
|
if isinstance(channel_action, NewAudioAction):
|
|
@@ -390,8 +395,13 @@ class VoiceInputChannel(InputChannel):
|
|
|
390
395
|
# end stream event came from the other side
|
|
391
396
|
break
|
|
392
397
|
|
|
393
|
-
async def
|
|
398
|
+
async def receive_asr_events() -> None:
|
|
394
399
|
async for event in asr_engine.stream_asr_events():
|
|
400
|
+
await asr_event_queue.put(event)
|
|
401
|
+
|
|
402
|
+
async def handle_asr_events() -> None:
|
|
403
|
+
while True:
|
|
404
|
+
event = await asr_event_queue.get()
|
|
395
405
|
await self.handle_asr_event(
|
|
396
406
|
event,
|
|
397
407
|
channel_websocket,
|
|
@@ -400,16 +410,18 @@ class VoiceInputChannel(InputChannel):
|
|
|
400
410
|
call_parameters,
|
|
401
411
|
)
|
|
402
412
|
|
|
403
|
-
|
|
404
|
-
|
|
413
|
+
tasks = [
|
|
414
|
+
asyncio.create_task(consume_audio_bytes()),
|
|
415
|
+
asyncio.create_task(receive_asr_events()),
|
|
416
|
+
asyncio.create_task(handle_asr_events()),
|
|
417
|
+
]
|
|
405
418
|
await asyncio.wait(
|
|
406
|
-
|
|
419
|
+
tasks,
|
|
407
420
|
return_when=asyncio.FIRST_COMPLETED,
|
|
408
421
|
)
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
asr_event_task.cancel()
|
|
422
|
+
for task in tasks:
|
|
423
|
+
if not task.done():
|
|
424
|
+
task.cancel()
|
|
413
425
|
await tts_engine.close_connection()
|
|
414
426
|
await asr_engine.close_connection()
|
|
415
427
|
await channel_websocket.close()
|
|
@@ -447,3 +459,13 @@ class VoiceInputChannel(InputChannel):
|
|
|
447
459
|
elif isinstance(e, UserIsSpeaking):
|
|
448
460
|
self._cancel_silence_timeout_watcher()
|
|
449
461
|
call_state.is_user_speaking = True # type: ignore[attr-defined]
|
|
462
|
+
elif isinstance(e, UserSilence):
|
|
463
|
+
output_channel = self.create_output_channel(voice_websocket, tts_engine)
|
|
464
|
+
message = UserMessage(
|
|
465
|
+
"/silence_timeout",
|
|
466
|
+
output_channel,
|
|
467
|
+
call_parameters.stream_id,
|
|
468
|
+
input_channel=self.name(),
|
|
469
|
+
metadata=asdict(call_parameters),
|
|
470
|
+
)
|
|
471
|
+
await on_new_message(message)
|