livekit-plugins-aws 1.0.0rc6__py3-none-any.whl → 1.3.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- livekit/plugins/aws/__init__.py +47 -7
- livekit/plugins/aws/experimental/realtime/__init__.py +11 -0
- livekit/plugins/aws/experimental/realtime/events.py +545 -0
- livekit/plugins/aws/experimental/realtime/pretty_printer.py +49 -0
- livekit/plugins/aws/experimental/realtime/realtime_model.py +2106 -0
- livekit/plugins/aws/experimental/realtime/turn_tracker.py +171 -0
- livekit/plugins/aws/experimental/realtime/types.py +38 -0
- livekit/plugins/aws/llm.py +109 -71
- livekit/plugins/aws/log.py +4 -0
- livekit/plugins/aws/models.py +4 -3
- livekit/plugins/aws/stt.py +214 -71
- livekit/plugins/aws/tts.py +96 -116
- livekit/plugins/aws/utils.py +29 -125
- livekit/plugins/aws/version.py +1 -1
- livekit_plugins_aws-1.3.9.dist-info/METADATA +385 -0
- livekit_plugins_aws-1.3.9.dist-info/RECORD +18 -0
- {livekit_plugins_aws-1.0.0rc6.dist-info → livekit_plugins_aws-1.3.9.dist-info}/WHEEL +1 -1
- livekit_plugins_aws-1.0.0rc6.dist-info/METADATA +0 -43
- livekit_plugins_aws-1.0.0rc6.dist-info/RECORD +0 -12
livekit/plugins/aws/stt.py
CHANGED
|
@@ -13,23 +13,59 @@
|
|
|
13
13
|
from __future__ import annotations
|
|
14
14
|
|
|
15
15
|
import asyncio
|
|
16
|
+
import concurrent.futures
|
|
17
|
+
import contextlib
|
|
18
|
+
import os
|
|
16
19
|
from dataclasses import dataclass
|
|
17
|
-
|
|
18
|
-
from amazon_transcribe.client import TranscribeStreamingClient
|
|
19
|
-
from amazon_transcribe.model import Result, TranscriptEvent
|
|
20
|
+
from typing import Any
|
|
20
21
|
|
|
21
22
|
from livekit import rtc
|
|
22
|
-
from livekit.agents import
|
|
23
|
+
from livekit.agents import (
|
|
24
|
+
DEFAULT_API_CONNECT_OPTIONS,
|
|
25
|
+
APIConnectOptions,
|
|
26
|
+
stt,
|
|
27
|
+
utils,
|
|
28
|
+
)
|
|
23
29
|
from livekit.agents.types import NOT_GIVEN, NotGivenOr
|
|
24
30
|
from livekit.agents.utils import is_given
|
|
31
|
+
from livekit.agents.voice.io import TimedString
|
|
25
32
|
|
|
26
33
|
from .log import logger
|
|
27
|
-
from .utils import
|
|
34
|
+
from .utils import DEFAULT_REGION
|
|
35
|
+
|
|
36
|
+
try:
|
|
37
|
+
from aws_sdk_transcribe_streaming.client import TranscribeStreamingClient # type: ignore
|
|
38
|
+
from aws_sdk_transcribe_streaming.config import Config # type: ignore
|
|
39
|
+
from aws_sdk_transcribe_streaming.models import ( # type: ignore
|
|
40
|
+
AudioEvent,
|
|
41
|
+
AudioStream,
|
|
42
|
+
AudioStreamAudioEvent,
|
|
43
|
+
BadRequestException,
|
|
44
|
+
Result,
|
|
45
|
+
StartStreamTranscriptionInput,
|
|
46
|
+
TranscriptEvent,
|
|
47
|
+
TranscriptResultStream,
|
|
48
|
+
)
|
|
49
|
+
from smithy_aws_core.identity.environment import EnvironmentCredentialsResolver
|
|
50
|
+
from smithy_core.aio.interfaces.eventstream import (
|
|
51
|
+
EventPublisher,
|
|
52
|
+
EventReceiver,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
_AWS_SDK_AVAILABLE = True
|
|
56
|
+
except ImportError:
|
|
57
|
+
_AWS_SDK_AVAILABLE = False
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@dataclass
|
|
61
|
+
class Credentials:
|
|
62
|
+
access_key_id: str
|
|
63
|
+
secret_access_key: str
|
|
64
|
+
session_token: str | None = None
|
|
28
65
|
|
|
29
66
|
|
|
30
67
|
@dataclass
|
|
31
68
|
class STTOptions:
|
|
32
|
-
speech_region: str
|
|
33
69
|
sample_rate: int
|
|
34
70
|
language: str
|
|
35
71
|
encoding: str
|
|
@@ -43,16 +79,15 @@ class STTOptions:
|
|
|
43
79
|
enable_partial_results_stabilization: NotGivenOr[bool]
|
|
44
80
|
partial_results_stability: NotGivenOr[str]
|
|
45
81
|
language_model_name: NotGivenOr[str]
|
|
82
|
+
region: str
|
|
46
83
|
|
|
47
84
|
|
|
48
85
|
class STT(stt.STT):
|
|
49
86
|
def __init__(
|
|
50
87
|
self,
|
|
51
88
|
*,
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
api_secret: NotGivenOr[str] = NOT_GIVEN,
|
|
55
|
-
sample_rate: int = 48000,
|
|
89
|
+
region: NotGivenOr[str] = NOT_GIVEN,
|
|
90
|
+
sample_rate: int = 24000,
|
|
56
91
|
language: str = "en-US",
|
|
57
92
|
encoding: str = "pcm",
|
|
58
93
|
vocabulary_name: NotGivenOr[str] = NOT_GIVEN,
|
|
@@ -65,14 +100,24 @@ class STT(stt.STT):
|
|
|
65
100
|
enable_partial_results_stabilization: NotGivenOr[bool] = NOT_GIVEN,
|
|
66
101
|
partial_results_stability: NotGivenOr[str] = NOT_GIVEN,
|
|
67
102
|
language_model_name: NotGivenOr[str] = NOT_GIVEN,
|
|
103
|
+
credentials: NotGivenOr[Credentials] = NOT_GIVEN,
|
|
68
104
|
):
|
|
69
|
-
super().__init__(
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
105
|
+
super().__init__(
|
|
106
|
+
capabilities=stt.STTCapabilities(
|
|
107
|
+
streaming=True, interim_results=True, aligned_transcript="word"
|
|
108
|
+
)
|
|
73
109
|
)
|
|
110
|
+
|
|
111
|
+
if not _AWS_SDK_AVAILABLE:
|
|
112
|
+
raise ImportError(
|
|
113
|
+
"The 'aws_sdk_transcribe_streaming' package is not installed. "
|
|
114
|
+
"This implementation requires Python 3.12+ and the 'aws_sdk_transcribe_streaming' dependency."
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
if not is_given(region):
|
|
118
|
+
region = os.getenv("AWS_REGION") or DEFAULT_REGION
|
|
119
|
+
|
|
74
120
|
self._config = STTOptions(
|
|
75
|
-
speech_region=self._speech_region,
|
|
76
121
|
language=language,
|
|
77
122
|
sample_rate=sample_rate,
|
|
78
123
|
encoding=encoding,
|
|
@@ -86,8 +131,26 @@ class STT(stt.STT):
|
|
|
86
131
|
enable_partial_results_stabilization=enable_partial_results_stabilization,
|
|
87
132
|
partial_results_stability=partial_results_stability,
|
|
88
133
|
language_model_name=language_model_name,
|
|
134
|
+
region=region,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
self._credentials = credentials if is_given(credentials) else None
|
|
138
|
+
|
|
139
|
+
@property
|
|
140
|
+
def model(self) -> str:
|
|
141
|
+
return (
|
|
142
|
+
self._config.language_model_name
|
|
143
|
+
if is_given(self._config.language_model_name)
|
|
144
|
+
else "unknown"
|
|
89
145
|
)
|
|
90
146
|
|
|
147
|
+
@property
|
|
148
|
+
def provider(self) -> str:
|
|
149
|
+
return "Amazon Transcribe"
|
|
150
|
+
|
|
151
|
+
async def aclose(self) -> None:
|
|
152
|
+
await super().aclose()
|
|
153
|
+
|
|
91
154
|
async def _recognize_impl(
|
|
92
155
|
self,
|
|
93
156
|
buffer: utils.AudioBuffer,
|
|
@@ -107,6 +170,7 @@ class STT(stt.STT):
|
|
|
107
170
|
stt=self,
|
|
108
171
|
conn_options=conn_options,
|
|
109
172
|
opts=self._config,
|
|
173
|
+
credentials=self._credentials,
|
|
110
174
|
)
|
|
111
175
|
|
|
112
176
|
|
|
@@ -116,66 +180,132 @@ class SpeechStream(stt.SpeechStream):
|
|
|
116
180
|
stt: STT,
|
|
117
181
|
opts: STTOptions,
|
|
118
182
|
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
|
183
|
+
credentials: Credentials | None = None,
|
|
119
184
|
) -> None:
|
|
120
185
|
super().__init__(stt=stt, conn_options=conn_options, sample_rate=opts.sample_rate)
|
|
121
186
|
self._opts = opts
|
|
122
|
-
self.
|
|
187
|
+
self._credentials = credentials
|
|
123
188
|
|
|
124
189
|
async def _run(self) -> None:
|
|
125
|
-
|
|
126
|
-
"
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
190
|
+
while True:
|
|
191
|
+
config_kwargs: dict[str, Any] = {"region": self._opts.region}
|
|
192
|
+
if self._credentials:
|
|
193
|
+
config_kwargs["aws_access_key_id"] = self._credentials.access_key_id
|
|
194
|
+
config_kwargs["aws_secret_access_key"] = self._credentials.secret_access_key
|
|
195
|
+
config_kwargs["aws_session_token"] = self._credentials.session_token
|
|
196
|
+
else:
|
|
197
|
+
config_kwargs["aws_credentials_identity_resolver"] = (
|
|
198
|
+
EnvironmentCredentialsResolver()
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
client: TranscribeStreamingClient = TranscribeStreamingClient(
|
|
202
|
+
config=Config(**config_kwargs)
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
live_config = {
|
|
206
|
+
"language_code": self._opts.language,
|
|
207
|
+
"media_sample_rate_hertz": self._opts.sample_rate,
|
|
208
|
+
"media_encoding": self._opts.encoding,
|
|
209
|
+
"vocabulary_name": self._opts.vocabulary_name,
|
|
210
|
+
"session_id": self._opts.session_id,
|
|
211
|
+
"vocab_filter_method": self._opts.vocab_filter_method,
|
|
212
|
+
"vocab_filter_name": self._opts.vocab_filter_name,
|
|
213
|
+
"show_speaker_label": self._opts.show_speaker_label,
|
|
214
|
+
"enable_channel_identification": self._opts.enable_channel_identification,
|
|
215
|
+
"number_of_channels": self._opts.number_of_channels,
|
|
216
|
+
"enable_partial_results_stabilization": self._opts.enable_partial_results_stabilization,
|
|
217
|
+
"partial_results_stability": self._opts.partial_results_stability,
|
|
218
|
+
"language_model_name": self._opts.language_model_name,
|
|
219
|
+
}
|
|
220
|
+
filtered_config = {k: v for k, v in live_config.items() if v and is_given(v)}
|
|
221
|
+
|
|
222
|
+
try:
|
|
223
|
+
stream = await client.start_stream_transcription(
|
|
224
|
+
input=StartStreamTranscriptionInput(**filtered_config)
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
# Get the output stream
|
|
228
|
+
_, output_stream = await stream.await_output()
|
|
229
|
+
|
|
230
|
+
async def input_generator(
|
|
231
|
+
audio_stream: EventPublisher[AudioStream],
|
|
232
|
+
) -> None:
|
|
233
|
+
try:
|
|
234
|
+
async for frame in self._input_ch:
|
|
235
|
+
if isinstance(frame, rtc.AudioFrame):
|
|
236
|
+
await audio_stream.send(
|
|
237
|
+
AudioStreamAudioEvent(
|
|
238
|
+
value=AudioEvent(audio_chunk=frame.data.tobytes())
|
|
239
|
+
)
|
|
240
|
+
)
|
|
241
|
+
# Send empty frame to close
|
|
242
|
+
await audio_stream.send(
|
|
243
|
+
AudioStreamAudioEvent(value=AudioEvent(audio_chunk=b""))
|
|
244
|
+
)
|
|
245
|
+
finally:
|
|
246
|
+
with contextlib.suppress(Exception):
|
|
247
|
+
await audio_stream.close()
|
|
248
|
+
|
|
249
|
+
async def handle_transcript_events(
|
|
250
|
+
output_stream: EventReceiver[TranscriptResultStream],
|
|
251
|
+
) -> None:
|
|
252
|
+
try:
|
|
253
|
+
async for event in output_stream:
|
|
254
|
+
if isinstance(event.value, TranscriptEvent):
|
|
255
|
+
self._process_transcript_event(event.value)
|
|
256
|
+
except concurrent.futures.InvalidStateError:
|
|
257
|
+
logger.warning(
|
|
258
|
+
"AWS Transcribe stream closed unexpectedly (InvalidStateError)"
|
|
259
|
+
)
|
|
260
|
+
pass
|
|
261
|
+
|
|
262
|
+
tasks = [
|
|
263
|
+
asyncio.create_task(input_generator(stream.input_stream)),
|
|
264
|
+
asyncio.create_task(handle_transcript_events(output_stream)),
|
|
265
|
+
]
|
|
266
|
+
gather_future = asyncio.gather(*tasks)
|
|
267
|
+
|
|
268
|
+
await asyncio.shield(gather_future)
|
|
269
|
+
except BadRequestException as e:
|
|
270
|
+
if e.message and e.message.startswith("Your request timed out"):
|
|
271
|
+
# AWS times out after 15s of inactivity, this tends to happen
|
|
272
|
+
# at the end of the session, when the input is gone, we'll ignore it and
|
|
273
|
+
# just treat it as a silent retry
|
|
274
|
+
logger.info("restarting transcribe session")
|
|
275
|
+
continue
|
|
276
|
+
else:
|
|
277
|
+
raise e
|
|
278
|
+
finally:
|
|
279
|
+
# Close input stream first
|
|
280
|
+
await utils.aio.gracefully_cancel(tasks[0])
|
|
281
|
+
|
|
282
|
+
# Wait for output stream to close cleanly
|
|
283
|
+
try:
|
|
284
|
+
await asyncio.wait_for(tasks[1], timeout=3.0)
|
|
285
|
+
except (asyncio.TimeoutError, asyncio.CancelledError):
|
|
286
|
+
await utils.aio.gracefully_cancel(tasks[1])
|
|
287
|
+
|
|
288
|
+
# Ensure gather future is retrieved to avoid "exception never retrieved"
|
|
289
|
+
with contextlib.suppress(Exception):
|
|
290
|
+
await gather_future
|
|
291
|
+
|
|
292
|
+
def _process_transcript_event(self, transcript_event: TranscriptEvent) -> None:
|
|
293
|
+
if not transcript_event.transcript or not transcript_event.transcript.results:
|
|
294
|
+
return
|
|
295
|
+
|
|
166
296
|
stream = transcript_event.transcript.results
|
|
167
297
|
for resp in stream:
|
|
168
|
-
if resp.start_time and resp.start_time == 0.0:
|
|
298
|
+
if resp.start_time is not None and resp.start_time == 0.0:
|
|
169
299
|
self._event_ch.send_nowait(
|
|
170
300
|
stt.SpeechEvent(type=stt.SpeechEventType.START_OF_SPEECH)
|
|
171
301
|
)
|
|
172
302
|
|
|
173
|
-
if resp.end_time and resp.end_time > 0.0:
|
|
303
|
+
if resp.end_time is not None and resp.end_time > 0.0:
|
|
174
304
|
if resp.is_partial:
|
|
175
305
|
self._event_ch.send_nowait(
|
|
176
306
|
stt.SpeechEvent(
|
|
177
307
|
type=stt.SpeechEventType.INTERIM_TRANSCRIPT,
|
|
178
|
-
alternatives=[_streaming_recognize_response_to_speech_data(resp)],
|
|
308
|
+
alternatives=[self._streaming_recognize_response_to_speech_data(resp)],
|
|
179
309
|
)
|
|
180
310
|
)
|
|
181
311
|
|
|
@@ -183,21 +313,34 @@ class SpeechStream(stt.SpeechStream):
|
|
|
183
313
|
self._event_ch.send_nowait(
|
|
184
314
|
stt.SpeechEvent(
|
|
185
315
|
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
|
|
186
|
-
alternatives=[_streaming_recognize_response_to_speech_data(resp)],
|
|
316
|
+
alternatives=[self._streaming_recognize_response_to_speech_data(resp)],
|
|
187
317
|
)
|
|
188
318
|
)
|
|
189
319
|
|
|
190
320
|
if not resp.is_partial:
|
|
191
321
|
self._event_ch.send_nowait(stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH))
|
|
192
322
|
|
|
323
|
+
def _streaming_recognize_response_to_speech_data(self, resp: Result) -> stt.SpeechData:
|
|
324
|
+
confidence = 0.0
|
|
325
|
+
if resp.alternatives and (items := resp.alternatives[0].items):
|
|
326
|
+
confidence = items[0].confidence or 0.0
|
|
193
327
|
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
328
|
+
return stt.SpeechData(
|
|
329
|
+
language=resp.language_code or self._opts.language,
|
|
330
|
+
start_time=(resp.start_time or 0.0) + self.start_time_offset,
|
|
331
|
+
end_time=(resp.end_time or 0.0) + self.start_time_offset,
|
|
332
|
+
text=resp.alternatives[0].transcript if resp.alternatives else "",
|
|
333
|
+
confidence=confidence,
|
|
334
|
+
words=[
|
|
335
|
+
TimedString(
|
|
336
|
+
text=item.content,
|
|
337
|
+
start_time=item.start_time + self.start_time_offset,
|
|
338
|
+
end_time=item.end_time + self.start_time_offset,
|
|
339
|
+
start_time_offset=self.start_time_offset,
|
|
340
|
+
confidence=item.confidence or 0.0,
|
|
341
|
+
)
|
|
342
|
+
for item in resp.alternatives[0].items
|
|
343
|
+
]
|
|
344
|
+
if resp.alternatives and resp.alternatives[0].items
|
|
345
|
+
else None,
|
|
346
|
+
)
|
livekit/plugins/aws/tts.py
CHANGED
|
@@ -12,20 +12,19 @@
|
|
|
12
12
|
|
|
13
13
|
from __future__ import annotations
|
|
14
14
|
|
|
15
|
-
import
|
|
16
|
-
from
|
|
17
|
-
from typing import Any, Callable
|
|
15
|
+
from dataclasses import dataclass, replace
|
|
16
|
+
from typing import cast
|
|
18
17
|
|
|
19
|
-
import
|
|
20
|
-
|
|
18
|
+
import aioboto3 # type: ignore
|
|
19
|
+
import botocore # type: ignore
|
|
20
|
+
import botocore.exceptions # type: ignore
|
|
21
|
+
from aiobotocore.config import AioConfig # type: ignore
|
|
21
22
|
|
|
22
23
|
from livekit.agents import (
|
|
23
24
|
APIConnectionError,
|
|
24
25
|
APIConnectOptions,
|
|
25
|
-
APIStatusError,
|
|
26
26
|
APITimeoutError,
|
|
27
27
|
tts,
|
|
28
|
-
utils,
|
|
29
28
|
)
|
|
30
29
|
from livekit.agents.types import (
|
|
31
30
|
DEFAULT_API_CONNECT_OPTIONS,
|
|
@@ -34,38 +33,38 @@ from livekit.agents.types import (
|
|
|
34
33
|
)
|
|
35
34
|
from livekit.agents.utils import is_given
|
|
36
35
|
|
|
37
|
-
from .models import
|
|
38
|
-
from .utils import _strip_nones
|
|
36
|
+
from .models import TTSLanguages, TTSSpeechEngine, TTSTextType
|
|
37
|
+
from .utils import _strip_nones
|
|
39
38
|
|
|
40
|
-
|
|
41
|
-
DEFAULT_SPEECH_ENGINE: TTS_SPEECH_ENGINE = "generative"
|
|
42
|
-
DEFAULT_SPEECH_REGION = "us-east-1"
|
|
39
|
+
DEFAULT_SPEECH_ENGINE: TTSSpeechEngine = "generative"
|
|
43
40
|
DEFAULT_VOICE = "Ruth"
|
|
44
|
-
|
|
41
|
+
DEFAULT_TEXT_TYPE: TTSTextType = "text"
|
|
45
42
|
|
|
46
43
|
|
|
47
44
|
@dataclass
|
|
48
45
|
class _TTSOptions:
|
|
49
46
|
# https://docs.aws.amazon.com/polly/latest/dg/API_SynthesizeSpeech.html
|
|
50
|
-
voice:
|
|
51
|
-
speech_engine:
|
|
52
|
-
|
|
47
|
+
voice: str
|
|
48
|
+
speech_engine: TTSSpeechEngine
|
|
49
|
+
region: str | None
|
|
53
50
|
sample_rate: int
|
|
54
|
-
language:
|
|
51
|
+
language: TTSLanguages | str | None
|
|
52
|
+
text_type: TTSTextType
|
|
55
53
|
|
|
56
54
|
|
|
57
55
|
class TTS(tts.TTS):
|
|
58
56
|
def __init__(
|
|
59
57
|
self,
|
|
60
58
|
*,
|
|
61
|
-
voice:
|
|
62
|
-
language: NotGivenOr[
|
|
63
|
-
speech_engine:
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
59
|
+
voice: str = "Ruth",
|
|
60
|
+
language: NotGivenOr[TTSLanguages | str] = NOT_GIVEN,
|
|
61
|
+
speech_engine: TTSSpeechEngine = "generative",
|
|
62
|
+
text_type: TTSTextType = "text",
|
|
63
|
+
sample_rate: int = 16000,
|
|
64
|
+
region: str | None = None,
|
|
65
|
+
api_key: str | None = None,
|
|
66
|
+
api_secret: str | None = None,
|
|
67
|
+
session: aioboto3.Session | None = None,
|
|
69
68
|
) -> None:
|
|
70
69
|
"""
|
|
71
70
|
Create a new instance of AWS Polly TTS.
|
|
@@ -76,130 +75,111 @@ class TTS(tts.TTS):
|
|
|
76
75
|
See https://docs.aws.amazon.com/polly/latest/dg/API_SynthesizeSpeech.html for more details on the the AWS Polly TTS.
|
|
77
76
|
|
|
78
77
|
Args:
|
|
79
|
-
|
|
80
|
-
language (
|
|
78
|
+
voice (TTSModels, optional): Voice ID to use for the synthesis. Defaults to "Ruth".
|
|
79
|
+
language (TTSLanguages, optional): language code for the Synthesize Speech request. This is only necessary if using a bilingual voice, such as Aditi, which can be used for either Indian English (en-IN) or Hindi (hi-IN).
|
|
80
|
+
speech_engine(TTSSpeechEngine, optional): The engine to use for the synthesis. Defaults to "generative".
|
|
81
|
+
text_type(TTSTextType, optional): Type of text to synthesize. Use "ssml" for SSML-enhanced text. Defaults to "text".
|
|
81
82
|
sample_rate(int, optional): The audio frequency specified in Hz. Defaults to 16000.
|
|
82
|
-
|
|
83
|
-
speech_region(str, optional): The region to use for the synthesis. Defaults to "us-east-1".
|
|
83
|
+
region(str, optional): The region to use for the synthesis. Defaults to "us-east-1".
|
|
84
84
|
api_key(str, optional): AWS access key id.
|
|
85
85
|
api_secret(str, optional): AWS secret access key.
|
|
86
|
+
session(aioboto3.Session, optional): Optional aioboto3 session to use.
|
|
86
87
|
""" # noqa: E501
|
|
87
88
|
super().__init__(
|
|
88
89
|
capabilities=tts.TTSCapabilities(
|
|
89
90
|
streaming=False,
|
|
90
91
|
),
|
|
91
92
|
sample_rate=sample_rate,
|
|
92
|
-
num_channels=
|
|
93
|
+
num_channels=1,
|
|
93
94
|
)
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
95
|
+
self._session = session or aioboto3.Session(
|
|
96
|
+
aws_access_key_id=api_key if is_given(api_key) else None,
|
|
97
|
+
aws_secret_access_key=api_secret if is_given(api_secret) else None,
|
|
98
|
+
region_name=region if is_given(region) else None,
|
|
97
99
|
)
|
|
98
100
|
|
|
99
101
|
self._opts = _TTSOptions(
|
|
100
102
|
voice=voice,
|
|
101
103
|
speech_engine=speech_engine,
|
|
102
|
-
|
|
103
|
-
|
|
104
|
+
text_type=text_type,
|
|
105
|
+
region=region or None,
|
|
106
|
+
language=language or None,
|
|
104
107
|
sample_rate=sample_rate,
|
|
105
108
|
)
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
def
|
|
109
|
-
return self.
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
)
|
|
109
|
+
|
|
110
|
+
@property
|
|
111
|
+
def model(self) -> str:
|
|
112
|
+
return self._opts.speech_engine
|
|
113
|
+
|
|
114
|
+
@property
|
|
115
|
+
def provider(self) -> str:
|
|
116
|
+
return "Amazon Polly"
|
|
115
117
|
|
|
116
118
|
def synthesize(
|
|
119
|
+
self, text: str, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
|
|
120
|
+
) -> ChunkedStream:
|
|
121
|
+
return ChunkedStream(tts=self, text=text, conn_options=conn_options)
|
|
122
|
+
|
|
123
|
+
def update_options(
|
|
117
124
|
self,
|
|
118
|
-
text: str,
|
|
119
125
|
*,
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
126
|
+
voice: NotGivenOr[str] = NOT_GIVEN,
|
|
127
|
+
language: NotGivenOr[str] = NOT_GIVEN,
|
|
128
|
+
speech_engine: NotGivenOr[TTSSpeechEngine] = NOT_GIVEN,
|
|
129
|
+
text_type: NotGivenOr[TTSTextType] = NOT_GIVEN,
|
|
130
|
+
) -> None:
|
|
131
|
+
if is_given(voice):
|
|
132
|
+
self._opts.voice = voice
|
|
133
|
+
if is_given(language):
|
|
134
|
+
self._opts.language = language
|
|
135
|
+
if is_given(speech_engine):
|
|
136
|
+
self._opts.speech_engine = cast(TTSSpeechEngine, speech_engine)
|
|
137
|
+
if is_given(text_type):
|
|
138
|
+
self._opts.text_type = cast(TTSTextType, text_type)
|
|
129
139
|
|
|
130
140
|
|
|
131
141
|
class ChunkedStream(tts.ChunkedStream):
|
|
132
142
|
def __init__(
|
|
133
|
-
self,
|
|
134
|
-
*,
|
|
135
|
-
tts: TTS,
|
|
136
|
-
text: str,
|
|
137
|
-
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
|
138
|
-
opts: _TTSOptions,
|
|
139
|
-
get_client: Callable[[], Any],
|
|
143
|
+
self, *, tts: TTS, text: str, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
|
|
140
144
|
) -> None:
|
|
141
145
|
super().__init__(tts=tts, input_text=text, conn_options=conn_options)
|
|
142
|
-
self.
|
|
143
|
-
self.
|
|
144
|
-
self._segment_id = utils.shortuuid()
|
|
145
|
-
|
|
146
|
-
async def _run(self):
|
|
147
|
-
request_id = utils.shortuuid()
|
|
146
|
+
self._tts = tts
|
|
147
|
+
self._opts = replace(tts._opts)
|
|
148
148
|
|
|
149
|
+
async def _run(self, output_emitter: tts.AudioEmitter) -> None:
|
|
149
150
|
try:
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
151
|
+
config = AioConfig(
|
|
152
|
+
connect_timeout=self._conn_options.timeout,
|
|
153
|
+
read_timeout=10,
|
|
154
|
+
retries={"mode": "standard", "total_max_attempts": 1},
|
|
155
|
+
)
|
|
156
|
+
async with self._tts._session.client("polly", config=config) as client: # type: ignore
|
|
157
|
+
response = await client.synthesize_speech(
|
|
158
|
+
**_strip_nones(
|
|
159
|
+
{
|
|
160
|
+
"Text": self._input_text,
|
|
161
|
+
"OutputFormat": "mp3",
|
|
162
|
+
"Engine": self._opts.speech_engine,
|
|
163
|
+
"VoiceId": self._opts.voice,
|
|
164
|
+
"TextType": self._opts.text_type,
|
|
165
|
+
"SampleRate": str(self._opts.sample_rate),
|
|
166
|
+
"LanguageCode": self._opts.language,
|
|
167
|
+
}
|
|
168
|
+
)
|
|
169
|
+
)
|
|
170
|
+
|
|
163
171
|
if "AudioStream" in response:
|
|
164
|
-
|
|
172
|
+
output_emitter.initialize(
|
|
173
|
+
request_id=response["ResponseMetadata"]["RequestId"],
|
|
165
174
|
sample_rate=self._opts.sample_rate,
|
|
166
175
|
num_channels=1,
|
|
176
|
+
mime_type="audio/mp3",
|
|
167
177
|
)
|
|
168
178
|
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
decoder.push(data)
|
|
175
|
-
finally:
|
|
176
|
-
decoder.end_input()
|
|
177
|
-
|
|
178
|
-
# Start pushing data to the decoder
|
|
179
|
-
push_task = asyncio.create_task(push_data())
|
|
180
|
-
|
|
181
|
-
try:
|
|
182
|
-
# Create emitter and process decoded frames
|
|
183
|
-
emitter = tts.SynthesizedAudioEmitter(
|
|
184
|
-
event_ch=self._event_ch,
|
|
185
|
-
request_id=request_id,
|
|
186
|
-
segment_id=self._segment_id,
|
|
187
|
-
)
|
|
188
|
-
async for frame in decoder:
|
|
189
|
-
emitter.push(frame)
|
|
190
|
-
emitter.flush()
|
|
191
|
-
await push_task
|
|
192
|
-
finally:
|
|
193
|
-
await utils.aio.gracefully_cancel(push_task)
|
|
194
|
-
|
|
195
|
-
except asyncio.TimeoutError as e:
|
|
196
|
-
raise APITimeoutError() from e
|
|
197
|
-
except aiohttp.ClientResponseError as e:
|
|
198
|
-
raise APIStatusError(
|
|
199
|
-
message=e.message,
|
|
200
|
-
status_code=e.status,
|
|
201
|
-
request_id=request_id,
|
|
202
|
-
body=None,
|
|
203
|
-
) from e
|
|
179
|
+
async with response["AudioStream"] as resp:
|
|
180
|
+
async for data, _ in resp.content.iter_chunks():
|
|
181
|
+
output_emitter.push(data)
|
|
182
|
+
except botocore.exceptions.ConnectTimeoutError:
|
|
183
|
+
raise APITimeoutError() from None
|
|
204
184
|
except Exception as e:
|
|
205
185
|
raise APIConnectionError() from e
|