beswarm 0.2.59__py3-none-any.whl → 0.2.61__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- beswarm/agents/planact.py +19 -13
- beswarm/aient/aient/core/models.py +0 -1
- beswarm/aient/aient/core/request.py +10 -7
- beswarm/aient/aient/core/response.py +54 -48
- beswarm/aient/aient/models/chatgpt.py +53 -11
- beswarm/taskmanager.py +4 -3
- beswarm/tools/completion.py +1 -1
- {beswarm-0.2.59.dist-info → beswarm-0.2.61.dist-info}/METADATA +1 -1
- {beswarm-0.2.59.dist-info → beswarm-0.2.61.dist-info}/RECORD +11 -11
- {beswarm-0.2.59.dist-info → beswarm-0.2.61.dist-info}/WHEEL +0 -0
- {beswarm-0.2.59.dist-info → beswarm-0.2.61.dist-info}/top_level.txt +0 -0
beswarm/agents/planact.py
CHANGED
@@ -4,6 +4,7 @@ import copy
|
|
4
4
|
import json
|
5
5
|
import difflib
|
6
6
|
import asyncio
|
7
|
+
import tomllib
|
7
8
|
import platform
|
8
9
|
from pathlib import Path
|
9
10
|
from datetime import datetime
|
@@ -11,6 +12,7 @@ from typing import List, Dict, Union
|
|
11
12
|
|
12
13
|
from ..broker import MessageBroker
|
13
14
|
from ..aient.aient.models import chatgpt
|
15
|
+
from ..aient.aient.models.chatgpt import ModelNotFoundError, TaskComplete
|
14
16
|
from ..aient.aient.plugins import get_function_call_list, registry
|
15
17
|
from ..prompt import worker_system_prompt, instruction_system_prompt
|
16
18
|
from ..utils import extract_xml_content, get_current_screen_image_message, replace_xml_content, register_mcp_tools, setup_logger
|
@@ -112,12 +114,6 @@ class InstructionAgent(BaseAgent):
|
|
112
114
|
async def handle_message(self, message: Dict):
|
113
115
|
"""Receives a worker response, generates the next instruction, and publishes it."""
|
114
116
|
|
115
|
-
if len(message["conversation"]) > 1 and message["conversation"][-2]["role"] == "user" \
|
116
|
-
and "<task_complete_message>" in message["conversation"][-2]["content"]:
|
117
|
-
task_complete_message = extract_xml_content(message["conversation"][-2]["content"], "task_complete_message")
|
118
|
-
self.broker.publish({"status": "finished", "result": task_complete_message}, self.status_topic)
|
119
|
-
return
|
120
|
-
|
121
117
|
instruction_prompt = "".join([
|
122
118
|
"</work_agent_conversation_end>\n\n",
|
123
119
|
f"任务目标: {self.goal}\n\n",
|
@@ -137,10 +133,11 @@ class InstructionAgent(BaseAgent):
|
|
137
133
|
if "find_and_click_element" in json.dumps(self.tools_json):
|
138
134
|
instruction_prompt = await get_current_screen_image_message(instruction_prompt)
|
139
135
|
|
140
|
-
|
136
|
+
try:
|
137
|
+
raw_response = await self.agent.ask_async(instruction_prompt)
|
138
|
+
except ModelNotFoundError as e:
|
139
|
+
raise Exception(str(e))
|
141
140
|
|
142
|
-
if "HTTP Error', 'status_code': 404" in raw_response:
|
143
|
-
raise Exception(f"Model: {self.config['engine']} not found!")
|
144
141
|
if "'status_code': 413" in raw_response or \
|
145
142
|
"'status_code': 400" in raw_response:
|
146
143
|
self.broker.publish({"status": "error", "result": "The request body is too long, please try again."}, self.status_topic)
|
@@ -205,14 +202,16 @@ class WorkerAgent(BaseAgent):
|
|
205
202
|
instruction = message["instruction"]
|
206
203
|
if "find_and_click_element" in json.dumps(self.tools_json):
|
207
204
|
instruction = await get_current_screen_image_message(instruction)
|
208
|
-
|
205
|
+
|
206
|
+
try:
|
207
|
+
response = await self.agent.ask_async(instruction)
|
208
|
+
except TaskComplete as e:
|
209
|
+
self.broker.publish({"status": "finished", "result": e.completion_message}, self.status_topic)
|
210
|
+
return
|
209
211
|
|
210
212
|
if response.strip() == '':
|
211
213
|
self.logger.error("\n❌ 工作智能体回复为空,请重新生成指令。")
|
212
214
|
self.broker.publish(message, self.error_topic)
|
213
|
-
elif "HTTP Error', 'status_code': 524" in response:
|
214
|
-
self.logger.error("\n❌ 工作智能体回复超时 100 秒,请重新生成指令。")
|
215
|
-
self.broker.publish(message, self.error_topic)
|
216
215
|
else:
|
217
216
|
self.broker.publish({"status": "new_message", "result": "\n✅ 工作智能体:\n" + response}, self.status_topic)
|
218
217
|
self.broker.publish({
|
@@ -258,6 +257,12 @@ class BrokerWorker:
|
|
258
257
|
self.logger = setup_logger(f"task_{self.work_dir.name}", log_file_path)
|
259
258
|
self.logger.info(f"Logger for task '{self.goal}' initialized. Log file: {log_file_path}")
|
260
259
|
|
260
|
+
beswarm_dir = Path(__file__).parent.parent.parent
|
261
|
+
with open(beswarm_dir / "pyproject.toml", "rb") as f:
|
262
|
+
pyproject_data = tomllib.load(f)
|
263
|
+
version = pyproject_data["project"]["version"]
|
264
|
+
self.logger.info(f"beswarm version: {version}")
|
265
|
+
|
261
266
|
async def _configure_tools(self):
|
262
267
|
mcp_list = [item for item in self.tools if isinstance(item, dict)]
|
263
268
|
if mcp_list:
|
@@ -275,6 +280,7 @@ class BrokerWorker:
|
|
275
280
|
def _task_status_subscriber(self, message: Dict):
|
276
281
|
"""Subscriber for task status changes."""
|
277
282
|
if message.get("status") == "finished":
|
283
|
+
self.logger.info("Task completed: " + message.get("result"))
|
278
284
|
self.final_result = message.get("result")
|
279
285
|
self.task_completion_event.set()
|
280
286
|
|
@@ -2,6 +2,7 @@ import re
|
|
2
2
|
import json
|
3
3
|
import httpx
|
4
4
|
import base64
|
5
|
+
import asyncio
|
5
6
|
import urllib.parse
|
6
7
|
from io import IOBase
|
7
8
|
from typing import Tuple
|
@@ -336,11 +337,11 @@ def create_jwt(client_email, private_key):
|
|
336
337
|
segments.append(base64.urlsafe_b64encode(signature).rstrip(b'='))
|
337
338
|
return b'.'.join(segments).decode()
|
338
339
|
|
339
|
-
def get_access_token(client_email, private_key):
|
340
|
-
jwt = create_jwt
|
340
|
+
async def get_access_token(client_email, private_key):
|
341
|
+
jwt = await asyncio.to_thread(create_jwt, client_email, private_key)
|
341
342
|
|
342
|
-
with httpx.
|
343
|
-
response = client.post(
|
343
|
+
async with httpx.AsyncClient() as client:
|
344
|
+
response = await client.post(
|
344
345
|
"https://oauth2.googleapis.com/token",
|
345
346
|
data={
|
346
347
|
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
|
@@ -356,7 +357,7 @@ async def get_vertex_gemini_payload(request, engine, provider, api_key=None):
|
|
356
357
|
'Content-Type': 'application/json'
|
357
358
|
}
|
358
359
|
if provider.get("client_email") and provider.get("private_key"):
|
359
|
-
access_token = get_access_token(provider['client_email'], provider['private_key'])
|
360
|
+
access_token = await get_access_token(provider['client_email'], provider['private_key'])
|
360
361
|
headers['Authorization'] = f"Bearer {access_token}"
|
361
362
|
if provider.get("project_id"):
|
362
363
|
project_id = provider.get("project_id")
|
@@ -596,7 +597,7 @@ async def get_vertex_claude_payload(request, engine, provider, api_key=None):
|
|
596
597
|
'Content-Type': 'application/json',
|
597
598
|
}
|
598
599
|
if provider.get("client_email") and provider.get("private_key"):
|
599
|
-
access_token = get_access_token(provider['client_email'], provider['private_key'])
|
600
|
+
access_token = await get_access_token(provider['client_email'], provider['private_key'])
|
600
601
|
headers['Authorization'] = f"Bearer {access_token}"
|
601
602
|
if provider.get("project_id"):
|
602
603
|
project_id = provider.get("project_id")
|
@@ -972,7 +973,9 @@ async def get_aws_payload(request, engine, provider, api_key=None):
|
|
972
973
|
|
973
974
|
if provider.get("aws_access_key") and provider.get("aws_secret_key"):
|
974
975
|
ACCEPT_HEADER = "application/vnd.amazon.bedrock.payload+json" # 指定接受 Bedrock 流格式
|
975
|
-
amz_date, payload_hash, authorization_header =
|
976
|
+
amz_date, payload_hash, authorization_header = await asyncio.to_thread(
|
977
|
+
get_signature, payload, original_model, provider.get("aws_access_key"), provider.get("aws_secret_key"), AWS_REGION, HOST, CONTENT_TYPE, ACCEPT_HEADER
|
978
|
+
)
|
976
979
|
headers = {
|
977
980
|
'Accept': ACCEPT_HEADER,
|
978
981
|
'Content-Type': CONTENT_TYPE,
|
@@ -3,6 +3,7 @@ import json
|
|
3
3
|
import random
|
4
4
|
import string
|
5
5
|
import base64
|
6
|
+
import asyncio
|
6
7
|
from datetime import datetime
|
7
8
|
|
8
9
|
from .log_config import logger
|
@@ -14,19 +15,19 @@ async def check_response(response, error_log):
|
|
14
15
|
error_message = await response.aread()
|
15
16
|
error_str = error_message.decode('utf-8', errors='replace')
|
16
17
|
try:
|
17
|
-
error_json = json.loads
|
18
|
+
error_json = await asyncio.to_thread(json.loads, error_str)
|
18
19
|
except json.JSONDecodeError:
|
19
20
|
error_json = error_str
|
20
21
|
return {"error": f"{error_log} HTTP Error", "status_code": response.status_code, "details": error_json}
|
21
22
|
return None
|
22
23
|
|
23
|
-
def gemini_json_poccess(response_str):
|
24
|
+
async def gemini_json_poccess(response_str):
|
24
25
|
promptTokenCount = 0
|
25
26
|
candidatesTokenCount = 0
|
26
27
|
totalTokenCount = 0
|
27
28
|
image_base64 = None
|
28
29
|
|
29
|
-
response_json = json.loads
|
30
|
+
response_json = await asyncio.to_thread(json.loads, response_str)
|
30
31
|
json_data = safe_get(response_json, "candidates", 0, "content", default=None)
|
31
32
|
finishReason = safe_get(response_json, "candidates", 0 , "finishReason", default=None)
|
32
33
|
if finishReason:
|
@@ -53,9 +54,9 @@ def gemini_json_poccess(response_str):
|
|
53
54
|
|
54
55
|
return is_thinking, reasoning_content, content, image_base64, function_call_name, function_full_response, finishReason, blockReason, promptTokenCount, candidatesTokenCount, totalTokenCount
|
55
56
|
|
56
|
-
async def fetch_gemini_response_stream(client, url, headers, payload, model):
|
57
|
+
async def fetch_gemini_response_stream(client, url, headers, payload, model, timeout):
|
57
58
|
timestamp = int(datetime.timestamp(datetime.now()))
|
58
|
-
async with client.stream('POST', url, headers=headers, json=payload) as response:
|
59
|
+
async with client.stream('POST', url, headers=headers, json=payload, timeout=timeout) as response:
|
59
60
|
error_message = await check_response(response, "fetch_gemini_response_stream")
|
60
61
|
if error_message:
|
61
62
|
yield error_message
|
@@ -75,7 +76,7 @@ async def fetch_gemini_response_stream(client, url, headers, payload, model):
|
|
75
76
|
if line.startswith("data: "):
|
76
77
|
parts_json = line.lstrip("data: ").strip()
|
77
78
|
try:
|
78
|
-
json.loads
|
79
|
+
await asyncio.to_thread(json.loads, parts_json)
|
79
80
|
except json.JSONDecodeError:
|
80
81
|
logger.error(f"JSON decode error: {parts_json}")
|
81
82
|
continue
|
@@ -83,12 +84,12 @@ async def fetch_gemini_response_stream(client, url, headers, payload, model):
|
|
83
84
|
parts_json += line
|
84
85
|
parts_json = parts_json.lstrip("[,")
|
85
86
|
try:
|
86
|
-
json.loads
|
87
|
+
await asyncio.to_thread(json.loads, parts_json)
|
87
88
|
except json.JSONDecodeError:
|
88
89
|
continue
|
89
90
|
|
90
91
|
# https://ai.google.dev/api/generate-content?hl=zh-cn#FinishReason
|
91
|
-
is_thinking, reasoning_content, content, image_base64, function_call_name, function_full_response, finishReason, blockReason, promptTokenCount, candidatesTokenCount, totalTokenCount = gemini_json_poccess(parts_json)
|
92
|
+
is_thinking, reasoning_content, content, image_base64, function_call_name, function_full_response, finishReason, blockReason, promptTokenCount, candidatesTokenCount, totalTokenCount = await gemini_json_poccess(parts_json)
|
92
93
|
|
93
94
|
if is_thinking:
|
94
95
|
sse_string = await generate_sse_response(timestamp, model, reasoning_content=reasoning_content)
|
@@ -122,9 +123,9 @@ async def fetch_gemini_response_stream(client, url, headers, payload, model):
|
|
122
123
|
|
123
124
|
yield "data: [DONE]" + end_of_line
|
124
125
|
|
125
|
-
async def fetch_vertex_claude_response_stream(client, url, headers, payload, model):
|
126
|
+
async def fetch_vertex_claude_response_stream(client, url, headers, payload, model, timeout):
|
126
127
|
timestamp = int(datetime.timestamp(datetime.now()))
|
127
|
-
async with client.stream('POST', url, headers=headers, json=payload) as response:
|
128
|
+
async with client.stream('POST', url, headers=headers, json=payload, timeout=timeout) as response:
|
128
129
|
error_message = await check_response(response, "fetch_vertex_claude_response_stream")
|
129
130
|
if error_message:
|
130
131
|
yield error_message
|
@@ -159,7 +160,7 @@ async def fetch_vertex_claude_response_stream(client, url, headers, payload, mod
|
|
159
160
|
|
160
161
|
if line and '\"text\": \"' in line and is_finish == False:
|
161
162
|
try:
|
162
|
-
json_data = json.loads
|
163
|
+
json_data = await asyncio.to_thread(json.loads, "{" + line.strip().rstrip(",") + "}")
|
163
164
|
content = json_data.get('text', '')
|
164
165
|
sse_string = await generate_sse_response(timestamp, model, content=content)
|
165
166
|
yield sse_string
|
@@ -176,7 +177,7 @@ async def fetch_vertex_claude_response_stream(client, url, headers, payload, mod
|
|
176
177
|
function_full_response += line
|
177
178
|
|
178
179
|
if need_function_call:
|
179
|
-
function_call = json.loads
|
180
|
+
function_call = await asyncio.to_thread(json.loads, function_full_response)
|
180
181
|
function_call_name = function_call["name"]
|
181
182
|
function_call_id = function_call["id"]
|
182
183
|
sse_string = await generate_sse_response(timestamp, model, content=None, tools_id=function_call_id, function_call_name=function_call_name)
|
@@ -190,14 +191,14 @@ async def fetch_vertex_claude_response_stream(client, url, headers, payload, mod
|
|
190
191
|
|
191
192
|
yield "data: [DONE]" + end_of_line
|
192
193
|
|
193
|
-
async def fetch_gpt_response_stream(client, url, headers, payload):
|
194
|
+
async def fetch_gpt_response_stream(client, url, headers, payload, timeout):
|
194
195
|
timestamp = int(datetime.timestamp(datetime.now()))
|
195
196
|
random.seed(timestamp)
|
196
197
|
random_str = ''.join(random.choices(string.ascii_letters + string.digits, k=29))
|
197
198
|
is_thinking = False
|
198
199
|
has_send_thinking = False
|
199
200
|
ark_tag = False
|
200
|
-
async with client.stream('POST', url, headers=headers, json=payload) as response:
|
201
|
+
async with client.stream('POST', url, headers=headers, json=payload, timeout=timeout) as response:
|
201
202
|
error_message = await check_response(response, "fetch_gpt_response_stream")
|
202
203
|
if error_message:
|
203
204
|
yield error_message
|
@@ -213,7 +214,7 @@ async def fetch_gpt_response_stream(client, url, headers, payload):
|
|
213
214
|
if line and not line.startswith(":") and (result:=line.lstrip("data: ").strip()):
|
214
215
|
if result.strip() == "[DONE]":
|
215
216
|
break
|
216
|
-
line = json.loads
|
217
|
+
line = await asyncio.to_thread(json.loads, result)
|
217
218
|
line['id'] = f"chatcmpl-{random_str}"
|
218
219
|
|
219
220
|
# 处理 <think> 标签
|
@@ -306,12 +307,12 @@ async def fetch_gpt_response_stream(client, url, headers, payload):
|
|
306
307
|
yield "data: " + json.dumps(line).strip() + end_of_line
|
307
308
|
yield "data: [DONE]" + end_of_line
|
308
309
|
|
309
|
-
async def fetch_azure_response_stream(client, url, headers, payload):
|
310
|
+
async def fetch_azure_response_stream(client, url, headers, payload, timeout):
|
310
311
|
timestamp = int(datetime.timestamp(datetime.now()))
|
311
312
|
is_thinking = False
|
312
313
|
has_send_thinking = False
|
313
314
|
ark_tag = False
|
314
|
-
async with client.stream('POST', url, headers=headers, json=payload) as response:
|
315
|
+
async with client.stream('POST', url, headers=headers, json=payload, timeout=timeout) as response:
|
315
316
|
error_message = await check_response(response, "fetch_azure_response_stream")
|
316
317
|
if error_message:
|
317
318
|
yield error_message
|
@@ -327,7 +328,7 @@ async def fetch_azure_response_stream(client, url, headers, payload):
|
|
327
328
|
if line and not line.startswith(":") and (result:=line.lstrip("data: ").strip()):
|
328
329
|
if result.strip() == "[DONE]":
|
329
330
|
break
|
330
|
-
line = json.loads
|
331
|
+
line = await asyncio.to_thread(json.loads, result)
|
331
332
|
no_stream_content = safe_get(line, "choices", 0, "message", "content", default="")
|
332
333
|
content = safe_get(line, "choices", 0, "delta", "content", default="")
|
333
334
|
|
@@ -362,9 +363,9 @@ async def fetch_azure_response_stream(client, url, headers, payload):
|
|
362
363
|
yield "data: " + json.dumps(line).strip() + end_of_line
|
363
364
|
yield "data: [DONE]" + end_of_line
|
364
365
|
|
365
|
-
async def fetch_cloudflare_response_stream(client, url, headers, payload, model):
|
366
|
+
async def fetch_cloudflare_response_stream(client, url, headers, payload, model, timeout):
|
366
367
|
timestamp = int(datetime.timestamp(datetime.now()))
|
367
|
-
async with client.stream('POST', url, headers=headers, json=payload) as response:
|
368
|
+
async with client.stream('POST', url, headers=headers, json=payload, timeout=timeout) as response:
|
368
369
|
error_message = await check_response(response, "fetch_cloudflare_response_stream")
|
369
370
|
if error_message:
|
370
371
|
yield error_message
|
@@ -380,16 +381,16 @@ async def fetch_cloudflare_response_stream(client, url, headers, payload, model)
|
|
380
381
|
line = line.lstrip("data: ")
|
381
382
|
if line == "[DONE]":
|
382
383
|
break
|
383
|
-
resp: dict = json.loads
|
384
|
+
resp: dict = await asyncio.to_thread(json.loads, line)
|
384
385
|
message = resp.get("response")
|
385
386
|
if message:
|
386
387
|
sse_string = await generate_sse_response(timestamp, model, content=message)
|
387
388
|
yield sse_string
|
388
389
|
yield "data: [DONE]" + end_of_line
|
389
390
|
|
390
|
-
async def fetch_cohere_response_stream(client, url, headers, payload, model):
|
391
|
+
async def fetch_cohere_response_stream(client, url, headers, payload, model, timeout):
|
391
392
|
timestamp = int(datetime.timestamp(datetime.now()))
|
392
|
-
async with client.stream('POST', url, headers=headers, json=payload) as response:
|
393
|
+
async with client.stream('POST', url, headers=headers, json=payload, timeout=timeout) as response:
|
393
394
|
error_message = await check_response(response, "fetch_cohere_response_stream")
|
394
395
|
if error_message:
|
395
396
|
yield error_message
|
@@ -401,7 +402,7 @@ async def fetch_cohere_response_stream(client, url, headers, payload, model):
|
|
401
402
|
while "\n" in buffer:
|
402
403
|
line, buffer = buffer.split("\n", 1)
|
403
404
|
# logger.info("line: %s", repr(line))
|
404
|
-
resp: dict = json.loads
|
405
|
+
resp: dict = await asyncio.to_thread(json.loads, line)
|
405
406
|
if resp.get("is_finished") == True:
|
406
407
|
break
|
407
408
|
if resp.get("event_type") == "text-generation":
|
@@ -410,9 +411,9 @@ async def fetch_cohere_response_stream(client, url, headers, payload, model):
|
|
410
411
|
yield sse_string
|
411
412
|
yield "data: [DONE]" + end_of_line
|
412
413
|
|
413
|
-
async def fetch_claude_response_stream(client, url, headers, payload, model):
|
414
|
+
async def fetch_claude_response_stream(client, url, headers, payload, model, timeout):
|
414
415
|
timestamp = int(datetime.timestamp(datetime.now()))
|
415
|
-
async with client.stream('POST', url, headers=headers, json=payload) as response:
|
416
|
+
async with client.stream('POST', url, headers=headers, json=payload, timeout=timeout) as response:
|
416
417
|
error_message = await check_response(response, "fetch_claude_response_stream")
|
417
418
|
if error_message:
|
418
419
|
yield error_message
|
@@ -427,7 +428,7 @@ async def fetch_claude_response_stream(client, url, headers, payload, model):
|
|
427
428
|
# logger.info(line)
|
428
429
|
|
429
430
|
if line.startswith("data:") and (line := line.lstrip("data: ")):
|
430
|
-
resp: dict = json.loads
|
431
|
+
resp: dict = await asyncio.to_thread(json.loads, line)
|
431
432
|
|
432
433
|
input_tokens = input_tokens or safe_get(resp, "message", "usage", "input_tokens", default=0)
|
433
434
|
# cache_creation_input_tokens = safe_get(resp, "message", "usage", "cache_creation_input_tokens", default=0)
|
@@ -463,9 +464,9 @@ async def fetch_claude_response_stream(client, url, headers, payload, model):
|
|
463
464
|
|
464
465
|
yield "data: [DONE]" + end_of_line
|
465
466
|
|
466
|
-
async def fetch_aws_response_stream(client, url, headers, payload, model):
|
467
|
+
async def fetch_aws_response_stream(client, url, headers, payload, model, timeout):
|
467
468
|
timestamp = int(datetime.timestamp(datetime.now()))
|
468
|
-
async with client.stream('POST', url, headers=headers, json=payload) as response:
|
469
|
+
async with client.stream('POST', url, headers=headers, json=payload, timeout=timeout) as response:
|
469
470
|
error_message = await check_response(response, "fetch_aws_response_stream")
|
470
471
|
if error_message:
|
471
472
|
yield error_message
|
@@ -486,7 +487,7 @@ async def fetch_aws_response_stream(client, url, headers, payload, model):
|
|
486
487
|
if not json_match:
|
487
488
|
continue
|
488
489
|
try:
|
489
|
-
chunk_data = json.loads
|
490
|
+
chunk_data = await asyncio.to_thread(json.loads, json_match.group(0).lstrip('event'))
|
490
491
|
except json.JSONDecodeError:
|
491
492
|
logger.error(f"DEBUG json.JSONDecodeError: {json_match.group(0).lstrip('event')!r}")
|
492
493
|
continue
|
@@ -496,7 +497,7 @@ async def fetch_aws_response_stream(client, url, headers, payload, model):
|
|
496
497
|
# 解码 Base64 编码的字节
|
497
498
|
decoded_bytes = base64.b64decode(chunk_data["bytes"])
|
498
499
|
# 将解码后的字节再次解析为 JSON
|
499
|
-
payload_chunk = json.loads
|
500
|
+
payload_chunk = await asyncio.to_thread(json.loads, decoded_bytes.decode('utf-8'))
|
500
501
|
# print(f"DEBUG payload_chunk: {payload_chunk!r}")
|
501
502
|
|
502
503
|
text = safe_get(payload_chunk, "delta", "text", default="")
|
@@ -514,13 +515,13 @@ async def fetch_aws_response_stream(client, url, headers, payload, model):
|
|
514
515
|
|
515
516
|
yield "data: [DONE]" + end_of_line
|
516
517
|
|
517
|
-
async def fetch_response(client, url, headers, payload, engine, model):
|
518
|
+
async def fetch_response(client, url, headers, payload, engine, model, timeout=200):
|
518
519
|
response = None
|
519
520
|
if payload.get("file"):
|
520
521
|
file = payload.pop("file")
|
521
|
-
response = await client.post(url, headers=headers, data=payload, files={"file": file})
|
522
|
+
response = await client.post(url, headers=headers, data=payload, files={"file": file}, timeout=timeout)
|
522
523
|
else:
|
523
|
-
response = await client.post(url, headers=headers, json=payload)
|
524
|
+
response = await client.post(url, headers=headers, json=payload, timeout=timeout)
|
524
525
|
error_message = await check_response(response, "fetch_response")
|
525
526
|
if error_message:
|
526
527
|
yield error_message
|
@@ -530,7 +531,8 @@ async def fetch_response(client, url, headers, payload, engine, model):
|
|
530
531
|
yield response.read()
|
531
532
|
|
532
533
|
elif engine == "gemini" or engine == "vertex-gemini" or engine == "aws":
|
533
|
-
|
534
|
+
response_bytes = await response.aread()
|
535
|
+
response_json = await asyncio.to_thread(json.loads, response_bytes)
|
534
536
|
# print("response_json", json.dumps(response_json, indent=4, ensure_ascii=False))
|
535
537
|
|
536
538
|
if isinstance(response_json, str):
|
@@ -585,7 +587,8 @@ async def fetch_response(client, url, headers, payload, engine, model):
|
|
585
587
|
yield await generate_no_stream_response(timestamp, model, content=content, tools_id=None, function_call_name=function_call_name, function_call_content=function_call_content, role=role, total_tokens=total_tokens, prompt_tokens=prompt_tokens, completion_tokens=candidates_tokens, reasoning_content=reasoning_content, image_base64=image_base64)
|
586
588
|
|
587
589
|
elif engine == "claude":
|
588
|
-
|
590
|
+
response_bytes = await response.aread()
|
591
|
+
response_json = await asyncio.to_thread(json.loads, response_bytes)
|
589
592
|
# print("response_json", json.dumps(response_json, indent=4, ensure_ascii=False))
|
590
593
|
|
591
594
|
content = safe_get(response_json, "content", 0, "text")
|
@@ -604,7 +607,8 @@ async def fetch_response(client, url, headers, payload, engine, model):
|
|
604
607
|
yield await generate_no_stream_response(timestamp, model, content=content, tools_id=tools_id, function_call_name=function_call_name, function_call_content=function_call_content, role=role, total_tokens=total_tokens, prompt_tokens=prompt_tokens, completion_tokens=output_tokens)
|
605
608
|
|
606
609
|
elif engine == "azure":
|
607
|
-
|
610
|
+
response_bytes = await response.aread()
|
611
|
+
response_json = await asyncio.to_thread(json.loads, response_bytes)
|
608
612
|
# 删除 content_filter_results
|
609
613
|
if "choices" in response_json:
|
610
614
|
for choice in response_json["choices"]:
|
@@ -618,34 +622,36 @@ async def fetch_response(client, url, headers, payload, engine, model):
|
|
618
622
|
yield response_json
|
619
623
|
|
620
624
|
elif "dashscope.aliyuncs.com" in url and "multimodal-generation" in url:
|
621
|
-
|
625
|
+
response_bytes = await response.aread()
|
626
|
+
response_json = await asyncio.to_thread(json.loads, response_bytes)
|
622
627
|
content = safe_get(response_json, "output", "choices", 0, "message", "content", 0, default=None)
|
623
628
|
yield content
|
624
629
|
else:
|
625
|
-
|
630
|
+
response_bytes = await response.aread()
|
631
|
+
response_json = await asyncio.to_thread(json.loads, response_bytes)
|
626
632
|
yield response_json
|
627
633
|
|
628
|
-
async def fetch_response_stream(client, url, headers, payload, engine, model):
|
634
|
+
async def fetch_response_stream(client, url, headers, payload, engine, model, timeout=200):
|
629
635
|
if engine == "gemini" or engine == "vertex-gemini":
|
630
|
-
async for chunk in fetch_gemini_response_stream(client, url, headers, payload, model):
|
636
|
+
async for chunk in fetch_gemini_response_stream(client, url, headers, payload, model, timeout):
|
631
637
|
yield chunk
|
632
638
|
elif engine == "claude" or engine == "vertex-claude":
|
633
|
-
async for chunk in fetch_claude_response_stream(client, url, headers, payload, model):
|
639
|
+
async for chunk in fetch_claude_response_stream(client, url, headers, payload, model, timeout):
|
634
640
|
yield chunk
|
635
641
|
elif engine == "aws":
|
636
|
-
async for chunk in fetch_aws_response_stream(client, url, headers, payload, model):
|
642
|
+
async for chunk in fetch_aws_response_stream(client, url, headers, payload, model, timeout):
|
637
643
|
yield chunk
|
638
644
|
elif engine == "gpt" or engine == "openrouter" or engine == "azure-databricks":
|
639
|
-
async for chunk in fetch_gpt_response_stream(client, url, headers, payload):
|
645
|
+
async for chunk in fetch_gpt_response_stream(client, url, headers, payload, timeout):
|
640
646
|
yield chunk
|
641
647
|
elif engine == "azure":
|
642
|
-
async for chunk in fetch_azure_response_stream(client, url, headers, payload):
|
648
|
+
async for chunk in fetch_azure_response_stream(client, url, headers, payload, timeout):
|
643
649
|
yield chunk
|
644
650
|
elif engine == "cloudflare":
|
645
|
-
async for chunk in fetch_cloudflare_response_stream(client, url, headers, payload, model):
|
651
|
+
async for chunk in fetch_cloudflare_response_stream(client, url, headers, payload, model, timeout):
|
646
652
|
yield chunk
|
647
653
|
elif engine == "cohere":
|
648
|
-
async for chunk in fetch_cohere_response_stream(client, url, headers, payload, model):
|
654
|
+
async for chunk in fetch_cohere_response_stream(client, url, headers, payload, model, timeout):
|
649
655
|
yield chunk
|
650
656
|
else:
|
651
|
-
raise ValueError("Unknown response")
|
657
|
+
raise ValueError("Unknown response")
|
@@ -15,6 +15,31 @@ from ..utils.scripts import safe_get, async_generator_to_sync, parse_function_xm
|
|
15
15
|
from ..core.request import prepare_request_payload
|
16
16
|
from ..core.response import fetch_response_stream, fetch_response
|
17
17
|
|
18
|
+
class APITimeoutError(Exception):
|
19
|
+
"""Custom exception for API timeout errors."""
|
20
|
+
pass
|
21
|
+
|
22
|
+
class ValidationError(Exception):
|
23
|
+
"""Custom exception for response validation errors."""
|
24
|
+
def __init__(self, message, response_text):
|
25
|
+
super().__init__(message)
|
26
|
+
self.response_text = response_text
|
27
|
+
|
28
|
+
class EmptyResponseError(Exception):
|
29
|
+
"""Custom exception for empty API responses."""
|
30
|
+
pass
|
31
|
+
|
32
|
+
class ModelNotFoundError(Exception):
|
33
|
+
"""Custom exception for model not found (404) errors."""
|
34
|
+
pass
|
35
|
+
|
36
|
+
class TaskComplete(Exception):
|
37
|
+
"""Exception-like signal to indicate the task is complete."""
|
38
|
+
def __init__(self, message):
|
39
|
+
self.completion_message = message
|
40
|
+
super().__init__(f"Task completed with message: {message}")
|
41
|
+
|
42
|
+
|
18
43
|
class chatgpt(BaseLLM):
|
19
44
|
"""
|
20
45
|
Official ChatGPT API
|
@@ -436,7 +461,7 @@ class chatgpt(BaseLLM):
|
|
436
461
|
yield chunk
|
437
462
|
|
438
463
|
if not full_response.strip():
|
439
|
-
raise
|
464
|
+
raise EmptyResponseError("Response is empty")
|
440
465
|
|
441
466
|
if self.print_log:
|
442
467
|
self.logger.info(f"total_tokens: {total_tokens}")
|
@@ -450,7 +475,7 @@ class chatgpt(BaseLLM):
|
|
450
475
|
if self.check_done:
|
451
476
|
# self.logger.info(f"worker Response: {full_response}")
|
452
477
|
if not full_response.strip().endswith('[done]'):
|
453
|
-
raise
|
478
|
+
raise ValidationError("Response is not ended with [done]", response_text=full_response)
|
454
479
|
else:
|
455
480
|
full_response = full_response.strip().rstrip('[done]')
|
456
481
|
full_response = full_response.replace("<tool_code>", "").replace("</tool_code>", "")
|
@@ -494,6 +519,8 @@ class chatgpt(BaseLLM):
|
|
494
519
|
# 删除 task_complete 跟其他工具一起调用的情况,因为 task_complete 必须单独调用
|
495
520
|
if len(function_parameter) > 1:
|
496
521
|
function_parameter = [tool_dict for tool_dict in function_parameter if tool_dict.get("function_name", "") != "task_complete"]
|
522
|
+
if len(function_parameter) == 1 and function_parameter[0].get("function_name", "") == "task_complete":
|
523
|
+
raise TaskComplete(safe_get(function_parameter, 0, "parameter", "message", default="The task has been completed."))
|
497
524
|
|
498
525
|
if self.print_log and invalid_tools:
|
499
526
|
self.logger.error(f"invalid_tools: {invalid_tools}")
|
@@ -739,13 +766,20 @@ class chatgpt(BaseLLM):
|
|
739
766
|
)
|
740
767
|
|
741
768
|
# 处理正常响应
|
769
|
+
index = 0
|
742
770
|
async for processed_chunk in self._process_stream_response(
|
743
771
|
generator, convo_id=convo_id, function_name=function_name,
|
744
772
|
total_tokens=total_tokens, function_arguments=function_arguments,
|
745
773
|
function_call_id=function_call_id, model=model, language=language,
|
746
774
|
system_prompt=system_prompt, pass_history=pass_history, is_async=True, stream=stream, **kwargs
|
747
775
|
):
|
776
|
+
if index == 0:
|
777
|
+
if "HTTP Error', 'status_code': 524" in processed_chunk:
|
778
|
+
raise APITimeoutError("Response timeout")
|
779
|
+
if "HTTP Error', 'status_code': 404" in processed_chunk:
|
780
|
+
raise ModelNotFoundError(f"Model: {model or self.engine} not found!")
|
748
781
|
yield processed_chunk
|
782
|
+
index += 1
|
749
783
|
|
750
784
|
# 成功处理,跳出重试循环
|
751
785
|
break
|
@@ -754,17 +788,25 @@ class chatgpt(BaseLLM):
|
|
754
788
|
return # Stop iteration
|
755
789
|
except httpx.RemoteProtocolError:
|
756
790
|
continue
|
791
|
+
except APITimeoutError:
|
792
|
+
self.logger.warning("API response timeout (524), retrying...")
|
793
|
+
continue
|
794
|
+
except ValidationError as e:
|
795
|
+
self.logger.warning(f"Validation failed: {e}. Retrying with corrective prompt.")
|
796
|
+
need_done_prompt = [
|
797
|
+
{"role": "assistant", "content": e.response_text},
|
798
|
+
{"role": "user", "content": "你的消息没有以[done]结尾,请重新输出"}
|
799
|
+
]
|
800
|
+
continue
|
801
|
+
except EmptyResponseError as e:
|
802
|
+
self.logger.warning(f"{e}, retrying...")
|
803
|
+
continue
|
804
|
+
except TaskComplete as e:
|
805
|
+
raise
|
806
|
+
except ModelNotFoundError as e:
|
807
|
+
raise
|
757
808
|
except Exception as e:
|
758
809
|
self.logger.error(f"{e}")
|
759
|
-
if "validation_error" in str(e):
|
760
|
-
bad_assistant_message = json.loads(str(e))["response"]
|
761
|
-
need_done_prompt = [
|
762
|
-
{"role": "assistant", "content": bad_assistant_message},
|
763
|
-
{"role": "user", "content": "你的消息没有以[done]结尾,请重新输出"}
|
764
|
-
]
|
765
|
-
continue
|
766
|
-
if "response_empty_error" in str(e):
|
767
|
-
continue
|
768
810
|
import traceback
|
769
811
|
self.logger.error(traceback.format_exc())
|
770
812
|
if "Invalid URL" in str(e):
|
beswarm/taskmanager.py
CHANGED
@@ -23,6 +23,7 @@ class TaskManager:
|
|
23
23
|
它管理任务的生命周期,并通过一个固定大小的工作者池来控制并发执行的任务数量。
|
24
24
|
"""
|
25
25
|
def __init__(self, concurrency_limit=None):
|
26
|
+
self.raw_concurrency_limit = concurrency_limit
|
26
27
|
self.concurrency_limit = concurrency_limit or int(os.getenv("BESWARM_CONCURRENCY_LIMIT", "3"))
|
27
28
|
|
28
29
|
if self.concurrency_limit <= 0:
|
@@ -38,8 +39,6 @@ class TaskManager:
|
|
38
39
|
self.cache_dir = None
|
39
40
|
self.task_cache_file = None
|
40
41
|
|
41
|
-
print(f"TaskManager 初始化,并发限制为: {self.concurrency_limit}")
|
42
|
-
|
43
42
|
def set_root_path(self, root_path):
|
44
43
|
"""设置工作根目录并加载持久化的任务状态。"""
|
45
44
|
if self.root_path is not None:
|
@@ -52,6 +51,9 @@ class TaskManager:
|
|
52
51
|
self._load_tasks_from_cache()
|
53
52
|
self.set_task_cache("root_path", str(self.root_path))
|
54
53
|
|
54
|
+
if not self.raw_concurrency_limit:
|
55
|
+
self.concurrency_limit = int(os.getenv("BESWARM_CONCURRENCY_LIMIT", "3"))
|
56
|
+
|
55
57
|
# 启动工作者池
|
56
58
|
self.start()
|
57
59
|
# 恢复中断的任务
|
@@ -86,7 +88,6 @@ class TaskManager:
|
|
86
88
|
|
87
89
|
async def _worker_loop(self, worker_name: str):
|
88
90
|
"""每个工作者的主循环,从队列中拉取并执行任务。"""
|
89
|
-
print(f"[{worker_name}] 已就绪,等待任务...")
|
90
91
|
while self._is_running:
|
91
92
|
try:
|
92
93
|
task_id, coro = await self._pending_queue.get()
|
beswarm/tools/completion.py
CHANGED
@@ -3,16 +3,16 @@ beswarm/broker.py,sha256=64Y-djrKYaZfBQ8obwHOmr921QgZeu9BtScZWaYLfDo,9887
|
|
3
3
|
beswarm/core.py,sha256=htssaaeIBZ_yOqvX9VtANoVWaZHt_7oWcxyDI1z0paQ,310
|
4
4
|
beswarm/knowledge_graph.py,sha256=oiOMknAJzGrOHc2AyQgvrCcZAkGLhFnsnvSBdfFBWMw,14831
|
5
5
|
beswarm/prompt.py,sha256=INVRWQZP6lysvGUcPOYI_er5-bi1gGe_qa6BTov7PmY,32362
|
6
|
-
beswarm/taskmanager.py,sha256=
|
6
|
+
beswarm/taskmanager.py,sha256=vMmcoZ4FlNvjEliRkv3AniPji50NcY4Q1_2HETzR0DU,12226
|
7
7
|
beswarm/utils.py,sha256=0J-b38P5QGT-A_38co7FjzaUNJykaskI7mbbcQ4w_68,8215
|
8
8
|
beswarm/agents/chatgroup.py,sha256=PzrmRcDKAbB7cxL16nMod_CzPosDV6bfTmXxQVuv-AQ,12012
|
9
|
-
beswarm/agents/planact.py,sha256=
|
9
|
+
beswarm/agents/planact.py,sha256=y6NtiiQC1SIkJCWeK4REoR8yRoMllRda5aiVeXhDZdE,20082
|
10
10
|
beswarm/aient/aient/__init__.py,sha256=SRfF7oDVlOOAi6nGKiJIUK6B_arqYLO9iSMp-2IZZps,21
|
11
11
|
beswarm/aient/aient/core/__init__.py,sha256=NxjebTlku35S4Dzr16rdSqSTWUvvwEeACe8KvHJnjPg,34
|
12
12
|
beswarm/aient/aient/core/log_config.py,sha256=kz2_yJv1p-o3lUQOwA3qh-LSc3wMHv13iCQclw44W9c,274
|
13
|
-
beswarm/aient/aient/core/models.py,sha256=
|
14
|
-
beswarm/aient/aient/core/request.py,sha256=
|
15
|
-
beswarm/aient/aient/core/response.py,sha256=
|
13
|
+
beswarm/aient/aient/core/models.py,sha256=KMlCRLjtq1wQHZTJGqnbWhPS2cHq6eLdnk7peKDrzR8,7490
|
14
|
+
beswarm/aient/aient/core/request.py,sha256=vfwi3ZGYp2hQzSJ6mPXJVgcV_uu5AJ_NAL84mLfF8WA,76674
|
15
|
+
beswarm/aient/aient/core/response.py,sha256=vQFuc3amHiD1hv_OiINRJnh33n79PnbdzMSBSRlqR5E,34309
|
16
16
|
beswarm/aient/aient/core/utils.py,sha256=D98d5Cy1h4ejKtuxS0EEDtL4YqpaZLB5tuXoVP0IBWQ,28462
|
17
17
|
beswarm/aient/aient/core/test/test_base_api.py,sha256=pWnycRJbuPSXKKU9AQjWrMAX1wiLC_014Qc9hh5C2Pw,524
|
18
18
|
beswarm/aient/aient/core/test/test_geminimask.py,sha256=HFX8jDbNg_FjjgPNxfYaR-0-roUrOO-ND-FVsuxSoiw,13254
|
@@ -21,7 +21,7 @@ beswarm/aient/aient/core/test/test_payload.py,sha256=8jBiJY1uidm1jzL-EiK0s6UGmW9
|
|
21
21
|
beswarm/aient/aient/models/__init__.py,sha256=ZTiZgbfBPTjIPSKURE7t6hlFBVLRS9lluGbmqc1WjxQ,43
|
22
22
|
beswarm/aient/aient/models/audio.py,sha256=kRd-8-WXzv4vwvsTGwnstK-WR8--vr9CdfCZzu8y9LA,1934
|
23
23
|
beswarm/aient/aient/models/base.py,sha256=-nnihYnx-vHZMqeVO9ljjt3k4FcD3n-iMk4tT-10nRQ,7232
|
24
|
-
beswarm/aient/aient/models/chatgpt.py,sha256=
|
24
|
+
beswarm/aient/aient/models/chatgpt.py,sha256=q62B6cbtHqKrqsQjM24k_1wi_5-UiuxkXa7e2yG_Clg,44661
|
25
25
|
beswarm/aient/aient/plugins/__init__.py,sha256=p3KO6Aa3Lupos4i2SjzLQw1hzQTigOAfEHngsldrsyk,986
|
26
26
|
beswarm/aient/aient/plugins/arXiv.py,sha256=yHjb6PS3GUWazpOYRMKMzghKJlxnZ5TX8z9F6UtUVow,1461
|
27
27
|
beswarm/aient/aient/plugins/config.py,sha256=TGgZ5SnNKZ8MmdznrZ-TEq7s2ulhAAwTSKH89bci3dA,7079
|
@@ -104,7 +104,7 @@ beswarm/queries/tree-sitter-languages/scala-tags.scm,sha256=UxQjz80JIrrJ7Pm56uUn
|
|
104
104
|
beswarm/queries/tree-sitter-languages/typescript-tags.scm,sha256=OMdCeedPiA24ky82DpgTMKXK_l2ySTuF2zrQ2fJAi9E,1253
|
105
105
|
beswarm/tools/__init__.py,sha256=CPFj04Lm6TEol_2BFX-mVrpzvVIsmPcfTWhRg3xex7Q,2077
|
106
106
|
beswarm/tools/click.py,sha256=wu-Ov5U2ZZLcU0gennDVh_2w_Td7F4dbVJcmu_dfHV4,20872
|
107
|
-
beswarm/tools/completion.py,sha256
|
107
|
+
beswarm/tools/completion.py,sha256=AIHtEHfSp6fs4Xa_nTcw9fLgXtYHtYGRDTSy7_x_Fak,544
|
108
108
|
beswarm/tools/edit_file.py,sha256=ZTJvbpsfRlp2t98kTn9XQ5qZBTdsWJVWv9t0lvK4RfU,9147
|
109
109
|
beswarm/tools/graph.py,sha256=IJ-wExgSEDLGYXTRtCtyxMM2wkA1gg40Z-KI3iyQThE,5084
|
110
110
|
beswarm/tools/planner.py,sha256=vsHd7rE8RQHJrZ7BQ0ZXhbt4Fjh3DeyxU4piA5R-VPM,1253
|
@@ -116,7 +116,7 @@ beswarm/tools/search_web.py,sha256=NYrb5KL_WUGPm-fOKT8Cyjon04lxBU-gaLdrVjeYgGo,1
|
|
116
116
|
beswarm/tools/subtasks.py,sha256=mIjA2QrRy9Fos4rYm8fCfu2QrsE_MGnQI9IR8dOxsGs,9885
|
117
117
|
beswarm/tools/worker.py,sha256=_cSkRUKRJMAiZiTfnBze_e9Kc7k7KvbB5hdxdvp4FW4,2009
|
118
118
|
beswarm/tools/write_csv.py,sha256=u0Hq18Ksfheb52MVtyLNCnSDHibITpsYBPs2ub7USYA,1466
|
119
|
-
beswarm-0.2.
|
120
|
-
beswarm-0.2.
|
121
|
-
beswarm-0.2.
|
122
|
-
beswarm-0.2.
|
119
|
+
beswarm-0.2.61.dist-info/METADATA,sha256=s14zUG7R6_2LWhXRMX-UZF0MQeh06d0bmX84pH5KlhU,3878
|
120
|
+
beswarm-0.2.61.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
121
|
+
beswarm-0.2.61.dist-info/top_level.txt,sha256=pJw4O87wvt5882smuSO6DfByJz7FJ8SxxT8h9fHCmpo,8
|
122
|
+
beswarm-0.2.61.dist-info/RECORD,,
|
File without changes
|
File without changes
|