xiaozhi-sdk 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xiaozhi-sdk might be problematic. Click here for more details.

xiaozhi_sdk/core.py ADDED
@@ -0,0 +1,270 @@
1
+ import asyncio
2
+ import json
3
+ import logging
4
+ import os
5
+ import re
6
+ import uuid
7
+ from collections import deque
8
+ from typing import Any, Callable, Deque, Dict, Optional
9
+
10
+ import websockets
11
+
12
+ from xiaozhi_sdk.config import INPUT_SERVER_AUDIO_SAMPLE_RATE
13
+ from xiaozhi_sdk.iot import OtaDevice
14
+ from xiaozhi_sdk.mcp import McpTool
15
+ from xiaozhi_sdk.utils import get_wav_info, read_audio_file, setup_opus
16
+ from xiaozhi_sdk.utils.mcp_tool import async_mcp_play_music, async_search_custom_music
17
+
18
+ setup_opus()
19
+ from xiaozhi_sdk.opus import AudioOpus
20
+
21
+ logger = logging.getLogger("xiaozhi_sdk")
22
+
23
+
24
+ class XiaoZhiWebsocket(McpTool):
25
+
26
+ def __init__(
27
+ self,
28
+ message_handler_callback: Optional[Callable] = None,
29
+ url: Optional[str] = None,
30
+ ota_url: Optional[str] = None,
31
+ audio_sample_rate: int = 16000,
32
+ audio_channels: int = 1,
33
+ send_wake: bool = False,
34
+ ):
35
+ super().__init__()
36
+ self.url = url
37
+ self.ota_url = ota_url
38
+ self.send_wake = send_wake
39
+ self.audio_channels = audio_channels
40
+ self.audio_opus = AudioOpus(audio_sample_rate, audio_channels)
41
+
42
+ # 客户端标识
43
+ self.client_id = str(uuid.uuid4())
44
+ self.mac_addr: Optional[str] = None
45
+ self.aec = False
46
+ self.websocket_token = ""
47
+
48
+ # 回调函数
49
+ self.message_handler_callback = message_handler_callback
50
+
51
+ # 连接状态
52
+ self.hello_received = asyncio.Event()
53
+ self.session_id = ""
54
+ self.websocket = None
55
+ self.message_handler_task: Optional[asyncio.Task] = None
56
+
57
+ # 输出音频
58
+ self.output_audio_queue: Deque[bytes] = deque()
59
+ self.is_playing: bool = False
60
+
61
+ # OTA设备
62
+ self.ota: Optional[OtaDevice] = None
63
+ self.iot_task: Optional[asyncio.Task] = None
64
+ self.wait_device_activated: bool = False
65
+ self.tool_func = {
66
+ "async_play_custom_music": async_mcp_play_music,
67
+ "async_search_custom_music": async_search_custom_music,
68
+ }
69
+
70
+ async def _send_hello(self, aec: bool) -> None:
71
+ """发送hello消息"""
72
+ hello_message = {
73
+ "type": "hello",
74
+ "version": 1,
75
+ "features": {"mcp": True, "aec": aec},
76
+ "transport": "websocket",
77
+ "audio_params": {
78
+ "format": "opus",
79
+ "sample_rate": 16000,
80
+ "channels": 1,
81
+ "frame_duration": 60,
82
+ },
83
+ }
84
+ await self.websocket.send(json.dumps(hello_message))
85
+ await asyncio.wait_for(self.hello_received.wait(), timeout=10.0)
86
+
87
+ async def _start_listen(self) -> None:
88
+ """开始监听"""
89
+ listen_message = {"session_id": self.session_id, "type": "listen", "state": "start", "mode": "realtime"}
90
+ await self.websocket.send(json.dumps(listen_message))
91
+
92
+ async def is_activate(self, ota_info):
93
+ """是否激活"""
94
+ if ota_info.get("activation"):
95
+ return False
96
+
97
+ return True
98
+
99
+ async def _activate_iot_device(self, license_key: str, ota_info: Dict[str, Any]) -> None:
100
+ """激活IoT设备"""
101
+ if not self.ota:
102
+ return
103
+
104
+ challenge = ota_info["activation"]["challenge"]
105
+ await asyncio.sleep(3)
106
+ self.wait_device_activated = True
107
+ for _ in range(10):
108
+ if await self.ota.check_activate(challenge, license_key):
109
+ self.wait_device_activated = False
110
+ break
111
+ await asyncio.sleep(3)
112
+
113
+ async def _send_demo_audio(self) -> None:
114
+ """发送演示音频"""
115
+ current_dir = os.path.dirname(os.path.abspath(__file__))
116
+ wav_path = os.path.join(current_dir, "../file/audio/greet.wav")
117
+ framerate, channels = get_wav_info(wav_path)
118
+ audio_opus = AudioOpus(framerate, channels)
119
+
120
+ for pcm_data in read_audio_file(wav_path):
121
+ opus_data = await audio_opus.pcm_to_opus(pcm_data)
122
+ await self.websocket.send(opus_data)
123
+ await self.send_silence_audio()
124
+
125
+ async def send_wake_word(self, wake_word: str = "你好,小智") -> None:
126
+ """发送唤醒词"""
127
+ await self.websocket.send(
128
+ json.dumps({"session_id": self.session_id, "type": "listen", "state": "detect", "text": wake_word})
129
+ )
130
+
131
+ async def send_silence_audio(self, duration_seconds: float = 1.2) -> None:
132
+ """发送静音音频"""
133
+ frames_count = int(duration_seconds * 1000 / 60)
134
+ pcm_frame = b"\x00\x00" * int(INPUT_SERVER_AUDIO_SAMPLE_RATE / 1000 * 60)
135
+
136
+ for _ in range(frames_count):
137
+ await self.send_audio(pcm_frame)
138
+
139
+ async def _handle_websocket_message(self, message: Any) -> None:
140
+ """处理接受到的WebSocket消息"""
141
+
142
+ # audio data
143
+ if isinstance(message, bytes):
144
+ pcm_array = await self.audio_opus.opus_to_pcm(message)
145
+ self.output_audio_queue.extend(pcm_array)
146
+ return
147
+
148
+ # json message
149
+ data = json.loads(message)
150
+ message_type = data["type"]
151
+ if message_type == "hello":
152
+ self.hello_received.set()
153
+ self.session_id = data["session_id"]
154
+ elif message_type == "mcp":
155
+ await self.mcp(data)
156
+ elif self.message_handler_callback:
157
+ if data["type"] == "tts":
158
+ if data["state"] == "sentence_start":
159
+ self.is_playing = True
160
+ # self.output_audio_queue.clear()
161
+ else:
162
+ self.is_playing = False
163
+
164
+ await self.message_handler_callback(data)
165
+
166
+ async def _message_handler(self) -> None:
167
+ """消息处理器"""
168
+ try:
169
+ async for message in self.websocket:
170
+ await self._handle_websocket_message(message)
171
+ except websockets.ConnectionClosed:
172
+ if self.message_handler_callback:
173
+ await self.message_handler_callback(
174
+ {"type": "websocket", "state": "close", "source": "sdk.message_handler"}
175
+ )
176
+ logger.debug("[websocket] close")
177
+
178
+ async def set_mcp_tool_callback(self, tool_func: Dict[str, Callable[..., Any]]) -> None:
179
+ """设置MCP工具回调函数"""
180
+ self.tool_func.update(tool_func)
181
+
182
+ async def connect_websocket(self, websocket_token):
183
+ """连接websocket"""
184
+ headers = {
185
+ "Authorization": "Bearer {}".format(websocket_token),
186
+ "Protocol-Version": "1",
187
+ "Device-Id": self.mac_addr,
188
+ "Client-Id": self.client_id,
189
+ }
190
+ try:
191
+ self.websocket = await websockets.connect(uri=self.url, additional_headers=headers)
192
+ except websockets.exceptions.InvalidMessage as e:
193
+ logger.error("[websocket] 连接失败,请检查网络连接或设备状态。当前链接地址: %s, 错误信息:%s", self.url, e)
194
+ return
195
+ self.message_handler_task = asyncio.create_task(self._message_handler())
196
+
197
+ await self._send_hello(self.aec)
198
+ await self._start_listen()
199
+ logger.debug("[websocket] Connection successful")
200
+ await asyncio.sleep(0.5)
201
+
202
+ async def init_connection(
203
+ self, mac_addr: str, aec: bool = False, serial_number: str = "", license_key: str = ""
204
+ ) -> None:
205
+ """初始化连接"""
206
+ mac_pattern = r"^([0-9A-Fa-f]{2}:){5}[0-9A-Fa-f]{2}$"
207
+ if not re.match(mac_pattern, mac_addr):
208
+ raise ValueError(f"无效的MAC地址格式: {mac_addr}。正确格式应为 XX:XX:XX:XX:XX:XX")
209
+
210
+ self.mac_addr = mac_addr.lower()
211
+ self.aec = aec
212
+
213
+ self.ota = OtaDevice(self.mac_addr, self.client_id, self.ota_url, serial_number)
214
+ ota_info = await self.ota.activate_device()
215
+ ws_url = ota_info.get("websocket", {}).get("url")
216
+ self.url = self.url or ws_url
217
+
218
+ if not self.url:
219
+ logger.warning("[websocket] 未找到websocket链接地址")
220
+ return
221
+
222
+ if "tenclass.net" not in self.url and "xiaozhi.me" not in self.url:
223
+ logger.warning("[websocket] 检测到非官方服务器,当前链接地址: %s", self.url)
224
+
225
+ self.websocket_token = ota_info["websocket"]["token"]
226
+ await self.connect_websocket(self.websocket_token)
227
+
228
+ if not await self.is_activate(ota_info):
229
+ self.iot_task = asyncio.create_task(self._activate_iot_device(license_key, ota_info))
230
+ logger.debug("[IOT] 设备未激活")
231
+
232
+ if self.send_wake:
233
+ await self.send_wake_word()
234
+
235
+ async def send_audio(self, pcm: bytes) -> None:
236
+ """发送音频数据"""
237
+ if not self.websocket:
238
+ return
239
+
240
+ state = self.websocket.state
241
+ if state == websockets.protocol.State.OPEN:
242
+ opus_data = await self.audio_opus.pcm_to_opus(pcm)
243
+ await self.websocket.send(opus_data)
244
+ elif state in [websockets.protocol.State.CLOSED, websockets.protocol.State.CLOSING]:
245
+ if self.wait_device_activated:
246
+ logger.debug("[websocket] Server actively disconnected, reconnecting...")
247
+ await self.connect_websocket(self.websocket_token)
248
+ elif self.message_handler_callback:
249
+ await self.message_handler_callback({"type": "websocket", "state": "close", "source": "sdk.send_audio"})
250
+ self.websocket = None
251
+ logger.debug("[websocket] Server actively disconnected")
252
+
253
+ await asyncio.sleep(0.5)
254
+ else:
255
+ await asyncio.sleep(0.1)
256
+
257
+ async def close(self) -> None:
258
+ """关闭连接"""
259
+ if self.message_handler_task and not self.message_handler_task.done():
260
+ self.message_handler_task.cancel()
261
+ try:
262
+ await self.message_handler_task
263
+ except asyncio.CancelledError:
264
+ pass
265
+
266
+ if self.iot_task:
267
+ self.iot_task.cancel()
268
+
269
+ if self.websocket:
270
+ await self.websocket.close()
xiaozhi_sdk/iot.py CHANGED
@@ -1,50 +1,84 @@
1
- import aiohttp
1
+ import hashlib
2
+ import hmac
2
3
  import json
4
+ import logging
5
+ from typing import Any, Dict, Optional
6
+
7
+ import aiohttp
3
8
 
9
+ from xiaozhi_sdk import __version__
4
10
  from xiaozhi_sdk.config import OTA_URL
5
11
 
6
- USER_AGENT = "XiaoXhi-SDK/1.0"
12
+ # 常量定义
13
+ BOARD_TYPE = "xiaozhi-sdk-box"
14
+ USER_AGENT = "xiaozhi-sdk/{}".format(__version__)
15
+ BOARD_NAME = "xiaozhi-sdk"
16
+
17
+ logger = logging.getLogger("xiaozhi_sdk")
18
+
19
+
20
+ class OtaDevice:
21
+ """
22
+ OTA设备管理类
7
23
 
24
+ 用于处理设备的激活和挑战验证操作。
8
25
 
9
- class OtaDevice(object):
26
+ Attributes:
27
+ ota_url (str): OTA服务器URL
28
+ mac_addr (str): 设备MAC地址
29
+ client_id (str): 客户端ID
30
+ serial_number (str): 设备序列号
31
+ """
32
+
33
+ def __init__(self, mac_addr: str, client_id: str, ota_url: Optional[str] = None, serial_number: str = "") -> None:
34
+ self.ota_url = ota_url or OTA_URL
35
+ self.ota_url = self.ota_url.rstrip("/")
10
36
 
11
- def __init__(self, mac_addr: str, client_id: str, serial_number: str = ""):
12
37
  self.mac_addr = mac_addr
13
38
  self.client_id = client_id
14
39
  self.serial_number = serial_number
15
40
 
16
- async def activate_device(self):
17
- header = {
41
+ def _get_base_headers(self) -> Dict[str, str]:
42
+ return {
18
43
  "user-agent": USER_AGENT,
19
44
  "Device-Id": self.mac_addr,
20
45
  "Client-Id": self.client_id,
21
46
  "Content-Type": "application/json",
22
- "serial-number": self.serial_number,
23
47
  }
48
+
49
+ async def activate_device(self) -> Dict[str, Any]:
50
+ headers = self._get_base_headers()
51
+ headers["serial-number"] = self.serial_number
52
+
24
53
  payload = {
25
- "application": {"version": "1.0.0"},
54
+ "application": {"version": __version__},
26
55
  "board": {
27
- "type": "xiaozhi-sdk-box",
28
- "name": "xiaozhi-sdk-main",
56
+ "type": BOARD_TYPE,
57
+ "name": BOARD_NAME,
29
58
  },
30
59
  }
60
+
31
61
  async with aiohttp.ClientSession() as session:
32
- async with session.post(OTA_URL, headers=header, data=json.dumps(payload)) as response:
33
- data = await response.json()
34
- return data
62
+ async with session.post(self.ota_url + "/", headers=headers, data=json.dumps(payload)) as response:
63
+ if response.status != 200:
64
+ err_text = await response.text()
65
+ raise Exception(err_text)
66
+ response.raise_for_status()
67
+ return await response.json()
68
+
69
+ async def check_activate(self, challenge: str, license_key: str = "") -> bool:
70
+ url = f"{self.ota_url}/activate"
71
+
72
+ headers = self._get_base_headers()
73
+
74
+ hmac_instance = hmac.new(license_key.encode(), challenge.encode(), hashlib.sha256)
75
+ hmac_result = hmac_instance.hexdigest()
76
+
77
+ payload = {"serial_number": self.serial_number, "challenge": challenge, "hmac": hmac_result}
35
78
 
36
- async def check_activate(self, challenge: str):
37
- url = OTA_URL + "/activate"
38
- header = {
39
- "user-agent": USER_AGENT,
40
- "Device-Id": self.mac_addr,
41
- "Client-Id": self.client_id,
42
- "Content-Type": "application/json",
43
- }
44
- payload = {
45
- "serial_number": self.serial_number,
46
- "challenge": challenge,
47
- }
48
79
  async with aiohttp.ClientSession() as session:
49
- async with session.post(url, headers=header, data=json.dumps(payload)) as response:
50
- return response.status == 200
80
+ async with session.post(url, headers=headers, data=json.dumps(payload)) as response:
81
+ is_ok = response.status == 200
82
+ if not is_ok:
83
+ logger.debug("[IOT] wait for activate device...")
84
+ return is_ok
xiaozhi_sdk/mcp.py CHANGED
@@ -1,75 +1,140 @@
1
+ import asyncio
1
2
  import json
3
+ import logging
2
4
 
5
+ import numpy as np
3
6
  import requests
4
7
 
5
- from xiaozhi_sdk.config import VL_URL
6
- from xiaozhi_sdk.data import mcp_initialize_payload, mcp_tools_payload, mcp_tool_conf
8
+ from xiaozhi_sdk.utils.mcp_data import mcp_initialize_payload, mcp_tool_conf, mcp_tools_payload
9
+ from xiaozhi_sdk.utils.mcp_tool import _get_random_music_info
10
+
11
+ logger = logging.getLogger("xiaozhi_sdk")
7
12
 
8
13
 
9
14
  class McpTool(object):
10
15
 
11
16
  def __init__(self):
12
17
  self.session_id = ""
13
- self.vl_token = ""
18
+ self.explain_url = ""
19
+ self.explain_token = ""
14
20
  self.websocket = None
15
21
  self.tool_func = {}
22
+ self.is_playing = False
16
23
 
17
24
  def get_mcp_json(self, payload: dict):
18
25
  return json.dumps({"session_id": self.session_id, "type": "mcp", "payload": payload})
19
26
 
20
27
  def _build_response(self, request_id: str, content: str, is_error: bool = False):
21
- return self.get_mcp_json({
22
- "jsonrpc": "2.0",
23
- "id": request_id,
24
- "result": {
25
- "content": [{"type": "text", "text": content}],
26
- "isError": is_error,
27
- },
28
- })
28
+ return self.get_mcp_json(
29
+ {
30
+ "jsonrpc": "2.0",
31
+ "id": request_id,
32
+ "result": {
33
+ "content": [{"type": "text", "text": content}],
34
+ "isError": is_error,
35
+ },
36
+ }
37
+ )
29
38
 
30
39
  async def analyze_image(self, img_byte: bytes, question: str = "这张图片里有什么?"):
31
- headers = {"Authorization": f"Bearer {self.vl_token}"}
40
+ headers = {"Authorization": f"Bearer {self.explain_token}"}
32
41
  files = {"file": ("camera.jpg", img_byte, "image/jpeg")}
33
42
  payload = {"question": question}
34
-
35
- response = requests.post(VL_URL, files=files, data=payload, headers=headers)
36
- return response.json()
43
+ try:
44
+ response = requests.post(self.explain_url, files=files, data=payload, headers=headers, timeout=5)
45
+ res_json = response.json()
46
+ except Exception:
47
+ return "网络异常", True
48
+ if res_json.get("error"):
49
+ return res_json, True
50
+ return res_json, False
51
+
52
+ async def play_custom_music(self, tool_func, arguments):
53
+ pcm_array, is_error = await tool_func(arguments)
54
+ while True:
55
+ if not self.is_playing:
56
+ break
57
+ await asyncio.sleep(0.1)
58
+ pcm_array = await self.audio_opus.change_sample_rate(np.array(pcm_array))
59
+ self.output_audio_queue.extend(pcm_array)
37
60
 
38
61
  async def mcp_tool_call(self, mcp_json: dict):
39
62
  tool_name = mcp_json["params"]["name"]
40
63
  tool_func = self.tool_func[tool_name]
41
-
64
+ arguments = mcp_json["params"]["arguments"]
65
+ try:
66
+ if tool_name == "async_play_custom_music":
67
+
68
+ # v1 返回 url
69
+ music_info = await _get_random_music_info(arguments["id_list"])
70
+ if not music_info.get("url"):
71
+ tool_res, is_error = {"message": "播放失败"}, True
72
+ else:
73
+ tool_res, is_error = {"message": "正在为你播放: {}".format(arguments["music_name"])}, False
74
+ data = {
75
+ "type": "music",
76
+ "state": "start",
77
+ "url": music_info["url"],
78
+ "text": arguments["music_name"],
79
+ "source": "sdk.mcp_music_tool",
80
+ }
81
+ await self.message_handler_callback(data)
82
+
83
+ # v2 音频放到输出
84
+ # asyncio.create_task(self.play_custom_music(tool_func, arguments))
85
+
86
+ elif tool_name.startswith("async_"):
87
+ tool_res, is_error = await tool_func(arguments)
88
+ else:
89
+ tool_res, is_error = tool_func(arguments)
90
+ except Exception as e:
91
+ logger.error("[MCP] tool_func error: %s", e)
92
+ return self._build_response(mcp_json["id"], "工具调用失败", True)
93
+
42
94
  if tool_name == "take_photo":
43
- res = await self.analyze_image(tool_func(None), mcp_json["params"]["arguments"]["question"])
44
- else:
45
- res = tool_func(mcp_json["params"]["arguments"])
46
-
47
- content = json.dumps(res, ensure_ascii=False)
48
- return self._build_response(mcp_json["id"], content)
95
+ tool_res, is_error = await self.analyze_image(tool_res, mcp_json["params"]["arguments"]["question"])
96
+
97
+ content = json.dumps(tool_res, ensure_ascii=False)
98
+ return self._build_response(mcp_json["id"], content, is_error)
49
99
 
50
100
  async def mcp(self, data: dict):
51
101
  payload = data["payload"]
52
102
  method = payload["method"]
53
103
 
54
104
  if method == "initialize":
55
- self.vl_token = payload["params"]["capabilities"]["vision"]["token"]
105
+ self.explain_url = payload["params"]["capabilities"]["vision"]["url"]
106
+ self.explain_token = payload["params"]["capabilities"]["vision"]["token"]
107
+
56
108
  mcp_initialize_payload["id"] = payload["id"]
57
109
  await self.websocket.send(self.get_mcp_json(mcp_initialize_payload))
58
110
 
111
+ elif method == "notifications/initialized":
112
+ # print("\nMCP 工具初始化")
113
+ pass
114
+
115
+ elif method == "notifications/cancelled":
116
+ logger.error("[MCP] 工具加载失败")
117
+
59
118
  elif method == "tools/list":
60
119
  mcp_tools_payload["id"] = payload["id"]
120
+ tool_list = []
61
121
  for name, func in self.tool_func.items():
62
122
  if func:
63
- mcp_tool_conf[name]["name"] = name
64
- mcp_tools_payload["result"]["tools"].append(mcp_tool_conf[name])
65
-
123
+ tool_list.append(name)
124
+ target_name = name.removeprefix("async_")
125
+ mcp_tool_conf[target_name]["name"] = name
126
+ mcp_tools_payload["result"]["tools"].append(mcp_tool_conf[target_name])
66
127
  await self.websocket.send(self.get_mcp_json(mcp_tools_payload))
128
+ logger.debug("[MCP] 加载成功,当前可用工具列表为:%s", tool_list)
67
129
 
68
130
  elif method == "tools/call":
69
- print("tools/call", payload)
70
131
  tool_name = payload["params"]["name"]
71
132
  if not self.tool_func.get(tool_name):
72
- raise Exception("Tool not found")
133
+ logger.warning("[MCP] Tool not found: %s", tool_name)
134
+ return
73
135
 
74
136
  mcp_res = await self.mcp_tool_call(payload)
75
137
  await self.websocket.send(mcp_res)
138
+ logger.debug("[MCP] Tool %s called", tool_name)
139
+ else:
140
+ logger.warning("[MCP] unknown method %s: %s", method, payload)
xiaozhi_sdk/opus.py CHANGED
@@ -1,13 +1,11 @@
1
- import os
2
- from xiaozhi_sdk import INPUT_SERVER_AUDIO_SAMPLE_RATE
1
+ import math
3
2
 
4
- # 设置 opus 库路径
5
- os.environ["DYLD_LIBRARY_PATH"] = "/opt/homebrew/lib:" + os.environ.get("DYLD_LIBRARY_PATH", "")
6
- os.environ["LIBRARY_PATH"] = "/opt/homebrew/lib:" + os.environ.get("LIBRARY_PATH", "")
7
3
  import av
8
4
  import numpy as np
9
5
  import opuslib
10
6
 
7
+ from xiaozhi_sdk.config import INPUT_SERVER_AUDIO_SAMPLE_RATE
8
+
11
9
 
12
10
  class AudioOpus:
13
11
 
@@ -33,11 +31,16 @@ class AudioOpus:
33
31
  pcm_bytes = pcm_array.tobytes()
34
32
  return self.opus_encoder.encode(pcm_bytes, 960)
35
33
 
36
- async def change_sample_rate(self, pcm_array):
34
+ @staticmethod
35
+ def to_n_960(samples) -> np.ndarray:
36
+ n = math.ceil(samples.shape[0] / 960)
37
+ arr_padded = np.pad(samples, (0, 960 * n - samples.shape[0]), mode="constant", constant_values=0)
38
+ return arr_padded.reshape(n, 960)
39
+
40
+ async def change_sample_rate(self, pcm_array) -> np.ndarray:
37
41
  if self.sample_rate == INPUT_SERVER_AUDIO_SAMPLE_RATE:
38
- return pcm_array.reshape(1, 960)
42
+ return self.to_n_960(pcm_array)
39
43
 
40
- c = int(self.sample_rate / INPUT_SERVER_AUDIO_SAMPLE_RATE)
41
44
  frame = av.AudioFrame.from_ndarray(np.array(pcm_array).reshape(1, -1), format="s16", layout="mono")
42
45
  frame.sample_rate = INPUT_SERVER_AUDIO_SAMPLE_RATE # Assuming input is 16kHz
43
46
  resampled_frames = self.resampler.resample(frame)
@@ -49,10 +52,9 @@ class AudioOpus:
49
52
  )
50
53
  new_frame.sample_rate = self.sample_rate
51
54
  new_samples = new_frame.to_ndarray().flatten()
52
- arr_padded = np.pad(new_samples, (0, 960 * c - new_samples.shape[0]), mode="constant", constant_values=0)
53
- return arr_padded.reshape(c, 960)
55
+ return self.to_n_960(new_samples)
54
56
 
55
- async def opus_to_pcm(self, opus):
57
+ async def opus_to_pcm(self, opus) -> np.ndarray:
56
58
  pcm_data = self.opus_decoder.decode(opus, 960)
57
59
  pcm_array = np.frombuffer(pcm_data, dtype=np.int16)
58
60
  samples = await self.change_sample_rate(pcm_array)
@@ -0,0 +1,57 @@
1
+ import ctypes.util
2
+ import os
3
+ import platform
4
+ import wave
5
+
6
+
7
+ def get_wav_info(file_path):
8
+ with wave.open(file_path, "rb") as wav_file:
9
+ return wav_file.getframerate(), wav_file.getnchannels()
10
+
11
+
12
+ def read_audio_file(file_path):
13
+ """
14
+ 读取音频文件并通过yield返回PCM流
15
+
16
+ Args:
17
+ file_path (str): 音频文件路径
18
+
19
+ Yields:
20
+ bytes: PCM音频数据块
21
+ """
22
+ with wave.open(file_path, "rb") as wav_file:
23
+ while True:
24
+ pcm = wav_file.readframes(960) # 每次读取960帧(60ms的音频数据)
25
+ if not pcm:
26
+ break
27
+ yield pcm
28
+
29
+
30
+ def setup_opus():
31
+
32
+ def fake_find_library(name):
33
+ current_dir = os.path.dirname(os.path.abspath(__file__))
34
+ if name == "opus":
35
+ system = platform.system().lower()
36
+ machine = platform.machine().lower()
37
+
38
+ # 检测架构
39
+ if machine in ["x86_64", "amd64", "x64"]:
40
+ arch = "x64"
41
+ elif machine in ["arm64", "aarch64"]:
42
+ arch = "arm64"
43
+ else:
44
+ # 默认使用x64作为回退
45
+ arch = "x64"
46
+
47
+ if system == "darwin": # macOS
48
+ return f"{current_dir}/../../file/opus/macos-{arch}-libopus.dylib"
49
+ elif system == "windows": # Windows
50
+ return f"{current_dir}/../../file/opus/windows-opus.dll"
51
+ elif system == "linux": # Linux
52
+ return f"{current_dir}/../../file/opus/linux-{arch}-libopus.so"
53
+ else:
54
+ # 默认情况,尝试系统查找
55
+ return ctypes.util.find_library(name)
56
+
57
+ ctypes.util.find_library = fake_find_library