livekit-plugins-google 0.11.3__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- livekit/plugins/google/beta/realtime/__init__.py +1 -5
- livekit/plugins/google/beta/realtime/api_proto.py +2 -4
- livekit/plugins/google/beta/realtime/realtime_api.py +407 -449
- livekit/plugins/google/llm.py +158 -220
- livekit/plugins/google/stt.py +80 -115
- livekit/plugins/google/tts.py +40 -56
- livekit/plugins/google/utils.py +251 -0
- livekit/plugins/google/version.py +1 -1
- {livekit_plugins_google-0.11.3.dist-info → livekit_plugins_google-1.0.0.dist-info}/METADATA +11 -21
- livekit_plugins_google-1.0.0.dist-info/RECORD +16 -0
- {livekit_plugins_google-0.11.3.dist-info → livekit_plugins_google-1.0.0.dist-info}/WHEEL +1 -2
- livekit/plugins/google/_utils.py +0 -199
- livekit/plugins/google/beta/realtime/transcriber.py +0 -270
- livekit_plugins_google-0.11.3.dist-info/RECORD +0 -18
- livekit_plugins_google-0.11.3.dist-info/top_level.txt +0 -1
livekit/plugins/google/stt.py
CHANGED
@@ -19,8 +19,14 @@ import dataclasses
|
|
19
19
|
import time
|
20
20
|
import weakref
|
21
21
|
from dataclasses import dataclass
|
22
|
-
from typing import Callable,
|
22
|
+
from typing import Callable, Union
|
23
23
|
|
24
|
+
from google.api_core.client_options import ClientOptions
|
25
|
+
from google.api_core.exceptions import DeadlineExceeded, GoogleAPICallError
|
26
|
+
from google.auth import default as gauth_default
|
27
|
+
from google.auth.exceptions import DefaultCredentialsError
|
28
|
+
from google.cloud.speech_v2 import SpeechAsyncClient
|
29
|
+
from google.cloud.speech_v2.types import cloud_speech
|
24
30
|
from livekit import rtc
|
25
31
|
from livekit.agents import (
|
26
32
|
DEFAULT_API_CONNECT_OPTIONS,
|
@@ -31,19 +37,17 @@ from livekit.agents import (
|
|
31
37
|
stt,
|
32
38
|
utils,
|
33
39
|
)
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
from
|
39
|
-
from google.cloud.speech_v2 import SpeechAsyncClient
|
40
|
-
from google.cloud.speech_v2.types import cloud_speech
|
40
|
+
from livekit.agents.types import (
|
41
|
+
NOT_GIVEN,
|
42
|
+
NotGivenOr,
|
43
|
+
)
|
44
|
+
from livekit.agents.utils import is_given
|
41
45
|
|
42
46
|
from .log import logger
|
43
47
|
from .models import SpeechLanguages, SpeechModels
|
44
48
|
|
45
49
|
LgType = Union[SpeechLanguages, str]
|
46
|
-
LanguageCode = Union[LgType,
|
50
|
+
LanguageCode = Union[LgType, list[LgType]]
|
47
51
|
|
48
52
|
# Google STT has a timeout of 5 mins, we'll attempt to restart the session
|
49
53
|
# before that timeout is reached
|
@@ -56,25 +60,23 @@ _min_confidence = 0.65
|
|
56
60
|
# This class is only be used internally to encapsulate the options
|
57
61
|
@dataclass
|
58
62
|
class STTOptions:
|
59
|
-
languages:
|
63
|
+
languages: list[LgType]
|
60
64
|
detect_language: bool
|
61
65
|
interim_results: bool
|
62
66
|
punctuate: bool
|
63
67
|
spoken_punctuation: bool
|
64
68
|
model: SpeechModels | str
|
65
69
|
sample_rate: int
|
66
|
-
keywords:
|
70
|
+
keywords: NotGivenOr[list[tuple[str, float]]] = NOT_GIVEN
|
67
71
|
|
68
72
|
def build_adaptation(self) -> cloud_speech.SpeechAdaptation | None:
|
69
|
-
if self.keywords:
|
73
|
+
if is_given(self.keywords):
|
70
74
|
return cloud_speech.SpeechAdaptation(
|
71
75
|
phrase_sets=[
|
72
76
|
cloud_speech.SpeechAdaptation.AdaptationPhraseSet(
|
73
77
|
inline_phrase_set=cloud_speech.PhraseSet(
|
74
78
|
phrases=[
|
75
|
-
cloud_speech.PhraseSet.Phrase(
|
76
|
-
value=keyword, boost=boost
|
77
|
-
)
|
79
|
+
cloud_speech.PhraseSet.Phrase(value=keyword, boost=boost)
|
78
80
|
for keyword, boost in self.keywords
|
79
81
|
]
|
80
82
|
)
|
@@ -96,9 +98,9 @@ class STT(stt.STT):
|
|
96
98
|
model: SpeechModels | str = "latest_long",
|
97
99
|
location: str = "global",
|
98
100
|
sample_rate: int = 16000,
|
99
|
-
credentials_info: dict
|
100
|
-
credentials_file: str
|
101
|
-
keywords:
|
101
|
+
credentials_info: NotGivenOr[dict] = NOT_GIVEN,
|
102
|
+
credentials_file: NotGivenOr[str] = NOT_GIVEN,
|
103
|
+
keywords: NotGivenOr[list[tuple[str, float]]] = NOT_GIVEN,
|
102
104
|
):
|
103
105
|
"""
|
104
106
|
Create a new instance of Google STT.
|
@@ -120,15 +122,13 @@ class STT(stt.STT):
|
|
120
122
|
credentials_file(str): the credentials file to use for recognition (default: None)
|
121
123
|
keywords(List[tuple[str, float]]): list of keywords to recognize (default: None)
|
122
124
|
"""
|
123
|
-
super().__init__(
|
124
|
-
capabilities=stt.STTCapabilities(streaming=True, interim_results=True)
|
125
|
-
)
|
125
|
+
super().__init__(capabilities=stt.STTCapabilities(streaming=True, interim_results=True))
|
126
126
|
|
127
127
|
self._location = location
|
128
128
|
self._credentials_info = credentials_info
|
129
129
|
self._credentials_file = credentials_file
|
130
130
|
|
131
|
-
if credentials_file
|
131
|
+
if not is_given(credentials_file) and not is_given(credentials_info):
|
132
132
|
try:
|
133
133
|
gauth_default()
|
134
134
|
except DefaultCredentialsError:
|
@@ -136,7 +136,7 @@ class STT(stt.STT):
|
|
136
136
|
"Application default credentials must be available "
|
137
137
|
"when using Google STT without explicitly passing "
|
138
138
|
"credentials through credentials_info or credentials_file."
|
139
|
-
)
|
139
|
+
) from None
|
140
140
|
|
141
141
|
if isinstance(languages, str):
|
142
142
|
languages = [languages]
|
@@ -163,23 +163,17 @@ class STT(stt.STT):
|
|
163
163
|
client_options = None
|
164
164
|
client: SpeechAsyncClient | None = None
|
165
165
|
if self._location != "global":
|
166
|
-
client_options = ClientOptions(
|
167
|
-
|
168
|
-
)
|
169
|
-
if self._credentials_info:
|
166
|
+
client_options = ClientOptions(api_endpoint=f"{self._location}-speech.googleapis.com")
|
167
|
+
if is_given(self._credentials_info):
|
170
168
|
client = SpeechAsyncClient.from_service_account_info(
|
171
|
-
self._credentials_info,
|
172
|
-
client_options=client_options,
|
169
|
+
self._credentials_info, client_options=client_options
|
173
170
|
)
|
174
|
-
elif self._credentials_file:
|
171
|
+
elif is_given(self._credentials_file):
|
175
172
|
client = SpeechAsyncClient.from_service_account_file(
|
176
|
-
self._credentials_file,
|
177
|
-
client_options=client_options,
|
173
|
+
self._credentials_file, client_options=client_options
|
178
174
|
)
|
179
175
|
else:
|
180
|
-
client = SpeechAsyncClient(
|
181
|
-
client_options=client_options,
|
182
|
-
)
|
176
|
+
client = SpeechAsyncClient(client_options=client_options)
|
183
177
|
assert client is not None
|
184
178
|
return client
|
185
179
|
|
@@ -196,19 +190,17 @@ class STT(stt.STT):
|
|
196
190
|
_, project_id = ga_default()
|
197
191
|
return f"projects/{project_id}/locations/{self._location}/recognizers/_"
|
198
192
|
|
199
|
-
def _sanitize_options(self, *, language: str
|
193
|
+
def _sanitize_options(self, *, language: NotGivenOr[str] = NOT_GIVEN) -> STTOptions:
|
200
194
|
config = dataclasses.replace(self._config)
|
201
195
|
|
202
|
-
if language:
|
196
|
+
if is_given(language):
|
203
197
|
config.languages = [language]
|
204
198
|
|
205
199
|
if not isinstance(config.languages, list):
|
206
200
|
config.languages = [config.languages]
|
207
201
|
elif not config.detect_language:
|
208
202
|
if len(config.languages) > 1:
|
209
|
-
logger.warning(
|
210
|
-
"multiple languages provided, but language detection is disabled"
|
211
|
-
)
|
203
|
+
logger.warning("multiple languages provided, but language detection is disabled")
|
212
204
|
config.languages = [config.languages[0]]
|
213
205
|
|
214
206
|
return config
|
@@ -217,7 +209,7 @@ class STT(stt.STT):
|
|
217
209
|
self,
|
218
210
|
buffer: utils.AudioBuffer,
|
219
211
|
*,
|
220
|
-
language: SpeechLanguages | str
|
212
|
+
language: NotGivenOr[SpeechLanguages | str] = NOT_GIVEN,
|
221
213
|
conn_options: APIConnectOptions,
|
222
214
|
) -> stt.SpeechEvent:
|
223
215
|
config = self._sanitize_options(language=language)
|
@@ -252,21 +244,18 @@ class STT(stt.STT):
|
|
252
244
|
|
253
245
|
return _recognize_response_to_speech_event(raw)
|
254
246
|
except DeadlineExceeded:
|
255
|
-
raise APITimeoutError()
|
247
|
+
raise APITimeoutError() from None
|
256
248
|
except GoogleAPICallError as e:
|
257
|
-
raise APIStatusError(
|
258
|
-
e.message,
|
259
|
-
status_code=e.code or -1,
|
260
|
-
)
|
249
|
+
raise APIStatusError(e.message, status_code=e.code or -1) from None
|
261
250
|
except Exception as e:
|
262
251
|
raise APIConnectionError() from e
|
263
252
|
|
264
253
|
def stream(
|
265
254
|
self,
|
266
255
|
*,
|
267
|
-
language: SpeechLanguages | str
|
256
|
+
language: NotGivenOr[SpeechLanguages | str] = NOT_GIVEN,
|
268
257
|
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
269
|
-
) ->
|
258
|
+
) -> SpeechStream:
|
270
259
|
config = self._sanitize_options(language=language)
|
271
260
|
stream = SpeechStream(
|
272
261
|
stt=self,
|
@@ -281,34 +270,34 @@ class STT(stt.STT):
|
|
281
270
|
def update_options(
|
282
271
|
self,
|
283
272
|
*,
|
284
|
-
languages: LanguageCode
|
285
|
-
detect_language: bool
|
286
|
-
interim_results: bool
|
287
|
-
punctuate: bool
|
288
|
-
spoken_punctuation: bool
|
289
|
-
model: SpeechModels
|
290
|
-
location: str
|
291
|
-
keywords:
|
273
|
+
languages: NotGivenOr[LanguageCode] = NOT_GIVEN,
|
274
|
+
detect_language: NotGivenOr[bool] = NOT_GIVEN,
|
275
|
+
interim_results: NotGivenOr[bool] = NOT_GIVEN,
|
276
|
+
punctuate: NotGivenOr[bool] = NOT_GIVEN,
|
277
|
+
spoken_punctuation: NotGivenOr[bool] = NOT_GIVEN,
|
278
|
+
model: NotGivenOr[SpeechModels] = NOT_GIVEN,
|
279
|
+
location: NotGivenOr[str] = NOT_GIVEN,
|
280
|
+
keywords: NotGivenOr[list[tuple[str, float]]] = NOT_GIVEN,
|
292
281
|
):
|
293
|
-
if languages
|
282
|
+
if is_given(languages):
|
294
283
|
if isinstance(languages, str):
|
295
284
|
languages = [languages]
|
296
285
|
self._config.languages = languages
|
297
|
-
if detect_language
|
286
|
+
if is_given(detect_language):
|
298
287
|
self._config.detect_language = detect_language
|
299
|
-
if interim_results
|
288
|
+
if is_given(interim_results):
|
300
289
|
self._config.interim_results = interim_results
|
301
|
-
if punctuate
|
290
|
+
if is_given(punctuate):
|
302
291
|
self._config.punctuate = punctuate
|
303
|
-
if spoken_punctuation
|
292
|
+
if is_given(spoken_punctuation):
|
304
293
|
self._config.spoken_punctuation = spoken_punctuation
|
305
|
-
if model
|
294
|
+
if is_given(model):
|
306
295
|
self._config.model = model
|
307
|
-
if location
|
296
|
+
if is_given(location):
|
308
297
|
self._location = location
|
309
298
|
# if location is changed, fetch a new client and recognizer as per the new location
|
310
299
|
self._pool.invalidate()
|
311
|
-
if keywords
|
300
|
+
if is_given(keywords):
|
312
301
|
self._config.keywords = keywords
|
313
302
|
|
314
303
|
for stream in self._streams:
|
@@ -337,9 +326,7 @@ class SpeechStream(stt.SpeechStream):
|
|
337
326
|
recognizer_cb: Callable[[SpeechAsyncClient], str],
|
338
327
|
config: STTOptions,
|
339
328
|
) -> None:
|
340
|
-
super().__init__(
|
341
|
-
stt=stt, conn_options=conn_options, sample_rate=config.sample_rate
|
342
|
-
)
|
329
|
+
super().__init__(stt=stt, conn_options=conn_options, sample_rate=config.sample_rate)
|
343
330
|
|
344
331
|
self._pool = pool
|
345
332
|
self._recognizer_cb = recognizer_cb
|
@@ -350,29 +337,29 @@ class SpeechStream(stt.SpeechStream):
|
|
350
337
|
def update_options(
|
351
338
|
self,
|
352
339
|
*,
|
353
|
-
languages: LanguageCode
|
354
|
-
detect_language: bool
|
355
|
-
interim_results: bool
|
356
|
-
punctuate: bool
|
357
|
-
spoken_punctuation: bool
|
358
|
-
model: SpeechModels
|
359
|
-
keywords:
|
340
|
+
languages: NotGivenOr[LanguageCode] = NOT_GIVEN,
|
341
|
+
detect_language: NotGivenOr[bool] = NOT_GIVEN,
|
342
|
+
interim_results: NotGivenOr[bool] = NOT_GIVEN,
|
343
|
+
punctuate: NotGivenOr[bool] = NOT_GIVEN,
|
344
|
+
spoken_punctuation: NotGivenOr[bool] = NOT_GIVEN,
|
345
|
+
model: NotGivenOr[SpeechModels] = NOT_GIVEN,
|
346
|
+
keywords: NotGivenOr[list[tuple[str, float]]] = NOT_GIVEN,
|
360
347
|
):
|
361
|
-
if languages
|
348
|
+
if is_given(languages):
|
362
349
|
if isinstance(languages, str):
|
363
350
|
languages = [languages]
|
364
351
|
self._config.languages = languages
|
365
|
-
if detect_language
|
352
|
+
if is_given(detect_language):
|
366
353
|
self._config.detect_language = detect_language
|
367
|
-
if interim_results
|
354
|
+
if is_given(interim_results):
|
368
355
|
self._config.interim_results = interim_results
|
369
|
-
if punctuate
|
356
|
+
if is_given(punctuate):
|
370
357
|
self._config.punctuate = punctuate
|
371
|
-
if spoken_punctuation
|
358
|
+
if is_given(spoken_punctuation):
|
372
359
|
self._config.spoken_punctuation = spoken_punctuation
|
373
|
-
if model
|
360
|
+
if is_given(model):
|
374
361
|
self._config.model = model
|
375
|
-
if keywords
|
362
|
+
if is_given(keywords):
|
376
363
|
self._config.keywords = keywords
|
377
364
|
|
378
365
|
self._reconnect_event.set()
|
@@ -380,9 +367,7 @@ class SpeechStream(stt.SpeechStream):
|
|
380
367
|
async def _run(self) -> None:
|
381
368
|
# google requires a async generator when calling streaming_recognize
|
382
369
|
# this function basically convert the queue into a async generator
|
383
|
-
async def input_generator(
|
384
|
-
client: SpeechAsyncClient, should_stop: asyncio.Event
|
385
|
-
):
|
370
|
+
async def input_generator(client: SpeechAsyncClient, should_stop: asyncio.Event):
|
386
371
|
try:
|
387
372
|
# first request should contain the config
|
388
373
|
yield cloud_speech.StreamingRecognizeRequest(
|
@@ -398,14 +383,10 @@ class SpeechStream(stt.SpeechStream):
|
|
398
383
|
return
|
399
384
|
|
400
385
|
if isinstance(frame, rtc.AudioFrame):
|
401
|
-
yield cloud_speech.StreamingRecognizeRequest(
|
402
|
-
audio=frame.data.tobytes()
|
403
|
-
)
|
386
|
+
yield cloud_speech.StreamingRecognizeRequest(audio=frame.data.tobytes())
|
404
387
|
|
405
388
|
except Exception:
|
406
|
-
logger.exception(
|
407
|
-
"an error occurred while streaming input to google STT"
|
408
|
-
)
|
389
|
+
logger.exception("an error occurred while streaming input to google STT")
|
409
390
|
|
410
391
|
async def process_stream(client: SpeechAsyncClient, stream):
|
411
392
|
has_started = False
|
@@ -421,7 +402,7 @@ class SpeechStream(stt.SpeechStream):
|
|
421
402
|
|
422
403
|
if (
|
423
404
|
resp.speech_event_type
|
424
|
-
== cloud_speech.StreamingRecognizeResponse.SpeechEventType.SPEECH_EVENT_TYPE_UNSPECIFIED
|
405
|
+
== cloud_speech.StreamingRecognizeResponse.SpeechEventType.SPEECH_EVENT_TYPE_UNSPECIFIED # noqa: E501
|
425
406
|
):
|
426
407
|
result = resp.results[0]
|
427
408
|
speech_data = _streaming_recognize_response_to_speech_data(resp)
|
@@ -442,19 +423,14 @@ class SpeechStream(stt.SpeechStream):
|
|
442
423
|
alternatives=[speech_data],
|
443
424
|
)
|
444
425
|
)
|
445
|
-
if (
|
446
|
-
time.time() - self._session_connected_at
|
447
|
-
> _max_session_duration
|
448
|
-
):
|
426
|
+
if time.time() - self._session_connected_at > _max_session_duration:
|
449
427
|
logger.debug(
|
450
428
|
"Google STT maximum connection time reached. Reconnecting..."
|
451
429
|
)
|
452
430
|
self._pool.remove(client)
|
453
431
|
if has_started:
|
454
432
|
self._event_ch.send_nowait(
|
455
|
-
stt.SpeechEvent(
|
456
|
-
type=stt.SpeechEventType.END_OF_SPEECH
|
457
|
-
)
|
433
|
+
stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH)
|
458
434
|
)
|
459
435
|
has_started = False
|
460
436
|
self._reconnect_event.set()
|
@@ -498,12 +474,8 @@ class SpeechStream(stt.SpeechStream):
|
|
498
474
|
)
|
499
475
|
self._session_connected_at = time.time()
|
500
476
|
|
501
|
-
process_stream_task = asyncio.create_task(
|
502
|
-
|
503
|
-
)
|
504
|
-
wait_reconnect_task = asyncio.create_task(
|
505
|
-
self._reconnect_event.wait()
|
506
|
-
)
|
477
|
+
process_stream_task = asyncio.create_task(process_stream(client, stream))
|
478
|
+
wait_reconnect_task = asyncio.create_task(self._reconnect_event.wait())
|
507
479
|
|
508
480
|
try:
|
509
481
|
done, _ = await asyncio.wait(
|
@@ -517,17 +489,12 @@ class SpeechStream(stt.SpeechStream):
|
|
517
489
|
break
|
518
490
|
self._reconnect_event.clear()
|
519
491
|
finally:
|
520
|
-
await utils.aio.gracefully_cancel(
|
521
|
-
process_stream_task, wait_reconnect_task
|
522
|
-
)
|
492
|
+
await utils.aio.gracefully_cancel(process_stream_task, wait_reconnect_task)
|
523
493
|
should_stop.set()
|
524
494
|
except DeadlineExceeded:
|
525
|
-
raise APITimeoutError()
|
495
|
+
raise APITimeoutError() from None
|
526
496
|
except GoogleAPICallError as e:
|
527
|
-
raise APIStatusError(
|
528
|
-
e.message,
|
529
|
-
status_code=e.code or -1,
|
530
|
-
)
|
497
|
+
raise APIStatusError(e.message, status_code=e.code or -1) from None
|
531
498
|
except Exception as e:
|
532
499
|
raise APIConnectionError() from e
|
533
500
|
|
@@ -580,8 +547,6 @@ def _streaming_recognize_response_to_speech_data(
|
|
580
547
|
if text == "":
|
581
548
|
return None
|
582
549
|
|
583
|
-
data = stt.SpeechData(
|
584
|
-
language=lg, start_time=0, end_time=0, confidence=confidence, text=text
|
585
|
-
)
|
550
|
+
data = stt.SpeechData(language=lg, start_time=0, end_time=0, confidence=confidence, text=text)
|
586
551
|
|
587
552
|
return data
|
livekit/plugins/google/tts.py
CHANGED
@@ -15,8 +15,11 @@
|
|
15
15
|
from __future__ import annotations
|
16
16
|
|
17
17
|
from dataclasses import dataclass
|
18
|
-
from typing import Optional
|
19
18
|
|
19
|
+
from google.api_core.client_options import ClientOptions
|
20
|
+
from google.api_core.exceptions import DeadlineExceeded, GoogleAPICallError
|
21
|
+
from google.cloud import texttospeech
|
22
|
+
from google.cloud.texttospeech_v1.types import SsmlVoiceGender, SynthesizeSpeechResponse
|
20
23
|
from livekit.agents import (
|
21
24
|
APIConnectionError,
|
22
25
|
APIConnectOptions,
|
@@ -25,13 +28,12 @@ from livekit.agents import (
|
|
25
28
|
tts,
|
26
29
|
utils,
|
27
30
|
)
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
from .models import Gender, SpeechLanguages
|
31
|
+
from livekit.agents.types import (
|
32
|
+
DEFAULT_API_CONNECT_OPTIONS,
|
33
|
+
NOT_GIVEN,
|
34
|
+
NotGivenOr,
|
35
|
+
)
|
36
|
+
from livekit.agents.utils import is_given
|
35
37
|
|
36
38
|
|
37
39
|
@dataclass
|
@@ -44,16 +46,14 @@ class TTS(tts.TTS):
|
|
44
46
|
def __init__(
|
45
47
|
self,
|
46
48
|
*,
|
47
|
-
|
48
|
-
gender: Gender | str = "neutral",
|
49
|
-
voice_name: str = "", # Not required
|
49
|
+
voice: NotGivenOr[texttospeech.VoiceSelectionParams] = NOT_GIVEN,
|
50
50
|
sample_rate: int = 24000,
|
51
51
|
pitch: int = 0,
|
52
52
|
effects_profile_id: str = "",
|
53
53
|
speaking_rate: float = 1.0,
|
54
54
|
location: str = "global",
|
55
|
-
credentials_info: dict
|
56
|
-
credentials_file: str
|
55
|
+
credentials_info: NotGivenOr[dict] = NOT_GIVEN,
|
56
|
+
credentials_file: NotGivenOr[str] = NOT_GIVEN,
|
57
57
|
) -> None:
|
58
58
|
"""
|
59
59
|
Create a new instance of Google TTS.
|
@@ -63,9 +63,7 @@ class TTS(tts.TTS):
|
|
63
63
|
environmental variable.
|
64
64
|
|
65
65
|
Args:
|
66
|
-
|
67
|
-
gender (Gender | str, optional): Voice gender ("male", "female", "neutral"). Default is "neutral".
|
68
|
-
voice_name (str, optional): Specific voice name. Default is an empty string.
|
66
|
+
voice (texttospeech.VoiceSelectionParams, optional): Voice selection parameters.
|
69
67
|
sample_rate (int, optional): Audio sample rate in Hz. Default is 24000.
|
70
68
|
location (str, optional): Location for the TTS client. Default is "global".
|
71
69
|
pitch (float, optional): Speaking pitch, ranging from -20.0 to 20.0 semitones relative to the original pitch. Default is 0.
|
@@ -73,7 +71,7 @@ class TTS(tts.TTS):
|
|
73
71
|
speaking_rate (float, optional): Speed of speech. Default is 1.0.
|
74
72
|
credentials_info (dict, optional): Dictionary containing Google Cloud credentials. Default is None.
|
75
73
|
credentials_file (str, optional): Path to the Google Cloud credentials JSON file. Default is None.
|
76
|
-
"""
|
74
|
+
""" # noqa: E501
|
77
75
|
|
78
76
|
super().__init__(
|
79
77
|
capabilities=tts.TTSCapabilities(
|
@@ -87,11 +85,12 @@ class TTS(tts.TTS):
|
|
87
85
|
self._credentials_info = credentials_info
|
88
86
|
self._credentials_file = credentials_file
|
89
87
|
self._location = location
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
88
|
+
if not is_given(voice):
|
89
|
+
voice = texttospeech.VoiceSelectionParams(
|
90
|
+
name="",
|
91
|
+
language_code="en-US",
|
92
|
+
ssml_gender=SsmlVoiceGender.NEUTRAL,
|
93
|
+
)
|
95
94
|
|
96
95
|
self._opts = _TTSOptions(
|
97
96
|
voice=voice,
|
@@ -107,26 +106,20 @@ class TTS(tts.TTS):
|
|
107
106
|
def update_options(
|
108
107
|
self,
|
109
108
|
*,
|
110
|
-
|
111
|
-
|
112
|
-
voice_name: str = "", # Not required
|
113
|
-
speaking_rate: float = 1.0,
|
109
|
+
voice: NotGivenOr[texttospeech.VoiceSelectionParams] = NOT_GIVEN,
|
110
|
+
speaking_rate: NotGivenOr[float] = NOT_GIVEN,
|
114
111
|
) -> None:
|
115
112
|
"""
|
116
113
|
Update the TTS options.
|
117
114
|
|
118
115
|
Args:
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
language_code=language,
|
127
|
-
ssml_gender=_gender_from_str(gender),
|
128
|
-
)
|
129
|
-
self._opts.audio_config.speaking_rate = speaking_rate
|
116
|
+
voice (texttospeech.VoiceSelectionParams, optional): Voice selection parameters.
|
117
|
+
speaking_rate (float, optional): Speed of speech.
|
118
|
+
""" # noqa: E501
|
119
|
+
if is_given(voice):
|
120
|
+
self._opts.voice = voice
|
121
|
+
if is_given(speaking_rate):
|
122
|
+
self._opts.audio_config.speaking_rate = speaking_rate
|
130
123
|
|
131
124
|
def _ensure_client(self) -> texttospeech.TextToSpeechAsyncClient:
|
132
125
|
api_endpoint = "texttospeech.googleapis.com"
|
@@ -135,19 +128,13 @@ class TTS(tts.TTS):
|
|
135
128
|
|
136
129
|
if self._client is None:
|
137
130
|
if self._credentials_info:
|
138
|
-
self._client = (
|
139
|
-
|
140
|
-
self._credentials_info,
|
141
|
-
client_options=ClientOptions(api_endpoint=api_endpoint),
|
142
|
-
)
|
131
|
+
self._client = texttospeech.TextToSpeechAsyncClient.from_service_account_info(
|
132
|
+
self._credentials_info, client_options=ClientOptions(api_endpoint=api_endpoint)
|
143
133
|
)
|
144
134
|
|
145
135
|
elif self._credentials_file:
|
146
|
-
self._client = (
|
147
|
-
|
148
|
-
self._credentials_file,
|
149
|
-
client_options=ClientOptions(api_endpoint=api_endpoint),
|
150
|
-
)
|
136
|
+
self._client = texttospeech.TextToSpeechAsyncClient.from_service_account_file(
|
137
|
+
self._credentials_file, client_options=ClientOptions(api_endpoint=api_endpoint)
|
151
138
|
)
|
152
139
|
else:
|
153
140
|
self._client = texttospeech.TextToSpeechAsyncClient(
|
@@ -161,8 +148,8 @@ class TTS(tts.TTS):
|
|
161
148
|
self,
|
162
149
|
text: str,
|
163
150
|
*,
|
164
|
-
conn_options:
|
165
|
-
) ->
|
151
|
+
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
152
|
+
) -> ChunkedStream:
|
166
153
|
return ChunkedStream(
|
167
154
|
tts=self,
|
168
155
|
input_text=text,
|
@@ -180,7 +167,7 @@ class ChunkedStream(tts.ChunkedStream):
|
|
180
167
|
input_text: str,
|
181
168
|
opts: _TTSOptions,
|
182
169
|
client: texttospeech.TextToSpeechAsyncClient,
|
183
|
-
conn_options:
|
170
|
+
conn_options: APIConnectOptions,
|
184
171
|
) -> None:
|
185
172
|
super().__init__(tts=tts, input_text=input_text, conn_options=conn_options)
|
186
173
|
self._opts, self._client = opts, client
|
@@ -216,14 +203,11 @@ class ChunkedStream(tts.ChunkedStream):
|
|
216
203
|
await decoder.aclose()
|
217
204
|
|
218
205
|
except DeadlineExceeded:
|
219
|
-
raise APITimeoutError()
|
206
|
+
raise APITimeoutError() from None
|
220
207
|
except GoogleAPICallError as e:
|
221
208
|
raise APIStatusError(
|
222
|
-
e.message,
|
223
|
-
|
224
|
-
request_id=None,
|
225
|
-
body=None,
|
226
|
-
)
|
209
|
+
e.message, status_code=e.code or -1, request_id=None, body=None
|
210
|
+
) from None
|
227
211
|
except Exception as e:
|
228
212
|
raise APIConnectionError() from e
|
229
213
|
|