livekit-plugins-google 0.11.1__py3-none-any.whl → 1.0.0.dev4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- livekit/plugins/google/beta/realtime/__init__.py +1 -5
- livekit/plugins/google/beta/realtime/api_proto.py +3 -2
- livekit/plugins/google/beta/realtime/realtime_api.py +22 -51
- livekit/plugins/google/beta/realtime/transcriber.py +11 -27
- livekit/plugins/google/llm.py +127 -197
- livekit/plugins/google/stt.py +28 -58
- livekit/plugins/google/tts.py +10 -16
- livekit/plugins/google/utils.py +213 -0
- livekit/plugins/google/version.py +1 -1
- {livekit_plugins_google-0.11.1.dist-info → livekit_plugins_google-1.0.0.dev4.dist-info}/METADATA +12 -22
- livekit_plugins_google-1.0.0.dev4.dist-info/RECORD +17 -0
- {livekit_plugins_google-0.11.1.dist-info → livekit_plugins_google-1.0.0.dev4.dist-info}/WHEEL +1 -2
- livekit/plugins/google/_utils.py +0 -199
- livekit_plugins_google-0.11.1.dist-info/RECORD +0 -18
- livekit_plugins_google-0.11.1.dist-info/top_level.txt +0 -1
livekit/plugins/google/llm.py
CHANGED
@@ -15,45 +15,43 @@
|
|
15
15
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
|
-
import asyncio
|
19
18
|
import json
|
20
19
|
import os
|
21
20
|
from dataclasses import dataclass
|
22
|
-
from typing import Any, Literal,
|
23
|
-
|
24
|
-
from livekit.agents import (
|
25
|
-
APIConnectionError,
|
26
|
-
APIStatusError,
|
27
|
-
llm,
|
28
|
-
utils,
|
29
|
-
)
|
30
|
-
from livekit.agents.llm import LLMCapabilities, ToolChoice, _create_ai_function_info
|
31
|
-
from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions
|
21
|
+
from typing import Any, Literal, cast
|
32
22
|
|
33
23
|
from google import genai
|
34
24
|
from google.auth._default_async import default_async
|
35
25
|
from google.genai import types
|
36
26
|
from google.genai.errors import APIError, ClientError, ServerError
|
27
|
+
from livekit.agents import APIConnectionError, APIStatusError, llm, utils
|
28
|
+
from livekit.agents.llm import FunctionTool, ToolChoice
|
29
|
+
from livekit.agents.types import (
|
30
|
+
DEFAULT_API_CONNECT_OPTIONS,
|
31
|
+
NOT_GIVEN,
|
32
|
+
APIConnectOptions,
|
33
|
+
NotGivenOr,
|
34
|
+
)
|
35
|
+
from livekit.agents.utils import is_given
|
37
36
|
|
38
|
-
from ._utils import _build_gemini_ctx, _build_tools
|
39
37
|
from .log import logger
|
40
38
|
from .models import ChatModels
|
39
|
+
from .utils import to_chat_ctx, to_fnc_ctx
|
41
40
|
|
42
41
|
|
43
42
|
@dataclass
|
44
|
-
class
|
43
|
+
class _LLMOptions:
|
45
44
|
model: ChatModels | str
|
46
|
-
temperature: float
|
47
|
-
tool_choice:
|
48
|
-
vertexai: bool
|
49
|
-
project: str
|
50
|
-
location: str
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
frequency_penalty: float | None = None
|
45
|
+
temperature: NotGivenOr[float]
|
46
|
+
tool_choice: NotGivenOr[ToolChoice | Literal["auto", "required", "none"]]
|
47
|
+
vertexai: NotGivenOr[bool]
|
48
|
+
project: NotGivenOr[str]
|
49
|
+
location: NotGivenOr[str]
|
50
|
+
max_output_tokens: NotGivenOr[int]
|
51
|
+
top_p: NotGivenOr[float]
|
52
|
+
top_k: NotGivenOr[float]
|
53
|
+
presence_penalty: NotGivenOr[float]
|
54
|
+
frequency_penalty: NotGivenOr[float]
|
57
55
|
|
58
56
|
|
59
57
|
class LLM(llm.LLM):
|
@@ -61,18 +59,17 @@ class LLM(llm.LLM):
|
|
61
59
|
self,
|
62
60
|
*,
|
63
61
|
model: ChatModels | str = "gemini-2.0-flash-001",
|
64
|
-
api_key: str
|
65
|
-
vertexai: bool = False,
|
66
|
-
project: str
|
67
|
-
location: str
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto",
|
62
|
+
api_key: NotGivenOr[str] = NOT_GIVEN,
|
63
|
+
vertexai: NotGivenOr[bool] = False,
|
64
|
+
project: NotGivenOr[str] = NOT_GIVEN,
|
65
|
+
location: NotGivenOr[str] = NOT_GIVEN,
|
66
|
+
temperature: NotGivenOr[float] = NOT_GIVEN,
|
67
|
+
max_output_tokens: NotGivenOr[int] = NOT_GIVEN,
|
68
|
+
top_p: NotGivenOr[float] = NOT_GIVEN,
|
69
|
+
top_k: NotGivenOr[float] = NOT_GIVEN,
|
70
|
+
presence_penalty: NotGivenOr[float] = NOT_GIVEN,
|
71
|
+
frequency_penalty: NotGivenOr[float] = NOT_GIVEN,
|
72
|
+
tool_choice: NotGivenOr[ToolChoice | Literal["auto", "required", "none"]] = NOT_GIVEN,
|
76
73
|
) -> None:
|
77
74
|
"""
|
78
75
|
Create a new instance of Google GenAI LLM.
|
@@ -90,7 +87,6 @@ class LLM(llm.LLM):
|
|
90
87
|
vertexai (bool, optional): Whether to use VertexAI. Defaults to False.
|
91
88
|
project (str, optional): The Google Cloud project to use (only for VertexAI). Defaults to None.
|
92
89
|
location (str, optional): The location to use for VertexAI API requests. Defaults value is "us-central1".
|
93
|
-
candidate_count (int, optional): Number of candidate responses to generate. Defaults to 1.
|
94
90
|
temperature (float, optional): Sampling temperature for response generation. Defaults to 0.8.
|
95
91
|
max_output_tokens (int, optional): Maximum number of tokens to generate in the output. Defaults to None.
|
96
92
|
top_p (float, optional): The nucleus sampling probability for response generation. Defaults to None.
|
@@ -99,16 +95,9 @@ class LLM(llm.LLM):
|
|
99
95
|
frequency_penalty (float, optional): Penalizes the model for repeating words. Defaults to None.
|
100
96
|
tool_choice (ToolChoice or Literal["auto", "required", "none"], optional): Specifies whether to use tools during response generation. Defaults to "auto".
|
101
97
|
"""
|
102
|
-
super().__init__(
|
103
|
-
capabilities=LLMCapabilities(
|
104
|
-
supports_choices_on_int=False,
|
105
|
-
requires_persistent_functions=False,
|
106
|
-
)
|
107
|
-
)
|
98
|
+
super().__init__()
|
108
99
|
self._project_id = project or os.environ.get("GOOGLE_CLOUD_PROJECT", None)
|
109
|
-
self._location = location or os.environ.get(
|
110
|
-
"GOOGLE_CLOUD_LOCATION", "us-central1"
|
111
|
-
)
|
100
|
+
self._location = location or os.environ.get("GOOGLE_CLOUD_LOCATION", "us-central1")
|
112
101
|
self._api_key = api_key or os.environ.get("GOOGLE_API_KEY", None)
|
113
102
|
_gac = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
|
114
103
|
if _gac is None:
|
@@ -131,14 +120,13 @@ class LLM(llm.LLM):
|
|
131
120
|
"API key is required for Google API either via api_key or GOOGLE_API_KEY environment variable"
|
132
121
|
)
|
133
122
|
|
134
|
-
self._opts =
|
123
|
+
self._opts = _LLMOptions(
|
135
124
|
model=model,
|
136
125
|
temperature=temperature,
|
137
126
|
tool_choice=tool_choice,
|
138
127
|
vertexai=vertexai,
|
139
128
|
project=project,
|
140
129
|
location=location,
|
141
|
-
candidate_count=candidate_count,
|
142
130
|
max_output_tokens=max_output_tokens,
|
143
131
|
top_p=top_p,
|
144
132
|
top_k=top_k,
|
@@ -151,41 +139,77 @@ class LLM(llm.LLM):
|
|
151
139
|
project=self._project_id,
|
152
140
|
location=self._location,
|
153
141
|
)
|
154
|
-
self._running_fncs: MutableSet[asyncio.Task[Any]] = set()
|
155
142
|
|
156
143
|
def chat(
|
157
144
|
self,
|
158
145
|
*,
|
159
146
|
chat_ctx: llm.ChatContext,
|
147
|
+
tools: list[FunctionTool] | None = None,
|
160
148
|
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
if
|
172
|
-
|
149
|
+
parallel_tool_calls: NotGivenOr[bool] = NOT_GIVEN,
|
150
|
+
tool_choice: NotGivenOr[ToolChoice | Literal["auto", "required", "none"]] = NOT_GIVEN,
|
151
|
+
extra_kwargs: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
|
152
|
+
) -> LLMStream:
|
153
|
+
extra = {}
|
154
|
+
|
155
|
+
if is_given(extra_kwargs):
|
156
|
+
extra.update(extra_kwargs)
|
157
|
+
|
158
|
+
tool_choice = tool_choice if is_given(tool_choice) else self._opts.tool_choice
|
159
|
+
if is_given(tool_choice):
|
160
|
+
gemini_tool_choice: types.ToolConfig
|
161
|
+
if isinstance(tool_choice, ToolChoice):
|
162
|
+
gemini_tool_choice = types.ToolConfig(
|
163
|
+
function_calling_config=types.FunctionCallingConfig(
|
164
|
+
mode="ANY",
|
165
|
+
allowed_function_names=[tool_choice["function"]["name"]],
|
166
|
+
)
|
167
|
+
)
|
168
|
+
extra["tool_config"] = gemini_tool_choice
|
169
|
+
elif tool_choice == "required":
|
170
|
+
gemini_tool_choice = types.ToolConfig(
|
171
|
+
function_calling_config=types.FunctionCallingConfig(
|
172
|
+
mode="ANY",
|
173
|
+
allowed_function_names=[fnc.name for fnc in tools],
|
174
|
+
)
|
175
|
+
)
|
176
|
+
extra["tool_config"] = gemini_tool_choice
|
177
|
+
elif tool_choice == "auto":
|
178
|
+
gemini_tool_choice = types.ToolConfig(
|
179
|
+
function_calling_config=types.FunctionCallingConfig(
|
180
|
+
mode="AUTO",
|
181
|
+
)
|
182
|
+
)
|
183
|
+
extra["tool_config"] = gemini_tool_choice
|
184
|
+
elif tool_choice == "none":
|
185
|
+
gemini_tool_choice = types.ToolConfig(
|
186
|
+
function_calling_config=types.FunctionCallingConfig(
|
187
|
+
mode="NONE",
|
188
|
+
)
|
189
|
+
)
|
190
|
+
extra["tool_config"] = gemini_tool_choice
|
191
|
+
|
192
|
+
if is_given(self._opts.temperature):
|
193
|
+
extra["temperature"] = self._opts.temperature
|
194
|
+
if is_given(self._opts.max_output_tokens):
|
195
|
+
extra["max_output_tokens"] = self._opts.max_output_tokens
|
196
|
+
if is_given(self._opts.top_p):
|
197
|
+
extra["top_p"] = self._opts.top_p
|
198
|
+
if is_given(self._opts.top_k):
|
199
|
+
extra["top_k"] = self._opts.top_k
|
200
|
+
if is_given(self._opts.presence_penalty):
|
201
|
+
extra["presence_penalty"] = self._opts.presence_penalty
|
202
|
+
if is_given(self._opts.frequency_penalty):
|
203
|
+
extra["frequency_penalty"] = self._opts.frequency_penalty
|
173
204
|
|
174
205
|
return LLMStream(
|
175
206
|
self,
|
176
207
|
client=self._client,
|
177
208
|
model=self._opts.model,
|
178
|
-
max_output_tokens=self._opts.max_output_tokens,
|
179
|
-
top_p=self._opts.top_p,
|
180
|
-
top_k=self._opts.top_k,
|
181
|
-
presence_penalty=self._opts.presence_penalty,
|
182
|
-
frequency_penalty=self._opts.frequency_penalty,
|
183
209
|
chat_ctx=chat_ctx,
|
184
|
-
|
210
|
+
tools=tools,
|
185
211
|
conn_options=conn_options,
|
186
|
-
|
187
|
-
temperature=temperature,
|
188
|
-
tool_choice=tool_choice,
|
212
|
+
extra_kwargs=extra,
|
189
213
|
)
|
190
214
|
|
191
215
|
|
@@ -198,96 +222,37 @@ class LLMStream(llm.LLMStream):
|
|
198
222
|
model: str | ChatModels,
|
199
223
|
chat_ctx: llm.ChatContext,
|
200
224
|
conn_options: APIConnectOptions,
|
201
|
-
|
202
|
-
|
203
|
-
n: int | None,
|
204
|
-
max_output_tokens: int | None,
|
205
|
-
top_p: float | None,
|
206
|
-
top_k: float | None,
|
207
|
-
presence_penalty: float | None,
|
208
|
-
frequency_penalty: float | None,
|
209
|
-
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]],
|
225
|
+
tools: list[FunctionTool] | None,
|
226
|
+
extra_kwargs: dict[str, Any],
|
210
227
|
) -> None:
|
211
|
-
super().__init__(
|
212
|
-
llm, chat_ctx=chat_ctx, fnc_ctx=fnc_ctx, conn_options=conn_options
|
213
|
-
)
|
228
|
+
super().__init__(llm, chat_ctx=chat_ctx, tools=tools, conn_options=conn_options)
|
214
229
|
self._client = client
|
215
230
|
self._model = model
|
216
231
|
self._llm: LLM = llm
|
217
|
-
self.
|
218
|
-
self._top_p = top_p
|
219
|
-
self._top_k = top_k
|
220
|
-
self._presence_penalty = presence_penalty
|
221
|
-
self._frequency_penalty = frequency_penalty
|
222
|
-
self._temperature = temperature
|
223
|
-
self._n = n
|
224
|
-
self._tool_choice = tool_choice
|
232
|
+
self._extra_kwargs = extra_kwargs
|
225
233
|
|
226
234
|
async def _run(self) -> None:
|
227
235
|
retryable = True
|
228
236
|
request_id = utils.shortuuid()
|
229
237
|
|
230
238
|
try:
|
231
|
-
|
232
|
-
turns, system_instruction = _build_gemini_ctx(self._chat_ctx, id(self))
|
233
|
-
|
234
|
-
if self._fnc_ctx and len(self._fnc_ctx.ai_functions) > 0:
|
235
|
-
functions = _build_tools(self._fnc_ctx)
|
236
|
-
opts["tools"] = [types.Tool(function_declarations=functions)]
|
237
|
-
|
238
|
-
if self._tool_choice is not None:
|
239
|
-
if isinstance(self._tool_choice, ToolChoice):
|
240
|
-
# specific function
|
241
|
-
tool_config = types.ToolConfig(
|
242
|
-
function_calling_config=types.FunctionCallingConfig(
|
243
|
-
mode=types.FunctionCallingConfigMode.ANY,
|
244
|
-
allowed_function_names=[self._tool_choice.name],
|
245
|
-
)
|
246
|
-
)
|
247
|
-
elif self._tool_choice == "required":
|
248
|
-
# model must call any function
|
249
|
-
tool_config = types.ToolConfig(
|
250
|
-
function_calling_config=types.FunctionCallingConfig(
|
251
|
-
mode=types.FunctionCallingConfigMode.ANY,
|
252
|
-
allowed_function_names=[
|
253
|
-
fnc.name
|
254
|
-
for fnc in self._fnc_ctx.ai_functions.values()
|
255
|
-
],
|
256
|
-
)
|
257
|
-
)
|
258
|
-
elif self._tool_choice == "auto":
|
259
|
-
# model can call any function
|
260
|
-
tool_config = types.ToolConfig(
|
261
|
-
function_calling_config=types.FunctionCallingConfig(
|
262
|
-
mode=types.FunctionCallingConfigMode.AUTO
|
263
|
-
)
|
264
|
-
)
|
265
|
-
elif self._tool_choice == "none":
|
266
|
-
# model cannot call any function
|
267
|
-
tool_config = types.ToolConfig(
|
268
|
-
function_calling_config=types.FunctionCallingConfig(
|
269
|
-
mode=types.FunctionCallingConfigMode.NONE,
|
270
|
-
)
|
271
|
-
)
|
272
|
-
opts["tool_config"] = tool_config
|
239
|
+
turns, system_instruction = to_chat_ctx(self._chat_ctx, id(self._llm))
|
273
240
|
|
241
|
+
self._extra_kwargs["tools"] = [
|
242
|
+
types.Tool(function_declarations=to_fnc_ctx(self._tools))
|
243
|
+
]
|
274
244
|
config = types.GenerateContentConfig(
|
275
|
-
candidate_count=self._n,
|
276
|
-
temperature=self._temperature,
|
277
|
-
max_output_tokens=self._max_output_tokens,
|
278
|
-
top_p=self._top_p,
|
279
|
-
top_k=self._top_k,
|
280
|
-
presence_penalty=self._presence_penalty,
|
281
|
-
frequency_penalty=self._frequency_penalty,
|
282
245
|
system_instruction=system_instruction,
|
283
|
-
**
|
246
|
+
**self._extra_kwargs,
|
284
247
|
)
|
248
|
+
|
285
249
|
stream = await self._client.aio.models.generate_content_stream(
|
286
250
|
model=self._model,
|
287
251
|
contents=cast(types.ContentListUnion, turns),
|
288
252
|
config=config,
|
289
253
|
)
|
290
|
-
|
254
|
+
|
255
|
+
async for response in stream:
|
291
256
|
if response.prompt_feedback:
|
292
257
|
raise APIStatusError(
|
293
258
|
response.prompt_feedback.json(),
|
@@ -311,8 +276,8 @@ class LLMStream(llm.LLMStream):
|
|
311
276
|
"gemini llm: there are multiple candidates in the response, returning response from the first one."
|
312
277
|
)
|
313
278
|
|
314
|
-
for
|
315
|
-
chat_chunk = self._parse_part(request_id,
|
279
|
+
for part in response.candidates[0].content.parts:
|
280
|
+
chat_chunk = self._parse_part(request_id, part)
|
316
281
|
if chat_chunk is not None:
|
317
282
|
retryable = False
|
318
283
|
self._event_ch.send_nowait(chat_chunk)
|
@@ -321,7 +286,7 @@ class LLMStream(llm.LLMStream):
|
|
321
286
|
usage = response.usage_metadata
|
322
287
|
self._event_ch.send_nowait(
|
323
288
|
llm.ChatChunk(
|
324
|
-
|
289
|
+
id=request_id,
|
325
290
|
usage=llm.CompletionUsage(
|
326
291
|
completion_tokens=usage.candidates_token_count or 0,
|
327
292
|
prompt_tokens=usage.prompt_token_count or 0,
|
@@ -329,6 +294,7 @@ class LLMStream(llm.LLMStream):
|
|
329
294
|
),
|
330
295
|
)
|
331
296
|
)
|
297
|
+
|
332
298
|
except ClientError as e:
|
333
299
|
raise APIStatusError(
|
334
300
|
"gemini llm: client error",
|
@@ -359,61 +325,25 @@ class LLMStream(llm.LLMStream):
|
|
359
325
|
retryable=retryable,
|
360
326
|
) from e
|
361
327
|
|
362
|
-
def _parse_part(
|
363
|
-
self, id: str, index: int, part: types.Part
|
364
|
-
) -> llm.ChatChunk | None:
|
328
|
+
def _parse_part(self, id: str, part: types.Part) -> llm.ChatChunk | None:
|
365
329
|
if part.function_call:
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
self, id: str, index: int, part: types.Part
|
380
|
-
) -> llm.ChatChunk | None:
|
381
|
-
if part.function_call is None:
|
382
|
-
logger.warning("gemini llm: no function call in the response")
|
383
|
-
return None
|
384
|
-
|
385
|
-
if part.function_call.name is None:
|
386
|
-
logger.warning("gemini llm: no function name in the response")
|
387
|
-
return None
|
388
|
-
|
389
|
-
if part.function_call.id is None:
|
390
|
-
part.function_call.id = utils.shortuuid()
|
391
|
-
|
392
|
-
if self._fnc_ctx is None:
|
393
|
-
logger.warning(
|
394
|
-
"google stream tried to run function without function context"
|
330
|
+
chat_chunk = llm.ChatChunk(
|
331
|
+
id=id,
|
332
|
+
delta=llm.ChoiceDelta(
|
333
|
+
role="assistant",
|
334
|
+
tool_calls=[
|
335
|
+
llm.FunctionToolCall(
|
336
|
+
arguments=json.dumps(part.function_call.args),
|
337
|
+
name=part.function_call.name,
|
338
|
+
call_id=part.function_call.id or utils.shortuuid("function_call_"),
|
339
|
+
)
|
340
|
+
],
|
341
|
+
content=part.text,
|
342
|
+
),
|
395
343
|
)
|
396
|
-
return
|
397
|
-
|
398
|
-
fnc_info = _create_ai_function_info(
|
399
|
-
self._fnc_ctx,
|
400
|
-
part.function_call.id,
|
401
|
-
part.function_call.name,
|
402
|
-
json.dumps(part.function_call.args),
|
403
|
-
)
|
404
|
-
|
405
|
-
self._function_calls_info.append(fnc_info)
|
344
|
+
return chat_chunk
|
406
345
|
|
407
346
|
return llm.ChatChunk(
|
408
|
-
|
409
|
-
|
410
|
-
llm.Choice(
|
411
|
-
delta=llm.ChoiceDelta(
|
412
|
-
role="assistant",
|
413
|
-
tool_calls=[fnc_info],
|
414
|
-
content=part.text,
|
415
|
-
),
|
416
|
-
index=index,
|
417
|
-
)
|
418
|
-
],
|
347
|
+
id=id,
|
348
|
+
delta=llm.ChoiceDelta(content=part.text, role="assistant"),
|
419
349
|
)
|
livekit/plugins/google/stt.py
CHANGED
@@ -19,8 +19,14 @@ import dataclasses
|
|
19
19
|
import time
|
20
20
|
import weakref
|
21
21
|
from dataclasses import dataclass
|
22
|
-
from typing import Callable,
|
22
|
+
from typing import Callable, Union
|
23
23
|
|
24
|
+
from google.api_core.client_options import ClientOptions
|
25
|
+
from google.api_core.exceptions import DeadlineExceeded, GoogleAPICallError
|
26
|
+
from google.auth import default as gauth_default
|
27
|
+
from google.auth.exceptions import DefaultCredentialsError
|
28
|
+
from google.cloud.speech_v2 import SpeechAsyncClient
|
29
|
+
from google.cloud.speech_v2.types import cloud_speech
|
24
30
|
from livekit import rtc
|
25
31
|
from livekit.agents import (
|
26
32
|
DEFAULT_API_CONNECT_OPTIONS,
|
@@ -32,18 +38,11 @@ from livekit.agents import (
|
|
32
38
|
utils,
|
33
39
|
)
|
34
40
|
|
35
|
-
from google.api_core.client_options import ClientOptions
|
36
|
-
from google.api_core.exceptions import DeadlineExceeded, GoogleAPICallError
|
37
|
-
from google.auth import default as gauth_default
|
38
|
-
from google.auth.exceptions import DefaultCredentialsError
|
39
|
-
from google.cloud.speech_v2 import SpeechAsyncClient
|
40
|
-
from google.cloud.speech_v2.types import cloud_speech
|
41
|
-
|
42
41
|
from .log import logger
|
43
42
|
from .models import SpeechLanguages, SpeechModels
|
44
43
|
|
45
44
|
LgType = Union[SpeechLanguages, str]
|
46
|
-
LanguageCode = Union[LgType,
|
45
|
+
LanguageCode = Union[LgType, list[LgType]]
|
47
46
|
|
48
47
|
# Google STT has a timeout of 5 mins, we'll attempt to restart the session
|
49
48
|
# before that timeout is reached
|
@@ -56,14 +55,14 @@ _min_confidence = 0.65
|
|
56
55
|
# This class is only be used internally to encapsulate the options
|
57
56
|
@dataclass
|
58
57
|
class STTOptions:
|
59
|
-
languages:
|
58
|
+
languages: list[LgType]
|
60
59
|
detect_language: bool
|
61
60
|
interim_results: bool
|
62
61
|
punctuate: bool
|
63
62
|
spoken_punctuation: bool
|
64
63
|
model: SpeechModels | str
|
65
64
|
sample_rate: int
|
66
|
-
keywords:
|
65
|
+
keywords: list[tuple[str, float]] | None
|
67
66
|
|
68
67
|
def build_adaptation(self) -> cloud_speech.SpeechAdaptation | None:
|
69
68
|
if self.keywords:
|
@@ -72,9 +71,7 @@ class STTOptions:
|
|
72
71
|
cloud_speech.SpeechAdaptation.AdaptationPhraseSet(
|
73
72
|
inline_phrase_set=cloud_speech.PhraseSet(
|
74
73
|
phrases=[
|
75
|
-
cloud_speech.PhraseSet.Phrase(
|
76
|
-
value=keyword, boost=boost
|
77
|
-
)
|
74
|
+
cloud_speech.PhraseSet.Phrase(value=keyword, boost=boost)
|
78
75
|
for keyword, boost in self.keywords
|
79
76
|
]
|
80
77
|
)
|
@@ -98,7 +95,7 @@ class STT(stt.STT):
|
|
98
95
|
sample_rate: int = 16000,
|
99
96
|
credentials_info: dict | None = None,
|
100
97
|
credentials_file: str | None = None,
|
101
|
-
keywords:
|
98
|
+
keywords: list[tuple[str, float]] | None = None,
|
102
99
|
):
|
103
100
|
"""
|
104
101
|
Create a new instance of Google STT.
|
@@ -120,9 +117,7 @@ class STT(stt.STT):
|
|
120
117
|
credentials_file(str): the credentials file to use for recognition (default: None)
|
121
118
|
keywords(List[tuple[str, float]]): list of keywords to recognize (default: None)
|
122
119
|
"""
|
123
|
-
super().__init__(
|
124
|
-
capabilities=stt.STTCapabilities(streaming=True, interim_results=True)
|
125
|
-
)
|
120
|
+
super().__init__(capabilities=stt.STTCapabilities(streaming=True, interim_results=True))
|
126
121
|
|
127
122
|
self._location = location
|
128
123
|
self._credentials_info = credentials_info
|
@@ -163,9 +158,7 @@ class STT(stt.STT):
|
|
163
158
|
client_options = None
|
164
159
|
client: SpeechAsyncClient | None = None
|
165
160
|
if self._location != "global":
|
166
|
-
client_options = ClientOptions(
|
167
|
-
api_endpoint=f"{self._location}-speech.googleapis.com"
|
168
|
-
)
|
161
|
+
client_options = ClientOptions(api_endpoint=f"{self._location}-speech.googleapis.com")
|
169
162
|
if self._credentials_info:
|
170
163
|
client = SpeechAsyncClient.from_service_account_info(
|
171
164
|
self._credentials_info,
|
@@ -206,9 +199,7 @@ class STT(stt.STT):
|
|
206
199
|
config.languages = [config.languages]
|
207
200
|
elif not config.detect_language:
|
208
201
|
if len(config.languages) > 1:
|
209
|
-
logger.warning(
|
210
|
-
"multiple languages provided, but language detection is disabled"
|
211
|
-
)
|
202
|
+
logger.warning("multiple languages provided, but language detection is disabled")
|
212
203
|
config.languages = [config.languages[0]]
|
213
204
|
|
214
205
|
return config
|
@@ -266,7 +257,7 @@ class STT(stt.STT):
|
|
266
257
|
*,
|
267
258
|
language: SpeechLanguages | str | None = None,
|
268
259
|
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
269
|
-
) ->
|
260
|
+
) -> SpeechStream:
|
270
261
|
config = self._sanitize_options(language=language)
|
271
262
|
stream = SpeechStream(
|
272
263
|
stt=self,
|
@@ -288,7 +279,7 @@ class STT(stt.STT):
|
|
288
279
|
spoken_punctuation: bool | None = None,
|
289
280
|
model: SpeechModels | None = None,
|
290
281
|
location: str | None = None,
|
291
|
-
keywords:
|
282
|
+
keywords: list[tuple[str, float]] | None = None,
|
292
283
|
):
|
293
284
|
if languages is not None:
|
294
285
|
if isinstance(languages, str):
|
@@ -337,9 +328,7 @@ class SpeechStream(stt.SpeechStream):
|
|
337
328
|
recognizer_cb: Callable[[SpeechAsyncClient], str],
|
338
329
|
config: STTOptions,
|
339
330
|
) -> None:
|
340
|
-
super().__init__(
|
341
|
-
stt=stt, conn_options=conn_options, sample_rate=config.sample_rate
|
342
|
-
)
|
331
|
+
super().__init__(stt=stt, conn_options=conn_options, sample_rate=config.sample_rate)
|
343
332
|
|
344
333
|
self._pool = pool
|
345
334
|
self._recognizer_cb = recognizer_cb
|
@@ -356,7 +345,7 @@ class SpeechStream(stt.SpeechStream):
|
|
356
345
|
punctuate: bool | None = None,
|
357
346
|
spoken_punctuation: bool | None = None,
|
358
347
|
model: SpeechModels | None = None,
|
359
|
-
keywords:
|
348
|
+
keywords: list[tuple[str, float]] | None = None,
|
360
349
|
):
|
361
350
|
if languages is not None:
|
362
351
|
if isinstance(languages, str):
|
@@ -380,9 +369,7 @@ class SpeechStream(stt.SpeechStream):
|
|
380
369
|
async def _run(self) -> None:
|
381
370
|
# google requires a async generator when calling streaming_recognize
|
382
371
|
# this function basically convert the queue into a async generator
|
383
|
-
async def input_generator(
|
384
|
-
client: SpeechAsyncClient, should_stop: asyncio.Event
|
385
|
-
):
|
372
|
+
async def input_generator(client: SpeechAsyncClient, should_stop: asyncio.Event):
|
386
373
|
try:
|
387
374
|
# first request should contain the config
|
388
375
|
yield cloud_speech.StreamingRecognizeRequest(
|
@@ -398,14 +385,10 @@ class SpeechStream(stt.SpeechStream):
|
|
398
385
|
return
|
399
386
|
|
400
387
|
if isinstance(frame, rtc.AudioFrame):
|
401
|
-
yield cloud_speech.StreamingRecognizeRequest(
|
402
|
-
audio=frame.data.tobytes()
|
403
|
-
)
|
388
|
+
yield cloud_speech.StreamingRecognizeRequest(audio=frame.data.tobytes())
|
404
389
|
|
405
390
|
except Exception:
|
406
|
-
logger.exception(
|
407
|
-
"an error occurred while streaming input to google STT"
|
408
|
-
)
|
391
|
+
logger.exception("an error occurred while streaming input to google STT")
|
409
392
|
|
410
393
|
async def process_stream(client: SpeechAsyncClient, stream):
|
411
394
|
has_started = False
|
@@ -442,19 +425,14 @@ class SpeechStream(stt.SpeechStream):
|
|
442
425
|
alternatives=[speech_data],
|
443
426
|
)
|
444
427
|
)
|
445
|
-
if (
|
446
|
-
time.time() - self._session_connected_at
|
447
|
-
> _max_session_duration
|
448
|
-
):
|
428
|
+
if time.time() - self._session_connected_at > _max_session_duration:
|
449
429
|
logger.debug(
|
450
430
|
"Google STT maximum connection time reached. Reconnecting..."
|
451
431
|
)
|
452
432
|
self._pool.remove(client)
|
453
433
|
if has_started:
|
454
434
|
self._event_ch.send_nowait(
|
455
|
-
stt.SpeechEvent(
|
456
|
-
type=stt.SpeechEventType.END_OF_SPEECH
|
457
|
-
)
|
435
|
+
stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH)
|
458
436
|
)
|
459
437
|
has_started = False
|
460
438
|
self._reconnect_event.set()
|
@@ -499,12 +477,8 @@ class SpeechStream(stt.SpeechStream):
|
|
499
477
|
)
|
500
478
|
self._session_connected_at = time.time()
|
501
479
|
|
502
|
-
process_stream_task = asyncio.create_task(
|
503
|
-
|
504
|
-
)
|
505
|
-
wait_reconnect_task = asyncio.create_task(
|
506
|
-
self._reconnect_event.wait()
|
507
|
-
)
|
480
|
+
process_stream_task = asyncio.create_task(process_stream(client, stream))
|
481
|
+
wait_reconnect_task = asyncio.create_task(self._reconnect_event.wait())
|
508
482
|
|
509
483
|
try:
|
510
484
|
done, _ = await asyncio.wait(
|
@@ -518,9 +492,7 @@ class SpeechStream(stt.SpeechStream):
|
|
518
492
|
break
|
519
493
|
self._reconnect_event.clear()
|
520
494
|
finally:
|
521
|
-
await utils.aio.gracefully_cancel(
|
522
|
-
process_stream_task, wait_reconnect_task
|
523
|
-
)
|
495
|
+
await utils.aio.gracefully_cancel(process_stream_task, wait_reconnect_task)
|
524
496
|
should_stop.set()
|
525
497
|
except DeadlineExceeded:
|
526
498
|
raise APITimeoutError()
|
@@ -581,8 +553,6 @@ def _streaming_recognize_response_to_speech_data(
|
|
581
553
|
if text == "":
|
582
554
|
return None
|
583
555
|
|
584
|
-
data = stt.SpeechData(
|
585
|
-
language=lg, start_time=0, end_time=0, confidence=confidence, text=text
|
586
|
-
)
|
556
|
+
data = stt.SpeechData(language=lg, start_time=0, end_time=0, confidence=confidence, text=text)
|
587
557
|
|
588
558
|
return data
|