livekit-plugins-google 0.9.0__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- livekit/plugins/google/__init__.py +2 -1
- livekit/plugins/google/_utils.py +202 -0
- livekit/plugins/google/beta/realtime/__init__.py +0 -2
- livekit/plugins/google/beta/realtime/api_proto.py +5 -60
- livekit/plugins/google/beta/realtime/realtime_api.py +168 -42
- livekit/plugins/google/beta/realtime/transcriber.py +173 -0
- livekit/plugins/google/llm.py +414 -0
- livekit/plugins/google/models.py +2 -0
- livekit/plugins/google/stt.py +64 -10
- livekit/plugins/google/version.py +1 -1
- {livekit_plugins_google-0.9.0.dist-info → livekit_plugins_google-0.10.0.dist-info}/METADATA +13 -3
- livekit_plugins_google-0.10.0.dist-info/RECORD +18 -0
- {livekit_plugins_google-0.9.0.dist-info → livekit_plugins_google-0.10.0.dist-info}/WHEEL +1 -1
- livekit_plugins_google-0.9.0.dist-info/RECORD +0 -15
- {livekit_plugins_google-0.9.0.dist-info → livekit_plugins_google-0.10.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,414 @@
|
|
1
|
+
# Copyright 2023 LiveKit, Inc.
|
2
|
+
#
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
import asyncio
|
19
|
+
import json
|
20
|
+
import os
|
21
|
+
from dataclasses import dataclass
|
22
|
+
from typing import Any, Literal, MutableSet, Union, cast
|
23
|
+
|
24
|
+
from livekit.agents import (
|
25
|
+
APIConnectionError,
|
26
|
+
APIStatusError,
|
27
|
+
llm,
|
28
|
+
utils,
|
29
|
+
)
|
30
|
+
from livekit.agents.llm import ToolChoice, _create_ai_function_info
|
31
|
+
from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions
|
32
|
+
|
33
|
+
from google import genai
|
34
|
+
from google.auth._default_async import default_async
|
35
|
+
from google.genai import types
|
36
|
+
from google.genai.errors import APIError, ClientError, ServerError
|
37
|
+
|
38
|
+
from ._utils import _build_gemini_ctx, _build_tools
|
39
|
+
from .log import logger
|
40
|
+
from .models import ChatModels
|
41
|
+
|
42
|
+
|
43
|
+
@dataclass
|
44
|
+
class LLMOptions:
|
45
|
+
model: ChatModels | str
|
46
|
+
temperature: float | None
|
47
|
+
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto"
|
48
|
+
vertexai: bool = False
|
49
|
+
project: str | None = None
|
50
|
+
location: str | None = None
|
51
|
+
candidate_count: int = 1
|
52
|
+
max_output_tokens: int | None = None
|
53
|
+
top_p: float | None = None
|
54
|
+
top_k: float | None = None
|
55
|
+
presence_penalty: float | None = None
|
56
|
+
frequency_penalty: float | None = None
|
57
|
+
|
58
|
+
|
59
|
+
class LLM(llm.LLM):
|
60
|
+
def __init__(
|
61
|
+
self,
|
62
|
+
*,
|
63
|
+
model: ChatModels | str = "gemini-2.0-flash-exp",
|
64
|
+
api_key: str | None = None,
|
65
|
+
vertexai: bool = False,
|
66
|
+
project: str | None = None,
|
67
|
+
location: str | None = None,
|
68
|
+
candidate_count: int = 1,
|
69
|
+
temperature: float = 0.8,
|
70
|
+
max_output_tokens: int | None = None,
|
71
|
+
top_p: float | None = None,
|
72
|
+
top_k: float | None = None,
|
73
|
+
presence_penalty: float | None = None,
|
74
|
+
frequency_penalty: float | None = None,
|
75
|
+
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto",
|
76
|
+
) -> None:
|
77
|
+
"""
|
78
|
+
Create a new instance of Google GenAI LLM.
|
79
|
+
|
80
|
+
Environment Requirements:
|
81
|
+
- For VertexAI: Set the `GOOGLE_APPLICATION_CREDENTIALS` environment variable to the path of the service account key file.
|
82
|
+
The Google Cloud project and location can be set via `project` and `location` arguments or the environment variables
|
83
|
+
`GOOGLE_CLOUD_PROJECT` and `GOOGLE_CLOUD_LOCATION`. By default, the project is inferred from the service account key file,
|
84
|
+
and the location defaults to "us-central1".
|
85
|
+
- For Google Gemini API: Set the `api_key` argument or the `GOOGLE_API_KEY` environment variable.
|
86
|
+
|
87
|
+
Args:
|
88
|
+
model (ChatModels | str, optional): The model name to use. Defaults to "gemini-2.0-flash-exp".
|
89
|
+
api_key (str, optional): The API key for Google Gemini. If not provided, it attempts to read from the `GOOGLE_API_KEY` environment variable.
|
90
|
+
vertexai (bool, optional): Whether to use VertexAI. Defaults to False.
|
91
|
+
project (str, optional): The Google Cloud project to use (only for VertexAI). Defaults to None.
|
92
|
+
location (str, optional): The location to use for VertexAI API requests. Defaults value is "us-central1".
|
93
|
+
candidate_count (int, optional): Number of candidate responses to generate. Defaults to 1.
|
94
|
+
temperature (float, optional): Sampling temperature for response generation. Defaults to 0.8.
|
95
|
+
max_output_tokens (int, optional): Maximum number of tokens to generate in the output. Defaults to None.
|
96
|
+
top_p (float, optional): The nucleus sampling probability for response generation. Defaults to None.
|
97
|
+
top_k (int, optional): The top-k sampling value for response generation. Defaults to None.
|
98
|
+
presence_penalty (float, optional): Penalizes the model for generating previously mentioned concepts. Defaults to None.
|
99
|
+
frequency_penalty (float, optional): Penalizes the model for repeating words. Defaults to None.
|
100
|
+
tool_choice (ToolChoice or Literal["auto", "required", "none"], optional): Specifies whether to use tools during response generation. Defaults to "auto".
|
101
|
+
"""
|
102
|
+
super().__init__()
|
103
|
+
self._capabilities = llm.LLMCapabilities(supports_choices_on_int=False)
|
104
|
+
self._project_id = project or os.environ.get("GOOGLE_CLOUD_PROJECT", None)
|
105
|
+
self._location = location or os.environ.get(
|
106
|
+
"GOOGLE_CLOUD_LOCATION", "us-central1"
|
107
|
+
)
|
108
|
+
self._api_key = api_key or os.environ.get("GOOGLE_API_KEY", None)
|
109
|
+
_gac = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
|
110
|
+
if _gac is None:
|
111
|
+
raise ValueError(
|
112
|
+
"`GOOGLE_APPLICATION_CREDENTIALS` environment variable is not set. please set it to the path of the service account key file."
|
113
|
+
)
|
114
|
+
|
115
|
+
if vertexai:
|
116
|
+
if not self._project_id:
|
117
|
+
_, self._project_id = default_async(
|
118
|
+
scopes=["https://www.googleapis.com/auth/cloud-platform"]
|
119
|
+
)
|
120
|
+
self._api_key = None # VertexAI does not require an API key
|
121
|
+
|
122
|
+
else:
|
123
|
+
self._project_id = None
|
124
|
+
self._location = None
|
125
|
+
if not self._api_key:
|
126
|
+
raise ValueError(
|
127
|
+
"API key is required for Google API either via api_key or GOOGLE_API_KEY environment variable"
|
128
|
+
)
|
129
|
+
|
130
|
+
self._opts = LLMOptions(
|
131
|
+
model=model,
|
132
|
+
temperature=temperature,
|
133
|
+
tool_choice=tool_choice,
|
134
|
+
vertexai=vertexai,
|
135
|
+
project=project,
|
136
|
+
location=location,
|
137
|
+
candidate_count=candidate_count,
|
138
|
+
max_output_tokens=max_output_tokens,
|
139
|
+
top_p=top_p,
|
140
|
+
top_k=top_k,
|
141
|
+
presence_penalty=presence_penalty,
|
142
|
+
frequency_penalty=frequency_penalty,
|
143
|
+
)
|
144
|
+
self._client = genai.Client(
|
145
|
+
api_key=self._api_key,
|
146
|
+
vertexai=vertexai,
|
147
|
+
project=self._project_id,
|
148
|
+
location=self._location,
|
149
|
+
)
|
150
|
+
self._running_fncs: MutableSet[asyncio.Task[Any]] = set()
|
151
|
+
|
152
|
+
def chat(
|
153
|
+
self,
|
154
|
+
*,
|
155
|
+
chat_ctx: llm.ChatContext,
|
156
|
+
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
157
|
+
fnc_ctx: llm.FunctionContext | None = None,
|
158
|
+
temperature: float | None = None,
|
159
|
+
n: int | None = 1,
|
160
|
+
parallel_tool_calls: bool | None = None,
|
161
|
+
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]]
|
162
|
+
| None = None,
|
163
|
+
) -> "LLMStream":
|
164
|
+
if tool_choice is None:
|
165
|
+
tool_choice = self._opts.tool_choice
|
166
|
+
|
167
|
+
if temperature is None:
|
168
|
+
temperature = self._opts.temperature
|
169
|
+
|
170
|
+
return LLMStream(
|
171
|
+
self,
|
172
|
+
client=self._client,
|
173
|
+
model=self._opts.model,
|
174
|
+
max_output_tokens=self._opts.max_output_tokens,
|
175
|
+
top_p=self._opts.top_p,
|
176
|
+
top_k=self._opts.top_k,
|
177
|
+
presence_penalty=self._opts.presence_penalty,
|
178
|
+
frequency_penalty=self._opts.frequency_penalty,
|
179
|
+
chat_ctx=chat_ctx,
|
180
|
+
fnc_ctx=fnc_ctx,
|
181
|
+
conn_options=conn_options,
|
182
|
+
n=n,
|
183
|
+
temperature=temperature,
|
184
|
+
tool_choice=tool_choice,
|
185
|
+
)
|
186
|
+
|
187
|
+
|
188
|
+
class LLMStream(llm.LLMStream):
|
189
|
+
def __init__(
|
190
|
+
self,
|
191
|
+
llm: LLM,
|
192
|
+
*,
|
193
|
+
client: genai.Client,
|
194
|
+
model: str | ChatModels,
|
195
|
+
chat_ctx: llm.ChatContext,
|
196
|
+
conn_options: APIConnectOptions,
|
197
|
+
fnc_ctx: llm.FunctionContext | None,
|
198
|
+
temperature: float | None,
|
199
|
+
n: int | None,
|
200
|
+
max_output_tokens: int | None,
|
201
|
+
top_p: float | None,
|
202
|
+
top_k: float | None,
|
203
|
+
presence_penalty: float | None,
|
204
|
+
frequency_penalty: float | None,
|
205
|
+
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]],
|
206
|
+
) -> None:
|
207
|
+
super().__init__(
|
208
|
+
llm, chat_ctx=chat_ctx, fnc_ctx=fnc_ctx, conn_options=conn_options
|
209
|
+
)
|
210
|
+
self._client = client
|
211
|
+
self._model = model
|
212
|
+
self._llm: LLM = llm
|
213
|
+
self._max_output_tokens = max_output_tokens
|
214
|
+
self._top_p = top_p
|
215
|
+
self._top_k = top_k
|
216
|
+
self._presence_penalty = presence_penalty
|
217
|
+
self._frequency_penalty = frequency_penalty
|
218
|
+
self._temperature = temperature
|
219
|
+
self._n = n
|
220
|
+
self._tool_choice = tool_choice
|
221
|
+
|
222
|
+
async def _run(self) -> None:
|
223
|
+
retryable = True
|
224
|
+
request_id = utils.shortuuid()
|
225
|
+
|
226
|
+
try:
|
227
|
+
opts: dict[str, Any] = dict()
|
228
|
+
turns, system_instruction = _build_gemini_ctx(self._chat_ctx, id(self))
|
229
|
+
|
230
|
+
if self._fnc_ctx and len(self._fnc_ctx.ai_functions) > 0:
|
231
|
+
functions = _build_tools(self._fnc_ctx)
|
232
|
+
opts["tools"] = [types.Tool(function_declarations=functions)]
|
233
|
+
|
234
|
+
if self._tool_choice is not None:
|
235
|
+
if isinstance(self._tool_choice, ToolChoice):
|
236
|
+
# specific function
|
237
|
+
tool_config = types.ToolConfig(
|
238
|
+
function_calling_config=types.FunctionCallingConfig(
|
239
|
+
mode="ANY",
|
240
|
+
allowed_function_names=[self._tool_choice.name],
|
241
|
+
)
|
242
|
+
)
|
243
|
+
elif self._tool_choice == "required":
|
244
|
+
# model must call any function
|
245
|
+
tool_config = types.ToolConfig(
|
246
|
+
function_calling_config=types.FunctionCallingConfig(
|
247
|
+
mode="ANY",
|
248
|
+
allowed_function_names=[
|
249
|
+
fnc.name
|
250
|
+
for fnc in self._fnc_ctx.ai_functions.values()
|
251
|
+
],
|
252
|
+
)
|
253
|
+
)
|
254
|
+
elif self._tool_choice == "auto":
|
255
|
+
# model can call any function
|
256
|
+
tool_config = types.ToolConfig(
|
257
|
+
function_calling_config=types.FunctionCallingConfig(
|
258
|
+
mode="AUTO"
|
259
|
+
)
|
260
|
+
)
|
261
|
+
elif self._tool_choice == "none":
|
262
|
+
# model cannot call any function
|
263
|
+
tool_config = types.ToolConfig(
|
264
|
+
function_calling_config=types.FunctionCallingConfig(
|
265
|
+
mode="NONE",
|
266
|
+
)
|
267
|
+
)
|
268
|
+
opts["tool_config"] = tool_config
|
269
|
+
|
270
|
+
config = types.GenerateContentConfig(
|
271
|
+
candidate_count=self._n,
|
272
|
+
temperature=self._temperature,
|
273
|
+
max_output_tokens=self._max_output_tokens,
|
274
|
+
top_p=self._top_p,
|
275
|
+
top_k=self._top_k,
|
276
|
+
presence_penalty=self._presence_penalty,
|
277
|
+
frequency_penalty=self._frequency_penalty,
|
278
|
+
system_instruction=system_instruction,
|
279
|
+
**opts,
|
280
|
+
)
|
281
|
+
async for response in self._client.aio.models.generate_content_stream(
|
282
|
+
model=self._model,
|
283
|
+
contents=cast(types.ContentListUnion, turns),
|
284
|
+
config=config,
|
285
|
+
):
|
286
|
+
if response.prompt_feedback:
|
287
|
+
raise APIStatusError(
|
288
|
+
response.prompt_feedback.json(),
|
289
|
+
retryable=False,
|
290
|
+
request_id=request_id,
|
291
|
+
)
|
292
|
+
|
293
|
+
if (
|
294
|
+
not response.candidates
|
295
|
+
or not response.candidates[0].content
|
296
|
+
or not response.candidates[0].content.parts
|
297
|
+
):
|
298
|
+
raise APIStatusError(
|
299
|
+
"No candidates in the response",
|
300
|
+
retryable=True,
|
301
|
+
request_id=request_id,
|
302
|
+
)
|
303
|
+
|
304
|
+
if len(response.candidates) > 1:
|
305
|
+
logger.warning(
|
306
|
+
"gemini llm: there are multiple candidates in the response, returning response from the first one."
|
307
|
+
)
|
308
|
+
|
309
|
+
for index, part in enumerate(response.candidates[0].content.parts):
|
310
|
+
chat_chunk = self._parse_part(request_id, index, part)
|
311
|
+
if chat_chunk is not None:
|
312
|
+
retryable = False
|
313
|
+
self._event_ch.send_nowait(chat_chunk)
|
314
|
+
|
315
|
+
if response.usage_metadata is not None:
|
316
|
+
usage = response.usage_metadata
|
317
|
+
self._event_ch.send_nowait(
|
318
|
+
llm.ChatChunk(
|
319
|
+
request_id=request_id,
|
320
|
+
usage=llm.CompletionUsage(
|
321
|
+
completion_tokens=usage.candidates_token_count or 0,
|
322
|
+
prompt_tokens=usage.prompt_token_count or 0,
|
323
|
+
total_tokens=usage.total_token_count or 0,
|
324
|
+
),
|
325
|
+
)
|
326
|
+
)
|
327
|
+
except ClientError as e:
|
328
|
+
raise APIStatusError(
|
329
|
+
"gemini llm: client error",
|
330
|
+
status_code=e.code,
|
331
|
+
body=e.message,
|
332
|
+
request_id=request_id,
|
333
|
+
retryable=False if e.code != 429 else True,
|
334
|
+
) from e
|
335
|
+
except ServerError as e:
|
336
|
+
raise APIStatusError(
|
337
|
+
"gemini llm: server error",
|
338
|
+
status_code=e.code,
|
339
|
+
body=e.message,
|
340
|
+
request_id=request_id,
|
341
|
+
retryable=retryable,
|
342
|
+
) from e
|
343
|
+
except APIError as e:
|
344
|
+
raise APIStatusError(
|
345
|
+
"gemini llm: api error",
|
346
|
+
status_code=e.code,
|
347
|
+
body=e.message,
|
348
|
+
request_id=request_id,
|
349
|
+
retryable=retryable,
|
350
|
+
) from e
|
351
|
+
except Exception as e:
|
352
|
+
raise APIConnectionError(
|
353
|
+
"gemini llm: error generating content",
|
354
|
+
retryable=retryable,
|
355
|
+
) from e
|
356
|
+
|
357
|
+
def _parse_part(
|
358
|
+
self, id: str, index: int, part: types.Part
|
359
|
+
) -> llm.ChatChunk | None:
|
360
|
+
if part.function_call:
|
361
|
+
return self._try_build_function(id, index, part)
|
362
|
+
|
363
|
+
return llm.ChatChunk(
|
364
|
+
request_id=id,
|
365
|
+
choices=[
|
366
|
+
llm.Choice(
|
367
|
+
delta=llm.ChoiceDelta(content=part.text, role="assistant"),
|
368
|
+
index=index,
|
369
|
+
)
|
370
|
+
],
|
371
|
+
)
|
372
|
+
|
373
|
+
def _try_build_function(
|
374
|
+
self, id: str, index: int, part: types.Part
|
375
|
+
) -> llm.ChatChunk | None:
|
376
|
+
if part.function_call is None:
|
377
|
+
logger.warning("gemini llm: no function call in the response")
|
378
|
+
return None
|
379
|
+
|
380
|
+
if part.function_call.name is None:
|
381
|
+
logger.warning("gemini llm: no function name in the response")
|
382
|
+
return None
|
383
|
+
|
384
|
+
if part.function_call.id is None:
|
385
|
+
part.function_call.id = utils.shortuuid()
|
386
|
+
|
387
|
+
if self._fnc_ctx is None:
|
388
|
+
logger.warning(
|
389
|
+
"google stream tried to run function without function context"
|
390
|
+
)
|
391
|
+
return None
|
392
|
+
|
393
|
+
fnc_info = _create_ai_function_info(
|
394
|
+
self._fnc_ctx,
|
395
|
+
part.function_call.id,
|
396
|
+
part.function_call.name,
|
397
|
+
json.dumps(part.function_call.args),
|
398
|
+
)
|
399
|
+
|
400
|
+
self._function_calls_info.append(fnc_info)
|
401
|
+
|
402
|
+
return llm.ChatChunk(
|
403
|
+
request_id=id,
|
404
|
+
choices=[
|
405
|
+
llm.Choice(
|
406
|
+
delta=llm.ChoiceDelta(
|
407
|
+
role="assistant",
|
408
|
+
tool_calls=[fnc_info],
|
409
|
+
content=part.text,
|
410
|
+
),
|
411
|
+
index=index,
|
412
|
+
)
|
413
|
+
],
|
414
|
+
)
|
livekit/plugins/google/models.py
CHANGED
livekit/plugins/google/stt.py
CHANGED
@@ -16,6 +16,7 @@ from __future__ import annotations
|
|
16
16
|
|
17
17
|
import asyncio
|
18
18
|
import dataclasses
|
19
|
+
import time
|
19
20
|
import weakref
|
20
21
|
from dataclasses import dataclass
|
21
22
|
from typing import List, Union
|
@@ -44,6 +45,10 @@ from .models import SpeechLanguages, SpeechModels
|
|
44
45
|
LgType = Union[SpeechLanguages, str]
|
45
46
|
LanguageCode = Union[LgType, List[LgType]]
|
46
47
|
|
48
|
+
# Google STT has a timeout of 5 mins, we'll attempt to restart the session
|
49
|
+
# before that timeout is reached
|
50
|
+
_max_session_duration = 240
|
51
|
+
|
47
52
|
|
48
53
|
# This class is only be used internally to encapsulate the options
|
49
54
|
@dataclass
|
@@ -229,8 +234,6 @@ class STT(stt.STT):
|
|
229
234
|
raise APIStatusError(
|
230
235
|
e.message,
|
231
236
|
status_code=e.code or -1,
|
232
|
-
request_id=None,
|
233
|
-
body=None,
|
234
237
|
)
|
235
238
|
except Exception as e:
|
236
239
|
raise APIConnectionError() from e
|
@@ -278,6 +281,13 @@ class STT(stt.STT):
|
|
278
281
|
self._config.spoken_punctuation = spoken_punctuation
|
279
282
|
if model is not None:
|
280
283
|
self._config.model = model
|
284
|
+
client = None
|
285
|
+
recognizer = None
|
286
|
+
if location is not None:
|
287
|
+
self._location = location
|
288
|
+
# if location is changed, fetch a new client and recognizer as per the new location
|
289
|
+
client = self._ensure_client()
|
290
|
+
recognizer = self._recognizer
|
281
291
|
if keywords is not None:
|
282
292
|
self._config.keywords = keywords
|
283
293
|
|
@@ -289,8 +299,9 @@ class STT(stt.STT):
|
|
289
299
|
punctuate=punctuate,
|
290
300
|
spoken_punctuation=spoken_punctuation,
|
291
301
|
model=model,
|
292
|
-
location=location,
|
293
302
|
keywords=keywords,
|
303
|
+
client=client,
|
304
|
+
recognizer=recognizer,
|
294
305
|
)
|
295
306
|
|
296
307
|
|
@@ -312,6 +323,7 @@ class SpeechStream(stt.SpeechStream):
|
|
312
323
|
self._recognizer = recognizer
|
313
324
|
self._config = config
|
314
325
|
self._reconnect_event = asyncio.Event()
|
326
|
+
self._session_connected_at: float = 0
|
315
327
|
|
316
328
|
def update_options(
|
317
329
|
self,
|
@@ -322,8 +334,9 @@ class SpeechStream(stt.SpeechStream):
|
|
322
334
|
punctuate: bool | None = None,
|
323
335
|
spoken_punctuation: bool | None = None,
|
324
336
|
model: SpeechModels | None = None,
|
325
|
-
location: str | None = None,
|
326
337
|
keywords: List[tuple[str, float]] | None = None,
|
338
|
+
client: SpeechAsyncClient | None = None,
|
339
|
+
recognizer: str | None = None,
|
327
340
|
):
|
328
341
|
if languages is not None:
|
329
342
|
if isinstance(languages, str):
|
@@ -341,13 +354,17 @@ class SpeechStream(stt.SpeechStream):
|
|
341
354
|
self._config.model = model
|
342
355
|
if keywords is not None:
|
343
356
|
self._config.keywords = keywords
|
357
|
+
if client is not None:
|
358
|
+
self._client = client
|
359
|
+
if recognizer is not None:
|
360
|
+
self._recognizer = recognizer
|
344
361
|
|
345
362
|
self._reconnect_event.set()
|
346
363
|
|
347
364
|
async def _run(self) -> None:
|
348
365
|
# google requires a async generator when calling streaming_recognize
|
349
366
|
# this function basically convert the queue into a async generator
|
350
|
-
async def input_generator():
|
367
|
+
async def input_generator(should_stop: asyncio.Event):
|
351
368
|
try:
|
352
369
|
# first request should contain the config
|
353
370
|
yield cloud_speech.StreamingRecognizeRequest(
|
@@ -356,6 +373,12 @@ class SpeechStream(stt.SpeechStream):
|
|
356
373
|
)
|
357
374
|
|
358
375
|
async for frame in self._input_ch:
|
376
|
+
# when the stream is aborted due to reconnect, this input_generator
|
377
|
+
# needs to stop consuming frames
|
378
|
+
# when the generator stops, the previous gRPC stream will close
|
379
|
+
if should_stop.is_set():
|
380
|
+
return
|
381
|
+
|
359
382
|
if isinstance(frame, rtc.AudioFrame):
|
360
383
|
yield cloud_speech.StreamingRecognizeRequest(
|
361
384
|
audio=frame.data.tobytes()
|
@@ -367,6 +390,7 @@ class SpeechStream(stt.SpeechStream):
|
|
367
390
|
)
|
368
391
|
|
369
392
|
async def process_stream(stream):
|
393
|
+
has_started = False
|
370
394
|
async for resp in stream:
|
371
395
|
if (
|
372
396
|
resp.speech_event_type
|
@@ -375,6 +399,7 @@ class SpeechStream(stt.SpeechStream):
|
|
375
399
|
self._event_ch.send_nowait(
|
376
400
|
stt.SpeechEvent(type=stt.SpeechEventType.START_OF_SPEECH)
|
377
401
|
)
|
402
|
+
has_started = True
|
378
403
|
|
379
404
|
if (
|
380
405
|
resp.speech_event_type
|
@@ -399,6 +424,22 @@ class SpeechStream(stt.SpeechStream):
|
|
399
424
|
alternatives=[speech_data],
|
400
425
|
)
|
401
426
|
)
|
427
|
+
if (
|
428
|
+
time.time() - self._session_connected_at
|
429
|
+
> _max_session_duration
|
430
|
+
):
|
431
|
+
logger.debug(
|
432
|
+
"Google STT maximum connection time reached. Reconnecting..."
|
433
|
+
)
|
434
|
+
if has_started:
|
435
|
+
self._event_ch.send_nowait(
|
436
|
+
stt.SpeechEvent(
|
437
|
+
type=stt.SpeechEventType.END_OF_SPEECH
|
438
|
+
)
|
439
|
+
)
|
440
|
+
has_started = False
|
441
|
+
self._reconnect_event.set()
|
442
|
+
return
|
402
443
|
|
403
444
|
if (
|
404
445
|
resp.speech_event_type
|
@@ -407,6 +448,7 @@ class SpeechStream(stt.SpeechStream):
|
|
407
448
|
self._event_ch.send_nowait(
|
408
449
|
stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH)
|
409
450
|
)
|
451
|
+
has_started = False
|
410
452
|
|
411
453
|
while True:
|
412
454
|
try:
|
@@ -431,12 +473,15 @@ class SpeechStream(stt.SpeechStream):
|
|
431
473
|
),
|
432
474
|
)
|
433
475
|
|
476
|
+
should_stop = asyncio.Event()
|
434
477
|
stream = await self._client.streaming_recognize(
|
435
|
-
requests=input_generator(),
|
478
|
+
requests=input_generator(should_stop),
|
436
479
|
)
|
480
|
+
self._session_connected_at = time.time()
|
437
481
|
|
438
482
|
process_stream_task = asyncio.create_task(process_stream(stream))
|
439
483
|
wait_reconnect_task = asyncio.create_task(self._reconnect_event.wait())
|
484
|
+
|
440
485
|
try:
|
441
486
|
done, _ = await asyncio.wait(
|
442
487
|
[process_stream_task, wait_reconnect_task],
|
@@ -445,14 +490,23 @@ class SpeechStream(stt.SpeechStream):
|
|
445
490
|
for task in done:
|
446
491
|
if task != wait_reconnect_task:
|
447
492
|
task.result()
|
493
|
+
if wait_reconnect_task not in done:
|
494
|
+
break
|
495
|
+
self._reconnect_event.clear()
|
448
496
|
finally:
|
449
497
|
await utils.aio.gracefully_cancel(
|
450
498
|
process_stream_task, wait_reconnect_task
|
451
499
|
)
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
500
|
+
should_stop.set()
|
501
|
+
except DeadlineExceeded:
|
502
|
+
raise APITimeoutError()
|
503
|
+
except GoogleAPICallError as e:
|
504
|
+
raise APIStatusError(
|
505
|
+
e.message,
|
506
|
+
status_code=e.code or -1,
|
507
|
+
)
|
508
|
+
except Exception as e:
|
509
|
+
raise APIConnectionError() from e
|
456
510
|
|
457
511
|
|
458
512
|
def _recognize_response_to_speech_event(
|
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.2
|
2
2
|
Name: livekit-plugins-google
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.10.0
|
4
4
|
Summary: Agent Framework plugin for services from Google Cloud
|
5
5
|
Home-page: https://github.com/livekit/agents
|
6
6
|
License: Apache-2.0
|
@@ -22,8 +22,18 @@ Description-Content-Type: text/markdown
|
|
22
22
|
Requires-Dist: google-auth<3,>=2
|
23
23
|
Requires-Dist: google-cloud-speech<3,>=2
|
24
24
|
Requires-Dist: google-cloud-texttospeech<3,>=2
|
25
|
-
Requires-Dist: google-genai
|
25
|
+
Requires-Dist: google-genai==0.5.0
|
26
26
|
Requires-Dist: livekit-agents>=0.12.3
|
27
|
+
Dynamic: classifier
|
28
|
+
Dynamic: description
|
29
|
+
Dynamic: description-content-type
|
30
|
+
Dynamic: home-page
|
31
|
+
Dynamic: keywords
|
32
|
+
Dynamic: license
|
33
|
+
Dynamic: project-url
|
34
|
+
Dynamic: requires-dist
|
35
|
+
Dynamic: requires-python
|
36
|
+
Dynamic: summary
|
27
37
|
|
28
38
|
# LiveKit Plugins Google
|
29
39
|
|
@@ -0,0 +1,18 @@
|
|
1
|
+
livekit/plugins/google/__init__.py,sha256=e_kSlFNmKhyyeliz7f4WOKc_Y0-y39QjO5nCWuguhss,1171
|
2
|
+
livekit/plugins/google/_utils.py,sha256=mjsqblhGMgAZ2MNPisAVkNsqq4gfO6vvprEKzAGoVwE,7248
|
3
|
+
livekit/plugins/google/llm.py,sha256=vL8iyRqWVPT0wCDeXTlybytlyJ-J-VolVQYqP-ZVlb0,16388
|
4
|
+
livekit/plugins/google/log.py,sha256=GI3YWN5YzrafnUccljzPRS_ZALkMNk1i21IRnTl2vNA,69
|
5
|
+
livekit/plugins/google/models.py,sha256=w_qmOk5y86vjtszDiGpP9p0ctjQeaB8-UzqprxgpvCY,1407
|
6
|
+
livekit/plugins/google/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
7
|
+
livekit/plugins/google/stt.py,sha256=E5kXPbicH4FEXBjyBzfqQWA-nPhKkojzcc-cbtWdmNs,21088
|
8
|
+
livekit/plugins/google/tts.py,sha256=95qXCigVQYWNbcN3pIKBpIah4b31U_MWtXv5Ji0AMc4,9229
|
9
|
+
livekit/plugins/google/version.py,sha256=sAL7xgP18DEksjwYUwabcCgRgKAAGXSWs6xp7NgcxoU,601
|
10
|
+
livekit/plugins/google/beta/__init__.py,sha256=AxRYc7NGG62Tv1MmcZVCDHNvlhbC86hM-_yP01Qb28k,47
|
11
|
+
livekit/plugins/google/beta/realtime/__init__.py,sha256=sGTn6JFNyA30QUXBZ_BV3l2eHpGAzR35ByXxg77vWNU,205
|
12
|
+
livekit/plugins/google/beta/realtime/api_proto.py,sha256=9EhmwgeIgKDqdSijv5Q9pgx7UhAakK02ZDwbnUsra_o,657
|
13
|
+
livekit/plugins/google/beta/realtime/realtime_api.py,sha256=vCjDQZvHS749Gf-QOLo-RaW4HlQHlzuArd3IlN5xMmY,21459
|
14
|
+
livekit/plugins/google/beta/realtime/transcriber.py,sha256=3TaYbtvPWHkxKlDSZSMLWBbR7KewBRg3HcdIxuGhl9c,5880
|
15
|
+
livekit_plugins_google-0.10.0.dist-info/METADATA,sha256=lsA9pwlWHE-q-9x3HKn2EeJ7ZdcpjxzEtYs1wRH5axE,2057
|
16
|
+
livekit_plugins_google-0.10.0.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
17
|
+
livekit_plugins_google-0.10.0.dist-info/top_level.txt,sha256=OoDok3xUmXbZRvOrfvvXB-Juu4DX79dlq188E19YHoo,8
|
18
|
+
livekit_plugins_google-0.10.0.dist-info/RECORD,,
|