livekit-plugins-volcenginee 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- livekit/plugins/volcengine/__init__.py +28 -0
- livekit/plugins/volcengine/llm.py +283 -0
- livekit/plugins/volcengine/log.py +3 -0
- livekit/plugins/volcengine/py.typed +0 -0
- livekit/plugins/volcengine/realtime.py +906 -0
- livekit/plugins/volcengine/stt.py +495 -0
- livekit/plugins/volcengine/tts.py +298 -0
- livekit/plugins/volcengine/utils.py +154 -0
- livekit/plugins/volcengine/version.py +1 -0
- livekit_plugins_volcenginee-1.3.0.dist-info/METADATA +615 -0
- livekit_plugins_volcenginee-1.3.0.dist-info/RECORD +12 -0
- livekit_plugins_volcenginee-1.3.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,906 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import contextlib
|
|
5
|
+
import copy
|
|
6
|
+
import json
|
|
7
|
+
import os
|
|
8
|
+
import time
|
|
9
|
+
import weakref
|
|
10
|
+
import gzip
|
|
11
|
+
import uuid
|
|
12
|
+
from collections.abc import Iterator
|
|
13
|
+
from dataclasses import dataclass
|
|
14
|
+
from typing import Literal, Callable
|
|
15
|
+
|
|
16
|
+
import aiohttp
|
|
17
|
+
import numpy as np
|
|
18
|
+
from livekit import rtc
|
|
19
|
+
from livekit.agents import llm, utils
|
|
20
|
+
from livekit.agents.types import (
|
|
21
|
+
DEFAULT_API_CONNECT_OPTIONS,
|
|
22
|
+
NOT_GIVEN,
|
|
23
|
+
APIConnectOptions,
|
|
24
|
+
NotGivenOr,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
from .log import logger
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
PROTOCOL_VERSION = 0b0001
|
|
31
|
+
DEFAULT_HEADER_SIZE = 0b0001
|
|
32
|
+
|
|
33
|
+
PROTOCOL_VERSION_BITS = 4
|
|
34
|
+
HEADER_BITS = 4
|
|
35
|
+
MESSAGE_TYPE_BITS = 4
|
|
36
|
+
MESSAGE_TYPE_SPECIFIC_FLAGS_BITS = 4
|
|
37
|
+
MESSAGE_SERIALIZATION_BITS = 4
|
|
38
|
+
MESSAGE_COMPRESSION_BITS = 4
|
|
39
|
+
RESERVED_BITS = 8
|
|
40
|
+
|
|
41
|
+
# Message Type:
|
|
42
|
+
CLIENT_FULL_REQUEST = 0b0001
|
|
43
|
+
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
|
|
44
|
+
|
|
45
|
+
SERVER_FULL_RESPONSE = 0b1001
|
|
46
|
+
SERVER_ACK = 0b1011
|
|
47
|
+
SERVER_ERROR_RESPONSE = 0b1111
|
|
48
|
+
|
|
49
|
+
# Message Type Specific Flags
|
|
50
|
+
NO_SEQUENCE = 0b0000 # no check sequence
|
|
51
|
+
POS_SEQUENCE = 0b0001
|
|
52
|
+
NEG_SEQUENCE = 0b0010
|
|
53
|
+
NEG_SEQUENCE_1 = 0b0011
|
|
54
|
+
|
|
55
|
+
MSG_WITH_EVENT = 0b0100
|
|
56
|
+
|
|
57
|
+
# Message Serialization
|
|
58
|
+
NO_SERIALIZATION = 0b0000
|
|
59
|
+
JSON = 0b0001
|
|
60
|
+
THRIFT = 0b0011
|
|
61
|
+
CUSTOM_TYPE = 0b1111
|
|
62
|
+
|
|
63
|
+
# Message Compression
|
|
64
|
+
NO_COMPRESSION = 0b0000
|
|
65
|
+
GZIP = 0b0001
|
|
66
|
+
CUSTOM_COMPRESSION = 0b1111
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def generate_header(
|
|
70
|
+
version=PROTOCOL_VERSION,
|
|
71
|
+
message_type=CLIENT_FULL_REQUEST,
|
|
72
|
+
message_type_specific_flags=MSG_WITH_EVENT,
|
|
73
|
+
serial_method=JSON,
|
|
74
|
+
compression_type=GZIP,
|
|
75
|
+
reserved_data=0x00,
|
|
76
|
+
extension_header=bytes(),
|
|
77
|
+
):
|
|
78
|
+
"""
|
|
79
|
+
protocol_version(4 bits), header_size(4 bits),
|
|
80
|
+
message_type(4 bits), message_type_specific_flags(4 bits)
|
|
81
|
+
serialization_method(4 bits) message_compression(4 bits)
|
|
82
|
+
reserved (8bits) 保留字段
|
|
83
|
+
header_extensions 扩展头(大小等于 8 * 4 * (header_size - 1) )
|
|
84
|
+
"""
|
|
85
|
+
header = bytearray()
|
|
86
|
+
header_size = int(len(extension_header) / 4) + 1
|
|
87
|
+
header.append((version << 4) | header_size)
|
|
88
|
+
header.append((message_type << 4) | message_type_specific_flags)
|
|
89
|
+
header.append((serial_method << 4) | compression_type)
|
|
90
|
+
header.append(reserved_data)
|
|
91
|
+
header.extend(extension_header)
|
|
92
|
+
return header
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def parse_response(res):
|
|
96
|
+
"""
|
|
97
|
+
- header
|
|
98
|
+
- (4bytes)header
|
|
99
|
+
- (4bits)version(v1) + (4bits)header_size
|
|
100
|
+
- (4bits)messageType + (4bits)messageTypeFlags
|
|
101
|
+
-- 0001 CompleteClient | -- 0001 hasSequence
|
|
102
|
+
-- 0010 audioonly | -- 0010 isTailPacket
|
|
103
|
+
| -- 0100 hasEvent
|
|
104
|
+
- (4bits)payloadFormat + (4bits)compression
|
|
105
|
+
- (8bits) reserve
|
|
106
|
+
- payload
|
|
107
|
+
- [optional 4 bytes] event
|
|
108
|
+
- [optional] session ID
|
|
109
|
+
-- (4 bytes)session ID len
|
|
110
|
+
-- session ID data
|
|
111
|
+
- (4 bytes)data len
|
|
112
|
+
- data
|
|
113
|
+
"""
|
|
114
|
+
if isinstance(res, str):
|
|
115
|
+
return {}
|
|
116
|
+
# protocol_version = res[0] >> 4
|
|
117
|
+
header_size = res[0] & 0x0F
|
|
118
|
+
message_type = res[1] >> 4
|
|
119
|
+
message_type_specific_flags = res[1] & 0x0F
|
|
120
|
+
serialization_method = res[2] >> 4
|
|
121
|
+
message_compression = res[2] & 0x0F
|
|
122
|
+
# reserved = res[3]
|
|
123
|
+
# header_extensions = res[4 : header_size * 4]
|
|
124
|
+
payload = res[header_size * 4 :]
|
|
125
|
+
result = {}
|
|
126
|
+
payload_msg = None
|
|
127
|
+
payload_size = 0
|
|
128
|
+
start = 0
|
|
129
|
+
if message_type == SERVER_FULL_RESPONSE or message_type == SERVER_ACK:
|
|
130
|
+
result["message_type"] = "SERVER_FULL_RESPONSE"
|
|
131
|
+
if message_type == SERVER_ACK:
|
|
132
|
+
result["message_type"] = "SERVER_ACK"
|
|
133
|
+
if message_type_specific_flags & NEG_SEQUENCE > 0:
|
|
134
|
+
result["seq"] = int.from_bytes(payload[:4], "big", signed=False)
|
|
135
|
+
start += 4
|
|
136
|
+
if message_type_specific_flags & MSG_WITH_EVENT > 0:
|
|
137
|
+
result["event"] = int.from_bytes(payload[:4], "big", signed=False)
|
|
138
|
+
start += 4
|
|
139
|
+
payload = payload[start:]
|
|
140
|
+
session_id_size = int.from_bytes(payload[:4], "big", signed=True)
|
|
141
|
+
session_id = payload[4 : session_id_size + 4]
|
|
142
|
+
result["session_id"] = str(session_id)
|
|
143
|
+
payload = payload[4 + session_id_size :]
|
|
144
|
+
payload_size = int.from_bytes(payload[:4], "big", signed=False)
|
|
145
|
+
payload_msg = payload[4:]
|
|
146
|
+
elif message_type == SERVER_ERROR_RESPONSE:
|
|
147
|
+
code = int.from_bytes(payload[:4], "big", signed=False)
|
|
148
|
+
result["code"] = code
|
|
149
|
+
payload_size = int.from_bytes(payload[4:8], "big", signed=False)
|
|
150
|
+
payload_msg = payload[8:]
|
|
151
|
+
if payload_msg is None:
|
|
152
|
+
return result
|
|
153
|
+
if message_compression == GZIP:
|
|
154
|
+
payload_msg = gzip.decompress(payload_msg)
|
|
155
|
+
if serialization_method == JSON:
|
|
156
|
+
payload_msg = json.loads(str(payload_msg, "utf-8"))
|
|
157
|
+
elif serialization_method != NO_SERIALIZATION:
|
|
158
|
+
payload_msg = str(payload_msg, "utf-8")
|
|
159
|
+
result["payload_msg"] = payload_msg
|
|
160
|
+
result["payload_size"] = payload_size
|
|
161
|
+
return result
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
@dataclass
|
|
165
|
+
class _RealtimeOptions:
|
|
166
|
+
app_id: str
|
|
167
|
+
access_token: str
|
|
168
|
+
bot_name: str
|
|
169
|
+
system_role: str
|
|
170
|
+
max_session_duration: float | None
|
|
171
|
+
conn_options: APIConnectOptions
|
|
172
|
+
modalities: list[Literal["text", "audio"]]
|
|
173
|
+
opening: str = "你好啊,今天过得怎么样?"
|
|
174
|
+
speaking_style: str = "你的说话风格简洁明了,语速适中,语调自然。"
|
|
175
|
+
speaker: str = "zh_female_vv_jupiter_bigtts"
|
|
176
|
+
sample_rate: int = 24000
|
|
177
|
+
num_channels: int = 1
|
|
178
|
+
format: str = "pcm"
|
|
179
|
+
model: Literal["O", "SC"] = "O"
|
|
180
|
+
character_manifest: str | None = None
|
|
181
|
+
end_smooth_window_ms: int = 500
|
|
182
|
+
enable_volc_websearch: bool = False
|
|
183
|
+
volc_websearch_type: Literal["web_summary", "web"] = "web_summary"
|
|
184
|
+
volc_websearch_api_key: str | None = None
|
|
185
|
+
volc_websearch_no_result_message: str = "抱歉,我找不到相关信息。"
|
|
186
|
+
|
|
187
|
+
@property
|
|
188
|
+
def ws_url(self) -> str:
|
|
189
|
+
return "wss://openspeech.bytedance.com/api/v3/realtime/dialogue"
|
|
190
|
+
|
|
191
|
+
def get_ws_headers(self) -> dict:
|
|
192
|
+
headers = {
|
|
193
|
+
"X-Api-App-ID": self.app_id,
|
|
194
|
+
"X-Api-Access-Key": self.access_token,
|
|
195
|
+
"X-Api-Resource-Id": "volc.speech.dialog", # 固定值
|
|
196
|
+
"X-Api-App-Key": "PlgvMymc7f3tQnJ6", # 固定值
|
|
197
|
+
"X-Api-Connect-Id": str(uuid.uuid4()),
|
|
198
|
+
}
|
|
199
|
+
return headers
|
|
200
|
+
|
|
201
|
+
def get_start_session_reqs(self, dialog_id: str | None) -> dict:
|
|
202
|
+
start_session_req = {
|
|
203
|
+
"asr": {
|
|
204
|
+
"extra": {
|
|
205
|
+
"end_smooth_window_ms": self.end_smooth_window_ms,
|
|
206
|
+
}
|
|
207
|
+
},
|
|
208
|
+
"tts": {
|
|
209
|
+
"audio_config": {
|
|
210
|
+
"channel": self.num_channels,
|
|
211
|
+
"format": self.format,
|
|
212
|
+
"sample_rate": self.sample_rate,
|
|
213
|
+
},
|
|
214
|
+
"speaker": self.speaker,
|
|
215
|
+
},
|
|
216
|
+
"dialog": {
|
|
217
|
+
"bot_name": self.bot_name,
|
|
218
|
+
"system_role": self.system_role,
|
|
219
|
+
"dialog_id": dialog_id or str(utils.shortuuid()),
|
|
220
|
+
"speaking_style": self.speaking_style,
|
|
221
|
+
"character_manifest": self.character_manifest,
|
|
222
|
+
"extra": {
|
|
223
|
+
"strict_audit": False,
|
|
224
|
+
"enable_volc_websearch": self.enable_volc_websearch,
|
|
225
|
+
"volc_websearch_type": self.volc_websearch_type,
|
|
226
|
+
"volc_websearch_api_key": self.volc_websearch_api_key,
|
|
227
|
+
"volc_websearch_no_result_message": self.volc_websearch_no_result_message,
|
|
228
|
+
"model": self.model,
|
|
229
|
+
},
|
|
230
|
+
},
|
|
231
|
+
}
|
|
232
|
+
return start_session_req
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
@dataclass
|
|
236
|
+
class _MessageGeneration:
|
|
237
|
+
message_id: str
|
|
238
|
+
text_ch: utils.aio.Chan[str]
|
|
239
|
+
audio_ch: utils.aio.Chan[rtc.AudioFrame]
|
|
240
|
+
audio_transcript: str = ""
|
|
241
|
+
modalities: asyncio.Future[list[Literal["text", "audio"]]] | None = None
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
@dataclass
|
|
245
|
+
class _ResponseGeneration:
|
|
246
|
+
message_ch: utils.aio.Chan[llm.MessageGeneration]
|
|
247
|
+
function_ch: utils.aio.Chan[llm.FunctionCall]
|
|
248
|
+
|
|
249
|
+
messages: dict[str, _MessageGeneration]
|
|
250
|
+
|
|
251
|
+
_done_fut: asyncio.Future[None]
|
|
252
|
+
_created_timestamp: float
|
|
253
|
+
"""timestamp when the response was created"""
|
|
254
|
+
_first_token_timestamp: float | None = None
|
|
255
|
+
"""timestamp when the first token was received"""
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
class RealtimeModel(llm.RealtimeModel):
|
|
259
|
+
def __init__(
|
|
260
|
+
self,
|
|
261
|
+
bot_name: str = "豆包",
|
|
262
|
+
speaking_style: str = "你的说话风格简洁明了,语速适中,语调自然。",
|
|
263
|
+
speaker: str = "zh_female_vv_jupiter_bigtts",
|
|
264
|
+
opening: str | None = None,
|
|
265
|
+
app_id: str | None = None,
|
|
266
|
+
access_token: str | None = None,
|
|
267
|
+
system_role: str | None = None,
|
|
268
|
+
character_manifest: str = None,
|
|
269
|
+
model: Literal["O", "SC"] = "O",
|
|
270
|
+
end_smooth_window_ms: int = 500,
|
|
271
|
+
enable_volc_websearch: bool = False,
|
|
272
|
+
volc_websearch_type: Literal["web_summary", "web"] = "web_summary",
|
|
273
|
+
volc_websearch_api_key: str = None,
|
|
274
|
+
volc_websearch_no_result_message: str = "抱歉,我找不到相关信息。",
|
|
275
|
+
rag_fn: Callable[[str], str] = None,
|
|
276
|
+
audio_output: bool = True,
|
|
277
|
+
modalities: NotGivenOr[list[Literal["text", "audio"]]] = NOT_GIVEN,
|
|
278
|
+
http_session: aiohttp.ClientSession | None = None,
|
|
279
|
+
max_session_duration: NotGivenOr[float | None] = NOT_GIVEN,
|
|
280
|
+
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
|
281
|
+
) -> None:
|
|
282
|
+
modalities = modalities if utils.is_given(modalities) else ["text", "audio"]
|
|
283
|
+
super().__init__(
|
|
284
|
+
capabilities=llm.RealtimeCapabilities(
|
|
285
|
+
message_truncation=True,
|
|
286
|
+
turn_detection=True,
|
|
287
|
+
user_transcription=True,
|
|
288
|
+
auto_tool_reply_generation=False,
|
|
289
|
+
audio_output=("audio" in modalities),
|
|
290
|
+
manual_function_calls=True,
|
|
291
|
+
)
|
|
292
|
+
)
|
|
293
|
+
logger.info(f"Model: {model}")
|
|
294
|
+
logger.info(f"Character Manifest: {character_manifest}")
|
|
295
|
+
logger.info(f"End Smooth Window MS: {end_smooth_window_ms}")
|
|
296
|
+
logger.info(f"Enable Volc Websearch: {enable_volc_websearch}")
|
|
297
|
+
logger.info(f"Volc Websearch Type: {volc_websearch_type}")
|
|
298
|
+
logger.info(f"Volc Websearch API Key: {volc_websearch_api_key}")
|
|
299
|
+
logger.info(
|
|
300
|
+
f"Volc Websearch No Result Message: {volc_websearch_no_result_message}"
|
|
301
|
+
)
|
|
302
|
+
app_id = app_id or os.environ.get("VOLCENGINE_REALTIME_APP_ID")
|
|
303
|
+
if app_id is None:
|
|
304
|
+
raise ValueError("VOLCENGINE_REALTIME_APP_ID is required")
|
|
305
|
+
access_token = access_token or os.environ.get(
|
|
306
|
+
"VOLCENGINE_REALTIME_ACCESS_TOKEN"
|
|
307
|
+
)
|
|
308
|
+
if access_token is None:
|
|
309
|
+
raise ValueError("VOLCENGINE_REALTIME_ACCESS_TOKEN is required")
|
|
310
|
+
self._opts = _RealtimeOptions(
|
|
311
|
+
app_id=app_id,
|
|
312
|
+
access_token=access_token,
|
|
313
|
+
bot_name=bot_name,
|
|
314
|
+
system_role=system_role,
|
|
315
|
+
speaker=speaker,
|
|
316
|
+
opening=opening,
|
|
317
|
+
speaking_style=speaking_style,
|
|
318
|
+
character_manifest=character_manifest,
|
|
319
|
+
model=model,
|
|
320
|
+
end_smooth_window_ms=end_smooth_window_ms,
|
|
321
|
+
enable_volc_websearch=enable_volc_websearch,
|
|
322
|
+
volc_websearch_type=volc_websearch_type,
|
|
323
|
+
volc_websearch_api_key=volc_websearch_api_key,
|
|
324
|
+
volc_websearch_no_result_message=volc_websearch_no_result_message,
|
|
325
|
+
modalities=modalities,
|
|
326
|
+
max_session_duration=max_session_duration,
|
|
327
|
+
conn_options=conn_options,
|
|
328
|
+
)
|
|
329
|
+
self._rag_fn = rag_fn
|
|
330
|
+
self._http_session = http_session
|
|
331
|
+
self._sessions = weakref.WeakSet[RealtimeSession]()
|
|
332
|
+
|
|
333
|
+
def update_options(
|
|
334
|
+
self,
|
|
335
|
+
*,
|
|
336
|
+
max_session_duration: NotGivenOr[float | None] = NOT_GIVEN,
|
|
337
|
+
) -> None:
|
|
338
|
+
pass
|
|
339
|
+
|
|
340
|
+
def _ensure_http_session(self) -> aiohttp.ClientSession:
|
|
341
|
+
if not self._http_session:
|
|
342
|
+
self._http_session = utils.http_context.http_session()
|
|
343
|
+
|
|
344
|
+
return self._http_session
|
|
345
|
+
|
|
346
|
+
def session(self) -> RealtimeSession:
|
|
347
|
+
sess = RealtimeSession(self)
|
|
348
|
+
self._sessions.add(sess)
|
|
349
|
+
return sess
|
|
350
|
+
|
|
351
|
+
async def aclose(self) -> None: ...
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
class RealtimeSession(
|
|
355
|
+
llm.RealtimeSession[
|
|
356
|
+
Literal["volcengine_server_event_received", "volcengine_client_event_queued"]
|
|
357
|
+
]
|
|
358
|
+
):
|
|
359
|
+
"""
|
|
360
|
+
A session for the volcengine Realtime API.
|
|
361
|
+
|
|
362
|
+
This class is used to interact with the volcengine Realtime API.
|
|
363
|
+
It is responsible for sending events to the volcengine Realtime API and receiving events from it.
|
|
364
|
+
|
|
365
|
+
It exposes two more events:
|
|
366
|
+
- volcengine_server_event_received: expose the raw server events from the OpenAI Realtime API
|
|
367
|
+
- volcengine_client_event_queued: expose the raw client events sent to the OpenAI Realtime API
|
|
368
|
+
"""
|
|
369
|
+
|
|
370
|
+
def __init__(self, realtime_model: RealtimeModel) -> None:
|
|
371
|
+
super().__init__(realtime_model)
|
|
372
|
+
self._realtime_model: RealtimeModel = realtime_model
|
|
373
|
+
self._opts = realtime_model._opts
|
|
374
|
+
self._tools = llm.ToolContext.empty()
|
|
375
|
+
self._msg_ch = utils.aio.Chan[rtc.AudioFrame]()
|
|
376
|
+
self._input_resampler: rtc.AudioResampler | None = None
|
|
377
|
+
self.session_id = str(uuid.uuid4())
|
|
378
|
+
|
|
379
|
+
self._instructions: str | None = None
|
|
380
|
+
self._main_atask = asyncio.create_task(
|
|
381
|
+
self._main_task(), name="RealtimeSession._main_task"
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
self._response_created_futures: dict[
|
|
385
|
+
str, asyncio.Future[llm.GenerationCreatedEvent]
|
|
386
|
+
] = {}
|
|
387
|
+
self._item_delete_future: dict[str, asyncio.Future] = {}
|
|
388
|
+
self._item_create_future: dict[str, asyncio.Future] = {}
|
|
389
|
+
|
|
390
|
+
self._current_generation: _ResponseGeneration | None = None
|
|
391
|
+
self._current_item: _MessageGeneration | None = None
|
|
392
|
+
self._remote_chat_ctx = llm.remote_chat_context.RemoteChatContext()
|
|
393
|
+
self._is_opening = False
|
|
394
|
+
self._first_tts_response = True
|
|
395
|
+
self._first_llm_response = True
|
|
396
|
+
self._first_llm_sentence = True
|
|
397
|
+
|
|
398
|
+
self._update_chat_ctx_lock = asyncio.Lock()
|
|
399
|
+
self._update_fnc_ctx_lock = asyncio.Lock()
|
|
400
|
+
|
|
401
|
+
# 100ms chunks
|
|
402
|
+
self._bstream = utils.audio.AudioByteStream(
|
|
403
|
+
self._realtime_model._opts.sample_rate,
|
|
404
|
+
self._realtime_model._opts.num_channels,
|
|
405
|
+
samples_per_channel=self._realtime_model._opts.sample_rate // 10,
|
|
406
|
+
)
|
|
407
|
+
self._pushed_duration_s: float = (
|
|
408
|
+
0 # duration of audio pushed to the OpenAI Realtime API
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
def send_event(self, event: rtc.AudioFrame) -> None:
|
|
412
|
+
with contextlib.suppress(utils.aio.channel.ChanClosed):
|
|
413
|
+
self._msg_ch.send_nowait(event)
|
|
414
|
+
|
|
415
|
+
@utils.log_exceptions(logger=logger)
|
|
416
|
+
async def _main_task(self) -> None:
|
|
417
|
+
logger.info("start realtime main task")
|
|
418
|
+
# while not self._msg_ch.closed:
|
|
419
|
+
ws_conn = await self._create_ws_conn()
|
|
420
|
+
|
|
421
|
+
try:
|
|
422
|
+
await self._run_ws(ws_conn)
|
|
423
|
+
|
|
424
|
+
except Exception as e:
|
|
425
|
+
logger.error("realtime main task error", exc_info=e)
|
|
426
|
+
self._emit_error(e, recoverable=False)
|
|
427
|
+
raise e
|
|
428
|
+
logger.info("realtime main task break")
|
|
429
|
+
# break
|
|
430
|
+
|
|
431
|
+
async def _create_ws_conn(self) -> aiohttp.ClientWebSocketResponse:
|
|
432
|
+
headers = self._realtime_model._opts.get_ws_headers()
|
|
433
|
+
url = self._realtime_model._opts.ws_url
|
|
434
|
+
return await asyncio.wait_for(
|
|
435
|
+
self._realtime_model._ensure_http_session().ws_connect(
|
|
436
|
+
url=url,
|
|
437
|
+
headers=headers,
|
|
438
|
+
),
|
|
439
|
+
self._realtime_model._opts.conn_options.timeout,
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
async def _run_ws(self, ws_conn: aiohttp.ClientWebSocketResponse) -> None:
|
|
443
|
+
closing = False
|
|
444
|
+
logger.info("start connection")
|
|
445
|
+
start_connection_request = bytearray(generate_header())
|
|
446
|
+
start_connection_request.extend(int(1).to_bytes(4, "big"))
|
|
447
|
+
payload_bytes = str.encode("{}")
|
|
448
|
+
payload_bytes = gzip.compress(payload_bytes)
|
|
449
|
+
start_connection_request.extend((len(payload_bytes)).to_bytes(4, "big"))
|
|
450
|
+
start_connection_request.extend(payload_bytes)
|
|
451
|
+
await ws_conn.send_bytes(start_connection_request)
|
|
452
|
+
_ = await ws_conn.receive_bytes()
|
|
453
|
+
|
|
454
|
+
logger.info("start session")
|
|
455
|
+
await self._start_session(ws_conn=ws_conn, dialog_id=self.session_id)
|
|
456
|
+
|
|
457
|
+
if self._realtime_model._opts.opening is not None:
|
|
458
|
+
self._is_opening = True
|
|
459
|
+
payload = {
|
|
460
|
+
"content": self._realtime_model._opts.opening,
|
|
461
|
+
}
|
|
462
|
+
hello_request = bytearray(generate_header())
|
|
463
|
+
hello_request.extend(int(300).to_bytes(4, "big"))
|
|
464
|
+
payload_bytes = str.encode(json.dumps(payload))
|
|
465
|
+
payload_bytes = gzip.compress(payload_bytes)
|
|
466
|
+
hello_request.extend((len(self.session_id)).to_bytes(4, "big"))
|
|
467
|
+
hello_request.extend(str.encode(self.session_id))
|
|
468
|
+
hello_request.extend((len(payload_bytes)).to_bytes(4, "big"))
|
|
469
|
+
hello_request.extend(payload_bytes)
|
|
470
|
+
await ws_conn.send_bytes(hello_request)
|
|
471
|
+
self._is_opening = True
|
|
472
|
+
logger.info("send hello request")
|
|
473
|
+
|
|
474
|
+
self._current_generation = _ResponseGeneration(
|
|
475
|
+
message_ch=utils.aio.Chan(),
|
|
476
|
+
function_ch=utils.aio.Chan(),
|
|
477
|
+
messages={},
|
|
478
|
+
_created_timestamp=time.time(),
|
|
479
|
+
_done_fut=asyncio.Future(),
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
generation_ev = llm.GenerationCreatedEvent(
|
|
483
|
+
message_stream=self._current_generation.message_ch,
|
|
484
|
+
function_stream=self._current_generation.function_ch,
|
|
485
|
+
user_initiated=False,
|
|
486
|
+
)
|
|
487
|
+
self.emit("generation_created", generation_ev)
|
|
488
|
+
item_id = utils.shortuuid()
|
|
489
|
+
modalities_fut: asyncio.Future[list[Literal["text", "audio"]]] = (
|
|
490
|
+
asyncio.Future()
|
|
491
|
+
)
|
|
492
|
+
self._current_item = _MessageGeneration(
|
|
493
|
+
message_id=item_id,
|
|
494
|
+
text_ch=utils.aio.Chan(),
|
|
495
|
+
audio_ch=utils.aio.Chan(),
|
|
496
|
+
modalities=modalities_fut,
|
|
497
|
+
)
|
|
498
|
+
if not self._realtime_model.capabilities.audio_output:
|
|
499
|
+
self._current_item.audio_ch.close()
|
|
500
|
+
self._current_item.modalities.set_result(["text"]) # type: ignore[union-attr]
|
|
501
|
+
else:
|
|
502
|
+
self._current_item.modalities.set_result(["audio", "text"]) # type: ignore[union-attr]
|
|
503
|
+
|
|
504
|
+
self._current_generation.message_ch.send_nowait(
|
|
505
|
+
llm.MessageGeneration(
|
|
506
|
+
message_id=item_id,
|
|
507
|
+
text_stream=self._current_item.text_ch,
|
|
508
|
+
audio_stream=self._current_item.audio_ch,
|
|
509
|
+
modalities=self._current_item.modalities,
|
|
510
|
+
)
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
@utils.log_exceptions(logger=logger)
|
|
514
|
+
async def _send_task() -> None:
|
|
515
|
+
nonlocal closing
|
|
516
|
+
async for frame in self._msg_ch:
|
|
517
|
+
try:
|
|
518
|
+
task_request = bytearray(
|
|
519
|
+
generate_header(
|
|
520
|
+
message_type=CLIENT_AUDIO_ONLY_REQUEST,
|
|
521
|
+
serial_method=NO_SERIALIZATION,
|
|
522
|
+
)
|
|
523
|
+
)
|
|
524
|
+
task_request.extend(int(200).to_bytes(4, "big"))
|
|
525
|
+
task_request.extend((len(self.session_id)).to_bytes(4, "big"))
|
|
526
|
+
task_request.extend(str.encode(self.session_id))
|
|
527
|
+
payload_bytes = gzip.compress(frame.data.tobytes())
|
|
528
|
+
task_request.extend(
|
|
529
|
+
(len(payload_bytes)).to_bytes(4, "big")
|
|
530
|
+
) # payload size(4 bytes)
|
|
531
|
+
task_request.extend(payload_bytes)
|
|
532
|
+
await ws_conn.send_bytes(task_request)
|
|
533
|
+
|
|
534
|
+
except Exception:
|
|
535
|
+
logger.error("send task error", exc_info=True)
|
|
536
|
+
break
|
|
537
|
+
|
|
538
|
+
closing = True
|
|
539
|
+
await ws_conn.close()
|
|
540
|
+
|
|
541
|
+
@utils.log_exceptions(logger=logger)
|
|
542
|
+
async def _recv_task() -> None:
|
|
543
|
+
while True:
|
|
544
|
+
try:
|
|
545
|
+
msg = await ws_conn.receive()
|
|
546
|
+
if msg.data is None:
|
|
547
|
+
continue
|
|
548
|
+
response = parse_response(msg.data)
|
|
549
|
+
event = response.get("event")
|
|
550
|
+
if event == 450: # ASRInfo
|
|
551
|
+
self.emit("input_speech_started", llm.InputSpeechStartedEvent())
|
|
552
|
+
logger.info("transcription start")
|
|
553
|
+
elif event == 451: # ASRResponse
|
|
554
|
+
response = response["payload_msg"]
|
|
555
|
+
transcription = response["results"][0]["alternatives"][0][
|
|
556
|
+
"text"
|
|
557
|
+
]
|
|
558
|
+
is_final = not response["results"][0]["is_interim"]
|
|
559
|
+
if is_final:
|
|
560
|
+
item_id = utils.shortuuid()
|
|
561
|
+
self.emit(
|
|
562
|
+
"input_audio_transcription_completed",
|
|
563
|
+
llm.InputTranscriptionCompleted(
|
|
564
|
+
item_id=item_id,
|
|
565
|
+
transcript=transcription,
|
|
566
|
+
is_final=True,
|
|
567
|
+
),
|
|
568
|
+
)
|
|
569
|
+
if self._current_generation is None:
|
|
570
|
+
self._current_generation = _ResponseGeneration(
|
|
571
|
+
message_ch=utils.aio.Chan(),
|
|
572
|
+
function_ch=utils.aio.Chan(),
|
|
573
|
+
messages={},
|
|
574
|
+
_created_timestamp=time.time(),
|
|
575
|
+
_done_fut=asyncio.Future(),
|
|
576
|
+
)
|
|
577
|
+
|
|
578
|
+
generation_ev = llm.GenerationCreatedEvent(
|
|
579
|
+
message_stream=self._current_generation.message_ch,
|
|
580
|
+
function_stream=self._current_generation.function_ch,
|
|
581
|
+
user_initiated=False,
|
|
582
|
+
)
|
|
583
|
+
|
|
584
|
+
self.emit("generation_created", generation_ev)
|
|
585
|
+
item_id = utils.shortuuid()
|
|
586
|
+
modalities_fut: asyncio.Future[
|
|
587
|
+
list[Literal["text", "audio"]]
|
|
588
|
+
] = asyncio.Future()
|
|
589
|
+
self._current_item = _MessageGeneration(
|
|
590
|
+
message_id=item_id,
|
|
591
|
+
text_ch=utils.aio.Chan(),
|
|
592
|
+
audio_ch=utils.aio.Chan(),
|
|
593
|
+
modalities=modalities_fut,
|
|
594
|
+
)
|
|
595
|
+
if not self._realtime_model.capabilities.audio_output:
|
|
596
|
+
self._current_item.audio_ch.close()
|
|
597
|
+
self._current_item.modalities.set_result(["text"]) # type: ignore[union-attr]
|
|
598
|
+
else:
|
|
599
|
+
self._current_item.modalities.set_result(
|
|
600
|
+
["audio", "text"]
|
|
601
|
+
) # type: ignore[union-attr]
|
|
602
|
+
self._current_generation.message_ch.send_nowait(
|
|
603
|
+
llm.MessageGeneration(
|
|
604
|
+
message_id=item_id,
|
|
605
|
+
text_stream=self._current_item.text_ch,
|
|
606
|
+
audio_stream=self._current_item.audio_ch,
|
|
607
|
+
modalities=self._current_item.modalities,
|
|
608
|
+
)
|
|
609
|
+
)
|
|
610
|
+
|
|
611
|
+
elif event == 459: # ASREnd
|
|
612
|
+
logger.info("transcription end")
|
|
613
|
+
self.emit(
|
|
614
|
+
"input_speech_stopped",
|
|
615
|
+
llm.InputSpeechStoppedEvent(
|
|
616
|
+
user_transcription_enabled=False
|
|
617
|
+
),
|
|
618
|
+
)
|
|
619
|
+
if self._realtime_model._rag_fn is not None:
|
|
620
|
+
logger.info("rag start")
|
|
621
|
+
rag_result = self._realtime_model._rag_fn(transcription)
|
|
622
|
+
payload = {
|
|
623
|
+
"external_rag": rag_result,
|
|
624
|
+
}
|
|
625
|
+
payload_bytes = str.encode(json.dumps(payload))
|
|
626
|
+
payload_bytes = gzip.compress(payload_bytes)
|
|
627
|
+
chat_rag_text_request = bytearray(generate_header())
|
|
628
|
+
chat_rag_text_request.extend(int(502).to_bytes(4, "big"))
|
|
629
|
+
chat_rag_text_request.extend(
|
|
630
|
+
(len(self.session_id)).to_bytes(4, "big")
|
|
631
|
+
)
|
|
632
|
+
chat_rag_text_request.extend(str.encode(self.session_id))
|
|
633
|
+
chat_rag_text_request.extend(
|
|
634
|
+
(len(payload_bytes)).to_bytes(4, "big")
|
|
635
|
+
)
|
|
636
|
+
chat_rag_text_request.extend(payload_bytes)
|
|
637
|
+
await ws_conn.send_bytes(chat_rag_text_request)
|
|
638
|
+
logger.info("rag end")
|
|
639
|
+
logger.info("llm start")
|
|
640
|
+
logger.info("tts start")
|
|
641
|
+
|
|
642
|
+
elif event == 352: # TTSResponse
|
|
643
|
+
if self._first_tts_response:
|
|
644
|
+
logger.info("llm first sentence")
|
|
645
|
+
logger.info("tts first response")
|
|
646
|
+
self._first_tts_response = False
|
|
647
|
+
audio_bytes = response[
|
|
648
|
+
"payload_msg"
|
|
649
|
+
] # 原始为float32,需要转为int16
|
|
650
|
+
audio = np.frombuffer(audio_bytes, dtype=np.float32)
|
|
651
|
+
# 裁剪到 [-1.0, 1.0],避免溢出
|
|
652
|
+
audio = np.clip(audio, -1.0, 1.0)
|
|
653
|
+
audio = (audio * 32767).astype(np.int16)
|
|
654
|
+
audio_bytes = audio.tobytes()
|
|
655
|
+
self._current_item.audio_ch.send_nowait(
|
|
656
|
+
rtc.AudioFrame(
|
|
657
|
+
data=audio_bytes,
|
|
658
|
+
sample_rate=self._realtime_model._opts.sample_rate,
|
|
659
|
+
num_channels=1,
|
|
660
|
+
samples_per_channel=len(audio_bytes) // 2,
|
|
661
|
+
)
|
|
662
|
+
)
|
|
663
|
+
elif event == 350: # TTSSentenceStart
|
|
664
|
+
pass
|
|
665
|
+
elif event == 351: # TTSSentenceEnd
|
|
666
|
+
pass
|
|
667
|
+
elif event == 359: # TTSEnded
|
|
668
|
+
logger.info("tts end")
|
|
669
|
+
self._current_item.audio_ch.close()
|
|
670
|
+
if self._is_opening:
|
|
671
|
+
self._current_item.text_ch.send_nowait(
|
|
672
|
+
self._realtime_model._opts.opening
|
|
673
|
+
)
|
|
674
|
+
self._current_item.text_ch.close()
|
|
675
|
+
self._is_opening = False
|
|
676
|
+
self._current_generation.message_ch.close()
|
|
677
|
+
self._current_generation.function_ch.close()
|
|
678
|
+
self._current_generation = None
|
|
679
|
+
self._first_tts_response = True
|
|
680
|
+
elif event == 550: # 模型回复的文本内容
|
|
681
|
+
if self._first_llm_response:
|
|
682
|
+
logger.info("llm first response")
|
|
683
|
+
self._first_llm_response = False
|
|
684
|
+
text = response["payload_msg"]["content"]
|
|
685
|
+
self._current_item.text_ch.send_nowait(text)
|
|
686
|
+
elif event == 559: # 模型回复文本结束事件
|
|
687
|
+
logger.info("llm end")
|
|
688
|
+
self._current_item.text_ch.close()
|
|
689
|
+
self._first_llm_response = True
|
|
690
|
+
else:
|
|
691
|
+
pass
|
|
692
|
+
except Exception:
|
|
693
|
+
logger.error("recv task error", exc_info=True)
|
|
694
|
+
break
|
|
695
|
+
|
|
696
|
+
tasks = [
|
|
697
|
+
asyncio.create_task(_recv_task(), name="_recv_task"),
|
|
698
|
+
asyncio.create_task(_send_task(), name="_send_task"),
|
|
699
|
+
]
|
|
700
|
+
try:
|
|
701
|
+
done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
|
702
|
+
for task in done:
|
|
703
|
+
task.result()
|
|
704
|
+
|
|
705
|
+
finally:
|
|
706
|
+
await utils.aio.cancel_and_wait(*tasks)
|
|
707
|
+
await ws_conn.close()
|
|
708
|
+
|
|
709
|
+
def _create_session_update_event(self):
|
|
710
|
+
pass
|
|
711
|
+
|
|
712
|
+
async def chat_tts_text(
|
|
713
|
+
self,
|
|
714
|
+
start: bool,
|
|
715
|
+
end: bool,
|
|
716
|
+
content: str,
|
|
717
|
+
ws_conn: aiohttp.ClientWebSocketResponse,
|
|
718
|
+
) -> None:
|
|
719
|
+
"""发送Chat TTS Text消息"""
|
|
720
|
+
payload = {
|
|
721
|
+
"start": start,
|
|
722
|
+
"end": end,
|
|
723
|
+
"content": content,
|
|
724
|
+
}
|
|
725
|
+
logger.info("ChatTTSTextRequest")
|
|
726
|
+
payload_bytes = str.encode(json.dumps(payload))
|
|
727
|
+
payload_bytes = gzip.compress(payload_bytes)
|
|
728
|
+
|
|
729
|
+
chat_tts_text_request = bytearray(generate_header())
|
|
730
|
+
chat_tts_text_request.extend(int(500).to_bytes(4, "big"))
|
|
731
|
+
chat_tts_text_request.extend((len(self.session_id)).to_bytes(4, "big"))
|
|
732
|
+
chat_tts_text_request.extend(str.encode(self.session_id))
|
|
733
|
+
chat_tts_text_request.extend((len(payload_bytes)).to_bytes(4, "big"))
|
|
734
|
+
chat_tts_text_request.extend(payload_bytes)
|
|
735
|
+
await ws_conn.send_bytes(chat_tts_text_request)
|
|
736
|
+
|
|
737
|
+
async def _start_session(
|
|
738
|
+
self, ws_conn: aiohttp.ClientWebSocketResponse, dialog_id: str
|
|
739
|
+
) -> None:
|
|
740
|
+
request_params = self._realtime_model._opts.get_start_session_reqs(
|
|
741
|
+
dialog_id=dialog_id
|
|
742
|
+
)
|
|
743
|
+
payload_bytes = str.encode(json.dumps(request_params))
|
|
744
|
+
payload_bytes = gzip.compress(payload_bytes)
|
|
745
|
+
start_session_request = bytearray(generate_header())
|
|
746
|
+
start_session_request.extend(int(100).to_bytes(4, "big"))
|
|
747
|
+
start_session_request.extend((len(self.session_id)).to_bytes(4, "big"))
|
|
748
|
+
start_session_request.extend(str.encode(self.session_id))
|
|
749
|
+
start_session_request.extend((len(payload_bytes)).to_bytes(4, "big"))
|
|
750
|
+
start_session_request.extend(payload_bytes)
|
|
751
|
+
await ws_conn.send_bytes(start_session_request)
|
|
752
|
+
_ = await ws_conn.receive_bytes()
|
|
753
|
+
|
|
754
|
+
async def _finish_session(self, ws_conn: aiohttp.ClientWebSocketResponse) -> None:
|
|
755
|
+
finish_session_request = bytearray(generate_header())
|
|
756
|
+
finish_session_request.extend(int(102).to_bytes(4, "big"))
|
|
757
|
+
payload_bytes = str.encode("{}")
|
|
758
|
+
payload_bytes = gzip.compress(payload_bytes)
|
|
759
|
+
finish_session_request.extend((len(self.session_id)).to_bytes(4, "big"))
|
|
760
|
+
finish_session_request.extend(str.encode(self.session_id))
|
|
761
|
+
finish_session_request.extend((len(payload_bytes)).to_bytes(4, "big"))
|
|
762
|
+
finish_session_request.extend(payload_bytes)
|
|
763
|
+
await ws_conn.send_bytes(finish_session_request)
|
|
764
|
+
|
|
765
|
+
async def _finish_connection(
|
|
766
|
+
self, ws_conn: aiohttp.ClientWebSocketResponse
|
|
767
|
+
) -> None:
|
|
768
|
+
finish_connection_request = bytearray(generate_header())
|
|
769
|
+
finish_connection_request.extend(int(2).to_bytes(4, "big"))
|
|
770
|
+
payload_bytes = str.encode("{}")
|
|
771
|
+
payload_bytes = gzip.compress(payload_bytes)
|
|
772
|
+
finish_connection_request.extend((len(payload_bytes)).to_bytes(4, "big"))
|
|
773
|
+
finish_connection_request.extend(payload_bytes)
|
|
774
|
+
await ws_conn.send_bytes(finish_connection_request)
|
|
775
|
+
_ = await ws_conn.receive_bytes()
|
|
776
|
+
|
|
777
|
+
@property
|
|
778
|
+
def chat_ctx(self) -> llm.ChatContext:
|
|
779
|
+
return self._remote_chat_ctx.to_chat_ctx()
|
|
780
|
+
|
|
781
|
+
@property
|
|
782
|
+
def tools(self) -> llm.ToolContext:
|
|
783
|
+
return self._tools.copy()
|
|
784
|
+
|
|
785
|
+
def update_options(
|
|
786
|
+
self,
|
|
787
|
+
*,
|
|
788
|
+
tool_choice: NotGivenOr[llm.ToolChoice | None] = NOT_GIVEN,
|
|
789
|
+
voice: NotGivenOr[str] = NOT_GIVEN,
|
|
790
|
+
) -> None:
|
|
791
|
+
pass
|
|
792
|
+
|
|
793
|
+
async def update_tools(self, tools: list[llm.Tool]) -> None:
|
|
794
|
+
pass
|
|
795
|
+
|
|
796
|
+
async def update_chat_ctx(self, chat_ctx: llm.ChatContext) -> None:
|
|
797
|
+
pass
|
|
798
|
+
|
|
799
|
+
def _create_update_chat_ctx_events(self, chat_ctx: llm.ChatContext):
|
|
800
|
+
events = []
|
|
801
|
+
|
|
802
|
+
return events
|
|
803
|
+
|
|
804
|
+
async def update_instructions(self, instructions: str) -> None:
|
|
805
|
+
self._opts.system_role = instructions
|
|
806
|
+
|
|
807
|
+
def push_audio(self, frame: rtc.AudioFrame) -> None:
|
|
808
|
+
for f in self._resample_audio(frame):
|
|
809
|
+
data = f.data.tobytes()
|
|
810
|
+
for nf in self._bstream.write(data):
|
|
811
|
+
self.send_event(nf)
|
|
812
|
+
self._pushed_duration_s += nf.duration
|
|
813
|
+
|
|
814
|
+
def push_video(self, frame: rtc.VideoFrame) -> None:
|
|
815
|
+
pass
|
|
816
|
+
|
|
817
|
+
def commit_audio(self) -> None:
|
|
818
|
+
if self._pushed_duration_s > 0.1:
|
|
819
|
+
self._pushed_duration_s = 0
|
|
820
|
+
|
|
821
|
+
def clear_audio(self) -> None:
|
|
822
|
+
self._pushed_duration_s = 0
|
|
823
|
+
|
|
824
|
+
def generate_reply(
|
|
825
|
+
self,
|
|
826
|
+
*,
|
|
827
|
+
instructions: NotGivenOr[str] = NOT_GIVEN,
|
|
828
|
+
tool_choice: NotGivenOr[llm.ToolChoice] = NOT_GIVEN,
|
|
829
|
+
tools: NotGivenOr[list[llm.Tool]] = NOT_GIVEN,
|
|
830
|
+
) -> asyncio.Future[llm.GenerationCreatedEvent]:
|
|
831
|
+
"""仅文字输入"""
|
|
832
|
+
event_id = utils.shortuuid("response_create_")
|
|
833
|
+
fut = asyncio.Future[llm.GenerationCreatedEvent]()
|
|
834
|
+
self._response_created_futures[event_id] = fut
|
|
835
|
+
|
|
836
|
+
def _on_timeout() -> None:
|
|
837
|
+
if fut and not fut.done():
|
|
838
|
+
fut.set_exception(llm.RealtimeError("generate_reply timed out."))
|
|
839
|
+
|
|
840
|
+
handle = asyncio.get_event_loop().call_later(5.0, _on_timeout)
|
|
841
|
+
fut.add_done_callback(lambda _: handle.cancel())
|
|
842
|
+
return fut
|
|
843
|
+
|
|
844
|
+
def interrupt(self) -> None:
|
|
845
|
+
pass
|
|
846
|
+
|
|
847
|
+
def truncate(
|
|
848
|
+
self,
|
|
849
|
+
*,
|
|
850
|
+
message_id: str,
|
|
851
|
+
modalities: list[Literal["text", "audio"]],
|
|
852
|
+
audio_end_ms: int,
|
|
853
|
+
audio_transcript: NotGivenOr[str] = NOT_GIVEN,
|
|
854
|
+
) -> None:
|
|
855
|
+
if "audio" in modalities:
|
|
856
|
+
# 当前 volcengine 实时接口未暴露远端音频截断事件;占位以对齐接口
|
|
857
|
+
pass
|
|
858
|
+
elif utils.is_given(audio_transcript):
|
|
859
|
+
# 同步转写文本到远端会话上下文
|
|
860
|
+
chat_ctx = self.chat_ctx.copy()
|
|
861
|
+
if (idx := chat_ctx.index_by_id(message_id)) is not None:
|
|
862
|
+
new_item = copy.copy(chat_ctx.items[idx])
|
|
863
|
+
assert new_item.type == "message"
|
|
864
|
+
|
|
865
|
+
new_item.content = [audio_transcript]
|
|
866
|
+
chat_ctx.items[idx] = new_item
|
|
867
|
+
events = self._create_update_chat_ctx_events(chat_ctx)
|
|
868
|
+
for ev in events:
|
|
869
|
+
self.send_event(ev)
|
|
870
|
+
|
|
871
|
+
async def aclose(self) -> None:
|
|
872
|
+
self._msg_ch.close()
|
|
873
|
+
await self._main_atask
|
|
874
|
+
|
|
875
|
+
def _resample_audio(self, frame: rtc.AudioFrame) -> Iterator[rtc.AudioFrame]:
|
|
876
|
+
if self._input_resampler:
|
|
877
|
+
if frame.sample_rate != self._input_resampler._input_rate:
|
|
878
|
+
# input audio changed to a different sample rate
|
|
879
|
+
self._input_resampler = None
|
|
880
|
+
|
|
881
|
+
if self._input_resampler is None and (
|
|
882
|
+
frame.sample_rate != self._realtime_model._opts.sample_rate
|
|
883
|
+
or frame.num_channels != self._realtime_model._opts.num_channels
|
|
884
|
+
):
|
|
885
|
+
self._input_resampler = rtc.AudioResampler(
|
|
886
|
+
input_rate=frame.sample_rate,
|
|
887
|
+
output_rate=self._realtime_model._opts.sample_rate,
|
|
888
|
+
num_channels=self._realtime_model._opts.num_channels,
|
|
889
|
+
)
|
|
890
|
+
|
|
891
|
+
if self._input_resampler:
|
|
892
|
+
# TODO(long): flush the resampler when the input source is changed
|
|
893
|
+
yield from self._input_resampler.push(frame)
|
|
894
|
+
else:
|
|
895
|
+
yield frame
|
|
896
|
+
|
|
897
|
+
def _emit_error(self, error: Exception, recoverable: bool) -> None:
|
|
898
|
+
self.emit(
|
|
899
|
+
"error",
|
|
900
|
+
llm.RealtimeModelError(
|
|
901
|
+
timestamp=time.time(),
|
|
902
|
+
label=self._realtime_model._label,
|
|
903
|
+
error=error,
|
|
904
|
+
recoverable=recoverable,
|
|
905
|
+
),
|
|
906
|
+
)
|