livekit-plugins-google 0.3.0__py3-none-any.whl → 1.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- livekit/plugins/google/__init__.py +33 -7
- livekit/plugins/google/beta/__init__.py +13 -0
- livekit/plugins/google/beta/gemini_tts.py +258 -0
- livekit/plugins/google/llm.py +562 -0
- livekit/plugins/google/log.py +3 -0
- livekit/plugins/google/models.py +160 -32
- livekit/plugins/google/realtime/__init__.py +9 -0
- livekit/plugins/google/realtime/api_proto.py +68 -0
- livekit/plugins/google/realtime/realtime_api.py +1249 -0
- livekit/plugins/google/stt.py +717 -283
- livekit/plugins/google/tools.py +71 -0
- livekit/plugins/google/tts.py +455 -0
- livekit/plugins/google/utils.py +220 -0
- livekit/plugins/google/version.py +1 -1
- livekit_plugins_google-1.3.11.dist-info/METADATA +63 -0
- livekit_plugins_google-1.3.11.dist-info/RECORD +18 -0
- {livekit_plugins_google-0.3.0.dist-info → livekit_plugins_google-1.3.11.dist-info}/WHEEL +1 -2
- livekit_plugins_google-0.3.0.dist-info/METADATA +0 -47
- livekit_plugins_google-0.3.0.dist-info/RECORD +0 -9
- livekit_plugins_google-0.3.0.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
from google.genai import types
|
|
6
|
+
from livekit.agents import llm
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class GeminiTool(llm.ProviderTool, ABC):
|
|
10
|
+
@abstractmethod
|
|
11
|
+
def to_tool_config(self) -> types.Tool: ...
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class GoogleSearch(GeminiTool):
|
|
16
|
+
exclude_domains: Optional[list[str]] = None
|
|
17
|
+
blocking_confidence: Optional[types.PhishBlockThreshold] = None
|
|
18
|
+
time_range_filter: Optional[types.Interval] = None
|
|
19
|
+
|
|
20
|
+
def to_tool_config(self) -> types.Tool:
|
|
21
|
+
return types.Tool(
|
|
22
|
+
google_search=types.GoogleSearch(
|
|
23
|
+
exclude_domains=self.exclude_domains,
|
|
24
|
+
blocking_confidence=self.blocking_confidence,
|
|
25
|
+
time_range_filter=self.time_range_filter,
|
|
26
|
+
)
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class GoogleMaps(GeminiTool):
|
|
32
|
+
auth_config: Optional[types.AuthConfig] = None
|
|
33
|
+
enable_widget: Optional[bool] = None
|
|
34
|
+
|
|
35
|
+
def to_tool_config(self) -> types.Tool:
|
|
36
|
+
return types.Tool(
|
|
37
|
+
google_maps=types.GoogleMaps(
|
|
38
|
+
auth_config=self.auth_config,
|
|
39
|
+
enable_widget=self.enable_widget,
|
|
40
|
+
)
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class URLContext(GeminiTool):
|
|
45
|
+
def to_tool_config(self) -> types.Tool:
|
|
46
|
+
return types.Tool(
|
|
47
|
+
url_context=types.UrlContext(),
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclass
|
|
52
|
+
class FileSearch(GeminiTool):
|
|
53
|
+
file_search_store_names: list[str]
|
|
54
|
+
top_k: Optional[int] = None
|
|
55
|
+
metadata_filter: Optional[str] = None
|
|
56
|
+
|
|
57
|
+
def to_tool_config(self) -> types.Tool:
|
|
58
|
+
return types.Tool(
|
|
59
|
+
file_search=types.FileSearch(
|
|
60
|
+
file_search_store_names=self.file_search_store_names,
|
|
61
|
+
top_k=self.top_k,
|
|
62
|
+
metadata_filter=self.metadata_filter,
|
|
63
|
+
)
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class ToolCodeExecution(GeminiTool):
|
|
68
|
+
def to_tool_config(self) -> types.Tool:
|
|
69
|
+
return types.Tool(
|
|
70
|
+
code_execution=types.ToolCodeExecution(),
|
|
71
|
+
)
|
|
@@ -0,0 +1,455 @@
|
|
|
1
|
+
# Copyright 2023 LiveKit, Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import asyncio
|
|
18
|
+
import weakref
|
|
19
|
+
from collections.abc import AsyncGenerator
|
|
20
|
+
from dataclasses import dataclass, replace
|
|
21
|
+
|
|
22
|
+
from google.api_core.client_options import ClientOptions
|
|
23
|
+
from google.api_core.exceptions import DeadlineExceeded, GoogleAPICallError
|
|
24
|
+
from google.cloud import texttospeech
|
|
25
|
+
from google.cloud.texttospeech_v1.types import (
|
|
26
|
+
CustomPronunciations,
|
|
27
|
+
SsmlVoiceGender,
|
|
28
|
+
SynthesizeSpeechResponse,
|
|
29
|
+
)
|
|
30
|
+
from livekit.agents import APIConnectOptions, APIStatusError, APITimeoutError, tokenize, tts, utils
|
|
31
|
+
from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, NOT_GIVEN, NotGivenOr
|
|
32
|
+
from livekit.agents.utils import is_given
|
|
33
|
+
|
|
34
|
+
from .log import logger
|
|
35
|
+
from .models import GeminiTTSModels, Gender, SpeechLanguages
|
|
36
|
+
|
|
37
|
+
NUM_CHANNELS = 1
|
|
38
|
+
DEFAULT_LANGUAGE = "en-US"
|
|
39
|
+
DEFAULT_GENDER = "neutral"
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass
|
|
43
|
+
class _TTSOptions:
|
|
44
|
+
voice: texttospeech.VoiceSelectionParams
|
|
45
|
+
encoding: texttospeech.AudioEncoding
|
|
46
|
+
sample_rate: int
|
|
47
|
+
pitch: float
|
|
48
|
+
effects_profile_id: str
|
|
49
|
+
speaking_rate: float
|
|
50
|
+
tokenizer: tokenize.SentenceTokenizer
|
|
51
|
+
volume_gain_db: float
|
|
52
|
+
custom_pronunciations: CustomPronunciations | None
|
|
53
|
+
enable_ssml: bool
|
|
54
|
+
use_markup: bool
|
|
55
|
+
model_name: str | None
|
|
56
|
+
prompt: str | None
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class TTS(tts.TTS):
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
*,
|
|
63
|
+
language: NotGivenOr[SpeechLanguages | str] = NOT_GIVEN,
|
|
64
|
+
gender: NotGivenOr[Gender | str] = NOT_GIVEN,
|
|
65
|
+
voice_name: NotGivenOr[str] = NOT_GIVEN,
|
|
66
|
+
voice_cloning_key: NotGivenOr[str] = NOT_GIVEN,
|
|
67
|
+
model_name: GeminiTTSModels | str = "gemini-2.5-flash-tts",
|
|
68
|
+
prompt: NotGivenOr[str] = NOT_GIVEN,
|
|
69
|
+
sample_rate: int = 24000,
|
|
70
|
+
pitch: int = 0,
|
|
71
|
+
effects_profile_id: str = "",
|
|
72
|
+
speaking_rate: float = 1.0,
|
|
73
|
+
volume_gain_db: float = 0.0,
|
|
74
|
+
location: str = "global",
|
|
75
|
+
audio_encoding: texttospeech.AudioEncoding = texttospeech.AudioEncoding.PCM, # type: ignore
|
|
76
|
+
credentials_info: NotGivenOr[dict] = NOT_GIVEN,
|
|
77
|
+
credentials_file: NotGivenOr[str] = NOT_GIVEN,
|
|
78
|
+
tokenizer: NotGivenOr[tokenize.SentenceTokenizer] = NOT_GIVEN,
|
|
79
|
+
custom_pronunciations: NotGivenOr[CustomPronunciations] = NOT_GIVEN,
|
|
80
|
+
use_streaming: bool = True,
|
|
81
|
+
enable_ssml: bool = False,
|
|
82
|
+
use_markup: bool = False,
|
|
83
|
+
) -> None:
|
|
84
|
+
"""
|
|
85
|
+
Create a new instance of Google TTS.
|
|
86
|
+
|
|
87
|
+
Credentials must be provided, either by using the ``credentials_info`` dict, or reading
|
|
88
|
+
from the file specified in ``credentials_file`` or the ``GOOGLE_APPLICATION_CREDENTIALS``
|
|
89
|
+
environmental variable.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
language (SpeechLanguages | str, optional): Language code (e.g., "en-US"). Default is "en-US".
|
|
93
|
+
gender (Gender | str, optional): Voice gender ("male", "female", "neutral"). Default is "neutral".
|
|
94
|
+
voice_name (str, optional): Specific voice name. Default is an empty string. See https://docs.cloud.google.com/text-to-speech/docs/gemini-tts#voice_options for supported voice in Gemini TTS models.
|
|
95
|
+
voice_cloning_key (str, optional): Voice clone key. Created via https://cloud.google.com/text-to-speech/docs/chirp3-instant-custom-voice
|
|
96
|
+
model_name (GeminiTTSModels | str, optional): Model name for TTS (e.g., "gemini-2.5-flash-tts", "chirp_3"). Default is "gemini-2.5-flash-tts".
|
|
97
|
+
prompt (str, optional): Style prompt for Gemini TTS models. Controls tone, style, and speaking characteristics. Only applied to first input chunk in streaming mode.
|
|
98
|
+
sample_rate (int, optional): Audio sample rate in Hz. Default is 24000.
|
|
99
|
+
location (str, optional): Location for the TTS client. Default is "global".
|
|
100
|
+
pitch (float, optional): Speaking pitch, ranging from -20.0 to 20.0 semitones relative to the original pitch. Default is 0.
|
|
101
|
+
effects_profile_id (str): Optional identifier for selecting audio effects profiles to apply to the synthesized speech.
|
|
102
|
+
speaking_rate (float, optional): Speed of speech. Default is 1.0.
|
|
103
|
+
volume_gain_db (float, optional): Volume gain in decibels. Default is 0.0. In the range [-96.0, 16.0]. Strongly recommended not to exceed +10 (dB).
|
|
104
|
+
credentials_info (dict, optional): Dictionary containing Google Cloud credentials. Default is None.
|
|
105
|
+
credentials_file (str, optional): Path to the Google Cloud credentials JSON file. Default is None.
|
|
106
|
+
tokenizer (tokenize.SentenceTokenizer, optional): Tokenizer for the TTS. Defaults to `livekit.agents.tokenize.blingfire.SentenceTokenizer`.
|
|
107
|
+
custom_pronunciations (CustomPronunciations, optional): Custom pronunciations for the TTS. Default is None.
|
|
108
|
+
use_streaming (bool, optional): Whether to use streaming synthesis. Default is True.
|
|
109
|
+
enable_ssml (bool, optional): Whether to enable SSML support. Default is False.
|
|
110
|
+
use_markup (bool, optional): Whether to enable markup input for HD voices. Default is False.
|
|
111
|
+
""" # noqa: E501
|
|
112
|
+
super().__init__(
|
|
113
|
+
capabilities=tts.TTSCapabilities(streaming=use_streaming),
|
|
114
|
+
sample_rate=sample_rate,
|
|
115
|
+
num_channels=1,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
if enable_ssml:
|
|
119
|
+
if use_streaming:
|
|
120
|
+
raise ValueError("SSML support is not available for streaming synthesis")
|
|
121
|
+
if use_markup:
|
|
122
|
+
raise ValueError("SSML support is not available for markup input")
|
|
123
|
+
|
|
124
|
+
self._client: texttospeech.TextToSpeechAsyncClient | None = None
|
|
125
|
+
self._credentials_info = credentials_info
|
|
126
|
+
self._credentials_file = credentials_file
|
|
127
|
+
self._location = location
|
|
128
|
+
|
|
129
|
+
lang = language if is_given(language) else DEFAULT_LANGUAGE
|
|
130
|
+
ssml_gender = _gender_from_str(DEFAULT_GENDER if not is_given(gender) else gender)
|
|
131
|
+
|
|
132
|
+
voice_params = texttospeech.VoiceSelectionParams(
|
|
133
|
+
language_code=lang,
|
|
134
|
+
ssml_gender=ssml_gender,
|
|
135
|
+
)
|
|
136
|
+
if model_name != "chirp_3": # voice_params.model_name must not be set for Chirp 3
|
|
137
|
+
voice_params.model_name = model_name
|
|
138
|
+
|
|
139
|
+
if is_given(voice_cloning_key):
|
|
140
|
+
voice_params.voice_clone = texttospeech.VoiceCloneParams(
|
|
141
|
+
voice_cloning_key=voice_cloning_key,
|
|
142
|
+
)
|
|
143
|
+
else:
|
|
144
|
+
if is_given(voice_name):
|
|
145
|
+
voice_params.name = voice_name
|
|
146
|
+
elif model_name == "chirp_3":
|
|
147
|
+
voice_params.name = "en-US-Chirp3-HD-Charon"
|
|
148
|
+
else:
|
|
149
|
+
voice_params.name = "Charon"
|
|
150
|
+
|
|
151
|
+
if not is_given(tokenizer):
|
|
152
|
+
tokenizer = tokenize.blingfire.SentenceTokenizer()
|
|
153
|
+
|
|
154
|
+
pronunciations = None if not is_given(custom_pronunciations) else custom_pronunciations
|
|
155
|
+
|
|
156
|
+
self._opts = _TTSOptions(
|
|
157
|
+
voice=voice_params,
|
|
158
|
+
encoding=audio_encoding,
|
|
159
|
+
sample_rate=sample_rate,
|
|
160
|
+
pitch=pitch,
|
|
161
|
+
effects_profile_id=effects_profile_id,
|
|
162
|
+
speaking_rate=speaking_rate,
|
|
163
|
+
tokenizer=tokenizer,
|
|
164
|
+
volume_gain_db=volume_gain_db,
|
|
165
|
+
custom_pronunciations=pronunciations,
|
|
166
|
+
enable_ssml=enable_ssml,
|
|
167
|
+
use_markup=use_markup,
|
|
168
|
+
model_name=model_name,
|
|
169
|
+
prompt=prompt if is_given(prompt) else None,
|
|
170
|
+
)
|
|
171
|
+
self._streams = weakref.WeakSet[SynthesizeStream]()
|
|
172
|
+
|
|
173
|
+
@property
|
|
174
|
+
def model(self) -> str:
|
|
175
|
+
return self._opts.model_name or "Chirp3"
|
|
176
|
+
|
|
177
|
+
@property
|
|
178
|
+
def provider(self) -> str:
|
|
179
|
+
return "Google Cloud Platform"
|
|
180
|
+
|
|
181
|
+
def update_options(
|
|
182
|
+
self,
|
|
183
|
+
*,
|
|
184
|
+
language: NotGivenOr[SpeechLanguages | str] = NOT_GIVEN,
|
|
185
|
+
gender: NotGivenOr[Gender | str] = NOT_GIVEN,
|
|
186
|
+
voice_name: NotGivenOr[str] = NOT_GIVEN,
|
|
187
|
+
model_name: NotGivenOr[str] = NOT_GIVEN,
|
|
188
|
+
prompt: NotGivenOr[str] = NOT_GIVEN,
|
|
189
|
+
speaking_rate: NotGivenOr[float] = NOT_GIVEN,
|
|
190
|
+
volume_gain_db: NotGivenOr[float] = NOT_GIVEN,
|
|
191
|
+
) -> None:
|
|
192
|
+
"""
|
|
193
|
+
Update the TTS options.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
language (SpeechLanguages | str, optional): Language code (e.g., "en-US").
|
|
197
|
+
gender (Gender | str, optional): Voice gender ("male", "female", "neutral").
|
|
198
|
+
voice_name (str, optional): Specific voice name.
|
|
199
|
+
model_name (str, optional): Model name for TTS (e.g., "gemini-2.5-flash-tts").
|
|
200
|
+
prompt (str, optional): Style prompt for Gemini TTS models.
|
|
201
|
+
speaking_rate (float, optional): Speed of speech.
|
|
202
|
+
volume_gain_db (float, optional): Volume gain in decibels.
|
|
203
|
+
"""
|
|
204
|
+
params = {}
|
|
205
|
+
if is_given(language):
|
|
206
|
+
params["language_code"] = str(language)
|
|
207
|
+
if is_given(gender):
|
|
208
|
+
params["ssml_gender"] = _gender_from_str(str(gender))
|
|
209
|
+
if is_given(voice_name):
|
|
210
|
+
params["name"] = voice_name
|
|
211
|
+
if is_given(model_name):
|
|
212
|
+
params["model_name"] = model_name
|
|
213
|
+
self._opts.model_name = model_name
|
|
214
|
+
|
|
215
|
+
if params:
|
|
216
|
+
self._opts.voice = texttospeech.VoiceSelectionParams(**params)
|
|
217
|
+
|
|
218
|
+
if is_given(speaking_rate):
|
|
219
|
+
self._opts.speaking_rate = speaking_rate
|
|
220
|
+
if is_given(volume_gain_db):
|
|
221
|
+
self._opts.volume_gain_db = volume_gain_db
|
|
222
|
+
if is_given(prompt):
|
|
223
|
+
self._opts.prompt = prompt
|
|
224
|
+
|
|
225
|
+
def _ensure_client(self) -> texttospeech.TextToSpeechAsyncClient:
|
|
226
|
+
api_endpoint = "texttospeech.googleapis.com"
|
|
227
|
+
if self._location != "global":
|
|
228
|
+
api_endpoint = f"{self._location}-texttospeech.googleapis.com"
|
|
229
|
+
|
|
230
|
+
if self._client is None:
|
|
231
|
+
if self._credentials_info:
|
|
232
|
+
self._client = texttospeech.TextToSpeechAsyncClient.from_service_account_info(
|
|
233
|
+
self._credentials_info, client_options=ClientOptions(api_endpoint=api_endpoint)
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
elif self._credentials_file:
|
|
237
|
+
self._client = texttospeech.TextToSpeechAsyncClient.from_service_account_file(
|
|
238
|
+
self._credentials_file, client_options=ClientOptions(api_endpoint=api_endpoint)
|
|
239
|
+
)
|
|
240
|
+
else:
|
|
241
|
+
self._client = texttospeech.TextToSpeechAsyncClient(
|
|
242
|
+
client_options=ClientOptions(api_endpoint=api_endpoint)
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
assert self._client is not None
|
|
246
|
+
return self._client
|
|
247
|
+
|
|
248
|
+
def stream(
|
|
249
|
+
self, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
|
|
250
|
+
) -> SynthesizeStream:
|
|
251
|
+
stream = SynthesizeStream(tts=self, conn_options=conn_options)
|
|
252
|
+
self._streams.add(stream)
|
|
253
|
+
return stream
|
|
254
|
+
|
|
255
|
+
def synthesize(
|
|
256
|
+
self, text: str, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
|
|
257
|
+
) -> ChunkedStream:
|
|
258
|
+
return ChunkedStream(tts=self, input_text=text, conn_options=conn_options)
|
|
259
|
+
|
|
260
|
+
async def aclose(self) -> None:
|
|
261
|
+
for stream in list(self._streams):
|
|
262
|
+
await stream.aclose()
|
|
263
|
+
self._streams.clear()
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
class ChunkedStream(tts.ChunkedStream):
|
|
267
|
+
def __init__(self, *, tts: TTS, input_text: str, conn_options: APIConnectOptions) -> None:
|
|
268
|
+
super().__init__(tts=tts, input_text=input_text, conn_options=conn_options)
|
|
269
|
+
self._tts: TTS = tts
|
|
270
|
+
self._opts = replace(tts._opts)
|
|
271
|
+
|
|
272
|
+
def _build_ssml(self) -> str:
|
|
273
|
+
ssml = "<speak>"
|
|
274
|
+
ssml += self._input_text
|
|
275
|
+
ssml += "</speak>"
|
|
276
|
+
return ssml
|
|
277
|
+
|
|
278
|
+
async def _run(self, output_emitter: tts.AudioEmitter) -> None:
|
|
279
|
+
try:
|
|
280
|
+
if self._opts.use_markup:
|
|
281
|
+
tts_input = texttospeech.SynthesisInput(
|
|
282
|
+
markup=self._input_text, custom_pronunciations=self._opts.custom_pronunciations
|
|
283
|
+
)
|
|
284
|
+
elif self._opts.enable_ssml:
|
|
285
|
+
tts_input = texttospeech.SynthesisInput(
|
|
286
|
+
ssml=self._build_ssml(), custom_pronunciations=self._opts.custom_pronunciations
|
|
287
|
+
)
|
|
288
|
+
else:
|
|
289
|
+
tts_input = texttospeech.SynthesisInput(
|
|
290
|
+
text=self._input_text, custom_pronunciations=self._opts.custom_pronunciations
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
if self._opts.prompt is not None:
|
|
294
|
+
tts_input.prompt = self._opts.prompt
|
|
295
|
+
|
|
296
|
+
response: SynthesizeSpeechResponse = await self._tts._ensure_client().synthesize_speech(
|
|
297
|
+
input=tts_input,
|
|
298
|
+
voice=self._opts.voice,
|
|
299
|
+
audio_config=texttospeech.AudioConfig(
|
|
300
|
+
audio_encoding=self._opts.encoding,
|
|
301
|
+
sample_rate_hertz=self._opts.sample_rate,
|
|
302
|
+
pitch=self._opts.pitch,
|
|
303
|
+
effects_profile_id=self._opts.effects_profile_id,
|
|
304
|
+
speaking_rate=self._opts.speaking_rate,
|
|
305
|
+
volume_gain_db=self._opts.volume_gain_db,
|
|
306
|
+
),
|
|
307
|
+
timeout=self._conn_options.timeout,
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
output_emitter.initialize(
|
|
311
|
+
request_id=utils.shortuuid(),
|
|
312
|
+
sample_rate=self._opts.sample_rate,
|
|
313
|
+
num_channels=1,
|
|
314
|
+
mime_type=_encoding_to_mimetype(self._opts.encoding),
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
output_emitter.push(response.audio_content)
|
|
318
|
+
except DeadlineExceeded:
|
|
319
|
+
raise APITimeoutError() from None
|
|
320
|
+
except GoogleAPICallError as e:
|
|
321
|
+
raise APIStatusError(e.message, status_code=e.code or -1) from e
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
class SynthesizeStream(tts.SynthesizeStream):
|
|
325
|
+
def __init__(self, *, tts: TTS, conn_options: APIConnectOptions):
|
|
326
|
+
super().__init__(tts=tts, conn_options=conn_options)
|
|
327
|
+
self._tts: TTS = tts
|
|
328
|
+
self._opts = replace(tts._opts)
|
|
329
|
+
self._segments_ch = utils.aio.Chan[tokenize.SentenceStream]()
|
|
330
|
+
|
|
331
|
+
async def _run(self, output_emitter: tts.AudioEmitter) -> None:
|
|
332
|
+
encoding = self._opts.encoding
|
|
333
|
+
if encoding not in (texttospeech.AudioEncoding.OGG_OPUS, texttospeech.AudioEncoding.PCM):
|
|
334
|
+
enc_name = texttospeech.AudioEncoding._member_names_[encoding]
|
|
335
|
+
logger.warning(
|
|
336
|
+
f"encoding {enc_name} isn't supported by the streaming_synthesize, "
|
|
337
|
+
"fallbacking to PCM"
|
|
338
|
+
)
|
|
339
|
+
encoding = texttospeech.AudioEncoding.PCM # type: ignore
|
|
340
|
+
|
|
341
|
+
output_emitter.initialize(
|
|
342
|
+
request_id=utils.shortuuid(),
|
|
343
|
+
sample_rate=self._opts.sample_rate,
|
|
344
|
+
num_channels=1,
|
|
345
|
+
mime_type=_encoding_to_mimetype(encoding),
|
|
346
|
+
stream=True,
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
streaming_config = texttospeech.StreamingSynthesizeConfig(
|
|
350
|
+
voice=self._opts.voice,
|
|
351
|
+
streaming_audio_config=texttospeech.StreamingAudioConfig(
|
|
352
|
+
audio_encoding=encoding,
|
|
353
|
+
sample_rate_hertz=self._opts.sample_rate,
|
|
354
|
+
speaking_rate=self._opts.speaking_rate,
|
|
355
|
+
),
|
|
356
|
+
custom_pronunciations=self._opts.custom_pronunciations,
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
async def _tokenize_input() -> None:
|
|
360
|
+
input_stream = None
|
|
361
|
+
async for input in self._input_ch:
|
|
362
|
+
if isinstance(input, str):
|
|
363
|
+
if input_stream is None:
|
|
364
|
+
input_stream = self._opts.tokenizer.stream()
|
|
365
|
+
self._segments_ch.send_nowait(input_stream)
|
|
366
|
+
input_stream.push_text(input)
|
|
367
|
+
elif isinstance(input, self._FlushSentinel):
|
|
368
|
+
if input_stream:
|
|
369
|
+
input_stream.end_input()
|
|
370
|
+
input_stream = None
|
|
371
|
+
|
|
372
|
+
self._segments_ch.close()
|
|
373
|
+
|
|
374
|
+
async def _run_segments() -> None:
|
|
375
|
+
async for input_stream in self._segments_ch:
|
|
376
|
+
await self._run_stream(input_stream, output_emitter, streaming_config)
|
|
377
|
+
|
|
378
|
+
tasks = [
|
|
379
|
+
asyncio.create_task(_tokenize_input()),
|
|
380
|
+
asyncio.create_task(_run_segments()),
|
|
381
|
+
]
|
|
382
|
+
try:
|
|
383
|
+
await asyncio.gather(*tasks)
|
|
384
|
+
finally:
|
|
385
|
+
await utils.aio.cancel_and_wait(*tasks)
|
|
386
|
+
|
|
387
|
+
async def _run_stream(
|
|
388
|
+
self,
|
|
389
|
+
input_stream: tokenize.SentenceStream,
|
|
390
|
+
output_emitter: tts.AudioEmitter,
|
|
391
|
+
streaming_config: texttospeech.StreamingSynthesizeConfig,
|
|
392
|
+
) -> None:
|
|
393
|
+
@utils.log_exceptions(logger=logger)
|
|
394
|
+
async def input_generator() -> AsyncGenerator[
|
|
395
|
+
texttospeech.StreamingSynthesizeRequest, None
|
|
396
|
+
]:
|
|
397
|
+
try:
|
|
398
|
+
yield texttospeech.StreamingSynthesizeRequest(streaming_config=streaming_config)
|
|
399
|
+
|
|
400
|
+
is_first_input = True
|
|
401
|
+
async for input in input_stream:
|
|
402
|
+
self._mark_started()
|
|
403
|
+
# prompt is only supported in the first input chunk (for Gemini TTS)
|
|
404
|
+
synthesis_input = texttospeech.StreamingSynthesisInput(
|
|
405
|
+
markup=input.token if self._opts.use_markup else None,
|
|
406
|
+
text=None if self._opts.use_markup else input.token,
|
|
407
|
+
prompt=self._opts.prompt if is_first_input else None,
|
|
408
|
+
)
|
|
409
|
+
is_first_input = False
|
|
410
|
+
yield texttospeech.StreamingSynthesizeRequest(input=synthesis_input)
|
|
411
|
+
|
|
412
|
+
except Exception:
|
|
413
|
+
logger.exception("an error occurred while streaming input to google TTS")
|
|
414
|
+
|
|
415
|
+
input_gen = input_generator()
|
|
416
|
+
try:
|
|
417
|
+
stream = await self._tts._ensure_client().streaming_synthesize(
|
|
418
|
+
input_gen, timeout=self._conn_options.timeout
|
|
419
|
+
)
|
|
420
|
+
output_emitter.start_segment(segment_id=utils.shortuuid())
|
|
421
|
+
|
|
422
|
+
async for resp in stream:
|
|
423
|
+
output_emitter.push(resp.audio_content)
|
|
424
|
+
|
|
425
|
+
output_emitter.end_segment()
|
|
426
|
+
|
|
427
|
+
except DeadlineExceeded:
|
|
428
|
+
raise APITimeoutError() from None
|
|
429
|
+
except GoogleAPICallError as e:
|
|
430
|
+
raise APIStatusError(e.message, status_code=e.code or -1) from e
|
|
431
|
+
finally:
|
|
432
|
+
await input_gen.aclose()
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
def _gender_from_str(gender: str) -> SsmlVoiceGender:
|
|
436
|
+
ssml_gender = SsmlVoiceGender.NEUTRAL
|
|
437
|
+
if gender == "male":
|
|
438
|
+
ssml_gender = SsmlVoiceGender.MALE
|
|
439
|
+
elif gender == "female":
|
|
440
|
+
ssml_gender = SsmlVoiceGender.FEMALE
|
|
441
|
+
|
|
442
|
+
return ssml_gender # type: ignore
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
def _encoding_to_mimetype(encoding: texttospeech.AudioEncoding) -> str:
|
|
446
|
+
if encoding == texttospeech.AudioEncoding.PCM:
|
|
447
|
+
return "audio/pcm"
|
|
448
|
+
elif encoding == texttospeech.AudioEncoding.LINEAR16:
|
|
449
|
+
return "audio/wav"
|
|
450
|
+
elif encoding == texttospeech.AudioEncoding.MP3:
|
|
451
|
+
return "audio/mp3"
|
|
452
|
+
elif encoding == texttospeech.AudioEncoding.OGG_OPUS:
|
|
453
|
+
return "audio/opus"
|
|
454
|
+
else:
|
|
455
|
+
raise RuntimeError(f"encoding {encoding} isn't supported")
|