livekit-plugins-volcenginee 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,906 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import contextlib
5
+ import copy
6
+ import json
7
+ import os
8
+ import time
9
+ import weakref
10
+ import gzip
11
+ import uuid
12
+ from collections.abc import Iterator
13
+ from dataclasses import dataclass
14
+ from typing import Literal, Callable
15
+
16
+ import aiohttp
17
+ import numpy as np
18
+ from livekit import rtc
19
+ from livekit.agents import llm, utils
20
+ from livekit.agents.types import (
21
+ DEFAULT_API_CONNECT_OPTIONS,
22
+ NOT_GIVEN,
23
+ APIConnectOptions,
24
+ NotGivenOr,
25
+ )
26
+
27
+ from .log import logger
28
+
29
+
30
+ PROTOCOL_VERSION = 0b0001
31
+ DEFAULT_HEADER_SIZE = 0b0001
32
+
33
+ PROTOCOL_VERSION_BITS = 4
34
+ HEADER_BITS = 4
35
+ MESSAGE_TYPE_BITS = 4
36
+ MESSAGE_TYPE_SPECIFIC_FLAGS_BITS = 4
37
+ MESSAGE_SERIALIZATION_BITS = 4
38
+ MESSAGE_COMPRESSION_BITS = 4
39
+ RESERVED_BITS = 8
40
+
41
+ # Message Type:
42
+ CLIENT_FULL_REQUEST = 0b0001
43
+ CLIENT_AUDIO_ONLY_REQUEST = 0b0010
44
+
45
+ SERVER_FULL_RESPONSE = 0b1001
46
+ SERVER_ACK = 0b1011
47
+ SERVER_ERROR_RESPONSE = 0b1111
48
+
49
+ # Message Type Specific Flags
50
+ NO_SEQUENCE = 0b0000 # no check sequence
51
+ POS_SEQUENCE = 0b0001
52
+ NEG_SEQUENCE = 0b0010
53
+ NEG_SEQUENCE_1 = 0b0011
54
+
55
+ MSG_WITH_EVENT = 0b0100
56
+
57
+ # Message Serialization
58
+ NO_SERIALIZATION = 0b0000
59
+ JSON = 0b0001
60
+ THRIFT = 0b0011
61
+ CUSTOM_TYPE = 0b1111
62
+
63
+ # Message Compression
64
+ NO_COMPRESSION = 0b0000
65
+ GZIP = 0b0001
66
+ CUSTOM_COMPRESSION = 0b1111
67
+
68
+
69
+ def generate_header(
70
+ version=PROTOCOL_VERSION,
71
+ message_type=CLIENT_FULL_REQUEST,
72
+ message_type_specific_flags=MSG_WITH_EVENT,
73
+ serial_method=JSON,
74
+ compression_type=GZIP,
75
+ reserved_data=0x00,
76
+ extension_header=bytes(),
77
+ ):
78
+ """
79
+ protocol_version(4 bits), header_size(4 bits),
80
+ message_type(4 bits), message_type_specific_flags(4 bits)
81
+ serialization_method(4 bits) message_compression(4 bits)
82
+ reserved (8bits) 保留字段
83
+ header_extensions 扩展头(大小等于 8 * 4 * (header_size - 1) )
84
+ """
85
+ header = bytearray()
86
+ header_size = int(len(extension_header) / 4) + 1
87
+ header.append((version << 4) | header_size)
88
+ header.append((message_type << 4) | message_type_specific_flags)
89
+ header.append((serial_method << 4) | compression_type)
90
+ header.append(reserved_data)
91
+ header.extend(extension_header)
92
+ return header
93
+
94
+
95
+ def parse_response(res):
96
+ """
97
+ - header
98
+ - (4bytes)header
99
+ - (4bits)version(v1) + (4bits)header_size
100
+ - (4bits)messageType + (4bits)messageTypeFlags
101
+ -- 0001 CompleteClient | -- 0001 hasSequence
102
+ -- 0010 audioonly | -- 0010 isTailPacket
103
+ | -- 0100 hasEvent
104
+ - (4bits)payloadFormat + (4bits)compression
105
+ - (8bits) reserve
106
+ - payload
107
+ - [optional 4 bytes] event
108
+ - [optional] session ID
109
+ -- (4 bytes)session ID len
110
+ -- session ID data
111
+ - (4 bytes)data len
112
+ - data
113
+ """
114
+ if isinstance(res, str):
115
+ return {}
116
+ # protocol_version = res[0] >> 4
117
+ header_size = res[0] & 0x0F
118
+ message_type = res[1] >> 4
119
+ message_type_specific_flags = res[1] & 0x0F
120
+ serialization_method = res[2] >> 4
121
+ message_compression = res[2] & 0x0F
122
+ # reserved = res[3]
123
+ # header_extensions = res[4 : header_size * 4]
124
+ payload = res[header_size * 4 :]
125
+ result = {}
126
+ payload_msg = None
127
+ payload_size = 0
128
+ start = 0
129
+ if message_type == SERVER_FULL_RESPONSE or message_type == SERVER_ACK:
130
+ result["message_type"] = "SERVER_FULL_RESPONSE"
131
+ if message_type == SERVER_ACK:
132
+ result["message_type"] = "SERVER_ACK"
133
+ if message_type_specific_flags & NEG_SEQUENCE > 0:
134
+ result["seq"] = int.from_bytes(payload[:4], "big", signed=False)
135
+ start += 4
136
+ if message_type_specific_flags & MSG_WITH_EVENT > 0:
137
+ result["event"] = int.from_bytes(payload[:4], "big", signed=False)
138
+ start += 4
139
+ payload = payload[start:]
140
+ session_id_size = int.from_bytes(payload[:4], "big", signed=True)
141
+ session_id = payload[4 : session_id_size + 4]
142
+ result["session_id"] = str(session_id)
143
+ payload = payload[4 + session_id_size :]
144
+ payload_size = int.from_bytes(payload[:4], "big", signed=False)
145
+ payload_msg = payload[4:]
146
+ elif message_type == SERVER_ERROR_RESPONSE:
147
+ code = int.from_bytes(payload[:4], "big", signed=False)
148
+ result["code"] = code
149
+ payload_size = int.from_bytes(payload[4:8], "big", signed=False)
150
+ payload_msg = payload[8:]
151
+ if payload_msg is None:
152
+ return result
153
+ if message_compression == GZIP:
154
+ payload_msg = gzip.decompress(payload_msg)
155
+ if serialization_method == JSON:
156
+ payload_msg = json.loads(str(payload_msg, "utf-8"))
157
+ elif serialization_method != NO_SERIALIZATION:
158
+ payload_msg = str(payload_msg, "utf-8")
159
+ result["payload_msg"] = payload_msg
160
+ result["payload_size"] = payload_size
161
+ return result
162
+
163
+
164
+ @dataclass
165
+ class _RealtimeOptions:
166
+ app_id: str
167
+ access_token: str
168
+ bot_name: str
169
+ system_role: str
170
+ max_session_duration: float | None
171
+ conn_options: APIConnectOptions
172
+ modalities: list[Literal["text", "audio"]]
173
+ opening: str = "你好啊,今天过得怎么样?"
174
+ speaking_style: str = "你的说话风格简洁明了,语速适中,语调自然。"
175
+ speaker: str = "zh_female_vv_jupiter_bigtts"
176
+ sample_rate: int = 24000
177
+ num_channels: int = 1
178
+ format: str = "pcm"
179
+ model: Literal["O", "SC"] = "O"
180
+ character_manifest: str | None = None
181
+ end_smooth_window_ms: int = 500
182
+ enable_volc_websearch: bool = False
183
+ volc_websearch_type: Literal["web_summary", "web"] = "web_summary"
184
+ volc_websearch_api_key: str | None = None
185
+ volc_websearch_no_result_message: str = "抱歉,我找不到相关信息。"
186
+
187
+ @property
188
+ def ws_url(self) -> str:
189
+ return "wss://openspeech.bytedance.com/api/v3/realtime/dialogue"
190
+
191
+ def get_ws_headers(self) -> dict:
192
+ headers = {
193
+ "X-Api-App-ID": self.app_id,
194
+ "X-Api-Access-Key": self.access_token,
195
+ "X-Api-Resource-Id": "volc.speech.dialog", # 固定值
196
+ "X-Api-App-Key": "PlgvMymc7f3tQnJ6", # 固定值
197
+ "X-Api-Connect-Id": str(uuid.uuid4()),
198
+ }
199
+ return headers
200
+
201
+ def get_start_session_reqs(self, dialog_id: str | None) -> dict:
202
+ start_session_req = {
203
+ "asr": {
204
+ "extra": {
205
+ "end_smooth_window_ms": self.end_smooth_window_ms,
206
+ }
207
+ },
208
+ "tts": {
209
+ "audio_config": {
210
+ "channel": self.num_channels,
211
+ "format": self.format,
212
+ "sample_rate": self.sample_rate,
213
+ },
214
+ "speaker": self.speaker,
215
+ },
216
+ "dialog": {
217
+ "bot_name": self.bot_name,
218
+ "system_role": self.system_role,
219
+ "dialog_id": dialog_id or str(utils.shortuuid()),
220
+ "speaking_style": self.speaking_style,
221
+ "character_manifest": self.character_manifest,
222
+ "extra": {
223
+ "strict_audit": False,
224
+ "enable_volc_websearch": self.enable_volc_websearch,
225
+ "volc_websearch_type": self.volc_websearch_type,
226
+ "volc_websearch_api_key": self.volc_websearch_api_key,
227
+ "volc_websearch_no_result_message": self.volc_websearch_no_result_message,
228
+ "model": self.model,
229
+ },
230
+ },
231
+ }
232
+ return start_session_req
233
+
234
+
235
+ @dataclass
236
+ class _MessageGeneration:
237
+ message_id: str
238
+ text_ch: utils.aio.Chan[str]
239
+ audio_ch: utils.aio.Chan[rtc.AudioFrame]
240
+ audio_transcript: str = ""
241
+ modalities: asyncio.Future[list[Literal["text", "audio"]]] | None = None
242
+
243
+
244
+ @dataclass
245
+ class _ResponseGeneration:
246
+ message_ch: utils.aio.Chan[llm.MessageGeneration]
247
+ function_ch: utils.aio.Chan[llm.FunctionCall]
248
+
249
+ messages: dict[str, _MessageGeneration]
250
+
251
+ _done_fut: asyncio.Future[None]
252
+ _created_timestamp: float
253
+ """timestamp when the response was created"""
254
+ _first_token_timestamp: float | None = None
255
+ """timestamp when the first token was received"""
256
+
257
+
258
+ class RealtimeModel(llm.RealtimeModel):
259
+ def __init__(
260
+ self,
261
+ bot_name: str = "豆包",
262
+ speaking_style: str = "你的说话风格简洁明了,语速适中,语调自然。",
263
+ speaker: str = "zh_female_vv_jupiter_bigtts",
264
+ opening: str | None = None,
265
+ app_id: str | None = None,
266
+ access_token: str | None = None,
267
+ system_role: str | None = None,
268
+ character_manifest: str = None,
269
+ model: Literal["O", "SC"] = "O",
270
+ end_smooth_window_ms: int = 500,
271
+ enable_volc_websearch: bool = False,
272
+ volc_websearch_type: Literal["web_summary", "web"] = "web_summary",
273
+ volc_websearch_api_key: str = None,
274
+ volc_websearch_no_result_message: str = "抱歉,我找不到相关信息。",
275
+ rag_fn: Callable[[str], str] = None,
276
+ audio_output: bool = True,
277
+ modalities: NotGivenOr[list[Literal["text", "audio"]]] = NOT_GIVEN,
278
+ http_session: aiohttp.ClientSession | None = None,
279
+ max_session_duration: NotGivenOr[float | None] = NOT_GIVEN,
280
+ conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
281
+ ) -> None:
282
+ modalities = modalities if utils.is_given(modalities) else ["text", "audio"]
283
+ super().__init__(
284
+ capabilities=llm.RealtimeCapabilities(
285
+ message_truncation=True,
286
+ turn_detection=True,
287
+ user_transcription=True,
288
+ auto_tool_reply_generation=False,
289
+ audio_output=("audio" in modalities),
290
+ manual_function_calls=True,
291
+ )
292
+ )
293
+ logger.info(f"Model: {model}")
294
+ logger.info(f"Character Manifest: {character_manifest}")
295
+ logger.info(f"End Smooth Window MS: {end_smooth_window_ms}")
296
+ logger.info(f"Enable Volc Websearch: {enable_volc_websearch}")
297
+ logger.info(f"Volc Websearch Type: {volc_websearch_type}")
298
+ logger.info(f"Volc Websearch API Key: {volc_websearch_api_key}")
299
+ logger.info(
300
+ f"Volc Websearch No Result Message: {volc_websearch_no_result_message}"
301
+ )
302
+ app_id = app_id or os.environ.get("VOLCENGINE_REALTIME_APP_ID")
303
+ if app_id is None:
304
+ raise ValueError("VOLCENGINE_REALTIME_APP_ID is required")
305
+ access_token = access_token or os.environ.get(
306
+ "VOLCENGINE_REALTIME_ACCESS_TOKEN"
307
+ )
308
+ if access_token is None:
309
+ raise ValueError("VOLCENGINE_REALTIME_ACCESS_TOKEN is required")
310
+ self._opts = _RealtimeOptions(
311
+ app_id=app_id,
312
+ access_token=access_token,
313
+ bot_name=bot_name,
314
+ system_role=system_role,
315
+ speaker=speaker,
316
+ opening=opening,
317
+ speaking_style=speaking_style,
318
+ character_manifest=character_manifest,
319
+ model=model,
320
+ end_smooth_window_ms=end_smooth_window_ms,
321
+ enable_volc_websearch=enable_volc_websearch,
322
+ volc_websearch_type=volc_websearch_type,
323
+ volc_websearch_api_key=volc_websearch_api_key,
324
+ volc_websearch_no_result_message=volc_websearch_no_result_message,
325
+ modalities=modalities,
326
+ max_session_duration=max_session_duration,
327
+ conn_options=conn_options,
328
+ )
329
+ self._rag_fn = rag_fn
330
+ self._http_session = http_session
331
+ self._sessions = weakref.WeakSet[RealtimeSession]()
332
+
333
+ def update_options(
334
+ self,
335
+ *,
336
+ max_session_duration: NotGivenOr[float | None] = NOT_GIVEN,
337
+ ) -> None:
338
+ pass
339
+
340
+ def _ensure_http_session(self) -> aiohttp.ClientSession:
341
+ if not self._http_session:
342
+ self._http_session = utils.http_context.http_session()
343
+
344
+ return self._http_session
345
+
346
+ def session(self) -> RealtimeSession:
347
+ sess = RealtimeSession(self)
348
+ self._sessions.add(sess)
349
+ return sess
350
+
351
+ async def aclose(self) -> None: ...
352
+
353
+
354
+ class RealtimeSession(
355
+ llm.RealtimeSession[
356
+ Literal["volcengine_server_event_received", "volcengine_client_event_queued"]
357
+ ]
358
+ ):
359
+ """
360
+ A session for the volcengine Realtime API.
361
+
362
+ This class is used to interact with the volcengine Realtime API.
363
+ It is responsible for sending events to the volcengine Realtime API and receiving events from it.
364
+
365
+ It exposes two more events:
366
+ - volcengine_server_event_received: expose the raw server events from the OpenAI Realtime API
367
+ - volcengine_client_event_queued: expose the raw client events sent to the OpenAI Realtime API
368
+ """
369
+
370
+ def __init__(self, realtime_model: RealtimeModel) -> None:
371
+ super().__init__(realtime_model)
372
+ self._realtime_model: RealtimeModel = realtime_model
373
+ self._opts = realtime_model._opts
374
+ self._tools = llm.ToolContext.empty()
375
+ self._msg_ch = utils.aio.Chan[rtc.AudioFrame]()
376
+ self._input_resampler: rtc.AudioResampler | None = None
377
+ self.session_id = str(uuid.uuid4())
378
+
379
+ self._instructions: str | None = None
380
+ self._main_atask = asyncio.create_task(
381
+ self._main_task(), name="RealtimeSession._main_task"
382
+ )
383
+
384
+ self._response_created_futures: dict[
385
+ str, asyncio.Future[llm.GenerationCreatedEvent]
386
+ ] = {}
387
+ self._item_delete_future: dict[str, asyncio.Future] = {}
388
+ self._item_create_future: dict[str, asyncio.Future] = {}
389
+
390
+ self._current_generation: _ResponseGeneration | None = None
391
+ self._current_item: _MessageGeneration | None = None
392
+ self._remote_chat_ctx = llm.remote_chat_context.RemoteChatContext()
393
+ self._is_opening = False
394
+ self._first_tts_response = True
395
+ self._first_llm_response = True
396
+ self._first_llm_sentence = True
397
+
398
+ self._update_chat_ctx_lock = asyncio.Lock()
399
+ self._update_fnc_ctx_lock = asyncio.Lock()
400
+
401
+ # 100ms chunks
402
+ self._bstream = utils.audio.AudioByteStream(
403
+ self._realtime_model._opts.sample_rate,
404
+ self._realtime_model._opts.num_channels,
405
+ samples_per_channel=self._realtime_model._opts.sample_rate // 10,
406
+ )
407
+ self._pushed_duration_s: float = (
408
+ 0 # duration of audio pushed to the OpenAI Realtime API
409
+ )
410
+
411
+ def send_event(self, event: rtc.AudioFrame) -> None:
412
+ with contextlib.suppress(utils.aio.channel.ChanClosed):
413
+ self._msg_ch.send_nowait(event)
414
+
415
+ @utils.log_exceptions(logger=logger)
416
+ async def _main_task(self) -> None:
417
+ logger.info("start realtime main task")
418
+ # while not self._msg_ch.closed:
419
+ ws_conn = await self._create_ws_conn()
420
+
421
+ try:
422
+ await self._run_ws(ws_conn)
423
+
424
+ except Exception as e:
425
+ logger.error("realtime main task error", exc_info=e)
426
+ self._emit_error(e, recoverable=False)
427
+ raise e
428
+ logger.info("realtime main task break")
429
+ # break
430
+
431
+ async def _create_ws_conn(self) -> aiohttp.ClientWebSocketResponse:
432
+ headers = self._realtime_model._opts.get_ws_headers()
433
+ url = self._realtime_model._opts.ws_url
434
+ return await asyncio.wait_for(
435
+ self._realtime_model._ensure_http_session().ws_connect(
436
+ url=url,
437
+ headers=headers,
438
+ ),
439
+ self._realtime_model._opts.conn_options.timeout,
440
+ )
441
+
442
+ async def _run_ws(self, ws_conn: aiohttp.ClientWebSocketResponse) -> None:
443
+ closing = False
444
+ logger.info("start connection")
445
+ start_connection_request = bytearray(generate_header())
446
+ start_connection_request.extend(int(1).to_bytes(4, "big"))
447
+ payload_bytes = str.encode("{}")
448
+ payload_bytes = gzip.compress(payload_bytes)
449
+ start_connection_request.extend((len(payload_bytes)).to_bytes(4, "big"))
450
+ start_connection_request.extend(payload_bytes)
451
+ await ws_conn.send_bytes(start_connection_request)
452
+ _ = await ws_conn.receive_bytes()
453
+
454
+ logger.info("start session")
455
+ await self._start_session(ws_conn=ws_conn, dialog_id=self.session_id)
456
+
457
+ if self._realtime_model._opts.opening is not None:
458
+ self._is_opening = True
459
+ payload = {
460
+ "content": self._realtime_model._opts.opening,
461
+ }
462
+ hello_request = bytearray(generate_header())
463
+ hello_request.extend(int(300).to_bytes(4, "big"))
464
+ payload_bytes = str.encode(json.dumps(payload))
465
+ payload_bytes = gzip.compress(payload_bytes)
466
+ hello_request.extend((len(self.session_id)).to_bytes(4, "big"))
467
+ hello_request.extend(str.encode(self.session_id))
468
+ hello_request.extend((len(payload_bytes)).to_bytes(4, "big"))
469
+ hello_request.extend(payload_bytes)
470
+ await ws_conn.send_bytes(hello_request)
471
+ self._is_opening = True
472
+ logger.info("send hello request")
473
+
474
+ self._current_generation = _ResponseGeneration(
475
+ message_ch=utils.aio.Chan(),
476
+ function_ch=utils.aio.Chan(),
477
+ messages={},
478
+ _created_timestamp=time.time(),
479
+ _done_fut=asyncio.Future(),
480
+ )
481
+
482
+ generation_ev = llm.GenerationCreatedEvent(
483
+ message_stream=self._current_generation.message_ch,
484
+ function_stream=self._current_generation.function_ch,
485
+ user_initiated=False,
486
+ )
487
+ self.emit("generation_created", generation_ev)
488
+ item_id = utils.shortuuid()
489
+ modalities_fut: asyncio.Future[list[Literal["text", "audio"]]] = (
490
+ asyncio.Future()
491
+ )
492
+ self._current_item = _MessageGeneration(
493
+ message_id=item_id,
494
+ text_ch=utils.aio.Chan(),
495
+ audio_ch=utils.aio.Chan(),
496
+ modalities=modalities_fut,
497
+ )
498
+ if not self._realtime_model.capabilities.audio_output:
499
+ self._current_item.audio_ch.close()
500
+ self._current_item.modalities.set_result(["text"]) # type: ignore[union-attr]
501
+ else:
502
+ self._current_item.modalities.set_result(["audio", "text"]) # type: ignore[union-attr]
503
+
504
+ self._current_generation.message_ch.send_nowait(
505
+ llm.MessageGeneration(
506
+ message_id=item_id,
507
+ text_stream=self._current_item.text_ch,
508
+ audio_stream=self._current_item.audio_ch,
509
+ modalities=self._current_item.modalities,
510
+ )
511
+ )
512
+
513
+ @utils.log_exceptions(logger=logger)
514
+ async def _send_task() -> None:
515
+ nonlocal closing
516
+ async for frame in self._msg_ch:
517
+ try:
518
+ task_request = bytearray(
519
+ generate_header(
520
+ message_type=CLIENT_AUDIO_ONLY_REQUEST,
521
+ serial_method=NO_SERIALIZATION,
522
+ )
523
+ )
524
+ task_request.extend(int(200).to_bytes(4, "big"))
525
+ task_request.extend((len(self.session_id)).to_bytes(4, "big"))
526
+ task_request.extend(str.encode(self.session_id))
527
+ payload_bytes = gzip.compress(frame.data.tobytes())
528
+ task_request.extend(
529
+ (len(payload_bytes)).to_bytes(4, "big")
530
+ ) # payload size(4 bytes)
531
+ task_request.extend(payload_bytes)
532
+ await ws_conn.send_bytes(task_request)
533
+
534
+ except Exception:
535
+ logger.error("send task error", exc_info=True)
536
+ break
537
+
538
+ closing = True
539
+ await ws_conn.close()
540
+
541
+ @utils.log_exceptions(logger=logger)
542
+ async def _recv_task() -> None:
543
+ while True:
544
+ try:
545
+ msg = await ws_conn.receive()
546
+ if msg.data is None:
547
+ continue
548
+ response = parse_response(msg.data)
549
+ event = response.get("event")
550
+ if event == 450: # ASRInfo
551
+ self.emit("input_speech_started", llm.InputSpeechStartedEvent())
552
+ logger.info("transcription start")
553
+ elif event == 451: # ASRResponse
554
+ response = response["payload_msg"]
555
+ transcription = response["results"][0]["alternatives"][0][
556
+ "text"
557
+ ]
558
+ is_final = not response["results"][0]["is_interim"]
559
+ if is_final:
560
+ item_id = utils.shortuuid()
561
+ self.emit(
562
+ "input_audio_transcription_completed",
563
+ llm.InputTranscriptionCompleted(
564
+ item_id=item_id,
565
+ transcript=transcription,
566
+ is_final=True,
567
+ ),
568
+ )
569
+ if self._current_generation is None:
570
+ self._current_generation = _ResponseGeneration(
571
+ message_ch=utils.aio.Chan(),
572
+ function_ch=utils.aio.Chan(),
573
+ messages={},
574
+ _created_timestamp=time.time(),
575
+ _done_fut=asyncio.Future(),
576
+ )
577
+
578
+ generation_ev = llm.GenerationCreatedEvent(
579
+ message_stream=self._current_generation.message_ch,
580
+ function_stream=self._current_generation.function_ch,
581
+ user_initiated=False,
582
+ )
583
+
584
+ self.emit("generation_created", generation_ev)
585
+ item_id = utils.shortuuid()
586
+ modalities_fut: asyncio.Future[
587
+ list[Literal["text", "audio"]]
588
+ ] = asyncio.Future()
589
+ self._current_item = _MessageGeneration(
590
+ message_id=item_id,
591
+ text_ch=utils.aio.Chan(),
592
+ audio_ch=utils.aio.Chan(),
593
+ modalities=modalities_fut,
594
+ )
595
+ if not self._realtime_model.capabilities.audio_output:
596
+ self._current_item.audio_ch.close()
597
+ self._current_item.modalities.set_result(["text"]) # type: ignore[union-attr]
598
+ else:
599
+ self._current_item.modalities.set_result(
600
+ ["audio", "text"]
601
+ ) # type: ignore[union-attr]
602
+ self._current_generation.message_ch.send_nowait(
603
+ llm.MessageGeneration(
604
+ message_id=item_id,
605
+ text_stream=self._current_item.text_ch,
606
+ audio_stream=self._current_item.audio_ch,
607
+ modalities=self._current_item.modalities,
608
+ )
609
+ )
610
+
611
+ elif event == 459: # ASREnd
612
+ logger.info("transcription end")
613
+ self.emit(
614
+ "input_speech_stopped",
615
+ llm.InputSpeechStoppedEvent(
616
+ user_transcription_enabled=False
617
+ ),
618
+ )
619
+ if self._realtime_model._rag_fn is not None:
620
+ logger.info("rag start")
621
+ rag_result = self._realtime_model._rag_fn(transcription)
622
+ payload = {
623
+ "external_rag": rag_result,
624
+ }
625
+ payload_bytes = str.encode(json.dumps(payload))
626
+ payload_bytes = gzip.compress(payload_bytes)
627
+ chat_rag_text_request = bytearray(generate_header())
628
+ chat_rag_text_request.extend(int(502).to_bytes(4, "big"))
629
+ chat_rag_text_request.extend(
630
+ (len(self.session_id)).to_bytes(4, "big")
631
+ )
632
+ chat_rag_text_request.extend(str.encode(self.session_id))
633
+ chat_rag_text_request.extend(
634
+ (len(payload_bytes)).to_bytes(4, "big")
635
+ )
636
+ chat_rag_text_request.extend(payload_bytes)
637
+ await ws_conn.send_bytes(chat_rag_text_request)
638
+ logger.info("rag end")
639
+ logger.info("llm start")
640
+ logger.info("tts start")
641
+
642
+ elif event == 352: # TTSResponse
643
+ if self._first_tts_response:
644
+ logger.info("llm first sentence")
645
+ logger.info("tts first response")
646
+ self._first_tts_response = False
647
+ audio_bytes = response[
648
+ "payload_msg"
649
+ ] # 原始为float32,需要转为int16
650
+ audio = np.frombuffer(audio_bytes, dtype=np.float32)
651
+ # 裁剪到 [-1.0, 1.0],避免溢出
652
+ audio = np.clip(audio, -1.0, 1.0)
653
+ audio = (audio * 32767).astype(np.int16)
654
+ audio_bytes = audio.tobytes()
655
+ self._current_item.audio_ch.send_nowait(
656
+ rtc.AudioFrame(
657
+ data=audio_bytes,
658
+ sample_rate=self._realtime_model._opts.sample_rate,
659
+ num_channels=1,
660
+ samples_per_channel=len(audio_bytes) // 2,
661
+ )
662
+ )
663
+ elif event == 350: # TTSSentenceStart
664
+ pass
665
+ elif event == 351: # TTSSentenceEnd
666
+ pass
667
+ elif event == 359: # TTSEnded
668
+ logger.info("tts end")
669
+ self._current_item.audio_ch.close()
670
+ if self._is_opening:
671
+ self._current_item.text_ch.send_nowait(
672
+ self._realtime_model._opts.opening
673
+ )
674
+ self._current_item.text_ch.close()
675
+ self._is_opening = False
676
+ self._current_generation.message_ch.close()
677
+ self._current_generation.function_ch.close()
678
+ self._current_generation = None
679
+ self._first_tts_response = True
680
+ elif event == 550: # 模型回复的文本内容
681
+ if self._first_llm_response:
682
+ logger.info("llm first response")
683
+ self._first_llm_response = False
684
+ text = response["payload_msg"]["content"]
685
+ self._current_item.text_ch.send_nowait(text)
686
+ elif event == 559: # 模型回复文本结束事件
687
+ logger.info("llm end")
688
+ self._current_item.text_ch.close()
689
+ self._first_llm_response = True
690
+ else:
691
+ pass
692
+ except Exception:
693
+ logger.error("recv task error", exc_info=True)
694
+ break
695
+
696
+ tasks = [
697
+ asyncio.create_task(_recv_task(), name="_recv_task"),
698
+ asyncio.create_task(_send_task(), name="_send_task"),
699
+ ]
700
+ try:
701
+ done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
702
+ for task in done:
703
+ task.result()
704
+
705
+ finally:
706
+ await utils.aio.cancel_and_wait(*tasks)
707
+ await ws_conn.close()
708
+
709
+ def _create_session_update_event(self):
710
+ pass
711
+
712
+ async def chat_tts_text(
713
+ self,
714
+ start: bool,
715
+ end: bool,
716
+ content: str,
717
+ ws_conn: aiohttp.ClientWebSocketResponse,
718
+ ) -> None:
719
+ """发送Chat TTS Text消息"""
720
+ payload = {
721
+ "start": start,
722
+ "end": end,
723
+ "content": content,
724
+ }
725
+ logger.info("ChatTTSTextRequest")
726
+ payload_bytes = str.encode(json.dumps(payload))
727
+ payload_bytes = gzip.compress(payload_bytes)
728
+
729
+ chat_tts_text_request = bytearray(generate_header())
730
+ chat_tts_text_request.extend(int(500).to_bytes(4, "big"))
731
+ chat_tts_text_request.extend((len(self.session_id)).to_bytes(4, "big"))
732
+ chat_tts_text_request.extend(str.encode(self.session_id))
733
+ chat_tts_text_request.extend((len(payload_bytes)).to_bytes(4, "big"))
734
+ chat_tts_text_request.extend(payload_bytes)
735
+ await ws_conn.send_bytes(chat_tts_text_request)
736
+
737
+ async def _start_session(
738
+ self, ws_conn: aiohttp.ClientWebSocketResponse, dialog_id: str
739
+ ) -> None:
740
+ request_params = self._realtime_model._opts.get_start_session_reqs(
741
+ dialog_id=dialog_id
742
+ )
743
+ payload_bytes = str.encode(json.dumps(request_params))
744
+ payload_bytes = gzip.compress(payload_bytes)
745
+ start_session_request = bytearray(generate_header())
746
+ start_session_request.extend(int(100).to_bytes(4, "big"))
747
+ start_session_request.extend((len(self.session_id)).to_bytes(4, "big"))
748
+ start_session_request.extend(str.encode(self.session_id))
749
+ start_session_request.extend((len(payload_bytes)).to_bytes(4, "big"))
750
+ start_session_request.extend(payload_bytes)
751
+ await ws_conn.send_bytes(start_session_request)
752
+ _ = await ws_conn.receive_bytes()
753
+
754
+ async def _finish_session(self, ws_conn: aiohttp.ClientWebSocketResponse) -> None:
755
+ finish_session_request = bytearray(generate_header())
756
+ finish_session_request.extend(int(102).to_bytes(4, "big"))
757
+ payload_bytes = str.encode("{}")
758
+ payload_bytes = gzip.compress(payload_bytes)
759
+ finish_session_request.extend((len(self.session_id)).to_bytes(4, "big"))
760
+ finish_session_request.extend(str.encode(self.session_id))
761
+ finish_session_request.extend((len(payload_bytes)).to_bytes(4, "big"))
762
+ finish_session_request.extend(payload_bytes)
763
+ await ws_conn.send_bytes(finish_session_request)
764
+
765
+ async def _finish_connection(
766
+ self, ws_conn: aiohttp.ClientWebSocketResponse
767
+ ) -> None:
768
+ finish_connection_request = bytearray(generate_header())
769
+ finish_connection_request.extend(int(2).to_bytes(4, "big"))
770
+ payload_bytes = str.encode("{}")
771
+ payload_bytes = gzip.compress(payload_bytes)
772
+ finish_connection_request.extend((len(payload_bytes)).to_bytes(4, "big"))
773
+ finish_connection_request.extend(payload_bytes)
774
+ await ws_conn.send_bytes(finish_connection_request)
775
+ _ = await ws_conn.receive_bytes()
776
+
777
+ @property
778
+ def chat_ctx(self) -> llm.ChatContext:
779
+ return self._remote_chat_ctx.to_chat_ctx()
780
+
781
+ @property
782
+ def tools(self) -> llm.ToolContext:
783
+ return self._tools.copy()
784
+
785
+ def update_options(
786
+ self,
787
+ *,
788
+ tool_choice: NotGivenOr[llm.ToolChoice | None] = NOT_GIVEN,
789
+ voice: NotGivenOr[str] = NOT_GIVEN,
790
+ ) -> None:
791
+ pass
792
+
793
+ async def update_tools(self, tools: list[llm.Tool]) -> None:
794
+ pass
795
+
796
+ async def update_chat_ctx(self, chat_ctx: llm.ChatContext) -> None:
797
+ pass
798
+
799
+ def _create_update_chat_ctx_events(self, chat_ctx: llm.ChatContext):
800
+ events = []
801
+
802
+ return events
803
+
804
+ async def update_instructions(self, instructions: str) -> None:
805
+ self._opts.system_role = instructions
806
+
807
+ def push_audio(self, frame: rtc.AudioFrame) -> None:
808
+ for f in self._resample_audio(frame):
809
+ data = f.data.tobytes()
810
+ for nf in self._bstream.write(data):
811
+ self.send_event(nf)
812
+ self._pushed_duration_s += nf.duration
813
+
814
+ def push_video(self, frame: rtc.VideoFrame) -> None:
815
+ pass
816
+
817
+ def commit_audio(self) -> None:
818
+ if self._pushed_duration_s > 0.1:
819
+ self._pushed_duration_s = 0
820
+
821
+ def clear_audio(self) -> None:
822
+ self._pushed_duration_s = 0
823
+
824
+ def generate_reply(
825
+ self,
826
+ *,
827
+ instructions: NotGivenOr[str] = NOT_GIVEN,
828
+ tool_choice: NotGivenOr[llm.ToolChoice] = NOT_GIVEN,
829
+ tools: NotGivenOr[list[llm.Tool]] = NOT_GIVEN,
830
+ ) -> asyncio.Future[llm.GenerationCreatedEvent]:
831
+ """仅文字输入"""
832
+ event_id = utils.shortuuid("response_create_")
833
+ fut = asyncio.Future[llm.GenerationCreatedEvent]()
834
+ self._response_created_futures[event_id] = fut
835
+
836
+ def _on_timeout() -> None:
837
+ if fut and not fut.done():
838
+ fut.set_exception(llm.RealtimeError("generate_reply timed out."))
839
+
840
+ handle = asyncio.get_event_loop().call_later(5.0, _on_timeout)
841
+ fut.add_done_callback(lambda _: handle.cancel())
842
+ return fut
843
+
844
+ def interrupt(self) -> None:
845
+ pass
846
+
847
+ def truncate(
848
+ self,
849
+ *,
850
+ message_id: str,
851
+ modalities: list[Literal["text", "audio"]],
852
+ audio_end_ms: int,
853
+ audio_transcript: NotGivenOr[str] = NOT_GIVEN,
854
+ ) -> None:
855
+ if "audio" in modalities:
856
+ # 当前 volcengine 实时接口未暴露远端音频截断事件;占位以对齐接口
857
+ pass
858
+ elif utils.is_given(audio_transcript):
859
+ # 同步转写文本到远端会话上下文
860
+ chat_ctx = self.chat_ctx.copy()
861
+ if (idx := chat_ctx.index_by_id(message_id)) is not None:
862
+ new_item = copy.copy(chat_ctx.items[idx])
863
+ assert new_item.type == "message"
864
+
865
+ new_item.content = [audio_transcript]
866
+ chat_ctx.items[idx] = new_item
867
+ events = self._create_update_chat_ctx_events(chat_ctx)
868
+ for ev in events:
869
+ self.send_event(ev)
870
+
871
+ async def aclose(self) -> None:
872
+ self._msg_ch.close()
873
+ await self._main_atask
874
+
875
+ def _resample_audio(self, frame: rtc.AudioFrame) -> Iterator[rtc.AudioFrame]:
876
+ if self._input_resampler:
877
+ if frame.sample_rate != self._input_resampler._input_rate:
878
+ # input audio changed to a different sample rate
879
+ self._input_resampler = None
880
+
881
+ if self._input_resampler is None and (
882
+ frame.sample_rate != self._realtime_model._opts.sample_rate
883
+ or frame.num_channels != self._realtime_model._opts.num_channels
884
+ ):
885
+ self._input_resampler = rtc.AudioResampler(
886
+ input_rate=frame.sample_rate,
887
+ output_rate=self._realtime_model._opts.sample_rate,
888
+ num_channels=self._realtime_model._opts.num_channels,
889
+ )
890
+
891
+ if self._input_resampler:
892
+ # TODO(long): flush the resampler when the input source is changed
893
+ yield from self._input_resampler.push(frame)
894
+ else:
895
+ yield frame
896
+
897
+ def _emit_error(self, error: Exception, recoverable: bool) -> None:
898
+ self.emit(
899
+ "error",
900
+ llm.RealtimeModelError(
901
+ timestamp=time.time(),
902
+ label=self._realtime_model._label,
903
+ error=error,
904
+ recoverable=recoverable,
905
+ ),
906
+ )