livekit-plugins-google 0.9.0__tar.gz → 0.10.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {livekit_plugins_google-0.9.0 → livekit_plugins_google-0.10.0}/PKG-INFO +13 -3
- {livekit_plugins_google-0.9.0 → livekit_plugins_google-0.10.0}/livekit/plugins/google/__init__.py +2 -1
- livekit_plugins_google-0.10.0/livekit/plugins/google/_utils.py +202 -0
- {livekit_plugins_google-0.9.0 → livekit_plugins_google-0.10.0}/livekit/plugins/google/beta/realtime/__init__.py +0 -2
- livekit_plugins_google-0.10.0/livekit/plugins/google/beta/realtime/api_proto.py +24 -0
- {livekit_plugins_google-0.9.0 → livekit_plugins_google-0.10.0}/livekit/plugins/google/beta/realtime/realtime_api.py +168 -42
- livekit_plugins_google-0.10.0/livekit/plugins/google/beta/realtime/transcriber.py +173 -0
- livekit_plugins_google-0.10.0/livekit/plugins/google/llm.py +414 -0
- {livekit_plugins_google-0.9.0 → livekit_plugins_google-0.10.0}/livekit/plugins/google/models.py +2 -0
- {livekit_plugins_google-0.9.0 → livekit_plugins_google-0.10.0}/livekit/plugins/google/stt.py +64 -10
- {livekit_plugins_google-0.9.0 → livekit_plugins_google-0.10.0}/livekit/plugins/google/version.py +1 -1
- {livekit_plugins_google-0.9.0 → livekit_plugins_google-0.10.0}/livekit_plugins_google.egg-info/PKG-INFO +13 -3
- {livekit_plugins_google-0.9.0 → livekit_plugins_google-0.10.0}/livekit_plugins_google.egg-info/SOURCES.txt +3 -0
- {livekit_plugins_google-0.9.0 → livekit_plugins_google-0.10.0}/livekit_plugins_google.egg-info/requires.txt +1 -1
- {livekit_plugins_google-0.9.0 → livekit_plugins_google-0.10.0}/setup.py +1 -1
- livekit_plugins_google-0.9.0/livekit/plugins/google/beta/realtime/api_proto.py +0 -79
- {livekit_plugins_google-0.9.0 → livekit_plugins_google-0.10.0}/README.md +0 -0
- {livekit_plugins_google-0.9.0 → livekit_plugins_google-0.10.0}/livekit/plugins/google/beta/__init__.py +0 -0
- {livekit_plugins_google-0.9.0 → livekit_plugins_google-0.10.0}/livekit/plugins/google/log.py +0 -0
- {livekit_plugins_google-0.9.0 → livekit_plugins_google-0.10.0}/livekit/plugins/google/py.typed +0 -0
- {livekit_plugins_google-0.9.0 → livekit_plugins_google-0.10.0}/livekit/plugins/google/tts.py +0 -0
- {livekit_plugins_google-0.9.0 → livekit_plugins_google-0.10.0}/livekit_plugins_google.egg-info/dependency_links.txt +0 -0
- {livekit_plugins_google-0.9.0 → livekit_plugins_google-0.10.0}/livekit_plugins_google.egg-info/top_level.txt +0 -0
- {livekit_plugins_google-0.9.0 → livekit_plugins_google-0.10.0}/pyproject.toml +0 -0
- {livekit_plugins_google-0.9.0 → livekit_plugins_google-0.10.0}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.2
|
2
2
|
Name: livekit-plugins-google
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.10.0
|
4
4
|
Summary: Agent Framework plugin for services from Google Cloud
|
5
5
|
Home-page: https://github.com/livekit/agents
|
6
6
|
License: Apache-2.0
|
@@ -22,8 +22,18 @@ Description-Content-Type: text/markdown
|
|
22
22
|
Requires-Dist: google-auth<3,>=2
|
23
23
|
Requires-Dist: google-cloud-speech<3,>=2
|
24
24
|
Requires-Dist: google-cloud-texttospeech<3,>=2
|
25
|
-
Requires-Dist: google-genai
|
25
|
+
Requires-Dist: google-genai==0.5.0
|
26
26
|
Requires-Dist: livekit-agents>=0.12.3
|
27
|
+
Dynamic: classifier
|
28
|
+
Dynamic: description
|
29
|
+
Dynamic: description-content-type
|
30
|
+
Dynamic: home-page
|
31
|
+
Dynamic: keywords
|
32
|
+
Dynamic: license
|
33
|
+
Dynamic: project-url
|
34
|
+
Dynamic: requires-dist
|
35
|
+
Dynamic: requires-python
|
36
|
+
Dynamic: summary
|
27
37
|
|
28
38
|
# LiveKit Plugins Google
|
29
39
|
|
{livekit_plugins_google-0.9.0 → livekit_plugins_google-0.10.0}/livekit/plugins/google/__init__.py
RENAMED
@@ -13,11 +13,12 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
from . import beta
|
16
|
+
from .llm import LLM
|
16
17
|
from .stt import STT, SpeechStream
|
17
18
|
from .tts import TTS
|
18
19
|
from .version import __version__
|
19
20
|
|
20
|
-
__all__ = ["STT", "TTS", "SpeechStream", "__version__", "beta"]
|
21
|
+
__all__ = ["STT", "TTS", "SpeechStream", "__version__", "beta", "LLM"]
|
21
22
|
from livekit.agents import Plugin
|
22
23
|
|
23
24
|
from .log import logger
|
@@ -0,0 +1,202 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import base64
|
4
|
+
import inspect
|
5
|
+
import json
|
6
|
+
from typing import Any, Dict, List, Optional, get_args, get_origin
|
7
|
+
|
8
|
+
from livekit import rtc
|
9
|
+
from livekit.agents import llm, utils
|
10
|
+
from livekit.agents.llm.function_context import _is_optional_type
|
11
|
+
|
12
|
+
from google.genai import types
|
13
|
+
|
14
|
+
JSON_SCHEMA_TYPE_MAP: dict[type, types.Type] = {
|
15
|
+
str: "STRING",
|
16
|
+
int: "INTEGER",
|
17
|
+
float: "NUMBER",
|
18
|
+
bool: "BOOLEAN",
|
19
|
+
dict: "OBJECT",
|
20
|
+
list: "ARRAY",
|
21
|
+
}
|
22
|
+
|
23
|
+
__all__ = ["_build_gemini_ctx", "_build_tools"]
|
24
|
+
|
25
|
+
|
26
|
+
def _build_parameters(arguments: Dict[str, Any]) -> types.Schema | None:
|
27
|
+
properties: Dict[str, types.Schema] = {}
|
28
|
+
required: List[str] = []
|
29
|
+
|
30
|
+
for arg_name, arg_info in arguments.items():
|
31
|
+
prop = types.Schema()
|
32
|
+
if arg_info.description:
|
33
|
+
prop.description = arg_info.description
|
34
|
+
|
35
|
+
_, py_type = _is_optional_type(arg_info.type)
|
36
|
+
origin = get_origin(py_type)
|
37
|
+
if origin is list:
|
38
|
+
item_type = get_args(py_type)[0]
|
39
|
+
if item_type not in JSON_SCHEMA_TYPE_MAP:
|
40
|
+
raise ValueError(f"Unsupported type: {item_type}")
|
41
|
+
prop.type = "ARRAY"
|
42
|
+
prop.items = types.Schema(type=JSON_SCHEMA_TYPE_MAP[item_type])
|
43
|
+
|
44
|
+
if arg_info.choices:
|
45
|
+
prop.items.enum = arg_info.choices
|
46
|
+
else:
|
47
|
+
if py_type not in JSON_SCHEMA_TYPE_MAP:
|
48
|
+
raise ValueError(f"Unsupported type: {py_type}")
|
49
|
+
|
50
|
+
prop.type = JSON_SCHEMA_TYPE_MAP[py_type]
|
51
|
+
|
52
|
+
if arg_info.choices:
|
53
|
+
prop.enum = arg_info.choices
|
54
|
+
if py_type is int:
|
55
|
+
raise ValueError(
|
56
|
+
f"Parameter '{arg_info.name}' uses integer choices, not supported by this model."
|
57
|
+
)
|
58
|
+
|
59
|
+
properties[arg_name] = prop
|
60
|
+
|
61
|
+
if arg_info.default is inspect.Parameter.empty:
|
62
|
+
required.append(arg_name)
|
63
|
+
|
64
|
+
if properties:
|
65
|
+
parameters = types.Schema(type="OBJECT", properties=properties)
|
66
|
+
if required:
|
67
|
+
parameters.required = required
|
68
|
+
|
69
|
+
return parameters
|
70
|
+
|
71
|
+
return None
|
72
|
+
|
73
|
+
|
74
|
+
def _build_tools(fnc_ctx: Any) -> List[types.FunctionDeclaration]:
|
75
|
+
function_declarations: List[types.FunctionDeclaration] = []
|
76
|
+
for fnc_info in fnc_ctx.ai_functions.values():
|
77
|
+
parameters = _build_parameters(fnc_info.arguments)
|
78
|
+
|
79
|
+
func_decl = types.FunctionDeclaration(
|
80
|
+
name=fnc_info.name,
|
81
|
+
description=fnc_info.description,
|
82
|
+
parameters=parameters,
|
83
|
+
)
|
84
|
+
|
85
|
+
function_declarations.append(func_decl)
|
86
|
+
return function_declarations
|
87
|
+
|
88
|
+
|
89
|
+
def _build_gemini_ctx(
|
90
|
+
chat_ctx: llm.ChatContext, cache_key: Any
|
91
|
+
) -> tuple[list[types.Content], Optional[types.Content]]:
|
92
|
+
turns: list[types.Content] = []
|
93
|
+
system_instruction: Optional[types.Content] = None
|
94
|
+
current_role: Optional[str] = None
|
95
|
+
parts: list[types.Part] = []
|
96
|
+
|
97
|
+
for msg in chat_ctx.messages:
|
98
|
+
if msg.role == "system":
|
99
|
+
if isinstance(msg.content, str):
|
100
|
+
system_instruction = types.Content(parts=[types.Part(text=msg.content)])
|
101
|
+
continue
|
102
|
+
|
103
|
+
if msg.role == "assistant":
|
104
|
+
role = "model"
|
105
|
+
elif msg.role == "tool":
|
106
|
+
role = "user"
|
107
|
+
else:
|
108
|
+
role = "user"
|
109
|
+
|
110
|
+
# If role changed, finalize previous parts into a turn
|
111
|
+
if role != current_role:
|
112
|
+
if current_role is not None and parts:
|
113
|
+
turns.append(types.Content(role=current_role, parts=parts))
|
114
|
+
current_role = role
|
115
|
+
parts = []
|
116
|
+
|
117
|
+
if msg.tool_calls:
|
118
|
+
for fnc in msg.tool_calls:
|
119
|
+
parts.append(
|
120
|
+
types.Part(
|
121
|
+
function_call=types.FunctionCall(
|
122
|
+
id=fnc.tool_call_id,
|
123
|
+
name=fnc.function_info.name,
|
124
|
+
args=fnc.arguments,
|
125
|
+
)
|
126
|
+
)
|
127
|
+
)
|
128
|
+
|
129
|
+
if msg.role == "tool":
|
130
|
+
if msg.content:
|
131
|
+
if isinstance(msg.content, dict):
|
132
|
+
parts.append(
|
133
|
+
types.Part(
|
134
|
+
function_response=types.FunctionResponse(
|
135
|
+
id=msg.tool_call_id,
|
136
|
+
name=msg.name,
|
137
|
+
response=msg.content,
|
138
|
+
)
|
139
|
+
)
|
140
|
+
)
|
141
|
+
elif isinstance(msg.content, str):
|
142
|
+
parts.append(
|
143
|
+
types.Part(
|
144
|
+
function_response=types.FunctionResponse(
|
145
|
+
id=msg.tool_call_id,
|
146
|
+
name=msg.name,
|
147
|
+
response={"result": msg.content},
|
148
|
+
)
|
149
|
+
)
|
150
|
+
)
|
151
|
+
else:
|
152
|
+
if msg.content:
|
153
|
+
if isinstance(msg.content, str):
|
154
|
+
parts.append(types.Part(text=msg.content))
|
155
|
+
elif isinstance(msg.content, dict):
|
156
|
+
parts.append(types.Part(text=json.dumps(msg.content)))
|
157
|
+
elif isinstance(msg.content, list):
|
158
|
+
for item in msg.content:
|
159
|
+
if isinstance(item, str):
|
160
|
+
parts.append(types.Part(text=item))
|
161
|
+
elif isinstance(item, llm.ChatImage):
|
162
|
+
parts.append(_build_gemini_image_part(item, cache_key))
|
163
|
+
|
164
|
+
# Finalize last role's parts if any remain
|
165
|
+
if current_role is not None and parts:
|
166
|
+
turns.append(types.Content(role=current_role, parts=parts))
|
167
|
+
|
168
|
+
return turns, system_instruction
|
169
|
+
|
170
|
+
|
171
|
+
def _build_gemini_image_part(image: llm.ChatImage, cache_key: Any) -> types.Part:
|
172
|
+
if isinstance(image.image, str):
|
173
|
+
# Check if the string is a Data URL
|
174
|
+
if image.image.startswith("data:image/jpeg;base64,"):
|
175
|
+
# Extract the base64 part after the comma
|
176
|
+
base64_data = image.image.split(",", 1)[1]
|
177
|
+
try:
|
178
|
+
image_bytes = base64.b64decode(base64_data)
|
179
|
+
except Exception as e:
|
180
|
+
raise ValueError("Invalid base64 data in image URL") from e
|
181
|
+
|
182
|
+
return types.Part.from_bytes(data=image_bytes, mime_type="image/jpeg")
|
183
|
+
else:
|
184
|
+
# Assume it's a regular URL
|
185
|
+
return types.Part.from_uri(file_uri=image.image, mime_type="image/jpeg")
|
186
|
+
|
187
|
+
elif isinstance(image.image, rtc.VideoFrame):
|
188
|
+
if cache_key not in image._cache:
|
189
|
+
opts = utils.images.EncodeOptions()
|
190
|
+
if image.inference_width and image.inference_height:
|
191
|
+
opts.resize_options = utils.images.ResizeOptions(
|
192
|
+
width=image.inference_width,
|
193
|
+
height=image.inference_height,
|
194
|
+
strategy="scale_aspect_fit",
|
195
|
+
)
|
196
|
+
encoded_data = utils.images.encode(image.image, opts)
|
197
|
+
image._cache[cache_key] = base64.b64encode(encoded_data).decode("utf-8")
|
198
|
+
|
199
|
+
return types.Part.from_bytes(
|
200
|
+
data=image._cache[cache_key], mime_type="image/jpeg"
|
201
|
+
)
|
202
|
+
raise ValueError(f"Unsupported image type: {type(image.image)}")
|
@@ -1,7 +1,6 @@
|
|
1
1
|
from .api_proto import (
|
2
2
|
ClientEvents,
|
3
3
|
LiveAPIModels,
|
4
|
-
ResponseModality,
|
5
4
|
Voice,
|
6
5
|
)
|
7
6
|
from .realtime_api import RealtimeModel
|
@@ -10,6 +9,5 @@ __all__ = [
|
|
10
9
|
"RealtimeModel",
|
11
10
|
"ClientEvents",
|
12
11
|
"LiveAPIModels",
|
13
|
-
"ResponseModality",
|
14
12
|
"Voice",
|
15
13
|
]
|
@@ -0,0 +1,24 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import Literal, Sequence, Union
|
4
|
+
|
5
|
+
from google.genai import types
|
6
|
+
|
7
|
+
from ..._utils import _build_gemini_ctx, _build_tools
|
8
|
+
|
9
|
+
LiveAPIModels = Literal["gemini-2.0-flash-exp"]
|
10
|
+
|
11
|
+
Voice = Literal["Puck", "Charon", "Kore", "Fenrir", "Aoede"]
|
12
|
+
|
13
|
+
__all__ = ["_build_tools", "ClientEvents", "_build_gemini_ctx"]
|
14
|
+
|
15
|
+
ClientEvents = Union[
|
16
|
+
types.ContentListUnion,
|
17
|
+
types.ContentListUnionDict,
|
18
|
+
types.LiveClientContentOrDict,
|
19
|
+
types.LiveClientRealtimeInput,
|
20
|
+
types.LiveClientRealtimeInputOrDict,
|
21
|
+
types.LiveClientToolResponseOrDict,
|
22
|
+
types.FunctionResponseOrDict,
|
23
|
+
Sequence[types.FunctionResponseOrDict],
|
24
|
+
]
|
@@ -1,7 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import asyncio
|
4
|
-
import base64
|
5
4
|
import json
|
6
5
|
import os
|
7
6
|
from dataclasses import dataclass
|
@@ -11,14 +10,22 @@ from livekit import rtc
|
|
11
10
|
from livekit.agents import llm, utils
|
12
11
|
from livekit.agents.llm.function_context import _create_ai_function_info
|
13
12
|
|
14
|
-
from google import genai
|
15
|
-
from google.genai.
|
13
|
+
from google import genai
|
14
|
+
from google.genai._api_client import HttpOptions
|
15
|
+
from google.genai.types import (
|
16
|
+
Blob,
|
17
|
+
Content,
|
16
18
|
FunctionResponse,
|
17
|
-
|
19
|
+
GenerationConfig,
|
20
|
+
LiveClientContent,
|
21
|
+
LiveClientRealtimeInput,
|
18
22
|
LiveClientToolResponse,
|
19
|
-
|
23
|
+
LiveConnectConfig,
|
24
|
+
Modality,
|
25
|
+
Part,
|
20
26
|
PrebuiltVoiceConfig,
|
21
27
|
SpeechConfig,
|
28
|
+
Tool,
|
22
29
|
VoiceConfig,
|
23
30
|
)
|
24
31
|
|
@@ -26,10 +33,11 @@ from ...log import logger
|
|
26
33
|
from .api_proto import (
|
27
34
|
ClientEvents,
|
28
35
|
LiveAPIModels,
|
29
|
-
ResponseModality,
|
30
36
|
Voice,
|
37
|
+
_build_gemini_ctx,
|
31
38
|
_build_tools,
|
32
39
|
)
|
40
|
+
from .transcriber import TranscriberSession, TranscriptionContent
|
33
41
|
|
34
42
|
EventTypes = Literal[
|
35
43
|
"start_session",
|
@@ -39,6 +47,9 @@ EventTypes = Literal[
|
|
39
47
|
"function_calls_collected",
|
40
48
|
"function_calls_finished",
|
41
49
|
"function_calls_cancelled",
|
50
|
+
"input_speech_transcription_completed",
|
51
|
+
"agent_speech_transcription_completed",
|
52
|
+
"agent_speech_stopped",
|
42
53
|
]
|
43
54
|
|
44
55
|
|
@@ -55,6 +66,12 @@ class GeminiContent:
|
|
55
66
|
content_type: Literal["text", "audio"]
|
56
67
|
|
57
68
|
|
69
|
+
@dataclass
|
70
|
+
class InputTranscription:
|
71
|
+
item_id: str
|
72
|
+
transcript: str
|
73
|
+
|
74
|
+
|
58
75
|
@dataclass
|
59
76
|
class Capabilities:
|
60
77
|
supports_truncate: bool
|
@@ -65,7 +82,7 @@ class ModelOptions:
|
|
65
82
|
model: LiveAPIModels | str
|
66
83
|
api_key: str | None
|
67
84
|
voice: Voice | str
|
68
|
-
response_modalities:
|
85
|
+
response_modalities: list[Modality] | None
|
69
86
|
vertexai: bool
|
70
87
|
project: str | None
|
71
88
|
location: str | None
|
@@ -76,18 +93,22 @@ class ModelOptions:
|
|
76
93
|
top_k: int | None
|
77
94
|
presence_penalty: float | None
|
78
95
|
frequency_penalty: float | None
|
79
|
-
instructions:
|
96
|
+
instructions: Content | None
|
97
|
+
enable_user_audio_transcription: bool
|
98
|
+
enable_agent_audio_transcription: bool
|
80
99
|
|
81
100
|
|
82
101
|
class RealtimeModel:
|
83
102
|
def __init__(
|
84
103
|
self,
|
85
104
|
*,
|
86
|
-
instructions: str =
|
105
|
+
instructions: str | None = None,
|
87
106
|
model: LiveAPIModels | str = "gemini-2.0-flash-exp",
|
88
107
|
api_key: str | None = None,
|
89
108
|
voice: Voice | str = "Puck",
|
90
|
-
modalities:
|
109
|
+
modalities: list[Modality] = ["AUDIO"],
|
110
|
+
enable_user_audio_transcription: bool = True,
|
111
|
+
enable_agent_audio_transcription: bool = True,
|
91
112
|
vertexai: bool = False,
|
92
113
|
project: str | None = None,
|
93
114
|
location: str | None = None,
|
@@ -103,15 +124,24 @@ class RealtimeModel:
|
|
103
124
|
"""
|
104
125
|
Initializes a RealtimeModel instance for interacting with Google's Realtime API.
|
105
126
|
|
127
|
+
Environment Requirements:
|
128
|
+
- For VertexAI: Set the `GOOGLE_APPLICATION_CREDENTIALS` environment variable to the path of the service account key file.
|
129
|
+
The Google Cloud project and location can be set via `project` and `location` arguments or the environment variables
|
130
|
+
`GOOGLE_CLOUD_PROJECT` and `GOOGLE_CLOUD_LOCATION`. By default, the project is inferred from the service account key file,
|
131
|
+
and the location defaults to "us-central1".
|
132
|
+
- For Google Gemini API: Set the `api_key` argument or the `GOOGLE_API_KEY` environment variable.
|
133
|
+
|
106
134
|
Args:
|
107
135
|
instructions (str, optional): Initial system instructions for the model. Defaults to "".
|
108
|
-
api_key (str or None, optional):
|
109
|
-
modalities (
|
136
|
+
api_key (str or None, optional): Google Gemini API key. If None, will attempt to read from the environment variable GOOGLE_API_KEY.
|
137
|
+
modalities (list[Modality], optional): Modalities to use, such as ["TEXT", "AUDIO"]. Defaults to ["AUDIO"].
|
110
138
|
model (str or None, optional): The name of the model to use. Defaults to "gemini-2.0-flash-exp".
|
111
139
|
voice (api_proto.Voice, optional): Voice setting for audio outputs. Defaults to "Puck".
|
140
|
+
enable_user_audio_transcription (bool, optional): Whether to enable user audio transcription. Defaults to True
|
141
|
+
enable_agent_audio_transcription (bool, optional): Whether to enable agent audio transcription. Defaults to True
|
112
142
|
temperature (float, optional): Sampling temperature for response generation. Defaults to 0.8.
|
113
143
|
vertexai (bool, optional): Whether to use VertexAI for the API. Defaults to False.
|
114
|
-
project (str or None, optional): The project to use for the API. Defaults to None. (for vertexai)
|
144
|
+
project (str or None, optional): The project id to use for the API. Defaults to None. (for vertexai)
|
115
145
|
location (str or None, optional): The location to use for the API. Defaults to None. (for vertexai)
|
116
146
|
candidate_count (int, optional): The number of candidate responses to generate. Defaults to 1.
|
117
147
|
top_p (float, optional): The top-p value for response generation
|
@@ -130,21 +160,38 @@ class RealtimeModel:
|
|
130
160
|
self._model = model
|
131
161
|
self._loop = loop or asyncio.get_event_loop()
|
132
162
|
self._api_key = api_key or os.environ.get("GOOGLE_API_KEY")
|
133
|
-
self.
|
134
|
-
self.
|
135
|
-
|
136
|
-
|
137
|
-
|
163
|
+
self._project = project or os.environ.get("GOOGLE_CLOUD_PROJECT")
|
164
|
+
self._location = location or os.environ.get("GOOGLE_CLOUD_LOCATION")
|
165
|
+
if vertexai:
|
166
|
+
if not self._project or not self._location:
|
167
|
+
raise ValueError(
|
168
|
+
"Project and location are required for VertexAI either via project and location or GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION environment variables"
|
169
|
+
)
|
170
|
+
self._api_key = None # VertexAI does not require an API key
|
171
|
+
|
172
|
+
else:
|
173
|
+
self._project = None
|
174
|
+
self._location = None
|
175
|
+
if not self._api_key:
|
176
|
+
raise ValueError(
|
177
|
+
"API key is required for Google API either via api_key or GOOGLE_API_KEY environment variable"
|
178
|
+
)
|
179
|
+
|
180
|
+
instructions_content = (
|
181
|
+
Content(parts=[Part(text=instructions)]) if instructions else None
|
182
|
+
)
|
138
183
|
|
139
184
|
self._rt_sessions: list[GeminiRealtimeSession] = []
|
140
185
|
self._opts = ModelOptions(
|
141
186
|
model=model,
|
142
|
-
api_key=
|
187
|
+
api_key=self._api_key,
|
143
188
|
voice=voice,
|
189
|
+
enable_user_audio_transcription=enable_user_audio_transcription,
|
190
|
+
enable_agent_audio_transcription=enable_agent_audio_transcription,
|
144
191
|
response_modalities=modalities,
|
145
192
|
vertexai=vertexai,
|
146
|
-
project=
|
147
|
-
location=
|
193
|
+
project=self._project,
|
194
|
+
location=self._location,
|
148
195
|
candidate_count=candidate_count,
|
149
196
|
temperature=temperature,
|
150
197
|
max_output_tokens=max_output_tokens,
|
@@ -152,7 +199,7 @@ class RealtimeModel:
|
|
152
199
|
top_k=top_k,
|
153
200
|
presence_penalty=presence_penalty,
|
154
201
|
frequency_penalty=frequency_penalty,
|
155
|
-
instructions=
|
202
|
+
instructions=instructions_content,
|
156
203
|
)
|
157
204
|
|
158
205
|
@property
|
@@ -208,16 +255,16 @@ class GeminiRealtimeSession(utils.EventEmitter[EventTypes]):
|
|
208
255
|
self._chat_ctx = chat_ctx
|
209
256
|
self._fnc_ctx = fnc_ctx
|
210
257
|
self._fnc_tasks = utils.aio.TaskSet()
|
258
|
+
self._is_interrupted = False
|
211
259
|
|
212
260
|
tools = []
|
213
261
|
if self._fnc_ctx is not None:
|
214
262
|
functions = _build_tools(self._fnc_ctx)
|
215
|
-
tools.append(
|
263
|
+
tools.append(Tool(function_declarations=functions))
|
216
264
|
|
217
|
-
self._config =
|
218
|
-
model=self._opts.model,
|
265
|
+
self._config = LiveConnectConfig(
|
219
266
|
response_modalities=self._opts.response_modalities,
|
220
|
-
generation_config=
|
267
|
+
generation_config=GenerationConfig(
|
221
268
|
candidate_count=self._opts.candidate_count,
|
222
269
|
temperature=self._opts.temperature,
|
223
270
|
max_output_tokens=self._opts.max_output_tokens,
|
@@ -237,7 +284,7 @@ class GeminiRealtimeSession(utils.EventEmitter[EventTypes]):
|
|
237
284
|
tools=tools,
|
238
285
|
)
|
239
286
|
self._client = genai.Client(
|
240
|
-
http_options=
|
287
|
+
http_options=HttpOptions(api_version="v1alpha"),
|
241
288
|
api_key=self._opts.api_key,
|
242
289
|
vertexai=self._opts.vertexai,
|
243
290
|
project=self._opts.project,
|
@@ -246,12 +293,22 @@ class GeminiRealtimeSession(utils.EventEmitter[EventTypes]):
|
|
246
293
|
self._main_atask = asyncio.create_task(
|
247
294
|
self._main_task(), name="gemini-realtime-session"
|
248
295
|
)
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
296
|
+
if self._opts.enable_user_audio_transcription:
|
297
|
+
self._transcriber = TranscriberSession(
|
298
|
+
client=self._client, model=self._opts.model
|
299
|
+
)
|
300
|
+
self._transcriber.on("input_speech_done", self._on_input_speech_done)
|
301
|
+
if self._opts.enable_agent_audio_transcription:
|
302
|
+
self._agent_transcriber = TranscriberSession(
|
303
|
+
client=self._client, model=self._opts.model
|
304
|
+
)
|
305
|
+
self._agent_transcriber.on("input_speech_done", self._on_agent_speech_done)
|
306
|
+
# init dummy task
|
307
|
+
self._init_sync_task = asyncio.create_task(asyncio.sleep(0))
|
253
308
|
self._send_ch = utils.aio.Chan[ClientEvents]()
|
254
309
|
self._active_response_id = None
|
310
|
+
if chat_ctx:
|
311
|
+
self.generate_reply(chat_ctx)
|
255
312
|
|
256
313
|
async def aclose(self) -> None:
|
257
314
|
if self._send_ch.closed:
|
@@ -269,32 +326,97 @@ class GeminiRealtimeSession(utils.EventEmitter[EventTypes]):
|
|
269
326
|
self._fnc_ctx = value
|
270
327
|
|
271
328
|
def _push_audio(self, frame: rtc.AudioFrame) -> None:
|
272
|
-
|
273
|
-
|
329
|
+
if self._opts.enable_user_audio_transcription:
|
330
|
+
self._transcriber._push_audio(frame)
|
331
|
+
realtime_input = LiveClientRealtimeInput(
|
332
|
+
media_chunks=[Blob(data=frame.data.tobytes(), mime_type="audio/pcm")],
|
333
|
+
)
|
334
|
+
self._queue_msg(realtime_input)
|
274
335
|
|
275
|
-
def _queue_msg(self, msg:
|
336
|
+
def _queue_msg(self, msg: ClientEvents) -> None:
|
276
337
|
self._send_ch.send_nowait(msg)
|
277
338
|
|
339
|
+
def generate_reply(
|
340
|
+
self,
|
341
|
+
ctx: llm.ChatContext | llm.ChatMessage,
|
342
|
+
turn_complete: bool = True,
|
343
|
+
) -> None:
|
344
|
+
if isinstance(ctx, llm.ChatMessage) and isinstance(ctx.content, str):
|
345
|
+
new_chat_ctx = llm.ChatContext()
|
346
|
+
new_chat_ctx.append(text=ctx.content, role=ctx.role)
|
347
|
+
elif isinstance(ctx, llm.ChatContext):
|
348
|
+
new_chat_ctx = ctx
|
349
|
+
else:
|
350
|
+
raise ValueError("Invalid chat context")
|
351
|
+
turns, _ = _build_gemini_ctx(new_chat_ctx, id(self))
|
352
|
+
client_content = LiveClientContent(
|
353
|
+
turn_complete=turn_complete,
|
354
|
+
turns=turns,
|
355
|
+
)
|
356
|
+
self._queue_msg(client_content)
|
357
|
+
|
278
358
|
def chat_ctx_copy(self) -> llm.ChatContext:
|
279
359
|
return self._chat_ctx.copy()
|
280
360
|
|
281
361
|
async def set_chat_ctx(self, ctx: llm.ChatContext) -> None:
|
282
362
|
self._chat_ctx = ctx.copy()
|
283
363
|
|
364
|
+
def cancel_response(self) -> None:
|
365
|
+
raise NotImplementedError("cancel_response is not supported yet")
|
366
|
+
|
367
|
+
def create_response(
|
368
|
+
self,
|
369
|
+
on_duplicate: Literal[
|
370
|
+
"cancel_existing", "cancel_new", "keep_both"
|
371
|
+
] = "keep_both",
|
372
|
+
) -> None:
|
373
|
+
raise NotImplementedError("create_response is not supported yet")
|
374
|
+
|
375
|
+
def commit_audio_buffer(self) -> None:
|
376
|
+
raise NotImplementedError("commit_audio_buffer is not supported yet")
|
377
|
+
|
378
|
+
def server_vad_enabled(self) -> bool:
|
379
|
+
return True
|
380
|
+
|
381
|
+
def _on_input_speech_done(self, content: TranscriptionContent) -> None:
|
382
|
+
if content.response_id and content.text:
|
383
|
+
self.emit(
|
384
|
+
"input_speech_transcription_completed",
|
385
|
+
InputTranscription(
|
386
|
+
item_id=content.response_id,
|
387
|
+
transcript=content.text,
|
388
|
+
),
|
389
|
+
)
|
390
|
+
|
391
|
+
# self._chat_ctx.append(text=content.text, role="user")
|
392
|
+
# TODO: implement sync mechanism to make sure the transcribed user speech is inside the chat_ctx and always before the generated agent speech
|
393
|
+
|
394
|
+
def _on_agent_speech_done(self, content: TranscriptionContent) -> None:
|
395
|
+
if not self._is_interrupted and content.response_id and content.text:
|
396
|
+
self.emit(
|
397
|
+
"agent_speech_transcription_completed",
|
398
|
+
InputTranscription(
|
399
|
+
item_id=content.response_id,
|
400
|
+
transcript=content.text,
|
401
|
+
),
|
402
|
+
)
|
403
|
+
# self._chat_ctx.append(text=content.text, role="assistant")
|
404
|
+
|
284
405
|
@utils.log_exceptions(logger=logger)
|
285
406
|
async def _main_task(self):
|
286
407
|
@utils.log_exceptions(logger=logger)
|
287
408
|
async def _send_task():
|
288
409
|
async for msg in self._send_ch:
|
289
|
-
await self._session.send(msg)
|
410
|
+
await self._session.send(input=msg)
|
290
411
|
|
291
|
-
await self._session.send(".", end_of_turn=True)
|
412
|
+
await self._session.send(input=".", end_of_turn=True)
|
292
413
|
|
293
414
|
@utils.log_exceptions(logger=logger)
|
294
415
|
async def _recv_task():
|
295
416
|
while True:
|
296
417
|
async for response in self._session.receive():
|
297
418
|
if self._active_response_id is None:
|
419
|
+
self._is_interrupted = False
|
298
420
|
self._active_response_id = utils.shortuuid()
|
299
421
|
text_stream = utils.aio.Chan[str]()
|
300
422
|
audio_stream = utils.aio.Chan[rtc.AudioFrame]()
|
@@ -307,7 +429,7 @@ class GeminiRealtimeSession(utils.EventEmitter[EventTypes]):
|
|
307
429
|
audio=[],
|
308
430
|
text_stream=text_stream,
|
309
431
|
audio_stream=audio_stream,
|
310
|
-
content_type=
|
432
|
+
content_type="audio",
|
311
433
|
)
|
312
434
|
self.emit("response_content_added", content)
|
313
435
|
|
@@ -326,6 +448,8 @@ class GeminiRealtimeSession(utils.EventEmitter[EventTypes]):
|
|
326
448
|
samples_per_channel=len(part.inline_data.data)
|
327
449
|
// 2,
|
328
450
|
)
|
451
|
+
if self._opts.enable_agent_audio_transcription:
|
452
|
+
self._agent_transcriber._push_audio(frame)
|
329
453
|
content.audio_stream.send_nowait(frame)
|
330
454
|
|
331
455
|
if server_content.interrupted or server_content.turn_complete:
|
@@ -333,10 +457,8 @@ class GeminiRealtimeSession(utils.EventEmitter[EventTypes]):
|
|
333
457
|
if isinstance(stream, utils.aio.Chan):
|
334
458
|
stream.close()
|
335
459
|
|
336
|
-
|
337
|
-
|
338
|
-
elif server_content.turn_complete:
|
339
|
-
self.emit("response_content_done", content)
|
460
|
+
self.emit("agent_speech_stopped")
|
461
|
+
self._is_interrupted = True
|
340
462
|
|
341
463
|
self._active_response_id = None
|
342
464
|
|
@@ -387,6 +509,10 @@ class GeminiRealtimeSession(utils.EventEmitter[EventTypes]):
|
|
387
509
|
finally:
|
388
510
|
await utils.aio.gracefully_cancel(*tasks)
|
389
511
|
await self._session.close()
|
512
|
+
if self._opts.enable_user_audio_transcription:
|
513
|
+
await self._transcriber.aclose()
|
514
|
+
if self._opts.enable_agent_audio_transcription:
|
515
|
+
await self._agent_transcriber.aclose()
|
390
516
|
|
391
517
|
@utils.log_exceptions(logger=logger)
|
392
518
|
async def _run_fnc_task(self, fnc_call_info: llm.FunctionCallInfo, item_id: str):
|
@@ -419,6 +545,6 @@ class GeminiRealtimeSession(utils.EventEmitter[EventTypes]):
|
|
419
545
|
)
|
420
546
|
]
|
421
547
|
)
|
422
|
-
await self._session.send(tool_response)
|
548
|
+
await self._session.send(input=tool_response)
|
423
549
|
|
424
550
|
self.emit("function_calls_finished", [called_fnc])
|