xiaozhi-sdk 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- file/audio/greet.wav +0 -0
- file/audio/play_music.wav +0 -0
- file/audio/say_hello.wav +0 -0
- file/audio/take_photo.wav +0 -0
- file/image/leijun.jpg +0 -0
- file/opus/linux-arm64-libopus.so +0 -0
- file/opus/linux-x64-libopus.so +0 -0
- file/opus/macos-arm64-libopus.dylib +0 -0
- file/opus/macos-x64-libopus.dylib +0 -0
- file/opus/windows-opus.dll +0 -0
- xiaozhi_sdk/__init__.py +2 -154
- xiaozhi_sdk/__main__.py +7 -86
- xiaozhi_sdk/cli.py +231 -0
- xiaozhi_sdk/config.py +1 -3
- xiaozhi_sdk/core.py +269 -0
- xiaozhi_sdk/iot.py +61 -27
- xiaozhi_sdk/mcp.py +128 -32
- xiaozhi_sdk/opus.py +13 -11
- xiaozhi_sdk/utils/__init__.py +57 -0
- xiaozhi_sdk/utils/mcp_tool.py +185 -0
- xiaozhi_sdk-0.2.0.dist-info/METADATA +90 -0
- xiaozhi_sdk-0.2.0.dist-info/RECORD +25 -0
- xiaozhi_sdk-0.2.0.dist-info/licenses/LICENSE +21 -0
- xiaozhi_sdk/data.py +0 -58
- xiaozhi_sdk/utils.py +0 -23
- xiaozhi_sdk-0.1.0.dist-info/METADATA +0 -58
- xiaozhi_sdk-0.1.0.dist-info/RECORD +0 -12
- {xiaozhi_sdk-0.1.0.dist-info → xiaozhi_sdk-0.2.0.dist-info}/WHEEL +0 -0
- {xiaozhi_sdk-0.1.0.dist-info → xiaozhi_sdk-0.2.0.dist-info}/top_level.txt +0 -0
xiaozhi_sdk/core.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
import re
|
|
6
|
+
import uuid
|
|
7
|
+
from collections import deque
|
|
8
|
+
from typing import Any, Callable, Deque, Dict, Optional
|
|
9
|
+
|
|
10
|
+
import websockets
|
|
11
|
+
|
|
12
|
+
from xiaozhi_sdk.config import INPUT_SERVER_AUDIO_SAMPLE_RATE
|
|
13
|
+
from xiaozhi_sdk.iot import OtaDevice
|
|
14
|
+
from xiaozhi_sdk.mcp import McpTool
|
|
15
|
+
from xiaozhi_sdk.utils import get_wav_info, read_audio_file, setup_opus
|
|
16
|
+
|
|
17
|
+
setup_opus()
|
|
18
|
+
from xiaozhi_sdk.opus import AudioOpus
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger("xiaozhi_sdk")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class XiaoZhiWebsocket(McpTool):
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
message_handler_callback: Optional[Callable] = None,
|
|
28
|
+
url: Optional[str] = None,
|
|
29
|
+
ota_url: Optional[str] = None,
|
|
30
|
+
audio_sample_rate: int = 16000,
|
|
31
|
+
audio_channels: int = 1,
|
|
32
|
+
wake_word: str = "",
|
|
33
|
+
):
|
|
34
|
+
super().__init__()
|
|
35
|
+
self.url = url
|
|
36
|
+
self.ota_url = ota_url
|
|
37
|
+
self.audio_channels = audio_channels
|
|
38
|
+
self.audio_opus = AudioOpus(audio_sample_rate, audio_channels)
|
|
39
|
+
self.wake_word = wake_word
|
|
40
|
+
|
|
41
|
+
# 客户端标识
|
|
42
|
+
self.client_id = str(uuid.uuid4())
|
|
43
|
+
self.mac_addr: Optional[str] = None
|
|
44
|
+
self.aec = False
|
|
45
|
+
self.websocket_token = ""
|
|
46
|
+
|
|
47
|
+
# 回调函数
|
|
48
|
+
self.message_handler_callback = message_handler_callback
|
|
49
|
+
|
|
50
|
+
# 连接状态
|
|
51
|
+
self.hello_received = asyncio.Event()
|
|
52
|
+
self.session_id = ""
|
|
53
|
+
self.websocket = None
|
|
54
|
+
self.message_handler_task: Optional[asyncio.Task] = None
|
|
55
|
+
|
|
56
|
+
# 输出音频
|
|
57
|
+
self.output_audio_queue: Deque[bytes] = deque()
|
|
58
|
+
self.is_playing: bool = False
|
|
59
|
+
|
|
60
|
+
# OTA设备
|
|
61
|
+
self.ota: Optional[OtaDevice] = None
|
|
62
|
+
self.iot_task: Optional[asyncio.Task] = None
|
|
63
|
+
self.wait_device_activated: bool = False
|
|
64
|
+
|
|
65
|
+
# mcp工具
|
|
66
|
+
self.mcp_tool_dict = {}
|
|
67
|
+
|
|
68
|
+
async def _send_hello(self, aec: bool) -> None:
|
|
69
|
+
"""发送hello消息"""
|
|
70
|
+
hello_message = {
|
|
71
|
+
"type": "hello",
|
|
72
|
+
"version": 1,
|
|
73
|
+
"features": {"mcp": True, "aec": aec},
|
|
74
|
+
"transport": "websocket",
|
|
75
|
+
"audio_params": {
|
|
76
|
+
"format": "opus",
|
|
77
|
+
"sample_rate": 16000,
|
|
78
|
+
"channels": 1,
|
|
79
|
+
"frame_duration": 60,
|
|
80
|
+
},
|
|
81
|
+
}
|
|
82
|
+
await self.websocket.send(json.dumps(hello_message))
|
|
83
|
+
await asyncio.wait_for(self.hello_received.wait(), timeout=10.0)
|
|
84
|
+
|
|
85
|
+
async def _start_listen(self) -> None:
|
|
86
|
+
"""开始监听"""
|
|
87
|
+
listen_message = {"session_id": self.session_id, "type": "listen", "state": "start", "mode": "realtime"}
|
|
88
|
+
await self.websocket.send(json.dumps(listen_message))
|
|
89
|
+
|
|
90
|
+
async def is_activate(self, ota_info):
|
|
91
|
+
"""是否激活"""
|
|
92
|
+
if ota_info.get("activation"):
|
|
93
|
+
return False
|
|
94
|
+
|
|
95
|
+
return True
|
|
96
|
+
|
|
97
|
+
async def _activate_iot_device(self, license_key: str, ota_info: Dict[str, Any]) -> None:
|
|
98
|
+
"""激活IoT设备"""
|
|
99
|
+
if not self.ota:
|
|
100
|
+
return
|
|
101
|
+
|
|
102
|
+
challenge = ota_info["activation"]["challenge"]
|
|
103
|
+
await asyncio.sleep(3)
|
|
104
|
+
self.wait_device_activated = True
|
|
105
|
+
for _ in range(10):
|
|
106
|
+
if await self.ota.check_activate(challenge, license_key):
|
|
107
|
+
self.wait_device_activated = False
|
|
108
|
+
break
|
|
109
|
+
await asyncio.sleep(3)
|
|
110
|
+
|
|
111
|
+
async def _send_demo_audio(self) -> None:
|
|
112
|
+
"""发送演示音频"""
|
|
113
|
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
114
|
+
wav_path = os.path.join(current_dir, "../file/audio/greet.wav")
|
|
115
|
+
framerate, channels = get_wav_info(wav_path)
|
|
116
|
+
audio_opus = AudioOpus(framerate, channels)
|
|
117
|
+
|
|
118
|
+
for pcm_data in read_audio_file(wav_path):
|
|
119
|
+
opus_data = await audio_opus.pcm_to_opus(pcm_data)
|
|
120
|
+
await self.websocket.send(opus_data)
|
|
121
|
+
await self.send_silence_audio()
|
|
122
|
+
|
|
123
|
+
async def send_wake_word(self, wake_word: str) -> None:
|
|
124
|
+
"""发送唤醒词"""
|
|
125
|
+
await self.websocket.send(
|
|
126
|
+
json.dumps({"session_id": self.session_id, "type": "listen", "state": "detect", "text": wake_word})
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
async def send_silence_audio(self, duration_seconds: float = 1.2) -> None:
|
|
130
|
+
"""发送静音音频"""
|
|
131
|
+
frames_count = int(duration_seconds * 1000 / 60)
|
|
132
|
+
pcm_frame = b"\x00\x00" * int(INPUT_SERVER_AUDIO_SAMPLE_RATE / 1000 * 60)
|
|
133
|
+
|
|
134
|
+
for _ in range(frames_count):
|
|
135
|
+
await self.send_audio(pcm_frame)
|
|
136
|
+
|
|
137
|
+
async def _handle_websocket_message(self, message: Any) -> None:
|
|
138
|
+
"""处理接受到的WebSocket消息"""
|
|
139
|
+
|
|
140
|
+
# audio data
|
|
141
|
+
if isinstance(message, bytes):
|
|
142
|
+
pcm_array = await self.audio_opus.opus_to_pcm(message)
|
|
143
|
+
self.output_audio_queue.extend(pcm_array)
|
|
144
|
+
return
|
|
145
|
+
|
|
146
|
+
# json message
|
|
147
|
+
data = json.loads(message)
|
|
148
|
+
message_type = data["type"]
|
|
149
|
+
if message_type == "hello":
|
|
150
|
+
self.hello_received.set()
|
|
151
|
+
self.session_id = data["session_id"]
|
|
152
|
+
elif message_type == "mcp":
|
|
153
|
+
await self.mcp(data)
|
|
154
|
+
elif self.message_handler_callback:
|
|
155
|
+
if data["type"] == "tts":
|
|
156
|
+
if data["state"] == "sentence_start":
|
|
157
|
+
self.is_playing = True
|
|
158
|
+
# self.output_audio_queue.clear()
|
|
159
|
+
else:
|
|
160
|
+
self.is_playing = False
|
|
161
|
+
|
|
162
|
+
await self.message_handler_callback(data)
|
|
163
|
+
|
|
164
|
+
async def _message_handler(self) -> None:
|
|
165
|
+
"""消息处理器"""
|
|
166
|
+
try:
|
|
167
|
+
async for message in self.websocket:
|
|
168
|
+
await self._handle_websocket_message(message)
|
|
169
|
+
except websockets.ConnectionClosed:
|
|
170
|
+
if self.message_handler_callback:
|
|
171
|
+
await self.message_handler_callback(
|
|
172
|
+
{"type": "websocket", "state": "close", "source": "sdk.message_handler"}
|
|
173
|
+
)
|
|
174
|
+
logger.debug("[websocket] close")
|
|
175
|
+
|
|
176
|
+
async def set_mcp_tool(self, mcp_tool_list) -> None:
|
|
177
|
+
"""设置MCP工具"""
|
|
178
|
+
for mcp_tool in mcp_tool_list:
|
|
179
|
+
self.mcp_tool_dict[mcp_tool["name"]] = mcp_tool
|
|
180
|
+
|
|
181
|
+
async def connect_websocket(self, websocket_token):
|
|
182
|
+
"""连接websocket"""
|
|
183
|
+
headers = {
|
|
184
|
+
"Authorization": "Bearer {}".format(websocket_token),
|
|
185
|
+
"Protocol-Version": "1",
|
|
186
|
+
"Device-Id": self.mac_addr,
|
|
187
|
+
"Client-Id": self.client_id,
|
|
188
|
+
}
|
|
189
|
+
try:
|
|
190
|
+
self.websocket = await websockets.connect(uri=self.url, additional_headers=headers)
|
|
191
|
+
except websockets.exceptions.InvalidMessage as e:
|
|
192
|
+
logger.error("[websocket] 连接失败,请检查网络连接或设备状态。当前链接地址: %s, 错误信息:%s", self.url, e)
|
|
193
|
+
return
|
|
194
|
+
self.message_handler_task = asyncio.create_task(self._message_handler())
|
|
195
|
+
|
|
196
|
+
await self._send_hello(self.aec)
|
|
197
|
+
await self._start_listen()
|
|
198
|
+
logger.debug("[websocket] Connection successful")
|
|
199
|
+
await asyncio.sleep(0.5)
|
|
200
|
+
|
|
201
|
+
async def init_connection(
|
|
202
|
+
self, mac_addr: str, aec: bool = False, serial_number: str = "", license_key: str = ""
|
|
203
|
+
) -> None:
|
|
204
|
+
"""初始化连接"""
|
|
205
|
+
mac_pattern = r"^([0-9A-Fa-f]{2}:){5}[0-9A-Fa-f]{2}$"
|
|
206
|
+
if not re.match(mac_pattern, mac_addr):
|
|
207
|
+
raise ValueError(f"无效的MAC地址格式: {mac_addr}。正确格式应为 XX:XX:XX:XX:XX:XX")
|
|
208
|
+
|
|
209
|
+
self.mac_addr = mac_addr.lower()
|
|
210
|
+
self.aec = aec
|
|
211
|
+
|
|
212
|
+
self.ota = OtaDevice(self.mac_addr, self.client_id, self.ota_url, serial_number)
|
|
213
|
+
ota_info = await self.ota.activate_device()
|
|
214
|
+
ws_url = ota_info.get("websocket", {}).get("url")
|
|
215
|
+
self.url = self.url or ws_url
|
|
216
|
+
|
|
217
|
+
if not self.url:
|
|
218
|
+
logger.warning("[websocket] 未找到websocket链接地址")
|
|
219
|
+
return
|
|
220
|
+
|
|
221
|
+
if "tenclass.net" not in self.url and "xiaozhi.me" not in self.url:
|
|
222
|
+
logger.warning("[websocket] 检测到非官方服务器,当前链接地址: %s", self.url)
|
|
223
|
+
|
|
224
|
+
self.websocket_token = ota_info["websocket"]["token"]
|
|
225
|
+
await self.connect_websocket(self.websocket_token)
|
|
226
|
+
|
|
227
|
+
if not await self.is_activate(ota_info):
|
|
228
|
+
self.iot_task = asyncio.create_task(self._activate_iot_device(license_key, ota_info))
|
|
229
|
+
logger.debug("[IOT] 设备未激活")
|
|
230
|
+
|
|
231
|
+
if self.wake_word:
|
|
232
|
+
await self.send_wake_word(self.wake_word)
|
|
233
|
+
|
|
234
|
+
async def send_audio(self, pcm: bytes) -> None:
|
|
235
|
+
"""发送音频数据"""
|
|
236
|
+
if not self.websocket:
|
|
237
|
+
return
|
|
238
|
+
|
|
239
|
+
state = self.websocket.state
|
|
240
|
+
if state == websockets.protocol.State.OPEN:
|
|
241
|
+
opus_data = await self.audio_opus.pcm_to_opus(pcm)
|
|
242
|
+
await self.websocket.send(opus_data)
|
|
243
|
+
elif state in [websockets.protocol.State.CLOSED, websockets.protocol.State.CLOSING]:
|
|
244
|
+
if self.wait_device_activated:
|
|
245
|
+
logger.debug("[websocket] Server actively disconnected, reconnecting...")
|
|
246
|
+
await self.connect_websocket(self.websocket_token)
|
|
247
|
+
elif self.message_handler_callback:
|
|
248
|
+
await self.message_handler_callback({"type": "websocket", "state": "close", "source": "sdk.send_audio"})
|
|
249
|
+
self.websocket = None
|
|
250
|
+
logger.debug("[websocket] Server actively disconnected")
|
|
251
|
+
|
|
252
|
+
await asyncio.sleep(0.5)
|
|
253
|
+
else:
|
|
254
|
+
await asyncio.sleep(0.1)
|
|
255
|
+
|
|
256
|
+
async def close(self) -> None:
|
|
257
|
+
"""关闭连接"""
|
|
258
|
+
if self.message_handler_task and not self.message_handler_task.done():
|
|
259
|
+
self.message_handler_task.cancel()
|
|
260
|
+
try:
|
|
261
|
+
await self.message_handler_task
|
|
262
|
+
except asyncio.CancelledError:
|
|
263
|
+
pass
|
|
264
|
+
|
|
265
|
+
if self.iot_task:
|
|
266
|
+
self.iot_task.cancel()
|
|
267
|
+
|
|
268
|
+
if self.websocket:
|
|
269
|
+
await self.websocket.close()
|
xiaozhi_sdk/iot.py
CHANGED
|
@@ -1,50 +1,84 @@
|
|
|
1
|
-
import
|
|
1
|
+
import hashlib
|
|
2
|
+
import hmac
|
|
2
3
|
import json
|
|
4
|
+
import logging
|
|
5
|
+
from typing import Any, Dict, Optional
|
|
6
|
+
|
|
7
|
+
import aiohttp
|
|
3
8
|
|
|
9
|
+
from xiaozhi_sdk import __version__
|
|
4
10
|
from xiaozhi_sdk.config import OTA_URL
|
|
5
11
|
|
|
6
|
-
|
|
12
|
+
# 常量定义
|
|
13
|
+
BOARD_TYPE = "xiaozhi-sdk-box"
|
|
14
|
+
USER_AGENT = "xiaozhi-sdk/{}".format(__version__)
|
|
15
|
+
BOARD_NAME = "xiaozhi-sdk"
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger("xiaozhi_sdk")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class OtaDevice:
|
|
21
|
+
"""
|
|
22
|
+
OTA设备管理类
|
|
7
23
|
|
|
24
|
+
用于处理设备的激活和挑战验证操作。
|
|
8
25
|
|
|
9
|
-
|
|
26
|
+
Attributes:
|
|
27
|
+
ota_url (str): OTA服务器URL
|
|
28
|
+
mac_addr (str): 设备MAC地址
|
|
29
|
+
client_id (str): 客户端ID
|
|
30
|
+
serial_number (str): 设备序列号
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self, mac_addr: str, client_id: str, ota_url: Optional[str] = None, serial_number: str = "") -> None:
|
|
34
|
+
self.ota_url = ota_url or OTA_URL
|
|
35
|
+
self.ota_url = self.ota_url.rstrip("/")
|
|
10
36
|
|
|
11
|
-
def __init__(self, mac_addr: str, client_id: str, serial_number: str = ""):
|
|
12
37
|
self.mac_addr = mac_addr
|
|
13
38
|
self.client_id = client_id
|
|
14
39
|
self.serial_number = serial_number
|
|
15
40
|
|
|
16
|
-
|
|
17
|
-
|
|
41
|
+
def _get_base_headers(self) -> Dict[str, str]:
|
|
42
|
+
return {
|
|
18
43
|
"user-agent": USER_AGENT,
|
|
19
44
|
"Device-Id": self.mac_addr,
|
|
20
45
|
"Client-Id": self.client_id,
|
|
21
46
|
"Content-Type": "application/json",
|
|
22
|
-
"serial-number": self.serial_number,
|
|
23
47
|
}
|
|
48
|
+
|
|
49
|
+
async def activate_device(self) -> Dict[str, Any]:
|
|
50
|
+
headers = self._get_base_headers()
|
|
51
|
+
headers["serial-number"] = self.serial_number
|
|
52
|
+
|
|
24
53
|
payload = {
|
|
25
|
-
"application": {"version":
|
|
54
|
+
"application": {"version": __version__},
|
|
26
55
|
"board": {
|
|
27
|
-
"type":
|
|
28
|
-
"name":
|
|
56
|
+
"type": BOARD_TYPE,
|
|
57
|
+
"name": BOARD_NAME,
|
|
29
58
|
},
|
|
30
59
|
}
|
|
60
|
+
|
|
31
61
|
async with aiohttp.ClientSession() as session:
|
|
32
|
-
async with session.post(
|
|
33
|
-
|
|
34
|
-
|
|
62
|
+
async with session.post(self.ota_url + "/", headers=headers, data=json.dumps(payload)) as response:
|
|
63
|
+
if response.status != 200:
|
|
64
|
+
err_text = await response.text()
|
|
65
|
+
raise Exception(err_text)
|
|
66
|
+
response.raise_for_status()
|
|
67
|
+
return await response.json()
|
|
68
|
+
|
|
69
|
+
async def check_activate(self, challenge: str, license_key: str = "") -> bool:
|
|
70
|
+
url = f"{self.ota_url}/activate"
|
|
71
|
+
|
|
72
|
+
headers = self._get_base_headers()
|
|
73
|
+
|
|
74
|
+
hmac_instance = hmac.new(license_key.encode(), challenge.encode(), hashlib.sha256)
|
|
75
|
+
hmac_result = hmac_instance.hexdigest()
|
|
76
|
+
|
|
77
|
+
payload = {"serial_number": self.serial_number, "challenge": challenge, "hmac": hmac_result}
|
|
35
78
|
|
|
36
|
-
async def check_activate(self, challenge: str):
|
|
37
|
-
url = OTA_URL + "/activate"
|
|
38
|
-
header = {
|
|
39
|
-
"user-agent": USER_AGENT,
|
|
40
|
-
"Device-Id": self.mac_addr,
|
|
41
|
-
"Client-Id": self.client_id,
|
|
42
|
-
"Content-Type": "application/json",
|
|
43
|
-
}
|
|
44
|
-
payload = {
|
|
45
|
-
"serial_number": self.serial_number,
|
|
46
|
-
"challenge": challenge,
|
|
47
|
-
}
|
|
48
79
|
async with aiohttp.ClientSession() as session:
|
|
49
|
-
async with session.post(url, headers=
|
|
50
|
-
|
|
80
|
+
async with session.post(url, headers=headers, data=json.dumps(payload)) as response:
|
|
81
|
+
is_ok = response.status == 200
|
|
82
|
+
if not is_ok:
|
|
83
|
+
logger.debug("[IOT] wait for activate device...")
|
|
84
|
+
return is_ok
|
xiaozhi_sdk/mcp.py
CHANGED
|
@@ -1,75 +1,171 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import copy
|
|
1
3
|
import json
|
|
4
|
+
import logging
|
|
5
|
+
import time
|
|
6
|
+
from typing import Any, Dict
|
|
2
7
|
|
|
8
|
+
import numpy as np
|
|
3
9
|
import requests
|
|
4
10
|
|
|
5
|
-
from xiaozhi_sdk.
|
|
6
|
-
|
|
11
|
+
from xiaozhi_sdk.utils.mcp_tool import _get_random_music_info
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger("xiaozhi_sdk")
|
|
14
|
+
|
|
15
|
+
mcp_initialize_payload: Dict[str, Any] = {
|
|
16
|
+
"jsonrpc": "2.0",
|
|
17
|
+
"id": 1,
|
|
18
|
+
"result": {
|
|
19
|
+
"protocolVersion": "2024-11-05",
|
|
20
|
+
"capabilities": {"tools": {}},
|
|
21
|
+
"serverInfo": {"name": "", "version": "0.0.1"},
|
|
22
|
+
},
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
mcp_tools_payload: Dict[str, Any] = {
|
|
26
|
+
"jsonrpc": "2.0",
|
|
27
|
+
"id": 2,
|
|
28
|
+
"result": {"tools": []},
|
|
29
|
+
}
|
|
7
30
|
|
|
8
31
|
|
|
9
32
|
class McpTool(object):
|
|
10
33
|
|
|
11
34
|
def __init__(self):
|
|
12
35
|
self.session_id = ""
|
|
13
|
-
self.
|
|
36
|
+
self.explain_url = ""
|
|
37
|
+
self.explain_token = ""
|
|
14
38
|
self.websocket = None
|
|
15
|
-
self.
|
|
39
|
+
self.mcp_tool_dict = {}
|
|
40
|
+
self.is_playing = False
|
|
41
|
+
self.message_handler_callback = None
|
|
16
42
|
|
|
17
43
|
def get_mcp_json(self, payload: dict):
|
|
18
44
|
return json.dumps({"session_id": self.session_id, "type": "mcp", "payload": payload})
|
|
19
45
|
|
|
20
46
|
def _build_response(self, request_id: str, content: str, is_error: bool = False):
|
|
21
|
-
return self.get_mcp_json(
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
"
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
47
|
+
return self.get_mcp_json(
|
|
48
|
+
{
|
|
49
|
+
"jsonrpc": "2.0",
|
|
50
|
+
"id": request_id,
|
|
51
|
+
"result": {
|
|
52
|
+
"content": [{"type": "text", "text": content}],
|
|
53
|
+
"isError": is_error,
|
|
54
|
+
},
|
|
55
|
+
}
|
|
56
|
+
)
|
|
29
57
|
|
|
30
58
|
async def analyze_image(self, img_byte: bytes, question: str = "这张图片里有什么?"):
|
|
31
|
-
headers = {"Authorization": f"Bearer {self.
|
|
59
|
+
headers = {"Authorization": f"Bearer {self.explain_token}"}
|
|
32
60
|
files = {"file": ("camera.jpg", img_byte, "image/jpeg")}
|
|
33
61
|
payload = {"question": question}
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
62
|
+
init_time = time.time()
|
|
63
|
+
try:
|
|
64
|
+
response = requests.post(self.explain_url, files=files, data=payload, headers=headers, timeout=8)
|
|
65
|
+
res_json = response.json()
|
|
66
|
+
except Exception as e:
|
|
67
|
+
logger.error("[MCP] 图片解析 error: %s", e)
|
|
68
|
+
return "网络异常", True
|
|
69
|
+
if res_json.get("error"):
|
|
70
|
+
return res_json, True
|
|
71
|
+
logger.debug("[MCP] 图片解析耗时:%s", time.time() - init_time)
|
|
72
|
+
return res_json, False
|
|
73
|
+
|
|
74
|
+
async def play_custom_music(self, tool_func, arguments):
|
|
75
|
+
pcm_array, is_error = await tool_func(arguments)
|
|
76
|
+
while True:
|
|
77
|
+
if not self.is_playing:
|
|
78
|
+
break
|
|
79
|
+
await asyncio.sleep(0.1)
|
|
80
|
+
pcm_array = await self.audio_opus.change_sample_rate(np.array(pcm_array))
|
|
81
|
+
self.output_audio_queue.extend(pcm_array)
|
|
37
82
|
|
|
38
83
|
async def mcp_tool_call(self, mcp_json: dict):
|
|
39
84
|
tool_name = mcp_json["params"]["name"]
|
|
40
|
-
|
|
41
|
-
|
|
85
|
+
mcp_tool = self.mcp_tool_dict[tool_name]
|
|
86
|
+
arguments = mcp_json["params"]["arguments"]
|
|
87
|
+
try:
|
|
88
|
+
if tool_name == "play_custom_music":
|
|
89
|
+
# v1 返回 url
|
|
90
|
+
music_info = await _get_random_music_info(arguments["id_list"])
|
|
91
|
+
if not music_info.get("url"):
|
|
92
|
+
tool_res, is_error = {"message": "播放失败"}, True
|
|
93
|
+
else:
|
|
94
|
+
tool_res, is_error = {"message": "正在为你播放: {}".format(arguments["music_name"])}, False
|
|
95
|
+
data = {
|
|
96
|
+
"type": "music",
|
|
97
|
+
"state": "start",
|
|
98
|
+
"url": music_info["url"],
|
|
99
|
+
"text": arguments["music_name"],
|
|
100
|
+
"source": "sdk.mcp_music_tool",
|
|
101
|
+
}
|
|
102
|
+
await self.message_handler_callback(data)
|
|
103
|
+
|
|
104
|
+
# v2 音频放到输出
|
|
105
|
+
# asyncio.create_task(self.play_custom_music(tool_func, arguments))
|
|
106
|
+
|
|
107
|
+
elif mcp_tool.get("is_async"):
|
|
108
|
+
tool_res, is_error = await mcp_tool["tool_func"](arguments)
|
|
109
|
+
else:
|
|
110
|
+
tool_res, is_error = mcp_tool["tool_func"](arguments)
|
|
111
|
+
except Exception as e:
|
|
112
|
+
logger.error("[MCP] tool_name: %s, error: %s", tool_name, e)
|
|
113
|
+
return self._build_response(mcp_json["id"], "工具调用失败", True)
|
|
114
|
+
|
|
115
|
+
if is_error:
|
|
116
|
+
logger.error("[MCP] tool_name: %s, error: %s", tool_name, tool_res)
|
|
117
|
+
return self._build_response(mcp_json["id"], "工具调用失败, {}".format(tool_res), True)
|
|
118
|
+
|
|
42
119
|
if tool_name == "take_photo":
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
content = json.dumps(res, ensure_ascii=False)
|
|
48
|
-
return self._build_response(mcp_json["id"], content)
|
|
120
|
+
tool_res, is_error = await self.analyze_image(tool_res, mcp_json["params"]["arguments"]["question"])
|
|
121
|
+
|
|
122
|
+
content = json.dumps(tool_res, ensure_ascii=False)
|
|
123
|
+
return self._build_response(mcp_json["id"], content, is_error)
|
|
49
124
|
|
|
50
125
|
async def mcp(self, data: dict):
|
|
51
126
|
payload = data["payload"]
|
|
52
127
|
method = payload["method"]
|
|
53
128
|
|
|
54
129
|
if method == "initialize":
|
|
55
|
-
self.
|
|
130
|
+
self.explain_url = payload["params"]["capabilities"]["vision"]["url"]
|
|
131
|
+
# self.explain_url = "http://82.157.143.133:8000/vision/explain"
|
|
132
|
+
self.explain_token = payload["params"]["capabilities"]["vision"]["token"]
|
|
133
|
+
|
|
56
134
|
mcp_initialize_payload["id"] = payload["id"]
|
|
57
135
|
await self.websocket.send(self.get_mcp_json(mcp_initialize_payload))
|
|
58
136
|
|
|
137
|
+
elif method == "notifications/initialized":
|
|
138
|
+
# print("\nMCP 工具初始化")
|
|
139
|
+
pass
|
|
140
|
+
|
|
141
|
+
elif method == "notifications/cancelled":
|
|
142
|
+
logger.error("[MCP] 工具加载失败")
|
|
143
|
+
|
|
59
144
|
elif method == "tools/list":
|
|
60
145
|
mcp_tools_payload["id"] = payload["id"]
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
146
|
+
tool_name_list = []
|
|
147
|
+
mcp_tool_dict = copy.deepcopy(self.mcp_tool_dict)
|
|
148
|
+
for _, mcp_tool in mcp_tool_dict.items():
|
|
149
|
+
tool_name_list.append(mcp_tool["name"])
|
|
150
|
+
tool_func = mcp_tool.pop("tool_func", None)
|
|
151
|
+
if not tool_func:
|
|
152
|
+
logger.error("[MCP] Tool %s has no tool_func", mcp_tool["name"])
|
|
153
|
+
return
|
|
154
|
+
mcp_tool.pop("is_async", None)
|
|
155
|
+
mcp_tools_payload["result"]["tools"].append(mcp_tool)
|
|
65
156
|
|
|
66
157
|
await self.websocket.send(self.get_mcp_json(mcp_tools_payload))
|
|
158
|
+
logger.debug("[MCP] 加载成功,当前可用工具列表为:%s", tool_name_list)
|
|
67
159
|
|
|
68
160
|
elif method == "tools/call":
|
|
69
|
-
print("tools/call", payload)
|
|
70
161
|
tool_name = payload["params"]["name"]
|
|
71
|
-
|
|
72
|
-
|
|
162
|
+
|
|
163
|
+
if not self.mcp_tool_dict.get(tool_name):
|
|
164
|
+
logger.warning("[MCP] Tool not found: %s", tool_name)
|
|
165
|
+
return
|
|
73
166
|
|
|
74
167
|
mcp_res = await self.mcp_tool_call(payload)
|
|
75
168
|
await self.websocket.send(mcp_res)
|
|
169
|
+
logger.debug("[MCP] Tool %s called", tool_name)
|
|
170
|
+
else:
|
|
171
|
+
logger.warning("[MCP] unknown method %s: %s", method, payload)
|
xiaozhi_sdk/opus.py
CHANGED
|
@@ -1,13 +1,11 @@
|
|
|
1
|
-
import
|
|
2
|
-
from xiaozhi_sdk import INPUT_SERVER_AUDIO_SAMPLE_RATE
|
|
1
|
+
import math
|
|
3
2
|
|
|
4
|
-
# 设置 opus 库路径
|
|
5
|
-
os.environ["DYLD_LIBRARY_PATH"] = "/opt/homebrew/lib:" + os.environ.get("DYLD_LIBRARY_PATH", "")
|
|
6
|
-
os.environ["LIBRARY_PATH"] = "/opt/homebrew/lib:" + os.environ.get("LIBRARY_PATH", "")
|
|
7
3
|
import av
|
|
8
4
|
import numpy as np
|
|
9
5
|
import opuslib
|
|
10
6
|
|
|
7
|
+
from xiaozhi_sdk.config import INPUT_SERVER_AUDIO_SAMPLE_RATE
|
|
8
|
+
|
|
11
9
|
|
|
12
10
|
class AudioOpus:
|
|
13
11
|
|
|
@@ -33,11 +31,16 @@ class AudioOpus:
|
|
|
33
31
|
pcm_bytes = pcm_array.tobytes()
|
|
34
32
|
return self.opus_encoder.encode(pcm_bytes, 960)
|
|
35
33
|
|
|
36
|
-
|
|
34
|
+
@staticmethod
|
|
35
|
+
def to_n_960(samples) -> np.ndarray:
|
|
36
|
+
n = math.ceil(samples.shape[0] / 960)
|
|
37
|
+
arr_padded = np.pad(samples, (0, 960 * n - samples.shape[0]), mode="constant", constant_values=0)
|
|
38
|
+
return arr_padded.reshape(n, 960)
|
|
39
|
+
|
|
40
|
+
async def change_sample_rate(self, pcm_array) -> np.ndarray:
|
|
37
41
|
if self.sample_rate == INPUT_SERVER_AUDIO_SAMPLE_RATE:
|
|
38
|
-
return
|
|
42
|
+
return self.to_n_960(pcm_array)
|
|
39
43
|
|
|
40
|
-
c = int(self.sample_rate / INPUT_SERVER_AUDIO_SAMPLE_RATE)
|
|
41
44
|
frame = av.AudioFrame.from_ndarray(np.array(pcm_array).reshape(1, -1), format="s16", layout="mono")
|
|
42
45
|
frame.sample_rate = INPUT_SERVER_AUDIO_SAMPLE_RATE # Assuming input is 16kHz
|
|
43
46
|
resampled_frames = self.resampler.resample(frame)
|
|
@@ -49,10 +52,9 @@ class AudioOpus:
|
|
|
49
52
|
)
|
|
50
53
|
new_frame.sample_rate = self.sample_rate
|
|
51
54
|
new_samples = new_frame.to_ndarray().flatten()
|
|
52
|
-
|
|
53
|
-
return arr_padded.reshape(c, 960)
|
|
55
|
+
return self.to_n_960(new_samples)
|
|
54
56
|
|
|
55
|
-
async def opus_to_pcm(self, opus):
|
|
57
|
+
async def opus_to_pcm(self, opus) -> np.ndarray:
|
|
56
58
|
pcm_data = self.opus_decoder.decode(opus, 960)
|
|
57
59
|
pcm_array = np.frombuffer(pcm_data, dtype=np.int16)
|
|
58
60
|
samples = await self.change_sample_rate(pcm_array)
|