xiaozhi-sdk 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xiaozhi_sdk/core.py ADDED
@@ -0,0 +1,269 @@
1
+ import asyncio
2
+ import json
3
+ import logging
4
+ import os
5
+ import re
6
+ import uuid
7
+ from collections import deque
8
+ from typing import Any, Callable, Deque, Dict, Optional
9
+
10
+ import websockets
11
+
12
+ from xiaozhi_sdk.config import INPUT_SERVER_AUDIO_SAMPLE_RATE
13
+ from xiaozhi_sdk.iot import OtaDevice
14
+ from xiaozhi_sdk.mcp import McpTool
15
+ from xiaozhi_sdk.utils import get_wav_info, read_audio_file, setup_opus
16
+
17
+ setup_opus()
18
+ from xiaozhi_sdk.opus import AudioOpus
19
+
20
+ logger = logging.getLogger("xiaozhi_sdk")
21
+
22
+
23
+ class XiaoZhiWebsocket(McpTool):
24
+
25
+ def __init__(
26
+ self,
27
+ message_handler_callback: Optional[Callable] = None,
28
+ url: Optional[str] = None,
29
+ ota_url: Optional[str] = None,
30
+ audio_sample_rate: int = 16000,
31
+ audio_channels: int = 1,
32
+ wake_word: str = "",
33
+ ):
34
+ super().__init__()
35
+ self.url = url
36
+ self.ota_url = ota_url
37
+ self.audio_channels = audio_channels
38
+ self.audio_opus = AudioOpus(audio_sample_rate, audio_channels)
39
+ self.wake_word = wake_word
40
+
41
+ # 客户端标识
42
+ self.client_id = str(uuid.uuid4())
43
+ self.mac_addr: Optional[str] = None
44
+ self.aec = False
45
+ self.websocket_token = ""
46
+
47
+ # 回调函数
48
+ self.message_handler_callback = message_handler_callback
49
+
50
+ # 连接状态
51
+ self.hello_received = asyncio.Event()
52
+ self.session_id = ""
53
+ self.websocket = None
54
+ self.message_handler_task: Optional[asyncio.Task] = None
55
+
56
+ # 输出音频
57
+ self.output_audio_queue: Deque[bytes] = deque()
58
+ self.is_playing: bool = False
59
+
60
+ # OTA设备
61
+ self.ota: Optional[OtaDevice] = None
62
+ self.iot_task: Optional[asyncio.Task] = None
63
+ self.wait_device_activated: bool = False
64
+
65
+ # mcp工具
66
+ self.mcp_tool_dict = {}
67
+
68
+ async def _send_hello(self, aec: bool) -> None:
69
+ """发送hello消息"""
70
+ hello_message = {
71
+ "type": "hello",
72
+ "version": 1,
73
+ "features": {"mcp": True, "aec": aec},
74
+ "transport": "websocket",
75
+ "audio_params": {
76
+ "format": "opus",
77
+ "sample_rate": 16000,
78
+ "channels": 1,
79
+ "frame_duration": 60,
80
+ },
81
+ }
82
+ await self.websocket.send(json.dumps(hello_message))
83
+ await asyncio.wait_for(self.hello_received.wait(), timeout=10.0)
84
+
85
+ async def _start_listen(self) -> None:
86
+ """开始监听"""
87
+ listen_message = {"session_id": self.session_id, "type": "listen", "state": "start", "mode": "realtime"}
88
+ await self.websocket.send(json.dumps(listen_message))
89
+
90
+ async def is_activate(self, ota_info):
91
+ """是否激活"""
92
+ if ota_info.get("activation"):
93
+ return False
94
+
95
+ return True
96
+
97
+ async def _activate_iot_device(self, license_key: str, ota_info: Dict[str, Any]) -> None:
98
+ """激活IoT设备"""
99
+ if not self.ota:
100
+ return
101
+
102
+ challenge = ota_info["activation"]["challenge"]
103
+ await asyncio.sleep(3)
104
+ self.wait_device_activated = True
105
+ for _ in range(10):
106
+ if await self.ota.check_activate(challenge, license_key):
107
+ self.wait_device_activated = False
108
+ break
109
+ await asyncio.sleep(3)
110
+
111
+ async def _send_demo_audio(self) -> None:
112
+ """发送演示音频"""
113
+ current_dir = os.path.dirname(os.path.abspath(__file__))
114
+ wav_path = os.path.join(current_dir, "../file/audio/greet.wav")
115
+ framerate, channels = get_wav_info(wav_path)
116
+ audio_opus = AudioOpus(framerate, channels)
117
+
118
+ for pcm_data in read_audio_file(wav_path):
119
+ opus_data = await audio_opus.pcm_to_opus(pcm_data)
120
+ await self.websocket.send(opus_data)
121
+ await self.send_silence_audio()
122
+
123
+ async def send_wake_word(self, wake_word: str) -> None:
124
+ """发送唤醒词"""
125
+ await self.websocket.send(
126
+ json.dumps({"session_id": self.session_id, "type": "listen", "state": "detect", "text": wake_word})
127
+ )
128
+
129
+ async def send_silence_audio(self, duration_seconds: float = 1.2) -> None:
130
+ """发送静音音频"""
131
+ frames_count = int(duration_seconds * 1000 / 60)
132
+ pcm_frame = b"\x00\x00" * int(INPUT_SERVER_AUDIO_SAMPLE_RATE / 1000 * 60)
133
+
134
+ for _ in range(frames_count):
135
+ await self.send_audio(pcm_frame)
136
+
137
+ async def _handle_websocket_message(self, message: Any) -> None:
138
+ """处理接受到的WebSocket消息"""
139
+
140
+ # audio data
141
+ if isinstance(message, bytes):
142
+ pcm_array = await self.audio_opus.opus_to_pcm(message)
143
+ self.output_audio_queue.extend(pcm_array)
144
+ return
145
+
146
+ # json message
147
+ data = json.loads(message)
148
+ message_type = data["type"]
149
+ if message_type == "hello":
150
+ self.hello_received.set()
151
+ self.session_id = data["session_id"]
152
+ elif message_type == "mcp":
153
+ await self.mcp(data)
154
+ elif self.message_handler_callback:
155
+ if data["type"] == "tts":
156
+ if data["state"] == "sentence_start":
157
+ self.is_playing = True
158
+ # self.output_audio_queue.clear()
159
+ else:
160
+ self.is_playing = False
161
+
162
+ await self.message_handler_callback(data)
163
+
164
+ async def _message_handler(self) -> None:
165
+ """消息处理器"""
166
+ try:
167
+ async for message in self.websocket:
168
+ await self._handle_websocket_message(message)
169
+ except websockets.ConnectionClosed:
170
+ if self.message_handler_callback:
171
+ await self.message_handler_callback(
172
+ {"type": "websocket", "state": "close", "source": "sdk.message_handler"}
173
+ )
174
+ logger.debug("[websocket] close")
175
+
176
+ async def set_mcp_tool(self, mcp_tool_list) -> None:
177
+ """设置MCP工具"""
178
+ for mcp_tool in mcp_tool_list:
179
+ self.mcp_tool_dict[mcp_tool["name"]] = mcp_tool
180
+
181
+ async def connect_websocket(self, websocket_token):
182
+ """连接websocket"""
183
+ headers = {
184
+ "Authorization": "Bearer {}".format(websocket_token),
185
+ "Protocol-Version": "1",
186
+ "Device-Id": self.mac_addr,
187
+ "Client-Id": self.client_id,
188
+ }
189
+ try:
190
+ self.websocket = await websockets.connect(uri=self.url, additional_headers=headers)
191
+ except websockets.exceptions.InvalidMessage as e:
192
+ logger.error("[websocket] 连接失败,请检查网络连接或设备状态。当前链接地址: %s, 错误信息:%s", self.url, e)
193
+ return
194
+ self.message_handler_task = asyncio.create_task(self._message_handler())
195
+
196
+ await self._send_hello(self.aec)
197
+ await self._start_listen()
198
+ logger.debug("[websocket] Connection successful")
199
+ await asyncio.sleep(0.5)
200
+
201
+ async def init_connection(
202
+ self, mac_addr: str, aec: bool = False, serial_number: str = "", license_key: str = ""
203
+ ) -> None:
204
+ """初始化连接"""
205
+ mac_pattern = r"^([0-9A-Fa-f]{2}:){5}[0-9A-Fa-f]{2}$"
206
+ if not re.match(mac_pattern, mac_addr):
207
+ raise ValueError(f"无效的MAC地址格式: {mac_addr}。正确格式应为 XX:XX:XX:XX:XX:XX")
208
+
209
+ self.mac_addr = mac_addr.lower()
210
+ self.aec = aec
211
+
212
+ self.ota = OtaDevice(self.mac_addr, self.client_id, self.ota_url, serial_number)
213
+ ota_info = await self.ota.activate_device()
214
+ ws_url = ota_info.get("websocket", {}).get("url")
215
+ self.url = self.url or ws_url
216
+
217
+ if not self.url:
218
+ logger.warning("[websocket] 未找到websocket链接地址")
219
+ return
220
+
221
+ if "tenclass.net" not in self.url and "xiaozhi.me" not in self.url:
222
+ logger.warning("[websocket] 检测到非官方服务器,当前链接地址: %s", self.url)
223
+
224
+ self.websocket_token = ota_info["websocket"]["token"]
225
+ await self.connect_websocket(self.websocket_token)
226
+
227
+ if not await self.is_activate(ota_info):
228
+ self.iot_task = asyncio.create_task(self._activate_iot_device(license_key, ota_info))
229
+ logger.debug("[IOT] 设备未激活")
230
+
231
+ if self.wake_word:
232
+ await self.send_wake_word(self.wake_word)
233
+
234
+ async def send_audio(self, pcm: bytes) -> None:
235
+ """发送音频数据"""
236
+ if not self.websocket:
237
+ return
238
+
239
+ state = self.websocket.state
240
+ if state == websockets.protocol.State.OPEN:
241
+ opus_data = await self.audio_opus.pcm_to_opus(pcm)
242
+ await self.websocket.send(opus_data)
243
+ elif state in [websockets.protocol.State.CLOSED, websockets.protocol.State.CLOSING]:
244
+ if self.wait_device_activated:
245
+ logger.debug("[websocket] Server actively disconnected, reconnecting...")
246
+ await self.connect_websocket(self.websocket_token)
247
+ elif self.message_handler_callback:
248
+ await self.message_handler_callback({"type": "websocket", "state": "close", "source": "sdk.send_audio"})
249
+ self.websocket = None
250
+ logger.debug("[websocket] Server actively disconnected")
251
+
252
+ await asyncio.sleep(0.5)
253
+ else:
254
+ await asyncio.sleep(0.1)
255
+
256
+ async def close(self) -> None:
257
+ """关闭连接"""
258
+ if self.message_handler_task and not self.message_handler_task.done():
259
+ self.message_handler_task.cancel()
260
+ try:
261
+ await self.message_handler_task
262
+ except asyncio.CancelledError:
263
+ pass
264
+
265
+ if self.iot_task:
266
+ self.iot_task.cancel()
267
+
268
+ if self.websocket:
269
+ await self.websocket.close()
xiaozhi_sdk/iot.py CHANGED
@@ -1,50 +1,84 @@
1
- import aiohttp
1
+ import hashlib
2
+ import hmac
2
3
  import json
4
+ import logging
5
+ from typing import Any, Dict, Optional
6
+
7
+ import aiohttp
3
8
 
9
+ from xiaozhi_sdk import __version__
4
10
  from xiaozhi_sdk.config import OTA_URL
5
11
 
6
- USER_AGENT = "XiaoXhi-SDK/1.0"
12
+ # 常量定义
13
+ BOARD_TYPE = "xiaozhi-sdk-box"
14
+ USER_AGENT = "xiaozhi-sdk/{}".format(__version__)
15
+ BOARD_NAME = "xiaozhi-sdk"
16
+
17
+ logger = logging.getLogger("xiaozhi_sdk")
18
+
19
+
20
+ class OtaDevice:
21
+ """
22
+ OTA设备管理类
7
23
 
24
+ 用于处理设备的激活和挑战验证操作。
8
25
 
9
- class OtaDevice(object):
26
+ Attributes:
27
+ ota_url (str): OTA服务器URL
28
+ mac_addr (str): 设备MAC地址
29
+ client_id (str): 客户端ID
30
+ serial_number (str): 设备序列号
31
+ """
32
+
33
+ def __init__(self, mac_addr: str, client_id: str, ota_url: Optional[str] = None, serial_number: str = "") -> None:
34
+ self.ota_url = ota_url or OTA_URL
35
+ self.ota_url = self.ota_url.rstrip("/")
10
36
 
11
- def __init__(self, mac_addr: str, client_id: str, serial_number: str = ""):
12
37
  self.mac_addr = mac_addr
13
38
  self.client_id = client_id
14
39
  self.serial_number = serial_number
15
40
 
16
- async def activate_device(self):
17
- header = {
41
+ def _get_base_headers(self) -> Dict[str, str]:
42
+ return {
18
43
  "user-agent": USER_AGENT,
19
44
  "Device-Id": self.mac_addr,
20
45
  "Client-Id": self.client_id,
21
46
  "Content-Type": "application/json",
22
- "serial-number": self.serial_number,
23
47
  }
48
+
49
+ async def activate_device(self) -> Dict[str, Any]:
50
+ headers = self._get_base_headers()
51
+ headers["serial-number"] = self.serial_number
52
+
24
53
  payload = {
25
- "application": {"version": "1.0.0"},
54
+ "application": {"version": __version__},
26
55
  "board": {
27
- "type": "xiaozhi-sdk-box",
28
- "name": "xiaozhi-sdk-main",
56
+ "type": BOARD_TYPE,
57
+ "name": BOARD_NAME,
29
58
  },
30
59
  }
60
+
31
61
  async with aiohttp.ClientSession() as session:
32
- async with session.post(OTA_URL, headers=header, data=json.dumps(payload)) as response:
33
- data = await response.json()
34
- return data
62
+ async with session.post(self.ota_url + "/", headers=headers, data=json.dumps(payload)) as response:
63
+ if response.status != 200:
64
+ err_text = await response.text()
65
+ raise Exception(err_text)
66
+ response.raise_for_status()
67
+ return await response.json()
68
+
69
+ async def check_activate(self, challenge: str, license_key: str = "") -> bool:
70
+ url = f"{self.ota_url}/activate"
71
+
72
+ headers = self._get_base_headers()
73
+
74
+ hmac_instance = hmac.new(license_key.encode(), challenge.encode(), hashlib.sha256)
75
+ hmac_result = hmac_instance.hexdigest()
76
+
77
+ payload = {"serial_number": self.serial_number, "challenge": challenge, "hmac": hmac_result}
35
78
 
36
- async def check_activate(self, challenge: str):
37
- url = OTA_URL + "/activate"
38
- header = {
39
- "user-agent": USER_AGENT,
40
- "Device-Id": self.mac_addr,
41
- "Client-Id": self.client_id,
42
- "Content-Type": "application/json",
43
- }
44
- payload = {
45
- "serial_number": self.serial_number,
46
- "challenge": challenge,
47
- }
48
79
  async with aiohttp.ClientSession() as session:
49
- async with session.post(url, headers=header, data=json.dumps(payload)) as response:
50
- return response.status == 200
80
+ async with session.post(url, headers=headers, data=json.dumps(payload)) as response:
81
+ is_ok = response.status == 200
82
+ if not is_ok:
83
+ logger.debug("[IOT] wait for activate device...")
84
+ return is_ok
xiaozhi_sdk/mcp.py CHANGED
@@ -1,75 +1,171 @@
1
+ import asyncio
2
+ import copy
1
3
  import json
4
+ import logging
5
+ import time
6
+ from typing import Any, Dict
2
7
 
8
+ import numpy as np
3
9
  import requests
4
10
 
5
- from xiaozhi_sdk.config import VL_URL
6
- from xiaozhi_sdk.data import mcp_initialize_payload, mcp_tools_payload, mcp_tool_conf
11
+ from xiaozhi_sdk.utils.mcp_tool import _get_random_music_info
12
+
13
+ logger = logging.getLogger("xiaozhi_sdk")
14
+
15
+ mcp_initialize_payload: Dict[str, Any] = {
16
+ "jsonrpc": "2.0",
17
+ "id": 1,
18
+ "result": {
19
+ "protocolVersion": "2024-11-05",
20
+ "capabilities": {"tools": {}},
21
+ "serverInfo": {"name": "", "version": "0.0.1"},
22
+ },
23
+ }
24
+
25
+ mcp_tools_payload: Dict[str, Any] = {
26
+ "jsonrpc": "2.0",
27
+ "id": 2,
28
+ "result": {"tools": []},
29
+ }
7
30
 
8
31
 
9
32
  class McpTool(object):
10
33
 
11
34
  def __init__(self):
12
35
  self.session_id = ""
13
- self.vl_token = ""
36
+ self.explain_url = ""
37
+ self.explain_token = ""
14
38
  self.websocket = None
15
- self.tool_func = {}
39
+ self.mcp_tool_dict = {}
40
+ self.is_playing = False
41
+ self.message_handler_callback = None
16
42
 
17
43
  def get_mcp_json(self, payload: dict):
18
44
  return json.dumps({"session_id": self.session_id, "type": "mcp", "payload": payload})
19
45
 
20
46
  def _build_response(self, request_id: str, content: str, is_error: bool = False):
21
- return self.get_mcp_json({
22
- "jsonrpc": "2.0",
23
- "id": request_id,
24
- "result": {
25
- "content": [{"type": "text", "text": content}],
26
- "isError": is_error,
27
- },
28
- })
47
+ return self.get_mcp_json(
48
+ {
49
+ "jsonrpc": "2.0",
50
+ "id": request_id,
51
+ "result": {
52
+ "content": [{"type": "text", "text": content}],
53
+ "isError": is_error,
54
+ },
55
+ }
56
+ )
29
57
 
30
58
  async def analyze_image(self, img_byte: bytes, question: str = "这张图片里有什么?"):
31
- headers = {"Authorization": f"Bearer {self.vl_token}"}
59
+ headers = {"Authorization": f"Bearer {self.explain_token}"}
32
60
  files = {"file": ("camera.jpg", img_byte, "image/jpeg")}
33
61
  payload = {"question": question}
34
-
35
- response = requests.post(VL_URL, files=files, data=payload, headers=headers)
36
- return response.json()
62
+ init_time = time.time()
63
+ try:
64
+ response = requests.post(self.explain_url, files=files, data=payload, headers=headers, timeout=8)
65
+ res_json = response.json()
66
+ except Exception as e:
67
+ logger.error("[MCP] 图片解析 error: %s", e)
68
+ return "网络异常", True
69
+ if res_json.get("error"):
70
+ return res_json, True
71
+ logger.debug("[MCP] 图片解析耗时:%s", time.time() - init_time)
72
+ return res_json, False
73
+
74
+ async def play_custom_music(self, tool_func, arguments):
75
+ pcm_array, is_error = await tool_func(arguments)
76
+ while True:
77
+ if not self.is_playing:
78
+ break
79
+ await asyncio.sleep(0.1)
80
+ pcm_array = await self.audio_opus.change_sample_rate(np.array(pcm_array))
81
+ self.output_audio_queue.extend(pcm_array)
37
82
 
38
83
  async def mcp_tool_call(self, mcp_json: dict):
39
84
  tool_name = mcp_json["params"]["name"]
40
- tool_func = self.tool_func[tool_name]
41
-
85
+ mcp_tool = self.mcp_tool_dict[tool_name]
86
+ arguments = mcp_json["params"]["arguments"]
87
+ try:
88
+ if tool_name == "play_custom_music":
89
+ # v1 返回 url
90
+ music_info = await _get_random_music_info(arguments["id_list"])
91
+ if not music_info.get("url"):
92
+ tool_res, is_error = {"message": "播放失败"}, True
93
+ else:
94
+ tool_res, is_error = {"message": "正在为你播放: {}".format(arguments["music_name"])}, False
95
+ data = {
96
+ "type": "music",
97
+ "state": "start",
98
+ "url": music_info["url"],
99
+ "text": arguments["music_name"],
100
+ "source": "sdk.mcp_music_tool",
101
+ }
102
+ await self.message_handler_callback(data)
103
+
104
+ # v2 音频放到输出
105
+ # asyncio.create_task(self.play_custom_music(tool_func, arguments))
106
+
107
+ elif mcp_tool.get("is_async"):
108
+ tool_res, is_error = await mcp_tool["tool_func"](arguments)
109
+ else:
110
+ tool_res, is_error = mcp_tool["tool_func"](arguments)
111
+ except Exception as e:
112
+ logger.error("[MCP] tool_name: %s, error: %s", tool_name, e)
113
+ return self._build_response(mcp_json["id"], "工具调用失败", True)
114
+
115
+ if is_error:
116
+ logger.error("[MCP] tool_name: %s, error: %s", tool_name, tool_res)
117
+ return self._build_response(mcp_json["id"], "工具调用失败, {}".format(tool_res), True)
118
+
42
119
  if tool_name == "take_photo":
43
- res = await self.analyze_image(tool_func(None), mcp_json["params"]["arguments"]["question"])
44
- else:
45
- res = tool_func(mcp_json["params"]["arguments"])
46
-
47
- content = json.dumps(res, ensure_ascii=False)
48
- return self._build_response(mcp_json["id"], content)
120
+ tool_res, is_error = await self.analyze_image(tool_res, mcp_json["params"]["arguments"]["question"])
121
+
122
+ content = json.dumps(tool_res, ensure_ascii=False)
123
+ return self._build_response(mcp_json["id"], content, is_error)
49
124
 
50
125
  async def mcp(self, data: dict):
51
126
  payload = data["payload"]
52
127
  method = payload["method"]
53
128
 
54
129
  if method == "initialize":
55
- self.vl_token = payload["params"]["capabilities"]["vision"]["token"]
130
+ self.explain_url = payload["params"]["capabilities"]["vision"]["url"]
131
+ # self.explain_url = "http://82.157.143.133:8000/vision/explain"
132
+ self.explain_token = payload["params"]["capabilities"]["vision"]["token"]
133
+
56
134
  mcp_initialize_payload["id"] = payload["id"]
57
135
  await self.websocket.send(self.get_mcp_json(mcp_initialize_payload))
58
136
 
137
+ elif method == "notifications/initialized":
138
+ # print("\nMCP 工具初始化")
139
+ pass
140
+
141
+ elif method == "notifications/cancelled":
142
+ logger.error("[MCP] 工具加载失败")
143
+
59
144
  elif method == "tools/list":
60
145
  mcp_tools_payload["id"] = payload["id"]
61
- for name, func in self.tool_func.items():
62
- if func:
63
- mcp_tool_conf[name]["name"] = name
64
- mcp_tools_payload["result"]["tools"].append(mcp_tool_conf[name])
146
+ tool_name_list = []
147
+ mcp_tool_dict = copy.deepcopy(self.mcp_tool_dict)
148
+ for _, mcp_tool in mcp_tool_dict.items():
149
+ tool_name_list.append(mcp_tool["name"])
150
+ tool_func = mcp_tool.pop("tool_func", None)
151
+ if not tool_func:
152
+ logger.error("[MCP] Tool %s has no tool_func", mcp_tool["name"])
153
+ return
154
+ mcp_tool.pop("is_async", None)
155
+ mcp_tools_payload["result"]["tools"].append(mcp_tool)
65
156
 
66
157
  await self.websocket.send(self.get_mcp_json(mcp_tools_payload))
158
+ logger.debug("[MCP] 加载成功,当前可用工具列表为:%s", tool_name_list)
67
159
 
68
160
  elif method == "tools/call":
69
- print("tools/call", payload)
70
161
  tool_name = payload["params"]["name"]
71
- if not self.tool_func.get(tool_name):
72
- raise Exception("Tool not found")
162
+
163
+ if not self.mcp_tool_dict.get(tool_name):
164
+ logger.warning("[MCP] Tool not found: %s", tool_name)
165
+ return
73
166
 
74
167
  mcp_res = await self.mcp_tool_call(payload)
75
168
  await self.websocket.send(mcp_res)
169
+ logger.debug("[MCP] Tool %s called", tool_name)
170
+ else:
171
+ logger.warning("[MCP] unknown method %s: %s", method, payload)
xiaozhi_sdk/opus.py CHANGED
@@ -1,13 +1,11 @@
1
- import os
2
- from xiaozhi_sdk import INPUT_SERVER_AUDIO_SAMPLE_RATE
1
+ import math
3
2
 
4
- # 设置 opus 库路径
5
- os.environ["DYLD_LIBRARY_PATH"] = "/opt/homebrew/lib:" + os.environ.get("DYLD_LIBRARY_PATH", "")
6
- os.environ["LIBRARY_PATH"] = "/opt/homebrew/lib:" + os.environ.get("LIBRARY_PATH", "")
7
3
  import av
8
4
  import numpy as np
9
5
  import opuslib
10
6
 
7
+ from xiaozhi_sdk.config import INPUT_SERVER_AUDIO_SAMPLE_RATE
8
+
11
9
 
12
10
  class AudioOpus:
13
11
 
@@ -33,11 +31,16 @@ class AudioOpus:
33
31
  pcm_bytes = pcm_array.tobytes()
34
32
  return self.opus_encoder.encode(pcm_bytes, 960)
35
33
 
36
- async def change_sample_rate(self, pcm_array):
34
+ @staticmethod
35
+ def to_n_960(samples) -> np.ndarray:
36
+ n = math.ceil(samples.shape[0] / 960)
37
+ arr_padded = np.pad(samples, (0, 960 * n - samples.shape[0]), mode="constant", constant_values=0)
38
+ return arr_padded.reshape(n, 960)
39
+
40
+ async def change_sample_rate(self, pcm_array) -> np.ndarray:
37
41
  if self.sample_rate == INPUT_SERVER_AUDIO_SAMPLE_RATE:
38
- return pcm_array.reshape(1, 960)
42
+ return self.to_n_960(pcm_array)
39
43
 
40
- c = int(self.sample_rate / INPUT_SERVER_AUDIO_SAMPLE_RATE)
41
44
  frame = av.AudioFrame.from_ndarray(np.array(pcm_array).reshape(1, -1), format="s16", layout="mono")
42
45
  frame.sample_rate = INPUT_SERVER_AUDIO_SAMPLE_RATE # Assuming input is 16kHz
43
46
  resampled_frames = self.resampler.resample(frame)
@@ -49,10 +52,9 @@ class AudioOpus:
49
52
  )
50
53
  new_frame.sample_rate = self.sample_rate
51
54
  new_samples = new_frame.to_ndarray().flatten()
52
- arr_padded = np.pad(new_samples, (0, 960 * c - new_samples.shape[0]), mode="constant", constant_values=0)
53
- return arr_padded.reshape(c, 960)
55
+ return self.to_n_960(new_samples)
54
56
 
55
- async def opus_to_pcm(self, opus):
57
+ async def opus_to_pcm(self, opus) -> np.ndarray:
56
58
  pcm_data = self.opus_decoder.decode(opus, 960)
57
59
  pcm_array = np.frombuffer(pcm_data, dtype=np.int16)
58
60
  samples = await self.change_sample_rate(pcm_array)