livekit-plugins-google 1.0.23__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- livekit/plugins/google/__init__.py +3 -2
- livekit/plugins/google/beta/realtime/realtime_api.py +296 -118
- livekit/plugins/google/llm.py +60 -27
- livekit/plugins/google/stt.py +19 -12
- livekit/plugins/google/tools.py +11 -0
- livekit/plugins/google/tts.py +109 -136
- livekit/plugins/google/utils.py +39 -88
- livekit/plugins/google/version.py +1 -1
- {livekit_plugins_google-1.0.23.dist-info → livekit_plugins_google-1.1.0.dist-info}/METADATA +2 -2
- livekit_plugins_google-1.1.0.dist-info/RECORD +17 -0
- livekit_plugins_google-1.0.23.dist-info/RECORD +0 -16
- {livekit_plugins_google-1.0.23.dist-info → livekit_plugins_google-1.1.0.dist-info}/WHEEL +0 -0
livekit/plugins/google/llm.py
CHANGED
@@ -20,13 +20,17 @@ import os
|
|
20
20
|
from dataclasses import dataclass
|
21
21
|
from typing import Any, cast
|
22
22
|
|
23
|
-
from google import genai
|
24
23
|
from google.auth._default_async import default_async
|
25
|
-
from google.genai import types
|
24
|
+
from google.genai import Client, types
|
26
25
|
from google.genai.errors import APIError, ClientError, ServerError
|
27
26
|
from livekit.agents import APIConnectionError, APIStatusError, llm, utils
|
28
|
-
from livekit.agents.llm import FunctionTool, ToolChoice, utils as llm_utils
|
29
|
-
from livekit.agents.llm.tool_context import
|
27
|
+
from livekit.agents.llm import FunctionTool, RawFunctionTool, ToolChoice, utils as llm_utils
|
28
|
+
from livekit.agents.llm.tool_context import (
|
29
|
+
get_function_info,
|
30
|
+
get_raw_function_info,
|
31
|
+
is_function_tool,
|
32
|
+
is_raw_function_tool,
|
33
|
+
)
|
30
34
|
from livekit.agents.types import (
|
31
35
|
DEFAULT_API_CONNECT_OPTIONS,
|
32
36
|
NOT_GIVEN,
|
@@ -37,7 +41,8 @@ from livekit.agents.utils import is_given
|
|
37
41
|
|
38
42
|
from .log import logger
|
39
43
|
from .models import ChatModels
|
40
|
-
from .
|
44
|
+
from .tools import _LLMTool
|
45
|
+
from .utils import create_tools_config, to_fnc_ctx, to_response_format
|
41
46
|
|
42
47
|
|
43
48
|
@dataclass
|
@@ -54,6 +59,7 @@ class _LLMOptions:
|
|
54
59
|
presence_penalty: NotGivenOr[float]
|
55
60
|
frequency_penalty: NotGivenOr[float]
|
56
61
|
thinking_config: NotGivenOr[types.ThinkingConfigOrDict]
|
62
|
+
gemini_tools: NotGivenOr[list[_LLMTool]]
|
57
63
|
|
58
64
|
|
59
65
|
class LLM(llm.LLM):
|
@@ -73,6 +79,7 @@ class LLM(llm.LLM):
|
|
73
79
|
frequency_penalty: NotGivenOr[float] = NOT_GIVEN,
|
74
80
|
tool_choice: NotGivenOr[ToolChoice] = NOT_GIVEN,
|
75
81
|
thinking_config: NotGivenOr[types.ThinkingConfigOrDict] = NOT_GIVEN,
|
82
|
+
gemini_tools: NotGivenOr[list[_LLMTool]] = NOT_GIVEN,
|
76
83
|
) -> None:
|
77
84
|
"""
|
78
85
|
Create a new instance of Google GenAI LLM.
|
@@ -98,10 +105,11 @@ class LLM(llm.LLM):
|
|
98
105
|
frequency_penalty (float, optional): Penalizes the model for repeating words. Defaults to None.
|
99
106
|
tool_choice (ToolChoice, optional): Specifies whether to use tools during response generation. Defaults to "auto".
|
100
107
|
thinking_config (ThinkingConfigOrDict, optional): The thinking configuration for response generation. Defaults to None.
|
108
|
+
gemini_tools (list[LLMTool], optional): The Gemini-specific tools to use for the session.
|
101
109
|
""" # noqa: E501
|
102
110
|
super().__init__()
|
103
111
|
gcp_project = project if is_given(project) else os.environ.get("GOOGLE_CLOUD_PROJECT")
|
104
|
-
gcp_location = (
|
112
|
+
gcp_location: str | None = (
|
105
113
|
location
|
106
114
|
if is_given(location)
|
107
115
|
else os.environ.get("GOOGLE_CLOUD_LOCATION") or "us-central1"
|
@@ -115,7 +123,7 @@ class LLM(llm.LLM):
|
|
115
123
|
|
116
124
|
if use_vertexai:
|
117
125
|
if not gcp_project:
|
118
|
-
_, gcp_project = default_async(
|
126
|
+
_, gcp_project = default_async( # type: ignore
|
119
127
|
scopes=["https://www.googleapis.com/auth/cloud-platform"]
|
120
128
|
)
|
121
129
|
gemini_api_key = None # VertexAI does not require an API key
|
@@ -157,8 +165,9 @@ class LLM(llm.LLM):
|
|
157
165
|
presence_penalty=presence_penalty,
|
158
166
|
frequency_penalty=frequency_penalty,
|
159
167
|
thinking_config=thinking_config,
|
168
|
+
gemini_tools=gemini_tools,
|
160
169
|
)
|
161
|
-
self._client =
|
170
|
+
self._client = Client(
|
162
171
|
api_key=gemini_api_key,
|
163
172
|
vertexai=use_vertexai,
|
164
173
|
project=gcp_project,
|
@@ -169,7 +178,7 @@ class LLM(llm.LLM):
|
|
169
178
|
self,
|
170
179
|
*,
|
171
180
|
chat_ctx: llm.ChatContext,
|
172
|
-
tools: list[FunctionTool] | None = None,
|
181
|
+
tools: list[FunctionTool | RawFunctionTool] | None = None,
|
173
182
|
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
174
183
|
parallel_tool_calls: NotGivenOr[bool] = NOT_GIVEN,
|
175
184
|
tool_choice: NotGivenOr[ToolChoice] = NOT_GIVEN,
|
@@ -177,50 +186,58 @@ class LLM(llm.LLM):
|
|
177
186
|
types.SchemaUnion | type[llm_utils.ResponseFormatT]
|
178
187
|
] = NOT_GIVEN,
|
179
188
|
extra_kwargs: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
|
189
|
+
gemini_tools: NotGivenOr[list[_LLMTool]] = NOT_GIVEN,
|
180
190
|
) -> LLMStream:
|
181
191
|
extra = {}
|
182
192
|
|
183
193
|
if is_given(extra_kwargs):
|
184
194
|
extra.update(extra_kwargs)
|
185
195
|
|
186
|
-
tool_choice =
|
196
|
+
tool_choice = (
|
197
|
+
cast(ToolChoice, tool_choice) if is_given(tool_choice) else self._opts.tool_choice
|
198
|
+
)
|
187
199
|
if is_given(tool_choice):
|
188
200
|
gemini_tool_choice: types.ToolConfig
|
189
201
|
if isinstance(tool_choice, dict) and tool_choice.get("type") == "function":
|
190
202
|
gemini_tool_choice = types.ToolConfig(
|
191
203
|
function_calling_config=types.FunctionCallingConfig(
|
192
|
-
mode=
|
204
|
+
mode=types.FunctionCallingConfigMode.ANY,
|
193
205
|
allowed_function_names=[tool_choice["function"]["name"]],
|
194
206
|
)
|
195
207
|
)
|
196
208
|
extra["tool_config"] = gemini_tool_choice
|
197
209
|
elif tool_choice == "required":
|
210
|
+
tool_names = []
|
211
|
+
for tool in tools or []:
|
212
|
+
if is_function_tool(tool):
|
213
|
+
tool_names.append(get_function_info(tool).name)
|
214
|
+
elif is_raw_function_tool(tool):
|
215
|
+
tool_names.append(get_raw_function_info(tool).name)
|
216
|
+
|
198
217
|
gemini_tool_choice = types.ToolConfig(
|
199
218
|
function_calling_config=types.FunctionCallingConfig(
|
200
|
-
mode=
|
201
|
-
allowed_function_names=
|
202
|
-
if tools
|
203
|
-
else None,
|
219
|
+
mode=types.FunctionCallingConfigMode.ANY,
|
220
|
+
allowed_function_names=tool_names or None,
|
204
221
|
)
|
205
222
|
)
|
206
223
|
extra["tool_config"] = gemini_tool_choice
|
207
224
|
elif tool_choice == "auto":
|
208
225
|
gemini_tool_choice = types.ToolConfig(
|
209
226
|
function_calling_config=types.FunctionCallingConfig(
|
210
|
-
mode=
|
227
|
+
mode=types.FunctionCallingConfigMode.AUTO,
|
211
228
|
)
|
212
229
|
)
|
213
230
|
extra["tool_config"] = gemini_tool_choice
|
214
231
|
elif tool_choice == "none":
|
215
232
|
gemini_tool_choice = types.ToolConfig(
|
216
233
|
function_calling_config=types.FunctionCallingConfig(
|
217
|
-
mode=
|
234
|
+
mode=types.FunctionCallingConfigMode.NONE,
|
218
235
|
)
|
219
236
|
)
|
220
237
|
extra["tool_config"] = gemini_tool_choice
|
221
238
|
|
222
239
|
if is_given(response_format):
|
223
|
-
extra["response_schema"] = to_response_format(response_format)
|
240
|
+
extra["response_schema"] = to_response_format(response_format) # type: ignore
|
224
241
|
extra["response_mime_type"] = "application/json"
|
225
242
|
|
226
243
|
if is_given(self._opts.temperature):
|
@@ -240,6 +257,8 @@ class LLM(llm.LLM):
|
|
240
257
|
if is_given(self._opts.thinking_config):
|
241
258
|
extra["thinking_config"] = self._opts.thinking_config
|
242
259
|
|
260
|
+
gemini_tools = gemini_tools if is_given(gemini_tools) else self._opts.gemini_tools
|
261
|
+
|
243
262
|
return LLMStream(
|
244
263
|
self,
|
245
264
|
client=self._client,
|
@@ -247,6 +266,7 @@ class LLM(llm.LLM):
|
|
247
266
|
chat_ctx=chat_ctx,
|
248
267
|
tools=tools or [],
|
249
268
|
conn_options=conn_options,
|
269
|
+
gemini_tools=gemini_tools,
|
250
270
|
extra_kwargs=extra,
|
251
271
|
)
|
252
272
|
|
@@ -256,32 +276,45 @@ class LLMStream(llm.LLMStream):
|
|
256
276
|
self,
|
257
277
|
llm: LLM,
|
258
278
|
*,
|
259
|
-
client:
|
279
|
+
client: Client,
|
260
280
|
model: str | ChatModels,
|
261
281
|
chat_ctx: llm.ChatContext,
|
262
282
|
conn_options: APIConnectOptions,
|
263
|
-
tools: list[FunctionTool],
|
283
|
+
tools: list[FunctionTool | RawFunctionTool],
|
264
284
|
extra_kwargs: dict[str, Any],
|
285
|
+
gemini_tools: NotGivenOr[list[_LLMTool]] = NOT_GIVEN,
|
265
286
|
) -> None:
|
266
287
|
super().__init__(llm, chat_ctx=chat_ctx, tools=tools, conn_options=conn_options)
|
267
288
|
self._client = client
|
268
289
|
self._model = model
|
269
290
|
self._llm: LLM = llm
|
270
291
|
self._extra_kwargs = extra_kwargs
|
292
|
+
self._gemini_tools = gemini_tools
|
271
293
|
|
272
294
|
async def _run(self) -> None:
|
273
295
|
retryable = True
|
274
296
|
request_id = utils.shortuuid()
|
275
297
|
|
276
298
|
try:
|
277
|
-
|
299
|
+
turns_dict, extra_data = self._chat_ctx.to_provider_format(format="google")
|
300
|
+
turns = [types.Content.model_validate(turn) for turn in turns_dict]
|
278
301
|
function_declarations = to_fnc_ctx(self._tools)
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
302
|
+
tools_config = create_tools_config(
|
303
|
+
function_tools=function_declarations,
|
304
|
+
gemini_tools=self._gemini_tools if is_given(self._gemini_tools) else None,
|
305
|
+
)
|
306
|
+
if tools_config:
|
307
|
+
self._extra_kwargs["tools"] = tools_config
|
308
|
+
|
283
309
|
config = types.GenerateContentConfig(
|
284
|
-
system_instruction=
|
310
|
+
system_instruction=(
|
311
|
+
[types.Part(text=content) for content in extra_data.system_messages]
|
312
|
+
if extra_data.system_messages
|
313
|
+
else None
|
314
|
+
),
|
315
|
+
http_options=types.HttpOptions(
|
316
|
+
timeout=int(self._conn_options.timeout * 1000),
|
317
|
+
),
|
285
318
|
**self._extra_kwargs,
|
286
319
|
)
|
287
320
|
|
@@ -371,7 +404,7 @@ class LLMStream(llm.LLMStream):
|
|
371
404
|
tool_calls=[
|
372
405
|
llm.FunctionToolCall(
|
373
406
|
arguments=json.dumps(part.function_call.args),
|
374
|
-
name=part.function_call.name,
|
407
|
+
name=part.function_call.name, # type: ignore
|
375
408
|
call_id=part.function_call.id or utils.shortuuid("function_call_"),
|
376
409
|
)
|
377
410
|
],
|
livekit/plugins/google/stt.py
CHANGED
@@ -18,8 +18,9 @@ import asyncio
|
|
18
18
|
import dataclasses
|
19
19
|
import time
|
20
20
|
import weakref
|
21
|
+
from collections.abc import AsyncGenerator, AsyncIterable
|
21
22
|
from dataclasses import dataclass
|
22
|
-
from typing import Callable, Union
|
23
|
+
from typing import Callable, Union, cast
|
23
24
|
|
24
25
|
from google.api_core.client_options import ClientOptions
|
25
26
|
from google.api_core.exceptions import DeadlineExceeded, GoogleAPICallError
|
@@ -140,7 +141,7 @@ class STT(stt.STT):
|
|
140
141
|
|
141
142
|
if not is_given(credentials_file) and not is_given(credentials_info):
|
142
143
|
try:
|
143
|
-
gauth_default()
|
144
|
+
gauth_default() # type: ignore
|
144
145
|
except DefaultCredentialsError:
|
145
146
|
raise ValueError(
|
146
147
|
"Application default credentials must be available "
|
@@ -168,9 +169,10 @@ class STT(stt.STT):
|
|
168
169
|
connect_cb=self._create_client,
|
169
170
|
)
|
170
171
|
|
171
|
-
async def _create_client(self) -> SpeechAsyncClient:
|
172
|
+
async def _create_client(self, timeout: float) -> SpeechAsyncClient:
|
172
173
|
# Add support for passing a specific location that matches recognizer
|
173
174
|
# see: https://cloud.google.com/speech-to-text/v2/docs/speech-to-text-supported-languages
|
175
|
+
# TODO(long): how to set timeout?
|
174
176
|
client_options = None
|
175
177
|
client: SpeechAsyncClient | None = None
|
176
178
|
if self._location != "global":
|
@@ -198,7 +200,7 @@ class STT(stt.STT):
|
|
198
200
|
except AttributeError:
|
199
201
|
from google.auth import default as ga_default
|
200
202
|
|
201
|
-
_, project_id = ga_default()
|
203
|
+
_, project_id = ga_default() # type: ignore
|
202
204
|
return f"projects/{project_id}/locations/{self._location}/recognizers/_"
|
203
205
|
|
204
206
|
def _sanitize_options(self, *, language: NotGivenOr[str] = NOT_GIVEN) -> STTOptions:
|
@@ -243,7 +245,7 @@ class STT(stt.STT):
|
|
243
245
|
)
|
244
246
|
|
245
247
|
try:
|
246
|
-
async with self._pool.connection() as client:
|
248
|
+
async with self._pool.connection(timeout=conn_options.timeout) as client:
|
247
249
|
raw = await client.recognize(
|
248
250
|
cloud_speech.RecognizeRequest(
|
249
251
|
recognizer=self._get_recognizer(client),
|
@@ -289,11 +291,11 @@ class STT(stt.STT):
|
|
289
291
|
model: NotGivenOr[SpeechModels] = NOT_GIVEN,
|
290
292
|
location: NotGivenOr[str] = NOT_GIVEN,
|
291
293
|
keywords: NotGivenOr[list[tuple[str, float]]] = NOT_GIVEN,
|
292
|
-
):
|
294
|
+
) -> None:
|
293
295
|
if is_given(languages):
|
294
296
|
if isinstance(languages, str):
|
295
297
|
languages = [languages]
|
296
|
-
self._config.languages = languages
|
298
|
+
self._config.languages = cast(list[LgType], languages)
|
297
299
|
if is_given(detect_language):
|
298
300
|
self._config.detect_language = detect_language
|
299
301
|
if is_given(interim_results):
|
@@ -356,11 +358,11 @@ class SpeechStream(stt.SpeechStream):
|
|
356
358
|
model: NotGivenOr[SpeechModels] = NOT_GIVEN,
|
357
359
|
min_confidence_threshold: NotGivenOr[float] = NOT_GIVEN,
|
358
360
|
keywords: NotGivenOr[list[tuple[str, float]]] = NOT_GIVEN,
|
359
|
-
):
|
361
|
+
) -> None:
|
360
362
|
if is_given(languages):
|
361
363
|
if isinstance(languages, str):
|
362
364
|
languages = [languages]
|
363
|
-
self._config.languages = languages
|
365
|
+
self._config.languages = cast(list[LgType], languages)
|
364
366
|
if is_given(detect_language):
|
365
367
|
self._config.detect_language = detect_language
|
366
368
|
if is_given(interim_results):
|
@@ -381,7 +383,9 @@ class SpeechStream(stt.SpeechStream):
|
|
381
383
|
async def _run(self) -> None:
|
382
384
|
# google requires a async generator when calling streaming_recognize
|
383
385
|
# this function basically convert the queue into a async generator
|
384
|
-
async def input_generator(
|
386
|
+
async def input_generator(
|
387
|
+
client: SpeechAsyncClient, should_stop: asyncio.Event
|
388
|
+
) -> AsyncGenerator[cloud_speech.StreamingRecognizeRequest, None]:
|
385
389
|
try:
|
386
390
|
# first request should contain the config
|
387
391
|
yield cloud_speech.StreamingRecognizeRequest(
|
@@ -402,7 +406,10 @@ class SpeechStream(stt.SpeechStream):
|
|
402
406
|
except Exception:
|
403
407
|
logger.exception("an error occurred while streaming input to google STT")
|
404
408
|
|
405
|
-
async def process_stream(
|
409
|
+
async def process_stream(
|
410
|
+
client: SpeechAsyncClient,
|
411
|
+
stream: AsyncIterable[cloud_speech.StreamingRecognizeResponse],
|
412
|
+
) -> None:
|
406
413
|
has_started = False
|
407
414
|
async for resp in stream:
|
408
415
|
if (
|
@@ -464,7 +471,7 @@ class SpeechStream(stt.SpeechStream):
|
|
464
471
|
|
465
472
|
while True:
|
466
473
|
try:
|
467
|
-
async with self._pool.connection() as client:
|
474
|
+
async with self._pool.connection(timeout=self._conn_options.timeout) as client:
|
468
475
|
self._streaming_config = cloud_speech.StreamingRecognitionConfig(
|
469
476
|
config=cloud_speech.RecognitionConfig(
|
470
477
|
explicit_decoding_config=cloud_speech.ExplicitDecodingConfig(
|