xiaozhi-sdk 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- file/audio/play_music.wav +0 -0
- xiaozhi_sdk/__init__.py +2 -225
- xiaozhi_sdk/__main__.py +3 -115
- xiaozhi_sdk/cli.py +139 -0
- xiaozhi_sdk/core.py +270 -0
- xiaozhi_sdk/iot.py +7 -1
- xiaozhi_sdk/mcp.py +44 -7
- xiaozhi_sdk/opus.py +13 -7
- xiaozhi_sdk/{utils.py → utils/__init__.py} +3 -3
- xiaozhi_sdk/{data.py → utils/mcp_data.py} +16 -0
- xiaozhi_sdk/utils/mcp_tool.py +88 -0
- {xiaozhi_sdk-0.0.5.dist-info → xiaozhi_sdk-0.0.7.dist-info}/METADATA +17 -22
- xiaozhi_sdk-0.0.7.dist-info/RECORD +26 -0
- xiaozhi_sdk-0.0.5.dist-info/RECORD +0 -22
- /file/opus/{windows-x86_64-opus.dll → windows-opus.dll} +0 -0
- {xiaozhi_sdk-0.0.5.dist-info → xiaozhi_sdk-0.0.7.dist-info}/WHEEL +0 -0
- {xiaozhi_sdk-0.0.5.dist-info → xiaozhi_sdk-0.0.7.dist-info}/licenses/LICENSE +0 -0
- {xiaozhi_sdk-0.0.5.dist-info → xiaozhi_sdk-0.0.7.dist-info}/top_level.txt +0 -0
|
Binary file
|
xiaozhi_sdk/__init__.py
CHANGED
|
@@ -1,226 +1,3 @@
|
|
|
1
|
-
__version__ = "0.0.
|
|
1
|
+
__version__ = "0.0.7"
|
|
2
2
|
|
|
3
|
-
import
|
|
4
|
-
import json
|
|
5
|
-
import logging
|
|
6
|
-
import os
|
|
7
|
-
import re
|
|
8
|
-
import uuid
|
|
9
|
-
from collections import deque
|
|
10
|
-
from typing import Any, Callable, Dict, Optional
|
|
11
|
-
|
|
12
|
-
import websockets
|
|
13
|
-
|
|
14
|
-
from xiaozhi_sdk.config import INPUT_SERVER_AUDIO_SAMPLE_RATE
|
|
15
|
-
from xiaozhi_sdk.iot import OtaDevice
|
|
16
|
-
from xiaozhi_sdk.mcp import McpTool
|
|
17
|
-
from xiaozhi_sdk.utils import get_wav_info, read_audio_file, setup_opus
|
|
18
|
-
|
|
19
|
-
setup_opus()
|
|
20
|
-
from xiaozhi_sdk.opus import AudioOpus
|
|
21
|
-
|
|
22
|
-
logger = logging.getLogger("xiaozhi_sdk")
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
class XiaoZhiWebsocket(McpTool):
|
|
26
|
-
|
|
27
|
-
def __init__(
|
|
28
|
-
self,
|
|
29
|
-
message_handler_callback: Optional[Callable] = None,
|
|
30
|
-
url: Optional[str] = None,
|
|
31
|
-
ota_url: Optional[str] = None,
|
|
32
|
-
audio_sample_rate: int = 16000,
|
|
33
|
-
audio_channels: int = 1,
|
|
34
|
-
):
|
|
35
|
-
super().__init__()
|
|
36
|
-
self.url = url
|
|
37
|
-
self.ota_url = ota_url
|
|
38
|
-
self.audio_channels = audio_channels
|
|
39
|
-
self.audio_opus = AudioOpus(audio_sample_rate, audio_channels)
|
|
40
|
-
|
|
41
|
-
# 客户端标识
|
|
42
|
-
self.client_id = str(uuid.uuid4())
|
|
43
|
-
self.mac_addr: Optional[str] = None
|
|
44
|
-
|
|
45
|
-
# 回调函数
|
|
46
|
-
self.message_handler_callback = message_handler_callback
|
|
47
|
-
|
|
48
|
-
# 连接状态
|
|
49
|
-
self.hello_received = asyncio.Event()
|
|
50
|
-
self.session_id = ""
|
|
51
|
-
self.websocket = None
|
|
52
|
-
self.message_handler_task: Optional[asyncio.Task] = None
|
|
53
|
-
|
|
54
|
-
# 输出音频
|
|
55
|
-
self.output_audio_queue: deque[bytes] = deque()
|
|
56
|
-
|
|
57
|
-
# OTA设备
|
|
58
|
-
self.ota: Optional[OtaDevice] = None
|
|
59
|
-
|
|
60
|
-
async def _send_hello(self, aec: bool) -> None:
|
|
61
|
-
"""发送hello消息"""
|
|
62
|
-
hello_message = {
|
|
63
|
-
"type": "hello",
|
|
64
|
-
"version": 1,
|
|
65
|
-
"features": {"aec": aec, "mcp": True},
|
|
66
|
-
"transport": "websocket",
|
|
67
|
-
"audio_params": {
|
|
68
|
-
"format": "opus",
|
|
69
|
-
"sample_rate": INPUT_SERVER_AUDIO_SAMPLE_RATE,
|
|
70
|
-
"channels": 1,
|
|
71
|
-
"frame_duration": 60,
|
|
72
|
-
},
|
|
73
|
-
}
|
|
74
|
-
await self.websocket.send(json.dumps(hello_message))
|
|
75
|
-
await asyncio.wait_for(self.hello_received.wait(), timeout=10.0)
|
|
76
|
-
|
|
77
|
-
async def _start_listen(self) -> None:
|
|
78
|
-
"""开始监听"""
|
|
79
|
-
|
|
80
|
-
listen_message = {"session_id": self.session_id, "type": "listen", "state": "start", "mode": "realtime"}
|
|
81
|
-
await self.websocket.send(json.dumps(listen_message))
|
|
82
|
-
|
|
83
|
-
async def _activate_iot_device(self, license_key: str, ota_info: Dict[str, Any]) -> None:
|
|
84
|
-
"""激活IoT设备"""
|
|
85
|
-
if not ota_info.get("activation"):
|
|
86
|
-
return
|
|
87
|
-
|
|
88
|
-
if not self.ota:
|
|
89
|
-
return
|
|
90
|
-
|
|
91
|
-
await self._send_demo_audio()
|
|
92
|
-
challenge = ota_info["activation"]["challenge"]
|
|
93
|
-
await asyncio.sleep(3)
|
|
94
|
-
|
|
95
|
-
for _ in range(10):
|
|
96
|
-
if await self.ota.check_activate(challenge, license_key):
|
|
97
|
-
break
|
|
98
|
-
await asyncio.sleep(3)
|
|
99
|
-
|
|
100
|
-
async def _send_demo_audio(self) -> None:
|
|
101
|
-
"""发送演示音频"""
|
|
102
|
-
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
103
|
-
wav_path = os.path.join(current_dir, "../file/audio/greet.wav")
|
|
104
|
-
framerate, channels = get_wav_info(wav_path)
|
|
105
|
-
audio_opus = AudioOpus(framerate, channels)
|
|
106
|
-
|
|
107
|
-
for pcm_data in read_audio_file(wav_path):
|
|
108
|
-
opus_data = await audio_opus.pcm_to_opus(pcm_data)
|
|
109
|
-
await self.websocket.send(opus_data)
|
|
110
|
-
await self.send_silence_audio()
|
|
111
|
-
|
|
112
|
-
async def send_silence_audio(self, duration_seconds: float = 1.2) -> None:
|
|
113
|
-
"""发送静音音频"""
|
|
114
|
-
frames_count = int(duration_seconds * 1000 / 60)
|
|
115
|
-
pcm_frame = b"\x00\x00" * int(INPUT_SERVER_AUDIO_SAMPLE_RATE / 1000 * 60)
|
|
116
|
-
|
|
117
|
-
for _ in range(frames_count):
|
|
118
|
-
await self.send_audio(pcm_frame)
|
|
119
|
-
|
|
120
|
-
async def _handle_websocket_message(self, message: Any) -> None:
|
|
121
|
-
"""处理接受到的WebSocket消息"""
|
|
122
|
-
|
|
123
|
-
# audio data
|
|
124
|
-
if isinstance(message, bytes):
|
|
125
|
-
pcm_array = await self.audio_opus.opus_to_pcm(message)
|
|
126
|
-
self.output_audio_queue.extend(pcm_array)
|
|
127
|
-
return
|
|
128
|
-
|
|
129
|
-
# json message
|
|
130
|
-
data = json.loads(message)
|
|
131
|
-
message_type = data["type"]
|
|
132
|
-
if message_type == "hello":
|
|
133
|
-
self.hello_received.set()
|
|
134
|
-
self.session_id = data["session_id"]
|
|
135
|
-
elif message_type == "mcp":
|
|
136
|
-
await self.mcp(data)
|
|
137
|
-
elif self.message_handler_callback:
|
|
138
|
-
await self.message_handler_callback(data)
|
|
139
|
-
|
|
140
|
-
async def _message_handler(self) -> None:
|
|
141
|
-
"""消息处理器"""
|
|
142
|
-
try:
|
|
143
|
-
async for message in self.websocket:
|
|
144
|
-
await self._handle_websocket_message(message)
|
|
145
|
-
except websockets.ConnectionClosed:
|
|
146
|
-
if self.message_handler_callback:
|
|
147
|
-
await self.message_handler_callback(
|
|
148
|
-
{"type": "websocket", "state": "close", "source": "sdk.message_handler"}
|
|
149
|
-
)
|
|
150
|
-
logger.info("[websocket] close")
|
|
151
|
-
|
|
152
|
-
async def set_mcp_tool_callback(self, tool_func: Dict[str, Callable[..., Any]]) -> None:
|
|
153
|
-
"""设置MCP工具回调函数"""
|
|
154
|
-
self.tool_func = tool_func
|
|
155
|
-
|
|
156
|
-
async def init_connection(
|
|
157
|
-
self, mac_addr: str, aec: bool = False, serial_number: str = "", license_key: str = ""
|
|
158
|
-
) -> None:
|
|
159
|
-
"""初始化连接"""
|
|
160
|
-
# 校验MAC地址格式 XX:XX:XX:XX:XX:XX
|
|
161
|
-
mac_pattern = r"^([0-9A-Fa-f]{2}:){5}[0-9A-Fa-f]{2}$"
|
|
162
|
-
if not re.match(mac_pattern, mac_addr):
|
|
163
|
-
raise ValueError(f"无效的MAC地址格式: {mac_addr}。正确格式应为 XX:XX:XX:XX:XX:XX")
|
|
164
|
-
|
|
165
|
-
self.mac_addr = mac_addr.lower()
|
|
166
|
-
|
|
167
|
-
self.ota = OtaDevice(self.mac_addr, self.client_id, self.ota_url, serial_number)
|
|
168
|
-
ota_info = await self.ota.activate_device()
|
|
169
|
-
ws_url = ota_info.get("websocket", {}).get("url")
|
|
170
|
-
self.url = self.url or ws_url
|
|
171
|
-
|
|
172
|
-
if not self.url:
|
|
173
|
-
logger.warning("[websocket] 未找到websocket链接地址")
|
|
174
|
-
return
|
|
175
|
-
|
|
176
|
-
if "tenclass.net" not in self.url and "xiaozhi.me" not in self.url:
|
|
177
|
-
logger.warning("[websocket] 检测到非官方服务器,当前链接地址: %s", self.url)
|
|
178
|
-
|
|
179
|
-
headers = {
|
|
180
|
-
"Authorization": "Bearer {}".format(ota_info["websocket"]["token"]),
|
|
181
|
-
"Protocol-Version": "1",
|
|
182
|
-
"Device-Id": self.mac_addr,
|
|
183
|
-
"Client-Id": self.client_id,
|
|
184
|
-
}
|
|
185
|
-
try:
|
|
186
|
-
self.websocket = await websockets.connect(uri=self.url, additional_headers=headers)
|
|
187
|
-
except websockets.exceptions.InvalidMessage as e:
|
|
188
|
-
logger.error("[websocket] 连接失败,请检查网络连接或设备状态。当前链接地址: %s, 错误信息:%s", self.url, e)
|
|
189
|
-
return
|
|
190
|
-
self.message_handler_task = asyncio.create_task(self._message_handler())
|
|
191
|
-
|
|
192
|
-
await self._send_hello(aec)
|
|
193
|
-
await self._start_listen()
|
|
194
|
-
asyncio.create_task(self._activate_iot_device(license_key, ota_info))
|
|
195
|
-
await asyncio.sleep(0.5)
|
|
196
|
-
|
|
197
|
-
async def send_audio(self, pcm: bytes) -> None:
|
|
198
|
-
"""发送音频数据"""
|
|
199
|
-
if not self.websocket:
|
|
200
|
-
return
|
|
201
|
-
|
|
202
|
-
state = self.websocket.state
|
|
203
|
-
if state == websockets.protocol.State.OPEN:
|
|
204
|
-
opus_data = await self.audio_opus.pcm_to_opus(pcm)
|
|
205
|
-
await self.websocket.send(opus_data)
|
|
206
|
-
elif state in [websockets.protocol.State.CLOSED, websockets.protocol.State.CLOSING]:
|
|
207
|
-
if self.message_handler_callback:
|
|
208
|
-
await self.message_handler_callback({"type": "websocket", "state": "close", "source": "sdk.send_audio"})
|
|
209
|
-
self.websocket = None
|
|
210
|
-
logger.info("[websocket] close")
|
|
211
|
-
|
|
212
|
-
await asyncio.sleep(0.5)
|
|
213
|
-
else:
|
|
214
|
-
await asyncio.sleep(0.1)
|
|
215
|
-
|
|
216
|
-
async def close(self) -> None:
|
|
217
|
-
"""关闭连接"""
|
|
218
|
-
if self.message_handler_task and not self.message_handler_task.done():
|
|
219
|
-
self.message_handler_task.cancel()
|
|
220
|
-
try:
|
|
221
|
-
await self.message_handler_task
|
|
222
|
-
except asyncio.CancelledError:
|
|
223
|
-
pass
|
|
224
|
-
|
|
225
|
-
if self.websocket:
|
|
226
|
-
await self.websocket.close()
|
|
3
|
+
from xiaozhi_sdk.core import XiaoZhiWebsocket # noqa
|
xiaozhi_sdk/__main__.py
CHANGED
|
@@ -1,123 +1,11 @@
|
|
|
1
|
-
import argparse
|
|
2
|
-
import asyncio
|
|
3
1
|
import logging
|
|
4
|
-
import time
|
|
5
|
-
from collections import deque
|
|
6
|
-
from typing import Optional
|
|
7
2
|
|
|
8
|
-
|
|
9
|
-
import sounddevice as sd
|
|
3
|
+
from xiaozhi_sdk.cli import main
|
|
10
4
|
|
|
11
|
-
from xiaozhi_sdk import XiaoZhiWebsocket
|
|
12
|
-
from xiaozhi_sdk.config import INPUT_SERVER_AUDIO_SAMPLE_RATE
|
|
13
|
-
|
|
14
|
-
# 配置logging
|
|
15
|
-
logging.basicConfig(
|
|
16
|
-
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
|
|
17
|
-
)
|
|
18
5
|
logger = logging.getLogger("xiaozhi_sdk")
|
|
19
6
|
|
|
20
|
-
# 全局状态
|
|
21
|
-
input_audio_buffer: deque[bytes] = deque()
|
|
22
|
-
is_playing_audio = False
|
|
23
|
-
is_end = False
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
async def handle_message(message):
|
|
27
|
-
"""处理接收到的消息"""
|
|
28
|
-
global is_end
|
|
29
|
-
logger.info("message received: %s", message)
|
|
30
|
-
if message["type"] == "websocket" and message["state"] == "close":
|
|
31
|
-
is_end = True
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
async def play_assistant_audio(audio_queue: deque[bytes]):
|
|
35
|
-
"""播放音频流"""
|
|
36
|
-
global is_playing_audio
|
|
37
|
-
|
|
38
|
-
stream = sd.OutputStream(samplerate=INPUT_SERVER_AUDIO_SAMPLE_RATE, channels=1, dtype=np.int16)
|
|
39
|
-
stream.start()
|
|
40
|
-
last_audio_time = None
|
|
41
|
-
|
|
42
|
-
while True:
|
|
43
|
-
if is_end:
|
|
44
|
-
return
|
|
45
|
-
|
|
46
|
-
if not audio_queue:
|
|
47
|
-
await asyncio.sleep(0.01)
|
|
48
|
-
if last_audio_time and time.time() - last_audio_time > 1:
|
|
49
|
-
is_playing_audio = False
|
|
50
|
-
continue
|
|
51
|
-
|
|
52
|
-
is_playing_audio = True
|
|
53
|
-
pcm_data = audio_queue.popleft()
|
|
54
|
-
stream.write(pcm_data)
|
|
55
|
-
last_audio_time = time.time()
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
class XiaoZhiClient:
|
|
59
|
-
"""小智客户端类"""
|
|
60
|
-
|
|
61
|
-
def __init__(
|
|
62
|
-
self,
|
|
63
|
-
url: Optional[str] = None,
|
|
64
|
-
ota_url: Optional[str] = None,
|
|
65
|
-
):
|
|
66
|
-
self.xiaozhi: Optional[XiaoZhiWebsocket] = None
|
|
67
|
-
self.url = url
|
|
68
|
-
self.ota_url = ota_url
|
|
69
|
-
|
|
70
|
-
async def start(self, mac_address: str, serial_number: str = "", license_key: str = ""):
|
|
71
|
-
"""启动客户端连接"""
|
|
72
|
-
self.mac_address = mac_address
|
|
73
|
-
self.xiaozhi = XiaoZhiWebsocket(handle_message, url=self.url, ota_url=self.ota_url)
|
|
74
|
-
await self.xiaozhi.init_connection(
|
|
75
|
-
self.mac_address, aec=False, serial_number=serial_number, license_key=license_key
|
|
76
|
-
)
|
|
77
|
-
asyncio.create_task(play_assistant_audio(self.xiaozhi.output_audio_queue))
|
|
78
|
-
|
|
79
|
-
def audio_callback(self, indata, frames, time, status):
|
|
80
|
-
"""音频输入回调函数"""
|
|
81
|
-
pcm_data = (indata.flatten() * 32767).astype(np.int16).tobytes()
|
|
82
|
-
input_audio_buffer.append(pcm_data)
|
|
83
|
-
|
|
84
|
-
async def process_audio_input(self):
|
|
85
|
-
"""处理音频输入"""
|
|
86
|
-
while True:
|
|
87
|
-
|
|
88
|
-
if is_end:
|
|
89
|
-
return
|
|
90
|
-
|
|
91
|
-
if not input_audio_buffer:
|
|
92
|
-
await asyncio.sleep(0.02)
|
|
93
|
-
continue
|
|
94
|
-
|
|
95
|
-
pcm_data = input_audio_buffer.popleft()
|
|
96
|
-
if not is_playing_audio:
|
|
97
|
-
await self.xiaozhi.send_audio(pcm_data)
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
async def main():
|
|
101
|
-
"""主函数"""
|
|
102
|
-
parser = argparse.ArgumentParser(description="小智SDK客户端")
|
|
103
|
-
parser.add_argument("device", help="设备的MAC地址 (格式: XX:XX:XX:XX:XX:XX)")
|
|
104
|
-
parser.add_argument("--url", help="服务端websocket地址")
|
|
105
|
-
parser.add_argument("--ota_url", help="OTA地址")
|
|
106
|
-
|
|
107
|
-
parser.add_argument("--serial_number", default="", help="设备的序列号")
|
|
108
|
-
parser.add_argument("--license_key", default="", help="设备的授权密钥")
|
|
109
|
-
|
|
110
|
-
args = parser.parse_args()
|
|
111
|
-
logger.info("Recording... Press Ctrl+C to stop.")
|
|
112
|
-
client = XiaoZhiClient(args.url, args.ota_url)
|
|
113
|
-
await client.start(args.device, args.serial_number, args.license_key)
|
|
114
|
-
|
|
115
|
-
with sd.InputStream(callback=client.audio_callback, channels=1, samplerate=16000, blocksize=960):
|
|
116
|
-
await client.process_audio_input()
|
|
117
|
-
|
|
118
|
-
|
|
119
7
|
if __name__ == "__main__":
|
|
120
8
|
try:
|
|
121
|
-
|
|
9
|
+
main()
|
|
122
10
|
except KeyboardInterrupt:
|
|
123
|
-
logger.
|
|
11
|
+
logger.debug("Stopping...")
|
xiaozhi_sdk/cli.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import logging
|
|
3
|
+
import time
|
|
4
|
+
from collections import deque
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
import click
|
|
8
|
+
import colorlog
|
|
9
|
+
import numpy as np
|
|
10
|
+
import sounddevice as sd
|
|
11
|
+
|
|
12
|
+
from xiaozhi_sdk import XiaoZhiWebsocket
|
|
13
|
+
from xiaozhi_sdk.config import INPUT_SERVER_AUDIO_SAMPLE_RATE
|
|
14
|
+
|
|
15
|
+
# 配置彩色logging
|
|
16
|
+
handler = colorlog.StreamHandler()
|
|
17
|
+
handler.setFormatter(
|
|
18
|
+
colorlog.ColoredFormatter(
|
|
19
|
+
"%(log_color)s%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
20
|
+
datefmt="%Y-%m-%d %H:%M:%S",
|
|
21
|
+
log_colors={
|
|
22
|
+
"DEBUG": "green",
|
|
23
|
+
"INFO": "white",
|
|
24
|
+
"WARNING": "yellow",
|
|
25
|
+
"ERROR": "red",
|
|
26
|
+
"CRITICAL": "red,bg_white",
|
|
27
|
+
},
|
|
28
|
+
)
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
logger = logging.getLogger("xiaozhi_sdk")
|
|
32
|
+
logger.addHandler(handler)
|
|
33
|
+
logger.setLevel(logging.DEBUG)
|
|
34
|
+
|
|
35
|
+
# 全局状态
|
|
36
|
+
input_audio_buffer: deque[bytes] = deque()
|
|
37
|
+
is_playing_audio = False
|
|
38
|
+
is_end = False
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
async def handle_message(message):
|
|
42
|
+
"""处理接收到的消息"""
|
|
43
|
+
global is_end
|
|
44
|
+
logger.info("message received: %s", message)
|
|
45
|
+
if message["type"] == "websocket" and message["state"] == "close":
|
|
46
|
+
is_end = True
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
async def play_assistant_audio(audio_queue: deque[bytes]):
|
|
50
|
+
"""播放音频流"""
|
|
51
|
+
global is_playing_audio
|
|
52
|
+
|
|
53
|
+
stream = sd.OutputStream(samplerate=INPUT_SERVER_AUDIO_SAMPLE_RATE, channels=1, dtype=np.int16)
|
|
54
|
+
stream.start()
|
|
55
|
+
last_audio_time = None
|
|
56
|
+
|
|
57
|
+
while True:
|
|
58
|
+
if is_end:
|
|
59
|
+
return
|
|
60
|
+
|
|
61
|
+
if not audio_queue:
|
|
62
|
+
await asyncio.sleep(0.01)
|
|
63
|
+
if last_audio_time and time.time() - last_audio_time > 1:
|
|
64
|
+
is_playing_audio = False
|
|
65
|
+
continue
|
|
66
|
+
|
|
67
|
+
is_playing_audio = True
|
|
68
|
+
pcm_data = audio_queue.popleft()
|
|
69
|
+
stream.write(pcm_data)
|
|
70
|
+
last_audio_time = time.time()
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class XiaoZhiClient:
|
|
74
|
+
"""小智客户端类"""
|
|
75
|
+
|
|
76
|
+
def __init__(
|
|
77
|
+
self,
|
|
78
|
+
url: Optional[str] = None,
|
|
79
|
+
ota_url: Optional[str] = None,
|
|
80
|
+
):
|
|
81
|
+
self.xiaozhi: Optional[XiaoZhiWebsocket] = None
|
|
82
|
+
self.url = url
|
|
83
|
+
self.ota_url = ota_url
|
|
84
|
+
self.mac_address = ""
|
|
85
|
+
|
|
86
|
+
async def start(self, mac_address: str, serial_number: str = "", license_key: str = ""):
|
|
87
|
+
"""启动客户端连接"""
|
|
88
|
+
self.mac_address = mac_address
|
|
89
|
+
self.xiaozhi = XiaoZhiWebsocket(handle_message, url=self.url, ota_url=self.ota_url, send_wake=True)
|
|
90
|
+
|
|
91
|
+
await self.xiaozhi.init_connection(
|
|
92
|
+
self.mac_address, aec=False, serial_number=serial_number, license_key=license_key
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
asyncio.create_task(play_assistant_audio(self.xiaozhi.output_audio_queue))
|
|
96
|
+
|
|
97
|
+
def audio_callback(self, indata, frames, time, status):
|
|
98
|
+
"""音频输入回调函数"""
|
|
99
|
+
pcm_data = (indata.flatten() * 32767).astype(np.int16).tobytes()
|
|
100
|
+
input_audio_buffer.append(pcm_data)
|
|
101
|
+
|
|
102
|
+
async def process_audio_input(self):
|
|
103
|
+
"""处理音频输入"""
|
|
104
|
+
while True:
|
|
105
|
+
|
|
106
|
+
if is_end:
|
|
107
|
+
return
|
|
108
|
+
|
|
109
|
+
if not input_audio_buffer:
|
|
110
|
+
await asyncio.sleep(0.02)
|
|
111
|
+
continue
|
|
112
|
+
|
|
113
|
+
pcm_data = input_audio_buffer.popleft()
|
|
114
|
+
if not is_playing_audio:
|
|
115
|
+
await self.xiaozhi.send_audio(pcm_data)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
async def run_client(mac_address: str, url: str, ota_url: str, serial_number: str, license_key: str):
|
|
119
|
+
"""运行客户端的异步函数"""
|
|
120
|
+
logger.debug("Recording... Press Ctrl+C to stop.")
|
|
121
|
+
client = XiaoZhiClient(url, ota_url)
|
|
122
|
+
await client.start(mac_address, serial_number, license_key)
|
|
123
|
+
|
|
124
|
+
with sd.InputStream(callback=client.audio_callback, channels=1, samplerate=16000, blocksize=960):
|
|
125
|
+
await client.process_audio_input()
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
@click.command()
|
|
129
|
+
@click.argument("mac_address")
|
|
130
|
+
@click.option("--url", help="服务端websocket地址")
|
|
131
|
+
@click.option("--ota_url", help="OTA地址")
|
|
132
|
+
@click.option("--serial_number", default="", help="设备的序列号")
|
|
133
|
+
@click.option("--license_key", default="", help="设备的授权密钥")
|
|
134
|
+
def main(mac_address: str, url: str, ota_url: str, serial_number: str, license_key: str):
|
|
135
|
+
"""小智SDK客户端
|
|
136
|
+
|
|
137
|
+
MAC_ADDRESS: 设备的MAC地址 (格式: XX:XX:XX:XX:XX:XX)
|
|
138
|
+
"""
|
|
139
|
+
asyncio.run(run_client(mac_address, url, ota_url, serial_number, license_key))
|
xiaozhi_sdk/core.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
import re
|
|
6
|
+
import uuid
|
|
7
|
+
from collections import deque
|
|
8
|
+
from typing import Any, Callable, Deque, Dict, Optional
|
|
9
|
+
|
|
10
|
+
import websockets
|
|
11
|
+
|
|
12
|
+
from xiaozhi_sdk.config import INPUT_SERVER_AUDIO_SAMPLE_RATE
|
|
13
|
+
from xiaozhi_sdk.iot import OtaDevice
|
|
14
|
+
from xiaozhi_sdk.mcp import McpTool
|
|
15
|
+
from xiaozhi_sdk.utils import get_wav_info, read_audio_file, setup_opus
|
|
16
|
+
from xiaozhi_sdk.utils.mcp_tool import async_mcp_play_music, async_search_custom_music
|
|
17
|
+
|
|
18
|
+
setup_opus()
|
|
19
|
+
from xiaozhi_sdk.opus import AudioOpus
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger("xiaozhi_sdk")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class XiaoZhiWebsocket(McpTool):
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
message_handler_callback: Optional[Callable] = None,
|
|
29
|
+
url: Optional[str] = None,
|
|
30
|
+
ota_url: Optional[str] = None,
|
|
31
|
+
audio_sample_rate: int = 16000,
|
|
32
|
+
audio_channels: int = 1,
|
|
33
|
+
send_wake: bool = False,
|
|
34
|
+
):
|
|
35
|
+
super().__init__()
|
|
36
|
+
self.url = url
|
|
37
|
+
self.ota_url = ota_url
|
|
38
|
+
self.send_wake = send_wake
|
|
39
|
+
self.audio_channels = audio_channels
|
|
40
|
+
self.audio_opus = AudioOpus(audio_sample_rate, audio_channels)
|
|
41
|
+
|
|
42
|
+
# 客户端标识
|
|
43
|
+
self.client_id = str(uuid.uuid4())
|
|
44
|
+
self.mac_addr: Optional[str] = None
|
|
45
|
+
self.aec = False
|
|
46
|
+
self.websocket_token = ""
|
|
47
|
+
|
|
48
|
+
# 回调函数
|
|
49
|
+
self.message_handler_callback = message_handler_callback
|
|
50
|
+
|
|
51
|
+
# 连接状态
|
|
52
|
+
self.hello_received = asyncio.Event()
|
|
53
|
+
self.session_id = ""
|
|
54
|
+
self.websocket = None
|
|
55
|
+
self.message_handler_task: Optional[asyncio.Task] = None
|
|
56
|
+
|
|
57
|
+
# 输出音频
|
|
58
|
+
self.output_audio_queue: Deque[bytes] = deque()
|
|
59
|
+
self.is_playing: bool = False
|
|
60
|
+
|
|
61
|
+
# OTA设备
|
|
62
|
+
self.ota: Optional[OtaDevice] = None
|
|
63
|
+
self.iot_task: Optional[asyncio.Task] = None
|
|
64
|
+
self.wait_device_activated: bool = False
|
|
65
|
+
self.tool_func = {
|
|
66
|
+
"async_play_custom_music": async_mcp_play_music,
|
|
67
|
+
"async_search_custom_music": async_search_custom_music,
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
async def _send_hello(self, aec: bool) -> None:
|
|
71
|
+
"""发送hello消息"""
|
|
72
|
+
hello_message = {
|
|
73
|
+
"type": "hello",
|
|
74
|
+
"version": 1,
|
|
75
|
+
"features": {"mcp": True, "aec": aec},
|
|
76
|
+
"transport": "websocket",
|
|
77
|
+
"audio_params": {
|
|
78
|
+
"format": "opus",
|
|
79
|
+
"sample_rate": 16000,
|
|
80
|
+
"channels": 1,
|
|
81
|
+
"frame_duration": 60,
|
|
82
|
+
},
|
|
83
|
+
}
|
|
84
|
+
await self.websocket.send(json.dumps(hello_message))
|
|
85
|
+
await asyncio.wait_for(self.hello_received.wait(), timeout=10.0)
|
|
86
|
+
|
|
87
|
+
async def _start_listen(self) -> None:
|
|
88
|
+
"""开始监听"""
|
|
89
|
+
listen_message = {"session_id": self.session_id, "type": "listen", "state": "start", "mode": "realtime"}
|
|
90
|
+
await self.websocket.send(json.dumps(listen_message))
|
|
91
|
+
|
|
92
|
+
async def is_activate(self, ota_info):
|
|
93
|
+
"""是否激活"""
|
|
94
|
+
if ota_info.get("activation"):
|
|
95
|
+
return False
|
|
96
|
+
|
|
97
|
+
return True
|
|
98
|
+
|
|
99
|
+
async def _activate_iot_device(self, license_key: str, ota_info: Dict[str, Any]) -> None:
|
|
100
|
+
"""激活IoT设备"""
|
|
101
|
+
if not self.ota:
|
|
102
|
+
return
|
|
103
|
+
|
|
104
|
+
challenge = ota_info["activation"]["challenge"]
|
|
105
|
+
await asyncio.sleep(3)
|
|
106
|
+
self.wait_device_activated = True
|
|
107
|
+
for _ in range(10):
|
|
108
|
+
if await self.ota.check_activate(challenge, license_key):
|
|
109
|
+
self.wait_device_activated = False
|
|
110
|
+
break
|
|
111
|
+
await asyncio.sleep(3)
|
|
112
|
+
|
|
113
|
+
async def _send_demo_audio(self) -> None:
|
|
114
|
+
"""发送演示音频"""
|
|
115
|
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
116
|
+
wav_path = os.path.join(current_dir, "../file/audio/greet.wav")
|
|
117
|
+
framerate, channels = get_wav_info(wav_path)
|
|
118
|
+
audio_opus = AudioOpus(framerate, channels)
|
|
119
|
+
|
|
120
|
+
for pcm_data in read_audio_file(wav_path):
|
|
121
|
+
opus_data = await audio_opus.pcm_to_opus(pcm_data)
|
|
122
|
+
await self.websocket.send(opus_data)
|
|
123
|
+
await self.send_silence_audio()
|
|
124
|
+
|
|
125
|
+
async def send_wake_word(self, wake_word: str = "你好,小智") -> None:
|
|
126
|
+
"""发送唤醒词"""
|
|
127
|
+
await self.websocket.send(
|
|
128
|
+
json.dumps({"session_id": self.session_id, "type": "listen", "state": "detect", "text": wake_word})
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
async def send_silence_audio(self, duration_seconds: float = 1.2) -> None:
|
|
132
|
+
"""发送静音音频"""
|
|
133
|
+
frames_count = int(duration_seconds * 1000 / 60)
|
|
134
|
+
pcm_frame = b"\x00\x00" * int(INPUT_SERVER_AUDIO_SAMPLE_RATE / 1000 * 60)
|
|
135
|
+
|
|
136
|
+
for _ in range(frames_count):
|
|
137
|
+
await self.send_audio(pcm_frame)
|
|
138
|
+
|
|
139
|
+
async def _handle_websocket_message(self, message: Any) -> None:
|
|
140
|
+
"""处理接受到的WebSocket消息"""
|
|
141
|
+
|
|
142
|
+
# audio data
|
|
143
|
+
if isinstance(message, bytes):
|
|
144
|
+
pcm_array = await self.audio_opus.opus_to_pcm(message)
|
|
145
|
+
self.output_audio_queue.extend(pcm_array)
|
|
146
|
+
return
|
|
147
|
+
|
|
148
|
+
# json message
|
|
149
|
+
data = json.loads(message)
|
|
150
|
+
message_type = data["type"]
|
|
151
|
+
if message_type == "hello":
|
|
152
|
+
self.hello_received.set()
|
|
153
|
+
self.session_id = data["session_id"]
|
|
154
|
+
elif message_type == "mcp":
|
|
155
|
+
await self.mcp(data)
|
|
156
|
+
elif self.message_handler_callback:
|
|
157
|
+
if data["type"] == "tts":
|
|
158
|
+
if data["state"] == "sentence_start":
|
|
159
|
+
self.is_playing = True
|
|
160
|
+
# self.output_audio_queue.clear()
|
|
161
|
+
else:
|
|
162
|
+
self.is_playing = False
|
|
163
|
+
|
|
164
|
+
await self.message_handler_callback(data)
|
|
165
|
+
|
|
166
|
+
async def _message_handler(self) -> None:
|
|
167
|
+
"""消息处理器"""
|
|
168
|
+
try:
|
|
169
|
+
async for message in self.websocket:
|
|
170
|
+
await self._handle_websocket_message(message)
|
|
171
|
+
except websockets.ConnectionClosed:
|
|
172
|
+
if self.message_handler_callback:
|
|
173
|
+
await self.message_handler_callback(
|
|
174
|
+
{"type": "websocket", "state": "close", "source": "sdk.message_handler"}
|
|
175
|
+
)
|
|
176
|
+
logger.debug("[websocket] close")
|
|
177
|
+
|
|
178
|
+
async def set_mcp_tool_callback(self, tool_func: Dict[str, Callable[..., Any]]) -> None:
|
|
179
|
+
"""设置MCP工具回调函数"""
|
|
180
|
+
self.tool_func.update(tool_func)
|
|
181
|
+
|
|
182
|
+
async def connect_websocket(self, websocket_token):
|
|
183
|
+
"""连接websocket"""
|
|
184
|
+
headers = {
|
|
185
|
+
"Authorization": "Bearer {}".format(websocket_token),
|
|
186
|
+
"Protocol-Version": "1",
|
|
187
|
+
"Device-Id": self.mac_addr,
|
|
188
|
+
"Client-Id": self.client_id,
|
|
189
|
+
}
|
|
190
|
+
try:
|
|
191
|
+
self.websocket = await websockets.connect(uri=self.url, additional_headers=headers)
|
|
192
|
+
except websockets.exceptions.InvalidMessage as e:
|
|
193
|
+
logger.error("[websocket] 连接失败,请检查网络连接或设备状态。当前链接地址: %s, 错误信息:%s", self.url, e)
|
|
194
|
+
return
|
|
195
|
+
self.message_handler_task = asyncio.create_task(self._message_handler())
|
|
196
|
+
|
|
197
|
+
await self._send_hello(self.aec)
|
|
198
|
+
await self._start_listen()
|
|
199
|
+
logger.debug("[websocket] Connection successful")
|
|
200
|
+
await asyncio.sleep(0.5)
|
|
201
|
+
|
|
202
|
+
async def init_connection(
|
|
203
|
+
self, mac_addr: str, aec: bool = False, serial_number: str = "", license_key: str = ""
|
|
204
|
+
) -> None:
|
|
205
|
+
"""初始化连接"""
|
|
206
|
+
mac_pattern = r"^([0-9A-Fa-f]{2}:){5}[0-9A-Fa-f]{2}$"
|
|
207
|
+
if not re.match(mac_pattern, mac_addr):
|
|
208
|
+
raise ValueError(f"无效的MAC地址格式: {mac_addr}。正确格式应为 XX:XX:XX:XX:XX:XX")
|
|
209
|
+
|
|
210
|
+
self.mac_addr = mac_addr.lower()
|
|
211
|
+
self.aec = aec
|
|
212
|
+
|
|
213
|
+
self.ota = OtaDevice(self.mac_addr, self.client_id, self.ota_url, serial_number)
|
|
214
|
+
ota_info = await self.ota.activate_device()
|
|
215
|
+
ws_url = ota_info.get("websocket", {}).get("url")
|
|
216
|
+
self.url = self.url or ws_url
|
|
217
|
+
|
|
218
|
+
if not self.url:
|
|
219
|
+
logger.warning("[websocket] 未找到websocket链接地址")
|
|
220
|
+
return
|
|
221
|
+
|
|
222
|
+
if "tenclass.net" not in self.url and "xiaozhi.me" not in self.url:
|
|
223
|
+
logger.warning("[websocket] 检测到非官方服务器,当前链接地址: %s", self.url)
|
|
224
|
+
|
|
225
|
+
self.websocket_token = ota_info["websocket"]["token"]
|
|
226
|
+
await self.connect_websocket(self.websocket_token)
|
|
227
|
+
|
|
228
|
+
if not await self.is_activate(ota_info):
|
|
229
|
+
self.iot_task = asyncio.create_task(self._activate_iot_device(license_key, ota_info))
|
|
230
|
+
logger.debug("[IOT] 设备未激活")
|
|
231
|
+
|
|
232
|
+
if self.send_wake:
|
|
233
|
+
await self.send_wake_word()
|
|
234
|
+
|
|
235
|
+
async def send_audio(self, pcm: bytes) -> None:
|
|
236
|
+
"""发送音频数据"""
|
|
237
|
+
if not self.websocket:
|
|
238
|
+
return
|
|
239
|
+
|
|
240
|
+
state = self.websocket.state
|
|
241
|
+
if state == websockets.protocol.State.OPEN:
|
|
242
|
+
opus_data = await self.audio_opus.pcm_to_opus(pcm)
|
|
243
|
+
await self.websocket.send(opus_data)
|
|
244
|
+
elif state in [websockets.protocol.State.CLOSED, websockets.protocol.State.CLOSING]:
|
|
245
|
+
if self.wait_device_activated:
|
|
246
|
+
logger.debug("[websocket] Server actively disconnected, reconnecting...")
|
|
247
|
+
await self.connect_websocket(self.websocket_token)
|
|
248
|
+
elif self.message_handler_callback:
|
|
249
|
+
await self.message_handler_callback({"type": "websocket", "state": "close", "source": "sdk.send_audio"})
|
|
250
|
+
self.websocket = None
|
|
251
|
+
logger.debug("[websocket] Server actively disconnected")
|
|
252
|
+
|
|
253
|
+
await asyncio.sleep(0.5)
|
|
254
|
+
else:
|
|
255
|
+
await asyncio.sleep(0.1)
|
|
256
|
+
|
|
257
|
+
async def close(self) -> None:
|
|
258
|
+
"""关闭连接"""
|
|
259
|
+
if self.message_handler_task and not self.message_handler_task.done():
|
|
260
|
+
self.message_handler_task.cancel()
|
|
261
|
+
try:
|
|
262
|
+
await self.message_handler_task
|
|
263
|
+
except asyncio.CancelledError:
|
|
264
|
+
pass
|
|
265
|
+
|
|
266
|
+
if self.iot_task:
|
|
267
|
+
self.iot_task.cancel()
|
|
268
|
+
|
|
269
|
+
if self.websocket:
|
|
270
|
+
await self.websocket.close()
|
xiaozhi_sdk/iot.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import hashlib
|
|
2
2
|
import hmac
|
|
3
3
|
import json
|
|
4
|
+
import logging
|
|
4
5
|
from typing import Any, Dict, Optional
|
|
5
6
|
|
|
6
7
|
import aiohttp
|
|
@@ -13,6 +14,8 @@ BOARD_TYPE = "xiaozhi-sdk-box"
|
|
|
13
14
|
USER_AGENT = "xiaozhi-sdk/{}".format(__version__)
|
|
14
15
|
BOARD_NAME = "xiaozhi-sdk-{}".format(__version__)
|
|
15
16
|
|
|
17
|
+
logger = logging.getLogger("xiaozhi_sdk")
|
|
18
|
+
|
|
16
19
|
|
|
17
20
|
class OtaDevice:
|
|
18
21
|
"""
|
|
@@ -72,4 +75,7 @@ class OtaDevice:
|
|
|
72
75
|
|
|
73
76
|
async with aiohttp.ClientSession() as session:
|
|
74
77
|
async with session.post(url, headers=headers, data=json.dumps(payload)) as response:
|
|
75
|
-
|
|
78
|
+
is_ok = response.status == 200
|
|
79
|
+
if not is_ok:
|
|
80
|
+
logger.debug("[IOT] wait for activate device...")
|
|
81
|
+
return is_ok
|
xiaozhi_sdk/mcp.py
CHANGED
|
@@ -1,9 +1,12 @@
|
|
|
1
|
+
import asyncio
|
|
1
2
|
import json
|
|
2
3
|
import logging
|
|
3
4
|
|
|
5
|
+
import numpy as np
|
|
4
6
|
import requests
|
|
5
7
|
|
|
6
|
-
from xiaozhi_sdk.
|
|
8
|
+
from xiaozhi_sdk.utils.mcp_data import mcp_initialize_payload, mcp_tool_conf, mcp_tools_payload
|
|
9
|
+
from xiaozhi_sdk.utils.mcp_tool import _get_random_music_info
|
|
7
10
|
|
|
8
11
|
logger = logging.getLogger("xiaozhi_sdk")
|
|
9
12
|
|
|
@@ -16,6 +19,7 @@ class McpTool(object):
|
|
|
16
19
|
self.explain_token = ""
|
|
17
20
|
self.websocket = None
|
|
18
21
|
self.tool_func = {}
|
|
22
|
+
self.is_playing = False
|
|
19
23
|
|
|
20
24
|
def get_mcp_json(self, payload: dict):
|
|
21
25
|
return json.dumps({"session_id": self.session_id, "type": "mcp", "payload": payload})
|
|
@@ -45,14 +49,46 @@ class McpTool(object):
|
|
|
45
49
|
return res_json, True
|
|
46
50
|
return res_json, False
|
|
47
51
|
|
|
52
|
+
async def play_custom_music(self, tool_func, arguments):
|
|
53
|
+
pcm_array, is_error = await tool_func(arguments)
|
|
54
|
+
while True:
|
|
55
|
+
if not self.is_playing:
|
|
56
|
+
break
|
|
57
|
+
await asyncio.sleep(0.1)
|
|
58
|
+
pcm_array = await self.audio_opus.change_sample_rate(np.array(pcm_array))
|
|
59
|
+
self.output_audio_queue.extend(pcm_array)
|
|
60
|
+
|
|
48
61
|
async def mcp_tool_call(self, mcp_json: dict):
|
|
49
62
|
tool_name = mcp_json["params"]["name"]
|
|
50
63
|
tool_func = self.tool_func[tool_name]
|
|
64
|
+
arguments = mcp_json["params"]["arguments"]
|
|
51
65
|
try:
|
|
52
|
-
|
|
66
|
+
if tool_name == "async_play_custom_music":
|
|
67
|
+
|
|
68
|
+
# v1 返回 url
|
|
69
|
+
music_info = await _get_random_music_info(arguments["id_list"])
|
|
70
|
+
if not music_info.get("url"):
|
|
71
|
+
tool_res, is_error = {"message": "播放失败"}, True
|
|
72
|
+
else:
|
|
73
|
+
tool_res, is_error = {"message": "正在为你播放: {}".format(arguments["music_name"])}, False
|
|
74
|
+
data = {
|
|
75
|
+
"type": "music", "state": "start",
|
|
76
|
+
"url": music_info["url"],
|
|
77
|
+
"text": arguments["music_name"],
|
|
78
|
+
"source": "sdk.mcp_music_tool"
|
|
79
|
+
}
|
|
80
|
+
await self.message_handler_callback(data)
|
|
81
|
+
|
|
82
|
+
# v2 音频放到输出
|
|
83
|
+
# asyncio.create_task(self.play_custom_music(tool_func, arguments))
|
|
84
|
+
|
|
85
|
+
elif tool_name.startswith("async_"):
|
|
86
|
+
tool_res, is_error = await tool_func(arguments)
|
|
87
|
+
else:
|
|
88
|
+
tool_res, is_error = tool_func(arguments)
|
|
53
89
|
except Exception as e:
|
|
54
90
|
logger.error("[MCP] tool_func error: %s", e)
|
|
55
|
-
return
|
|
91
|
+
return self._build_response(mcp_json["id"], "工具调用失败", True)
|
|
56
92
|
|
|
57
93
|
if tool_name == "take_photo":
|
|
58
94
|
tool_res, is_error = await self.analyze_image(tool_res, mcp_json["params"]["arguments"]["question"])
|
|
@@ -84,10 +120,11 @@ class McpTool(object):
|
|
|
84
120
|
for name, func in self.tool_func.items():
|
|
85
121
|
if func:
|
|
86
122
|
tool_list.append(name)
|
|
87
|
-
|
|
88
|
-
|
|
123
|
+
target_name = name.removeprefix("async_")
|
|
124
|
+
mcp_tool_conf[target_name]["name"] = name
|
|
125
|
+
mcp_tools_payload["result"]["tools"].append(mcp_tool_conf[target_name])
|
|
89
126
|
await self.websocket.send(self.get_mcp_json(mcp_tools_payload))
|
|
90
|
-
logger.
|
|
127
|
+
logger.debug("[MCP] 加载成功,当前可用工具列表为:%s", tool_list)
|
|
91
128
|
|
|
92
129
|
elif method == "tools/call":
|
|
93
130
|
tool_name = payload["params"]["name"]
|
|
@@ -97,6 +134,6 @@ class McpTool(object):
|
|
|
97
134
|
|
|
98
135
|
mcp_res = await self.mcp_tool_call(payload)
|
|
99
136
|
await self.websocket.send(mcp_res)
|
|
100
|
-
logger.
|
|
137
|
+
logger.debug("[MCP] Tool %s called", tool_name)
|
|
101
138
|
else:
|
|
102
139
|
logger.warning("[MCP] unknown method %s: %s", method, payload)
|
xiaozhi_sdk/opus.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
1
3
|
import av
|
|
2
4
|
import numpy as np
|
|
3
5
|
import opuslib
|
|
4
6
|
|
|
5
|
-
from xiaozhi_sdk import INPUT_SERVER_AUDIO_SAMPLE_RATE
|
|
7
|
+
from xiaozhi_sdk.config import INPUT_SERVER_AUDIO_SAMPLE_RATE
|
|
6
8
|
|
|
7
9
|
|
|
8
10
|
class AudioOpus:
|
|
@@ -29,11 +31,16 @@ class AudioOpus:
|
|
|
29
31
|
pcm_bytes = pcm_array.tobytes()
|
|
30
32
|
return self.opus_encoder.encode(pcm_bytes, 960)
|
|
31
33
|
|
|
32
|
-
|
|
34
|
+
@staticmethod
|
|
35
|
+
def to_n_960(samples) -> np.ndarray:
|
|
36
|
+
n = math.ceil(samples.shape[0] / 960)
|
|
37
|
+
arr_padded = np.pad(samples, (0, 960 * n - samples.shape[0]), mode="constant", constant_values=0)
|
|
38
|
+
return arr_padded.reshape(n, 960)
|
|
39
|
+
|
|
40
|
+
async def change_sample_rate(self, pcm_array) -> np.ndarray:
|
|
33
41
|
if self.sample_rate == INPUT_SERVER_AUDIO_SAMPLE_RATE:
|
|
34
|
-
return
|
|
42
|
+
return self.to_n_960(pcm_array)
|
|
35
43
|
|
|
36
|
-
c = int(self.sample_rate / INPUT_SERVER_AUDIO_SAMPLE_RATE)
|
|
37
44
|
frame = av.AudioFrame.from_ndarray(np.array(pcm_array).reshape(1, -1), format="s16", layout="mono")
|
|
38
45
|
frame.sample_rate = INPUT_SERVER_AUDIO_SAMPLE_RATE # Assuming input is 16kHz
|
|
39
46
|
resampled_frames = self.resampler.resample(frame)
|
|
@@ -45,10 +52,9 @@ class AudioOpus:
|
|
|
45
52
|
)
|
|
46
53
|
new_frame.sample_rate = self.sample_rate
|
|
47
54
|
new_samples = new_frame.to_ndarray().flatten()
|
|
48
|
-
|
|
49
|
-
return arr_padded.reshape(c, 960)
|
|
55
|
+
return self.to_n_960(new_samples)
|
|
50
56
|
|
|
51
|
-
async def opus_to_pcm(self, opus):
|
|
57
|
+
async def opus_to_pcm(self, opus) -> np.ndarray:
|
|
52
58
|
pcm_data = self.opus_decoder.decode(opus, 960)
|
|
53
59
|
pcm_array = np.frombuffer(pcm_data, dtype=np.int16)
|
|
54
60
|
samples = await self.change_sample_rate(pcm_array)
|
|
@@ -45,11 +45,11 @@ def setup_opus():
|
|
|
45
45
|
arch = "x64"
|
|
46
46
|
|
|
47
47
|
if system == "darwin": # macOS
|
|
48
|
-
return f"{current_dir}
|
|
48
|
+
return f"{current_dir}/../../file/opus/macos-{arch}-libopus.dylib"
|
|
49
49
|
elif system == "windows": # Windows
|
|
50
|
-
return f"{current_dir}
|
|
50
|
+
return f"{current_dir}/../../file/opus/windows-opus.dll"
|
|
51
51
|
elif system == "linux": # Linux
|
|
52
|
-
return f"{current_dir}
|
|
52
|
+
return f"{current_dir}/../../file/opus/linux-{arch}-libopus.so"
|
|
53
53
|
else:
|
|
54
54
|
# 默认情况,尝试系统查找
|
|
55
55
|
return ctypes.util.find_library(name)
|
|
@@ -11,6 +11,22 @@ mcp_initialize_payload: Dict[str, Any] = {
|
|
|
11
11
|
}
|
|
12
12
|
|
|
13
13
|
mcp_tool_conf: Dict[str, Dict[str, Any]] = {
|
|
14
|
+
"search_custom_music": {
|
|
15
|
+
"description": "Search music and get music IDs. Use this tool when the user asks to search or play music. This tool returns a list of music with their IDs, which are required for playing music. Args:\n `music_name`: The name of the music to search\n `author_name`: The name of the music author (optional)",
|
|
16
|
+
"inputSchema": {
|
|
17
|
+
"type": "object",
|
|
18
|
+
"properties": {"music_name": {"type": "string"}, "author_name": {"type": "string"}},
|
|
19
|
+
"required": ["music_name"],
|
|
20
|
+
},
|
|
21
|
+
},
|
|
22
|
+
"play_custom_music": {
|
|
23
|
+
"description": "Play music using music IDs. IMPORTANT: You must call `search_custom_music` first to get the music IDs before using this tool. Use this tool after getting music IDs from search results. Args:\n `id_list`: The id list of the music to play (obtained from search_custom_music results). The list must contain more than 2 music IDs, and the system will randomly select one to play.\n `music_name`: The name of the music (obtained from search_custom_music results)",
|
|
24
|
+
"inputSchema": {
|
|
25
|
+
"type": "object",
|
|
26
|
+
"properties": {"music_name": {"type": "string"}, "id_list": {"type": "array", "items": {"type": "string"}, "minItems": 3}},
|
|
27
|
+
"required": ["music_name", "id_list"],
|
|
28
|
+
},
|
|
29
|
+
},
|
|
14
30
|
"get_device_status": {
|
|
15
31
|
"description": "Provides the real-time information of the device, including the current status of the audio speaker, screen, battery, network, etc.\nUse this tool for: \n1. Answering questions about current condition (e.g. what is the current volume of the audio speaker?)\n2. As the first step to control the device (e.g. turn up / down the volume of the audio speaker, etc.)",
|
|
16
32
|
"inputSchema": {"type": "object", "properties": {}},
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
import io
|
|
2
|
+
import random
|
|
3
|
+
|
|
4
|
+
import aiohttp
|
|
5
|
+
import numpy as np
|
|
6
|
+
from pydub import AudioSegment
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
async def async_search_custom_music(data) -> tuple[dict, bool]:
|
|
10
|
+
search_url = f"https://music-api.gdstudio.xyz/api.php?types=search&name={data['music_name']}&count=100&pages=1"
|
|
11
|
+
|
|
12
|
+
# 为搜索请求设置 10 秒超时
|
|
13
|
+
timeout = aiohttp.ClientTimeout(total=10)
|
|
14
|
+
async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
15
|
+
async with session.get(search_url) as response:
|
|
16
|
+
response_json = await response.json()
|
|
17
|
+
|
|
18
|
+
music_list = []
|
|
19
|
+
first_music_list = []
|
|
20
|
+
other_music_list1 = []
|
|
21
|
+
other_music_list2 = []
|
|
22
|
+
for line in response_json:
|
|
23
|
+
if data.get("author_name") and data["author_name"] in line["artist"][0]:
|
|
24
|
+
first_music_list.append(line)
|
|
25
|
+
elif data.get("author_name") and (data["author_name"] in line["artist"] or data["author_name"] in line["name"]):
|
|
26
|
+
other_music_list1.append(line)
|
|
27
|
+
else:
|
|
28
|
+
other_music_list2.append(line)
|
|
29
|
+
|
|
30
|
+
if len(first_music_list) <= 10:
|
|
31
|
+
music_list = first_music_list
|
|
32
|
+
random.shuffle(other_music_list2)
|
|
33
|
+
music_list = music_list + other_music_list1[: 20 - len(music_list)]
|
|
34
|
+
music_list = music_list + other_music_list2[: 20 - len(music_list)]
|
|
35
|
+
|
|
36
|
+
# print(data)
|
|
37
|
+
# print("找到音乐,数量:", len(first_music_list), len(music_list))
|
|
38
|
+
|
|
39
|
+
if not music_list:
|
|
40
|
+
return {}, False
|
|
41
|
+
return {"message": "已找到歌曲", "music_list": music_list}, False
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
async def _get_random_music_info(id_list: list) -> dict:
|
|
45
|
+
timeout = aiohttp.ClientTimeout(total=10)
|
|
46
|
+
async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
47
|
+
random.shuffle(id_list)
|
|
48
|
+
|
|
49
|
+
for music_id in id_list:
|
|
50
|
+
url = f"https://music-api.gdstudio.xyz/api.php?types=url&id={music_id}&br=128"
|
|
51
|
+
async with session.get(url) as response:
|
|
52
|
+
res_json = await response.json()
|
|
53
|
+
if res_json.get("url"):
|
|
54
|
+
break
|
|
55
|
+
|
|
56
|
+
return res_json
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
async def async_mcp_play_music(data) -> tuple[list, bool]:
|
|
60
|
+
id_list = data["id_list"]
|
|
61
|
+
res_json = await _get_random_music_info(id_list)
|
|
62
|
+
|
|
63
|
+
if not res_json:
|
|
64
|
+
return [], False
|
|
65
|
+
|
|
66
|
+
pcm_list = []
|
|
67
|
+
buffer = io.BytesIO()
|
|
68
|
+
# 为下载音乐文件设置 60 秒超时(音乐文件可能比较大)
|
|
69
|
+
download_timeout = aiohttp.ClientTimeout(total=60)
|
|
70
|
+
async with aiohttp.ClientSession(timeout=download_timeout) as session:
|
|
71
|
+
async with session.get(res_json["url"]) as resp:
|
|
72
|
+
async for chunk in resp.content.iter_chunked(1024):
|
|
73
|
+
buffer.write(chunk)
|
|
74
|
+
|
|
75
|
+
buffer.seek(0)
|
|
76
|
+
audio = AudioSegment.from_mp3(buffer)
|
|
77
|
+
audio = audio.set_frame_rate(16000).set_channels(1).set_sample_width(2) # 2 bytes = 16 bits
|
|
78
|
+
pcm_data = audio.raw_data
|
|
79
|
+
|
|
80
|
+
chunk_size = 960 * 2
|
|
81
|
+
for i in range(0, len(pcm_data), chunk_size):
|
|
82
|
+
chunk = pcm_data[i: i + chunk_size]
|
|
83
|
+
|
|
84
|
+
if chunk: # 确保不添加空块
|
|
85
|
+
chunk = np.frombuffer(chunk, dtype=np.int16)
|
|
86
|
+
pcm_list.extend(chunk)
|
|
87
|
+
|
|
88
|
+
return pcm_list, False
|
|
@@ -1,14 +1,13 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: xiaozhi-sdk
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.7
|
|
4
4
|
Summary: 一个用于连接和控制小智智能设备的Python SDK,支持实时音频通信、MCP工具集成和设备管理功能。
|
|
5
5
|
Author-email: dairoot <623815825@qq.com>
|
|
6
|
-
License: MIT
|
|
6
|
+
License-Expression: MIT
|
|
7
7
|
Project-URL: Homepage, https://github.com/dairoot/xiaozhi-sdk
|
|
8
8
|
Classifier: Programming Language :: Python :: 3
|
|
9
|
-
Classifier: License :: OSI Approved :: MIT License
|
|
10
9
|
Classifier: Operating System :: OS Independent
|
|
11
|
-
Requires-Python: >=3.
|
|
10
|
+
Requires-Python: >=3.9
|
|
12
11
|
Description-Content-Type: text/markdown
|
|
13
12
|
License-File: LICENSE
|
|
14
13
|
Requires-Dist: numpy
|
|
@@ -19,6 +18,10 @@ Requires-Dist: opuslib
|
|
|
19
18
|
Requires-Dist: requests
|
|
20
19
|
Requires-Dist: sounddevice
|
|
21
20
|
Requires-Dist: python-socks
|
|
21
|
+
Requires-Dist: click
|
|
22
|
+
Requires-Dist: colorlog
|
|
23
|
+
Requires-Dist: soundfile>=0.13.1
|
|
24
|
+
Requires-Dist: pydub>=0.25.1
|
|
22
25
|
Dynamic: license-file
|
|
23
26
|
|
|
24
27
|
# 小智SDK (XiaoZhi SDK)
|
|
@@ -53,21 +56,7 @@ pip install xiaozhi-sdk
|
|
|
53
56
|
#### 查看帮助信息
|
|
54
57
|
|
|
55
58
|
```bash
|
|
56
|
-
python -m xiaozhi_sdk
|
|
57
|
-
```
|
|
58
|
-
|
|
59
|
-
输出示例:
|
|
60
|
-
```text
|
|
61
|
-
positional arguments:
|
|
62
|
-
device 你的小智设备的MAC地址 (格式: XX:XX:XX:XX:XX:XX)
|
|
63
|
-
|
|
64
|
-
options:
|
|
65
|
-
-h, --help show this help message and exit
|
|
66
|
-
--url URL 服务端websocket地址
|
|
67
|
-
--ota_url OTA_URL OTA地址
|
|
68
|
-
--serial_number SERIAL_NUMBER 设备的序列号
|
|
69
|
-
--license_key LICENSE_KEY 设备的授权密钥
|
|
70
|
-
|
|
59
|
+
python -m xiaozhi_sdk --help
|
|
71
60
|
```
|
|
72
61
|
|
|
73
62
|
#### 连接设备(需要提供 MAC 地址)
|
|
@@ -76,14 +65,20 @@ options:
|
|
|
76
65
|
python -m xiaozhi_sdk 00:22:44:66:88:00
|
|
77
66
|
```
|
|
78
67
|
|
|
79
|
-
### 2. 编程使用
|
|
68
|
+
### 2. 编程使用 (高阶用法)
|
|
80
69
|
参考 [examples](examples/) 文件中的示例代码,可以快速开始使用 SDK。
|
|
81
70
|
|
|
82
71
|
|
|
83
|
-
|
|
72
|
+
---
|
|
73
|
+
|
|
74
|
+
## ✅ 运行测试
|
|
84
75
|
|
|
85
76
|
```bash
|
|
86
|
-
|
|
77
|
+
# 安装开发依赖
|
|
78
|
+
uv sync --group dev
|
|
79
|
+
|
|
80
|
+
# 运行测试
|
|
81
|
+
uv run pytest
|
|
87
82
|
```
|
|
88
83
|
|
|
89
84
|
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
file/audio/greet.wav,sha256=F60kKKFVQZyYh67_-9AJHMviuquSWHHqwGQewUSOAFg,32720
|
|
2
|
+
file/audio/play_music.wav,sha256=uqUIKz-3bqViDsjEZ2n6g_7xsggbRY6JwdZTCGS8b2E,61772
|
|
3
|
+
file/audio/say_hello.wav,sha256=RGo2MDUF7npGmjFPT4III0ibf7dIZ1c47jijrF0Yjaw,34146
|
|
4
|
+
file/audio/take_photo.wav,sha256=_DNWg31Q8NIxN3eUS4wBC7mn4MZCWLCNPuKfKPv1ojQ,51412
|
|
5
|
+
file/image/leijun.jpg,sha256=plhBvnB4O21RjLwH-HjNq0jH4Msy5ppA_IDWe5ieNg4,70814
|
|
6
|
+
file/opus/linux-arm64-libopus.so,sha256=D2H5VDUomaYuLetejCvLwCgf-iAVP0isg1yGwfsuvEE,493032
|
|
7
|
+
file/opus/linux-x64-libopus.so,sha256=FmXJqkxLpDzNFOHYkmOzmsp1hP0eIS5b6x_XfOs-IQA,623008
|
|
8
|
+
file/opus/macos-arm64-libopus.dylib,sha256=H7wXwkrGwb-hesMMZGFxWb0Ri1Y4m5GWiKsd8CfOhE8,357584
|
|
9
|
+
file/opus/macos-x64-libopus.dylib,sha256=MqyL_OjwSACF4Xs_-KrGbcScy4IEprr5Rlkk3ddZye8,550856
|
|
10
|
+
file/opus/windows-opus.dll,sha256=kLfhioMvbJhOgNMAldpWk3DCZqC5Xd70LRbHnACvAnw,463360
|
|
11
|
+
xiaozhi_sdk/__init__.py,sha256=42z4NTA7JPjhEk64Z2CmOkQ71dd3aPyLX02T6gQfe_4,77
|
|
12
|
+
xiaozhi_sdk/__main__.py,sha256=i0ZJdHUqAKg9vwZrK_w0TJkzdotTYTK8aUeSPcJc1ks,210
|
|
13
|
+
xiaozhi_sdk/cli.py,sha256=KCSTyH6ocMSNs3WYRSqQsvGlzJNRqBR88otz-w-yb9E,4241
|
|
14
|
+
xiaozhi_sdk/config.py,sha256=h4mpMeBf2vT9qYAqCCbGVGmMemkgk98pcXP2Rh4TEFc,89
|
|
15
|
+
xiaozhi_sdk/core.py,sha256=564SefCBus6qNRApWqwI113aIN1p4eYpci1mLeMExIs,10007
|
|
16
|
+
xiaozhi_sdk/iot.py,sha256=IO3SfiuQxucYl_917BCNCwIAv1dajCJI-IFTWwHnSDE,2580
|
|
17
|
+
xiaozhi_sdk/mcp.py,sha256=kFkyZjLrjJNwZuZtFueMdeFAjzm0DJa7GsXcIg5YYi4,5524
|
|
18
|
+
xiaozhi_sdk/opus.py,sha256=r3nnYg0ZKAJTreb_3nKgfHJh06MJiMvnNMPO1SWdoMM,2224
|
|
19
|
+
xiaozhi_sdk/utils/__init__.py,sha256=XKSHWoFmuSkpwaIr308HybRzfFIXoT1Fd-eUKo_im6Y,1705
|
|
20
|
+
xiaozhi_sdk/utils/mcp_data.py,sha256=xeZhqucYCeKPJiRbrTptkNeCd0ampQXgb5lA5UsJW3U,3851
|
|
21
|
+
xiaozhi_sdk/utils/mcp_tool.py,sha256=OU1mr3qr4XB0nvE8sX6-eCZyT2l-RdFe0caaTbOf4L8,3085
|
|
22
|
+
xiaozhi_sdk-0.0.7.dist-info/licenses/LICENSE,sha256=Vwgps1iODKl43cAtME_0dawTjAzNW-O2BWiN5BHggww,1085
|
|
23
|
+
xiaozhi_sdk-0.0.7.dist-info/METADATA,sha256=K4p1y2BqooqQ5gWqm4gf3iVQ5AbATsX-WgbBIM0jtHU,2091
|
|
24
|
+
xiaozhi_sdk-0.0.7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
25
|
+
xiaozhi_sdk-0.0.7.dist-info/top_level.txt,sha256=nBpue4hU5Ykm5CtYPsAdxSa_yqbtZsIT_gF_EkBaJPM,12
|
|
26
|
+
xiaozhi_sdk-0.0.7.dist-info/RECORD,,
|
|
@@ -1,22 +0,0 @@
|
|
|
1
|
-
file/audio/greet.wav,sha256=F60kKKFVQZyYh67_-9AJHMviuquSWHHqwGQewUSOAFg,32720
|
|
2
|
-
file/audio/say_hello.wav,sha256=RGo2MDUF7npGmjFPT4III0ibf7dIZ1c47jijrF0Yjaw,34146
|
|
3
|
-
file/audio/take_photo.wav,sha256=_DNWg31Q8NIxN3eUS4wBC7mn4MZCWLCNPuKfKPv1ojQ,51412
|
|
4
|
-
file/image/leijun.jpg,sha256=plhBvnB4O21RjLwH-HjNq0jH4Msy5ppA_IDWe5ieNg4,70814
|
|
5
|
-
file/opus/linux-arm64-libopus.so,sha256=D2H5VDUomaYuLetejCvLwCgf-iAVP0isg1yGwfsuvEE,493032
|
|
6
|
-
file/opus/linux-x64-libopus.so,sha256=FmXJqkxLpDzNFOHYkmOzmsp1hP0eIS5b6x_XfOs-IQA,623008
|
|
7
|
-
file/opus/macos-arm64-libopus.dylib,sha256=H7wXwkrGwb-hesMMZGFxWb0Ri1Y4m5GWiKsd8CfOhE8,357584
|
|
8
|
-
file/opus/macos-x64-libopus.dylib,sha256=MqyL_OjwSACF4Xs_-KrGbcScy4IEprr5Rlkk3ddZye8,550856
|
|
9
|
-
file/opus/windows-x86_64-opus.dll,sha256=kLfhioMvbJhOgNMAldpWk3DCZqC5Xd70LRbHnACvAnw,463360
|
|
10
|
-
xiaozhi_sdk/__init__.py,sha256=psrVT0BBTsCj1wqJ7ZmKRhGfRj2y7NPGRwn_aiKLw-E,8146
|
|
11
|
-
xiaozhi_sdk/__main__.py,sha256=_Xh6v2oMYXYHsrAkw4PYMJpvi-0r3ujLNRLMxPNarTQ,3807
|
|
12
|
-
xiaozhi_sdk/config.py,sha256=h4mpMeBf2vT9qYAqCCbGVGmMemkgk98pcXP2Rh4TEFc,89
|
|
13
|
-
xiaozhi_sdk/data.py,sha256=8z8erOjBZFvPSBJlPoyTzRYZ3BuMvnPpAFQCbSxs-48,2522
|
|
14
|
-
xiaozhi_sdk/iot.py,sha256=w_DvEDcQmaP1JJAYElRDCJP5h0GhaoZ2lg5JBnmZBnU,2392
|
|
15
|
-
xiaozhi_sdk/mcp.py,sha256=JA-z6EjGqitEfwMlvxk6XUSjbmfAdyWJVZPjtjqo6Oo,3823
|
|
16
|
-
xiaozhi_sdk/opus.py,sha256=4O-kz-PcUVmpa27Vju6jv-sbwywuAXFvVL23R1-vv5o,2104
|
|
17
|
-
xiaozhi_sdk/utils.py,sha256=5qHAiI5Nrzeka3TofMPhAVmMovEJJa6QSrKcDM0OF4g,1703
|
|
18
|
-
xiaozhi_sdk-0.0.5.dist-info/licenses/LICENSE,sha256=Vwgps1iODKl43cAtME_0dawTjAzNW-O2BWiN5BHggww,1085
|
|
19
|
-
xiaozhi_sdk-0.0.5.dist-info/METADATA,sha256=M0BcMb2zM9SjHdJHfivECeUtOFKNuPDiBoe3BM8dkW0,2352
|
|
20
|
-
xiaozhi_sdk-0.0.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
21
|
-
xiaozhi_sdk-0.0.5.dist-info/top_level.txt,sha256=nBpue4hU5Ykm5CtYPsAdxSa_yqbtZsIT_gF_EkBaJPM,12
|
|
22
|
-
xiaozhi_sdk-0.0.5.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|