livekit-plugins-google 0.9.1__py3-none-any.whl → 0.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- livekit/plugins/google/__init__.py +2 -1
- livekit/plugins/google/_utils.py +202 -0
- livekit/plugins/google/beta/realtime/__init__.py +0 -2
- livekit/plugins/google/beta/realtime/api_proto.py +5 -60
- livekit/plugins/google/beta/realtime/realtime_api.py +155 -41
- livekit/plugins/google/beta/realtime/transcriber.py +173 -0
- livekit/plugins/google/llm.py +414 -0
- livekit/plugins/google/models.py +2 -0
- livekit/plugins/google/stt.py +3 -3
- livekit/plugins/google/version.py +1 -1
- {livekit_plugins_google-0.9.1.dist-info → livekit_plugins_google-0.10.1.dist-info}/METADATA +2 -2
- livekit_plugins_google-0.10.1.dist-info/RECORD +18 -0
- livekit_plugins_google-0.9.1.dist-info/RECORD +0 -15
- {livekit_plugins_google-0.9.1.dist-info → livekit_plugins_google-0.10.1.dist-info}/WHEEL +0 -0
- {livekit_plugins_google-0.9.1.dist-info → livekit_plugins_google-0.10.1.dist-info}/top_level.txt +0 -0
@@ -13,11 +13,12 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
from . import beta
|
16
|
+
from .llm import LLM
|
16
17
|
from .stt import STT, SpeechStream
|
17
18
|
from .tts import TTS
|
18
19
|
from .version import __version__
|
19
20
|
|
20
|
-
__all__ = ["STT", "TTS", "SpeechStream", "__version__", "beta"]
|
21
|
+
__all__ = ["STT", "TTS", "SpeechStream", "__version__", "beta", "LLM"]
|
21
22
|
from livekit.agents import Plugin
|
22
23
|
|
23
24
|
from .log import logger
|
@@ -0,0 +1,202 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import base64
|
4
|
+
import inspect
|
5
|
+
import json
|
6
|
+
from typing import Any, Dict, List, Optional, get_args, get_origin
|
7
|
+
|
8
|
+
from livekit import rtc
|
9
|
+
from livekit.agents import llm, utils
|
10
|
+
from livekit.agents.llm.function_context import _is_optional_type
|
11
|
+
|
12
|
+
from google.genai import types
|
13
|
+
|
14
|
+
JSON_SCHEMA_TYPE_MAP: dict[type, types.Type] = {
|
15
|
+
str: "STRING",
|
16
|
+
int: "INTEGER",
|
17
|
+
float: "NUMBER",
|
18
|
+
bool: "BOOLEAN",
|
19
|
+
dict: "OBJECT",
|
20
|
+
list: "ARRAY",
|
21
|
+
}
|
22
|
+
|
23
|
+
__all__ = ["_build_gemini_ctx", "_build_tools"]
|
24
|
+
|
25
|
+
|
26
|
+
def _build_parameters(arguments: Dict[str, Any]) -> types.Schema | None:
|
27
|
+
properties: Dict[str, types.Schema] = {}
|
28
|
+
required: List[str] = []
|
29
|
+
|
30
|
+
for arg_name, arg_info in arguments.items():
|
31
|
+
prop = types.Schema()
|
32
|
+
if arg_info.description:
|
33
|
+
prop.description = arg_info.description
|
34
|
+
|
35
|
+
_, py_type = _is_optional_type(arg_info.type)
|
36
|
+
origin = get_origin(py_type)
|
37
|
+
if origin is list:
|
38
|
+
item_type = get_args(py_type)[0]
|
39
|
+
if item_type not in JSON_SCHEMA_TYPE_MAP:
|
40
|
+
raise ValueError(f"Unsupported type: {item_type}")
|
41
|
+
prop.type = "ARRAY"
|
42
|
+
prop.items = types.Schema(type=JSON_SCHEMA_TYPE_MAP[item_type])
|
43
|
+
|
44
|
+
if arg_info.choices:
|
45
|
+
prop.items.enum = arg_info.choices
|
46
|
+
else:
|
47
|
+
if py_type not in JSON_SCHEMA_TYPE_MAP:
|
48
|
+
raise ValueError(f"Unsupported type: {py_type}")
|
49
|
+
|
50
|
+
prop.type = JSON_SCHEMA_TYPE_MAP[py_type]
|
51
|
+
|
52
|
+
if arg_info.choices:
|
53
|
+
prop.enum = arg_info.choices
|
54
|
+
if py_type is int:
|
55
|
+
raise ValueError(
|
56
|
+
f"Parameter '{arg_info.name}' uses integer choices, not supported by this model."
|
57
|
+
)
|
58
|
+
|
59
|
+
properties[arg_name] = prop
|
60
|
+
|
61
|
+
if arg_info.default is inspect.Parameter.empty:
|
62
|
+
required.append(arg_name)
|
63
|
+
|
64
|
+
if properties:
|
65
|
+
parameters = types.Schema(type="OBJECT", properties=properties)
|
66
|
+
if required:
|
67
|
+
parameters.required = required
|
68
|
+
|
69
|
+
return parameters
|
70
|
+
|
71
|
+
return None
|
72
|
+
|
73
|
+
|
74
|
+
def _build_tools(fnc_ctx: Any) -> List[types.FunctionDeclaration]:
|
75
|
+
function_declarations: List[types.FunctionDeclaration] = []
|
76
|
+
for fnc_info in fnc_ctx.ai_functions.values():
|
77
|
+
parameters = _build_parameters(fnc_info.arguments)
|
78
|
+
|
79
|
+
func_decl = types.FunctionDeclaration(
|
80
|
+
name=fnc_info.name,
|
81
|
+
description=fnc_info.description,
|
82
|
+
parameters=parameters,
|
83
|
+
)
|
84
|
+
|
85
|
+
function_declarations.append(func_decl)
|
86
|
+
return function_declarations
|
87
|
+
|
88
|
+
|
89
|
+
def _build_gemini_ctx(
|
90
|
+
chat_ctx: llm.ChatContext, cache_key: Any
|
91
|
+
) -> tuple[list[types.Content], Optional[types.Content]]:
|
92
|
+
turns: list[types.Content] = []
|
93
|
+
system_instruction: Optional[types.Content] = None
|
94
|
+
current_role: Optional[str] = None
|
95
|
+
parts: list[types.Part] = []
|
96
|
+
|
97
|
+
for msg in chat_ctx.messages:
|
98
|
+
if msg.role == "system":
|
99
|
+
if isinstance(msg.content, str):
|
100
|
+
system_instruction = types.Content(parts=[types.Part(text=msg.content)])
|
101
|
+
continue
|
102
|
+
|
103
|
+
if msg.role == "assistant":
|
104
|
+
role = "model"
|
105
|
+
elif msg.role == "tool":
|
106
|
+
role = "user"
|
107
|
+
else:
|
108
|
+
role = "user"
|
109
|
+
|
110
|
+
# If role changed, finalize previous parts into a turn
|
111
|
+
if role != current_role:
|
112
|
+
if current_role is not None and parts:
|
113
|
+
turns.append(types.Content(role=current_role, parts=parts))
|
114
|
+
current_role = role
|
115
|
+
parts = []
|
116
|
+
|
117
|
+
if msg.tool_calls:
|
118
|
+
for fnc in msg.tool_calls:
|
119
|
+
parts.append(
|
120
|
+
types.Part(
|
121
|
+
function_call=types.FunctionCall(
|
122
|
+
id=fnc.tool_call_id,
|
123
|
+
name=fnc.function_info.name,
|
124
|
+
args=fnc.arguments,
|
125
|
+
)
|
126
|
+
)
|
127
|
+
)
|
128
|
+
|
129
|
+
if msg.role == "tool":
|
130
|
+
if msg.content:
|
131
|
+
if isinstance(msg.content, dict):
|
132
|
+
parts.append(
|
133
|
+
types.Part(
|
134
|
+
function_response=types.FunctionResponse(
|
135
|
+
id=msg.tool_call_id,
|
136
|
+
name=msg.name,
|
137
|
+
response=msg.content,
|
138
|
+
)
|
139
|
+
)
|
140
|
+
)
|
141
|
+
elif isinstance(msg.content, str):
|
142
|
+
parts.append(
|
143
|
+
types.Part(
|
144
|
+
function_response=types.FunctionResponse(
|
145
|
+
id=msg.tool_call_id,
|
146
|
+
name=msg.name,
|
147
|
+
response={"result": msg.content},
|
148
|
+
)
|
149
|
+
)
|
150
|
+
)
|
151
|
+
else:
|
152
|
+
if msg.content:
|
153
|
+
if isinstance(msg.content, str):
|
154
|
+
parts.append(types.Part(text=msg.content))
|
155
|
+
elif isinstance(msg.content, dict):
|
156
|
+
parts.append(types.Part(text=json.dumps(msg.content)))
|
157
|
+
elif isinstance(msg.content, list):
|
158
|
+
for item in msg.content:
|
159
|
+
if isinstance(item, str):
|
160
|
+
parts.append(types.Part(text=item))
|
161
|
+
elif isinstance(item, llm.ChatImage):
|
162
|
+
parts.append(_build_gemini_image_part(item, cache_key))
|
163
|
+
|
164
|
+
# Finalize last role's parts if any remain
|
165
|
+
if current_role is not None and parts:
|
166
|
+
turns.append(types.Content(role=current_role, parts=parts))
|
167
|
+
|
168
|
+
return turns, system_instruction
|
169
|
+
|
170
|
+
|
171
|
+
def _build_gemini_image_part(image: llm.ChatImage, cache_key: Any) -> types.Part:
|
172
|
+
if isinstance(image.image, str):
|
173
|
+
# Check if the string is a Data URL
|
174
|
+
if image.image.startswith("data:image/jpeg;base64,"):
|
175
|
+
# Extract the base64 part after the comma
|
176
|
+
base64_data = image.image.split(",", 1)[1]
|
177
|
+
try:
|
178
|
+
image_bytes = base64.b64decode(base64_data)
|
179
|
+
except Exception as e:
|
180
|
+
raise ValueError("Invalid base64 data in image URL") from e
|
181
|
+
|
182
|
+
return types.Part.from_bytes(data=image_bytes, mime_type="image/jpeg")
|
183
|
+
else:
|
184
|
+
# Assume it's a regular URL
|
185
|
+
return types.Part.from_uri(file_uri=image.image, mime_type="image/jpeg")
|
186
|
+
|
187
|
+
elif isinstance(image.image, rtc.VideoFrame):
|
188
|
+
if cache_key not in image._cache:
|
189
|
+
opts = utils.images.EncodeOptions()
|
190
|
+
if image.inference_width and image.inference_height:
|
191
|
+
opts.resize_options = utils.images.ResizeOptions(
|
192
|
+
width=image.inference_width,
|
193
|
+
height=image.inference_height,
|
194
|
+
strategy="scale_aspect_fit",
|
195
|
+
)
|
196
|
+
encoded_data = utils.images.encode(image.image, opts)
|
197
|
+
image._cache[cache_key] = base64.b64encode(encoded_data).decode("utf-8")
|
198
|
+
|
199
|
+
return types.Part.from_bytes(
|
200
|
+
data=image._cache[cache_key], mime_type="image/jpeg"
|
201
|
+
)
|
202
|
+
raise ValueError(f"Unsupported image type: {type(image.image)}")
|
@@ -1,7 +1,6 @@
|
|
1
1
|
from .api_proto import (
|
2
2
|
ClientEvents,
|
3
3
|
LiveAPIModels,
|
4
|
-
ResponseModality,
|
5
4
|
Voice,
|
6
5
|
)
|
7
6
|
from .realtime_api import RealtimeModel
|
@@ -10,6 +9,5 @@ __all__ = [
|
|
10
9
|
"RealtimeModel",
|
11
10
|
"ClientEvents",
|
12
11
|
"LiveAPIModels",
|
13
|
-
"ResponseModality",
|
14
12
|
"Voice",
|
15
13
|
]
|
@@ -1,15 +1,16 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
import
|
4
|
-
from typing import Any, Dict, List, Literal, Sequence, Union
|
3
|
+
from typing import Literal, Sequence, Union
|
5
4
|
|
6
|
-
from google.genai import types
|
5
|
+
from google.genai import types
|
6
|
+
|
7
|
+
from ..._utils import _build_gemini_ctx, _build_tools
|
7
8
|
|
8
9
|
LiveAPIModels = Literal["gemini-2.0-flash-exp"]
|
9
10
|
|
10
11
|
Voice = Literal["Puck", "Charon", "Kore", "Fenrir", "Aoede"]
|
11
|
-
ResponseModality = Literal["AUDIO", "TEXT"]
|
12
12
|
|
13
|
+
__all__ = ["_build_tools", "ClientEvents", "_build_gemini_ctx"]
|
13
14
|
|
14
15
|
ClientEvents = Union[
|
15
16
|
types.ContentListUnion,
|
@@ -21,59 +22,3 @@ ClientEvents = Union[
|
|
21
22
|
types.FunctionResponseOrDict,
|
22
23
|
Sequence[types.FunctionResponseOrDict],
|
23
24
|
]
|
24
|
-
|
25
|
-
|
26
|
-
JSON_SCHEMA_TYPE_MAP = {
|
27
|
-
str: "string",
|
28
|
-
int: "integer",
|
29
|
-
float: "number",
|
30
|
-
bool: "boolean",
|
31
|
-
dict: "object",
|
32
|
-
list: "array",
|
33
|
-
}
|
34
|
-
|
35
|
-
|
36
|
-
def _build_parameters(arguments: Dict[str, Any]) -> types.SchemaDict:
|
37
|
-
properties: Dict[str, types.SchemaDict] = {}
|
38
|
-
required: List[str] = []
|
39
|
-
|
40
|
-
for arg_name, arg_info in arguments.items():
|
41
|
-
py_type = arg_info.type
|
42
|
-
if py_type not in JSON_SCHEMA_TYPE_MAP:
|
43
|
-
raise ValueError(f"Unsupported type: {py_type}")
|
44
|
-
|
45
|
-
prop: types.SchemaDict = {
|
46
|
-
"type": JSON_SCHEMA_TYPE_MAP[py_type],
|
47
|
-
"description": arg_info.description,
|
48
|
-
}
|
49
|
-
|
50
|
-
if arg_info.choices:
|
51
|
-
prop["enum"] = arg_info.choices
|
52
|
-
|
53
|
-
properties[arg_name] = prop
|
54
|
-
|
55
|
-
if arg_info.default is inspect.Parameter.empty:
|
56
|
-
required.append(arg_name)
|
57
|
-
|
58
|
-
parameters: types.SchemaDict = {"type": "object", "properties": properties}
|
59
|
-
|
60
|
-
if required:
|
61
|
-
parameters["required"] = required
|
62
|
-
|
63
|
-
return parameters
|
64
|
-
|
65
|
-
|
66
|
-
def _build_tools(fnc_ctx: Any) -> List[types.FunctionDeclarationDict]:
|
67
|
-
function_declarations: List[types.FunctionDeclarationDict] = []
|
68
|
-
for fnc_info in fnc_ctx.ai_functions.values():
|
69
|
-
parameters = _build_parameters(fnc_info.arguments)
|
70
|
-
|
71
|
-
func_decl: types.FunctionDeclarationDict = {
|
72
|
-
"name": fnc_info.name,
|
73
|
-
"description": fnc_info.description,
|
74
|
-
"parameters": parameters,
|
75
|
-
}
|
76
|
-
|
77
|
-
function_declarations.append(func_decl)
|
78
|
-
|
79
|
-
return function_declarations
|
@@ -1,7 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import asyncio
|
4
|
-
import base64
|
5
4
|
import json
|
6
5
|
import os
|
7
6
|
from dataclasses import dataclass
|
@@ -11,14 +10,22 @@ from livekit import rtc
|
|
11
10
|
from livekit.agents import llm, utils
|
12
11
|
from livekit.agents.llm.function_context import _create_ai_function_info
|
13
12
|
|
14
|
-
from google import genai
|
15
|
-
from google.genai.
|
13
|
+
from google import genai
|
14
|
+
from google.genai._api_client import HttpOptions
|
15
|
+
from google.genai.types import (
|
16
|
+
Blob,
|
17
|
+
Content,
|
16
18
|
FunctionResponse,
|
17
|
-
|
19
|
+
GenerationConfig,
|
20
|
+
LiveClientContent,
|
21
|
+
LiveClientRealtimeInput,
|
18
22
|
LiveClientToolResponse,
|
19
|
-
|
23
|
+
LiveConnectConfig,
|
24
|
+
Modality,
|
25
|
+
Part,
|
20
26
|
PrebuiltVoiceConfig,
|
21
27
|
SpeechConfig,
|
28
|
+
Tool,
|
22
29
|
VoiceConfig,
|
23
30
|
)
|
24
31
|
|
@@ -26,10 +33,11 @@ from ...log import logger
|
|
26
33
|
from .api_proto import (
|
27
34
|
ClientEvents,
|
28
35
|
LiveAPIModels,
|
29
|
-
ResponseModality,
|
30
36
|
Voice,
|
37
|
+
_build_gemini_ctx,
|
31
38
|
_build_tools,
|
32
39
|
)
|
40
|
+
from .transcriber import TranscriberSession, TranscriptionContent
|
33
41
|
|
34
42
|
EventTypes = Literal[
|
35
43
|
"start_session",
|
@@ -39,6 +47,9 @@ EventTypes = Literal[
|
|
39
47
|
"function_calls_collected",
|
40
48
|
"function_calls_finished",
|
41
49
|
"function_calls_cancelled",
|
50
|
+
"input_speech_transcription_completed",
|
51
|
+
"agent_speech_transcription_completed",
|
52
|
+
"agent_speech_stopped",
|
42
53
|
]
|
43
54
|
|
44
55
|
|
@@ -55,6 +66,12 @@ class GeminiContent:
|
|
55
66
|
content_type: Literal["text", "audio"]
|
56
67
|
|
57
68
|
|
69
|
+
@dataclass
|
70
|
+
class InputTranscription:
|
71
|
+
item_id: str
|
72
|
+
transcript: str
|
73
|
+
|
74
|
+
|
58
75
|
@dataclass
|
59
76
|
class Capabilities:
|
60
77
|
supports_truncate: bool
|
@@ -65,7 +82,7 @@ class ModelOptions:
|
|
65
82
|
model: LiveAPIModels | str
|
66
83
|
api_key: str | None
|
67
84
|
voice: Voice | str
|
68
|
-
response_modalities:
|
85
|
+
response_modalities: list[Modality] | None
|
69
86
|
vertexai: bool
|
70
87
|
project: str | None
|
71
88
|
location: str | None
|
@@ -76,18 +93,22 @@ class ModelOptions:
|
|
76
93
|
top_k: int | None
|
77
94
|
presence_penalty: float | None
|
78
95
|
frequency_penalty: float | None
|
79
|
-
instructions:
|
96
|
+
instructions: Content | None
|
97
|
+
enable_user_audio_transcription: bool
|
98
|
+
enable_agent_audio_transcription: bool
|
80
99
|
|
81
100
|
|
82
101
|
class RealtimeModel:
|
83
102
|
def __init__(
|
84
103
|
self,
|
85
104
|
*,
|
86
|
-
instructions: str =
|
105
|
+
instructions: str | None = None,
|
87
106
|
model: LiveAPIModels | str = "gemini-2.0-flash-exp",
|
88
107
|
api_key: str | None = None,
|
89
108
|
voice: Voice | str = "Puck",
|
90
|
-
modalities:
|
109
|
+
modalities: list[Modality] = ["AUDIO"],
|
110
|
+
enable_user_audio_transcription: bool = True,
|
111
|
+
enable_agent_audio_transcription: bool = True,
|
91
112
|
vertexai: bool = False,
|
92
113
|
project: str | None = None,
|
93
114
|
location: str | None = None,
|
@@ -103,15 +124,24 @@ class RealtimeModel:
|
|
103
124
|
"""
|
104
125
|
Initializes a RealtimeModel instance for interacting with Google's Realtime API.
|
105
126
|
|
127
|
+
Environment Requirements:
|
128
|
+
- For VertexAI: Set the `GOOGLE_APPLICATION_CREDENTIALS` environment variable to the path of the service account key file.
|
129
|
+
The Google Cloud project and location can be set via `project` and `location` arguments or the environment variables
|
130
|
+
`GOOGLE_CLOUD_PROJECT` and `GOOGLE_CLOUD_LOCATION`. By default, the project is inferred from the service account key file,
|
131
|
+
and the location defaults to "us-central1".
|
132
|
+
- For Google Gemini API: Set the `api_key` argument or the `GOOGLE_API_KEY` environment variable.
|
133
|
+
|
106
134
|
Args:
|
107
135
|
instructions (str, optional): Initial system instructions for the model. Defaults to "".
|
108
136
|
api_key (str or None, optional): Google Gemini API key. If None, will attempt to read from the environment variable GOOGLE_API_KEY.
|
109
|
-
modalities (
|
137
|
+
modalities (list[Modality], optional): Modalities to use, such as ["TEXT", "AUDIO"]. Defaults to ["AUDIO"].
|
110
138
|
model (str or None, optional): The name of the model to use. Defaults to "gemini-2.0-flash-exp".
|
111
139
|
voice (api_proto.Voice, optional): Voice setting for audio outputs. Defaults to "Puck".
|
140
|
+
enable_user_audio_transcription (bool, optional): Whether to enable user audio transcription. Defaults to True
|
141
|
+
enable_agent_audio_transcription (bool, optional): Whether to enable agent audio transcription. Defaults to True
|
112
142
|
temperature (float, optional): Sampling temperature for response generation. Defaults to 0.8.
|
113
143
|
vertexai (bool, optional): Whether to use VertexAI for the API. Defaults to False.
|
114
|
-
project (str or None, optional): The project to use for the API. Defaults to None. (for vertexai)
|
144
|
+
project (str or None, optional): The project id to use for the API. Defaults to None. (for vertexai)
|
115
145
|
location (str or None, optional): The location to use for the API. Defaults to None. (for vertexai)
|
116
146
|
candidate_count (int, optional): The number of candidate responses to generate. Defaults to 1.
|
117
147
|
top_p (float, optional): The top-p value for response generation
|
@@ -130,21 +160,38 @@ class RealtimeModel:
|
|
130
160
|
self._model = model
|
131
161
|
self._loop = loop or asyncio.get_event_loop()
|
132
162
|
self._api_key = api_key or os.environ.get("GOOGLE_API_KEY")
|
133
|
-
self.
|
134
|
-
self.
|
135
|
-
|
136
|
-
|
137
|
-
|
163
|
+
self._project = project or os.environ.get("GOOGLE_CLOUD_PROJECT")
|
164
|
+
self._location = location or os.environ.get("GOOGLE_CLOUD_LOCATION")
|
165
|
+
if vertexai:
|
166
|
+
if not self._project or not self._location:
|
167
|
+
raise ValueError(
|
168
|
+
"Project and location are required for VertexAI either via project and location or GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION environment variables"
|
169
|
+
)
|
170
|
+
self._api_key = None # VertexAI does not require an API key
|
171
|
+
|
172
|
+
else:
|
173
|
+
self._project = None
|
174
|
+
self._location = None
|
175
|
+
if not self._api_key:
|
176
|
+
raise ValueError(
|
177
|
+
"API key is required for Google API either via api_key or GOOGLE_API_KEY environment variable"
|
178
|
+
)
|
179
|
+
|
180
|
+
instructions_content = (
|
181
|
+
Content(parts=[Part(text=instructions)]) if instructions else None
|
182
|
+
)
|
138
183
|
|
139
184
|
self._rt_sessions: list[GeminiRealtimeSession] = []
|
140
185
|
self._opts = ModelOptions(
|
141
186
|
model=model,
|
142
|
-
api_key=
|
187
|
+
api_key=self._api_key,
|
143
188
|
voice=voice,
|
189
|
+
enable_user_audio_transcription=enable_user_audio_transcription,
|
190
|
+
enable_agent_audio_transcription=enable_agent_audio_transcription,
|
144
191
|
response_modalities=modalities,
|
145
192
|
vertexai=vertexai,
|
146
|
-
project=
|
147
|
-
location=
|
193
|
+
project=self._project,
|
194
|
+
location=self._location,
|
148
195
|
candidate_count=candidate_count,
|
149
196
|
temperature=temperature,
|
150
197
|
max_output_tokens=max_output_tokens,
|
@@ -152,7 +199,7 @@ class RealtimeModel:
|
|
152
199
|
top_k=top_k,
|
153
200
|
presence_penalty=presence_penalty,
|
154
201
|
frequency_penalty=frequency_penalty,
|
155
|
-
instructions=
|
202
|
+
instructions=instructions_content,
|
156
203
|
)
|
157
204
|
|
158
205
|
@property
|
@@ -208,16 +255,16 @@ class GeminiRealtimeSession(utils.EventEmitter[EventTypes]):
|
|
208
255
|
self._chat_ctx = chat_ctx
|
209
256
|
self._fnc_ctx = fnc_ctx
|
210
257
|
self._fnc_tasks = utils.aio.TaskSet()
|
258
|
+
self._is_interrupted = False
|
211
259
|
|
212
260
|
tools = []
|
213
261
|
if self._fnc_ctx is not None:
|
214
262
|
functions = _build_tools(self._fnc_ctx)
|
215
|
-
tools.append(
|
263
|
+
tools.append(Tool(function_declarations=functions))
|
216
264
|
|
217
|
-
self._config =
|
218
|
-
model=self._opts.model,
|
265
|
+
self._config = LiveConnectConfig(
|
219
266
|
response_modalities=self._opts.response_modalities,
|
220
|
-
generation_config=
|
267
|
+
generation_config=GenerationConfig(
|
221
268
|
candidate_count=self._opts.candidate_count,
|
222
269
|
temperature=self._opts.temperature,
|
223
270
|
max_output_tokens=self._opts.max_output_tokens,
|
@@ -237,7 +284,7 @@ class GeminiRealtimeSession(utils.EventEmitter[EventTypes]):
|
|
237
284
|
tools=tools,
|
238
285
|
)
|
239
286
|
self._client = genai.Client(
|
240
|
-
http_options=
|
287
|
+
http_options=HttpOptions(api_version="v1alpha"),
|
241
288
|
api_key=self._opts.api_key,
|
242
289
|
vertexai=self._opts.vertexai,
|
243
290
|
project=self._opts.project,
|
@@ -246,10 +293,18 @@ class GeminiRealtimeSession(utils.EventEmitter[EventTypes]):
|
|
246
293
|
self._main_atask = asyncio.create_task(
|
247
294
|
self._main_task(), name="gemini-realtime-session"
|
248
295
|
)
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
296
|
+
if self._opts.enable_user_audio_transcription:
|
297
|
+
self._transcriber = TranscriberSession(
|
298
|
+
client=self._client, model=self._opts.model
|
299
|
+
)
|
300
|
+
self._transcriber.on("input_speech_done", self._on_input_speech_done)
|
301
|
+
if self._opts.enable_agent_audio_transcription:
|
302
|
+
self._agent_transcriber = TranscriberSession(
|
303
|
+
client=self._client, model=self._opts.model
|
304
|
+
)
|
305
|
+
self._agent_transcriber.on("input_speech_done", self._on_agent_speech_done)
|
306
|
+
# init dummy task
|
307
|
+
self._init_sync_task = asyncio.create_task(asyncio.sleep(0))
|
253
308
|
self._send_ch = utils.aio.Chan[ClientEvents]()
|
254
309
|
self._active_response_id = None
|
255
310
|
|
@@ -269,10 +324,14 @@ class GeminiRealtimeSession(utils.EventEmitter[EventTypes]):
|
|
269
324
|
self._fnc_ctx = value
|
270
325
|
|
271
326
|
def _push_audio(self, frame: rtc.AudioFrame) -> None:
|
272
|
-
|
273
|
-
|
327
|
+
if self._opts.enable_user_audio_transcription:
|
328
|
+
self._transcriber._push_audio(frame)
|
329
|
+
realtime_input = LiveClientRealtimeInput(
|
330
|
+
media_chunks=[Blob(data=frame.data.tobytes(), mime_type="audio/pcm")],
|
331
|
+
)
|
332
|
+
self._queue_msg(realtime_input)
|
274
333
|
|
275
|
-
def _queue_msg(self, msg:
|
334
|
+
def _queue_msg(self, msg: ClientEvents) -> None:
|
276
335
|
self._send_ch.send_nowait(msg)
|
277
336
|
|
278
337
|
def chat_ctx_copy(self) -> llm.ChatContext:
|
@@ -281,20 +340,71 @@ class GeminiRealtimeSession(utils.EventEmitter[EventTypes]):
|
|
281
340
|
async def set_chat_ctx(self, ctx: llm.ChatContext) -> None:
|
282
341
|
self._chat_ctx = ctx.copy()
|
283
342
|
|
343
|
+
def cancel_response(self) -> None:
|
344
|
+
raise NotImplementedError("cancel_response is not supported yet")
|
345
|
+
|
346
|
+
def create_response(
|
347
|
+
self,
|
348
|
+
on_duplicate: Literal[
|
349
|
+
"cancel_existing", "cancel_new", "keep_both"
|
350
|
+
] = "keep_both",
|
351
|
+
) -> None:
|
352
|
+
turns, _ = _build_gemini_ctx(self._chat_ctx, id(self))
|
353
|
+
ctx = [self._opts.instructions] + turns if self._opts.instructions else turns
|
354
|
+
|
355
|
+
if not ctx:
|
356
|
+
logger.warning(
|
357
|
+
"gemini-realtime-session: No chat context to send, sending dummy content."
|
358
|
+
)
|
359
|
+
ctx = [Content(parts=[Part(text=".")])]
|
360
|
+
|
361
|
+
self._queue_msg(LiveClientContent(turns=ctx, turn_complete=True))
|
362
|
+
|
363
|
+
def commit_audio_buffer(self) -> None:
|
364
|
+
raise NotImplementedError("commit_audio_buffer is not supported yet")
|
365
|
+
|
366
|
+
def server_vad_enabled(self) -> bool:
|
367
|
+
return True
|
368
|
+
|
369
|
+
def _on_input_speech_done(self, content: TranscriptionContent) -> None:
|
370
|
+
if content.response_id and content.text:
|
371
|
+
self.emit(
|
372
|
+
"input_speech_transcription_completed",
|
373
|
+
InputTranscription(
|
374
|
+
item_id=content.response_id,
|
375
|
+
transcript=content.text,
|
376
|
+
),
|
377
|
+
)
|
378
|
+
|
379
|
+
# self._chat_ctx.append(text=content.text, role="user")
|
380
|
+
# TODO: implement sync mechanism to make sure the transcribed user speech is inside the chat_ctx and always before the generated agent speech
|
381
|
+
|
382
|
+
def _on_agent_speech_done(self, content: TranscriptionContent) -> None:
|
383
|
+
if not self._is_interrupted and content.response_id and content.text:
|
384
|
+
self.emit(
|
385
|
+
"agent_speech_transcription_completed",
|
386
|
+
InputTranscription(
|
387
|
+
item_id=content.response_id,
|
388
|
+
transcript=content.text,
|
389
|
+
),
|
390
|
+
)
|
391
|
+
# self._chat_ctx.append(text=content.text, role="assistant")
|
392
|
+
|
284
393
|
@utils.log_exceptions(logger=logger)
|
285
394
|
async def _main_task(self):
|
286
395
|
@utils.log_exceptions(logger=logger)
|
287
396
|
async def _send_task():
|
288
397
|
async for msg in self._send_ch:
|
289
|
-
await self._session.send(msg)
|
398
|
+
await self._session.send(input=msg)
|
290
399
|
|
291
|
-
await self._session.send(".", end_of_turn=True)
|
400
|
+
await self._session.send(input=".", end_of_turn=True)
|
292
401
|
|
293
402
|
@utils.log_exceptions(logger=logger)
|
294
403
|
async def _recv_task():
|
295
404
|
while True:
|
296
405
|
async for response in self._session.receive():
|
297
406
|
if self._active_response_id is None:
|
407
|
+
self._is_interrupted = False
|
298
408
|
self._active_response_id = utils.shortuuid()
|
299
409
|
text_stream = utils.aio.Chan[str]()
|
300
410
|
audio_stream = utils.aio.Chan[rtc.AudioFrame]()
|
@@ -307,7 +417,7 @@ class GeminiRealtimeSession(utils.EventEmitter[EventTypes]):
|
|
307
417
|
audio=[],
|
308
418
|
text_stream=text_stream,
|
309
419
|
audio_stream=audio_stream,
|
310
|
-
content_type=
|
420
|
+
content_type="audio",
|
311
421
|
)
|
312
422
|
self.emit("response_content_added", content)
|
313
423
|
|
@@ -326,6 +436,8 @@ class GeminiRealtimeSession(utils.EventEmitter[EventTypes]):
|
|
326
436
|
samples_per_channel=len(part.inline_data.data)
|
327
437
|
// 2,
|
328
438
|
)
|
439
|
+
if self._opts.enable_agent_audio_transcription:
|
440
|
+
self._agent_transcriber._push_audio(frame)
|
329
441
|
content.audio_stream.send_nowait(frame)
|
330
442
|
|
331
443
|
if server_content.interrupted or server_content.turn_complete:
|
@@ -333,10 +445,8 @@ class GeminiRealtimeSession(utils.EventEmitter[EventTypes]):
|
|
333
445
|
if isinstance(stream, utils.aio.Chan):
|
334
446
|
stream.close()
|
335
447
|
|
336
|
-
|
337
|
-
|
338
|
-
elif server_content.turn_complete:
|
339
|
-
self.emit("response_content_done", content)
|
448
|
+
self.emit("agent_speech_stopped")
|
449
|
+
self._is_interrupted = True
|
340
450
|
|
341
451
|
self._active_response_id = None
|
342
452
|
|
@@ -387,6 +497,10 @@ class GeminiRealtimeSession(utils.EventEmitter[EventTypes]):
|
|
387
497
|
finally:
|
388
498
|
await utils.aio.gracefully_cancel(*tasks)
|
389
499
|
await self._session.close()
|
500
|
+
if self._opts.enable_user_audio_transcription:
|
501
|
+
await self._transcriber.aclose()
|
502
|
+
if self._opts.enable_agent_audio_transcription:
|
503
|
+
await self._agent_transcriber.aclose()
|
390
504
|
|
391
505
|
@utils.log_exceptions(logger=logger)
|
392
506
|
async def _run_fnc_task(self, fnc_call_info: llm.FunctionCallInfo, item_id: str):
|
@@ -419,6 +533,6 @@ class GeminiRealtimeSession(utils.EventEmitter[EventTypes]):
|
|
419
533
|
)
|
420
534
|
]
|
421
535
|
)
|
422
|
-
await self._session.send(tool_response)
|
536
|
+
await self._session.send(input=tool_response)
|
423
537
|
|
424
538
|
self.emit("function_calls_finished", [called_fnc])
|