livekit-plugins-aws 0.1.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of livekit-plugins-aws might be problematic. Click here for more details.
- livekit/plugins/aws/llm.py +160 -239
- livekit/plugins/aws/models.py +1 -1
- livekit/plugins/aws/stt.py +114 -98
- livekit/plugins/aws/tts.py +72 -79
- livekit/plugins/aws/utils.py +144 -0
- livekit/plugins/aws/version.py +1 -1
- {livekit_plugins_aws-0.1.0.dist-info → livekit_plugins_aws-1.0.0.dist-info}/METADATA +14 -24
- livekit_plugins_aws-1.0.0.dist-info/RECORD +12 -0
- {livekit_plugins_aws-0.1.0.dist-info → livekit_plugins_aws-1.0.0.dist-info}/WHEEL +1 -2
- livekit/plugins/aws/_utils.py +0 -216
- livekit_plugins_aws-0.1.0.dist-info/RECORD +0 -13
- livekit_plugins_aws-0.1.0.dist-info/top_level.txt +0 -1
livekit/plugins/aws/stt.py
CHANGED
|
@@ -14,70 +14,72 @@ from __future__ import annotations
|
|
|
14
14
|
|
|
15
15
|
import asyncio
|
|
16
16
|
from dataclasses import dataclass
|
|
17
|
-
from typing import Optional
|
|
18
17
|
|
|
18
|
+
import aioboto3
|
|
19
|
+
from amazon_transcribe.auth import StaticCredentialResolver
|
|
19
20
|
from amazon_transcribe.client import TranscribeStreamingClient
|
|
20
21
|
from amazon_transcribe.model import Result, TranscriptEvent
|
|
22
|
+
|
|
21
23
|
from livekit import rtc
|
|
22
|
-
from livekit.agents import
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
utils,
|
|
27
|
-
)
|
|
28
|
-
|
|
29
|
-
from ._utils import _get_aws_credentials
|
|
24
|
+
from livekit.agents import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions, stt, utils
|
|
25
|
+
from livekit.agents.types import NOT_GIVEN, NotGivenOr
|
|
26
|
+
from livekit.agents.utils import is_given
|
|
27
|
+
|
|
30
28
|
from .log import logger
|
|
29
|
+
from .utils import DEFAULT_REGION, get_aws_async_session
|
|
30
|
+
|
|
31
|
+
REFRESH_INTERVAL = 1800
|
|
31
32
|
|
|
32
33
|
|
|
33
34
|
@dataclass
|
|
34
35
|
class STTOptions:
|
|
35
|
-
speech_region: str
|
|
36
36
|
sample_rate: int
|
|
37
37
|
language: str
|
|
38
38
|
encoding: str
|
|
39
|
-
vocabulary_name:
|
|
40
|
-
session_id:
|
|
41
|
-
vocab_filter_method:
|
|
42
|
-
vocab_filter_name:
|
|
43
|
-
show_speaker_label:
|
|
44
|
-
enable_channel_identification:
|
|
45
|
-
number_of_channels:
|
|
46
|
-
enable_partial_results_stabilization:
|
|
47
|
-
partial_results_stability:
|
|
48
|
-
language_model_name:
|
|
39
|
+
vocabulary_name: NotGivenOr[str]
|
|
40
|
+
session_id: NotGivenOr[str]
|
|
41
|
+
vocab_filter_method: NotGivenOr[str]
|
|
42
|
+
vocab_filter_name: NotGivenOr[str]
|
|
43
|
+
show_speaker_label: NotGivenOr[bool]
|
|
44
|
+
enable_channel_identification: NotGivenOr[bool]
|
|
45
|
+
number_of_channels: NotGivenOr[int]
|
|
46
|
+
enable_partial_results_stabilization: NotGivenOr[bool]
|
|
47
|
+
partial_results_stability: NotGivenOr[str]
|
|
48
|
+
language_model_name: NotGivenOr[str]
|
|
49
49
|
|
|
50
50
|
|
|
51
51
|
class STT(stt.STT):
|
|
52
52
|
def __init__(
|
|
53
53
|
self,
|
|
54
54
|
*,
|
|
55
|
-
|
|
56
|
-
api_key: str
|
|
57
|
-
api_secret: str
|
|
55
|
+
region: NotGivenOr[str] = NOT_GIVEN,
|
|
56
|
+
api_key: NotGivenOr[str] = NOT_GIVEN,
|
|
57
|
+
api_secret: NotGivenOr[str] = NOT_GIVEN,
|
|
58
58
|
sample_rate: int = 48000,
|
|
59
59
|
language: str = "en-US",
|
|
60
60
|
encoding: str = "pcm",
|
|
61
|
-
vocabulary_name:
|
|
62
|
-
session_id:
|
|
63
|
-
vocab_filter_method:
|
|
64
|
-
vocab_filter_name:
|
|
65
|
-
show_speaker_label:
|
|
66
|
-
enable_channel_identification:
|
|
67
|
-
number_of_channels:
|
|
68
|
-
enable_partial_results_stabilization:
|
|
69
|
-
partial_results_stability:
|
|
70
|
-
language_model_name:
|
|
61
|
+
vocabulary_name: NotGivenOr[str] = NOT_GIVEN,
|
|
62
|
+
session_id: NotGivenOr[str] = NOT_GIVEN,
|
|
63
|
+
vocab_filter_method: NotGivenOr[str] = NOT_GIVEN,
|
|
64
|
+
vocab_filter_name: NotGivenOr[str] = NOT_GIVEN,
|
|
65
|
+
show_speaker_label: NotGivenOr[bool] = NOT_GIVEN,
|
|
66
|
+
enable_channel_identification: NotGivenOr[bool] = NOT_GIVEN,
|
|
67
|
+
number_of_channels: NotGivenOr[int] = NOT_GIVEN,
|
|
68
|
+
enable_partial_results_stabilization: NotGivenOr[bool] = NOT_GIVEN,
|
|
69
|
+
partial_results_stability: NotGivenOr[str] = NOT_GIVEN,
|
|
70
|
+
language_model_name: NotGivenOr[str] = NOT_GIVEN,
|
|
71
|
+
session: aioboto3.Session | None = None,
|
|
72
|
+
refresh_interval: NotGivenOr[int] = NOT_GIVEN,
|
|
71
73
|
):
|
|
72
|
-
super().__init__(
|
|
73
|
-
|
|
74
|
+
super().__init__(capabilities=stt.STTCapabilities(streaming=True, interim_results=True))
|
|
75
|
+
self._region = region if is_given(region) else DEFAULT_REGION
|
|
76
|
+
self._session = session or get_aws_async_session(
|
|
77
|
+
api_key=api_key if is_given(api_key) else None,
|
|
78
|
+
api_secret=api_secret if is_given(api_secret) else None,
|
|
79
|
+
region=self._region,
|
|
74
80
|
)
|
|
75
81
|
|
|
76
|
-
self._api_key, self._api_secret = _get_aws_credentials(
|
|
77
|
-
api_key, api_secret, speech_region
|
|
78
|
-
)
|
|
79
82
|
self._config = STTOptions(
|
|
80
|
-
speech_region=speech_region,
|
|
81
83
|
language=language,
|
|
82
84
|
sample_rate=sample_rate,
|
|
83
85
|
encoding=encoding,
|
|
@@ -92,26 +94,47 @@ class STT(stt.STT):
|
|
|
92
94
|
partial_results_stability=partial_results_stability,
|
|
93
95
|
language_model_name=language_model_name,
|
|
94
96
|
)
|
|
97
|
+
self._pool = utils.ConnectionPool[TranscribeStreamingClient](
|
|
98
|
+
connect_cb=self._create_client,
|
|
99
|
+
max_session_duration=refresh_interval
|
|
100
|
+
if is_given(refresh_interval)
|
|
101
|
+
else REFRESH_INTERVAL,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
async def _create_client(self) -> TranscribeStreamingClient:
|
|
105
|
+
creds = await self._session.get_credentials()
|
|
106
|
+
frozen_credentials = await creds.get_frozen_credentials()
|
|
107
|
+
return TranscribeStreamingClient(
|
|
108
|
+
region=self._region,
|
|
109
|
+
credential_resolver=StaticCredentialResolver(
|
|
110
|
+
access_key_id=frozen_credentials.access_key,
|
|
111
|
+
secret_access_key=frozen_credentials.secret_key,
|
|
112
|
+
session_token=frozen_credentials.token,
|
|
113
|
+
),
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
async def aclose(self) -> None:
|
|
117
|
+
await self._pool.aclose()
|
|
118
|
+
await super().aclose()
|
|
95
119
|
|
|
96
120
|
async def _recognize_impl(
|
|
97
121
|
self,
|
|
98
122
|
buffer: utils.AudioBuffer,
|
|
99
123
|
*,
|
|
100
|
-
language: str
|
|
124
|
+
language: NotGivenOr[str] = NOT_GIVEN,
|
|
101
125
|
conn_options: APIConnectOptions,
|
|
102
126
|
) -> stt.SpeechEvent:
|
|
103
|
-
raise NotImplementedError(
|
|
104
|
-
"Amazon Transcribe does not support single frame recognition"
|
|
105
|
-
)
|
|
127
|
+
raise NotImplementedError("Amazon Transcribe does not support single frame recognition")
|
|
106
128
|
|
|
107
129
|
def stream(
|
|
108
130
|
self,
|
|
109
131
|
*,
|
|
110
|
-
language: str
|
|
132
|
+
language: NotGivenOr[str] = NOT_GIVEN,
|
|
111
133
|
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
|
112
|
-
) ->
|
|
134
|
+
) -> SpeechStream:
|
|
113
135
|
return SpeechStream(
|
|
114
136
|
stt=self,
|
|
137
|
+
pool=self._pool,
|
|
115
138
|
conn_options=conn_options,
|
|
116
139
|
opts=self._config,
|
|
117
140
|
)
|
|
@@ -122,54 +145,54 @@ class SpeechStream(stt.SpeechStream):
|
|
|
122
145
|
self,
|
|
123
146
|
stt: STT,
|
|
124
147
|
opts: STTOptions,
|
|
148
|
+
pool: utils.ConnectionPool[TranscribeStreamingClient],
|
|
125
149
|
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
|
126
150
|
) -> None:
|
|
127
|
-
super().__init__(
|
|
128
|
-
stt=stt, conn_options=conn_options, sample_rate=opts.sample_rate
|
|
129
|
-
)
|
|
151
|
+
super().__init__(stt=stt, conn_options=conn_options, sample_rate=opts.sample_rate)
|
|
130
152
|
self._opts = opts
|
|
131
|
-
self.
|
|
153
|
+
self._pool = pool
|
|
132
154
|
|
|
133
155
|
async def _run(self) -> None:
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
)
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
async
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
156
|
+
async with self._pool.connection() as client:
|
|
157
|
+
live_config = {
|
|
158
|
+
"language_code": self._opts.language,
|
|
159
|
+
"media_sample_rate_hz": self._opts.sample_rate,
|
|
160
|
+
"media_encoding": self._opts.encoding,
|
|
161
|
+
"vocabulary_name": self._opts.vocabulary_name,
|
|
162
|
+
"session_id": self._opts.session_id,
|
|
163
|
+
"vocab_filter_method": self._opts.vocab_filter_method,
|
|
164
|
+
"vocab_filter_name": self._opts.vocab_filter_name,
|
|
165
|
+
"show_speaker_label": self._opts.show_speaker_label,
|
|
166
|
+
"enable_channel_identification": self._opts.enable_channel_identification,
|
|
167
|
+
"number_of_channels": self._opts.number_of_channels,
|
|
168
|
+
"enable_partial_results_stabilization": self._opts.enable_partial_results_stabilization, # noqa: E501
|
|
169
|
+
"partial_results_stability": self._opts.partial_results_stability,
|
|
170
|
+
"language_model_name": self._opts.language_model_name,
|
|
171
|
+
}
|
|
172
|
+
filtered_config = {k: v for k, v in live_config.items() if v and is_given(v)}
|
|
173
|
+
stream = await client.start_stream_transcription(**filtered_config)
|
|
174
|
+
|
|
175
|
+
@utils.log_exceptions(logger=logger)
|
|
176
|
+
async def input_generator():
|
|
177
|
+
async for frame in self._input_ch:
|
|
178
|
+
if isinstance(frame, rtc.AudioFrame):
|
|
179
|
+
await stream.input_stream.send_audio_event(audio_chunk=frame.data.tobytes())
|
|
180
|
+
await stream.input_stream.end_stream()
|
|
181
|
+
|
|
182
|
+
@utils.log_exceptions(logger=logger)
|
|
183
|
+
async def handle_transcript_events():
|
|
184
|
+
async for event in stream.output_stream:
|
|
185
|
+
if isinstance(event, TranscriptEvent):
|
|
186
|
+
self._process_transcript_event(event)
|
|
187
|
+
|
|
188
|
+
tasks = [
|
|
189
|
+
asyncio.create_task(input_generator()),
|
|
190
|
+
asyncio.create_task(handle_transcript_events()),
|
|
191
|
+
]
|
|
192
|
+
try:
|
|
193
|
+
await asyncio.gather(*tasks)
|
|
194
|
+
finally:
|
|
195
|
+
await utils.aio.gracefully_cancel(*tasks)
|
|
173
196
|
|
|
174
197
|
def _process_transcript_event(self, transcript_event: TranscriptEvent):
|
|
175
198
|
stream = transcript_event.transcript.results
|
|
@@ -184,9 +207,7 @@ class SpeechStream(stt.SpeechStream):
|
|
|
184
207
|
self._event_ch.send_nowait(
|
|
185
208
|
stt.SpeechEvent(
|
|
186
209
|
type=stt.SpeechEventType.INTERIM_TRANSCRIPT,
|
|
187
|
-
alternatives=[
|
|
188
|
-
_streaming_recognize_response_to_speech_data(resp)
|
|
189
|
-
],
|
|
210
|
+
alternatives=[_streaming_recognize_response_to_speech_data(resp)],
|
|
190
211
|
)
|
|
191
212
|
)
|
|
192
213
|
|
|
@@ -194,16 +215,12 @@ class SpeechStream(stt.SpeechStream):
|
|
|
194
215
|
self._event_ch.send_nowait(
|
|
195
216
|
stt.SpeechEvent(
|
|
196
217
|
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
|
|
197
|
-
alternatives=[
|
|
198
|
-
_streaming_recognize_response_to_speech_data(resp)
|
|
199
|
-
],
|
|
218
|
+
alternatives=[_streaming_recognize_response_to_speech_data(resp)],
|
|
200
219
|
)
|
|
201
220
|
)
|
|
202
221
|
|
|
203
222
|
if not resp.is_partial:
|
|
204
|
-
self._event_ch.send_nowait(
|
|
205
|
-
stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH)
|
|
206
|
-
)
|
|
223
|
+
self._event_ch.send_nowait(stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH))
|
|
207
224
|
|
|
208
225
|
|
|
209
226
|
def _streaming_recognize_response_to_speech_data(resp: Result) -> stt.SpeechData:
|
|
@@ -211,7 +228,6 @@ def _streaming_recognize_response_to_speech_data(resp: Result) -> stt.SpeechData
|
|
|
211
228
|
language="en-US",
|
|
212
229
|
start_time=resp.start_time if resp.start_time else 0.0,
|
|
213
230
|
end_time=resp.end_time if resp.end_time else 0.0,
|
|
214
|
-
confidence=0.0,
|
|
215
231
|
text=resp.alternatives[0].transcript if resp.alternatives else "",
|
|
216
232
|
)
|
|
217
233
|
|
livekit/plugins/aws/tts.py
CHANGED
|
@@ -14,11 +14,10 @@ from __future__ import annotations
|
|
|
14
14
|
|
|
15
15
|
import asyncio
|
|
16
16
|
from dataclasses import dataclass
|
|
17
|
-
from typing import Any, Callable, Optional
|
|
18
17
|
|
|
18
|
+
import aioboto3
|
|
19
19
|
import aiohttp
|
|
20
|
-
|
|
21
|
-
from livekit import rtc
|
|
20
|
+
|
|
22
21
|
from livekit.agents import (
|
|
23
22
|
APIConnectionError,
|
|
24
23
|
APIConnectOptions,
|
|
@@ -27,14 +26,18 @@ from livekit.agents import (
|
|
|
27
26
|
tts,
|
|
28
27
|
utils,
|
|
29
28
|
)
|
|
29
|
+
from livekit.agents.types import (
|
|
30
|
+
DEFAULT_API_CONNECT_OPTIONS,
|
|
31
|
+
NOT_GIVEN,
|
|
32
|
+
NotGivenOr,
|
|
33
|
+
)
|
|
34
|
+
from livekit.agents.utils import is_given
|
|
30
35
|
|
|
31
|
-
from .
|
|
32
|
-
from .
|
|
36
|
+
from .models import TTS_LANGUAGE, TTS_SPEECH_ENGINE
|
|
37
|
+
from .utils import _strip_nones, get_aws_async_session
|
|
33
38
|
|
|
34
39
|
TTS_NUM_CHANNELS: int = 1
|
|
35
|
-
DEFAULT_OUTPUT_FORMAT: TTS_OUTPUT_FORMAT = "pcm"
|
|
36
40
|
DEFAULT_SPEECH_ENGINE: TTS_SPEECH_ENGINE = "generative"
|
|
37
|
-
DEFAULT_SPEECH_REGION = "us-east-1"
|
|
38
41
|
DEFAULT_VOICE = "Ruth"
|
|
39
42
|
DEFAULT_SAMPLE_RATE = 16000
|
|
40
43
|
|
|
@@ -42,27 +45,25 @@ DEFAULT_SAMPLE_RATE = 16000
|
|
|
42
45
|
@dataclass
|
|
43
46
|
class _TTSOptions:
|
|
44
47
|
# https://docs.aws.amazon.com/polly/latest/dg/API_SynthesizeSpeech.html
|
|
45
|
-
voice: str
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
speech_region: str
|
|
48
|
+
voice: NotGivenOr[str]
|
|
49
|
+
speech_engine: NotGivenOr[TTS_SPEECH_ENGINE]
|
|
50
|
+
region: str
|
|
49
51
|
sample_rate: int
|
|
50
|
-
language: TTS_LANGUAGE | str
|
|
52
|
+
language: NotGivenOr[TTS_LANGUAGE | str]
|
|
51
53
|
|
|
52
54
|
|
|
53
55
|
class TTS(tts.TTS):
|
|
54
56
|
def __init__(
|
|
55
57
|
self,
|
|
56
58
|
*,
|
|
57
|
-
voice: str
|
|
58
|
-
language: TTS_LANGUAGE | str
|
|
59
|
-
|
|
60
|
-
speech_engine: TTS_SPEECH_ENGINE = DEFAULT_SPEECH_ENGINE,
|
|
59
|
+
voice: NotGivenOr[str] = NOT_GIVEN,
|
|
60
|
+
language: NotGivenOr[TTS_LANGUAGE | str] = NOT_GIVEN,
|
|
61
|
+
speech_engine: NotGivenOr[TTS_SPEECH_ENGINE] = NOT_GIVEN,
|
|
61
62
|
sample_rate: int = DEFAULT_SAMPLE_RATE,
|
|
62
|
-
|
|
63
|
-
api_key: str
|
|
64
|
-
api_secret: str
|
|
65
|
-
session:
|
|
63
|
+
region: NotGivenOr[str] = NOT_GIVEN,
|
|
64
|
+
api_key: NotGivenOr[str] = NOT_GIVEN,
|
|
65
|
+
api_secret: NotGivenOr[str] = NOT_GIVEN,
|
|
66
|
+
session: aioboto3.Session | None = None,
|
|
66
67
|
) -> None:
|
|
67
68
|
"""
|
|
68
69
|
Create a new instance of AWS Polly TTS.
|
|
@@ -75,13 +76,13 @@ class TTS(tts.TTS):
|
|
|
75
76
|
Args:
|
|
76
77
|
Voice (TTSModels, optional): Voice ID to use for the synthesis. Defaults to "Ruth".
|
|
77
78
|
language (TTS_LANGUAGE, optional): language code for the Synthesize Speech request. This is only necessary if using a bilingual voice, such as Aditi, which can be used for either Indian English (en-IN) or Hindi (hi-IN).
|
|
78
|
-
output_format(TTS_OUTPUT_FORMAT, optional): The format in which the returned output will be encoded. Defaults to "pcm".
|
|
79
79
|
sample_rate(int, optional): The audio frequency specified in Hz. Defaults to 16000.
|
|
80
80
|
speech_engine(TTS_SPEECH_ENGINE, optional): The engine to use for the synthesis. Defaults to "generative".
|
|
81
|
-
|
|
81
|
+
region(str, optional): The region to use for the synthesis. Defaults to "us-east-1".
|
|
82
82
|
api_key(str, optional): AWS access key id.
|
|
83
83
|
api_secret(str, optional): AWS secret access key.
|
|
84
|
-
|
|
84
|
+
session(aioboto3.Session, optional): Optional aioboto3 session to use.
|
|
85
|
+
""" # noqa: E501
|
|
85
86
|
super().__init__(
|
|
86
87
|
capabilities=tts.TTSCapabilities(
|
|
87
88
|
streaming=False,
|
|
@@ -89,41 +90,31 @@ class TTS(tts.TTS):
|
|
|
89
90
|
sample_rate=sample_rate,
|
|
90
91
|
num_channels=TTS_NUM_CHANNELS,
|
|
91
92
|
)
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
93
|
+
self._session = session or get_aws_async_session(
|
|
94
|
+
api_key=api_key if is_given(api_key) else None,
|
|
95
|
+
api_secret=api_secret if is_given(api_secret) else None,
|
|
96
|
+
region=region if is_given(region) else None,
|
|
95
97
|
)
|
|
96
|
-
|
|
97
98
|
self._opts = _TTSOptions(
|
|
98
99
|
voice=voice,
|
|
99
|
-
output_format=output_format,
|
|
100
100
|
speech_engine=speech_engine,
|
|
101
|
-
|
|
101
|
+
region=region,
|
|
102
102
|
language=language,
|
|
103
103
|
sample_rate=sample_rate,
|
|
104
104
|
)
|
|
105
|
-
self._session = session or get_session()
|
|
106
|
-
|
|
107
|
-
def _get_client(self):
|
|
108
|
-
return self._session.create_client(
|
|
109
|
-
"polly",
|
|
110
|
-
region_name=self._opts.speech_region,
|
|
111
|
-
aws_access_key_id=self._api_key,
|
|
112
|
-
aws_secret_access_key=self._api_secret,
|
|
113
|
-
)
|
|
114
105
|
|
|
115
106
|
def synthesize(
|
|
116
107
|
self,
|
|
117
108
|
text: str,
|
|
118
109
|
*,
|
|
119
|
-
conn_options:
|
|
120
|
-
) ->
|
|
110
|
+
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
|
111
|
+
) -> ChunkedStream:
|
|
121
112
|
return ChunkedStream(
|
|
122
113
|
tts=self,
|
|
123
114
|
text=text,
|
|
124
115
|
conn_options=conn_options,
|
|
116
|
+
session=self._session,
|
|
125
117
|
opts=self._opts,
|
|
126
|
-
get_client=self._get_client,
|
|
127
118
|
)
|
|
128
119
|
|
|
129
120
|
|
|
@@ -133,57 +124,63 @@ class ChunkedStream(tts.ChunkedStream):
|
|
|
133
124
|
*,
|
|
134
125
|
tts: TTS,
|
|
135
126
|
text: str,
|
|
136
|
-
|
|
127
|
+
session: aioboto3.Session,
|
|
128
|
+
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
|
137
129
|
opts: _TTSOptions,
|
|
138
|
-
get_client: Callable[[], Any],
|
|
139
130
|
) -> None:
|
|
140
131
|
super().__init__(tts=tts, input_text=text, conn_options=conn_options)
|
|
141
132
|
self._opts = opts
|
|
142
|
-
self._get_client = get_client
|
|
143
133
|
self._segment_id = utils.shortuuid()
|
|
134
|
+
self._session = session
|
|
144
135
|
|
|
145
136
|
async def _run(self):
|
|
146
137
|
request_id = utils.shortuuid()
|
|
147
138
|
|
|
148
139
|
try:
|
|
149
|
-
async with self.
|
|
140
|
+
async with self._session.client("polly") as client:
|
|
150
141
|
params = {
|
|
151
142
|
"Text": self._input_text,
|
|
152
|
-
"OutputFormat":
|
|
153
|
-
"Engine": self._opts.speech_engine
|
|
154
|
-
|
|
143
|
+
"OutputFormat": "mp3",
|
|
144
|
+
"Engine": self._opts.speech_engine
|
|
145
|
+
if is_given(self._opts.speech_engine)
|
|
146
|
+
else DEFAULT_SPEECH_ENGINE,
|
|
147
|
+
"VoiceId": self._opts.voice if is_given(self._opts.voice) else DEFAULT_VOICE,
|
|
155
148
|
"TextType": "text",
|
|
156
149
|
"SampleRate": str(self._opts.sample_rate),
|
|
157
|
-
"LanguageCode": self._opts.language,
|
|
150
|
+
"LanguageCode": self._opts.language if is_given(self._opts.language) else None,
|
|
158
151
|
}
|
|
159
152
|
response = await client.synthesize_speech(**_strip_nones(params))
|
|
160
153
|
if "AudioStream" in response:
|
|
161
|
-
decoder = utils.codecs.
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
154
|
+
decoder = utils.codecs.AudioStreamDecoder(
|
|
155
|
+
sample_rate=self._opts.sample_rate,
|
|
156
|
+
num_channels=1,
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
# Create a task to push data to the decoder
|
|
160
|
+
async def push_data():
|
|
161
|
+
try:
|
|
162
|
+
async with response["AudioStream"] as resp:
|
|
163
|
+
async for data, _ in resp.content.iter_chunks():
|
|
164
|
+
decoder.push(data)
|
|
165
|
+
finally:
|
|
166
|
+
decoder.end_input()
|
|
167
|
+
|
|
168
|
+
# Start pushing data to the decoder
|
|
169
|
+
push_task = asyncio.create_task(push_data())
|
|
170
|
+
|
|
171
|
+
try:
|
|
172
|
+
# Create emitter and process decoded frames
|
|
173
|
+
emitter = tts.SynthesizedAudioEmitter(
|
|
174
|
+
event_ch=self._event_ch,
|
|
175
|
+
request_id=request_id,
|
|
176
|
+
segment_id=self._segment_id,
|
|
177
|
+
)
|
|
178
|
+
async for frame in decoder:
|
|
179
|
+
emitter.push(frame)
|
|
180
|
+
emitter.flush()
|
|
181
|
+
await push_task
|
|
182
|
+
finally:
|
|
183
|
+
await utils.aio.gracefully_cancel(push_task)
|
|
187
184
|
|
|
188
185
|
except asyncio.TimeoutError as e:
|
|
189
186
|
raise APITimeoutError() from e
|
|
@@ -196,7 +193,3 @@ class ChunkedStream(tts.ChunkedStream):
|
|
|
196
193
|
) from e
|
|
197
194
|
except Exception as e:
|
|
198
195
|
raise APIConnectionError() from e
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
def _strip_nones(d: dict[str, Any]) -> dict[str, Any]:
|
|
202
|
-
return {k: v for k, v in d.items() if v is not None}
|