AstrBot 4.1.7__py3-none-any.whl → 4.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astrbot/core/config/default.py +33 -1
- astrbot/core/conversation_mgr.py +12 -4
- astrbot/core/db/__init__.py +5 -0
- astrbot/core/db/sqlite.py +8 -0
- astrbot/core/pipeline/process_stage/method/llm_request.py +25 -8
- astrbot/core/pipeline/session_status_check/stage.py +12 -1
- astrbot/core/pipeline/waking_check/stage.py +10 -5
- astrbot/core/platform/astr_message_event.py +9 -5
- astrbot/core/platform/sources/wecom/wecom_adapter.py +1 -0
- astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +1 -0
- astrbot/core/provider/manager.py +2 -0
- astrbot/core/provider/sources/coze_api_client.py +314 -0
- astrbot/core/provider/sources/coze_source.py +635 -0
- astrbot/core/star/filter/command.py +26 -13
- astrbot/core/star/filter/command_group.py +15 -5
- astrbot/core/star/session_llm_manager.py +0 -4
- astrbot/core/utils/dify_api_client.py +44 -57
- astrbot/dashboard/routes/chat.py +70 -36
- astrbot/dashboard/routes/session_management.py +235 -78
- {astrbot-4.1.7.dist-info → astrbot-4.2.0.dist-info}/METADATA +1 -1
- {astrbot-4.1.7.dist-info → astrbot-4.2.0.dist-info}/RECORD +24 -22
- {astrbot-4.1.7.dist-info → astrbot-4.2.0.dist-info}/WHEEL +0 -0
- {astrbot-4.1.7.dist-info → astrbot-4.2.0.dist-info}/entry_points.txt +0 -0
- {astrbot-4.1.7.dist-info → astrbot-4.2.0.dist-info}/licenses/LICENSE +0 -0
astrbot/core/config/default.py
CHANGED
|
@@ -6,7 +6,7 @@ import os
|
|
|
6
6
|
|
|
7
7
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
|
8
8
|
|
|
9
|
-
VERSION = "4.
|
|
9
|
+
VERSION = "4.2.0"
|
|
10
10
|
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
|
11
11
|
|
|
12
12
|
# 默认配置
|
|
@@ -869,6 +869,18 @@ CONFIG_METADATA_2 = {
|
|
|
869
869
|
"timeout": 60,
|
|
870
870
|
"hint": "请确保你在 AstrBot 里设置的 APP 类型和 Dify 里面创建的应用的类型一致!",
|
|
871
871
|
},
|
|
872
|
+
"Coze": {
|
|
873
|
+
"id": "coze",
|
|
874
|
+
"provider": "coze",
|
|
875
|
+
"provider_type": "chat_completion",
|
|
876
|
+
"type": "coze",
|
|
877
|
+
"enable": True,
|
|
878
|
+
"coze_api_key": "",
|
|
879
|
+
"bot_id": "",
|
|
880
|
+
"coze_api_base": "https://api.coze.cn",
|
|
881
|
+
"timeout": 60,
|
|
882
|
+
"auto_save_history": True,
|
|
883
|
+
},
|
|
872
884
|
"阿里云百炼应用": {
|
|
873
885
|
"id": "dashscope",
|
|
874
886
|
"provider": "dashscope",
|
|
@@ -1735,6 +1747,26 @@ CONFIG_METADATA_2 = {
|
|
|
1735
1747
|
"hint": "发送的消息文本内容对应的输入变量名。默认为 astrbot_text_query。",
|
|
1736
1748
|
"obvious": True,
|
|
1737
1749
|
},
|
|
1750
|
+
"coze_api_key": {
|
|
1751
|
+
"description": "Coze API Key",
|
|
1752
|
+
"type": "string",
|
|
1753
|
+
"hint": "Coze API 密钥,用于访问 Coze 服务。",
|
|
1754
|
+
},
|
|
1755
|
+
"bot_id": {
|
|
1756
|
+
"description": "Bot ID",
|
|
1757
|
+
"type": "string",
|
|
1758
|
+
"hint": "Coze 机器人的 ID,在 Coze 平台上创建机器人后获得。",
|
|
1759
|
+
},
|
|
1760
|
+
"coze_api_base": {
|
|
1761
|
+
"description": "API Base URL",
|
|
1762
|
+
"type": "string",
|
|
1763
|
+
"hint": "Coze API 的基础 URL 地址,默认为 https://api.coze.cn",
|
|
1764
|
+
},
|
|
1765
|
+
"auto_save_history": {
|
|
1766
|
+
"description": "由 Coze 管理对话记录",
|
|
1767
|
+
"type": "bool",
|
|
1768
|
+
"hint": "启用后,将由 Coze 进行对话历史记录管理, 此时 AstrBot 本地保存的上下文不会生效(仅供浏览), 对 AstrBot 的上下文进行的操作也不会生效。如果为禁用, 则使用 AstrBot 管理上下文。",
|
|
1769
|
+
},
|
|
1738
1770
|
},
|
|
1739
1771
|
},
|
|
1740
1772
|
"provider_settings": {
|
astrbot/core/conversation_mgr.py
CHANGED
|
@@ -87,17 +87,25 @@ class ConversationManager:
|
|
|
87
87
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
|
88
88
|
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
|
89
89
|
"""
|
|
90
|
-
f = False
|
|
91
90
|
if not conversation_id:
|
|
92
91
|
conversation_id = self.session_conversations.get(unified_msg_origin)
|
|
93
|
-
if conversation_id:
|
|
94
|
-
f = True
|
|
95
92
|
if conversation_id:
|
|
96
93
|
await self.db.delete_conversation(cid=conversation_id)
|
|
97
|
-
|
|
94
|
+
curr_cid = await self.get_curr_conversation_id(unified_msg_origin)
|
|
95
|
+
if curr_cid == conversation_id:
|
|
98
96
|
self.session_conversations.pop(unified_msg_origin, None)
|
|
99
97
|
await sp.session_remove(unified_msg_origin, "sel_conv_id")
|
|
100
98
|
|
|
99
|
+
async def delete_conversations_by_user_id(self, unified_msg_origin: str):
|
|
100
|
+
"""删除会话的所有对话
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
|
104
|
+
"""
|
|
105
|
+
await self.db.delete_conversations_by_user_id(user_id=unified_msg_origin)
|
|
106
|
+
self.session_conversations.pop(unified_msg_origin, None)
|
|
107
|
+
await sp.session_remove(unified_msg_origin, "sel_conv_id")
|
|
108
|
+
|
|
101
109
|
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str | None:
|
|
102
110
|
"""获取会话当前的对话 ID
|
|
103
111
|
|
astrbot/core/db/__init__.py
CHANGED
|
@@ -154,6 +154,11 @@ class BaseDatabase(abc.ABC):
|
|
|
154
154
|
"""Delete a conversation by its ID."""
|
|
155
155
|
...
|
|
156
156
|
|
|
157
|
+
@abc.abstractmethod
|
|
158
|
+
async def delete_conversations_by_user_id(self, user_id: str) -> None:
|
|
159
|
+
"""Delete all conversations for a specific user."""
|
|
160
|
+
...
|
|
161
|
+
|
|
157
162
|
@abc.abstractmethod
|
|
158
163
|
async def insert_platform_message_history(
|
|
159
164
|
self,
|
astrbot/core/db/sqlite.py
CHANGED
|
@@ -249,6 +249,14 @@ class SQLiteDatabase(BaseDatabase):
|
|
|
249
249
|
delete(ConversationV2).where(ConversationV2.conversation_id == cid)
|
|
250
250
|
)
|
|
251
251
|
|
|
252
|
+
async def delete_conversations_by_user_id(self, user_id: str) -> None:
|
|
253
|
+
async with self.get_db() as session:
|
|
254
|
+
session: AsyncSession
|
|
255
|
+
async with session.begin():
|
|
256
|
+
await session.execute(
|
|
257
|
+
delete(ConversationV2).where(ConversationV2.user_id == user_id)
|
|
258
|
+
)
|
|
259
|
+
|
|
252
260
|
async def insert_platform_message_history(
|
|
253
261
|
self,
|
|
254
262
|
platform_id,
|
|
@@ -291,13 +291,6 @@ async def run_agent(
|
|
|
291
291
|
else:
|
|
292
292
|
astr_event.set_result(MessageEventResult().message(err_msg))
|
|
293
293
|
return
|
|
294
|
-
asyncio.create_task(
|
|
295
|
-
Metric.upload(
|
|
296
|
-
llm_tick=1,
|
|
297
|
-
model_name=agent_runner.provider.get_model(),
|
|
298
|
-
provider_type=agent_runner.provider.meta().type,
|
|
299
|
-
)
|
|
300
|
-
)
|
|
301
294
|
|
|
302
295
|
|
|
303
296
|
class LLMRequestSubStage(Stage):
|
|
@@ -524,6 +517,14 @@ class LLMRequestSubStage(Stage):
|
|
|
524
517
|
if event.get_platform_name() == "webchat":
|
|
525
518
|
asyncio.create_task(self._handle_webchat(event, req, provider))
|
|
526
519
|
|
|
520
|
+
asyncio.create_task(
|
|
521
|
+
Metric.upload(
|
|
522
|
+
llm_tick=1,
|
|
523
|
+
model_name=agent_runner.provider.get_model(),
|
|
524
|
+
provider_type=agent_runner.provider.meta().type,
|
|
525
|
+
)
|
|
526
|
+
)
|
|
527
|
+
|
|
527
528
|
async def _handle_webchat(
|
|
528
529
|
self, event: AstrMessageEvent, req: ProviderRequest, prov: Provider
|
|
529
530
|
):
|
|
@@ -536,7 +537,23 @@ class LLMRequestSubStage(Stage):
|
|
|
536
537
|
latest_pair = messages[-2:]
|
|
537
538
|
if not latest_pair:
|
|
538
539
|
return
|
|
539
|
-
|
|
540
|
+
content = latest_pair[0].get("content", "")
|
|
541
|
+
if isinstance(content, list):
|
|
542
|
+
# 多模态
|
|
543
|
+
text_parts = []
|
|
544
|
+
for item in content:
|
|
545
|
+
if isinstance(item, dict):
|
|
546
|
+
if item.get("type") == "text":
|
|
547
|
+
text_parts.append(item.get("text", ""))
|
|
548
|
+
elif item.get("type") == "image":
|
|
549
|
+
text_parts.append("[图片]")
|
|
550
|
+
elif isinstance(item, str):
|
|
551
|
+
text_parts.append(item)
|
|
552
|
+
cleaned_text = "User: " + " ".join(text_parts).strip()
|
|
553
|
+
elif isinstance(content, str):
|
|
554
|
+
cleaned_text = "User: " + content.strip()
|
|
555
|
+
else:
|
|
556
|
+
return
|
|
540
557
|
logger.debug(f"WebChat 对话标题生成请求,清理后的文本: {cleaned_text}")
|
|
541
558
|
llm_resp = await prov.text_chat(
|
|
542
559
|
system_prompt="You are expert in summarizing user's query.",
|
|
@@ -11,7 +11,8 @@ class SessionStatusCheckStage(Stage):
|
|
|
11
11
|
"""检查会话是否整体启用"""
|
|
12
12
|
|
|
13
13
|
async def initialize(self, ctx: PipelineContext) -> None:
|
|
14
|
-
|
|
14
|
+
self.ctx = ctx
|
|
15
|
+
self.conv_mgr = ctx.plugin_manager.context.conversation_manager
|
|
15
16
|
|
|
16
17
|
async def process(
|
|
17
18
|
self, event: AstrMessageEvent
|
|
@@ -19,4 +20,14 @@ class SessionStatusCheckStage(Stage):
|
|
|
19
20
|
# 检查会话是否整体启用
|
|
20
21
|
if not SessionServiceManager.is_session_enabled(event.unified_msg_origin):
|
|
21
22
|
logger.debug(f"会话 {event.unified_msg_origin} 已被关闭,已终止事件传播。")
|
|
23
|
+
|
|
24
|
+
# workaround for #2309
|
|
25
|
+
conv_id = await self.conv_mgr.get_curr_conversation_id(
|
|
26
|
+
event.unified_msg_origin
|
|
27
|
+
)
|
|
28
|
+
if not conv_id:
|
|
29
|
+
await self.conv_mgr.new_conversation(
|
|
30
|
+
event.unified_msg_origin, platform_id=event.get_platform_id()
|
|
31
|
+
)
|
|
32
|
+
|
|
22
33
|
event.stop_event()
|
|
@@ -5,6 +5,7 @@ from astrbot.core.message.components import At, AtAll, Reply
|
|
|
5
5
|
from astrbot.core.message.message_event_result import MessageChain, MessageEventResult
|
|
6
6
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
|
7
7
|
from astrbot.core.star.filter.permission import PermissionTypeFilter
|
|
8
|
+
from astrbot.core.star.filter.command_group import CommandGroupFilter
|
|
8
9
|
from astrbot.core.star.session_plugin_manager import SessionPluginManager
|
|
9
10
|
from astrbot.core.star.star import star_map
|
|
10
11
|
from astrbot.core.star.star_handler import EventType, star_handlers_registry
|
|
@@ -170,11 +171,15 @@ class WakingCheckStage(Stage):
|
|
|
170
171
|
is_wake = True
|
|
171
172
|
event.is_wake = True
|
|
172
173
|
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
)
|
|
174
|
+
is_group_cmd_handler = any(
|
|
175
|
+
isinstance(f, CommandGroupFilter) for f in handler.event_filters
|
|
176
|
+
)
|
|
177
|
+
if not is_group_cmd_handler:
|
|
178
|
+
activated_handlers.append(handler)
|
|
179
|
+
if "parsed_params" in event.get_extra(default={}):
|
|
180
|
+
handlers_parsed_params[handler.handler_full_name] = (
|
|
181
|
+
event.get_extra("parsed_params")
|
|
182
|
+
)
|
|
178
183
|
|
|
179
184
|
event._extras.pop("parsed_params", None)
|
|
180
185
|
|
|
@@ -4,7 +4,7 @@ import re
|
|
|
4
4
|
import hashlib
|
|
5
5
|
import uuid
|
|
6
6
|
|
|
7
|
-
from typing import List, Union, Optional, AsyncGenerator
|
|
7
|
+
from typing import List, Union, Optional, AsyncGenerator, TypeVar, Any
|
|
8
8
|
|
|
9
9
|
from astrbot import logger
|
|
10
10
|
from astrbot.core.db.po import Conversation
|
|
@@ -26,6 +26,8 @@ from .astrbot_message import AstrBotMessage, Group
|
|
|
26
26
|
from .platform_metadata import PlatformMetadata
|
|
27
27
|
from .message_session import MessageSession, MessageSesion # noqa
|
|
28
28
|
|
|
29
|
+
_VT = TypeVar("_VT")
|
|
30
|
+
|
|
29
31
|
|
|
30
32
|
class AstrMessageEvent(abc.ABC):
|
|
31
33
|
def __init__(
|
|
@@ -49,7 +51,7 @@ class AstrMessageEvent(abc.ABC):
|
|
|
49
51
|
"""是否唤醒(是否通过 WakingStage)"""
|
|
50
52
|
self.is_at_or_wake_command = False
|
|
51
53
|
"""是否是 At 机器人或者带有唤醒词或者是私聊(插件注册的事件监听器会让 is_wake 设为 True, 但是不会让这个属性置为 True)"""
|
|
52
|
-
self._extras = {}
|
|
54
|
+
self._extras: dict[str, Any] = {}
|
|
53
55
|
self.session = MessageSesion(
|
|
54
56
|
platform_name=platform_meta.id,
|
|
55
57
|
message_type=message_obj.type,
|
|
@@ -57,7 +59,7 @@ class AstrMessageEvent(abc.ABC):
|
|
|
57
59
|
)
|
|
58
60
|
self.unified_msg_origin = str(self.session)
|
|
59
61
|
"""统一的消息来源字符串。格式为 platform_name:message_type:session_id"""
|
|
60
|
-
self._result: MessageEventResult = None
|
|
62
|
+
self._result: MessageEventResult | None = None
|
|
61
63
|
"""消息事件的结果"""
|
|
62
64
|
|
|
63
65
|
self._has_send_oper = False
|
|
@@ -173,13 +175,15 @@ class AstrMessageEvent(abc.ABC):
|
|
|
173
175
|
"""
|
|
174
176
|
self._extras[key] = value
|
|
175
177
|
|
|
176
|
-
def get_extra(
|
|
178
|
+
def get_extra(
|
|
179
|
+
self, key: str | None = None, default: _VT = None
|
|
180
|
+
) -> dict[str, Any] | _VT:
|
|
177
181
|
"""
|
|
178
182
|
获取额外的信息。
|
|
179
183
|
"""
|
|
180
184
|
if key is None:
|
|
181
185
|
return self._extras
|
|
182
|
-
return self._extras.get(key,
|
|
186
|
+
return self._extras.get(key, default)
|
|
183
187
|
|
|
184
188
|
def clear_extra(self):
|
|
185
189
|
"""
|
astrbot/core/provider/manager.py
CHANGED
|
@@ -234,6 +234,8 @@ class ProviderManager:
|
|
|
234
234
|
)
|
|
235
235
|
case "dify":
|
|
236
236
|
from .sources.dify_source import ProviderDify as ProviderDify
|
|
237
|
+
case "coze":
|
|
238
|
+
from .sources.coze_source import ProviderCoze as ProviderCoze
|
|
237
239
|
case "dashscope":
|
|
238
240
|
from .sources.dashscope_source import (
|
|
239
241
|
ProviderDashscope as ProviderDashscope,
|
|
@@ -0,0 +1,314 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import asyncio
|
|
3
|
+
import aiohttp
|
|
4
|
+
import io
|
|
5
|
+
from typing import Dict, List, Any, AsyncGenerator
|
|
6
|
+
from astrbot.core import logger
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class CozeAPIClient:
|
|
10
|
+
def __init__(self, api_key: str, api_base: str = "https://api.coze.cn"):
|
|
11
|
+
self.api_key = api_key
|
|
12
|
+
self.api_base = api_base
|
|
13
|
+
self.session = None
|
|
14
|
+
|
|
15
|
+
async def _ensure_session(self):
|
|
16
|
+
"""确保HTTP session存在"""
|
|
17
|
+
if self.session is None:
|
|
18
|
+
connector = aiohttp.TCPConnector(
|
|
19
|
+
ssl=False if self.api_base.startswith("http://") else True,
|
|
20
|
+
limit=100,
|
|
21
|
+
limit_per_host=30,
|
|
22
|
+
keepalive_timeout=30,
|
|
23
|
+
enable_cleanup_closed=True,
|
|
24
|
+
)
|
|
25
|
+
timeout = aiohttp.ClientTimeout(
|
|
26
|
+
total=120, # 默认超时时间
|
|
27
|
+
connect=30,
|
|
28
|
+
sock_read=120,
|
|
29
|
+
)
|
|
30
|
+
headers = {
|
|
31
|
+
"Authorization": f"Bearer {self.api_key}",
|
|
32
|
+
"Accept": "text/event-stream",
|
|
33
|
+
}
|
|
34
|
+
self.session = aiohttp.ClientSession(
|
|
35
|
+
headers=headers, timeout=timeout, connector=connector
|
|
36
|
+
)
|
|
37
|
+
return self.session
|
|
38
|
+
|
|
39
|
+
async def upload_file(
|
|
40
|
+
self,
|
|
41
|
+
file_data: bytes,
|
|
42
|
+
) -> str:
|
|
43
|
+
"""上传文件到 Coze 并返回 file_id
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
file_data (bytes): 文件的二进制数据
|
|
47
|
+
Returns:
|
|
48
|
+
str: 上传成功后返回的 file_id
|
|
49
|
+
"""
|
|
50
|
+
session = await self._ensure_session()
|
|
51
|
+
url = f"{self.api_base}/v1/files/upload"
|
|
52
|
+
|
|
53
|
+
try:
|
|
54
|
+
file_io = io.BytesIO(file_data)
|
|
55
|
+
async with session.post(
|
|
56
|
+
url,
|
|
57
|
+
data={
|
|
58
|
+
"file": file_io,
|
|
59
|
+
},
|
|
60
|
+
timeout=aiohttp.ClientTimeout(total=60),
|
|
61
|
+
) as response:
|
|
62
|
+
if response.status == 401:
|
|
63
|
+
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
|
|
64
|
+
|
|
65
|
+
response_text = await response.text()
|
|
66
|
+
logger.debug(
|
|
67
|
+
f"文件上传响应状态: {response.status}, 内容: {response_text}"
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
if response.status != 200:
|
|
71
|
+
raise Exception(
|
|
72
|
+
f"文件上传失败,状态码: {response.status}, 响应: {response_text}"
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
try:
|
|
76
|
+
result = await response.json()
|
|
77
|
+
except json.JSONDecodeError:
|
|
78
|
+
raise Exception(f"文件上传响应解析失败: {response_text}")
|
|
79
|
+
|
|
80
|
+
if result.get("code") != 0:
|
|
81
|
+
raise Exception(f"文件上传失败: {result.get('msg', '未知错误')}")
|
|
82
|
+
|
|
83
|
+
file_id = result["data"]["id"]
|
|
84
|
+
logger.debug(f"[Coze] 图片上传成功,file_id: {file_id}")
|
|
85
|
+
return file_id
|
|
86
|
+
|
|
87
|
+
except asyncio.TimeoutError:
|
|
88
|
+
logger.error("文件上传超时")
|
|
89
|
+
raise Exception("文件上传超时")
|
|
90
|
+
except Exception as e:
|
|
91
|
+
logger.error(f"文件上传失败: {str(e)}")
|
|
92
|
+
raise Exception(f"文件上传失败: {str(e)}")
|
|
93
|
+
|
|
94
|
+
async def download_image(self, image_url: str) -> bytes:
|
|
95
|
+
"""下载图片并返回字节数据
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
image_url (str): 图片的URL
|
|
99
|
+
Returns:
|
|
100
|
+
bytes: 图片的二进制数据
|
|
101
|
+
"""
|
|
102
|
+
session = await self._ensure_session()
|
|
103
|
+
|
|
104
|
+
try:
|
|
105
|
+
async with session.get(image_url) as response:
|
|
106
|
+
if response.status != 200:
|
|
107
|
+
raise Exception(f"下载图片失败,状态码: {response.status}")
|
|
108
|
+
|
|
109
|
+
image_data = await response.read()
|
|
110
|
+
return image_data
|
|
111
|
+
|
|
112
|
+
except Exception as e:
|
|
113
|
+
logger.error(f"下载图片失败 {image_url}: {str(e)}")
|
|
114
|
+
raise Exception(f"下载图片失败: {str(e)}")
|
|
115
|
+
|
|
116
|
+
async def chat_messages(
|
|
117
|
+
self,
|
|
118
|
+
bot_id: str,
|
|
119
|
+
user_id: str,
|
|
120
|
+
additional_messages: List[Dict] | None = None,
|
|
121
|
+
conversation_id: str | None = None,
|
|
122
|
+
auto_save_history: bool = True,
|
|
123
|
+
stream: bool = True,
|
|
124
|
+
timeout: float = 120,
|
|
125
|
+
) -> AsyncGenerator[Dict[str, Any], None]:
|
|
126
|
+
"""发送聊天消息并返回流式响应
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
bot_id: Bot ID
|
|
130
|
+
user_id: 用户ID
|
|
131
|
+
additional_messages: 额外消息列表
|
|
132
|
+
conversation_id: 会话ID
|
|
133
|
+
auto_save_history: 是否自动保存历史
|
|
134
|
+
stream: 是否流式响应
|
|
135
|
+
timeout: 超时时间
|
|
136
|
+
"""
|
|
137
|
+
session = await self._ensure_session()
|
|
138
|
+
url = f"{self.api_base}/v3/chat"
|
|
139
|
+
|
|
140
|
+
payload = {
|
|
141
|
+
"bot_id": bot_id,
|
|
142
|
+
"user_id": user_id,
|
|
143
|
+
"stream": stream,
|
|
144
|
+
"auto_save_history": auto_save_history,
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
if additional_messages:
|
|
148
|
+
payload["additional_messages"] = additional_messages
|
|
149
|
+
|
|
150
|
+
params = {}
|
|
151
|
+
if conversation_id:
|
|
152
|
+
params["conversation_id"] = conversation_id
|
|
153
|
+
|
|
154
|
+
logger.debug(f"Coze chat_messages payload: {payload}, params: {params}")
|
|
155
|
+
|
|
156
|
+
try:
|
|
157
|
+
async with session.post(
|
|
158
|
+
url,
|
|
159
|
+
json=payload,
|
|
160
|
+
params=params,
|
|
161
|
+
timeout=aiohttp.ClientTimeout(total=timeout),
|
|
162
|
+
) as response:
|
|
163
|
+
if response.status == 401:
|
|
164
|
+
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
|
|
165
|
+
|
|
166
|
+
if response.status != 200:
|
|
167
|
+
raise Exception(f"Coze API 流式请求失败,状态码: {response.status}")
|
|
168
|
+
|
|
169
|
+
# SSE
|
|
170
|
+
buffer = ""
|
|
171
|
+
event_type = None
|
|
172
|
+
event_data = None
|
|
173
|
+
|
|
174
|
+
async for chunk in response.content:
|
|
175
|
+
if chunk:
|
|
176
|
+
buffer += chunk.decode("utf-8", errors="ignore")
|
|
177
|
+
lines = buffer.split("\n")
|
|
178
|
+
buffer = lines[-1]
|
|
179
|
+
|
|
180
|
+
for line in lines[:-1]:
|
|
181
|
+
line = line.strip()
|
|
182
|
+
|
|
183
|
+
if not line:
|
|
184
|
+
if event_type and event_data:
|
|
185
|
+
yield {"event": event_type, "data": event_data}
|
|
186
|
+
event_type = None
|
|
187
|
+
event_data = None
|
|
188
|
+
elif line.startswith("event:"):
|
|
189
|
+
event_type = line[6:].strip()
|
|
190
|
+
elif line.startswith("data:"):
|
|
191
|
+
data_str = line[5:].strip()
|
|
192
|
+
if data_str and data_str != "[DONE]":
|
|
193
|
+
try:
|
|
194
|
+
event_data = json.loads(data_str)
|
|
195
|
+
except json.JSONDecodeError:
|
|
196
|
+
event_data = {"content": data_str}
|
|
197
|
+
|
|
198
|
+
except asyncio.TimeoutError:
|
|
199
|
+
raise Exception(f"Coze API 流式请求超时 ({timeout}秒)")
|
|
200
|
+
except Exception as e:
|
|
201
|
+
raise Exception(f"Coze API 流式请求失败: {str(e)}")
|
|
202
|
+
|
|
203
|
+
async def clear_context(self, conversation_id: str):
|
|
204
|
+
"""清空会话上下文
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
conversation_id: 会话ID
|
|
208
|
+
Returns:
|
|
209
|
+
dict: API响应结果
|
|
210
|
+
"""
|
|
211
|
+
session = await self._ensure_session()
|
|
212
|
+
url = f"{self.api_base}/v3/conversation/message/clear_context"
|
|
213
|
+
payload = {"conversation_id": conversation_id}
|
|
214
|
+
|
|
215
|
+
try:
|
|
216
|
+
async with session.post(url, json=payload) as response:
|
|
217
|
+
response_text = await response.text()
|
|
218
|
+
|
|
219
|
+
if response.status == 401:
|
|
220
|
+
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
|
|
221
|
+
|
|
222
|
+
if response.status != 200:
|
|
223
|
+
raise Exception(f"Coze API 请求失败,状态码: {response.status}")
|
|
224
|
+
|
|
225
|
+
try:
|
|
226
|
+
return json.loads(response_text)
|
|
227
|
+
except json.JSONDecodeError:
|
|
228
|
+
raise Exception("Coze API 返回非JSON格式")
|
|
229
|
+
|
|
230
|
+
except asyncio.TimeoutError:
|
|
231
|
+
raise Exception("Coze API 请求超时")
|
|
232
|
+
except aiohttp.ClientError as e:
|
|
233
|
+
raise Exception(f"Coze API 请求失败: {str(e)}")
|
|
234
|
+
|
|
235
|
+
async def get_message_list(
|
|
236
|
+
self,
|
|
237
|
+
conversation_id: str,
|
|
238
|
+
order: str = "desc",
|
|
239
|
+
limit: int = 10,
|
|
240
|
+
offset: int = 0,
|
|
241
|
+
):
|
|
242
|
+
"""获取消息列表
|
|
243
|
+
|
|
244
|
+
Args:
|
|
245
|
+
conversation_id: 会话ID
|
|
246
|
+
order: 排序方式 (asc/desc)
|
|
247
|
+
limit: 限制数量
|
|
248
|
+
offset: 偏移量
|
|
249
|
+
Returns:
|
|
250
|
+
dict: API响应结果
|
|
251
|
+
"""
|
|
252
|
+
session = await self._ensure_session()
|
|
253
|
+
url = f"{self.api_base}/v3/conversation/message/list"
|
|
254
|
+
params = {
|
|
255
|
+
"conversation_id": conversation_id,
|
|
256
|
+
"order": order,
|
|
257
|
+
"limit": limit,
|
|
258
|
+
"offset": offset,
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
try:
|
|
262
|
+
async with session.get(url, params=params) as response:
|
|
263
|
+
response.raise_for_status()
|
|
264
|
+
return await response.json()
|
|
265
|
+
|
|
266
|
+
except Exception as e:
|
|
267
|
+
logger.error(f"获取Coze消息列表失败: {str(e)}")
|
|
268
|
+
raise Exception(f"获取Coze消息列表失败: {str(e)}")
|
|
269
|
+
|
|
270
|
+
async def close(self):
|
|
271
|
+
"""关闭会话"""
|
|
272
|
+
if self.session:
|
|
273
|
+
await self.session.close()
|
|
274
|
+
self.session = None
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
if __name__ == "__main__":
|
|
278
|
+
import os
|
|
279
|
+
import asyncio
|
|
280
|
+
|
|
281
|
+
async def test_coze_api_client():
|
|
282
|
+
api_key = os.getenv("COZE_API_KEY", "")
|
|
283
|
+
bot_id = os.getenv("COZE_BOT_ID", "")
|
|
284
|
+
client = CozeAPIClient(api_key=api_key)
|
|
285
|
+
|
|
286
|
+
try:
|
|
287
|
+
with open("README.md", "rb") as f:
|
|
288
|
+
file_data = f.read()
|
|
289
|
+
file_id = await client.upload_file(file_data)
|
|
290
|
+
print(f"Uploaded file_id: {file_id}")
|
|
291
|
+
async for event in client.chat_messages(
|
|
292
|
+
bot_id=bot_id,
|
|
293
|
+
user_id="test_user",
|
|
294
|
+
additional_messages=[
|
|
295
|
+
{
|
|
296
|
+
"role": "user",
|
|
297
|
+
"content": json.dumps(
|
|
298
|
+
[
|
|
299
|
+
{"type": "text", "text": "这是什么"},
|
|
300
|
+
{"type": "file", "file_id": file_id},
|
|
301
|
+
],
|
|
302
|
+
ensure_ascii=False,
|
|
303
|
+
),
|
|
304
|
+
"content_type": "object_string",
|
|
305
|
+
},
|
|
306
|
+
],
|
|
307
|
+
stream=True,
|
|
308
|
+
):
|
|
309
|
+
print(f"Event: {event}")
|
|
310
|
+
|
|
311
|
+
finally:
|
|
312
|
+
await client.close()
|
|
313
|
+
|
|
314
|
+
asyncio.run(test_coze_api_client())
|