livekit-plugins-volcenginee 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,495 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import gzip
5
+ import json
6
+ import os
7
+ import weakref
8
+ from dataclasses import dataclass
9
+ from typing import Literal
10
+
11
+ import aiohttp
12
+
13
+ from livekit import rtc
14
+ from livekit.agents import (
15
+ DEFAULT_API_CONNECT_OPTIONS,
16
+ APIConnectOptions,
17
+ APIStatusError,
18
+ stt,
19
+ utils,
20
+ )
21
+ from livekit.agents.types import (
22
+ NOT_GIVEN,
23
+ NotGivenOr,
24
+ )
25
+ from livekit.agents.utils import AudioBuffer
26
+
27
+ from .log import logger
28
+
29
+ PROTOCOL_VERSION = 0b0001
30
+ DEFAULT_HEADER_SIZE = 0b0001
31
+
32
+ # Message Type:
33
+ FULL_CLIENT_REQUEST = 0b0001
34
+ AUDIO_ONLY_REQUEST = 0b0010
35
+ FULL_SERVER_RESPONSE = 0b1001
36
+ SERVER_ACK = 0b1011
37
+ SERVER_ERROR_RESPONSE = 0b1111
38
+
39
+ # Message Type Specific Flags
40
+ NO_SEQUENCE = 0b0000
41
+ POS_SEQUENCE = 0b0001
42
+ NEG_SEQUENCE = 0b0010
43
+ NEG_WITH_SEQUENCE = 0b0011
44
+ NEG_SEQUENCE_1 = 0b0011
45
+
46
+ # Message Serialization
47
+ NO_SERIALIZATION = 0b0000
48
+ JSON = 0b0001
49
+
50
+ # Message Compression
51
+ NO_COMPRESSION = 0b0000
52
+ GZIP = 0b0001
53
+
54
+
55
+ def generate_header(
56
+ message_type=FULL_CLIENT_REQUEST,
57
+ message_type_specific_flags=NO_SEQUENCE,
58
+ serial_method=JSON,
59
+ compression_type=GZIP,
60
+ reserved_data=0x00,
61
+ ):
62
+ header = bytearray()
63
+ header_size = 1
64
+ header.append((PROTOCOL_VERSION << 4) | header_size)
65
+ header.append((message_type << 4) | message_type_specific_flags)
66
+ header.append((serial_method << 4) | compression_type)
67
+ header.append(reserved_data)
68
+ return header
69
+
70
+
71
+ def generate_before_payload(sequence: int):
72
+ before_payload = bytearray()
73
+ before_payload.extend(sequence.to_bytes(4, "big", signed=True))
74
+ return before_payload
75
+
76
+
77
+ @dataclass
78
+ class STTOptions:
79
+ app_id: str | None = None
80
+ access_token: str | None = None
81
+ source_type: Literal["duration", "concurrent"] = "duration"
82
+ resource_id: str | None = None
83
+
84
+ base_url: str = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel"
85
+ format: Literal["pcm", "wav", "ogg"] = "pcm"
86
+ sample_rate: int = 16000
87
+ bits: int = 16
88
+ num_channels: int = 1
89
+ language: str = "zh-CN"
90
+
91
+ model_name: str = "bigmodel"
92
+ codec: Literal["raw", "opus"] = "raw"
93
+ enable_itn: bool = False
94
+ enable_punc: bool = True
95
+ enable_ddc: bool = False
96
+ show_utterance: bool = True
97
+ result_type: Literal["full", "single"] = "single"
98
+ vad_segment_duration: int = 3000
99
+ end_window_size: int = 500
100
+ force_to_speech_time: int = 1000
101
+
102
+ def get_ws_url(self):
103
+ return self.base_url
104
+
105
+ def get_ws_query_params(self, uid: str | None = None) -> bytearray:
106
+ if uid is None:
107
+ uid = utils.shortuuid()
108
+ submit_request_json = {
109
+ "user": {"uid": uid},
110
+ "audio": {
111
+ "format": self.format,
112
+ "rate": self.sample_rate,
113
+ "bits": self.bits,
114
+ "channels": self.num_channels,
115
+ "codec": self.codec,
116
+ },
117
+ "request": {
118
+ "model_name": self.model_name,
119
+ "enable_itn": self.enable_itn,
120
+ "enable_punc": self.enable_punc,
121
+ "enable_ddc": self.enable_ddc,
122
+ "show_utterance": self.show_utterance,
123
+ "result_type": self.result_type,
124
+ "vad_segment_duration": self.vad_segment_duration,
125
+ "end_window_size": self.end_window_size,
126
+ "force_to_speech_time": self.force_to_speech_time,
127
+ },
128
+ }
129
+ payload_bytes = gzip.compress(str.encode(json.dumps(submit_request_json)))
130
+ full_client_request = bytearray(
131
+ generate_header(message_type_specific_flags=POS_SEQUENCE)
132
+ )
133
+ full_client_request.extend(generate_before_payload(sequence=1))
134
+ full_client_request.extend((len(payload_bytes)).to_bytes(4, "big"))
135
+ full_client_request.extend(payload_bytes)
136
+ return full_client_request
137
+
138
+ def get_chunk_request(
139
+ self, chunk: bytes, seq: int, last: bool = False
140
+ ) -> bytearray:
141
+ payload_bytes = gzip.compress(chunk)
142
+ audio_only_request = bytearray(
143
+ generate_header(
144
+ message_type=AUDIO_ONLY_REQUEST,
145
+ message_type_specific_flags=POS_SEQUENCE,
146
+ )
147
+ )
148
+ if last:
149
+ audio_only_request = bytearray(
150
+ generate_header(
151
+ message_type=AUDIO_ONLY_REQUEST,
152
+ message_type_specific_flags=NEG_WITH_SEQUENCE,
153
+ )
154
+ )
155
+ audio_only_request.extend(generate_before_payload(sequence=seq))
156
+ audio_only_request.extend((len(payload_bytes)).to_bytes(4, "big"))
157
+ audio_only_request.extend(payload_bytes)
158
+ return audio_only_request
159
+
160
+ def get_ws_header(self, reqid: str | None = None) -> dict[str, str]:
161
+ header = {}
162
+ if reqid is None:
163
+ reqid = utils.shortuuid()
164
+ if self.resource_id is not None:
165
+ header["X-Api-Resource-Id"] = self.resource_id
166
+ elif self.source_type == "duration":
167
+ header["X-Api-Resource-Id"] = "volc.bigasr.sauc.duration"
168
+ else:
169
+ header["X-Api-Resource-Id"] = "volc.bigasr.sauc.concurrent"
170
+ if self.app_id is None:
171
+ self.app_id = os.environ.get("VOLCENGINE_STT_APP_ID", None)
172
+ if self.app_id is None:
173
+ raise ValueError("VOLCENGINE_STT_APP_ID is not set")
174
+ if self.access_token is None:
175
+ self.access_token = os.environ.get("VOLCENGINE_STT_ACCESS_TOKEN", None)
176
+ if self.access_token is None:
177
+ raise ValueError("VOLCENGINE_STT_ACCESS_TOKEN is not set")
178
+ header["X-Api-Access-Key"] = self.access_token
179
+ header["X-Api-App-Key"] = self.app_id
180
+ header["X-Api-Connect-Id"] = reqid
181
+ header["X-Api-Request-Id"] = reqid
182
+ return header
183
+
184
+
185
+ class STT(stt.STT):
186
+ def __init__(
187
+ self,
188
+ *,
189
+ app_id: str | None = None,
190
+ base_url: str = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel",
191
+ access_token: str | None = None,
192
+ resource_id: str | None = None,
193
+ model_name: str = "bigmodel",
194
+ enable_itn: bool = False,
195
+ enable_punc: bool = True,
196
+ enable_ddc: bool = False,
197
+ vad_segment_duration: int = 3000,
198
+ end_window_size: int = 500,
199
+ force_to_speech_time: int = 1000,
200
+ http_session: aiohttp.ClientSession | None = None,
201
+ interim_results: bool = True,
202
+ ) -> None:
203
+ super().__init__(
204
+ capabilities=stt.STTCapabilities(
205
+ streaming=True, interim_results=interim_results
206
+ )
207
+ )
208
+
209
+ self._opts = STTOptions(
210
+ base_url=base_url,
211
+ access_token=access_token,
212
+ app_id=app_id,
213
+ resource_id=resource_id,
214
+ model_name=model_name,
215
+ enable_itn=enable_itn,
216
+ enable_punc=enable_punc,
217
+ enable_ddc=enable_ddc,
218
+ vad_segment_duration=vad_segment_duration,
219
+ end_window_size=end_window_size,
220
+ force_to_speech_time=force_to_speech_time,
221
+ )
222
+
223
+ self._session = http_session
224
+ self._streams = weakref.WeakSet[SpeechStream]()
225
+
226
+ def _ensure_session(self) -> aiohttp.ClientSession:
227
+ if not self._session:
228
+ self._session = utils.http_context.http_session()
229
+ return self._session
230
+
231
+ async def _recognize_impl(
232
+ self,
233
+ buffer: AudioBuffer,
234
+ *,
235
+ language: NotGivenOr[str] = NOT_GIVEN,
236
+ conn_options: APIConnectOptions,
237
+ ) -> stt.SpeechEvent:
238
+ raise NotImplementedError("not implemented")
239
+
240
+ def stream(
241
+ self,
242
+ *,
243
+ language: NotGivenOr[str] = NOT_GIVEN,
244
+ conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
245
+ ) -> SpeechStream:
246
+ stream = SpeechStream(
247
+ stt=self,
248
+ conn_options=conn_options,
249
+ opts=self._opts,
250
+ http_session=self._ensure_session(),
251
+ )
252
+ self._streams.add(stream)
253
+ return stream
254
+
255
+
256
+ class SpeechStream(stt.SpeechStream):
257
+ def __init__(
258
+ self,
259
+ *,
260
+ stt: STT,
261
+ opts: STTOptions,
262
+ conn_options: APIConnectOptions,
263
+ http_session: aiohttp.ClientSession,
264
+ ) -> None:
265
+ super().__init__(
266
+ stt=stt, conn_options=conn_options, sample_rate=opts.sample_rate
267
+ )
268
+
269
+ self._opts = opts
270
+ self._session = http_session
271
+ self._speaking = False
272
+
273
+ self._request_id = utils.shortuuid()
274
+ self._reconnect_event = asyncio.Event()
275
+
276
+ async def _run(self) -> None:
277
+ closing_ws = False
278
+
279
+ @utils.log_exceptions(logger=logger)
280
+ async def send_task(ws: aiohttp.ClientWebSocketResponse):
281
+ nonlocal closing_ws
282
+
283
+ full_client_request = self._opts.get_ws_query_params(uid=self._request_id)
284
+ await ws.send_bytes(full_client_request)
285
+
286
+ samples_100ms = self._opts.sample_rate // 10
287
+ audio_bstream = utils.audio.AudioByteStream(
288
+ sample_rate=self._opts.sample_rate,
289
+ num_channels=self._opts.num_channels,
290
+ samples_per_channel=samples_100ms,
291
+ )
292
+ has_ended = False
293
+ seq = 1
294
+ async for data in self._input_ch:
295
+ frames: list[rtc.AudioFrame] = []
296
+ if isinstance(data, rtc.AudioFrame):
297
+ frames.extend(audio_bstream.write(data.data.tobytes()))
298
+ elif isinstance(data, self._FlushSentinel):
299
+ frames.extend(audio_bstream.flush())
300
+ has_ended = True
301
+ for frame in frames:
302
+ seq += 1
303
+ if has_ended:
304
+ seq = -seq
305
+ chunk_request = self._opts.get_chunk_request(
306
+ frame.data.tobytes(), seq=seq, last=has_ended
307
+ )
308
+ await ws.send_bytes(chunk_request)
309
+
310
+ @utils.log_exceptions(logger=logger)
311
+ async def recv_task(ws: aiohttp.ClientWebSocketResponse):
312
+ nonlocal closing_ws
313
+ while True:
314
+ msg = await ws.receive()
315
+ if msg.type in (
316
+ aiohttp.WSMsgType.CLOSED,
317
+ aiohttp.WSMsgType.CLOSE,
318
+ aiohttp.WSMsgType.CLOSING,
319
+ ):
320
+ if closing_ws:
321
+ return
322
+ raise APIStatusError(message="connection closed unexpectedly")
323
+
324
+ try:
325
+ self._process_stream_event(msg.data)
326
+ except Exception:
327
+ logger.exception("failed to process message")
328
+
329
+ ws: aiohttp.ClientWebSocketResponse | None = None
330
+
331
+ while True:
332
+ try:
333
+ ws = await self._connect_ws()
334
+ tasks = [
335
+ asyncio.create_task(send_task(ws)),
336
+ asyncio.create_task(recv_task(ws)),
337
+ ]
338
+ wait_reconnect_task = asyncio.create_task(self._reconnect_event.wait())
339
+ try:
340
+ done, _ = await asyncio.wait(
341
+ [asyncio.gather(*tasks), wait_reconnect_task],
342
+ return_when=asyncio.FIRST_COMPLETED,
343
+ )
344
+
345
+ for task in done:
346
+ if task != wait_reconnect_task:
347
+ task.result()
348
+
349
+ if wait_reconnect_task not in done:
350
+ break
351
+
352
+ self._reconnect_event.clear()
353
+ finally:
354
+ await utils.aio.gracefully_cancel(*tasks, wait_reconnect_task)
355
+ finally:
356
+ if ws is not None:
357
+ await ws.close()
358
+
359
+ async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse:
360
+ ws = await asyncio.wait_for(
361
+ self._session.ws_connect(
362
+ self._opts.get_ws_url(),
363
+ headers=self._opts.get_ws_header(reqid=self._request_id),
364
+ max_msg_size=1000000000,
365
+ ),
366
+ self._conn_options.timeout,
367
+ )
368
+ return ws
369
+
370
+ def _process_stream_event(self, data: dict) -> None:
371
+ results = parse_response(res=data)["payload_msg"]
372
+ result = results.get("result", None)
373
+ if result is None:
374
+ return
375
+ text = result.get("text", "")
376
+ if text == "":
377
+ return
378
+ utterances = result.get("utterances", [])
379
+ if len(utterances) == 0:
380
+ return
381
+ language = self._opts.language
382
+ definite = utterances[0].get("definite", "False")
383
+ start_time = utterances[0].get("start_time", 0.0)
384
+ end_time = utterances[0].get("end_time", 0.0)
385
+ confidence = result.get("confidence", 0.0)
386
+ if not definite and not self._speaking:
387
+ self._speaking = True
388
+ self._event_ch.send_nowait(
389
+ stt.SpeechEvent(type=stt.SpeechEventType.START_OF_SPEECH)
390
+ )
391
+ logger.info("transcription start")
392
+ if text:
393
+ alternatives = [
394
+ stt.SpeechData(
395
+ language=language,
396
+ text=text,
397
+ start_time=start_time,
398
+ end_time=end_time,
399
+ confidence=confidence,
400
+ )
401
+ ]
402
+ self._event_ch.send_nowait(
403
+ stt.SpeechEvent(
404
+ type=stt.SpeechEventType.INTERIM_TRANSCRIPT,
405
+ request_id=self._request_id,
406
+ alternatives=alternatives,
407
+ )
408
+ )
409
+ elif not definite and self._speaking:
410
+ alternatives = [
411
+ stt.SpeechData(
412
+ text=text,
413
+ start_time=start_time,
414
+ end_time=end_time,
415
+ confidence=confidence,
416
+ language=language,
417
+ )
418
+ ]
419
+ self._event_ch.send_nowait(
420
+ stt.SpeechEvent(
421
+ type=stt.SpeechEventType.INTERIM_TRANSCRIPT,
422
+ request_id=self._request_id,
423
+ alternatives=alternatives,
424
+ )
425
+ )
426
+ elif definite and self._speaking:
427
+ alternatives = [
428
+ stt.SpeechData(
429
+ text=text,
430
+ start_time=start_time,
431
+ end_time=end_time,
432
+ confidence=confidence,
433
+ language=language,
434
+ )
435
+ ]
436
+ self._event_ch.send_nowait(
437
+ stt.SpeechEvent(
438
+ type=stt.SpeechEventType.FINAL_TRANSCRIPT,
439
+ request_id=self._request_id,
440
+ alternatives=alternatives,
441
+ )
442
+ )
443
+ self._event_ch.send_nowait(
444
+ stt.SpeechEvent(
445
+ type=stt.SpeechEventType.END_OF_SPEECH,
446
+ request_id=self._request_id,
447
+ )
448
+ )
449
+ self._speaking = False
450
+ logger.info("transcription end", extra={"text": text})
451
+
452
+
453
+ def parse_response(res):
454
+ header_size = res[0] & 0x0F
455
+ message_type = res[1] >> 4
456
+ message_type_specific_flags = res[1] & 0x0F
457
+ serialization_method = res[2] >> 4
458
+ message_compression = res[2] & 0x0F
459
+ payload = res[header_size * 4 :]
460
+ result = {"is_last_package": False}
461
+ payload_msg = None
462
+ payload_size = 0
463
+ if message_type_specific_flags & 0x01:
464
+ seq = int.from_bytes(payload[:4], "big", signed=True)
465
+ result["payload_sequence"] = seq
466
+ payload = payload[4:]
467
+
468
+ if message_type_specific_flags & 0x02:
469
+ result["is_last_package"] = True
470
+
471
+ if message_type == FULL_SERVER_RESPONSE:
472
+ payload_size = int.from_bytes(payload[:4], "big", signed=True)
473
+ payload_msg = payload[4:]
474
+ elif message_type == SERVER_ACK:
475
+ seq = int.from_bytes(payload[:4], "big", signed=True)
476
+ result["seq"] = seq
477
+ if len(payload) >= 8:
478
+ payload_size = int.from_bytes(payload[4:8], "big", signed=False)
479
+ payload_msg = payload[8:]
480
+ elif message_type == SERVER_ERROR_RESPONSE:
481
+ code = int.from_bytes(payload[:4], "big", signed=False)
482
+ result["code"] = code
483
+ payload_size = int.from_bytes(payload[4:8], "big", signed=False)
484
+ payload_msg = payload[8:]
485
+ if payload_msg is None:
486
+ return result
487
+ if message_compression == GZIP:
488
+ payload_msg = gzip.decompress(payload_msg)
489
+ if serialization_method == JSON:
490
+ payload_msg = json.loads(str(payload_msg, "utf-8"))
491
+ elif serialization_method != NO_SERIALIZATION:
492
+ payload_msg = str(payload_msg, "utf-8")
493
+ result["payload_msg"] = payload_msg
494
+ result["payload_size"] = payload_size
495
+ return result