livekit-plugins-aws 1.0.0rc6__py3-none-any.whl → 1.3.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,23 +13,59 @@
13
13
  from __future__ import annotations
14
14
 
15
15
  import asyncio
16
+ import concurrent.futures
17
+ import contextlib
18
+ import os
16
19
  from dataclasses import dataclass
17
-
18
- from amazon_transcribe.client import TranscribeStreamingClient
19
- from amazon_transcribe.model import Result, TranscriptEvent
20
+ from typing import Any
20
21
 
21
22
  from livekit import rtc
22
- from livekit.agents import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions, stt, utils
23
+ from livekit.agents import (
24
+ DEFAULT_API_CONNECT_OPTIONS,
25
+ APIConnectOptions,
26
+ stt,
27
+ utils,
28
+ )
23
29
  from livekit.agents.types import NOT_GIVEN, NotGivenOr
24
30
  from livekit.agents.utils import is_given
31
+ from livekit.agents.voice.io import TimedString
25
32
 
26
33
  from .log import logger
27
- from .utils import get_aws_credentials
34
+ from .utils import DEFAULT_REGION
35
+
36
+ try:
37
+ from aws_sdk_transcribe_streaming.client import TranscribeStreamingClient # type: ignore
38
+ from aws_sdk_transcribe_streaming.config import Config # type: ignore
39
+ from aws_sdk_transcribe_streaming.models import ( # type: ignore
40
+ AudioEvent,
41
+ AudioStream,
42
+ AudioStreamAudioEvent,
43
+ BadRequestException,
44
+ Result,
45
+ StartStreamTranscriptionInput,
46
+ TranscriptEvent,
47
+ TranscriptResultStream,
48
+ )
49
+ from smithy_aws_core.identity.environment import EnvironmentCredentialsResolver
50
+ from smithy_core.aio.interfaces.eventstream import (
51
+ EventPublisher,
52
+ EventReceiver,
53
+ )
54
+
55
+ _AWS_SDK_AVAILABLE = True
56
+ except ImportError:
57
+ _AWS_SDK_AVAILABLE = False
58
+
59
+
60
+ @dataclass
61
+ class Credentials:
62
+ access_key_id: str
63
+ secret_access_key: str
64
+ session_token: str | None = None
28
65
 
29
66
 
30
67
  @dataclass
31
68
  class STTOptions:
32
- speech_region: str
33
69
  sample_rate: int
34
70
  language: str
35
71
  encoding: str
@@ -43,16 +79,15 @@ class STTOptions:
43
79
  enable_partial_results_stabilization: NotGivenOr[bool]
44
80
  partial_results_stability: NotGivenOr[str]
45
81
  language_model_name: NotGivenOr[str]
82
+ region: str
46
83
 
47
84
 
48
85
  class STT(stt.STT):
49
86
  def __init__(
50
87
  self,
51
88
  *,
52
- speech_region: str = "us-east-1",
53
- api_key: NotGivenOr[str] = NOT_GIVEN,
54
- api_secret: NotGivenOr[str] = NOT_GIVEN,
55
- sample_rate: int = 48000,
89
+ region: NotGivenOr[str] = NOT_GIVEN,
90
+ sample_rate: int = 24000,
56
91
  language: str = "en-US",
57
92
  encoding: str = "pcm",
58
93
  vocabulary_name: NotGivenOr[str] = NOT_GIVEN,
@@ -65,14 +100,24 @@ class STT(stt.STT):
65
100
  enable_partial_results_stabilization: NotGivenOr[bool] = NOT_GIVEN,
66
101
  partial_results_stability: NotGivenOr[str] = NOT_GIVEN,
67
102
  language_model_name: NotGivenOr[str] = NOT_GIVEN,
103
+ credentials: NotGivenOr[Credentials] = NOT_GIVEN,
68
104
  ):
69
- super().__init__(capabilities=stt.STTCapabilities(streaming=True, interim_results=True))
70
-
71
- self._api_key, self._api_secret, self._speech_region = get_aws_credentials(
72
- api_key, api_secret, speech_region
105
+ super().__init__(
106
+ capabilities=stt.STTCapabilities(
107
+ streaming=True, interim_results=True, aligned_transcript="word"
108
+ )
73
109
  )
110
+
111
+ if not _AWS_SDK_AVAILABLE:
112
+ raise ImportError(
113
+ "The 'aws_sdk_transcribe_streaming' package is not installed. "
114
+ "This implementation requires Python 3.12+ and the 'aws_sdk_transcribe_streaming' dependency."
115
+ )
116
+
117
+ if not is_given(region):
118
+ region = os.getenv("AWS_REGION") or DEFAULT_REGION
119
+
74
120
  self._config = STTOptions(
75
- speech_region=self._speech_region,
76
121
  language=language,
77
122
  sample_rate=sample_rate,
78
123
  encoding=encoding,
@@ -86,8 +131,26 @@ class STT(stt.STT):
86
131
  enable_partial_results_stabilization=enable_partial_results_stabilization,
87
132
  partial_results_stability=partial_results_stability,
88
133
  language_model_name=language_model_name,
134
+ region=region,
135
+ )
136
+
137
+ self._credentials = credentials if is_given(credentials) else None
138
+
139
+ @property
140
+ def model(self) -> str:
141
+ return (
142
+ self._config.language_model_name
143
+ if is_given(self._config.language_model_name)
144
+ else "unknown"
89
145
  )
90
146
 
147
+ @property
148
+ def provider(self) -> str:
149
+ return "Amazon Transcribe"
150
+
151
+ async def aclose(self) -> None:
152
+ await super().aclose()
153
+
91
154
  async def _recognize_impl(
92
155
  self,
93
156
  buffer: utils.AudioBuffer,
@@ -107,6 +170,7 @@ class STT(stt.STT):
107
170
  stt=self,
108
171
  conn_options=conn_options,
109
172
  opts=self._config,
173
+ credentials=self._credentials,
110
174
  )
111
175
 
112
176
 
@@ -116,66 +180,132 @@ class SpeechStream(stt.SpeechStream):
116
180
  stt: STT,
117
181
  opts: STTOptions,
118
182
  conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
183
+ credentials: Credentials | None = None,
119
184
  ) -> None:
120
185
  super().__init__(stt=stt, conn_options=conn_options, sample_rate=opts.sample_rate)
121
186
  self._opts = opts
122
- self._client = TranscribeStreamingClient(region=self._opts.speech_region)
187
+ self._credentials = credentials
123
188
 
124
189
  async def _run(self) -> None:
125
- live_config = {
126
- "language_code": self._opts.language,
127
- "media_sample_rate_hz": self._opts.sample_rate,
128
- "media_encoding": self._opts.encoding,
129
- "vocabulary_name": self._opts.vocabulary_name,
130
- "session_id": self._opts.session_id,
131
- "vocab_filter_method": self._opts.vocab_filter_method,
132
- "vocab_filter_name": self._opts.vocab_filter_name,
133
- "show_speaker_label": self._opts.show_speaker_label,
134
- "enable_channel_identification": self._opts.enable_channel_identification,
135
- "number_of_channels": self._opts.number_of_channels,
136
- "enable_partial_results_stabilization": self._opts.enable_partial_results_stabilization,
137
- "partial_results_stability": self._opts.partial_results_stability,
138
- "language_model_name": self._opts.language_model_name,
139
- }
140
- filtered_config = {k: v for k, v in live_config.items() if is_given(v)}
141
- stream = await self._client.start_stream_transcription(**filtered_config)
142
-
143
- @utils.log_exceptions(logger=logger)
144
- async def input_generator():
145
- async for frame in self._input_ch:
146
- if isinstance(frame, rtc.AudioFrame):
147
- await stream.input_stream.send_audio_event(audio_chunk=frame.data.tobytes())
148
- await stream.input_stream.end_stream()
149
-
150
- @utils.log_exceptions(logger=logger)
151
- async def handle_transcript_events():
152
- async for event in stream.output_stream:
153
- if isinstance(event, TranscriptEvent):
154
- self._process_transcript_event(event)
155
-
156
- tasks = [
157
- asyncio.create_task(input_generator()),
158
- asyncio.create_task(handle_transcript_events()),
159
- ]
160
- try:
161
- await asyncio.gather(*tasks)
162
- finally:
163
- await utils.aio.gracefully_cancel(*tasks)
164
-
165
- def _process_transcript_event(self, transcript_event: TranscriptEvent):
190
+ while True:
191
+ config_kwargs: dict[str, Any] = {"region": self._opts.region}
192
+ if self._credentials:
193
+ config_kwargs["aws_access_key_id"] = self._credentials.access_key_id
194
+ config_kwargs["aws_secret_access_key"] = self._credentials.secret_access_key
195
+ config_kwargs["aws_session_token"] = self._credentials.session_token
196
+ else:
197
+ config_kwargs["aws_credentials_identity_resolver"] = (
198
+ EnvironmentCredentialsResolver()
199
+ )
200
+
201
+ client: TranscribeStreamingClient = TranscribeStreamingClient(
202
+ config=Config(**config_kwargs)
203
+ )
204
+
205
+ live_config = {
206
+ "language_code": self._opts.language,
207
+ "media_sample_rate_hertz": self._opts.sample_rate,
208
+ "media_encoding": self._opts.encoding,
209
+ "vocabulary_name": self._opts.vocabulary_name,
210
+ "session_id": self._opts.session_id,
211
+ "vocab_filter_method": self._opts.vocab_filter_method,
212
+ "vocab_filter_name": self._opts.vocab_filter_name,
213
+ "show_speaker_label": self._opts.show_speaker_label,
214
+ "enable_channel_identification": self._opts.enable_channel_identification,
215
+ "number_of_channels": self._opts.number_of_channels,
216
+ "enable_partial_results_stabilization": self._opts.enable_partial_results_stabilization,
217
+ "partial_results_stability": self._opts.partial_results_stability,
218
+ "language_model_name": self._opts.language_model_name,
219
+ }
220
+ filtered_config = {k: v for k, v in live_config.items() if v and is_given(v)}
221
+
222
+ try:
223
+ stream = await client.start_stream_transcription(
224
+ input=StartStreamTranscriptionInput(**filtered_config)
225
+ )
226
+
227
+ # Get the output stream
228
+ _, output_stream = await stream.await_output()
229
+
230
+ async def input_generator(
231
+ audio_stream: EventPublisher[AudioStream],
232
+ ) -> None:
233
+ try:
234
+ async for frame in self._input_ch:
235
+ if isinstance(frame, rtc.AudioFrame):
236
+ await audio_stream.send(
237
+ AudioStreamAudioEvent(
238
+ value=AudioEvent(audio_chunk=frame.data.tobytes())
239
+ )
240
+ )
241
+ # Send empty frame to close
242
+ await audio_stream.send(
243
+ AudioStreamAudioEvent(value=AudioEvent(audio_chunk=b""))
244
+ )
245
+ finally:
246
+ with contextlib.suppress(Exception):
247
+ await audio_stream.close()
248
+
249
+ async def handle_transcript_events(
250
+ output_stream: EventReceiver[TranscriptResultStream],
251
+ ) -> None:
252
+ try:
253
+ async for event in output_stream:
254
+ if isinstance(event.value, TranscriptEvent):
255
+ self._process_transcript_event(event.value)
256
+ except concurrent.futures.InvalidStateError:
257
+ logger.warning(
258
+ "AWS Transcribe stream closed unexpectedly (InvalidStateError)"
259
+ )
260
+ pass
261
+
262
+ tasks = [
263
+ asyncio.create_task(input_generator(stream.input_stream)),
264
+ asyncio.create_task(handle_transcript_events(output_stream)),
265
+ ]
266
+ gather_future = asyncio.gather(*tasks)
267
+
268
+ await asyncio.shield(gather_future)
269
+ except BadRequestException as e:
270
+ if e.message and e.message.startswith("Your request timed out"):
271
+ # AWS times out after 15s of inactivity, this tends to happen
272
+ # at the end of the session, when the input is gone, we'll ignore it and
273
+ # just treat it as a silent retry
274
+ logger.info("restarting transcribe session")
275
+ continue
276
+ else:
277
+ raise e
278
+ finally:
279
+ # Close input stream first
280
+ await utils.aio.gracefully_cancel(tasks[0])
281
+
282
+ # Wait for output stream to close cleanly
283
+ try:
284
+ await asyncio.wait_for(tasks[1], timeout=3.0)
285
+ except (asyncio.TimeoutError, asyncio.CancelledError):
286
+ await utils.aio.gracefully_cancel(tasks[1])
287
+
288
+ # Ensure gather future is retrieved to avoid "exception never retrieved"
289
+ with contextlib.suppress(Exception):
290
+ await gather_future
291
+
292
+ def _process_transcript_event(self, transcript_event: TranscriptEvent) -> None:
293
+ if not transcript_event.transcript or not transcript_event.transcript.results:
294
+ return
295
+
166
296
  stream = transcript_event.transcript.results
167
297
  for resp in stream:
168
- if resp.start_time and resp.start_time == 0.0:
298
+ if resp.start_time is not None and resp.start_time == 0.0:
169
299
  self._event_ch.send_nowait(
170
300
  stt.SpeechEvent(type=stt.SpeechEventType.START_OF_SPEECH)
171
301
  )
172
302
 
173
- if resp.end_time and resp.end_time > 0.0:
303
+ if resp.end_time is not None and resp.end_time > 0.0:
174
304
  if resp.is_partial:
175
305
  self._event_ch.send_nowait(
176
306
  stt.SpeechEvent(
177
307
  type=stt.SpeechEventType.INTERIM_TRANSCRIPT,
178
- alternatives=[_streaming_recognize_response_to_speech_data(resp)],
308
+ alternatives=[self._streaming_recognize_response_to_speech_data(resp)],
179
309
  )
180
310
  )
181
311
 
@@ -183,21 +313,34 @@ class SpeechStream(stt.SpeechStream):
183
313
  self._event_ch.send_nowait(
184
314
  stt.SpeechEvent(
185
315
  type=stt.SpeechEventType.FINAL_TRANSCRIPT,
186
- alternatives=[_streaming_recognize_response_to_speech_data(resp)],
316
+ alternatives=[self._streaming_recognize_response_to_speech_data(resp)],
187
317
  )
188
318
  )
189
319
 
190
320
  if not resp.is_partial:
191
321
  self._event_ch.send_nowait(stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH))
192
322
 
323
+ def _streaming_recognize_response_to_speech_data(self, resp: Result) -> stt.SpeechData:
324
+ confidence = 0.0
325
+ if resp.alternatives and (items := resp.alternatives[0].items):
326
+ confidence = items[0].confidence or 0.0
193
327
 
194
- def _streaming_recognize_response_to_speech_data(resp: Result) -> stt.SpeechData:
195
- data = stt.SpeechData(
196
- language="en-US",
197
- start_time=resp.start_time if resp.start_time else 0.0,
198
- end_time=resp.end_time if resp.end_time else 0.0,
199
- confidence=0.0,
200
- text=resp.alternatives[0].transcript if resp.alternatives else "",
201
- )
202
-
203
- return data
328
+ return stt.SpeechData(
329
+ language=resp.language_code or self._opts.language,
330
+ start_time=(resp.start_time or 0.0) + self.start_time_offset,
331
+ end_time=(resp.end_time or 0.0) + self.start_time_offset,
332
+ text=resp.alternatives[0].transcript if resp.alternatives else "",
333
+ confidence=confidence,
334
+ words=[
335
+ TimedString(
336
+ text=item.content,
337
+ start_time=item.start_time + self.start_time_offset,
338
+ end_time=item.end_time + self.start_time_offset,
339
+ start_time_offset=self.start_time_offset,
340
+ confidence=item.confidence or 0.0,
341
+ )
342
+ for item in resp.alternatives[0].items
343
+ ]
344
+ if resp.alternatives and resp.alternatives[0].items
345
+ else None,
346
+ )
@@ -12,20 +12,19 @@
12
12
 
13
13
  from __future__ import annotations
14
14
 
15
- import asyncio
16
- from dataclasses import dataclass
17
- from typing import Any, Callable
15
+ from dataclasses import dataclass, replace
16
+ from typing import cast
18
17
 
19
- import aiohttp
20
- from aiobotocore.session import AioSession, get_session
18
+ import aioboto3 # type: ignore
19
+ import botocore # type: ignore
20
+ import botocore.exceptions # type: ignore
21
+ from aiobotocore.config import AioConfig # type: ignore
21
22
 
22
23
  from livekit.agents import (
23
24
  APIConnectionError,
24
25
  APIConnectOptions,
25
- APIStatusError,
26
26
  APITimeoutError,
27
27
  tts,
28
- utils,
29
28
  )
30
29
  from livekit.agents.types import (
31
30
  DEFAULT_API_CONNECT_OPTIONS,
@@ -34,38 +33,38 @@ from livekit.agents.types import (
34
33
  )
35
34
  from livekit.agents.utils import is_given
36
35
 
37
- from .models import TTS_LANGUAGE, TTS_SPEECH_ENGINE
38
- from .utils import _strip_nones, get_aws_credentials
36
+ from .models import TTSLanguages, TTSSpeechEngine, TTSTextType
37
+ from .utils import _strip_nones
39
38
 
40
- TTS_NUM_CHANNELS: int = 1
41
- DEFAULT_SPEECH_ENGINE: TTS_SPEECH_ENGINE = "generative"
42
- DEFAULT_SPEECH_REGION = "us-east-1"
39
+ DEFAULT_SPEECH_ENGINE: TTSSpeechEngine = "generative"
43
40
  DEFAULT_VOICE = "Ruth"
44
- DEFAULT_SAMPLE_RATE = 16000
41
+ DEFAULT_TEXT_TYPE: TTSTextType = "text"
45
42
 
46
43
 
47
44
  @dataclass
48
45
  class _TTSOptions:
49
46
  # https://docs.aws.amazon.com/polly/latest/dg/API_SynthesizeSpeech.html
50
- voice: NotGivenOr[str]
51
- speech_engine: NotGivenOr[TTS_SPEECH_ENGINE]
52
- speech_region: str
47
+ voice: str
48
+ speech_engine: TTSSpeechEngine
49
+ region: str | None
53
50
  sample_rate: int
54
- language: NotGivenOr[TTS_LANGUAGE | str]
51
+ language: TTSLanguages | str | None
52
+ text_type: TTSTextType
55
53
 
56
54
 
57
55
  class TTS(tts.TTS):
58
56
  def __init__(
59
57
  self,
60
58
  *,
61
- voice: NotGivenOr[str] = NOT_GIVEN,
62
- language: NotGivenOr[TTS_LANGUAGE | str] = NOT_GIVEN,
63
- speech_engine: NotGivenOr[TTS_SPEECH_ENGINE] = NOT_GIVEN,
64
- sample_rate: int = DEFAULT_SAMPLE_RATE,
65
- speech_region: NotGivenOr[str] = DEFAULT_SPEECH_REGION,
66
- api_key: NotGivenOr[str] = NOT_GIVEN,
67
- api_secret: NotGivenOr[str] = NOT_GIVEN,
68
- session: AioSession | None = None,
59
+ voice: str = "Ruth",
60
+ language: NotGivenOr[TTSLanguages | str] = NOT_GIVEN,
61
+ speech_engine: TTSSpeechEngine = "generative",
62
+ text_type: TTSTextType = "text",
63
+ sample_rate: int = 16000,
64
+ region: str | None = None,
65
+ api_key: str | None = None,
66
+ api_secret: str | None = None,
67
+ session: aioboto3.Session | None = None,
69
68
  ) -> None:
70
69
  """
71
70
  Create a new instance of AWS Polly TTS.
@@ -76,130 +75,111 @@ class TTS(tts.TTS):
76
75
  See https://docs.aws.amazon.com/polly/latest/dg/API_SynthesizeSpeech.html for more details on the the AWS Polly TTS.
77
76
 
78
77
  Args:
79
- Voice (TTSModels, optional): Voice ID to use for the synthesis. Defaults to "Ruth".
80
- language (TTS_LANGUAGE, optional): language code for the Synthesize Speech request. This is only necessary if using a bilingual voice, such as Aditi, which can be used for either Indian English (en-IN) or Hindi (hi-IN).
78
+ voice (TTSModels, optional): Voice ID to use for the synthesis. Defaults to "Ruth".
79
+ language (TTSLanguages, optional): language code for the Synthesize Speech request. This is only necessary if using a bilingual voice, such as Aditi, which can be used for either Indian English (en-IN) or Hindi (hi-IN).
80
+ speech_engine(TTSSpeechEngine, optional): The engine to use for the synthesis. Defaults to "generative".
81
+ text_type(TTSTextType, optional): Type of text to synthesize. Use "ssml" for SSML-enhanced text. Defaults to "text".
81
82
  sample_rate(int, optional): The audio frequency specified in Hz. Defaults to 16000.
82
- speech_engine(TTS_SPEECH_ENGINE, optional): The engine to use for the synthesis. Defaults to "generative".
83
- speech_region(str, optional): The region to use for the synthesis. Defaults to "us-east-1".
83
+ region(str, optional): The region to use for the synthesis. Defaults to "us-east-1".
84
84
  api_key(str, optional): AWS access key id.
85
85
  api_secret(str, optional): AWS secret access key.
86
+ session(aioboto3.Session, optional): Optional aioboto3 session to use.
86
87
  """ # noqa: E501
87
88
  super().__init__(
88
89
  capabilities=tts.TTSCapabilities(
89
90
  streaming=False,
90
91
  ),
91
92
  sample_rate=sample_rate,
92
- num_channels=TTS_NUM_CHANNELS,
93
+ num_channels=1,
93
94
  )
94
-
95
- self._api_key, self._api_secret, self._speech_region = get_aws_credentials(
96
- api_key, api_secret, speech_region
95
+ self._session = session or aioboto3.Session(
96
+ aws_access_key_id=api_key if is_given(api_key) else None,
97
+ aws_secret_access_key=api_secret if is_given(api_secret) else None,
98
+ region_name=region if is_given(region) else None,
97
99
  )
98
100
 
99
101
  self._opts = _TTSOptions(
100
102
  voice=voice,
101
103
  speech_engine=speech_engine,
102
- speech_region=self._speech_region,
103
- language=language,
104
+ text_type=text_type,
105
+ region=region or None,
106
+ language=language or None,
104
107
  sample_rate=sample_rate,
105
108
  )
106
- self._session = session or get_session()
107
-
108
- def _get_client(self):
109
- return self._session.create_client(
110
- "polly",
111
- region_name=self._opts.speech_region,
112
- aws_access_key_id=self._api_key,
113
- aws_secret_access_key=self._api_secret,
114
- )
109
+
110
+ @property
111
+ def model(self) -> str:
112
+ return self._opts.speech_engine
113
+
114
+ @property
115
+ def provider(self) -> str:
116
+ return "Amazon Polly"
115
117
 
116
118
  def synthesize(
119
+ self, text: str, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
120
+ ) -> ChunkedStream:
121
+ return ChunkedStream(tts=self, text=text, conn_options=conn_options)
122
+
123
+ def update_options(
117
124
  self,
118
- text: str,
119
125
  *,
120
- conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
121
- ) -> ChunkedStream:
122
- return ChunkedStream(
123
- tts=self,
124
- text=text,
125
- conn_options=conn_options,
126
- opts=self._opts,
127
- get_client=self._get_client,
128
- )
126
+ voice: NotGivenOr[str] = NOT_GIVEN,
127
+ language: NotGivenOr[str] = NOT_GIVEN,
128
+ speech_engine: NotGivenOr[TTSSpeechEngine] = NOT_GIVEN,
129
+ text_type: NotGivenOr[TTSTextType] = NOT_GIVEN,
130
+ ) -> None:
131
+ if is_given(voice):
132
+ self._opts.voice = voice
133
+ if is_given(language):
134
+ self._opts.language = language
135
+ if is_given(speech_engine):
136
+ self._opts.speech_engine = cast(TTSSpeechEngine, speech_engine)
137
+ if is_given(text_type):
138
+ self._opts.text_type = cast(TTSTextType, text_type)
129
139
 
130
140
 
131
141
  class ChunkedStream(tts.ChunkedStream):
132
142
  def __init__(
133
- self,
134
- *,
135
- tts: TTS,
136
- text: str,
137
- conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
138
- opts: _TTSOptions,
139
- get_client: Callable[[], Any],
143
+ self, *, tts: TTS, text: str, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
140
144
  ) -> None:
141
145
  super().__init__(tts=tts, input_text=text, conn_options=conn_options)
142
- self._opts = opts
143
- self._get_client = get_client
144
- self._segment_id = utils.shortuuid()
145
-
146
- async def _run(self):
147
- request_id = utils.shortuuid()
146
+ self._tts = tts
147
+ self._opts = replace(tts._opts)
148
148
 
149
+ async def _run(self, output_emitter: tts.AudioEmitter) -> None:
149
150
  try:
150
- async with self._get_client() as client:
151
- params = {
152
- "Text": self._input_text,
153
- "OutputFormat": "mp3",
154
- "Engine": self._opts.speech_engine
155
- if is_given(self._opts.speech_engine)
156
- else DEFAULT_SPEECH_ENGINE,
157
- "VoiceId": self._opts.voice if is_given(self._opts.voice) else DEFAULT_VOICE,
158
- "TextType": "text",
159
- "SampleRate": str(self._opts.sample_rate),
160
- "LanguageCode": self._opts.language if is_given(self._opts.language) else None,
161
- }
162
- response = await client.synthesize_speech(**_strip_nones(params))
151
+ config = AioConfig(
152
+ connect_timeout=self._conn_options.timeout,
153
+ read_timeout=10,
154
+ retries={"mode": "standard", "total_max_attempts": 1},
155
+ )
156
+ async with self._tts._session.client("polly", config=config) as client: # type: ignore
157
+ response = await client.synthesize_speech(
158
+ **_strip_nones(
159
+ {
160
+ "Text": self._input_text,
161
+ "OutputFormat": "mp3",
162
+ "Engine": self._opts.speech_engine,
163
+ "VoiceId": self._opts.voice,
164
+ "TextType": self._opts.text_type,
165
+ "SampleRate": str(self._opts.sample_rate),
166
+ "LanguageCode": self._opts.language,
167
+ }
168
+ )
169
+ )
170
+
163
171
  if "AudioStream" in response:
164
- decoder = utils.codecs.AudioStreamDecoder(
172
+ output_emitter.initialize(
173
+ request_id=response["ResponseMetadata"]["RequestId"],
165
174
  sample_rate=self._opts.sample_rate,
166
175
  num_channels=1,
176
+ mime_type="audio/mp3",
167
177
  )
168
178
 
169
- # Create a task to push data to the decoder
170
- async def push_data():
171
- try:
172
- async with response["AudioStream"] as resp:
173
- async for data, _ in resp.content.iter_chunks():
174
- decoder.push(data)
175
- finally:
176
- decoder.end_input()
177
-
178
- # Start pushing data to the decoder
179
- push_task = asyncio.create_task(push_data())
180
-
181
- try:
182
- # Create emitter and process decoded frames
183
- emitter = tts.SynthesizedAudioEmitter(
184
- event_ch=self._event_ch,
185
- request_id=request_id,
186
- segment_id=self._segment_id,
187
- )
188
- async for frame in decoder:
189
- emitter.push(frame)
190
- emitter.flush()
191
- await push_task
192
- finally:
193
- await utils.aio.gracefully_cancel(push_task)
194
-
195
- except asyncio.TimeoutError as e:
196
- raise APITimeoutError() from e
197
- except aiohttp.ClientResponseError as e:
198
- raise APIStatusError(
199
- message=e.message,
200
- status_code=e.status,
201
- request_id=request_id,
202
- body=None,
203
- ) from e
179
+ async with response["AudioStream"] as resp:
180
+ async for data, _ in resp.content.iter_chunks():
181
+ output_emitter.push(data)
182
+ except botocore.exceptions.ConnectTimeoutError:
183
+ raise APITimeoutError() from None
204
184
  except Exception as e:
205
185
  raise APIConnectionError() from e